google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
meridian/backend/__init__.py
CHANGED
|
@@ -18,6 +18,7 @@ import abc
|
|
|
18
18
|
import functools
|
|
19
19
|
import os
|
|
20
20
|
from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
|
21
|
+
import warnings
|
|
21
22
|
|
|
22
23
|
from meridian.backend import config
|
|
23
24
|
import numpy as np
|
|
@@ -37,6 +38,7 @@ _DEFAULT_INT = "int64"
|
|
|
37
38
|
_TENSORFLOW_TILE_KEYWORD = "multiples"
|
|
38
39
|
_JAX_TILE_KEYWORD = "reps"
|
|
39
40
|
|
|
41
|
+
_ARG_AUTOGRAPH = "autograph"
|
|
40
42
|
_ARG_JIT_COMPILE = "jit_compile"
|
|
41
43
|
_ARG_STATIC_ARGNUMS = "static_argnums"
|
|
42
44
|
_ARG_STATIC_ARGNAMES = "static_argnames"
|
|
@@ -134,6 +136,54 @@ def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
|
|
|
134
136
|
|
|
135
137
|
|
|
136
138
|
# --- Private Backend-Specific Implementations ---
|
|
139
|
+
def _jax_stabilize_rf_roi_grid(
|
|
140
|
+
spend_grid: np.ndarray,
|
|
141
|
+
outcome_grid: np.ndarray,
|
|
142
|
+
n_rf_channels: int,
|
|
143
|
+
) -> np.ndarray:
|
|
144
|
+
"""Stabilizes the RF ROI grid for JAX using a stable index lookup."""
|
|
145
|
+
new_outcome_grid = outcome_grid.copy()
|
|
146
|
+
rf_slice = slice(-n_rf_channels, None)
|
|
147
|
+
rf_spend_grid = spend_grid[:, rf_slice]
|
|
148
|
+
rf_outcome_grid = new_outcome_grid[:, rf_slice]
|
|
149
|
+
|
|
150
|
+
last_valid_indices = np.sum(~np.isnan(rf_spend_grid), axis=0) - 1
|
|
151
|
+
channel_indices = np.arange(n_rf_channels)
|
|
152
|
+
|
|
153
|
+
ref_spend = rf_spend_grid[last_valid_indices, channel_indices]
|
|
154
|
+
ref_outcome = rf_outcome_grid[last_valid_indices, channel_indices]
|
|
155
|
+
|
|
156
|
+
rf_roi = np.divide(
|
|
157
|
+
ref_outcome,
|
|
158
|
+
ref_spend,
|
|
159
|
+
out=np.zeros_like(ref_outcome, dtype=np.float64),
|
|
160
|
+
where=(ref_spend != 0),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
new_outcome_grid[:, rf_slice] = rf_roi * rf_spend_grid
|
|
164
|
+
return new_outcome_grid
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _tf_stabilize_rf_roi_grid(
|
|
168
|
+
spend_grid: np.ndarray,
|
|
169
|
+
outcome_grid: np.ndarray,
|
|
170
|
+
n_rf_channels: int,
|
|
171
|
+
) -> np.ndarray:
|
|
172
|
+
"""Stabilizes the RF ROI grid for TF using nanmax logic."""
|
|
173
|
+
new_outcome_grid = outcome_grid.copy()
|
|
174
|
+
rf_slice = slice(-n_rf_channels, None)
|
|
175
|
+
rf_outcome_max = np.nanmax(new_outcome_grid[:, rf_slice], axis=0)
|
|
176
|
+
rf_spend_max = np.nanmax(spend_grid[:, rf_slice], axis=0)
|
|
177
|
+
|
|
178
|
+
rf_roi = np.divide(
|
|
179
|
+
rf_outcome_max,
|
|
180
|
+
rf_spend_max,
|
|
181
|
+
out=np.zeros_like(rf_outcome_max, dtype=np.float64),
|
|
182
|
+
where=(rf_spend_max != 0),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
new_outcome_grid[:, rf_slice] = rf_roi * spend_grid[:, rf_slice]
|
|
186
|
+
return new_outcome_grid
|
|
137
187
|
|
|
138
188
|
|
|
139
189
|
def _jax_arange(
|
|
@@ -171,7 +221,7 @@ def _tf_arange(
|
|
|
171
221
|
|
|
172
222
|
def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
|
|
173
223
|
"""JAX implementation for cast."""
|
|
174
|
-
return
|
|
224
|
+
return jax_ops.asarray(x, dtype=dtype)
|
|
175
225
|
|
|
176
226
|
|
|
177
227
|
def _jax_divide_no_nan(x, y):
|
|
@@ -205,6 +255,7 @@ def _jax_function_wrapper(func=None, **kwargs):
|
|
|
205
255
|
jit_kwargs = kwargs.copy()
|
|
206
256
|
|
|
207
257
|
jit_kwargs.pop(_ARG_JIT_COMPILE, None)
|
|
258
|
+
jit_kwargs.pop(_ARG_AUTOGRAPH, None)
|
|
208
259
|
|
|
209
260
|
if _ARG_STATIC_ARGNUMS in jit_kwargs:
|
|
210
261
|
if not jit_kwargs[_ARG_STATIC_ARGNUMS]:
|
|
@@ -255,17 +306,132 @@ def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
|
|
|
255
306
|
)
|
|
256
307
|
|
|
257
308
|
|
|
258
|
-
def _jax_make_tensor_proto(
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
309
|
+
def _jax_make_tensor_proto(values, dtype=None, shape=None): # pylint: disable=unused-argument
|
|
310
|
+
"""JAX implementation for make_tensor_proto."""
|
|
311
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
312
|
+
from tensorflow.core.framework import tensor_pb2
|
|
313
|
+
from tensorflow.core.framework import tensor_shape_pb2
|
|
314
|
+
from tensorflow.core.framework import types_pb2
|
|
315
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
262
316
|
|
|
317
|
+
if not isinstance(values, np.ndarray):
|
|
318
|
+
values = np.array(values)
|
|
263
319
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
320
|
+
if dtype:
|
|
321
|
+
numpy_dtype = np.dtype(dtype)
|
|
322
|
+
values = values.astype(numpy_dtype)
|
|
323
|
+
else:
|
|
324
|
+
numpy_dtype = values.dtype
|
|
325
|
+
|
|
326
|
+
dtype_map = {
|
|
327
|
+
np.dtype(np.float16): types_pb2.DT_HALF,
|
|
328
|
+
np.dtype(np.float32): types_pb2.DT_FLOAT,
|
|
329
|
+
np.dtype(np.float64): types_pb2.DT_DOUBLE,
|
|
330
|
+
np.dtype(np.int32): types_pb2.DT_INT32,
|
|
331
|
+
np.dtype(np.uint8): types_pb2.DT_UINT8,
|
|
332
|
+
np.dtype(np.uint16): types_pb2.DT_UINT16,
|
|
333
|
+
np.dtype(np.uint32): types_pb2.DT_UINT32,
|
|
334
|
+
np.dtype(np.uint64): types_pb2.DT_UINT64,
|
|
335
|
+
np.dtype(np.int16): types_pb2.DT_INT16,
|
|
336
|
+
np.dtype(np.int8): types_pb2.DT_INT8,
|
|
337
|
+
np.dtype(np.int64): types_pb2.DT_INT64,
|
|
338
|
+
np.dtype(np.complex64): types_pb2.DT_COMPLEX64,
|
|
339
|
+
np.dtype(np.complex128): types_pb2.DT_COMPLEX128,
|
|
340
|
+
np.dtype(np.bool_): types_pb2.DT_BOOL,
|
|
341
|
+
# Note: String types are handled outside the map.
|
|
342
|
+
}
|
|
343
|
+
proto_dtype = dtype_map.get(numpy_dtype)
|
|
344
|
+
if proto_dtype is None and numpy_dtype.kind in ("S", "U"):
|
|
345
|
+
proto_dtype = types_pb2.DT_STRING
|
|
346
|
+
|
|
347
|
+
if proto_dtype is None:
|
|
348
|
+
raise TypeError(
|
|
349
|
+
f"Unsupported dtype for TensorProto conversion: {numpy_dtype}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
proto = tensor_pb2.TensorProto(
|
|
353
|
+
dtype=proto_dtype,
|
|
354
|
+
tensor_shape=tensor_shape_pb2.TensorShapeProto(
|
|
355
|
+
dim=[
|
|
356
|
+
tensor_shape_pb2.TensorShapeProto.Dim(size=d)
|
|
357
|
+
for d in values.shape
|
|
358
|
+
]
|
|
359
|
+
),
|
|
267
360
|
)
|
|
268
361
|
|
|
362
|
+
proto.tensor_content = values.tobytes()
|
|
363
|
+
return proto
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _jax_make_ndarray(proto):
|
|
367
|
+
"""JAX implementation for make_ndarray."""
|
|
368
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
369
|
+
from tensorflow.core.framework import types_pb2
|
|
370
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
371
|
+
|
|
372
|
+
dtype_map = {
|
|
373
|
+
types_pb2.DT_HALF: np.float16,
|
|
374
|
+
types_pb2.DT_FLOAT: np.float32,
|
|
375
|
+
types_pb2.DT_DOUBLE: np.float64,
|
|
376
|
+
types_pb2.DT_INT32: np.int32,
|
|
377
|
+
types_pb2.DT_UINT8: np.uint8,
|
|
378
|
+
types_pb2.DT_UINT16: np.uint16,
|
|
379
|
+
types_pb2.DT_UINT32: np.uint32,
|
|
380
|
+
types_pb2.DT_UINT64: np.uint64,
|
|
381
|
+
types_pb2.DT_INT16: np.int16,
|
|
382
|
+
types_pb2.DT_INT8: np.int8,
|
|
383
|
+
types_pb2.DT_INT64: np.int64,
|
|
384
|
+
types_pb2.DT_COMPLEX64: np.complex64,
|
|
385
|
+
types_pb2.DT_COMPLEX128: np.complex128,
|
|
386
|
+
types_pb2.DT_BOOL: np.bool_,
|
|
387
|
+
types_pb2.DT_STRING: np.bytes_,
|
|
388
|
+
}
|
|
389
|
+
if proto.dtype not in dtype_map:
|
|
390
|
+
raise TypeError(f"Unsupported TensorProto dtype: {proto.dtype}")
|
|
391
|
+
|
|
392
|
+
shape = [d.size for d in proto.tensor_shape.dim]
|
|
393
|
+
dtype = dtype_map[proto.dtype]
|
|
394
|
+
|
|
395
|
+
if proto.tensor_content:
|
|
396
|
+
num_elements = np.prod(shape).item() if shape else 0
|
|
397
|
+
# When deserializing a string from tensor_content, the itemsize is not
|
|
398
|
+
# explicitly stored. We must infer it from the content length and shape.
|
|
399
|
+
if dtype == np.bytes_ and num_elements > 0:
|
|
400
|
+
content_len = len(proto.tensor_content)
|
|
401
|
+
itemsize = content_len // num_elements
|
|
402
|
+
if itemsize * num_elements != content_len:
|
|
403
|
+
raise ValueError(
|
|
404
|
+
"Tensor content size is not a multiple of the number of elements"
|
|
405
|
+
" for string dtype."
|
|
406
|
+
)
|
|
407
|
+
dtype = np.dtype(f"S{itemsize}")
|
|
408
|
+
|
|
409
|
+
return (
|
|
410
|
+
np.frombuffer(proto.tensor_content, dtype=dtype).copy().reshape(shape)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Fallback for protos that store data in val fields instead of tensor_content.
|
|
414
|
+
if dtype == np.float32:
|
|
415
|
+
val_field = proto.float_val
|
|
416
|
+
elif dtype == np.float64:
|
|
417
|
+
val_field = proto.double_val
|
|
418
|
+
elif dtype == np.int32:
|
|
419
|
+
val_field = proto.int_val
|
|
420
|
+
elif dtype == np.int64:
|
|
421
|
+
val_field = proto.int64_val
|
|
422
|
+
elif dtype == np.bool_:
|
|
423
|
+
val_field = proto.bool_val
|
|
424
|
+
else:
|
|
425
|
+
if proto.string_val:
|
|
426
|
+
return np.array(proto.string_val, dtype=np.bytes_).reshape(shape)
|
|
427
|
+
if not any(shape):
|
|
428
|
+
return np.array([], dtype=dtype).reshape(shape)
|
|
429
|
+
raise TypeError(
|
|
430
|
+
f"Unsupported dtype for TensorProto value field fallback: {dtype}"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
return np.array(val_field, dtype=dtype).reshape(shape)
|
|
434
|
+
|
|
269
435
|
|
|
270
436
|
def _jax_get_indices_where(condition):
|
|
271
437
|
"""JAX implementation for get_indices_where."""
|
|
@@ -324,6 +490,7 @@ def _jax_boolean_mask(tensor, mask, axis=None):
|
|
|
324
490
|
|
|
325
491
|
if axis is None:
|
|
326
492
|
axis = 0
|
|
493
|
+
mask = jnp.asarray(mask)
|
|
327
494
|
tensor_swapped = jnp.moveaxis(tensor, axis, 0)
|
|
328
495
|
masked = tensor_swapped[mask]
|
|
329
496
|
return jnp.moveaxis(masked, 0, axis)
|
|
@@ -336,17 +503,26 @@ def _tf_boolean_mask(tensor, mask, axis=None):
|
|
|
336
503
|
return tf.boolean_mask(tensor, mask, axis=axis)
|
|
337
504
|
|
|
338
505
|
|
|
339
|
-
def _jax_gather(params, indices):
|
|
340
|
-
"""JAX implementation for gather."""
|
|
341
|
-
|
|
342
|
-
|
|
506
|
+
def _jax_gather(params, indices, axis=0):
|
|
507
|
+
"""JAX implementation for gather with axis support."""
|
|
508
|
+
import jax.numpy as jnp
|
|
509
|
+
|
|
510
|
+
if isinstance(params, (list, tuple)):
|
|
511
|
+
params = np.array(params)
|
|
512
|
+
|
|
513
|
+
# JAX can't JIT-compile operations on string or object arrays. We detect
|
|
514
|
+
# these types and fall back to standard NumPy operations.
|
|
515
|
+
if isinstance(params, np.ndarray) and params.dtype.kind in ("S", "U", "O"):
|
|
516
|
+
return np.take(params, np.asarray(indices), axis=axis)
|
|
517
|
+
|
|
518
|
+
return jnp.take(params, indices, axis=axis)
|
|
343
519
|
|
|
344
520
|
|
|
345
|
-
def _tf_gather(params, indices):
|
|
521
|
+
def _tf_gather(params, indices, axis=0):
|
|
346
522
|
"""TensorFlow implementation for gather."""
|
|
347
523
|
import tensorflow as tf
|
|
348
524
|
|
|
349
|
-
return tf.gather(params, indices)
|
|
525
|
+
return tf.gather(params, indices, axis=axis)
|
|
350
526
|
|
|
351
527
|
|
|
352
528
|
def _jax_fill(dims, value):
|
|
@@ -417,6 +593,127 @@ def _jax_transpose(a, perm=None):
|
|
|
417
593
|
return jnp.transpose(a, axes=perm)
|
|
418
594
|
|
|
419
595
|
|
|
596
|
+
def _jax_get_seed_data(seed: Any) -> Optional[np.ndarray]:
|
|
597
|
+
"""Extracts the underlying numerical data from a JAX PRNGKey."""
|
|
598
|
+
if seed is None:
|
|
599
|
+
return None
|
|
600
|
+
|
|
601
|
+
return np.array(jax.random.key_data(seed))
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
def _tf_get_seed_data(seed: Any) -> Optional[np.ndarray]:
|
|
605
|
+
"""Converts a TensorFlow-style seed into a NumPy array."""
|
|
606
|
+
if seed is None:
|
|
607
|
+
return None
|
|
608
|
+
return np.array(seed)
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def _jax_convert_to_tensor(data, dtype=None):
|
|
612
|
+
"""Converts data to a JAX array, handling strings as NumPy arrays."""
|
|
613
|
+
# JAX does not natively support string tensors in the same way TF does.
|
|
614
|
+
# If a string dtype is requested, or if the data is inherently strings,
|
|
615
|
+
# we fall back to a standard NumPy array.
|
|
616
|
+
is_string_target = False
|
|
617
|
+
if dtype is not None:
|
|
618
|
+
try:
|
|
619
|
+
if np.dtype(dtype).kind in ("S", "U"):
|
|
620
|
+
is_string_target = True
|
|
621
|
+
except TypeError:
|
|
622
|
+
# This can happen if dtype is not a valid dtype specifier,
|
|
623
|
+
# let jax.asarray handle it.
|
|
624
|
+
pass
|
|
625
|
+
|
|
626
|
+
is_string_data = isinstance(data, (list, np.ndarray)) and np.array(
|
|
627
|
+
data
|
|
628
|
+
).dtype.kind in ("S", "U")
|
|
629
|
+
|
|
630
|
+
if is_string_target or (dtype is None and is_string_data):
|
|
631
|
+
return np.array(data, dtype=dtype)
|
|
632
|
+
|
|
633
|
+
return jax_ops.asarray(data, dtype=dtype)
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
def _tf_nanmean(a, axis=None, keepdims=False):
|
|
637
|
+
import tensorflow.experimental.numpy as tnp
|
|
638
|
+
|
|
639
|
+
return tf_backend.convert_to_tensor(
|
|
640
|
+
tnp.nanmean(a, axis=axis, keepdims=keepdims)
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def _tf_nansum(a, axis=None, keepdims=False):
|
|
645
|
+
import tensorflow.experimental.numpy as tnp
|
|
646
|
+
|
|
647
|
+
return tf_backend.convert_to_tensor(
|
|
648
|
+
tnp.nansum(a, axis=axis, keepdims=keepdims)
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def _tf_nanvar(a, axis=None, keepdims=False):
|
|
653
|
+
"""Calculates variance ignoring NaNs, strictly returning a Tensor."""
|
|
654
|
+
import tensorflow as tf
|
|
655
|
+
import tensorflow.experimental.numpy as tnp
|
|
656
|
+
# We implement two-pass variance to correctly handle NaNs and ensure
|
|
657
|
+
# all operations remain within the TF graph (maintaining differentiability).
|
|
658
|
+
a_tensor = tf.convert_to_tensor(a)
|
|
659
|
+
mean = tnp.nanmean(a_tensor, axis=axis, keepdims=True)
|
|
660
|
+
sq_diff = tf.math.squared_difference(a_tensor, mean)
|
|
661
|
+
var = tnp.nanmean(sq_diff, axis=axis, keepdims=keepdims)
|
|
662
|
+
return tf.convert_to_tensor(var)
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def _jax_one_hot(
|
|
666
|
+
indices, depth, on_value=None, off_value=None, axis=None, dtype=None
|
|
667
|
+
):
|
|
668
|
+
"""JAX implementation for one_hot."""
|
|
669
|
+
import jax.numpy as jnp
|
|
670
|
+
|
|
671
|
+
resolved_dtype = _resolve_dtype(dtype, on_value, off_value, 1, 0)
|
|
672
|
+
jax_axis = -1 if axis is None else axis
|
|
673
|
+
|
|
674
|
+
one_hot_result = jax.nn.one_hot(
|
|
675
|
+
indices, num_classes=depth, dtype=jnp.dtype(resolved_dtype), axis=jax_axis
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
on_val = 1 if on_value is None else on_value
|
|
679
|
+
off_val = 0 if off_value is None else off_value
|
|
680
|
+
|
|
681
|
+
if on_val == 1 and off_val == 0:
|
|
682
|
+
return one_hot_result
|
|
683
|
+
|
|
684
|
+
on_tensor = jnp.array(on_val, dtype=jnp.dtype(resolved_dtype))
|
|
685
|
+
off_tensor = jnp.array(off_val, dtype=jnp.dtype(resolved_dtype))
|
|
686
|
+
|
|
687
|
+
return jnp.where(one_hot_result == 1, on_tensor, off_tensor)
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def _jax_roll(a, shift, axis=None):
|
|
691
|
+
"""JAX implementation for roll."""
|
|
692
|
+
import jax.numpy as jnp
|
|
693
|
+
|
|
694
|
+
return jnp.roll(a, shift, axis=axis)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def _tf_roll(a, shift: Sequence[int], axis=None):
|
|
698
|
+
"""TensorFlow implementation for roll that handles axis=None."""
|
|
699
|
+
import tensorflow as tf
|
|
700
|
+
|
|
701
|
+
if axis is None:
|
|
702
|
+
original_shape = tf.shape(a)
|
|
703
|
+
flat_tensor = tf.reshape(a, [-1])
|
|
704
|
+
rolled_flat = tf.roll(flat_tensor, shift=shift, axis=0)
|
|
705
|
+
return tf.reshape(rolled_flat, original_shape)
|
|
706
|
+
return tf.roll(a, shift, axis=axis)
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
def _jax_enable_op_determinism():
|
|
710
|
+
"""No-op for JAX. Determinism is handled via stateless PRNGKeys."""
|
|
711
|
+
warnings.warn(
|
|
712
|
+
"op determinism is a TensorFlow-specific concept and has no effect when"
|
|
713
|
+
" using the JAX backend."
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
|
|
420
717
|
# --- Backend Initialization ---
|
|
421
718
|
_BACKEND = config.get_backend()
|
|
422
719
|
|
|
@@ -429,15 +726,67 @@ if _BACKEND == config.Backend.JAX:
|
|
|
429
726
|
import jax
|
|
430
727
|
import jax.numpy as jax_ops
|
|
431
728
|
import tensorflow_probability.substrates.jax as tfp_jax
|
|
729
|
+
from jax import tree_util
|
|
432
730
|
|
|
433
731
|
class ExtensionType:
|
|
434
|
-
"""A JAX-compatible stand-in for tf.experimental.ExtensionType.
|
|
732
|
+
"""A JAX-compatible stand-in for tf.experimental.ExtensionType.
|
|
435
733
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
734
|
+
This class registers itself as a JAX Pytree node, allowing it to be passed
|
|
735
|
+
through JIT-compiled functions.
|
|
736
|
+
"""
|
|
737
|
+
|
|
738
|
+
def __init_subclass__(cls, **kwargs):
|
|
739
|
+
super().__init_subclass__(**kwargs)
|
|
740
|
+
tree_util.register_pytree_node(
|
|
741
|
+
cls,
|
|
742
|
+
cls._tree_flatten,
|
|
743
|
+
cls._tree_unflatten,
|
|
439
744
|
)
|
|
440
745
|
|
|
746
|
+
def _tree_flatten(self):
|
|
747
|
+
"""Flattens the object for JAX tracing.
|
|
748
|
+
|
|
749
|
+
Fields containing JAX arrays (or convertibles) are treated as children.
|
|
750
|
+
Fields containing strings or NumPy string arrays must be treated as
|
|
751
|
+
auxiliary data because JAX cannot trace non-numeric/non-boolean data.
|
|
752
|
+
|
|
753
|
+
Returns:
|
|
754
|
+
A tuple of (children, aux_data), where children are the
|
|
755
|
+
tracer-compatible parts of the object, and aux_data contains auxiliary
|
|
756
|
+
information needed for unflattening.
|
|
757
|
+
"""
|
|
758
|
+
d = vars(self)
|
|
759
|
+
all_keys = sorted(d.keys())
|
|
760
|
+
children = []
|
|
761
|
+
aux = {}
|
|
762
|
+
children_keys = []
|
|
763
|
+
|
|
764
|
+
for k in all_keys:
|
|
765
|
+
v = d[k]
|
|
766
|
+
# Identify string data to prevent JAX tracing errors.
|
|
767
|
+
# 'S' is zero-terminated bytes (fixed-width), 'U' is unicode string.
|
|
768
|
+
is_numpy_string = isinstance(v, np.ndarray) and v.dtype.kind in (
|
|
769
|
+
"S",
|
|
770
|
+
"U",
|
|
771
|
+
)
|
|
772
|
+
is_plain_string = isinstance(v, str)
|
|
773
|
+
|
|
774
|
+
if is_numpy_string or is_plain_string:
|
|
775
|
+
aux[k] = v
|
|
776
|
+
else:
|
|
777
|
+
children.append(v)
|
|
778
|
+
children_keys.append(k)
|
|
779
|
+
return children, (aux, children_keys)
|
|
780
|
+
|
|
781
|
+
@classmethod
|
|
782
|
+
def _tree_unflatten(cls, aux_and_keys, children):
|
|
783
|
+
aux, children_keys = aux_and_keys
|
|
784
|
+
obj = cls.__new__(cls)
|
|
785
|
+
vars(obj).update(aux)
|
|
786
|
+
for k, v in zip(children_keys, children):
|
|
787
|
+
setattr(obj, k, v)
|
|
788
|
+
return obj
|
|
789
|
+
|
|
441
790
|
class _JaxErrors:
|
|
442
791
|
# pylint: disable=invalid-name
|
|
443
792
|
ResourceExhaustedError = MemoryError
|
|
@@ -492,6 +841,26 @@ if _BACKEND == config.Backend.JAX:
|
|
|
492
841
|
|
|
493
842
|
random = _JaxRandom()
|
|
494
843
|
|
|
844
|
+
@functools.partial(
|
|
845
|
+
jax.jit,
|
|
846
|
+
static_argnames=[
|
|
847
|
+
"joint_dist",
|
|
848
|
+
"n_chains",
|
|
849
|
+
"n_draws",
|
|
850
|
+
"num_adaptation_steps",
|
|
851
|
+
"dual_averaging_kwargs",
|
|
852
|
+
"max_tree_depth",
|
|
853
|
+
"unrolled_leapfrog_steps",
|
|
854
|
+
"parallel_iterations",
|
|
855
|
+
],
|
|
856
|
+
)
|
|
857
|
+
def _jax_xla_windowed_adaptive_nuts(**kwargs):
|
|
858
|
+
"""JAX-specific JIT wrapper for the NUTS sampler."""
|
|
859
|
+
kwargs["seed"] = random.prng_key(kwargs["seed"])
|
|
860
|
+
return experimental.mcmc.windowed_adaptive_nuts(**kwargs)
|
|
861
|
+
|
|
862
|
+
xla_windowed_adaptive_nuts = _jax_xla_windowed_adaptive_nuts
|
|
863
|
+
|
|
495
864
|
_ops = jax_ops
|
|
496
865
|
errors = _JaxErrors()
|
|
497
866
|
Tensor = jax.Array
|
|
@@ -499,7 +868,7 @@ if _BACKEND == config.Backend.JAX:
|
|
|
499
868
|
bijectors = tfp_jax.bijectors
|
|
500
869
|
experimental = tfp_jax.experimental
|
|
501
870
|
mcmc = tfp_jax.mcmc
|
|
502
|
-
_convert_to_tensor =
|
|
871
|
+
_convert_to_tensor = _jax_convert_to_tensor
|
|
503
872
|
|
|
504
873
|
# Standardized Public API
|
|
505
874
|
absolute = _ops.abs
|
|
@@ -507,18 +876,6 @@ if _BACKEND == config.Backend.JAX:
|
|
|
507
876
|
arange = _jax_arange
|
|
508
877
|
argmax = _jax_argmax
|
|
509
878
|
boolean_mask = _jax_boolean_mask
|
|
510
|
-
concatenate = _ops.concatenate
|
|
511
|
-
stack = _ops.stack
|
|
512
|
-
split = _jax_split
|
|
513
|
-
zeros = _ops.zeros
|
|
514
|
-
zeros_like = _ops.zeros_like
|
|
515
|
-
ones = _ops.ones
|
|
516
|
-
ones_like = _ops.ones_like
|
|
517
|
-
repeat = _ops.repeat
|
|
518
|
-
reshape = _ops.reshape
|
|
519
|
-
tile = _ops.tile
|
|
520
|
-
where = _ops.where
|
|
521
|
-
broadcast_to = _ops.broadcast_to
|
|
522
879
|
broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
|
|
523
880
|
broadcast_to = _ops.broadcast_to
|
|
524
881
|
cast = _jax_cast
|
|
@@ -527,6 +884,7 @@ if _BACKEND == config.Backend.JAX:
|
|
|
527
884
|
divide = _ops.divide
|
|
528
885
|
divide_no_nan = _jax_divide_no_nan
|
|
529
886
|
einsum = _ops.einsum
|
|
887
|
+
enable_op_determinism = _jax_enable_op_determinism
|
|
530
888
|
equal = _ops.equal
|
|
531
889
|
exp = _ops.exp
|
|
532
890
|
expand_dims = _ops.expand_dims
|
|
@@ -534,12 +892,17 @@ if _BACKEND == config.Backend.JAX:
|
|
|
534
892
|
function = _jax_function_wrapper
|
|
535
893
|
gather = _jax_gather
|
|
536
894
|
get_indices_where = _jax_get_indices_where
|
|
895
|
+
get_seed_data = _jax_get_seed_data
|
|
537
896
|
is_nan = _ops.isnan
|
|
538
897
|
log = _ops.log
|
|
539
898
|
make_ndarray = _jax_make_ndarray
|
|
540
899
|
make_tensor_proto = _jax_make_tensor_proto
|
|
900
|
+
nanmean = jax_ops.nanmean
|
|
541
901
|
nanmedian = _jax_nanmedian
|
|
902
|
+
nansum = jax_ops.nansum
|
|
903
|
+
nanvar = jax_ops.nanvar
|
|
542
904
|
numpy_function = _jax_numpy_function
|
|
905
|
+
one_hot = _jax_one_hot
|
|
543
906
|
ones = _ops.ones
|
|
544
907
|
ones_like = _ops.ones_like
|
|
545
908
|
rank = _ops.ndim
|
|
@@ -551,6 +914,8 @@ if _BACKEND == config.Backend.JAX:
|
|
|
551
914
|
reduce_sum = _ops.sum
|
|
552
915
|
repeat = _ops.repeat
|
|
553
916
|
reshape = _ops.reshape
|
|
917
|
+
roll = _jax_roll
|
|
918
|
+
split = _jax_split
|
|
554
919
|
stack = _ops.stack
|
|
555
920
|
tile = _jax_tile
|
|
556
921
|
transpose = _jax_transpose
|
|
@@ -564,13 +929,14 @@ if _BACKEND == config.Backend.JAX:
|
|
|
564
929
|
newaxis = _ops.newaxis
|
|
565
930
|
TensorShape = _jax_tensor_shape
|
|
566
931
|
int32 = _ops.int32
|
|
932
|
+
string = np.bytes_
|
|
933
|
+
|
|
934
|
+
stabilize_rf_roi_grid = _jax_stabilize_rf_roi_grid
|
|
567
935
|
|
|
568
936
|
def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
|
|
569
|
-
|
|
570
|
-
"
|
|
571
|
-
"
|
|
572
|
-
" integer directly to the sampling methods (e.g., `sample_prior`),"
|
|
573
|
-
" which will be used to create a JAX PRNGKey internally."
|
|
937
|
+
warnings.warn(
|
|
938
|
+
"backend.set_random_seed is a no-op in JAX. Randomness is managed "
|
|
939
|
+
"statelessly via PRNGKeys passed to sampling methods."
|
|
574
940
|
)
|
|
575
941
|
|
|
576
942
|
elif _BACKEND == config.Backend.TENSORFLOW:
|
|
@@ -638,29 +1004,25 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
638
1004
|
|
|
639
1005
|
random = _TfRandom()
|
|
640
1006
|
|
|
1007
|
+
@tf_backend.function(autograph=False, jit_compile=True)
|
|
1008
|
+
def _tf_xla_windowed_adaptive_nuts(**kwargs):
|
|
1009
|
+
"""TensorFlow-specific XLA wrapper for the NUTS sampler."""
|
|
1010
|
+
return experimental.mcmc.windowed_adaptive_nuts(**kwargs)
|
|
1011
|
+
|
|
1012
|
+
xla_windowed_adaptive_nuts = _tf_xla_windowed_adaptive_nuts
|
|
1013
|
+
|
|
641
1014
|
tfd = tfp.distributions
|
|
642
1015
|
bijectors = tfp.bijectors
|
|
643
1016
|
experimental = tfp.experimental
|
|
644
1017
|
mcmc = tfp.mcmc
|
|
645
|
-
|
|
646
1018
|
_convert_to_tensor = _ops.convert_to_tensor
|
|
1019
|
+
|
|
1020
|
+
# Standardized Public API
|
|
647
1021
|
absolute = _ops.math.abs
|
|
648
1022
|
allclose = _ops.experimental.numpy.allclose
|
|
649
1023
|
arange = _tf_arange
|
|
650
1024
|
argmax = _tf_argmax
|
|
651
1025
|
boolean_mask = _tf_boolean_mask
|
|
652
|
-
concatenate = _ops.concat
|
|
653
|
-
stack = _ops.stack
|
|
654
|
-
split = _ops.split
|
|
655
|
-
zeros = _ops.zeros
|
|
656
|
-
zeros_like = _ops.zeros_like
|
|
657
|
-
ones = _ops.ones
|
|
658
|
-
ones_like = _ops.ones_like
|
|
659
|
-
repeat = _ops.repeat
|
|
660
|
-
reshape = _ops.reshape
|
|
661
|
-
tile = _ops.tile
|
|
662
|
-
where = _ops.where
|
|
663
|
-
broadcast_to = _ops.broadcast_to
|
|
664
1026
|
broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
|
|
665
1027
|
broadcast_to = _ops.broadcast_to
|
|
666
1028
|
cast = _ops.cast
|
|
@@ -669,6 +1031,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
669
1031
|
divide = _ops.divide
|
|
670
1032
|
divide_no_nan = _ops.math.divide_no_nan
|
|
671
1033
|
einsum = _ops.einsum
|
|
1034
|
+
enable_op_determinism = _ops.config.experimental.enable_op_determinism
|
|
672
1035
|
equal = _ops.equal
|
|
673
1036
|
exp = _ops.math.exp
|
|
674
1037
|
expand_dims = _ops.expand_dims
|
|
@@ -676,12 +1039,17 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
676
1039
|
function = _tf_function_wrapper
|
|
677
1040
|
gather = _tf_gather
|
|
678
1041
|
get_indices_where = _tf_get_indices_where
|
|
1042
|
+
get_seed_data = _tf_get_seed_data
|
|
679
1043
|
is_nan = _ops.math.is_nan
|
|
680
1044
|
log = _ops.math.log
|
|
681
1045
|
make_ndarray = _ops.make_ndarray
|
|
682
1046
|
make_tensor_proto = _ops.make_tensor_proto
|
|
1047
|
+
nanmean = _tf_nanmean
|
|
683
1048
|
nanmedian = _tf_nanmedian
|
|
1049
|
+
nansum = _tf_nansum
|
|
1050
|
+
nanvar = _tf_nanvar
|
|
684
1051
|
numpy_function = _ops.numpy_function
|
|
1052
|
+
one_hot = _ops.one_hot
|
|
685
1053
|
ones = _ops.ones
|
|
686
1054
|
ones_like = _ops.ones_like
|
|
687
1055
|
rank = _ops.rank
|
|
@@ -693,7 +1061,9 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
693
1061
|
reduce_sum = _ops.reduce_sum
|
|
694
1062
|
repeat = _ops.repeat
|
|
695
1063
|
reshape = _ops.reshape
|
|
1064
|
+
roll = _tf_roll
|
|
696
1065
|
set_random_seed = tf_backend.keras.utils.set_random_seed
|
|
1066
|
+
split = _ops.split
|
|
697
1067
|
stack = _ops.stack
|
|
698
1068
|
tile = _ops.tile
|
|
699
1069
|
transpose = _ops.transpose
|
|
@@ -702,11 +1072,14 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
702
1072
|
zeros = _ops.zeros
|
|
703
1073
|
zeros_like = _ops.zeros_like
|
|
704
1074
|
|
|
1075
|
+
stabilize_rf_roi_grid = _tf_stabilize_rf_roi_grid
|
|
1076
|
+
|
|
705
1077
|
float32 = _ops.float32
|
|
706
1078
|
bool_ = _ops.bool
|
|
707
1079
|
newaxis = _ops.newaxis
|
|
708
1080
|
TensorShape = _ops.TensorShape
|
|
709
1081
|
int32 = _ops.int32
|
|
1082
|
+
string = _ops.string
|
|
710
1083
|
|
|
711
1084
|
else:
|
|
712
1085
|
raise ValueError(f"Unsupported backend: {_BACKEND}")
|
|
@@ -735,7 +1108,7 @@ def _extract_int_seed(s: Any) -> Optional[int]:
|
|
|
735
1108
|
# in unpredictable ways. The goal is to safely attempt extraction and fail
|
|
736
1109
|
# gracefully.
|
|
737
1110
|
except Exception: # pylint: disable=broad-except
|
|
738
|
-
|
|
1111
|
+
return None
|
|
739
1112
|
return None
|
|
740
1113
|
|
|
741
1114
|
|
|
@@ -819,6 +1192,14 @@ class _JaxRNGHandler(_BaseRNGHandler):
|
|
|
819
1192
|
if seed is None:
|
|
820
1193
|
return
|
|
821
1194
|
|
|
1195
|
+
if (
|
|
1196
|
+
isinstance(seed, jax.Array) # pylint: disable=undefined-variable
|
|
1197
|
+
and seed.shape == (2,)
|
|
1198
|
+
and seed.dtype == jax_ops.uint32 # pylint: disable=undefined-variable
|
|
1199
|
+
):
|
|
1200
|
+
self._key = seed
|
|
1201
|
+
return
|
|
1202
|
+
|
|
822
1203
|
if self._int_seed is None:
|
|
823
1204
|
raise ValueError(
|
|
824
1205
|
"JAX backend requires a seed that is an integer or a scalar array,"
|
|
@@ -830,7 +1211,8 @@ class _JaxRNGHandler(_BaseRNGHandler):
|
|
|
830
1211
|
def get_kernel_seed(self) -> Any:
|
|
831
1212
|
if self._key is None:
|
|
832
1213
|
return None
|
|
833
|
-
|
|
1214
|
+
_, subkey = random.split(self._key)
|
|
1215
|
+
return int(jax.random.randint(subkey, (), 0, 2**31 - 1)) # pylint: disable=undefined-variable
|
|
834
1216
|
|
|
835
1217
|
def get_next_seed(self) -> Any:
|
|
836
1218
|
if self._key is None:
|
|
@@ -853,35 +1235,22 @@ class _JaxRNGHandler(_BaseRNGHandler):
|
|
|
853
1235
|
return _JaxRNGHandler(np.asarray(new_seed_tensor).item())
|
|
854
1236
|
|
|
855
1237
|
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
"""TensorFlow implementation.
|
|
1238
|
+
class _TFRNGHandler(_BaseRNGHandler):
|
|
1239
|
+
"""A stateless-style RNG handler for TensorFlow.
|
|
859
1240
|
|
|
860
|
-
|
|
861
|
-
stateful `tf.random.Generator`-based implementation.
|
|
1241
|
+
This handler canonicalizes any seed input into a single stateless seed Tensor.
|
|
862
1242
|
"""
|
|
863
1243
|
|
|
864
|
-
def __init__(self, seed: SeedType
|
|
865
|
-
"""Initializes the TensorFlow
|
|
1244
|
+
def __init__(self, seed: SeedType):
|
|
1245
|
+
"""Initializes the TensorFlow RNG handler.
|
|
866
1246
|
|
|
867
1247
|
Args:
|
|
868
1248
|
seed: The initial seed. Can be an integer, a sequence of two integers, a
|
|
869
|
-
corresponding Tensor, or None.
|
|
870
|
-
|
|
871
|
-
seed tensor is used directly, and the standard initialization logic for
|
|
872
|
-
the public `seed` argument is bypassed.
|
|
873
|
-
|
|
874
|
-
Raises:
|
|
875
|
-
ValueError: If `seed` is a sequence with a length other than 2.
|
|
1249
|
+
corresponding Tensor, or None. It will be sanitized and stored
|
|
1250
|
+
internally as a single stateless seed Tensor.
|
|
876
1251
|
"""
|
|
877
1252
|
super().__init__(seed)
|
|
878
|
-
self.
|
|
879
|
-
|
|
880
|
-
if _sanitized_seed is not None:
|
|
881
|
-
# Internal path: A pre-sanitized seed was provided by a trusted source
|
|
882
|
-
# so we adopt it directly.
|
|
883
|
-
self._tf_sanitized_seed = _sanitized_seed
|
|
884
|
-
return
|
|
1253
|
+
self._seed_state: Optional["_tf.Tensor"] = None
|
|
885
1254
|
|
|
886
1255
|
if seed is None:
|
|
887
1256
|
return
|
|
@@ -899,63 +1268,33 @@ class _TFLegacyRNGHandler(_BaseRNGHandler):
|
|
|
899
1268
|
else:
|
|
900
1269
|
seed_to_sanitize = seed
|
|
901
1270
|
|
|
902
|
-
self.
|
|
1271
|
+
self._seed_state = random.sanitize_seed(seed_to_sanitize)
|
|
903
1272
|
|
|
904
1273
|
def get_kernel_seed(self) -> Any:
|
|
905
|
-
|
|
1274
|
+
"""Provides the current stateless seed of the handler without advancing it."""
|
|
1275
|
+
return self._seed_state
|
|
906
1276
|
|
|
907
1277
|
def get_next_seed(self) -> Any:
|
|
908
|
-
"""Returns
|
|
909
|
-
|
|
910
|
-
Returns:
|
|
911
|
-
The original integer seed provided during initialization.
|
|
912
|
-
|
|
913
|
-
Raises:
|
|
914
|
-
RuntimeError: If the handler was not initialized with a scalar integer
|
|
915
|
-
seed, which is required for the legacy prior sampling path.
|
|
916
|
-
"""
|
|
917
|
-
if self._seed_input is None:
|
|
1278
|
+
"""Returns a new unique seed and advances the handler's internal state."""
|
|
1279
|
+
if self._seed_state is None:
|
|
918
1280
|
return None
|
|
1281
|
+
new_seed, next_state = random.stateless_split(self._seed_state)
|
|
1282
|
+
self._seed_state = next_state
|
|
1283
|
+
return new_seed
|
|
919
1284
|
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
def advance_handler(self) -> "_TFLegacyRNGHandler":
|
|
928
|
-
"""Creates a new handler by incrementing the sanitized seed by 1.
|
|
929
|
-
|
|
930
|
-
Returns:
|
|
931
|
-
A new `_TFLegacyRNGHandler` instance with an incremented seed state.
|
|
932
|
-
|
|
933
|
-
Raises:
|
|
934
|
-
RuntimeError: If the handler's sanitized seed was not initialized.
|
|
935
|
-
"""
|
|
936
|
-
if self._seed_input is None:
|
|
937
|
-
return _TFLegacyRNGHandler(None)
|
|
938
|
-
|
|
939
|
-
if self._tf_sanitized_seed is None:
|
|
940
|
-
# Should be caught during init, but included for defensive programming.
|
|
941
|
-
raise RuntimeError("RNGHandler sanitized seed not initialized.")
|
|
942
|
-
|
|
943
|
-
new_sanitized_seed = self._tf_sanitized_seed + 1
|
|
944
|
-
|
|
945
|
-
# Create a new handler instance, passing the original seed input (to
|
|
946
|
-
# preserve state like `_int_seed`) and injecting the new sanitized seed
|
|
947
|
-
# via the private constructor argument.
|
|
948
|
-
return _TFLegacyRNGHandler(
|
|
949
|
-
self._seed_input, _sanitized_seed=new_sanitized_seed
|
|
950
|
-
)
|
|
1285
|
+
def advance_handler(self) -> "_TFRNGHandler":
|
|
1286
|
+
"""Creates a new handler from a new seed and advances the current handler."""
|
|
1287
|
+
if self._seed_state is None:
|
|
1288
|
+
return _TFRNGHandler(None)
|
|
1289
|
+
seed_for_new_handler, next_state = random.stateless_split(self._seed_state)
|
|
1290
|
+
self._seed_state = next_state
|
|
1291
|
+
return _TFRNGHandler(seed_for_new_handler)
|
|
951
1292
|
|
|
952
1293
|
|
|
953
1294
|
if _BACKEND == config.Backend.JAX:
|
|
954
1295
|
RNGHandler = _JaxRNGHandler
|
|
955
1296
|
elif _BACKEND == config.Backend.TENSORFLOW:
|
|
956
|
-
RNGHandler =
|
|
957
|
-
_TFLegacyRNGHandler # TODO: Replace with _TFRNGHandler
|
|
958
|
-
)
|
|
1297
|
+
RNGHandler = _TFRNGHandler
|
|
959
1298
|
else:
|
|
960
1299
|
raise ImportError(f"RNGHandler not implemented for backend: {_BACKEND}")
|
|
961
1300
|
|