google-meridian 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
meridian/data/load.py CHANGED
@@ -27,6 +27,7 @@ import warnings
27
27
 
28
28
  import immutabledict
29
29
  from meridian import constants
30
+ from meridian.data import data_frame_input_data_builder
30
31
  from meridian.data import input_data
31
32
  import numpy as np
32
33
  import pandas as pd
@@ -801,432 +802,167 @@ class DataFrameDataLoader(InputDataLoader):
801
802
  organic_reach_to_channel: Mapping[str, str] | None = None
802
803
  organic_frequency_to_channel: Mapping[str, str] | None = None
803
804
 
804
- # If [key] in the following dict exists as an attribute in `coord_to_columns`,
805
- # then the corresponding attribute must exist in this loader instance.
806
- _required_mappings = immutabledict.immutabledict({
807
- 'media': 'media_to_channel',
808
- 'media_spend': 'media_spend_to_channel',
809
- 'reach': 'reach_to_channel',
810
- 'frequency': 'frequency_to_channel',
811
- 'rf_spend': 'rf_spend_to_channel',
812
- 'organic_reach': 'organic_reach_to_channel',
813
- 'organic_frequency': 'organic_frequency_to_channel',
814
- })
815
-
816
805
  def __post_init__(self):
817
- self._validate_and_normalize_time_values()
818
- self._expand_if_national()
819
- self._validate_column_names()
820
- self._validate_required_mappings()
821
- self._validate_geo_and_time()
822
- self._validate_nas()
823
-
824
- def _validate_and_normalize_time_values(self):
825
- """Validates that time values are in the conventional Meridian format.
826
-
827
- Time values are expected to be (a) strings formatted in `"yyyy-mm-dd"` or
828
- (b) `datetime` values as numpy's `datetime64` types. All other types are
829
- not currently supported.
830
-
831
- In (b) case, `datetime` coordinate values will be normalized as formatted
832
- strings.
833
- """
834
- time_column_name = self.coord_to_columns.time
835
-
836
- if self.df.dtypes[time_column_name] == np.dtype('datetime64[ns]'):
837
- self.df[time_column_name] = self.df[time_column_name].map(
838
- lambda time: time.strftime(constants.DATE_FORMAT)
839
- )
840
- else:
841
- # Assume that the `time` column values are strings formatted as dates.
842
- for _, time in self.df[time_column_name].items():
843
- try:
844
- _ = dt.datetime.strptime(time, constants.DATE_FORMAT)
845
- except ValueError as exc:
806
+ # If [key] in the following dict exists as an attribute in
807
+ # `coord_to_columns`, then the corresponding attribute must exist in this
808
+ # loader instance.
809
+ required_mappings = immutabledict.immutabledict({
810
+ 'media': 'media_to_channel',
811
+ 'media_spend': 'media_spend_to_channel',
812
+ 'reach': 'reach_to_channel',
813
+ 'frequency': 'frequency_to_channel',
814
+ 'rf_spend': 'rf_spend_to_channel',
815
+ 'organic_reach': 'organic_reach_to_channel',
816
+ 'organic_frequency': 'organic_frequency_to_channel',
817
+ })
818
+ for coord_name, channel_dict in required_mappings.items():
819
+ if getattr(self.coord_to_columns, coord_name, None) is not None:
820
+ if getattr(self, channel_dict, None) is None:
846
821
  raise ValueError(
847
- f"Invalid time label: '{time}'. Expected format:"
848
- f" '{constants.DATE_FORMAT}'"
849
- ) from exc
850
-
851
- def _validate_column_names(self):
852
- """Validates the column names in `df` and `coord_to_columns`."""
853
-
854
- desired_columns = []
855
- for field in dataclasses.fields(self.coord_to_columns):
856
- value = getattr(self.coord_to_columns, field.name)
857
- if isinstance(value, str):
858
- desired_columns.append(value)
859
- elif isinstance(value, Sequence):
860
- for column in value:
861
- desired_columns.append(column)
862
- desired_columns = sorted(desired_columns)
863
-
864
- actual_columns = sorted(self.df.columns.to_list())
865
- if any(d not in actual_columns for d in desired_columns):
866
- raise ValueError(
867
- f'Values of the `coord_to_columns` object {desired_columns}'
868
- f' should map to the DataFrame column names {actual_columns}.'
869
- )
870
-
871
- def _expand_if_national(self):
872
- """Adds geo/population columns in a national model if necessary."""
873
-
874
- geo_column_name = self.coord_to_columns.geo
875
- population_column_name = self.coord_to_columns.population
876
-
877
- def set_default_population_with_lag_periods():
878
- """Sets the `population` column.
879
-
880
- The `population` column is set to the default value for non-lag periods,
881
- and None for lag-periods. The lag periods are inferred from the Nan values
882
- in the other non-media columns.
883
- """
884
- non_lagged_idx = self.df.isna().idxmin().max()
885
- self.df[population_column_name] = (
886
- constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE
887
- )
888
- self.df.loc[: non_lagged_idx - 1, population_column_name] = None
889
-
890
- if geo_column_name not in self.df.columns:
891
- self.df[geo_column_name] = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
892
-
893
- if self.df[geo_column_name].nunique() == 1:
894
- self.df[geo_column_name] = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
895
- if population_column_name in self.df.columns:
896
- warnings.warn(
897
- 'The `population` argument is ignored in a nationally aggregated'
898
- ' model. It will be reset to [1, 1, ..., 1]'
822
+ f"When {coord_name} data is provided, '{channel_dict}' is"
823
+ ' required.'
824
+ )
825
+ else:
826
+ if set(getattr(self, channel_dict)) != set(
827
+ getattr(self.coord_to_columns, coord_name)
828
+ ):
829
+ raise ValueError(
830
+ f'The {channel_dict} keys must have the same set of values as'
831
+ f' the {coord_name} columns.'
832
+ )
833
+ if (
834
+ self.media_to_channel is not None
835
+ and self.media_spend_to_channel is not None
836
+ ):
837
+ if set(self.media_to_channel.values()) != set(
838
+ self.media_spend_to_channel.values()
839
+ ):
840
+ raise ValueError(
841
+ 'The media and media_spend columns must have the same set of'
842
+ ' channels.'
899
843
  )
