google-meridian 1.4.0__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/RECORD +50 -46
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/checks.py +118 -116
  7. meridian/analysis/review/constants.py +3 -3
  8. meridian/analysis/review/results.py +131 -68
  9. meridian/analysis/review/reviewer.py +8 -23
  10. meridian/analysis/summarizer.py +6 -1
  11. meridian/analysis/test_utils.py +2898 -2538
  12. meridian/analysis/visualizer.py +28 -9
  13. meridian/backend/__init__.py +106 -0
  14. meridian/constants.py +1 -0
  15. meridian/data/input_data.py +30 -52
  16. meridian/data/input_data_builder.py +2 -9
  17. meridian/data/test_utils.py +25 -41
  18. meridian/data/validator.py +48 -0
  19. meridian/mlflow/autolog.py +19 -9
  20. meridian/model/adstock_hill.py +3 -5
  21. meridian/model/context.py +134 -0
  22. meridian/model/eda/constants.py +334 -4
  23. meridian/model/eda/eda_engine.py +724 -312
  24. meridian/model/eda/eda_outcome.py +177 -33
  25. meridian/model/model.py +159 -110
  26. meridian/model/model_test_data.py +38 -0
  27. meridian/model/posterior_sampler.py +103 -62
  28. meridian/model/prior_sampler.py +114 -94
  29. meridian/model/spec.py +23 -14
  30. meridian/templates/card.html.jinja +9 -7
  31. meridian/templates/chart.html.jinja +1 -6
  32. meridian/templates/finding.html.jinja +19 -0
  33. meridian/templates/findings.html.jinja +33 -0
  34. meridian/templates/formatter.py +41 -5
  35. meridian/templates/formatter_test.py +127 -0
  36. meridian/templates/style.css +66 -9
  37. meridian/templates/style.scss +85 -4
  38. meridian/templates/table.html.jinja +1 -0
  39. meridian/version.py +1 -1
  40. scenarioplanner/linkingapi/constants.py +1 -1
  41. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  42. schema/processors/marketing_processor.py +11 -10
  43. schema/processors/model_processor.py +4 -1
  44. schema/serde/distribution.py +12 -7
  45. schema/serde/hyperparameters.py +54 -107
  46. schema/serde/meridian_serde.py +12 -3
  47. schema/utils/__init__.py +1 -0
  48. schema/utils/proto_enum_converter.py +127 -0
  49. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/licenses/LICENSE +0 -0
  50. {google_meridian-1.4.0.dist-info → google_meridian-1.5.1.dist-info}/top_level.txt +0 -0
@@ -16,14 +16,18 @@
16
16
 
17
17
  from collections.abc import Mapping, Sequence
18
18
  import dataclasses
19
+ import functools
19
20
  import itertools
20
21
  import numbers
21
22
  from typing import Any, Optional
22
23
  import warnings
23
24
 
25
+ import arviz as az
24
26
  from meridian import backend
25
27
  from meridian import constants
26
28
  from meridian.model import adstock_hill
29
+ from meridian.model import context
30
+ from meridian.model import equations
27
31
  from meridian.model import model
28
32
  from meridian.model import transformers
29
33
  import numpy as np
@@ -53,6 +57,27 @@ def _validate_non_media_baseline_values_numbers(
53
57
  )
54
58
 
55
59
 
60
+ # TODO: Remove this method.
61
+ def _get_model_context(
62
+ meridian: model.Meridian | None,
63
+ model_context: context.ModelContext | None,
64
+ ) -> context.ModelContext:
65
+ """Gets `model_context`, handling the deprecated `meridian` argument."""
66
+ if meridian is not None:
67
+ warnings.warn(
68
+ (
69
+ "The `meridian` argument is deprecated and will be removed in a"
70
+ " future version. Use `model_context` instead."
71
+ ),
72
+ DeprecationWarning,
73
+ stacklevel=3,
74
+ )
75
+ return meridian.model_context
76
+ if model_context is None:
77
+ raise ValueError("Either `meridian` or `model_context` must be provided.")
78
+ return model_context
79
+
80
+
56
81
  @dataclasses.dataclass
57
82
  class DataTensors(backend.ExtensionType):
58
83
  """Container for data variable arguments of Analyzer methods.
@@ -218,30 +243,38 @@ class DataTensors(backend.ExtensionType):
218
243
  backend.concatenate(spend_tensors, axis=-1) if spend_tensors else None
219
244
  )
220
245
 
221
- def get_modified_times(self, meridian: model.Meridian) -> int | None:
246
+ def get_modified_times(
247
+ self,
248
+ meridian: model.Meridian | None = None,
249
+ model_context: context.ModelContext | None = None,
250
+ ) -> int | None:
222
251
  """Returns `n_times` of any tensor where `n_times` has been modified.
223
252
 
224
253
  This method compares the time dimensions of the attributes in the
225
- `DataTensors` object with the corresponding tensors in the `meridian`
254
+ `DataTensors` object with the corresponding tensors in the `model_context`
226
255
  object. If any of the time dimensions are different, then this method
227
256
  returns the modified number of time periods of the tensor in the
228
257
  `DataTensors` object. If all time dimensions are the same, returns `None`.
229
258
 
230
259
  Args:
231
- meridian: A Meridian object to validate against and get the original data
232
- tensors from.
260
+ meridian: Deprecated. A Meridian object to validate against and get the
261
+ original data tensors from. This argument is deprecated and will be
262
+ removed in a future version. Use `model_context` instead.
263
+ model_context: A ModelContext object to validate against and get the
264
+ original data tensors from.
233
265
 
234
266
  Returns:
235
267
  The `n_times` of any tensor where `n_times` is different from the times
236
- of the corresponding tensor in the `meridian` object. If all time
268
+ of the corresponding tensor in the `model_context` object. If all time
237
269
  dimensions are the same, returns `None`.
238
270
  """
271
+ model_context = _get_model_context(meridian, model_context)
239
272
  for field in dataclasses.fields(self):
240
273
  new_tensor = getattr(self, field.name)
241
274
  if field.name == constants.RF_IMPRESSIONS:
242
- old_tensor = getattr(meridian.rf_tensors, field.name)
275
+ old_tensor = getattr(model_context.rf_tensors, field.name)
243
276
  else:
244
- old_tensor = getattr(meridian.input_data, field.name)
277
+ old_tensor = getattr(model_context.input_data, field.name)
245
278
  # The time dimension is always the second dimension, except for when spend
246
279
  # data is provided with only one dimension of (n_channels).
247
280
  if (
@@ -264,24 +297,28 @@ class DataTensors(backend.ExtensionType):
264
297
  def validate_and_fill_missing_data(
265
298
  self,
266
299
  required_tensors_names: Sequence[str],
267
- meridian: model.Meridian,
300
+ meridian: model.Meridian | None = None,
301
+ model_context: context.ModelContext | None = None,
268
302
  allow_modified_times: bool = True,
269
303
  ) -> Self:
270
304
  """Fills missing data tensors with their original values from the model.
271
305
 
272
306
  This method uses the collection of data tensors set in the DataTensor class
273
307
  and fills in the missing tensors with their original values from the
274
- Meridian object that is passed in. For example, if `required_tensors_names =
275
- ["media", "reach", "frequency"]` and only `media` is set in the DataTensors
276
- class, then this method will output a new DataTensors object with the
277
- `media` value in this object plus the values of the `reach` and `frequency`
278
- from the `meridian` object.
308
+ ModelContext object that is passed in. For example, if
309
+ `required_tensors_names = ["media", "reach", "frequency"]` and only `media`
310
+ is set in the DataTensors class, then this method will output a new
311
+ DataTensors object with the `media` value in this object plus the values of
312
+ the `reach` and `frequency` from the `model_context` object.
279
313
 
280
314
  Args:
281
315
  required_tensors_names: A sequence of data tensors names to validate and
282
- fill in with the original values from the `meridian` object.
283
- meridian: The Meridian object to validate against and get the original
284
- data tensors from.
316
+ fill in with the original values from the `model_context` object.
317
+ meridian: Deprecated. A Meridian object to validate against and get the
318
+ original data tensors from. This argument is deprecated and will be
319
+ removed in a future version. Use `model_context` instead.
320
+ model_context: A ModelContext object to validate against and get the
321
+ original data tensors from.
285
322
  allow_modified_times: A boolean flag indicating whether to allow modified
286
323
  time dimensions in the new data tensors. If False, an error will be
287
324
  raised if the time dimensions of any tensor is modified.
@@ -290,15 +327,30 @@ class DataTensors(backend.ExtensionType):
290
327
  A `DataTensors` container with the original values from the Meridian
291
328
  object filled in for the missing data tensors.
292
329
  """
293
- self._validate_correct_variables_filled(required_tensors_names, meridian)
294
- self._validate_geo_dims(required_tensors_names, meridian)
295
- self._validate_channel_dims(required_tensors_names, meridian)
330
+ model_context = _get_model_context(meridian, model_context)
331
+ self._validate_correct_variables_filled(
332
+ required_variables=required_tensors_names,
333
+ model_context=model_context,
334
+ )
335
+ self._validate_geo_dims(
336
+ required_fields=required_tensors_names, model_context=model_context
337
+ )
338
+ self._validate_channel_dims(
339
+ required_fields=required_tensors_names, model_context=model_context
340
+ )
296
341
  if allow_modified_times:
297
- self._validate_time_dims_flexible_times(required_tensors_names, meridian)
342
+ self._validate_time_dims_flexible_times(
343
+ required_fields=required_tensors_names, model_context=model_context
344
+ )
298
345
  else:
299
- self._validate_time_dims(required_tensors_names, meridian)
346
+ self._validate_time_dims(
347
+ required_fields=required_tensors_names, model_context=model_context
348
+ )
300
349
 
301
- return self._fill_default_values(required_tensors_names, meridian)
350
+ return self._fill_default_values(
351
+ required_fields=required_tensors_names,
352
+ model_context=model_context,
353
+ )
302
354
 
303
355
  def _validate_n_dims(self):
304
356
  """Raises an error if the tensors have the wrong number of dimensions."""
@@ -320,18 +372,20 @@ class DataTensors(backend.ExtensionType):
320
372
  _check_n_dims(tensor, field.name, 3)
321
373
 
322
374
  def _validate_correct_variables_filled(
323
- self, required_variables: Sequence[str], meridian: model.Meridian
324
- ):
375
+ self,
376
+ required_variables: Sequence[str],
377
+ model_context: context.ModelContext,
378
+ ) -> None:
325
379
  """Validates that the correct variables are filled.
326
380
 
327
381
  Args:
328
382
  required_variables: A sequence of data tensors names that are required to
329
383
  be filled in.
330
- meridian: The Meridian object to validate against.
384
+ model_context: A `ModelContext` object to validate against.
331
385
 
332
386
  Raises:
333
387
  ValueError: If an attribute exists in the `DataTensors` object that is not
334
- in the `meridian` object, it is not allowed to be used in analysis.
388
+ in the `model_context` object, it is not allowed to be used in analysis.
335
389
  Warning: If an attribute exists in the `DataTensors` object that is not in
336
390
  the `required_variables` list, it will be ignored.
