google-meridian 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +50 -46
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/checks.py +118 -116
  7. meridian/analysis/review/constants.py +3 -3
  8. meridian/analysis/review/results.py +131 -68
  9. meridian/analysis/review/reviewer.py +8 -23
  10. meridian/analysis/summarizer.py +6 -1
  11. meridian/analysis/test_utils.py +2898 -2538
  12. meridian/analysis/visualizer.py +28 -9
  13. meridian/backend/__init__.py +106 -0
  14. meridian/constants.py +1 -0
  15. meridian/data/input_data.py +30 -52
  16. meridian/data/input_data_builder.py +2 -9
  17. meridian/data/test_utils.py +25 -41
  18. meridian/data/validator.py +48 -0
  19. meridian/mlflow/autolog.py +19 -9
  20. meridian/model/adstock_hill.py +3 -5
  21. meridian/model/context.py +134 -0
  22. meridian/model/eda/constants.py +334 -4
  23. meridian/model/eda/eda_engine.py +724 -312
  24. meridian/model/eda/eda_outcome.py +177 -33
  25. meridian/model/model.py +159 -110
  26. meridian/model/model_test_data.py +38 -0
  27. meridian/model/posterior_sampler.py +103 -62
  28. meridian/model/prior_sampler.py +114 -94
  29. meridian/model/spec.py +23 -14
  30. meridian/templates/card.html.jinja +9 -7
  31. meridian/templates/chart.html.jinja +1 -6
  32. meridian/templates/finding.html.jinja +19 -0
  33. meridian/templates/findings.html.jinja +33 -0
  34. meridian/templates/formatter.py +41 -5
  35. meridian/templates/formatter_test.py +127 -0
  36. meridian/templates/style.css +66 -9
  37. meridian/templates/style.scss +85 -4
  38. meridian/templates/table.html.jinja +1 -0
  39. meridian/version.py +1 -1
  40. scenarioplanner/linkingapi/constants.py +1 -1
  41. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  42. schema/processors/marketing_processor.py +11 -10
  43. schema/processors/model_processor.py +4 -1
  44. schema/serde/distribution.py +12 -7
  45. schema/serde/hyperparameters.py +54 -107
  46. schema/serde/meridian_serde.py +12 -3
  47. schema/utils/__init__.py +1 -0
  48. schema/utils/proto_enum_converter.py +127 -0
  49. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
  50. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,15 @@
15
15
  """Module for MCMC sampling of posterior distributions in a Meridian model."""
16
16
 
17
17
  from collections.abc import Mapping, Sequence
18
- from typing import TYPE_CHECKING
18
+ import functools
19
+ from typing import Optional, TYPE_CHECKING
20
+ import warnings
19
21
 
20
22
  import arviz as az
21
23
  from meridian import backend
22
24
  from meridian import constants
25
+ from meridian.model import context
26
+ from meridian.model import equations
23
27
  import numpy as np
24
28
 
25
29
  if TYPE_CHECKING:
@@ -72,34 +76,39 @@ def _get_tau_g(
72
76
  return backend.tfd.Deterministic(tau_g, name="tau_g")
73
77
 
74
78
 
75
- def _joint_dist_unpinned(mmm: "model.Meridian"):
79
+ def _joint_dist_unpinned(
80
+ model_context: context.ModelContext,
81
+ model_equations: equations.ModelEquations,
82
+ ):
76
83
  """Returns unpinned joint distribution."""
77
84
 
78
85
  # This lists all the derived properties and states of this Meridian object
79
86
  # that are referenced by the joint distribution coroutine.
80
87
  # That is, these are the list of captured parameters.
81
- prior_broadcast = mmm.prior_broadcast
82
- baseline_geo_idx = mmm.baseline_geo_idx
83
- knot_info = mmm.knot_info
84
- n_geos = mmm.n_geos
85
- n_times = mmm.n_times
86
- n_media_channels = mmm.n_media_channels
87
- n_rf_channels = mmm.n_rf_channels
88
- n_organic_media_channels = mmm.n_organic_media_channels
89
- n_organic_rf_channels = mmm.n_organic_rf_channels
90
- n_controls = mmm.n_controls
91
- n_non_media_channels = mmm.n_non_media_channels
92
- holdout_id = mmm.holdout_id
93
- media_tensors = mmm.media_tensors
94
- rf_tensors = mmm.rf_tensors
95
- organic_media_tensors = mmm.organic_media_tensors
96
- organic_rf_tensors = mmm.organic_rf_tensors
97
- controls_scaled = mmm.controls_scaled
98
- non_media_treatments_normalized = mmm.non_media_treatments_normalized
99
- media_effects_dist = mmm.media_effects_dist
100
- adstock_hill_media_fn = mmm.adstock_hill_media
101
- adstock_hill_rf_fn = mmm.adstock_hill_rf
102
- total_outcome = mmm.total_outcome
88
+ prior_broadcast = model_context.prior_broadcast
89
+ baseline_geo_idx = model_context.baseline_geo_idx
90
+ knot_info = model_context.knot_info
91
+ n_geos = model_context.n_geos
92
+ n_times = model_context.n_times
93
+ n_media_channels = model_context.n_media_channels
94
+ n_rf_channels = model_context.n_rf_channels
95
+ n_organic_media_channels = model_context.n_organic_media_channels
96
+ n_organic_rf_channels = model_context.n_organic_rf_channels
97
+ n_controls = model_context.n_controls
98
+ n_non_media_channels = model_context.n_non_media_channels
99
+ holdout_id = model_context.holdout_id
100
+ media_tensors = model_context.media_tensors
101
+ rf_tensors = model_context.rf_tensors
102
+ organic_media_tensors = model_context.organic_media_tensors
103
+ organic_rf_tensors = model_context.organic_rf_tensors
104
+ controls_scaled = model_context.controls_scaled
105
+ non_media_treatments_normalized = (
106
+ model_context.non_media_treatments_normalized
107
+ )
108
+ media_effects_dist = model_context.media_effects_dist
109
+ adstock_hill_media_fn = model_equations.adstock_hill_media
110
+ adstock_hill_rf_fn = model_equations.adstock_hill_rf
111
+ total_outcome = model_context.total_outcome
103
112
 
104
113
  # Sample directly from prior.
105
114
  knot_values = yield prior_broadcast.knot_values
@@ -142,9 +151,9 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
142
151
  alpha=alpha_m,
143
152
  ec=ec_m,
144
153
  slope=slope_m,
145
- decay_functions=mmm.adstock_decay_spec.media,
154
+ decay_functions=model_context.adstock_decay_spec.media,
146
155
  )
147
- prior_type = mmm.model_spec.effective_media_prior_type
156
+ prior_type = model_context.model_spec.effective_media_prior_type
148
157
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
149
158
  beta_m = yield prior_broadcast.beta_m
150
159
  else:
@@ -160,14 +169,14 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
160
169
  treatment_parameter_m * media_tensors.prior_denominator
161
170
  )
