google-meridian 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/METADATA +6 -2
  2. google_meridian-1.1.2.dist-info/RECORD +46 -0
  3. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/WHEEL +1 -1
  4. meridian/__init__.py +2 -2
  5. meridian/analysis/__init__.py +1 -1
  6. meridian/analysis/analyzer.py +29 -22
  7. meridian/analysis/formatter.py +1 -1
  8. meridian/analysis/optimizer.py +70 -44
  9. meridian/analysis/summarizer.py +1 -1
  10. meridian/analysis/summary_text.py +1 -1
  11. meridian/analysis/test_utils.py +1 -1
  12. meridian/analysis/visualizer.py +17 -8
  13. meridian/constants.py +3 -3
  14. meridian/data/__init__.py +4 -1
  15. meridian/data/arg_builder.py +1 -1
  16. meridian/data/data_frame_input_data_builder.py +614 -0
  17. meridian/data/input_data.py +12 -8
  18. meridian/data/input_data_builder.py +817 -0
  19. meridian/data/load.py +121 -428
  20. meridian/data/nd_array_input_data_builder.py +509 -0
  21. meridian/data/test_utils.py +60 -43
  22. meridian/data/time_coordinates.py +1 -1
  23. meridian/mlflow/__init__.py +17 -0
  24. meridian/mlflow/autolog.py +54 -0
  25. meridian/model/__init__.py +1 -1
  26. meridian/model/adstock_hill.py +1 -1
  27. meridian/model/knots.py +1 -1
  28. meridian/model/media.py +1 -1
  29. meridian/model/model.py +65 -37
  30. meridian/model/model_test_data.py +75 -1
  31. meridian/model/posterior_sampler.py +19 -15
  32. meridian/model/prior_distribution.py +1 -1
  33. meridian/model/prior_sampler.py +32 -26
  34. meridian/model/spec.py +18 -8
  35. meridian/model/transformers.py +1 -1
  36. google_meridian-1.1.0.dist-info/RECORD +0 -41
  37. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/licenses/LICENSE +0 -0
  38. {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/top_level.txt +0 -0
meridian/data/load.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The Meridian Authors.
1
+ # Copyright 2025 The Meridian Authors.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -27,6 +27,7 @@ import warnings
27
27
 
28
28
  import immutabledict
29
29
  from meridian import constants
30
+ from meridian.data import data_frame_input_data_builder
30
31
  from meridian.data import input_data
31
32
  import numpy as np
32
33
  import pandas as pd
@@ -79,7 +80,7 @@ class XrDatasetDataLoader(InputDataLoader):
79
80
  """Constructor.
80
81
 
81
82
  The coordinates of the input dataset should be: `time`, `media_time`,
82
- `control_variable`, `geo` (optional for a national model),
83
+ `control_variable` (optional), `geo` (optional for a national model),
83
84
  `non_media_channel` (optional), `organic_media_channel` (optional),
84
85
  `organic_rf_channel` (optional), and
85
86
  either `media_channel`, `rf_channel`, or both.
@@ -93,7 +94,7 @@ class XrDatasetDataLoader(InputDataLoader):
93
94
 
94
95
  * `kpi`: `(geo, time)`
95
96
  * `revenue_per_kpi`: `(geo, time)`
96
- * `controls`: `(geo, time, control_variable)`
97
+ * `controls`: `(geo, time, control_variable)` - optional
97
98
  * `population`: `(geo)`
98
99
  * `media`: `(geo, media_time, media_channel)` - optional
99
100
  * `media_spend`: `(geo, time, media_channel)`, `(1, time, media_channel)`,
@@ -113,7 +114,7 @@ class XrDatasetDataLoader(InputDataLoader):
113
114
 
114
115
  * `kpi`: `([1,] time)`
115
116
  * `revenue_per_kpi`: `([1,] time)`
116
- * `controls`: `([1,] time, control_variable)`
117
+ * `controls`: `([1,] time, control_variable)` - optional
117
118
  * `population`: `([1],)` - this array is optional for national data
118
119
  * `media`: `([1,] media_time, media_channel)` - optional
119
120
  * `media_spend`: `([1,] time, media_channel)` or
@@ -198,7 +199,7 @@ class XrDatasetDataLoader(InputDataLoader):
198
199
  self.dataset = dataset.rename(name_mapping)
199
200
 
200
201
  # Add a `geo` dimension if it is not already present.
201
- if (constants.GEO) not in self.dataset.dims.keys():
202
+ if (constants.GEO) not in self.dataset.sizes.keys():
202
203
  self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
203
204
 
204
205
  if len(self.dataset.coords[constants.GEO]) == 1:
@@ -228,7 +229,7 @@ class XrDatasetDataLoader(InputDataLoader):
228
229
  compat='override',
229
230
  )
230
231
 
231
- if constants.MEDIA_TIME not in self.dataset.dims.keys():
232
+ if constants.MEDIA_TIME not in self.dataset.sizes.keys():
232
233
  self._add_media_time()
233
234
  self._normalize_time_coordinates(constants.TIME)
234
235
  self._normalize_time_coordinates(constants.MEDIA_TIME)
@@ -349,14 +350,17 @@ class XrDatasetDataLoader(InputDataLoader):
349
350
  # Arrays in which NAs are expected in the lagged-media period.
350
351
  na_arrays = [
351
352
  constants.KPI,
352
- constants.CONTROLS,
353
353
  ]
354
354
 
355
- na_mask = self.dataset[constants.KPI].isnull().any(
356
- dim=constants.GEO
357
- ) | self.dataset[constants.CONTROLS].isnull().any(
358
- dim=[constants.GEO, constants.CONTROL_VARIABLE]
359
- )
355
+ na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
356
+
357
+ if constants.CONTROLS in self.dataset.data_vars.keys():
358
+ na_arrays.append(constants.CONTROLS)
359
+ na_mask |= (
360
+ self.dataset[constants.CONTROLS]
361
+ .isnull()
362
+ .any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
363
+ )
360
364
 
361
365
  if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
362
366
  na_arrays.append(constants.NON_MEDIA_TREATMENTS)
@@ -427,11 +431,12 @@ class XrDatasetDataLoader(InputDataLoader):
427
431
  .dropna(dim=constants.TIME)
428
432
  .rename({constants.TIME: new_time})
429
433
  )
430
- new_dataset[constants.CONTROLS] = (
431
- new_dataset[constants.CONTROLS]
432
- .dropna(dim=constants.TIME)
433
- .rename({constants.TIME: new_time})
434
- )
434
+ if constants.CONTROLS in new_dataset.data_vars.keys():
435
+ new_dataset[constants.CONTROLS] = (
436
+ new_dataset[constants.CONTROLS]
437
+ .dropna(dim=constants.TIME)
438
+ .rename({constants.TIME: new_time})
439
+ )
435
440
  if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
436
441
  new_dataset[constants.NON_MEDIA_TREATMENTS] = (
437
442
  new_dataset[constants.NON_MEDIA_TREATMENTS]
@@ -466,6 +471,11 @@ class XrDatasetDataLoader(InputDataLoader):
466
471
 
467
472
  def load(self) -> input_data.InputData:
468
473
  """Returns an `InputData` object containing the data from the dataset."""
474
+ controls = (
475
+ self.dataset.controls
476
+ if constants.CONTROLS in self.dataset.data_vars.keys()
477
+ else None
478
+ )
469
479
  revenue_per_kpi = (
470
480
  self.dataset.revenue_per_kpi
471
481
  if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys()
@@ -519,9 +529,9 @@ class XrDatasetDataLoader(InputDataLoader):
519
529
  return input_data.InputData(
520
530
  kpi=self.dataset.kpi,
521
531
  kpi_type=self.kpi_type,
522
- revenue_per_kpi=revenue_per_kpi,
523
- controls=self.dataset.controls,
524
532
  population=self.dataset.population,
533
+ controls=controls,
534
+ revenue_per_kpi=revenue_per_kpi,
525
535
  media=media,
526
536
  media_spend=media_spend,
527
537
  reach=reach,
@@ -539,14 +549,14 @@ class CoordToColumns:
539
549
  """A mapping between the desired and actual column names in the input data.
540
550
 
541
551
  Attributes:
542
- controls: List of column names containing `controls` values in the input
543
- data.
544
552
  time: Name of column containing `time` values in the input data.
545
- kpi: Name of column containing `kpi` values in the input data.
546
- revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
547
- input data.
548
553
  geo: Name of column containing `geo` values in the input data. This field
549
554
  is optional for a national model.
555
+ kpi: Name of column containing `kpi` values in the input data.
556
+ controls: List of column names containing `controls` values in the input
557
+ data. Optional.
558
+ revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
559
+ input data. Optional. Will be overridden if model KPI type is "revenue".
550
560
  population: Name of column containing `population` values in the input data.
551
561
  This field is optional for a national model.
552
562
  media: List of column names containing `media` values in the input data.
@@ -567,11 +577,11 @@ class CoordToColumns:
567
577
  values in the input data.
568
578
  """
569
579
 
570
- controls: Sequence[str]
571
580
  time: str = constants.TIME
581
+ geo: str = constants.GEO
572
582
  kpi: str = constants.KPI
583
+ controls: Sequence[str] | None = None
573
584
  revenue_per_kpi: str | None = None
574
- geo: str = constants.GEO
575
585
  population: str = constants.POPULATION
576
586
  # Media data
577
587
  media: Sequence[str] | None = None
@@ -607,7 +617,7 @@ class DataFrameDataLoader(InputDataLoader):
607
617
  to the DataFrame column names if they are different. The fields are:
608
618
 
609
619
  * `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
610
- * `controls` (multiple columns)
620
+ * `controls` (multiple columns, optional)
611
621
  * (1) `media`, `media_spend` (multiple columns)
612
622
  * (2) `reach`, `frequency`, `rf_spend` (multiple columns)
613
623
  * `non_media_treatments` (multiple columns, optional)
@@ -792,110 +802,20 @@ class DataFrameDataLoader(InputDataLoader):
792
802
  organic_reach_to_channel: Mapping[str, str] | None = None
793
803
  organic_frequency_to_channel: Mapping[str, str] | None = None
794
804
 
795
- # If [key] in the following dict exists as an attribute in `coord_to_columns`,
796
- # then the corresponding attribute must exist in this loader instance.
797
- _required_mappings = immutabledict.immutabledict({
798
- 'media': 'media_to_channel',
799
- 'media_spend': 'media_spend_to_channel',
800
- 'reach': 'reach_to_channel',
801
- 'frequency': 'frequency_to_channel',
802
- 'rf_spend': 'rf_spend_to_channel',
803
- 'organic_reach': 'organic_reach_to_channel',
804
- 'organic_frequency': 'organic_frequency_to_channel',
805
- })
806
-
807
805
  def __post_init__(self):
808
- self._validate_and_normalize_time_values()
809
- self._expand_if_national()
810
- self._validate_column_names()
811
- self._validate_required_mappings()
812
- self._validate_geo_and_time()
813
- self._validate_nas()
814
-
815
- def _validate_and_normalize_time_values(self):
816
- """Validates that time values are in the conventional Meridian format.
817
-
818
- Time values are expected to be (a) strings formatted in `"yyyy-mm-dd"` or
819
- (b) `datetime` values as numpy's `datetime64` types. All other types are
820
- not currently supported.
821
-
822
- In (b) case, `datetime` coordinate values will be normalized as formatted
823
- strings.
824
- """
825
- time_column_name = self.coord_to_columns.time
826
-
827
- if self.df.dtypes[time_column_name] == np.dtype('datetime64[ns]'):
828
- self.df[time_column_name] = self.df[time_column_name].map(
829
- lambda time: time.strftime(constants.DATE_FORMAT)
830
- )
831
- else:
832
- # Assume that the `time` column values are strings formatted as dates.
833
- for _, time in self.df[time_column_name].items():
834
- try:
835
- _ = dt.datetime.strptime(time, constants.DATE_FORMAT)
836
- except ValueError as exc:
837
- raise ValueError(
838
- f"Invalid time label: '{time}'. Expected format:"
839
- f" '{constants.DATE_FORMAT}'"
840
- ) from exc
841
-
842
- def _validate_column_names(self):
843
- """Validates the column names in `df` and `coord_to_columns`."""
844
-
845
- desired_columns = []
846
- for field in dataclasses.fields(self.coord_to_columns):
847
- value = getattr(self.coord_to_columns, field.name)
848
- if isinstance(value, str):
849
- desired_columns.append(value)
850
- elif isinstance(value, Sequence):
851
- for column in value:
852
- desired_columns.append(column)
853
- desired_columns = sorted(desired_columns)
854
-
855
- actual_columns = sorted(self.df.columns.to_list())
856
- if any(d not in actual_columns for d in desired_columns):
857
- raise ValueError(
858
- f'Values of the `coord_to_columns` object {desired_columns}'
859
- f' should map to the DataFrame column names {actual_columns}.'
860
- )
861
-
862
- def _expand_if_national(self):
863
- """Adds geo/population columns in a national model if necessary."""
864
-
865
- geo_column_name = self.coord_to_columns.geo
866
- population_column_name = self.coord_to_columns.population
867
-
868
- def set_default_population_with_lag_periods():
869
- """Sets the `population` column.
870
-
871
- The `population` column is set to the default value for non-lag periods,
872
- and None for lag-periods. The lag periods are inferred from the Nan values
873
- in the other non-media columns.
874
- """
875
- non_lagged_idx = self.df.isna().idxmin().max()
876
- self.df[population_column_name] = (
877
- constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE
878
- )
879
- self.df.loc[: non_lagged_idx - 1, population_column_name] = None
880
-
881
- if geo_column_name not in self.df.columns:
882
- self.df[geo_column_name] = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
883
-
884
- if self.df[geo_column_name].nunique() == 1:
885
- self.df[geo_column_name] = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
886
- if population_column_name in self.df.columns:
887
- warnings.warn(
888
- 'The `population` argument is ignored in a nationally aggregated'
889
- ' model. It will be reset to [1, 1, ..., 1]'
890
- )
891
- set_default_population_with_lag_periods()
892
-
893
- if population_column_name not in self.df.columns:
894
- set_default_population_with_lag_periods()
895
-
896
- def _validate_required_mappings(self):
897
- """Validates required mappings in `coord_to_columns`."""
898
- for coord_name, channel_dict in self._required_mappings.items():
806
+ # If [key] in the following dict exists as an attribute in
807
+ # `coord_to_columns`, then the corresponding attribute must exist in this
808
+ # loader instance.
809
+ required_mappings = immutabledict.immutabledict({
810
+ 'media': 'media_to_channel',
811
+ 'media_spend': 'media_spend_to_channel',
812
+ 'reach': 'reach_to_channel',
813
+ 'frequency': 'frequency_to_channel',
814
+ 'rf_spend': 'rf_spend_to_channel',
815
+ 'organic_reach': 'organic_reach_to_channel',
816
+ 'organic_frequency': 'organic_frequency_to_channel',
817
+ })
818
+ for coord_name, channel_dict in required_mappings.items():
899
819
  if (
900
820
  getattr(self.coord_to_columns, coord_name, None) is not None
901
821
  and getattr(self, channel_dict, None) is None
@@ -904,316 +824,89 @@ class DataFrameDataLoader(InputDataLoader):
904
824
  f"When {coord_name} data is provided, '{channel_dict}' is required."
905
825
  )
906
826
 
907
- def _validate_geo_and_time(self):
908
- """Validates that for every geo the list of `time`s is the same."""
909
- geo_column_name = self.coord_to_columns.geo
910
- time_column_name = self.coord_to_columns.time
911
-
912
- df_grouped = self.df.sort_values(time_column_name).groupby(
913
- geo_column_name, sort=False
914
- )[time_column_name]
915
- if any(df_grouped.count() != df_grouped.nunique()):
916
- raise ValueError("Duplicate entries found in the 'time' column.")
917
-
918
- times_by_geo = df_grouped.apply(list).reset_index(drop=True)
919
- if any(t != times_by_geo[0] for t in times_by_geo[1:]):
920
- raise ValueError(
921
- "Values in the 'time' column not consistent across different geos."
922
- )
923
-
924
- def _validate_nas(self):
925
- """Validates that the only NAs are in the lagged-media period."""
926
- # Check if there are no NAs in media.
927
- if self.coord_to_columns.media is not None:
928
- if self.df[self.coord_to_columns.media].isna().any(axis=None):
929
- raise ValueError('NA values found in the media columns.')
930
-
931
- # Check if there are no NAs in reach & frequency.
932
- if self.coord_to_columns.reach is not None:
933
- if self.df[self.coord_to_columns.reach].isna().any(axis=None):
934
- raise ValueError('NA values found in the reach columns.')
935
- if self.coord_to_columns.frequency is not None:
936
- if self.df[self.coord_to_columns.frequency].isna().any(axis=None):
937
- raise ValueError('NA values found in the frequency columns.')
938
-
939
- # Check if ther are no NAs in organic_media.
940
- if self.coord_to_columns.organic_media is not None:
941
- if self.df[self.coord_to_columns.organic_media].isna().any(axis=None):
942
- raise ValueError('NA values found in the organic_media columns.')
943
-
944
- # Check if there are no NAs in organic_reach & organic_frequency.
945
- if self.coord_to_columns.organic_reach is not None:
946
- if self.df[self.coord_to_columns.organic_reach].isna().any(axis=None):
947
- raise ValueError('NA values found in the organic_reach columns.')
948
- if self.coord_to_columns.organic_frequency is not None:
949
- if self.df[self.coord_to_columns.organic_frequency].isna().any(axis=None):
950
- raise ValueError('NA values found in the organic_frequency columns.')
951
-
952
- # Determine columns in which NAs are expected in the lagged-media period.
953
- not_lagged_columns = []
954
- coords = [
955
- constants.KPI,
956
- constants.CONTROLS,
957
- constants.POPULATION,
958
- ]
959
- if self.coord_to_columns.revenue_per_kpi is not None:
960
- coords.append(constants.REVENUE_PER_KPI)
961
- if self.coord_to_columns.media_spend is not None:
962
- coords.append(constants.MEDIA_SPEND)
963
- if self.coord_to_columns.rf_spend is not None:
964
- coords.append(constants.RF_SPEND)
965
- if self.coord_to_columns.non_media_treatments is not None:
966
- coords.append(constants.NON_MEDIA_TREATMENTS)
967
- for coord in coords:
968
- columns = getattr(self.coord_to_columns, coord)
969
- columns = [columns] if isinstance(columns, str) else columns
970
- not_lagged_columns.extend(columns)
971
-
972
- # Dates with at least one non-NA value in columns different from media,
973
- # reach, frequency, organic_media, organic_reach, and organic_frequency.
974
- time_column_name = self.coord_to_columns.time
975
- no_na_period = self.df[(~self.df[not_lagged_columns].isna()).any(axis=1)][
976
- time_column_name
977
- ].unique()
978
-
979
- # Dates with 100% NA values in all columns different from media, reach,
980
- # frequency, organic_media, organic_reach, and organic_frequency.
981
- na_period = [
982
- t for t in self.df[time_column_name].unique() if t not in no_na_period
983
- ]
984
-
985
- # Check if na_period is a continuous window starting from the earliest time
986
- # period.
987
- if not np.all(
988
- np.sort(na_period)
989
- == np.sort(self.df[time_column_name].unique())[: len(na_period)]
990
- ):
991
- raise ValueError(
992
- "The 'lagged media' period (period with 100% NA values in all"
993
- f' non-media columns) {na_period} is not a continuous window starting'
994
- ' from the earliest time period.'
995
- )
996
-
997
- # Check if for the non-lagged period, there are no NAs in data different
998
- # from media, reach, frequency, organic_media, organic_reach, and
999
- # organic_frequency.
1000
- not_lagged_data = self.df.loc[
1001
- self.df[time_column_name].isin(no_na_period),
1002
- not_lagged_columns,
1003
- ]
1004
- if not_lagged_data.isna().any(axis=None):
1005
- incorrect_columns = []
1006
- for column in not_lagged_columns:
1007
- if not_lagged_data[column].isna().any(axis=None):
1008
- incorrect_columns.append(column)
1009
- raise ValueError(
1010
- f'NA values found in columns {incorrect_columns} within the modeling'
1011
- ' time window (time periods where the KPI is modeled).'
1012
- )
1013
-
1014
827
  def load(self) -> input_data.InputData:
1015
828
  """Reads data from a dataframe and returns an InputData object."""
1016
829
 
1017
- # Change geo strings to numbers to keep the order of geos. The .to_xarray()
1018
- # method from Pandas sorts lexicographically by the key columns, so if the
1019
- # geos were unsorted strings, it would change their order.
1020
- geo_column_name = self.coord_to_columns.geo
1021
- time_column_name = self.coord_to_columns.time
1022
- geo_names = self.df[geo_column_name].unique()
1023
- self.df[geo_column_name] = self.df[geo_column_name].replace(
1024
- dict(zip(geo_names, np.arange(len(geo_names))))
830
+ builder = data_frame_input_data_builder.DataFrameInputDataBuilder(
831
+ kpi_type=self.kpi_type
832
+ ).with_kpi(
833
+ self.df,
834
+ self.coord_to_columns.kpi,
835
+ self.coord_to_columns.time,
836
+ self.coord_to_columns.geo,
1025
837
  )
1026
- df_indexed = self.df.set_index([geo_column_name, time_column_name])
1027
-
1028
- kpi_xr = (
1029
- df_indexed[self.coord_to_columns.kpi]
1030
- .dropna()
1031
- .rename(constants.KPI)
1032
- .rename_axis([constants.GEO, constants.TIME])
1033
- .to_frame()
1034
- .to_xarray()
1035
- )
1036
- population_xr = (
1037
- df_indexed[self.coord_to_columns.population]
1038
- .groupby(geo_column_name)
1039
- .mean()
1040
- .rename(constants.POPULATION)
1041
- .rename_axis([constants.GEO])
1042
- .to_frame()
1043
- .to_xarray()
1044
- )
1045
- controls_xr = (
1046
- df_indexed[self.coord_to_columns.controls]
1047
- .stack()
1048
- .rename(constants.CONTROLS)
1049
- .rename_axis(
1050
- [constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
1051
- )
1052
- .to_frame()
1053
- .to_xarray()
1054
- )
1055
- dataset = xr.combine_by_coords([kpi_xr, population_xr, controls_xr])
1056
-
1057
- if self.coord_to_columns.non_media_treatments is not None:
1058
- non_media_xr = (
1059
- df_indexed[self.coord_to_columns.non_media_treatments]
1060
- .stack()
1061
- .rename(constants.NON_MEDIA_TREATMENTS)
1062
- .rename_axis(
1063
- [constants.GEO, constants.TIME, constants.NON_MEDIA_CHANNEL]
1064
- )
1065
- .to_frame()
1066
- .to_xarray()
838
+ if self.coord_to_columns.population in self.df.columns:
839
+ builder.with_population(
840
+ self.df, self.coord_to_columns.population, self.coord_to_columns.geo
1067
841
  )
1068
- dataset = xr.combine_by_coords([dataset, non_media_xr])
1069
-
1070
- if self.coord_to_columns.revenue_per_kpi is not None:
1071
- revenue_per_kpi_xr = (
1072
- df_indexed[self.coord_to_columns.revenue_per_kpi]
1073
- .dropna()
1074
- .rename(constants.REVENUE_PER_KPI)
1075
- .rename_axis([constants.GEO, constants.TIME])
1076
- .to_frame()
1077
- .to_xarray()
1078
- )
1079
- dataset = xr.combine_by_coords([dataset, revenue_per_kpi_xr])
1080
- if self.coord_to_columns.media is not None:
1081
- media_xr = (
1082
- df_indexed[self.coord_to_columns.media]
1083
- .stack()
1084
- .rename(constants.MEDIA)
1085
- .rename_axis(
1086
- [constants.GEO, constants.MEDIA_TIME, constants.MEDIA_CHANNEL]
1087
- )
1088
- .to_frame()
1089
- .to_xarray()
1090
- )
1091
- media_xr.coords[constants.MEDIA_CHANNEL] = [
1092
- self.media_to_channel[x]
1093
- for x in media_xr.coords[constants.MEDIA_CHANNEL].values
1094
- ]
1095
-
1096
- media_spend_xr = (
1097
- df_indexed[self.coord_to_columns.media_spend]
1098
- .stack()
1099
- .rename(constants.MEDIA_SPEND)
1100
- .rename_axis([constants.GEO, constants.TIME, constants.MEDIA_CHANNEL])
1101
- .to_frame()
1102
- .to_xarray()
842
+ if self.coord_to_columns.controls is not None:
843
+ builder.with_controls(
844
+ self.df,
845
+ list(self.coord_to_columns.controls),
846
+ self.coord_to_columns.time,
847
+ self.coord_to_columns.geo,
1103
848
  )
1104
- media_spend_xr.coords[constants.MEDIA_CHANNEL] = [
1105
- self.media_spend_to_channel[x]
1106
- for x in media_spend_xr.coords[constants.MEDIA_CHANNEL].values
1107
- ]
1108
- dataset = xr.combine_by_coords([dataset, media_xr, media_spend_xr])
1109
-
1110
- if self.coord_to_columns.reach is not None:
1111
- reach_xr = (
1112
- df_indexed[self.coord_to_columns.reach]
1113
- .stack()
1114
- .rename(constants.REACH)
1115
- .rename_axis(
1116
- [constants.GEO, constants.MEDIA_TIME, constants.RF_CHANNEL]
1117
- )
1118
- .to_frame()
1119
- .to_xarray()
1120
- )
1121
- reach_xr.coords[constants.RF_CHANNEL] = [
1122
- self.reach_to_channel[x]
1123
- for x in reach_xr.coords[constants.RF_CHANNEL].values
1124
- ]
1125
-
1126
- frequency_xr = (
1127
- df_indexed[self.coord_to_columns.frequency]
1128
- .stack()
1129
- .rename(constants.FREQUENCY)
1130
- .rename_axis(
1131
- [constants.GEO, constants.MEDIA_TIME, constants.RF_CHANNEL]
1132
- )
1133
- .to_frame()
1134
- .to_xarray()
849
+ if self.coord_to_columns.non_media_treatments is not None:
850
+ builder.with_non_media_treatments(
851
+ self.df,
852
+ list(self.coord_to_columns.non_media_treatments),
853
+ self.coord_to_columns.time,
854
+ self.coord_to_columns.geo,
1135
855
  )
1136
- frequency_xr.coords[constants.RF_CHANNEL] = [
1137
- self.frequency_to_channel[x]
1138
- for x in frequency_xr.coords[constants.RF_CHANNEL].values
1139
- ]
1140
-
1141
- rf_spend_xr = (
1142
- df_indexed[self.coord_to_columns.rf_spend]
1143
- .stack()
1144
- .rename(constants.RF_SPEND)
1145
- .rename_axis([constants.GEO, constants.TIME, constants.RF_CHANNEL])
1146
- .to_frame()
1147
- .to_xarray()
856
+ if self.coord_to_columns.revenue_per_kpi is not None:
857
+ builder.with_revenue_per_kpi(
858
+ self.df,
859
+ self.coord_to_columns.revenue_per_kpi,
860
+ self.coord_to_columns.time,
861
+ self.coord_to_columns.geo,
1148
862
  )
1149
- rf_spend_xr.coords[constants.RF_CHANNEL] = [
1150
- self.rf_spend_to_channel[x]
1151
- for x in rf_spend_xr.coords[constants.RF_CHANNEL].values
1152
- ]
1153
- dataset = xr.combine_by_coords(
1154
- [dataset, reach_xr, frequency_xr, rf_spend_xr]
863
+ if (
864
+ self.coord_to_columns.media is not None
865
+ and self.media_to_channel is not None
866
+ ):
867
+ builder.with_media(
868
+ self.df,
869
+ list(self.coord_to_columns.media),
870
+ list(self.coord_to_columns.media_spend),
871
+ list(self.media_to_channel.values()),
872
+ self.coord_to_columns.time,
873
+ self.coord_to_columns.geo,
1155
874
  )
1156
875
 
1157
- if self.coord_to_columns.organic_media is not None:
1158
- organic_media_xr = (
1159
- df_indexed[self.coord_to_columns.organic_media]
1160
- .stack()
1161
- .rename(constants.ORGANIC_MEDIA)
1162
- .rename_axis([
1163
- constants.GEO,
1164
- constants.MEDIA_TIME,
1165
- constants.ORGANIC_MEDIA_CHANNEL,
1166
- ])
1167
- .to_frame()
1168
- .to_xarray()
1169
- )
1170
- dataset = xr.combine_by_coords([dataset, organic_media_xr])
1171
-
1172
- if self.coord_to_columns.organic_reach is not None:
1173
- organic_reach_xr = (
1174
- df_indexed[self.coord_to_columns.organic_reach]
1175
- .stack()
1176
- .rename(constants.ORGANIC_REACH)
1177
- .rename_axis([
1178
- constants.GEO,
1179
- constants.MEDIA_TIME,
1180
- constants.ORGANIC_RF_CHANNEL,
1181
- ])
1182
- .to_frame()
1183
- .to_xarray()
876
+ if (
877
+ self.coord_to_columns.reach is not None
878
+ and self.reach_to_channel is not None
879
+ ):
880
+ builder.with_reach(
881
+ self.df,
882
+ list(self.coord_to_columns.reach),
883
+ list(self.coord_to_columns.frequency),
884
+ list(self.coord_to_columns.rf_spend),
885
+ list(self.reach_to_channel.values()),
886
+ self.coord_to_columns.time,
887
+ self.coord_to_columns.geo,
1184
888
  )
1185
- organic_reach_xr.coords[constants.ORGANIC_RF_CHANNEL] = [
1186
- self.organic_reach_to_channel[x]
1187
- for x in organic_reach_xr.coords[constants.ORGANIC_RF_CHANNEL].values
1188
- ]
1189
- organic_frequency_xr = (
1190
- df_indexed[self.coord_to_columns.organic_frequency]
1191
- .stack()
1192
- .rename(constants.ORGANIC_FREQUENCY)
1193
- .rename_axis([
1194
- constants.GEO,
1195
- constants.MEDIA_TIME,
1196
- constants.ORGANIC_RF_CHANNEL,
1197
- ])
1198
- .to_frame()
1199
- .to_xarray()
889
+ if self.coord_to_columns.organic_media is not None:
890
+ builder.with_organic_media(
891
+ self.df,
892
+ list(self.coord_to_columns.organic_media),
893
+ list(self.coord_to_columns.organic_media),
894
+ self.coord_to_columns.time,
895
+ self.coord_to_columns.geo,
1200
896
  )
1201
- organic_frequency_xr.coords[constants.ORGANIC_RF_CHANNEL] = [
1202
- self.organic_frequency_to_channel[x]
1203
- for x in organic_frequency_xr.coords[
1204
- constants.ORGANIC_RF_CHANNEL
1205
- ].values
1206
- ]
1207
- dataset = xr.combine_by_coords(
1208
- [dataset, organic_reach_xr, organic_frequency_xr]
897
+ if (
898
+ self.coord_to_columns.organic_reach is not None
899
+ and self.organic_reach_to_channel is not None
900
+ ):
901
+ builder.with_organic_reach(
902
+ self.df,
903
+ list(self.coord_to_columns.organic_reach),
904
+ list(self.coord_to_columns.organic_frequency),
905
+ list(self.organic_reach_to_channel.values()),
906
+ self.coord_to_columns.time,
907
+ self.coord_to_columns.geo,
1209
908
  )
1210
-
1211
- # Change back to geo names
1212
- self.df[geo_column_name] = self.df[geo_column_name].replace(
1213
- dict(zip(np.arange(len(geo_names)), geo_names))
1214
- )
1215
- dataset.coords[constants.GEO] = geo_names
1216
- return XrDatasetDataLoader(dataset, kpi_type=self.kpi_type).load()
909
+ return builder.build()
1217
910
 
1218
911
 
1219
912
  class CsvDataLoader(InputDataLoader):
@@ -1224,7 +917,7 @@ class CsvDataLoader(InputDataLoader):
1224
917
  CSV column names, if they are different. The fields are:
1225
918
 
1226
919
  * `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
1227
- * `controls` (multiple columns)
920
+ * `controls` (multiple columns, optional)
1228
921
  * (1) `media`, `media_spend` (multiple columns)
1229
922
  * (2) `reach`, `frequency`, `rf_spend` (multiple columns)
1230
923
  * `non_media_treatments` (multiple columns, optional)