900
- set_default_population_with_lag_periods()
901
-
902
- if population_column_name not in self.df.columns:
903
- set_default_population_with_lag_periods()
904
-
905
- def _validate_required_mappings(self):
906
- """Validates required mappings in `coord_to_columns`."""
907
- for coord_name, channel_dict in self._required_mappings.items():
844
+ if (
845
+ self.reach_to_channel is not None
846
+ and self.frequency_to_channel is not None
847
+ and self.rf_spend_to_channel is not None
848
+ ):
908
849
  if (
909
- getattr(self.coord_to_columns, coord_name, None) is not None
910
- and getattr(self, channel_dict, None) is None
850
+ set(self.reach_to_channel.values())
851
+ != set(self.frequency_to_channel.values())
852
+ != set(self.rf_spend_to_channel.values())
911
853
  ):
912
854
  raise ValueError(
913
- f"When {coord_name} data is provided, '{channel_dict}' is required."
855
+ 'The reach, frequency, and rf_spend columns must have the same set'
856
+ ' of channels.'
914
857
  )
915
-
916
- def _validate_geo_and_time(self):
917
- """Validates that for every geo the list of `time`s is the same."""
918
- geo_column_name = self.coord_to_columns.geo
919
- time_column_name = self.coord_to_columns.time
920
-
921
- df_grouped = self.df.sort_values(time_column_name).groupby(
922
- geo_column_name, sort=False
923
- )[time_column_name]
924
- if any(df_grouped.count() != df_grouped.nunique()):
925
- raise ValueError("Duplicate entries found in the 'time' column.")
926
-
927
- times_by_geo = df_grouped.apply(list).reset_index(drop=True)
928
- if any(t != times_by_geo[0] for t in times_by_geo[1:]):
929
- raise ValueError(
930
- "Values in the 'time' column not consistent across different geos."
931
- )
932
-
933
- def _validate_nas(self):
934
- """Validates that the only NAs are in the lagged-media period."""
935
- # Check if there are no NAs in media.
936
- if self.coord_to_columns.media is not None:
937
- if self.df[self.coord_to_columns.media].isna().any(axis=None):
938
- raise ValueError('NA values found in the media columns.')
939
-
940
- # Check if there are no NAs in reach & frequency.
941
- if self.coord_to_columns.reach is not None:
942
- if self.df[self.coord_to_columns.reach].isna().any(axis=None):
943
- raise ValueError('NA values found in the reach columns.')
944
- if self.coord_to_columns.frequency is not None:
945
- if self.df[self.coord_to_columns.frequency].isna().any(axis=None):
946
- raise ValueError('NA values found in the frequency columns.')
947
-
948
- # Check if ther are no NAs in organic_media.
949
- if self.coord_to_columns.organic_media is not None:
950
- if self.df[self.coord_to_columns.organic_media].isna().any(axis=None):
951
- raise ValueError('NA values found in the organic_media columns.')
952
-
953
- # Check if there are no NAs in organic_reach & organic_frequency.
954
- if self.coord_to_columns.organic_reach is not None:
955
- if self.df[self.coord_to_columns.organic_reach].isna().any(axis=None):
956
- raise ValueError('NA values found in the organic_reach columns.')
957
- if self.coord_to_columns.organic_frequency is not None:
958
- if self.df[self.coord_to_columns.organic_frequency].isna().any(axis=None):
959
- raise ValueError('NA values found in the organic_frequency columns.')
960
-
961
- # Determine columns in which NAs are expected in the lagged-media period.
962
- not_lagged_columns = []
963
- coords = [
964
- constants.KPI,
965
- constants.POPULATION,
966
- ]
967
- if self.coord_to_columns.controls is not None:
968
- coords.append(constants.CONTROLS)
969
- if self.coord_to_columns.revenue_per_kpi is not None:
970
- coords.append(constants.REVENUE_PER_KPI)
971
- if self.coord_to_columns.media_spend is not None:
972
- coords.append(constants.MEDIA_SPEND)
973
- if self.coord_to_columns.rf_spend is not None:
974
- coords.append(constants.RF_SPEND)
975
- if self.coord_to_columns.non_media_treatments is not None:
976
- coords.append(constants.NON_MEDIA_TREATMENTS)
977
- for coord in coords:
978
- columns = getattr(self.coord_to_columns, coord)
979
- columns = [columns] if isinstance(columns, str) else columns
980
- not_lagged_columns.extend(columns)
981
-
982
- # Dates with at least one non-NA value in columns different from media,
983
- # reach, frequency, organic_media, organic_reach, and organic_frequency.
984
- time_column_name = self.coord_to_columns.time
985
- no_na_period = self.df[(~self.df[not_lagged_columns].isna()).any(axis=1)][
986
- time_column_name
987
- ].unique()
988
-
989
- # Dates with 100% NA values in all columns different from media, reach,
990
- # frequency, organic_media, organic_reach, and organic_frequency.
991
- na_period = [
992
- t for t in self.df[time_column_name].unique() if t not in no_na_period
993
- ]
994
-
995
- # Check if na_period is a continuous window starting from the earliest time
996
- # period.
997
- if not np.all(
998
- np.sort(na_period)
999
- == np.sort(self.df[time_column_name].unique())[: len(na_period)]
858
+ if (
859
+ self.organic_reach_to_channel is not None
860
+ and self.organic_frequency_to_channel is not None
1000
861
  ):
1001
- raise ValueError(
1002
- "The 'lagged media' period (period with 100% NA values in all"
1003
- f' non-media columns) {na_period} is not a continuous window starting'
1004
- ' from the earliest time period.'
1005
- )
1006
-
1007
- # Check if for the non-lagged period, there are no NAs in data different
1008
- # from media, reach, frequency, organic_media, organic_reach, and
1009
- # organic_frequency.
1010
- not_lagged_data = self.df.loc[
1011
- self.df[time_column_name].isin(no_na_period),
1012
- not_lagged_columns,
1013
- ]
1014
- if not_lagged_data.isna().any(axis=None):
1015
- incorrect_columns = []
1016
- for column in not_lagged_columns:
1017
- if not_lagged_data[column].isna().any(axis=None):
1018
- incorrect_columns.append(column)
1019
- raise ValueError(
1020
- f'NA values found in columns {incorrect_columns} within the modeling'
1021
- ' time window (time periods where the KPI is modeled).'
1022
- )
862
+ if set(self.organic_reach_to_channel.values()) != set(
863
+ self.organic_frequency_to_channel.values()
864
+ ):
865
+ raise ValueError(
866
+ 'The organic_reach and organic_frequency columns must have the'
867
+ ' same set of channels.'
868
+ )
1023
869
 
1024
870
  def load(self) -> input_data.InputData:
1025
871
  """Reads data from a dataframe and returns an InputData object."""
1026
872
 
1027
- # Change geo strings to numbers to keep the order of geos. The .to_xarray()
1028
- # method from Pandas sorts lexicographically by the key columns, so if the
1029
- # geos were unsorted strings, it would change their order.
1030
- geo_column_name = self.coord_to_columns.geo
1031
- time_column_name = self.coord_to_columns.time
1032
- geo_names = self.df[geo_column_name].unique()
1033
- self.df[geo_column_name] = self.df[geo_column_name].replace(
1034
- dict(zip(geo_names, np.arange(len(geo_names))))
1035
- )
1036
- df_indexed = self.df.set_index([geo_column_name, time_column_name])
1037
-
1038
- kpi_xr = (
1039
- df_indexed[self.coord_to_columns.kpi]
1040
- .dropna()
1041
- .rename(constants.KPI)
1042
- .rename_axis([constants.GEO, constants.TIME])
1043
- .to_frame()
1044
- .to_xarray()
1045
- )
1046
- population_xr = (
1047
- df_indexed[self.coord_to_columns.population]
1048
- .groupby(geo_column_name)
1049
- .mean()
1050
- .rename(constants.POPULATION)
1051
- .rename_axis([constants.GEO])
1052
- .to_frame()
1053
- .to_xarray()
873
+ builder = data_frame_input_data_builder.DataFrameInputDataBuilder(
874
+ kpi_type=self.kpi_type
875
+ ).with_kpi(
876
+ self.df,
877
+ self.coord_to_columns.kpi,
878
+ self.coord_to_columns.time,
879
+ self.coord_to_columns.geo,
1054
880
  )
1055
- dataset = xr.combine_by_coords([kpi_xr, population_xr])
1056
-
881
+ if self.coord_to_columns.population in self.df.columns:
882
+ builder.with_population(
883
+ self.df, self.coord_to_columns.population, self.coord_to_columns.geo
884
+ )
1057
885
  if self.coord_to_columns.controls is not None:
1058
- controls_xr = (
1059
- df_indexed[self.coord_to_columns.controls]
1060
- .stack()
1061
- .rename(constants.CONTROLS)
1062
- .rename_axis(
1063
- [constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
1064
- )
1065
- .to_frame()
1066
- .to_xarray()
886
+ builder.with_controls(
887
+ self.df,
888
+ list(self.coord_to_columns.controls),
889
+ self.coord_to_columns.time,
890
+ self.coord_to_columns.geo,
1067
891
  )
1068
- dataset = xr.combine_by_coords([dataset, controls_xr])
1069
-
1070
892
  if self.coord_to_columns.non_media_treatments is not None:
1071
- non_media_xr = (
1072
- df_indexed[self.coord_to_columns.non_media_treatments]
1073
- .stack()
1074
- .rename(constants.NON_MEDIA_TREATMENTS)
1075
- .rename_axis(
1076
- [constants.GEO, constants.TIME, constants.NON_MEDIA_CHANNEL]
1077
- )
1078
- .to_frame()
1079
- .to_xarray()
893
+ builder.with_non_media_treatments(
894
+ self.df,
895
+ list(self.coord_to_columns.non_media_treatments),
896
+ self.coord_to_columns.time,
897
+ self.coord_to_columns.geo,
1080
898
  )
1081
- dataset = xr.combine_by_coords([dataset, non_media_xr])
1082
-
1083
899
  if self.coord_to_columns.revenue_per_kpi is not None:
1084
- revenue_per_kpi_xr = (
1085
- df_indexed[self.coord_to_columns.revenue_per_kpi]
1086
- .dropna()
1087
- .rename(constants.REVENUE_PER_KPI)
1088
- .rename_axis([constants.GEO, constants.TIME])
1089
- .to_frame()
1090
- .to_xarray()
1091
- )
1092
- dataset = xr.combine_by_coords([dataset, revenue_per_kpi_xr])
1093
- if self.coord_to_columns.media is not None:
1094
- media_xr = (
1095
- df_indexed[self.coord_to_columns.media]
1096
- .stack()
1097
- .rename(constants.MEDIA)
1098
- .rename_axis(
1099
- [constants.GEO, constants.MEDIA_TIME, constants.MEDIA_CHANNEL]
1100
- )
1101
- .to_frame()
1102
- .to_xarray()
1103
- )
1104
- media_xr.coords[constants.MEDIA_CHANNEL] = [
1105
- self.media_to_channel[x]
1106
- for x in media_xr.coords[constants.MEDIA_CHANNEL].values
1107
- ]
1108
-
1109
- media_spend_xr = (
1110
- df_indexed[self.coord_to_columns.media_spend]
1111
- .stack()
1112
- .rename(constants.MEDIA_SPEND)
1113
- .rename_axis([constants.GEO, constants.TIME, constants.MEDIA_CHANNEL])
1114
- .to_frame()
1115
- .to_xarray()
1116
- )
1117
- media_spend_xr.coords[constants.MEDIA_CHANNEL] = [
1118
- self.media_spend_to_channel[x]
1119
- for x in media_spend_xr.coords[constants.MEDIA_CHANNEL].values
1120
- ]
1121
- dataset = xr.combine_by_coords([dataset, media_xr, media_spend_xr])
1122
-
1123
- if self.coord_to_columns.reach is not None:
1124
- reach_xr = (
1125
- df_indexed[self.coord_to_columns.reach]
1126
- .stack()
1127
- .rename(constants.REACH)
1128
- .rename_axis(
1129
- [constants.GEO, constants.MEDIA_TIME, constants.RF_CHANNEL]
1130
- )
1131
- .to_frame()
1132
- .to_xarray()
1133
- )
1134
- reach_xr.coords[constants.RF_CHANNEL] = [
1135
- self.reach_to_channel[x]
1136
- for x in reach_xr.coords[constants.RF_CHANNEL].values
1137
- ]
1138
-
1139
- frequency_xr = (
1140
- df_indexed[self.coord_to_columns.frequency]
1141
- .stack()
1142
- .rename(constants.FREQUENCY)
1143
- .rename_axis(
1144
- [constants.GEO, constants.MEDIA_TIME, constants.RF_CHANNEL]
1145
- )
1146
- .to_frame()
1147
- .to_xarray()
900
+ builder.with_revenue_per_kpi(
901
+ self.df,
902
+ self.coord_to_columns.revenue_per_kpi,
903
+ self.coord_to_columns.time,
904
+ self.coord_to_columns.geo,
1148
905
  )
1149
- frequency_xr.coords[constants.RF_CHANNEL] = [
1150
- self.frequency_to_channel[x]
1151
- for x in frequency_xr.coords[constants.RF_CHANNEL].values
1152
- ]
1153
-
1154
- rf_spend_xr = (
1155
- df_indexed[self.coord_to_columns.rf_spend]
1156
- .stack()
1157
- .rename(constants.RF_SPEND)
1158
- .rename_axis([constants.GEO, constants.TIME, constants.RF_CHANNEL])
1159
- .to_frame()
1160
- .to_xarray()
906
+ if (
907
+ self.media_to_channel is not None
908
+ and self.media_spend_to_channel is not None
909
+ ):
910
+ sorted_channels = sorted(self.media_to_channel.values())
911
+ inv_media_map = {v: k for k, v in self.media_to_channel.items()}
912
+ inv_spend_map = {v: k for k, v in self.media_spend_to_channel.items()}
913
+
914
+ builder.with_media(
915
+ self.df,
916
+ [inv_media_map[ch] for ch in sorted_channels],
917
+ [inv_spend_map[ch] for ch in sorted_channels],
918
+ sorted_channels,
919
+ self.coord_to_columns.time,
920
+ self.coord_to_columns.geo,
1161
921
  )
1162
- rf_spend_xr.coords[constants.RF_CHANNEL] = [
1163
- self.rf_spend_to_channel[x]
1164
- for x in rf_spend_xr.coords[constants.RF_CHANNEL].values
1165
- ]
1166
- dataset = xr.combine_by_coords(
1167
- [dataset, reach_xr, frequency_xr, rf_spend_xr]
922
+ if (
923
+ self.reach_to_channel is not None
924
+ and self.frequency_to_channel is not None
925
+ and self.rf_spend_to_channel is not None
926
+ ):
927
+ sorted_channels = sorted(self.reach_to_channel.values())
928
+ inv_reach_map = {v: k for k, v in self.reach_to_channel.items()}
929
+ inv_freq_map = {v: k for k, v in self.frequency_to_channel.items()}
930
+ inv_rf_spend_map = {v: k for k, v in self.rf_spend_to_channel.items()}
931
+ builder.with_reach(
932
+ self.df,
933
+ [inv_reach_map[ch] for ch in sorted_channels],
934
+ [inv_freq_map[ch] for ch in sorted_channels],
935
+ [inv_rf_spend_map[ch] for ch in sorted_channels],
936
+ sorted_channels,
937
+ self.coord_to_columns.time,
938
+ self.coord_to_columns.geo,
1168
939
  )
1169
-
1170
940
  if self.coord_to_columns.organic_media is not None:
1171
- organic_media_xr = (
1172
- df_indexed[self.coord_to_columns.organic_media]
1173
- .stack()
1174
- .rename(constants.ORGANIC_MEDIA)
1175
- .rename_axis([
1176
- constants.GEO,
1177
- constants.MEDIA_TIME,
1178
- constants.ORGANIC_MEDIA_CHANNEL,
1179
- ])
1180
- .to_frame()
1181
- .to_xarray()
1182
- )
1183
- dataset = xr.combine_by_coords([dataset, organic_media_xr])
1184
-
1185
- if self.coord_to_columns.organic_reach is not None:
1186
- organic_reach_xr = (
1187
- df_indexed[self.coord_to_columns.organic_reach]
1188
- .stack()
1189
- .rename(constants.ORGANIC_REACH)
1190
- .rename_axis([
1191
- constants.GEO,
1192
- constants.MEDIA_TIME,
1193
- constants.ORGANIC_RF_CHANNEL,
1194
- ])
1195
- .to_frame()
1196
- .to_xarray()
941
+ builder.with_organic_media(
942
+ self.df,
943
+ list(self.coord_to_columns.organic_media),
944
+ list(self.coord_to_columns.organic_media),
945
+ self.coord_to_columns.time,
946
+ self.coord_to_columns.geo,
1197
947
  )
1198
- organic_reach_xr.coords[constants.ORGANIC_RF_CHANNEL] = [
1199
- self.organic_reach_to_channel[x]
1200
- for x in organic_reach_xr.coords[constants.ORGANIC_RF_CHANNEL].values
1201
- ]
1202
- organic_frequency_xr = (
1203
- df_indexed[self.coord_to_columns.organic_frequency]
1204
- .stack()
1205
- .rename(constants.ORGANIC_FREQUENCY)
1206
- .rename_axis([
1207
- constants.GEO,
1208
- constants.MEDIA_TIME,
1209
- constants.ORGANIC_RF_CHANNEL,
1210
- ])
1211
- .to_frame()
1212
- .to_xarray()
1213
- )
1214
- organic_frequency_xr.coords[constants.ORGANIC_RF_CHANNEL] = [
1215
- self.organic_frequency_to_channel[x]
1216
- for x in organic_frequency_xr.coords[
1217
- constants.ORGANIC_RF_CHANNEL
1218
- ].values
1219
- ]
1220
- dataset = xr.combine_by_coords(
1221
- [dataset, organic_reach_xr, organic_frequency_xr]
948
+ if (
949
+ self.organic_reach_to_channel is not None
950
+ and self.organic_frequency_to_channel is not None
951
+ ):
952
+ sorted_channels = sorted(self.organic_reach_to_channel.values())
953
+ inv_reach_map = {v: k for k, v in self.organic_reach_to_channel.items()}
954
+ inv_freq_map = {
955
+ v: k for k, v in self.organic_frequency_to_channel.items()
956
+ }
957
+ builder.with_organic_reach(
958
+ self.df,
959
+ [inv_reach_map[ch] for ch in sorted_channels],
960
+ [inv_freq_map[ch] for ch in sorted_channels],
961
+ sorted_channels,
962
+ self.coord_to_columns.time,
963
+ self.coord_to_columns.geo,
1222
964
  )
1223
-
1224
- # Change back to geo names
1225
- self.df[geo_column_name] = self.df[geo_column_name].replace(
1226
- dict(zip(np.arange(len(geo_names)), geo_names))
1227
- )
1228
- dataset.coords[constants.GEO] = geo_names
1229
- return XrDatasetDataLoader(dataset, kpi_type=self.kpi_type).load()
965
+ return builder.build()
1230
966
 
1231
967
 
1232
968
  class CsvDataLoader(InputDataLoader):