google-meridian 1.3.2__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/METADATA +8 -4
  2. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/RECORD +49 -17
  3. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
  4. meridian/analysis/summarizer.py +7 -2
  5. meridian/analysis/test_utils.py +934 -485
  6. meridian/analysis/visualizer.py +10 -6
  7. meridian/constants.py +1 -0
  8. meridian/data/test_utils.py +82 -10
  9. meridian/model/__init__.py +2 -0
  10. meridian/model/context.py +925 -0
  11. meridian/model/eda/constants.py +1 -0
  12. meridian/model/equations.py +418 -0
  13. meridian/model/knots.py +58 -47
  14. meridian/model/model.py +93 -792
  15. meridian/version.py +1 -1
  16. scenarioplanner/__init__.py +42 -0
  17. scenarioplanner/converters/__init__.py +25 -0
  18. scenarioplanner/converters/dataframe/__init__.py +28 -0
  19. scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
  20. scenarioplanner/converters/dataframe/common.py +71 -0
  21. scenarioplanner/converters/dataframe/constants.py +137 -0
  22. scenarioplanner/converters/dataframe/converter.py +42 -0
  23. scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
  24. scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
  25. scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
  26. scenarioplanner/converters/mmm.py +743 -0
  27. scenarioplanner/converters/mmm_converter.py +58 -0
  28. scenarioplanner/converters/sheets.py +156 -0
  29. scenarioplanner/converters/test_data.py +714 -0
  30. scenarioplanner/linkingapi/__init__.py +47 -0
  31. scenarioplanner/linkingapi/constants.py +27 -0
  32. scenarioplanner/linkingapi/url_generator.py +131 -0
  33. scenarioplanner/mmm_ui_proto_generator.py +354 -0
  34. schema/__init__.py +5 -2
  35. schema/mmm_proto_generator.py +71 -0
  36. schema/model_consumer.py +133 -0
  37. schema/processors/__init__.py +77 -0
  38. schema/processors/budget_optimization_processor.py +832 -0
  39. schema/processors/common.py +64 -0
  40. schema/processors/marketing_processor.py +1136 -0
  41. schema/processors/model_fit_processor.py +367 -0
  42. schema/processors/model_kernel_processor.py +117 -0
  43. schema/processors/model_processor.py +412 -0
  44. schema/processors/reach_frequency_optimization_processor.py +584 -0
  45. schema/test_data.py +380 -0
  46. schema/utils/__init__.py +1 -0
  47. schema/utils/date_range_bucketing.py +117 -0
  48. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
  49. {google_meridian-1.3.2.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
meridian/model/model.py CHANGED
@@ -17,7 +17,6 @@
17
17
  import collections
18
18
  from collections.abc import Mapping, Sequence
19
19
  import functools
20
- import numbers
21
20
  import os
22
21
  import warnings
23
22
 
@@ -28,6 +27,8 @@ from meridian import constants
28
27
  from meridian.data import input_data as data
29
28
  from meridian.data import time_coordinates as tc
30
29
  from meridian.model import adstock_hill
30
+ from meridian.model import context
31
+ from meridian.model import equations
31
32
  from meridian.model import knots
32
33
  from meridian.model import media
33
34
  from meridian.model import posterior_sampler
@@ -76,27 +77,15 @@ def _warn_setting_national_args(**kwargs):
76
77
  )
77
78
 
78
79
 
79
- def _check_for_negative_effect(
80
- dist: backend.tfd.Distribution, media_effects_dist: str
81
- ):
82
- """Checks for negative effect in the model."""
83
- if (
84
- media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
85
- and np.any(dist.cdf(0)) > 0
86
- ):
87
- raise ValueError(
88
- "Media priors must have non-negative support when"
89
- f' `media_effects_dist`="{media_effects_dist}". Found negative effect'
90
- f" in {dist.name}."
91
- )
92
-
93
-
94
80
  class Meridian:
95
81
  """Contains the main functionality for fitting the Meridian MMM model.
96
82
 
97
83
  Attributes:
98
84
  input_data: An `InputData` object containing the input data for the model.
99
85
  model_spec: A `ModelSpec` object containing the model specification.
86
+ model_context: A `ModelContext` object containing the model context.
87
+ equations: A `ModelEquations` object containing stateless mathematical
88
+ functions and utilities for Meridian MMM.
100
89
  inference_data: A _mutable_ `arviz.InferenceData` object containing the
101
90
  resulting data from fitting the model.
102
91
  eda_engine: An `EDAEngine` object containing the EDA engine.
@@ -168,15 +157,17 @@ class Meridian:
168
157
  ) = None, # for deserializer use only
169
158
  eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
170
159
  ):
171
- self._input_data = input_data
172
- self._model_spec = model_spec if model_spec else spec.ModelSpec()
173
160
  self._inference_data = (
174
161
  inference_data if inference_data else az.InferenceData()
175
162
  )
163
+ self._model_context = context.ModelContext(
164
+ input_data=input_data,
165
+ model_spec=model_spec if model_spec else spec.ModelSpec(),
166
+ )
167
+ self._equations = equations.ModelEquations(self._model_context)
176
168
 
177
169
  self._eda_spec = eda_spec
178
170
 
179
- self._validate_data_dependent_model_spec()
180
171
  self._validate_injected_inference_data()
181
172
 
182
173
  if self.is_national:
@@ -184,22 +175,23 @@ class Meridian:
184
175
  media_effects_dist=self.model_spec.media_effects_dist,
185
176
  unique_sigma_for_each_geo=self.model_spec.unique_sigma_for_each_geo,
186
177
  )
187
- self._warn_setting_ignored_priors()
188
- self._set_total_media_contribution_prior = False
189
- self._validate_mroi_priors_non_revenue()
190
- self._validate_roi_priors_non_revenue()
191
- self._check_for_negative_effects()
192
- self._validate_geo_invariants()
193
- self._validate_time_invariants()
194
- self._validate_kpi_transformer()
178
+ self._validate_kpi_variability()
195
179
 
196
180
  @property
197
181
  def input_data(self) -> data.InputData:
198
- return self._input_data
182
+ return self._model_context.input_data
199
183
 
200
184
  @property
201
185
  def model_spec(self) -> spec.ModelSpec:
202
- return self._model_spec
186
+ return self._model_context.model_spec
187
+
188
+ @property
189
+ def model_context(self) -> context.ModelContext:
190
+ return self._model_context
191
+
192
+ @property
193
+ def equations(self) -> equations.ModelEquations:
194
+ return self._equations
203
195
 
204
196
  @property
205
197
  def inference_data(self) -> az.InferenceData:
@@ -219,176 +211,111 @@ class Meridian:
219
211
 
220
212
  @functools.cached_property
221
213
  def media_tensors(self) -> media.MediaTensors:
222
- return media.build_media_tensors(self.input_data, self.model_spec)
214
+ return self._model_context.media_tensors
223
215
 
224
216
  @functools.cached_property
225
217
  def rf_tensors(self) -> media.RfTensors:
226
- return media.build_rf_tensors(self.input_data, self.model_spec)
218
+ return self._model_context.rf_tensors
227
219
 
228
220
  @functools.cached_property
229
221
  def organic_media_tensors(self) -> media.OrganicMediaTensors:
230
- return media.build_organic_media_tensors(self.input_data)
222
+ return self._model_context.organic_media_tensors
231
223
 
232
224
  @functools.cached_property
233
225
  def organic_rf_tensors(self) -> media.OrganicRfTensors:
234
- return media.build_organic_rf_tensors(self.input_data)
226
+ return self._model_context.organic_rf_tensors
235
227
 
236
228
  @functools.cached_property
237
229
  def kpi(self) -> backend.Tensor:
238
- return backend.to_tensor(self.input_data.kpi, dtype=backend.float32)
230
+ return self._model_context.kpi
239
231
 
240
232
  @functools.cached_property
241
233
  def revenue_per_kpi(self) -> backend.Tensor | None:
242
- if self.input_data.revenue_per_kpi is None:
243
- return None
244
- return backend.to_tensor(
245
- self.input_data.revenue_per_kpi, dtype=backend.float32
246
- )
234
+ return self._model_context.revenue_per_kpi
247
235
 
248
236
  @functools.cached_property
249
237
  def controls(self) -> backend.Tensor | None:
250
- if self.input_data.controls is None:
251
- return None
252
- return backend.to_tensor(self.input_data.controls, dtype=backend.float32)
238
+ return self._model_context.controls
253
239
 
254
240
  @functools.cached_property
255
241
  def non_media_treatments(self) -> backend.Tensor | None:
256
- if self.input_data.non_media_treatments is None:
257
- return None
258
- return backend.to_tensor(
259
- self.input_data.non_media_treatments, dtype=backend.float32
260
- )
242
+ return self._model_context.non_media_treatments
261
243
 
262
244
  @functools.cached_property
263
245
  def population(self) -> backend.Tensor:
264
- return backend.to_tensor(self.input_data.population, dtype=backend.float32)
246
+ return self._model_context.population
265
247
 
266
248
  @functools.cached_property
267
249
  def total_spend(self) -> backend.Tensor:
268
- return backend.to_tensor(
269
- self.input_data.get_total_spend(), dtype=backend.float32
270
- )
250
+ return self._model_context.total_spend
271
251
 
272
252
  @functools.cached_property
273
253
  def total_outcome(self) -> backend.Tensor:
274
- return backend.to_tensor(
275
- self.input_data.get_total_outcome(), dtype=backend.float32
276
- )
254
+ return self._model_context.total_outcome
277
255
 
278
256
  @property
279
257
  def n_geos(self) -> int:
280
- return len(self.input_data.geo)
258
+ return self._model_context.n_geos
281
259
 
282
260
  @property
283
261
  def n_media_channels(self) -> int:
284
- if self.input_data.media_channel is None:
285
- return 0
286
- return len(self.input_data.media_channel)
262
+ return self._model_context.n_media_channels
287
263
 
288
264
  @property
289
265
  def n_rf_channels(self) -> int:
290
- if self.input_data.rf_channel is None:
291
- return 0
292
- return len(self.input_data.rf_channel)
266
+ return self._model_context.n_rf_channels
293
267
 
294
268
  @property
295
269
  def n_organic_media_channels(self) -> int:
296
- if self.input_data.organic_media_channel is None:
297
- return 0
298
- return len(self.input_data.organic_media_channel)
270
+ return self._model_context.n_organic_media_channels
299
271
 
300
272
  @property
301
273
  def n_organic_rf_channels(self) -> int:
302
- if self.input_data.organic_rf_channel is None:
303
- return 0
304
- return len(self.input_data.organic_rf_channel)
274
+ return self._model_context.n_organic_rf_channels
305
275
 
306
276
  @property
307
277
  def n_controls(self) -> int:
308
- if self.input_data.control_variable is None:
309
- return 0
310
- return len(self.input_data.control_variable)
278
+ return self._model_context.n_controls
311
279
 
312
280
  @property
313
281
  def n_non_media_channels(self) -> int:
314
- if self.input_data.non_media_channel is None:
315
- return 0
316
- return len(self.input_data.non_media_channel)
282
+ return self._model_context.n_non_media_channels
317
283
 
318
284
  @property
319
285
  def n_times(self) -> int:
320
- return len(self.input_data.time)
286
+ return self._model_context.n_times
321
287
 
322
288
  @property
323
289
  def n_media_times(self) -> int:
324
- return len(self.input_data.media_time)
290
+ return self._model_context.n_media_times
325
291
 
326
292
  @property
327
293
  def is_national(self) -> bool:
328
- return self.n_geos == 1
294
+ return self._model_context.is_national
329
295
 
330
296
  @functools.cached_property
331
297
  def knot_info(self) -> knots.KnotInfo:
332
- return knots.get_knot_info(
333
- n_times=self.n_times,
334
- knots=self.model_spec.knots,
335
- enable_aks=self.model_spec.enable_aks,
336
- data=self.input_data,
337
- is_national=self.is_national,
338
- )
298
+ return self._model_context.knot_info
339
299
 
340
300
  @functools.cached_property
341
301
  def controls_transformer(
342
302
  self,
343
303
  ) -> transformers.CenteringAndScalingTransformer | None:
344
- """Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
345
- if self.controls is None:
346
- return None
347
-
348
- if self.model_spec.control_population_scaling_id is not None:
349
- controls_population_scaling_id = backend.to_tensor(
350
- self.model_spec.control_population_scaling_id, dtype=backend.bool_
351
- )
352
- else:
353
- controls_population_scaling_id = None
354
-
355
- return transformers.CenteringAndScalingTransformer(
356
- tensor=self.controls,
357
- population=self.population,
358
- population_scaling_id=controls_population_scaling_id,
359
- )
304
+ return self._model_context.controls_transformer
360
305
 
