google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/METADATA +8 -2
- google_meridian-1.2.0.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +526 -362
- meridian/analysis/optimizer.py +275 -267
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +37 -49
- meridian/backend/__init__.py +514 -0
- meridian/backend/config.py +59 -0
- meridian/backend/test_utils.py +95 -0
- meridian/constants.py +59 -3
- meridian/data/input_data.py +94 -0
- meridian/data/test_utils.py +144 -12
- meridian/model/adstock_hill.py +279 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +306 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +323 -157
- meridian/model/posterior_sampler.py +84 -77
- meridian/model/prior_distribution.py +538 -168
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +53 -47
- meridian/version.py +1 -1
- google_meridian-1.1.5.dist-info/RECORD +0 -47
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -16,18 +16,20 @@
|
|
|
16
16
|
|
|
17
17
|
import collections
|
|
18
18
|
import os
|
|
19
|
+
from typing import NamedTuple
|
|
19
20
|
|
|
21
|
+
from meridian import backend
|
|
20
22
|
from meridian import constants
|
|
23
|
+
from meridian.data import input_data
|
|
21
24
|
from meridian.data import test_utils
|
|
22
|
-
import tensorflow as tf
|
|
23
25
|
import xarray as xr
|
|
24
26
|
|
|
25
27
|
|
|
26
|
-
def _convert_with_swap(array: xr.DataArray, n_burnin: int) ->
|
|
27
|
-
"""Converts a DataArray to a
|
|
28
|
+
def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> backend.Tensor:
|
|
29
|
+
"""Converts a DataArray to a backend.Tensor with the correct MCMC format.
|
|
28
30
|
|
|
29
|
-
This function converts a DataArray to
|
|
30
|
-
and adds the burnin part. This is needed to properly mock the
|
|
31
|
+
This function converts a DataArray to backend.Tensor, swaps first two
|
|
32
|
+
dimensions and adds the burnin part. This is needed to properly mock the
|
|
31
33
|
_xla_windowed_adaptive_nuts() function output in the sample_posterior
|
|
32
34
|
tests.
|
|
33
35
|
|
|
@@ -39,9 +41,9 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
|
|
|
39
41
|
A tensor in the same format as returned by the _xla_windowed_adaptive_nuts()
|
|
40
42
|
function.
|
|
41
43
|
"""
|
|
42
|
-
tensor =
|
|
44
|
+
tensor = backend.to_tensor(array)
|
|
43
45
|
perm = [1, 0] + [i for i in range(2, len(tensor.shape))]
|
|
44
|
-
transposed_tensor =
|
|
46
|
+
transposed_tensor = backend.transpose(tensor, perm=perm)
|
|
45
47
|
|
|
46
48
|
# Add the "burnin" part to the mocked output of _xla_windowed_adaptive_nuts
|
|
47
49
|
# to make sure sample_posterior returns the correct "keep" part.
|
|
@@ -50,15 +52,21 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
|
|
|
50
52
|
else:
|
|
51
53
|
pad_value = 0.0 if array.dtype.kind == "f" else 0
|
|
52
54
|
|
|
53
|
-
burnin =
|
|
54
|
-
return
|
|
55
|
+
burnin = backend.fill([n_burnin] + transposed_tensor.shape[1:], pad_value)
|
|
56
|
+
return backend.concatenate(
|
|
55
57
|
[burnin, transposed_tensor],
|
|
56
58
|
axis=0,
|
|
57
59
|
)
|
|
58
60
|
|
|
59
61
|
|
|
60
62
|
class WithInputDataSamples:
|
|
61
|
-
"""A mixin to inject test data samples to a unit test class.
|
|
63
|
+
"""A mixin to inject test data samples to a unit test class.
|
|
64
|
+
|
|
65
|
+
The `setup` method is a classmethod because loading and creating the data
|
|
66
|
+
samples can be expensive. The properties return deep copies to ensure
|
|
67
|
+
immutability for the data properties. As a result, it is recommended to use
|
|
68
|
+
local variables to avoid re-loading the data samples multiple times.
|
|
69
|
+
"""
|
|
62
70
|
|
|
63
71
|
# TODO: Update the sample data to span over 1 or 2 year(s).
|
|
64
72
|
_TEST_DIR = os.path.join(os.path.dirname(__file__), "test_data")
|
|
@@ -114,169 +122,203 @@ class WithInputDataSamples:
|
|
|
114
122
|
_N_MEDIA_CHANNELS = 3
|
|
115
123
|
_N_RF_CHANNELS = 2
|
|
116
124
|
_N_CONTROLS = 2
|
|
117
|
-
_ROI_CALIBRATION_PERIOD =
|
|
118
|
-
|
|
119
|
-
dtype=
|
|
125
|
+
_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
126
|
+
backend.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
|
|
127
|
+
dtype=backend.bool_,
|
|
120
128
|
)
|
|
121
|
-
_RF_ROI_CALIBRATION_PERIOD =
|
|
122
|
-
|
|
123
|
-
dtype=
|
|
129
|
+
_RF_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
130
|
+
backend.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
|
|
131
|
+
dtype=backend.bool_,
|
|
124
132
|
)
|
|
125
133
|
_N_ORGANIC_MEDIA_CHANNELS = 4
|
|
126
134
|
_N_ORGANIC_RF_CHANNELS = 1
|
|
127
135
|
_N_NON_MEDIA_CHANNELS = 2
|
|
128
136
|
|
|
129
|
-
|
|
137
|
+
# Private class variables to hold the base test data.
|
|
138
|
+
_input_data_non_revenue_no_revenue_per_kpi: input_data.InputData
|
|
139
|
+
_input_data_media_and_rf_non_revenue_no_revenue_per_kpi: input_data.InputData
|
|
140
|
+
_input_data_with_media_only: input_data.InputData
|
|
141
|
+
_input_data_with_rf_only: input_data.InputData
|
|
142
|
+
_input_data_with_media_and_rf: input_data.InputData
|
|
143
|
+
_input_data_with_media_and_rf_no_controls: input_data.InputData
|
|
144
|
+
_short_input_data_with_media_only: input_data.InputData
|
|
145
|
+
_short_input_data_with_media_only_no_controls: input_data.InputData
|
|
146
|
+
_short_input_data_with_rf_only: input_data.InputData
|
|
147
|
+
_short_input_data_with_media_and_rf: input_data.InputData
|
|
148
|
+
_national_input_data_media_only: input_data.InputData
|
|
149
|
+
_national_input_data_media_and_rf: input_data.InputData
|
|
150
|
+
_test_dist_media_and_rf: collections.OrderedDict[str, backend.Tensor]
|
|
151
|
+
_test_dist_media_only: collections.OrderedDict[str, backend.Tensor]
|
|
152
|
+
_test_dist_media_only_no_controls: collections.OrderedDict[
|
|
153
|
+
str, backend.Tensor
|
|
154
|
+
]
|
|
155
|
+
_test_dist_rf_only: collections.OrderedDict[str, backend.Tensor]
|
|
156
|
+
_test_trace: dict[str, backend.Tensor]
|
|
157
|
+
_national_input_data_non_media_and_organic: input_data.InputData
|
|
158
|
+
_input_data_non_media_and_organic: input_data.InputData
|
|
159
|
+
_short_input_data_non_media_and_organic: input_data.InputData
|
|
160
|
+
_short_input_data_non_media: input_data.InputData
|
|
161
|
+
_input_data_non_media_and_organic_same_time_dims: input_data.InputData
|
|
162
|
+
|
|
163
|
+
# The following NamedTuples and their attributes are immutable, so they can
|
|
164
|
+
# be accessed directly.
|
|
165
|
+
test_posterior_states_media_and_rf: NamedTuple
|
|
166
|
+
test_posterior_states_media_only: NamedTuple
|
|
167
|
+
test_posterior_states_media_only_no_controls: NamedTuple
|
|
168
|
+
test_posterior_states_rf_only: NamedTuple
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def setup(cls):
|
|
130
172
|
"""Sets up input data samples."""
|
|
131
|
-
|
|
173
|
+
cls._input_data_non_revenue_no_revenue_per_kpi = (
|
|
132
174
|
test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
|
|
133
|
-
n_geos=
|
|
134
|
-
n_times=
|
|
135
|
-
n_media_times=
|
|
136
|
-
n_controls=
|
|
137
|
-
n_media_channels=
|
|
175
|
+
n_geos=cls._N_GEOS,
|
|
176
|
+
n_times=cls._N_TIMES,
|
|
177
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
178
|
+
n_controls=cls._N_CONTROLS,
|
|
179
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
138
180
|
seed=0,
|
|
139
181
|
)
|
|
140
182
|
)
|
|
141
|
-
|
|
183
|
+
cls._input_data_media_and_rf_non_revenue_no_revenue_per_kpi = (
|
|
142
184
|
test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
|
|
143
|
-
n_geos=
|
|
144
|
-
n_times=
|
|
145
|
-
n_media_times=
|
|
146
|
-
n_controls=
|
|
147
|
-
n_media_channels=
|
|
148
|
-
n_rf_channels=
|
|
185
|
+
n_geos=cls._N_GEOS,
|
|
186
|
+
n_times=cls._N_TIMES,
|
|
187
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
188
|
+
n_controls=cls._N_CONTROLS,
|
|
189
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
190
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
149
191
|
seed=0,
|
|
150
192
|
)
|
|
151
193
|
)
|
|
152
|
-
|
|
194
|
+
cls._input_data_with_media_only = (
|
|
153
195
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
154
|
-
n_geos=
|
|
155
|
-
n_times=
|
|
156
|
-
n_media_times=
|
|
157
|
-
n_controls=
|
|
158
|
-
n_media_channels=
|
|
196
|
+
n_geos=cls._N_GEOS,
|
|
197
|
+
n_times=cls._N_TIMES,
|
|
198
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
199
|
+
n_controls=cls._N_CONTROLS,
|
|
200
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
159
201
|
seed=0,
|
|
160
202
|
)
|
|
161
203
|
)
|
|
162
|
-
|
|
204
|
+
cls._input_data_with_rf_only = (
|
|
163
205
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
164
|
-
n_geos=
|
|
165
|
-
n_times=
|
|
166
|
-
n_media_times=
|
|
167
|
-
n_controls=
|
|
168
|
-
n_rf_channels=
|
|
206
|
+
n_geos=cls._N_GEOS,
|
|
207
|
+
n_times=cls._N_TIMES,
|
|
208
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
209
|
+
n_controls=cls._N_CONTROLS,
|
|
210
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
169
211
|
seed=0,
|
|
170
212
|
)
|
|
171
213
|
)
|
|
172
|
-
|
|
214
|
+
cls._input_data_with_media_and_rf = (
|
|
173
215
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
174
|
-
n_geos=
|
|
175
|
-
n_times=
|
|
176
|
-
n_media_times=
|
|
177
|
-
n_controls=
|
|
178
|
-
n_media_channels=
|
|
179
|
-
n_rf_channels=
|
|
216
|
+
n_geos=cls._N_GEOS,
|
|
217
|
+
n_times=cls._N_TIMES,
|
|
218
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
219
|
+
n_controls=cls._N_CONTROLS,
|
|
220
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
221
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
180
222
|
seed=0,
|
|
181
223
|
)
|
|
182
224
|
)
|
|
183
|
-
|
|
225
|
+
cls._input_data_with_media_and_rf_no_controls = (
|
|
184
226
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
185
|
-
n_geos=
|
|
186
|
-
n_times=
|
|
187
|
-
n_media_times=
|
|
227
|
+
n_geos=cls._N_GEOS,
|
|
228
|
+
n_times=cls._N_TIMES,
|
|
229
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
188
230
|
n_controls=None,
|
|
189
|
-
n_media_channels=
|
|
190
|
-
n_rf_channels=
|
|
231
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
232
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
191
233
|
seed=0,
|
|
192
234
|
)
|
|
193
235
|
)
|
|
194
|
-
|
|
236
|
+
cls._short_input_data_with_media_only = (
|
|
195
237
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
196
|
-
n_geos=
|
|
197
|
-
n_times=
|
|
198
|
-
n_media_times=
|
|
199
|
-
n_controls=
|
|
200
|
-
n_media_channels=
|
|
238
|
+
n_geos=cls._N_GEOS,
|
|
239
|
+
n_times=cls._N_TIMES_SHORT,
|
|
240
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
241
|
+
n_controls=cls._N_CONTROLS,
|
|
242
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
201
243
|
seed=0,
|
|
202
244
|
)
|
|
203
245
|
)
|
|
204
|
-
|
|
246
|
+
cls._short_input_data_with_media_only_no_controls = (
|
|
205
247
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
206
|
-
n_geos=
|
|
207
|
-
n_times=
|
|
208
|
-
n_media_times=
|
|
248
|
+
n_geos=cls._N_GEOS,
|
|
249
|
+
n_times=cls._N_TIMES_SHORT,
|
|
250
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
209
251
|
n_controls=0,
|
|
210
|
-
n_media_channels=
|
|
252
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
211
253
|
seed=0,
|
|
212
254
|
)
|
|
213
255
|
)
|
|
214
|
-
|
|
256
|
+
cls._short_input_data_with_rf_only = (
|
|
215
257
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
216
|
-
n_geos=
|
|
217
|
-
n_times=
|
|
218
|
-
n_media_times=
|
|
219
|
-
n_controls=
|
|
220
|
-
n_rf_channels=
|
|
258
|
+
n_geos=cls._N_GEOS,
|
|
259
|
+
n_times=cls._N_TIMES_SHORT,
|
|
260
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
261
|
+
n_controls=cls._N_CONTROLS,
|
|
262
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
221
263
|
seed=0,
|
|
222
264
|
)
|
|
223
265
|
)
|
|
224
|
-
|
|
266
|
+
cls._short_input_data_with_media_and_rf = (
|
|
225
267
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
226
|
-
n_geos=
|
|
227
|
-
n_times=
|
|
228
|
-
n_media_times=
|
|
229
|
-
n_controls=
|
|
230
|
-
n_media_channels=
|
|
231
|
-
n_rf_channels=
|
|
268
|
+
n_geos=cls._N_GEOS,
|
|
269
|
+
n_times=cls._N_TIMES_SHORT,
|
|
270
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
271
|
+
n_controls=cls._N_CONTROLS,
|
|
272
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
273
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
232
274
|
seed=0,
|
|
233
275
|
)
|
|
234
276
|
)
|
|
235
|
-
|
|
277
|
+
cls._national_input_data_media_only = (
|
|
236
278
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
237
|
-
n_geos=
|
|
238
|
-
n_times=
|
|
239
|
-
n_media_times=
|
|
240
|
-
n_controls=
|
|
241
|
-
n_media_channels=
|
|
279
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
280
|
+
n_times=cls._N_TIMES,
|
|
281
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
282
|
+
n_controls=cls._N_CONTROLS,
|
|
283
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
242
284
|
seed=0,
|
|
243
285
|
)
|
|
244
286
|
)
|
|
245
|
-
|
|
287
|
+
cls._national_input_data_media_and_rf = (
|
|
246
288
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
247
|
-
n_geos=
|
|
248
|
-
n_times=
|
|
249
|
-
n_media_times=
|
|
250
|
-
n_controls=
|
|
251
|
-
n_media_channels=
|
|
252
|
-
n_rf_channels=
|
|
289
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
290
|
+
n_times=cls._N_TIMES,
|
|
291
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
292
|
+
n_controls=cls._N_CONTROLS,
|
|
293
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
294
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
253
295
|
seed=0,
|
|
254
296
|
)
|
|
255
297
|
)
|
|
256
298
|
|
|
257
299
|
test_prior_media_and_rf = xr.open_dataset(
|
|
258
|
-
|
|
300
|
+
cls._TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH
|
|
259
301
|
)
|
|
260
302
|
test_prior_media_only = xr.open_dataset(
|
|
261
|
-
|
|
303
|
+
cls._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
|
|
262
304
|
)
|
|
263
305
|
test_prior_media_only_no_controls = xr.open_dataset(
|
|
264
|
-
|
|
306
|
+
cls._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
|
|
265
307
|
)
|
|
266
|
-
test_prior_rf_only = xr.open_dataset(
|
|
267
|
-
|
|
268
|
-
param:
|
|
308
|
+
test_prior_rf_only = xr.open_dataset(cls._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
|
|
309
|
+
cls._test_dist_media_and_rf = collections.OrderedDict({
|
|
310
|
+
param: backend.to_tensor(test_prior_media_and_rf[param])
|
|
269
311
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
270
312
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
271
313
|
+ constants.RF_PARAMETER_NAMES
|
|
272
314
|
})
|
|
273
|
-
|
|
274
|
-
param:
|
|
315
|
+
cls._test_dist_media_only = collections.OrderedDict({
|
|
316
|
+
param: backend.to_tensor(test_prior_media_only[param])
|
|
275
317
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
276
318
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
277
319
|
})
|
|
278
|
-
|
|
279
|
-
param:
|
|
320
|
+
cls._test_dist_media_only_no_controls = collections.OrderedDict({
|
|
321
|
+
param: backend.to_tensor(test_prior_media_only_no_controls[param])
|
|
280
322
|
for param in (
|
|
281
323
|
set(
|
|
282
324
|
constants.COMMON_PARAMETER_NAMES
|
|
@@ -287,27 +329,27 @@ class WithInputDataSamples:
|
|
|
287
329
|
)
|
|
288
330
|
)
|
|
289
331
|
})
|
|
290
|
-
|
|
291
|
-
param:
|
|
332
|
+
cls._test_dist_rf_only = collections.OrderedDict({
|
|
333
|
+
param: backend.to_tensor(test_prior_rf_only[param])
|
|
292
334
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
293
335
|
+ constants.RF_PARAMETER_NAMES
|
|
294
336
|
})
|
|
295
337
|
|
|
296
338
|
test_posterior_media_and_rf = xr.open_dataset(
|
|
297
|
-
|
|
339
|
+
cls._TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH
|
|
298
340
|
)
|
|
299
341
|
test_posterior_media_only = xr.open_dataset(
|
|
300
|
-
|
|
342
|
+
cls._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
|
|
301
343
|
)
|
|
302
344
|
test_posterior_media_only_no_controls = xr.open_dataset(
|
|
303
|
-
|
|
345
|
+
cls._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
|
|
304
346
|
)
|
|
305
347
|
test_posterior_rf_only = xr.open_dataset(
|
|
306
|
-
|
|
348
|
+
cls._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
|
|
307
349
|
)
|
|
308
350
|
posterior_params_to_tensors_media_and_rf = {
|
|
309
351
|
param: _convert_with_swap(
|
|
310
|
-
test_posterior_media_and_rf[param], n_burnin=
|
|
352
|
+
test_posterior_media_and_rf[param], n_burnin=cls._N_BURNIN
|
|
311
353
|
)
|
|
312
354
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
313
355
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
@@ -315,7 +357,7 @@ class WithInputDataSamples:
|
|
|
315
357
|
}
|
|
316
358
|
posterior_params_to_tensors_media_only = {
|
|
317
359
|
param: _convert_with_swap(
|
|
318
|
-
test_posterior_media_only[param], n_burnin=
|
|
360
|
+
test_posterior_media_only[param], n_burnin=cls._N_BURNIN
|
|
319
361
|
)
|
|
320
362
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
321
363
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
@@ -323,7 +365,7 @@ class WithInputDataSamples:
|
|
|
323
365
|
posterior_params_to_tensors_media_only_no_controls = {
|
|
324
366
|
param: _convert_with_swap(
|
|
325
367
|
test_posterior_media_only_no_controls[param],
|
|
326
|
-
n_burnin=
|
|
368
|
+
n_burnin=cls._N_BURNIN,
|
|
327
369
|
)
|
|
328
370
|
for param in (
|
|
329
371
|
set(
|
|
@@ -337,22 +379,22 @@ class WithInputDataSamples:
|
|
|
337
379
|
}
|
|
338
380
|
posterior_params_to_tensors_rf_only = {
|
|
339
381
|
param: _convert_with_swap(
|
|
340
|
-
test_posterior_rf_only[param], n_burnin=
|
|
382
|
+
test_posterior_rf_only[param], n_burnin=cls._N_BURNIN
|
|
341
383
|
)
|
|
342
384
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
343
385
|
+ constants.RF_PARAMETER_NAMES
|
|
344
386
|
}
|
|
345
|
-
|
|
387
|
+
cls.test_posterior_states_media_and_rf = collections.namedtuple(
|
|
346
388
|
"StructTuple",
|
|
347
389
|
constants.COMMON_PARAMETER_NAMES
|
|
348
390
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
349
391
|
+ constants.RF_PARAMETER_NAMES,
|
|
350
392
|
)(**posterior_params_to_tensors_media_and_rf)
|
|
351
|
-
|
|
393
|
+
cls.test_posterior_states_media_only = collections.namedtuple(
|
|
352
394
|
"StructTuple",
|
|
353
395
|
constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
|
|
354
396
|
)(**posterior_params_to_tensors_media_only)
|
|
355
|
-
|
|
397
|
+
cls.test_posterior_states_media_only_no_controls = collections.namedtuple(
|
|
356
398
|
"StructTuple",
|
|
357
399
|
(
|
|
358
400
|
set(
|
|
@@ -364,73 +406,197 @@ class WithInputDataSamples:
|
|
|
364
406
|
)
|
|
365
407
|
),
|
|
366
408
|
)(**posterior_params_to_tensors_media_only_no_controls)
|
|
367
|
-
|
|
409
|
+
cls.test_posterior_states_rf_only = collections.namedtuple(
|
|
368
410
|
"StructTuple",
|
|
369
411
|
constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,
|
|
370
412
|
)(**posterior_params_to_tensors_rf_only)
|
|
371
413
|
|
|
372
|
-
test_trace = xr.open_dataset(
|
|
373
|
-
|
|
374
|
-
param: _convert_with_swap(test_trace[param], n_burnin=
|
|
414
|
+
test_trace = xr.open_dataset(cls._TEST_SAMPLE_TRACE_PATH)
|
|
415
|
+
cls._test_trace = {
|
|
416
|
+
param: _convert_with_swap(test_trace[param], n_burnin=cls._N_BURNIN)
|
|
375
417
|
for param in test_trace.data_vars
|
|
376
418
|
}
|
|
377
419
|
|
|
378
420
|
# The following are input data samples with non-paid channels.
|
|
379
421
|
|
|
380
|
-
|
|
422
|
+
cls._national_input_data_non_media_and_organic = (
|
|
381
423
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
382
|
-
n_geos=
|
|
383
|
-
n_times=
|
|
384
|
-
n_media_times=
|
|
385
|
-
n_controls=
|
|
386
|
-
n_non_media_channels=
|
|
387
|
-
n_media_channels=
|
|
388
|
-
n_rf_channels=
|
|
389
|
-
n_organic_media_channels=
|
|
390
|
-
n_organic_rf_channels=
|
|
424
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
425
|
+
n_times=cls._N_TIMES,
|
|
426
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
427
|
+
n_controls=cls._N_CONTROLS,
|
|
428
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
429
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
430
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
431
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
432
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
391
433
|
seed=0,
|
|
392
434
|
)
|
|
393
435
|
)
|
|
394
436
|
|
|
395
|
-
|
|
437
|
+
cls._input_data_non_media_and_organic = (
|
|
396
438
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
397
|
-
n_geos=
|
|
398
|
-
n_times=
|
|
399
|
-
n_media_times=
|
|
400
|
-
n_controls=
|
|
401
|
-
n_non_media_channels=
|
|
402
|
-
n_media_channels=
|
|
403
|
-
n_rf_channels=
|
|
404
|
-
n_organic_media_channels=
|
|
405
|
-
n_organic_rf_channels=
|
|
439
|
+
n_geos=cls._N_GEOS,
|
|
440
|
+
n_times=cls._N_TIMES,
|
|
441
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
442
|
+
n_controls=cls._N_CONTROLS,
|
|
443
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
444
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
445
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
446
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
447
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
406
448
|
seed=0,
|
|
407
449
|
)
|
|
408
450
|
)
|
|
409
|
-
|
|
451
|
+
cls._short_input_data_non_media_and_organic = (
|
|
410
452
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
411
|
-
n_geos=
|
|
412
|
-
n_times=
|
|
413
|
-
n_media_times=
|
|
414
|
-
n_controls=
|
|
415
|
-
n_non_media_channels=
|
|
416
|
-
n_media_channels=
|
|
417
|
-
n_rf_channels=
|
|
418
|
-
n_organic_media_channels=
|
|
419
|
-
n_organic_rf_channels=
|
|
453
|
+
n_geos=cls._N_GEOS,
|
|
454
|
+
n_times=cls._N_TIMES_SHORT,
|
|
455
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
456
|
+
n_controls=cls._N_CONTROLS,
|
|
457
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
458
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
459
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
460
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
461
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
420
462
|
seed=0,
|
|
421
463
|
)
|
|
422
464
|
)
|
|
423
|
-
|
|
465
|
+
cls._short_input_data_non_media = (
|
|
424
466
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
425
|
-
n_geos=
|
|
426
|
-
n_times=
|
|
427
|
-
n_media_times=
|
|
428
|
-
n_controls=
|
|
429
|
-
n_non_media_channels=
|
|
430
|
-
n_media_channels=
|
|
431
|
-
n_rf_channels=
|
|
467
|
+
n_geos=cls._N_GEOS,
|
|
468
|
+
n_times=cls._N_TIMES_SHORT,
|
|
469
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
470
|
+
n_controls=cls._N_CONTROLS,
|
|
471
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
472
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
473
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
432
474
|
n_organic_media_channels=0,
|
|
433
475
|
n_organic_rf_channels=0,
|
|
434
476
|
seed=0,
|
|
435
477
|
)
|
|
436
478
|
)
|
|
479
|
+
cls._input_data_non_media_and_organic_same_time_dims = (
|
|
480
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
481
|
+
n_geos=cls._N_GEOS,
|
|
482
|
+
n_times=cls._N_TIMES,
|
|
483
|
+
n_media_times=cls._N_TIMES,
|
|
484
|
+
n_controls=cls._N_CONTROLS,
|
|
485
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
486
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
487
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
488
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
489
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
490
|
+
seed=0,
|
|
491
|
+
)
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
@property
|
|
495
|
+
def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
|
|
496
|
+
return self._input_data_non_revenue_no_revenue_per_kpi.copy(deep=True)
|
|
497
|
+
|
|
498
|
+
@property
|
|
499
|
+
def input_data_media_and_rf_non_revenue_no_revenue_per_kpi(
|
|
500
|
+
self,
|
|
501
|
+
) -> input_data.InputData:
|
|
502
|
+
return self._input_data_media_and_rf_non_revenue_no_revenue_per_kpi.copy(
|
|
503
|
+
deep=True
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
@property
|
|
507
|
+
def input_data_with_media_only(self) -> input_data.InputData:
|
|
508
|
+
return self._input_data_with_media_only.copy(deep=True)
|
|
509
|
+
|
|
510
|
+
@property
|
|
511
|
+
def input_data_with_rf_only(self) -> input_data.InputData:
|
|
512
|
+
return self._input_data_with_rf_only.copy(deep=True)
|
|
513
|
+
|
|
514
|
+
@property
|
|
515
|
+
def input_data_with_media_and_rf(self) -> input_data.InputData:
|
|
516
|
+
return self._input_data_with_media_and_rf.copy(deep=True)
|
|
517
|
+
|
|
518
|
+
@property
|
|
519
|
+
def input_data_with_media_and_rf_no_controls(
|
|
520
|
+
self,
|
|
521
|
+
) -> input_data.InputData:
|
|
522
|
+
return self._input_data_with_media_and_rf_no_controls.copy(deep=True)
|
|
523
|
+
|
|
524
|
+
@property
|
|
525
|
+
def short_input_data_with_media_only(self) -> input_data.InputData:
|
|
526
|
+
return self._short_input_data_with_media_only.copy(deep=True)
|
|
527
|
+
|
|
528
|
+
@property
|
|
529
|
+
def short_input_data_with_media_only_no_controls(
|
|
530
|
+
self,
|
|
531
|
+
) -> input_data.InputData:
|
|
532
|
+
return self._short_input_data_with_media_only_no_controls.copy(deep=True)
|
|
533
|
+
|
|
534
|
+
@property
|
|
535
|
+
def short_input_data_with_rf_only(self) -> input_data.InputData:
|
|
536
|
+
return self._short_input_data_with_rf_only.copy(deep=True)
|
|
537
|
+
|
|
538
|
+
@property
|
|
539
|
+
def short_input_data_with_media_and_rf(
|
|
540
|
+
self,
|
|
541
|
+
) -> input_data.InputData:
|
|
542
|
+
return self._short_input_data_with_media_and_rf.copy(deep=True)
|
|
543
|
+
|
|
544
|
+
@property
|
|
545
|
+
def national_input_data_media_only(self) -> input_data.InputData:
|
|
546
|
+
return self._national_input_data_media_only.copy(deep=True)
|
|
547
|
+
|
|
548
|
+
@property
|
|
549
|
+
def national_input_data_media_and_rf(self) -> input_data.InputData:
|
|
550
|
+
return self._national_input_data_media_and_rf.copy(deep=True)
|
|
551
|
+
|
|
552
|
+
@property
|
|
553
|
+
def test_dist_media_and_rf(
|
|
554
|
+
self,
|
|
555
|
+
) -> collections.OrderedDict[str, backend.Tensor]:
|
|
556
|
+
return collections.OrderedDict(self._test_dist_media_and_rf)
|
|
557
|
+
|
|
558
|
+
@property
|
|
559
|
+
def test_dist_media_only(
|
|
560
|
+
self,
|
|
561
|
+
) -> collections.OrderedDict[str, backend.Tensor]:
|
|
562
|
+
return collections.OrderedDict(self._test_dist_media_only)
|
|
563
|
+
|
|
564
|
+
@property
|
|
565
|
+
def test_dist_media_only_no_controls(
|
|
566
|
+
self,
|
|
567
|
+
) -> collections.OrderedDict[str, backend.Tensor]:
|
|
568
|
+
return collections.OrderedDict(self._test_dist_media_only_no_controls)
|
|
569
|
+
|
|
570
|
+
@property
|
|
571
|
+
def test_dist_rf_only(self) -> collections.OrderedDict[str, backend.Tensor]:
|
|
572
|
+
return collections.OrderedDict(self._test_dist_rf_only)
|
|
573
|
+
|
|
574
|
+
@property
|
|
575
|
+
def test_trace(self) -> dict[str, backend.Tensor]:
|
|
576
|
+
return self._test_trace.copy()
|
|
577
|
+
|
|
578
|
+
@property
|
|
579
|
+
def national_input_data_non_media_and_organic(
|
|
580
|
+
self,
|
|
581
|
+
) -> input_data.InputData:
|
|
582
|
+
return self._national_input_data_non_media_and_organic.copy(deep=True)
|
|
583
|
+
|
|
584
|
+
@property
|
|
585
|
+
def input_data_non_media_and_organic(self) -> input_data.InputData:
|
|
586
|
+
return self._input_data_non_media_and_organic.copy(deep=True)
|
|
587
|
+
|
|
588
|
+
@property
|
|
589
|
+
def short_input_data_non_media_and_organic(
|
|
590
|
+
self,
|
|
591
|
+
) -> input_data.InputData:
|
|
592
|
+
return self._short_input_data_non_media_and_organic.copy(deep=True)
|
|
593
|
+
|
|
594
|
+
@property
|
|
595
|
+
def short_input_data_non_media(self) -> input_data.InputData:
|
|
596
|
+
return self._short_input_data_non_media.copy(deep=True)
|
|
597
|
+
|
|
598
|
+
@property
|
|
599
|
+
def input_data_non_media_and_organic_same_time_dims(
|
|
600
|
+
self,
|
|
601
|
+
) -> input_data.InputData:
|
|
602
|
+
return self._input_data_non_media_and_organic_same_time_dims.copy(deep=True)
|