google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
  2. google_meridian-1.5.0.dist-info/RECORD +112 -0
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
  5. meridian/analysis/analyzer.py +558 -398
  6. meridian/analysis/optimizer.py +90 -68
  7. meridian/analysis/review/reviewer.py +4 -1
  8. meridian/analysis/summarizer.py +13 -3
  9. meridian/analysis/test_utils.py +2911 -2102
  10. meridian/analysis/visualizer.py +37 -14
  11. meridian/backend/__init__.py +106 -0
  12. meridian/constants.py +2 -0
  13. meridian/data/input_data.py +30 -52
  14. meridian/data/input_data_builder.py +2 -9
  15. meridian/data/test_utils.py +107 -51
  16. meridian/data/validator.py +48 -0
  17. meridian/mlflow/autolog.py +19 -9
  18. meridian/model/__init__.py +2 -0
  19. meridian/model/adstock_hill.py +3 -5
  20. meridian/model/context.py +1059 -0
  21. meridian/model/eda/constants.py +335 -4
  22. meridian/model/eda/eda_engine.py +723 -312
  23. meridian/model/eda/eda_outcome.py +177 -33
  24. meridian/model/equations.py +418 -0
  25. meridian/model/knots.py +58 -47
  26. meridian/model/model.py +228 -878
  27. meridian/model/model_test_data.py +38 -0
  28. meridian/model/posterior_sampler.py +103 -62
  29. meridian/model/prior_sampler.py +114 -94
  30. meridian/model/spec.py +23 -14
  31. meridian/templates/card.html.jinja +9 -7
  32. meridian/templates/chart.html.jinja +1 -6
  33. meridian/templates/finding.html.jinja +19 -0
  34. meridian/templates/findings.html.jinja +33 -0
  35. meridian/templates/formatter.py +41 -5
  36. meridian/templates/formatter_test.py +127 -0
  37. meridian/templates/style.css +66 -9
  38. meridian/templates/style.scss +85 -4
  39. meridian/templates/table.html.jinja +1 -0
  40. meridian/version.py +1 -1
  41. scenarioplanner/__init__.py +42 -0
  42. scenarioplanner/converters/__init__.py +25 -0
  43. scenarioplanner/converters/dataframe/__init__.py +28 -0
  44. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  45. scenarioplanner/converters/dataframe/common.py +71 -0
  46. scenarioplanner/converters/dataframe/constants.py +137 -0
  47. scenarioplanner/converters/dataframe/converter.py +42 -0
  48. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  49. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  50. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  51. scenarioplanner/converters/mmm.py +743 -0
  52. scenarioplanner/converters/mmm_converter.py +58 -0
  53. scenarioplanner/converters/sheets.py +156 -0
  54. scenarioplanner/converters/test_data.py +714 -0
  55. scenarioplanner/linkingapi/__init__.py +47 -0
  56. scenarioplanner/linkingapi/constants.py +27 -0
  57. scenarioplanner/linkingapi/url_generator.py +131 -0
  58. scenarioplanner/mmm_ui_proto_generator.py +355 -0
  59. schema/__init__.py +5 -2
  60. schema/mmm_proto_generator.py +71 -0
  61. schema/model_consumer.py +133 -0
  62. schema/processors/__init__.py +77 -0
  63. schema/processors/budget_optimization_processor.py +832 -0
  64. schema/processors/common.py +64 -0
  65. schema/processors/marketing_processor.py +1137 -0
  66. schema/processors/model_fit_processor.py +367 -0
  67. schema/processors/model_kernel_processor.py +117 -0
  68. schema/processors/model_processor.py +415 -0
  69. schema/processors/reach_frequency_optimization_processor.py +584 -0
  70. schema/serde/distribution.py +12 -7
  71. schema/serde/hyperparameters.py +54 -107
  72. schema/serde/meridian_serde.py +6 -1
  73. schema/test_data.py +380 -0
  74. schema/utils/__init__.py +2 -0
  75. schema/utils/date_range_bucketing.py +117 -0
  76. schema/utils/proto_enum_converter.py +127 -0
  77. google_meridian-1.3.2.dist-info/RECORD +0 -76
  78. {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
meridian/model/model.py CHANGED
@@ -16,8 +16,8 @@
16
16
 
17
17
  import collections
18
18
  from collections.abc import Mapping, Sequence
19
+ import dataclasses
19
20
  import functools
20
- import numbers
21
21
  import os
22
22
  import warnings
23
23
 
@@ -28,6 +28,8 @@ from meridian import constants
28
28
  from meridian.data import input_data as data
29
29
  from meridian.data import time_coordinates as tc
30
30
  from meridian.model import adstock_hill
31
+ from meridian.model import context
32
+ from meridian.model import equations
31
33
  from meridian.model import knots
32
34
  from meridian.model import media
33
35
  from meridian.model import posterior_sampler
@@ -76,27 +78,15 @@ def _warn_setting_national_args(**kwargs):
76
78
  )
77
79
 
78
80
 
79
- def _check_for_negative_effect(
80
- dist: backend.tfd.Distribution, media_effects_dist: str
81
- ):
82
- """Checks for negative effect in the model."""
83
- if (
84
- media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
85
- and np.any(dist.cdf(0)) > 0
86
- ):
87
- raise ValueError(
88
- "Media priors must have non-negative support when"
89
- f' `media_effects_dist`="{media_effects_dist}". Found negative effect'
90
- f" in {dist.name}."
91
- )
92
-
93
-
94
81
  class Meridian:
