google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
- google_meridian-1.3.0.dist-info/RECORD +62 -0
- meridian/analysis/__init__.py +2 -0
- meridian/analysis/analyzer.py +280 -142
- meridian/analysis/formatter.py +2 -2
- meridian/analysis/optimizer.py +353 -169
- meridian/analysis/review/__init__.py +20 -0
- meridian/analysis/review/checks.py +721 -0
- meridian/analysis/review/configs.py +110 -0
- meridian/analysis/review/constants.py +40 -0
- meridian/analysis/review/results.py +544 -0
- meridian/analysis/review/reviewer.py +186 -0
- meridian/analysis/summarizer.py +14 -12
- meridian/analysis/templates/chips.html.jinja +12 -0
- meridian/analysis/test_utils.py +27 -5
- meridian/analysis/visualizer.py +45 -50
- meridian/backend/__init__.py +698 -55
- meridian/backend/config.py +75 -16
- meridian/backend/test_utils.py +127 -1
- meridian/constants.py +52 -11
- meridian/data/input_data.py +7 -2
- meridian/data/test_utils.py +5 -3
- meridian/mlflow/autolog.py +2 -2
- meridian/model/__init__.py +1 -0
- meridian/model/adstock_hill.py +10 -9
- meridian/model/eda/__init__.py +3 -0
- meridian/model/eda/constants.py +21 -0
- meridian/model/eda/eda_engine.py +1580 -84
- meridian/model/eda/eda_outcome.py +200 -0
- meridian/model/eda/eda_spec.py +84 -0
- meridian/model/eda/meridian_eda.py +220 -0
- meridian/model/knots.py +56 -50
- meridian/model/media.py +10 -8
- meridian/model/model.py +79 -16
- meridian/model/model_test_data.py +53 -9
- meridian/model/posterior_sampler.py +398 -391
- meridian/model/prior_distribution.py +114 -39
- meridian/model/prior_sampler.py +146 -90
- meridian/model/spec.py +7 -8
- meridian/model/transformers.py +16 -8
- meridian/version.py +1 -1
- google_meridian-1.2.0.dist-info/RECORD +0 -52
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
meridian/backend/config.py
CHANGED
|
@@ -15,6 +15,8 @@
|
|
|
15
15
|
"""Backend configuration for Meridian."""
|
|
16
16
|
|
|
17
17
|
import enum
|
|
18
|
+
import os
|
|
19
|
+
from typing import Union
|
|
18
20
|
import warnings
|
|
19
21
|
|
|
20
22
|
|
|
@@ -23,35 +25,92 @@ class Backend(enum.Enum):
|
|
|
23
25
|
JAX = "jax"
|
|
24
26
|
|
|
25
27
|
|
|
26
|
-
|
|
28
|
+
_DEFAULT_BACKEND = Backend.TENSORFLOW
|
|
27
29
|
|
|
28
30
|
|
|
29
|
-
def
|
|
31
|
+
def _warn_jax_experimental() -> None:
|
|
32
|
+
"""Issues a warning that the JAX backend is experimental."""
|
|
33
|
+
warnings.warn(
|
|
34
|
+
(
|
|
35
|
+
"The JAX backend is currently under development and is not yet"
|
|
36
|
+
" functional. It is intended for internal testing only and should"
|
|
37
|
+
" not be used. Please use the TensorFlow backend."
|
|
38
|
+
),
|
|
39
|
+
UserWarning,
|
|
40
|
+
# Set stacklevel=2 so the warning points to the caller of set_backend
|
|
41
|
+
# or the location where the module is imported if initialized via env var.
|
|
42
|
+
stacklevel=2,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _initialize_backend() -> Backend:
|
|
47
|
+
"""Initializes the backend based on environment variables or defaults."""
|
|
48
|
+
env_backend_str = os.environ.get("MERIDIAN_BACKEND")
|
|
49
|
+
|
|
50
|
+
if not env_backend_str:
|
|
51
|
+
return _DEFAULT_BACKEND
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
backend = Backend(env_backend_str.lower())
|
|
55
|
+
if backend == Backend.JAX:
|
|
56
|
+
_warn_jax_experimental()
|
|
57
|
+
return backend
|
|
58
|
+
except ValueError:
|
|
59
|
+
warnings.warn(
|
|
60
|
+
(
|
|
61
|
+
"Invalid MERIDIAN_BACKEND environment variable:"
|
|
62
|
+
f" '{env_backend_str}'. Supported values are 'tensorflow' and"
|
|
63
|
+
f" 'jax'. Defaulting to {_DEFAULT_BACKEND.value}."
|
|
64
|
+
),
|
|
65
|
+
RuntimeWarning,
|
|
66
|
+
)
|
|
67
|
+
return _DEFAULT_BACKEND
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
_BACKEND = _initialize_backend()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def set_backend(backend: Union[Backend, str]) -> None:
|
|
30
74
|
"""Sets the backend for Meridian.
|
|
31
75
|
|
|
76
|
+
**Warning:** This function should ideally be called at the beginning of your
|
|
77
|
+
program, before any other Meridian modules are imported or used.
|
|
78
|
+
|
|
79
|
+
Changing the backend after Meridian's functions or classes have been
|
|
80
|
+
imported can lead to unpredictable behavior. This is because already-imported
|
|
81
|
+
modules will not reflect the backend change.
|
|
82
|
+
|
|
32
83
|
Note: The JAX backend is currently under development and should not be used.
|
|
33
84
|
|
|
85
|
+
Changing the backend at runtime requires reloading the `meridian.backend`
|
|
86
|
+
module for the changes to take effect globally.
|
|
87
|
+
|
|
34
88
|
Args:
|
|
35
|
-
backend: The backend to use, must be a member of the `Backend` enum
|
|
89
|
+
backend: The backend to use, must be a member of the `Backend` enum or a
|
|
90
|
+
valid string ('tensorflow', 'jax').
|
|
36
91
|
|
|
37
92
|
Raises:
|
|
38
|
-
ValueError: If the provided backend is not
|
|
93
|
+
ValueError: If the provided backend is not valid.
|
|
39
94
|
"""
|
|
40
95
|
global _BACKEND
|
|
41
|
-
if not isinstance(backend, Backend):
|
|
42
|
-
raise ValueError("Backend must be a member of the Backend enum.")
|
|
43
96
|
|
|
44
|
-
if backend
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
97
|
+
if isinstance(backend, str):
|
|
98
|
+
try:
|
|
99
|
+
backend_enum = Backend(backend.lower())
|
|
100
|
+
except ValueError as exc:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Invalid backend string '{backend}'. Must be one of: "
|
|
103
|
+
f"{[b.value for b in Backend]}"
|
|
104
|
+
) from exc
|
|
105
|
+
elif isinstance(backend, Backend):
|
|
106
|
+
backend_enum = backend
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError("Backend must be a Backend enum member or a string.")
|
|
109
|
+
|
|
110
|
+
if backend_enum == Backend.JAX and _BACKEND != Backend.JAX:
|
|
111
|
+
_warn_jax_experimental()
|
|
53
112
|
|
|
54
|
-
_BACKEND =
|
|
113
|
+
_BACKEND = backend_enum
|
|
55
114
|
|
|
56
115
|
|
|
57
116
|
def get_backend() -> Backend:
|
meridian/backend/test_utils.py
CHANGED
|
@@ -14,7 +14,10 @@
|
|
|
14
14
|
|
|
15
15
|
"""Common testing utilities for Meridian, designed to be backend-agnostic."""
|
|
16
16
|
|
|
17
|
-
from typing import Any
|
|
17
|
+
from typing import Any, Optional
|
|
18
|
+
from absl.testing import parameterized
|
|
19
|
+
from meridian import backend
|
|
20
|
+
from meridian.backend import config
|
|
18
21
|
import numpy as np
|
|
19
22
|
|
|
20
23
|
# A type alias for backend-agnostic array-like objects.
|
|
@@ -67,6 +70,39 @@ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
|
|
|
67
70
|
np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
|
|
68
71
|
|
|
69
72
|
|
|
73
|
+
def assert_seed_allequal(a: Any, b: Any, err_msg: str = ""):
|
|
74
|
+
"""Backend-agnostic assertion to check if two seed objects are equal."""
|
|
75
|
+
data_a = backend.get_seed_data(a)
|
|
76
|
+
data_b = backend.get_seed_data(b)
|
|
77
|
+
if data_a is None and data_b is None:
|
|
78
|
+
return
|
|
79
|
+
np.testing.assert_array_equal(data_a, data_b, err_msg=err_msg)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def assert_not_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
|
|
83
|
+
"""Asserts that two objects are not element-wise equal."""
|
|
84
|
+
np.testing.assert_(
|
|
85
|
+
not np.array_equal(np.array(a), np.array(b)),
|
|
86
|
+
msg=f"Arrays are unexpectedly equal.\n{err_msg}",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def assert_seed_not_allequal(a: Any, b: Any, err_msg: str = ""):
|
|
91
|
+
"""Asserts that two seed objects are not element-wise equal."""
|
|
92
|
+
data_a = backend.get_seed_data(a)
|
|
93
|
+
data_b = backend.get_seed_data(b)
|
|
94
|
+
if data_a is None and data_b is None:
|
|
95
|
+
raise AssertionError(
|
|
96
|
+
f"Seeds are unexpectedly equal (both are None). {err_msg}"
|
|
97
|
+
)
|
|
98
|
+
if data_a is None or data_b is None:
|
|
99
|
+
return
|
|
100
|
+
np.testing.assert_(
|
|
101
|
+
not np.array_equal(data_a, data_b),
|
|
102
|
+
msg=f"Seeds are unexpectedly equal.\n{err_msg}",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
70
106
|
def assert_all_finite(a: ArrayLike, err_msg: str = ""):
|
|
71
107
|
"""Backend-agnostic assertion to check if all elements in an array are finite.
|
|
72
108
|
|
|
@@ -93,3 +129,93 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
|
|
|
93
129
|
"""
|
|
94
130
|
if not np.all(np.array(a) >= 0):
|
|
95
131
|
raise AssertionError(err_msg or "Array contains negative values.")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class MeridianTestCase(parameterized.TestCase):
|
|
135
|
+
"""Base test class for Meridian providing backend-aware utilities.
|
|
136
|
+
|
|
137
|
+
This class handles initialization timing issues (crucial for JAX by forcing
|
|
138
|
+
tensor operations into setUp) and provides a unified way to handle random
|
|
139
|
+
number generation across backends (Stateful TF vs Stateless JAX).
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def setUpClass(cls):
|
|
144
|
+
super().setUpClass()
|
|
145
|
+
# Enforce determinism for TensorFlow tests before any tests are run.
|
|
146
|
+
# This is a no-op with a warning for the JAX backend.
|
|
147
|
+
backend.enable_op_determinism()
|
|
148
|
+
|
|
149
|
+
def setUp(self):
|
|
150
|
+
super().setUp()
|
|
151
|
+
# Default seed, can be overridden by subclasses before calling
|
|
152
|
+
# _initialize_rng().
|
|
153
|
+
self.seed = 42
|
|
154
|
+
self._jax_key = None
|
|
155
|
+
self._initialize_rng()
|
|
156
|
+
|
|
157
|
+
def _initialize_rng(self):
|
|
158
|
+
"""Initializes the RNG state or key based on self.seed."""
|
|
159
|
+
current_backend = config.get_backend()
|
|
160
|
+
|
|
161
|
+
if current_backend == config.Backend.TENSORFLOW:
|
|
162
|
+
# In TF, we use the global stateful seed for test reproducibility.
|
|
163
|
+
try:
|
|
164
|
+
backend.set_random_seed(self.seed)
|
|
165
|
+
except NotImplementedError:
|
|
166
|
+
# Handle cases where backend might be misconfigured during transition.
|
|
167
|
+
pass
|
|
168
|
+
elif current_backend == config.Backend.JAX:
|
|
169
|
+
# In JAX, we must manage PRNGKeys explicitly.
|
|
170
|
+
# Import JAX locally to avoid hard dependency if TF is the active backend,
|
|
171
|
+
# and to ensure initialization happens after absltest.main() starts.
|
|
172
|
+
# pylint: disable=g-import-not-at-top
|
|
173
|
+
import jax
|
|
174
|
+
# pylint: enable=g-import-not-at-top
|
|
175
|
+
self._jax_key = jax.random.PRNGKey(self.seed)
|
|
176
|
+
else:
|
|
177
|
+
raise ValueError(f"Unknown backend: {current_backend}")
|
|
178
|
+
|
|
179
|
+
def get_next_rng_seed_or_key(self) -> Optional[Any]:
|
|
180
|
+
"""Gets the next available seed or key for backend operations.
|
|
181
|
+
|
|
182
|
+
This should be passed to the `seed` argument of TFP sampling methods.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
A JAX PRNGKey if the backend is JAX (splitting the internal key).
|
|
186
|
+
None if the backend is TensorFlow (relying on the global state).
|
|
187
|
+
"""
|
|
188
|
+
if self._jax_key is not None:
|
|
189
|
+
# JAX requires splitting the key for each use.
|
|
190
|
+
# pylint: disable=g-import-not-at-top
|
|
191
|
+
import jax
|
|
192
|
+
# pylint: enable=g-import-not-at-top
|
|
193
|
+
self._jax_key, subkey = jax.random.split(self._jax_key)
|
|
194
|
+
return subkey
|
|
195
|
+
else:
|
|
196
|
+
# For stateful TF, returning None allows TFP/TF to use the global seed.
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
def sample(
|
|
200
|
+
self,
|
|
201
|
+
distribution: backend.tfd.Distribution,
|
|
202
|
+
sample_shape: Any = (),
|
|
203
|
+
**kwargs: Any,
|
|
204
|
+
) -> backend.Tensor:
|
|
205
|
+
"""Performs a backend-agnostic sample from a distribution.
|
|
206
|
+
|
|
207
|
+
This method abstracts away the need for explicit seed management in JAX.
|
|
208
|
+
When the JAX backend is active, it automatically provides a PRNGKey from
|
|
209
|
+
the test's managed key state. In TensorFlow, it performs a standard sample.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
distribution: The TFP distribution object to sample from.
|
|
213
|
+
sample_shape: The shape of the desired sample.
|
|
214
|
+
**kwargs: Additional keyword arguments to pass to the underlying `sample`
|
|
215
|
+
method (e.g., `name`).
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
A tensor containing the sampled values.
|
|
219
|
+
"""
|
|
220
|
+
seed = self.get_next_rng_seed_or_key()
|
|
221
|
+
return distribution.sample(sample_shape=sample_shape, seed=seed, **kwargs)
|
meridian/constants.py
CHANGED
|
@@ -54,6 +54,9 @@ DATE_FORMAT = '%Y-%m-%d'
|
|
|
54
54
|
# Example: "2024 Apr"
|
|
55
55
|
QUARTER_FORMAT = '%Y %b'
|
|
56
56
|
|
|
57
|
+
ORGANIC_PREFIX = 'organic_'
|
|
58
|
+
NATIONAL_PREFIX = 'national_'
|
|
59
|
+
|
|
57
60
|
# Input data variables.
|
|
58
61
|
KPI = 'kpi'
|
|
59
62
|
REVENUE_PER_KPI = 'revenue_per_kpi'
|
|
@@ -65,10 +68,10 @@ REACH = 'reach'
|
|
|
65
68
|
FREQUENCY = 'frequency'
|
|
66
69
|
RF_IMPRESSIONS = 'rf_impressions'
|
|
67
70
|
RF_SPEND = 'rf_spend'
|
|
68
|
-
ORGANIC_MEDIA =
|
|
69
|
-
ORGANIC_RF
|
|
70
|
-
ORGANIC_REACH =
|
|
71
|
-
ORGANIC_FREQUENCY =
|
|
71
|
+
ORGANIC_MEDIA = ORGANIC_PREFIX + MEDIA
|
|
72
|
+
# ORGANIC_RF is defined below.
|
|
73
|
+
ORGANIC_REACH = ORGANIC_PREFIX + REACH
|
|
74
|
+
ORGANIC_FREQUENCY = ORGANIC_PREFIX + FREQUENCY
|
|
72
75
|
NON_MEDIA_TREATMENTS = 'non_media_treatments'
|
|
73
76
|
REVENUE = 'revenue'
|
|
74
77
|
NON_REVENUE = 'non_revenue'
|
|
@@ -121,15 +124,17 @@ RF_DATA = (
|
|
|
121
124
|
RF_SPEND,
|
|
122
125
|
REVENUE_PER_KPI,
|
|
123
126
|
)
|
|
124
|
-
NON_REVENUE_DATA = IMPRESSIONS_DATA + (CONTROLS,)
|
|
125
127
|
|
|
126
128
|
# Scaled input data variables.
|
|
127
129
|
MEDIA_SCALED = 'media_scaled'
|
|
128
130
|
REACH_SCALED = 'reach_scaled'
|
|
129
|
-
ORGANIC_MEDIA_SCALED =
|
|
130
|
-
ORGANIC_REACH_SCALED =
|
|
131
|
+
ORGANIC_MEDIA_SCALED = ORGANIC_PREFIX + MEDIA_SCALED
|
|
132
|
+
ORGANIC_REACH_SCALED = ORGANIC_PREFIX + REACH_SCALED
|
|
131
133
|
NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
|
|
132
134
|
CONTROLS_SCALED = 'controls_scaled'
|
|
135
|
+
KPI_SCALED = f'{KPI}_scaled'
|
|
136
|
+
POPULATION_SCALED_KPI = f'{POPULATION}_scaled_{KPI}'
|
|
137
|
+
RF_IMPRESSIONS_SCALED = f'{RF_IMPRESSIONS}_scaled'
|
|
133
138
|
|
|
134
139
|
# Non-media treatments baseline value constants.
|
|
135
140
|
NON_MEDIA_BASELINE_MIN = 'min'
|
|
@@ -144,9 +149,9 @@ MEDIA_CHANNEL = 'media_channel'
|
|
|
144
149
|
RF_CHANNEL = 'rf_channel'
|
|
145
150
|
CHANNEL = 'channel'
|
|
146
151
|
RF = 'rf'
|
|
147
|
-
ORGANIC_RF =
|
|
148
|
-
ORGANIC_MEDIA_CHANNEL =
|
|
149
|
-
ORGANIC_RF_CHANNEL =
|
|
152
|
+
ORGANIC_RF = ORGANIC_PREFIX + RF
|
|
153
|
+
ORGANIC_MEDIA_CHANNEL = ORGANIC_PREFIX + MEDIA_CHANNEL
|
|
154
|
+
ORGANIC_RF_CHANNEL = ORGANIC_PREFIX + RF_CHANNEL
|
|
150
155
|
NON_MEDIA_CHANNEL = 'non_media_channel'
|
|
151
156
|
CONTROL_VARIABLE = 'control_variable'
|
|
152
157
|
REQUIRED_INPUT_DATA_COORD_NAMES = (
|
|
@@ -172,6 +177,41 @@ POSSIBLE_INPUT_DATA_COORDS_AND_ARRAYS_SET = frozenset(
|
|
|
172
177
|
POSSIBLE_INPUT_DATA_COORD_NAMES + POSSIBLE_INPUT_DATA_ARRAY_NAMES
|
|
173
178
|
)
|
|
174
179
|
|
|
180
|
+
# EDA Engine properties
|
|
181
|
+
ORGANIC_RF_IMPRESSIONS = ORGANIC_PREFIX + RF_IMPRESSIONS
|
|
182
|
+
ORGANIC_RF_IMPRESSIONS_SCALED = f'{ORGANIC_RF_IMPRESSIONS}_scaled'
|
|
183
|
+
TREATMENT_CONTROL_SCALED = 'treatment_control_scaled'
|
|
184
|
+
NATIONAL_TREATMENT_CONTROL_SCALED = (
|
|
185
|
+
f'{NATIONAL_PREFIX}{TREATMENT_CONTROL_SCALED}'
|
|
186
|
+
)
|
|
187
|
+
NATIONAL_CONTROLS_SCALED = f'{NATIONAL_PREFIX}{CONTROLS_SCALED}'
|
|
188
|
+
NATIONAL_MEDIA_SPEND = f'{NATIONAL_PREFIX}{MEDIA_SPEND}'
|
|
189
|
+
NATIONAL_MEDIA = f'{NATIONAL_PREFIX}{MEDIA}'
|
|
190
|
+
NATIONAL_MEDIA_SCALED = f'{NATIONAL_PREFIX}{MEDIA_SCALED}'
|
|
191
|
+
NATIONAL_ORGANIC_MEDIA = f'{NATIONAL_PREFIX}{ORGANIC_MEDIA}'
|
|
192
|
+
NATIONAL_ORGANIC_MEDIA_SCALED = f'{NATIONAL_PREFIX}{ORGANIC_MEDIA_SCALED}'
|
|
193
|
+
NATIONAL_NON_MEDIA_TREATMENTS_SCALED = (
|
|
194
|
+
f'{NATIONAL_PREFIX}{NON_MEDIA_TREATMENTS_SCALED}'
|
|
195
|
+
)
|
|
196
|
+
NATIONAL_RF_SPEND = f'{NATIONAL_PREFIX}{RF_SPEND}'
|
|
197
|
+
NATIONAL_KPI_SCALED = f'{NATIONAL_PREFIX}{KPI_SCALED}'
|
|
198
|
+
NATIONAL_REACH = f'{NATIONAL_PREFIX}{REACH}'
|
|
199
|
+
NATIONAL_REACH_SCALED = f'{NATIONAL_PREFIX}{REACH_SCALED}'
|
|
200
|
+
NATIONAL_ORGANIC_REACH = f'{NATIONAL_PREFIX}{ORGANIC_REACH}'
|
|
201
|
+
NATIONAL_ORGANIC_REACH_SCALED = f'{NATIONAL_PREFIX}{ORGANIC_REACH_SCALED}'
|
|
202
|
+
NATIONAL_FREQUENCY = f'{NATIONAL_PREFIX}{FREQUENCY}'
|
|
203
|
+
NATIONAL_ORGANIC_FREQUENCY = f'{NATIONAL_PREFIX}{ORGANIC_FREQUENCY}'
|
|
204
|
+
NATIONAL_RF_IMPRESSIONS = f'{NATIONAL_PREFIX}{RF_IMPRESSIONS}'
|
|
205
|
+
NATIONAL_ORGANIC_RF_IMPRESSIONS = f'{NATIONAL_PREFIX}{ORGANIC_RF_IMPRESSIONS}'
|
|
206
|
+
NATIONAL_RF_IMPRESSIONS_SCALED = f'{NATIONAL_PREFIX}{RF_IMPRESSIONS_SCALED}'
|
|
207
|
+
NATIONAL_ORGANIC_RF_IMPRESSIONS_SCALED = (
|
|
208
|
+
f'{NATIONAL_PREFIX}{ORGANIC_RF_IMPRESSIONS_SCALED}'
|
|
209
|
+
)
|
|
210
|
+
ALL_REACH_SCALED = 'all_reach_scaled'
|
|
211
|
+
ALL_FREQUENCY = 'all_frequency'
|
|
212
|
+
NATIONAL_ALL_REACH_SCALED = f'{NATIONAL_PREFIX}{ALL_REACH_SCALED}'
|
|
213
|
+
NATIONAL_ALL_FREQUENCY = f'{NATIONAL_PREFIX}{ALL_FREQUENCY}'
|
|
214
|
+
|
|
175
215
|
|
|
176
216
|
# National model constants.
|
|
177
217
|
NATIONAL = 'national'
|
|
@@ -354,7 +394,6 @@ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
|
|
|
354
394
|
ETA_ORF,
|
|
355
395
|
)
|
|
356
396
|
|
|
357
|
-
|
|
358
397
|
MEDIA_PARAMETERS = (
|
|
359
398
|
ROI_M,
|
|
360
399
|
MROI_M,
|
|
@@ -742,6 +781,8 @@ HILL_NUM_STEPS = 500
|
|
|
742
781
|
# Summary template params.
|
|
743
782
|
START_DATE = 'start_date'
|
|
744
783
|
END_DATE = 'end_date'
|
|
784
|
+
DEFAULT_CURRENCY = '$'
|
|
785
|
+
SELECTED_GEOS = 'selected_geos'
|
|
745
786
|
CARD_INSIGHTS = 'insights'
|
|
746
787
|
CARD_CHARTS = 'charts'
|
|
747
788
|
CARD_STATS = 'stats'
|
meridian/data/input_data.py
CHANGED
|
@@ -454,14 +454,19 @@ class InputData:
|
|
|
454
454
|
population = self.population.values[:, np.newaxis]
|
|
455
455
|
|
|
456
456
|
population_scaled_kpi = np.divide(
|
|
457
|
-
kpi,
|
|
457
|
+
kpi,
|
|
458
|
+
population,
|
|
459
|
+
out=np.zeros_like(kpi, dtype=float),
|
|
460
|
+
where=(population != 0),
|
|
458
461
|
)
|
|
459
462
|
population_scaled_mean = np.mean(population_scaled_kpi)
|
|
460
463
|
population_scaled_stdev = np.std(population_scaled_kpi)
|
|
461
464
|
kpi_scaled = np.divide(
|
|
462
465
|
population_scaled_kpi - population_scaled_mean,
|
|
463
466
|
population_scaled_stdev,
|
|
464
|
-
out=np.zeros_like(
|
|
467
|
+
out=np.zeros_like(
|
|
468
|
+
population_scaled_kpi - population_scaled_mean, dtype=float
|
|
469
|
+
),
|
|
465
470
|
where=(population_scaled_stdev != 0),
|
|
466
471
|
)
|
|
467
472
|
return kpi_scaled - np.mean(kpi_scaled, axis=1, keepdims=True)
|
meridian/data/test_utils.py
CHANGED
|
@@ -1898,10 +1898,12 @@ def sample_input_data_for_aks_with_expected_knot_info() -> (
|
|
|
1898
1898
|
'non_revenue',
|
|
1899
1899
|
)
|
|
1900
1900
|
expected_knot_info = knots.KnotInfo(
|
|
1901
|
-
n_knots=
|
|
1902
|
-
knot_locations=np.array(
|
|
1901
|
+
n_knots=13,
|
|
1902
|
+
knot_locations=np.array(
|
|
1903
|
+
[11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90]
|
|
1904
|
+
),
|
|
1903
1905
|
weights=knots.l1_distance_weights(
|
|
1904
|
-
117, np.array([38, 39, 41, 48, 50, 55])
|
|
1906
|
+
117, np.array([11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90])
|
|
1905
1907
|
),
|
|
1906
1908
|
)
|
|
1907
1909
|
return data, expected_knot_info
|
meridian/mlflow/autolog.py
CHANGED
|
@@ -72,6 +72,7 @@ import json
|
|
|
72
72
|
from typing import Any, Callable
|
|
73
73
|
|
|
74
74
|
import arviz as az
|
|
75
|
+
from meridian import backend
|
|
75
76
|
from meridian.analysis import visualizer
|
|
76
77
|
import mlflow
|
|
77
78
|
from mlflow.utils.autologging_utils import autologging_integration, safe_patch
|
|
@@ -81,7 +82,6 @@ from meridian.model import prior_sampler
|
|
|
81
82
|
from meridian.model import spec
|
|
82
83
|
from meridian.version import __version__
|
|
83
84
|
import numpy as np
|
|
84
|
-
import tensorflow_probability as tfp
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
FLAVOR_NAME = "meridian"
|
|
@@ -123,7 +123,7 @@ def _log_priors(model_spec: spec.ModelSpec) -> None:
|
|
|
123
123
|
field_value = getattr(priors, field.name)
|
|
124
124
|
|
|
125
125
|
# Stringify Distributions and numpy arrays.
|
|
126
|
-
if isinstance(field_value,
|
|
126
|
+
if isinstance(field_value, backend.tfd.Distribution):
|
|
127
127
|
field_value = str(field_value)
|
|
128
128
|
elif isinstance(field_value, np.ndarray):
|
|
129
129
|
field_value = json.dumps(field_value.tolist())
|
meridian/model/__init__.py
CHANGED
meridian/model/adstock_hill.py
CHANGED
|
@@ -145,8 +145,9 @@ def compute_decay_weights(
|
|
|
145
145
|
alpha, l_range, window_size, constants.GEOMETRIC_DECAY, normalize,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
+
is_binomial = [s == constants.BINOMIAL_DECAY for s in decay_functions]
|
|
148
149
|
binomial_decay_mask = backend.reshape(
|
|
149
|
-
backend.to_tensor(
|
|
150
|
+
backend.to_tensor(is_binomial, dtype=backend.bool_),
|
|
150
151
|
(-1, 1),
|
|
151
152
|
)
|
|
152
153
|
|
|
@@ -189,14 +190,14 @@ def _compute_single_decay_function_weights(
|
|
|
189
190
|
A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
|
|
190
191
|
"""
|
|
191
192
|
expanded_alpha = backend.expand_dims(alpha, -1)
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
193
|
+
|
|
194
|
+
if decay_function == constants.GEOMETRIC_DECAY:
|
|
195
|
+
weights = expanded_alpha**l_range
|
|
196
|
+
elif decay_function == constants.BINOMIAL_DECAY:
|
|
197
|
+
mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
|
|
198
|
+
weights = (1 - l_range / window_size) ** mapped_alpha_binomial
|
|
199
|
+
else:
|
|
200
|
+
raise ValueError(f'Unsupported decay function: {decay_function}')
|
|
200
201
|
|
|
201
202
|
if normalize:
|
|
202
203
|
normalization_factors = backend.reduce_sum(weights, axis=-1, keepdims=True)
|
meridian/model/eda/__init__.py
CHANGED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Constants specific to MeridianEDA."""
|
|
16
|
+
|
|
17
|
+
# EDA Plotting properties
|
|
18
|
+
VARIABLE_1 = 'var1'
|
|
19
|
+
VARIABLE_2 = 'var2'
|
|
20
|
+
VARIABLE = 'var'
|
|
21
|
+
CORRELATION = 'correlation'
|