google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
  2. google_meridian-1.3.0.dist-info/RECORD +62 -0
  3. meridian/analysis/__init__.py +2 -0
  4. meridian/analysis/analyzer.py +280 -142
  5. meridian/analysis/formatter.py +2 -2
  6. meridian/analysis/optimizer.py +353 -169
  7. meridian/analysis/review/__init__.py +20 -0
  8. meridian/analysis/review/checks.py +721 -0
  9. meridian/analysis/review/configs.py +110 -0
  10. meridian/analysis/review/constants.py +40 -0
  11. meridian/analysis/review/results.py +544 -0
  12. meridian/analysis/review/reviewer.py +186 -0
  13. meridian/analysis/summarizer.py +14 -12
  14. meridian/analysis/templates/chips.html.jinja +12 -0
  15. meridian/analysis/test_utils.py +27 -5
  16. meridian/analysis/visualizer.py +45 -50
  17. meridian/backend/__init__.py +698 -55
  18. meridian/backend/config.py +75 -16
  19. meridian/backend/test_utils.py +127 -1
  20. meridian/constants.py +52 -11
  21. meridian/data/input_data.py +7 -2
  22. meridian/data/test_utils.py +5 -3
  23. meridian/mlflow/autolog.py +2 -2
  24. meridian/model/__init__.py +1 -0
  25. meridian/model/adstock_hill.py +10 -9
  26. meridian/model/eda/__init__.py +3 -0
  27. meridian/model/eda/constants.py +21 -0
  28. meridian/model/eda/eda_engine.py +1580 -84
  29. meridian/model/eda/eda_outcome.py +200 -0
  30. meridian/model/eda/eda_spec.py +84 -0
  31. meridian/model/eda/meridian_eda.py +220 -0
  32. meridian/model/knots.py +56 -50
  33. meridian/model/media.py +10 -8
  34. meridian/model/model.py +79 -16
  35. meridian/model/model_test_data.py +53 -9
  36. meridian/model/posterior_sampler.py +398 -391
  37. meridian/model/prior_distribution.py +114 -39
  38. meridian/model/prior_sampler.py +146 -90
  39. meridian/model/spec.py +7 -8
  40. meridian/model/transformers.py +16 -8
  41. meridian/version.py +1 -1
  42. google_meridian-1.2.0.dist-info/RECORD +0 -52
  43. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
  44. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
  45. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
meridian/model/media.py CHANGED
@@ -63,8 +63,8 @@ class MediaTensors:
63
63
  media_spend: A tensor constructed from `InputData.media_spend`.
64
64
  media_transformer: A `MediaTransformer` to scale media tensors using the
65
65
  model's media data.
66
- media_scaled: The media tensor normalized by population and by the median
67
- value.
66
+ media_scaled: The media tensor after pre-modeling transformations including
67
+ population scaling and scaling by the median non-zero value.
68
68
  prior_media_scaled_counterfactual: A tensor containing `media_scaled` values
69
69
  corresponding to the counterfactual scenario required for the prior
70
70
  calculation. For ROI priors, the counterfactual scenario is where media is
@@ -169,8 +169,9 @@ class OrganicMediaTensors:
169
169
  organic_media: A tensor constructed from `InputData.organic_media`.
170
170
  organic_media_transformer: A `MediaTransformer` to scale media tensors using
171
171
  the model's organic media data.
172
- organic_media_scaled: The organic media tensor normalized by population and
173
- by the median value.
172
+ organic_media_scaled: The organic media tensor after pre-modeling
173
+ transformations including population scaling and scaling by the media
174
+ non-zero value.
174
175
  """
175
176
 
176
177
  organic_media: backend.Tensor | None = None
@@ -214,8 +215,8 @@ class RfTensors:
214
215
  rf_spend: A tensor constructed from `InputData.rf_spend`.
215
216
  reach_transformer: A `MediaTransformer` to scale RF tensors using the
216
217
  model's RF data.
217
- reach_scaled: A reach tensor normalized by population and by the median
218
- value.
218
+ reach_scaled: A reach tensor after pre-modeling transformations including
219
+ population scaling and scaling by the median non-zero value.
219
220
  prior_reach_scaled_counterfactual: A tensor containing `reach_scaled` values
220
221
  corresponding to the counterfactual scenario required for the prior
221
222
  calculation. For ROI priors, the counterfactual scenario is where reach is
@@ -324,8 +325,9 @@ class OrganicRfTensors:
324
325
  organic_frequency: A tensor constructed from `InputData.organic_frequency`.
325
326
  organic_reach_transformer: A `MediaTransformer` to scale organic RF tensors
326
327
  using the model's organic RF data.
327
- organic_reach_scaled: An organic reach tensor normalized by population and
328
- by the median value.
328
+ organic_reach_scaled: An organic reach tensor after pre-modeling
329
+ transformations including population scaling and scaling by the median
330
+ non-zero value.
329
331
  """
