google-meridian 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -60,7 +60,7 @@ def _check_dim_collection(
60
60
  )
61
61
 
62
62
 
63
- def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray]):
63
+ def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray | None]):
64
64
  """Verifies that the dimensions of the appropriate arrays match."""
65
65
  lengths = [len(array.coords[dim]) for array in arrays if array is not None]
66
66
  names = [array.name for array in arrays if array is not None]
@@ -83,6 +83,31 @@ def _check_coords_match(dim: str, arrays: Sequence[xr.DataArray]):
83
83
  )
84
84
 
85
85
 
86
+ def _aggregate_spend(
87
+ spend: xr.DataArray, calibration_period: np.ndarray | None
88
+ ) -> np.ndarray | None:
89
+ """Aggregates spend for each channel over the calibration period.
90
+
91
+ Args:
92
+ spend: An array with shape `(n_geos, n_times, n_channels)` to aggregate.
93
+ calibration_period: An optional boolean array of shape `(n_media_times,
94
+ n_channels)`. If provided, spend is filtered according to this period.
95
+
96
+ Returns:
97
+ A 1-D array of aggregated media spend per channel, or `None` if `spend` is
98
+ `None`.
99
+ """
100
+ if spend is None:
101
+ return None
102
+
103
+ if calibration_period is None:
104
+ return np.sum(spend, axis=(0, 1))
105
+
106
+ # Select the last `n_times` from the `calibration_period`
107
+ factors = np.where(calibration_period[-spend.shape[1] :, :], 1, 0)
108
+ return np.einsum("gtm,tm->m", spend, factors)
109
+
110
+
86
111
  @dataclasses.dataclass
87
112
  class InputData:
