google-meridian 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/METADATA +14 -11
  2. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/RECORD +47 -43
  3. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
  4. meridian/analysis/analyzer.py +558 -398
  5. meridian/analysis/optimizer.py +90 -68
  6. meridian/analysis/review/reviewer.py +4 -1
  7. meridian/analysis/summarizer.py +6 -1
  8. meridian/analysis/test_utils.py +2898 -2538
  9. meridian/analysis/visualizer.py +28 -9
  10. meridian/backend/__init__.py +106 -0
  11. meridian/constants.py +1 -0
  12. meridian/data/input_data.py +30 -52
  13. meridian/data/input_data_builder.py +2 -9
  14. meridian/data/test_utils.py +25 -41
  15. meridian/data/validator.py +48 -0
  16. meridian/mlflow/autolog.py +19 -9
  17. meridian/model/adstock_hill.py +3 -5
  18. meridian/model/context.py +134 -0
  19. meridian/model/eda/constants.py +334 -4
  20. meridian/model/eda/eda_engine.py +723 -312
  21. meridian/model/eda/eda_outcome.py +177 -33
  22. meridian/model/model.py +159 -110
  23. meridian/model/model_test_data.py +38 -0
  24. meridian/model/posterior_sampler.py +103 -62
  25. meridian/model/prior_sampler.py +114 -94
  26. meridian/model/spec.py +23 -14
  27. meridian/templates/card.html.jinja +9 -7
  28. meridian/templates/chart.html.jinja +1 -6
  29. meridian/templates/finding.html.jinja +19 -0
  30. meridian/templates/findings.html.jinja +33 -0
  31. meridian/templates/formatter.py +41 -5
  32. meridian/templates/formatter_test.py +127 -0
  33. meridian/templates/style.css +66 -9
  34. meridian/templates/style.scss +85 -4
  35. meridian/templates/table.html.jinja +1 -0
  36. meridian/version.py +1 -1
  37. scenarioplanner/linkingapi/constants.py +1 -1
  38. scenarioplanner/mmm_ui_proto_generator.py +1 -0
  39. schema/processors/marketing_processor.py +11 -10
  40. schema/processors/model_processor.py +4 -1
  41. schema/serde/distribution.py +12 -7
  42. schema/serde/hyperparameters.py +54 -107
  43. schema/serde/meridian_serde.py +6 -1
  44. schema/utils/__init__.py +1 -0
  45. schema/utils/proto_enum_converter.py +127 -0
  46. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
  47. {google_meridian-1.4.0.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +0 -0
meridian/model/model.py CHANGED
@@ -16,6 +16,7 @@
16
16
 
17
17
  import collections
18
18
  from collections.abc import Mapping, Sequence
19
+ import dataclasses
19
20
  import functools
20
21
  import os
21
22
  import warnings
@@ -84,7 +85,7 @@ class Meridian:
84
85
  input_data: An `InputData` object containing the input data for the model.
85
86
  model_spec: A `ModelSpec` object containing the model specification.
86
87
  model_context: A `ModelContext` object containing the model context.
87
- equations: A `ModelEquations` object containing stateless mathematical
88
+ model_equations: A `ModelEquations` object containing stateless mathematical
88
89
  functions and utilities for Meridian MMM.
89
90
  inference_data: A _mutable_ `arviz.InferenceData` object containing the
90
91
  resulting data from fitting the model.
@@ -164,7 +165,7 @@ class Meridian:
164
165
  input_data=input_data,
165
166
  model_spec=model_spec if model_spec else spec.ModelSpec(),
166
167
  )
167
- self._equations = equations.ModelEquations(self._model_context)
168
+ self._model_equations = equations.ModelEquations(self._model_context)
168
169
 
169
170
  self._eda_spec = eda_spec
170
171
 
@@ -190,8 +191,8 @@ class Meridian:
190
191
  return self._model_context
191
192
 
192
193
  @property
193
- def equations(self) -> equations.ModelEquations:
194
- return self._equations
194
+ def model_equations(self) -> equations.ModelEquations:
195
+ return self._model_equations
195
196
 