337
391
  """
@@ -346,44 +400,48 @@ class DataTensors(backend.ExtensionType):
346
400
  )
347
401
  if field.name in required_variables:
348
402
  if field.name == constants.RF_IMPRESSIONS:
349
- if meridian.n_rf_channels == 0:
403
+ if model_context.n_rf_channels == 0:
350
404
  raise ValueError(
351
405
  "New `rf_impressions` is not allowed because there are no R&F"
352
406
  " channels in the Meridian model."
353
407
  )
354
- elif getattr(meridian.input_data, field.name) is None:
408
+ elif getattr(model_context.input_data, field.name) is None:
355
409
  raise ValueError(
356
410
  f"New `{field.name}` is not allowed because the input data to the"
357
411
  f" Meridian model does not contain `{field.name}`."
358
412
  )
359
413
 
360
414
  def _validate_geo_dims(
361
- self, required_fields: Sequence[str], meridian: model.Meridian
362
- ):
415
+ self,
416
+ required_fields: Sequence[str],
417
+ model_context: context.ModelContext,
418
+ ) -> None:
363
419
  """Validates the geo dimension of the specified data variables."""
364
420
  for var_name in required_fields:
365
421
  new_tensor = getattr(self, var_name)
366
- if new_tensor is not None and new_tensor.shape[0] != meridian.n_geos:
422
+ if new_tensor is not None and new_tensor.shape[0] != model_context.n_geos:
367
423
  # Skip spend and time data with only 1 dimension.
368
424
  if new_tensor.ndim == 1:
369
425
  continue
370
426
  raise ValueError(
371
- f"New `{var_name}` is expected to have {meridian.n_geos}"
427
+ f"New `{var_name}` is expected to have {model_context.n_geos}"
372
428
  f" geos. Found {new_tensor.shape[0]} geos."
373
429
  )
374
430
 
375
431
  def _validate_channel_dims(
376
- self, required_fields: Sequence[str], meridian: model.Meridian
377
- ):
432
+ self,
433
+ required_fields: Sequence[str],
434
+ model_context: context.ModelContext,
435
+ ) -> None:
378
436
  """Validates the channel dimension of the specified data variables."""
379
437
  for var_name in required_fields:
380
438
  if var_name in [constants.REVENUE_PER_KPI, constants.TIME]:
381
439
  continue
382
440
  new_tensor = getattr(self, var_name)
383
441
  if var_name == constants.RF_IMPRESSIONS:
384
- old_tensor = getattr(meridian.rf_tensors, var_name)
442
+ old_tensor = getattr(model_context.rf_tensors, var_name)
385
443
  else:
386
- old_tensor = getattr(meridian.input_data, var_name)
444
+ old_tensor = getattr(model_context.input_data, var_name)
387
445
  if new_tensor is not None:
388
446
  assert old_tensor is not None
389
447
  if new_tensor.shape[-1] != old_tensor.shape[-1]:
@@ -393,15 +451,17 @@ class DataTensors(backend.ExtensionType):
393
451
  )
394
452
 
395
453
  def _validate_time_dims(
396
- self, required_fields: Sequence[str], meridian: model.Meridian
454
+ self,
455
+ required_fields: Sequence[str],
456
+ model_context: context.ModelContext,
397
457
  ):
398
458
  """Validates the time dimension of the specified data variables."""
399
459
  for var_name in required_fields:
400
460
  new_tensor = getattr(self, var_name)
401
461
  if var_name == constants.RF_IMPRESSIONS:
402
- old_tensor = getattr(meridian.rf_tensors, var_name)
462
+ old_tensor = getattr(model_context.rf_tensors, var_name)
403
463
  else:
404
- old_tensor = getattr(meridian.input_data, var_name)
464
+ old_tensor = getattr(model_context.input_data, var_name)
405
465
 
406
466
  # Skip spend data with only 1 dimension of (n_channels).
407
467
  if (
@@ -428,10 +488,12 @@ class DataTensors(backend.ExtensionType):
428
488
  )
429
489
 
430
490
  def _validate_time_dims_flexible_times(
431
- self, required_fields: Sequence[str], meridian: model.Meridian
491
+ self,
492
+ required_fields: Sequence[str],
493
+ model_context: context.ModelContext,
432
494
  ):
433
495
  """Validates the time dimension for the flexible times case."""
434
- new_n_times = self.get_modified_times(meridian)
496
+ new_n_times = self.get_modified_times(model_context=model_context)
435
497
  # If no times were modified, then there is nothing more to validate.
436
498
  if new_n_times is None:
437
499
  return
@@ -440,9 +502,9 @@ class DataTensors(backend.ExtensionType):
440
502
  for var_name in required_fields:
441
503
  new_tensor = getattr(self, var_name)
442
504
  if var_name == constants.RF_IMPRESSIONS:
443
- old_tensor = getattr(meridian.rf_tensors, var_name)
505
+ old_tensor = getattr(model_context.rf_tensors, var_name)
444
506
  else:
445
- old_tensor = getattr(meridian.input_data, var_name)
507
+ old_tensor = getattr(model_context.input_data, var_name)
446
508
 
447
509
  if old_tensor is None:
448
510
  continue
@@ -484,7 +546,9 @@ class DataTensors(backend.ExtensionType):
484
546
  )
485
547
 
486
548
  def _fill_default_values(
487
- self, required_fields: Sequence[str], meridian: model.Meridian
549
+ self,
550
+ required_fields: Sequence[str],
551
+ model_context: context.ModelContext,
488
552
  ) -> Self:
489
553
  """Fills default values and returns a new DataTensors object."""
490
554
  output = {}
@@ -493,23 +557,23 @@ class DataTensors(backend.ExtensionType):
493
557
  if var_name not in required_fields:
494
558
  continue
495
559
 
496
- if hasattr(meridian.media_tensors, var_name):
497
- old_tensor = getattr(meridian.media_tensors, var_name)
498
- elif hasattr(meridian.rf_tensors, var_name):
499
- old_tensor = getattr(meridian.rf_tensors, var_name)
500
- elif hasattr(meridian.organic_media_tensors, var_name):
501
- old_tensor = getattr(meridian.organic_media_tensors, var_name)
502
- elif hasattr(meridian.organic_rf_tensors, var_name):
503
- old_tensor = getattr(meridian.organic_rf_tensors, var_name)
560
+ if hasattr(model_context.media_tensors, var_name):
561
+ old_tensor = getattr(model_context.media_tensors, var_name)
562
+ elif hasattr(model_context.rf_tensors, var_name):
563
+ old_tensor = getattr(model_context.rf_tensors, var_name)
564
+ elif hasattr(model_context.organic_media_tensors, var_name):
565
+ old_tensor = getattr(model_context.organic_media_tensors, var_name)
566
+ elif hasattr(model_context.organic_rf_tensors, var_name):
567
+ old_tensor = getattr(model_context.organic_rf_tensors, var_name)
504
568
  elif var_name == constants.NON_MEDIA_TREATMENTS:
505
- old_tensor = meridian.non_media_treatments
569
+ old_tensor = model_context.non_media_treatments
506
570
  elif var_name == constants.CONTROLS:
507
- old_tensor = meridian.controls
571
+ old_tensor = model_context.controls
508
572
  elif var_name == constants.REVENUE_PER_KPI:
509
- old_tensor = meridian.revenue_per_kpi
573
+ old_tensor = model_context.revenue_per_kpi
510
574
  elif var_name == constants.TIME:
511
575
  old_tensor = backend.to_tensor(
512
- meridian.input_data.time.values.tolist(), dtype=backend.string
576
+ model_context.input_data.time.values.tolist(), dtype=backend.string
513
577
  )
514
578
  else:
515
579
  continue
@@ -858,12 +922,38 @@ def _central_tendency_and_ci_by_prior_and_posterior(
858
922
  class Analyzer:
859
923
  """Runs calculations to analyze the raw data after fitting the model."""
860
924
 
861
- def __init__(self, meridian: model.Meridian):
862
- self._meridian = meridian
925
+ def __init__(
926
+ self,
927
+ # TODO: Remove this argument.
928
+ meridian: model.Meridian | None = None,
929
+ *,
930
+ model_context: context.ModelContext | None = None,
931
+ inference_data: az.InferenceData | None = None,
932
+ ):
863
933
  # Make the meridian object ready for methods in this analyzer that create
864
934
  # backend.function computation graphs: it should be frozen for no more
865
935
  # internal states mutation before those graphs execute.
866
- self._meridian.populate_cached_properties()
936
+ self.model_context = _get_model_context(meridian, model_context)
937
+
938
+ if meridian is not None:
939
+ self._inference_data = inference_data or meridian.inference_data
940
+ elif model_context is None or inference_data is None:
941
+ raise ValueError(
942
+ "If `meridian` is not provided, then `model_context` and"
943
+ " `inference_data` must be provided."
944
+ )
945
+ else:
946
+ self._inference_data = inference_data
947
+
948
+ self.model_context.populate_cached_properties()
949
+
950
+ @functools.cached_property
951
+ def _model_equations(self) -> equations.ModelEquations:
952
+ return equations.ModelEquations(self.model_context)
953
+
954
+ @property
955
+ def inference_data(self) -> az.InferenceData:
956
+ return self._inference_data
867
957
 
868
958
  @backend.function(jit_compile=True)
869
959
  def _get_kpi_means(
@@ -902,7 +992,7 @@ class Analyzer:
902
992
  result = tau_gt + backend.einsum(
903
993
  "...gtm,...gm->...gt", combined_media_transformed, combined_beta
904
994
  )
905
- if self._meridian.controls is not None:
995
+ if self.model_context.controls is not None:
906
996
  result += backend.einsum(
907
997
  "...gtc,...gc->...gt",
908
998
  data_tensors.controls,
@@ -933,20 +1023,20 @@ class Analyzer:
933
1023
  UserWarning: If the KPI type is revenue and use_kpi is True or if
934
1024
  `use_kpi=False` but `revenue_per_kpi` is not available.
935
1025
  """
936
- if use_kpi and self._meridian.input_data.kpi_type == constants.REVENUE:
1026
+ if use_kpi and self.model_context.input_data.kpi_type == constants.REVENUE:
937
1027
  warnings.warn(
938
1028
  "Setting `use_kpi=True` has no effect when `kpi_type=REVENUE`"
939
1029
  " since in this case, KPI is equal to revenue."
940
1030
  )
941
1031
  return False
942
1032
 
943
- if not use_kpi and self._meridian.input_data.revenue_per_kpi is None:
1033
+ if not use_kpi and self.model_context.input_data.revenue_per_kpi is None:
944
1034
  warnings.warn(
945
1035
  "Revenue analysis is not available when `revenue_per_kpi` is"
946
1036
  " unknown. Defaulting to KPI analysis."
947
1037
  )
948
1038
 
949
- return use_kpi or self._meridian.input_data.revenue_per_kpi is None
1039
+ return use_kpi or self.model_context.input_data.revenue_per_kpi is None
950
1040
 