361
306
  @functools.cached_property
362
307
  def non_media_transformer(
363
308
  self,
364
309
  ) -> transformers.CenteringAndScalingTransformer | None:
365
- """Returns a `CenteringAndScalingTransformer` for non-media treatments."""
366
- if self.non_media_treatments is None:
367
- return None
368
- if self.model_spec.non_media_population_scaling_id is not None:
369
- non_media_population_scaling_id = backend.to_tensor(
370
- self.model_spec.non_media_population_scaling_id, dtype=backend.bool_
371
- )
372
- else:
373
- non_media_population_scaling_id = None
374
-
375
- return transformers.CenteringAndScalingTransformer(
376
- tensor=self.non_media_treatments,
377
- population=self.population,
378
- population_scaling_id=non_media_population_scaling_id,
379
- )
310
+ return self._model_context.non_media_transformer
380
311
 
381
312
  @functools.cached_property
382
313
  def kpi_transformer(self) -> transformers.KpiTransformer:
383
- return transformers.KpiTransformer(self.kpi, self.population)
314
+ return self._model_context.kpi_transformer
384
315
 
385
316
  @functools.cached_property
386
317
  def controls_scaled(self) -> backend.Tensor | None:
387
- if self.controls is not None:
388
- # If `controls` is defined, then `controls_transformer` is also defined.
389
- return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
390
- else:
391
- return None
318
+ return self._model_context.controls_scaled
392
319
 
393
320
  @functools.cached_property
394
321
  def non_media_treatments_normalized(self) -> backend.Tensor | None:
@@ -398,117 +325,38 @@ class Meridian:
398
325
  `non_media_population_scaling_id` is `True`) and normalized by centering and
399
326
  scaling with means and standard deviations.
400
327
  """
401
- if self.non_media_transformer is not None:
402
- return self.non_media_transformer.forward(
403
- self.non_media_treatments
404
- ) # pytype: disable=attribute-error
405
- else:
406
- return None
328
+ return self._model_context.non_media_treatments_normalized
407
329
 
408
330
  @functools.cached_property
409
331
  def kpi_scaled(self) -> backend.Tensor:
410
- return self.kpi_transformer.forward(self.kpi)
332
+ return self._model_context.kpi_scaled
411
333
 
412
334
  @functools.cached_property
413
335
  def media_effects_dist(self) -> str:
414
- if self.is_national:
415
- return constants.NATIONAL_MODEL_SPEC_ARGS[constants.MEDIA_EFFECTS_DIST]
416
- else:
417
- return self.model_spec.media_effects_dist
336
+ return self._model_context.media_effects_dist
418
337
 
419
338
  @functools.cached_property
420
339
  def unique_sigma_for_each_geo(self) -> bool:
421
- if self.is_national:
422
- # Should evaluate to False.
423
- return constants.NATIONAL_MODEL_SPEC_ARGS[
424
- constants.UNIQUE_SIGMA_FOR_EACH_GEO
425
- ]
426
- else:
427
- return self.model_spec.unique_sigma_for_each_geo
340
+ return self._model_context.unique_sigma_for_each_geo
428
341
 
429
342
  @functools.cached_property
430
343
  def baseline_geo_idx(self) -> int:
431
344
  """Returns the index of the baseline geo."""
432
- if isinstance(self.model_spec.baseline_geo, int):
433
- if (
434
- self.model_spec.baseline_geo < 0
435
- or self.model_spec.baseline_geo >= self.n_geos
436
- ):
437
- raise ValueError(
438
- f"Baseline geo index {self.model_spec.baseline_geo} out of range"
439
- f" [0, {self.n_geos - 1}]."
440
- )
441
- return self.model_spec.baseline_geo
442
- elif isinstance(self.model_spec.baseline_geo, str):
443
- # np.where returns a 1-D tuple, its first element is an array of found
444
- # elements.
445
- index = np.where(self.input_data.geo == self.model_spec.baseline_geo)[0]
446
- if index.size == 0:
447
- raise ValueError(
448
- f"Baseline geo '{self.model_spec.baseline_geo}' not found."
449
- )
450
- # Geos are unique, so index is a 1-element array.
451
- return index[0]
452
- else:
453
- return backend.argmax(self.population)
345
+ return self._model_context.baseline_geo_idx
454
346
 
455
347
  @functools.cached_property
456
348
  def holdout_id(self) -> backend.Tensor | None:
457
- if self.model_spec.holdout_id is None:
458
- return None
459
- tensor = backend.to_tensor(self.model_spec.holdout_id, dtype=backend.bool_)
460
- return tensor[backend.newaxis, ...] if self.is_national else tensor
349
+ return self._model_context.holdout_id
461
350
 
462
351
  @functools.cached_property
463
352
  def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
464
353
  """Returns `AdstockDecaySpec` object with correctly mapped channels."""
465
- if isinstance(self.model_spec.adstock_decay_spec, str):
466
- return adstock_hill.AdstockDecaySpec.from_consistent_type(
467
- self.model_spec.adstock_decay_spec
468
- )
469
-
470
- try:
471
- return self._create_adstock_decay_functions_from_channel_map(
472
- self.model_spec.adstock_decay_spec
473
- )
474
- except KeyError as e:
475
- raise ValueError(
476
- "Unrecognized channel names found in `adstock_decay_spec` keys"
477
- f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
478
- " either contain only channel_names"
479
- f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
480
- " be one or more of {'media', 'rf', 'organic_media',"
481
- " 'organic_rf'}."
482
- ) from e
354
+ return self._model_context.adstock_decay_spec
483
355
 
484
356
  @functools.cached_property
485
357
  def prior_broadcast(self) -> prior_distribution.PriorDistribution:
486
358
  """Returns broadcasted `PriorDistribution` object."""
487
- total_spend = self.input_data.get_total_spend()
488
- # Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
489
- if len(total_spend.shape) == 1:
490
- # Already aggregated by channel.
491
- agg_total_spend = total_spend
492
- elif len(total_spend.shape) == 2:
493
- agg_total_spend = np.sum(total_spend, axis=(0,))
494
- else:
495
- agg_total_spend = np.sum(total_spend, axis=(0, 1))
496
-
497
- return self.model_spec.prior.broadcast(
498
- n_geos=self.n_geos,
499
- n_media_channels=self.n_media_channels,
500
- n_rf_channels=self.n_rf_channels,
501
- n_organic_media_channels=self.n_organic_media_channels,
502
- n_organic_rf_channels=self.n_organic_rf_channels,
503
- n_controls=self.n_controls,
504
- n_non_media_channels=self.n_non_media_channels,
505
- unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
506
- n_knots=self.knot_info.n_knots,
507
- is_national=self.is_national,
508
- set_total_media_contribution_prior=self._set_total_media_contribution_prior,
509
- kpi=np.sum(self.input_data.kpi.values),
510
- total_spend=agg_total_spend,
511
- )
359
+ return self._model_context.prior_broadcast
512
360
 
513
361
  @functools.cached_property
514
362
  def prior_sampler_callable(self) -> prior_sampler.PriorDistributionSampler:
@@ -522,6 +370,8 @@ class Meridian:
522
370
  """A `PosteriorMCMCSampler` callable bound to this model."""
523
371
  return posterior_sampler.PosteriorMCMCSampler(self)
524
372
 
373
+ # TODO: Deprecate this method in favor of the one in
374
+ # `equations.py`.
525
375
  def compute_non_media_treatments_baseline(
526
376
  self,
527
377
  non_media_baseline_values: Sequence[str | float] | None = None,
@@ -544,70 +394,10 @@ class Meridian:
544
394
  A tensor of shape `(n_non_media_channels,)` containing the
545
395
  baseline values for each non-media treatment channel.
546
396
  """
