google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/METADATA +8 -2
- google_meridian-1.2.1.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +621 -393
- meridian/analysis/optimizer.py +403 -351
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +53 -54
- meridian/backend/__init__.py +975 -0
- meridian/backend/config.py +118 -0
- meridian/backend/test_utils.py +181 -0
- meridian/constants.py +71 -10
- meridian/data/input_data.py +99 -0
- meridian/data/test_utils.py +146 -12
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +280 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +735 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +331 -159
- meridian/model/posterior_sampler.py +388 -383
- meridian/model/prior_distribution.py +612 -177
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +55 -49
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Backend configuration for Meridian."""
|
|
16
|
+
|
|
17
|
+
import enum
|
|
18
|
+
import os
|
|
19
|
+
from typing import Union
|
|
20
|
+
import warnings
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Backend(enum.Enum):
|
|
24
|
+
TENSORFLOW = "tensorflow"
|
|
25
|
+
JAX = "jax"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
_DEFAULT_BACKEND = Backend.TENSORFLOW
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _warn_jax_experimental() -> None:
|
|
32
|
+
"""Issues a warning that the JAX backend is experimental."""
|
|
33
|
+
warnings.warn(
|
|
34
|
+
(
|
|
35
|
+
"The JAX backend is currently under development and is not yet"
|
|
36
|
+
" functional. It is intended for internal testing only and should"
|
|
37
|
+
" not be used. Please use the TensorFlow backend."
|
|
38
|
+
),
|
|
39
|
+
UserWarning,
|
|
40
|
+
# Set stacklevel=2 so the warning points to the caller of set_backend
|
|
41
|
+
# or the location where the module is imported if initialized via env var.
|
|
42
|
+
stacklevel=2,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _initialize_backend() -> Backend:
|
|
47
|
+
"""Initializes the backend based on environment variables or defaults."""
|
|
48
|
+
env_backend_str = os.environ.get("MERIDIAN_BACKEND")
|
|
49
|
+
|
|
50
|
+
if not env_backend_str:
|
|
51
|
+
return _DEFAULT_BACKEND
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
backend = Backend(env_backend_str.lower())
|
|
55
|
+
if backend == Backend.JAX:
|
|
56
|
+
_warn_jax_experimental()
|
|
57
|
+
return backend
|
|
58
|
+
except ValueError:
|
|
59
|
+
warnings.warn(
|
|
60
|
+
(
|
|
61
|
+
"Invalid MERIDIAN_BACKEND environment variable:"
|
|
62
|
+
f" '{env_backend_str}'. Supported values are 'tensorflow' and"
|
|
63
|
+
f" 'jax'. Defaulting to {_DEFAULT_BACKEND.value}."
|
|
64
|
+
),
|
|
65
|
+
RuntimeWarning,
|
|
66
|
+
)
|
|
67
|
+
return _DEFAULT_BACKEND
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
_BACKEND = _initialize_backend()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def set_backend(backend: Union[Backend, str]) -> None:
|
|
74
|
+
"""Sets the backend for Meridian.
|
|
75
|
+
|
|
76
|
+
**Warning:** This function should ideally be called at the beginning of your
|
|
77
|
+
program, before any other Meridian modules are imported or used.
|
|
78
|
+
|
|
79
|
+
Changing the backend after Meridian's functions or classes have been
|
|
80
|
+
imported can lead to unpredictable behavior. This is because already-imported
|
|
81
|
+
modules will not reflect the backend change.
|
|
82
|
+
|
|
83
|
+
Note: The JAX backend is currently under development and should not be used.
|
|
84
|
+
|
|
85
|
+
Changing the backend at runtime requires reloading the `meridian.backend`
|
|
86
|
+
module for the changes to take effect globally.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
backend: The backend to use, must be a member of the `Backend` enum or a
|
|
90
|
+
valid string ('tensorflow', 'jax').
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
ValueError: If the provided backend is not valid.
|
|
94
|
+
"""
|
|
95
|
+
global _BACKEND
|
|
96
|
+
|
|
97
|
+
if isinstance(backend, str):
|
|
98
|
+
try:
|
|
99
|
+
backend_enum = Backend(backend.lower())
|
|
100
|
+
except ValueError as exc:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Invalid backend string '{backend}'. Must be one of: "
|
|
103
|
+
f"{[b.value for b in Backend]}"
|
|
104
|
+
) from exc
|
|
105
|
+
elif isinstance(backend, Backend):
|
|
106
|
+
backend_enum = backend
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError("Backend must be a Backend enum member or a string.")
|
|
109
|
+
|
|
110
|
+
if backend_enum == Backend.JAX and _BACKEND != Backend.JAX:
|
|
111
|
+
_warn_jax_experimental()
|
|
112
|
+
|
|
113
|
+
_BACKEND = backend_enum
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_backend() -> Backend:
|
|
117
|
+
"""Returns the current backend for Meridian."""
|
|
118
|
+
return _BACKEND
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Common testing utilities for Meridian, designed to be backend-agnostic."""
|
|
16
|
+
|
|
17
|
+
from typing import Any, Optional
|
|
18
|
+
from absl.testing import parameterized
|
|
19
|
+
from meridian import backend
|
|
20
|
+
from meridian.backend import config
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
# A type alias for backend-agnostic array-like objects.
|
|
24
|
+
# We use `Any` here to avoid circular dependencies with the backend module
|
|
25
|
+
# while still allowing the function to accept backend-specific tensor types.
|
|
26
|
+
ArrayLike = Any
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def assert_allclose(
|
|
30
|
+
a: ArrayLike,
|
|
31
|
+
b: ArrayLike,
|
|
32
|
+
rtol: float = 1e-6,
|
|
33
|
+
atol: float = 1e-6,
|
|
34
|
+
err_msg: str = "",
|
|
35
|
+
):
|
|
36
|
+
"""Backend-agnostic assertion to check if two array-like objects are close.
|
|
37
|
+
|
|
38
|
+
This function converts both inputs to NumPy arrays before comparing them,
|
|
39
|
+
making it compatible with TensorFlow Tensors, JAX Arrays, and standard
|
|
40
|
+
Python lists or NumPy arrays.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
a: The first array-like object to compare.
|
|
44
|
+
b: The second array-like object to compare.
|
|
45
|
+
rtol: The relative tolerance parameter.
|
|
46
|
+
atol: The absolute tolerance parameter.
|
|
47
|
+
err_msg: The error message to be printed in case of failure.
|
|
48
|
+
|
|
49
|
+
Raises:
|
|
50
|
+
AssertionError: If the two arrays are not equal within the given tolerance.
|
|
51
|
+
"""
|
|
52
|
+
np.testing.assert_allclose(
|
|
53
|
+
np.array(a), np.array(b), rtol=rtol, atol=atol, err_msg=err_msg
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
|
|
58
|
+
"""Backend-agnostic assertion to check if two array-like objects are equal.
|
|
59
|
+
|
|
60
|
+
This function converts both inputs to NumPy arrays before comparing them.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
a: The first array-like object to compare.
|
|
64
|
+
b: The second array-like object to compare.
|
|
65
|
+
err_msg: The error message to be printed in case of failure.
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
AssertionError: If the two arrays are not equal.
|
|
69
|
+
"""
|
|
70
|
+
np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def assert_all_finite(a: ArrayLike, err_msg: str = ""):
|
|
74
|
+
"""Backend-agnostic assertion to check if all elements in an array are finite.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
a: The array-like object to check.
|
|
78
|
+
err_msg: The error message to be printed in case of failure.
|
|
79
|
+
|
|
80
|
+
Raises:
|
|
81
|
+
AssertionError: If the array contains non-finite values.
|
|
82
|
+
"""
|
|
83
|
+
if not np.all(np.isfinite(np.array(a))):
|
|
84
|
+
raise AssertionError(err_msg or "Array contains non-finite values.")
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
|
|
88
|
+
"""Backend-agnostic assertion to check if all elements are non-negative.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
a: The array-like object to check.
|
|
92
|
+
err_msg: The error message to be printed in case of failure.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
AssertionError: If the array contains negative values.
|
|
96
|
+
"""
|
|
97
|
+
if not np.all(np.array(a) >= 0):
|
|
98
|
+
raise AssertionError(err_msg or "Array contains negative values.")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class MeridianTestCase(parameterized.TestCase):
|
|
102
|
+
"""Base test class for Meridian providing backend-aware utilities.
|
|
103
|
+
|
|
104
|
+
This class handles initialization timing issues (crucial for JAX by forcing
|
|
105
|
+
tensor operations into setUp) and provides a unified way to handle random
|
|
106
|
+
number generation across backends (Stateful TF vs Stateless JAX).
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
def setUp(self):
|
|
110
|
+
super().setUp()
|
|
111
|
+
# Default seed, can be overridden by subclasses before calling
|
|
112
|
+
# _initialize_rng().
|
|
113
|
+
self.seed = 42
|
|
114
|
+
self._jax_key = None
|
|
115
|
+
self._initialize_rng()
|
|
116
|
+
|
|
117
|
+
def _initialize_rng(self):
|
|
118
|
+
"""Initializes the RNG state or key based on self.seed."""
|
|
119
|
+
current_backend = config.get_backend()
|
|
120
|
+
|
|
121
|
+
if current_backend == config.Backend.TENSORFLOW:
|
|
122
|
+
# In TF, we use the global stateful seed for test reproducibility.
|
|
123
|
+
try:
|
|
124
|
+
backend.set_random_seed(self.seed)
|
|
125
|
+
except NotImplementedError:
|
|
126
|
+
# Handle cases where backend might be misconfigured during transition.
|
|
127
|
+
pass
|
|
128
|
+
elif current_backend == config.Backend.JAX:
|
|
129
|
+
# In JAX, we must manage PRNGKeys explicitly.
|
|
130
|
+
# Import JAX locally to avoid hard dependency if TF is the active backend,
|
|
131
|
+
# and to ensure initialization happens after absltest.main() starts.
|
|
132
|
+
# pylint: disable=g-import-not-at-top
|
|
133
|
+
import jax
|
|
134
|
+
# pylint: enable=g-import-not-at-top
|
|
135
|
+
self._jax_key = jax.random.PRNGKey(self.seed)
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError(f"Unknown backend: {current_backend}")
|
|
138
|
+
|
|
139
|
+
def get_next_rng_seed_or_key(self) -> Optional[Any]:
|
|
140
|
+
"""Gets the next available seed or key for backend operations.
|
|
141
|
+
|
|
142
|
+
This should be passed to the `seed` argument of TFP sampling methods.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
A JAX PRNGKey if the backend is JAX (splitting the internal key).
|
|
146
|
+
None if the backend is TensorFlow (relying on the global state).
|
|
147
|
+
"""
|
|
148
|
+
if self._jax_key is not None:
|
|
149
|
+
# JAX requires splitting the key for each use.
|
|
150
|
+
# pylint: disable=g-import-not-at-top
|
|
151
|
+
import jax
|
|
152
|
+
# pylint: enable=g-import-not-at-top
|
|
153
|
+
self._jax_key, subkey = jax.random.split(self._jax_key)
|
|
154
|
+
return subkey
|
|
155
|
+
else:
|
|
156
|
+
# For stateful TF, returning None allows TFP/TF to use the global seed.
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
def sample(
|
|
160
|
+
self,
|
|
161
|
+
distribution: backend.tfd.Distribution,
|
|
162
|
+
sample_shape: Any = (),
|
|
163
|
+
**kwargs: Any,
|
|
164
|
+
) -> backend.Tensor:
|
|
165
|
+
"""Performs a backend-agnostic sample from a distribution.
|
|
166
|
+
|
|
167
|
+
This method abstracts away the need for explicit seed management in JAX.
|
|
168
|
+
When the JAX backend is active, it automatically provides a PRNGKey from
|
|
169
|
+
the test's managed key state. In TensorFlow, it performs a standard sample.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
distribution: The TFP distribution object to sample from.
|
|
173
|
+
sample_shape: The shape of the desired sample.
|
|
174
|
+
**kwargs: Additional keyword arguments to pass to the underlying `sample`
|
|
175
|
+
method (e.g., `name`).
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
A tensor containing the sampled values.
|
|
179
|
+
"""
|
|
180
|
+
seed = self.get_next_rng_seed_or_key()
|
|
181
|
+
return distribution.sample(sample_shape=sample_shape, seed=seed, **kwargs)
|
meridian/constants.py
CHANGED
|
@@ -54,6 +54,8 @@ DATE_FORMAT = '%Y-%m-%d'
|
|
|
54
54
|
# Example: "2024 Apr"
|
|
55
55
|
QUARTER_FORMAT = '%Y %b'
|
|
56
56
|
|
|
57
|
+
ORGANIC_PREFIX = 'organic_'
|
|
58
|
+
|
|
57
59
|
# Input data variables.
|
|
58
60
|
KPI = 'kpi'
|
|
59
61
|
REVENUE_PER_KPI = 'revenue_per_kpi'
|
|
@@ -65,9 +67,10 @@ REACH = 'reach'
|
|
|
65
67
|
FREQUENCY = 'frequency'
|
|
66
68
|
RF_IMPRESSIONS = 'rf_impressions'
|
|
67
69
|
RF_SPEND = 'rf_spend'
|
|
68
|
-
ORGANIC_MEDIA =
|
|
69
|
-
|
|
70
|
-
|
|
70
|
+
ORGANIC_MEDIA = ORGANIC_PREFIX + MEDIA
|
|
71
|
+
# ORGANIC_RF is defined below.
|
|
72
|
+
ORGANIC_REACH = ORGANIC_PREFIX + REACH
|
|
73
|
+
ORGANIC_FREQUENCY = ORGANIC_PREFIX + FREQUENCY
|
|
71
74
|
NON_MEDIA_TREATMENTS = 'non_media_treatments'
|
|
72
75
|
REVENUE = 'revenue'
|
|
73
76
|
NON_REVENUE = 'non_revenue'
|
|
@@ -125,8 +128,8 @@ NON_REVENUE_DATA = IMPRESSIONS_DATA + (CONTROLS,)
|
|
|
125
128
|
# Scaled input data variables.
|
|
126
129
|
MEDIA_SCALED = 'media_scaled'
|
|
127
130
|
REACH_SCALED = 'reach_scaled'
|
|
128
|
-
ORGANIC_MEDIA_SCALED =
|
|
129
|
-
ORGANIC_REACH_SCALED =
|
|
131
|
+
ORGANIC_MEDIA_SCALED = ORGANIC_PREFIX + MEDIA_SCALED
|
|
132
|
+
ORGANIC_REACH_SCALED = ORGANIC_PREFIX + REACH_SCALED
|
|
130
133
|
NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
|
|
131
134
|
CONTROLS_SCALED = 'controls_scaled'
|
|
132
135
|
|
|
@@ -143,8 +146,9 @@ MEDIA_CHANNEL = 'media_channel'
|
|
|
143
146
|
RF_CHANNEL = 'rf_channel'
|
|
144
147
|
CHANNEL = 'channel'
|
|
145
148
|
RF = 'rf'
|
|
146
|
-
|
|
147
|
-
|
|
149
|
+
ORGANIC_RF = ORGANIC_PREFIX + RF
|
|
150
|
+
ORGANIC_MEDIA_CHANNEL = ORGANIC_PREFIX + MEDIA_CHANNEL
|
|
151
|
+
ORGANIC_RF_CHANNEL = ORGANIC_PREFIX + RF_CHANNEL
|
|
148
152
|
NON_MEDIA_CHANNEL = 'non_media_channel'
|
|
149
153
|
CONTROL_VARIABLE = 'control_variable'
|
|
150
154
|
REQUIRED_INPUT_DATA_COORD_NAMES = (
|
|
@@ -170,6 +174,9 @@ POSSIBLE_INPUT_DATA_COORDS_AND_ARRAYS_SET = frozenset(
|
|
|
170
174
|
POSSIBLE_INPUT_DATA_COORD_NAMES + POSSIBLE_INPUT_DATA_ARRAY_NAMES
|
|
171
175
|
)
|
|
172
176
|
|
|
177
|
+
# EDA property constants
|
|
178
|
+
ORGANIC_RF_IMPRESSIONS = ORGANIC_PREFIX + RF_IMPRESSIONS
|
|
179
|
+
|
|
173
180
|
|
|
174
181
|
# National model constants.
|
|
175
182
|
NATIONAL = 'national'
|
|
@@ -212,9 +219,11 @@ NON_PAID_TREATMENT_PRIOR_TYPES = frozenset({
|
|
|
212
219
|
TREATMENT_PRIOR_TYPE_COEFFICIENT,
|
|
213
220
|
TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
214
221
|
})
|
|
215
|
-
PAID_MEDIA_ROI_PRIOR_TYPES = frozenset(
|
|
216
|
-
|
|
217
|
-
|
|
222
|
+
PAID_MEDIA_ROI_PRIOR_TYPES = frozenset({
|
|
223
|
+
TREATMENT_PRIOR_TYPE_ROI,
|
|
224
|
+
TREATMENT_PRIOR_TYPE_MROI,
|
|
225
|
+
TREATMENT_PRIOR_TYPE_CONTRIBUTION,
|
|
226
|
+
})
|
|
218
227
|
# Represents a 1% increase in spend.
|
|
219
228
|
MROI_FACTOR = 1.01
|
|
220
229
|
|
|
@@ -315,6 +324,41 @@ RF_PARAMETER_NAMES = (
|
|
|
315
324
|
BETA_RF,
|
|
316
325
|
BETA_GRF,
|
|
317
326
|
)
|
|
327
|
+
ORGANIC_MEDIA_PARAMETER_NAMES = (
|
|
328
|
+
CONTRIBUTION_OM,
|
|
329
|
+
BETA_OM,
|
|
330
|
+
ETA_OM,
|
|
331
|
+
ALPHA_OM,
|
|
332
|
+
EC_OM,
|
|
333
|
+
SLOPE_OM,
|
|
334
|
+
BETA_GOM,
|
|
335
|
+
)
|
|
336
|
+
ORGANIC_RF_PARAMETER_NAMES = (
|
|
337
|
+
CONTRIBUTION_ORF,
|
|
338
|
+
BETA_ORF,
|
|
339
|
+
ETA_ORF,
|
|
340
|
+
ALPHA_ORF,
|
|
341
|
+
EC_ORF,
|
|
342
|
+
SLOPE_ORF,
|
|
343
|
+
BETA_GORF,
|
|
344
|
+
)
|
|
345
|
+
NON_MEDIA_PARAMETER_NAMES = (
|
|
346
|
+
CONTRIBUTION_N,
|
|
347
|
+
GAMMA_N,
|
|
348
|
+
XI_N,
|
|
349
|
+
GAMMA_GN,
|
|
350
|
+
)
|
|
351
|
+
ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
|
|
352
|
+
SLOPE_M,
|
|
353
|
+
SLOPE_OM,
|
|
354
|
+
XI_N,
|
|
355
|
+
XI_C,
|
|
356
|
+
ETA_M,
|
|
357
|
+
ETA_RF,
|
|
358
|
+
ETA_OM,
|
|
359
|
+
ETA_ORF,
|
|
360
|
+
)
|
|
361
|
+
|
|
318
362
|
|
|
319
363
|
MEDIA_PARAMETERS = (
|
|
320
364
|
ROI_M,
|
|
@@ -501,10 +545,17 @@ ADSTOCK_HILL_FUNCTIONS = frozenset({
|
|
|
501
545
|
'hill',
|
|
502
546
|
})
|
|
503
547
|
|
|
548
|
+
# Adstock decay functions.
|
|
549
|
+
GEOMETRIC_DECAY = 'geometric'
|
|
550
|
+
BINOMIAL_DECAY = 'binomial'
|
|
551
|
+
|
|
552
|
+
ADSTOCK_DECAY_FUNCTIONS = frozenset({GEOMETRIC_DECAY, BINOMIAL_DECAY})
|
|
553
|
+
ADSTOCK_CHANNELS = (MEDIA, RF, ORGANIC_MEDIA, ORGANIC_RF)
|
|
504
554
|
|
|
505
555
|
# Distribution constants.
|
|
506
556
|
DISTRIBUTION = 'distribution'
|
|
507
557
|
DISTRIBUTION_TYPE = 'distribution_type'
|
|
558
|
+
INDEPENDENT_MULTIVARIATE = 'IndependentMultivariate'
|
|
508
559
|
PRIOR = 'prior'
|
|
509
560
|
POSTERIOR = 'posterior'
|
|
510
561
|
# Prior mean proportion of KPI incremental due to all media.
|
|
@@ -710,3 +761,13 @@ WEEKLY = 'weekly'
|
|
|
710
761
|
QUARTERLY = 'quarterly'
|
|
711
762
|
TIME_GRANULARITIES = frozenset({WEEKLY, QUARTERLY})
|
|
712
763
|
QUARTERLY_SUMMARY_THRESHOLD_WEEKS = 52
|
|
764
|
+
|
|
765
|
+
# Automatic Knot Selection constants
|
|
766
|
+
KNOTS_SELECTED = 'knots_selected'
|
|
767
|
+
SELECTION_COEFS = 'selection_coefs'
|
|
768
|
+
MODEL = 'model'
|
|
769
|
+
REGRESSION_COEFS = 'regression_coefs'
|
|
770
|
+
SELECTED_MATRIX = 'selected_matrix'
|
|
771
|
+
AIC = 'aic'
|
|
772
|
+
BIC = 'bic'
|
|
773
|
+
EBIC = 'ebic'
|
meridian/data/input_data.py
CHANGED
|
@@ -442,6 +442,59 @@ class InputData:
|
|
|
442
442
|
"""Checks whether the `rf_spend` array has a time dimension."""
|
|
443
443
|
return self.rf_spend is not None and constants.TIME in self.rf_spend.coords
|
|
444
444
|
|
|
445
|
+
@property
|
|
446
|
+
def scaled_centered_kpi(self) -> np.ndarray:
|
|
447
|
+
"""Calculates scaled and centered KPI values.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
An array of KPI values that have been population-scaled and
|
|
451
|
+
mean-centered by geo.
|
|
452
|
+
"""
|
|
453
|
+
kpi = self.kpi.values
|
|
454
|
+
population = self.population.values[:, np.newaxis]
|
|
455
|
+
|
|
456
|
+
population_scaled_kpi = np.divide(
|
|
457
|
+
kpi,
|
|
458
|
+
population,
|
|
459
|
+
out=np.zeros_like(kpi, dtype=float),
|
|
460
|
+
where=(population != 0),
|
|
461
|
+
)
|
|
462
|
+
population_scaled_mean = np.mean(population_scaled_kpi)
|
|
463
|
+
population_scaled_stdev = np.std(population_scaled_kpi)
|
|
464
|
+
kpi_scaled = np.divide(
|
|
465
|
+
population_scaled_kpi - population_scaled_mean,
|
|
466
|
+
population_scaled_stdev,
|
|
467
|
+
out=np.zeros_like(
|
|
468
|
+
population_scaled_kpi - population_scaled_mean, dtype=float
|
|
469
|
+
),
|
|
470
|
+
where=(population_scaled_stdev != 0),
|
|
471
|
+
)
|
|
472
|
+
return kpi_scaled - np.mean(kpi_scaled, axis=1, keepdims=True)
|
|
473
|
+
|
|
474
|
+
def copy(self, deep: bool = True) -> "InputData":
|
|
475
|
+
"""Returns a copy of the InputData instance.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
deep: If True, a deep copy is made, meaning all xarray.DataArray objects
|
|
479
|
+
are also deepcopied. If False, a shallow copy is made.
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
A new InputData instance.
|
|
483
|
+
"""
|
|
484
|
+
if not deep:
|
|
485
|
+
return dataclasses.replace(self)
|
|
486
|
+
|
|
487
|
+
copied_fields = {}
|
|
488
|
+
for field in dataclasses.fields(self):
|
|
489
|
+
value = getattr(self, field.name)
|
|
490
|
+
if isinstance(value, xr.DataArray):
|
|
491
|
+
copied_fields[field.name] = value.copy(deep=True)
|
|
492
|
+
else:
|
|
493
|
+
# For other types, dataclasses.replace does a shallow copy.
|
|
494
|
+
copied_fields[field.name] = value
|
|
495
|
+
|
|
496
|
+
return InputData(**copied_fields)
|
|
497
|
+
|
|
445
498
|
def _validate_scenarios(self):
|
|
446
499
|
"""Verifies that calibration and analysis is set correctly."""
|
|
447
500
|
n_geos = len(self.kpi.coords[constants.GEO])
|
|
@@ -848,6 +901,32 @@ class InputData:
|
|
|
848
901
|
raise ValueError("Both RF and media channel values are missing.")
|
|
849
902
|
# pytype: enable=attribute-error
|
|
850
903
|
|
|
904
|
+
def get_all_adstock_hill_channels(self) -> np.ndarray:
|
|
905
|
+
"""Returns all channel dimensions that adstock hill is applied to.
|
|
906
|
+
|
|
907
|
+
RF, organic media and organic RF channels are concatenated to the end of the
|
|
908
|
+
media channels if they are present.
|
|
909
|
+
"""
|
|
910
|
+
adstock_hill_channels = []
|
|
911
|
+
|
|
912
|
+
if self.media_channel is not None:
|
|
913
|
+
adstock_hill_channels.append(self.media_channel.values)
|
|
914
|
+
|
|
915
|
+
if self.rf_channel is not None:
|
|
916
|
+
adstock_hill_channels.append(self.rf_channel.values)
|
|
917
|
+
|
|
918
|
+
if self.organic_media_channel is not None:
|
|
919
|
+
adstock_hill_channels.append(self.organic_media_channel.values)
|
|
920
|
+
|
|
921
|
+
if self.organic_rf_channel is not None:
|
|
922
|
+
adstock_hill_channels.append(self.organic_rf_channel.values)
|
|
923
|
+
|
|
924
|
+
if not adstock_hill_channels:
|
|
925
|
+
raise ValueError("Media, RF, organic media and organic RF channels are "
|
|
926
|
+
"all missing.")
|
|
927
|
+
|
|
928
|
+
return np.concatenate(adstock_hill_channels, axis=None)
|
|
929
|
+
|
|
851
930
|
def get_paid_channels_argument_builder(
|
|
852
931
|
self,
|
|
853
932
|
) -> arg_builder.OrderedListArgumentBuilder:
|
|
@@ -870,6 +949,26 @@ class InputData:
|
|
|
870
949
|
raise ValueError("There are no RF channels in the input data.")
|
|
871
950
|
return arg_builder.OrderedListArgumentBuilder(self.rf_channel.values)
|
|
872
951
|
|
|
952
|
+
def get_organic_media_channels_argument_builder(
|
|
953
|
+
self
|
|
954
|
+
) -> arg_builder.OrderedListArgumentBuilder:
|
|
955
|
+
"""Returns an argument builder for *organic* media channels *only*."""
|
|
956
|
+
if self.organic_media_channel is None:
|
|
957
|
+
raise ValueError("There are no organic media channels in the input data.")
|
|
958
|
+
return arg_builder.OrderedListArgumentBuilder(
|
|
959
|
+
self.organic_media_channel.values
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
def get_organic_rf_channels_argument_builder(
|
|
963
|
+
self
|
|
964
|
+
) -> arg_builder.OrderedListArgumentBuilder:
|
|
965
|
+
"""Returns an argument builder for *organic* RF channels *only*."""
|
|
966
|
+
if self.organic_rf_channel is None:
|
|
967
|
+
raise ValueError("There are no organic RF channels in the input data.")
|
|
968
|
+
return arg_builder.OrderedListArgumentBuilder(
|
|
969
|
+
self.organic_rf_channel.values
|
|
970
|
+
)
|
|
971
|
+
|
|
873
972
|
def get_all_channels(self) -> np.ndarray:
|
|
874
973
|
"""Returns all the channel dimensions.
|
|
875
974
|
|