google-meridian 1.1.6__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/model/model.py CHANGED
@@ -22,6 +22,7 @@ import warnings
22
22
 
23
23
  import arviz as az
24
24
  import joblib
25
+ from meridian import backend
25
26
  from meridian import constants
26
27
  from meridian.data import input_data as data
27
28
  from meridian.data import time_coordinates as tc
@@ -34,8 +35,6 @@ from meridian.model import prior_sampler
34
35
  from meridian.model import spec
35
36
  from meridian.model import transformers
36
37
  import numpy as np
37
- import tensorflow as tf
38
- import tensorflow_probability as tfp
39
38
 
40
39
 
41
40
  __all__ = [
@@ -70,7 +69,7 @@ def _warn_setting_national_args(**kwargs):
70
69
 
71
70
 
72
71
  def _check_for_negative_effect(
73
- dist: tfp.distributions.Distribution, media_effects_dist: str
72
+ dist: backend.tfd.Distribution, media_effects_dist: str
74
73
  ):
75
74
  """Checks for negative effect in the model."""
76
75
  if (
@@ -202,45 +201,45 @@ class Meridian:
202
201
  return media.build_organic_rf_tensors(self.input_data)
203
202
 
204
203
  @functools.cached_property
205
- def kpi(self) -> tf.Tensor:
206
- return tf.convert_to_tensor(self.input_data.kpi, dtype=tf.float32)
204
+ def kpi(self) -> backend.Tensor:
205
+ return backend.to_tensor(self.input_data.kpi, dtype=backend.float32)
207
206
 
208
207
  @functools.cached_property
209
- def revenue_per_kpi(self) -> tf.Tensor | None:
208
+ def revenue_per_kpi(self) -> backend.Tensor | None:
210
209
  if self.input_data.revenue_per_kpi is None:
211
210
  return None
212
- return tf.convert_to_tensor(
213
- self.input_data.revenue_per_kpi, dtype=tf.float32
211
+ return backend.to_tensor(
212
+ self.input_data.revenue_per_kpi, dtype=backend.float32
214
213
  )
215
214
 
216
215
  @functools.cached_property
217
- def controls(self) -> tf.Tensor | None:
216
+ def controls(self) -> backend.Tensor | None:
218
217
  if self.input_data.controls is None:
219
218
  return None
220
- return tf.convert_to_tensor(self.input_data.controls, dtype=tf.float32)
219
+ return backend.to_tensor(self.input_data.controls, dtype=backend.float32)
221
220
 
222
221
  @functools.cached_property
223
- def non_media_treatments(self) -> tf.Tensor | None:
222
+ def non_media_treatments(self) -> backend.Tensor | None:
224
223
  if self.input_data.non_media_treatments is None:
225
224
  return None
226
- return tf.convert_to_tensor(
227
- self.input_data.non_media_treatments, dtype=tf.float32
225
+ return backend.to_tensor(
226
+ self.input_data.non_media_treatments, dtype=backend.float32
228
227
  )
229
228
 
230
229
  @functools.cached_property
231
- def population(self) -> tf.Tensor:
232
- return tf.convert_to_tensor(self.input_data.population, dtype=tf.float32)
230
+ def population(self) -> backend.Tensor:
231
+ return backend.to_tensor(self.input_data.population, dtype=backend.float32)
233
232
 
234
233
  @functools.cached_property
235
- def total_spend(self) -> tf.Tensor:
236
- return tf.convert_to_tensor(
237
- self.input_data.get_total_spend(), dtype=tf.float32
234
+ def total_spend(self) -> backend.Tensor:
235
+ return backend.to_tensor(
236
+ self.input_data.get_total_spend(), dtype=backend.float32
238
237
  )
239
238
 
240
239
  @functools.cached_property
241
- def total_outcome(self) -> tf.Tensor:
242
- return tf.convert_to_tensor(
243
- self.input_data.get_total_outcome(), dtype=tf.float32
240
+ def total_outcome(self) -> backend.Tensor:
241
+ return backend.to_tensor(
242
+ self.input_data.get_total_outcome(), dtype=backend.float32
244
243
  )
245
244
 
246
245
  @property
@@ -300,6 +299,8 @@ class Meridian:
300
299
  return knots.get_knot_info(
301
300
  n_times=self.n_times,
302
301
  knots=self.model_spec.knots,
302
+ enable_aks=self.model_spec.enable_aks,
303
+ data=self.input_data,
303
304
  is_national=self.is_national,
304
305
  )
305
306
 
@@ -312,8 +313,8 @@ class Meridian:
312
313
  return None
313
314
 
314
315
  if self.model_spec.control_population_scaling_id is not None:
315
- controls_population_scaling_id = tf.convert_to_tensor(
316
- self.model_spec.control_population_scaling_id, dtype=bool
316
+ controls_population_scaling_id = backend.to_tensor(
317
+ self.model_spec.control_population_scaling_id, dtype=backend.bool_
317
318
  )
318
319
  else:
319
320
  controls_population_scaling_id = None
@@ -332,8 +333,8 @@ class Meridian:
332
333
  if self.non_media_treatments is None:
333
334
  return None
334
335
  if self.model_spec.non_media_population_scaling_id is not None:
335
- non_media_population_scaling_id = tf.convert_to_tensor(
336
- self.model_spec.non_media_population_scaling_id, dtype=bool
336
+ non_media_population_scaling_id = backend.to_tensor(
337
+ self.model_spec.non_media_population_scaling_id, dtype=backend.bool_
337
338
  )
338
339
  else:
339
340
  non_media_population_scaling_id = None
@@ -349,7 +350,7 @@ class Meridian:
349
350
  return transformers.KpiTransformer(self.kpi, self.population)
350
351
 
351
352
  @functools.cached_property
352
- def controls_scaled(self) -> tf.Tensor | None:
353
+ def controls_scaled(self) -> backend.Tensor | None:
353
354
  if self.controls is not None:
354
355
  # If `controls` is defined, then `controls_transformer` is also defined.
355
356
  return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
@@ -357,7 +358,7 @@ class Meridian:
357
358
  return None
358
359
 
359
360
  @functools.cached_property
360
- def non_media_treatments_normalized(self) -> tf.Tensor | None:
361
+ def non_media_treatments_normalized(self) -> backend.Tensor | None:
361
362
  """Normalized non-media treatments.
362
363
 
363
364
  The non-media treatments values are scaled by population (for channels where
@@ -372,7 +373,7 @@ class Meridian:
372
373
  return None
373
374
 
374
375
  @functools.cached_property
375
- def kpi_scaled(self) -> tf.Tensor:
376
+ def kpi_scaled(self) -> backend.Tensor:
376
377
  return self.kpi_transformer.forward(self.kpi)
377
378
 
378
379
  @functools.cached_property
@@ -416,14 +417,35 @@ class Meridian:
416
417
  # Geos are unique, so index is a 1-element array.
417
418
  return index[0]
418
419
  else:
419
- return tf.argmax(self.population)
420
+ return backend.argmax(self.population)
420
421
 
421
422
  @functools.cached_property
422
- def holdout_id(self) -> tf.Tensor | None:
423
+ def holdout_id(self) -> backend.Tensor | None:
423
424
  if self.model_spec.holdout_id is None:
424
425
  return None
425
- tensor = tf.convert_to_tensor(self.model_spec.holdout_id, dtype=bool)
426
- return tensor[tf.newaxis, ...] if self.is_national else tensor
426
+ tensor = backend.to_tensor(self.model_spec.holdout_id, dtype=backend.bool_)
427
+ return tensor[backend.newaxis, ...] if self.is_national else tensor
428
+
429
+ @functools.cached_property
430
+ def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
431
+ """Returns `AdstockDecaySpec` object with correctly mapped channels."""
432
+ if isinstance(self.model_spec.adstock_decay_spec, str):
433
+ return adstock_hill.AdstockDecaySpec.from_consistent_type(
434
+ self.model_spec.adstock_decay_spec
435
+ )
436
+
437
+ try:
438
+ return self._create_adstock_decay_functions_from_channel_map(
439
+ self.model_spec.adstock_decay_spec
440
+ )
441
+ except KeyError as e:
442
+ raise ValueError(
443
+ "Unrecognized channel names found in `adstock_decay_spec` keys"
444
+ f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
445
+ " either contain only channel_names"
446
+ f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
447
+ " be one or more of {'media', 'rf', 'organic_media', 'organic_rf'}."
448
+ ) from e
427
449
 
428
450
  @functools.cached_property
429
451
  def prior_broadcast(self) -> prior_distribution.PriorDistribution:
@@ -469,7 +491,7 @@ class Meridian:
469
491
  def compute_non_media_treatments_baseline(
470
492
  self,
471
493
  non_media_baseline_values: Sequence[str | float] | None = None,
472
- ) -> tf.Tensor:
494
+ ) -> backend.Tensor:
473
495
  """Computes the baseline for each non-media treatment channel.
474
496
 
475
497
  Args:
@@ -491,16 +513,19 @@ class Meridian:
491
513
  if non_media_baseline_values is None:
492
514
  non_media_baseline_values = self.model_spec.non_media_baseline_values
493
515
 
516
+ no_op_scaling_factor = backend.ones_like(self.population)[
517
+ :, backend.newaxis, backend.newaxis
518
+ ]
494
519
  if self.model_spec.non_media_population_scaling_id is not None:
495
- scaling_factors = tf.where(
520
+ scaling_factors = backend.where(
496
521
  self.model_spec.non_media_population_scaling_id,
497
- self.population[:, tf.newaxis, tf.newaxis],
498
- tf.ones_like(self.population)[:, tf.newaxis, tf.newaxis],
522
+ self.population[:, backend.newaxis, backend.newaxis],
523
+ no_op_scaling_factor,
499
524
  )
500
525
  else:
501
- scaling_factors = tf.ones_like(self.population)[:, tf.newaxis, tf.newaxis]
526
+ scaling_factors = no_op_scaling_factor
502
527
 
503
- non_media_treatments_population_scaled = tf.math.divide_no_nan(
528
+ non_media_treatments_population_scaled = backend.divide_no_nan(
504
529
  self.non_media_treatments, scaling_factors
505
530
  )
506
531
 
@@ -528,15 +553,15 @@ class Meridian:
528
553
  baseline_value = non_media_baseline_values_filled[channel]
529
554
 
530
555
  if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
531
- baseline_for_channel = tf.reduce_min(
556
+ baseline_for_channel = backend.reduce_min(
532
557
  non_media_treatments_population_scaled[..., channel], axis=[0, 1]
533
558
  )
534
559
  elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
535
- baseline_for_channel = tf.reduce_max(
560
+ baseline_for_channel = backend.reduce_max(
536
561
  non_media_treatments_population_scaled[..., channel], axis=[0, 1]
537
562
  )
538
563
  elif isinstance(baseline_value, numbers.Number):
539
- baseline_for_channel = tf.cast(baseline_value, tf.float32)
564
+ baseline_for_channel = backend.cast(baseline_value, backend.float32)
540
565
  else:
541
566
  raise ValueError(
542
567
  f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
@@ -545,7 +570,7 @@ class Meridian:
545
570
 
546
571
  baseline_list.append(baseline_for_channel)
547
572
 
548
- return tf.stack(baseline_list, axis=-1)
573
+ return backend.stack(baseline_list, axis=-1)
549
574
 
550
575
  def expand_selected_time_dims(
551
576
  self,
@@ -766,6 +791,58 @@ class Meridian:
766
791
  f" ({self.n_non_media_channels},)`."
767
792
  )
768
793
 
794
+ def _create_adstock_decay_functions_from_channel_map(
795
+ self, channel_function_map: Mapping[str, str]
796
+ ) -> adstock_hill.AdstockDecaySpec:
797
+ """Create `AdstockDecaySpec` from mapping from channels to decay functions."""
798
+
799
+ for channel in channel_function_map:
800
+ if channel not in self.input_data.get_all_adstock_hill_channels():
801
+ raise KeyError(f"Channel {channel} not found in data.")
802
+
803
+ if self.input_data.media_channel is not None:
804
+ media_channel_builder = self.input_data.get_paid_media_channels_argument_builder().with_default_value(
805
+ constants.GEOMETRIC_DECAY
806
+ )
807
+ media_adstock_function = media_channel_builder(**channel_function_map)
808
+ else:
809
+ media_adstock_function = constants.GEOMETRIC_DECAY
810
+
811
+ if self.input_data.rf_channel is not None:
812
+ rf_channel_builder = self.input_data.get_paid_rf_channels_argument_builder().with_default_value(
813
+ constants.GEOMETRIC_DECAY
814
+ )
815
+ rf_adstock_function = rf_channel_builder(**channel_function_map)
816
+ else:
817
+ rf_adstock_function = constants.GEOMETRIC_DECAY
818
+
819
+ if self.input_data.organic_media_channel is not None:
820
+ organic_media_channel_builder = self.input_data.get_organic_media_channels_argument_builder().with_default_value(
821
+ constants.GEOMETRIC_DECAY
822
+ )
823
+ organic_media_adstock_function = organic_media_channel_builder(
824
+ **channel_function_map
825
+ )
826
+ else:
827
+ organic_media_adstock_function = constants.GEOMETRIC_DECAY
828
+
829
+ if self.input_data.organic_rf_channel is not None:
830
+ organic_rf_channel_builder = self.input_data.get_organic_rf_channels_argument_builder().with_default_value(
831
+ constants.GEOMETRIC_DECAY
832
+ )
833
+ organic_rf_adstock_function = organic_rf_channel_builder(
834
+ **channel_function_map
835
+ )
836
+ else:
837
+ organic_rf_adstock_function = constants.GEOMETRIC_DECAY
838
+
839
+ return adstock_hill.AdstockDecaySpec(
840
+ media=media_adstock_function,
841
+ rf=rf_adstock_function,
842
+ organic_media=organic_media_adstock_function,
843
+ organic_rf=organic_rf_adstock_function,
844
+ )
845
+
769
846
  def _warn_setting_ignored_priors(self):
770
847
  """Raises a warning if ignored priors are set."""
771
848
  default_distribution = prior_distribution.PriorDistribution()
@@ -946,7 +1023,7 @@ class Meridian:
946
1023
 
947
1024
  def _check_if_no_geo_variation(
948
1025
  self,
949
- scaled_data: tf.Tensor,
1026
+ scaled_data: backend.Tensor,
950
1027
  data_name: str,
951
1028
  data_dims: Sequence[str],
952
1029
  epsilon=1e-4,
@@ -954,16 +1031,16 @@ class Meridian:
954
1031
  """Raise an error if `n_knots == n_time` and data lacks geo variation."""
955
1032
 
956
1033
  # Result shape: [n, d], where d is the number of axes of condition.
957
- col_idx_full = tf.where(tf.math.reduce_std(scaled_data, axis=0) < epsilon)[
958
- :, 1
959
- ]
960
- col_idx_unique, _, counts = tf.unique_with_counts(col_idx_full)
1034
+ col_idx_full = backend.get_indices_where(
1035
+ backend.reduce_std(scaled_data, axis=0) < epsilon
1036
+ )[:, 1]
1037
+ col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
961
1038
  # We use the shape of scaled_data (instead of `n_time`) because the data may
962
1039
  # be padded to account for lagged effects.
963
1040
  data_n_time = scaled_data.shape[1]
964
- mask = tf.equal(counts, data_n_time)
965
- col_idx_bad = tf.boolean_mask(col_idx_unique, mask)
966
- dims_bad = tf.gather(data_dims, col_idx_bad)
1041
+ mask = backend.equal(counts, data_n_time)
1042
+ col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
1043
+ dims_bad = backend.gather(data_dims, col_idx_bad)
967
1044
 
968
1045
  if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
969
1046
  raise ValueError(
@@ -1024,7 +1101,7 @@ class Meridian:
1024
1101
 
1025
1102
  def _check_if_no_time_variation(
1026
1103
  self,
1027
- scaled_data: tf.Tensor,
1104
+ scaled_data: backend.Tensor,
1028
1105
  data_name: str,
1029
1106
  data_dims: Sequence[str],
1030
1107
  epsilon=1e-4,
@@ -1032,13 +1109,13 @@ class Meridian:
1032
1109
  """Raise an error if data lacks time variation."""
1033
1110
 
1034
1111
  # Result shape: [n, d], where d is the number of axes of condition.
1035
- col_idx_full = tf.where(tf.math.reduce_std(scaled_data, axis=1) < epsilon)[
1036
- :, 1
1037
- ]
1038
- col_idx_unique, _, counts = tf.unique_with_counts(col_idx_full)
1039
- mask = tf.equal(counts, self.n_geos)
1040
- col_idx_bad = tf.boolean_mask(col_idx_unique, mask)
1041
- dims_bad = tf.gather(data_dims, col_idx_bad)
1112
+ col_idx_full = backend.get_indices_where(
1113
+ backend.reduce_std(scaled_data, axis=1) < epsilon
1114
+ )[:, 1]
1115
+ col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
1116
+ mask = backend.equal(counts, self.n_geos)
1117
+ col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
1118
+ dims_bad = backend.gather(data_dims, col_idx_bad)
1042
1119
  if col_idx_bad.shape[0]:
1043
1120
  if self.is_national:
1044
1121
  raise ValueError(
@@ -1058,12 +1135,19 @@ class Meridian:
1058
1135
  " time."
1059
1136
  )
1060
1137
 
1138
+ def _kpi_has_variability(self):
1139
+ """Returns True if the KPI has variability across geos and times."""
1140
+ return self.kpi_transformer.population_scaled_stdev != 0
1141
+
1061
1142
  def _validate_kpi_transformer(self):
1062
1143
  """Validates the KPI transformer."""
1144
+ if self._kpi_has_variability():
1145
+ return
1146
+
1063
1147
  kpi = "kpi" if self.is_national else "population_scaled_kpi"
1148
+
1064
1149
  if (
1065
1150
  self.n_media_channels > 0
1066
- and self.kpi_transformer.population_scaled_stdev == 0
1067
1151
  and self.model_spec.effective_media_prior_type
1068
1152
  in constants.PAID_MEDIA_ROI_PRIOR_TYPES
1069
1153
  ):
@@ -1074,7 +1158,6 @@ class Meridian:
1074
1158
  )
1075
1159
  if (
1076
1160
  self.n_rf_channels > 0
1077
- and self.kpi_transformer.population_scaled_stdev == 0
1078
1161
  and self.model_spec.effective_rf_prior_type
1079
1162
  in constants.PAID_MEDIA_ROI_PRIOR_TYPES
1080
1163
  ):
@@ -1082,14 +1165,44 @@ class Meridian:
1082
1165
  f"`{kpi}` cannot be constant with"
1083
1166
  f' `rf_prior_type` = "{self.model_spec.effective_rf_prior_type}".'
1084
1167
  )
1168
+ if (
1169
+ self.n_organic_media_channels > 0
1170
+ and self.model_spec.organic_media_prior_type
1171
+ in [constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION]
1172
+ ):
1173
+ raise ValueError(
1174
+ f"`{kpi}` cannot be constant with"
1175
+ " `organic_media_prior_type` ="
1176
+ f' "{self.model_spec.organic_media_prior_type}".'
1177
+ )
1178
+ if (
1179
+ self.n_organic_rf_channels > 0
1180
+ and self.model_spec.organic_rf_prior_type
1181
+ in [constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION]
1182
+ ):
1183
+ raise ValueError(
1184
+ f"`{kpi}` cannot be constant with"
1185
+ " `organic_rf_prior_type` ="
1186
+ f' "{self.model_spec.organic_rf_prior_type}".'
1187
+ )
1188
+ if (
1189
+ self.n_non_media_channels > 0
1190
+ and self.model_spec.non_media_treatments_prior_type
1191
+ in [constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION]
1192
+ ):
1193
+ raise ValueError(
1194
+ f"`{kpi}` cannot be constant with"
1195
+ " `non_media_treatments_prior_type` ="
1196
+ f' "{self.model_spec.non_media_treatments_prior_type}".'
1197
+ )
1085
1198
 
1086
1199
  def linear_predictor_counterfactual_difference_media(
1087
1200
  self,
1088
- media_transformed: tf.Tensor,
1089
- alpha_m: tf.Tensor,
1090
- ec_m: tf.Tensor,
1091
- slope_m: tf.Tensor,
1092
- ) -> tf.Tensor:
1201
+ media_transformed: backend.Tensor,
1202
+ alpha_m: backend.Tensor,
1203
+ ec_m: backend.Tensor,
1204
+ slope_m: backend.Tensor,
1205
+ ) -> backend.Tensor:
1093
1206
  """Calculates linear predictor counterfactual difference for non-RF media.
1094
1207
 
1095
1208
  For non-RF media variables (paid or organic), this function calculates the
@@ -1118,18 +1231,21 @@ class Meridian:
1118
1231
  alpha_m,
1119
1232
  ec_m,
1120
1233
  slope_m,
1234
+ decay_functions=self.adstock_decay_spec.media,
1121
1235
  )
1122
1236
  # Absolute values is needed because the difference is negative for mROI
1123
1237
  # priors and positive for ROI and contribution priors.
1124
- return tf.abs(media_transformed - media_transformed_counterfactual)
1238
+ return backend.absolute(
1239
+ media_transformed - media_transformed_counterfactual
1240
+ )
1125
1241
 
1126
1242
  def linear_predictor_counterfactual_difference_rf(
1127
1243
  self,
1128
- rf_transformed: tf.Tensor,
1129
- alpha_rf: tf.Tensor,
1130
- ec_rf: tf.Tensor,
1131
- slope_rf: tf.Tensor,
1132
- ) -> tf.Tensor:
1244
+ rf_transformed: backend.Tensor,
1245
+ alpha_rf: backend.Tensor,
1246
+ ec_rf: backend.Tensor,
1247
+ slope_rf: backend.Tensor,
1248
+ ) -> backend.Tensor:
1133
1249
  """Calculates linear predictor counterfactual difference for RF media.
1134
1250
 
1135
1251
  For RF media variables (paid or organic), this function calculates the
@@ -1159,19 +1275,20 @@ class Meridian:
1159
1275
  alpha=alpha_rf,
1160
1276
  ec=ec_rf,
1161
1277
  slope=slope_rf,
1278
+ decay_functions=self.adstock_decay_spec.rf,
1162
1279
  )
1163
1280
  # Absolute values is needed because the difference is negative for mROI
1164
1281
  # priors and positive for ROI and contribution priors.
1165
- return tf.abs(rf_transformed - rf_transformed_counterfactual)
1282
+ return backend.absolute(rf_transformed - rf_transformed_counterfactual)
1166
1283
 
1167
1284
  def calculate_beta_x(
1168
1285
  self,
1169
1286
  is_non_media: bool,
1170
- incremental_outcome_x: tf.Tensor,
1171
- linear_predictor_counterfactual_difference: tf.Tensor,
1172
- eta_x: tf.Tensor,
1173
- beta_gx_dev: tf.Tensor,
1174
- ) -> tf.Tensor:
1287
+ incremental_outcome_x: backend.Tensor,
1288
+ linear_predictor_counterfactual_difference: backend.Tensor,
1289
+ eta_x: backend.Tensor,
1290
+ beta_gx_dev: backend.Tensor,
1291
+ ) -> backend.Tensor:
1175
1292
  """Calculates coefficient mean parameter for any treatment variable type.
1176
1293
 
1177
1294
  The "beta_x" in the function name refers to the coefficient mean parameter
@@ -1216,10 +1333,12 @@ class Meridian:
1216
1333
  self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
1217
1334
  )
1218
1335
  if self.revenue_per_kpi is None:
1219
- revenue_per_kpi = tf.ones([self.n_geos, self.n_times], dtype=tf.float32)
1336
+ revenue_per_kpi = backend.ones(
1337
+ [self.n_geos, self.n_times], dtype=backend.float32
1338
+ )
1220
1339
  else:
1221
1340
  revenue_per_kpi = self.revenue_per_kpi
1222
- incremental_outcome_gx_over_beta_gx = tf.einsum(
1341
+ incremental_outcome_gx_over_beta_gx = backend.einsum(
1223
1342
  "...gtx,gt,g,->...gx",
1224
1343
  linear_predictor_counterfactual_difference,
1225
1344
  revenue_per_kpi,
@@ -1227,34 +1346,35 @@ class Meridian:
1227
1346
  self.kpi_transformer.population_scaled_stdev,
1228
1347
  )
1229
1348
  if random_effects_normal:
1230
- numerator_term_x = tf.einsum(
1349
+ numerator_term_x = backend.einsum(
1231
1350
  "...gx,...gx,...x->...x",
1232
1351
  incremental_outcome_gx_over_beta_gx,
1233
1352
  beta_gx_dev,
1234
1353
  eta_x,
1235
1354
  )
1236
- denominator_term_x = tf.einsum(
1355
+ denominator_term_x = backend.einsum(
1237
1356
  "...gx->...x", incremental_outcome_gx_over_beta_gx
1238
1357
  )
1239
1358
  return (incremental_outcome_x - numerator_term_x) / denominator_term_x
1240
1359
  # For log-normal random effects, beta_x and eta_x are not mean & std.
1241
1360
  # The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
1242
- denominator_term_x = tf.einsum(
1361
+ denominator_term_x = backend.einsum(
1243
1362
  "...gx,...gx->...x",
1244
1363
  incremental_outcome_gx_over_beta_gx,
1245
- tf.math.exp(beta_gx_dev * eta_x[..., tf.newaxis, :]),
1364
+ backend.exp(beta_gx_dev * eta_x[..., backend.newaxis, :]),
1246
1365
  )
1247
- return tf.math.log(incremental_outcome_x) - tf.math.log(denominator_term_x)
1366
+ return backend.log(incremental_outcome_x) - backend.log(denominator_term_x)
1248
1367
 
1249
1368
  def adstock_hill_media(
1250
1369
  self,
1251
- media: tf.Tensor, # pylint: disable=redefined-outer-name
1252
- alpha: tf.Tensor,
1253
- ec: tf.Tensor,
1254
- slope: tf.Tensor,
1370
+ media: backend.Tensor, # pylint: disable=redefined-outer-name
1371
+ alpha: backend.Tensor,
1372
+ ec: backend.Tensor,
1373
+ slope: backend.Tensor,
1374
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
1255
1375
  n_times_output: int | None = None,
1256
- ) -> tf.Tensor:
1257
- """Transforms media using Adstock and Hill functions in the desired order.
1376
+ ) -> backend.Tensor:
1377
+ """Transforms media or using Adstock and Hill functions in the desired order.
1258
1378
 
1259
1379
  Args:
1260
1380
  media: Tensor of dimensions `(n_geos, n_media_times, n_media_channels)`
@@ -1264,6 +1384,8 @@ class Meridian:
1264
1384
  alpha: Uniform distribution for Adstock and Hill calculations.
1265
1385
  ec: Shifted half-normal distribution for Adstock and Hill calculations.
1266
1386
  slope: Deterministic distribution for Adstock and Hill calculations.
1387
+ decay_functions: String or sequence of strings denoting the adstock decay
1388
+ function(s) for each channel. Default: 'geometric'.
1267
1389
  n_times_output: Number of time periods to output. This argument is
1268
1390
  optional when the number of time periods in `media` equals
1269
1391
  `self.n_media_times`, in which case `n_times_output` defaults to
@@ -1284,6 +1406,7 @@ class Meridian:
1284
1406
  alpha=alpha,
1285
1407
  max_lag=self.model_spec.max_lag,
1286
1408
  n_times_output=n_times_output,
1409
+ decay_functions=decay_functions,
1287
1410
  )
1288
1411
  hill_transformer = adstock_hill.HillTransformer(
1289
1412
  ec=ec,
@@ -1302,13 +1425,14 @@ class Meridian:
1302
1425
 
1303
1426
  def adstock_hill_rf(
1304
1427
  self,
1305
- reach: tf.Tensor,
1306
- frequency: tf.Tensor,
1307
- alpha: tf.Tensor,
1308
- ec: tf.Tensor,
1309
- slope: tf.Tensor,
1428
+ reach: backend.Tensor,
1429
+ frequency: backend.Tensor,
1430
+ alpha: backend.Tensor,
1431
+ ec: backend.Tensor,
1432
+ slope: backend.Tensor,
1433
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
1310
1434
  n_times_output: int | None = None,
1311
- ) -> tf.Tensor:
1435
+ ) -> backend.Tensor:
1312
1436
  """Transforms reach and frequency (RF) using Hill and Adstock functions.
1313
1437
 
1314
1438
  Args:
@@ -1319,6 +1443,8 @@ class Meridian:
1319
1443
  alpha: Uniform distribution for Adstock and Hill calculations.
1320
1444
  ec: Shifted half-normal distribution for Adstock and Hill calculations.
1321
1445
  slope: Deterministic distribution for Adstock and Hill calculations.
1446
+ decay_functions: String or sequence of strings denoting the adstock decay
1447
+ function(s) for each channel. Default: 'geometric'.
1322
1448
  n_times_output: Number of time periods to output. This argument is
1323
1449
  optional when the number of time periods in `reach` equals
1324
1450
  `self.n_media_times`, in which case `n_times_output` defaults to
@@ -1343,6 +1469,7 @@ class Meridian:
1343
1469
  alpha=alpha,
1344
1470
  max_lag=self.model_spec.max_lag,
1345
1471
  n_times_output=n_times_output,
1472
+ decay_functions=decay_functions,
1346
1473
  )
1347
1474
  adj_frequency = hill_transformer.forward(frequency)
1348
1475
  rf_out = adstock_transformer.forward(reach * adj_frequency)
@@ -1448,7 +1575,7 @@ class Meridian:
1448
1575
  n_adapt: int,
1449
1576
  n_burnin: int,
1450
1577
  n_keep: int,
1451
- current_state: Mapping[str, tf.Tensor] | None = None,
1578
+ current_state: Mapping[str, backend.Tensor] | None = None,
1452
1579
  init_step_size: int | None = None,
1453
1580
  dual_averaging_kwargs: Mapping[str, int] | None = None,
1454
1581
  max_tree_depth: int = 10,