google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,975 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Backend Abstraction Layer for Meridian."""
16
+
17
+ import abc
18
+ import functools
19
+ import os
20
+ from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
21
+
22
+ from meridian.backend import config
23
+ import numpy as np
24
+ from typing_extensions import Literal
25
+
26
+
27
+ # The conditional imports in this module are a deliberate design choice for the
28
+ # backend abstraction layer. The TFP-on-JAX substrate provides a nearly
29
+ # identical API to the standard TFP library, making an alias-based approach more
30
+ # pragmatic than a full Abstract Base Class implementation, which would require
31
+ # extensive boilerplate.
32
+ # pylint: disable=g-import-not-at-top,g-bad-import-order
33
+
34
+ _DEFAULT_FLOAT = "float32"
35
+ _DEFAULT_INT = "int64"
36
+
37
+ _TENSORFLOW_TILE_KEYWORD = "multiples"
38
+ _JAX_TILE_KEYWORD = "reps"
39
+
40
+ _ARG_JIT_COMPILE = "jit_compile"
41
+ _ARG_STATIC_ARGNUMS = "static_argnums"
42
+ _ARG_STATIC_ARGNAMES = "static_argnames"
43
+
44
+ _DEFAULT_SEED_DTYPE = "int32"
45
+ _MAX_INT32 = np.iinfo(np.int32).max
46
+
47
+ if TYPE_CHECKING:
48
+ import dataclasses
49
+ import jax as _jax
50
+ import tensorflow as _tf
51
+
52
+ TensorShapeInstance = Union[_tf.TensorShape, Tuple[int, ...]]
53
+
54
+ SeedType = Any
55
+
56
+
57
+ def standardize_dtype(dtype: Any) -> str:
58
+ """Converts a backend-specific dtype to a standard string representation.
59
+
60
+ Args:
61
+ dtype: A backend-specific dtype object (e.g., tf.DType, np.dtype).
62
+
63
+ Returns:
64
+ A canonical string representation of the dtype (e.g., 'float32').
65
+ """
66
+
67
+ # Handle None explicitly, as np.dtype(None) defaults to float64.
68
+
69
+ if dtype is None:
70
+ return str(None)
71
+
72
+ if hasattr(dtype, "as_numpy_dtype"):
73
+ dtype = dtype.as_numpy_dtype
74
+
75
+ try:
76
+ return np.dtype(dtype).name
77
+ except TypeError:
78
+ return str(dtype)
79
+
80
+
81
+ def result_type(*types: Any) -> str:
82
+ """Infers the result dtype from a list of input types, backend-agnostically.
83
+
84
+ This acts as the single source of truth for type promotion rules. The
85
+ promotion logic is designed to be consistent across all backends.
86
+
87
+ Rule: If any input is a float, the result is float32. Otherwise, the result
88
+ is int64 to match NumPy/JAX's default behavior for precision.
89
+
90
+ Args:
91
+ *types: A variable number of type objects (e.g., `<class 'int'>`,
92
+ np.dtype('float32')).
93
+
94
+ Returns:
95
+ A string representing the promoted dtype.
96
+ """
97
+ standardized_types = []
98
+ for t in types:
99
+ if t is None:
100
+ continue
101
+ try:
102
+ # Standardize the input type before checking promotion rules.
103
+ standardized_types.append(standardize_dtype(t))
104
+ except Exception: # pylint: disable=broad-except
105
+ # Fallback if standardization fails for an unexpected type.
106
+ standardized_types.append(str(t))
107
+
108
+ if any("float" in t for t in standardized_types):
109
+ return _DEFAULT_FLOAT
110
+ return _DEFAULT_INT
111
+
112
+
113
+ def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
114
+ """Resolves the final dtype for an operation.
115
+
116
+ If a dtype is explicitly provided, it's returned. Otherwise, it infers the
117
+ dtype from the input arguments using the backend-agnostic `result_type`
118
+ promotion rules.
119
+
120
+ Args:
121
+ dtype: The user-provided dtype, which may be None.
122
+ *args: The input arguments to the operation, used for dtype inference.
123
+
124
+ Returns:
125
+ A string representing the resolved dtype.
126
+ """
127
+ if dtype is not None:
128
+ return standardize_dtype(dtype)
129
+
130
+ input_types = [
131
+ getattr(arg, "dtype", type(arg)) for arg in args if arg is not None
132
+ ]
133
+ return result_type(*input_types)
134
+
135
+
136
+ # --- Private Backend-Specific Implementations ---
137
+
138
+
139
+ def _jax_arange(
140
+ start: Any,
141
+ stop: Optional[Any] = None,
142
+ step: Any = 1,
143
+ dtype: Optional[Any] = None,
144
+ ) -> "_jax.Array":
145
+ """JAX implementation for arange."""
146
+
147
+ # Import locally to make the function self-contained.
148
+
149
+ import jax.numpy as jnp
150
+
151
+ resolved_dtype = _resolve_dtype(dtype, start, stop, step)
152
+ return jnp.arange(start, stop, step=step, dtype=resolved_dtype)
153
+
154
+
155
+ def _tf_arange(
156
+ start: Any,
157
+ stop: Optional[Any] = None,
158
+ step: Any = 1,
159
+ dtype: Optional[Any] = None,
160
+ ) -> "_tf.Tensor":
161
+ """TensorFlow implementation for arange."""
162
+ import tensorflow as tf
163
+
164
+ resolved_dtype = _resolve_dtype(dtype, start, stop, step)
165
+ try:
166
+ return tf.range(start, limit=stop, delta=step, dtype=resolved_dtype)
167
+ except tf.errors.NotFoundError:
168
+ result = tf.range(start, limit=stop, delta=step, dtype=tf.float32)
169
+ return tf.cast(result, resolved_dtype)
170
+
171
+
172
+ def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
173
+ """JAX implementation for cast."""
174
+ return x.astype(dtype)
175
+
176
+
177
+ def _jax_divide_no_nan(x, y):
178
+ """JAX implementation for divide_no_nan."""
179
+ import jax.numpy as jnp
180
+
181
+ return jnp.where(y != 0, jnp.divide(x, y), 0.0)
182
+
183
+
184
+ def _jax_function_wrapper(func=None, **kwargs):
185
+ """A wrapper for jax.jit that handles TF-like args and static args.
186
+
187
+ This wrapper provides compatibility with TensorFlow's `tf.function` arguments
188
+ and improves ergonomics when decorating class methods in JAX.
189
+
190
+ By default, if neither `static_argnums` nor `static_argnames` are provided, it
191
+ defaults `static_argnums` to `(0,)`. This assumes the function is a method
192
+ where the first argument (`self` or `cls`) should be treated as static.
193
+
194
+ To disable this behavior for plain functions, explicitly provide an empty
195
+ tuple: `@backend.function(static_argnums=())`.
196
+
197
+ Args:
198
+ func: The function to wrap.
199
+ **kwargs: Keyword arguments passed to jax.jit. TF-specific arguments (like
200
+ `jit_compile`) are ignored.
201
+
202
+ Returns:
203
+ The wrapped function or a decorator.
204
+ """
205
+ jit_kwargs = kwargs.copy()
206
+
207
+ jit_kwargs.pop(_ARG_JIT_COMPILE, None)
208
+
209
+ if _ARG_STATIC_ARGNUMS in jit_kwargs:
210
+ if not jit_kwargs[_ARG_STATIC_ARGNUMS]:
211
+ jit_kwargs.pop(_ARG_STATIC_ARGNUMS)
212
+ else:
213
+ jit_kwargs[_ARG_STATIC_ARGNUMS] = (0,)
214
+
215
+ decorator = functools.partial(jax.jit, **jit_kwargs)
216
+
217
+ if func:
218
+ return decorator(func)
219
+ return decorator
220
+
221
+
222
+ def _tf_function_wrapper(func=None, **kwargs):
223
+ """A wrapper for tf.function that ignores JAX-specific arguments."""
224
+ import tensorflow as tf
225
+
226
+ kwargs.pop(_ARG_STATIC_ARGNAMES, None)
227
+ kwargs.pop(_ARG_STATIC_ARGNUMS, None)
228
+
229
+ decorator = tf.function(**kwargs)
230
+
231
+ if func:
232
+ return decorator(func)
233
+ return decorator
234
+
235
+
236
+ def _jax_nanmedian(a, axis=None):
237
+ """JAX implementation for nanmedian."""
238
+ import jax.numpy as jnp
239
+
240
+ return jnp.nanmedian(a, axis=axis)
241
+
242
+
243
+ def _tf_nanmedian(a, axis=None):
244
+ """TensorFlow implementation for nanmedian using numpy_function."""
245
+ import tensorflow as tf
246
+
247
+ return tf.numpy_function(
248
+ lambda x: np.nanmedian(x, axis=axis).astype(x.dtype), [a], a.dtype
249
+ )
250
+
251
+
252
+ def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
253
+ raise NotImplementedError(
254
+ "backend.numpy_function is not implemented for the JAX backend."
255
+ )
256
+
257
+
258
+ def _jax_make_tensor_proto(*args, **kwargs): # pylint: disable=unused-argument
259
+ raise NotImplementedError(
260
+ "backend.make_tensor_proto is not implemented for the JAX backend."
261
+ )
262
+
263
+
264
+ def _jax_make_ndarray(*args, **kwargs): # pylint: disable=unused-argument
265
+ raise NotImplementedError(
266
+ "backend.make_ndarray is not implemented for the JAX backend."
267
+ )
268
+
269
+
270
+ def _jax_get_indices_where(condition):
271
+ """JAX implementation for get_indices_where."""
272
+ import jax.numpy as jnp
273
+
274
+ return jnp.stack(jnp.where(condition), axis=-1)
275
+
276
+
277
+ def _tf_get_indices_where(condition):
278
+ """TensorFlow implementation for get_indices_where."""
279
+ import tensorflow as tf
280
+
281
+ return tf.where(condition)
282
+
283
+
284
+ def _jax_split(value, num_or_size_splits, axis=0):
285
+ """JAX implementation for split that accepts size splits."""
286
+ import jax.numpy as jnp
287
+
288
+ if not isinstance(num_or_size_splits, int):
289
+ indices = jnp.cumsum(jnp.array(num_or_size_splits))[:-1]
290
+ return jnp.split(value, indices, axis=axis)
291
+
292
+ return jnp.split(value, num_or_size_splits, axis=axis)
293
+
294
+
295
+ def _jax_tile(*args, **kwargs):
296
+ """JAX wrapper for tile that supports the `multiples` keyword argument."""
297
+ import jax.numpy as jnp
298
+
299
+ if _TENSORFLOW_TILE_KEYWORD in kwargs:
300
+ kwargs[_JAX_TILE_KEYWORD] = kwargs.pop(_TENSORFLOW_TILE_KEYWORD)
301
+ return jnp.tile(*args, **kwargs)
302
+
303
+
304
+ def _jax_unique_with_counts(x):
305
+ """JAX implementation for unique_with_counts."""
306
+ import jax.numpy as jnp
307
+
308
+ y, counts = jnp.unique(x, return_counts=True)
309
+ # The TF version returns a tuple of (y, idx, count). The idx is not used in
310
+ # the calling code, so we can return None for it to maintain tuple structure.
311
+ return y, None, counts
312
+
313
+
314
+ def _tf_unique_with_counts(x):
315
+ """TensorFlow implementation for unique_with_counts."""
316
+ import tensorflow as tf
317
+
318
+ return tf.unique_with_counts(x)
319
+
320
+
321
+ def _jax_boolean_mask(tensor, mask, axis=None):
322
+ """JAX implementation for boolean_mask that supports an axis argument."""
323
+ import jax.numpy as jnp
324
+
325
+ if axis is None:
326
+ axis = 0
327
+ tensor_swapped = jnp.moveaxis(tensor, axis, 0)
328
+ masked = tensor_swapped[mask]
329
+ return jnp.moveaxis(masked, 0, axis)
330
+
331
+
332
+ def _tf_boolean_mask(tensor, mask, axis=None):
333
+ """TensorFlow implementation for boolean_mask."""
334
+ import tensorflow as tf
335
+
336
+ return tf.boolean_mask(tensor, mask, axis=axis)
337
+
338
+
339
+ def _jax_gather(params, indices):
340
+ """JAX implementation for gather."""
341
+ # JAX uses standard array indexing for gather operations.
342
+ return params[indices]
343
+
344
+
345
+ def _tf_gather(params, indices):
346
+ """TensorFlow implementation for gather."""
347
+ import tensorflow as tf
348
+
349
+ return tf.gather(params, indices)
350
+
351
+
352
+ def _jax_fill(dims, value):
353
+ """JAX implementation for fill."""
354
+ import jax.numpy as jnp
355
+
356
+ return jnp.full(dims, value)
357
+
358
+
359
+ def _tf_fill(dims, value):
360
+ """TensorFlow implementation for fill."""
361
+ import tensorflow as tf
362
+
363
+ return tf.fill(dims, value)
364
+
365
+
366
+ def _jax_argmax(tensor, axis=None):
367
+ """JAX implementation for argmax, aligned with TensorFlow's default.
368
+
369
+ This function finds the indices of the maximum values along a specified axis.
370
+ Crucially, it mimics the default behavior of TensorFlow's `tf.argmax`, where
371
+ if `axis` is `None`, the operation defaults to `axis=0`. This differs from
372
+ NumPy's and JAX's native `argmax` behavior, which would flatten the array
373
+ before finding the index.
374
+
375
+ Args:
376
+ tensor: The input JAX array.
377
+ axis: An integer specifying the axis along which to find the index of the
378
+ maximum value. If `None`, it defaults to `0` to match TensorFlow's
379
+ behavior.
380
+
381
+ Returns:
382
+ A JAX array containing the indices of the maximum values.
383
+ """
384
+ import jax.numpy as jnp
385
+
386
+ if axis is None:
387
+ axis = 0
388
+ return jnp.argmax(tensor, axis=axis)
389
+
390
+
391
+ def _tf_argmax(tensor, axis=None):
392
+ """TensorFlow implementation for argmax."""
393
+ import tensorflow as tf
394
+
395
+ return tf.argmax(tensor, axis=axis)
396
+
397
+
398
+ def _jax_broadcast_dynamic_shape(shape_x, shape_y):
399
+ """JAX implementation for broadcast_dynamic_shape."""
400
+ import jax.numpy as jnp
401
+
402
+ return jnp.broadcast_shapes(shape_x, shape_y)
403
+
404
+
405
+ def _jax_tensor_shape(dims):
406
+ """JAX implementation for TensorShape."""
407
+ if isinstance(dims, int):
408
+ return (dims,)
409
+
410
+ return tuple(dims)
411
+
412
+
413
+ def _jax_transpose(a, perm=None):
414
+ """JAX wrapper for transpose to support the 'perm' keyword argument."""
415
+ import jax.numpy as jnp
416
+
417
+ return jnp.transpose(a, axes=perm)
418
+
419
+
420
+ # --- Backend Initialization ---
421
+ _BACKEND = config.get_backend()
422
+
423
+ # We expose standardized functions directly at the module level (backend.foo)
424
+ # to provide a consistent, NumPy-like API across backends. The '_ops' object
425
+ # is a private member for accessing the full, raw backend library if necessary,
426
+ # but usage should prefer the top-level standardized functions.
427
+
428
+ if _BACKEND == config.Backend.JAX:
429
+ import jax
430
+ import jax.numpy as jax_ops
431
+ import tensorflow_probability.substrates.jax as tfp_jax
432
+
433
+ class ExtensionType:
434
+ """A JAX-compatible stand-in for tf.experimental.ExtensionType."""
435
+
436
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
437
+ raise NotImplementedError(
438
+ "ExtensionType is not yet implemented for the JAX backend."
439
+ )
440
+
441
+ class _JaxErrors:
442
+ # pylint: disable=invalid-name
443
+ ResourceExhaustedError = MemoryError
444
+ InvalidArgumentError = ValueError
445
+ # pylint: enable=invalid-name
446
+
447
+ class _JaxRandom:
448
+ """Provides JAX-based random number generation utilities.
449
+
450
+ This class mirrors the structure needed by `RNGHandler` for JAX.
451
+ """
452
+
453
+ @staticmethod
454
+ def prng_key(seed):
455
+ return jax.random.PRNGKey(seed)
456
+
457
+ @staticmethod
458
+ def split(key):
459
+ return jax.random.split(key)
460
+
461
+ @staticmethod
462
+ def generator_from_seed(seed):
463
+ raise NotImplementedError("JAX backend does not use Generators.")
464
+
465
+ @staticmethod
466
+ def stateless_split(seed: Any, num: int = 2):
467
+ raise NotImplementedError(
468
+ "Direct stateless splitting from an integer seed is not the primary"
469
+ " pattern used in the JAX backend. Use `backend.random.split(key)`"
470
+ " instead."
471
+ )
472
+
473
+ @staticmethod
474
+ def stateless_randint(key, shape, minval, maxval, dtype=jax_ops.int32):
475
+ """Wrapper for jax.random.randint."""
476
+ return jax.random.randint(
477
+ key, shape, minval=minval, maxval=maxval, dtype=dtype
478
+ )
479
+
480
+ @staticmethod
481
+ def stateless_uniform(
482
+ key, shape, dtype=jax_ops.float32, minval=0.0, maxval=1.0
483
+ ):
484
+ """Replacement for tfp_jax.random.stateless_uniform using jax.random.uniform."""
485
+ return jax.random.uniform(
486
+ key, shape=shape, dtype=dtype, minval=minval, maxval=maxval
487
+ )
488
+
489
+ @staticmethod
490
+ def sanitize_seed(seed):
491
+ return tfp_jax.random.sanitize_seed(seed)
492
+
493
+ random = _JaxRandom()
494
+
495
+ _ops = jax_ops
496
+ errors = _JaxErrors()
497
+ Tensor = jax.Array
498
+ tfd = tfp_jax.distributions
499
+ bijectors = tfp_jax.bijectors
500
+ experimental = tfp_jax.experimental
501
+ mcmc = tfp_jax.mcmc
502
+ _convert_to_tensor = _ops.asarray
503
+
504
+ # Standardized Public API
505
+ absolute = _ops.abs
506
+ allclose = _ops.allclose
507
+ arange = _jax_arange
508
+ argmax = _jax_argmax
509
+ boolean_mask = _jax_boolean_mask
510
+ concatenate = _ops.concatenate
511
+ stack = _ops.stack
512
+ split = _jax_split
513
+ zeros = _ops.zeros
514
+ zeros_like = _ops.zeros_like
515
+ ones = _ops.ones
516
+ ones_like = _ops.ones_like
517
+ repeat = _ops.repeat
518
+ reshape = _ops.reshape
519
+ tile = _ops.tile
520
+ where = _ops.where
521
+ broadcast_to = _ops.broadcast_to
522
+ broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
523
+ broadcast_to = _ops.broadcast_to
524
+ cast = _jax_cast
525
+ concatenate = _ops.concatenate
526
+ cumsum = _ops.cumsum
527
+ divide = _ops.divide
528
+ divide_no_nan = _jax_divide_no_nan
529
+ einsum = _ops.einsum
530
+ equal = _ops.equal
531
+ exp = _ops.exp
532
+ expand_dims = _ops.expand_dims
533
+ fill = _jax_fill
534
+ function = _jax_function_wrapper
535
+ gather = _jax_gather
536
+ get_indices_where = _jax_get_indices_where
537
+ is_nan = _ops.isnan
538
+ log = _ops.log
539
+ make_ndarray = _jax_make_ndarray
540
+ make_tensor_proto = _jax_make_tensor_proto
541
+ nanmedian = _jax_nanmedian
542
+ numpy_function = _jax_numpy_function
543
+ ones = _ops.ones
544
+ ones_like = _ops.ones_like
545
+ rank = _ops.ndim
546
+ reduce_any = _ops.any
547
+ reduce_max = _ops.max
548
+ reduce_mean = _ops.mean
549
+ reduce_min = _ops.min
550
+ reduce_std = _ops.std
551
+ reduce_sum = _ops.sum
552
+ repeat = _ops.repeat
553
+ reshape = _ops.reshape
554
+ stack = _ops.stack
555
+ tile = _jax_tile
556
+ transpose = _jax_transpose
557
+ unique_with_counts = _jax_unique_with_counts
558
+ where = _ops.where
559
+ zeros = _ops.zeros
560
+ zeros_like = _ops.zeros_like
561
+
562
+ float32 = _ops.float32
563
+ bool_ = _ops.bool_
564
+ newaxis = _ops.newaxis
565
+ TensorShape = _jax_tensor_shape
566
+ int32 = _ops.int32
567
+
568
+ def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
569
+ raise NotImplementedError(
570
+ "JAX does not support a global, stateful random seed. `set_random_seed`"
571
+ " is not implemented. Instead, you must pass an explicit `seed`"
572
+ " integer directly to the sampling methods (e.g., `sample_prior`),"
573
+ " which will be used to create a JAX PRNGKey internally."
574
+ )
575
+
576
+ elif _BACKEND == config.Backend.TENSORFLOW:
577
+ import tensorflow as tf_backend
578
+ import tensorflow_probability as tfp
579
+
580
+ _ops = tf_backend
581
+ errors = _ops.errors
582
+
583
+ Tensor = tf_backend.Tensor
584
+ ExtensionType = _ops.experimental.ExtensionType
585
+
586
+ class _TfRandom:
587
+ """Provides TensorFlow-based random number generation utilities.
588
+
589
+ This class mirrors the structure needed by `RNGHandler` for TensorFlow.
590
+ """
591
+
592
+ @staticmethod
593
+ def prng_key(seed):
594
+ raise NotImplementedError(
595
+ "TensorFlow backend does not use explicit PRNG keys for"
596
+ " standard sampling."
597
+ )
598
+
599
+ @staticmethod
600
+ def split(key):
601
+ raise NotImplementedError(
602
+ "TensorFlow backend does not implement explicit key splitting."
603
+ )
604
+
605
+ @staticmethod
606
+ def generator_from_seed(seed):
607
+ return tf_backend.random.Generator.from_seed(seed)
608
+
609
+ @staticmethod
610
+ def stateless_split(
611
+ seed: "tf_backend.Tensor", num: int = 2
612
+ ) -> "tf_backend.Tensor":
613
+ return tf_backend.random.experimental.stateless_split(seed, num=num)
614
+
615
+ @staticmethod
616
+ def stateless_randint(
617
+ seed: "tf_backend.Tensor",
618
+ shape: "TensorShapeInstance",
619
+ minval: int,
620
+ maxval: int,
621
+ dtype: Any = _DEFAULT_SEED_DTYPE,
622
+ ) -> "tf_backend.Tensor":
623
+ return tf_backend.random.stateless_uniform(
624
+ shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
625
+ )
626
+
627
+ @staticmethod
628
+ def stateless_uniform(
629
+ seed, shape, minval=0, maxval=None, dtype=tf_backend.float32
630
+ ):
631
+ return tf_backend.random.stateless_uniform(
632
+ shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
633
+ )
634
+
635
+ @staticmethod
636
+ def sanitize_seed(seed):
637
+ return tfp.random.sanitize_seed(seed)
638
+
639
+ random = _TfRandom()
640
+
641
+ tfd = tfp.distributions
642
+ bijectors = tfp.bijectors
643
+ experimental = tfp.experimental
644
+ mcmc = tfp.mcmc
645
+
646
+ _convert_to_tensor = _ops.convert_to_tensor
647
+ absolute = _ops.math.abs
648
+ allclose = _ops.experimental.numpy.allclose
649
+ arange = _tf_arange
650
+ argmax = _tf_argmax
651
+ boolean_mask = _tf_boolean_mask
652
+ concatenate = _ops.concat
653
+ stack = _ops.stack
654
+ split = _ops.split
655
+ zeros = _ops.zeros
656
+ zeros_like = _ops.zeros_like
657
+ ones = _ops.ones
658
+ ones_like = _ops.ones_like
659
+ repeat = _ops.repeat
660
+ reshape = _ops.reshape
661
+ tile = _ops.tile
662
+ where = _ops.where
663
+ broadcast_to = _ops.broadcast_to
664
+ broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
665
+ broadcast_to = _ops.broadcast_to
666
+ cast = _ops.cast
667
+ concatenate = _ops.concat
668
+ cumsum = _ops.cumsum
669
+ divide = _ops.divide
670
+ divide_no_nan = _ops.math.divide_no_nan
671
+ einsum = _ops.einsum
672
+ equal = _ops.equal
673
+ exp = _ops.math.exp
674
+ expand_dims = _ops.expand_dims
675
+ fill = _tf_fill
676
+ function = _tf_function_wrapper
677
+ gather = _tf_gather
678
+ get_indices_where = _tf_get_indices_where
679
+ is_nan = _ops.math.is_nan
680
+ log = _ops.math.log
681
+ make_ndarray = _ops.make_ndarray
682
+ make_tensor_proto = _ops.make_tensor_proto
683
+ nanmedian = _tf_nanmedian
684
+ numpy_function = _ops.numpy_function
685
+ ones = _ops.ones
686
+ ones_like = _ops.ones_like
687
+ rank = _ops.rank
688
+ reduce_any = _ops.reduce_any
689
+ reduce_max = _ops.reduce_max
690
+ reduce_mean = _ops.reduce_mean
691
+ reduce_min = _ops.reduce_min
692
+ reduce_std = _ops.math.reduce_std
693
+ reduce_sum = _ops.reduce_sum
694
+ repeat = _ops.repeat
695
+ reshape = _ops.reshape
696
+ set_random_seed = tf_backend.keras.utils.set_random_seed
697
+ stack = _ops.stack
698
+ tile = _ops.tile
699
+ transpose = _ops.transpose
700
+ unique_with_counts = _tf_unique_with_counts
701
+ where = _ops.where
702
+ zeros = _ops.zeros
703
+ zeros_like = _ops.zeros_like
704
+
705
+ float32 = _ops.float32
706
+ bool_ = _ops.bool
707
+ newaxis = _ops.newaxis
708
+ TensorShape = _ops.TensorShape
709
+ int32 = _ops.int32
710
+
711
+ else:
712
+ raise ValueError(f"Unsupported backend: {_BACKEND}")
713
+ # pylint: enable=g-import-not-at-top,g-bad-import-order
714
+
715
+
716
+ def _extract_int_seed(s: Any) -> Optional[int]:
717
+ """Attempts to extract a scalar Python integer from various input types.
718
+
719
+ Args:
720
+ s: The input seed, which can be an int, Tensor, or array-like object.
721
+
722
+ Returns:
723
+ A Python integer if a scalar integer can be extracted, otherwise None.
724
+ """
725
+ try:
726
+ if isinstance(s, int):
727
+ return s
728
+
729
+ value = np.asarray(s)
730
+
731
+ if value.ndim == 0 and np.issubdtype(value.dtype, np.integer):
732
+ return int(value)
733
+ # A broad exception is used here because the input `s` can be of many types
734
+ # (e.g., JAX PRNGKey) which may cause np.asarray or other operations to fail
735
+ # in unpredictable ways. The goal is to safely attempt extraction and fail
736
+ # gracefully.
737
+ except Exception: # pylint: disable=broad-except
738
+ pass
739
+ return None
740
+
741
+
742
+ class _BaseRNGHandler(abc.ABC):
743
+ """A backend-agnostic abstract base class for random number generation state.
744
+
745
+ This handler provides a stateful-style interface for consuming randomness,
746
+ abstracting away the differences between JAX's stateless PRNG keys and
747
+ TensorFlow's paradigms.
748
+
749
+ Attributes:
750
+ _seed_input: The original seed object provided during initialization.
751
+ _int_seed: A Python integer extracted from the seed, if possible.
752
+ """
753
+
754
+ def __init__(self, seed: SeedType):
755
+ """Initializes the RNG handler.
756
+
757
+ Args:
758
+ seed: The initial seed. The accepted type depends on the backend. For JAX,
759
+ this must be an integer. For TensorFlow, this can be an integer, a
760
+ sequence of two integers, or a Tensor. If None, the handler becomes a
761
+ no-op, returning None for all seed requests.
762
+ """
763
+ self._seed_input = seed
764
+ self._int_seed: Optional[int] = _extract_int_seed(seed)
765
+
766
+ @abc.abstractmethod
767
+ def get_kernel_seed(self) -> Any:
768
+ """Provides a backend-appropriate sanitized seed/key for an MCMC kernel.
769
+
770
+ This method exposes the current state of the handler in the format expected
771
+ by the backend's MCMC machinery. It does not advance the internal state.
772
+
773
+ Returns:
774
+ A backend-specific seed object (e.g., a JAX PRNGKey or a TF Tensor).
775
+ """
776
+ raise NotImplementedError
777
+
778
+ @abc.abstractmethod
779
+ def get_next_seed(self) -> Any:
780
+ """Provides the appropriate seed object for the next sequential operation.
781
+
782
+ This is primarily used for prior sampling and typically advances the
783
+ internal state.
784
+
785
+ Returns:
786
+ A backend-specific seed object for a single random operation.
787
+ """
788
+ raise NotImplementedError
789
+
790
+ @abc.abstractmethod
791
+ def advance_handler(self) -> "_BaseRNGHandler":
792
+ """Creates a new, independent RNGHandler for a subsequent operation.
793
+
794
+ This method is used to generate a new handler derived deterministically
795
+ from the current handler's state.
796
+
797
+ Returns:
798
+ A new, independent `RNGHandler` instance.
799
+ """
800
+ raise NotImplementedError
801
+
802
+
803
+ class _JaxRNGHandler(_BaseRNGHandler):
804
+ """JAX implementation of the RNGHandler using explicit key splitting."""
805
+
806
+ def __init__(self, seed: SeedType):
807
+ """Initializes the JAX RNG handler.
808
+
809
+ Args:
810
+ seed: The initial seed, which must be a Python integer, a scalar integer
811
+ Tensor/array, or None.
812
+
813
+ Raises:
814
+ ValueError: If the provided seed is not a scalar integer or None.
815
+ """
816
+ super().__init__(seed)
817
+ self._key: Optional["_jax.Array"] = None
818
+
819
+ if seed is None:
820
+ return
821
+
822
+ if self._int_seed is None:
823
+ raise ValueError(
824
+ "JAX backend requires a seed that is an integer or a scalar array,"
825
+ f" but got: {type(seed)} with value {seed!r}"
826
+ )
827
+
828
+ self._key = random.prng_key(self._int_seed)
829
+
830
+ def get_kernel_seed(self) -> Any:
831
+ if self._key is None:
832
+ return None
833
+ return random.sanitize_seed(self._key)
834
+
835
+ def get_next_seed(self) -> Any:
836
+ if self._key is None:
837
+ return None
838
+ self._key, subkey = random.split(self._key)
839
+ return subkey
840
+
841
+ def advance_handler(self) -> "_JaxRNGHandler":
842
+ if self._key is None:
843
+ return _JaxRNGHandler(None)
844
+
845
+ self._key, subkey_for_new_handler = random.split(self._key)
846
+ new_seed_tensor = random.stateless_randint(
847
+ key=subkey_for_new_handler,
848
+ shape=(),
849
+ minval=0,
850
+ maxval=_MAX_INT32,
851
+ dtype=_DEFAULT_SEED_DTYPE,
852
+ )
853
+ return _JaxRNGHandler(np.asarray(new_seed_tensor).item())
854
+
855
+
856
+ # TODO: Replace with _TFRNGHandler
857
+ class _TFLegacyRNGHandler(_BaseRNGHandler):
858
+ """TensorFlow implementation.
859
+
860
+ TODO: This class should be removed and replaced with a correct,
861
+ stateful `tf.random.Generator`-based implementation.
862
+ """
863
+
864
+ def __init__(self, seed: SeedType, *, _sanitized_seed: Optional[Any] = None):
865
+ """Initializes the TensorFlow legacy RNG handler.
866
+
867
+ Args:
868
+ seed: The initial seed. Can be an integer, a sequence of two integers, a
869
+ corresponding Tensor, or None.
870
+ _sanitized_seed: For internal use only. If provided, this pre-computed
871
+ seed tensor is used directly, and the standard initialization logic for
872
+ the public `seed` argument is bypassed.
873
+
874
+ Raises:
875
+ ValueError: If `seed` is a sequence with a length other than 2.
876
+ """
877
+ super().__init__(seed)
878
+ self._tf_sanitized_seed: Optional[Any] = None
879
+
880
+ if _sanitized_seed is not None:
881
+ # Internal path: A pre-sanitized seed was provided by a trusted source
882
+ # so we adopt it directly.
883
+ self._tf_sanitized_seed = _sanitized_seed
884
+ return
885
+
886
+ if seed is None:
887
+ return
888
+
889
+ if isinstance(seed, Sequence) and len(seed) != 2:
890
+ raise ValueError(
891
+ "Invalid seed: Must be either a single integer (stateful seed) or a"
892
+ " pair of two integers (stateless seed). See"
893
+ " [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
894
+ " for details."
895
+ )
896
+
897
+ if isinstance(seed, int):
898
+ seed_to_sanitize = (seed, seed)
899
+ else:
900
+ seed_to_sanitize = seed
901
+
902
+ self._tf_sanitized_seed = random.sanitize_seed(seed_to_sanitize)
903
+
904
+ def get_kernel_seed(self) -> Any:
905
+ return self._tf_sanitized_seed
906
+
907
+ def get_next_seed(self) -> Any:
908
+ """Returns the original integer seed to preserve prior sampling behavior.
909
+
910
+ Returns:
911
+ The original integer seed provided during initialization.
912
+
913
+ Raises:
914
+ RuntimeError: If the handler was not initialized with a scalar integer
915
+ seed, which is required for the legacy prior sampling path.
916
+ """
917
+ if self._seed_input is None:
918
+ return None
919
+
920
+ if self._int_seed is None:
921
+ raise RuntimeError(
922
+ "RNGHandler was not initialized with a scalar integer seed, cannot"
923
+ " provide seed for TensorFlow prior sampling."
924
+ )
925
+ return self._int_seed
926
+
927
+ def advance_handler(self) -> "_TFLegacyRNGHandler":
928
+ """Creates a new handler by incrementing the sanitized seed by 1.
929
+
930
+ Returns:
931
+ A new `_TFLegacyRNGHandler` instance with an incremented seed state.
932
+
933
+ Raises:
934
+ RuntimeError: If the handler's sanitized seed was not initialized.
935
+ """
936
+ if self._seed_input is None:
937
+ return _TFLegacyRNGHandler(None)
938
+
939
+ if self._tf_sanitized_seed is None:
940
+ # Should be caught during init, but included for defensive programming.
941
+ raise RuntimeError("RNGHandler sanitized seed not initialized.")
942
+
943
+ new_sanitized_seed = self._tf_sanitized_seed + 1
944
+
945
+ # Create a new handler instance, passing the original seed input (to
946
+ # preserve state like `_int_seed`) and injecting the new sanitized seed
947
+ # via the private constructor argument.
948
+ return _TFLegacyRNGHandler(
949
+ self._seed_input, _sanitized_seed=new_sanitized_seed
950
+ )
951
+
952
+
953
+ if _BACKEND == config.Backend.JAX:
954
+ RNGHandler = _JaxRNGHandler
955
+ elif _BACKEND == config.Backend.TENSORFLOW:
956
+ RNGHandler = (
957
+ _TFLegacyRNGHandler # TODO: Replace with _TFRNGHandler
958
+ )
959
+ else:
960
+ raise ImportError(f"RNGHandler not implemented for backend: {_BACKEND}")
961
+
962
+
963
+ def to_tensor(data: Any, dtype: Optional[Any] = None) -> Tensor: # type: ignore
964
+ """Converts input data to the currently active backend tensor type.
965
+
966
+ Args:
967
+ data: The data to convert.
968
+ dtype: The desired data type of the resulting tensor. The accepted types
969
+ depend on the active backend (e.g., jax.numpy.dtype or tf.DType).
970
+
971
+ Returns:
972
+ A tensor representation of the data for the active backend.
973
+ """
974
+
975
+ return _convert_to_tensor(data, dtype=dtype)