330
332
 
331
333
  organic_reach: backend.Tensor | None = None
meridian/model/model.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  """Meridian module for the geo-level Bayesian hierarchical media mix model."""
16
16
 
17
+ import collections
17
18
  from collections.abc import Mapping, Sequence
18
19
  import functools
19
20
  import numbers
@@ -34,19 +35,26 @@ from meridian.model import prior_distribution
34
35
  from meridian.model import prior_sampler
35
36
  from meridian.model import spec
36
37
  from meridian.model import transformers
38
+ from meridian.model.eda import eda_engine
39
+ from meridian.model.eda import eda_outcome
40
+ from meridian.model.eda import eda_spec as eda_spec_module
37
41
  import numpy as np
38
42
 
39
-
40
43
  __all__ = [
41
44
  "MCMCSamplingError",
42
45
  "MCMCOOMError",
43
46
  "Meridian",
47
+ "ModelFittingError",
44
48
  "NotFittedModelError",
45
49
  "save_mmm",
46
50
  "load_mmm",
47
51
  ]
48
52
 
49
53
 
54
+ class ModelFittingError(Exception):
55
+ """Model has critical issues preventing fitting."""
56
+
57
+
50
58
  class NotFittedModelError(Exception):
51
59
  """Model has not been fitted."""
52
60
 
@@ -91,6 +99,10 @@ class Meridian:
91
99
  model_spec: A `ModelSpec` object containing the model specification.
92
100
  inference_data: A _mutable_ `arviz.InferenceData` object containing the
93
101
  resulting data from fitting the model.
102
+ eda_engine: An `EDAEngine` object containing the EDA engine.
103
+ eda_spec: An `EDASpec` object containing the EDA specification.
104
+ eda_outcomes: A list of `EDAOutcome` objects containing the outcomes from
105
+ running critical EDA checks.
94
106
  n_geos: Number of geos in the data.
95
107
  n_media_channels: Number of media channels in the data.
96
108
  n_rf_channels: Number of reach and frequency (RF) channels in the data.
@@ -126,11 +138,17 @@ class Meridian:
126
138
  treatmenttensors using the model's non-media treatment data.
127
139
  kpi_transformer: A `KpiTransformer` to scale KPI tensors using the model's
128
140
  KPI data.
129
- controls_scaled: The controls tensor normalized by population and by the
130
- median value.
131
- non_media_treatments_scaled: The non-media treatment tensor normalized by
132
- population and by the median value.
133
- kpi_scaled: The KPI tensor normalized by population and by the median value.
141
+ controls_scaled: The controls tensor after pre-modeling transformations
142
+ including population scaling (for variables with
143
+ `ModelSpec.control_population_scaling_id` set to `True`), centering by the
144
+ mean, and scaling by the standard deviation.
145
+ non_media_treatments_scaled: The non-media treatment tensor after
146
+ pre-modeling transformations including population scaling (for variables
147
+ with `ModelSpec.non_media_population_scaling_id` set to `True`), centering
148
+ by the mean, and scaling by the standard deviation.
149
+ kpi_scaled: The KPI tensor after pre-modeling transformations including
150
+ population scaling, centering by the mean, and scaling by the standard
151
+ deviation.
134
152
  media_effects_dist: A string to specify the distribution of media random
135
153
  effects across geos.
136
154
  unique_sigma_for_each_geo: A boolean indicating whether to use a unique
@@ -148,6 +166,7 @@ class Meridian:
148
166
  inference_data: (
149
167
  az.InferenceData | None
150
168
  ) = None, # for deserializer use only
169
+ eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
151
170
  ):
152
171
  self._input_data = input_data
153
172
  self._model_spec = model_spec if model_spec else spec.ModelSpec()
@@ -155,6 +174,8 @@ class Meridian:
155
174
  inference_data if inference_data else az.InferenceData()
156
175
  )
157
176
 
177
+ self._eda_spec = eda_spec
178
+
158
179
  self._validate_data_dependent_model_spec()
159
180
  self._validate_injected_inference_data()
160
181
 
@@ -184,6 +205,18 @@ class Meridian:
184
205
  def inference_data(self) -> az.InferenceData:
185
206
  return self._inference_data
186
207
 
