google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -15,11 +15,21 @@
15
15
  """Common testing utilities for Meridian, designed to be backend-agnostic."""
16
16
 
17
17
  from typing import Any, Optional
18
+
18
19
  from absl.testing import parameterized
20
+ from google.protobuf import descriptor
21
+ from google.protobuf import message
19
22
  from meridian import backend
20
23
  from meridian.backend import config
21
24
  import numpy as np
22
25
 
26
+ from tensorflow.python.util.protobuf import compare
27
+ # pylint: disable=g-direct-tensorflow-import
28
+ from tensorflow.core.framework import tensor_pb2
29
+ # pylint: enable=g-direct-tensorflow-import
30
+
31
+ FieldDescriptor = descriptor.FieldDescriptor
32
+
23
33
  # A type alias for backend-agnostic array-like objects.
24
34
  # We use `Any` here to avoid circular dependencies with the backend module
25
35
  # while still allowing the function to accept backend-specific tensor types.
@@ -70,6 +80,39 @@ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
70
80
  np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
71
81
 
72
82
 
83
+ def assert_seed_allequal(a: Any, b: Any, err_msg: str = ""):
84
+ """Backend-agnostic assertion to check if two seed objects are equal."""
85
+ data_a = backend.get_seed_data(a)
86
+ data_b = backend.get_seed_data(b)
87
+ if data_a is None and data_b is None:
88
+ return
89
+ np.testing.assert_array_equal(data_a, data_b, err_msg=err_msg)
90
+
91
+
92
+ def assert_not_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
93
+ """Asserts that two objects are not element-wise equal."""
94
+ np.testing.assert_(
95
+ not np.array_equal(np.array(a), np.array(b)),
96
+ msg=f"Arrays are unexpectedly equal.\n{err_msg}",
97
+ )
98
+
99
+
100
+ def assert_seed_not_allequal(a: Any, b: Any, err_msg: str = ""):
101
+ """Asserts that two seed objects are not element-wise equal."""
102
+ data_a = backend.get_seed_data(a)
103
+ data_b = backend.get_seed_data(b)
104
+ if data_a is None and data_b is None:
105
+ raise AssertionError(
106
+ f"Seeds are unexpectedly equal (both are None). {err_msg}"
107
+ )
108
+ if data_a is None or data_b is None:
109
+ return
110
+ np.testing.assert_(
111
+ not np.array_equal(data_a, data_b),
112
+ msg=f"Seeds are unexpectedly equal.\n{err_msg}",
113
+ )
114
+
115
+
73
116
  def assert_all_finite(a: ArrayLike, err_msg: str = ""):
74
117
  """Backend-agnostic assertion to check if all elements in an array are finite.
75
118
 
@@ -98,6 +141,118 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
98
141
  raise AssertionError(err_msg or "Array contains negative values.")
99
142
 
100
143
 
144
+ # --- Proto Utilities ---
145
+ def normalize_tensor_protos(proto: message.Message):
146
+ """Recursively normalizes TensorProto messages within a proto (In-place).
147
+
148
+ This ensures a consistent serialization format across different backends
149
+ (e.g., JAX vs TF) by repacking TensorProtos using the current backend's
150
+ canonical method (backend.make_tensor_proto). This handles differences
151
+ like using `bool_val` versus `tensor_content` for boolean tensors.
152
+
153
+ Args:
154
+ proto: The protobuf message object to normalize. This object is modified in
155
+ place.
156
+ """
157
+ if not isinstance(proto, message.Message):
158
+ return
159
+
160
+ for desc, value in proto.ListFields():
161
+ if desc.type != FieldDescriptor.TYPE_MESSAGE:
162
+ continue
163
+
164
+ # A map is defined as a repeated field whose message type has the
165
+ # map_entry option set.
166
+ is_map = (
167
+ desc.label == FieldDescriptor.LABEL_REPEATED
168
+ and desc.message_type.has_options
169
+ and desc.message_type.GetOptions().map_entry
170
+ )
171
+
172
+ if is_map:
173
+ for item in value.values():
174
+ # Helper checks if values are scalars or messages.
175
+ _process_message_for_normalization(item)
176
+
177
+ elif desc.label == FieldDescriptor.LABEL_REPEATED:
178
+ # Handle standard repeated message fields.
179
+ for item in value:
180
+ _process_message_for_normalization(item)
181
+ else:
182
+ # Handle singular message fields.
183
+ _process_message_for_normalization(value)
184
+
185
+
186
+ def _process_message_for_normalization(msg: Any):
187
+ """Helper to process a potential message during normalization traversal."""
188
+ # Ensure we only process message objects.
189
+ # If msg is a scalar (e.g., string from map<string, string>), stop recursion.
190
+ if not isinstance(msg, message.Message):
191
+ return
192
+
193
+ if isinstance(msg, tensor_pb2.TensorProto):
194
+ _repack_tensor_proto(msg)
195
+ else:
196
+ # If it's another message type, recurse into its fields.
197
+ normalize_tensor_protos(msg)
198
+
199
+
200
+ def _repack_tensor_proto(tensor_proto: "tensor_pb2.TensorProto"):
201
+ """Repacks a TensorProto in place to use a consistent serialization format."""
202
+ if not tensor_proto.ByteSize():
203
+ return
204
+
205
+ try:
206
+ data_array = backend.make_ndarray(tensor_proto)
207
+ except Exception as e:
208
+ raise ValueError(
209
+ "Failed to deserialize TensorProto during normalization:"
210
+ f" {e}\nProto content:\n{tensor_proto}"
211
+ ) from e
212
+
213
+ new_tensor_proto = backend.make_tensor_proto(data_array)
214
+
215
+ tensor_proto.Clear()
216
+ tensor_proto.CopyFrom(new_tensor_proto)
217
+
218
+
219
+ def assert_normalized_proto_equal(
220
+ test_case: parameterized.TestCase,
221
+ expected: message.Message,
222
+ actual: message.Message,
223
+ msg: Optional[str] = None,
224
+ **kwargs: Any,
225
+ ):
226
+ """Compares two protos after normalizing TensorProto fields.
227
+
228
+ Use this instead of compare.assertProtoEqual when protos contain tensors
229
+ to ensure backend-agnostic comparison.
230
+
231
+ Args:
232
+ test_case: The TestCase instance (self).
233
+ expected: The expected protobuf message.
234
+ actual: The actual protobuf message.
235
+ msg: An optional message to display on failure.
236
+ **kwargs: Additional keyword arguments passed to assertProto2Equal (e.g.,
237
+ precision).
238
+ """
239
+ # Work on copies to avoid mutating the original objects
240
+ expected_copy = expected.__class__()
241
+ expected_copy.CopyFrom(expected)
242
+ actual_copy = actual.__class__()
243
+ actual_copy.CopyFrom(actual)
244
+
245
+ try:
246
+ normalize_tensor_protos(expected_copy)
247
+ normalize_tensor_protos(actual_copy)
248
+ except ValueError as e:
249
+ test_case.fail(f"Proto normalization failed: {e}. {msg}")
250
+
251
+ compare.assertProtoEqual(
252
+ test_case, expected_copy, actual_copy, msg=msg, **kwargs
253
+ )
254
+
255
+
101
256
  class MeridianTestCase(parameterized.TestCase):
102
257
  """Base test class for Meridian providing backend-aware utilities.