547
- if non_media_baseline_values is None:
548
- non_media_baseline_values = self.model_spec.non_media_baseline_values
549
-
550
- no_op_scaling_factor = backend.ones_like(self.population)[
551
- :, backend.newaxis, backend.newaxis
552
- ]
553
- if self.model_spec.non_media_population_scaling_id is not None:
554
- scaling_factors = backend.where(
555
- self.model_spec.non_media_population_scaling_id,
556
- self.population[:, backend.newaxis, backend.newaxis],
557
- no_op_scaling_factor,
558
- )
559
- else:
560
- scaling_factors = no_op_scaling_factor
561
-
562
- non_media_treatments_population_scaled = backend.divide_no_nan(
563
- self.non_media_treatments, scaling_factors
397
+ return self.equations.compute_non_media_treatments_baseline(
398
+ non_media_baseline_values=non_media_baseline_values
564
399
  )
565
400
 
566
- if non_media_baseline_values is None:
567
- # If non_media_baseline_values is not provided, use the minimum
568
- # value for each non_media treatment channel as the baseline.
569
- non_media_baseline_values_filled = [
570
- constants.NON_MEDIA_BASELINE_MIN
571
- ] * non_media_treatments_population_scaled.shape[-1]
572
- else:
573
- non_media_baseline_values_filled = non_media_baseline_values
574
-
575
- if non_media_treatments_population_scaled.shape[-1] != len(
576
- non_media_baseline_values_filled
577
- ):
578
- raise ValueError(
579
- "The number of non-media channels"
580
- f" ({non_media_treatments_population_scaled.shape[-1]}) does not"
581
- " match the number of baseline values"
582
- f" ({len(non_media_baseline_values_filled)})."
583
- )
584
-
585
- baseline_list = []
586
- for channel in range(non_media_treatments_population_scaled.shape[-1]):
587
- baseline_value = non_media_baseline_values_filled[channel]
588
-
589
- if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
590
- baseline_for_channel = backend.reduce_min(
591
- non_media_treatments_population_scaled[..., channel], axis=[0, 1]
592
- )
593
- elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
594
- baseline_for_channel = backend.reduce_max(
595
- non_media_treatments_population_scaled[..., channel], axis=[0, 1]
596
- )
597
- elif isinstance(baseline_value, numbers.Number):
598
- baseline_for_channel = backend.to_tensor(
599
- baseline_value, dtype=backend.float32
600
- )
601
- else:
602
- raise ValueError(
603
- f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
604
- " float numbers and strings 'min' and 'max' are supported."
605
- )
606
-
607
- baseline_list.append(baseline_for_channel)
608
-
609
- return backend.stack(baseline_list, axis=-1)
610
-
611
401
  def expand_selected_time_dims(
612
402
  self,
613
403
  start_date: tc.Date = None,
@@ -752,427 +542,8 @@ class Meridian:
752
542
  self.n_non_media_channels,
753
543
  )
754
544
 
755
- def _validate_data_dependent_model_spec(self):
756
- """Validates that the data dependent model specs have correct shapes."""
757
-
758
- if (
759
- self.model_spec.roi_calibration_period is not None
760
- and self.model_spec.roi_calibration_period.shape
761
- != (
762
- self.n_media_times,
763
- self.n_media_channels,
764
- )
765
- ):
766
- raise ValueError(
767
- "The shape of `roi_calibration_period`"
768
- f" {self.model_spec.roi_calibration_period.shape} is different from"
769
- f" `(n_media_times, n_media_channels) = ({self.n_media_times},"
770
- f" {self.n_media_channels})`."
771
- )
772
-
773
- if (
774
- self.model_spec.rf_roi_calibration_period is not None
775
- and self.model_spec.rf_roi_calibration_period.shape
776
- != (
777
- self.n_media_times,
778
- self.n_rf_channels,
779
- )
780
- ):
781
- raise ValueError(
782
- "The shape of `rf_roi_calibration_period`"
783
- f" {self.model_spec.rf_roi_calibration_period.shape} is different"
784
- f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
785
- f" {self.n_rf_channels})`."
786
- )
787
-
788
- if self.model_spec.holdout_id is not None:
789
- if self.is_national and (
790
- self.model_spec.holdout_id.shape != (self.n_times,)
791
- ):
792
- raise ValueError(
793
- f"The shape of `holdout_id` {self.model_spec.holdout_id.shape} is"
794
- f" different from `(n_times,) = ({self.n_times},)`."
795
- )
796
- elif not self.is_national and (
797
- self.model_spec.holdout_id.shape
798
- != (
799
- self.n_geos,
800
- self.n_times,
801
- )
802
- ):
803
- raise ValueError(
804
- f"The shape of `holdout_id` {self.model_spec.holdout_id.shape} is"
805
- f" different from `(n_geos, n_times) = ({self.n_geos},"
806
- f" {self.n_times})`."
807
- )
808
-
809
- if self.model_spec.control_population_scaling_id is not None and (
810
- self.model_spec.control_population_scaling_id.shape
811
- != (self.n_controls,)
812
- ):
813
- raise ValueError(
814
- "The shape of `control_population_scaling_id`"
815
- f" {self.model_spec.control_population_scaling_id.shape} is different"
816
- f" from `(n_controls,) = ({self.n_controls},)`."
817
- )
818
-
819
- if self.model_spec.non_media_population_scaling_id is not None and (
820
- self.model_spec.non_media_population_scaling_id.shape
821
- != (self.n_non_media_channels,)
822
- ):
823
- raise ValueError(
824
- "The shape of `non_media_population_scaling_id`"
825
- f" {self.model_spec.non_media_population_scaling_id.shape} is"
826
- " different from `(n_non_media_channels,) ="
827
- f" ({self.n_non_media_channels},)`."
828
- )
829
-
830
- def _create_adstock_decay_functions_from_channel_map(
831
- self, channel_function_map: Mapping[str, str]
832
- ) -> adstock_hill.AdstockDecaySpec:
833
- """Create `AdstockDecaySpec` from mapping from channels to decay functions."""
834
-
835
- for channel in channel_function_map:
836
- if channel not in self.input_data.get_all_adstock_hill_channels():
837
- raise KeyError(f"Channel {channel} not found in data.")
838
-
839
- if self.input_data.media_channel is not None:
840
- media_channel_builder = self.input_data.get_paid_media_channels_argument_builder().with_default_value(
841
- constants.GEOMETRIC_DECAY
842
- )
843
- media_adstock_function = media_channel_builder(**channel_function_map)
844
- else:
845
- media_adstock_function = constants.GEOMETRIC_DECAY
846
-
847
- if self.input_data.rf_channel is not None:
848
- rf_channel_builder = self.input_data.get_paid_rf_channels_argument_builder().with_default_value(
849
- constants.GEOMETRIC_DECAY
850
- )
851
- rf_adstock_function = rf_channel_builder(**channel_function_map)
852
- else:
853
- rf_adstock_function = constants.GEOMETRIC_DECAY
854
-
855
- if self.input_data.organic_media_channel is not None:
856
- organic_media_channel_builder = self.input_data.get_organic_media_channels_argument_builder().with_default_value(
857
- constants.GEOMETRIC_DECAY
858
- )
859
- organic_media_adstock_function = organic_media_channel_builder(
860
- **channel_function_map
861
- )
862
- else:
863
- organic_media_adstock_function = constants.GEOMETRIC_DECAY
864
-
865
- if self.input_data.organic_rf_channel is not None:
866
- organic_rf_channel_builder = self.input_data.get_organic_rf_channels_argument_builder().with_default_value(
867
- constants.GEOMETRIC_DECAY
868
- )
869
- organic_rf_adstock_function = organic_rf_channel_builder(
870
- **channel_function_map
871
- )
872
- else:
873
- organic_rf_adstock_function = constants.GEOMETRIC_DECAY
874
-
875
- return adstock_hill.AdstockDecaySpec(
876
- media=media_adstock_function,
877
- rf=rf_adstock_function,
878
- organic_media=organic_media_adstock_function,
879
- organic_rf=organic_rf_adstock_function,
880
- )
881
-
882
- def _warn_setting_ignored_priors(self):
883
- """Raises a warning if ignored priors are set."""
884
- default_distribution = prior_distribution.PriorDistribution()
885
- for ignored_priors_dict, prior_type, prior_type_name in (
886
- (
887
- constants.IGNORED_PRIORS_MEDIA,
888
- self.model_spec.effective_media_prior_type,
889
- "media_prior_type",
890
- ),
891
- (
892
- constants.IGNORED_PRIORS_RF,
893
- self.model_spec.effective_rf_prior_type,
894
- "rf_prior_type",
895
- ),
896
- ):
897
- ignored_custom_priors = []
898
- for prior in ignored_priors_dict.get(prior_type, []):
899
- self_prior = getattr(self.model_spec.prior, prior)
900
- default_prior = getattr(default_distribution, prior)
901
- if not prior_distribution.distributions_are_equal(
902
- self_prior, default_prior
903
- ):
904
- ignored_custom_priors.append(prior)
905
- if ignored_custom_priors:
906
- ignored_priors_str = ", ".join(ignored_custom_priors)
907
- warnings.warn(
908
- f"Custom prior(s) `{ignored_priors_str}` are ignored when"
909
- f' `{prior_type_name}` is set to "{prior_type}".'
910
- )
911
-
912
- def _validate_mroi_priors_non_revenue(self):
913
- """Validates mroi priors in the non-revenue outcome case."""
914
- if (
915
- self.input_data.kpi_type == constants.NON_REVENUE
916
- and self.input_data.revenue_per_kpi is None
917
- ):
918
- default_distribution = prior_distribution.PriorDistribution()
919
- if (
920
- self.n_media_channels > 0
921
- and (
922
- self.model_spec.effective_media_prior_type
923
- == constants.TREATMENT_PRIOR_TYPE_MROI
924
- )
925
- and prior_distribution.distributions_are_equal(
926
- self.model_spec.prior.mroi_m, default_distribution.mroi_m
927
- )
928
- ):
929
- raise ValueError(
930
- f"Custom priors should be set on `{constants.MROI_M}` when"
931
- ' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
932
- " kpi data is missing."
933
- )
934
- if (
935
- self.n_rf_channels > 0
936
- and (
937
- self.model_spec.effective_rf_prior_type
938
- == constants.TREATMENT_PRIOR_TYPE_MROI
939
- )
940
- and prior_distribution.distributions_are_equal(
941
- self.model_spec.prior.mroi_rf, default_distribution.mroi_rf
942
- )
943
- ):
944
- raise ValueError(
945
- f"Custom priors should be set on `{constants.MROI_RF}` when"
946
- ' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
947
- " data is missing."
948
- )
949
-
950
- def _validate_roi_priors_non_revenue(self):
951
- """Validates roi priors in the non-revenue outcome case."""
952
- if (
953
- self.input_data.kpi_type == constants.NON_REVENUE
954
- and self.input_data.revenue_per_kpi is None
955
- ):
956
- default_distribution = prior_distribution.PriorDistribution()
957
- default_roi_m_used = (
958
- self.model_spec.effective_media_prior_type
959
- == constants.TREATMENT_PRIOR_TYPE_ROI
960
- and prior_distribution.distributions_are_equal(
961
- self.model_spec.prior.roi_m, default_distribution.roi_m
962
- )
963
- )
964
- default_roi_rf_used = (
965
- self.model_spec.effective_rf_prior_type
966
- == constants.TREATMENT_PRIOR_TYPE_ROI
967
- and prior_distribution.distributions_are_equal(
968
- self.model_spec.prior.roi_rf, default_distribution.roi_rf
969
- )
970
- )
971
- # If ROI priors are used with the default prior distribution for all paid
972
- # channels (media and RF), then use the "total paid media contribution
973
- # prior" procedure.
974
- if (
975
- (default_roi_m_used and default_roi_rf_used)
976
- or (self.n_media_channels == 0 and default_roi_rf_used)
977
- or (self.n_rf_channels == 0 and default_roi_m_used)
978
- ):
979
- self._set_total_media_contribution_prior = True
980
- warnings.warn(
981
- "Consider setting custom ROI priors, as kpi_type was specified as"
982
- " `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
983
- " total media contribution prior will be used with"
984
- f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
985
- " documentation available at "
986
- " https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
987
- )
988
- elif self.n_media_channels > 0 and default_roi_m_used:
989
- raise ValueError(
990
- f"Custom priors should be set on `{constants.ROI_M}` when"
991
- ' `media_prior_type` is "roi", custom priors are assigned on'
992
- ' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
993
- " non-revenue and revenue per kpi data is missing."
994
- )
995
- elif self.n_rf_channels > 0 and default_roi_rf_used:
996
- raise ValueError(
997
- f"Custom priors should be set on `{constants.ROI_RF}` when"
998
- ' `rf_prior_type` is "roi", custom priors are assigned on'
999
- ' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
1000
- " non-revenue and revenue per kpi data is missing."
1001
- )
1002
-
1003
- def _check_for_negative_effects(self):
1004
- prior = self.model_spec.prior
1005
- if self.n_media_channels > 0:
1006
- _check_for_negative_effect(prior.roi_m, self.media_effects_dist)
1007
- _check_for_negative_effect(prior.mroi_m, self.media_effects_dist)
1008
- if self.n_rf_channels > 0:
1009
- _check_for_negative_effect(prior.roi_rf, self.media_effects_dist)
1010
- _check_for_negative_effect(prior.mroi_rf, self.media_effects_dist)
1011
-
1012
- def _validate_geo_invariants(self):
1013
- """Validates non-national model invariants."""
1014
- if self.is_national:
1015
- return
1016
-
1017
- if self.input_data.controls is not None:
1018
- self._check_if_no_geo_variation(
1019
- self.controls_scaled,
1020
- constants.CONTROLS,
1021
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
1022
- )
1023
- if self.input_data.non_media_treatments is not None:
1024
- self._check_if_no_geo_variation(
1025
- self.non_media_treatments_normalized,
1026
- constants.NON_MEDIA_TREATMENTS,
1027
- self.input_data.non_media_treatments.coords[
1028
- constants.NON_MEDIA_CHANNEL
1029
- ].values,
1030
- )
1031
- if self.input_data.media is not None:
1032
- self._check_if_no_geo_variation(
1033
- self.media_tensors.media_scaled,
1034
- constants.MEDIA,
1035
- self.input_data.media.coords[constants.MEDIA_CHANNEL].values,
1036
- )
1037
- if self.input_data.reach is not None:
1038
- self._check_if_no_geo_variation(
1039
- self.rf_tensors.reach_scaled,
1040
- constants.REACH,
1041
- self.input_data.reach.coords[constants.RF_CHANNEL].values,
1042
- )
1043
- if self.input_data.organic_media is not None:
1044
- self._check_if_no_geo_variation(
1045
- self.organic_media_tensors.organic_media_scaled,
1046
- "organic_media",
1047
- self.input_data.organic_media.coords[
1048
- constants.ORGANIC_MEDIA_CHANNEL
1049
- ].values,
1050
- )
1051
- if self.input_data.organic_reach is not None:
1052
- self._check_if_no_geo_variation(
1053
- self.organic_rf_tensors.organic_reach_scaled,
1054
- "organic_reach",
1055
- self.input_data.organic_reach.coords[
1056
- constants.ORGANIC_RF_CHANNEL
1057
- ].values,
1058
- )
1059
-
1060
- def _check_if_no_geo_variation(
1061
- self,
1062
- scaled_data: backend.Tensor,
1063
- data_name: str,
1064
- data_dims: Sequence[str],
1065
- epsilon=1e-4,
1066
- ):
1067
- """Raise an error if `n_knots == n_time` and data lacks geo variation."""
1068
-
1069
- # Result shape: [n, d], where d is the number of axes of condition.
1070
- col_idx_full = backend.get_indices_where(
1071
- backend.reduce_std(scaled_data, axis=0) < epsilon
1072
- )[:, 1]
1073
- col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
1074
- # We use the shape of scaled_data (instead of `n_time`) because the data may
1075
- # be padded to account for lagged effects.
1076
- data_n_time = scaled_data.shape[1]
1077
- mask = backend.equal(counts, data_n_time)
1078
- col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
1079
- dims_bad = backend.gather(data_dims, col_idx_bad)
1080
-
1081
- if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
1082
- raise ValueError(
1083
- f"The following {data_name} variables do not vary across geos, making"
1084
- f" a model with n_knots=n_time unidentifiable: {dims_bad}. This can"
1085
- " lead to poor model convergence. Since these variables only vary"
1086
- " across time and not across geo, they are collinear with time and"
1087
- " redundant in a model with a parameter for each time period. To"
1088
- " address this, you can either: (1) decrease the number of knots"
1089
- " (n_knots < n_time), or (2) drop the listed variables that do not"
1090
- " vary across geos."
1091
- )
1092
-
1093
- def _validate_time_invariants(self):
1094
- """Validates model time invariants."""
1095
- if self.input_data.controls is not None:
1096
- self._check_if_no_time_variation(
1097
- self.controls_scaled,
1098
- constants.CONTROLS,
1099
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
1100
- )
1101
- if self.input_data.non_media_treatments is not None:
1102
- self._check_if_no_time_variation(
1103
- self.non_media_treatments_normalized,
1104
- constants.NON_MEDIA_TREATMENTS,
1105
- self.input_data.non_media_treatments.coords[
1106
- constants.NON_MEDIA_CHANNEL
1107
- ].values,
1108
- )
1109
- if self.input_data.media is not None:
1110
- self._check_if_no_time_variation(
1111
- self.media_tensors.media_scaled,
1112
- constants.MEDIA,
1113
- self.input_data.media.coords[constants.MEDIA_CHANNEL].values,
1114
- )
1115
- if self.input_data.reach is not None:
1116
- self._check_if_no_time_variation(
1117
- self.rf_tensors.reach_scaled,
1118
- constants.REACH,
1119
- self.input_data.reach.coords[constants.RF_CHANNEL].values,
1120
- )
1121
- if self.input_data.organic_media is not None:
1122
- self._check_if_no_time_variation(
1123
- self.organic_media_tensors.organic_media_scaled,
1124
- constants.ORGANIC_MEDIA,
1125
- self.input_data.organic_media.coords[
1126
- constants.ORGANIC_MEDIA_CHANNEL
1127
- ].values,
1128
- )
1129
- if self.input_data.organic_reach is not None:
1130
- self._check_if_no_time_variation(
1131
- self.organic_rf_tensors.organic_reach_scaled,
1132
- constants.ORGANIC_REACH,
1133
- self.input_data.organic_reach.coords[
1134
- constants.ORGANIC_RF_CHANNEL
1135
- ].values,
1136
- )
1137
-
1138
- def _check_if_no_time_variation(
1139
- self,
1140
- scaled_data: backend.Tensor,
1141
- data_name: str,
1142
- data_dims: Sequence[str],
1143
- epsilon=1e-4,
1144
- ):
1145
- """Raise an error if data lacks time variation."""
1146
-
1147
- # Result shape: [n, d], where d is the number of axes of condition.
1148
- col_idx_full = backend.get_indices_where(
1149
- backend.reduce_std(scaled_data, axis=1) < epsilon
1150
- )[:, 1]
1151
- col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
1152
- mask = backend.equal(counts, self.n_geos)
1153
- col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
1154
- dims_bad = backend.gather(data_dims, col_idx_bad)
1155
- if col_idx_bad.shape[0]:
1156
- if self.is_national:
1157
- raise ValueError(
1158
- f"The following {data_name} variables do not vary across time,"
1159
- " which is equivalent to no signal at all in a national model:"
1160
- f" {dims_bad}. This can lead to poor model convergence. To address"
1161
- " this, drop the listed variables that do not vary across time."
1162
- )
1163
- else:
1164
- raise ValueError(
1165
- f"The following {data_name} variables do not vary across time,"
1166
- f" making a model with geo main effects unidentifiable: {dims_bad}."
1167
- " This can lead to poor model convergence. Since these variables"
1168
- " only vary across geo and not across time, they are collinear"
1169
- " with geo and redundant in a model with geo main effects. To"
1170
- " address this, drop the listed variables that do not vary across"
1171
- " time."
1172
- )
1173
-
1174
- def _validate_kpi_transformer(self):
1175
- """Validates the KPI transformer."""
545
+ def _validate_kpi_variability(self):
546
+ """Validates the KPI variability."""
1176
547
  if self.eda_engine.kpi_has_variability:
1177
548
  return
1178
549
  kpi = self.eda_engine.kpi_scaled_da.name
@@ -1227,6 +598,7 @@ class Meridian:
1227
598
  f' "{self.model_spec.non_media_treatments_prior_type}".'
1228
599
  )
