google-meridian 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/data/load.py CHANGED
@@ -22,14 +22,11 @@ object.
22
22
  import abc
23
23
  from collections.abc import Mapping, Sequence
24
24
  import dataclasses
25
- import datetime as dt
26
- import warnings
27
-
28
25
  import immutabledict
29
26
  from meridian import constants
30
27
  from meridian.data import data_frame_input_data_builder
31
28
  from meridian.data import input_data
32
- import numpy as np
29
+ from meridian.data import input_data_builder
33
30
  import pandas as pd
34
31
  import xarray as xr
35
32
 
@@ -202,346 +199,124 @@ class XrDatasetDataLoader(InputDataLoader):
202
199
  if (constants.GEO) not in self.dataset.sizes.keys():
203
200
  self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
204
201
 
205
- if len(self.dataset.coords[constants.GEO]) == 1:
206
- if constants.POPULATION in self.dataset.data_vars.keys():
207
- warnings.warn(
208
- 'The `population` argument is ignored in a nationally aggregated'
209
- ' model. It will be reset to [1]'
210
- )
211
- self.dataset = self.dataset.drop_vars(names=[constants.POPULATION])
212
-
213
- # Add a default `population` [1].
214
- national_population_darray = xr.DataArray(
215
- [constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE],
216
- dims=[constants.GEO],
217
- coords={
218
- constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME],
219
- },
220
- name=constants.POPULATION,
221
- )
222
- self.dataset = xr.combine_by_coords(
223
- [
224
- national_population_darray,
225
- self.dataset.assign_coords(
226
- {constants.GEO: [constants.NATIONAL_MODEL_DEFAULT_GEO_NAME]}
227
- ),
228
- ],
229
- compat='override',
230
- )
231
-
232
202
  if constants.MEDIA_TIME not in self.dataset.sizes.keys():
233
- self._add_media_time()
234
- self._normalize_time_coordinates(constants.TIME)
235
- self._normalize_time_coordinates(constants.MEDIA_TIME)
236
- self._validate_dataset()
237
-
238
- def _normalize_time_coordinates(self, dim: str):
239
- if self.dataset.coords.dtypes[dim] == np.dtype('datetime64[ns]'):
240
- date_strvalues = np.datetime_as_string(self.dataset.coords[dim], unit='D')
241
- self.dataset = self.dataset.assign_coords({dim: date_strvalues})
242
-
243
- # Assume that the time coordinate labels are date-formatted strings.
244
- # We don't currently support other, arbitrary object types in the loaders.
245
- for time in self.dataset.coords[dim].values:
246
- try:
247
- _ = dt.datetime.strptime(time, constants.DATE_FORMAT)
248
- except ValueError as exc:
249
- raise ValueError(
250
- f"Invalid time label: '{time}'. Expected format:"
251
- f" '{constants.DATE_FORMAT}'"
252
- ) from exc
203
+ na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
253
204
 
