google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
  2. google_meridian-1.3.0.dist-info/RECORD +62 -0
  3. meridian/analysis/__init__.py +2 -0
  4. meridian/analysis/analyzer.py +280 -142
  5. meridian/analysis/formatter.py +2 -2
  6. meridian/analysis/optimizer.py +353 -169
  7. meridian/analysis/review/__init__.py +20 -0
  8. meridian/analysis/review/checks.py +721 -0
  9. meridian/analysis/review/configs.py +110 -0
  10. meridian/analysis/review/constants.py +40 -0
  11. meridian/analysis/review/results.py +544 -0
  12. meridian/analysis/review/reviewer.py +186 -0
  13. meridian/analysis/summarizer.py +14 -12
  14. meridian/analysis/templates/chips.html.jinja +12 -0
  15. meridian/analysis/test_utils.py +27 -5
  16. meridian/analysis/visualizer.py +45 -50
  17. meridian/backend/__init__.py +698 -55
  18. meridian/backend/config.py +75 -16
  19. meridian/backend/test_utils.py +127 -1
  20. meridian/constants.py +52 -11
  21. meridian/data/input_data.py +7 -2
  22. meridian/data/test_utils.py +5 -3
  23. meridian/mlflow/autolog.py +2 -2
  24. meridian/model/__init__.py +1 -0
  25. meridian/model/adstock_hill.py +10 -9
  26. meridian/model/eda/__init__.py +3 -0
  27. meridian/model/eda/constants.py +21 -0
  28. meridian/model/eda/eda_engine.py +1580 -84
  29. meridian/model/eda/eda_outcome.py +200 -0
  30. meridian/model/eda/eda_spec.py +84 -0
  31. meridian/model/eda/meridian_eda.py +220 -0
  32. meridian/model/knots.py +56 -50
  33. meridian/model/media.py +10 -8
  34. meridian/model/model.py +79 -16
  35. meridian/model/model_test_data.py +53 -9
  36. meridian/model/posterior_sampler.py +398 -391
  37. meridian/model/prior_distribution.py +114 -39
  38. meridian/model/prior_sampler.py +146 -90
  39. meridian/model/spec.py +7 -8
  40. meridian/model/transformers.py +16 -8
  41. meridian/version.py +1 -1
  42. google_meridian-1.2.0.dist-info/RECORD +0 -52
  43. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
  44. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
  45. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,16 @@
14
14
 
15
15
  """Backend Abstraction Layer for Meridian."""
16
16
 
17
+ import abc
18
+ import functools
17
19
  import os
18
- from typing import Any, Optional, TYPE_CHECKING, Tuple, Union
19
-
20
+ from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
21
+ import warnings
20
22
  from meridian.backend import config
21
23
  import numpy as np
22
24
  from typing_extensions import Literal
23
25
 
26
+
24
27
  # The conditional imports in this module are a deliberate design choice for the
25
28
  # backend abstraction layer. The TFP-on-JAX substrate provides a nearly
26
29
  # identical API to the standard TFP library, making an alias-based approach more
@@ -28,6 +31,20 @@ from typing_extensions import Literal
28
31
  # extensive boilerplate.
29
32
  # pylint: disable=g-import-not-at-top,g-bad-import-order
30
33
 
34
+ _DEFAULT_FLOAT = "float32"
35
+ _DEFAULT_INT = "int64"
36
+
37
+ _TENSORFLOW_TILE_KEYWORD = "multiples"
38
+ _JAX_TILE_KEYWORD = "reps"
39
+
40
+ _ARG_AUTOGRAPH = "autograph"
41
+ _ARG_JIT_COMPILE = "jit_compile"
42
+ _ARG_STATIC_ARGNUMS = "static_argnums"
43
+ _ARG_STATIC_ARGNAMES = "static_argnames"
44
+
45
+ _DEFAULT_SEED_DTYPE = "int32"
46
+ _MAX_INT32 = np.iinfo(np.int32).max
47
+
31
48
  if TYPE_CHECKING:
32
49
  import dataclasses
33
50
  import jax as _jax
@@ -35,6 +52,8 @@ if TYPE_CHECKING:
35
52
 
36
53
  TensorShapeInstance = Union[_tf.TensorShape, Tuple[int, ...]]
37
54
 
55
+ SeedType = Any
56
+
38
57
 
39
58
  def standardize_dtype(dtype: Any) -> str:
