google-meridian 1.0.9__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.0.9.dist-info → google_meridian-1.1.0.dist-info}/METADATA +2 -2
- google_meridian-1.1.0.dist-info/RECORD +41 -0
- {google_meridian-1.0.9.dist-info → google_meridian-1.1.0.dist-info}/WHEEL +1 -1
- meridian/__init__.py +1 -1
- meridian/analysis/analyzer.py +195 -189
- meridian/analysis/optimizer.py +263 -65
- meridian/analysis/summarizer.py +4 -4
- meridian/analysis/test_utils.py +81 -81
- meridian/analysis/visualizer.py +12 -16
- meridian/constants.py +100 -16
- meridian/data/input_data.py +115 -19
- meridian/data/test_utils.py +116 -5
- meridian/data/time_coordinates.py +3 -3
- meridian/model/media.py +133 -98
- meridian/model/model.py +447 -57
- meridian/model/model_test_data.py +11 -0
- meridian/model/posterior_sampler.py +120 -43
- meridian/model/prior_distribution.py +96 -51
- meridian/model/prior_sampler.py +179 -209
- meridian/model/spec.py +196 -36
- meridian/model/transformers.py +15 -3
- google_meridian-1.0.9.dist-info/RECORD +0 -41
- {google_meridian-1.0.9.dist-info → google_meridian-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.0.9.dist-info → google_meridian-1.1.0.dist-info}/top_level.txt +0 -0
meridian/data/input_data.py
CHANGED
|
@@ -60,7 +60,7 @@ def _check_dim_collection(
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray]):
|
|
63
|
+
def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray | None]):
|
|
64
64
|
"""Verifies that the dimensions of the appropriate arrays match."""
|
|
65
65
|
lengths = [len(array.coords[dim]) for array in arrays if array is not None]
|
|
66
66
|
names = [array.name for array in arrays if array is not None]
|
|
@@ -83,6 +83,31 @@ def _check_coords_match(dim: str, arrays: Sequence[xr.DataArray]):
|
|
|
83
83
|
)
|
|
84
84
|
|
|
85
85
|
|
|
86
|
+
def _aggregate_spend(
|
|
87
|
+
spend: xr.DataArray, calibration_period: np.ndarray | None
|
|
88
|
+
) -> np.ndarray | None:
|
|
89
|
+
"""Aggregates spend for each channel over the calibration period.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
spend: An array with shape `(n_geos, n_times, n_channels)` to aggregate.
|
|
93
|
+
calibration_period: An optional boolean array of shape `(n_media_times,
|
|
94
|
+
n_channels)`. If provided, spend is filtered according to this period.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
A 1-D array of aggregated media spend per channel, or `None` if `spend` is
|
|
98
|
+
`None`.
|
|
99
|
+
"""
|
|
100
|
+
if spend is None:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
if calibration_period is None:
|
|
104
|
+
return np.sum(spend, axis=(0, 1))
|
|
105
|
+
|
|
106
|
+
# Select the last `n_times` from the `calibration_period`
|
|
107
|
+
factors = np.where(calibration_period[-spend.shape[1] :, :], 1, 0)
|
|
108
|
+
return np.einsum("gtm,tm->m", spend, factors)
|
|
109
|
+
|
|
110
|
+
|
|
86
111
|
@dataclasses.dataclass
|
|
87
112
|
class InputData:
|
|
88
113
|
"""A data container for advertising data in a format supported by Meridian.
|
|
@@ -120,8 +145,14 @@ class InputData:
|
|
|
120
145
|
in the same order. If either of these arguments is passed, then the other
|
|
121
146
|
is not optional.
|
|
122
147
|
media_spend: An optional `DataArray` containing the cost of each media
|
|
123
|
-
channel. This is used as the denominator for ROI calculations.
|
|
124
|
-
|
|
148
|
+
channel. This is used as the denominator for ROI calculations. It is also
|
|
149
|
+
used to calculate an assumed cost per media unit for post-modeling
|
|
150
|
+
analysis such as response curves and budget optimization. Only the
|
|
151
|
+
aggregate spend (across geos and time periods) is required for these
|
|
152
|
+
calculations. However, a spend breakdown by geo and time period is
|
|
153
|
+
required if `roi_calibration_period` is specified or if conducting
|
|
154
|
+
post-modeling analysis on a specific subset of geos and/or time periods.
|
|
155
|
+
The DataArray shape can be `(n_geos, n_times, n_media_channels)` or
|
|
125
156
|
`(n_media_channels,)` if the data is aggregated over `geo` and `time`
|
|
126
157
|
dimensions. We recommend that the spend total aligns with the time window
|
|
127
158
|
of the `kpi` and `controls` data, which is the time window over which
|
|
@@ -131,7 +162,9 @@ class InputData:
|
|
|
131
162
|
time window of media executed during the time window. `media` and
|
|
132
163
|
`media_spend` must contain the same number of media channels in the same
|
|
133
164
|
order. If either of these arguments is passed, then the other is not
|
|
134
|
-
optional.
|
|
165
|
+
optional. If a tensor of shape `(n_media_channels,)` is passed as
|
|
166
|
+
`media_spend`, then it will be automatically allocated across geos and
|
|
167
|
+
times proportinally to `media`.
|
|
135
168
|
reach: An optional `DataArray` of dimensions `(n_geos, n_media_times,
|
|
136
169
|
n_rf_channels)` containing non-negative `reach` values. It is required
|
|
137
170
|
that `n_media_times` ≥ `n_times`, and the final `n_times` time periods
|
|
@@ -164,18 +197,26 @@ class InputData:
|
|
|
164
197
|
others are not optional.
|
|
165
198
|
rf_spend: An optional `DataArray` containing the cost of each reach and
|
|
166
199
|
frequency channel. This is used as the denominator for ROI calculations.
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
200
|
+
It is also used to calculate an assumed cost per media unit for
|
|
201
|
+
post-modeling analysis such as response curves and budget optimization.
|
|
202
|
+
Only the aggregate spend (across geos and time periods) is required for
|
|
203
|
+
these calculations. However, a spend breakdown by geo and time period is
|
|
204
|
+
required if `rf_roi_calibration_period` is specified or if conducting
|
|
205
|
+
post-modeling analysis on a specific subset of geos and/or time periods.
|
|
206
|
+
The DataArray shape can be `(n_rf_channels,)` or `(n_geos, n_times,
|
|
207
|
+
n_rf_channels)`. The spend should be aggregated over geo and/or time
|
|
208
|
+
dimensions that are not represented. We recommend that the spend total
|
|
209
|
+
aligns with the time window of the `kpi` and `controls` data, which is the
|
|
210
|
+
time window over which incremental outcome of the ROI numerator is
|
|
211
|
+
calculated. However, note that incremental outcome is influenced by media
|
|
212
|
+
execution prior to this time window, through lagged effects, and excludes
|
|
213
|
+
lagged effects beyond the time window of media executed during the time
|
|
214
|
+
window. If only `media` data is used, `rf_spend` will be `None`. `reach`,
|
|
215
|
+
`frequency`, and `rf_spend` must contain the same number of media channels
|
|
216
|
+
in the same order. If any of these arguments is passed, then the others
|
|
217
|
+
are not optional. If a tensor of shape `(n_rf_channels,)` is passed as
|
|
218
|
+
`rf_spend`, then it will be automatically allocated across geos and times
|
|
219
|
+
proportionally to `(reach * frequency)`.
|
|
179
220
|
organic_media: An optional `DataArray` of dimensions `(n_geos,
|
|
180
221
|
n_media_times, n_organic_media_channels)` containing non-negative organic
|
|
181
222
|
media values. Organic media variables are media activities that have no
|
|
@@ -265,6 +306,40 @@ class InputData:
|
|
|
265
306
|
if isinstance(array, xr.DataArray) and constants.GEO in array.dims:
|
|
266
307
|
array.coords[constants.GEO] = array.coords[constants.GEO].astype(str)
|
|
267
308
|
|
|
309
|
+
# TODO: b/416775065 - Combine with Analyzer._impute_and_aggregate_spend
|
|
310
|
+
@functools.cached_property
|
|
311
|
+
def allocated_media_spend(self) -> xr.DataArray | None:
|
|
312
|
+
"""Returns the allocated media spend for each geo and time."""
|
|
313
|
+
if self.media_spend is not None and len(self.media_spend.shape) == 1:
|
|
314
|
+
return self._allocate_spend(self.media_spend, self.media)
|
|
315
|
+
else:
|
|
316
|
+
return self.media_spend
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def allocated_rf_spend(self) -> xr.DataArray | None:
|
|
320
|
+
"""Returns the allocated RF spend for each geo and time."""
|
|
321
|
+
if self.rf_spend is not None and len(self.rf_spend.shape) == 1:
|
|
322
|
+
return self._allocate_spend(self.rf_spend, self.reach * self.frequency)
|
|
323
|
+
else:
|
|
324
|
+
return self.rf_spend
|
|
325
|
+
|
|
326
|
+
def aggregate_media_spend(
|
|
327
|
+
self, calibration_period: np.ndarray | None = None
|
|
328
|
+
) -> np.ndarray | None:
|
|
329
|
+
"""Aggregates media spend by channel over the calibration period."""
|
|
330
|
+
return _aggregate_spend(
|
|
331
|
+
spend=self.allocated_media_spend, calibration_period=calibration_period
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def aggregate_rf_spend(
|
|
335
|
+
self, calibration_period: np.ndarray | None = None
|
|
336
|
+
) -> np.ndarray | None:
|
|
337
|
+
"""Aggregates RF spend by channel over the calibration period."""
|
|
338
|
+
return _aggregate_spend(
|
|
339
|
+
spend=self.allocated_rf_spend,
|
|
340
|
+
calibration_period=calibration_period,
|
|
341
|
+
)
|
|
342
|
+
|
|
268
343
|
@property
|
|
269
344
|
def geo(self) -> xr.DataArray:
|
|
270
345
|
"""Returns the geo dimension."""
|
|
@@ -424,7 +499,8 @@ class InputData:
|
|
|
424
499
|
|
|
425
500
|
def _validate_names(self):
|
|
426
501
|
"""Verifies that the names of the data arrays are correct."""
|
|
427
|
-
|
|
502
|
+
# Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
|
|
503
|
+
arrays = (
|
|
428
504
|
self.kpi,
|
|
429
505
|
self.controls,
|
|
430
506
|
self.population,
|
|
@@ -438,7 +514,7 @@ class InputData:
|
|
|
438
514
|
self.reach,
|
|
439
515
|
self.frequency,
|
|
440
516
|
self.rf_spend,
|
|
441
|
-
|
|
517
|
+
)
|
|
442
518
|
|
|
443
519
|
for array, name in zip(arrays, constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES):
|
|
444
520
|
if array is not None and array.name != name:
|
|
@@ -479,7 +555,6 @@ class InputData:
|
|
|
479
555
|
[
|
|
480
556
|
[constants.RF_CHANNEL],
|
|
481
557
|
[constants.GEO, constants.TIME, constants.RF_CHANNEL],
|
|
482
|
-
[constants.GEO, constants.RF_CHANNEL],
|
|
483
558
|
],
|
|
484
559
|
)
|
|
485
560
|
_check_dim_collection(
|
|
@@ -848,3 +923,24 @@ class InputData:
|
|
|
848
923
|
return self.media_spend.values
|
|
849
924
|
else:
|
|
850
925
|
raise ValueError("Both RF and Media are missing.")
|
|
926
|
+
|
|
927
|
+
def get_total_outcome(self) -> np.ndarray:
|
|
928
|
+
"""Returns total outcome, aggregated over geos and times."""
|
|
929
|
+
if self.revenue_per_kpi is None:
|
|
930
|
+
return np.sum(self.kpi.values)
|
|
931
|
+
return np.sum(self.kpi.values * self.revenue_per_kpi.values)
|
|
932
|
+
|
|
933
|
+
def _allocate_spend(self, spend: xr.DataArray, media_units: xr.DataArray):
|
|
934
|
+
"""Allocates spend across geo and time proportionally to media units."""
|
|
935
|
+
n_times = len(self.kpi.coords[constants.TIME])
|
|
936
|
+
selected_media_units = media_units.isel(media_time=slice(-n_times, None))
|
|
937
|
+
total_media_units_per_channel = selected_media_units.sum(
|
|
938
|
+
dim=["geo", "media_time"]
|
|
939
|
+
)
|
|
940
|
+
proportions = selected_media_units / total_media_units_per_channel
|
|
941
|
+
expanded_spend = spend.expand_dims({
|
|
942
|
+
"geo": selected_media_units["geo"],
|
|
943
|
+
"media_time": selected_media_units["media_time"],
|
|
944
|
+
})
|
|
945
|
+
allocated_spend = expanded_spend * proportions
|
|
946
|
+
return allocated_spend.rename({"media_time": "time"})
|
meridian/data/test_utils.py
CHANGED
|
@@ -65,12 +65,24 @@ _REQUIRED_COORDS = immutabledict.immutabledict({
|
|
|
65
65
|
c.MEDIA_TIME: _sample_times(n_times=3),
|
|
66
66
|
c.CONTROL_VARIABLE: ['control_0', 'control_1'],
|
|
67
67
|
})
|
|
68
|
+
_NON_MEDIA_COORDS = immutabledict.immutabledict(
|
|
69
|
+
{c.NON_MEDIA_CHANNEL: ['non_media_channel_0', 'non_media_channel_1']}
|
|
70
|
+
)
|
|
68
71
|
_MEDIA_COORDS = immutabledict.immutabledict(
|
|
69
72
|
{c.MEDIA_CHANNEL: ['media_channel_0', 'media_channel_1', 'media_channel_2']}
|
|
70
73
|
)
|
|
74
|
+
_ORGANIC_MEDIA_COORDS = immutabledict.immutabledict({
|
|
75
|
+
c.ORGANIC_MEDIA_CHANNEL: [
|
|
76
|
+
'organic_media_channel_0',
|
|
77
|
+
'organic_media_channel_1',
|
|
78
|
+
]
|
|
79
|
+
})
|
|
71
80
|
_RF_COORDS = immutabledict.immutabledict(
|
|
72
81
|
{c.RF_CHANNEL: ['rf_channel_0', 'rf_channel_1']}
|
|
73
82
|
)
|
|
83
|
+
_ORGANIC_RF_COORDS = immutabledict.immutabledict(
|
|
84
|
+
{c.ORGANIC_RF_CHANNEL: ['organic_rf_channel_0', 'organic_rf_channel_1']}
|
|
85
|
+
)
|
|
74
86
|
|
|
75
87
|
_REQUIRED_DATA_VARS = immutabledict.immutabledict({
|
|
76
88
|
c.KPI: (['geo', 'time'], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
|
|
@@ -376,6 +388,70 @@ DATASET_WITHOUT_TIME_VARIATION_IN_REACH = xr.Dataset(
|
|
|
376
388
|
},
|
|
377
389
|
)
|
|
378
390
|
|
|
391
|
+
DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_MEDIA = xr.Dataset(
|
|
392
|
+
coords=_REQUIRED_COORDS
|
|
393
|
+
| _MEDIA_COORDS
|
|
394
|
+
| _RF_COORDS
|
|
395
|
+
| _ORGANIC_MEDIA_COORDS,
|
|
396
|
+
data_vars=_REQUIRED_DATA_VARS
|
|
397
|
+
| _MEDIA_DATA_VARS
|
|
398
|
+
| _RF_DATA_VARS
|
|
399
|
+
| _OPTIONAL_DATA_VARS
|
|
400
|
+
| {
|
|
401
|
+
c.ORGANIC_MEDIA: (
|
|
402
|
+
['geo', 'media_time', 'organic_media_channel'],
|
|
403
|
+
[
|
|
404
|
+
[[2.1, 2.2], [2.1, 2.21], [2.1, 2.2]],
|
|
405
|
+
[[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
|
|
406
|
+
],
|
|
407
|
+
),
|
|
408
|
+
},
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_REACH = xr.Dataset(
|
|
412
|
+
coords=_REQUIRED_COORDS
|
|
413
|
+
| _MEDIA_COORDS
|
|
414
|
+
| _RF_COORDS
|
|
415
|
+
| _ORGANIC_RF_COORDS,
|
|
416
|
+
data_vars=_REQUIRED_DATA_VARS
|
|
417
|
+
| _MEDIA_DATA_VARS
|
|
418
|
+
| _RF_DATA_VARS
|
|
419
|
+
| _OPTIONAL_DATA_VARS
|
|
420
|
+
| {
|
|
421
|
+
c.ORGANIC_REACH: (
|
|
422
|
+
['geo', 'media_time', 'organic_rf_channel'],
|
|
423
|
+
[
|
|
424
|
+
[[2.1, 2.2], [2.11, 2.2], [2.1, 2.2]],
|
|
425
|
+
[[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
|
|
426
|
+
],
|
|
427
|
+
),
|
|
428
|
+
c.ORGANIC_FREQUENCY: (
|
|
429
|
+
['geo', 'media_time', 'organic_rf_channel'],
|
|
430
|
+
[
|
|
431
|
+
[[7.1, 7.2], [7.3, 7.4], [7.5, 7.6]],
|
|
432
|
+
[[7.11, 7.21], [7.31, 7.41], [7.51, 7.61]],
|
|
433
|
+
],
|
|
434
|
+
),
|
|
435
|
+
},
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
DATASET_WITHOUT_TIME_VARIATION_IN_NON_MEDIA_TREATMENTS = xr.Dataset(
|
|
439
|
+
coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _NON_MEDIA_COORDS,
|
|
440
|
+
data_vars=_REQUIRED_DATA_VARS
|
|
441
|
+
| _MEDIA_DATA_VARS
|
|
442
|
+
| _RF_DATA_VARS
|
|
443
|
+
| _OPTIONAL_DATA_VARS
|
|
444
|
+
| {
|
|
445
|
+
c.NON_MEDIA_TREATMENTS: (
|
|
446
|
+
['geo', 'time', 'non_media_channel'],
|
|
447
|
+
[
|
|
448
|
+
[[2.1, 2.2], [2.1, 2.2], [2.1, 2.2]],
|
|
449
|
+
[[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
|
|
450
|
+
],
|
|
451
|
+
),
|
|
452
|
+
},
|
|
453
|
+
)
|
|
454
|
+
|
|
379
455
|
_NATIONAL_COORDS = immutabledict.immutabledict({
|
|
380
456
|
c.TIME: [
|
|
381
457
|
_SAMPLE_START_DATE.strftime(c.DATE_FORMAT),
|
|
@@ -1491,17 +1567,52 @@ def sample_input_data_from_dataset(
|
|
|
1491
1567
|
dataset: xr.Dataset, kpi_type: str
|
|
1492
1568
|
) -> input_data.InputData:
|
|
1493
1569
|
"""Generates a sample `InputData` from a full xarray Dataset."""
|
|
1570
|
+
media = dataset.media if c.MEDIA in dataset.data_vars.keys() else None
|
|
1571
|
+
media_spend = (
|
|
1572
|
+
dataset.media_spend if c.MEDIA_SPEND in dataset.data_vars.keys() else None
|
|
1573
|
+
)
|
|
1574
|
+
reach = dataset.reach if c.REACH in dataset.data_vars.keys() else None
|
|
1575
|
+
frequency = (
|
|
1576
|
+
dataset.frequency if c.FREQUENCY in dataset.data_vars.keys() else None
|
|
1577
|
+
)
|
|
1578
|
+
rf_spend = (
|
|
1579
|
+
dataset.rf_spend if c.RF_SPEND in dataset.data_vars.keys() else None
|
|
1580
|
+
)
|
|
1581
|
+
organic_media = (
|
|
1582
|
+
dataset.organic_media
|
|
1583
|
+
if c.ORGANIC_MEDIA in dataset.data_vars.keys()
|
|
1584
|
+
else None
|
|
1585
|
+
)
|
|
1586
|
+
organic_reach = (
|
|
1587
|
+
dataset.organic_reach
|
|
1588
|
+
if c.ORGANIC_REACH in dataset.data_vars.keys()
|
|
1589
|
+
else None
|
|
1590
|
+
)
|
|
1591
|
+
organic_frequency = (
|
|
1592
|
+
dataset.organic_frequency
|
|
1593
|
+
if c.ORGANIC_FREQUENCY in dataset.data_vars.keys()
|
|
1594
|
+
else None
|
|
1595
|
+
)
|
|
1596
|
+
non_media_treatments = (
|
|
1597
|
+
dataset.non_media_treatments
|
|
1598
|
+
if c.NON_MEDIA_TREATMENTS in dataset.data_vars.keys()
|
|
1599
|
+
else None
|
|
1600
|
+
)
|
|
1494
1601
|
return input_data.InputData(
|
|
1495
1602
|
kpi=dataset.kpi,
|
|
1496
1603
|
kpi_type=kpi_type,
|
|
1497
1604
|
revenue_per_kpi=dataset.revenue_per_kpi,
|
|
1498
1605
|
population=dataset.population,
|
|
1499
1606
|
controls=dataset.controls,
|
|
1500
|
-
media=
|
|
1501
|
-
media_spend=
|
|
1502
|
-
reach=
|
|
1503
|
-
frequency=
|
|
1504
|
-
rf_spend=
|
|
1607
|
+
media=media,
|
|
1608
|
+
media_spend=media_spend,
|
|
1609
|
+
reach=reach,
|
|
1610
|
+
frequency=frequency,
|
|
1611
|
+
rf_spend=rf_spend,
|
|
1612
|
+
organic_media=organic_media,
|
|
1613
|
+
organic_reach=organic_reach,
|
|
1614
|
+
organic_frequency=organic_frequency,
|
|
1615
|
+
non_media_treatments=non_media_treatments,
|
|
1505
1616
|
)
|
|
1506
1617
|
|
|
1507
1618
|
|
|
@@ -36,7 +36,7 @@ __all__ = [
|
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
# A type alias for a polymorphic "date" type.
|
|
39
|
-
Date: TypeAlias = str | datetime.datetime | datetime.date | np.datetime64
|
|
39
|
+
Date: TypeAlias = str | datetime.datetime | datetime.date | np.datetime64 | None
|
|
40
40
|
|
|
41
41
|
# A type alias for a polymorphic "date interval" type. In all variants it is
|
|
42
42
|
# always a tuple of (start_date, end_date).
|
|
@@ -236,8 +236,8 @@ class TimeCoordinates:
|
|
|
236
236
|
|
|
237
237
|
def expand_selected_time_dims(
|
|
238
238
|
self,
|
|
239
|
-
start_date: Date
|
|
240
|
-
end_date: Date
|
|
239
|
+
start_date: Date = None,
|
|
240
|
+
end_date: Date = None,
|
|
241
241
|
) -> list[datetime.date] | None:
|
|
242
242
|
"""Validates and returns time dimension values based on the selected times.
|
|
243
243
|
|