google-meridian 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/METADATA +6 -2
- google_meridian-1.1.2.dist-info/RECORD +46 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/WHEEL +1 -1
- meridian/__init__.py +2 -2
- meridian/analysis/__init__.py +1 -1
- meridian/analysis/analyzer.py +29 -22
- meridian/analysis/formatter.py +1 -1
- meridian/analysis/optimizer.py +70 -44
- meridian/analysis/summarizer.py +1 -1
- meridian/analysis/summary_text.py +1 -1
- meridian/analysis/test_utils.py +1 -1
- meridian/analysis/visualizer.py +17 -8
- meridian/constants.py +3 -3
- meridian/data/__init__.py +4 -1
- meridian/data/arg_builder.py +1 -1
- meridian/data/data_frame_input_data_builder.py +614 -0
- meridian/data/input_data.py +12 -8
- meridian/data/input_data_builder.py +817 -0
- meridian/data/load.py +121 -428
- meridian/data/nd_array_input_data_builder.py +509 -0
- meridian/data/test_utils.py +60 -43
- meridian/data/time_coordinates.py +1 -1
- meridian/mlflow/__init__.py +17 -0
- meridian/mlflow/autolog.py +54 -0
- meridian/model/__init__.py +1 -1
- meridian/model/adstock_hill.py +1 -1
- meridian/model/knots.py +1 -1
- meridian/model/media.py +1 -1
- meridian/model/model.py +65 -37
- meridian/model/model_test_data.py +75 -1
- meridian/model/posterior_sampler.py +19 -15
- meridian/model/prior_distribution.py +1 -1
- meridian/model/prior_sampler.py +32 -26
- meridian/model/spec.py +18 -8
- meridian/model/transformers.py +1 -1
- google_meridian-1.1.0.dist-info/RECORD +0 -41
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.2.dist-info}/top_level.txt +0 -0
meridian/data/load.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -27,6 +27,7 @@ import warnings
|
|
|
27
27
|
|
|
28
28
|
import immutabledict
|
|
29
29
|
from meridian import constants
|
|
30
|
+
from meridian.data import data_frame_input_data_builder
|
|
30
31
|
from meridian.data import input_data
|
|
31
32
|
import numpy as np
|
|
32
33
|
import pandas as pd
|
|
@@ -79,7 +80,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
79
80
|
"""Constructor.
|
|
80
81
|
|
|
81
82
|
The coordinates of the input dataset should be: `time`, `media_time`,
|
|
82
|
-
`control_variable
|
|
83
|
+
`control_variable` (optional), `geo` (optional for a national model),
|
|
83
84
|
`non_media_channel` (optional), `organic_media_channel` (optional),
|
|
84
85
|
`organic_rf_channel` (optional), and
|
|
85
86
|
either `media_channel`, `rf_channel`, or both.
|
|
@@ -93,7 +94,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
93
94
|
|
|
94
95
|
* `kpi`: `(geo, time)`
|
|
95
96
|
* `revenue_per_kpi`: `(geo, time)`
|
|
96
|
-
* `controls`: `(geo, time, control_variable)`
|
|
97
|
+
* `controls`: `(geo, time, control_variable)` - optional
|
|
97
98
|
* `population`: `(geo)`
|
|
98
99
|
* `media`: `(geo, media_time, media_channel)` - optional
|
|
99
100
|
* `media_spend`: `(geo, time, media_channel)`, `(1, time, media_channel)`,
|
|
@@ -113,7 +114,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
113
114
|
|
|
114
115
|
* `kpi`: `([1,] time)`
|
|
115
116
|
* `revenue_per_kpi`: `([1,] time)`
|
|
116
|
-
* `controls`: `([1,] time, control_variable)`
|
|
117
|
+
* `controls`: `([1,] time, control_variable)` - optional
|
|
117
118
|
* `population`: `([1],)` - this array is optional for national data
|
|
118
119
|
* `media`: `([1,] media_time, media_channel)` - optional
|
|
119
120
|
* `media_spend`: `([1,] time, media_channel)` or
|
|
@@ -198,7 +199,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
198
199
|
self.dataset = dataset.rename(name_mapping)
|
|
199
200
|
|
|
200
201
|
# Add a `geo` dimension if it is not already present.
|
|
201
|
-
if (constants.GEO) not in self.dataset.
|
|
202
|
+
if (constants.GEO) not in self.dataset.sizes.keys():
|
|
202
203
|
self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
|
|
203
204
|
|
|
204
205
|
if len(self.dataset.coords[constants.GEO]) == 1:
|
|
@@ -228,7 +229,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
228
229
|
compat='override',
|
|
229
230
|
)
|
|
230
231
|
|
|
231
|
-
if constants.MEDIA_TIME not in self.dataset.
|
|
232
|
+
if constants.MEDIA_TIME not in self.dataset.sizes.keys():
|
|
232
233
|
self._add_media_time()
|
|
233
234
|
self._normalize_time_coordinates(constants.TIME)
|
|
234
235
|
self._normalize_time_coordinates(constants.MEDIA_TIME)
|
|
@@ -349,14 +350,17 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
349
350
|
# Arrays in which NAs are expected in the lagged-media period.
|
|
350
351
|
na_arrays = [
|
|
351
352
|
constants.KPI,
|
|
352
|
-
constants.CONTROLS,
|
|
353
353
|
]
|
|
354
354
|
|
|
355
|
-
na_mask = self.dataset[constants.KPI].isnull().any(
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
355
|
+
na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
|
|
356
|
+
|
|
357
|
+
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
358
|
+
na_arrays.append(constants.CONTROLS)
|
|
359
|
+
na_mask |= (
|
|
360
|
+
self.dataset[constants.CONTROLS]
|
|
361
|
+
.isnull()
|
|
362
|
+
.any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
|
|
363
|
+
)
|
|
360
364
|
|
|
361
365
|
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
362
366
|
na_arrays.append(constants.NON_MEDIA_TREATMENTS)
|
|
@@ -427,11 +431,12 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
427
431
|
.dropna(dim=constants.TIME)
|
|
428
432
|
.rename({constants.TIME: new_time})
|
|
429
433
|
)
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
434
|
+
if constants.CONTROLS in new_dataset.data_vars.keys():
|
|
435
|
+
new_dataset[constants.CONTROLS] = (
|
|
436
|
+
new_dataset[constants.CONTROLS]
|
|
437
|
+
.dropna(dim=constants.TIME)
|
|
438
|
+
.rename({constants.TIME: new_time})
|
|
439
|
+
)
|
|
435
440
|
if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
|
|
436
441
|
new_dataset[constants.NON_MEDIA_TREATMENTS] = (
|
|
437
442
|
new_dataset[constants.NON_MEDIA_TREATMENTS]
|
|
@@ -466,6 +471,11 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
466
471
|
|
|
467
472
|
def load(self) -> input_data.InputData:
|
|
468
473
|
"""Returns an `InputData` object containing the data from the dataset."""
|
|
474
|
+
controls = (
|
|
475
|
+
self.dataset.controls
|
|
476
|
+
if constants.CONTROLS in self.dataset.data_vars.keys()
|
|
477
|
+
else None
|
|
478
|
+
)
|
|
469
479
|
revenue_per_kpi = (
|
|
470
480
|
self.dataset.revenue_per_kpi
|
|
471
481
|
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys()
|
|
@@ -519,9 +529,9 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
519
529
|
return input_data.InputData(
|
|
520
530
|
kpi=self.dataset.kpi,
|
|
521
531
|
kpi_type=self.kpi_type,
|
|
522
|
-
revenue_per_kpi=revenue_per_kpi,
|
|
523
|
-
controls=self.dataset.controls,
|
|
524
532
|
population=self.dataset.population,
|
|
533
|
+
controls=controls,
|
|
534
|
+
revenue_per_kpi=revenue_per_kpi,
|
|
525
535
|
media=media,
|
|
526
536
|
media_spend=media_spend,
|
|
527
537
|
reach=reach,
|
|
@@ -539,14 +549,14 @@ class CoordToColumns:
|
|
|
539
549
|
"""A mapping between the desired and actual column names in the input data.
|
|
540
550
|
|
|
541
551
|
Attributes:
|
|
542
|
-
controls: List of column names containing `controls` values in the input
|
|
543
|
-
data.
|
|
544
552
|
time: Name of column containing `time` values in the input data.
|
|
545
|
-
kpi: Name of column containing `kpi` values in the input data.
|
|
546
|
-
revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
|
|
547
|
-
input data.
|
|
548
553
|
geo: Name of column containing `geo` values in the input data. This field
|
|
549
554
|
is optional for a national model.
|
|
555
|
+
kpi: Name of column containing `kpi` values in the input data.
|
|
556
|
+
controls: List of column names containing `controls` values in the input
|
|
557
|
+
data. Optional.
|
|
558
|
+
revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
|
|
559
|
+
input data. Optional. Will be overridden if model KPI type is "revenue".
|
|
550
560
|
population: Name of column containing `population` values in the input data.
|
|
551
561
|
This field is optional for a national model.
|
|
552
562
|
media: List of column names containing `media` values in the input data.
|
|
@@ -567,11 +577,11 @@ class CoordToColumns:
|
|
|
567
577
|
values in the input data.
|
|
568
578
|
"""
|
|
569
579
|
|
|
570
|
-
controls: Sequence[str]
|
|
571
580
|
time: str = constants.TIME
|
|
581
|
+
geo: str = constants.GEO
|
|
572
582
|
kpi: str = constants.KPI
|
|
583
|
+
controls: Sequence[str] | None = None
|
|
573
584
|
revenue_per_kpi: str | None = None
|
|
574
|
-
geo: str = constants.GEO
|
|
575
585
|
population: str = constants.POPULATION
|
|
576
586
|
# Media data
|
|
577
587
|
media: Sequence[str] | None = None
|
|
@@ -607,7 +617,7 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
607
617
|
to the DataFrame column names if they are different. The fields are:
|
|
608
618
|
|
|
609
619
|
* `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
|
|
610
|
-
* `controls` (multiple columns)
|
|
620
|
+
* `controls` (multiple columns, optional)
|
|
611
621
|
* (1) `media`, `media_spend` (multiple columns)
|
|
612
622
|
* (2) `reach`, `frequency`, `rf_spend` (multiple columns)
|
|
613
623
|
* `non_media_treatments` (multiple columns, optional)
|
|
@@ -792,110 +802,20 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
792
802
|
organic_reach_to_channel: Mapping[str, str] | None = None
|
|
793
803
|
organic_frequency_to_channel: Mapping[str, str] | None = None
|
|
794
804
|
|
|
795
|
-
# If [key] in the following dict exists as an attribute in `coord_to_columns`,
|
|
796
|
-
# then the corresponding attribute must exist in this loader instance.
|
|
797
|
-
_required_mappings = immutabledict.immutabledict({
|
|
798
|
-
'media': 'media_to_channel',
|
|
799
|
-
'media_spend': 'media_spend_to_channel',
|
|
800
|
-
'reach': 'reach_to_channel',
|
|
801
|
-
'frequency': 'frequency_to_channel',
|
|
802
|
-
'rf_spend': 'rf_spend_to_channel',
|
|
803
|
-
'organic_reach': 'organic_reach_to_channel',
|
|
804
|
-
'organic_frequency': 'organic_frequency_to_channel',
|
|
805
|
-
})
|
|
806
|
-
|
|
807
805
|
def __post_init__(self):
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
In (b) case, `datetime` coordinate values will be normalized as formatted
|
|
823
|
-
strings.
|
|
824
|
-
"""
|
|
825
|
-
time_column_name = self.coord_to_columns.time
|
|
826
|
-
|
|
827
|
-
if self.df.dtypes[time_column_name] == np.dtype('datetime64[ns]'):
|
|
828
|
-
self.df[time_column_name] = self.df[time_column_name].map(
|
|
829
|
-
lambda time: time.strftime(constants.DATE_FORMAT)
|
|
830
|
-
)
|
|
831
|
-
else:
|
|
832
|
-
# Assume that the `time` column values are strings formatted as dates.
|
|
833
|
-
for _, time in self.df[time_column_name].items():
|
|
834
|
-
try:
|
|
835
|
-
_ = dt.datetime.strptime(time, constants.DATE_FORMAT)
|
|
836
|
-
except ValueError as exc:
|
|
837
|
-
raise ValueError(
|
|
838
|
-
f"Invalid time label: '{time}'. Expected format:"
|
|
839
|
-
f" '{constants.DATE_FORMAT}'"
|
|
840
|
-
) from exc
|
|
841
|
-
|
|
842
|
-
def _validate_column_names(self):
|
|
843
|
-
"""Validates the column names in `df` and `coord_to_columns`."""
|
|
844
|
-
|
|
845
|
-
desired_columns = []
|
|
846
|
-
for field in dataclasses.fields(self.coord_to_columns):
|
|
847
|
-
value = getattr(self.coord_to_columns, field.name)
|
|
848
|
-
if isinstance(value, str):
|
|
849
|
-
desired_columns.append(value)
|
|
850
|
-
elif isinstance(value, Sequence):
|
|
851
|
-
for column in value:
|
|
852
|
-
desired_columns.append(column)
|
|
853
|
-
desired_columns = sorted(desired_columns)
|
|
854
|
-
|
|
855
|
-
actual_columns = sorted(self.df.columns.to_list())
|
|
856
|
-
if any(d not in actual_columns for d in desired_columns):
|
|
857
|
-
raise ValueError(
|
|
858
|
-
f'Values of the `coord_to_columns` object {desired_columns}'
|
|
859
|
-
f' should map to the DataFrame column names {actual_columns}.'
|
|
860
|
-
)
|
|
861
|
-
|
|
862
|
-
def _expand_if_national(self):
|
|
863
|
-
"""Adds geo/population columns in a national model if necessary."""
|
|
864
|
-
|
|
865
|
-
geo_column_name = self.coord_to_columns.geo
|
|
866
|
-
population_column_name = self.coord_to_columns.population
|
|
867
|
-
|
|
868
|
-
def set_default_population_with_lag_periods():
|
|
869
|
-
"""Sets the `population` column.
|
|
870
|
-
|
|
871
|
-
The `population` column is set to the default value for non-lag periods,
|
|
872
|
-
and None for lag-periods. The lag periods are inferred from the Nan values
|
|
873
|
-
in the other non-media columns.
|
|
874
|
-
"""
|
|
875
|
-
non_lagged_idx = self.df.isna().idxmin().max()
|
|
876
|
-
self.df[population_column_name] = (
|
|
877
|
-
constants.NATIONAL_MODEL_DEFAULT_POPULATION_VALUE
|
|
878
|
-
)
|
|
879
|
-
self.df.loc[: non_lagged_idx - 1, population_column_name] = None
|
|
880
|
-
|
|
881
|
-
if geo_column_name not in self.df.columns:
|
|
882
|
-
self.df[geo_column_name] = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
|
|
883
|
-
|
|
884
|
-
if self.df[geo_column_name].nunique() == 1:
|
|
885
|
-
self.df[geo_column_name] = constants.NATIONAL_MODEL_DEFAULT_GEO_NAME
|
|
886
|
-
if population_column_name in self.df.columns:
|
|
887
|
-
warnings.warn(
|
|
888
|
-
'The `population` argument is ignored in a nationally aggregated'
|
|
889
|
-
' model. It will be reset to [1, 1, ..., 1]'
|
|
890
|
-
)
|
|
891
|
-
set_default_population_with_lag_periods()
|
|
892
|
-
|
|
893
|
-
if population_column_name not in self.df.columns:
|
|
894
|
-
set_default_population_with_lag_periods()
|
|
895
|
-
|
|
896
|
-
def _validate_required_mappings(self):
|
|
897
|
-
"""Validates required mappings in `coord_to_columns`."""
|
|
898
|
-
for coord_name, channel_dict in self._required_mappings.items():
|
|
806
|
+
# If [key] in the following dict exists as an attribute in
|
|
807
|
+
# `coord_to_columns`, then the corresponding attribute must exist in this
|
|
808
|
+
# loader instance.
|
|
809
|
+
required_mappings = immutabledict.immutabledict({
|
|
810
|
+
'media': 'media_to_channel',
|
|
811
|
+
'media_spend': 'media_spend_to_channel',
|
|
812
|
+
'reach': 'reach_to_channel',
|
|
813
|
+
'frequency': 'frequency_to_channel',
|
|
814
|
+
'rf_spend': 'rf_spend_to_channel',
|
|
815
|
+
'organic_reach': 'organic_reach_to_channel',
|
|
816
|
+
'organic_frequency': 'organic_frequency_to_channel',
|
|
817
|
+
})
|
|
818
|
+
for coord_name, channel_dict in required_mappings.items():
|
|
899
819
|
if (
|
|
900
820
|
getattr(self.coord_to_columns, coord_name, None) is not None
|
|
901
821
|
and getattr(self, channel_dict, None) is None
|
|
@@ -904,316 +824,89 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
904
824
|
f"When {coord_name} data is provided, '{channel_dict}' is required."
|
|
905
825
|
)
|
|
906
826
|
|
|
907
|
-
def _validate_geo_and_time(self):
|
|
908
|
-
"""Validates that for every geo the list of `time`s is the same."""
|
|
909
|
-
geo_column_name = self.coord_to_columns.geo
|
|
910
|
-
time_column_name = self.coord_to_columns.time
|
|
911
|
-
|
|
912
|
-
df_grouped = self.df.sort_values(time_column_name).groupby(
|
|
913
|
-
geo_column_name, sort=False
|
|
914
|
-
)[time_column_name]
|
|
915
|
-
if any(df_grouped.count() != df_grouped.nunique()):
|
|
916
|
-
raise ValueError("Duplicate entries found in the 'time' column.")
|
|
917
|
-
|
|
918
|
-
times_by_geo = df_grouped.apply(list).reset_index(drop=True)
|
|
919
|
-
if any(t != times_by_geo[0] for t in times_by_geo[1:]):
|
|
920
|
-
raise ValueError(
|
|
921
|
-
"Values in the 'time' column not consistent across different geos."
|
|
922
|
-
)
|
|
923
|
-
|
|
924
|
-
def _validate_nas(self):
|
|
925
|
-
"""Validates that the only NAs are in the lagged-media period."""
|
|
926
|
-
# Check if there are no NAs in media.
|
|
927
|
-
if self.coord_to_columns.media is not None:
|
|
928
|
-
if self.df[self.coord_to_columns.media].isna().any(axis=None):
|
|
929
|
-
raise ValueError('NA values found in the media columns.')
|
|
930
|
-
|
|
931
|
-
# Check if there are no NAs in reach & frequency.
|
|
932
|
-
if self.coord_to_columns.reach is not None:
|
|
933
|
-
if self.df[self.coord_to_columns.reach].isna().any(axis=None):
|
|
934
|
-
raise ValueError('NA values found in the reach columns.')
|
|
935
|
-
if self.coord_to_columns.frequency is not None:
|
|
936
|
-
if self.df[self.coord_to_columns.frequency].isna().any(axis=None):
|
|
937
|
-
raise ValueError('NA values found in the frequency columns.')
|
|
938
|
-
|
|
939
|
-
# Check if ther are no NAs in organic_media.
|
|
940
|
-
if self.coord_to_columns.organic_media is not None:
|
|
941
|
-
if self.df[self.coord_to_columns.organic_media].isna().any(axis=None):
|
|
942
|
-
raise ValueError('NA values found in the organic_media columns.')
|
|
943
|
-
|
|
944
|
-
# Check if there are no NAs in organic_reach & organic_frequency.
|
|
945
|
-
if self.coord_to_columns.organic_reach is not None:
|
|
946
|
-
if self.df[self.coord_to_columns.organic_reach].isna().any(axis=None):
|
|
947
|
-
raise ValueError('NA values found in the organic_reach columns.')
|
|
948
|
-
if self.coord_to_columns.organic_frequency is not None:
|
|
949
|
-
if self.df[self.coord_to_columns.organic_frequency].isna().any(axis=None):
|
|
950
|
-
raise ValueError('NA values found in the organic_frequency columns.')
|
|
951
|
-
|
|
952
|
-
# Determine columns in which NAs are expected in the lagged-media period.
|
|
953
|
-
not_lagged_columns = []
|
|
954
|
-
coords = [
|
|
955
|
-
constants.KPI,
|
|
956
|
-
constants.CONTROLS,
|
|
957
|
-
constants.POPULATION,
|
|
958
|
-
]
|
|
959
|
-
if self.coord_to_columns.revenue_per_kpi is not None:
|
|
960
|
-
coords.append(constants.REVENUE_PER_KPI)
|
|
961
|
-
if self.coord_to_columns.media_spend is not None:
|
|
962
|
-
coords.append(constants.MEDIA_SPEND)
|
|
963
|
-
if self.coord_to_columns.rf_spend is not None:
|
|
964
|
-
coords.append(constants.RF_SPEND)
|
|
965
|
-
if self.coord_to_columns.non_media_treatments is not None:
|
|
966
|
-
coords.append(constants.NON_MEDIA_TREATMENTS)
|
|
967
|
-
for coord in coords:
|
|
968
|
-
columns = getattr(self.coord_to_columns, coord)
|
|
969
|
-
columns = [columns] if isinstance(columns, str) else columns
|
|
970
|
-
not_lagged_columns.extend(columns)
|
|
971
|
-
|
|
972
|
-
# Dates with at least one non-NA value in columns different from media,
|
|
973
|
-
# reach, frequency, organic_media, organic_reach, and organic_frequency.
|
|
974
|
-
time_column_name = self.coord_to_columns.time
|
|
975
|
-
no_na_period = self.df[(~self.df[not_lagged_columns].isna()).any(axis=1)][
|
|
976
|
-
time_column_name
|
|
977
|
-
].unique()
|
|
978
|
-
|
|
979
|
-
# Dates with 100% NA values in all columns different from media, reach,
|
|
980
|
-
# frequency, organic_media, organic_reach, and organic_frequency.
|
|
981
|
-
na_period = [
|
|
982
|
-
t for t in self.df[time_column_name].unique() if t not in no_na_period
|
|
983
|
-
]
|
|
984
|
-
|
|
985
|
-
# Check if na_period is a continuous window starting from the earliest time
|
|
986
|
-
# period.
|
|
987
|
-
if not np.all(
|
|
988
|
-
np.sort(na_period)
|
|
989
|
-
== np.sort(self.df[time_column_name].unique())[: len(na_period)]
|
|
990
|
-
):
|
|
991
|
-
raise ValueError(
|
|
992
|
-
"The 'lagged media' period (period with 100% NA values in all"
|
|
993
|
-
f' non-media columns) {na_period} is not a continuous window starting'
|
|
994
|
-
' from the earliest time period.'
|
|
995
|
-
)
|
|
996
|
-
|
|
997
|
-
# Check if for the non-lagged period, there are no NAs in data different
|
|
998
|
-
# from media, reach, frequency, organic_media, organic_reach, and
|
|
999
|
-
# organic_frequency.
|
|
1000
|
-
not_lagged_data = self.df.loc[
|
|
1001
|
-
self.df[time_column_name].isin(no_na_period),
|
|
1002
|
-
not_lagged_columns,
|
|
1003
|
-
]
|
|
1004
|
-
if not_lagged_data.isna().any(axis=None):
|
|
1005
|
-
incorrect_columns = []
|
|
1006
|
-
for column in not_lagged_columns:
|
|
1007
|
-
if not_lagged_data[column].isna().any(axis=None):
|
|
1008
|
-
incorrect_columns.append(column)
|
|
1009
|
-
raise ValueError(
|
|
1010
|
-
f'NA values found in columns {incorrect_columns} within the modeling'
|
|
1011
|
-
' time window (time periods where the KPI is modeled).'
|
|
1012
|
-
)
|
|
1013
|
-
|
|
1014
827
|
def load(self) -> input_data.InputData:
|
|
1015
828
|
"""Reads data from a dataframe and returns an InputData object."""
|
|
1016
829
|
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
dict(zip(geo_names, np.arange(len(geo_names))))
|
|
830
|
+
builder = data_frame_input_data_builder.DataFrameInputDataBuilder(
|
|
831
|
+
kpi_type=self.kpi_type
|
|
832
|
+
).with_kpi(
|
|
833
|
+
self.df,
|
|
834
|
+
self.coord_to_columns.kpi,
|
|
835
|
+
self.coord_to_columns.time,
|
|
836
|
+
self.coord_to_columns.geo,
|
|
1025
837
|
)
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
df_indexed[self.coord_to_columns.kpi]
|
|
1030
|
-
.dropna()
|
|
1031
|
-
.rename(constants.KPI)
|
|
1032
|
-
.rename_axis([constants.GEO, constants.TIME])
|
|
1033
|
-
.to_frame()
|
|
1034
|
-
.to_xarray()
|
|
1035
|
-
)
|
|
1036
|
-
population_xr = (
|
|
1037
|
-
df_indexed[self.coord_to_columns.population]
|
|
1038
|
-
.groupby(geo_column_name)
|
|
1039
|
-
.mean()
|
|
1040
|
-
.rename(constants.POPULATION)
|
|
1041
|
-
.rename_axis([constants.GEO])
|
|
1042
|
-
.to_frame()
|
|
1043
|
-
.to_xarray()
|
|
1044
|
-
)
|
|
1045
|
-
controls_xr = (
|
|
1046
|
-
df_indexed[self.coord_to_columns.controls]
|
|
1047
|
-
.stack()
|
|
1048
|
-
.rename(constants.CONTROLS)
|
|
1049
|
-
.rename_axis(
|
|
1050
|
-
[constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
|
|
1051
|
-
)
|
|
1052
|
-
.to_frame()
|
|
1053
|
-
.to_xarray()
|
|
1054
|
-
)
|
|
1055
|
-
dataset = xr.combine_by_coords([kpi_xr, population_xr, controls_xr])
|
|
1056
|
-
|
|
1057
|
-
if self.coord_to_columns.non_media_treatments is not None:
|
|
1058
|
-
non_media_xr = (
|
|
1059
|
-
df_indexed[self.coord_to_columns.non_media_treatments]
|
|
1060
|
-
.stack()
|
|
1061
|
-
.rename(constants.NON_MEDIA_TREATMENTS)
|
|
1062
|
-
.rename_axis(
|
|
1063
|
-
[constants.GEO, constants.TIME, constants.NON_MEDIA_CHANNEL]
|
|
1064
|
-
)
|
|
1065
|
-
.to_frame()
|
|
1066
|
-
.to_xarray()
|
|
838
|
+
if self.coord_to_columns.population in self.df.columns:
|
|
839
|
+
builder.with_population(
|
|
840
|
+
self.df, self.coord_to_columns.population, self.coord_to_columns.geo
|
|
1067
841
|
)
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
.
|
|
1074
|
-
.rename(constants.REVENUE_PER_KPI)
|
|
1075
|
-
.rename_axis([constants.GEO, constants.TIME])
|
|
1076
|
-
.to_frame()
|
|
1077
|
-
.to_xarray()
|
|
1078
|
-
)
|
|
1079
|
-
dataset = xr.combine_by_coords([dataset, revenue_per_kpi_xr])
|
|
1080
|
-
if self.coord_to_columns.media is not None:
|
|
1081
|
-
media_xr = (
|
|
1082
|
-
df_indexed[self.coord_to_columns.media]
|
|
1083
|
-
.stack()
|
|
1084
|
-
.rename(constants.MEDIA)
|
|
1085
|
-
.rename_axis(
|
|
1086
|
-
[constants.GEO, constants.MEDIA_TIME, constants.MEDIA_CHANNEL]
|
|
1087
|
-
)
|
|
1088
|
-
.to_frame()
|
|
1089
|
-
.to_xarray()
|
|
1090
|
-
)
|
|
1091
|
-
media_xr.coords[constants.MEDIA_CHANNEL] = [
|
|
1092
|
-
self.media_to_channel[x]
|
|
1093
|
-
for x in media_xr.coords[constants.MEDIA_CHANNEL].values
|
|
1094
|
-
]
|
|
1095
|
-
|
|
1096
|
-
media_spend_xr = (
|
|
1097
|
-
df_indexed[self.coord_to_columns.media_spend]
|
|
1098
|
-
.stack()
|
|
1099
|
-
.rename(constants.MEDIA_SPEND)
|
|
1100
|
-
.rename_axis([constants.GEO, constants.TIME, constants.MEDIA_CHANNEL])
|
|
1101
|
-
.to_frame()
|
|
1102
|
-
.to_xarray()
|
|
842
|
+
if self.coord_to_columns.controls is not None:
|
|
843
|
+
builder.with_controls(
|
|
844
|
+
self.df,
|
|
845
|
+
list(self.coord_to_columns.controls),
|
|
846
|
+
self.coord_to_columns.time,
|
|
847
|
+
self.coord_to_columns.geo,
|
|
1103
848
|
)
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
if self.coord_to_columns.reach is not None:
|
|
1111
|
-
reach_xr = (
|
|
1112
|
-
df_indexed[self.coord_to_columns.reach]
|
|
1113
|
-
.stack()
|
|
1114
|
-
.rename(constants.REACH)
|
|
1115
|
-
.rename_axis(
|
|
1116
|
-
[constants.GEO, constants.MEDIA_TIME, constants.RF_CHANNEL]
|
|
1117
|
-
)
|
|
1118
|
-
.to_frame()
|
|
1119
|
-
.to_xarray()
|
|
1120
|
-
)
|
|
1121
|
-
reach_xr.coords[constants.RF_CHANNEL] = [
|
|
1122
|
-
self.reach_to_channel[x]
|
|
1123
|
-
for x in reach_xr.coords[constants.RF_CHANNEL].values
|
|
1124
|
-
]
|
|
1125
|
-
|
|
1126
|
-
frequency_xr = (
|
|
1127
|
-
df_indexed[self.coord_to_columns.frequency]
|
|
1128
|
-
.stack()
|
|
1129
|
-
.rename(constants.FREQUENCY)
|
|
1130
|
-
.rename_axis(
|
|
1131
|
-
[constants.GEO, constants.MEDIA_TIME, constants.RF_CHANNEL]
|
|
1132
|
-
)
|
|
1133
|
-
.to_frame()
|
|
1134
|
-
.to_xarray()
|
|
849
|
+
if self.coord_to_columns.non_media_treatments is not None:
|
|
850
|
+
builder.with_non_media_treatments(
|
|
851
|
+
self.df,
|
|
852
|
+
list(self.coord_to_columns.non_media_treatments),
|
|
853
|
+
self.coord_to_columns.time,
|
|
854
|
+
self.coord_to_columns.geo,
|
|
1135
855
|
)
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
df_indexed[self.coord_to_columns.rf_spend]
|
|
1143
|
-
.stack()
|
|
1144
|
-
.rename(constants.RF_SPEND)
|
|
1145
|
-
.rename_axis([constants.GEO, constants.TIME, constants.RF_CHANNEL])
|
|
1146
|
-
.to_frame()
|
|
1147
|
-
.to_xarray()
|
|
856
|
+
if self.coord_to_columns.revenue_per_kpi is not None:
|
|
857
|
+
builder.with_revenue_per_kpi(
|
|
858
|
+
self.df,
|
|
859
|
+
self.coord_to_columns.revenue_per_kpi,
|
|
860
|
+
self.coord_to_columns.time,
|
|
861
|
+
self.coord_to_columns.geo,
|
|
1148
862
|
)
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
863
|
+
if (
|
|
864
|
+
self.coord_to_columns.media is not None
|
|
865
|
+
and self.media_to_channel is not None
|
|
866
|
+
):
|
|
867
|
+
builder.with_media(
|
|
868
|
+
self.df,
|
|
869
|
+
list(self.coord_to_columns.media),
|
|
870
|
+
list(self.coord_to_columns.media_spend),
|
|
871
|
+
list(self.media_to_channel.values()),
|
|
872
|
+
self.coord_to_columns.time,
|
|
873
|
+
self.coord_to_columns.geo,
|
|
1155
874
|
)
|
|
1156
875
|
|
|
1157
|
-
if
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
.
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
.
|
|
1168
|
-
.
|
|
1169
|
-
)
|
|
1170
|
-
dataset = xr.combine_by_coords([dataset, organic_media_xr])
|
|
1171
|
-
|
|
1172
|
-
if self.coord_to_columns.organic_reach is not None:
|
|
1173
|
-
organic_reach_xr = (
|
|
1174
|
-
df_indexed[self.coord_to_columns.organic_reach]
|
|
1175
|
-
.stack()
|
|
1176
|
-
.rename(constants.ORGANIC_REACH)
|
|
1177
|
-
.rename_axis([
|
|
1178
|
-
constants.GEO,
|
|
1179
|
-
constants.MEDIA_TIME,
|
|
1180
|
-
constants.ORGANIC_RF_CHANNEL,
|
|
1181
|
-
])
|
|
1182
|
-
.to_frame()
|
|
1183
|
-
.to_xarray()
|
|
876
|
+
if (
|
|
877
|
+
self.coord_to_columns.reach is not None
|
|
878
|
+
and self.reach_to_channel is not None
|
|
879
|
+
):
|
|
880
|
+
builder.with_reach(
|
|
881
|
+
self.df,
|
|
882
|
+
list(self.coord_to_columns.reach),
|
|
883
|
+
list(self.coord_to_columns.frequency),
|
|
884
|
+
list(self.coord_to_columns.rf_spend),
|
|
885
|
+
list(self.reach_to_channel.values()),
|
|
886
|
+
self.coord_to_columns.time,
|
|
887
|
+
self.coord_to_columns.geo,
|
|
1184
888
|
)
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
.
|
|
1192
|
-
.rename(constants.ORGANIC_FREQUENCY)
|
|
1193
|
-
.rename_axis([
|
|
1194
|
-
constants.GEO,
|
|
1195
|
-
constants.MEDIA_TIME,
|
|
1196
|
-
constants.ORGANIC_RF_CHANNEL,
|
|
1197
|
-
])
|
|
1198
|
-
.to_frame()
|
|
1199
|
-
.to_xarray()
|
|
889
|
+
if self.coord_to_columns.organic_media is not None:
|
|
890
|
+
builder.with_organic_media(
|
|
891
|
+
self.df,
|
|
892
|
+
list(self.coord_to_columns.organic_media),
|
|
893
|
+
list(self.coord_to_columns.organic_media),
|
|
894
|
+
self.coord_to_columns.time,
|
|
895
|
+
self.coord_to_columns.geo,
|
|
1200
896
|
)
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
897
|
+
if (
|
|
898
|
+
self.coord_to_columns.organic_reach is not None
|
|
899
|
+
and self.organic_reach_to_channel is not None
|
|
900
|
+
):
|
|
901
|
+
builder.with_organic_reach(
|
|
902
|
+
self.df,
|
|
903
|
+
list(self.coord_to_columns.organic_reach),
|
|
904
|
+
list(self.coord_to_columns.organic_frequency),
|
|
905
|
+
list(self.organic_reach_to_channel.values()),
|
|
906
|
+
self.coord_to_columns.time,
|
|
907
|
+
self.coord_to_columns.geo,
|
|
1209
908
|
)
|
|
1210
|
-
|
|
1211
|
-
# Change back to geo names
|
|
1212
|
-
self.df[geo_column_name] = self.df[geo_column_name].replace(
|
|
1213
|
-
dict(zip(np.arange(len(geo_names)), geo_names))
|
|
1214
|
-
)
|
|
1215
|
-
dataset.coords[constants.GEO] = geo_names
|
|
1216
|
-
return XrDatasetDataLoader(dataset, kpi_type=self.kpi_type).load()
|
|
909
|
+
return builder.build()
|
|
1217
910
|
|
|
1218
911
|
|
|
1219
912
|
class CsvDataLoader(InputDataLoader):
|
|
@@ -1224,7 +917,7 @@ class CsvDataLoader(InputDataLoader):
|
|
|
1224
917
|
CSV column names, if they are different. The fields are:
|
|
1225
918
|
|
|
1226
919
|
* `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
|
|
1227
|
-
* `controls` (multiple columns)
|
|
920
|
+
* `controls` (multiple columns, optional)
|
|
1228
921
|
* (1) `media`, `media_spend` (multiple columns)
|
|
1229
922
|
* (2) `reach`, `frequency`, `rf_spend` (multiple columns)
|
|
1230
923
|
* `non_media_treatments` (multiple columns, optional)
|