162
171
  linear_predictor_counterfactual_difference = (
163
- mmm.linear_predictor_counterfactual_difference_media(
172
+ model_equations.linear_predictor_counterfactual_difference_media(
164
173
  media_transformed=media_transformed,
165
174
  alpha_m=alpha_m,
166
175
  ec_m=ec_m,
167
176
  slope_m=slope_m,
168
177
  )
169
178
  )
170
- beta_m_value = mmm.calculate_beta_x(
179
+ beta_m_value = model_equations.calculate_beta_x(
171
180
  is_non_media=False,
172
181
  incremental_outcome_x=incremental_outcome_m,
173
182
  linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
@@ -208,10 +217,10 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
208
217
  alpha=alpha_rf,
209
218
  ec=ec_rf,
210
219
  slope=slope_rf,
211
- decay_functions=mmm.adstock_decay_spec.rf,
220
+ decay_functions=model_context.adstock_decay_spec.rf,
212
221
  )
213
222
 
214
- prior_type = mmm.model_spec.effective_rf_prior_type
223
+ prior_type = model_context.model_spec.effective_rf_prior_type
215
224
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
216
225
  beta_rf = yield prior_broadcast.beta_rf
217
226
  else:
@@ -227,14 +236,14 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
227
236
  treatment_parameter_rf * rf_tensors.prior_denominator
228
237
  )
229
238
  linear_predictor_counterfactual_difference = (
230
- mmm.linear_predictor_counterfactual_difference_rf(
239
+ model_equations.linear_predictor_counterfactual_difference_rf(
231
240
  rf_transformed=rf_transformed,
232
241
  alpha_rf=alpha_rf,
233
242
  ec_rf=ec_rf,
234
243
  slope_rf=slope_rf,
235
244
  )
236
245
  )
237
- beta_rf_value = mmm.calculate_beta_x(
246
+ beta_rf_value = model_equations.calculate_beta_x(
238
247
  is_non_media=False,
239
248
  incremental_outcome_x=incremental_outcome_rf,
240
249
  linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
@@ -274,15 +283,15 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
274
283
  alpha=alpha_om,
275
284
  ec=ec_om,
276
285
  slope=slope_om,
277
- decay_functions=mmm.adstock_decay_spec.organic_media,
286
+ decay_functions=model_context.adstock_decay_spec.organic_media,
278
287
  )
279
- prior_type = mmm.model_spec.organic_media_prior_type
288
+ prior_type = model_context.model_spec.organic_media_prior_type
280
289
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
281
290
  beta_om = yield prior_broadcast.beta_om
282
291
  elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
283
292
  contribution_om = yield prior_broadcast.contribution_om
284
293
  incremental_outcome_om = contribution_om * total_outcome
285
- beta_om_value = mmm.calculate_beta_x(
294
+ beta_om_value = model_equations.calculate_beta_x(
286
295
  is_non_media=False,
287
296
  incremental_outcome_x=incremental_outcome_om,
288
297
  linear_predictor_counterfactual_difference=organic_media_transformed,
@@ -325,16 +334,16 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
325
334
  alpha=alpha_orf,
326
335
  ec=ec_orf,
327
336
  slope=slope_orf,
328
- decay_functions=mmm.adstock_decay_spec.organic_rf,
337
+ decay_functions=model_context.adstock_decay_spec.organic_rf,
329
338
  )
330
339
 
331
- prior_type = mmm.model_spec.organic_rf_prior_type
340
+ prior_type = model_context.model_spec.organic_rf_prior_type
332
341
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
333
342
  beta_orf = yield prior_broadcast.beta_orf
334
343
  elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
335
344
  contribution_orf = yield prior_broadcast.contribution_orf
336
345
  incremental_outcome_orf = contribution_orf * total_outcome
337
- beta_orf_value = mmm.calculate_beta_x(
346
+ beta_orf_value = model_equations.calculate_beta_x(
338
347
  is_non_media=False,
339
348
  incremental_outcome_x=incremental_outcome_orf,
340
349
  linear_predictor_counterfactual_difference=organic_rf_transformed,
@@ -382,26 +391,26 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
382
391
  "gtc,gc->gt", controls_scaled, gamma_gc
383
392
  )
384
393
 
385
- if mmm.non_media_treatments is not None:
394
+ if model_context.non_media_treatments is not None:
386
395
  xi_n = yield prior_broadcast.xi_n
387
396
  gamma_gn_dev = yield backend.tfd.Sample(
388
397
  backend.tfd.Normal(0, 1),
389
398
  [n_geos, n_non_media_channels],
390
399
  name=constants.GAMMA_GN_DEV,
391
400
  )
392
- prior_type = mmm.model_spec.non_media_treatments_prior_type
401
+ prior_type = model_context.model_spec.non_media_treatments_prior_type
393
402
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
394
403
  gamma_n = yield prior_broadcast.gamma_n
395
404
  elif prior_type == constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION:
396
405
  contribution_n = yield prior_broadcast.contribution_n
397
406
  incremental_outcome_n = contribution_n * total_outcome
398
- baseline_scaled = mmm.non_media_transformer.forward( # pytype: disable=attribute-error
399
- mmm.compute_non_media_treatments_baseline()
407
+ baseline_scaled = model_context.non_media_transformer.forward( # pytype: disable=attribute-error
408
+ model_equations.compute_non_media_treatments_baseline()
400
409
  )
401
410
  linear_predictor_counterfactual_difference = (
402
411
  non_media_treatments_normalized - baseline_scaled
403
412
  )
404
- gamma_n_value = mmm.calculate_beta_x(
413
+ gamma_n_value = model_equations.calculate_beta_x(
405
414
  is_non_media=True,
406
415
  incremental_outcome_x=incremental_outcome_n,
407
416
  linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
@@ -439,10 +448,36 @@ def _joint_dist_unpinned(mmm: "model.Meridian"):
439
448
  class PosteriorMCMCSampler:
440
449
  """A callable that samples from posterior distributions using MCMC."""
441
450
 
442
- def __init__(self, meridian: "model.Meridian"):
443
- self._meridian = meridian
451
+ # TODO: Deprecate direct injection of `model.Meridian`.
452
+ def __init__(
453
+ self,
454
+ meridian: Optional["model.Meridian"] = None,
455
+ *,
456
+ model_context: context.ModelContext | None = None,
457
+ ):
458
+ if meridian is not None:
459
+ warnings.warn(
460
+ "Initializing PosteriorMCMCSampler with a Meridian object is"
461
+ " deprecated and will be removed in a future version. Please use"
462
+ " `model_context` instead.",
463
+ DeprecationWarning,
464
+ stacklevel=2,
465
+ )
466
+ self._meridian = meridian
467
+ self._model_context = meridian.model_context
468
+ elif model_context is not None:
469
+ self._meridian = None
470
+ self._model_context = model_context
471
+ else:
472
+ raise ValueError(
473
+ "Either `meridian` or `model_context` must be provided."
474
+ )
444
475
  self._joint_dist = None
445
476
 
477
+ @functools.cached_property
478
+ def _model_equations(self) -> equations.ModelEquations:
479
+ return equations.ModelEquations(self._model_context)
480
+
446
481
  def __getstate__(self):
447
482
  state = self.__dict__.copy()
448
483
  # Exclude unpickleable objects.
@@ -454,29 +489,33 @@ class PosteriorMCMCSampler:
454
489
  self.__dict__.update(state)
455
490
  self._joint_dist = None
456
491
 
492
+ # TODO: Remove this property in favor of using `model_context`
493
+ # and `model_equations` directly.
457
494
  @property
458
- def model(self) -> "model.Meridian":
495
+ def model(self) -> Optional["model.Meridian"]:
459
496
  return self._meridian
460
497
 
461
498
  def _joint_dist_unpinned_fn(self):
462
- return _joint_dist_unpinned(self.model)
499
+ return _joint_dist_unpinned(self._model_context, self._model_equations)
463
500
 
464
501
  def _get_joint_dist_unpinned(self) -> backend.tfd.Distribution:
465
502
  """Builds a `JointDistributionCoroutineAutoBatched` function for MCMC."""
466
- mmm = self.model
467
- mmm.populate_cached_properties()
503
+ self._model_context.populate_cached_properties()
468
504
  return backend.tfd.JointDistributionCoroutineAutoBatched(
469
505
  self._joint_dist_unpinned_fn
470
506
  )
471
507
 
472
508
  def _get_joint_dist(self) -> backend.tfd.Distribution:
509
+ """Returns a joint distribution for MCMC sampling."""
473
510
  if self._joint_dist is None:
474
- mmm = self.model
475
- y = (
476
- backend.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
477
- if mmm.holdout_id is not None
478
- else mmm.kpi_scaled
479
- )
511
+ if self._model_context.holdout_id is not None:
512
+ y = backend.where(
513
+ self._model_context.holdout_id,
514
+ 0.0,
515
+ self._model_context.kpi_scaled,
516
+ )
517
+ else:
518
+ y = self._model_context.kpi_scaled
480
519
  self._joint_dist = self._get_joint_dist_unpinned().experimental_pin(y=y)
481
520
  return self._joint_dist
482
521
 
@@ -495,7 +534,7 @@ class PosteriorMCMCSampler:
495
534
  parallel_iterations: int = 10,
496
535
  seed: Sequence[int] | int | None = None,
497
536
  **pins,
498
- ) -> None:
537
+ ) -> az.InferenceData:
499
538
  """Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions.
500
539
 
501
540
  For more information about the arguments, see [`windowed_adaptive_nuts`]
@@ -548,6 +587,10 @@ class PosteriorMCMCSampler:
548
587
  **pins: These are used to condition the provided joint distribution, and
549
588
  are passed directly to `joint_dist.experimental_pin(**pins)`.
550
589
 
590
+ Returns:
591
+ An `arviz.InferenceData` object containing the posterior samples, trace
592
+ metrics, and sampling statistics.
593
+
551
594
  Throws:
552
595
  MCMCOOMError: If the model is out of memory. Try reducing `n_keep` or pass
553
596
  a list of integers as `n_chains` to sample chains serially. For more
@@ -604,10 +647,10 @@ class PosteriorMCMCSampler:
604
647
  if k not in constants.UNSAVED_PARAMETERS
605
648
  }
606
649
  # Create Arviz InferenceData for posterior draws.
607
- posterior_coords = self.model.create_inference_data_coords(
650
+ posterior_coords = self._model_context.create_inference_data_coords(
608
651
  total_chains, n_keep
609
652
  )
610
- posterior_dims = self.model.create_inference_data_dims()
653
+ posterior_dims = self._model_context.create_inference_data_dims()
611
654
  infdata_posterior = az.convert_to_inference_data(
612
655
  mcmc_states, coords=posterior_coords, dims=posterior_dims
613
656
  )
@@ -669,7 +712,5 @@ class PosteriorMCMCSampler:
669
712
  dims=sample_stats_dims,
670
713
  group="sample_stats",
671
714
  )
672
- posterior_inference_data = az.concat(
673
- infdata_posterior, infdata_trace, infdata_sample_stats
674
- )
675
- self.model.inference_data.extend(posterior_inference_data, join="right")
715
+
716
+ return az.concat(infdata_posterior, infdata_trace, infdata_sample_stats)