88
113
  """A data container for advertising data in a format supported by Meridian.
@@ -96,11 +121,11 @@ class InputData:
96
121
  `revenue_per_kpi` exists, ROI calibration is used and the analysis is run
97
122
  on revenue. When the `revenue_per_kpi` doesn't exist for the same
98
123
  `kpi_type`, custom ROI calibration is used and the analysis is run on KPI.
99
- controls: A DataArray of dimensions `(n_geos, n_times, n_controls)`
100
- containing control variable values.
101
124
  population: A DataArray of dimensions `(n_geos,)` containing the population
102
125
  of each group. This variable is used to scale the KPI and media for
103
126
  modeling.
127
+ controls: An optional DataArray of dimensions `(n_geos, n_times,
128
+ n_controls)` containing control variable values.
104
129
  revenue_per_kpi: An optional DataArray of dimensions `(n_geos, n_times)`
105
130
  containing the average revenue amount per KPI unit. Although modeling is
106
131
  done on `kpi`, model analysis and optimization are done on `KPI *
@@ -120,8 +145,14 @@ class InputData:
120
145
  in the same order. If either of these arguments is passed, then the other
121
146
  is not optional.
122
147
  media_spend: An optional `DataArray` containing the cost of each media
123
- channel. This is used as the denominator for ROI calculations. The
124
- DataArray shape can be `(n_geos, n_times, n_media_channels)` or
148
+ channel. This is used as the denominator for ROI calculations. It is also
149
+ used to calculate an assumed cost per media unit for post-modeling
150
+ analysis such as response curves and budget optimization. Only the
151
+ aggregate spend (across geos and time periods) is required for these
152
+ calculations. However, a spend breakdown by geo and time period is
153
+ required if `roi_calibration_period` is specified or if conducting
154
+ post-modeling analysis on a specific subset of geos and/or time periods.
155
+ The DataArray shape can be `(n_geos, n_times, n_media_channels)` or
125
156
  `(n_media_channels,)` if the data is aggregated over `geo` and `time`
126
157
  dimensions. We recommend that the spend total aligns with the time window
127
158
  of the `kpi` and `controls` data, which is the time window over which
@@ -131,7 +162,9 @@ class InputData:
131
162
  time window of media executed during the time window. `media` and
132
163
  `media_spend` must contain the same number of media channels in the same
133
164
  order. If either of these arguments is passed, then the other is not
134
- optional.
165
+ optional. If a tensor of shape `(n_media_channels,)` is passed as
166
+ `media_spend`, then it will be automatically allocated across geos and
167
+ times proportinally to `media`.
135
168
  reach: An optional `DataArray` of dimensions `(n_geos, n_media_times,
136
169
  n_rf_channels)` containing non-negative `reach` values. It is required
137
170
  that `n_media_times` ≥ `n_times`, and the final `n_times` time periods
@@ -164,18 +197,26 @@ class InputData:
164
197
  others are not optional.
165
198
  rf_spend: An optional `DataArray` containing the cost of each reach and
166
199
  frequency channel. This is used as the denominator for ROI calculations.
167
- The DataArray shape can be `(n_rf_channels,)`, `(n_geos, n_times,
168
- n_rf_channels)`, or `(n_geos, n_rf_channels)`. The spend should be
169
- aggregated over geo and/or time dimensions that are not represented. We
170
- recommend that the spend total aligns with the time window of the `kpi`
171
- and `controls` data, which is the time window over which incremental
172
- outcome of the ROI numerator is calculated. However, note that incremental
173
- outcome is influenced by media execution prior to this time window,
174
- through lagged effects, and excludes lagged effects beyond the time window
175
- of media executed during the time window. If only `media` data is used,
176
- `rf_spend` will be `None`. `reach`, `frequency`, and `rf_spend` must
177
- contain the same number of media channels in the same order. If any of
178
- these arguments is passed, then the others are not optional.
200
+ It is also used to calculate an assumed cost per media unit for
201
+ post-modeling analysis such as response curves and budget optimization.
202
+ Only the aggregate spend (across geos and time periods) is required for
203
+ these calculations. However, a spend breakdown by geo and time period is
204
+ required if `rf_roi_calibration_period` is specified or if conducting
205
+ post-modeling analysis on a specific subset of geos and/or time periods.
206
+ The DataArray shape can be `(n_rf_channels,)` or `(n_geos, n_times,
207
+ n_rf_channels)`. The spend should be aggregated over geo and/or time
208
+ dimensions that are not represented. We recommend that the spend total
209
+ aligns with the time window of the `kpi` and `controls` data, which is the
210
+ time window over which incremental outcome of the ROI numerator is
211
+ calculated. However, note that incremental outcome is influenced by media
212
+ execution prior to this time window, through lagged effects, and excludes
213
+ lagged effects beyond the time window of media executed during the time
214
+ window. If only `media` data is used, `rf_spend` will be `None`. `reach`,
215
+ `frequency`, and `rf_spend` must contain the same number of media channels
216
+ in the same order. If any of these arguments is passed, then the others
217
+ are not optional. If a tensor of shape `(n_rf_channels,)` is passed as
218
+ `rf_spend`, then it will be automatically allocated across geos and times
219
+ proportionally to `(reach * frequency)`.
179
220
  organic_media: An optional `DataArray` of dimensions `(n_geos,
180
221
  n_media_times, n_organic_media_channels)` containing non-negative organic
181
222
  media values. Organic media variables are media activities that have no
@@ -234,8 +275,8 @@ class InputData:
234
275
 
235
276
  kpi: xr.DataArray
236
277
  kpi_type: str
237
- controls: xr.DataArray
238
278
  population: xr.DataArray
279
+ controls: xr.DataArray | None = None
239
280
  revenue_per_kpi: xr.DataArray | None = None
240
281
  media: xr.DataArray | None = None
241
282
  media_spend: xr.DataArray | None = None
@@ -265,6 +306,40 @@ class InputData:
265
306
  if isinstance(array, xr.DataArray) and constants.GEO in array.dims:
266
307
  array.coords[constants.GEO] = array.coords[constants.GEO].astype(str)
267
308
 
309
+ # TODO: b/416775065 - Combine with Analyzer._impute_and_aggregate_spend
310
+ @functools.cached_property
311
+ def allocated_media_spend(self) -> xr.DataArray | None:
312
+ """Returns the allocated media spend for each geo and time."""
313
+ if self.media_spend is not None and len(self.media_spend.shape) == 1:
314
+ return self._allocate_spend(self.media_spend, self.media)
315
+ else:
316
+ return self.media_spend
317
+
318
+ @property
319
+ def allocated_rf_spend(self) -> xr.DataArray | None:
320
+ """Returns the allocated RF spend for each geo and time."""
321
+ if self.rf_spend is not None and len(self.rf_spend.shape) == 1:
322
+ return self._allocate_spend(self.rf_spend, self.reach * self.frequency)
323
+ else:
324
+ return self.rf_spend
325
+
326
+ def aggregate_media_spend(
327
+ self, calibration_period: np.ndarray | None = None
328
+ ) -> np.ndarray | None:
329
+ """Aggregates media spend by channel over the calibration period."""
330
+ return _aggregate_spend(
331
+ spend=self.allocated_media_spend, calibration_period=calibration_period
332
+ )
333
+
334
+ def aggregate_rf_spend(
335
+ self, calibration_period: np.ndarray | None = None
336
+ ) -> np.ndarray | None:
337
+ """Aggregates RF spend by channel over the calibration period."""
338
+ return _aggregate_spend(
339
+ spend=self.allocated_rf_spend,
340
+ calibration_period=calibration_period,
341
+ )
342
+
268
343
  @property
269
344
  def geo(self) -> xr.DataArray:
270
345
  """Returns the geo dimension."""
@@ -334,9 +409,12 @@ class InputData:
334
409
  return None
335
410
 
336
411
  @property
337
- def control_variable(self) -> xr.DataArray:
412
+ def control_variable(self) -> xr.DataArray | None:
338
413
  """Returns the control variable dimension."""
339
- return self.controls[constants.CONTROL_VARIABLE]
414
+ if self.controls is not None:
415
+ return self.controls[constants.CONTROL_VARIABLE]
416
+ else:
417
+ return None
340
418
 
341
419
  @property
342
420
  def media_spend_has_geo_dimension(self) -> bool:
@@ -424,10 +502,11 @@ class InputData:
424
502
 
425
503
  def _validate_names(self):
426
504
  """Verifies that the names of the data arrays are correct."""
427
- arrays = [
505
+ # Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
506
+ arrays = (
428
507
  self.kpi,
429
- self.controls,
430
508
  self.population,
509
+ self.controls,
431
510
  self.revenue_per_kpi,
432
511
  self.organic_media,
433
512
  self.organic_reach,
@@ -438,7 +517,7 @@ class InputData:
438
517
  self.reach,
439
518
  self.frequency,
440
519
  self.rf_spend,
441
- ]
520
+ )
442
521
 
443
522
  for array, name in zip(arrays, constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES):
444
523
  if array is not None and array.name != name:
@@ -479,7 +558,6 @@ class InputData:
479
558
  [
480
559
  [constants.RF_CHANNEL],
481
560
  [constants.GEO, constants.TIME, constants.RF_CHANNEL],
482
- [constants.GEO, constants.RF_CHANNEL],
483
561
  ],
484
562
  )
485
563
  _check_dim_collection(
@@ -711,9 +789,10 @@ class InputData:
711
789
  """Returns data as a single `xarray.Dataset` object."""
712
790
  data = [
713
791
  self.kpi,
714
- self.controls,
715
792
  self.population,
716
793
  ]
794
+ if self.controls is not None:
795
+ data.append(self.controls)
717
796
  if self.revenue_per_kpi is not None:
718
797
  data.append(self.revenue_per_kpi)
719
798
  if self.media is not None:
@@ -848,3 +927,24 @@ class InputData:
848
927
  return self.media_spend.values
849
928
  else:
850
929
  raise ValueError("Both RF and Media are missing.")
930
+
931
+ def get_total_outcome(self) -> np.ndarray:
932
+ """Returns total outcome, aggregated over geos and times."""
933
+ if self.revenue_per_kpi is None:
934
+ return np.sum(self.kpi.values)
935
+ return np.sum(self.kpi.values * self.revenue_per_kpi.values)
936
+
937
+ def _allocate_spend(self, spend: xr.DataArray, media_units: xr.DataArray):
938
+ """Allocates spend across geo and time proportionally to media units."""
939
+ n_times = len(self.kpi.coords[constants.TIME])
940
+ selected_media_units = media_units.isel(media_time=slice(-n_times, None))
941
+ total_media_units_per_channel = selected_media_units.sum(
942
+ dim=["geo", "media_time"]
943
+ )
944
+ proportions = selected_media_units / total_media_units_per_channel
945
+ expanded_spend = spend.expand_dims({
946
+ "geo": selected_media_units["geo"],
947
+ "media_time": selected_media_units["media_time"],
948
+ })
949
+ allocated_spend = expanded_spend * proportions
950
+ return allocated_spend.rename({"media_time": "time"})
meridian/data/load.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -79,7 +79,7 @@ class XrDatasetDataLoader(InputDataLoader):
79
79
  """Constructor.
