google-meridian 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/model/model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
16
16
 
17
17
  from collections.abc import Mapping, Sequence
18
18
  import functools
19
+ import numbers
19
20
  import os
20
21
  import warnings
21
22
 
@@ -34,6 +35,7 @@ from meridian.model import spec
34
35
  from meridian.model import transformers
35
36
  import numpy as np
36
37
  import tensorflow as tf
38
+ import tensorflow_probability as tfp
37
39
 
38
40
 
39
41
  __all__ = [
@@ -67,6 +69,21 @@ def _warn_setting_national_args(**kwargs):
67
69
  )
68
70
 
69
71
 
72
+ def _check_for_negative_effect(
73
+ dist: tfp.distributions.Distribution, media_effects_dist: str
74
+ ):
75
+ """Checks for negative effect in the model."""
76
+ if (
77
+ media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
78
+ and np.any(dist.cdf(0)) > 0
79
+ ):
80
+ raise ValueError(
81
+ "Media priors must have non-negative support when"
82
+ f' `media_effects_dist`="{media_effects_dist}". Found negative effect'
83
+ f" in {dist.name}."
84
+ )
85
+
86
+
70
87
  class Meridian:
71
88
  """Contains the main functionality for fitting the Meridian MMM model.
72
89
 
@@ -102,6 +119,8 @@ class Meridian:
102
119
  tensors.
103
120
  total_spend: A tensor containing total spend, including
104
121
  `media_tensors.media_spend` and `rf_tensors.rf_spend`.
122
+ total_outcome: A tensor containing the total outcome, aggregated over geos
123
+ and times.
105
124
  controls_transformer: A `ControlsTransformer` to scale controls tensors
106
125
  using the model's controls data.
107
126
  non_media_transformer: A `CenteringAndScalingTransformer` to scale non-media
@@ -146,7 +165,10 @@ class Meridian:
146
165
  unique_sigma_for_each_geo=self.model_spec.unique_sigma_for_each_geo,
147
166
  )
148
167
  self._warn_setting_ignored_priors()
149
- self._validate_paid_media_prior_type()
168
+ self._set_total_media_contribution_prior = False
169
+ self._validate_mroi_priors_non_revenue()
170
+ self._validate_roi_priors_non_revenue()
171
+ self._check_for_negative_effects()
150
172
  self._validate_geo_invariants()
151
173
  self._validate_time_invariants()
152
174
  self._validate_kpi_transformer()
@@ -192,7 +214,9 @@ class Meridian:
192
214
  )
193
215
 
194
216
  @functools.cached_property
195
- def controls(self) -> tf.Tensor:
217
+ def controls(self) -> tf.Tensor | None:
218
+ if self.input_data.controls is None:
219
+ return None
196
220
  return tf.convert_to_tensor(self.input_data.controls, dtype=tf.float32)
197
221
 
198
222
  @functools.cached_property
@@ -213,6 +237,12 @@ class Meridian:
213
237
  self.input_data.get_total_spend(), dtype=tf.float32
214
238
  )
215
239
 
240
+ @functools.cached_property
241
+ def total_outcome(self) -> tf.Tensor:
242
+ return tf.convert_to_tensor(
243
+ self.input_data.get_total_outcome(), dtype=tf.float32
244
+ )
245
+
216
246
  @property
217
247
  def n_geos(self) -> int:
218
248
  return len(self.input_data.geo)
@@ -243,6 +273,8 @@ class Meridian:
243
273
 
244
274
  @property
245
275
  def n_controls(self) -> int:
276
+ if self.input_data.control_variable is None:
277
+ return 0
246
278
  return len(self.input_data.control_variable)
247
279
 
248
280
  @property
@@ -276,7 +308,13 @@ class Meridian:
276
308
  )
277
309
 
278
310
  @functools.cached_property
279
- def controls_transformer(self) -> transformers.CenteringAndScalingTransformer:
311
+ def controls_transformer(
312
+ self,
313
+ ) -> transformers.CenteringAndScalingTransformer | None:
314
+ """Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
315
+ if self.controls is None:
316
+ return None
317
+
280
318
  if self.model_spec.control_population_scaling_id is not None:
