google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,118 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Backend configuration for Meridian."""
16
+
17
+ import enum
18
+ import os
19
+ from typing import Union
20
+ import warnings
21
+
22
+
23
+ class Backend(enum.Enum):
24
+ TENSORFLOW = "tensorflow"
25
+ JAX = "jax"
26
+
27
+
28
+ _DEFAULT_BACKEND = Backend.TENSORFLOW
29
+
30
+
31
+ def _warn_jax_experimental() -> None:
32
+ """Issues a warning that the JAX backend is experimental."""
33
+ warnings.warn(
34
+ (
35
+ "The JAX backend is currently under development and is not yet"
36
+ " functional. It is intended for internal testing only and should"
37
+ " not be used. Please use the TensorFlow backend."
38
+ ),
39
+ UserWarning,
40
+ # Set stacklevel=2 so the warning points to the caller of set_backend
41
+ # or the location where the module is imported if initialized via env var.
42
+ stacklevel=2,
43
+ )
44
+
45
+
46
+ def _initialize_backend() -> Backend:
47
+ """Initializes the backend based on environment variables or defaults."""
48
+ env_backend_str = os.environ.get("MERIDIAN_BACKEND")
49
+
50
+ if not env_backend_str:
51
+ return _DEFAULT_BACKEND
52
+
53
+ try:
54
+ backend = Backend(env_backend_str.lower())
55
+ if backend == Backend.JAX:
56
+ _warn_jax_experimental()
57
+ return backend
58
+ except ValueError:
59
+ warnings.warn(
60
+ (
61
+ "Invalid MERIDIAN_BACKEND environment variable:"
62
+ f" '{env_backend_str}'. Supported values are 'tensorflow' and"
63
+ f" 'jax'. Defaulting to {_DEFAULT_BACKEND.value}."
64
+ ),
65
+ RuntimeWarning,
66
+ )
67
+ return _DEFAULT_BACKEND
68
+
69
+
70
+ _BACKEND = _initialize_backend()
71
+
72
+
73
+ def set_backend(backend: Union[Backend, str]) -> None:
74
+ """Sets the backend for Meridian.
75
+
76
+ **Warning:** This function should ideally be called at the beginning of your
77
+ program, before any other Meridian modules are imported or used.
78
+
79
+ Changing the backend after Meridian's functions or classes have been
80
+ imported can lead to unpredictable behavior. This is because already-imported
81
+ modules will not reflect the backend change.
82
+
83
+ Note: The JAX backend is currently under development and should not be used.
84
+
85
+ Changing the backend at runtime requires reloading the `meridian.backend`
86
+ module for the changes to take effect globally.
87
+
88
+ Args:
89
+ backend: The backend to use, must be a member of the `Backend` enum or a
90
+ valid string ('tensorflow', 'jax').
91
+
92
+ Raises:
93
+ ValueError: If the provided backend is not valid.
94
+ """
95
+ global _BACKEND
96
+
97
+ if isinstance(backend, str):
98
+ try:
99
+ backend_enum = Backend(backend.lower())
100
+ except ValueError as exc:
101
+ raise ValueError(
102
+ f"Invalid backend string '{backend}'. Must be one of: "
103
+ f"{[b.value for b in Backend]}"
104
+ ) from exc
105
+ elif isinstance(backend, Backend):
106
+ backend_enum = backend
107
+ else:
108
+ raise ValueError("Backend must be a Backend enum member or a string.")
109
+
110
+ if backend_enum == Backend.JAX and _BACKEND != Backend.JAX:
111
+ _warn_jax_experimental()
112
+
113
+ _BACKEND = backend_enum
114
+
115
+
116
+ def get_backend() -> Backend:
117
+ """Returns the current backend for Meridian."""
118
+ return _BACKEND
@@ -0,0 +1,181 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Common testing utilities for Meridian, designed to be backend-agnostic."""
16
+
17
+ from typing import Any, Optional
18
+ from absl.testing import parameterized
19
+ from meridian import backend
20
+ from meridian.backend import config
21
+ import numpy as np
22
+
23
+ # A type alias for backend-agnostic array-like objects.
24
+ # We use `Any` here to avoid circular dependencies with the backend module
25
+ # while still allowing the function to accept backend-specific tensor types.
26
+ ArrayLike = Any
27
+
28
+
29
+ def assert_allclose(
30
+ a: ArrayLike,
31
+ b: ArrayLike,
32
+ rtol: float = 1e-6,
33
+ atol: float = 1e-6,
34
+ err_msg: str = "",
35
+ ):
36
+ """Backend-agnostic assertion to check if two array-like objects are close.
37
+
38
+ This function converts both inputs to NumPy arrays before comparing them,
39
+ making it compatible with TensorFlow Tensors, JAX Arrays, and standard
40
+ Python lists or NumPy arrays.
41
+
42
+ Args:
43
+ a: The first array-like object to compare.
44
+ b: The second array-like object to compare.
45
+ rtol: The relative tolerance parameter.
46
+ atol: The absolute tolerance parameter.
47
+ err_msg: The error message to be printed in case of failure.
48
+
49
+ Raises:
50
+ AssertionError: If the two arrays are not equal within the given tolerance.
51
+ """
52
+ np.testing.assert_allclose(
53
+ np.array(a), np.array(b), rtol=rtol, atol=atol, err_msg=err_msg
54
+ )
55
+
56
+
57
+ def assert_allequal(a: ArrayLike, b: ArrayLike, err_msg: str = ""):
58
+ """Backend-agnostic assertion to check if two array-like objects are equal.
59
+
60
+ This function converts both inputs to NumPy arrays before comparing them.
61
+
62
+ Args:
63
+ a: The first array-like object to compare.
64
+ b: The second array-like object to compare.
65
+ err_msg: The error message to be printed in case of failure.
66
+
67
+ Raises:
68
+ AssertionError: If the two arrays are not equal.
69
+ """
70
+ np.testing.assert_array_equal(np.array(a), np.array(b), err_msg=err_msg)
71
+
72
+
73
+ def assert_all_finite(a: ArrayLike, err_msg: str = ""):
74
+ """Backend-agnostic assertion to check if all elements in an array are finite.
75
+
76
+ Args:
77
+ a: The array-like object to check.
78
+ err_msg: The error message to be printed in case of failure.
79
+
80
+ Raises:
81
+ AssertionError: If the array contains non-finite values.
82
+ """
83
+ if not np.all(np.isfinite(np.array(a))):
84
+ raise AssertionError(err_msg or "Array contains non-finite values.")
85
+
86
+
87
+ def assert_all_non_negative(a: ArrayLike, err_msg: str = ""):
88
+ """Backend-agnostic assertion to check if all elements are non-negative.
89
+
90
+ Args:
91
+ a: The array-like object to check.
92
+ err_msg: The error message to be printed in case of failure.
93
+
94
+ Raises:
95
+ AssertionError: If the array contains negative values.
96
+ """
97
+ if not np.all(np.array(a) >= 0):
98
+ raise AssertionError(err_msg or "Array contains negative values.")
99
+
100
+
101
+ class MeridianTestCase(parameterized.TestCase):
102
+ """Base test class for Meridian providing backend-aware utilities.
103
+
104
+ This class handles initialization timing issues (crucial for JAX by forcing
105
+ tensor operations into setUp) and provides a unified way to handle random
106
+ number generation across backends (Stateful TF vs Stateless JAX).
107
+ """
108
+
109
+ def setUp(self):
110
+ super().setUp()
111
+ # Default seed, can be overridden by subclasses before calling
112
+ # _initialize_rng().
113
+ self.seed = 42
114
+ self._jax_key = None
115
+ self._initialize_rng()
116
+
117
+ def _initialize_rng(self):
118
+ """Initializes the RNG state or key based on self.seed."""
119
+ current_backend = config.get_backend()
120
+
121
+ if current_backend == config.Backend.TENSORFLOW:
122
+ # In TF, we use the global stateful seed for test reproducibility.
123
+ try:
124
+ backend.set_random_seed(self.seed)
125
+ except NotImplementedError:
126
+ # Handle cases where backend might be misconfigured during transition.
127
+ pass
128
+ elif current_backend == config.Backend.JAX:
129
+ # In JAX, we must manage PRNGKeys explicitly.
130
+ # Import JAX locally to avoid hard dependency if TF is the active backend,
131
+ # and to ensure initialization happens after absltest.main() starts.
132
+ # pylint: disable=g-import-not-at-top
133
+ import jax
134
+ # pylint: enable=g-import-not-at-top
135
+ self._jax_key = jax.random.PRNGKey(self.seed)
136
+ else:
137
+ raise ValueError(f"Unknown backend: {current_backend}")
138
+
139
+ def get_next_rng_seed_or_key(self) -> Optional[Any]:
140
+ """Gets the next available seed or key for backend operations.
141
+
142
+ This should be passed to the `seed` argument of TFP sampling methods.
143
+
144
+ Returns:
145
+ A JAX PRNGKey if the backend is JAX (splitting the internal key).
146
+ None if the backend is TensorFlow (relying on the global state).
147
+ """
148
+ if self._jax_key is not None:
149
+ # JAX requires splitting the key for each use.
150
+ # pylint: disable=g-import-not-at-top
151
+ import jax
152
+ # pylint: enable=g-import-not-at-top
153
+ self._jax_key, subkey = jax.random.split(self._jax_key)
154
+ return subkey
155
+ else:
156
+ # For stateful TF, returning None allows TFP/TF to use the global seed.
157
+ return None
158
+
159
+ def sample(
160
+ self,
161
+ distribution: backend.tfd.Distribution,
162
+ sample_shape: Any = (),
163
+ **kwargs: Any,
164
+ ) -> backend.Tensor:
165
+ """Performs a backend-agnostic sample from a distribution.
166
+
167
+ This method abstracts away the need for explicit seed management in JAX.
168
+ When the JAX backend is active, it automatically provides a PRNGKey from
169
+ the test's managed key state. In TensorFlow, it performs a standard sample.
170
+
171
+ Args:
172
+ distribution: The TFP distribution object to sample from.
173
+ sample_shape: The shape of the desired sample.
174
+ **kwargs: Additional keyword arguments to pass to the underlying `sample`
175
+ method (e.g., `name`).
176
+
177
+ Returns:
178
+ A tensor containing the sampled values.
179
+ """
180
+ seed = self.get_next_rng_seed_or_key()
181
+ return distribution.sample(sample_shape=sample_shape, seed=seed, **kwargs)
meridian/constants.py CHANGED
@@ -54,6 +54,8 @@ DATE_FORMAT = '%Y-%m-%d'
54
54
  # Example: "2024 Apr"
