google-meridian 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,13 +14,16 @@
14
14
 
15
15
  """Backend Abstraction Layer for Meridian."""
16
16
 
17
+ import abc
18
+ import functools
17
19
  import os
18
- from typing import Any, Optional, TYPE_CHECKING, Tuple, Union
20
+ from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
19
21
 
20
22
  from meridian.backend import config
21
23
  import numpy as np
22
24
  from typing_extensions import Literal
23
25
 
26
+
24
27
  # The conditional imports in this module are a deliberate design choice for the
25
28
  # backend abstraction layer. The TFP-on-JAX substrate provides a nearly
26
29
  # identical API to the standard TFP library, making an alias-based approach more
@@ -28,6 +31,19 @@ from typing_extensions import Literal
28
31
  # extensive boilerplate.
29
32
  # pylint: disable=g-import-not-at-top,g-bad-import-order
30
33
 
34
+ _DEFAULT_FLOAT = "float32"
35
+ _DEFAULT_INT = "int64"
36
+
37
+ _TENSORFLOW_TILE_KEYWORD = "multiples"
38
+ _JAX_TILE_KEYWORD = "reps"
39
+
40
+ _ARG_JIT_COMPILE = "jit_compile"
41
+ _ARG_STATIC_ARGNUMS = "static_argnums"
42
+ _ARG_STATIC_ARGNAMES = "static_argnames"
43
+
44
+ _DEFAULT_SEED_DTYPE = "int32"
45
+ _MAX_INT32 = np.iinfo(np.int32).max
46
+
31
47
  if TYPE_CHECKING:
32
48
  import dataclasses
33
49
  import jax as _jax
@@ -35,6 +51,8 @@ if TYPE_CHECKING:
35
51
 
36
52
  TensorShapeInstance = Union[_tf.TensorShape, Tuple[int, ...]]
37
53
 
54
+ SeedType = Any
55
+
38
56
 
39
57
  def standardize_dtype(dtype: Any) -> str:
