google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
meridian/backend/test_utils.py
CHANGED
|
@@ -15,11 +15,21 @@
|
|
|
15
15
|
"""Common testing utilities for Meridian, designed to be backend-agnostic."""
|
|
16
16
|
|
|
17
17
|
from typing import Any, Optional
|
|
18
|
+
|
|
18
19
|
from absl.testing import parameterized
|
|
20
|
+
from google.protobuf import descriptor
|
|
21
|
+
from google.protobuf import message
|
|
19
22
|
from meridian import backend
|
|
20
23
|
from meridian.backend import config
|
|
21
24
|
import numpy as np
|
|
22
25
|
|
|
26
|
+
from tensorflow.python.util.protobuf import compare
|
|
27
|
+
# pylint: disable=g-direct-tensorflow-import
|
|
28
|
+
from tensorflow.core.framework import tensor_pb2
|
|
29
|
+
# pylint: enable=g-direct-tensorflow-import
|
|
30
|
+
|
|
31
|
+
FieldDescriptor = descriptor.FieldDescriptor
|
|
32
|
+
|
|
23
33
|
# A type alias for backend-agnostic array-like objects.
|
|
24
34
|
# We use `Any` here to avoid circular dependencies with the backend module
|
|
25
35
|
# while still allowing the function to accept backend-specific tensor types.
|
|
@@ -70,6 +80,39 @@ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
|
|
|
70
80
|
np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
|
|
71
81
|
|
|
72
82
|
|
|
83
|
+
def assert_seed_allequal(a: Any, b: Any, err_msg: str = ""):
|
|
84
|
+
"""Backend-agnostic assertion to check if two seed objects are equal."""
|
|
85
|
+
data_a = backend.get_seed_data(a)
|
|
86
|
+
data_b = backend.get_seed_data(b)
|
|
87
|
+
if data_a is None and data_b is None:
|
|
88
|
+
return
|
|
89
|
+
np.testing.assert_array_equal(data_a, data_b, err_msg=err_msg)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def assert_not_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
|
|
93
|
+
"""Asserts that two objects are not element-wise equal."""
|
|
94
|
+
np.testing.assert_(
|
|
95
|
+
not np.array_equal(np.array(a), np.array(b)),
|
|
96
|
+
msg=f"Arrays are unexpectedly equal.\n{err_msg}",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def assert_seed_not_allequal(a: Any, b: Any, err_msg: str = ""):
|
|
101
|
+
"""Asserts that two seed objects are not element-wise equal."""
|
|
102
|
+
data_a = backend.get_seed_data(a)
|
|
103
|
+
data_b = backend.get_seed_data(b)
|
|
104
|
+
if data_a is None and data_b is None:
|
|
105
|
+
raise AssertionError(
|
|
106
|
+
f"Seeds are unexpectedly equal (both are None). {err_msg}"
|
|
107
|
+
)
|
|
108
|
+
if data_a is None or data_b is None:
|
|
109
|
+
return
|
|
110
|
+
np.testing.assert_(
|
|
111
|
+
not np.array_equal(data_a, data_b),
|
|
112
|
+
msg=f"Seeds are unexpectedly equal.\n{err_msg}",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
73
116
|
def assert_all_finite(a: ArrayLike, err_msg: str = ""):
|
|
74
117
|
"""Backend-agnostic assertion to check if all elements in an array are finite.
|
|
75
118
|
|
|
@@ -98,6 +141,118 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
|
|
|
98
141
|
raise AssertionError(err_msg or "Array contains negative values.")
|
|
99
142
|
|
|
100
143
|
|
|
144
|
+
# --- Proto Utilities ---
|
|
145
|
+
def normalize_tensor_protos(proto: message.Message):
|
|
146
|
+
"""Recursively normalizes TensorProto messages within a proto (In-place).
|
|
147
|
+
|
|
148
|
+
This ensures a consistent serialization format across different backends
|
|
149
|
+
(e.g., JAX vs TF) by repacking TensorProtos using the current backend's
|
|
150
|
+
canonical method (backend.make_tensor_proto). This handles differences
|
|
151
|
+
like using `bool_val` versus `tensor_content` for boolean tensors.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
proto: The protobuf message object to normalize. This object is modified in
|
|
155
|
+
place.
|
|
156
|
+
"""
|
|
157
|
+
if not isinstance(proto, message.Message):
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
for desc, value in proto.ListFields():
|
|
161
|
+
if desc.type != FieldDescriptor.TYPE_MESSAGE:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
# A map is defined as a repeated field whose message type has the
|
|
165
|
+
# map_entry option set.
|
|
166
|
+
is_map = (
|
|
167
|
+
desc.label == FieldDescriptor.LABEL_REPEATED
|
|
168
|
+
and desc.message_type.has_options
|
|
169
|
+
and desc.message_type.GetOptions().map_entry
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if is_map:
|
|
173
|
+
for item in value.values():
|
|
174
|
+
# Helper checks if values are scalars or messages.
|
|
175
|
+
_process_message_for_normalization(item)
|
|
176
|
+
|
|
177
|
+
elif desc.label == FieldDescriptor.LABEL_REPEATED:
|
|
178
|
+
# Handle standard repeated message fields.
|
|
179
|
+
for item in value:
|
|
180
|
+
_process_message_for_normalization(item)
|
|
181
|
+
else:
|
|
182
|
+
# Handle singular message fields.
|
|
183
|
+
_process_message_for_normalization(value)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _process_message_for_normalization(msg: Any):
|
|
187
|
+
"""Helper to process a potential message during normalization traversal."""
|
|
188
|
+
# Ensure we only process message objects.
|
|
189
|
+
# If msg is a scalar (e.g., string from map<string, string>), stop recursion.
|
|
190
|
+
if not isinstance(msg, message.Message):
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
if isinstance(msg, tensor_pb2.TensorProto):
|
|
194
|
+
_repack_tensor_proto(msg)
|
|
195
|
+
else:
|
|
196
|
+
# If it's another message type, recurse into its fields.
|
|
197
|
+
normalize_tensor_protos(msg)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _repack_tensor_proto(tensor_proto: "tensor_pb2.TensorProto"):
|
|
201
|
+
"""Repacks a TensorProto in place to use a consistent serialization format."""
|
|
202
|
+
if not tensor_proto.ByteSize():
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
data_array = backend.make_ndarray(tensor_proto)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
raise ValueError(
|
|
209
|
+
"Failed to deserialize TensorProto during normalization:"
|
|
210
|
+
f" {e}\nProto content:\n{tensor_proto}"
|
|
211
|
+
) from e
|
|
212
|
+
|
|
213
|
+
new_tensor_proto = backend.make_tensor_proto(data_array)
|
|
214
|
+
|
|
215
|
+
tensor_proto.Clear()
|
|
216
|
+
tensor_proto.CopyFrom(new_tensor_proto)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def assert_normalized_proto_equal(
|
|
220
|
+
test_case: parameterized.TestCase,
|
|
221
|
+
expected: message.Message,
|
|
222
|
+
actual: message.Message,
|
|
223
|
+
msg: Optional[str] = None,
|
|
224
|
+
**kwargs: Any,
|
|
225
|
+
):
|
|
226
|
+
"""Compares two protos after normalizing TensorProto fields.
|
|
227
|
+
|
|
228
|
+
Use this instead of compare.assertProtoEqual when protos contain tensors
|
|
229
|
+
to ensure backend-agnostic comparison.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
test_case: The TestCase instance (self).
|
|
233
|
+
expected: The expected protobuf message.
|
|
234
|
+
actual: The actual protobuf message.
|
|
235
|
+
msg: An optional message to display on failure.
|
|
236
|
+
**kwargs: Additional keyword arguments passed to assertProto2Equal (e.g.,
|
|
237
|
+
precision).
|
|
238
|
+
"""
|
|
239
|
+
# Work on copies to avoid mutating the original objects
|
|
240
|
+
expected_copy = expected.__class__()
|
|
241
|
+
expected_copy.CopyFrom(expected)
|
|
242
|
+
actual_copy = actual.__class__()
|
|
243
|
+
actual_copy.CopyFrom(actual)
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
normalize_tensor_protos(expected_copy)
|
|
247
|
+
normalize_tensor_protos(actual_copy)
|
|
248
|
+
except ValueError as e:
|
|
249
|
+
test_case.fail(f"Proto normalization failed: {e}. {msg}")
|
|
250
|
+
|
|
251
|
+
compare.assertProtoEqual(
|
|
252
|
+
test_case, expected_copy, actual_copy, msg=msg, **kwargs
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
101
256
|
class MeridianTestCase(parameterized.TestCase):
|
|
102
257
|
"""Base test class for Meridian providing backend-aware utilities.
|
|
103
258
|
|
|
@@ -106,6 +261,13 @@ class MeridianTestCase(parameterized.TestCase):
|
|
|
106
261
|
number generation across backends (Stateful TF vs Stateless JAX).
|
|
107
262
|
"""
|
|
108
263
|
|
|
264
|
+
@classmethod
|
|
265
|
+
def setUpClass(cls):
|
|
266
|
+
super().setUpClass()
|
|
267
|
+
# Enforce determinism for TensorFlow tests before any tests are run.
|
|
268
|
+
# This is a no-op with a warning for the JAX backend.
|
|
269
|
+
backend.enable_op_determinism()
|
|
270
|
+
|
|
109
271
|
def setUp(self):
|
|
110
272
|
super().setUp()
|
|
111
273
|
# Default seed, can be overridden by subclasses before calling
|
meridian/constants.py
CHANGED
|
@@ -55,6 +55,7 @@ DATE_FORMAT = '%Y-%m-%d'
|
|
|
55
55
|
QUARTER_FORMAT = '%Y %b'
|
|
56
56
|
|
|
57
57
|
ORGANIC_PREFIX = 'organic_'
|
|
58
|
+
NATIONAL_PREFIX = 'national_'
|
|
58
59
|
|
|
59
60
|
# Input data variables.
|
|
60
61
|
KPI = 'kpi'
|
|
@@ -123,7 +124,6 @@ RF_DATA = (
|
|
|
123
124
|
RF_SPEND,
|
|
124
125
|
REVENUE_PER_KPI,
|
|
125
126
|
)
|
|
126
|
-
NON_REVENUE_DATA = IMPRESSIONS_DATA + (CONTROLS,)
|
|
127
127
|
|
|
128
128
|
# Scaled input data variables.
|
|
129
129
|
MEDIA_SCALED = 'media_scaled'
|
|
@@ -132,6 +132,9 @@ ORGANIC_MEDIA_SCALED = ORGANIC_PREFIX + MEDIA_SCALED
|
|
|
132
132
|
ORGANIC_REACH_SCALED = ORGANIC_PREFIX + REACH_SCALED
|
|
133
133
|
NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
|
|
134
134
|
CONTROLS_SCALED = 'controls_scaled'
|
|
135
|
+
KPI_SCALED = f'{KPI}_scaled'
|
|
136
|
+
POPULATION_SCALED_KPI = f'{POPULATION}_scaled_{KPI}'
|
|
137
|
+
RF_IMPRESSIONS_SCALED = f'{RF_IMPRESSIONS}_scaled'
|
|
135
138
|
|
|
136
139
|
# Non-media treatments baseline value constants.
|
|
137
140
|
NON_MEDIA_BASELINE_MIN = 'min'
|
|
@@ -174,8 +177,40 @@ POSSIBLE_INPUT_DATA_COORDS_AND_ARRAYS_SET = frozenset(
|
|
|
174
177
|
POSSIBLE_INPUT_DATA_COORD_NAMES + POSSIBLE_INPUT_DATA_ARRAY_NAMES
|
|
175
178
|
)
|
|
176
179
|
|
|
177
|
-
# EDA
|
|
180
|
+
# EDA Engine properties
|
|
178
181
|
ORGANIC_RF_IMPRESSIONS = ORGANIC_PREFIX + RF_IMPRESSIONS
|
|
182
|
+
ORGANIC_RF_IMPRESSIONS_SCALED = f'{ORGANIC_RF_IMPRESSIONS}_scaled'
|
|
183
|
+
TREATMENT_CONTROL_SCALED = 'treatment_control_scaled'
|
|
184
|
+
NATIONAL_TREATMENT_CONTROL_SCALED = (
|
|
185
|
+
f'{NATIONAL_PREFIX}{TREATMENT_CONTROL_SCALED}'
|
|
186
|
+
)
|
|
187
|
+
NATIONAL_CONTROLS_SCALED = f'{NATIONAL_PREFIX}{CONTROLS_SCALED}'
|
|
188
|
+
NATIONAL_MEDIA_SPEND = f'{NATIONAL_PREFIX}{MEDIA_SPEND}'
|
|
189
|
+
NATIONAL_MEDIA = f'{NATIONAL_PREFIX}{MEDIA}'
|
|
190
|
+
NATIONAL_MEDIA_SCALED = f'{NATIONAL_PREFIX}{MEDIA_SCALED}'
|
|
191
|
+
NATIONAL_ORGANIC_MEDIA = f'{NATIONAL_PREFIX}{ORGANIC_MEDIA}'
|
|
192
|
+
NATIONAL_ORGANIC_MEDIA_SCALED = f'{NATIONAL_PREFIX}{ORGANIC_MEDIA_SCALED}'
|
|
193
|
+
NATIONAL_NON_MEDIA_TREATMENTS_SCALED = (
|
|
194
|
+
f'{NATIONAL_PREFIX}{NON_MEDIA_TREATMENTS_SCALED}'
|
|
195
|
+
)
|
|
196
|
+
NATIONAL_RF_SPEND = f'{NATIONAL_PREFIX}{RF_SPEND}'
|
|
197
|
+
NATIONAL_KPI_SCALED = f'{NATIONAL_PREFIX}{KPI_SCALED}'
|
|
198
|
+
NATIONAL_REACH = f'{NATIONAL_PREFIX}{REACH}'
|
|
199
|
+
NATIONAL_REACH_SCALED = f'{NATIONAL_PREFIX}{REACH_SCALED}'
|
|
200
|
+
NATIONAL_ORGANIC_REACH = f'{NATIONAL_PREFIX}{ORGANIC_REACH}'
|
|
201
|
+
NATIONAL_ORGANIC_REACH_SCALED = f'{NATIONAL_PREFIX}{ORGANIC_REACH_SCALED}'
|
|
202
|
+
NATIONAL_FREQUENCY = f'{NATIONAL_PREFIX}{FREQUENCY}'
|
|
203
|
+
NATIONAL_ORGANIC_FREQUENCY = f'{NATIONAL_PREFIX}{ORGANIC_FREQUENCY}'
|
|
204
|
+
NATIONAL_RF_IMPRESSIONS = f'{NATIONAL_PREFIX}{RF_IMPRESSIONS}'
|
|
205
|
+
NATIONAL_ORGANIC_RF_IMPRESSIONS = f'{NATIONAL_PREFIX}{ORGANIC_RF_IMPRESSIONS}'
|
|
206
|
+
NATIONAL_RF_IMPRESSIONS_SCALED = f'{NATIONAL_PREFIX}{RF_IMPRESSIONS_SCALED}'
|
|
207
|
+
NATIONAL_ORGANIC_RF_IMPRESSIONS_SCALED = (
|
|
208
|
+
f'{NATIONAL_PREFIX}{ORGANIC_RF_IMPRESSIONS_SCALED}'
|
|
209
|
+
)
|
|
210
|
+
ALL_REACH_SCALED = 'all_reach_scaled'
|
|
211
|
+
ALL_FREQUENCY = 'all_frequency'
|
|
212
|
+
NATIONAL_ALL_REACH_SCALED = f'{NATIONAL_PREFIX}{ALL_REACH_SCALED}'
|
|
213
|
+
NATIONAL_ALL_FREQUENCY = f'{NATIONAL_PREFIX}{ALL_FREQUENCY}'
|
|
179
214
|
|
|
180
215
|
|
|
181
216
|
# National model constants.
|
|
@@ -359,7 +394,6 @@ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
|
|
|
359
394
|
ETA_ORF,
|
|
360
395
|
)
|
|
361
396
|
|
|
362
|
-
|
|
363
397
|
MEDIA_PARAMETERS = (
|
|
364
398
|
ROI_M,
|
|
365
399
|
MROI_M,
|
|
@@ -747,6 +781,8 @@ HILL_NUM_STEPS = 500
|
|
|
747
781
|
# Summary template params.
|
|
748
782
|
START_DATE = 'start_date'
|
|
749
783
|
END_DATE = 'end_date'
|
|
784
|
+
DEFAULT_CURRENCY = '$'
|
|
785
|
+
SELECTED_GEOS = 'selected_geos'
|
|
750
786
|
CARD_INSIGHTS = 'insights'
|
|
751
787
|
CARD_CHARTS = 'charts'
|
|
752
788
|
CARD_STATS = 'stats'
|
meridian/model/__init__.py
CHANGED
meridian/model/eda/__init__.py
CHANGED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Constants specific to MeridianEDA."""
|
|
16
|
+
|
|
17
|
+
# EDA Plotting properties
|
|
18
|
+
VARIABLE_1 = 'var1'
|
|
19
|
+
VARIABLE_2 = 'var2'
|
|
20
|
+
VARIABLE = 'var'
|
|
21
|
+
CORRELATION = 'correlation'
|