google-meridian 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,8 @@
15
15
  """Backend configuration for Meridian."""
16
16
 
17
17
  import enum
18
+ import os
19
+ from typing import Union
18
20
  import warnings
19
21
 
20
22
 
@@ -23,35 +25,92 @@ class Backend(enum.Enum):
23
25
  JAX = "jax"
24
26
 
25
27
 
26
- _BACKEND = Backend.TENSORFLOW
28
+ _DEFAULT_BACKEND = Backend.TENSORFLOW
27
29
 
28
30
 
29
- def set_backend(backend: Backend) -> None:
31
+ def _warn_jax_experimental() -> None:
32
+ """Issues a warning that the JAX backend is experimental."""
33
+ warnings.warn(
34
+ (
35
+ "The JAX backend is currently under development and is not yet"
36
+ " functional. It is intended for internal testing only and should"
37
+ " not be used. Please use the TensorFlow backend."
38
+ ),
39
+ UserWarning,
40
+ # Set stacklevel=2 so the warning points to the caller of set_backend
41
+ # or the location where the module is imported if initialized via env var.
42
+ stacklevel=2,
43
+ )
44
+
45
+
46
+ def _initialize_backend() -> Backend:
47
+ """Initializes the backend based on environment variables or defaults."""
48
+ env_backend_str = os.environ.get("MERIDIAN_BACKEND")
49
+
50
+ if not env_backend_str:
51
+ return _DEFAULT_BACKEND
52
+
53
+ try:
54
+ backend = Backend(env_backend_str.lower())
55
+ if backend == Backend.JAX:
56
+ _warn_jax_experimental()
57
+ return backend
58
+ except ValueError:
59
+ warnings.warn(
60
+ (
61
+ "Invalid MERIDIAN_BACKEND environment variable:"
62
+ f" '{env_backend_str}'. Supported values are 'tensorflow' and"
63
+ f" 'jax'. Defaulting to {_DEFAULT_BACKEND.value}."
64
+ ),
65
+ RuntimeWarning,
66
+ )
67
+ return _DEFAULT_BACKEND
68
+
69
+
70
+ _BACKEND = _initialize_backend()
71
+
72
+
73
+ def set_backend(backend: Union[Backend, str]) -> None:
30
74
  """Sets the backend for Meridian.
31
75
 
76
+ **Warning:** This function should ideally be called at the beginning of your
77
+ program, before any other Meridian modules are imported or used.
78
+
79
+ Changing the backend after Meridian's functions or classes have been
80
+ imported can lead to unpredictable behavior. This is because already-imported
81
+ modules will not reflect the backend change.
82
+
32
83
  Note: The JAX backend is currently under development and should not be used.
33
84
 
85
+ Changing the backend at runtime requires reloading the `meridian.backend`
86
+ module for the changes to take effect globally.
87
+
34
88
  Args:
35
- backend: The backend to use, must be a member of the `Backend` enum.
89
+ backend: The backend to use, must be a member of the `Backend` enum or a
90
+ valid string ('tensorflow', 'jax').
36
91
 
37
92
  Raises:
38
- ValueError: If the provided backend is not a valid `Backend` enum member.
93
+ ValueError: If the provided backend is not valid.
39
94
  """
40
95
  global _BACKEND
41
- if not isinstance(backend, Backend):
42
- raise ValueError("Backend must be a member of the Backend enum.")
43
96
 
44
- if backend == Backend.JAX:
45
- warnings.warn(
46
- (
47
- "The JAX backend is currently under development and is not yet"
48
- " functional. It is intended for internal testing only and should"
49
- " not be used. Please use the TensorFlow backend."
50
- ),
51
- UserWarning,
52
- )
97
+ if isinstance(backend, str):
98
+ try:
99
+ backend_enum = Backend(backend.lower())
100
+ except ValueError as exc:
101
+ raise ValueError(
102
+ f"Invalid backend string '{backend}'. Must be one of: "
103
+ f"{[b.value for b in Backend]}"
104
+ ) from exc
105
+ elif isinstance(backend, Backend):
106
+ backend_enum = backend
107
+ else:
108
+ raise ValueError("Backend must be a Backend enum member or a string.")
109
+
110
+ if backend_enum == Backend.JAX and _BACKEND != Backend.JAX:
111
+ _warn_jax_experimental()
53
112
 
