google-meridian 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. google_meridian-1.3.2.dist-info/METADATA +209 -0
  2. google_meridian-1.3.2.dist-info/RECORD +76 -0
  3. {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +1 -2
  5. meridian/analysis/analyzer.py +0 -1
  6. meridian/analysis/optimizer.py +5 -3
  7. meridian/analysis/review/checks.py +81 -30
  8. meridian/analysis/review/constants.py +4 -0
  9. meridian/analysis/review/results.py +40 -9
  10. meridian/analysis/summarizer.py +1 -1
  11. meridian/analysis/visualizer.py +1 -1
  12. meridian/backend/__init__.py +229 -24
  13. meridian/backend/test_utils.py +194 -0
  14. meridian/constants.py +1 -0
  15. meridian/data/load.py +2 -0
  16. meridian/model/eda/__init__.py +0 -1
  17. meridian/model/eda/constants.py +12 -2
  18. meridian/model/eda/eda_engine.py +353 -45
  19. meridian/model/eda/eda_outcome.py +21 -1
  20. meridian/model/knots.py +17 -0
  21. meridian/model/model_test_data.py +15 -0
  22. meridian/{analysis/templates → templates}/card.html.jinja +1 -1
  23. meridian/{analysis/templates → templates}/chart.html.jinja +1 -1
  24. meridian/{analysis/templates → templates}/chips.html.jinja +1 -1
  25. meridian/{analysis → templates}/formatter.py +12 -1
  26. meridian/templates/formatter_test.py +216 -0
  27. meridian/{analysis/templates → templates}/insights.html.jinja +1 -1
  28. meridian/{analysis/templates → templates}/stats.html.jinja +1 -1
  29. meridian/{analysis/templates → templates}/style.css +1 -1
  30. meridian/{analysis/templates → templates}/style.scss +1 -1
  31. meridian/{analysis/templates → templates}/summary.html.jinja +4 -2
  32. meridian/{analysis/templates → templates}/table.html.jinja +1 -1
  33. meridian/version.py +1 -1
  34. schema/__init__.py +30 -0
  35. schema/serde/__init__.py +26 -0
  36. schema/serde/constants.py +48 -0
  37. schema/serde/distribution.py +515 -0
  38. schema/serde/eda_spec.py +192 -0
  39. schema/serde/function_registry.py +143 -0
  40. schema/serde/hyperparameters.py +363 -0
  41. schema/serde/inference_data.py +105 -0
  42. schema/serde/marketing_data.py +1321 -0
  43. schema/serde/meridian_serde.py +413 -0
  44. schema/serde/serde.py +47 -0
  45. schema/serde/test_data.py +4608 -0
  46. schema/utils/__init__.py +17 -0
  47. schema/utils/time_record.py +156 -0
  48. google_meridian-1.3.0.dist-info/METADATA +0 -409
  49. google_meridian-1.3.0.dist-info/RECORD +0 -62
  50. meridian/model/eda/meridian_eda.py +0 -220
  51. {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/WHEEL +0 -0
  52. {google_meridian-1.3.0.dist-info → google_meridian-1.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -19,6 +19,7 @@ import functools
19
19
  import os
20
20
  from typing import Any, Optional, Sequence, Tuple, TYPE_CHECKING, Union
21
21
  import warnings
22
+
22
23
  from meridian.backend import config
23
24
  import numpy as np
24
25
  from typing_extensions import Literal
@@ -220,7 +221,7 @@ def _tf_arange(
220
221
 
221
222
  def _jax_cast(x: Any, dtype: Any) -> "_jax.Array":
222
223
  """JAX implementation for cast."""
223
- return x.astype(dtype)
224
+ return jax_ops.asarray(x, dtype=dtype)
224
225
 
225
226
 
226
227
  def _jax_divide_no_nan(x, y):
@@ -305,17 +306,132 @@ def _jax_numpy_function(*args, **kwargs): # pylint: disable=unused-argument
305
306
  )
306
307
 
307
308
 
308
- def _jax_make_tensor_proto(*args, **kwargs): # pylint: disable=unused-argument
309
- raise NotImplementedError(
310
- "backend.make_tensor_proto is not implemented for the JAX backend."
311
- )
309
+ def _jax_make_tensor_proto(values, dtype=None, shape=None): # pylint: disable=unused-argument
310
+ """JAX implementation for make_tensor_proto."""
311
+ # pylint: disable=g-direct-tensorflow-import
312
+ from tensorflow.core.framework import tensor_pb2
313
+ from tensorflow.core.framework import tensor_shape_pb2
314
+ from tensorflow.core.framework import types_pb2
315
+ # pylint: enable=g-direct-tensorflow-import
312
316
 
317
+ if not isinstance(values, np.ndarray):
318
+ values = np.array(values)
313
319
 
314
- def _jax_make_ndarray(*args, **kwargs): # pylint: disable=unused-argument
315
- raise NotImplementedError(
316
- "backend.make_ndarray is not implemented for the JAX backend."
320
+ if dtype:
321
+ numpy_dtype = np.dtype(dtype)
322
+ values = values.astype(numpy_dtype)
323
+ else:
324
+ numpy_dtype = values.dtype
325
+
326
+ dtype_map = {
327
+ np.dtype(np.float16): types_pb2.DT_HALF,
328
+ np.dtype(np.float32): types_pb2.DT_FLOAT,
329
+ np.dtype(np.float64): types_pb2.DT_DOUBLE,
330
+ np.dtype(np.int32): types_pb2.DT_INT32,
331
+ np.dtype(np.uint8): types_pb2.DT_UINT8,
332
+ np.dtype(np.uint16): types_pb2.DT_UINT16,
333
+ np.dtype(np.uint32): types_pb2.DT_UINT32,
334
+ np.dtype(np.uint64): types_pb2.DT_UINT64,
335
+ np.dtype(np.int16): types_pb2.DT_INT16,
336
+ np.dtype(np.int8): types_pb2.DT_INT8,
337
+ np.dtype(np.int64): types_pb2.DT_INT64,
338
+ np.dtype(np.complex64): types_pb2.DT_COMPLEX64,
339
+ np.dtype(np.complex128): types_pb2.DT_COMPLEX128,
340
+ np.dtype(np.bool_): types_pb2.DT_BOOL,
341
+ # Note: String types are handled outside the map.
342
+ }
343
+ proto_dtype = dtype_map.get(numpy_dtype)
344
+ if proto_dtype is None and numpy_dtype.kind in ("S", "U"):
345
+ proto_dtype = types_pb2.DT_STRING
346
+
347
+ if proto_dtype is None:
348
+ raise TypeError(
349
+ f"Unsupported dtype for TensorProto conversion: {numpy_dtype}"
350
+ )
351
+
352
+ proto = tensor_pb2.TensorProto(
353
+ dtype=proto_dtype,
354
+ tensor_shape=tensor_shape_pb2.TensorShapeProto(
355
+ dim=[
356
+ tensor_shape_pb2.TensorShapeProto.Dim(size=d)
357
+ for d in values.shape
358
+ ]
359
+ ),
317
360
  )
318
361
 
362
+ proto.tensor_content = values.tobytes()
363
+ return proto
364
+
365
+
366
+ def _jax_make_ndarray(proto):
367
+ """JAX implementation for make_ndarray."""
368
+ # pylint: disable=g-direct-tensorflow-import
369
+ from tensorflow.core.framework import types_pb2
370
+ # pylint: enable=g-direct-tensorflow-import
371
+
372
+ dtype_map = {
373
+ types_pb2.DT_HALF: np.float16,
374
+ types_pb2.DT_FLOAT: np.float32,
375
+ types_pb2.DT_DOUBLE: np.float64,
376
+ types_pb2.DT_INT32: np.int32,
377
+ types_pb2.DT_UINT8: np.uint8,
378
+ types_pb2.DT_UINT16: np.uint16,
379
+ types_pb2.DT_UINT32: np.uint32,
380
+ types_pb2.DT_UINT64: np.uint64,
381
+ types_pb2.DT_INT16: np.int16,
382
+ types_pb2.DT_INT8: np.int8,
383
+ types_pb2.DT_INT64: np.int64,
384
+ types_pb2.DT_COMPLEX64: np.complex64,
385
+ types_pb2.DT_COMPLEX128: np.complex128,
386
+ types_pb2.DT_BOOL: np.bool_,
387
+ types_pb2.DT_STRING: np.bytes_,
388
+ }
389
+ if proto.dtype not in dtype_map:
390
+ raise TypeError(f"Unsupported TensorProto dtype: {proto.dtype}")
391
+
392
+ shape = [d.size for d in proto.tensor_shape.dim]
393
+ dtype = dtype_map[proto.dtype]
394
+
395
+ if proto.tensor_content:
396
+ num_elements = np.prod(shape).item() if shape else 0
397
+ # When deserializing a string from tensor_content, the itemsize is not
398
+ # explicitly stored. We must infer it from the content length and shape.
399
+ if dtype == np.bytes_ and num_elements > 0:
400
+ content_len = len(proto.tensor_content)
401
+ itemsize = content_len // num_elements
402
+ if itemsize * num_elements != content_len:
403
+ raise ValueError(
404
+ "Tensor content size is not a multiple of the number of elements"
405
+ " for string dtype."
406
+ )
407
+ dtype = np.dtype(f"S{itemsize}")
408
+
409
+ return (
410
+ np.frombuffer(proto.tensor_content, dtype=dtype).copy().reshape(shape)
411
+ )
412
+
413
+ # Fallback for protos that store data in val fields instead of tensor_content.
414
+ if dtype == np.float32:
415
+ val_field = proto.float_val
416
+ elif dtype == np.float64:
417
+ val_field = proto.double_val
418
+ elif dtype == np.int32:
419
+ val_field = proto.int_val
420
+ elif dtype == np.int64:
421
+ val_field = proto.int64_val
422
+ elif dtype == np.bool_:
423
+ val_field = proto.bool_val
424
+ else:
425
+ if proto.string_val:
426
+ return np.array(proto.string_val, dtype=np.bytes_).reshape(shape)
427
+ if not any(shape):
428
+ return np.array([], dtype=dtype).reshape(shape)
429
+ raise TypeError(
430
+ f"Unsupported dtype for TensorProto value field fallback: {dtype}"
431
+ )
432
+
433
+ return np.array(val_field, dtype=dtype).reshape(shape)
434
+
319
435
 
320
436
  def _jax_get_indices_where(condition):
321
437
  """JAX implementation for get_indices_where."""
@@ -493,16 +609,75 @@ def _tf_get_seed_data(seed: Any) -> Optional[np.ndarray]:
493
609
 
494
610
 
495
611
  def _jax_convert_to_tensor(data, dtype=None):
496
- """Converts data to a JAX array, handling strings as NumPy arrays."""
612
+ """Converts data to a JAX array, handling strings as NumPy arrays.
613
+
614
+ This function explicitly unwraps objects with a `.values` attribute (e.g.,
615
+ pandas.DataFrame, xarray.DataArray) to access the underlying NumPy array,
616
+ provided that `.values` is not a method. This takes precedence over the
617
+ `__array__` protocol.
618
+
619
+ It also handles precision mismatches: if `data` is float64 and `dtype` is
620
+ not specified, and JAX x64 mode is disabled (default), it issues a warning
621
+ and explicitly casts to float32 to match the backend default and prevent
622
+ silent precision loss or type errors in downstream operations.
623
+
624
+ Args:
625
+ data: The data to convert.
626
+ dtype: The desired data type.
627
+
628
+ Returns:
629
+ A JAX array, or a NumPy array if the dtype is a string type.
630
+ """
631
+ # Unwrap xarray.DataArray, pandas.Series, and pandas.DataFrame objects.
632
+ # These objects wrap the underlying NumPy array in a .values attribute.
633
+ if hasattr(data, "values") and not callable(data.values):
634
+ data = data.values
635
+
636
+ # Convert to numpy array upfront to simplify dtype inspection below.
637
+ # A standard Python float is 64-bit, and this conversion allows the
638
+ # subsequent logic to correctly detect and handle potential float64
639
+ # downcasting for scalar inputs.
640
+ if isinstance(data, (list, tuple, float)):
641
+ data = np.array(data)
642
+
497
643
  # JAX does not natively support string tensors in the same way TF does.
498
644
  # If a string dtype is requested, or if the data is inherently strings,
499
645
  # we fall back to a standard NumPy array.
500
- if dtype == np.str_ or (
501
- dtype is None
502
- and isinstance(data, (list, np.ndarray))
503
- and np.array(data).dtype.kind in ("S", "U")
504
- ):
505
- return np.array(data, dtype=np.str_)
646
+ is_string_target = False
647
+ if dtype is not None:
648
+ try:
649
+ if np.dtype(dtype).kind in ("S", "U"):
650
+ is_string_target = True
651
+ except TypeError:
652
+ # This can happen if dtype is not a valid dtype specifier,
653
+ # let jax.asarray handle it.
654
+ pass
655
+
656
+ is_string_data = isinstance(data, np.ndarray) and data.dtype.kind in (
657
+ "S",
658
+ "U",
659
+ )
660
+
661
+ if is_string_target:
662
+ return np.array(data, dtype=dtype)
663
+
664
+ if dtype is None and is_string_data:
665
+ return data
666
+
667
+ # If the user provides float64 data but does not request a specific dtype,
668
+ # and JAX 64-bit mode is disabled (default), JAX would implicitly truncate.
669
+ # We cast to float32 and warn the user to prevent silent mismatches.
670
+ if dtype is None:
671
+ is_float64_input = hasattr(data, "dtype") and data.dtype == np.float64
672
+ if is_float64_input:
673
+ if not jax.config.jax_enable_x64:
674
+ warnings.warn(
675
+ "Input data is float64. Casting to float32 to match backend "
676
+ "default precision.",
677
+ UserWarning,
678
+ )
679
+ dtype = jax_ops.float32
680
+
506
681
  return jax_ops.asarray(data, dtype=dtype)
507
682
 
508
683
 
@@ -535,18 +710,48 @@ def _tf_nanvar(a, axis=None, keepdims=False):
535
710
  return tf.convert_to_tensor(var)
536
711
 
537
712
 
538
- def _jax_one_hot(*args, **kwargs): # pylint: disable=unused-argument
713
+ def _jax_one_hot(
714
+ indices, depth, on_value=None, off_value=None, axis=None, dtype=None
715
+ ):
539
716
  """JAX implementation for one_hot."""
540
- raise NotImplementedError(
541
- "backend.one_hot is not implemented for the JAX backend."
717
+ import jax.numpy as jnp
718
+
719
+ resolved_dtype = _resolve_dtype(dtype, on_value, off_value, 1, 0)
720
+ jax_axis = -1 if axis is None else axis
721
+
722
+ one_hot_result = jax.nn.one_hot(
723
+ indices, num_classes=depth, dtype=jnp.dtype(resolved_dtype), axis=jax_axis
542
724
  )
543
725
 
726
+ on_val = 1 if on_value is None else on_value
727
+ off_val = 0 if off_value is None else off_value
728
+
729
+ if on_val == 1 and off_val == 0:
730
+ return one_hot_result
544
731
 
545
- def _jax_roll(*args, **kwargs): # pylint: disable=unused-argument
732
+ on_tensor = jnp.array(on_val, dtype=jnp.dtype(resolved_dtype))
733
+ off_tensor = jnp.array(off_val, dtype=jnp.dtype(resolved_dtype))
734
+
735
+ return jnp.where(one_hot_result == 1, on_tensor, off_tensor)
736
+
737
+
738
+ def _jax_roll(a, shift, axis=None):
546
739
  """JAX implementation for roll."""
547
- raise NotImplementedError(
548
- "backend.roll is not implemented for the JAX backend."
549
- )
740
+ import jax.numpy as jnp
741
+
742
+ return jnp.roll(a, shift, axis=axis)
743
+
744
+
745
+ def _tf_roll(a, shift: Sequence[int], axis=None):
746
+ """TensorFlow implementation for roll that handles axis=None."""
747
+ import tensorflow as tf
748
+
749
+ if axis is None:
750
+ original_shape = tf.shape(a)
751
+ flat_tensor = tf.reshape(a, [-1])
752
+ rolled_flat = tf.roll(flat_tensor, shift=shift, axis=0)
753
+ return tf.reshape(rolled_flat, original_shape)
754
+ return tf.roll(a, shift, axis=axis)
550
755
 
551
756
 
552
757
  def _jax_enable_op_determinism():
@@ -772,7 +977,7 @@ if _BACKEND == config.Backend.JAX:
772
977
  newaxis = _ops.newaxis
773
978
  TensorShape = _jax_tensor_shape
774
979
  int32 = _ops.int32
775
- string = np.str_
980
+ string = np.bytes_
776
981
 
777
982
  stabilize_rf_roi_grid = _jax_stabilize_rf_roi_grid
778
983
 
@@ -904,7 +1109,7 @@ elif _BACKEND == config.Backend.TENSORFLOW:
904
1109
  reduce_sum = _ops.reduce_sum
905
1110
  repeat = _ops.repeat
906
1111
  reshape = _ops.reshape
907
- roll = _ops.roll
1112
+ roll = _tf_roll
908
1113
  set_random_seed = tf_backend.keras.utils.set_random_seed
909
1114
  split = _ops.split
910
1115
  stack = _ops.stack
@@ -14,12 +14,25 @@
14
14
 
15
15
  """Common testing utilities for Meridian, designed to be backend-agnostic."""
16
16
 
17
+ import dataclasses
17
18
  from typing import Any, Optional
19
+
18
20
  from absl.testing import parameterized
21
+ from google.protobuf import descriptor
22
+ from google.protobuf import message
19
23
  from meridian import backend
20
24
  from meridian.backend import config
21
25
  import numpy as np
22
26
 
27
+ from tensorflow.python.util.protobuf import compare
28
+ # pylint: disable=g-direct-tensorflow-import
29
+ from tensorflow.core.framework import tensor_pb2
30
+
31
+
32
+ # pylint: enable=g-direct-tensorflow-import
33
+
34
+ FieldDescriptor = descriptor.FieldDescriptor
35
+
23
36
  # A type alias for backend-agnostic array-like objects.
24
37
  # We use `Any` here to avoid circular dependencies with the backend module
25
38
  # while still allowing the function to accept backend-specific tensor types.
@@ -70,6 +83,75 @@ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
70
83
  np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
71
84
 
72
85
 
86
+ def assert_deep_equals(
87
+ test_case,
88
+ obj1: Any,
89
+ obj2: Any,
90
+ msg: str = "",
91
+ rtol: float = 1e-5,
92
+ atol: float = 1e-5,
93
+ ):
94
+ """Recursive equality check handling Dataclasses, Lists, and Backend Tensors.
95
+
96
+ Args:
97
+ test_case: The unittest.TestCase instance (self) to use for assertions.
98
+ obj1: The first object to compare.
99
+ obj2: The second object to compare.
100
+ msg: Optional error message prefix.
101
+ rtol: Relative tolerance for float comparison.
102
+ atol: Absolute tolerance for float comparison.
103
+ """
104
+ if obj1 is None or obj2 is None:
105
+ test_case.assertEqual(obj1, obj2, msg=msg)
106
+ return
107
+
108
+ if (
109
+ hasattr(obj1, "__array__")
110
+ or hasattr(obj1, "numpy")
111
+ or isinstance(obj1, (np.ndarray, backend.Tensor))
112
+ ):
113
+ arr1 = np.array(obj1)
114
+ arr2 = np.array(obj2)
115
+
116
+ # Check for non-numeric types where atol/rtol don't apply
117
+ if arr1.dtype.kind in ("U", "S", "O", "b"):
118
+ np.testing.assert_array_equal(arr1, arr2, err_msg=msg)
119
+ else:
120
+ np.testing.assert_allclose(arr1, arr2, err_msg=msg, rtol=rtol, atol=atol)
121
+ return
122
+
123
+ if dataclasses.is_dataclass(obj1):
124
+ test_case.assertIs(
125
+ type(obj1),
126
+ type(obj2),
127
+ msg=f"{msg} Type mismatch: {type(obj1)} vs {type(obj2)}",
128
+ )
129
+ for field in dataclasses.fields(obj1):
130
+ val1 = getattr(obj1, field.name)
131
+ val2 = getattr(obj2, field.name)
132
+ assert_deep_equals(
133
+ test_case,
134
+ val1,
135
+ val2,
136
+ msg=f"{msg}.{field.name}",
137
+ rtol=rtol,
138
+ atol=atol,
139
+ )
140
+ return
141
+
142
+ if isinstance(obj1, (list, tuple)):
143
+ test_case.assertIsInstance(obj2, (list, tuple), msg=f"{msg} Type mismatch")
144
+ test_case.assertEqual(len(obj1), len(obj2), msg=f"{msg} Length mismatch")
145
+ for i, (item1, item2) in enumerate(zip(obj1, obj2)):
146
+ assert_deep_equals(
147
+ test_case, item1, item2, msg=f"{msg}[{i}]", rtol=rtol, atol=atol
148
+ )
149
+ return
150
+
151
+ # Fallback to standard equality for primitives (int, str, float, etc.)
152
+ test_case.assertEqual(obj1, obj2, msg=msg)
153
+
154
+
73
155
  def assert_seed_allequal(a: Any, b: Any, err_msg: str = ""):
74
156
  """Backend-agnostic assertion to check if two seed objects are equal."""
75
157
  data_a = backend.get_seed_data(a)
@@ -131,6 +213,118 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
131
213
  raise AssertionError(err_msg or "Array contains negative values.")
132
214
 
133
215
 
216
+ # --- Proto Utilities ---
217
+ def normalize_tensor_protos(proto: message.Message):
218
+ """Recursively normalizes TensorProto messages within a proto (In-place).
219
+
220
+ This ensures a consistent serialization format across different backends
221
+ (e.g., JAX vs TF) by repacking TensorProtos using the current backend's
222
+ canonical method (backend.make_tensor_proto). This handles differences
223
+ like using `bool_val` versus `tensor_content` for boolean tensors.
224
+
225
+ Args:
226
+ proto: The protobuf message object to normalize. This object is modified in
227
+ place.
228
+ """
229
+ if not isinstance(proto, message.Message):
230
+ return
231
+
232
+ for desc, value in proto.ListFields():
233
+ if desc.type != FieldDescriptor.TYPE_MESSAGE:
234
+ continue
235
+
236
+ # A map is defined as a repeated field whose message type has the
237
+ # map_entry option set.
238
+ is_map = (
239
+ desc.label == FieldDescriptor.LABEL_REPEATED
240
+ and desc.message_type.has_options
241
+ and desc.message_type.GetOptions().map_entry
242
+ )
243
+
244
+ if is_map:
245
+ for item in value.values():
246
+ # Helper checks if values are scalars or messages.
247
+ _process_message_for_normalization(item)
248
+
249
+ elif desc.label == FieldDescriptor.LABEL_REPEATED:
250
+ # Handle standard repeated message fields.
251
+ for item in value:
252
+ _process_message_for_normalization(item)
253
+ else:
254
+ # Handle singular message fields.
255
+ _process_message_for_normalization(value)
256
+
257
+
258
+ def _process_message_for_normalization(msg: Any):
259
+ """Helper to process a potential message during normalization traversal."""
260
+ # Ensure we only process message objects.
261
+ # If msg is a scalar (e.g., string from map<string, string>), stop recursion.
262
+ if not isinstance(msg, message.Message):
263
+ return
264
+
265
+ if isinstance(msg, tensor_pb2.TensorProto):
266
+ _repack_tensor_proto(msg)
267
+ else:
268
+ # If it's another message type, recurse into its fields.
269
+ normalize_tensor_protos(msg)
270
+
271
+
272
+ def _repack_tensor_proto(tensor_proto: "tensor_pb2.TensorProto"):
273
+ """Repacks a TensorProto in place to use a consistent serialization format."""
274
+ if not tensor_proto.ByteSize():
275
+ return
276
+
277
+ try:
278
+ data_array = backend.make_ndarray(tensor_proto)
279
+ except Exception as e:
280
+ raise ValueError(
281
+ "Failed to deserialize TensorProto during normalization:"
282
+ f" {e}\nProto content:\n{tensor_proto}"
283
+ ) from e
284
+
285
+ new_tensor_proto = backend.make_tensor_proto(data_array)
286
+
287
+ tensor_proto.Clear()
288
+ tensor_proto.CopyFrom(new_tensor_proto)
289
+
290
+
291
+ def assert_normalized_proto_equal(
292
+ test_case: parameterized.TestCase,
293
+ expected: message.Message,
294
+ actual: message.Message,
295
+ msg: Optional[str] = None,
296
+ **kwargs: Any,
297
+ ):
298
+ """Compares two protos after normalizing TensorProto fields.
299
+
300
+ Use this instead of compare.assertProtoEqual when protos contain tensors
301
+ to ensure backend-agnostic comparison.
302
+
303
+ Args:
304
+ test_case: The TestCase instance (self).
305
+ expected: The expected protobuf message.
306
+ actual: The actual protobuf message.
307
+ msg: An optional message to display on failure.
308
+ **kwargs: Additional keyword arguments passed to assertProto2Equal (e.g.,
309
+ precision).
310
+ """
311
+ # Work on copies to avoid mutating the original objects
312
+ expected_copy = expected.__class__()
313
+ expected_copy.CopyFrom(expected)
314
+ actual_copy = actual.__class__()
315
+ actual_copy.CopyFrom(actual)
316
+
317
+ try:
318
+ normalize_tensor_protos(expected_copy)
319
+ normalize_tensor_protos(actual_copy)
320
+ except ValueError as e:
321
+ test_case.fail(f"Proto normalization failed: {e}. {msg}")
322
+
323
+ compare.assertProtoEqual(
324
+ test_case, expected_copy, actual_copy, msg=msg, **kwargs
325
+ )
326
+
327
+
134
328
  class MeridianTestCase(parameterized.TestCase):
135
329
  """Base test class for Meridian providing backend-aware utilities.
136
330
 
meridian/constants.py CHANGED
@@ -662,6 +662,7 @@ CURRENT_SPEND = 'current_spend'
662
662
 
663
663
  # Media summary metrics.
664
664
  SPEND = 'spend'
665
+ COST = 'cost'
665
666
  IMPRESSIONS = 'impressions'
666
667
  ROI = 'roi'
667
668
  OPTIMIZED_ROI = 'optimized_roi'
meridian/data/load.py CHANGED
@@ -35,6 +35,8 @@ __all__ = [
35
35
  'InputDataLoader',
36
36
  'XrDatasetDataLoader',
37
37
  'DataFrameDataLoader',
38
+ 'CoordToColumns',
39
+ 'CsvDataLoader',
38
40
  ]
39
41
 
40
42
 
@@ -17,4 +17,3 @@
17
17
  from meridian.model.eda import eda_engine
18
18
  from meridian.model.eda import eda_outcome
19
19
  from meridian.model.eda import eda_spec
20
- from meridian.model.eda import meridian_eda
@@ -14,8 +14,18 @@
14
14
 
15
15
  """Constants specific to MeridianEDA."""
16
16
 
17
- # EDA Plotting properties
17
+ # EDA Engine constants
18
+ COST_PER_MEDIA_UNIT = 'cost_per_media_unit'
19
+
18
20
  VARIABLE_1 = 'var1'
19
21
  VARIABLE_2 = 'var2'
20
- VARIABLE = 'var'
21
22
  CORRELATION = 'correlation'
23
+ ABS_CORRELATION_COL_NAME = 'abs_correlation'
24
+
25
+ # EDA Plotting properties
26
+ VARIABLE = 'var'
27
+ VALUE = 'value'
28
+ NATIONALIZE = 'nationalize'
29
+ MEDIA_IMPRESSIONS_SCALED = 'media_impressions_scaled'
30
+ IMPRESSION_SHARE_SCALED = 'impression_share_scaled'
31
+ SPEND_SHARE = 'spend_share'