google-meridian 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +50 -46
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/checks.py +118 -116
  7. meridian/analysis/review/constants.py +3 -3
  8. meridian/analysis/review/results.py +131 -68
  9. meridian/analysis/review/reviewer.py +8 -23
  10. meridian/analysis/summarizer.py +6 -1
  11. meridian/analysis/test_utils.py +2898 -2538
  12. meridian/analysis/visualizer.py +28 -9
  13. meridian/backend/__init__.py +106 -0
  14. meridian/constants.py +1 -0
  15. meridian/data/input_data.py +30 -52
  16. meridian/data/input_data_builder.py +2 -9
  17. meridian/data/test_utils.py +25 -41
  18. meridian/data/validator.py +48 -0
  19. meridian/mlflow/autolog.py +19 -9
  20. meridian/model/adstock_hill.py +3 -5
  21. meridian/model/context.py +134 -0
  22. meridian/model/eda/constants.py +334 -4
  23. meridian/model/eda/eda_engine.py +724 -312
  24. meridian/model/eda/eda_outcome.py +177 -33
  25. meridian/model/model.py +159 -110
  26. meridian/model/model_test_data.py +38 -0
  27. meridian/model/posterior_sampler.py +103 -62
  28. meridian/model/prior_sampler.py +114 -94
  29. meridian/model/spec.py +23 -14
  30. meridian/templates/card.html.jinja +9 -7
  31. meridian/templates/chart.html.jinja +1 -6
  32. meridian/templates/finding.html.jinja +19 -0
  33. meridian/templates/findings.html.jinja +33 -0
  34. meridian/templates/formatter.py +41 -5
  35. meridian/templates/formatter_test.py +127 -0
  36. meridian/templates/style.css +66 -9
  37. meridian/templates/style.scss +85 -4
  38. meridian/templates/table.html.jinja +1 -0
  39. meridian/version.py +1 -1
  40. scenarioplanner/linkingapi/constants.py +1 -1
  41. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  42. schema/processors/marketing_processor.py +11 -10
  43. schema/processors/model_processor.py +4 -1
  44. schema/serde/distribution.py +12 -7
  45. schema/serde/hyperparameters.py +54 -107
  46. schema/serde/meridian_serde.py +12 -3
  47. schema/utils/__init__.py +1 -0
  48. schema/utils/proto_enum_converter.py +127 -0
  49. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
  50. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
15
15
  """Implementation of the Model Quality Checks."""
16
16
 
17
17
  import abc
18
- from collections.abc import Sequence
18
+ from collections.abc import MutableMapping, Sequence
19
19
  import dataclasses
20
20
  from typing import Generic, TypeVar
21
21
  import warnings
