google-meridian 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.2.dist-info/METADATA +209 -0
- google_meridian-1.3.2.dist-info/RECORD +76 -0
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +1 -2
- meridian/analysis/analyzer.py +0 -1
- meridian/analysis/optimizer.py +5 -3
- meridian/analysis/review/checks.py +81 -30
- meridian/analysis/review/constants.py +4 -0
- meridian/analysis/review/results.py +40 -9
- meridian/analysis/summarizer.py +1 -1
- meridian/analysis/visualizer.py +1 -1
- meridian/backend/__init__.py +229 -24
- meridian/backend/test_utils.py +194 -0
- meridian/constants.py +1 -0
- meridian/data/load.py +2 -0
- meridian/model/eda/__init__.py +0 -1
- meridian/model/eda/constants.py +12 -2
- meridian/model/eda/eda_engine.py +353 -45
- meridian/model/eda/eda_outcome.py +21 -1
- meridian/model/knots.py +17 -0
- meridian/model/model_test_data.py +15 -0
- meridian/{analysis/templates → templates}/card.html.jinja +1 -1
- meridian/{analysis/templates → templates}/chart.html.jinja +1 -1
- meridian/{analysis/templates → templates}/chips.html.jinja +1 -1
- meridian/{analysis → templates}/formatter.py +12 -1
- meridian/templates/formatter_test.py +216 -0
- meridian/{analysis/templates → templates}/insights.html.jinja +1 -1
- meridian/{analysis/templates → templates}/stats.html.jinja +1 -1
- meridian/{analysis/templates → templates}/style.css +1 -1
- meridian/{analysis/templates → templates}/style.scss +1 -1
- meridian/{analysis/templates → templates}/summary.html.jinja +4 -2
- meridian/{analysis/templates → templates}/table.html.jinja +1 -1
- meridian/version.py +1 -1
- schema/__init__.py +30 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.3.0.dist-info/METADATA +0 -409
- google_meridian-1.3.0.dist-info/RECORD +0 -62
- meridian/model/eda/meridian_eda.py +0 -220
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/WHEEL +0 -0
- {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/licenses/LICENSE +0 -0
meridian/backend/__init__.py
CHANGED
|
@@ -19,6 +19,7 @@ import functools
|
|
|
19
19
|
import os
|
|
20
20
|
from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
|
21
21
|
import warnings
|
|
22
|
+
|
|
22
23
|
from meridian.backend import config
|
|
23
24
|
import numpy as np
|
|
24
25
|
from typing_extensions import Literal
|
|
@@ -220,7 +221,7 @@ def _tf_arange(
|
|
|
220
221
|
|
|
221
222
|
def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
|
|
222
223
|
"""JAX implementation for cast."""
|
|
223
|
-
return
|
|
224
|
+
return jax_ops.asarray(x, dtype=dtype)
|
|
224
225
|
|
|
225
226
|
|
|
226
227
|
def _jax_divide_no_nan(x, y):
|
|
@@ -305,17 +306,132 @@ def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
|
|
|
305
306
|
)
|
|
306
307
|
|
|
307
308
|
|
|
308
|
-
def _jax_make_tensor_proto(
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
309
|
+
def _jax_make_tensor_proto(values, dtype=None, shape=None): # pylint: disable=unused-argument
|
|
310
|
+
"""JAX implementation for make_tensor_proto."""
|
|
311
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
312
|
+
from tensorflow.core.framework import tensor_pb2
|
|
313
|
+
from tensorflow.core.framework import tensor_shape_pb2
|
|
314
|
+
from tensorflow.core.framework import types_pb2
|
|
315
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
312
316
|
|
|
317
|
+
if not isinstance(values, np.ndarray):
|
|
318
|
+
values = np.array(values)
|
|
313
319
|
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
320
|
+
if dtype:
|
|
321
|
+
numpy_dtype = np.dtype(dtype)
|
|
322
|
+
values = values.astype(numpy_dtype)
|
|
323
|
+
else:
|
|
324
|
+
numpy_dtype = values.dtype
|
|
325
|
+
|
|
326
|
+
dtype_map = {
|
|
327
|
+
np.dtype(np.float16): types_pb2.DT_HALF,
|
|
328
|
+
np.dtype(np.float32): types_pb2.DT_FLOAT,
|
|
329
|
+
np.dtype(np.float64): types_pb2.DT_DOUBLE,
|
|
330
|
+
np.dtype(np.int32): types_pb2.DT_INT32,
|
|
331
|
+
np.dtype(np.uint8): types_pb2.DT_UINT8,
|
|
332
|
+
np.dtype(np.uint16): types_pb2.DT_UINT16,
|
|
333
|
+
np.dtype(np.uint32): types_pb2.DT_UINT32,
|
|
334
|
+
np.dtype(np.uint64): types_pb2.DT_UINT64,
|
|
335
|
+
np.dtype(np.int16): types_pb2.DT_INT16,
|
|
336
|
+
np.dtype(np.int8): types_pb2.DT_INT8,
|
|
337
|
+
np.dtype(np.int64): types_pb2.DT_INT64,
|
|
338
|
+
np.dtype(np.complex64): types_pb2.DT_COMPLEX64,
|
|
339
|
+
np.dtype(np.complex128): types_pb2.DT_COMPLEX128,
|
|
340
|
+
np.dtype(np.bool_): types_pb2.DT_BOOL,
|
|
341
|
+
# Note: String types are handled outside the map.
|
|
342
|
+
}
|
|
343
|
+
proto_dtype = dtype_map.get(numpy_dtype)
|
|
344
|
+
if proto_dtype is None and numpy_dtype.kind in ("S", "U"):
|
|
345
|
+
proto_dtype = types_pb2.DT_STRING
|
|
346
|
+
|
|
347
|
+
if proto_dtype is None:
|
|
348
|
+
raise TypeError(
|
|
349
|
+
f"Unsupported dtype for TensorProto conversion: {numpy_dtype}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
proto = tensor_pb2.TensorProto(
|
|
353
|
+
dtype=proto_dtype,
|
|
354
|
+
tensor_shape=tensor_shape_pb2.TensorShapeProto(
|
|
355
|
+
dim=[
|
|
356
|
+
tensor_shape_pb2.TensorShapeProto.Dim(size=d)
|
|
357
|
+
for d in values.shape
|
|
358
|
+
]
|
|
359
|
+
),
|
|
317
360
|
)
|
|
318
361
|
|
|
362
|
+
proto.tensor_content = values.tobytes()
|
|
363
|
+
return proto
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def _jax_make_ndarray(proto):
|
|
367
|
+
"""JAX implementation for make_ndarray."""
|
|
368
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
369
|
+
from tensorflow.core.framework import types_pb2
|
|
370
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
371
|
+
|
|
372
|
+
dtype_map = {
|
|
373
|
+
types_pb2.DT_HALF: np.float16,
|
|
374
|
+
types_pb2.DT_FLOAT: np.float32,
|
|
375
|
+
types_pb2.DT_DOUBLE: np.float64,
|
|
376
|
+
types_pb2.DT_INT32: np.int32,
|
|
377
|
+
types_pb2.DT_UINT8: np.uint8,
|
|
378
|
+
types_pb2.DT_UINT16: np.uint16,
|
|
379
|
+
types_pb2.DT_UINT32: np.uint32,
|
|
380
|
+
types_pb2.DT_UINT64: np.uint64,
|
|
381
|
+
types_pb2.DT_INT16: np.int16,
|
|
382
|
+
types_pb2.DT_INT8: np.int8,
|
|
383
|
+
types_pb2.DT_INT64: np.int64,
|
|
384
|
+
types_pb2.DT_COMPLEX64: np.complex64,
|
|
385
|
+
types_pb2.DT_COMPLEX128: np.complex128,
|
|
386
|
+
types_pb2.DT_BOOL: np.bool_,
|
|
387
|
+
types_pb2.DT_STRING: np.bytes_,
|
|
388
|
+
}
|
|
389
|
+
if proto.dtype not in dtype_map:
|
|
390
|
+
raise TypeError(f"Unsupported TensorProto dtype: {proto.dtype}")
|
|
391
|
+
|
|
392
|
+
shape = [d.size for d in proto.tensor_shape.dim]
|
|
393
|
+
dtype = dtype_map[proto.dtype]
|
|
394
|
+
|
|
395
|
+
if proto.tensor_content:
|
|
396
|
+
num_elements = np.prod(shape).item() if shape else 0
|
|
397
|
+
# When deserializing a string from tensor_content, the itemsize is not
|
|
398
|
+
# explicitly stored. We must infer it from the content length and shape.
|
|
399
|
+
if dtype == np.bytes_ and num_elements > 0:
|
|
400
|
+
content_len = len(proto.tensor_content)
|
|
401
|
+
itemsize = content_len // num_elements
|
|
402
|
+
if itemsize * num_elements != content_len:
|
|
403
|
+
raise ValueError(
|
|
404
|
+
"Tensor content size is not a multiple of the number of elements"
|
|
405
|
+
" for string dtype."
|
|
406
|
+
)
|
|
407
|
+
dtype = np.dtype(f"S{itemsize}")
|
|
408
|
+
|
|
409
|
+
return (
|
|
410
|
+
np.frombuffer(proto.tensor_content, dtype=dtype).copy().reshape(shape)
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# Fallback for protos that store data in val fields instead of tensor_content.
|
|
414
|
+
if dtype == np.float32:
|
|
415
|
+
val_field = proto.float_val
|
|
416
|
+
elif dtype == np.float64:
|
|
417
|
+
val_field = proto.double_val
|
|
418
|
+
elif dtype == np.int32:
|
|
419
|
+
val_field = proto.int_val
|
|
420
|
+
elif dtype == np.int64:
|
|
421
|
+
val_field = proto.int64_val
|
|
422
|
+
elif dtype == np.bool_:
|
|
423
|
+
val_field = proto.bool_val
|
|
424
|
+
else:
|
|
425
|
+
if proto.string_val:
|
|
426
|
+
return np.array(proto.string_val, dtype=np.bytes_).reshape(shape)
|
|
427
|
+
if not any(shape):
|
|
428
|
+
return np.array([], dtype=dtype).reshape(shape)
|
|
429
|
+
raise TypeError(
|
|
430
|
+
f"Unsupported dtype for TensorProto value field fallback: {dtype}"
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
return np.array(val_field, dtype=dtype).reshape(shape)
|
|
434
|
+
|
|
319
435
|
|
|
320
436
|
def _jax_get_indices_where(condition):
|
|
321
437
|
"""JAX implementation for get_indices_where."""
|
|
@@ -493,16 +609,75 @@ def _tf_get_seed_data(seed: Any) -> Optional[np.ndarray]:
|
|
|
493
609
|
|
|
494
610
|
|
|
495
611
|
def _jax_convert_to_tensor(data, dtype=None):
|
|
496
|
-
"""Converts data to a JAX array, handling strings as NumPy arrays.
|
|
612
|
+
"""Converts data to a JAX array, handling strings as NumPy arrays.
|
|
613
|
+
|
|
614
|
+
This function explicitly unwraps objects with a `.values` attribute (e.g.,
|
|
615
|
+
pandas.DataFrame, xarray.DataArray) to access the underlying NumPy array,
|
|
616
|
+
provided that `.values` is not a method. This takes precedence over the
|
|
617
|
+
`__array__` protocol.
|
|
618
|
+
|
|
619
|
+
It also handles precision mismatches: if `data` is float64 and `dtype` is
|
|
620
|
+
not specified, and JAX x64 mode is disabled (default), it issues a warning
|
|
621
|
+
and explicitly casts to float32 to match the backend default and prevent
|
|
622
|
+
silent precision loss or type errors in downstream operations.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
data: The data to convert.
|
|
626
|
+
dtype: The desired data type.
|
|
627
|
+
|
|
628
|
+
Returns:
|
|
629
|
+
A JAX array, or a NumPy array if the dtype is a string type.
|
|
630
|
+
"""
|
|
631
|
+
# Unwrap xarray.DataArray, pandas.Series, and pandas.DataFrame objects.
|
|
632
|
+
# These objects wrap the underlying NumPy array in a .values attribute.
|
|
633
|
+
if hasattr(data, "values") and not callable(data.values):
|
|
634
|
+
data = data.values
|
|
635
|
+
|
|
636
|
+
# Convert to numpy array upfront to simplify dtype inspection below.
|
|
637
|
+
# A standard Python float is 64-bit, and this conversion allows the
|
|
638
|
+
# subsequent logic to correctly detect and handle potential float64
|
|
639
|
+
# downcasting for scalar inputs.
|
|
640
|
+
if isinstance(data, (list, tuple, float)):
|
|
641
|
+
data = np.array(data)
|
|
642
|
+
|
|
497
643
|
# JAX does not natively support string tensors in the same way TF does.
|
|
498
644
|
# If a string dtype is requested, or if the data is inherently strings,
|
|
499
645
|
# we fall back to a standard NumPy array.
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
646
|
+
is_string_target = False
|
|
647
|
+
if dtype is not None:
|
|
648
|
+
try:
|
|
649
|
+
if np.dtype(dtype).kind in ("S", "U"):
|
|
650
|
+
is_string_target = True
|
|
651
|
+
except TypeError:
|
|
652
|
+
# This can happen if dtype is not a valid dtype specifier,
|
|
653
|
+
# let jax.asarray handle it.
|
|
654
|
+
pass
|
|
655
|
+
|
|
656
|
+
is_string_data = isinstance(data, np.ndarray) and data.dtype.kind in (
|
|
657
|
+
"S",
|
|
658
|
+
"U",
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
if is_string_target:
|
|
662
|
+
return np.array(data, dtype=dtype)
|
|
663
|
+
|
|
664
|
+
if dtype is None and is_string_data:
|
|
665
|
+
return data
|
|
666
|
+
|
|
667
|
+
# If the user provides float64 data but does not request a specific dtype,
|
|
668
|
+
# and JAX 64-bit mode is disabled (default), JAX would implicitly truncate.
|
|
669
|
+
# We cast to float32 and warn the user to prevent silent mismatches.
|
|
670
|
+
if dtype is None:
|
|
671
|
+
is_float64_input = hasattr(data, "dtype") and data.dtype == np.float64
|
|
672
|
+
if is_float64_input:
|
|
673
|
+
if not jax.config.jax_enable_x64:
|
|
674
|
+
warnings.warn(
|
|
675
|
+
"Input data is float64. Casting to float32 to match backend "
|
|
676
|
+
"default precision.",
|
|
677
|
+
UserWarning,
|
|
678
|
+
)
|
|
679
|
+
dtype = jax_ops.float32
|
|
680
|
+
|
|
506
681
|
return jax_ops.asarray(data, dtype=dtype)
|
|
507
682
|
|
|
508
683
|
|
|
@@ -535,18 +710,48 @@ def _tf_nanvar(a, axis=None, keepdims=False):
|
|
|
535
710
|
return tf.convert_to_tensor(var)
|
|
536
711
|
|
|
537
712
|
|
|
538
|
-
def _jax_one_hot(
|
|
713
|
+
def _jax_one_hot(
|
|
714
|
+
indices, depth, on_value=None, off_value=None, axis=None, dtype=None
|
|
715
|
+
):
|
|
539
716
|
"""JAX implementation for one_hot."""
|
|
540
|
-
|
|
541
|
-
|
|
717
|
+
import jax.numpy as jnp
|
|
718
|
+
|
|
719
|
+
resolved_dtype = _resolve_dtype(dtype, on_value, off_value, 1, 0)
|
|
720
|
+
jax_axis = -1 if axis is None else axis
|
|
721
|
+
|
|
722
|
+
one_hot_result = jax.nn.one_hot(
|
|
723
|
+
indices, num_classes=depth, dtype=jnp.dtype(resolved_dtype), axis=jax_axis
|
|
542
724
|
)
|
|
543
725
|
|
|
726
|
+
on_val = 1 if on_value is None else on_value
|
|
727
|
+
off_val = 0 if off_value is None else off_value
|
|
728
|
+
|
|
729
|
+
if on_val == 1 and off_val == 0:
|
|
730
|
+
return one_hot_result
|
|
544
731
|
|
|
545
|
-
|
|
732
|
+
on_tensor = jnp.array(on_val, dtype=jnp.dtype(resolved_dtype))
|
|
733
|
+
off_tensor = jnp.array(off_val, dtype=jnp.dtype(resolved_dtype))
|
|
734
|
+
|
|
735
|
+
return jnp.where(one_hot_result == 1, on_tensor, off_tensor)
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def _jax_roll(a, shift, axis=None):
|
|
546
739
|
"""JAX implementation for roll."""
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
)
|
|
740
|
+
import jax.numpy as jnp
|
|
741
|
+
|
|
742
|
+
return jnp.roll(a, shift, axis=axis)
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
def _tf_roll(a, shift: Sequence[int], axis=None):
|
|
746
|
+
"""TensorFlow implementation for roll that handles axis=None."""
|
|
747
|
+
import tensorflow as tf
|
|
748
|
+
|
|
749
|
+
if axis is None:
|
|
750
|
+
original_shape = tf.shape(a)
|
|
751
|
+
flat_tensor = tf.reshape(a, [-1])
|
|
752
|
+
rolled_flat = tf.roll(flat_tensor, shift=shift, axis=0)
|
|
753
|
+
return tf.reshape(rolled_flat, original_shape)
|
|
754
|
+
return tf.roll(a, shift, axis=axis)
|
|
550
755
|
|
|
551
756
|
|
|
552
757
|
def _jax_enable_op_determinism():
|
|
@@ -772,7 +977,7 @@ if _BACKEND == config.Backend.JAX:
|
|
|
772
977
|
newaxis = _ops.newaxis
|
|
773
978
|
TensorShape = _jax_tensor_shape
|
|
774
979
|
int32 = _ops.int32
|
|
775
|
-
string = np.
|
|
980
|
+
string = np.bytes_
|
|
776
981
|
|
|
777
982
|
stabilize_rf_roi_grid = _jax_stabilize_rf_roi_grid
|
|
778
983
|
|
|
@@ -904,7 +1109,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
|
|
|
904
1109
|
reduce_sum = _ops.reduce_sum
|
|
905
1110
|
repeat = _ops.repeat
|
|
906
1111
|
reshape = _ops.reshape
|
|
907
|
-
roll =
|
|
1112
|
+
roll = _tf_roll
|
|
908
1113
|
set_random_seed = tf_backend.keras.utils.set_random_seed
|
|
909
1114
|
split = _ops.split
|
|
910
1115
|
stack = _ops.stack
|
meridian/backend/test_utils.py
CHANGED
|
@@ -14,12 +14,25 @@
|
|
|
14
14
|
|
|
15
15
|
"""Common testing utilities for Meridian, designed to be backend-agnostic."""
|
|
16
16
|
|
|
17
|
+
import dataclasses
|
|
17
18
|
from typing import Any, Optional
|
|
19
|
+
|
|
18
20
|
from absl.testing import parameterized
|
|
21
|
+
from google.protobuf import descriptor
|
|
22
|
+
from google.protobuf import message
|
|
19
23
|
from meridian import backend
|
|
20
24
|
from meridian.backend import config
|
|
21
25
|
import numpy as np
|
|
22
26
|
|
|
27
|
+
from tensorflow.python.util.protobuf import compare
|
|
28
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
29
|
+
from tensorflow.core.framework import tensor_pb2
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
33
|
+
|
|
34
|
+
FieldDescriptor = descriptor.FieldDescriptor
|
|
35
|
+
|
|
23
36
|
# A type alias for backend-agnostic array-like objects.
|
|
24
37
|
# We use `Any` here to avoid circular dependencies with the backend module
|
|
25
38
|
# while still allowing the function to accept backend-specific tensor types.
|
|
@@ -70,6 +83,75 @@ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
|
|
|
70
83
|
np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
|
|
71
84
|
|
|
72
85
|
|
|
86
|
+
def assert_deep_equals(
|
|
87
|
+
test_case,
|
|
88
|
+
obj1: Any,
|
|
89
|
+
obj2: Any,
|
|
90
|
+
msg: str = "",
|
|
91
|
+
rtol: float = 1e-5,
|
|
92
|
+
atol: float = 1e-5,
|
|
93
|
+
):
|
|
94
|
+
"""Recursive equality check handling Dataclasses, Lists, and Backend Tensors.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
test_case: The unittest.TestCase instance (self) to use for assertions.
|
|
98
|
+
obj1: The first object to compare.
|
|
99
|
+
obj2: The second object to compare.
|
|
100
|
+
msg: Optional error message prefix.
|
|
101
|
+
rtol: Relative tolerance for float comparison.
|
|
102
|
+
atol: Absolute tolerance for float comparison.
|
|
103
|
+
"""
|
|
104
|
+
if obj1 is None or obj2 is None:
|
|
105
|
+
test_case.assertEqual(obj1, obj2, msg=msg)
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
if (
|
|
109
|
+
hasattr(obj1, "__array__")
|
|
110
|
+
or hasattr(obj1, "numpy")
|
|
111
|
+
or isinstance(obj1, (np.ndarray, backend.Tensor))
|
|
112
|
+
):
|
|
113
|
+
arr1 = np.array(obj1)
|
|
114
|
+
arr2 = np.array(obj2)
|
|
115
|
+
|
|
116
|
+
# Check for non-numeric types where atol/rtol don't apply
|
|
117
|
+
if arr1.dtype.kind in ("U", "S", "O", "b"):
|
|
118
|
+
np.testing.assert_array_equal(arr1, arr2, err_msg=msg)
|
|
119
|
+
else:
|
|
120
|
+
np.testing.assert_allclose(arr1, arr2, err_msg=msg, rtol=rtol, atol=atol)
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
if dataclasses.is_dataclass(obj1):
|
|
124
|
+
test_case.assertIs(
|
|
125
|
+
type(obj1),
|
|
126
|
+
type(obj2),
|
|
127
|
+
msg=f"{msg} Type mismatch: {type(obj1)} vs {type(obj2)}",
|
|
128
|
+
)
|
|
129
|
+
for field in dataclasses.fields(obj1):
|
|
130
|
+
val1 = getattr(obj1, field.name)
|
|
131
|
+
val2 = getattr(obj2, field.name)
|
|
132
|
+
assert_deep_equals(
|
|
133
|
+
test_case,
|
|
134
|
+
val1,
|
|
135
|
+
val2,
|
|
136
|
+
msg=f"{msg}.{field.name}",
|
|
137
|
+
rtol=rtol,
|
|
138
|
+
atol=atol,
|
|
139
|
+
)
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
if isinstance(obj1, (list, tuple)):
|
|
143
|
+
test_case.assertIsInstance(obj2, (list, tuple), msg=f"{msg} Type mismatch")
|
|
144
|
+
test_case.assertEqual(len(obj1), len(obj2), msg=f"{msg} Length mismatch")
|
|
145
|
+
for i, (item1, item2) in enumerate(zip(obj1, obj2)):
|
|
146
|
+
assert_deep_equals(
|
|
147
|
+
test_case, item1, item2, msg=f"{msg}[{i}]", rtol=rtol, atol=atol
|
|
148
|
+
)
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
# Fallback to standard equality for primitives (int, str, float, etc.)
|
|
152
|
+
test_case.assertEqual(obj1, obj2, msg=msg)
|
|
153
|
+
|
|
154
|
+
|
|
73
155
|
def assert_seed_allequal(a: Any, b: Any, err_msg: str = ""):
|
|
74
156
|
"""Backend-agnostic assertion to check if two seed objects are equal."""
|
|
75
157
|
data_a = backend.get_seed_data(a)
|
|
@@ -131,6 +213,118 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
|
|
|
131
213
|
raise AssertionError(err_msg or "Array contains negative values.")
|
|
132
214
|
|
|
133
215
|
|
|
216
|
+
# --- Proto Utilities ---
|
|
217
|
+
def normalize_tensor_protos(proto: message.Message):
|
|
218
|
+
"""Recursively normalizes TensorProto messages within a proto (In-place).
|
|
219
|
+
|
|
220
|
+
This ensures a consistent serialization format across different backends
|
|
221
|
+
(e.g., JAX vs TF) by repacking TensorProtos using the current backend's
|
|
222
|
+
canonical method (backend.make_tensor_proto). This handles differences
|
|
223
|
+
like using `bool_val` versus `tensor_content` for boolean tensors.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
proto: The protobuf message object to normalize. This object is modified in
|
|
227
|
+
place.
|
|
228
|
+
"""
|
|
229
|
+
if not isinstance(proto, message.Message):
|
|
230
|
+
return
|
|
231
|
+
|
|
232
|
+
for desc, value in proto.ListFields():
|
|
233
|
+
if desc.type != FieldDescriptor.TYPE_MESSAGE:
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
# A map is defined as a repeated field whose message type has the
|
|
237
|
+
# map_entry option set.
|
|
238
|
+
is_map = (
|
|
239
|
+
desc.label == FieldDescriptor.LABEL_REPEATED
|
|
240
|
+
and desc.message_type.has_options
|
|
241
|
+
and desc.message_type.GetOptions().map_entry
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
if is_map:
|
|
245
|
+
for item in value.values():
|
|
246
|
+
# Helper checks if values are scalars or messages.
|
|
247
|
+
_process_message_for_normalization(item)
|
|
248
|
+
|
|
249
|
+
elif desc.label == FieldDescriptor.LABEL_REPEATED:
|
|
250
|
+
# Handle standard repeated message fields.
|
|
251
|
+
for item in value:
|
|
252
|
+
_process_message_for_normalization(item)
|
|
253
|
+
else:
|
|
254
|
+
# Handle singular message fields.
|
|
255
|
+
_process_message_for_normalization(value)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _process_message_for_normalization(msg: Any):
|
|
259
|
+
"""Helper to process a potential message during normalization traversal."""
|
|
260
|
+
# Ensure we only process message objects.
|
|
261
|
+
# If msg is a scalar (e.g., string from map<string, string>), stop recursion.
|
|
262
|
+
if not isinstance(msg, message.Message):
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
if isinstance(msg, tensor_pb2.TensorProto):
|
|
266
|
+
_repack_tensor_proto(msg)
|
|
267
|
+
else:
|
|
268
|
+
# If it's another message type, recurse into its fields.
|
|
269
|
+
normalize_tensor_protos(msg)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _repack_tensor_proto(tensor_proto: "tensor_pb2.TensorProto"):
|
|
273
|
+
"""Repacks a TensorProto in place to use a consistent serialization format."""
|
|
274
|
+
if not tensor_proto.ByteSize():
|
|
275
|
+
return
|
|
276
|
+
|
|
277
|
+
try:
|
|
278
|
+
data_array = backend.make_ndarray(tensor_proto)
|
|
279
|
+
except Exception as e:
|
|
280
|
+
raise ValueError(
|
|
281
|
+
"Failed to deserialize TensorProto during normalization:"
|
|
282
|
+
f" {e}\nProto content:\n{tensor_proto}"
|
|
283
|
+
) from e
|
|
284
|
+
|
|
285
|
+
new_tensor_proto = backend.make_tensor_proto(data_array)
|
|
286
|
+
|
|
287
|
+
tensor_proto.Clear()
|
|
288
|
+
tensor_proto.CopyFrom(new_tensor_proto)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def assert_normalized_proto_equal(
|
|
292
|
+
test_case: parameterized.TestCase,
|
|
293
|
+
expected: message.Message,
|
|
294
|
+
actual: message.Message,
|
|
295
|
+
msg: Optional[str] = None,
|
|
296
|
+
**kwargs: Any,
|
|
297
|
+
):
|
|
298
|
+
"""Compares two protos after normalizing TensorProto fields.
|
|
299
|
+
|
|
300
|
+
Use this instead of compare.assertProtoEqual when protos contain tensors
|
|
301
|
+
to ensure backend-agnostic comparison.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
test_case: The TestCase instance (self).
|
|
305
|
+
expected: The expected protobuf message.
|
|
306
|
+
actual: The actual protobuf message.
|
|
307
|
+
msg: An optional message to display on failure.
|
|
308
|
+
**kwargs: Additional keyword arguments passed to assertProto2Equal (e.g.,
|
|
309
|
+
precision).
|
|
310
|
+
"""
|
|
311
|
+
# Work on copies to avoid mutating the original objects
|
|
312
|
+
expected_copy = expected.__class__()
|
|
313
|
+
expected_copy.CopyFrom(expected)
|
|
314
|
+
actual_copy = actual.__class__()
|
|
315
|
+
actual_copy.CopyFrom(actual)
|
|
316
|
+
|
|
317
|
+
try:
|
|
318
|
+
normalize_tensor_protos(expected_copy)
|
|
319
|
+
normalize_tensor_protos(actual_copy)
|
|
320
|
+
except ValueError as e:
|
|
321
|
+
test_case.fail(f"Proto normalization failed: {e}. {msg}")
|
|
322
|
+
|
|
323
|
+
compare.assertProtoEqual(
|
|
324
|
+
test_case, expected_copy, actual_copy, msg=msg, **kwargs
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
|
|
134
328
|
class MeridianTestCase(parameterized.TestCase):
|
|
135
329
|
"""Base test class for Meridian providing backend-aware utilities.
|
|
136
330
|
|
meridian/constants.py
CHANGED
meridian/data/load.py
CHANGED
meridian/model/eda/__init__.py
CHANGED
meridian/model/eda/constants.py
CHANGED
|
@@ -14,8 +14,18 @@
|
|
|
14
14
|
|
|
15
15
|
"""Constants specific to MeridianEDA."""
|
|
16
16
|
|
|
17
|
-
# EDA
|
|
17
|
+
# EDA Engine constants
|
|
18
|
+
COST_PER_MEDIA_UNIT = 'cost_per_media_unit'
|
|
19
|
+
|
|
18
20
|
VARIABLE_1 = 'var1'
|
|
19
21
|
VARIABLE_2 = 'var2'
|
|
20
|
-
VARIABLE = 'var'
|
|
21
22
|
CORRELATION = 'correlation'
|
|
23
|
+
ABS_CORRELATION_COL_NAME = 'abs_correlation'
|
|
24
|
+
|
|
25
|
+
# EDA Plotting properties
|
|
26
|
+
VARIABLE = 'var'
|
|
27
|
+
VALUE = 'value'
|
|
28
|
+
NATIONALIZE = 'nationalize'
|
|
29
|
+
MEDIA_IMPRESSIONS_SCALED = 'media_impressions_scaled'
|
|
30
|
+
IMPRESSION_SHARE_SCALED = 'impression_share_scaled'
|
|
31
|
+
SPEND_SHARE = 'spend_share'
|