281
319
  controls_population_scaling_id = tf.convert_to_tensor(
282
320
  self.model_spec.control_population_scaling_id, dtype=bool
@@ -315,13 +353,25 @@ class Meridian:
315
353
  return transformers.KpiTransformer(self.kpi, self.population)
316
354
 
317
355
  @functools.cached_property
318
- def controls_scaled(self) -> tf.Tensor:
319
- return self.controls_transformer.forward(self.controls)
356
+ def controls_scaled(self) -> tf.Tensor | None:
357
+ if self.controls is not None:
358
+ # If `controls` is defined, then `controls_transformer` is also defined.
359
+ return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
360
+ else:
361
+ return None
320
362
 
321
363
  @functools.cached_property
322
- def non_media_treatments_scaled(self) -> tf.Tensor | None:
364
+ def non_media_treatments_normalized(self) -> tf.Tensor | None:
365
+ """Normalized non-media treatments.
366
+
367
+ The non-media treatments values are scaled by population (for channels where
368
+ `non_media_population_scaling_id` is `True`) and normalized by centering and
369
+ scaling with means and standard deviations.
370
+ """
323
371
  if self.non_media_transformer is not None:
324
- return self.non_media_transformer.forward(self.non_media_treatments) # pytype: disable=attribute-error
372
+ return self.non_media_transformer.forward(
373
+ self.non_media_treatments
374
+ ) # pytype: disable=attribute-error
325
375
  else:
326
376
  return None
327
377
 
@@ -381,12 +431,6 @@ class Meridian:
381
431
  @functools.cached_property
382
432
  def prior_broadcast(self) -> prior_distribution.PriorDistribution:
383
433
  """Returns broadcasted `PriorDistribution` object."""
384
- set_total_media_contribution_prior = (
385
- self.input_data.revenue_per_kpi is None
386
- and self.input_data.kpi_type == constants.NON_REVENUE
387
- and self.model_spec.paid_media_prior_type
388
- == constants.PAID_MEDIA_PRIOR_TYPE_ROI
389
- )
390
434
  total_spend = self.input_data.get_total_spend()
391
435
  # Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
392
436
  if len(total_spend.shape) == 1:
@@ -408,10 +452,9 @@ class Meridian:
408
452
  sigma_shape=self._sigma_shape,
409
453
  n_knots=self.knot_info.n_knots,
410
454
  is_national=self.is_national,
411
- set_total_media_contribution_prior=set_total_media_contribution_prior,
455
+ set_total_media_contribution_prior=self._set_total_media_contribution_prior,
412
456
  kpi=np.sum(self.input_data.kpi.values),
413
457
  total_spend=agg_total_spend,
414
- media_effects_dist=self.media_effects_dist,
415
458
  )
416
459
 
417
460
  @functools.cached_property
@@ -426,10 +469,91 @@ class Meridian:
426
469
  """A `PosteriorMCMCSampler` callable bound to this model."""
427
470
  return posterior_sampler.PosteriorMCMCSampler(self)
428
471
 