196
197
  @property
197
198
  def inference_data(self) -> az.InferenceData:
@@ -199,57 +200,59 @@ class Meridian:
199
200
 
200
201
  @functools.cached_property
201
202
  def eda_engine(self) -> eda_engine.EDAEngine:
202
- return eda_engine.EDAEngine(self, spec=self._eda_spec)
203
+ return eda_engine.EDAEngine(
204
+ spec=self._eda_spec, model_context=self.model_context
205
+ )
203
206
 
204
207
  @property
205
208
  def eda_spec(self) -> eda_spec_module.EDASpec:
206
209
  return self._eda_spec
207
210
 
208
211
  @property
209
- def eda_outcomes(self) -> Sequence[eda_outcome.EDAOutcome]:
212
+ def eda_outcomes(self) -> eda_outcome.CriticalCheckEDAOutcomes:
210
213
  return self.eda_engine.run_all_critical_checks()
211
214
 
212
- @functools.cached_property
215
+ @property
213
216
  def media_tensors(self) -> media.MediaTensors:
214
217
  return self._model_context.media_tensors
215
218
 
216
- @functools.cached_property
219
+ @property
217
220
  def rf_tensors(self) -> media.RfTensors:
218
221
  return self._model_context.rf_tensors
219
222
 
220
- @functools.cached_property
223
+ @property
221
224
  def organic_media_tensors(self) -> media.OrganicMediaTensors:
222
225
  return self._model_context.organic_media_tensors
223
226
 
224
- @functools.cached_property
227
+ @property
225
228
  def organic_rf_tensors(self) -> media.OrganicRfTensors:
226
229
  return self._model_context.organic_rf_tensors
227
230
 
228
- @functools.cached_property
231
+ @property
229
232
  def kpi(self) -> backend.Tensor:
230
233
  return self._model_context.kpi
231
234
 
232
- @functools.cached_property
235
+ @property
233
236
  def revenue_per_kpi(self) -> backend.Tensor | None:
234
237
  return self._model_context.revenue_per_kpi
235
238
 
236
- @functools.cached_property
239
+ @property
237
240
  def controls(self) -> backend.Tensor | None:
238
241
  return self._model_context.controls
239
242
 
240
- @functools.cached_property
243
+ @property
241
244
  def non_media_treatments(self) -> backend.Tensor | None:
242
245
  return self._model_context.non_media_treatments
243
246
 
244
- @functools.cached_property
247
+ @property
245
248
  def population(self) -> backend.Tensor:
246
249
  return self._model_context.population
247
250
 
248
- @functools.cached_property
251
+ @property
249
252
  def total_spend(self) -> backend.Tensor:
250
253
  return self._model_context.total_spend
251
254
 
252
- @functools.cached_property
255
+ @property
253
256
  def total_outcome(self) -> backend.Tensor:
254
257
  return self._model_context.total_outcome
255
258
 
@@ -293,31 +296,31 @@ class Meridian:
293
296
  def is_national(self) -> bool:
294
297
  return self._model_context.is_national
295
298
 
296
- @functools.cached_property
299
+ @property
297
300
  def knot_info(self) -> knots.KnotInfo:
298
301
  return self._model_context.knot_info
299
302
 
300
- @functools.cached_property
303
+ @property
301
304
  def controls_transformer(
302
305
  self,
303
306
  ) -> transformers.CenteringAndScalingTransformer | None:
304
307
  return self._model_context.controls_transformer
305
308
 
306
- @functools.cached_property
309
+ @property
307
310
  def non_media_transformer(
308
311
  self,
309
312
  ) -> transformers.CenteringAndScalingTransformer | None:
310
313
  return self._model_context.non_media_transformer
311
314
 
312
- @functools.cached_property
315
+ @property
313
316
  def kpi_transformer(self) -> transformers.KpiTransformer:
314
317
  return self._model_context.kpi_transformer
315
318
 
316
- @functools.cached_property
319
+ @property
317
320
  def controls_scaled(self) -> backend.Tensor | None:
