google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
meridian/model/model.py
CHANGED
|
@@ -16,8 +16,8 @@
|
|
|
16
16
|
|
|
17
17
|
import collections
|
|
18
18
|
from collections.abc import Mapping, Sequence
|
|
19
|
+
import dataclasses
|
|
19
20
|
import functools
|
|
20
|
-
import numbers
|
|
21
21
|
import os
|
|
22
22
|
import warnings
|
|
23
23
|
|
|
@@ -28,6 +28,8 @@ from meridian import constants
|
|
|
28
28
|
from meridian.data import input_data as data
|
|
29
29
|
from meridian.data import time_coordinates as tc
|
|
30
30
|
from meridian.model import adstock_hill
|
|
31
|
+
from meridian.model import context
|
|
32
|
+
from meridian.model import equations
|
|
31
33
|
from meridian.model import knots
|
|
32
34
|
from meridian.model import media
|
|
33
35
|
from meridian.model import posterior_sampler
|
|
@@ -76,27 +78,15 @@ def _warn_setting_national_args(**kwargs):
|
|
|
76
78
|
)
|
|
77
79
|
|
|
78
80
|
|
|
79
|
-
def _check_for_negative_effect(
|
|
80
|
-
dist: backend.tfd.Distribution, media_effects_dist: str
|
|
81
|
-
):
|
|
82
|
-
"""Checks for negative effect in the model."""
|
|
83
|
-
if (
|
|
84
|
-
media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
|
|
85
|
-
and np.any(dist.cdf(0)) > 0
|
|
86
|
-
):
|
|
87
|
-
raise ValueError(
|
|
88
|
-
"Media priors must have non-negative support when"
|
|
89
|
-
f' `media_effects_dist`="{media_effects_dist}". Found negative effect'
|
|
90
|
-
f" in {dist.name}."
|
|
91
|
-
)
|
|
92
|
-
|
|
93
|
-
|
|
94
81
|
class Meridian:
|
|
95
82
|
"""Contains the main functionality for fitting the Meridian MMM model.
|
|
96
83
|
|
|
97
84
|
Attributes:
|
|
98
85
|
input_data: An `InputData` object containing the input data for the model.
|
|
99
86
|
model_spec: A `ModelSpec` object containing the model specification.
|
|
87
|
+
model_context: A `ModelContext` object containing the model context.
|
|
88
|
+
model_equations: A `ModelEquations` object containing stateless mathematical
|
|
89
|
+
functions and utilities for Meridian MMM.
|
|
100
90
|
inference_data: A _mutable_ `arviz.InferenceData` object containing the
|
|
101
91
|
resulting data from fitting the model.
|
|
102
92
|
eda_engine: An `EDAEngine` object containing the EDA engine.
|
|
@@ -168,15 +158,17 @@ class Meridian:
|
|
|
168
158
|
) = None, # for deserializer use only
|
|
169
159
|
eda_spec: eda_spec_module.EDASpec = eda_spec_module.EDASpec(),
|
|
170
160
|
):
|
|
171
|
-
self._input_data = input_data
|
|
172
|
-
self._model_spec = model_spec if model_spec else spec.ModelSpec()
|
|
173
161
|
self._inference_data = (
|
|
174
162
|
inference_data if inference_data else az.InferenceData()
|
|
175
163
|
)
|
|
164
|
+
self._model_context = context.ModelContext(
|
|
165
|
+
input_data=input_data,
|
|
166
|
+
model_spec=model_spec if model_spec else spec.ModelSpec(),
|
|
167
|
+
)
|
|
168
|
+
self._model_equations = equations.ModelEquations(self._model_context)
|
|
176
169
|
|
|
177
170
|
self._eda_spec = eda_spec
|
|
178
171
|
|
|
179
|
-
self._validate_data_dependent_model_spec()
|
|
180
172
|
self._validate_injected_inference_data()
|
|
181
173
|
|
|
182
174
|
if self.is_national:
|
|
@@ -184,22 +176,23 @@ class Meridian:
|
|
|
184
176
|
media_effects_dist=self.model_spec.media_effects_dist,
|
|
185
177
|
unique_sigma_for_each_geo=self.model_spec.unique_sigma_for_each_geo,
|
|
186
178
|
)
|
|
187
|
-
self.
|
|
188
|
-
self._set_total_media_contribution_prior = False
|
|
189
|
-
self._validate_mroi_priors_non_revenue()
|
|
190
|
-
self._validate_roi_priors_non_revenue()
|
|
191
|
-
self._check_for_negative_effects()
|
|
192
|
-
self._validate_geo_invariants()
|
|
193
|
-
self._validate_time_invariants()
|
|
194
|
-
self._validate_kpi_transformer()
|
|
179
|
+
self._validate_kpi_variability()
|
|
195
180
|
|
|
196
181
|
@property
|
|
197
182
|
def input_data(self) -> data.InputData:
|
|
198
|
-
return self.
|
|
183
|
+
return self._model_context.input_data
|
|
199
184
|
|
|
200
185
|
@property
|
|
201
186
|
def model_spec(self) -> spec.ModelSpec:
|
|
202
|
-
return self.
|
|
187
|
+
return self._model_context.model_spec
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def model_context(self) -> context.ModelContext:
|
|
191
|
+
return self._model_context
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def model_equations(self) -> equations.ModelEquations:
|
|
195
|
+
return self._model_equations
|
|
203
196
|
|
|
204
197
|
@property
|
|
205
198
|
def inference_data(self) -> az.InferenceData:
|
|
@@ -207,190 +200,127 @@ class Meridian:
|
|
|
207
200
|
|
|
208
201
|
@functools.cached_property
|
|
209
202
|
def eda_engine(self) -> eda_engine.EDAEngine:
|
|
210
|
-
return eda_engine.EDAEngine(
|
|
203
|
+
return eda_engine.EDAEngine(
|
|
204
|
+
spec=self._eda_spec, model_context=self.model_context
|
|
205
|
+
)
|
|
211
206
|
|
|
212
207
|
@property
|
|
213
208
|
def eda_spec(self) -> eda_spec_module.EDASpec:
|
|
214
209
|
return self._eda_spec
|
|
215
210
|
|
|
216
211
|
@property
|
|
217
|
-
def eda_outcomes(self) ->
|
|
212
|
+
def eda_outcomes(self) -> eda_outcome.CriticalCheckEDAOutcomes:
|
|
218
213
|
return self.eda_engine.run_all_critical_checks()
|
|
219
214
|
|
|
220
|
-
@
|
|
215
|
+
@property
|
|
221
216
|
def media_tensors(self) -> media.MediaTensors:
|
|
222
|
-
return
|
|
217
|
+
return self._model_context.media_tensors
|
|
223
218
|
|
|
224
|
-
@
|
|
219
|
+
@property
|
|
225
220
|
def rf_tensors(self) -> media.RfTensors:
|
|
226
|
-
return
|
|
221
|
+
return self._model_context.rf_tensors
|
|
227
222
|
|
|
228
|
-
@
|
|
223
|
+
@property
|
|
229
224
|
def organic_media_tensors(self) -> media.OrganicMediaTensors:
|
|
230
|
-
return
|
|
225
|
+
return self._model_context.organic_media_tensors
|
|
231
226
|
|
|
232
|
-
@
|
|
227
|
+
@property
|
|
233
228
|
def organic_rf_tensors(self) -> media.OrganicRfTensors:
|
|
234
|
-
return
|
|
229
|
+
return self._model_context.organic_rf_tensors
|
|
235
230
|
|
|
236
|
-
@
|
|
231
|
+
@property
|
|
237
232
|
def kpi(self) -> backend.Tensor:
|
|
238
|
-
return
|
|
233
|
+
return self._model_context.kpi
|
|
239
234
|
|
|
240
|
-
@
|
|
235
|
+
@property
|
|
241
236
|
def revenue_per_kpi(self) -> backend.Tensor | None:
|
|
242
|
-
|
|
243
|
-
return None
|
|
244
|
-
return backend.to_tensor(
|
|
245
|
-
self.input_data.revenue_per_kpi, dtype=backend.float32
|
|
246
|
-
)
|
|
237
|
+
return self._model_context.revenue_per_kpi
|
|
247
238
|
|
|
248
|
-
@
|
|
239
|
+
@property
|
|
249
240
|
def controls(self) -> backend.Tensor | None:
|
|
250
|
-
|
|
251
|
-
return None
|
|
252
|
-
return backend.to_tensor(self.input_data.controls, dtype=backend.float32)
|
|
241
|
+
return self._model_context.controls
|
|
253
242
|
|
|
254
|
-
@
|
|
243
|
+
@property
|
|
255
244
|
def non_media_treatments(self) -> backend.Tensor | None:
|
|
256
|
-
|
|
257
|
-
return None
|
|
258
|
-
return backend.to_tensor(
|
|
259
|
-
self.input_data.non_media_treatments, dtype=backend.float32
|
|
260
|
-
)
|
|
245
|
+
return self._model_context.non_media_treatments
|
|
261
246
|
|
|
262
|
-
@
|
|
247
|
+
@property
|
|
263
248
|
def population(self) -> backend.Tensor:
|
|
264
|
-
return
|
|
249
|
+
return self._model_context.population
|
|
265
250
|
|
|
266
|
-
@
|
|
251
|
+
@property
|
|
267
252
|
def total_spend(self) -> backend.Tensor:
|
|
268
|
-
return
|
|
269
|
-
self.input_data.get_total_spend(), dtype=backend.float32
|
|
270
|
-
)
|
|
253
|
+
return self._model_context.total_spend
|
|
271
254
|
|
|
272
|
-
@
|
|
255
|
+
@property
|
|
273
256
|
def total_outcome(self) -> backend.Tensor:
|
|
274
|
-
return
|
|
275
|
-
self.input_data.get_total_outcome(), dtype=backend.float32
|
|
276
|
-
)
|
|
257
|
+
return self._model_context.total_outcome
|
|
277
258
|
|
|
278
259
|
@property
|
|
279
260
|
def n_geos(self) -> int:
|
|
280
|
-
return
|
|
261
|
+
return self._model_context.n_geos
|
|
281
262
|
|
|
282
263
|
@property
|
|
283
264
|
def n_media_channels(self) -> int:
|
|
284
|
-
|
|
285
|
-
return 0
|
|
286
|
-
return len(self.input_data.media_channel)
|
|
265
|
+
return self._model_context.n_media_channels
|
|
287
266
|
|
|
288
267
|
@property
|
|
289
268
|
def n_rf_channels(self) -> int:
|
|
290
|
-
|
|
291
|
-
return 0
|
|
292
|
-
return len(self.input_data.rf_channel)
|
|
269
|
+
return self._model_context.n_rf_channels
|
|
293
270
|
|
|
294
271
|
@property
|
|
295
272
|
def n_organic_media_channels(self) -> int:
|
|
296
|
-
|
|
297
|
-
return 0
|
|
298
|
-
return len(self.input_data.organic_media_channel)
|
|
273
|
+
return self._model_context.n_organic_media_channels
|
|
299
274
|
|
|
300
275
|
@property
|
|
301
276
|
def n_organic_rf_channels(self) -> int:
|
|
302
|
-
|
|
303
|
-
return 0
|
|
304
|
-
return len(self.input_data.organic_rf_channel)
|
|
277
|
+
return self._model_context.n_organic_rf_channels
|
|
305
278
|
|
|
306
279
|
@property
|
|
307
280
|
def n_controls(self) -> int:
|
|
308
|
-
|
|
309
|
-
return 0
|
|
310
|
-
return len(self.input_data.control_variable)
|
|
281
|
+
return self._model_context.n_controls
|
|
311
282
|
|
|
312
283
|
@property
|
|
313
284
|
def n_non_media_channels(self) -> int:
|
|
314
|
-
|
|
315
|
-
return 0
|
|
316
|
-
return len(self.input_data.non_media_channel)
|
|
285
|
+
return self._model_context.n_non_media_channels
|
|
317
286
|
|
|
318
287
|
@property
|
|
319
288
|
def n_times(self) -> int:
|
|
320
|
-
return
|
|
289
|
+
return self._model_context.n_times
|
|
321
290
|
|
|
322
291
|
@property
|
|
323
292
|
def n_media_times(self) -> int:
|
|
324
|
-
return
|
|
293
|
+
return self._model_context.n_media_times
|
|
325
294
|
|
|
326
295
|
@property
|
|
327
296
|
def is_national(self) -> bool:
|
|
328
|
-
return self.
|
|
297
|
+
return self._model_context.is_national
|
|
329
298
|
|
|
330
|
-
@
|
|
299
|
+
@property
|
|
331
300
|
def knot_info(self) -> knots.KnotInfo:
|
|
332
|
-
return
|
|
333
|
-
n_times=self.n_times,
|
|
334
|
-
knots=self.model_spec.knots,
|
|
335
|
-
enable_aks=self.model_spec.enable_aks,
|
|
336
|
-
data=self.input_data,
|
|
337
|
-
is_national=self.is_national,
|
|
338
|
-
)
|
|
301
|
+
return self._model_context.knot_info
|
|
339
302
|
|
|
340
|
-
@
|
|
303
|
+
@property
|
|
341
304
|
def controls_transformer(
|
|
342
305
|
self,
|
|
343
306
|
) -> transformers.CenteringAndScalingTransformer | None:
|
|
344
|
-
|
|
345
|
-
if self.controls is None:
|
|
346
|
-
return None
|
|
307
|
+
return self._model_context.controls_transformer
|
|
347
308
|
|
|
348
|
-
|
|
349
|
-
controls_population_scaling_id = backend.to_tensor(
|
|
350
|
-
self.model_spec.control_population_scaling_id, dtype=backend.bool_
|
|
351
|
-
)
|
|
352
|
-
else:
|
|
353
|
-
controls_population_scaling_id = None
|
|
354
|
-
|
|
355
|
-
return transformers.CenteringAndScalingTransformer(
|
|
356
|
-
tensor=self.controls,
|
|
357
|
-
population=self.population,
|
|
358
|
-
population_scaling_id=controls_population_scaling_id,
|
|
359
|
-
)
|
|
360
|
-
|
|
361
|
-
@functools.cached_property
|
|
309
|
+
@property
|
|
362
310
|
def non_media_transformer(
|
|
363
311
|
self,
|
|
364
312
|
) -> transformers.CenteringAndScalingTransformer | None:
|
|
365
|
-
|
|
366
|
-
if self.non_media_treatments is None:
|
|
367
|
-
return None
|
|
368
|
-
if self.model_spec.non_media_population_scaling_id is not None:
|
|
369
|
-
non_media_population_scaling_id = backend.to_tensor(
|
|
370
|
-
self.model_spec.non_media_population_scaling_id, dtype=backend.bool_
|
|
371
|
-
)
|
|
372
|
-
else:
|
|
373
|
-
non_media_population_scaling_id = None
|
|
313
|
+
return self._model_context.non_media_transformer
|
|
374
314
|
|
|
375
|
-
|
|
376
|
-
tensor=self.non_media_treatments,
|
|
377
|
-
population=self.population,
|
|
378
|
-
population_scaling_id=non_media_population_scaling_id,
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
@functools.cached_property
|
|
315
|
+
@property
|
|
382
316
|
def kpi_transformer(self) -> transformers.KpiTransformer:
|
|
383
|
-
return
|
|
317
|
+
return self._model_context.kpi_transformer
|
|
384
318
|
|
|
385
|
-
@
|
|
319
|
+
@property
|
|
386
320
|
def controls_scaled(self) -> backend.Tensor | None:
|
|
387
|
-
|
|
388
|
-
# If `controls` is defined, then `controls_transformer` is also defined.
|
|
389
|
-
return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
|
|
390
|
-
else:
|
|
391
|
-
return None
|
|
321
|
+
return self._model_context.controls_scaled
|
|
392
322
|
|
|
393
|
-
@
|
|
323
|
+
@property
|
|
394
324
|
def non_media_treatments_normalized(self) -> backend.Tensor | None:
|
|
395
325
|
"""Normalized non-media treatments.
|
|
396
326
|
|
|
@@ -398,130 +328,56 @@ class Meridian:
|
|
|
398
328
|
`non_media_population_scaling_id` is `True`) and normalized by centering and
|
|
399
329
|
scaling with means and standard deviations.
|
|
400
330
|
"""
|
|
401
|
-
|
|
402
|
-
return self.non_media_transformer.forward(
|
|
403
|
-
self.non_media_treatments
|
|
404
|
-
) # pytype: disable=attribute-error
|
|
405
|
-
else:
|
|
406
|
-
return None
|
|
331
|
+
return self._model_context.non_media_treatments_normalized
|
|
407
332
|
|
|
408
|
-
@
|
|
333
|
+
@property
|
|
409
334
|
def kpi_scaled(self) -> backend.Tensor:
|
|
410
|
-
return self.
|
|
335
|
+
return self._model_context.kpi_scaled
|
|
411
336
|
|
|
412
|
-
@
|
|
337
|
+
@property
|
|
413
338
|
def media_effects_dist(self) -> str:
|
|
414
|
-
|
|
415
|
-
return constants.NATIONAL_MODEL_SPEC_ARGS[constants.MEDIA_EFFECTS_DIST]
|
|
416
|
-
else:
|
|
417
|
-
return self.model_spec.media_effects_dist
|
|
339
|
+
return self._model_context.media_effects_dist
|
|
418
340
|
|
|
419
|
-
@
|
|
341
|
+
@property
|
|
420
342
|
def unique_sigma_for_each_geo(self) -> bool:
|
|
421
|
-
|
|
422
|
-
# Should evaluate to False.
|
|
423
|
-
return constants.NATIONAL_MODEL_SPEC_ARGS[
|
|
424
|
-
constants.UNIQUE_SIGMA_FOR_EACH_GEO
|
|
425
|
-
]
|
|
426
|
-
else:
|
|
427
|
-
return self.model_spec.unique_sigma_for_each_geo
|
|
343
|
+
return self._model_context.unique_sigma_for_each_geo
|
|
428
344
|
|
|
429
|
-
@
|
|
345
|
+
@property
|
|
430
346
|
def baseline_geo_idx(self) -> int:
|
|
431
347
|
"""Returns the index of the baseline geo."""
|
|
432
|
-
|
|
433
|
-
if (
|
|
434
|
-
self.model_spec.baseline_geo < 0
|
|
435
|
-
or self.model_spec.baseline_geo >= self.n_geos
|
|
436
|
-
):
|
|
437
|
-
raise ValueError(
|
|
438
|
-
f"Baseline geo index {self.model_spec.baseline_geo} out of range"
|
|
439
|
-
f" [0, {self.n_geos - 1}]."
|
|
440
|
-
)
|
|
441
|
-
return self.model_spec.baseline_geo
|
|
442
|
-
elif isinstance(self.model_spec.baseline_geo, str):
|
|
443
|
-
# np.where returns a 1-D tuple, its first element is an array of found
|
|
444
|
-
# elements.
|
|
445
|
-
index = np.where(self.input_data.geo == self.model_spec.baseline_geo)[0]
|
|
446
|
-
if index.size == 0:
|
|
447
|
-
raise ValueError(
|
|
448
|
-
f"Baseline geo '{self.model_spec.baseline_geo}' not found."
|
|
449
|
-
)
|
|
450
|
-
# Geos are unique, so index is a 1-element array.
|
|
451
|
-
return index[0]
|
|
452
|
-
else:
|
|
453
|
-
return backend.argmax(self.population)
|
|
348
|
+
return self._model_context.baseline_geo_idx
|
|
454
349
|
|
|
455
|
-
@
|
|
350
|
+
@property
|
|
456
351
|
def holdout_id(self) -> backend.Tensor | None:
|
|
457
|
-
|
|
458
|
-
return None
|
|
459
|
-
tensor = backend.to_tensor(self.model_spec.holdout_id, dtype=backend.bool_)
|
|
460
|
-
return tensor[backend.newaxis, ...] if self.is_national else tensor
|
|
352
|
+
return self._model_context.holdout_id
|
|
461
353
|
|
|
462
|
-
@
|
|
354
|
+
@property
|
|
463
355
|
def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
|
|
464
356
|
"""Returns `AdstockDecaySpec` object with correctly mapped channels."""
|
|
465
|
-
|
|
466
|
-
return adstock_hill.AdstockDecaySpec.from_consistent_type(
|
|
467
|
-
self.model_spec.adstock_decay_spec
|
|
468
|
-
)
|
|
357
|
+
return self._model_context.adstock_decay_spec
|
|
469
358
|
|
|
470
|
-
|
|
471
|
-
return self._create_adstock_decay_functions_from_channel_map(
|
|
472
|
-
self.model_spec.adstock_decay_spec
|
|
473
|
-
)
|
|
474
|
-
except KeyError as e:
|
|
475
|
-
raise ValueError(
|
|
476
|
-
"Unrecognized channel names found in `adstock_decay_spec` keys"
|
|
477
|
-
f" {tuple(self.model_spec.adstock_decay_spec.keys())}. Keys should"
|
|
478
|
-
" either contain only channel_names"
|
|
479
|
-
f" {tuple(self.input_data.get_all_adstock_hill_channels().tolist())} or"
|
|
480
|
-
" be one or more of {'media', 'rf', 'organic_media',"
|
|
481
|
-
" 'organic_rf'}."
|
|
482
|
-
) from e
|
|
483
|
-
|
|
484
|
-
@functools.cached_property
|
|
359
|
+
@property
|
|
485
360
|
def prior_broadcast(self) -> prior_distribution.PriorDistribution:
|
|
486
361
|
"""Returns broadcasted `PriorDistribution` object."""
|
|
487
|
-
|
|
488
|
-
# Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
|
|
489
|
-
if len(total_spend.shape) == 1:
|
|
490
|
-
# Already aggregated by channel.
|
|
491
|
-
agg_total_spend = total_spend
|
|
492
|
-
elif len(total_spend.shape) == 2:
|
|
493
|
-
agg_total_spend = np.sum(total_spend, axis=(0,))
|
|
494
|
-
else:
|
|
495
|
-
agg_total_spend = np.sum(total_spend, axis=(0, 1))
|
|
496
|
-
|
|
497
|
-
return self.model_spec.prior.broadcast(
|
|
498
|
-
n_geos=self.n_geos,
|
|
499
|
-
n_media_channels=self.n_media_channels,
|
|
500
|
-
n_rf_channels=self.n_rf_channels,
|
|
501
|
-
n_organic_media_channels=self.n_organic_media_channels,
|
|
502
|
-
n_organic_rf_channels=self.n_organic_rf_channels,
|
|
503
|
-
n_controls=self.n_controls,
|
|
504
|
-
n_non_media_channels=self.n_non_media_channels,
|
|
505
|
-
unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
|
|
506
|
-
n_knots=self.knot_info.n_knots,
|
|
507
|
-
is_national=self.is_national,
|
|
508
|
-
set_total_media_contribution_prior=self._set_total_media_contribution_prior,
|
|
509
|
-
kpi=np.sum(self.input_data.kpi.values),
|
|
510
|
-
total_spend=agg_total_spend,
|
|
511
|
-
)
|
|
362
|
+
return self._model_context.prior_broadcast
|
|
512
363
|
|
|
513
364
|
@functools.cached_property
|
|
514
365
|
def prior_sampler_callable(self) -> prior_sampler.PriorDistributionSampler:
|
|
515
366
|
"""A `PriorDistributionSampler` callable bound to this model."""
|
|
516
|
-
return prior_sampler.PriorDistributionSampler(
|
|
367
|
+
return prior_sampler.PriorDistributionSampler(
|
|
368
|
+
model_context=self.model_context,
|
|
369
|
+
)
|
|
517
370
|
|
|
518
371
|
@functools.cached_property
|
|
519
372
|
def posterior_sampler_callable(
|
|
520
373
|
self,
|
|
521
374
|
) -> posterior_sampler.PosteriorMCMCSampler:
|
|
522
375
|
"""A `PosteriorMCMCSampler` callable bound to this model."""
|
|
523
|
-
return posterior_sampler.PosteriorMCMCSampler(
|
|
376
|
+
return posterior_sampler.PosteriorMCMCSampler(
|
|
377
|
+
model_context=self.model_context,
|
|
378
|
+
)
|
|
524
379
|
|
|
380
|
+
# TODO: Remove this method.
|
|
525
381
|
def compute_non_media_treatments_baseline(
|
|
526
382
|
self,
|
|
527
383
|
non_media_baseline_values: Sequence[str | float] | None = None,
|
|
@@ -544,70 +400,18 @@ class Meridian:
|
|
|
544
400
|
A tensor of shape `(n_non_media_channels,)` containing the
|
|
545
401
|
baseline values for each non-media treatment channel.
|
|
546
402
|
"""
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
self.population[:, backend.newaxis, backend.newaxis],
|
|
557
|
-
no_op_scaling_factor,
|
|
558
|
-
)
|
|
559
|
-
else:
|
|
560
|
-
scaling_factors = no_op_scaling_factor
|
|
561
|
-
|
|
562
|
-
non_media_treatments_population_scaled = backend.divide_no_nan(
|
|
563
|
-
self.non_media_treatments, scaling_factors
|
|
403
|
+
warnings.warn(
|
|
404
|
+
"Meridian.compute_non_media_treatments_baseline() is deprecated and"
|
|
405
|
+
" will be removed in a future version. Use"
|
|
406
|
+
" `ModelEquations.compute_non_media_treatments_baseline()` instead.",
|
|
407
|
+
DeprecationWarning,
|
|
408
|
+
stacklevel=2,
|
|
409
|
+
)
|
|
410
|
+
return self.model_equations.compute_non_media_treatments_baseline(
|
|
411
|
+
non_media_baseline_values=non_media_baseline_values
|
|
564
412
|
)
|
|
565
413
|
|
|
566
|
-
|
|
567
|
-
# If non_media_baseline_values is not provided, use the minimum
|
|
568
|
-
# value for each non_media treatment channel as the baseline.
|
|
569
|
-
non_media_baseline_values_filled = [
|
|
570
|
-
constants.NON_MEDIA_BASELINE_MIN
|
|
571
|
-
] * non_media_treatments_population_scaled.shape[-1]
|
|
572
|
-
else:
|
|
573
|
-
non_media_baseline_values_filled = non_media_baseline_values
|
|
574
|
-
|
|
575
|
-
if non_media_treatments_population_scaled.shape[-1] != len(
|
|
576
|
-
non_media_baseline_values_filled
|
|
577
|
-
):
|
|
578
|
-
raise ValueError(
|
|
579
|
-
"The number of non-media channels"
|
|
580
|
-
f" ({non_media_treatments_population_scaled.shape[-1]}) does not"
|
|
581
|
-
" match the number of baseline values"
|
|
582
|
-
f" ({len(non_media_baseline_values_filled)})."
|
|
583
|
-
)
|
|
584
|
-
|
|
585
|
-
baseline_list = []
|
|
586
|
-
for channel in range(non_media_treatments_population_scaled.shape[-1]):
|
|
587
|
-
baseline_value = non_media_baseline_values_filled[channel]
|
|
588
|
-
|
|
589
|
-
if baseline_value == constants.NON_MEDIA_BASELINE_MIN:
|
|
590
|
-
baseline_for_channel = backend.reduce_min(
|
|
591
|
-
non_media_treatments_population_scaled[..., channel], axis=[0, 1]
|
|
592
|
-
)
|
|
593
|
-
elif baseline_value == constants.NON_MEDIA_BASELINE_MAX:
|
|
594
|
-
baseline_for_channel = backend.reduce_max(
|
|
595
|
-
non_media_treatments_population_scaled[..., channel], axis=[0, 1]
|
|
596
|
-
)
|
|
597
|
-
elif isinstance(baseline_value, numbers.Number):
|
|
598
|
-
baseline_for_channel = backend.to_tensor(
|
|
599
|
-
baseline_value, dtype=backend.float32
|
|
600
|
-
)
|
|
601
|
-
else:
|
|
602
|
-
raise ValueError(
|
|
603
|
-
f"Invalid non_media_baseline_values value: '{baseline_value}'. Only"
|
|
604
|
-
" float numbers and strings 'min' and 'max' are supported."
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
baseline_list.append(baseline_for_channel)
|
|
608
|
-
|
|
609
|
-
return backend.stack(baseline_list, axis=-1)
|
|
610
|
-
|
|
414
|
+
# TODO: Remove this method.
|
|
611
415
|
def expand_selected_time_dims(
|
|
612
416
|
self,
|
|
613
417
|
start_date: tc.Date = None,
|
|
@@ -635,12 +439,16 @@ class Meridian:
|
|
|
635
439
|
ValueError if `start_date` or `end_date` is not in the input data time
|
|
636
440
|
dimensions.
|
|
637
441
|
"""
|
|
638
|
-
|
|
442
|
+
warnings.warn(
|
|
443
|
+
"Meridian.expand_selected_time_dims() is deprecated and will be removed"
|
|
444
|
+
" in a future version. Use `ModelContext.expand_selected_time_dims()`"
|
|
445
|
+
" instead.",
|
|
446
|
+
DeprecationWarning,
|
|
447
|
+
stacklevel=2,
|
|
448
|
+
)
|
|
449
|
+
return self._model_context.expand_selected_time_dims(
|
|
639
450
|
start_date=start_date, end_date=end_date
|
|
640
451
|
)
|
|
641
|
-
if expanded is None:
|
|
642
|
-
return None
|
|
643
|
-
return [date.strftime(constants.DATE_FORMAT) for date in expanded]
|
|
644
452
|
|
|
645
453
|
def _validate_injected_inference_data(self):
|
|
646
454
|
"""Validates that the injected inference data has correct shapes.
|
|
@@ -752,427 +560,8 @@ class Meridian:
|
|
|
752
560
|
self.n_non_media_channels,
|
|
753
561
|
)
|
|
754
562
|
|
|
755
|
-
def
|
|
756
|
-
"""Validates
|
|
757
|
-
|
|
758
|
-
if (
|
|
759
|
-
self.model_spec.roi_calibration_period is not None
|
|
760
|
-
and self.model_spec.roi_calibration_period.shape
|
|
761
|
-
!= (
|
|
762
|
-
self.n_media_times,
|
|
763
|
-
self.n_media_channels,
|
|
764
|
-
)
|
|
765
|
-
):
|
|
766
|
-
raise ValueError(
|
|
767
|
-
"The shape of `roi_calibration_period`"
|
|
768
|
-
f" {self.model_spec.roi_calibration_period.shape} is different from"
|
|
769
|
-
f" `(n_media_times, n_media_channels) = ({self.n_media_times},"
|
|
770
|
-
f" {self.n_media_channels})`."
|
|
771
|
-
)
|
|
772
|
-
|
|
773
|
-
if (
|
|
774
|
-
self.model_spec.rf_roi_calibration_period is not None
|
|
775
|
-
and self.model_spec.rf_roi_calibration_period.shape
|
|
776
|
-
!= (
|
|
777
|
-
self.n_media_times,
|
|
778
|
-
self.n_rf_channels,
|
|
779
|
-
)
|
|
780
|
-
):
|
|
781
|
-
raise ValueError(
|
|
782
|
-
"The shape of `rf_roi_calibration_period`"
|
|
783
|
-
f" {self.model_spec.rf_roi_calibration_period.shape} is different"
|
|
784
|
-
f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
|
|
785
|
-
f" {self.n_rf_channels})`."
|
|
786
|
-
)
|
|
787
|
-
|
|
788
|
-
if self.model_spec.holdout_id is not None:
|
|
789
|
-
if self.is_national and (
|
|
790
|
-
self.model_spec.holdout_id.shape != (self.n_times,)
|
|
791
|
-
):
|
|
792
|
-
raise ValueError(
|
|
793
|
-
f"The shape of `holdout_id` {self.model_spec.holdout_id.shape} is"
|
|
794
|
-
f" different from `(n_times,) = ({self.n_times},)`."
|
|
795
|
-
)
|
|
796
|
-
elif not self.is_national and (
|
|
797
|
-
self.model_spec.holdout_id.shape
|
|
798
|
-
!= (
|
|
799
|
-
self.n_geos,
|
|
800
|
-
self.n_times,
|
|
801
|
-
)
|
|
802
|
-
):
|
|
803
|
-
raise ValueError(
|
|
804
|
-
f"The shape of `holdout_id` {self.model_spec.holdout_id.shape} is"
|
|
805
|
-
f" different from `(n_geos, n_times) = ({self.n_geos},"
|
|
806
|
-
f" {self.n_times})`."
|
|
807
|
-
)
|
|
808
|
-
|
|
809
|
-
if self.model_spec.control_population_scaling_id is not None and (
|
|
810
|
-
self.model_spec.control_population_scaling_id.shape
|
|
811
|
-
!= (self.n_controls,)
|
|
812
|
-
):
|
|
813
|
-
raise ValueError(
|
|
814
|
-
"The shape of `control_population_scaling_id`"
|
|
815
|
-
f" {self.model_spec.control_population_scaling_id.shape} is different"
|
|
816
|
-
f" from `(n_controls,) = ({self.n_controls},)`."
|
|
817
|
-
)
|
|
818
|
-
|
|
819
|
-
if self.model_spec.non_media_population_scaling_id is not None and (
|
|
820
|
-
self.model_spec.non_media_population_scaling_id.shape
|
|
821
|
-
!= (self.n_non_media_channels,)
|
|
822
|
-
):
|
|
823
|
-
raise ValueError(
|
|
824
|
-
"The shape of `non_media_population_scaling_id`"
|
|
825
|
-
f" {self.model_spec.non_media_population_scaling_id.shape} is"
|
|
826
|
-
" different from `(n_non_media_channels,) ="
|
|
827
|
-
f" ({self.n_non_media_channels},)`."
|
|
828
|
-
)
|
|
829
|
-
|
|
830
|
-
def _create_adstock_decay_functions_from_channel_map(
|
|
831
|
-
self, channel_function_map: Mapping[str, str]
|
|
832
|
-
) -> adstock_hill.AdstockDecaySpec:
|
|
833
|
-
"""Create `AdstockDecaySpec` from mapping from channels to decay functions."""
|
|
834
|
-
|
|
835
|
-
for channel in channel_function_map:
|
|
836
|
-
if channel not in self.input_data.get_all_adstock_hill_channels():
|
|
837
|
-
raise KeyError(f"Channel {channel} not found in data.")
|
|
838
|
-
|
|
839
|
-
if self.input_data.media_channel is not None:
|
|
840
|
-
media_channel_builder = self.input_data.get_paid_media_channels_argument_builder().with_default_value(
|
|
841
|
-
constants.GEOMETRIC_DECAY
|
|
842
|
-
)
|
|
843
|
-
media_adstock_function = media_channel_builder(**channel_function_map)
|
|
844
|
-
else:
|
|
845
|
-
media_adstock_function = constants.GEOMETRIC_DECAY
|
|
846
|
-
|
|
847
|
-
if self.input_data.rf_channel is not None:
|
|
848
|
-
rf_channel_builder = self.input_data.get_paid_rf_channels_argument_builder().with_default_value(
|
|
849
|
-
constants.GEOMETRIC_DECAY
|
|
850
|
-
)
|
|
851
|
-
rf_adstock_function = rf_channel_builder(**channel_function_map)
|
|
852
|
-
else:
|
|
853
|
-
rf_adstock_function = constants.GEOMETRIC_DECAY
|
|
854
|
-
|
|
855
|
-
if self.input_data.organic_media_channel is not None:
|
|
856
|
-
organic_media_channel_builder = self.input_data.get_organic_media_channels_argument_builder().with_default_value(
|
|
857
|
-
constants.GEOMETRIC_DECAY
|
|
858
|
-
)
|
|
859
|
-
organic_media_adstock_function = organic_media_channel_builder(
|
|
860
|
-
**channel_function_map
|
|
861
|
-
)
|
|
862
|
-
else:
|
|
863
|
-
organic_media_adstock_function = constants.GEOMETRIC_DECAY
|
|
864
|
-
|
|
865
|
-
if self.input_data.organic_rf_channel is not None:
|
|
866
|
-
organic_rf_channel_builder = self.input_data.get_organic_rf_channels_argument_builder().with_default_value(
|
|
867
|
-
constants.GEOMETRIC_DECAY
|
|
868
|
-
)
|
|
869
|
-
organic_rf_adstock_function = organic_rf_channel_builder(
|
|
870
|
-
**channel_function_map
|
|
871
|
-
)
|
|
872
|
-
else:
|
|
873
|
-
organic_rf_adstock_function = constants.GEOMETRIC_DECAY
|
|
874
|
-
|
|
875
|
-
return adstock_hill.AdstockDecaySpec(
|
|
876
|
-
media=media_adstock_function,
|
|
877
|
-
rf=rf_adstock_function,
|
|
878
|
-
organic_media=organic_media_adstock_function,
|
|
879
|
-
organic_rf=organic_rf_adstock_function,
|
|
880
|
-
)
|
|
881
|
-
|
|
882
|
-
def _warn_setting_ignored_priors(self):
|
|
883
|
-
"""Raises a warning if ignored priors are set."""
|
|
884
|
-
default_distribution = prior_distribution.PriorDistribution()
|
|
885
|
-
for ignored_priors_dict, prior_type, prior_type_name in (
|
|
886
|
-
(
|
|
887
|
-
constants.IGNORED_PRIORS_MEDIA,
|
|
888
|
-
self.model_spec.effective_media_prior_type,
|
|
889
|
-
"media_prior_type",
|
|
890
|
-
),
|
|
891
|
-
(
|
|
892
|
-
constants.IGNORED_PRIORS_RF,
|
|
893
|
-
self.model_spec.effective_rf_prior_type,
|
|
894
|
-
"rf_prior_type",
|
|
895
|
-
),
|
|
896
|
-
):
|
|
897
|
-
ignored_custom_priors = []
|
|
898
|
-
for prior in ignored_priors_dict.get(prior_type, []):
|
|
899
|
-
self_prior = getattr(self.model_spec.prior, prior)
|
|
900
|
-
default_prior = getattr(default_distribution, prior)
|
|
901
|
-
if not prior_distribution.distributions_are_equal(
|
|
902
|
-
self_prior, default_prior
|
|
903
|
-
):
|
|
904
|
-
ignored_custom_priors.append(prior)
|
|
905
|
-
if ignored_custom_priors:
|
|
906
|
-
ignored_priors_str = ", ".join(ignored_custom_priors)
|
|
907
|
-
warnings.warn(
|
|
908
|
-
f"Custom prior(s) `{ignored_priors_str}` are ignored when"
|
|
909
|
-
f' `{prior_type_name}` is set to "{prior_type}".'
|
|
910
|
-
)
|
|
911
|
-
|
|
912
|
-
def _validate_mroi_priors_non_revenue(self):
|
|
913
|
-
"""Validates mroi priors in the non-revenue outcome case."""
|
|
914
|
-
if (
|
|
915
|
-
self.input_data.kpi_type == constants.NON_REVENUE
|
|
916
|
-
and self.input_data.revenue_per_kpi is None
|
|
917
|
-
):
|
|
918
|
-
default_distribution = prior_distribution.PriorDistribution()
|
|
919
|
-
if (
|
|
920
|
-
self.n_media_channels > 0
|
|
921
|
-
and (
|
|
922
|
-
self.model_spec.effective_media_prior_type
|
|
923
|
-
== constants.TREATMENT_PRIOR_TYPE_MROI
|
|
924
|
-
)
|
|
925
|
-
and prior_distribution.distributions_are_equal(
|
|
926
|
-
self.model_spec.prior.mroi_m, default_distribution.mroi_m
|
|
927
|
-
)
|
|
928
|
-
):
|
|
929
|
-
raise ValueError(
|
|
930
|
-
f"Custom priors should be set on `{constants.MROI_M}` when"
|
|
931
|
-
' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
|
|
932
|
-
" kpi data is missing."
|
|
933
|
-
)
|
|
934
|
-
if (
|
|
935
|
-
self.n_rf_channels > 0
|
|
936
|
-
and (
|
|
937
|
-
self.model_spec.effective_rf_prior_type
|
|
938
|
-
== constants.TREATMENT_PRIOR_TYPE_MROI
|
|
939
|
-
)
|
|
940
|
-
and prior_distribution.distributions_are_equal(
|
|
941
|
-
self.model_spec.prior.mroi_rf, default_distribution.mroi_rf
|
|
942
|
-
)
|
|
943
|
-
):
|
|
944
|
-
raise ValueError(
|
|
945
|
-
f"Custom priors should be set on `{constants.MROI_RF}` when"
|
|
946
|
-
' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
|
|
947
|
-
" data is missing."
|
|
948
|
-
)
|
|
949
|
-
|
|
950
|
-
def _validate_roi_priors_non_revenue(self):
|
|
951
|
-
"""Validates roi priors in the non-revenue outcome case."""
|
|
952
|
-
if (
|
|
953
|
-
self.input_data.kpi_type == constants.NON_REVENUE
|
|
954
|
-
and self.input_data.revenue_per_kpi is None
|
|
955
|
-
):
|
|
956
|
-
default_distribution = prior_distribution.PriorDistribution()
|
|
957
|
-
default_roi_m_used = (
|
|
958
|
-
self.model_spec.effective_media_prior_type
|
|
959
|
-
== constants.TREATMENT_PRIOR_TYPE_ROI
|
|
960
|
-
and prior_distribution.distributions_are_equal(
|
|
961
|
-
self.model_spec.prior.roi_m, default_distribution.roi_m
|
|
962
|
-
)
|
|
963
|
-
)
|
|
964
|
-
default_roi_rf_used = (
|
|
965
|
-
self.model_spec.effective_rf_prior_type
|
|
966
|
-
== constants.TREATMENT_PRIOR_TYPE_ROI
|
|
967
|
-
and prior_distribution.distributions_are_equal(
|
|
968
|
-
self.model_spec.prior.roi_rf, default_distribution.roi_rf
|
|
969
|
-
)
|
|
970
|
-
)
|
|
971
|
-
# If ROI priors are used with the default prior distribution for all paid
|
|
972
|
-
# channels (media and RF), then use the "total paid media contribution
|
|
973
|
-
# prior" procedure.
|
|
974
|
-
if (
|
|
975
|
-
(default_roi_m_used and default_roi_rf_used)
|
|
976
|
-
or (self.n_media_channels == 0 and default_roi_rf_used)
|
|
977
|
-
or (self.n_rf_channels == 0 and default_roi_m_used)
|
|
978
|
-
):
|
|
979
|
-
self._set_total_media_contribution_prior = True
|
|
980
|
-
warnings.warn(
|
|
981
|
-
"Consider setting custom ROI priors, as kpi_type was specified as"
|
|
982
|
-
" `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
|
|
983
|
-
" total media contribution prior will be used with"
|
|
984
|
-
f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
|
|
985
|
-
" documentation available at "
|
|
986
|
-
" https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
|
|
987
|
-
)
|
|
988
|
-
elif self.n_media_channels > 0 and default_roi_m_used:
|
|
989
|
-
raise ValueError(
|
|
990
|
-
f"Custom priors should be set on `{constants.ROI_M}` when"
|
|
991
|
-
' `media_prior_type` is "roi", custom priors are assigned on'
|
|
992
|
-
' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
|
|
993
|
-
" non-revenue and revenue per kpi data is missing."
|
|
994
|
-
)
|
|
995
|
-
elif self.n_rf_channels > 0 and default_roi_rf_used:
|
|
996
|
-
raise ValueError(
|
|
997
|
-
f"Custom priors should be set on `{constants.ROI_RF}` when"
|
|
998
|
-
' `rf_prior_type` is "roi", custom priors are assigned on'
|
|
999
|
-
' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
|
|
1000
|
-
" non-revenue and revenue per kpi data is missing."
|
|
1001
|
-
)
|
|
1002
|
-
|
|
1003
|
-
def _check_for_negative_effects(self):
|
|
1004
|
-
prior = self.model_spec.prior
|
|
1005
|
-
if self.n_media_channels > 0:
|
|
1006
|
-
_check_for_negative_effect(prior.roi_m, self.media_effects_dist)
|
|
1007
|
-
_check_for_negative_effect(prior.mroi_m, self.media_effects_dist)
|
|
1008
|
-
if self.n_rf_channels > 0:
|
|
1009
|
-
_check_for_negative_effect(prior.roi_rf, self.media_effects_dist)
|
|
1010
|
-
_check_for_negative_effect(prior.mroi_rf, self.media_effects_dist)
|
|
1011
|
-
|
|
1012
|
-
def _validate_geo_invariants(self):
|
|
1013
|
-
"""Validates non-national model invariants."""
|
|
1014
|
-
if self.is_national:
|
|
1015
|
-
return
|
|
1016
|
-
|
|
1017
|
-
if self.input_data.controls is not None:
|
|
1018
|
-
self._check_if_no_geo_variation(
|
|
1019
|
-
self.controls_scaled,
|
|
1020
|
-
constants.CONTROLS,
|
|
1021
|
-
self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
1022
|
-
)
|
|
1023
|
-
if self.input_data.non_media_treatments is not None:
|
|
1024
|
-
self._check_if_no_geo_variation(
|
|
1025
|
-
self.non_media_treatments_normalized,
|
|
1026
|
-
constants.NON_MEDIA_TREATMENTS,
|
|
1027
|
-
self.input_data.non_media_treatments.coords[
|
|
1028
|
-
constants.NON_MEDIA_CHANNEL
|
|
1029
|
-
].values,
|
|
1030
|
-
)
|
|
1031
|
-
if self.input_data.media is not None:
|
|
1032
|
-
self._check_if_no_geo_variation(
|
|
1033
|
-
self.media_tensors.media_scaled,
|
|
1034
|
-
constants.MEDIA,
|
|
1035
|
-
self.input_data.media.coords[constants.MEDIA_CHANNEL].values,
|
|
1036
|
-
)
|
|
1037
|
-
if self.input_data.reach is not None:
|
|
1038
|
-
self._check_if_no_geo_variation(
|
|
1039
|
-
self.rf_tensors.reach_scaled,
|
|
1040
|
-
constants.REACH,
|
|
1041
|
-
self.input_data.reach.coords[constants.RF_CHANNEL].values,
|
|
1042
|
-
)
|
|
1043
|
-
if self.input_data.organic_media is not None:
|
|
1044
|
-
self._check_if_no_geo_variation(
|
|
1045
|
-
self.organic_media_tensors.organic_media_scaled,
|
|
1046
|
-
"organic_media",
|
|
1047
|
-
self.input_data.organic_media.coords[
|
|
1048
|
-
constants.ORGANIC_MEDIA_CHANNEL
|
|
1049
|
-
].values,
|
|
1050
|
-
)
|
|
1051
|
-
if self.input_data.organic_reach is not None:
|
|
1052
|
-
self._check_if_no_geo_variation(
|
|
1053
|
-
self.organic_rf_tensors.organic_reach_scaled,
|
|
1054
|
-
"organic_reach",
|
|
1055
|
-
self.input_data.organic_reach.coords[
|
|
1056
|
-
constants.ORGANIC_RF_CHANNEL
|
|
1057
|
-
].values,
|
|
1058
|
-
)
|
|
1059
|
-
|
|
1060
|
-
def _check_if_no_geo_variation(
|
|
1061
|
-
self,
|
|
1062
|
-
scaled_data: backend.Tensor,
|
|
1063
|
-
data_name: str,
|
|
1064
|
-
data_dims: Sequence[str],
|
|
1065
|
-
epsilon=1e-4,
|
|
1066
|
-
):
|
|
1067
|
-
"""Raise an error if `n_knots == n_time` and data lacks geo variation."""
|
|
1068
|
-
|
|
1069
|
-
# Result shape: [n, d], where d is the number of axes of condition.
|
|
1070
|
-
col_idx_full = backend.get_indices_where(
|
|
1071
|
-
backend.reduce_std(scaled_data, axis=0) < epsilon
|
|
1072
|
-
)[:, 1]
|
|
1073
|
-
col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
|
|
1074
|
-
# We use the shape of scaled_data (instead of `n_time`) because the data may
|
|
1075
|
-
# be padded to account for lagged effects.
|
|
1076
|
-
data_n_time = scaled_data.shape[1]
|
|
1077
|
-
mask = backend.equal(counts, data_n_time)
|
|
1078
|
-
col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
|
|
1079
|
-
dims_bad = backend.gather(data_dims, col_idx_bad)
|
|
1080
|
-
|
|
1081
|
-
if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
|
|
1082
|
-
raise ValueError(
|
|
1083
|
-
f"The following {data_name} variables do not vary across geos, making"
|
|
1084
|
-
f" a model with n_knots=n_time unidentifiable: {dims_bad}. This can"
|
|
1085
|
-
" lead to poor model convergence. Since these variables only vary"
|
|
1086
|
-
" across time and not across geo, they are collinear with time and"
|
|
1087
|
-
" redundant in a model with a parameter for each time period. To"
|
|
1088
|
-
" address this, you can either: (1) decrease the number of knots"
|
|
1089
|
-
" (n_knots < n_time), or (2) drop the listed variables that do not"
|
|
1090
|
-
" vary across geos."
|
|
1091
|
-
)
|
|
1092
|
-
|
|
1093
|
-
def _validate_time_invariants(self):
|
|
1094
|
-
"""Validates model time invariants."""
|
|
1095
|
-
if self.input_data.controls is not None:
|
|
1096
|
-
self._check_if_no_time_variation(
|
|
1097
|
-
self.controls_scaled,
|
|
1098
|
-
constants.CONTROLS,
|
|
1099
|
-
self.input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
1100
|
-
)
|
|
1101
|
-
if self.input_data.non_media_treatments is not None:
|
|
1102
|
-
self._check_if_no_time_variation(
|
|
1103
|
-
self.non_media_treatments_normalized,
|
|
1104
|
-
constants.NON_MEDIA_TREATMENTS,
|
|
1105
|
-
self.input_data.non_media_treatments.coords[
|
|
1106
|
-
constants.NON_MEDIA_CHANNEL
|
|
1107
|
-
].values,
|
|
1108
|
-
)
|
|
1109
|
-
if self.input_data.media is not None:
|
|
1110
|
-
self._check_if_no_time_variation(
|
|
1111
|
-
self.media_tensors.media_scaled,
|
|
1112
|
-
constants.MEDIA,
|
|
1113
|
-
self.input_data.media.coords[constants.MEDIA_CHANNEL].values,
|
|
1114
|
-
)
|
|
1115
|
-
if self.input_data.reach is not None:
|
|
1116
|
-
self._check_if_no_time_variation(
|
|
1117
|
-
self.rf_tensors.reach_scaled,
|
|
1118
|
-
constants.REACH,
|
|
1119
|
-
self.input_data.reach.coords[constants.RF_CHANNEL].values,
|
|
1120
|
-
)
|
|
1121
|
-
if self.input_data.organic_media is not None:
|
|
1122
|
-
self._check_if_no_time_variation(
|
|
1123
|
-
self.organic_media_tensors.organic_media_scaled,
|
|
1124
|
-
constants.ORGANIC_MEDIA,
|
|
1125
|
-
self.input_data.organic_media.coords[
|
|
1126
|
-
constants.ORGANIC_MEDIA_CHANNEL
|
|
1127
|
-
].values,
|
|
1128
|
-
)
|
|
1129
|
-
if self.input_data.organic_reach is not None:
|
|
1130
|
-
self._check_if_no_time_variation(
|
|
1131
|
-
self.organic_rf_tensors.organic_reach_scaled,
|
|
1132
|
-
constants.ORGANIC_REACH,
|
|
1133
|
-
self.input_data.organic_reach.coords[
|
|
1134
|
-
constants.ORGANIC_RF_CHANNEL
|
|
1135
|
-
].values,
|
|
1136
|
-
)
|
|
1137
|
-
|
|
1138
|
-
def _check_if_no_time_variation(
|
|
1139
|
-
self,
|
|
1140
|
-
scaled_data: backend.Tensor,
|
|
1141
|
-
data_name: str,
|
|
1142
|
-
data_dims: Sequence[str],
|
|
1143
|
-
epsilon=1e-4,
|
|
1144
|
-
):
|
|
1145
|
-
"""Raise an error if data lacks time variation."""
|
|
1146
|
-
|
|
1147
|
-
# Result shape: [n, d], where d is the number of axes of condition.
|
|
1148
|
-
col_idx_full = backend.get_indices_where(
|
|
1149
|
-
backend.reduce_std(scaled_data, axis=1) < epsilon
|
|
1150
|
-
)[:, 1]
|
|
1151
|
-
col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
|
|
1152
|
-
mask = backend.equal(counts, self.n_geos)
|
|
1153
|
-
col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
|
|
1154
|
-
dims_bad = backend.gather(data_dims, col_idx_bad)
|
|
1155
|
-
if col_idx_bad.shape[0]:
|
|
1156
|
-
if self.is_national:
|
|
1157
|
-
raise ValueError(
|
|
1158
|
-
f"The following {data_name} variables do not vary across time,"
|
|
1159
|
-
" which is equivalent to no signal at all in a national model:"
|
|
1160
|
-
f" {dims_bad}. This can lead to poor model convergence. To address"
|
|
1161
|
-
" this, drop the listed variables that do not vary across time."
|
|
1162
|
-
)
|
|
1163
|
-
else:
|
|
1164
|
-
raise ValueError(
|
|
1165
|
-
f"The following {data_name} variables do not vary across time,"
|
|
1166
|
-
f" making a model with geo main effects unidentifiable: {dims_bad}."
|
|
1167
|
-
" This can lead to poor model convergence. Since these variables"
|
|
1168
|
-
" only vary across geo and not across time, they are collinear"
|
|
1169
|
-
" with geo and redundant in a model with geo main effects. To"
|
|
1170
|
-
" address this, drop the listed variables that do not vary across"
|
|
1171
|
-
" time."
|
|
1172
|
-
)
|
|
1173
|
-
|
|
1174
|
-
def _validate_kpi_transformer(self):
|
|
1175
|
-
"""Validates the KPI transformer."""
|
|
563
|
+
def _validate_kpi_variability(self):
|
|
564
|
+
"""Validates the KPI variability."""
|
|
1176
565
|
if self.eda_engine.kpi_has_variability:
|
|
1177
566
|
return
|
|
1178
567
|
kpi = self.eda_engine.kpi_scaled_da.name
|
|
@@ -1227,6 +616,7 @@ class Meridian:
|
|
|
1227
616
|
f' "{self.model_spec.non_media_treatments_prior_type}".'
|
|
1228
617
|
)
|
|
1229
618
|
|
|
619
|
+
# TODO: Remove this method.
|
|
1230
620
|
def linear_predictor_counterfactual_difference_media(
|
|
1231
621
|
self,
|
|
1232
622
|
media_transformed: backend.Tensor,
|
|
@@ -1255,21 +645,24 @@ class Meridian:
|
|
|
1255
645
|
The linear predictor difference between the treatment variable and its
|
|
1256
646
|
counterfactual.
|
|
1257
647
|
"""
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
decay_functions=self.adstock_decay_spec.media,
|
|
648
|
+
warnings.warn(
|
|
649
|
+
"Meridian.linear_predictor_counterfactual_difference_media() is"
|
|
650
|
+
" deprecated and will be removed in a future version. Use "
|
|
651
|
+
"`ModelEquations.linear_predictor_counterfactual_difference_media()`"
|
|
652
|
+
" instead.",
|
|
653
|
+
DeprecationWarning,
|
|
654
|
+
stacklevel=2,
|
|
1266
655
|
)
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
656
|
+
return (
|
|
657
|
+
self.model_equations.linear_predictor_counterfactual_difference_media(
|
|
658
|
+
media_transformed=media_transformed,
|
|
659
|
+
alpha_m=alpha_m,
|
|
660
|
+
ec_m=ec_m,
|
|
661
|
+
slope_m=slope_m,
|
|
662
|
+
)
|
|
1271
663
|
)
|
|
1272
664
|
|
|
665
|
+
# TODO: Remove this method.
|
|
1273
666
|
def linear_predictor_counterfactual_difference_rf(
|
|
1274
667
|
self,
|
|
1275
668
|
rf_transformed: backend.Tensor,
|
|
@@ -1298,20 +691,21 @@ class Meridian:
|
|
|
1298
691
|
The linear predictor difference between the treatment variable and its
|
|
1299
692
|
counterfactual.
|
|
1300
693
|
"""
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
694
|
+
warnings.warn(
|
|
695
|
+
"Meridian.linear_predictor_counterfactual_difference_rf() is deprecated"
|
|
696
|
+
" and will be removed in a future version. Use `ModelEquations."
|
|
697
|
+
"linear_predictor_counterfactual_difference_rf()` instead.",
|
|
698
|
+
DeprecationWarning,
|
|
699
|
+
stacklevel=2,
|
|
700
|
+
)
|
|
701
|
+
return self.model_equations.linear_predictor_counterfactual_difference_rf(
|
|
702
|
+
rf_transformed=rf_transformed,
|
|
703
|
+
alpha_rf=alpha_rf,
|
|
704
|
+
ec_rf=ec_rf,
|
|
705
|
+
slope_rf=slope_rf,
|
|
1310
706
|
)
|
|
1311
|
-
# Absolute values is needed because the difference is negative for mROI
|
|
1312
|
-
# priors and positive for ROI and contribution priors.
|
|
1313
|
-
return backend.absolute(rf_transformed - rf_transformed_counterfactual)
|
|
1314
707
|
|
|
708
|
+
# TODO: Remove this method.
|
|
1315
709
|
def calculate_beta_x(
|
|
1316
710
|
self,
|
|
1317
711
|
is_non_media: bool,
|
|
@@ -1357,45 +751,22 @@ class Meridian:
|
|
|
1357
751
|
The coefficient mean parameter of the treatment variable, which has
|
|
1358
752
|
dimension equal to the number of treatment channels..
|
|
1359
753
|
"""
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
if self.revenue_per_kpi is None:
|
|
1367
|
-
revenue_per_kpi = backend.ones(
|
|
1368
|
-
[self.n_geos, self.n_times], dtype=backend.float32
|
|
1369
|
-
)
|
|
1370
|
-
else:
|
|
1371
|
-
revenue_per_kpi = self.revenue_per_kpi
|
|
1372
|
-
incremental_outcome_gx_over_beta_gx = backend.einsum(
|
|
1373
|
-
"...gtx,gt,g,->...gx",
|
|
1374
|
-
linear_predictor_counterfactual_difference,
|
|
1375
|
-
revenue_per_kpi,
|
|
1376
|
-
self.population,
|
|
1377
|
-
self.kpi_transformer.population_scaled_stdev,
|
|
754
|
+
warnings.warn(
|
|
755
|
+
"Meridian.calculate_beta_x() is deprecated and will be removed in a"
|
|
756
|
+
" future version. Use `ModelEquations.calculate_beta_x()`"
|
|
757
|
+
" instead.",
|
|
758
|
+
DeprecationWarning,
|
|
759
|
+
stacklevel=2,
|
|
1378
760
|
)
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
)
|
|
1386
|
-
denominator_term_x = backend.einsum(
|
|
1387
|
-
"...gx->...x", incremental_outcome_gx_over_beta_gx
|
|
1388
|
-
)
|
|
1389
|
-
return (incremental_outcome_x - numerator_term_x) / denominator_term_x
|
|
1390
|
-
# For log-normal random effects, beta_x and eta_x are not mean & std.
|
|
1391
|
-
# The parameterization is beta_gx ~ exp(beta_x + eta_x * N(0, 1)).
|
|
1392
|
-
denominator_term_x = backend.einsum(
|
|
1393
|
-
"...gx,...gx->...x",
|
|
1394
|
-
incremental_outcome_gx_over_beta_gx,
|
|
1395
|
-
backend.exp(beta_gx_dev * eta_x[..., backend.newaxis, :]),
|
|
761
|
+
return self.model_equations.calculate_beta_x(
|
|
762
|
+
is_non_media=is_non_media,
|
|
763
|
+
incremental_outcome_x=incremental_outcome_x,
|
|
764
|
+
linear_predictor_counterfactual_difference=linear_predictor_counterfactual_difference,
|
|
765
|
+
eta_x=eta_x,
|
|
766
|
+
beta_gx_dev=beta_gx_dev,
|
|
1396
767
|
)
|
|
1397
|
-
return backend.log(incremental_outcome_x) - backend.log(denominator_term_x)
|
|
1398
768
|
|
|
769
|
+
# TODO: Remove this method.
|
|
1399
770
|
def adstock_hill_media(
|
|
1400
771
|
self,
|
|
1401
772
|
media: backend.Tensor, # pylint: disable=redefined-outer-name
|
|
@@ -1426,34 +797,23 @@ class Meridian:
|
|
|
1426
797
|
Tensor with dimensions `[..., n_geos, n_times, n_media_channels]`
|
|
1427
798
|
representing Adstock and Hill-transformed media.
|
|
1428
799
|
"""
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
)
|
|
1436
|
-
adstock_transformer = adstock_hill.AdstockTransformer(
|
|
1437
|
-
alpha=alpha,
|
|
1438
|
-
max_lag=self.model_spec.max_lag,
|
|
1439
|
-
n_times_output=n_times_output,
|
|
1440
|
-
decay_functions=decay_functions,
|
|
800
|
+
warnings.warn(
|
|
801
|
+
"Meridian.adstock_hill_media() is deprecated and will be removed in a"
|
|
802
|
+
" future version. Use `ModelEquations.adstock_hill_media()`"
|
|
803
|
+
" instead.",
|
|
804
|
+
DeprecationWarning,
|
|
805
|
+
stacklevel=2,
|
|
1441
806
|
)
|
|
1442
|
-
|
|
807
|
+
return self.model_equations.adstock_hill_media(
|
|
808
|
+
media=media,
|
|
809
|
+
alpha=alpha,
|
|
1443
810
|
ec=ec,
|
|
1444
811
|
slope=slope,
|
|
812
|
+
decay_functions=decay_functions,
|
|
813
|
+
n_times_output=n_times_output,
|
|
1445
814
|
)
|
|
1446
|
-
transformers_list = (
|
|
1447
|
-
[hill_transformer, adstock_transformer]
|
|
1448
|
-
if self.model_spec.hill_before_adstock
|
|
1449
|
-
else [adstock_transformer, hill_transformer]
|
|
1450
|
-
)
|
|
1451
|
-
|
|
1452
|
-
media_out = media
|
|
1453
|
-
for transformer in transformers_list:
|
|
1454
|
-
media_out = transformer.forward(media_out)
|
|
1455
|
-
return media_out
|
|
1456
815
|
|
|
816
|
+
# TODO: Remove this method.
|
|
1457
817
|
def adstock_hill_rf(
|
|
1458
818
|
self,
|
|
1459
819
|
reach: backend.Tensor,
|
|
@@ -1485,27 +845,22 @@ class Meridian:
|
|
|
1485
845
|
Tensor with dimensions `[..., n_geos, n_times, n_rf_channels]`
|
|
1486
846
|
representing Hill and Adstock-transformed RF.
|
|
1487
847
|
"""
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
)
|
|
1495
|
-
hill_transformer = adstock_hill.HillTransformer(
|
|
1496
|
-
ec=ec,
|
|
1497
|
-
slope=slope,
|
|
848
|
+
warnings.warn(
|
|
849
|
+
"Meridian.adstock_hill_rf() is deprecated and will be removed in a"
|
|
850
|
+
" future version. Use `ModelEquations.adstock_hill_rf()`"
|
|
851
|
+
" instead.",
|
|
852
|
+
DeprecationWarning,
|
|
853
|
+
stacklevel=2,
|
|
1498
854
|
)
|
|
1499
|
-
|
|
855
|
+
return self.model_equations.adstock_hill_rf(
|
|
856
|
+
reach=reach,
|
|
857
|
+
frequency=frequency,
|
|
1500
858
|
alpha=alpha,
|
|
1501
|
-
|
|
1502
|
-
|
|
859
|
+
ec=ec,
|
|
860
|
+
slope=slope,
|
|
1503
861
|
decay_functions=decay_functions,
|
|
862
|
+
n_times_output=n_times_output,
|
|
1504
863
|
)
|
|
1505
|
-
adj_frequency = hill_transformer.forward(frequency)
|
|
1506
|
-
rf_out = adstock_transformer.forward(reach * adj_frequency)
|
|
1507
|
-
|
|
1508
|
-
return rf_out
|
|
1509
864
|
|
|
1510
865
|
def populate_cached_properties(self):
|
|
1511
866
|
"""Eagerly activates all cached properties.
|
|
@@ -1515,6 +870,7 @@ class Meridian:
|
|
|
1515
870
|
internal state mutations are problematic, and so this method freezes the
|
|
1516
871
|
object's states before the computation graph is created.
|
|
1517
872
|
"""
|
|
873
|
+
self._model_context.populate_cached_properties()
|
|
1518
874
|
cls = self.__class__
|
|
1519
875
|
# "Freeze" all @cached_property attributes by simply accessing them (with
|
|
1520
876
|
# `getattr()`).
|
|
@@ -1526,66 +882,30 @@ class Meridian:
|
|
|
1526
882
|
for attr in cached_properties:
|
|
1527
883
|
_ = getattr(self, attr)
|
|
1528
884
|
|
|
885
|
+
# TODO: Remove this method.
|
|
1529
886
|
def create_inference_data_coords(
|
|
1530
887
|
self, n_chains: int, n_draws: int
|
|
1531
888
|
) -> Mapping[str, np.ndarray | Sequence[str]]:
|
|
1532
889
|
"""Creates data coordinates for inference data."""
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
self.input_data.rf_channel
|
|
1540
|
-
if self.input_data.rf_channel is not None
|
|
1541
|
-
else np.array([])
|
|
1542
|
-
)
|
|
1543
|
-
organic_media_channel_names = (
|
|
1544
|
-
self.input_data.organic_media_channel
|
|
1545
|
-
if self.input_data.organic_media_channel is not None
|
|
1546
|
-
else np.array([])
|
|
1547
|
-
)
|
|
1548
|
-
organic_rf_channel_names = (
|
|
1549
|
-
self.input_data.organic_rf_channel
|
|
1550
|
-
if self.input_data.organic_rf_channel is not None
|
|
1551
|
-
else np.array([])
|
|
1552
|
-
)
|
|
1553
|
-
non_media_channel_names = (
|
|
1554
|
-
self.input_data.non_media_channel
|
|
1555
|
-
if self.input_data.non_media_channel is not None
|
|
1556
|
-
else np.array([])
|
|
890
|
+
warnings.warn(
|
|
891
|
+
"Meridian.create_inference_data_coords() is deprecated and will be"
|
|
892
|
+
" removed in a future version. Use"
|
|
893
|
+
" `ModelContext.create_inference_data_coords()` instead.",
|
|
894
|
+
DeprecationWarning,
|
|
895
|
+
stacklevel=2,
|
|
1557
896
|
)
|
|
1558
|
-
|
|
1559
|
-
self.input_data.control_variable
|
|
1560
|
-
if self.input_data.control_variable is not None
|
|
1561
|
-
else np.array([])
|
|
1562
|
-
)
|
|
1563
|
-
return {
|
|
1564
|
-
constants.CHAIN: np.arange(n_chains),
|
|
1565
|
-
constants.DRAW: np.arange(n_draws),
|
|
1566
|
-
constants.GEO: self.input_data.geo,
|
|
1567
|
-
constants.TIME: self.input_data.time,
|
|
1568
|
-
constants.MEDIA_TIME: self.input_data.media_time,
|
|
1569
|
-
constants.KNOTS: np.arange(self.knot_info.n_knots),
|
|
1570
|
-
constants.CONTROL_VARIABLE: control_variable_names,
|
|
1571
|
-
constants.NON_MEDIA_CHANNEL: non_media_channel_names,
|
|
1572
|
-
constants.MEDIA_CHANNEL: media_channel_names,
|
|
1573
|
-
constants.RF_CHANNEL: rf_channel_names,
|
|
1574
|
-
constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
|
|
1575
|
-
constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
|
|
1576
|
-
}
|
|
897
|
+
return self._model_context.create_inference_data_coords(n_chains, n_draws)
|
|
1577
898
|
|
|
899
|
+
# TODO: Remove this method.
|
|
1578
900
|
def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
for param, dims in inference_dims.items()
|
|
1588
|
-
}
|
|
901
|
+
warnings.warn(
|
|
902
|
+
"Meridian.create_inference_data_dims() is deprecated and will be"
|
|
903
|
+
" removed in a future version. Use"
|
|
904
|
+
" `ModelContext.create_inference_data_dims()` instead.",
|
|
905
|
+
DeprecationWarning,
|
|
906
|
+
stacklevel=2,
|
|
907
|
+
)
|
|
908
|
+
return self._model_context.create_inference_data_dims()
|
|
1589
909
|
|
|
1590
910
|
def sample_prior(self, n_draws: int, seed: int | None = None):
|
|
1591
911
|
"""Draws samples from the prior distributions.
|
|
@@ -1598,14 +918,25 @@ class Meridian:
|
|
|
1598
918
|
see [PRNGS and seeds]
|
|
1599
919
|
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
1600
920
|
"""
|
|
1601
|
-
self.prior_sampler_callable(n_draws=n_draws, seed=seed)
|
|
921
|
+
prior_draws = self.prior_sampler_callable(n_draws=n_draws, seed=seed)
|
|
922
|
+
# Create Arviz InferenceData for prior draws.
|
|
923
|
+
prior_coords = self._model_context.create_inference_data_coords(1, n_draws)
|
|
924
|
+
prior_dims = self._model_context.create_inference_data_dims()
|
|
925
|
+
prior_inference_data = az.convert_to_inference_data(
|
|
926
|
+
prior_draws,
|
|
927
|
+
coords=prior_coords,
|
|
928
|
+
dims=prior_dims,
|
|
929
|
+
group=constants.PRIOR,
|
|
930
|
+
)
|
|
931
|
+
self.inference_data.extend(prior_inference_data, join="right")
|
|
1602
932
|
|
|
1603
933
|
def _run_model_fitting_guardrail(self):
|
|
1604
934
|
"""Raises an error if the model has critical EDA issues."""
|
|
1605
935
|
error_findings_by_type: dict[eda_outcome.EDACheckType, list[str]] = (
|
|
1606
936
|
collections.defaultdict(list)
|
|
1607
937
|
)
|
|
1608
|
-
for
|
|
938
|
+
for field in dataclasses.fields(self.eda_outcomes):
|
|
939
|
+
outcome = getattr(self.eda_outcomes, field.name)
|
|
1609
940
|
error_findings = [
|
|
1610
941
|
finding
|
|
1611
942
|
for finding in outcome.findings
|
|
@@ -1709,7 +1040,7 @@ class Meridian:
|
|
|
1709
1040
|
"""
|
|
1710
1041
|
self._run_model_fitting_guardrail()
|
|
1711
1042
|
|
|
1712
|
-
self.posterior_sampler_callable(
|
|
1043
|
+
posterior_inference_data = self.posterior_sampler_callable(
|
|
1713
1044
|
n_chains=n_chains,
|
|
1714
1045
|
n_adapt=n_adapt,
|
|
1715
1046
|
n_burnin=n_burnin,
|
|
@@ -1724,6 +1055,7 @@ class Meridian:
|
|
|
1724
1055
|
seed=seed,
|
|
1725
1056
|
**pins,
|
|
1726
1057
|
)
|
|
1058
|
+
self.inference_data.extend(posterior_inference_data, join="right")
|
|
1727
1059
|
|
|
1728
1060
|
|
|
1729
1061
|
def save_mmm(mmm: Meridian, file_path: str):
|
|
@@ -1737,6 +1069,15 @@ def save_mmm(mmm: Meridian, file_path: str):
|
|
|
1737
1069
|
mmm: Model object to save.
|
|
1738
1070
|
file_path: File path to save a pickled model object.
|
|
1739
1071
|
"""
|
|
1072
|
+
warnings.warn(
|
|
1073
|
+
"save_mmm is deprecated and will be removed in a future release. Please"
|
|
1074
|
+
" use `schema.serde.meridian_serde.save_meridian` instead. See"
|
|
1075
|
+
" https://developers.google.com/meridian/docs/user-guide/saving-model-object"
|
|
1076
|
+
" for details.",
|
|
1077
|
+
DeprecationWarning,
|
|
1078
|
+
stacklevel=2,
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1740
1081
|
if not os.path.exists(os.path.dirname(file_path)):
|
|
1741
1082
|
os.makedirs(os.path.dirname(file_path))
|
|
1742
1083
|
|
|
@@ -1760,6 +1101,15 @@ def load_mmm(file_path: str) -> Meridian:
|
|
|
1760
1101
|
Raises:
|
|
1761
1102
|
FileNotFoundError: If `file_path` does not exist.
|
|
1762
1103
|
"""
|
|
1104
|
+
warnings.warn(
|
|
1105
|
+
"load_mmm is deprecated and will be removed in a future release. Please"
|
|
1106
|
+
" use `meridian.schema.serde.meridian_serde.load_meridian` instead. See"
|
|
1107
|
+
" https://developers.google.com/meridian/docs/user-guide/saving-model-object"
|
|
1108
|
+
" for details.",
|
|
1109
|
+
DeprecationWarning,
|
|
1110
|
+
stacklevel=2,
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1763
1113
|
try:
|
|
1764
1114
|
with open(file_path, "rb") as f:
|
|
1765
1115
|
mmm = joblib.load(f)
|