80
80
 
81
81
  The coordinates of the input dataset should be: `time`, `media_time`,
82
- `control_variable`, `geo` (optional for a national model),
82
+ `control_variable` (optional), `geo` (optional for a national model),
83
83
  `non_media_channel` (optional), `organic_media_channel` (optional),
84
84
  `organic_rf_channel` (optional), and
85
85
  either `media_channel`, `rf_channel`, or both.
@@ -93,7 +93,7 @@ class XrDatasetDataLoader(InputDataLoader):
93
93
 
94
94
  * `kpi`: `(geo, time)`
95
95
  * `revenue_per_kpi`: `(geo, time)`
96
- * `controls`: `(geo, time, control_variable)`
96
+ * `controls`: `(geo, time, control_variable)` - optional
97
97
  * `population`: `(geo)`
98
98
  * `media`: `(geo, media_time, media_channel)` - optional
99
99
  * `media_spend`: `(geo, time, media_channel)`, `(1, time, media_channel)`,
@@ -113,7 +113,7 @@ class XrDatasetDataLoader(InputDataLoader):
113
113
 
114
114
  * `kpi`: `([1,] time)`
115
115
  * `revenue_per_kpi`: `([1,] time)`
116
- * `controls`: `([1,] time, control_variable)`
116
+ * `controls`: `([1,] time, control_variable)` - optional
117
117
  * `population`: `([1],)` - this array is optional for national data