318
321
  return self._model_context.controls_scaled
319
322
 
320
- @functools.cached_property
323
+ @property
321
324
  def non_media_treatments_normalized(self) -> backend.Tensor | None:
322
325
  """Normalized non-media treatments.
323
326
 
@@ -327,33 +330,33 @@ class Meridian:
327
330
  """
328
331
  return self._model_context.non_media_treatments_normalized
329
332
 
330
- @functools.cached_property
333
+ @property
331
334
  def kpi_scaled(self) -> backend.Tensor:
332
335
  return self._model_context.kpi_scaled
333
336
 
334
- @functools.cached_property
337
+ @property
335
338
  def media_effects_dist(self) -> str:
336
339
  return self._model_context.media_effects_dist
337
340
 
338
- @functools.cached_property
341
+ @property
339
342
  def unique_sigma_for_each_geo(self) -> bool:
340
343
  return self._model_context.unique_sigma_for_each_geo
341
344
 
342
- @functools.cached_property
345
+ @property
343
346
  def baseline_geo_idx(self) -> int:
344
347
  """Returns the index of the baseline geo."""
345
348
  return self._model_context.baseline_geo_idx
346
349
 
347
- @functools.cached_property
350
+ @property
348
351
  def holdout_id(self) -> backend.Tensor | None:
349
352
  return self._model_context.holdout_id
350
353
 
351
- @functools.cached_property
354
+ @property
352
355
  def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
353
356
  """Returns `AdstockDecaySpec` object with correctly mapped channels."""
354
357
  return self._model_context.adstock_decay_spec
355
358
 
356
- @functools.cached_property
359
+ @property
357
360
  def prior_broadcast(self) -> prior_distribution.PriorDistribution:
358
361
  """Returns broadcasted `PriorDistribution` object."""
359
362
  return self._model_context.prior_broadcast
@@ -361,17 +364,20 @@ class Meridian:
361
364
  @functools.cached_property
362
365
  def prior_sampler_callable(self) -> prior_sampler.PriorDistributionSampler:
363
366
  """A `PriorDistributionSampler` callable bound to this model."""
364
- return prior_sampler.PriorDistributionSampler(self)
367
+ return prior_sampler.PriorDistributionSampler(
368
+ model_context=self.model_context,
369
+ )
365
370
 
366
371
  @functools.cached_property
367
372
  def posterior_sampler_callable(
368
373
  self,
369
374
  ) -> posterior_sampler.PosteriorMCMCSampler:
370
375
  """A `PosteriorMCMCSampler` callable bound to this model."""
371
- return posterior_sampler.PosteriorMCMCSampler(self)
376
+ return posterior_sampler.PosteriorMCMCSampler(
377
+ model_context=self.model_context,
378
+ )
372
379
 