208
+ @functools.cached_property
209
+ def eda_engine(self) -> eda_engine.EDAEngine:
210
+ return eda_engine.EDAEngine(self, spec=self._eda_spec)
211
+
212
+ @property
213
+ def eda_spec(self) -> eda_spec_module.EDASpec:
214
+ return self._eda_spec
215
+
216
+ @property
217
+ def eda_outcomes(self) -> Sequence[eda_outcome.EDAOutcome]:
218
+ return self.eda_engine.run_all_critical_checks()
219
+
187
220
  @functools.cached_property
188
221
  def media_tensors(self) -> media.MediaTensors:
189
222
  return media.build_media_tensors(self.input_data, self.model_spec)
@@ -444,7 +477,8 @@ class Meridian:
444
477
  f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
445
478
  " either contain only channel_names"
446
479
  f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
447
- " be one or more of {'media', 'rf', 'organic_media', 'organic_rf'}."
480
+ " be one or more of {'media', 'rf', 'organic_media',"
481
+ " 'organic_rf'}."
448
482
  ) from e
449
483
 
450
484
  @functools.cached_property
@@ -561,7 +595,9 @@ class Meridian:
561
595
  non_media_treatments_population_scaled[..., channel], axis=[0, 1]
562
596
  )
563
597
  elif isinstance(baseline_value, numbers.Number):
564
- baseline_for_channel = backend.cast(baseline_value, backend.float32)
598
+ baseline_for_channel = backend.to_tensor(
599
+ baseline_value, dtype=backend.float32
600
+ )
565
601
  else:
566
602
  raise ValueError(
567
603
  f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
@@ -1135,16 +1171,11 @@ class Meridian:
1135
1171
  " time."
1136
1172
  )
1137
1173
 
1138
- def _kpi_has_variability(self):
1139
- """Returns True if the KPI has variability across geos and times."""
1140
- return self.kpi_transformer.population_scaled_stdev != 0
1141
-
1142
1174
  def _validate_kpi_transformer(self):
1143
1175
  """Validates the KPI transformer."""
1144
- if self._kpi_has_variability():
1176
+ if self.eda_engine.kpi_has_variability:
1145
1177
  return
1146
-
1147
- kpi = "kpi" if self.is_national else "population_scaled_kpi"
1178
+ kpi = self.eda_engine.kpi_scaled_da.name
1148
1179
 
