google-meridian 1.2.1__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. google_meridian-1.3.1.dist-info/METADATA +209 -0
  2. google_meridian-1.3.1.dist-info/RECORD +76 -0
  3. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/__init__.py +2 -0
  5. meridian/analysis/analyzer.py +179 -105
  6. meridian/analysis/formatter.py +2 -2
  7. meridian/analysis/optimizer.py +227 -87
  8. meridian/analysis/review/__init__.py +20 -0
  9. meridian/analysis/review/checks.py +721 -0
  10. meridian/analysis/review/configs.py +110 -0
  11. meridian/analysis/review/constants.py +40 -0
  12. meridian/analysis/review/results.py +544 -0
  13. meridian/analysis/review/reviewer.py +186 -0
  14. meridian/analysis/summarizer.py +21 -34
  15. meridian/analysis/templates/chips.html.jinja +12 -0
  16. meridian/analysis/test_utils.py +27 -5
  17. meridian/analysis/visualizer.py +41 -57
  18. meridian/backend/__init__.py +457 -118
  19. meridian/backend/test_utils.py +162 -0
  20. meridian/constants.py +39 -3
  21. meridian/model/__init__.py +1 -0
  22. meridian/model/eda/__init__.py +3 -0
  23. meridian/model/eda/constants.py +21 -0
  24. meridian/model/eda/eda_engine.py +1309 -196
  25. meridian/model/eda/eda_outcome.py +200 -0
  26. meridian/model/eda/eda_spec.py +84 -0
  27. meridian/model/eda/meridian_eda.py +220 -0
  28. meridian/model/knots.py +55 -49
  29. meridian/model/media.py +10 -8
  30. meridian/model/model.py +79 -16
  31. meridian/model/model_test_data.py +53 -0
  32. meridian/model/posterior_sampler.py +39 -32
  33. meridian/model/prior_distribution.py +12 -2
  34. meridian/model/prior_sampler.py +146 -90
  35. meridian/model/spec.py +7 -8
  36. meridian/model/transformers.py +11 -3
  37. meridian/version.py +1 -1
  38. schema/__init__.py +18 -0
  39. schema/serde/__init__.py +26 -0
  40. schema/serde/constants.py +48 -0
  41. schema/serde/distribution.py +515 -0
  42. schema/serde/eda_spec.py +192 -0
  43. schema/serde/function_registry.py +143 -0
  44. schema/serde/hyperparameters.py +363 -0
  45. schema/serde/inference_data.py +105 -0
  46. schema/serde/marketing_data.py +1321 -0
  47. schema/serde/meridian_serde.py +413 -0
  48. schema/serde/serde.py +47 -0
  49. schema/serde/test_data.py +4608 -0
  50. schema/utils/__init__.py +17 -0
  51. schema/utils/time_record.py +156 -0
  52. google_meridian-1.2.1.dist-info/METADATA +0 -409
  53. google_meridian-1.2.1.dist-info/RECORD +0 -52
  54. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/WHEEL +0 -0
  55. {google_meridian-1.2.1.dist-info → google_meridian-1.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,721 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Implementation of the Model Quality Checks."""
16
+
17
+ import abc
18
+ from collections.abc import Sequence
19
+ import dataclasses
20
+ from typing import Generic, TypeVar
21
+ import warnings
22
+
23
+ from meridian import backend
24
+ from meridian import constants
25
+ from meridian.analysis import analyzer as analyzer_module
26
+ from meridian.analysis.review import configs
27
+ from meridian.analysis.review import constants as review_constants
28
+ from meridian.analysis.review import results
29
+ from meridian.model import model
30
+ import numpy as np
31
+
32
+ ConfigType = TypeVar("ConfigType", bound=configs.BaseConfig)
33
+ ResultType = TypeVar("ResultType", bound=results.CheckResult)
34
+
35
+
36
+ class BaseCheck(abc.ABC, Generic[ConfigType, ResultType]):
37
+ """A generic, abstract base class for a single, runnable quality check."""
38
+
39
+ def __init__(
40
+ self,
41
+ meridian: model.Meridian,
42
+ analyzer: analyzer_module.Analyzer,
43
+ config: ConfigType,
44
+ ):
45
+ self._meridian = meridian
46
+ self._analyzer = analyzer
47
+ self._config = config
48
+
49
+ @abc.abstractmethod
50
+ def run(self) -> ResultType:
51
+ """Executes the check.
52
+
53
+ The return type uses the generic ResultType, making it specific for each
54
+ subclass.
55
+ """
56
+ raise NotImplementedError()
57
+
58
+
59
+ # ==============================================================================
60
+ # Check: Convergence
61
+ # ==============================================================================
62
+ class ConvergenceCheck(
63
+ BaseCheck[configs.ConvergenceConfig, results.ConvergenceCheckResult]
64
+ ):
65
+ """Checks for model convergence."""
66
+
67
+ def run(self) -> results.ConvergenceCheckResult:
68
+ rhats = self._analyzer.get_rhat()
69
+ with warnings.catch_warnings():
70
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
71
+ max_rhats = {k: np.nanmax(v) for k, v in rhats.items()}
72
+
73
+ valid_rhat_items = [
74
+ item for item in max_rhats.items() if not np.isnan(item[1])
75
+ ]
76
+ if not valid_rhat_items:
77
+ return results.ConvergenceCheckResult(
78
+ case=results.ConvergenceCases.CONVERGED,
79
+ details={
80
+ review_constants.RHAT: np.nan,
81
+ review_constants.PARAMETER: np.nan,
82
+ review_constants.CONVERGENCE_THRESHOLD: (
83
+ self._config.convergence_threshold
84
+ ),
85
+ },
86
+ )
87
+
88
+ max_parameter, max_rhat = max(max_rhats.items(), key=lambda item: item[1])
89
+
90
+ details = {
91
+ review_constants.RHAT: max_rhat,
92
+ review_constants.PARAMETER: max_parameter,
93
+ review_constants.CONVERGENCE_THRESHOLD: (
94
+ self._config.convergence_threshold
95
+ ),
96
+ }
97
+
98
+ # Case 1: Converged.
99
+ if max_rhat < self._config.convergence_threshold:
100
+ return results.ConvergenceCheckResult(
101
+ case=results.ConvergenceCases.CONVERGED,
102
+ details=details,
103
+ )
104
+
105
+ # Case 2: Not fully converged, but potentially acceptable.
106
+ elif (
107
+ self._config.convergence_threshold
108
+ <= max_rhat
109
+ < self._config.not_fully_convergence_threshold
110
+ ):
111
+ return results.ConvergenceCheckResult(
112
+ case=results.ConvergenceCases.NOT_FULLY_CONVERGED,
113
+ details=details,
114
+ )
115
+
116
+ # Case 3: Not converged and unacceptable.
117
+ else: # max_rhat >= divergence_threshold
118
+ return results.ConvergenceCheckResult(
119
+ case=results.ConvergenceCases.NOT_CONVERGED,
120
+ details=details,
121
+ )
122
+
123
+
124
+ # ==============================================================================
125
+ # Check: Baseline
126
+ # ==============================================================================
127
+ class BaselineCheck(
128
+ BaseCheck[configs.BaselineConfig, results.BaselineCheckResult]
129
+ ):
130
+ """Checks for negative baseline probability."""
131
+
132
+ def run(self) -> results.BaselineCheckResult:
133
+ prob = self._analyzer.negative_baseline_probability()
134
+ details = {
135
+ review_constants.NEGATIVE_BASELINE_PROB: prob,
136
+ review_constants.NEGATIVE_BASELINE_PROB_FAIL_THRESHOLD: (
137
+ self._config.negative_baseline_prob_fail_threshold
138
+ ),
139
+ review_constants.NEGATIVE_BASELINE_PROB_REVIEW_THRESHOLD: (
140
+ self._config.negative_baseline_prob_review_threshold
141
+ ),
142
+ }
143
+ # Case 1: FAIL
144
+ if prob > self._config.negative_baseline_prob_fail_threshold:
145
+ return results.BaselineCheckResult(
146
+ case=results.BaselineCases.FAIL,
147
+ details=details,
148
+ )
149
+ # Case 2: REVIEW
150
+ elif prob >= self._config.negative_baseline_prob_review_threshold:
151
+ return results.BaselineCheckResult(
152
+ case=results.BaselineCases.REVIEW,
153
+ details=details,
154
+ )
155
+ # Case 3: PASS
156
+ else:
157
+ return results.BaselineCheckResult(
158
+ case=results.BaselineCases.PASS, details=details
159
+ )
160
+
161
+
162
+ # ==============================================================================
163
+ # Check: Bayesian Posterior Predictive P-value
164
+ # ==============================================================================
165
+ class BayesianPPPCheck(
166
+ BaseCheck[configs.BayesianPPPConfig, results.BayesianPPPCheckResult]
167
+ ):
168
+ """Checks for Bayesian Posterior Predictive P-value."""
169
+
170
+ def run(self) -> results.BayesianPPPCheckResult:
171
+ mmm = self._meridian
172
+ analyzer = self._analyzer
173
+
174
+ outcome = mmm.kpi
175
+ if mmm.revenue_per_kpi is not None:
176
+ outcome *= mmm.revenue_per_kpi
177
+ total_outcome_actual = np.sum(outcome)
178
+
179
+ total_outcome_posterior = analyzer.expected_outcome(
180
+ aggregate_times=True, aggregate_geos=True
181
+ )
182
+ total_outcome_expected = np.asarray(total_outcome_posterior).flatten()
183
+
184
+ total_outcome_expected_mean = np.mean(total_outcome_expected)
185
+
186
+ bayesian_ppp = np.mean(
187
+ np.abs(total_outcome_expected - total_outcome_expected_mean)
188
+ >= np.abs(total_outcome_actual - total_outcome_expected_mean)
189
+ )
190
+
191
+ details = {
192
+ review_constants.BAYESIAN_PPP: bayesian_ppp,
193
+ }
194
+
195
+ if bayesian_ppp >= self._config.ppp_threshold:
196
+ return results.BayesianPPPCheckResult(
197
+ case=results.BayesianPPPCases.PASS,
198
+ details=details,
199
+ )
200
+ else:
201
+ return results.BayesianPPPCheckResult(
202
+ case=results.BayesianPPPCases.FAIL,
203
+ details=details,
204
+ )
205
+
206
+
207
+ # ==============================================================================
208
+ # Check: Goodness of Fit
209
+ # ==============================================================================
210
+ class GoodnessOfFitCheck(
211
+ BaseCheck[configs.GoodnessOfFitConfig, results.GoodnessOfFitCheckResult]
212
+ ):
213
+ """Checks for goodness of fit of the model."""
214
+
215
+ def run(self) -> results.GoodnessOfFitCheckResult:
216
+ gof_ds = self._analyzer.predictive_accuracy()
217
+ gof_df = gof_ds.to_dataframe().reset_index()
218
+
219
+ geo_granularity = (
220
+ constants.NATIONAL if self._meridian.n_geos == 1 else constants.GEO
221
+ )
222
+
223
+ gof_metrics = gof_df[gof_df[constants.GEO_GRANULARITY] == geo_granularity]
224
+ if constants.EVALUATION_SET_VAR in gof_df.columns:
225
+ gof_metrics = gof_metrics[
226
+ gof_metrics[constants.EVALUATION_SET_VAR] == constants.ALL_DATA
227
+ ]
228
+
229
+ gof_metrics_pivoted = gof_metrics.pivot(
230
+ index=constants.GEO_GRANULARITY,
231
+ columns=constants.METRIC,
232
+ values=constants.VALUE,
233
+ )
234
+ gof_metrics_series = gof_metrics_pivoted.loc[geo_granularity]
235
+
236
+ r_squared = gof_metrics_series[constants.R_SQUARED]
237
+ mape = gof_metrics_series[constants.MAPE]
238
+ wmape = gof_metrics_series[constants.WMAPE]
239
+
240
+ details = {
241
+ review_constants.R_SQUARED: r_squared,
242
+ review_constants.MAPE: mape,
243
+ review_constants.WMAPE: wmape,
244
+ }
245
+
246
+ if r_squared > 0:
247
+ return results.GoodnessOfFitCheckResult(
248
+ case=results.GoodnessOfFitCases.PASS,
249
+ details=details,
250
+ )
251
+ else: # r_squared <= 0
252
+ return results.GoodnessOfFitCheckResult(
253
+ case=results.GoodnessOfFitCases.REVIEW,
254
+ details=details,
255
+ )
256
+
257
+
258
+ # ==============================================================================
259
+ # Check: ROI Consistency
260
+ # ==============================================================================
261
+ def _format_roi_channels_msg(channels: np.ndarray, direction: str) -> str:
262
+ if channels.size == 0:
263
+ return ""
264
+ plural = "s" if channels.size > 1 else ""
265
+ return (
266
+ f"an unusually {direction} ROI estimate (for channel{plural} "
267
+ f"{', '.join(f'`{channel}`' for channel in channels)})"
268
+ )
269
+
270
+
271
+ def _inf_prior_quantiles_channels(
272
+ channels: np.ndarray,
273
+ lo_roi_quantiles: np.ndarray,
274
+ hi_roi_quantiles: np.ndarray,
275
+ ) -> np.ndarray:
276
+ """Returns channels with infinite prior quantiles.
277
+
278
+ Args:
279
+ channels: The names of the channels.
280
+ lo_roi_quantiles: The lower quantiles of the ROI prior.
281
+ hi_roi_quantiles: The upper quantiles of the ROI prior.
282
+
283
+ Returns:
284
+ An array of channel names with infinite prior quantiles.
285
+ """
286
+ inf_mask = np.isinf(lo_roi_quantiles) | np.isinf(hi_roi_quantiles)
287
+ return channels[inf_mask]
288
+
289
+
290
+ @dataclasses.dataclass
291
+ class _ROIConsistencyChannelData:
292
+ """A data structure for auxiliary data for the ROI Consistency Check.
293
+
294
+ Attributes:
295
+ prior_roi_los: Lower quantile values from ROI priors, corresponding to the
296
+ channels in `all_channels`.
297
+ prior_roi_his: Upper quantile values from ROI priors, corresponding to the
298
+ channels in `all_channels`.
299
+ posterior_means: Mean values of ROI posteriors, corresponding to the
300
+ channels in `all_channels`.
301
+ all_channels: Channel names for which quantile computations were successful;
302
+ channels for which quantiles could not be computed are skipped. They are
303
+ ordered with media channels (`roi_m`) followed by reach and frequency (RF)
304
+ channels (`roi_rf`).
305
+ inf_channels: Channels with infinite prior quantiles.
306
+ low_roi_channels: Channels with posterior means below their prior's lower
307
+ quantile.
308
+ high_roi_channels: Channels with posterior means above their prior's upper
309
+ quantile.
310
+ quantile_not_defined_channels: Channel names for which quantiles could not
311
+ be computed.
312
+ quantile_not_defined_parameters: Parameters for which the quantile method is
313
+ not implemented.
314
+ """
315
+
316
+ prior_roi_los: np.ndarray
317
+ prior_roi_his: np.ndarray
318
+ posterior_means: np.ndarray
319
+ all_channels: np.ndarray
320
+ inf_channels: np.ndarray
321
+ low_roi_channels: np.ndarray
322
+ high_roi_channels: np.ndarray
323
+ quantile_not_defined_channels: np.ndarray
324
+ quantile_not_defined_parameters: list[backend.tfd.Distribution] = (
325
+ dataclasses.field(default_factory=list)
326
+ )
327
+
328
+
329
+ def _get_roi_consistency_channel_data(
330
+ prior_rois: Sequence[backend.tfd.Distribution],
331
+ posterior_rois: Sequence[backend.tfd.Distribution],
332
+ channels_names: Sequence[Sequence[str]],
333
+ prior_lower_quantile: float,
334
+ prior_upper_quantile: float,
335
+ ) -> _ROIConsistencyChannelData:
336
+ """Returns the channel-level data for the ROI Consistency Check.
337
+
338
+ Args:
339
+ prior_rois: The ROI priors for all channels, in the same order as
340
+ `channels_names`.
341
+ posterior_rois: The ROI posteriors for all channels, in the same order as
342
+ `channels_names`.
343
+ channels_names: The names of all channels, with media channels (`roi_m`)
344
+ followed by any reach and frequency (RF) channels (`roi_rf`).
345
+ prior_lower_quantile: The lower quantile of the ROI prior.
346
+ prior_upper_quantile: The upper quantile of the ROI prior.
347
+
348
+ Returns:
349
+ A _ROIConsistencyChannelData object containing the channel-level data for
350
+ the ROI Consistency Check.
351
+ """
352
+
353
+ prior_roi_los_parts = []
354
+ prior_roi_his_parts = []
355
+ posterior_means_parts = []
356
+ all_channels_parts = []
357
+ quantile_not_defined_parameters = []
358
+ quantile_not_defined_channels = []
359
+
360
+ for prior_roi, posterior_roi, channels in zip(
361
+ prior_rois, posterior_rois, channels_names
362
+ ):
363
+ try:
364
+ prior_roi_lo = prior_roi.quantile(
365
+ prior_lower_quantile,
366
+ )
367
+ prior_roi_hi = prior_roi.quantile(
368
+ prior_upper_quantile,
369
+ )
370
+ posterior_mean = np.mean(posterior_roi, axis=(0, 1))
371
+
372
+ n_channels = len(channels)
373
+ prior_roi_lo = np.broadcast_to(prior_roi_lo, shape=(n_channels,))
374
+ prior_roi_hi = np.broadcast_to(prior_roi_hi, shape=(n_channels,))
375
+
376
+ prior_roi_los_parts.append(prior_roi_lo)
377
+ prior_roi_his_parts.append(prior_roi_hi)
378
+ posterior_means_parts.append(posterior_mean)
379
+ all_channels_parts.append(channels)
380
+ except NotImplementedError:
381
+ quantile_not_defined_parameters.append(prior_roi)
382
+ quantile_not_defined_channels.extend(channels)
383
+
384
+ if prior_roi_los_parts:
385
+ prior_roi_los = np.concatenate(prior_roi_los_parts)
386
+ prior_roi_his = np.concatenate(prior_roi_his_parts)
387
+ posterior_means = np.concatenate(posterior_means_parts)
388
+ all_channels = np.concatenate(all_channels_parts)
389
+ else:
390
+ prior_roi_los = np.array([])
391
+ prior_roi_his = np.array([])
392
+ posterior_means = np.array([])
393
+ all_channels = np.array([])
394
+
395
+ inf_channels = _inf_prior_quantiles_channels(
396
+ channels=all_channels,
397
+ lo_roi_quantiles=prior_roi_los,
398
+ hi_roi_quantiles=prior_roi_his,
399
+ )
400
+ low_roi_channels = all_channels[posterior_means < prior_roi_los]
401
+ high_roi_channels = all_channels[posterior_means > prior_roi_his]
402
+
403
+ return _ROIConsistencyChannelData(
404
+ prior_roi_los=prior_roi_los,
405
+ prior_roi_his=prior_roi_his,
406
+ posterior_means=posterior_means,
407
+ all_channels=all_channels,
408
+ inf_channels=inf_channels,
409
+ low_roi_channels=low_roi_channels,
410
+ high_roi_channels=high_roi_channels,
411
+ quantile_not_defined_parameters=quantile_not_defined_parameters,
412
+ quantile_not_defined_channels=np.array(quantile_not_defined_channels),
413
+ )
414
+
415
+
416
+ def _compute_channel_results(
417
+ channel_data: _ROIConsistencyChannelData,
418
+ ) -> list[results.ROIConsistencyChannelResult]:
419
+ """Returns the channel-level results for the ROI Consistency Check."""
420
+
421
+ channel_results = []
422
+ for channel in channel_data.quantile_not_defined_channels:
423
+ case = results.ROIConsistencyChannelCases.QUANTILE_NOT_DEFINED
424
+ channel_results.append(
425
+ results.ROIConsistencyChannelResult(
426
+ case=case,
427
+ details={},
428
+ channel_name=channel,
429
+ )
430
+ )
431
+ for i, channel in enumerate(channel_data.all_channels):
432
+ if channel in channel_data.inf_channels:
433
+ case = results.ROIConsistencyChannelCases.PRIOR_ROI_QUANTILE_INF
434
+ elif channel in channel_data.low_roi_channels:
435
+ case = results.ROIConsistencyChannelCases.ROI_LOW
436
+ elif channel in channel_data.high_roi_channels:
437
+ case = results.ROIConsistencyChannelCases.ROI_HIGH
438
+ else:
439
+ case = results.ROIConsistencyChannelCases.ROI_PASS
440
+ channel_results.append(
441
+ results.ROIConsistencyChannelResult(
442
+ case=case,
443
+ details={
444
+ review_constants.PRIOR_ROI_LO: channel_data.prior_roi_los[i],
445
+ review_constants.PRIOR_ROI_HI: channel_data.prior_roi_his[i],
446
+ review_constants.POSTERIOR_ROI_MEAN: (
447
+ channel_data.posterior_means[i]
448
+ ),
449
+ },
450
+ channel_name=channel,
451
+ )
452
+ )
453
+ return channel_results
454
+
455
+
456
+ def _compute_aggregate_result(
457
+ channel_data: _ROIConsistencyChannelData,
458
+ ) -> results.ROIConsistencyCheckResult:
459
+ """Returns the aggregate result for the ROI Consistency Check."""
460
+ channel_results = _compute_channel_results(channel_data=channel_data)
461
+
462
+ aggregate_details = {}
463
+
464
+ # Channel Case 5: QUANTILE_NOT_DEFINED
465
+ if channel_data.quantile_not_defined_parameters:
466
+ aggregate_details[review_constants.QUANTILE_NOT_DEFINED_MSG] = (
467
+ "The quantile method is not defined for the following parameters:"
468
+ f" {channel_data.quantile_not_defined_parameters}. The ROI"
469
+ " Consistency check cannot be performed for these parameters."
470
+ )
471
+ else:
472
+ aggregate_details[review_constants.QUANTILE_NOT_DEFINED_MSG] = ""
473
+
474
+ # Channel Case 4: PRIOR_ROI_QUANTILE_INF
475
+ if channel_data.inf_channels.size > 0:
476
+ aggregate_details[review_constants.INF_CHANNELS_MSG] = (
477
+ "Prior ROI quantiles are infinite for channels:"
478
+ f" {', '.join(channel_data.inf_channels)}"
479
+ )
480
+ else:
481
+ aggregate_details[review_constants.INF_CHANNELS_MSG] = ""
482
+
483
+ # Channel Cases 2-3: ROI_LOW, ROI_HIGH
484
+ if (
485
+ channel_data.low_roi_channels.size > 0
486
+ or channel_data.high_roi_channels.size > 0
487
+ ):
488
+ low_msg = _format_roi_channels_msg(channel_data.low_roi_channels, "low")
489
+ high_msg = _format_roi_channels_msg(channel_data.high_roi_channels, "high")
490
+
491
+ channels_low_high = " and ".join(filter(None, [low_msg, high_msg]))
492
+ aggregate_details[review_constants.LOW_HIGH_CHANNELS_MSG] = (
493
+ f"We've detected {channels_low_high} where the posterior point"
494
+ " estimate falls into the extreme tail of your custom prior."
495
+ )
496
+ else:
497
+ aggregate_details[review_constants.LOW_HIGH_CHANNELS_MSG] = ""
498
+
499
+ if (
500
+ aggregate_details[review_constants.QUANTILE_NOT_DEFINED_MSG]
501
+ or aggregate_details[review_constants.INF_CHANNELS_MSG]
502
+ or aggregate_details[review_constants.LOW_HIGH_CHANNELS_MSG]
503
+ ):
504
+ aggregate_case = results.ROIConsistencyAggregateCases.REVIEW
505
+ else:
506
+ aggregate_case = results.ROIConsistencyAggregateCases.PASS
507
+
508
+ return results.ROIConsistencyCheckResult(
509
+ case=aggregate_case,
510
+ details=aggregate_details,
511
+ channel_results=channel_results,
512
+ )
513
+
514
+
515
+ class ROIConsistencyCheck(
516
+ BaseCheck[configs.ROIConsistencyConfig, results.ROIConsistencyCheckResult]
517
+ ):
518
+ """Checks if ROI posterior mean is in tails of ROI prior."""
519
+
520
+ def run(self) -> results.ROIConsistencyCheckResult:
521
+ prior_rois = []
522
+ posterior_rois = []
523
+ channel_names = []
524
+ if (
525
+ constants.MEDIA_CHANNEL
526
+ in self._meridian.inference_data.posterior.coords
527
+ ):
528
+ prior_rois.append(self._meridian.model_spec.prior.roi_m)
529
+ posterior_rois.append(self._meridian.inference_data.posterior.roi_m)
530
+ channel_names.append(
531
+ self._meridian.inference_data.posterior.media_channel.values
532
+ )
533
+ if constants.RF_CHANNEL in self._meridian.inference_data.posterior.coords:
534
+ prior_rois.append(self._meridian.model_spec.prior.roi_rf)
535
+ posterior_rois.append(self._meridian.inference_data.posterior.roi_rf)
536
+ channel_names.append(
537
+ self._meridian.inference_data.posterior.rf_channel.values
538
+ )
539
+
540
+ channel_data = _get_roi_consistency_channel_data(
541
+ prior_rois=prior_rois,
542
+ posterior_rois=posterior_rois,
543
+ channels_names=channel_names,
544
+ prior_lower_quantile=self._config.prior_lower_quantile,
545
+ prior_upper_quantile=self._config.prior_upper_quantile,
546
+ )
547
+
548
+ return _compute_aggregate_result(channel_data=channel_data)
549
+
550
+
551
+ # ==============================================================================
552
+ # Check: Prior-Posterior Shift
553
+ # ==============================================================================
554
+ def _bootstrap(x: np.ndarray, n_bootstraps: int) -> np.ndarray:
555
+ """Performs non-parametric bootstrap resampling on the columns of x."""
556
+ n_rows, n_cols = x.shape
557
+ x_bs = np.empty((n_bootstraps, n_rows, n_cols))
558
+ for i in range(n_bootstraps):
559
+ col_indices = np.random.choice(n_cols, n_cols, replace=True)
560
+ x_bs[i, :, :] = x[:, col_indices]
561
+ return x_bs
562
+
563
+
564
+ def _calculate_new_statistics_from_samples(
565
+ mmm: model.Meridian, n_bootstraps: int, var_name: str, n_channels: int
566
+ ) -> dict[str, np.ndarray]:
567
+ """Calculate Mean, Median, Q1, and Q3 from posterior samples."""
568
+ n_chains = len(mmm.inference_data.posterior.coords[constants.CHAIN])
569
+ n_draws = len(mmm.inference_data.posterior.coords[constants.DRAW])
570
+ n_posterior_samples = n_chains * n_draws
571
+
572
+ posterior_samples = np.transpose(
573
+ np.reshape(
574
+ mmm.inference_data.posterior.variables[var_name].values,
575
+ (n_posterior_samples, n_channels),
576
+ )
577
+ )
578
+ x = _bootstrap(
579
+ posterior_samples, n_bootstraps
580
+ ) # x is (bootstraps, channels, samples)
581
+
582
+ mean = np.mean(x, axis=-1)
583
+ median = np.quantile(x, q=0.5, axis=-1)
584
+ q1 = np.quantile(x, q=0.25, axis=-1)
585
+ q3 = np.quantile(x, q=0.75, axis=-1)
586
+
587
+ return {
588
+ review_constants.MEAN: mean,
589
+ review_constants.MEDIAN: median,
590
+ review_constants.Q1: q1,
591
+ review_constants.Q3: q3,
592
+ }
593
+
594
+
595
+ def _get_shifted_mask(
596
+ posterior_stat: np.ndarray, prior_stat: np.ndarray, alpha: float
597
+ ) -> np.ndarray:
598
+ """Returns a boolean mask indicating which channels have a significant shift."""
599
+ prior_stat_b = prior_stat[np.newaxis, ...]
600
+ shift_1 = np.mean(posterior_stat > prior_stat_b, axis=0) < alpha
601
+ shift_2 = np.mean(posterior_stat < prior_stat_b, axis=0) < alpha
602
+ return shift_1 | shift_2
603
+
604
+
605
+ class PriorPosteriorShiftCheck(
606
+ BaseCheck[
607
+ configs.PriorPosteriorShiftConfig,
608
+ results.PriorPosteriorShiftCheckResult,
609
+ ]
610
+ ):
611
+ """Checks for a significant shift between prior and posterior of ROI."""
612
+
613
+ # Tuple of (channel_results, no_shift_channels)
614
+ _CHANNEL_TYPE_RESULT = tuple[
615
+ list[results.PriorPosteriorShiftChannelResult],
616
+ list[str],
617
+ ]
618
+
619
+ def _run_for_channel_type(self, channel_type: str) -> _CHANNEL_TYPE_RESULT:
620
+ """Runs the prior-posterior shift check for a given channel type.
621
+
622
+ Args:
623
+ channel_type: The channel type ('media_channel' or 'rf_channel') to run
624
+ the check for.
625
+
626
+ Returns:
627
+ A tuple of (`channel_results`, `no_shift_channels`).
628
+ """
629
+ if channel_type not in self._meridian.inference_data.posterior.coords:
630
+ return [], []
631
+
632
+ channel_results = []
633
+ no_shift_channels = []
634
+
635
+ n_channels = len(
636
+ self._meridian.inference_data.posterior[channel_type].values
637
+ )
638
+ if channel_type == constants.MEDIA_CHANNEL:
639
+ var_name = constants.ROI_M
640
+ prior_dist = self._meridian.model_spec.prior.roi_m
641
+ else:
642
+ var_name = constants.ROI_RF
643
+ prior_dist = self._meridian.model_spec.prior.roi_rf
644
+ prior_stats = {}
645
+ try:
646
+ prior_stats[review_constants.MEAN] = prior_dist.mean()
647
+ except NotImplementedError:
648
+ pass
649
+ try:
650
+ prior_stats[review_constants.MEDIAN] = prior_dist.quantile(0.5)
651
+ except NotImplementedError:
652
+ pass
653
+ try:
654
+ prior_stats[review_constants.Q1] = prior_dist.quantile(0.25)
655
+ except NotImplementedError:
656
+ pass
657
+ try:
658
+ prior_stats[review_constants.Q3] = prior_dist.quantile(0.75)
659
+ except NotImplementedError:
660
+ pass
661
+
662
+ post_stats = _calculate_new_statistics_from_samples(
663
+ self._meridian, self._config.n_bootstraps, var_name, n_channels
664
+ )
665
+
666
+ alpha = self._config.alpha
667
+ any_shift = np.zeros(n_channels, dtype=bool)
668
+ for key in prior_stats:
669
+ prior_stat = prior_stats[key]
670
+ post_stat = post_stats[key]
671
+ current_shift = _get_shifted_mask(post_stat, prior_stat, alpha)
672
+ any_shift = any_shift | current_shift
673
+
674
+ channel_names = self._meridian.inference_data.posterior[channel_type].values
675
+ for i, channel_name in enumerate(channel_names):
676
+ shifted = any_shift[i]
677
+ case = (
678
+ results.PriorPosteriorShiftChannelCases.SHIFT
679
+ if shifted
680
+ else results.PriorPosteriorShiftChannelCases.NO_SHIFT
681
+ )
682
+ if not shifted:
683
+ no_shift_channels.append(channel_name)
684
+ channel_results.append(
685
+ results.PriorPosteriorShiftChannelResult(
686
+ case=case, details={}, channel_name=channel_name
687
+ )
688
+ )
689
+ return channel_results, no_shift_channels
690
+
691
+ def _aggregate(
692
+ self,
693
+ *channel_type_results: _CHANNEL_TYPE_RESULT,
694
+ ) -> results.PriorPosteriorShiftCheckResult:
695
+ """Aggregates results from multiple channel types."""
696
+ channel_results = []
697
+ no_shift_channels = []
698
+ for results_part, channels_part in channel_type_results:
699
+ channel_results.extend(results_part)
700
+ no_shift_channels.extend(channels_part)
701
+
702
+ if no_shift_channels:
703
+ agg_case = results.PriorPosteriorShiftAggregateCases.REVIEW
704
+ final_details = {
705
+ "channels_str": ", ".join(
706
+ f"`{channel}`" for channel in no_shift_channels
707
+ )
708
+ }
709
+ else:
710
+ agg_case = results.PriorPosteriorShiftAggregateCases.PASS
711
+ final_details = {}
712
+
713
+ return results.PriorPosteriorShiftCheckResult(
714
+ case=agg_case, details=final_details, channel_results=channel_results
715
+ )
716
+
717
+ def run(self) -> results.PriorPosteriorShiftCheckResult:
718
+ np.random.seed(self._config.seed)
719
+ media_results = self._run_for_channel_type(constants.MEDIA_CHANNEL)
720
+ rf_results = self._run_for_channel_type(constants.RF_CHANNEL)
721
+ return self._aggregate(media_results, rf_results)