google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
- google_meridian-1.3.0.dist-info/RECORD +62 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +280 -142
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +353 -169
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +14 -12
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +45 -50
- meridian/backend/__init__.py +698 -55
- meridian/backend/config.py +75 -16
- meridian/backend/test_utils.py +127 -1
- meridian/constants.py +52 -11
- meridian/data/input_data.py +7 -2
- meridian/data/test_utils.py +5 -3
- meridian/mlflow/autolog.py +2 -2
- meridian/model/__init__.py +1 -0
- meridian/model/adstock_hill.py +10 -9
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1580 -84
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +56 -50
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -9
- meridian/model/posterior_sampler.py +398 -391
- meridian/model/prior_distribution.py +114 -39
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +16 -8
- meridian/version.py +1 -1
- google_meridian-1.2.0.dist-info/RECORD +0 -52
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
meridian/backend/__init__.py
CHANGED
|
@@ -14,13 +14,16 @@
|
|
|
14
14
|
|
|
15
15
|
"""Backend Abstraction Layer for Meridian."""
|
|
16
16
|
|
|
17
|
+
import abc
|
|
18
|
+
import functools
|
|
17
19
|
import os
|
|
18
|
-
from typing import Any, Optional,
|
|
19
|
-
|
|
20
|
+
from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
|
21
|
+
import warnings
|
|
20
22
|
from meridian.backend import config
|
|
21
23
|
import numpy as np
|
|
22
24
|
from typing_extensions import Literal
|
|
23
25
|
|
|
26
|
+
|
|
24
27
|
# The conditional imports in this module are a deliberate design choice for the
|
|
25
28
|
# backend abstraction layer. The TFP-on-JAX substrate provides a nearly
|
|
26
29
|
# identical API to the standard TFP library, making an alias-based approach more
|
|
@@ -28,6 +31,20 @@ from typing_extensions import Literal
|
|
|
28
31
|
# extensive boilerplate.
|
|
29
32
|
# pylint: disable=g-import-not-at-top,g-bad-import-order
|
|
30
33
|
|
|
34
|
+
_DEFAULT_FLOAT = "float32"
|
|
35
|
+
_DEFAULT_INT = "int64"
|
|
36
|
+
|
|
37
|
+
_TENSORFLOW_TILE_KEYWORD = "multiples"
|
|
38
|
+
_JAX_TILE_KEYWORD = "reps"
|
|
39
|
+
|
|
40
|
+
_ARG_AUTOGRAPH = "autograph"
|
|
41
|
+
_ARG_JIT_COMPILE = "jit_compile"
|
|
42
|
+
_ARG_STATIC_ARGNUMS = "static_argnums"
|
|
43
|
+
_ARG_STATIC_ARGNAMES = "static_argnames"
|
|
44
|
+
|
|
45
|
+
_DEFAULT_SEED_DTYPE = "int32"
|
|
46
|
+
_MAX_INT32 = np.iinfo(np.int32).max
|
|
47
|
+
|
|
31
48
|
if TYPE_CHECKING:
|
|
32
49
|
import dataclasses
|
|
33
50
|
import jax as _jax
|
|
@@ -35,6 +52,8 @@ if TYPE_CHECKING:
|
|
|
35
52
|
|
|
36
53
|
TensorShapeInstance = Union[_tf.TensorShape, Tuple[int, ...]]
|
|
37
54
|
|
|
55
|
+
SeedType = Any
|
|
56
|
+
|
|
38
57
|
|
|
39
58
|
def standardize_dtype(dtype: Any) -> str:
|
|
40
59
|
"""Converts a backend-specific dtype to a standard string representation.
|
|
@@ -88,8 +107,8 @@ def result_type(*types: Any) -> str:
|
|
|
88
107
|
standardized_types.append(str(t))
|
|
89
108
|
|
|
90
109
|
if any("float" in t for t in standardized_types):
|
|
91
|
-
return
|
|
92
|
-
return
|
|
110
|
+
return _DEFAULT_FLOAT
|
|
111
|
+
return _DEFAULT_INT
|
|
93
112
|
|
|
94
113
|
|
|
95
114
|
def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
|
|
@@ -116,6 +135,54 @@ def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
|
|
|
116
135
|
|
|
117
136
|
|
|
118
137
|
# --- Private Backend-Specific Implementations ---
|
|
138
|
+
def _jax_stabilize_rf_roi_grid(
|
|
139
|
+
spend_grid: np.ndarray,
|
|
140
|
+
outcome_grid: np.ndarray,
|
|
141
|
+
n_rf_channels: int,
|
|
142
|
+
) -> np.ndarray:
|
|
143
|
+
"""Stabilizes the RF ROI grid for JAX using a stable index lookup."""
|
|
144
|
+
new_outcome_grid = outcome_grid.copy()
|
|
145
|
+
rf_slice = slice(-n_rf_channels, None)
|
|
146
|
+
rf_spend_grid = spend_grid[:, rf_slice]
|
|
147
|
+
rf_outcome_grid = new_outcome_grid[:, rf_slice]
|
|
148
|
+
|
|
149
|
+
last_valid_indices = np.sum(~np.isnan(rf_spend_grid), axis=0) - 1
|
|
150
|
+
channel_indices = np.arange(n_rf_channels)
|
|
151
|
+
|
|
152
|
+
ref_spend = rf_spend_grid[last_valid_indices, channel_indices]
|
|
153
|
+
ref_outcome = rf_outcome_grid[last_valid_indices, channel_indices]
|
|
154
|
+
|
|
155
|
+
rf_roi = np.divide(
|
|
156
|
+
ref_outcome,
|
|
157
|
+
ref_spend,
|
|
158
|
+
out=np.zeros_like(ref_outcome, dtype=np.float64),
|
|
159
|
+
where=(ref_spend != 0),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
new_outcome_grid[:, rf_slice] = rf_roi * rf_spend_grid
|
|
163
|
+
return new_outcome_grid
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _tf_stabilize_rf_roi_grid(
|
|
167
|
+
spend_grid: np.ndarray,
|
|
168
|
+
outcome_grid: np.ndarray,
|
|
169
|
+
n_rf_channels: int,
|
|
170
|
+
) -> np.ndarray:
|
|
171
|
+
"""Stabilizes the RF ROI grid for TF using nanmax logic."""
|
|
172
|
+
new_outcome_grid = outcome_grid.copy()
|
|
173
|
+
rf_slice = slice(-n_rf_channels, None)
|
|
174
|
+
rf_outcome_max = np.nanmax(new_outcome_grid[:, rf_slice], axis=0)
|
|
175
|
+
rf_spend_max = np.nanmax(spend_grid[:, rf_slice], axis=0)
|
|
176
|
+
|
|
177
|
+
rf_roi = np.divide(
|
|
178
|
+
rf_outcome_max,
|
|
179
|
+
rf_spend_max,
|
|
180
|
+
out=np.zeros_like(rf_outcome_max, dtype=np.float64),
|
|
181
|
+
where=(rf_spend_max != 0),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
new_outcome_grid[:, rf_slice] = rf_roi * spend_grid[:, rf_slice]
|
|
185
|
+
return new_outcome_grid
|
|
119
186
|
|
|
120
187
|
|
|
121
188
|
def _jax_arange(
|
|
@@ -163,6 +230,75 @@ def _jax_divide_no_nan(x, y):
|
|
|
163
230
|
return jnp.where(y != 0, jnp.divide(x, y), 0.0)
|
|
164
231
|
|
|
165
232
|
|
|
233
|
+
def _jax_function_wrapper(func=None, **kwargs):
|
|
234
|
+
"""A wrapper for jax.jit that handles TF-like args and static args.
|
|
235
|
+
|
|
236
|
+
This wrapper provides compatibility with TensorFlow's `tf.function` arguments
|
|
237
|
+
and improves ergonomics when decorating class methods in JAX.
|
|
238
|
+
|
|
239
|
+
By default, if neither `static_argnums` nor `static_argnames` are provided, it
|
|
240
|
+
defaults `static_argnums` to `(0,)`. This assumes the function is a method
|
|
241
|
+
where the first argument (`self` or `cls`) should be treated as static.
|
|
242
|
+
|
|
243
|
+
To disable this behavior for plain functions, explicitly provide an empty
|
|
244
|
+
tuple: `@backend.function(static_argnums=())`.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
func: The function to wrap.
|
|
248
|
+
**kwargs: Keyword arguments passed to jax.jit. TF-specific arguments (like
|
|
249
|
+
`jit_compile`) are ignored.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
The wrapped function or a decorator.
|
|
253
|
+
"""
|
|
254
|
+
jit_kwargs = kwargs.copy()
|
|
255
|
+
|
|
256
|
+
jit_kwargs.pop(_ARG_JIT_COMPILE, None)
|
|
257
|
+
jit_kwargs.pop(_ARG_AUTOGRAPH, None)
|
|
258
|
+
|
|
259
|
+
if _ARG_STATIC_ARGNUMS in jit_kwargs:
|
|
260
|
+
if not jit_kwargs[_ARG_STATIC_ARGNUMS]:
|
|
261
|
+
jit_kwargs.pop(_ARG_STATIC_ARGNUMS)
|
|
262
|
+
else:
|
|
263
|
+
jit_kwargs[_ARG_STATIC_ARGNUMS] = (0,)
|
|
264
|
+
|
|
265
|
+
decorator = functools.partial(jax.jit, **jit_kwargs)
|
|
266
|
+
|
|
267
|
+
if func:
|
|
268
|
+
return decorator(func)
|
|
269
|
+
return decorator
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _tf_function_wrapper(func=None, **kwargs):
|
|
273
|
+
"""A wrapper for tf.function that ignores JAX-specific arguments."""
|
|
274
|
+
import tensorflow as tf
|
|
275
|
+
|
|
276
|
+
kwargs.pop(_ARG_STATIC_ARGNAMES, None)
|
|
277
|
+
kwargs.pop(_ARG_STATIC_ARGNUMS, None)
|
|
278
|
+
|
|
279
|
+
decorator = tf.function(**kwargs)
|
|
280
|
+
|
|
281
|
+
if func:
|
|
282
|
+
return decorator(func)
|
|
283
|
+
return decorator
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _jax_nanmedian(a, axis=None):
|
|
287
|
+
"""JAX implementation for nanmedian."""
|
|
288
|
+
import jax.numpy as jnp
|
|
289
|
+
|
|
290
|
+
return jnp.nanmedian(a, axis=axis)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _tf_nanmedian(a, axis=None):
|
|
294
|
+
"""TensorFlow implementation for nanmedian using numpy_function."""
|
|
295
|
+
import tensorflow as tf
|
|
296
|
+
|
|
297
|
+
return tf.numpy_function(
|
|
298
|
+
lambda x: np.nanmedian(x, axis=axis).astype(x.dtype), [a], a.dtype
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
166
302
|
def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
|
|
167
303
|
raise NotImplementedError(
|
|
168
304
|
"backend.numpy_function is not implemented for the JAX backend."
|
|
@@ -195,6 +331,26 @@ def _tf_get_indices_where(condition):
|
|
|
195
331
|
return tf.where(condition)
|
|
196
332
|
|
|
197
333
|
|
|
334
|
+
def _jax_split(value, num_or_size_splits, axis=0):
|
|
335
|
+
"""JAX implementation for split that accepts size splits."""
|
|
336
|
+
import jax.numpy as jnp
|
|
337
|
+
|
|
338
|
+
if not isinstance(num_or_size_splits, int):
|
|
339
|
+
indices = jnp.cumsum(jnp.array(num_or_size_splits))[:-1]
|
|
340
|
+
return jnp.split(value, indices, axis=axis)
|
|
341
|
+
|
|
342
|
+
return jnp.split(value, num_or_size_splits, axis=axis)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _jax_tile(*args, **kwargs):
|
|
346
|
+
"""JAX wrapper for tile that supports the `multiples` keyword argument."""
|
|
347
|
+
import jax.numpy as jnp
|
|
348
|
+
|
|
349
|
+
if _TENSORFLOW_TILE_KEYWORD in kwargs:
|
|
350
|
+
kwargs[_JAX_TILE_KEYWORD] = kwargs.pop(_TENSORFLOW_TILE_KEYWORD)
|
|
351
|
+
return jnp.tile(*args, **kwargs)
|
|
352
|
+
|
|
353
|
+
|
|
198
354
|
def _jax_unique_with_counts(x):
|
|
199
355
|
"""JAX implementation for unique_with_counts."""
|
|
200
356
|
import jax.numpy as jnp
|
|
@@ -218,6 +374,7 @@ def _jax_boolean_mask(tensor, mask, axis=None):
|
|
|
218
374
|
|
|
219
375
|
if axis is None:
|
|
220
376
|
axis = 0
|
|
377
|
+
mask = jnp.asarray(mask)
|
|
221
378
|
tensor_swapped = jnp.moveaxis(tensor, axis, 0)
|
|
222
379
|
masked = tensor_swapped[mask]
|
|
223
380
|
return jnp.moveaxis(masked, 0, axis)
|
|
@@ -230,17 +387,26 @@ def _tf_boolean_mask(tensor, mask, axis=None):
|
|
|
230
387
|
return tf.boolean_mask(tensor, mask, axis=axis)
|
|
231
388
|
|
|
232
389
|
|
|
233
|
-
def _jax_gather(params, indices):
|
|
234
|
-
"""JAX implementation for gather."""
|
|
235
|
-
|
|
236
|
-
|
|
390
|
+
def _jax_gather(params, indices, axis=0):
|
|
391
|
+
"""JAX implementation for gather with axis support."""
|
|
392
|
+
import jax.numpy as jnp
|
|
393
|
+
|
|
394
|
+
if isinstance(params, (list, tuple)):
|
|
395
|
+
params = np.array(params)
|
|
396
|
+
|
|
397
|
+
# JAX can't JIT-compile operations on string or object arrays. We detect
|
|
398
|
+
# these types and fall back to standard NumPy operations.
|
|
399
|
+
if isinstance(params, np.ndarray) and params.dtype.kind in ("S", "U", "O"):
|
|
400
|
+
return np.take(params, np.asarray(indices), axis=axis)
|
|
401
|
+
|
|
402
|
+
return jnp.take(params, indices, axis=axis)
|
|
237
403
|
|
|
238
404
|
|
|
239
|
-
def _tf_gather(params, indices):
|
|
405
|
+
def _tf_gather(params, indices, axis=0):
|
|
240
406
|
"""TensorFlow implementation for gather."""
|
|
241
407
|
import tensorflow as tf
|
|
242
408
|
|
|
243
|
-
return tf.gather(params, indices)
|
|
409
|
+
return tf.gather(params, indices, axis=axis)
|
|
244
410
|
|
|
245
411
|
|
|
246
412
|
def _jax_fill(dims, value):
|
|
@@ -304,6 +470,93 @@ def _jax_tensor_shape(dims):
|
|
|
304
470
|
return tuple(dims)
|
|
305
471
|
|
|
306
472
|
|
|
473
|
+
def _jax_transpose(a, perm=None):
|
|
474
|
+
"""JAX wrapper for transpose to support the 'perm' keyword argument."""
|
|
475
|
+
import jax.numpy as jnp
|
|
476
|
+
|
|
477
|
+
return jnp.transpose(a, axes=perm)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def _jax_get_seed_data(seed: Any) -> Optional[np.ndarray]:
|
|
481
|
+
"""Extracts the underlying numerical data from a JAX PRNGKey."""
|
|
482
|
+
if seed is None:
|
|
483
|
+
return None
|
|
484
|
+
|
|
485
|
+
return np.array(jax.random.key_data(seed))
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def _tf_get_seed_data(seed: Any) -> Optional[np.ndarray]:
|
|
489
|
+
"""Converts a TensorFlow-style seed into a NumPy array."""
|
|
490
|
+
if seed is None:
|
|
491
|
+
return None
|
|
492
|
+
return np.array(seed)
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def _jax_convert_to_tensor(data, dtype=None):
|
|
496
|
+
"""Converts data to a JAX array, handling strings as NumPy arrays."""
|
|
497
|
+
# JAX does not natively support string tensors in the same way TF does.
|
|
498
|
+
# If a string dtype is requested, or if the data is inherently strings,
|
|
499
|
+
# we fall back to a standard NumPy array.
|
|
500
|
+
if dtype == np.str_ or (
|
|
501
|
+
dtype is None
|
|
502
|
+
and isinstance(data, (list, np.ndarray))
|
|
503
|
+
and np.array(data).dtype.kind in ("S", "U")
|
|
504
|
+
):
|
|
505
|
+
return np.array(data, dtype=np.str_)
|
|
506
|
+
return jax_ops.asarray(data, dtype=dtype)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def _tf_nanmean(a, axis=None, keepdims=False):
|
|
510
|
+
import tensorflow.experimental.numpy as tnp
|
|
511
|
+
|
|
512
|
+
return tf_backend.convert_to_tensor(
|
|
513
|
+
tnp.nanmean(a, axis=axis, keepdims=keepdims)
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _tf_nansum(a, axis=None, keepdims=False):
|
|
518
|
+
import tensorflow.experimental.numpy as tnp
|
|
519
|
+
|
|
520
|
+
return tf_backend.convert_to_tensor(
|
|
521
|
+
tnp.nansum(a, axis=axis, keepdims=keepdims)
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _tf_nanvar(a, axis=None, keepdims=False):
|
|
526
|
+
"""Calculates variance ignoring NaNs, strictly returning a Tensor."""
|
|
527
|
+
import tensorflow as tf
|
|
528
|
+
import tensorflow.experimental.numpy as tnp
|
|
529
|
+
# We implement two-pass variance to correctly handle NaNs and ensure
|
|
530
|
+
# all operations remain within the TF graph (maintaining differentiability).
|
|
531
|
+
a_tensor = tf.convert_to_tensor(a)
|
|
532
|
+
mean = tnp.nanmean(a_tensor, axis=axis, keepdims=True)
|
|
533
|
+
sq_diff = tf.math.squared_difference(a_tensor, mean)
|
|
534
|
+
var = tnp.nanmean(sq_diff, axis=axis, keepdims=keepdims)
|
|
535
|
+
return tf.convert_to_tensor(var)
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def _jax_one_hot(*args, **kwargs): # pylint: disable=unused-argument
|
|
539
|
+
"""JAX implementation for one_hot."""
|
|
540
|
+
raise NotImplementedError(
|
|
541
|
+
"backend.one_hot is not implemented for the JAX backend."
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def _jax_roll(*args, **kwargs): # pylint: disable=unused-argument
|
|
546
|
+
"""JAX implementation for roll."""
|
|
547
|
+
raise NotImplementedError(
|
|
548
|
+
"backend.roll is not implemented for the JAX backend."
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def _jax_enable_op_determinism():
|
|
553
|
+
"""No-op for JAX. Determinism is handled via stateless PRNGKeys."""
|
|
554
|
+
warnings.warn(
|
|
555
|
+
"op determinism is a TensorFlow-specific concept and has no effect when"
|
|
556
|
+
" using the JAX backend."
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
|
|
307
560
|
# --- Backend Initialization ---
|
|
308
561
|
_BACKEND = config.get_backend()
|
|
309
562
|
|
|
@@ -316,30 +569,149 @@ if _BACKEND == config.Backend.JAX:
|
|
|
316
569
|
import jax
|
|
317
570
|
import jax.numpy as jax_ops
|
|
318
571
|
import tensorflow_probability.substrates.jax as tfp_jax
|
|
572
|
+
from jax import tree_util
|
|
319
573
|
|
|
320
574
|
class ExtensionType:
|
|
321
|
-
"""A JAX-compatible stand-in for tf.experimental.ExtensionType.
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
575
|
+
"""A JAX-compatible stand-in for tf.experimental.ExtensionType.
|
|
576
|
+
|
|
577
|
+
This class registers itself as a JAX Pytree node, allowing it to be passed
|
|
578
|
+
through JIT-compiled functions.
|
|
579
|
+
"""
|
|
580
|
+
|
|
581
|
+
def __init_subclass__(cls, **kwargs):
|
|
582
|
+
super().__init_subclass__(**kwargs)
|
|
583
|
+
tree_util.register_pytree_node(
|
|
584
|
+
cls,
|
|
585
|
+
cls._tree_flatten,
|
|
586
|
+
cls._tree_unflatten,
|
|
326
587
|
)
|
|
327
588
|
|
|
589
|
+
def _tree_flatten(self):
|
|
590
|
+
"""Flattens the object for JAX tracing.
|
|
591
|
+
|
|
592
|
+
Fields containing JAX arrays (or convertibles) are treated as children.
|
|
593
|
+
Fields containing strings or NumPy string arrays must be treated as
|
|
594
|
+
auxiliary data because JAX cannot trace non-numeric/non-boolean data.
|
|
595
|
+
|
|
596
|
+
Returns:
|
|
597
|
+
A tuple of (children, aux_data), where children are the
|
|
598
|
+
tracer-compatible parts of the object, and aux_data contains auxiliary
|
|
599
|
+
information needed for unflattening.
|
|
600
|
+
"""
|
|
601
|
+
d = vars(self)
|
|
602
|
+
all_keys = sorted(d.keys())
|
|
603
|
+
children = []
|
|
604
|
+
aux = {}
|
|
605
|
+
children_keys = []
|
|
606
|
+
|
|
607
|
+
for k in all_keys:
|
|
608
|
+
v = d[k]
|
|
609
|
+
# Identify string data to prevent JAX tracing errors.
|
|
610
|
+
# 'S' is zero-terminated bytes (fixed-width), 'U' is unicode string.
|
|
611
|
+
is_numpy_string = isinstance(v, np.ndarray) and v.dtype.kind in (
|
|
612
|
+
"S",
|
|
613
|
+
"U",
|
|
614
|
+
)
|
|
615
|
+
is_plain_string = isinstance(v, str)
|
|
616
|
+
|
|
617
|
+
if is_numpy_string or is_plain_string:
|
|
618
|
+
aux[k] = v
|
|
619
|
+
else:
|
|
620
|
+
children.append(v)
|
|
621
|
+
children_keys.append(k)
|
|
622
|
+
return children, (aux, children_keys)
|
|
623
|
+
|
|
624
|
+
@classmethod
|
|
625
|
+
def _tree_unflatten(cls, aux_and_keys, children):
|
|
626
|
+
aux, children_keys = aux_and_keys
|
|
627
|
+
obj = cls.__new__(cls)
|
|
628
|
+
vars(obj).update(aux)
|
|
629
|
+
for k, v in zip(children_keys, children):
|
|
630
|
+
setattr(obj, k, v)
|
|
631
|
+
return obj
|
|
632
|
+
|
|
328
633
|
class _JaxErrors:
|
|
329
634
|
# pylint: disable=invalid-name
|
|
330
635
|
ResourceExhaustedError = MemoryError
|
|
331
636
|
InvalidArgumentError = ValueError
|
|
332
637
|
# pylint: enable=invalid-name
|
|
333
638
|
|
|
639
|
+
class _JaxRandom:
|
|
640
|
+
"""Provides JAX-based random number generation utilities.
|
|
641
|
+
|
|
642
|
+
This class mirrors the structure needed by `RNGHandler` for JAX.
|
|
643
|
+
"""
|
|
644
|
+
|
|
645
|
+
@staticmethod
|
|
646
|
+
def prng_key(seed):
|
|
647
|
+
return jax.random.PRNGKey(seed)
|
|
648
|
+
|
|
649
|
+
@staticmethod
|
|
650
|
+
def split(key):
|
|
651
|
+
return jax.random.split(key)
|
|
652
|
+
|
|
653
|
+
@staticmethod
|
|
654
|
+
def generator_from_seed(seed):
|
|
655
|
+
raise NotImplementedError("JAX backend does not use Generators.")
|
|
656
|
+
|
|
657
|
+
@staticmethod
|
|
658
|
+
def stateless_split(seed: Any, num: int = 2):
|
|
659
|
+
raise NotImplementedError(
|
|
660
|
+
"Direct stateless splitting from an integer seed is not the primary"
|
|
661
|
+
" pattern used in the JAX backend. Use `backend.random.split(key)`"
|
|
662
|
+
" instead."
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
@staticmethod
|
|
666
|
+
def stateless_randint(key, shape, minval, maxval, dtype=jax_ops.int32):
|
|
667
|
+
"""Wrapper for jax.random.randint."""
|
|
668
|
+
return jax.random.randint(
|
|
669
|
+
key, shape, minval=minval, maxval=maxval, dtype=dtype
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
@staticmethod
|
|
673
|
+
def stateless_uniform(
|
|
674
|
+
key, shape, dtype=jax_ops.float32, minval=0.0, maxval=1.0
|
|
675
|
+
):
|
|
676
|
+
"""Replacement for tfp_jax.random.stateless_uniform using jax.random.uniform."""
|
|
677
|
+
return jax.random.uniform(
|
|
678
|
+
key, shape=shape, dtype=dtype, minval=minval, maxval=maxval
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
@staticmethod
|
|
682
|
+
def sanitize_seed(seed):
|
|
683
|
+
return tfp_jax.random.sanitize_seed(seed)
|
|
684
|
+
|
|
685
|
+
random = _JaxRandom()
|
|
686
|
+
|
|
687
|
+
@functools.partial(
|
|
688
|
+
jax.jit,
|
|
689
|
+
static_argnames=[
|
|
690
|
+
"joint_dist",
|
|
691
|
+
"n_chains",
|
|
692
|
+
"n_draws",
|
|
693
|
+
"num_adaptation_steps",
|
|
694
|
+
"dual_averaging_kwargs",
|
|
695
|
+
"max_tree_depth",
|
|
696
|
+
"unrolled_leapfrog_steps",
|
|
697
|
+
"parallel_iterations",
|
|
698
|
+
],
|
|
699
|
+
)
|
|
700
|
+
def _jax_xla_windowed_adaptive_nuts(**kwargs):
|
|
701
|
+
"""JAX-specific JIT wrapper for the NUTS sampler."""
|
|
702
|
+
kwargs["seed"] = random.prng_key(kwargs["seed"])
|
|
703
|
+
return experimental.mcmc.windowed_adaptive_nuts(**kwargs)
|
|
704
|
+
|
|
705
|
+
xla_windowed_adaptive_nuts = _jax_xla_windowed_adaptive_nuts
|
|
706
|
+
|
|
334
707
|
_ops = jax_ops
|
|
335
708
|
errors = _JaxErrors()
|
|
336
709
|
Tensor = jax.Array
|
|
337
710
|
tfd = tfp_jax.distributions
|
|
338
711
|
bijectors = tfp_jax.bijectors
|
|
339
712
|
experimental = tfp_jax.experimental
|
|
340
|
-
random = tfp_jax.random
|
|
341
713
|
mcmc = tfp_jax.mcmc
|
|
342
|
-
_convert_to_tensor =
|
|
714
|
+
_convert_to_tensor = _jax_convert_to_tensor
|
|
343
715
|
|
|
344
716
|
# Standardized Public API
|
|
345
717
|
absolute = _ops.abs
|
|
@@ -347,19 +719,6 @@ if _BACKEND == config.Backend.JAX:
|
|
|
347
719
|
arange = _jax_arange
|
|
348
720
|
argmax = _jax_argmax
|
|
349
721
|
boolean_mask = _jax_boolean_mask
|
|
350
|
-
concatenate = _ops.concatenate
|
|
351
|
-
stack = _ops.stack
|
|
352
|
-
split = _ops.split
|
|
353
|
-
zeros = _ops.zeros
|
|
354
|
-
zeros_like = _ops.zeros_like
|
|
355
|
-
ones = _ops.ones
|
|
356
|
-
ones_like = _ops.ones_like
|
|
357
|
-
repeat = _ops.repeat
|
|
358
|
-
reshape = _ops.reshape
|
|
359
|
-
tile = _ops.tile
|
|
360
|
-
where = _ops.where
|
|
361
|
-
transpose = _ops.transpose
|
|
362
|
-
broadcast_to = _ops.broadcast_to
|
|
363
722
|
broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
|
|
364
723
|
broadcast_to = _ops.broadcast_to
|
|
365
724
|
cast = _jax_cast
|
|
@@ -368,18 +727,25 @@ if _BACKEND == config.Backend.JAX:
|
|
|
368
727
|
divide = _ops.divide
|
|
369
728
|
divide_no_nan = _jax_divide_no_nan
|
|
370
729
|
einsum = _ops.einsum
|
|
730
|
+
enable_op_determinism = _jax_enable_op_determinism
|
|
371
731
|
equal = _ops.equal
|
|
372
732
|
exp = _ops.exp
|
|
373
733
|
expand_dims = _ops.expand_dims
|
|
374
734
|
fill = _jax_fill
|
|
375
|
-
function =
|
|
735
|
+
function = _jax_function_wrapper
|
|
376
736
|
gather = _jax_gather
|
|
377
737
|
get_indices_where = _jax_get_indices_where
|
|
738
|
+
get_seed_data = _jax_get_seed_data
|
|
378
739
|
is_nan = _ops.isnan
|
|
379
740
|
log = _ops.log
|
|
380
741
|
make_ndarray = _jax_make_ndarray
|
|
381
742
|
make_tensor_proto = _jax_make_tensor_proto
|
|
743
|
+
nanmean = jax_ops.nanmean
|
|
744
|
+
nanmedian = _jax_nanmedian
|
|
745
|
+
nansum = jax_ops.nansum
|
|
746
|
+
nanvar = jax_ops.nanvar
|
|
382
747
|
numpy_function = _jax_numpy_function
|
|
748
|
+
one_hot = _jax_one_hot
|
|
383
749
|
ones = _ops.ones
|
|
384
750
|
ones_like = _ops.ones_like
|
|
385
751
|
rank = _ops.ndim
|
|
@@ -391,10 +757,11 @@ if _BACKEND == config.Backend.JAX:
|
|
|
391
757
|
reduce_sum = _ops.sum
|
|
392
758
|
repeat = _ops.repeat
|
|
393
759
|
reshape = _ops.reshape
|
|
394
|
-
|
|
760
|
+
roll = _jax_roll
|
|
761
|
+
split = _jax_split
|
|
395
762
|
stack = _ops.stack
|
|
396
|
-
tile =
|
|
397
|
-
transpose =
|
|
763
|
+
tile = _jax_tile
|
|
764
|
+
transpose = _jax_transpose
|
|
398
765
|
unique_with_counts = _jax_unique_with_counts
|
|
399
766
|
where = _ops.where
|
|
400
767
|
zeros = _ops.zeros
|
|
@@ -404,13 +771,15 @@ if _BACKEND == config.Backend.JAX:
|
|
|
404
771
|
bool_ = _ops.bool_
|
|
405
772
|
newaxis = _ops.newaxis
|
|
406
773
|
TensorShape = _jax_tensor_shape
|
|
774
|
+
int32 = _ops.int32
|
|
775
|
+
string = np.str_
|
|
776
|
+
|
|
777
|
+
stabilize_rf_roi_grid = _jax_stabilize_rf_roi_grid
|
|
407
778
|
|
|
408
779
|
def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
|
|
409
|
-
|
|
410
|
-
"
|
|
411
|
-
"
|
|
412
|
-
" integer directly to the sampling methods (e.g., `sample_prior`),"
|
|
413
|
-
" which will be used to create a JAX PRNGKey internally."
|
|
780
|
+
warnings.warn(
|
|
781
|
+
"backend.set_random_seed is a no-op in JAX. Randomness is managed "
|
|
782
|
+
"statelessly via PRNGKeys passed to sampling methods."
|
|
414
783
|
)
|
|
415
784
|
|
|
416
785
|
elif _BACKEND == config.Backend.TENSORFLOW:
|
|
@@ -423,31 +792,80 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
423
792
|
Tensor = tf_backend.Tensor
|
|
424
793
|
ExtensionType = _ops.experimental.ExtensionType
|
|
425
794
|
|
|
795
|
+
class _TfRandom:
|
|
796
|
+
"""Provides TensorFlow-based random number generation utilities.
|
|
797
|
+
|
|
798
|
+
This class mirrors the structure needed by `RNGHandler` for TensorFlow.
|
|
799
|
+
"""
|
|
800
|
+
|
|
801
|
+
@staticmethod
|
|
802
|
+
def prng_key(seed):
|
|
803
|
+
raise NotImplementedError(
|
|
804
|
+
"TensorFlow backend does not use explicit PRNG keys for"
|
|
805
|
+
" standard sampling."
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
@staticmethod
|
|
809
|
+
def split(key):
|
|
810
|
+
raise NotImplementedError(
|
|
811
|
+
"TensorFlow backend does not implement explicit key splitting."
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
@staticmethod
|
|
815
|
+
def generator_from_seed(seed):
|
|
816
|
+
return tf_backend.random.Generator.from_seed(seed)
|
|
817
|
+
|
|
818
|
+
@staticmethod
|
|
819
|
+
def stateless_split(
|
|
820
|
+
seed: "tf_backend.Tensor", num: int = 2
|
|
821
|
+
) -> "tf_backend.Tensor":
|
|
822
|
+
return tf_backend.random.experimental.stateless_split(seed, num=num)
|
|
823
|
+
|
|
824
|
+
@staticmethod
|
|
825
|
+
def stateless_randint(
|
|
826
|
+
seed: "tf_backend.Tensor",
|
|
827
|
+
shape: "TensorShapeInstance",
|
|
828
|
+
minval: int,
|
|
829
|
+
maxval: int,
|
|
830
|
+
dtype: Any = _DEFAULT_SEED_DTYPE,
|
|
831
|
+
) -> "tf_backend.Tensor":
|
|
832
|
+
return tf_backend.random.stateless_uniform(
|
|
833
|
+
shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
@staticmethod
|
|
837
|
+
def stateless_uniform(
|
|
838
|
+
seed, shape, minval=0, maxval=None, dtype=tf_backend.float32
|
|
839
|
+
):
|
|
840
|
+
return tf_backend.random.stateless_uniform(
|
|
841
|
+
shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
@staticmethod
|
|
845
|
+
def sanitize_seed(seed):
|
|
846
|
+
return tfp.random.sanitize_seed(seed)
|
|
847
|
+
|
|
848
|
+
random = _TfRandom()
|
|
849
|
+
|
|
850
|
+
@tf_backend.function(autograph=False, jit_compile=True)
|
|
851
|
+
def _tf_xla_windowed_adaptive_nuts(**kwargs):
|
|
852
|
+
"""TensorFlow-specific XLA wrapper for the NUTS sampler."""
|
|
853
|
+
return experimental.mcmc.windowed_adaptive_nuts(**kwargs)
|
|
854
|
+
|
|
855
|
+
xla_windowed_adaptive_nuts = _tf_xla_windowed_adaptive_nuts
|
|
856
|
+
|
|
426
857
|
tfd = tfp.distributions
|
|
427
858
|
bijectors = tfp.bijectors
|
|
428
859
|
experimental = tfp.experimental
|
|
429
|
-
random = tfp.random
|
|
430
860
|
mcmc = tfp.mcmc
|
|
431
|
-
|
|
432
861
|
_convert_to_tensor = _ops.convert_to_tensor
|
|
862
|
+
|
|
863
|
+
# Standardized Public API
|
|
433
864
|
absolute = _ops.math.abs
|
|
434
865
|
allclose = _ops.experimental.numpy.allclose
|
|
435
866
|
arange = _tf_arange
|
|
436
867
|
argmax = _tf_argmax
|
|
437
868
|
boolean_mask = _tf_boolean_mask
|
|
438
|
-
concatenate = _ops.concat
|
|
439
|
-
stack = _ops.stack
|
|
440
|
-
split = _ops.split
|
|
441
|
-
zeros = _ops.zeros
|
|
442
|
-
zeros_like = _ops.zeros_like
|
|
443
|
-
ones = _ops.ones
|
|
444
|
-
ones_like = _ops.ones_like
|
|
445
|
-
repeat = _ops.repeat
|
|
446
|
-
reshape = _ops.reshape
|
|
447
|
-
tile = _ops.tile
|
|
448
|
-
where = _ops.where
|
|
449
|
-
transpose = _ops.transpose
|
|
450
|
-
broadcast_to = _ops.broadcast_to
|
|
451
869
|
broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
|
|
452
870
|
broadcast_to = _ops.broadcast_to
|
|
453
871
|
cast = _ops.cast
|
|
@@ -456,18 +874,25 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
456
874
|
divide = _ops.divide
|
|
457
875
|
divide_no_nan = _ops.math.divide_no_nan
|
|
458
876
|
einsum = _ops.einsum
|
|
877
|
+
enable_op_determinism = _ops.config.experimental.enable_op_determinism
|
|
459
878
|
equal = _ops.equal
|
|
460
879
|
exp = _ops.math.exp
|
|
461
880
|
expand_dims = _ops.expand_dims
|
|
462
881
|
fill = _tf_fill
|
|
463
|
-
function =
|
|
882
|
+
function = _tf_function_wrapper
|
|
464
883
|
gather = _tf_gather
|
|
465
884
|
get_indices_where = _tf_get_indices_where
|
|
885
|
+
get_seed_data = _tf_get_seed_data
|
|
466
886
|
is_nan = _ops.math.is_nan
|
|
467
887
|
log = _ops.math.log
|
|
468
888
|
make_ndarray = _ops.make_ndarray
|
|
469
889
|
make_tensor_proto = _ops.make_tensor_proto
|
|
890
|
+
nanmean = _tf_nanmean
|
|
891
|
+
nanmedian = _tf_nanmedian
|
|
892
|
+
nansum = _tf_nansum
|
|
893
|
+
nanvar = _tf_nanvar
|
|
470
894
|
numpy_function = _ops.numpy_function
|
|
895
|
+
one_hot = _ops.one_hot
|
|
471
896
|
ones = _ops.ones
|
|
472
897
|
ones_like = _ops.ones_like
|
|
473
898
|
rank = _ops.rank
|
|
@@ -479,6 +904,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
479
904
|
reduce_sum = _ops.reduce_sum
|
|
480
905
|
repeat = _ops.repeat
|
|
481
906
|
reshape = _ops.reshape
|
|
907
|
+
roll = _ops.roll
|
|
482
908
|
set_random_seed = tf_backend.keras.utils.set_random_seed
|
|
483
909
|
split = _ops.split
|
|
484
910
|
stack = _ops.stack
|
|
@@ -489,16 +915,233 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
489
915
|
zeros = _ops.zeros
|
|
490
916
|
zeros_like = _ops.zeros_like
|
|
491
917
|
|
|
918
|
+
stabilize_rf_roi_grid = _tf_stabilize_rf_roi_grid
|
|
919
|
+
|
|
492
920
|
float32 = _ops.float32
|
|
493
921
|
bool_ = _ops.bool
|
|
494
922
|
newaxis = _ops.newaxis
|
|
495
923
|
TensorShape = _ops.TensorShape
|
|
924
|
+
int32 = _ops.int32
|
|
925
|
+
string = _ops.string
|
|
496
926
|
|
|
497
927
|
else:
|
|
498
928
|
raise ValueError(f"Unsupported backend: {_BACKEND}")
|
|
499
929
|
# pylint: enable=g-import-not-at-top,g-bad-import-order
|
|
500
930
|
|
|
501
931
|
|
|
932
|
+
def _extract_int_seed(s: Any) -> Optional[int]:
|
|
933
|
+
"""Attempts to extract a scalar Python integer from various input types.
|
|
934
|
+
|
|
935
|
+
Args:
|
|
936
|
+
s: The input seed, which can be an int, Tensor, or array-like object.
|
|
937
|
+
|
|
938
|
+
Returns:
|
|
939
|
+
A Python integer if a scalar integer can be extracted, otherwise None.
|
|
940
|
+
"""
|
|
941
|
+
try:
|
|
942
|
+
if isinstance(s, int):
|
|
943
|
+
return s
|
|
944
|
+
|
|
945
|
+
value = np.asarray(s)
|
|
946
|
+
|
|
947
|
+
if value.ndim == 0 and np.issubdtype(value.dtype, np.integer):
|
|
948
|
+
return int(value)
|
|
949
|
+
# A broad exception is used here because the input `s` can be of many types
|
|
950
|
+
# (e.g., JAX PRNGKey) which may cause np.asarray or other operations to fail
|
|
951
|
+
# in unpredictable ways. The goal is to safely attempt extraction and fail
|
|
952
|
+
# gracefully.
|
|
953
|
+
except Exception: # pylint: disable=broad-except
|
|
954
|
+
return None
|
|
955
|
+
return None
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
class _BaseRNGHandler(abc.ABC):
|
|
959
|
+
"""A backend-agnostic abstract base class for random number generation state.
|
|
960
|
+
|
|
961
|
+
This handler provides a stateful-style interface for consuming randomness,
|
|
962
|
+
abstracting away the differences between JAX's stateless PRNG keys and
|
|
963
|
+
TensorFlow's paradigms.
|
|
964
|
+
|
|
965
|
+
Attributes:
|
|
966
|
+
_seed_input: The original seed object provided during initialization.
|
|
967
|
+
_int_seed: A Python integer extracted from the seed, if possible.
|
|
968
|
+
"""
|
|
969
|
+
|
|
970
|
+
def __init__(self, seed: SeedType):
|
|
971
|
+
"""Initializes the RNG handler.
|
|
972
|
+
|
|
973
|
+
Args:
|
|
974
|
+
seed: The initial seed. The accepted type depends on the backend. For JAX,
|
|
975
|
+
this must be an integer. For TensorFlow, this can be an integer, a
|
|
976
|
+
sequence of two integers, or a Tensor. If None, the handler becomes a
|
|
977
|
+
no-op, returning None for all seed requests.
|
|
978
|
+
"""
|
|
979
|
+
self._seed_input = seed
|
|
980
|
+
self._int_seed: Optional[int] = _extract_int_seed(seed)
|
|
981
|
+
|
|
982
|
+
@abc.abstractmethod
|
|
983
|
+
def get_kernel_seed(self) -> Any:
|
|
984
|
+
"""Provides a backend-appropriate sanitized seed/key for an MCMC kernel.
|
|
985
|
+
|
|
986
|
+
This method exposes the current state of the handler in the format expected
|
|
987
|
+
by the backend's MCMC machinery. It does not advance the internal state.
|
|
988
|
+
|
|
989
|
+
Returns:
|
|
990
|
+
A backend-specific seed object (e.g., a JAX PRNGKey or a TF Tensor).
|
|
991
|
+
"""
|
|
992
|
+
raise NotImplementedError
|
|
993
|
+
|
|
994
|
+
@abc.abstractmethod
|
|
995
|
+
def get_next_seed(self) -> Any:
|
|
996
|
+
"""Provides the appropriate seed object for the next sequential operation.
|
|
997
|
+
|
|
998
|
+
This is primarily used for prior sampling and typically advances the
|
|
999
|
+
internal state.
|
|
1000
|
+
|
|
1001
|
+
Returns:
|
|
1002
|
+
A backend-specific seed object for a single random operation.
|
|
1003
|
+
"""
|
|
1004
|
+
raise NotImplementedError
|
|
1005
|
+
|
|
1006
|
+
@abc.abstractmethod
|
|
1007
|
+
def advance_handler(self) -> "_BaseRNGHandler":
|
|
1008
|
+
"""Creates a new, independent RNGHandler for a subsequent operation.
|
|
1009
|
+
|
|
1010
|
+
This method is used to generate a new handler derived deterministically
|
|
1011
|
+
from the current handler's state.
|
|
1012
|
+
|
|
1013
|
+
Returns:
|
|
1014
|
+
A new, independent `RNGHandler` instance.
|
|
1015
|
+
"""
|
|
1016
|
+
raise NotImplementedError
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
class _JaxRNGHandler(_BaseRNGHandler):
|
|
1020
|
+
"""JAX implementation of the RNGHandler using explicit key splitting."""
|
|
1021
|
+
|
|
1022
|
+
def __init__(self, seed: SeedType):
|
|
1023
|
+
"""Initializes the JAX RNG handler.
|
|
1024
|
+
|
|
1025
|
+
Args:
|
|
1026
|
+
seed: The initial seed, which must be a Python integer, a scalar integer
|
|
1027
|
+
Tensor/array, or None.
|
|
1028
|
+
|
|
1029
|
+
Raises:
|
|
1030
|
+
ValueError: If the provided seed is not a scalar integer or None.
|
|
1031
|
+
"""
|
|
1032
|
+
super().__init__(seed)
|
|
1033
|
+
self._key: Optional["_jax.Array"] = None
|
|
1034
|
+
|
|
1035
|
+
if seed is None:
|
|
1036
|
+
return
|
|
1037
|
+
|
|
1038
|
+
if (
|
|
1039
|
+
isinstance(seed, jax.Array) # pylint: disable=undefined-variable
|
|
1040
|
+
and seed.shape == (2,)
|
|
1041
|
+
and seed.dtype == jax_ops.uint32 # pylint: disable=undefined-variable
|
|
1042
|
+
):
|
|
1043
|
+
self._key = seed
|
|
1044
|
+
return
|
|
1045
|
+
|
|
1046
|
+
if self._int_seed is None:
|
|
1047
|
+
raise ValueError(
|
|
1048
|
+
"JAX backend requires a seed that is an integer or a scalar array,"
|
|
1049
|
+
f" but got: {type(seed)} with value {seed!r}"
|
|
1050
|
+
)
|
|
1051
|
+
|
|
1052
|
+
self._key = random.prng_key(self._int_seed)
|
|
1053
|
+
|
|
1054
|
+
def get_kernel_seed(self) -> Any:
|
|
1055
|
+
if self._key is None:
|
|
1056
|
+
return None
|
|
1057
|
+
_, subkey = random.split(self._key)
|
|
1058
|
+
return int(jax.random.randint(subkey, (), 0, 2**31 - 1)) # pylint: disable=undefined-variable
|
|
1059
|
+
|
|
1060
|
+
def get_next_seed(self) -> Any:
|
|
1061
|
+
if self._key is None:
|
|
1062
|
+
return None
|
|
1063
|
+
self._key, subkey = random.split(self._key)
|
|
1064
|
+
return subkey
|
|
1065
|
+
|
|
1066
|
+
def advance_handler(self) -> "_JaxRNGHandler":
|
|
1067
|
+
if self._key is None:
|
|
1068
|
+
return _JaxRNGHandler(None)
|
|
1069
|
+
|
|
1070
|
+
self._key, subkey_for_new_handler = random.split(self._key)
|
|
1071
|
+
new_seed_tensor = random.stateless_randint(
|
|
1072
|
+
key=subkey_for_new_handler,
|
|
1073
|
+
shape=(),
|
|
1074
|
+
minval=0,
|
|
1075
|
+
maxval=_MAX_INT32,
|
|
1076
|
+
dtype=_DEFAULT_SEED_DTYPE,
|
|
1077
|
+
)
|
|
1078
|
+
return _JaxRNGHandler(np.asarray(new_seed_tensor).item())
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
class _TFRNGHandler(_BaseRNGHandler):
|
|
1082
|
+
"""A stateless-style RNG handler for TensorFlow.
|
|
1083
|
+
|
|
1084
|
+
This handler canonicalizes any seed input into a single stateless seed Tensor.
|
|
1085
|
+
"""
|
|
1086
|
+
|
|
1087
|
+
def __init__(self, seed: SeedType):
|
|
1088
|
+
"""Initializes the TensorFlow RNG handler.
|
|
1089
|
+
|
|
1090
|
+
Args:
|
|
1091
|
+
seed: The initial seed. Can be an integer, a sequence of two integers, a
|
|
1092
|
+
corresponding Tensor, or None. It will be sanitized and stored
|
|
1093
|
+
internally as a single stateless seed Tensor.
|
|
1094
|
+
"""
|
|
1095
|
+
super().__init__(seed)
|
|
1096
|
+
self._seed_state: Optional["_tf.Tensor"] = None
|
|
1097
|
+
|
|
1098
|
+
if seed is None:
|
|
1099
|
+
return
|
|
1100
|
+
|
|
1101
|
+
if isinstance(seed, Sequence) and len(seed) != 2:
|
|
1102
|
+
raise ValueError(
|
|
1103
|
+
"Invalid seed: Must be either a single integer (stateful seed) or a"
|
|
1104
|
+
" pair of two integers (stateless seed). See"
|
|
1105
|
+
" [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
|
|
1106
|
+
" for details."
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
if isinstance(seed, int):
|
|
1110
|
+
seed_to_sanitize = (seed, seed)
|
|
1111
|
+
else:
|
|
1112
|
+
seed_to_sanitize = seed
|
|
1113
|
+
|
|
1114
|
+
self._seed_state = random.sanitize_seed(seed_to_sanitize)
|
|
1115
|
+
|
|
1116
|
+
def get_kernel_seed(self) -> Any:
|
|
1117
|
+
"""Provides the current stateless seed of the handler without advancing it."""
|
|
1118
|
+
return self._seed_state
|
|
1119
|
+
|
|
1120
|
+
def get_next_seed(self) -> Any:
|
|
1121
|
+
"""Returns a new unique seed and advances the handler's internal state."""
|
|
1122
|
+
if self._seed_state is None:
|
|
1123
|
+
return None
|
|
1124
|
+
new_seed, next_state = random.stateless_split(self._seed_state)
|
|
1125
|
+
self._seed_state = next_state
|
|
1126
|
+
return new_seed
|
|
1127
|
+
|
|
1128
|
+
def advance_handler(self) -> "_TFRNGHandler":
|
|
1129
|
+
"""Creates a new handler from a new seed and advances the current handler."""
|
|
1130
|
+
if self._seed_state is None:
|
|
1131
|
+
return _TFRNGHandler(None)
|
|
1132
|
+
seed_for_new_handler, next_state = random.stateless_split(self._seed_state)
|
|
1133
|
+
self._seed_state = next_state
|
|
1134
|
+
return _TFRNGHandler(seed_for_new_handler)
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
if _BACKEND == config.Backend.JAX:
|
|
1138
|
+
RNGHandler = _JaxRNGHandler
|
|
1139
|
+
elif _BACKEND == config.Backend.TENSORFLOW:
|
|
1140
|
+
RNGHandler = _TFRNGHandler
|
|
1141
|
+
else:
|
|
1142
|
+
raise ImportError(f"RNGHandler not implemented for backend: {_BACKEND}")
|
|
1143
|
+
|
|
1144
|
+
|
|
502
1145
|
def to_tensor(data: Any, dtype: Optional[Any] = None) -> Tensor: # type: ignore
|
|
503
1146
|
"""Converts input data to the currently active backend tensor type.
|
|
504
1147
|
|