google-meridian 1.5.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: google-meridian
3
- Version: 1.5.0
3
+ Version: 1.5.1
4
4
  Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
5
5
  Author-email: The Meridian Authors <no-reply@google.com>
6
6
  Project-URL: homepage, https://github.com/google/meridian
@@ -210,7 +210,7 @@ To cite this repository:
210
210
  author = {Google Meridian Marketing Mix Modeling Team},
211
211
  title = {Meridian: Marketing Mix Modeling},
212
212
  url = {https://github.com/google/meridian},
213
- version = {1.5.0},
213
+ version = {1.5.1},
214
214
  year = {2025},
215
215
  }
216
216
  ```
@@ -1,7 +1,7 @@
1
- google_meridian-1.5.0.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
1
+ google_meridian-1.5.1.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
2
2
  meridian/__init__.py,sha256=0fOT5oNZF7-pbiWWGUefV-ysafttieG079m1ijMFQO8,861
3
3
  meridian/constants.py,sha256=3OTX4TKp_lLuuO6V2fA2icgoJQxQUHJjIDxH8yBD4Kw,20359
4
- meridian/version.py,sha256=y8_0XVrWHR9rLHxSjY_gIn3J8QixGcPiCDPA5JnNYSY,644
4
+ meridian/version.py,sha256=kAZOoCSOwL13YeWOWHzABMmvsYhX2egAYQqEeUsPS2M,644
5
5
  meridian/analysis/__init__.py,sha256=AM7xpqoeC-mmY4tPIyHisjQ2MICI7v3jSri--DhDqXA,874
6
6
  meridian/analysis/analyzer.py,sha256=8x6yrDnk_Sy_-fp9M9ZDZzKXqrszrZ_xmXcA78Tf3AY,224888
7
7
  meridian/analysis/optimizer.py,sha256=2_saKikIXuQuwGzNJpkuX2hYpdgcoBYR6_-m1MbbkSM,127223
@@ -10,11 +10,11 @@ meridian/analysis/summary_text.py,sha256=I_smDkZJYp2j77ea-9AIbgeraDa7-qUYyb-IthP
10
10
  meridian/analysis/test_utils.py,sha256=pQQPhKertGawgH2ry1hxiV1aAOlVcV5aHoafqQPzS6s,98743
11
11
  meridian/analysis/visualizer.py,sha256=MprZXNMOAF1BTJi5zdtHRnSwNZ20YYEe8UL2jGJk64k,94850
12
12
  meridian/analysis/review/__init__.py,sha256=cF24EbhiVSs-tvtRf59uVin39tu6aCTTCaeEdv6ISZ8,804
13
- meridian/analysis/review/checks.py,sha256=QEUVwC8L9Pif3y0B_OVAUpZeN6EnFGKH2oXA_dsdgbc,26893
13
+ meridian/analysis/review/checks.py,sha256=Q-niQrgyird1orFgkjgEqQDef57ooRRKR8IhYFPMXsc,27155
14
14
  meridian/analysis/review/configs.py,sha256=5JJ8v6n22GNBmE78xNX6jwdjkZz2qar4Q9YTcVqzcoI,3653
15
- meridian/analysis/review/constants.py,sha256=9tnnc_vAaJi9mnZ0GrRr86RsVq5fyhBaEtIvOLNmn8A,1498
16
- meridian/analysis/review/results.py,sha256=ZZiAdFrqySzcbjrCEacBLQS_ddiklZ34yD1BWns5SYI,17418
17
- meridian/analysis/review/reviewer.py,sha256=6nza-toZhDNWs_x-A6CdN41vYfyhXMk4h_Ckota7Rxg,6766
15
+ meridian/analysis/review/constants.py,sha256=DM6mgDXiLXcyA89EYthWOHHCN6_CP3ey_DuPhP3ZWu4,1497
16
+ meridian/analysis/review/results.py,sha256=HtKW3qw8T2wJ_Ei4wv-qhWTizczD3H_dS0tc9Nik5-4,19252
17
+ meridian/analysis/review/reviewer.py,sha256=llE4ssH4QK4xVZwKcVzrXqEsef2pnCWTKHnmA4sT46M,6009
18
18
  meridian/backend/__init__.py,sha256=DaFTfvsqYtkheFvgV2kdPsyJoz8c-X2_ISSMlleHbVk,45411
19
19
  meridian/backend/config.py,sha256=B9VQnhBfg9RW04GNbt7F5uCugByenoJzt-keFLLYEp8,3561
20
20
  meridian/backend/test_utils.py,sha256=oJNosF_x_BzNuia8LzLFb_YfjGWHRCzR5FXNN5KQ8sw,13738
@@ -45,7 +45,7 @@ meridian/model/spec.py,sha256=hmVz1LZlE1un3Lt2Hx6L8FR7iG8OtL1i6XScCXqvVzE,19684
45
45
  meridian/model/transformers.py,sha256=HxlVJitxP-wu-NOHU0tArFUZ4NAO3c7adAYj4Zvqnvo,8363
46
46
  meridian/model/eda/__init__.py,sha256=bMj9kd2LWU_LQZAjQv54FFggzdv4CKRYblvc-0cHXc4,768
47
47
  meridian/model/eda/constants.py,sha256=maaZ0suGwhWbHIoNqQis9mV4LwlNexyADYx92U2Mrew,15124
48
- meridian/model/eda/eda_engine.py,sha256=MGrnw2-AocaNW-g4jA1N4wm0di5Hl48MWnJiZbu96JM,88467
48
+ meridian/model/eda/eda_engine.py,sha256=-lz2PLAhujkCfq2IMhDnIFePPEF-oZAScnuvFmgJK-Y,88516
49
49
  meridian/model/eda/eda_outcome.py,sha256=xCy0sl92Vge0ANnMqLuadjFTeZlyAs0rBh0zRBUrpzM,11328
50
50
  meridian/model/eda/eda_spec.py,sha256=diieYyZH0ee3ZLy0rGFMcWrrgiUrz2HctMwOrmtJR6w,2871
51
51
  meridian/templates/card.html.jinja,sha256=AWgKGLPf7qTFVNy-vylXm-_tatzW9ngPk1mJm9sTCPg,1332
@@ -99,14 +99,14 @@ schema/serde/function_registry.py,sha256=GbgC5_9NDcA9Y7nqmdJ-4-LK5JPhhfI50Lmfy5Z
99
99
  schema/serde/hyperparameters.py,sha256=0Lgep_lT5Ro6svvLPdR6OyL_qCb0-bRrxJVxsmySmJs,12176
100
100
  schema/serde/inference_data.py,sha256=DrwE9hU8LMrl0z8W_sUSIaPrRdym_lu0iOqpT4KZxsA,3623
101
101
  schema/serde/marketing_data.py,sha256=yb-fRTe84Sjg7-v3wsvYRRXvrxLSFWSenO0_ikMvUpk,44845
102
- schema/serde/meridian_serde.py,sha256=VQG1eakJ3nZhT_gndfss_nZlxTlGe5EIWLSEgcaHqV8,16134
102
+ schema/serde/meridian_serde.py,sha256=5q2AkZ52Ew0SJUH9g4VXqWHSwjzVJ_-ChCx6B5FA8CE,16246
103
103
  schema/serde/serde.py,sha256=8vUqhJxvZgX9UY3rXTyWJznRgapwDzzaHXDHwV_kKTA,1612
104
104
  schema/serde/test_data.py,sha256=7hfEWyvZ9WcAkVAOXt6elX8stJlsfhfd-ASlHo9SRb8,107342
105
105
  schema/utils/__init__.py,sha256=OzDmXWCpogCt6EkremIShzTowsZF8dHzfEjkJkE9qfk,767
106
106
  schema/utils/date_range_bucketing.py,sha256=14vcRGf3odWT9mBdCykRNmVCEiuUI_1SvVygNzvqBuM,3809
107
107
  schema/utils/proto_enum_converter.py,sha256=vCKGQGWfCt6W7GZy7QQRFAj3XqLUQwt_eWZzsX6pA0E,4021
108
108
  schema/utils/time_record.py,sha256=-KzHFjvSBUUXsfESPAfcJP_VFxaFLqj90Ac0kgKWfpI,4624
109
- google_meridian-1.5.0.dist-info/METADATA,sha256=wW-a2DyNSLr6ySi0ola2R-hP2sIt9yt96MFJLrbGGP4,10024
110
- google_meridian-1.5.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
111
- google_meridian-1.5.0.dist-info/top_level.txt,sha256=oAi0z-fUuo6p8SnJ0WrojGR2mKOWDz43yr6EjzaXqy8,32
112
- google_meridian-1.5.0.dist-info/RECORD,,
109
+ google_meridian-1.5.1.dist-info/METADATA,sha256=SWHh9POzkljYJxkNgvQaG0GMHK9BHsjsW0ksaoIjqaI,10024
110
+ google_meridian-1.5.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
111
+ google_meridian-1.5.1.dist-info/top_level.txt,sha256=oAi0z-fUuo6p8SnJ0WrojGR2mKOWDz43yr6EjzaXqy8,32
112
+ google_meridian-1.5.1.dist-info/RECORD,,
@@ -15,7 +15,7 @@
15
15
  """Implementation of the Model Quality Checks."""
16
16
 
17
17
  import abc
18
- from collections.abc import Sequence
18
+ from collections.abc import MutableMapping, Sequence
19
19
  import dataclasses
20
20
  from typing import Generic, TypeVar
21
21
  import warnings
@@ -77,31 +77,16 @@ class ConvergenceCheck(
77
77
  if not valid_rhat_items:
78
78
  return results.ConvergenceCheckResult(
79
79
  case=results.ConvergenceCases.CONVERGED,
80
- details={
81
- review_constants.RHAT: np.nan,
82
- review_constants.PARAMETER: np.nan,
83
- review_constants.CONVERGENCE_THRESHOLD: (
84
- self._config.convergence_threshold
85
- ),
86
- },
80
+ config=self._config,
81
+ max_rhat=np.nan,
82
+ max_parameter=np.nan,
87
83
  )
88
84
 
89
85
  max_parameter, max_rhat = max(max_rhats.items(), key=lambda item: item[1])
90
86
 
91
- details = {
92
- review_constants.RHAT: max_rhat,
93
- review_constants.PARAMETER: max_parameter,
94
- review_constants.CONVERGENCE_THRESHOLD: (
95
- self._config.convergence_threshold
96
- ),
97
- }
98
-
99
87
  # Case 1: Converged.
100
88
  if max_rhat < self._config.convergence_threshold:
101
- return results.ConvergenceCheckResult(
102
- case=results.ConvergenceCases.CONVERGED,
103
- details=details,
104
- )
89
+ case = results.ConvergenceCases.CONVERGED
105
90
 
106
91
  # Case 2: Not fully converged, but potentially acceptable.
107
92
  elif (
@@ -109,17 +94,18 @@ class ConvergenceCheck(
109
94
  <= max_rhat
110
95
  < self._config.not_fully_convergence_threshold
111
96
  ):
112
- return results.ConvergenceCheckResult(
113
- case=results.ConvergenceCases.NOT_FULLY_CONVERGED,
114
- details=details,
115
- )
97
+ case = results.ConvergenceCases.NOT_FULLY_CONVERGED
116
98
 
117
99
  # Case 3: Not converged and unacceptable.
118
100
  else: # max_rhat >= divergence_threshold
119
- return results.ConvergenceCheckResult(
120
- case=results.ConvergenceCases.NOT_CONVERGED,
121
- details=details,
122
- )
101
+ case = results.ConvergenceCases.NOT_CONVERGED
102
+
103
+ return results.ConvergenceCheckResult(
104
+ case=case,
105
+ config=self._config,
106
+ max_rhat=max_rhat,
107
+ max_parameter=max_parameter,
108
+ )
123
109
 
124
110
 
125
111
  # ==============================================================================
@@ -131,33 +117,25 @@ class BaselineCheck(
131
117
  """Checks for negative baseline probability."""
132
118
 
133
119
  def run(self) -> results.BaselineCheckResult:
134
- prob = self._analyzer.negative_baseline_probability()
135
- details = {
136
- review_constants.NEGATIVE_BASELINE_PROB: prob,
137
- review_constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD: (
138
- self._config.negative_baseline_prob_fail_threshold
139
- ),
140
- review_constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD: (
141
- self._config.negative_baseline_prob_review_threshold
142
- ),
143
- }
120
+ prob = float(self._analyzer.negative_baseline_probability())
121
+
144
122
  # Case 1: FAIL
145
123
  if prob > self._config.negative_baseline_prob_fail_threshold:
146
- return results.BaselineCheckResult(
147
- case=results.BaselineCases.FAIL,
148
- details=details,
149
- )
124
+ case = results.BaselineCases.FAIL
125
+
150
126
  # Case 2: REVIEW
151
127
  elif prob >= self._config.negative_baseline_prob_review_threshold:
152
- return results.BaselineCheckResult(
153
- case=results.BaselineCases.REVIEW,
154
- details=details,
155
- )
128
+ case = results.BaselineCases.REVIEW
129
+
156
130
  # Case 3: PASS
157
131
  else:
158
- return results.BaselineCheckResult(
159
- case=results.BaselineCases.PASS, details=details
160
- )
132
+ case = results.BaselineCases.PASS
133
+
134
+ return results.BaselineCheckResult(
135
+ case=case,
136
+ config=self._config,
137
+ negative_baseline_prob=prob,
138
+ )
161
139
 
162
140
 
163
141
  # ==============================================================================
@@ -189,45 +167,40 @@ class BayesianPPPCheck(
189
167
  >= np.abs(total_outcome_actual - total_outcome_expected_mean)
190
168
  )
191
169
 
192
- details = {
193
- review_constants.BAYESIAN_PPP: bayesian_ppp,
194
- }
195
-
196
170
  if bayesian_ppp >= self._config.ppp_threshold:
197
- return results.BayesianPPPCheckResult(
198
- case=results.BayesianPPPCases.PASS,
199
- details=details,
200
- )
171
+ case = results.BayesianPPPCases.PASS
201
172
  else:
202
- return results.BayesianPPPCheckResult(
203
- case=results.BayesianPPPCases.FAIL,
204
- details=details,
205
- )
173
+ case = results.BayesianPPPCases.FAIL
174
+
175
+ return results.BayesianPPPCheckResult(
176
+ case=case,
177
+ config=self._config,
178
+ bayesian_ppp=bayesian_ppp,
179
+ )
206
180
 
207
181
 
208
182
  # ==============================================================================
209
183
  # Check: Goodness of Fit
210
184
  # ==============================================================================
211
- def _set_details_from_gof_dataframe(
212
- details: dict[str, float],
185
+ def _set_metrics_from_gof_dataframe(
186
+ metrics: MutableMapping[str, float],
213
187
  gof_df: pd.DataFrame,
214
188
  geo_granularity: str,
215
- suffix: str | None = None,
189
+ suffix: str,
216
190
  ) -> None:
217
- """Sets the `details` variable of the GoodnessOfFitCheckResult.
191
+ """Sets the `metrics` variable of the GoodnessOfFitCheckResult.
218
192
 
219
193
  This method takes a DataFrame containing goodness of fit metrics and pivots it
220
- to a Series, which is then added to the `details` variable of the
194
+ to a Series, which is then added to the `metrics` variable of the
221
195
  `GoodnessOfFitCheckResult`.
222
196
 
223
197
  Args:
224
- details: A dictionary to store the goodness of fit metrics in.
198
+ metrics: A dictionary to store the goodness of fit metrics in.
225
199
  gof_df: A DataFrame containing predictive accuracy of the whole data (if
226
200
  holdout set is not used) of filtered to a single evaluation set ("all",
227
201
  "train", or "test").
228
202
  geo_granularity: The geo granularity of the data ("geo" or "national").
229
- suffix: A suffix to add to the metric names (e.g., "all", "train", "test").
230
- If None, the metrics are added without a suffix.
203
+ suffix: A suffix to add to the metric names (e.g., "_train", "_test").
231
204
  """
232
205
  gof_metrics_pivoted = gof_df.pivot(
233
206
  index=constants.GEO_GRANULARITY,
@@ -235,22 +208,15 @@ def _set_details_from_gof_dataframe(
235
208
  values=constants.VALUE,
236
209
  )
237
210
  gof_metrics_series = gof_metrics_pivoted.loc[geo_granularity]
238
- if suffix is not None:
239
- details[f"{review_constants.R_SQUARED}_{suffix}"] = gof_metrics_series[
240
- constants.R_SQUARED
241
- ]
242
- details[f"{review_constants.MAPE}_{suffix}"] = gof_metrics_series[
243
- constants.MAPE
244
- ]
245
- details[f"{review_constants.WMAPE}_{suffix}"] = gof_metrics_series[
246
- constants.WMAPE
247
- ]
248
- else:
249
- details[review_constants.R_SQUARED] = gof_metrics_series[
250
- constants.R_SQUARED
251
- ]
252
- details[review_constants.MAPE] = gof_metrics_series[constants.MAPE]
253
- details[review_constants.WMAPE] = gof_metrics_series[constants.WMAPE]
211
+ metrics[f"{review_constants.R_SQUARED}{suffix}"] = gof_metrics_series[
212
+ constants.R_SQUARED
213
+ ]
214
+ metrics[f"{review_constants.MAPE}{suffix}"] = gof_metrics_series[
215
+ constants.MAPE
216
+ ]
217
+ metrics[f"{review_constants.WMAPE}{suffix}"] = gof_metrics_series[
218
+ constants.WMAPE
219
+ ]
254
220
 
255
221
 
256
222
  class GoodnessOfFitCheck(
@@ -269,7 +235,7 @@ class GoodnessOfFitCheck(
269
235
  gof_metrics = gof_df[gof_df[constants.GEO_GRANULARITY] == geo_granularity]
270
236
  is_holdout = constants.EVALUATION_SET_VAR in gof_df.columns
271
237
 
272
- details = {}
238
+ metrics_dict = {}
273
239
  case = results.GoodnessOfFitCases.PASS
274
240
 
275
241
  if is_holdout:
@@ -281,29 +247,71 @@ class GoodnessOfFitCheck(
281
247
  set_metrics = gof_metrics[
282
248
  gof_metrics[constants.EVALUATION_SET_VAR] == evaluation_set
283
249
  ]
284
- _set_details_from_gof_dataframe(
285
- details=details,
250
+ _set_metrics_from_gof_dataframe(
251
+ metrics=metrics_dict,
286
252
  gof_df=set_metrics,
287
253
  geo_granularity=geo_granularity,
288
254
  suffix=suffix,
289
255
  )
290
- if details[f"{review_constants.R_SQUARED}_{suffix}"] <= 0:
256
+ if metrics_dict[f"{review_constants.R_SQUARED}{suffix}"] <= 0:
291
257
  case = results.GoodnessOfFitCases.REVIEW
258
+ return results.GoodnessOfFitCheckResult(
259
+ case=case,
260
+ metrics=results.GoodnessOfFitMetrics(
261
+ r_squared=metrics_dict[
262
+ f"{review_constants.R_SQUARED}{review_constants.ALL_SUFFIX}"
263
+ ],
264
+ mape=metrics_dict[
265
+ f"{review_constants.MAPE}{review_constants.ALL_SUFFIX}"
266
+ ],
267
+ wmape=metrics_dict[
268
+ f"{review_constants.WMAPE}{review_constants.ALL_SUFFIX}"
269
+ ],
270
+ r_squared_train=metrics_dict[
271
+ f"{review_constants.R_SQUARED}{review_constants.TRAIN_SUFFIX}"
272
+ ],
273
+ mape_train=metrics_dict[
274
+ f"{review_constants.MAPE}{review_constants.TRAIN_SUFFIX}"
275
+ ],
276
+ wmape_train=metrics_dict[
277
+ f"{review_constants.WMAPE}{review_constants.TRAIN_SUFFIX}"
278
+ ],
279
+ r_squared_test=metrics_dict[
280
+ f"{review_constants.R_SQUARED}{review_constants.TEST_SUFFIX}"
281
+ ],
282
+ mape_test=metrics_dict[
283
+ f"{review_constants.MAPE}{review_constants.TEST_SUFFIX}"
284
+ ],
285
+ wmape_test=metrics_dict[
286
+ f"{review_constants.WMAPE}{review_constants.TEST_SUFFIX}"
287
+ ],
288
+ ),
289
+ is_holdout=is_holdout,
290
+ )
292
291
  else:
293
- _set_details_from_gof_dataframe(
294
- details=details,
292
+ _set_metrics_from_gof_dataframe(
293
+ metrics=metrics_dict,
295
294
  gof_df=gof_metrics,
296
295
  geo_granularity=geo_granularity,
297
- suffix=None,
296
+ suffix=review_constants.ALL_SUFFIX,
298
297
  )
299
- if details[review_constants.R_SQUARED] <= 0:
298
+ if metrics_dict[review_constants.R_SQUARED] <= 0:
300
299
  case = results.GoodnessOfFitCases.REVIEW
301
-
302
- return results.GoodnessOfFitCheckResult(
303
- case=case,
304
- details=details,
305
- is_holdout=is_holdout,
306
- )
300
+ return results.GoodnessOfFitCheckResult(
301
+ case=case,
302
+ metrics=results.GoodnessOfFitMetrics(
303
+ r_squared=metrics_dict[
304
+ f"{review_constants.R_SQUARED}{review_constants.ALL_SUFFIX}"
305
+ ],
306
+ mape=metrics_dict[
307
+ f"{review_constants.MAPE}{review_constants.ALL_SUFFIX}"
308
+ ],
309
+ wmape=metrics_dict[
310
+ f"{review_constants.WMAPE}{review_constants.ALL_SUFFIX}"
311
+ ],
312
+ ),
313
+ is_holdout=is_holdout,
314
+ )
307
315
 
308
316
 
309
317
  # ==============================================================================
@@ -475,8 +483,10 @@ def _compute_channel_results(
475
483
  channel_results.append(
476
484
  results.ROIConsistencyChannelResult(
477
485
  case=case,
478
- details={},
479
486
  channel_name=channel,
487
+ prior_roi_lo=np.nan,
488
+ prior_roi_hi=np.nan,
489
+ posterior_roi_mean=np.nan,
480
490
  )
481
491
  )
482
492
  for i, channel in enumerate(channel_data.all_channels):
@@ -491,14 +501,10 @@ def _compute_channel_results(
491
501
  channel_results.append(
492
502
  results.ROIConsistencyChannelResult(
493
503
  case=case,
494
- details={
495
- review_constants.PRIOR_ROI_LO: channel_data.prior_roi_los[i],
496
- review_constants.PRIOR_ROI_HI: channel_data.prior_roi_his[i],
497
- review_constants.POSTERIOR_ROI_MEAN: (
498
- channel_data.posterior_means[i]
499
- ),
500
- },
501
504
  channel_name=channel,
505
+ prior_roi_lo=channel_data.prior_roi_los[i],
506
+ prior_roi_hi=channel_data.prior_roi_his[i],
507
+ posterior_roi_mean=channel_data.posterior_means[i],
502
508
  )
503
509
  )
504
510
  return channel_results
@@ -558,7 +564,7 @@ def _compute_aggregate_result(
558
564
 
559
565
  return results.ROIConsistencyCheckResult(
560
566
  case=aggregate_case,
561
- details=aggregate_details,
567
+ aggregate_details=aggregate_details,
562
568
  channel_results=channel_results,
563
569
  )
564
570
 
@@ -734,7 +740,7 @@ class PriorPosteriorShiftCheck(
734
740
  no_shift_channels.append(channel_name)
735
741
  channel_results.append(
736
742
  results.PriorPosteriorShiftChannelResult(
737
- case=case, details={}, channel_name=channel_name
743
+ case=case, channel_name=channel_name
738
744
  )
739
745
  )
740
746
  return channel_results, no_shift_channels
@@ -752,17 +758,13 @@ class PriorPosteriorShiftCheck(
752
758
 
753
759
  if no_shift_channels:
754
760
  agg_case = results.PriorPosteriorShiftAggregateCases.REVIEW
755
- final_details = {
756
- "channels_str": ", ".join(
757
- f"`{channel}`" for channel in no_shift_channels
758
- )
759
- }
760
761
  else:
761
762
  agg_case = results.PriorPosteriorShiftAggregateCases.PASS
762
- final_details = {}
763
763
 
764
764
  return results.PriorPosteriorShiftCheckResult(
765
- case=agg_case, details=final_details, channel_results=channel_results
765
+ case=agg_case,
766
+ channel_results=channel_results,
767
+ no_shift_channels=no_shift_channels,
766
768
  )
767
769
 
768
770
  def run(self) -> results.PriorPosteriorShiftCheckResult:
@@ -32,9 +32,9 @@ NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD = (
32
32
  R_SQUARED = "r_squared"
33
33
  MAPE = "mape"
34
34
  WMAPE = "wmape"
35
- ALL_SUFFIX = "all"
36
- TRAIN_SUFFIX = "train"
37
- TEST_SUFFIX = "test"
35
+ ALL_SUFFIX = ""
36
+ TRAIN_SUFFIX = "_train"
37
+ TEST_SUFFIX = "_test"
38
38
  EVALUATION_SET_SUFFIXES = (ALL_SUFFIX, TRAIN_SUFFIX, TEST_SUFFIX)
39
39
  MEAN = "mean"
40
40
  VARIANCE = "variance"
@@ -14,9 +14,13 @@
14
14
 
15
15
  """Data structures for the Model Quality Checks results."""
16
16
 
17
+ import abc
18
+ from collections.abc import Mapping
17
19
  import dataclasses
18
20
  import enum
19
21
  from typing import Any
22
+
23
+ from meridian.analysis.review import configs
20
24
  from meridian.analysis.review import constants
21
25
 
22
26
 
@@ -58,11 +62,16 @@ class ModelCheckCase(BaseCase):
58
62
 
59
63
 
60
64
  @dataclasses.dataclass(frozen=True)
61
- class BaseResultData:
65
+ class BaseResultData(abc.ABC):
62
66
  """Base class for check result data."""
63
67
 
64
68
  case: BaseCase
65
- details: dict[str, Any]
69
+
70
+ @property
71
+ @abc.abstractmethod
72
+ def details(self) -> Mapping[str, Any]:
73
+ """Returns the details for message formatting."""
74
+ raise NotImplementedError
66
75
 
67
76
 
68
77
  @dataclasses.dataclass(frozen=True)
@@ -145,17 +154,18 @@ class ConvergenceCheckResult(CheckResult):
145
154
  """The immutable result of the Convergence Check."""
146
155
 
147
156
  case: ConvergenceCases
157
+ config: configs.ConvergenceConfig
158
+ max_rhat: float
159
+ max_parameter: str
148
160
 
149
- def __post_init__(self):
150
- if self.case == ConvergenceCases.CONVERGED and (
151
- constants.CONVERGENCE_THRESHOLD not in self.details
152
- ):
153
- raise ValueError(
154
- "The message template 'The model has likely converged, as all"
155
- " parameters have R-hat values < {convergence_threshold}'. is"
156
- " missing required formatting arguments: convergence_threshold."
157
- f" Details: {self.details}."
158
- )
161
+ @property
162
+ def details(self) -> Mapping[str, Any]:
163
+ """The check result details."""
164
+ return {
165
+ constants.RHAT: self.max_rhat,
166
+ constants.PARAMETER: self.max_parameter,
167
+ constants.CONVERGENCE_THRESHOLD: self.config.convergence_threshold,
168
+ }
159
169
 
160
170
 
161
171
  # ==============================================================================
@@ -223,24 +233,21 @@ class BaselineCheckResult(CheckResult):
223
233
  """The immutable result of the Baseline Check."""
224
234
 
225
235
  case: BaselineCases
236
+ config: configs.BaselineConfig
237
+ negative_baseline_prob: float
226
238
 
227
- def __post_init__(self):
228
- if self.case is BaselineCases.PASS:
229
- return
230
- if any(
231
- key not in self.details
232
- for key in (
233
- constants.NEGATIVE_BASELINE_PROB,
234
- constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD,
235
- constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD,
236
- )
237
- ):
238
- raise ValueError(
239
- "The message template is missing required formatting arguments:"
240
- " negative_baseline_prob, negative_baseline_prob_fail_threshold,"
241
- " negative_baseline_prob_review_threshold. Details:"
242
- f" {self.details}."
243
- )
239
+ @property
240
+ def details(self) -> Mapping[str, Any]:
241
+ """The check result details."""
242
+ return {
243
+ constants.NEGATIVE_BASELINE_PROB: self.negative_baseline_prob,
244
+ constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD: (
245
+ self.config.negative_baseline_prob_fail_threshold
246
+ ),
247
+ constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD: (
248
+ self.config.negative_baseline_prob_review_threshold
249
+ ),
250
+ }
244
251
 
245
252
 
246
253
  # ==============================================================================
@@ -287,14 +294,15 @@ class BayesianPPPCheckResult(CheckResult):
287
294
  """The immutable result of the Bayesian Posterior Predictive P-value Check."""
288
295
 
289
296
  case: BayesianPPPCases
297
+ config: configs.BayesianPPPConfig
298
+ bayesian_ppp: float
290
299
 
291
- def __post_init__(self):
292
- if constants.BAYESIAN_PPP not in self.details:
293
- raise ValueError(
294
- "The message template is missing required formatting arguments:"
295
- " bayesian_ppp. Details:"
296
- f" {self.details}."
297
- )
300
+ @property
301
+ def details(self) -> Mapping[str, Any]:
302
+ """The check result details."""
303
+ return {
304
+ constants.BAYESIAN_PPP: self.bayesian_ppp,
305
+ }
298
306
 
299
307
 
300
308
  # ==============================================================================
@@ -337,55 +345,77 @@ class GoodnessOfFitCases(ModelCheckCase, enum.Enum):
337
345
  super().__init__(status, message_template, recommendation)
338
346
 
339
347
 
348
+ @dataclasses.dataclass(frozen=True)
349
+ class GoodnessOfFitMetrics:
350
+ """The metrics for the Goodness of Fit Check."""
351
+
352
+ r_squared: float
353
+ mape: float
354
+ wmape: float
355
+ r_squared_train: float | None = None
356
+ mape_train: float | None = None
357
+ wmape_train: float | None = None
358
+ r_squared_test: float | None = None
359
+ mape_test: float | None = None
360
+ wmape_test: float | None = None
361
+
362
+
340
363
  @dataclasses.dataclass(frozen=True)
341
364
  class GoodnessOfFitCheckResult(CheckResult):
342
365
  """The immutable result of the Goodness of Fit Check."""
343
366
 
344
367
  case: GoodnessOfFitCases
368
+ metrics: GoodnessOfFitMetrics
345
369
  is_holdout: bool = False
346
370
 
347
371
  def __post_init__(self):
348
372
  if self.is_holdout:
349
- required_keys = []
350
- for suffix in [
351
- constants.ALL_SUFFIX,
352
- constants.TRAIN_SUFFIX,
353
- constants.TEST_SUFFIX,
354
- ]:
355
- required_keys.extend([
356
- f"{constants.R_SQUARED}_{suffix}",
357
- f"{constants.MAPE}_{suffix}",
358
- f"{constants.WMAPE}_{suffix}",
359
- ])
360
- if any(key not in self.details for key in required_keys):
373
+ if any(
374
+ metric is None
375
+ for metric in (
376
+ self.metrics.r_squared_train,
377
+ self.metrics.mape_train,
378
+ self.metrics.wmape_train,
379
+ self.metrics.r_squared_test,
380
+ self.metrics.mape_test,
381
+ self.metrics.wmape_test,
382
+ )
383
+ ):
361
384
  raise ValueError(
362
385
  "The message template is missing required formatting arguments for"
363
- f" holdout case. Required keys: {required_keys}. Details:"
364
- f" {self.details}."
365
- )
366
- elif any(
367
- key not in self.details
368
- for key in (
369
- constants.R_SQUARED,
370
- constants.MAPE,
371
- constants.WMAPE,
386
+ " holdout case. Required keys: r_squared_train, mape_train,"
387
+ " wmape_train, r_squared_test, mape_test, wmape_test. Metrics:"
388
+ f" {self.metrics}."
372
389
  )
373
- ):
374
- raise ValueError(
375
- "The message template is missing required formatting arguments:"
376
- " r_squared, mape, wmape. Details:"
377
- f" {self.details}."
378
- )
390
+
391
+ @property
392
+ def details(self) -> Mapping[str, Any]:
393
+ """The check result details."""
394
+ return {
395
+ f"{constants.R_SQUARED}{constants.ALL_SUFFIX}": self.metrics.r_squared,
396
+ f"{constants.MAPE}{constants.ALL_SUFFIX}": self.metrics.mape,
397
+ f"{constants.WMAPE}{constants.ALL_SUFFIX}": self.metrics.wmape,
398
+ f"{constants.R_SQUARED}{constants.TRAIN_SUFFIX}": (
399
+ self.metrics.r_squared_train
400
+ ),
401
+ f"{constants.MAPE}{constants.TRAIN_SUFFIX}": self.metrics.mape_train,
402
+ f"{constants.WMAPE}{constants.TRAIN_SUFFIX}": self.metrics.wmape_train,
403
+ f"{constants.R_SQUARED}{constants.TEST_SUFFIX}": (
404
+ self.metrics.r_squared_test
405
+ ),
406
+ f"{constants.MAPE}{constants.TEST_SUFFIX}": self.metrics.mape_test,
407
+ f"{constants.WMAPE}{constants.TEST_SUFFIX}": self.metrics.wmape_test,
408
+ }
379
409
 
380
410
  @property
381
411
  def recommendation(self) -> str:
382
- """Returns the check result message."""
412
+ """The check result message."""
383
413
  if self.is_holdout:
384
414
  report_str = (
385
- "R-squared = {r_squared_all:.4f} (All),"
415
+ "R-squared = {r_squared:.4f} (All),"
386
416
  " {r_squared_train:.4f} (Train), {r_squared_test:.4f} (Test); MAPE"
387
- " = {mape_all:.4f} (All), {mape_train:.4f} (Train),"
388
- " {mape_test:.4f} (Test); wMAPE = {wmape_all:.4f} (All),"
417
+ " = {mape:.4f} (All), {mape_train:.4f} (Train),"
418
+ " {mape_test:.4f} (Test); wMAPE = {wmape:.4f} (All),"
389
419
  " {wmape_train:.4f} (Train), {wmape_test:.4f} (Test)".format(
390
420
  **self.details
391
421
  )
@@ -450,6 +480,18 @@ class ROIConsistencyChannelResult(ChannelResult):
450
480
  """The immutable result of ROI Consistency Check for a single channel."""
451
481
 
452
482
  case: ROIConsistencyChannelCases
483
+ prior_roi_lo: float
484
+ prior_roi_hi: float
485
+ posterior_roi_mean: float
486
+
487
+ @property
488
+ def details(self) -> Mapping[str, Any]:
489
+ """Returns the check result details."""
490
+ return {
491
+ constants.PRIOR_ROI_LO: self.prior_roi_lo,
492
+ constants.PRIOR_ROI_HI: self.prior_roi_hi,
493
+ constants.POSTERIOR_ROI_MEAN: self.posterior_roi_mean,
494
+ }
453
495
 
454
496
 
455
497
  @dataclasses.dataclass(frozen=True)
@@ -458,6 +500,12 @@ class ROIConsistencyCheckResult(CheckResult):
458
500
 
459
501
  case: ROIConsistencyAggregateCases
460
502
  channel_results: list[ROIConsistencyChannelResult]
503
+ aggregate_details: Mapping[str, Any]
504
+
505
+ @property
506
+ def details(self) -> Mapping[str, Any]:
507
+ """Returns the check result details."""
508
+ return self.aggregate_details
461
509
 
462
510
 
463
511
  # ==============================================================================
@@ -517,6 +565,11 @@ class PriorPosteriorShiftChannelResult(ChannelResult):
517
565
 
518
566
  case: PriorPosteriorShiftChannelCases
519
567
 
568
+ @property
569
+ def details(self) -> Mapping[str, Any]:
570
+ """Returns the check result details."""
571
+ return {}
572
+
520
573
 
521
574
  @dataclasses.dataclass(frozen=True)
522
575
  class PriorPosteriorShiftCheckResult(CheckResult):
@@ -524,6 +577,16 @@ class PriorPosteriorShiftCheckResult(CheckResult):
524
577
 
525
578
  case: PriorPosteriorShiftAggregateCases
526
579
  channel_results: list[PriorPosteriorShiftChannelResult]
580
+ no_shift_channels: list[str]
581
+
582
+ @property
583
+ def details(self) -> Mapping[str, Any]:
584
+ """Returns the check result details."""
585
+ return {
586
+ "channels_str": ", ".join(
587
+ f"`{channel}`" for channel in self.no_shift_channels
588
+ )
589
+ }
527
590
 
528
591
 
529
592
  # ==============================================================================
@@ -567,7 +630,7 @@ class ReviewSummary:
567
630
  return "\n".join(report)
568
631
 
569
632
  @property
570
- def checks_status(self) -> dict[str, str]:
633
+ def checks_status(self) -> Mapping[str, str]:
571
634
  """Returns a dictionary of check names and statuses."""
572
635
  return {
573
636
  result.__class__.__name__: result.case.status.name
@@ -29,7 +29,7 @@ CheckType = typing.Type[checks.BaseCheck]
29
29
  ConfigInstance = configs.BaseConfig
30
30
  ChecksBattery = immutabledict.immutabledict[CheckType, ConfigInstance]
31
31
 
32
- _DEFAULT_POST_CONVERGENCE_CHECKS: ChecksBattery = immutabledict.immutabledict({
32
+ _POST_CONVERGENCE_CHECKS: ChecksBattery = immutabledict.immutabledict({
33
33
  checks.BaselineCheck: configs.BaselineConfig(),
34
34
  checks.BayesianPPPCheck: configs.BayesianPPPConfig(),
35
35
  checks.GoodnessOfFitCheck: configs.GoodnessOfFitConfig(),
@@ -39,39 +39,22 @@ _DEFAULT_POST_CONVERGENCE_CHECKS: ChecksBattery = immutabledict.immutabledict({
39
39
 
40
40
 
41
41
  class ModelReviewer:
42
- """Executes a series of quality checks on a Meridian model.
42
+ """A tool for executing a series of quality checks on a Meridian model.
43
43
 
44
44
  The reviewer first runs a convergence check. If the model has converged, it
45
45
  proceeds to run a battery of post-convergence checks.
46
46
 
47
- The default battery of post-convergence checks includes:
47
+ The battery of post-convergence checks includes:
48
48
  - BaselineCheck
49
49
  - BayesianPPPCheck
50
50
  - GoodnessOfFitCheck
51
51
  - PriorPosteriorShiftCheck
52
52
  - ROIConsistencyCheck
53
- Each with its default configuration.
54
-
55
- This battery of checks can be customized by passing a dictionary to the
56
- `post_convergence_checks` argument of the constructor, mapping check
57
- classes to their configuration instances. For example, to run only the
58
- BaselineCheck with a non-default configuration:
59
-
60
- ```python
61
- my_checks = {
62
- checks.BaselineCheck: configs.BaselineConfig(
63
- negative_baseline_prob_review_threshold=0.1,
64
- negative_baseline_prob_fail_threshold=0.5,
65
- )
66
- }
67
- reviewer = ModelReviewer(meridian_model, post_convergence_checks=my_checks)
68
- ```
69
53
  """
70
54
 
71
55
  def __init__(
72
56
  self,
73
57
  meridian,
74
- post_convergence_checks: ChecksBattery = _DEFAULT_POST_CONVERGENCE_CHECKS,
75
58
  ):
76
59
  self._meridian = meridian
77
60
  self._results: list[results.CheckResult] = []
@@ -79,7 +62,6 @@ class ModelReviewer:
79
62
  model_context=meridian.model_context,
80
63
  inference_data=meridian.inference_data,
81
64
  )
82
- self._post_convergence_checks = post_convergence_checks
83
65
 
84
66
  def _run_and_handle(self, check_class, config):
85
67
  instance = check_class(self._meridian, self._analyzer, config) # pytype: disable=not-instantiable
@@ -139,7 +121,7 @@ class ModelReviewer:
139
121
  )
140
122
 
141
123
  # Run all other checks in sequence.
142
- for check_class, config in self._post_convergence_checks.items():
124
+ for check_class, config in _POST_CONVERGENCE_CHECKS.items():
143
125
  if (
144
126
  check_class == checks.PriorPosteriorShiftCheck
145
127
  and not self._uses_roi_priors()
@@ -550,6 +550,7 @@ class EDAEngine:
550
550
 
551
551
  def __init__(
552
552
  self,
553
+ # TODO: b/476230365 - Remove meridian arg.
553
554
  meridian: model.Meridian | None = None,
554
555
  spec: eda_spec.EDASpec = eda_spec.EDASpec(),
555
556
  *,
meridian/version.py CHANGED
@@ -14,4 +14,4 @@
14
14
 
15
15
  """Module for Meridian version."""
16
16
 
17
- __version__ = "1.5.0"
17
+ __version__ = "1.5.1"
@@ -357,7 +357,9 @@ def save_meridian(
357
357
  if not _file_exists(os.path.dirname(file_path)):
358
358
  _make_dirs(os.path.dirname(file_path))
359
359
 
360
- with _file_open(file_path, 'wb') as f:
360
+ mode = 'wb' if file_path.endswith('.binpb') else 'w'
361
+
362
+ with _file_open(file_path, mode) as f:
361
363
  # Creates an MmmKernel.
362
364
  serialized_kernel = MeridianSerde().serialize(
363
365
  mmm,
@@ -402,7 +404,9 @@ def load_meridian(
402
404
  Returns:
403
405
  Model object loaded from the file path.
404
406
  """
405
- with _file_open(file_path, 'rb') as f:
407
+ mode = 'rb' if file_path.endswith('.binpb') else 'r'
408
+
409
+ with _file_open(file_path, mode) as f:
406
410
  if file_path.endswith('.binpb'):
407
411
  serialized_model = kernel_pb.MmmKernel.FromString(f.read())
408
412
  elif file_path.endswith('.textproto') or file_path.endswith('.txtpb'):