40
59
  """Converts a backend-specific dtype to a standard string representation.
@@ -88,8 +107,8 @@ def result_type(*types: Any) -> str:
88
107
  standardized_types.append(str(t))
89
108
 
90
109
  if any("float" in t for t in standardized_types):
91
- return "float32"
92
- return "int64"
110
+ return _DEFAULT_FLOAT
111
+ return _DEFAULT_INT
93
112
 
94
113
 
95
114
  def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
@@ -116,6 +135,54 @@ def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
116
135
 
117
136
 
118
137
  # --- Private Backend-Specific Implementations ---
138
+ def _jax_stabilize_rf_roi_grid(
139
+ spend_grid: np.ndarray,
140
+ outcome_grid: np.ndarray,
141
+ n_rf_channels: int,
142
+ ) -> np.ndarray:
143
+ """Stabilizes the RF ROI grid for JAX using a stable index lookup."""
144
+ new_outcome_grid = outcome_grid.copy()
145
+ rf_slice = slice(-n_rf_channels, None)
146
+ rf_spend_grid = spend_grid[:, rf_slice]
147
+ rf_outcome_grid = new_outcome_grid[:, rf_slice]
148
+
149
+ last_valid_indices = np.sum(~np.isnan(rf_spend_grid), axis=0) - 1
150
+ channel_indices = np.arange(n_rf_channels)
151
+
152
+ ref_spend = rf_spend_grid[last_valid_indices, channel_indices]
153
+ ref_outcome = rf_outcome_grid[last_valid_indices, channel_indices]
154
+
155
+ rf_roi = np.divide(
156
+ ref_outcome,
157
+ ref_spend,
158
+ out=np.zeros_like(ref_outcome, dtype=np.float64),
159
+ where=(ref_spend != 0),
160
+ )
161
+
162
+ new_outcome_grid[:, rf_slice] = rf_roi * rf_spend_grid
163
+ return new_outcome_grid
164
+
165
+
166
+ def _tf_stabilize_rf_roi_grid(
167
+ spend_grid: np.ndarray,
168
+ outcome_grid: np.ndarray,
169
+ n_rf_channels: int,
170
+ ) -> np.ndarray:
171
+ """Stabilizes the RF ROI grid for TF using nanmax logic."""
172
+ new_outcome_grid = outcome_grid.copy()
173
+ rf_slice = slice(-n_rf_channels, None)
174
+ rf_outcome_max = np.nanmax(new_outcome_grid[:, rf_slice], axis=0)
175
+ rf_spend_max = np.nanmax(spend_grid[:, rf_slice], axis=0)
176
+
177
+ rf_roi = np.divide(
178
+ rf_outcome_max,
179
+ rf_spend_max,
180
+ out=np.zeros_like(rf_outcome_max, dtype=np.float64),
181
+ where=(rf_spend_max != 0),
182
+ )
183
+
184
+ new_outcome_grid[:, rf_slice] = rf_roi * spend_grid[:, rf_slice]
185
+ return new_outcome_grid
119
186
 
120
187
 
121
188
  def _jax_arange(
@@ -163,6 +230,75 @@ def _jax_divide_no_nan(x, y):
163
230
  return jnp.where(y != 0, jnp.divide(x, y), 0.0)
164
231
 
165
232
 
233
+ def _jax_function_wrapper(func=None, **kwargs):
234
+ """A wrapper for jax.jit that handles TF-like args and static args.
235
+
236
+ This wrapper provides compatibility with TensorFlow's `tf.function` arguments
237
+ and improves ergonomics when decorating class methods in JAX.
238
+
239
+ By default, if neither `static_argnums` nor `static_argnames` are provided, it
240
+ defaults `static_argnums` to `(0,)`. This assumes the function is a method
241
+ where the first argument (`self` or `cls`) should be treated as static.
242
+
243
+ To disable this behavior for plain functions, explicitly provide an empty
244
+ tuple: `@backend.function(static_argnums=())`.
245
+
246
+ Args:
247
+ func: The function to wrap.
248
+ **kwargs: Keyword arguments passed to jax.jit. TF-specific arguments (like
249
+ `jit_compile`) are ignored.
250
+
251
+ Returns:
252
+ The wrapped function or a decorator.
253
+ """
254
+ jit_kwargs = kwargs.copy()
255
+
256
+ jit_kwargs.pop(_ARG_JIT_COMPILE, None)
257
+ jit_kwargs.pop(_ARG_AUTOGRAPH, None)
258
+
259
+ if _ARG_STATIC_ARGNUMS in jit_kwargs:
260
+ if not jit_kwargs[_ARG_STATIC_ARGNUMS]:
261
+ jit_kwargs.pop(_ARG_STATIC_ARGNUMS)
262
+ else:
263
+ jit_kwargs[_ARG_STATIC_ARGNUMS] = (0,)
264
+
265
+ decorator = functools.partial(jax.jit, **jit_kwargs)
266
+
267
+ if func:
268
+ return decorator(func)
269
+ return decorator
270
+
271
+
272
+ def _tf_function_wrapper(func=None, **kwargs):
273
+ """A wrapper for tf.function that ignores JAX-specific arguments."""
274
+ import tensorflow as tf
275
+
276
+ kwargs.pop(_ARG_STATIC_ARGNAMES, None)
277
+ kwargs.pop(_ARG_STATIC_ARGNUMS, None)
278
+
279
+ decorator = tf.function(**kwargs)
280
+
281
+ if func:
282
+ return decorator(func)
283
+ return decorator
284
+
285
+
286
+ def _jax_nanmedian(a, axis=None):
287
+ """JAX implementation for nanmedian."""
288
+ import jax.numpy as jnp
289
+
290
+ return jnp.nanmedian(a, axis=axis)
291
+
292
+
293
+ def _tf_nanmedian(a, axis=None):
294
+ """TensorFlow implementation for nanmedian using numpy_function."""
295
+ import tensorflow as tf
296
+
297
+ return tf.numpy_function(
298
+ lambda x: np.nanmedian(x, axis=axis).astype(x.dtype), [a], a.dtype
299
+ )
300
+
301
+
166
302
  def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
167
303
  raise NotImplementedError(
168
304
  "backend.numpy_function is not implemented for the JAX backend."
@@ -195,6 +331,26 @@ def _tf_get_indices_where(condition):
195
331
  return tf.where(condition)
196
332
 
197
333
 
334
+ def _jax_split(value, num_or_size_splits, axis=0):
335
+ """JAX implementation for split that accepts size splits."""
336
+ import jax.numpy as jnp
337
+
338
+ if not isinstance(num_or_size_splits, int):
339
+ indices = jnp.cumsum(jnp.array(num_or_size_splits))[:-1]
340
+ return jnp.split(value, indices, axis=axis)
341
+
342
+ return jnp.split(value, num_or_size_splits, axis=axis)
343
+
344
+
345
+ def _jax_tile(*args, **kwargs):
346
+ """JAX wrapper for tile that supports the `multiples` keyword argument."""
347
+ import jax.numpy as jnp
348
+
349
+ if _TENSORFLOW_TILE_KEYWORD in kwargs:
350
+ kwargs[_JAX_TILE_KEYWORD] = kwargs.pop(_TENSORFLOW_TILE_KEYWORD)
351
+ return jnp.tile(*args, **kwargs)
352
+
353
+
198
354
  def _jax_unique_with_counts(x):
199
355
  """JAX implementation for unique_with_counts."""
200
356
  import jax.numpy as jnp
@@ -218,6 +374,7 @@ def _jax_boolean_mask(tensor, mask, axis=None):
218
374
 
219
375
  if axis is None:
220
376
  axis = 0
377
+ mask = jnp.asarray(mask)
221
378
  tensor_swapped = jnp.moveaxis(tensor, axis, 0)
222
379
  masked = tensor_swapped[mask]
223
380
  return jnp.moveaxis(masked, 0, axis)
@@ -230,17 +387,26 @@ def _tf_boolean_mask(tensor, mask, axis=None):
230
387
  return tf.boolean_mask(tensor, mask, axis=axis)
231
388
 
232
389
 
233
- def _jax_gather(params, indices):
234
- """JAX implementation for gather."""
235
- # JAX uses standard array indexing for gather operations.
236
- return params[indices]
390
+ def _jax_gather(params, indices, axis=0):
391
+ """JAX implementation for gather with axis support."""
392
+ import jax.numpy as jnp
393
+
394
+ if isinstance(params, (list, tuple)):
395
+ params = np.array(params)
396
+
397
+ # JAX can't JIT-compile operations on string or object arrays. We detect
398
+ # these types and fall back to standard NumPy operations.
399
+ if isinstance(params, np.ndarray) and params.dtype.kind in ("S", "U", "O"):
400
+ return np.take(params, np.asarray(indices), axis=axis)
401
+
402
+ return jnp.take(params, indices, axis=axis)
237
403
 
238
404
 
239
- def _tf_gather(params, indices):
405
+ def _tf_gather(params, indices, axis=0):
240
406
  """TensorFlow implementation for gather."""
241
407
  import tensorflow as tf
242
408
 
243
- return tf.gather(params, indices)
409
+ return tf.gather(params, indices, axis=axis)
244
410
 
245
411
 
246
412
  def _jax_fill(dims, value):
@@ -304,6 +470,93 @@ def _jax_tensor_shape(dims):
304
470
  return tuple(dims)
305
471
 
306
472
 
473
+ def _jax_transpose(a, perm=None):
474
+ """JAX wrapper for transpose to support the 'perm' keyword argument."""
475
+ import jax.numpy as jnp
476
+
477
+ return jnp.transpose(a, axes=perm)
478
+
479
+
480
+ def _jax_get_seed_data(seed: Any) -> Optional[np.ndarray]:
481
+ """Extracts the underlying numerical data from a JAX PRNGKey."""
482
+ if seed is None:
483
+ return None
484
+
485
+ return np.array(jax.random.key_data(seed))
486
+
487
+
488
+ def _tf_get_seed_data(seed: Any) -> Optional[np.ndarray]:
489
+ """Converts a TensorFlow-style seed into a NumPy array."""
490
+ if seed is None:
491
+ return None
492
+ return np.array(seed)
493
+
494
+
495
+ def _jax_convert_to_tensor(data, dtype=None):
496
+ """Converts data to a JAX array, handling strings as NumPy arrays."""
497
+ # JAX does not natively support string tensors in the same way TF does.
498
+ # If a string dtype is requested, or if the data is inherently strings,
499
+ # we fall back to a standard NumPy array.
500
+ if dtype == np.str_ or (
501
+ dtype is None
502
+ and isinstance(data, (list, np.ndarray))
503
+ and np.array(data).dtype.kind in ("S", "U")
504
+ ):
505
+ return np.array(data, dtype=np.str_)
506
+ return jax_ops.asarray(data, dtype=dtype)
507
+
508
+
509
+ def _tf_nanmean(a, axis=None, keepdims=False):
510
+ import tensorflow.experimental.numpy as tnp
511
+
512
+ return tf_backend.convert_to_tensor(
513
+ tnp.nanmean(a, axis=axis, keepdims=keepdims)
514
+ )
515
+
516
+
517
+ def _tf_nansum(a, axis=None, keepdims=False):
518
+ import tensorflow.experimental.numpy as tnp
519
+
520
+ return tf_backend.convert_to_tensor(
521
+ tnp.nansum(a, axis=axis, keepdims=keepdims)
522
+ )
523
+
524
+
525
+ def _tf_nanvar(a, axis=None, keepdims=False):
526
+ """Calculates variance ignoring NaNs, strictly returning a Tensor."""
527
+ import tensorflow as tf
528
+ import tensorflow.experimental.numpy as tnp
529
+ # We implement two-pass variance to correctly handle NaNs and ensure
530
+ # all operations remain within the TF graph (maintaining differentiability).
531
+ a_tensor = tf.convert_to_tensor(a)
532
+ mean = tnp.nanmean(a_tensor, axis=axis, keepdims=True)
533
+ sq_diff = tf.math.squared_difference(a_tensor, mean)
534
+ var = tnp.nanmean(sq_diff, axis=axis, keepdims=keepdims)
535
+ return tf.convert_to_tensor(var)
536
+
537
+
538
+ def _jax_one_hot(*args, **kwargs): # pylint: disable=unused-argument
539
+ """JAX implementation for one_hot."""
540
+ raise NotImplementedError(
541
+ "backend.one_hot is not implemented for the JAX backend."
542
+ )
543
+
544
+
545
+ def _jax_roll(*args, **kwargs): # pylint: disable=unused-argument
546
+ """JAX implementation for roll."""
547
+ raise NotImplementedError(
548
+ "backend.roll is not implemented for the JAX backend."
549
+ )
550
+
551
+
552
+ def _jax_enable_op_determinism():
553
+ """No-op for JAX. Determinism is handled via stateless PRNGKeys."""
554
+ warnings.warn(
555
+ "op determinism is a TensorFlow-specific concept and has no effect when"
556
+ " using the JAX backend."
557
+ )
558
+
559
+
307
560
  # --- Backend Initialization ---
308
561
  _BACKEND = config.get_backend()
309
562
 
@@ -316,30 +569,149 @@ if _BACKEND == config.Backend.JAX:
316
569
  import jax
317
570
  import jax.numpy as jax_ops
318
571
  import tensorflow_probability.substrates.jax as tfp_jax
572
+ from jax import tree_util
319
573
 
320
574
  class ExtensionType:
321
- """A JAX-compatible stand-in for tf.experimental.ExtensionType."""
322
-
323
- def __init__(self, *args: Any, **kwargs: Any) -> None:
324
- raise NotImplementedError(
325
- "ExtensionType is not yet implemented for the JAX backend."
575
+ """A JAX-compatible stand-in for tf.experimental.ExtensionType.
576
+
577
+ This class registers itself as a JAX Pytree node, allowing it to be passed
578
+ through JIT-compiled functions.
579
+ """
580
+
581
+ def __init_subclass__(cls, **kwargs):
582
+ super().__init_subclass__(**kwargs)
583
+ tree_util.register_pytree_node(
584
+ cls,
585
+ cls._tree_flatten,
586
+ cls._tree_unflatten,
326
587
  )
327
588
 
589
+ def _tree_flatten(self):
590
+ """Flattens the object for JAX tracing.
591
+
592
+ Fields containing JAX arrays (or convertibles) are treated as children.
593
+ Fields containing strings or NumPy string arrays must be treated as
594
+ auxiliary data because JAX cannot trace non-numeric/non-boolean data.
595
+
596
+ Returns:
597
+ A tuple of (children, aux_data), where children are the
598
+ tracer-compatible parts of the object, and aux_data contains auxiliary
599
+ information needed for unflattening.
600
+ """
601
+ d = vars(self)
602
+ all_keys = sorted(d.keys())
603
+ children = []
604
+ aux = {}
605
+ children_keys = []
606
+
607
+ for k in all_keys:
608
+ v = d[k]
609
+ # Identify string data to prevent JAX tracing errors.
610
+ # 'S' is zero-terminated bytes (fixed-width), 'U' is unicode string.
611
+ is_numpy_string = isinstance(v, np.ndarray) and v.dtype.kind in (
612
+ "S",
613
+ "U",
614
+ )
615
+ is_plain_string = isinstance(v, str)
616
+
617
+ if is_numpy_string or is_plain_string:
618
+ aux[k] = v
619
+ else:
620
+ children.append(v)
621
+ children_keys.append(k)
622
+ return children, (aux, children_keys)
623
+
624
+ @classmethod
625
+ def _tree_unflatten(cls, aux_and_keys, children):
626
+ aux, children_keys = aux_and_keys
627
+ obj = cls.__new__(cls)
628
+ vars(obj).update(aux)
629
+ for k, v in zip(children_keys, children):
630
+ setattr(obj, k, v)
631
+ return obj
632
+
328
633
  class _JaxErrors:
329
634
  # pylint: disable=invalid-name
330
635
  ResourceExhaustedError = MemoryError
331
636
  InvalidArgumentError = ValueError
332
637
  # pylint: enable=invalid-name
333
638
 
639
+ class _JaxRandom:
640
+ """Provides JAX-based random number generation utilities.
641
+
642
+ This class mirrors the structure needed by `RNGHandler` for JAX.
643
+ """
644
+
645
+ @staticmethod
646
+ def prng_key(seed):
647
+ return jax.random.PRNGKey(seed)
648
+
649
+ @staticmethod
650
+ def split(key):
651
+ return jax.random.split(key)
652
+
653
+ @staticmethod
654
+ def generator_from_seed(seed):
655
+ raise NotImplementedError("JAX backend does not use Generators.")
656
+
657
+ @staticmethod
658
+ def stateless_split(seed: Any, num: int = 2):
659
+ raise NotImplementedError(
660
+ "Direct stateless splitting from an integer seed is not the primary"
661
+ " pattern used in the JAX backend. Use `backend.random.split(key)`"
662
+ " instead."
663
+ )
664
+
665
+ @staticmethod
666
+ def stateless_randint(key, shape, minval, maxval, dtype=jax_ops.int32):
667
+ """Wrapper for jax.random.randint."""
668
+ return jax.random.randint(
669
+ key, shape, minval=minval, maxval=maxval, dtype=dtype
670
+ )
671
+
672
+ @staticmethod
673
+ def stateless_uniform(
674
+ key, shape, dtype=jax_ops.float32, minval=0.0, maxval=1.0
675
+ ):
676
+ """Replacement for tfp_jax.random.stateless_uniform using jax.random.uniform."""
677
+ return jax.random.uniform(
678
+ key, shape=shape, dtype=dtype, minval=minval, maxval=maxval
679
+ )
680
+
681
+ @staticmethod
682
+ def sanitize_seed(seed):
683
+ return tfp_jax.random.sanitize_seed(seed)
684
+
685
+ random = _JaxRandom()
686
+
687
+ @functools.partial(
688
+ jax.jit,
689
+ static_argnames=[
690
+ "joint_dist",
691
+ "n_chains",
692
+ "n_draws",
693
+ "num_adaptation_steps",
694
+ "dual_averaging_kwargs",
695
+ "max_tree_depth",
696
+ "unrolled_leapfrog_steps",
697
+ "parallel_iterations",
698
+ ],
699
+ )
700
+ def _jax_xla_windowed_adaptive_nuts(**kwargs):
701
+ """JAX-specific JIT wrapper for the NUTS sampler."""
702
+ kwargs["seed"] = random.prng_key(kwargs["seed"])
703
+ return experimental.mcmc.windowed_adaptive_nuts(**kwargs)
704
+
705
+ xla_windowed_adaptive_nuts = _jax_xla_windowed_adaptive_nuts
706
+
334
707
  _ops = jax_ops
335
708
  errors = _JaxErrors()
336
709
  Tensor = jax.Array
337
710
  tfd = tfp_jax.distributions
338
711
  bijectors = tfp_jax.bijectors
339
712
  experimental = tfp_jax.experimental
340
- random = tfp_jax.random
341
713
  mcmc = tfp_jax.mcmc
342
- _convert_to_tensor = _ops.asarray
714
+ _convert_to_tensor = _jax_convert_to_tensor
343
715
 
344
716
  # Standardized Public API
345
717
  absolute = _ops.abs
@@ -347,19 +719,6 @@ if _BACKEND == config.Backend.JAX:
347
719
  arange = _jax_arange
348
720
  argmax = _jax_argmax
349
721
  boolean_mask = _jax_boolean_mask
350
- concatenate = _ops.concatenate
351
- stack = _ops.stack
352
- split = _ops.split
353
- zeros = _ops.zeros
354
- zeros_like = _ops.zeros_like
355
- ones = _ops.ones
356
- ones_like = _ops.ones_like
357
- repeat = _ops.repeat
358
- reshape = _ops.reshape
359
- tile = _ops.tile
360
- where = _ops.where
361
- transpose = _ops.transpose
362
- broadcast_to = _ops.broadcast_to
363
722
  broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
364
723
  broadcast_to = _ops.broadcast_to
365
724
  cast = _jax_cast
@@ -368,18 +727,25 @@ if _BACKEND == config.Backend.JAX:
368
727
  divide = _ops.divide
369
728
  divide_no_nan = _jax_divide_no_nan
370
729
  einsum = _ops.einsum
730
+ enable_op_determinism = _jax_enable_op_determinism
371
731
  equal = _ops.equal
372
732
  exp = _ops.exp
373
733
  expand_dims = _ops.expand_dims
374
734
  fill = _jax_fill
375
- function = jax.jit
735
+ function = _jax_function_wrapper
376
736
  gather = _jax_gather
377
737
  get_indices_where = _jax_get_indices_where
738
+ get_seed_data = _jax_get_seed_data
378
739
  is_nan = _ops.isnan
379
740
  log = _ops.log
380
741
  make_ndarray = _jax_make_ndarray
381
742
  make_tensor_proto = _jax_make_tensor_proto
743
+ nanmean = jax_ops.nanmean
744
+ nanmedian = _jax_nanmedian
745
+ nansum = jax_ops.nansum
746
+ nanvar = jax_ops.nanvar
382
747
  numpy_function = _jax_numpy_function
748
+ one_hot = _jax_one_hot
383
749
  ones = _ops.ones
384
750
  ones_like = _ops.ones_like
385
751
  rank = _ops.ndim
@@ -391,10 +757,11 @@ if _BACKEND == config.Backend.JAX:
391
757
  reduce_sum = _ops.sum
392
758
  repeat = _ops.repeat
393
759
  reshape = _ops.reshape
394
- split = _ops.split
760
+ roll = _jax_roll
761
+ split = _jax_split
395
762
  stack = _ops.stack
396
- tile = _ops.tile
397
- transpose = _ops.transpose
763
+ tile = _jax_tile
764
+ transpose = _jax_transpose
398
765
  unique_with_counts = _jax_unique_with_counts
399
766
  where = _ops.where
400
767
  zeros = _ops.zeros
@@ -404,13 +771,15 @@ if _BACKEND == config.Backend.JAX:
404
771
  bool_ = _ops.bool_
405
772
  newaxis = _ops.newaxis
406
773
  TensorShape = _jax_tensor_shape
774
+ int32 = _ops.int32
775
+ string = np.str_
776
+
777
+ stabilize_rf_roi_grid = _jax_stabilize_rf_roi_grid
407
778
 
408
779
  def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
409
- raise NotImplementedError(
410
- "JAX does not support a global, stateful random seed. `set_random_seed`"
411
- " is not implemented. Instead, you must pass an explicit `seed`"
412
- " integer directly to the sampling methods (e.g., `sample_prior`),"
413
- " which will be used to create a JAX PRNGKey internally."
780
+ warnings.warn(
781
+ "backend.set_random_seed is a no-op in JAX. Randomness is managed "
782
+ "statelessly via PRNGKeys passed to sampling methods."
414
783
  )
415
784
 
416
785
  elif _BACKEND == config.Backend.TENSORFLOW:
@@ -423,31 +792,80 @@ elif _BACKEND == config.Backend.TENSORFLOW:
423
792
  Tensor = tf_backend.Tensor
424
793
  ExtensionType = _ops.experimental.ExtensionType
425
794
 
795
+ class _TfRandom:
796
+ """Provides TensorFlow-based random number generation utilities.
797
+
798
+ This class mirrors the structure needed by `RNGHandler` for TensorFlow.
799
+ """
800
+
801
+ @staticmethod
802
+ def prng_key(seed):
803
+ raise NotImplementedError(
804
+ "TensorFlow backend does not use explicit PRNG keys for"
805
+ " standard sampling."
806
+ )
807
+
808
+ @staticmethod
809
+ def split(key):
810
+ raise NotImplementedError(
811
+ "TensorFlow backend does not implement explicit key splitting."
812
+ )
813
+
814
+ @staticmethod
815
+ def generator_from_seed(seed):
816
+ return tf_backend.random.Generator.from_seed(seed)
817
+
818
+ @staticmethod
819
+ def stateless_split(
820
+ seed: "tf_backend.Tensor", num: int = 2
821
+ ) -> "tf_backend.Tensor":
822
+ return tf_backend.random.experimental.stateless_split(seed, num=num)
823
+
824
+ @staticmethod
825
+ def stateless_randint(
826
+ seed: "tf_backend.Tensor",
827
+ shape: "TensorShapeInstance",
828
+ minval: int,
829
+ maxval: int,
830
+ dtype: Any = _DEFAULT_SEED_DTYPE,
831
+ ) -> "tf_backend.Tensor":
832
+ return tf_backend.random.stateless_uniform(
833
+ shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
834
+ )
835
+
836
+ @staticmethod
837
+ def stateless_uniform(
838
+ seed, shape, minval=0, maxval=None, dtype=tf_backend.float32
839
+ ):
840
+ return tf_backend.random.stateless_uniform(
841
+ shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
842
+ )
843
+
844
+ @staticmethod
845
+ def sanitize_seed(seed):
846
+ return tfp.random.sanitize_seed(seed)
847
+
848
+ random = _TfRandom()
849
+
850
+ @tf_backend.function(autograph=False, jit_compile=True)
851
+ def _tf_xla_windowed_adaptive_nuts(**kwargs):
852
+ """TensorFlow-specific XLA wrapper for the NUTS sampler."""
853
+ return experimental.mcmc.windowed_adaptive_nuts(**kwargs)
854
+
855
+ xla_windowed_adaptive_nuts = _tf_xla_windowed_adaptive_nuts
856
+
426
857
  tfd = tfp.distributions
427
858
  bijectors = tfp.bijectors
428
859
  experimental = tfp.experimental
429
- random = tfp.random
430
860
  mcmc = tfp.mcmc
431
-
432
861
  _convert_to_tensor = _ops.convert_to_tensor
862
+
863
+ # Standardized Public API
433
864
  absolute = _ops.math.abs
434
865
  allclose = _ops.experimental.numpy.allclose
435
866
  arange = _tf_arange
436
867
  argmax = _tf_argmax
437
868
  boolean_mask = _tf_boolean_mask
438
- concatenate = _ops.concat
439
- stack = _ops.stack
440
- split = _ops.split
441
- zeros = _ops.zeros
442
- zeros_like = _ops.zeros_like
443
- ones = _ops.ones
444
- ones_like = _ops.ones_like
445
- repeat = _ops.repeat
446
- reshape = _ops.reshape
447
- tile = _ops.tile
448
- where = _ops.where
449
- transpose = _ops.transpose
450
- broadcast_to = _ops.broadcast_to
451
869
  broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
452
870
  broadcast_to = _ops.broadcast_to
453
871
  cast = _ops.cast
@@ -456,18 +874,25 @@ elif _BACKEND == config.Backend.TENSORFLOW:
456
874
  divide = _ops.divide
457
875
  divide_no_nan = _ops.math.divide_no_nan
458
876
  einsum = _ops.einsum
877
+ enable_op_determinism = _ops.config.experimental.enable_op_determinism
459
878
  equal = _ops.equal
460
879
  exp = _ops.math.exp
461
880
  expand_dims = _ops.expand_dims
462
881
  fill = _tf_fill
463
- function = _ops.function
882
+ function = _tf_function_wrapper
464
883
  gather = _tf_gather
465
884
  get_indices_where = _tf_get_indices_where
885
+ get_seed_data = _tf_get_seed_data
466
886
  is_nan = _ops.math.is_nan
467
887
  log = _ops.math.log
468
888
  make_ndarray = _ops.make_ndarray
469
889
  make_tensor_proto = _ops.make_tensor_proto
890
+ nanmean = _tf_nanmean
891
+ nanmedian = _tf_nanmedian
892
+ nansum = _tf_nansum
893
+ nanvar = _tf_nanvar
470
894
  numpy_function = _ops.numpy_function
895
+ one_hot = _ops.one_hot
471
896
  ones = _ops.ones
472
897
  ones_like = _ops.ones_like
473
898
  rank = _ops.rank
@@ -479,6 +904,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
479
904
  reduce_sum = _ops.reduce_sum
480
905
  repeat = _ops.repeat
481
906
  reshape = _ops.reshape
907
+ roll = _ops.roll
482
908
  set_random_seed = tf_backend.keras.utils.set_random_seed
483
909
  split = _ops.split
484
910
  stack = _ops.stack
@@ -489,16 +915,233 @@ elif _BACKEND == config.Backend.TENSORFLOW:
489
915
  zeros = _ops.zeros
490
916
  zeros_like = _ops.zeros_like
491
917
 
918
+ stabilize_rf_roi_grid = _tf_stabilize_rf_roi_grid
919
+
492
920
  float32 = _ops.float32
493
921
  bool_ = _ops.bool
494
922
  newaxis = _ops.newaxis
495
923
  TensorShape = _ops.TensorShape
924
+ int32 = _ops.int32
925
+ string = _ops.string
496
926
 
497
927
  else:
498
928
  raise ValueError(f"Unsupported backend: {_BACKEND}")
499
929
  # pylint: enable=g-import-not-at-top,g-bad-import-order
500
930
 
501
931
 
932
+ def _extract_int_seed(s: Any) -> Optional[int]:
933
+ """Attempts to extract a scalar Python integer from various input types.
934
+
935
+ Args:
936
+ s: The input seed, which can be an int, Tensor, or array-like object.
937
+
938
+ Returns:
939
+ A Python integer if a scalar integer can be extracted, otherwise None.
940
+ """
941
+ try:
942
+ if isinstance(s, int):
943
+ return s
944
+
945
+ value = np.asarray(s)
946
+
947
+ if value.ndim == 0 and np.issubdtype(value.dtype, np.integer):
948
+ return int(value)
949
+ # A broad exception is used here because the input `s` can be of many types
950
+ # (e.g., JAX PRNGKey) which may cause np.asarray or other operations to fail
951
+ # in unpredictable ways. The goal is to safely attempt extraction and fail
952
+ # gracefully.
953
+ except Exception: # pylint: disable=broad-except
954
+ return None
955
+ return None
956
+
957
+
958
+ class _BaseRNGHandler(abc.ABC):
959
+ """A backend-agnostic abstract base class for random number generation state.
960
+
961
+ This handler provides a stateful-style interface for consuming randomness,
962
+ abstracting away the differences between JAX's stateless PRNG keys and
963
+ TensorFlow's paradigms.
964
+
965
+ Attributes:
966
+ _seed_input: The original seed object provided during initialization.
967
+ _int_seed: A Python integer extracted from the seed, if possible.
968
+ """
969
+
970
+ def __init__(self, seed: SeedType):
971
+ """Initializes the RNG handler.
972
+
973
+ Args:
974
+ seed: The initial seed. The accepted type depends on the backend. For JAX,
975
+ this must be an integer. For TensorFlow, this can be an integer, a
976
+ sequence of two integers, or a Tensor. If None, the handler becomes a
977
+ no-op, returning None for all seed requests.
978
+ """
979
+ self._seed_input = seed
980
+ self._int_seed: Optional[int] = _extract_int_seed(seed)
981
+
982
+ @abc.abstractmethod
983
+ def get_kernel_seed(self) -> Any:
984
+ """Provides a backend-appropriate sanitized seed/key for an MCMC kernel.
985
+
986
+ This method exposes the current state of the handler in the format expected
987
+ by the backend's MCMC machinery. It does not advance the internal state.
988
+
989
+ Returns:
990
+ A backend-specific seed object (e.g., a JAX PRNGKey or a TF Tensor).
991
+ """
992
+ raise NotImplementedError
993
+
994
+ @abc.abstractmethod
995
+ def get_next_seed(self) -> Any:
996
+ """Provides the appropriate seed object for the next sequential operation.
997
+
998
+ This is primarily used for prior sampling and typically advances the
999
+ internal state.
1000
+
1001
+ Returns:
1002
+ A backend-specific seed object for a single random operation.
1003
+ """
1004
+ raise NotImplementedError
1005
+
1006
+ @abc.abstractmethod
1007
+ def advance_handler(self) -> "_BaseRNGHandler":
1008
+ """Creates a new, independent RNGHandler for a subsequent operation.
1009
+
1010
+ This method is used to generate a new handler derived deterministically
1011
+ from the current handler's state.
1012
+
1013
+ Returns:
1014
+ A new, independent `RNGHandler` instance.
1015
+ """
1016
+ raise NotImplementedError
1017
+
1018
+
1019
+ class _JaxRNGHandler(_BaseRNGHandler):
1020
+ """JAX implementation of the RNGHandler using explicit key splitting."""
1021
+
1022
+ def __init__(self, seed: SeedType):
1023
+ """Initializes the JAX RNG handler.
1024
+
1025
+ Args:
1026
+ seed: The initial seed, which must be a Python integer, a scalar integer
1027
+ Tensor/array, or None.
1028
+
1029
+ Raises:
1030
+ ValueError: If the provided seed is not a scalar integer or None.
1031
+ """
1032
+ super().__init__(seed)
1033
+ self._key: Optional["_jax.Array"] = None
1034
+
1035
+ if seed is None:
1036
+ return
1037
+
1038
+ if (
1039
+ isinstance(seed, jax.Array) # pylint: disable=undefined-variable
1040
+ and seed.shape == (2,)
1041
+ and seed.dtype == jax_ops.uint32 # pylint: disable=undefined-variable
1042
+ ):
1043
+ self._key = seed
1044
+ return
1045
+
1046
+ if self._int_seed is None:
1047
+ raise ValueError(
1048
+ "JAX backend requires a seed that is an integer or a scalar array,"
1049
+ f" but got: {type(seed)} with value {seed!r}"
1050
+ )
1051
+
1052
+ self._key = random.prng_key(self._int_seed)
1053
+
1054
+ def get_kernel_seed(self) -> Any:
1055
+ if self._key is None:
1056
+ return None
1057
+ _, subkey = random.split(self._key)
1058
+ return int(jax.random.randint(subkey, (), 0, 2**31 - 1)) # pylint: disable=undefined-variable
1059
+
1060
+ def get_next_seed(self) -> Any:
1061
+ if self._key is None:
1062
+ return None
1063
+ self._key, subkey = random.split(self._key)
1064
+ return subkey
1065
+
1066
+ def advance_handler(self) -> "_JaxRNGHandler":
1067
+ if self._key is None:
1068
+ return _JaxRNGHandler(None)
1069
+
1070
+ self._key, subkey_for_new_handler = random.split(self._key)
1071
+ new_seed_tensor = random.stateless_randint(
1072
+ key=subkey_for_new_handler,
1073
+ shape=(),
1074
+ minval=0,
1075
+ maxval=_MAX_INT32,
1076
+ dtype=_DEFAULT_SEED_DTYPE,
1077
+ )
1078
+ return _JaxRNGHandler(np.asarray(new_seed_tensor).item())
1079
+
1080
+
1081
+ class _TFRNGHandler(_BaseRNGHandler):
1082
+ """A stateless-style RNG handler for TensorFlow.
1083
+
1084
+ This handler canonicalizes any seed input into a single stateless seed Tensor.
1085
+ """
1086
+
1087
+ def __init__(self, seed: SeedType):
1088
+ """Initializes the TensorFlow RNG handler.
1089
+
1090
+ Args:
1091
+ seed: The initial seed. Can be an integer, a sequence of two integers, a
1092
+ corresponding Tensor, or None. It will be sanitized and stored
1093
+ internally as a single stateless seed Tensor.
1094
+ """
1095
+ super().__init__(seed)
1096
+ self._seed_state: Optional["_tf.Tensor"] = None
1097
+
1098
+ if seed is None:
1099
+ return
1100
+
1101
+ if isinstance(seed, Sequence) and len(seed) != 2:
1102
+ raise ValueError(
1103
+ "Invalid seed: Must be either a single integer (stateful seed) or a"
1104
+ " pair of two integers (stateless seed). See"
1105
+ " [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
1106
+ " for details."
1107
+ )
1108
+
1109
+ if isinstance(seed, int):
1110
+ seed_to_sanitize = (seed, seed)
1111
+ else:
1112
+ seed_to_sanitize = seed
1113
+
1114
+ self._seed_state = random.sanitize_seed(seed_to_sanitize)
1115
+
1116
+ def get_kernel_seed(self) -> Any:
1117
+ """Provides the current stateless seed of the handler without advancing it."""
1118
+ return self._seed_state
1119
+
1120
+ def get_next_seed(self) -> Any:
1121
+ """Returns a new unique seed and advances the handler's internal state."""
1122
+ if self._seed_state is None:
1123
+ return None
1124
+ new_seed, next_state = random.stateless_split(self._seed_state)
1125
+ self._seed_state = next_state
1126
+ return new_seed
1127
+
1128
+ def advance_handler(self) -> "_TFRNGHandler":
1129
+ """Creates a new handler from a new seed and advances the current handler."""
1130
+ if self._seed_state is None:
1131
+ return _TFRNGHandler(None)
1132
+ seed_for_new_handler, next_state = random.stateless_split(self._seed_state)
1133
+ self._seed_state = next_state
1134
+ return _TFRNGHandler(seed_for_new_handler)
1135
+
1136
+
1137
+ if _BACKEND == config.Backend.JAX:
1138
+ RNGHandler = _JaxRNGHandler
1139
+ elif _BACKEND == config.Backend.TENSORFLOW:
1140
+ RNGHandler = _TFRNGHandler
1141
+ else:
1142
+ raise ImportError(f"RNGHandler not implemented for backend: {_BACKEND}")
1143
+
1144
+
502
1145
  def to_tensor(data: Any, dtype: Optional[Any] = None) -> Tensor: # type: ignore
503
1146
  """Converts input data to the currently active backend tensor type.
504
1147