google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,110 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Configurations for the Model Quality Checks."""
16
+
17
+ import dataclasses
18
+
19
+
20
+ @dataclasses.dataclass(frozen=True)
21
+ class BaseConfig:
22
+ """Base class for all check configurations."""
23
+
24
+
25
+ @dataclasses.dataclass(frozen=True)
26
+ class ConvergenceConfig(BaseConfig):
27
+ """Configuration for the Convergence Check.
28
+
29
+ Attributes:
30
+ convergence_threshold: The threshold for the R-hat statistic to determine if
31
+ the model has converged. R-hat values below this are considered converged.
32
+ not_fully_convergence_threshold: The threshold for the R-hat statistic to
33
+ determine if the model is not fully converged but potentially acceptable.
34
+ R-hat values between `convergence_threshold` and this value are considered
35
+ not fully converged. R-hat values above this threshold are considered not
36
+ converged.
37
+ """
38
+
39
+ convergence_threshold: float = 1.2
40
+ not_fully_convergence_threshold: float = 10.0
41
+
42
+
43
+ @dataclasses.dataclass(frozen=True)
44
+ class ROIConsistencyConfig(BaseConfig):
45
+ """Configuration for the ROI Consistency Check.
46
+
47
+ This check verifies if the posterior median of the ROI falls within a
48
+ reasonable range of the prior distribution.
49
+
50
+ Attributes:
51
+ prior_lower_quantile: The lower quantile of the ROI prior distribution to
52
+ define the lower bound of the reasonable range.
53
+ prior_upper_quantile: The upper quantile of the ROI prior distribution to
54
+ define the upper bound of the reasonable range.
55
+ """
56
+
57
+ prior_lower_quantile: float = 0.01
58
+ prior_upper_quantile: float = 0.99
59
+
60
+
61
+ @dataclasses.dataclass(frozen=True)
62
+ class BaselineConfig(BaseConfig):
63
+ """Configuration for the Baseline Check.
64
+
65
+ This check warns if there is a high probability of a negative baseline.
66
+
67
+ Attributes:
68
+ negative_baseline_prob_review_threshold: Probability threshold for a
69
+ review. If the probability of a negative baseline is above this value, a
70
+ review is issued.
71
+ negative_baseline_prob_fail_threshold: Probability threshold for a failure.
72
+ If the probability of a negative baseline is above this value, the check
73
+ fails.
74
+ """
75
+
76
+ negative_baseline_prob_review_threshold: float = 0.2
77
+ negative_baseline_prob_fail_threshold: float = 0.8
78
+
79
+
80
+ @dataclasses.dataclass(frozen=True)
81
+ class BayesianPPPConfig(BaseConfig):
82
+ """Configuration for the Bayesian Posterior Predictive P-value Check.
83
+
84
+ Attributes:
85
+ ppp_threshold: P-value threshold for posterior predictive check.
86
+ """
87
+
88
+ ppp_threshold: float = 0.05
89
+
90
+
91
+ @dataclasses.dataclass(frozen=True)
92
+ class GoodnessOfFitConfig(BaseConfig):
93
+ """An empty config for the Goodness of Fit Check."""
94
+
95
+
96
+ @dataclasses.dataclass(frozen=True)
97
+ class PriorPosteriorShiftConfig(BaseConfig):
98
+ """Configuration for the Prior-Posterior Shift Check.
99
+
100
+ Attributes:
101
+ n_bootstraps: Number of bootstrap samples to use for calculating posterior
102
+ statistics.
103
+ alpha: Significance level for detecting a shift between prior and posterior
104
+ distributions.
105
+ seed: Random seed for reproducibility of bootstrap sampling.
106
+ """
107
+
108
+ n_bootstraps: int = 1000
109
+ alpha: float = 0.05
110
+ seed: int = 42
@@ -0,0 +1,40 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Constants for model review."""
16
+
17
+ RHAT = "rhat"
18
+ PARAMETER = "parameter"
19
+ CONVERGENCE_THRESHOLD = "convergence_threshold"
20
+ CHANNELS_LOW_HIGH = "channels_low_high"
21
+ PRIOR_ROI_LO = "prior_roi_lo"
22
+ PRIOR_ROI_HI = "prior_roi_hi"
23
+ POSTERIOR_ROI_MEAN = "posterior_roi_mean"
24
+ QUANTILE_NOT_DEFINED_MSG = "quantile_not_defined_msg"
25
+ INF_CHANNELS_MSG = "inf_channels_msg"
26
+ LOW_HIGH_CHANNELS_MSG = "low_high_channels_msg"
27
+ NEGATIVE_BASELINE_PROB = "negative_baseline_prob"
28
+ NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD = "negative_baseline_prob_fail_threshold"
29
+ NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD = (
30
+ "negative_baseline_prob_review_threshold"
31
+ )
32
+ R_SQUARED = "r_squared"
33
+ MAPE = "mape"
34
+ WMAPE = "wmape"
35
+ MEAN = "mean"
36
+ VARIANCE = "variance"
37
+ MEDIAN = "median"
38
+ Q1 = "q1"
39
+ Q3 = "q3"
40
+ BAYESIAN_PPP = "bayesian_ppp"
@@ -0,0 +1,544 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Data structures for the Model Quality Checks results."""
16
+
17
+ import dataclasses
18
+ import enum
19
+ from typing import Any
20
+ from meridian.analysis.review import constants
21
+
22
+
23
+ # ==============================================================================
24
+ # Base classes
25
+ # ==============================================================================
26
+ @enum.unique
27
+ class Status(enum.Enum):
28
+ PASS = enum.auto()
29
+ REVIEW = enum.auto()
30
+ FAIL = enum.auto()
31
+
32
+
33
+ class BaseCase:
34
+ """Base class for all check cases."""
35
+
36
+ status: Status
37
+
38
+ def __init__(self, status: Status):
39
+ """Initializes the base case with a status."""
40
+ self.status = status
41
+
42
+
43
+ class ModelCheckCase(BaseCase):
44
+ """Base class for all model-level check cases."""
45
+
46
+ message_template: str
47
+ recommendation: str | None = None
48
+
49
+ def __init__(
50
+ self,
51
+ status: Status,
52
+ message_template: str,
53
+ recommendation: str | None = None,
54
+ ):
55
+ super().__init__(status)
56
+ self.message_template = message_template
57
+ self.recommendation = recommendation
58
+
59
+
60
+ @dataclasses.dataclass(frozen=True)
61
+ class BaseResultData:
62
+ """Base class for check result data."""
63
+
64
+ case: BaseCase
65
+ details: dict[str, Any]
66
+
67
+
68
+ @dataclasses.dataclass(frozen=True)
69
+ class ChannelResult(BaseResultData):
70
+ """Base class for channel-level check results."""
71
+
72
+ channel_name: str
73
+
74
+
75
+ @dataclasses.dataclass(frozen=True)
76
+ class CheckResult(BaseResultData):
77
+ """Base class for model-level check results."""
78
+
79
+ case: ModelCheckCase
80
+
81
+ @property
82
+ def recommendation(self) -> str:
83
+ """Returns the check result message."""
84
+ report_str = self.case.message_template.format(**self.details)
85
+ if self.case.recommendation:
86
+ return f"{report_str} {self.case.recommendation}"
87
+ return report_str
88
+
89
+
90
+ # ==============================================================================
91
+ # Check: Convergence
92
+ # ==============================================================================
93
+ NOT_FULLY_CONVERGED_RECOMMENDATION = (
94
+ "Manually inspect the parameters with high R-hat values to determine if the"
95
+ " results are acceptable for your use case, and consider increasing MCMC"
96
+ " iterations or investigating model misspecification."
97
+ )
98
+
99
+ NOT_CONVERGED_RECOMMENDATION = (
100
+ "We recommend increasing MCMC iterations or investigating model"
101
+ " misspecification (e.g., priors, multicollinearity) before proceeding."
102
+ )
103
+
104
+
105
+ @enum.unique
106
+ class ConvergenceCases(ModelCheckCase, enum.Enum):
107
+ """Cases for the Convergence Check."""
108
+
109
+ CONVERGED = (
110
+ Status.PASS,
111
+ (
112
+ "The model has likely converged, as all parameters have R-hat values"
113
+ " < {convergence_threshold}."
114
+ ),
115
+ None,
116
+ )
117
+ NOT_FULLY_CONVERGED = (
118
+ Status.FAIL,
119
+ (
120
+ "The model hasn't fully converged, and the `max_r_hat` for parameter"
121
+ " `{parameter}` is {rhat:.2f}."
122
+ ),
123
+ NOT_FULLY_CONVERGED_RECOMMENDATION,
124
+ )
125
+ NOT_CONVERGED = (
126
+ Status.FAIL,
127
+ (
128
+ "The model hasn't converged, and the `max_r_hat` for parameter"
129
+ " `{parameter}` is {rhat:.2f}."
130
+ ),
131
+ NOT_CONVERGED_RECOMMENDATION,
132
+ )
133
+
134
+ def __init__(
135
+ self,
136
+ status: Status,
137
+ message_template: str,
138
+ recommendation: str | None,
139
+ ):
140
+ super().__init__(status, message_template, recommendation)
141
+
142
+
143
+ @dataclasses.dataclass(frozen=True)
144
+ class ConvergenceCheckResult(CheckResult):
145
+ """The immutable result of the Convergence Check."""
146
+
147
+ case: ConvergenceCases
148
+
149
+ def __post_init__(self):
150
+ if self.case == ConvergenceCases.CONVERGED and (
151
+ constants.CONVERGENCE_THRESHOLD not in self.details
152
+ ):
153
+ raise ValueError(
154
+ "The message template 'The model has likely converged, as all"
155
+ " parameters have R-hat values < {convergence_threshold}'. is"
156
+ " missing required formatting arguments: convergence_threshold."
157
+ f" Details: {self.details}."
158
+ )
159
+
160
+
161
+ # ==============================================================================
162
+ # Check: Baseline
163
+ # ==============================================================================
164
+ _BASELINE_FAIL_RECOMMENDATION = (
165
+ "This high probability points to a statistical error and is a clear signal"
166
+ " that the model requires adjustment. The model is likely over-crediting"
167
+ " your treatments. Consider adjusting the model's settings, data, or priors"
168
+ " to correct this issue."
169
+ )
170
+ _BASELINE_REVIEW_RECOMMENDATION = (
171
+ "This indicates that the baseline time series occasionally dips into"
172
+ " negative values. We recommend visually inspecting the baseline time"
173
+ " series in the Model Fit charts, but don't be overly concerned. An"
174
+ " occasional, small dip may indicate minor statistical error, which is"
175
+ " inherent in any model."
176
+ )
177
+ _BASELINE_PASS_RECOMMENDATION = (
178
+ "We recommend visually inspecting the baseline time series in the Model "
179
+ "Fit charts to confirm this."
180
+ )
181
+
182
+
183
+ @enum.unique
184
+ class BaselineCases(ModelCheckCase, enum.Enum):
185
+ """Cases for the Baseline Check."""
186
+
187
+ PASS = (
188
+ Status.PASS,
189
+ (
190
+ "The posterior probability that the baseline is negative is"
191
+ " {negative_baseline_prob:.2f}."
192
+ ),
193
+ _BASELINE_PASS_RECOMMENDATION,
194
+ )
195
+ REVIEW = (
196
+ Status.REVIEW,
197
+ (
198
+ "The posterior probability that the baseline is negative is"
199
+ " {negative_baseline_prob:.2f}."
200
+ ),
201
+ _BASELINE_REVIEW_RECOMMENDATION,
202
+ )
203
+ FAIL = (
204
+ Status.FAIL,
205
+ (
206
+ "The posterior probability that the baseline is negative is"
207
+ " {negative_baseline_prob:.2f}."
208
+ ),
209
+ _BASELINE_FAIL_RECOMMENDATION,
210
+ )
211
+
212
+ def __init__(
213
+ self,
214
+ status: Status,
215
+ message_template: str,
216
+ recommendation: str | None,
217
+ ):
218
+ super().__init__(status, message_template, recommendation)
219
+
220
+
221
+ @dataclasses.dataclass(frozen=True)
222
+ class BaselineCheckResult(CheckResult):
223
+ """The immutable result of the Baseline Check."""
224
+
225
+ case: BaselineCases
226
+
227
+ def __post_init__(self):
228
+ if self.case is BaselineCases.PASS:
229
+ return
230
+ if any(
231
+ key not in self.details
232
+ for key in (
233
+ constants.NEGATIVE_BASELINE_PROB,
234
+ constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD,
235
+ constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD,
236
+ )
237
+ ):
238
+ raise ValueError(
239
+ "The message template is missing required formatting arguments:"
240
+ " negative_baseline_prob, negative_baseline_prob_fail_threshold,"
241
+ " negative_baseline_prob_review_threshold. Details:"
242
+ f" {self.details}."
243
+ )
244
+
245
+
246
+ # ==============================================================================
247
+ # Check: Bayesian Posterior Predictive P-value
248
+ # ==============================================================================
249
+ _BAYESIAN_PPP_FAIL_RECOMMENDATION = (
250
+ "The observed total outcome is an extreme outlier compared to the model's"
251
+ " expected total outcomes, which suggests a systematic lack of fit. We"
252
+ " recommend reviewing input data quality and re-examining the model"
253
+ " specification (e.g., priors, transformations) to resolve this issue."
254
+ )
255
+ _BAYESIAN_PPP_PASS_RECOMMENDATION = (
256
+ "The observed total outcome is consistent with the model's posterior"
257
+ " predictive distribution."
258
+ )
259
+
260
+
261
+ @enum.unique
262
+ class BayesianPPPCases(ModelCheckCase, enum.Enum):
263
+ """Cases for the Bayesian Posterior Predictive P-value Check."""
264
+
265
+ PASS = (
266
+ Status.PASS,
267
+ "The Bayesian posterior predictive p-value is {bayesian_ppp:.2f}.",
268
+ _BAYESIAN_PPP_PASS_RECOMMENDATION,
269
+ )
270
+ FAIL = (
271
+ Status.FAIL,
272
+ "The Bayesian posterior predictive p-value is {bayesian_ppp:.2f}.",
273
+ _BAYESIAN_PPP_FAIL_RECOMMENDATION,
274
+ )
275
+
276
+ def __init__(
277
+ self,
278
+ status: Status,
279
+ message_template: str,
280
+ recommendation: str | None,
281
+ ):
282
+ super().__init__(status, message_template, recommendation)
283
+
284
+
285
+ @dataclasses.dataclass(frozen=True)
286
+ class BayesianPPPCheckResult(CheckResult):
287
+ """The immutable result of the Bayesian Posterior Predictive P-value Check."""
288
+
289
+ case: BayesianPPPCases
290
+
291
+ def __post_init__(self):
292
+ if constants.BAYESIAN_PPP not in self.details:
293
+ raise ValueError(
294
+ "The message template is missing required formatting arguments:"
295
+ " bayesian_ppp. Details:"
296
+ f" {self.details}."
297
+ )
298
+
299
+
300
+ # ==============================================================================
301
+ # Check: Goodness of Fit
302
+ # ==============================================================================
303
+ _GOODNESS_OF_FIT_REVIEW_RECOMMENDATION = (
304
+ "A negative R-squared signals a potential conflict between your priors and"
305
+ " the data, and it warrants investigation. If this conflict is intentional"
306
+ " (due to an informative prior), no further action is needed. If it's"
307
+ " unintentional, we recommend relaxing your priors to be less restrictive."
308
+ )
309
+
310
+ _GOODNESS_OF_FIT_PASS_RECOMMENDATION = (
311
+ "These goodness-of-fit metrics are intended for guidance and relative"
312
+ " comparison."
313
+ )
314
+
315
+
316
+ @enum.unique
317
+ class GoodnessOfFitCases(ModelCheckCase, enum.Enum):
318
+ """Cases for the Goodness of Fit Check."""
319
+
320
+ PASS = (
321
+ Status.PASS,
322
+ (
323
+ "R-squared = {r_squared:.4f}, MAPE = {mape:.4f}, and wMAPE ="
324
+ " {wmape:.4f}."
325
+ ),
326
+ _GOODNESS_OF_FIT_PASS_RECOMMENDATION,
327
+ )
328
+ REVIEW = (
329
+ Status.REVIEW,
330
+ (
331
+ "R-squared = {r_squared:.4f}, MAPE = {mape:.4f}, and wMAPE ="
332
+ " {wmape:.4f}."
333
+ ),
334
+ _GOODNESS_OF_FIT_REVIEW_RECOMMENDATION,
335
+ )
336
+
337
+ def __init__(
338
+ self,
339
+ status: Status,
340
+ message_template: str,
341
+ recommendation: str | None,
342
+ ):
343
+ super().__init__(status, message_template, recommendation)
344
+
345
+
346
+ @dataclasses.dataclass(frozen=True)
347
+ class GoodnessOfFitCheckResult(CheckResult):
348
+ """The immutable result of the Goodness of Fit Check."""
349
+
350
+ case: GoodnessOfFitCases
351
+
352
+ def __post_init__(self):
353
+ if any(
354
+ key not in self.details
355
+ for key in (
356
+ constants.R_SQUARED,
357
+ constants.MAPE,
358
+ constants.WMAPE,
359
+ )
360
+ ):
361
+ raise ValueError(
362
+ "The message template is missing required formatting arguments:"
363
+ " r_squared, mape, wmape. Details:"
364
+ f" {self.details}."
365
+ )
366
+
367
+
368
+ # ==============================================================================
369
+ # Check: ROI Consistency
370
+ # ==============================================================================
371
+ _ROI_CONSISTENCY_RECOMMENDATION = (
372
+ "Please review this result to determine if it is reasonable within your"
373
+ " business context."
374
+ )
375
+
376
+
377
+ @enum.unique
378
+ class ROIConsistencyChannelCases(BaseCase, enum.Enum):
379
+ """Cases for ROI Consistency Check per channel."""
380
+
381
+ ROI_PASS = (Status.PASS, enum.auto())
382
+ ROI_LOW = (Status.REVIEW, enum.auto())
383
+ ROI_HIGH = (Status.REVIEW, enum.auto())
384
+ PRIOR_ROI_QUANTILE_INF = (Status.REVIEW, enum.auto())
385
+ QUANTILE_NOT_DEFINED = (Status.REVIEW, enum.auto())
386
+
387
+ def __init__(self, status: Status, unique_id: Any):
388
+ super().__init__(status)
389
+
390
+
391
+ class ROIConsistencyAggregateCases(ModelCheckCase, enum.Enum):
392
+ """Cases for ROI Consistency Check aggregate result."""
393
+
394
+ PASS = (
395
+ Status.PASS,
396
+ (
397
+ "The posterior distribution of the ROI is within a reasonable range,"
398
+ " aligning with the custom priors you provided."
399
+ ),
400
+ None,
401
+ )
402
+ REVIEW = (
403
+ Status.REVIEW,
404
+ "{quantile_not_defined_msg}{inf_channels_msg}{low_high_channels_msg}",
405
+ _ROI_CONSISTENCY_RECOMMENDATION,
406
+ )
407
+
408
+ def __init__(
409
+ self,
410
+ status: Status,
411
+ message_template: str,
412
+ recommendation: str | None,
413
+ ):
414
+ super().__init__(status, message_template, recommendation)
415
+
416
+
417
+ @dataclasses.dataclass(frozen=True)
418
+ class ROIConsistencyChannelResult(ChannelResult):
419
+ """The immutable result of ROI Consistency Check for a single channel."""
420
+
421
+ case: ROIConsistencyChannelCases
422
+
423
+
424
+ @dataclasses.dataclass(frozen=True)
425
+ class ROIConsistencyCheckResult(CheckResult):
426
+ """The immutable result of model-level ROI Consistency Check."""
427
+
428
+ case: ROIConsistencyAggregateCases
429
+ channel_results: list[ROIConsistencyChannelResult]
430
+
431
+
432
+ # ==============================================================================
433
+ # Check: Prior-Posterior Shift
434
+ # ==============================================================================
435
+ _PPS_REVIEW_RECOMMENDATION = (
436
+ "Please review these channels to see if this is expected (due to a strong"
437
+ " priors) or problematic (due to a weak signal)."
438
+ )
439
+
440
+
441
+ @enum.unique
442
+ class PriorPosteriorShiftChannelCases(BaseCase, enum.Enum):
443
+ """Cases for Prior-Posterior Shift Check per channel."""
444
+
445
+ SHIFT = (Status.PASS, enum.auto())
446
+ NO_SHIFT = (Status.REVIEW, enum.auto())
447
+
448
+ def __init__(self, status: Status, unique_id: Any):
449
+ super().__init__(status)
450
+
451
+
452
+ class PriorPosteriorShiftAggregateCases(ModelCheckCase, enum.Enum):
453
+ """Cases for Prior-Posterior Shift Check aggregate result."""
454
+
455
+ PASS = (
456
+ Status.PASS,
457
+ (
458
+ "The model has successfully learned from the data. This is a positive"
459
+ " sign that your data was informative."
460
+ ),
461
+ None,
462
+ )
463
+ REVIEW = (
464
+ Status.REVIEW,
465
+ (
466
+ "We've detected channel(s) {channels_str} where the posterior"
467
+ " distribution did not significantly shift from the prior. This"
468
+ " suggests the data signal for these channels was not strong enough"
469
+ " to update the model's beliefs."
470
+ ),
471
+ _PPS_REVIEW_RECOMMENDATION,
472
+ )
473
+
474
+ def __init__(
475
+ self,
476
+ status: Status,
477
+ message_template: str,
478
+ recommendation: str | None,
479
+ ):
480
+ super().__init__(status, message_template, recommendation)
481
+
482
+
483
+ @dataclasses.dataclass(frozen=True)
484
+ class PriorPosteriorShiftChannelResult(ChannelResult):
485
+ """The result of Prior-Posterior Shift Check for a single channel."""
486
+
487
+ case: PriorPosteriorShiftChannelCases
488
+
489
+
490
+ @dataclasses.dataclass(frozen=True)
491
+ class PriorPosteriorShiftCheckResult(CheckResult):
492
+ """The immutable result of model-level Prior-Posterior Shift Check."""
493
+
494
+ case: PriorPosteriorShiftAggregateCases
495
+ channel_results: list[PriorPosteriorShiftChannelResult]
496
+
497
+
498
+ # ==============================================================================
499
+ # Review Summary
500
+ # ==============================================================================
501
+ @dataclasses.dataclass(frozen=True)
502
+ class ReviewSummary:
503
+ """The final summary of all model quality checks.
504
+
505
+ Attributes:
506
+ overall_status: The overall status of all checks.
507
+ summary_message: A summary message of all checks.
508
+ results: A list of all check results.
509
+ """
510
+
511
+ overall_status: Status
512
+ summary_message: str
513
+ results: list[CheckResult]
514
+
515
+ def __repr__(self) -> str:
516
+ report = []
517
+ report.append("=" * 40)
518
+ report.append("Model Quality Checks")
519
+ report.append("=" * 40)
520
+ report.append(f"Overall Status: {self.overall_status.name}")
521
+ report.append(f"Summary: {self.summary_message}")
522
+ report.append("\nCheck Results:")
523
+
524
+ for result in self.results:
525
+ name = result.__class__.__name__
526
+ if name.endswith("CheckResult"):
527
+ title = name[: -len("CheckResult")]
528
+ else:
529
+ title = name
530
+
531
+ report.append("-" * 40)
532
+ report.append(f"{title} Check:")
533
+ report.append(f" Status: {result.case.status.name}")
534
+ report.append(f" Recommendation: {result.recommendation}")
535
+
536
+ return "\n".join(report)
537
+
538
+ @property
539
+ def checks_status(self) -> dict[str, str]:
540
+ """Returns a dictionary of check names and statuses."""
541
+ return {
542
+ result.__class__.__name__: result.case.status.name
543
+ for result in self.results
544
+ }