google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,18 +16,20 @@
16
16
 
17
17
  import collections
18
18
  import os
19
+ from typing import NamedTuple
19
20
 
21
+ from meridian import backend
20
22
  from meridian import constants
23
+ from meridian.data import input_data
21
24
  from meridian.data import test_utils
22
- import tensorflow as tf
23
25
  import xarray as xr
24
26
 
25
27
 
26
- def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
27
- """Converts a DataArray to a tf.Tensor with the correct MCMC format.
28
+ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> backend.Tensor:
29
+ """Converts a DataArray to a backend.Tensor with the correct MCMC format.
28
30
 
29
- This function converts a DataArray to tf.Tensor, swaps first two dimensions
30
- and adds the burnin part. This is needed to properly mock the
31
+ This function converts a DataArray to backend.Tensor, swaps first two
32
+ dimensions and adds the burnin part. This is needed to properly mock the
31
33
  _xla_windowed_adaptive_nuts() function output in the sample_posterior
32
34
  tests.
33
35
 
@@ -39,9 +41,9 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
39
41
  A tensor in the same format as returned by the _xla_windowed_adaptive_nuts()
40
42
  function.
41
43
  """
42
- tensor = tf.convert_to_tensor(array)
44
+ tensor = backend.to_tensor(array)
43
45
  perm = [1, 0] + [i for i in range(2, len(tensor.shape))]
44
- transposed_tensor = tf.transpose(tensor, perm=perm)
46
+ transposed_tensor = backend.transpose(tensor, perm=perm)
45
47
 
46
48
  # Add the "burnin" part to the mocked output of _xla_windowed_adaptive_nuts
47
49
  # to make sure sample_posterior returns the correct "keep" part.
@@ -50,15 +52,21 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
50
52
  else:
51
53
  pad_value = 0.0 if array.dtype.kind == "f" else 0
52
54
 
53
- burnin = tf.fill([n_burnin] + transposed_tensor.shape[1:], pad_value)
54
- return tf.concat(
55
+ burnin = backend.fill([n_burnin] + transposed_tensor.shape[1:], pad_value)
56
+ return backend.concatenate(
55
57
  [burnin, transposed_tensor],
56
58
  axis=0,
57
59
  )
58
60
 
59
61
 
60
62
  class WithInputDataSamples:
61
- """A mixin to inject test data samples to a unit test class."""
63
+ """A mixin to inject test data samples to a unit test class.
64
+
65
+ The `setup` method is a classmethod because loading and creating the data
66
+ samples can be expensive. The properties return deep copies to ensure
67
+ immutability for the data properties. As a result, it is recommended to use
68
+ local variables to avoid re-loading the data samples multiple times.
69
+ """
62
70
 
63
71
  # TODO: Update the sample data to span over 1 or 2 year(s).
64
72
  _TEST_DIR = os.path.join(os.path.dirname(__file__), "test_data")
@@ -114,169 +122,203 @@ class WithInputDataSamples:
114
122
  _N_MEDIA_CHANNELS = 3
115
123
  _N_RF_CHANNELS = 2
116
124
  _N_CONTROLS = 2
117
- _ROI_CALIBRATION_PERIOD = tf.cast(
118
- tf.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
119
- dtype=tf.bool,
125
+ _ROI_CALIBRATION_PERIOD = backend.cast(
126
+ backend.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
127
+ dtype=backend.bool_,
120
128
  )
121
- _RF_ROI_CALIBRATION_PERIOD = tf.cast(
122
- tf.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
123
- dtype=tf.bool,
129
+ _RF_ROI_CALIBRATION_PERIOD = backend.cast(
130
+ backend.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
131
+ dtype=backend.bool_,
124
132
  )
125
133
  _N_ORGANIC_MEDIA_CHANNELS = 4
126
134
  _N_ORGANIC_RF_CHANNELS = 1
127
135
  _N_NON_MEDIA_CHANNELS = 2
128
136
 
129
- def setup(self):
137
+ # Private class variables to hold the base test data.
138
+ _input_data_non_revenue_no_revenue_per_kpi: input_data.InputData
139
+ _input_data_media_and_rf_non_revenue_no_revenue_per_kpi: input_data.InputData
140
+ _input_data_with_media_only: input_data.InputData
141
+ _input_data_with_rf_only: input_data.InputData
142
+ _input_data_with_media_and_rf: input_data.InputData
143
+ _input_data_with_media_and_rf_no_controls: input_data.InputData
144
+ _short_input_data_with_media_only: input_data.InputData
145
+ _short_input_data_with_media_only_no_controls: input_data.InputData
146
+ _short_input_data_with_rf_only: input_data.InputData
147
+ _short_input_data_with_media_and_rf: input_data.InputData
148
+ _national_input_data_media_only: input_data.InputData
149
+ _national_input_data_media_and_rf: input_data.InputData
150
+ _test_dist_media_and_rf: collections.OrderedDict[str, backend.Tensor]
151
+ _test_dist_media_only: collections.OrderedDict[str, backend.Tensor]
152
+ _test_dist_media_only_no_controls: collections.OrderedDict[
153
+ str, backend.Tensor
154
+ ]
155
+ _test_dist_rf_only: collections.OrderedDict[str, backend.Tensor]
156
+ _test_trace: dict[str, backend.Tensor]
157
+ _national_input_data_non_media_and_organic: input_data.InputData
158
+ _input_data_non_media_and_organic: input_data.InputData
159
+ _short_input_data_non_media_and_organic: input_data.InputData
160
+ _short_input_data_non_media: input_data.InputData
161
+ _input_data_non_media_and_organic_same_time_dims: input_data.InputData
162
+
163
+ # The following NamedTuples and their attributes are immutable, so they can
164
+ # be accessed directly.
165
+ test_posterior_states_media_and_rf: NamedTuple
166
+ test_posterior_states_media_only: NamedTuple
167
+ test_posterior_states_media_only_no_controls: NamedTuple
168
+ test_posterior_states_rf_only: NamedTuple
169
+
170
+ @classmethod
171
+ def setup(cls):
130
172
  """Sets up input data samples."""
131
- self.input_data_non_revenue_no_revenue_per_kpi = (
173
+ cls._input_data_non_revenue_no_revenue_per_kpi = (
132
174
  test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
133
- n_geos=self._N_GEOS,
134
- n_times=self._N_TIMES,
135
- n_media_times=self._N_MEDIA_TIMES,
136
- n_controls=self._N_CONTROLS,
137
- n_media_channels=self._N_MEDIA_CHANNELS,
175
+ n_geos=cls._N_GEOS,
176
+ n_times=cls._N_TIMES,
177
+ n_media_times=cls._N_MEDIA_TIMES,
178
+ n_controls=cls._N_CONTROLS,
179
+ n_media_channels=cls._N_MEDIA_CHANNELS,
138
180
  seed=0,
139
181
  )
140
182
  )
141
- self.input_data_media_and_rf_non_revenue_no_revenue_per_kpi = (
183
+ cls._input_data_media_and_rf_non_revenue_no_revenue_per_kpi = (
142
184
  test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
143
- n_geos=self._N_GEOS,
144
- n_times=self._N_TIMES,
145
- n_media_times=self._N_MEDIA_TIMES,
146
- n_controls=self._N_CONTROLS,
147
- n_media_channels=self._N_MEDIA_CHANNELS,
148
- n_rf_channels=self._N_RF_CHANNELS,
185
+ n_geos=cls._N_GEOS,
186
+ n_times=cls._N_TIMES,
187
+ n_media_times=cls._N_MEDIA_TIMES,
188
+ n_controls=cls._N_CONTROLS,
189
+ n_media_channels=cls._N_MEDIA_CHANNELS,
190
+ n_rf_channels=cls._N_RF_CHANNELS,
149
191
  seed=0,
150
192
  )
151
193
  )
152
- self.input_data_with_media_only = (
194
+ cls._input_data_with_media_only = (
153
195
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
154
- n_geos=self._N_GEOS,
155
- n_times=self._N_TIMES,
156
- n_media_times=self._N_MEDIA_TIMES,
157
- n_controls=self._N_CONTROLS,
158
- n_media_channels=self._N_MEDIA_CHANNELS,
196
+ n_geos=cls._N_GEOS,
197
+ n_times=cls._N_TIMES,
198
+ n_media_times=cls._N_MEDIA_TIMES,
199
+ n_controls=cls._N_CONTROLS,
200
+ n_media_channels=cls._N_MEDIA_CHANNELS,
159
201
  seed=0,
160
202
  )
161
203
  )
162
- self.input_data_with_rf_only = (
204
+ cls._input_data_with_rf_only = (
163
205
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
164
- n_geos=self._N_GEOS,
165
- n_times=self._N_TIMES,
166
- n_media_times=self._N_MEDIA_TIMES,
167
- n_controls=self._N_CONTROLS,
168
- n_rf_channels=self._N_RF_CHANNELS,
206
+ n_geos=cls._N_GEOS,
207
+ n_times=cls._N_TIMES,
208
+ n_media_times=cls._N_MEDIA_TIMES,
209
+ n_controls=cls._N_CONTROLS,
210
+ n_rf_channels=cls._N_RF_CHANNELS,
169
211
  seed=0,
170
212
  )
171
213
  )
172
- self.input_data_with_media_and_rf = (
214
+ cls._input_data_with_media_and_rf = (
173
215
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
174
- n_geos=self._N_GEOS,
175
- n_times=self._N_TIMES,
176
- n_media_times=self._N_MEDIA_TIMES,
177
- n_controls=self._N_CONTROLS,
178
- n_media_channels=self._N_MEDIA_CHANNELS,
179
- n_rf_channels=self._N_RF_CHANNELS,
216
+ n_geos=cls._N_GEOS,
217
+ n_times=cls._N_TIMES,
218
+ n_media_times=cls._N_MEDIA_TIMES,
219
+ n_controls=cls._N_CONTROLS,
220
+ n_media_channels=cls._N_MEDIA_CHANNELS,
221
+ n_rf_channels=cls._N_RF_CHANNELS,
180
222
  seed=0,
181
223
  )
182
224
  )
183
- self.input_data_with_media_and_rf_no_controls = (
225
+ cls._input_data_with_media_and_rf_no_controls = (
184
226
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
185
- n_geos=self._N_GEOS,
186
- n_times=self._N_TIMES,
187
- n_media_times=self._N_MEDIA_TIMES,
227
+ n_geos=cls._N_GEOS,
228
+ n_times=cls._N_TIMES,
229
+ n_media_times=cls._N_MEDIA_TIMES,
188
230
  n_controls=None,
189
- n_media_channels=self._N_MEDIA_CHANNELS,
190
- n_rf_channels=self._N_RF_CHANNELS,
231
+ n_media_channels=cls._N_MEDIA_CHANNELS,
232
+ n_rf_channels=cls._N_RF_CHANNELS,
191
233
  seed=0,
192
234
  )
193
235
  )
194
- self.short_input_data_with_media_only = (
236
+ cls._short_input_data_with_media_only = (
195
237
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
196
- n_geos=self._N_GEOS,
197
- n_times=self._N_TIMES_SHORT,
198
- n_media_times=self._N_MEDIA_TIMES_SHORT,
199
- n_controls=self._N_CONTROLS,
200
- n_media_channels=self._N_MEDIA_CHANNELS,
238
+ n_geos=cls._N_GEOS,
239
+ n_times=cls._N_TIMES_SHORT,
240
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
241
+ n_controls=cls._N_CONTROLS,
242
+ n_media_channels=cls._N_MEDIA_CHANNELS,
201
243
  seed=0,
202
244
  )
203
245
  )
204
- self.short_input_data_with_media_only_no_controls = (
246
+ cls._short_input_data_with_media_only_no_controls = (
205
247
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
206
- n_geos=self._N_GEOS,
207
- n_times=self._N_TIMES_SHORT,
208
- n_media_times=self._N_MEDIA_TIMES_SHORT,
248
+ n_geos=cls._N_GEOS,
249
+ n_times=cls._N_TIMES_SHORT,
250
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
209
251
  n_controls=0,
210
- n_media_channels=self._N_MEDIA_CHANNELS,
252
+ n_media_channels=cls._N_MEDIA_CHANNELS,
211
253
  seed=0,
212
254
  )
213
255
  )
214
- self.short_input_data_with_rf_only = (
256
+ cls._short_input_data_with_rf_only = (
215
257
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
216
- n_geos=self._N_GEOS,
217
- n_times=self._N_TIMES_SHORT,
218
- n_media_times=self._N_MEDIA_TIMES_SHORT,
219
- n_controls=self._N_CONTROLS,
220
- n_rf_channels=self._N_RF_CHANNELS,
258
+ n_geos=cls._N_GEOS,
259
+ n_times=cls._N_TIMES_SHORT,
260
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
261
+ n_controls=cls._N_CONTROLS,
262
+ n_rf_channels=cls._N_RF_CHANNELS,
221
263
  seed=0,
222
264
  )
223
265
  )
224
- self.short_input_data_with_media_and_rf = (
266
+ cls._short_input_data_with_media_and_rf = (
225
267
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
226
- n_geos=self._N_GEOS,
227
- n_times=self._N_TIMES_SHORT,
228
- n_media_times=self._N_MEDIA_TIMES_SHORT,
229
- n_controls=self._N_CONTROLS,
230
- n_media_channels=self._N_MEDIA_CHANNELS,
231
- n_rf_channels=self._N_RF_CHANNELS,
268
+ n_geos=cls._N_GEOS,
269
+ n_times=cls._N_TIMES_SHORT,
270
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
271
+ n_controls=cls._N_CONTROLS,
272
+ n_media_channels=cls._N_MEDIA_CHANNELS,
273
+ n_rf_channels=cls._N_RF_CHANNELS,
232
274
  seed=0,
233
275
  )
234
276
  )
235
- self.national_input_data_media_only = (
277
+ cls._national_input_data_media_only = (
236
278
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
237
- n_geos=self._N_GEOS_NATIONAL,
238
- n_times=self._N_TIMES,
239
- n_media_times=self._N_MEDIA_TIMES,
240
- n_controls=self._N_CONTROLS,
241
- n_media_channels=self._N_MEDIA_CHANNELS,
279
+ n_geos=cls._N_GEOS_NATIONAL,
280
+ n_times=cls._N_TIMES,
281
+ n_media_times=cls._N_MEDIA_TIMES,
282
+ n_controls=cls._N_CONTROLS,
283
+ n_media_channels=cls._N_MEDIA_CHANNELS,
242
284
  seed=0,
243
285
  )
244
286
  )
245
- self.national_input_data_media_and_rf = (
287
+ cls._national_input_data_media_and_rf = (
246
288
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
247
- n_geos=self._N_GEOS_NATIONAL,
248
- n_times=self._N_TIMES,
249
- n_media_times=self._N_MEDIA_TIMES,
250
- n_controls=self._N_CONTROLS,
251
- n_media_channels=self._N_MEDIA_CHANNELS,
252
- n_rf_channels=self._N_RF_CHANNELS,
289
+ n_geos=cls._N_GEOS_NATIONAL,
290
+ n_times=cls._N_TIMES,
291
+ n_media_times=cls._N_MEDIA_TIMES,
292
+ n_controls=cls._N_CONTROLS,
293
+ n_media_channels=cls._N_MEDIA_CHANNELS,
294
+ n_rf_channels=cls._N_RF_CHANNELS,
253
295
  seed=0,
254
296
  )
255
297
  )
256
298
 
257
299
  test_prior_media_and_rf = xr.open_dataset(
258
- self._TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH
300
+ cls._TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH
259
301
  )
260
302
  test_prior_media_only = xr.open_dataset(
261
- self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
303
+ cls._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
262
304
  )
263
305
  test_prior_media_only_no_controls = xr.open_dataset(
264
- self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
306
+ cls._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
265
307
  )
266
- test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
267
- self.test_dist_media_and_rf = collections.OrderedDict({
268
- param: tf.convert_to_tensor(test_prior_media_and_rf[param])
308
+ test_prior_rf_only = xr.open_dataset(cls._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
309
+ cls._test_dist_media_and_rf = collections.OrderedDict({
310
+ param: backend.to_tensor(test_prior_media_and_rf[param])
269
311
  for param in constants.COMMON_PARAMETER_NAMES
270
312
  + constants.MEDIA_PARAMETER_NAMES
271
313
  + constants.RF_PARAMETER_NAMES
272
314
  })
273
- self.test_dist_media_only = collections.OrderedDict({
274
- param: tf.convert_to_tensor(test_prior_media_only[param])
315
+ cls._test_dist_media_only = collections.OrderedDict({
316
+ param: backend.to_tensor(test_prior_media_only[param])
275
317
  for param in constants.COMMON_PARAMETER_NAMES
276
318
  + constants.MEDIA_PARAMETER_NAMES
277
319
  })
278
- self.test_dist_media_only_no_controls = collections.OrderedDict({
279
- param: tf.convert_to_tensor(test_prior_media_only_no_controls[param])
320
+ cls._test_dist_media_only_no_controls = collections.OrderedDict({
321
+ param: backend.to_tensor(test_prior_media_only_no_controls[param])
280
322
  for param in (
281
323
  set(
282
324
  constants.COMMON_PARAMETER_NAMES
@@ -287,27 +329,27 @@ class WithInputDataSamples:
287
329
  )
288
330
  )
289
331
  })
290
- self.test_dist_rf_only = collections.OrderedDict({
291
- param: tf.convert_to_tensor(test_prior_rf_only[param])
332
+ cls._test_dist_rf_only = collections.OrderedDict({
333
+ param: backend.to_tensor(test_prior_rf_only[param])
292
334
  for param in constants.COMMON_PARAMETER_NAMES
293
335
  + constants.RF_PARAMETER_NAMES
294
336
  })
295
337
 
296
338
  test_posterior_media_and_rf = xr.open_dataset(
297
- self._TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH
339
+ cls._TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH
298
340
  )
299
341
  test_posterior_media_only = xr.open_dataset(
300
- self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
342
+ cls._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
301
343
  )
302
344
  test_posterior_media_only_no_controls = xr.open_dataset(
303
- self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
345
+ cls._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
304
346
  )
305
347
  test_posterior_rf_only = xr.open_dataset(
306
- self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
348
+ cls._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
307
349
  )
308
350
  posterior_params_to_tensors_media_and_rf = {
309
351
  param: _convert_with_swap(
310
- test_posterior_media_and_rf[param], n_burnin=self._N_BURNIN
352
+ test_posterior_media_and_rf[param], n_burnin=cls._N_BURNIN
311
353
  )
312
354
  for param in constants.COMMON_PARAMETER_NAMES
313
355
  + constants.MEDIA_PARAMETER_NAMES
@@ -315,7 +357,7 @@ class WithInputDataSamples:
315
357
  }
316
358
  posterior_params_to_tensors_media_only = {
317
359
  param: _convert_with_swap(
318
- test_posterior_media_only[param], n_burnin=self._N_BURNIN
360
+ test_posterior_media_only[param], n_burnin=cls._N_BURNIN
319
361
  )
320
362
  for param in constants.COMMON_PARAMETER_NAMES
321
363
  + constants.MEDIA_PARAMETER_NAMES
@@ -323,7 +365,7 @@ class WithInputDataSamples:
323
365
  posterior_params_to_tensors_media_only_no_controls = {
324
366
  param: _convert_with_swap(
325
367
  test_posterior_media_only_no_controls[param],
326
- n_burnin=self._N_BURNIN,
368
+ n_burnin=cls._N_BURNIN,
327
369
  )
328
370
  for param in (
329
371
  set(
@@ -337,22 +379,22 @@ class WithInputDataSamples:
337
379
  }
338
380
  posterior_params_to_tensors_rf_only = {
339
381
  param: _convert_with_swap(
340
- test_posterior_rf_only[param], n_burnin=self._N_BURNIN
382
+ test_posterior_rf_only[param], n_burnin=cls._N_BURNIN
341
383
  )
342
384
  for param in constants.COMMON_PARAMETER_NAMES
343
385
  + constants.RF_PARAMETER_NAMES
344
386
  }
345
- self.test_posterior_states_media_and_rf = collections.namedtuple(
387
+ cls.test_posterior_states_media_and_rf = collections.namedtuple(
346
388
  "StructTuple",
347
389
  constants.COMMON_PARAMETER_NAMES
348
390
  + constants.MEDIA_PARAMETER_NAMES
349
391
  + constants.RF_PARAMETER_NAMES,
350
392
  )(**posterior_params_to_tensors_media_and_rf)
351
- self.test_posterior_states_media_only = collections.namedtuple(
393
+ cls.test_posterior_states_media_only = collections.namedtuple(
352
394
  "StructTuple",
353
395
  constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
354
396
  )(**posterior_params_to_tensors_media_only)
355
- self.test_posterior_states_media_only_no_controls = collections.namedtuple(
397
+ cls.test_posterior_states_media_only_no_controls = collections.namedtuple(
356
398
  "StructTuple",
357
399
  (
358
400
  set(
@@ -364,73 +406,197 @@ class WithInputDataSamples:
364
406
  )
365
407
  ),
366
408
  )(**posterior_params_to_tensors_media_only_no_controls)
367
- self.test_posterior_states_rf_only = collections.namedtuple(
409
+ cls.test_posterior_states_rf_only = collections.namedtuple(
368
410
  "StructTuple",
369
411
  constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,
370
412
  )(**posterior_params_to_tensors_rf_only)
371
413
 
372
- test_trace = xr.open_dataset(self._TEST_SAMPLE_TRACE_PATH)
373
- self.test_trace = {
374
- param: _convert_with_swap(test_trace[param], n_burnin=self._N_BURNIN)
414
+ test_trace = xr.open_dataset(cls._TEST_SAMPLE_TRACE_PATH)
415
+ cls._test_trace = {
416
+ param: _convert_with_swap(test_trace[param], n_burnin=cls._N_BURNIN)
375
417
  for param in test_trace.data_vars
376
418
  }
377
419
 
378
420
  # The following are input data samples with non-paid channels.
379
421
 
380
- self.national_input_data_non_media_and_organic = (
422
+ cls._national_input_data_non_media_and_organic = (
381
423
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
382
- n_geos=self._N_GEOS_NATIONAL,
383
- n_times=self._N_TIMES,
384
- n_media_times=self._N_MEDIA_TIMES,
385
- n_controls=self._N_CONTROLS,
386
- n_non_media_channels=self._N_NON_MEDIA_CHANNELS,
387
- n_media_channels=self._N_MEDIA_CHANNELS,
388
- n_rf_channels=self._N_RF_CHANNELS,
389
- n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS,
390
- n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS,
424
+ n_geos=cls._N_GEOS_NATIONAL,
425
+ n_times=cls._N_TIMES,
426
+ n_media_times=cls._N_MEDIA_TIMES,
427
+ n_controls=cls._N_CONTROLS,
428
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
429
+ n_media_channels=cls._N_MEDIA_CHANNELS,
430
+ n_rf_channels=cls._N_RF_CHANNELS,
431
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
432
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
391
433
  seed=0,
392
434
  )
393
435
  )
394
436
 
395
- self.input_data_non_media_and_organic = (
437
+ cls._input_data_non_media_and_organic = (
396
438
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
397
- n_geos=self._N_GEOS,
398
- n_times=self._N_TIMES,
399
- n_media_times=self._N_MEDIA_TIMES,
400
- n_controls=self._N_CONTROLS,
401
- n_non_media_channels=self._N_NON_MEDIA_CHANNELS,
402
- n_media_channels=self._N_MEDIA_CHANNELS,
403
- n_rf_channels=self._N_RF_CHANNELS,
404
- n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS,
405
- n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS,
439
+ n_geos=cls._N_GEOS,
440
+ n_times=cls._N_TIMES,
441
+ n_media_times=cls._N_MEDIA_TIMES,
442
+ n_controls=cls._N_CONTROLS,
443
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
444
+ n_media_channels=cls._N_MEDIA_CHANNELS,
445
+ n_rf_channels=cls._N_RF_CHANNELS,
446
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
447
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
406
448
  seed=0,
407
449
  )
408
450
  )
409
- self.short_input_data_non_media_and_organic = (
451
+ cls._short_input_data_non_media_and_organic = (
410
452
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
411
- n_geos=self._N_GEOS,
412
- n_times=self._N_TIMES_SHORT,
413
- n_media_times=self._N_MEDIA_TIMES_SHORT,
414
- n_controls=self._N_CONTROLS,
415
- n_non_media_channels=self._N_NON_MEDIA_CHANNELS,
416
- n_media_channels=self._N_MEDIA_CHANNELS,
417
- n_rf_channels=self._N_RF_CHANNELS,
418
- n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS,
419
- n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS,
453
+ n_geos=cls._N_GEOS,
454
+ n_times=cls._N_TIMES_SHORT,
455
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
456
+ n_controls=cls._N_CONTROLS,
457
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
458
+ n_media_channels=cls._N_MEDIA_CHANNELS,
459
+ n_rf_channels=cls._N_RF_CHANNELS,
460
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
461
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
420
462
  seed=0,
421
463
  )
422
464
  )
423
- self.short_input_data_non_media = (
465
+ cls._short_input_data_non_media = (
424
466
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
425
- n_geos=self._N_GEOS,
426
- n_times=self._N_TIMES_SHORT,
427
- n_media_times=self._N_MEDIA_TIMES_SHORT,
428
- n_controls=self._N_CONTROLS,
429
- n_non_media_channels=self._N_NON_MEDIA_CHANNELS,
430
- n_media_channels=self._N_MEDIA_CHANNELS,
431
- n_rf_channels=self._N_RF_CHANNELS,
467
+ n_geos=cls._N_GEOS,
468
+ n_times=cls._N_TIMES_SHORT,
469
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
470
+ n_controls=cls._N_CONTROLS,
471
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
472
+ n_media_channels=cls._N_MEDIA_CHANNELS,
473
+ n_rf_channels=cls._N_RF_CHANNELS,
432
474
  n_organic_media_channels=0,
433
475
  n_organic_rf_channels=0,
434
476
  seed=0,
435
477
  )
436
478
  )
479
+ cls._input_data_non_media_and_organic_same_time_dims = (
480
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
481
+ n_geos=cls._N_GEOS,
482
+ n_times=cls._N_TIMES,
483
+ n_media_times=cls._N_TIMES,
484
+ n_controls=cls._N_CONTROLS,
485
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
486
+ n_media_channels=cls._N_MEDIA_CHANNELS,
487
+ n_rf_channels=cls._N_RF_CHANNELS,
488
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
489
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
490
+ seed=0,
491
+ )
492
+ )
493
+
494
+ @property
495
+ def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
496
+ return self._input_data_non_revenue_no_revenue_per_kpi.copy(deep=True)
497
+
498
+ @property
499
+ def input_data_media_and_rf_non_revenue_no_revenue_per_kpi(
500
+ self,
501
+ ) -> input_data.InputData:
502
+ return self._input_data_media_and_rf_non_revenue_no_revenue_per_kpi.copy(
503
+ deep=True
504
+ )
505
+
506
+ @property
507
+ def input_data_with_media_only(self) -> input_data.InputData:
508
+ return self._input_data_with_media_only.copy(deep=True)
509
+
510
+ @property
511
+ def input_data_with_rf_only(self) -> input_data.InputData:
512
+ return self._input_data_with_rf_only.copy(deep=True)
513
+
514
+ @property
515
+ def input_data_with_media_and_rf(self) -> input_data.InputData:
516
+ return self._input_data_with_media_and_rf.copy(deep=True)
517
+
518
+ @property
519
+ def input_data_with_media_and_rf_no_controls(
520
+ self,
521
+ ) -> input_data.InputData:
522
+ return self._input_data_with_media_and_rf_no_controls.copy(deep=True)
523
+
524
+ @property
525
+ def short_input_data_with_media_only(self) -> input_data.InputData:
526
+ return self._short_input_data_with_media_only.copy(deep=True)
527
+
528
+ @property
529
+ def short_input_data_with_media_only_no_controls(
530
+ self,
531
+ ) -> input_data.InputData:
532
+ return self._short_input_data_with_media_only_no_controls.copy(deep=True)
533
+
534
+ @property
535
+ def short_input_data_with_rf_only(self) -> input_data.InputData:
536
+ return self._short_input_data_with_rf_only.copy(deep=True)
537
+
538
+ @property
539
+ def short_input_data_with_media_and_rf(
540
+ self,
541
+ ) -> input_data.InputData:
542
+ return self._short_input_data_with_media_and_rf.copy(deep=True)
543
+
544
+ @property
545
+ def national_input_data_media_only(self) -> input_data.InputData:
546
+ return self._national_input_data_media_only.copy(deep=True)
547
+
548
+ @property
549
+ def national_input_data_media_and_rf(self) -> input_data.InputData:
550
+ return self._national_input_data_media_and_rf.copy(deep=True)
551
+
552
+ @property
553
+ def test_dist_media_and_rf(
554
+ self,
555
+ ) -> collections.OrderedDict[str, backend.Tensor]:
556
+ return collections.OrderedDict(self._test_dist_media_and_rf)
557
+
558
+ @property
559
+ def test_dist_media_only(
560
+ self,
561
+ ) -> collections.OrderedDict[str, backend.Tensor]:
562
+ return collections.OrderedDict(self._test_dist_media_only)
563
+
564
+ @property
565
+ def test_dist_media_only_no_controls(
566
+ self,
567
+ ) -> collections.OrderedDict[str, backend.Tensor]:
568
+ return collections.OrderedDict(self._test_dist_media_only_no_controls)
569
+
570
+ @property
571
+ def test_dist_rf_only(self) -> collections.OrderedDict[str, backend.Tensor]:
572
+ return collections.OrderedDict(self._test_dist_rf_only)
573
+
574
+ @property
575
+ def test_trace(self) -> dict[str, backend.Tensor]:
576
+ return self._test_trace.copy()
577
+
578
+ @property
579
+ def national_input_data_non_media_and_organic(
580
+ self,
581
+ ) -> input_data.InputData:
582
+ return self._national_input_data_non_media_and_organic.copy(deep=True)
583
+
584
+ @property
585
+ def input_data_non_media_and_organic(self) -> input_data.InputData:
586
+ return self._input_data_non_media_and_organic.copy(deep=True)
587
+
588
+ @property
589
+ def short_input_data_non_media_and_organic(
590
+ self,
591
+ ) -> input_data.InputData:
592
+ return self._short_input_data_non_media_and_organic.copy(deep=True)
593
+
594
+ @property
595
+ def short_input_data_non_media(self) -> input_data.InputData:
596
+ return self._short_input_data_non_media.copy(deep=True)
597
+
598
+ @property
599
+ def input_data_non_media_and_organic_same_time_dims(
600
+ self,
601
+ ) -> input_data.InputData:
602
+ return self._input_data_non_media_and_organic_same_time_dims.copy(deep=True)