google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -18,6 +18,7 @@ import abc
18
18
  import functools
19
19
  import os
20
20
  from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
21
+ import warnings
21
22
 
22
23
  from meridian.backend import config
23
24
  import numpy as np
@@ -37,6 +38,7 @@ _DEFAULT_INT = "int64"
37
38
  _TENSORFLOW_TILE_KEYWORD = "multiples"
38
39
  _JAX_TILE_KEYWORD = "reps"
39
40
 
41
+ _ARG_AUTOGRAPH = "autograph"
40
42
  _ARG_JIT_COMPILE = "jit_compile"
41
43
  _ARG_STATIC_ARGNUMS = "static_argnums"
42
44
  _ARG_STATIC_ARGNAMES = "static_argnames"
@@ -134,6 +136,54 @@ def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
134
136
 
135
137
 
136
138
  # --- Private Backend-Specific Implementations ---
139
+ def _jax_stabilize_rf_roi_grid(
140
+ spend_grid: np.ndarray,
141
+ outcome_grid: np.ndarray,
142
+ n_rf_channels: int,
143
+ ) -> np.ndarray:
144
+ """Stabilizes the RF ROI grid for JAX using a stable index lookup."""
145
+ new_outcome_grid = outcome_grid.copy()
146
+ rf_slice = slice(-n_rf_channels, None)
147
+ rf_spend_grid = spend_grid[:, rf_slice]
148
+ rf_outcome_grid = new_outcome_grid[:, rf_slice]
149
+
150
+ last_valid_indices = np.sum(~np.isnan(rf_spend_grid), axis=0) - 1
151
+ channel_indices = np.arange(n_rf_channels)
152
+
153
+ ref_spend = rf_spend_grid[last_valid_indices, channel_indices]
154
+ ref_outcome = rf_outcome_grid[last_valid_indices, channel_indices]
155
+
156
+ rf_roi = np.divide(
157
+ ref_outcome,
158
+ ref_spend,
159
+ out=np.zeros_like(ref_outcome, dtype=np.float64),
160
+ where=(ref_spend != 0),
161
+ )
162
+
163
+ new_outcome_grid[:, rf_slice] = rf_roi * rf_spend_grid
164
+ return new_outcome_grid
165
+
166
+
167
+ def _tf_stabilize_rf_roi_grid(
168
+ spend_grid: np.ndarray,
169
+ outcome_grid: np.ndarray,
170
+ n_rf_channels: int,
171
+ ) -> np.ndarray:
172
+ """Stabilizes the RF ROI grid for TF using nanmax logic."""
173
+ new_outcome_grid = outcome_grid.copy()
174
+ rf_slice = slice(-n_rf_channels, None)
175
+ rf_outcome_max = np.nanmax(new_outcome_grid[:, rf_slice], axis=0)
176
+ rf_spend_max = np.nanmax(spend_grid[:, rf_slice], axis=0)
177
+
178
+ rf_roi = np.divide(
179
+ rf_outcome_max,
180
+ rf_spend_max,
181
+ out=np.zeros_like(rf_outcome_max, dtype=np.float64),
182
+ where=(rf_spend_max != 0),
183
+ )
184
+
185
+ new_outcome_grid[:, rf_slice] = rf_roi * spend_grid[:, rf_slice]
186
+ return new_outcome_grid
137
187
 
138
188
 
