google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/METADATA +8 -2
- google_meridian-1.2.1.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +621 -393
- meridian/analysis/optimizer.py +403 -351
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +53 -54
- meridian/backend/__init__.py +975 -0
- meridian/backend/config.py +118 -0
- meridian/backend/test_utils.py +181 -0
- meridian/constants.py +71 -10
- meridian/data/input_data.py +99 -0
- meridian/data/test_utils.py +146 -12
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +280 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +735 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +331 -159
- meridian/model/posterior_sampler.py +388 -383
- meridian/model/prior_distribution.py +612 -177
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +55 -49
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,975 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Backend Abstraction Layer for Meridian."""
|
|
16
|
+
|
|
17
|
+
import abc
|
|
18
|
+
import functools
|
|
19
|
+
import os
|
|
20
|
+
from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
|
21
|
+
|
|
22
|
+
from meridian.backend import config
|
|
23
|
+
import numpy as np
|
|
24
|
+
from typing_extensions import Literal
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# The conditional imports in this module are a deliberate design choice for the
|
|
28
|
+
# backend abstraction layer. The TFP-on-JAX substrate provides a nearly
|
|
29
|
+
# identical API to the standard TFP library, making an alias-based approach more
|
|
30
|
+
# pragmatic than a full Abstract Base Class implementation, which would require
|
|
31
|
+
# extensive boilerplate.
|
|
32
|
+
# pylint: disable=g-import-not-at-top,g-bad-import-order
|
|
33
|
+
|
|
34
|
+
_DEFAULT_FLOAT = "float32"
|
|
35
|
+
_DEFAULT_INT = "int64"
|
|
36
|
+
|
|
37
|
+
_TENSORFLOW_TILE_KEYWORD = "multiples"
|
|
38
|
+
_JAX_TILE_KEYWORD = "reps"
|
|
39
|
+
|
|
40
|
+
_ARG_JIT_COMPILE = "jit_compile"
|
|
41
|
+
_ARG_STATIC_ARGNUMS = "static_argnums"
|
|
42
|
+
_ARG_STATIC_ARGNAMES = "static_argnames"
|
|
43
|
+
|
|
44
|
+
_DEFAULT_SEED_DTYPE = "int32"
|
|
45
|
+
_MAX_INT32 = np.iinfo(np.int32).max
|
|
46
|
+
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
import dataclasses
|
|
49
|
+
import jax as _jax
|
|
50
|
+
import tensorflow as _tf
|
|
51
|
+
|
|
52
|
+
TensorShapeInstance = Union[_tf.TensorShape, Tuple[int, ...]]
|
|
53
|
+
|
|
54
|
+
SeedType = Any
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def standardize_dtype(dtype: Any) -> str:
|
|
58
|
+
"""Converts a backend-specific dtype to a standard string representation.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
dtype: A backend-specific dtype object (e.g., tf.DType, np.dtype).
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
A canonical string representation of the dtype (e.g., 'float32').
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
# Handle None explicitly, as np.dtype(None) defaults to float64.
|
|
68
|
+
|
|
69
|
+
if dtype is None:
|
|
70
|
+
return str(None)
|
|
71
|
+
|
|
72
|
+
if hasattr(dtype, "as_numpy_dtype"):
|
|
73
|
+
dtype = dtype.as_numpy_dtype
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
return np.dtype(dtype).name
|
|
77
|
+
except TypeError:
|
|
78
|
+
return str(dtype)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def result_type(*types: Any) -> str:
|
|
82
|
+
"""Infers the result dtype from a list of input types, backend-agnostically.
|
|
83
|
+
|
|
84
|
+
This acts as the single source of truth for type promotion rules. The
|
|
85
|
+
promotion logic is designed to be consistent across all backends.
|
|
86
|
+
|
|
87
|
+
Rule: If any input is a float, the result is float32. Otherwise, the result
|
|
88
|
+
is int64 to match NumPy/JAX's default behavior for precision.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
*types: A variable number of type objects (e.g., `<class 'int'>`,
|
|
92
|
+
np.dtype('float32')).
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A string representing the promoted dtype.
|
|
96
|
+
"""
|
|
97
|
+
standardized_types = []
|
|
98
|
+
for t in types:
|
|
99
|
+
if t is None:
|
|
100
|
+
continue
|
|
101
|
+
try:
|
|
102
|
+
# Standardize the input type before checking promotion rules.
|
|
103
|
+
standardized_types.append(standardize_dtype(t))
|
|
104
|
+
except Exception: # pylint: disable=broad-except
|
|
105
|
+
# Fallback if standardization fails for an unexpected type.
|
|
106
|
+
standardized_types.append(str(t))
|
|
107
|
+
|
|
108
|
+
if any("float" in t for t in standardized_types):
|
|
109
|
+
return _DEFAULT_FLOAT
|
|
110
|
+
return _DEFAULT_INT
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
|
|
114
|
+
"""Resolves the final dtype for an operation.
|
|
115
|
+
|
|
116
|
+
If a dtype is explicitly provided, it's returned. Otherwise, it infers the
|
|
117
|
+
dtype from the input arguments using the backend-agnostic `result_type`
|
|
118
|
+
promotion rules.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
dtype: The user-provided dtype, which may be None.
|
|
122
|
+
*args: The input arguments to the operation, used for dtype inference.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
A string representing the resolved dtype.
|
|
126
|
+
"""
|
|
127
|
+
if dtype is not None:
|
|
128
|
+
return standardize_dtype(dtype)
|
|
129
|
+
|
|
130
|
+
input_types = [
|
|
131
|
+
getattr(arg, "dtype", type(arg)) for arg in args if arg is not None
|
|
132
|
+
]
|
|
133
|
+
return result_type(*input_types)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# --- Private Backend-Specific Implementations ---
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _jax_arange(
|
|
140
|
+
start: Any,
|
|
141
|
+
stop: Optional[Any] = None,
|
|
142
|
+
step: Any = 1,
|
|
143
|
+
dtype: Optional[Any] = None,
|
|
144
|
+
) -> "_jax.Array":
|
|
145
|
+
"""JAX implementation for arange."""
|
|
146
|
+
|
|
147
|
+
# Import locally to make the function self-contained.
|
|
148
|
+
|
|
149
|
+
import jax.numpy as jnp
|
|
150
|
+
|
|
151
|
+
resolved_dtype = _resolve_dtype(dtype, start, stop, step)
|
|
152
|
+
return jnp.arange(start, stop, step=step, dtype=resolved_dtype)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _tf_arange(
|
|
156
|
+
start: Any,
|
|
157
|
+
stop: Optional[Any] = None,
|
|
158
|
+
step: Any = 1,
|
|
159
|
+
dtype: Optional[Any] = None,
|
|
160
|
+
) -> "_tf.Tensor":
|
|
161
|
+
"""TensorFlow implementation for arange."""
|
|
162
|
+
import tensorflow as tf
|
|
163
|
+
|
|
164
|
+
resolved_dtype = _resolve_dtype(dtype, start, stop, step)
|
|
165
|
+
try:
|
|
166
|
+
return tf.range(start, limit=stop, delta=step, dtype=resolved_dtype)
|
|
167
|
+
except tf.errors.NotFoundError:
|
|
168
|
+
result = tf.range(start, limit=stop, delta=step, dtype=tf.float32)
|
|
169
|
+
return tf.cast(result, resolved_dtype)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
|
|
173
|
+
"""JAX implementation for cast."""
|
|
174
|
+
return x.astype(dtype)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def _jax_divide_no_nan(x, y):
|
|
178
|
+
"""JAX implementation for divide_no_nan."""
|
|
179
|
+
import jax.numpy as jnp
|
|
180
|
+
|
|
181
|
+
return jnp.where(y != 0, jnp.divide(x, y), 0.0)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _jax_function_wrapper(func=None, **kwargs):
|
|
185
|
+
"""A wrapper for jax.jit that handles TF-like args and static args.
|
|
186
|
+
|
|
187
|
+
This wrapper provides compatibility with TensorFlow's `tf.function` arguments
|
|
188
|
+
and improves ergonomics when decorating class methods in JAX.
|
|
189
|
+
|
|
190
|
+
By default, if neither `static_argnums` nor `static_argnames` are provided, it
|
|
191
|
+
defaults `static_argnums` to `(0,)`. This assumes the function is a method
|
|
192
|
+
where the first argument (`self` or `cls`) should be treated as static.
|
|
193
|
+
|
|
194
|
+
To disable this behavior for plain functions, explicitly provide an empty
|
|
195
|
+
tuple: `@backend.function(static_argnums=())`.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
func: The function to wrap.
|
|
199
|
+
**kwargs: Keyword arguments passed to jax.jit. TF-specific arguments (like
|
|
200
|
+
`jit_compile`) are ignored.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
The wrapped function or a decorator.
|
|
204
|
+
"""
|
|
205
|
+
jit_kwargs = kwargs.copy()
|
|
206
|
+
|
|
207
|
+
jit_kwargs.pop(_ARG_JIT_COMPILE, None)
|
|
208
|
+
|
|
209
|
+
if _ARG_STATIC_ARGNUMS in jit_kwargs:
|
|
210
|
+
if not jit_kwargs[_ARG_STATIC_ARGNUMS]:
|
|
211
|
+
jit_kwargs.pop(_ARG_STATIC_ARGNUMS)
|
|
212
|
+
else:
|
|
213
|
+
jit_kwargs[_ARG_STATIC_ARGNUMS] = (0,)
|
|
214
|
+
|
|
215
|
+
decorator = functools.partial(jax.jit, **jit_kwargs)
|
|
216
|
+
|
|
217
|
+
if func:
|
|
218
|
+
return decorator(func)
|
|
219
|
+
return decorator
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _tf_function_wrapper(func=None, **kwargs):
|
|
223
|
+
"""A wrapper for tf.function that ignores JAX-specific arguments."""
|
|
224
|
+
import tensorflow as tf
|
|
225
|
+
|
|
226
|
+
kwargs.pop(_ARG_STATIC_ARGNAMES, None)
|
|
227
|
+
kwargs.pop(_ARG_STATIC_ARGNUMS, None)
|
|
228
|
+
|
|
229
|
+
decorator = tf.function(**kwargs)
|
|
230
|
+
|
|
231
|
+
if func:
|
|
232
|
+
return decorator(func)
|
|
233
|
+
return decorator
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _jax_nanmedian(a, axis=None):
|
|
237
|
+
"""JAX implementation for nanmedian."""
|
|
238
|
+
import jax.numpy as jnp
|
|
239
|
+
|
|
240
|
+
return jnp.nanmedian(a, axis=axis)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _tf_nanmedian(a, axis=None):
|
|
244
|
+
"""TensorFlow implementation for nanmedian using numpy_function."""
|
|
245
|
+
import tensorflow as tf
|
|
246
|
+
|
|
247
|
+
return tf.numpy_function(
|
|
248
|
+
lambda x: np.nanmedian(x, axis=axis).astype(x.dtype), [a], a.dtype
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
|
|
253
|
+
raise NotImplementedError(
|
|
254
|
+
"backend.numpy_function is not implemented for the JAX backend."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _jax_make_tensor_proto(*args, **kwargs): # pylint: disable=unused-argument
|
|
259
|
+
raise NotImplementedError(
|
|
260
|
+
"backend.make_tensor_proto is not implemented for the JAX backend."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _jax_make_ndarray(*args, **kwargs): # pylint: disable=unused-argument
|
|
265
|
+
raise NotImplementedError(
|
|
266
|
+
"backend.make_ndarray is not implemented for the JAX backend."
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _jax_get_indices_where(condition):
|
|
271
|
+
"""JAX implementation for get_indices_where."""
|
|
272
|
+
import jax.numpy as jnp
|
|
273
|
+
|
|
274
|
+
return jnp.stack(jnp.where(condition), axis=-1)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _tf_get_indices_where(condition):
|
|
278
|
+
"""TensorFlow implementation for get_indices_where."""
|
|
279
|
+
import tensorflow as tf
|
|
280
|
+
|
|
281
|
+
return tf.where(condition)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _jax_split(value, num_or_size_splits, axis=0):
|
|
285
|
+
"""JAX implementation for split that accepts size splits."""
|
|
286
|
+
import jax.numpy as jnp
|
|
287
|
+
|
|
288
|
+
if not isinstance(num_or_size_splits, int):
|
|
289
|
+
indices = jnp.cumsum(jnp.array(num_or_size_splits))[:-1]
|
|
290
|
+
return jnp.split(value, indices, axis=axis)
|
|
291
|
+
|
|
292
|
+
return jnp.split(value, num_or_size_splits, axis=axis)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _jax_tile(*args, **kwargs):
|
|
296
|
+
"""JAX wrapper for tile that supports the `multiples` keyword argument."""
|
|
297
|
+
import jax.numpy as jnp
|
|
298
|
+
|
|
299
|
+
if _TENSORFLOW_TILE_KEYWORD in kwargs:
|
|
300
|
+
kwargs[_JAX_TILE_KEYWORD] = kwargs.pop(_TENSORFLOW_TILE_KEYWORD)
|
|
301
|
+
return jnp.tile(*args, **kwargs)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _jax_unique_with_counts(x):
|
|
305
|
+
"""JAX implementation for unique_with_counts."""
|
|
306
|
+
import jax.numpy as jnp
|
|
307
|
+
|
|
308
|
+
y, counts = jnp.unique(x, return_counts=True)
|
|
309
|
+
# The TF version returns a tuple of (y, idx, count). The idx is not used in
|
|
310
|
+
# the calling code, so we can return None for it to maintain tuple structure.
|
|
311
|
+
return y, None, counts
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _tf_unique_with_counts(x):
|
|
315
|
+
"""TensorFlow implementation for unique_with_counts."""
|
|
316
|
+
import tensorflow as tf
|
|
317
|
+
|
|
318
|
+
return tf.unique_with_counts(x)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _jax_boolean_mask(tensor, mask, axis=None):
|
|
322
|
+
"""JAX implementation for boolean_mask that supports an axis argument."""
|
|
323
|
+
import jax.numpy as jnp
|
|
324
|
+
|
|
325
|
+
if axis is None:
|
|
326
|
+
axis = 0
|
|
327
|
+
tensor_swapped = jnp.moveaxis(tensor, axis, 0)
|
|
328
|
+
masked = tensor_swapped[mask]
|
|
329
|
+
return jnp.moveaxis(masked, 0, axis)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _tf_boolean_mask(tensor, mask, axis=None):
|
|
333
|
+
"""TensorFlow implementation for boolean_mask."""
|
|
334
|
+
import tensorflow as tf
|
|
335
|
+
|
|
336
|
+
return tf.boolean_mask(tensor, mask, axis=axis)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _jax_gather(params, indices):
|
|
340
|
+
"""JAX implementation for gather."""
|
|
341
|
+
# JAX uses standard array indexing for gather operations.
|
|
342
|
+
return params[indices]
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def _tf_gather(params, indices):
|
|
346
|
+
"""TensorFlow implementation for gather."""
|
|
347
|
+
import tensorflow as tf
|
|
348
|
+
|
|
349
|
+
return tf.gather(params, indices)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def _jax_fill(dims, value):
|
|
353
|
+
"""JAX implementation for fill."""
|
|
354
|
+
import jax.numpy as jnp
|
|
355
|
+
|
|
356
|
+
return jnp.full(dims, value)
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def _tf_fill(dims, value):
|
|
360
|
+
"""TensorFlow implementation for fill."""
|
|
361
|
+
import tensorflow as tf
|
|
362
|
+
|
|
363
|
+
return tf.fill(dims, value)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _jax_argmax(tensor, axis=None):
|
|
367
|
+
"""JAX implementation for argmax, aligned with TensorFlow's default.
|
|
368
|
+
|
|
369
|
+
This function finds the indices of the maximum values along a specified axis.
|
|
370
|
+
Crucially, it mimics the default behavior of TensorFlow's `tf.argmax`, where
|
|
371
|
+
if `axis` is `None`, the operation defaults to `axis=0`. This differs from
|
|
372
|
+
NumPy's and JAX's native `argmax` behavior, which would flatten the array
|
|
373
|
+
before finding the index.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
tensor: The input JAX array.
|
|
377
|
+
axis: An integer specifying the axis along which to find the index of the
|
|
378
|
+
maximum value. If `None`, it defaults to `0` to match TensorFlow's
|
|
379
|
+
behavior.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
A JAX array containing the indices of the maximum values.
|
|
383
|
+
"""
|
|
384
|
+
import jax.numpy as jnp
|
|
385
|
+
|
|
386
|
+
if axis is None:
|
|
387
|
+
axis = 0
|
|
388
|
+
return jnp.argmax(tensor, axis=axis)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def _tf_argmax(tensor, axis=None):
|
|
392
|
+
"""TensorFlow implementation for argmax."""
|
|
393
|
+
import tensorflow as tf
|
|
394
|
+
|
|
395
|
+
return tf.argmax(tensor, axis=axis)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _jax_broadcast_dynamic_shape(shape_x, shape_y):
|
|
399
|
+
"""JAX implementation for broadcast_dynamic_shape."""
|
|
400
|
+
import jax.numpy as jnp
|
|
401
|
+
|
|
402
|
+
return jnp.broadcast_shapes(shape_x, shape_y)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def _jax_tensor_shape(dims):
|
|
406
|
+
"""JAX implementation for TensorShape."""
|
|
407
|
+
if isinstance(dims, int):
|
|
408
|
+
return (dims,)
|
|
409
|
+
|
|
410
|
+
return tuple(dims)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def _jax_transpose(a, perm=None):
|
|
414
|
+
"""JAX wrapper for transpose to support the 'perm' keyword argument."""
|
|
415
|
+
import jax.numpy as jnp
|
|
416
|
+
|
|
417
|
+
return jnp.transpose(a, axes=perm)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
# --- Backend Initialization ---
|
|
421
|
+
_BACKEND = config.get_backend()
|
|
422
|
+
|
|
423
|
+
# We expose standardized functions directly at the module level (backend.foo)
|
|
424
|
+
# to provide a consistent, NumPy-like API across backends. The '_ops' object
|
|
425
|
+
# is a private member for accessing the full, raw backend library if necessary,
|
|
426
|
+
# but usage should prefer the top-level standardized functions.
|
|
427
|
+
|
|
428
|
+
if _BACKEND == config.Backend.JAX:
|
|
429
|
+
import jax
|
|
430
|
+
import jax.numpy as jax_ops
|
|
431
|
+
import tensorflow_probability.substrates.jax as tfp_jax
|
|
432
|
+
|
|
433
|
+
class ExtensionType:
|
|
434
|
+
"""A JAX-compatible stand-in for tf.experimental.ExtensionType."""
|
|
435
|
+
|
|
436
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
437
|
+
raise NotImplementedError(
|
|
438
|
+
"ExtensionType is not yet implemented for the JAX backend."
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
class _JaxErrors:
|
|
442
|
+
# pylint: disable=invalid-name
|
|
443
|
+
ResourceExhaustedError = MemoryError
|
|
444
|
+
InvalidArgumentError = ValueError
|
|
445
|
+
# pylint: enable=invalid-name
|
|
446
|
+
|
|
447
|
+
class _JaxRandom:
|
|
448
|
+
"""Provides JAX-based random number generation utilities.
|
|
449
|
+
|
|
450
|
+
This class mirrors the structure needed by `RNGHandler` for JAX.
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
@staticmethod
|
|
454
|
+
def prng_key(seed):
|
|
455
|
+
return jax.random.PRNGKey(seed)
|
|
456
|
+
|
|
457
|
+
@staticmethod
|
|
458
|
+
def split(key):
|
|
459
|
+
return jax.random.split(key)
|
|
460
|
+
|
|
461
|
+
@staticmethod
|
|
462
|
+
def generator_from_seed(seed):
|
|
463
|
+
raise NotImplementedError("JAX backend does not use Generators.")
|
|
464
|
+
|
|
465
|
+
@staticmethod
|
|
466
|
+
def stateless_split(seed: Any, num: int = 2):
|
|
467
|
+
raise NotImplementedError(
|
|
468
|
+
"Direct stateless splitting from an integer seed is not the primary"
|
|
469
|
+
" pattern used in the JAX backend. Use `backend.random.split(key)`"
|
|
470
|
+
" instead."
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
@staticmethod
|
|
474
|
+
def stateless_randint(key, shape, minval, maxval, dtype=jax_ops.int32):
|
|
475
|
+
"""Wrapper for jax.random.randint."""
|
|
476
|
+
return jax.random.randint(
|
|
477
|
+
key, shape, minval=minval, maxval=maxval, dtype=dtype
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
@staticmethod
|
|
481
|
+
def stateless_uniform(
|
|
482
|
+
key, shape, dtype=jax_ops.float32, minval=0.0, maxval=1.0
|
|
483
|
+
):
|
|
484
|
+
"""Replacement for tfp_jax.random.stateless_uniform using jax.random.uniform."""
|
|
485
|
+
return jax.random.uniform(
|
|
486
|
+
key, shape=shape, dtype=dtype, minval=minval, maxval=maxval
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
@staticmethod
|
|
490
|
+
def sanitize_seed(seed):
|
|
491
|
+
return tfp_jax.random.sanitize_seed(seed)
|
|
492
|
+
|
|
493
|
+
random = _JaxRandom()
|
|
494
|
+
|
|
495
|
+
_ops = jax_ops
|
|
496
|
+
errors = _JaxErrors()
|
|
497
|
+
Tensor = jax.Array
|
|
498
|
+
tfd = tfp_jax.distributions
|
|
499
|
+
bijectors = tfp_jax.bijectors
|
|
500
|
+
experimental = tfp_jax.experimental
|
|
501
|
+
mcmc = tfp_jax.mcmc
|
|
502
|
+
_convert_to_tensor = _ops.asarray
|
|
503
|
+
|
|
504
|
+
# Standardized Public API
|
|
505
|
+
absolute = _ops.abs
|
|
506
|
+
allclose = _ops.allclose
|
|
507
|
+
arange = _jax_arange
|
|
508
|
+
argmax = _jax_argmax
|
|
509
|
+
boolean_mask = _jax_boolean_mask
|
|
510
|
+
concatenate = _ops.concatenate
|
|
511
|
+
stack = _ops.stack
|
|
512
|
+
split = _jax_split
|
|
513
|
+
zeros = _ops.zeros
|
|
514
|
+
zeros_like = _ops.zeros_like
|
|
515
|
+
ones = _ops.ones
|
|
516
|
+
ones_like = _ops.ones_like
|
|
517
|
+
repeat = _ops.repeat
|
|
518
|
+
reshape = _ops.reshape
|
|
519
|
+
tile = _ops.tile
|
|
520
|
+
where = _ops.where
|
|
521
|
+
broadcast_to = _ops.broadcast_to
|
|
522
|
+
broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
|
|
523
|
+
broadcast_to = _ops.broadcast_to
|
|
524
|
+
cast = _jax_cast
|
|
525
|
+
concatenate = _ops.concatenate
|
|
526
|
+
cumsum = _ops.cumsum
|
|
527
|
+
divide = _ops.divide
|
|
528
|
+
divide_no_nan = _jax_divide_no_nan
|
|
529
|
+
einsum = _ops.einsum
|
|
530
|
+
equal = _ops.equal
|
|
531
|
+
exp = _ops.exp
|
|
532
|
+
expand_dims = _ops.expand_dims
|
|
533
|
+
fill = _jax_fill
|
|
534
|
+
function = _jax_function_wrapper
|
|
535
|
+
gather = _jax_gather
|
|
536
|
+
get_indices_where = _jax_get_indices_where
|
|
537
|
+
is_nan = _ops.isnan
|
|
538
|
+
log = _ops.log
|
|
539
|
+
make_ndarray = _jax_make_ndarray
|
|
540
|
+
make_tensor_proto = _jax_make_tensor_proto
|
|
541
|
+
nanmedian = _jax_nanmedian
|
|
542
|
+
numpy_function = _jax_numpy_function
|
|
543
|
+
ones = _ops.ones
|
|
544
|
+
ones_like = _ops.ones_like
|
|
545
|
+
rank = _ops.ndim
|
|
546
|
+
reduce_any = _ops.any
|
|
547
|
+
reduce_max = _ops.max
|
|
548
|
+
reduce_mean = _ops.mean
|
|
549
|
+
reduce_min = _ops.min
|
|
550
|
+
reduce_std = _ops.std
|
|
551
|
+
reduce_sum = _ops.sum
|
|
552
|
+
repeat = _ops.repeat
|
|
553
|
+
reshape = _ops.reshape
|
|
554
|
+
stack = _ops.stack
|
|
555
|
+
tile = _jax_tile
|
|
556
|
+
transpose = _jax_transpose
|
|
557
|
+
unique_with_counts = _jax_unique_with_counts
|
|
558
|
+
where = _ops.where
|
|
559
|
+
zeros = _ops.zeros
|
|
560
|
+
zeros_like = _ops.zeros_like
|
|
561
|
+
|
|
562
|
+
float32 = _ops.float32
|
|
563
|
+
bool_ = _ops.bool_
|
|
564
|
+
newaxis = _ops.newaxis
|
|
565
|
+
TensorShape = _jax_tensor_shape
|
|
566
|
+
int32 = _ops.int32
|
|
567
|
+
|
|
568
|
+
def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
|
|
569
|
+
raise NotImplementedError(
|
|
570
|
+
"JAX does not support a global, stateful random seed. `set_random_seed`"
|
|
571
|
+
" is not implemented. Instead, you must pass an explicit `seed`"
|
|
572
|
+
" integer directly to the sampling methods (e.g., `sample_prior`),"
|
|
573
|
+
" which will be used to create a JAX PRNGKey internally."
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
elif _BACKEND == config.Backend.TENSORFLOW:
|
|
577
|
+
import tensorflow as tf_backend
|
|
578
|
+
import tensorflow_probability as tfp
|
|
579
|
+
|
|
580
|
+
_ops = tf_backend
|
|
581
|
+
errors = _ops.errors
|
|
582
|
+
|
|
583
|
+
Tensor = tf_backend.Tensor
|
|
584
|
+
ExtensionType = _ops.experimental.ExtensionType
|
|
585
|
+
|
|
586
|
+
class _TfRandom:
|
|
587
|
+
"""Provides TensorFlow-based random number generation utilities.
|
|
588
|
+
|
|
589
|
+
This class mirrors the structure needed by `RNGHandler` for TensorFlow.
|
|
590
|
+
"""
|
|
591
|
+
|
|
592
|
+
@staticmethod
|
|
593
|
+
def prng_key(seed):
|
|
594
|
+
raise NotImplementedError(
|
|
595
|
+
"TensorFlow backend does not use explicit PRNG keys for"
|
|
596
|
+
" standard sampling."
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
@staticmethod
|
|
600
|
+
def split(key):
|
|
601
|
+
raise NotImplementedError(
|
|
602
|
+
"TensorFlow backend does not implement explicit key splitting."
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
@staticmethod
|
|
606
|
+
def generator_from_seed(seed):
|
|
607
|
+
return tf_backend.random.Generator.from_seed(seed)
|
|
608
|
+
|
|
609
|
+
@staticmethod
|
|
610
|
+
def stateless_split(
|
|
611
|
+
seed: "tf_backend.Tensor", num: int = 2
|
|
612
|
+
) -> "tf_backend.Tensor":
|
|
613
|
+
return tf_backend.random.experimental.stateless_split(seed, num=num)
|
|
614
|
+
|
|
615
|
+
@staticmethod
|
|
616
|
+
def stateless_randint(
|
|
617
|
+
seed: "tf_backend.Tensor",
|
|
618
|
+
shape: "TensorShapeInstance",
|
|
619
|
+
minval: int,
|
|
620
|
+
maxval: int,
|
|
621
|
+
dtype: Any = _DEFAULT_SEED_DTYPE,
|
|
622
|
+
) -> "tf_backend.Tensor":
|
|
623
|
+
return tf_backend.random.stateless_uniform(
|
|
624
|
+
shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
@staticmethod
|
|
628
|
+
def stateless_uniform(
|
|
629
|
+
seed, shape, minval=0, maxval=None, dtype=tf_backend.float32
|
|
630
|
+
):
|
|
631
|
+
return tf_backend.random.stateless_uniform(
|
|
632
|
+
shape=shape, seed=seed, minval=minval, maxval=maxval, dtype=dtype
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
@staticmethod
|
|
636
|
+
def sanitize_seed(seed):
|
|
637
|
+
return tfp.random.sanitize_seed(seed)
|
|
638
|
+
|
|
639
|
+
random = _TfRandom()
|
|
640
|
+
|
|
641
|
+
tfd = tfp.distributions
|
|
642
|
+
bijectors = tfp.bijectors
|
|
643
|
+
experimental = tfp.experimental
|
|
644
|
+
mcmc = tfp.mcmc
|
|
645
|
+
|
|
646
|
+
_convert_to_tensor = _ops.convert_to_tensor
|
|
647
|
+
absolute = _ops.math.abs
|
|
648
|
+
allclose = _ops.experimental.numpy.allclose
|
|
649
|
+
arange = _tf_arange
|
|
650
|
+
argmax = _tf_argmax
|
|
651
|
+
boolean_mask = _tf_boolean_mask
|
|
652
|
+
concatenate = _ops.concat
|
|
653
|
+
stack = _ops.stack
|
|
654
|
+
split = _ops.split
|
|
655
|
+
zeros = _ops.zeros
|
|
656
|
+
zeros_like = _ops.zeros_like
|
|
657
|
+
ones = _ops.ones
|
|
658
|
+
ones_like = _ops.ones_like
|
|
659
|
+
repeat = _ops.repeat
|
|
660
|
+
reshape = _ops.reshape
|
|
661
|
+
tile = _ops.tile
|
|
662
|
+
where = _ops.where
|
|
663
|
+
broadcast_to = _ops.broadcast_to
|
|
664
|
+
broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
|
|
665
|
+
broadcast_to = _ops.broadcast_to
|
|
666
|
+
cast = _ops.cast
|
|
667
|
+
concatenate = _ops.concat
|
|
668
|
+
cumsum = _ops.cumsum
|
|
669
|
+
divide = _ops.divide
|
|
670
|
+
divide_no_nan = _ops.math.divide_no_nan
|
|
671
|
+
einsum = _ops.einsum
|
|
672
|
+
equal = _ops.equal
|
|
673
|
+
exp = _ops.math.exp
|
|
674
|
+
expand_dims = _ops.expand_dims
|
|
675
|
+
fill = _tf_fill
|
|
676
|
+
function = _tf_function_wrapper
|
|
677
|
+
gather = _tf_gather
|
|
678
|
+
get_indices_where = _tf_get_indices_where
|
|
679
|
+
is_nan = _ops.math.is_nan
|
|
680
|
+
log = _ops.math.log
|
|
681
|
+
make_ndarray = _ops.make_ndarray
|
|
682
|
+
make_tensor_proto = _ops.make_tensor_proto
|
|
683
|
+
nanmedian = _tf_nanmedian
|
|
684
|
+
numpy_function = _ops.numpy_function
|
|
685
|
+
ones = _ops.ones
|
|
686
|
+
ones_like = _ops.ones_like
|
|
687
|
+
rank = _ops.rank
|
|
688
|
+
reduce_any = _ops.reduce_any
|
|
689
|
+
reduce_max = _ops.reduce_max
|
|
690
|
+
reduce_mean = _ops.reduce_mean
|
|
691
|
+
reduce_min = _ops.reduce_min
|
|
692
|
+
reduce_std = _ops.math.reduce_std
|
|
693
|
+
reduce_sum = _ops.reduce_sum
|
|
694
|
+
repeat = _ops.repeat
|
|
695
|
+
reshape = _ops.reshape
|
|
696
|
+
set_random_seed = tf_backend.keras.utils.set_random_seed
|
|
697
|
+
stack = _ops.stack
|
|
698
|
+
tile = _ops.tile
|
|
699
|
+
transpose = _ops.transpose
|
|
700
|
+
unique_with_counts = _tf_unique_with_counts
|
|
701
|
+
where = _ops.where
|
|
702
|
+
zeros = _ops.zeros
|
|
703
|
+
zeros_like = _ops.zeros_like
|
|
704
|
+
|
|
705
|
+
float32 = _ops.float32
|
|
706
|
+
bool_ = _ops.bool
|
|
707
|
+
newaxis = _ops.newaxis
|
|
708
|
+
TensorShape = _ops.TensorShape
|
|
709
|
+
int32 = _ops.int32
|
|
710
|
+
|
|
711
|
+
else:
|
|
712
|
+
raise ValueError(f"Unsupported backend: {_BACKEND}")
|
|
713
|
+
# pylint: enable=g-import-not-at-top,g-bad-import-order
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
def _extract_int_seed(s: Any) -> Optional[int]:
|
|
717
|
+
"""Attempts to extract a scalar Python integer from various input types.
|
|
718
|
+
|
|
719
|
+
Args:
|
|
720
|
+
s: The input seed, which can be an int, Tensor, or array-like object.
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
A Python integer if a scalar integer can be extracted, otherwise None.
|
|
724
|
+
"""
|
|
725
|
+
try:
|
|
726
|
+
if isinstance(s, int):
|
|
727
|
+
return s
|
|
728
|
+
|
|
729
|
+
value = np.asarray(s)
|
|
730
|
+
|
|
731
|
+
if value.ndim == 0 and np.issubdtype(value.dtype, np.integer):
|
|
732
|
+
return int(value)
|
|
733
|
+
# A broad exception is used here because the input `s` can be of many types
|
|
734
|
+
# (e.g., JAX PRNGKey) which may cause np.asarray or other operations to fail
|
|
735
|
+
# in unpredictable ways. The goal is to safely attempt extraction and fail
|
|
736
|
+
# gracefully.
|
|
737
|
+
except Exception: # pylint: disable=broad-except
|
|
738
|
+
pass
|
|
739
|
+
return None
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
class _BaseRNGHandler(abc.ABC):
|
|
743
|
+
"""A backend-agnostic abstract base class for random number generation state.
|
|
744
|
+
|
|
745
|
+
This handler provides a stateful-style interface for consuming randomness,
|
|
746
|
+
abstracting away the differences between JAX's stateless PRNG keys and
|
|
747
|
+
TensorFlow's paradigms.
|
|
748
|
+
|
|
749
|
+
Attributes:
|
|
750
|
+
_seed_input: The original seed object provided during initialization.
|
|
751
|
+
_int_seed: A Python integer extracted from the seed, if possible.
|
|
752
|
+
"""
|
|
753
|
+
|
|
754
|
+
def __init__(self, seed: SeedType):
|
|
755
|
+
"""Initializes the RNG handler.
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
seed: The initial seed. The accepted type depends on the backend. For JAX,
|
|
759
|
+
this must be an integer. For TensorFlow, this can be an integer, a
|
|
760
|
+
sequence of two integers, or a Tensor. If None, the handler becomes a
|
|
761
|
+
no-op, returning None for all seed requests.
|
|
762
|
+
"""
|
|
763
|
+
self._seed_input = seed
|
|
764
|
+
self._int_seed: Optional[int] = _extract_int_seed(seed)
|
|
765
|
+
|
|
766
|
+
@abc.abstractmethod
|
|
767
|
+
def get_kernel_seed(self) -> Any:
|
|
768
|
+
"""Provides a backend-appropriate sanitized seed/key for an MCMC kernel.
|
|
769
|
+
|
|
770
|
+
This method exposes the current state of the handler in the format expected
|
|
771
|
+
by the backend's MCMC machinery. It does not advance the internal state.
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
A backend-specific seed object (e.g., a JAX PRNGKey or a TF Tensor).
|
|
775
|
+
"""
|
|
776
|
+
raise NotImplementedError
|
|
777
|
+
|
|
778
|
+
@abc.abstractmethod
|
|
779
|
+
def get_next_seed(self) -> Any:
|
|
780
|
+
"""Provides the appropriate seed object for the next sequential operation.
|
|
781
|
+
|
|
782
|
+
This is primarily used for prior sampling and typically advances the
|
|
783
|
+
internal state.
|
|
784
|
+
|
|
785
|
+
Returns:
|
|
786
|
+
A backend-specific seed object for a single random operation.
|
|
787
|
+
"""
|
|
788
|
+
raise NotImplementedError
|
|
789
|
+
|
|
790
|
+
@abc.abstractmethod
|
|
791
|
+
def advance_handler(self) -> "_BaseRNGHandler":
|
|
792
|
+
"""Creates a new, independent RNGHandler for a subsequent operation.
|
|
793
|
+
|
|
794
|
+
This method is used to generate a new handler derived deterministically
|
|
795
|
+
from the current handler's state.
|
|
796
|
+
|
|
797
|
+
Returns:
|
|
798
|
+
A new, independent `RNGHandler` instance.
|
|
799
|
+
"""
|
|
800
|
+
raise NotImplementedError
|
|
801
|
+
|
|
802
|
+
|
|
803
|
+
class _JaxRNGHandler(_BaseRNGHandler):
|
|
804
|
+
"""JAX implementation of the RNGHandler using explicit key splitting."""
|
|
805
|
+
|
|
806
|
+
def __init__(self, seed: SeedType):
|
|
807
|
+
"""Initializes the JAX RNG handler.
|
|
808
|
+
|
|
809
|
+
Args:
|
|
810
|
+
seed: The initial seed, which must be a Python integer, a scalar integer
|
|
811
|
+
Tensor/array, or None.
|
|
812
|
+
|
|
813
|
+
Raises:
|
|
814
|
+
ValueError: If the provided seed is not a scalar integer or None.
|
|
815
|
+
"""
|
|
816
|
+
super().__init__(seed)
|
|
817
|
+
self._key: Optional["_jax.Array"] = None
|
|
818
|
+
|
|
819
|
+
if seed is None:
|
|
820
|
+
return
|
|
821
|
+
|
|
822
|
+
if self._int_seed is None:
|
|
823
|
+
raise ValueError(
|
|
824
|
+
"JAX backend requires a seed that is an integer or a scalar array,"
|
|
825
|
+
f" but got: {type(seed)} with value {seed!r}"
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
self._key = random.prng_key(self._int_seed)
|
|
829
|
+
|
|
830
|
+
def get_kernel_seed(self) -> Any:
|
|
831
|
+
if self._key is None:
|
|
832
|
+
return None
|
|
833
|
+
return random.sanitize_seed(self._key)
|
|
834
|
+
|
|
835
|
+
def get_next_seed(self) -> Any:
|
|
836
|
+
if self._key is None:
|
|
837
|
+
return None
|
|
838
|
+
self._key, subkey = random.split(self._key)
|
|
839
|
+
return subkey
|
|
840
|
+
|
|
841
|
+
def advance_handler(self) -> "_JaxRNGHandler":
|
|
842
|
+
if self._key is None:
|
|
843
|
+
return _JaxRNGHandler(None)
|
|
844
|
+
|
|
845
|
+
self._key, subkey_for_new_handler = random.split(self._key)
|
|
846
|
+
new_seed_tensor = random.stateless_randint(
|
|
847
|
+
key=subkey_for_new_handler,
|
|
848
|
+
shape=(),
|
|
849
|
+
minval=0,
|
|
850
|
+
maxval=_MAX_INT32,
|
|
851
|
+
dtype=_DEFAULT_SEED_DTYPE,
|
|
852
|
+
)
|
|
853
|
+
return _JaxRNGHandler(np.asarray(new_seed_tensor).item())
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
# TODO: Replace with _TFRNGHandler
|
|
857
|
+
class _TFLegacyRNGHandler(_BaseRNGHandler):
|
|
858
|
+
"""TensorFlow implementation.
|
|
859
|
+
|
|
860
|
+
TODO: This class should be removed and replaced with a correct,
|
|
861
|
+
stateful `tf.random.Generator`-based implementation.
|
|
862
|
+
"""
|
|
863
|
+
|
|
864
|
+
def __init__(self, seed: SeedType, *, _sanitized_seed: Optional[Any] = None):
|
|
865
|
+
"""Initializes the TensorFlow legacy RNG handler.
|
|
866
|
+
|
|
867
|
+
Args:
|
|
868
|
+
seed: The initial seed. Can be an integer, a sequence of two integers, a
|
|
869
|
+
corresponding Tensor, or None.
|
|
870
|
+
_sanitized_seed: For internal use only. If provided, this pre-computed
|
|
871
|
+
seed tensor is used directly, and the standard initialization logic for
|
|
872
|
+
the public `seed` argument is bypassed.
|
|
873
|
+
|
|
874
|
+
Raises:
|
|
875
|
+
ValueError: If `seed` is a sequence with a length other than 2.
|
|
876
|
+
"""
|
|
877
|
+
super().__init__(seed)
|
|
878
|
+
self._tf_sanitized_seed: Optional[Any] = None
|
|
879
|
+
|
|
880
|
+
if _sanitized_seed is not None:
|
|
881
|
+
# Internal path: A pre-sanitized seed was provided by a trusted source
|
|
882
|
+
# so we adopt it directly.
|
|
883
|
+
self._tf_sanitized_seed = _sanitized_seed
|
|
884
|
+
return
|
|
885
|
+
|
|
886
|
+
if seed is None:
|
|
887
|
+
return
|
|
888
|
+
|
|
889
|
+
if isinstance(seed, Sequence) and len(seed) != 2:
|
|
890
|
+
raise ValueError(
|
|
891
|
+
"Invalid seed: Must be either a single integer (stateful seed) or a"
|
|
892
|
+
" pair of two integers (stateless seed). See"
|
|
893
|
+
" [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed)"
|
|
894
|
+
" for details."
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
if isinstance(seed, int):
|
|
898
|
+
seed_to_sanitize = (seed, seed)
|
|
899
|
+
else:
|
|
900
|
+
seed_to_sanitize = seed
|
|
901
|
+
|
|
902
|
+
self._tf_sanitized_seed = random.sanitize_seed(seed_to_sanitize)
|
|
903
|
+
|
|
904
|
+
def get_kernel_seed(self) -> Any:
|
|
905
|
+
return self._tf_sanitized_seed
|
|
906
|
+
|
|
907
|
+
def get_next_seed(self) -> Any:
|
|
908
|
+
"""Returns the original integer seed to preserve prior sampling behavior.
|
|
909
|
+
|
|
910
|
+
Returns:
|
|
911
|
+
The original integer seed provided during initialization.
|
|
912
|
+
|
|
913
|
+
Raises:
|
|
914
|
+
RuntimeError: If the handler was not initialized with a scalar integer
|
|
915
|
+
seed, which is required for the legacy prior sampling path.
|
|
916
|
+
"""
|
|
917
|
+
if self._seed_input is None:
|
|
918
|
+
return None
|
|
919
|
+
|
|
920
|
+
if self._int_seed is None:
|
|
921
|
+
raise RuntimeError(
|
|
922
|
+
"RNGHandler was not initialized with a scalar integer seed, cannot"
|
|
923
|
+
" provide seed for TensorFlow prior sampling."
|
|
924
|
+
)
|
|
925
|
+
return self._int_seed
|
|
926
|
+
|
|
927
|
+
def advance_handler(self) -> "_TFLegacyRNGHandler":
|
|
928
|
+
"""Creates a new handler by incrementing the sanitized seed by 1.
|
|
929
|
+
|
|
930
|
+
Returns:
|
|
931
|
+
A new `_TFLegacyRNGHandler` instance with an incremented seed state.
|
|
932
|
+
|
|
933
|
+
Raises:
|
|
934
|
+
RuntimeError: If the handler's sanitized seed was not initialized.
|
|
935
|
+
"""
|
|
936
|
+
if self._seed_input is None:
|
|
937
|
+
return _TFLegacyRNGHandler(None)
|
|
938
|
+
|
|
939
|
+
if self._tf_sanitized_seed is None:
|
|
940
|
+
# Should be caught during init, but included for defensive programming.
|
|
941
|
+
raise RuntimeError("RNGHandler sanitized seed not initialized.")
|
|
942
|
+
|
|
943
|
+
new_sanitized_seed = self._tf_sanitized_seed + 1
|
|
944
|
+
|
|
945
|
+
# Create a new handler instance, passing the original seed input (to
|
|
946
|
+
# preserve state like `_int_seed`) and injecting the new sanitized seed
|
|
947
|
+
# via the private constructor argument.
|
|
948
|
+
return _TFLegacyRNGHandler(
|
|
949
|
+
self._seed_input, _sanitized_seed=new_sanitized_seed
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
if _BACKEND == config.Backend.JAX:
|
|
954
|
+
RNGHandler = _JaxRNGHandler
|
|
955
|
+
elif _BACKEND == config.Backend.TENSORFLOW:
|
|
956
|
+
RNGHandler = (
|
|
957
|
+
_TFLegacyRNGHandler # TODO: Replace with _TFRNGHandler
|
|
958
|
+
)
|
|
959
|
+
else:
|
|
960
|
+
raise ImportError(f"RNGHandler not implemented for backend: {_BACKEND}")
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def to_tensor(data: Any, dtype: Optional[Any] = None) -> Tensor: # type: ignore
|
|
964
|
+
"""Converts input data to the currently active backend tensor type.
|
|
965
|
+
|
|
966
|
+
Args:
|
|
967
|
+
data: The data to convert.
|
|
968
|
+
dtype: The desired data type of the resulting tensor. The accepted types
|
|
969
|
+
depend on the active backend (e.g., jax.numpy.dtype or tf.DType).
|
|
970
|
+
|
|
971
|
+
Returns:
|
|
972
|
+
A tensor representation of the data for the active backend.
|
|
973
|
+
"""
|
|
974
|
+
|
|
975
|
+
return _convert_to_tensor(data, dtype=dtype)
|