472
+ def compute_non_media_treatments_baseline(
473
+ self,
474
+ non_media_baseline_values: Sequence[str | float] | None = None,
475
+ ) -> tf.Tensor:
476
+ """Computes the baseline for each non-media treatment channel.
477
+
478
+ Args:
479
+ non_media_baseline_values: Optional list of shape
480
+ `(n_non_media_channels,)`. Each element is either a float (which means
481
+ that the fixed value will be used as baseline for the given channel) or
482
+ one of the strings "min" or "max" (which mean that the global minimum or
483
+ maximum value will be used as baseline for the values of the given
484
+ non_media treatment channel). If float values are provided, it is
485
+ expected that they are scaled by population for the channels where
486
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
487
+ `model_spec.non_media_baseline_values` is used, which defaults to the
488
+ minimum value for each non_media treatment channel.
489
+
490
+ Returns:
491
+ A tensor of shape `(n_non_media_channels,)` containing the
492
+ baseline values for each non-media treatment channel.
493
+ """
494
+ if non_media_baseline_values is None:
495
+ non_media_baseline_values = self.model_spec.non_media_baseline_values
496
+
497
+ if self.model_spec.non_media_population_scaling_id is not None:
498
+ scaling_factors = tf.where(
499
+ self.model_spec.non_media_population_scaling_id,
500
+ self.population[:, tf.newaxis, tf.newaxis],
501
+ tf.ones_like(self.population)[:, tf.newaxis, tf.newaxis],
502
+ )
503
+ else:
504
+ scaling_factors = tf.ones_like(self.population)[:, tf.newaxis, tf.newaxis]
505
+
506
+ non_media_treatments_population_scaled = tf.math.divide_no_nan(
507
+ self.non_media_treatments, scaling_factors
508
+ )
509
+
510
+ if non_media_baseline_values is None:
511
+ # If non_media_baseline_values is not provided, use the minimum
512
+ # value for each non_media treatment channel as the baseline.
513
+ non_media_baseline_values_filled = [
514
+ constants.NON_MEDIA_BASELINE_MIN
515
+ ] * non_media_treatments_population_scaled.shape[-1]
516
+ else:
517
+ non_media_baseline_values_filled = non_media_baseline_values
518
+
519
+ if non_media_treatments_population_scaled.shape[-1] != len(
520
+ non_media_baseline_values_filled
521
+ ):
522
+ raise ValueError(
523
+ "The number of non-media channels"
524
+ f" ({non_media_treatments_population_scaled.shape[-1]}) does not"
525
+ " match the number of baseline values"
526
+ f" ({len(non_media_baseline_values_filled)})."
527
+ )
528
+
529
+ baseline_list = []
530
+ for channel in range(non_media_treatments_population_scaled.shape[-1]):
531
+ baseline_value = non_media_baseline_values_filled[channel]
532
+
533
+ if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
534
+ baseline_for_channel = tf.reduce_min(
535
+ non_media_treatments_population_scaled[..., channel], axis=[0, 1]
536
+ )
537
+ elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
538
+ baseline_for_channel = tf.reduce_max(
539
+ non_media_treatments_population_scaled[..., channel], axis=[0, 1]
540
+ )
541
+ elif isinstance(baseline_value, numbers.Number):
542
+ baseline_for_channel = tf.cast(baseline_value, tf.float32)
543
+ else:
544
+ raise ValueError(
545
+ f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
546
+ " float numbers and strings 'min' and 'max' are supported."
547
+ )
548
+
549
+ baseline_list.append(baseline_for_channel)
550
+
551
+ return tf.stack(baseline_list, axis=-1)
552
+
429
553
  def expand_selected_time_dims(
430
554
  self,
431
- start_date: tc.Date | None = None,
432
- end_date: tc.Date | None = None,
555
+ start_date: tc.Date = None,
556
+ end_date: tc.Date = None,
433
557
  ) -> list[str] | None:
434
558
  """Validates and returns time dimension values based on the selected times.
435
559
 
@@ -652,65 +776,147 @@ class Meridian:
652
776
  def _warn_setting_ignored_priors(self):
653
777
  """Raises a warning if ignored priors are set."""
654
778
  default_distribution = prior_distribution.PriorDistribution()
655
- prior_type = self.model_spec.paid_media_prior_type
656
-
657
- ignored_custom_priors = []
658
- for prior in constants.IGNORED_PRIORS.get(prior_type, []):
659
- self_prior = getattr(self.model_spec.prior, prior)
660
- default_prior = getattr(default_distribution, prior)
661
- if not prior_distribution.distributions_are_equal(
662
- self_prior, default_prior
663
- ):
664
- ignored_custom_priors.append(prior)
665
- if ignored_custom_priors:
666
- ignored_priors_str = ", ".join(ignored_custom_priors)
667
- warnings.warn(
668
- f"Custom prior(s) `{ignored_priors_str}` are ignored when"
669
- " `paid_media_prior_type` is set to"
670
- f' "{prior_type}".'
671
- )
779
+ for ignored_priors_dict, prior_type, prior_type_name in (
780
+ (
781
+ constants.IGNORED_PRIORS_MEDIA,
782
+ self.model_spec.effective_media_prior_type,
783
+ "media_prior_type",
784
+ ),
785
+ (
786
+ constants.IGNORED_PRIORS_RF,
787
+ self.model_spec.effective_rf_prior_type,
788
+ "rf_prior_type",
789
+ ),
790
+ ):
791
+ ignored_custom_priors = []
792
+ for prior in ignored_priors_dict.get(prior_type, []):
793
+ self_prior = getattr(self.model_spec.prior, prior)
794
+ default_prior = getattr(default_distribution, prior)
795
+ if not prior_distribution.distributions_are_equal(
796
+ self_prior, default_prior
797
+ ):
798
+ ignored_custom_priors.append(prior)
799
+ if ignored_custom_priors:
800
+ ignored_priors_str = ", ".join(ignored_custom_priors)
801
+ warnings.warn(
802
+ f"Custom prior(s) `{ignored_priors_str}` are ignored when"
803
+ f' `{prior_type_name}` is set to "{prior_type}".'
804
+ )
672
805
 
673
- def _validate_paid_media_prior_type(self):
674
- """Validates the media prior type."""
675
- default_distribution = prior_distribution.PriorDistribution()
676
- mroi_m_not_set = (
677
- self.n_media_channels > 0
678
- and prior_distribution.distributions_are_equal(
679
- self.model_spec.prior.mroi_m, default_distribution.mroi_m
806
+ def _validate_mroi_priors_non_revenue(self):
807
+ """Validates mroi priors in the non-revenue outcome case."""
808
+ if (
809
+ self.input_data.kpi_type == constants.NON_REVENUE
810
+ and self.input_data.revenue_per_kpi is None
811
+ ):
812
+ default_distribution = prior_distribution.PriorDistribution()
813
+ if (
814
+ self.n_media_channels > 0
815
+ and (
816
+ self.model_spec.effective_media_prior_type
817
+ == constants.TREATMENT_PRIOR_TYPE_MROI
818
+ )
819
+ and prior_distribution.distributions_are_equal(
820
+ self.model_spec.prior.mroi_m, default_distribution.mroi_m
821
+ )
822
+ ):
823
+ raise ValueError(
824
+ f"Custom priors should be set on `{constants.MROI_M}` when"
825
+ ' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
826
+ " kpi data is missing."
680
827
  )
681
- )
682
- mroi_rf_not_set = (
683
- self.n_rf_channels > 0
684
- and prior_distribution.distributions_are_equal(
685
- self.model_spec.prior.mroi_rf, default_distribution.mroi_rf
828
+ if (
829
+ self.n_rf_channels > 0
830
+ and (
831
+ self.model_spec.effective_rf_prior_type
832
+ == constants.TREATMENT_PRIOR_TYPE_MROI
833
+ )
834
+ and prior_distribution.distributions_are_equal(
835
+ self.model_spec.prior.mroi_rf, default_distribution.mroi_rf
836
+ )
837
+ ):
838
+ raise ValueError(
839
+ f"Custom priors should be set on `{constants.MROI_RF}` when"
840
+ ' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
841
+ " data is missing."
686
842
  )
687
- )
843
+
844
+ def _validate_roi_priors_non_revenue(self):
845
+ """Validates roi priors in the non-revenue outcome case."""
688
846
  if (
689
- self.input_data.revenue_per_kpi is None
690
- and self.input_data.kpi_type == constants.NON_REVENUE
691
- and self.model_spec.paid_media_prior_type
692
- == constants.PAID_MEDIA_PRIOR_TYPE_MROI
693
- and (mroi_m_not_set or mroi_rf_not_set)
847
+ self.input_data.kpi_type == constants.NON_REVENUE
848
+ and self.input_data.revenue_per_kpi is None
694
849
  ):
695
- raise ValueError(
696
- f"Custom priors should be set on `{constants.MROI_M}` and"
697
- f" `{constants.MROI_RF}` when KPI is non-revenue and revenue per kpi"
698
- " data is missing."
850
+ default_distribution = prior_distribution.PriorDistribution()
851
+ default_roi_m_used = (
852
+ self.model_spec.effective_media_prior_type
853
+ == constants.TREATMENT_PRIOR_TYPE_ROI
854
+ and prior_distribution.distributions_are_equal(
855
+ self.model_spec.prior.roi_m, default_distribution.roi_m
856
+ )
699
857
  )
858
+ default_roi_rf_used = (
859
+ self.model_spec.effective_rf_prior_type
860
+ == constants.TREATMENT_PRIOR_TYPE_ROI
861
+ and prior_distribution.distributions_are_equal(
862
+ self.model_spec.prior.roi_rf, default_distribution.roi_rf
863
+ )
864
+ )
865
+ # If ROI priors are used with the default prior distribution for all paid
866
+ # channels (media and RF), then use the "total paid media contribution
867
+ # prior" procedure.
868
+ if (
869
+ (default_roi_m_used and default_roi_rf_used)
870
+ or (self.n_media_channels == 0 and default_roi_rf_used)
871
+ or (self.n_rf_channels == 0 and default_roi_m_used)
872
+ ):
873
+ self._set_total_media_contribution_prior = True
874
+ warnings.warn(
875
+ "Consider setting custom ROI priors, as kpi_type was specified as"
876
+ " `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
877
+ " total media contribution prior will be used with"
878
+ f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
879
+ " documentation available at "
880
+ " https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
881
+ )
882
+ elif self.n_media_channels > 0 and default_roi_m_used:
883
+ raise ValueError(
884
+ f"Custom priors should be set on `{constants.ROI_M}` when"
885
+ ' `media_prior_type` is "roi", custom priors are assigned on'
886
+ ' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
887
+ " non-revenue and revenue per kpi data is missing."
888
+ )
889
+ elif self.n_rf_channels > 0 and default_roi_rf_used:
890
+ raise ValueError(
891
+ f"Custom priors should be set on `{constants.ROI_RF}` when"
892
+ ' `rf_prior_type` is "roi", custom priors are assigned on'
893
+ ' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
894
+ " non-revenue and revenue per kpi data is missing."
895
+ )
896
+
897
+ def _check_for_negative_effects(self):
898
+ prior = self.model_spec.prior
899
+ if self.n_media_channels > 0:
900
+ _check_for_negative_effect(prior.roi_m, self.media_effects_dist)
901
+ _check_for_negative_effect(prior.mroi_m, self.media_effects_dist)
902
+ if self.n_rf_channels > 0:
903
+ _check_for_negative_effect(prior.roi_rf, self.media_effects_dist)
904
+ _check_for_negative_effect(prior.mroi_rf, self.media_effects_dist)
700
905
 
701
906
  def _validate_geo_invariants(self):
702
907
  """Validates non-national model invariants."""
703
908
  if self.is_national:
704
909
  return
705
910
 
706
- self._check_if_no_geo_variation(
707
- self.controls_scaled,
708
- constants.CONTROLS,
709
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
710
- )
911
+ if self.input_data.controls is not None:
912
+ self._check_if_no_geo_variation(
913
+ self.controls_scaled,
914
+ constants.CONTROLS,
915
+ self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
916
+ )
711
917
  if self.input_data.non_media_treatments is not None:
712
918
  self._check_if_no_geo_variation(
713
- self.non_media_treatments_scaled,
919
+ self.non_media_treatments_normalized,
714
920
  constants.NON_MEDIA_TREATMENTS,
715
921
  self.input_data.non_media_treatments.coords[
716
922
  constants.NON_MEDIA_CHANNEL
@@ -780,12 +986,20 @@ class Meridian:
780
986
 
781
987
  def _validate_time_invariants(self):
782
988
  """Validates model time invariants."""
783
-
784
- self._check_if_no_time_variation(
785
- self.controls_scaled,
786
- constants.CONTROLS,
787
- self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
788
- )
989
+ if self.input_data.controls is not None:
990
+ self._check_if_no_time_variation(
991
+ self.controls_scaled,
992
+ constants.CONTROLS,
993
+ self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
994
+ )
995
+ if self.input_data.non_media_treatments is not None:
996
+ self._check_if_no_time_variation(
997
+ self.non_media_treatments_normalized,
998
+ constants.NON_MEDIA_TREATMENTS,
999
+ self.input_data.non_media_treatments.coords[
1000
+ constants.NON_MEDIA_CHANNEL
1001
+ ].values,
1002
+ )
789
1003
  if self.input_data.media is not None:
790
1004
  self._check_if_no_time_variation(
791
1005
  self.media_tensors.media_scaled,
@@ -798,6 +1012,22 @@ class Meridian:
798
1012
  constants.REACH,
799
1013
  self.input_data.reach.coords[constants.RF_CHANNEL].values,
800
1014
  )
1015
+ if self.input_data.organic_media is not None:
1016
+ self._check_if_no_time_variation(
1017
+ self.organic_media_tensors.organic_media_scaled,
1018
+ constants.ORGANIC_MEDIA,
1019
+ self.input_data.organic_media.coords[
1020
+ constants.ORGANIC_MEDIA_CHANNEL
1021
+ ].values,
1022
+ )
1023
+ if self.input_data.organic_reach is not None:
1024
+ self._check_if_no_time_variation(
1025
+ self.organic_rf_tensors.organic_reach_scaled,
1026
+ constants.ORGANIC_REACH,
1027
+ self.input_data.organic_reach.coords[
1028
+ constants.ORGANIC_RF_CHANNEL
1029
+ ].values,
1030
+ )
801
1031
 
802
1032
  def _check_if_no_time_variation(
803
1033
  self,
@@ -829,17 +1059,192 @@ class Meridian:
829
1059
 
830
1060
  def _validate_kpi_transformer(self):
831
1061
  """Validates the KPI transformer."""
1062
+ kpi = "kpi" if self.is_national else "population_scaled_kpi"
1063
+ if (
1064
+ self.n_media_channels > 0
1065
+ and self.kpi_transformer.population_scaled_stdev == 0
1066
+ and self.model_spec.effective_media_prior_type
1067
+ in constants.PAID_MEDIA_ROI_PRIOR_TYPES
1068
+ ):
1069
+ raise ValueError(
1070
+ f"`{kpi}` cannot be constant with"
1071
+ " `media_prior_type` ="
1072
+ f' "{self.model_spec.effective_media_prior_type}".'
1073
+ )
832
1074
  if (
833
- self.kpi_transformer.population_scaled_stdev == 0
834
- and self.model_spec.paid_media_prior_type
1075
+ self.n_rf_channels > 0
1076
+ and self.kpi_transformer.population_scaled_stdev == 0
1077
+ and self.model_spec.effective_rf_prior_type
835
1078
  in constants.PAID_MEDIA_ROI_PRIOR_TYPES
836
1079
  ):
837
- kpi = "kpi" if self.is_national else "population_scaled_kpi"
838
1080
  raise ValueError(
839
1081
  f"`{kpi}` cannot be constant with"
840
- f" {self.model_spec.paid_media_prior_type} prior type."
1082
+ f' `rf_prior_type` = "{self.model_spec.effective_rf_prior_type}".'
841
1083
  )
842
1084
 
1085
+ def linear_predictor_counterfactual_difference_media(
1086
+ self,
1087
+ media_transformed: tf.Tensor,
1088
+ alpha_m: tf.Tensor,
1089
+ ec_m: tf.Tensor,
1090
+ slope_m: tf.Tensor,
1091
+ ) -> tf.Tensor:
1092
+ """Calculates linear predictor counterfactual difference for non-RF media.
1093
+
1094
+ For non-RF media variables (paid or organic), this function calculates the
1095
+ linear predictor difference between the treatment variable and its
1096
+ counterfactual. "Linear predictor" refers to the output of the hill/adstock
1097
+ function, which is multiplied by the geo-level coefficient.
1098
+
1099
+ This function does the calculation efficiently by only calculating calling
1100
+ the hill/adstock function if the prior counterfactual is not all zeros.
1101
+
1102
+ Args:
1103
+ media_transformed: The output of the hill/adstock function for actual
1104
+ historical media data.
1105
+ alpha_m: The adstock alpha parameter values.
1106
+ ec_m: The adstock ec parameter values.
1107
+ slope_m: The adstock hill slope parameter values.
1108
+
1109
+ Returns:
1110
+ The linear predictor difference between the treatment variable and its
1111
+ counterfactual.
1112
+ """
1113
+ if self.media_tensors.prior_media_scaled_counterfactual is None:
1114
+ return media_transformed
1115
+ media_transformed_counterfactual = self.adstock_hill_media(
1116
+ self.media_tensors.prior_media_scaled_counterfactual,
1117
+ alpha_m,
1118
+ ec_m,
1119
+ slope_m,
1120
+ )
1121
+ # Absolute values is needed because the difference is negative for mROI
1122
+ # priors and positive for ROI and contribution priors.
1123
+ return tf.abs(media_transformed - media_transformed_counterfactual)
1124
+
1125
+ def linear_predictor_counterfactual_difference_rf(
1126
+ self,
1127
+ rf_transformed: tf.Tensor,
1128
+ alpha_rf: tf.Tensor,
1129
+ ec_rf: tf.Tensor,
1130
+ slope_rf: tf.Tensor,
1131
+ ) -> tf.Tensor:
1132
+ """Calculates linear predictor counterfactual difference for RF media.
1133
+
1134
+ For RF media variables (paid or organic), this function calculates the
1135
+ linear predictor difference between the treatment variable and its
1136
+ counterfactual. "Linear predictor" refers to the output of the hill/adstock
1137
+ function, which is multiplied by the geo-level coefficient.
1138
+
1139
+ This function does the calculation efficiently by only calculating calling
1140
+ the hill/adstock function if the prior counterfactual is not all zeros.
1141
+
1142
+ Args:
1143
+ rf_transformed: The output of the hill/adstock function for actual
1144
+ historical media data.
1145
+ alpha_rf: The adstock alpha parameter values.
1146
+ ec_rf: The adstock ec parameter values.
1147
+ slope_rf: The adstock hill slope parameter values.
1148
+
1149
+ Returns:
1150
+ The linear predictor difference between the treatment variable and its
1151
+ counterfactual.
1152
+ """
1153
+ if self.rf_tensors.prior_reach_scaled_counterfactual is None:
1154
+ return rf_transformed
1155
+ rf_transformed_counterfactual = self.adstock_hill_rf(
1156
+ reach=self.rf_tensors.prior_reach_scaled_counterfactual,
1157
+ frequency=self.rf_tensors.frequency,
1158
+ alpha=alpha_rf,
1159
+ ec=ec_rf,
1160
+ slope=slope_rf,
1161
+ )
1162
+ # Absolute values is needed because the difference is negative for mROI
1163
+ # priors and positive for ROI and contribution priors.
1164
+ return tf.abs(rf_transformed - rf_transformed_counterfactual)
1165
+
1166
+ def calculate_beta_x(
1167
+ self,
1168
+ is_non_media: bool,
1169
+ incremental_outcome_x: tf.Tensor,
1170
+ linear_predictor_counterfactual_difference: tf.Tensor,
1171
+ eta_x: tf.Tensor,
1172
+ beta_gx_dev: tf.Tensor,
1173
+ ) -> tf.Tensor:
1174
+ """Calculates coefficient mean parameter for any treatment variable type.
1175
+
1176
+ The "beta_x" in the function name refers to the coefficient mean parameter
1177
+ of any treatment variable. The "x" can represent "m", "rf", "om", or "orf".
1178
+ This function can also be used to calculate "gamma_n" for any non-media
1179
+ treatments.
1180
+
1181
+ Args:
1182
+ is_non_media: Boolean indicating whether the treatment variable is a
1183
+ non-media treatment. This argument is used to determine whether the
1184
+ coefficient random effects are normal or log-normal. If `True`, then
1185
+ random effects are assumed to be normal. Otherwise, the distribution is
1186
+ inferred from `self.media_effects_dist`.
1187
+ incremental_outcome_x: The incremental outcome of the treatment variable,
1188
+ which depends on the parameter values of a particular prior or posterior
1189
+ draw. The "_x" indicates that this is a tensor with length equal to the
1190
+ dimension of the treatment variable.
1191
+ linear_predictor_counterfactual_difference: The difference between the
1192
+ treatment variable and its counterfactual on the linear predictor scale.
1193
+ "Linear predictor" refers to the quantity that is multiplied by the
1194
+ geo-level coefficient. For media variables, this is the output of the
1195
+ hill/adstock transformation function. For non-media treatments, this is
1196
+ simply the treatment variable after centering/scaling transformations.
1197
+ This tensor has dimensions for geo, time, and channel.
1198
+ eta_x: The random effect standard deviation parameter values. For media
1199
+ variables, the "x" represents "m", "rf", "om", or "orf". For non-media
1200
+ treatments, this argument should be set to `xi_n`, which is analogous to
1201
+ "eta".
1202
+ beta_gx_dev: The latent standard normal parameter values of the geo-level
1203
+ coefficients. For media variables, the "x" represents "m", "rf", "om",
1204
+ or "orf". For non-media treatments, this argument should be set to
1205
+ `gamma_gn_dev`, which is analogous to "beta_gx_dev".
1206
+
1207
+ Returns:
1208
+ The coefficient mean parameter of the treatment variable, which has
1209
+ dimension equal to the number of treatment channels..
1210
+ """
1211
+ if is_non_media:
1212
+ random_effects_normal = True
1213
+ else:
1214
+ random_effects_normal = (
1215
+ self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1216
+ )
1217
+ if self.revenue_per_kpi is None:
1218
+ revenue_per_kpi = tf.ones([self.n_geos, self.n_times], dtype=tf.float32)
1219
+ else:
1220
+ revenue_per_kpi = self.revenue_per_kpi
1221
+ incremental_outcome_gx_over_beta_gx = tf.einsum(
1222
+ "...gtx,gt,g,->...gx",
1223
+ linear_predictor_counterfactual_difference,
1224
+ revenue_per_kpi,
1225
+ self.population,
1226
+ self.kpi_transformer.population_scaled_stdev,
1227
+ )
1228
+ if random_effects_normal:
1229
+ numerator_term_x = tf.einsum(
1230
+ "...gx,...gx,...x->...x",
1231
+ incremental_outcome_gx_over_beta_gx,
1232
+ beta_gx_dev,
1233
+ eta_x,
1234
+ )
1235
+ denominator_term_x = tf.einsum(
1236
+ "...gx->...x", incremental_outcome_gx_over_beta_gx
1237
+ )
1238
+ return (incremental_outcome_x - numerator_term_x) / denominator_term_x
1239
+ # For log-normal random effects, beta_x and eta_x are not mean & std.
1240
+ # The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
1241
+ denominator_term_x = tf.einsum(
1242
+ "...gx,...gx->...x",
1243
+ incremental_outcome_gx_over_beta_gx,
1244
+ tf.math.exp(beta_gx_dev * eta_x[..., tf.newaxis, :]),
1245
+ )
1246
+ return tf.math.log(incremental_outcome_x) - tf.math.log(denominator_term_x)
1247
+
843
1248
  def adstock_hill_media(
844
1249
  self,
845
1250
  media: tf.Tensor, # pylint: disable=redefined-outer-name
@@ -966,31 +1371,36 @@ class Meridian:
966
1371
  self, n_chains: int, n_draws: int
967
1372
  ) -> Mapping[str, np.ndarray | Sequence[str]]:
968
1373
  """Creates data coordinates for inference data."""
969
- media_channel_values = (
1374
+ media_channel_names = (
970
1375
  self.input_data.media_channel
971
1376
  if self.input_data.media_channel is not None
972
1377
  else np.array([])
973
1378
  )
974
- rf_channel_values = (
1379
+ rf_channel_names = (
975
1380
  self.input_data.rf_channel
976
1381
  if self.input_data.rf_channel is not None
977
1382
  else np.array([])
978
1383
  )
979
- organic_media_channel_values = (
1384
+ organic_media_channel_names = (
980
1385
  self.input_data.organic_media_channel
981
1386
  if self.input_data.organic_media_channel is not None
982
1387
  else np.array([])
983
1388
  )
984
- organic_rf_channel_values = (
1389
+ organic_rf_channel_names = (
985
1390
  self.input_data.organic_rf_channel
986
1391
  if self.input_data.organic_rf_channel is not None
987
1392
  else np.array([])
988
1393
  )
989
- non_media_channel_values = (
1394
+ non_media_channel_names = (
990
1395
  self.input_data.non_media_channel
991
1396
  if self.input_data.non_media_channel is not None
992
1397
  else np.array([])
993
1398
  )
1399
+ control_variable_names = (
1400
+ self.input_data.control_variable
1401
+ if self.input_data.control_variable is not None
1402
+ else np.array([])
1403
+ )
994
1404
  return {
995
1405
  constants.CHAIN: np.arange(n_chains),
996
1406
  constants.DRAW: np.arange(n_draws),
@@ -998,12 +1408,12 @@ class Meridian:
998
1408
  constants.TIME: self.input_data.time,
999
1409
  constants.MEDIA_TIME: self.input_data.media_time,
1000
1410
  constants.KNOTS: np.arange(self.knot_info.n_knots),
1001
- constants.CONTROL_VARIABLE: self.input_data.control_variable,
1002
- constants.NON_MEDIA_CHANNEL: non_media_channel_values,
1003
- constants.MEDIA_CHANNEL: media_channel_values,
1004
- constants.RF_CHANNEL: rf_channel_values,
1005
- constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_values,
1006
- constants.ORGANIC_RF_CHANNEL: organic_rf_channel_values,
1411
+ constants.CONTROL_VARIABLE: control_variable_names,
1412
+ constants.NON_MEDIA_CHANNEL: non_media_channel_names,
1413
+ constants.MEDIA_CHANNEL: media_channel_names,
1414
+ constants.RF_CHANNEL: rf_channel_names,
1415
+ constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
1416
+ constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
1007
1417
  }
1008
1418
 
1009
1419
  def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]: