google-meridian 1.0.8__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.0.8.dist-info → google_meridian-1.1.0.dist-info}/METADATA +2 -2
- google_meridian-1.1.0.dist-info/RECORD +41 -0
- {google_meridian-1.0.8.dist-info → google_meridian-1.1.0.dist-info}/WHEEL +1 -1
- meridian/__init__.py +1 -1
- meridian/analysis/analyzer.py +303 -207
- meridian/analysis/optimizer.py +431 -82
- meridian/analysis/summarizer.py +25 -7
- meridian/analysis/test_utils.py +81 -81
- meridian/analysis/visualizer.py +81 -39
- meridian/constants.py +111 -26
- meridian/data/input_data.py +115 -19
- meridian/data/test_utils.py +116 -5
- meridian/data/time_coordinates.py +3 -3
- meridian/model/media.py +133 -98
- meridian/model/model.py +457 -52
- meridian/model/model_test_data.py +11 -0
- meridian/model/posterior_sampler.py +120 -43
- meridian/model/prior_distribution.py +95 -29
- meridian/model/prior_sampler.py +179 -209
- meridian/model/spec.py +196 -36
- meridian/model/transformers.py +15 -3
- google_meridian-1.0.8.dist-info/RECORD +0 -41
- {google_meridian-1.0.8.dist-info → google_meridian-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.0.8.dist-info → google_meridian-1.1.0.dist-info}/top_level.txt +0 -0
meridian/constants.py
CHANGED
|
@@ -51,6 +51,8 @@ GREY_300 = '#DADCE0'
|
|
|
51
51
|
|
|
52
52
|
# Example: "2024-01-09"
|
|
53
53
|
DATE_FORMAT = '%Y-%m-%d'
|
|
54
|
+
# Example: "2024 Apr"
|
|
55
|
+
QUARTER_FORMAT = '%Y %b'
|
|
54
56
|
|
|
55
57
|
# Input data variables.
|
|
56
58
|
KPI = 'kpi'
|
|
@@ -95,12 +97,8 @@ POSSIBLE_INPUT_DATA_ARRAY_NAMES = (
|
|
|
95
97
|
+ MEDIA_INPUT_DATA_ARRAY_NAMES
|
|
96
98
|
+ RF_INPUT_DATA_ARRAY_NAMES
|
|
97
99
|
)
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
REACH,
|
|
101
|
-
FREQUENCY,
|
|
102
|
-
REVENUE_PER_KPI,
|
|
103
|
-
)
|
|
100
|
+
PAID_CHANNELS = (MEDIA, REACH, FREQUENCY)
|
|
101
|
+
PAID_DATA = PAID_CHANNELS + (REVENUE_PER_KPI,)
|
|
104
102
|
NON_PAID_DATA = (
|
|
105
103
|
ORGANIC_MEDIA,
|
|
106
104
|
ORGANIC_REACH,
|
|
@@ -112,11 +110,7 @@ SPEND_DATA = (
|
|
|
112
110
|
RF_SPEND,
|
|
113
111
|
)
|
|
114
112
|
PERFORMANCE_DATA = PAID_DATA + SPEND_DATA
|
|
115
|
-
IMPRESSIONS_DATA =
|
|
116
|
-
MEDIA,
|
|
117
|
-
REACH,
|
|
118
|
-
FREQUENCY,
|
|
119
|
-
) + NON_PAID_DATA
|
|
113
|
+
IMPRESSIONS_DATA = PAID_CHANNELS + NON_PAID_DATA
|
|
120
114
|
RF_DATA = (
|
|
121
115
|
REACH,
|
|
122
116
|
FREQUENCY,
|
|
@@ -200,17 +194,23 @@ RF_ROI_CALIBRATION_PERIOD = 'rf_roi_calibration_period'
|
|
|
200
194
|
KNOTS = 'knots'
|
|
201
195
|
BASELINE_GEO = 'baseline_geo'
|
|
202
196
|
|
|
203
|
-
#
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
197
|
+
# Treatment prior types.
|
|
198
|
+
TREATMENT_PRIOR_TYPE_ROI = 'roi'
|
|
199
|
+
TREATMENT_PRIOR_TYPE_MROI = 'mroi'
|
|
200
|
+
TREATMENT_PRIOR_TYPE_COEFFICIENT = 'coefficient'
|
|
201
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION = 'contribution'
|
|
202
|
+
PAID_TREATMENT_PRIOR_TYPES = frozenset({
|
|
203
|
+
TREATMENT_PRIOR_TYPE_ROI,
|
|
204
|
+
TREATMENT_PRIOR_TYPE_MROI,
|
|
205
|
+
TREATMENT_PRIOR_TYPE_COEFFICIENT,
|
|
206
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
207
|
+
})
|
|
208
|
+
NON_PAID_TREATMENT_PRIOR_TYPES = frozenset({
|
|
209
|
+
TREATMENT_PRIOR_TYPE_COEFFICIENT,
|
|
210
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
211
211
|
})
|
|
212
212
|
PAID_MEDIA_ROI_PRIOR_TYPES = frozenset(
|
|
213
|
-
{
|
|
213
|
+
{TREATMENT_PRIOR_TYPE_ROI, TREATMENT_PRIOR_TYPE_MROI}
|
|
214
214
|
)
|
|
215
215
|
# Represents a 1% increase in spend.
|
|
216
216
|
MROI_FACTOR = 1.01
|
|
@@ -239,6 +239,11 @@ ROI_M = 'roi_m'
|
|
|
239
239
|
ROI_RF = 'roi_rf'
|
|
240
240
|
MROI_M = 'mroi_m'
|
|
241
241
|
MROI_RF = 'mroi_rf'
|
|
242
|
+
CONTRIBUTION_M = 'contribution_m'
|
|
243
|
+
CONTRIBUTION_RF = 'contribution_rf'
|
|
244
|
+
CONTRIBUTION_OM = 'contribution_om'
|
|
245
|
+
CONTRIBUTION_ORF = 'contribution_orf'
|
|
246
|
+
CONTRIBUTION_N = 'contribution_n'
|
|
242
247
|
GAMMA_C = 'gamma_c'
|
|
243
248
|
GAMMA_N = 'gamma_n'
|
|
244
249
|
XI_C = 'xi_c'
|
|
@@ -307,7 +312,29 @@ RF_PARAMETER_NAMES = (
|
|
|
307
312
|
BETA_RF,
|
|
308
313
|
BETA_GRF,
|
|
309
314
|
)
|
|
315
|
+
|
|
316
|
+
MEDIA_PARAMETERS = (
|
|
317
|
+
ROI_M,
|
|
318
|
+
MROI_M,
|
|
319
|
+
CONTRIBUTION_M,
|
|
320
|
+
BETA_M,
|
|
321
|
+
ETA_M,
|
|
322
|
+
ALPHA_M,
|
|
323
|
+
EC_M,
|
|
324
|
+
SLOPE_M,
|
|
325
|
+
)
|
|
326
|
+
RF_PARAMETERS = (
|
|
327
|
+
ROI_RF,
|
|
328
|
+
MROI_RF,
|
|
329
|
+
CONTRIBUTION_RF,
|
|
330
|
+
BETA_RF,
|
|
331
|
+
ETA_RF,
|
|
332
|
+
ALPHA_RF,
|
|
333
|
+
EC_RF,
|
|
334
|
+
SLOPE_RF,
|
|
335
|
+
)
|
|
310
336
|
ORGANIC_MEDIA_PARAMETERS = (
|
|
337
|
+
CONTRIBUTION_OM,
|
|
311
338
|
BETA_OM,
|
|
312
339
|
ETA_OM,
|
|
313
340
|
ALPHA_OM,
|
|
@@ -315,6 +342,7 @@ ORGANIC_MEDIA_PARAMETERS = (
|
|
|
315
342
|
SLOPE_OM,
|
|
316
343
|
)
|
|
317
344
|
ORGANIC_RF_PARAMETERS = (
|
|
345
|
+
CONTRIBUTION_ORF,
|
|
318
346
|
BETA_ORF,
|
|
319
347
|
ETA_ORF,
|
|
320
348
|
ALPHA_ORF,
|
|
@@ -322,13 +350,12 @@ ORGANIC_RF_PARAMETERS = (
|
|
|
322
350
|
SLOPE_ORF,
|
|
323
351
|
)
|
|
324
352
|
NON_MEDIA_PARAMETERS = (
|
|
353
|
+
CONTRIBUTION_N,
|
|
325
354
|
GAMMA_N,
|
|
326
355
|
XI_N,
|
|
327
356
|
)
|
|
328
357
|
|
|
329
358
|
KNOTS_PARAMETERS = (KNOT_VALUES,)
|
|
330
|
-
MEDIA_PARAMETERS = (ETA_M, BETA_M, ALPHA_M, EC_M, SLOPE_M, ROI_M, MROI_M)
|
|
331
|
-
RF_PARAMETERS = (ETA_RF, BETA_RF, ALPHA_RF, EC_RF, SLOPE_RF, ROI_RF, MROI_RF)
|
|
332
359
|
CONTROL_PARAMETERS = (GAMMA_C, XI_C)
|
|
333
360
|
SIGMA_PARAMETERS = (SIGMA,)
|
|
334
361
|
GEO_PARAMETERS = (
|
|
@@ -366,10 +393,61 @@ UNSAVED_PARAMETERS = (
|
|
|
366
393
|
GAMMA_GN_DEV,
|
|
367
394
|
TAU_G_EXCL_BASELINE, # Used to derive TAU_G.
|
|
368
395
|
)
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
396
|
+
IGNORED_PRIORS_MEDIA = immutabledict.immutabledict({
|
|
397
|
+
TREATMENT_PRIOR_TYPE_ROI: (
|
|
398
|
+
BETA_M,
|
|
399
|
+
MROI_M,
|
|
400
|
+
CONTRIBUTION_M,
|
|
401
|
+
),
|
|
402
|
+
TREATMENT_PRIOR_TYPE_MROI: (
|
|
403
|
+
BETA_M,
|
|
404
|
+
ROI_M,
|
|
405
|
+
CONTRIBUTION_M,
|
|
406
|
+
),
|
|
407
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION: (
|
|
408
|
+
BETA_M,
|
|
409
|
+
ROI_M,
|
|
410
|
+
MROI_M,
|
|
411
|
+
),
|
|
412
|
+
TREATMENT_PRIOR_TYPE_COEFFICIENT: (
|
|
413
|
+
ROI_M,
|
|
414
|
+
MROI_M,
|
|
415
|
+
CONTRIBUTION_M,
|
|
416
|
+
),
|
|
417
|
+
})
|
|
418
|
+
IGNORED_PRIORS_RF = immutabledict.immutabledict({
|
|
419
|
+
TREATMENT_PRIOR_TYPE_ROI: (
|
|
420
|
+
BETA_RF,
|
|
421
|
+
MROI_RF,
|
|
422
|
+
CONTRIBUTION_RF,
|
|
423
|
+
),
|
|
424
|
+
TREATMENT_PRIOR_TYPE_MROI: (
|
|
425
|
+
BETA_RF,
|
|
426
|
+
ROI_RF,
|
|
427
|
+
CONTRIBUTION_RF,
|
|
428
|
+
),
|
|
429
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION: (
|
|
430
|
+
BETA_RF,
|
|
431
|
+
ROI_RF,
|
|
432
|
+
MROI_RF,
|
|
433
|
+
),
|
|
434
|
+
TREATMENT_PRIOR_TYPE_COEFFICIENT: (
|
|
435
|
+
ROI_RF,
|
|
436
|
+
MROI_RF,
|
|
437
|
+
CONTRIBUTION_RF,
|
|
438
|
+
),
|
|
439
|
+
})
|
|
440
|
+
IGNORED_PRIORS_ORGANIC_MEDIA = immutabledict.immutabledict({
|
|
441
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION: (BETA_OM,),
|
|
442
|
+
TREATMENT_PRIOR_TYPE_COEFFICIENT: (CONTRIBUTION_OM,),
|
|
443
|
+
})
|
|
444
|
+
IGNORED_PRIORS_ORGANIC_RF = immutabledict.immutabledict({
|
|
445
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION: (BETA_ORF,),
|
|
446
|
+
TREATMENT_PRIOR_TYPE_COEFFICIENT: (CONTRIBUTION_ORF,),
|
|
447
|
+
})
|
|
448
|
+
IGNORED_PRIORS_NON_MEDIA_TREATMENTS = immutabledict.immutabledict({
|
|
449
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION: (GAMMA_N,),
|
|
450
|
+
TREATMENT_PRIOR_TYPE_COEFFICIENT: (CONTRIBUTION_N,),
|
|
373
451
|
})
|
|
374
452
|
|
|
375
453
|
# Inference data dimensions.
|
|
@@ -622,3 +700,10 @@ CARD_STATS = 'stats'
|
|
|
622
700
|
# VegaLite common params.
|
|
623
701
|
VEGALITE_FACET_DEFAULT_WIDTH = 400
|
|
624
702
|
VEGALITE_FACET_LARGE_WIDTH = 500
|
|
703
|
+
VEGALITE_FACET_EXTRA_LARGE_WIDTH = 700
|
|
704
|
+
|
|
705
|
+
# Time Granularity Constants
|
|
706
|
+
WEEKLY = 'weekly'
|
|
707
|
+
QUARTERLY = 'quarterly'
|
|
708
|
+
TIME_GRANULARITIES = frozenset({WEEKLY, QUARTERLY})
|
|
709
|
+
QUARTERLY_SUMMARY_THRESHOLD_WEEKS = 52
|
meridian/data/input_data.py
CHANGED
|
@@ -60,7 +60,7 @@ def _check_dim_collection(
|
|
|
60
60
|
)
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray]):
|
|
63
|
+
def _check_dim_match(dim: str, arrays: Sequence[xr.DataArray | None]):
|
|
64
64
|
"""Verifies that the dimensions of the appropriate arrays match."""
|
|
65
65
|
lengths = [len(array.coords[dim]) for array in arrays if array is not None]
|
|
66
66
|
names = [array.name for array in arrays if array is not None]
|
|
@@ -83,6 +83,31 @@ def _check_coords_match(dim: str, arrays: Sequence[xr.DataArray]):
|
|
|
83
83
|
)
|
|
84
84
|
|
|
85
85
|
|
|
86
|
+
def _aggregate_spend(
|
|
87
|
+
spend: xr.DataArray, calibration_period: np.ndarray | None
|
|
88
|
+
) -> np.ndarray | None:
|
|
89
|
+
"""Aggregates spend for each channel over the calibration period.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
spend: An array with shape `(n_geos, n_times, n_channels)` to aggregate.
|
|
93
|
+
calibration_period: An optional boolean array of shape `(n_media_times,
|
|
94
|
+
n_channels)`. If provided, spend is filtered according to this period.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
A 1-D array of aggregated media spend per channel, or `None` if `spend` is
|
|
98
|
+
`None`.
|
|
99
|
+
"""
|
|
100
|
+
if spend is None:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
if calibration_period is None:
|
|
104
|
+
return np.sum(spend, axis=(0, 1))
|
|
105
|
+
|
|
106
|
+
# Select the last `n_times` from the `calibration_period`
|
|
107
|
+
factors = np.where(calibration_period[-spend.shape[1] :, :], 1, 0)
|
|
108
|
+
return np.einsum("gtm,tm->m", spend, factors)
|
|
109
|
+
|
|
110
|
+
|
|
86
111
|
@dataclasses.dataclass
|
|
87
112
|
class InputData:
|
|
88
113
|
"""A data container for advertising data in a format supported by Meridian.
|
|
@@ -120,8 +145,14 @@ class InputData:
|
|
|
120
145
|
in the same order. If either of these arguments is passed, then the other
|
|
121
146
|
is not optional.
|
|
122
147
|
media_spend: An optional `DataArray` containing the cost of each media
|
|
123
|
-
channel. This is used as the denominator for ROI calculations.
|
|
124
|
-
|
|
148
|
+
channel. This is used as the denominator for ROI calculations. It is also
|
|
149
|
+
used to calculate an assumed cost per media unit for post-modeling
|
|
150
|
+
analysis such as response curves and budget optimization. Only the
|
|
151
|
+
aggregate spend (across geos and time periods) is required for these
|
|
152
|
+
calculations. However, a spend breakdown by geo and time period is
|
|
153
|
+
required if `roi_calibration_period` is specified or if conducting
|
|
154
|
+
post-modeling analysis on a specific subset of geos and/or time periods.
|
|
155
|
+
The DataArray shape can be `(n_geos, n_times, n_media_channels)` or
|
|
125
156
|
`(n_media_channels,)` if the data is aggregated over `geo` and `time`
|
|
126
157
|
dimensions. We recommend that the spend total aligns with the time window
|
|
127
158
|
of the `kpi` and `controls` data, which is the time window over which
|
|
@@ -131,7 +162,9 @@ class InputData:
|
|
|
131
162
|
time window of media executed during the time window. `media` and
|
|
132
163
|
`media_spend` must contain the same number of media channels in the same
|
|
133
164
|
order. If either of these arguments is passed, then the other is not
|
|
134
|
-
optional.
|
|
165
|
+
optional. If a tensor of shape `(n_media_channels,)` is passed as
|
|
166
|
+
`media_spend`, then it will be automatically allocated across geos and
|
|
167
|
+
times proportinally to `media`.
|
|
135
168
|
reach: An optional `DataArray` of dimensions `(n_geos, n_media_times,
|
|
136
169
|
n_rf_channels)` containing non-negative `reach` values. It is required
|
|
137
170
|
that `n_media_times` ≥ `n_times`, and the final `n_times` time periods
|
|
@@ -164,18 +197,26 @@ class InputData:
|
|
|
164
197
|
others are not optional.
|
|
165
198
|
rf_spend: An optional `DataArray` containing the cost of each reach and
|
|
166
199
|
frequency channel. This is used as the denominator for ROI calculations.
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
200
|
+
It is also used to calculate an assumed cost per media unit for
|
|
201
|
+
post-modeling analysis such as response curves and budget optimization.
|
|
202
|
+
Only the aggregate spend (across geos and time periods) is required for
|
|
203
|
+
these calculations. However, a spend breakdown by geo and time period is
|
|
204
|
+
required if `rf_roi_calibration_period` is specified or if conducting
|
|
205
|
+
post-modeling analysis on a specific subset of geos and/or time periods.
|
|
206
|
+
The DataArray shape can be `(n_rf_channels,)` or `(n_geos, n_times,
|
|
207
|
+
n_rf_channels)`. The spend should be aggregated over geo and/or time
|
|
208
|
+
dimensions that are not represented. We recommend that the spend total
|
|
209
|
+
aligns with the time window of the `kpi` and `controls` data, which is the
|
|
210
|
+
time window over which incremental outcome of the ROI numerator is
|
|
211
|
+
calculated. However, note that incremental outcome is influenced by media
|
|
212
|
+
execution prior to this time window, through lagged effects, and excludes
|
|
213
|
+
lagged effects beyond the time window of media executed during the time
|
|
214
|
+
window. If only `media` data is used, `rf_spend` will be `None`. `reach`,
|
|
215
|
+
`frequency`, and `rf_spend` must contain the same number of media channels
|
|
216
|
+
in the same order. If any of these arguments is passed, then the others
|
|
217
|
+
are not optional. If a tensor of shape `(n_rf_channels,)` is passed as
|
|
218
|
+
`rf_spend`, then it will be automatically allocated across geos and times
|
|
219
|
+
proportionally to `(reach * frequency)`.
|
|
179
220
|
organic_media: An optional `DataArray` of dimensions `(n_geos,
|
|
180
221
|
n_media_times, n_organic_media_channels)` containing non-negative organic
|
|
181
222
|
media values. Organic media variables are media activities that have no
|
|
@@ -265,6 +306,40 @@ class InputData:
|
|
|
265
306
|
if isinstance(array, xr.DataArray) and constants.GEO in array.dims:
|
|
266
307
|
array.coords[constants.GEO] = array.coords[constants.GEO].astype(str)
|
|
267
308
|
|
|
309
|
+
# TODO: b/416775065 - Combine with Analyzer._impute_and_aggregate_spend
|
|
310
|
+
@functools.cached_property
|
|
311
|
+
def allocated_media_spend(self) -> xr.DataArray | None:
|
|
312
|
+
"""Returns the allocated media spend for each geo and time."""
|
|
313
|
+
if self.media_spend is not None and len(self.media_spend.shape) == 1:
|
|
314
|
+
return self._allocate_spend(self.media_spend, self.media)
|
|
315
|
+
else:
|
|
316
|
+
return self.media_spend
|
|
317
|
+
|
|
318
|
+
@property
|
|
319
|
+
def allocated_rf_spend(self) -> xr.DataArray | None:
|
|
320
|
+
"""Returns the allocated RF spend for each geo and time."""
|
|
321
|
+
if self.rf_spend is not None and len(self.rf_spend.shape) == 1:
|
|
322
|
+
return self._allocate_spend(self.rf_spend, self.reach * self.frequency)
|
|
323
|
+
else:
|
|
324
|
+
return self.rf_spend
|
|
325
|
+
|
|
326
|
+
def aggregate_media_spend(
|
|
327
|
+
self, calibration_period: np.ndarray | None = None
|
|
328
|
+
) -> np.ndarray | None:
|
|
329
|
+
"""Aggregates media spend by channel over the calibration period."""
|
|
330
|
+
return _aggregate_spend(
|
|
331
|
+
spend=self.allocated_media_spend, calibration_period=calibration_period
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def aggregate_rf_spend(
|
|
335
|
+
self, calibration_period: np.ndarray | None = None
|
|
336
|
+
) -> np.ndarray | None:
|
|
337
|
+
"""Aggregates RF spend by channel over the calibration period."""
|
|
338
|
+
return _aggregate_spend(
|
|
339
|
+
spend=self.allocated_rf_spend,
|
|
340
|
+
calibration_period=calibration_period,
|
|
341
|
+
)
|
|
342
|
+
|
|
268
343
|
@property
|
|
269
344
|
def geo(self) -> xr.DataArray:
|
|
270
345
|
"""Returns the geo dimension."""
|
|
@@ -424,7 +499,8 @@ class InputData:
|
|
|
424
499
|
|
|
425
500
|
def _validate_names(self):
|
|
426
501
|
"""Verifies that the names of the data arrays are correct."""
|
|
427
|
-
|
|
502
|
+
# Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
|
|
503
|
+
arrays = (
|
|
428
504
|
self.kpi,
|
|
429
505
|
self.controls,
|
|
430
506
|
self.population,
|
|
@@ -438,7 +514,7 @@ class InputData:
|
|
|
438
514
|
self.reach,
|
|
439
515
|
self.frequency,
|
|
440
516
|
self.rf_spend,
|
|
441
|
-
|
|
517
|
+
)
|
|
442
518
|
|
|
443
519
|
for array, name in zip(arrays, constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES):
|
|
444
520
|
if array is not None and array.name != name:
|
|
@@ -479,7 +555,6 @@ class InputData:
|
|
|
479
555
|
[
|
|
480
556
|
[constants.RF_CHANNEL],
|
|
481
557
|
[constants.GEO, constants.TIME, constants.RF_CHANNEL],
|
|
482
|
-
[constants.GEO, constants.RF_CHANNEL],
|
|
483
558
|
],
|
|
484
559
|
)
|
|
485
560
|
_check_dim_collection(
|
|
@@ -848,3 +923,24 @@ class InputData:
|
|
|
848
923
|
return self.media_spend.values
|
|
849
924
|
else:
|
|
850
925
|
raise ValueError("Both RF and Media are missing.")
|
|
926
|
+
|
|
927
|
+
def get_total_outcome(self) -> np.ndarray:
|
|
928
|
+
"""Returns total outcome, aggregated over geos and times."""
|
|
929
|
+
if self.revenue_per_kpi is None:
|
|
930
|
+
return np.sum(self.kpi.values)
|
|
931
|
+
return np.sum(self.kpi.values * self.revenue_per_kpi.values)
|
|
932
|
+
|
|
933
|
+
def _allocate_spend(self, spend: xr.DataArray, media_units: xr.DataArray):
|
|
934
|
+
"""Allocates spend across geo and time proportionally to media units."""
|
|
935
|
+
n_times = len(self.kpi.coords[constants.TIME])
|
|
936
|
+
selected_media_units = media_units.isel(media_time=slice(-n_times, None))
|
|
937
|
+
total_media_units_per_channel = selected_media_units.sum(
|
|
938
|
+
dim=["geo", "media_time"]
|
|
939
|
+
)
|
|
940
|
+
proportions = selected_media_units / total_media_units_per_channel
|
|
941
|
+
expanded_spend = spend.expand_dims({
|
|
942
|
+
"geo": selected_media_units["geo"],
|
|
943
|
+
"media_time": selected_media_units["media_time"],
|
|
944
|
+
})
|
|
945
|
+
allocated_spend = expanded_spend * proportions
|
|
946
|
+
return allocated_spend.rename({"media_time": "time"})
|
meridian/data/test_utils.py
CHANGED
|
@@ -65,12 +65,24 @@ _REQUIRED_COORDS = immutabledict.immutabledict({
|
|
|
65
65
|
c.MEDIA_TIME: _sample_times(n_times=3),
|
|
66
66
|
c.CONTROL_VARIABLE: ['control_0', 'control_1'],
|
|
67
67
|
})
|
|
68
|
+
_NON_MEDIA_COORDS = immutabledict.immutabledict(
|
|
69
|
+
{c.NON_MEDIA_CHANNEL: ['non_media_channel_0', 'non_media_channel_1']}
|
|
70
|
+
)
|
|
68
71
|
_MEDIA_COORDS = immutabledict.immutabledict(
|
|
69
72
|
{c.MEDIA_CHANNEL: ['media_channel_0', 'media_channel_1', 'media_channel_2']}
|
|
70
73
|
)
|
|
74
|
+
_ORGANIC_MEDIA_COORDS = immutabledict.immutabledict({
|
|
75
|
+
c.ORGANIC_MEDIA_CHANNEL: [
|
|
76
|
+
'organic_media_channel_0',
|
|
77
|
+
'organic_media_channel_1',
|
|
78
|
+
]
|
|
79
|
+
})
|
|
71
80
|
_RF_COORDS = immutabledict.immutabledict(
|
|
72
81
|
{c.RF_CHANNEL: ['rf_channel_0', 'rf_channel_1']}
|
|
73
82
|
)
|
|
83
|
+
_ORGANIC_RF_COORDS = immutabledict.immutabledict(
|
|
84
|
+
{c.ORGANIC_RF_CHANNEL: ['organic_rf_channel_0', 'organic_rf_channel_1']}
|
|
85
|
+
)
|
|
74
86
|
|
|
75
87
|
_REQUIRED_DATA_VARS = immutabledict.immutabledict({
|
|
76
88
|
c.KPI: (['geo', 'time'], [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
|
|
@@ -376,6 +388,70 @@ DATASET_WITHOUT_TIME_VARIATION_IN_REACH = xr.Dataset(
|
|
|
376
388
|
},
|
|
377
389
|
)
|
|
378
390
|
|
|
391
|
+
DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_MEDIA = xr.Dataset(
|
|
392
|
+
coords=_REQUIRED_COORDS
|
|
393
|
+
| _MEDIA_COORDS
|
|
394
|
+
| _RF_COORDS
|
|
395
|
+
| _ORGANIC_MEDIA_COORDS,
|
|
396
|
+
data_vars=_REQUIRED_DATA_VARS
|
|
397
|
+
| _MEDIA_DATA_VARS
|
|
398
|
+
| _RF_DATA_VARS
|
|
399
|
+
| _OPTIONAL_DATA_VARS
|
|
400
|
+
| {
|
|
401
|
+
c.ORGANIC_MEDIA: (
|
|
402
|
+
['geo', 'media_time', 'organic_media_channel'],
|
|
403
|
+
[
|
|
404
|
+
[[2.1, 2.2], [2.1, 2.21], [2.1, 2.2]],
|
|
405
|
+
[[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
|
|
406
|
+
],
|
|
407
|
+
),
|
|
408
|
+
},
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
DATASET_WITHOUT_TIME_VARIATION_IN_ORGANIC_REACH = xr.Dataset(
|
|
412
|
+
coords=_REQUIRED_COORDS
|
|
413
|
+
| _MEDIA_COORDS
|
|
414
|
+
| _RF_COORDS
|
|
415
|
+
| _ORGANIC_RF_COORDS,
|
|
416
|
+
data_vars=_REQUIRED_DATA_VARS
|
|
417
|
+
| _MEDIA_DATA_VARS
|
|
418
|
+
| _RF_DATA_VARS
|
|
419
|
+
| _OPTIONAL_DATA_VARS
|
|
420
|
+
| {
|
|
421
|
+
c.ORGANIC_REACH: (
|
|
422
|
+
['geo', 'media_time', 'organic_rf_channel'],
|
|
423
|
+
[
|
|
424
|
+
[[2.1, 2.2], [2.11, 2.2], [2.1, 2.2]],
|
|
425
|
+
[[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
|
|
426
|
+
],
|
|
427
|
+
),
|
|
428
|
+
c.ORGANIC_FREQUENCY: (
|
|
429
|
+
['geo', 'media_time', 'organic_rf_channel'],
|
|
430
|
+
[
|
|
431
|
+
[[7.1, 7.2], [7.3, 7.4], [7.5, 7.6]],
|
|
432
|
+
[[7.11, 7.21], [7.31, 7.41], [7.51, 7.61]],
|
|
433
|
+
],
|
|
434
|
+
),
|
|
435
|
+
},
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
DATASET_WITHOUT_TIME_VARIATION_IN_NON_MEDIA_TREATMENTS = xr.Dataset(
|
|
439
|
+
coords=_REQUIRED_COORDS | _MEDIA_COORDS | _RF_COORDS | _NON_MEDIA_COORDS,
|
|
440
|
+
data_vars=_REQUIRED_DATA_VARS
|
|
441
|
+
| _MEDIA_DATA_VARS
|
|
442
|
+
| _RF_DATA_VARS
|
|
443
|
+
| _OPTIONAL_DATA_VARS
|
|
444
|
+
| {
|
|
445
|
+
c.NON_MEDIA_TREATMENTS: (
|
|
446
|
+
['geo', 'time', 'non_media_channel'],
|
|
447
|
+
[
|
|
448
|
+
[[2.1, 2.2], [2.1, 2.2], [2.1, 2.2]],
|
|
449
|
+
[[2.7, 2.8], [2.7, 2.8], [2.7, 2.8]],
|
|
450
|
+
],
|
|
451
|
+
),
|
|
452
|
+
},
|
|
453
|
+
)
|
|
454
|
+
|
|
379
455
|
_NATIONAL_COORDS = immutabledict.immutabledict({
|
|
380
456
|
c.TIME: [
|
|
381
457
|
_SAMPLE_START_DATE.strftime(c.DATE_FORMAT),
|
|
@@ -1491,17 +1567,52 @@ def sample_input_data_from_dataset(
|
|
|
1491
1567
|
dataset: xr.Dataset, kpi_type: str
|
|
1492
1568
|
) -> input_data.InputData:
|
|
1493
1569
|
"""Generates a sample `InputData` from a full xarray Dataset."""
|
|
1570
|
+
media = dataset.media if c.MEDIA in dataset.data_vars.keys() else None
|
|
1571
|
+
media_spend = (
|
|
1572
|
+
dataset.media_spend if c.MEDIA_SPEND in dataset.data_vars.keys() else None
|
|
1573
|
+
)
|
|
1574
|
+
reach = dataset.reach if c.REACH in dataset.data_vars.keys() else None
|
|
1575
|
+
frequency = (
|
|
1576
|
+
dataset.frequency if c.FREQUENCY in dataset.data_vars.keys() else None
|
|
1577
|
+
)
|
|
1578
|
+
rf_spend = (
|
|
1579
|
+
dataset.rf_spend if c.RF_SPEND in dataset.data_vars.keys() else None
|
|
1580
|
+
)
|
|
1581
|
+
organic_media = (
|
|
1582
|
+
dataset.organic_media
|
|
1583
|
+
if c.ORGANIC_MEDIA in dataset.data_vars.keys()
|
|
1584
|
+
else None
|
|
1585
|
+
)
|
|
1586
|
+
organic_reach = (
|
|
1587
|
+
dataset.organic_reach
|
|
1588
|
+
if c.ORGANIC_REACH in dataset.data_vars.keys()
|
|
1589
|
+
else None
|
|
1590
|
+
)
|
|
1591
|
+
organic_frequency = (
|
|
1592
|
+
dataset.organic_frequency
|
|
1593
|
+
if c.ORGANIC_FREQUENCY in dataset.data_vars.keys()
|
|
1594
|
+
else None
|
|
1595
|
+
)
|
|
1596
|
+
non_media_treatments = (
|
|
1597
|
+
dataset.non_media_treatments
|
|
1598
|
+
if c.NON_MEDIA_TREATMENTS in dataset.data_vars.keys()
|
|
1599
|
+
else None
|
|
1600
|
+
)
|
|
1494
1601
|
return input_data.InputData(
|
|
1495
1602
|
kpi=dataset.kpi,
|
|
1496
1603
|
kpi_type=kpi_type,
|
|
1497
1604
|
revenue_per_kpi=dataset.revenue_per_kpi,
|
|
1498
1605
|
population=dataset.population,
|
|
1499
1606
|
controls=dataset.controls,
|
|
1500
|
-
media=
|
|
1501
|
-
media_spend=
|
|
1502
|
-
reach=
|
|
1503
|
-
frequency=
|
|
1504
|
-
rf_spend=
|
|
1607
|
+
media=media,
|
|
1608
|
+
media_spend=media_spend,
|
|
1609
|
+
reach=reach,
|
|
1610
|
+
frequency=frequency,
|
|
1611
|
+
rf_spend=rf_spend,
|
|
1612
|
+
organic_media=organic_media,
|
|
1613
|
+
organic_reach=organic_reach,
|
|
1614
|
+
organic_frequency=organic_frequency,
|
|
1615
|
+
non_media_treatments=non_media_treatments,
|
|
1505
1616
|
)
|
|
1506
1617
|
|
|
1507
1618
|
|
|
@@ -36,7 +36,7 @@ __all__ = [
|
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
# A type alias for a polymorphic "date" type.
|
|
39
|
-
Date: TypeAlias = str | datetime.datetime | datetime.date | np.datetime64
|
|
39
|
+
Date: TypeAlias = str | datetime.datetime | datetime.date | np.datetime64 | None
|
|
40
40
|
|
|
41
41
|
# A type alias for a polymorphic "date interval" type. In all variants it is
|
|
42
42
|
# always a tuple of (start_date, end_date).
|
|
@@ -236,8 +236,8 @@ class TimeCoordinates:
|
|
|
236
236
|
|
|
237
237
|
def expand_selected_time_dims(
|
|
238
238
|
self,
|
|
239
|
-
start_date: Date
|
|
240
|
-
end_date: Date
|
|
239
|
+
start_date: Date = None,
|
|
240
|
+
end_date: Date = None,
|
|
241
241
|
) -> list[datetime.date] | None:
|
|
242
242
|
"""Validates and returns time dimension values based on the selected times.
|
|
243
243
|
|