google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/METADATA +8 -2
- google_meridian-1.2.1.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +621 -393
- meridian/analysis/optimizer.py +403 -351
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +53 -54
- meridian/backend/__init__.py +975 -0
- meridian/backend/config.py +118 -0
- meridian/backend/test_utils.py +181 -0
- meridian/constants.py +71 -10
- meridian/data/input_data.py +99 -0
- meridian/data/test_utils.py +146 -12
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +280 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +735 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +331 -159
- meridian/model/posterior_sampler.py +388 -383
- meridian/model/prior_distribution.py +612 -177
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +55 -49
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -16,18 +16,20 @@
|
|
|
16
16
|
|
|
17
17
|
import collections
|
|
18
18
|
import os
|
|
19
|
+
from typing import NamedTuple
|
|
19
20
|
|
|
21
|
+
from meridian import backend
|
|
20
22
|
from meridian import constants
|
|
23
|
+
from meridian.data import input_data
|
|
21
24
|
from meridian.data import test_utils
|
|
22
|
-
import tensorflow as tf
|
|
23
25
|
import xarray as xr
|
|
24
26
|
|
|
25
27
|
|
|
26
|
-
def _convert_with_swap(array: xr.DataArray, n_burnin: int) ->
|
|
27
|
-
"""Converts a DataArray to a
|
|
28
|
+
def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> backend.Tensor:
|
|
29
|
+
"""Converts a DataArray to a backend.Tensor with the correct MCMC format.
|
|
28
30
|
|
|
29
|
-
This function converts a DataArray to
|
|
30
|
-
and adds the burnin part. This is needed to properly mock the
|
|
31
|
+
This function converts a DataArray to backend.Tensor, swaps first two
|
|
32
|
+
dimensions and adds the burnin part. This is needed to properly mock the
|
|
31
33
|
_xla_windowed_adaptive_nuts() function output in the sample_posterior
|
|
32
34
|
tests.
|
|
33
35
|
|
|
@@ -39,9 +41,9 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
|
|
|
39
41
|
A tensor in the same format as returned by the _xla_windowed_adaptive_nuts()
|
|
40
42
|
function.
|
|
41
43
|
"""
|
|
42
|
-
tensor =
|
|
44
|
+
tensor = backend.to_tensor(array)
|
|
43
45
|
perm = [1, 0] + [i for i in range(2, len(tensor.shape))]
|
|
44
|
-
transposed_tensor =
|
|
46
|
+
transposed_tensor = backend.transpose(tensor, perm=perm)
|
|
45
47
|
|
|
46
48
|
# Add the "burnin" part to the mocked output of _xla_windowed_adaptive_nuts
|
|
47
49
|
# to make sure sample_posterior returns the correct "keep" part.
|
|
@@ -50,15 +52,23 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
|
|
|
50
52
|
else:
|
|
51
53
|
pad_value = 0.0 if array.dtype.kind == "f" else 0
|
|
52
54
|
|
|
53
|
-
burnin =
|
|
54
|
-
|
|
55
|
+
burnin = backend.fill(
|
|
56
|
+
[n_burnin] + list(transposed_tensor.shape[1:]), pad_value
|
|
57
|
+
)
|
|
58
|
+
return backend.concatenate(
|
|
55
59
|
[burnin, transposed_tensor],
|
|
56
60
|
axis=0,
|
|
57
61
|
)
|
|
58
62
|
|
|
59
63
|
|
|
60
64
|
class WithInputDataSamples:
|
|
61
|
-
"""A mixin to inject test data samples to a unit test class.
|
|
65
|
+
"""A mixin to inject test data samples to a unit test class.
|
|
66
|
+
|
|
67
|
+
The `setup` method is a classmethod because loading and creating the data
|
|
68
|
+
samples can be expensive. The properties return deep copies to ensure
|
|
69
|
+
immutability for the data properties. As a result, it is recommended to use
|
|
70
|
+
local variables to avoid re-loading the data samples multiple times.
|
|
71
|
+
"""
|
|
62
72
|
|
|
63
73
|
# TODO: Update the sample data to span over 1 or 2 year(s).
|
|
64
74
|
_TEST_DIR = os.path.join(os.path.dirname(__file__), "test_data")
|
|
@@ -114,169 +124,207 @@ class WithInputDataSamples:
|
|
|
114
124
|
_N_MEDIA_CHANNELS = 3
|
|
115
125
|
_N_RF_CHANNELS = 2
|
|
116
126
|
_N_CONTROLS = 2
|
|
117
|
-
_ROI_CALIBRATION_PERIOD = tf.cast(
|
|
118
|
-
tf.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
|
|
119
|
-
dtype=tf.bool,
|
|
120
|
-
)
|
|
121
|
-
_RF_ROI_CALIBRATION_PERIOD = tf.cast(
|
|
122
|
-
tf.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
|
|
123
|
-
dtype=tf.bool,
|
|
124
|
-
)
|
|
125
127
|
_N_ORGANIC_MEDIA_CHANNELS = 4
|
|
126
128
|
_N_ORGANIC_RF_CHANNELS = 1
|
|
127
129
|
_N_NON_MEDIA_CHANNELS = 2
|
|
128
130
|
|
|
129
|
-
|
|
131
|
+
_ROI_CALIBRATION_PERIOD: backend.Tensor
|
|
132
|
+
_RF_ROI_CALIBRATION_PERIOD: backend.Tensor
|
|
133
|
+
|
|
134
|
+
# Private class variables to hold the base test data.
|
|
135
|
+
_input_data_non_revenue_no_revenue_per_kpi: input_data.InputData
|
|
136
|
+
_input_data_media_and_rf_non_revenue_no_revenue_per_kpi: input_data.InputData
|
|
137
|
+
_input_data_with_media_only: input_data.InputData
|
|
138
|
+
_input_data_with_rf_only: input_data.InputData
|
|
139
|
+
_input_data_with_media_and_rf: input_data.InputData
|
|
140
|
+
_input_data_with_media_and_rf_no_controls: input_data.InputData
|
|
141
|
+
_short_input_data_with_media_only: input_data.InputData
|
|
142
|
+
_short_input_data_with_media_only_no_controls: input_data.InputData
|
|
143
|
+
_short_input_data_with_rf_only: input_data.InputData
|
|
144
|
+
_short_input_data_with_media_and_rf: input_data.InputData
|
|
145
|
+
_national_input_data_media_only: input_data.InputData
|
|
146
|
+
_national_input_data_media_and_rf: input_data.InputData
|
|
147
|
+
_test_dist_media_and_rf: collections.OrderedDict[str, backend.Tensor]
|
|
148
|
+
_test_dist_media_only: collections.OrderedDict[str, backend.Tensor]
|
|
149
|
+
_test_dist_media_only_no_controls: collections.OrderedDict[
|
|
150
|
+
str, backend.Tensor
|
|
151
|
+
]
|
|
152
|
+
_test_dist_rf_only: collections.OrderedDict[str, backend.Tensor]
|
|
153
|
+
_test_trace: dict[str, backend.Tensor]
|
|
154
|
+
_national_input_data_non_media_and_organic: input_data.InputData
|
|
155
|
+
_input_data_non_media_and_organic: input_data.InputData
|
|
156
|
+
_short_input_data_non_media_and_organic: input_data.InputData
|
|
157
|
+
_short_input_data_non_media: input_data.InputData
|
|
158
|
+
_input_data_non_media_and_organic_same_time_dims: input_data.InputData
|
|
159
|
+
|
|
160
|
+
# The following NamedTuples and their attributes are immutable, so they can
|
|
161
|
+
# be accessed directly.
|
|
162
|
+
test_posterior_states_media_and_rf: NamedTuple
|
|
163
|
+
test_posterior_states_media_only: NamedTuple
|
|
164
|
+
test_posterior_states_media_only_no_controls: NamedTuple
|
|
165
|
+
test_posterior_states_rf_only: NamedTuple
|
|
166
|
+
|
|
167
|
+
@classmethod
|
|
168
|
+
def setup(cls):
|
|
130
169
|
"""Sets up input data samples."""
|
|
131
|
-
|
|
170
|
+
cls._ROI_CALIBRATION_PERIOD = backend.cast(
|
|
171
|
+
backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_MEDIA_CHANNELS)),
|
|
172
|
+
dtype=backend.bool_,
|
|
173
|
+
)
|
|
174
|
+
cls._RF_ROI_CALIBRATION_PERIOD = backend.cast(
|
|
175
|
+
backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_RF_CHANNELS)),
|
|
176
|
+
dtype=backend.bool_,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
cls._input_data_non_revenue_no_revenue_per_kpi = (
|
|
132
180
|
test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
|
|
133
|
-
n_geos=
|
|
134
|
-
n_times=
|
|
135
|
-
n_media_times=
|
|
136
|
-
n_controls=
|
|
137
|
-
n_media_channels=
|
|
181
|
+
n_geos=cls._N_GEOS,
|
|
182
|
+
n_times=cls._N_TIMES,
|
|
183
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
184
|
+
n_controls=cls._N_CONTROLS,
|
|
185
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
138
186
|
seed=0,
|
|
139
187
|
)
|
|
140
188
|
)
|
|
141
|
-
|
|
189
|
+
cls._input_data_media_and_rf_non_revenue_no_revenue_per_kpi = (
|
|
142
190
|
test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
|
|
143
|
-
n_geos=
|
|
144
|
-
n_times=
|
|
145
|
-
n_media_times=
|
|
146
|
-
n_controls=
|
|
147
|
-
n_media_channels=
|
|
148
|
-
n_rf_channels=
|
|
191
|
+
n_geos=cls._N_GEOS,
|
|
192
|
+
n_times=cls._N_TIMES,
|
|
193
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
194
|
+
n_controls=cls._N_CONTROLS,
|
|
195
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
196
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
149
197
|
seed=0,
|
|
150
198
|
)
|
|
151
199
|
)
|
|
152
|
-
|
|
200
|
+
cls._input_data_with_media_only = (
|
|
153
201
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
154
|
-
n_geos=
|
|
155
|
-
n_times=
|
|
156
|
-
n_media_times=
|
|
157
|
-
n_controls=
|
|
158
|
-
n_media_channels=
|
|
202
|
+
n_geos=cls._N_GEOS,
|
|
203
|
+
n_times=cls._N_TIMES,
|
|
204
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
205
|
+
n_controls=cls._N_CONTROLS,
|
|
206
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
159
207
|
seed=0,
|
|
160
208
|
)
|
|
161
209
|
)
|
|
162
|
-
|
|
210
|
+
cls._input_data_with_rf_only = (
|
|
163
211
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
164
|
-
n_geos=
|
|
165
|
-
n_times=
|
|
166
|
-
n_media_times=
|
|
167
|
-
n_controls=
|
|
168
|
-
n_rf_channels=
|
|
212
|
+
n_geos=cls._N_GEOS,
|
|
213
|
+
n_times=cls._N_TIMES,
|
|
214
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
215
|
+
n_controls=cls._N_CONTROLS,
|
|
216
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
169
217
|
seed=0,
|
|
170
218
|
)
|
|
171
219
|
)
|
|
172
|
-
|
|
220
|
+
cls._input_data_with_media_and_rf = (
|
|
173
221
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
174
|
-
n_geos=
|
|
175
|
-
n_times=
|
|
176
|
-
n_media_times=
|
|
177
|
-
n_controls=
|
|
178
|
-
n_media_channels=
|
|
179
|
-
n_rf_channels=
|
|
222
|
+
n_geos=cls._N_GEOS,
|
|
223
|
+
n_times=cls._N_TIMES,
|
|
224
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
225
|
+
n_controls=cls._N_CONTROLS,
|
|
226
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
227
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
180
228
|
seed=0,
|
|
181
229
|
)
|
|
182
230
|
)
|
|
183
|
-
|
|
231
|
+
cls._input_data_with_media_and_rf_no_controls = (
|
|
184
232
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
185
|
-
n_geos=
|
|
186
|
-
n_times=
|
|
187
|
-
n_media_times=
|
|
233
|
+
n_geos=cls._N_GEOS,
|
|
234
|
+
n_times=cls._N_TIMES,
|
|
235
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
188
236
|
n_controls=None,
|
|
189
|
-
n_media_channels=
|
|
190
|
-
n_rf_channels=
|
|
237
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
238
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
191
239
|
seed=0,
|
|
192
240
|
)
|
|
193
241
|
)
|
|
194
|
-
|
|
242
|
+
cls._short_input_data_with_media_only = (
|
|
195
243
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
196
|
-
n_geos=
|
|
197
|
-
n_times=
|
|
198
|
-
n_media_times=
|
|
199
|
-
n_controls=
|
|
200
|
-
n_media_channels=
|
|
244
|
+
n_geos=cls._N_GEOS,
|
|
245
|
+
n_times=cls._N_TIMES_SHORT,
|
|
246
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
247
|
+
n_controls=cls._N_CONTROLS,
|
|
248
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
201
249
|
seed=0,
|
|
202
250
|
)
|
|
203
251
|
)
|
|
204
|
-
|
|
252
|
+
cls._short_input_data_with_media_only_no_controls = (
|
|
205
253
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
206
|
-
n_geos=
|
|
207
|
-
n_times=
|
|
208
|
-
n_media_times=
|
|
254
|
+
n_geos=cls._N_GEOS,
|
|
255
|
+
n_times=cls._N_TIMES_SHORT,
|
|
256
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
209
257
|
n_controls=0,
|
|
210
|
-
n_media_channels=
|
|
258
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
211
259
|
seed=0,
|
|
212
260
|
)
|
|
213
261
|
)
|
|
214
|
-
|
|
262
|
+
cls._short_input_data_with_rf_only = (
|
|
215
263
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
216
|
-
n_geos=
|
|
217
|
-
n_times=
|
|
218
|
-
n_media_times=
|
|
219
|
-
n_controls=
|
|
220
|
-
n_rf_channels=
|
|
264
|
+
n_geos=cls._N_GEOS,
|
|
265
|
+
n_times=cls._N_TIMES_SHORT,
|
|
266
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
267
|
+
n_controls=cls._N_CONTROLS,
|
|
268
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
221
269
|
seed=0,
|
|
222
270
|
)
|
|
223
271
|
)
|
|
224
|
-
|
|
272
|
+
cls._short_input_data_with_media_and_rf = (
|
|
225
273
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
226
|
-
n_geos=
|
|
227
|
-
n_times=
|
|
228
|
-
n_media_times=
|
|
229
|
-
n_controls=
|
|
230
|
-
n_media_channels=
|
|
231
|
-
n_rf_channels=
|
|
274
|
+
n_geos=cls._N_GEOS,
|
|
275
|
+
n_times=cls._N_TIMES_SHORT,
|
|
276
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
277
|
+
n_controls=cls._N_CONTROLS,
|
|
278
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
279
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
232
280
|
seed=0,
|
|
233
281
|
)
|
|
234
282
|
)
|
|
235
|
-
|
|
283
|
+
cls._national_input_data_media_only = (
|
|
236
284
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
237
|
-
n_geos=
|
|
238
|
-
n_times=
|
|
239
|
-
n_media_times=
|
|
240
|
-
n_controls=
|
|
241
|
-
n_media_channels=
|
|
285
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
286
|
+
n_times=cls._N_TIMES,
|
|
287
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
288
|
+
n_controls=cls._N_CONTROLS,
|
|
289
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
242
290
|
seed=0,
|
|
243
291
|
)
|
|
244
292
|
)
|
|
245
|
-
|
|
293
|
+
cls._national_input_data_media_and_rf = (
|
|
246
294
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
247
|
-
n_geos=
|
|
248
|
-
n_times=
|
|
249
|
-
n_media_times=
|
|
250
|
-
n_controls=
|
|
251
|
-
n_media_channels=
|
|
252
|
-
n_rf_channels=
|
|
295
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
296
|
+
n_times=cls._N_TIMES,
|
|
297
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
298
|
+
n_controls=cls._N_CONTROLS,
|
|
299
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
300
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
253
301
|
seed=0,
|
|
254
302
|
)
|
|
255
303
|
)
|
|
256
304
|
|
|
257
305
|
test_prior_media_and_rf = xr.open_dataset(
|
|
258
|
-
|
|
306
|
+
cls._TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH
|
|
259
307
|
)
|
|
260
308
|
test_prior_media_only = xr.open_dataset(
|
|
261
|
-
|
|
309
|
+
cls._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
|
|
262
310
|
)
|
|
263
311
|
test_prior_media_only_no_controls = xr.open_dataset(
|
|
264
|
-
|
|
312
|
+
cls._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
|
|
265
313
|
)
|
|
266
|
-
test_prior_rf_only = xr.open_dataset(
|
|
267
|
-
|
|
268
|
-
param:
|
|
314
|
+
test_prior_rf_only = xr.open_dataset(cls._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
|
|
315
|
+
cls._test_dist_media_and_rf = collections.OrderedDict({
|
|
316
|
+
param: backend.to_tensor(test_prior_media_and_rf[param])
|
|
269
317
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
270
318
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
271
319
|
+ constants.RF_PARAMETER_NAMES
|
|
272
320
|
})
|
|
273
|
-
|
|
274
|
-
param:
|
|
321
|
+
cls._test_dist_media_only = collections.OrderedDict({
|
|
322
|
+
param: backend.to_tensor(test_prior_media_only[param])
|
|
275
323
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
276
324
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
277
325
|
})
|
|
278
|
-
|
|
279
|
-
param:
|
|
326
|
+
cls._test_dist_media_only_no_controls = collections.OrderedDict({
|
|
327
|
+
param: backend.to_tensor(test_prior_media_only_no_controls[param])
|
|
280
328
|
for param in (
|
|
281
329
|
set(
|
|
282
330
|
constants.COMMON_PARAMETER_NAMES
|
|
@@ -287,27 +335,27 @@ class WithInputDataSamples:
|
|
|
287
335
|
)
|
|
288
336
|
)
|
|
289
337
|
})
|
|
290
|
-
|
|
291
|
-
param:
|
|
338
|
+
cls._test_dist_rf_only = collections.OrderedDict({
|
|
339
|
+
param: backend.to_tensor(test_prior_rf_only[param])
|
|
292
340
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
293
341
|
+ constants.RF_PARAMETER_NAMES
|
|
294
342
|
})
|
|
295
343
|
|
|
296
344
|
test_posterior_media_and_rf = xr.open_dataset(
|
|
297
|
-
|
|
345
|
+
cls._TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH
|
|
298
346
|
)
|
|
299
347
|
test_posterior_media_only = xr.open_dataset(
|
|
300
|
-
|
|
348
|
+
cls._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
|
|
301
349
|
)
|
|
302
350
|
test_posterior_media_only_no_controls = xr.open_dataset(
|
|
303
|
-
|
|
351
|
+
cls._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
|
|
304
352
|
)
|
|
305
353
|
test_posterior_rf_only = xr.open_dataset(
|
|
306
|
-
|
|
354
|
+
cls._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
|
|
307
355
|
)
|
|
308
356
|
posterior_params_to_tensors_media_and_rf = {
|
|
309
357
|
param: _convert_with_swap(
|
|
310
|
-
test_posterior_media_and_rf[param], n_burnin=
|
|
358
|
+
test_posterior_media_and_rf[param], n_burnin=cls._N_BURNIN
|
|
311
359
|
)
|
|
312
360
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
313
361
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
@@ -315,7 +363,7 @@ class WithInputDataSamples:
|
|
|
315
363
|
}
|
|
316
364
|
posterior_params_to_tensors_media_only = {
|
|
317
365
|
param: _convert_with_swap(
|
|
318
|
-
test_posterior_media_only[param], n_burnin=
|
|
366
|
+
test_posterior_media_only[param], n_burnin=cls._N_BURNIN
|
|
319
367
|
)
|
|
320
368
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
321
369
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
@@ -323,7 +371,7 @@ class WithInputDataSamples:
|
|
|
323
371
|
posterior_params_to_tensors_media_only_no_controls = {
|
|
324
372
|
param: _convert_with_swap(
|
|
325
373
|
test_posterior_media_only_no_controls[param],
|
|
326
|
-
n_burnin=
|
|
374
|
+
n_burnin=cls._N_BURNIN,
|
|
327
375
|
)
|
|
328
376
|
for param in (
|
|
329
377
|
set(
|
|
@@ -337,22 +385,22 @@ class WithInputDataSamples:
|
|
|
337
385
|
}
|
|
338
386
|
posterior_params_to_tensors_rf_only = {
|
|
339
387
|
param: _convert_with_swap(
|
|
340
|
-
test_posterior_rf_only[param], n_burnin=
|
|
388
|
+
test_posterior_rf_only[param], n_burnin=cls._N_BURNIN
|
|
341
389
|
)
|
|
342
390
|
for param in constants.COMMON_PARAMETER_NAMES
|
|
343
391
|
+ constants.RF_PARAMETER_NAMES
|
|
344
392
|
}
|
|
345
|
-
|
|
393
|
+
cls.test_posterior_states_media_and_rf = collections.namedtuple(
|
|
346
394
|
"StructTuple",
|
|
347
395
|
constants.COMMON_PARAMETER_NAMES
|
|
348
396
|
+ constants.MEDIA_PARAMETER_NAMES
|
|
349
397
|
+ constants.RF_PARAMETER_NAMES,
|
|
350
398
|
)(**posterior_params_to_tensors_media_and_rf)
|
|
351
|
-
|
|
399
|
+
cls.test_posterior_states_media_only = collections.namedtuple(
|
|
352
400
|
"StructTuple",
|
|
353
401
|
constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
|
|
354
402
|
)(**posterior_params_to_tensors_media_only)
|
|
355
|
-
|
|
403
|
+
cls.test_posterior_states_media_only_no_controls = collections.namedtuple(
|
|
356
404
|
"StructTuple",
|
|
357
405
|
(
|
|
358
406
|
set(
|
|
@@ -364,73 +412,197 @@ class WithInputDataSamples:
|
|
|
364
412
|
)
|
|
365
413
|
),
|
|
366
414
|
)(**posterior_params_to_tensors_media_only_no_controls)
|
|
367
|
-
|
|
415
|
+
cls.test_posterior_states_rf_only = collections.namedtuple(
|
|
368
416
|
"StructTuple",
|
|
369
417
|
constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,
|
|
370
418
|
)(**posterior_params_to_tensors_rf_only)
|
|
371
419
|
|
|
372
|
-
test_trace = xr.open_dataset(
|
|
373
|
-
|
|
374
|
-
param: _convert_with_swap(test_trace[param], n_burnin=
|
|
420
|
+
test_trace = xr.open_dataset(cls._TEST_SAMPLE_TRACE_PATH)
|
|
421
|
+
cls._test_trace = {
|
|
422
|
+
param: _convert_with_swap(test_trace[param], n_burnin=cls._N_BURNIN)
|
|
375
423
|
for param in test_trace.data_vars
|
|
376
424
|
}
|
|
377
425
|
|
|
378
426
|
# The following are input data samples with non-paid channels.
|
|
379
427
|
|
|
380
|
-
|
|
428
|
+
cls._national_input_data_non_media_and_organic = (
|
|
381
429
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
382
|
-
n_geos=
|
|
383
|
-
n_times=
|
|
384
|
-
n_media_times=
|
|
385
|
-
n_controls=
|
|
386
|
-
n_non_media_channels=
|
|
387
|
-
n_media_channels=
|
|
388
|
-
n_rf_channels=
|
|
389
|
-
n_organic_media_channels=
|
|
390
|
-
n_organic_rf_channels=
|
|
430
|
+
n_geos=cls._N_GEOS_NATIONAL,
|
|
431
|
+
n_times=cls._N_TIMES,
|
|
432
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
433
|
+
n_controls=cls._N_CONTROLS,
|
|
434
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
435
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
436
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
437
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
438
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
391
439
|
seed=0,
|
|
392
440
|
)
|
|
393
441
|
)
|
|
394
442
|
|
|
395
|
-
|
|
443
|
+
cls._input_data_non_media_and_organic = (
|
|
396
444
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
397
|
-
n_geos=
|
|
398
|
-
n_times=
|
|
399
|
-
n_media_times=
|
|
400
|
-
n_controls=
|
|
401
|
-
n_non_media_channels=
|
|
402
|
-
n_media_channels=
|
|
403
|
-
n_rf_channels=
|
|
404
|
-
n_organic_media_channels=
|
|
405
|
-
n_organic_rf_channels=
|
|
445
|
+
n_geos=cls._N_GEOS,
|
|
446
|
+
n_times=cls._N_TIMES,
|
|
447
|
+
n_media_times=cls._N_MEDIA_TIMES,
|
|
448
|
+
n_controls=cls._N_CONTROLS,
|
|
449
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
450
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
451
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
452
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
453
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
406
454
|
seed=0,
|
|
407
455
|
)
|
|
408
456
|
)
|
|
409
|
-
|
|
457
|
+
cls._short_input_data_non_media_and_organic = (
|
|
410
458
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
411
|
-
n_geos=
|
|
412
|
-
n_times=
|
|
413
|
-
n_media_times=
|
|
414
|
-
n_controls=
|
|
415
|
-
n_non_media_channels=
|
|
416
|
-
n_media_channels=
|
|
417
|
-
n_rf_channels=
|
|
418
|
-
n_organic_media_channels=
|
|
419
|
-
n_organic_rf_channels=
|
|
459
|
+
n_geos=cls._N_GEOS,
|
|
460
|
+
n_times=cls._N_TIMES_SHORT,
|
|
461
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
462
|
+
n_controls=cls._N_CONTROLS,
|
|
463
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
464
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
465
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
466
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
467
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
420
468
|
seed=0,
|
|
421
469
|
)
|
|
422
470
|
)
|
|
423
|
-
|
|
471
|
+
cls._short_input_data_non_media = (
|
|
424
472
|
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
425
|
-
n_geos=
|
|
426
|
-
n_times=
|
|
427
|
-
n_media_times=
|
|
428
|
-
n_controls=
|
|
429
|
-
n_non_media_channels=
|
|
430
|
-
n_media_channels=
|
|
431
|
-
n_rf_channels=
|
|
473
|
+
n_geos=cls._N_GEOS,
|
|
474
|
+
n_times=cls._N_TIMES_SHORT,
|
|
475
|
+
n_media_times=cls._N_MEDIA_TIMES_SHORT,
|
|
476
|
+
n_controls=cls._N_CONTROLS,
|
|
477
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
478
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
479
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
432
480
|
n_organic_media_channels=0,
|
|
433
481
|
n_organic_rf_channels=0,
|
|
434
482
|
seed=0,
|
|
435
483
|
)
|
|
436
484
|
)
|
|
485
|
+
cls._input_data_non_media_and_organic_same_time_dims = (
|
|
486
|
+
test_utils.sample_input_data_non_revenue_revenue_per_kpi(
|
|
487
|
+
n_geos=cls._N_GEOS,
|
|
488
|
+
n_times=cls._N_TIMES,
|
|
489
|
+
n_media_times=cls._N_TIMES,
|
|
490
|
+
n_controls=cls._N_CONTROLS,
|
|
491
|
+
n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
|
|
492
|
+
n_media_channels=cls._N_MEDIA_CHANNELS,
|
|
493
|
+
n_rf_channels=cls._N_RF_CHANNELS,
|
|
494
|
+
n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
|
|
495
|
+
n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
|
|
496
|
+
seed=0,
|
|
497
|
+
)
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
@property
|
|
501
|
+
def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
|
|
502
|
+
return self._input_data_non_revenue_no_revenue_per_kpi.copy(deep=True)
|
|
503
|
+
|
|
504
|
+
@property
|
|
505
|
+
def input_data_media_and_rf_non_revenue_no_revenue_per_kpi(
|
|
506
|
+
self,
|
|
507
|
+
) -> input_data.InputData:
|
|
508
|
+
return self._input_data_media_and_rf_non_revenue_no_revenue_per_kpi.copy(
|
|
509
|
+
deep=True
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
@property
|
|
513
|
+
def input_data_with_media_only(self) -> input_data.InputData:
|
|
514
|
+
return self._input_data_with_media_only.copy(deep=True)
|
|
515
|
+
|
|
516
|
+
@property
|
|
517
|
+
def input_data_with_rf_only(self) -> input_data.InputData:
|
|
518
|
+
return self._input_data_with_rf_only.copy(deep=True)
|
|
519
|
+
|
|
520
|
+
@property
|
|
521
|
+
def input_data_with_media_and_rf(self) -> input_data.InputData:
|
|
522
|
+
return self._input_data_with_media_and_rf.copy(deep=True)
|
|
523
|
+
|
|
524
|
+
@property
|
|
525
|
+
def input_data_with_media_and_rf_no_controls(
|
|
526
|
+
self,
|
|
527
|
+
) -> input_data.InputData:
|
|
528
|
+
return self._input_data_with_media_and_rf_no_controls.copy(deep=True)
|
|
529
|
+
|
|
530
|
+
@property
|
|
531
|
+
def short_input_data_with_media_only(self) -> input_data.InputData:
|
|
532
|
+
return self._short_input_data_with_media_only.copy(deep=True)
|
|
533
|
+
|
|
534
|
+
@property
|
|
535
|
+
def short_input_data_with_media_only_no_controls(
|
|
536
|
+
self,
|
|
537
|
+
) -> input_data.InputData:
|
|
538
|
+
return self._short_input_data_with_media_only_no_controls.copy(deep=True)
|
|
539
|
+
|
|
540
|
+
@property
|
|
541
|
+
def short_input_data_with_rf_only(self) -> input_data.InputData:
|
|
542
|
+
return self._short_input_data_with_rf_only.copy(deep=True)
|
|
543
|
+
|
|
544
|
+
@property
|
|
545
|
+
def short_input_data_with_media_and_rf(
|
|
546
|
+
self,
|
|
547
|
+
) -> input_data.InputData:
|
|
548
|
+
return self._short_input_data_with_media_and_rf.copy(deep=True)
|
|
549
|
+
|
|
550
|
+
@property
|
|
551
|
+
def national_input_data_media_only(self) -> input_data.InputData:
|
|
552
|
+
return self._national_input_data_media_only.copy(deep=True)
|
|
553
|
+
|
|
554
|
+
@property
|
|
555
|
+
def national_input_data_media_and_rf(self) -> input_data.InputData:
|
|
556
|
+
return self._national_input_data_media_and_rf.copy(deep=True)
|
|
557
|
+
|
|
558
|
+
@property
|
|
559
|
+
def test_dist_media_and_rf(
|
|
560
|
+
self,
|
|
561
|
+
) -> collections.OrderedDict[str, backend.Tensor]:
|
|
562
|
+
return collections.OrderedDict(self._test_dist_media_and_rf)
|
|
563
|
+
|
|
564
|
+
@property
|
|
565
|
+
def test_dist_media_only(
|
|
566
|
+
self,
|
|
567
|
+
) -> collections.OrderedDict[str, backend.Tensor]:
|
|
568
|
+
return collections.OrderedDict(self._test_dist_media_only)
|
|
569
|
+
|
|
570
|
+
@property
|
|
571
|
+
def test_dist_media_only_no_controls(
|
|
572
|
+
self,
|
|
573
|
+
) -> collections.OrderedDict[str, backend.Tensor]:
|
|
574
|
+
return collections.OrderedDict(self._test_dist_media_only_no_controls)
|
|
575
|
+
|
|
576
|
+
@property
|
|
577
|
+
def test_dist_rf_only(self) -> collections.OrderedDict[str, backend.Tensor]:
|
|
578
|
+
return collections.OrderedDict(self._test_dist_rf_only)
|
|
579
|
+
|
|
580
|
+
@property
|
|
581
|
+
def test_trace(self) -> dict[str, backend.Tensor]:
|
|
582
|
+
return self._test_trace.copy()
|
|
583
|
+
|
|
584
|
+
@property
|
|
585
|
+
def national_input_data_non_media_and_organic(
|
|
586
|
+
self,
|
|
587
|
+
) -> input_data.InputData:
|
|
588
|
+
return self._national_input_data_non_media_and_organic.copy(deep=True)
|
|
589
|
+
|
|
590
|
+
@property
|
|
591
|
+
def input_data_non_media_and_organic(self) -> input_data.InputData:
|
|
592
|
+
return self._input_data_non_media_and_organic.copy(deep=True)
|
|
593
|
+
|
|
594
|
+
@property
|
|
595
|
+
def short_input_data_non_media_and_organic(
|
|
596
|
+
self,
|
|
597
|
+
) -> input_data.InputData:
|
|
598
|
+
return self._short_input_data_non_media_and_organic.copy(deep=True)
|
|
599
|
+
|
|
600
|
+
@property
|
|
601
|
+
def short_input_data_non_media(self) -> input_data.InputData:
|
|
602
|
+
return self._short_input_data_non_media.copy(deep=True)
|
|
603
|
+
|
|
604
|
+
@property
|
|
605
|
+
def input_data_non_media_and_organic_same_time_dims(
|
|
606
|
+
self,
|
|
607
|
+
) -> input_data.InputData:
|
|
608
|
+
return self._input_data_non_media_and_organic_same_time_dims.copy(deep=True)
|