google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,186 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Implementation of the runner of the Model Quality Checks."""
16
+
17
+ import typing
18
+
19
+ import immutabledict
20
+ from meridian import constants
21
+ from meridian.analysis import analyzer as analyzer_module
22
+ from meridian.analysis.review import checks
23
+ from meridian.analysis.review import configs
24
+ from meridian.analysis.review import results
25
+ from meridian.model import prior_distribution
26
+
27
+
28
+ CheckType = typing.Type[checks.BaseCheck]
29
+ ConfigInstance = configs.BaseConfig
30
+ ChecksBattery = immutabledict.immutabledict[CheckType, ConfigInstance]
31
+
32
+ _DEFAULT_POST_CONVERGENCE_CHECKS: ChecksBattery = immutabledict.immutabledict({
33
+ checks.BaselineCheck: configs.BaselineConfig(),
34
+ checks.BayesianPPPCheck: configs.BayesianPPPConfig(),
35
+ checks.GoodnessOfFitCheck: configs.GoodnessOfFitConfig(),
36
+ checks.PriorPosteriorShiftCheck: configs.PriorPosteriorShiftConfig(),
37
+ checks.ROIConsistencyCheck: configs.ROIConsistencyConfig(),
38
+ })
39
+
40
+
41
+ class ModelReviewer:
42
+ """Executes a series of quality checks on a Meridian model.
43
+
44
+ The reviewer first runs a convergence check. If the model has converged, it
45
+ proceeds to run a battery of post-convergence checks.
46
+
47
+ The default battery of post-convergence checks includes:
48
+ - BaselineCheck
49
+ - BayesianPPPCheck
50
+ - GoodnessOfFitCheck
51
+ - PriorPosteriorShiftCheck
52
+ - ROIConsistencyCheck
53
+ Each with its default configuration.
54
+
55
+ This battery of checks can be customized by passing a dictionary to the
56
+ `post_convergence_checks` argument of the constructor, mapping check
57
+ classes to their configuration instances. For example, to run only the
58
+ BaselineCheck with a non-default configuration:
59
+
60
+ ```python
61
+ my_checks = {
62
+ checks.BaselineCheck: configs.BaselineConfig(
63
+ negative_baseline_prob_review_threshold=0.1,
64
+ negative_baseline_prob_fail_threshold=0.5,
65
+ )
66
+ }
67
+ reviewer = ModelReviewer(meridian_model, post_convergence_checks=my_checks)
68
+ ```
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ meridian,
74
+ post_convergence_checks: ChecksBattery = _DEFAULT_POST_CONVERGENCE_CHECKS,
75
+ ):
76
+ self._meridian = meridian
77
+ self._results: list[results.CheckResult] = []
78
+ self._analyzer = analyzer_module.Analyzer(meridian)
79
+ self._post_convergence_checks = post_convergence_checks
80
+
81
+ def _run_and_handle(self, check_class, config):
82
+ instance = check_class(self._meridian, self._analyzer, config) # pytype: disable=not-instantiable
83
+ self._results.append(instance.run())
84
+
85
+ def _uses_roi_priors(self):
86
+ """Checks if the model uses ROI priors."""
87
+ return (
88
+ self._meridian.n_media_channels > 0
89
+ and self._meridian.model_spec.effective_media_prior_type
90
+ == constants.TREATMENT_PRIOR_TYPE_ROI
91
+ ) or (
92
+ self._meridian.n_rf_channels > 0
93
+ and self._meridian.model_spec.effective_rf_prior_type
94
+ == constants.TREATMENT_PRIOR_TYPE_ROI
95
+ )
96
+
97
+ def _has_custom_roi_priors(self):
98
+ """Checks if the model uses custom ROI priors."""
99
+ default_distribution = prior_distribution.PriorDistribution()
100
+ if (
101
+ self._meridian.n_media_channels > 0
102
+ and self._meridian.model_spec.effective_media_prior_type
103
+ == constants.TREATMENT_PRIOR_TYPE_ROI
104
+ ):
105
+ if not prior_distribution.distributions_are_equal(
106
+ self._meridian.model_spec.prior.roi_m, default_distribution.roi_m
107
+ ):
108
+ return True
109
+ if (
110
+ self._meridian.n_rf_channels > 0
111
+ and self._meridian.model_spec.effective_rf_prior_type
112
+ == constants.TREATMENT_PRIOR_TYPE_ROI
113
+ ):
114
+ if not prior_distribution.distributions_are_equal(
115
+ self._meridian.model_spec.prior.roi_rf, default_distribution.roi_rf
116
+ ):
117
+ return True
118
+ return False
119
+
120
+ def run(self) -> results.ReviewSummary:
121
+ """Executes all checks and generates the final summary."""
122
+ self._results.clear()
123
+ self._run_and_handle(checks.ConvergenceCheck, configs.ConvergenceConfig())
124
+
125
+ # Stop if the model did not converge.
126
+ if (
127
+ self._results
128
+ and self._results[0].case is results.ConvergenceCases.NOT_CONVERGED
129
+ ):
130
+ return results.ReviewSummary(
131
+ overall_status=results.Status.FAIL,
132
+ summary_message=(
133
+ "Failed: Model did not converge. Other checks were skipped."
134
+ ),
135
+ results=self._results,
136
+ )
137
+
138
+ # Run all other checks in sequence.
139
+ for check_class, config in self._post_convergence_checks.items():
140
+ if (
141
+ check_class == checks.PriorPosteriorShiftCheck
142
+ and not self._uses_roi_priors()
143
+ ):
144
+ # Skip the Prior-Posterior Shift check if no ROI priors are used.
145
+ continue
146
+ if (
147
+ check_class == checks.ROIConsistencyCheck
148
+ and not self._has_custom_roi_priors()
149
+ ):
150
+ # Skip the ROI Consistency check if no custom ROI priors are provided.
151
+ continue
152
+ self._run_and_handle(check_class, config)
153
+
154
+ # Determine the final overall status.
155
+ has_failures = any(
156
+ res.case.status is results.Status.FAIL for res in self._results
157
+ )
158
+ has_reviews = any(
159
+ res.case.status is results.Status.REVIEW for res in self._results
160
+ )
161
+
162
+ if has_failures and has_reviews:
163
+ overall_status = results.Status.FAIL
164
+ summary_message = (
165
+ "Failed: Quality issues were detected in your model. Follow"
166
+ " recommendations to address any failed checks and review"
167
+ " results to determine if further action is needed."
168
+ )
169
+ elif has_failures:
170
+ overall_status = results.Status.FAIL
171
+ summary_message = (
172
+ "Failed: Quality issues were detected in your model. Address failed"
173
+ " checks before proceeding."
174
+ )
175
+ elif has_reviews:
176
+ overall_status = results.Status.PASS
177
+ summary_message = "Passed with reviews: Review is needed."
178
+ else:
179
+ overall_status = results.Status.PASS
180
+ summary_message = "Passed: No major quality issues were identified."
181
+
182
+ return results.ReviewSummary(
183
+ overall_status=overall_status,
184
+ summary_message=summary_message,
185
+ results=self._results,
186
+ )
@@ -20,6 +20,7 @@ import os
20
20
 
