google-meridian 1.1.6__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,514 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Backend Abstraction Layer for Meridian."""
16
+
17
+ import os
18
+ from typing import Any, Optional, TYPE_CHECKING, Tuple, Union
19
+
20
+ from meridian.backend import config
21
+ import numpy as np
22
+ from typing_extensions import Literal
23
+
24
+ # The conditional imports in this module are a deliberate design choice for the
25
+ # backend abstraction layer. The TFP-on-JAX substrate provides a nearly
26
+ # identical API to the standard TFP library, making an alias-based approach more
27
+ # pragmatic than a full Abstract Base Class implementation, which would require
28
+ # extensive boilerplate.
29
+ # pylint: disable=g-import-not-at-top,g-bad-import-order
30
+
31
+ if TYPE_CHECKING:
32
+ import dataclasses
33
+ import jax as _jax
34
+ import tensorflow as _tf
35
+
36
+ TensorShapeInstance = Union[_tf.TensorShape, Tuple[int, ...]]
37
+
38
+
39
+ def standardize_dtype(dtype: Any) -> str:
40
+ """Converts a backend-specific dtype to a standard string representation.
41
+
42
+ Args:
43
+ dtype: A backend-specific dtype object (e.g., tf.DType, np.dtype).
44
+
45
+ Returns:
46
+ A canonical string representation of the dtype (e.g., 'float32').
47
+ """
48
+
49
+ # Handle None explicitly, as np.dtype(None) defaults to float64.
50
+
51
+ if dtype is None:
52
+ return str(None)
53
+
54
+ if hasattr(dtype, "as_numpy_dtype"):
55
+ dtype = dtype.as_numpy_dtype
56
+
57
+ try:
58
+ return np.dtype(dtype).name
59
+ except TypeError:
60
+ return str(dtype)
61
+
62
+
63
+ def result_type(*types: Any) -> str:
64
+ """Infers the result dtype from a list of input types, backend-agnostically.
65
+
66
+ This acts as the single source of truth for type promotion rules. The
67
+ promotion logic is designed to be consistent across all backends.
68
+
69
+ Rule: If any input is a float, the result is float32. Otherwise, the result
70
+ is int64 to match NumPy/JAX's default behavior for precision.
71
+
72
+ Args:
73
+ *types: A variable number of type objects (e.g., `<class 'int'>`,
74
+ np.dtype('float32')).
75
+
76
+ Returns:
77
+ A string representing the promoted dtype.
78
+ """
79
+ standardized_types = []
80
+ for t in types:
81
+ if t is None:
82
+ continue
83
+ try:
84
+ # Standardize the input type before checking promotion rules.
85
+ standardized_types.append(standardize_dtype(t))
86
+ except Exception: # pylint: disable=broad-except
87
+ # Fallback if standardization fails for an unexpected type.
88
+ standardized_types.append(str(t))
89
+
90
+ if any("float" in t for t in standardized_types):
91
+ return "float32"
92
+ return "int64"
93
+
94
+
95
+ def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
96
+ """Resolves the final dtype for an operation.
97
+
98
+ If a dtype is explicitly provided, it's returned. Otherwise, it infers the
99
+ dtype from the input arguments using the backend-agnostic `result_type`
100
+ promotion rules.
101
+
102
+ Args:
103
+ dtype: The user-provided dtype, which may be None.
104
+ *args: The input arguments to the operation, used for dtype inference.
105
+
106
+ Returns:
107
+ A string representing the resolved dtype.
108
+ """
109
+ if dtype is not None:
110
+ return standardize_dtype(dtype)
111
+
112
+ input_types = [
113
+ getattr(arg, "dtype", type(arg)) for arg in args if arg is not None
114
+ ]
115
+ return result_type(*input_types)
116
+
117
+
118
+ # --- Private Backend-Specific Implementations ---
119
+
120
+
121
+ def _jax_arange(
122
+ start: Any,
123
+ stop: Optional[Any] = None,
124
+ step: Any = 1,
125
+ dtype: Optional[Any] = None,
126
+ ) -> "_jax.Array":
127
+ """JAX implementation for arange."""
128
+
129
+ # Import locally to make the function self-contained.
130
+
131
+ import jax.numpy as jnp
132
+
133
+ resolved_dtype = _resolve_dtype(dtype, start, stop, step)
134
+ return jnp.arange(start, stop, step=step, dtype=resolved_dtype)
135
+
136
+
137
+ def _tf_arange(
138
+ start: Any,
139
+ stop: Optional[Any] = None,
140
+ step: Any = 1,
141
+ dtype: Optional[Any] = None,
142
+ ) -> "_tf.Tensor":
143
+ """TensorFlow implementation for arange."""
144
+ import tensorflow as tf
145
+
146
+ resolved_dtype = _resolve_dtype(dtype, start, stop, step)
147
+ try:
148
+ return tf.range(start, limit=stop, delta=step, dtype=resolved_dtype)
149
+ except tf.errors.NotFoundError:
150
+ result = tf.range(start, limit=stop, delta=step, dtype=tf.float32)
151
+ return tf.cast(result, resolved_dtype)
152
+
153
+
154
+ def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
155
+ """JAX implementation for cast."""
156
+ return x.astype(dtype)
157
+
158
+
159
+ def _jax_divide_no_nan(x, y):
160
+ """JAX implementation for divide_no_nan."""
161
+ import jax.numpy as jnp
162
+
163
+ return jnp.where(y != 0, jnp.divide(x, y), 0.0)
164
+
165
+
166
+ def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
167
+ raise NotImplementedError(
168
+ "backend.numpy_function is not implemented for the JAX backend."
169
+ )
170
+
171
+
172
+ def _jax_make_tensor_proto(*args, **kwargs): # pylint: disable=unused-argument
173
+ raise NotImplementedError(
174
+ "backend.make_tensor_proto is not implemented for the JAX backend."
175
+ )
176
+
177
+
178
+ def _jax_make_ndarray(*args, **kwargs): # pylint: disable=unused-argument
179
+ raise NotImplementedError(
180
+ "backend.make_ndarray is not implemented for the JAX backend."
181
+ )
182
+
183
+
184
+ def _jax_get_indices_where(condition):
185
+ """JAX implementation for get_indices_where."""
186
+ import jax.numpy as jnp
187
+
188
+ return jnp.stack(jnp.where(condition), axis=-1)
189
+
190
+
191
+ def _tf_get_indices_where(condition):
192
+ """TensorFlow implementation for get_indices_where."""
193
+ import tensorflow as tf
194
+
195
+ return tf.where(condition)
196
+
197
+
198
+ def _jax_unique_with_counts(x):
199
+ """JAX implementation for unique_with_counts."""
200
+ import jax.numpy as jnp
201
+
202
+ y, counts = jnp.unique(x, return_counts=True)
203
+ # The TF version returns a tuple of (y, idx, count). The idx is not used in
204
+ # the calling code, so we can return None for it to maintain tuple structure.
205
+ return y, None, counts
206
+
207
+
208
+ def _tf_unique_with_counts(x):
209
+ """TensorFlow implementation for unique_with_counts."""
210
+ import tensorflow as tf
211
+
212
+ return tf.unique_with_counts(x)
213
+
214
+
215
+ def _jax_boolean_mask(tensor, mask, axis=None):
216
+ """JAX implementation for boolean_mask that supports an axis argument."""
217
+ import jax.numpy as jnp
218
+
219
+ if axis is None:
220
+ axis = 0
221
+ tensor_swapped = jnp.moveaxis(tensor, axis, 0)
222
+ masked = tensor_swapped[mask]
223
+ return jnp.moveaxis(masked, 0, axis)
224
+
225
+
226
+ def _tf_boolean_mask(tensor, mask, axis=None):
227
+ """TensorFlow implementation for boolean_mask."""
228
+ import tensorflow as tf
229
+
230
+ return tf.boolean_mask(tensor, mask, axis=axis)
231
+
232
+
233
+ def _jax_gather(params, indices):
234
+ """JAX implementation for gather."""
235
+ # JAX uses standard array indexing for gather operations.
236
+ return params[indices]
237
+
238
+
239
+ def _tf_gather(params, indices):
240
+ """TensorFlow implementation for gather."""
241
+ import tensorflow as tf
242
+
243
+ return tf.gather(params, indices)
244
+
245
+
246
+ def _jax_fill(dims, value):
247
+ """JAX implementation for fill."""
248
+ import jax.numpy as jnp
249
+
250
+ return jnp.full(dims, value)
251
+
252
+
253
+ def _tf_fill(dims, value):
254
+ """TensorFlow implementation for fill."""
255
+ import tensorflow as tf
256
+
257
+ return tf.fill(dims, value)
258
+
259
+
260
+ def _jax_argmax(tensor, axis=None):
261
+ """JAX implementation for argmax, aligned with TensorFlow's default.
262
+
263
+ This function finds the indices of the maximum values along a specified axis.
264
+ Crucially, it mimics the default behavior of TensorFlow's `tf.argmax`, where
265
+ if `axis` is `None`, the operation defaults to `axis=0`. This differs from
266
+ NumPy's and JAX's native `argmax` behavior, which would flatten the array
267
+ before finding the index.
268
+
269
+ Args:
270
+ tensor: The input JAX array.
271
+ axis: An integer specifying the axis along which to find the index of the
272
+ maximum value. If `None`, it defaults to `0` to match TensorFlow's
273
+ behavior.
274
+
275
+ Returns:
276
+ A JAX array containing the indices of the maximum values.
277
+ """
278
+ import jax.numpy as jnp
279
+
280
+ if axis is None:
281
+ axis = 0
282
+ return jnp.argmax(tensor, axis=axis)
283
+
284
+
285
+ def _tf_argmax(tensor, axis=None):
286
+ """TensorFlow implementation for argmax."""
287
+ import tensorflow as tf
288
+
289
+ return tf.argmax(tensor, axis=axis)
290
+
291
+
292
+ def _jax_broadcast_dynamic_shape(shape_x, shape_y):
293
+ """JAX implementation for broadcast_dynamic_shape."""
294
+ import jax.numpy as jnp
295
+
296
+ return jnp.broadcast_shapes(shape_x, shape_y)
297
+
298
+
299
+ def _jax_tensor_shape(dims):
300
+ """JAX implementation for TensorShape."""
301
+ if isinstance(dims, int):
302
+ return (dims,)
303
+
304
+ return tuple(dims)
305
+
306
+
307
+ # --- Backend Initialization ---
308
+ _BACKEND = config.get_backend()
309
+
310
+ # We expose standardized functions directly at the module level (backend.foo)
311
+ # to provide a consistent, NumPy-like API across backends. The '_ops' object
312
+ # is a private member for accessing the full, raw backend library if necessary,
313
+ # but usage should prefer the top-level standardized functions.
314
+
315
+ if _BACKEND == config.Backend.JAX:
316
+ import jax
317
+ import jax.numpy as jax_ops
318
+ import tensorflow_probability.substrates.jax as tfp_jax
319
+
320
+ class ExtensionType:
321
+ """A JAX-compatible stand-in for tf.experimental.ExtensionType."""
322
+
323
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
324
+ raise NotImplementedError(
325
+ "ExtensionType is not yet implemented for the JAX backend."
326
+ )
327
+
328
+ class _JaxErrors:
329
+ # pylint: disable=invalid-name
330
+ ResourceExhaustedError = MemoryError
331
+ InvalidArgumentError = ValueError
332
+ # pylint: enable=invalid-name
333
+
334
+ _ops = jax_ops
335
+ errors = _JaxErrors()
336
+ Tensor = jax.Array
337
+ tfd = tfp_jax.distributions
338
+ bijectors = tfp_jax.bijectors
339
+ experimental = tfp_jax.experimental
340
+ random = tfp_jax.random
341
+ mcmc = tfp_jax.mcmc
342
+ _convert_to_tensor = _ops.asarray
343
+
344
+ # Standardized Public API
345
+ absolute = _ops.abs
346
+ allclose = _ops.allclose
347
+ arange = _jax_arange
348
+ argmax = _jax_argmax
349
+ boolean_mask = _jax_boolean_mask
350
+ concatenate = _ops.concatenate
351
+ stack = _ops.stack
352
+ split = _ops.split
353
+ zeros = _ops.zeros
354
+ zeros_like = _ops.zeros_like
355
+ ones = _ops.ones
356
+ ones_like = _ops.ones_like
357
+ repeat = _ops.repeat
358
+ reshape = _ops.reshape
359
+ tile = _ops.tile
360
+ where = _ops.where
361
+ transpose = _ops.transpose
362
+ broadcast_to = _ops.broadcast_to
363
+ broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
364
+ broadcast_to = _ops.broadcast_to
365
+ cast = _jax_cast
366
+ concatenate = _ops.concatenate
367
+ cumsum = _ops.cumsum
368
+ divide = _ops.divide
369
+ divide_no_nan = _jax_divide_no_nan
370
+ einsum = _ops.einsum
371
+ equal = _ops.equal
372
+ exp = _ops.exp
373
+ expand_dims = _ops.expand_dims
374
+ fill = _jax_fill
375
+ function = jax.jit
376
+ gather = _jax_gather
377
+ get_indices_where = _jax_get_indices_where
378
+ is_nan = _ops.isnan
379
+ log = _ops.log
380
+ make_ndarray = _jax_make_ndarray
381
+ make_tensor_proto = _jax_make_tensor_proto
382
+ numpy_function = _jax_numpy_function
383
+ ones = _ops.ones
384
+ ones_like = _ops.ones_like
385
+ rank = _ops.ndim
386
+ reduce_any = _ops.any
387
+ reduce_max = _ops.max
388
+ reduce_mean = _ops.mean
389
+ reduce_min = _ops.min
390
+ reduce_std = _ops.std
391
+ reduce_sum = _ops.sum
392
+ repeat = _ops.repeat
393
+ reshape = _ops.reshape
394
+ split = _ops.split
395
+ stack = _ops.stack
396
+ tile = _ops.tile
397
+ transpose = _ops.transpose
398
+ unique_with_counts = _jax_unique_with_counts
399
+ where = _ops.where
400
+ zeros = _ops.zeros
401
+ zeros_like = _ops.zeros_like
402
+
403
+ float32 = _ops.float32
404
+ bool_ = _ops.bool_
405
+ newaxis = _ops.newaxis
406
+ TensorShape = _jax_tensor_shape
407
+
408
+ def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
409
+ raise NotImplementedError(
410
+ "JAX does not support a global, stateful random seed. `set_random_seed`"
411
+ " is not implemented. Instead, you must pass an explicit `seed`"
412
+ " integer directly to the sampling methods (e.g., `sample_prior`),"
413
+ " which will be used to create a JAX PRNGKey internally."
414
+ )
415
+
416
+ elif _BACKEND == config.Backend.TENSORFLOW:
417
+ import tensorflow as tf_backend
418
+ import tensorflow_probability as tfp
419
+
420
+ _ops = tf_backend
421
+ errors = _ops.errors
422
+
423
+ Tensor = tf_backend.Tensor
424
+ ExtensionType = _ops.experimental.ExtensionType
425
+
426
+ tfd = tfp.distributions
427
+ bijectors = tfp.bijectors
428
+ experimental = tfp.experimental
429
+ random = tfp.random
430
+ mcmc = tfp.mcmc
431
+
432
+ _convert_to_tensor = _ops.convert_to_tensor
433
+ absolute = _ops.math.abs
434
+ allclose = _ops.experimental.numpy.allclose
435
+ arange = _tf_arange
436
+ argmax = _tf_argmax
437
+ boolean_mask = _tf_boolean_mask
438
+ concatenate = _ops.concat
439
+ stack = _ops.stack
440
+ split = _ops.split
441
+ zeros = _ops.zeros
442
+ zeros_like = _ops.zeros_like
443
+ ones = _ops.ones
444
+ ones_like = _ops.ones_like
445
+ repeat = _ops.repeat
446
+ reshape = _ops.reshape
447
+ tile = _ops.tile
448
+ where = _ops.where
449
+ transpose = _ops.transpose
450
+ broadcast_to = _ops.broadcast_to
451
+ broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
452
+ broadcast_to = _ops.broadcast_to
453
+ cast = _ops.cast
454
+ concatenate = _ops.concat
455
+ cumsum = _ops.cumsum
456
+ divide = _ops.divide
457
+ divide_no_nan = _ops.math.divide_no_nan
458
+ einsum = _ops.einsum
459
+ equal = _ops.equal
460
+ exp = _ops.math.exp
461
+ expand_dims = _ops.expand_dims
462
+ fill = _tf_fill
463
+ function = _ops.function
464
+ gather = _tf_gather
465
+ get_indices_where = _tf_get_indices_where
466
+ is_nan = _ops.math.is_nan
467
+ log = _ops.math.log
468
+ make_ndarray = _ops.make_ndarray
469
+ make_tensor_proto = _ops.make_tensor_proto
470
+ numpy_function = _ops.numpy_function
471
+ ones = _ops.ones
472
+ ones_like = _ops.ones_like
473
+ rank = _ops.rank
474
+ reduce_any = _ops.reduce_any
475
+ reduce_max = _ops.reduce_max
476
+ reduce_mean = _ops.reduce_mean
477
+ reduce_min = _ops.reduce_min
478
+ reduce_std = _ops.math.reduce_std
479
+ reduce_sum = _ops.reduce_sum
480
+ repeat = _ops.repeat
481
+ reshape = _ops.reshape
482
+ set_random_seed = tf_backend.keras.utils.set_random_seed
483
+ split = _ops.split
484
+ stack = _ops.stack
485
+ tile = _ops.tile
486
+ transpose = _ops.transpose
487
+ unique_with_counts = _tf_unique_with_counts
488
+ where = _ops.where
489
+ zeros = _ops.zeros
490
+ zeros_like = _ops.zeros_like
491
+
492
+ float32 = _ops.float32
493
+ bool_ = _ops.bool
494
+ newaxis = _ops.newaxis
495
+ TensorShape = _ops.TensorShape
496
+
497
+ else:
498
+ raise ValueError(f"Unsupported backend: {_BACKEND}")
499
+ # pylint: enable=g-import-not-at-top,g-bad-import-order
500
+
501
+
502
+ def to_tensor(data: Any, dtype: Optional[Any] = None) -> Tensor: # type: ignore
503
+ """Converts input data to the currently active backend tensor type.
504
+
505
+ Args:
506
+ data: The data to convert.
507
+ dtype: The desired data type of the resulting tensor. The accepted types
508
+ depend on the active backend (e.g., jax.numpy.dtype or tf.DType).
509
+
510
+ Returns:
511
+ A tensor representation of the data for the active backend.
512
+ """
513
+
514
+ return _convert_to_tensor(data, dtype=dtype)
@@ -0,0 +1,59 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Backend configuration for Meridian."""
16
+
17
+ import enum
18
+ import warnings
19
+
20
+
21
+ class Backend(enum.Enum):
22
+ TENSORFLOW = "tensorflow"
23
+ JAX = "jax"
24
+
25
+
26
+ _BACKEND = Backend.TENSORFLOW
27
+
28
+
29
+ def set_backend(backend: Backend) -> None:
30
+ """Sets the backend for Meridian.
31
+
32
+ Note: The JAX backend is currently under development and should not be used.
33
+
34
+ Args:
35
+ backend: The backend to use, must be a member of the `Backend` enum.
36
+
37
+ Raises:
38
+ ValueError: If the provided backend is not a valid `Backend` enum member.
39
+ """
40
+ global _BACKEND
41
+ if not isinstance(backend, Backend):
42
+ raise ValueError("Backend must be a member of the Backend enum.")
43
+
44
+ if backend == Backend.JAX:
45
+ warnings.warn(
46
+ (
47
+ "The JAX backend is currently under development and is not yet"
48
+ " functional. It is intended for internal testing only and should"
49
+ " not be used. Please use the TensorFlow backend."
50
+ ),
51
+ UserWarning,
52
+ )
53
+
54
+ _BACKEND = backend
55
+
56
+
57
+ def get_backend() -> Backend:
58
+ """Returns the current backend for Meridian."""
59
+ return _BACKEND
@@ -0,0 +1,95 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Common testing utilities for Meridian, designed to be backend-agnostic."""
16
+
17
+ from typing import Any
18
+ import numpy as np
19
+
20
+ # A type alias for backend-agnostic array-like objects.
21
+ # We use `Any` here to avoid circular dependencies with the backend module
22
+ # while still allowing the function to accept backend-specific tensor types.
23
+ ArrayLike = Any
24
+
25
+
26
+ def assert_allclose(
27
+ a: ArrayLike,
28
+ b: ArrayLike,
29
+ rtol: float = 1e-6,
30
+ atol: float = 1e-6,
31
+ err_msg: str = "",
32
+ ):
33
+ """Backend-agnostic assertion to check if two array-like objects are close.
34
+
35
+ This function converts both inputs to NumPy arrays before comparing them,
36
+ making it compatible with TensorFlow Tensors, JAX Arrays, and standard
37
+ Python lists or NumPy arrays.
38
+
39
+ Args:
40
+ a: The first array-like object to compare.
41
+ b: The second array-like object to compare.
42
+ rtol: The relative tolerance parameter.
43
+ atol: The absolute tolerance parameter.
44
+ err_msg: The error message to be printed in case of failure.
45
+
46
+ Raises:
47
+ AssertionError: If the two arrays are not equal within the given tolerance.
48
+ """
49
+ np.testing.assert_allclose(
50
+ np.array(a), np.array(b), rtol=rtol, atol=atol, err_msg=err_msg
51
+ )
52
+
53
+
54
+ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
55
+ """Backend-agnostic assertion to check if two array-like objects are equal.
56
+
57
+ This function converts both inputs to NumPy arrays before comparing them.
58
+
59
+ Args:
60
+ a: The first array-like object to compare.
61
+ b: The second array-like object to compare.
62
+ err_msg: The error message to be printed in case of failure.
63
+
64
+ Raises:
65
+ AssertionError: If the two arrays are not equal.
66
+ """
67
+ np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
68
+
69
+
70
+ def assert_all_finite(a: ArrayLike, err_msg: str = ""):
71
+ """Backend-agnostic assertion to check if all elements in an array are finite.
72
+
73
+ Args:
74
+ a: The array-like object to check.
75
+ err_msg: The error message to be printed in case of failure.
76
+
77
+ Raises:
78
+ AssertionError: If the array contains non-finite values.
79
+ """
80
+ if not np.all(np.isfinite(np.array(a))):
81
+ raise AssertionError(err_msg or "Array contains non-finite values.")
82
+
83
+
84
+ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
85
+ """Backend-agnostic assertion to check if all elements are non-negative.
86
+
87
+ Args:
88
+ a: The array-like object to check.
89
+ err_msg: The error message to be printed in case of failure.
90
+
91
+ Raises:
92
+ AssertionError: If the array contains negative values.
93
+ """
94
+ if not np.all(np.array(a) >= 0):
95
+ raise AssertionError(err_msg or "Array contains negative values.")