google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/METADATA +8 -2
- google_meridian-1.2.1.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +621 -393
- meridian/analysis/optimizer.py +403 -351
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +53 -54
- meridian/backend/__init__.py +975 -0
- meridian/backend/config.py +118 -0
- meridian/backend/test_utils.py +181 -0
- meridian/constants.py +71 -10
- meridian/data/input_data.py +99 -0
- meridian/data/test_utils.py +146 -12
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +280 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +735 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +331 -159
- meridian/model/posterior_sampler.py +388 -383
- meridian/model/prior_distribution.py +612 -177
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +55 -49
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
meridian/model/model.py
CHANGED
|
@@ -22,6 +22,7 @@ import warnings
|
|
|
22
22
|
|
|
23
23
|
import arviz as az
|
|
24
24
|
import joblib
|
|
25
|
+
from meridian import backend
|
|
25
26
|
from meridian import constants
|
|
26
27
|
from meridian.data import input_data as data
|
|
27
28
|
from meridian.data import time_coordinates as tc
|
|
@@ -34,8 +35,6 @@ from meridian.model import prior_sampler
|
|
|
34
35
|
from meridian.model import spec
|
|
35
36
|
from meridian.model import transformers
|
|
36
37
|
import numpy as np
|
|
37
|
-
import tensorflow as tf
|
|
38
|
-
import tensorflow_probability as tfp
|
|
39
38
|
|
|
40
39
|
|
|
41
40
|
__all__ = [
|
|
@@ -70,7 +69,7 @@ def _warn_setting_national_args(**kwargs):
|
|
|
70
69
|
|
|
71
70
|
|
|
72
71
|
def _check_for_negative_effect(
|
|
73
|
-
dist:
|
|
72
|
+
dist: backend.tfd.Distribution, media_effects_dist: str
|
|
74
73
|
):
|
|
75
74
|
"""Checks for negative effect in the model."""
|
|
76
75
|
if (
|
|
@@ -202,45 +201,45 @@ class Meridian:
|
|
|
202
201
|
return media.build_organic_rf_tensors(self.input_data)
|
|
203
202
|
|
|
204
203
|
@functools.cached_property
|
|
205
|
-
def kpi(self) ->
|
|
206
|
-
return
|
|
204
|
+
def kpi(self) -> backend.Tensor:
|
|
205
|
+
return backend.to_tensor(self.input_data.kpi, dtype=backend.float32)
|
|
207
206
|
|
|
208
207
|
@functools.cached_property
|
|
209
|
-
def revenue_per_kpi(self) ->
|
|
208
|
+
def revenue_per_kpi(self) -> backend.Tensor | None:
|
|
210
209
|
if self.input_data.revenue_per_kpi is None:
|
|
211
210
|
return None
|
|
212
|
-
return
|
|
213
|
-
self.input_data.revenue_per_kpi, dtype=
|
|
211
|
+
return backend.to_tensor(
|
|
212
|
+
self.input_data.revenue_per_kpi, dtype=backend.float32
|
|
214
213
|
)
|
|
215
214
|
|
|
216
215
|
@functools.cached_property
|
|
217
|
-
def controls(self) ->
|
|
216
|
+
def controls(self) -> backend.Tensor | None:
|
|
218
217
|
if self.input_data.controls is None:
|
|
219
218
|
return None
|
|
220
|
-
return
|
|
219
|
+
return backend.to_tensor(self.input_data.controls, dtype=backend.float32)
|
|
221
220
|
|
|
222
221
|
@functools.cached_property
|
|
223
|
-
def non_media_treatments(self) ->
|
|
222
|
+
def non_media_treatments(self) -> backend.Tensor | None:
|
|
224
223
|
if self.input_data.non_media_treatments is None:
|
|
225
224
|
return None
|
|
226
|
-
return
|
|
227
|
-
self.input_data.non_media_treatments, dtype=
|
|
225
|
+
return backend.to_tensor(
|
|
226
|
+
self.input_data.non_media_treatments, dtype=backend.float32
|
|
228
227
|
)
|
|
229
228
|
|
|
230
229
|
@functools.cached_property
|
|
231
|
-
def population(self) ->
|
|
232
|
-
return
|
|
230
|
+
def population(self) -> backend.Tensor:
|
|
231
|
+
return backend.to_tensor(self.input_data.population, dtype=backend.float32)
|
|
233
232
|
|
|
234
233
|
@functools.cached_property
|
|
235
|
-
def total_spend(self) ->
|
|
236
|
-
return
|
|
237
|
-
self.input_data.get_total_spend(), dtype=
|
|
234
|
+
def total_spend(self) -> backend.Tensor:
|
|
235
|
+
return backend.to_tensor(
|
|
236
|
+
self.input_data.get_total_spend(), dtype=backend.float32
|
|
238
237
|
)
|
|
239
238
|
|
|
240
239
|
@functools.cached_property
|
|
241
|
-
def total_outcome(self) ->
|
|
242
|
-
return
|
|
243
|
-
self.input_data.get_total_outcome(), dtype=
|
|
240
|
+
def total_outcome(self) -> backend.Tensor:
|
|
241
|
+
return backend.to_tensor(
|
|
242
|
+
self.input_data.get_total_outcome(), dtype=backend.float32
|
|
244
243
|
)
|
|
245
244
|
|
|
246
245
|
@property
|
|
@@ -300,6 +299,8 @@ class Meridian:
|
|
|
300
299
|
return knots.get_knot_info(
|
|
301
300
|
n_times=self.n_times,
|
|
302
301
|
knots=self.model_spec.knots,
|
|
302
|
+
enable_aks=self.model_spec.enable_aks,
|
|
303
|
+
data=self.input_data,
|
|
303
304
|
is_national=self.is_national,
|
|
304
305
|
)
|
|
305
306
|
|
|
@@ -312,8 +313,8 @@ class Meridian:
|
|
|
312
313
|
return None
|
|
313
314
|
|
|
314
315
|
if self.model_spec.control_population_scaling_id is not None:
|
|
315
|
-
controls_population_scaling_id =
|
|
316
|
-
self.model_spec.control_population_scaling_id, dtype=
|
|
316
|
+
controls_population_scaling_id = backend.to_tensor(
|
|
317
|
+
self.model_spec.control_population_scaling_id, dtype=backend.bool_
|
|
317
318
|
)
|
|
318
319
|
else:
|
|
319
320
|
controls_population_scaling_id = None
|
|
@@ -332,8 +333,8 @@ class Meridian:
|
|
|
332
333
|
if self.non_media_treatments is None:
|
|
333
334
|
return None
|
|
334
335
|
if self.model_spec.non_media_population_scaling_id is not None:
|
|
335
|
-
non_media_population_scaling_id =
|
|
336
|
-
self.model_spec.non_media_population_scaling_id, dtype=
|
|
336
|
+
non_media_population_scaling_id = backend.to_tensor(
|
|
337
|
+
self.model_spec.non_media_population_scaling_id, dtype=backend.bool_
|
|
337
338
|
)
|
|
338
339
|
else:
|
|
339
340
|
non_media_population_scaling_id = None
|
|
@@ -349,7 +350,7 @@ class Meridian:
|
|
|
349
350
|
return transformers.KpiTransformer(self.kpi, self.population)
|
|
350
351
|
|
|
351
352
|
@functools.cached_property
|
|
352
|
-
def controls_scaled(self) ->
|
|
353
|
+
def controls_scaled(self) -> backend.Tensor | None:
|
|
353
354
|
if self.controls is not None:
|
|
354
355
|
# If `controls` is defined, then `controls_transformer` is also defined.
|
|
355
356
|
return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
|
|
@@ -357,7 +358,7 @@ class Meridian:
|
|
|
357
358
|
return None
|
|
358
359
|
|
|
359
360
|
@functools.cached_property
|
|
360
|
-
def non_media_treatments_normalized(self) ->
|
|
361
|
+
def non_media_treatments_normalized(self) -> backend.Tensor | None:
|
|
361
362
|
"""Normalized non-media treatments.
|
|
362
363
|
|
|
363
364
|
The non-media treatments values are scaled by population (for channels where
|
|
@@ -372,7 +373,7 @@ class Meridian:
|
|
|
372
373
|
return None
|
|
373
374
|
|
|
374
375
|
@functools.cached_property
|
|
375
|
-
def kpi_scaled(self) ->
|
|
376
|
+
def kpi_scaled(self) -> backend.Tensor:
|
|
376
377
|
return self.kpi_transformer.forward(self.kpi)
|
|
377
378
|
|
|
378
379
|
@functools.cached_property
|
|
@@ -416,14 +417,35 @@ class Meridian:
|
|
|
416
417
|
# Geos are unique, so index is a 1-element array.
|
|
417
418
|
return index[0]
|
|
418
419
|
else:
|
|
419
|
-
return
|
|
420
|
+
return backend.argmax(self.population)
|
|
420
421
|
|
|
421
422
|
@functools.cached_property
|
|
422
|
-
def holdout_id(self) ->
|
|
423
|
+
def holdout_id(self) -> backend.Tensor | None:
|
|
423
424
|
if self.model_spec.holdout_id is None:
|
|
424
425
|
return None
|
|
425
|
-
tensor =
|
|
426
|
-
return tensor[
|
|
426
|
+
tensor = backend.to_tensor(self.model_spec.holdout_id, dtype=backend.bool_)
|
|
427
|
+
return tensor[backend.newaxis, ...] if self.is_national else tensor
|
|
428
|
+
|
|
429
|
+
@functools.cached_property
|
|
430
|
+
def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
|
|
431
|
+
"""Returns `AdstockDecaySpec` object with correctly mapped channels."""
|
|
432
|
+
if isinstance(self.model_spec.adstock_decay_spec, str):
|
|
433
|
+
return adstock_hill.AdstockDecaySpec.from_consistent_type(
|
|
434
|
+
self.model_spec.adstock_decay_spec
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
try:
|
|
438
|
+
return self._create_adstock_decay_functions_from_channel_map(
|
|
439
|
+
self.model_spec.adstock_decay_spec
|
|
440
|
+
)
|
|
441
|
+
except KeyError as e:
|
|
442
|
+
raise ValueError(
|
|
443
|
+
"Unrecognized channel names found in `adstock_decay_spec` keys"
|
|
444
|
+
f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
|
|
445
|
+
" either contain only channel_names"
|
|
446
|
+
f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
|
|
447
|
+
" be one or more of {'media', 'rf', 'organic_media', 'organic_rf'}."
|
|
448
|
+
) from e
|
|
427
449
|
|
|
428
450
|
@functools.cached_property
|
|
429
451
|
def prior_broadcast(self) -> prior_distribution.PriorDistribution:
|
|
@@ -469,7 +491,7 @@ class Meridian:
|
|
|
469
491
|
def compute_non_media_treatments_baseline(
|
|
470
492
|
self,
|
|
471
493
|
non_media_baseline_values: Sequence[str | float] | None = None,
|
|
472
|
-
) ->
|
|
494
|
+
) -> backend.Tensor:
|
|
473
495
|
"""Computes the baseline for each non-media treatment channel.
|
|
474
496
|
|
|
475
497
|
Args:
|
|
@@ -491,16 +513,19 @@ class Meridian:
|
|
|
491
513
|
if non_media_baseline_values is None:
|
|
492
514
|
non_media_baseline_values = self.model_spec.non_media_baseline_values
|
|
493
515
|
|
|
516
|
+
no_op_scaling_factor = backend.ones_like(self.population)[
|
|
517
|
+
:, backend.newaxis, backend.newaxis
|
|
518
|
+
]
|
|
494
519
|
if self.model_spec.non_media_population_scaling_id is not None:
|
|
495
|
-
scaling_factors =
|
|
520
|
+
scaling_factors = backend.where(
|
|
496
521
|
self.model_spec.non_media_population_scaling_id,
|
|
497
|
-
self.population[:,
|
|
498
|
-
|
|
522
|
+
self.population[:, backend.newaxis, backend.newaxis],
|
|
523
|
+
no_op_scaling_factor,
|
|
499
524
|
)
|
|
500
525
|
else:
|
|
501
|
-
scaling_factors =
|
|
526
|
+
scaling_factors = no_op_scaling_factor
|
|
502
527
|
|
|
503
|
-
non_media_treatments_population_scaled =
|
|
528
|
+
non_media_treatments_population_scaled = backend.divide_no_nan(
|
|
504
529
|
self.non_media_treatments, scaling_factors
|
|
505
530
|
)
|
|
506
531
|
|
|
@@ -528,15 +553,15 @@ class Meridian:
|
|
|
528
553
|
baseline_value = non_media_baseline_values_filled[channel]
|
|
529
554
|
|
|
530
555
|
if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
|
|
531
|
-
baseline_for_channel =
|
|
556
|
+
baseline_for_channel = backend.reduce_min(
|
|
532
557
|
non_media_treatments_population_scaled[..., channel], axis=[0, 1]
|
|
533
558
|
)
|
|
534
559
|
elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
|
|
535
|
-
baseline_for_channel =
|
|
560
|
+
baseline_for_channel = backend.reduce_max(
|
|
536
561
|
non_media_treatments_population_scaled[..., channel], axis=[0, 1]
|
|
537
562
|
)
|
|
538
563
|
elif isinstance(baseline_value, numbers.Number):
|
|
539
|
-
baseline_for_channel =
|
|
564
|
+
baseline_for_channel = backend.cast(baseline_value, backend.float32)
|
|
540
565
|
else:
|
|
541
566
|
raise ValueError(
|
|
542
567
|
f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
|
|
@@ -545,7 +570,7 @@ class Meridian:
|
|
|
545
570
|
|
|
546
571
|
baseline_list.append(baseline_for_channel)
|
|
547
572
|
|
|
548
|
-
return
|
|
573
|
+
return backend.stack(baseline_list, axis=-1)
|
|
549
574
|
|
|
550
575
|
def expand_selected_time_dims(
|
|
551
576
|
self,
|
|
@@ -766,6 +791,58 @@ class Meridian:
|
|
|
766
791
|
f" ({self.n_non_media_channels},)`."
|
|
767
792
|
)
|
|
768
793
|
|
|
794
|
+
def _create_adstock_decay_functions_from_channel_map(
|
|
795
|
+
self, channel_function_map: Mapping[str, str]
|
|
796
|
+
) -> adstock_hill.AdstockDecaySpec:
|
|
797
|
+
"""Create `AdstockDecaySpec` from mapping from channels to decay functions."""
|
|
798
|
+
|
|
799
|
+
for channel in channel_function_map:
|
|
800
|
+
if channel not in self.input_data.get_all_adstock_hill_channels():
|
|
801
|
+
raise KeyError(f"Channel {channel} not found in data.")
|
|
802
|
+
|
|
803
|
+
if self.input_data.media_channel is not None:
|
|
804
|
+
media_channel_builder = self.input_data.get_paid_media_channels_argument_builder().with_default_value(
|
|
805
|
+
constants.GEOMETRIC_DECAY
|
|
806
|
+
)
|
|
807
|
+
media_adstock_function = media_channel_builder(**channel_function_map)
|
|
808
|
+
else:
|
|
809
|
+
media_adstock_function = constants.GEOMETRIC_DECAY
|
|
810
|
+
|
|
811
|
+
if self.input_data.rf_channel is not None:
|
|
812
|
+
rf_channel_builder = self.input_data.get_paid_rf_channels_argument_builder().with_default_value(
|
|
813
|
+
constants.GEOMETRIC_DECAY
|
|
814
|
+
)
|
|
815
|
+
rf_adstock_function = rf_channel_builder(**channel_function_map)
|
|
816
|
+
else:
|
|
817
|
+
rf_adstock_function = constants.GEOMETRIC_DECAY
|
|
818
|
+
|
|
819
|
+
if self.input_data.organic_media_channel is not None:
|
|
820
|
+
organic_media_channel_builder = self.input_data.get_organic_media_channels_argument_builder().with_default_value(
|
|
821
|
+
constants.GEOMETRIC_DECAY
|
|
822
|
+
)
|
|
823
|
+
organic_media_adstock_function = organic_media_channel_builder(
|
|
824
|
+
**channel_function_map
|
|
825
|
+
)
|
|
826
|
+
else:
|
|
827
|
+
organic_media_adstock_function = constants.GEOMETRIC_DECAY
|
|
828
|
+
|
|
829
|
+
if self.input_data.organic_rf_channel is not None:
|
|
830
|
+
organic_rf_channel_builder = self.input_data.get_organic_rf_channels_argument_builder().with_default_value(
|
|
831
|
+
constants.GEOMETRIC_DECAY
|
|
832
|
+
)
|
|
833
|
+
organic_rf_adstock_function = organic_rf_channel_builder(
|
|
834
|
+
**channel_function_map
|
|
835
|
+
)
|
|
836
|
+
else:
|
|
837
|
+
organic_rf_adstock_function = constants.GEOMETRIC_DECAY
|
|
838
|
+
|
|
839
|
+
return adstock_hill.AdstockDecaySpec(
|
|
840
|
+
media=media_adstock_function,
|
|
841
|
+
rf=rf_adstock_function,
|
|
842
|
+
organic_media=organic_media_adstock_function,
|
|
843
|
+
organic_rf=organic_rf_adstock_function,
|
|
844
|
+
)
|
|
845
|
+
|
|
769
846
|
def _warn_setting_ignored_priors(self):
|
|
770
847
|
"""Raises a warning if ignored priors are set."""
|
|
771
848
|
default_distribution = prior_distribution.PriorDistribution()
|
|
@@ -946,7 +1023,7 @@ class Meridian:
|
|
|
946
1023
|
|
|
947
1024
|
def _check_if_no_geo_variation(
|
|
948
1025
|
self,
|
|
949
|
-
scaled_data:
|
|
1026
|
+
scaled_data: backend.Tensor,
|
|
950
1027
|
data_name: str,
|
|
951
1028
|
data_dims: Sequence[str],
|
|
952
1029
|
epsilon=1e-4,
|
|
@@ -954,16 +1031,16 @@ class Meridian:
|
|
|
954
1031
|
"""Raise an error if `n_knots == n_time` and data lacks geo variation."""
|
|
955
1032
|
|
|
956
1033
|
# Result shape: [n, d], where d is the number of axes of condition.
|
|
957
|
-
col_idx_full =
|
|
958
|
-
|
|
959
|
-
]
|
|
960
|
-
col_idx_unique, _, counts =
|
|
1034
|
+
col_idx_full = backend.get_indices_where(
|
|
1035
|
+
backend.reduce_std(scaled_data, axis=0) < epsilon
|
|
1036
|
+
)[:, 1]
|
|
1037
|
+
col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
|
|
961
1038
|
# We use the shape of scaled_data (instead of `n_time`) because the data may
|
|
962
1039
|
# be padded to account for lagged effects.
|
|
963
1040
|
data_n_time = scaled_data.shape[1]
|
|
964
|
-
mask =
|
|
965
|
-
col_idx_bad =
|
|
966
|
-
dims_bad =
|
|
1041
|
+
mask = backend.equal(counts, data_n_time)
|
|
1042
|
+
col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
|
|
1043
|
+
dims_bad = backend.gather(data_dims, col_idx_bad)
|
|
967
1044
|
|
|
968
1045
|
if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
|
|
969
1046
|
raise ValueError(
|
|
@@ -1024,7 +1101,7 @@ class Meridian:
|
|
|
1024
1101
|
|
|
1025
1102
|
def _check_if_no_time_variation(
|
|
1026
1103
|
self,
|
|
1027
|
-
scaled_data:
|
|
1104
|
+
scaled_data: backend.Tensor,
|
|
1028
1105
|
data_name: str,
|
|
1029
1106
|
data_dims: Sequence[str],
|
|
1030
1107
|
epsilon=1e-4,
|
|
@@ -1032,13 +1109,13 @@ class Meridian:
|
|
|
1032
1109
|
"""Raise an error if data lacks time variation."""
|
|
1033
1110
|
|
|
1034
1111
|
# Result shape: [n, d], where d is the number of axes of condition.
|
|
1035
|
-
col_idx_full =
|
|
1036
|
-
|
|
1037
|
-
]
|
|
1038
|
-
col_idx_unique, _, counts =
|
|
1039
|
-
mask =
|
|
1040
|
-
col_idx_bad =
|
|
1041
|
-
dims_bad =
|
|
1112
|
+
col_idx_full = backend.get_indices_where(
|
|
1113
|
+
backend.reduce_std(scaled_data, axis=1) < epsilon
|
|
1114
|
+
)[:, 1]
|
|
1115
|
+
col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
|
|
1116
|
+
mask = backend.equal(counts, self.n_geos)
|
|
1117
|
+
col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
|
|
1118
|
+
dims_bad = backend.gather(data_dims, col_idx_bad)
|
|
1042
1119
|
if col_idx_bad.shape[0]:
|
|
1043
1120
|
if self.is_national:
|
|
1044
1121
|
raise ValueError(
|
|
@@ -1058,12 +1135,19 @@ class Meridian:
|
|
|
1058
1135
|
" time."
|
|
1059
1136
|
)
|
|
1060
1137
|
|
|
1138
|
+
def _kpi_has_variability(self):
|
|
1139
|
+
"""Returns True if the KPI has variability across geos and times."""
|
|
1140
|
+
return self.kpi_transformer.population_scaled_stdev != 0
|
|
1141
|
+
|
|
1061
1142
|
def _validate_kpi_transformer(self):
|
|
1062
1143
|
"""Validates the KPI transformer."""
|
|
1144
|
+
if self._kpi_has_variability():
|
|
1145
|
+
return
|
|
1146
|
+
|
|
1063
1147
|
kpi = "kpi" if self.is_national else "population_scaled_kpi"
|
|
1148
|
+
|
|
1064
1149
|
if (
|
|
1065
1150
|
self.n_media_channels > 0
|
|
1066
|
-
and self.kpi_transformer.population_scaled_stdev == 0
|
|
1067
1151
|
and self.model_spec.effective_media_prior_type
|
|
1068
1152
|
in constants.PAID_MEDIA_ROI_PRIOR_TYPES
|
|
1069
1153
|
):
|
|
@@ -1074,7 +1158,6 @@ class Meridian:
|
|
|
1074
1158
|
)
|
|
1075
1159
|
if (
|
|
1076
1160
|
self.n_rf_channels > 0
|
|
1077
|
-
and self.kpi_transformer.population_scaled_stdev == 0
|
|
1078
1161
|
and self.model_spec.effective_rf_prior_type
|
|
1079
1162
|
in constants.PAID_MEDIA_ROI_PRIOR_TYPES
|
|
1080
1163
|
):
|
|
@@ -1082,14 +1165,44 @@ class Meridian:
|
|
|
1082
1165
|
f"`{kpi}` cannot be constant with"
|
|
1083
1166
|
f' `rf_prior_type` = "{self.model_spec.effective_rf_prior_type}".'
|
|
1084
1167
|
)
|
|
1168
|
+
if (
|
|
1169
|
+
self.n_organic_media_channels > 0
|
|
1170
|
+
and self.model_spec.organic_media_prior_type
|
|
1171
|
+
in [constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION]
|
|
1172
|
+
):
|
|
1173
|
+
raise ValueError(
|
|
1174
|
+
f"`{kpi}` cannot be constant with"
|
|
1175
|
+
" `organic_media_prior_type` ="
|
|
1176
|
+
f' "{self.model_spec.organic_media_prior_type}".'
|
|
1177
|
+
)
|
|
1178
|
+
if (
|
|
1179
|
+
self.n_organic_rf_channels > 0
|
|
1180
|
+
and self.model_spec.organic_rf_prior_type
|
|
1181
|
+
in [constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION]
|
|
1182
|
+
):
|
|
1183
|
+
raise ValueError(
|
|
1184
|
+
f"`{kpi}` cannot be constant with"
|
|
1185
|
+
" `organic_rf_prior_type` ="
|
|
1186
|
+
f' "{self.model_spec.organic_rf_prior_type}".'
|
|
1187
|
+
)
|
|
1188
|
+
if (
|
|
1189
|
+
self.n_non_media_channels > 0
|
|
1190
|
+
and self.model_spec.non_media_treatments_prior_type
|
|
1191
|
+
in [constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION]
|
|
1192
|
+
):
|
|
1193
|
+
raise ValueError(
|
|
1194
|
+
f"`{kpi}` cannot be constant with"
|
|
1195
|
+
" `non_media_treatments_prior_type` ="
|
|
1196
|
+
f' "{self.model_spec.non_media_treatments_prior_type}".'
|
|
1197
|
+
)
|
|
1085
1198
|
|
|
1086
1199
|
def linear_predictor_counterfactual_difference_media(
|
|
1087
1200
|
self,
|
|
1088
|
-
media_transformed:
|
|
1089
|
-
alpha_m:
|
|
1090
|
-
ec_m:
|
|
1091
|
-
slope_m:
|
|
1092
|
-
) ->
|
|
1201
|
+
media_transformed: backend.Tensor,
|
|
1202
|
+
alpha_m: backend.Tensor,
|
|
1203
|
+
ec_m: backend.Tensor,
|
|
1204
|
+
slope_m: backend.Tensor,
|
|
1205
|
+
) -> backend.Tensor:
|
|
1093
1206
|
"""Calculates linear predictor counterfactual difference for non-RF media.
|
|
1094
1207
|
|
|
1095
1208
|
For non-RF media variables (paid or organic), this function calculates the
|
|
@@ -1118,18 +1231,21 @@ class Meridian:
|
|
|
1118
1231
|
alpha_m,
|
|
1119
1232
|
ec_m,
|
|
1120
1233
|
slope_m,
|
|
1234
|
+
decay_functions=self.adstock_decay_spec.media,
|
|
1121
1235
|
)
|
|
1122
1236
|
# Absolute values is needed because the difference is negative for mROI
|
|
1123
1237
|
# priors and positive for ROI and contribution priors.
|
|
1124
|
-
return
|
|
1238
|
+
return backend.absolute(
|
|
1239
|
+
media_transformed - media_transformed_counterfactual
|
|
1240
|
+
)
|
|
1125
1241
|
|
|
1126
1242
|
def linear_predictor_counterfactual_difference_rf(
|
|
1127
1243
|
self,
|
|
1128
|
-
rf_transformed:
|
|
1129
|
-
alpha_rf:
|
|
1130
|
-
ec_rf:
|
|
1131
|
-
slope_rf:
|
|
1132
|
-
) ->
|
|
1244
|
+
rf_transformed: backend.Tensor,
|
|
1245
|
+
alpha_rf: backend.Tensor,
|
|
1246
|
+
ec_rf: backend.Tensor,
|
|
1247
|
+
slope_rf: backend.Tensor,
|
|
1248
|
+
) -> backend.Tensor:
|
|
1133
1249
|
"""Calculates linear predictor counterfactual difference for RF media.
|
|
1134
1250
|
|
|
1135
1251
|
For RF media variables (paid or organic), this function calculates the
|
|
@@ -1159,19 +1275,20 @@ class Meridian:
|
|
|
1159
1275
|
alpha=alpha_rf,
|
|
1160
1276
|
ec=ec_rf,
|
|
1161
1277
|
slope=slope_rf,
|
|
1278
|
+
decay_functions=self.adstock_decay_spec.rf,
|
|
1162
1279
|
)
|
|
1163
1280
|
# Absolute values is needed because the difference is negative for mROI
|
|
1164
1281
|
# priors and positive for ROI and contribution priors.
|
|
1165
|
-
return
|
|
1282
|
+
return backend.absolute(rf_transformed - rf_transformed_counterfactual)
|
|
1166
1283
|
|
|
1167
1284
|
def calculate_beta_x(
|
|
1168
1285
|
self,
|
|
1169
1286
|
is_non_media: bool,
|
|
1170
|
-
incremental_outcome_x:
|
|
1171
|
-
linear_predictor_counterfactual_difference:
|
|
1172
|
-
eta_x:
|
|
1173
|
-
beta_gx_dev:
|
|
1174
|
-
) ->
|
|
1287
|
+
incremental_outcome_x: backend.Tensor,
|
|
1288
|
+
linear_predictor_counterfactual_difference: backend.Tensor,
|
|
1289
|
+
eta_x: backend.Tensor,
|
|
1290
|
+
beta_gx_dev: backend.Tensor,
|
|
1291
|
+
) -> backend.Tensor:
|
|
1175
1292
|
"""Calculates coefficient mean parameter for any treatment variable type.
|
|
1176
1293
|
|
|
1177
1294
|
The "beta_x" in the function name refers to the coefficient mean parameter
|
|
@@ -1216,10 +1333,12 @@ class Meridian:
|
|
|
1216
1333
|
self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
1217
1334
|
)
|
|
1218
1335
|
if self.revenue_per_kpi is None:
|
|
1219
|
-
revenue_per_kpi =
|
|
1336
|
+
revenue_per_kpi = backend.ones(
|
|
1337
|
+
[self.n_geos, self.n_times], dtype=backend.float32
|
|
1338
|
+
)
|
|
1220
1339
|
else:
|
|
1221
1340
|
revenue_per_kpi = self.revenue_per_kpi
|
|
1222
|
-
incremental_outcome_gx_over_beta_gx =
|
|
1341
|
+
incremental_outcome_gx_over_beta_gx = backend.einsum(
|
|
1223
1342
|
"...gtx,gt,g,->...gx",
|
|
1224
1343
|
linear_predictor_counterfactual_difference,
|
|
1225
1344
|
revenue_per_kpi,
|
|
@@ -1227,34 +1346,35 @@ class Meridian:
|
|
|
1227
1346
|
self.kpi_transformer.population_scaled_stdev,
|
|
1228
1347
|
)
|
|
1229
1348
|
if random_effects_normal:
|
|
1230
|
-
numerator_term_x =
|
|
1349
|
+
numerator_term_x = backend.einsum(
|
|
1231
1350
|
"...gx,...gx,...x->...x",
|
|
1232
1351
|
incremental_outcome_gx_over_beta_gx,
|
|
1233
1352
|
beta_gx_dev,
|
|
1234
1353
|
eta_x,
|
|
1235
1354
|
)
|
|
1236
|
-
denominator_term_x =
|
|
1355
|
+
denominator_term_x = backend.einsum(
|
|
1237
1356
|
"...gx->...x", incremental_outcome_gx_over_beta_gx
|
|
1238
1357
|
)
|
|
1239
1358
|
return (incremental_outcome_x - numerator_term_x) / denominator_term_x
|
|
1240
1359
|
# For log-normal random effects, beta_x and eta_x are not mean & std.
|
|
1241
1360
|
# The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
|
|
1242
|
-
denominator_term_x =
|
|
1361
|
+
denominator_term_x = backend.einsum(
|
|
1243
1362
|
"...gx,...gx->...x",
|
|
1244
1363
|
incremental_outcome_gx_over_beta_gx,
|
|
1245
|
-
|
|
1364
|
+
backend.exp(beta_gx_dev * eta_x[..., backend.newaxis, :]),
|
|
1246
1365
|
)
|
|
1247
|
-
return
|
|
1366
|
+
return backend.log(incremental_outcome_x) - backend.log(denominator_term_x)
|
|
1248
1367
|
|
|
1249
1368
|
def adstock_hill_media(
|
|
1250
1369
|
self,
|
|
1251
|
-
media:
|
|
1252
|
-
alpha:
|
|
1253
|
-
ec:
|
|
1254
|
-
slope:
|
|
1370
|
+
media: backend.Tensor, # pylint: disable=redefined-outer-name
|
|
1371
|
+
alpha: backend.Tensor,
|
|
1372
|
+
ec: backend.Tensor,
|
|
1373
|
+
slope: backend.Tensor,
|
|
1374
|
+
decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
|
|
1255
1375
|
n_times_output: int | None = None,
|
|
1256
|
-
) ->
|
|
1257
|
-
"""Transforms media using Adstock and Hill functions in the desired order.
|
|
1376
|
+
) -> backend.Tensor:
|
|
1377
|
+
"""Transforms media or using Adstock and Hill functions in the desired order.
|
|
1258
1378
|
|
|
1259
1379
|
Args:
|
|
1260
1380
|
media: Tensor of dimensions `(n_geos, n_media_times, n_media_channels)`
|
|
@@ -1264,6 +1384,8 @@ class Meridian:
|
|
|
1264
1384
|
alpha: Uniform distribution for Adstock and Hill calculations.
|
|
1265
1385
|
ec: Shifted half-normal distribution for Adstock and Hill calculations.
|
|
1266
1386
|
slope: Deterministic distribution for Adstock and Hill calculations.
|
|
1387
|
+
decay_functions: String or sequence of strings denoting the adstock decay
|
|
1388
|
+
function(s) for each channel. Default: 'geometric'.
|
|
1267
1389
|
n_times_output: Number of time periods to output. This argument is
|
|
1268
1390
|
optional when the number of time periods in `media` equals
|
|
1269
1391
|
`self.n_media_times`, in which case `n_times_output` defaults to
|
|
@@ -1284,6 +1406,7 @@ class Meridian:
|
|
|
1284
1406
|
alpha=alpha,
|
|
1285
1407
|
max_lag=self.model_spec.max_lag,
|
|
1286
1408
|
n_times_output=n_times_output,
|
|
1409
|
+
decay_functions=decay_functions,
|
|
1287
1410
|
)
|
|
1288
1411
|
hill_transformer = adstock_hill.HillTransformer(
|
|
1289
1412
|
ec=ec,
|
|
@@ -1302,13 +1425,14 @@ class Meridian:
|
|
|
1302
1425
|
|
|
1303
1426
|
def adstock_hill_rf(
|
|
1304
1427
|
self,
|
|
1305
|
-
reach:
|
|
1306
|
-
frequency:
|
|
1307
|
-
alpha:
|
|
1308
|
-
ec:
|
|
1309
|
-
slope:
|
|
1428
|
+
reach: backend.Tensor,
|
|
1429
|
+
frequency: backend.Tensor,
|
|
1430
|
+
alpha: backend.Tensor,
|
|
1431
|
+
ec: backend.Tensor,
|
|
1432
|
+
slope: backend.Tensor,
|
|
1433
|
+
decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
|
|
1310
1434
|
n_times_output: int | None = None,
|
|
1311
|
-
) ->
|
|
1435
|
+
) -> backend.Tensor:
|
|
1312
1436
|
"""Transforms reach and frequency (RF) using Hill and Adstock functions.
|
|
1313
1437
|
|
|
1314
1438
|
Args:
|
|
@@ -1319,6 +1443,8 @@ class Meridian:
|
|
|
1319
1443
|
alpha: Uniform distribution for Adstock and Hill calculations.
|
|
1320
1444
|
ec: Shifted half-normal distribution for Adstock and Hill calculations.
|
|
1321
1445
|
slope: Deterministic distribution for Adstock and Hill calculations.
|
|
1446
|
+
decay_functions: String or sequence of strings denoting the adstock decay
|
|
1447
|
+
function(s) for each channel. Default: 'geometric'.
|
|
1322
1448
|
n_times_output: Number of time periods to output. This argument is
|
|
1323
1449
|
optional when the number of time periods in `reach` equals
|
|
1324
1450
|
`self.n_media_times`, in which case `n_times_output` defaults to
|
|
@@ -1343,6 +1469,7 @@ class Meridian:
|
|
|
1343
1469
|
alpha=alpha,
|
|
1344
1470
|
max_lag=self.model_spec.max_lag,
|
|
1345
1471
|
n_times_output=n_times_output,
|
|
1472
|
+
decay_functions=decay_functions,
|
|
1346
1473
|
)
|
|
1347
1474
|
adj_frequency = hill_transformer.forward(frequency)
|
|
1348
1475
|
rf_out = adstock_transformer.forward(reach * adj_frequency)
|
|
@@ -1448,7 +1575,7 @@ class Meridian:
|
|
|
1448
1575
|
n_adapt: int,
|
|
1449
1576
|
n_burnin: int,
|
|
1450
1577
|
n_keep: int,
|
|
1451
|
-
current_state: Mapping[str,
|
|
1578
|
+
current_state: Mapping[str, backend.Tensor] | None = None,
|
|
1452
1579
|
init_step_size: int | None = None,
|
|
1453
1580
|
dual_averaging_kwargs: Mapping[str, int] | None = None,
|
|
1454
1581
|
max_tree_depth: int = 10,
|