google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,18 +16,20 @@
16
16
 
17
17
  import collections
18
18
  import os
19
+ from typing import NamedTuple
19
20
 
21
+ from meridian import backend
20
22
  from meridian import constants
23
+ from meridian.data import input_data
21
24
  from meridian.data import test_utils
22
- import tensorflow as tf
23
25
  import xarray as xr
24
26
 
25
27
 
26
- def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
27
- """Converts a DataArray to a tf.Tensor with the correct MCMC format.
28
+ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> backend.Tensor:
29
+ """Converts a DataArray to a backend.Tensor with the correct MCMC format.
28
30
 
29
- This function converts a DataArray to tf.Tensor, swaps first two dimensions
30
- and adds the burnin part. This is needed to properly mock the
31
+ This function converts a DataArray to backend.Tensor, swaps first two
32
+ dimensions and adds the burnin part. This is needed to properly mock the
31
33
  _xla_windowed_adaptive_nuts() function output in the sample_posterior
32
34
  tests.
33
35
 
@@ -39,9 +41,9 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
39
41
  A tensor in the same format as returned by the _xla_windowed_adaptive_nuts()
40
42
  function.
41
43
  """
42
- tensor = tf.convert_to_tensor(array)
44
+ tensor = backend.to_tensor(array)
43
45
  perm = [1, 0] + [i for i in range(2, len(tensor.shape))]
44
- transposed_tensor = tf.transpose(tensor, perm=perm)
46
+ transposed_tensor = backend.transpose(tensor, perm=perm)
45
47
 
46
48
  # Add the "burnin" part to the mocked output of _xla_windowed_adaptive_nuts
47
49
  # to make sure sample_posterior returns the correct "keep" part.
@@ -50,15 +52,23 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor:
50
52
  else:
51
53
  pad_value = 0.0 if array.dtype.kind == "f" else 0
52
54
 
53
- burnin = tf.fill([n_burnin] + transposed_tensor.shape[1:], pad_value)
54
- return tf.concat(
55
+ burnin = backend.fill(
56
+ [n_burnin] + list(transposed_tensor.shape[1:]), pad_value
57
+ )
58
+ return backend.concatenate(
55
59
  [burnin, transposed_tensor],
56
60
  axis=0,
57
61
  )
58
62
 
59
63
 
60
64
  class WithInputDataSamples:
61
- """A mixin to inject test data samples to a unit test class."""
65
+ """A mixin to inject test data samples to a unit test class.
66
+
67
+ The `setup` method is a classmethod because loading and creating the data
68
+ samples can be expensive. The properties return deep copies to ensure
69
+ immutability for the data properties. As a result, it is recommended to use
70
+ local variables to avoid re-loading the data samples multiple times.
71
+ """
62
72
 
63
73
  # TODO: Update the sample data to span over 1 or 2 year(s).
64
74
  _TEST_DIR = os.path.join(os.path.dirname(__file__), "test_data")
@@ -114,169 +124,207 @@ class WithInputDataSamples:
114
124
  _N_MEDIA_CHANNELS = 3
115
125
  _N_RF_CHANNELS = 2
116
126
  _N_CONTROLS = 2
117
- _ROI_CALIBRATION_PERIOD = tf.cast(
118
- tf.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
119
- dtype=tf.bool,
120
- )
121
- _RF_ROI_CALIBRATION_PERIOD = tf.cast(
122
- tf.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
123
- dtype=tf.bool,
124
- )
125
127
  _N_ORGANIC_MEDIA_CHANNELS = 4
126
128
  _N_ORGANIC_RF_CHANNELS = 1
127
129
  _N_NON_MEDIA_CHANNELS = 2
128
130
 
129
- def setup(self):
131
+ _ROI_CALIBRATION_PERIOD: backend.Tensor
132
+ _RF_ROI_CALIBRATION_PERIOD: backend.Tensor
133
+
134
+ # Private class variables to hold the base test data.
135
+ _input_data_non_revenue_no_revenue_per_kpi: input_data.InputData
136
+ _input_data_media_and_rf_non_revenue_no_revenue_per_kpi: input_data.InputData
137
+ _input_data_with_media_only: input_data.InputData
138
+ _input_data_with_rf_only: input_data.InputData
139
+ _input_data_with_media_and_rf: input_data.InputData
140
+ _input_data_with_media_and_rf_no_controls: input_data.InputData
141
+ _short_input_data_with_media_only: input_data.InputData
142
+ _short_input_data_with_media_only_no_controls: input_data.InputData
143
+ _short_input_data_with_rf_only: input_data.InputData
144
+ _short_input_data_with_media_and_rf: input_data.InputData
145
+ _national_input_data_media_only: input_data.InputData
146
+ _national_input_data_media_and_rf: input_data.InputData
147
+ _test_dist_media_and_rf: collections.OrderedDict[str, backend.Tensor]
148
+ _test_dist_media_only: collections.OrderedDict[str, backend.Tensor]
149
+ _test_dist_media_only_no_controls: collections.OrderedDict[
150
+ str, backend.Tensor
151
+ ]
152
+ _test_dist_rf_only: collections.OrderedDict[str, backend.Tensor]
153
+ _test_trace: dict[str, backend.Tensor]
154
+ _national_input_data_non_media_and_organic: input_data.InputData
155
+ _input_data_non_media_and_organic: input_data.InputData
156
+ _short_input_data_non_media_and_organic: input_data.InputData
157
+ _short_input_data_non_media: input_data.InputData
158
+ _input_data_non_media_and_organic_same_time_dims: input_data.InputData
159
+
160
+ # The following NamedTuples and their attributes are immutable, so they can
161
+ # be accessed directly.
162
+ test_posterior_states_media_and_rf: NamedTuple
163
+ test_posterior_states_media_only: NamedTuple
164
+ test_posterior_states_media_only_no_controls: NamedTuple
165
+ test_posterior_states_rf_only: NamedTuple
166
+
167
+ @classmethod
168
+ def setup(cls):
130
169
  """Sets up input data samples."""
131
- self.input_data_non_revenue_no_revenue_per_kpi = (
170
+ cls._ROI_CALIBRATION_PERIOD = backend.cast(
171
+ backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_MEDIA_CHANNELS)),
172
+ dtype=backend.bool_,
173
+ )
174
+ cls._RF_ROI_CALIBRATION_PERIOD = backend.cast(
175
+ backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_RF_CHANNELS)),
176
+ dtype=backend.bool_,
177
+ )
178
+
179
+ cls._input_data_non_revenue_no_revenue_per_kpi = (
132
180
  test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
133
- n_geos=self._N_GEOS,
134
- n_times=self._N_TIMES,
135
- n_media_times=self._N_MEDIA_TIMES,
136
- n_controls=self._N_CONTROLS,
137
- n_media_channels=self._N_MEDIA_CHANNELS,
181
+ n_geos=cls._N_GEOS,
182
+ n_times=cls._N_TIMES,
183
+ n_media_times=cls._N_MEDIA_TIMES,
184
+ n_controls=cls._N_CONTROLS,
185
+ n_media_channels=cls._N_MEDIA_CHANNELS,
138
186
  seed=0,
139
187
  )
140
188
  )
141
- self.input_data_media_and_rf_non_revenue_no_revenue_per_kpi = (
189
+ cls._input_data_media_and_rf_non_revenue_no_revenue_per_kpi = (
142
190
  test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
143
- n_geos=self._N_GEOS,
144
- n_times=self._N_TIMES,
145
- n_media_times=self._N_MEDIA_TIMES,
146
- n_controls=self._N_CONTROLS,
147
- n_media_channels=self._N_MEDIA_CHANNELS,
148
- n_rf_channels=self._N_RF_CHANNELS,
191
+ n_geos=cls._N_GEOS,
192
+ n_times=cls._N_TIMES,
193
+ n_media_times=cls._N_MEDIA_TIMES,
194
+ n_controls=cls._N_CONTROLS,
195
+ n_media_channels=cls._N_MEDIA_CHANNELS,
196
+ n_rf_channels=cls._N_RF_CHANNELS,
149
197
  seed=0,
150
198
  )
151
199
  )
152
- self.input_data_with_media_only = (
200
+ cls._input_data_with_media_only = (
153
201
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
154
- n_geos=self._N_GEOS,
155
- n_times=self._N_TIMES,
156
- n_media_times=self._N_MEDIA_TIMES,
157
- n_controls=self._N_CONTROLS,
158
- n_media_channels=self._N_MEDIA_CHANNELS,
202
+ n_geos=cls._N_GEOS,
203
+ n_times=cls._N_TIMES,
204
+ n_media_times=cls._N_MEDIA_TIMES,
205
+ n_controls=cls._N_CONTROLS,
206
+ n_media_channels=cls._N_MEDIA_CHANNELS,
159
207
  seed=0,
160
208
  )
161
209
  )
162
- self.input_data_with_rf_only = (
210
+ cls._input_data_with_rf_only = (
163
211
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
164
- n_geos=self._N_GEOS,
165
- n_times=self._N_TIMES,
166
- n_media_times=self._N_MEDIA_TIMES,
167
- n_controls=self._N_CONTROLS,
168
- n_rf_channels=self._N_RF_CHANNELS,
212
+ n_geos=cls._N_GEOS,
213
+ n_times=cls._N_TIMES,
214
+ n_media_times=cls._N_MEDIA_TIMES,
215
+ n_controls=cls._N_CONTROLS,
216
+ n_rf_channels=cls._N_RF_CHANNELS,
169
217
  seed=0,
170
218
  )
171
219
  )
172
- self.input_data_with_media_and_rf = (
220
+ cls._input_data_with_media_and_rf = (
173
221
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
174
- n_geos=self._N_GEOS,
175
- n_times=self._N_TIMES,
176
- n_media_times=self._N_MEDIA_TIMES,
177
- n_controls=self._N_CONTROLS,
178
- n_media_channels=self._N_MEDIA_CHANNELS,
179
- n_rf_channels=self._N_RF_CHANNELS,
222
+ n_geos=cls._N_GEOS,
223
+ n_times=cls._N_TIMES,
224
+ n_media_times=cls._N_MEDIA_TIMES,
225
+ n_controls=cls._N_CONTROLS,
226
+ n_media_channels=cls._N_MEDIA_CHANNELS,
227
+ n_rf_channels=cls._N_RF_CHANNELS,
180
228
  seed=0,
181
229
  )
182
230
  )
183
- self.input_data_with_media_and_rf_no_controls = (
231
+ cls._input_data_with_media_and_rf_no_controls = (
184
232
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
185
- n_geos=self._N_GEOS,
186
- n_times=self._N_TIMES,
187
- n_media_times=self._N_MEDIA_TIMES,
233
+ n_geos=cls._N_GEOS,
234
+ n_times=cls._N_TIMES,
235
+ n_media_times=cls._N_MEDIA_TIMES,
188
236
  n_controls=None,
189
- n_media_channels=self._N_MEDIA_CHANNELS,
190
- n_rf_channels=self._N_RF_CHANNELS,
237
+ n_media_channels=cls._N_MEDIA_CHANNELS,
238
+ n_rf_channels=cls._N_RF_CHANNELS,
191
239
  seed=0,
192
240
  )
193
241
  )
194
- self.short_input_data_with_media_only = (
242
+ cls._short_input_data_with_media_only = (
195
243
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
196
- n_geos=self._N_GEOS,
197
- n_times=self._N_TIMES_SHORT,
198
- n_media_times=self._N_MEDIA_TIMES_SHORT,
199
- n_controls=self._N_CONTROLS,
200
- n_media_channels=self._N_MEDIA_CHANNELS,
244
+ n_geos=cls._N_GEOS,
245
+ n_times=cls._N_TIMES_SHORT,
246
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
247
+ n_controls=cls._N_CONTROLS,
248
+ n_media_channels=cls._N_MEDIA_CHANNELS,
201
249
  seed=0,
202
250
  )
203
251
  )
204
- self.short_input_data_with_media_only_no_controls = (
252
+ cls._short_input_data_with_media_only_no_controls = (
205
253
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
206
- n_geos=self._N_GEOS,
207
- n_times=self._N_TIMES_SHORT,
208
- n_media_times=self._N_MEDIA_TIMES_SHORT,
254
+ n_geos=cls._N_GEOS,
255
+ n_times=cls._N_TIMES_SHORT,
256
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
209
257
  n_controls=0,
210
- n_media_channels=self._N_MEDIA_CHANNELS,
258
+ n_media_channels=cls._N_MEDIA_CHANNELS,
211
259
  seed=0,
212
260
  )
213
261
  )
214
- self.short_input_data_with_rf_only = (
262
+ cls._short_input_data_with_rf_only = (
215
263
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
216
- n_geos=self._N_GEOS,
217
- n_times=self._N_TIMES_SHORT,
218
- n_media_times=self._N_MEDIA_TIMES_SHORT,
219
- n_controls=self._N_CONTROLS,
220
- n_rf_channels=self._N_RF_CHANNELS,
264
+ n_geos=cls._N_GEOS,
265
+ n_times=cls._N_TIMES_SHORT,
266
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
267
+ n_controls=cls._N_CONTROLS,
268
+ n_rf_channels=cls._N_RF_CHANNELS,
221
269
  seed=0,
222
270
  )
223
271
  )
224
- self.short_input_data_with_media_and_rf = (
272
+ cls._short_input_data_with_media_and_rf = (
225
273
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
226
- n_geos=self._N_GEOS,
227
- n_times=self._N_TIMES_SHORT,
228
- n_media_times=self._N_MEDIA_TIMES_SHORT,
229
- n_controls=self._N_CONTROLS,
230
- n_media_channels=self._N_MEDIA_CHANNELS,
231
- n_rf_channels=self._N_RF_CHANNELS,
274
+ n_geos=cls._N_GEOS,
275
+ n_times=cls._N_TIMES_SHORT,
276
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
277
+ n_controls=cls._N_CONTROLS,
278
+ n_media_channels=cls._N_MEDIA_CHANNELS,
279
+ n_rf_channels=cls._N_RF_CHANNELS,
232
280
  seed=0,
233
281
  )
234
282
  )
235
- self.national_input_data_media_only = (
283
+ cls._national_input_data_media_only = (
236
284
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
237
- n_geos=self._N_GEOS_NATIONAL,
238
- n_times=self._N_TIMES,
239
- n_media_times=self._N_MEDIA_TIMES,
240
- n_controls=self._N_CONTROLS,
241
- n_media_channels=self._N_MEDIA_CHANNELS,
285
+ n_geos=cls._N_GEOS_NATIONAL,
286
+ n_times=cls._N_TIMES,
287
+ n_media_times=cls._N_MEDIA_TIMES,
288
+ n_controls=cls._N_CONTROLS,
289
+ n_media_channels=cls._N_MEDIA_CHANNELS,
242
290
  seed=0,
243
291
  )
244
292
  )
245
- self.national_input_data_media_and_rf = (
293
+ cls._national_input_data_media_and_rf = (
246
294
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
247
- n_geos=self._N_GEOS_NATIONAL,
248
- n_times=self._N_TIMES,
249
- n_media_times=self._N_MEDIA_TIMES,
250
- n_controls=self._N_CONTROLS,
251
- n_media_channels=self._N_MEDIA_CHANNELS,
252
- n_rf_channels=self._N_RF_CHANNELS,
295
+ n_geos=cls._N_GEOS_NATIONAL,
296
+ n_times=cls._N_TIMES,
297
+ n_media_times=cls._N_MEDIA_TIMES,
298
+ n_controls=cls._N_CONTROLS,
299
+ n_media_channels=cls._N_MEDIA_CHANNELS,
300
+ n_rf_channels=cls._N_RF_CHANNELS,
253
301
  seed=0,
254
302
  )
255
303
  )
256
304
 
257
305
  test_prior_media_and_rf = xr.open_dataset(
258
- self._TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH
306
+ cls._TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH
259
307
  )
260
308
  test_prior_media_only = xr.open_dataset(
261
- self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
309
+ cls._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH
262
310
  )
263
311
  test_prior_media_only_no_controls = xr.open_dataset(
264
- self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
312
+ cls._TEST_SAMPLE_PRIOR_MEDIA_ONLY_NO_CONTROLS_PATH
265
313
  )
266
- test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
267
- self.test_dist_media_and_rf = collections.OrderedDict({
268
- param: tf.convert_to_tensor(test_prior_media_and_rf[param])
314
+ test_prior_rf_only = xr.open_dataset(cls._TEST_SAMPLE_PRIOR_RF_ONLY_PATH)
315
+ cls._test_dist_media_and_rf = collections.OrderedDict({
316
+ param: backend.to_tensor(test_prior_media_and_rf[param])
269
317
  for param in constants.COMMON_PARAMETER_NAMES
270
318
  + constants.MEDIA_PARAMETER_NAMES
271
319
  + constants.RF_PARAMETER_NAMES
272
320
  })
273
- self.test_dist_media_only = collections.OrderedDict({
274
- param: tf.convert_to_tensor(test_prior_media_only[param])
321
+ cls._test_dist_media_only = collections.OrderedDict({
322
+ param: backend.to_tensor(test_prior_media_only[param])
275
323
  for param in constants.COMMON_PARAMETER_NAMES
276
324
  + constants.MEDIA_PARAMETER_NAMES
277
325
  })
278
- self.test_dist_media_only_no_controls = collections.OrderedDict({
279
- param: tf.convert_to_tensor(test_prior_media_only_no_controls[param])
326
+ cls._test_dist_media_only_no_controls = collections.OrderedDict({
327
+ param: backend.to_tensor(test_prior_media_only_no_controls[param])
280
328
  for param in (
281
329
  set(
282
330
  constants.COMMON_PARAMETER_NAMES
@@ -287,27 +335,27 @@ class WithInputDataSamples:
287
335
  )
288
336
  )
289
337
  })
290
- self.test_dist_rf_only = collections.OrderedDict({
291
- param: tf.convert_to_tensor(test_prior_rf_only[param])
338
+ cls._test_dist_rf_only = collections.OrderedDict({
339
+ param: backend.to_tensor(test_prior_rf_only[param])
292
340
  for param in constants.COMMON_PARAMETER_NAMES
293
341
  + constants.RF_PARAMETER_NAMES
294
342
  })
295
343
 
296
344
  test_posterior_media_and_rf = xr.open_dataset(
297
- self._TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH
345
+ cls._TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH
298
346
  )
299
347
  test_posterior_media_only = xr.open_dataset(
300
- self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
348
+ cls._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH
301
349
  )
302
350
  test_posterior_media_only_no_controls = xr.open_dataset(
303
- self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
351
+ cls._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_NO_CONTROLS_PATH
304
352
  )
305
353
  test_posterior_rf_only = xr.open_dataset(
306
- self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
354
+ cls._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH
307
355
  )
308
356
  posterior_params_to_tensors_media_and_rf = {
309
357
  param: _convert_with_swap(
310
- test_posterior_media_and_rf[param], n_burnin=self._N_BURNIN
358
+ test_posterior_media_and_rf[param], n_burnin=cls._N_BURNIN
311
359
  )
312
360
  for param in constants.COMMON_PARAMETER_NAMES
313
361
  + constants.MEDIA_PARAMETER_NAMES
@@ -315,7 +363,7 @@ class WithInputDataSamples:
315
363
  }
316
364
  posterior_params_to_tensors_media_only = {
317
365
  param: _convert_with_swap(
318
- test_posterior_media_only[param], n_burnin=self._N_BURNIN
366
+ test_posterior_media_only[param], n_burnin=cls._N_BURNIN
319
367
  )
320
368
  for param in constants.COMMON_PARAMETER_NAMES
321
369
  + constants.MEDIA_PARAMETER_NAMES
@@ -323,7 +371,7 @@ class WithInputDataSamples:
323
371
  posterior_params_to_tensors_media_only_no_controls = {
324
372
  param: _convert_with_swap(
325
373
  test_posterior_media_only_no_controls[param],
326
- n_burnin=self._N_BURNIN,
374
+ n_burnin=cls._N_BURNIN,
327
375
  )
328
376
  for param in (
329
377
  set(
@@ -337,22 +385,22 @@ class WithInputDataSamples:
337
385
  }
338
386
  posterior_params_to_tensors_rf_only = {
339
387
  param: _convert_with_swap(
340
- test_posterior_rf_only[param], n_burnin=self._N_BURNIN
388
+ test_posterior_rf_only[param], n_burnin=cls._N_BURNIN
341
389
  )
342
390
  for param in constants.COMMON_PARAMETER_NAMES
343
391
  + constants.RF_PARAMETER_NAMES
344
392
  }
345
- self.test_posterior_states_media_and_rf = collections.namedtuple(
393
+ cls.test_posterior_states_media_and_rf = collections.namedtuple(
346
394
  "StructTuple",
347
395
  constants.COMMON_PARAMETER_NAMES
348
396
  + constants.MEDIA_PARAMETER_NAMES
349
397
  + constants.RF_PARAMETER_NAMES,
350
398
  )(**posterior_params_to_tensors_media_and_rf)
351
- self.test_posterior_states_media_only = collections.namedtuple(
399
+ cls.test_posterior_states_media_only = collections.namedtuple(
352
400
  "StructTuple",
353
401
  constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES,
354
402
  )(**posterior_params_to_tensors_media_only)
355
- self.test_posterior_states_media_only_no_controls = collections.namedtuple(
403
+ cls.test_posterior_states_media_only_no_controls = collections.namedtuple(
356
404
  "StructTuple",
357
405
  (
358
406
  set(
@@ -364,73 +412,197 @@ class WithInputDataSamples:
364
412
  )
365
413
  ),
366
414
  )(**posterior_params_to_tensors_media_only_no_controls)
367
- self.test_posterior_states_rf_only = collections.namedtuple(
415
+ cls.test_posterior_states_rf_only = collections.namedtuple(
368
416
  "StructTuple",
369
417
  constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES,
370
418
  )(**posterior_params_to_tensors_rf_only)
371
419
 
372
- test_trace = xr.open_dataset(self._TEST_SAMPLE_TRACE_PATH)
373
- self.test_trace = {
374
- param: _convert_with_swap(test_trace[param], n_burnin=self._N_BURNIN)
420
+ test_trace = xr.open_dataset(cls._TEST_SAMPLE_TRACE_PATH)
421
+ cls._test_trace = {
422
+ param: _convert_with_swap(test_trace[param], n_burnin=cls._N_BURNIN)
375
423
  for param in test_trace.data_vars
376
424
  }
377
425
 
378
426
  # The following are input data samples with non-paid channels.
379
427
 
380
- self.national_input_data_non_media_and_organic = (
428
+ cls._national_input_data_non_media_and_organic = (
381
429
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
382
- n_geos=self._N_GEOS_NATIONAL,
383
- n_times=self._N_TIMES,
384
- n_media_times=self._N_MEDIA_TIMES,
385
- n_controls=self._N_CONTROLS,
386
- n_non_media_channels=self._N_NON_MEDIA_CHANNELS,
387
- n_media_channels=self._N_MEDIA_CHANNELS,
388
- n_rf_channels=self._N_RF_CHANNELS,
389
- n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS,
390
- n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS,
430
+ n_geos=cls._N_GEOS_NATIONAL,
431
+ n_times=cls._N_TIMES,
432
+ n_media_times=cls._N_MEDIA_TIMES,
433
+ n_controls=cls._N_CONTROLS,
434
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
435
+ n_media_channels=cls._N_MEDIA_CHANNELS,
436
+ n_rf_channels=cls._N_RF_CHANNELS,
437
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
438
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
391
439
  seed=0,
392
440
  )
393
441
  )
394
442
 
395
- self.input_data_non_media_and_organic = (
443
+ cls._input_data_non_media_and_organic = (
396
444
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
397
- n_geos=self._N_GEOS,
398
- n_times=self._N_TIMES,
399
- n_media_times=self._N_MEDIA_TIMES,
400
- n_controls=self._N_CONTROLS,
401
- n_non_media_channels=self._N_NON_MEDIA_CHANNELS,
402
- n_media_channels=self._N_MEDIA_CHANNELS,
403
- n_rf_channels=self._N_RF_CHANNELS,
404
- n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS,
405
- n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS,
445
+ n_geos=cls._N_GEOS,
446
+ n_times=cls._N_TIMES,
447
+ n_media_times=cls._N_MEDIA_TIMES,
448
+ n_controls=cls._N_CONTROLS,
449
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
450
+ n_media_channels=cls._N_MEDIA_CHANNELS,
451
+ n_rf_channels=cls._N_RF_CHANNELS,
452
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
453
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
406
454
  seed=0,
407
455
  )
408
456
  )
409
- self.short_input_data_non_media_and_organic = (
457
+ cls._short_input_data_non_media_and_organic = (
410
458
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
411
- n_geos=self._N_GEOS,
412
- n_times=self._N_TIMES_SHORT,
413
- n_media_times=self._N_MEDIA_TIMES_SHORT,
414
- n_controls=self._N_CONTROLS,
415
- n_non_media_channels=self._N_NON_MEDIA_CHANNELS,
416
- n_media_channels=self._N_MEDIA_CHANNELS,
417
- n_rf_channels=self._N_RF_CHANNELS,
418
- n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS,
419
- n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS,
459
+ n_geos=cls._N_GEOS,
460
+ n_times=cls._N_TIMES_SHORT,
461
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
462
+ n_controls=cls._N_CONTROLS,
463
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
464
+ n_media_channels=cls._N_MEDIA_CHANNELS,
465
+ n_rf_channels=cls._N_RF_CHANNELS,
466
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
467
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
420
468
  seed=0,
421
469
  )
422
470
  )
423
- self.short_input_data_non_media = (
471
+ cls._short_input_data_non_media = (
424
472
  test_utils.sample_input_data_non_revenue_revenue_per_kpi(
425
- n_geos=self._N_GEOS,
426
- n_times=self._N_TIMES_SHORT,
427
- n_media_times=self._N_MEDIA_TIMES_SHORT,
428
- n_controls=self._N_CONTROLS,
429
- n_non_media_channels=self._N_NON_MEDIA_CHANNELS,
430
- n_media_channels=self._N_MEDIA_CHANNELS,
431
- n_rf_channels=self._N_RF_CHANNELS,
473
+ n_geos=cls._N_GEOS,
474
+ n_times=cls._N_TIMES_SHORT,
475
+ n_media_times=cls._N_MEDIA_TIMES_SHORT,
476
+ n_controls=cls._N_CONTROLS,
477
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
478
+ n_media_channels=cls._N_MEDIA_CHANNELS,
479
+ n_rf_channels=cls._N_RF_CHANNELS,
432
480
  n_organic_media_channels=0,
433
481
  n_organic_rf_channels=0,
434
482
  seed=0,
435
483
  )
436
484
  )
485
+ cls._input_data_non_media_and_organic_same_time_dims = (
486
+ test_utils.sample_input_data_non_revenue_revenue_per_kpi(
487
+ n_geos=cls._N_GEOS,
488
+ n_times=cls._N_TIMES,
489
+ n_media_times=cls._N_TIMES,
490
+ n_controls=cls._N_CONTROLS,
491
+ n_non_media_channels=cls._N_NON_MEDIA_CHANNELS,
492
+ n_media_channels=cls._N_MEDIA_CHANNELS,
493
+ n_rf_channels=cls._N_RF_CHANNELS,
494
+ n_organic_media_channels=cls._N_ORGANIC_MEDIA_CHANNELS,
495
+ n_organic_rf_channels=cls._N_ORGANIC_RF_CHANNELS,
496
+ seed=0,
497
+ )
498
+ )
499
+
500
+ @property
501
+ def input_data_non_revenue_no_revenue_per_kpi(self) -> input_data.InputData:
502
+ return self._input_data_non_revenue_no_revenue_per_kpi.copy(deep=True)
503
+
504
+ @property
505
+ def input_data_media_and_rf_non_revenue_no_revenue_per_kpi(
506
+ self,
507
+ ) -> input_data.InputData:
508
+ return self._input_data_media_and_rf_non_revenue_no_revenue_per_kpi.copy(
509
+ deep=True
510
+ )
511
+
512
+ @property
513
+ def input_data_with_media_only(self) -> input_data.InputData:
514
+ return self._input_data_with_media_only.copy(deep=True)
515
+
516
+ @property
517
+ def input_data_with_rf_only(self) -> input_data.InputData:
518
+ return self._input_data_with_rf_only.copy(deep=True)
519
+
520
+ @property
521
+ def input_data_with_media_and_rf(self) -> input_data.InputData:
522
+ return self._input_data_with_media_and_rf.copy(deep=True)
523
+
524
+ @property
525
+ def input_data_with_media_and_rf_no_controls(
526
+ self,
527
+ ) -> input_data.InputData:
528
+ return self._input_data_with_media_and_rf_no_controls.copy(deep=True)
529
+
530
+ @property
531
+ def short_input_data_with_media_only(self) -> input_data.InputData:
532
+ return self._short_input_data_with_media_only.copy(deep=True)
533
+
534
+ @property
535
+ def short_input_data_with_media_only_no_controls(
536
+ self,
537
+ ) -> input_data.InputData:
538
+ return self._short_input_data_with_media_only_no_controls.copy(deep=True)
539
+
540
+ @property
541
+ def short_input_data_with_rf_only(self) -> input_data.InputData:
542
+ return self._short_input_data_with_rf_only.copy(deep=True)
543
+
544
+ @property
545
+ def short_input_data_with_media_and_rf(
546
+ self,
547
+ ) -> input_data.InputData:
548
+ return self._short_input_data_with_media_and_rf.copy(deep=True)
549
+
550
+ @property
551
+ def national_input_data_media_only(self) -> input_data.InputData:
552
+ return self._national_input_data_media_only.copy(deep=True)
553
+
554
+ @property
555
+ def national_input_data_media_and_rf(self) -> input_data.InputData:
556
+ return self._national_input_data_media_and_rf.copy(deep=True)
557
+
558
+ @property
559
+ def test_dist_media_and_rf(
560
+ self,
561
+ ) -> collections.OrderedDict[str, backend.Tensor]:
562
+ return collections.OrderedDict(self._test_dist_media_and_rf)
563
+
564
+ @property
565
+ def test_dist_media_only(
566
+ self,
567
+ ) -> collections.OrderedDict[str, backend.Tensor]:
568
+ return collections.OrderedDict(self._test_dist_media_only)
569
+
570
+ @property
571
+ def test_dist_media_only_no_controls(
572
+ self,
573
+ ) -> collections.OrderedDict[str, backend.Tensor]:
574
+ return collections.OrderedDict(self._test_dist_media_only_no_controls)
575
+
576
+ @property
577
+ def test_dist_rf_only(self) -> collections.OrderedDict[str, backend.Tensor]:
578
+ return collections.OrderedDict(self._test_dist_rf_only)
579
+
580
+ @property
581
+ def test_trace(self) -> dict[str, backend.Tensor]:
582
+ return self._test_trace.copy()
583
+
584
+ @property
585
+ def national_input_data_non_media_and_organic(
586
+ self,
587
+ ) -> input_data.InputData:
588
+ return self._national_input_data_non_media_and_organic.copy(deep=True)
589
+
590
+ @property
591
+ def input_data_non_media_and_organic(self) -> input_data.InputData:
592
+ return self._input_data_non_media_and_organic.copy(deep=True)
593
+
594
+ @property
595
+ def short_input_data_non_media_and_organic(
596
+ self,
597
+ ) -> input_data.InputData:
598
+ return self._short_input_data_non_media_and_organic.copy(deep=True)
599
+
600
+ @property
601
+ def short_input_data_non_media(self) -> input_data.InputData:
602
+ return self._short_input_data_non_media.copy(deep=True)
603
+
604
+ @property
605
+ def input_data_non_media_and_organic_same_time_dims(
606
+ self,
607
+ ) -> input_data.InputData:
608
+ return self._input_data_non_media_and_organic_same_time_dims.copy(deep=True)