951
1041
  def _get_adstock_dataframe(
952
1042
  self,
@@ -972,36 +1062,37 @@ class Analyzer:
972
1062
  ci_lo, and mean decayed effects for either media or RF channel types.
973
1063
  """
974
1064
  window_size = min(
975
- self._meridian.model_spec.max_lag + 1, self._meridian.n_media_times
1065
+ self.model_context.model_spec.max_lag + 1,
1066
+ self.model_context.n_media_times,
976
1067
  )
977
1068
  if channel_type == constants.MEDIA:
978
- prior = self._meridian.inference_data.prior.alpha_m.values[0]
1069
+ prior = self._inference_data.prior.alpha_m.values[0]
979
1070
  posterior = np.reshape(
980
- self._meridian.inference_data.posterior.alpha_m.values,
981
- (-1, self._meridian.n_media_channels),
1071
+ self._inference_data.posterior.alpha_m.values,
1072
+ (-1, self.model_context.n_media_channels),
982
1073
  )
983
- decay_functions = self._meridian.adstock_decay_spec.media
1074
+ decay_functions = self.model_context.adstock_decay_spec.media
984
1075
  elif channel_type == constants.RF:
985
- prior = self._meridian.inference_data.prior.alpha_rf.values[0]
1076
+ prior = self._inference_data.prior.alpha_rf.values[0]
986
1077
  posterior = np.reshape(
987
- self._meridian.inference_data.posterior.alpha_rf.values,
988
- (-1, self._meridian.n_rf_channels),
1078
+ self._inference_data.posterior.alpha_rf.values,
1079
+ (-1, self.model_context.n_rf_channels),
989
1080
  )
990
- decay_functions = self._meridian.adstock_decay_spec.rf
1081
+ decay_functions = self.model_context.adstock_decay_spec.rf
991
1082
  elif channel_type == constants.ORGANIC_MEDIA:
992
- prior = self._meridian.inference_data.prior.alpha_om.values[0]
1083
+ prior = self._inference_data.prior.alpha_om.values[0]
993
1084
  posterior = np.reshape(
994
- self._meridian.inference_data.posterior.alpha_om.values,
995
- (-1, self._meridian.n_organic_media_channels),
1085
+ self._inference_data.posterior.alpha_om.values,
1086
+ (-1, self.model_context.n_organic_media_channels),
996
1087
  )
997
- decay_functions = self._meridian.adstock_decay_spec.organic_media
1088
+ decay_functions = self.model_context.adstock_decay_spec.organic_media
998
1089
  elif channel_type == constants.ORGANIC_RF:
999
- prior = self._meridian.inference_data.prior.alpha_orf.values[0]
1090
+ prior = self._inference_data.prior.alpha_orf.values[0]
1000
1091
  posterior = np.reshape(
1001
- self._meridian.inference_data.posterior.alpha_orf.values,
1002
- (-1, self._meridian.n_organic_rf_channels),
1092
+ self._inference_data.posterior.alpha_orf.values,
1093
+ (-1, self.model_context.n_organic_rf_channels),
1003
1094
  )
1004
- decay_functions = self._meridian.adstock_decay_spec.organic_rf
1095
+ decay_functions = self.model_context.adstock_decay_spec.organic_rf
1005
1096
  else:
1006
1097
  raise ValueError(
1007
1098
  f"Unsupported channel type for adstock decay: '{channel_type}'. "
@@ -1103,65 +1194,65 @@ class Analyzer:
1103
1194
  """
1104
1195
  if new_data is None:
1105
1196
  return DataTensors(
1106
- media=self._meridian.media_tensors.media_scaled,
1107
- reach=self._meridian.rf_tensors.reach_scaled,
1108
- frequency=self._meridian.rf_tensors.frequency,
1109
- organic_media=self._meridian.organic_media_tensors.organic_media_scaled,
1110
- organic_reach=self._meridian.organic_rf_tensors.organic_reach_scaled,
1111
- organic_frequency=self._meridian.organic_rf_tensors.organic_frequency,
1112
- non_media_treatments=self._meridian.non_media_treatments_normalized,
1113
- controls=self._meridian.controls_scaled,
1114
- revenue_per_kpi=self._meridian.revenue_per_kpi,
1197
+ media=self.model_context.media_tensors.media_scaled,
1198
+ reach=self.model_context.rf_tensors.reach_scaled,
1199
+ frequency=self.model_context.rf_tensors.frequency,
1200
+ organic_media=self.model_context.organic_media_tensors.organic_media_scaled,
1201
+ organic_reach=self.model_context.organic_rf_tensors.organic_reach_scaled,
1202
+ organic_frequency=self.model_context.organic_rf_tensors.organic_frequency,
1203
+ non_media_treatments=self.model_context.non_media_treatments_normalized,
1204
+ controls=self.model_context.controls_scaled,
1205
+ revenue_per_kpi=self.model_context.revenue_per_kpi,
1115
1206
  )
1116
1207
  media_scaled = _transformed_new_or_scaled(
1117
1208
  new_variable=new_data.media,
1118
- transformer=self._meridian.media_tensors.media_transformer,
1119
- scaled_variable=self._meridian.media_tensors.media_scaled,
1209
+ transformer=self.model_context.media_tensors.media_transformer,
1210
+ scaled_variable=self.model_context.media_tensors.media_scaled,
1120
1211
  )
1121
1212
 
1122
1213
  reach_scaled = _transformed_new_or_scaled(
1123
1214
  new_variable=new_data.reach,
1124
- transformer=self._meridian.rf_tensors.reach_transformer,
1125
- scaled_variable=self._meridian.rf_tensors.reach_scaled,
1215
+ transformer=self.model_context.rf_tensors.reach_transformer,
1216
+ scaled_variable=self.model_context.rf_tensors.reach_scaled,
1126
1217
  )
1127
1218
 
1128
1219
  frequency = (
1129
1220
  new_data.frequency
1130
1221
  if new_data.frequency is not None
1131
- else self._meridian.rf_tensors.frequency
1222
+ else self.model_context.rf_tensors.frequency
1132
1223
  )
1133
1224
 
1134
1225
  controls_scaled = _transformed_new_or_scaled(
1135
1226
  new_variable=new_data.controls,
1136
- transformer=self._meridian.controls_transformer,
1137
- scaled_variable=self._meridian.controls_scaled,
1227
+ transformer=self.model_context.controls_transformer,
1228
+ scaled_variable=self.model_context.controls_scaled,
1138
1229
  )
1139
1230
  revenue_per_kpi = (
1140
1231
  new_data.revenue_per_kpi
1141
1232
  if new_data.revenue_per_kpi is not None
1142
- else self._meridian.revenue_per_kpi
1233
+ else self.model_context.revenue_per_kpi
1143
1234
  )
1144
1235
 
1145
1236
  if include_non_paid_channels:
1146
1237
  organic_media_scaled = _transformed_new_or_scaled(
1147
1238
  new_variable=new_data.organic_media,
1148
- transformer=self._meridian.organic_media_tensors.organic_media_transformer,
1149
- scaled_variable=self._meridian.organic_media_tensors.organic_media_scaled,
1239
+ transformer=self.model_context.organic_media_tensors.organic_media_transformer,
1240
+ scaled_variable=self.model_context.organic_media_tensors.organic_media_scaled,
1150
1241
  )
1151
1242
  organic_reach_scaled = _transformed_new_or_scaled(
1152
1243
  new_variable=new_data.organic_reach,
1153
- transformer=self._meridian.organic_rf_tensors.organic_reach_transformer,
1154
- scaled_variable=self._meridian.organic_rf_tensors.organic_reach_scaled,
1244
+ transformer=self.model_context.organic_rf_tensors.organic_reach_transformer,
1245
+ scaled_variable=self.model_context.organic_rf_tensors.organic_reach_scaled,
1155
1246
  )
1156
1247
  organic_frequency = (
1157
1248
  new_data.organic_frequency
1158
1249
  if new_data.organic_frequency is not None
1159
- else self._meridian.organic_rf_tensors.organic_frequency
1250
+ else self.model_context.organic_rf_tensors.organic_frequency
1160
1251
  )
1161
1252
  non_media_treatments_normalized = _transformed_new_or_scaled(
1162
1253
  new_variable=new_data.non_media_treatments,
1163
- transformer=self._meridian.non_media_transformer,
1164
- scaled_variable=self._meridian.non_media_treatments_normalized,
1254
+ transformer=self.model_context.non_media_transformer,
1255
+ scaled_variable=self.model_context.non_media_treatments_normalized,
1165
1256
  )
1166
1257
  return DataTensors(
1167
1258
  media=media_scaled,
@@ -1198,14 +1289,14 @@ class Analyzer:
1198
1289
  and organic RF parameters names in inference data.
1199
1290
  """
1200
1291
  params = []
1201
- if self._meridian.media_tensors.media is not None:
1292
+ if self.model_context.media_tensors.media is not None:
1202
1293
  params.extend([
1203
1294
  constants.EC_M,
1204
1295
  constants.SLOPE_M,
1205
1296
  constants.ALPHA_M,
1206
1297
  constants.BETA_GM,
1207
1298
  ])
1208
- if self._meridian.rf_tensors.reach is not None:
1299
+ if self.model_context.rf_tensors.reach is not None:
1209
1300
  params.extend([
1210
1301
  constants.EC_RF,
1211
1302
  constants.SLOPE_RF,
@@ -1213,21 +1304,21 @@ class Analyzer:
1213
1304
  constants.BETA_GRF,
1214
1305
  ])
1215
1306
  if include_non_paid_channels:
1216
- if self._meridian.organic_media_tensors.organic_media is not None:
1307
+ if self.model_context.organic_media_tensors.organic_media is not None:
1217
1308
  params.extend([
1218
1309
  constants.EC_OM,
1219
1310
  constants.SLOPE_OM,
1220
1311
  constants.ALPHA_OM,
1221
1312
  constants.BETA_GOM,
1222
1313
  ])
1223
- if self._meridian.organic_rf_tensors.organic_reach is not None:
1314
+ if self.model_context.organic_rf_tensors.organic_reach is not None:
1224
1315
  params.extend([
1225
1316
  constants.EC_ORF,
1226
1317
  constants.SLOPE_ORF,
1227
1318
  constants.ALPHA_ORF,
1228
1319
  constants.BETA_GORF,
1229
1320
  ])
1230
- if self._meridian.non_media_treatments is not None:
1321
+ if self.model_context.non_media_treatments is not None:
1231
1322
  params.extend([
1232
1323
  constants.GAMMA_GN,
1233
1324
  ])
@@ -1261,12 +1352,12 @@ class Analyzer:
1261
1352
  combined_betas = []
1262
1353
  if data_tensors.media is not None:
1263
1354
  combined_medias.append(
1264
- self._meridian.adstock_hill_media(
1355
+ self._model_equations.adstock_hill_media(
1265
1356
  media=data_tensors.media,
1266
1357
  alpha=dist_tensors.alpha_m,
1267
1358
  ec=dist_tensors.ec_m,
1268
1359
  slope=dist_tensors.slope_m,
1269
- decay_functions=self._meridian.adstock_decay_spec.media,
1360
+ decay_functions=self.model_context.adstock_decay_spec.media,
1270
1361
  n_times_output=n_times_output,
1271
1362
  )
1272
1363
  )
@@ -1274,38 +1365,38 @@ class Analyzer:
1274
1365
 
1275
1366
  if data_tensors.reach is not None:
1276
1367
  combined_medias.append(
1277
- self._meridian.adstock_hill_rf(
1368
+ self._model_equations.adstock_hill_rf(
1278
1369
  reach=data_tensors.reach,
1279
1370
  frequency=data_tensors.frequency,
1280
1371
  alpha=dist_tensors.alpha_rf,
1281
1372
  ec=dist_tensors.ec_rf,
1282
1373
  slope=dist_tensors.slope_rf,
1283
- decay_functions=self._meridian.adstock_decay_spec.rf,
1374
+ decay_functions=self.model_context.adstock_decay_spec.rf,
1284
1375
  n_times_output=n_times_output,
1285
1376
  )
1286
1377
  )
1287
1378
  combined_betas.append(dist_tensors.beta_grf)
1288
1379
  if data_tensors.organic_media is not None:
1289
1380
  combined_medias.append(
1290
- self._meridian.adstock_hill_media(
1381
+ self._model_equations.adstock_hill_media(
1291
1382
  media=data_tensors.organic_media,
1292
1383
  alpha=dist_tensors.alpha_om,
1293
1384
  ec=dist_tensors.ec_om,
1294
1385
  slope=dist_tensors.slope_om,
1295
- decay_functions=self._meridian.adstock_decay_spec.organic_media,
1386
+ decay_functions=self.model_context.adstock_decay_spec.organic_media,
1296
1387
  n_times_output=n_times_output,
1297
1388
  )
1298
1389
  )
1299
1390
  combined_betas.append(dist_tensors.beta_gom)
1300
1391
  if data_tensors.organic_reach is not None:
1301
1392
  combined_medias.append(
1302
- self._meridian.adstock_hill_rf(
1393
+ self._model_equations.adstock_hill_rf(
1303
1394
  reach=data_tensors.organic_reach,
1304
1395
  frequency=data_tensors.organic_frequency,
1305
1396
  alpha=dist_tensors.alpha_orf,
1306
1397
  ec=dist_tensors.ec_orf,
1307
1398
  slope=dist_tensors.slope_orf,
1308
- decay_functions=self._meridian.adstock_decay_spec.organic_rf,
1399
+ decay_functions=self.model_context.adstock_decay_spec.organic_rf,
1309
1400
  n_times_output=n_times_output,
1310
1401
  )
1311
1402
  )
@@ -1355,7 +1446,7 @@ class Analyzer:
1355
1446
  Returns:
1356
1447
  A tensor with filtered and/or aggregated geo and time dimensions.
1357
1448
  """
1358
- mmm = self._meridian
1449
+ m_context = self.model_context
1359
1450
 
1360
1451
  # Validate the tensor shape and determine if it has a media dimension.
1361
1452
  if flexible_time_dim:
@@ -1367,17 +1458,17 @@ class Analyzer:
1367
1458
  )
1368
1459
  n_times = tensor.shape[-2] if has_media_dim else tensor.shape[-1]
1369
1460
  else:
1370
- n_times = mmm.n_times
1461
+ n_times = m_context.n_times
1371
1462
  # Allowed subsets of channels: media, RF, media+RF, all channels.
1372
1463
  allowed_n_channels = [
1373
- mmm.n_media_channels,
1374
- mmm.n_rf_channels,
1375
- mmm.n_media_channels + mmm.n_rf_channels,
1376
- mmm.n_media_channels
1377
- + mmm.n_rf_channels
1378
- + mmm.n_non_media_channels
1379
- + mmm.n_organic_media_channels
1380
- + mmm.n_organic_rf_channels,
1464
+ m_context.n_media_channels,
1465
+ m_context.n_rf_channels,
1466
+ m_context.n_media_channels + m_context.n_rf_channels,
1467
+ m_context.n_media_channels
1468
+ + m_context.n_rf_channels
1469
+ + m_context.n_non_media_channels
1470
+ + m_context.n_organic_media_channels
1471
+ + m_context.n_organic_rf_channels,
1381
1472
  ]
1382
1473
  # Allow extra channel if aggregated (All_Channels) value is included.
1383
1474
  allowed_channel_dim = allowed_n_channels + [
@@ -1386,10 +1477,10 @@ class Analyzer:
1386
1477
  expected_shapes_w_media = [
1387
1478
  backend.TensorShape(shape)
1388
1479
  for shape in itertools.product(
1389
- [mmm.n_geos], [n_times], allowed_channel_dim
1480
+ [m_context.n_geos], [n_times], allowed_channel_dim
1390
1481
  )
1391
1482
  ]
1392
- expected_shape_wo_media = backend.TensorShape([mmm.n_geos, n_times])
1483
+ expected_shape_wo_media = backend.TensorShape([m_context.n_geos, n_times])
1393
1484
  if not flexible_time_dim:
1394
1485
  if tensor.shape[-3:] in expected_shapes_w_media:
1395
1486
  has_media_dim = True
@@ -1417,13 +1508,15 @@ class Analyzer:
1417
1508
 
1418
1509
  # Validate the selected geo and time dimensions and create a mask.
1419
1510
  if selected_geos is not None:
1420
- if any(geo not in mmm.input_data.geo for geo in selected_geos):
1511
+ if any(geo not in m_context.input_data.geo for geo in selected_geos):
1421
1512
  raise ValueError(
1422
1513
  "`selected_geos` must match the geo dimension names from "
1423
1514
  "meridian.InputData."
1424
1515
  )
1425
1516
  geo_indices = [
1426
- i for i, x in enumerate(mmm.input_data.geo) if x in selected_geos
1517
+ i
1518
+ for i, x in enumerate(m_context.input_data.geo)
1519
+ if x in selected_geos
1427
1520
  ]
1428
1521
  tensor = backend.gather(
1429
1522
  tensor,
@@ -1434,14 +1527,16 @@ class Analyzer:
1434
1527
  if selected_times is not None:
1435
1528
  _validate_selected_times(
1436
1529
  selected_times=selected_times,
1437
- input_times=mmm.input_data.time,
1530
+ input_times=m_context.input_data.time,
1438
1531
  n_times=tensor.shape[time_dim],
1439
1532
  arg_name="selected_times",
1440
1533
  comparison_arg_name="`tensor`",
1441
1534
  )
1442
1535
  if _is_str_list(selected_times):
1443
1536
  time_indices = [
1444
- i for i, x in enumerate(mmm.input_data.time) if x in selected_times
1537
+ i
1538
+ for i, x in enumerate(m_context.input_data.time)
1539
+ if x in selected_times
1445
1540
  ]
1446
1541
  tensor = backend.gather(
1447
1542
  tensor,
@@ -1558,13 +1653,13 @@ class Analyzer:
1558
1653
  """
1559
1654
  use_kpi = self._use_kpi(use_kpi)
1560
1655
  self._check_kpi_transformation(inverse_transform_outcome, use_kpi)
1561
- if self._meridian.is_national:
1656
+ if self.model_context.is_national:
1562
1657
  _warn_if_geo_arg_in_kwargs(
1563
1658
  aggregate_geos=aggregate_geos,
1564
1659
  selected_geos=selected_geos,
1565
1660
  )
1566
1661
  dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
1567
- if dist_type not in self._meridian.inference_data.groups():
1662
+ if dist_type not in self._inference_data.groups():
1568
1663
  raise model.NotFittedModelError(
1569
1664
  f"sample_{dist_type}() must be called prior to calling"
1570
1665
  " `expected_outcome()`."
@@ -1577,14 +1672,14 @@ class Analyzer:
1577
1672
  )
1578
1673
  filled_tensors = new_data.validate_and_fill_missing_data(
1579
1674
  required_tensors_names=required_fields,
1580
- meridian=self._meridian,
1675
+ model_context=self.model_context,
1581
1676
  allow_modified_times=False,
1582
1677
  )
1583
1678
 
1584
1679
  params = (
1585
- self._meridian.inference_data.posterior
1680
+ self._inference_data.posterior
1586
1681
  if use_posterior
1587
- else self._meridian.inference_data.prior
1682
+ else self._inference_data.prior
1588
1683
  )
1589
1684
  # We always compute the expected outcome of all channels, including non-paid
1590
1685
  # channels.
@@ -1596,7 +1691,7 @@ class Analyzer:
1596
1691
  n_draws = params.draw.size
1597
1692
  n_chains = params.chain.size
1598
1693
  outcome_means = backend.zeros(
1599
- (n_chains, 0, self._meridian.n_geos, self._meridian.n_times)
1694
+ (n_chains, 0, self.model_context.n_geos, self.model_context.n_times)
1600
1695
  )
1601
1696
  batch_starting_indices = np.arange(n_draws, step=batch_size)
1602
1697
  param_list = (
@@ -1604,7 +1699,7 @@ class Analyzer:
1604
1699
  constants.MU_T,
1605
1700
  constants.TAU_G,
1606
1701
  ]
1607
- + ([constants.GAMMA_GC] if self._meridian.n_controls else [])
1702
+ + ([constants.GAMMA_GC] if self.model_context.n_controls else [])
1608
1703
  + self._get_causal_param_names(include_non_paid_channels=True)
1609
1704
  )
1610
1705
  outcome_means_temps = []
@@ -1626,7 +1721,7 @@ class Analyzer:
1626
1721
  [outcome_means, *outcome_means_temps], axis=1
1627
1722
  )
1628
1723
  if inverse_transform_outcome:
1629
- outcome_means = self._meridian.kpi_transformer.inverse(outcome_means)
1724
+ outcome_means = self.model_context.kpi_transformer.inverse(outcome_means)
1630
1725
  if not use_kpi:
1631
1726
  outcome_means *= filled_tensors.revenue_per_kpi
1632
1727
 
@@ -1700,7 +1795,7 @@ class Analyzer:
1700
1795
  " `_get_incremental_kpi` when `non_media_treatments` data is"
1701
1796
  " present."
1702
1797
  )
1703
- n_media_times = self._meridian.n_media_times
1798
+ n_media_times = self.model_context.n_media_times
1704
1799
  if data_tensors.media is not None:
1705
1800
  n_times = data_tensors.media.shape[1] # pytype: disable=attribute-error
1706
1801
  n_times_output = n_times if n_times != n_media_times else None
@@ -1759,11 +1854,11 @@ class Analyzer:
1759
1854
  """
1760
1855
  use_kpi = self._use_kpi(use_kpi)
1761
1856
  if revenue_per_kpi is None:
1762
- revenue_per_kpi = self._meridian.revenue_per_kpi
1763
- t1 = self._meridian.kpi_transformer.inverse(
1857
+ revenue_per_kpi = self.model_context.revenue_per_kpi
1858
+ t1 = self.model_context.kpi_transformer.inverse(
1764
1859
  backend.einsum("...m->m...", modeled_incremental_outcome)
1765
1860
  )
1766
- t2 = self._meridian.kpi_transformer.inverse(backend.zeros_like(t1))
1861
+ t2 = self.model_context.kpi_transformer.inverse(backend.zeros_like(t1))
1767
1862
  kpi = backend.einsum("m...->...m", t1 - t2)
1768
1863
 
1769
1864
  if use_kpi:
@@ -1884,7 +1979,7 @@ class Analyzer:
1884
1979
  has_media_dim=True,
1885
1980
  )
1886
1981
 
1887
- # TODO: b/407847021 - Add support for `new_data.time`.
1982
+ # TODO: Add support for `new_data.time`.
1888
1983
  def incremental_outcome(
1889
1984
  self,
1890
1985
  use_posterior: bool = True,
@@ -2050,10 +2145,10 @@ class Analyzer:
2050
2145
  dimension and not all treatment variables are provided in `new_data`
2051
2146
  with matching time dimensions.
2052
2147
  """
2053
- mmm = self._meridian
2148
+ m_context = self.model_context
2054
2149
  use_kpi = self._use_kpi(use_kpi)
2055
2150
  self._check_kpi_transformation(inverse_transform_outcome, use_kpi)
2056
- if self._meridian.is_national:
2151
+ if m_context.is_national:
2057
2152
  _warn_if_geo_arg_in_kwargs(
2058
2153
  aggregate_geos=aggregate_geos,
2059
2154
  selected_geos=selected_geos,
@@ -2061,7 +2156,7 @@ class Analyzer:
2061
2156
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2062
2157
  dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
2063
2158
 
2064
- if dist_type not in mmm.inference_data.groups():
2159
+ if dist_type not in self.inference_data.groups():
2065
2160
  raise model.NotFittedModelError(
2066
2161
  f"sample_{dist_type}() must be called prior to calling this method."
2067
2162
  )
@@ -2084,23 +2179,26 @@ class Analyzer:
2084
2179
  if include_non_paid_channels:
2085
2180
  required_params += constants.NON_PAID_DATA
2086
2181
  data_tensors = new_data.validate_and_fill_missing_data(
2087
- required_tensors_names=required_params, meridian=self._meridian
2182
+ required_tensors_names=required_params,
2183
+ model_context=self.model_context,
2184
+ )
2185
+ new_n_media_times = data_tensors.get_modified_times(
2186
+ model_context=self.model_context
2088
2187
  )
2089
- new_n_media_times = data_tensors.get_modified_times(self._meridian)
2090
2188
 
2091
2189
  if new_n_media_times is None:
2092
- new_n_media_times = mmm.n_media_times
2190
+ new_n_media_times = m_context.n_media_times
2093
2191
  _validate_selected_times(
2094
2192
  selected_times=selected_times,
2095
- input_times=mmm.input_data.time,
2096
- n_times=mmm.n_times,
2193
+ input_times=m_context.input_data.time,
2194
+ n_times=m_context.n_times,
2097
2195
  arg_name="selected_times",
2098
2196
  comparison_arg_name="the input data",
2099
2197
  )
2100
2198
  _validate_selected_times(
2101
2199
  selected_times=media_selected_times,
2102
- input_times=mmm.input_data.media_time,
2103
- n_times=mmm.n_media_times,
2200
+ input_times=m_context.input_data.media_time,
2201
+ n_times=m_context.n_media_times,
2104
2202
  arg_name="media_selected_times",
2105
2203
  comparison_arg_name="the media tensors",
2106
2204
  )
@@ -2115,7 +2213,7 @@ class Analyzer:
2115
2213
  else:
2116
2214
  if all(isinstance(time, str) for time in media_selected_times):
2117
2215
  media_selected_times = [
2118
- x in media_selected_times for x in mmm.input_data.media_time
2216
+ x in media_selected_times for x in m_context.input_data.media_time
2119
2217
  ]
2120
2218
 
2121
2219
  # Set counterfactual tensors based on the scaling factors and the media
@@ -2129,11 +2227,11 @@ class Analyzer:
2129
2227
 
2130
2228
  if data_tensors.non_media_treatments is not None:
2131
2229
  non_media_treatments_baseline_scaled = (
2132
- self._meridian.compute_non_media_treatments_baseline(
2230
+ self._model_equations.compute_non_media_treatments_baseline(
2133
2231
  non_media_baseline_values=non_media_baseline_values,
2134
2232
  )
2135
2233
  )
2136
- non_media_treatments_baseline_normalized = self._meridian.non_media_transformer.forward( # pytype: disable=attribute-error
2234
+ non_media_treatments_baseline_normalized = self.model_context.non_media_transformer.forward( # pytype: disable=attribute-error
2137
2235
  non_media_treatments_baseline_scaled,
2138
2236
  apply_population_scaling=False,
2139
2237
  )
@@ -2160,7 +2258,7 @@ class Analyzer:
2160
2258
  new_data=incremented_data0,
2161
2259
  include_non_paid_channels=include_non_paid_channels,
2162
2260
  )
2163
- # TODO: b/415198977 - Verify the computation of outcome of non-media
2261
+ # TODO: Verify the computation of outcome of non-media
2164
2262
  # treatments with `media_selected_times` and scale factors.
2165
2263
 
2166
2264
  data_tensors0 = DataTensors(
@@ -2181,9 +2279,9 @@ class Analyzer:
2181
2279
 
2182
2280
  # Calculate incremental outcome in batches.
2183
2281
  params = (
2184
- self._meridian.inference_data.posterior
2282
+ self.inference_data.posterior
2185
2283
  if use_posterior
2186
- else self._meridian.inference_data.prior
2284
+ else self.inference_data.prior
2187
2285
  )
2188
2286
  n_draws = params.draw.size
2189
2287
  batch_starting_indices = np.arange(n_draws, step=batch_size)
@@ -2252,23 +2350,23 @@ class Analyzer:
2252
2350
  ValueError: If the geo or time granularity arguments are not valid for the
2253
2351
  ROI analysis.
2254
2352
  """
2255
- if self._meridian.is_national:
2353
+ if self.model_context.is_national:
2256
2354
  _warn_if_geo_arg_in_kwargs(
2257
2355
  aggregate_geos=aggregate_geos,
2258
2356
  selected_geos=selected_geos,
2259
2357
  )
2260
2358
  if selected_geos is not None or not aggregate_geos:
2261
2359
  if (
2262
- self._meridian.media_tensors.media_spend is not None
2263
- and not self._meridian.input_data.media_spend_has_geo_dimension
2360
+ self.model_context.media_tensors.media_spend is not None
2361
+ and not self.model_context.input_data.media_spend_has_geo_dimension
2264
2362
  ):
2265
2363
  raise ValueError(
2266
2364
  "`selected_geos` and `aggregate_geos=False` are not allowed because"
2267
2365
  " Meridian `media_spend` data does not have a geo dimension."
2268
2366
  )
2269
2367
  if (
2270
- self._meridian.rf_tensors.rf_spend is not None
2271
- and not self._meridian.input_data.rf_spend_has_geo_dimension
2368
+ self.model_context.rf_tensors.rf_spend is not None
2369
+ and not self.model_context.input_data.rf_spend_has_geo_dimension
2272
2370
  ):
2273
2371
  raise ValueError(
2274
2372
  "`selected_geos` and `aggregate_geos=False` are not allowed because"
@@ -2277,16 +2375,16 @@ class Analyzer:
2277
2375
 
2278
2376
  if selected_times is not None:
2279
2377
  if (
2280
- self._meridian.media_tensors.media_spend is not None
2281
- and not self._meridian.input_data.media_spend_has_time_dimension
2378
+ self.model_context.media_tensors.media_spend is not None
2379
+ and not self.model_context.input_data.media_spend_has_time_dimension
2282
2380
  ):
2283
2381
  raise ValueError(
2284
2382
  "`selected_times` is not allowed because Meridian `media_spend`"
2285
2383
  " data does not have a time dimension."
2286
2384
  )
2287
2385
  if (
2288
- self._meridian.rf_tensors.rf_spend is not None
2289
- and not self._meridian.input_data.rf_spend_has_time_dimension
2386
+ self.model_context.rf_tensors.rf_spend is not None
2387
+ and not self.model_context.input_data.rf_spend_has_time_dimension
2290
2388
  ):
2291
2389
  raise ValueError(
2292
2390
  "`selected_times` is not allowed because Meridian `rf_spend` data"
@@ -2379,7 +2477,7 @@ class Analyzer:
2379
2477
  new_data = DataTensors()
2380
2478
  filled_data = new_data.validate_and_fill_missing_data(
2381
2479
  required_tensors_names=required_values,
2382
- meridian=self._meridian,
2480
+ model_context=self.model_context,
2383
2481
  )
2384
2482
  numerator = self.incremental_outcome(
2385
2483
  new_data=filled_data.filter_fields(constants.PAID_DATA),
@@ -2396,25 +2494,26 @@ class Analyzer:
2396
2494
  )
2397
2495
  spend_inc = filled_data.total_spend() * incremental_increase
2398
2496
  if spend_inc is not None and spend_inc.ndim == 3:
2399
- denominator = self.filter_and_aggregate_geos_and_times(
2400
- spend_inc,
2401
- aggregate_times=True,
2402
- flexible_time_dim=True,
2403
- has_media_dim=True,
2404
- **dim_kwargs,
2497
+ return backend.divide(
2498
+ numerator,
2499
+ self.filter_and_aggregate_geos_and_times(
2500
+ spend_inc,
2501
+ aggregate_times=True,
2502
+ flexible_time_dim=True,
2503
+ has_media_dim=True,
2504
+ **dim_kwargs,
2505
+ ),
2405
2506
  )
2406
- else:
2407
- if not aggregate_geos:
2408
- # This check should not be reachable. It is here to protect against
2409
- # future changes to self._validate_geo_and_time_granularity. If
2410
- # spend_inc.ndim is not 3 and `aggregate_geos` is `False`, then
2411
- # self._validate_geo_and_time_granularity should raise an error.
2412
- raise ValueError(
2413
- "aggregate_geos must be True if spend does not have a geo "
2414
- "dimension."
2415
- )
2416
- denominator = spend_inc
2417
- return backend.divide_no_nan(numerator, denominator)
2507
+
2508
+ if not aggregate_geos:
2509
+ # This check should not be reachable. It is here to protect against
2510
+ # future changes to self._validate_geo_and_time_granularity. If
2511
+ # spend_inc.ndim is not 3 and `aggregate_geos` is `False`, then
2512
+ # self._validate_geo_and_time_granularity should raise an error.
2513
+ raise ValueError(
2514
+ "aggregate_geos must be True if spend does not have a geo dimension."
2515
+ )
2516
+ return backend.divide(numerator, spend_inc)
2418
2517
 
2419
2518
  def roi(
2420
2519
  self,
@@ -2501,7 +2600,7 @@ class Analyzer:
2501
2600
  new_data = DataTensors()
2502
2601
  filled_data = new_data.validate_and_fill_missing_data(
2503
2602
  required_tensors_names=required_values,
2504
- meridian=self._meridian,
2603
+ model_context=self.model_context,
2505
2604
  )
2506
2605
  incremental_outcome = self.incremental_outcome(
2507
2606
  new_data=filled_data.filter_fields(constants.PAID_DATA),
@@ -2511,26 +2610,27 @@ class Analyzer:
2511
2610
 
2512
2611
  spend = filled_data.total_spend()
2513
2612
  if spend is not None and spend.ndim == 3:
2514
- denominator = self.filter_and_aggregate_geos_and_times(
2515
- spend,
2516
- aggregate_times=True,
2517
- flexible_time_dim=True,
2518
- has_media_dim=True,
2519
- **dim_kwargs,
2613
+ return backend.divide(
2614
+ incremental_outcome,
2615
+ self.filter_and_aggregate_geos_and_times(
2616
+ spend,
2617
+ aggregate_times=True,
2618
+ flexible_time_dim=True,
2619
+ has_media_dim=True,
2620
+ **dim_kwargs,
2621
+ ),
2520
2622
  )
2521
- else:
2522
- if not aggregate_geos:
2523
- # This check should not be reachable. It is here to protect against
2524
- # future changes to self._validate_geo_and_time_granularity. If
2525
- # spend_inc.ndim is not 3 and either of `aggregate_geos` or
2526
- # `aggregate_times` is `False`, then
2527
- # self._validate_geo_and_time_granularity should raise an error.
2528
- raise ValueError(
2529
- "aggregate_geos must be True if spend does not have a geo "
2530
- "dimension."
2531
- )
2532
- denominator = spend
2533
- return backend.divide_no_nan(incremental_outcome, denominator)
2623
+
2624
+ if not aggregate_geos:
2625
+ # This check should not be reachable. It is here to protect against
2626
+ # future changes to self._validate_geo_and_time_granularity. If
2627
+ # spend_inc.ndim is not 3 and either of `aggregate_geos` or
2628
+ # `aggregate_times` is `False`, then
2629
+ # self._validate_geo_and_time_granularity should raise an error.
2630
+ raise ValueError(
2631
+ "aggregate_geos must be True if spend does not have a geo dimension."
2632
+ )
2633
+ return backend.divide(incremental_outcome, spend)
2534
2634
 
2535
2635
  def cpik(
2536
2636
  self,
@@ -2605,7 +2705,7 @@ class Analyzer:
2605
2705
  aggregate_geos=aggregate_geos,
2606
2706
  batch_size=batch_size,
2607
2707
  )
2608
- return backend.divide_no_nan(1.0, roi)
2708
+ return backend.divide(1, roi)
2609
2709
 
2610
2710
  def _mean_and_ci_by_eval_set(
2611
2711
  self,
@@ -2647,8 +2747,12 @@ class Analyzer:
2647
2747
  draws, confidence_level=confidence_level
2648
2748
  )
2649
2749
 
2650
- train_draws = np.where(self._meridian.model_spec.holdout_id, np.nan, draws)
2651
- test_draws = np.where(self._meridian.model_spec.holdout_id, draws, np.nan)
2750
+ train_draws = np.where(
2751
+ self.model_context.model_spec.holdout_id, np.nan, draws
2752
+ )
2753
+ test_draws = np.where(
2754
+ self.model_context.model_spec.holdout_id, draws, np.nan
2755
+ )
2652
2756
  draws_by_evaluation_set = np.stack(
2653
2757
  [train_draws, test_draws, draws], axis=0
2654
2758
  ) # shape (n_evaluation_sets(=3), n_chains, n_draws, n_geos, n_times)
@@ -2669,13 +2773,14 @@ class Analyzer:
2669
2773
 
2670
2774
  def _can_split_by_holdout_id(self, split_by_holdout_id: bool) -> bool:
2671
2775
  """Returns whether the data can be split by holdout_id."""
2672
- if split_by_holdout_id and self._meridian.model_spec.holdout_id is None:
2776
+ if split_by_holdout_id and self.model_context.model_spec.holdout_id is None:
2673
2777
  warnings.warn(
2674
2778
  "`split_by_holdout_id` is True but `holdout_id` is `None`. Data will"
2675
2779
  " not be split."
2676
2780
  )
2677
2781
  return (
2678
- split_by_holdout_id and self._meridian.model_spec.holdout_id is not None
2782
+ split_by_holdout_id
2783
+ and self.model_context.model_spec.holdout_id is not None
2679
2784
  )
2680
2785
 
2681
2786
  def expected_vs_actual_data(
@@ -2713,7 +2818,7 @@ class Analyzer:
2713
2818
  """
2714
2819
  _validate_non_media_baseline_values_numbers(non_media_baseline_values)
2715
2820
  use_kpi = self._use_kpi(use_kpi)
2716
- mmm = self._meridian
2821
+ m_context = self.model_context
2717
2822
  can_split_by_holdout = self._can_split_by_holdout_id(split_by_holdout_id)
2718
2823
  expected_outcome = self.expected_outcome(
2719
2824
  aggregate_geos=False, aggregate_times=False, use_kpi=use_kpi
@@ -2742,7 +2847,9 @@ class Analyzer:
2742
2847
  )
2743
2848
  actual = np.asarray(
2744
2849
  self.filter_and_aggregate_geos_and_times(
2745
- mmm.kpi if use_kpi else mmm.kpi * mmm.revenue_per_kpi,
2850
+ m_context.kpi
2851
+ if use_kpi
2852
+ else m_context.kpi * m_context.revenue_per_kpi,
2746
2853
  aggregate_geos=aggregate_geos,
2747
2854
  aggregate_times=aggregate_times,
2748
2855
  )
@@ -2754,9 +2861,9 @@ class Analyzer:
2754
2861
  }
2755
2862
 
2756
2863
  if not aggregate_geos:
2757
- coords[constants.GEO] = mmm.input_data.geo.data
2864
+ coords[constants.GEO] = m_context.input_data.geo.data
2758
2865
  if not aggregate_times:
2759
- coords[constants.TIME] = mmm.input_data.time.data
2866
+ coords[constants.TIME] = m_context.input_data.time.data
2760
2867
  if can_split_by_holdout:
2761
2868
  coords[constants.EVALUATION_SET_VAR] = list(constants.EVALUATION_SET)
2762
2869
 
@@ -2816,56 +2923,57 @@ class Analyzer:
2816
2923
  n_draws, n_geos, n_times)`. The `n_geos` and `n_times` dimensions is
2817
2924
  dropped if `aggregate_geos=True` or `aggregate_time=True`, respectively.
2818
2925
  """
2926
+ ctx = self.model_context
2819
2927
  new_media = (
2820
- backend.zeros_like(self._meridian.media_tensors.media)
2821
- if self._meridian.media_tensors.media is not None
2928
+ backend.zeros_like(ctx.media_tensors.media)
2929
+ if ctx.media_tensors.media is not None
2822
2930
  else None
2823
2931
  )
2824
2932
  # Frequency is not needed because the reach is zero.
2825
2933
  new_reach = (
2826
- backend.zeros_like(self._meridian.rf_tensors.reach)
2827
- if self._meridian.rf_tensors.reach is not None
2934
+ backend.zeros_like(ctx.rf_tensors.reach)
2935
+ if ctx.rf_tensors.reach is not None
2828
2936
  else None
2829
2937
  )
2830
2938
  new_organic_media = (
2831
- backend.zeros_like(self._meridian.organic_media_tensors.organic_media)
2832
- if self._meridian.organic_media_tensors.organic_media is not None
2939
+ backend.zeros_like(ctx.organic_media_tensors.organic_media)
2940
+ if ctx.organic_media_tensors.organic_media is not None
2833
2941
  else None
2834
2942
  )
2835
2943
  new_organic_reach = (
2836
- backend.zeros_like(self._meridian.organic_rf_tensors.organic_reach)
2837
- if self._meridian.organic_rf_tensors.organic_reach is not None
2944
+ backend.zeros_like(ctx.organic_rf_tensors.organic_reach)
2945
+ if ctx.organic_rf_tensors.organic_reach is not None
2838
2946
  else None
2839
2947
  )
2840
- if self._meridian.non_media_treatments is not None:
2841
- if self._meridian.model_spec.non_media_population_scaling_id is not None:
2948
+ if ctx.non_media_treatments is not None:
2949
+ if ctx.model_spec.non_media_population_scaling_id is not None:
2842
2950
  scaling_factors = backend.where(
2843
- self._meridian.model_spec.non_media_population_scaling_id,
2844
- self._meridian.population[:, backend.newaxis, backend.newaxis],
2845
- backend.ones_like(self._meridian.population)[
2951
+ ctx.model_spec.non_media_population_scaling_id,
2952
+ ctx.population[:, backend.newaxis, backend.newaxis],
2953
+ backend.ones_like(ctx.population)[
2846
2954
  :, backend.newaxis, backend.newaxis
2847
2955
  ],
2848
2956
  )
2849
2957
  else:
2850
- scaling_factors = backend.ones_like(self._meridian.population)[
2958
+ scaling_factors = backend.ones_like(ctx.population)[
2851
2959
  :, backend.newaxis, backend.newaxis
2852
2960
  ]
2853
2961
 
2854
- baseline = self._meridian.compute_non_media_treatments_baseline(
2962
+ baseline = self._model_equations.compute_non_media_treatments_baseline(
2855
2963
  non_media_baseline_values=non_media_baseline_values,
2856
2964
  )
2857
2965
  new_non_media_treatments_population_scaled = backend.broadcast_to(
2858
2966
  backend.to_tensor(baseline, dtype=backend.float32)[
2859
2967
  backend.newaxis, backend.newaxis, :
2860
2968
  ],
2861
- self._meridian.non_media_treatments.shape,
2969
+ ctx.non_media_treatments.shape,
2862
2970
  )
2863
2971
  new_non_media_treatments = (
2864
2972
  new_non_media_treatments_population_scaled * scaling_factors
2865
2973
  )
2866
2974
  else:
2867
2975
  new_non_media_treatments = None
2868
- new_controls = self._meridian.controls
2976
+ new_controls = ctx.controls
2869
2977
 
2870
2978
  new_data = DataTensors(
2871
2979
  media=new_media,
@@ -3134,23 +3242,25 @@ class Analyzer:
3134
3242
  + (constants.CHANNEL,)
3135
3243
  )
3136
3244
  channels = (
3137
- self._meridian.input_data.get_all_channels()
3245
+ self.model_context.input_data.get_all_channels()
3138
3246
  if include_non_paid_channels
3139
- else self._meridian.input_data.get_all_paid_channels()
3247
+ else self.model_context.input_data.get_all_paid_channels()
3140
3248
  )
3141
3249
  xr_coords = {constants.CHANNEL: list(channels) + [constants.ALL_CHANNELS]}
3142
3250
  if not aggregate_geos:
3143
3251
  geo_dims = (
3144
- self._meridian.input_data.geo.data
3252
+ self.model_context.input_data.geo.data
3145
3253
  if selected_geos is None
3146
3254
  else selected_geos
3147
3255
  )
3148
3256
  xr_coords[constants.GEO] = geo_dims
3149
3257
  if not aggregate_times:
3150
3258
  # Get the time coordinates for flexible time dimensions.
3151
- modified_times = new_data.get_modified_times(self._meridian)
3259
+ modified_times = new_data.get_modified_times(
3260
+ model_context=self.model_context
3261
+ )
3152
3262
  if modified_times is None:
3153
- times = self._meridian.input_data.time.data
3263
+ times = self.model_context.input_data.time.data
3154
3264
  else:
3155
3265
  times = np.arange(modified_times)
3156
3266
 
@@ -3199,7 +3309,7 @@ class Analyzer:
3199
3309
  # channels.
3200
3310
  ).where(lambda ds: ds.channel != constants.ALL_CHANNELS)
3201
3311
 
3202
- if new_data.get_modified_times(self._meridian) is None:
3312
+ if new_data.get_modified_times(model_context=self.model_context) is None:
3203
3313
  expected_outcome_fields = list(
3204
3314
  constants.PAID_DATA + constants.NON_PAID_DATA + (constants.CONTROLS,)
3205
3315
  )
@@ -3239,26 +3349,35 @@ class Analyzer:
3239
3349
  "Effectiveness is not reported because it does not have a clear"
3240
3350
  " interpretation by time period."
3241
3351
  )
3242
- return xr.merge([
3243
- incremental_outcome,
3244
- pct_of_contribution,
3245
- ])
3352
+ return xr.merge(
3353
+ [
3354
+ incremental_outcome,
3355
+ pct_of_contribution,
3356
+ ],
3357
+ compat="no_conflicts",
3358
+ )
3246
3359
  else:
3247
- return xr.merge([
3248
- incremental_outcome,
3249
- pct_of_contribution,
3250
- effectiveness,
3251
- ])
3360
+ return xr.merge(
3361
+ [
3362
+ incremental_outcome,
3363
+ pct_of_contribution,
3364
+ effectiveness,
3365
+ ],
3366
+ compat="no_conflicts",
3367
+ )
3252
3368
 
3253
3369
  # If non-paid channels are not included, return all metrics, paid and
3254
3370
  # non-paid.
3255
3371
  spend_list = []
3256
3372
  new_spend_tensors = new_data.filter_fields(
3257
3373
  constants.SPEND_DATA
3258
- ).validate_and_fill_missing_data(constants.SPEND_DATA, self._meridian)
3259
- if self._meridian.n_media_channels > 0:
3374
+ ).validate_and_fill_missing_data(
3375
+ required_tensors_names=constants.SPEND_DATA,
3376
+ model_context=self.model_context,
3377
+ )
3378
+ if self.model_context.n_media_channels > 0:
3260
3379
  spend_list.append(new_spend_tensors.media_spend)
3261
- if self._meridian.n_rf_channels > 0:
3380
+ if self.model_context.n_rf_channels > 0:
3262
3381
  spend_list.append(new_spend_tensors.rf_spend)
3263
3382
  # TODO Add support for 1-dimensional spend.
3264
3383
  aggregated_spend = self.filter_and_aggregate_geos_and_times(
@@ -3288,11 +3407,14 @@ class Analyzer:
3288
3407
  "ROI, mROI, Effectiveness, and CPIK are not reported because they "
3289
3408
  "do not have a clear interpretation by time period."
3290
3409
  )
3291
- return xr.merge([
3292
- spend_data,
3293
- incremental_outcome,
3294
- pct_of_contribution,
3295
- ])
3410
+ return xr.merge(
3411
+ [
3412
+ spend_data,
3413
+ incremental_outcome,
3414
+ pct_of_contribution,
3415
+ ],
3416
+ compat="no_conflicts",
3417
+ )
3296
3418
  else:
3297
3419
  roi = self._compute_roi_aggregate(
3298
3420
  incremental_outcome_prior=incremental_outcome_prior,
@@ -3339,15 +3461,18 @@ class Analyzer:
3339
3461
  xr_coords=xr_coords_with_ci_and_distribution,
3340
3462
  confidence_level=confidence_level,
3341
3463
  )
3342
- return xr.merge([
3343
- spend_data,
3344
- incremental_outcome,
3345
- pct_of_contribution,
3346
- roi,
3347
- effectiveness,
3348
- mroi,
3349
- cpik,
3350
- ])
3464
+ return xr.merge(
3465
+ [
3466
+ spend_data,
3467
+ incremental_outcome,
3468
+ pct_of_contribution,
3469
+ roi,
3470
+ effectiveness,
3471
+ mroi,
3472
+ cpik,
3473
+ ],
3474
+ compat="no_conflicts",
3475
+ )
3351
3476
 
3352
3477
  def get_aggregated_impressions(
3353
3478
  self,
@@ -3401,17 +3526,18 @@ class Analyzer:
3401
3526
  if new_data is None:
3402
3527
  new_data = DataTensors()
3403
3528
  data_tensors = new_data.validate_and_fill_missing_data(
3404
- tensor_names_list, self._meridian
3529
+ required_tensors_names=tensor_names_list,
3530
+ model_context=self.model_context,
3405
3531
  )
3406
3532
  n_times = (
3407
- data_tensors.get_modified_times(self._meridian)
3408
- or self._meridian.n_times
3533
+ data_tensors.get_modified_times(model_context=self.model_context)
3534
+ or self.model_context.n_times
3409
3535
  )
3410
3536
  impressions_list = []
3411
- if self._meridian.n_media_channels > 0:
3537
+ if self.model_context.n_media_channels > 0:
3412
3538
  impressions_list.append(data_tensors.media[:, -n_times:, :])
3413
3539
 
3414
- if self._meridian.n_rf_channels > 0:
3540
+ if self.model_context.n_rf_channels > 0:
3415
3541
  if optimal_frequency is None:
3416
3542
  new_frequency = data_tensors.frequency
3417
3543
  else:
@@ -3423,9 +3549,9 @@ class Analyzer:
3423
3549
  )
3424
3550
 
3425
3551
  if include_non_paid_channels:
3426
- if self._meridian.n_organic_media_channels > 0:
3552
+ if self.model_context.n_organic_media_channels > 0:
3427
3553
  impressions_list.append(data_tensors.organic_media[:, -n_times:, :])
3428
- if self._meridian.n_organic_rf_channels > 0:
3554
+ if self.model_context.n_organic_rf_channels > 0:
3429
3555
  if optimal_frequency is None:
3430
3556
  new_organic_frequency = data_tensors.organic_frequency
3431
3557
  else:
@@ -3437,7 +3563,7 @@ class Analyzer:
3437
3563
  data_tensors.organic_reach[:, -n_times:, :]
3438
3564
  * new_organic_frequency[:, -n_times:, :]
3439
3565
  )
3440
- if self._meridian.n_non_media_channels > 0:
3566
+ if self.model_context.n_non_media_channels > 0:
3441
3567
  impressions_list.append(data_tensors.non_media_treatments)
3442
3568
 
3443
3569
  return self.filter_and_aggregate_geos_and_times(
@@ -3512,14 +3638,14 @@ class Analyzer:
3512
3638
  xr_coords = {constants.CHANNEL: [constants.BASELINE]}
3513
3639
  if not aggregate_geos:
3514
3640
  geo_dims = (
3515
- self._meridian.input_data.geo.data
3641
+ self.model_context.input_data.geo.data
3516
3642
  if selected_geos is None
3517
3643
  else selected_geos
3518
3644
  )
3519
3645
  xr_coords[constants.GEO] = geo_dims
3520
3646
  if not aggregate_times:
3521
3647
  time_dims = (
3522
- self._meridian.input_data.time.data
3648
+ self.model_context.input_data.time.data
3523
3649
  if selected_times is None
3524
3650
  else selected_times
3525
3651
  )
@@ -3585,10 +3711,13 @@ class Analyzer:
3585
3711
  confidence_level=confidence_level,
3586
3712
  ).sel(channel=constants.BASELINE)
3587
3713
 
3588
- return xr.merge([
3589
- baseline_outcome,
3590
- baseline_pct_of_contribution,
3591
- ])
3714
+ return xr.merge(
3715
+ [
3716
+ baseline_outcome,
3717
+ baseline_pct_of_contribution,
3718
+ ],
3719
+ compat="no_conflicts",
3720
+ )
3592
3721
 
3593
3722
  def optimal_freq(
3594
3723
  self,
@@ -3680,43 +3809,48 @@ class Analyzer:
3680
3809
  dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
3681
3810
  use_kpi = self._use_kpi(use_kpi)
3682
3811
  new_data = new_data or DataTensors()
3683
- if self._meridian.n_rf_channels == 0:
3812
+ if self.model_context.n_rf_channels == 0:
3684
3813
  raise ValueError(
3685
3814
  "Must have at least one channel with reach and frequency data."
3686
3815
  )
3687
- if dist_type not in self._meridian.inference_data.groups():
3816
+ if dist_type not in self.inference_data.groups():
3688
3817
  raise model.NotFittedModelError(
3689
3818
  f"sample_{dist_type}() must be called prior to calling this method."
3690
3819
  )
3691
3820
 
3692
3821
  filled_data = new_data.validate_and_fill_missing_data(
3693
- [
3822
+ required_tensors_names=[
3694
3823
  constants.RF_IMPRESSIONS,
3695
3824
  constants.RF_SPEND,
3696
3825
  constants.REVENUE_PER_KPI,
3697
3826
  ],
3698
- self._meridian,
3827
+ model_context=self.model_context,
3699
3828
  )
3700
3829
  # TODO: Once treatment type filtering is added, remove adding
3701
3830
  # dummy media and media spend to `roi()` and `summary_metrics()`. This is a
3702
3831
  # hack to use `roi()` and `summary_metrics()` for RF only analysis.
3703
- has_media = self._meridian.n_media_channels > 0
3832
+ has_media = self.model_context.n_media_channels > 0
3704
3833
  n_media_times = (
3705
- filled_data.get_modified_times(self._meridian)
3706
- or self._meridian.n_media_times
3834
+ filled_data.get_modified_times(model_context=self.model_context)
3835
+ or self.model_context.n_media_times
3707
3836
  )
3708
3837
  n_times = (
3709
- filled_data.get_modified_times(self._meridian) or self._meridian.n_times
3710
- )
3711
- dummy_media = backend.ones(
3712
- (self._meridian.n_geos, n_media_times, self._meridian.n_media_channels)
3713
- )
3714
- dummy_media_spend = backend.ones(
3715
- (self._meridian.n_geos, n_times, self._meridian.n_media_channels)
3838
+ filled_data.get_modified_times(model_context=self.model_context)
3839
+ or self.model_context.n_times
3716
3840
  )
3841
+ dummy_media = backend.ones((
3842
+ self.model_context.n_geos,
3843
+ n_media_times,
3844
+ self.model_context.n_media_channels,
3845
+ ))
3846
+ dummy_media_spend = backend.ones((
3847
+ self.model_context.n_geos,
3848
+ n_times,
3849
+ self.model_context.n_media_channels,
3850
+ ))
3717
3851
 
3718
3852
  max_freq = max_frequency or np.max(
3719
- np.array(self._meridian.rf_tensors.frequency)
3853
+ np.array(self.model_context.rf_tensors.frequency)
3720
3854
  )
3721
3855
  if freq_grid is None:
3722
3856
  freq_grid = np.arange(1, max_freq, 0.1)
@@ -3724,7 +3858,9 @@ class Analyzer:
3724
3858
  # Create a frequency grid for shape (len(freq_grid), n_rf_channels, 4) where
3725
3859
  # the last argument is for the mean, median, lower and upper confidence
3726
3860
  # intervals.
3727
- metric_grid = np.zeros((len(freq_grid), self._meridian.n_rf_channels, 4))
3861
+ metric_grid = np.zeros(
3862
+ (len(freq_grid), self.model_context.n_rf_channels, 4)
3863
+ )
3728
3864
 
3729
3865
  for i, freq in enumerate(freq_grid):
3730
3866
  new_frequency = backend.ones_like(filled_data.rf_impressions) * freq
@@ -3744,15 +3880,15 @@ class Analyzer:
3744
3880
  selected_times=selected_times,
3745
3881
  aggregate_geos=True,
3746
3882
  use_kpi=use_kpi,
3747
- )[..., -self._meridian.n_rf_channels :]
3883
+ )[..., -self.model_context.n_rf_channels :]
3748
3884
  metric_grid[i, :] = get_central_tendency_and_ci(
3749
3885
  metric_grid_temp, confidence_level, include_median=True
3750
3886
  )
3751
3887
 
3752
3888
  optimal_freq_idx = np.nanargmax(metric_grid[:, :, 0], axis=0)
3753
3889
  rf_channel_values = (
3754
- self._meridian.input_data.rf_channel.values
3755
- if self._meridian.input_data.rf_channel is not None
3890
+ self.model_context.input_data.rf_channel.values
3891
+ if self.model_context.input_data.rf_channel is not None
3756
3892
  else []
3757
3893
  )
3758
3894
 
@@ -3901,7 +4037,7 @@ class Analyzer:
3901
4037
  three metrics are computed for each.
3902
4038
  """
3903
4039
  use_kpi = self._use_kpi(use_kpi)
3904
- if self._meridian.is_national:
4040
+ if self.model_context.is_national:
3905
4041
  _warn_if_geo_arg_in_kwargs(
3906
4042
  selected_geos=selected_geos,
3907
4043
  )
@@ -3922,9 +4058,9 @@ class Analyzer:
3922
4058
  constants.GEO_GRANULARITY: [constants.GEO, constants.NATIONAL],
3923
4059
  }
3924
4060
  if use_kpi:
3925
- input_tensor = self._meridian.kpi
4061
+ input_tensor = self.model_context.kpi
3926
4062
  else:
3927
- input_tensor = self._meridian.kpi * self._meridian.revenue_per_kpi
4063
+ input_tensor = self.model_context.kpi * self.model_context.revenue_per_kpi
3928
4064
  actual = np.asarray(
3929
4065
  self.filter_and_aggregate_geos_and_times(
3930
4066
  tensor=input_tensor,
@@ -3941,7 +4077,7 @@ class Analyzer:
3941
4077
  rsquared_national, mape_national, wmape_national = (
3942
4078
  self._predictive_accuracy_helper(np.sum(actual, 0), np.sum(expected, 0))
3943
4079
  )
3944
- if self._meridian.model_spec.holdout_id is None:
4080
+ if self.model_context.model_spec.holdout_id is None:
3945
4081
  rsquared_arr = [rsquared, rsquared_national]
3946
4082
  mape_arr = [mape, mape_national]
3947
4083
  wmape_arr = [wmape, wmape_national]
@@ -3955,7 +4091,9 @@ class Analyzer:
3955
4091
  xr_coords[constants.EVALUATION_SET_VAR] = list(constants.EVALUATION_SET)
3956
4092
 
3957
4093
  holdout_id = self._filter_holdout_id_for_selected_geos_and_times(
3958
- self._meridian.model_spec.holdout_id, selected_geos, selected_times
4094
+ self.model_context.model_spec.holdout_id,
4095
+ selected_geos,
4096
+ selected_times,
3959
4097
  )
3960
4098
 
3961
4099
  nansum = lambda x: np.where(
@@ -3985,7 +4123,7 @@ class Analyzer:
3985
4123
  )
3986
4124
  xr_data = {constants.VALUE: (xr_dims, stacked_total)}
3987
4125
  dataset = xr.Dataset(data_vars=xr_data, coords=xr_coords)
3988
- if self._meridian.is_national:
4126
+ if self.model_context.is_national:
3989
4127
  # Remove the geo-level coordinate.
3990
4128
  dataset = dataset.sel(geo_granularity=[constants.NATIONAL])
3991
4129
  return dataset
@@ -4023,14 +4161,16 @@ class Analyzer:
4023
4161
  ) -> np.ndarray:
4024
4162
  """Filters the holdout_id array for selected times and geos."""
4025
4163
 
4026
- if selected_geos is not None and not self._meridian.is_national:
4027
- geo_mask = [x in selected_geos for x in self._meridian.input_data.geo]
4164
+ if selected_geos is not None and not self.model_context.is_national:
4165
+ geo_mask = [x in selected_geos for x in self.model_context.input_data.geo]
4028
4166
  holdout_id = holdout_id[geo_mask]
4029
4167
 
4030
4168
  if selected_times is not None:
4031
- time_mask = [x in selected_times for x in self._meridian.input_data.time]
4169
+ time_mask = [
4170
+ x in selected_times for x in self.model_context.input_data.time
4171
+ ]
4032
4172
  # If model is national, holdout_id will have only 1 dimension.
4033
- if self._meridian.is_national:
4173
+ if self.model_context.is_national:
4034
4174
  holdout_id = holdout_id[time_mask]
4035
4175
  else:
4036
4176
  holdout_id = holdout_id[:, time_mask]
@@ -4048,7 +4188,7 @@ class Analyzer:
4048
4188
  NotFittedModelError: If self.sample_posterior() is not called before
4049
4189
  calling this method.
4050
4190
  """
4051
- if constants.POSTERIOR not in self._meridian.inference_data.groups():
4191
+ if constants.POSTERIOR not in self._inference_data.groups():
4052
4192
  raise model.NotFittedModelError(
4053
4193
  "sample_posterior() must be called prior to calling this method."
4054
4194
  )
@@ -4059,11 +4199,10 @@ class Analyzer:
4059
4199
  perm = [1, 0] + list(range(2, n_dim))
4060
4200
  return backend.transpose(x_tensor, perm)
4061
4201
 
4062
- rhat = backend.mcmc.potential_scale_reduction({
4202
+ return backend.mcmc.potential_scale_reduction({
4063
4203
  k: _transpose_first_two_dims(v)
4064
- for k, v in self._meridian.inference_data.posterior.data_vars.items()
4204
+ for k, v in self._inference_data.posterior.data_vars.items()
4065
4205
  })
4066
- return rhat
4067
4206
 
4068
4207
  def rhat_summary(self, bad_rhat_threshold: float = 1.2) -> pd.DataFrame:
4069
4208
  """Computes a summary of the R-hat values for each parameter in the model.
@@ -4110,8 +4249,15 @@ class Analyzer:
4110
4249
 
4111
4250
  rhat_summary = []
4112
4251
  for param in rhat:
4252
+ # `tau_g` and `tau_g_excl_baseline` are the only parameters that have
4253
+ # inconsistent names in the prior and the posterior. Here, we ensure that
4254
+ # the `has_deterministic_param` takes the right parameter name.
4255
+ param_name = (
4256
+ constants.TAU_G_EXCL_BASELINE if param == constants.TAU_G else param
4257
+ )
4258
+
4113
4259
  # Skip if parameter is deterministic according to the prior.
4114
- if self._meridian.prior_broadcast.has_deterministic_param(param):
4260
+ if self.model_context.prior_broadcast.has_deterministic_param(param_name):
4115
4261
  continue
4116
4262
 
4117
4263
  if rhat[param].ndim == 2:
@@ -4186,8 +4332,8 @@ class Analyzer:
4186
4332
  selected_times: Optional list containing a subset of dates to include. If
4187
4333
  `new_data` is provided with modified time periods, then `selected_times`
4188
4334
  must be a subset of `new_data.times`. Otherwise, `selected_times` must
4189
- be a subset of `self._meridian.input_data.time`. By default, all time
4190
- periods are included.
4335
+ be a subset of `self._model_context.input_data.time`. By default, all
4336
+ time periods are included.
4191
4337
  by_reach: Boolean. For channels with reach and frequency. If `True`, plots
4192
4338
  the response curve by reach. If `False`, plots the response curve by
4193
4339
  frequency.
@@ -4206,7 +4352,7 @@ class Analyzer:
4206
4352
  An `xarray.Dataset` containing the data needed to visualize response
4207
4353
  curves.
4208
4354
  """
4209
- if self._meridian.is_national:
4355
+ if self.model_context.is_national:
4210
4356
  _warn_if_geo_arg_in_kwargs(
4211
4357
  selected_geos=selected_geos,
4212
4358
  )
@@ -4218,20 +4364,22 @@ class Analyzer:
4218
4364
  }
4219
4365
  if new_data is None:
4220
4366
  new_data = DataTensors()
4221
- # TODO: b/442920356 - Support flexible time without providing exact dates.
4367
+ # TODO: Support flexible time without providing exact dates.
4222
4368
  required_tensors_names = constants.PERFORMANCE_DATA + (constants.TIME,)
4223
4369
  filled_data = new_data.validate_and_fill_missing_data(
4224
4370
  required_tensors_names=required_tensors_names,
4225
- meridian=self._meridian,
4371
+ model_context=self.model_context,
4226
4372
  allow_modified_times=True,
4227
4373
  )
4228
- new_n_media_times = filled_data.get_modified_times(self._meridian)
4374
+ new_n_media_times = filled_data.get_modified_times(
4375
+ model_context=self.model_context
4376
+ )
4229
4377
 
4230
4378
  if new_n_media_times is None:
4231
4379
  _validate_selected_times(
4232
4380
  selected_times=selected_times,
4233
- input_times=self._meridian.input_data.time,
4234
- n_times=self._meridian.n_times,
4381
+ input_times=self.model_context.input_data.time,
4382
+ n_times=self.model_context.n_times,
4235
4383
  arg_name="selected_times",
4236
4384
  comparison_arg_name="the input data",
4237
4385
  )
@@ -4243,12 +4391,12 @@ class Analyzer:
4243
4391
  new_n_media_times=new_n_media_times,
4244
4392
  new_time=new_time,
4245
4393
  )
4246
- # TODO: b/407847021 - Switch to Sequence[str] once it is supported.
4394
+ # TODO: Switch to Sequence[str] once it is supported.
4247
4395
  if selected_times is not None:
4248
4396
  selected_times = [x in selected_times for x in new_time]
4249
4397
  dim_kwargs["selected_times"] = selected_times
4250
4398
 
4251
- if self._meridian.n_rf_channels > 0 and use_optimal_frequency:
4399
+ if self.model_context.n_rf_channels > 0 and use_optimal_frequency:
4252
4400
  opt_freq_data = DataTensors(
4253
4401
  media=filled_data.media,
4254
4402
  rf_impressions=filled_data.reach * filled_data.frequency,
@@ -4265,7 +4413,7 @@ class Analyzer:
4265
4413
  ).optimal_frequency,
4266
4414
  dtype=backend.float32,
4267
4415
  )
4268
- reach = backend.divide_no_nan(
4416
+ reach = backend.divide(
4269
4417
  filled_data.reach * filled_data.frequency,
4270
4418
  frequency,
4271
4419
  )
@@ -4276,13 +4424,13 @@ class Analyzer:
4276
4424
  spend_multipliers = list(np.arange(0, 2.2, 0.2))
4277
4425
  incremental_outcome = np.zeros((
4278
4426
  len(spend_multipliers),
4279
- len(self._meridian.input_data.get_all_paid_channels()),
4427
+ len(self.model_context.input_data.get_all_paid_channels()),
4280
4428
  3,
4281
4429
  ))
4282
4430
  for i, multiplier in enumerate(spend_multipliers):
4283
4431
  if multiplier == 0:
4284
4432
  incremental_outcome[i, :, :] = backend.zeros(
4285
- (len(self._meridian.input_data.get_all_paid_channels()), 3)
4433
+ (len(self.model_context.input_data.get_all_paid_channels()), 3)
4286
4434
  ) # Last dimension = 3 for the mean, ci_lo and ci_hi.
4287
4435
  continue
4288
4436
  scaled_data = _scale_tensors_by_multiplier(
@@ -4317,7 +4465,9 @@ class Analyzer:
4317
4465
  )
4318
4466
  spend_einsum = backend.einsum("k,m->km", np.array(spend_multipliers), spend)
4319
4467
  xr_coords = {
4320
- constants.CHANNEL: self._meridian.input_data.get_all_paid_channels(),
4468
+ constants.CHANNEL: (
4469
+ self.model_context.input_data.get_all_paid_channels()
4470
+ ),
4321
4471
  constants.METRIC: [
4322
4472
  constants.MEAN,
4323
4473
  constants.CI_LO,
@@ -4352,8 +4502,8 @@ class Analyzer:
4352
4502
  `ci_hi`, `ci_lo`, and `mean` for the Adstock function.
4353
4503
  """
4354
4504
  if (
4355
- constants.PRIOR not in self._meridian.inference_data.groups()
4356
- or constants.POSTERIOR not in self._meridian.inference_data.groups()
4505
+ constants.PRIOR not in self._inference_data.groups()
4506
+ or constants.POSTERIOR not in self._inference_data.groups()
4357
4507
  ):
4358
4508
  raise model.NotFittedModelError(
4359
4509
  "sample_prior() and sample_posterior() must be called prior to"
@@ -4362,7 +4512,7 @@ class Analyzer:
4362
4512
 
4363
4513
  # Choose a step_size such that time_unit has consecutive integers defined
4364
4514
  # throughout.
4365
- max_lag = max(self._meridian.model_spec.max_lag, 1)
4515
+ max_lag = max(self.model_context.model_spec.max_lag, 1)
4366
4516
  steps_per_time_period_max_lag = (
4367
4517
  constants.ADSTOCK_DECAY_MAX_TOTAL_STEPS // max_lag
4368
4518
  )
@@ -4406,23 +4556,23 @@ class Analyzer:
4406
4556
  final_df_list.append(adstock_df)
4407
4557
 
4408
4558
  _add_adstock_decay_for_channel(
4409
- self._meridian.n_media_channels,
4410
- self._meridian.input_data.media_channel,
4559
+ self.model_context.n_media_channels,
4560
+ self.model_context.input_data.media_channel,
4411
4561
  constants.MEDIA,
4412
4562
  )
4413
4563
  _add_adstock_decay_for_channel(
4414
- self._meridian.n_rf_channels,
4415
- self._meridian.input_data.rf_channel,
4564
+ self.model_context.n_rf_channels,
4565
+ self.model_context.input_data.rf_channel,
4416
4566
  constants.RF,
4417
4567
  )
4418
4568
  _add_adstock_decay_for_channel(
4419
- self._meridian.n_organic_media_channels,
4420
- self._meridian.input_data.organic_media_channel,
4569
+ self.model_context.n_organic_media_channels,
4570
+ self.model_context.input_data.organic_media_channel,
4421
4571
  constants.ORGANIC_MEDIA,
4422
4572
  )
4423
4573
  _add_adstock_decay_for_channel(
4424
- self._meridian.n_organic_rf_channels,
4425
- self._meridian.input_data.organic_rf_channel,
4574
+ self.model_context.n_organic_rf_channels,
4575
+ self.model_context.input_data.organic_rf_channel,
4426
4576
  constants.ORGANIC_RF,
4427
4577
  )
4428
4578
 
@@ -4468,50 +4618,52 @@ class Analyzer:
4468
4618
  """
4469
4619
  if (
4470
4620
  channel_type == constants.MEDIA
4471
- and self._meridian.input_data.media_channel is not None
4621
+ and self.model_context.input_data.media_channel is not None
4472
4622
  ):
4473
4623
  ec = constants.EC_M
4474
4624
  slope = constants.SLOPE_M
4475
- channels = self._meridian.input_data.media_channel.values
4476
- transformer = self._meridian.media_tensors.media_transformer
4625
+ channels = self.model_context.input_data.media_channel.values
4626
+ transformer = self.model_context.media_tensors.media_transformer
4477
4627
  linspace_max_values = np.max(
4478
- np.array(self._meridian.media_tensors.media_scaled), axis=(0, 1)
4628
+ np.array(self.model_context.media_tensors.media_scaled), axis=(0, 1)
4479
4629
  )
4480
4630
  elif (
4481
4631
  channel_type == constants.RF
4482
- and self._meridian.input_data.rf_channel is not None
4632
+ and self.model_context.input_data.rf_channel is not None
4483
4633
  ):
4484
4634
  ec = constants.EC_RF
4485
4635
  slope = constants.SLOPE_RF
4486
- channels = self._meridian.input_data.rf_channel.values
4636
+ channels = self.model_context.input_data.rf_channel.values
4487
4637
  transformer = None
4488
4638
  linspace_max_values = np.max(
4489
- np.array(self._meridian.rf_tensors.frequency), axis=(0, 1)
4639
+ np.array(self.model_context.rf_tensors.frequency), axis=(0, 1)
4490
4640
  )
4491
4641
  elif (
4492
4642
  channel_type == constants.ORGANIC_MEDIA
4493
- and self._meridian.input_data.organic_media_channel is not None
4643
+ and self.model_context.input_data.organic_media_channel is not None
4494
4644
  ):
4495
4645
  ec = constants.EC_OM
4496
4646
  slope = constants.SLOPE_OM
4497
- channels = self._meridian.input_data.organic_media_channel.values
4647
+ channels = self.model_context.input_data.organic_media_channel.values
4498
4648
  transformer = (
4499
- self._meridian.organic_media_tensors.organic_media_transformer
4649
+ self.model_context.organic_media_tensors.organic_media_transformer
4500
4650
  )
4501
4651
  linspace_max_values = np.max(
4502
- np.array(self._meridian.organic_media_tensors.organic_media_scaled),
4652
+ np.array(
4653
+ self.model_context.organic_media_tensors.organic_media_scaled
4654
+ ),
4503
4655
  axis=(0, 1),
4504
4656
  )
4505
4657
  elif (
4506
4658
  channel_type == constants.ORGANIC_RF
4507
- and self._meridian.input_data.organic_rf_channel is not None
4659
+ and self.model_context.input_data.organic_rf_channel is not None
4508
4660
  ):
4509
4661
  ec = constants.EC_ORF
4510
4662
  slope = constants.SLOPE_ORF
4511
- channels = self._meridian.input_data.organic_rf_channel.values
4663
+ channels = self.model_context.input_data.organic_rf_channel.values
4512
4664
  transformer = None
4513
4665
  linspace_max_values = np.max(
4514
- np.array(self._meridian.organic_rf_tensors.organic_frequency),
4666
+ np.array(self.model_context.organic_rf_tensors.organic_frequency),
4515
4667
  axis=(0, 1),
4516
4668
  )
4517
4669
  else:
@@ -4546,12 +4698,12 @@ class Analyzer:
4546
4698
  # n_draws, n_geos, n_times, n_channels), and we want to plot the
4547
4699
  # dependency on time only.
4548
4700
  hill_vals_prior = adstock_hill.HillTransformer(
4549
- self._meridian.inference_data.prior[ec].values,
4550
- self._meridian.inference_data.prior[slope].values,
4701
+ self._inference_data.prior[ec].values,
4702
+ self._inference_data.prior[slope].values,
4551
4703
  ).forward(expanded_linspace)[:, :, 0, :, :]
4552
4704
  hill_vals_posterior = adstock_hill.HillTransformer(
4553
- self._meridian.inference_data.posterior[ec].values,
4554
- self._meridian.inference_data.posterior[slope].values,
4705
+ self._inference_data.posterior[ec].values,
4706
+ self._inference_data.posterior[slope].values,
4555
4707
  ).forward(expanded_linspace)[:, :, 0, :, :]
4556
4708
 
4557
4709
  hill_dataset = _central_tendency_and_ci_by_prior_and_posterior(
@@ -4698,77 +4850,81 @@ class Analyzer:
4698
4850
  distribution of media units per capita for media channels or average
4699
4851
  frequency for RF channels over weeks and geos for the Hill plots.
4700
4852
  """
4701
- n_geos = self._meridian.n_geos
4702
- n_media_times = self._meridian.n_media_times
4853
+ n_geos = self.model_context.n_geos
4854
+ n_media_times = self.model_context.n_media_times
4703
4855
 
4704
4856
  df_list = []
4705
4857
 
4706
4858
  # RF.
4707
- if self._meridian.input_data.rf_channel is not None:
4708
- frequency = self._meridian.rf_tensors.frequency
4859
+ if self.model_context.input_data.rf_channel is not None:
4860
+ frequency = self.model_context.rf_tensors.frequency
4709
4861
  if frequency is not None:
4710
4862
  reshaped_frequency = backend.reshape(
4711
- frequency, (n_geos * n_media_times, self._meridian.n_rf_channels)
4863
+ frequency,
4864
+ (n_geos * n_media_times, self.model_context.n_rf_channels),
4712
4865
  )
4713
4866
  rf_hist_data = self._get_channel_hill_histogram_dataframe(
4714
4867
  channel_type=constants.RF,
4715
4868
  data_to_histogram=reshaped_frequency,
4716
- channel_names=self._meridian.input_data.rf_channel.values,
4869
+ channel_names=self.model_context.input_data.rf_channel.values,
4717
4870
  n_bins=n_bins,
4718
4871
  )
4719
4872
  df_list.append(pd.DataFrame(rf_hist_data))
4720
4873
 
4721
4874
  # Media.
4722
- if self._meridian.input_data.media_channel is not None:
4723
- transformer = self._meridian.media_tensors.media_transformer
4724
- scaled = self._meridian.media_tensors.media_scaled
4875
+ if self.model_context.input_data.media_channel is not None:
4876
+ transformer = self.model_context.media_tensors.media_transformer
4877
+ scaled = self.model_context.media_tensors.media_scaled
4725
4878
  if transformer is not None and scaled is not None:
4726
4879
  population_scaled_median = transformer.population_scaled_median_m
4727
4880
  scaled_media_units = scaled * population_scaled_median
4728
4881
  reshaped_scaled_media_units = backend.reshape(
4729
4882
  scaled_media_units,
4730
- (n_geos * n_media_times, self._meridian.n_media_channels),
4883
+ (n_geos * n_media_times, self.model_context.n_media_channels),
4731
4884
  )
4732
4885
  media_hist_data = self._get_channel_hill_histogram_dataframe(
4733
4886
  channel_type=constants.MEDIA,
4734
4887
  data_to_histogram=reshaped_scaled_media_units,
4735
- channel_names=self._meridian.input_data.media_channel.values,
4888
+ channel_names=self.model_context.input_data.media_channel.values,
4736
4889
  n_bins=n_bins,
4737
4890
  )
4738
4891
  df_list.append(pd.DataFrame(media_hist_data))
4739
4892
  # Organic media.
4740
- if self._meridian.input_data.organic_media_channel is not None:
4893
+ if self.model_context.input_data.organic_media_channel is not None:
4741
4894
  transformer_om = (
4742
- self._meridian.organic_media_tensors.organic_media_transformer
4895
+ self.model_context.organic_media_tensors.organic_media_transformer
4743
4896
  )
4744
- scaled_om = self._meridian.organic_media_tensors.organic_media_scaled
4897
+ scaled_om = self.model_context.organic_media_tensors.organic_media_scaled
4745
4898
  if transformer_om is not None and scaled_om is not None:
4746
4899
  population_scaled_median_om = transformer_om.population_scaled_median_m
4747
4900
  scaled_organic_media_units = scaled_om * population_scaled_median_om
4748
4901
  reshaped_scaled_organic_media_units = backend.reshape(
4749
4902
  scaled_organic_media_units,
4750
- (n_geos * n_media_times, self._meridian.n_organic_media_channels),
4903
+ (
4904
+ n_geos * n_media_times,
4905
+ self.model_context.n_organic_media_channels,
4906
+ ),
4751
4907
  )
4752
4908
  organic_media_hist_data = self._get_channel_hill_histogram_dataframe(
4753
4909
  channel_type=constants.ORGANIC_MEDIA,
4754
4910
  data_to_histogram=reshaped_scaled_organic_media_units,
4755
- channel_names=self._meridian.input_data.organic_media_channel.values,
4911
+ channel_names=self.model_context.input_data.organic_media_channel.values,
4756
4912
  n_bins=n_bins,
4757
4913
  )
4758
4914
  df_list.append(pd.DataFrame(organic_media_hist_data))
4759
4915
 
4760
4916
  # Organic RF.
4761
- if self._meridian.input_data.organic_rf_channel is not None:
4762
- frequency = self._meridian.organic_rf_tensors.organic_frequency
4917
+ if self.model_context.input_data.organic_rf_channel is not None:
4918
+ frequency = self.model_context.organic_rf_tensors.organic_frequency
4763
4919
  if frequency is not None:
4764
4920
  reshaped_frequency = backend.reshape(
4765
4921
  frequency,
4766
- (n_geos * n_media_times, self._meridian.n_organic_rf_channels),
4922
+ (n_geos * n_media_times, self.model_context.n_organic_rf_channels),
4767
4923
  )
4768
4924
  organic_rf_hist_data = self._get_channel_hill_histogram_dataframe(
4769
4925
  channel_type=constants.ORGANIC_RF,
4770
4926
  data_to_histogram=reshaped_frequency,
4771
- channel_names=self._meridian.input_data.organic_rf_channel.values,
4927
+ channel_names=self.model_context.input_data.organic_rf_channel.values,
4772
4928
  n_bins=n_bins,
4773
4929
  )
4774
4930
  df_list.append(pd.DataFrame(organic_rf_hist_data))
@@ -4811,8 +4967,8 @@ class Analyzer:
4811
4967
  for a histogram bin.
4812
4968
  """
4813
4969
  if (
4814
- constants.PRIOR not in self._meridian.inference_data.groups()
4815
- or constants.POSTERIOR not in self._meridian.inference_data.groups()
4970
+ constants.PRIOR not in self._inference_data.groups()
4971
+ or constants.POSTERIOR not in self._inference_data.groups()
4816
4972
  ):
4817
4973
  raise model.NotFittedModelError(
4818
4974
  "sample_prior() and sample_posterior() must be called prior to"
@@ -4821,10 +4977,10 @@ class Analyzer:
4821
4977
 
4822
4978
  final_dfs = [pd.DataFrame()]
4823
4979
  for n_channels, channel_type in [
4824
- (self._meridian.n_media_channels, constants.MEDIA),
4825
- (self._meridian.n_rf_channels, constants.RF),
4826
- (self._meridian.n_organic_media_channels, constants.ORGANIC_MEDIA),
4827
- (self._meridian.n_organic_rf_channels, constants.ORGANIC_RF),
4980
+ (self.model_context.n_media_channels, constants.MEDIA),
4981
+ (self.model_context.n_rf_channels, constants.RF),
4982
+ (self.model_context.n_organic_media_channels, constants.ORGANIC_MEDIA),
4983
+ (self.model_context.n_organic_rf_channels, constants.ORGANIC_RF),
4828
4984
  ]:
4829
4985
  if n_channels > 0:
4830
4986
  hill_df = self._get_hill_curves_dataframe(
@@ -4967,6 +5123,7 @@ class Analyzer:
4967
5123
  include_median=True,
4968
5124
  )
4969
5125
 
5126
+ # TODO: Remove this method.
4970
5127
  def get_historical_spend(
4971
5128
  self,
4972
5129
  selected_times: Sequence[str] | None = None,
@@ -5048,7 +5205,8 @@ class Analyzer:
5048
5205
  new_data = new_data or DataTensors()
5049
5206
  required_tensors_names = constants.PAID_CHANNELS + constants.SPEND_DATA
5050
5207
  filled_data = new_data.validate_and_fill_missing_data(
5051
- required_tensors_names, self._meridian
5208
+ required_tensors_names=required_tensors_names,
5209
+ model_context=self.model_context,
5052
5210
  )
5053
5211
 
5054
5212
  empty_da = xr.DataArray(
@@ -5058,9 +5216,9 @@ class Analyzer:
5058
5216
  if not include_media:
5059
5217
  aggregated_media_spend = empty_da
5060
5218
  elif (
5061
- self._meridian.media_tensors.media is None
5062
- or self._meridian.media_tensors.media_spend is None
5063
- or self._meridian.input_data.media_channel is None
5219
+ self.model_context.media_tensors.media is None
5220
+ or self.model_context.media_tensors.media_spend is None
5221
+ or self.model_context.input_data.media_channel is None
5064
5222
  ):
5065
5223
  warnings.warn(
5066
5224
  "Requested spends for paid media channels that do not have R&F"
@@ -5073,16 +5231,18 @@ class Analyzer:
5073
5231
  selected_times=selected_times,
5074
5232
  media_execution_values=filled_data.media,
5075
5233
  channel_spend=filled_data.media_spend,
5076
- channel_names=list(self._meridian.input_data.media_channel.values),
5234
+ channel_names=list(
5235
+ self.model_context.input_data.media_channel.values
5236
+ ),
5077
5237
  )
5078
5238
 
5079
5239
  if not include_rf:
5080
5240
  aggregated_rf_spend = empty_da
5081
5241
  elif (
5082
- self._meridian.input_data.rf_channel is None
5083
- or self._meridian.rf_tensors.reach is None
5084
- or self._meridian.rf_tensors.frequency is None
5085
- or self._meridian.rf_tensors.rf_spend is None
5242
+ self.model_context.input_data.rf_channel is None
5243
+ or self.model_context.rf_tensors.reach is None
5244
+ or self.model_context.rf_tensors.frequency is None
5245
+ or self.model_context.rf_tensors.rf_spend is None
5086
5246
  ):
5087
5247
  warnings.warn(
5088
5248
  "Requested spends for paid media channels with R&F data, but the"
@@ -5096,7 +5256,7 @@ class Analyzer:
5096
5256
  selected_times=selected_times,
5097
5257
  media_execution_values=rf_execution_values,
5098
5258
  channel_spend=filled_data.rf_spend,
5099
- channel_names=list(self._meridian.input_data.rf_channel.values),
5259
+ channel_names=list(self.model_context.input_data.rf_channel.values),
5100
5260
  )
5101
5261
 
5102
5262
  return xr.concat(
@@ -5157,9 +5317,9 @@ class Analyzer:
5157
5317
  # channel_spend.ndim can only be 3 or 1.
5158
5318
  else:
5159
5319
  # media spend can have more time points than the model time points
5160
- if media_execution_values.shape[1] == self._meridian.n_media_times:
5320
+ if media_execution_values.shape[1] == self.model_context.n_media_times:
5161
5321
  media_exe_values = media_execution_values[
5162
- :, -self._meridian.n_times :, :
5322
+ :, -self.model_context.n_times :, :
5163
5323
  ]
5164
5324
  else:
5165
5325
  media_exe_values = media_execution_values
@@ -5169,7 +5329,7 @@ class Analyzer:
5169
5329
  media_exe_values,
5170
5330
  **dim_kwargs,
5171
5331
  )
5172
- imputed_cpmu = backend.divide_no_nan(
5332
+ imputed_cpmu = backend.divide(
5173
5333
  channel_spend,
5174
5334
  np.sum(media_exe_values, (0, 1)),
5175
5335
  )