google-meridian 1.0.8__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/model/model.py CHANGED
@@ -16,6 +16,7 @@
16
16
 
17
17
  from collections.abc import Mapping, Sequence
18
18
  import functools
19
+ import numbers
19
20
  import os
20
21
  import warnings
21
22
 
@@ -34,6 +35,7 @@ from meridian.model import spec
34
35
  from meridian.model import transformers
35
36
  import numpy as np
36
37
  import tensorflow as tf
38
+ import tensorflow_probability as tfp
37
39
 
38
40
 
39
41
  __all__ = [
@@ -67,6 +69,21 @@ def _warn_setting_national_args(**kwargs):
67
69
  )
68
70
 
69
71
 
72
+ def _check_for_negative_effect(
73
+ dist: tfp.distributions.Distribution, media_effects_dist: str
74
+ ):
75
+ """Checks for negative effect in the model."""
76
+ if (
77
+ media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
78
+ and np.any(dist.cdf(0)) > 0
79
+ ):
80
+ raise ValueError(
81
+ "Media priors must have non-negative support when"
82
+ f' `media_effects_dist`="{media_effects_dist}". Found negative effect'
83
+ f" in {dist.name}."
84
+ )
85
+
86
+
70
87
  class Meridian:
71
88
  """Contains the main functionality for fitting the Meridian MMM model.
72
89
 
@@ -102,6 +119,8 @@ class Meridian:
102
119
  tensors.
103
120
  total_spend: A tensor containing total spend, including
104
121
  `media_tensors.media_spend` and `rf_tensors.rf_spend`.
122
+ total_outcome: A tensor containing the total outcome, aggregated over geos
123
+ and times.
105
124
  controls_transformer: A `ControlsTransformer` to scale controls tensors
106
125
  using the model's controls data.
107
126
  non_media_transformer: A `CenteringAndScalingTransformer` to scale non-media
@@ -146,9 +165,13 @@ class Meridian:
146
165
  unique_sigma_for_each_geo=self.model_spec.unique_sigma_for_each_geo,
147
166
  )
148
167
  self._warn_setting_ignored_priors()
149
- self._validate_paid_media_prior_type()
168
+ self._set_total_media_contribution_prior = False
169
+ self._validate_mroi_priors_non_revenue()
170
+ self._validate_roi_priors_non_revenue()
171
+ self._check_for_negative_effects()
150
172
  self._validate_geo_invariants()
151
173
  self._validate_time_invariants()
174
+ self._validate_kpi_transformer()
152
175
 
153
176
  @property
154
177
  def input_data(self) -> data.InputData:
@@ -212,6 +235,12 @@ class Meridian:
212
235
  self.input_data.get_total_spend(), dtype=tf.float32
213
236
  )
214
237
 
238
+ @functools.cached_property
239
+ def total_outcome(self) -> tf.Tensor:
240
+ return tf.convert_to_tensor(
241
+ self.input_data.get_total_outcome(), dtype=tf.float32
242
+ )
243
+
215
244
  @property
216
245
  def n_geos(self) -> int:
217
246
  return len(self.input_data.geo)
@@ -318,9 +347,17 @@ class Meridian:
318
347
  return self.controls_transformer.forward(self.controls)
319
348
 
320
349
  @functools.cached_property
321
- def non_media_treatments_scaled(self) -> tf.Tensor | None:
350
+ def non_media_treatments_normalized(self) -> tf.Tensor | None:
351
+ """Normalized non-media treatments.
352
+
353
+ The non-media treatments values are scaled by population (for channels where
354
+ `non_media_population_scaling_id` is `True`) and normalized by centering and
355
+ scaling with means and standard deviations.
356
+ """
322
357
  if self.non_media_transformer is not None:
323
- return self.non_media_transformer.forward(self.non_media_treatments) # pytype: disable=attribute-error
358
+ return self.non_media_transformer.forward(
359
+ self.non_media_treatments
360
+ ) # pytype: disable=attribute-error
324
361
  else:
325
362
  return None
326
363
 
@@ -380,12 +417,6 @@ class Meridian:
380
417
  @functools.cached_property
381
418
  def prior_broadcast(self) -> prior_distribution.PriorDistribution:
382
419
  """Returns broadcasted `PriorDistribution` object."""
383
- set_total_media_contribution_prior = (
384
- self.input_data.revenue_per_kpi is None
385
- and self.input_data.kpi_type == constants.NON_REVENUE
386
- and self.model_spec.paid_media_prior_type
387
- == constants.PAID_MEDIA_PRIOR_TYPE_ROI
388
- )
389
420
  total_spend = self.input_data.get_total_spend()
390
421
  # Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
391
422
  if len(total_spend.shape) == 1:
@@ -407,7 +438,7 @@ class Meridian:
407
438
  sigma_shape=self._sigma_shape,
408
439
  n_knots=self.knot_info.n_knots,
409
440
  is_national=self.is_national,
410
- set_total_media_contribution_prior=set_total_media_contribution_prior,
441
+ set_total_media_contribution_prior=self._set_total_media_contribution_prior,
411
442
  kpi=np.sum(self.input_data.kpi.values),
412
443
  total_spend=agg_total_spend,
413
444
  )
@@ -424,10 +455,91 @@ class Meridian:
424
455
  """A `PosteriorMCMCSampler` callable bound to this model."""
425
456
  return posterior_sampler.PosteriorMCMCSampler(self)
426
457
 
458
+ def compute_non_media_treatments_baseline(
459
+ self,
460
+ non_media_baseline_values: Sequence[str | float] | None = None,
461
+ ) -> tf.Tensor:
462
+ """Computes the baseline for each non-media treatment channel.
463
+
464
+ Args:
465
+ non_media_baseline_values: Optional list of shape
466
+ `(n_non_media_channels,)`. Each element is either a float (which means
467
+ that the fixed value will be used as baseline for the given channel) or
468
+ one of the strings "min" or "max" (which mean that the global minimum or
469
+ maximum value will be used as baseline for the values of the given
470
+ non_media treatment channel). If float values are provided, it is
471
+ expected that they are scaled by population for the channels where
472
+ `model_spec.non_media_population_scaling_id` is `True`. If `None`, the
473
+ `model_spec.non_media_baseline_values` is used, which defaults to the
474
+ minimum value for each non_media treatment channel.
475
+
476
+ Returns:
477
+ A tensor of shape `(n_non_media_channels,)` containing the
478
+ baseline values for each non-media treatment channel.
479
+ """
480
+ if non_media_baseline_values is None:
481
+ non_media_baseline_values = self.model_spec.non_media_baseline_values
482
+
483
+ if self.model_spec.non_media_population_scaling_id is not None:
484
+ scaling_factors = tf.where(
485
+ self.model_spec.non_media_population_scaling_id,
486
+ self.population[:, tf.newaxis, tf.newaxis],
487
+ tf.ones_like(self.population)[:, tf.newaxis, tf.newaxis],
488
+ )
489
+ else:
490
+ scaling_factors = tf.ones_like(self.population)[:, tf.newaxis, tf.newaxis]
491
+
492
+ non_media_treatments_population_scaled = tf.math.divide_no_nan(
493
+ self.non_media_treatments, scaling_factors
494
+ )
495
+
496
+ if non_media_baseline_values is None:
497
+ # If non_media_baseline_values is not provided, use the minimum
498
+ # value for each non_media treatment channel as the baseline.
499
+ non_media_baseline_values_filled = [
500
+ constants.NON_MEDIA_BASELINE_MIN
501
+ ] * non_media_treatments_population_scaled.shape[-1]
502
+ else:
503
+ non_media_baseline_values_filled = non_media_baseline_values
504
+
505
+ if non_media_treatments_population_scaled.shape[-1] != len(
506
+ non_media_baseline_values_filled
507
+ ):
508
+ raise ValueError(
509
+ "The number of non-media channels"
510
+ f" ({non_media_treatments_population_scaled.shape[-1]}) does not"
511
+ " match the number of baseline values"
512
+ f" ({len(non_media_baseline_values_filled)})."
513
+ )
514
+
515
+ baseline_list = []
516
+ for channel in range(non_media_treatments_population_scaled.shape[-1]):
517
+ baseline_value = non_media_baseline_values_filled[channel]
518
+
519
+ if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
520
+ baseline_for_channel = tf.reduce_min(
521
+ non_media_treatments_population_scaled[..., channel], axis=[0, 1]
522
+ )
523
+ elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
524
+ baseline_for_channel = tf.reduce_max(
525
+ non_media_treatments_population_scaled[..., channel], axis=[0, 1]
526
+ )
527
+ elif isinstance(baseline_value, numbers.Number):
528
+ baseline_for_channel = tf.cast(baseline_value, tf.float32)
529
+ else:
530
+ raise ValueError(
531
+ f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
532
+ " float numbers and strings 'min' and 'max' are supported."
533
+ )
534
+
535
+ baseline_list.append(baseline_for_channel)
536
+
537
+ return tf.stack(baseline_list, axis=-1)
538
+
427
539
  def expand_selected_time_dims(
428
540
  self,
429
- start_date: tc.Date | None = None,
430
- end_date: tc.Date | None = None,
541
+ start_date: tc.Date = None,
542
+ end_date: tc.Date = None,
431
543
  ) -> list[str] | None:
432
544
  """Validates and returns time dimension values based on the selected times.
