google-meridian 1.3.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/METADATA +18 -11
- google_meridian-1.5.0.dist-info/RECORD +112 -0
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/WHEEL +1 -1
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/top_level.txt +1 -0
- meridian/analysis/analyzer.py +558 -398
- meridian/analysis/optimizer.py +90 -68
- meridian/analysis/review/reviewer.py +4 -1
- meridian/analysis/summarizer.py +13 -3
- meridian/analysis/test_utils.py +2911 -2102
- meridian/analysis/visualizer.py +37 -14
- meridian/backend/__init__.py +106 -0
- meridian/constants.py +2 -0
- meridian/data/input_data.py +30 -52
- meridian/data/input_data_builder.py +2 -9
- meridian/data/test_utils.py +107 -51
- meridian/data/validator.py +48 -0
- meridian/mlflow/autolog.py +19 -9
- meridian/model/__init__.py +2 -0
- meridian/model/adstock_hill.py +3 -5
- meridian/model/context.py +1059 -0
- meridian/model/eda/constants.py +335 -4
- meridian/model/eda/eda_engine.py +723 -312
- meridian/model/eda/eda_outcome.py +177 -33
- meridian/model/equations.py +418 -0
- meridian/model/knots.py +58 -47
- meridian/model/model.py +228 -878
- meridian/model/model_test_data.py +38 -0
- meridian/model/posterior_sampler.py +103 -62
- meridian/model/prior_sampler.py +114 -94
- meridian/model/spec.py +23 -14
- meridian/templates/card.html.jinja +9 -7
- meridian/templates/chart.html.jinja +1 -6
- meridian/templates/finding.html.jinja +19 -0
- meridian/templates/findings.html.jinja +33 -0
- meridian/templates/formatter.py +41 -5
- meridian/templates/formatter_test.py +127 -0
- meridian/templates/style.css +66 -9
- meridian/templates/style.scss +85 -4
- meridian/templates/table.html.jinja +1 -0
- meridian/version.py +1 -1
- scenarioplanner/__init__.py +42 -0
- scenarioplanner/converters/__init__.py +25 -0
- scenarioplanner/converters/dataframe/__init__.py +28 -0
- scenarioplanner/converters/dataframe/budget_opt_converters.py +383 -0
- scenarioplanner/converters/dataframe/common.py +71 -0
- scenarioplanner/converters/dataframe/constants.py +137 -0
- scenarioplanner/converters/dataframe/converter.py +42 -0
- scenarioplanner/converters/dataframe/dataframe_model_converter.py +70 -0
- scenarioplanner/converters/dataframe/marketing_analyses_converters.py +543 -0
- scenarioplanner/converters/dataframe/rf_opt_converters.py +314 -0
- scenarioplanner/converters/mmm.py +743 -0
- scenarioplanner/converters/mmm_converter.py +58 -0
- scenarioplanner/converters/sheets.py +156 -0
- scenarioplanner/converters/test_data.py +714 -0
- scenarioplanner/linkingapi/__init__.py +47 -0
- scenarioplanner/linkingapi/constants.py +27 -0
- scenarioplanner/linkingapi/url_generator.py +131 -0
- scenarioplanner/mmm_ui_proto_generator.py +355 -0
- schema/__init__.py +5 -2
- schema/mmm_proto_generator.py +71 -0
- schema/model_consumer.py +133 -0
- schema/processors/__init__.py +77 -0
- schema/processors/budget_optimization_processor.py +832 -0
- schema/processors/common.py +64 -0
- schema/processors/marketing_processor.py +1137 -0
- schema/processors/model_fit_processor.py +367 -0
- schema/processors/model_kernel_processor.py +117 -0
- schema/processors/model_processor.py +415 -0
- schema/processors/reach_frequency_optimization_processor.py +584 -0
- schema/serde/distribution.py +12 -7
- schema/serde/hyperparameters.py +54 -107
- schema/serde/meridian_serde.py +6 -1
- schema/test_data.py +380 -0
- schema/utils/__init__.py +2 -0
- schema/utils/date_range_bucketing.py +117 -0
- schema/utils/proto_enum_converter.py +127 -0
- google_meridian-1.3.2.dist-info/RECORD +0 -76
- {google_meridian-1.3.2.dist-info → google_meridian-1.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1059 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Defines ModelContext class for Meridian."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Mapping, Sequence
|
|
18
|
+
import functools
|
|
19
|
+
import warnings
|
|
20
|
+
|
|
21
|
+
from meridian import backend
|
|
22
|
+
from meridian import constants
|
|
23
|
+
from meridian.data import input_data as data
|
|
24
|
+
from meridian.data import time_coordinates as tc
|
|
25
|
+
from meridian.model import adstock_hill
|
|
26
|
+
from meridian.model import knots
|
|
27
|
+
from meridian.model import media
|
|
28
|
+
from meridian.model import prior_distribution
|
|
29
|
+
from meridian.model import spec
|
|
30
|
+
from meridian.model import transformers
|
|
31
|
+
import numpy as np
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"ModelContext",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ModelContext:
|
|
39
|
+
"""Model context for Meridian.
|
|
40
|
+
|
|
41
|
+
This class contains all model parameters that do not change between the runs
|
|
42
|
+
of Meridian.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
input_data: data.InputData,
|
|
48
|
+
model_spec: spec.ModelSpec,
|
|
49
|
+
):
|
|
50
|
+
self._input_data = input_data
|
|
51
|
+
self._model_spec = model_spec
|
|
52
|
+
|
|
53
|
+
self._validate_data_dependent_model_spec()
|
|
54
|
+
self._validate_model_spec_shapes()
|
|
55
|
+
|
|
56
|
+
self._set_total_media_contribution_prior = False
|
|
57
|
+
self._warn_setting_ignored_priors()
|
|
58
|
+
self._validate_mroi_priors_non_revenue()
|
|
59
|
+
self._validate_roi_priors_non_revenue()
|
|
60
|
+
self._check_media_prior_support()
|
|
61
|
+
self._validate_geo_invariants()
|
|
62
|
+
self._validate_time_invariants()
|
|
63
|
+
self._validate_media_spend_for_paid_channels()
|
|
64
|
+
self._validate_rf_spend_for_paid_channels()
|
|
65
|
+
|
|
66
|
+
def _validate_data_dependent_model_spec(self):
|
|
67
|
+
"""Validates that the data dependent model specs have correct shapes."""
|
|
68
|
+
|
|
69
|
+
if self._model_spec.roi_calibration_period is not None and (
|
|
70
|
+
self._model_spec.roi_calibration_period.shape
|
|
71
|
+
!= (
|
|
72
|
+
self.n_media_times,
|
|
73
|
+
self.n_media_channels,
|
|
74
|
+
)
|
|
75
|
+
):
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"The shape of `roi_calibration_period`"
|
|
78
|
+
f" {self._model_spec.roi_calibration_period.shape} is different from"
|
|
79
|
+
f" `(n_media_times, n_media_channels) = ({self.n_media_times},"
|
|
80
|
+
f" {self.n_media_channels})`."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if self._model_spec.rf_roi_calibration_period is not None and (
|
|
84
|
+
self._model_spec.rf_roi_calibration_period.shape
|
|
85
|
+
!= (
|
|
86
|
+
self.n_media_times,
|
|
87
|
+
self.n_rf_channels,
|
|
88
|
+
)
|
|
89
|
+
):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"The shape of `rf_roi_calibration_period`"
|
|
92
|
+
f" {self._model_spec.rf_roi_calibration_period.shape} is different"
|
|
93
|
+
f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
|
|
94
|
+
f" {self.n_rf_channels})`."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if self._model_spec.holdout_id is not None:
|
|
98
|
+
if self.is_national and (
|
|
99
|
+
self._model_spec.holdout_id.shape != (self.n_times,)
|
|
100
|
+
):
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
|
|
103
|
+
f" different from `(n_times,) = ({self.n_times},)`."
|
|
104
|
+
)
|
|
105
|
+
elif not self.is_national and (
|
|
106
|
+
self._model_spec.holdout_id.shape
|
|
107
|
+
!= (
|
|
108
|
+
self.n_geos,
|
|
109
|
+
self.n_times,
|
|
110
|
+
)
|
|
111
|
+
):
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
|
|
114
|
+
f" different from `(n_geos, n_times) = ({self.n_geos},"
|
|
115
|
+
f" {self.n_times})`."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
if self._model_spec.control_population_scaling_id is not None and (
|
|
119
|
+
self._model_spec.control_population_scaling_id.shape
|
|
120
|
+
!= (self.n_controls,)
|
|
121
|
+
):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"The shape of `control_population_scaling_id`"
|
|
124
|
+
f" {self._model_spec.control_population_scaling_id.shape} is"
|
|
125
|
+
f" different from `(n_controls,) = ({self.n_controls},)`."
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if self._model_spec.non_media_population_scaling_id is not None and (
|
|
129
|
+
self._model_spec.non_media_population_scaling_id.shape
|
|
130
|
+
!= (self.n_non_media_channels,)
|
|
131
|
+
):
|
|
132
|
+
raise ValueError(
|
|
133
|
+
"The shape of `non_media_population_scaling_id`"
|
|
134
|
+
f" {self._model_spec.non_media_population_scaling_id.shape} is"
|
|
135
|
+
" different from `(n_non_media_channels,) ="
|
|
136
|
+
f" ({self.n_non_media_channels},)`."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _validate_model_spec_shapes(self):
|
|
140
|
+
"""Validate shapes of model_spec attributes."""
|
|
141
|
+
if self._model_spec.roi_calibration_period is not None:
|
|
142
|
+
if self._model_spec.roi_calibration_period.shape != (
|
|
143
|
+
self.n_media_times,
|
|
144
|
+
self.n_media_channels,
|
|
145
|
+
):
|
|
146
|
+
raise ValueError(
|
|
147
|
+
"The shape of `roi_calibration_period`"
|
|
148
|
+
f" {self._model_spec.roi_calibration_period.shape} is different"
|
|
149
|
+
f" from `(n_media_times, n_media_channels) = ({self.n_media_times},"
|
|
150
|
+
f" {self.n_media_channels})`."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if self._model_spec.rf_roi_calibration_period is not None:
|
|
154
|
+
if self._model_spec.rf_roi_calibration_period.shape != (
|
|
155
|
+
self.n_media_times,
|
|
156
|
+
self.n_rf_channels,
|
|
157
|
+
):
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"The shape of `rf_roi_calibration_period`"
|
|
160
|
+
f" {self._model_spec.rf_roi_calibration_period.shape} is different"
|
|
161
|
+
f" from `(n_media_times, n_rf_channels) = ({self.n_media_times},"
|
|
162
|
+
f" {self.n_rf_channels})`."
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
if self._model_spec.holdout_id is not None:
|
|
166
|
+
expected_shape = (
|
|
167
|
+
(self.n_times,) if self.is_national else (self.n_geos, self.n_times)
|
|
168
|
+
)
|
|
169
|
+
if self._model_spec.holdout_id.shape != expected_shape:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"The shape of `holdout_id` {self._model_spec.holdout_id.shape} is"
|
|
172
|
+
" different from"
|
|
173
|
+
f" {'`(n_times,)`' if self.is_national else '`(n_geos, n_times)`'} ="
|
|
174
|
+
f" {expected_shape}."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
if self._model_spec.control_population_scaling_id is not None:
|
|
178
|
+
if self._model_spec.control_population_scaling_id.shape != (
|
|
179
|
+
self.n_controls,
|
|
180
|
+
):
|
|
181
|
+
raise ValueError(
|
|
182
|
+
"The shape of `control_population_scaling_id`"
|
|
183
|
+
f" {self._model_spec.control_population_scaling_id.shape} is"
|
|
184
|
+
f" different from `(n_controls,) = ({self.n_controls},)`."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def _validate_geo_invariants(self):
|
|
188
|
+
"""Validates non-national model invariants."""
|
|
189
|
+
if self.is_national:
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
if self._input_data.controls is not None:
|
|
193
|
+
self._check_if_no_geo_variation(
|
|
194
|
+
self.controls_scaled,
|
|
195
|
+
constants.CONTROLS,
|
|
196
|
+
self._input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
197
|
+
)
|
|
198
|
+
if self._input_data.non_media_treatments is not None:
|
|
199
|
+
self._check_if_no_geo_variation(
|
|
200
|
+
self.non_media_treatments_normalized,
|
|
201
|
+
constants.NON_MEDIA_TREATMENTS,
|
|
202
|
+
self._input_data.non_media_treatments.coords[
|
|
203
|
+
constants.NON_MEDIA_CHANNEL
|
|
204
|
+
].values,
|
|
205
|
+
)
|
|
206
|
+
if self._input_data.media is not None:
|
|
207
|
+
self._check_if_no_geo_variation(
|
|
208
|
+
self.media_tensors.media_scaled,
|
|
209
|
+
constants.MEDIA,
|
|
210
|
+
self._input_data.media.coords[constants.MEDIA_CHANNEL].values,
|
|
211
|
+
)
|
|
212
|
+
if self._input_data.reach is not None:
|
|
213
|
+
self._check_if_no_geo_variation(
|
|
214
|
+
self.rf_tensors.reach_scaled,
|
|
215
|
+
constants.REACH,
|
|
216
|
+
self._input_data.reach.coords[constants.RF_CHANNEL].values,
|
|
217
|
+
)
|
|
218
|
+
if self._input_data.organic_media is not None:
|
|
219
|
+
self._check_if_no_geo_variation(
|
|
220
|
+
self.organic_media_tensors.organic_media_scaled,
|
|
221
|
+
"organic_media",
|
|
222
|
+
self._input_data.organic_media.coords[
|
|
223
|
+
constants.ORGANIC_MEDIA_CHANNEL
|
|
224
|
+
].values,
|
|
225
|
+
)
|
|
226
|
+
if self._input_data.organic_reach is not None:
|
|
227
|
+
self._check_if_no_geo_variation(
|
|
228
|
+
self.organic_rf_tensors.organic_reach_scaled,
|
|
229
|
+
"organic_reach",
|
|
230
|
+
self._input_data.organic_reach.coords[
|
|
231
|
+
constants.ORGANIC_RF_CHANNEL
|
|
232
|
+
].values,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def _check_if_no_geo_variation(
|
|
236
|
+
self,
|
|
237
|
+
scaled_data: backend.Tensor,
|
|
238
|
+
data_name: str,
|
|
239
|
+
data_dims: Sequence[str],
|
|
240
|
+
epsilon=1e-4,
|
|
241
|
+
):
|
|
242
|
+
"""Raise an error if `n_knots == n_time` and data lacks geo variation."""
|
|
243
|
+
|
|
244
|
+
# Result shape: [n, d], where d is the number of axes of condition.
|
|
245
|
+
col_idx_full = backend.get_indices_where(
|
|
246
|
+
backend.reduce_std(scaled_data, axis=0) < epsilon
|
|
247
|
+
)[:, 1]
|
|
248
|
+
col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
|
|
249
|
+
# We use the shape of scaled_data (instead of `n_time`) because the data may
|
|
250
|
+
# be padded to account for lagged effects.
|
|
251
|
+
data_n_time = scaled_data.shape[1]
|
|
252
|
+
mask = backend.equal(counts, data_n_time)
|
|
253
|
+
col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
|
|
254
|
+
dims_bad = backend.gather(data_dims, col_idx_bad)
|
|
255
|
+
|
|
256
|
+
if col_idx_bad.shape[0] and self.knot_info.n_knots == self.n_times:
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f"The following {data_name} variables do not vary across geos, making"
|
|
259
|
+
f" a model with n_knots=n_time unidentifiable: {dims_bad}. This can"
|
|
260
|
+
" lead to poor model convergence. Since these variables only vary"
|
|
261
|
+
" across time and not across geo, they are collinear with time and"
|
|
262
|
+
" redundant in a model with a parameter for each time period. To"
|
|
263
|
+
" address this, you can either: (1) decrease the number of knots"
|
|
264
|
+
" (n_knots < n_time), or (2) drop the listed variables that do not"
|
|
265
|
+
" vary across geos."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def _validate_time_invariants(self):
|
|
269
|
+
"""Validates model time invariants."""
|
|
270
|
+
if self._input_data.controls is not None:
|
|
271
|
+
self._check_if_no_time_variation(
|
|
272
|
+
self.controls_scaled,
|
|
273
|
+
constants.CONTROLS,
|
|
274
|
+
self._input_data.controls.coords[constants.CONTROL_VARIABLE].values,
|
|
275
|
+
)
|
|
276
|
+
if self._input_data.non_media_treatments is not None:
|
|
277
|
+
self._check_if_no_time_variation(
|
|
278
|
+
self.non_media_treatments_normalized,
|
|
279
|
+
constants.NON_MEDIA_TREATMENTS,
|
|
280
|
+
self._input_data.non_media_treatments.coords[
|
|
281
|
+
constants.NON_MEDIA_CHANNEL
|
|
282
|
+
].values,
|
|
283
|
+
)
|
|
284
|
+
if self._input_data.media is not None:
|
|
285
|
+
self._check_if_no_time_variation(
|
|
286
|
+
self.media_tensors.media_scaled,
|
|
287
|
+
constants.MEDIA,
|
|
288
|
+
self._input_data.media.coords[constants.MEDIA_CHANNEL].values,
|
|
289
|
+
)
|
|
290
|
+
if self._input_data.reach is not None:
|
|
291
|
+
self._check_if_no_time_variation(
|
|
292
|
+
self.rf_tensors.reach_scaled,
|
|
293
|
+
constants.REACH,
|
|
294
|
+
self._input_data.reach.coords[constants.RF_CHANNEL].values,
|
|
295
|
+
)
|
|
296
|
+
if self._input_data.organic_media is not None:
|
|
297
|
+
self._check_if_no_time_variation(
|
|
298
|
+
self.organic_media_tensors.organic_media_scaled,
|
|
299
|
+
constants.ORGANIC_MEDIA,
|
|
300
|
+
self._input_data.organic_media.coords[
|
|
301
|
+
constants.ORGANIC_MEDIA_CHANNEL
|
|
302
|
+
].values,
|
|
303
|
+
)
|
|
304
|
+
if self._input_data.organic_reach is not None:
|
|
305
|
+
self._check_if_no_time_variation(
|
|
306
|
+
self.organic_rf_tensors.organic_reach_scaled,
|
|
307
|
+
constants.ORGANIC_REACH,
|
|
308
|
+
self._input_data.organic_reach.coords[
|
|
309
|
+
constants.ORGANIC_RF_CHANNEL
|
|
310
|
+
].values,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def _validate_media_spend_for_paid_channels(self) -> None:
|
|
314
|
+
self._validate_spend_for_paid_channels(
|
|
315
|
+
self.input_data.aggregate_media_spend(), constants.MEDIA_CHANNEL
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
def _validate_rf_spend_for_paid_channels(self) -> None:
|
|
319
|
+
self._validate_spend_for_paid_channels(
|
|
320
|
+
self.input_data.aggregate_rf_spend(), constants.RF_CHANNEL
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
def _validate_spend_for_paid_channels(
|
|
324
|
+
self,
|
|
325
|
+
spend: np.ndarray | None,
|
|
326
|
+
dim: str,
|
|
327
|
+
) -> None:
|
|
328
|
+
"""Validates non-zero media spend for paid media channels.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
spend: The media spend data to validate.
|
|
332
|
+
dim: The dimension name of the spend data.
|
|
333
|
+
|
|
334
|
+
Raises:
|
|
335
|
+
ValueError if any paid media channel has zero total spend.
|
|
336
|
+
"""
|
|
337
|
+
if spend is None:
|
|
338
|
+
return
|
|
339
|
+
zero_spend_channels = spend.coords[dim].where(spend == 0, drop=True).values
|
|
340
|
+
|
|
341
|
+
if zero_spend_channels.size > 0:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
"Zero total spend detected for paid channels:"
|
|
344
|
+
f" {', '.join(zero_spend_channels)}. If data is correct and this is"
|
|
345
|
+
" expected, please consider modeling the data as organic media."
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
def _check_if_no_time_variation(
|
|
349
|
+
self,
|
|
350
|
+
scaled_data: backend.Tensor,
|
|
351
|
+
data_name: str,
|
|
352
|
+
data_dims: Sequence[str],
|
|
353
|
+
epsilon=1e-4,
|
|
354
|
+
):
|
|
355
|
+
"""Raise an error if data lacks time variation."""
|
|
356
|
+
|
|
357
|
+
# Result shape: [n, d], where d is the number of axes of condition.
|
|
358
|
+
col_idx_full = backend.get_indices_where(
|
|
359
|
+
backend.reduce_std(scaled_data, axis=1) < epsilon
|
|
360
|
+
)[:, 1]
|
|
361
|
+
col_idx_unique, _, counts = backend.unique_with_counts(col_idx_full)
|
|
362
|
+
mask = backend.equal(counts, self.n_geos)
|
|
363
|
+
col_idx_bad = backend.boolean_mask(col_idx_unique, mask)
|
|
364
|
+
dims_bad = backend.gather(data_dims, col_idx_bad)
|
|
365
|
+
if col_idx_bad.shape[0]:
|
|
366
|
+
if self.is_national:
|
|
367
|
+
raise ValueError(
|
|
368
|
+
f"The following {data_name} variables do not vary across time,"
|
|
369
|
+
" which is equivalent to no signal at all in a national model:"
|
|
370
|
+
f" {dims_bad}. This can lead to poor model convergence. To address"
|
|
371
|
+
" this, drop the listed variables that do not vary across time."
|
|
372
|
+
)
|
|
373
|
+
else:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"The following {data_name} variables do not vary across time,"
|
|
376
|
+
f" making a model with geo main effects unidentifiable: {dims_bad}."
|
|
377
|
+
" This can lead to poor model convergence. Since these variables"
|
|
378
|
+
" only vary across geo and not across time, they are collinear"
|
|
379
|
+
" with geo and redundant in a model with geo main effects. To"
|
|
380
|
+
" address this, drop the listed variables that do not vary across"
|
|
381
|
+
" time."
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
@property
|
|
385
|
+
def input_data(self) -> data.InputData:
|
|
386
|
+
return self._input_data
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def model_spec(self) -> spec.ModelSpec:
|
|
390
|
+
return self._model_spec
|
|
391
|
+
|
|
392
|
+
@functools.cached_property
|
|
393
|
+
def media_tensors(self) -> media.MediaTensors:
|
|
394
|
+
return media.build_media_tensors(self._input_data, self._model_spec)
|
|
395
|
+
|
|
396
|
+
@functools.cached_property
|
|
397
|
+
def rf_tensors(self) -> media.RfTensors:
|
|
398
|
+
return media.build_rf_tensors(self._input_data, self._model_spec)
|
|
399
|
+
|
|
400
|
+
@functools.cached_property
|
|
401
|
+
def organic_media_tensors(self) -> media.OrganicMediaTensors:
|
|
402
|
+
return media.build_organic_media_tensors(self._input_data)
|
|
403
|
+
|
|
404
|
+
@functools.cached_property
|
|
405
|
+
def organic_rf_tensors(self) -> media.OrganicRfTensors:
|
|
406
|
+
return media.build_organic_rf_tensors(self._input_data)
|
|
407
|
+
|
|
408
|
+
@functools.cached_property
|
|
409
|
+
def kpi(self) -> backend.Tensor:
|
|
410
|
+
return backend.to_tensor(self._input_data.kpi, dtype=backend.float32)
|
|
411
|
+
|
|
412
|
+
@functools.cached_property
|
|
413
|
+
def revenue_per_kpi(self) -> backend.Tensor | None:
|
|
414
|
+
if self._input_data.revenue_per_kpi is None:
|
|
415
|
+
return None
|
|
416
|
+
return backend.to_tensor(
|
|
417
|
+
self._input_data.revenue_per_kpi, dtype=backend.float32
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
@functools.cached_property
|
|
421
|
+
def controls(self) -> backend.Tensor | None:
|
|
422
|
+
if self._input_data.controls is None:
|
|
423
|
+
return None
|
|
424
|
+
return backend.to_tensor(self._input_data.controls, dtype=backend.float32)
|
|
425
|
+
|
|
426
|
+
@functools.cached_property
|
|
427
|
+
def non_media_treatments(self) -> backend.Tensor | None:
|
|
428
|
+
if self._input_data.non_media_treatments is None:
|
|
429
|
+
return None
|
|
430
|
+
return backend.to_tensor(
|
|
431
|
+
self._input_data.non_media_treatments, dtype=backend.float32
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
@functools.cached_property
|
|
435
|
+
def population(self) -> backend.Tensor:
|
|
436
|
+
return backend.to_tensor(self._input_data.population, dtype=backend.float32)
|
|
437
|
+
|
|
438
|
+
@functools.cached_property
|
|
439
|
+
def total_spend(self) -> backend.Tensor:
|
|
440
|
+
return backend.to_tensor(
|
|
441
|
+
self._input_data.get_total_spend(), dtype=backend.float32
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
@functools.cached_property
|
|
445
|
+
def total_outcome(self) -> backend.Tensor:
|
|
446
|
+
return backend.to_tensor(
|
|
447
|
+
self._input_data.get_total_outcome(), dtype=backend.float32
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
@property
|
|
451
|
+
def n_geos(self) -> int:
|
|
452
|
+
return len(self._input_data.geo)
|
|
453
|
+
|
|
454
|
+
@property
|
|
455
|
+
def n_media_channels(self) -> int:
|
|
456
|
+
if self._input_data.media_channel is None:
|
|
457
|
+
return 0
|
|
458
|
+
return len(self._input_data.media_channel)
|
|
459
|
+
|
|
460
|
+
@property
|
|
461
|
+
def n_rf_channels(self) -> int:
|
|
462
|
+
if self._input_data.rf_channel is None:
|
|
463
|
+
return 0
|
|
464
|
+
return len(self._input_data.rf_channel)
|
|
465
|
+
|
|
466
|
+
@property
|
|
467
|
+
def n_organic_media_channels(self) -> int:
|
|
468
|
+
if self._input_data.organic_media_channel is None:
|
|
469
|
+
return 0
|
|
470
|
+
return len(self._input_data.organic_media_channel)
|
|
471
|
+
|
|
472
|
+
@property
|
|
473
|
+
def n_organic_rf_channels(self) -> int:
|
|
474
|
+
if self._input_data.organic_rf_channel is None:
|
|
475
|
+
return 0
|
|
476
|
+
return len(self._input_data.organic_rf_channel)
|
|
477
|
+
|
|
478
|
+
@property
|
|
479
|
+
def n_controls(self) -> int:
|
|
480
|
+
if self._input_data.control_variable is None:
|
|
481
|
+
return 0
|
|
482
|
+
return len(self._input_data.control_variable)
|
|
483
|
+
|
|
484
|
+
@property
|
|
485
|
+
def n_non_media_channels(self) -> int:
|
|
486
|
+
if self._input_data.non_media_channel is None:
|
|
487
|
+
return 0
|
|
488
|
+
return len(self._input_data.non_media_channel)
|
|
489
|
+
|
|
490
|
+
@property
|
|
491
|
+
def n_times(self) -> int:
|
|
492
|
+
return len(self._input_data.time)
|
|
493
|
+
|
|
494
|
+
@property
|
|
495
|
+
def n_media_times(self) -> int:
|
|
496
|
+
return len(self._input_data.media_time)
|
|
497
|
+
|
|
498
|
+
@property
|
|
499
|
+
def is_national(self) -> bool:
|
|
500
|
+
return self.n_geos == 1
|
|
501
|
+
|
|
502
|
+
@functools.cached_property
|
|
503
|
+
def knot_info(self) -> knots.KnotInfo:
|
|
504
|
+
return knots.get_knot_info(
|
|
505
|
+
n_times=self.n_times,
|
|
506
|
+
knots=self._model_spec.knots,
|
|
507
|
+
enable_aks=self._model_spec.enable_aks,
|
|
508
|
+
data=self._input_data,
|
|
509
|
+
is_national=self.is_national,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
@functools.cached_property
|
|
513
|
+
def controls_transformer(
|
|
514
|
+
self,
|
|
515
|
+
) -> transformers.CenteringAndScalingTransformer | None:
|
|
516
|
+
"""Returns a `CenteringAndScalingTransformer` for controls, if it exists."""
|
|
517
|
+
if self.controls is None:
|
|
518
|
+
return None
|
|
519
|
+
|
|
520
|
+
if self._model_spec.control_population_scaling_id is not None:
|
|
521
|
+
controls_population_scaling_id = backend.to_tensor(
|
|
522
|
+
self._model_spec.control_population_scaling_id, dtype=backend.bool_
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
controls_population_scaling_id = None
|
|
526
|
+
|
|
527
|
+
return transformers.CenteringAndScalingTransformer(
|
|
528
|
+
tensor=self.controls,
|
|
529
|
+
population=self.population,
|
|
530
|
+
population_scaling_id=controls_population_scaling_id,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
@functools.cached_property
|
|
534
|
+
def non_media_transformer(
|
|
535
|
+
self,
|
|
536
|
+
) -> transformers.CenteringAndScalingTransformer | None:
|
|
537
|
+
"""Returns a `CenteringAndScalingTransformer` for non-media treatments."""
|
|
538
|
+
if self.non_media_treatments is None:
|
|
539
|
+
return None
|
|
540
|
+
if self._model_spec.non_media_population_scaling_id is not None:
|
|
541
|
+
non_media_population_scaling_id = backend.to_tensor(
|
|
542
|
+
self._model_spec.non_media_population_scaling_id, dtype=backend.bool_
|
|
543
|
+
)
|
|
544
|
+
else:
|
|
545
|
+
non_media_population_scaling_id = None
|
|
546
|
+
|
|
547
|
+
return transformers.CenteringAndScalingTransformer(
|
|
548
|
+
tensor=self.non_media_treatments,
|
|
549
|
+
population=self.population,
|
|
550
|
+
population_scaling_id=non_media_population_scaling_id,
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
@functools.cached_property
|
|
554
|
+
def kpi_transformer(self) -> transformers.KpiTransformer:
|
|
555
|
+
return transformers.KpiTransformer(self.kpi, self.population)
|
|
556
|
+
|
|
557
|
+
@functools.cached_property
|
|
558
|
+
def controls_scaled(self) -> backend.Tensor | None:
|
|
559
|
+
if self.controls is not None:
|
|
560
|
+
# If `controls` is defined, then `controls_transformer` is also defined.
|
|
561
|
+
return self.controls_transformer.forward(self.controls) # pytype: disable=attribute-error
|
|
562
|
+
else:
|
|
563
|
+
return None
|
|
564
|
+
|
|
565
|
+
@functools.cached_property
|
|
566
|
+
def non_media_treatments_normalized(self) -> backend.Tensor | None:
|
|
567
|
+
"""Normalized non-media treatments.
|
|
568
|
+
|
|
569
|
+
The non-media treatments values are scaled by population (for channels where
|
|
570
|
+
`non_media_population_scaling_id` is `True`) and normalized by centering and
|
|
571
|
+
scaling with means and standard deviations.
|
|
572
|
+
"""
|
|
573
|
+
if self.non_media_transformer is not None:
|
|
574
|
+
return self.non_media_transformer.forward(
|
|
575
|
+
self.non_media_treatments
|
|
576
|
+
) # pytype: disable=attribute-error
|
|
577
|
+
else:
|
|
578
|
+
return None
|
|
579
|
+
|
|
580
|
+
@functools.cached_property
|
|
581
|
+
def kpi_scaled(self) -> backend.Tensor:
|
|
582
|
+
return self.kpi_transformer.forward(self.kpi)
|
|
583
|
+
|
|
584
|
+
@functools.cached_property
|
|
585
|
+
def media_effects_dist(self) -> str:
|
|
586
|
+
if self.is_national:
|
|
587
|
+
return constants.NATIONAL_MODEL_SPEC_ARGS[constants.MEDIA_EFFECTS_DIST]
|
|
588
|
+
else:
|
|
589
|
+
return self._model_spec.media_effects_dist
|
|
590
|
+
|
|
591
|
+
@functools.cached_property
|
|
592
|
+
def unique_sigma_for_each_geo(self) -> bool:
|
|
593
|
+
if self.is_national:
|
|
594
|
+
# Should evaluate to False.
|
|
595
|
+
return constants.NATIONAL_MODEL_SPEC_ARGS[
|
|
596
|
+
constants.UNIQUE_SIGMA_FOR_EACH_GEO
|
|
597
|
+
]
|
|
598
|
+
else:
|
|
599
|
+
return self._model_spec.unique_sigma_for_each_geo
|
|
600
|
+
|
|
601
|
+
@functools.cached_property
|
|
602
|
+
def baseline_geo_idx(self) -> int:
|
|
603
|
+
"""Returns the index of the baseline geo."""
|
|
604
|
+
if isinstance(self._model_spec.baseline_geo, int):
|
|
605
|
+
if (
|
|
606
|
+
self._model_spec.baseline_geo < 0
|
|
607
|
+
or self._model_spec.baseline_geo >= self.n_geos
|
|
608
|
+
):
|
|
609
|
+
raise ValueError(
|
|
610
|
+
f"Baseline geo index {self._model_spec.baseline_geo} out of range"
|
|
611
|
+
f" [0, {self.n_geos - 1}]."
|
|
612
|
+
)
|
|
613
|
+
return self._model_spec.baseline_geo
|
|
614
|
+
elif isinstance(self._model_spec.baseline_geo, str):
|
|
615
|
+
# np.where returns a 1-D tuple, its first element is an array of found
|
|
616
|
+
# elements.
|
|
617
|
+
index = np.where(self._input_data.geo == self._model_spec.baseline_geo)[0]
|
|
618
|
+
if index.size == 0:
|
|
619
|
+
raise ValueError(
|
|
620
|
+
f"Baseline geo '{self._model_spec.baseline_geo}' not found."
|
|
621
|
+
)
|
|
622
|
+
# Geos are unique, so index is a 1-element array.
|
|
623
|
+
return index[0]
|
|
624
|
+
else:
|
|
625
|
+
return backend.argmax(self.population)
|
|
626
|
+
|
|
627
|
+
@functools.cached_property
|
|
628
|
+
def holdout_id(self) -> backend.Tensor | None:
|
|
629
|
+
if self._model_spec.holdout_id is None:
|
|
630
|
+
return None
|
|
631
|
+
tensor = backend.to_tensor(self._model_spec.holdout_id, dtype=backend.bool_)
|
|
632
|
+
return tensor[backend.newaxis, ...] if self.is_national else tensor
|
|
633
|
+
|
|
634
|
+
def _warn_setting_ignored_priors(self):
|
|
635
|
+
"""Raises a warning if ignored priors are set."""
|
|
636
|
+
default_distribution = prior_distribution.PriorDistribution()
|
|
637
|
+
for ignored_priors_dict, prior_type, prior_type_name in (
|
|
638
|
+
(
|
|
639
|
+
constants.IGNORED_PRIORS_MEDIA,
|
|
640
|
+
self._model_spec.effective_media_prior_type,
|
|
641
|
+
"media_prior_type",
|
|
642
|
+
),
|
|
643
|
+
(
|
|
644
|
+
constants.IGNORED_PRIORS_RF,
|
|
645
|
+
self._model_spec.effective_rf_prior_type,
|
|
646
|
+
"rf_prior_type",
|
|
647
|
+
),
|
|
648
|
+
):
|
|
649
|
+
ignored_custom_priors = []
|
|
650
|
+
for prior in ignored_priors_dict.get(prior_type, []):
|
|
651
|
+
self_prior = getattr(self._model_spec.prior, prior)
|
|
652
|
+
default_prior = getattr(default_distribution, prior)
|
|
653
|
+
if not prior_distribution.distributions_are_equal(
|
|
654
|
+
self_prior, default_prior
|
|
655
|
+
):
|
|
656
|
+
ignored_custom_priors.append(prior)
|
|
657
|
+
if ignored_custom_priors:
|
|
658
|
+
ignored_priors_str = ", ".join(ignored_custom_priors)
|
|
659
|
+
warnings.warn(
|
|
660
|
+
f"Custom prior(s) `{ignored_priors_str}` are ignored when"
|
|
661
|
+
f' `{prior_type_name}` is set to "{prior_type}".'
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
def _validate_mroi_priors_non_revenue(self):
|
|
665
|
+
"""Validates mroi priors in the non-revenue outcome case."""
|
|
666
|
+
if (
|
|
667
|
+
self._input_data.kpi_type == constants.NON_REVENUE
|
|
668
|
+
and self._input_data.revenue_per_kpi is None
|
|
669
|
+
):
|
|
670
|
+
default_distribution = prior_distribution.PriorDistribution()
|
|
671
|
+
if (
|
|
672
|
+
self.n_media_channels > 0
|
|
673
|
+
and (
|
|
674
|
+
self._model_spec.effective_media_prior_type
|
|
675
|
+
== constants.TREATMENT_PRIOR_TYPE_MROI
|
|
676
|
+
)
|
|
677
|
+
and prior_distribution.distributions_are_equal(
|
|
678
|
+
self._model_spec.prior.mroi_m, default_distribution.mroi_m
|
|
679
|
+
)
|
|
680
|
+
):
|
|
681
|
+
raise ValueError(
|
|
682
|
+
f"Custom priors should be set on `{constants.MROI_M}` when"
|
|
683
|
+
' `media_prior_type` is "mroi", KPI is non-revenue and revenue per'
|
|
684
|
+
" kpi data is missing."
|
|
685
|
+
)
|
|
686
|
+
if (
|
|
687
|
+
self.n_rf_channels > 0
|
|
688
|
+
and (
|
|
689
|
+
self._model_spec.effective_rf_prior_type
|
|
690
|
+
== constants.TREATMENT_PRIOR_TYPE_MROI
|
|
691
|
+
)
|
|
692
|
+
and prior_distribution.distributions_are_equal(
|
|
693
|
+
self._model_spec.prior.mroi_rf, default_distribution.mroi_rf
|
|
694
|
+
)
|
|
695
|
+
):
|
|
696
|
+
raise ValueError(
|
|
697
|
+
f"Custom priors should be set on `{constants.MROI_RF}` when"
|
|
698
|
+
' `rf_prior_type` is "mroi", KPI is non-revenue and revenue per kpi'
|
|
699
|
+
" data is missing."
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
def _validate_roi_priors_non_revenue(self):
|
|
703
|
+
"""Validates roi priors in the non-revenue outcome case."""
|
|
704
|
+
if (
|
|
705
|
+
self._input_data.kpi_type == constants.NON_REVENUE
|
|
706
|
+
and self._input_data.revenue_per_kpi is None
|
|
707
|
+
):
|
|
708
|
+
default_distribution = prior_distribution.PriorDistribution()
|
|
709
|
+
default_roi_m_used = (
|
|
710
|
+
self._model_spec.effective_media_prior_type
|
|
711
|
+
== constants.TREATMENT_PRIOR_TYPE_ROI
|
|
712
|
+
and prior_distribution.distributions_are_equal(
|
|
713
|
+
self._model_spec.prior.roi_m, default_distribution.roi_m
|
|
714
|
+
)
|
|
715
|
+
)
|
|
716
|
+
default_roi_rf_used = (
|
|
717
|
+
self._model_spec.effective_rf_prior_type
|
|
718
|
+
== constants.TREATMENT_PRIOR_TYPE_ROI
|
|
719
|
+
and prior_distribution.distributions_are_equal(
|
|
720
|
+
self._model_spec.prior.roi_rf, default_distribution.roi_rf
|
|
721
|
+
)
|
|
722
|
+
)
|
|
723
|
+
# If ROI priors are used with the default prior distribution for all paid
|
|
724
|
+
# channels (media and RF), then use the "total paid media contribution
|
|
725
|
+
# prior" procedure.
|
|
726
|
+
if (
|
|
727
|
+
(default_roi_m_used and default_roi_rf_used)
|
|
728
|
+
or (self.n_media_channels == 0 and default_roi_rf_used)
|
|
729
|
+
or (self.n_rf_channels == 0 and default_roi_m_used)
|
|
730
|
+
):
|
|
731
|
+
self._set_total_media_contribution_prior = True
|
|
732
|
+
warnings.warn(
|
|
733
|
+
"Consider setting custom ROI priors, as kpi_type was specified as"
|
|
734
|
+
" `non_revenue` with no `revenue_per_kpi` being set. Otherwise, the"
|
|
735
|
+
" total media contribution prior will be used with"
|
|
736
|
+
f" `p_mean={constants.P_MEAN}` and `p_sd={constants.P_SD}`. Further"
|
|
737
|
+
" documentation available at "
|
|
738
|
+
" https://developers.google.com/meridian/docs/advanced-modeling/unknown-revenue-kpi-custom#set-total-paid-media-contribution-prior",
|
|
739
|
+
)
|
|
740
|
+
elif self.n_media_channels > 0 and default_roi_m_used:
|
|
741
|
+
raise ValueError(
|
|
742
|
+
f"Custom priors should be set on `{constants.ROI_M}` when"
|
|
743
|
+
' `media_prior_type` is "roi", custom priors are assigned on'
|
|
744
|
+
' `{constants.ROI_RF}` or `rf_prior_type` is not "roi", KPI is'
|
|
745
|
+
" non-revenue and revenue per kpi data is missing."
|
|
746
|
+
)
|
|
747
|
+
elif self.n_rf_channels > 0 and default_roi_rf_used:
|
|
748
|
+
raise ValueError(
|
|
749
|
+
f"Custom priors should be set on `{constants.ROI_RF}` when"
|
|
750
|
+
' `rf_prior_type` is "roi", custom priors are assigned on'
|
|
751
|
+
' `{constants.ROI_M}` or `media_prior_type` is not "roi", KPI is'
|
|
752
|
+
" non-revenue and revenue per kpi data is missing."
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
def _check_media_prior_support(self) -> None:
|
|
756
|
+
"""Checks ROI, mROI, and Contribution prior support when random effects are log-normal.
|
|
757
|
+
|
|
758
|
+
Priors for ROI, mROI, and Contribution can only have negative support if the
|
|
759
|
+
random effects follow a normal distribution. This check enforces that priors
|
|
760
|
+
have non-negative support when random effects follow a log-normal
|
|
761
|
+
distribution. This check only applies to geo-level models with log-normal
|
|
762
|
+
random effects since national models do not have random effects.
|
|
763
|
+
"""
|
|
764
|
+
prior = self._model_spec.prior
|
|
765
|
+
if self.n_media_channels > 0:
|
|
766
|
+
self._check_for_negative_support(
|
|
767
|
+
prior.roi_m,
|
|
768
|
+
self.media_effects_dist,
|
|
769
|
+
constants.TREATMENT_PRIOR_TYPE_ROI,
|
|
770
|
+
)
|
|
771
|
+
self._check_for_negative_support(
|
|
772
|
+
prior.mroi_m,
|
|
773
|
+
self.media_effects_dist,
|
|
774
|
+
constants.TREATMENT_PRIOR_TYPE_MROI,
|
|
775
|
+
)
|
|
776
|
+
self._check_for_negative_support(
|
|
777
|
+
prior.contribution_m,
|
|
778
|
+
self.media_effects_dist,
|
|
779
|
+
constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
780
|
+
)
|
|
781
|
+
if self.n_rf_channels > 0:
|
|
782
|
+
self._check_for_negative_support(
|
|
783
|
+
prior.roi_rf,
|
|
784
|
+
self.media_effects_dist,
|
|
785
|
+
constants.TREATMENT_PRIOR_TYPE_ROI,
|
|
786
|
+
)
|
|
787
|
+
self._check_for_negative_support(
|
|
788
|
+
prior.mroi_rf,
|
|
789
|
+
self.media_effects_dist,
|
|
790
|
+
constants.TREATMENT_PRIOR_TYPE_MROI,
|
|
791
|
+
)
|
|
792
|
+
self._check_for_negative_support(
|
|
793
|
+
prior.contribution_rf,
|
|
794
|
+
self.media_effects_dist,
|
|
795
|
+
constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
796
|
+
)
|
|
797
|
+
if self.n_organic_media_channels > 0:
|
|
798
|
+
self._check_for_negative_support(
|
|
799
|
+
prior.contribution_om,
|
|
800
|
+
self.media_effects_dist,
|
|
801
|
+
constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
802
|
+
)
|
|
803
|
+
if self.n_organic_rf_channels > 0:
|
|
804
|
+
self._check_for_negative_support(
|
|
805
|
+
prior.contribution_orf,
|
|
806
|
+
self.media_effects_dist,
|
|
807
|
+
constants.TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
def _check_for_negative_support(
|
|
811
|
+
self,
|
|
812
|
+
dist: backend.tfd.Distribution,
|
|
813
|
+
media_effects_dist: str,
|
|
814
|
+
prior_type: str,
|
|
815
|
+
) -> None:
|
|
816
|
+
"""Checks for negative support in prior distributions.
|
|
817
|
+
|
|
818
|
+
When `media_effects_dist` is `MEDIA_EFFECTS_LOG_NORMAL`, prior distributions
|
|
819
|
+
for media effects must be non-negative. This function raises a ValueError if
|
|
820
|
+
any part of the distribution's CDF is greater than 0 at 0, indicating some
|
|
821
|
+
probability mass below zero.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
dist: The distribution to check.
|
|
825
|
+
media_effects_dist: The type of media effects distribution.
|
|
826
|
+
prior_type: The prior type that corresponds with current prior under test.
|
|
827
|
+
|
|
828
|
+
Raises:
|
|
829
|
+
ValueError: If the prior distribution has negative support when
|
|
830
|
+
`media_effects_dist` is `MEDIA_EFFECTS_LOG_NORMAL`.
|
|
831
|
+
"""
|
|
832
|
+
if (
|
|
833
|
+
prior_type == self._model_spec.media_prior_type
|
|
834
|
+
and media_effects_dist == constants.MEDIA_EFFECTS_LOG_NORMAL
|
|
835
|
+
and np.any(dist.cdf(0) > 0)
|
|
836
|
+
):
|
|
837
|
+
raise ValueError(
|
|
838
|
+
"Media priors must have non-negative support when"
|
|
839
|
+
f' `media_effects_dist`="{media_effects_dist}". Found negative prior'
|
|
840
|
+
f" distribution support for {dist.name}."
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
@functools.cached_property
|
|
844
|
+
def prior_broadcast(self) -> prior_distribution.PriorDistribution:
|
|
845
|
+
"""Returns broadcasted `PriorDistribution` object."""
|
|
846
|
+
total_spend = self._input_data.get_total_spend()
|
|
847
|
+
# Total spend can have 1, 2 or 3 dimensions. Aggregate by channel.
|
|
848
|
+
if len(total_spend.shape) == 1:
|
|
849
|
+
# Already aggregated by channel.
|
|
850
|
+
agg_total_spend = total_spend
|
|
851
|
+
elif len(total_spend.shape) == 2:
|
|
852
|
+
agg_total_spend = np.sum(total_spend, axis=(0,))
|
|
853
|
+
else:
|
|
854
|
+
agg_total_spend = np.sum(total_spend, axis=(0, 1))
|
|
855
|
+
|
|
856
|
+
return self._model_spec.prior.broadcast(
|
|
857
|
+
n_geos=self.n_geos,
|
|
858
|
+
n_media_channels=self.n_media_channels,
|
|
859
|
+
n_rf_channels=self.n_rf_channels,
|
|
860
|
+
n_organic_media_channels=self.n_organic_media_channels,
|
|
861
|
+
n_organic_rf_channels=self.n_organic_rf_channels,
|
|
862
|
+
n_controls=self.n_controls,
|
|
863
|
+
n_non_media_channels=self.n_non_media_channels,
|
|
864
|
+
unique_sigma_for_each_geo=self.unique_sigma_for_each_geo,
|
|
865
|
+
n_knots=self.knot_info.n_knots,
|
|
866
|
+
is_national=self.is_national,
|
|
867
|
+
set_total_media_contribution_prior=self._set_total_media_contribution_prior,
|
|
868
|
+
kpi=np.sum(self._input_data.kpi.values),
|
|
869
|
+
total_spend=agg_total_spend,
|
|
870
|
+
)
|
|
871
|
+
|
|
872
|
+
@functools.cached_property
|
|
873
|
+
def adstock_decay_spec(self) -> adstock_hill.AdstockDecaySpec:
|
|
874
|
+
"""Returns `AdstockDecaySpec` object with correctly mapped channels."""
|
|
875
|
+
if isinstance(self._model_spec.adstock_decay_spec, str):
|
|
876
|
+
return adstock_hill.AdstockDecaySpec.from_consistent_type(
|
|
877
|
+
self._model_spec.adstock_decay_spec
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
try:
|
|
881
|
+
return self._create_adstock_decay_functions_from_channel_map(
|
|
882
|
+
self._model_spec.adstock_decay_spec
|
|
883
|
+
)
|
|
884
|
+
except KeyError as e:
|
|
885
|
+
raise ValueError(
|
|
886
|
+
"Unrecognized channel names found in `adstock_decay_spec` keys"
|
|
887
|
+
f" {tuple(self._model_spec.adstock_decay_spec.keys())}. Keys should"
|
|
888
|
+
" either contain only channel_names"
|
|
889
|
+
f" {tuple(self._input_data.get_all_adstock_hill_channels().tolist())} or"
|
|
890
|
+
" be one or more of {'media', 'rf', 'organic_media',"
|
|
891
|
+
" 'organic_rf'}."
|
|
892
|
+
) from e
|
|
893
|
+
|
|
894
|
+
def _create_adstock_decay_functions_from_channel_map(
|
|
895
|
+
self, channel_function_map: Mapping[str, str]
|
|
896
|
+
) -> adstock_hill.AdstockDecaySpec:
|
|
897
|
+
"""Create `AdstockDecaySpec` from mapping from channels to decay functions."""
|
|
898
|
+
|
|
899
|
+
for channel in channel_function_map:
|
|
900
|
+
if channel not in self._input_data.get_all_adstock_hill_channels():
|
|
901
|
+
raise KeyError(f"Channel {channel} not found in data.")
|
|
902
|
+
|
|
903
|
+
if self._input_data.media_channel is not None:
|
|
904
|
+
media_channel_builder = self._input_data.get_paid_media_channels_argument_builder().with_default_value(
|
|
905
|
+
constants.GEOMETRIC_DECAY
|
|
906
|
+
)
|
|
907
|
+
media_adstock_function = media_channel_builder(**channel_function_map)
|
|
908
|
+
else:
|
|
909
|
+
media_adstock_function = constants.GEOMETRIC_DECAY
|
|
910
|
+
|
|
911
|
+
if self._input_data.rf_channel is not None:
|
|
912
|
+
rf_channel_builder = self._input_data.get_paid_rf_channels_argument_builder().with_default_value(
|
|
913
|
+
constants.GEOMETRIC_DECAY
|
|
914
|
+
)
|
|
915
|
+
rf_adstock_function = rf_channel_builder(**channel_function_map)
|
|
916
|
+
else:
|
|
917
|
+
rf_adstock_function = constants.GEOMETRIC_DECAY
|
|
918
|
+
|
|
919
|
+
if self._input_data.organic_media_channel is not None:
|
|
920
|
+
organic_media_channel_builder = self._input_data.get_organic_media_channels_argument_builder().with_default_value(
|
|
921
|
+
constants.GEOMETRIC_DECAY
|
|
922
|
+
)
|
|
923
|
+
organic_media_adstock_function = organic_media_channel_builder(
|
|
924
|
+
**channel_function_map
|
|
925
|
+
)
|
|
926
|
+
else:
|
|
927
|
+
organic_media_adstock_function = constants.GEOMETRIC_DECAY
|
|
928
|
+
|
|
929
|
+
if self._input_data.organic_rf_channel is not None:
|
|
930
|
+
organic_rf_channel_builder = self._input_data.get_organic_rf_channels_argument_builder().with_default_value(
|
|
931
|
+
constants.GEOMETRIC_DECAY
|
|
932
|
+
)
|
|
933
|
+
organic_rf_adstock_function = organic_rf_channel_builder(
|
|
934
|
+
**channel_function_map
|
|
935
|
+
)
|
|
936
|
+
else:
|
|
937
|
+
organic_rf_adstock_function = constants.GEOMETRIC_DECAY
|
|
938
|
+
|
|
939
|
+
return adstock_hill.AdstockDecaySpec(
|
|
940
|
+
media=media_adstock_function,
|
|
941
|
+
rf=rf_adstock_function,
|
|
942
|
+
organic_media=organic_media_adstock_function,
|
|
943
|
+
organic_rf=organic_rf_adstock_function,
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
def create_inference_data_coords(
|
|
947
|
+
self, n_chains: int, n_draws: int
|
|
948
|
+
) -> Mapping[str, np.ndarray | Sequence[str]]:
|
|
949
|
+
"""Creates data coordinates for inference data."""
|
|
950
|
+
media_channel_names = (
|
|
951
|
+
self.input_data.media_channel
|
|
952
|
+
if self.input_data.media_channel is not None
|
|
953
|
+
else np.array([])
|
|
954
|
+
)
|
|
955
|
+
rf_channel_names = (
|
|
956
|
+
self.input_data.rf_channel
|
|
957
|
+
if self.input_data.rf_channel is not None
|
|
958
|
+
else np.array([])
|
|
959
|
+
)
|
|
960
|
+
organic_media_channel_names = (
|
|
961
|
+
self.input_data.organic_media_channel
|
|
962
|
+
if self.input_data.organic_media_channel is not None
|
|
963
|
+
else np.array([])
|
|
964
|
+
)
|
|
965
|
+
organic_rf_channel_names = (
|
|
966
|
+
self.input_data.organic_rf_channel
|
|
967
|
+
if self.input_data.organic_rf_channel is not None
|
|
968
|
+
else np.array([])
|
|
969
|
+
)
|
|
970
|
+
non_media_channel_names = (
|
|
971
|
+
self.input_data.non_media_channel
|
|
972
|
+
if self.input_data.non_media_channel is not None
|
|
973
|
+
else np.array([])
|
|
974
|
+
)
|
|
975
|
+
control_variable_names = (
|
|
976
|
+
self.input_data.control_variable
|
|
977
|
+
if self.input_data.control_variable is not None
|
|
978
|
+
else np.array([])
|
|
979
|
+
)
|
|
980
|
+
return {
|
|
981
|
+
constants.CHAIN: np.arange(n_chains),
|
|
982
|
+
constants.DRAW: np.arange(n_draws),
|
|
983
|
+
constants.GEO: self.input_data.geo,
|
|
984
|
+
constants.TIME: self.input_data.time,
|
|
985
|
+
constants.MEDIA_TIME: self.input_data.media_time,
|
|
986
|
+
constants.KNOTS: np.arange(self.knot_info.n_knots),
|
|
987
|
+
constants.CONTROL_VARIABLE: control_variable_names,
|
|
988
|
+
constants.NON_MEDIA_CHANNEL: non_media_channel_names,
|
|
989
|
+
constants.MEDIA_CHANNEL: media_channel_names,
|
|
990
|
+
constants.RF_CHANNEL: rf_channel_names,
|
|
991
|
+
constants.ORGANIC_MEDIA_CHANNEL: organic_media_channel_names,
|
|
992
|
+
constants.ORGANIC_RF_CHANNEL: organic_rf_channel_names,
|
|
993
|
+
}
|
|
994
|
+
|
|
995
|
+
def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]:
|
|
996
|
+
"""Creates data dimensions for inference data."""
|
|
997
|
+
inference_dims = dict(constants.INFERENCE_DIMS)
|
|
998
|
+
if self.unique_sigma_for_each_geo:
|
|
999
|
+
inference_dims[constants.SIGMA] = [constants.GEO]
|
|
1000
|
+
else:
|
|
1001
|
+
inference_dims[constants.SIGMA] = []
|
|
1002
|
+
|
|
1003
|
+
return {
|
|
1004
|
+
param: [constants.CHAIN, constants.DRAW] + list(dims)
|
|
1005
|
+
for param, dims in inference_dims.items()
|
|
1006
|
+
}
|
|
1007
|
+
|
|
1008
|
+
def populate_cached_properties(self):
|
|
1009
|
+
"""Eagerly activates all cached properties.
|
|
1010
|
+
|
|
1011
|
+
This is useful for creating a `tf.function` computation graph with this
|
|
1012
|
+
Meridian object as part of a captured closure. Within the computation graph,
|
|
1013
|
+
internal state mutations are problematic, and so this method freezes the
|
|
1014
|
+
object's states before the computation graph is created.
|
|
1015
|
+
"""
|
|
1016
|
+
cls = self.__class__
|
|
1017
|
+
# "Freeze" all @cached_property attributes by simply accessing them (with
|
|
1018
|
+
# `getattr()`).
|
|
1019
|
+
cached_properties = [
|
|
1020
|
+
attr
|
|
1021
|
+
for attr in dir(self)
|
|
1022
|
+
if isinstance(getattr(cls, attr, cls), functools.cached_property)
|
|
1023
|
+
]
|
|
1024
|
+
for attr in cached_properties:
|
|
1025
|
+
_ = getattr(self, attr)
|
|
1026
|
+
|
|
1027
|
+
def expand_selected_time_dims(
|
|
1028
|
+
self,
|
|
1029
|
+
start_date: tc.Date = None,
|
|
1030
|
+
end_date: tc.Date = None,
|
|
1031
|
+
) -> list[str] | None:
|
|
1032
|
+
"""Validates and returns time dimension values based on the selected times.
|
|
1033
|
+
|
|
1034
|
+
If both `start_date` and `end_date` are None, returns None. If specified,
|
|
1035
|
+
both `start_date` and `end_date` are inclusive, and must be present in the
|
|
1036
|
+
time coordinates of the input data.
|
|
1037
|
+
|
|
1038
|
+
Args:
|
|
1039
|
+
start_date: Start date of the selected time period. If None, implies the
|
|
1040
|
+
earliest time dimension value in the input data.
|
|
1041
|
+
end_date: End date of the selected time period. If None, implies the
|
|
1042
|
+
latest time dimension value in the input data.
|
|
1043
|
+
|
|
1044
|
+
Returns:
|
|
1045
|
+
A list of time dimension values (as Meridian-formatted strings) in the
|
|
1046
|
+
input data within the selected time period, or do nothing and pass through
|
|
1047
|
+
None if both arguments are Nones, or if `start_date` and `end_date`
|
|
1048
|
+
correspond to the entire time range in the input data.
|
|
1049
|
+
|
|
1050
|
+
Raises:
|
|
1051
|
+
ValueError if `start_date` or `end_date` is not in the input data time
|
|
1052
|
+
dimensions.
|
|
1053
|
+
"""
|
|
1054
|
+
expanded = self.input_data.time_coordinates.expand_selected_time_dims(
|
|
1055
|
+
start_date=start_date, end_date=end_date
|
|
1056
|
+
)
|
|
1057
|
+
if expanded is None:
|
|
1058
|
+
return None
|
|
1059
|
+
return [date.strftime(constants.DATE_FORMAT) for date in expanded]
|