95
82
  """Contains the main functionality for fitting the Meridian MMM model.
96
83
 
97
84
  Attributes:
98
85
  input_data: An `InputData` object containing the input data for the model.
99
86
  model_spec: A `ModelSpec` object containing the model specification.
87
+ model_context: A `ModelContext` object containing the model context.
88
+ model_equations: A `ModelEquations` object containing stateless mathematical
89
+ functions and utilities for Meridian MMM.
100
90
  inference_data: A _mutable_ `arviz.InferenceData` object containing the
101
91
  resulting data from fitting the model.
102
92
  eda_engine: An `EDAEngine` object containing the EDA engine.
@@ -168,15 +158,17 @@ class Meridian:
168
158
  ) = None, # for deserializer use only
169
159
  eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
170
160
  ):
171
- self._input_data = input_data
172
- self._model_spec = model_spec if model_spec else spec.ModelSpec()
173
161
  self._inference_data = (
174
162
  inference_data if inference_data else az.InferenceData()
175
163
  )
164
+ self._model_context = context.ModelContext(
165
+ input_data=input_data,
166
+ model_spec=model_spec if model_spec else spec.ModelSpec(),
167
+ )
168
+ self._model_equations = equations.ModelEquations(self._model_context)
176
169
 
177
170
  self._eda_spec = eda_spec
178
171
 
179
- self._validate_data_dependent_model_spec()
180
172
  self._validate_injected_inference_data()
181
173
 
182
174
  if self.is_national:
@@ -184,22 +176,23 @@ class Meridian:
184
176
  media_effects_dist=self.model_spec.media_effects_dist,
185
177
  unique_sigma_for_each_geo=self.model_spec.unique_sigma_for_each_geo,
186
178
  )
187
- self._warn_setting_ignored_priors()
188
- self._set_total_media_contribution_prior = False
189
- self._validate_mroi_priors_non_revenue()
190
- self._validate_roi_priors_non_revenue()
191
- self._check_for_negative_effects()
192
- self._validate_geo_invariants()
193
- self._validate_time_invariants()
194
- self._validate_kpi_transformer()
179
+ self._validate_kpi_variability()
195
180
 
196
181
  @property
197
182
  def input_data(self) -> data.InputData:
198
- return self._input_data
183
+ return self._model_context.input_data
199
184
 
200
185
  @property
201
186
  def model_spec(self) -> spec.ModelSpec:
202
- return self._model_spec
187
+ return self._model_context.model_spec
188
+
189
+ @property
190
+ def model_context(self) -> context.ModelContext:
191
+ return self._model_context
192
+
193
+ @property
194
+ def model_equations(self) -> equations.ModelEquations:
195
+ return self._model_equations
203
196
 
204
197
  @property
205
198
  def inference_data(self) -> az.InferenceData:
@@ -207,190 +200,127 @@ class Meridian:
207
200
 
208
201
  @functools.cached_property
209
202
  def eda_engine(self) -> eda_engine.EDAEngine:
210
- return eda_engine.EDAEngine(self, spec=self._eda_spec)
203
+ return eda_engine.EDAEngine(
204
+ spec=self._eda_spec, model_context=self.model_context
205
+ )
211
206
 
212
207
  @property
213
208
  def eda_spec(self) -> eda_spec_module.EDASpec:
214
209
  return self._eda_spec
215
210
 
216
211
  @property
217
- def eda_outcomes(self) -> Sequence[eda_outcome.EDAOutcome]:
212
+ def eda_outcomes(self) -> eda_outcome.CriticalCheckEDAOutcomes:
218
213
  return self.eda_engine.run_all_critical_checks()
219
214
 
220
- @functools.cached_property
215
+ @property
221
216
  def media_tensors(self) -> media.MediaTensors:
222
- return media.build_media_tensors(self.input_data, self.model_spec)
217
+ return self._model_context.media_tensors
223
218
 
224
- @functools.cached_property
219
+ @property
225
220
  def rf_tensors(self) -> media.RfTensors:
226
- return media.build_rf_tensors(self.input_data, self.model_spec)
221
+ return self._model_context.rf_tensors
227
222
 
228
- @functools.cached_property
223
+ @property
229
224
  def organic_media_tensors(self) -> media.OrganicMediaTensors:
230
- return media.build_organic_media_tensors(self.input_data)
225
+ return self._model_context.organic_media_tensors
231
226
 
232
- @functools.cached_property
227
+ @property
233
228
  def organic_rf_tensors(self) -> media.OrganicRfTensors:
234
- return media.build_organic_rf_tensors(self.input_data)
229
+ return self._model_context.organic_rf_tensors
235
230
 
236
- @functools.cached_property
231
+ @property
237
232
  def kpi(self) -> backend.Tensor:
238
- return backend.to_tensor(self.input_data.kpi, dtype=backend.float32)
233
+ return self._model_context.kpi
239
234
 
240
- @functools.cached_property
235
+ @property
241
236
  def revenue_per_kpi(self) -> backend.Tensor | None:
242
- if self.input_data.revenue_per_kpi is None:
243
- return None
244
- return backend.to_tensor(
245
- self.input_data.revenue_per_kpi, dtype=backend.float32
246
- )
237
+ return self._model_context.revenue_per_kpi
247
238
 
248
- @functools.cached_property
239
+ @property
249
240
  def controls(self) -> backend.Tensor | None:
250
- if self.input_data.controls is None:
251
- return None
252
- return backend.to_tensor(self.input_data.controls, dtype=backend.float32)
241
+ return self._model_context.controls
253
242
 
254
- @functools.cached_property
243
+ @property
255
244
  def non_media_treatments(self) -> backend.Tensor | None:
256
- if self.input_data.non_media_treatments is None:
257
- return None
258
- return backend.to_tensor(
259
- self.input_data.non_media_treatments, dtype=backend.float32
260
- )
245
+ return self._model_context.non_media_treatments
261
246
 
262
- @functools.cached_property
247
+ @property
263
248
  def population(self) -> backend.Tensor:
264
- return backend.to_tensor(self.input_data.population, dtype=backend.float32)
249
+ return self._model_context.population
265
250
 
266
- @functools.cached_property
251
+ @property
267
252
  def total_spend(self) -> backend.Tensor:
268
- return backend.to_tensor(
269
- self.input_data.get_total_spend(), dtype=backend.float32
270
- )
253
+ return self._model_context.total_spend
271
254
 
272
- @functools.cached_property
255
+ @property
273
256
  def total_outcome(self) -> backend.Tensor:
274
- return backend.to_tensor(
275
- self.input_data.get_total_outcome(), dtype=backend.float32
276
- )
257
+ return self._model_context.total_outcome
277
258
 
278
259
  @property
279
260
  def n_geos(self) -> int:
280
- return len(self.input_data.geo)
261
+ return self._model_context.n_geos
281
262
 
282
263
  @property
283
264
  def n_media_channels(self) -> int:
284
- if self.input_data.media_channel is None:
285
- return 0
286
- return len(self.input_data.media_channel)
265
+ return self._model_context.n_media_channels
287
266
 
288
267
  @property
289
268
  def n_rf_channels(self) -> int:
290
- if self.input_data.rf_channel is None:
291
- return 0
292
- return len(self.input_data.rf_channel)
269
+ return self._model_context.n_rf_channels
293
270
 
294
271
  @property
295
272
  def n_organic_media_channels(self) -> int:
296
- if self.input_data.organic_media_channel is None:
297
- return 0
298
- return len(self.input_data.organic_media_channel)
273
+ return self._model_context.n_organic_media_channels
299
274
 
300
275
  @property
301
276
  def n_organic_rf_channels(self) -> int:
302
- if self.input_data.organic_rf_channel is None:
303
- return 0
304
- return len(self.input_data.organic_rf_channel)
277
+ return self._model_context.n_organic_rf_channels
305
278
 
306
279
  @property
307
280
  def n_controls(self) -> int:
308
- if self.input_data.control_variable is None:
309
- return 0
310
- return len(self.input_data.control_variable)
281
+ return self._model_context.n_controls
311
282
 
312
283
  @property
313
284
  def n_non_media_channels(self) -> int:
314
- if self.input_data.non_media_channel is None:
315
- return 0
316
- return len(self.input_data.non_media_channel)
285
+ return self._model_context.n_non_media_channels
317
286
 
318
287
  @property
319
288
  def n_times(self) -> int:
320
- return len(self.input_data.time)
289
+ return self._model_context.n_times
321
290
 
322
291
  @property
323
292
  def n_media_times(self) -> int:
324
- return len(self.input_data.media_time)
293
+ return self._model_context.n_media_times
325
294
 
326
295
  @property
327
296
  def is_national(self) -> bool:
328
- return self.n_geos == 1
297
+ return self._model_context.is_national
329
298
 
330
- @functools.cached_property
299
+ @property
331
300
  def knot_info(self) -> knots.KnotInfo:
332
- return knots.get_knot_info(
333
- n_times=self.n_times,
334
- knots=self.model_spec.knots,
335
- enable_aks=self.model_spec.enable_aks,
336
- data=self.input_data,
337
- is_national=self.is_national,
338
- )
301
+ return self._model_context.knot_info
339
302
 
340
- @functools.cached_property
303
+ @property
341
304
  def controls_transformer(
342
305
  self,
343
306
  ) -> transformers.CenteringAndScalingTransformer | None:
344
- """Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
345
- if self.controls is None:
346
- return None
307
+ return self._model_context.controls_transformer
347
308
 
348
- if self.model_spec.control_population_scaling_id is not None:
349
- controls_population_scaling_id = backend.to_tensor(
350
- self.model_spec.control_population_scaling_id, dtype=backend.bool_
351
- )
352
- else:
353
- controls_population_scaling_id = None
354
-
355
- return transformers.CenteringAndScalingTransformer(
356
- tensor=self.controls,
357
- population=self.population,
358
- population_scaling_id=controls_population_scaling_id,
359
- )
360
-
361
- @functools.cached_property
309
+ @property
362
310
  def non_media_transformer(
363
311
  self,
364
312
  ) -> transformers.CenteringAndScalingTransformer | None:
365
- """Returns a `CenteringAndScalingTransformer` for non-media treatments."""
366
- if self.non_media_treatments is None:
367
- return None
368
- if self.model_spec.non_media_population_scaling_id is not None:
369
- non_media_population_scaling_id = backend.to_tensor(
370
- self.model_spec.non_media_population_scaling_id, dtype=backend.bool_
371
- )
372
- else:
373
- non_media_population_scaling_id = None
313
+ return self._model_context.non_media_transformer
374
314
 
