google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/METADATA +8 -2
- google_meridian-1.2.0.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +526 -362
- meridian/analysis/optimizer.py +275 -267
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +37 -49
- meridian/backend/__init__.py +514 -0
- meridian/backend/config.py +59 -0
- meridian/backend/test_utils.py +95 -0
- meridian/constants.py +59 -3
- meridian/data/input_data.py +94 -0
- meridian/data/test_utils.py +144 -12
- meridian/model/adstock_hill.py +279 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +306 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +323 -157
- meridian/model/posterior_sampler.py +84 -77
- meridian/model/prior_distribution.py +538 -168
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +53 -47
- meridian/version.py +1 -1
- google_meridian-1.1.5.dist-info/RECORD +0 -47
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,514 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Backend Abstraction Layer for Meridian."""
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
from typing import Any, Optional, TYPE_CHECKING, Tuple, Union
|
|
19
|
+
|
|
20
|
+
from meridian.backend import config
|
|
21
|
+
import numpy as np
|
|
22
|
+
from typing_extensions import Literal
|
|
23
|
+
|
|
24
|
+
# The conditional imports in this module are a deliberate design choice for the
|
|
25
|
+
# backend abstraction layer. The TFP-on-JAX substrate provides a nearly
|
|
26
|
+
# identical API to the standard TFP library, making an alias-based approach more
|
|
27
|
+
# pragmatic than a full Abstract Base Class implementation, which would require
|
|
28
|
+
# extensive boilerplate.
|
|
29
|
+
# pylint: disable=g-import-not-at-top,g-bad-import-order
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
import dataclasses
|
|
33
|
+
import jax as _jax
|
|
34
|
+
import tensorflow as _tf
|
|
35
|
+
|
|
36
|
+
TensorShapeInstance = Union[_tf.TensorShape, Tuple[int, ...]]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def standardize_dtype(dtype: Any) -> str:
|
|
40
|
+
"""Converts a backend-specific dtype to a standard string representation.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
dtype: A backend-specific dtype object (e.g., tf.DType, np.dtype).
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
A canonical string representation of the dtype (e.g., 'float32').
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
# Handle None explicitly, as np.dtype(None) defaults to float64.
|
|
50
|
+
|
|
51
|
+
if dtype is None:
|
|
52
|
+
return str(None)
|
|
53
|
+
|
|
54
|
+
if hasattr(dtype, "as_numpy_dtype"):
|
|
55
|
+
dtype = dtype.as_numpy_dtype
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
return np.dtype(dtype).name
|
|
59
|
+
except TypeError:
|
|
60
|
+
return str(dtype)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def result_type(*types: Any) -> str:
|
|
64
|
+
"""Infers the result dtype from a list of input types, backend-agnostically.
|
|
65
|
+
|
|
66
|
+
This acts as the single source of truth for type promotion rules. The
|
|
67
|
+
promotion logic is designed to be consistent across all backends.
|
|
68
|
+
|
|
69
|
+
Rule: If any input is a float, the result is float32. Otherwise, the result
|
|
70
|
+
is int64 to match NumPy/JAX's default behavior for precision.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
*types: A variable number of type objects (e.g., `<class 'int'>`,
|
|
74
|
+
np.dtype('float32')).
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
A string representing the promoted dtype.
|
|
78
|
+
"""
|
|
79
|
+
standardized_types = []
|
|
80
|
+
for t in types:
|
|
81
|
+
if t is None:
|
|
82
|
+
continue
|
|
83
|
+
try:
|
|
84
|
+
# Standardize the input type before checking promotion rules.
|
|
85
|
+
standardized_types.append(standardize_dtype(t))
|
|
86
|
+
except Exception: # pylint: disable=broad-except
|
|
87
|
+
# Fallback if standardization fails for an unexpected type.
|
|
88
|
+
standardized_types.append(str(t))
|
|
89
|
+
|
|
90
|
+
if any("float" in t for t in standardized_types):
|
|
91
|
+
return "float32"
|
|
92
|
+
return "int64"
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _resolve_dtype(dtype: Optional[Any], *args: Any) -> str:
|
|
96
|
+
"""Resolves the final dtype for an operation.
|
|
97
|
+
|
|
98
|
+
If a dtype is explicitly provided, it's returned. Otherwise, it infers the
|
|
99
|
+
dtype from the input arguments using the backend-agnostic `result_type`
|
|
100
|
+
promotion rules.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
dtype: The user-provided dtype, which may be None.
|
|
104
|
+
*args: The input arguments to the operation, used for dtype inference.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
A string representing the resolved dtype.
|
|
108
|
+
"""
|
|
109
|
+
if dtype is not None:
|
|
110
|
+
return standardize_dtype(dtype)
|
|
111
|
+
|
|
112
|
+
input_types = [
|
|
113
|
+
getattr(arg, "dtype", type(arg)) for arg in args if arg is not None
|
|
114
|
+
]
|
|
115
|
+
return result_type(*input_types)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# --- Private Backend-Specific Implementations ---
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _jax_arange(
|
|
122
|
+
start: Any,
|
|
123
|
+
stop: Optional[Any] = None,
|
|
124
|
+
step: Any = 1,
|
|
125
|
+
dtype: Optional[Any] = None,
|
|
126
|
+
) -> "_jax.Array":
|
|
127
|
+
"""JAX implementation for arange."""
|
|
128
|
+
|
|
129
|
+
# Import locally to make the function self-contained.
|
|
130
|
+
|
|
131
|
+
import jax.numpy as jnp
|
|
132
|
+
|
|
133
|
+
resolved_dtype = _resolve_dtype(dtype, start, stop, step)
|
|
134
|
+
return jnp.arange(start, stop, step=step, dtype=resolved_dtype)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _tf_arange(
|
|
138
|
+
start: Any,
|
|
139
|
+
stop: Optional[Any] = None,
|
|
140
|
+
step: Any = 1,
|
|
141
|
+
dtype: Optional[Any] = None,
|
|
142
|
+
) -> "_tf.Tensor":
|
|
143
|
+
"""TensorFlow implementation for arange."""
|
|
144
|
+
import tensorflow as tf
|
|
145
|
+
|
|
146
|
+
resolved_dtype = _resolve_dtype(dtype, start, stop, step)
|
|
147
|
+
try:
|
|
148
|
+
return tf.range(start, limit=stop, delta=step, dtype=resolved_dtype)
|
|
149
|
+
except tf.errors.NotFoundError:
|
|
150
|
+
result = tf.range(start, limit=stop, delta=step, dtype=tf.float32)
|
|
151
|
+
return tf.cast(result, resolved_dtype)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
|
|
155
|
+
"""JAX implementation for cast."""
|
|
156
|
+
return x.astype(dtype)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _jax_divide_no_nan(x, y):
|
|
160
|
+
"""JAX implementation for divide_no_nan."""
|
|
161
|
+
import jax.numpy as jnp
|
|
162
|
+
|
|
163
|
+
return jnp.where(y != 0, jnp.divide(x, y), 0.0)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
|
|
167
|
+
raise NotImplementedError(
|
|
168
|
+
"backend.numpy_function is not implemented for the JAX backend."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _jax_make_tensor_proto(*args, **kwargs): # pylint: disable=unused-argument
|
|
173
|
+
raise NotImplementedError(
|
|
174
|
+
"backend.make_tensor_proto is not implemented for the JAX backend."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _jax_make_ndarray(*args, **kwargs): # pylint: disable=unused-argument
|
|
179
|
+
raise NotImplementedError(
|
|
180
|
+
"backend.make_ndarray is not implemented for the JAX backend."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _jax_get_indices_where(condition):
|
|
185
|
+
"""JAX implementation for get_indices_where."""
|
|
186
|
+
import jax.numpy as jnp
|
|
187
|
+
|
|
188
|
+
return jnp.stack(jnp.where(condition), axis=-1)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _tf_get_indices_where(condition):
|
|
192
|
+
"""TensorFlow implementation for get_indices_where."""
|
|
193
|
+
import tensorflow as tf
|
|
194
|
+
|
|
195
|
+
return tf.where(condition)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _jax_unique_with_counts(x):
|
|
199
|
+
"""JAX implementation for unique_with_counts."""
|
|
200
|
+
import jax.numpy as jnp
|
|
201
|
+
|
|
202
|
+
y, counts = jnp.unique(x, return_counts=True)
|
|
203
|
+
# The TF version returns a tuple of (y, idx, count). The idx is not used in
|
|
204
|
+
# the calling code, so we can return None for it to maintain tuple structure.
|
|
205
|
+
return y, None, counts
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _tf_unique_with_counts(x):
|
|
209
|
+
"""TensorFlow implementation for unique_with_counts."""
|
|
210
|
+
import tensorflow as tf
|
|
211
|
+
|
|
212
|
+
return tf.unique_with_counts(x)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _jax_boolean_mask(tensor, mask, axis=None):
|
|
216
|
+
"""JAX implementation for boolean_mask that supports an axis argument."""
|
|
217
|
+
import jax.numpy as jnp
|
|
218
|
+
|
|
219
|
+
if axis is None:
|
|
220
|
+
axis = 0
|
|
221
|
+
tensor_swapped = jnp.moveaxis(tensor, axis, 0)
|
|
222
|
+
masked = tensor_swapped[mask]
|
|
223
|
+
return jnp.moveaxis(masked, 0, axis)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _tf_boolean_mask(tensor, mask, axis=None):
|
|
227
|
+
"""TensorFlow implementation for boolean_mask."""
|
|
228
|
+
import tensorflow as tf
|
|
229
|
+
|
|
230
|
+
return tf.boolean_mask(tensor, mask, axis=axis)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _jax_gather(params, indices):
|
|
234
|
+
"""JAX implementation for gather."""
|
|
235
|
+
# JAX uses standard array indexing for gather operations.
|
|
236
|
+
return params[indices]
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _tf_gather(params, indices):
|
|
240
|
+
"""TensorFlow implementation for gather."""
|
|
241
|
+
import tensorflow as tf
|
|
242
|
+
|
|
243
|
+
return tf.gather(params, indices)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _jax_fill(dims, value):
|
|
247
|
+
"""JAX implementation for fill."""
|
|
248
|
+
import jax.numpy as jnp
|
|
249
|
+
|
|
250
|
+
return jnp.full(dims, value)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _tf_fill(dims, value):
|
|
254
|
+
"""TensorFlow implementation for fill."""
|
|
255
|
+
import tensorflow as tf
|
|
256
|
+
|
|
257
|
+
return tf.fill(dims, value)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def _jax_argmax(tensor, axis=None):
|
|
261
|
+
"""JAX implementation for argmax, aligned with TensorFlow's default.
|
|
262
|
+
|
|
263
|
+
This function finds the indices of the maximum values along a specified axis.
|
|
264
|
+
Crucially, it mimics the default behavior of TensorFlow's `tf.argmax`, where
|
|
265
|
+
if `axis` is `None`, the operation defaults to `axis=0`. This differs from
|
|
266
|
+
NumPy's and JAX's native `argmax` behavior, which would flatten the array
|
|
267
|
+
before finding the index.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
tensor: The input JAX array.
|
|
271
|
+
axis: An integer specifying the axis along which to find the index of the
|
|
272
|
+
maximum value. If `None`, it defaults to `0` to match TensorFlow's
|
|
273
|
+
behavior.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
A JAX array containing the indices of the maximum values.
|
|
277
|
+
"""
|
|
278
|
+
import jax.numpy as jnp
|
|
279
|
+
|
|
280
|
+
if axis is None:
|
|
281
|
+
axis = 0
|
|
282
|
+
return jnp.argmax(tensor, axis=axis)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def _tf_argmax(tensor, axis=None):
|
|
286
|
+
"""TensorFlow implementation for argmax."""
|
|
287
|
+
import tensorflow as tf
|
|
288
|
+
|
|
289
|
+
return tf.argmax(tensor, axis=axis)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _jax_broadcast_dynamic_shape(shape_x, shape_y):
|
|
293
|
+
"""JAX implementation for broadcast_dynamic_shape."""
|
|
294
|
+
import jax.numpy as jnp
|
|
295
|
+
|
|
296
|
+
return jnp.broadcast_shapes(shape_x, shape_y)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _jax_tensor_shape(dims):
|
|
300
|
+
"""JAX implementation for TensorShape."""
|
|
301
|
+
if isinstance(dims, int):
|
|
302
|
+
return (dims,)
|
|
303
|
+
|
|
304
|
+
return tuple(dims)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# --- Backend Initialization ---
|
|
308
|
+
_BACKEND = config.get_backend()
|
|
309
|
+
|
|
310
|
+
# We expose standardized functions directly at the module level (backend.foo)
|
|
311
|
+
# to provide a consistent, NumPy-like API across backends. The '_ops' object
|
|
312
|
+
# is a private member for accessing the full, raw backend library if necessary,
|
|
313
|
+
# but usage should prefer the top-level standardized functions.
|
|
314
|
+
|
|
315
|
+
if _BACKEND == config.Backend.JAX:
|
|
316
|
+
import jax
|
|
317
|
+
import jax.numpy as jax_ops
|
|
318
|
+
import tensorflow_probability.substrates.jax as tfp_jax
|
|
319
|
+
|
|
320
|
+
class ExtensionType:
|
|
321
|
+
"""A JAX-compatible stand-in for tf.experimental.ExtensionType."""
|
|
322
|
+
|
|
323
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
324
|
+
raise NotImplementedError(
|
|
325
|
+
"ExtensionType is not yet implemented for the JAX backend."
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
class _JaxErrors:
|
|
329
|
+
# pylint: disable=invalid-name
|
|
330
|
+
ResourceExhaustedError = MemoryError
|
|
331
|
+
InvalidArgumentError = ValueError
|
|
332
|
+
# pylint: enable=invalid-name
|
|
333
|
+
|
|
334
|
+
_ops = jax_ops
|
|
335
|
+
errors = _JaxErrors()
|
|
336
|
+
Tensor = jax.Array
|
|
337
|
+
tfd = tfp_jax.distributions
|
|
338
|
+
bijectors = tfp_jax.bijectors
|
|
339
|
+
experimental = tfp_jax.experimental
|
|
340
|
+
random = tfp_jax.random
|
|
341
|
+
mcmc = tfp_jax.mcmc
|
|
342
|
+
_convert_to_tensor = _ops.asarray
|
|
343
|
+
|
|
344
|
+
# Standardized Public API
|
|
345
|
+
absolute = _ops.abs
|
|
346
|
+
allclose = _ops.allclose
|
|
347
|
+
arange = _jax_arange
|
|
348
|
+
argmax = _jax_argmax
|
|
349
|
+
boolean_mask = _jax_boolean_mask
|
|
350
|
+
concatenate = _ops.concatenate
|
|
351
|
+
stack = _ops.stack
|
|
352
|
+
split = _ops.split
|
|
353
|
+
zeros = _ops.zeros
|
|
354
|
+
zeros_like = _ops.zeros_like
|
|
355
|
+
ones = _ops.ones
|
|
356
|
+
ones_like = _ops.ones_like
|
|
357
|
+
repeat = _ops.repeat
|
|
358
|
+
reshape = _ops.reshape
|
|
359
|
+
tile = _ops.tile
|
|
360
|
+
where = _ops.where
|
|
361
|
+
transpose = _ops.transpose
|
|
362
|
+
broadcast_to = _ops.broadcast_to
|
|
363
|
+
broadcast_dynamic_shape = _jax_broadcast_dynamic_shape
|
|
364
|
+
broadcast_to = _ops.broadcast_to
|
|
365
|
+
cast = _jax_cast
|
|
366
|
+
concatenate = _ops.concatenate
|
|
367
|
+
cumsum = _ops.cumsum
|
|
368
|
+
divide = _ops.divide
|
|
369
|
+
divide_no_nan = _jax_divide_no_nan
|
|
370
|
+
einsum = _ops.einsum
|
|
371
|
+
equal = _ops.equal
|
|
372
|
+
exp = _ops.exp
|
|
373
|
+
expand_dims = _ops.expand_dims
|
|
374
|
+
fill = _jax_fill
|
|
375
|
+
function = jax.jit
|
|
376
|
+
gather = _jax_gather
|
|
377
|
+
get_indices_where = _jax_get_indices_where
|
|
378
|
+
is_nan = _ops.isnan
|
|
379
|
+
log = _ops.log
|
|
380
|
+
make_ndarray = _jax_make_ndarray
|
|
381
|
+
make_tensor_proto = _jax_make_tensor_proto
|
|
382
|
+
numpy_function = _jax_numpy_function
|
|
383
|
+
ones = _ops.ones
|
|
384
|
+
ones_like = _ops.ones_like
|
|
385
|
+
rank = _ops.ndim
|
|
386
|
+
reduce_any = _ops.any
|
|
387
|
+
reduce_max = _ops.max
|
|
388
|
+
reduce_mean = _ops.mean
|
|
389
|
+
reduce_min = _ops.min
|
|
390
|
+
reduce_std = _ops.std
|
|
391
|
+
reduce_sum = _ops.sum
|
|
392
|
+
repeat = _ops.repeat
|
|
393
|
+
reshape = _ops.reshape
|
|
394
|
+
split = _ops.split
|
|
395
|
+
stack = _ops.stack
|
|
396
|
+
tile = _ops.tile
|
|
397
|
+
transpose = _ops.transpose
|
|
398
|
+
unique_with_counts = _jax_unique_with_counts
|
|
399
|
+
where = _ops.where
|
|
400
|
+
zeros = _ops.zeros
|
|
401
|
+
zeros_like = _ops.zeros_like
|
|
402
|
+
|
|
403
|
+
float32 = _ops.float32
|
|
404
|
+
bool_ = _ops.bool_
|
|
405
|
+
newaxis = _ops.newaxis
|
|
406
|
+
TensorShape = _jax_tensor_shape
|
|
407
|
+
|
|
408
|
+
def set_random_seed(seed: int) -> None: # pylint: disable=unused-argument
|
|
409
|
+
raise NotImplementedError(
|
|
410
|
+
"JAX does not support a global, stateful random seed. `set_random_seed`"
|
|
411
|
+
" is not implemented. Instead, you must pass an explicit `seed`"
|
|
412
|
+
" integer directly to the sampling methods (e.g., `sample_prior`),"
|
|
413
|
+
" which will be used to create a JAX PRNGKey internally."
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
elif _BACKEND == config.Backend.TENSORFLOW:
|
|
417
|
+
import tensorflow as tf_backend
|
|
418
|
+
import tensorflow_probability as tfp
|
|
419
|
+
|
|
420
|
+
_ops = tf_backend
|
|
421
|
+
errors = _ops.errors
|
|
422
|
+
|
|
423
|
+
Tensor = tf_backend.Tensor
|
|
424
|
+
ExtensionType = _ops.experimental.ExtensionType
|
|
425
|
+
|
|
426
|
+
tfd = tfp.distributions
|
|
427
|
+
bijectors = tfp.bijectors
|
|
428
|
+
experimental = tfp.experimental
|
|
429
|
+
random = tfp.random
|
|
430
|
+
mcmc = tfp.mcmc
|
|
431
|
+
|
|
432
|
+
_convert_to_tensor = _ops.convert_to_tensor
|
|
433
|
+
absolute = _ops.math.abs
|
|
434
|
+
allclose = _ops.experimental.numpy.allclose
|
|
435
|
+
arange = _tf_arange
|
|
436
|
+
argmax = _tf_argmax
|
|
437
|
+
boolean_mask = _tf_boolean_mask
|
|
438
|
+
concatenate = _ops.concat
|
|
439
|
+
stack = _ops.stack
|
|
440
|
+
split = _ops.split
|
|
441
|
+
zeros = _ops.zeros
|
|
442
|
+
zeros_like = _ops.zeros_like
|
|
443
|
+
ones = _ops.ones
|
|
444
|
+
ones_like = _ops.ones_like
|
|
445
|
+
repeat = _ops.repeat
|
|
446
|
+
reshape = _ops.reshape
|
|
447
|
+
tile = _ops.tile
|
|
448
|
+
where = _ops.where
|
|
449
|
+
transpose = _ops.transpose
|
|
450
|
+
broadcast_to = _ops.broadcast_to
|
|
451
|
+
broadcast_dynamic_shape = _ops.broadcast_dynamic_shape
|
|
452
|
+
broadcast_to = _ops.broadcast_to
|
|
453
|
+
cast = _ops.cast
|
|
454
|
+
concatenate = _ops.concat
|
|
455
|
+
cumsum = _ops.cumsum
|
|
456
|
+
divide = _ops.divide
|
|
457
|
+
divide_no_nan = _ops.math.divide_no_nan
|
|
458
|
+
einsum = _ops.einsum
|
|
459
|
+
equal = _ops.equal
|
|
460
|
+
exp = _ops.math.exp
|
|
461
|
+
expand_dims = _ops.expand_dims
|
|
462
|
+
fill = _tf_fill
|
|
463
|
+
function = _ops.function
|
|
464
|
+
gather = _tf_gather
|
|
465
|
+
get_indices_where = _tf_get_indices_where
|
|
466
|
+
is_nan = _ops.math.is_nan
|
|
467
|
+
log = _ops.math.log
|
|
468
|
+
make_ndarray = _ops.make_ndarray
|
|
469
|
+
make_tensor_proto = _ops.make_tensor_proto
|
|
470
|
+
numpy_function = _ops.numpy_function
|
|
471
|
+
ones = _ops.ones
|
|
472
|
+
ones_like = _ops.ones_like
|
|
473
|
+
rank = _ops.rank
|
|
474
|
+
reduce_any = _ops.reduce_any
|
|
475
|
+
reduce_max = _ops.reduce_max
|
|
476
|
+
reduce_mean = _ops.reduce_mean
|
|
477
|
+
reduce_min = _ops.reduce_min
|
|
478
|
+
reduce_std = _ops.math.reduce_std
|
|
479
|
+
reduce_sum = _ops.reduce_sum
|
|
480
|
+
repeat = _ops.repeat
|
|
481
|
+
reshape = _ops.reshape
|
|
482
|
+
set_random_seed = tf_backend.keras.utils.set_random_seed
|
|
483
|
+
split = _ops.split
|
|
484
|
+
stack = _ops.stack
|
|
485
|
+
tile = _ops.tile
|
|
486
|
+
transpose = _ops.transpose
|
|
487
|
+
unique_with_counts = _tf_unique_with_counts
|
|
488
|
+
where = _ops.where
|
|
489
|
+
zeros = _ops.zeros
|
|
490
|
+
zeros_like = _ops.zeros_like
|
|
491
|
+
|
|
492
|
+
float32 = _ops.float32
|
|
493
|
+
bool_ = _ops.bool
|
|
494
|
+
newaxis = _ops.newaxis
|
|
495
|
+
TensorShape = _ops.TensorShape
|
|
496
|
+
|
|
497
|
+
else:
|
|
498
|
+
raise ValueError(f"Unsupported backend: {_BACKEND}")
|
|
499
|
+
# pylint: enable=g-import-not-at-top,g-bad-import-order
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def to_tensor(data: Any, dtype: Optional[Any] = None) -> Tensor: # type: ignore
|
|
503
|
+
"""Converts input data to the currently active backend tensor type.
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
data: The data to convert.
|
|
507
|
+
dtype: The desired data type of the resulting tensor. The accepted types
|
|
508
|
+
depend on the active backend (e.g., jax.numpy.dtype or tf.DType).
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
A tensor representation of the data for the active backend.
|
|
512
|
+
"""
|
|
513
|
+
|
|
514
|
+
return _convert_to_tensor(data, dtype=dtype)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Backend configuration for Meridian."""
|
|
16
|
+
|
|
17
|
+
import enum
|
|
18
|
+
import warnings
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Backend(enum.Enum):
|
|
22
|
+
TENSORFLOW = "tensorflow"
|
|
23
|
+
JAX = "jax"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
_BACKEND = Backend.TENSORFLOW
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def set_backend(backend: Backend) -> None:
|
|
30
|
+
"""Sets the backend for Meridian.
|
|
31
|
+
|
|
32
|
+
Note: The JAX backend is currently under development and should not be used.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
backend: The backend to use, must be a member of the `Backend` enum.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If the provided backend is not a valid `Backend` enum member.
|
|
39
|
+
"""
|
|
40
|
+
global _BACKEND
|
|
41
|
+
if not isinstance(backend, Backend):
|
|
42
|
+
raise ValueError("Backend must be a member of the Backend enum.")
|
|
43
|
+
|
|
44
|
+
if backend == Backend.JAX:
|
|
45
|
+
warnings.warn(
|
|
46
|
+
(
|
|
47
|
+
"The JAX backend is currently under development and is not yet"
|
|
48
|
+
" functional. It is intended for internal testing only and should"
|
|
49
|
+
" not be used. Please use the TensorFlow backend."
|
|
50
|
+
),
|
|
51
|
+
UserWarning,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
_BACKEND = backend
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_backend() -> Backend:
|
|
58
|
+
"""Returns the current backend for Meridian."""
|
|
59
|
+
return _BACKEND
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Common testing utilities for Meridian, designed to be backend-agnostic."""
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
# A type alias for backend-agnostic array-like objects.
|
|
21
|
+
# We use `Any` here to avoid circular dependencies with the backend module
|
|
22
|
+
# while still allowing the function to accept backend-specific tensor types.
|
|
23
|
+
ArrayLike = Any
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def assert_allclose(
|
|
27
|
+
a: ArrayLike,
|
|
28
|
+
b: ArrayLike,
|
|
29
|
+
rtol: float = 1e-6,
|
|
30
|
+
atol: float = 1e-6,
|
|
31
|
+
err_msg: str = "",
|
|
32
|
+
):
|
|
33
|
+
"""Backend-agnostic assertion to check if two array-like objects are close.
|
|
34
|
+
|
|
35
|
+
This function converts both inputs to NumPy arrays before comparing them,
|
|
36
|
+
making it compatible with TensorFlow Tensors, JAX Arrays, and standard
|
|
37
|
+
Python lists or NumPy arrays.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
a: The first array-like object to compare.
|
|
41
|
+
b: The second array-like object to compare.
|
|
42
|
+
rtol: The relative tolerance parameter.
|
|
43
|
+
atol: The absolute tolerance parameter.
|
|
44
|
+
err_msg: The error message to be printed in case of failure.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
AssertionError: If the two arrays are not equal within the given tolerance.
|
|
48
|
+
"""
|
|
49
|
+
np.testing.assert_allclose(
|
|
50
|
+
np.array(a), np.array(b), rtol=rtol, atol=atol, err_msg=err_msg
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
|
|
55
|
+
"""Backend-agnostic assertion to check if two array-like objects are equal.
|
|
56
|
+
|
|
57
|
+
This function converts both inputs to NumPy arrays before comparing them.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
a: The first array-like object to compare.
|
|
61
|
+
b: The second array-like object to compare.
|
|
62
|
+
err_msg: The error message to be printed in case of failure.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
AssertionError: If the two arrays are not equal.
|
|
66
|
+
"""
|
|
67
|
+
np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def assert_all_finite(a: ArrayLike, err_msg: str = ""):
|
|
71
|
+
"""Backend-agnostic assertion to check if all elements in an array are finite.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
a: The array-like object to check.
|
|
75
|
+
err_msg: The error message to be printed in case of failure.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
AssertionError: If the array contains non-finite values.
|
|
79
|
+
"""
|
|
80
|
+
if not np.all(np.isfinite(np.array(a))):
|
|
81
|
+
raise AssertionError(err_msg or "Array contains non-finite values.")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
|
|
85
|
+
"""Backend-agnostic assertion to check if all elements are non-negative.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
a: The array-like object to check.
|
|
89
|
+
err_msg: The error message to be printed in case of failure.
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
AssertionError: If the array contains negative values.
|
|
93
|
+
"""
|
|
94
|
+
if not np.all(np.array(a) >= 0):
|
|
95
|
+
raise AssertionError(err_msg or "Array contains negative values.")
|