google-meridian 1.1.3__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/data/load.py CHANGED
@@ -22,14 +22,11 @@ object.
22
22
  import abc
23
23
  from collections.abc import Mapping, Sequence
24
24
  import dataclasses
25
- import datetime as dt
26
- import warnings
27
-
28
25
  import immutabledict
29
26
  from meridian import constants
30
27
  from meridian.data import data_frame_input_data_builder
31
28
  from meridian.data import input_data
32
- import numpy as np
29
+ from meridian.data import input_data_builder
33
30
  import pandas as pd
34
31
  import xarray as xr
35
32
 
@@ -202,346 +199,124 @@ class XrDatasetDataLoader(InputDataLoader):
202
199
  if (constants.GEO) not in self.dataset.sizes.keys():
203
200
  self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
204
201
 
205
- if len(self.dataset.coords[constants.GEO]) == 1:
206
- if constants.POPULATION in self.dataset.data_vars.keys():
207
- warnings.warn(
208
- 'The `population` argument is ignored in a nationally aggregated'
209
- ' model. It will be reset to [1]'
210
- )
211
- self.dataset = self.dataset.drop_vars(names=[constants.POPULATION])
212
-
213
- # Add a default `population` [1].
214
- national_population_darray = xr.DataArray(
215
- [constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE],
216
- dims=[constants.GEO],
217
- coords={
218
- constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME],
219
- },
220
- name=constants.POPULATION,
221
- )
222
- self.dataset = xr.combine_by_coords(
223
- [
224
- national_population_darray,
225
- self.dataset.assign_coords(
226
- {constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]}
227
- ),
228
- ],
229
- compat='override',
230
- )
231
-
232
202
  if constants.MEDIA_TIME not in self.dataset.sizes.keys():
233
- self._add_media_time()
234
- self._normalize_time_coordinates(constants.TIME)
235
- self._normalize_time_coordinates(constants.MEDIA_TIME)
236
- self._validate_dataset()
237
-
238
- def _normalize_time_coordinates(self, dim: str):
239
- if self.dataset.coords.dtypes[dim] == np.dtype('datetime64[ns]'):
240
- date_strvalues = np.datetime_as_string(self.dataset.coords[dim], unit='D')
241
- self.dataset = self.dataset.assign_coords({dim: date_strvalues})
242
-
243
- # Assume that the time coordinate labels are date-formatted strings.
244
- # We don't currently support other, arbitrary object types in the loaders.
245
- for time in self.dataset.coords[dim].values:
246
- try:
247
- _ = dt.datetime.strptime(time, constants.DATE_FORMAT)
248
- except ValueError as exc:
249
- raise ValueError(
250
- f"Invalid time label: '{time}'. Expected format:"
251
- f" '{constants.DATE_FORMAT}'"
252
- ) from exc
203
+ na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
253
204
 