375
- return transformers.CenteringAndScalingTransformer(
376
- tensor=self.non_media_treatments,
377
- population=self.population,
378
- population_scaling_id=non_media_population_scaling_id,
379
- )
380
-
381
- @functools.cached_property
315
+ @property
382
316
  def kpi_transformer(self) -> transformers.KpiTransformer:
383
- return transformers.KpiTransformer(self.kpi, self.population)
317
+ return self._model_context.kpi_transformer
384
318
 
385
- @functools.cached_property
319
+ @property
386
320
  def controls_scaled(self) -> backend.Tensor | None:
387
- if self.controls is not None:
388
- # If `controls` is defined, then `controls_transformer` is also defined.
389
- return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
390
- else:
391
- return None
321
+ return self._model_context.controls_scaled
392
322
 
393
- @functools.cached_property
323
+ @property
394
324
  def non_media_treatments_normalized(self) -> backend.Tensor | None:
395
325
  """Normalized non-media treatments.
396
326
 
@@ -398,130 +328,56 @@ class Meridian:
398
328
  `non_media_population_scaling_id` is `True`) and normalized by centering and
399
329
  scaling with means and standard deviations.
400
330
  """
401
- if self.non_media_transformer is not None:
402
- return self.non_media_transformer.forward(
403
- self.non_media_treatments
404
- ) # pytype: disable=attribute-error
405
- else:
406
- return None
331
+ return self._model_context.non_media_treatments_normalized
407
332
 
408
- @functools.cached_property
333
+ @property
409
334
  def kpi_scaled(self) -> backend.Tensor:
410
- return self.kpi_transformer.forward(self.kpi)
335
+ return self._model_context.kpi_scaled
411
336
 
412
- @functools.cached_property
337
+ @property
413
338
  def media_effects_dist(self) -> str:
414
- if self.is_national:
415
- return constants.NATIONAL_MODEL_SPEC_ARGS[constants.MEDIA_EFFECTS_DIST]
416
- else:
417
- return self.model_spec.media_effects_dist
339
+ return self._model_context.media_effects_dist
418
340
 
419
- @functools.cached_property
341
+ @property
420
342
  def unique_sigma_for_each_geo(self) -> bool:
421
- if self.is_national:
422
- # Should evaluate to False.
423
- return constants.NATIONAL_MODEL_SPEC_ARGS[
424
- constants.UNIQUE_SIGMA_FOR_EACH_GEO
425
- ]
426
- else:
427
- return self.model_spec.unique_sigma_for_each_geo
343
+ return self._model_context.unique_sigma_for_each_geo
428
344
 
429
- @functools.cached_property
345
+ @property
430
346
  def baseline_geo_idx(self) -> int:
431
347
  """Returns the index of the baseline geo."""
432
- if isinstance(self.model_spec.baseline_geo, int):
433
- if (
434
- self.model_spec.baseline_geo < 0
435
- or self.model_spec.baseline_geo >= self.n_geos
436
- ):
437
- raise ValueError(
438
- f"Baseline geo index {self.model_spec.baseline_geo} out of range"
439
- f" [0, {self.n_geos - 1}]."
440
- )
441
- return self.model_spec.baseline_geo
442
- elif isinstance(self.model_spec.baseline_geo, str):
443
- # np.where returns a 1-D tuple, its first element is an array of found
444
- # elements.
445
- index = np.where(self.input_data.geo == self.model_spec.baseline_geo)[0]
446
- if index.size == 0:
447
- raise ValueError(
448
- f"Baseline geo '{self.model_spec.baseline_geo}' not found."
449
- )
450
- # Geos are unique, so index is a 1-element array.
451
- return index[0]
452
- else:
453
- return backend.argmax(self.population)
348
+ return self._model_context.baseline_geo_idx
454
349
 
455
- @functools.cached_property
350
+ @property
456
351
  def holdout_id(self) -> backend.Tensor | None:
457
- if self.model_spec.holdout_id is None:
458
- return None
459
- tensor = backend.to_tensor(self.model_spec.holdout_id, dtype=backend.bool_)
460
- return tensor[backend.newaxis, ...] if self.is_national else tensor
352
+ return self._model_context.holdout_id
461
353
 
462
- @functools.cached_property
354
+ @property
463
355
  def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
464
356
  """Returns `AdstockDecaySpec` object with correctly mapped channels."""
465
- if isinstance(self.model_spec.adstock_decay_spec, str):
466
- return adstock_hill.AdstockDecaySpec.from_consistent_type(
467
- self.model_spec.adstock_decay_spec
468
- )
357
+ return self._model_context.adstock_decay_spec
469
358
 
470
- try:
471
- return self._create_adstock_decay_functions_from_channel_map(
472
- self.model_spec.adstock_decay_spec
473
- )
474
- except KeyError as e:
475
- raise ValueError(
476
- "Unrecognized channel names found in `adstock_decay_spec` keys"
477
- f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
478
- " either contain only channel_names"
479
- f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
480
- " be one or more of {'media', 'rf', 'organic_media',"
481
- " 'organic_rf'}."
482
- ) from e
483
-
484
- @functools.cached_property
359
+ @property
485
360
  def prior_broadcast(self) -> prior_distribution.PriorDistribution:
486
361
  """Returns broadcasted `PriorDistribution` object."""
487
- total_spend = self.input_data.get_total_spend()
488
- # Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
489
- if len(total_spend.shape) == 1:
490
- # Already aggregated by channel.
491
- agg_total_spend = total_spend
492
- elif len(total_spend.shape) == 2:
493
- agg_total_spend = np.sum(total_spend, axis=(0,))
494
- else:
495
- agg_total_spend = np.sum(total_spend, axis=(0, 1))
496
-
497
- return self.model_spec.prior.broadcast(
498
- n_geos=self.n_geos,
499
- n_media_channels=self.n_media_channels,
500
- n_rf_channels=self.n_rf_channels,
501
- n_organic_media_channels=self.n_organic_media_channels,
502
- n_organic_rf_channels=self.n_organic_rf_channels,
503
- n_controls=self.n_controls,
504
- n_non_media_channels=self.n_non_media_channels,
505
- unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
506
- n_knots=self.knot_info.n_knots,
507
- is_national=self.is_national,
508
- set_total_media_contribution_prior=self._set_total_media_contribution_prior,
509
- kpi=np.sum(self.input_data.kpi.values),
510
- total_spend=agg_total_spend,
511
- )
362
+ return self._model_context.prior_broadcast
512
363
 
513
364
  @functools.cached_property
514
365
  def prior_sampler_callable(self) -> prior_sampler.PriorDistributionSampler:
515
366
  """A `PriorDistributionSampler` callable bound to this model."""
516
- return prior_sampler.PriorDistributionSampler(self)
367
+ return prior_sampler.PriorDistributionSampler(
368
+ model_context=self.model_context,
369
+ )
517
370
 
518
371
  @functools.cached_property
519
372
  def posterior_sampler_callable(
520
373
  self,
521
374
  ) -> posterior_sampler.PosteriorMCMCSampler:
522
375
  """A `PosteriorMCMCSampler` callable bound to this model."""
523
- return posterior_sampler.PosteriorMCMCSampler(self)
376
+ return posterior_sampler.PosteriorMCMCSampler(
377
+ model_context=self.model_context,
378
+ )
524
379
 
