google-meridian 1.1.3__py3-none-any.whl → 1.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.4.dist-info}/METADATA +2 -2
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.4.dist-info}/RECORD +12 -12
- meridian/analysis/analyzer.py +18 -11
- meridian/analysis/optimizer.py +292 -47
- meridian/constants.py +2 -0
- meridian/data/data_frame_input_data_builder.py +41 -0
- meridian/data/input_data_builder.py +3 -1
- meridian/data/load.py +210 -350
- meridian/version.py +1 -1
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.4.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.4.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.3.dist-info → google_meridian-1.1.4.dist-info}/top_level.txt +0 -0
meridian/data/load.py
CHANGED
|
@@ -22,14 +22,11 @@ object.
|
|
|
22
22
|
import abc
|
|
23
23
|
from collections.abc import Mapping, Sequence
|
|
24
24
|
import dataclasses
|
|
25
|
-
import datetime as dt
|
|
26
|
-
import warnings
|
|
27
|
-
|
|
28
25
|
import immutabledict
|
|
29
26
|
from meridian import constants
|
|
30
27
|
from meridian.data import data_frame_input_data_builder
|
|
31
28
|
from meridian.data import input_data
|
|
32
|
-
|
|
29
|
+
from meridian.data import input_data_builder
|
|
33
30
|
import pandas as pd
|
|
34
31
|
import xarray as xr
|
|
35
32
|
|
|
@@ -202,346 +199,124 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
202
199
|
if (constants.GEO) not in self.dataset.sizes.keys():
|
|
203
200
|
self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
|
|
204
201
|
|
|
205
|
-
if len(self.dataset.coords[constants.GEO]) == 1:
|
|
206
|
-
if constants.POPULATION in self.dataset.data_vars.keys():
|
|
207
|
-
warnings.warn(
|
|
208
|
-
'The `population` argument is ignored in a nationally aggregated'
|
|
209
|
-
' model. It will be reset to [1]'
|
|
210
|
-
)
|
|
211
|
-
self.dataset = self.dataset.drop_vars(names=[constants.POPULATION])
|
|
212
|
-
|
|
213
|
-
# Add a default `population` [1].
|
|
214
|
-
national_population_darray = xr.DataArray(
|
|
215
|
-
[constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE],
|
|
216
|
-
dims=[constants.GEO],
|
|
217
|
-
coords={
|
|
218
|
-
constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME],
|
|
219
|
-
},
|
|
220
|
-
name=constants.POPULATION,
|
|
221
|
-
)
|
|
222
|
-
self.dataset = xr.combine_by_coords(
|
|
223
|
-
[
|
|
224
|
-
national_population_darray,
|
|
225
|
-
self.dataset.assign_coords(
|
|
226
|
-
{constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]}
|
|
227
|
-
),
|
|
228
|
-
],
|
|
229
|
-
compat='override',
|
|
230
|
-
)
|
|
231
|
-
|
|
232
202
|
if constants.MEDIA_TIME not in self.dataset.sizes.keys():
|
|
233
|
-
self.
|
|
234
|
-
self._normalize_time_coordinates(constants.TIME)
|
|
235
|
-
self._normalize_time_coordinates(constants.MEDIA_TIME)
|
|
236
|
-
self._validate_dataset()
|
|
237
|
-
|
|
238
|
-
def _normalize_time_coordinates(self, dim: str):
|
|
239
|
-
if self.dataset.coords.dtypes[dim] == np.dtype('datetime64[ns]'):
|
|
240
|
-
date_strvalues = np.datetime_as_string(self.dataset.coords[dim], unit='D')
|
|
241
|
-
self.dataset = self.dataset.assign_coords({dim: date_strvalues})
|
|
242
|
-
|
|
243
|
-
# Assume that the time coordinate labels are date-formatted strings.
|
|
244
|
-
# We don't currently support other, arbitrary object types in the loaders.
|
|
245
|
-
for time in self.dataset.coords[dim].values:
|
|
246
|
-
try:
|
|
247
|
-
_ = dt.datetime.strptime(time, constants.DATE_FORMAT)
|
|
248
|
-
except ValueError as exc:
|
|
249
|
-
raise ValueError(
|
|
250
|
-
f"Invalid time label: '{time}'. Expected format:"
|
|
251
|
-
f" '{constants.DATE_FORMAT}'"
|
|
252
|
-
) from exc
|
|
203
|
+
na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
|
|
253
204
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
" Please use the 'name_mapping' argument to rename the coordinates."
|
|
205
|
+
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
206
|
+
na_mask |= (
|
|
207
|
+
self.dataset[constants.CONTROLS]
|
|
208
|
+
.isnull()
|
|
209
|
+
.any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
|
|
260
210
|
)
|
|
261
211
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
212
|
+
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
213
|
+
na_mask |= (
|
|
214
|
+
self.dataset[constants.NON_MEDIA_TREATMENTS]
|
|
215
|
+
.isnull()
|
|
216
|
+
.any(dim=[constants.GEO, constants.NON_MEDIA_CHANNEL])
|
|
267
217
|
)
|
|
268
218
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
if
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
if
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
if missing_media_input and missing_rf_input:
|
|
288
|
-
raise ValueError(
|
|
289
|
-
"Some required data is missing. Please use the 'name_mapping'"
|
|
290
|
-
' argument to rename the coordinates/arrays. It is required to have'
|
|
291
|
-
' at least one of media or reach and frequency.'
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
if missing_media_input and len(missing_media_input) != len(
|
|
295
|
-
constants.MEDIA_INPUT_DATA_COORD_NAMES
|
|
296
|
-
) + len(constants.MEDIA_INPUT_DATA_ARRAY_NAMES):
|
|
297
|
-
raise ValueError(
|
|
298
|
-
f"Media data is partially missing. '{missing_media_input}' not found"
|
|
299
|
-
" in dataset's coordinates/arrays. Please use the 'name_mapping'"
|
|
300
|
-
' argument to rename the coordinates/arrays.'
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
if missing_rf_input and len(missing_rf_input) != len(
|
|
304
|
-
constants.RF_INPUT_DATA_COORD_NAMES
|
|
305
|
-
) + len(constants.RF_INPUT_DATA_ARRAY_NAMES):
|
|
306
|
-
raise ValueError(
|
|
307
|
-
f"RF data is partially missing. '{missing_rf_input}' not found in"
|
|
308
|
-
" dataset's coordinates/arrays. Please use the 'name_mapping'"
|
|
309
|
-
' argument to rename the coordinates/arrays.'
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
def _add_media_time(self):
|
|
313
|
-
"""Creates the `media_time` coordinate if it is not provided directly.
|
|
314
|
-
|
|
315
|
-
The user can either create both `time` and `media_time` coordinates directly
|
|
316
|
-
and use them to provide the lagged data for `media`, `reach` and `frequency`
|
|
317
|
-
arrays, or use the `time` coordinate for all arrays. In the second case,
|
|
318
|
-
the lagged period will be determined and the `media_time` and `time`
|
|
319
|
-
coordinates will be created based on the missing values in the other arrays:
|
|
320
|
-
`kpi`, `revenue_per_kpi`, `controls`, `media_spend`, `rf_spend`. The
|
|
321
|
-
analogous mechanism to determine the lagged period is used in
|
|
322
|
-
`DataFrameDataLoader` and `CsvDataLoader`.
|
|
323
|
-
"""
|
|
324
|
-
# Check if there are no NAs in media.
|
|
325
|
-
if constants.MEDIA in self.dataset.data_vars.keys():
|
|
326
|
-
if self.dataset.media.isnull().any(axis=None):
|
|
327
|
-
raise ValueError('NA values found in the media array.')
|
|
328
|
-
|
|
329
|
-
# Check if there are no NAs in reach & frequency.
|
|
330
|
-
if constants.REACH in self.dataset.data_vars.keys():
|
|
331
|
-
if self.dataset.reach.isnull().any(axis=None):
|
|
332
|
-
raise ValueError('NA values found in the reach array.')
|
|
333
|
-
if constants.FREQUENCY in self.dataset.data_vars.keys():
|
|
334
|
-
if self.dataset.frequency.isnull().any(axis=None):
|
|
335
|
-
raise ValueError('NA values found in the frequency array.')
|
|
336
|
-
|
|
337
|
-
# Check if there are no NAs in organic media.
|
|
338
|
-
if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys():
|
|
339
|
-
if self.dataset.organic_media.isnull().any(axis=None):
|
|
340
|
-
raise ValueError('NA values found in the organic media array.')
|
|
341
|
-
|
|
342
|
-
# Check if there are no NAs in organic reach & frequency.
|
|
343
|
-
if constants.ORGANIC_REACH in self.dataset.data_vars.keys():
|
|
344
|
-
if self.dataset.organic_reach.isnull().any(axis=None):
|
|
345
|
-
raise ValueError('NA values found in the organic reach array.')
|
|
346
|
-
if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys():
|
|
347
|
-
if self.dataset.organic_frequency.isnull().any(axis=None):
|
|
348
|
-
raise ValueError('NA values found in the organic frequency array.')
|
|
349
|
-
|
|
350
|
-
# Arrays in which NAs are expected in the lagged-media period.
|
|
351
|
-
na_arrays = [
|
|
352
|
-
constants.KPI,
|
|
353
|
-
]
|
|
354
|
-
|
|
355
|
-
na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
|
|
356
|
-
|
|
357
|
-
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
358
|
-
na_arrays.append(constants.CONTROLS)
|
|
359
|
-
na_mask |= (
|
|
360
|
-
self.dataset[constants.CONTROLS]
|
|
361
|
-
.isnull()
|
|
362
|
-
.any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
|
|
363
|
-
)
|
|
364
|
-
|
|
365
|
-
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
366
|
-
na_arrays.append(constants.NON_MEDIA_TREATMENTS)
|
|
367
|
-
na_mask |= (
|
|
368
|
-
self.dataset[constants.NON_MEDIA_TREATMENTS]
|
|
369
|
-
.isnull()
|
|
370
|
-
.any(dim=[constants.GEO, constants.NON_MEDIA_CHANNEL])
|
|
371
|
-
)
|
|
372
|
-
|
|
373
|
-
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
|
|
374
|
-
na_arrays.append(constants.REVENUE_PER_KPI)
|
|
375
|
-
na_mask |= (
|
|
376
|
-
self.dataset[constants.REVENUE_PER_KPI]
|
|
377
|
-
.isnull()
|
|
378
|
-
.any(dim=constants.GEO)
|
|
379
|
-
)
|
|
380
|
-
if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
|
|
381
|
-
na_arrays.append(constants.MEDIA_SPEND)
|
|
382
|
-
na_mask |= (
|
|
383
|
-
self.dataset[constants.MEDIA_SPEND]
|
|
384
|
-
.isnull()
|
|
385
|
-
.any(dim=[constants.GEO, constants.MEDIA_CHANNEL])
|
|
386
|
-
)
|
|
387
|
-
if constants.RF_SPEND in self.dataset.data_vars.keys():
|
|
388
|
-
na_arrays.append(constants.RF_SPEND)
|
|
389
|
-
na_mask |= (
|
|
390
|
-
self.dataset[constants.RF_SPEND]
|
|
391
|
-
.isnull()
|
|
392
|
-
.any(dim=[constants.GEO, constants.RF_CHANNEL])
|
|
393
|
-
)
|
|
394
|
-
|
|
395
|
-
# Dates with at least one non-NA value in non-media columns
|
|
396
|
-
no_na_period = self.dataset[constants.TIME].isel(time=~na_mask).values
|
|
397
|
-
|
|
398
|
-
# Dates with 100% NA values in all non-media columns.
|
|
399
|
-
na_period = self.dataset[constants.TIME].isel(time=na_mask).values
|
|
400
|
-
|
|
401
|
-
# Check if na_period is a continuous window starting from the earliest time
|
|
402
|
-
# period.
|
|
403
|
-
if not np.all(
|
|
404
|
-
np.sort(na_period)
|
|
405
|
-
== np.sort(np.unique(self.dataset[constants.TIME]))[: len(na_period)]
|
|
406
|
-
):
|
|
407
|
-
raise ValueError(
|
|
408
|
-
"The 'lagged media' period (period with 100% NA values in all"
|
|
409
|
-
f' non-media columns) {na_period} is not a continuous window starting'
|
|
410
|
-
' from the earliest time period.'
|
|
411
|
-
)
|
|
412
|
-
|
|
413
|
-
# Check if for the non-lagged period, there are no NAs in non-media data
|
|
414
|
-
for array in na_arrays:
|
|
415
|
-
if np.any(np.isnan(self.dataset[array].isel(time=~na_mask))):
|
|
416
|
-
raise ValueError(
|
|
417
|
-
'NA values found in other than media columns outside the'
|
|
418
|
-
f' lagged-media period {na_period} (continuous window of 100% NA'
|
|
419
|
-
' values in all other than media columns).'
|
|
219
|
+
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
|
|
220
|
+
na_mask |= (
|
|
221
|
+
self.dataset[constants.REVENUE_PER_KPI]
|
|
222
|
+
.isnull()
|
|
223
|
+
.any(dim=constants.GEO)
|
|
224
|
+
)
|
|
225
|
+
if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
|
|
226
|
+
na_mask |= (
|
|
227
|
+
self.dataset[constants.MEDIA_SPEND]
|
|
228
|
+
.isnull()
|
|
229
|
+
.any(dim=[constants.GEO, constants.MEDIA_CHANNEL])
|
|
230
|
+
)
|
|
231
|
+
if constants.RF_SPEND in self.dataset.data_vars.keys():
|
|
232
|
+
na_mask |= (
|
|
233
|
+
self.dataset[constants.RF_SPEND]
|
|
234
|
+
.isnull()
|
|
235
|
+
.any(dim=[constants.GEO, constants.RF_CHANNEL])
|
|
420
236
|
)
|
|
421
237
|
|
|
422
|
-
|
|
423
|
-
|
|
238
|
+
# Dates with at least one non-NA value in non-media columns
|
|
239
|
+
no_na_period = self.dataset[constants.TIME].isel(time=~na_mask).values
|
|
424
240
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
)
|
|
241
|
+
# Create new `time` and `media_time` coordinates.
|
|
242
|
+
new_time = 'new_time'
|
|
428
243
|
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
.dropna(dim=constants.TIME)
|
|
432
|
-
.rename({constants.TIME: new_time})
|
|
433
|
-
)
|
|
434
|
-
if constants.CONTROLS in new_dataset.data_vars.keys():
|
|
435
|
-
new_dataset[constants.CONTROLS] = (
|
|
436
|
-
new_dataset[constants.CONTROLS]
|
|
437
|
-
.dropna(dim=constants.TIME)
|
|
438
|
-
.rename({constants.TIME: new_time})
|
|
439
|
-
)
|
|
440
|
-
if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
|
|
441
|
-
new_dataset[constants.NON_MEDIA_TREATMENTS] = (
|
|
442
|
-
new_dataset[constants.NON_MEDIA_TREATMENTS]
|
|
443
|
-
.dropna(dim=constants.TIME)
|
|
444
|
-
.rename({constants.TIME: new_time})
|
|
244
|
+
new_dataset = self.dataset.assign_coords(
|
|
245
|
+
new_time=(new_time, no_na_period),
|
|
445
246
|
)
|
|
446
247
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
new_dataset[constants.REVENUE_PER_KPI]
|
|
248
|
+
new_dataset[constants.KPI] = (
|
|
249
|
+
new_dataset[constants.KPI]
|
|
450
250
|
.dropna(dim=constants.TIME)
|
|
451
251
|
.rename({constants.TIME: new_time})
|
|
452
252
|
)
|
|
253
|
+
if constants.CONTROLS in new_dataset.data_vars.keys():
|
|
254
|
+
new_dataset[constants.CONTROLS] = (
|
|
255
|
+
new_dataset[constants.CONTROLS]
|
|
256
|
+
.dropna(dim=constants.TIME)
|
|
257
|
+
.rename({constants.TIME: new_time})
|
|
258
|
+
)
|
|
259
|
+
if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
|
|
260
|
+
new_dataset[constants.NON_MEDIA_TREATMENTS] = (
|
|
261
|
+
new_dataset[constants.NON_MEDIA_TREATMENTS]
|
|
262
|
+
.dropna(dim=constants.TIME)
|
|
263
|
+
.rename({constants.TIME: new_time})
|
|
264
|
+
)
|
|
453
265
|
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
266
|
+
if constants.REVENUE_PER_KPI in new_dataset.data_vars.keys():
|
|
267
|
+
new_dataset[constants.REVENUE_PER_KPI] = (
|
|
268
|
+
new_dataset[constants.REVENUE_PER_KPI]
|
|
269
|
+
.dropna(dim=constants.TIME)
|
|
270
|
+
.rename({constants.TIME: new_time})
|
|
271
|
+
)
|
|
460
272
|
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
273
|
+
if constants.MEDIA_SPEND in new_dataset.data_vars.keys():
|
|
274
|
+
new_dataset[constants.MEDIA_SPEND] = (
|
|
275
|
+
new_dataset[constants.MEDIA_SPEND]
|
|
276
|
+
.dropna(dim=constants.TIME)
|
|
277
|
+
.rename({constants.TIME: new_time})
|
|
278
|
+
)
|
|
467
279
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
280
|
+
if constants.RF_SPEND in new_dataset.data_vars.keys():
|
|
281
|
+
new_dataset[constants.RF_SPEND] = (
|
|
282
|
+
new_dataset[constants.RF_SPEND]
|
|
283
|
+
.dropna(dim=constants.TIME)
|
|
284
|
+
.rename({constants.TIME: new_time})
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
self.dataset = new_dataset.rename(
|
|
288
|
+
{constants.TIME: constants.MEDIA_TIME, new_time: constants.TIME}
|
|
289
|
+
)
|
|
471
290
|
|
|
472
291
|
def load(self) -> input_data.InputData:
|
|
473
292
|
"""Returns an `InputData` object containing the data from the dataset."""
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
)
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
)
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
)
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
if constants.FREQUENCY in self.dataset.data_vars.keys()
|
|
502
|
-
else None
|
|
503
|
-
)
|
|
504
|
-
rf_spend = (
|
|
505
|
-
self.dataset.rf_spend
|
|
506
|
-
if constants.RF_SPEND in self.dataset.data_vars.keys()
|
|
507
|
-
else None
|
|
508
|
-
)
|
|
509
|
-
non_media_treatments = (
|
|
510
|
-
self.dataset.non_media_treatments
|
|
511
|
-
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys()
|
|
512
|
-
else None
|
|
513
|
-
)
|
|
514
|
-
organic_media = (
|
|
515
|
-
self.dataset.organic_media
|
|
516
|
-
if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys()
|
|
517
|
-
else None
|
|
518
|
-
)
|
|
519
|
-
organic_reach = (
|
|
520
|
-
self.dataset.organic_reach
|
|
521
|
-
if constants.ORGANIC_REACH in self.dataset.data_vars.keys()
|
|
522
|
-
else None
|
|
523
|
-
)
|
|
524
|
-
organic_frequency = (
|
|
525
|
-
self.dataset.organic_frequency
|
|
526
|
-
if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys()
|
|
527
|
-
else None
|
|
528
|
-
)
|
|
529
|
-
return input_data.InputData(
|
|
530
|
-
kpi=self.dataset.kpi,
|
|
531
|
-
kpi_type=self.kpi_type,
|
|
532
|
-
population=self.dataset.population,
|
|
533
|
-
controls=controls,
|
|
534
|
-
revenue_per_kpi=revenue_per_kpi,
|
|
535
|
-
media=media,
|
|
536
|
-
media_spend=media_spend,
|
|
537
|
-
reach=reach,
|
|
538
|
-
frequency=frequency,
|
|
539
|
-
rf_spend=rf_spend,
|
|
540
|
-
non_media_treatments=non_media_treatments,
|
|
541
|
-
organic_media=organic_media,
|
|
542
|
-
organic_reach=organic_reach,
|
|
543
|
-
organic_frequency=organic_frequency,
|
|
544
|
-
)
|
|
293
|
+
builder = input_data_builder.InputDataBuilder(self.kpi_type)
|
|
294
|
+
builder.kpi = self.dataset.kpi
|
|
295
|
+
if constants.POPULATION in self.dataset.data_vars.keys():
|
|
296
|
+
builder.population = self.dataset.population
|
|
297
|
+
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
298
|
+
builder.controls = self.dataset.controls
|
|
299
|
+
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
|
|
300
|
+
builder.revenue_per_kpi = self.dataset.revenue_per_kpi
|
|
301
|
+
if constants.MEDIA in self.dataset.data_vars.keys():
|
|
302
|
+
builder.media = self.dataset.media
|
|
303
|
+
if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
|
|
304
|
+
builder.media_spend = self.dataset.media_spend
|
|
305
|
+
if constants.REACH in self.dataset.data_vars.keys():
|
|
306
|
+
builder.reach = self.dataset.reach
|
|
307
|
+
if constants.FREQUENCY in self.dataset.data_vars.keys():
|
|
308
|
+
builder.frequency = self.dataset.frequency
|
|
309
|
+
if constants.RF_SPEND in self.dataset.data_vars.keys():
|
|
310
|
+
builder.rf_spend = self.dataset.rf_spend
|
|
311
|
+
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
312
|
+
builder.non_media_treatments = self.dataset.non_media_treatments
|
|
313
|
+
if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys():
|
|
314
|
+
builder.organic_media = self.dataset.organic_media
|
|
315
|
+
if constants.ORGANIC_REACH in self.dataset.data_vars.keys():
|
|
316
|
+
builder.organic_reach = self.dataset.organic_reach
|
|
317
|
+
if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys():
|
|
318
|
+
builder.organic_frequency = self.dataset.organic_frequency
|
|
319
|
+
return builder.build()
|
|
545
320
|
|
|
546
321
|
|
|
547
322
|
@dataclasses.dataclass(frozen=True)
|
|
@@ -607,6 +382,9 @@ class CoordToColumns:
|
|
|
607
382
|
' both.'
|
|
608
383
|
)
|
|
609
384
|
|
|
385
|
+
if self.revenue_per_kpi is not None and not self.revenue_per_kpi.strip():
|
|
386
|
+
raise ValueError('`revenue_per_kpi` should not be empty if provided.')
|
|
387
|
+
|
|
610
388
|
|
|
611
389
|
@dataclasses.dataclass
|
|
612
390
|
class DataFrameDataLoader(InputDataLoader):
|
|
@@ -830,6 +608,7 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
830
608
|
f'The {channel_dict} keys must have the same set of values as'
|
|
831
609
|
f' the {coord_name} columns.'
|
|
832
610
|
)
|
|
611
|
+
|
|
833
612
|
if (
|
|
834
613
|
self.media_to_channel is not None
|
|
835
614
|
and self.media_spend_to_channel is not None
|
|
@@ -841,6 +620,27 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
841
620
|
'The media and media_spend columns must have the same set of'
|
|
842
621
|
' channels.'
|
|
843
622
|
)
|
|
623
|
+
|
|
624
|
+
# The columns listed in `media` and `media_spend` must correspond to the
|
|
625
|
+
# same channels, in user-given order!
|
|
626
|
+
# For example, this is invalid:
|
|
627
|
+
# media = ['impressions_tv', 'impressions_yt']
|
|
628
|
+
# media_spend = ['spend_yt', 'spend_tv']
|
|
629
|
+
# But we can only detect this after we map each `media` and `media_spend`
|
|
630
|
+
# column to its canonical channel name.
|
|
631
|
+
media_channels = [
|
|
632
|
+
self.media_to_channel[c] for c in self.coord_to_columns.media
|
|
633
|
+
]
|
|
634
|
+
media_spend_channels = [
|
|
635
|
+
self.media_spend_to_channel[c]
|
|
636
|
+
for c in self.coord_to_columns.media_spend
|
|
637
|
+
]
|
|
638
|
+
if media_channels != media_spend_channels:
|
|
639
|
+
raise ValueError(
|
|
640
|
+
'The `media` and `media_spend` columns must correspond to the same'
|
|
641
|
+
' channels, in user order.'
|
|
642
|
+
)
|
|
643
|
+
|
|
844
644
|
if (
|
|
845
645
|
self.reach_to_channel is not None
|
|
846
646
|
and self.frequency_to_channel is not None
|
|
@@ -855,6 +655,23 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
855
655
|
'The reach, frequency, and rf_spend columns must have the same set'
|
|
856
656
|
' of channels.'
|
|
857
657
|
)
|
|
658
|
+
|
|
659
|
+
# Same channel ordering concerns as for `media` and `media_spend`.
|
|
660
|
+
reach_channels = [
|
|
661
|
+
self.reach_to_channel[c] for c in self.coord_to_columns.reach
|
|
662
|
+
]
|
|
663
|
+
frequency_channels = [
|
|
664
|
+
self.frequency_to_channel[c] for c in self.coord_to_columns.frequency
|
|
665
|
+
]
|
|
666
|
+
rf_spend_channels = [
|
|
667
|
+
self.rf_spend_to_channel[c] for c in self.coord_to_columns.rf_spend
|
|
668
|
+
]
|
|
669
|
+
if not (reach_channels == frequency_channels == rf_spend_channels):
|
|
670
|
+
raise ValueError(
|
|
671
|
+
'The `reach`, `frequency`, and `rf_spend` columns must correspond'
|
|
672
|
+
' to the same channels, in user order.'
|
|
673
|
+
)
|
|
674
|
+
|
|
858
675
|
if (
|
|
859
676
|
self.organic_reach_to_channel is not None
|
|
860
677
|
and self.organic_frequency_to_channel is not None
|
|
@@ -867,6 +684,21 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
867
684
|
' same set of channels.'
|
|
868
685
|
)
|
|
869
686
|
|
|
687
|
+
# Same channel ordering concerns as for `media` and `media_spend`.
|
|
688
|
+
organic_reach_channels = [
|
|
689
|
+
self.organic_reach_to_channel[c]
|
|
690
|
+
for c in self.coord_to_columns.organic_reach
|
|
691
|
+
]
|
|
692
|
+
organic_frequency_channels = [
|
|
693
|
+
self.organic_frequency_to_channel[c]
|
|
694
|
+
for c in self.coord_to_columns.organic_frequency
|
|
695
|
+
]
|
|
696
|
+
if organic_reach_channels != organic_frequency_channels:
|
|
697
|
+
raise ValueError(
|
|
698
|
+
'The `organic_reach` and `organic_frequency` columns must'
|
|
699
|
+
' correspond to the same channels, in user order.'
|
|
700
|
+
)
|
|
701
|
+
|
|
870
702
|
def load(self) -> input_data.InputData:
|
|
871
703
|
"""Reads data from a dataframe and returns an InputData object."""
|
|
872
704
|
|
|
@@ -878,66 +710,86 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
878
710
|
self.coord_to_columns.time,
|
|
879
711
|
self.coord_to_columns.geo,
|
|
880
712
|
)
|
|
713
|
+
|
|
881
714
|
if self.coord_to_columns.population in self.df.columns:
|
|
882
715
|
builder.with_population(
|
|
883
716
|
self.df, self.coord_to_columns.population, self.coord_to_columns.geo
|
|
884
717
|
)
|
|
885
|
-
|
|
718
|
+
|
|
719
|
+
if self.coord_to_columns.controls:
|
|
886
720
|
builder.with_controls(
|
|
887
721
|
self.df,
|
|
888
722
|
list(self.coord_to_columns.controls),
|
|
889
723
|
self.coord_to_columns.time,
|
|
890
724
|
self.coord_to_columns.geo,
|
|
891
725
|
)
|
|
892
|
-
|
|
726
|
+
|
|
727
|
+
if self.coord_to_columns.non_media_treatments:
|
|
893
728
|
builder.with_non_media_treatments(
|
|
894
729
|
self.df,
|
|
895
730
|
list(self.coord_to_columns.non_media_treatments),
|
|
896
731
|
self.coord_to_columns.time,
|
|
897
732
|
self.coord_to_columns.geo,
|
|
898
733
|
)
|
|
899
|
-
|
|
734
|
+
|
|
735
|
+
if self.coord_to_columns.revenue_per_kpi:
|
|
900
736
|
builder.with_revenue_per_kpi(
|
|
901
737
|
self.df,
|
|
902
738
|
self.coord_to_columns.revenue_per_kpi,
|
|
903
739
|
self.coord_to_columns.time,
|
|
904
740
|
self.coord_to_columns.geo,
|
|
905
741
|
)
|
|
742
|
+
|
|
906
743
|
if (
|
|
907
744
|
self.media_to_channel is not None
|
|
908
745
|
and self.media_spend_to_channel is not None
|
|
909
746
|
):
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
747
|
+
# Based on the invariant rule enforced in `__post_init__`, the columns
|
|
748
|
+
# listed in `media` and `media_spend` are already validated to correspond
|
|
749
|
+
# to the same channels, in user-given order.
|
|
750
|
+
media_execution_columns = list(self.coord_to_columns.media)
|
|
751
|
+
media_spend_columns = list(self.coord_to_columns.media_spend)
|
|
752
|
+
# So now we can use one of the channel mapper dicts to get the canonical
|
|
753
|
+
# channel names for each column.
|
|
754
|
+
media_channel_names = [
|
|
755
|
+
self.media_to_channel[c] for c in self.coord_to_columns.media
|
|
756
|
+
]
|
|
914
757
|
builder.with_media(
|
|
915
758
|
self.df,
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
759
|
+
media_execution_columns,
|
|
760
|
+
media_spend_columns,
|
|
761
|
+
media_channel_names,
|
|
919
762
|
self.coord_to_columns.time,
|
|
920
763
|
self.coord_to_columns.geo,
|
|
921
764
|
)
|
|
765
|
+
|
|
922
766
|
if (
|
|
923
767
|
self.reach_to_channel is not None
|
|
924
768
|
and self.frequency_to_channel is not None
|
|
925
769
|
and self.rf_spend_to_channel is not None
|
|
926
770
|
):
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
771
|
+
# Based on the invariant rule enforced in `__post_init__`, the columns
|
|
772
|
+
# listed in `reach`, `frequency`, and `rf_spend` are already validated
|
|
773
|
+
# to correspond to the same channels, in user-given order.
|
|
774
|
+
reach_columns = list(self.coord_to_columns.reach)
|
|
775
|
+
frequency_columns = list(self.coord_to_columns.frequency)
|
|
776
|
+
rf_spend_columns = list(self.coord_to_columns.rf_spend)
|
|
777
|
+
# So now we can use one of the channel mapper dicts to get the canonical
|
|
778
|
+
# channel names for each column.
|
|
779
|
+
rf_channel_names = [
|
|
780
|
+
self.reach_to_channel[c] for c in self.coord_to_columns.reach
|
|
781
|
+
]
|
|
931
782
|
builder.with_reach(
|
|
932
783
|
self.df,
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
784
|
+
reach_columns,
|
|
785
|
+
frequency_columns,
|
|
786
|
+
rf_spend_columns,
|
|
787
|
+
rf_channel_names,
|
|
937
788
|
self.coord_to_columns.time,
|
|
938
789
|
self.coord_to_columns.geo,
|
|
939
790
|
)
|
|
940
|
-
|
|
791
|
+
|
|
792
|
+
if self.coord_to_columns.organic_media:
|
|
941
793
|
builder.with_organic_media(
|
|
942
794
|
self.df,
|
|
943
795
|
list(self.coord_to_columns.organic_media),
|
|
@@ -945,23 +797,31 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
945
797
|
self.coord_to_columns.time,
|
|
946
798
|
self.coord_to_columns.geo,
|
|
947
799
|
)
|
|
800
|
+
|
|
948
801
|
if (
|
|
949
802
|
self.organic_reach_to_channel is not None
|
|
950
803
|
and self.organic_frequency_to_channel is not None
|
|
951
804
|
):
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
805
|
+
# Based on the invariant rule enforced in `__post_init__`, the columns
|
|
806
|
+
# listed in `organic_reach` and `organic_frequency` are already
|
|
807
|
+
# validated to correspond to the same channels, in user-given order.
|
|
808
|
+
organic_reach_columns = list(self.coord_to_columns.organic_reach)
|
|
809
|
+
organic_frequency_columns = list(self.coord_to_columns.organic_frequency)
|
|
810
|
+
# So now we can use one of the channel mapper dicts to get the canonical
|
|
811
|
+
# channel names for each column.
|
|
812
|
+
organic_rf_channel_names = [
|
|
813
|
+
self.organic_reach_to_channel[c]
|
|
814
|
+
for c in self.coord_to_columns.organic_reach
|
|
815
|
+
]
|
|
957
816
|
builder.with_organic_reach(
|
|
958
817
|
self.df,
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
818
|
+
organic_reach_columns,
|
|
819
|
+
organic_frequency_columns,
|
|
820
|
+
organic_rf_channel_names,
|
|
962
821
|
self.coord_to_columns.time,
|
|
963
822
|
self.coord_to_columns.geo,
|
|
964
823
|
)
|
|
824
|
+
|
|
965
825
|
return builder.build()
|
|
966
826
|
|
|
967
827
|
|