google-meridian 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.1.dist-info}/METADATA +2 -2
- google_meridian-1.1.1.dist-info/RECORD +41 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.1.dist-info}/WHEEL +1 -1
- meridian/__init__.py +2 -2
- meridian/analysis/__init__.py +1 -1
- meridian/analysis/analyzer.py +18 -17
- meridian/analysis/formatter.py +1 -1
- meridian/analysis/optimizer.py +1 -1
- meridian/analysis/summarizer.py +1 -1
- meridian/analysis/summary_text.py +1 -1
- meridian/analysis/test_utils.py +1 -1
- meridian/analysis/visualizer.py +2 -3
- meridian/constants.py +3 -3
- meridian/data/__init__.py +1 -1
- meridian/data/arg_builder.py +1 -1
- meridian/data/input_data.py +12 -8
- meridian/data/load.py +53 -40
- meridian/data/test_utils.py +60 -43
- meridian/data/time_coordinates.py +1 -1
- meridian/model/__init__.py +1 -1
- meridian/model/adstock_hill.py +1 -1
- meridian/model/knots.py +1 -1
- meridian/model/media.py +1 -1
- meridian/model/model.py +47 -27
- meridian/model/model_test_data.py +75 -1
- meridian/model/posterior_sampler.py +19 -15
- meridian/model/prior_distribution.py +1 -1
- meridian/model/prior_sampler.py +32 -26
- meridian/model/spec.py +1 -1
- meridian/model/transformers.py +1 -1
- google_meridian-1.1.0.dist-info/RECORD +0 -41
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.0.dist-info → google_meridian-1.1.1.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: google-meridian
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.1
|
|
4
4
|
Summary: Google's open source mixed marketing model library, helps you understand your return on investment and direct your ad spend with confidence.
|
|
5
5
|
Author-email: The Meridian Authors <no-reply@google.com>
|
|
6
6
|
License:
|
|
@@ -393,7 +393,7 @@ To cite this repository:
|
|
|
393
393
|
author = {Google Meridian Marketing Mix Modeling Team},
|
|
394
394
|
title = {Meridian: Marketing Mix Modeling},
|
|
395
395
|
url = {https://github.com/google/meridian},
|
|
396
|
-
version = {1.1.
|
|
396
|
+
version = {1.1.1},
|
|
397
397
|
year = {2025},
|
|
398
398
|
}
|
|
399
399
|
```
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
google_meridian-1.1.1.dist-info/licenses/LICENSE,sha256=WNHhf_5RCaeuKWyq_K39vmp9F28LxKsB4SpomwSZ2L0,11357
|
|
2
|
+
meridian/__init__.py,sha256=v7cNJABthU3UGBjqzcBs5J7MInPPxRkCUZChVo2pw3M,714
|
|
3
|
+
meridian/constants.py,sha256=AWhDEP9VcyQtPCbZhM6cPXHeWuz19wjaqB5lGz6qBsw,17161
|
|
4
|
+
meridian/analysis/__init__.py,sha256=nGBYz7k9FVdadO_WVGMKJcfq7Yy_TuuP8zgee4i9pSA,836
|
|
5
|
+
meridian/analysis/analyzer.py,sha256=VBEQYP28G23F2EXoEOqGrWJRmCr_ez-qWD3brQlqZI4,204098
|
|
6
|
+
meridian/analysis/formatter.py,sha256=ENIdR1CRiaVqIGEXx1HcnsA4ewgDD_nhsYCweJAThaw,7270
|
|
7
|
+
meridian/analysis/optimizer.py,sha256=ZmO05reNjlFOy8i3E8M9dDMYCIzNnQjLdH99zSorkqw,106122
|
|
8
|
+
meridian/analysis/summarizer.py,sha256=IthOUTMufGvAvbxiDhaKwe7uYCyiTyiQ8vgdmUtdevs,18855
|
|
9
|
+
meridian/analysis/summary_text.py,sha256=I_smDkZJYp2j77ea-9AIbgeraDa7-qUYyb-IthP2qO4,12438
|
|
10
|
+
meridian/analysis/test_utils.py,sha256=ES1r1akhRjD4pf2oTaGqzDfGNu9weAcLv6UZRuIkfEc,77699
|
|
11
|
+
meridian/analysis/visualizer.py,sha256=VHgbvGnRmloawilU_I7FPsqZcAYpZq5ODl3cHy2eiDo,93728
|
|
12
|
+
meridian/analysis/templates/card.html.jinja,sha256=pv4MVbQ25CcvtZY-LH7bFW0OSeHobkeEkAleB1sfQ14,1284
|
|
13
|
+
meridian/analysis/templates/chart.html.jinja,sha256=87i0xnXHRBoLLxBpKv2i960TLToWq4r1aVQZqaXIeMQ,1086
|
|
14
|
+
meridian/analysis/templates/chips.html.jinja,sha256=Az0tQwF_-b03JDLyOzpeH-8fb-6jgJgbNfnUUSm-q6E,645
|
|
15
|
+
meridian/analysis/templates/insights.html.jinja,sha256=6hEWipbOMiMzs9QGZ6dcB_73tNkj0ZtNiC8E89a98zg,606
|
|
16
|
+
meridian/analysis/templates/stats.html.jinja,sha256=9hQOG02FX1IHVIvdWS_-LI2bbSaqdyHEtCZkiArwAg0,772
|
|
17
|
+
meridian/analysis/templates/style.css,sha256=RODTWc2pXcG9zW3q9SEJpVXgeD-WwQgzLpmFcbXPhLg,5492
|
|
18
|
+
meridian/analysis/templates/style.scss,sha256=nSrZOpcIrVyiL4eC9jLUlxIZtAKZ0Rt8pwfk4H1nMrs,5076
|
|
19
|
+
meridian/analysis/templates/summary.html.jinja,sha256=LuENVDHYIpNo4pzloYaCR2K9XN1Ow6_9oQOcOwD9nGg,1707
|
|
20
|
+
meridian/analysis/templates/table.html.jinja,sha256=mvLMZx92RcD2JAS2w2eZtfYG-6WdfwYVo7pM8TbHp4g,1176
|
|
21
|
+
meridian/data/__init__.py,sha256=4F6_dCnDOic08yMw6_nIDR03B9cF_4STDFb430XvZR4,774
|
|
22
|
+
meridian/data/arg_builder.py,sha256=Kqlt88bOqFj6D3xNwvWo4MBwNwcDFHzd-wMfEOmLoPU,3741
|
|
23
|
+
meridian/data/input_data.py,sha256=teJPKTBfW-AzBWgf_fEO_S_Z1J_veqQkCvctINaid6I,39749
|
|
24
|
+
meridian/data/load.py,sha256=iFdNq9J89qlmOIrvMER1ci8LzZD87gHl6NTW49h7ZFE,55260
|
|
25
|
+
meridian/data/test_utils.py,sha256=6GJrPmeaF4uzMxxRgzERGv4g1XMUHwI0s7qDVMZUjuI,55565
|
|
26
|
+
meridian/data/time_coordinates.py,sha256=C5A5fscSLjPH6G9YT8OspgIlCrkMY7y8dMFEt3tNSnE,9874
|
|
27
|
+
meridian/model/__init__.py,sha256=9NFfqUE5WgFc-9lQMkbfkwwV-bQIz0tsQ_3Jyq0A4SU,982
|
|
28
|
+
meridian/model/adstock_hill.py,sha256=20A_6rbDUAADEkkHspB7JpCm5tYfYS1FQ6hJMLu21Pk,9283
|
|
29
|
+
meridian/model/knots.py,sha256=KPEgnb-UdQQ4QBugOYEke-zBgEghgTmeCMoeiJ30meY,8054
|
|
30
|
+
meridian/model/media.py,sha256=R0LnMUNTuGzXD2lzNRRORA4-p21xpdhkVVsvFaWtEK0,13819
|
|
31
|
+
meridian/model/model.py,sha256=JXHCcxpUDXqJQ9hI0YkY5PfGbpt8d3jAKR1TbCP08PI,61110
|
|
32
|
+
meridian/model/model_test_data.py,sha256=hDDTEzm72LknW9c5E_dNsy4Mm4Tfs6AirhGf_QxykFs,15552
|
|
33
|
+
meridian/model/posterior_sampler.py,sha256=jjLqcYEAorVJ_2nmhpkVUjCGAyNUZYPTEXVTDHufbqA,27727
|
|
34
|
+
meridian/model/prior_distribution.py,sha256=IEDU1rabcmKNY8lxwbbO4OUAlMHPIMa7flM_zsu3DLM,42417
|
|
35
|
+
meridian/model/prior_sampler.py,sha256=jSaxFmJzyN2OKqKyU059Ar4Yr565w4zlInPl4zxjGZk,23212
|
|
36
|
+
meridian/model/spec.py,sha256=b6nYj39L-Yy5j2i2IHdZHY2trRvjEA-9i_c3b__63A8,17239
|
|
37
|
+
meridian/model/transformers.py,sha256=nRjzq1fQG0ypldxboM7Gqok6WSAXAS1witRXoAzeH9Q,7763
|
|
38
|
+
google_meridian-1.1.1.dist-info/METADATA,sha256=5yywzNt-Pe3h9GLYo-0MfmOku5tHg2J5XrcJtUTp3Gk,22055
|
|
39
|
+
google_meridian-1.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
40
|
+
google_meridian-1.1.1.dist-info/top_level.txt,sha256=nwaCebZvvU34EopTKZsjK0OMTFjVnkf4FfnBN_TAc0g,9
|
|
41
|
+
google_meridian-1.1.1.dist-info/RECORD,,
|
meridian/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
"""Meridian API."""
|
|
16
16
|
|
|
17
|
-
__version__ = "1.1.
|
|
17
|
+
__version__ = "1.1.1"
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
from meridian import analysis
|
meridian/analysis/__init__.py
CHANGED
meridian/analysis/analyzer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -788,7 +788,7 @@ class Analyzer:
|
|
|
788
788
|
tensors are expected to be scaled by their corresponding transformers.
|
|
789
789
|
dist_tensors: A `DistributionTensors` container with the distribution
|
|
790
790
|
tensors for media, RF, organic media, organic RF, non-media treatments,
|
|
791
|
-
and controls.
|
|
791
|
+
and controls (if available).
|
|
792
792
|
|
|
793
793
|
Returns:
|
|
794
794
|
Tensor representing computed kpi means.
|
|
@@ -803,17 +803,15 @@ class Analyzer:
|
|
|
803
803
|
)
|
|
804
804
|
)
|
|
805
805
|
|
|
806
|
-
result = (
|
|
807
|
-
|
|
808
|
-
+ tf.einsum(
|
|
809
|
-
"...gtm,...gm->...gt", combined_media_transformed, combined_beta
|
|
810
|
-
)
|
|
811
|
-
+ tf.einsum(
|
|
812
|
-
"...gtc,...gc->...gt",
|
|
813
|
-
data_tensors.controls,
|
|
814
|
-
dist_tensors.gamma_gc,
|
|
815
|
-
)
|
|
806
|
+
result = tau_gt + tf.einsum(
|
|
807
|
+
"...gtm,...gm->...gt", combined_media_transformed, combined_beta
|
|
816
808
|
)
|
|
809
|
+
if self._meridian.controls is not None:
|
|
810
|
+
result += tf.einsum(
|
|
811
|
+
"...gtc,...gc->...gt",
|
|
812
|
+
data_tensors.controls,
|
|
813
|
+
dist_tensors.gamma_gc,
|
|
814
|
+
)
|
|
817
815
|
if data_tensors.non_media_treatments is not None:
|
|
818
816
|
result += tf.einsum(
|
|
819
817
|
"...gtm,...gm->...gt",
|
|
@@ -1464,11 +1462,14 @@ class Analyzer:
|
|
|
1464
1462
|
(n_chains, 0, self._meridian.n_geos, self._meridian.n_times)
|
|
1465
1463
|
)
|
|
1466
1464
|
batch_starting_indices = np.arange(n_draws, step=batch_size)
|
|
1467
|
-
param_list =
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1465
|
+
param_list = (
|
|
1466
|
+
[
|
|
1467
|
+
constants.MU_T,
|
|
1468
|
+
constants.TAU_G,
|
|
1469
|
+
]
|
|
1470
|
+
+ ([constants.GAMMA_GC] if self._meridian.n_controls else [])
|
|
1471
|
+
+ self._get_causal_param_names(include_non_paid_channels=True)
|
|
1472
|
+
)
|
|
1472
1473
|
outcome_means_temps = []
|
|
1473
1474
|
for start_index in batch_starting_indices:
|
|
1474
1475
|
stop_index = np.min([n_draws, start_index + batch_size])
|
meridian/analysis/formatter.py
CHANGED
meridian/analysis/optimizer.py
CHANGED
meridian/analysis/summarizer.py
CHANGED
meridian/analysis/test_utils.py
CHANGED
meridian/analysis/visualizer.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -1493,8 +1493,7 @@ class MediaSummary:
|
|
|
1493
1493
|
Returns:
|
|
1494
1494
|
An `xarray.Dataset` containing the following:
|
|
1495
1495
|
- **Coordinates:** `channel`, `metric` (`mean`, `median`, `ci_lo`,
|
|
1496
|
-
|
|
1497
|
-
`distribution` (`prior`, `posterior`)
|
|
1496
|
+
`ci_hi`), `distribution` (`prior`, `posterior`)
|
|
1498
1497
|
- **Data variables:** `incremental_outcome`, `pct_of_contribution`,
|
|
1499
1498
|
`effectiveness`.
|
|
1500
1499
|
"""
|
meridian/constants.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -72,10 +72,10 @@ REVENUE = 'revenue'
|
|
|
72
72
|
NON_REVENUE = 'non_revenue'
|
|
73
73
|
REQUIRED_INPUT_DATA_ARRAY_NAMES = (
|
|
74
74
|
KPI,
|
|
75
|
-
CONTROLS,
|
|
76
75
|
POPULATION,
|
|
77
76
|
)
|
|
78
77
|
OPTIONAL_INPUT_DATA_ARRAY_NAMES = (
|
|
78
|
+
CONTROLS,
|
|
79
79
|
REVENUE_PER_KPI,
|
|
80
80
|
ORGANIC_MEDIA,
|
|
81
81
|
ORGANIC_REACH,
|
|
@@ -148,7 +148,6 @@ REQUIRED_INPUT_DATA_COORD_NAMES = (
|
|
|
148
148
|
GEO,
|
|
149
149
|
TIME,
|
|
150
150
|
MEDIA_TIME,
|
|
151
|
-
CONTROL_VARIABLE,
|
|
152
151
|
)
|
|
153
152
|
NON_PAID_MEDIA_INPUT_DATA_COORD_NAMES = (
|
|
154
153
|
ORGANIC_MEDIA_CHANNEL,
|
|
@@ -159,6 +158,7 @@ MEDIA_INPUT_DATA_COORD_NAMES = (MEDIA_CHANNEL,)
|
|
|
159
158
|
RF_INPUT_DATA_COORD_NAMES = (RF_CHANNEL,)
|
|
160
159
|
POSSIBLE_INPUT_DATA_COORD_NAMES = (
|
|
161
160
|
REQUIRED_INPUT_DATA_COORD_NAMES
|
|
161
|
+
+ (CONTROL_VARIABLE,)
|
|
162
162
|
+ NON_PAID_MEDIA_INPUT_DATA_COORD_NAMES
|
|
163
163
|
+ MEDIA_INPUT_DATA_COORD_NAMES
|
|
164
164
|
+ RF_INPUT_DATA_COORD_NAMES
|
meridian/data/__init__.py
CHANGED
meridian/data/arg_builder.py
CHANGED
meridian/data/input_data.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -121,11 +121,11 @@ class InputData:
|
|
|
121
121
|
`revenue_per_kpi` exists, ROI calibration is used and the analysis is run
|
|
122
122
|
on revenue. When the `revenue_per_kpi` doesn't exist for the same
|
|
123
123
|
`kpi_type`, custom ROI calibration is used and the analysis is run on KPI.
|
|
124
|
-
controls: A DataArray of dimensions `(n_geos, n_times, n_controls)`
|
|
125
|
-
containing control variable values.
|
|
126
124
|
population: A DataArray of dimensions `(n_geos,)` containing the population
|
|
127
125
|
of each group. This variable is used to scale the KPI and media for
|
|
128
126
|
modeling.
|
|
127
|
+
controls: An optional DataArray of dimensions `(n_geos, n_times,
|
|
128
|
+
n_controls)` containing control variable values.
|
|
129
129
|
revenue_per_kpi: An optional DataArray of dimensions `(n_geos, n_times)`
|
|
130
130
|
containing the average revenue amount per KPI unit. Although modeling is
|
|
131
131
|
done on `kpi`, model analysis and optimization are done on `KPI *
|
|
@@ -275,8 +275,8 @@ class InputData:
|
|
|
275
275
|
|
|
276
276
|
kpi: xr.DataArray
|
|
277
277
|
kpi_type: str
|
|
278
|
-
controls: xr.DataArray
|
|
279
278
|
population: xr.DataArray
|
|
279
|
+
controls: xr.DataArray | None = None
|
|
280
280
|
revenue_per_kpi: xr.DataArray | None = None
|
|
281
281
|
media: xr.DataArray | None = None
|
|
282
282
|
media_spend: xr.DataArray | None = None
|
|
@@ -409,9 +409,12 @@ class InputData:
|
|
|
409
409
|
return None
|
|
410
410
|
|
|
411
411
|
@property
|
|
412
|
-
def control_variable(self) -> xr.DataArray:
|
|
412
|
+
def control_variable(self) -> xr.DataArray | None:
|
|
413
413
|
"""Returns the control variable dimension."""
|
|
414
|
-
|
|
414
|
+
if self.controls is not None:
|
|
415
|
+
return self.controls[constants.CONTROL_VARIABLE]
|
|
416
|
+
else:
|
|
417
|
+
return None
|
|
415
418
|
|
|
416
419
|
@property
|
|
417
420
|
def media_spend_has_geo_dimension(self) -> bool:
|
|
@@ -502,8 +505,8 @@ class InputData:
|
|
|
502
505
|
# Must match the order of constants.POSSIBLE_INPUT_DATA_ARRAY_NAMES!
|
|
503
506
|
arrays = (
|
|
504
507
|
self.kpi,
|
|
505
|
-
self.controls,
|
|
506
508
|
self.population,
|
|
509
|
+
self.controls,
|
|
507
510
|
self.revenue_per_kpi,
|
|
508
511
|
self.organic_media,
|
|
509
512
|
self.organic_reach,
|
|
@@ -786,9 +789,10 @@ class InputData:
|
|
|
786
789
|
"""Returns data as a single `xarray.Dataset` object."""
|
|
787
790
|
data = [
|
|
788
791
|
self.kpi,
|
|
789
|
-
self.controls,
|
|
790
792
|
self.population,
|
|
791
793
|
]
|
|
794
|
+
if self.controls is not None:
|
|
795
|
+
data.append(self.controls)
|
|
792
796
|
if self.revenue_per_kpi is not None:
|
|
793
797
|
data.append(self.revenue_per_kpi)
|
|
794
798
|
if self.media is not None:
|
meridian/data/load.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -79,7 +79,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
79
79
|
"""Constructor.
|
|
80
80
|
|
|
81
81
|
The coordinates of the input dataset should be: `time`, `media_time`,
|
|
82
|
-
`control_variable
|
|
82
|
+
`control_variable` (optional), `geo` (optional for a national model),
|
|
83
83
|
`non_media_channel` (optional), `organic_media_channel` (optional),
|
|
84
84
|
`organic_rf_channel` (optional), and
|
|
85
85
|
either `media_channel`, `rf_channel`, or both.
|
|
@@ -93,7 +93,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
93
93
|
|
|
94
94
|
* `kpi`: `(geo, time)`
|
|
95
95
|
* `revenue_per_kpi`: `(geo, time)`
|
|
96
|
-
* `controls`: `(geo, time, control_variable)`
|
|
96
|
+
* `controls`: `(geo, time, control_variable)` - optional
|
|
97
97
|
* `population`: `(geo)`
|
|
98
98
|
* `media`: `(geo, media_time, media_channel)` - optional
|
|
99
99
|
* `media_spend`: `(geo, time, media_channel)`, `(1, time, media_channel)`,
|
|
@@ -113,7 +113,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
113
113
|
|
|
114
114
|
* `kpi`: `([1,] time)`
|
|
115
115
|
* `revenue_per_kpi`: `([1,] time)`
|
|
116
|
-
* `controls`: `([1,] time, control_variable)`
|
|
116
|
+
* `controls`: `([1,] time, control_variable)` - optional
|
|
117
117
|
* `population`: `([1],)` - this array is optional for national data
|
|
118
118
|
* `media`: `([1,] media_time, media_channel)` - optional
|
|
119
119
|
* `media_spend`: `([1,] time, media_channel)` or
|
|
@@ -198,7 +198,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
198
198
|
self.dataset = dataset.rename(name_mapping)
|
|
199
199
|
|
|
200
200
|
# Add a `geo` dimension if it is not already present.
|
|
201
|
-
if (constants.GEO) not in self.dataset.
|
|
201
|
+
if (constants.GEO) not in self.dataset.sizes.keys():
|
|
202
202
|
self.dataset = self.dataset.expand_dims(dim=[constants.GEO], axis=0)
|
|
203
203
|
|
|
204
204
|
if len(self.dataset.coords[constants.GEO]) == 1:
|
|
@@ -228,7 +228,7 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
228
228
|
compat='override',
|
|
229
229
|
)
|
|
230
230
|
|
|
231
|
-
if constants.MEDIA_TIME not in self.dataset.
|
|
231
|
+
if constants.MEDIA_TIME not in self.dataset.sizes.keys():
|
|
232
232
|
self._add_media_time()
|
|
233
233
|
self._normalize_time_coordinates(constants.TIME)
|
|
234
234
|
self._normalize_time_coordinates(constants.MEDIA_TIME)
|
|
@@ -349,14 +349,17 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
349
349
|
# Arrays in which NAs are expected in the lagged-media period.
|
|
350
350
|
na_arrays = [
|
|
351
351
|
constants.KPI,
|
|
352
|
-
constants.CONTROLS,
|
|
353
352
|
]
|
|
354
353
|
|
|
355
|
-
na_mask = self.dataset[constants.KPI].isnull().any(
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
354
|
+
na_mask = self.dataset[constants.KPI].isnull().any(dim=constants.GEO)
|
|
355
|
+
|
|
356
|
+
if constants.CONTROLS in self.dataset.data_vars.keys():
|
|
357
|
+
na_arrays.append(constants.CONTROLS)
|
|
358
|
+
na_mask |= (
|
|
359
|
+
self.dataset[constants.CONTROLS]
|
|
360
|
+
.isnull()
|
|
361
|
+
.any(dim=[constants.GEO, constants.CONTROL_VARIABLE])
|
|
362
|
+
)
|
|
360
363
|
|
|
361
364
|
if constants.NON_MEDIA_TREATMENTS in self.dataset.data_vars.keys():
|
|
362
365
|
na_arrays.append(constants.NON_MEDIA_TREATMENTS)
|
|
@@ -427,11 +430,12 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
427
430
|
.dropna(dim=constants.TIME)
|
|
428
431
|
.rename({constants.TIME: new_time})
|
|
429
432
|
)
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
433
|
+
if constants.CONTROLS in new_dataset.data_vars.keys():
|
|
434
|
+
new_dataset[constants.CONTROLS] = (
|
|
435
|
+
new_dataset[constants.CONTROLS]
|
|
436
|
+
.dropna(dim=constants.TIME)
|
|
437
|
+
.rename({constants.TIME: new_time})
|
|
438
|
+
)
|
|
435
439
|
if constants.NON_MEDIA_TREATMENTS in new_dataset.data_vars.keys():
|
|
436
440
|
new_dataset[constants.NON_MEDIA_TREATMENTS] = (
|
|
437
441
|
new_dataset[constants.NON_MEDIA_TREATMENTS]
|
|
@@ -466,6 +470,11 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
466
470
|
|
|
467
471
|
def load(self) -> input_data.InputData:
|
|
468
472
|
"""Returns an `InputData` object containing the data from the dataset."""
|
|
473
|
+
controls = (
|
|
474
|
+
self.dataset.controls
|
|
475
|
+
if constants.CONTROLS in self.dataset.data_vars.keys()
|
|
476
|
+
else None
|
|
477
|
+
)
|
|
469
478
|
revenue_per_kpi = (
|
|
470
479
|
self.dataset.revenue_per_kpi
|
|
471
480
|
if constants.REVENUE_PER_KPI in self.dataset.data_vars.keys()
|
|
@@ -519,9 +528,9 @@ class XrDatasetDataLoader(InputDataLoader):
|
|
|
519
528
|
return input_data.InputData(
|
|
520
529
|
kpi=self.dataset.kpi,
|
|
521
530
|
kpi_type=self.kpi_type,
|
|
522
|
-
revenue_per_kpi=revenue_per_kpi,
|
|
523
|
-
controls=self.dataset.controls,
|
|
524
531
|
population=self.dataset.population,
|
|
532
|
+
controls=controls,
|
|
533
|
+
revenue_per_kpi=revenue_per_kpi,
|
|
525
534
|
media=media,
|
|
526
535
|
media_spend=media_spend,
|
|
527
536
|
reach=reach,
|
|
@@ -539,14 +548,14 @@ class CoordToColumns:
|
|
|
539
548
|
"""A mapping between the desired and actual column names in the input data.
|
|
540
549
|
|
|
541
550
|
Attributes:
|
|
542
|
-
controls: List of column names containing `controls` values in the input
|
|
543
|
-
data.
|
|
544
551
|
time: Name of column containing `time` values in the input data.
|
|
545
|
-
kpi: Name of column containing `kpi` values in the input data.
|
|
546
|
-
revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
|
|
547
|
-
input data.
|
|
548
552
|
geo: Name of column containing `geo` values in the input data. This field
|
|
549
553
|
is optional for a national model.
|
|
554
|
+
kpi: Name of column containing `kpi` values in the input data.
|
|
555
|
+
controls: List of column names containing `controls` values in the input
|
|
556
|
+
data. Optional.
|
|
557
|
+
revenue_per_kpi: Name of column containing `revenue_per_kpi` values in the
|
|
558
|
+
input data. Optional. Will be overridden if model KPI type is "revenue".
|
|
550
559
|
population: Name of column containing `population` values in the input data.
|
|
551
560
|
This field is optional for a national model.
|
|
552
561
|
media: List of column names containing `media` values in the input data.
|
|
@@ -567,11 +576,11 @@ class CoordToColumns:
|
|
|
567
576
|
values in the input data.
|
|
568
577
|
"""
|
|
569
578
|
|
|
570
|
-
controls: Sequence[str]
|
|
571
579
|
time: str = constants.TIME
|
|
580
|
+
geo: str = constants.GEO
|
|
572
581
|
kpi: str = constants.KPI
|
|
582
|
+
controls: Sequence[str] | None = None
|
|
573
583
|
revenue_per_kpi: str | None = None
|
|
574
|
-
geo: str = constants.GEO
|
|
575
584
|
population: str = constants.POPULATION
|
|
576
585
|
# Media data
|
|
577
586
|
media: Sequence[str] | None = None
|
|
@@ -607,7 +616,7 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
607
616
|
to the DataFrame column names if they are different. The fields are:
|
|
608
617
|
|
|
609
618
|
* `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
|
|
610
|
-
* `controls` (multiple columns)
|
|
619
|
+
* `controls` (multiple columns, optional)
|
|
611
620
|
* (1) `media`, `media_spend` (multiple columns)
|
|
612
621
|
* (2) `reach`, `frequency`, `rf_spend` (multiple columns)
|
|
613
622
|
* `non_media_treatments` (multiple columns, optional)
|
|
@@ -953,9 +962,10 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
953
962
|
not_lagged_columns = []
|
|
954
963
|
coords = [
|
|
955
964
|
constants.KPI,
|
|
956
|
-
constants.CONTROLS,
|
|
957
965
|
constants.POPULATION,
|
|
958
966
|
]
|
|
967
|
+
if self.coord_to_columns.controls is not None:
|
|
968
|
+
coords.append(constants.CONTROLS)
|
|
959
969
|
if self.coord_to_columns.revenue_per_kpi is not None:
|
|
960
970
|
coords.append(constants.REVENUE_PER_KPI)
|
|
961
971
|
if self.coord_to_columns.media_spend is not None:
|
|
@@ -1042,17 +1052,20 @@ class DataFrameDataLoader(InputDataLoader):
|
|
|
1042
1052
|
.to_frame()
|
|
1043
1053
|
.to_xarray()
|
|
1044
1054
|
)
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1055
|
+
dataset = xr.combine_by_coords([kpi_xr, population_xr])
|
|
1056
|
+
|
|
1057
|
+
if self.coord_to_columns.controls is not None:
|
|
1058
|
+
controls_xr = (
|
|
1059
|
+
df_indexed[self.coord_to_columns.controls]
|
|
1060
|
+
.stack()
|
|
1061
|
+
.rename(constants.CONTROLS)
|
|
1062
|
+
.rename_axis(
|
|
1063
|
+
[constants.GEO, constants.TIME, constants.CONTROL_VARIABLE]
|
|
1064
|
+
)
|
|
1065
|
+
.to_frame()
|
|
1066
|
+
.to_xarray()
|
|
1067
|
+
)
|
|
1068
|
+
dataset = xr.combine_by_coords([dataset, controls_xr])
|
|
1056
1069
|
|
|
1057
1070
|
if self.coord_to_columns.non_media_treatments is not None:
|
|
1058
1071
|
non_media_xr = (
|
|
@@ -1224,7 +1237,7 @@ class CsvDataLoader(InputDataLoader):
|
|
|
1224
1237
|
CSV column names, if they are different. The fields are:
|
|
1225
1238
|
|
|
1226
1239
|
* `geo`, `time`, `kpi`, `revenue_per_kpi`, `population` (single column)
|
|
1227
|
-
* `controls` (multiple columns)
|
|
1240
|
+
* `controls` (multiple columns, optional)
|
|
1228
1241
|
* (1) `media`, `media_spend` (multiple columns)
|
|
1229
1242
|
* (2) `reach`, `frequency`, `rf_spend` (multiple columns)
|
|
1230
1243
|
* `non_media_treatments` (multiple columns, optional)
|