google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google_meridian-1.3.1.dist-info/METADATA +209 -0
- google_meridian-1.3.1.dist-info/RECORD +76 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +179 -105
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +227 -87
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +21 -34
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +41 -57
- meridian/backend/__init__.py +457 -118
- meridian/backend/test_utils.py +162 -0
- meridian/constants.py +39 -3
- meridian/model/__init__.py +1 -0
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1309 -196
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +55 -49
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -0
- meridian/model/posterior_sampler.py +39 -32
- meridian/model/prior_distribution.py +12 -2
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +11 -3
- meridian/version.py +1 -1
- schema/__init__.py +18 -0
- schema/serde/__init__.py +26 -0
- schema/serde/constants.py +48 -0
- schema/serde/distribution.py +515 -0
- schema/serde/eda_spec.py +192 -0
- schema/serde/function_registry.py +143 -0
- schema/serde/hyperparameters.py +363 -0
- schema/serde/inference_data.py +105 -0
- schema/serde/marketing_data.py +1321 -0
- schema/serde/meridian_serde.py +413 -0
- schema/serde/serde.py +47 -0
- schema/serde/test_data.py +4608 -0
- schema/utils/__init__.py +17 -0
- schema/utils/time_record.py +156 -0
- google_meridian-1.2.1.dist-info/METADATA +0 -409
- google_meridian-1.2.1.dist-info/RECORD +0 -52
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,721 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Implementation of the Model Quality Checks."""
|
|
16
|
+
|
|
17
|
+
import abc
|
|
18
|
+
from collections.abc import Sequence
|
|
19
|
+
import dataclasses
|
|
20
|
+
from typing import Generic, TypeVar
|
|
21
|
+
import warnings
|
|
22
|
+
|
|
23
|
+
from meridian import backend
|
|
24
|
+
from meridian import constants
|
|
25
|
+
from meridian.analysis import analyzer as analyzer_module
|
|
26
|
+
from meridian.analysis.review import configs
|
|
27
|
+
from meridian.analysis.review import constants as review_constants
|
|
28
|
+
from meridian.analysis.review import results
|
|
29
|
+
from meridian.model import model
|
|
30
|
+
import numpy as np
|
|
31
|
+
|
|
32
|
+
ConfigType = TypeVar("ConfigType", bound=configs.BaseConfig)
|
|
33
|
+
ResultType = TypeVar("ResultType", bound=results.CheckResult)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BaseCheck(abc.ABC, Generic[ConfigType, ResultType]):
|
|
37
|
+
"""A generic, abstract base class for a single, runnable quality check."""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
meridian: model.Meridian,
|
|
42
|
+
analyzer: analyzer_module.Analyzer,
|
|
43
|
+
config: ConfigType,
|
|
44
|
+
):
|
|
45
|
+
self._meridian = meridian
|
|
46
|
+
self._analyzer = analyzer
|
|
47
|
+
self._config = config
|
|
48
|
+
|
|
49
|
+
@abc.abstractmethod
|
|
50
|
+
def run(self) -> ResultType:
|
|
51
|
+
"""Executes the check.
|
|
52
|
+
|
|
53
|
+
The return type uses the generic ResultType, making it specific for each
|
|
54
|
+
subclass.
|
|
55
|
+
"""
|
|
56
|
+
raise NotImplementedError()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# ==============================================================================
|
|
60
|
+
# Check: Convergence
|
|
61
|
+
# ==============================================================================
|
|
62
|
+
class ConvergenceCheck(
|
|
63
|
+
BaseCheck[configs.ConvergenceConfig, results.ConvergenceCheckResult]
|
|
64
|
+
):
|
|
65
|
+
"""Checks for model convergence."""
|
|
66
|
+
|
|
67
|
+
def run(self) -> results.ConvergenceCheckResult:
|
|
68
|
+
rhats = self._analyzer.get_rhat()
|
|
69
|
+
with warnings.catch_warnings():
|
|
70
|
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
|
71
|
+
max_rhats = {k: np.nanmax(v) for k, v in rhats.items()}
|
|
72
|
+
|
|
73
|
+
valid_rhat_items = [
|
|
74
|
+
item for item in max_rhats.items() if not np.isnan(item[1])
|
|
75
|
+
]
|
|
76
|
+
if not valid_rhat_items:
|
|
77
|
+
return results.ConvergenceCheckResult(
|
|
78
|
+
case=results.ConvergenceCases.CONVERGED,
|
|
79
|
+
details={
|
|
80
|
+
review_constants.RHAT: np.nan,
|
|
81
|
+
review_constants.PARAMETER: np.nan,
|
|
82
|
+
review_constants.CONVERGENCE_THRESHOLD: (
|
|
83
|
+
self._config.convergence_threshold
|
|
84
|
+
),
|
|
85
|
+
},
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
max_parameter, max_rhat = max(max_rhats.items(), key=lambda item: item[1])
|
|
89
|
+
|
|
90
|
+
details = {
|
|
91
|
+
review_constants.RHAT: max_rhat,
|
|
92
|
+
review_constants.PARAMETER: max_parameter,
|
|
93
|
+
review_constants.CONVERGENCE_THRESHOLD: (
|
|
94
|
+
self._config.convergence_threshold
|
|
95
|
+
),
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Case 1: Converged.
|
|
99
|
+
if max_rhat < self._config.convergence_threshold:
|
|
100
|
+
return results.ConvergenceCheckResult(
|
|
101
|
+
case=results.ConvergenceCases.CONVERGED,
|
|
102
|
+
details=details,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Case 2: Not fully converged, but potentially acceptable.
|
|
106
|
+
elif (
|
|
107
|
+
self._config.convergence_threshold
|
|
108
|
+
<= max_rhat
|
|
109
|
+
< self._config.not_fully_convergence_threshold
|
|
110
|
+
):
|
|
111
|
+
return results.ConvergenceCheckResult(
|
|
112
|
+
case=results.ConvergenceCases.NOT_FULLY_CONVERGED,
|
|
113
|
+
details=details,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Case 3: Not converged and unacceptable.
|
|
117
|
+
else: # max_rhat >= divergence_threshold
|
|
118
|
+
return results.ConvergenceCheckResult(
|
|
119
|
+
case=results.ConvergenceCases.NOT_CONVERGED,
|
|
120
|
+
details=details,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# ==============================================================================
|
|
125
|
+
# Check: Baseline
|
|
126
|
+
# ==============================================================================
|
|
127
|
+
class BaselineCheck(
|
|
128
|
+
BaseCheck[configs.BaselineConfig, results.BaselineCheckResult]
|
|
129
|
+
):
|
|
130
|
+
"""Checks for negative baseline probability."""
|
|
131
|
+
|
|
132
|
+
def run(self) -> results.BaselineCheckResult:
|
|
133
|
+
prob = self._analyzer.negative_baseline_probability()
|
|
134
|
+
details = {
|
|
135
|
+
review_constants.NEGATIVE_BASELINE_PROB: prob,
|
|
136
|
+
review_constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD: (
|
|
137
|
+
self._config.negative_baseline_prob_fail_threshold
|
|
138
|
+
),
|
|
139
|
+
review_constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD: (
|
|
140
|
+
self._config.negative_baseline_prob_review_threshold
|
|
141
|
+
),
|
|
142
|
+
}
|
|
143
|
+
# Case 1: FAIL
|
|
144
|
+
if prob > self._config.negative_baseline_prob_fail_threshold:
|
|
145
|
+
return results.BaselineCheckResult(
|
|
146
|
+
case=results.BaselineCases.FAIL,
|
|
147
|
+
details=details,
|
|
148
|
+
)
|
|
149
|
+
# Case 2: REVIEW
|
|
150
|
+
elif prob >= self._config.negative_baseline_prob_review_threshold:
|
|
151
|
+
return results.BaselineCheckResult(
|
|
152
|
+
case=results.BaselineCases.REVIEW,
|
|
153
|
+
details=details,
|
|
154
|
+
)
|
|
155
|
+
# Case 3: PASS
|
|
156
|
+
else:
|
|
157
|
+
return results.BaselineCheckResult(
|
|
158
|
+
case=results.BaselineCases.PASS, details=details
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# ==============================================================================
|
|
163
|
+
# Check: Bayesian Posterior Predictive P-value
|
|
164
|
+
# ==============================================================================
|
|
165
|
+
class BayesianPPPCheck(
|
|
166
|
+
BaseCheck[configs.BayesianPPPConfig, results.BayesianPPPCheckResult]
|
|
167
|
+
):
|
|
168
|
+
"""Checks for Bayesian Posterior Predictive P-value."""
|
|
169
|
+
|
|
170
|
+
def run(self) -> results.BayesianPPPCheckResult:
|
|
171
|
+
mmm = self._meridian
|
|
172
|
+
analyzer = self._analyzer
|
|
173
|
+
|
|
174
|
+
outcome = mmm.kpi
|
|
175
|
+
if mmm.revenue_per_kpi is not None:
|
|
176
|
+
outcome *= mmm.revenue_per_kpi
|
|
177
|
+
total_outcome_actual = np.sum(outcome)
|
|
178
|
+
|
|
179
|
+
total_outcome_posterior = analyzer.expected_outcome(
|
|
180
|
+
aggregate_times=True, aggregate_geos=True
|
|
181
|
+
)
|
|
182
|
+
total_outcome_expected = np.asarray(total_outcome_posterior).flatten()
|
|
183
|
+
|
|
184
|
+
total_outcome_expected_mean = np.mean(total_outcome_expected)
|
|
185
|
+
|
|
186
|
+
bayesian_ppp = np.mean(
|
|
187
|
+
np.abs(total_outcome_expected - total_outcome_expected_mean)
|
|
188
|
+
>= np.abs(total_outcome_actual - total_outcome_expected_mean)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
details = {
|
|
192
|
+
review_constants.BAYESIAN_PPP: bayesian_ppp,
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
if bayesian_ppp >= self._config.ppp_threshold:
|
|
196
|
+
return results.BayesianPPPCheckResult(
|
|
197
|
+
case=results.BayesianPPPCases.PASS,
|
|
198
|
+
details=details,
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
return results.BayesianPPPCheckResult(
|
|
202
|
+
case=results.BayesianPPPCases.FAIL,
|
|
203
|
+
details=details,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# ==============================================================================
|
|
208
|
+
# Check: Goodness of Fit
|
|
209
|
+
# ==============================================================================
|
|
210
|
+
class GoodnessOfFitCheck(
|
|
211
|
+
BaseCheck[configs.GoodnessOfFitConfig, results.GoodnessOfFitCheckResult]
|
|
212
|
+
):
|
|
213
|
+
"""Checks for goodness of fit of the model."""
|
|
214
|
+
|
|
215
|
+
def run(self) -> results.GoodnessOfFitCheckResult:
|
|
216
|
+
gof_ds = self._analyzer.predictive_accuracy()
|
|
217
|
+
gof_df = gof_ds.to_dataframe().reset_index()
|
|
218
|
+
|
|
219
|
+
geo_granularity = (
|
|
220
|
+
constants.NATIONAL if self._meridian.n_geos == 1 else constants.GEO
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
gof_metrics = gof_df[gof_df[constants.GEO_GRANULARITY] == geo_granularity]
|
|
224
|
+
if constants.EVALUATION_SET_VAR in gof_df.columns:
|
|
225
|
+
gof_metrics = gof_metrics[
|
|
226
|
+
gof_metrics[constants.EVALUATION_SET_VAR] == constants.ALL_DATA
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
gof_metrics_pivoted = gof_metrics.pivot(
|
|
230
|
+
index=constants.GEO_GRANULARITY,
|
|
231
|
+
columns=constants.METRIC,
|
|
232
|
+
values=constants.VALUE,
|
|
233
|
+
)
|
|
234
|
+
gof_metrics_series = gof_metrics_pivoted.loc[geo_granularity]
|
|
235
|
+
|
|
236
|
+
r_squared = gof_metrics_series[constants.R_SQUARED]
|
|
237
|
+
mape = gof_metrics_series[constants.MAPE]
|
|
238
|
+
wmape = gof_metrics_series[constants.WMAPE]
|
|
239
|
+
|
|
240
|
+
details = {
|
|
241
|
+
review_constants.R_SQUARED: r_squared,
|
|
242
|
+
review_constants.MAPE: mape,
|
|
243
|
+
review_constants.WMAPE: wmape,
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
if r_squared > 0:
|
|
247
|
+
return results.GoodnessOfFitCheckResult(
|
|
248
|
+
case=results.GoodnessOfFitCases.PASS,
|
|
249
|
+
details=details,
|
|
250
|
+
)
|
|
251
|
+
else: # r_squared <= 0
|
|
252
|
+
return results.GoodnessOfFitCheckResult(
|
|
253
|
+
case=results.GoodnessOfFitCases.REVIEW,
|
|
254
|
+
details=details,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
# ==============================================================================
|
|
259
|
+
# Check: ROI Consistency
|
|
260
|
+
# ==============================================================================
|
|
261
|
+
def _format_roi_channels_msg(channels: np.ndarray, direction: str) -> str:
|
|
262
|
+
if channels.size == 0:
|
|
263
|
+
return ""
|
|
264
|
+
plural = "s" if channels.size > 1 else ""
|
|
265
|
+
return (
|
|
266
|
+
f"an unusually {direction} ROI estimate (for channel{plural} "
|
|
267
|
+
f"{', '.join(f'`{channel}`' for channel in channels)})"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _inf_prior_quantiles_channels(
|
|
272
|
+
channels: np.ndarray,
|
|
273
|
+
lo_roi_quantiles: np.ndarray,
|
|
274
|
+
hi_roi_quantiles: np.ndarray,
|
|
275
|
+
) -> np.ndarray:
|
|
276
|
+
"""Returns channels with infinite prior quantiles.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
channels: The names of the channels.
|
|
280
|
+
lo_roi_quantiles: The lower quantiles of the ROI prior.
|
|
281
|
+
hi_roi_quantiles: The upper quantiles of the ROI prior.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
An array of channel names with infinite prior quantiles.
|
|
285
|
+
"""
|
|
286
|
+
inf_mask = np.isinf(lo_roi_quantiles) | np.isinf(hi_roi_quantiles)
|
|
287
|
+
return channels[inf_mask]
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@dataclasses.dataclass
|
|
291
|
+
class _ROIConsistencyChannelData:
|
|
292
|
+
"""A data structure for auxiliary data for the ROI Consistency Check.
|
|
293
|
+
|
|
294
|
+
Attributes:
|
|
295
|
+
prior_roi_los: Lower quantile values from ROI priors, corresponding to the
|
|
296
|
+
channels in `all_channels`.
|
|
297
|
+
prior_roi_his: Upper quantile values from ROI priors, corresponding to the
|
|
298
|
+
channels in `all_channels`.
|
|
299
|
+
posterior_means: Mean values of ROI posteriors, corresponding to the
|
|
300
|
+
channels in `all_channels`.
|
|
301
|
+
all_channels: Channel names for which quantile computations were successful;
|
|
302
|
+
channels for which quantiles could not be computed are skipped. They are
|
|
303
|
+
ordered with media channels (`roi_m`) followed by reach and frequency (RF)
|
|
304
|
+
channels (`roi_rf`).
|
|
305
|
+
inf_channels: Channels with infinite prior quantiles.
|
|
306
|
+
low_roi_channels: Channels with posterior means below their prior's lower
|
|
307
|
+
quantile.
|
|
308
|
+
high_roi_channels: Channels with posterior means above their prior's upper
|
|
309
|
+
quantile.
|
|
310
|
+
quantile_not_defined_channels: Channel names for which quantiles could not
|
|
311
|
+
be computed.
|
|
312
|
+
quantile_not_defined_parameters: Parameters for which the quantile method is
|
|
313
|
+
not implemented.
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
prior_roi_los: np.ndarray
|
|
317
|
+
prior_roi_his: np.ndarray
|
|
318
|
+
posterior_means: np.ndarray
|
|
319
|
+
all_channels: np.ndarray
|
|
320
|
+
inf_channels: np.ndarray
|
|
321
|
+
low_roi_channels: np.ndarray
|
|
322
|
+
high_roi_channels: np.ndarray
|
|
323
|
+
quantile_not_defined_channels: np.ndarray
|
|
324
|
+
quantile_not_defined_parameters: list[backend.tfd.Distribution] = (
|
|
325
|
+
dataclasses.field(default_factory=list)
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _get_roi_consistency_channel_data(
|
|
330
|
+
prior_rois: Sequence[backend.tfd.Distribution],
|
|
331
|
+
posterior_rois: Sequence[backend.tfd.Distribution],
|
|
332
|
+
channels_names: Sequence[Sequence[str]],
|
|
333
|
+
prior_lower_quantile: float,
|
|
334
|
+
prior_upper_quantile: float,
|
|
335
|
+
) -> _ROIConsistencyChannelData:
|
|
336
|
+
"""Returns the channel-level data for the ROI Consistency Check.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
prior_rois: The ROI priors for all channels, in the same order as
|
|
340
|
+
`channels_names`.
|
|
341
|
+
posterior_rois: The ROI posteriors for all channels, in the same order as
|
|
342
|
+
`channels_names`.
|
|
343
|
+
channels_names: The names of all channels, with media channels (`roi_m`)
|
|
344
|
+
followed by any reach and frequency (RF) channels (`roi_rf`).
|
|
345
|
+
prior_lower_quantile: The lower quantile of the ROI prior.
|
|
346
|
+
prior_upper_quantile: The upper quantile of the ROI prior.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
A _ROIConsistencyChannelData object containing the channel-level data for
|
|
350
|
+
the ROI Consistency Check.
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
prior_roi_los_parts = []
|
|
354
|
+
prior_roi_his_parts = []
|
|
355
|
+
posterior_means_parts = []
|
|
356
|
+
all_channels_parts = []
|
|
357
|
+
quantile_not_defined_parameters = []
|
|
358
|
+
quantile_not_defined_channels = []
|
|
359
|
+
|
|
360
|
+
for prior_roi, posterior_roi, channels in zip(
|
|
361
|
+
prior_rois, posterior_rois, channels_names
|
|
362
|
+
):
|
|
363
|
+
try:
|
|
364
|
+
prior_roi_lo = prior_roi.quantile(
|
|
365
|
+
prior_lower_quantile,
|
|
366
|
+
)
|
|
367
|
+
prior_roi_hi = prior_roi.quantile(
|
|
368
|
+
prior_upper_quantile,
|
|
369
|
+
)
|
|
370
|
+
posterior_mean = np.mean(posterior_roi, axis=(0, 1))
|
|
371
|
+
|
|
372
|
+
n_channels = len(channels)
|
|
373
|
+
prior_roi_lo = np.broadcast_to(prior_roi_lo, shape=(n_channels,))
|
|
374
|
+
prior_roi_hi = np.broadcast_to(prior_roi_hi, shape=(n_channels,))
|
|
375
|
+
|
|
376
|
+
prior_roi_los_parts.append(prior_roi_lo)
|
|
377
|
+
prior_roi_his_parts.append(prior_roi_hi)
|
|
378
|
+
posterior_means_parts.append(posterior_mean)
|
|
379
|
+
all_channels_parts.append(channels)
|
|
380
|
+
except NotImplementedError:
|
|
381
|
+
quantile_not_defined_parameters.append(prior_roi)
|
|
382
|
+
quantile_not_defined_channels.extend(channels)
|
|
383
|
+
|
|
384
|
+
if prior_roi_los_parts:
|
|
385
|
+
prior_roi_los = np.concatenate(prior_roi_los_parts)
|
|
386
|
+
prior_roi_his = np.concatenate(prior_roi_his_parts)
|
|
387
|
+
posterior_means = np.concatenate(posterior_means_parts)
|
|
388
|
+
all_channels = np.concatenate(all_channels_parts)
|
|
389
|
+
else:
|
|
390
|
+
prior_roi_los = np.array([])
|
|
391
|
+
prior_roi_his = np.array([])
|
|
392
|
+
posterior_means = np.array([])
|
|
393
|
+
all_channels = np.array([])
|
|
394
|
+
|
|
395
|
+
inf_channels = _inf_prior_quantiles_channels(
|
|
396
|
+
channels=all_channels,
|
|
397
|
+
lo_roi_quantiles=prior_roi_los,
|
|
398
|
+
hi_roi_quantiles=prior_roi_his,
|
|
399
|
+
)
|
|
400
|
+
low_roi_channels = all_channels[posterior_means < prior_roi_los]
|
|
401
|
+
high_roi_channels = all_channels[posterior_means > prior_roi_his]
|
|
402
|
+
|
|
403
|
+
return _ROIConsistencyChannelData(
|
|
404
|
+
prior_roi_los=prior_roi_los,
|
|
405
|
+
prior_roi_his=prior_roi_his,
|
|
406
|
+
posterior_means=posterior_means,
|
|
407
|
+
all_channels=all_channels,
|
|
408
|
+
inf_channels=inf_channels,
|
|
409
|
+
low_roi_channels=low_roi_channels,
|
|
410
|
+
high_roi_channels=high_roi_channels,
|
|
411
|
+
quantile_not_defined_parameters=quantile_not_defined_parameters,
|
|
412
|
+
quantile_not_defined_channels=np.array(quantile_not_defined_channels),
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def _compute_channel_results(
|
|
417
|
+
channel_data: _ROIConsistencyChannelData,
|
|
418
|
+
) -> list[results.ROIConsistencyChannelResult]:
|
|
419
|
+
"""Returns the channel-level results for the ROI Consistency Check."""
|
|
420
|
+
|
|
421
|
+
channel_results = []
|
|
422
|
+
for channel in channel_data.quantile_not_defined_channels:
|
|
423
|
+
case = results.ROIConsistencyChannelCases.QUANTILE_NOT_DEFINED
|
|
424
|
+
channel_results.append(
|
|
425
|
+
results.ROIConsistencyChannelResult(
|
|
426
|
+
case=case,
|
|
427
|
+
details={},
|
|
428
|
+
channel_name=channel,
|
|
429
|
+
)
|
|
430
|
+
)
|
|
431
|
+
for i, channel in enumerate(channel_data.all_channels):
|
|
432
|
+
if channel in channel_data.inf_channels:
|
|
433
|
+
case = results.ROIConsistencyChannelCases.PRIOR_ROI_QUANTILE_INF
|
|
434
|
+
elif channel in channel_data.low_roi_channels:
|
|
435
|
+
case = results.ROIConsistencyChannelCases.ROI_LOW
|
|
436
|
+
elif channel in channel_data.high_roi_channels:
|
|
437
|
+
case = results.ROIConsistencyChannelCases.ROI_HIGH
|
|
438
|
+
else:
|
|
439
|
+
case = results.ROIConsistencyChannelCases.ROI_PASS
|
|
440
|
+
channel_results.append(
|
|
441
|
+
results.ROIConsistencyChannelResult(
|
|
442
|
+
case=case,
|
|
443
|
+
details={
|
|
444
|
+
review_constants.PRIOR_ROI_LO: channel_data.prior_roi_los[i],
|
|
445
|
+
review_constants.PRIOR_ROI_HI: channel_data.prior_roi_his[i],
|
|
446
|
+
review_constants.POSTERIOR_ROI_MEAN: (
|
|
447
|
+
channel_data.posterior_means[i]
|
|
448
|
+
),
|
|
449
|
+
},
|
|
450
|
+
channel_name=channel,
|
|
451
|
+
)
|
|
452
|
+
)
|
|
453
|
+
return channel_results
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def _compute_aggregate_result(
|
|
457
|
+
channel_data: _ROIConsistencyChannelData,
|
|
458
|
+
) -> results.ROIConsistencyCheckResult:
|
|
459
|
+
"""Returns the aggregate result for the ROI Consistency Check."""
|
|
460
|
+
channel_results = _compute_channel_results(channel_data=channel_data)
|
|
461
|
+
|
|
462
|
+
aggregate_details = {}
|
|
463
|
+
|
|
464
|
+
# Channel Case 5: QUANTILE_NOT_DEFINED
|
|
465
|
+
if channel_data.quantile_not_defined_parameters:
|
|
466
|
+
aggregate_details[review_constants.QUANTILE_NOT_DEFINED_MSG] = (
|
|
467
|
+
"The quantile method is not defined for the following parameters:"
|
|
468
|
+
f" {channel_data.quantile_not_defined_parameters}. The ROI"
|
|
469
|
+
" Consistency check cannot be performed for these parameters."
|
|
470
|
+
)
|
|
471
|
+
else:
|
|
472
|
+
aggregate_details[review_constants.QUANTILE_NOT_DEFINED_MSG] = ""
|
|
473
|
+
|
|
474
|
+
# Channel Case 4: PRIOR_ROI_QUANTILE_INF
|
|
475
|
+
if channel_data.inf_channels.size > 0:
|
|
476
|
+
aggregate_details[review_constants.INF_CHANNELS_MSG] = (
|
|
477
|
+
"Prior ROI quantiles are infinite for channels:"
|
|
478
|
+
f" {', '.join(channel_data.inf_channels)}"
|
|
479
|
+
)
|
|
480
|
+
else:
|
|
481
|
+
aggregate_details[review_constants.INF_CHANNELS_MSG] = ""
|
|
482
|
+
|
|
483
|
+
# Channel Cases 2-3: ROI_LOW, ROI_HIGH
|
|
484
|
+
if (
|
|
485
|
+
channel_data.low_roi_channels.size > 0
|
|
486
|
+
or channel_data.high_roi_channels.size > 0
|
|
487
|
+
):
|
|
488
|
+
low_msg = _format_roi_channels_msg(channel_data.low_roi_channels, "low")
|
|
489
|
+
high_msg = _format_roi_channels_msg(channel_data.high_roi_channels, "high")
|
|
490
|
+
|
|
491
|
+
channels_low_high = " and ".join(filter(None, [low_msg, high_msg]))
|
|
492
|
+
aggregate_details[review_constants.LOW_HIGH_CHANNELS_MSG] = (
|
|
493
|
+
f"We've detected {channels_low_high} where the posterior point"
|
|
494
|
+
" estimate falls into the extreme tail of your custom prior."
|
|
495
|
+
)
|
|
496
|
+
else:
|
|
497
|
+
aggregate_details[review_constants.LOW_HIGH_CHANNELS_MSG] = ""
|
|
498
|
+
|
|
499
|
+
if (
|
|
500
|
+
aggregate_details[review_constants.QUANTILE_NOT_DEFINED_MSG]
|
|
501
|
+
or aggregate_details[review_constants.INF_CHANNELS_MSG]
|
|
502
|
+
or aggregate_details[review_constants.LOW_HIGH_CHANNELS_MSG]
|
|
503
|
+
):
|
|
504
|
+
aggregate_case = results.ROIConsistencyAggregateCases.REVIEW
|
|
505
|
+
else:
|
|
506
|
+
aggregate_case = results.ROIConsistencyAggregateCases.PASS
|
|
507
|
+
|
|
508
|
+
return results.ROIConsistencyCheckResult(
|
|
509
|
+
case=aggregate_case,
|
|
510
|
+
details=aggregate_details,
|
|
511
|
+
channel_results=channel_results,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
class ROIConsistencyCheck(
|
|
516
|
+
BaseCheck[configs.ROIConsistencyConfig, results.ROIConsistencyCheckResult]
|
|
517
|
+
):
|
|
518
|
+
"""Checks if ROI posterior mean is in tails of ROI prior."""
|
|
519
|
+
|
|
520
|
+
def run(self) -> results.ROIConsistencyCheckResult:
|
|
521
|
+
prior_rois = []
|
|
522
|
+
posterior_rois = []
|
|
523
|
+
channel_names = []
|
|
524
|
+
if (
|
|
525
|
+
constants.MEDIA_CHANNEL
|
|
526
|
+
in self._meridian.inference_data.posterior.coords
|
|
527
|
+
):
|
|
528
|
+
prior_rois.append(self._meridian.model_spec.prior.roi_m)
|
|
529
|
+
posterior_rois.append(self._meridian.inference_data.posterior.roi_m)
|
|
530
|
+
channel_names.append(
|
|
531
|
+
self._meridian.inference_data.posterior.media_channel.values
|
|
532
|
+
)
|
|
533
|
+
if constants.RF_CHANNEL in self._meridian.inference_data.posterior.coords:
|
|
534
|
+
prior_rois.append(self._meridian.model_spec.prior.roi_rf)
|
|
535
|
+
posterior_rois.append(self._meridian.inference_data.posterior.roi_rf)
|
|
536
|
+
channel_names.append(
|
|
537
|
+
self._meridian.inference_data.posterior.rf_channel.values
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
channel_data = _get_roi_consistency_channel_data(
|
|
541
|
+
prior_rois=prior_rois,
|
|
542
|
+
posterior_rois=posterior_rois,
|
|
543
|
+
channels_names=channel_names,
|
|
544
|
+
prior_lower_quantile=self._config.prior_lower_quantile,
|
|
545
|
+
prior_upper_quantile=self._config.prior_upper_quantile,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
return _compute_aggregate_result(channel_data=channel_data)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
# ==============================================================================
|
|
552
|
+
# Check: Prior-Posterior Shift
|
|
553
|
+
# ==============================================================================
|
|
554
|
+
def _bootstrap(x: np.ndarray, n_bootstraps: int) -> np.ndarray:
|
|
555
|
+
"""Performs non-parametric bootstrap resampling on the columns of x."""
|
|
556
|
+
n_rows, n_cols = x.shape
|
|
557
|
+
x_bs = np.empty((n_bootstraps, n_rows, n_cols))
|
|
558
|
+
for i in range(n_bootstraps):
|
|
559
|
+
col_indices = np.random.choice(n_cols, n_cols, replace=True)
|
|
560
|
+
x_bs[i, :, :] = x[:, col_indices]
|
|
561
|
+
return x_bs
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def _calculate_new_statistics_from_samples(
|
|
565
|
+
mmm: model.Meridian, n_bootstraps: int, var_name: str, n_channels: int
|
|
566
|
+
) -> dict[str, np.ndarray]:
|
|
567
|
+
"""Calculate Mean, Median, Q1, and Q3 from posterior samples."""
|
|
568
|
+
n_chains = len(mmm.inference_data.posterior.coords[constants.CHAIN])
|
|
569
|
+
n_draws = len(mmm.inference_data.posterior.coords[constants.DRAW])
|
|
570
|
+
n_posterior_samples = n_chains * n_draws
|
|
571
|
+
|
|
572
|
+
posterior_samples = np.transpose(
|
|
573
|
+
np.reshape(
|
|
574
|
+
mmm.inference_data.posterior.variables[var_name].values,
|
|
575
|
+
(n_posterior_samples, n_channels),
|
|
576
|
+
)
|
|
577
|
+
)
|
|
578
|
+
x = _bootstrap(
|
|
579
|
+
posterior_samples, n_bootstraps
|
|
580
|
+
) # x is (bootstraps, channels, samples)
|
|
581
|
+
|
|
582
|
+
mean = np.mean(x, axis=-1)
|
|
583
|
+
median = np.quantile(x, q=0.5, axis=-1)
|
|
584
|
+
q1 = np.quantile(x, q=0.25, axis=-1)
|
|
585
|
+
q3 = np.quantile(x, q=0.75, axis=-1)
|
|
586
|
+
|
|
587
|
+
return {
|
|
588
|
+
review_constants.MEAN: mean,
|
|
589
|
+
review_constants.MEDIAN: median,
|
|
590
|
+
review_constants.Q1: q1,
|
|
591
|
+
review_constants.Q3: q3,
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def _get_shifted_mask(
|
|
596
|
+
posterior_stat: np.ndarray, prior_stat: np.ndarray, alpha: float
|
|
597
|
+
) -> np.ndarray:
|
|
598
|
+
"""Returns a boolean mask indicating which channels have a significant shift."""
|
|
599
|
+
prior_stat_b = prior_stat[np.newaxis, ...]
|
|
600
|
+
shift_1 = np.mean(posterior_stat > prior_stat_b, axis=0) < alpha
|
|
601
|
+
shift_2 = np.mean(posterior_stat < prior_stat_b, axis=0) < alpha
|
|
602
|
+
return shift_1 | shift_2
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
class PriorPosteriorShiftCheck(
|
|
606
|
+
BaseCheck[
|
|
607
|
+
configs.PriorPosteriorShiftConfig,
|
|
608
|
+
results.PriorPosteriorShiftCheckResult,
|
|
609
|
+
]
|
|
610
|
+
):
|
|
611
|
+
"""Checks for a significant shift between prior and posterior of ROI."""
|
|
612
|
+
|
|
613
|
+
# Tuple of (channel_results, no_shift_channels)
|
|
614
|
+
_CHANNEL_TYPE_RESULT = tuple[
|
|
615
|
+
list[results.PriorPosteriorShiftChannelResult],
|
|
616
|
+
list[str],
|
|
617
|
+
]
|
|
618
|
+
|
|
619
|
+
def _run_for_channel_type(self, channel_type: str) -> _CHANNEL_TYPE_RESULT:
|
|
620
|
+
"""Runs the prior-posterior shift check for a given channel type.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
channel_type: The channel type ('media_channel' or 'rf_channel') to run
|
|
624
|
+
the check for.
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
A tuple of (`channel_results`, `no_shift_channels`).
|
|
628
|
+
"""
|
|
629
|
+
if channel_type not in self._meridian.inference_data.posterior.coords:
|
|
630
|
+
return [], []
|
|
631
|
+
|
|
632
|
+
channel_results = []
|
|
633
|
+
no_shift_channels = []
|
|
634
|
+
|
|
635
|
+
n_channels = len(
|
|
636
|
+
self._meridian.inference_data.posterior[channel_type].values
|
|
637
|
+
)
|
|
638
|
+
if channel_type == constants.MEDIA_CHANNEL:
|
|
639
|
+
var_name = constants.ROI_M
|
|
640
|
+
prior_dist = self._meridian.model_spec.prior.roi_m
|
|
641
|
+
else:
|
|
642
|
+
var_name = constants.ROI_RF
|
|
643
|
+
prior_dist = self._meridian.model_spec.prior.roi_rf
|
|
644
|
+
prior_stats = {}
|
|
645
|
+
try:
|
|
646
|
+
prior_stats[review_constants.MEAN] = prior_dist.mean()
|
|
647
|
+
except NotImplementedError:
|
|
648
|
+
pass
|
|
649
|
+
try:
|
|
650
|
+
prior_stats[review_constants.MEDIAN] = prior_dist.quantile(0.5)
|
|
651
|
+
except NotImplementedError:
|
|
652
|
+
pass
|
|
653
|
+
try:
|
|
654
|
+
prior_stats[review_constants.Q1] = prior_dist.quantile(0.25)
|
|
655
|
+
except NotImplementedError:
|
|
656
|
+
pass
|
|
657
|
+
try:
|
|
658
|
+
prior_stats[review_constants.Q3] = prior_dist.quantile(0.75)
|
|
659
|
+
except NotImplementedError:
|
|
660
|
+
pass
|
|
661
|
+
|
|
662
|
+
post_stats = _calculate_new_statistics_from_samples(
|
|
663
|
+
self._meridian, self._config.n_bootstraps, var_name, n_channels
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
alpha = self._config.alpha
|
|
667
|
+
any_shift = np.zeros(n_channels, dtype=bool)
|
|
668
|
+
for key in prior_stats:
|
|
669
|
+
prior_stat = prior_stats[key]
|
|
670
|
+
post_stat = post_stats[key]
|
|
671
|
+
current_shift = _get_shifted_mask(post_stat, prior_stat, alpha)
|
|
672
|
+
any_shift = any_shift | current_shift
|
|
673
|
+
|
|
674
|
+
channel_names = self._meridian.inference_data.posterior[channel_type].values
|
|
675
|
+
for i, channel_name in enumerate(channel_names):
|
|
676
|
+
shifted = any_shift[i]
|
|
677
|
+
case = (
|
|
678
|
+
results.PriorPosteriorShiftChannelCases.SHIFT
|
|
679
|
+
if shifted
|
|
680
|
+
else results.PriorPosteriorShiftChannelCases.NO_SHIFT
|
|
681
|
+
)
|
|
682
|
+
if not shifted:
|
|
683
|
+
no_shift_channels.append(channel_name)
|
|
684
|
+
channel_results.append(
|
|
685
|
+
results.PriorPosteriorShiftChannelResult(
|
|
686
|
+
case=case, details={}, channel_name=channel_name
|
|
687
|
+
)
|
|
688
|
+
)
|
|
689
|
+
return channel_results, no_shift_channels
|
|
690
|
+
|
|
691
|
+
def _aggregate(
|
|
692
|
+
self,
|
|
693
|
+
*channel_type_results: _CHANNEL_TYPE_RESULT,
|
|
694
|
+
) -> results.PriorPosteriorShiftCheckResult:
|
|
695
|
+
"""Aggregates results from multiple channel types."""
|
|
696
|
+
channel_results = []
|
|
697
|
+
no_shift_channels = []
|
|
698
|
+
for results_part, channels_part in channel_type_results:
|
|
699
|
+
channel_results.extend(results_part)
|
|
700
|
+
no_shift_channels.extend(channels_part)
|
|
701
|
+
|
|
702
|
+
if no_shift_channels:
|
|
703
|
+
agg_case = results.PriorPosteriorShiftAggregateCases.REVIEW
|
|
704
|
+
final_details = {
|
|
705
|
+
"channels_str": ", ".join(
|
|
706
|
+
f"`{channel}`" for channel in no_shift_channels
|
|
707
|
+
)
|
|
708
|
+
}
|
|
709
|
+
else:
|
|
710
|
+
agg_case = results.PriorPosteriorShiftAggregateCases.PASS
|
|
711
|
+
final_details = {}
|
|
712
|
+
|
|
713
|
+
return results.PriorPosteriorShiftCheckResult(
|
|
714
|
+
case=agg_case, details=final_details, channel_results=channel_results
|
|
715
|
+
)
|
|
716
|
+
|
|
717
|
+
def run(self) -> results.PriorPosteriorShiftCheckResult:
|
|
718
|
+
np.random.seed(self._config.seed)
|
|
719
|
+
media_results = self._run_for_channel_type(constants.MEDIA_CHANNEL)
|
|
720
|
+
rf_results = self._run_for_channel_type(constants.RF_CHANNEL)
|
|
721
|
+
return self._aggregate(media_results, rf_results)
|