1149
1180
  if (
1150
1181
  self.n_media_channels > 0
@@ -1569,6 +1600,36 @@ class Meridian:
1569
1600
  """
1570
1601
  self.prior_sampler_callable(n_draws=n_draws, seed=seed)
1571
1602
 
1603
+ def _run_model_fitting_guardrail(self):
1604
+ """Raises an error if the model has critical EDA issues."""
1605
+ error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = (
1606
+ collections.defaultdict(list)
1607
+ )
1608
+ for outcome in self.eda_outcomes:
1609
+ error_findings = [
1610
+ finding
1611
+ for finding in outcome.findings
1612
+ if finding.severity == eda_outcome.EDASeverity.ERROR
1613
+ ]
1614
+ if error_findings:
1615
+ error_findings_by_type[outcome.check_type].extend(
1616
+ [finding.explanation for finding in error_findings]
1617
+ )
1618
+
1619
+ if error_findings_by_type:
1620
+ error_message_lines = [
1621
+ "Model has critical EDA issues. Please fix before running"
1622
+ " `sample_posterior`.\n"
1623
+ ]
1624
+ for check_type, explanations in error_findings_by_type.items():
1625
+ error_message_lines.append(f"Check type: {check_type.name}")
1626
+ for explanation in explanations:
1627
+ error_message_lines.append(f"- {explanation}")
1628
+ error_message_lines.append(
1629
+ "For further details, please refer to `Meridian.eda_outcomes`."
1630
+ )
1631
+ raise ModelFittingError("\n".join(error_message_lines))
1632
+
1572
1633
  def sample_posterior(
1573
1634
  self,
1574
1635
  n_chains: Sequence[int] | int,
@@ -1644,8 +1705,10 @@ class Meridian:
1644
1705
  a list of integers as `n_chains` to sample chains serially. For more
1645
1706
  information, see
1646
1707
  [ResourceExhaustedError when running Meridian.sample_posterior]
1647
- (https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
1708
+ (https://developers.google.com/meridian/docs/post-modeling/model-debugging#gpu-oom-error).
1648
1709
  """
1710
+ self._run_model_fitting_guardrail()
1711
+
1649
1712
  self.posterior_sampler_callable(
1650
1713
  n_chains=n_chains,
1651
1714
  n_adapt=n_adapt,
@@ -52,7 +52,9 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> backend.Tensor:
52
52
  else:
53
53
  pad_value = 0.0 if array.dtype.kind == "f" else 0
54
54
 
55
- burnin = backend.fill([n_burnin] + transposed_tensor.shape[1:], pad_value)
55
+ burnin = backend.fill(
56
+ [n_burnin] + list(transposed_tensor.shape[1:]), pad_value
57
+ )
56
58
  return backend.concatenate(
57
59
  [burnin, transposed_tensor],
58
60
  axis=0,
@@ -122,18 +124,13 @@ class WithInputDataSamples:
122
124
  _N_MEDIA_CHANNELS = 3
123
125
  _N_RF_CHANNELS = 2
124
126
  _N_CONTROLS = 2
125
- _ROI_CALIBRATION_PERIOD = backend.cast(
126
- backend.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
127
- dtype=backend.bool_,
128
- )
129
- _RF_ROI_CALIBRATION_PERIOD = backend.cast(
130
- backend.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
131
- dtype=backend.bool_,
132
- )
133
127
  _N_ORGANIC_MEDIA_CHANNELS = 4
134
128
  _N_ORGANIC_RF_CHANNELS = 1
135
129
  _N_NON_MEDIA_CHANNELS = 2
136
130
 
131
+ _ROI_CALIBRATION_PERIOD: backend.Tensor
132
+ _RF_ROI_CALIBRATION_PERIOD: backend.Tensor
133
+
137
134
  # Private class variables to hold the base test data.
138
135
  _input_data_non_revenue_no_revenue_per_kpi: input_data.InputData
139
136
  _input_data_media_and_rf_non_revenue_no_revenue_per_kpi: input_data.InputData
@@ -159,6 +156,8 @@ class WithInputDataSamples:
159
156
  _short_input_data_non_media_and_organic: input_data.InputData
160
157
  _short_input_data_non_media: input_data.InputData
161
158
  _input_data_non_media_and_organic_same_time_dims: input_data.InputData
159
+ _input_data_organic_only: input_data.InputData
160
+ _national_input_data_organic_only: input_data.InputData
162
161
 
163
162
  # The following NamedTuples and their attributes are immutable, so they can
164
163
  # be accessed directly.
@@ -170,6 +169,15 @@ class WithInputDataSamples:
170
169
  @classmethod
171
170
  def setup(cls):
172
171
  """Sets up input data samples."""
172
+ cls._ROI_CALIBRATION_PERIOD = backend.cast(
173
+ backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_MEDIA_CHANNELS)),
174
+ dtype=backend.bool_,
175
+ )
176
+ cls._RF_ROI_CALIBRATION_PERIOD = backend.cast(
177
+ backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_RF_CHANNELS)),
178
+ dtype=backend.bool_,
179
+ )
180
+
173
181
  cls._input_data_non_revenue_no_revenue_per_kpi = (
174
182
  test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
175
183
  n_geos=cls._N_GEOS,
@@ -490,6 +498,34 @@ class WithInputDataSamples:
490
498
  seed=0,
491
499
  )
492
500
  )
501
+ cls._input_data_organic_only = (
502
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
503
+ n_geos=cls._N_GEOS,
504
+ n_times=cls._N_TIMES,
505
+ n_media_times=cls._N_MEDIA_TIMES,
506
+ n_controls=cls._N_CONTROLS,
507
+ n_non_media_channels=0,
508
+ n_media_channels=cls._N_MEDIA_CHANNELS,
509
+ n_rf_channels=0,
510
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
511
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
512
+ seed=0,
513
+ )
514
+ )
515
+ cls._national_input_data_organic_only = (
516
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
517
+ n_geos=cls._N_GEOS_NATIONAL,
518
+ n_times=cls._N_TIMES,
519
+ n_media_times=cls._N_MEDIA_TIMES,
520
+ n_controls=cls._N_CONTROLS,
521
+ n_non_media_channels=0,
522
+ n_media_channels=cls._N_MEDIA_CHANNELS,
523
+ n_rf_channels=0,
524
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
525
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
526
+ seed=0,
527
+ )
528
+ )
493
529
 
494
530
  @property
495
531
  def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
@@ -600,3 +636,11 @@ class WithInputDataSamples:
600
636
  self,
601
637
  ) -> input_data.InputData:
602
638
  return self._input_data_non_media_and_organic_same_time_dims.copy(deep=True)
639
+
640
+ @property
641
+ def input_data_organic_only(self) -> input_data.InputData:
642
+ return self._input_data_organic_only.copy(deep=True)
643
+
644
+ @property
645
+ def national_input_data_organic_only(self) -> input_data.InputData:
646
+ return self._national_input_data_organic_only.copy(deep=True)