103
258
 
@@ -106,6 +261,13 @@ class MeridianTestCase(parameterized.TestCase):
106
261
  number generation across backends (Stateful TF vs Stateless JAX).
107
262
  """
108
263
 
264
+ @classmethod
265
+ def setUpClass(cls):
266
+ super().setUpClass()
267
+ # Enforce determinism for TensorFlow tests before any tests are run.
268
+ # This is a no-op with a warning for the JAX backend.
269
+ backend.enable_op_determinism()
270
+
109
271
  def setUp(self):
110
272
  super().setUp()
111
273
  # Default seed, can be overridden by subclasses before calling
meridian/constants.py CHANGED
@@ -55,6 +55,7 @@ DATE_FORMAT = '%Y-%m-%d'
55
55
  QUARTER_FORMAT = '%Y %b'
56
56
 
57
57
  ORGANIC_PREFIX = 'organic_'
58
+ NATIONAL_PREFIX = 'national_'
58
59
 
59
60
  # Input data variables.
60
61
  KPI = 'kpi'
@@ -123,7 +124,6 @@ RF_DATA = (
123
124
  RF_SPEND,
124
125
  REVENUE_PER_KPI,
125
126
  )
126
- NON_REVENUE_DATA = IMPRESSIONS_DATA + (CONTROLS,)
127
127
 
128
128
  # Scaled input data variables.
129
129
  MEDIA_SCALED = 'media_scaled'
@@ -132,6 +132,9 @@ ORGANIC_MEDIA_SCALED = ORGANIC_PREFIX + MEDIA_SCALED
132
132
  ORGANIC_REACH_SCALED = ORGANIC_PREFIX + REACH_SCALED
133
133
  NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
134
134
  CONTROLS_SCALED = 'controls_scaled'
135
+ KPI_SCALED = f'{KPI}_scaled'
136
+ POPULATION_SCALED_KPI = f'{POPULATION}_scaled_{KPI}'
137
+ RF_IMPRESSIONS_SCALED = f'{RF_IMPRESSIONS}_scaled'
135
138
 
136
139
  # Non-media treatments baseline value constants.
137
140
  NON_MEDIA_BASELINE_MIN = 'min'
@@ -174,8 +177,40 @@ POSSIBLE_INPUT_DATA_COORDS_AND_ARRAYS_SET = frozenset(
174
177
  POSSIBLE_INPUT_DATA_COORD_NAMES + POSSIBLE_INPUT_DATA_ARRAY_NAMES
175
178
  )
176
179
 
177
- # EDA property constants
180
+ # EDA Engine properties
178
181
  ORGANIC_RF_IMPRESSIONS = ORGANIC_PREFIX + RF_IMPRESSIONS
182
+ ORGANIC_RF_IMPRESSIONS_SCALED = f'{ORGANIC_RF_IMPRESSIONS}_scaled'
183
+ TREATMENT_CONTROL_SCALED = 'treatment_control_scaled'
184
+ NATIONAL_TREATMENT_CONTROL_SCALED = (
185
+ f'{NATIONAL_PREFIX}{TREATMENT_CONTROL_SCALED}'
186
+ )
187
+ NATIONAL_CONTROLS_SCALED = f'{NATIONAL_PREFIX}{CONTROLS_SCALED}'
188
+ NATIONAL_MEDIA_SPEND = f'{NATIONAL_PREFIX}{MEDIA_SPEND}'
189
+ NATIONAL_MEDIA = f'{NATIONAL_PREFIX}{MEDIA}'
190
+ NATIONAL_MEDIA_SCALED = f'{NATIONAL_PREFIX}{MEDIA_SCALED}'
191
+ NATIONAL_ORGANIC_MEDIA = f'{NATIONAL_PREFIX}{ORGANIC_MEDIA}'
192
+ NATIONAL_ORGANIC_MEDIA_SCALED = f'{NATIONAL_PREFIX}{ORGANIC_MEDIA_SCALED}'
193
+ NATIONAL_NON_MEDIA_TREATMENTS_SCALED = (
194
+ f'{NATIONAL_PREFIX}{NON_MEDIA_TREATMENTS_SCALED}'
195
+ )
196
+ NATIONAL_RF_SPEND = f'{NATIONAL_PREFIX}{RF_SPEND}'
197
+ NATIONAL_KPI_SCALED = f'{NATIONAL_PREFIX}{KPI_SCALED}'
198
+ NATIONAL_REACH = f'{NATIONAL_PREFIX}{REACH}'
199
+ NATIONAL_REACH_SCALED = f'{NATIONAL_PREFIX}{REACH_SCALED}'
200
+ NATIONAL_ORGANIC_REACH = f'{NATIONAL_PREFIX}{ORGANIC_REACH}'
201
+ NATIONAL_ORGANIC_REACH_SCALED = f'{NATIONAL_PREFIX}{ORGANIC_REACH_SCALED}'
202
+ NATIONAL_FREQUENCY = f'{NATIONAL_PREFIX}{FREQUENCY}'
203
+ NATIONAL_ORGANIC_FREQUENCY = f'{NATIONAL_PREFIX}{ORGANIC_FREQUENCY}'
204
+ NATIONAL_RF_IMPRESSIONS = f'{NATIONAL_PREFIX}{RF_IMPRESSIONS}'
205
+ NATIONAL_ORGANIC_RF_IMPRESSIONS = f'{NATIONAL_PREFIX}{ORGANIC_RF_IMPRESSIONS}'
206
+ NATIONAL_RF_IMPRESSIONS_SCALED = f'{NATIONAL_PREFIX}{RF_IMPRESSIONS_SCALED}'
207
+ NATIONAL_ORGANIC_RF_IMPRESSIONS_SCALED = (
208
+ f'{NATIONAL_PREFIX}{ORGANIC_RF_IMPRESSIONS_SCALED}'
209
+ )
210
+ ALL_REACH_SCALED = 'all_reach_scaled'
211
+ ALL_FREQUENCY = 'all_frequency'
212
+ NATIONAL_ALL_REACH_SCALED = f'{NATIONAL_PREFIX}{ALL_REACH_SCALED}'
213
+ NATIONAL_ALL_FREQUENCY = f'{NATIONAL_PREFIX}{ALL_FREQUENCY}'
179
214
 
180
215
 
181
216
  # National model constants.
@@ -359,7 +394,6 @@ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
359
394
  ETA_ORF,
360
395
  )
361
396
 
362
-
363
397
  MEDIA_PARAMETERS = (
364
398
  ROI_M,
365
399
  MROI_M,
@@ -747,6 +781,8 @@ HILL_NUM_STEPS = 500
747
781
  # Summary template params.
748
782
  START_DATE = 'start_date'
749
783
  END_DATE = 'end_date'
784
+ DEFAULT_CURRENCY = '$'
785
+ SELECTED_GEOS = 'selected_geos'
750
786
  CARD_INSIGHTS = 'insights'
751
787
  CARD_CHARTS = 'charts'
752
788
  CARD_STATS = 'stats'
@@ -15,6 +15,7 @@
15
15
  """The Meridian API module that models the data."""
16
16
 
17
17
  from meridian.model import adstock_hill
18
+ from meridian.model import eda
18
19
  from meridian.model import knots
19
20
  from meridian.model import media
20
21
  from meridian.model import model
@@ -15,3 +15,6 @@
15
15
  """The Meridian API module that performs EDA checks."""
16
16
 
17
17
  from meridian.model.eda import eda_engine
18
+ from meridian.model.eda import eda_outcome
19
+ from meridian.model.eda import eda_spec
20
+ from meridian.model.eda import meridian_eda
@@ -0,0 +1,21 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Constants specific to MeridianEDA."""
16
+
17
+ # EDA Plotting properties
18
+ VARIABLE_1 = 'var1'
19
+ VARIABLE_2 = 'var2'
20
+ VARIABLE = 'var'
21
+ CORRELATION = 'correlation'