254
- def _validate_dataset(self):
255
- for coord_name in constants.REQUIRED_INPUT_DATA_COORD_NAMES:
256
- if coord_name not in self.dataset.coords:
257
- raise ValueError(
258
- f"Coordinate '{coord_name}' not found in dataset's coordinates."
259
- " Please use the 'name_mapping' argument to rename the coordinates."
205
+ if constants.CONTROLS in self.dataset.data_vars.keys():
206
+ na_mask |= (
207
+ self.dataset[constants.CONTROLS]
208
+ .isnull()
209
+ .any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
260
210
  )
261
211
 
262
- for array_name in constants.REQUIRED_INPUT_DATA_ARRAY_NAMES:
263
- if array_name not in self.dataset.data_vars:
264
- raise ValueError(
265
- f"Array '{array_name}' not found in dataset's arrays."
266
- " Please use the 'name_mapping' argument to rename the arrays."
212
+ if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
213
+ na_mask |= (
214
+ self.dataset[constants.NON_MEDIA_TREATMENTS]
215
+ .isnull()
216
+ .any(dim=[constants.GEO, constants.NON_MEDIA_CHANNEL])
267
217
  )
268
218
 
269
- # Check for media.
270
- missing_media_input = []
271
- for coord_name in constants.MEDIA_INPUT_DATA_COORD_NAMES:
272
- if coord_name not in self.dataset.coords:
273
- missing_media_input.append(coord_name)
274
- for array_name in constants.MEDIA_INPUT_DATA_ARRAY_NAMES:
275
- if array_name not in self.dataset.data_vars:
276
- missing_media_input.append(array_name)
277
-
278
- # Check for RF.
279
- missing_rf_input = []
280
- for coord_name in constants.RF_INPUT_DATA_COORD_NAMES:
281
- if coord_name not in self.dataset.coords:
282
- missing_rf_input.append(coord_name)
283
- for array_name in constants.RF_INPUT_DATA_ARRAY_NAMES:
284
- if array_name not in self.dataset.data_vars:
285
- missing_rf_input.append(array_name)
286
-
287
- if missing_media_input and missing_rf_input:
288
- raise ValueError(
289
- "Some required data is missing. Please use the 'name_mapping'"
290
- ' argument to rename the coordinates/arrays. It is required to have'
291
- ' at least one of media or reach and frequency.'
292
- )
293
-
294
- if missing_media_input and len(missing_media_input) != len(
295
- constants.MEDIA_INPUT_DATA_COORD_NAMES
296
- ) + len(constants.MEDIA_INPUT_DATA_ARRAY_NAMES):
297
- raise ValueError(
298
- f"Media data is partially missing. '{missing_media_input}' not found"
299
- " in dataset's coordinates/arrays. Please use the 'name_mapping'"
300
- ' argument to rename the coordinates/arrays.'
301
- )
302
-
303
- if missing_rf_input and len(missing_rf_input) != len(
304
- constants.RF_INPUT_DATA_COORD_NAMES
305
- ) + len(constants.RF_INPUT_DATA_ARRAY_NAMES):
306
- raise ValueError(
307
- f"RF data is partially missing. '{missing_rf_input}' not found in"
308
- " dataset's coordinates/arrays. Please use the 'name_mapping'"
309
- ' argument to rename the coordinates/arrays.'
310
- )
311
-
312
- def _add_media_time(self):
313
- """Creates the `media_time` coordinate if it is not provided directly.
314
-
315
- The user can either create both `time` and `media_time` coordinates directly
316
- and use them to provide the lagged data for `media`, `reach` and `frequency`
317
- arrays, or use the `time` coordinate for all arrays. In the second case,
318
- the lagged period will be determined and the `media_time` and `time`
319
- coordinates will be created based on the missing values in the other arrays:
320
- `kpi`, `revenue_per_kpi`, `controls`, `media_spend`, `rf_spend`. The
321
- analogous mechanism to determine the lagged period is used in
322
- `DataFrameDataLoader` and `CsvDataLoader`.
323
- """
324
- # Check if there are no NAs in media.
325
- if constants.MEDIA in self.dataset.data_vars.keys():
326
- if self.dataset.media.isnull().any(axis=None):
327
- raise ValueError('NA values found in the media array.')
328
-
329
- # Check if there are no NAs in reach & frequency.
330
- if constants.REACH in self.dataset.data_vars.keys():
331
- if self.dataset.reach.isnull().any(axis=None):
332
- raise ValueError('NA values found in the reach array.')
333
- if constants.FREQUENCY in self.dataset.data_vars.keys():
334
- if self.dataset.frequency.isnull().any(axis=None):
335
- raise ValueError('NA values found in the frequency array.')
336
-
337
- # Check if there are no NAs in organic media.
338
- if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys():
339
- if self.dataset.organic_media.isnull().any(axis=None):
340
- raise ValueError('NA values found in the organic media array.')
341
-
342
- # Check if there are no NAs in organic reach & frequency.
343
- if constants.ORGANIC_REACH in self.dataset.data_vars.keys():
344
- if self.dataset.organic_reach.isnull().any(axis=None):
345
- raise ValueError('NA values found in the organic reach array.')
346
- if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys():
347
- if self.dataset.organic_frequency.isnull().any(axis=None):
348
- raise ValueError('NA values found in the organic frequency array.')
349
-
350
- # Arrays in which NAs are expected in the lagged-media period.
351
- na_arrays = [
352
- constants.KPI,
353
- ]
354
-
355
- na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
356
-
357
- if constants.CONTROLS in self.dataset.data_vars.keys():
358
- na_arrays.append(constants.CONTROLS)
359
- na_mask |= (
360
- self.dataset[constants.CONTROLS]
361
- .isnull()
362
- .any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
363
- )
364
-
365
- if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
366
- na_arrays.append(constants.NON_MEDIA_TREATMENTS)
367
- na_mask |= (
368
- self.dataset[constants.NON_MEDIA_TREATMENTS]
369
- .isnull()
370
- .any(dim=[constants.GEO, constants.NON_MEDIA_CHANNEL])
371
- )
372
-
373
- if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
374
- na_arrays.append(constants.REVENUE_PER_KPI)
375
- na_mask |= (
376
- self.dataset[constants.REVENUE_PER_KPI]
377
- .isnull()
378
- .any(dim=constants.GEO)
379
- )
380
- if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
381
- na_arrays.append(constants.MEDIA_SPEND)
382
- na_mask |= (
383
- self.dataset[constants.MEDIA_SPEND]
384
- .isnull()
385
- .any(dim=[constants.GEO, constants.MEDIA_CHANNEL])
386
- )
387
- if constants.RF_SPEND in self.dataset.data_vars.keys():
388
- na_arrays.append(constants.RF_SPEND)
389
- na_mask |= (
390
- self.dataset[constants.RF_SPEND]
391
- .isnull()
392
- .any(dim=[constants.GEO, constants.RF_CHANNEL])
393
- )
394
-
395
- # Dates with at least one non-NA value in non-media columns
396
- no_na_period = self.dataset[constants.TIME].isel(time=~na_mask).values
397
-
398
- # Dates with 100% NA values in all non-media columns.
399
- na_period = self.dataset[constants.TIME].isel(time=na_mask).values
400
-
401
- # Check if na_period is a continuous window starting from the earliest time
402
- # period.
403
- if not np.all(
404
- np.sort(na_period)
405
- == np.sort(np.unique(self.dataset[constants.TIME]))[: len(na_period)]
406
- ):
407
- raise ValueError(
408
- "The 'lagged media' period (period with 100% NA values in all"
409
- f' non-media columns) {na_period} is not a continuous window starting'
410
- ' from the earliest time period.'
411
- )
412
-
413
- # Check if for the non-lagged period, there are no NAs in non-media data
414
- for array in na_arrays:
415
- if np.any(np.isnan(self.dataset[array].isel(time=~na_mask))):
416
- raise ValueError(
417
- 'NA values found in other than media columns outside the'
418
- f' lagged-media period {na_period} (continuous window of 100% NA'
419
- ' values in all other than media columns).'
219
+ if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
220
+ na_mask |= (
221
+ self.dataset[constants.REVENUE_PER_KPI]
222
+ .isnull()
223
+ .any(dim=constants.GEO)
224
+ )
225
+ if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
226
+ na_mask |= (
227
+ self.dataset[constants.MEDIA_SPEND]
228
+ .isnull()
229
+ .any(dim=[constants.GEO, constants.MEDIA_CHANNEL])
230
+ )
231
+ if constants.RF_SPEND in self.dataset.data_vars.keys():
232
+ na_mask |= (
233
+ self.dataset[constants.RF_SPEND]
234
+ .isnull()
235
+ .any(dim=[constants.GEO, constants.RF_CHANNEL])
420
236
  )
421
237
 
422
- # Create new `time` and `media_time` coordinates.
423
- new_time = 'new_time'
238
+ # Dates with at least one non-NA value in non-media columns
239
+ no_na_period = self.dataset[constants.TIME].isel(time=~na_mask).values
424
240
 
425
- new_dataset = self.dataset.assign_coords(
426
- new_time=(new_time, no_na_period),
427
- )
241
+ # Create new `time` and `media_time` coordinates.
242
+ new_time = 'new_time'
428
243
 
429
- new_dataset[constants.KPI] = (
430
- new_dataset[constants.KPI]
431
- .dropna(dim=constants.TIME)
432
- .rename({constants.TIME: new_time})
433
- )
434
- if constants.CONTROLS in new_dataset.data_vars.keys():
435
- new_dataset[constants.CONTROLS] = (
436
- new_dataset[constants.CONTROLS]
437
- .dropna(dim=constants.TIME)
438
- .rename({constants.TIME: new_time})
439
- )
440
- if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
441
- new_dataset[constants.NON_MEDIA_TREATMENTS] = (
442
- new_dataset[constants.NON_MEDIA_TREATMENTS]
443
- .dropna(dim=constants.TIME)
444
- .rename({constants.TIME: new_time})
244
+ new_dataset = self.dataset.assign_coords(
245
+ new_time=(new_time, no_na_period),
445
246
  )
446
247
 
447
- if constants.REVENUE_PER_KPI in new_dataset.data_vars.keys():
448
- new_dataset[constants.REVENUE_PER_KPI] = (
449
- new_dataset[constants.REVENUE_PER_KPI]
248
+ new_dataset[constants.KPI] = (
249
+ new_dataset[constants.KPI]
450
250
  .dropna(dim=constants.TIME)
451
251
  .rename({constants.TIME: new_time})
452
252
  )
253
+ if constants.CONTROLS in new_dataset.data_vars.keys():
254
+ new_dataset[constants.CONTROLS] = (
255
+ new_dataset[constants.CONTROLS]
256
+ .dropna(dim=constants.TIME)
257
+ .rename({constants.TIME: new_time})
258
+ )
259
+ if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
260
+ new_dataset[constants.NON_MEDIA_TREATMENTS] = (
261
+ new_dataset[constants.NON_MEDIA_TREATMENTS]
262
+ .dropna(dim=constants.TIME)
263
+ .rename({constants.TIME: new_time})
264
+ )
453
265
 
454
- if constants.MEDIA_SPEND in new_dataset.data_vars.keys():
455
- new_dataset[constants.MEDIA_SPEND] = (
456
- new_dataset[constants.MEDIA_SPEND]
457
- .dropna(dim=constants.TIME)
458
- .rename({constants.TIME: new_time})
459
- )
266
+ if constants.REVENUE_PER_KPI in new_dataset.data_vars.keys():
267
+ new_dataset[constants.REVENUE_PER_KPI] = (
268
+ new_dataset[constants.REVENUE_PER_KPI]
269
+ .dropna(dim=constants.TIME)
270
+ .rename({constants.TIME: new_time})
271
+ )
460
272
 
461
- if constants.RF_SPEND in new_dataset.data_vars.keys():
462
- new_dataset[constants.RF_SPEND] = (
463
- new_dataset[constants.RF_SPEND]
464
- .dropna(dim=constants.TIME)
465
- .rename({constants.TIME: new_time})
466
- )
273
+ if constants.MEDIA_SPEND in new_dataset.data_vars.keys():
274
+ new_dataset[constants.MEDIA_SPEND] = (
275
+ new_dataset[constants.MEDIA_SPEND]
276
+ .dropna(dim=constants.TIME)
277
+ .rename({constants.TIME: new_time})
278
+ )
467
279
 
468
- self.dataset = new_dataset.rename(
469
- {constants.TIME: constants.MEDIA_TIME, new_time: constants.TIME}
470
- )
280
+ if constants.RF_SPEND in new_dataset.data_vars.keys():
281
+ new_dataset[constants.RF_SPEND] = (
282
+ new_dataset[constants.RF_SPEND]
283
+ .dropna(dim=constants.TIME)
284
+ .rename({constants.TIME: new_time})
285
+ )
286
+
287
+ self.dataset = new_dataset.rename(
288
+ {constants.TIME: constants.MEDIA_TIME, new_time: constants.TIME}
289
+ )
471
290
 
472
291
  def load(self) -> input_data.InputData:
473
292
  """Returns an `InputData` object containing the data from the dataset."""
474
- controls = (
475
- self.dataset.controls
476
- if constants.CONTROLS in self.dataset.data_vars.keys()
477
- else None
478
- )
479
- revenue_per_kpi = (
480
- self.dataset.revenue_per_kpi
481
- if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys()
482
- else None
483
- )
484
- media = (
485
- self.dataset.media
486
- if constants.MEDIA in self.dataset.data_vars.keys()
487
- else None
488
- )
489
- media_spend = (
490
- self.dataset.media_spend
491
- if constants.MEDIA in self.dataset.data_vars.keys()
492
- else None
493
- )
494
- reach = (
495
- self.dataset.reach
496
- if constants.REACH in self.dataset.data_vars.keys()
497
- else None
498
- )
499
- frequency = (
500
- self.dataset.frequency
501
- if constants.FREQUENCY in self.dataset.data_vars.keys()
502
- else None
503
- )
504
- rf_spend = (
505
- self.dataset.rf_spend
506
- if constants.RF_SPEND in self.dataset.data_vars.keys()
507
- else None
508
- )
509
- non_media_treatments = (
510
- self.dataset.non_media_treatments
511
- if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys()
512
- else None
513
- )
514
- organic_media = (
515
- self.dataset.organic_media
516
- if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys()
517
- else None
518
- )
519
- organic_reach = (
520
- self.dataset.organic_reach
521
- if constants.ORGANIC_REACH in self.dataset.data_vars.keys()
522
- else None
523
- )
524
- organic_frequency = (
525
- self.dataset.organic_frequency
526
- if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys()
527
- else None
528
- )
529
- return input_data.InputData(
530
- kpi=self.dataset.kpi,
531
- kpi_type=self.kpi_type,
532
- population=self.dataset.population,
533
- controls=controls,
534
- revenue_per_kpi=revenue_per_kpi,
535
- media=media,
536
- media_spend=media_spend,
537
- reach=reach,
538
- frequency=frequency,
539
- rf_spend=rf_spend,
540
- non_media_treatments=non_media_treatments,
541
- organic_media=organic_media,
542
- organic_reach=organic_reach,
543
- organic_frequency=organic_frequency,
544
- )
293
+ builder = input_data_builder.InputDataBuilder(self.kpi_type)
294
+ builder.kpi = self.dataset.kpi
295
+ if constants.POPULATION in self.dataset.data_vars.keys():
296
+ builder.population = self.dataset.population
297
+ if constants.CONTROLS in self.dataset.data_vars.keys():
298
+ builder.controls = self.dataset.controls
299
+ if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
300
+ builder.revenue_per_kpi = self.dataset.revenue_per_kpi
301
+ if constants.MEDIA in self.dataset.data_vars.keys():
302
+ builder.media = self.dataset.media
303
+ if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
304
+ builder.media_spend = self.dataset.media_spend
305
+ if constants.REACH in self.dataset.data_vars.keys():
306
+ builder.reach = self.dataset.reach
307
+ if constants.FREQUENCY in self.dataset.data_vars.keys():
308
+ builder.frequency = self.dataset.frequency
309
+ if constants.RF_SPEND in self.dataset.data_vars.keys():
310
+ builder.rf_spend = self.dataset.rf_spend
311
+ if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
312
+ builder.non_media_treatments = self.dataset.non_media_treatments
313
+ if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys():
314
+ builder.organic_media = self.dataset.organic_media
315
+ if constants.ORGANIC_REACH in self.dataset.data_vars.keys():
316
+ builder.organic_reach = self.dataset.organic_reach
317
+ if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys():
318
+ builder.organic_frequency = self.dataset.organic_frequency
319
+ return builder.build()
545
320
 
546
321
 
547
322
  @dataclasses.dataclass(frozen=True)
@@ -607,6 +382,9 @@ class CoordToColumns:
607
382
  ' both.'
608
383
  )
609
384
 
385
+ if self.revenue_per_kpi is not None and not self.revenue_per_kpi.strip():
386
+ raise ValueError('`revenue_per_kpi` should not be empty if provided.')
387
+
610
388
 
611
389
  @dataclasses.dataclass
612
390
  class DataFrameDataLoader(InputDataLoader):
@@ -830,6 +608,7 @@ class DataFrameDataLoader(InputDataLoader):
830
608
  f'The {channel_dict} keys must have the same set of values as'
831
609
  f' the {coord_name} columns.'
832
610
  )
611
+
833
612
  if (
834
613
  self.media_to_channel is not None
835
614
  and self.media_spend_to_channel is not None
@@ -841,6 +620,27 @@ class DataFrameDataLoader(InputDataLoader):
841
620
  'The media and media_spend columns must have the same set of'
842
621
  ' channels.'
843
622
  )
623
+
624
+ # The columns listed in `media` and `media_spend` must correspond to the
625
+ # same channels, in user-given order!
626
+ # For example, this is invalid:
627
+ # media = ['impressions_tv', 'impressions_yt']
628
+ # media_spend = ['spend_yt', 'spend_tv']
629
+ # But we can only detect this after we map each `media` and `media_spend`
630
+ # column to its canonical channel name.
631
+ media_channels = [
632
+ self.media_to_channel[c] for c in self.coord_to_columns.media
633
+ ]
634
+ media_spend_channels = [
635
+ self.media_spend_to_channel[c]
636
+ for c in self.coord_to_columns.media_spend
637
+ ]
638
+ if media_channels != media_spend_channels:
639
+ raise ValueError(
640
+ 'The `media` and `media_spend` columns must correspond to the same'
641
+ ' channels, in user order.'
642
+ )
643
+
844
644
  if (
845
645
  self.reach_to_channel is not None
846
646
  and self.frequency_to_channel is not None
@@ -855,6 +655,23 @@ class DataFrameDataLoader(InputDataLoader):
855
655
  'The reach, frequency, and rf_spend columns must have the same set'
856
656
  ' of channels.'
857
657
  )
658
+
659
+ # Same channel ordering concerns as for `media` and `media_spend`.
660
+ reach_channels = [
661
+ self.reach_to_channel[c] for c in self.coord_to_columns.reach
662
+ ]
663
+ frequency_channels = [
664
+ self.frequency_to_channel[c] for c in self.coord_to_columns.frequency
665
+ ]
666
+ rf_spend_channels = [
667
+ self.rf_spend_to_channel[c] for c in self.coord_to_columns.rf_spend
668
+ ]
669
+ if not (reach_channels == frequency_channels == rf_spend_channels):
670
+ raise ValueError(
671
+ 'The `reach`, `frequency`, and `rf_spend` columns must correspond'
672
+ ' to the same channels, in user order.'
673
+ )
674
+
858
675
  if (
859
676
  self.organic_reach_to_channel is not None
860
677
  and self.organic_frequency_to_channel is not None
@@ -867,6 +684,21 @@ class DataFrameDataLoader(InputDataLoader):
867
684
  ' same set of channels.'
868
685
  )
869
686
 
687
+ # Same channel ordering concerns as for `media` and `media_spend`.
688
+ organic_reach_channels = [
689
+ self.organic_reach_to_channel[c]
690
+ for c in self.coord_to_columns.organic_reach
691
+ ]
692
+ organic_frequency_channels = [
693
+ self.organic_frequency_to_channel[c]
694
+ for c in self.coord_to_columns.organic_frequency
695
+ ]
696
+ if organic_reach_channels != organic_frequency_channels:
697
+ raise ValueError(
698
+ 'The `organic_reach` and `organic_frequency` columns must'
699
+ ' correspond to the same channels, in user order.'
700
+ )
701
+
870
702
  def load(self) -> input_data.InputData:
871
703
  """Reads data from a dataframe and returns an InputData object."""
872
704
 
@@ -878,66 +710,86 @@ class DataFrameDataLoader(InputDataLoader):
878
710
  self.coord_to_columns.time,
879
711
  self.coord_to_columns.geo,
880
712
  )
713
+
881
714
  if self.coord_to_columns.population in self.df.columns:
882
715
  builder.with_population(
883
716
  self.df, self.coord_to_columns.population, self.coord_to_columns.geo
884
717
  )
885
- if self.coord_to_columns.controls is not None:
718
+
719
+ if self.coord_to_columns.controls:
886
720
  builder.with_controls(
887
721
  self.df,
888
722
  list(self.coord_to_columns.controls),
889
723
  self.coord_to_columns.time,
890
724
  self.coord_to_columns.geo,
891
725
  )
892
- if self.coord_to_columns.non_media_treatments is not None:
726
+
727
+ if self.coord_to_columns.non_media_treatments:
893
728
  builder.with_non_media_treatments(
894
729
  self.df,
895
730
  list(self.coord_to_columns.non_media_treatments),
896
731
  self.coord_to_columns.time,
897
732
  self.coord_to_columns.geo,
898
733
  )
899
- if self.coord_to_columns.revenue_per_kpi is not None:
734
+
735
+ if self.coord_to_columns.revenue_per_kpi:
900
736
  builder.with_revenue_per_kpi(
901
737
  self.df,
902
738
  self.coord_to_columns.revenue_per_kpi,
903
739
  self.coord_to_columns.time,
904
740
  self.coord_to_columns.geo,
905
741
  )
742
+
906
743
  if (
907
744
  self.media_to_channel is not None
908
745
  and self.media_spend_to_channel is not None
909
746
  ):
910
- sorted_channels = sorted(self.media_to_channel.values())
911
- inv_media_map = {v: k for k, v in self.media_to_channel.items()}
912
- inv_spend_map = {v: k for k, v in self.media_spend_to_channel.items()}
913
-
747
+ # Based on the invariant rule enforced in `__post_init__`, the columns
748
+ # listed in `media` and `media_spend` are already validated to correspond
749
+ # to the same channels, in user-given order.
750
+ media_execution_columns = list(self.coord_to_columns.media)
751
+ media_spend_columns = list(self.coord_to_columns.media_spend)
752
+ # So now we can use one of the channel mapper dicts to get the canonical
753
+ # channel names for each column.
754
+ media_channel_names = [
755
+ self.media_to_channel[c] for c in self.coord_to_columns.media
756
+ ]
914
757
  builder.with_media(
915
758
  self.df,
916
- [inv_media_map[ch] for ch in sorted_channels],
917
- [inv_spend_map[ch] for ch in sorted_channels],
918
- sorted_channels,
759
+ media_execution_columns,
760
+ media_spend_columns,
761
+ media_channel_names,
919
762
  self.coord_to_columns.time,
920
763
  self.coord_to_columns.geo,
921
764
  )
765
+
922
766
  if (
923
767
  self.reach_to_channel is not None
924
768
  and self.frequency_to_channel is not None
925
769
  and self.rf_spend_to_channel is not None
926
770
  ):
927
- sorted_channels = sorted(self.reach_to_channel.values())
928
- inv_reach_map = {v: k for k, v in self.reach_to_channel.items()}
929
- inv_freq_map = {v: k for k, v in self.frequency_to_channel.items()}
930
- inv_rf_spend_map = {v: k for k, v in self.rf_spend_to_channel.items()}
771
+ # Based on the invariant rule enforced in `__post_init__`, the columns
772
+ # listed in `reach`, `frequency`, and `rf_spend` are already validated
773
+ # to correspond to the same channels, in user-given order.
774
+ reach_columns = list(self.coord_to_columns.reach)
775
+ frequency_columns = list(self.coord_to_columns.frequency)
776
+ rf_spend_columns = list(self.coord_to_columns.rf_spend)
777
+ # So now we can use one of the channel mapper dicts to get the canonical
778
+ # channel names for each column.
779
+ rf_channel_names = [
780
+ self.reach_to_channel[c] for c in self.coord_to_columns.reach
781
+ ]
931
782
  builder.with_reach(
932
783
  self.df,
933
- [inv_reach_map[ch] for ch in sorted_channels],
934
- [inv_freq_map[ch] for ch in sorted_channels],
935
- [inv_rf_spend_map[ch] for ch in sorted_channels],
936
- sorted_channels,
784
+ reach_columns,
785
+ frequency_columns,
786
+ rf_spend_columns,
787
+ rf_channel_names,
937
788
  self.coord_to_columns.time,
938
789
  self.coord_to_columns.geo,
939
790
  )
940
- if self.coord_to_columns.organic_media is not None:
791
+
792
+ if self.coord_to_columns.organic_media:
941
793
  builder.with_organic_media(
942
794
  self.df,
943
795
  list(self.coord_to_columns.organic_media),
@@ -945,23 +797,31 @@ class DataFrameDataLoader(InputDataLoader):
945
797
  self.coord_to_columns.time,
946
798
  self.coord_to_columns.geo,
947
799
  )
800
+
948
801
  if (
949
802
  self.organic_reach_to_channel is not None
950
803
  and self.organic_frequency_to_channel is not None
951
804
  ):
952
- sorted_channels = sorted(self.organic_reach_to_channel.values())
953
- inv_reach_map = {v: k for k, v in self.organic_reach_to_channel.items()}
954
- inv_freq_map = {
955
- v: k for k, v in self.organic_frequency_to_channel.items()
956
- }
805
+ # Based on the invariant rule enforced in `__post_init__`, the columns
806
+ # listed in `organic_reach` and `organic_frequency` are already
807
+ # validated to correspond to the same channels, in user-given order.
808
+ organic_reach_columns = list(self.coord_to_columns.organic_reach)
809
+ organic_frequency_columns = list(self.coord_to_columns.organic_frequency)
810
+ # So now we can use one of the channel mapper dicts to get the canonical
811
+ # channel names for each column.
812
+ organic_rf_channel_names = [
813
+ self.organic_reach_to_channel[c]
814
+ for c in self.coord_to_columns.organic_reach
815
+ ]
957
816
  builder.with_organic_reach(
958
817
  self.df,
959
- [inv_reach_map[ch] for ch in sorted_channels],
960
- [inv_freq_map[ch] for ch in sorted_channels],
961
- sorted_channels,
818
+ organic_reach_columns,
819
+ organic_frequency_columns,
820
+ organic_rf_channel_names,
962
821
  self.coord_to_columns.time,
963
822
  self.coord_to_columns.geo,
964
823
  )
824
+
965
825
  return builder.build()
966
826
 
967
827