40
58
  """Converts a backend-specific dtype to a standard string representation.
@@ -88,8 +106,8 @@ def result_type(*types: Any) -> str:
88
106
  standardized_types.append(str(t))
89
107
 
90
108
  if any("float" in t for t in standardized_types):
91
- return "float32"
92
- return "int64"
109
+ return _DEFAULT_FLOAT
110
+ return _DEFAULT_INT
93
111
 
94
112
 
95
113
  def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
@@ -163,6 +181,74 @@ def _jax_divide_no_nan(x, y):
163
181
  return jnp.where(y != 0, jnp.divide(x, y), 0.0)
164
182
 
165
183
 
184
+ def _jax_function_wrapper(func=None, **kwargs):
185
+ """A wrapper for jax.jit that handles TF-like args and static args.
186
+
187
+ This wrapper provides compatibility with TensorFlow's `tf.function` arguments
188
+ and improves ergonomics when decorating class methods in JAX.
189
+
190
+ By default, if neither `static_argnums` nor `static_argnames` are provided, it
191
+ defaults `static_argnums` to `(0,)`. This assumes the function is a method
192
+ where the first argument (`self` or `cls`) should be treated as static.
193
+
194
+ To disable this behavior for plain functions, explicitly provide an empty
195
+ tuple: `@backend.function(static_argnums=())`.
196
+
197
+ Args:
198
+ func: The function to wrap.
199
+ **kwargs: Keyword arguments passed to jax.jit. TF-specific arguments (like
200
+ `jit_compile`) are ignored.
201
+
202
+ Returns:
203
+ The wrapped function or a decorator.
204
+ """
205
+ jit_kwargs = kwargs.copy()
206
+
207
+ jit_kwargs.pop(_ARG_JIT_COMPILE, None)
208
+
209
+ if _ARG_STATIC_ARGNUMS in jit_kwargs:
210
+ if not jit_kwargs[_ARG_STATIC_ARGNUMS]:
211
+ jit_kwargs.pop(_ARG_STATIC_ARGNUMS)
212
+ else:
213
+ jit_kwargs[_ARG_STATIC_ARGNUMS] = (0,)
214
+
215
+ decorator = functools.partial(jax.jit, **jit_kwargs)
216
+
217
+ if func:
218
+ return decorator(func)
219
+ return decorator
220
+
221
+
222
+ def _tf_function_wrapper(func=None, **kwargs):
223
+ """A wrapper for tf.function that ignores JAX-specific arguments."""
224
+ import tensorflow as tf
225
+
226
+ kwargs.pop(_ARG_STATIC_ARGNAMES, None)
227
+ kwargs.pop(_ARG_STATIC_ARGNUMS, None)
228
+
229
+ decorator = tf.function(**kwargs)
230
+
231
+ if func:
232
+ return decorator(func)
233
+ return decorator
234
+
235
+
236
+ def _jax_nanmedian(a, axis=None):
237
+ """JAX implementation for nanmedian."""
238
+ import jax.numpy as jnp
239
+
240
+ return jnp.nanmedian(a, axis=axis)
241
+
242
+
243
+ def _tf_nanmedian(a, axis=None):
244
+ """TensorFlow implementation for nanmedian using numpy_function."""
245
+ import tensorflow as tf
246
+
247
+ return tf.numpy_function(
248
+ lambda x: np.nanmedian(x, axis=axis).astype(x.dtype), [a], a.dtype
249
+ )
250
+
251
+
166
252
  def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
167
253
  raise NotImplementedError(
168
254
  "backend.numpy_function is not implemented for the JAX backend."
@@ -195,6 +281,26 @@ def _tf_get_indices_where(condition):
195
281
  return tf.where(condition)
196
282
 
197
283
 
284
+ def _jax_split(value, num_or_size_splits, axis=0):
285
+ """JAX implementation for split that accepts size splits."""
286
+ import jax.numpy as jnp
287
+
288
+ if not isinstance(num_or_size_splits, int):
289
+ indices = jnp.cumsum(jnp.array(num_or_size_splits))[:-1]
290
+ return jnp.split(value, indices, axis=axis)
291
+
292
+ return jnp.split(value, num_or_size_splits, axis=axis)
293
+
294
+
295
+ def _jax_tile(*args, **kwargs):
296
+ """JAX wrapper for tile that supports the `multiples` keyword argument."""
297
+ import jax.numpy as jnp
298
+
299
+ if _TENSORFLOW_TILE_KEYWORD in kwargs:
300
+ kwargs[_JAX_TILE_KEYWORD] = kwargs.pop(_TENSORFLOW_TILE_KEYWORD)
301
+ return jnp.tile(*args, **kwargs)
302
+
303
+
198
304
  def _jax_unique_with_counts(x):
199
305
  """JAX implementation for unique_with_counts."""
200
306
  import jax.numpy as jnp
@@ -304,6 +410,13 @@ def _jax_tensor_shape(dims):
304
410
  return tuple(dims)
305
411
 
306
412
 
413
+ def _jax_transpose(a, perm=None):
414
+ """JAX wrapper for transpose to support the 'perm' keyword argument."""
415
+ import jax.numpy as jnp
416
+
417
+ return jnp.transpose(a, axes=perm)
418
+
419
+
307
420
  # --- Backend Initialization ---
308
421
  _BACKEND = config.get_backend()
309
422
 
@@ -331,13 +444,60 @@ if _BACKEND == config.Backend.JAX:
331
444
  InvalidArgumentError = ValueError
332
445
  # pylint: enable=invalid-name
333
446
 
447
+ class _JaxRandom:
448
+ """Provides JAX-based random number generation utilities.
449
+
450
+ This class mirrors the structure needed by `RNGHandler` for JAX.
451
+ """
452
+
453
+ @staticmethod
454
+ def prng_key(seed):
455
+ return jax.random.PRNGKey(seed)
456
+
457
+ @staticmethod
458
+ def split(key):
459
+ return jax.random.split(key)
460
+
461
+ @staticmethod
462
+ def generator_from_seed(seed):
463
+ raise NotImplementedError("JAX backend does not use Generators.")
464
+
465
+ @staticmethod
466
+ def stateless_split(seed: Any, num: int = 2):
467
+ raise NotImplementedError(
468
+ "Direct stateless splitting from an integer seed is not the primary"
469
+ " pattern used in the JAX backend. Use `backend.random.split(key)`"
470
+ " instead."
471
+ )
472
+
473
+ @staticmethod
474
+ def stateless_randint(key, shape, minval, maxval, dtype=jax_ops.int32):
475
+ """Wrapper for jax.random.randint."""
476
+ return jax.random.randint(
477
+ key, shape, minval=minval, maxval=maxval, dtype=dtype
478
+ )
479
+
480
+ @staticmethod
481
+ def stateless_uniform(
482
+ key, shape, dtype=jax_ops.float32, minval=0.0, maxval=1.0
483
+ ):
484
+ """Replacement for tfp_jax.random.stateless_uniform using jax.random.uniform."""
485
+ return jax.random.uniform(
486
+ key, shape=shape, dtype=dtype, minval=minval, maxval=maxval
487
+ )
488
+
489
+ @staticmethod
490
+ def sanitize_seed(seed):
491
+ return tfp_jax.random.sanitize_seed(seed)
492
+
493
+ random = _JaxRandom()
494
+
334
495
  _ops = jax_ops
335
496
  errors = _JaxErrors()
336
497
  Tensor = jax.Array
337
498
  tfd = tfp_jax.distributions
338
499
  bijectors = tfp_jax.bijectors
339
500
  experimental = tfp_jax.experimental
340
- random = tfp_jax.random
341
501
  mcmc = tfp_jax.mcmc
342
502
  _convert_to_tensor = _ops.asarray
343
503
 
@@ -349,7 +509,7 @@ if _BACKEND == config.Backend.JAX:
349
509
  boolean_mask = _jax_boolean_mask
350
510
  concatenate = _ops.concatenate
351
511
  stack = _ops.stack
352
- split = _ops.split
512
+ split = _jax_split
353
513
  zeros = _ops.zeros
354
514
  zeros_like = _ops.zeros_like
355
515
  ones = _ops.ones
@@ -358,7 +518,6 @@ if _BACKEND == config.Backend.JAX:
358
518
  reshape = _ops.reshape
359
519
  tile = _ops.tile
360
520
  where = _ops.where
361
- transpose = _ops.transpose
362
521
  broadcast_to = _ops.broadcast_to
363
522
  broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
364
523
  broadcast_to = _ops.broadcast_to
@@ -372,13 +531,14 @@ if _BACKEND == config.Backend.JAX:
372
531
  exp = _ops.exp
373
532
  expand_dims = _ops.expand_dims
374
533
  fill = _jax_fill
375
- function = jax.jit
534
+ function = _jax_function_wrapper
376
535
  gather = _jax_gather
377
536
  get_indices_where = _jax_get_indices_where
378
537
  is_nan = _ops.isnan
379
538
  log = _ops.log
380
539
  make_ndarray = _jax_make_ndarray
381
540
  make_tensor_proto = _jax_make_tensor_proto
541
+ nanmedian = _jax_nanmedian
382
542
  numpy_function = _jax_numpy_function
383
543
  ones = _ops.ones
384
544
  ones_like = _ops.ones_like
@@ -391,10 +551,9 @@ if _BACKEND == config.Backend.JAX:
391
551
  reduce_sum = _ops.sum
392
552
  repeat = _ops.repeat
393
553
  reshape = _ops.reshape
394
- split = _ops.split
395
554
  stack = _ops.stack
396
- tile = _ops.tile
397
- transpose = _ops.transpose
555
+ tile = _jax_tile
556
+ transpose = _jax_transpose
398
557
  unique_with_counts = _jax_unique_with_counts
399
558
  where = _ops.where
400
559
  zeros = _ops.zeros
@@ -404,6 +563,7 @@ if _BACKEND == config.Backend.JAX:
404
563
  bool_ = _ops.bool_
405
564
  newaxis = _ops.newaxis
406
565
  TensorShape = _jax_tensor_shape
566
+ int32 = _ops.int32
407
567
 
408
568
  def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
409
569
  raise NotImplementedError(
@@ -423,10 +583,64 @@ elif _BACKEND == config.Backend.TENSORFLOW:
423
583
  Tensor = tf_backend.Tensor
424
584
  ExtensionType = _ops.experimental.ExtensionType
425
585
 
586
+ class _TfRandom:
587
+ """Provides TensorFlow-based random number generation utilities.
588
+
589
+ This class mirrors the structure needed by `RNGHandler` for TensorFlow.
590
+ """
591
+
592
+ @staticmethod
593
+ def prng_key(seed):
594
+ raise NotImplementedError(
595
+ "TensorFlow backend does not use explicit PRNG keys for"
596
+ " standard sampling."
597
+ )
598
+
599
+ @staticmethod
600
+ def split(key):
601
+ raise NotImplementedError(
602
+ "TensorFlow backend does not implement explicit key splitting."
603
+ )
604
+
605
+ @staticmethod
606
+ def generator_from_seed(seed):
607
+ return tf_backend.random.Generator.from_seed(seed)
608
+
609
+ @staticmethod
610
+ def stateless_split(
611
+ seed: "tf_backend.Tensor", num: int = 2
612
+ ) -> "tf_backend.Tensor":
613
+ return tf_backend.random.experimental.stateless_split(seed, num=num)
614
+
615
+ @staticmethod
616
+ def stateless_randint(
617
+ seed: "tf_backend.Tensor",
618
+ shape: "TensorShapeInstance",
619
+ minval: int,
620
+ maxval: int,
621
+ dtype: Any = _DEFAULT_SEED_DTYPE,
622
+ ) -> "tf_backend.Tensor":
623
+ return tf_backend.random.stateless_uniform(
624
+ shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
625
+ )
626
+
627
+ @staticmethod
628
+ def stateless_uniform(
629
+ seed, shape, minval=0, maxval=None, dtype=tf_backend.float32
630
+ ):
631
+ return tf_backend.random.stateless_uniform(
632
+ shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
633
+ )
634
+
635
+ @staticmethod
636
+ def sanitize_seed(seed):
637
+ return tfp.random.sanitize_seed(seed)
638
+
639
+ random = _TfRandom()
640
+
426
641
  tfd = tfp.distributions
427
642
  bijectors = tfp.bijectors
428
643
  experimental = tfp.experimental
429
- random = tfp.random
430
644
  mcmc = tfp.mcmc
431
645
 
432
646
  _convert_to_tensor = _ops.convert_to_tensor
@@ -446,7 +660,6 @@ elif _BACKEND == config.Backend.TENSORFLOW:
446
660
  reshape = _ops.reshape
447
661
  tile = _ops.tile
448
662
  where = _ops.where
449
- transpose = _ops.transpose
450
663
  broadcast_to = _ops.broadcast_to
451
664
  broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
452
665
  broadcast_to = _ops.broadcast_to
@@ -460,13 +673,14 @@ elif _BACKEND == config.Backend.TENSORFLOW:
460
673
  exp = _ops.math.exp
461
674
  expand_dims = _ops.expand_dims
462
675
  fill = _tf_fill
463
- function = _ops.function
676
+ function = _tf_function_wrapper
464
677
  gather = _tf_gather
465
678
  get_indices_where = _tf_get_indices_where
466
679
  is_nan = _ops.math.is_nan
467
680
  log = _ops.math.log
468
681
  make_ndarray = _ops.make_ndarray
469
682
  make_tensor_proto = _ops.make_tensor_proto
683
+ nanmedian = _tf_nanmedian
470
684
  numpy_function = _ops.numpy_function
471
685
  ones = _ops.ones
472
686
  ones_like = _ops.ones_like
@@ -480,7 +694,6 @@ elif _BACKEND == config.Backend.TENSORFLOW:
480
694
  repeat = _ops.repeat
481
695
  reshape = _ops.reshape
482
696
  set_random_seed = tf_backend.keras.utils.set_random_seed
483
- split = _ops.split
484
697
  stack = _ops.stack
485
698
  tile = _ops.tile
486
699
  transpose = _ops.transpose
@@ -493,12 +706,260 @@ elif _BACKEND == config.Backend.TENSORFLOW:
493
706
  bool_ = _ops.bool
494
707
  newaxis = _ops.newaxis
495
708
  TensorShape = _ops.TensorShape
709
+ int32 = _ops.int32
496
710
 
497
711
  else:
498
712
  raise ValueError(f"Unsupported backend: {_BACKEND}")
499
713
  # pylint: enable=g-import-not-at-top,g-bad-import-order
500
714
 
501
715
 
716
+ def _extract_int_seed(s: Any) -> Optional[int]:
717
+ """Attempts to extract a scalar Python integer from various input types.
718
+
719
+ Args:
720
+ s: The input seed, which can be an int, Tensor, or array-like object.
721
+
722
+ Returns:
723
+ A Python integer if a scalar integer can be extracted, otherwise None.
724
+ """
725
+ try:
726
+ if isinstance(s, int):
727
+ return s
728
+
729
+ value = np.asarray(s)
730
+
731
+ if value.ndim == 0 and np.issubdtype(value.dtype, np.integer):
732
+ return int(value)
733
+ # A broad exception is used here because the input `s` can be of many types
734
+ # (e.g., JAX PRNGKey) which may cause np.asarray or other operations to fail
735
+ # in unpredictable ways. The goal is to safely attempt extraction and fail
736
+ # gracefully.
737
+ except Exception: # pylint: disable=broad-except
738
+ pass
739
+ return None
740
+
741
+
742
+ class _BaseRNGHandler(abc.ABC):
743
+ """A backend-agnostic abstract base class for random number generation state.
744
+
745
+ This handler provides a stateful-style interface for consuming randomness,
746
+ abstracting away the differences between JAX's stateless PRNG keys and
747
+ TensorFlow's paradigms.
748
+
749
+ Attributes:
750
+ _seed_input: The original seed object provided during initialization.
751
+ _int_seed: A Python integer extracted from the seed, if possible.
752
+ """
753
+
754
+ def __init__(self, seed: SeedType):
755
+ """Initializes the RNG handler.
756
+
757
+ Args:
758
+ seed: The initial seed. The accepted type depends on the backend. For JAX,
759
+ this must be an integer. For TensorFlow, this can be an integer, a
760
+ sequence of two integers, or a Tensor. If None, the handler becomes a
761
+ no-op, returning None for all seed requests.
762
+ """
763
+ self._seed_input = seed
764
+ self._int_seed: Optional[int] = _extract_int_seed(seed)
765
+
766
+ @abc.abstractmethod
767
+ def get_kernel_seed(self) -> Any:
768
+ """Provides a backend-appropriate sanitized seed/key for an MCMC kernel.
769
+
770
+ This method exposes the current state of the handler in the format expected
771
+ by the backend's MCMC machinery. It does not advance the internal state.
772
+
773
+ Returns:
774
+ A backend-specific seed object (e.g., a JAX PRNGKey or a TF Tensor).
775
+ """
776
+ raise NotImplementedError
777
+
778
+ @abc.abstractmethod
779
+ def get_next_seed(self) -> Any:
780
+ """Provides the appropriate seed object for the next sequential operation.
781
+
782
+ This is primarily used for prior sampling and typically advances the
783
+ internal state.
784
+
785
+ Returns:
786
+ A backend-specific seed object for a single random operation.
787
+ """
788
+ raise NotImplementedError
789
+
790
+ @abc.abstractmethod
791
+ def advance_handler(self) -> "_BaseRNGHandler":
792
+ """Creates a new, independent RNGHandler for a subsequent operation.
793
+
794
+ This method is used to generate a new handler derived deterministically
795
+ from the current handler's state.
796
+
797
+ Returns:
798
+ A new, independent `RNGHandler` instance.
799
+ """
800
+ raise NotImplementedError
801
+
802
+
803
+ class _JaxRNGHandler(_BaseRNGHandler):
804
+ """JAX implementation of the RNGHandler using explicit key splitting."""
805
+
806
+ def __init__(self, seed: SeedType):
807
+ """Initializes the JAX RNG handler.
808
+
809
+ Args:
810
+ seed: The initial seed, which must be a Python integer, a scalar integer
811
+ Tensor/array, or None.
812
+
813
+ Raises:
814
+ ValueError: If the provided seed is not a scalar integer or None.
815
+ """
816
+ super().__init__(seed)
817
+ self._key: Optional["_jax.Array"] = None
818
+
819
+ if seed is None:
820
+ return
821
+
822
+ if self._int_seed is None:
823
+ raise ValueError(
824
+ "JAX backend requires a seed that is an integer or a scalar array,"
825
+ f" but got: {type(seed)} with value {seed!r}"
826
+ )
827
+
828
+ self._key = random.prng_key(self._int_seed)
829
+
830
+ def get_kernel_seed(self) -> Any:
831
+ if self._key is None:
832
+ return None
833
+ return random.sanitize_seed(self._key)
834
+
835
+ def get_next_seed(self) -> Any:
836
+ if self._key is None:
837
+ return None
838
+ self._key, subkey = random.split(self._key)
839
+ return subkey
840
+
841
+ def advance_handler(self) -> "_JaxRNGHandler":
842
+ if self._key is None:
843
+ return _JaxRNGHandler(None)
844
+
845
+ self._key, subkey_for_new_handler = random.split(self._key)
846
+ new_seed_tensor = random.stateless_randint(
847
+ key=subkey_for_new_handler,
848
+ shape=(),
849
+ minval=0,
850
+ maxval=_MAX_INT32,
851
+ dtype=_DEFAULT_SEED_DTYPE,
852
+ )
853
+ return _JaxRNGHandler(np.asarray(new_seed_tensor).item())
854
+
855
+
856
+ # TODO: Replace with _TFRNGHandler
857
+ class _TFLegacyRNGHandler(_BaseRNGHandler):
858
+ """TensorFlow implementation.
859
+
860
+ TODO: This class should be removed and replaced with a correct,
861
+ stateful `tf.random.Generator`-based implementation.
862
+ """
863
+
864
+ def __init__(self, seed: SeedType, *, _sanitized_seed: Optional[Any] = None):
865
+ """Initializes the TensorFlow legacy RNG handler.
866
+
867
+ Args:
868
+ seed: The initial seed. Can be an integer, a sequence of two integers, a
869
+ corresponding Tensor, or None.
870
+ _sanitized_seed: For internal use only. If provided, this pre-computed
871
+ seed tensor is used directly, and the standard initialization logic for
872
+ the public `seed` argument is bypassed.
873
+
874
+ Raises:
875
+ ValueError: If `seed` is a sequence with a length other than 2.
876
+ """
877
+ super().__init__(seed)
878
+ self._tf_sanitized_seed: Optional[Any] = None
879
+
880
+ if _sanitized_seed is not None:
881
+ # Internal path: A pre-sanitized seed was provided by a trusted source
882
+ # so we adopt it directly.
883
+ self._tf_sanitized_seed = _sanitized_seed
884
+ return
885
+
886
+ if seed is None:
887
+ return
888
+
889
+ if isinstance(seed, Sequence) and len(seed) != 2:
890
+ raise ValueError(
891
+ "Invalid seed: Must be either a single integer (stateful seed) or a"
892
+ " pair of two integers (stateless seed). See"
893
+ " [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
894
+ " for details."
895
+ )
896
+
897
+ if isinstance(seed, int):
898
+ seed_to_sanitize = (seed, seed)
899
+ else:
900
+ seed_to_sanitize = seed
901
+
902
+ self._tf_sanitized_seed = random.sanitize_seed(seed_to_sanitize)
903
+
904
+ def get_kernel_seed(self) -> Any:
905
+ return self._tf_sanitized_seed
906
+
907
+ def get_next_seed(self) -> Any:
908
+ """Returns the original integer seed to preserve prior sampling behavior.
909
+
910
+ Returns:
911
+ The original integer seed provided during initialization.
912
+
913
+ Raises:
914
+ RuntimeError: If the handler was not initialized with a scalar integer
915
+ seed, which is required for the legacy prior sampling path.
916
+ """
917
+ if self._seed_input is None:
918
+ return None
919
+
920
+ if self._int_seed is None:
921
+ raise RuntimeError(
922
+ "RNGHandler was not initialized with a scalar integer seed, cannot"
923
+ " provide seed for TensorFlow prior sampling."
924
+ )
925
+ return self._int_seed
926
+
927
+ def advance_handler(self) -> "_TFLegacyRNGHandler":
928
+ """Creates a new handler by incrementing the sanitized seed by 1.
929
+
930
+ Returns:
931
+ A new `_TFLegacyRNGHandler` instance with an incremented seed state.
932
+
933
+ Raises:
934
+ RuntimeError: If the handler's sanitized seed was not initialized.
935
+ """
936
+ if self._seed_input is None:
937
+ return _TFLegacyRNGHandler(None)
938
+
939
+ if self._tf_sanitized_seed is None:
940
+ # Should be caught during init, but included for defensive programming.
941
+ raise RuntimeError("RNGHandler sanitized seed not initialized.")
942
+
943
+ new_sanitized_seed = self._tf_sanitized_seed + 1
944
+
945
+ # Create a new handler instance, passing the original seed input (to
946
+ # preserve state like `_int_seed`) and injecting the new sanitized seed
947
+ # via the private constructor argument.
948
+ return _TFLegacyRNGHandler(
949
+ self._seed_input, _sanitized_seed=new_sanitized_seed
950
+ )
951
+
952
+
953
+ if _BACKEND == config.Backend.JAX:
954
+ RNGHandler = _JaxRNGHandler
955
+ elif _BACKEND == config.Backend.TENSORFLOW:
956
+ RNGHandler = (
957
+ _TFLegacyRNGHandler # TODO: Replace with _TFRNGHandler
958
+ )
959
+ else:
960
+ raise ImportError(f"RNGHandler not implemented for backend: {_BACKEND}")
961
+
962
+
502
963
  def to_tensor(data: Any, dtype: Optional[Any] = None) -> Tensor: # type: ignore
503
964
  """Converts input data to the currently active backend tensor type.
504
965