118
118
  * `media`: `([1,] media_time, media_channel)` - optional
119
119
  * `media_spend`: `([1,] time, media_channel)` or
@@ -198,7 +198,7 @@ class XrDatasetDataLoader(InputDataLoader):
198
198
  self.dataset = dataset.rename(name_mapping)
199
199
 
200
200
  # Add a `geo` dimension if it is not already present.
201
- if (constants.GEO) not in self.dataset.dims.keys():
201
+ if (constants.GEO) not in self.dataset.sizes.keys():
202
202
  self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
203
203
 
204
204
  if len(self.dataset.coords[constants.GEO]) == 1:
@@ -228,7 +228,7 @@ class XrDatasetDataLoader(InputDataLoader):
228
228
  compat='override',
229
229
  )
230
230
 
231
- if constants.MEDIA_TIME not in self.dataset.dims.keys():
231
+ if constants.MEDIA_TIME not in self.dataset.sizes.keys():
232
232
  self._add_media_time()
233
233
  self._normalize_time_coordinates(constants.TIME)
234
234
  self._normalize_time_coordinates(constants.MEDIA_TIME)
@@ -349,14 +349,17 @@ class XrDatasetDataLoader(InputDataLoader):
349
349
  # Arrays in which NAs are expected in the lagged-media period.
350
350
  na_arrays = [
351
351
  constants.KPI,
352
- constants.CONTROLS,
353
352
  ]
354
353
 
355
- na_mask = self.dataset[constants.KPI].isnull().any(
356
- dim=constants.GEO
357
- ) | self.dataset[constants.CONTROLS].isnull().any(
358
- dim=[constants.GEO, constants.CONTROL_VARIABLE]
359
- )
354
+ na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
355
+
356
+ if constants.CONTROLS in self.dataset.data_vars.keys():
357
+ na_arrays.append(constants.CONTROLS)
358
+ na_mask |= (
359
+ self.dataset[constants.CONTROLS]
360
+ .isnull()
361
+ .any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
362
+ )
360
363
 
361
364
  if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
362
365
  na_arrays.append(constants.NON_MEDIA_TREATMENTS)
@@ -427,11 +430,12 @@ class XrDatasetDataLoader(InputDataLoader):
427
430
  .dropna(dim=constants.TIME)
428
431
  .rename({constants.TIME: new_time})
429
432
  )
430
- new_dataset[constants.CONTROLS] = (
431
- new_dataset[constants.CONTROLS]
432
- .dropna(dim=constants.TIME)
433
- .rename({constants.TIME: new_time})
434
- )
433
+ if constants.CONTROLS in new_dataset.data_vars.keys():
434
+ new_dataset[constants.CONTROLS] = (
435
+ new_dataset[constants.CONTROLS]
436
+ .dropna(dim=constants.TIME)
437
+ .rename({constants.TIME: new_time})
438
+ )
435
439
  if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
