google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
meridian/analysis/analyzer.py
CHANGED
|
@@ -16,14 +16,18 @@
|
|
|
16
16
|
|
|
17
17
|
from collections.abc import Mapping, Sequence
|
|
18
18
|
import dataclasses
|
|
19
|
+
import functools
|
|
19
20
|
import itertools
|
|
20
21
|
import numbers
|
|
21
22
|
from typing import Any, Optional
|
|
22
23
|
import warnings
|
|
23
24
|
|
|
25
|
+
import arviz as az
|
|
24
26
|
from meridian import backend
|
|
25
27
|
from meridian import constants
|
|
26
28
|
from meridian.model import adstock_hill
|
|
29
|
+
from meridian.model import context
|
|
30
|
+
from meridian.model import equations
|
|
27
31
|
from meridian.model import model
|
|
28
32
|
from meridian.model import transformers
|
|
29
33
|
import numpy as np
|
|
@@ -53,6 +57,27 @@ def _validate_non_media_baseline_values_numbers(
|
|
|
53
57
|
)
|
|
54
58
|
|
|
55
59
|
|
|
60
|
+
# TODO: Remove this method.
|
|
61
|
+
def _get_model_context(
|
|
62
|
+
meridian: model.Meridian | None,
|
|
63
|
+
model_context: context.ModelContext | None,
|
|
64
|
+
) -> context.ModelContext:
|
|
65
|
+
"""Gets `model_context`, handling the deprecated `meridian` argument."""
|
|
66
|
+
if meridian is not None:
|
|
67
|
+
warnings.warn(
|
|
68
|
+
(
|
|
69
|
+
"The `meridian` argument is deprecated and will be removed in a"
|
|
70
|
+
" future version. Use `model_context` instead."
|
|
71
|
+
),
|
|
72
|
+
DeprecationWarning,
|
|
73
|
+
stacklevel=3,
|
|
74
|
+
)
|
|
75
|
+
return meridian.model_context
|
|
76
|
+
if model_context is None:
|
|
77
|
+
raise ValueError("Either `meridian` or `model_context` must be provided.")
|
|
78
|
+
return model_context
|
|
79
|
+
|
|
80
|
+
|
|
56
81
|
@dataclasses.dataclass
|
|
57
82
|
class DataTensors(backend.ExtensionType):
|
|
58
83
|
"""Container for data variable arguments of Analyzer methods.
|
|
@@ -218,30 +243,38 @@ class DataTensors(backend.ExtensionType):
|
|
|
218
243
|
backend.concatenate(spend_tensors, axis=-1) if spend_tensors else None
|
|
219
244
|
)
|
|
220
245
|
|
|
221
|
-
def get_modified_times(
|
|
246
|
+
def get_modified_times(
|
|
247
|
+
self,
|
|
248
|
+
meridian: model.Meridian | None = None,
|
|
249
|
+
model_context: context.ModelContext | None = None,
|
|
250
|
+
) -> int | None:
|
|
222
251
|
"""Returns `n_times` of any tensor where `n_times` has been modified.
|
|
223
252
|
|
|
224
253
|
This method compares the time dimensions of the attributes in the
|
|
225
|
-
`DataTensors` object with the corresponding tensors in the `
|
|
254
|
+
`DataTensors` object with the corresponding tensors in the `model_context`
|
|
226
255
|
object. If any of the time dimensions are different, then this method
|
|
227
256
|
returns the modified number of time periods of the tensor in the
|
|
228
257
|
`DataTensors` object. If all time dimensions are the same, returns `None`.
|
|
229
258
|
|
|
230
259
|
Args:
|
|
231
|
-
meridian: A Meridian object to validate against and get the
|
|
232
|
-
tensors from.
|
|
260
|
+
meridian: Deprecated. A Meridian object to validate against and get the
|
|
261
|
+
original data tensors from. This argument is deprecated and will be
|
|
262
|
+
removed in a future version. Use `model_context` instead.
|
|
263
|
+
model_context: A ModelContext object to validate against and get the
|
|
264
|
+
original data tensors from.
|
|
233
265
|
|
|
234
266
|
Returns:
|
|
235
267
|
The `n_times` of any tensor where `n_times` is different from the times
|
|
236
|
-
of the corresponding tensor in the `
|
|
268
|
+
of the corresponding tensor in the `model_context` object. If all time
|
|
237
269
|
dimensions are the same, returns `None`.
|
|
238
270
|
"""
|
|
271
|
+
model_context = _get_model_context(meridian, model_context)
|
|
239
272
|
for field in dataclasses.fields(self):
|
|
240
273
|
new_tensor = getattr(self, field.name)
|
|
241
274
|
if field.name == constants.RF_IMPRESSIONS:
|
|
242
|
-
old_tensor = getattr(
|
|
275
|
+
old_tensor = getattr(model_context.rf_tensors, field.name)
|
|
243
276
|
else:
|
|
244
|
-
old_tensor = getattr(
|
|
277
|
+
old_tensor = getattr(model_context.input_data, field.name)
|
|
245
278
|
# The time dimension is always the second dimension, except for when spend
|
|
246
279
|
# data is provided with only one dimension of (n_channels).
|
|
247
280
|
if (
|
|
@@ -264,24 +297,28 @@ class DataTensors(backend.ExtensionType):
|
|
|
264
297
|
def validate_and_fill_missing_data(
|
|
265
298
|
self,
|
|
266
299
|
required_tensors_names: Sequence[str],
|
|
267
|
-
meridian: model.Meridian,
|
|
300
|
+
meridian: model.Meridian | None = None,
|
|
301
|
+
model_context: context.ModelContext | None = None,
|
|
268
302
|
allow_modified_times: bool = True,
|
|
269
303
|
) -> Self:
|
|
270
304
|
"""Fills missing data tensors with their original values from the model.
|
|
271
305
|
|
|
272
306
|
This method uses the collection of data tensors set in the DataTensor class
|
|
273
307
|
and fills in the missing tensors with their original values from the
|
|
274
|
-
|
|
275
|
-
["media", "reach", "frequency"]` and only `media`
|
|
276
|
-
class, then this method will output a new
|
|
277
|
-
`media` value in this object plus the values of
|
|
278
|
-
from the `
|
|
308
|
+
ModelContext object that is passed in. For example, if
|
|
309
|
+
`required_tensors_names = ["media", "reach", "frequency"]` and only `media`
|
|
310
|
+
is set in the DataTensors class, then this method will output a new
|
|
311
|
+
DataTensors object with the `media` value in this object plus the values of
|
|
312
|
+
the `reach` and `frequency` from the `model_context` object.
|
|
279
313
|
|
|
280
314
|
Args:
|
|
281
315
|
required_tensors_names: A sequence of data tensors names to validate and
|
|
282
|
-
fill in with the original values from the `
|
|
283
|
-
meridian:
|
|
284
|
-
data tensors from.
|
|
316
|
+
fill in with the original values from the `model_context` object.
|
|
317
|
+
meridian: Deprecated. A Meridian object to validate against and get the
|
|
318
|
+
original data tensors from. This argument is deprecated and will be
|
|
319
|
+
removed in a future version. Use `model_context` instead.
|
|
320
|
+
model_context: A ModelContext object to validate against and get the
|
|
321
|
+
original data tensors from.
|
|
285
322
|
allow_modified_times: A boolean flag indicating whether to allow modified
|
|
286
323
|
time dimensions in the new data tensors. If False, an error will be
|
|
287
324
|
raised if the time dimensions of any tensor is modified.
|
|
@@ -290,15 +327,30 @@ class DataTensors(backend.ExtensionType):
|
|
|
290
327
|
A `DataTensors` container with the original values from the Meridian
|
|
291
328
|
object filled in for the missing data tensors.
|
|
292
329
|
"""
|
|
293
|
-
|
|
294
|
-
self.
|
|
295
|
-
|
|
330
|
+
model_context = _get_model_context(meridian, model_context)
|
|
331
|
+
self._validate_correct_variables_filled(
|
|
332
|
+
required_variables=required_tensors_names,
|
|
333
|
+
model_context=model_context,
|
|
334
|
+
)
|
|
335
|
+
self._validate_geo_dims(
|
|
336
|
+
required_fields=required_tensors_names, model_context=model_context
|
|
337
|
+
)
|
|
338
|
+
self._validate_channel_dims(
|
|
339
|
+
required_fields=required_tensors_names, model_context=model_context
|
|
340
|
+
)
|
|
296
341
|
if allow_modified_times:
|
|
297
|
-
self._validate_time_dims_flexible_times(
|
|
342
|
+
self._validate_time_dims_flexible_times(
|
|
343
|
+
required_fields=required_tensors_names, model_context=model_context
|
|
344
|
+
)
|
|
298
345
|
else:
|
|
299
|
-
self._validate_time_dims(
|
|
346
|
+
self._validate_time_dims(
|
|
347
|
+
required_fields=required_tensors_names, model_context=model_context
|
|
348
|
+
)
|
|
300
349
|
|
|
301
|
-
return self._fill_default_values(
|
|
350
|
+
return self._fill_default_values(
|
|
351
|
+
required_fields=required_tensors_names,
|
|
352
|
+
model_context=model_context,
|
|
353
|
+
)
|
|
302
354
|
|
|
303
355
|
def _validate_n_dims(self):
|
|
304
356
|
"""Raises an error if the tensors have the wrong number of dimensions."""
|
|
@@ -320,18 +372,20 @@ class DataTensors(backend.ExtensionType):
|
|
|
320
372
|
_check_n_dims(tensor, field.name, 3)
|
|
321
373
|
|
|
322
374
|
def _validate_correct_variables_filled(
|
|
323
|
-
self,
|
|
324
|
-
|
|
375
|
+
self,
|
|
376
|
+
required_variables: Sequence[str],
|
|
377
|
+
model_context: context.ModelContext,
|
|
378
|
+
) -> None:
|
|
325
379
|
"""Validates that the correct variables are filled.
|
|
326
380
|
|
|
327
381
|
Args:
|
|
328
382
|
required_variables: A sequence of data tensors names that are required to
|
|
329
383
|
be filled in.
|
|
330
|
-
|
|
384
|
+
model_context: A `ModelContext` object to validate against.
|
|
331
385
|
|
|
332
386
|
Raises:
|
|
333
387
|
ValueError: If an attribute exists in the `DataTensors` object that is not
|
|
334
|
-
in the `
|
|
388
|
+
in the `model_context` object, it is not allowed to be used in analysis.
|
|
335
389
|
Warning: If an attribute exists in the `DataTensors` object that is not in
|
|
336
390
|
the `required_variables` list, it will be ignored.
|
|
337
391
|
"""
|
|
@@ -346,44 +400,48 @@ class DataTensors(backend.ExtensionType):
|
|
|
346
400
|
)
|
|
347
401
|
if field.name in required_variables:
|
|
348
402
|
if field.name == constants.RF_IMPRESSIONS:
|
|
349
|
-
if
|
|
403
|
+
if model_context.n_rf_channels == 0:
|
|
350
404
|
raise ValueError(
|
|
351
405
|
"New `rf_impressions` is not allowed because there are no R&F"
|
|
352
406
|
" channels in the Meridian model."
|
|
353
407
|
)
|
|
354
|
-
elif getattr(
|
|
408
|
+
elif getattr(model_context.input_data, field.name) is None:
|
|
355
409
|
raise ValueError(
|
|
356
410
|
f"New `{field.name}` is not allowed because the input data to the"
|
|
357
411
|
f" Meridian model does not contain `{field.name}`."
|
|
358
412
|
)
|
|
359
413
|
|
|
360
414
|
def _validate_geo_dims(
|
|
361
|
-
self,
|
|
362
|
-
|
|
415
|
+
self,
|
|
416
|
+
required_fields: Sequence[str],
|
|
417
|
+
model_context: context.ModelContext,
|
|
418
|
+
) -> None:
|
|
363
419
|
"""Validates the geo dimension of the specified data variables."""
|
|
364
420
|
for var_name in required_fields:
|
|
365
421
|
new_tensor = getattr(self, var_name)
|
|
366
|
-
if new_tensor is not None and new_tensor.shape[0] !=
|
|
422
|
+
if new_tensor is not None and new_tensor.shape[0] != model_context.n_geos:
|
|
367
423
|
# Skip spend and time data with only 1 dimension.
|
|
368
424
|
if new_tensor.ndim == 1:
|
|
369
425
|
continue
|
|
370
426
|
raise ValueError(
|
|
371
|
-
f"New `{var_name}` is expected to have {
|
|
427
|
+
f"New `{var_name}` is expected to have {model_context.n_geos}"
|
|
372
428
|
f" geos. Found {new_tensor.shape[0]} geos."
|
|
373
429
|
)
|
|
374
430
|
|
|
375
431
|
def _validate_channel_dims(
|
|
376
|
-
self,
|
|
377
|
-
|
|
432
|
+
self,
|
|
433
|
+
required_fields: Sequence[str],
|
|
434
|
+
model_context: context.ModelContext,
|
|
435
|
+
) -> None:
|
|
378
436
|
"""Validates the channel dimension of the specified data variables."""
|
|
379
437
|
for var_name in required_fields:
|
|
380
438
|
if var_name in [constants.REVENUE_PER_KPI, constants.TIME]:
|
|
381
439
|
continue
|
|
382
440
|
new_tensor = getattr(self, var_name)
|
|
383
441
|
if var_name == constants.RF_IMPRESSIONS:
|
|
384
|
-
old_tensor = getattr(
|
|
442
|
+
old_tensor = getattr(model_context.rf_tensors, var_name)
|
|
385
443
|
else:
|
|
386
|
-
old_tensor = getattr(
|
|
444
|
+
old_tensor = getattr(model_context.input_data, var_name)
|
|
387
445
|
if new_tensor is not None:
|
|
388
446
|
assert old_tensor is not None
|
|
389
447
|
if new_tensor.shape[-1] != old_tensor.shape[-1]:
|
|
@@ -393,15 +451,17 @@ class DataTensors(backend.ExtensionType):
|
|
|
393
451
|
)
|
|
394
452
|
|
|
395
453
|
def _validate_time_dims(
|
|
396
|
-
self,
|
|
454
|
+
self,
|
|
455
|
+
required_fields: Sequence[str],
|
|
456
|
+
model_context: context.ModelContext,
|
|
397
457
|
):
|
|
398
458
|
"""Validates the time dimension of the specified data variables."""
|
|
399
459
|
for var_name in required_fields:
|
|
400
460
|
new_tensor = getattr(self, var_name)
|
|
401
461
|
if var_name == constants.RF_IMPRESSIONS:
|
|
402
|
-
old_tensor = getattr(
|
|
462
|
+
old_tensor = getattr(model_context.rf_tensors, var_name)
|
|
403
463
|
else:
|
|
404
|
-
old_tensor = getattr(
|
|
464
|
+
old_tensor = getattr(model_context.input_data, var_name)
|
|
405
465
|
|
|
406
466
|
# Skip spend data with only 1 dimension of (n_channels).
|
|
407
467
|
if (
|
|
@@ -428,10 +488,12 @@ class DataTensors(backend.ExtensionType):
|
|
|
428
488
|
)
|
|
429
489
|
|
|
430
490
|
def _validate_time_dims_flexible_times(
|
|
431
|
-
self,
|
|
491
|
+
self,
|
|
492
|
+
required_fields: Sequence[str],
|
|
493
|
+
model_context: context.ModelContext,
|
|
432
494
|
):
|
|
433
495
|
"""Validates the time dimension for the flexible times case."""
|
|
434
|
-
new_n_times = self.get_modified_times(
|
|
496
|
+
new_n_times = self.get_modified_times(model_context=model_context)
|
|
435
497
|
# If no times were modified, then there is nothing more to validate.
|
|
436
498
|
if new_n_times is None:
|
|
437
499
|
return
|
|
@@ -440,9 +502,9 @@ class DataTensors(backend.ExtensionType):
|
|
|
440
502
|
for var_name in required_fields:
|
|
441
503
|
new_tensor = getattr(self, var_name)
|
|
442
504
|
if var_name == constants.RF_IMPRESSIONS:
|
|
443
|
-
old_tensor = getattr(
|
|
505
|
+
old_tensor = getattr(model_context.rf_tensors, var_name)
|
|
444
506
|
else:
|
|
445
|
-
old_tensor = getattr(
|
|
507
|
+
old_tensor = getattr(model_context.input_data, var_name)
|
|
446
508
|
|
|
447
509
|
if old_tensor is None:
|
|
448
510
|
continue
|
|
@@ -484,7 +546,9 @@ class DataTensors(backend.ExtensionType):
|
|
|
484
546
|
)
|
|
485
547
|
|
|
486
548
|
def _fill_default_values(
|
|
487
|
-
self,
|
|
549
|
+
self,
|
|
550
|
+
required_fields: Sequence[str],
|
|
551
|
+
model_context: context.ModelContext,
|
|
488
552
|
) -> Self:
|
|
489
553
|
"""Fills default values and returns a new DataTensors object."""
|
|
490
554
|
output = {}
|
|
@@ -493,23 +557,23 @@ class DataTensors(backend.ExtensionType):
|
|
|
493
557
|
if var_name not in required_fields:
|
|
494
558
|
continue
|
|
495
559
|
|
|
496
|
-
if hasattr(
|
|
497
|
-
old_tensor = getattr(
|
|
498
|
-
elif hasattr(
|
|
499
|
-
old_tensor = getattr(
|
|
500
|
-
elif hasattr(
|
|
501
|
-
old_tensor = getattr(
|
|
502
|
-
elif hasattr(
|
|
503
|
-
old_tensor = getattr(
|
|
560
|
+
if hasattr(model_context.media_tensors, var_name):
|
|
561
|
+
old_tensor = getattr(model_context.media_tensors, var_name)
|
|
562
|
+
elif hasattr(model_context.rf_tensors, var_name):
|
|
563
|
+
old_tensor = getattr(model_context.rf_tensors, var_name)
|
|
564
|
+
elif hasattr(model_context.organic_media_tensors, var_name):
|
|
565
|
+
old_tensor = getattr(model_context.organic_media_tensors, var_name)
|
|
566
|
+
elif hasattr(model_context.organic_rf_tensors, var_name):
|
|
567
|
+
old_tensor = getattr(model_context.organic_rf_tensors, var_name)
|
|
504
568
|
elif var_name == constants.NON_MEDIA_TREATMENTS:
|
|
505
|
-
old_tensor =
|
|
569
|
+
old_tensor = model_context.non_media_treatments
|
|
506
570
|
elif var_name == constants.CONTROLS:
|
|
507
|
-
old_tensor =
|
|
571
|
+
old_tensor = model_context.controls
|
|
508
572
|
elif var_name == constants.REVENUE_PER_KPI:
|
|
509
|
-
old_tensor =
|
|
573
|
+
old_tensor = model_context.revenue_per_kpi
|
|
510
574
|
elif var_name == constants.TIME:
|
|
511
575
|
old_tensor = backend.to_tensor(
|
|
512
|
-
|
|
576
|
+
model_context.input_data.time.values.tolist(), dtype=backend.string
|
|
513
577
|
)
|
|
514
578
|
else:
|
|
515
579
|
continue
|
|
@@ -858,12 +922,38 @@ def _central_tendency_and_ci_by_prior_and_posterior(
|
|
|
858
922
|
class Analyzer:
|
|
859
923
|
"""Runs calculations to analyze the raw data after fitting the model."""
|
|
860
924
|
|
|
861
|
-
def __init__(
|
|
862
|
-
|
|
925
|
+
def __init__(
|
|
926
|
+
self,
|
|
927
|
+
# TODO: Remove this argument.
|
|
928
|
+
meridian: model.Meridian | None = None,
|
|
929
|
+
*,
|
|
930
|
+
model_context: context.ModelContext | None = None,
|
|
931
|
+
inference_data: az.InferenceData | None = None,
|
|
932
|
+
):
|
|
863
933
|
# Make the meridian object ready for methods in this analyzer that create
|
|
864
934
|
# backend.function computation graphs: it should be frozen for no more
|
|
865
935
|
# internal states mutation before those graphs execute.
|
|
866
|
-
self.
|
|
936
|
+
self.model_context = _get_model_context(meridian, model_context)
|
|
937
|
+
|
|
938
|
+
if meridian is not None:
|
|
939
|
+
self._inference_data = inference_data or meridian.inference_data
|
|
940
|
+
elif model_context is None or inference_data is None:
|
|
941
|
+
raise ValueError(
|
|
942
|
+
"If `meridian` is not provided, then `model_context` and"
|
|
943
|
+
" `inference_data` must be provided."
|
|
944
|
+
)
|
|
945
|
+
else:
|
|
946
|
+
self._inference_data = inference_data
|
|
947
|
+
|
|
948
|
+
self.model_context.populate_cached_properties()
|
|
949
|
+
|
|
950
|
+
@functools.cached_property
|
|
951
|
+
def _model_equations(self) -> equations.ModelEquations:
|
|
952
|
+
return equations.ModelEquations(self.model_context)
|
|
953
|
+
|
|
954
|
+
@property
|
|
955
|
+
def inference_data(self) -> az.InferenceData:
|
|
956
|
+
return self._inference_data
|
|
867
957
|
|
|
868
958
|
@backend.function(jit_compile=True)
|
|
869
959
|
def _get_kpi_means(
|
|
@@ -902,7 +992,7 @@ class Analyzer:
|
|
|
902
992
|
result = tau_gt + backend.einsum(
|
|
903
993
|
"...gtm,...gm->...gt", combined_media_transformed, combined_beta
|
|
904
994
|
)
|
|
905
|
-
if self.
|
|
995
|
+
if self.model_context.controls is not None:
|
|
906
996
|
result += backend.einsum(
|
|
907
997
|
"...gtc,...gc->...gt",
|
|
908
998
|
data_tensors.controls,
|
|
@@ -933,20 +1023,20 @@ class Analyzer:
|
|
|
933
1023
|
UserWarning: If the KPI type is revenue and use_kpi is True or if
|
|
934
1024
|
`use_kpi=False` but `revenue_per_kpi` is not available.
|
|
935
1025
|
"""
|
|
936
|
-
if use_kpi and self.
|
|
1026
|
+
if use_kpi and self.model_context.input_data.kpi_type == constants.REVENUE:
|
|
937
1027
|
warnings.warn(
|
|
938
1028
|
"Setting `use_kpi=True` has no effect when `kpi_type=REVENUE`"
|
|
939
1029
|
" since in this case, KPI is equal to revenue."
|
|
940
1030
|
)
|
|
941
1031
|
return False
|
|
942
1032
|
|
|
943
|
-
if not use_kpi and self.
|
|
1033
|
+
if not use_kpi and self.model_context.input_data.revenue_per_kpi is None:
|
|
944
1034
|
warnings.warn(
|
|
945
1035
|
"Revenue analysis is not available when `revenue_per_kpi` is"
|
|
946
1036
|
" unknown. Defaulting to KPI analysis."
|
|
947
1037
|
)
|
|
948
1038
|
|
|
949
|
-
return use_kpi or self.
|
|
1039
|
+
return use_kpi or self.model_context.input_data.revenue_per_kpi is None
|
|
950
1040
|
|
|
951
1041
|
def _get_adstock_dataframe(
|
|
952
1042
|
self,
|
|
@@ -972,36 +1062,37 @@ class Analyzer:
|
|
|
972
1062
|
ci_lo, and mean decayed effects for either media or RF channel types.
|
|
973
1063
|
"""
|
|
974
1064
|
window_size = min(
|
|
975
|
-
self.
|
|
1065
|
+
self.model_context.model_spec.max_lag + 1,
|
|
1066
|
+
self.model_context.n_media_times,
|
|
976
1067
|
)
|
|
977
1068
|
if channel_type == constants.MEDIA:
|
|
978
|
-
prior = self.
|
|
1069
|
+
prior = self._inference_data.prior.alpha_m.values[0]
|
|
979
1070
|
posterior = np.reshape(
|
|
980
|
-
self.
|
|
981
|
-
(-1, self.
|
|
1071
|
+
self._inference_data.posterior.alpha_m.values,
|
|
1072
|
+
(-1, self.model_context.n_media_channels),
|
|
982
1073
|
)
|
|
983
|
-
decay_functions = self.
|
|
1074
|
+
decay_functions = self.model_context.adstock_decay_spec.media
|
|
984
1075
|
elif channel_type == constants.RF:
|
|
985
|
-
prior = self.
|
|
1076
|
+
prior = self._inference_data.prior.alpha_rf.values[0]
|
|
986
1077
|
posterior = np.reshape(
|
|
987
|
-
self.
|
|
988
|
-
(-1, self.
|
|
1078
|
+
self._inference_data.posterior.alpha_rf.values,
|
|
1079
|
+
(-1, self.model_context.n_rf_channels),
|
|
989
1080
|
)
|
|
990
|
-
decay_functions = self.
|
|
1081
|
+
decay_functions = self.model_context.adstock_decay_spec.rf
|
|
991
1082
|
elif channel_type == constants.ORGANIC_MEDIA:
|
|
992
|
-
prior = self.
|
|
1083
|
+
prior = self._inference_data.prior.alpha_om.values[0]
|
|
993
1084
|
posterior = np.reshape(
|
|
994
|
-
self.
|
|
995
|
-
(-1, self.
|
|
1085
|
+
self._inference_data.posterior.alpha_om.values,
|
|
1086
|
+
(-1, self.model_context.n_organic_media_channels),
|
|
996
1087
|
)
|
|
997
|
-
decay_functions = self.
|
|
1088
|
+
decay_functions = self.model_context.adstock_decay_spec.organic_media
|
|
998
1089
|
elif channel_type == constants.ORGANIC_RF:
|
|
999
|
-
prior = self.
|
|
1090
|
+
prior = self._inference_data.prior.alpha_orf.values[0]
|
|
1000
1091
|
posterior = np.reshape(
|
|
1001
|
-
self.
|
|
1002
|
-
(-1, self.
|
|
1092
|
+
self._inference_data.posterior.alpha_orf.values,
|
|
1093
|
+
(-1, self.model_context.n_organic_rf_channels),
|
|
1003
1094
|
)
|
|
1004
|
-
decay_functions = self.
|
|
1095
|
+
decay_functions = self.model_context.adstock_decay_spec.organic_rf
|
|
1005
1096
|
else:
|
|
1006
1097
|
raise ValueError(
|
|
1007
1098
|
f"Unsupported channel type for adstock decay: '{channel_type}'. "
|
|
@@ -1103,65 +1194,65 @@ class Analyzer:
|
|
|
1103
1194
|
"""
|
|
1104
1195
|
if new_data is None:
|
|
1105
1196
|
return DataTensors(
|
|
1106
|
-
media=self.
|
|
1107
|
-
reach=self.
|
|
1108
|
-
frequency=self.
|
|
1109
|
-
organic_media=self.
|
|
1110
|
-
organic_reach=self.
|
|
1111
|
-
organic_frequency=self.
|
|
1112
|
-
non_media_treatments=self.
|
|
1113
|
-
controls=self.
|
|
1114
|
-
revenue_per_kpi=self.
|
|
1197
|
+
media=self.model_context.media_tensors.media_scaled,
|
|
1198
|
+
reach=self.model_context.rf_tensors.reach_scaled,
|
|
1199
|
+
frequency=self.model_context.rf_tensors.frequency,
|
|
1200
|
+
organic_media=self.model_context.organic_media_tensors.organic_media_scaled,
|
|
1201
|
+
organic_reach=self.model_context.organic_rf_tensors.organic_reach_scaled,
|
|
1202
|
+
organic_frequency=self.model_context.organic_rf_tensors.organic_frequency,
|
|
1203
|
+
non_media_treatments=self.model_context.non_media_treatments_normalized,
|
|
1204
|
+
controls=self.model_context.controls_scaled,
|
|
1205
|
+
revenue_per_kpi=self.model_context.revenue_per_kpi,
|
|
1115
1206
|
)
|
|
1116
1207
|
media_scaled = _transformed_new_or_scaled(
|
|
1117
1208
|
new_variable=new_data.media,
|
|
1118
|
-
transformer=self.
|
|
1119
|
-
scaled_variable=self.
|
|
1209
|
+
transformer=self.model_context.media_tensors.media_transformer,
|
|
1210
|
+
scaled_variable=self.model_context.media_tensors.media_scaled,
|
|
1120
1211
|
)
|
|
1121
1212
|
|
|
1122
1213
|
reach_scaled = _transformed_new_or_scaled(
|
|
1123
1214
|
new_variable=new_data.reach,
|
|
1124
|
-
transformer=self.
|
|
1125
|
-
scaled_variable=self.
|
|
1215
|
+
transformer=self.model_context.rf_tensors.reach_transformer,
|
|
1216
|
+
scaled_variable=self.model_context.rf_tensors.reach_scaled,
|
|
1126
1217
|
)
|
|
1127
1218
|
|
|
1128
1219
|
frequency = (
|
|
1129
1220
|
new_data.frequency
|
|
1130
1221
|
if new_data.frequency is not None
|
|
1131
|
-
else self.
|
|
1222
|
+
else self.model_context.rf_tensors.frequency
|
|
1132
1223
|
)
|
|
1133
1224
|
|
|
1134
1225
|
controls_scaled = _transformed_new_or_scaled(
|
|
1135
1226
|
new_variable=new_data.controls,
|
|
1136
|
-
transformer=self.
|
|
1137
|
-
scaled_variable=self.
|
|
1227
|
+
transformer=self.model_context.controls_transformer,
|
|
1228
|
+
scaled_variable=self.model_context.controls_scaled,
|
|
1138
1229
|
)
|
|
1139
1230
|
revenue_per_kpi = (
|
|
1140
1231
|
new_data.revenue_per_kpi
|
|
1141
1232
|
if new_data.revenue_per_kpi is not None
|
|
1142
|
-
else self.
|
|
1233
|
+
else self.model_context.revenue_per_kpi
|
|
1143
1234
|
)
|
|
1144
1235
|
|
|
1145
1236
|
if include_non_paid_channels:
|
|
1146
1237
|
organic_media_scaled = _transformed_new_or_scaled(
|
|
1147
1238
|
new_variable=new_data.organic_media,
|
|
1148
|
-
transformer=self.
|
|
1149
|
-
scaled_variable=self.
|
|
1239
|
+
transformer=self.model_context.organic_media_tensors.organic_media_transformer,
|
|
1240
|
+
scaled_variable=self.model_context.organic_media_tensors.organic_media_scaled,
|
|
1150
1241
|
)
|
|
1151
1242
|
organic_reach_scaled = _transformed_new_or_scaled(
|
|
1152
1243
|
new_variable=new_data.organic_reach,
|
|
1153
|
-
transformer=self.
|
|
1154
|
-
scaled_variable=self.
|
|
1244
|
+
transformer=self.model_context.organic_rf_tensors.organic_reach_transformer,
|
|
1245
|
+
scaled_variable=self.model_context.organic_rf_tensors.organic_reach_scaled,
|
|
1155
1246
|
)
|
|
1156
1247
|
organic_frequency = (
|
|
1157
1248
|
new_data.organic_frequency
|
|
1158
1249
|
if new_data.organic_frequency is not None
|
|
1159
|
-
else self.
|
|
1250
|
+
else self.model_context.organic_rf_tensors.organic_frequency
|
|
1160
1251
|
)
|
|
1161
1252
|
non_media_treatments_normalized = _transformed_new_or_scaled(
|
|
1162
1253
|
new_variable=new_data.non_media_treatments,
|
|
1163
|
-
transformer=self.
|
|
1164
|
-
scaled_variable=self.
|
|
1254
|
+
transformer=self.model_context.non_media_transformer,
|
|
1255
|
+
scaled_variable=self.model_context.non_media_treatments_normalized,
|
|
1165
1256
|
)
|
|
1166
1257
|
return DataTensors(
|
|
1167
1258
|
media=media_scaled,
|
|
@@ -1198,14 +1289,14 @@ class Analyzer:
|
|
|
1198
1289
|
and organic RF parameters names in inference data.
|
|
1199
1290
|
"""
|
|
1200
1291
|
params = []
|
|
1201
|
-
if self.
|
|
1292
|
+
if self.model_context.media_tensors.media is not None:
|
|
1202
1293
|
params.extend([
|
|
1203
1294
|
constants.EC_M,
|
|
1204
1295
|
constants.SLOPE_M,
|
|
1205
1296
|
constants.ALPHA_M,
|
|
1206
1297
|
constants.BETA_GM,
|
|
1207
1298
|
])
|
|
1208
|
-
if self.
|
|
1299
|
+
if self.model_context.rf_tensors.reach is not None:
|
|
1209
1300
|
params.extend([
|
|
1210
1301
|
constants.EC_RF,
|
|
1211
1302
|
constants.SLOPE_RF,
|
|
@@ -1213,21 +1304,21 @@ class Analyzer:
|
|
|
1213
1304
|
constants.BETA_GRF,
|
|
1214
1305
|
])
|
|
1215
1306
|
if include_non_paid_channels:
|
|
1216
|
-
if self.
|
|
1307
|
+
if self.model_context.organic_media_tensors.organic_media is not None:
|
|
1217
1308
|
params.extend([
|
|
1218
1309
|
constants.EC_OM,
|
|
1219
1310
|
constants.SLOPE_OM,
|
|
1220
1311
|
constants.ALPHA_OM,
|
|
1221
1312
|
constants.BETA_GOM,
|
|
1222
1313
|
])
|
|
1223
|
-
if self.
|
|
1314
|
+
if self.model_context.organic_rf_tensors.organic_reach is not None:
|
|
1224
1315
|
params.extend([
|
|
1225
1316
|
constants.EC_ORF,
|
|
1226
1317
|
constants.SLOPE_ORF,
|
|
1227
1318
|
constants.ALPHA_ORF,
|
|
1228
1319
|
constants.BETA_GORF,
|
|
1229
1320
|
])
|
|
1230
|
-
if self.
|
|
1321
|
+
if self.model_context.non_media_treatments is not None:
|
|
1231
1322
|
params.extend([
|
|
1232
1323
|
constants.GAMMA_GN,
|
|
1233
1324
|
])
|
|
@@ -1261,12 +1352,12 @@ class Analyzer:
|
|
|
1261
1352
|
combined_betas = []
|
|
1262
1353
|
if data_tensors.media is not None:
|
|
1263
1354
|
combined_medias.append(
|
|
1264
|
-
self.
|
|
1355
|
+
self._model_equations.adstock_hill_media(
|
|
1265
1356
|
media=data_tensors.media,
|
|
1266
1357
|
alpha=dist_tensors.alpha_m,
|
|
1267
1358
|
ec=dist_tensors.ec_m,
|
|
1268
1359
|
slope=dist_tensors.slope_m,
|
|
1269
|
-
decay_functions=self.
|
|
1360
|
+
decay_functions=self.model_context.adstock_decay_spec.media,
|
|
1270
1361
|
n_times_output=n_times_output,
|
|
1271
1362
|
)
|
|
1272
1363
|
)
|
|
@@ -1274,38 +1365,38 @@ class Analyzer:
|
|
|
1274
1365
|
|
|
1275
1366
|
if data_tensors.reach is not None:
|
|
1276
1367
|
combined_medias.append(
|
|
1277
|
-
self.
|
|
1368
|
+
self._model_equations.adstock_hill_rf(
|
|
1278
1369
|
reach=data_tensors.reach,
|
|
1279
1370
|
frequency=data_tensors.frequency,
|
|
1280
1371
|
alpha=dist_tensors.alpha_rf,
|
|
1281
1372
|
ec=dist_tensors.ec_rf,
|
|
1282
1373
|
slope=dist_tensors.slope_rf,
|
|
1283
|
-
decay_functions=self.
|
|
1374
|
+
decay_functions=self.model_context.adstock_decay_spec.rf,
|
|
1284
1375
|
n_times_output=n_times_output,
|
|
1285
1376
|
)
|
|
1286
1377
|
)
|
|
1287
1378
|
combined_betas.append(dist_tensors.beta_grf)
|
|
1288
1379
|
if data_tensors.organic_media is not None:
|
|
1289
1380
|
combined_medias.append(
|
|
1290
|
-
self.
|
|
1381
|
+
self._model_equations.adstock_hill_media(
|
|
1291
1382
|
media=data_tensors.organic_media,
|
|
1292
1383
|
alpha=dist_tensors.alpha_om,
|
|
1293
1384
|
ec=dist_tensors.ec_om,
|
|
1294
1385
|
slope=dist_tensors.slope_om,
|
|
1295
|
-
decay_functions=self.
|
|
1386
|
+
decay_functions=self.model_context.adstock_decay_spec.organic_media,
|
|
1296
1387
|
n_times_output=n_times_output,
|
|
1297
1388
|
)
|
|
1298
1389
|
)
|
|
1299
1390
|
combined_betas.append(dist_tensors.beta_gom)
|
|
1300
1391
|
if data_tensors.organic_reach is not None:
|
|
1301
1392
|
combined_medias.append(
|
|
1302
|
-
self.
|
|
1393
|
+
self._model_equations.adstock_hill_rf(
|
|
1303
1394
|
reach=data_tensors.organic_reach,
|
|
1304
1395
|
frequency=data_tensors.organic_frequency,
|
|
1305
1396
|
alpha=dist_tensors.alpha_orf,
|
|
1306
1397
|
ec=dist_tensors.ec_orf,
|
|
1307
1398
|
slope=dist_tensors.slope_orf,
|
|
1308
|
-
decay_functions=self.
|
|
1399
|
+
decay_functions=self.model_context.adstock_decay_spec.organic_rf,
|
|
1309
1400
|
n_times_output=n_times_output,
|
|
1310
1401
|
)
|
|
1311
1402
|
)
|
|
@@ -1355,7 +1446,7 @@ class Analyzer:
|
|
|
1355
1446
|
Returns:
|
|
1356
1447
|
A tensor with filtered and/or aggregated geo and time dimensions.
|
|
1357
1448
|
"""
|
|
1358
|
-
|
|
1449
|
+
m_context = self.model_context
|
|
1359
1450
|
|
|
1360
1451
|
# Validate the tensor shape and determine if it has a media dimension.
|
|
1361
1452
|
if flexible_time_dim:
|
|
@@ -1367,17 +1458,17 @@ class Analyzer:
|
|
|
1367
1458
|
)
|
|
1368
1459
|
n_times = tensor.shape[-2] if has_media_dim else tensor.shape[-1]
|
|
1369
1460
|
else:
|
|
1370
|
-
n_times =
|
|
1461
|
+
n_times = m_context.n_times
|
|
1371
1462
|
# Allowed subsets of channels: media, RF, media+RF, all channels.
|
|
1372
1463
|
allowed_n_channels = [
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
+
|
|
1378
|
-
+
|
|
1379
|
-
+
|
|
1380
|
-
+
|
|
1464
|
+
m_context.n_media_channels,
|
|
1465
|
+
m_context.n_rf_channels,
|
|
1466
|
+
m_context.n_media_channels + m_context.n_rf_channels,
|
|
1467
|
+
m_context.n_media_channels
|
|
1468
|
+
+ m_context.n_rf_channels
|
|
1469
|
+
+ m_context.n_non_media_channels
|
|
1470
|
+
+ m_context.n_organic_media_channels
|
|
1471
|
+
+ m_context.n_organic_rf_channels,
|
|
1381
1472
|
]
|
|
1382
1473
|
# Allow extra channel if aggregated (All_Channels) value is included.
|
|
1383
1474
|
allowed_channel_dim = allowed_n_channels + [
|
|
@@ -1386,10 +1477,10 @@ class Analyzer:
|
|
|
1386
1477
|
expected_shapes_w_media = [
|
|
1387
1478
|
backend.TensorShape(shape)
|
|
1388
1479
|
for shape in itertools.product(
|
|
1389
|
-
[
|
|
1480
|
+
[m_context.n_geos], [n_times], allowed_channel_dim
|
|
1390
1481
|
)
|
|
1391
1482
|
]
|
|
1392
|
-
expected_shape_wo_media = backend.TensorShape([
|
|
1483
|
+
expected_shape_wo_media = backend.TensorShape([m_context.n_geos, n_times])
|
|
1393
1484
|
if not flexible_time_dim:
|
|
1394
1485
|
if tensor.shape[-3:] in expected_shapes_w_media:
|
|
1395
1486
|
has_media_dim = True
|
|
@@ -1417,13 +1508,15 @@ class Analyzer:
|
|
|
1417
1508
|
|
|
1418
1509
|
# Validate the selected geo and time dimensions and create a mask.
|
|
1419
1510
|
if selected_geos is not None:
|
|
1420
|
-
if any(geo not in
|
|
1511
|
+
if any(geo not in m_context.input_data.geo for geo in selected_geos):
|
|
1421
1512
|
raise ValueError(
|
|
1422
1513
|
"`selected_geos` must match the geo dimension names from "
|
|
1423
1514
|
"meridian.InputData."
|
|
1424
1515
|
)
|
|
1425
1516
|
geo_indices = [
|
|
1426
|
-
i
|
|
1517
|
+
i
|
|
1518
|
+
for i, x in enumerate(m_context.input_data.geo)
|
|
1519
|
+
if x in selected_geos
|
|
1427
1520
|
]
|
|
1428
1521
|
tensor = backend.gather(
|
|
1429
1522
|
tensor,
|
|
@@ -1434,14 +1527,16 @@ class Analyzer:
|
|
|
1434
1527
|
if selected_times is not None:
|
|
1435
1528
|
_validate_selected_times(
|
|
1436
1529
|
selected_times=selected_times,
|
|
1437
|
-
input_times=
|
|
1530
|
+
input_times=m_context.input_data.time,
|
|
1438
1531
|
n_times=tensor.shape[time_dim],
|
|
1439
1532
|
arg_name="selected_times",
|
|
1440
1533
|
comparison_arg_name="`tensor`",
|
|
1441
1534
|
)
|
|
1442
1535
|
if _is_str_list(selected_times):
|
|
1443
1536
|
time_indices = [
|
|
1444
|
-
i
|
|
1537
|
+
i
|
|
1538
|
+
for i, x in enumerate(m_context.input_data.time)
|
|
1539
|
+
if x in selected_times
|
|
1445
1540
|
]
|
|
1446
1541
|
tensor = backend.gather(
|
|
1447
1542
|
tensor,
|
|
@@ -1558,13 +1653,13 @@ class Analyzer:
|
|
|
1558
1653
|
"""
|
|
1559
1654
|
use_kpi = self._use_kpi(use_kpi)
|
|
1560
1655
|
self._check_kpi_transformation(inverse_transform_outcome, use_kpi)
|
|
1561
|
-
if self.
|
|
1656
|
+
if self.model_context.is_national:
|
|
1562
1657
|
_warn_if_geo_arg_in_kwargs(
|
|
1563
1658
|
aggregate_geos=aggregate_geos,
|
|
1564
1659
|
selected_geos=selected_geos,
|
|
1565
1660
|
)
|
|
1566
1661
|
dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
|
|
1567
|
-
if dist_type not in self.
|
|
1662
|
+
if dist_type not in self._inference_data.groups():
|
|
1568
1663
|
raise model.NotFittedModelError(
|
|
1569
1664
|
f"sample_{dist_type}() must be called prior to calling"
|
|
1570
1665
|
" `expected_outcome()`."
|
|
@@ -1577,14 +1672,14 @@ class Analyzer:
|
|
|
1577
1672
|
)
|
|
1578
1673
|
filled_tensors = new_data.validate_and_fill_missing_data(
|
|
1579
1674
|
required_tensors_names=required_fields,
|
|
1580
|
-
|
|
1675
|
+
model_context=self.model_context,
|
|
1581
1676
|
allow_modified_times=False,
|
|
1582
1677
|
)
|
|
1583
1678
|
|
|
1584
1679
|
params = (
|
|
1585
|
-
self.
|
|
1680
|
+
self._inference_data.posterior
|
|
1586
1681
|
if use_posterior
|
|
1587
|
-
else self.
|
|
1682
|
+
else self._inference_data.prior
|
|
1588
1683
|
)
|
|
1589
1684
|
# We always compute the expected outcome of all channels, including non-paid
|
|
1590
1685
|
# channels.
|
|
@@ -1596,7 +1691,7 @@ class Analyzer:
|
|
|
1596
1691
|
n_draws = params.draw.size
|
|
1597
1692
|
n_chains = params.chain.size
|
|
1598
1693
|
outcome_means = backend.zeros(
|
|
1599
|
-
(n_chains, 0, self.
|
|
1694
|
+
(n_chains, 0, self.model_context.n_geos, self.model_context.n_times)
|
|
1600
1695
|
)
|
|
1601
1696
|
batch_starting_indices = np.arange(n_draws, step=batch_size)
|
|
1602
1697
|
param_list = (
|
|
@@ -1604,7 +1699,7 @@ class Analyzer:
|
|
|
1604
1699
|
constants.MU_T,
|
|
1605
1700
|
constants.TAU_G,
|
|
1606
1701
|
]
|
|
1607
|
-
+ ([constants.GAMMA_GC] if self.
|
|
1702
|
+
+ ([constants.GAMMA_GC] if self.model_context.n_controls else [])
|
|
1608
1703
|
+ self._get_causal_param_names(include_non_paid_channels=True)
|
|
1609
1704
|
)
|
|
1610
1705
|
outcome_means_temps = []
|
|
@@ -1626,7 +1721,7 @@ class Analyzer:
|
|
|
1626
1721
|
[outcome_means, *outcome_means_temps], axis=1
|
|
1627
1722
|
)
|
|
1628
1723
|
if inverse_transform_outcome:
|
|
1629
|
-
outcome_means = self.
|
|
1724
|
+
outcome_means = self.model_context.kpi_transformer.inverse(outcome_means)
|
|
1630
1725
|
if not use_kpi:
|
|
1631
1726
|
outcome_means *= filled_tensors.revenue_per_kpi
|
|
1632
1727
|
|
|
@@ -1700,7 +1795,7 @@ class Analyzer:
|
|
|
1700
1795
|
" `_get_incremental_kpi` when `non_media_treatments` data is"
|
|
1701
1796
|
" present."
|
|
1702
1797
|
)
|
|
1703
|
-
n_media_times = self.
|
|
1798
|
+
n_media_times = self.model_context.n_media_times
|
|
1704
1799
|
if data_tensors.media is not None:
|
|
1705
1800
|
n_times = data_tensors.media.shape[1] # pytype: disable=attribute-error
|
|
1706
1801
|
n_times_output = n_times if n_times != n_media_times else None
|
|
@@ -1759,11 +1854,11 @@ class Analyzer:
|
|
|
1759
1854
|
"""
|
|
1760
1855
|
use_kpi = self._use_kpi(use_kpi)
|
|
1761
1856
|
if revenue_per_kpi is None:
|
|
1762
|
-
revenue_per_kpi = self.
|
|
1763
|
-
t1 = self.
|
|
1857
|
+
revenue_per_kpi = self.model_context.revenue_per_kpi
|
|
1858
|
+
t1 = self.model_context.kpi_transformer.inverse(
|
|
1764
1859
|
backend.einsum("...m->m...", modeled_incremental_outcome)
|
|
1765
1860
|
)
|
|
1766
|
-
t2 = self.
|
|
1861
|
+
t2 = self.model_context.kpi_transformer.inverse(backend.zeros_like(t1))
|
|
1767
1862
|
kpi = backend.einsum("m...->...m", t1 - t2)
|
|
1768
1863
|
|
|
1769
1864
|
if use_kpi:
|
|
@@ -1884,7 +1979,7 @@ class Analyzer:
|
|
|
1884
1979
|
has_media_dim=True,
|
|
1885
1980
|
)
|
|
1886
1981
|
|
|
1887
|
-
# TODO:
|
|
1982
|
+
# TODO: Add support for `new_data.time`.
|
|
1888
1983
|
def incremental_outcome(
|
|
1889
1984
|
self,
|
|
1890
1985
|
use_posterior: bool = True,
|
|
@@ -2050,10 +2145,10 @@ class Analyzer:
|
|
|
2050
2145
|
dimension and not all treatment variables are provided in `new_data`
|
|
2051
2146
|
with matching time dimensions.
|
|
2052
2147
|
"""
|
|
2053
|
-
|
|
2148
|
+
m_context = self.model_context
|
|
2054
2149
|
use_kpi = self._use_kpi(use_kpi)
|
|
2055
2150
|
self._check_kpi_transformation(inverse_transform_outcome, use_kpi)
|
|
2056
|
-
if
|
|
2151
|
+
if m_context.is_national:
|
|
2057
2152
|
_warn_if_geo_arg_in_kwargs(
|
|
2058
2153
|
aggregate_geos=aggregate_geos,
|
|
2059
2154
|
selected_geos=selected_geos,
|
|
@@ -2061,7 +2156,7 @@ class Analyzer:
|
|
|
2061
2156
|
_validate_non_media_baseline_values_numbers(non_media_baseline_values)
|
|
2062
2157
|
dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
|
|
2063
2158
|
|
|
2064
|
-
if dist_type not in
|
|
2159
|
+
if dist_type not in self.inference_data.groups():
|
|
2065
2160
|
raise model.NotFittedModelError(
|
|
2066
2161
|
f"sample_{dist_type}() must be called prior to calling this method."
|
|
2067
2162
|
)
|
|
@@ -2084,23 +2179,26 @@ class Analyzer:
|
|
|
2084
2179
|
if include_non_paid_channels:
|
|
2085
2180
|
required_params += constants.NON_PAID_DATA
|
|
2086
2181
|
data_tensors = new_data.validate_and_fill_missing_data(
|
|
2087
|
-
required_tensors_names=required_params,
|
|
2182
|
+
required_tensors_names=required_params,
|
|
2183
|
+
model_context=self.model_context,
|
|
2184
|
+
)
|
|
2185
|
+
new_n_media_times = data_tensors.get_modified_times(
|
|
2186
|
+
model_context=self.model_context
|
|
2088
2187
|
)
|
|
2089
|
-
new_n_media_times = data_tensors.get_modified_times(self._meridian)
|
|
2090
2188
|
|
|
2091
2189
|
if new_n_media_times is None:
|
|
2092
|
-
new_n_media_times =
|
|
2190
|
+
new_n_media_times = m_context.n_media_times
|
|
2093
2191
|
_validate_selected_times(
|
|
2094
2192
|
selected_times=selected_times,
|
|
2095
|
-
input_times=
|
|
2096
|
-
n_times=
|
|
2193
|
+
input_times=m_context.input_data.time,
|
|
2194
|
+
n_times=m_context.n_times,
|
|
2097
2195
|
arg_name="selected_times",
|
|
2098
2196
|
comparison_arg_name="the input data",
|
|
2099
2197
|
)
|
|
2100
2198
|
_validate_selected_times(
|
|
2101
2199
|
selected_times=media_selected_times,
|
|
2102
|
-
input_times=
|
|
2103
|
-
n_times=
|
|
2200
|
+
input_times=m_context.input_data.media_time,
|
|
2201
|
+
n_times=m_context.n_media_times,
|
|
2104
2202
|
arg_name="media_selected_times",
|
|
2105
2203
|
comparison_arg_name="the media tensors",
|
|
2106
2204
|
)
|
|
@@ -2115,7 +2213,7 @@ class Analyzer:
|
|
|
2115
2213
|
else:
|
|
2116
2214
|
if all(isinstance(time, str) for time in media_selected_times):
|
|
2117
2215
|
media_selected_times = [
|
|
2118
|
-
x in media_selected_times for x in
|
|
2216
|
+
x in media_selected_times for x in m_context.input_data.media_time
|
|
2119
2217
|
]
|
|
2120
2218
|
|
|
2121
2219
|
# Set counterfactual tensors based on the scaling factors and the media
|
|
@@ -2129,11 +2227,11 @@ class Analyzer:
|
|
|
2129
2227
|
|
|
2130
2228
|
if data_tensors.non_media_treatments is not None:
|
|
2131
2229
|
non_media_treatments_baseline_scaled = (
|
|
2132
|
-
self.
|
|
2230
|
+
self._model_equations.compute_non_media_treatments_baseline(
|
|
2133
2231
|
non_media_baseline_values=non_media_baseline_values,
|
|
2134
2232
|
)
|
|
2135
2233
|
)
|
|
2136
|
-
non_media_treatments_baseline_normalized = self.
|
|
2234
|
+
non_media_treatments_baseline_normalized = self.model_context.non_media_transformer.forward( # pytype: disable=attribute-error
|
|
2137
2235
|
non_media_treatments_baseline_scaled,
|
|
2138
2236
|
apply_population_scaling=False,
|
|
2139
2237
|
)
|
|
@@ -2160,7 +2258,7 @@ class Analyzer:
|
|
|
2160
2258
|
new_data=incremented_data0,
|
|
2161
2259
|
include_non_paid_channels=include_non_paid_channels,
|
|
2162
2260
|
)
|
|
2163
|
-
# TODO:
|
|
2261
|
+
# TODO: Verify the computation of outcome of non-media
|
|
2164
2262
|
# treatments with `media_selected_times` and scale factors.
|
|
2165
2263
|
|
|
2166
2264
|
data_tensors0 = DataTensors(
|
|
@@ -2181,9 +2279,9 @@ class Analyzer:
|
|
|
2181
2279
|
|
|
2182
2280
|
# Calculate incremental outcome in batches.
|
|
2183
2281
|
params = (
|
|
2184
|
-
self.
|
|
2282
|
+
self.inference_data.posterior
|
|
2185
2283
|
if use_posterior
|
|
2186
|
-
else self.
|
|
2284
|
+
else self.inference_data.prior
|
|
2187
2285
|
)
|
|
2188
2286
|
n_draws = params.draw.size
|
|
2189
2287
|
batch_starting_indices = np.arange(n_draws, step=batch_size)
|
|
@@ -2252,23 +2350,23 @@ class Analyzer:
|
|
|
2252
2350
|
ValueError: If the geo or time granularity arguments are not valid for the
|
|
2253
2351
|
ROI analysis.
|
|
2254
2352
|
"""
|
|
2255
|
-
if self.
|
|
2353
|
+
if self.model_context.is_national:
|
|
2256
2354
|
_warn_if_geo_arg_in_kwargs(
|
|
2257
2355
|
aggregate_geos=aggregate_geos,
|
|
2258
2356
|
selected_geos=selected_geos,
|
|
2259
2357
|
)
|
|
2260
2358
|
if selected_geos is not None or not aggregate_geos:
|
|
2261
2359
|
if (
|
|
2262
|
-
self.
|
|
2263
|
-
and not self.
|
|
2360
|
+
self.model_context.media_tensors.media_spend is not None
|
|
2361
|
+
and not self.model_context.input_data.media_spend_has_geo_dimension
|
|
2264
2362
|
):
|
|
2265
2363
|
raise ValueError(
|
|
2266
2364
|
"`selected_geos` and `aggregate_geos=False` are not allowed because"
|
|
2267
2365
|
" Meridian `media_spend` data does not have a geo dimension."
|
|
2268
2366
|
)
|
|
2269
2367
|
if (
|
|
2270
|
-
self.
|
|
2271
|
-
and not self.
|
|
2368
|
+
self.model_context.rf_tensors.rf_spend is not None
|
|
2369
|
+
and not self.model_context.input_data.rf_spend_has_geo_dimension
|
|
2272
2370
|
):
|
|
2273
2371
|
raise ValueError(
|
|
2274
2372
|
"`selected_geos` and `aggregate_geos=False` are not allowed because"
|
|
@@ -2277,16 +2375,16 @@ class Analyzer:
|
|
|
2277
2375
|
|
|
2278
2376
|
if selected_times is not None:
|
|
2279
2377
|
if (
|
|
2280
|
-
self.
|
|
2281
|
-
and not self.
|
|
2378
|
+
self.model_context.media_tensors.media_spend is not None
|
|
2379
|
+
and not self.model_context.input_data.media_spend_has_time_dimension
|
|
2282
2380
|
):
|
|
2283
2381
|
raise ValueError(
|
|
2284
2382
|
"`selected_times` is not allowed because Meridian `media_spend`"
|
|
2285
2383
|
" data does not have a time dimension."
|
|
2286
2384
|
)
|
|
2287
2385
|
if (
|
|
2288
|
-
self.
|
|
2289
|
-
and not self.
|
|
2386
|
+
self.model_context.rf_tensors.rf_spend is not None
|
|
2387
|
+
and not self.model_context.input_data.rf_spend_has_time_dimension
|
|
2290
2388
|
):
|
|
2291
2389
|
raise ValueError(
|
|
2292
2390
|
"`selected_times` is not allowed because Meridian `rf_spend` data"
|
|
@@ -2379,7 +2477,7 @@ class Analyzer:
|
|
|
2379
2477
|
new_data = DataTensors()
|
|
2380
2478
|
filled_data = new_data.validate_and_fill_missing_data(
|
|
2381
2479
|
required_tensors_names=required_values,
|
|
2382
|
-
|
|
2480
|
+
model_context=self.model_context,
|
|
2383
2481
|
)
|
|
2384
2482
|
numerator = self.incremental_outcome(
|
|
2385
2483
|
new_data=filled_data.filter_fields(constants.PAID_DATA),
|
|
@@ -2396,25 +2494,26 @@ class Analyzer:
|
|
|
2396
2494
|
)
|
|
2397
2495
|
spend_inc = filled_data.total_spend() * incremental_increase
|
|
2398
2496
|
if spend_inc is not None and spend_inc.ndim == 3:
|
|
2399
|
-
|
|
2400
|
-
|
|
2401
|
-
|
|
2402
|
-
|
|
2403
|
-
|
|
2404
|
-
|
|
2497
|
+
return backend.divide(
|
|
2498
|
+
numerator,
|
|
2499
|
+
self.filter_and_aggregate_geos_and_times(
|
|
2500
|
+
spend_inc,
|
|
2501
|
+
aggregate_times=True,
|
|
2502
|
+
flexible_time_dim=True,
|
|
2503
|
+
has_media_dim=True,
|
|
2504
|
+
**dim_kwargs,
|
|
2505
|
+
),
|
|
2405
2506
|
)
|
|
2406
|
-
|
|
2407
|
-
|
|
2408
|
-
|
|
2409
|
-
|
|
2410
|
-
|
|
2411
|
-
|
|
2412
|
-
|
|
2413
|
-
|
|
2414
|
-
|
|
2415
|
-
|
|
2416
|
-
denominator = spend_inc
|
|
2417
|
-
return backend.divide_no_nan(numerator, denominator)
|
|
2507
|
+
|
|
2508
|
+
if not aggregate_geos:
|
|
2509
|
+
# This check should not be reachable. It is here to protect against
|
|
2510
|
+
# future changes to self._validate_geo_and_time_granularity. If
|
|
2511
|
+
# spend_inc.ndim is not 3 and `aggregate_geos` is `False`, then
|
|
2512
|
+
# self._validate_geo_and_time_granularity should raise an error.
|
|
2513
|
+
raise ValueError(
|
|
2514
|
+
"aggregate_geos must be True if spend does not have a geo dimension."
|
|
2515
|
+
)
|
|
2516
|
+
return backend.divide(numerator, spend_inc)
|
|
2418
2517
|
|
|
2419
2518
|
def roi(
|
|
2420
2519
|
self,
|
|
@@ -2501,7 +2600,7 @@ class Analyzer:
|
|
|
2501
2600
|
new_data = DataTensors()
|
|
2502
2601
|
filled_data = new_data.validate_and_fill_missing_data(
|
|
2503
2602
|
required_tensors_names=required_values,
|
|
2504
|
-
|
|
2603
|
+
model_context=self.model_context,
|
|
2505
2604
|
)
|
|
2506
2605
|
incremental_outcome = self.incremental_outcome(
|
|
2507
2606
|
new_data=filled_data.filter_fields(constants.PAID_DATA),
|
|
@@ -2511,26 +2610,27 @@ class Analyzer:
|
|
|
2511
2610
|
|
|
2512
2611
|
spend = filled_data.total_spend()
|
|
2513
2612
|
if spend is not None and spend.ndim == 3:
|
|
2514
|
-
|
|
2515
|
-
|
|
2516
|
-
|
|
2517
|
-
|
|
2518
|
-
|
|
2519
|
-
|
|
2613
|
+
return backend.divide(
|
|
2614
|
+
incremental_outcome,
|
|
2615
|
+
self.filter_and_aggregate_geos_and_times(
|
|
2616
|
+
spend,
|
|
2617
|
+
aggregate_times=True,
|
|
2618
|
+
flexible_time_dim=True,
|
|
2619
|
+
has_media_dim=True,
|
|
2620
|
+
**dim_kwargs,
|
|
2621
|
+
),
|
|
2520
2622
|
)
|
|
2521
|
-
|
|
2522
|
-
|
|
2523
|
-
|
|
2524
|
-
|
|
2525
|
-
|
|
2526
|
-
|
|
2527
|
-
|
|
2528
|
-
|
|
2529
|
-
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
denominator = spend
|
|
2533
|
-
return backend.divide_no_nan(incremental_outcome, denominator)
|
|
2623
|
+
|
|
2624
|
+
if not aggregate_geos:
|
|
2625
|
+
# This check should not be reachable. It is here to protect against
|
|
2626
|
+
# future changes to self._validate_geo_and_time_granularity. If
|
|
2627
|
+
# spend_inc.ndim is not 3 and either of `aggregate_geos` or
|
|
2628
|
+
# `aggregate_times` is `False`, then
|
|
2629
|
+
# self._validate_geo_and_time_granularity should raise an error.
|
|
2630
|
+
raise ValueError(
|
|
2631
|
+
"aggregate_geos must be True if spend does not have a geo dimension."
|
|
2632
|
+
)
|
|
2633
|
+
return backend.divide(incremental_outcome, spend)
|
|
2534
2634
|
|
|
2535
2635
|
def cpik(
|
|
2536
2636
|
self,
|
|
@@ -2605,7 +2705,7 @@ class Analyzer:
|
|
|
2605
2705
|
aggregate_geos=aggregate_geos,
|
|
2606
2706
|
batch_size=batch_size,
|
|
2607
2707
|
)
|
|
2608
|
-
return backend.
|
|
2708
|
+
return backend.divide(1, roi)
|
|
2609
2709
|
|
|
2610
2710
|
def _mean_and_ci_by_eval_set(
|
|
2611
2711
|
self,
|
|
@@ -2647,8 +2747,12 @@ class Analyzer:
|
|
|
2647
2747
|
draws, confidence_level=confidence_level
|
|
2648
2748
|
)
|
|
2649
2749
|
|
|
2650
|
-
train_draws = np.where(
|
|
2651
|
-
|
|
2750
|
+
train_draws = np.where(
|
|
2751
|
+
self.model_context.model_spec.holdout_id, np.nan, draws
|
|
2752
|
+
)
|
|
2753
|
+
test_draws = np.where(
|
|
2754
|
+
self.model_context.model_spec.holdout_id, draws, np.nan
|
|
2755
|
+
)
|
|
2652
2756
|
draws_by_evaluation_set = np.stack(
|
|
2653
2757
|
[train_draws, test_draws, draws], axis=0
|
|
2654
2758
|
) # shape (n_evaluation_sets(=3), n_chains, n_draws, n_geos, n_times)
|
|
@@ -2669,13 +2773,14 @@ class Analyzer:
|
|
|
2669
2773
|
|
|
2670
2774
|
def _can_split_by_holdout_id(self, split_by_holdout_id: bool) -> bool:
|
|
2671
2775
|
"""Returns whether the data can be split by holdout_id."""
|
|
2672
|
-
if split_by_holdout_id and self.
|
|
2776
|
+
if split_by_holdout_id and self.model_context.model_spec.holdout_id is None:
|
|
2673
2777
|
warnings.warn(
|
|
2674
2778
|
"`split_by_holdout_id` is True but `holdout_id` is `None`. Data will"
|
|
2675
2779
|
" not be split."
|
|
2676
2780
|
)
|
|
2677
2781
|
return (
|
|
2678
|
-
split_by_holdout_id
|
|
2782
|
+
split_by_holdout_id
|
|
2783
|
+
and self.model_context.model_spec.holdout_id is not None
|
|
2679
2784
|
)
|
|
2680
2785
|
|
|
2681
2786
|
def expected_vs_actual_data(
|
|
@@ -2713,7 +2818,7 @@ class Analyzer:
|
|
|
2713
2818
|
"""
|
|
2714
2819
|
_validate_non_media_baseline_values_numbers(non_media_baseline_values)
|
|
2715
2820
|
use_kpi = self._use_kpi(use_kpi)
|
|
2716
|
-
|
|
2821
|
+
m_context = self.model_context
|
|
2717
2822
|
can_split_by_holdout = self._can_split_by_holdout_id(split_by_holdout_id)
|
|
2718
2823
|
expected_outcome = self.expected_outcome(
|
|
2719
2824
|
aggregate_geos=False, aggregate_times=False, use_kpi=use_kpi
|
|
@@ -2742,7 +2847,9 @@ class Analyzer:
|
|
|
2742
2847
|
)
|
|
2743
2848
|
actual = np.asarray(
|
|
2744
2849
|
self.filter_and_aggregate_geos_and_times(
|
|
2745
|
-
|
|
2850
|
+
m_context.kpi
|
|
2851
|
+
if use_kpi
|
|
2852
|
+
else m_context.kpi * m_context.revenue_per_kpi,
|
|
2746
2853
|
aggregate_geos=aggregate_geos,
|
|
2747
2854
|
aggregate_times=aggregate_times,
|
|
2748
2855
|
)
|
|
@@ -2754,9 +2861,9 @@ class Analyzer:
|
|
|
2754
2861
|
}
|
|
2755
2862
|
|
|
2756
2863
|
if not aggregate_geos:
|
|
2757
|
-
coords[constants.GEO] =
|
|
2864
|
+
coords[constants.GEO] = m_context.input_data.geo.data
|
|
2758
2865
|
if not aggregate_times:
|
|
2759
|
-
coords[constants.TIME] =
|
|
2866
|
+
coords[constants.TIME] = m_context.input_data.time.data
|
|
2760
2867
|
if can_split_by_holdout:
|
|
2761
2868
|
coords[constants.EVALUATION_SET_VAR] = list(constants.EVALUATION_SET)
|
|
2762
2869
|
|
|
@@ -2816,56 +2923,57 @@ class Analyzer:
|
|
|
2816
2923
|
n_draws, n_geos, n_times)`. The `n_geos` and `n_times` dimensions is
|
|
2817
2924
|
dropped if `aggregate_geos=True` or `aggregate_time=True`, respectively.
|
|
2818
2925
|
"""
|
|
2926
|
+
ctx = self.model_context
|
|
2819
2927
|
new_media = (
|
|
2820
|
-
backend.zeros_like(
|
|
2821
|
-
if
|
|
2928
|
+
backend.zeros_like(ctx.media_tensors.media)
|
|
2929
|
+
if ctx.media_tensors.media is not None
|
|
2822
2930
|
else None
|
|
2823
2931
|
)
|
|
2824
2932
|
# Frequency is not needed because the reach is zero.
|
|
2825
2933
|
new_reach = (
|
|
2826
|
-
backend.zeros_like(
|
|
2827
|
-
if
|
|
2934
|
+
backend.zeros_like(ctx.rf_tensors.reach)
|
|
2935
|
+
if ctx.rf_tensors.reach is not None
|
|
2828
2936
|
else None
|
|
2829
2937
|
)
|
|
2830
2938
|
new_organic_media = (
|
|
2831
|
-
backend.zeros_like(
|
|
2832
|
-
if
|
|
2939
|
+
backend.zeros_like(ctx.organic_media_tensors.organic_media)
|
|
2940
|
+
if ctx.organic_media_tensors.organic_media is not None
|
|
2833
2941
|
else None
|
|
2834
2942
|
)
|
|
2835
2943
|
new_organic_reach = (
|
|
2836
|
-
backend.zeros_like(
|
|
2837
|
-
if
|
|
2944
|
+
backend.zeros_like(ctx.organic_rf_tensors.organic_reach)
|
|
2945
|
+
if ctx.organic_rf_tensors.organic_reach is not None
|
|
2838
2946
|
else None
|
|
2839
2947
|
)
|
|
2840
|
-
if
|
|
2841
|
-
if
|
|
2948
|
+
if ctx.non_media_treatments is not None:
|
|
2949
|
+
if ctx.model_spec.non_media_population_scaling_id is not None:
|
|
2842
2950
|
scaling_factors = backend.where(
|
|
2843
|
-
|
|
2844
|
-
|
|
2845
|
-
backend.ones_like(
|
|
2951
|
+
ctx.model_spec.non_media_population_scaling_id,
|
|
2952
|
+
ctx.population[:, backend.newaxis, backend.newaxis],
|
|
2953
|
+
backend.ones_like(ctx.population)[
|
|
2846
2954
|
:, backend.newaxis, backend.newaxis
|
|
2847
2955
|
],
|
|
2848
2956
|
)
|
|
2849
2957
|
else:
|
|
2850
|
-
scaling_factors = backend.ones_like(
|
|
2958
|
+
scaling_factors = backend.ones_like(ctx.population)[
|
|
2851
2959
|
:, backend.newaxis, backend.newaxis
|
|
2852
2960
|
]
|
|
2853
2961
|
|
|
2854
|
-
baseline = self.
|
|
2962
|
+
baseline = self._model_equations.compute_non_media_treatments_baseline(
|
|
2855
2963
|
non_media_baseline_values=non_media_baseline_values,
|
|
2856
2964
|
)
|
|
2857
2965
|
new_non_media_treatments_population_scaled = backend.broadcast_to(
|
|
2858
2966
|
backend.to_tensor(baseline, dtype=backend.float32)[
|
|
2859
2967
|
backend.newaxis, backend.newaxis, :
|
|
2860
2968
|
],
|
|
2861
|
-
|
|
2969
|
+
ctx.non_media_treatments.shape,
|
|
2862
2970
|
)
|
|
2863
2971
|
new_non_media_treatments = (
|
|
2864
2972
|
new_non_media_treatments_population_scaled * scaling_factors
|
|
2865
2973
|
)
|
|
2866
2974
|
else:
|
|
2867
2975
|
new_non_media_treatments = None
|
|
2868
|
-
new_controls =
|
|
2976
|
+
new_controls = ctx.controls
|
|
2869
2977
|
|
|
2870
2978
|
new_data = DataTensors(
|
|
2871
2979
|
media=new_media,
|
|
@@ -3134,23 +3242,25 @@ class Analyzer:
|
|
|
3134
3242
|
+ (constants.CHANNEL,)
|
|
3135
3243
|
)
|
|
3136
3244
|
channels = (
|
|
3137
|
-
self.
|
|
3245
|
+
self.model_context.input_data.get_all_channels()
|
|
3138
3246
|
if include_non_paid_channels
|
|
3139
|
-
else self.
|
|
3247
|
+
else self.model_context.input_data.get_all_paid_channels()
|
|
3140
3248
|
)
|
|
3141
3249
|
xr_coords = {constants.CHANNEL: list(channels) + [constants.ALL_CHANNELS]}
|
|
3142
3250
|
if not aggregate_geos:
|
|
3143
3251
|
geo_dims = (
|
|
3144
|
-
self.
|
|
3252
|
+
self.model_context.input_data.geo.data
|
|
3145
3253
|
if selected_geos is None
|
|
3146
3254
|
else selected_geos
|
|
3147
3255
|
)
|
|
3148
3256
|
xr_coords[constants.GEO] = geo_dims
|
|
3149
3257
|
if not aggregate_times:
|
|
3150
3258
|
# Get the time coordinates for flexible time dimensions.
|
|
3151
|
-
modified_times = new_data.get_modified_times(
|
|
3259
|
+
modified_times = new_data.get_modified_times(
|
|
3260
|
+
model_context=self.model_context
|
|
3261
|
+
)
|
|
3152
3262
|
if modified_times is None:
|
|
3153
|
-
times = self.
|
|
3263
|
+
times = self.model_context.input_data.time.data
|
|
3154
3264
|
else:
|
|
3155
3265
|
times = np.arange(modified_times)
|
|
3156
3266
|
|
|
@@ -3199,7 +3309,7 @@ class Analyzer:
|
|
|
3199
3309
|
# channels.
|
|
3200
3310
|
).where(lambda ds: ds.channel != constants.ALL_CHANNELS)
|
|
3201
3311
|
|
|
3202
|
-
if new_data.get_modified_times(self.
|
|
3312
|
+
if new_data.get_modified_times(model_context=self.model_context) is None:
|
|
3203
3313
|
expected_outcome_fields = list(
|
|
3204
3314
|
constants.PAID_DATA + constants.NON_PAID_DATA + (constants.CONTROLS,)
|
|
3205
3315
|
)
|
|
@@ -3239,26 +3349,35 @@ class Analyzer:
|
|
|
3239
3349
|
"Effectiveness is not reported because it does not have a clear"
|
|
3240
3350
|
" interpretation by time period."
|
|
3241
3351
|
)
|
|
3242
|
-
return xr.merge(
|
|
3243
|
-
|
|
3244
|
-
|
|
3245
|
-
|
|
3352
|
+
return xr.merge(
|
|
3353
|
+
[
|
|
3354
|
+
incremental_outcome,
|
|
3355
|
+
pct_of_contribution,
|
|
3356
|
+
],
|
|
3357
|
+
compat="no_conflicts",
|
|
3358
|
+
)
|
|
3246
3359
|
else:
|
|
3247
|
-
return xr.merge(
|
|
3248
|
-
|
|
3249
|
-
|
|
3250
|
-
|
|
3251
|
-
|
|
3360
|
+
return xr.merge(
|
|
3361
|
+
[
|
|
3362
|
+
incremental_outcome,
|
|
3363
|
+
pct_of_contribution,
|
|
3364
|
+
effectiveness,
|
|
3365
|
+
],
|
|
3366
|
+
compat="no_conflicts",
|
|
3367
|
+
)
|
|
3252
3368
|
|
|
3253
3369
|
# If non-paid channels are not included, return all metrics, paid and
|
|
3254
3370
|
# non-paid.
|
|
3255
3371
|
spend_list = []
|
|
3256
3372
|
new_spend_tensors = new_data.filter_fields(
|
|
3257
3373
|
constants.SPEND_DATA
|
|
3258
|
-
).validate_and_fill_missing_data(
|
|
3259
|
-
|
|
3374
|
+
).validate_and_fill_missing_data(
|
|
3375
|
+
required_tensors_names=constants.SPEND_DATA,
|
|
3376
|
+
model_context=self.model_context,
|
|
3377
|
+
)
|
|
3378
|
+
if self.model_context.n_media_channels > 0:
|
|
3260
3379
|
spend_list.append(new_spend_tensors.media_spend)
|
|
3261
|
-
if self.
|
|
3380
|
+
if self.model_context.n_rf_channels > 0:
|
|
3262
3381
|
spend_list.append(new_spend_tensors.rf_spend)
|
|
3263
3382
|
# TODO Add support for 1-dimensional spend.
|
|
3264
3383
|
aggregated_spend = self.filter_and_aggregate_geos_and_times(
|
|
@@ -3288,11 +3407,14 @@ class Analyzer:
|
|
|
3288
3407
|
"ROI, mROI, Effectiveness, and CPIK are not reported because they "
|
|
3289
3408
|
"do not have a clear interpretation by time period."
|
|
3290
3409
|
)
|
|
3291
|
-
return xr.merge(
|
|
3292
|
-
|
|
3293
|
-
|
|
3294
|
-
|
|
3295
|
-
|
|
3410
|
+
return xr.merge(
|
|
3411
|
+
[
|
|
3412
|
+
spend_data,
|
|
3413
|
+
incremental_outcome,
|
|
3414
|
+
pct_of_contribution,
|
|
3415
|
+
],
|
|
3416
|
+
compat="no_conflicts",
|
|
3417
|
+
)
|
|
3296
3418
|
else:
|
|
3297
3419
|
roi = self._compute_roi_aggregate(
|
|
3298
3420
|
incremental_outcome_prior=incremental_outcome_prior,
|
|
@@ -3339,15 +3461,18 @@ class Analyzer:
|
|
|
3339
3461
|
xr_coords=xr_coords_with_ci_and_distribution,
|
|
3340
3462
|
confidence_level=confidence_level,
|
|
3341
3463
|
)
|
|
3342
|
-
return xr.merge(
|
|
3343
|
-
|
|
3344
|
-
|
|
3345
|
-
|
|
3346
|
-
|
|
3347
|
-
|
|
3348
|
-
|
|
3349
|
-
|
|
3350
|
-
|
|
3464
|
+
return xr.merge(
|
|
3465
|
+
[
|
|
3466
|
+
spend_data,
|
|
3467
|
+
incremental_outcome,
|
|
3468
|
+
pct_of_contribution,
|
|
3469
|
+
roi,
|
|
3470
|
+
effectiveness,
|
|
3471
|
+
mroi,
|
|
3472
|
+
cpik,
|
|
3473
|
+
],
|
|
3474
|
+
compat="no_conflicts",
|
|
3475
|
+
)
|
|
3351
3476
|
|
|
3352
3477
|
def get_aggregated_impressions(
|
|
3353
3478
|
self,
|
|
@@ -3401,17 +3526,18 @@ class Analyzer:
|
|
|
3401
3526
|
if new_data is None:
|
|
3402
3527
|
new_data = DataTensors()
|
|
3403
3528
|
data_tensors = new_data.validate_and_fill_missing_data(
|
|
3404
|
-
tensor_names_list,
|
|
3529
|
+
required_tensors_names=tensor_names_list,
|
|
3530
|
+
model_context=self.model_context,
|
|
3405
3531
|
)
|
|
3406
3532
|
n_times = (
|
|
3407
|
-
data_tensors.get_modified_times(self.
|
|
3408
|
-
or self.
|
|
3533
|
+
data_tensors.get_modified_times(model_context=self.model_context)
|
|
3534
|
+
or self.model_context.n_times
|
|
3409
3535
|
)
|
|
3410
3536
|
impressions_list = []
|
|
3411
|
-
if self.
|
|
3537
|
+
if self.model_context.n_media_channels > 0:
|
|
3412
3538
|
impressions_list.append(data_tensors.media[:, -n_times:, :])
|
|
3413
3539
|
|
|
3414
|
-
if self.
|
|
3540
|
+
if self.model_context.n_rf_channels > 0:
|
|
3415
3541
|
if optimal_frequency is None:
|
|
3416
3542
|
new_frequency = data_tensors.frequency
|
|
3417
3543
|
else:
|
|
@@ -3423,9 +3549,9 @@ class Analyzer:
|
|
|
3423
3549
|
)
|
|
3424
3550
|
|
|
3425
3551
|
if include_non_paid_channels:
|
|
3426
|
-
if self.
|
|
3552
|
+
if self.model_context.n_organic_media_channels > 0:
|
|
3427
3553
|
impressions_list.append(data_tensors.organic_media[:, -n_times:, :])
|
|
3428
|
-
if self.
|
|
3554
|
+
if self.model_context.n_organic_rf_channels > 0:
|
|
3429
3555
|
if optimal_frequency is None:
|
|
3430
3556
|
new_organic_frequency = data_tensors.organic_frequency
|
|
3431
3557
|
else:
|
|
@@ -3437,7 +3563,7 @@ class Analyzer:
|
|
|
3437
3563
|
data_tensors.organic_reach[:, -n_times:, :]
|
|
3438
3564
|
* new_organic_frequency[:, -n_times:, :]
|
|
3439
3565
|
)
|
|
3440
|
-
if self.
|
|
3566
|
+
if self.model_context.n_non_media_channels > 0:
|
|
3441
3567
|
impressions_list.append(data_tensors.non_media_treatments)
|
|
3442
3568
|
|
|
3443
3569
|
return self.filter_and_aggregate_geos_and_times(
|
|
@@ -3512,14 +3638,14 @@ class Analyzer:
|
|
|
3512
3638
|
xr_coords = {constants.CHANNEL: [constants.BASELINE]}
|
|
3513
3639
|
if not aggregate_geos:
|
|
3514
3640
|
geo_dims = (
|
|
3515
|
-
self.
|
|
3641
|
+
self.model_context.input_data.geo.data
|
|
3516
3642
|
if selected_geos is None
|
|
3517
3643
|
else selected_geos
|
|
3518
3644
|
)
|
|
3519
3645
|
xr_coords[constants.GEO] = geo_dims
|
|
3520
3646
|
if not aggregate_times:
|
|
3521
3647
|
time_dims = (
|
|
3522
|
-
self.
|
|
3648
|
+
self.model_context.input_data.time.data
|
|
3523
3649
|
if selected_times is None
|
|
3524
3650
|
else selected_times
|
|
3525
3651
|
)
|
|
@@ -3585,10 +3711,13 @@ class Analyzer:
|
|
|
3585
3711
|
confidence_level=confidence_level,
|
|
3586
3712
|
).sel(channel=constants.BASELINE)
|
|
3587
3713
|
|
|
3588
|
-
return xr.merge(
|
|
3589
|
-
|
|
3590
|
-
|
|
3591
|
-
|
|
3714
|
+
return xr.merge(
|
|
3715
|
+
[
|
|
3716
|
+
baseline_outcome,
|
|
3717
|
+
baseline_pct_of_contribution,
|
|
3718
|
+
],
|
|
3719
|
+
compat="no_conflicts",
|
|
3720
|
+
)
|
|
3592
3721
|
|
|
3593
3722
|
def optimal_freq(
|
|
3594
3723
|
self,
|
|
@@ -3680,43 +3809,48 @@ class Analyzer:
|
|
|
3680
3809
|
dist_type = constants.POSTERIOR if use_posterior else constants.PRIOR
|
|
3681
3810
|
use_kpi = self._use_kpi(use_kpi)
|
|
3682
3811
|
new_data = new_data or DataTensors()
|
|
3683
|
-
if self.
|
|
3812
|
+
if self.model_context.n_rf_channels == 0:
|
|
3684
3813
|
raise ValueError(
|
|
3685
3814
|
"Must have at least one channel with reach and frequency data."
|
|
3686
3815
|
)
|
|
3687
|
-
if dist_type not in self.
|
|
3816
|
+
if dist_type not in self.inference_data.groups():
|
|
3688
3817
|
raise model.NotFittedModelError(
|
|
3689
3818
|
f"sample_{dist_type}() must be called prior to calling this method."
|
|
3690
3819
|
)
|
|
3691
3820
|
|
|
3692
3821
|
filled_data = new_data.validate_and_fill_missing_data(
|
|
3693
|
-
[
|
|
3822
|
+
required_tensors_names=[
|
|
3694
3823
|
constants.RF_IMPRESSIONS,
|
|
3695
3824
|
constants.RF_SPEND,
|
|
3696
3825
|
constants.REVENUE_PER_KPI,
|
|
3697
3826
|
],
|
|
3698
|
-
self.
|
|
3827
|
+
model_context=self.model_context,
|
|
3699
3828
|
)
|
|
3700
3829
|
# TODO: Once treatment type filtering is added, remove adding
|
|
3701
3830
|
# dummy media and media spend to `roi()` and `summary_metrics()`. This is a
|
|
3702
3831
|
# hack to use `roi()` and `summary_metrics()` for RF only analysis.
|
|
3703
|
-
has_media = self.
|
|
3832
|
+
has_media = self.model_context.n_media_channels > 0
|
|
3704
3833
|
n_media_times = (
|
|
3705
|
-
filled_data.get_modified_times(self.
|
|
3706
|
-
or self.
|
|
3834
|
+
filled_data.get_modified_times(model_context=self.model_context)
|
|
3835
|
+
or self.model_context.n_media_times
|
|
3707
3836
|
)
|
|
3708
3837
|
n_times = (
|
|
3709
|
-
filled_data.get_modified_times(self.
|
|
3710
|
-
|
|
3711
|
-
dummy_media = backend.ones(
|
|
3712
|
-
(self._meridian.n_geos, n_media_times, self._meridian.n_media_channels)
|
|
3713
|
-
)
|
|
3714
|
-
dummy_media_spend = backend.ones(
|
|
3715
|
-
(self._meridian.n_geos, n_times, self._meridian.n_media_channels)
|
|
3838
|
+
filled_data.get_modified_times(model_context=self.model_context)
|
|
3839
|
+
or self.model_context.n_times
|
|
3716
3840
|
)
|
|
3841
|
+
dummy_media = backend.ones((
|
|
3842
|
+
self.model_context.n_geos,
|
|
3843
|
+
n_media_times,
|
|
3844
|
+
self.model_context.n_media_channels,
|
|
3845
|
+
))
|
|
3846
|
+
dummy_media_spend = backend.ones((
|
|
3847
|
+
self.model_context.n_geos,
|
|
3848
|
+
n_times,
|
|
3849
|
+
self.model_context.n_media_channels,
|
|
3850
|
+
))
|
|
3717
3851
|
|
|
3718
3852
|
max_freq = max_frequency or np.max(
|
|
3719
|
-
np.array(self.
|
|
3853
|
+
np.array(self.model_context.rf_tensors.frequency)
|
|
3720
3854
|
)
|
|
3721
3855
|
if freq_grid is None:
|
|
3722
3856
|
freq_grid = np.arange(1, max_freq, 0.1)
|
|
@@ -3724,7 +3858,9 @@ class Analyzer:
|
|
|
3724
3858
|
# Create a frequency grid for shape (len(freq_grid), n_rf_channels, 4) where
|
|
3725
3859
|
# the last argument is for the mean, median, lower and upper confidence
|
|
3726
3860
|
# intervals.
|
|
3727
|
-
metric_grid = np.zeros(
|
|
3861
|
+
metric_grid = np.zeros(
|
|
3862
|
+
(len(freq_grid), self.model_context.n_rf_channels, 4)
|
|
3863
|
+
)
|
|
3728
3864
|
|
|
3729
3865
|
for i, freq in enumerate(freq_grid):
|
|
3730
3866
|
new_frequency = backend.ones_like(filled_data.rf_impressions) * freq
|
|
@@ -3744,15 +3880,15 @@ class Analyzer:
|
|
|
3744
3880
|
selected_times=selected_times,
|
|
3745
3881
|
aggregate_geos=True,
|
|
3746
3882
|
use_kpi=use_kpi,
|
|
3747
|
-
)[..., -self.
|
|
3883
|
+
)[..., -self.model_context.n_rf_channels :]
|
|
3748
3884
|
metric_grid[i, :] = get_central_tendency_and_ci(
|
|
3749
3885
|
metric_grid_temp, confidence_level, include_median=True
|
|
3750
3886
|
)
|
|
3751
3887
|
|
|
3752
3888
|
optimal_freq_idx = np.nanargmax(metric_grid[:, :, 0], axis=0)
|
|
3753
3889
|
rf_channel_values = (
|
|
3754
|
-
self.
|
|
3755
|
-
if self.
|
|
3890
|
+
self.model_context.input_data.rf_channel.values
|
|
3891
|
+
if self.model_context.input_data.rf_channel is not None
|
|
3756
3892
|
else []
|
|
3757
3893
|
)
|
|
3758
3894
|
|
|
@@ -3901,7 +4037,7 @@ class Analyzer:
|
|
|
3901
4037
|
three metrics are computed for each.
|
|
3902
4038
|
"""
|
|
3903
4039
|
use_kpi = self._use_kpi(use_kpi)
|
|
3904
|
-
if self.
|
|
4040
|
+
if self.model_context.is_national:
|
|
3905
4041
|
_warn_if_geo_arg_in_kwargs(
|
|
3906
4042
|
selected_geos=selected_geos,
|
|
3907
4043
|
)
|
|
@@ -3922,9 +4058,9 @@ class Analyzer:
|
|
|
3922
4058
|
constants.GEO_GRANULARITY: [constants.GEO, constants.NATIONAL],
|
|
3923
4059
|
}
|
|
3924
4060
|
if use_kpi:
|
|
3925
|
-
input_tensor = self.
|
|
4061
|
+
input_tensor = self.model_context.kpi
|
|
3926
4062
|
else:
|
|
3927
|
-
input_tensor = self.
|
|
4063
|
+
input_tensor = self.model_context.kpi * self.model_context.revenue_per_kpi
|
|
3928
4064
|
actual = np.asarray(
|
|
3929
4065
|
self.filter_and_aggregate_geos_and_times(
|
|
3930
4066
|
tensor=input_tensor,
|
|
@@ -3941,7 +4077,7 @@ class Analyzer:
|
|
|
3941
4077
|
rsquared_national, mape_national, wmape_national = (
|
|
3942
4078
|
self._predictive_accuracy_helper(np.sum(actual, 0), np.sum(expected, 0))
|
|
3943
4079
|
)
|
|
3944
|
-
if self.
|
|
4080
|
+
if self.model_context.model_spec.holdout_id is None:
|
|
3945
4081
|
rsquared_arr = [rsquared, rsquared_national]
|
|
3946
4082
|
mape_arr = [mape, mape_national]
|
|
3947
4083
|
wmape_arr = [wmape, wmape_national]
|
|
@@ -3955,7 +4091,9 @@ class Analyzer:
|
|
|
3955
4091
|
xr_coords[constants.EVALUATION_SET_VAR] = list(constants.EVALUATION_SET)
|
|
3956
4092
|
|
|
3957
4093
|
holdout_id = self._filter_holdout_id_for_selected_geos_and_times(
|
|
3958
|
-
self.
|
|
4094
|
+
self.model_context.model_spec.holdout_id,
|
|
4095
|
+
selected_geos,
|
|
4096
|
+
selected_times,
|
|
3959
4097
|
)
|
|
3960
4098
|
|
|
3961
4099
|
nansum = lambda x: np.where(
|
|
@@ -3985,7 +4123,7 @@ class Analyzer:
|
|
|
3985
4123
|
)
|
|
3986
4124
|
xr_data = {constants.VALUE: (xr_dims, stacked_total)}
|
|
3987
4125
|
dataset = xr.Dataset(data_vars=xr_data, coords=xr_coords)
|
|
3988
|
-
if self.
|
|
4126
|
+
if self.model_context.is_national:
|
|
3989
4127
|
# Remove the geo-level coordinate.
|
|
3990
4128
|
dataset = dataset.sel(geo_granularity=[constants.NATIONAL])
|
|
3991
4129
|
return dataset
|
|
@@ -4023,14 +4161,16 @@ class Analyzer:
|
|
|
4023
4161
|
) -> np.ndarray:
|
|
4024
4162
|
"""Filters the holdout_id array for selected times and geos."""
|
|
4025
4163
|
|
|
4026
|
-
if selected_geos is not None and not self.
|
|
4027
|
-
geo_mask = [x in selected_geos for x in self.
|
|
4164
|
+
if selected_geos is not None and not self.model_context.is_national:
|
|
4165
|
+
geo_mask = [x in selected_geos for x in self.model_context.input_data.geo]
|
|
4028
4166
|
holdout_id = holdout_id[geo_mask]
|
|
4029
4167
|
|
|
4030
4168
|
if selected_times is not None:
|
|
4031
|
-
time_mask = [
|
|
4169
|
+
time_mask = [
|
|
4170
|
+
x in selected_times for x in self.model_context.input_data.time
|
|
4171
|
+
]
|
|
4032
4172
|
# If model is national, holdout_id will have only 1 dimension.
|
|
4033
|
-
if self.
|
|
4173
|
+
if self.model_context.is_national:
|
|
4034
4174
|
holdout_id = holdout_id[time_mask]
|
|
4035
4175
|
else:
|
|
4036
4176
|
holdout_id = holdout_id[:, time_mask]
|
|
@@ -4048,7 +4188,7 @@ class Analyzer:
|
|
|
4048
4188
|
NotFittedModelError: If self.sample_posterior() is not called before
|
|
4049
4189
|
calling this method.
|
|
4050
4190
|
"""
|
|
4051
|
-
if constants.POSTERIOR not in self.
|
|
4191
|
+
if constants.POSTERIOR not in self._inference_data.groups():
|
|
4052
4192
|
raise model.NotFittedModelError(
|
|
4053
4193
|
"sample_posterior() must be called prior to calling this method."
|
|
4054
4194
|
)
|
|
@@ -4059,11 +4199,10 @@ class Analyzer:
|
|
|
4059
4199
|
perm = [1, 0] + list(range(2, n_dim))
|
|
4060
4200
|
return backend.transpose(x_tensor, perm)
|
|
4061
4201
|
|
|
4062
|
-
|
|
4202
|
+
return backend.mcmc.potential_scale_reduction({
|
|
4063
4203
|
k: _transpose_first_two_dims(v)
|
|
4064
|
-
for k, v in self.
|
|
4204
|
+
for k, v in self._inference_data.posterior.data_vars.items()
|
|
4065
4205
|
})
|
|
4066
|
-
return rhat
|
|
4067
4206
|
|
|
4068
4207
|
def rhat_summary(self, bad_rhat_threshold: float = 1.2) -> pd.DataFrame:
|
|
4069
4208
|
"""Computes a summary of the R-hat values for each parameter in the model.
|
|
@@ -4110,8 +4249,15 @@ class Analyzer:
|
|
|
4110
4249
|
|
|
4111
4250
|
rhat_summary = []
|
|
4112
4251
|
for param in rhat:
|
|
4252
|
+
# `tau_g` and `tau_g_excl_baseline` are the only parameters that have
|
|
4253
|
+
# inconsistent names in the prior and the posterior. Here, we ensure that
|
|
4254
|
+
# the `has_deterministic_param` takes the right parameter name.
|
|
4255
|
+
param_name = (
|
|
4256
|
+
constants.TAU_G_EXCL_BASELINE if param == constants.TAU_G else param
|
|
4257
|
+
)
|
|
4258
|
+
|
|
4113
4259
|
# Skip if parameter is deterministic according to the prior.
|
|
4114
|
-
if self.
|
|
4260
|
+
if self.model_context.prior_broadcast.has_deterministic_param(param_name):
|
|
4115
4261
|
continue
|
|
4116
4262
|
|
|
4117
4263
|
if rhat[param].ndim == 2:
|
|
@@ -4186,8 +4332,8 @@ class Analyzer:
|
|
|
4186
4332
|
selected_times: Optional list containing a subset of dates to include. If
|
|
4187
4333
|
`new_data` is provided with modified time periods, then `selected_times`
|
|
4188
4334
|
must be a subset of `new_data.times`. Otherwise, `selected_times` must
|
|
4189
|
-
be a subset of `self.
|
|
4190
|
-
periods are included.
|
|
4335
|
+
be a subset of `self._model_context.input_data.time`. By default, all
|
|
4336
|
+
time periods are included.
|
|
4191
4337
|
by_reach: Boolean. For channels with reach and frequency. If `True`, plots
|
|
4192
4338
|
the response curve by reach. If `False`, plots the response curve by
|
|
4193
4339
|
frequency.
|
|
@@ -4206,7 +4352,7 @@ class Analyzer:
|
|
|
4206
4352
|
An `xarray.Dataset` containing the data needed to visualize response
|
|
4207
4353
|
curves.
|
|
4208
4354
|
"""
|
|
4209
|
-
if self.
|
|
4355
|
+
if self.model_context.is_national:
|
|
4210
4356
|
_warn_if_geo_arg_in_kwargs(
|
|
4211
4357
|
selected_geos=selected_geos,
|
|
4212
4358
|
)
|
|
@@ -4218,20 +4364,22 @@ class Analyzer:
|
|
|
4218
4364
|
}
|
|
4219
4365
|
if new_data is None:
|
|
4220
4366
|
new_data = DataTensors()
|
|
4221
|
-
# TODO:
|
|
4367
|
+
# TODO: Support flexible time without providing exact dates.
|
|
4222
4368
|
required_tensors_names = constants.PERFORMANCE_DATA + (constants.TIME,)
|
|
4223
4369
|
filled_data = new_data.validate_and_fill_missing_data(
|
|
4224
4370
|
required_tensors_names=required_tensors_names,
|
|
4225
|
-
|
|
4371
|
+
model_context=self.model_context,
|
|
4226
4372
|
allow_modified_times=True,
|
|
4227
4373
|
)
|
|
4228
|
-
new_n_media_times = filled_data.get_modified_times(
|
|
4374
|
+
new_n_media_times = filled_data.get_modified_times(
|
|
4375
|
+
model_context=self.model_context
|
|
4376
|
+
)
|
|
4229
4377
|
|
|
4230
4378
|
if new_n_media_times is None:
|
|
4231
4379
|
_validate_selected_times(
|
|
4232
4380
|
selected_times=selected_times,
|
|
4233
|
-
input_times=self.
|
|
4234
|
-
n_times=self.
|
|
4381
|
+
input_times=self.model_context.input_data.time,
|
|
4382
|
+
n_times=self.model_context.n_times,
|
|
4235
4383
|
arg_name="selected_times",
|
|
4236
4384
|
comparison_arg_name="the input data",
|
|
4237
4385
|
)
|
|
@@ -4243,12 +4391,12 @@ class Analyzer:
|
|
|
4243
4391
|
new_n_media_times=new_n_media_times,
|
|
4244
4392
|
new_time=new_time,
|
|
4245
4393
|
)
|
|
4246
|
-
# TODO:
|
|
4394
|
+
# TODO: Switch to Sequence[str] once it is supported.
|
|
4247
4395
|
if selected_times is not None:
|
|
4248
4396
|
selected_times = [x in selected_times for x in new_time]
|
|
4249
4397
|
dim_kwargs["selected_times"] = selected_times
|
|
4250
4398
|
|
|
4251
|
-
if self.
|
|
4399
|
+
if self.model_context.n_rf_channels > 0 and use_optimal_frequency:
|
|
4252
4400
|
opt_freq_data = DataTensors(
|
|
4253
4401
|
media=filled_data.media,
|
|
4254
4402
|
rf_impressions=filled_data.reach * filled_data.frequency,
|
|
@@ -4265,7 +4413,7 @@ class Analyzer:
|
|
|
4265
4413
|
).optimal_frequency,
|
|
4266
4414
|
dtype=backend.float32,
|
|
4267
4415
|
)
|
|
4268
|
-
reach = backend.
|
|
4416
|
+
reach = backend.divide(
|
|
4269
4417
|
filled_data.reach * filled_data.frequency,
|
|
4270
4418
|
frequency,
|
|
4271
4419
|
)
|
|
@@ -4276,13 +4424,13 @@ class Analyzer:
|
|
|
4276
4424
|
spend_multipliers = list(np.arange(0, 2.2, 0.2))
|
|
4277
4425
|
incremental_outcome = np.zeros((
|
|
4278
4426
|
len(spend_multipliers),
|
|
4279
|
-
len(self.
|
|
4427
|
+
len(self.model_context.input_data.get_all_paid_channels()),
|
|
4280
4428
|
3,
|
|
4281
4429
|
))
|
|
4282
4430
|
for i, multiplier in enumerate(spend_multipliers):
|
|
4283
4431
|
if multiplier == 0:
|
|
4284
4432
|
incremental_outcome[i, :, :] = backend.zeros(
|
|
4285
|
-
(len(self.
|
|
4433
|
+
(len(self.model_context.input_data.get_all_paid_channels()), 3)
|
|
4286
4434
|
) # Last dimension = 3 for the mean, ci_lo and ci_hi.
|
|
4287
4435
|
continue
|
|
4288
4436
|
scaled_data = _scale_tensors_by_multiplier(
|
|
@@ -4317,7 +4465,9 @@ class Analyzer:
|
|
|
4317
4465
|
)
|
|
4318
4466
|
spend_einsum = backend.einsum("k,m->km", np.array(spend_multipliers), spend)
|
|
4319
4467
|
xr_coords = {
|
|
4320
|
-
constants.CHANNEL:
|
|
4468
|
+
constants.CHANNEL: (
|
|
4469
|
+
self.model_context.input_data.get_all_paid_channels()
|
|
4470
|
+
),
|
|
4321
4471
|
constants.METRIC: [
|
|
4322
4472
|
constants.MEAN,
|
|
4323
4473
|
constants.CI_LO,
|
|
@@ -4352,8 +4502,8 @@ class Analyzer:
|
|
|
4352
4502
|
`ci_hi`, `ci_lo`, and `mean` for the Adstock function.
|
|
4353
4503
|
"""
|
|
4354
4504
|
if (
|
|
4355
|
-
constants.PRIOR not in self.
|
|
4356
|
-
or constants.POSTERIOR not in self.
|
|
4505
|
+
constants.PRIOR not in self._inference_data.groups()
|
|
4506
|
+
or constants.POSTERIOR not in self._inference_data.groups()
|
|
4357
4507
|
):
|
|
4358
4508
|
raise model.NotFittedModelError(
|
|
4359
4509
|
"sample_prior() and sample_posterior() must be called prior to"
|
|
@@ -4362,7 +4512,7 @@ class Analyzer:
|
|
|
4362
4512
|
|
|
4363
4513
|
# Choose a step_size such that time_unit has consecutive integers defined
|
|
4364
4514
|
# throughout.
|
|
4365
|
-
max_lag = max(self.
|
|
4515
|
+
max_lag = max(self.model_context.model_spec.max_lag, 1)
|
|
4366
4516
|
steps_per_time_period_max_lag = (
|
|
4367
4517
|
constants.ADSTOCK_DECAY_MAX_TOTAL_STEPS // max_lag
|
|
4368
4518
|
)
|
|
@@ -4406,23 +4556,23 @@ class Analyzer:
|
|
|
4406
4556
|
final_df_list.append(adstock_df)
|
|
4407
4557
|
|
|
4408
4558
|
_add_adstock_decay_for_channel(
|
|
4409
|
-
self.
|
|
4410
|
-
self.
|
|
4559
|
+
self.model_context.n_media_channels,
|
|
4560
|
+
self.model_context.input_data.media_channel,
|
|
4411
4561
|
constants.MEDIA,
|
|
4412
4562
|
)
|
|
4413
4563
|
_add_adstock_decay_for_channel(
|
|
4414
|
-
self.
|
|
4415
|
-
self.
|
|
4564
|
+
self.model_context.n_rf_channels,
|
|
4565
|
+
self.model_context.input_data.rf_channel,
|
|
4416
4566
|
constants.RF,
|
|
4417
4567
|
)
|
|
4418
4568
|
_add_adstock_decay_for_channel(
|
|
4419
|
-
self.
|
|
4420
|
-
self.
|
|
4569
|
+
self.model_context.n_organic_media_channels,
|
|
4570
|
+
self.model_context.input_data.organic_media_channel,
|
|
4421
4571
|
constants.ORGANIC_MEDIA,
|
|
4422
4572
|
)
|
|
4423
4573
|
_add_adstock_decay_for_channel(
|
|
4424
|
-
self.
|
|
4425
|
-
self.
|
|
4574
|
+
self.model_context.n_organic_rf_channels,
|
|
4575
|
+
self.model_context.input_data.organic_rf_channel,
|
|
4426
4576
|
constants.ORGANIC_RF,
|
|
4427
4577
|
)
|
|
4428
4578
|
|
|
@@ -4468,50 +4618,52 @@ class Analyzer:
|
|
|
4468
4618
|
"""
|
|
4469
4619
|
if (
|
|
4470
4620
|
channel_type == constants.MEDIA
|
|
4471
|
-
and self.
|
|
4621
|
+
and self.model_context.input_data.media_channel is not None
|
|
4472
4622
|
):
|
|
4473
4623
|
ec = constants.EC_M
|
|
4474
4624
|
slope = constants.SLOPE_M
|
|
4475
|
-
channels = self.
|
|
4476
|
-
transformer = self.
|
|
4625
|
+
channels = self.model_context.input_data.media_channel.values
|
|
4626
|
+
transformer = self.model_context.media_tensors.media_transformer
|
|
4477
4627
|
linspace_max_values = np.max(
|
|
4478
|
-
np.array(self.
|
|
4628
|
+
np.array(self.model_context.media_tensors.media_scaled), axis=(0, 1)
|
|
4479
4629
|
)
|
|
4480
4630
|
elif (
|
|
4481
4631
|
channel_type == constants.RF
|
|
4482
|
-
and self.
|
|
4632
|
+
and self.model_context.input_data.rf_channel is not None
|
|
4483
4633
|
):
|
|
4484
4634
|
ec = constants.EC_RF
|
|
4485
4635
|
slope = constants.SLOPE_RF
|
|
4486
|
-
channels = self.
|
|
4636
|
+
channels = self.model_context.input_data.rf_channel.values
|
|
4487
4637
|
transformer = None
|
|
4488
4638
|
linspace_max_values = np.max(
|
|
4489
|
-
np.array(self.
|
|
4639
|
+
np.array(self.model_context.rf_tensors.frequency), axis=(0, 1)
|
|
4490
4640
|
)
|
|
4491
4641
|
elif (
|
|
4492
4642
|
channel_type == constants.ORGANIC_MEDIA
|
|
4493
|
-
and self.
|
|
4643
|
+
and self.model_context.input_data.organic_media_channel is not None
|
|
4494
4644
|
):
|
|
4495
4645
|
ec = constants.EC_OM
|
|
4496
4646
|
slope = constants.SLOPE_OM
|
|
4497
|
-
channels = self.
|
|
4647
|
+
channels = self.model_context.input_data.organic_media_channel.values
|
|
4498
4648
|
transformer = (
|
|
4499
|
-
self.
|
|
4649
|
+
self.model_context.organic_media_tensors.organic_media_transformer
|
|
4500
4650
|
)
|
|
4501
4651
|
linspace_max_values = np.max(
|
|
4502
|
-
np.array(
|
|
4652
|
+
np.array(
|
|
4653
|
+
self.model_context.organic_media_tensors.organic_media_scaled
|
|
4654
|
+
),
|
|
4503
4655
|
axis=(0, 1),
|
|
4504
4656
|
)
|
|
4505
4657
|
elif (
|
|
4506
4658
|
channel_type == constants.ORGANIC_RF
|
|
4507
|
-
and self.
|
|
4659
|
+
and self.model_context.input_data.organic_rf_channel is not None
|
|
4508
4660
|
):
|
|
4509
4661
|
ec = constants.EC_ORF
|
|
4510
4662
|
slope = constants.SLOPE_ORF
|
|
4511
|
-
channels = self.
|
|
4663
|
+
channels = self.model_context.input_data.organic_rf_channel.values
|
|
4512
4664
|
transformer = None
|
|
4513
4665
|
linspace_max_values = np.max(
|
|
4514
|
-
np.array(self.
|
|
4666
|
+
np.array(self.model_context.organic_rf_tensors.organic_frequency),
|
|
4515
4667
|
axis=(0, 1),
|
|
4516
4668
|
)
|
|
4517
4669
|
else:
|
|
@@ -4546,12 +4698,12 @@ class Analyzer:
|
|
|
4546
4698
|
# n_draws, n_geos, n_times, n_channels), and we want to plot the
|
|
4547
4699
|
# dependency on time only.
|
|
4548
4700
|
hill_vals_prior = adstock_hill.HillTransformer(
|
|
4549
|
-
self.
|
|
4550
|
-
self.
|
|
4701
|
+
self._inference_data.prior[ec].values,
|
|
4702
|
+
self._inference_data.prior[slope].values,
|
|
4551
4703
|
).forward(expanded_linspace)[:, :, 0, :, :]
|
|
4552
4704
|
hill_vals_posterior = adstock_hill.HillTransformer(
|
|
4553
|
-
self.
|
|
4554
|
-
self.
|
|
4705
|
+
self._inference_data.posterior[ec].values,
|
|
4706
|
+
self._inference_data.posterior[slope].values,
|
|
4555
4707
|
).forward(expanded_linspace)[:, :, 0, :, :]
|
|
4556
4708
|
|
|
4557
4709
|
hill_dataset = _central_tendency_and_ci_by_prior_and_posterior(
|
|
@@ -4698,77 +4850,81 @@ class Analyzer:
|
|
|
4698
4850
|
distribution of media units per capita for media channels or average
|
|
4699
4851
|
frequency for RF channels over weeks and geos for the Hill plots.
|
|
4700
4852
|
"""
|
|
4701
|
-
n_geos = self.
|
|
4702
|
-
n_media_times = self.
|
|
4853
|
+
n_geos = self.model_context.n_geos
|
|
4854
|
+
n_media_times = self.model_context.n_media_times
|
|
4703
4855
|
|
|
4704
4856
|
df_list = []
|
|
4705
4857
|
|
|
4706
4858
|
# RF.
|
|
4707
|
-
if self.
|
|
4708
|
-
frequency = self.
|
|
4859
|
+
if self.model_context.input_data.rf_channel is not None:
|
|
4860
|
+
frequency = self.model_context.rf_tensors.frequency
|
|
4709
4861
|
if frequency is not None:
|
|
4710
4862
|
reshaped_frequency = backend.reshape(
|
|
4711
|
-
frequency,
|
|
4863
|
+
frequency,
|
|
4864
|
+
(n_geos * n_media_times, self.model_context.n_rf_channels),
|
|
4712
4865
|
)
|
|
4713
4866
|
rf_hist_data = self._get_channel_hill_histogram_dataframe(
|
|
4714
4867
|
channel_type=constants.RF,
|
|
4715
4868
|
data_to_histogram=reshaped_frequency,
|
|
4716
|
-
channel_names=self.
|
|
4869
|
+
channel_names=self.model_context.input_data.rf_channel.values,
|
|
4717
4870
|
n_bins=n_bins,
|
|
4718
4871
|
)
|
|
4719
4872
|
df_list.append(pd.DataFrame(rf_hist_data))
|
|
4720
4873
|
|
|
4721
4874
|
# Media.
|
|
4722
|
-
if self.
|
|
4723
|
-
transformer = self.
|
|
4724
|
-
scaled = self.
|
|
4875
|
+
if self.model_context.input_data.media_channel is not None:
|
|
4876
|
+
transformer = self.model_context.media_tensors.media_transformer
|
|
4877
|
+
scaled = self.model_context.media_tensors.media_scaled
|
|
4725
4878
|
if transformer is not None and scaled is not None:
|
|
4726
4879
|
population_scaled_median = transformer.population_scaled_median_m
|
|
4727
4880
|
scaled_media_units = scaled * population_scaled_median
|
|
4728
4881
|
reshaped_scaled_media_units = backend.reshape(
|
|
4729
4882
|
scaled_media_units,
|
|
4730
|
-
(n_geos * n_media_times, self.
|
|
4883
|
+
(n_geos * n_media_times, self.model_context.n_media_channels),
|
|
4731
4884
|
)
|
|
4732
4885
|
media_hist_data = self._get_channel_hill_histogram_dataframe(
|
|
4733
4886
|
channel_type=constants.MEDIA,
|
|
4734
4887
|
data_to_histogram=reshaped_scaled_media_units,
|
|
4735
|
-
channel_names=self.
|
|
4888
|
+
channel_names=self.model_context.input_data.media_channel.values,
|
|
4736
4889
|
n_bins=n_bins,
|
|
4737
4890
|
)
|
|
4738
4891
|
df_list.append(pd.DataFrame(media_hist_data))
|
|
4739
4892
|
# Organic media.
|
|
4740
|
-
if self.
|
|
4893
|
+
if self.model_context.input_data.organic_media_channel is not None:
|
|
4741
4894
|
transformer_om = (
|
|
4742
|
-
self.
|
|
4895
|
+
self.model_context.organic_media_tensors.organic_media_transformer
|
|
4743
4896
|
)
|
|
4744
|
-
scaled_om = self.
|
|
4897
|
+
scaled_om = self.model_context.organic_media_tensors.organic_media_scaled
|
|
4745
4898
|
if transformer_om is not None and scaled_om is not None:
|
|
4746
4899
|
population_scaled_median_om = transformer_om.population_scaled_median_m
|
|
4747
4900
|
scaled_organic_media_units = scaled_om * population_scaled_median_om
|
|
4748
4901
|
reshaped_scaled_organic_media_units = backend.reshape(
|
|
4749
4902
|
scaled_organic_media_units,
|
|
4750
|
-
(
|
|
4903
|
+
(
|
|
4904
|
+
n_geos * n_media_times,
|
|
4905
|
+
self.model_context.n_organic_media_channels,
|
|
4906
|
+
),
|
|
4751
4907
|
)
|
|
4752
4908
|
organic_media_hist_data = self._get_channel_hill_histogram_dataframe(
|
|
4753
4909
|
channel_type=constants.ORGANIC_MEDIA,
|
|
4754
4910
|
data_to_histogram=reshaped_scaled_organic_media_units,
|
|
4755
|
-
channel_names=self.
|
|
4911
|
+
channel_names=self.model_context.input_data.organic_media_channel.values,
|
|
4756
4912
|
n_bins=n_bins,
|
|
4757
4913
|
)
|
|
4758
4914
|
df_list.append(pd.DataFrame(organic_media_hist_data))
|
|
4759
4915
|
|
|
4760
4916
|
# Organic RF.
|
|
4761
|
-
if self.
|
|
4762
|
-
frequency = self.
|
|
4917
|
+
if self.model_context.input_data.organic_rf_channel is not None:
|
|
4918
|
+
frequency = self.model_context.organic_rf_tensors.organic_frequency
|
|
4763
4919
|
if frequency is not None:
|
|
4764
4920
|
reshaped_frequency = backend.reshape(
|
|
4765
4921
|
frequency,
|
|
4766
|
-
(n_geos * n_media_times, self.
|
|
4922
|
+
(n_geos * n_media_times, self.model_context.n_organic_rf_channels),
|
|
4767
4923
|
)
|
|
4768
4924
|
organic_rf_hist_data = self._get_channel_hill_histogram_dataframe(
|
|
4769
4925
|
channel_type=constants.ORGANIC_RF,
|
|
4770
4926
|
data_to_histogram=reshaped_frequency,
|
|
4771
|
-
channel_names=self.
|
|
4927
|
+
channel_names=self.model_context.input_data.organic_rf_channel.values,
|
|
4772
4928
|
n_bins=n_bins,
|
|
4773
4929
|
)
|
|
4774
4930
|
df_list.append(pd.DataFrame(organic_rf_hist_data))
|
|
@@ -4811,8 +4967,8 @@ class Analyzer:
|
|
|
4811
4967
|
for a histogram bin.
|
|
4812
4968
|
"""
|
|
4813
4969
|
if (
|
|
4814
|
-
constants.PRIOR not in self.
|
|
4815
|
-
or constants.POSTERIOR not in self.
|
|
4970
|
+
constants.PRIOR not in self._inference_data.groups()
|
|
4971
|
+
or constants.POSTERIOR not in self._inference_data.groups()
|
|
4816
4972
|
):
|
|
4817
4973
|
raise model.NotFittedModelError(
|
|
4818
4974
|
"sample_prior() and sample_posterior() must be called prior to"
|
|
@@ -4821,10 +4977,10 @@ class Analyzer:
|
|
|
4821
4977
|
|
|
4822
4978
|
final_dfs = [pd.DataFrame()]
|
|
4823
4979
|
for n_channels, channel_type in [
|
|
4824
|
-
(self.
|
|
4825
|
-
(self.
|
|
4826
|
-
(self.
|
|
4827
|
-
(self.
|
|
4980
|
+
(self.model_context.n_media_channels, constants.MEDIA),
|
|
4981
|
+
(self.model_context.n_rf_channels, constants.RF),
|
|
4982
|
+
(self.model_context.n_organic_media_channels, constants.ORGANIC_MEDIA),
|
|
4983
|
+
(self.model_context.n_organic_rf_channels, constants.ORGANIC_RF),
|
|
4828
4984
|
]:
|
|
4829
4985
|
if n_channels > 0:
|
|
4830
4986
|
hill_df = self._get_hill_curves_dataframe(
|
|
@@ -4967,6 +5123,7 @@ class Analyzer:
|
|
|
4967
5123
|
include_median=True,
|
|
4968
5124
|
)
|
|
4969
5125
|
|
|
5126
|
+
# TODO: Remove this method.
|
|
4970
5127
|
def get_historical_spend(
|
|
4971
5128
|
self,
|
|
4972
5129
|
selected_times: Sequence[str] | None = None,
|
|
@@ -5048,7 +5205,8 @@ class Analyzer:
|
|
|
5048
5205
|
new_data = new_data or DataTensors()
|
|
5049
5206
|
required_tensors_names = constants.PAID_CHANNELS + constants.SPEND_DATA
|
|
5050
5207
|
filled_data = new_data.validate_and_fill_missing_data(
|
|
5051
|
-
required_tensors_names,
|
|
5208
|
+
required_tensors_names=required_tensors_names,
|
|
5209
|
+
model_context=self.model_context,
|
|
5052
5210
|
)
|
|
5053
5211
|
|
|
5054
5212
|
empty_da = xr.DataArray(
|
|
@@ -5058,9 +5216,9 @@ class Analyzer:
|
|
|
5058
5216
|
if not include_media:
|
|
5059
5217
|
aggregated_media_spend = empty_da
|
|
5060
5218
|
elif (
|
|
5061
|
-
self.
|
|
5062
|
-
or self.
|
|
5063
|
-
or self.
|
|
5219
|
+
self.model_context.media_tensors.media is None
|
|
5220
|
+
or self.model_context.media_tensors.media_spend is None
|
|
5221
|
+
or self.model_context.input_data.media_channel is None
|
|
5064
5222
|
):
|
|
5065
5223
|
warnings.warn(
|
|
5066
5224
|
"Requested spends for paid media channels that do not have R&F"
|
|
@@ -5073,16 +5231,18 @@ class Analyzer:
|
|
|
5073
5231
|
selected_times=selected_times,
|
|
5074
5232
|
media_execution_values=filled_data.media,
|
|
5075
5233
|
channel_spend=filled_data.media_spend,
|
|
5076
|
-
channel_names=list(
|
|
5234
|
+
channel_names=list(
|
|
5235
|
+
self.model_context.input_data.media_channel.values
|
|
5236
|
+
),
|
|
5077
5237
|
)
|
|
5078
5238
|
|
|
5079
5239
|
if not include_rf:
|
|
5080
5240
|
aggregated_rf_spend = empty_da
|
|
5081
5241
|
elif (
|
|
5082
|
-
self.
|
|
5083
|
-
or self.
|
|
5084
|
-
or self.
|
|
5085
|
-
or self.
|
|
5242
|
+
self.model_context.input_data.rf_channel is None
|
|
5243
|
+
or self.model_context.rf_tensors.reach is None
|
|
5244
|
+
or self.model_context.rf_tensors.frequency is None
|
|
5245
|
+
or self.model_context.rf_tensors.rf_spend is None
|
|
5086
5246
|
):
|
|
5087
5247
|
warnings.warn(
|
|
5088
5248
|
"Requested spends for paid media channels with R&F data, but the"
|
|
@@ -5096,7 +5256,7 @@ class Analyzer:
|
|
|
5096
5256
|
selected_times=selected_times,
|
|
5097
5257
|
media_execution_values=rf_execution_values,
|
|
5098
5258
|
channel_spend=filled_data.rf_spend,
|
|
5099
|
-
channel_names=list(self.
|
|
5259
|
+
channel_names=list(self.model_context.input_data.rf_channel.values),
|
|
5100
5260
|
)
|
|
5101
5261
|
|
|
5102
5262
|
return xr.concat(
|
|
@@ -5157,9 +5317,9 @@ class Analyzer:
|
|
|
5157
5317
|
# channel_spend.ndim can only be 3 or 1.
|
|
5158
5318
|
else:
|
|
5159
5319
|
# media spend can have more time points than the model time points
|
|
5160
|
-
if media_execution_values.shape[1] == self.
|
|
5320
|
+
if media_execution_values.shape[1] == self.model_context.n_media_times:
|
|
5161
5321
|
media_exe_values = media_execution_values[
|
|
5162
|
-
:, -self.
|
|
5322
|
+
:, -self.model_context.n_times :, :
|
|
5163
5323
|
]
|
|
5164
5324
|
else:
|
|
5165
5325
|
media_exe_values = media_execution_values
|
|
@@ -5169,7 +5329,7 @@ class Analyzer:
|
|
|
5169
5329
|
media_exe_values,
|
|
5170
5330
|
**dim_kwargs,
|
|
5171
5331
|
)
|
|
5172
|
-
imputed_cpmu = backend.
|
|
5332
|
+
imputed_cpmu = backend.divide(
|
|
5173
5333
|
channel_spend,
|
|
5174
5334
|
np.sum(media_exe_values, (0, 1)),
|
|
5175
5335
|
)
|