139
189
  def _jax_arange(
@@ -171,7 +221,7 @@ def _tf_arange(
171
221
 
172
222
  def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
173
223
  """JAX implementation for cast."""
174
- return x.astype(dtype)
224
+ return jax_ops.asarray(x, dtype=dtype)
175
225
 
176
226
 
177
227
  def _jax_divide_no_nan(x, y):
@@ -205,6 +255,7 @@ def _jax_function_wrapper(func=None, **kwargs):
205
255
  jit_kwargs = kwargs.copy()
206
256
 
207
257
  jit_kwargs.pop(_ARG_JIT_COMPILE, None)
258
+ jit_kwargs.pop(_ARG_AUTOGRAPH, None)
208
259
 
209
260
  if _ARG_STATIC_ARGNUMS in jit_kwargs:
210
261
  if not jit_kwargs[_ARG_STATIC_ARGNUMS]:
@@ -255,17 +306,132 @@ def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
255
306
  )
256
307
 
257
308
 
258
- def _jax_make_tensor_proto(*args, **kwargs): # pylint: disable=unused-argument
259
- raise NotImplementedError(
260
- "backend.make_tensor_proto is not implemented for the JAX backend."
261
- )
309
+ def _jax_make_tensor_proto(values, dtype=None, shape=None): # pylint: disable=unused-argument
310
+ """JAX implementation for make_tensor_proto."""
311
+ # pylint: disable=g-direct-tensorflow-import
312
+ from tensorflow.core.framework import tensor_pb2
313
+ from tensorflow.core.framework import tensor_shape_pb2
314
+ from tensorflow.core.framework import types_pb2
315
+ # pylint: enable=g-direct-tensorflow-import
262
316
 
317
+ if not isinstance(values, np.ndarray):
318
+ values = np.array(values)
263
319
 
264
- def _jax_make_ndarray(*args, **kwargs): # pylint: disable=unused-argument
265
- raise NotImplementedError(
266
- "backend.make_ndarray is not implemented for the JAX backend."
320
+ if dtype:
321
+ numpy_dtype = np.dtype(dtype)
322
+ values = values.astype(numpy_dtype)
323
+ else:
324
+ numpy_dtype = values.dtype
325
+
326
+ dtype_map = {
327
+ np.dtype(np.float16): types_pb2.DT_HALF,
328
+ np.dtype(np.float32): types_pb2.DT_FLOAT,
329
+ np.dtype(np.float64): types_pb2.DT_DOUBLE,
330
+ np.dtype(np.int32): types_pb2.DT_INT32,
331
+ np.dtype(np.uint8): types_pb2.DT_UINT8,
332
+ np.dtype(np.uint16): types_pb2.DT_UINT16,
333
+ np.dtype(np.uint32): types_pb2.DT_UINT32,
334
+ np.dtype(np.uint64): types_pb2.DT_UINT64,
335
+ np.dtype(np.int16): types_pb2.DT_INT16,
336
+ np.dtype(np.int8): types_pb2.DT_INT8,
337
+ np.dtype(np.int64): types_pb2.DT_INT64,
338
+ np.dtype(np.complex64): types_pb2.DT_COMPLEX64,
339
+ np.dtype(np.complex128): types_pb2.DT_COMPLEX128,
340
+ np.dtype(np.bool_): types_pb2.DT_BOOL,
341
+ # Note: String types are handled outside the map.
342
+ }
343
+ proto_dtype = dtype_map.get(numpy_dtype)
344
+ if proto_dtype is None and numpy_dtype.kind in ("S", "U"):
345
+ proto_dtype = types_pb2.DT_STRING
346
+
347
+ if proto_dtype is None:
348
+ raise TypeError(
349
+ f"Unsupported dtype for TensorProto conversion: {numpy_dtype}"
350
+ )
351
+
352
+ proto = tensor_pb2.TensorProto(
353
+ dtype=proto_dtype,
354
+ tensor_shape=tensor_shape_pb2.TensorShapeProto(
355
+ dim=[
356
+ tensor_shape_pb2.TensorShapeProto.Dim(size=d)
357
+ for d in values.shape
358
+ ]
359
+ ),
267
360
  )
268
361
 
362
+ proto.tensor_content = values.tobytes()
363
+ return proto
364
+
365
+
366
+ def _jax_make_ndarray(proto):
367
+ """JAX implementation for make_ndarray."""
368
+ # pylint: disable=g-direct-tensorflow-import
369
+ from tensorflow.core.framework import types_pb2
370
+ # pylint: enable=g-direct-tensorflow-import
371
+
372
+ dtype_map = {
373
+ types_pb2.DT_HALF: np.float16,
374
+ types_pb2.DT_FLOAT: np.float32,
375
+ types_pb2.DT_DOUBLE: np.float64,
376
+ types_pb2.DT_INT32: np.int32,
377
+ types_pb2.DT_UINT8: np.uint8,
378
+ types_pb2.DT_UINT16: np.uint16,
379
+ types_pb2.DT_UINT32: np.uint32,
380
+ types_pb2.DT_UINT64: np.uint64,
381
+ types_pb2.DT_INT16: np.int16,
382
+ types_pb2.DT_INT8: np.int8,
383
+ types_pb2.DT_INT64: np.int64,
384
+ types_pb2.DT_COMPLEX64: np.complex64,
385
+ types_pb2.DT_COMPLEX128: np.complex128,
386
+ types_pb2.DT_BOOL: np.bool_,
387
+ types_pb2.DT_STRING: np.bytes_,
388
+ }
389
+ if proto.dtype not in dtype_map:
390
+ raise TypeError(f"Unsupported TensorProto dtype: {proto.dtype}")
391
+
392
+ shape = [d.size for d in proto.tensor_shape.dim]
393
+ dtype = dtype_map[proto.dtype]
394
+
395
+ if proto.tensor_content:
396
+ num_elements = np.prod(shape).item() if shape else 0
397
+ # When deserializing a string from tensor_content, the itemsize is not
398
+ # explicitly stored. We must infer it from the content length and shape.
399
+ if dtype == np.bytes_ and num_elements > 0:
400
+ content_len = len(proto.tensor_content)
401
+ itemsize = content_len // num_elements
402
+ if itemsize * num_elements != content_len:
403
+ raise ValueError(
404
+ "Tensor content size is not a multiple of the number of elements"
405
+ " for string dtype."
406
+ )
407
+ dtype = np.dtype(f"S{itemsize}")
408
+
409
+ return (
410
+ np.frombuffer(proto.tensor_content, dtype=dtype).copy().reshape(shape)
411
+ )
412
+
413
+ # Fallback for protos that store data in val fields instead of tensor_content.
414
+ if dtype == np.float32:
415
+ val_field = proto.float_val
416
+ elif dtype == np.float64:
417
+ val_field = proto.double_val
418
+ elif dtype == np.int32:
419
+ val_field = proto.int_val
420
+ elif dtype == np.int64:
421
+ val_field = proto.int64_val
422
+ elif dtype == np.bool_:
423
+ val_field = proto.bool_val
424
+ else:
425
+ if proto.string_val:
426
+ return np.array(proto.string_val, dtype=np.bytes_).reshape(shape)
427
+ if not any(shape):
428
+ return np.array([], dtype=dtype).reshape(shape)
429
+ raise TypeError(
430
+ f"Unsupported dtype for TensorProto value field fallback: {dtype}"
431
+ )
432
+
433
+ return np.array(val_field, dtype=dtype).reshape(shape)
434
+
269
435
 
270
436
  def _jax_get_indices_where(condition):
271
437
  """JAX implementation for get_indices_where."""
@@ -324,6 +490,7 @@ def _jax_boolean_mask(tensor, mask, axis=None):
324
490
 
325
491
  if axis is None:
326
492
  axis = 0
493
+ mask = jnp.asarray(mask)
327
494
  tensor_swapped = jnp.moveaxis(tensor, axis, 0)
328
495
  masked = tensor_swapped[mask]
329
496
  return jnp.moveaxis(masked, 0, axis)
@@ -336,17 +503,26 @@ def _tf_boolean_mask(tensor, mask, axis=None):
336
503
  return tf.boolean_mask(tensor, mask, axis=axis)
337
504
 
338
505
 
339
- def _jax_gather(params, indices):
340
- """JAX implementation for gather."""
341
- # JAX uses standard array indexing for gather operations.
342
- return params[indices]
506
+ def _jax_gather(params, indices, axis=0):
507
+ """JAX implementation for gather with axis support."""
508
+ import jax.numpy as jnp
509
+
510
+ if isinstance(params, (list, tuple)):
511
+ params = np.array(params)
512
+
513
+ # JAX can't JIT-compile operations on string or object arrays. We detect
514
+ # these types and fall back to standard NumPy operations.
515
+ if isinstance(params, np.ndarray) and params.dtype.kind in ("S", "U", "O"):
516
+ return np.take(params, np.asarray(indices), axis=axis)
517
+
518
+ return jnp.take(params, indices, axis=axis)
343
519
 
344
520
 
345
- def _tf_gather(params, indices):
521
+ def _tf_gather(params, indices, axis=0):
346
522
  """TensorFlow implementation for gather."""
347
523
  import tensorflow as tf
348
524
 
349
- return tf.gather(params, indices)
525
+ return tf.gather(params, indices, axis=axis)
350
526
 
351
527
 
352
528
  def _jax_fill(dims, value):
@@ -417,6 +593,127 @@ def _jax_transpose(a, perm=None):
417
593
  return jnp.transpose(a, axes=perm)
418
594
 
419
595
 
596
+ def _jax_get_seed_data(seed: Any) -> Optional[np.ndarray]:
597
+ """Extracts the underlying numerical data from a JAX PRNGKey."""
598
+ if seed is None:
599
+ return None
600
+
601
+ return np.array(jax.random.key_data(seed))
602
+
603
+
604
+ def _tf_get_seed_data(seed: Any) -> Optional[np.ndarray]:
605
+ """Converts a TensorFlow-style seed into a NumPy array."""
606
+ if seed is None:
607
+ return None
608
+ return np.array(seed)
609
+
610
+
611
+ def _jax_convert_to_tensor(data, dtype=None):
612
+ """Converts data to a JAX array, handling strings as NumPy arrays."""
613
+ # JAX does not natively support string tensors in the same way TF does.
614
+ # If a string dtype is requested, or if the data is inherently strings,
615
+ # we fall back to a standard NumPy array.
616
+ is_string_target = False
617
+ if dtype is not None:
618
+ try:
619
+ if np.dtype(dtype).kind in ("S", "U"):
620
+ is_string_target = True
621
+ except TypeError:
622
+ # This can happen if dtype is not a valid dtype specifier,
623
+ # let jax.asarray handle it.
624
+ pass
625
+
626
+ is_string_data = isinstance(data, (list, np.ndarray)) and np.array(
627
+ data
628
+ ).dtype.kind in ("S", "U")
629
+
630
+ if is_string_target or (dtype is None and is_string_data):
631
+ return np.array(data, dtype=dtype)
632
+
633
+ return jax_ops.asarray(data, dtype=dtype)
634
+
635
+
636
+ def _tf_nanmean(a, axis=None, keepdims=False):
637
+ import tensorflow.experimental.numpy as tnp
638
+
639
+ return tf_backend.convert_to_tensor(
640
+ tnp.nanmean(a, axis=axis, keepdims=keepdims)
641
+ )
642
+
643
+
644
+ def _tf_nansum(a, axis=None, keepdims=False):
645
+ import tensorflow.experimental.numpy as tnp
646
+
647
+ return tf_backend.convert_to_tensor(
648
+ tnp.nansum(a, axis=axis, keepdims=keepdims)
649
+ )
650
+
651
+
652
+ def _tf_nanvar(a, axis=None, keepdims=False):
653
+ """Calculates variance ignoring NaNs, strictly returning a Tensor."""
654
+ import tensorflow as tf
655
+ import tensorflow.experimental.numpy as tnp
656
+ # We implement two-pass variance to correctly handle NaNs and ensure
657
+ # all operations remain within the TF graph (maintaining differentiability).
658
+ a_tensor = tf.convert_to_tensor(a)
659
+ mean = tnp.nanmean(a_tensor, axis=axis, keepdims=True)
660
+ sq_diff = tf.math.squared_difference(a_tensor, mean)
661
+ var = tnp.nanmean(sq_diff, axis=axis, keepdims=keepdims)
662
+ return tf.convert_to_tensor(var)
663
+
664
+
665
+ def _jax_one_hot(
666
+ indices, depth, on_value=None, off_value=None, axis=None, dtype=None
667
+ ):
668
+ """JAX implementation for one_hot."""
669
+ import jax.numpy as jnp
670
+
671
+ resolved_dtype = _resolve_dtype(dtype, on_value, off_value, 1, 0)
672
+ jax_axis = -1 if axis is None else axis
673
+
674
+ one_hot_result = jax.nn.one_hot(
675
+ indices, num_classes=depth, dtype=jnp.dtype(resolved_dtype), axis=jax_axis
676
+ )
677
+
678
+ on_val = 1 if on_value is None else on_value
679
+ off_val = 0 if off_value is None else off_value
680
+
681
+ if on_val == 1 and off_val == 0:
682
+ return one_hot_result
683
+
684
+ on_tensor = jnp.array(on_val, dtype=jnp.dtype(resolved_dtype))
685
+ off_tensor = jnp.array(off_val, dtype=jnp.dtype(resolved_dtype))
686
+
687
+ return jnp.where(one_hot_result == 1, on_tensor, off_tensor)
688
+
689
+
690
+ def _jax_roll(a, shift, axis=None):
691
+ """JAX implementation for roll."""
692
+ import jax.numpy as jnp
693
+
694
+ return jnp.roll(a, shift, axis=axis)
695
+
696
+
697
+ def _tf_roll(a, shift: Sequence[int], axis=None):
698
+ """TensorFlow implementation for roll that handles axis=None."""
699
+ import tensorflow as tf
700
+
701
+ if axis is None:
702
+ original_shape = tf.shape(a)
703
+ flat_tensor = tf.reshape(a, [-1])
704
+ rolled_flat = tf.roll(flat_tensor, shift=shift, axis=0)
705
+ return tf.reshape(rolled_flat, original_shape)
706
+ return tf.roll(a, shift, axis=axis)
707
+
708
+
709
+ def _jax_enable_op_determinism():
710
+ """No-op for JAX. Determinism is handled via stateless PRNGKeys."""
711
+ warnings.warn(
712
+ "op determinism is a TensorFlow-specific concept and has no effect when"
713
+ " using the JAX backend."
714
+ )
715
+
716
+
420
717
  # --- Backend Initialization ---
421
718
  _BACKEND = config.get_backend()
422
719
 
@@ -429,15 +726,67 @@ if _BACKEND == config.Backend.JAX:
429
726
  import jax
430
727
  import jax.numpy as jax_ops
431
728
  import tensorflow_probability.substrates.jax as tfp_jax
729
+ from jax import tree_util
432
730
 
433
731
  class ExtensionType:
434
- """A JAX-compatible stand-in for tf.experimental.ExtensionType."""
732
+ """A JAX-compatible stand-in for tf.experimental.ExtensionType.
435
733
 
436
- def __init__(self, *args: Any, **kwargs: Any) -> None:
437
- raise NotImplementedError(
438
- "ExtensionType is not yet implemented for the JAX backend."
734
+ This class registers itself as a JAX Pytree node, allowing it to be passed
735
+ through JIT-compiled functions.
736
+ """
737
+
738
+ def __init_subclass__(cls, **kwargs):
739
+ super().__init_subclass__(**kwargs)
740
+ tree_util.register_pytree_node(
741
+ cls,
742
+ cls._tree_flatten,
743
+ cls._tree_unflatten,
439
744
  )
440
745
 
746
+ def _tree_flatten(self):
747
+ """Flattens the object for JAX tracing.
748
+
749
+ Fields containing JAX arrays (or convertibles) are treated as children.
750
+ Fields containing strings or NumPy string arrays must be treated as
751
+ auxiliary data because JAX cannot trace non-numeric/non-boolean data.
752
+
753
+ Returns:
754
+ A tuple of (children, aux_data), where children are the
755
+ tracer-compatible parts of the object, and aux_data contains auxiliary
756
+ information needed for unflattening.
757
+ """
758
+ d = vars(self)
759
+ all_keys = sorted(d.keys())
760
+ children = []
761
+ aux = {}
762
+ children_keys = []
763
+
764
+ for k in all_keys:
765
+ v = d[k]
766
+ # Identify string data to prevent JAX tracing errors.
767
+ # 'S' is zero-terminated bytes (fixed-width), 'U' is unicode string.
768
+ is_numpy_string = isinstance(v, np.ndarray) and v.dtype.kind in (
769
+ "S",
770
+ "U",
771
+ )
772
+ is_plain_string = isinstance(v, str)
773
+
774
+ if is_numpy_string or is_plain_string:
775
+ aux[k] = v
776
+ else:
777
+ children.append(v)
778
+ children_keys.append(k)
779
+ return children, (aux, children_keys)
780
+
781
+ @classmethod
782
+ def _tree_unflatten(cls, aux_and_keys, children):
783
+ aux, children_keys = aux_and_keys
784
+ obj = cls.__new__(cls)
785
+ vars(obj).update(aux)
786
+ for k, v in zip(children_keys, children):
787
+ setattr(obj, k, v)
788
+ return obj
789
+
441
790
  class _JaxErrors:
442
791
  # pylint: disable=invalid-name
443
792
  ResourceExhaustedError = MemoryError
@@ -492,6 +841,26 @@ if _BACKEND == config.Backend.JAX:
492
841
 
493
842
  random = _JaxRandom()
494
843
 
844
+ @functools.partial(
845
+ jax.jit,
846
+ static_argnames=[
847
+ "joint_dist",
848
+ "n_chains",
849
+ "n_draws",
850
+ "num_adaptation_steps",
851
+ "dual_averaging_kwargs",
852
+ "max_tree_depth",
853
+ "unrolled_leapfrog_steps",
854
+ "parallel_iterations",
855
+ ],
856
+ )
857
+ def _jax_xla_windowed_adaptive_nuts(**kwargs):
858
+ """JAX-specific JIT wrapper for the NUTS sampler."""
859
+ kwargs["seed"] = random.prng_key(kwargs["seed"])
860
+ return experimental.mcmc.windowed_adaptive_nuts(**kwargs)
861
+
862
+ xla_windowed_adaptive_nuts = _jax_xla_windowed_adaptive_nuts
863
+
495
864
  _ops = jax_ops
496
865
  errors = _JaxErrors()
497
866
  Tensor = jax.Array
@@ -499,7 +868,7 @@ if _BACKEND == config.Backend.JAX:
499
868
  bijectors = tfp_jax.bijectors
500
869
  experimental = tfp_jax.experimental
501
870
  mcmc = tfp_jax.mcmc
502
- _convert_to_tensor = _ops.asarray
871
+ _convert_to_tensor = _jax_convert_to_tensor
503
872
 
504
873
  # Standardized Public API
505
874
  absolute = _ops.abs
@@ -507,18 +876,6 @@ if _BACKEND == config.Backend.JAX:
507
876
  arange = _jax_arange
508
877
  argmax = _jax_argmax
509
878
  boolean_mask = _jax_boolean_mask
510
- concatenate = _ops.concatenate
511
- stack = _ops.stack
512
- split = _jax_split
513
- zeros = _ops.zeros
514
- zeros_like = _ops.zeros_like
515
- ones = _ops.ones
516
- ones_like = _ops.ones_like
517
- repeat = _ops.repeat
518
- reshape = _ops.reshape
519
- tile = _ops.tile
520
- where = _ops.where
521
- broadcast_to = _ops.broadcast_to
522
879
  broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
523
880
  broadcast_to = _ops.broadcast_to
524
881
  cast = _jax_cast
@@ -527,6 +884,7 @@ if _BACKEND == config.Backend.JAX:
527
884
  divide = _ops.divide
528
885
  divide_no_nan = _jax_divide_no_nan
529
886
  einsum = _ops.einsum
887
+ enable_op_determinism = _jax_enable_op_determinism
530
888
  equal = _ops.equal
531
889
  exp = _ops.exp
532
890
  expand_dims = _ops.expand_dims
@@ -534,12 +892,17 @@ if _BACKEND == config.Backend.JAX:
534
892
  function = _jax_function_wrapper
535
893
  gather = _jax_gather
536
894
  get_indices_where = _jax_get_indices_where
895
+ get_seed_data = _jax_get_seed_data
537
896
  is_nan = _ops.isnan
538
897
  log = _ops.log
539
898
  make_ndarray = _jax_make_ndarray
540
899
  make_tensor_proto = _jax_make_tensor_proto
900
+ nanmean = jax_ops.nanmean
541
901
  nanmedian = _jax_nanmedian
902
+ nansum = jax_ops.nansum
903
+ nanvar = jax_ops.nanvar
542
904
  numpy_function = _jax_numpy_function
905
+ one_hot = _jax_one_hot
543
906
  ones = _ops.ones
544
907
  ones_like = _ops.ones_like
545
908
  rank = _ops.ndim
@@ -551,6 +914,8 @@ if _BACKEND == config.Backend.JAX:
551
914
  reduce_sum = _ops.sum
552
915
  repeat = _ops.repeat
553
916
  reshape = _ops.reshape
917
+ roll = _jax_roll
918
+ split = _jax_split
554
919
  stack = _ops.stack
555
920
  tile = _jax_tile
556
921
  transpose = _jax_transpose
@@ -564,13 +929,14 @@ if _BACKEND == config.Backend.JAX:
564
929
  newaxis = _ops.newaxis
565
930
  TensorShape = _jax_tensor_shape
566
931
  int32 = _ops.int32
932
+ string = np.bytes_
933
+
934
+ stabilize_rf_roi_grid = _jax_stabilize_rf_roi_grid
567
935
 
568
936
  def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
569
- raise NotImplementedError(
570
- "JAX does not support a global, stateful random seed. `set_random_seed`"
571
- " is not implemented. Instead, you must pass an explicit `seed`"
572
- " integer directly to the sampling methods (e.g., `sample_prior`),"
573
- " which will be used to create a JAX PRNGKey internally."
937
+ warnings.warn(
938
+ "backend.set_random_seed is a no-op in JAX. Randomness is managed "
939
+ "statelessly via PRNGKeys passed to sampling methods."
574
940
  )
575
941
 
576
942
  elif _BACKEND == config.Backend.TENSORFLOW:
@@ -638,29 +1004,25 @@ elif _BACKEND == config.Backend.TENSORFLOW:
638
1004
 
639
1005
  random = _TfRandom()
640
1006
 
1007
+ @tf_backend.function(autograph=False, jit_compile=True)
1008
+ def _tf_xla_windowed_adaptive_nuts(**kwargs):
1009
+ """TensorFlow-specific XLA wrapper for the NUTS sampler."""
1010
+ return experimental.mcmc.windowed_adaptive_nuts(**kwargs)
1011
+
1012
+ xla_windowed_adaptive_nuts = _tf_xla_windowed_adaptive_nuts
1013
+
641
1014
  tfd = tfp.distributions
642
1015
  bijectors = tfp.bijectors
643
1016
  experimental = tfp.experimental
644
1017
  mcmc = tfp.mcmc
645
-
646
1018
  _convert_to_tensor = _ops.convert_to_tensor
1019
+
1020
+ # Standardized Public API
647
1021
  absolute = _ops.math.abs
648
1022
  allclose = _ops.experimental.numpy.allclose
649
1023
  arange = _tf_arange
650
1024
  argmax = _tf_argmax
651
1025
  boolean_mask = _tf_boolean_mask
652
- concatenate = _ops.concat
653
- stack = _ops.stack
654
- split = _ops.split
655
- zeros = _ops.zeros
656
- zeros_like = _ops.zeros_like
657
- ones = _ops.ones
658
- ones_like = _ops.ones_like
659
- repeat = _ops.repeat
660
- reshape = _ops.reshape
661
- tile = _ops.tile
662
- where = _ops.where
663
- broadcast_to = _ops.broadcast_to
664
1026
  broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
665
1027
  broadcast_to = _ops.broadcast_to
666
1028
  cast = _ops.cast
@@ -669,6 +1031,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
669
1031
  divide = _ops.divide
670
1032
  divide_no_nan = _ops.math.divide_no_nan
671
1033
  einsum = _ops.einsum
1034
+ enable_op_determinism = _ops.config.experimental.enable_op_determinism
672
1035
  equal = _ops.equal
673
1036
  exp = _ops.math.exp
674
1037
  expand_dims = _ops.expand_dims
@@ -676,12 +1039,17 @@ elif _BACKEND == config.Backend.TENSORFLOW:
676
1039
  function = _tf_function_wrapper
677
1040
  gather = _tf_gather
678
1041
  get_indices_where = _tf_get_indices_where
1042
+ get_seed_data = _tf_get_seed_data
679
1043
  is_nan = _ops.math.is_nan
680
1044
  log = _ops.math.log
681
1045
  make_ndarray = _ops.make_ndarray
682
1046
  make_tensor_proto = _ops.make_tensor_proto
1047
+ nanmean = _tf_nanmean
683
1048
  nanmedian = _tf_nanmedian
1049
+ nansum = _tf_nansum
1050
+ nanvar = _tf_nanvar
684
1051
  numpy_function = _ops.numpy_function
1052
+ one_hot = _ops.one_hot
685
1053
  ones = _ops.ones
686
1054
  ones_like = _ops.ones_like
687
1055
  rank = _ops.rank
@@ -693,7 +1061,9 @@ elif _BACKEND == config.Backend.TENSORFLOW:
693
1061
  reduce_sum = _ops.reduce_sum
694
1062
  repeat = _ops.repeat
695
1063
  reshape = _ops.reshape
1064
+ roll = _tf_roll
696
1065
  set_random_seed = tf_backend.keras.utils.set_random_seed
1066
+ split = _ops.split
697
1067
  stack = _ops.stack
698
1068
  tile = _ops.tile
699
1069
  transpose = _ops.transpose
@@ -702,11 +1072,14 @@ elif _BACKEND == config.Backend.TENSORFLOW:
702
1072
  zeros = _ops.zeros
703
1073
  zeros_like = _ops.zeros_like
704
1074
 
1075
+ stabilize_rf_roi_grid = _tf_stabilize_rf_roi_grid
1076
+
705
1077
  float32 = _ops.float32
706
1078
  bool_ = _ops.bool
707
1079
  newaxis = _ops.newaxis
708
1080
  TensorShape = _ops.TensorShape
709
1081
  int32 = _ops.int32
1082
+ string = _ops.string
710
1083
 
711
1084
  else:
712
1085
  raise ValueError(f"Unsupported backend: {_BACKEND}")
@@ -735,7 +1108,7 @@ def _extract_int_seed(s: Any) -> Optional[int]:
735
1108
  # in unpredictable ways. The goal is to safely attempt extraction and fail
736
1109
  # gracefully.
737
1110
  except Exception: # pylint: disable=broad-except
738
- pass
1111
+ return None
739
1112
  return None
740
1113
 
741
1114
 
@@ -819,6 +1192,14 @@ class _JaxRNGHandler(_BaseRNGHandler):
819
1192
  if seed is None:
820
1193
  return
821
1194
 
1195
+ if (
1196
+ isinstance(seed, jax.Array) # pylint: disable=undefined-variable
1197
+ and seed.shape == (2,)
1198
+ and seed.dtype == jax_ops.uint32 # pylint: disable=undefined-variable
1199
+ ):
1200
+ self._key = seed
1201
+ return
1202
+
822
1203
  if self._int_seed is None:
823
1204
  raise ValueError(
824
1205
  "JAX backend requires a seed that is an integer or a scalar array,"
@@ -830,7 +1211,8 @@ class _JaxRNGHandler(_BaseRNGHandler):
830
1211
  def get_kernel_seed(self) -> Any:
831
1212
  if self._key is None:
832
1213
  return None
833
- return random.sanitize_seed(self._key)
1214
+ _, subkey = random.split(self._key)
1215
+ return int(jax.random.randint(subkey, (), 0, 2**31 - 1)) # pylint: disable=undefined-variable
834
1216
 
835
1217
  def get_next_seed(self) -> Any:
836
1218
  if self._key is None:
@@ -853,35 +1235,22 @@ class _JaxRNGHandler(_BaseRNGHandler):
853
1235
  return _JaxRNGHandler(np.asarray(new_seed_tensor).item())
854
1236
 
855
1237
 
856
- # TODO: Replace with _TFRNGHandler
857
- class _TFLegacyRNGHandler(_BaseRNGHandler):
858
- """TensorFlow implementation.
1238
+ class _TFRNGHandler(_BaseRNGHandler):
1239
+ """A stateless-style RNG handler for TensorFlow.
859
1240
 
860
- TODO: This class should be removed and replaced with a correct,
861
- stateful `tf.random.Generator`-based implementation.
1241
+ This handler canonicalizes any seed input into a single stateless seed Tensor.
862
1242
  """
863
1243
 
864
- def __init__(self, seed: SeedType, *, _sanitized_seed: Optional[Any] = None):
865
- """Initializes the TensorFlow legacy RNG handler.
1244
+ def __init__(self, seed: SeedType):
1245
+ """Initializes the TensorFlow RNG handler.
866
1246
 
867
1247
  Args:
868
1248
  seed: The initial seed. Can be an integer, a sequence of two integers, a
869
- corresponding Tensor, or None.
870
- _sanitized_seed: For internal use only. If provided, this pre-computed
871
- seed tensor is used directly, and the standard initialization logic for
872
- the public `seed` argument is bypassed.
873
-
874
- Raises:
875
- ValueError: If `seed` is a sequence with a length other than 2.
1249
+ corresponding Tensor, or None. It will be sanitized and stored
1250
+ internally as a single stateless seed Tensor.
876
1251
  """
877
1252
  super().__init__(seed)
878
- self._tf_sanitized_seed: Optional[Any] = None
879
-
880
- if _sanitized_seed is not None:
881
- # Internal path: A pre-sanitized seed was provided by a trusted source
882
- # so we adopt it directly.
883
- self._tf_sanitized_seed = _sanitized_seed
884
- return
1253
+ self._seed_state: Optional["_tf.Tensor"] = None
885
1254
 
886
1255
  if seed is None:
887
1256
  return
@@ -899,63 +1268,33 @@ class _TFLegacyRNGHandler(_BaseRNGHandler):
899
1268
  else:
900
1269
  seed_to_sanitize = seed
901
1270
 
902
- self._tf_sanitized_seed = random.sanitize_seed(seed_to_sanitize)
1271
+ self._seed_state = random.sanitize_seed(seed_to_sanitize)
903
1272
 
904
1273
  def get_kernel_seed(self) -> Any:
905
- return self._tf_sanitized_seed
1274
+ """Provides the current stateless seed of the handler without advancing it."""
1275
+ return self._seed_state
906
1276
 
907
1277
  def get_next_seed(self) -> Any:
908
- """Returns the original integer seed to preserve prior sampling behavior.
909
-
910
- Returns:
911
- The original integer seed provided during initialization.
912
-
913
- Raises:
914
- RuntimeError: If the handler was not initialized with a scalar integer
915
- seed, which is required for the legacy prior sampling path.
916
- """
917
- if self._seed_input is None:
1278
+ """Returns a new unique seed and advances the handler's internal state."""
1279
+ if self._seed_state is None:
918
1280
  return None
1281
+ new_seed, next_state = random.stateless_split(self._seed_state)
1282
+ self._seed_state = next_state
1283
+ return new_seed
919
1284
 
920
- if self._int_seed is None:
921
- raise RuntimeError(
922
- "RNGHandler was not initialized with a scalar integer seed, cannot"
923
- " provide seed for TensorFlow prior sampling."
924
- )
925
- return self._int_seed
926
-
927
- def advance_handler(self) -> "_TFLegacyRNGHandler":
928
- """Creates a new handler by incrementing the sanitized seed by 1.
929
-
930
- Returns:
931
- A new `_TFLegacyRNGHandler` instance with an incremented seed state.
932
-
933
- Raises:
934
- RuntimeError: If the handler's sanitized seed was not initialized.
935
- """
936
- if self._seed_input is None:
937
- return _TFLegacyRNGHandler(None)
938
-
939
- if self._tf_sanitized_seed is None:
940
- # Should be caught during init, but included for defensive programming.
941
- raise RuntimeError("RNGHandler sanitized seed not initialized.")
942
-
943
- new_sanitized_seed = self._tf_sanitized_seed + 1
944
-
945
- # Create a new handler instance, passing the original seed input (to
946
- # preserve state like `_int_seed`) and injecting the new sanitized seed
947
- # via the private constructor argument.
948
- return _TFLegacyRNGHandler(
949
- self._seed_input, _sanitized_seed=new_sanitized_seed
950
- )
1285
+ def advance_handler(self) -> "_TFRNGHandler":
1286
+ """Creates a new handler from a new seed and advances the current handler."""
1287
+ if self._seed_state is None:
1288
+ return _TFRNGHandler(None)
1289
+ seed_for_new_handler, next_state = random.stateless_split(self._seed_state)
1290
+ self._seed_state = next_state
1291
+ return _TFRNGHandler(seed_for_new_handler)
951
1292
 
952
1293
 
953
1294
  if _BACKEND == config.Backend.JAX:
954
1295
  RNGHandler = _JaxRNGHandler
955
1296
  elif _BACKEND == config.Backend.TENSORFLOW:
956
- RNGHandler = (
957
- _TFLegacyRNGHandler # TODO: Replace with _TFRNGHandler
958
- )
1297
+ RNGHandler = _TFRNGHandler
959
1298
  else:
960
1299
  raise ImportError(f"RNGHandler not implemented for backend: {_BACKEND}")
961
1300