380
+ # TODO: Remove this method.
525
381
  def compute_non_media_treatments_baseline(
526
382
  self,
527
383
  non_media_baseline_values: Sequence[str | float] | None = None,
@@ -544,70 +400,18 @@ class Meridian:
544
400
  A tensor of shape `(n_non_media_channels,)` containing the
545
401
  baseline values for each non-media treatment channel.
546
402
  """
547
- if non_media_baseline_values is None:
548
- non_media_baseline_values = self.model_spec.non_media_baseline_values
549
-
550
- no_op_scaling_factor = backend.ones_like(self.population)[
551
- :, backend.newaxis, backend.newaxis
552
- ]
553
- if self.model_spec.non_media_population_scaling_id is not None:
554
- scaling_factors = backend.where(
555
- self.model_spec.non_media_population_scaling_id,
556
- self.population[:, backend.newaxis, backend.newaxis],
557
- no_op_scaling_factor,
558
- )
559
- else:
560
- scaling_factors = no_op_scaling_factor
561
-
562
- non_media_treatments_population_scaled = backend.divide_no_nan(
563
- self.non_media_treatments, scaling_factors
403
+ warnings.warn(
404
+ "Meridian.compute_non_media_treatments_baseline() is deprecated and"
405
+ " will be removed in a future version. Use"
406
+ " `ModelEquations.compute_non_media_treatments_baseline()` instead.",
407
+ DeprecationWarning,
408
+ stacklevel=2,
409
+ )
410
+ return self.model_equations.compute_non_media_treatments_baseline(
411
+ non_media_baseline_values=non_media_baseline_values
564
412
  )
565
413
 
566
- if non_media_baseline_values is None:
567
- # If non_media_baseline_values is not provided, use the minimum
568
- # value for each non_media treatment channel as the baseline.
569
- non_media_baseline_values_filled = [
570
- constants.NON_MEDIA_BASELINE_MIN
571
- ] * non_media_treatments_population_scaled.shape[-1]
572
- else:
573
- non_media_baseline_values_filled = non_media_baseline_values
574
-
575
- if non_media_treatments_population_scaled.shape[-1] != len(
576
- non_media_baseline_values_filled
577
- ):
578
- raise ValueError(
579
- "The number of non-media channels"
580
- f" ({non_media_treatments_population_scaled.shape[-1]}) does not"
581
- " match the number of baseline values"
582
- f" ({len(non_media_baseline_values_filled)})."
583
- )
584
-
585
- baseline_list = []
586
- for channel in range(non_media_treatments_population_scaled.shape[-1]):
587
- baseline_value = non_media_baseline_values_filled[channel]
588
-
589
- if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
590
- baseline_for_channel = backend.reduce_min(
591
- non_media_treatments_population_scaled[..., channel], axis=[0, 1]
592
- )
593
- elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
594
- baseline_for_channel = backend.reduce_max(
595
- non_media_treatments_population_scaled[..., channel], axis=[0, 1]
596
- )
597
- elif isinstance(baseline_value, numbers.Number):
598
- baseline_for_channel = backend.to_tensor(
599
- baseline_value, dtype=backend.float32
600
- )
601
- else:
602
- raise ValueError(
603
- f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
604
- " float numbers and strings 'min' and 'max' are supported."
605
- )
606
-
607
- baseline_list.append(baseline_for_channel)
608
-
609
- return backend.stack(baseline_list, axis=-1)
610
-
414
+ # TODO: Remove this method.
611
415
  def expand_selected_time_dims(
612
416
  self,
613
417
  start_date: tc.Date = None,
@@ -635,12 +439,16 @@ class Meridian:
635
439
  ValueError if `start_date` or `end_date` is not in the input data time
636
440
  dimensions.
637
441
  """
638
- expanded = self.input_data.time_coordinates.expand_selected_time_dims(
442
+ warnings.warn(
443
+ "Meridian.expand_selected_time_dims() is deprecated and will be removed"
444
+ " in a future version. Use `ModelContext.expand_selected_time_dims()`"
445
+ " instead.",
446
+ DeprecationWarning,
447
+ stacklevel=2,
448
+ )
449
+ return self._model_context.expand_selected_time_dims(
639
450
  start_date=start_date, end_date=end_date
640
451
  )
641
- if expanded is None:
642
- return None
643
- return [date.strftime(constants.DATE_FORMAT) for date in expanded]
644
452
 
645
453
  def _validate_injected_inference_data(self):
646
454
  """Validates that the injected inference data has correct shapes.
@@ -752,427 +560,8 @@ class Meridian:
752
560
  self.n_non_media_channels,
753
561
  )
754
562
 
755
- def _validate_data_dependent_model_spec(self):
756
- """Validates that the data dependent model specs have correct shapes."""
757
-
758
- if (
759
- self.model_spec.roi_calibration_period is not None
760
- and self.model_spec.roi_calibration_period.shape
761
- != (
762
- self.n_media_times,
763
- self.n_media_channels,
764
- )
765
- ):
766
- raise ValueError(
767
- "The shape of `roi_calibration_period`"
768
- f" {self.model_spec.roi_calibration_period.shape} is different from"
769
- f" `(n_media_times, n_media_channels) = ({self.n_media_times},"
770
- f" {self.n_media_channels})`."
771
- )
772
-
773
- if (
774
- self.model_spec.rf_roi_calibration_period is not None
775
- and self.model_spec.rf_roi_calibration_period.shape
776
- != (
777
- self.n_media_times,
778
- self.n_rf_channels,
779
- )
780
- ):
781
- raise ValueError(
782
- "The shape of `rf_roi_calibration_period`"
783
- f" {self.model_spec.rf_roi_calibration_period.shape} is different"
784
- f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
785
- f" {self.n_rf_channels})`."
786
- )
787
-
788
- if self.model_spec.holdout_id is not None:
789
- if self.is_national and (
790
- self.model_spec.holdout_id.shape != (self.n_times,)
791
- ):
792
- raise ValueError(
793
- f"The shape of `holdout_id` {self.model_spec.holdout_id.shape} is"
794
- f" different from `(n_times,) = ({self.n_times},)`."
795
- )
796
- elif not self.is_national and (
797
- self.model_spec.holdout_id.shape
798
- != (
799
- self.n_geos,
800
- self.n_times,
801
- )
802
- ):
803
- raise ValueError(
804
- f"The shape of `holdout_id` {self.model_spec.holdout_id.shape} is"
805
- f" different from `(n_geos, n_times) = ({self.n_geos},"
806
- f" {self.n_times})`."
807
- )
808
-
809
- if self.model_spec.control_population_scaling_id is not None and (
810
- self.model_spec.control_population_scaling_id.shape
811
- != (self.n_controls,)
812
- ):
813
- raise ValueError(
814
- "The shape of `control_population_scaling_id`"
815
- f" {self.model_spec.control_population_scaling_id.shape} is different"
816
- f" from `(n_controls,) = ({self.n_controls},)`."
817
- )
818
-
819
- if self.model_spec.non_media_population_scaling_id is not None and (
820
- self.model_spec.non_media_population_scaling_id.shape
821
- != (self.n_non_media_channels,)
822
- ):
823
- raise ValueError(
824
- "The shape of `non_media_population_scaling_id`"
825
- f" {self.model_spec.non_media_population_scaling_id.shape} is"
826
- " different from `(n_non_media_channels,) ="
827
- f" ({self.n_non_media_channels},)`."
828
- )
829
-
830
- def _create_adstock_decay_functions_from_channel_map(
831
- self, channel_function_map: Mapping[str, str]
832
- ) -> adstock_hill.AdstockDecaySpec:
833
- """Create `AdstockDecaySpec` from mapping from channels to decay functions."""
834
-
835
- for channel in channel_function_map:
836
- if channel not in self.input_data.get_all_adstock_hill_channels():
837
- raise KeyError(f"Channel {channel} not found in data.")
838
-
839
- if self.input_data.media_channel is not None:
840
- media_channel_builder = self.input_data.get_paid_media_channels_argument_builder().with_default_value(
841
- constants.GEOMETRIC_DECAY
842
- )
843
- media_adstock_function = media_channel_builder(**channel_function_map)
844
- else:
845
- media_adstock_function = constants.GEOMETRIC_DECAY
846
-
847
- if self.input_data.rf_channel is not None:
848
- rf_channel_builder = self.input_data.get_paid_rf_channels_argument_builder().with_default_value(
849
- constants.GEOMETRIC_DECAY
850
- )
851
- rf_adstock_function = rf_channel_builder(**channel_function_map)
852
- else:
853
- rf_adstock_function = constants.GEOMETRIC_DECAY
854
-
855
- if self.input_data.organic_media_channel is not None:
856
- organic_media_channel_builder = self.input_data.get_organic_media_channels_argument_builder().with_default_value(
857
- constants.GEOMETRIC_DECAY
858
- )
859
- organic_media_adstock_function = organic_media_channel_builder(
860
- **channel_function_map
861
- )
862
- else:
863
- organic_media_adstock_function = constants.GEOMETRIC_DECAY
864
-
865
- if self.input_data.organic_rf_channel is not None:
866
- organic_rf_channel_builder = self.input_data.get_organic_rf_channels_argument_builder().with_default_value(
867
- constants.GEOMETRIC_DECAY
868
- )
869
- organic_rf_adstock_function = organic_rf_channel_builder(
870
- **channel_function_map
871
- )
872
- else:
873
- organic_rf_adstock_function = constants.GEOMETRIC_DECAY
874
-
875
- return adstock_hill.AdstockDecaySpec(
876
- media=media_adstock_function,
877
- rf=rf_adstock_function,
878
- organic_media=organic_media_adstock_function,
879
- organic_rf=organic_rf_adstock_function,
880
- )
881
-
882
- def _warn_setting_ignored_priors(self):
883
- """Raises a warning if ignored priors are set."""
884
- default_distribution = prior_distribution.PriorDistribution()
885
- for ignored_priors_dict, prior_type, prior_type_name in (
886
- (
887
- constants.IGNORED_PRIORS_MEDIA,
888
- self.model_spec.effective_media_prior_type,
889
- "media_prior_type",
890
- ),
891
- (
892
- constants.IGNORED_PRIORS_RF,
893
- self.model_spec.effective_rf_prior_type,
894
- "rf_prior_type",
895
- ),
896
- ):
897
- ignored_custom_priors = []
898
- for prior in ignored_priors_dict.get(prior_type, []):
899
- self_prior = getattr(self.model_spec.prior, prior)
900
- default_prior = getattr(default_distribution, prior)
901
- if not prior_distribution.distributions_are_equal(
902
- self_prior, default_prior
903
- ):
904
- ignored_custom_priors.append(prior)
905
- if ignored_custom_priors:
906
- ignored_priors_str = ", ".join(ignored_custom_priors)
907
- warnings.warn(
908
- f"Custom prior(s) `{ignored_priors_str}` are ignored when"
909
- f' `{prior_type_name}` is set to "{prior_type}".'
910
- )
911
-
912
- def _validate_mroi_priors_non_revenue(self):
913
- """Validates mroi priors in the non-revenue outcome case."""
914
- if (
915
- self.input_data.kpi_type == constants.NON_REVENUE
916
- and self.input_data.revenue_per_kpi is None
917
- ):
918
- default_distribution = prior_distribution.PriorDistribution()
919
- if (
920
- self.n_media_channels > 0
921
- and (
922
- self.model_spec.effective_media_prior_type
923
- == constants.TREATMENT_PRIOR_TYPE_MROI
924
- )
925
- and prior_distribution.distributions_are_equal(
926
- self.model_spec.prior.mroi_m, default_distribution.mroi_m
927
- )
928
- ):
929
- raise ValueError(
930
- f"Custom priors should be set on `{constants.MROI_M}` when"
931
- ' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
932
- " kpi data is missing."
933
- )
934
- if (
935
- self.n_rf_channels > 0
936
- and (
937
- self.model_spec.effective_rf_prior_type
938
- == constants.TREATMENT_PRIOR_TYPE_MROI
939
- )
940
- and prior_distribution.distributions_are_equal(
941
- self.model_spec.prior.mroi_rf, default_distribution.mroi_rf
942
- )
943
- ):
944
- raise ValueError(
945
- f"Custom priors should be set on `{constants.MROI_RF}` when"
946
- ' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
947
- " data is missing."
948
- )
949
-
950
- def _validate_roi_priors_non_revenue(self):
951
- """Validates roi priors in the non-revenue outcome case."""
952
- if (
953
- self.input_data.kpi_type == constants.NON_REVENUE
954
- and self.input_data.revenue_per_kpi is None
955
- ):
956
- default_distribution = prior_distribution.PriorDistribution()
957
- default_roi_m_used = (
958
- self.model_spec.effective_media_prior_type
959
- == constants.TREATMENT_PRIOR_TYPE_ROI
960
- and prior_distribution.distributions_are_equal(
961
- self.model_spec.prior.roi_m, default_distribution.roi_m
962
- )
963
- )
964
- default_roi_rf_used = (
965
- self.model_spec.effective_rf_prior_type
966
- == constants.TREATMENT_PRIOR_TYPE_ROI
967
- and prior_distribution.distributions_are_equal(
968
- self.model_spec.prior.roi_rf, default_distribution.roi_rf
969
- )
970
- )
971
- # If ROI priors are used with the default prior distribution for all paid
972
- # channels (media and RF), then use the "total paid media contribution
973
- # prior" procedure.
974
- if (
975
- (default_roi_m_used and default_roi_rf_used)
976
- or (self.n_media_channels == 0 and default_roi_rf_used)
977
- or (self.n_rf_channels == 0 and default_roi_m_used)
978
- ):
979
- self._set_total_media_contribution_prior = True
980
- warnings.warn(
981
- "Consider setting custom ROI priors, as kpi_type was specified as"
982
- " `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
983
- " total media contribution prior will be used with"
984
- f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
985
- " documentation available at "
986
- " https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
987
- )
988
- elif self.n_media_channels > 0 and default_roi_m_used:
989
- raise ValueError(
990
- f"Custom priors should be set on `{constants.ROI_M}` when"
991
- ' `media_prior_type` is "roi", custom priors are assigned on'
992
- ' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
993
- " non-revenue and revenue per kpi data is missing."
994
- )
995
- elif self.n_rf_channels > 0 and default_roi_rf_used:
996
- raise ValueError(
997
- f"Custom priors should be set on `{constants.ROI_RF}` when"
998
- ' `rf_prior_type` is "roi", custom priors are assigned on'
999
- ' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
1000
- " non-revenue and revenue per kpi data is missing."
1001
- )
1002
-
1003
- def _check_for_negative_effects(self):
1004
- prior = self.model_spec.prior
1005
- if self.n_media_channels > 0:
1006
- _check_for_negative_effect(prior.roi_m, self.media_effects_dist)
1007
- _check_for_negative_effect(prior.mroi_m, self.media_effects_dist)
1008
- if self.n_rf_channels > 0:
1009
- _check_for_negative_effect(prior.roi_rf, self.media_effects_dist)
1010
- _check_for_negative_effect(prior.mroi_rf, self.media_effects_dist)
1011
-
1012
- def _validate_geo_invariants(self):
1013
- """Validates non-national model invariants."""
1014
- if self.is_national:
1015
- return
1016
-
1017
- if self.input_data.controls is not None:
1018
- self._check_if_no_geo_variation(
1019
- self.controls_scaled,
1020
- constants.CONTROLS,
1021
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
1022
- )
1023
- if self.input_data.non_media_treatments is not None:
1024
- self._check_if_no_geo_variation(
1025
- self.non_media_treatments_normalized,
1026
- constants.NON_MEDIA_TREATMENTS,
1027
- self.input_data.non_media_treatments.coords[
1028
- constants.NON_MEDIA_CHANNEL
1029
- ].values,
1030
- )
1031
- if self.input_data.media is not None:
1032
- self._check_if_no_geo_variation(
1033
- self.media_tensors.media_scaled,
1034
- constants.MEDIA,
1035
- self.input_data.media.coords[constants.MEDIA_CHANNEL].values,
1036
- )
1037
- if self.input_data.reach is not None:
1038
- self._check_if_no_geo_variation(
1039
- self.rf_tensors.reach_scaled,
1040
- constants.REACH,
1041
- self.input_data.reach.coords[constants.RF_CHANNEL].values,
1042
- )
1043
- if self.input_data.organic_media is not None:
1044
- self._check_if_no_geo_variation(
1045
- self.organic_media_tensors.organic_media_scaled,
1046
- "organic_media",
1047
- self.input_data.organic_media.coords[
1048
- constants.ORGANIC_MEDIA_CHANNEL
1049
- ].values,
1050
- )
1051
- if self.input_data.organic_reach is not None:
1052
- self._check_if_no_geo_variation(
1053
- self.organic_rf_tensors.organic_reach_scaled,
1054
- "organic_reach",
1055
- self.input_data.organic_reach.coords[
1056
- constants.ORGANIC_RF_CHANNEL
1057
- ].values,
1058
- )
1059
-
1060
- def _check_if_no_geo_variation(
1061
- self,
1062
- scaled_data: backend.Tensor,
1063
- data_name: str,
1064
- data_dims: Sequence[str],
1065
- epsilon=1e-4,
1066
- ):
1067
- """Raise an error if `n_knots == n_time` and data lacks geo variation."""
1068
-
1069
- # Result shape: [n, d], where d is the number of axes of condition.
1070
- col_idx_full = backend.get_indices_where(
1071
- backend.reduce_std(scaled_data, axis=0) < epsilon
1072
- )[:, 1]
1073
- col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
1074
- # We use the shape of scaled_data (instead of `n_time`) because the data may
1075
- # be padded to account for lagged effects.
1076
- data_n_time = scaled_data.shape[1]
1077
- mask = backend.equal(counts, data_n_time)
1078
- col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
1079
- dims_bad = backend.gather(data_dims, col_idx_bad)
1080
-
1081
- if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
1082
- raise ValueError(
1083
- f"The following {data_name} variables do not vary across geos, making"
1084
- f" a model with n_knots=n_time unidentifiable: {dims_bad}. This can"
1085
- " lead to poor model convergence. Since these variables only vary"
1086
- " across time and not across geo, they are collinear with time and"
1087
- " redundant in a model with a parameter for each time period. To"
1088
- " address this, you can either: (1) decrease the number of knots"
1089
- " (n_knots < n_time), or (2) drop the listed variables that do not"
1090
- " vary across geos."
1091
- )
1092
-
1093
- def _validate_time_invariants(self):
1094
- """Validates model time invariants."""
1095
- if self.input_data.controls is not None:
1096
- self._check_if_no_time_variation(
1097
- self.controls_scaled,
1098
- constants.CONTROLS,
1099
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
1100
- )
1101
- if self.input_data.non_media_treatments is not None:
1102
- self._check_if_no_time_variation(
1103
- self.non_media_treatments_normalized,
1104
- constants.NON_MEDIA_TREATMENTS,
1105
- self.input_data.non_media_treatments.coords[
1106
- constants.NON_MEDIA_CHANNEL
1107
- ].values,
1108
- )
1109
- if self.input_data.media is not None:
1110
- self._check_if_no_time_variation(
1111
- self.media_tensors.media_scaled,
1112
- constants.MEDIA,
1113
- self.input_data.media.coords[constants.MEDIA_CHANNEL].values,
1114
- )
1115
- if self.input_data.reach is not None:
1116
- self._check_if_no_time_variation(
1117
- self.rf_tensors.reach_scaled,
1118
- constants.REACH,
1119
- self.input_data.reach.coords[constants.RF_CHANNEL].values,
1120
- )
1121
- if self.input_data.organic_media is not None:
1122
- self._check_if_no_time_variation(
1123
- self.organic_media_tensors.organic_media_scaled,
1124
- constants.ORGANIC_MEDIA,
1125
- self.input_data.organic_media.coords[
1126
- constants.ORGANIC_MEDIA_CHANNEL
1127
- ].values,
1128
- )
1129
- if self.input_data.organic_reach is not None:
1130
- self._check_if_no_time_variation(
1131
- self.organic_rf_tensors.organic_reach_scaled,
1132
- constants.ORGANIC_REACH,
1133
- self.input_data.organic_reach.coords[
1134
- constants.ORGANIC_RF_CHANNEL
1135
- ].values,
1136
- )
1137
-
1138
- def _check_if_no_time_variation(
1139
- self,
1140
- scaled_data: backend.Tensor,
1141
- data_name: str,
1142
- data_dims: Sequence[str],
1143
- epsilon=1e-4,
1144
- ):
1145
- """Raise an error if data lacks time variation."""
1146
-
1147
- # Result shape: [n, d], where d is the number of axes of condition.
1148
- col_idx_full = backend.get_indices_where(
1149
- backend.reduce_std(scaled_data, axis=1) < epsilon
1150
- )[:, 1]
1151
- col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
1152
- mask = backend.equal(counts, self.n_geos)
1153
- col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
1154
- dims_bad = backend.gather(data_dims, col_idx_bad)
1155
- if col_idx_bad.shape[0]:
1156
- if self.is_national:
1157
- raise ValueError(
1158
- f"The following {data_name} variables do not vary across time,"
1159
- " which is equivalent to no signal at all in a national model:"
1160
- f" {dims_bad}. This can lead to poor model convergence. To address"
1161
- " this, drop the listed variables that do not vary across time."
1162
- )
1163
- else:
1164
- raise ValueError(
1165
- f"The following {data_name} variables do not vary across time,"
1166
- f" making a model with geo main effects unidentifiable: {dims_bad}."
1167
- " This can lead to poor model convergence. Since these variables"
1168
- " only vary across geo and not across time, they are collinear"
1169
- " with geo and redundant in a model with geo main effects. To"
1170
- " address this, drop the listed variables that do not vary across"
1171
- " time."
1172
- )
1173
-
1174
- def _validate_kpi_transformer(self):
1175
- """Validates the KPI transformer."""
563
+ def _validate_kpi_variability(self):
564
+ """Validates the KPI variability."""
1176
565
  if self.eda_engine.kpi_has_variability:
1177
566
  return
1178
567
  kpi = self.eda_engine.kpi_scaled_da.name
@@ -1227,6 +616,7 @@ class Meridian:
1227
616
  f' "{self.model_spec.non_media_treatments_prior_type}".'
1228
617
  )
1229
618
 
619
+ # TODO: Remove this method.
1230
620
  def linear_predictor_counterfactual_difference_media(
1231
621
  self,
1232
622
  media_transformed: backend.Tensor,
@@ -1255,21 +645,24 @@ class Meridian:
1255
645
  The linear predictor difference between the treatment variable and its
1256
646
  counterfactual.
1257
647
  """
1258
- if self.media_tensors.prior_media_scaled_counterfactual is None:
1259
- return media_transformed
1260
- media_transformed_counterfactual = self.adstock_hill_media(
1261
- self.media_tensors.prior_media_scaled_counterfactual,
1262
- alpha_m,
1263
- ec_m,
1264
- slope_m,
1265
- decay_functions=self.adstock_decay_spec.media,
648
+ warnings.warn(
649
+ "Meridian.linear_predictor_counterfactual_difference_media() is"
650
+ " deprecated and will be removed in a future version. Use "
651
+ "`ModelEquations.linear_predictor_counterfactual_difference_media()`"
652
+ " instead.",
653
+ DeprecationWarning,
654
+ stacklevel=2,
1266
655
  )
1267
- # Absolute values is needed because the difference is negative for mROI
1268
- # priors and positive for ROI and contribution priors.
1269
- return backend.absolute(
1270
- media_transformed - media_transformed_counterfactual
656
+ return (
657
+ self.model_equations.linear_predictor_counterfactual_difference_media(
658
+ media_transformed=media_transformed,
659
+ alpha_m=alpha_m,
660
+ ec_m=ec_m,
661
+ slope_m=slope_m,
662
+ )
1271
663
  )
1272
664
 
665
+ # TODO: Remove this method.
1273
666
  def linear_predictor_counterfactual_difference_rf(
1274
667
  self,
1275
668
  rf_transformed: backend.Tensor,
@@ -1298,20 +691,21 @@ class Meridian:
1298
691
  The linear predictor difference between the treatment variable and its
1299
692
  counterfactual.
1300
693
  """
1301
- if self.rf_tensors.prior_reach_scaled_counterfactual is None:
1302
- return rf_transformed
1303
- rf_transformed_counterfactual = self.adstock_hill_rf(
1304
- reach=self.rf_tensors.prior_reach_scaled_counterfactual,
1305
- frequency=self.rf_tensors.frequency,
1306
- alpha=alpha_rf,
1307
- ec=ec_rf,
1308
- slope=slope_rf,
1309
- decay_functions=self.adstock_decay_spec.rf,
694
+ warnings.warn(
695
+ "Meridian.linear_predictor_counterfactual_difference_rf() is deprecated"
696
+ " and will be removed in a future version. Use `ModelEquations."
697
+ "linear_predictor_counterfactual_difference_rf()` instead.",
698
+ DeprecationWarning,
699
+ stacklevel=2,
700
+ )
701
+ return self.model_equations.linear_predictor_counterfactual_difference_rf(
702
+ rf_transformed=rf_transformed,
703
+ alpha_rf=alpha_rf,
704
+ ec_rf=ec_rf,
705
+ slope_rf=slope_rf,
1310
706
  )
1311
- # Absolute values is needed because the difference is negative for mROI
1312
- # priors and positive for ROI and contribution priors.
1313
- return backend.absolute(rf_transformed - rf_transformed_counterfactual)
1314
707
 
708
+ # TODO: Remove this method.
1315
709
  def calculate_beta_x(
1316
710
  self,
1317
711
  is_non_media: bool,
@@ -1357,45 +751,22 @@ class Meridian:
1357
751
  The coefficient mean parameter of the treatment variable, which has
1358
752
  dimension equal to the number of treatment channels..
1359
753
  """
1360
- if is_non_media:
1361
- random_effects_normal = True
1362
- else:
1363
- random_effects_normal = (
1364
- self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1365
- )
1366
- if self.revenue_per_kpi is None:
1367
- revenue_per_kpi = backend.ones(
1368
- [self.n_geos, self.n_times], dtype=backend.float32
1369
- )
1370
- else:
1371
- revenue_per_kpi = self.revenue_per_kpi
1372
- incremental_outcome_gx_over_beta_gx = backend.einsum(
1373
- "...gtx,gt,g,->...gx",
1374
- linear_predictor_counterfactual_difference,
1375
- revenue_per_kpi,
1376
- self.population,
1377
- self.kpi_transformer.population_scaled_stdev,
754
+ warnings.warn(
755
+ "Meridian.calculate_beta_x() is deprecated and will be removed in a"
756
+ " future version. Use `ModelEquations.calculate_beta_x()`"
757
+ " instead.",
758
+ DeprecationWarning,
759
+ stacklevel=2,
1378
760
  )
1379
- if random_effects_normal:
1380
- numerator_term_x = backend.einsum(
1381
- "...gx,...gx,...x->...x",
1382
- incremental_outcome_gx_over_beta_gx,
1383
- beta_gx_dev,
1384
- eta_x,
1385
- )
1386
- denominator_term_x = backend.einsum(
1387
- "...gx->...x", incremental_outcome_gx_over_beta_gx
1388
- )
1389
- return (incremental_outcome_x - numerator_term_x) / denominator_term_x
1390
- # For log-normal random effects, beta_x and eta_x are not mean & std.
1391
- # The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
1392
- denominator_term_x = backend.einsum(
1393
- "...gx,...gx->...x",
1394
- incremental_outcome_gx_over_beta_gx,
1395
- backend.exp(beta_gx_dev * eta_x[..., backend.newaxis, :]),
761
+ return self.model_equations.calculate_beta_x(
762
+ is_non_media=is_non_media,
763
+ incremental_outcome_x=incremental_outcome_x,
764
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
765
+ eta_x=eta_x,
766
+ beta_gx_dev=beta_gx_dev,
1396
767
  )
1397
- return backend.log(incremental_outcome_x) - backend.log(denominator_term_x)
1398
768
 
769
+ # TODO: Remove this method.
1399
770
  def adstock_hill_media(
1400
771
  self,
1401
772
  media: backend.Tensor, # pylint: disable=redefined-outer-name
@@ -1426,34 +797,23 @@ class Meridian:
1426
797
  Tensor with dimensions `[..., n_geos, n_times, n_media_channels]`
1427
798
  representing Adstock and Hill-transformed media.
1428
799
  """
1429
- if n_times_output is None and (media.shape[1] == self.n_media_times):
1430
- n_times_output = self.n_times
1431
- elif n_times_output is None:
1432
- raise ValueError(
1433
- "n_times_output is required. This argument is only optional when "
1434
- "`media` has a number of time periods equal to `self.n_media_times`."
1435
- )
1436
- adstock_transformer = adstock_hill.AdstockTransformer(
1437
- alpha=alpha,
1438
- max_lag=self.model_spec.max_lag,
1439
- n_times_output=n_times_output,
1440
- decay_functions=decay_functions,
800
+ warnings.warn(
801
+ "Meridian.adstock_hill_media() is deprecated and will be removed in a"
802
+ " future version. Use `ModelEquations.adstock_hill_media()`"
803
+ " instead.",
804
+ DeprecationWarning,
805
+ stacklevel=2,
1441
806
  )
1442
- hill_transformer = adstock_hill.HillTransformer(
807
+ return self.model_equations.adstock_hill_media(
808
+ media=media,
809
+ alpha=alpha,
1443
810
  ec=ec,
1444
811
  slope=slope,
812
+ decay_functions=decay_functions,
813
+ n_times_output=n_times_output,
1445
814
  )
1446
- transformers_list = (
1447
- [hill_transformer, adstock_transformer]
1448
- if self.model_spec.hill_before_adstock
1449
- else [adstock_transformer, hill_transformer]
1450
- )
1451
-
1452
- media_out = media
1453
- for transformer in transformers_list:
1454
- media_out = transformer.forward(media_out)
1455
- return media_out
1456
815
 
816
+ # TODO: Remove this method.
1457
817
  def adstock_hill_rf(
1458
818
  self,
1459
819
  reach: backend.Tensor,
@@ -1485,27 +845,22 @@ class Meridian:
1485
845
  Tensor with dimensions `[..., n_geos, n_times, n_rf_channels]`
1486
846
  representing Hill and Adstock-transformed RF.
1487
847
  """
1488
- if n_times_output is None and (reach.shape[1] == self.n_media_times):
1489
- n_times_output = self.n_times
1490
- elif n_times_output is None:
1491
- raise ValueError(
1492
- "n_times_output is required. This argument is only optional when "
1493
- "`reach` has a number of time periods equal to `self.n_media_times`."
1494
- )
1495
- hill_transformer = adstock_hill.HillTransformer(
1496
- ec=ec,
1497
- slope=slope,
848
+ warnings.warn(
849
+ "Meridian.adstock_hill_rf() is deprecated and will be removed in a"
850
+ " future version. Use `ModelEquations.adstock_hill_rf()`"
851
+ " instead.",
852
+ DeprecationWarning,
853
+ stacklevel=2,
1498
854
  )
1499
- adstock_transformer = adstock_hill.AdstockTransformer(
855
+ return self.model_equations.adstock_hill_rf(
856
+ reach=reach,
857
+ frequency=frequency,
1500
858
  alpha=alpha,
1501
- max_lag=self.model_spec.max_lag,
1502
- n_times_output=n_times_output,
859
+ ec=ec,
860
+ slope=slope,
1503
861
  decay_functions=decay_functions,
862
+ n_times_output=n_times_output,
1504
863
  )
1505
- adj_frequency = hill_transformer.forward(frequency)
1506
- rf_out = adstock_transformer.forward(reach * adj_frequency)
1507
-
1508
- return rf_out
1509
864
 
1510
865
  def populate_cached_properties(self):
1511
866
  """Eagerly activates all cached properties.
@@ -1515,6 +870,7 @@ class Meridian:
1515
870
  internal state mutations are problematic, and so this method freezes the
1516
871
  object's states before the computation graph is created.
1517
872
  """
873
+ self._model_context.populate_cached_properties()
1518
874
  cls = self.__class__
1519
875
  # "Freeze" all @cached_property attributes by simply accessing them (with
1520
876
  # `getattr()`).
@@ -1526,66 +882,30 @@ class Meridian:
1526
882
  for attr in cached_properties:
1527
883
  _ = getattr(self, attr)
1528
884
 
885
+ # TODO: Remove this method.
1529
886
  def create_inference_data_coords(
1530
887
  self, n_chains: int, n_draws: int
1531
888
  ) -> Mapping[str, np.ndarray | Sequence[str]]:
1532
889
  """Creates data coordinates for inference data."""
1533
- media_channel_names = (
1534
- self.input_data.media_channel
1535
- if self.input_data.media_channel is not None
1536
- else np.array([])
1537
- )
1538
- rf_channel_names = (
1539
- self.input_data.rf_channel
1540
- if self.input_data.rf_channel is not None
1541
- else np.array([])
1542
- )
1543
- organic_media_channel_names = (
1544
- self.input_data.organic_media_channel
1545
- if self.input_data.organic_media_channel is not None
1546
- else np.array([])
1547
- )
1548
- organic_rf_channel_names = (
1549
- self.input_data.organic_rf_channel
1550
- if self.input_data.organic_rf_channel is not None
1551
- else np.array([])
1552
- )
1553
- non_media_channel_names = (
1554
- self.input_data.non_media_channel
1555
- if self.input_data.non_media_channel is not None
1556
- else np.array([])
890
+ warnings.warn(
891
+ "Meridian.create_inference_data_coords() is deprecated and will be"
892
+ " removed in a future version. Use"
893
+ " `ModelContext.create_inference_data_coords()` instead.",
894
+ DeprecationWarning,
895
+ stacklevel=2,
1557
896
  )
1558
- control_variable_names = (
1559
- self.input_data.control_variable
1560
- if self.input_data.control_variable is not None
1561
- else np.array([])
1562
- )
1563
- return {
1564
- constants.CHAIN: np.arange(n_chains),
1565
- constants.DRAW: np.arange(n_draws),
1566
- constants.GEO: self.input_data.geo,
1567
- constants.TIME: self.input_data.time,
1568
- constants.MEDIA_TIME: self.input_data.media_time,
1569
- constants.KNOTS: np.arange(self.knot_info.n_knots),
1570
- constants.CONTROL_VARIABLE: control_variable_names,
1571
- constants.NON_MEDIA_CHANNEL: non_media_channel_names,
1572
- constants.MEDIA_CHANNEL: media_channel_names,
1573
- constants.RF_CHANNEL: rf_channel_names,
1574
- constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
1575
- constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
1576
- }
897
+ return self._model_context.create_inference_data_coords(n_chains, n_draws)
1577
898
 
899
+ # TODO: Remove this method.
1578
900
  def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
1579
- inference_dims = dict(constants.INFERENCE_DIMS)
1580
- if self.unique_sigma_for_each_geo:
1581
- inference_dims[constants.SIGMA] = [constants.GEO]
1582
- else:
1583
- inference_dims[constants.SIGMA] = []
1584
-
1585
- return {
1586
- param: [constants.CHAIN, constants.DRAW] + list(dims)
1587
- for param, dims in inference_dims.items()
1588
- }
901
+ warnings.warn(
902
+ "Meridian.create_inference_data_dims() is deprecated and will be"
903
+ " removed in a future version. Use"
904
+ " `ModelContext.create_inference_data_dims()` instead.",
905
+ DeprecationWarning,
906
+ stacklevel=2,
907
+ )
908
+ return self._model_context.create_inference_data_dims()
1589
909
 
1590
910
  def sample_prior(self, n_draws: int, seed: int | None = None):
1591
911
  """Draws samples from the prior distributions.
@@ -1598,14 +918,25 @@ class Meridian:
1598
918
  see [PRNGS and seeds]
1599
919
  (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
1600
920
  """
1601
- self.prior_sampler_callable(n_draws=n_draws, seed=seed)
921
+ prior_draws = self.prior_sampler_callable(n_draws=n_draws, seed=seed)
922
+ # Create Arviz InferenceData for prior draws.
923
+ prior_coords = self._model_context.create_inference_data_coords(1, n_draws)
924
+ prior_dims = self._model_context.create_inference_data_dims()
925
+ prior_inference_data = az.convert_to_inference_data(
926
+ prior_draws,
927
+ coords=prior_coords,
928
+ dims=prior_dims,
929
+ group=constants.PRIOR,
930
+ )
931
+ self.inference_data.extend(prior_inference_data, join="right")
1602
932
 
1603
933
  def _run_model_fitting_guardrail(self):
1604
934
  """Raises an error if the model has critical EDA issues."""
1605
935
  error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = (
1606
936
  collections.defaultdict(list)
1607
937
  )
1608
- for outcome in self.eda_outcomes:
938
+ for field in dataclasses.fields(self.eda_outcomes):
939
+ outcome = getattr(self.eda_outcomes, field.name)
1609
940
  error_findings = [
1610
941
  finding
1611
942
  for finding in outcome.findings
@@ -1709,7 +1040,7 @@ class Meridian:
1709
1040
  """
1710
1041
  self._run_model_fitting_guardrail()
1711
1042
 
1712
- self.posterior_sampler_callable(
1043
+ posterior_inference_data = self.posterior_sampler_callable(
1713
1044
  n_chains=n_chains,
1714
1045
  n_adapt=n_adapt,
1715
1046
  n_burnin=n_burnin,
@@ -1724,6 +1055,7 @@ class Meridian:
1724
1055
  seed=seed,
1725
1056
  **pins,
1726
1057
  )
1058
+ self.inference_data.extend(posterior_inference_data, join="right")
1727
1059
 
1728
1060
 
1729
1061
  def save_mmm(mmm: Meridian, file_path: str):
@@ -1737,6 +1069,15 @@ def save_mmm(mmm: Meridian, file_path: str):
1737
1069
  mmm: Model object to save.
1738
1070
  file_path: File path to save a pickled model object.
1739
1071
  """
1072
+ warnings.warn(
1073
+ "save_mmm is deprecated and will be removed in a future release. Please"
1074
+ " use `schema.serde.meridian_serde.save_meridian` instead. See"
1075
+ " https://developers.google.com/meridian/docs/user-guide/saving-model-object"
1076
+ " for details.",
1077
+ DeprecationWarning,
1078
+ stacklevel=2,
1079
+ )
1080
+
1740
1081
  if not os.path.exists(os.path.dirname(file_path)):
1741
1082
  os.makedirs(os.path.dirname(file_path))
1742
1083
 
@@ -1760,6 +1101,15 @@ def load_mmm(file_path: str) -> Meridian:
1760
1101
  Raises:
1761
1102
  FileNotFoundError: If `file_path` does not exist.
1762
1103
  """
1104
+ warnings.warn(
1105
+ "load_mmm is deprecated and will be removed in a future release. Please"
1106
+ " use `meridian.schema.serde.meridian_serde.load_meridian` instead. See"
1107
+ " https://developers.google.com/meridian/docs/user-guide/saving-model-object"
1108
+ " for details.",
1109
+ DeprecationWarning,
1110
+ stacklevel=2,
1111
+ )
1112
+
1763
1113
  try:
1764
1114
  with open(file_path, "rb") as f:
1765
1115
  mmm = joblib.load(f)