google-meridian 1.0.8__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/constants.py CHANGED
@@ -51,6 +51,8 @@ GREY_300 = '#DADCE0'
51
51
 
52
52
  # Example: "2024-01-09"
53
53
  DATE_FORMAT = '%Y-%m-%d'
54
+ # Example: "2024 Apr"
55
+ QUARTER_FORMAT = '%Y %b'
54
56
 
55
57
  # Input data variables.
56
58
  KPI = 'kpi'
@@ -95,12 +97,8 @@ POSSIBLE_INPUT_DATA_ARRAY_NAMES = (
95
97
  + MEDIA_INPUT_DATA_ARRAY_NAMES
96
98
  + RF_INPUT_DATA_ARRAY_NAMES
97
99
  )
98
- PAID_DATA = (
99
- MEDIA,
100
- REACH,
101
- FREQUENCY,
102
- REVENUE_PER_KPI,
103
- )
100
+ PAID_CHANNELS = (MEDIA, REACH, FREQUENCY)
101
+ PAID_DATA = PAID_CHANNELS + (REVENUE_PER_KPI,)
104
102
  NON_PAID_DATA = (
105
103
  ORGANIC_MEDIA,
106
104
  ORGANIC_REACH,
@@ -112,11 +110,7 @@ SPEND_DATA = (
112
110
  RF_SPEND,
113
111
  )
114
112
  PERFORMANCE_DATA = PAID_DATA + SPEND_DATA
115
- IMPRESSIONS_DATA = (
116
- MEDIA,
117
- REACH,
118
- FREQUENCY,
119
- ) + NON_PAID_DATA
113
+ IMPRESSIONS_DATA = PAID_CHANNELS + NON_PAID_DATA
120
114
  RF_DATA = (
121
115
  REACH,
122
116
  FREQUENCY,
@@ -200,17 +194,23 @@ RF_ROI_CALIBRATION_PERIOD = 'rf_roi_calibration_period'
200
194
  KNOTS = 'knots'
201
195
  BASELINE_GEO = 'baseline_geo'
202
196
 
203
- # Media prior types.
204
- PAID_MEDIA_PRIOR_TYPE_ROI = 'roi'
205
- PAID_MEDIA_PRIOR_TYPE_MROI = 'mroi'
206
- PAID_MEDIA_PRIOR_TYPE_COEFFICIENT = 'coefficient'
207
- PAID_MEDIA_PRIOR_TYPES = frozenset({
208
- PAID_MEDIA_PRIOR_TYPE_ROI,
209
- PAID_MEDIA_PRIOR_TYPE_MROI,
210
- PAID_MEDIA_PRIOR_TYPE_COEFFICIENT,
197
+ # Treatment prior types.
198
+ TREATMENT_PRIOR_TYPE_ROI = 'roi'
199
+ TREATMENT_PRIOR_TYPE_MROI = 'mroi'
200
+ TREATMENT_PRIOR_TYPE_COEFFICIENT = 'coefficient'
201
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION = 'contribution'
202
+ PAID_TREATMENT_PRIOR_TYPES = frozenset({
203
+ TREATMENT_PRIOR_TYPE_ROI,
204
+ TREATMENT_PRIOR_TYPE_MROI,
205
+ TREATMENT_PRIOR_TYPE_COEFFICIENT,
206
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION,
207
+ })
208
+ NON_PAID_TREATMENT_PRIOR_TYPES = frozenset({
209
+ TREATMENT_PRIOR_TYPE_COEFFICIENT,
210
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION,
211
211
  })
212
212
  PAID_MEDIA_ROI_PRIOR_TYPES = frozenset(
213
- {PAID_MEDIA_PRIOR_TYPE_ROI, PAID_MEDIA_PRIOR_TYPE_MROI}
213
+ {TREATMENT_PRIOR_TYPE_ROI, TREATMENT_PRIOR_TYPE_MROI}
214
214
  )
215
215
  # Represents a 1% increase in spend.
216
216
  MROI_FACTOR = 1.01
@@ -239,6 +239,11 @@ ROI_M = 'roi_m'
239
239
  ROI_RF = 'roi_rf'
240
240
  MROI_M = 'mroi_m'
241
241
  MROI_RF = 'mroi_rf'
242
+ CONTRIBUTION_M = 'contribution_m'
243
+ CONTRIBUTION_RF = 'contribution_rf'
244
+ CONTRIBUTION_OM = 'contribution_om'
245
+ CONTRIBUTION_ORF = 'contribution_orf'
246
+ CONTRIBUTION_N = 'contribution_n'
242
247
  GAMMA_C = 'gamma_c'
243
248
  GAMMA_N = 'gamma_n'
244
249
  XI_C = 'xi_c'
@@ -307,7 +312,29 @@ RF_PARAMETER_NAMES = (
307
312
  BETA_RF,
308
313
  BETA_GRF,
309
314
  )
315
+
316
+ MEDIA_PARAMETERS = (
317
+ ROI_M,
318
+ MROI_M,
319
+ CONTRIBUTION_M,
320
+ BETA_M,
321
+ ETA_M,
322
+ ALPHA_M,
323
+ EC_M,
324
+ SLOPE_M,
325
+ )
326
+ RF_PARAMETERS = (
327
+ ROI_RF,
328
+ MROI_RF,
329
+ CONTRIBUTION_RF,
330
+ BETA_RF,
331
+ ETA_RF,
332
+ ALPHA_RF,
333
+ EC_RF,
334
+ SLOPE_RF,
335
+ )
310
336
  ORGANIC_MEDIA_PARAMETERS = (
337
+ CONTRIBUTION_OM,
311
338
  BETA_OM,
312
339
  ETA_OM,
313
340
  ALPHA_OM,
@@ -315,6 +342,7 @@ ORGANIC_MEDIA_PARAMETERS = (
315
342
  SLOPE_OM,
316
343
  )
317
344
  ORGANIC_RF_PARAMETERS = (
345
+ CONTRIBUTION_ORF,
318
346
  BETA_ORF,
319
347
  ETA_ORF,
320
348
  ALPHA_ORF,
@@ -322,13 +350,12 @@ ORGANIC_RF_PARAMETERS = (
322
350
  SLOPE_ORF,
323
351
  )
324
352
  NON_MEDIA_PARAMETERS = (
353
+ CONTRIBUTION_N,
325
354
  GAMMA_N,
326
355
  XI_N,
327
356
  )
328
357
 
329
358
  KNOTS_PARAMETERS = (KNOT_VALUES,)
330
- MEDIA_PARAMETERS = (ETA_M, BETA_M, ALPHA_M, EC_M, SLOPE_M, ROI_M, MROI_M)
331
- RF_PARAMETERS = (ETA_RF, BETA_RF, ALPHA_RF, EC_RF, SLOPE_RF, ROI_RF, MROI_RF)
332
359
  CONTROL_PARAMETERS = (GAMMA_C, XI_C)
333
360
  SIGMA_PARAMETERS = (SIGMA,)
334
361
  GEO_PARAMETERS = (
@@ -366,10 +393,61 @@ UNSAVED_PARAMETERS = (
366
393
  GAMMA_GN_DEV,
367
394
  TAU_G_EXCL_BASELINE, # Used to derive TAU_G.
368
395
  )
369
- IGNORED_PRIORS = immutabledict.immutabledict({
370
- PAID_MEDIA_PRIOR_TYPE_ROI: (BETA_M, BETA_RF, MROI_M, MROI_RF),
371
- PAID_MEDIA_PRIOR_TYPE_MROI: (BETA_M, BETA_RF, ROI_M, ROI_RF),
372
- PAID_MEDIA_PRIOR_TYPE_COEFFICIENT: (ROI_M, ROI_RF, MROI_M, MROI_RF),
396
+ IGNORED_PRIORS_MEDIA = immutabledict.immutabledict({
397
+ TREATMENT_PRIOR_TYPE_ROI: (
398
+ BETA_M,
399
+ MROI_M,
400
+ CONTRIBUTION_M,
401
+ ),
402
+ TREATMENT_PRIOR_TYPE_MROI: (
403
+ BETA_M,
404
+ ROI_M,
405
+ CONTRIBUTION_M,
406
+ ),
407
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION: (
408
+ BETA_M,
409
+ ROI_M,
410
+ MROI_M,
411
+ ),
412
+ TREATMENT_PRIOR_TYPE_COEFFICIENT: (
413
+ ROI_M,
414
+ MROI_M,
415
+ CONTRIBUTION_M,
416
+ ),
417
+ })
418
+ IGNORED_PRIORS_RF = immutabledict.immutabledict({
419
+ TREATMENT_PRIOR_TYPE_ROI: (
420
+ BETA_RF,
421
+ MROI_RF,
422
+ CONTRIBUTION_RF,
423
+ ),
424
+ TREATMENT_PRIOR_TYPE_MROI: (
425
+ BETA_RF,
426
+ ROI_RF,
427
+ CONTRIBUTION_RF,
428
+ ),
429
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION: (
430
+ BETA_RF,
431
+ ROI_RF,
432
+ MROI_RF,
433
+ ),
434
+ TREATMENT_PRIOR_TYPE_COEFFICIENT: (
435
+ ROI_RF,
436
+ MROI_RF,
437
+ CONTRIBUTION_RF,
438
+ ),
439
+ })
440
+ IGNORED_PRIORS_ORGANIC_MEDIA = immutabledict.immutabledict({
441
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION: (BETA_OM,),
442
+ TREATMENT_PRIOR_TYPE_COEFFICIENT: (CONTRIBUTION_OM,),
443
+ })
444
+ IGNORED_PRIORS_ORGANIC_RF = immutabledict.immutabledict({
445
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION: (BETA_ORF,),
446
+ TREATMENT_PRIOR_TYPE_COEFFICIENT: (CONTRIBUTION_ORF,),
447
+ })
448
+ IGNORED_PRIORS_NON_MEDIA_TREATMENTS = immutabledict.immutabledict({
449
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION: (GAMMA_N,),
450
+ TREATMENT_PRIOR_TYPE_COEFFICIENT: (CONTRIBUTION_N,),
373
451
  })
374
452
 
375
453
  # Inference data dimensions.
@@ -622,3 +700,10 @@ CARD_STATS = 'stats'
622
700
  # VegaLite common params.
623
701
  VEGALITE_FACET_DEFAULT_WIDTH = 400
624
702
  VEGALITE_FACET_LARGE_WIDTH = 500
703
+ VEGALITE_FACET_EXTRA_LARGE_WIDTH = 700
704
+
705
+ # Time Granularity Constants
706
+ WEEKLY = 'weekly'
707
+ QUARTERLY = 'quarterly'
708
+ TIME_GRANULARITIES = frozenset({WEEKLY, QUARTERLY})
709
+ QUARTERLY_SUMMARY_THRESHOLD_WEEKS = 52
@@ -60,7 +60,7 @@ def _check_dim_collection(
60
60
  )
61
61
 
62
62
 
63
- def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray]):
63
+ def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray | None]):
64
64
  """Verifies that the dimensions of the appropriate arrays match."""
65
65
  lengths = [len(array.coords[dim]) for array in arrays if array is not None]
66
66
  names = [array.name for array in arrays if array is not None]
@@ -83,6 +83,31 @@ def _check_coords_match(dim: str, arrays: Sequence[xr.DataArray]):
83
83
  )
84
84
 
85
85
 
86
+ def _aggregate_spend(
87
+ spend: xr.DataArray, calibration_period: np.ndarray | None
88
+ ) -> np.ndarray | None:
89
+ """Aggregates spend for each channel over the calibration period.
90
+
91
+ Args:
92
+ spend: An array with shape `(n_geos, n_times, n_channels)` to aggregate.
93
+ calibration_period: An optional boolean array of shape `(n_media_times,
94
+ n_channels)`. If provided, spend is filtered according to this period.
95
+
96
+ Returns:
97
+ A 1-D array of aggregated media spend per channel, or `None` if `spend` is
98
+ `None`.
99
+ """
100
+ if spend is None:
101
+ return None
102
+
103
+ if calibration_period is None:
104
+ return np.sum(spend, axis=(0, 1))
105
+
106
+ # Select the last `n_times` from the `calibration_period`
107
+ factors = np.where(calibration_period[-spend.shape[1] :, :], 1, 0)
108
+ return np.einsum("gtm,tm->m", spend, factors)
109
+
110
+
86
111
  @dataclasses.dataclass
87
112
  class InputData:
88
113
  """A data container for advertising data in a format supported by Meridian.
@@ -120,8 +145,14 @@ class InputData:
120
145
  in the same order. If either of these arguments is passed, then the other
121
146
  is not optional.
122
147
  media_spend: An optional `DataArray` containing the cost of each media
123
- channel. This is used as the denominator for ROI calculations. The
124
- DataArray shape can be `(n_geos, n_times, n_media_channels)` or
148
+ channel. This is used as the denominator for ROI calculations. It is also
149
+ used to calculate an assumed cost per media unit for post-modeling
150
+ analysis such as response curves and budget optimization. Only the
151
+ aggregate spend (across geos and time periods) is required for these
152
+ calculations. However, a spend breakdown by geo and time period is
153
+ required if `roi_calibration_period` is specified or if conducting
154
+ post-modeling analysis on a specific subset of geos and/or time periods.
155
+ The DataArray shape can be `(n_geos, n_times, n_media_channels)` or
125
156
  `(n_media_channels,)` if the data is aggregated over `geo` and `time`
126
157
  dimensions. We recommend that the spend total aligns with the time window
127
158
  of the `kpi` and `controls` data, which is the time window over which
@@ -131,7 +162,9 @@ class InputData:
131
162
  time window of media executed during the time window. `media` and
132
163
  `media_spend` must contain the same number of media channels in the same
133
164
  order. If either of these arguments is passed, then the other is not
134
- optional.
165
+ optional. If a tensor of shape `(n_media_channels,)` is passed as
166
+ `media_spend`, then it will be automatically allocated across geos and
167
+ times proportinally to `media`.
135
168
  reach: An optional `DataArray` of dimensions `(n_geos, n_media_times,
136
169
  n_rf_channels)` containing non-negative `reach` values. It is required
137
170
  that `n_media_times` ≥ `n_times`, and the final `n_times` time periods
@@ -164,18 +197,26 @@ class InputData:
164
197
  others are not optional.
165
198
  rf_spend: An optional `DataArray` containing the cost of each reach and
166
199
  frequency channel. This is used as the denominator for ROI calculations.
167
- The DataArray shape can be `(n_rf_channels,)`, `(n_geos, n_times,
168
- n_rf_channels)`, or `(n_geos, n_rf_channels)`. The spend should be
169
- aggregated over geo and/or time dimensions that are not represented. We
170
- recommend that the spend total aligns with the time window of the `kpi`
171
- and `controls` data, which is the time window over which incremental
172
- outcome of the ROI numerator is calculated. However, note that incremental
173
- outcome is influenced by media execution prior to this time window,
174
- through lagged effects, and excludes lagged effects beyond the time window
175
- of media executed during the time window. If only `media` data is used,
176
- `rf_spend` will be `None`. `reach`, `frequency`, and `rf_spend` must
177
- contain the same number of media channels in the same order. If any of
178
- these arguments is passed, then the others are not optional.
200
+ It is also used to calculate an assumed cost per media unit for
201
+ post-modeling analysis such as response curves and budget optimization.
202
+ Only the aggregate spend (across geos and time periods) is required for
203
+ these calculations. However, a spend breakdown by geo and time period is
204
+ required if `rf_roi_calibration_period` is specified or if conducting
205
+ post-modeling analysis on a specific subset of geos and/or time periods.
206
+ The DataArray shape can be `(n_rf_channels,)` or `(n_geos, n_times,
207
+ n_rf_channels)`. The spend should be aggregated over geo and/or time
208
+ dimensions that are not represented. We recommend that the spend total
209
+ aligns with the time window of the `kpi` and `controls` data, which is the
210
+ time window over which incremental outcome of the ROI numerator is
211
+ calculated. However, note that incremental outcome is influenced by media
212
+ execution prior to this time window, through lagged effects, and excludes
213
+ lagged effects beyond the time window of media executed during the time
214
+ window. If only `media` data is used, `rf_spend` will be `None`. `reach`,
215
+ `frequency`, and `rf_spend` must contain the same number of media channels
216
+ in the same order. If any of these arguments is passed, then the others
217
+ are not optional. If a tensor of shape `(n_rf_channels,)` is passed as
218
+ `rf_spend`, then it will be automatically allocated across geos and times
219
+ proportionally to `(reach * frequency)`.
179
220
  organic_media: An optional `DataArray` of dimensions `(n_geos,
180
221
  n_media_times, n_organic_media_channels)` containing non-negative organic
181
222
  media values. Organic media variables are media activities that have no
@@ -265,6 +306,40 @@ class InputData:
265
306
  if isinstance(array, xr.DataArray) and constants.GEO in array.dims:
266
307
  array.coords[constants.GEO] = array.coords[constants.GEO].astype(str)
267
308
 
309
+ # TODO: b/416775065 - Combine with Analyzer._impute_and_aggregate_spend
310
+ @functools.cached_property
311
+ def allocated_media_spend(self) -> xr.DataArray | None:
312
+ """Returns the allocated media spend for each geo and time."""
313
+ if self.media_spend is not None and len(self.media_spend.shape) == 1:
314
+ return self._allocate_spend(self.media_spend, self.media)
315
+ else:
316
+ return self.media_spend
317
+
318
+ @property
319
+ def allocated_rf_spend(self) -> xr.DataArray | None:
320
+ """Returns the allocated RF spend for each geo and time."""
321
+ if self.rf_spend is not None and len(self.rf_spend.shape) == 1:
322
+ return self._allocate_spend(self.rf_spend, self.reach * self.frequency)
323
+ else:
324
+ return self.rf_spend
325
+
326
+ def aggregate_media_spend(
327
+ self, calibration_period: np.ndarray | None = None
328
+ ) -> np.ndarray | None:
329
+ """Aggregates media spend by channel over the calibration period."""
330
+ return _aggregate_spend(
331
+ spend=self.allocated_media_spend, calibration_period=calibration_period
332
+ )
333
+
334
+ def aggregate_rf_spend(
335
+ self, calibration_period: np.ndarray | None = None
336
+ ) -> np.ndarray | None:
337
+ """Aggregates RF spend by channel over the calibration period."""
338
+ return _aggregate_spend(
339
+ spend=self.allocated_rf_spend,
340
+ calibration_period=calibration_period,
341
+ )
342
+
268
343
  @property
269
344
  def geo(self) -> xr.DataArray:
270
345
  """Returns the geo dimension."""
@@ -424,7 +499,8 @@ class InputData:
424
499
 
425
500
  def _validate_names(self):
426
501
  """Verifies that the names of the data arrays are correct."""
427
- arrays = [
502
+ # Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
503
+ arrays = (
428
504
  self.kpi,
429
505
  self.controls,
430
506
  self.population,
@@ -438,7 +514,7 @@ class InputData:
438
514
  self.reach,
439
515
  self.frequency,
440
516
  self.rf_spend,
441
- ]
517
+ )
442
518
 
443
519
  for array, name in zip(arrays, constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES):
444
520
  if array is not None and array.name != name:
@@ -479,7 +555,6 @@ class InputData:
479
555
  [
480
556
  [constants.RF_CHANNEL],
481
557
  [constants.GEO, constants.TIME, constants.RF_CHANNEL],
482
- [constants.GEO, constants.RF_CHANNEL],
483
558
  ],
484
559
  )
485
560
  _check_dim_collection(
@@ -848,3 +923,24 @@ class InputData:
848
923
  return self.media_spend.values
849
924
  else:
850
925
  raise ValueError("Both RF and Media are missing.")
926
+
927
+ def get_total_outcome(self) -> np.ndarray:
928
+ """Returns total outcome, aggregated over geos and times."""
929
+ if self.revenue_per_kpi is None:
930
+ return np.sum(self.kpi.values)
931
+ return np.sum(self.kpi.values * self.revenue_per_kpi.values)
932
+
933
+ def _allocate_spend(self, spend: xr.DataArray, media_units: xr.DataArray):
934
+ """Allocates spend across geo and time proportionally to media units."""
935
+ n_times = len(self.kpi.coords[constants.TIME])
936
+ selected_media_units = media_units.isel(media_time=slice(-n_times, None))
937
+ total_media_units_per_channel = selected_media_units.sum(
938
+ dim=["geo", "media_time"]
939
+ )
940
+ proportions = selected_media_units / total_media_units_per_channel
941
+ expanded_spend = spend.expand_dims({
942
+ "geo": selected_media_units["geo"],
943
+ "media_time": selected_media_units["media_time"],
944
+ })
945
+ allocated_spend = expanded_spend * proportions
946
+ return allocated_spend.rename({"media_time": "time"})
@@ -65,12 +65,24 @@ _REQUIRED_COORDS = immutabledict.immutabledict({
65
65
  c.MEDIA_TIME: _sample_times(n_times=3),
66
66
  c.CONTROL_VARIABLE: ['control_0', 'control_1'],
67
67
  })
68
+ _NON_MEDIA_COORDS = immutabledict.immutabledict(
69
+ {c.NON_MEDIA_CHANNEL: ['non_media_channel_0', 'non_media_channel_1']}
70
+ )
68
71
  _MEDIA_COORDS = immutabledict.immutabledict(
69
72
  {c.MEDIA_CHANNEL: ['media_channel_0', 'media_channel_1', 'media_channel_2']}
70
73
  )
74
+ _ORGANIC_MEDIA_COORDS = immutabledict.immutabledict({
75
+ c.ORGANIC_MEDIA_CHANNEL: [
76
+ 'organic_media_channel_0',
77
+ 'organic_media_channel_1',
78
+ ]
79
+ })
71
80
  _RF_COORDS = immutabledict.immutabledict(
72
81
  {c.RF_CHANNEL: ['rf_channel_0', 'rf_channel_1']}
73
82
  )
83
+ _ORGANIC_RF_COORDS = immutabledict.immutabledict(
84
+ {c.ORGANIC_RF_CHANNEL: ['organic_rf_channel_0', 'organic_rf_channel_1']}
85
+ )
74
86
 
75
87
  _REQUIRED_DATA_VARS = immutabledict.immutabledict({
76
88
  c.KPI: (['geo', 'time'], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
@@ -376,6 +388,70 @@ DATASET_WITHOUT_TIME_VARIATION_IN_REACH = xr.Dataset(
376
388
  },
377
389
  )
378
390
 
391
+ DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_MEDIA = xr.Dataset(
392
+ coords=_REQUIRED_COORDS
393
+ | _MEDIA_COORDS
394
+ | _RF_COORDS
395
+ | _ORGANIC_MEDIA_COORDS,
396
+ data_vars=_REQUIRED_DATA_VARS
397
+ | _MEDIA_DATA_VARS
398
+ | _RF_DATA_VARS
399
+ | _OPTIONAL_DATA_VARS
400
+ | {
401
+ c.ORGANIC_MEDIA: (
402
+ ['geo', 'media_time', 'organic_media_channel'],
403
+ [
404
+ [[2.1, 2.2], [2.1, 2.21], [2.1, 2.2]],
405
+ [[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
406
+ ],
407
+ ),
408
+ },
409
+ )
410
+
411
+ DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_REACH = xr.Dataset(
412
+ coords=_REQUIRED_COORDS
413
+ | _MEDIA_COORDS
414
+ | _RF_COORDS
415
+ | _ORGANIC_RF_COORDS,
416
+ data_vars=_REQUIRED_DATA_VARS
417
+ | _MEDIA_DATA_VARS
418
+ | _RF_DATA_VARS
419
+ | _OPTIONAL_DATA_VARS
420
+ | {
421
+ c.ORGANIC_REACH: (
422
+ ['geo', 'media_time', 'organic_rf_channel'],
423
+ [
424
+ [[2.1, 2.2], [2.11, 2.2], [2.1, 2.2]],
425
+ [[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
426
+ ],
427
+ ),
428
+ c.ORGANIC_FREQUENCY: (
429
+ ['geo', 'media_time', 'organic_rf_channel'],
430
+ [
431
+ [[7.1, 7.2], [7.3, 7.4], [7.5, 7.6]],
432
+ [[7.11, 7.21], [7.31, 7.41], [7.51, 7.61]],
433
+ ],
434
+ ),
435
+ },
436
+ )
437
+
438
+ DATASET_WITHOUT_TIME_VARIATION_IN_NON_MEDIA_TREATMENTS = xr.Dataset(
439
+ coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _NON_MEDIA_COORDS,
440
+ data_vars=_REQUIRED_DATA_VARS
441
+ | _MEDIA_DATA_VARS
442
+ | _RF_DATA_VARS
443
+ | _OPTIONAL_DATA_VARS
444
+ | {
445
+ c.NON_MEDIA_TREATMENTS: (
446
+ ['geo', 'time', 'non_media_channel'],
447
+ [
448
+ [[2.1, 2.2], [2.1, 2.2], [2.1, 2.2]],
449
+ [[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
450
+ ],
451
+ ),
452
+ },
453
+ )
454
+
379
455
  _NATIONAL_COORDS = immutabledict.immutabledict({
380
456
  c.TIME: [
381
457
  _SAMPLE_START_DATE.strftime(c.DATE_FORMAT),
@@ -1491,17 +1567,52 @@ def sample_input_data_from_dataset(
1491
1567
  dataset: xr.Dataset, kpi_type: str
1492
1568
  ) -> input_data.InputData:
1493
1569
  """Generates a sample `InputData` from a full xarray Dataset."""
1570
+ media = dataset.media if c.MEDIA in dataset.data_vars.keys() else None
1571
+ media_spend = (
1572
+ dataset.media_spend if c.MEDIA_SPEND in dataset.data_vars.keys() else None
1573
+ )
1574
+ reach = dataset.reach if c.REACH in dataset.data_vars.keys() else None
1575
+ frequency = (
1576
+ dataset.frequency if c.FREQUENCY in dataset.data_vars.keys() else None
1577
+ )
1578
+ rf_spend = (
1579
+ dataset.rf_spend if c.RF_SPEND in dataset.data_vars.keys() else None
1580
+ )
1581
+ organic_media = (
1582
+ dataset.organic_media
1583
+ if c.ORGANIC_MEDIA in dataset.data_vars.keys()
1584
+ else None
1585
+ )
1586
+ organic_reach = (
1587
+ dataset.organic_reach
1588
+ if c.ORGANIC_REACH in dataset.data_vars.keys()
1589
+ else None
1590
+ )
1591
+ organic_frequency = (
1592
+ dataset.organic_frequency
1593
+ if c.ORGANIC_FREQUENCY in dataset.data_vars.keys()
1594
+ else None
1595
+ )
1596
+ non_media_treatments = (
1597
+ dataset.non_media_treatments
1598
+ if c.NON_MEDIA_TREATMENTS in dataset.data_vars.keys()
1599
+ else None
1600
+ )
1494
1601
  return input_data.InputData(
1495
1602
  kpi=dataset.kpi,
1496
1603
  kpi_type=kpi_type,
1497
1604
  revenue_per_kpi=dataset.revenue_per_kpi,
1498
1605
  population=dataset.population,
1499
1606
  controls=dataset.controls,
1500
- media=dataset.media,
1501
- media_spend=dataset.media_spend,
1502
- reach=dataset.reach,
1503
- frequency=dataset.frequency,
1504
- rf_spend=dataset.rf_spend,
1607
+ media=media,
1608
+ media_spend=media_spend,
1609
+ reach=reach,
1610
+ frequency=frequency,
1611
+ rf_spend=rf_spend,
1612
+ organic_media=organic_media,
1613
+ organic_reach=organic_reach,
1614
+ organic_frequency=organic_frequency,
1615
+ non_media_treatments=non_media_treatments,
1505
1616
  )
1506
1617
 
1507
1618
 
@@ -36,7 +36,7 @@ __all__ = [
36
36
 
37
37
 
38
38
  # A type alias for a polymorphic "date" type.
39
- Date: TypeAlias = str | datetime.datetime | datetime.date | np.datetime64
39
+ Date: TypeAlias = str | datetime.datetime | datetime.date | np.datetime64 | None
40
40
 
41
41
  # A type alias for a polymorphic "date interval" type. In all variants it is
42
42
  # always a tuple of (start_date, end_date).
@@ -236,8 +236,8 @@ class TimeCoordinates:
236
236
 
237
237
  def expand_selected_time_dims(
238
238
  self,
239
- start_date: Date | None = None,
240
- end_date: Date | None = None,
239
+ start_date: Date = None,
240
+ end_date: Date = None,
241
241
  ) -> list[datetime.date] | None:
242
242
  """Validates and returns time dimension values based on the selected times.
243
243