google-meridian 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@
14
14
 
15
15
  """Meridian EDA Engine."""
16
16
 
17
+ import dataclasses
17
18
  import functools
18
19
  from typing import Callable, Dict, Optional, TypeAlias
19
20
  from meridian import constants
@@ -28,11 +29,64 @@ _DEFAULT_DA_VAR_AGG_FUNCTION = np.sum
28
29
  AggregationMap: TypeAlias = Dict[str, Callable[[xr.DataArray], np.ndarray]]
29
30
 
30
31
 
32
+ @dataclasses.dataclass(frozen=True, kw_only=True)
33
+ class ReachFrequencyData:
34
+ """Holds reach and frequency data arrays.
35
+
36
+ Attributes:
37
+ reach_raw_da: Raw reach data.
38
+ reach_scaled_da: Scaled reach data.
39
+ reach_raw_da_national: National raw reach data.
40
+ reach_scaled_da_national: National scaled reach data.
41
+ frequency_da: Frequency data.
42
+ frequency_da_national: National frequency data.
43
+ rf_impressions_scaled_da: Scaled reach * frequency impressions data.
44
+ rf_impressions_scaled_da_national: National scaled reach * frequency
45
+ impressions data.
46
+ rf_impressions_raw_da: Raw reach * frequency impressions data.
47
+ rf_impressions_raw_da_national: National raw reach * frequency impressions
48
+ data.
49
+ """
50
+
51
+ reach_raw_da: xr.DataArray
52
+ reach_scaled_da: xr.DataArray
53
+ reach_raw_da_national: xr.DataArray
54
+ reach_scaled_da_national: xr.DataArray
55
+ frequency_da: xr.DataArray
56
+ frequency_da_national: xr.DataArray
57
+ rf_impressions_scaled_da: xr.DataArray
58
+ rf_impressions_scaled_da_national: xr.DataArray
59
+ rf_impressions_raw_da: xr.DataArray
60
+ rf_impressions_raw_da_national: xr.DataArray
61
+
62
+
63
+ @dataclasses.dataclass(frozen=True, kw_only=True)
64
+ class AggregationConfig:
65
+ """Configuration for custom aggregation functions.
66
+
67
+ Attributes:
68
+ control_variables: A dictionary mapping control variable names to
69
+ aggregation functions. Defaults to `np.sum` if a variable is not
70
+ specified.
71
+ non_media_treatments: A dictionary mapping non-media variable names to
72
+ aggregation functions. Defaults to `np.sum` if a variable is not
73
+ specified.
74
+ """
75
+
76
+ control_variables: AggregationMap = dataclasses.field(default_factory=dict)
77
+ non_media_treatments: AggregationMap = dataclasses.field(default_factory=dict)
78
+
79
+
31
80
  class EDAEngine:
32
81
  """Meridian EDA Engine."""
33
82
 
34
- def __init__(self, meridian: model.Meridian):
83
+ def __init__(
84
+ self,
85
+ meridian: model.Meridian,
86
+ agg_config: AggregationConfig = AggregationConfig(),
87
+ ):
35
88
  self._meridian = meridian
89
+ self._agg_config = agg_config
36
90
 
37
91
  @functools.cached_property
38
92
  def controls_scaled_da(self) -> xr.DataArray | None:
@@ -44,6 +98,26 @@ class EDAEngine:
44
98
  )
45
99
  return controls_scaled_da
46
100
 
101
+ @functools.cached_property
102
+ def controls_scaled_da_national(self) -> xr.DataArray | None:
103
+ """Returns the national controls data array."""
104
+ if self._meridian.input_data.controls is None:
105
+ return None
106
+ if self._meridian.is_national:
107
+ if self.controls_scaled_da is None:
108
+ # This case should be impossible given the check above.
109
+ raise RuntimeError(
110
+ 'controls_scaled_da is None when controls is not None.'
111
+ )
112
+ return self.controls_scaled_da.squeeze(constants.GEO)
113
+ else:
114
+ return self._aggregate_and_scale_geo_da(
115
+ self._meridian.input_data.controls,
116
+ transformers.CenteringAndScalingTransformer,
117
+ constants.CONTROL_VARIABLE,
118
+ self._agg_config.control_variables,
119
+ )
120
+
47
121
  @functools.cached_property
48
122
  def media_raw_da(self) -> xr.DataArray | None:
49
123
  if self._meridian.input_data.media is None:
@@ -71,13 +145,32 @@ class EDAEngine:
71
145
  # No need to truncate the media time for media spend.
72
146
  return media_spend_da
73
147
 
148
+ @functools.cached_property
149
+ def media_spend_da_national(self) -> xr.DataArray | None:
150
+ """Returns the national media spend data array."""
151
+ if self._meridian.input_data.media_spend is None:
152
+ return None
153
+ if self._meridian.is_national:
154
+ if self.media_spend_da is None:
155
+ # This case should be impossible given the check above.
156
+ raise RuntimeError(
157
+ 'media_spend_da is None when media_spend is not None.'
158
+ )
159
+ return self.media_spend_da.squeeze(constants.GEO)
160
+ else:
161
+ return self._aggregate_and_scale_geo_da(
162
+ self._meridian.input_data.media_spend,
163
+ None,
164
+ )
165
+
74
166
  @functools.cached_property
75
167
  def media_raw_da_national(self) -> xr.DataArray | None:
76
168
  if self.media_raw_da is None:
77
169
  return None
78
170
  if self._meridian.is_national:
79
- return self.media_raw_da
171
+ return self.media_raw_da.squeeze(constants.GEO)
80
172
  else:
173
+ # Note that media is summable by assumption.
81
174
  return self._aggregate_and_scale_geo_da(
82
175
  self.media_raw_da,
83
176
  None,
@@ -88,8 +181,9 @@ class EDAEngine:
88
181
  if self.media_scaled_da is None:
89
182
  return None
90
183
  if self._meridian.is_national:
91
- return self.media_scaled_da
184
+ return self.media_scaled_da.squeeze(constants.GEO)
92
185
  else:
186
+ # Note that media is summable by assumption.
93
187
  return self._aggregate_and_scale_geo_da(
94
188
  self.media_raw_da,
95
189
  transformers.MediaTransformer,
@@ -116,8 +210,9 @@ class EDAEngine:
116
210
  if self.organic_media_raw_da is None:
117
211
  return None
118
212
  if self._meridian.is_national:
119
- return self.organic_media_raw_da
213
+ return self.organic_media_raw_da.squeeze(constants.GEO)
120
214
  else:
215
+ # Note that organic media is summable by assumption.
121
216
  return self._aggregate_and_scale_geo_da(self.organic_media_raw_da, None)
122
217
 
123
218
  @functools.cached_property
@@ -125,8 +220,9 @@ class EDAEngine:
125
220
  if self.organic_media_scaled_da is None:
126
221
  return None
127
222
  if self._meridian.is_national:
128
- return self.organic_media_scaled_da
223
+ return self.organic_media_scaled_da.squeeze(constants.GEO)
129
224
  else:
225
+ # Note that organic media is summable by assumption.
130
226
  return self._aggregate_and_scale_geo_da(
131
227
  self.organic_media_raw_da,
132
228
  transformers.MediaTransformer,
@@ -142,6 +238,26 @@ class EDAEngine:
142
238
  )
143
239
  return non_media_scaled_da
144
240
 
241
+ @functools.cached_property
242
+ def non_media_scaled_da_national(self) -> xr.DataArray | None:
243
+ """Returns the national non-media treatment data array."""
244
+ if self._meridian.input_data.non_media_treatments is None:
245
+ return None
246
+ if self._meridian.is_national:
247
+ if self.non_media_scaled_da is None:
248
+ # This case should be impossible given the check above.
249
+ raise RuntimeError(
250
+ 'non_media_scaled_da is None when non_media_treatments is not None.'
251
+ )
252
+ return self.non_media_scaled_da.squeeze(constants.GEO)
253
+ else:
254
+ return self._aggregate_and_scale_geo_da(
255
+ self._meridian.input_data.non_media_treatments,
256
+ transformers.CenteringAndScalingTransformer,
257
+ constants.NON_MEDIA_CHANNEL,
258
+ self._agg_config.non_media_treatments,
259
+ )
260
+
145
261
  @functools.cached_property
146
262
  def rf_spend_da(self) -> xr.DataArray | None:
147
263
  if self._meridian.input_data.rf_spend is None:
@@ -157,12 +273,226 @@ class EDAEngine:
157
273
  if self._meridian.input_data.rf_spend is None:
158
274
  return None
159
275
  if self._meridian.is_national:
160
- return self.rf_spend_da
276
+ if self.rf_spend_da is None:
277
+ # This case should be impossible given the check above.
278
+ raise RuntimeError('rf_spend_da is None when rf_spend is not None.')
279
+ return self.rf_spend_da.squeeze(constants.GEO)
161
280
  else:
162
281
  return self._aggregate_and_scale_geo_da(
163
282
  self._meridian.input_data.rf_spend, None
164
283
  )
165
284
 
285
+ @functools.cached_property
286
+ def _rf_data(self) -> ReachFrequencyData | None:
287
+ if self._meridian.input_data.reach is None:
288
+ return None
289
+ return self._get_rf_data(
290
+ self._meridian.input_data.reach,
291
+ self._meridian.input_data.frequency,
292
+ is_organic=False,
293
+ )
294
+
295
+ @property
296
+ def reach_raw_da(self) -> xr.DataArray | None:
297
+ if self._rf_data is None:
298
+ return None
299
+ return self._rf_data.reach_raw_da
300
+
301
+ @property
302
+ def reach_scaled_da(self) -> xr.DataArray | None:
303
+ if self._rf_data is None:
304
+ return None
305
+ return self._rf_data.reach_scaled_da
306
+
307
+ @property
308
+ def reach_raw_da_national(self) -> xr.DataArray | None:
309
+ if self._rf_data is None:
310
+ return None
311
+ return self._rf_data.reach_raw_da_national
312
+
313
+ @property
314
+ def reach_scaled_da_national(self) -> xr.DataArray | None:
315
+ if self._rf_data is None:
316
+ return None
317
+ return self._rf_data.reach_scaled_da_national
318
+
319
+ @property
320
+ def frequency_da(self) -> xr.DataArray | None:
321
+ if self._rf_data is None:
322
+ return None
323
+ return self._rf_data.frequency_da
324
+
325
+ @property
326
+ def frequency_da_national(self) -> xr.DataArray | None:
327
+ if self._rf_data is None:
328
+ return None
329
+ return self._rf_data.frequency_da_national
330
+
331
+ @property
332
+ def rf_impressions_raw_da(self) -> xr.DataArray | None:
333
+ if self._rf_data is None:
334
+ return None
335
+ return self._rf_data.rf_impressions_raw_da
336
+
337
+ @property
338
+ def rf_impressions_raw_da_national(self) -> xr.DataArray | None:
339
+ if self._rf_data is None:
340
+ return None
341
+ return self._rf_data.rf_impressions_raw_da_national
342
+
343
+ @property
344
+ def rf_impressions_scaled_da(self) -> xr.DataArray | None:
345
+ if self._rf_data is None:
346
+ return None
347
+ return self._rf_data.rf_impressions_scaled_da
348
+
349
+ @property
350
+ def rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
351
+ if self._rf_data is None:
352
+ return None
353
+ return self._rf_data.rf_impressions_scaled_da_national
354
+
355
+ @functools.cached_property
356
+ def _organic_rf_data(self) -> ReachFrequencyData | None:
357
+ if self._meridian.input_data.organic_reach is None:
358
+ return None
359
+ return self._get_rf_data(
360
+ self._meridian.input_data.organic_reach,
361
+ self._meridian.input_data.organic_frequency,
362
+ is_organic=True,
363
+ )
364
+
365
+ @property
366
+ def organic_reach_raw_da(self) -> xr.DataArray | None:
367
+ if self._organic_rf_data is None:
368
+ return None
369
+ return self._organic_rf_data.reach_raw_da
370
+
371
+ @property
372
+ def organic_reach_scaled_da(self) -> xr.DataArray | None:
373
+ if self._organic_rf_data is None:
374
+ return None
375
+ return self._organic_rf_data.reach_scaled_da
376
+
377
+ @property
378
+ def organic_reach_raw_da_national(self) -> xr.DataArray | None:
379
+ if self._organic_rf_data is None:
380
+ return None
381
+ return self._organic_rf_data.reach_raw_da_national
382
+
383
+ @property
384
+ def organic_reach_scaled_da_national(self) -> xr.DataArray | None:
385
+ if self._organic_rf_data is None:
386
+ return None
387
+ return self._organic_rf_data.reach_scaled_da_national
388
+
389
+ @property
390
+ def organic_rf_impressions_scaled_da(self) -> xr.DataArray | None:
391
+ if self._organic_rf_data is None:
392
+ return None
393
+ return self._organic_rf_data.rf_impressions_scaled_da
394
+
395
+ @property
396
+ def organic_rf_impressions_scaled_da_national(self) -> xr.DataArray | None:
397
+ if self._organic_rf_data is None:
398
+ return None
399
+ return self._organic_rf_data.rf_impressions_scaled_da_national
400
+
401
+ @property
402
+ def organic_frequency_da(self) -> xr.DataArray | None:
403
+ if self._organic_rf_data is None:
404
+ return None
405
+ return self._organic_rf_data.frequency_da
406
+
407
+ @property
408
+ def organic_frequency_da_national(self) -> xr.DataArray | None:
409
+ if self._organic_rf_data is None:
410
+ return None
411
+ return self._organic_rf_data.frequency_da_national
412
+
413
+ @property
414
+ def organic_rf_impressions_raw_da(self) -> xr.DataArray | None:
415
+ if self._organic_rf_data is None:
416
+ return None
417
+ return self._organic_rf_data.rf_impressions_raw_da
418
+
419
+ @property
420
+ def organic_rf_impressions_raw_da_national(self) -> xr.DataArray | None:
421
+ if self._organic_rf_data is None:
422
+ return None
423
+ return self._organic_rf_data.rf_impressions_raw_da_national
424
+
425
+ @functools.cached_property
426
+ def geo_population_da(self) -> xr.DataArray | None:
427
+ if self._meridian.is_national:
428
+ return None
429
+ return xr.DataArray(
430
+ self._meridian.population,
431
+ coords={constants.GEO: self._meridian.input_data.geo.values},
432
+ dims=[constants.GEO],
433
+ name=constants.POPULATION,
434
+ )
435
+
436
+ @functools.cached_property
437
+ def kpi_scaled_da(self) -> xr.DataArray:
438
+ return _data_array_like(
439
+ da=self._meridian.input_data.kpi,
440
+ values=self._meridian.kpi_scaled,
441
+ )
442
+
443
+ @functools.cached_property
444
+ def kpi_scaled_da_national(self) -> xr.DataArray:
445
+ if self._meridian.is_national:
446
+ return self.kpi_scaled_da.squeeze(constants.GEO)
447
+ else:
448
+ # Note that kpi is summable by assumption.
449
+ return self._aggregate_and_scale_geo_da(
450
+ self._meridian.input_data.kpi,
451
+ transformers.CenteringAndScalingTransformer,
452
+ )
453
+
454
+ @functools.cached_property
455
+ def treatment_control_scaled_ds(self) -> xr.Dataset:
456
+ """Returns a Dataset containing all scaled treatments and controls.
457
+
458
+ This includes media, RF impressions, organic media, organic RF impressions,
459
+ non-media treatments, and control variables, all at the geo level.
460
+ """
461
+ to_merge = [
462
+ da
463
+ for da in [
464
+ self.media_scaled_da,
465
+ self.rf_impressions_scaled_da,
466
+ self.organic_media_scaled_da,
467
+ self.organic_rf_impressions_scaled_da,
468
+ self.controls_scaled_da,
469
+ self.non_media_scaled_da,
470
+ ]
471
+ if da is not None
472
+ ]
473
+ return xr.merge(to_merge, join='inner')
474
+
475
+ @functools.cached_property
476
+ def treatment_control_scaled_ds_national(self) -> xr.Dataset:
477
+ """Returns a Dataset containing all scaled treatments and controls.
478
+
479
+ This includes media, RF impressions, organic media, organic RF impressions,
480
+ non-media treatments, and control variables, all at the national level.
481
+ """
482
+ to_merge_national = [
483
+ da
484
+ for da in [
485
+ self.media_scaled_da_national,
486
+ self.rf_impressions_scaled_da_national,
487
+ self.organic_media_scaled_da_national,
488
+ self.organic_rf_impressions_scaled_da_national,
489
+ self.controls_scaled_da_national,
490
+ self.non_media_scaled_da_national,
491
+ ]
492
+ if da is not None
493
+ ]
494
+ return xr.merge(to_merge_national, join='inner')
495
+
166
496
  def _truncate_media_time(self, da: xr.DataArray) -> xr.DataArray:
167
497
  """Truncates the first `start` elements of the media time of a variable."""
168
498
  # This should not happen. If it does, it means this function is mis-used.
@@ -183,15 +513,17 @@ class EDAEngine:
183
513
  population: tf.Tensor = tf.constant([1.0], dtype=tf.float32),
184
514
  ):
185
515
  """Scales xarray values with a TensorTransformer."""
516
+ da = xarray.copy()
517
+
186
518
  if transformer_class is None:
187
- return xarray
519
+ return da
188
520
  elif transformer_class is transformers.CenteringAndScalingTransformer:
189
521
  xarray_transformer = transformers.CenteringAndScalingTransformer(
190
- tensor=xarray.values, population=population
522
+ tensor=da.values, population=population
191
523
  )
192
524
  elif transformer_class is transformers.MediaTransformer:
193
525
  xarray_transformer = transformers.MediaTransformer(
194
- media=xarray.values, population=population
526
+ media=da.values, population=population
195
527
  )
196
528
  else:
197
529
  raise ValueError(
@@ -200,8 +532,8 @@ class EDAEngine:
200
532
  + '.\nMust be one of: CenteringAndScalingTransformer or'
201
533
  ' MediaTransformer.'
202
534
  )
203
- xarray.values = xarray_transformer.forward(xarray.values)
204
- return xarray
535
+ da.values = xarray_transformer.forward(da.values)
536
+ return da
205
537
 
206
538
  def _aggregate_variables(
207
539
  self,
@@ -282,6 +614,103 @@ class EDAEngine:
282
614
 
283
615
  return da_national.sel({constants.GEO: temp_geo_dim}, drop=True)
284
616
 
617
+ def _get_rf_data(
618
+ self,
619
+ reach_raw_da: xr.DataArray,
620
+ freq_raw_da: xr.DataArray,
621
+ is_organic: bool,
622
+ ) -> ReachFrequencyData:
623
+ """Get impressions and frequencies data arrays for RF channels."""
624
+ if is_organic:
625
+ scaled_reach_values = (
626
+ self._meridian.organic_rf_tensors.organic_reach_scaled
627
+ )
628
+ else:
629
+ scaled_reach_values = self._meridian.rf_tensors.reach_scaled
630
+ reach_scaled_da = _data_array_like(
631
+ da=reach_raw_da, values=scaled_reach_values
632
+ )
633
+ # Truncate the media time for reach and scaled reach.
634
+ reach_raw_da = self._truncate_media_time(reach_raw_da)
635
+ reach_scaled_da = self._truncate_media_time(reach_scaled_da)
636
+
637
+ # The geo level frequency
638
+ frequency_da = self._truncate_media_time(freq_raw_da)
639
+
640
+ # The raw geo level impression
641
+ # It's equal to reach * frequency.
642
+ impressions_raw_da = reach_raw_da * frequency_da
643
+ impressions_raw_da.name = (
644
+ constants.ORGANIC_RF_IMPRESSIONS
645
+ if is_organic
646
+ else constants.RF_IMPRESSIONS
647
+ )
648
+ impressions_raw_da.values = tf.cast(impressions_raw_da.values, tf.float32)
649
+
650
+ if self._meridian.is_national:
651
+ reach_raw_da_national = reach_raw_da.squeeze(constants.GEO)
652
+ reach_scaled_da_national = reach_scaled_da.squeeze(constants.GEO)
653
+ impressions_raw_da_national = impressions_raw_da.squeeze(constants.GEO)
654
+ frequency_da_national = frequency_da.squeeze(constants.GEO)
655
+
656
+ # Scaled impressions
657
+ impressions_scaled_da = self._scale_xarray(
658
+ impressions_raw_da, transformers.MediaTransformer
659
+ )
660
+ impressions_scaled_da_national = impressions_scaled_da.squeeze(
661
+ constants.GEO
662
+ )
663
+ else:
664
+ reach_raw_da_national = self._aggregate_and_scale_geo_da(
665
+ reach_raw_da, None
666
+ )
667
+ reach_scaled_da_national = self._aggregate_and_scale_geo_da(
668
+ reach_raw_da, transformers.MediaTransformer
669
+ )
670
+ impressions_raw_da_national = self._aggregate_and_scale_geo_da(
671
+ impressions_raw_da, None
672
+ )
673
+
674
+ # National frequency is a weighted average of geo frequencies,
675
+ # weighted by reach.
676
+ frequency_da_national = xr.where(
677
+ reach_raw_da_national == 0.0,
678
+ 0.0,
679
+ impressions_raw_da_national / reach_raw_da_national,
680
+ )
681
+ frequency_da_national.name = (
682
+ constants.ORGANIC_PREFIX if is_organic else ''
683
+ ) + constants.FREQUENCY
684
+ frequency_da_national.values = tf.cast(
685
+ frequency_da_national.values, tf.float32
686
+ )
687
+
688
+ # Scale the impressions by population
689
+ impressions_scaled_da = self._scale_xarray(
690
+ impressions_raw_da,
691
+ transformers.MediaTransformer,
692
+ population=self._meridian.population,
693
+ )
694
+
695
+ # Scale the national impressions
696
+ impressions_scaled_da_national = self._aggregate_and_scale_geo_da(
697
+ impressions_raw_da,
698
+ transformers.MediaTransformer,
699
+ )
700
+
701
+ return ReachFrequencyData(
702
+ reach_raw_da=reach_raw_da,
703
+ reach_scaled_da=reach_scaled_da,
704
+ reach_raw_da_national=reach_raw_da_national,
705
+ reach_scaled_da_national=reach_scaled_da_national,
706
+ frequency_da=frequency_da,
707
+ frequency_da_national=frequency_da_national,
708
+ rf_impressions_scaled_da=impressions_scaled_da,
709
+ rf_impressions_scaled_da_national=impressions_scaled_da_national,
710
+ rf_impressions_raw_da=impressions_raw_da,
711
+ rf_impressions_raw_da_national=impressions_raw_da_national,
712
+ )
713
+
285
714
 
286
715
  def _data_array_like(
287
716
  *, da: xr.DataArray, values: np.ndarray | tf.Tensor
meridian/model/knots.py CHANGED
@@ -276,7 +276,7 @@ class AKS:
276
276
  feasible_idx = np.where(
277
277
  (n_knots >= min_internal_knots) & (n_knots <= max_internal_knots)
278
278
  )[0]
279
- information_criterion = aspline[constants.EBIC][feasible_idx]
279
+ information_criterion = aspline[constants.AIC][feasible_idx]
280
280
  knots_sel = [aspline[constants.KNOTS_SELECTED][i] for i in feasible_idx]
281
281
  model = [aspline[constants.MODEL][i] for i in feasible_idx]
282
282
  opt_idx = max(
@@ -52,7 +52,9 @@ def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> backend.Tensor:
52
52
  else:
53
53
  pad_value = 0.0 if array.dtype.kind == "f" else 0
54
54
 
55
- burnin = backend.fill([n_burnin] + transposed_tensor.shape[1:], pad_value)
55
+ burnin = backend.fill(
56
+ [n_burnin] + list(transposed_tensor.shape[1:]), pad_value
57
+ )
56
58
  return backend.concatenate(
57
59
  [burnin, transposed_tensor],
58
60
  axis=0,
@@ -122,18 +124,13 @@ class WithInputDataSamples:
122
124
  _N_MEDIA_CHANNELS = 3
123
125
  _N_RF_CHANNELS = 2
124
126
  _N_CONTROLS = 2
125
- _ROI_CALIBRATION_PERIOD = backend.cast(
126
- backend.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)),
127
- dtype=backend.bool_,
128
- )
129
- _RF_ROI_CALIBRATION_PERIOD = backend.cast(
130
- backend.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)),
131
- dtype=backend.bool_,
132
- )
133
127
  _N_ORGANIC_MEDIA_CHANNELS = 4
134
128
  _N_ORGANIC_RF_CHANNELS = 1
135
129
  _N_NON_MEDIA_CHANNELS = 2
136
130
 
131
+ _ROI_CALIBRATION_PERIOD: backend.Tensor
132
+ _RF_ROI_CALIBRATION_PERIOD: backend.Tensor
133
+
137
134
  # Private class variables to hold the base test data.
138
135
  _input_data_non_revenue_no_revenue_per_kpi: input_data.InputData
139
136
  _input_data_media_and_rf_non_revenue_no_revenue_per_kpi: input_data.InputData
@@ -170,6 +167,15 @@ class WithInputDataSamples:
170
167
  @classmethod
171
168
  def setup(cls):
172
169
  """Sets up input data samples."""
170
+ cls._ROI_CALIBRATION_PERIOD = backend.cast(
171
+ backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_MEDIA_CHANNELS)),
172
+ dtype=backend.bool_,
173
+ )
174
+ cls._RF_ROI_CALIBRATION_PERIOD = backend.cast(
175
+ backend.ones((cls._N_MEDIA_TIMES_SHORT, cls._N_RF_CHANNELS)),
176
+ dtype=backend.bool_,
177
+ )
178
+
173
179
  cls._input_data_non_revenue_no_revenue_per_kpi = (
174
180
  test_utils.sample_input_data_non_revenue_no_revenue_per_kpi(
175
181
  n_geos=cls._N_GEOS,