google-meridian 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/METADATA +10 -10
  2. google_meridian-1.3.0.dist-info/RECORD +62 -0
  3. meridian/analysis/__init__.py +2 -0
  4. meridian/analysis/analyzer.py +280 -142
  5. meridian/analysis/formatter.py +2 -2
  6. meridian/analysis/optimizer.py +353 -169
  7. meridian/analysis/review/__init__.py +20 -0
  8. meridian/analysis/review/checks.py +721 -0
  9. meridian/analysis/review/configs.py +110 -0
  10. meridian/analysis/review/constants.py +40 -0
  11. meridian/analysis/review/results.py +544 -0
  12. meridian/analysis/review/reviewer.py +186 -0
  13. meridian/analysis/summarizer.py +14 -12
  14. meridian/analysis/templates/chips.html.jinja +12 -0
  15. meridian/analysis/test_utils.py +27 -5
  16. meridian/analysis/visualizer.py +45 -50
  17. meridian/backend/__init__.py +698 -55
  18. meridian/backend/config.py +75 -16
  19. meridian/backend/test_utils.py +127 -1
  20. meridian/constants.py +52 -11
  21. meridian/data/input_data.py +7 -2
  22. meridian/data/test_utils.py +5 -3
  23. meridian/mlflow/autolog.py +2 -2
  24. meridian/model/__init__.py +1 -0
  25. meridian/model/adstock_hill.py +10 -9
  26. meridian/model/eda/__init__.py +3 -0
  27. meridian/model/eda/constants.py +21 -0
  28. meridian/model/eda/eda_engine.py +1580 -84
  29. meridian/model/eda/eda_outcome.py +200 -0
  30. meridian/model/eda/eda_spec.py +84 -0
  31. meridian/model/eda/meridian_eda.py +220 -0
  32. meridian/model/knots.py +56 -50
  33. meridian/model/media.py +10 -8
  34. meridian/model/model.py +79 -16
  35. meridian/model/model_test_data.py +53 -9
  36. meridian/model/posterior_sampler.py +398 -391
  37. meridian/model/prior_distribution.py +114 -39
  38. meridian/model/prior_sampler.py +146 -90
  39. meridian/model/spec.py +7 -8
  40. meridian/model/transformers.py +16 -8
  41. meridian/version.py +1 -1
  42. google_meridian-1.2.0.dist-info/RECORD +0 -52
  43. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/WHEEL +0 -0
  44. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/licenses/LICENSE +0 -0
  45. {google_meridian-1.2.0.dist-info → google_meridian-1.3.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@
15
15
  """Backend configuration for Meridian."""
16
16
 
17
17
  import enum
18
+ import os
19
+ from typing import Union
18
20
  import warnings
19
21
 
20
22
 
@@ -23,35 +25,92 @@ class Backend(enum.Enum):
23
25
  JAX = "jax"
24
26
 
25
27
 
26
- _BACKEND = Backend.TENSORFLOW
28
+ _DEFAULT_BACKEND = Backend.TENSORFLOW
27
29
 
28
30
 
29
- def set_backend(backend: Backend) -> None:
31
+ def _warn_jax_experimental() -> None:
32
+ """Issues a warning that the JAX backend is experimental."""
33
+ warnings.warn(
34
+ (
35
+ "The JAX backend is currently under development and is not yet"
36
+ " functional. It is intended for internal testing only and should"
37
+ " not be used. Please use the TensorFlow backend."
38
+ ),
39
+ UserWarning,
40
+ # Set stacklevel=2 so the warning points to the caller of set_backend
41
+ # or the location where the module is imported if initialized via env var.
42
+ stacklevel=2,
43
+ )
44
+
45
+
46
+ def _initialize_backend() -> Backend:
47
+ """Initializes the backend based on environment variables or defaults."""
48
+ env_backend_str = os.environ.get("MERIDIAN_BACKEND")
49
+
50
+ if not env_backend_str:
51
+ return _DEFAULT_BACKEND
52
+
53
+ try:
54
+ backend = Backend(env_backend_str.lower())
55
+ if backend == Backend.JAX:
56
+ _warn_jax_experimental()
57
+ return backend
58
+ except ValueError:
59
+ warnings.warn(
60
+ (
61
+ "Invalid MERIDIAN_BACKEND environment variable:"
62
+ f" '{env_backend_str}'. Supported values are 'tensorflow' and"
63
+ f" 'jax'. Defaulting to {_DEFAULT_BACKEND.value}."
64
+ ),
65
+ RuntimeWarning,
66
+ )
67
+ return _DEFAULT_BACKEND
68
+
69
+
70
+ _BACKEND = _initialize_backend()
71
+
72
+
73
+ def set_backend(backend: Union[Backend, str]) -> None:
30
74
  """Sets the backend for Meridian.
31
75
 
76
+ **Warning:** This function should ideally be called at the beginning of your
77
+ program, before any other Meridian modules are imported or used.
78
+
79
+ Changing the backend after Meridian's functions or classes have been
80
+ imported can lead to unpredictable behavior. This is because already-imported
81
+ modules will not reflect the backend change.
82
+
32
83
  Note: The JAX backend is currently under development and should not be used.
33
84
 
85
+ Changing the backend at runtime requires reloading the `meridian.backend`
86
+ module for the changes to take effect globally.
87
+
34
88
  Args:
35
- backend: The backend to use, must be a member of the `Backend` enum.
89
+ backend: The backend to use, must be a member of the `Backend` enum or a
90
+ valid string ('tensorflow', 'jax').
36
91
 
37
92
  Raises:
38
- ValueError: If the provided backend is not a valid `Backend` enum member.
93
+ ValueError: If the provided backend is not valid.
39
94
  """
40
95
  global _BACKEND
41
- if not isinstance(backend, Backend):
42
- raise ValueError("Backend must be a member of the Backend enum.")
43
96
 
44
- if backend == Backend.JAX:
45
- warnings.warn(
46
- (
47
- "The JAX backend is currently under development and is not yet"
48
- " functional. It is intended for internal testing only and should"
49
- " not be used. Please use the TensorFlow backend."
50
- ),
51
- UserWarning,
52
- )
97
+ if isinstance(backend, str):
98
+ try:
99
+ backend_enum = Backend(backend.lower())
100
+ except ValueError as exc:
101
+ raise ValueError(
102
+ f"Invalid backend string '{backend}'. Must be one of: "
103
+ f"{[b.value for b in Backend]}"
104
+ ) from exc
105
+ elif isinstance(backend, Backend):
106
+ backend_enum = backend
107
+ else:
108
+ raise ValueError("Backend must be a Backend enum member or a string.")
109
+
110
+ if backend_enum == Backend.JAX and _BACKEND != Backend.JAX:
111
+ _warn_jax_experimental()
53
112
 
54
- _BACKEND = backend
113
+ _BACKEND = backend_enum
55
114
 
56
115
 
57
116
  def get_backend() -> Backend:
@@ -14,7 +14,10 @@
14
14
 
15
15
  """Common testing utilities for Meridian, designed to be backend-agnostic."""
16
16
 
17
- from typing import Any
17
+ from typing import Any, Optional
18
+ from absl.testing import parameterized
19
+ from meridian import backend
20
+ from meridian.backend import config
18
21
  import numpy as np
19
22
 
20
23
  # A type alias for backend-agnostic array-like objects.
@@ -67,6 +70,39 @@ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
67
70
  np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
68
71
 
69
72
 
73
+ def assert_seed_allequal(a: Any, b: Any, err_msg: str = ""):
74
+ """Backend-agnostic assertion to check if two seed objects are equal."""
75
+ data_a = backend.get_seed_data(a)
76
+ data_b = backend.get_seed_data(b)
77
+ if data_a is None and data_b is None:
78
+ return
79
+ np.testing.assert_array_equal(data_a, data_b, err_msg=err_msg)
80
+
81
+
82
+ def assert_not_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
83
+ """Asserts that two objects are not element-wise equal."""
84
+ np.testing.assert_(
85
+ not np.array_equal(np.array(a), np.array(b)),
86
+ msg=f"Arrays are unexpectedly equal.\n{err_msg}",
87
+ )
88
+
89
+
90
+ def assert_seed_not_allequal(a: Any, b: Any, err_msg: str = ""):
91
+ """Asserts that two seed objects are not element-wise equal."""
92
+ data_a = backend.get_seed_data(a)
93
+ data_b = backend.get_seed_data(b)
94
+ if data_a is None and data_b is None:
95
+ raise AssertionError(
96
+ f"Seeds are unexpectedly equal (both are None). {err_msg}"
97
+ )
98
+ if data_a is None or data_b is None:
99
+ return
100
+ np.testing.assert_(
101
+ not np.array_equal(data_a, data_b),
102
+ msg=f"Seeds are unexpectedly equal.\n{err_msg}",
103
+ )
104
+
105
+
70
106
  def assert_all_finite(a: ArrayLike, err_msg: str = ""):
71
107
  """Backend-agnostic assertion to check if all elements in an array are finite.
72
108
 
@@ -93,3 +129,93 @@ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
93
129
  """
94
130
  if not np.all(np.array(a) >= 0):
95
131
  raise AssertionError(err_msg or "Array contains negative values.")
132
+
133
+
134
+ class MeridianTestCase(parameterized.TestCase):
135
+ """Base test class for Meridian providing backend-aware utilities.
136
+
137
+ This class handles initialization timing issues (crucial for JAX by forcing
138
+ tensor operations into setUp) and provides a unified way to handle random
139
+ number generation across backends (Stateful TF vs Stateless JAX).
140
+ """
141
+
142
+ @classmethod
143
+ def setUpClass(cls):
144
+ super().setUpClass()
145
+ # Enforce determinism for TensorFlow tests before any tests are run.
146
+ # This is a no-op with a warning for the JAX backend.
147
+ backend.enable_op_determinism()
148
+
149
+ def setUp(self):
150
+ super().setUp()
151
+ # Default seed, can be overridden by subclasses before calling
152
+ # _initialize_rng().
153
+ self.seed = 42
154
+ self._jax_key = None
155
+ self._initialize_rng()
156
+
157
+ def _initialize_rng(self):
158
+ """Initializes the RNG state or key based on self.seed."""
159
+ current_backend = config.get_backend()
160
+
161
+ if current_backend == config.Backend.TENSORFLOW:
162
+ # In TF, we use the global stateful seed for test reproducibility.
163
+ try:
164
+ backend.set_random_seed(self.seed)
165
+ except NotImplementedError:
166
+ # Handle cases where backend might be misconfigured during transition.
167
+ pass
168
+ elif current_backend == config.Backend.JAX:
169
+ # In JAX, we must manage PRNGKeys explicitly.
170
+ # Import JAX locally to avoid hard dependency if TF is the active backend,
171
+ # and to ensure initialization happens after absltest.main() starts.
172
+ # pylint: disable=g-import-not-at-top
173
+ import jax
174
+ # pylint: enable=g-import-not-at-top
175
+ self._jax_key = jax.random.PRNGKey(self.seed)
176
+ else:
177
+ raise ValueError(f"Unknown backend: {current_backend}")
178
+
179
+ def get_next_rng_seed_or_key(self) -> Optional[Any]:
180
+ """Gets the next available seed or key for backend operations.
181
+
182
+ This should be passed to the `seed` argument of TFP sampling methods.
183
+
184
+ Returns:
185
+ A JAX PRNGKey if the backend is JAX (splitting the internal key).
186
+ None if the backend is TensorFlow (relying on the global state).
187
+ """
188
+ if self._jax_key is not None:
189
+ # JAX requires splitting the key for each use.
190
+ # pylint: disable=g-import-not-at-top
191
+ import jax
192
+ # pylint: enable=g-import-not-at-top
193
+ self._jax_key, subkey = jax.random.split(self._jax_key)
194
+ return subkey
195
+ else:
196
+ # For stateful TF, returning None allows TFP/TF to use the global seed.
197
+ return None
198
+
199
+ def sample(
200
+ self,
201
+ distribution: backend.tfd.Distribution,
202
+ sample_shape: Any = (),
203
+ **kwargs: Any,
204
+ ) -> backend.Tensor:
205
+ """Performs a backend-agnostic sample from a distribution.
206
+
207
+ This method abstracts away the need for explicit seed management in JAX.
208
+ When the JAX backend is active, it automatically provides a PRNGKey from
209
+ the test's managed key state. In TensorFlow, it performs a standard sample.
210
+
211
+ Args:
212
+ distribution: The TFP distribution object to sample from.
213
+ sample_shape: The shape of the desired sample.
214
+ **kwargs: Additional keyword arguments to pass to the underlying `sample`
215
+ method (e.g., `name`).
216
+
217
+ Returns:
218
+ A tensor containing the sampled values.
219
+ """
220
+ seed = self.get_next_rng_seed_or_key()
221
+ return distribution.sample(sample_shape=sample_shape, seed=seed, **kwargs)
meridian/constants.py CHANGED
@@ -54,6 +54,9 @@ DATE_FORMAT = '%Y-%m-%d'
54
54
  # Example: "2024 Apr"
55
55
  QUARTER_FORMAT = '%Y %b'
56
56
 
57
+ ORGANIC_PREFIX = 'organic_'
58
+ NATIONAL_PREFIX = 'national_'
59
+
57
60
  # Input data variables.
58
61
  KPI = 'kpi'
59
62
  REVENUE_PER_KPI = 'revenue_per_kpi'
@@ -65,10 +68,10 @@ REACH = 'reach'
65
68
  FREQUENCY = 'frequency'
66
69
  RF_IMPRESSIONS = 'rf_impressions'
67
70
  RF_SPEND = 'rf_spend'
68
- ORGANIC_MEDIA = 'organic_media'
69
- ORGANIC_RF = 'organic_rf'
70
- ORGANIC_REACH = 'organic_reach'
71
- ORGANIC_FREQUENCY = 'organic_frequency'
71
+ ORGANIC_MEDIA = ORGANIC_PREFIX + MEDIA
72
+ # ORGANIC_RF is defined below.
73
+ ORGANIC_REACH = ORGANIC_PREFIX + REACH
74
+ ORGANIC_FREQUENCY = ORGANIC_PREFIX + FREQUENCY
72
75
  NON_MEDIA_TREATMENTS = 'non_media_treatments'
73
76
  REVENUE = 'revenue'
74
77
  NON_REVENUE = 'non_revenue'
@@ -121,15 +124,17 @@ RF_DATA = (
121
124
  RF_SPEND,
122
125
  REVENUE_PER_KPI,
123
126
  )
124
- NON_REVENUE_DATA = IMPRESSIONS_DATA + (CONTROLS,)
125
127
 
126
128
  # Scaled input data variables.
127
129
  MEDIA_SCALED = 'media_scaled'
128
130
  REACH_SCALED = 'reach_scaled'
129
- ORGANIC_MEDIA_SCALED = 'organic_media_scaled'
130
- ORGANIC_REACH_SCALED = 'organic_reach_scaled'
131
+ ORGANIC_MEDIA_SCALED = ORGANIC_PREFIX + MEDIA_SCALED
132
+ ORGANIC_REACH_SCALED = ORGANIC_PREFIX + REACH_SCALED
131
133
  NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
132
134
  CONTROLS_SCALED = 'controls_scaled'
135
+ KPI_SCALED = f'{KPI}_scaled'
136
+ POPULATION_SCALED_KPI = f'{POPULATION}_scaled_{KPI}'
137
+ RF_IMPRESSIONS_SCALED = f'{RF_IMPRESSIONS}_scaled'
133
138
 
134
139
  # Non-media treatments baseline value constants.
135
140
  NON_MEDIA_BASELINE_MIN = 'min'
@@ -144,9 +149,9 @@ MEDIA_CHANNEL = 'media_channel'
144
149
  RF_CHANNEL = 'rf_channel'
145
150
  CHANNEL = 'channel'
146
151
  RF = 'rf'
147
- ORGANIC_RF = 'organic_rf'
148
- ORGANIC_MEDIA_CHANNEL = 'organic_media_channel'
149
- ORGANIC_RF_CHANNEL = 'organic_rf_channel'
152
+ ORGANIC_RF = ORGANIC_PREFIX + RF
153
+ ORGANIC_MEDIA_CHANNEL = ORGANIC_PREFIX + MEDIA_CHANNEL
154
+ ORGANIC_RF_CHANNEL = ORGANIC_PREFIX + RF_CHANNEL
150
155
  NON_MEDIA_CHANNEL = 'non_media_channel'
151
156
  CONTROL_VARIABLE = 'control_variable'
152
157
  REQUIRED_INPUT_DATA_COORD_NAMES = (
@@ -172,6 +177,41 @@ POSSIBLE_INPUT_DATA_COORDS_AND_ARRAYS_SET = frozenset(
172
177
  POSSIBLE_INPUT_DATA_COORD_NAMES + POSSIBLE_INPUT_DATA_ARRAY_NAMES
173
178
  )
174
179
 
180
+ # EDA Engine properties
181
+ ORGANIC_RF_IMPRESSIONS = ORGANIC_PREFIX + RF_IMPRESSIONS
182
+ ORGANIC_RF_IMPRESSIONS_SCALED = f'{ORGANIC_RF_IMPRESSIONS}_scaled'
183
+ TREATMENT_CONTROL_SCALED = 'treatment_control_scaled'
184
+ NATIONAL_TREATMENT_CONTROL_SCALED = (
185
+ f'{NATIONAL_PREFIX}{TREATMENT_CONTROL_SCALED}'
186
+ )
187
+ NATIONAL_CONTROLS_SCALED = f'{NATIONAL_PREFIX}{CONTROLS_SCALED}'
188
+ NATIONAL_MEDIA_SPEND = f'{NATIONAL_PREFIX}{MEDIA_SPEND}'
189
+ NATIONAL_MEDIA = f'{NATIONAL_PREFIX}{MEDIA}'
190
+ NATIONAL_MEDIA_SCALED = f'{NATIONAL_PREFIX}{MEDIA_SCALED}'
191
+ NATIONAL_ORGANIC_MEDIA = f'{NATIONAL_PREFIX}{ORGANIC_MEDIA}'
192
+ NATIONAL_ORGANIC_MEDIA_SCALED = f'{NATIONAL_PREFIX}{ORGANIC_MEDIA_SCALED}'
193
+ NATIONAL_NON_MEDIA_TREATMENTS_SCALED = (
194
+ f'{NATIONAL_PREFIX}{NON_MEDIA_TREATMENTS_SCALED}'
195
+ )
196
+ NATIONAL_RF_SPEND = f'{NATIONAL_PREFIX}{RF_SPEND}'
197
+ NATIONAL_KPI_SCALED = f'{NATIONAL_PREFIX}{KPI_SCALED}'
198
+ NATIONAL_REACH = f'{NATIONAL_PREFIX}{REACH}'
199
+ NATIONAL_REACH_SCALED = f'{NATIONAL_PREFIX}{REACH_SCALED}'
200
+ NATIONAL_ORGANIC_REACH = f'{NATIONAL_PREFIX}{ORGANIC_REACH}'
201
+ NATIONAL_ORGANIC_REACH_SCALED = f'{NATIONAL_PREFIX}{ORGANIC_REACH_SCALED}'
202
+ NATIONAL_FREQUENCY = f'{NATIONAL_PREFIX}{FREQUENCY}'
203
+ NATIONAL_ORGANIC_FREQUENCY = f'{NATIONAL_PREFIX}{ORGANIC_FREQUENCY}'
204
+ NATIONAL_RF_IMPRESSIONS = f'{NATIONAL_PREFIX}{RF_IMPRESSIONS}'
205
+ NATIONAL_ORGANIC_RF_IMPRESSIONS = f'{NATIONAL_PREFIX}{ORGANIC_RF_IMPRESSIONS}'
206
+ NATIONAL_RF_IMPRESSIONS_SCALED = f'{NATIONAL_PREFIX}{RF_IMPRESSIONS_SCALED}'
207
+ NATIONAL_ORGANIC_RF_IMPRESSIONS_SCALED = (
208
+ f'{NATIONAL_PREFIX}{ORGANIC_RF_IMPRESSIONS_SCALED}'
209
+ )
210
+ ALL_REACH_SCALED = 'all_reach_scaled'
211
+ ALL_FREQUENCY = 'all_frequency'
212
+ NATIONAL_ALL_REACH_SCALED = f'{NATIONAL_PREFIX}{ALL_REACH_SCALED}'
213
+ NATIONAL_ALL_FREQUENCY = f'{NATIONAL_PREFIX}{ALL_FREQUENCY}'
214
+
175
215
 
176
216
  # National model constants.
177
217
  NATIONAL = 'national'
@@ -354,7 +394,6 @@ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
354
394
  ETA_ORF,
355
395
  )
356
396
 
357
-
358
397
  MEDIA_PARAMETERS = (
359
398
  ROI_M,
360
399
  MROI_M,
@@ -742,6 +781,8 @@ HILL_NUM_STEPS = 500
742
781
  # Summary template params.
743
782
  START_DATE = 'start_date'
744
783
  END_DATE = 'end_date'
784
+ DEFAULT_CURRENCY = '$'
785
+ SELECTED_GEOS = 'selected_geos'
745
786
  CARD_INSIGHTS = 'insights'
746
787
  CARD_CHARTS = 'charts'
747
788
  CARD_STATS = 'stats'
@@ -454,14 +454,19 @@ class InputData:
454
454
  population = self.population.values[:, np.newaxis]
455
455
 
456
456
  population_scaled_kpi = np.divide(
457
- kpi, population, out=np.zeros_like(kpi), where=(population != 0)
457
+ kpi,
458
+ population,
459
+ out=np.zeros_like(kpi, dtype=float),
460
+ where=(population != 0),
458
461
  )
459
462
  population_scaled_mean = np.mean(population_scaled_kpi)
460
463
  population_scaled_stdev = np.std(population_scaled_kpi)
461
464
  kpi_scaled = np.divide(
462
465
  population_scaled_kpi - population_scaled_mean,
463
466
  population_scaled_stdev,
464
- out=np.zeros_like(population_scaled_kpi - population_scaled_mean),
467
+ out=np.zeros_like(
468
+ population_scaled_kpi - population_scaled_mean, dtype=float
469
+ ),
465
470
  where=(population_scaled_stdev != 0),
466
471
  )
467
472
  return kpi_scaled - np.mean(kpi_scaled, axis=1, keepdims=True)
@@ -1898,10 +1898,12 @@ def sample_input_data_for_aks_with_expected_knot_info() -> (
1898
1898
  'non_revenue',
1899
1899
  )
1900
1900
  expected_knot_info = knots.KnotInfo(
1901
- n_knots=6,
1902
- knot_locations=np.array([38, 39, 41, 48, 50, 55]),
1901
+ n_knots=13,
1902
+ knot_locations=np.array(
1903
+ [11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90]
1904
+ ),
1903
1905
  weights=knots.l1_distance_weights(
1904
- 117, np.array([38, 39, 41, 48, 50, 55])
1906
+ 117, np.array([11, 14, 38, 39, 41, 43, 45, 48, 50, 55, 87, 89, 90])
1905
1907
  ),
1906
1908
  )
1907
1909
  return data, expected_knot_info
@@ -72,6 +72,7 @@ import json
72
72
  from typing import Any, Callable
73
73
 
74
74
  import arviz as az
75
+ from meridian import backend
75
76
  from meridian.analysis import visualizer
76
77
  import mlflow
77
78
  from mlflow.utils.autologging_utils import autologging_integration, safe_patch
@@ -81,7 +82,6 @@ from meridian.model import prior_sampler
81
82
  from meridian.model import spec
82
83
  from meridian.version import __version__
83
84
  import numpy as np
84
- import tensorflow_probability as tfp
85
85
 
86
86
 
87
87
  FLAVOR_NAME = "meridian"
@@ -123,7 +123,7 @@ def _log_priors(model_spec: spec.ModelSpec) -> None:
123
123
  field_value = getattr(priors, field.name)
124
124
 
125
125
  # Stringify Distributions and numpy arrays.
126
- if isinstance(field_value, tfp.distributions.Distribution):
126
+ if isinstance(field_value, backend.tfd.Distribution):
127
127
  field_value = str(field_value)
128
128
  elif isinstance(field_value, np.ndarray):
129
129
  field_value = json.dumps(field_value.tolist())
@@ -15,6 +15,7 @@
15
15
  """The Meridian API module that models the data."""
16
16
 
17
17
  from meridian.model import adstock_hill
18
+ from meridian.model import eda
18
19
  from meridian.model import knots
19
20
  from meridian.model import media
20
21
  from meridian.model import model
@@ -145,8 +145,9 @@ def compute_decay_weights(
145
145
  alpha, l_range, window_size, constants.GEOMETRIC_DECAY, normalize,
146
146
  )
147
147
 
148
+ is_binomial = [s == constants.BINOMIAL_DECAY for s in decay_functions]
148
149
  binomial_decay_mask = backend.reshape(
149
- backend.to_tensor(decay_functions) == constants.BINOMIAL_DECAY,
150
+ backend.to_tensor(is_binomial, dtype=backend.bool_),
150
151
  (-1, 1),
151
152
  )
152
153
 
@@ -189,14 +190,14 @@ def _compute_single_decay_function_weights(
189
190
  A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
190
191
  """
191
192
  expanded_alpha = backend.expand_dims(alpha, -1)
192
- match decay_function:
193
- case constants.GEOMETRIC_DECAY:
194
- weights = expanded_alpha**l_range
195
- case constants.BINOMIAL_DECAY:
196
- mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
197
- weights = (1 - l_range / window_size) ** mapped_alpha_binomial
198
- case _:
199
- raise ValueError(f'Unsupported decay function: {decay_function}')
193
+
194
+ if decay_function == constants.GEOMETRIC_DECAY:
195
+ weights = expanded_alpha**l_range
196
+ elif decay_function == constants.BINOMIAL_DECAY:
197
+ mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
198
+ weights = (1 - l_range / window_size) ** mapped_alpha_binomial
199
+ else:
200
+ raise ValueError(f'Unsupported decay function: {decay_function}')
200
201
 
201
202
  if normalize:
202
203
  normalization_factors = backend.reduce_sum(weights, axis=-1, keepdims=True)
@@ -15,3 +15,6 @@
15
15
  """The Meridian API module that performs EDA checks."""
16
16
 
17
17
  from meridian.model.eda import eda_engine
18
+ from meridian.model.eda import eda_outcome
19
+ from meridian.model.eda import eda_spec
20
+ from meridian.model.eda import meridian_eda
@@ -0,0 +1,21 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Constants specific to MeridianEDA."""
16
+
17
+ # EDA Plotting properties
18
+ VARIABLE_1 = 'var1'
19
+ VARIABLE_2 = 'var2'
20
+ VARIABLE = 'var'
21
+ CORRELATION = 'correlation'