254
- def _validate_dataset(self):
255
- for coord_name in constants.REQUIRED_INPUT_DATA_COORD_NAMES:
256
- if coord_name not in self.dataset.coords:
257
- raise ValueError(
258
- f"Coordinate '{coord_name}' not found in dataset's coordinates."
259
- " Please use the 'name_mapping' argument to rename the coordinates."
205
+ if constants.CONTROLS in self.dataset.data_vars.keys():
206
+ na_mask |= (
207
+ self.dataset[constants.CONTROLS]
208
+ .isnull()
209
+ .any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
260
210
  )
261
211
 
262
- for array_name in constants.REQUIRED_INPUT_DATA_ARRAY_NAMES:
263
- if array_name not in self.dataset.data_vars:
264
- raise ValueError(
265
- f"Array '{array_name}' not found in dataset's arrays."
266
- " Please use the 'name_mapping' argument to rename the arrays."
212
+ if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
213
+ na_mask |= (
214
+ self.dataset[constants.NON_MEDIA_TREATMENTS]
215
+ .isnull()
216
+ .any(dim=[constants.GEO, constants.NON_MEDIA_CHANNEL])
267
217
  )
268
218
 
269
- # Check for media.
270
- missing_media_input = []
271
- for coord_name in constants.MEDIA_INPUT_DATA_COORD_NAMES:
272
- if coord_name not in self.dataset.coords:
273
- missing_media_input.append(coord_name)
274
- for array_name in constants.MEDIA_INPUT_DATA_ARRAY_NAMES:
275
- if array_name not in self.dataset.data_vars:
276
- missing_media_input.append(array_name)
277
-
278
- # Check for RF.
279
- missing_rf_input = []
280
- for coord_name in constants.RF_INPUT_DATA_COORD_NAMES:
281
- if coord_name not in self.dataset.coords:
282
- missing_rf_input.append(coord_name)
283
- for array_name in constants.RF_INPUT_DATA_ARRAY_NAMES:
284
- if array_name not in self.dataset.data_vars:
285
- missing_rf_input.append(array_name)
286
-
287
- if missing_media_input and missing_rf_input:
288
- raise ValueError(
289
- "Some required data is missing. Please use the 'name_mapping'"
290
- ' argument to rename the coordinates/arrays. It is required to have'
291
- ' at least one of media or reach and frequency.'
292
- )
293
-
294
- if missing_media_input and len(missing_media_input) != len(
295
- constants.MEDIA_INPUT_DATA_COORD_NAMES
296
- ) + len(constants.MEDIA_INPUT_DATA_ARRAY_NAMES):
297
- raise ValueError(
298
- f"Media data is partially missing. '{missing_media_input}' not found"
299
- " in dataset's coordinates/arrays. Please use the 'name_mapping'"
300
- ' argument to rename the coordinates/arrays.'
301
- )
302
-
303
- if missing_rf_input and len(missing_rf_input) != len(
304
- constants.RF_INPUT_DATA_COORD_NAMES
305
- ) + len(constants.RF_INPUT_DATA_ARRAY_NAMES):
306
- raise ValueError(
307
- f"RF data is partially missing. '{missing_rf_input}' not found in"
308
- " dataset's coordinates/arrays. Please use the 'name_mapping'"
309
- ' argument to rename the coordinates/arrays.'
310
- )
311
-
312
- def _add_media_time(self):
313
- """Creates the `media_time` coordinate if it is not provided directly.
314
-
315
- The user can either create both `time` and `media_time` coordinates directly
316
- and use them to provide the lagged data for `media`, `reach` and `frequency`
317
- arrays, or use the `time` coordinate for all arrays. In the second case,
318
- the lagged period will be determined and the `media_time` and `time`
319
- coordinates will be created based on the missing values in the other arrays:
320
- `kpi`, `revenue_per_kpi`, `controls`, `media_spend`, `rf_spend`. The
321
- analogous mechanism to determine the lagged period is used in
322
- `DataFrameDataLoader` and `CsvDataLoader`.
323
- """
324
- # Check if there are no NAs in media.
325
- if constants.MEDIA in self.dataset.data_vars.keys():
326
- if self.dataset.media.isnull().any(axis=None):
327
- raise ValueError('NA values found in the media array.')
328
-
329
- # Check if there are no NAs in reach & frequency.
330
- if constants.REACH in self.dataset.data_vars.keys():
331
- if self.dataset.reach.isnull().any(axis=None):
332
- raise ValueError('NA values found in the reach array.')
333
- if constants.FREQUENCY in self.dataset.data_vars.keys():
334
- if self.dataset.frequency.isnull().any(axis=None):
335
- raise ValueError('NA values found in the frequency array.')
336
-
337
- # Check if there are no NAs in organic media.
338
- if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys():
339
- if self.dataset.organic_media.isnull().any(axis=None):
340
- raise ValueError('NA values found in the organic media array.')
341
-
342
- # Check if there are no NAs in organic reach & frequency.
343
- if constants.ORGANIC_REACH in self.dataset.data_vars.keys():
344
- if self.dataset.organic_reach.isnull().any(axis=None):
345
- raise ValueError('NA values found in the organic reach array.')
346
- if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys():
347
- if self.dataset.organic_frequency.isnull().any(axis=None):
348
- raise ValueError('NA values found in the organic frequency array.')
349
-
350
- # Arrays in which NAs are expected in the lagged-media period.
351
- na_arrays = [
352
- constants.KPI,
353
- ]
354
-
355
- na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
356
-
357
- if constants.CONTROLS in self.dataset.data_vars.keys():
358
- na_arrays.append(constants.CONTROLS)
359
- na_mask |= (
360
- self.dataset[constants.CONTROLS]
361
- .isnull()
362
- .any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
363
- )
364
-
365
- if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
366
- na_arrays.append(constants.NON_MEDIA_TREATMENTS)
367
- na_mask |= (
368
- self.dataset[constants.NON_MEDIA_TREATMENTS]
369
- .isnull()
370
- .any(dim=[constants.GEO, constants.NON_MEDIA_CHANNEL])
371
- )
372
-
373
- if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
374
- na_arrays.append(constants.REVENUE_PER_KPI)
375
- na_mask |= (
376
- self.dataset[constants.REVENUE_PER_KPI]
377
- .isnull()
378
- .any(dim=constants.GEO)
379
- )
380
- if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
381
- na_arrays.append(constants.MEDIA_SPEND)
382
- na_mask |= (
383
- self.dataset[constants.MEDIA_SPEND]
384
- .isnull()
385
- .any(dim=[constants.GEO, constants.MEDIA_CHANNEL])
386
- )
387
- if constants.RF_SPEND in self.dataset.data_vars.keys():
388
- na_arrays.append(constants.RF_SPEND)
389
- na_mask |= (
390
- self.dataset[constants.RF_SPEND]
391
- .isnull()
392
- .any(dim=[constants.GEO, constants.RF_CHANNEL])
393
- )
394
-
395
- # Dates with at least one non-NA value in non-media columns
396
- no_na_period = self.dataset[constants.TIME].isel(time=~na_mask).values
397
-
398
- # Dates with 100% NA values in all non-media columns.
399
- na_period = self.dataset[constants.TIME].isel(time=na_mask).values
400
-
401
- # Check if na_period is a continuous window starting from the earliest time
402
- # period.
403
- if not np.all(
404
- np.sort(na_period)
405
- == np.sort(np.unique(self.dataset[constants.TIME]))[: len(na_period)]
406
- ):
407
- raise ValueError(
408
- "The 'lagged media' period (period with 100% NA values in all"
409
- f' non-media columns) {na_period} is not a continuous window starting'
410
- ' from the earliest time period.'
411
- )
412
-
413
- # Check if for the non-lagged period, there are no NAs in non-media data
414
- for array in na_arrays:
415
- if np.any(np.isnan(self.dataset[array].isel(time=~na_mask))):
416
- raise ValueError(
417
- 'NA values found in other than media columns outside the'
418
- f' lagged-media period {na_period} (continuous window of 100% NA'
419
- ' values in all other than media columns).'
219
+ if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
220
+ na_mask |= (
221
+ self.dataset[constants.REVENUE_PER_KPI]
222
+ .isnull()
223
+ .any(dim=constants.GEO)
224
+ )
225
+ if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
226
+ na_mask |= (
227
+ self.dataset[constants.MEDIA_SPEND]
228
+ .isnull()
229
+ .any(dim=[constants.GEO, constants.MEDIA_CHANNEL])
230
+ )
231
+ if constants.RF_SPEND in self.dataset.data_vars.keys():
232
+ na_mask |= (
233
+ self.dataset[constants.RF_SPEND]
234
+ .isnull()
235
+ .any(dim=[constants.GEO, constants.RF_CHANNEL])
420
236
  )
421
237
 
422
- # Create new `time` and `media_time` coordinates.
423
- new_time = 'new_time'
238
+ # Dates with at least one non-NA value in non-media columns
239
+ no_na_period = self.dataset[constants.TIME].isel(time=~na_mask).values
424
240
 
425
- new_dataset = self.dataset.assign_coords(
426
- new_time=(new_time, no_na_period),
427
- )
241
+ # Create new `time` and `media_time` coordinates.
242
+ new_time = 'new_time'
428
243
 
429
- new_dataset[constants.KPI] = (
430
- new_dataset[constants.KPI]
431
- .dropna(dim=constants.TIME)
432
- .rename({constants.TIME: new_time})
433
- )
434
- if constants.CONTROLS in new_dataset.data_vars.keys():
435
- new_dataset[constants.CONTROLS] = (
436
- new_dataset[constants.CONTROLS]
437
- .dropna(dim=constants.TIME)
438
- .rename({constants.TIME: new_time})
439
- )
440
- if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
441
- new_dataset[constants.NON_MEDIA_TREATMENTS] = (
442
- new_dataset[constants.NON_MEDIA_TREATMENTS]
443
- .dropna(dim=constants.TIME)
444
- .rename({constants.TIME: new_time})
244
+ new_dataset = self.dataset.assign_coords(
245
+ new_time=(new_time, no_na_period),
445
246
  )
446
247
 
447
- if constants.REVENUE_PER_KPI in new_dataset.data_vars.keys():
448
- new_dataset[constants.REVENUE_PER_KPI] = (
449
- new_dataset[constants.REVENUE_PER_KPI]
248
+ new_dataset[constants.KPI] = (
249
+ new_dataset[constants.KPI]
450
250
  .dropna(dim=constants.TIME)
451
251
  .rename({constants.TIME: new_time})
452
252
  )
253
+ if constants.CONTROLS in new_dataset.data_vars.keys():
254
+ new_dataset[constants.CONTROLS] = (
255
+ new_dataset[constants.CONTROLS]
256
+ .dropna(dim=constants.TIME)
257
+ .rename({constants.TIME: new_time})
258
+ )
259
+ if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
260
+ new_dataset[constants.NON_MEDIA_TREATMENTS] = (
261
+ new_dataset[constants.NON_MEDIA_TREATMENTS]
262
+ .dropna(dim=constants.TIME)
263
+ .rename({constants.TIME: new_time})
264
+ )
453
265
 
454
- if constants.MEDIA_SPEND in new_dataset.data_vars.keys():
455
- new_dataset[constants.MEDIA_SPEND] = (
456
- new_dataset[constants.MEDIA_SPEND]
457
- .dropna(dim=constants.TIME)
458
- .rename({constants.TIME: new_time})
459
- )
266
+ if constants.REVENUE_PER_KPI in new_dataset.data_vars.keys():
267
+ new_dataset[constants.REVENUE_PER_KPI] = (
268
+ new_dataset[constants.REVENUE_PER_KPI]
269
+ .dropna(dim=constants.TIME)
270
+ .rename({constants.TIME: new_time})
271
+ )
460
272
 
461
- if constants.RF_SPEND in new_dataset.data_vars.keys():
462
- new_dataset[constants.RF_SPEND] = (
463
- new_dataset[constants.RF_SPEND]
464
- .dropna(dim=constants.TIME)
465
- .rename({constants.TIME: new_time})
466
- )
273
+ if constants.MEDIA_SPEND in new_dataset.data_vars.keys():
274
+ new_dataset[constants.MEDIA_SPEND] = (
275
+ new_dataset[constants.MEDIA_SPEND]
276
+ .dropna(dim=constants.TIME)
277
+ .rename({constants.TIME: new_time})
278
+ )
467
279
 
468
- self.dataset = new_dataset.rename(
469
- {constants.TIME: constants.MEDIA_TIME, new_time: constants.TIME}
470
- )
280
+ if constants.RF_SPEND in new_dataset.data_vars.keys():
281
+ new_dataset[constants.RF_SPEND] = (
282
+ new_dataset[constants.RF_SPEND]
283
+ .dropna(dim=constants.TIME)
284
+ .rename({constants.TIME: new_time})
285
+ )
286
+
287
+ self.dataset = new_dataset.rename(
288
+ {constants.TIME: constants.MEDIA_TIME, new_time: constants.TIME}
289
+ )
471
290
 
472
291
  def load(self) -> input_data.InputData:
473
292
  """Returns an `InputData` object containing the data from the dataset."""
474
- controls = (
475
- self.dataset.controls
476
- if constants.CONTROLS in self.dataset.data_vars.keys()
477
- else None
478
- )
479
- revenue_per_kpi = (
480
- self.dataset.revenue_per_kpi
481
- if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys()
482
- else None
483
- )
484
- media = (
485
- self.dataset.media
486
- if constants.MEDIA in self.dataset.data_vars.keys()
487
- else None
488
- )
489
- media_spend = (
490
- self.dataset.media_spend
491
- if constants.MEDIA in self.dataset.data_vars.keys()
492
- else None
493
- )
494
- reach = (
495
- self.dataset.reach
496
- if constants.REACH in self.dataset.data_vars.keys()
497
- else None
498
- )
499
- frequency = (
500
- self.dataset.frequency
501
- if constants.FREQUENCY in self.dataset.data_vars.keys()
502
- else None
503
- )
504
- rf_spend = (
505
- self.dataset.rf_spend
506
- if constants.RF_SPEND in self.dataset.data_vars.keys()
507
- else None
508
- )
509
- non_media_treatments = (
510
- self.dataset.non_media_treatments
511
- if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys()
512
- else None
513
- )
514
- organic_media = (
515
- self.dataset.organic_media
516
- if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys()
517
- else None
518
- )
519
- organic_reach = (
520
- self.dataset.organic_reach
521
- if constants.ORGANIC_REACH in self.dataset.data_vars.keys()
522
- else None
523
- )
524
- organic_frequency = (
525
- self.dataset.organic_frequency
526
- if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys()
527
- else None
528
- )
529
- return input_data.InputData(
530
- kpi=self.dataset.kpi,
531
- kpi_type=self.kpi_type,
532
- population=self.dataset.population,
533
- controls=controls,
534
- revenue_per_kpi=revenue_per_kpi,
535
- media=media,
536
- media_spend=media_spend,
537
- reach=reach,
538
- frequency=frequency,
539
- rf_spend=rf_spend,
540
- non_media_treatments=non_media_treatments,
541
- organic_media=organic_media,
542
- organic_reach=organic_reach,
543
- organic_frequency=organic_frequency,
544
- )
293
+ builder = input_data_builder.InputDataBuilder(self.kpi_type)
294
+ builder.kpi = self.dataset.kpi
295
+ if constants.POPULATION in self.dataset.data_vars.keys():
296
+ builder.population = self.dataset.population
297
+ if constants.CONTROLS in self.dataset.data_vars.keys():
298
+ builder.controls = self.dataset.controls
299
+ if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys():
300
+ builder.revenue_per_kpi = self.dataset.revenue_per_kpi
301
+ if constants.MEDIA in self.dataset.data_vars.keys():
302
+ builder.media = self.dataset.media
303
+ if constants.MEDIA_SPEND in self.dataset.data_vars.keys():
304
+ builder.media_spend = self.dataset.media_spend
305
+ if constants.REACH in self.dataset.data_vars.keys():
306
+ builder.reach = self.dataset.reach
307
+ if constants.FREQUENCY in self.dataset.data_vars.keys():
308
+ builder.frequency = self.dataset.frequency
309
+ if constants.RF_SPEND in self.dataset.data_vars.keys():
310
+ builder.rf_spend = self.dataset.rf_spend
311
+ if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
312
+ builder.non_media_treatments = self.dataset.non_media_treatments
313
+ if constants.ORGANIC_MEDIA in self.dataset.data_vars.keys():
314
+ builder.organic_media = self.dataset.organic_media
315
+ if constants.ORGANIC_REACH in self.dataset.data_vars.keys():
316
+ builder.organic_reach = self.dataset.organic_reach
317
+ if constants.ORGANIC_FREQUENCY in self.dataset.data_vars.keys():
318
+ builder.organic_frequency = self.dataset.organic_frequency
319
+ return builder.build()
545
320
 
546
321
 
547
322
  @dataclasses.dataclass(frozen=True)
@@ -607,6 +382,9 @@ class CoordToColumns:
607
382
  ' both.'
608
383
  )
609
384
 
385
+ if self.revenue_per_kpi is not None and not self.revenue_per_kpi.strip():
386
+ raise ValueError('`revenue_per_kpi` should not be empty if provided.')
387
+
610
388
 
611
389
  @dataclasses.dataclass
612
390
  class DataFrameDataLoader(InputDataLoader):
@@ -816,12 +594,109 @@ class DataFrameDataLoader(InputDataLoader):
816
594
  'organic_frequency': 'organic_frequency_to_channel',
817
595
  })
818
596
  for coord_name, channel_dict in required_mappings.items():
597
+ if getattr(self.coord_to_columns, coord_name, None) is not None:
598
+ if getattr(self, channel_dict, None) is None:
599
+ raise ValueError(
600
+ f"When {coord_name} data is provided, '{channel_dict}' is"
601
+ ' required.'
602
+ )
603
+ else:
604
+ if set(getattr(self, channel_dict)) != set(
605
+ getattr(self.coord_to_columns, coord_name)
606
+ ):
607
+ raise ValueError(
608
+ f'The {channel_dict} keys must have the same set of values as'
609
+ f' the {coord_name} columns.'
610
+ )
611
+
612
+ if (
613
+ self.media_to_channel is not None
614
+ and self.media_spend_to_channel is not None
615
+ ):
616
+ if set(self.media_to_channel.values()) != set(
617
+ self.media_spend_to_channel.values()
618
+ ):
619
+ raise ValueError(
620
+ 'The media and media_spend columns must have the same set of'
621
+ ' channels.'
622
+ )
623
+
624
+ # The columns listed in `media` and `media_spend` must correspond to the
625
+ # same channels, in user-given order!
626
+ # For example, this is invalid:
627
+ # media = ['impressions_tv', 'impressions_yt']
628
+ # media_spend = ['spend_yt', 'spend_tv']
629
+ # But we can only detect this after we map each `media` and `media_spend`
630
+ # column to its canonical channel name.
631
+ media_channels = [
632
+ self.media_to_channel[c] for c in self.coord_to_columns.media
633
+ ]
634
+ media_spend_channels = [
635
+ self.media_spend_to_channel[c]
636
+ for c in self.coord_to_columns.media_spend
637
+ ]
638
+ if media_channels != media_spend_channels:
639
+ raise ValueError(
640
+ 'The `media` and `media_spend` columns must correspond to the same'
641
+ ' channels, in user order.'
642
+ )
643
+
644
+ if (
645
+ self.reach_to_channel is not None
646
+ and self.frequency_to_channel is not None
647
+ and self.rf_spend_to_channel is not None
648
+ ):
819
649
  if (
820
- getattr(self.coord_to_columns, coord_name, None) is not None
821
- and getattr(self, channel_dict, None) is None
650
+ set(self.reach_to_channel.values())
651
+ != set(self.frequency_to_channel.values())
652
+ != set(self.rf_spend_to_channel.values())
822
653
  ):
823
654
  raise ValueError(
824
- f"When {coord_name} data is provided, '{channel_dict}' is required."
655
+ 'The reach, frequency, and rf_spend columns must have the same set'
656
+ ' of channels.'
657
+ )
658
+
659
+ # Same channel ordering concerns as for `media` and `media_spend`.
660
+ reach_channels = [
661
+ self.reach_to_channel[c] for c in self.coord_to_columns.reach
662
+ ]
663
+ frequency_channels = [
664
+ self.frequency_to_channel[c] for c in self.coord_to_columns.frequency
665
+ ]
666
+ rf_spend_channels = [
667
+ self.rf_spend_to_channel[c] for c in self.coord_to_columns.rf_spend
668
+ ]
669
+ if not (reach_channels == frequency_channels == rf_spend_channels):
670
+ raise ValueError(
671
+ 'The `reach`, `frequency`, and `rf_spend` columns must correspond'
672
+ ' to the same channels, in user order.'
673
+ )
674
+
675
+ if (
676
+ self.organic_reach_to_channel is not None
677
+ and self.organic_frequency_to_channel is not None
678
+ ):
679
+ if set(self.organic_reach_to_channel.values()) != set(
680
+ self.organic_frequency_to_channel.values()
681
+ ):
682
+ raise ValueError(
683
+ 'The organic_reach and organic_frequency columns must have the'
684
+ ' same set of channels.'
685
+ )
686
+
687
+ # Same channel ordering concerns as for `media` and `media_spend`.
688
+ organic_reach_channels = [
689
+ self.organic_reach_to_channel[c]
690
+ for c in self.coord_to_columns.organic_reach
691
+ ]
692
+ organic_frequency_channels = [
693
+ self.organic_frequency_to_channel[c]
694
+ for c in self.coord_to_columns.organic_frequency
695
+ ]
696
+ if organic_reach_channels != organic_frequency_channels:
697
+ raise ValueError(
698
+ 'The `organic_reach` and `organic_frequency` columns must'
699
+ ' correspond to the same channels, in user order.'
825
700
  )
826
701
 
827
702
  def load(self) -> input_data.InputData:
@@ -835,58 +710,86 @@ class DataFrameDataLoader(InputDataLoader):
835
710
  self.coord_to_columns.time,
836
711
  self.coord_to_columns.geo,
837
712
  )
713
+
838
714
  if self.coord_to_columns.population in self.df.columns:
839
715
  builder.with_population(
840
716
  self.df, self.coord_to_columns.population, self.coord_to_columns.geo
841
717
  )
842
- if self.coord_to_columns.controls is not None:
718
+
719
+ if self.coord_to_columns.controls:
843
720
  builder.with_controls(
844
721
  self.df,
845
722
  list(self.coord_to_columns.controls),
846
723
  self.coord_to_columns.time,
847
724
  self.coord_to_columns.geo,
848
725
  )
849
- if self.coord_to_columns.non_media_treatments is not None:
726
+
727
+ if self.coord_to_columns.non_media_treatments:
850
728
  builder.with_non_media_treatments(
851
729
  self.df,
852
730
  list(self.coord_to_columns.non_media_treatments),
853
731
  self.coord_to_columns.time,
854
732
  self.coord_to_columns.geo,
855
733
  )
856
- if self.coord_to_columns.revenue_per_kpi is not None:
734
+
735
+ if self.coord_to_columns.revenue_per_kpi:
857
736
  builder.with_revenue_per_kpi(
858
737
  self.df,
859
738
  self.coord_to_columns.revenue_per_kpi,
860
739
  self.coord_to_columns.time,
861
740
  self.coord_to_columns.geo,
862
741
  )
742
+
863
743
  if (
864
- self.coord_to_columns.media is not None
865
- and self.media_to_channel is not None
744
+ self.media_to_channel is not None
745
+ and self.media_spend_to_channel is not None
866
746
  ):
747
+ # Based on the invariant rule enforced in `__post_init__`, the columns
748
+ # listed in `media` and `media_spend` are already validated to correspond
749
+ # to the same channels, in user-given order.
750
+ media_execution_columns = list(self.coord_to_columns.media)
751
+ media_spend_columns = list(self.coord_to_columns.media_spend)
752
+ # So now we can use one of the channel mapper dicts to get the canonical
753
+ # channel names for each column.
754
+ media_channel_names = [
755
+ self.media_to_channel[c] for c in self.coord_to_columns.media
756
+ ]
867
757
  builder.with_media(
868
758
  self.df,
869
- list(self.coord_to_columns.media),
870
- list(self.coord_to_columns.media_spend),
871
- list(self.media_to_channel.values()),
759
+ media_execution_columns,
760
+ media_spend_columns,
761
+ media_channel_names,
872
762
  self.coord_to_columns.time,
873
763
  self.coord_to_columns.geo,
874
764
  )
875
765
 
876
766
  if (
877
- self.coord_to_columns.reach is not None
878
- and self.reach_to_channel is not None
767
+ self.reach_to_channel is not None
768
+ and self.frequency_to_channel is not None
769
+ and self.rf_spend_to_channel is not None
879
770
  ):
771
+ # Based on the invariant rule enforced in `__post_init__`, the columns
772
+ # listed in `reach`, `frequency`, and `rf_spend` are already validated
773
+ # to correspond to the same channels, in user-given order.
774
+ reach_columns = list(self.coord_to_columns.reach)
775
+ frequency_columns = list(self.coord_to_columns.frequency)
776
+ rf_spend_columns = list(self.coord_to_columns.rf_spend)
777
+ # So now we can use one of the channel mapper dicts to get the canonical
778
+ # channel names for each column.
779
+ rf_channel_names = [
780
+ self.reach_to_channel[c] for c in self.coord_to_columns.reach
781
+ ]
880
782
  builder.with_reach(
881
783
  self.df,
882
- list(self.coord_to_columns.reach),
883
- list(self.coord_to_columns.frequency),
884
- list(self.coord_to_columns.rf_spend),
885
- list(self.reach_to_channel.values()),
784
+ reach_columns,
785
+ frequency_columns,
786
+ rf_spend_columns,
787
+ rf_channel_names,
886
788
  self.coord_to_columns.time,
887
789
  self.coord_to_columns.geo,
888
790
  )
889
- if self.coord_to_columns.organic_media is not None:
791
+
792
+ if self.coord_to_columns.organic_media:
890
793
  builder.with_organic_media(
891
794
  self.df,
892
795
  list(self.coord_to_columns.organic_media),
@@ -894,18 +797,31 @@ class DataFrameDataLoader(InputDataLoader):
894
797
  self.coord_to_columns.time,
895
798
  self.coord_to_columns.geo,
896
799
  )
800
+
897
801
  if (
898
- self.coord_to_columns.organic_reach is not None
899
- and self.organic_reach_to_channel is not None
802
+ self.organic_reach_to_channel is not None
803
+ and self.organic_frequency_to_channel is not None
900
804
  ):
805
+ # Based on the invariant rule enforced in `__post_init__`, the columns
806
+ # listed in `organic_reach` and `organic_frequency` are already
807
+ # validated to correspond to the same channels, in user-given order.
808
+ organic_reach_columns = list(self.coord_to_columns.organic_reach)
809
+ organic_frequency_columns = list(self.coord_to_columns.organic_frequency)
810
+ # So now we can use one of the channel mapper dicts to get the canonical
811
+ # channel names for each column.
812
+ organic_rf_channel_names = [
813
+ self.organic_reach_to_channel[c]
814
+ for c in self.coord_to_columns.organic_reach
815
+ ]
901
816
  builder.with_organic_reach(
902
817
  self.df,
903
- list(self.coord_to_columns.organic_reach),
904
- list(self.coord_to_columns.organic_frequency),
905
- list(self.organic_reach_to_channel.values()),
818
+ organic_reach_columns,
819
+ organic_frequency_columns,
820
+ organic_rf_channel_names,
906
821
  self.coord_to_columns.time,
907
822
  self.coord_to_columns.geo,
908
823
  )
824
+
909
825
  return builder.build()
910
826
 
911
827