google-meridian 1.3.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.1.dist-info → google_meridian-1.4.0.dist-info}/METADATA +13 -9
- google_meridian-1.4.0.dist-info/RECORD +108 -0
- {google_meridian-1.3.1.dist-info → google_meridian-1.4.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/__init__.py +1 -2
- meridian/analysis/analyzer.py +0 -1
- meridian/analysis/optimizer.py +5 -3
- meridian/analysis/review/checks.py +81 -30
- meridian/analysis/review/constants.py +4 -0
- meridian/analysis/review/results.py +40 -9
- meridian/analysis/summarizer.py +8 -3
- meridian/analysis/test_utils.py +934 -485
- meridian/analysis/visualizer.py +11 -7
- meridian/backend/__init__.py +53 -5
- meridian/backend/test_utils.py +72 -0
- meridian/constants.py +2 -0
- meridian/data/load.py +2 -0
- meridian/data/test_utils.py +82 -10
- meridian/model/__init__.py +2 -0
- meridian/model/context.py +925 -0
- meridian/model/eda/__init__.py +0 -1
- meridian/model/eda/constants.py +13 -2
- meridian/model/eda/eda_engine.py +299 -37
- meridian/model/eda/eda_outcome.py +21 -1
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +75 -47
- meridian/model/model.py +93 -792
- meridian/{analysis/templates → templates}/card.html.jinja +1 -1
- meridian/{analysis/templates → templates}/chart.html.jinja +1 -1
- meridian/{analysis/templates → templates}/chips.html.jinja +1 -1
- meridian/{analysis → templates}/formatter.py +12 -1
- meridian/templates/formatter_test.py +216 -0
- meridian/{analysis/templates → templates}/insights.html.jinja +1 -1
- meridian/{analysis/templates → templates}/stats.html.jinja +1 -1
- meridian/{analysis/templates → templates}/style.css +1 -1
- meridian/{analysis/templates → templates}/style.scss +1 -1
- meridian/{analysis/templates → templates}/summary.html.jinja +4 -2
- meridian/{analysis/templates → templates}/table.html.jinja +1 -1
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +354 -0
- schema/__init__.py +15 -0
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1136 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +412 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/test_data.py +380 -0
- schema/utils/__init__.py +1 -0
- schema/utils/date_range_bucketing.py +117 -0
- google_meridian-1.3.1.dist-info/RECORD +0 -76
- meridian/model/eda/meridian_eda.py +0 -220
- {google_meridian-1.3.1.dist-info → google_meridian-1.4.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.3.1.dist-info → google_meridian-1.4.0.dist-info}/licenses/LICENSE +0 -0
meridian/model/model.py
CHANGED
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
import collections
|
|
18
18
|
from collections.abc import Mapping, Sequence
|
|
19
19
|
import functools
|
|
20
|
-
import numbers
|
|
21
20
|
import os
|
|
22
21
|
import warnings
|
|
23
22
|
|
|
@@ -28,6 +27,8 @@ from meridian import constants
|
|
|
28
27
|
from meridian.data import input_data as data
|
|
29
28
|
from meridian.data import time_coordinates as tc
|
|
30
29
|
from meridian.model import adstock_hill
|
|
30
|
+
from meridian.model import context
|
|
31
|
+
from meridian.model import equations
|
|
31
32
|
from meridian.model import knots
|
|
32
33
|
from meridian.model import media
|
|
33
34
|
from meridian.model import posterior_sampler
|
|
@@ -76,27 +77,15 @@ def _warn_setting_national_args(**kwargs):
|
|
|
76
77
|
)
|
|
77
78
|
|
|
78
79
|
|
|
79
|
-
def _check_for_negative_effect(
|
|
80
|
-
dist: backend.tfd.Distribution, media_effects_dist: str
|
|
81
|
-
):
|
|
82
|
-
"""Checks for negative effect in the model."""
|
|
83
|
-
if (
|
|
84
|
-
media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
|
|
85
|
-
and np.any(dist.cdf(0)) > 0
|
|
86
|
-
):
|
|
87
|
-
raise ValueError(
|
|
88
|
-
"Media priors must have non-negative support when"
|
|
89
|
-
f' `media_effects_dist`="{media_effects_dist}". Found negative effect'
|
|
90
|
-
f" in {dist.name}."
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
|
|
94
80
|
class Meridian:
|
|
95
81
|
"""Contains the main functionality for fitting the Meridian MMM model.
|
|
96
82
|
|
|
97
83
|
Attributes:
|
|
98
84
|
input_data: An `InputData` object containing the input data for the model.
|
|
99
85
|
model_spec: A `ModelSpec` object containing the model specification.
|
|
86
|
+
model_context: A `ModelContext` object containing the model context.
|
|
87
|
+
equations: A `ModelEquations` object containing stateless mathematical
|
|
88
|
+
functions and utilities for Meridian MMM.
|
|
100
89
|
inference_data: A _mutable_ `arviz.InferenceData` object containing the
|
|
101
90
|
resulting data from fitting the model.
|
|
102
91
|
eda_engine: An `EDAEngine` object containing the EDA engine.
|
|
@@ -168,15 +157,17 @@ class Meridian:
|
|
|
168
157
|
) = None, # for deserializer use only
|
|
169
158
|
eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
|
|
170
159
|
):
|
|
171
|
-
self._input_data = input_data
|
|
172
|
-
self._model_spec = model_spec if model_spec else spec.ModelSpec()
|
|
173
160
|
self._inference_data = (
|
|
174
161
|
inference_data if inference_data else az.InferenceData()
|
|
175
162
|
)
|
|
163
|
+
self._model_context = context.ModelContext(
|
|
164
|
+
input_data=input_data,
|
|
165
|
+
model_spec=model_spec if model_spec else spec.ModelSpec(),
|
|
166
|
+
)
|
|
167
|
+
self._equations = equations.ModelEquations(self._model_context)
|
|
176
168
|
|
|
177
169
|
self._eda_spec = eda_spec
|
|
178
170
|
|
|
179
|
-
self._validate_data_dependent_model_spec()
|
|
180
171
|
self._validate_injected_inference_data()
|
|
181
172
|
|
|
182
173
|
if self.is_national:
|
|
@@ -184,22 +175,23 @@ class Meridian:
|
|
|
184
175
|
media_effects_dist=self.model_spec.media_effects_dist,
|
|
185
176
|
unique_sigma_for_each_geo=self.model_spec.unique_sigma_for_each_geo,
|
|
186
177
|
)
|
|
187
|
-
self.
|
|
188
|
-
self._set_total_media_contribution_prior = False
|
|
189
|
-
self._validate_mroi_priors_non_revenue()
|
|
190
|
-
self._validate_roi_priors_non_revenue()
|
|
191
|
-
self._check_for_negative_effects()
|
|
192
|
-
self._validate_geo_invariants()
|
|
193
|
-
self._validate_time_invariants()
|
|
194
|
-
self._validate_kpi_transformer()
|
|
178
|
+
self._validate_kpi_variability()
|
|
195
179
|
|
|
196
180
|
@property
|
|
197
181
|
def input_data(self) -> data.InputData:
|
|
198
|
-
return self.
|
|
182
|
+
return self._model_context.input_data
|
|
199
183
|
|
|
200
184
|
@property
|
|
201
185
|
def model_spec(self) -> spec.ModelSpec:
|
|
202
|
-
return self.
|
|
186
|
+
return self._model_context.model_spec
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def model_context(self) -> context.ModelContext:
|
|
190
|
+
return self._model_context
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def equations(self) -> equations.ModelEquations:
|
|
194
|
+
return self._equations
|
|
203
195
|
|
|
204
196
|
@property
|
|
205
197
|
def inference_data(self) -> az.InferenceData:
|
|
@@ -219,176 +211,111 @@ class Meridian:
|
|
|
219
211
|
|
|
220
212
|
@functools.cached_property
|
|
221
213
|
def media_tensors(self) -> media.MediaTensors:
|
|
222
|
-
return
|
|
214
|
+
return self._model_context.media_tensors
|
|
223
215
|
|
|
224
216
|
@functools.cached_property
|
|
225
217
|
def rf_tensors(self) -> media.RfTensors:
|
|
226
|
-
return
|
|
218
|
+
return self._model_context.rf_tensors
|
|
227
219
|
|
|
228
220
|
@functools.cached_property
|
|
229
221
|
def organic_media_tensors(self) -> media.OrganicMediaTensors:
|
|
230
|
-
return
|
|
222
|
+
return self._model_context.organic_media_tensors
|
|
231
223
|
|
|
232
224
|
@functools.cached_property
|
|
233
225
|
def organic_rf_tensors(self) -> media.OrganicRfTensors:
|
|
234
|
-
return
|
|
226
|
+
return self._model_context.organic_rf_tensors
|
|
235
227
|
|
|
236
228
|
@functools.cached_property
|
|
237
229
|
def kpi(self) -> backend.Tensor:
|
|
238
|
-
return
|
|
230
|
+
return self._model_context.kpi
|
|
239
231
|
|
|
240
232
|
@functools.cached_property
|
|
241
233
|
def revenue_per_kpi(self) -> backend.Tensor | None:
|
|
242
|
-
|
|
243
|
-
return None
|
|
244
|
-
return backend.to_tensor(
|
|
245
|
-
self.input_data.revenue_per_kpi, dtype=backend.float32
|
|
246
|
-
)
|
|
234
|
+
return self._model_context.revenue_per_kpi
|
|
247
235
|
|
|
248
236
|
@functools.cached_property
|
|
249
237
|
def controls(self) -> backend.Tensor | None:
|
|
250
|
-
|
|
251
|
-
return None
|
|
252
|
-
return backend.to_tensor(self.input_data.controls, dtype=backend.float32)
|
|
238
|
+
return self._model_context.controls
|
|
253
239
|
|
|
254
240
|
@functools.cached_property
|
|
255
241
|
def non_media_treatments(self) -> backend.Tensor | None:
|
|
256
|
-
|
|
257
|
-
return None
|
|
258
|
-
return backend.to_tensor(
|
|
259
|
-
self.input_data.non_media_treatments, dtype=backend.float32
|
|
260
|
-
)
|
|
242
|
+
return self._model_context.non_media_treatments
|
|
261
243
|
|
|
262
244
|
@functools.cached_property
|
|
263
245
|
def population(self) -> backend.Tensor:
|
|
264
|
-
return
|
|
246
|
+
return self._model_context.population
|
|
265
247
|
|
|
266
248
|
@functools.cached_property
|
|
267
249
|
def total_spend(self) -> backend.Tensor:
|
|
268
|
-
return
|
|
269
|
-
self.input_data.get_total_spend(), dtype=backend.float32
|
|
270
|
-
)
|
|
250
|
+
return self._model_context.total_spend
|
|
271
251
|
|
|
272
252
|
@functools.cached_property
|
|
273
253
|
def total_outcome(self) -> backend.Tensor:
|
|
274
|
-
return
|
|
275
|
-
self.input_data.get_total_outcome(), dtype=backend.float32
|
|
276
|
-
)
|
|
254
|
+
return self._model_context.total_outcome
|
|
277
255
|
|
|
278
256
|
@property
|
|
279
257
|
def n_geos(self) -> int:
|
|
280
|
-
return
|
|
258
|
+
return self._model_context.n_geos
|
|
281
259
|
|
|
282
260
|
@property
|
|
283
261
|
def n_media_channels(self) -> int:
|
|
284
|
-
|
|
285
|
-
return 0
|
|
286
|
-
return len(self.input_data.media_channel)
|
|
262
|
+
return self._model_context.n_media_channels
|
|
287
263
|
|
|
288
264
|
@property
|
|
289
265
|
def n_rf_channels(self) -> int:
|
|
290
|
-
|
|
291
|
-
return 0
|
|
292
|
-
return len(self.input_data.rf_channel)
|
|
266
|
+
return self._model_context.n_rf_channels
|
|
293
267
|
|
|
294
268
|
@property
|
|
295
269
|
def n_organic_media_channels(self) -> int:
|
|
296
|
-
|
|
297
|
-
return 0
|
|
298
|
-
return len(self.input_data.organic_media_channel)
|
|
270
|
+
return self._model_context.n_organic_media_channels
|
|
299
271
|
|
|
300
272
|
@property
|
|
301
273
|
def n_organic_rf_channels(self) -> int:
|
|
302
|
-
|
|
303
|
-
return 0
|
|
304
|
-
return len(self.input_data.organic_rf_channel)
|
|
274
|
+
return self._model_context.n_organic_rf_channels
|
|
305
275
|
|
|
306
276
|
@property
|
|
307
277
|
def n_controls(self) -> int:
|
|
308
|
-
|
|
309
|
-
return 0
|
|
310
|
-
return len(self.input_data.control_variable)
|
|
278
|
+
return self._model_context.n_controls
|
|
311
279
|
|
|
312
280
|
@property
|
|
313
281
|
def n_non_media_channels(self) -> int:
|
|
314
|
-
|
|
315
|
-
return 0
|
|
316
|
-
return len(self.input_data.non_media_channel)
|
|
282
|
+
return self._model_context.n_non_media_channels
|
|
317
283
|
|
|
318
284
|
@property
|
|
319
285
|
def n_times(self) -> int:
|
|
320
|
-
return
|
|
286
|
+
return self._model_context.n_times
|
|
321
287
|
|
|
322
288
|
@property
|
|
323
289
|
def n_media_times(self) -> int:
|
|
324
|
-
return
|
|
290
|
+
return self._model_context.n_media_times
|
|
325
291
|
|
|
326
292
|
@property
|
|
327
293
|
def is_national(self) -> bool:
|
|
328
|
-
return self.
|
|
294
|
+
return self._model_context.is_national
|
|
329
295
|
|
|
330
296
|
@functools.cached_property
|
|
331
297
|
def knot_info(self) -> knots.KnotInfo:
|
|
332
|
-
return
|
|
333
|
-
n_times=self.n_times,
|
|
334
|
-
knots=self.model_spec.knots,
|
|
335
|
-
enable_aks=self.model_spec.enable_aks,
|
|
336
|
-
data=self.input_data,
|
|
337
|
-
is_national=self.is_national,
|
|
338
|
-
)
|
|
298
|
+
return self._model_context.knot_info
|
|
339
299
|
|
|
340
300
|
@functools.cached_property
|
|
341
301
|
def controls_transformer(
|
|
342
302
|
self,
|
|
343
303
|
) -> transformers.CenteringAndScalingTransformer | None:
|
|
344
|
-
|
|
345
|
-
if self.controls is None:
|
|
346
|
-
return None
|
|
347
|
-
|
|
348
|
-
if self.model_spec.control_population_scaling_id is not None:
|
|
349
|
-
controls_population_scaling_id = backend.to_tensor(
|
|
350
|
-
self.model_spec.control_population_scaling_id, dtype=backend.bool_
|
|
351
|
-
)
|
|
352
|
-
else:
|
|
353
|
-
controls_population_scaling_id = None
|
|
354
|
-
|
|
355
|
-
return transformers.CenteringAndScalingTransformer(
|
|
356
|
-
tensor=self.controls,
|
|
357
|
-
population=self.population,
|
|
358
|
-
population_scaling_id=controls_population_scaling_id,
|
|
359
|
-
)
|
|
304
|
+
return self._model_context.controls_transformer
|
|
360
305
|
|
|
361
306
|
@functools.cached_property
|
|
362
307
|
def non_media_transformer(
|
|
363
308
|
self,
|
|
364
309
|
) -> transformers.CenteringAndScalingTransformer | None:
|
|
365
|
-
|
|
366
|
-
if self.non_media_treatments is None:
|
|
367
|
-
return None
|
|
368
|
-
if self.model_spec.non_media_population_scaling_id is not None:
|
|
369
|
-
non_media_population_scaling_id = backend.to_tensor(
|
|
370
|
-
self.model_spec.non_media_population_scaling_id, dtype=backend.bool_
|
|
371
|
-
)
|
|
372
|
-
else:
|
|
373
|
-
non_media_population_scaling_id = None
|
|
374
|
-
|
|
375
|
-
return transformers.CenteringAndScalingTransformer(
|
|
376
|
-
tensor=self.non_media_treatments,
|
|
377
|
-
population=self.population,
|
|
378
|
-
population_scaling_id=non_media_population_scaling_id,
|
|
379
|
-
)
|
|
310
|
+
return self._model_context.non_media_transformer
|
|
380
311
|
|
|
381
312
|
@functools.cached_property
|
|
382
313
|
def kpi_transformer(self) -> transformers.KpiTransformer:
|
|
383
|
-
return
|
|
314
|
+
return self._model_context.kpi_transformer
|
|
384
315
|
|
|
385
316
|
@functools.cached_property
|
|
386
317
|
def controls_scaled(self) -> backend.Tensor | None:
|
|
387
|
-
|
|
388
|
-
# If `controls` is defined, then `controls_transformer` is also defined.
|
|
389
|
-
return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
|
|
390
|
-
else:
|
|
391
|
-
return None
|
|
318
|
+
return self._model_context.controls_scaled
|
|
392
319
|
|
|
393
320
|
@functools.cached_property
|
|
394
321
|
def non_media_treatments_normalized(self) -> backend.Tensor | None:
|
|
@@ -398,117 +325,38 @@ class Meridian:
|
|
|
398
325
|
`non_media_population_scaling_id` is `True`) and normalized by centering and
|
|
399
326
|
scaling with means and standard deviations.
|
|
400
327
|
"""
|
|
401
|
-
|
|
402
|
-
return self.non_media_transformer.forward(
|
|
403
|
-
self.non_media_treatments
|
|
404
|
-
) # pytype: disable=attribute-error
|
|
405
|
-
else:
|
|
406
|
-
return None
|
|
328
|
+
return self._model_context.non_media_treatments_normalized
|
|
407
329
|
|
|
408
330
|
@functools.cached_property
|
|
409
331
|
def kpi_scaled(self) -> backend.Tensor:
|
|
410
|
-
return self.
|
|
332
|
+
return self._model_context.kpi_scaled
|
|
411
333
|
|
|
412
334
|
@functools.cached_property
|
|
413
335
|
def media_effects_dist(self) -> str:
|
|
414
|
-
|
|
415
|
-
return constants.NATIONAL_MODEL_SPEC_ARGS[constants.MEDIA_EFFECTS_DIST]
|
|
416
|
-
else:
|
|
417
|
-
return self.model_spec.media_effects_dist
|
|
336
|
+
return self._model_context.media_effects_dist
|
|
418
337
|
|
|
419
338
|
@functools.cached_property
|
|
420
339
|
def unique_sigma_for_each_geo(self) -> bool:
|
|
421
|
-
|
|
422
|
-
# Should evaluate to False.
|
|
423
|
-
return constants.NATIONAL_MODEL_SPEC_ARGS[
|
|
424
|
-
constants.UNIQUE_SIGMA_FOR_EACH_GEO
|
|
425
|
-
]
|
|
426
|
-
else:
|
|
427
|
-
return self.model_spec.unique_sigma_for_each_geo
|
|
340
|
+
return self._model_context.unique_sigma_for_each_geo
|
|
428
341
|
|
|
429
342
|
@functools.cached_property
|
|
430
343
|
def baseline_geo_idx(self) -> int:
|
|
431
344
|
"""Returns the index of the baseline geo."""
|
|
432
|
-
|
|
433
|
-
if (
|
|
434
|
-
self.model_spec.baseline_geo < 0
|
|
435
|
-
or self.model_spec.baseline_geo >= self.n_geos
|
|
436
|
-
):
|
|
437
|
-
raise ValueError(
|
|
438
|
-
f"Baseline geo index {self.model_spec.baseline_geo} out of range"
|
|
439
|
-
f" [0, {self.n_geos - 1}]."
|
|
440
|
-
)
|
|
441
|
-
return self.model_spec.baseline_geo
|
|
442
|
-
elif isinstance(self.model_spec.baseline_geo, str):
|
|
443
|
-
# np.where returns a 1-D tuple, its first element is an array of found
|
|
444
|
-
# elements.
|
|
445
|
-
index = np.where(self.input_data.geo == self.model_spec.baseline_geo)[0]
|
|
446
|
-
if index.size == 0:
|
|
447
|
-
raise ValueError(
|
|
448
|
-
f"Baseline geo '{self.model_spec.baseline_geo}' not found."
|
|
449
|
-
)
|
|
450
|
-
# Geos are unique, so index is a 1-element array.
|
|
451
|
-
return index[0]
|
|
452
|
-
else:
|
|
453
|
-
return backend.argmax(self.population)
|
|
345
|
+
return self._model_context.baseline_geo_idx
|
|
454
346
|
|
|
455
347
|
@functools.cached_property
|
|
456
348
|
def holdout_id(self) -> backend.Tensor | None:
|
|
457
|
-
|
|
458
|
-
return None
|
|
459
|
-
tensor = backend.to_tensor(self.model_spec.holdout_id, dtype=backend.bool_)
|
|
460
|
-
return tensor[backend.newaxis, ...] if self.is_national else tensor
|
|
349
|
+
return self._model_context.holdout_id
|
|
461
350
|
|
|
462
351
|
@functools.cached_property
|
|
463
352
|
def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
|
|
464
353
|
"""Returns `AdstockDecaySpec` object with correctly mapped channels."""
|
|
465
|
-
|
|
466
|
-
return adstock_hill.AdstockDecaySpec.from_consistent_type(
|
|
467
|
-
self.model_spec.adstock_decay_spec
|
|
468
|
-
)
|
|
469
|
-
|
|
470
|
-
try:
|
|
471
|
-
return self._create_adstock_decay_functions_from_channel_map(
|
|
472
|
-
self.model_spec.adstock_decay_spec
|
|
473
|
-
)
|
|
474
|
-
except KeyError as e:
|
|
475
|
-
raise ValueError(
|
|
476
|
-
"Unrecognized channel names found in `adstock_decay_spec` keys"
|
|
477
|
-
f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
|
|
478
|
-
" either contain only channel_names"
|
|
479
|
-
f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
|
|
480
|
-
" be one or more of {'media', 'rf', 'organic_media',"
|
|
481
|
-
" 'organic_rf'}."
|
|
482
|
-
) from e
|
|
354
|
+
return self._model_context.adstock_decay_spec
|
|
483
355
|
|
|
484
356
|
@functools.cached_property
|
|
485
357
|
def prior_broadcast(self) -> prior_distribution.PriorDistribution:
|
|
486
358
|
"""Returns broadcasted `PriorDistribution` object."""
|
|
487
|
-
|
|
488
|
-
# Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
|
|
489
|
-
if len(total_spend.shape) == 1:
|
|
490
|
-
# Already aggregated by channel.
|
|
491
|
-
agg_total_spend = total_spend
|
|
492
|
-
elif len(total_spend.shape) == 2:
|
|
493
|
-
agg_total_spend = np.sum(total_spend, axis=(0,))
|
|
494
|
-
else:
|
|
495
|
-
agg_total_spend = np.sum(total_spend, axis=(0, 1))
|
|
496
|
-
|
|
497
|
-
return self.model_spec.prior.broadcast(
|
|
498
|
-
n_geos=self.n_geos,
|
|
499
|
-
n_media_channels=self.n_media_channels,
|
|
500
|
-
n_rf_channels=self.n_rf_channels,
|
|
501
|
-
n_organic_media_channels=self.n_organic_media_channels,
|
|
502
|
-
n_organic_rf_channels=self.n_organic_rf_channels,
|
|
503
|
-
n_controls=self.n_controls,
|
|
504
|
-
n_non_media_channels=self.n_non_media_channels,
|
|
505
|
-
unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
|
|
506
|
-
n_knots=self.knot_info.n_knots,
|
|
507
|
-
is_national=self.is_national,
|
|
508
|
-
set_total_media_contribution_prior=self._set_total_media_contribution_prior,
|
|
509
|
-
kpi=np.sum(self.input_data.kpi.values),
|
|
510
|
-
total_spend=agg_total_spend,
|
|
511
|
-
)
|
|
359
|
+
return self._model_context.prior_broadcast
|
|
512
360
|
|
|
513
361
|
@functools.cached_property
|
|
514
362
|
def prior_sampler_callable(self) -> prior_sampler.PriorDistributionSampler:
|
|
@@ -522,6 +370,8 @@ class Meridian:
|
|
|
522
370
|
"""A `PosteriorMCMCSampler` callable bound to this model."""
|
|
523
371
|
return posterior_sampler.PosteriorMCMCSampler(self)
|
|
524
372
|
|
|
373
|
+
# TODO: Deprecate this method in favor of the one in
|
|
374
|
+
# `equations.py`.
|
|
525
375
|
def compute_non_media_treatments_baseline(
|
|
526
376
|
self,
|
|
527
377
|
non_media_baseline_values: Sequence[str | float] | None = None,
|
|
@@ -544,70 +394,10 @@ class Meridian:
|
|
|
544
394
|
A tensor of shape `(n_non_media_channels,)` containing the
|
|
545
395
|
baseline values for each non-media treatment channel.
|
|
546
396
|
"""
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
no_op_scaling_factor = backend.ones_like(self.population)[
|
|
551
|
-
:, backend.newaxis, backend.newaxis
|
|
552
|
-
]
|
|
553
|
-
if self.model_spec.non_media_population_scaling_id is not None:
|
|
554
|
-
scaling_factors = backend.where(
|
|
555
|
-
self.model_spec.non_media_population_scaling_id,
|
|
556
|
-
self.population[:, backend.newaxis, backend.newaxis],
|
|
557
|
-
no_op_scaling_factor,
|
|
558
|
-
)
|
|
559
|
-
else:
|
|
560
|
-
scaling_factors = no_op_scaling_factor
|
|
561
|
-
|
|
562
|
-
non_media_treatments_population_scaled = backend.divide_no_nan(
|
|
563
|
-
self.non_media_treatments, scaling_factors
|
|
397
|
+
return self.equations.compute_non_media_treatments_baseline(
|
|
398
|
+
non_media_baseline_values=non_media_baseline_values
|
|
564
399
|
)
|
|
565
400
|
|
|
566
|
-
if non_media_baseline_values is None:
|
|
567
|
-
# If non_media_baseline_values is not provided, use the minimum
|
|
568
|
-
# value for each non_media treatment channel as the baseline.
|
|
569
|
-
non_media_baseline_values_filled = [
|
|
570
|
-
constants.NON_MEDIA_BASELINE_MIN
|
|
571
|
-
] * non_media_treatments_population_scaled.shape[-1]
|
|
572
|
-
else:
|
|
573
|
-
non_media_baseline_values_filled = non_media_baseline_values
|
|
574
|
-
|
|
575
|
-
if non_media_treatments_population_scaled.shape[-1] != len(
|
|
576
|
-
non_media_baseline_values_filled
|
|
577
|
-
):
|
|
578
|
-
raise ValueError(
|
|
579
|
-
"The number of non-media channels"
|
|
580
|
-
f" ({non_media_treatments_population_scaled.shape[-1]}) does not"
|
|
581
|
-
" match the number of baseline values"
|
|
582
|
-
f" ({len(non_media_baseline_values_filled)})."
|
|
583
|
-
)
|
|
584
|
-
|
|
585
|
-
baseline_list = []
|
|
586
|
-
for channel in range(non_media_treatments_population_scaled.shape[-1]):
|
|
587
|
-
baseline_value = non_media_baseline_values_filled[channel]
|
|
588
|
-
|
|
589
|
-
if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
|
|
590
|
-
baseline_for_channel = backend.reduce_min(
|
|
591
|
-
non_media_treatments_population_scaled[..., channel], axis=[0, 1]
|
|
592
|
-
)
|
|
593
|
-
elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
|
|
594
|
-
baseline_for_channel = backend.reduce_max(
|
|
595
|
-
non_media_treatments_population_scaled[..., channel], axis=[0, 1]
|
|
596
|
-
)
|
|
597
|
-
elif isinstance(baseline_value, numbers.Number):
|
|
598
|
-
baseline_for_channel = backend.to_tensor(
|
|
599
|
-
baseline_value, dtype=backend.float32
|
|
600
|
-
)
|
|
601
|
-
else:
|
|
602
|
-
raise ValueError(
|
|
603
|
-
f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
|
|
604
|
-
" float numbers and strings 'min' and 'max' are supported."
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
baseline_list.append(baseline_for_channel)
|
|
608
|
-
|
|
609
|
-
return backend.stack(baseline_list, axis=-1)
|
|
610
|
-
|
|
611
401
|
def expand_selected_time_dims(
|
|
612
402
|
self,
|
|
613
403
|
start_date: tc.Date = None,
|
|
@@ -752,427 +542,8 @@ class Meridian:
|
|
|
752
542
|
self.n_non_media_channels,
|
|
753
543
|
)
|
|
754
544
|
|
|
755
|
-
def
|
|
756
|
-
"""Validates
|
|
757
|
-
|
|
758
|
-
if (
|
|
759
|
-
self.model_spec.roi_calibration_period is not None
|
|
760
|
-
and self.model_spec.roi_calibration_period.shape
|
|
761
|
-
!= (
|
|
762
|
-
self.n_media_times,
|
|
763
|
-
self.n_media_channels,
|
|
764
|
-
)
|
|
765
|
-
):
|
|
766
|
-
raise ValueError(
|
|
767
|
-
"The shape of `roi_calibration_period`"
|
|
768
|
-
f" {self.model_spec.roi_calibration_period.shape} is different from"
|
|
769
|
-
f" `(n_media_times, n_media_channels) = ({self.n_media_times},"
|
|
770
|
-
f" {self.n_media_channels})`."
|
|
771
|
-
)
|
|
772
|
-
|
|
773
|
-
if (
|
|
774
|
-
self.model_spec.rf_roi_calibration_period is not None
|
|
775
|
-
and self.model_spec.rf_roi_calibration_period.shape
|
|
776
|
-
!= (
|
|
777
|
-
self.n_media_times,
|
|
778
|
-
self.n_rf_channels,
|
|
779
|
-
)
|
|
780
|
-
):
|
|
781
|
-
raise ValueError(
|
|
782
|
-
"The shape of `rf_roi_calibration_period`"
|
|
783
|
-
f" {self.model_spec.rf_roi_calibration_period.shape} is different"
|
|
784
|
-
f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
|
|
785
|
-
f" {self.n_rf_channels})`."
|
|
786
|
-
)
|
|
787
|
-
|
|
788
|
-
if self.model_spec.holdout_id is not None:
|
|
789
|
-
if self.is_national and (
|
|
790
|
-
self.model_spec.holdout_id.shape != (self.n_times,)
|
|
791
|
-
):
|
|
792
|
-
raise ValueError(
|
|
793
|
-
f"The shape of `holdout_id` {self.model_spec.holdout_id.shape} is"
|
|
794
|
-
f" different from `(n_times,) = ({self.n_times},)`."
|
|
795
|
-
)
|
|
796
|
-
elif not self.is_national and (
|
|
797
|
-
self.model_spec.holdout_id.shape
|
|
798
|
-
!= (
|
|
799
|
-
self.n_geos,
|
|
800
|
-
self.n_times,
|
|
801
|
-
)
|
|
802
|
-
):
|
|
803
|
-
raise ValueError(
|
|
804
|
-
f"The shape of `holdout_id` {self.model_spec.holdout_id.shape} is"
|
|
805
|
-
f" different from `(n_geos, n_times) = ({self.n_geos},"
|
|
806
|
-
f" {self.n_times})`."
|
|
807
|
-
)
|
|
808
|
-
|
|
809
|
-
if self.model_spec.control_population_scaling_id is not None and (
|
|
810
|
-
self.model_spec.control_population_scaling_id.shape
|
|
811
|
-
!= (self.n_controls,)
|
|
812
|
-
):
|
|
813
|
-
raise ValueError(
|
|
814
|
-
"The shape of `control_population_scaling_id`"
|
|
815
|
-
f" {self.model_spec.control_population_scaling_id.shape} is different"
|
|
816
|
-
f" from `(n_controls,) = ({self.n_controls},)`."
|
|
817
|
-
)
|
|
818
|
-
|
|
819
|
-
if self.model_spec.non_media_population_scaling_id is not None and (
|
|
820
|
-
self.model_spec.non_media_population_scaling_id.shape
|
|
821
|
-
!= (self.n_non_media_channels,)
|
|
822
|
-
):
|
|
823
|
-
raise ValueError(
|
|
824
|
-
"The shape of `non_media_population_scaling_id`"
|
|
825
|
-
f" {self.model_spec.non_media_population_scaling_id.shape} is"
|
|
826
|
-
" different from `(n_non_media_channels,) ="
|
|
827
|
-
f" ({self.n_non_media_channels},)`."
|
|
828
|
-
)
|
|
829
|
-
|
|
830
|
-
def _create_adstock_decay_functions_from_channel_map(
|
|
831
|
-
self, channel_function_map: Mapping[str, str]
|
|
832
|
-
) -> adstock_hill.AdstockDecaySpec:
|
|
833
|
-
"""Create `AdstockDecaySpec` from mapping from channels to decay functions."""
|
|
834
|
-
|
|
835
|
-
for channel in channel_function_map:
|
|
836
|
-
if channel not in self.input_data.get_all_adstock_hill_channels():
|
|
837
|
-
raise KeyError(f"Channel {channel} not found in data.")
|
|
838
|
-
|
|
839
|
-
if self.input_data.media_channel is not None:
|
|
840
|
-
media_channel_builder = self.input_data.get_paid_media_channels_argument_builder().with_default_value(
|
|
841
|
-
constants.GEOMETRIC_DECAY
|
|
842
|
-
)
|
|
843
|
-
media_adstock_function = media_channel_builder(**channel_function_map)
|
|
844
|
-
else:
|
|
845
|
-
media_adstock_function = constants.GEOMETRIC_DECAY
|
|
846
|
-
|
|
847
|
-
if self.input_data.rf_channel is not None:
|
|
848
|
-
rf_channel_builder = self.input_data.get_paid_rf_channels_argument_builder().with_default_value(
|
|
849
|
-
constants.GEOMETRIC_DECAY
|
|
850
|
-
)
|
|
851
|
-
rf_adstock_function = rf_channel_builder(**channel_function_map)
|
|
852
|
-
else:
|
|
853
|
-
rf_adstock_function = constants.GEOMETRIC_DECAY
|
|
854
|
-
|
|
855
|
-
if self.input_data.organic_media_channel is not None:
|
|
856
|
-
organic_media_channel_builder = self.input_data.get_organic_media_channels_argument_builder().with_default_value(
|
|
857
|
-
constants.GEOMETRIC_DECAY
|
|
858
|
-
)
|
|
859
|
-
organic_media_adstock_function = organic_media_channel_builder(
|
|
860
|
-
**channel_function_map
|
|
861
|
-
)
|
|
862
|
-
else:
|
|
863
|
-
organic_media_adstock_function = constants.GEOMETRIC_DECAY
|
|
864
|
-
|
|
865
|
-
if self.input_data.organic_rf_channel is not None:
|
|
866
|
-
organic_rf_channel_builder = self.input_data.get_organic_rf_channels_argument_builder().with_default_value(
|
|
867
|
-
constants.GEOMETRIC_DECAY
|
|
868
|
-
)
|
|
869
|
-
organic_rf_adstock_function = organic_rf_channel_builder(
|
|
870
|
-
**channel_function_map
|
|
871
|
-
)
|
|
872
|
-
else:
|
|
873
|
-
organic_rf_adstock_function = constants.GEOMETRIC_DECAY
|
|
874
|
-
|
|
875
|
-
return adstock_hill.AdstockDecaySpec(
|
|
876
|
-
media=media_adstock_function,
|
|
877
|
-
rf=rf_adstock_function,
|
|
878
|
-
organic_media=organic_media_adstock_function,
|
|
879
|
-
organic_rf=organic_rf_adstock_function,
|
|
880
|
-
)
|
|
881
|
-
|
|
882
|
-
def _warn_setting_ignored_priors(self):
|
|
883
|
-
"""Raises a warning if ignored priors are set."""
|
|
884
|
-
default_distribution = prior_distribution.PriorDistribution()
|
|
885
|
-
for ignored_priors_dict, prior_type, prior_type_name in (
|
|
886
|
-
(
|
|
887
|
-
constants.IGNORED_PRIORS_MEDIA,
|
|
888
|
-
self.model_spec.effective_media_prior_type,
|
|
889
|
-
"media_prior_type",
|
|
890
|
-
),
|
|
891
|
-
(
|
|
892
|
-
constants.IGNORED_PRIORS_RF,
|
|
893
|
-
self.model_spec.effective_rf_prior_type,
|
|
894
|
-
"rf_prior_type",
|
|
895
|
-
),
|
|
896
|
-
):
|
|
897
|
-
ignored_custom_priors = []
|
|
898
|
-
for prior in ignored_priors_dict.get(prior_type, []):
|
|
899
|
-
self_prior = getattr(self.model_spec.prior, prior)
|
|
900
|
-
default_prior = getattr(default_distribution, prior)
|
|
901
|
-
if not prior_distribution.distributions_are_equal(
|
|
902
|
-
self_prior, default_prior
|
|
903
|
-
):
|
|
904
|
-
ignored_custom_priors.append(prior)
|
|
905
|
-
if ignored_custom_priors:
|
|
906
|
-
ignored_priors_str = ", ".join(ignored_custom_priors)
|
|
907
|
-
warnings.warn(
|
|
908
|
-
f"Custom prior(s) `{ignored_priors_str}` are ignored when"
|
|
909
|
-
f' `{prior_type_name}` is set to "{prior_type}".'
|
|
910
|
-
)
|
|
911
|
-
|
|
912
|
-
def _validate_mroi_priors_non_revenue(self):
|
|
913
|
-
"""Validates mroi priors in the non-revenue outcome case."""
|
|
914
|
-
if (
|
|
915
|
-
self.input_data.kpi_type == constants.NON_REVENUE
|
|
916
|
-
and self.input_data.revenue_per_kpi is None
|
|
917
|
-
):
|
|
918
|
-
default_distribution = prior_distribution.PriorDistribution()
|
|
919
|
-
if (
|
|
920
|
-
self.n_media_channels > 0
|
|
921
|
-
and (
|
|
922
|
-
self.model_spec.effective_media_prior_type
|
|
923
|
-
== constants.TREATMENT_PRIOR_TYPE_MROI
|
|
924
|
-
)
|
|
925
|
-
and prior_distribution.distributions_are_equal(
|
|
926
|
-
self.model_spec.prior.mroi_m, default_distribution.mroi_m
|
|
927
|
-
)
|
|
928
|
-
):
|
|
929
|
-
raise ValueError(
|
|
930
|
-
f"Custom priors should be set on `{constants.MROI_M}` when"
|
|
931
|
-
' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
|
|
932
|
-
" kpi data is missing."
|
|
933
|
-
)
|
|
934
|
-
if (
|
|
935
|
-
self.n_rf_channels > 0
|
|
936
|
-
and (
|
|
937
|
-
self.model_spec.effective_rf_prior_type
|
|
938
|
-
== constants.TREATMENT_PRIOR_TYPE_MROI
|
|
939
|
-
)
|
|
940
|
-
and prior_distribution.distributions_are_equal(
|
|
941
|
-
self.model_spec.prior.mroi_rf, default_distribution.mroi_rf
|
|
942
|
-
)
|
|
943
|
-
):
|
|
944
|
-
raise ValueError(
|
|
945
|
-
f"Custom priors should be set on `{constants.MROI_RF}` when"
|
|
946
|
-
' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
|
|
947
|
-
" data is missing."
|
|
948
|
-
)
|
|
949
|
-
|
|
950
|
-
def _validate_roi_priors_non_revenue(self):
|
|
951
|
-
"""Validates roi priors in the non-revenue outcome case."""
|
|
952
|
-
if (
|
|
953
|
-
self.input_data.kpi_type == constants.NON_REVENUE
|
|
954
|
-
and self.input_data.revenue_per_kpi is None
|
|
955
|
-
):
|
|
956
|
-
default_distribution = prior_distribution.PriorDistribution()
|
|
957
|
-
default_roi_m_used = (
|
|
958
|
-
self.model_spec.effective_media_prior_type
|
|
959
|
-
== constants.TREATMENT_PRIOR_TYPE_ROI
|
|
960
|
-
and prior_distribution.distributions_are_equal(
|
|
961
|
-
self.model_spec.prior.roi_m, default_distribution.roi_m
|
|
962
|
-
)
|
|
963
|
-
)
|
|
964
|
-
default_roi_rf_used = (
|
|
965
|
-
self.model_spec.effective_rf_prior_type
|
|
966
|
-
== constants.TREATMENT_PRIOR_TYPE_ROI
|
|
967
|
-
and prior_distribution.distributions_are_equal(
|
|
968
|
-
self.model_spec.prior.roi_rf, default_distribution.roi_rf
|
|
969
|
-
)
|
|
970
|
-
)
|
|
971
|
-
# If ROI priors are used with the default prior distribution for all paid
|
|
972
|
-
# channels (media and RF), then use the "total paid media contribution
|
|
973
|
-
# prior" procedure.
|
|
974
|
-
if (
|
|
975
|
-
(default_roi_m_used and default_roi_rf_used)
|
|
976
|
-
or (self.n_media_channels == 0 and default_roi_rf_used)
|
|
977
|
-
or (self.n_rf_channels == 0 and default_roi_m_used)
|
|
978
|
-
):
|
|
979
|
-
self._set_total_media_contribution_prior = True
|
|
980
|
-
warnings.warn(
|
|
981
|
-
"Consider setting custom ROI priors, as kpi_type was specified as"
|
|
982
|
-
" `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
|
|
983
|
-
" total media contribution prior will be used with"
|
|
984
|
-
f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
|
|
985
|
-
" documentation available at "
|
|
986
|
-
" https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
|
|
987
|
-
)
|
|
988
|
-
elif self.n_media_channels > 0 and default_roi_m_used:
|
|
989
|
-
raise ValueError(
|
|
990
|
-
f"Custom priors should be set on `{constants.ROI_M}` when"
|
|
991
|
-
' `media_prior_type` is "roi", custom priors are assigned on'
|
|
992
|
-
' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
|
|
993
|
-
" non-revenue and revenue per kpi data is missing."
|
|
994
|
-
)
|
|
995
|
-
elif self.n_rf_channels > 0 and default_roi_rf_used:
|
|
996
|
-
raise ValueError(
|
|
997
|
-
f"Custom priors should be set on `{constants.ROI_RF}` when"
|
|
998
|
-
' `rf_prior_type` is "roi", custom priors are assigned on'
|
|
999
|
-
' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
|
|
1000
|
-
" non-revenue and revenue per kpi data is missing."
|
|
1001
|
-
)
|
|
1002
|
-
|
|
1003
|
-
def _check_for_negative_effects(self):
|
|
1004
|
-
prior = self.model_spec.prior
|
|
1005
|
-
if self.n_media_channels > 0:
|
|
1006
|
-
_check_for_negative_effect(prior.roi_m, self.media_effects_dist)
|
|
1007
|
-
_check_for_negative_effect(prior.mroi_m, self.media_effects_dist)
|
|
1008
|
-
if self.n_rf_channels > 0:
|
|
1009
|
-
_check_for_negative_effect(prior.roi_rf, self.media_effects_dist)
|
|
1010
|
-
_check_for_negative_effect(prior.mroi_rf, self.media_effects_dist)
|
|
1011
|
-
|
|
1012
|
-
def _validate_geo_invariants(self):
|
|
1013
|
-
"""Validates non-national model invariants."""
|
|
1014
|
-
if self.is_national:
|
|
1015
|
-
return
|
|
1016
|
-
|
|
1017
|
-
if self.input_data.controls is not None:
|
|
1018
|
-
self._check_if_no_geo_variation(
|
|
1019
|
-
self.controls_scaled,
|
|
1020
|
-
constants.CONTROLS,
|
|
1021
|
-
self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
1022
|
-
)
|
|
1023
|
-
if self.input_data.non_media_treatments is not None:
|
|
1024
|
-
self._check_if_no_geo_variation(
|
|
1025
|
-
self.non_media_treatments_normalized,
|
|
1026
|
-
constants.NON_MEDIA_TREATMENTS,
|
|
1027
|
-
self.input_data.non_media_treatments.coords[
|
|
1028
|
-
constants.NON_MEDIA_CHANNEL
|
|
1029
|
-
].values,
|
|
1030
|
-
)
|
|
1031
|
-
if self.input_data.media is not None:
|
|
1032
|
-
self._check_if_no_geo_variation(
|
|
1033
|
-
self.media_tensors.media_scaled,
|
|
1034
|
-
constants.MEDIA,
|
|
1035
|
-
self.input_data.media.coords[constants.MEDIA_CHANNEL].values,
|
|
1036
|
-
)
|
|
1037
|
-
if self.input_data.reach is not None:
|
|
1038
|
-
self._check_if_no_geo_variation(
|
|
1039
|
-
self.rf_tensors.reach_scaled,
|
|
1040
|
-
constants.REACH,
|
|
1041
|
-
self.input_data.reach.coords[constants.RF_CHANNEL].values,
|
|
1042
|
-
)
|
|
1043
|
-
if self.input_data.organic_media is not None:
|
|
1044
|
-
self._check_if_no_geo_variation(
|
|
1045
|
-
self.organic_media_tensors.organic_media_scaled,
|
|
1046
|
-
"organic_media",
|
|
1047
|
-
self.input_data.organic_media.coords[
|
|
1048
|
-
constants.ORGANIC_MEDIA_CHANNEL
|
|
1049
|
-
].values,
|
|
1050
|
-
)
|
|
1051
|
-
if self.input_data.organic_reach is not None:
|
|
1052
|
-
self._check_if_no_geo_variation(
|
|
1053
|
-
self.organic_rf_tensors.organic_reach_scaled,
|
|
1054
|
-
"organic_reach",
|
|
1055
|
-
self.input_data.organic_reach.coords[
|
|
1056
|
-
constants.ORGANIC_RF_CHANNEL
|
|
1057
|
-
].values,
|
|
1058
|
-
)
|
|
1059
|
-
|
|
1060
|
-
def _check_if_no_geo_variation(
|
|
1061
|
-
self,
|
|
1062
|
-
scaled_data: backend.Tensor,
|
|
1063
|
-
data_name: str,
|
|
1064
|
-
data_dims: Sequence[str],
|
|
1065
|
-
epsilon=1e-4,
|
|
1066
|
-
):
|
|
1067
|
-
"""Raise an error if `n_knots == n_time` and data lacks geo variation."""
|
|
1068
|
-
|
|
1069
|
-
# Result shape: [n, d], where d is the number of axes of condition.
|
|
1070
|
-
col_idx_full = backend.get_indices_where(
|
|
1071
|
-
backend.reduce_std(scaled_data, axis=0) < epsilon
|
|
1072
|
-
)[:, 1]
|
|
1073
|
-
col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
|
|
1074
|
-
# We use the shape of scaled_data (instead of `n_time`) because the data may
|
|
1075
|
-
# be padded to account for lagged effects.
|
|
1076
|
-
data_n_time = scaled_data.shape[1]
|
|
1077
|
-
mask = backend.equal(counts, data_n_time)
|
|
1078
|
-
col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
|
|
1079
|
-
dims_bad = backend.gather(data_dims, col_idx_bad)
|
|
1080
|
-
|
|
1081
|
-
if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
|
|
1082
|
-
raise ValueError(
|
|
1083
|
-
f"The following {data_name} variables do not vary across geos, making"
|
|
1084
|
-
f" a model with n_knots=n_time unidentifiable: {dims_bad}. This can"
|
|
1085
|
-
" lead to poor model convergence. Since these variables only vary"
|
|
1086
|
-
" across time and not across geo, they are collinear with time and"
|
|
1087
|
-
" redundant in a model with a parameter for each time period. To"
|
|
1088
|
-
" address this, you can either: (1) decrease the number of knots"
|
|
1089
|
-
" (n_knots < n_time), or (2) drop the listed variables that do not"
|
|
1090
|
-
" vary across geos."
|
|
1091
|
-
)
|
|
1092
|
-
|
|
1093
|
-
def _validate_time_invariants(self):
|
|
1094
|
-
"""Validates model time invariants."""
|
|
1095
|
-
if self.input_data.controls is not None:
|
|
1096
|
-
self._check_if_no_time_variation(
|
|
1097
|
-
self.controls_scaled,
|
|
1098
|
-
constants.CONTROLS,
|
|
1099
|
-
self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
1100
|
-
)
|
|
1101
|
-
if self.input_data.non_media_treatments is not None:
|
|
1102
|
-
self._check_if_no_time_variation(
|
|
1103
|
-
self.non_media_treatments_normalized,
|
|
1104
|
-
constants.NON_MEDIA_TREATMENTS,
|
|
1105
|
-
self.input_data.non_media_treatments.coords[
|
|
1106
|
-
constants.NON_MEDIA_CHANNEL
|
|
1107
|
-
].values,
|
|
1108
|
-
)
|
|
1109
|
-
if self.input_data.media is not None:
|
|
1110
|
-
self._check_if_no_time_variation(
|
|
1111
|
-
self.media_tensors.media_scaled,
|
|
1112
|
-
constants.MEDIA,
|
|
1113
|
-
self.input_data.media.coords[constants.MEDIA_CHANNEL].values,
|
|
1114
|
-
)
|
|
1115
|
-
if self.input_data.reach is not None:
|
|
1116
|
-
self._check_if_no_time_variation(
|
|
1117
|
-
self.rf_tensors.reach_scaled,
|
|
1118
|
-
constants.REACH,
|
|
1119
|
-
self.input_data.reach.coords[constants.RF_CHANNEL].values,
|
|
1120
|
-
)
|
|
1121
|
-
if self.input_data.organic_media is not None:
|
|
1122
|
-
self._check_if_no_time_variation(
|
|
1123
|
-
self.organic_media_tensors.organic_media_scaled,
|
|
1124
|
-
constants.ORGANIC_MEDIA,
|
|
1125
|
-
self.input_data.organic_media.coords[
|
|
1126
|
-
constants.ORGANIC_MEDIA_CHANNEL
|
|
1127
|
-
].values,
|
|
1128
|
-
)
|
|
1129
|
-
if self.input_data.organic_reach is not None:
|
|
1130
|
-
self._check_if_no_time_variation(
|
|
1131
|
-
self.organic_rf_tensors.organic_reach_scaled,
|
|
1132
|
-
constants.ORGANIC_REACH,
|
|
1133
|
-
self.input_data.organic_reach.coords[
|
|
1134
|
-
constants.ORGANIC_RF_CHANNEL
|
|
1135
|
-
].values,
|
|
1136
|
-
)
|
|
1137
|
-
|
|
1138
|
-
def _check_if_no_time_variation(
|
|
1139
|
-
self,
|
|
1140
|
-
scaled_data: backend.Tensor,
|
|
1141
|
-
data_name: str,
|
|
1142
|
-
data_dims: Sequence[str],
|
|
1143
|
-
epsilon=1e-4,
|
|
1144
|
-
):
|
|
1145
|
-
"""Raise an error if data lacks time variation."""
|
|
1146
|
-
|
|
1147
|
-
# Result shape: [n, d], where d is the number of axes of condition.
|
|
1148
|
-
col_idx_full = backend.get_indices_where(
|
|
1149
|
-
backend.reduce_std(scaled_data, axis=1) < epsilon
|
|
1150
|
-
)[:, 1]
|
|
1151
|
-
col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
|
|
1152
|
-
mask = backend.equal(counts, self.n_geos)
|
|
1153
|
-
col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
|
|
1154
|
-
dims_bad = backend.gather(data_dims, col_idx_bad)
|
|
1155
|
-
if col_idx_bad.shape[0]:
|
|
1156
|
-
if self.is_national:
|
|
1157
|
-
raise ValueError(
|
|
1158
|
-
f"The following {data_name} variables do not vary across time,"
|
|
1159
|
-
" which is equivalent to no signal at all in a national model:"
|
|
1160
|
-
f" {dims_bad}. This can lead to poor model convergence. To address"
|
|
1161
|
-
" this, drop the listed variables that do not vary across time."
|
|
1162
|
-
)
|
|
1163
|
-
else:
|
|
1164
|
-
raise ValueError(
|
|
1165
|
-
f"The following {data_name} variables do not vary across time,"
|
|
1166
|
-
f" making a model with geo main effects unidentifiable: {dims_bad}."
|
|
1167
|
-
" This can lead to poor model convergence. Since these variables"
|
|
1168
|
-
" only vary across geo and not across time, they are collinear"
|
|
1169
|
-
" with geo and redundant in a model with geo main effects. To"
|
|
1170
|
-
" address this, drop the listed variables that do not vary across"
|
|
1171
|
-
" time."
|
|
1172
|
-
)
|
|
1173
|
-
|
|
1174
|
-
def _validate_kpi_transformer(self):
|
|
1175
|
-
"""Validates the KPI transformer."""
|
|
545
|
+
def _validate_kpi_variability(self):
|
|
546
|
+
"""Validates the KPI variability."""
|
|
1176
547
|
if self.eda_engine.kpi_has_variability:
|
|
1177
548
|
return
|
|
1178
549
|
kpi = self.eda_engine.kpi_scaled_da.name
|
|
@@ -1227,6 +598,7 @@ class Meridian:
|
|
|
1227
598
|
f' "{self.model_spec.non_media_treatments_prior_type}".'
|
|
1228
599
|
)
|
|
1229
600
|
|
|
601
|
+
# TODO: Deprecate in favor of ModelEquations.adstock_hill_rf.
|
|
1230
602
|
def linear_predictor_counterfactual_difference_media(
|
|
1231
603
|
self,
|
|
1232
604
|
media_transformed: backend.Tensor,
|
|
@@ -1255,21 +627,15 @@ class Meridian:
|
|
|
1255
627
|
The linear predictor difference between the treatment variable and its
|
|
1256
628
|
counterfactual.
|
|
1257
629
|
"""
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
ec_m,
|
|
1264
|
-
slope_m,
|
|
1265
|
-
decay_functions=self.adstock_decay_spec.media,
|
|
1266
|
-
)
|
|
1267
|
-
# Absolute values is needed because the difference is negative for mROI
|
|
1268
|
-
# priors and positive for ROI and contribution priors.
|
|
1269
|
-
return backend.absolute(
|
|
1270
|
-
media_transformed - media_transformed_counterfactual
|
|
630
|
+
return self.equations.linear_predictor_counterfactual_difference_media(
|
|
631
|
+
media_transformed=media_transformed,
|
|
632
|
+
alpha_m=alpha_m,
|
|
633
|
+
ec_m=ec_m,
|
|
634
|
+
slope_m=slope_m,
|
|
1271
635
|
)
|
|
1272
636
|
|
|
637
|
+
# TODO: Deprecate in favor of
|
|
638
|
+
# ModelEquations.linear_predictor_counterfactual_difference_rf.
|
|
1273
639
|
def linear_predictor_counterfactual_difference_rf(
|
|
1274
640
|
self,
|
|
1275
641
|
rf_transformed: backend.Tensor,
|
|
@@ -1298,20 +664,14 @@ class Meridian:
|
|
|
1298
664
|
The linear predictor difference between the treatment variable and its
|
|
1299
665
|
counterfactual.
|
|
1300
666
|
"""
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
alpha=alpha_rf,
|
|
1307
|
-
ec=ec_rf,
|
|
1308
|
-
slope=slope_rf,
|
|
1309
|
-
decay_functions=self.adstock_decay_spec.rf,
|
|
667
|
+
return self.equations.linear_predictor_counterfactual_difference_rf(
|
|
668
|
+
rf_transformed=rf_transformed,
|
|
669
|
+
alpha_rf=alpha_rf,
|
|
670
|
+
ec_rf=ec_rf,
|
|
671
|
+
slope_rf=slope_rf,
|
|
1310
672
|
)
|
|
1311
|
-
# Absolute values is needed because the difference is negative for mROI
|
|
1312
|
-
# priors and positive for ROI and contribution priors.
|
|
1313
|
-
return backend.absolute(rf_transformed - rf_transformed_counterfactual)
|
|
1314
673
|
|
|
674
|
+
# TODO: Deprecate in favor of ModelEquations.calculate_beta_x.
|
|
1315
675
|
def calculate_beta_x(
|
|
1316
676
|
self,
|
|
1317
677
|
is_non_media: bool,
|
|
@@ -1357,45 +717,15 @@ class Meridian:
|
|
|
1357
717
|
The coefficient mean parameter of the treatment variable, which has
|
|
1358
718
|
dimension equal to the number of treatment channels..
|
|
1359
719
|
"""
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
if self.revenue_per_kpi is None:
|
|
1367
|
-
revenue_per_kpi = backend.ones(
|
|
1368
|
-
[self.n_geos, self.n_times], dtype=backend.float32
|
|
1369
|
-
)
|
|
1370
|
-
else:
|
|
1371
|
-
revenue_per_kpi = self.revenue_per_kpi
|
|
1372
|
-
incremental_outcome_gx_over_beta_gx = backend.einsum(
|
|
1373
|
-
"...gtx,gt,g,->...gx",
|
|
1374
|
-
linear_predictor_counterfactual_difference,
|
|
1375
|
-
revenue_per_kpi,
|
|
1376
|
-
self.population,
|
|
1377
|
-
self.kpi_transformer.population_scaled_stdev,
|
|
1378
|
-
)
|
|
1379
|
-
if random_effects_normal:
|
|
1380
|
-
numerator_term_x = backend.einsum(
|
|
1381
|
-
"...gx,...gx,...x->...x",
|
|
1382
|
-
incremental_outcome_gx_over_beta_gx,
|
|
1383
|
-
beta_gx_dev,
|
|
1384
|
-
eta_x,
|
|
1385
|
-
)
|
|
1386
|
-
denominator_term_x = backend.einsum(
|
|
1387
|
-
"...gx->...x", incremental_outcome_gx_over_beta_gx
|
|
1388
|
-
)
|
|
1389
|
-
return (incremental_outcome_x - numerator_term_x) / denominator_term_x
|
|
1390
|
-
# For log-normal random effects, beta_x and eta_x are not mean & std.
|
|
1391
|
-
# The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
|
|
1392
|
-
denominator_term_x = backend.einsum(
|
|
1393
|
-
"...gx,...gx->...x",
|
|
1394
|
-
incremental_outcome_gx_over_beta_gx,
|
|
1395
|
-
backend.exp(beta_gx_dev * eta_x[..., backend.newaxis, :]),
|
|
720
|
+
return self.equations.calculate_beta_x(
|
|
721
|
+
is_non_media=is_non_media,
|
|
722
|
+
incremental_outcome_x=incremental_outcome_x,
|
|
723
|
+
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
724
|
+
eta_x=eta_x,
|
|
725
|
+
beta_gx_dev=beta_gx_dev,
|
|
1396
726
|
)
|
|
1397
|
-
return backend.log(incremental_outcome_x) - backend.log(denominator_term_x)
|
|
1398
727
|
|
|
728
|
+
# TODO: Deprecate in favor of ModelEquations.adstock_hill_media.
|
|
1399
729
|
def adstock_hill_media(
|
|
1400
730
|
self,
|
|
1401
731
|
media: backend.Tensor, # pylint: disable=redefined-outer-name
|
|
@@ -1426,34 +756,16 @@ class Meridian:
|
|
|
1426
756
|
Tensor with dimensions `[..., n_geos, n_times, n_media_channels]`
|
|
1427
757
|
representing Adstock and Hill-transformed media.
|
|
1428
758
|
"""
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
elif n_times_output is None:
|
|
1432
|
-
raise ValueError(
|
|
1433
|
-
"n_times_output is required. This argument is only optional when "
|
|
1434
|
-
"`media` has a number of time periods equal to `self.n_media_times`."
|
|
1435
|
-
)
|
|
1436
|
-
adstock_transformer = adstock_hill.AdstockTransformer(
|
|
759
|
+
return self.equations.adstock_hill_media(
|
|
760
|
+
media=media,
|
|
1437
761
|
alpha=alpha,
|
|
1438
|
-
max_lag=self.model_spec.max_lag,
|
|
1439
|
-
n_times_output=n_times_output,
|
|
1440
|
-
decay_functions=decay_functions,
|
|
1441
|
-
)
|
|
1442
|
-
hill_transformer = adstock_hill.HillTransformer(
|
|
1443
762
|
ec=ec,
|
|
1444
763
|
slope=slope,
|
|
764
|
+
decay_functions=decay_functions,
|
|
765
|
+
n_times_output=n_times_output,
|
|
1445
766
|
)
|
|
1446
|
-
transformers_list = (
|
|
1447
|
-
[hill_transformer, adstock_transformer]
|
|
1448
|
-
if self.model_spec.hill_before_adstock
|
|
1449
|
-
else [adstock_transformer, hill_transformer]
|
|
1450
|
-
)
|
|
1451
|
-
|
|
1452
|
-
media_out = media
|
|
1453
|
-
for transformer in transformers_list:
|
|
1454
|
-
media_out = transformer.forward(media_out)
|
|
1455
|
-
return media_out
|
|
1456
767
|
|
|
768
|
+
# TODO: Deprecate in favor of ModelEquations.adstock_hill_rf.
|
|
1457
769
|
def adstock_hill_rf(
|
|
1458
770
|
self,
|
|
1459
771
|
reach: backend.Tensor,
|
|
@@ -1485,27 +797,15 @@ class Meridian:
|
|
|
1485
797
|
Tensor with dimensions `[..., n_geos, n_times, n_rf_channels]`
|
|
1486
798
|
representing Hill and Adstock-transformed RF.
|
|
1487
799
|
"""
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
"n_times_output is required. This argument is only optional when "
|
|
1493
|
-
"`reach` has a number of time periods equal to `self.n_media_times`."
|
|
1494
|
-
)
|
|
1495
|
-
hill_transformer = adstock_hill.HillTransformer(
|
|
800
|
+
return self.equations.adstock_hill_rf(
|
|
801
|
+
reach=reach,
|
|
802
|
+
frequency=frequency,
|
|
803
|
+
alpha=alpha,
|
|
1496
804
|
ec=ec,
|
|
1497
805
|
slope=slope,
|
|
1498
|
-
)
|
|
1499
|
-
adstock_transformer = adstock_hill.AdstockTransformer(
|
|
1500
|
-
alpha=alpha,
|
|
1501
|
-
max_lag=self.model_spec.max_lag,
|
|
1502
|
-
n_times_output=n_times_output,
|
|
1503
806
|
decay_functions=decay_functions,
|
|
807
|
+
n_times_output=n_times_output,
|
|
1504
808
|
)
|
|
1505
|
-
adj_frequency = hill_transformer.forward(frequency)
|
|
1506
|
-
rf_out = adstock_transformer.forward(reach * adj_frequency)
|
|
1507
|
-
|
|
1508
|
-
return rf_out
|
|
1509
809
|
|
|
1510
810
|
def populate_cached_properties(self):
|
|
1511
811
|
"""Eagerly activates all cached properties.
|
|
@@ -1515,6 +815,7 @@ class Meridian:
|
|
|
1515
815
|
internal state mutations are problematic, and so this method freezes the
|
|
1516
816
|
object's states before the computation graph is created.
|
|
1517
817
|
"""
|
|
818
|
+
self._model_context.populate_cached_properties()
|
|
1518
819
|
cls = self.__class__
|
|
1519
820
|
# "Freeze" all @cached_property attributes by simply accessing them (with
|
|
1520
821
|
# `getattr()`).
|