54
- _BACKEND = backend
113
+ _BACKEND = backend_enum
55
114
 
56
115
 
57
116
  def get_backend() -> Backend:
@@ -14,7 +14,10 @@
14
14
 
15
15
  """Common testing utilities for Meridian, designed to be backend-agnostic."""
16
16
 
17
- from typing import Any
17
+ from typing import Any, Optional
18
+ from absl.testing import parameterized
19
+ from meridian import backend
20
+ from meridian.backend import config
18
21
  import numpy as np
19
22
 
20
23
  # A type alias for backend-agnostic array-like objects.
@@ -93,3 +96,86 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
93
96
  """
94
97
  if not np.all(np.array(a) >= 0):
95
98
  raise AssertionError(err_msg or "Array contains negative values.")
99
+
100
+
101
+ class MeridianTestCase(parameterized.TestCase):
102
+ """Base test class for Meridian providing backend-aware utilities.
103
+
104
+ This class handles initialization timing issues (crucial for JAX by forcing
105
+ tensor operations into setUp) and provides a unified way to handle random
106
+ number generation across backends (Stateful TF vs Stateless JAX).
107
+ """
108
+
109
+ def setUp(self):
110
+ super().setUp()
111
+ # Default seed, can be overridden by subclasses before calling
112
+ # _initialize_rng().
113
+ self.seed = 42
114
+ self._jax_key = None
115
+ self._initialize_rng()
116
+
117
+ def _initialize_rng(self):
118
+ """Initializes the RNG state or key based on self.seed."""
119
+ current_backend = config.get_backend()
120
+
121
+ if current_backend == config.Backend.TENSORFLOW:
122
+ # In TF, we use the global stateful seed for test reproducibility.
123
+ try:
124
+ backend.set_random_seed(self.seed)
125
+ except NotImplementedError:
126
+ # Handle cases where backend might be misconfigured during transition.
127
+ pass
128
+ elif current_backend == config.Backend.JAX:
129
+ # In JAX, we must manage PRNGKeys explicitly.
130
+ # Import JAX locally to avoid hard dependency if TF is the active backend,
131
+ # and to ensure initialization happens after absltest.main() starts.
132
+ # pylint: disable=g-import-not-at-top
133
+ import jax
134
+ # pylint: enable=g-import-not-at-top
135
+ self._jax_key = jax.random.PRNGKey(self.seed)
136
+ else:
137
+ raise ValueError(f"Unknown backend: {current_backend}")
138
+
139
+ def get_next_rng_seed_or_key(self) -> Optional[Any]:
140
+ """Gets the next available seed or key for backend operations.
141
+
142
+ This should be passed to the `seed` argument of TFP sampling methods.
143
+
144
+ Returns:
145
+ A JAX PRNGKey if the backend is JAX (splitting the internal key).
146
+ None if the backend is TensorFlow (relying on the global state).
147
+ """
148
+ if self._jax_key is not None:
149
+ # JAX requires splitting the key for each use.
150
+ # pylint: disable=g-import-not-at-top
151
+ import jax
152
+ # pylint: enable=g-import-not-at-top
153
+ self._jax_key, subkey = jax.random.split(self._jax_key)
154
+ return subkey
155
+ else:
156
+ # For stateful TF, returning None allows TFP/TF to use the global seed.
157
+ return None
158
+
159
+ def sample(
160
+ self,
161
+ distribution: backend.tfd.Distribution,
162
+ sample_shape: Any = (),
163
+ **kwargs: Any,
164
+ ) -> backend.Tensor:
165
+ """Performs a backend-agnostic sample from a distribution.
166
+
167
+ This method abstracts away the need for explicit seed management in JAX.
168
+ When the JAX backend is active, it automatically provides a PRNGKey from
169
+ the test's managed key state. In TensorFlow, it performs a standard sample.
170
+
171
+ Args:
172
+ distribution: The TFP distribution object to sample from.
173
+ sample_shape: The shape of the desired sample.
174
+ **kwargs: Additional keyword arguments to pass to the underlying `sample`
175
+ method (e.g., `name`).
176
+
177
+ Returns:
178
+ A tensor containing the sampled values.
179
+ """
180
+ seed = self.get_next_rng_seed_or_key()
181
+ return distribution.sample(sample_shape=sample_shape, seed=seed, **kwargs)
meridian/constants.py CHANGED
@@ -54,6 +54,8 @@ DATE_FORMAT = '%Y-%m-%d'
54
54
  # Example: "2024 Apr"
55
55
  QUARTER_FORMAT = '%Y %b'
56
56
 
57
+ ORGANIC_PREFIX = 'organic_'
58
+
57
59
  # Input data variables.
58
60
  KPI = 'kpi'
59
61
  REVENUE_PER_KPI = 'revenue_per_kpi'
@@ -65,10 +67,10 @@ REACH = 'reach'
65
67
  FREQUENCY = 'frequency'
66
68
  RF_IMPRESSIONS = 'rf_impressions'
67
69
  RF_SPEND = 'rf_spend'
68
- ORGANIC_MEDIA = 'organic_media'
69
- ORGANIC_RF = 'organic_rf'
70
- ORGANIC_REACH = 'organic_reach'
71
- ORGANIC_FREQUENCY = 'organic_frequency'
70
+ ORGANIC_MEDIA = ORGANIC_PREFIX + MEDIA
71
+ # ORGANIC_RF is defined below.
72
+ ORGANIC_REACH = ORGANIC_PREFIX + REACH
73
+ ORGANIC_FREQUENCY = ORGANIC_PREFIX + FREQUENCY
72
74
  NON_MEDIA_TREATMENTS = 'non_media_treatments'
73
75
  REVENUE = 'revenue'
74
76
  NON_REVENUE = 'non_revenue'
@@ -126,8 +128,8 @@ NON_REVENUE_DATA = IMPRESSIONS_DATA + (CONTROLS,)
126
128
  # Scaled input data variables.
127
129
  MEDIA_SCALED = 'media_scaled'
128
130
  REACH_SCALED = 'reach_scaled'
129
- ORGANIC_MEDIA_SCALED = 'organic_media_scaled'
130
- ORGANIC_REACH_SCALED = 'organic_reach_scaled'
131
+ ORGANIC_MEDIA_SCALED = ORGANIC_PREFIX + MEDIA_SCALED
132
+ ORGANIC_REACH_SCALED = ORGANIC_PREFIX + REACH_SCALED
131
133
  NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
132
134
  CONTROLS_SCALED = 'controls_scaled'
133
135
 
@@ -144,9 +146,9 @@ MEDIA_CHANNEL = 'media_channel'
144
146
  RF_CHANNEL = 'rf_channel'
145
147
  CHANNEL = 'channel'
146
148
  RF = 'rf'
147
- ORGANIC_RF = 'organic_rf'
148
- ORGANIC_MEDIA_CHANNEL = 'organic_media_channel'
149
- ORGANIC_RF_CHANNEL = 'organic_rf_channel'
149
+ ORGANIC_RF = ORGANIC_PREFIX + RF
150
+ ORGANIC_MEDIA_CHANNEL = ORGANIC_PREFIX + MEDIA_CHANNEL
151
+ ORGANIC_RF_CHANNEL = ORGANIC_PREFIX + RF_CHANNEL
150
152
  NON_MEDIA_CHANNEL = 'non_media_channel'
151
153
  CONTROL_VARIABLE = 'control_variable'
152
154
  REQUIRED_INPUT_DATA_COORD_NAMES = (
@@ -172,6 +174,9 @@ POSSIBLE_INPUT_DATA_COORDS_AND_ARRAYS_SET = frozenset(
172
174
  POSSIBLE_INPUT_DATA_COORD_NAMES + POSSIBLE_INPUT_DATA_ARRAY_NAMES
173
175
  )
174
176
 
177
+ # EDA property constants
178
+ ORGANIC_RF_IMPRESSIONS = ORGANIC_PREFIX + RF_IMPRESSIONS
179
+
175
180
 
176
181
  # National model constants.
177
182
  NATIONAL = 'national'
@@ -454,14 +454,19 @@ class InputData:
454
454
  population = self.population.values[:, np.newaxis]
455
455
 
456
456
  population_scaled_kpi = np.divide(
457
- kpi, population, out=np.zeros_like(kpi), where=(population != 0)
457
+ kpi,
458
+ population,
459
+ out=np.zeros_like(kpi, dtype=float),
460
+ where=(population != 0),
458
461
  )
459
462
  population_scaled_mean = np.mean(population_scaled_kpi)
460
463
  population_scaled_stdev = np.std(population_scaled_kpi)
461
464
  kpi_scaled = np.divide(
462
465
  population_scaled_kpi - population_scaled_mean,
463
466
  population_scaled_stdev,
464
- out=np.zeros_like(population_scaled_kpi - population_scaled_mean),
467
+ out=np.zeros_like(
468
+ population_scaled_kpi - population_scaled_mean, dtype=float
469
+ ),
465
470
  where=(population_scaled_stdev != 0),
466
471
  )
467
472
  return kpi_scaled - np.mean(kpi_scaled, axis=1, keepdims=True)
@@ -1898,10 +1898,12 @@ def sample_input_data_for_aks_with_expected_knot_info() -> (
1898
1898
  'non_revenue',
1899
1899
  )
1900
1900
  expected_knot_info = knots.KnotInfo(
1901
- n_knots=6,
1902
- knot_locations=np.array([38, 39, 41, 48, 50, 55]),
1901
+ n_knots=13,
1902
+ knot_locations=np.array(
1903
+ [11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90]
1904
+ ),
1903
1905
  weights=knots.l1_distance_weights(
1904
- 117, np.array([38, 39, 41, 48, 50, 55])
1906
+ 117, np.array([11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90])
1905
1907
  ),
1906
1908
  )
1907
1909
  return data, expected_knot_info
@@ -72,6 +72,7 @@ import json
72
72
  from typing import Any, Callable
73
73
 
74
74
  import arviz as az
75
+ from meridian import backend
75
76
  from meridian.analysis import visualizer
76
77
  import mlflow
77
78
  from mlflow.utils.autologging_utils import autologging_integration, safe_patch
@@ -81,7 +82,6 @@ from meridian.model import prior_sampler
81
82
  from meridian.model import spec
82
83
  from meridian.version import __version__
83
84
  import numpy as np
84
- import tensorflow_probability as tfp
85
85
 
86
86
 
87
87
  FLAVOR_NAME = "meridian"
@@ -123,7 +123,7 @@ def _log_priors(model_spec: spec.ModelSpec) -> None:
123
123
  field_value = getattr(priors, field.name)
124
124
 
125
125
  # Stringify Distributions and numpy arrays.
126
- if isinstance(field_value, tfp.distributions.Distribution):
126
+ if isinstance(field_value, backend.tfd.Distribution):
127
127
  field_value = str(field_value)
128
128
  elif isinstance(field_value, np.ndarray):
129
129
  field_value = json.dumps(field_value.tolist())
@@ -145,8 +145,9 @@ def compute_decay_weights(
145
145
  alpha, l_range, window_size, constants.GEOMETRIC_DECAY, normalize,
146
146
  )
147
147
 
148
+ is_binomial = [s == constants.BINOMIAL_DECAY for s in decay_functions]
148
149
  binomial_decay_mask = backend.reshape(
149
- backend.to_tensor(decay_functions) == constants.BINOMIAL_DECAY,
150
+ backend.to_tensor(is_binomial, dtype=backend.bool_),
150
151
  (-1, 1),
151
152
  )
152
153
 
@@ -189,14 +190,14 @@ def _compute_single_decay_function_weights(
189
190
  A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
190
191
  """
191
192
  expanded_alpha = backend.expand_dims(alpha, -1)
192
- match decay_function:
193
- case constants.GEOMETRIC_DECAY:
194
- weights = expanded_alpha**l_range
195
- case constants.BINOMIAL_DECAY:
196
- mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
197
- weights = (1 - l_range / window_size) ** mapped_alpha_binomial
198
- case _:
199
- raise ValueError(f'Unsupported decay function: {decay_function}')
193
+
194
+ if decay_function == constants.GEOMETRIC_DECAY:
195
+ weights = expanded_alpha**l_range
196
+ elif decay_function == constants.BINOMIAL_DECAY:
197
+ mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
198
+ weights = (1 - l_range / window_size) ** mapped_alpha_binomial
199
+ else:
200
+ raise ValueError(f'Unsupported decay function: {decay_function}')
200
201
 
201
202
  if normalize:
202
203
  normalization_factors = backend.reduce_sum(weights, axis=-1, keepdims=True)