@@ -77,31 +77,16 @@ class ConvergenceCheck(
77
77
  if not valid_rhat_items:
78
78
  return results.ConvergenceCheckResult(
79
79
  case=results.ConvergenceCases.CONVERGED,
80
- details={
81
- review_constants.RHAT: np.nan,
82
- review_constants.PARAMETER: np.nan,
83
- review_constants.CONVERGENCE_THRESHOLD: (
84
- self._config.convergence_threshold
85
- ),
86
- },
80
+ config=self._config,
81
+ max_rhat=np.nan,
82
+ max_parameter=np.nan,
87
83
  )
88
84
 
89
85
  max_parameter, max_rhat = max(max_rhats.items(), key=lambda item: item[1])
90
86
 
91
- details = {
92
- review_constants.RHAT: max_rhat,
93
- review_constants.PARAMETER: max_parameter,
94
- review_constants.CONVERGENCE_THRESHOLD: (
95
- self._config.convergence_threshold
96
- ),
97
- }
98
-
99
87
  # Case 1: Converged.
100
88
  if max_rhat < self._config.convergence_threshold:
101
- return results.ConvergenceCheckResult(
102
- case=results.ConvergenceCases.CONVERGED,
103
- details=details,
104
- )
89
+ case = results.ConvergenceCases.CONVERGED
105
90
 
106
91
  # Case 2: Not fully converged, but potentially acceptable.
107
92
  elif (
@@ -109,17 +94,18 @@ class ConvergenceCheck(
109
94
  <= max_rhat
110
95
  < self._config.not_fully_convergence_threshold
111
96
  ):
112
- return results.ConvergenceCheckResult(
113
- case=results.ConvergenceCases.NOT_FULLY_CONVERGED,
114
- details=details,
115
- )
97
+ case = results.ConvergenceCases.NOT_FULLY_CONVERGED
116
98
 
117
99
  # Case 3: Not converged and unacceptable.
118
100
  else: # max_rhat >= divergence_threshold
119
- return results.ConvergenceCheckResult(
120
- case=results.ConvergenceCases.NOT_CONVERGED,
121
- details=details,
122
- )
101
+ case = results.ConvergenceCases.NOT_CONVERGED
102
+
103
+ return results.ConvergenceCheckResult(
104
+ case=case,
105
+ config=self._config,
106
+ max_rhat=max_rhat,
107
+ max_parameter=max_parameter,
108
+ )
123
109
 
124
110
 
125
111
  # ==============================================================================
@@ -131,33 +117,25 @@ class BaselineCheck(
131
117
  """Checks for negative baseline probability."""
132
118
 
133
119
  def run(self) -> results.BaselineCheckResult:
134
- prob = self._analyzer.negative_baseline_probability()
135
- details = {
136
- review_constants.NEGATIVE_BASELINE_PROB: prob,
137
- review_constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD: (
138
- self._config.negative_baseline_prob_fail_threshold
139
- ),
140
- review_constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD: (
141
- self._config.negative_baseline_prob_review_threshold
142
- ),
143
- }
120
+ prob = float(self._analyzer.negative_baseline_probability())
121
+
144
122
  # Case 1: FAIL
145
123
  if prob > self._config.negative_baseline_prob_fail_threshold:
146
- return results.BaselineCheckResult(
147
- case=results.BaselineCases.FAIL,
148
- details=details,
149
- )
124
+ case = results.BaselineCases.FAIL
125
+
150
126
  # Case 2: REVIEW
151
127
  elif prob >= self._config.negative_baseline_prob_review_threshold:
152
- return results.BaselineCheckResult(
153
- case=results.BaselineCases.REVIEW,
154
- details=details,
155
- )
128
+ case = results.BaselineCases.REVIEW
129
+
156
130
  # Case 3: PASS
157
131
  else:
158
- return results.BaselineCheckResult(
159
- case=results.BaselineCases.PASS, details=details
160
- )
132
+ case = results.BaselineCases.PASS
133
+
134
+ return results.BaselineCheckResult(
135
+ case=case,
136
+ config=self._config,
137
+ negative_baseline_prob=prob,
138
+ )
161
139
 
162
140
 
163
141
  # ==============================================================================
@@ -189,45 +167,40 @@ class BayesianPPPCheck(
189
167
  >= np.abs(total_outcome_actual - total_outcome_expected_mean)
190
168
  )
191
169
 
192
- details = {
193
- review_constants.BAYESIAN_PPP: bayesian_ppp,
194
- }
195
-
196
170
  if bayesian_ppp >= self._config.ppp_threshold:
197
- return results.BayesianPPPCheckResult(
198
- case=results.BayesianPPPCases.PASS,
199
- details=details,
200
- )
171
+ case = results.BayesianPPPCases.PASS
201
172
  else:
202
- return results.BayesianPPPCheckResult(
203
- case=results.BayesianPPPCases.FAIL,
204
- details=details,
205
- )
173
+ case = results.BayesianPPPCases.FAIL
174
+
175
+ return results.BayesianPPPCheckResult(
176
+ case=case,
177
+ config=self._config,
178
+ bayesian_ppp=bayesian_ppp,
179
+ )
206
180
 
207
181
 
208
182
  # ==============================================================================
209
183
  # Check: Goodness of Fit
210
184
  # ==============================================================================
211
- def _set_details_from_gof_dataframe(
212
- details: dict[str, float],
185
+ def _set_metrics_from_gof_dataframe(
186
+ metrics: MutableMapping[str, float],
213
187
  gof_df: pd.DataFrame,
214
188
  geo_granularity: str,
215
- suffix: str | None = None,
189
+ suffix: str,
216
190
  ) -> None:
217
- """Sets the `details` variable of the GoodnessOfFitCheckResult.
191
+ """Sets the `metrics` variable of the GoodnessOfFitCheckResult.
218
192
 
219
193
  This method takes a DataFrame containing goodness of fit metrics and pivots it
220
- to a Series, which is then added to the `details` variable of the
194
+ to a Series, which is then added to the `metrics` variable of the
221
195
  `GoodnessOfFitCheckResult`.
222
196
 
223
197
  Args:
224
- details: A dictionary to store the goodness of fit metrics in.
198
+ metrics: A dictionary to store the goodness of fit metrics in.
225
199
  gof_df: A DataFrame containing predictive accuracy of the whole data (if
226
200
  holdout set is not used) of filtered to a single evaluation set ("all",
227
201
  "train", or "test").
228
202
  geo_granularity: The geo granularity of the data ("geo" or "national").
229
- suffix: A suffix to add to the metric names (e.g., "all", "train", "test").
230
- If None, the metrics are added without a suffix.
203
+ suffix: A suffix to add to the metric names (e.g., "_train", "_test").
231
204
  """
232
205
  gof_metrics_pivoted = gof_df.pivot(
233
206
  index=constants.GEO_GRANULARITY,
@@ -235,22 +208,15 @@ def _set_details_from_gof_dataframe(
235
208
  values=constants.VALUE,
236
209
  )
237
210
  gof_metrics_series = gof_metrics_pivoted.loc[geo_granularity]
238
- if suffix is not None:
239
- details[f"{review_constants.R_SQUARED}_{suffix}"] = gof_metrics_series[
240
- constants.R_SQUARED
241
- ]
242
- details[f"{review_constants.MAPE}_{suffix}"] = gof_metrics_series[
243
- constants.MAPE
244
- ]
245
- details[f"{review_constants.WMAPE}_{suffix}"] = gof_metrics_series[
246
- constants.WMAPE
247
- ]
248
- else:
249
- details[review_constants.R_SQUARED] = gof_metrics_series[
250
- constants.R_SQUARED
251
- ]
252
- details[review_constants.MAPE] = gof_metrics_series[constants.MAPE]
253
- details[review_constants.WMAPE] = gof_metrics_series[constants.WMAPE]
211
+ metrics[f"{review_constants.R_SQUARED}{suffix}"] = gof_metrics_series[
212
+ constants.R_SQUARED
213
+ ]
214
+ metrics[f"{review_constants.MAPE}{suffix}"] = gof_metrics_series[
215
+ constants.MAPE
216
+ ]
217
+ metrics[f"{review_constants.WMAPE}{suffix}"] = gof_metrics_series[
218
+ constants.WMAPE
219
+ ]
254
220
 
255
221
 
256
222
  class GoodnessOfFitCheck(
@@ -269,7 +235,7 @@ class GoodnessOfFitCheck(
269
235
  gof_metrics = gof_df[gof_df[constants.GEO_GRANULARITY] == geo_granularity]
270
236
  is_holdout = constants.EVALUATION_SET_VAR in gof_df.columns
271
237
 
272
- details = {}
238
+ metrics_dict = {}
273
239
  case = results.GoodnessOfFitCases.PASS
274
240
 
275
241
  if is_holdout:
@@ -281,29 +247,71 @@ class GoodnessOfFitCheck(
281
247
  set_metrics = gof_metrics[
282
248
  gof_metrics[constants.EVALUATION_SET_VAR] == evaluation_set
283
249
  ]
284
- _set_details_from_gof_dataframe(
285
- details=details,
250
+ _set_metrics_from_gof_dataframe(
251
+ metrics=metrics_dict,
286
252
  gof_df=set_metrics,
287
253
  geo_granularity=geo_granularity,
288
254
  suffix=suffix,
289
255
  )
290
- if details[f"{review_constants.R_SQUARED}_{suffix}"] <= 0:
256
+ if metrics_dict[f"{review_constants.R_SQUARED}{suffix}"] <= 0:
291
257
  case = results.GoodnessOfFitCases.REVIEW
258
+ return results.GoodnessOfFitCheckResult(
259
+ case=case,
260
+ metrics=results.GoodnessOfFitMetrics(
261
+ r_squared=metrics_dict[
262
+ f"{review_constants.R_SQUARED}{review_constants.ALL_SUFFIX}"
263
+ ],
264
+ mape=metrics_dict[
265
+ f"{review_constants.MAPE}{review_constants.ALL_SUFFIX}"
266
+ ],
267
+ wmape=metrics_dict[
268
+ f"{review_constants.WMAPE}{review_constants.ALL_SUFFIX}"
269
+ ],
270
+ r_squared_train=metrics_dict[
271
+ f"{review_constants.R_SQUARED}{review_constants.TRAIN_SUFFIX}"
272
+ ],
273
+ mape_train=metrics_dict[
274
+ f"{review_constants.MAPE}{review_constants.TRAIN_SUFFIX}"
275
+ ],
276
+ wmape_train=metrics_dict[
277
+ f"{review_constants.WMAPE}{review_constants.TRAIN_SUFFIX}"
278
+ ],
279
+ r_squared_test=metrics_dict[
280
+ f"{review_constants.R_SQUARED}{review_constants.TEST_SUFFIX}"
281
+ ],
282
+ mape_test=metrics_dict[
283
+ f"{review_constants.MAPE}{review_constants.TEST_SUFFIX}"
284
+ ],
285
+ wmape_test=metrics_dict[
286
+ f"{review_constants.WMAPE}{review_constants.TEST_SUFFIX}"
287
+ ],
288
+ ),
289
+ is_holdout=is_holdout,
290
+ )
292
291
  else:
293
- _set_details_from_gof_dataframe(
294
- details=details,
292
+ _set_metrics_from_gof_dataframe(
293
+ metrics=metrics_dict,
295
294
  gof_df=gof_metrics,
296
295
  geo_granularity=geo_granularity,
297
- suffix=None,
296
+ suffix=review_constants.ALL_SUFFIX,
298
297
  )
299
- if details[review_constants.R_SQUARED] <= 0:
298
+ if metrics_dict[review_constants.R_SQUARED] <= 0:
300
299
  case = results.GoodnessOfFitCases.REVIEW
301
-
302
- return results.GoodnessOfFitCheckResult(
303
- case=case,
304
- details=details,
305
- is_holdout=is_holdout,
306
- )
300
+ return results.GoodnessOfFitCheckResult(
301
+ case=case,
302
+ metrics=results.GoodnessOfFitMetrics(
303
+ r_squared=metrics_dict[
304
+ f"{review_constants.R_SQUARED}{review_constants.ALL_SUFFIX}"
305
+ ],
306
+ mape=metrics_dict[
307
+ f"{review_constants.MAPE}{review_constants.ALL_SUFFIX}"
308
+ ],
309
+ wmape=metrics_dict[
310
+ f"{review_constants.WMAPE}{review_constants.ALL_SUFFIX}"
311
+ ],
312
+ ),
313
+ is_holdout=is_holdout,
314
+ )
307
315
 
308
316
 
309
317
  # ==============================================================================
@@ -475,8 +483,10 @@ def _compute_channel_results(
475
483
  channel_results.append(
476
484
  results.ROIConsistencyChannelResult(
477
485
  case=case,
478
- details={},
479
486
  channel_name=channel,
487
+ prior_roi_lo=np.nan,
488
+ prior_roi_hi=np.nan,
489
+ posterior_roi_mean=np.nan,
480
490
  )
481
491
  )
482
492
  for i, channel in enumerate(channel_data.all_channels):
@@ -491,14 +501,10 @@ def _compute_channel_results(
491
501
  channel_results.append(
492
502
  results.ROIConsistencyChannelResult(
493
503
  case=case,
494
- details={
495
- review_constants.PRIOR_ROI_LO: channel_data.prior_roi_los[i],
496
- review_constants.PRIOR_ROI_HI: channel_data.prior_roi_his[i],
497
- review_constants.POSTERIOR_ROI_MEAN: (
498
- channel_data.posterior_means[i]
499
- ),
500
- },
501
504
  channel_name=channel,
505
+ prior_roi_lo=channel_data.prior_roi_los[i],
506
+ prior_roi_hi=channel_data.prior_roi_his[i],
507
+ posterior_roi_mean=channel_data.posterior_means[i],
502
508
  )
503
509
  )
504
510
  return channel_results
@@ -558,7 +564,7 @@ def _compute_aggregate_result(
558
564
 
559
565
  return results.ROIConsistencyCheckResult(
560
566
  case=aggregate_case,
561
- details=aggregate_details,
567
+ aggregate_details=aggregate_details,
562
568
  channel_results=channel_results,
563
569
  )
564
570
 
@@ -734,7 +740,7 @@ class PriorPosteriorShiftCheck(
734
740
  no_shift_channels.append(channel_name)
735
741
  channel_results.append(
736
742
  results.PriorPosteriorShiftChannelResult(
737
- case=case, details={}, channel_name=channel_name
743
+ case=case, channel_name=channel_name
738
744
  )
739
745
  )
740
746
  return channel_results, no_shift_channels
@@ -752,17 +758,13 @@ class PriorPosteriorShiftCheck(
752
758
 
753
759
  if no_shift_channels:
754
760
  agg_case = results.PriorPosteriorShiftAggregateCases.REVIEW
755
- final_details = {
756
- "channels_str": ", ".join(
757
- f"`{channel}`" for channel in no_shift_channels
758
- )
759
- }
760
761
  else:
761
762
  agg_case = results.PriorPosteriorShiftAggregateCases.PASS
762
- final_details = {}
763
763
 
764
764
  return results.PriorPosteriorShiftCheckResult(
765
- case=agg_case, details=final_details, channel_results=channel_results
765
+ case=agg_case,
766
+ channel_results=channel_results,
767
+ no_shift_channels=no_shift_channels,
766
768
  )
767
769
 
768
770
  def run(self) -> results.PriorPosteriorShiftCheckResult:
@@ -32,9 +32,9 @@ NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD = (
32
32
  R_SQUARED = "r_squared"
33
33
  MAPE = "mape"
34
34
  WMAPE = "wmape"
35
- ALL_SUFFIX = "all"
36
- TRAIN_SUFFIX = "train"
37
- TEST_SUFFIX = "test"
35
+ ALL_SUFFIX = ""
36
+ TRAIN_SUFFIX = "_train"
37
+ TEST_SUFFIX = "_test"
38
38
  EVALUATION_SET_SUFFIXES = (ALL_SUFFIX, TRAIN_SUFFIX, TEST_SUFFIX)
39
39
  MEAN = "mean"
40
40
  VARIANCE = "variance"
@@ -14,9 +14,13 @@
14
14
 
15
15
  """Data structures for the Model Quality Checks results."""
16
16
 
17
+ import abc
18
+ from collections.abc import Mapping
17
19
  import dataclasses
18
20
  import enum
19
21
  from typing import Any
22
+
23
+ from meridian.analysis.review import configs
20
24
  from meridian.analysis.review import constants
21
25
 
22
26
 
@@ -58,11 +62,16 @@ class ModelCheckCase(BaseCase):
58
62
 
59
63
 
60
64
  @dataclasses.dataclass(frozen=True)
61
- class BaseResultData:
65
+ class BaseResultData(abc.ABC):
62
66
  """Base class for check result data."""
63
67
 
64
68
  case: BaseCase
65
- details: dict[str, Any]
69
+
70
+ @property
71
+ @abc.abstractmethod
72
+ def details(self) -> Mapping[str, Any]:
73
+ """Returns the details for message formatting."""
74
+ raise NotImplementedError
66
75
 
67
76
 
68
77
  @dataclasses.dataclass(frozen=True)
@@ -145,17 +154,18 @@ class ConvergenceCheckResult(CheckResult):
145
154
  """The immutable result of the Convergence Check."""
146
155
 
147
156
  case: ConvergenceCases
157
+ config: configs.ConvergenceConfig
158
+ max_rhat: float
159
+ max_parameter: str
148
160
 
149
- def __post_init__(self):
150
- if self.case == ConvergenceCases.CONVERGED and (
151
- constants.CONVERGENCE_THRESHOLD not in self.details
152
- ):
153
- raise ValueError(
154
- "The message template 'The model has likely converged, as all"
155
- " parameters have R-hat values < {convergence_threshold}'. is"
156
- " missing required formatting arguments: convergence_threshold."
157
- f" Details: {self.details}."
158
- )
161
+ @property
162
+ def details(self) -> Mapping[str, Any]:
163
+ """The check result details."""
164
+ return {
165
+ constants.RHAT: self.max_rhat,
166
+ constants.PARAMETER: self.max_parameter,
167
+ constants.CONVERGENCE_THRESHOLD: self.config.convergence_threshold,
168
+ }
159
169
 
160
170
 
161
171
  # ==============================================================================
@@ -223,24 +233,21 @@ class BaselineCheckResult(CheckResult):
223
233
  """The immutable result of the Baseline Check."""
224
234
 
225
235
  case: BaselineCases
236
+ config: configs.BaselineConfig
237
+ negative_baseline_prob: float
226
238
 
227
- def __post_init__(self):
228
- if self.case is BaselineCases.PASS:
229
- return
230
- if any(
231
- key not in self.details
232
- for key in (
233
- constants.NEGATIVE_BASELINE_PROB,
234
- constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD,
235
- constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD,
236
- )
237
- ):
238
- raise ValueError(
239
- "The message template is missing required formatting arguments:"
240
- " negative_baseline_prob, negative_baseline_prob_fail_threshold,"
241
- " negative_baseline_prob_review_threshold. Details:"
242
- f" {self.details}."
243
- )
239
+ @property
240
+ def details(self) -> Mapping[str, Any]:
241
+ """The check result details."""
242
+ return {
243
+ constants.NEGATIVE_BASELINE_PROB: self.negative_baseline_prob,
244
+ constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD: (
245
+ self.config.negative_baseline_prob_fail_threshold
246
+ ),
247
+ constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD: (
248
+ self.config.negative_baseline_prob_review_threshold
249
+ ),
250
+ }
244
251
 
245
252
 
246
253
  # ==============================================================================
@@ -287,14 +294,15 @@ class BayesianPPPCheckResult(CheckResult):
287
294
  """The immutable result of the Bayesian Posterior Predictive P-value Check."""
288
295
 
289
296
  case: BayesianPPPCases
297
+ config: configs.BayesianPPPConfig
298
+ bayesian_ppp: float
290
299
 
291
- def __post_init__(self):
292
- if constants.BAYESIAN_PPP not in self.details:
293
- raise ValueError(
294
- "The message template is missing required formatting arguments:"
295
- " bayesian_ppp. Details:"
296
- f" {self.details}."
297
- )
300
+ @property
301
+ def details(self) -> Mapping[str, Any]:
302
+ """The check result details."""
303
+ return {
304
+ constants.BAYESIAN_PPP: self.bayesian_ppp,
305
+ }
298
306
 
299
307
 
300
308
  # ==============================================================================
@@ -337,55 +345,77 @@ class GoodnessOfFitCases(ModelCheckCase, enum.Enum):
337
345
  super().__init__(status, message_template, recommendation)
338
346
 
339
347
 
348
+ @dataclasses.dataclass(frozen=True)
349
+ class GoodnessOfFitMetrics:
350
+ """The metrics for the Goodness of Fit Check."""
351
+
352
+ r_squared: float
353
+ mape: float
354
+ wmape: float
355
+ r_squared_train: float | None = None
356
+ mape_train: float | None = None
357
+ wmape_train: float | None = None
358
+ r_squared_test: float | None = None
359
+ mape_test: float | None = None
360
+ wmape_test: float | None = None
361
+
362
+
340
363
  @dataclasses.dataclass(frozen=True)
341
364
  class GoodnessOfFitCheckResult(CheckResult):
342
365
  """The immutable result of the Goodness of Fit Check."""
343
366
 
344
367
  case: GoodnessOfFitCases
368
+ metrics: GoodnessOfFitMetrics
345
369
  is_holdout: bool = False
346
370
 
347
371
  def __post_init__(self):
348
372
  if self.is_holdout:
349
- required_keys = []
350
- for suffix in [
351
- constants.ALL_SUFFIX,
352
- constants.TRAIN_SUFFIX,
353
- constants.TEST_SUFFIX,
354
- ]:
355
- required_keys.extend([
356
- f"{constants.R_SQUARED}_{suffix}",
357
- f"{constants.MAPE}_{suffix}",
358
- f"{constants.WMAPE}_{suffix}",
359
- ])
360
- if any(key not in self.details for key in required_keys):
373
+ if any(
374
+ metric is None
375
+ for metric in (
376
+ self.metrics.r_squared_train,
377
+ self.metrics.mape_train,
378
+ self.metrics.wmape_train,
379
+ self.metrics.r_squared_test,
380
+ self.metrics.mape_test,
381
+ self.metrics.wmape_test,
382
+ )
383
+ ):
361
384
  raise ValueError(
362
385
  "The message template is missing required formatting arguments for"
363
- f" holdout case. Required keys: {required_keys}. Details:"
364
- f" {self.details}."
365
- )
366
- elif any(
367
- key not in self.details
368
- for key in (
369
- constants.R_SQUARED,
370
- constants.MAPE,
371
- constants.WMAPE,
386
+ " holdout case. Required keys: r_squared_train, mape_train,"
387
+ " wmape_train, r_squared_test, mape_test, wmape_test. Metrics:"
388
+ f" {self.metrics}."
372
389
  )
373
- ):
374
- raise ValueError(
375
- "The message template is missing required formatting arguments:"
376
- " r_squared, mape, wmape. Details:"
377
- f" {self.details}."
378
- )
390
+
391
+ @property
392
+ def details(self) -> Mapping[str, Any]:
393
+ """The check result details."""
394
+ return {
395
+ f"{constants.R_SQUARED}{constants.ALL_SUFFIX}": self.metrics.r_squared,
396
+ f"{constants.MAPE}{constants.ALL_SUFFIX}": self.metrics.mape,
397
+ f"{constants.WMAPE}{constants.ALL_SUFFIX}": self.metrics.wmape,
398
+ f"{constants.R_SQUARED}{constants.TRAIN_SUFFIX}": (
399
+ self.metrics.r_squared_train
400
+ ),
401
+ f"{constants.MAPE}{constants.TRAIN_SUFFIX}": self.metrics.mape_train,
402
+ f"{constants.WMAPE}{constants.TRAIN_SUFFIX}": self.metrics.wmape_train,
403
+ f"{constants.R_SQUARED}{constants.TEST_SUFFIX}": (
404
+ self.metrics.r_squared_test
405
+ ),
406
+ f"{constants.MAPE}{constants.TEST_SUFFIX}": self.metrics.mape_test,
407
+ f"{constants.WMAPE}{constants.TEST_SUFFIX}": self.metrics.wmape_test,
408
+ }
379
409
 
380
410
  @property
381
411
  def recommendation(self) -> str:
382
- """Returns the check result message."""
412
+ """The check result message."""
383
413
  if self.is_holdout:
384
414
  report_str = (
385
- "R-squared = {r_squared_all:.4f} (All),"
415
+ "R-squared = {r_squared:.4f} (All),"
386
416
  " {r_squared_train:.4f} (Train), {r_squared_test:.4f} (Test); MAPE"
387
- " = {mape_all:.4f} (All), {mape_train:.4f} (Train),"
388
- " {mape_test:.4f} (Test); wMAPE = {wmape_all:.4f} (All),"
417
+ " = {mape:.4f} (All), {mape_train:.4f} (Train),"
418
+ " {mape_test:.4f} (Test); wMAPE = {wmape:.4f} (All),"
389
419
  " {wmape_train:.4f} (Train), {wmape_test:.4f} (Test)".format(
390
420
  **self.details
391
421
  )
@@ -450,6 +480,18 @@ class ROIConsistencyChannelResult(ChannelResult):
450
480
  """The immutable result of ROI Consistency Check for a single channel."""
451
481
 
452
482
  case: ROIConsistencyChannelCases
483
+ prior_roi_lo: float
484
+ prior_roi_hi: float
485
+ posterior_roi_mean: float
486
+
487
+ @property
488
+ def details(self) -> Mapping[str, Any]:
489
+ """Returns the check result details."""
490
+ return {
491
+ constants.PRIOR_ROI_LO: self.prior_roi_lo,
492
+ constants.PRIOR_ROI_HI: self.prior_roi_hi,
493
+ constants.POSTERIOR_ROI_MEAN: self.posterior_roi_mean,
494
+ }
453
495
 
454
496
 
455
497
  @dataclasses.dataclass(frozen=True)
@@ -458,6 +500,12 @@ class ROIConsistencyCheckResult(CheckResult):
458
500
 
459
501
  case: ROIConsistencyAggregateCases
460
502
  channel_results: list[ROIConsistencyChannelResult]
503
+ aggregate_details: Mapping[str, Any]
504
+
505
+ @property
506
+ def details(self) -> Mapping[str, Any]:
507
+ """Returns the check result details."""
508
+ return self.aggregate_details
461
509
 
462
510
 
463
511
  # ==============================================================================
@@ -517,6 +565,11 @@ class PriorPosteriorShiftChannelResult(ChannelResult):
517
565
 
518
566
  case: PriorPosteriorShiftChannelCases
519
567
 
568
+ @property
569
+ def details(self) -> Mapping[str, Any]:
570
+ """Returns the check result details."""
571
+ return {}
572
+
520
573
 
521
574
  @dataclasses.dataclass(frozen=True)
522
575
  class PriorPosteriorShiftCheckResult(CheckResult):
@@ -524,6 +577,16 @@ class PriorPosteriorShiftCheckResult(CheckResult):
524
577
 
525
578
  case: PriorPosteriorShiftAggregateCases
526
579
  channel_results: list[PriorPosteriorShiftChannelResult]
580
+ no_shift_channels: list[str]
581
+
582
+ @property
583
+ def details(self) -> Mapping[str, Any]:
584
+ """Returns the check result details."""
585
+ return {
586
+ "channels_str": ", ".join(
587
+ f"`{channel}`" for channel in self.no_shift_channels
588
+ )
589
+ }
527
590
 
528
591
 
529
592
  # ==============================================================================
@@ -567,7 +630,7 @@ class ReviewSummary:
567
630
  return "\n".join(report)
568
631
 
569
632
  @property
570
- def checks_status(self) -> dict[str, str]:
633
+ def checks_status(self) -> Mapping[str, str]:
571
634
  """Returns a dictionary of check names and statuses."""
572
635
  return {
573
636
  result.__class__.__name__: result.case.status.name