433
545
 
@@ -650,51 +762,132 @@ class Meridian:
650
762
  def _warn_setting_ignored_priors(self):
651
763
  """Raises a warning if ignored priors are set."""
652
764
  default_distribution = prior_distribution.PriorDistribution()
653
- prior_type = self.model_spec.paid_media_prior_type
654
-
655
- ignored_custom_priors = []
656
- for prior in constants.IGNORED_PRIORS.get(prior_type, []):
657
- self_prior = getattr(self.model_spec.prior, prior)
658
- default_prior = getattr(default_distribution, prior)
659
- if not prior_distribution.distributions_are_equal(
660
- self_prior, default_prior
661
- ):
662
- ignored_custom_priors.append(prior)
663
- if ignored_custom_priors:
664
- ignored_priors_str = ", ".join(ignored_custom_priors)
665
- warnings.warn(
666
- f"Custom prior(s) `{ignored_priors_str}` are ignored when"
667
- " `paid_media_prior_type` is set to"
668
- f' "{prior_type}".'
669
- )
765
+ for ignored_priors_dict, prior_type, prior_type_name in (
766
+ (
767
+ constants.IGNORED_PRIORS_MEDIA,
768
+ self.model_spec.effective_media_prior_type,
769
+ "media_prior_type",
770
+ ),
771
+ (
772
+ constants.IGNORED_PRIORS_RF,
773
+ self.model_spec.effective_rf_prior_type,
774
+ "rf_prior_type",
775
+ ),
776
+ ):
777
+ ignored_custom_priors = []
778
+ for prior in ignored_priors_dict.get(prior_type, []):
779
+ self_prior = getattr(self.model_spec.prior, prior)
780
+ default_prior = getattr(default_distribution, prior)
781
+ if not prior_distribution.distributions_are_equal(
782
+ self_prior, default_prior
783
+ ):
784
+ ignored_custom_priors.append(prior)
785
+ if ignored_custom_priors:
786
+ ignored_priors_str = ", ".join(ignored_custom_priors)
787
+ warnings.warn(
788
+ f"Custom prior(s) `{ignored_priors_str}` are ignored when"
789
+ f' `{prior_type_name}` is set to "{prior_type}".'
790
+ )
670
791
 