436
440
  new_dataset[constants.NON_MEDIA_TREATMENTS] = (
437
441
  new_dataset[constants.NON_MEDIA_TREATMENTS]
@@ -466,6 +470,11 @@ class XrDatasetDataLoader(InputDataLoader):
466
470
 
467
471
  def load(self) -> input_data.InputData:
468
472
  """Returns an `InputData` object containing the data from the dataset."""
473
+ controls = (
474
+ self.dataset.controls
475
+ if constants.CONTROLS in self.dataset.data_vars.keys()
476
+ else None
477
+ )
469
478
  revenue_per_kpi = (
470
479
  self.dataset.revenue_per_kpi
471
480
  if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys()
@@ -519,9 +528,9 @@ class XrDatasetDataLoader(InputDataLoader):
519
528
  return input_data.InputData(
520
529
  kpi=self.dataset.kpi,
521
530
  kpi_type=self.kpi_type,
522
- revenue_per_kpi=revenue_per_kpi,
523
- controls=self.dataset.controls,
524
531
  population=self.dataset.population,
532
+ controls=controls,
533
+ revenue_per_kpi=revenue_per_kpi,
525
534
  media=media,
526
535
  media_spend=media_spend,
527
536
  reach=reach,
@@ -539,14 +548,14 @@ class CoordToColumns:
539
548
  """A mapping between the desired and actual column names in the input data.
540
549
 
541
550
  Attributes:
542
- controls: List of column names containing `controls` values in the input
543
- data.
544
551
  time: Name of column containing `time` values in the input data.
545
- kpi: Name of column containing `kpi` values in the input data.
546
- revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
547
- input data.
548
552
  geo: Name of column containing `geo` values in the input data. This field
549
553
  is optional for a national model.
554
+ kpi: Name of column containing `kpi` values in the input data.
555
+ controls: List of column names containing `controls` values in the input
556
+ data. Optional.
557
+ revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
558
+ input data. Optional. Will be overridden if model KPI type is "revenue".
550
559
  population: Name of column containing `population` values in the input data.
551
560
  This field is optional for a national model.
552
561
  media: List of column names containing `media` values in the input data.
@@ -567,11 +576,11 @@ class CoordToColumns:
567
576
  values in the input data.
568
577
  """
569
578
 
570
- controls: Sequence[str]
571
579
  time: str = constants.TIME
580
+ geo: str = constants.GEO
572
581
  kpi: str = constants.KPI
582
+ controls: Sequence[str] | None = None
573
583
  revenue_per_kpi: str | None = None
574
- geo: str = constants.GEO
575
584
  population: str = constants.POPULATION
576
585
  # Media data
577
586
  media: Sequence[str] | None = None
@@ -607,7 +616,7 @@ class DataFrameDataLoader(InputDataLoader):
607
616
  to the DataFrame column names if they are different. The fields are:
608
617
 
609
618
  * `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
610
- * `controls` (multiple columns)
619
+ * `controls` (multiple columns, optional)
611
620
  * (1) `media`, `media_spend` (multiple columns)
612
621
  * (2) `reach`, `frequency`, `rf_spend` (multiple columns)
613
622
  * `non_media_treatments` (multiple columns, optional)
@@ -953,9 +962,10 @@ class DataFrameDataLoader(InputDataLoader):
953
962
  not_lagged_columns = []
954
963
  coords = [
955
964
  constants.KPI,
956
- constants.CONTROLS,
957
965
  constants.POPULATION,
958
966
  ]
967
+ if self.coord_to_columns.controls is not None:
968
+ coords.append(constants.CONTROLS)
959
969
  if self.coord_to_columns.revenue_per_kpi is not None:
960
970
  coords.append(constants.REVENUE_PER_KPI)
961
971
  if self.coord_to_columns.media_spend is not None:
@@ -1042,17 +1052,20 @@ class DataFrameDataLoader(InputDataLoader):
1042
1052
  .to_frame()
1043
1053
  .to_xarray()
1044
1054
  )
1045
- controls_xr = (
1046
- df_indexed[self.coord_to_columns.controls]
1047
- .stack()
1048
- .rename(constants.CONTROLS)
1049
- .rename_axis(
1050
- [constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
1051
- )
1052
- .to_frame()
1053
- .to_xarray()
1054
- )
1055
- dataset = xr.combine_by_coords([kpi_xr, population_xr, controls_xr])
1055
+ dataset = xr.combine_by_coords([kpi_xr, population_xr])
1056
+
1057
+ if self.coord_to_columns.controls is not None:
1058
+ controls_xr = (
1059
+ df_indexed[self.coord_to_columns.controls]
1060
+ .stack()
1061
+ .rename(constants.CONTROLS)
1062
+ .rename_axis(
1063
+ [constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
1064
+ )
1065
+ .to_frame()
1066
+ .to_xarray()
1067
+ )
1068
+ dataset = xr.combine_by_coords([dataset, controls_xr])
1056
1069
 
1057
1070
  if self.coord_to_columns.non_media_treatments is not None:
1058
1071
  non_media_xr = (
@@ -1224,7 +1237,7 @@ class CsvDataLoader(InputDataLoader):
1224
1237
  CSV column names, if they are different. The fields are:
1225
1238
 
1226
1239
  * `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
1227
- * `controls` (multiple columns)
1240
+ * `controls` (multiple columns, optional)
1228
1241
  * (1) `media`, `media_spend` (multiple columns)
1229
1242
  * (2) `reach`, `frequency`, `rf_spend` (multiple columns)
1230
1243
  * `non_media_treatments` (multiple columns, optional)