google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
- google_meridian-1.3.0.dist-info/RECORD +62 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +280 -142
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +353 -169
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +14 -12
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +45 -50
- meridian/backend/__init__.py +698 -55
- meridian/backend/config.py +75 -16
- meridian/backend/test_utils.py +127 -1
- meridian/constants.py +52 -11
- meridian/data/input_data.py +7 -2
- meridian/data/test_utils.py +5 -3
- meridian/mlflow/autolog.py +2 -2
- meridian/model/__init__.py +1 -0
- meridian/model/adstock_hill.py +10 -9
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1580 -84
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +56 -50
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -9
- meridian/model/posterior_sampler.py +398 -391
- meridian/model/prior_distribution.py +114 -39
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +16 -8
- meridian/version.py +1 -1
- google_meridian-1.2.0.dist-info/RECORD +0 -52
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Configurations for the Model Quality Checks."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclasses.dataclass(frozen=True)
|
|
21
|
+
class BaseConfig:
|
|
22
|
+
"""Base class for all check configurations."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclasses.dataclass(frozen=True)
|
|
26
|
+
class ConvergenceConfig(BaseConfig):
|
|
27
|
+
"""Configuration for the Convergence Check.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
convergence_threshold: The threshold for the R-hat statistic to determine if
|
|
31
|
+
the model has converged. R-hat values below this are considered converged.
|
|
32
|
+
not_fully_convergence_threshold: The threshold for the R-hat statistic to
|
|
33
|
+
determine if the model is not fully converged but potentially acceptable.
|
|
34
|
+
R-hat values between `convergence_threshold` and this value are considered
|
|
35
|
+
not fully converged. R-hat values above this threshold are considered not
|
|
36
|
+
converged.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
convergence_threshold: float = 1.2
|
|
40
|
+
not_fully_convergence_threshold: float = 10.0
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclasses.dataclass(frozen=True)
|
|
44
|
+
class ROIConsistencyConfig(BaseConfig):
|
|
45
|
+
"""Configuration for the ROI Consistency Check.
|
|
46
|
+
|
|
47
|
+
This check verifies if the posterior median of the ROI falls within a
|
|
48
|
+
reasonable range of the prior distribution.
|
|
49
|
+
|
|
50
|
+
Attributes:
|
|
51
|
+
prior_lower_quantile: The lower quantile of the ROI prior distribution to
|
|
52
|
+
define the lower bound of the reasonable range.
|
|
53
|
+
prior_upper_quantile: The upper quantile of the ROI prior distribution to
|
|
54
|
+
define the upper bound of the reasonable range.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
prior_lower_quantile: float = 0.01
|
|
58
|
+
prior_upper_quantile: float = 0.99
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclasses.dataclass(frozen=True)
|
|
62
|
+
class BaselineConfig(BaseConfig):
|
|
63
|
+
"""Configuration for the Baseline Check.
|
|
64
|
+
|
|
65
|
+
This check warns if there is a high probability of a negative baseline.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
negative_baseline_prob_review_threshold: Probability threshold for a
|
|
69
|
+
review. If the probability of a negative baseline is above this value, a
|
|
70
|
+
review is issued.
|
|
71
|
+
negative_baseline_prob_fail_threshold: Probability threshold for a failure.
|
|
72
|
+
If the probability of a negative baseline is above this value, the check
|
|
73
|
+
fails.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
negative_baseline_prob_review_threshold: float = 0.2
|
|
77
|
+
negative_baseline_prob_fail_threshold: float = 0.8
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclasses.dataclass(frozen=True)
|
|
81
|
+
class BayesianPPPConfig(BaseConfig):
|
|
82
|
+
"""Configuration for the Bayesian Posterior Predictive P-value Check.
|
|
83
|
+
|
|
84
|
+
Attributes:
|
|
85
|
+
ppp_threshold: P-value threshold for posterior predictive check.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
ppp_threshold: float = 0.05
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclasses.dataclass(frozen=True)
|
|
92
|
+
class GoodnessOfFitConfig(BaseConfig):
|
|
93
|
+
"""An empty config for the Goodness of Fit Check."""
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclasses.dataclass(frozen=True)
|
|
97
|
+
class PriorPosteriorShiftConfig(BaseConfig):
|
|
98
|
+
"""Configuration for the Prior-Posterior Shift Check.
|
|
99
|
+
|
|
100
|
+
Attributes:
|
|
101
|
+
n_bootstraps: Number of bootstrap samples to use for calculating posterior
|
|
102
|
+
statistics.
|
|
103
|
+
alpha: Significance level for detecting a shift between prior and posterior
|
|
104
|
+
distributions.
|
|
105
|
+
seed: Random seed for reproducibility of bootstrap sampling.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
n_bootstraps: int = 1000
|
|
109
|
+
alpha: float = 0.05
|
|
110
|
+
seed: int = 42
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Constants for model review."""
|
|
16
|
+
|
|
17
|
+
RHAT = "rhat"
|
|
18
|
+
PARAMETER = "parameter"
|
|
19
|
+
CONVERGENCE_THRESHOLD = "convergence_threshold"
|
|
20
|
+
CHANNELS_LOW_HIGH = "channels_low_high"
|
|
21
|
+
PRIOR_ROI_LO = "prior_roi_lo"
|
|
22
|
+
PRIOR_ROI_HI = "prior_roi_hi"
|
|
23
|
+
POSTERIOR_ROI_MEAN = "posterior_roi_mean"
|
|
24
|
+
QUANTILE_NOT_DEFINED_MSG = "quantile_not_defined_msg"
|
|
25
|
+
INF_CHANNELS_MSG = "inf_channels_msg"
|
|
26
|
+
LOW_HIGH_CHANNELS_MSG = "low_high_channels_msg"
|
|
27
|
+
NEGATIVE_BASELINE_PROB = "negative_baseline_prob"
|
|
28
|
+
NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD = "negative_baseline_prob_fail_threshold"
|
|
29
|
+
NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD = (
|
|
30
|
+
"negative_baseline_prob_review_threshold"
|
|
31
|
+
)
|
|
32
|
+
R_SQUARED = "r_squared"
|
|
33
|
+
MAPE = "mape"
|
|
34
|
+
WMAPE = "wmape"
|
|
35
|
+
MEAN = "mean"
|
|
36
|
+
VARIANCE = "variance"
|
|
37
|
+
MEDIAN = "median"
|
|
38
|
+
Q1 = "q1"
|
|
39
|
+
Q3 = "q3"
|
|
40
|
+
BAYESIAN_PPP = "bayesian_ppp"
|
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Data structures for the Model Quality Checks results."""
|
|
16
|
+
|
|
17
|
+
import dataclasses
|
|
18
|
+
import enum
|
|
19
|
+
from typing import Any
|
|
20
|
+
from meridian.analysis.review import constants
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# ==============================================================================
|
|
24
|
+
# Base classes
|
|
25
|
+
# ==============================================================================
|
|
26
|
+
@enum.unique
|
|
27
|
+
class Status(enum.Enum):
|
|
28
|
+
PASS = enum.auto()
|
|
29
|
+
REVIEW = enum.auto()
|
|
30
|
+
FAIL = enum.auto()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseCase:
|
|
34
|
+
"""Base class for all check cases."""
|
|
35
|
+
|
|
36
|
+
status: Status
|
|
37
|
+
|
|
38
|
+
def __init__(self, status: Status):
|
|
39
|
+
"""Initializes the base case with a status."""
|
|
40
|
+
self.status = status
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ModelCheckCase(BaseCase):
|
|
44
|
+
"""Base class for all model-level check cases."""
|
|
45
|
+
|
|
46
|
+
message_template: str
|
|
47
|
+
recommendation: str | None = None
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
status: Status,
|
|
52
|
+
message_template: str,
|
|
53
|
+
recommendation: str | None = None,
|
|
54
|
+
):
|
|
55
|
+
super().__init__(status)
|
|
56
|
+
self.message_template = message_template
|
|
57
|
+
self.recommendation = recommendation
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclasses.dataclass(frozen=True)
|
|
61
|
+
class BaseResultData:
|
|
62
|
+
"""Base class for check result data."""
|
|
63
|
+
|
|
64
|
+
case: BaseCase
|
|
65
|
+
details: dict[str, Any]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclasses.dataclass(frozen=True)
|
|
69
|
+
class ChannelResult(BaseResultData):
|
|
70
|
+
"""Base class for channel-level check results."""
|
|
71
|
+
|
|
72
|
+
channel_name: str
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclasses.dataclass(frozen=True)
|
|
76
|
+
class CheckResult(BaseResultData):
|
|
77
|
+
"""Base class for model-level check results."""
|
|
78
|
+
|
|
79
|
+
case: ModelCheckCase
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def recommendation(self) -> str:
|
|
83
|
+
"""Returns the check result message."""
|
|
84
|
+
report_str = self.case.message_template.format(**self.details)
|
|
85
|
+
if self.case.recommendation:
|
|
86
|
+
return f"{report_str} {self.case.recommendation}"
|
|
87
|
+
return report_str
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
# ==============================================================================
|
|
91
|
+
# Check: Convergence
|
|
92
|
+
# ==============================================================================
|
|
93
|
+
NOT_FULLY_CONVERGED_RECOMMENDATION = (
|
|
94
|
+
"Manually inspect the parameters with high R-hat values to determine if the"
|
|
95
|
+
" results are acceptable for your use case, and consider increasing MCMC"
|
|
96
|
+
" iterations or investigating model misspecification."
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
NOT_CONVERGED_RECOMMENDATION = (
|
|
100
|
+
"We recommend increasing MCMC iterations or investigating model"
|
|
101
|
+
" misspecification (e.g., priors, multicollinearity) before proceeding."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@enum.unique
|
|
106
|
+
class ConvergenceCases(ModelCheckCase, enum.Enum):
|
|
107
|
+
"""Cases for the Convergence Check."""
|
|
108
|
+
|
|
109
|
+
CONVERGED = (
|
|
110
|
+
Status.PASS,
|
|
111
|
+
(
|
|
112
|
+
"The model has likely converged, as all parameters have R-hat values"
|
|
113
|
+
" < {convergence_threshold}."
|
|
114
|
+
),
|
|
115
|
+
None,
|
|
116
|
+
)
|
|
117
|
+
NOT_FULLY_CONVERGED = (
|
|
118
|
+
Status.FAIL,
|
|
119
|
+
(
|
|
120
|
+
"The model hasn't fully converged, and the `max_r_hat` for parameter"
|
|
121
|
+
" `{parameter}` is {rhat:.2f}."
|
|
122
|
+
),
|
|
123
|
+
NOT_FULLY_CONVERGED_RECOMMENDATION,
|
|
124
|
+
)
|
|
125
|
+
NOT_CONVERGED = (
|
|
126
|
+
Status.FAIL,
|
|
127
|
+
(
|
|
128
|
+
"The model hasn't converged, and the `max_r_hat` for parameter"
|
|
129
|
+
" `{parameter}` is {rhat:.2f}."
|
|
130
|
+
),
|
|
131
|
+
NOT_CONVERGED_RECOMMENDATION,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def __init__(
|
|
135
|
+
self,
|
|
136
|
+
status: Status,
|
|
137
|
+
message_template: str,
|
|
138
|
+
recommendation: str | None,
|
|
139
|
+
):
|
|
140
|
+
super().__init__(status, message_template, recommendation)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclasses.dataclass(frozen=True)
|
|
144
|
+
class ConvergenceCheckResult(CheckResult):
|
|
145
|
+
"""The immutable result of the Convergence Check."""
|
|
146
|
+
|
|
147
|
+
case: ConvergenceCases
|
|
148
|
+
|
|
149
|
+
def __post_init__(self):
|
|
150
|
+
if self.case == ConvergenceCases.CONVERGED and (
|
|
151
|
+
constants.CONVERGENCE_THRESHOLD not in self.details
|
|
152
|
+
):
|
|
153
|
+
raise ValueError(
|
|
154
|
+
"The message template 'The model has likely converged, as all"
|
|
155
|
+
" parameters have R-hat values < {convergence_threshold}'. is"
|
|
156
|
+
" missing required formatting arguments: convergence_threshold."
|
|
157
|
+
f" Details: {self.details}."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# ==============================================================================
|
|
162
|
+
# Check: Baseline
|
|
163
|
+
# ==============================================================================
|
|
164
|
+
_BASELINE_FAIL_RECOMMENDATION = (
|
|
165
|
+
"This high probability points to a statistical error and is a clear signal"
|
|
166
|
+
" that the model requires adjustment. The model is likely over-crediting"
|
|
167
|
+
" your treatments. Consider adjusting the model's settings, data, or priors"
|
|
168
|
+
" to correct this issue."
|
|
169
|
+
)
|
|
170
|
+
_BASELINE_REVIEW_RECOMMENDATION = (
|
|
171
|
+
"This indicates that the baseline time series occasionally dips into"
|
|
172
|
+
" negative values. We recommend visually inspecting the baseline time"
|
|
173
|
+
" series in the Model Fit charts, but don't be overly concerned. An"
|
|
174
|
+
" occasional, small dip may indicate minor statistical error, which is"
|
|
175
|
+
" inherent in any model."
|
|
176
|
+
)
|
|
177
|
+
_BASELINE_PASS_RECOMMENDATION = (
|
|
178
|
+
"We recommend visually inspecting the baseline time series in the Model "
|
|
179
|
+
"Fit charts to confirm this."
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@enum.unique
|
|
184
|
+
class BaselineCases(ModelCheckCase, enum.Enum):
|
|
185
|
+
"""Cases for the Baseline Check."""
|
|
186
|
+
|
|
187
|
+
PASS = (
|
|
188
|
+
Status.PASS,
|
|
189
|
+
(
|
|
190
|
+
"The posterior probability that the baseline is negative is"
|
|
191
|
+
" {negative_baseline_prob:.2f}."
|
|
192
|
+
),
|
|
193
|
+
_BASELINE_PASS_RECOMMENDATION,
|
|
194
|
+
)
|
|
195
|
+
REVIEW = (
|
|
196
|
+
Status.REVIEW,
|
|
197
|
+
(
|
|
198
|
+
"The posterior probability that the baseline is negative is"
|
|
199
|
+
" {negative_baseline_prob:.2f}."
|
|
200
|
+
),
|
|
201
|
+
_BASELINE_REVIEW_RECOMMENDATION,
|
|
202
|
+
)
|
|
203
|
+
FAIL = (
|
|
204
|
+
Status.FAIL,
|
|
205
|
+
(
|
|
206
|
+
"The posterior probability that the baseline is negative is"
|
|
207
|
+
" {negative_baseline_prob:.2f}."
|
|
208
|
+
),
|
|
209
|
+
_BASELINE_FAIL_RECOMMENDATION,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
status: Status,
|
|
215
|
+
message_template: str,
|
|
216
|
+
recommendation: str | None,
|
|
217
|
+
):
|
|
218
|
+
super().__init__(status, message_template, recommendation)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@dataclasses.dataclass(frozen=True)
|
|
222
|
+
class BaselineCheckResult(CheckResult):
|
|
223
|
+
"""The immutable result of the Baseline Check."""
|
|
224
|
+
|
|
225
|
+
case: BaselineCases
|
|
226
|
+
|
|
227
|
+
def __post_init__(self):
|
|
228
|
+
if self.case is BaselineCases.PASS:
|
|
229
|
+
return
|
|
230
|
+
if any(
|
|
231
|
+
key not in self.details
|
|
232
|
+
for key in (
|
|
233
|
+
constants.NEGATIVE_BASELINE_PROB,
|
|
234
|
+
constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD,
|
|
235
|
+
constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD,
|
|
236
|
+
)
|
|
237
|
+
):
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"The message template is missing required formatting arguments:"
|
|
240
|
+
" negative_baseline_prob, negative_baseline_prob_fail_threshold,"
|
|
241
|
+
" negative_baseline_prob_review_threshold. Details:"
|
|
242
|
+
f" {self.details}."
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
# ==============================================================================
|
|
247
|
+
# Check: Bayesian Posterior Predictive P-value
|
|
248
|
+
# ==============================================================================
|
|
249
|
+
_BAYESIAN_PPP_FAIL_RECOMMENDATION = (
|
|
250
|
+
"The observed total outcome is an extreme outlier compared to the model's"
|
|
251
|
+
" expected total outcomes, which suggests a systematic lack of fit. We"
|
|
252
|
+
" recommend reviewing input data quality and re-examining the model"
|
|
253
|
+
" specification (e.g., priors, transformations) to resolve this issue."
|
|
254
|
+
)
|
|
255
|
+
_BAYESIAN_PPP_PASS_RECOMMENDATION = (
|
|
256
|
+
"The observed total outcome is consistent with the model's posterior"
|
|
257
|
+
" predictive distribution."
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@enum.unique
|
|
262
|
+
class BayesianPPPCases(ModelCheckCase, enum.Enum):
|
|
263
|
+
"""Cases for the Bayesian Posterior Predictive P-value Check."""
|
|
264
|
+
|
|
265
|
+
PASS = (
|
|
266
|
+
Status.PASS,
|
|
267
|
+
"The Bayesian posterior predictive p-value is {bayesian_ppp:.2f}.",
|
|
268
|
+
_BAYESIAN_PPP_PASS_RECOMMENDATION,
|
|
269
|
+
)
|
|
270
|
+
FAIL = (
|
|
271
|
+
Status.FAIL,
|
|
272
|
+
"The Bayesian posterior predictive p-value is {bayesian_ppp:.2f}.",
|
|
273
|
+
_BAYESIAN_PPP_FAIL_RECOMMENDATION,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def __init__(
|
|
277
|
+
self,
|
|
278
|
+
status: Status,
|
|
279
|
+
message_template: str,
|
|
280
|
+
recommendation: str | None,
|
|
281
|
+
):
|
|
282
|
+
super().__init__(status, message_template, recommendation)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
@dataclasses.dataclass(frozen=True)
|
|
286
|
+
class BayesianPPPCheckResult(CheckResult):
|
|
287
|
+
"""The immutable result of the Bayesian Posterior Predictive P-value Check."""
|
|
288
|
+
|
|
289
|
+
case: BayesianPPPCases
|
|
290
|
+
|
|
291
|
+
def __post_init__(self):
|
|
292
|
+
if constants.BAYESIAN_PPP not in self.details:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
"The message template is missing required formatting arguments:"
|
|
295
|
+
" bayesian_ppp. Details:"
|
|
296
|
+
f" {self.details}."
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
# ==============================================================================
|
|
301
|
+
# Check: Goodness of Fit
|
|
302
|
+
# ==============================================================================
|
|
303
|
+
_GOODNESS_OF_FIT_REVIEW_RECOMMENDATION = (
|
|
304
|
+
"A negative R-squared signals a potential conflict between your priors and"
|
|
305
|
+
" the data, and it warrants investigation. If this conflict is intentional"
|
|
306
|
+
" (due to an informative prior), no further action is needed. If it's"
|
|
307
|
+
" unintentional, we recommend relaxing your priors to be less restrictive."
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
_GOODNESS_OF_FIT_PASS_RECOMMENDATION = (
|
|
311
|
+
"These goodness-of-fit metrics are intended for guidance and relative"
|
|
312
|
+
" comparison."
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
@enum.unique
|
|
317
|
+
class GoodnessOfFitCases(ModelCheckCase, enum.Enum):
|
|
318
|
+
"""Cases for the Goodness of Fit Check."""
|
|
319
|
+
|
|
320
|
+
PASS = (
|
|
321
|
+
Status.PASS,
|
|
322
|
+
(
|
|
323
|
+
"R-squared = {r_squared:.4f}, MAPE = {mape:.4f}, and wMAPE ="
|
|
324
|
+
" {wmape:.4f}."
|
|
325
|
+
),
|
|
326
|
+
_GOODNESS_OF_FIT_PASS_RECOMMENDATION,
|
|
327
|
+
)
|
|
328
|
+
REVIEW = (
|
|
329
|
+
Status.REVIEW,
|
|
330
|
+
(
|
|
331
|
+
"R-squared = {r_squared:.4f}, MAPE = {mape:.4f}, and wMAPE ="
|
|
332
|
+
" {wmape:.4f}."
|
|
333
|
+
),
|
|
334
|
+
_GOODNESS_OF_FIT_REVIEW_RECOMMENDATION,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def __init__(
|
|
338
|
+
self,
|
|
339
|
+
status: Status,
|
|
340
|
+
message_template: str,
|
|
341
|
+
recommendation: str | None,
|
|
342
|
+
):
|
|
343
|
+
super().__init__(status, message_template, recommendation)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@dataclasses.dataclass(frozen=True)
|
|
347
|
+
class GoodnessOfFitCheckResult(CheckResult):
|
|
348
|
+
"""The immutable result of the Goodness of Fit Check."""
|
|
349
|
+
|
|
350
|
+
case: GoodnessOfFitCases
|
|
351
|
+
|
|
352
|
+
def __post_init__(self):
|
|
353
|
+
if any(
|
|
354
|
+
key not in self.details
|
|
355
|
+
for key in (
|
|
356
|
+
constants.R_SQUARED,
|
|
357
|
+
constants.MAPE,
|
|
358
|
+
constants.WMAPE,
|
|
359
|
+
)
|
|
360
|
+
):
|
|
361
|
+
raise ValueError(
|
|
362
|
+
"The message template is missing required formatting arguments:"
|
|
363
|
+
" r_squared, mape, wmape. Details:"
|
|
364
|
+
f" {self.details}."
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
# ==============================================================================
|
|
369
|
+
# Check: ROI Consistency
|
|
370
|
+
# ==============================================================================
|
|
371
|
+
_ROI_CONSISTENCY_RECOMMENDATION = (
|
|
372
|
+
"Please review this result to determine if it is reasonable within your"
|
|
373
|
+
" business context."
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
@enum.unique
|
|
378
|
+
class ROIConsistencyChannelCases(BaseCase, enum.Enum):
|
|
379
|
+
"""Cases for ROI Consistency Check per channel."""
|
|
380
|
+
|
|
381
|
+
ROI_PASS = (Status.PASS, enum.auto())
|
|
382
|
+
ROI_LOW = (Status.REVIEW, enum.auto())
|
|
383
|
+
ROI_HIGH = (Status.REVIEW, enum.auto())
|
|
384
|
+
PRIOR_ROI_QUANTILE_INF = (Status.REVIEW, enum.auto())
|
|
385
|
+
QUANTILE_NOT_DEFINED = (Status.REVIEW, enum.auto())
|
|
386
|
+
|
|
387
|
+
def __init__(self, status: Status, unique_id: Any):
|
|
388
|
+
super().__init__(status)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class ROIConsistencyAggregateCases(ModelCheckCase, enum.Enum):
|
|
392
|
+
"""Cases for ROI Consistency Check aggregate result."""
|
|
393
|
+
|
|
394
|
+
PASS = (
|
|
395
|
+
Status.PASS,
|
|
396
|
+
(
|
|
397
|
+
"The posterior distribution of the ROI is within a reasonable range,"
|
|
398
|
+
" aligning with the custom priors you provided."
|
|
399
|
+
),
|
|
400
|
+
None,
|
|
401
|
+
)
|
|
402
|
+
REVIEW = (
|
|
403
|
+
Status.REVIEW,
|
|
404
|
+
"{quantile_not_defined_msg}{inf_channels_msg}{low_high_channels_msg}",
|
|
405
|
+
_ROI_CONSISTENCY_RECOMMENDATION,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
def __init__(
|
|
409
|
+
self,
|
|
410
|
+
status: Status,
|
|
411
|
+
message_template: str,
|
|
412
|
+
recommendation: str | None,
|
|
413
|
+
):
|
|
414
|
+
super().__init__(status, message_template, recommendation)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
@dataclasses.dataclass(frozen=True)
|
|
418
|
+
class ROIConsistencyChannelResult(ChannelResult):
|
|
419
|
+
"""The immutable result of ROI Consistency Check for a single channel."""
|
|
420
|
+
|
|
421
|
+
case: ROIConsistencyChannelCases
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@dataclasses.dataclass(frozen=True)
|
|
425
|
+
class ROIConsistencyCheckResult(CheckResult):
|
|
426
|
+
"""The immutable result of model-level ROI Consistency Check."""
|
|
427
|
+
|
|
428
|
+
case: ROIConsistencyAggregateCases
|
|
429
|
+
channel_results: list[ROIConsistencyChannelResult]
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
# ==============================================================================
|
|
433
|
+
# Check: Prior-Posterior Shift
|
|
434
|
+
# ==============================================================================
|
|
435
|
+
_PPS_REVIEW_RECOMMENDATION = (
|
|
436
|
+
"Please review these channels to see if this is expected (due to a strong"
|
|
437
|
+
" priors) or problematic (due to a weak signal)."
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
@enum.unique
|
|
442
|
+
class PriorPosteriorShiftChannelCases(BaseCase, enum.Enum):
|
|
443
|
+
"""Cases for Prior-Posterior Shift Check per channel."""
|
|
444
|
+
|
|
445
|
+
SHIFT = (Status.PASS, enum.auto())
|
|
446
|
+
NO_SHIFT = (Status.REVIEW, enum.auto())
|
|
447
|
+
|
|
448
|
+
def __init__(self, status: Status, unique_id: Any):
|
|
449
|
+
super().__init__(status)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
class PriorPosteriorShiftAggregateCases(ModelCheckCase, enum.Enum):
|
|
453
|
+
"""Cases for Prior-Posterior Shift Check aggregate result."""
|
|
454
|
+
|
|
455
|
+
PASS = (
|
|
456
|
+
Status.PASS,
|
|
457
|
+
(
|
|
458
|
+
"The model has successfully learned from the data. This is a positive"
|
|
459
|
+
" sign that your data was informative."
|
|
460
|
+
),
|
|
461
|
+
None,
|
|
462
|
+
)
|
|
463
|
+
REVIEW = (
|
|
464
|
+
Status.REVIEW,
|
|
465
|
+
(
|
|
466
|
+
"We've detected channel(s) {channels_str} where the posterior"
|
|
467
|
+
" distribution did not significantly shift from the prior. This"
|
|
468
|
+
" suggests the data signal for these channels was not strong enough"
|
|
469
|
+
" to update the model's beliefs."
|
|
470
|
+
),
|
|
471
|
+
_PPS_REVIEW_RECOMMENDATION,
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
def __init__(
|
|
475
|
+
self,
|
|
476
|
+
status: Status,
|
|
477
|
+
message_template: str,
|
|
478
|
+
recommendation: str | None,
|
|
479
|
+
):
|
|
480
|
+
super().__init__(status, message_template, recommendation)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@dataclasses.dataclass(frozen=True)
|
|
484
|
+
class PriorPosteriorShiftChannelResult(ChannelResult):
|
|
485
|
+
"""The result of Prior-Posterior Shift Check for a single channel."""
|
|
486
|
+
|
|
487
|
+
case: PriorPosteriorShiftChannelCases
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
@dataclasses.dataclass(frozen=True)
|
|
491
|
+
class PriorPosteriorShiftCheckResult(CheckResult):
|
|
492
|
+
"""The immutable result of model-level Prior-Posterior Shift Check."""
|
|
493
|
+
|
|
494
|
+
case: PriorPosteriorShiftAggregateCases
|
|
495
|
+
channel_results: list[PriorPosteriorShiftChannelResult]
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
# ==============================================================================
|
|
499
|
+
# Review Summary
|
|
500
|
+
# ==============================================================================
|
|
501
|
+
@dataclasses.dataclass(frozen=True)
|
|
502
|
+
class ReviewSummary:
|
|
503
|
+
"""The final summary of all model quality checks.
|
|
504
|
+
|
|
505
|
+
Attributes:
|
|
506
|
+
overall_status: The overall status of all checks.
|
|
507
|
+
summary_message: A summary message of all checks.
|
|
508
|
+
results: A list of all check results.
|
|
509
|
+
"""
|
|
510
|
+
|
|
511
|
+
overall_status: Status
|
|
512
|
+
summary_message: str
|
|
513
|
+
results: list[CheckResult]
|
|
514
|
+
|
|
515
|
+
def __repr__(self) -> str:
|
|
516
|
+
report = []
|
|
517
|
+
report.append("=" * 40)
|
|
518
|
+
report.append("Model Quality Checks")
|
|
519
|
+
report.append("=" * 40)
|
|
520
|
+
report.append(f"Overall Status: {self.overall_status.name}")
|
|
521
|
+
report.append(f"Summary: {self.summary_message}")
|
|
522
|
+
report.append("\nCheck Results:")
|
|
523
|
+
|
|
524
|
+
for result in self.results:
|
|
525
|
+
name = result.__class__.__name__
|
|
526
|
+
if name.endswith("CheckResult"):
|
|
527
|
+
title = name[: -len("CheckResult")]
|
|
528
|
+
else:
|
|
529
|
+
title = name
|
|
530
|
+
|
|
531
|
+
report.append("-" * 40)
|
|
532
|
+
report.append(f"{title} Check:")
|
|
533
|
+
report.append(f" Status: {result.case.status.name}")
|
|
534
|
+
report.append(f" Recommendation: {result.recommendation}")
|
|
535
|
+
|
|
536
|
+
return "\n".join(report)
|
|
537
|
+
|
|
538
|
+
@property
|
|
539
|
+
def checks_status(self) -> dict[str, str]:
|
|
540
|
+
"""Returns a dictionary of check names and statuses."""
|
|
541
|
+
return {
|
|
542
|
+
result.__class__.__name__: result.case.status.name
|
|
543
|
+
for result in self.results
|
|
544
|
+
}
|