21
21
  import jinja2
22
22
  from meridian import constants as c
23
+ from meridian.analysis import analyzer
23
24
  from meridian.analysis import formatter
24
25
  from meridian.analysis import summary_text
25
26
  from meridian.analysis import visualizer
@@ -59,17 +60,18 @@ RESPONSE_CURVES_CARD_SPEC = formatter.CardSpec(
59
60
  class Summarizer:
60
61
  """Generates HTML summary visualizations from the model fitting."""
61
62
 
62
- def __init__(self, meridian: model.Meridian):
63
+ def __init__(self, meridian: model.Meridian, use_kpi: bool = False):
63
64
  """Initialize the visualizer classes that are not time-dependent."""
64
65
  self._meridian = meridian
66
+ self._use_kpi = analyzer.Analyzer(meridian)._use_kpi(use_kpi)
65
67
 
66
68
  @functools.cached_property
67
69
  def _model_fit(self):
68
- return visualizer.ModelFit(self._meridian)
70
+ return visualizer.ModelFit(self._meridian, use_kpi=self._use_kpi)
69
71
 
70
72
  @functools.cached_property
71
73
  def _model_diagnostics(self):
72
- return visualizer.ModelDiagnostics(self._meridian)
74
+ return visualizer.ModelDiagnostics(self._meridian, use_kpi=self._use_kpi)
73
75
 
74
76
  def output_model_results_summary(
75
77
  self,
@@ -77,7 +79,6 @@ class Summarizer:
77
79
  filepath: str,
78
80
  start_date: tc.Date = None,
79
81
  end_date: tc.Date = None,
80
- use_kpi: bool = False,
81
82
  ):
82
83
  """Generates and saves the HTML results summary output.
83
84
 
@@ -87,18 +88,15 @@ class Summarizer:
87
88
  start_date: Optional start date selector, *inclusive*, in _yyyy-mm-dd_
88
89
  format.
89
90
  end_date: Optional end date selector, *inclusive* in _yyyy-mm-dd_ format.
90
- use_kpi: If `True`, calculate the incremental KPI. Otherwise, calculate
91
- the incremental revenue using the revenue per KPI (if available).
92
91
  """
93
92
  os.makedirs(filepath, exist_ok=True)
94
93
  with open(os.path.join(filepath, filename), 'w') as f:
95
- f.write(self._gen_model_results_summary(start_date, end_date, use_kpi))
94
+ f.write(self._gen_model_results_summary(start_date, end_date))
96
95
 
97
96
  def _gen_model_results_summary(
98
97
  self,
99
98
  start_date: tc.Date = None,
100
99
  end_date: tc.Date = None,
101
- use_kpi: bool = False,
102
100
  ) -> str:
103
101
  """Generate HTML results summary output (as sanitized content str)."""
104
102
  all_dates = self._meridian.input_data.time_coordinates.all_dates
@@ -144,7 +142,6 @@ class Summarizer:
144
142
  cards_htmls = self._create_cards_htmls(
145
143
  template_env,
146
144
  selected_times=selected_times,
147
- use_kpi=use_kpi,
148
145
  )
149
146
 
150
147
  return html_template.render(
@@ -155,29 +152,29 @@ class Summarizer:
155
152
  self,
156
153
  template_env: jinja2.Environment,
157
154
  selected_times: Sequence[str] | None,
158
- use_kpi: bool,
159
155
  ) -> Sequence[str]:
160
156
  """Creates the HTML snippets for cards in the summary page."""
161
157
  media_summary = visualizer.MediaSummary(
162
- self._meridian, selected_times=selected_times
158
+ self._meridian, selected_times=selected_times, use_kpi=self._use_kpi
159
+ )
160
+ media_effects = visualizer.MediaEffects(
161
+ self._meridian, use_kpi=self._use_kpi
163
162
  )
164
- media_effects = visualizer.MediaEffects(self._meridian)
165
163
  reach_frequency = (
166
164
  visualizer.ReachAndFrequency(
167
- self._meridian, selected_times=selected_times
165
+ self._meridian, selected_times=selected_times, use_kpi=self._use_kpi
168
166
  )
169
167
  if self._meridian.n_rf_channels > 0
170
168
  else None
171
169
  )
172
170
  cards = [
173
171
  self._create_model_fit_card_html(
174
- template_env, selected_times=selected_times, use_kpi=use_kpi
172
+ template_env, selected_times=selected_times
175
173
  ),
176
174
  self._create_outcome_contrib_card_html(
177
175
  template_env,
178
176
  media_summary,
179
177
  selected_times=selected_times,
180
- use_kpi=use_kpi,
181
178
  ),
182
179
  self._create_performance_breakdown_card_html(
183
180
  template_env, media_summary
@@ -188,17 +185,16 @@ class Summarizer:
188
185
  media_summary=media_summary,
189
186
  media_effects=media_effects,
190
187
  reach_frequency=reach_frequency,
191
- use_kpi=use_kpi,
192
188
  ),
193
189
  ]
194
190
  return cards
195
191
 
196
192
  def _create_model_fit_card_html(
197
- self, template_env: jinja2.Environment, use_kpi: bool, **kwargs
193
+ self, template_env: jinja2.Environment, **kwargs
198
194
  ) -> str:
199
195
  """Creates the HTML snippet for the Model Fit card."""
200
196
  model_fit = self._model_fit
201
- outcome = self._kpi_or_revenue(use_kpi)
197
+ outcome = self._kpi_or_revenue()
202
198
  expected_actual_outcome_chart = formatter.ChartSpec(
203
199
  id=summary_text.EXPECTED_ACTUAL_OUTCOME_CHART_ID,
204
200
  description=summary_text.EXPECTED_ACTUAL_OUTCOME_CHART_DESCRIPTION_FORMAT.format(
@@ -207,9 +203,7 @@ class Summarizer:
207
203
  chart_json=model_fit.plot_model_fit(**kwargs).to_json(),
208
204
  )
209
205
 
210
- predictive_accuracy_table = self._predictive_accuracy_table_spec(
211
- use_kpi=use_kpi, **kwargs
212
- )
206
+ predictive_accuracy_table = self._predictive_accuracy_table_spec(**kwargs)
213
207
  insights = summary_text.MODEL_FIT_INSIGHTS_FORMAT
214
208
 
215
209
  return formatter.create_card_html(
@@ -219,11 +213,9 @@ class Summarizer:
219
213
  [expected_actual_outcome_chart, predictive_accuracy_table],
220
214
  )
221
215
 
222
- def _predictive_accuracy_table_spec(
223
- self, use_kpi: bool, **kwargs
224
- ) -> formatter.TableSpec:
216
+ def _predictive_accuracy_table_spec(self, **kwargs) -> formatter.TableSpec:
225
217
  """Creates the HTML snippet for the predictive accuracy table."""
226
- outcome = self._kpi_or_revenue(use_kpi)
218
+ outcome = self._kpi_or_revenue()
227
219
  model_diag = self._model_diagnostics
228
220
  table = model_diag.predictive_accuracy_table(column_var=c.METRIC, **kwargs)
229
221
 
@@ -284,10 +276,9 @@ class Summarizer:
284
276
  template_env: jinja2.Environment,
285
277
  media_summary: visualizer.MediaSummary,
286
278
  selected_times: Sequence[str] | None,
287
- use_kpi: bool,
288
279
  ) -> str:
289
280
  """Creates the HTML snippet for the Outcome Contrib card."""
290
- outcome = self._kpi_or_revenue(use_kpi)
281
+ outcome = self._kpi_or_revenue()
291
282
 
292
283
  num_selected_times = (
293
284
  self._meridian.n_times
@@ -457,10 +448,9 @@ class Summarizer:
457
448
  media_summary: visualizer.MediaSummary,
458
449
  media_effects: visualizer.MediaEffects,
459
450
  reach_frequency: visualizer.ReachAndFrequency | None,
460
- use_kpi: bool,
461
451
  ) -> str:
462
452
  """Creates the HTML snippet for the Optimal Analyst card."""
463
- outcome = self._kpi_or_revenue(use_kpi)
453
+ outcome = self._kpi_or_revenue()
464
454
  charts = []
465
455
  charts.append(
466
456
  formatter.ChartSpec(
@@ -473,7 +463,6 @@ class Summarizer:
473
463
  selected_times=(
474
464
  frozenset(selected_times) if selected_times else None
475
465
  ),
476
- use_kpi=use_kpi,
477
466
  plot_separately=False,
478
467
  include_ci=False,
479
468
  num_channels_displayed=7,
@@ -541,7 +530,5 @@ class Summarizer:
541
530
  rf_channel=most_spend_rf_channel
542
531
  ).optimal_frequency
543
532
 
544
- def _kpi_or_revenue(self, use_kpi: bool) -> str:
545
- if use_kpi or self._meridian.input_data.revenue_per_kpi is None:
546
- return c.KPI.upper()
547
- return c.REVENUE
533
+ def _kpi_or_revenue(self) -> str:
534
+ return c.KPI.upper() if self._use_kpi else c.REVENUE
@@ -18,4 +18,16 @@ limitations under the License.
18
18
  <chip>
19
19
  Time period: {{ start_date }} - {{ end_date }}
20
20
  </chip>
21
+ {% if selected_geos %}
22
+ {%- set geo_count = selected_geos | length %}
23
+ {%- set full_list = selected_geos | join(', ') %}
24
+ <chip title="Selected: {{ full_list }}">
25
+ Selected geos:
26
+ {%- if geo_count > 5 %}
27
+ {{ selected_geos[:5] | join(', ') }}, + {{ geo_count - 5 }} more
28
+ {%- else %}
29
+ {{ full_list }}
30
+ {%- endif %}
31
+ </chip>
32
+ {% endif %}
21
33
  </chips>
@@ -726,8 +726,6 @@ INC_OUTCOME_NON_PAID_USE_PRIOR = np.array([[
726
726
  2.224e05,
727
727
  ],
728
728
  ]])
729
-
730
-
731
729
  INC_OUTCOME_NON_PAID_USE_POSTERIOR = np.array([
732
730
  [
733
731
  [
@@ -1055,7 +1053,6 @@ INC_OUTCOME_NON_PAID_USE_POSTERIOR = np.array([
1055
1053
  ],
1056
1054
  ])
1057
1055
 
1058
-
1059
1056
  INC_OUTCOME_NON_MEDIA_MAX = np.array([
1060
1057
  [
1061
1058
  [
@@ -1282,7 +1279,6 @@ INC_OUTCOME_NON_MEDIA_MAX = np.array([
1282
1279
  ],
1283
1280
  ],
1284
1281
  ])
1285
-
1286
1282
  INC_OUTCOME_NON_MEDIA_MIX = np.array([
1287
1283
  [
1288
1284
  [
@@ -1509,7 +1505,6 @@ INC_OUTCOME_NON_MEDIA_MIX = np.array([
1509
1505
  ],
1510
1506
  ],
1511
1507
  ])
1512
-
1513
1508
  INC_OUTCOME_NON_MEDIA_FIXED = np.array([
1514
1509
  [
1515
1510
  [
@@ -1737,6 +1732,33 @@ INC_OUTCOME_NON_MEDIA_FIXED = np.array([
1737
1732
  ],
1738
1733
  ])
1739
1734
 
1735
+ EXP_OUTCOME_MEDIA_AND_RF = np.array([
1736
+ [
1737
+ 32254.154,
1738
+ 32296.66,
1739
+ 32331.918,
1740
+ 32391.336,
1741
+ 32797.367,
1742
+ 33077.375,
1743
+ 33469.82,
1744
+ 33316.707,
1745
+ 33251.33,
1746
+ 33261.21,
1747
+ ],
1748
+ [
1749
+ 53146.668,
1750
+ 53151.16,
1751
+ 53404.375,
1752
+ 53814.035,
1753
+ 54136.48,
1754
+ 54141.82,
1755
+ 54581.93,
1756
+ 54777.164,
1757
+ 54912.758,
1758
+ 54915.656,
1759
+ ],
1760
+ ])
1761
+
1740
1762
  MROI_MEDIA_AND_RF_USE_PRIOR = np.array([[
1741
1763
  [0.3399, 1.7045, 3.1300, 2.7845, 0.3523],
1742
1764
  [0.4540, 1.3445, 0.9966, 0.2985, 0.4917],