google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
  2. google_meridian-1.5.0.dist-info/RECORD +112 -0
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
  5. meridian/analysis/analyzer.py +558 -398
  6. meridian/analysis/optimizer.py +90 -68
  7. meridian/analysis/review/reviewer.py +4 -1
  8. meridian/analysis/summarizer.py +13 -3
  9. meridian/analysis/test_utils.py +2911 -2102
  10. meridian/analysis/visualizer.py +37 -14
  11. meridian/backend/__init__.py +106 -0
  12. meridian/constants.py +2 -0
  13. meridian/data/input_data.py +30 -52
  14. meridian/data/input_data_builder.py +2 -9
  15. meridian/data/test_utils.py +107 -51
  16. meridian/data/validator.py +48 -0
  17. meridian/mlflow/autolog.py +19 -9
  18. meridian/model/__init__.py +2 -0
  19. meridian/model/adstock_hill.py +3 -5
  20. meridian/model/context.py +1059 -0
  21. meridian/model/eda/constants.py +335 -4
  22. meridian/model/eda/eda_engine.py +723 -312
  23. meridian/model/eda/eda_outcome.py +177 -33
  24. meridian/model/equations.py +418 -0
  25. meridian/model/knots.py +58 -47
  26. meridian/model/model.py +228 -878
  27. meridian/model/model_test_data.py +38 -0
  28. meridian/model/posterior_sampler.py +103 -62
  29. meridian/model/prior_sampler.py +114 -94
  30. meridian/model/spec.py +23 -14
  31. meridian/templates/card.html.jinja +9 -7
  32. meridian/templates/chart.html.jinja +1 -6
  33. meridian/templates/finding.html.jinja +19 -0
  34. meridian/templates/findings.html.jinja +33 -0
  35. meridian/templates/formatter.py +41 -5
  36. meridian/templates/formatter_test.py +127 -0
  37. meridian/templates/style.css +66 -9
  38. meridian/templates/style.scss +85 -4
  39. meridian/templates/table.html.jinja +1 -0
  40. meridian/version.py +1 -1
  41. scenarioplanner/__init__.py +42 -0
  42. scenarioplanner/converters/__init__.py +25 -0
  43. scenarioplanner/converters/dataframe/__init__.py +28 -0
  44. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  45. scenarioplanner/converters/dataframe/common.py +71 -0
  46. scenarioplanner/converters/dataframe/constants.py +137 -0
  47. scenarioplanner/converters/dataframe/converter.py +42 -0
  48. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  49. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  50. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  51. scenarioplanner/converters/mmm.py +743 -0
  52. scenarioplanner/converters/mmm_converter.py +58 -0
  53. scenarioplanner/converters/sheets.py +156 -0
  54. scenarioplanner/converters/test_data.py +714 -0
  55. scenarioplanner/linkingapi/__init__.py +47 -0
  56. scenarioplanner/linkingapi/constants.py +27 -0
  57. scenarioplanner/linkingapi/url_generator.py +131 -0
  58. scenarioplanner/mmm_ui_proto_generator.py +355 -0
  59. schema/__init__.py +5 -2
  60. schema/mmm_proto_generator.py +71 -0
  61. schema/model_consumer.py +133 -0
  62. schema/processors/__init__.py +77 -0
  63. schema/processors/budget_optimization_processor.py +832 -0
  64. schema/processors/common.py +64 -0
  65. schema/processors/marketing_processor.py +1137 -0
  66. schema/processors/model_fit_processor.py +367 -0
  67. schema/processors/model_kernel_processor.py +117 -0
  68. schema/processors/model_processor.py +415 -0
  69. schema/processors/reach_frequency_optimization_processor.py +584 -0
  70. schema/serde/distribution.py +12 -7
  71. schema/serde/hyperparameters.py +54 -107
  72. schema/serde/meridian_serde.py +6 -1
  73. schema/test_data.py +380 -0
  74. schema/utils/__init__.py +2 -0
  75. schema/utils/date_range_bucketing.py +117 -0
  76. schema/utils/proto_enum_converter.py +127 -0
  77. google_meridian-1.3.2.dist-info/RECORD +0 -76
  78. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -15,12 +15,16 @@
15
15
  """Module for sampling prior distributions in a Meridian model."""
16
16
 
17
17
  from collections.abc import Mapping
18
- from typing import TYPE_CHECKING
18
+ import functools
19
+ from typing import Optional, TYPE_CHECKING
20
+ import warnings
19
21
 
