google-meridian 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.0.9.dist-info → google_meridian-1.1.1.dist-info}/METADATA +2 -2
- google_meridian-1.1.1.dist-info/RECORD +41 -0
- {google_meridian-1.0.9.dist-info → google_meridian-1.1.1.dist-info}/WHEEL +1 -1
- meridian/__init__.py +2 -2
- meridian/analysis/__init__.py +1 -1
- meridian/analysis/analyzer.py +213 -206
- meridian/analysis/formatter.py +1 -1
- meridian/analysis/optimizer.py +264 -66
- meridian/analysis/summarizer.py +5 -5
- meridian/analysis/summary_text.py +1 -1
- meridian/analysis/test_utils.py +82 -82
- meridian/analysis/visualizer.py +14 -19
- meridian/constants.py +103 -19
- meridian/data/__init__.py +1 -1
- meridian/data/arg_builder.py +1 -1
- meridian/data/input_data.py +127 -27
- meridian/data/load.py +53 -40
- meridian/data/test_utils.py +172 -44
- meridian/data/time_coordinates.py +4 -4
- meridian/model/__init__.py +1 -1
- meridian/model/adstock_hill.py +1 -1
- meridian/model/knots.py +1 -1
- meridian/model/media.py +134 -99
- meridian/model/model.py +494 -84
- meridian/model/model_test_data.py +86 -1
- meridian/model/posterior_sampler.py +139 -58
- meridian/model/prior_distribution.py +97 -52
- meridian/model/prior_sampler.py +209 -233
- meridian/model/spec.py +197 -37
- meridian/model/transformers.py +16 -4
- google_meridian-1.0.9.dist-info/RECORD +0 -41
- {google_meridian-1.0.9.dist-info → google_meridian-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.0.9.dist-info → google_meridian-1.1.1.dist-info}/top_level.txt +0 -0
meridian/data/input_data.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -60,7 +60,7 @@ def _check_dim_collection(
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray]):
|
|
63
|
+
def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray | None]):
|
|
64
64
|
"""Verifies that the dimensions of the appropriate arrays match."""
|
|
65
65
|
lengths = [len(array.coords[dim]) for array in arrays if array is not None]
|
|
66
66
|
names = [array.name for array in arrays if array is not None]
|
|
@@ -83,6 +83,31 @@ def _check_coords_match(dim: str, arrays: Sequence[xr.DataArray]):
|
|
|
83
83
|
)
|
|
84
84
|
|
|
85
85
|
|
|
86
|
+
def _aggregate_spend(
|
|
87
|
+
spend: xr.DataArray, calibration_period: np.ndarray | None
|
|
88
|
+
) -> np.ndarray | None:
|
|
89
|
+
"""Aggregates spend for each channel over the calibration period.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
spend: An array with shape `(n_geos, n_times, n_channels)` to aggregate.
|
|
93
|
+
calibration_period: An optional boolean array of shape `(n_media_times,
|
|
94
|
+
n_channels)`. If provided, spend is filtered according to this period.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
A 1-D array of aggregated media spend per channel, or `None` if `spend` is
|
|
98
|
+
`None`.
|
|
99
|
+
"""
|
|
100
|
+
if spend is None:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
if calibration_period is None:
|
|
104
|
+
return np.sum(spend, axis=(0, 1))
|
|
105
|
+
|
|
106
|
+
# Select the last `n_times` from the `calibration_period`
|
|
107
|
+
factors = np.where(calibration_period[-spend.shape[1] :, :], 1, 0)
|
|
108
|
+
return np.einsum("gtm,tm->m", spend, factors)
|
|
109
|
+
|
|
110
|
+
|
|
86
111
|
@dataclasses.dataclass
|
|
87
112
|
class InputData:
|
|
88
113
|
"""A data container for advertising data in a format supported by Meridian.
|
|
@@ -96,11 +121,11 @@ class InputData:
|
|
|
96
121
|
`revenue_per_kpi` exists, ROI calibration is used and the analysis is run
|
|
97
122
|
on revenue. When the `revenue_per_kpi` doesn't exist for the same
|
|
98
123
|
`kpi_type`, custom ROI calibration is used and the analysis is run on KPI.
|
|
99
|
-
controls: A DataArray of dimensions `(n_geos, n_times, n_controls)`
|
|
100
|
-
containing control variable values.
|
|
101
124
|
population: A DataArray of dimensions `(n_geos,)` containing the population
|
|
102
125
|
of each group. This variable is used to scale the KPI and media for
|
|
103
126
|
modeling.
|
|
127
|
+
controls: An optional DataArray of dimensions `(n_geos, n_times,
|
|
128
|
+
n_controls)` containing control variable values.
|
|
104
129
|
revenue_per_kpi: An optional DataArray of dimensions `(n_geos, n_times)`
|
|
105
130
|
containing the average revenue amount per KPI unit. Although modeling is
|
|
106
131
|
done on `kpi`, model analysis and optimization are done on `KPI *
|
|
@@ -120,8 +145,14 @@ class InputData:
|
|
|
120
145
|
in the same order. If either of these arguments is passed, then the other
|
|
121
146
|
is not optional.
|
|
122
147
|
media_spend: An optional `DataArray` containing the cost of each media
|
|
123
|
-
channel. This is used as the denominator for ROI calculations.
|
|
124
|
-
|
|
148
|
+
channel. This is used as the denominator for ROI calculations. It is also
|
|
149
|
+
used to calculate an assumed cost per media unit for post-modeling
|
|
150
|
+
analysis such as response curves and budget optimization. Only the
|
|
151
|
+
aggregate spend (across geos and time periods) is required for these
|
|
152
|
+
calculations. However, a spend breakdown by geo and time period is
|
|
153
|
+
required if `roi_calibration_period` is specified or if conducting
|
|
154
|
+
post-modeling analysis on a specific subset of geos and/or time periods.
|
|
155
|
+
The DataArray shape can be `(n_geos, n_times, n_media_channels)` or
|
|
125
156
|
`(n_media_channels,)` if the data is aggregated over `geo` and `time`
|
|
126
157
|
dimensions. We recommend that the spend total aligns with the time window
|
|
127
158
|
of the `kpi` and `controls` data, which is the time window over which
|
|
@@ -131,7 +162,9 @@ class InputData:
|
|
|
131
162
|
time window of media executed during the time window. `media` and
|
|
132
163
|
`media_spend` must contain the same number of media channels in the same
|
|
133
164
|
order. If either of these arguments is passed, then the other is not
|
|
134
|
-
optional.
|
|
165
|
+
optional. If a tensor of shape `(n_media_channels,)` is passed as
|
|
166
|
+
`media_spend`, then it will be automatically allocated across geos and
|
|
167
|
+
times proportinally to `media`.
|
|
135
168
|
reach: An optional `DataArray` of dimensions `(n_geos, n_media_times,
|
|
136
169
|
n_rf_channels)` containing non-negative `reach` values. It is required
|
|
137
170
|
that `n_media_times` ≥ `n_times`, and the final `n_times` time periods
|
|
@@ -164,18 +197,26 @@ class InputData:
|
|
|
164
197
|
others are not optional.
|
|
165
198
|
rf_spend: An optional `DataArray` containing the cost of each reach and
|
|
166
199
|
frequency channel. This is used as the denominator for ROI calculations.
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
200
|
+
It is also used to calculate an assumed cost per media unit for
|
|
201
|
+
post-modeling analysis such as response curves and budget optimization.
|
|
202
|
+
Only the aggregate spend (across geos and time periods) is required for
|
|
203
|
+
these calculations. However, a spend breakdown by geo and time period is
|
|
204
|
+
required if `rf_roi_calibration_period` is specified or if conducting
|
|
205
|
+
post-modeling analysis on a specific subset of geos and/or time periods.
|
|
206
|
+
The DataArray shape can be `(n_rf_channels,)` or `(n_geos, n_times,
|
|
207
|
+
n_rf_channels)`. The spend should be aggregated over geo and/or time
|
|
208
|
+
dimensions that are not represented. We recommend that the spend total
|
|
209
|
+
aligns with the time window of the `kpi` and `controls` data, which is the
|
|
210
|
+
time window over which incremental outcome of the ROI numerator is
|
|
211
|
+
calculated. However, note that incremental outcome is influenced by media
|
|
212
|
+
execution prior to this time window, through lagged effects, and excludes
|
|
213
|
+
lagged effects beyond the time window of media executed during the time
|
|
214
|
+
window. If only `media` data is used, `rf_spend` will be `None`. `reach`,
|
|
215
|
+
`frequency`, and `rf_spend` must contain the same number of media channels
|
|
216
|
+
in the same order. If any of these arguments is passed, then the others
|
|
217
|
+
are not optional. If a tensor of shape `(n_rf_channels,)` is passed as
|
|
218
|
+
`rf_spend`, then it will be automatically allocated across geos and times
|
|
219
|
+
proportionally to `(reach * frequency)`.
|
|
179
220
|
organic_media: An optional `DataArray` of dimensions `(n_geos,
|
|
180
221
|
n_media_times, n_organic_media_channels)` containing non-negative organic
|
|
181
222
|
media values. Organic media variables are media activities that have no
|
|
@@ -234,8 +275,8 @@ class InputData:
|
|
|
234
275
|
|
|
235
276
|
kpi: xr.DataArray
|
|
236
277
|
kpi_type: str
|
|
237
|
-
controls: xr.DataArray
|
|
238
278
|
population: xr.DataArray
|
|
279
|
+
controls: xr.DataArray | None = None
|
|
239
280
|
revenue_per_kpi: xr.DataArray | None = None
|
|
240
281
|
media: xr.DataArray | None = None
|
|
241
282
|
media_spend: xr.DataArray | None = None
|
|
@@ -265,6 +306,40 @@ class InputData:
|
|
|
265
306
|
if isinstance(array, xr.DataArray) and constants.GEO in array.dims:
|
|
266
307
|
array.coords[constants.GEO] = array.coords[constants.GEO].astype(str)
|
|
267
308
|
|
|
309
|
+
# TODO: b/416775065 - Combine with Analyzer._impute_and_aggregate_spend
|
|
310
|
+
@functools.cached_property
|
|
311
|
+
def allocated_media_spend(self) -> xr.DataArray | None:
|
|
312
|
+
"""Returns the allocated media spend for each geo and time."""
|
|
313
|
+
if self.media_spend is not None and len(self.media_spend.shape) == 1:
|
|
314
|
+
return self._allocate_spend(self.media_spend, self.media)
|
|
315
|
+
else:
|
|
316
|
+
return self.media_spend
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def allocated_rf_spend(self) -> xr.DataArray | None:
|
|
320
|
+
"""Returns the allocated RF spend for each geo and time."""
|
|
321
|
+
if self.rf_spend is not None and len(self.rf_spend.shape) == 1:
|
|
322
|
+
return self._allocate_spend(self.rf_spend, self.reach * self.frequency)
|
|
323
|
+
else:
|
|
324
|
+
return self.rf_spend
|
|
325
|
+
|
|
326
|
+
def aggregate_media_spend(
|
|
327
|
+
self, calibration_period: np.ndarray | None = None
|
|
328
|
+
) -> np.ndarray | None:
|
|
329
|
+
"""Aggregates media spend by channel over the calibration period."""
|
|
330
|
+
return _aggregate_spend(
|
|
331
|
+
spend=self.allocated_media_spend, calibration_period=calibration_period
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def aggregate_rf_spend(
|
|
335
|
+
self, calibration_period: np.ndarray | None = None
|
|
336
|
+
) -> np.ndarray | None:
|
|
337
|
+
"""Aggregates RF spend by channel over the calibration period."""
|
|
338
|
+
return _aggregate_spend(
|
|
339
|
+
spend=self.allocated_rf_spend,
|
|
340
|
+
calibration_period=calibration_period,
|
|
341
|
+
)
|
|
342
|
+
|
|
268
343
|
@property
|
|
269
344
|
def geo(self) -> xr.DataArray:
|
|
270
345
|
"""Returns the geo dimension."""
|
|
@@ -334,9 +409,12 @@ class InputData:
|
|
|
334
409
|
return None
|
|
335
410
|
|
|
336
411
|
@property
|
|
337
|
-
def control_variable(self) -> xr.DataArray:
|
|
412
|
+
def control_variable(self) -> xr.DataArray | None:
|
|
338
413
|
"""Returns the control variable dimension."""
|
|
339
|
-
|
|
414
|
+
if self.controls is not None:
|
|
415
|
+
return self.controls[constants.CONTROL_VARIABLE]
|
|
416
|
+
else:
|
|
417
|
+
return None
|
|
340
418
|
|
|
341
419
|
@property
|
|
342
420
|
def media_spend_has_geo_dimension(self) -> bool:
|
|
@@ -424,10 +502,11 @@ class InputData:
|
|
|
424
502
|
|
|
425
503
|
def _validate_names(self):
|
|
426
504
|
"""Verifies that the names of the data arrays are correct."""
|
|
427
|
-
|
|
505
|
+
# Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
|
|
506
|
+
arrays = (
|
|
428
507
|
self.kpi,
|
|
429
|
-
self.controls,
|
|
430
508
|
self.population,
|
|
509
|
+
self.controls,
|
|
431
510
|
self.revenue_per_kpi,
|
|
432
511
|
self.organic_media,
|
|
433
512
|
self.organic_reach,
|
|
@@ -438,7 +517,7 @@ class InputData:
|
|
|
438
517
|
self.reach,
|
|
439
518
|
self.frequency,
|
|
440
519
|
self.rf_spend,
|
|
441
|
-
|
|
520
|
+
)
|
|
442
521
|
|
|
443
522
|
for array, name in zip(arrays, constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES):
|
|
444
523
|
if array is not None and array.name != name:
|
|
@@ -479,7 +558,6 @@ class InputData:
|
|
|
479
558
|
[
|
|
480
559
|
[constants.RF_CHANNEL],
|
|
481
560
|
[constants.GEO, constants.TIME, constants.RF_CHANNEL],
|
|
482
|
-
[constants.GEO, constants.RF_CHANNEL],
|
|
483
561
|
],
|
|
484
562
|
)
|
|
485
563
|
_check_dim_collection(
|
|
@@ -711,9 +789,10 @@ class InputData:
|
|
|
711
789
|
"""Returns data as a single `xarray.Dataset` object."""
|
|
712
790
|
data = [
|
|
713
791
|
self.kpi,
|
|
714
|
-
self.controls,
|
|
715
792
|
self.population,
|
|
716
793
|
]
|
|
794
|
+
if self.controls is not None:
|
|
795
|
+
data.append(self.controls)
|
|
717
796
|
if self.revenue_per_kpi is not None:
|
|
718
797
|
data.append(self.revenue_per_kpi)
|
|
719
798
|
if self.media is not None:
|
|
@@ -848,3 +927,24 @@ class InputData:
|
|
|
848
927
|
return self.media_spend.values
|
|
849
928
|
else:
|
|
850
929
|
raise ValueError("Both RF and Media are missing.")
|
|
930
|
+
|
|
931
|
+
def get_total_outcome(self) -> np.ndarray:
|
|
932
|
+
"""Returns total outcome, aggregated over geos and times."""
|
|
933
|
+
if self.revenue_per_kpi is None:
|
|
934
|
+
return np.sum(self.kpi.values)
|
|
935
|
+
return np.sum(self.kpi.values * self.revenue_per_kpi.values)
|
|
936
|
+
|
|
937
|
+
def _allocate_spend(self, spend: xr.DataArray, media_units: xr.DataArray):
|
|
938
|
+
"""Allocates spend across geo and time proportionally to media units."""
|
|
939
|
+
n_times = len(self.kpi.coords[constants.TIME])
|
|
940
|
+
selected_media_units = media_units.isel(media_time=slice(-n_times, None))
|
|
941
|
+
total_media_units_per_channel = selected_media_units.sum(
|
|
942
|
+
dim=["geo", "media_time"]
|
|
943
|
+
)
|
|
944
|
+
proportions = selected_media_units / total_media_units_per_channel
|
|
945
|
+
expanded_spend = spend.expand_dims({
|
|
946
|
+
"geo": selected_media_units["geo"],
|
|
947
|
+
"media_time": selected_media_units["media_time"],
|
|
948
|
+
})
|
|
949
|
+
allocated_spend = expanded_spend * proportions
|
|
950
|
+
return allocated_spend.rename({"media_time": "time"})
|
meridian/data/load.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -79,7 +79,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
79
79
|
"""Constructor.
|
|
80
80
|
|
|
81
81
|
The coordinates of the input dataset should be: `time`, `media_time`,
|
|
82
|
-
`control_variable
|
|
82
|
+
`control_variable` (optional), `geo` (optional for a national model),
|
|
83
83
|
`non_media_channel` (optional), `organic_media_channel` (optional),
|
|
84
84
|
`organic_rf_channel` (optional), and
|
|
85
85
|
either `media_channel`, `rf_channel`, or both.
|
|
@@ -93,7 +93,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
93
93
|
|
|
94
94
|
* `kpi`: `(geo, time)`
|
|
95
95
|
* `revenue_per_kpi`: `(geo, time)`
|
|
96
|
-
* `controls`: `(geo, time, control_variable)`
|
|
96
|
+
* `controls`: `(geo, time, control_variable)` - optional
|
|
97
97
|
* `population`: `(geo)`
|
|
98
98
|
* `media`: `(geo, media_time, media_channel)` - optional
|
|
99
99
|
* `media_spend`: `(geo, time, media_channel)`, `(1, time, media_channel)`,
|
|
@@ -113,7 +113,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
113
113
|
|
|
114
114
|
* `kpi`: `([1,] time)`
|
|
115
115
|
* `revenue_per_kpi`: `([1,] time)`
|
|
116
|
-
* `controls`: `([1,] time, control_variable)`
|
|
116
|
+
* `controls`: `([1,] time, control_variable)` - optional
|
|
117
117
|
* `population`: `([1],)` - this array is optional for national data
|
|
118
118
|
* `media`: `([1,] media_time, media_channel)` - optional
|
|
119
119
|
* `media_spend`: `([1,] time, media_channel)` or
|
|
@@ -198,7 +198,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
198
198
|
self.dataset = dataset.rename(name_mapping)
|
|
199
199
|
|
|
200
200
|
# Add a `geo` dimension if it is not already present.
|
|
201
|
-
if (constants.GEO) not in self.dataset.
|
|
201
|
+
if (constants.GEO) not in self.dataset.sizes.keys():
|
|
202
202
|
self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
|
|
203
203
|
|
|
204
204
|
if len(self.dataset.coords[constants.GEO]) == 1:
|
|
@@ -228,7 +228,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
228
228
|
compat='override',
|
|
229
229
|
)
|
|
230
230
|
|
|
231
|
-
if constants.MEDIA_TIME not in self.dataset.
|
|
231
|
+
if constants.MEDIA_TIME not in self.dataset.sizes.keys():
|
|
232
232
|
self._add_media_time()
|
|
233
233
|
self._normalize_time_coordinates(constants.TIME)
|
|
234
234
|
self._normalize_time_coordinates(constants.MEDIA_TIME)
|
|
@@ -349,14 +349,17 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
349
349
|
# Arrays in which NAs are expected in the lagged-media period.
|
|
350
350
|
na_arrays = [
|
|
351
351
|
constants.KPI,
|
|
352
|
-
constants.CONTROLS,
|
|
353
352
|
]
|
|
354
353
|
|
|
355
|
-
na_mask = self.dataset[constants.KPI].isnull().any(
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
354
|
+
na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
|
|
355
|
+
|
|
356
|
+
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
357
|
+
na_arrays.append(constants.CONTROLS)
|
|
358
|
+
na_mask |= (
|
|
359
|
+
self.dataset[constants.CONTROLS]
|
|
360
|
+
.isnull()
|
|
361
|
+
.any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
|
|
362
|
+
)
|
|
360
363
|
|
|
361
364
|
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
362
365
|
na_arrays.append(constants.NON_MEDIA_TREATMENTS)
|
|
@@ -427,11 +430,12 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
427
430
|
.dropna(dim=constants.TIME)
|
|
428
431
|
.rename({constants.TIME: new_time})
|
|
429
432
|
)
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
433
|
+
if constants.CONTROLS in new_dataset.data_vars.keys():
|
|
434
|
+
new_dataset[constants.CONTROLS] = (
|
|
435
|
+
new_dataset[constants.CONTROLS]
|
|
436
|
+
.dropna(dim=constants.TIME)
|
|
437
|
+
.rename({constants.TIME: new_time})
|
|
438
|
+
)
|
|
435
439
|
if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
|
|
436
440
|
new_dataset[constants.NON_MEDIA_TREATMENTS] = (
|
|
437
441
|
new_dataset[constants.NON_MEDIA_TREATMENTS]
|
|
@@ -466,6 +470,11 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
466
470
|
|
|
467
471
|
def load(self) -> input_data.InputData:
|
|
468
472
|
"""Returns an `InputData` object containing the data from the dataset."""
|
|
473
|
+
controls = (
|
|
474
|
+
self.dataset.controls
|
|
475
|
+
if constants.CONTROLS in self.dataset.data_vars.keys()
|
|
476
|
+
else None
|
|
477
|
+
)
|
|
469
478
|
revenue_per_kpi = (
|
|
470
479
|
self.dataset.revenue_per_kpi
|
|
471
480
|
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys()
|
|
@@ -519,9 +528,9 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
519
528
|
return input_data.InputData(
|
|
520
529
|
kpi=self.dataset.kpi,
|
|
521
530
|
kpi_type=self.kpi_type,
|
|
522
|
-
revenue_per_kpi=revenue_per_kpi,
|
|
523
|
-
controls=self.dataset.controls,
|
|
524
531
|
population=self.dataset.population,
|
|
532
|
+
controls=controls,
|
|
533
|
+
revenue_per_kpi=revenue_per_kpi,
|
|
525
534
|
media=media,
|
|
526
535
|
media_spend=media_spend,
|
|
527
536
|
reach=reach,
|
|
@@ -539,14 +548,14 @@ class CoordToColumns:
|
|
|
539
548
|
"""A mapping between the desired and actual column names in the input data.
|
|
540
549
|
|
|
541
550
|
Attributes:
|
|
542
|
-
controls: List of column names containing `controls` values in the input
|
|
543
|
-
data.
|
|
544
551
|
time: Name of column containing `time` values in the input data.
|
|
545
|
-
kpi: Name of column containing `kpi` values in the input data.
|
|
546
|
-
revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
|
|
547
|
-
input data.
|
|
548
552
|
geo: Name of column containing `geo` values in the input data. This field
|
|
549
553
|
is optional for a national model.
|
|
554
|
+
kpi: Name of column containing `kpi` values in the input data.
|
|
555
|
+
controls: List of column names containing `controls` values in the input
|
|
556
|
+
data. Optional.
|
|
557
|
+
revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
|
|
558
|
+
input data. Optional. Will be overridden if model KPI type is "revenue".
|
|
550
559
|
population: Name of column containing `population` values in the input data.
|
|
551
560
|
This field is optional for a national model.
|
|
552
561
|
media: List of column names containing `media` values in the input data.
|
|
@@ -567,11 +576,11 @@ class CoordToColumns:
|
|
|
567
576
|
values in the input data.
|
|
568
577
|
"""
|
|
569
578
|
|
|
570
|
-
controls: Sequence[str]
|
|
571
579
|
time: str = constants.TIME
|
|
580
|
+
geo: str = constants.GEO
|
|
572
581
|
kpi: str = constants.KPI
|
|
582
|
+
controls: Sequence[str] | None = None
|
|
573
583
|
revenue_per_kpi: str | None = None
|
|
574
|
-
geo: str = constants.GEO
|
|
575
584
|
population: str = constants.POPULATION
|
|
576
585
|
# Media data
|
|
577
586
|
media: Sequence[str] | None = None
|
|
@@ -607,7 +616,7 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
607
616
|
to the DataFrame column names if they are different. The fields are:
|
|
608
617
|
|
|
609
618
|
* `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
|
|
610
|
-
* `controls` (multiple columns)
|
|
619
|
+
* `controls` (multiple columns, optional)
|
|
611
620
|
* (1) `media`, `media_spend` (multiple columns)
|
|
612
621
|
* (2) `reach`, `frequency`, `rf_spend` (multiple columns)
|
|
613
622
|
* `non_media_treatments` (multiple columns, optional)
|
|
@@ -953,9 +962,10 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
953
962
|
not_lagged_columns = []
|
|
954
963
|
coords = [
|
|
955
964
|
constants.KPI,
|
|
956
|
-
constants.CONTROLS,
|
|
957
965
|
constants.POPULATION,
|
|
958
966
|
]
|
|
967
|
+
if self.coord_to_columns.controls is not None:
|
|
968
|
+
coords.append(constants.CONTROLS)
|
|
959
969
|
if self.coord_to_columns.revenue_per_kpi is not None:
|
|
960
970
|
coords.append(constants.REVENUE_PER_KPI)
|
|
961
971
|
if self.coord_to_columns.media_spend is not None:
|
|
@@ -1042,17 +1052,20 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
1042
1052
|
.to_frame()
|
|
1043
1053
|
.to_xarray()
|
|
1044
1054
|
)
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1055
|
+
dataset = xr.combine_by_coords([kpi_xr, population_xr])
|
|
1056
|
+
|
|
1057
|
+
if self.coord_to_columns.controls is not None:
|
|
1058
|
+
controls_xr = (
|
|
1059
|
+
df_indexed[self.coord_to_columns.controls]
|
|
1060
|
+
.stack()
|
|
1061
|
+
.rename(constants.CONTROLS)
|
|
1062
|
+
.rename_axis(
|
|
1063
|
+
[constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
|
|
1064
|
+
)
|
|
1065
|
+
.to_frame()
|
|
1066
|
+
.to_xarray()
|
|
1067
|
+
)
|
|
1068
|
+
dataset = xr.combine_by_coords([dataset, controls_xr])
|
|
1056
1069
|
|
|
1057
1070
|
if self.coord_to_columns.non_media_treatments is not None:
|
|
1058
1071
|
non_media_xr = (
|
|
@@ -1224,7 +1237,7 @@ class CsvDataLoader(InputDataLoader):
|
|
|
1224
1237
|
CSV column names, if they are different. The fields are:
|
|
1225
1238
|
|
|
1226
1239
|
* `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
|
|
1227
|
-
* `controls` (multiple columns)
|
|
1240
|
+
* `controls` (multiple columns, optional)
|
|
1228
1241
|
* (1) `media`, `media_spend` (multiple columns)
|
|
1229
1242
|
* (2) `reach`, `frequency`, `rf_spend` (multiple columns)
|
|
1230
1243
|
* `non_media_treatments` (multiple columns, optional)
|