671
- def _validate_paid_media_prior_type(self):
672
- """Validates the media prior type."""
673
- default_distribution = prior_distribution.PriorDistribution()
674
- mroi_m_not_set = (
675
- self.n_media_channels > 0
676
- and prior_distribution.distributions_are_equal(
677
- self.model_spec.prior.mroi_m, default_distribution.mroi_m
792
+ def _validate_mroi_priors_non_revenue(self):
793
+ """Validates mroi priors in the non-revenue outcome case."""
794
+ if (
795
+ self.input_data.kpi_type == constants.NON_REVENUE
796
+ and self.input_data.revenue_per_kpi is None
797
+ ):
798
+ default_distribution = prior_distribution.PriorDistribution()
799
+ if (
800
+ self.n_media_channels > 0
801
+ and (
802
+ self.model_spec.effective_media_prior_type
803
+ == constants.TREATMENT_PRIOR_TYPE_MROI
804
+ )
805
+ and prior_distribution.distributions_are_equal(
806
+ self.model_spec.prior.mroi_m, default_distribution.mroi_m
807
+ )
808
+ ):
809
+ raise ValueError(
810
+ f"Custom priors should be set on `{constants.MROI_M}` when"
811
+ ' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
812
+ " kpi data is missing."
678
813
  )
679
- )
680
- mroi_rf_not_set = (
681
- self.n_rf_channels > 0
682
- and prior_distribution.distributions_are_equal(
683
- self.model_spec.prior.mroi_rf, default_distribution.mroi_rf
814
+ if (
815
+ self.n_rf_channels > 0
816
+ and (
817
+ self.model_spec.effective_rf_prior_type
818
+ == constants.TREATMENT_PRIOR_TYPE_MROI
819
+ )
820
+ and prior_distribution.distributions_are_equal(
821
+ self.model_spec.prior.mroi_rf, default_distribution.mroi_rf
822
+ )
823
+ ):
824
+ raise ValueError(
825
+ f"Custom priors should be set on `{constants.MROI_RF}` when"
826
+ ' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
827
+ " data is missing."
684
828
  )
685
- )
829
+
830
+ def _validate_roi_priors_non_revenue(self):
831
+ """Validates roi priors in the non-revenue outcome case."""
686
832
  if (
687
- self.input_data.revenue_per_kpi is None
688
- and self.input_data.kpi_type == constants.NON_REVENUE
689
- and self.model_spec.paid_media_prior_type
690
- == constants.PAID_MEDIA_PRIOR_TYPE_MROI
691
- and (mroi_m_not_set or mroi_rf_not_set)
833
+ self.input_data.kpi_type == constants.NON_REVENUE
834
+ and self.input_data.revenue_per_kpi is None
692
835
  ):
693
- raise ValueError(
694
- f"Custom priors should be set on `{constants.MROI_M}` and"
695
- f" `{constants.MROI_RF}` when KPI is non-revenue and revenue per kpi"
696
- " data is missing."
836
+ default_distribution = prior_distribution.PriorDistribution()
837
+ default_roi_m_used = (
838
+ self.model_spec.effective_media_prior_type
839
+ == constants.TREATMENT_PRIOR_TYPE_ROI
840
+ and prior_distribution.distributions_are_equal(
841
+ self.model_spec.prior.roi_m, default_distribution.roi_m
842
+ )
843
+ )
844
+ default_roi_rf_used = (
845
+ self.model_spec.effective_rf_prior_type
846
+ == constants.TREATMENT_PRIOR_TYPE_ROI
847
+ and prior_distribution.distributions_are_equal(
848
+ self.model_spec.prior.roi_rf, default_distribution.roi_rf
849
+ )
697
850
  )
851
+ # If ROI priors are used with the default prior distribution for all paid
852
+ # channels (media and RF), then use the "total paid media contribution
853
+ # prior" procedure.
854
+ if (
855
+ (default_roi_m_used and default_roi_rf_used)
856
+ or (self.n_media_channels == 0 and default_roi_rf_used)
857
+ or (self.n_rf_channels == 0 and default_roi_m_used)
858
+ ):
859
+ self._set_total_media_contribution_prior = True
860
+ warnings.warn(
861
+ "Consider setting custom ROI priors, as kpi_type was specified as"
862
+ " `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
863
+ " total media contribution prior will be used with"
864
+ f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
865
+ " documentation available at "
866
+ " https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
867
+ )
868
+ elif self.n_media_channels > 0 and default_roi_m_used:
869
+ raise ValueError(
870
+ f"Custom priors should be set on `{constants.ROI_M}` when"
871
+ ' `media_prior_type` is "roi", custom priors are assigned on'
872
+ ' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
873
+ " non-revenue and revenue per kpi data is missing."
874
+ )
875
+ elif self.n_rf_channels > 0 and default_roi_rf_used:
876
+ raise ValueError(
877
+ f"Custom priors should be set on `{constants.ROI_RF}` when"
878
+ ' `rf_prior_type` is "roi", custom priors are assigned on'
879
+ ' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
880
+ " non-revenue and revenue per kpi data is missing."
881
+ )
882
+
883
+ def _check_for_negative_effects(self):
884
+ prior = self.model_spec.prior
885
+ if self.n_media_channels > 0:
886
+ _check_for_negative_effect(prior.roi_m, self.media_effects_dist)
887
+ _check_for_negative_effect(prior.mroi_m, self.media_effects_dist)
888
+ if self.n_rf_channels > 0:
889
+ _check_for_negative_effect(prior.roi_rf, self.media_effects_dist)
890
+ _check_for_negative_effect(prior.mroi_rf, self.media_effects_dist)
698
891
 
699
892
  def _validate_geo_invariants(self):
700
893
  """Validates non-national model invariants."""
@@ -708,7 +901,7 @@ class Meridian:
708
901
  )
709
902
  if self.input_data.non_media_treatments is not None:
710
903
  self._check_if_no_geo_variation(
711
- self.non_media_treatments_scaled,
904
+ self.non_media_treatments_normalized,
712
905
  constants.NON_MEDIA_TREATMENTS,
713
906
  self.input_data.non_media_treatments.coords[
714
907
  constants.NON_MEDIA_CHANNEL
@@ -784,6 +977,14 @@ class Meridian:
784
977
  constants.CONTROLS,
785
978
  self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
786
979
  )
980
+ if self.input_data.non_media_treatments is not None:
981
+ self._check_if_no_time_variation(
982
+ self.non_media_treatments_normalized,
983
+ constants.NON_MEDIA_TREATMENTS,
984
+ self.input_data.non_media_treatments.coords[
985
+ constants.NON_MEDIA_CHANNEL
986
+ ].values,
987
+ )
787
988
  if self.input_data.media is not None:
788
989
  self._check_if_no_time_variation(
789
990
  self.media_tensors.media_scaled,
@@ -796,6 +997,22 @@ class Meridian:
796
997
  constants.REACH,
797
998
  self.input_data.reach.coords[constants.RF_CHANNEL].values,
798
999
  )
1000
+ if self.input_data.organic_media is not None:
1001
+ self._check_if_no_time_variation(
1002
+ self.organic_media_tensors.organic_media_scaled,
1003
+ constants.ORGANIC_MEDIA,
1004
+ self.input_data.organic_media.coords[
1005
+ constants.ORGANIC_MEDIA_CHANNEL
1006
+ ].values,
1007
+ )
1008
+ if self.input_data.organic_reach is not None:
1009
+ self._check_if_no_time_variation(
1010
+ self.organic_rf_tensors.organic_reach_scaled,
1011
+ constants.ORGANIC_REACH,
1012
+ self.input_data.organic_reach.coords[
1013
+ constants.ORGANIC_RF_CHANNEL
1014
+ ].values,
1015
+ )
799
1016
 
800
1017
  def _check_if_no_time_variation(
801
1018
  self,
@@ -825,6 +1042,194 @@ class Meridian:
825
1042
  " the listed variables that do not vary across time."
826
1043
  )
827
1044
 
1045
+ def _validate_kpi_transformer(self):
1046
+ """Validates the KPI transformer."""
1047
+ kpi = "kpi" if self.is_national else "population_scaled_kpi"
1048
+ if (
1049
+ self.n_media_channels > 0
1050
+ and self.kpi_transformer.population_scaled_stdev == 0
1051
+ and self.model_spec.effective_media_prior_type
1052
+ in constants.PAID_MEDIA_ROI_PRIOR_TYPES
1053
+ ):
1054
+ raise ValueError(
1055
+ f"`{kpi}` cannot be constant with"
1056
+ " `media_prior_type` ="
1057
+ f' "{self.model_spec.effective_media_prior_type}".'
1058
+ )
1059
+ if (
1060
+ self.n_rf_channels > 0
1061
+ and self.kpi_transformer.population_scaled_stdev == 0
1062
+ and self.model_spec.effective_rf_prior_type
1063
+ in constants.PAID_MEDIA_ROI_PRIOR_TYPES
1064
+ ):
1065
+ raise ValueError(
1066
+ f"`{kpi}` cannot be constant with"
1067
+ f' `rf_prior_type` = "{self.model_spec.effective_rf_prior_type}".'
1068
+ )
1069
+
1070
+ def linear_predictor_counterfactual_difference_media(
1071
+ self,
1072
+ media_transformed: tf.Tensor,
1073
+ alpha_m: tf.Tensor,
1074
+ ec_m: tf.Tensor,
1075
+ slope_m: tf.Tensor,
1076
+ ) -> tf.Tensor:
1077
+ """Calculates linear predictor counterfactual difference for non-RF media.
1078
+
1079
+ For non-RF media variables (paid or organic), this function calculates the
1080
+ linear predictor difference between the treatment variable and its
1081
+ counterfactual. "Linear predictor" refers to the output of the hill/adstock
1082
+ function, which is multiplied by the geo-level coefficient.
1083
+
1084
+ This function does the calculation efficiently by only calculating calling
1085
+ the hill/adstock function if the prior counterfactual is not all zeros.
1086
+
1087
+ Args:
1088
+ media_transformed: The output of the hill/adstock function for actual
1089
+ historical media data.
1090
+ alpha_m: The adstock alpha parameter values.
1091
+ ec_m: The adstock ec parameter values.
1092
+ slope_m: The adstock hill slope parameter values.
1093
+
1094
+ Returns:
1095
+ The linear predictor difference between the treatment variable and its
1096
+ counterfactual.
1097
+ """
1098
+ if self.media_tensors.prior_media_scaled_counterfactual is None:
1099
+ return media_transformed
1100
+ media_transformed_counterfactual = self.adstock_hill_media(
1101
+ self.media_tensors.prior_media_scaled_counterfactual,
1102
+ alpha_m,
1103
+ ec_m,
1104
+ slope_m,
1105
+ )
1106
+ # Absolute values is needed because the difference is negative for mROI
1107
+ # priors and positive for ROI and contribution priors.
1108
+ return tf.abs(media_transformed - media_transformed_counterfactual)
1109
+
1110
+ def linear_predictor_counterfactual_difference_rf(
1111
+ self,
1112
+ rf_transformed: tf.Tensor,
1113
+ alpha_rf: tf.Tensor,
1114
+ ec_rf: tf.Tensor,
1115
+ slope_rf: tf.Tensor,
1116
+ ) -> tf.Tensor:
1117
+ """Calculates linear predictor counterfactual difference for RF media.
1118
+
1119
+ For RF media variables (paid or organic), this function calculates the
1120
+ linear predictor difference between the treatment variable and its
1121
+ counterfactual. "Linear predictor" refers to the output of the hill/adstock
1122
+ function, which is multiplied by the geo-level coefficient.
1123
+
1124
+ This function does the calculation efficiently by only calculating calling
1125
+ the hill/adstock function if the prior counterfactual is not all zeros.
1126
+
1127
+ Args:
1128
+ rf_transformed: The output of the hill/adstock function for actual
1129
+ historical media data.
1130
+ alpha_rf: The adstock alpha parameter values.
1131
+ ec_rf: The adstock ec parameter values.
1132
+ slope_rf: The adstock hill slope parameter values.
1133
+
1134
+ Returns:
1135
+ The linear predictor difference between the treatment variable and its
1136
+ counterfactual.
1137
+ """
1138
+ if self.rf_tensors.prior_reach_scaled_counterfactual is None:
1139
+ return rf_transformed
1140
+ rf_transformed_counterfactual = self.adstock_hill_rf(
1141
+ reach=self.rf_tensors.prior_reach_scaled_counterfactual,
1142
+ frequency=self.rf_tensors.frequency,
1143
+ alpha=alpha_rf,
1144
+ ec=ec_rf,
1145
+ slope=slope_rf,
1146
+ )
1147
+ # Absolute values is needed because the difference is negative for mROI
1148
+ # priors and positive for ROI and contribution priors.
1149
+ return tf.abs(rf_transformed - rf_transformed_counterfactual)
1150
+
1151
+ def calculate_beta_x(
1152
+ self,
1153
+ is_non_media: bool,
1154
+ incremental_outcome_x: tf.Tensor,
1155
+ linear_predictor_counterfactual_difference: tf.Tensor,
1156
+ eta_x: tf.Tensor,
1157
+ beta_gx_dev: tf.Tensor,
1158
+ ) -> tf.Tensor:
1159
+ """Calculates coefficient mean parameter for any treatment variable type.
1160
+
1161
+ The "beta_x" in the function name refers to the coefficient mean parameter
1162
+ of any treatment variable. The "x" can represent "m", "rf", "om", or "orf".
1163
+ This function can also be used to calculate "gamma_n" for any non-media
1164
+ treatments.
1165
+
1166
+ Args:
1167
+ is_non_media: Boolean indicating whether the treatment variable is a
1168
+ non-media treatment. This argument is used to determine whether the
1169
+ coefficient random effects are normal or log-normal. If `True`, then
1170
+ random effects are assumed to be normal. Otherwise, the distribution is
1171
+ inferred from `self.media_effects_dist`.
1172
+ incremental_outcome_x: The incremental outcome of the treatment variable,
1173
+ which depends on the parameter values of a particular prior or posterior
1174
+ draw. The "_x" indicates that this is a tensor with length equal to the
1175
+ dimension of the treatment variable.
1176
+ linear_predictor_counterfactual_difference: The difference between the
1177
+ treatment variable and its counterfactual on the linear predictor scale.
1178
+ "Linear predictor" refers to the quantity that is multiplied by the
1179
+ geo-level coefficient. For media variables, this is the output of the
1180
+ hill/adstock transformation function. For non-media treatments, this is
1181
+ simply the treatment variable after centering/scaling transformations.
1182
+ This tensor has dimensions for geo, time, and channel.
1183
+ eta_x: The random effect standard deviation parameter values. For media
1184
+ variables, the "x" represents "m", "rf", "om", or "orf". For non-media
1185
+ treatments, this argument should be set to `xi_n`, which is analogous to
1186
+ "eta".
1187
+ beta_gx_dev: The latent standard normal parameter values of the geo-level
1188
+ coefficients. For media variables, the "x" represents "m", "rf", "om",
1189
+ or "orf". For non-media treatments, this argument should be set to
1190
+ `gamma_gn_dev`, which is analogous to "beta_gx_dev".
1191
+
1192
+ Returns:
1193
+ The coefficient mean parameter of the treatment variable, which has
1194
+ dimension equal to the number of treatment channels..
1195
+ """
1196
+ if is_non_media:
1197
+ random_effects_normal = True
1198
+ else:
1199
+ random_effects_normal = (
1200
+ self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1201
+ )
1202
+ if self.revenue_per_kpi is None:
1203
+ revenue_per_kpi = tf.ones([self.n_geos, self.n_times], dtype=tf.float32)
1204
+ else:
1205
+ revenue_per_kpi = self.revenue_per_kpi
1206
+ incremental_outcome_gx_over_beta_gx = tf.einsum(
1207
+ "...gtx,gt,g,->...gx",
1208
+ linear_predictor_counterfactual_difference,
1209
+ revenue_per_kpi,
1210
+ self.population,
1211
+ self.kpi_transformer.population_scaled_stdev,
1212
+ )
1213
+ if random_effects_normal:
1214
+ numerator_term_x = tf.einsum(
1215
+ "...gx,...gx,...x->...x",
1216
+ incremental_outcome_gx_over_beta_gx,
1217
+ beta_gx_dev,
1218
+ eta_x,
1219
+ )
1220
+ denominator_term_x = tf.einsum(
1221
+ "...gx->...x", incremental_outcome_gx_over_beta_gx
1222
+ )
1223
+ return (incremental_outcome_x - numerator_term_x) / denominator_term_x
1224
+ # For log-normal random effects, beta_x and eta_x are not mean & std.
1225
+ # The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
1226
+ denominator_term_x = tf.einsum(
1227
+ "...gx,...gx->...x",
1228
+ incremental_outcome_gx_over_beta_gx,
1229
+ tf.math.exp(beta_gx_dev * eta_x[..., tf.newaxis, :]),
1230
+ )
1231
+ return tf.math.log(incremental_outcome_x) - tf.math.log(denominator_term_x)
1232
+
828
1233
  def adstock_hill_media(
829
1234
  self,
830
1235
  media: tf.Tensor, # pylint: disable=redefined-outer-name
@@ -130,6 +130,17 @@ class WithInputDataSamples:
130
130
  seed=0,
131
131
  )
132
132
  )
133
+ self.input_data_media_and_rf_non_revenue_no_revenue_per_kpi = (
134
+ test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
135
+ n_geos=self._N_GEOS,
136
+ n_times=self._N_TIMES,
137
+ n_media_times=self._N_MEDIA_TIMES,
138
+ n_controls=self._N_CONTROLS,
139
+ n_media_channels=self._N_MEDIA_CHANNELS,
140
+ n_rf_channels=self._N_RF_CHANNELS,
141
+ seed=0,
142
+ )
143
+ )
133
144
  self.input_data_with_media_only = (
134
145
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
135
146
  n_geos=self._N_GEOS,