20
- import arviz as az
21
22
  from meridian import backend
22
23
  from meridian import constants
24
+ from meridian.model import context
25
+ from meridian.model import equations
23
26
 
27
+ # TODO: Break this circular dependency.
24
28
  if TYPE_CHECKING:
25
29
  from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
26
30
 
@@ -63,8 +67,34 @@ def _get_tau_g(
63
67
  class PriorDistributionSampler:
64
68
  """A callable that samples from a model spec's prior distributions."""
65
69
 
66
- def __init__(self, meridian: "model.Meridian"):
67
- self._meridian = meridian
70
+ # TODO: Deprecate direct injection of `model.Meridian`.
71
+ def __init__(
72
+ self,
73
+ meridian: Optional["model.Meridian"] = None,
74
+ *,
75
+ model_context: context.ModelContext | None = None,
76
+ ):
77
+ if meridian is not None:
78
+ warnings.warn(
79
+ "Initializing PriorDistributionSampler with a Meridian object is"
80
+ " deprecated and will be removed in a future version. Please use"
81
+ " `model_context` instead.",
82
+ DeprecationWarning,
83
+ stacklevel=2,
84
+ )
85
+ self._meridian = meridian
86
+ self._model_context = meridian.model_context
87
+ elif model_context is not None:
88
+ self._meridian = None
89
+ self._model_context = model_context
90
+ else:
91
+ raise ValueError(
92
+ "Either `meridian` or `model_context` must be provided."
93
+ )
94
+
95
+ @functools.cached_property
96
+ def _model_equations(self) -> equations.ModelEquations:
97
+ return equations.ModelEquations(self._model_context)
68
98
 
69
99
  def _sample_media_priors(
70
100
  self,
@@ -82,9 +112,9 @@ class PriorDistributionSampler:
82
112
  n_media_channels]` or `[n_draws, n_media_channels]` containing the
83
113
  samples.
84
114
  """
85
- mmm = self._meridian
115
+ ctx = self._model_context
86
116
 
87
- prior = mmm.prior_broadcast
117
+ prior = ctx.prior_broadcast
88
118
  sample_shape = [1, n_draws]
89
119
 
90
120
  media_vars = {
@@ -103,11 +133,11 @@ class PriorDistributionSampler:
103
133
  }
104
134
  beta_gm_dev = backend.tfd.Sample(
105
135
  backend.tfd.Normal(0, 1),
106
- [mmm.n_geos, mmm.n_media_channels],
136
+ [ctx.n_geos, ctx.n_media_channels],
107
137
  name=constants.BETA_GM_DEV,
108
138
  ).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
109
139
 
110
- prior_type = mmm.model_spec.effective_media_prior_type
140
+ prior_type = ctx.model_spec.effective_media_prior_type
111
141
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
112
142
  media_vars[constants.BETA_M] = prior.beta_m.sample(
113
143
  sample_shape=sample_shape, seed=rng_handler.get_next_seed()
@@ -131,24 +161,22 @@ class PriorDistributionSampler:
131
161
  else:
132
162
  raise ValueError(f"Unsupported prior type: {prior_type}")
133
163
  incremental_outcome_m = (
134
- treatment_parameter_m * mmm.media_tensors.prior_denominator
164
+ treatment_parameter_m * ctx.media_tensors.prior_denominator
135
165
  )
136
- media_transformed = mmm.adstock_hill_media(
137
- media=mmm.media_tensors.media_scaled,
166
+ media_transformed = self._model_equations.adstock_hill_media(
167
+ media=ctx.media_tensors.media_scaled,
138
168
  alpha=media_vars[constants.ALPHA_M],
139
169
  ec=media_vars[constants.EC_M],
140
170
  slope=media_vars[constants.SLOPE_M],
141
- decay_functions=mmm.adstock_decay_spec.media,
171
+ decay_functions=ctx.adstock_decay_spec.media,
142
172
  )
143
- linear_predictor_counterfactual_difference = (
144
- mmm.linear_predictor_counterfactual_difference_media(
145
- media_transformed=media_transformed,
146
- alpha_m=media_vars[constants.ALPHA_M],
147
- ec_m=media_vars[constants.EC_M],
148
- slope_m=media_vars[constants.SLOPE_M],
149
- )
173
+ linear_predictor_counterfactual_difference = self._model_equations.linear_predictor_counterfactual_difference_media(
174
+ media_transformed=media_transformed,
175
+ alpha_m=media_vars[constants.ALPHA_M],
176
+ ec_m=media_vars[constants.EC_M],
177
+ slope_m=media_vars[constants.SLOPE_M],
150
178
  )
151
- beta_m_value = mmm.calculate_beta_x(
179
+ beta_m_value = self._model_equations.calculate_beta_x(
152
180
  is_non_media=False,
153
181
  incremental_outcome_x=incremental_outcome_m,
154
182
  linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
@@ -165,7 +193,7 @@ class PriorDistributionSampler:
165
193
  )
166
194
  beta_gm_value = (
167
195
  beta_eta_combined
168
- if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
196
+ if ctx.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
169
197
  else backend.exp(beta_eta_combined)
170
198
  )
171
199
  media_vars[constants.BETA_GM] = backend.tfd.Deterministic(
@@ -190,9 +218,9 @@ class PriorDistributionSampler:
190
218
  `[n_draws, n_geos, n_rf_channels]` or `[n_draws, n_rf_channels]`
191
219
  containing the samples.
192
220
  """
193
- mmm = self._meridian
221
+ ctx = self._model_context
194
222
 
195
- prior = mmm.prior_broadcast
223
+ prior = ctx.prior_broadcast
196
224
  sample_shape = [1, n_draws]
197
225
 
198
226
  rf_vars = {
@@ -211,11 +239,11 @@ class PriorDistributionSampler:
211
239
  }
212
240
  beta_grf_dev = backend.tfd.Sample(
213
241
  backend.tfd.Normal(0, 1),
214
- [mmm.n_geos, mmm.n_rf_channels],
242
+ [ctx.n_geos, ctx.n_rf_channels],
215
243
  name=constants.BETA_GRF_DEV,
216
244
  ).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
217
245
 
218
- prior_type = mmm.model_spec.effective_rf_prior_type
246
+ prior_type = ctx.model_spec.effective_rf_prior_type
219
247
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
220
248
  rf_vars[constants.BETA_RF] = prior.beta_rf.sample(
221
249
  sample_shape=sample_shape, seed=rng_handler.get_next_seed()
@@ -239,25 +267,25 @@ class PriorDistributionSampler:
239
267
  else:
240
268
  raise ValueError(f"Unsupported prior type: {prior_type}")
241
269
  incremental_outcome_rf = (
242
- treatment_parameter_rf * mmm.rf_tensors.prior_denominator
270
+ treatment_parameter_rf * ctx.rf_tensors.prior_denominator
243
271
  )
244
- rf_transformed = mmm.adstock_hill_rf(
245
- reach=mmm.rf_tensors.reach_scaled,
246
- frequency=mmm.rf_tensors.frequency,
272
+ rf_transformed = self._model_equations.adstock_hill_rf(
273
+ reach=ctx.rf_tensors.reach_scaled,
274
+ frequency=ctx.rf_tensors.frequency,
247
275
  alpha=rf_vars[constants.ALPHA_RF],
248
276
  ec=rf_vars[constants.EC_RF],
249
277
  slope=rf_vars[constants.SLOPE_RF],
250
- decay_functions=mmm.adstock_decay_spec.rf,
278
+ decay_functions=ctx.adstock_decay_spec.rf,
251
279
  )
252
280
  linear_predictor_counterfactual_difference = (
253
- mmm.linear_predictor_counterfactual_difference_rf(
281
+ self._model_equations.linear_predictor_counterfactual_difference_rf(
254
282
  rf_transformed=rf_transformed,
255
283
  alpha_rf=rf_vars[constants.ALPHA_RF],
256
284
  ec_rf=rf_vars[constants.EC_RF],
257
285
  slope_rf=rf_vars[constants.SLOPE_RF],
258
286
  )
259
287
  )
260
- beta_rf_value = mmm.calculate_beta_x(
288
+ beta_rf_value = self._model_equations.calculate_beta_x(
261
289
  is_non_media=False,
262
290
  incremental_outcome_x=incremental_outcome_rf,
263
291
  linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
@@ -275,7 +303,7 @@ class PriorDistributionSampler:
275
303
  )
276
304
  beta_grf_value = (
277
305
  beta_eta_combined
278
- if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
306
+ if ctx.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
279
307
  else backend.exp(beta_eta_combined)
280
308
  )
281
309
  rf_vars[constants.BETA_GRF] = backend.tfd.Deterministic(
@@ -300,9 +328,9 @@ class PriorDistributionSampler:
300
328
  `[n_draws, n_geos, n_organic_media_channels]` or
301
329
  `[n_draws, n_organic_media_channels]` containing the samples.
302
330
  """
303
- mmm = self._meridian
331
+ ctx = self._model_context
304
332
 
305
- prior = mmm.prior_broadcast
333
+ prior = ctx.prior_broadcast
306
334
  sample_shape = [1, n_draws]
307
335
 
308
336
  organic_media_vars = {
@@ -321,11 +349,11 @@ class PriorDistributionSampler:
321
349
  }
322
350
  beta_gom_dev = backend.tfd.Sample(
323
351
  backend.tfd.Normal(0, 1),
324
- [mmm.n_geos, mmm.n_organic_media_channels],
352
+ [ctx.n_geos, ctx.n_organic_media_channels],
325
353
  name=constants.BETA_GOM_DEV,
326
354
  ).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
327
355
 
328
- prior_type = mmm.model_spec.organic_media_prior_type
356
+ prior_type = ctx.model_spec.organic_media_prior_type
329
357
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
330
358
  organic_media_vars[constants.BETA_OM] = prior.beta_om.sample(
331
359
  sample_shape=sample_shape, seed=rng_handler.get_next_seed()
@@ -337,16 +365,16 @@ class PriorDistributionSampler:
337
365
  )
338
366
  )
339
367
  incremental_outcome_om = (
340
- organic_media_vars[constants.CONTRIBUTION_OM] * mmm.total_outcome
368
+ organic_media_vars[constants.CONTRIBUTION_OM] * ctx.total_outcome
341
369
  )
342
- organic_media_transformed = mmm.adstock_hill_media(
343
- media=mmm.organic_media_tensors.organic_media_scaled,
370
+ organic_media_transformed = self._model_equations.adstock_hill_media(
371
+ media=ctx.organic_media_tensors.organic_media_scaled,
344
372
  alpha=organic_media_vars[constants.ALPHA_OM],
345
373
  ec=organic_media_vars[constants.EC_OM],
346
374
  slope=organic_media_vars[constants.SLOPE_OM],
347
- decay_functions=mmm.adstock_decay_spec.organic_media,
375
+ decay_functions=ctx.adstock_decay_spec.organic_media,
348
376
  )
349
- beta_om_value = mmm.calculate_beta_x(
377
+ beta_om_value = self._model_equations.calculate_beta_x(
350
378
  is_non_media=False,
351
379
  incremental_outcome_x=incremental_outcome_om,
352
380
  linear_predictor_counterfactual_difference=organic_media_transformed,
@@ -367,7 +395,7 @@ class PriorDistributionSampler:
367
395
  )
368
396
  beta_gom_value = (
369
397
  beta_eta_combined
370
- if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
398
+ if ctx.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
371
399
  else backend.exp(beta_eta_combined)
372
400
  )
373
401
  organic_media_vars[constants.BETA_GOM] = backend.tfd.Deterministic(
@@ -392,9 +420,9 @@ class PriorDistributionSampler:
392
420
  `[n_draws, n_geos, n_organic_rf_channels]` or
393
421
  `[n_draws, n_organic_rf_channels]` containing the samples.
394
422
  """
395
- mmm = self._meridian
423
+ ctx = self._model_context
396
424
 
397
- prior = mmm.prior_broadcast
425
+ prior = ctx.prior_broadcast
398
426
  sample_shape = [1, n_draws]
399
427
 
400
428
  organic_rf_vars = {
@@ -413,11 +441,11 @@ class PriorDistributionSampler:
413
441
  }
414
442
  beta_gorf_dev = backend.tfd.Sample(
415
443
  backend.tfd.Normal(0, 1),
416
- [mmm.n_geos, mmm.n_organic_rf_channels],
444
+ [ctx.n_geos, ctx.n_organic_rf_channels],
417
445
  name=constants.BETA_GORF_DEV,
418
446
  ).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
419
447
 
420
- prior_type = mmm.model_spec.organic_media_prior_type
448
+ prior_type = ctx.model_spec.organic_media_prior_type
421
449
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
422
450
  organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(
423
451
  sample_shape=sample_shape, seed=rng_handler.get_next_seed()
@@ -429,17 +457,17 @@ class PriorDistributionSampler:
429
457
  )
430
458
  )
431
459
  incremental_outcome_orf = (
432
- organic_rf_vars[constants.CONTRIBUTION_ORF] * mmm.total_outcome
460
+ organic_rf_vars[constants.CONTRIBUTION_ORF] * ctx.total_outcome
433
461
  )
434
- organic_rf_transformed = mmm.adstock_hill_rf(
435
- reach=mmm.organic_rf_tensors.organic_reach_scaled,
436
- frequency=mmm.organic_rf_tensors.organic_frequency,
462
+ organic_rf_transformed = self._model_equations.adstock_hill_rf(
463
+ reach=ctx.organic_rf_tensors.organic_reach_scaled,
464
+ frequency=ctx.organic_rf_tensors.organic_frequency,
437
465
  alpha=organic_rf_vars[constants.ALPHA_ORF],
438
466
  ec=organic_rf_vars[constants.EC_ORF],
439
467
  slope=organic_rf_vars[constants.SLOPE_ORF],
440
- decay_functions=mmm.adstock_decay_spec.organic_rf,
468
+ decay_functions=ctx.adstock_decay_spec.organic_rf,
441
469
  )
442
- beta_orf_value = mmm.calculate_beta_x(
470
+ beta_orf_value = self._model_equations.calculate_beta_x(
443
471
  is_non_media=False,
444
472
  incremental_outcome_x=incremental_outcome_orf,
445
473
  linear_predictor_counterfactual_difference=organic_rf_transformed,
@@ -460,7 +488,7 @@ class PriorDistributionSampler:
460
488
  )
461
489
  beta_gorf_value = (
462
490
  beta_eta_combined
463
- if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
491
+ if ctx.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
464
492
  else backend.exp(beta_eta_combined)
465
493
  )
466
494
  organic_rf_vars[constants.BETA_GORF] = backend.tfd.Deterministic(
@@ -485,9 +513,9 @@ class PriorDistributionSampler:
485
513
  `[n_draws, n_geos, n_non_media_channels]` or
486
514
  `[n_draws, n_non_media_channels]` containing the samples.
487
515
  """
488
- mmm = self._meridian
516
+ ctx = self._model_context
489
517
 
490
- prior = mmm.prior_broadcast
518
+ prior = ctx.prior_broadcast
491
519
  sample_shape = [1, n_draws]
492
520
 
493
521
  non_media_treatments_vars = {
@@ -497,10 +525,10 @@ class PriorDistributionSampler:
497
525
  }
498
526
  gamma_gn_dev = backend.tfd.Sample(
499
527
  backend.tfd.Normal(0, 1),
500
- [mmm.n_geos, mmm.n_non_media_channels],
528
+ [ctx.n_geos, ctx.n_non_media_channels],
501
529
  name=constants.GAMMA_GN_DEV,
502
530
  ).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
503
- prior_type = mmm.model_spec.non_media_treatments_prior_type
531
+ prior_type = ctx.model_spec.non_media_treatments_prior_type
504
532
  if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
505
533
  non_media_treatments_vars[constants.GAMMA_N] = prior.gamma_n.sample(
506
534
  sample_shape=sample_shape, seed=rng_handler.get_next_seed()
@@ -513,15 +541,15 @@ class PriorDistributionSampler:
513
541
  )
514
542
  incremental_outcome_n = (
515
543
  non_media_treatments_vars[constants.CONTRIBUTION_N]
516
- * mmm.total_outcome
544
+ * ctx.total_outcome
517
545
  )
518
- baseline_scaled = mmm.non_media_transformer.forward( # pytype: disable=attribute-error
519
- mmm.compute_non_media_treatments_baseline()
546
+ baseline_scaled = ctx.non_media_transformer.forward( # pytype: disable=attribute-error
547
+ self._model_equations.compute_non_media_treatments_baseline()
520
548
  )
521
549
  linear_predictor_counterfactual_difference = (
522
- mmm.non_media_treatments_normalized - baseline_scaled
550
+ ctx.non_media_treatments_normalized - baseline_scaled
523
551
  )
524
- gamma_n_value = mmm.calculate_beta_x(
552
+ gamma_n_value = self._model_equations.calculate_beta_x(
525
553
  is_non_media=True,
526
554
  incremental_outcome_x=incremental_outcome_n,
527
555
  linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
@@ -541,13 +569,23 @@ class PriorDistributionSampler:
541
569
  ).sample(seed=rng_handler.get_next_seed())
542
570
  return non_media_treatments_vars
543
571
 
544
- def _sample_prior(
572
+ def __call__(
545
573
  self,
546
574
  n_draws: int,
547
575
  seed: int | None = None,
548
576
  ) -> Mapping[str, backend.Tensor]:
549
- """Returns a mapping of prior parameters to tensors of the samples."""
550
- mmm = self._meridian
577
+ """Draws samples from prior distributions.
578
+
579
+ Args:
580
+ n_draws: Number of samples drawn from the prior distribution.
581
+ seed: Used to set the seed for reproducible results. For more information,
582
+ see [PRNGS and seeds]
583
+ (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
584
+
585
+ Returns:
586
+ A mapping of prior parameter names to tensors of the samples.
587
+ """
588
+ ctx = self._model_context
551
589
 
552
590
  # For stateful sampling, the random seed must be set to ensure that any
553
591
  # random numbers that are generated are deterministic.
@@ -556,7 +594,7 @@ class PriorDistributionSampler:
556
594
 
557
595
  rng_handler = backend.RNGHandler(seed)
558
596
 
559
- prior = mmm.prior_broadcast
597
+ prior = ctx.prior_broadcast
560
598
  # `sample_shape` is prepended to the shape of each BatchBroadcast in `prior`
561
599
  # when it is sampled.
562
600
  sample_shape = [1, n_draws]
@@ -574,7 +612,7 @@ class PriorDistributionSampler:
574
612
  constants.TAU_G: (
575
613
  _get_tau_g(
576
614
  tau_g_excl_baseline=tau_g_excl_baseline,
577
- baseline_geo_idx=mmm.baseline_geo_idx,
615
+ baseline_geo_idx=ctx.baseline_geo_idx,
578
616
  ).sample(seed=rng_handler.get_next_seed())
579
617
  ),
580
618
  }
@@ -583,14 +621,14 @@ class PriorDistributionSampler:
583
621
  backend.einsum(
584
622
  "...k,kt->...t",
585
623
  base_vars[constants.KNOT_VALUES],
586
- backend.to_tensor(mmm.knot_info.weights),
624
+ backend.to_tensor(ctx.knot_info.weights),
587
625
  ),
588
626
  name=constants.MU_T,
589
627
  ).sample(seed=rng_handler.get_next_seed())
590
628
 
591
629
  # Omit gamma_c, xi_c, and gamma_gc parameters from sampled distributions if
592
630
  # there are no control variables in the model.
593
- if mmm.n_controls:
631
+ if ctx.n_controls:
594
632
  base_vars |= {
595
633
  constants.GAMMA_C: prior.gamma_c.sample(
596
634
  sample_shape=sample_shape, seed=rng_handler.get_next_seed()
@@ -602,7 +640,7 @@ class PriorDistributionSampler:
602
640
 
603
641
  gamma_gc_dev = backend.tfd.Sample(
604
642
  backend.tfd.Normal(0, 1),
605
- [mmm.n_geos, mmm.n_controls],
643
+ [ctx.n_geos, ctx.n_controls],
606
644
  name=constants.GAMMA_GC_DEV,
607
645
  ).sample(sample_shape=sample_shape, seed=rng_handler.get_next_seed())
608
646
  base_vars[constants.GAMMA_GC] = backend.tfd.Deterministic(
@@ -613,27 +651,27 @@ class PriorDistributionSampler:
613
651
 
614
652
  media_vars = (
615
653
  self._sample_media_priors(n_draws, rng_handler)
616
- if mmm.media_tensors.media is not None
654
+ if ctx.media_tensors.media is not None
617
655
  else {}
618
656
  )
619
657
  rf_vars = (
620
658
  self._sample_rf_priors(n_draws, rng_handler)
621
- if mmm.rf_tensors.reach is not None
659
+ if ctx.rf_tensors.reach is not None
622
660
  else {}
623
661
  )
624
662
  organic_media_vars = (
625
663
  self._sample_organic_media_priors(n_draws, rng_handler)
626
- if mmm.organic_media_tensors.organic_media is not None
664
+ if ctx.organic_media_tensors.organic_media is not None
627
665
  else {}
628
666
  )
629
667
  organic_rf_vars = (
630
668
  self._sample_organic_rf_priors(n_draws, rng_handler)
631
- if mmm.organic_rf_tensors.organic_reach is not None
669
+ if ctx.organic_rf_tensors.organic_reach is not None
632
670
  else {}
633
671
  )
634
672
  non_media_treatments_vars = (
635
673
  self._sample_non_media_treatments_priors(n_draws, rng_handler)
636
- if mmm.non_media_treatments_normalized is not None
674
+ if ctx.non_media_treatments_normalized is not None
637
675
  else {}
638
676
  )
639
677
 
@@ -645,21 +683,3 @@ class PriorDistributionSampler:
645
683
  | organic_rf_vars
646
684
  | non_media_treatments_vars
647
685
  )
648
-
649
- def __call__(self, n_draws: int, seed: int | None = None) -> None:
650
- """Draws samples from prior distributions.
651
-
652
- Args:
653
- n_draws: Number of samples drawn from the prior distribution.
654
- seed: Used to set the seed for reproducible results. For more information,
655
- see [PRNGS and seeds]
656
- (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
657
- """
658
- prior_draws = self._sample_prior(n_draws=n_draws, seed=seed)
659
- # Create Arviz InferenceData for prior draws.
660
- prior_coords = self._meridian.create_inference_data_coords(1, n_draws)
661
- prior_dims = self._meridian.create_inference_data_dims()
662
- prior_inference_data = az.convert_to_inference_data(
663
- prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR
664
- )
665
- self._meridian.inference_data.extend(prior_inference_data, join="right")
meridian/model/spec.py CHANGED
@@ -14,11 +14,10 @@
14
14
 
15
15
  """Defines model specification parameters for Meridian."""
16
16
 
17
- from collections.abc import Mapping
17
+ from collections.abc import Collection, Mapping
18
18
  import dataclasses
19
19
  from typing import Sequence
20
20
  import warnings
21
-
22
21
  from meridian import constants
23
22
  from meridian.model import prior_distribution
24
23
  import numpy as np
@@ -166,17 +165,17 @@ class ModelSpec:
166
165
  given non_media treatments channel). If `None`, the minimum value is used
167
166
  as baseline for each non-media treatments channel. This attribute is used
168
167
  as the default value for the corresponding argument to `Analyzer` methods.
169
- knots: An optional integer or list of integers indicating the knots used to
170
- estimate time effects. When `knots` is a list of integers, the knot
171
- locations are provided by that list. Zero corresponds to a knot at the
172
- first time period, one corresponds to a knot at the second time period,
173
- ..., and `(n_times - 1)` corresponds to a knot at the last time period).
174
- Typically, we recommend including knots at `0` and `(n_times - 1)`, but
175
- this is not required. When `knots` is an integer, then there are knots
176
- with locations equally spaced across the time periods, (including knots at
177
- zero and `(n_times - 1)`. When `knots` is` 1`, there is a single common
178
- regression coefficient used for all time periods. If `knots` is set to
179
- `None`, then the numbers of knots used is equal to the number of time
168
+ knots: An optional integer or collection of integers indicating the knots
169
+ used to estimate time effects. When `knots` is a collection of integers,
170
+ the knot locations are provided by that list. Zero corresponds to a knot
171
+ at the first time period, one corresponds to a knot at the second time
172
+ period, ..., and `(n_times - 1)` corresponds to a knot at the last time
173
+ period). Typically, we recommend including knots at `0` and `(n_times -
174
+ 1)`, but this is not required. When `knots` is an integer, then there are
175
+ knots with locations equally spaced across the time periods, (including
176
+ knots at zero and `(n_times - 1)`. When `knots` is` 1`, there is a single
177
+ common regression coefficient used for all time periods. If `knots` is set
178
+ to `None`, then the numbers of knots used is equal to the number of time
180
179
  periods in the case of a geo model. This is equivalent to each time period
181
180
  having its own regression coefficient. If `knots` is set to `None` in the
182
181
  case of a national model, then the number of knots used is `1`. Default:
@@ -235,7 +234,7 @@ class ModelSpec:
235
234
  constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION
236
235
  )
237
236
  non_media_baseline_values: Sequence[float | str] | None = None
238
- knots: int | list[int] | None = None
237
+ knots: int | Collection[int] | None = None
239
238
  baseline_geo: int | str | None = None
240
239
  holdout_id: np.ndarray | None = None
241
240
  control_population_scaling_id: np.ndarray | None = None
@@ -321,6 +320,12 @@ class ModelSpec:
321
320
  prior_type_name="rf_prior_type",
322
321
  )
323
322
 
323
+ if isinstance(self.knots, Collection):
324
+ knots_list = list(self.knots)
325
+ if not all(isinstance(x, (int, np.integer)) for x in knots_list):
326
+ raise ValueError("`knots` must be a sequence of integers.")
327
+ object.__setattr__(self, "knots", [int(x) for x in knots_list])
328
+
324
329
  # Validate knots.
325
330
  if isinstance(self.knots, list) and not self.knots:
326
331
  raise ValueError("The `knots` parameter cannot be an empty list.")
@@ -330,6 +335,10 @@ class ModelSpec:
330
335
  raise ValueError(
331
336
  "The `knots` parameter cannot be set when `enable_aks` is True."
332
337
  )
338
+ if not (self.knots is None or isinstance(self.knots, (int, list))):
339
+ raise ValueError(
340
+ f"Unsupported type for `knots` parameter: {type(self.knots)}."
341
+ )
333
342
 
334
343
  @property
335
344
  def effective_media_prior_type(self) -> str:
@@ -18,13 +18,15 @@ limitations under the License.
18
18
  <card-title>
19
19
  {{ title }}
20
20
  </card-title>
21
- <card-insights>
22
- <card-insights-icon>
23
- <img src="https://www.gstatic.com/images/icons/material/system/svg/insights_24px.svg"
24
- alt="insights icon" />
25
- </card-insights-icon>
26
- {{ insights }} {# `insights` is a pre-rendered HTML snippet. #}
27
- </card-insights>
21
+ {% if insights %}
22
+ <card-insights>
23
+ <card-insights-icon>
24
+ <img src="https://www.gstatic.com/images/icons/material/system/svg/insights_24px.svg"
25
+ alt="insights icon" />
26
+ </card-insights-icon>
27
+ {{ insights }} {# `insights` is a pre-rendered HTML snippet. #}
28
+ </card-insights>
29
+ {% endif %}
28
30
  {% if stats %}
29
31
  <stats-section>
30
32
  {% for item in stats %} {# Each `item` is a pre-rendered HTML snippet. #}
@@ -15,6 +15,7 @@ limitations under the License.
15
15
  #}
16
16
 
17
17
  <chart>
18
+ {% include "findings.html.jinja" %}
18
19
  <chart-embed id="{{ id }}"></chart-embed>
19
20
  {% if description %}
20
21
  <chart-description>
@@ -25,13 +26,7 @@ limitations under the License.
25
26
 
26
27
  <script type="text/javascript">
27
28
  (() => {
28
- const opt = {
29
- mode: 'vega-lite',
30
- width: 'container',
31
- autosize: { type: 'fit', contains: 'padding' }
32
- };
33
29
  const spec = JSON.parse({{ chart_json|tojson }});
34
- const chartDiv = document.getElementById('{{ id }}');
35
30
  vegaEmbed('#{{ id }}', spec).catch(console.error);
36
31
  })();
37
32
  </script>
@@ -0,0 +1,19 @@
1
+ {#
2
+ Copyright 2026 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ #}
16
+
17
+ <finding class="{{ finding_class }}">
18
+ {{ text }}
19
+ </finding>
@@ -0,0 +1,33 @@
1
+ {#
2
+ Copyright 2025 Google LLC
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ #}
16
+
17
+ {% macro render_findings(finding_type, findings, icon_alt, icon_url) %}
18
+ {% if findings %}
19
+ {% for finding in findings %}
20
+ <{{ finding_type }}s>
21
+ <{{ finding_type }}s-icon>
22
+ <img src="{{ icon_url }}"
23
+ alt="{{ icon_alt }}" />
24
+ </{{ finding_type }}s-icon>
25
+ <p class="{{ finding_type }}-text">{{ finding }}</p>
26
+ </{{ finding_type }}s>
27
+ {% endfor %}
28
+ {% endif %}
29
+ {% endmacro %}
30
+
31
+ {{ render_findings('error', errors, 'error icon', 'https://www.gstatic.com/images/icons/material/system/svg/error_outline_24px.svg') }}
32
+ {{ render_findings('warning', warnings, 'warning icon', 'https://www.gstatic.com/images/icons/material/system/svg/warning_amber_24px.svg') }}
33
+ {{ render_findings('info', infos, 'info icon', 'https://www.gstatic.com/images/icons/material/system/svg/info_outline_24px.svg') }}