google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,17 +15,201 @@
15
15
  """Function definitions for Adstock and Hill calculations."""
16
16
 
17
17
  import abc
18
- import tensorflow as tf
18
+ from collections.abc import Sequence
19
+ import dataclasses
20
+ from meridian import backend
21
+ from meridian import constants
22
+
19
23
 
20
24
  __all__ = [
25
+ 'AdstockDecaySpec',
21
26
  'AdstockHillTransformer',
22
27
  'AdstockTransformer',
23
28
  'HillTransformer',
29
+ 'transform_non_negative_reals_distribution',
30
+ 'compute_decay_weights',
24
31
  ]
25
32
 
26
33
 
34
+ @dataclasses.dataclass(frozen=True)
35
+ class AdstockDecaySpec:
36
+ """Specification for each channel's adstock decay function.
37
+
38
+ This class contains the adstock decay function(s) to use for each channel
39
+ that the adstock transformation is applied to.
40
+
41
+ Attributes:
42
+ media: A string or sequence of strings specifying the adstock function(s)
43
+ to use for media channels.
44
+ rf: A string or sequence of strings specifying the adstock function(s)
45
+ to use for reach and frequency channels.
46
+ organic_media: A string or sequence of strings specifying the adstock
47
+ function(s) to use for organic media channels.
48
+ organic_rf: A string or sequence of strings specifying the adstock
49
+ function(s) to use for organic reach and frequency channels.
50
+ """
51
+
52
+ media: str | Sequence[str] = constants.GEOMETRIC_DECAY
53
+ rf: str | Sequence[str] = constants.GEOMETRIC_DECAY
54
+ organic_media: str | Sequence[str] = constants.GEOMETRIC_DECAY
55
+ organic_rf: str | Sequence[str] = constants.GEOMETRIC_DECAY
56
+
57
+ @classmethod
58
+ def from_consistent_type(
59
+ cls,
60
+ consistent_decay_function: str = constants.GEOMETRIC_DECAY,
61
+ ) -> 'AdstockDecaySpec':
62
+ """Create an `AdstockDecaySpec` with the same decay function for all channels.
63
+
64
+ Arguments:
65
+ consistent_decay_function: A string denoting the adstock decay function
66
+ to use for all channels that the Adstock transformation is applied to.
67
+
68
+ Raises:
69
+ ValueError: If `consistent_decay_function` is not 'geometric' or
70
+ 'binomial'.
71
+ """
72
+ adstock_decay_functions = dict.fromkeys(
73
+ constants.ADSTOCK_CHANNELS, consistent_decay_function
74
+ )
75
+ return cls(**adstock_decay_functions)
76
+
77
+ def __post_init__(self):
78
+ adstock_decay_functions = [
79
+ self.media,
80
+ self.rf,
81
+ self.organic_media,
82
+ self.organic_rf,
83
+ ]
84
+
85
+ for v in adstock_decay_functions:
86
+ if isinstance(v, str):
87
+ _validate_adstock_decay_function(v)
88
+ else:
89
+ for vi in v:
90
+ _validate_adstock_decay_function(vi)
91
+
92
+
93
+ def _validate_adstock_decay_function(adstock_decay_func: str):
94
+ if adstock_decay_func not in constants.ADSTOCK_DECAY_FUNCTIONS:
95
+ raise ValueError(
96
+ 'Unrecognized adstock decay function value '
97
+ f'("{adstock_decay_func}"), expected one of '
98
+ f'{constants.ADSTOCK_DECAY_FUNCTIONS}.'
99
+ )
100
+
101
+
102
+ def compute_decay_weights(
103
+ alpha: backend.Tensor,
104
+ l_range: backend.Tensor,
105
+ window_size: int,
106
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
107
+ normalize: bool = True,
108
+ ) -> backend.Tensor:
109
+ """Computes decay weights using geometric and/or binomial decay.
110
+
111
+ This function always broadcasts the lag dimension (`l_range`) to the
112
+ trailing axis of the output tensor.
113
+
114
+ Args:
115
+ alpha: The parameter for the adstock decay function.
116
+ l_range: A 1D tensor representing the lag range, e.g., `[w-1, w-2, ...,
117
+ 0]`.
118
+ window_size: The number of time periods that go into the adstock weighted
119
+ average for each output time period.
120
+ decay_functions: String or sequence of strings indicating the decay
121
+ function(s) to use for the Adstock calculation. Allowed values
122
+ are 'geometric' and 'binomial'.
123
+ normalize: A boolean indicating whether to normalize the weights. Default:
124
+ `True`.
125
+
126
+ Returns:
127
+ A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
128
+
129
+ Raises:
130
+ ValueError: If the shape of `decay_functions` is not broadcastable to
131
+ the shape of `alpha`.
132
+
133
+ """
134
+
135
+ if isinstance(decay_functions, str):
136
+ # Same decay function for all channels
137
+ return _compute_single_decay_function_weights(
138
+ alpha, l_range, window_size, decay_functions, normalize,
139
+ )
140
+
141
+ binomial_weights = _compute_single_decay_function_weights(
142
+ alpha, l_range, window_size, constants.BINOMIAL_DECAY, normalize,
143
+ )
144
+ geometric_weights = _compute_single_decay_function_weights(
145
+ alpha, l_range, window_size, constants.GEOMETRIC_DECAY, normalize,
146
+ )
147
+
148
+ is_binomial = [s == constants.BINOMIAL_DECAY for s in decay_functions]
149
+ binomial_decay_mask = backend.reshape(
150
+ backend.to_tensor(is_binomial, dtype=backend.bool_),
151
+ (-1, 1),
152
+ )
153
+
154
+ try:
155
+ # pytype: disable=bad-return-type
156
+ return backend.where(
157
+ binomial_decay_mask, binomial_weights, geometric_weights
158
+ )
159
+ # pytype: enable=bad-return-type
160
+ except (backend.errors.InvalidArgumentError, ValueError) as e:
161
+ raise ValueError(
162
+ f'The shape of `alpha` ({alpha.shape}) is incompatible with the length'
163
+ f' of `decay_functions` ({len(decay_functions)})'
164
+ ) from e
165
+
166
+
167
+ def _compute_single_decay_function_weights(
168
+ alpha: backend.Tensor,
169
+ l_range: backend.Tensor,
170
+ window_size: int,
171
+ decay_function: str,
172
+ normalize: bool,
173
+ ) -> backend.Tensor:
174
+ """Computes decay weights using geometric decay.
175
+
176
+ This function always broadcasts the lag dimension (`l_range`) to the
177
+ trailing axis of the output tensor.
178
+
179
+ Args:
180
+ alpha: The parameter for the adstock decay function.
181
+ l_range: A 1D tensor representing the lag range, e.g., `[w-1, w-2, ...,
182
+ 0]`.
183
+ window_size: The number of time periods that go into the adstock weighted
184
+ average for each output time period.
185
+ decay_function: String indicating the decay function to use for the
186
+ Adstock calculation. Allowed values are 'geometric' and 'binomial'.
187
+ normalize: A boolean indicating whether to normalize the weights.
188
+
189
+ Returns:
190
+ A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
191
+ """
192
+ expanded_alpha = backend.expand_dims(alpha, -1)
193
+
194
+ if decay_function == constants.GEOMETRIC_DECAY:
195
+ weights = expanded_alpha**l_range
196
+ elif decay_function == constants.BINOMIAL_DECAY:
197
+ mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
198
+ weights = (1 - l_range / window_size) ** mapped_alpha_binomial
199
+ else:
200
+ raise ValueError(f'Unsupported decay function: {decay_function}')
201
+
202
+ if normalize:
203
+ normalization_factors = backend.reduce_sum(weights, axis=-1, keepdims=True)
204
+ return backend.divide(weights, normalization_factors)
205
+ return weights
206
+
207
+
27
208
  def _validate_arguments(
28
- media: tf.Tensor, alpha: tf.Tensor, max_lag: int, n_times_output: int
209
+ media: backend.Tensor,
210
+ alpha: backend.Tensor,
211
+ max_lag: int,
212
+ n_times_output: int,
29
213
  ) -> None:
30
214
  batch_dims = alpha.shape[:-1]
31
215
  n_media_times = media.shape[-2]
@@ -35,7 +219,7 @@ def _validate_arguments(
35
219
  '`n_times_output` cannot exceed number of time periods in the media'
36
220
  ' data.'
37
221
  )
38
- if media.shape[:-3] not in [tf.TensorShape([]), tf.TensorShape(batch_dims)]:
222
+ if tuple(media.shape[:-3]) not in [(), tuple(batch_dims)]:
39
223
  raise ValueError(
40
224
  '`media` batch dims do not match `alpha` batch dims. If `media` '
41
225
  'has batch dims, then they must match `alpha`.'
@@ -51,11 +235,12 @@ def _validate_arguments(
51
235
 
52
236
 
53
237
  def _adstock(
54
- media: tf.Tensor,
55
- alpha: tf.Tensor,
238
+ media: backend.Tensor,
239
+ alpha: backend.Tensor,
56
240
  max_lag: int,
57
241
  n_times_output: int,
58
- ) -> tf.Tensor:
242
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
243
+ ) -> backend.Tensor:
59
244
  """Computes the Adstock function."""
60
245
  _validate_arguments(
61
246
  media=media, alpha=alpha, max_lag=max_lag, n_times_output=n_times_output
@@ -91,34 +276,43 @@ def _adstock(
91
276
  + (required_n_media_times - n_media_times,)
92
277
  + (media.shape[-1],)
93
278
  )
94
- media = tf.concat([tf.zeros(pad_shape), media], axis=-2)
279
+ media = backend.concatenate([backend.zeros(pad_shape), media], axis=-2)
95
280
 
96
281
  # Adstock calculation.
97
282
  window_list = [None] * window_size
98
283
  for i in range(window_size):
99
- window_list[i] = media[..., i:i+n_times_output, :]
100
- windowed = tf.stack(window_list)
101
- l_range = tf.range(window_size - 1, -1, -1, dtype=tf.float32)
102
- weights = tf.expand_dims(alpha, -1) ** l_range
103
- normalization_factors = tf.expand_dims(
104
- (1 - alpha ** (window_size)) / (1 - alpha), -1
284
+ window_list[i] = media[..., i : i + n_times_output, :]
285
+ windowed = backend.stack(window_list)
286
+ l_range = backend.arange(window_size - 1, -1, -1, dtype=backend.float32)
287
+ weights = compute_decay_weights(
288
+ alpha=alpha,
289
+ l_range=l_range,
290
+ window_size=window_size,
291
+ decay_functions=decay_functions,
292
+ normalize=True,
105
293
  )
106
- weights = tf.divide(weights, normalization_factors)
107
- return tf.einsum('...mw,w...gtm->...gtm', weights, windowed)
294
+ return backend.einsum('...mw,w...gtm->...gtm', weights, windowed)
295
+
296
+
297
+ def _map_alpha_for_binomial_decay(x: backend.Tensor):
298
+ # Map x -> 1/x - 1 to map [0, 1] to [0, +inf].
299
+ # 0 -> +inf is a valid mapping and reflects the "no adstock" case.
300
+
301
+ return 1 / x - 1
108
302
 
109
303
 
110
304
  def _hill(
111
- media: tf.Tensor,
112
- ec: tf.Tensor,
113
- slope: tf.Tensor,
114
- ) -> tf.Tensor:
305
+ media: backend.Tensor,
306
+ ec: backend.Tensor,
307
+ slope: backend.Tensor,
308
+ ) -> backend.Tensor:
115
309
  """Computes the Hill function."""
116
310
  batch_dims = slope.shape[:-1]
117
311
 
118
312
  # Argument checks.
119
313
  if slope.shape != ec.shape:
120
314
  raise ValueError('`slope` and `ec` dimensions do not match.')
121
- if media.shape[:-3] not in [tf.TensorShape([]), tf.TensorShape(batch_dims)]:
315
+ if tuple(media.shape[:-3]) not in [(), tuple(batch_dims)]:
122
316
  raise ValueError(
123
317
  '`media` batch dims do not match `slope` and `ec` batch dims. '
124
318
  'If `media` has batch dims, then they must match `slope` and '
@@ -129,8 +323,8 @@ def _hill(
129
323
  '`media` contains a different number of channels than `slope` and `ec`.'
130
324
  )
131
325
 
132
- t1 = media ** slope[..., tf.newaxis, tf.newaxis, :]
133
- t2 = (ec**slope)[..., tf.newaxis, tf.newaxis, :]
326
+ t1 = media ** slope[..., backend.newaxis, backend.newaxis, :]
327
+ t2 = (ec**slope)[..., backend.newaxis, backend.newaxis, :]
134
328
  return t1 / (t1 + t2)
135
329
 
136
330
 
@@ -138,24 +332,28 @@ class AdstockHillTransformer(metaclass=abc.ABCMeta):
138
332
  """Abstract class to compute the Adstock and Hill transformation of media."""
139
333
 
140
334
  @abc.abstractmethod
141
- def forward(self, media: tf.Tensor) -> tf.Tensor:
335
+ def forward(self, media: backend.Tensor) -> backend.Tensor:
142
336
  """Computes the Adstock and Hill transformation of a given media tensor."""
143
337
  pass
144
338
 
145
339
 
146
340
  class AdstockTransformer(AdstockHillTransformer):
147
- """Computes the Adstock transformation of media."""
148
-
149
- def __init__(self, alpha: tf.Tensor, max_lag: int, n_times_output: int):
341
+ """Class to compute the Adstock transformation of media."""
342
+
343
+ def __init__(
344
+ self,
345
+ alpha: backend.Tensor,
346
+ max_lag: int,
347
+ n_times_output: int,
348
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
349
+ ):
150
350
  """Initializes this transformer based on Adstock function parameters.
151
351
 
152
352
  Args:
153
- alpha: Tensor of `alpha` parameters taking values `[0, 1)` with
353
+ alpha: Tensor of `alpha` parameters taking values in `[0, 1]` with
154
354
  dimensions `[..., n_media_channels]`. Batch dimensions `(...)` are
155
355
  optional. Note that `alpha = 0` is allowed, so it is possible to put a
156
- point mass prior at zero (effectively no Adstock). However, `alpha = 1`
157
- is not allowed since the geometric sum formula is not defined, and there
158
- is no practical reason to have point mass at `alpha = 1`.
356
+ point mass prior at zero (effectively no Adstock).
159
357
  max_lag: Integer indicating the maximum number of lag periods (≥ `0`) to
160
358
  include in the Adstock calculation.
161
359
  n_times_output: Integer indicating the number of time periods to include
@@ -164,12 +362,16 @@ class AdstockTransformer(AdstockHillTransformer):
164
362
  correspond to the most recent time periods of the media argument. For
165
363
  example, `media[..., -n_times_output:, :]` represents the media
166
364
  execution of the output weeks.
365
+ decay_functions: String or list of strings indicating the decay
366
+ function(s) to use for the Adstock calculation for each channel.
367
+ Default is geometric decay for all channels.
167
368
  """
168
369
  self._alpha = alpha
169
370
  self._max_lag = max_lag
170
371
  self._n_times_output = n_times_output
372
+ self._decay_functions = decay_functions
171
373
 
172
- def forward(self, media: tf.Tensor) -> tf.Tensor:
374
+ def forward(self, media: backend.Tensor) -> backend.Tensor:
173
375
  """Computes the Adstock transformation of a given `media` tensor.
174
376
 
175
377
  For geo `g`, time period `t`, and media channel `m`, Adstock is calculated
@@ -196,13 +398,14 @@ class AdstockTransformer(AdstockHillTransformer):
196
398
  alpha=self._alpha,
197
399
  max_lag=self._max_lag,
198
400
  n_times_output=self._n_times_output,
401
+ decay_functions=self._decay_functions,
199
402
  )
200
403
 
201
404
 
202
405
  class HillTransformer(AdstockHillTransformer):
203
406
  """Class to compute the Hill transformation of media."""
204
407
 
205
- def __init__(self, ec: tf.Tensor, slope: tf.Tensor):
408
+ def __init__(self, ec: backend.Tensor, slope: backend.Tensor):
206
409
  """Initializes the instance based on the Hill function parameters.
207
410
 
208
411
  Args:
@@ -216,7 +419,7 @@ class HillTransformer(AdstockHillTransformer):
216
419
  self._ec = ec
217
420
  self._slope = slope
218
421
 
219
- def forward(self, media: tf.Tensor) -> tf.Tensor:
422
+ def forward(self, media: backend.Tensor) -> backend.Tensor:
220
423
  """Computes the Hill transformation of a given `media` tensor.
221
424
 
222
425
  Calculates results for the Hill function, which accounts for the diminishing
@@ -234,3 +437,47 @@ class HillTransformer(AdstockHillTransformer):
234
437
  representing Hill-transformed media.
235
438
  """
236
439
  return _hill(media=media, ec=self._ec, slope=self._slope)
440
+
441
+
442
+ def transform_non_negative_reals_distribution(
443
+ distribution: backend.tfd.Distribution,
444
+ ) -> backend.tfd.TransformedDistribution:
445
+ """Transforms a distribution with support on `[0, infinity)` to `(0, 1]`.
446
+
447
+ This allows for defining a prior on `alpha_*`, the exponent of the binomial
448
+ Adstock decay function, directly, and then translating it to a distribution
449
+ defined on the unit interval as Meridian expects. This transformation
450
+ `(x -> 1 / (1 + x))` is the inverse of the interval mapping the Meridian
451
+ performs `(x -> 1 / x - 1)` on alpha to define the binomial Adstock
452
+ decay function's exponent.
453
+
454
+ For example, to define a `LogNormal(0.2, 0.9)` prior on `alpha_*`:
455
+
456
+ ```python
457
+ from meridian import backend
458
+ alpha_star_prior = backend.tfd.LogNormal(0.2, 0.9)
459
+ alpha_prior = transform_non_negative_reals_distribution(alpha_star_prior)
460
+ prior = prior_distribution.PriorDistribution(
461
+ alpha_m=alpha_prior,
462
+ ...
463
+ )
464
+ ```
465
+
466
+ Args:
467
+ distribution: A Tensorflow Probability distribution with support on `[0,
468
+ infinity)`.
469
+
470
+ Returns:
471
+ A Tensorflow Probability `TransformedDistribution` with support on `(0, 1]`,
472
+ such that the resultant prior on `alpha_*` is the input distribution.
473
+ """
474
+
475
+ bijector = backend.bijectors.Chain(
476
+ [backend.bijectors.Reciprocal(), backend.bijectors.Shift(1)]
477
+ )
478
+
479
+ return backend.tfd.TransformedDistribution(
480
+ distribution=distribution,
481
+ bijector=bijector,
482
+ name=f'{distribution.name}UnitIntervalMapped',
483
+ )
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The Meridian API module that performs EDA checks."""
16
+
17
+ from meridian.model.eda import eda_engine