55
55
  QUARTER_FORMAT = '%Y %b'
56
56
 
57
+ ORGANIC_PREFIX = 'organic_'
58
+
57
59
  # Input data variables.
58
60
  KPI = 'kpi'
59
61
  REVENUE_PER_KPI = 'revenue_per_kpi'
@@ -65,9 +67,10 @@ REACH = 'reach'
65
67
  FREQUENCY = 'frequency'
66
68
  RF_IMPRESSIONS = 'rf_impressions'
67
69
  RF_SPEND = 'rf_spend'
68
- ORGANIC_MEDIA = 'organic_media'
69
- ORGANIC_REACH = 'organic_reach'
70
- ORGANIC_FREQUENCY = 'organic_frequency'
70
+ ORGANIC_MEDIA = ORGANIC_PREFIX + MEDIA
71
+ # ORGANIC_RF is defined below.
72
+ ORGANIC_REACH = ORGANIC_PREFIX + REACH
73
+ ORGANIC_FREQUENCY = ORGANIC_PREFIX + FREQUENCY
71
74
  NON_MEDIA_TREATMENTS = 'non_media_treatments'
72
75
  REVENUE = 'revenue'
73
76
  NON_REVENUE = 'non_revenue'
@@ -125,8 +128,8 @@ NON_REVENUE_DATA = IMPRESSIONS_DATA + (CONTROLS,)
125
128
  # Scaled input data variables.
