google-meridian 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/METADATA +2 -2
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/RECORD +24 -24
- meridian/analysis/analyzer.py +101 -37
- meridian/analysis/optimizer.py +132 -88
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/visualizer.py +16 -5
- meridian/backend/__init__.py +475 -14
- meridian/backend/config.py +75 -16
- meridian/backend/test_utils.py +87 -1
- meridian/constants.py +14 -9
- meridian/data/input_data.py +7 -2
- meridian/data/test_utils.py +5 -3
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +10 -9
- meridian/model/eda/eda_engine.py +440 -11
- meridian/model/knots.py +1 -1
- meridian/model/model_test_data.py +15 -9
- meridian/model/posterior_sampler.py +365 -365
- meridian/model/prior_distribution.py +104 -39
- meridian/model/transformers.py +5 -5
- meridian/version.py +1 -1
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
meridian/backend/__init__.py
CHANGED
|
@@ -14,13 +14,16 @@
|
|
|
14
14
|
|
|
15
15
|
"""Backend Abstraction Layer for Meridian."""
|
|
16
16
|
|
|
17
|
+
import abc
|
|
18
|
+
import functools
|
|
17
19
|
import os
|
|
18
|
-
from typing import Any, Optional,
|
|
20
|
+
from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
|
19
21
|
|
|
20
22
|
from meridian.backend import config
|
|
21
23
|
import numpy as np
|
|
22
24
|
from typing_extensions import Literal
|
|
23
25
|
|
|
26
|
+
|
|
24
27
|
# The conditional imports in this module are a deliberate design choice for the
|
|
25
28
|
# backend abstraction layer. The TFP-on-JAX substrate provides a nearly
|
|
26
29
|
# identical API to the standard TFP library, making an alias-based approach more
|
|
@@ -28,6 +31,19 @@ from typing_extensions import Literal
|
|
|
28
31
|
# extensive boilerplate.
|
|
29
32
|
# pylint: disable=g-import-not-at-top,g-bad-import-order
|
|
30
33
|
|
|
34
|
+
_DEFAULT_FLOAT = "float32"
|
|
35
|
+
_DEFAULT_INT = "int64"
|
|
36
|
+
|
|
37
|
+
_TENSORFLOW_TILE_KEYWORD = "multiples"
|
|
38
|
+
_JAX_TILE_KEYWORD = "reps"
|
|
39
|
+
|
|
40
|
+
_ARG_JIT_COMPILE = "jit_compile"
|
|
41
|
+
_ARG_STATIC_ARGNUMS = "static_argnums"
|
|
42
|
+
_ARG_STATIC_ARGNAMES = "static_argnames"
|
|
43
|
+
|
|
44
|
+
_DEFAULT_SEED_DTYPE = "int32"
|
|
45
|
+
_MAX_INT32 = np.iinfo(np.int32).max
|
|
46
|
+
|
|
31
47
|
if TYPE_CHECKING:
|
|
32
48
|
import dataclasses
|
|
33
49
|
import jax as _jax
|
|
@@ -35,6 +51,8 @@ if TYPE_CHECKING:
|
|
|
35
51
|
|
|
36
52
|
TensorShapeInstance = Union[_tf.TensorShape, Tuple[int, ...]]
|
|
37
53
|
|
|
54
|
+
SeedType = Any
|
|
55
|
+
|
|
38
56
|
|
|
39
57
|
def standardize_dtype(dtype: Any) -> str:
|
|
40
58
|
"""Converts a backend-specific dtype to a standard string representation.
|
|
@@ -88,8 +106,8 @@ def result_type(*types: Any) -> str:
|
|
|
88
106
|
standardized_types.append(str(t))
|
|
89
107
|
|
|
90
108
|
if any("float" in t for t in standardized_types):
|
|
91
|
-
return
|
|
92
|
-
return
|
|
109
|
+
return _DEFAULT_FLOAT
|
|
110
|
+
return _DEFAULT_INT
|
|
93
111
|
|
|
94
112
|
|
|
95
113
|
def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
|
|
@@ -163,6 +181,74 @@ def _jax_divide_no_nan(x, y):
|
|
|
163
181
|
return jnp.where(y != 0, jnp.divide(x, y), 0.0)
|
|
164
182
|
|
|
165
183
|
|
|
184
|
+
def _jax_function_wrapper(func=None, **kwargs):
|
|
185
|
+
"""A wrapper for jax.jit that handles TF-like args and static args.
|
|
186
|
+
|
|
187
|
+
This wrapper provides compatibility with TensorFlow's `tf.function` arguments
|
|
188
|
+
and improves ergonomics when decorating class methods in JAX.
|
|
189
|
+
|
|
190
|
+
By default, if neither `static_argnums` nor `static_argnames` are provided, it
|
|
191
|
+
defaults `static_argnums` to `(0,)`. This assumes the function is a method
|
|
192
|
+
where the first argument (`self` or `cls`) should be treated as static.
|
|
193
|
+
|
|
194
|
+
To disable this behavior for plain functions, explicitly provide an empty
|
|
195
|
+
tuple: `@backend.function(static_argnums=())`.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
func: The function to wrap.
|
|
199
|
+
**kwargs: Keyword arguments passed to jax.jit. TF-specific arguments (like
|
|
200
|
+
`jit_compile`) are ignored.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
The wrapped function or a decorator.
|
|
204
|
+
"""
|
|
205
|
+
jit_kwargs = kwargs.copy()
|
|
206
|
+
|
|
207
|
+
jit_kwargs.pop(_ARG_JIT_COMPILE, None)
|
|
208
|
+
|
|
209
|
+
if _ARG_STATIC_ARGNUMS in jit_kwargs:
|
|
210
|
+
if not jit_kwargs[_ARG_STATIC_ARGNUMS]:
|
|
211
|
+
jit_kwargs.pop(_ARG_STATIC_ARGNUMS)
|
|
212
|
+
else:
|
|
213
|
+
jit_kwargs[_ARG_STATIC_ARGNUMS] = (0,)
|
|
214
|
+
|
|
215
|
+
decorator = functools.partial(jax.jit, **jit_kwargs)
|
|
216
|
+
|
|
217
|
+
if func:
|
|
218
|
+
return decorator(func)
|
|
219
|
+
return decorator
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _tf_function_wrapper(func=None, **kwargs):
|
|
223
|
+
"""A wrapper for tf.function that ignores JAX-specific arguments."""
|
|
224
|
+
import tensorflow as tf
|
|
225
|
+
|
|
226
|
+
kwargs.pop(_ARG_STATIC_ARGNAMES, None)
|
|
227
|
+
kwargs.pop(_ARG_STATIC_ARGNUMS, None)
|
|
228
|
+
|
|
229
|
+
decorator = tf.function(**kwargs)
|
|
230
|
+
|
|
231
|
+
if func:
|
|
232
|
+
return decorator(func)
|
|
233
|
+
return decorator
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _jax_nanmedian(a, axis=None):
|
|
237
|
+
"""JAX implementation for nanmedian."""
|
|
238
|
+
import jax.numpy as jnp
|
|
239
|
+
|
|
240
|
+
return jnp.nanmedian(a, axis=axis)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _tf_nanmedian(a, axis=None):
|
|
244
|
+
"""TensorFlow implementation for nanmedian using numpy_function."""
|
|
245
|
+
import tensorflow as tf
|
|
246
|
+
|
|
247
|
+
return tf.numpy_function(
|
|
248
|
+
lambda x: np.nanmedian(x, axis=axis).astype(x.dtype), [a], a.dtype
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
166
252
|
def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
|
|
167
253
|
raise NotImplementedError(
|
|
168
254
|
"backend.numpy_function is not implemented for the JAX backend."
|
|
@@ -195,6 +281,26 @@ def _tf_get_indices_where(condition):
|
|
|
195
281
|
return tf.where(condition)
|
|
196
282
|
|
|
197
283
|
|
|
284
|
+
def _jax_split(value, num_or_size_splits, axis=0):
|
|
285
|
+
"""JAX implementation for split that accepts size splits."""
|
|
286
|
+
import jax.numpy as jnp
|
|
287
|
+
|
|
288
|
+
if not isinstance(num_or_size_splits, int):
|
|
289
|
+
indices = jnp.cumsum(jnp.array(num_or_size_splits))[:-1]
|
|
290
|
+
return jnp.split(value, indices, axis=axis)
|
|
291
|
+
|
|
292
|
+
return jnp.split(value, num_or_size_splits, axis=axis)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _jax_tile(*args, **kwargs):
|
|
296
|
+
"""JAX wrapper for tile that supports the `multiples` keyword argument."""
|
|
297
|
+
import jax.numpy as jnp
|
|
298
|
+
|
|
299
|
+
if _TENSORFLOW_TILE_KEYWORD in kwargs:
|
|
300
|
+
kwargs[_JAX_TILE_KEYWORD] = kwargs.pop(_TENSORFLOW_TILE_KEYWORD)
|
|
301
|
+
return jnp.tile(*args, **kwargs)
|
|
302
|
+
|
|
303
|
+
|
|
198
304
|
def _jax_unique_with_counts(x):
|
|
199
305
|
"""JAX implementation for unique_with_counts."""
|
|
200
306
|
import jax.numpy as jnp
|
|
@@ -304,6 +410,13 @@ def _jax_tensor_shape(dims):
|
|
|
304
410
|
return tuple(dims)
|
|
305
411
|
|
|
306
412
|
|
|
413
|
+
def _jax_transpose(a, perm=None):
|
|
414
|
+
"""JAX wrapper for transpose to support the 'perm' keyword argument."""
|
|
415
|
+
import jax.numpy as jnp
|
|
416
|
+
|
|
417
|
+
return jnp.transpose(a, axes=perm)
|
|
418
|
+
|
|
419
|
+
|
|
307
420
|
# --- Backend Initialization ---
|
|
308
421
|
_BACKEND = config.get_backend()
|
|
309
422
|
|
|
@@ -331,13 +444,60 @@ if _BACKEND == config.Backend.JAX:
|
|
|
331
444
|
InvalidArgumentError = ValueError
|
|
332
445
|
# pylint: enable=invalid-name
|
|
333
446
|
|
|
447
|
+
class _JaxRandom:
|
|
448
|
+
"""Provides JAX-based random number generation utilities.
|
|
449
|
+
|
|
450
|
+
This class mirrors the structure needed by `RNGHandler` for JAX.
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
@staticmethod
|
|
454
|
+
def prng_key(seed):
|
|
455
|
+
return jax.random.PRNGKey(seed)
|
|
456
|
+
|
|
457
|
+
@staticmethod
|
|
458
|
+
def split(key):
|
|
459
|
+
return jax.random.split(key)
|
|
460
|
+
|
|
461
|
+
@staticmethod
|
|
462
|
+
def generator_from_seed(seed):
|
|
463
|
+
raise NotImplementedError("JAX backend does not use Generators.")
|
|
464
|
+
|
|
465
|
+
@staticmethod
|
|
466
|
+
def stateless_split(seed: Any, num: int = 2):
|
|
467
|
+
raise NotImplementedError(
|
|
468
|
+
"Direct stateless splitting from an integer seed is not the primary"
|
|
469
|
+
" pattern used in the JAX backend. Use `backend.random.split(key)`"
|
|
470
|
+
" instead."
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
@staticmethod
|
|
474
|
+
def stateless_randint(key, shape, minval, maxval, dtype=jax_ops.int32):
|
|
475
|
+
"""Wrapper for jax.random.randint."""
|
|
476
|
+
return jax.random.randint(
|
|
477
|
+
key, shape, minval=minval, maxval=maxval, dtype=dtype
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
@staticmethod
|
|
481
|
+
def stateless_uniform(
|
|
482
|
+
key, shape, dtype=jax_ops.float32, minval=0.0, maxval=1.0
|
|
483
|
+
):
|
|
484
|
+
"""Replacement for tfp_jax.random.stateless_uniform using jax.random.uniform."""
|
|
485
|
+
return jax.random.uniform(
|
|
486
|
+
key, shape=shape, dtype=dtype, minval=minval, maxval=maxval
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
@staticmethod
|
|
490
|
+
def sanitize_seed(seed):
|
|
491
|
+
return tfp_jax.random.sanitize_seed(seed)
|
|
492
|
+
|
|
493
|
+
random = _JaxRandom()
|
|
494
|
+
|
|
334
495
|
_ops = jax_ops
|
|
335
496
|
errors = _JaxErrors()
|
|
336
497
|
Tensor = jax.Array
|
|
337
498
|
tfd = tfp_jax.distributions
|
|
338
499
|
bijectors = tfp_jax.bijectors
|
|
339
500
|
experimental = tfp_jax.experimental
|
|
340
|
-
random = tfp_jax.random
|
|
341
501
|
mcmc = tfp_jax.mcmc
|
|
342
502
|
_convert_to_tensor = _ops.asarray
|
|
343
503
|
|
|
@@ -349,7 +509,7 @@ if _BACKEND == config.Backend.JAX:
|
|
|
349
509
|
boolean_mask = _jax_boolean_mask
|
|
350
510
|
concatenate = _ops.concatenate
|
|
351
511
|
stack = _ops.stack
|
|
352
|
-
split =
|
|
512
|
+
split = _jax_split
|
|
353
513
|
zeros = _ops.zeros
|
|
354
514
|
zeros_like = _ops.zeros_like
|
|
355
515
|
ones = _ops.ones
|
|
@@ -358,7 +518,6 @@ if _BACKEND == config.Backend.JAX:
|
|
|
358
518
|
reshape = _ops.reshape
|
|
359
519
|
tile = _ops.tile
|
|
360
520
|
where = _ops.where
|
|
361
|
-
transpose = _ops.transpose
|
|
362
521
|
broadcast_to = _ops.broadcast_to
|
|
363
522
|
broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
|
|
364
523
|
broadcast_to = _ops.broadcast_to
|
|
@@ -372,13 +531,14 @@ if _BACKEND == config.Backend.JAX:
|
|
|
372
531
|
exp = _ops.exp
|
|
373
532
|
expand_dims = _ops.expand_dims
|
|
374
533
|
fill = _jax_fill
|
|
375
|
-
function =
|
|
534
|
+
function = _jax_function_wrapper
|
|
376
535
|
gather = _jax_gather
|
|
377
536
|
get_indices_where = _jax_get_indices_where
|
|
378
537
|
is_nan = _ops.isnan
|
|
379
538
|
log = _ops.log
|
|
380
539
|
make_ndarray = _jax_make_ndarray
|
|
381
540
|
make_tensor_proto = _jax_make_tensor_proto
|
|
541
|
+
nanmedian = _jax_nanmedian
|
|
382
542
|
numpy_function = _jax_numpy_function
|
|
383
543
|
ones = _ops.ones
|
|
384
544
|
ones_like = _ops.ones_like
|
|
@@ -391,10 +551,9 @@ if _BACKEND == config.Backend.JAX:
|
|
|
391
551
|
reduce_sum = _ops.sum
|
|
392
552
|
repeat = _ops.repeat
|
|
393
553
|
reshape = _ops.reshape
|
|
394
|
-
split = _ops.split
|
|
395
554
|
stack = _ops.stack
|
|
396
|
-
tile =
|
|
397
|
-
transpose =
|
|
555
|
+
tile = _jax_tile
|
|
556
|
+
transpose = _jax_transpose
|
|
398
557
|
unique_with_counts = _jax_unique_with_counts
|
|
399
558
|
where = _ops.where
|
|
400
559
|
zeros = _ops.zeros
|
|
@@ -404,6 +563,7 @@ if _BACKEND == config.Backend.JAX:
|
|
|
404
563
|
bool_ = _ops.bool_
|
|
405
564
|
newaxis = _ops.newaxis
|
|
406
565
|
TensorShape = _jax_tensor_shape
|
|
566
|
+
int32 = _ops.int32
|
|
407
567
|
|
|
408
568
|
def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
|
|
409
569
|
raise NotImplementedError(
|
|
@@ -423,10 +583,64 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
423
583
|
Tensor = tf_backend.Tensor
|
|
424
584
|
ExtensionType = _ops.experimental.ExtensionType
|
|
425
585
|
|
|
586
|
+
class _TfRandom:
|
|
587
|
+
"""Provides TensorFlow-based random number generation utilities.
|
|
588
|
+
|
|
589
|
+
This class mirrors the structure needed by `RNGHandler` for TensorFlow.
|
|
590
|
+
"""
|
|
591
|
+
|
|
592
|
+
@staticmethod
|
|
593
|
+
def prng_key(seed):
|
|
594
|
+
raise NotImplementedError(
|
|
595
|
+
"TensorFlow backend does not use explicit PRNG keys for"
|
|
596
|
+
" standard sampling."
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
@staticmethod
|
|
600
|
+
def split(key):
|
|
601
|
+
raise NotImplementedError(
|
|
602
|
+
"TensorFlow backend does not implement explicit key splitting."
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
@staticmethod
|
|
606
|
+
def generator_from_seed(seed):
|
|
607
|
+
return tf_backend.random.Generator.from_seed(seed)
|
|
608
|
+
|
|
609
|
+
@staticmethod
|
|
610
|
+
def stateless_split(
|
|
611
|
+
seed: "tf_backend.Tensor", num: int = 2
|
|
612
|
+
) -> "tf_backend.Tensor":
|
|
613
|
+
return tf_backend.random.experimental.stateless_split(seed, num=num)
|
|
614
|
+
|
|
615
|
+
@staticmethod
|
|
616
|
+
def stateless_randint(
|
|
617
|
+
seed: "tf_backend.Tensor",
|
|
618
|
+
shape: "TensorShapeInstance",
|
|
619
|
+
minval: int,
|
|
620
|
+
maxval: int,
|
|
621
|
+
dtype: Any = _DEFAULT_SEED_DTYPE,
|
|
622
|
+
) -> "tf_backend.Tensor":
|
|
623
|
+
return tf_backend.random.stateless_uniform(
|
|
624
|
+
shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
@staticmethod
|
|
628
|
+
def stateless_uniform(
|
|
629
|
+
seed, shape, minval=0, maxval=None, dtype=tf_backend.float32
|
|
630
|
+
):
|
|
631
|
+
return tf_backend.random.stateless_uniform(
|
|
632
|
+
shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
@staticmethod
|
|
636
|
+
def sanitize_seed(seed):
|
|
637
|
+
return tfp.random.sanitize_seed(seed)
|
|
638
|
+
|
|
639
|
+
random = _TfRandom()
|
|
640
|
+
|
|
426
641
|
tfd = tfp.distributions
|
|
427
642
|
bijectors = tfp.bijectors
|
|
428
643
|
experimental = tfp.experimental
|
|
429
|
-
random = tfp.random
|
|
430
644
|
mcmc = tfp.mcmc
|
|
431
645
|
|
|
432
646
|
_convert_to_tensor = _ops.convert_to_tensor
|
|
@@ -446,7 +660,6 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
446
660
|
reshape = _ops.reshape
|
|
447
661
|
tile = _ops.tile
|
|
448
662
|
where = _ops.where
|
|
449
|
-
transpose = _ops.transpose
|
|
450
663
|
broadcast_to = _ops.broadcast_to
|
|
451
664
|
broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
|
|
452
665
|
broadcast_to = _ops.broadcast_to
|
|
@@ -460,13 +673,14 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
460
673
|
exp = _ops.math.exp
|
|
461
674
|
expand_dims = _ops.expand_dims
|
|
462
675
|
fill = _tf_fill
|
|
463
|
-
function =
|
|
676
|
+
function = _tf_function_wrapper
|
|
464
677
|
gather = _tf_gather
|
|
465
678
|
get_indices_where = _tf_get_indices_where
|
|
466
679
|
is_nan = _ops.math.is_nan
|
|
467
680
|
log = _ops.math.log
|
|
468
681
|
make_ndarray = _ops.make_ndarray
|
|
469
682
|
make_tensor_proto = _ops.make_tensor_proto
|
|
683
|
+
nanmedian = _tf_nanmedian
|
|
470
684
|
numpy_function = _ops.numpy_function
|
|
471
685
|
ones = _ops.ones
|
|
472
686
|
ones_like = _ops.ones_like
|
|
@@ -480,7 +694,6 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
480
694
|
repeat = _ops.repeat
|
|
481
695
|
reshape = _ops.reshape
|
|
482
696
|
set_random_seed = tf_backend.keras.utils.set_random_seed
|
|
483
|
-
split = _ops.split
|
|
484
697
|
stack = _ops.stack
|
|
485
698
|
tile = _ops.tile
|
|
486
699
|
transpose = _ops.transpose
|
|
@@ -493,12 +706,260 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
493
706
|
bool_ = _ops.bool
|
|
494
707
|
newaxis = _ops.newaxis
|
|
495
708
|
TensorShape = _ops.TensorShape
|
|
709
|
+
int32 = _ops.int32
|
|
496
710
|
|
|
497
711
|
else:
|
|
498
712
|
raise ValueError(f"Unsupported backend: {_BACKEND}")
|
|
499
713
|
# pylint: enable=g-import-not-at-top,g-bad-import-order
|
|
500
714
|
|
|
501
715
|
|
|
716
|
+
def _extract_int_seed(s: Any) -> Optional[int]:
|
|
717
|
+
"""Attempts to extract a scalar Python integer from various input types.
|
|
718
|
+
|
|
719
|
+
Args:
|
|
720
|
+
s: The input seed, which can be an int, Tensor, or array-like object.
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
A Python integer if a scalar integer can be extracted, otherwise None.
|
|
724
|
+
"""
|
|
725
|
+
try:
|
|
726
|
+
if isinstance(s, int):
|
|
727
|
+
return s
|
|
728
|
+
|
|
729
|
+
value = np.asarray(s)
|
|
730
|
+
|
|
731
|
+
if value.ndim == 0 and np.issubdtype(value.dtype, np.integer):
|
|
732
|
+
return int(value)
|
|
733
|
+
# A broad exception is used here because the input `s` can be of many types
|
|
734
|
+
# (e.g., JAX PRNGKey) which may cause np.asarray or other operations to fail
|
|
735
|
+
# in unpredictable ways. The goal is to safely attempt extraction and fail
|
|
736
|
+
# gracefully.
|
|
737
|
+
except Exception: # pylint: disable=broad-except
|
|
738
|
+
pass
|
|
739
|
+
return None
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
class _BaseRNGHandler(abc.ABC):
|
|
743
|
+
"""A backend-agnostic abstract base class for random number generation state.
|
|
744
|
+
|
|
745
|
+
This handler provides a stateful-style interface for consuming randomness,
|
|
746
|
+
abstracting away the differences between JAX's stateless PRNG keys and
|
|
747
|
+
TensorFlow's paradigms.
|
|
748
|
+
|
|
749
|
+
Attributes:
|
|
750
|
+
_seed_input: The original seed object provided during initialization.
|
|
751
|
+
_int_seed: A Python integer extracted from the seed, if possible.
|
|
752
|
+
"""
|
|
753
|
+
|
|
754
|
+
def __init__(self, seed: SeedType):
|
|
755
|
+
"""Initializes the RNG handler.
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
seed: The initial seed. The accepted type depends on the backend. For JAX,
|
|
759
|
+
this must be an integer. For TensorFlow, this can be an integer, a
|
|
760
|
+
sequence of two integers, or a Tensor. If None, the handler becomes a
|
|
761
|
+
no-op, returning None for all seed requests.
|
|
762
|
+
"""
|
|
763
|
+
self._seed_input = seed
|
|
764
|
+
self._int_seed: Optional[int] = _extract_int_seed(seed)
|
|
765
|
+
|
|
766
|
+
@abc.abstractmethod
|
|
767
|
+
def get_kernel_seed(self) -> Any:
|
|
768
|
+
"""Provides a backend-appropriate sanitized seed/key for an MCMC kernel.
|
|
769
|
+
|
|
770
|
+
This method exposes the current state of the handler in the format expected
|
|
771
|
+
by the backend's MCMC machinery. It does not advance the internal state.
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
A backend-specific seed object (e.g., a JAX PRNGKey or a TF Tensor).
|
|
775
|
+
"""
|
|
776
|
+
raise NotImplementedError
|
|
777
|
+
|
|
778
|
+
@abc.abstractmethod
|
|
779
|
+
def get_next_seed(self) -> Any:
|
|
780
|
+
"""Provides the appropriate seed object for the next sequential operation.
|
|
781
|
+
|
|
782
|
+
This is primarily used for prior sampling and typically advances the
|
|
783
|
+
internal state.
|
|
784
|
+
|
|
785
|
+
Returns:
|
|
786
|
+
A backend-specific seed object for a single random operation.
|
|
787
|
+
"""
|
|
788
|
+
raise NotImplementedError
|
|
789
|
+
|
|
790
|
+
@abc.abstractmethod
|
|
791
|
+
def advance_handler(self) -> "_BaseRNGHandler":
|
|
792
|
+
"""Creates a new, independent RNGHandler for a subsequent operation.
|
|
793
|
+
|
|
794
|
+
This method is used to generate a new handler derived deterministically
|
|
795
|
+
from the current handler's state.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
A new, independent `RNGHandler` instance.
|
|
799
|
+
"""
|
|
800
|
+
raise NotImplementedError
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
class _JaxRNGHandler(_BaseRNGHandler):
|
|
804
|
+
"""JAX implementation of the RNGHandler using explicit key splitting."""
|
|
805
|
+
|
|
806
|
+
def __init__(self, seed: SeedType):
|
|
807
|
+
"""Initializes the JAX RNG handler.
|
|
808
|
+
|
|
809
|
+
Args:
|
|
810
|
+
seed: The initial seed, which must be a Python integer, a scalar integer
|
|
811
|
+
Tensor/array, or None.
|
|
812
|
+
|
|
813
|
+
Raises:
|
|
814
|
+
ValueError: If the provided seed is not a scalar integer or None.
|
|
815
|
+
"""
|
|
816
|
+
super().__init__(seed)
|
|
817
|
+
self._key: Optional["_jax.Array"] = None
|
|
818
|
+
|
|
819
|
+
if seed is None:
|
|
820
|
+
return
|
|
821
|
+
|
|
822
|
+
if self._int_seed is None:
|
|
823
|
+
raise ValueError(
|
|
824
|
+
"JAX backend requires a seed that is an integer or a scalar array,"
|
|
825
|
+
f" but got: {type(seed)} with value {seed!r}"
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
self._key = random.prng_key(self._int_seed)
|
|
829
|
+
|
|
830
|
+
def get_kernel_seed(self) -> Any:
|
|
831
|
+
if self._key is None:
|
|
832
|
+
return None
|
|
833
|
+
return random.sanitize_seed(self._key)
|
|
834
|
+
|
|
835
|
+
def get_next_seed(self) -> Any:
|
|
836
|
+
if self._key is None:
|
|
837
|
+
return None
|
|
838
|
+
self._key, subkey = random.split(self._key)
|
|
839
|
+
return subkey
|
|
840
|
+
|
|
841
|
+
def advance_handler(self) -> "_JaxRNGHandler":
|
|
842
|
+
if self._key is None:
|
|
843
|
+
return _JaxRNGHandler(None)
|
|
844
|
+
|
|
845
|
+
self._key, subkey_for_new_handler = random.split(self._key)
|
|
846
|
+
new_seed_tensor = random.stateless_randint(
|
|
847
|
+
key=subkey_for_new_handler,
|
|
848
|
+
shape=(),
|
|
849
|
+
minval=0,
|
|
850
|
+
maxval=_MAX_INT32,
|
|
851
|
+
dtype=_DEFAULT_SEED_DTYPE,
|
|
852
|
+
)
|
|
853
|
+
return _JaxRNGHandler(np.asarray(new_seed_tensor).item())
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
# TODO: Replace with _TFRNGHandler
|
|
857
|
+
class _TFLegacyRNGHandler(_BaseRNGHandler):
|
|
858
|
+
"""TensorFlow implementation.
|
|
859
|
+
|
|
860
|
+
TODO: This class should be removed and replaced with a correct,
|
|
861
|
+
stateful `tf.random.Generator`-based implementation.
|
|
862
|
+
"""
|
|
863
|
+
|
|
864
|
+
def __init__(self, seed: SeedType, *, _sanitized_seed: Optional[Any] = None):
|
|
865
|
+
"""Initializes the TensorFlow legacy RNG handler.
|
|
866
|
+
|
|
867
|
+
Args:
|
|
868
|
+
seed: The initial seed. Can be an integer, a sequence of two integers, a
|
|
869
|
+
corresponding Tensor, or None.
|
|
870
|
+
_sanitized_seed: For internal use only. If provided, this pre-computed
|
|
871
|
+
seed tensor is used directly, and the standard initialization logic for
|
|
872
|
+
the public `seed` argument is bypassed.
|
|
873
|
+
|
|
874
|
+
Raises:
|
|
875
|
+
ValueError: If `seed` is a sequence with a length other than 2.
|
|
876
|
+
"""
|
|
877
|
+
super().__init__(seed)
|
|
878
|
+
self._tf_sanitized_seed: Optional[Any] = None
|
|
879
|
+
|
|
880
|
+
if _sanitized_seed is not None:
|
|
881
|
+
# Internal path: A pre-sanitized seed was provided by a trusted source
|
|
882
|
+
# so we adopt it directly.
|
|
883
|
+
self._tf_sanitized_seed = _sanitized_seed
|
|
884
|
+
return
|
|
885
|
+
|
|
886
|
+
if seed is None:
|
|
887
|
+
return
|
|
888
|
+
|
|
889
|
+
if isinstance(seed, Sequence) and len(seed) != 2:
|
|
890
|
+
raise ValueError(
|
|
891
|
+
"Invalid seed: Must be either a single integer (stateful seed) or a"
|
|
892
|
+
" pair of two integers (stateless seed). See"
|
|
893
|
+
" [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
|
|
894
|
+
" for details."
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
if isinstance(seed, int):
|
|
898
|
+
seed_to_sanitize = (seed, seed)
|
|
899
|
+
else:
|
|
900
|
+
seed_to_sanitize = seed
|
|
901
|
+
|
|
902
|
+
self._tf_sanitized_seed = random.sanitize_seed(seed_to_sanitize)
|
|
903
|
+
|
|
904
|
+
def get_kernel_seed(self) -> Any:
|
|
905
|
+
return self._tf_sanitized_seed
|
|
906
|
+
|
|
907
|
+
def get_next_seed(self) -> Any:
|
|
908
|
+
"""Returns the original integer seed to preserve prior sampling behavior.
|
|
909
|
+
|
|
910
|
+
Returns:
|
|
911
|
+
The original integer seed provided during initialization.
|
|
912
|
+
|
|
913
|
+
Raises:
|
|
914
|
+
RuntimeError: If the handler was not initialized with a scalar integer
|
|
915
|
+
seed, which is required for the legacy prior sampling path.
|
|
916
|
+
"""
|
|
917
|
+
if self._seed_input is None:
|
|
918
|
+
return None
|
|
919
|
+
|
|
920
|
+
if self._int_seed is None:
|
|
921
|
+
raise RuntimeError(
|
|
922
|
+
"RNGHandler was not initialized with a scalar integer seed, cannot"
|
|
923
|
+
" provide seed for TensorFlow prior sampling."
|
|
924
|
+
)
|
|
925
|
+
return self._int_seed
|
|
926
|
+
|
|
927
|
+
def advance_handler(self) -> "_TFLegacyRNGHandler":
|
|
928
|
+
"""Creates a new handler by incrementing the sanitized seed by 1.
|
|
929
|
+
|
|
930
|
+
Returns:
|
|
931
|
+
A new `_TFLegacyRNGHandler` instance with an incremented seed state.
|
|
932
|
+
|
|
933
|
+
Raises:
|
|
934
|
+
RuntimeError: If the handler's sanitized seed was not initialized.
|
|
935
|
+
"""
|
|
936
|
+
if self._seed_input is None:
|
|
937
|
+
return _TFLegacyRNGHandler(None)
|
|
938
|
+
|
|
939
|
+
if self._tf_sanitized_seed is None:
|
|
940
|
+
# Should be caught during init, but included for defensive programming.
|
|
941
|
+
raise RuntimeError("RNGHandler sanitized seed not initialized.")
|
|
942
|
+
|
|
943
|
+
new_sanitized_seed = self._tf_sanitized_seed + 1
|
|
944
|
+
|
|
945
|
+
# Create a new handler instance, passing the original seed input (to
|
|
946
|
+
# preserve state like `_int_seed`) and injecting the new sanitized seed
|
|
947
|
+
# via the private constructor argument.
|
|
948
|
+
return _TFLegacyRNGHandler(
|
|
949
|
+
self._seed_input, _sanitized_seed=new_sanitized_seed
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
if _BACKEND == config.Backend.JAX:
|
|
954
|
+
RNGHandler = _JaxRNGHandler
|
|
955
|
+
elif _BACKEND == config.Backend.TENSORFLOW:
|
|
956
|
+
RNGHandler = (
|
|
957
|
+
_TFLegacyRNGHandler # TODO: Replace with _TFRNGHandler
|
|
958
|
+
)
|
|
959
|
+
else:
|
|
960
|
+
raise ImportError(f"RNGHandler not implemented for backend: {_BACKEND}")
|
|
961
|
+
|
|
962
|
+
|
|
502
963
|
def to_tensor(data: Any, dtype: Optional[Any] = None) -> Tensor: # type: ignore
|
|
503
964
|
"""Converts input data to the currently active backend tensor type.
|
|
504
965
|
|