1229
600
 
601
+ # TODO: Deprecate in favor of ModelEquations.adstock_hill_rf.
1230
602
  def linear_predictor_counterfactual_difference_media(
1231
603
  self,
1232
604
  media_transformed: backend.Tensor,
@@ -1255,21 +627,15 @@ class Meridian:
1255
627
  The linear predictor difference between the treatment variable and its
1256
628
  counterfactual.
1257
629
  """
1258
- if self.media_tensors.prior_media_scaled_counterfactual is None:
1259
- return media_transformed
1260
- media_transformed_counterfactual = self.adstock_hill_media(
1261
- self.media_tensors.prior_media_scaled_counterfactual,
1262
- alpha_m,
1263
- ec_m,
1264
- slope_m,
1265
- decay_functions=self.adstock_decay_spec.media,
1266
- )
1267
- # Absolute values is needed because the difference is negative for mROI
1268
- # priors and positive for ROI and contribution priors.
1269
- return backend.absolute(
1270
- media_transformed - media_transformed_counterfactual
630
+ return self.equations.linear_predictor_counterfactual_difference_media(
631
+ media_transformed=media_transformed,
632
+ alpha_m=alpha_m,
633
+ ec_m=ec_m,
634
+ slope_m=slope_m,
1271
635
  )
1272
636
 
637
+ # TODO: Deprecate in favor of
638
+ # ModelEquations.linear_predictor_counterfactual_difference_rf.
1273
639
  def linear_predictor_counterfactual_difference_rf(
1274
640
  self,
1275
641
  rf_transformed: backend.Tensor,
@@ -1298,20 +664,14 @@ class Meridian:
1298
664
  The linear predictor difference between the treatment variable and its
1299
665
  counterfactual.
1300
666
  """
1301
- if self.rf_tensors.prior_reach_scaled_counterfactual is None:
1302
- return rf_transformed
1303
- rf_transformed_counterfactual = self.adstock_hill_rf(
1304
- reach=self.rf_tensors.prior_reach_scaled_counterfactual,
1305
- frequency=self.rf_tensors.frequency,
1306
- alpha=alpha_rf,
1307
- ec=ec_rf,
1308
- slope=slope_rf,
1309
- decay_functions=self.adstock_decay_spec.rf,
667
+ return self.equations.linear_predictor_counterfactual_difference_rf(
668
+ rf_transformed=rf_transformed,
669
+ alpha_rf=alpha_rf,
670
+ ec_rf=ec_rf,
671
+ slope_rf=slope_rf,
1310
672
  )
1311
- # Absolute values is needed because the difference is negative for mROI
1312
- # priors and positive for ROI and contribution priors.
1313
- return backend.absolute(rf_transformed - rf_transformed_counterfactual)
1314
673
 
674
+ # TODO: Deprecate in favor of ModelEquations.calculate_beta_x.
1315
675
  def calculate_beta_x(
1316
676
  self,
1317
677
  is_non_media: bool,
@@ -1357,45 +717,15 @@ class Meridian:
1357
717
  The coefficient mean parameter of the treatment variable, which has
1358
718
  dimension equal to the number of treatment channels..
1359
719
  """
1360
- if is_non_media:
1361
- random_effects_normal = True
1362
- else:
1363
- random_effects_normal = (
1364
- self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1365
- )
1366
- if self.revenue_per_kpi is None:
1367
- revenue_per_kpi = backend.ones(
1368
- [self.n_geos, self.n_times], dtype=backend.float32
1369
- )
1370
- else:
1371
- revenue_per_kpi = self.revenue_per_kpi
1372
- incremental_outcome_gx_over_beta_gx = backend.einsum(
1373
- "...gtx,gt,g,->...gx",
1374
- linear_predictor_counterfactual_difference,
1375
- revenue_per_kpi,
1376
- self.population,
1377
- self.kpi_transformer.population_scaled_stdev,
1378
- )
1379
- if random_effects_normal:
1380
- numerator_term_x = backend.einsum(
1381
- "...gx,...gx,...x->...x",
1382
- incremental_outcome_gx_over_beta_gx,
1383
- beta_gx_dev,
1384
- eta_x,
1385
- )
1386
- denominator_term_x = backend.einsum(
1387
- "...gx->...x", incremental_outcome_gx_over_beta_gx
1388
- )
1389
- return (incremental_outcome_x - numerator_term_x) / denominator_term_x
1390
- # For log-normal random effects, beta_x and eta_x are not mean & std.
1391
- # The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
1392
- denominator_term_x = backend.einsum(
1393
- "...gx,...gx->...x",
1394
- incremental_outcome_gx_over_beta_gx,
1395
- backend.exp(beta_gx_dev * eta_x[..., backend.newaxis, :]),
720
+ return self.equations.calculate_beta_x(
721
+ is_non_media=is_non_media,
722
+ incremental_outcome_x=incremental_outcome_x,
723
+ linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
724
+ eta_x=eta_x,
725
+ beta_gx_dev=beta_gx_dev,
1396
726
  )
1397
- return backend.log(incremental_outcome_x) - backend.log(denominator_term_x)
1398
727
 
728
+ # TODO: Deprecate in favor of ModelEquations.adstock_hill_media.
1399
729
  def adstock_hill_media(
1400
730
  self,
1401
731
  media: backend.Tensor, # pylint: disable=redefined-outer-name
@@ -1426,34 +756,16 @@ class Meridian:
1426
756
  Tensor with dimensions `[..., n_geos, n_times, n_media_channels]`
1427
757
  representing Adstock and Hill-transformed media.
1428
758
  """
1429
- if n_times_output is None and (media.shape[1] == self.n_media_times):
1430
- n_times_output = self.n_times
1431
- elif n_times_output is None:
1432
- raise ValueError(
1433
- "n_times_output is required. This argument is only optional when "
1434
- "`media` has a number of time periods equal to `self.n_media_times`."
1435
- )
1436
- adstock_transformer = adstock_hill.AdstockTransformer(
759
+ return self.equations.adstock_hill_media(
760
+ media=media,
1437
761
  alpha=alpha,
1438
- max_lag=self.model_spec.max_lag,
1439
- n_times_output=n_times_output,
1440
- decay_functions=decay_functions,
1441
- )
1442
- hill_transformer = adstock_hill.HillTransformer(
1443
762
  ec=ec,
1444
763
  slope=slope,
764
+ decay_functions=decay_functions,
765
+ n_times_output=n_times_output,
1445
766
  )
1446
- transformers_list = (
1447
- [hill_transformer, adstock_transformer]
1448
- if self.model_spec.hill_before_adstock
1449
- else [adstock_transformer, hill_transformer]
1450
- )
1451
-
1452
- media_out = media
1453
- for transformer in transformers_list:
1454
- media_out = transformer.forward(media_out)
1455
- return media_out
1456
767
 
768
+ # TODO: Deprecate in favor of ModelEquations.adstock_hill_rf.
1457
769
  def adstock_hill_rf(
1458
770
  self,
1459
771
  reach: backend.Tensor,
@@ -1485,27 +797,15 @@ class Meridian:
1485
797
  Tensor with dimensions `[..., n_geos, n_times, n_rf_channels]`
1486
798
  representing Hill and Adstock-transformed RF.
1487
799
  """
1488
- if n_times_output is None and (reach.shape[1] == self.n_media_times):
1489
- n_times_output = self.n_times
1490
- elif n_times_output is None:
1491
- raise ValueError(
1492
- "n_times_output is required. This argument is only optional when "
1493
- "`reach` has a number of time periods equal to `self.n_media_times`."
1494
- )
1495
- hill_transformer = adstock_hill.HillTransformer(
800
+ return self.equations.adstock_hill_rf(
801
+ reach=reach,
802
+ frequency=frequency,
803
+ alpha=alpha,
1496
804
  ec=ec,
1497
805
  slope=slope,
1498
- )
1499
- adstock_transformer = adstock_hill.AdstockTransformer(
1500
- alpha=alpha,
1501
- max_lag=self.model_spec.max_lag,
1502
- n_times_output=n_times_output,
1503
806
  decay_functions=decay_functions,
807
+ n_times_output=n_times_output,
1504
808
  )
1505
- adj_frequency = hill_transformer.forward(frequency)
1506
- rf_out = adstock_transformer.forward(reach * adj_frequency)
1507
-
1508
- return rf_out
1509
809
 
1510
810
  def populate_cached_properties(self):
1511
811
  """Eagerly activates all cached properties.
@@ -1515,6 +815,7 @@ class Meridian:
1515
815
  internal state mutations are problematic, and so this method freezes the
1516
816
  object's states before the computation graph is created.
1517
817
  """
818
+ self._model_context.populate_cached_properties()
1518
819
  cls = self.__class__
1519
820
  # "Freeze" all @cached_property attributes by simply accessing them (with
1520
821
  # `getattr()`).