126
129
  MEDIA_SCALED = 'media_scaled'
127
130
  REACH_SCALED = 'reach_scaled'
128
- ORGANIC_MEDIA_SCALED = 'organic_media_scaled'
129
- ORGANIC_REACH_SCALED = 'organic_reach_scaled'
131
+ ORGANIC_MEDIA_SCALED = ORGANIC_PREFIX + MEDIA_SCALED
132
+ ORGANIC_REACH_SCALED = ORGANIC_PREFIX + REACH_SCALED
130
133
  NON_MEDIA_TREATMENTS_SCALED = 'non_media_treatments_scaled'
131
134
  CONTROLS_SCALED = 'controls_scaled'
132
135
 
@@ -143,8 +146,9 @@ MEDIA_CHANNEL = 'media_channel'
143
146
  RF_CHANNEL = 'rf_channel'
144
147
  CHANNEL = 'channel'
145
148
  RF = 'rf'
146
- ORGANIC_MEDIA_CHANNEL = 'organic_media_channel'
147
- ORGANIC_RF_CHANNEL = 'organic_rf_channel'
149
+ ORGANIC_RF = ORGANIC_PREFIX + RF
150
+ ORGANIC_MEDIA_CHANNEL = ORGANIC_PREFIX + MEDIA_CHANNEL
151
+ ORGANIC_RF_CHANNEL = ORGANIC_PREFIX + RF_CHANNEL
148
152
  NON_MEDIA_CHANNEL = 'non_media_channel'