373
- # TODO: Deprecate this method in favor of the one in
374
- # `equations.py`.
380
+ # TODO: Remove this method.
375
381
  def compute_non_media_treatments_baseline(
376
382
  self,
377
383
  non_media_baseline_values: Sequence[str | float] | None = None,
@@ -394,10 +400,18 @@ class Meridian:
394
400
  A tensor of shape `(n_non_media_channels,)` containing the
395
401
  baseline values for each non-media treatment channel.
396
402
  """
397
- return self.equations.compute_non_media_treatments_baseline(
403
+ warnings.warn(
404
+ "Meridian.compute_non_media_treatments_baseline() is deprecated and"
405
+ " will be removed in a future version. Use"
406
+ " `ModelEquations.compute_non_media_treatments_baseline()` instead.",
407
+ DeprecationWarning,
408
+ stacklevel=2,
409
+ )
410
+ return self.model_equations.compute_non_media_treatments_baseline(
398
411
  non_media_baseline_values=non_media_baseline_values
399
412
  )
400
413
 
414
+ # TODO: Remove this method.
401
415
  def expand_selected_time_dims(
402
416
  self,
403
417
  start_date: tc.Date = None,
@@ -425,12 +439,16 @@ class Meridian:
425
439
  ValueError if `start_date` or `end_date` is not in the input data time
426
440
  dimensions.
427
441
  """
428
- expanded = self.input_data.time_coordinates.expand_selected_time_dims(
442
+ warnings.warn(
443
+ "Meridian.expand_selected_time_dims() is deprecated and will be removed"
444
+ " in a future version. Use `ModelContext.expand_selected_time_dims()`"
445
+ " instead.",
446
+ DeprecationWarning,
447
+ stacklevel=2,
448
+ )
449
+ return self._model_context.expand_selected_time_dims(
429
450
  start_date=start_date, end_date=end_date
430
451
  )
431
- if expanded is None:
432
- return None
433
- return [date.strftime(constants.DATE_FORMAT) for date in expanded]
434
452
 
435
453
  def _validate_injected_inference_data(self):
436
454
  """Validates that the injected inference data has correct shapes.
@@ -598,7 +616,7 @@ class Meridian:
598
616
  f' "{self.model_spec.non_media_treatments_prior_type}".'
599
617
  )
600
618
 
601
- # TODO: Deprecate in favor of ModelEquations.adstock_hill_rf.
619
+ # TODO: Remove this method.
602
620
  def linear_predictor_counterfactual_difference_media(
603
621
  self,
604
622
  media_transformed: backend.Tensor,
@@ -627,15 +645,24 @@ class Meridian:
627
645
  The linear predictor difference between the treatment variable and its
628
646
  counterfactual.
629
647
  """
630
- return self.equations.linear_predictor_counterfactual_difference_media(
631
- media_transformed=media_transformed,
632
- alpha_m=alpha_m,
633
- ec_m=ec_m,
634
- slope_m=slope_m,
648
+ warnings.warn(
649
+ "Meridian.linear_predictor_counterfactual_difference_media() is"
650
+ " deprecated and will be removed in a future version. Use "
651
+ "`ModelEquations.linear_predictor_counterfactual_difference_media()`"
652
+ " instead.",
653
+ DeprecationWarning,
654
+ stacklevel=2,
655
+ )
656
+ return (
657
+ self.model_equations.linear_predictor_counterfactual_difference_media(
658
+ media_transformed=media_transformed,
659
+ alpha_m=alpha_m,
660
+ ec_m=ec_m,
661
+ slope_m=slope_m,
662
+ )
635
663
  )
636
664
 
637
- # TODO: Deprecate in favor of
638
- # ModelEquations.linear_predictor_counterfactual_difference_rf.
665
+ # TODO: Remove this method.
639
666
  def linear_predictor_counterfactual_difference_rf(
640
667
  self,
641
668
  rf_transformed: backend.Tensor,
@@ -664,14 +691,21 @@ class Meridian:
664
691
  The linear predictor difference between the treatment variable and its
665
692
  counterfactual.
666
693
  """
667
- return self.equations.linear_predictor_counterfactual_difference_rf(
694
+ warnings.warn(
695
+ "Meridian.linear_predictor_counterfactual_difference_rf() is deprecated"
696
+ " and will be removed in a future version. Use `ModelEquations."
697
+ "linear_predictor_counterfactual_difference_rf()` instead.",
698
+ DeprecationWarning,
699
+ stacklevel=2,
700
+ )
701
+ return self.model_equations.linear_predictor_counterfactual_difference_rf(
668
702
  rf_transformed=rf_transformed,
669
703
  alpha_rf=alpha_rf,
670
704
  ec_rf=ec_rf,
671
705
  slope_rf=slope_rf,
672
706
  )
673
707
 
674
- # TODO: Deprecate in favor of ModelEquations.calculate_beta_x.
708
+ # TODO: Remove this method.
675
709
  def calculate_beta_x(
676
710
  self,
677
711
  is_non_media: bool,
@@ -717,7 +751,14 @@ class Meridian:
717
751
  The coefficient mean parameter of the treatment variable, which has
718
752
  dimension equal to the number of treatment channels..
719
753
  """
720
- return self.equations.calculate_beta_x(
754
+ warnings.warn(
755
+ "Meridian.calculate_beta_x() is deprecated and will be removed in a"
756
+ " future version. Use `ModelEquations.calculate_beta_x()`"
757
+ " instead.",
758
+ DeprecationWarning,
759
+ stacklevel=2,
760
+ )
761
+ return self.model_equations.calculate_beta_x(
721
762
  is_non_media=is_non_media,
722
763
  incremental_outcome_x=incremental_outcome_x,
723
764
  linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
@@ -725,7 +766,7 @@ class Meridian:
725
766
  beta_gx_dev=beta_gx_dev,
726
767
  )
727
768
 
728
- # TODO: Deprecate in favor of ModelEquations.adstock_hill_media.
769
+ # TODO: Remove this method.
729
770
  def adstock_hill_media(
730
771
  self,
731
772
  media: backend.Tensor, # pylint: disable=redefined-outer-name
@@ -756,7 +797,14 @@ class Meridian:
756
797
  Tensor with dimensions `[..., n_geos, n_times, n_media_channels]`
757
798
  representing Adstock and Hill-transformed media.
758
799
  """
759
- return self.equations.adstock_hill_media(
800
+ warnings.warn(
801
+ "Meridian.adstock_hill_media() is deprecated and will be removed in a"
802
+ " future version. Use `ModelEquations.adstock_hill_media()`"
803
+ " instead.",
804
+ DeprecationWarning,
805
+ stacklevel=2,
806
+ )
807
+ return self.model_equations.adstock_hill_media(
760
808
  media=media,
761
809
  alpha=alpha,
762
810
  ec=ec,
@@ -765,7 +813,7 @@ class Meridian:
765
813
  n_times_output=n_times_output,
766
814
  )
767
815
 
768
- # TODO: Deprecate in favor of ModelEquations.adstock_hill_rf.
816
+ # TODO: Remove this method.
769
817
  def adstock_hill_rf(
770
818
  self,
771
819
  reach: backend.Tensor,
@@ -797,7 +845,14 @@ class Meridian:
797
845
  Tensor with dimensions `[..., n_geos, n_times, n_rf_channels]`
798
846
  representing Hill and Adstock-transformed RF.
799
847
  """
800
- return self.equations.adstock_hill_rf(
848
+ warnings.warn(
849
+ "Meridian.adstock_hill_rf() is deprecated and will be removed in a"
850
+ " future version. Use `ModelEquations.adstock_hill_rf()`"
851
+ " instead.",
852
+ DeprecationWarning,
853
+ stacklevel=2,
854
+ )
855
+ return self.model_equations.adstock_hill_rf(
801
856
  reach=reach,
802
857
  frequency=frequency,
803
858
  alpha=alpha,
@@ -827,66 +882,30 @@ class Meridian:
827
882
  for attr in cached_properties:
828
883
  _ = getattr(self, attr)
829
884
 
885
+ # TODO: Remove this method.
830
886
  def create_inference_data_coords(
831
887
  self, n_chains: int, n_draws: int
832
888
  ) -> Mapping[str, np.ndarray | Sequence[str]]:
833
889
  """Creates data coordinates for inference data."""
834
- media_channel_names = (
835
- self.input_data.media_channel
836
- if self.input_data.media_channel is not None
837
- else np.array([])
838
- )
839
- rf_channel_names = (
840
- self.input_data.rf_channel
841
- if self.input_data.rf_channel is not None
842
- else np.array([])
843
- )
844
- organic_media_channel_names = (
845
- self.input_data.organic_media_channel
846
- if self.input_data.organic_media_channel is not None
847
- else np.array([])
890
+ warnings.warn(
891
+ "Meridian.create_inference_data_coords() is deprecated and will be"
892
+ " removed in a future version. Use"
893
+ " `ModelContext.create_inference_data_coords()` instead.",
894
+ DeprecationWarning,
895
+ stacklevel=2,
848
896
  )
849
- organic_rf_channel_names = (
850
- self.input_data.organic_rf_channel
851
- if self.input_data.organic_rf_channel is not None
852
- else np.array([])
853
- )
854
- non_media_channel_names = (
855
- self.input_data.non_media_channel
856
- if self.input_data.non_media_channel is not None
857
- else np.array([])
858
- )
859
- control_variable_names = (
860
- self.input_data.control_variable
861
- if self.input_data.control_variable is not None
862
- else np.array([])
863
- )
864
- return {
865
- constants.CHAIN: np.arange(n_chains),
866
- constants.DRAW: np.arange(n_draws),
867
- constants.GEO: self.input_data.geo,
868
- constants.TIME: self.input_data.time,
869
- constants.MEDIA_TIME: self.input_data.media_time,
870
- constants.KNOTS: np.arange(self.knot_info.n_knots),
871
- constants.CONTROL_VARIABLE: control_variable_names,
872
- constants.NON_MEDIA_CHANNEL: non_media_channel_names,
873
- constants.MEDIA_CHANNEL: media_channel_names,
874
- constants.RF_CHANNEL: rf_channel_names,
875
- constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
876
- constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
877
- }
897
+ return self._model_context.create_inference_data_coords(n_chains, n_draws)
878
898
 
899
+ # TODO: Remove this method.
879
900
  def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
880
- inference_dims = dict(constants.INFERENCE_DIMS)
881
- if self.unique_sigma_for_each_geo:
882
- inference_dims[constants.SIGMA] = [constants.GEO]
883
- else:
884
- inference_dims[constants.SIGMA] = []
885
-
886
- return {
887
- param: [constants.CHAIN, constants.DRAW] + list(dims)
888
- for param, dims in inference_dims.items()
889
- }
901
+ warnings.warn(
902
+ "Meridian.create_inference_data_dims() is deprecated and will be"
903
+ " removed in a future version. Use"
904
+ " `ModelContext.create_inference_data_dims()` instead.",
905
+ DeprecationWarning,
906
+ stacklevel=2,
907
+ )
908
+ return self._model_context.create_inference_data_dims()
890
909
 
891
910
  def sample_prior(self, n_draws: int, seed: int | None = None):
892
911
  """Draws samples from the prior distributions.
@@ -899,14 +918,25 @@ class Meridian:
899
918
  see [PRNGS and seeds]
900
919
  (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
901
920
  """
902
- self.prior_sampler_callable(n_draws=n_draws, seed=seed)
921
+ prior_draws = self.prior_sampler_callable(n_draws=n_draws, seed=seed)
922
+ # Create Arviz InferenceData for prior draws.
923
+ prior_coords = self._model_context.create_inference_data_coords(1, n_draws)
924
+ prior_dims = self._model_context.create_inference_data_dims()
925
+ prior_inference_data = az.convert_to_inference_data(
926
+ prior_draws,
927
+ coords=prior_coords,
928
+ dims=prior_dims,
929
+ group=constants.PRIOR,
930
+ )
931
+ self.inference_data.extend(prior_inference_data, join="right")
903
932
 
904
933
  def _run_model_fitting_guardrail(self):
905
934
  """Raises an error if the model has critical EDA issues."""
906
935
  error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = (
907
936
  collections.defaultdict(list)
908
937
  )
909
- for outcome in self.eda_outcomes:
938
+ for field in dataclasses.fields(self.eda_outcomes):
939
+ outcome = getattr(self.eda_outcomes, field.name)
910
940
  error_findings = [
911
941
  finding
912
942
  for finding in outcome.findings
@@ -1010,7 +1040,7 @@ class Meridian:
1010
1040
  """
1011
1041
  self._run_model_fitting_guardrail()
1012
1042
 
1013
- self.posterior_sampler_callable(
1043
+ posterior_inference_data = self.posterior_sampler_callable(
1014
1044
  n_chains=n_chains,
1015
1045
  n_adapt=n_adapt,
1016
1046
  n_burnin=n_burnin,
@@ -1025,6 +1055,7 @@ class Meridian:
1025
1055
  seed=seed,
1026
1056
  **pins,
1027
1057
  )
1058
+ self.inference_data.extend(posterior_inference_data, join="right")
1028
1059
 
1029
1060
 
1030
1061
  def save_mmm(mmm: Meridian, file_path: str):
@@ -1038,6 +1069,15 @@ def save_mmm(mmm: Meridian, file_path: str):
1038
1069
  mmm: Model object to save.
1039
1070
  file_path: File path to save a pickled model object.
1040
1071
  """
1072
+ warnings.warn(
1073
+ "save_mmm is deprecated and will be removed in a future release. Please"
1074
+ " use `schema.serde.meridian_serde.save_meridian` instead. See"
1075
+ " https://developers.google.com/meridian/docs/user-guide/saving-model-object"
1076
+ " for details.",
1077
+ DeprecationWarning,
1078
+ stacklevel=2,
1079
+ )
1080
+
1041
1081
  if not os.path.exists(os.path.dirname(file_path)):
1042
1082
  os.makedirs(os.path.dirname(file_path))
1043
1083
 
@@ -1061,6 +1101,15 @@ def load_mmm(file_path: str) -> Meridian:
1061
1101
  Raises:
1062
1102
  FileNotFoundError: If `file_path` does not exist.
1063
1103
  """
1104
+ warnings.warn(
1105
+ "load_mmm is deprecated and will be removed in a future release. Please"
1106
+ " use `meridian.schema.serde.meridian_serde.load_meridian` instead. See"
1107
+ " https://developers.google.com/meridian/docs/user-guide/saving-model-object"
1108
+ " for details.",
1109
+ DeprecationWarning,
1110
+ stacklevel=2,
1111
+ )
1112
+
1064
1113
  try:
1065
1114
  with open(file_path, "rb") as f:
1066
1115
  mmm = joblib.load(f)
@@ -159,6 +159,8 @@ class WithInputDataSamples:
159
159
  _input_data_non_media_and_organic_same_time_dims: input_data.InputData
160
160
  _input_data_organic_only: input_data.InputData
161
161
  _national_input_data_organic_only: input_data.InputData
162
+ _input_data_non_media_only: input_data.InputData
163
+ _national_input_data_non_media_only: input_data.InputData
162
164
 
163
165
  # The following NamedTuples and their attributes are immutable, so they can
164
166
  # be accessed directly.
@@ -537,6 +539,34 @@ class WithInputDataSamples:
537
539
  seed=0,
538
540
  )
539
541
  )
542
+ cls._input_data_non_media_only = (
543
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
544
+ n_geos=cls._N_GEOS,
545
+ n_times=cls._N_TIMES,
546
+ n_media_times=cls._N_MEDIA_TIMES,
547
+ n_controls=0,
548
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
549
+ n_media_channels=cls._N_MEDIA_CHANNELS,
550
+ n_rf_channels=0,
551
+ n_organic_media_channels=0,
552
+ n_organic_rf_channels=0,
553
+ seed=0,
554
+ )
555
+ )
556
+ cls._national_input_data_non_media_only = (
557
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
558
+ n_geos=cls._N_GEOS_NATIONAL,
559
+ n_times=cls._N_TIMES,
560
+ n_media_times=cls._N_MEDIA_TIMES,
561
+ n_controls=0,
562
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
563
+ n_media_channels=cls._N_MEDIA_CHANNELS,
564
+ n_rf_channels=0,
565
+ n_organic_media_channels=0,
566
+ n_organic_rf_channels=0,
567
+ seed=0,
568
+ )
569
+ )
540
570
 
541
571
  @property
542
572
  def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
@@ -659,3 +689,11 @@ class WithInputDataSamples:
659
689
  @property
660
690
  def national_input_data_organic_only(self) -> input_data.InputData:
661
691
  return self._national_input_data_organic_only.copy(deep=True)
692
+
693
+ @property
694
+ def input_data_non_media_only(self) -> input_data.InputData:
695
+ return self._input_data_non_media_only.copy(deep=True)
696
+
697
+ @property
698
+ def national_input_data_non_media_only(self) -> input_data.InputData:
699
+ return self._national_input_data_non_media_only.copy(deep=True)