google-meridian 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.4.dist-info}/METADATA +2 -2
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.4.dist-info}/RECORD +18 -17
- meridian/__init__.py +6 -4
- meridian/analysis/analyzer.py +68 -25
- meridian/analysis/optimizer.py +298 -48
- meridian/constants.py +3 -0
- meridian/data/data_frame_input_data_builder.py +41 -0
- meridian/data/input_data_builder.py +12 -4
- meridian/data/load.py +262 -346
- meridian/mlflow/autolog.py +158 -6
- meridian/model/media.py +7 -0
- meridian/model/model.py +14 -16
- meridian/model/posterior_sampler.py +13 -9
- meridian/model/prior_sampler.py +4 -6
- meridian/version.py +17 -0
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.4.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.4.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.2.dist-info → google_meridian-1.1.4.dist-info}/top_level.txt +0 -0
meridian/data/load.py
CHANGED
|
@@ -22,14 +22,11 @@ object.
|
|
|
22
22
|
import abc
|
|
23
23
|
from collections.abc import Mapping, Sequence
|
|
24
24
|
import dataclasses
|
|
25
|
-
import datetime as dt
|
|
26
|
-
import warnings
|
|
27
|
-
|
|
28
25
|
import immutabledict
|
|
29
26
|
from meridian import constants
|
|
30
27
|
from meridian.data import data_frame_input_data_builder
|
|
31
28
|
from meridian.data import input_data
|
|
32
|
-
|
|
29
|
+
from meridian.data import input_data_builder
|
|
33
30
|
import pandas as pd
|
|
34
31
|
import xarray as xr
|
|
35
32
|
|
|
@@ -202,346 +199,124 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
202
199
|
if (constants.GEO) not in self.dataset.sizes.keys():
|
|
203
200
|
self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
|
|
204
201
|
|
|
205
|
-
if len(self.dataset.coords[constants.GEO]) == 1:
|
|
206
|
-
if constants.POPULATION in self.dataset.data_vars.keys():
|
|
207
|
-
warnings.warn(
|
|
208
|
-
'The `population` argument is ignored in a nationally aggregated'
|
|
209
|
-
' model. It will be reset to [1]'
|
|
210
|
-
)
|
|
211
|
-
self.dataset = self.dataset.drop_vars(names=[constants.POPULATION])
|
|
212
|
-
|
|
213
|
-
# Add a default `population` [1].
|
|
214
|
-
national_population_darray = xr.DataArray(
|
|
215
|
-
[constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE],
|
|
216
|
-
dims=[constants.GEO],
|
|
217
|
-
coords={
|
|
218
|
-
constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME],
|
|
219
|
-
},
|
|
220
|
-
name=constants.POPULATION,
|
|
221
|
-
)
|
|
222
|
-
self.dataset = xr.combine_by_coords(
|
|
223
|
-
[
|
|
224
|
-
national_population_darray,
|
|
225
|
-
self.dataset.assign_coords(
|
|
226
|
-
{constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]}
|
|
227
|
-
),
|
|
228
|
-
],
|
|
229
|
-
compat='override',
|
|
230
|
-
)
|
|
231
|
-
|
|
232
202
|
if constants.MEDIA_TIME not in self.dataset.sizes.keys():
|
|
233
|
-
self.
|
|
234
|
-
self._normalize_time_coordinates(constants.TIME)
|
|
235
|
-
self._normalize_time_coordinates(constants.MEDIA_TIME)
|
|
236
|
-
self._validate_dataset()
|
|
237
|
-
|
|
238
|
-
def _normalize_time_coordinates(self, dim: str):
|
|
239
|
-
if self.dataset.coords.dtypes[dim] == np.dtype('datetime64[ns]'):
|
|
240
|
-
date_strvalues = np.datetime_as_string(self.dataset.coords[dim], unit='D')
|
|
241
|
-
self.dataset = self.dataset.assign_coords({dim: date_strvalues})
|
|
242
|
-
|
|
243
|
-
# Assume that the time coordinate labels are date-formatted strings.
|
|
244
|
-
# We don't currently support other, arbitrary object types in the loaders.
|
|
245
|
-
for time in self.dataset.coords[dim].values:
|
|
246
|
-
try:
|
|
247
|
-
_ = dt.datetime.strptime(time, constants.DATE_FORMAT)
|
|
248
|
-
except ValueError as exc:
|
|
249
|
-
raise ValueError(
|
|
250
|
-
f"Invalid time label: '{time}'. Expected format:"
|
|
251
|
-
f" '{constants.DATE_FORMAT}'"
|
|
252
|
-
) from exc
|
|
203
|
+
na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
|
|
253
204
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
" Please use the 'name_mapping' argument to rename the coordinates."
|
|
205
|
+
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
206
|
+
na_mask |= (
|
|
207
|
+
self.dataset[constants.CONTROLS]
|
|
208
|
+
.isnull()
|
|
209
|
+
.any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
|
|
260
210
|
)
|
|
261
211
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
212
|
+
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
213
|
+
na_mask |= (
|
|
214
|
+
self.dataset[constants.NON_MEDIA_TREATMENTS]
|
|
215
|
+
.isnull()
|
|
216
|
+
.any(dim=[constants.GEO, constants.NON_MEDIA_CHANNEL])
|
|
267
217
|
)
|
|
268
218
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
if
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
if
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
if missing_media_input and missing_rf_input:
|
|
288
|
-
raise ValueError(
|
|
289
|
-
"Some required data is missing. Please use the 'name_mapping'"
|
|
290
|
-
' argument to rename the coordinates/arrays. It is required to have'
|
|
291
|
-
' at least one of media or reach and frequency.'
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
if missing_media_input and len(missing_media_input) != len(
|
|
295
|
-
constants.MEDIA_INPUT_DATA_COORD_NAMES
|
|
296
|
-
) + len(constants.MEDIA_INPUT_DATA_ARRAY_NAMES):
|
|
297
|
-
raise ValueError(
|
|
298
|
-
f"Media data is partially missing. '{missing_media_input}' not found"
|
|
299
|
-
" in dataset's coordinates/arrays. Please use the 'name_mapping'"
|
|
300
|
-
' argument to rename the coordinates/arrays.'
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
if missing_rf_input and len(missing_rf_input) != len(
|
|
304
|
-
constants.RF_INPUT_DATA_COORD_NAMES
|
|
305
|
-
) + len(constants.RF_INPUT_DATA_ARRAY_NAMES):
|
|
306
|
-
raise ValueError(
|
|
307
|
-
f"RF data is partially missing. '{missing_rf_input}' not found in"
|
|
308
|
-
" dataset's coordinates/arrays. Please use the 'name_mapping'"
|
|
309
|
-
' argument to rename the coordinates/arrays.'
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
def _add_media_time(self):
|
|
313
|
-
"""Creates the `media_time` coordinate if it is not provided directly.
|
|
314
|
-
|
|
315
|
-
The user can either create both `time` and `media_time` coordinates directly
|
|
316
|
-
and use them to provide the lagged data for `media`, `reach` and `frequency`
|
|
317
|
-
arrays, or use the `time` coordinate for all arrays. In the second case,
|
|
318
|
-
the lagged period will be determined and the `media_time` and `time`
|
|
319
|
-
coordinates will be created based on the missing values in the other arrays:
|
|
320
|
-
`kpi`, `revenue_per_kpi`, `controls`, `media_spend`, `rf_spend`. The
|
|
321
|
-
analogous mechanism to determine the lagged period is used in
|
|
322
|
-
`DataFrameDataLoader` and `CsvDataLoader`.
|
|
323
|
-
"""
|
|
324
|
-
# Check if there are no NAs in media.
|
|
325
|
-
if constants.MEDIA in self.dataset.data_vars.keys():
|
|
326
|
-
if self.dataset.media.isnull().any(axis=None):
|
|
327
|
-
raise ValueError('NA values found in the media array.')
|
|
328
|
-
|
|
329
|
-
# Check if there are no NAs in reach & frequency.
|
|
330
|
-
if constants.REACH in self.dataset.data_vars.keys():
|
|
331
|
-
if self.dataset.reach.isnull().any(axis=None):
|
|
332
|
-
raise ValueError('NA values found in the reach array.')
|
|
333
|
-
if constants.FREQUENCY in self.dataset.data_vars.keys():
|
|
334
|
-
if self.dataset.frequency.isnull().any(axis=None):
|
|
335
|
-
raise ValueError('NA values found in the frequency array.')
|
|
336
|
-
|
|
337
|
-
# Check if there are no NAs in organic media.
|
|
338
|
-
if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys():
|
|
339
|
-
if self.dataset.organic_media.isnull().any(axis=None):
|
|
340
|
-
raise ValueError('NA values found in the organic media array.')
|
|
341
|
-
|
|
342
|
-
# Check if there are no NAs in organic reach & frequency.
|
|
343
|
-
if constants.ORGANIC_REACH in self.dataset.data_vars.keys():
|
|
344
|
-
if self.dataset.organic_reach.isnull().any(axis=None):
|
|
345
|
-
raise ValueError('NA values found in the organic reach array.')
|
|
346
|
-
if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys():
|
|
347
|
-
if self.dataset.organic_frequency.isnull().any(axis=None):
|
|
348
|
-
raise ValueError('NA values found in the organic frequency array.')
|
|
349
|
-
|
|
350
|
-
# Arrays in which NAs are expected in the lagged-media period.
|
|
351
|
-
na_arrays = [
|
|
352
|
-
constants.KPI,
|
|
353
|
-
]
|
|
354
|
-
|
|
355
|
-
na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
|
|
356
|
-
|
|
357
|
-
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
358
|
-
na_arrays.append(constants.CONTROLS)
|
|
359
|
-
na_mask |= (
|
|
360
|
-
self.dataset[constants.CONTROLS]
|
|
361
|
-
.isnull()
|
|
362
|
-
.any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
|
|
363
|
-
)
|
|
364
|
-
|
|
365
|
-
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
366
|
-
na_arrays.append(constants.NON_MEDIA_TREATMENTS)
|
|
367
|
-
na_mask |= (
|
|
368
|
-
self.dataset[constants.NON_MEDIA_TREATMENTS]
|
|
369
|
-
.isnull()
|
|
370
|
-
.any(dim=[constants.GEO, constants.NON_MEDIA_CHANNEL])
|
|
371
|
-
)
|
|
372
|
-
|
|
373
|
-
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
|
|
374
|
-
na_arrays.append(constants.REVENUE_PER_KPI)
|
|
375
|
-
na_mask |= (
|
|
376
|
-
self.dataset[constants.REVENUE_PER_KPI]
|
|
377
|
-
.isnull()
|
|
378
|
-
.any(dim=constants.GEO)
|
|
379
|
-
)
|
|
380
|
-
if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
|
|
381
|
-
na_arrays.append(constants.MEDIA_SPEND)
|
|
382
|
-
na_mask |= (
|
|
383
|
-
self.dataset[constants.MEDIA_SPEND]
|
|
384
|
-
.isnull()
|
|
385
|
-
.any(dim=[constants.GEO, constants.MEDIA_CHANNEL])
|
|
386
|
-
)
|
|
387
|
-
if constants.RF_SPEND in self.dataset.data_vars.keys():
|
|
388
|
-
na_arrays.append(constants.RF_SPEND)
|
|
389
|
-
na_mask |= (
|
|
390
|
-
self.dataset[constants.RF_SPEND]
|
|
391
|
-
.isnull()
|
|
392
|
-
.any(dim=[constants.GEO, constants.RF_CHANNEL])
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
# Dates with at least one non-NA value in non-media columns
|
|
396
|
-
no_na_period = self.dataset[constants.TIME].isel(time=~na_mask).values
|
|
397
|
-
|
|
398
|
-
# Dates with 100% NA values in all non-media columns.
|
|
399
|
-
na_period = self.dataset[constants.TIME].isel(time=na_mask).values
|
|
400
|
-
|
|
401
|
-
# Check if na_period is a continuous window starting from the earliest time
|
|
402
|
-
# period.
|
|
403
|
-
if not np.all(
|
|
404
|
-
np.sort(na_period)
|
|
405
|
-
== np.sort(np.unique(self.dataset[constants.TIME]))[: len(na_period)]
|
|
406
|
-
):
|
|
407
|
-
raise ValueError(
|
|
408
|
-
"The 'lagged media' period (period with 100% NA values in all"
|
|
409
|
-
f' non-media columns) {na_period} is not a continuous window starting'
|
|
410
|
-
' from the earliest time period.'
|
|
411
|
-
)
|
|
412
|
-
|
|
413
|
-
# Check if for the non-lagged period, there are no NAs in non-media data
|
|
414
|
-
for array in na_arrays:
|
|
415
|
-
if np.any(np.isnan(self.dataset[array].isel(time=~na_mask))):
|
|
416
|
-
raise ValueError(
|
|
417
|
-
'NA values found in other than media columns outside the'
|
|
418
|
-
f' lagged-media period {na_period} (continuous window of 100% NA'
|
|
419
|
-
' values in all other than media columns).'
|
|
219
|
+
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
|
|
220
|
+
na_mask |= (
|
|
221
|
+
self.dataset[constants.REVENUE_PER_KPI]
|
|
222
|
+
.isnull()
|
|
223
|
+
.any(dim=constants.GEO)
|
|
224
|
+
)
|
|
225
|
+
if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
|
|
226
|
+
na_mask |= (
|
|
227
|
+
self.dataset[constants.MEDIA_SPEND]
|
|
228
|
+
.isnull()
|
|
229
|
+
.any(dim=[constants.GEO, constants.MEDIA_CHANNEL])
|
|
230
|
+
)
|
|
231
|
+
if constants.RF_SPEND in self.dataset.data_vars.keys():
|
|
232
|
+
na_mask |= (
|
|
233
|
+
self.dataset[constants.RF_SPEND]
|
|
234
|
+
.isnull()
|
|
235
|
+
.any(dim=[constants.GEO, constants.RF_CHANNEL])
|
|
420
236
|
)
|
|
421
237
|
|
|
422
|
-
|
|
423
|
-
|
|
238
|
+
# Dates with at least one non-NA value in non-media columns
|
|
239
|
+
no_na_period = self.dataset[constants.TIME].isel(time=~na_mask).values
|
|
424
240
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
)
|
|
241
|
+
# Create new `time` and `media_time` coordinates.
|
|
242
|
+
new_time = 'new_time'
|
|
428
243
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
.dropna(dim=constants.TIME)
|
|
432
|
-
.rename({constants.TIME: new_time})
|
|
433
|
-
)
|
|
434
|
-
if constants.CONTROLS in new_dataset.data_vars.keys():
|
|
435
|
-
new_dataset[constants.CONTROLS] = (
|
|
436
|
-
new_dataset[constants.CONTROLS]
|
|
437
|
-
.dropna(dim=constants.TIME)
|
|
438
|
-
.rename({constants.TIME: new_time})
|
|
439
|
-
)
|
|
440
|
-
if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
|
|
441
|
-
new_dataset[constants.NON_MEDIA_TREATMENTS] = (
|
|
442
|
-
new_dataset[constants.NON_MEDIA_TREATMENTS]
|
|
443
|
-
.dropna(dim=constants.TIME)
|
|
444
|
-
.rename({constants.TIME: new_time})
|
|
244
|
+
new_dataset = self.dataset.assign_coords(
|
|
245
|
+
new_time=(new_time, no_na_period),
|
|
445
246
|
)
|
|
446
247
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
new_dataset[constants.REVENUE_PER_KPI]
|
|
248
|
+
new_dataset[constants.KPI] = (
|
|
249
|
+
new_dataset[constants.KPI]
|
|
450
250
|
.dropna(dim=constants.TIME)
|
|
451
251
|
.rename({constants.TIME: new_time})
|
|
452
252
|
)
|
|
253
|
+
if constants.CONTROLS in new_dataset.data_vars.keys():
|
|
254
|
+
new_dataset[constants.CONTROLS] = (
|
|
255
|
+
new_dataset[constants.CONTROLS]
|
|
256
|
+
.dropna(dim=constants.TIME)
|
|
257
|
+
.rename({constants.TIME: new_time})
|
|
258
|
+
)
|
|
259
|
+
if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
|
|
260
|
+
new_dataset[constants.NON_MEDIA_TREATMENTS] = (
|
|
261
|
+
new_dataset[constants.NON_MEDIA_TREATMENTS]
|
|
262
|
+
.dropna(dim=constants.TIME)
|
|
263
|
+
.rename({constants.TIME: new_time})
|
|
264
|
+
)
|
|
453
265
|
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
266
|
+
if constants.REVENUE_PER_KPI in new_dataset.data_vars.keys():
|
|
267
|
+
new_dataset[constants.REVENUE_PER_KPI] = (
|
|
268
|
+
new_dataset[constants.REVENUE_PER_KPI]
|
|
269
|
+
.dropna(dim=constants.TIME)
|
|
270
|
+
.rename({constants.TIME: new_time})
|
|
271
|
+
)
|
|
460
272
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
273
|
+
if constants.MEDIA_SPEND in new_dataset.data_vars.keys():
|
|
274
|
+
new_dataset[constants.MEDIA_SPEND] = (
|
|
275
|
+
new_dataset[constants.MEDIA_SPEND]
|
|
276
|
+
.dropna(dim=constants.TIME)
|
|
277
|
+
.rename({constants.TIME: new_time})
|
|
278
|
+
)
|
|
467
279
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
280
|
+
if constants.RF_SPEND in new_dataset.data_vars.keys():
|
|
281
|
+
new_dataset[constants.RF_SPEND] = (
|
|
282
|
+
new_dataset[constants.RF_SPEND]
|
|
283
|
+
.dropna(dim=constants.TIME)
|
|
284
|
+
.rename({constants.TIME: new_time})
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
self.dataset = new_dataset.rename(
|
|
288
|
+
{constants.TIME: constants.MEDIA_TIME, new_time: constants.TIME}
|
|
289
|
+
)
|
|
471
290
|
|
|
472
291
|
def load(self) -> input_data.InputData:
|
|
473
292
|
"""Returns an `InputData` object containing the data from the dataset."""
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
)
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
)
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
)
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
if constants.FREQUENCY in self.dataset.data_vars.keys()
|
|
502
|
-
else None
|
|
503
|
-
)
|
|
504
|
-
rf_spend = (
|
|
505
|
-
self.dataset.rf_spend
|
|
506
|
-
if constants.RF_SPEND in self.dataset.data_vars.keys()
|
|
507
|
-
else None
|
|
508
|
-
)
|
|
509
|
-
non_media_treatments = (
|
|
510
|
-
self.dataset.non_media_treatments
|
|
511
|
-
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys()
|
|
512
|
-
else None
|
|
513
|
-
)
|
|
514
|
-
organic_media = (
|
|
515
|
-
self.dataset.organic_media
|
|
516
|
-
if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys()
|
|
517
|
-
else None
|
|
518
|
-
)
|
|
519
|
-
organic_reach = (
|
|
520
|
-
self.dataset.organic_reach
|
|
521
|
-
if constants.ORGANIC_REACH in self.dataset.data_vars.keys()
|
|
522
|
-
else None
|
|
523
|
-
)
|
|
524
|
-
organic_frequency = (
|
|
525
|
-
self.dataset.organic_frequency
|
|
526
|
-
if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys()
|
|
527
|
-
else None
|
|
528
|
-
)
|
|
529
|
-
return input_data.InputData(
|
|
530
|
-
kpi=self.dataset.kpi,
|
|
531
|
-
kpi_type=self.kpi_type,
|
|
532
|
-
population=self.dataset.population,
|
|
533
|
-
controls=controls,
|
|
534
|
-
revenue_per_kpi=revenue_per_kpi,
|
|
535
|
-
media=media,
|
|
536
|
-
media_spend=media_spend,
|
|
537
|
-
reach=reach,
|
|
538
|
-
frequency=frequency,
|
|
539
|
-
rf_spend=rf_spend,
|
|
540
|
-
non_media_treatments=non_media_treatments,
|
|
541
|
-
organic_media=organic_media,
|
|
542
|
-
organic_reach=organic_reach,
|
|
543
|
-
organic_frequency=organic_frequency,
|
|
544
|
-
)
|
|
293
|
+
builder = input_data_builder.InputDataBuilder(self.kpi_type)
|
|
294
|
+
builder.kpi = self.dataset.kpi
|
|
295
|
+
if constants.POPULATION in self.dataset.data_vars.keys():
|
|
296
|
+
builder.population = self.dataset.population
|
|
297
|
+
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
298
|
+
builder.controls = self.dataset.controls
|
|
299
|
+
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
|
|
300
|
+
builder.revenue_per_kpi = self.dataset.revenue_per_kpi
|
|
301
|
+
if constants.MEDIA in self.dataset.data_vars.keys():
|
|
302
|
+
builder.media = self.dataset.media
|
|
303
|
+
if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
|
|
304
|
+
builder.media_spend = self.dataset.media_spend
|
|
305
|
+
if constants.REACH in self.dataset.data_vars.keys():
|
|
306
|
+
builder.reach = self.dataset.reach
|
|
307
|
+
if constants.FREQUENCY in self.dataset.data_vars.keys():
|
|
308
|
+
builder.frequency = self.dataset.frequency
|
|
309
|
+
if constants.RF_SPEND in self.dataset.data_vars.keys():
|
|
310
|
+
builder.rf_spend = self.dataset.rf_spend
|
|
311
|
+
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
312
|
+
builder.non_media_treatments = self.dataset.non_media_treatments
|
|
313
|
+
if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys():
|
|
314
|
+
builder.organic_media = self.dataset.organic_media
|
|
315
|
+
if constants.ORGANIC_REACH in self.dataset.data_vars.keys():
|
|
316
|
+
builder.organic_reach = self.dataset.organic_reach
|
|
317
|
+
if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys():
|
|
318
|
+
builder.organic_frequency = self.dataset.organic_frequency
|
|
319
|
+
return builder.build()
|
|
545
320
|
|
|
546
321
|
|
|
547
322
|
@dataclasses.dataclass(frozen=True)
|
|
@@ -607,6 +382,9 @@ class CoordToColumns:
|
|
|
607
382
|
' both.'
|
|
608
383
|
)
|
|
609
384
|
|
|
385
|
+
if self.revenue_per_kpi is not None and not self.revenue_per_kpi.strip():
|
|
386
|
+
raise ValueError('`revenue_per_kpi` should not be empty if provided.')
|
|
387
|
+
|
|
610
388
|
|
|
611
389
|
@dataclasses.dataclass
|
|
612
390
|
class DataFrameDataLoader(InputDataLoader):
|
|
@@ -816,12 +594,109 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
816
594
|
'organic_frequency': 'organic_frequency_to_channel',
|
|
817
595
|
})
|
|
818
596
|
for coord_name, channel_dict in required_mappings.items():
|
|
597
|
+
if getattr(self.coord_to_columns, coord_name, None) is not None:
|
|
598
|
+
if getattr(self, channel_dict, None) is None:
|
|
599
|
+
raise ValueError(
|
|
600
|
+
f"When {coord_name} data is provided, '{channel_dict}' is"
|
|
601
|
+
' required.'
|
|
602
|
+
)
|
|
603
|
+
else:
|
|
604
|
+
if set(getattr(self, channel_dict)) != set(
|
|
605
|
+
getattr(self.coord_to_columns, coord_name)
|
|
606
|
+
):
|
|
607
|
+
raise ValueError(
|
|
608
|
+
f'The {channel_dict} keys must have the same set of values as'
|
|
609
|
+
f' the {coord_name} columns.'
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
if (
|
|
613
|
+
self.media_to_channel is not None
|
|
614
|
+
and self.media_spend_to_channel is not None
|
|
615
|
+
):
|
|
616
|
+
if set(self.media_to_channel.values()) != set(
|
|
617
|
+
self.media_spend_to_channel.values()
|
|
618
|
+
):
|
|
619
|
+
raise ValueError(
|
|
620
|
+
'The media and media_spend columns must have the same set of'
|
|
621
|
+
' channels.'
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
# The columns listed in `media` and `media_spend` must correspond to the
|
|
625
|
+
# same channels, in user-given order!
|
|
626
|
+
# For example, this is invalid:
|
|
627
|
+
# media = ['impressions_tv', 'impressions_yt']
|
|
628
|
+
# media_spend = ['spend_yt', 'spend_tv']
|
|
629
|
+
# But we can only detect this after we map each `media` and `media_spend`
|
|
630
|
+
# column to its canonical channel name.
|
|
631
|
+
media_channels = [
|
|
632
|
+
self.media_to_channel[c] for c in self.coord_to_columns.media
|
|
633
|
+
]
|
|
634
|
+
media_spend_channels = [
|
|
635
|
+
self.media_spend_to_channel[c]
|
|
636
|
+
for c in self.coord_to_columns.media_spend
|
|
637
|
+
]
|
|
638
|
+
if media_channels != media_spend_channels:
|
|
639
|
+
raise ValueError(
|
|
640
|
+
'The `media` and `media_spend` columns must correspond to the same'
|
|
641
|
+
' channels, in user order.'
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
if (
|
|
645
|
+
self.reach_to_channel is not None
|
|
646
|
+
and self.frequency_to_channel is not None
|
|
647
|
+
and self.rf_spend_to_channel is not None
|
|
648
|
+
):
|
|
819
649
|
if (
|
|
820
|
-
|
|
821
|
-
|
|
650
|
+
set(self.reach_to_channel.values())
|
|
651
|
+
!= set(self.frequency_to_channel.values())
|
|
652
|
+
!= set(self.rf_spend_to_channel.values())
|
|
822
653
|
):
|
|
823
654
|
raise ValueError(
|
|
824
|
-
|
|
655
|
+
'The reach, frequency, and rf_spend columns must have the same set'
|
|
656
|
+
' of channels.'
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# Same channel ordering concerns as for `media` and `media_spend`.
|
|
660
|
+
reach_channels = [
|
|
661
|
+
self.reach_to_channel[c] for c in self.coord_to_columns.reach
|
|
662
|
+
]
|
|
663
|
+
frequency_channels = [
|
|
664
|
+
self.frequency_to_channel[c] for c in self.coord_to_columns.frequency
|
|
665
|
+
]
|
|
666
|
+
rf_spend_channels = [
|
|
667
|
+
self.rf_spend_to_channel[c] for c in self.coord_to_columns.rf_spend
|
|
668
|
+
]
|
|
669
|
+
if not (reach_channels == frequency_channels == rf_spend_channels):
|
|
670
|
+
raise ValueError(
|
|
671
|
+
'The `reach`, `frequency`, and `rf_spend` columns must correspond'
|
|
672
|
+
' to the same channels, in user order.'
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
if (
|
|
676
|
+
self.organic_reach_to_channel is not None
|
|
677
|
+
and self.organic_frequency_to_channel is not None
|
|
678
|
+
):
|
|
679
|
+
if set(self.organic_reach_to_channel.values()) != set(
|
|
680
|
+
self.organic_frequency_to_channel.values()
|
|
681
|
+
):
|
|
682
|
+
raise ValueError(
|
|
683
|
+
'The organic_reach and organic_frequency columns must have the'
|
|
684
|
+
' same set of channels.'
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
# Same channel ordering concerns as for `media` and `media_spend`.
|
|
688
|
+
organic_reach_channels = [
|
|
689
|
+
self.organic_reach_to_channel[c]
|
|
690
|
+
for c in self.coord_to_columns.organic_reach
|
|
691
|
+
]
|
|
692
|
+
organic_frequency_channels = [
|
|
693
|
+
self.organic_frequency_to_channel[c]
|
|
694
|
+
for c in self.coord_to_columns.organic_frequency
|
|
695
|
+
]
|
|
696
|
+
if organic_reach_channels != organic_frequency_channels:
|
|
697
|
+
raise ValueError(
|
|
698
|
+
'The `organic_reach` and `organic_frequency` columns must'
|
|
699
|
+
' correspond to the same channels, in user order.'
|
|
825
700
|
)
|
|
826
701
|
|
|
827
702
|
def load(self) -> input_data.InputData:
|
|
@@ -835,58 +710,86 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
835
710
|
self.coord_to_columns.time,
|
|
836
711
|
self.coord_to_columns.geo,
|
|
837
712
|
)
|
|
713
|
+
|
|
838
714
|
if self.coord_to_columns.population in self.df.columns:
|
|
839
715
|
builder.with_population(
|
|
840
716
|
self.df, self.coord_to_columns.population, self.coord_to_columns.geo
|
|
841
717
|
)
|
|
842
|
-
|
|
718
|
+
|
|
719
|
+
if self.coord_to_columns.controls:
|
|
843
720
|
builder.with_controls(
|
|
844
721
|
self.df,
|
|
845
722
|
list(self.coord_to_columns.controls),
|
|
846
723
|
self.coord_to_columns.time,
|
|
847
724
|
self.coord_to_columns.geo,
|
|
848
725
|
)
|
|
849
|
-
|
|
726
|
+
|
|
727
|
+
if self.coord_to_columns.non_media_treatments:
|
|
850
728
|
builder.with_non_media_treatments(
|
|
851
729
|
self.df,
|
|
852
730
|
list(self.coord_to_columns.non_media_treatments),
|
|
853
731
|
self.coord_to_columns.time,
|
|
854
732
|
self.coord_to_columns.geo,
|
|
855
733
|
)
|
|
856
|
-
|
|
734
|
+
|
|
735
|
+
if self.coord_to_columns.revenue_per_kpi:
|
|
857
736
|
builder.with_revenue_per_kpi(
|
|
858
737
|
self.df,
|
|
859
738
|
self.coord_to_columns.revenue_per_kpi,
|
|
860
739
|
self.coord_to_columns.time,
|
|
861
740
|
self.coord_to_columns.geo,
|
|
862
741
|
)
|
|
742
|
+
|
|
863
743
|
if (
|
|
864
|
-
self.
|
|
865
|
-
and self.
|
|
744
|
+
self.media_to_channel is not None
|
|
745
|
+
and self.media_spend_to_channel is not None
|
|
866
746
|
):
|
|
747
|
+
# Based on the invariant rule enforced in `__post_init__`, the columns
|
|
748
|
+
# listed in `media` and `media_spend` are already validated to correspond
|
|
749
|
+
# to the same channels, in user-given order.
|
|
750
|
+
media_execution_columns = list(self.coord_to_columns.media)
|
|
751
|
+
media_spend_columns = list(self.coord_to_columns.media_spend)
|
|
752
|
+
# So now we can use one of the channel mapper dicts to get the canonical
|
|
753
|
+
# channel names for each column.
|
|
754
|
+
media_channel_names = [
|
|
755
|
+
self.media_to_channel[c] for c in self.coord_to_columns.media
|
|
756
|
+
]
|
|
867
757
|
builder.with_media(
|
|
868
758
|
self.df,
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
759
|
+
media_execution_columns,
|
|
760
|
+
media_spend_columns,
|
|
761
|
+
media_channel_names,
|
|
872
762
|
self.coord_to_columns.time,
|
|
873
763
|
self.coord_to_columns.geo,
|
|
874
764
|
)
|
|
875
765
|
|
|
876
766
|
if (
|
|
877
|
-
self.
|
|
878
|
-
and self.
|
|
767
|
+
self.reach_to_channel is not None
|
|
768
|
+
and self.frequency_to_channel is not None
|
|
769
|
+
and self.rf_spend_to_channel is not None
|
|
879
770
|
):
|
|
771
|
+
# Based on the invariant rule enforced in `__post_init__`, the columns
|
|
772
|
+
# listed in `reach`, `frequency`, and `rf_spend` are already validated
|
|
773
|
+
# to correspond to the same channels, in user-given order.
|
|
774
|
+
reach_columns = list(self.coord_to_columns.reach)
|
|
775
|
+
frequency_columns = list(self.coord_to_columns.frequency)
|
|
776
|
+
rf_spend_columns = list(self.coord_to_columns.rf_spend)
|
|
777
|
+
# So now we can use one of the channel mapper dicts to get the canonical
|
|
778
|
+
# channel names for each column.
|
|
779
|
+
rf_channel_names = [
|
|
780
|
+
self.reach_to_channel[c] for c in self.coord_to_columns.reach
|
|
781
|
+
]
|
|
880
782
|
builder.with_reach(
|
|
881
783
|
self.df,
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
784
|
+
reach_columns,
|
|
785
|
+
frequency_columns,
|
|
786
|
+
rf_spend_columns,
|
|
787
|
+
rf_channel_names,
|
|
886
788
|
self.coord_to_columns.time,
|
|
887
789
|
self.coord_to_columns.geo,
|
|
888
790
|
)
|
|
889
|
-
|
|
791
|
+
|
|
792
|
+
if self.coord_to_columns.organic_media:
|
|
890
793
|
builder.with_organic_media(
|
|
891
794
|
self.df,
|
|
892
795
|
list(self.coord_to_columns.organic_media),
|
|
@@ -894,18 +797,31 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
894
797
|
self.coord_to_columns.time,
|
|
895
798
|
self.coord_to_columns.geo,
|
|
896
799
|
)
|
|
800
|
+
|
|
897
801
|
if (
|
|
898
|
-
self.
|
|
899
|
-
and self.
|
|
802
|
+
self.organic_reach_to_channel is not None
|
|
803
|
+
and self.organic_frequency_to_channel is not None
|
|
900
804
|
):
|
|
805
|
+
# Based on the invariant rule enforced in `__post_init__`, the columns
|
|
806
|
+
# listed in `organic_reach` and `organic_frequency` are already
|
|
807
|
+
# validated to correspond to the same channels, in user-given order.
|
|
808
|
+
organic_reach_columns = list(self.coord_to_columns.organic_reach)
|
|
809
|
+
organic_frequency_columns = list(self.coord_to_columns.organic_frequency)
|
|
810
|
+
# So now we can use one of the channel mapper dicts to get the canonical
|
|
811
|
+
# channel names for each column.
|
|
812
|
+
organic_rf_channel_names = [
|
|
813
|
+
self.organic_reach_to_channel[c]
|
|
814
|
+
for c in self.coord_to_columns.organic_reach
|
|
815
|
+
]
|
|
901
816
|
builder.with_organic_reach(
|
|
902
817
|
self.df,
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
818
|
+
organic_reach_columns,
|
|
819
|
+
organic_frequency_columns,
|
|
820
|
+
organic_rf_channel_names,
|
|
906
821
|
self.coord_to_columns.time,
|
|
907
822
|
self.coord_to_columns.geo,
|
|
908
823
|
)
|
|
824
|
+
|
|
909
825
|
return builder.build()
|
|
910
826
|
|
|
911
827
|
|