149
153
  CONTROL_VARIABLE = 'control_variable'
150
154
  REQUIRED_INPUT_DATA_COORD_NAMES = (
@@ -170,6 +174,9 @@ POSSIBLE_INPUT_DATA_COORDS_AND_ARRAYS_SET = frozenset(
170
174
  POSSIBLE_INPUT_DATA_COORD_NAMES + POSSIBLE_INPUT_DATA_ARRAY_NAMES
171
175
  )
172
176
 
177
+ # EDA property constants
178
+ ORGANIC_RF_IMPRESSIONS = ORGANIC_PREFIX + RF_IMPRESSIONS
179
+
173
180
 
174
181
  # National model constants.
175
182
  NATIONAL = 'national'
@@ -212,9 +219,11 @@ NON_PAID_TREATMENT_PRIOR_TYPES = frozenset({
212
219
  TREATMENT_PRIOR_TYPE_COEFFICIENT,
213
220
  TREATMENT_PRIOR_TYPE_CONTRIBUTION,
214
221
  })
215
- PAID_MEDIA_ROI_PRIOR_TYPES = frozenset(
216
- {TREATMENT_PRIOR_TYPE_ROI, TREATMENT_PRIOR_TYPE_MROI}
217
- )
222
+ PAID_MEDIA_ROI_PRIOR_TYPES = frozenset({
223
+ TREATMENT_PRIOR_TYPE_ROI,
224
+ TREATMENT_PRIOR_TYPE_MROI,
225
+ TREATMENT_PRIOR_TYPE_CONTRIBUTION,
226
+ })
218
227
  # Represents a 1% increase in spend.
219
228
  MROI_FACTOR = 1.01
220
229
 
@@ -315,6 +324,41 @@ RF_PARAMETER_NAMES = (
315
324
  BETA_RF,
316
325
  BETA_GRF,
317
326
  )
327
+ ORGANIC_MEDIA_PARAMETER_NAMES = (
328
+ CONTRIBUTION_OM,
329
+ BETA_OM,
330
+ ETA_OM,
331
+ ALPHA_OM,
332
+ EC_OM,
333
+ SLOPE_OM,
334
+ BETA_GOM,
335
+ )
336
+ ORGANIC_RF_PARAMETER_NAMES = (
337
+ CONTRIBUTION_ORF,
338
+ BETA_ORF,
339
+ ETA_ORF,
340
+ ALPHA_ORF,
341
+ EC_ORF,
342
+ SLOPE_ORF,
343
+ BETA_GORF,
344
+ )
345
+ NON_MEDIA_PARAMETER_NAMES = (
346
+ CONTRIBUTION_N,
347
+ GAMMA_N,
348
+ XI_N,
349
+ GAMMA_GN,
350
+ )
351
+ ALL_NATIONAL_DETERMINISTIC_PARAMETER_NAMES = (
352
+ SLOPE_M,
353
+ SLOPE_OM,
354
+ XI_N,
355
+ XI_C,
356
+ ETA_M,
357
+ ETA_RF,
358
+ ETA_OM,
359
+ ETA_ORF,
360
+ )
361
+
318
362
 
319
363
  MEDIA_PARAMETERS = (
320
364
  ROI_M,
@@ -501,10 +545,17 @@ ADSTOCK_HILL_FUNCTIONS = frozenset({
501
545
  'hill',
502
546
  })
503
547
 
548
+ # Adstock decay functions.
549
+ GEOMETRIC_DECAY = 'geometric'
550
+ BINOMIAL_DECAY = 'binomial'
551
+
552
+ ADSTOCK_DECAY_FUNCTIONS = frozenset({GEOMETRIC_DECAY, BINOMIAL_DECAY})
553
+ ADSTOCK_CHANNELS = (MEDIA, RF, ORGANIC_MEDIA, ORGANIC_RF)
504
554
 
505
555
  # Distribution constants.
506
556
  DISTRIBUTION = 'distribution'
507
557
  DISTRIBUTION_TYPE = 'distribution_type'
558
+ INDEPENDENT_MULTIVARIATE = 'IndependentMultivariate'
508
559
  PRIOR = 'prior'
509
560
  POSTERIOR = 'posterior'
510
561
  # Prior mean proportion of KPI incremental due to all media.
@@ -710,3 +761,13 @@ WEEKLY = 'weekly'
710
761
  QUARTERLY = 'quarterly'
711
762
  TIME_GRANULARITIES = frozenset({WEEKLY, QUARTERLY})
712
763
  QUARTERLY_SUMMARY_THRESHOLD_WEEKS = 52
764
+
765
+ # Automatic Knot Selection constants
766
+ KNOTS_SELECTED = 'knots_selected'
767
+ SELECTION_COEFS = 'selection_coefs'
768
+ MODEL = 'model'
769
+ REGRESSION_COEFS = 'regression_coefs'
770
+ SELECTED_MATRIX = 'selected_matrix'
771
+ AIC = 'aic'
772
+ BIC = 'bic'
773
+ EBIC = 'ebic'
@@ -442,6 +442,59 @@ class InputData:
442
442
  """Checks whether the `rf_spend` array has a time dimension."""
443
443
  return self.rf_spend is not None and constants.TIME in self.rf_spend.coords
444
444
 
445
+ @property
446
+ def scaled_centered_kpi(self) -> np.ndarray:
447
+ """Calculates scaled and centered KPI values.
448
+
449
+ Returns:
450
+ An array of KPI values that have been population-scaled and
451
+ mean-centered by geo.
452
+ """
453
+ kpi = self.kpi.values
454
+ population = self.population.values[:, np.newaxis]
455
+
456
+ population_scaled_kpi = np.divide(
457
+ kpi,
458
+ population,
459
+ out=np.zeros_like(kpi, dtype=float),
460
+ where=(population != 0),
461
+ )
462
+ population_scaled_mean = np.mean(population_scaled_kpi)
463
+ population_scaled_stdev = np.std(population_scaled_kpi)
464
+ kpi_scaled = np.divide(
465
+ population_scaled_kpi - population_scaled_mean,
466
+ population_scaled_stdev,
467
+ out=np.zeros_like(
468
+ population_scaled_kpi - population_scaled_mean, dtype=float
469
+ ),
470
+ where=(population_scaled_stdev != 0),
471
+ )
472
+ return kpi_scaled - np.mean(kpi_scaled, axis=1, keepdims=True)
473
+
474
+ def copy(self, deep: bool = True) -> "InputData":
475
+ """Returns a copy of the InputData instance.
476
+
477
+ Args:
478
+ deep: If True, a deep copy is made, meaning all xarray.DataArray objects
479
+ are also deepcopied. If False, a shallow copy is made.
480
+
481
+ Returns:
482
+ A new InputData instance.
483
+ """
484
+ if not deep:
485
+ return dataclasses.replace(self)
486
+
487
+ copied_fields = {}
488
+ for field in dataclasses.fields(self):
489
+ value = getattr(self, field.name)
490
+ if isinstance(value, xr.DataArray):
491
+ copied_fields[field.name] = value.copy(deep=True)
492
+ else:
493
+ # For other types, dataclasses.replace does a shallow copy.
494
+ copied_fields[field.name] = value
495
+
496
+ return InputData(**copied_fields)
497
+
445
498
  def _validate_scenarios(self):
446
499
  """Verifies that calibration and analysis is set correctly."""
447
500
  n_geos = len(self.kpi.coords[constants.GEO])
@@ -848,6 +901,32 @@ class InputData:
848
901
  raise ValueError("Both RF and media channel values are missing.")
849
902
  # pytype: enable=attribute-error
850
903
 
904
+ def get_all_adstock_hill_channels(self) -> np.ndarray:
905
+ """Returns all channel dimensions that adstock hill is applied to.
906
+
907
+ RF, organic media and organic RF channels are concatenated to the end of the
908
+ media channels if they are present.
909
+ """
910
+ adstock_hill_channels = []
911
+
912
+ if self.media_channel is not None:
913
+ adstock_hill_channels.append(self.media_channel.values)
914
+
915
+ if self.rf_channel is not None:
916
+ adstock_hill_channels.append(self.rf_channel.values)
917
+
918
+ if self.organic_media_channel is not None:
919
+ adstock_hill_channels.append(self.organic_media_channel.values)
920
+
921
+ if self.organic_rf_channel is not None:
922
+ adstock_hill_channels.append(self.organic_rf_channel.values)
923
+
924
+ if not adstock_hill_channels:
925
+ raise ValueError("Media, RF, organic media and organic RF channels are "
926
+ "all missing.")
927
+
928
+ return np.concatenate(adstock_hill_channels, axis=None)
929
+
851
930
  def get_paid_channels_argument_builder(
852
931
  self,
853
932
  ) -> arg_builder.OrderedListArgumentBuilder:
@@ -870,6 +949,26 @@ class InputData:
870
949
  raise ValueError("There are no RF channels in the input data.")
871
950
  return arg_builder.OrderedListArgumentBuilder(self.rf_channel.values)
872
951
 
952
+ def get_organic_media_channels_argument_builder(
953
+ self
954
+ ) -> arg_builder.OrderedListArgumentBuilder:
955
+ """Returns an argument builder for *organic* media channels *only*."""
956
+ if self.organic_media_channel is None:
957
+ raise ValueError("There are no organic media channels in the input data.")
958
+ return arg_builder.OrderedListArgumentBuilder(
959
+ self.organic_media_channel.values
960
+ )
961
+
962
+ def get_organic_rf_channels_argument_builder(
963
+ self
964
+ ) -> arg_builder.OrderedListArgumentBuilder:
965
+ """Returns an argument builder for *organic* RF channels *only*."""
966
+ if self.organic_rf_channel is None:
967
+ raise ValueError("There are no organic RF channels in the input data.")
968
+ return arg_builder.OrderedListArgumentBuilder(
969
+ self.organic_rf_channel.values
970
+ )
971
+
873
972
  def get_all_channels(self) -> np.ndarray:
874
973
  """Returns all the channel dimensions.
875
974