google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
- google_meridian-1.3.0.dist-info/RECORD +62 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +280 -142
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +353 -169
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +14 -12
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +45 -50
- meridian/backend/__init__.py +698 -55
- meridian/backend/config.py +75 -16
- meridian/backend/test_utils.py +127 -1
- meridian/constants.py +52 -11
- meridian/data/input_data.py +7 -2
- meridian/data/test_utils.py +5 -3
- meridian/mlflow/autolog.py +2 -2
- meridian/model/__init__.py +1 -0
- meridian/model/adstock_hill.py +10 -9
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1580 -84
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +56 -50
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -9
- meridian/model/posterior_sampler.py +398 -391
- meridian/model/prior_distribution.py +114 -39
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +16 -8
- meridian/version.py +1 -1
- google_meridian-1.2.0.dist-info/RECORD +0 -52
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
meridian/model/media.py
CHANGED
|
@@ -63,8 +63,8 @@ class MediaTensors:
|
|
|
63
63
|
media_spend: A tensor constructed from `InputData.media_spend`.
|
|
64
64
|
media_transformer: A `MediaTransformer` to scale media tensors using the
|
|
65
65
|
model's media data.
|
|
66
|
-
media_scaled: The media tensor
|
|
67
|
-
value.
|
|
66
|
+
media_scaled: The media tensor after pre-modeling transformations including
|
|
67
|
+
population scaling and scaling by the median non-zero value.
|
|
68
68
|
prior_media_scaled_counterfactual: A tensor containing `media_scaled` values
|
|
69
69
|
corresponding to the counterfactual scenario required for the prior
|
|
70
70
|
calculation. For ROI priors, the counterfactual scenario is where media is
|
|
@@ -169,8 +169,9 @@ class OrganicMediaTensors:
|
|
|
169
169
|
organic_media: A tensor constructed from `InputData.organic_media`.
|
|
170
170
|
organic_media_transformer: A `MediaTransformer` to scale media tensors using
|
|
171
171
|
the model's organic media data.
|
|
172
|
-
organic_media_scaled: The organic media tensor
|
|
173
|
-
by the
|
|
172
|
+
organic_media_scaled: The organic media tensor after pre-modeling
|
|
173
|
+
transformations including population scaling and scaling by the media
|
|
174
|
+
non-zero value.
|
|
174
175
|
"""
|
|
175
176
|
|
|
176
177
|
organic_media: backend.Tensor | None = None
|
|
@@ -214,8 +215,8 @@ class RfTensors:
|
|
|
214
215
|
rf_spend: A tensor constructed from `InputData.rf_spend`.
|
|
215
216
|
reach_transformer: A `MediaTransformer` to scale RF tensors using the
|
|
216
217
|
model's RF data.
|
|
217
|
-
reach_scaled: A reach tensor
|
|
218
|
-
value.
|
|
218
|
+
reach_scaled: A reach tensor after pre-modeling transformations including
|
|
219
|
+
population scaling and scaling by the median non-zero value.
|
|
219
220
|
prior_reach_scaled_counterfactual: A tensor containing `reach_scaled` values
|
|
220
221
|
corresponding to the counterfactual scenario required for the prior
|
|
221
222
|
calculation. For ROI priors, the counterfactual scenario is where reach is
|
|
@@ -324,8 +325,9 @@ class OrganicRfTensors:
|
|
|
324
325
|
organic_frequency: A tensor constructed from `InputData.organic_frequency`.
|
|
325
326
|
organic_reach_transformer: A `MediaTransformer` to scale organic RF tensors
|
|
326
327
|
using the model's organic RF data.
|
|
327
|
-
organic_reach_scaled: An organic reach tensor
|
|
328
|
-
by the median
|
|
328
|
+
organic_reach_scaled: An organic reach tensor after pre-modeling
|
|
329
|
+
transformations including population scaling and scaling by the median
|
|
330
|
+
non-zero value.
|
|
329
331
|
"""
|
|
330
332
|
|
|
331
333
|
organic_reach: backend.Tensor | None = None
|
meridian/model/model.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
"""Meridian module for the geo-level Bayesian hierarchical media mix model."""
|
|
16
16
|
|
|
17
|
+
import collections
|
|
17
18
|
from collections.abc import Mapping, Sequence
|
|
18
19
|
import functools
|
|
19
20
|
import numbers
|
|
@@ -34,19 +35,26 @@ from meridian.model import prior_distribution
|
|
|
34
35
|
from meridian.model import prior_sampler
|
|
35
36
|
from meridian.model import spec
|
|
36
37
|
from meridian.model import transformers
|
|
38
|
+
from meridian.model.eda import eda_engine
|
|
39
|
+
from meridian.model.eda import eda_outcome
|
|
40
|
+
from meridian.model.eda import eda_spec as eda_spec_module
|
|
37
41
|
import numpy as np
|
|
38
42
|
|
|
39
|
-
|
|
40
43
|
__all__ = [
|
|
41
44
|
"MCMCSamplingError",
|
|
42
45
|
"MCMCOOMError",
|
|
43
46
|
"Meridian",
|
|
47
|
+
"ModelFittingError",
|
|
44
48
|
"NotFittedModelError",
|
|
45
49
|
"save_mmm",
|
|
46
50
|
"load_mmm",
|
|
47
51
|
]
|
|
48
52
|
|
|
49
53
|
|
|
54
|
+
class ModelFittingError(Exception):
|
|
55
|
+
"""Model has critical issues preventing fitting."""
|
|
56
|
+
|
|
57
|
+
|
|
50
58
|
class NotFittedModelError(Exception):
|
|
51
59
|
"""Model has not been fitted."""
|
|
52
60
|
|
|
@@ -91,6 +99,10 @@ class Meridian:
|
|
|
91
99
|
model_spec: A `ModelSpec` object containing the model specification.
|
|
92
100
|
inference_data: A _mutable_ `arviz.InferenceData` object containing the
|
|
93
101
|
resulting data from fitting the model.
|
|
102
|
+
eda_engine: An `EDAEngine` object containing the EDA engine.
|
|
103
|
+
eda_spec: An `EDASpec` object containing the EDA specification.
|
|
104
|
+
eda_outcomes: A list of `EDAOutcome` objects containing the outcomes from
|
|
105
|
+
running critical EDA checks.
|
|
94
106
|
n_geos: Number of geos in the data.
|
|
95
107
|
n_media_channels: Number of media channels in the data.
|
|
96
108
|
n_rf_channels: Number of reach and frequency (RF) channels in the data.
|
|
@@ -126,11 +138,17 @@ class Meridian:
|
|
|
126
138
|
treatmenttensors using the model's non-media treatment data.
|
|
127
139
|
kpi_transformer: A `KpiTransformer` to scale KPI tensors using the model's
|
|
128
140
|
KPI data.
|
|
129
|
-
controls_scaled: The controls tensor
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
141
|
+
controls_scaled: The controls tensor after pre-modeling transformations
|
|
142
|
+
including population scaling (for variables with
|
|
143
|
+
`ModelSpec.control_population_scaling_id` set to `True`), centering by the
|
|
144
|
+
mean, and scaling by the standard deviation.
|
|
145
|
+
non_media_treatments_scaled: The non-media treatment tensor after
|
|
146
|
+
pre-modeling transformations including population scaling (for variables
|
|
147
|
+
with `ModelSpec.non_media_population_scaling_id` set to `True`), centering
|
|
148
|
+
by the mean, and scaling by the standard deviation.
|
|
149
|
+
kpi_scaled: The KPI tensor after pre-modeling transformations including
|
|
150
|
+
population scaling, centering by the mean, and scaling by the standard
|
|
151
|
+
deviation.
|
|
134
152
|
media_effects_dist: A string to specify the distribution of media random
|
|
135
153
|
effects across geos.
|
|
136
154
|
unique_sigma_for_each_geo: A boolean indicating whether to use a unique
|
|
@@ -148,6 +166,7 @@ class Meridian:
|
|
|
148
166
|
inference_data: (
|
|
149
167
|
az.InferenceData | None
|
|
150
168
|
) = None, # for deserializer use only
|
|
169
|
+
eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
|
|
151
170
|
):
|
|
152
171
|
self._input_data = input_data
|
|
153
172
|
self._model_spec = model_spec if model_spec else spec.ModelSpec()
|
|
@@ -155,6 +174,8 @@ class Meridian:
|
|
|
155
174
|
inference_data if inference_data else az.InferenceData()
|
|
156
175
|
)
|
|
157
176
|
|
|
177
|
+
self._eda_spec = eda_spec
|
|
178
|
+
|
|
158
179
|
self._validate_data_dependent_model_spec()
|
|
159
180
|
self._validate_injected_inference_data()
|
|
160
181
|
|
|
@@ -184,6 +205,18 @@ class Meridian:
|
|
|
184
205
|
def inference_data(self) -> az.InferenceData:
|
|
185
206
|
return self._inference_data
|
|
186
207
|
|
|
208
|
+
@functools.cached_property
|
|
209
|
+
def eda_engine(self) -> eda_engine.EDAEngine:
|
|
210
|
+
return eda_engine.EDAEngine(self, spec=self._eda_spec)
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def eda_spec(self) -> eda_spec_module.EDASpec:
|
|
214
|
+
return self._eda_spec
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def eda_outcomes(self) -> Sequence[eda_outcome.EDAOutcome]:
|
|
218
|
+
return self.eda_engine.run_all_critical_checks()
|
|
219
|
+
|
|
187
220
|
@functools.cached_property
|
|
188
221
|
def media_tensors(self) -> media.MediaTensors:
|
|
189
222
|
return media.build_media_tensors(self.input_data, self.model_spec)
|
|
@@ -444,7 +477,8 @@ class Meridian:
|
|
|
444
477
|
f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
|
|
445
478
|
" either contain only channel_names"
|
|
446
479
|
f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
|
|
447
|
-
" be one or more of {'media', 'rf', 'organic_media',
|
|
480
|
+
" be one or more of {'media', 'rf', 'organic_media',"
|
|
481
|
+
" 'organic_rf'}."
|
|
448
482
|
) from e
|
|
449
483
|
|
|
450
484
|
@functools.cached_property
|
|
@@ -561,7 +595,9 @@ class Meridian:
|
|
|
561
595
|
non_media_treatments_population_scaled[..., channel], axis=[0, 1]
|
|
562
596
|
)
|
|
563
597
|
elif isinstance(baseline_value, numbers.Number):
|
|
564
|
-
baseline_for_channel = backend.
|
|
598
|
+
baseline_for_channel = backend.to_tensor(
|
|
599
|
+
baseline_value, dtype=backend.float32
|
|
600
|
+
)
|
|
565
601
|
else:
|
|
566
602
|
raise ValueError(
|
|
567
603
|
f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
|
|
@@ -1135,16 +1171,11 @@ class Meridian:
|
|
|
1135
1171
|
" time."
|
|
1136
1172
|
)
|
|
1137
1173
|
|
|
1138
|
-
def _kpi_has_variability(self):
|
|
1139
|
-
"""Returns True if the KPI has variability across geos and times."""
|
|
1140
|
-
return self.kpi_transformer.population_scaled_stdev != 0
|
|
1141
|
-
|
|
1142
1174
|
def _validate_kpi_transformer(self):
|
|
1143
1175
|
"""Validates the KPI transformer."""
|
|
1144
|
-
if self.
|
|
1176
|
+
if self.eda_engine.kpi_has_variability:
|
|
1145
1177
|
return
|
|
1146
|
-
|
|
1147
|
-
kpi = "kpi" if self.is_national else "population_scaled_kpi"
|
|
1178
|
+
kpi = self.eda_engine.kpi_scaled_da.name
|
|
1148
1179
|
|
|
1149
1180
|
if (
|
|
1150
1181
|
self.n_media_channels > 0
|
|
@@ -1569,6 +1600,36 @@ class Meridian:
|
|
|
1569
1600
|
"""
|
|
1570
1601
|
self.prior_sampler_callable(n_draws=n_draws, seed=seed)
|
|
1571
1602
|
|
|
1603
|
+
def _run_model_fitting_guardrail(self):
|
|
1604
|
+
"""Raises an error if the model has critical EDA issues."""
|
|
1605
|
+
error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = (
|
|
1606
|
+
collections.defaultdict(list)
|
|
1607
|
+
)
|
|
1608
|
+
for outcome in self.eda_outcomes:
|
|
1609
|
+
error_findings = [
|
|
1610
|
+
finding
|
|
1611
|
+
for finding in outcome.findings
|
|
1612
|
+
if finding.severity == eda_outcome.EDASeverity.ERROR
|
|
1613
|
+
]
|
|
1614
|
+
if error_findings:
|
|
1615
|
+
error_findings_by_type[outcome.check_type].extend(
|
|
1616
|
+
[finding.explanation for finding in error_findings]
|
|
1617
|
+
)
|
|
1618
|
+
|
|
1619
|
+
if error_findings_by_type:
|
|
1620
|
+
error_message_lines = [
|
|
1621
|
+
"Model has critical EDA issues. Please fix before running"
|
|
1622
|
+
" `sample_posterior`.\n"
|
|
1623
|
+
]
|
|
1624
|
+
for check_type, explanations in error_findings_by_type.items():
|
|
1625
|
+
error_message_lines.append(f"Check type: {check_type.name}")
|
|
1626
|
+
for explanation in explanations:
|
|
1627
|
+
error_message_lines.append(f"- {explanation}")
|
|
1628
|
+
error_message_lines.append(
|
|
1629
|
+
"For further details, please refer to `Meridian.eda_outcomes`."
|
|
1630
|
+
)
|
|
1631
|
+
raise ModelFittingError("\n".join(error_message_lines))
|
|
1632
|
+
|
|
1572
1633
|
def sample_posterior(
|
|
1573
1634
|
self,
|
|
1574
1635
|
n_chains: Sequence[int] | int,
|
|
@@ -1644,8 +1705,10 @@ class Meridian:
|
|
|
1644
1705
|
a list of integers as `n_chains` to sample chains serially. For more
|
|
1645
1706
|
information, see
|
|
1646
1707
|
[ResourceExhaustedError when running Meridian.sample_posterior]
|
|
1647
|
-
(https://developers.google.com/meridian/docs/
|
|
1708
|
+
(https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error).
|
|
1648
1709
|
"""
|
|
1710
|
+
self._run_model_fitting_guardrail()
|
|
1711
|
+
|
|
1649
1712
|
self.posterior_sampler_callable(
|
|
1650
1713
|
n_chains=n_chains,
|
|
1651
1714
|
n_adapt=n_adapt,
|
|
@@ -52,7 +52,9 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> backend.Tensor:
|
|
|
52
52
|
else:
|
|
53
53
|
pad_value = 0.0 if array.dtype.kind == "f" else 0
|
|
54
54
|
|
|
55
|
-
burnin = backend.fill(
|
|
55
|
+
burnin = backend.fill(
|
|
56
|
+
[n_burnin] + list(transposed_tensor.shape[1:]), pad_value
|
|
57
|
+
)
|
|
56
58
|
return backend.concatenate(
|
|
57
59
|
[burnin, transposed_tensor],
|
|
58
60
|
axis=0,
|
|
@@ -122,18 +124,13 @@ class WithInputDataSamples:
|
|
|
122
124
|
_N_MEDIA_CHANNELS = 3
|
|
123
125
|
_N_RF_CHANNELS = 2
|
|
124
126
|
_N_CONTROLS = 2
|
|
125
|
-
_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
126
|
-
backend.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
|
|
127
|
-
dtype=backend.bool_,
|
|
128
|
-
)
|
|
129
|
-
_RF_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
130
|
-
backend.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
|
|
131
|
-
dtype=backend.bool_,
|
|
132
|
-
)
|
|
133
127
|
_N_ORGANIC_MEDIA_CHANNELS = 4
|
|
134
128
|
_N_ORGANIC_RF_CHANNELS = 1
|
|
135
129
|
_N_NON_MEDIA_CHANNELS = 2
|
|
136
130
|
|
|
131
|
+
_ROI_CALIBRATION_PERIOD: backend.Tensor
|
|
132
|
+
_RF_ROI_CALIBRATION_PERIOD: backend.Tensor
|
|
133
|
+
|
|
137
134
|
# Private class variables to hold the base test data.
|
|
138
135
|
_input_data_non_revenue_no_revenue_per_kpi: input_data.InputData
|
|
139
136
|
_input_data_media_and_rf_non_revenue_no_revenue_per_kpi: input_data.InputData
|
|
@@ -159,6 +156,8 @@ class WithInputDataSamples:
|
|
|
159
156
|
_short_input_data_non_media_and_organic: input_data.InputData
|
|
160
157
|
_short_input_data_non_media: input_data.InputData
|
|
161
158
|
_input_data_non_media_and_organic_same_time_dims: input_data.InputData
|
|
159
|
+
_input_data_organic_only: input_data.InputData
|
|
160
|
+
_national_input_data_organic_only: input_data.InputData
|
|
162
161
|
|
|
163
162
|
# The following NamedTuples and their attributes are immutable, so they can
|
|
164
163
|
# be accessed directly.
|
|
@@ -170,6 +169,15 @@ class WithInputDataSamples:
|
|
|
170
169
|
@classmethod
|
|
171
170
|
def setup(cls):
|
|
172
171
|
"""Sets up input data samples."""
|
|
172
|
+
cls._ROI_CALIBRATION_PERIOD = backend.cast(
|
|
173
|
+
backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_MEDIA_CHANNELS)),
|
|
174
|
+
dtype=backend.bool_,
|
|
175
|
+
)
|
|
176
|
+
cls._RF_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
177
|
+
backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_RF_CHANNELS)),
|
|
178
|
+
dtype=backend.bool_,
|
|
179
|
+
)
|
|
180
|
+
|
|
173
181
|
cls._input_data_non_revenue_no_revenue_per_kpi = (
|
|
174
182
|
test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
|
|
175
183
|
n_geos=cls._N_GEOS,
|
|
@@ -490,6 +498,34 @@ class WithInputDataSamples:
|
|
|
490
498
|
seed=0,
|
|
491
499
|
)
|
|
492
500
|
)
|
|
501
|
+
cls._input_data_organic_only = (
|
|
502
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
503
|
+
n_geos=cls._N_GEOS,
|
|
504
|
+
n_times=cls._N_TIMES,
|
|
505
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
506
|
+
n_controls=cls._N_CONTROLS,
|
|
507
|
+
n_non_media_channels=0,
|
|
508
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
509
|
+
n_rf_channels=0,
|
|
510
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
511
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
512
|
+
seed=0,
|
|
513
|
+
)
|
|
514
|
+
)
|
|
515
|
+
cls._national_input_data_organic_only = (
|
|
516
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
517
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
518
|
+
n_times=cls._N_TIMES,
|
|
519
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
520
|
+
n_controls=cls._N_CONTROLS,
|
|
521
|
+
n_non_media_channels=0,
|
|
522
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
523
|
+
n_rf_channels=0,
|
|
524
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
525
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
526
|
+
seed=0,
|
|
527
|
+
)
|
|
528
|
+
)
|
|
493
529
|
|
|
494
530
|
@property
|
|
495
531
|
def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
|
|
@@ -600,3 +636,11 @@ class WithInputDataSamples:
|
|
|
600
636
|
self,
|
|
601
637
|
) -> input_data.InputData:
|
|
602
638
|
return self._input_data_non_media_and_organic_same_time_dims.copy(deep=True)
|
|
639
|
+
|
|
640
|
+
@property
|
|
641
|
+
def input_data_organic_only(self) -> input_data.InputData:
|
|
642
|
+
return self._input_data_organic_only.copy(deep=True)
|
|
643
|
+
|
|
644
|
+
@property
|
|
645
|
+
def national_input_data_organic_only(self) -> input_data.InputData:
|
|
646
|
+
return self._national_input_data_organic_only.copy(deep=True)
|