google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,17 +15,200 @@
15
15
  """Function definitions for Adstock and Hill calculations."""
16
16
 
17
17
  import abc
18
- import tensorflow as tf
18
+ from collections.abc import Sequence
19
+ import dataclasses
20
+ from meridian import backend
21
+ from meridian import constants
22
+
19
23
 
20
24
  __all__ = [
25
+ 'AdstockDecaySpec',
21
26
  'AdstockHillTransformer',
22
27
  'AdstockTransformer',
23
28
  'HillTransformer',
29
+ 'transform_non_negative_reals_distribution',
30
+ 'compute_decay_weights',
24
31
  ]
25
32
 
26
33
 
34
+ @dataclasses.dataclass(frozen=True)
35
+ class AdstockDecaySpec:
36
+ """Specification for each channel's adstock decay function.
37
+
38
+ This class contains the adstock decay function(s) to use for each channel
39
+ that the adstock transformation is applied to.
40
+
41
+ Attributes:
42
+ media: A string or sequence of strings specifying the adstock function(s)
43
+ to use for media channels.
44
+ rf: A string or sequence of strings specifying the adstock function(s)
45
+ to use for reach and frequency channels.
46
+ organic_media: A string or sequence of strings specifying the adstock
47
+ function(s) to use for organic media channels.
48
+ organic_rf: A string or sequence of strings specifying the adstock
49
+ function(s) to use for organic reach and frequency channels.
50
+ """
51
+
52
+ media: str | Sequence[str] = constants.GEOMETRIC_DECAY
53
+ rf: str | Sequence[str] = constants.GEOMETRIC_DECAY
54
+ organic_media: str | Sequence[str] = constants.GEOMETRIC_DECAY
55
+ organic_rf: str | Sequence[str] = constants.GEOMETRIC_DECAY
56
+
57
+ @classmethod
58
+ def from_consistent_type(
59
+ cls,
60
+ consistent_decay_function: str = constants.GEOMETRIC_DECAY,
61
+ ) -> 'AdstockDecaySpec':
62
+ """Create an `AdstockDecaySpec` with the same decay function for all channels.
63
+
64
+ Arguments:
65
+ consistent_decay_function: A string denoting the adstock decay function
66
+ to use for all channels that the Adstock transformation is applied to.
67
+
68
+ Raises:
69
+ ValueError: If `consistent_decay_function` is not 'geometric' or
70
+ 'binomial'.
71
+ """
72
+ adstock_decay_functions = dict.fromkeys(
73
+ constants.ADSTOCK_CHANNELS, consistent_decay_function
74
+ )
75
+ return cls(**adstock_decay_functions)
76
+
77
+ def __post_init__(self):
78
+ adstock_decay_functions = [
79
+ self.media,
80
+ self.rf,
81
+ self.organic_media,
82
+ self.organic_rf,
83
+ ]
84
+
85
+ for v in adstock_decay_functions:
86
+ if isinstance(v, str):
87
+ _validate_adstock_decay_function(v)
88
+ else:
89
+ for vi in v:
90
+ _validate_adstock_decay_function(vi)
91
+
92
+
93
+ def _validate_adstock_decay_function(adstock_decay_func: str):
94
+ if adstock_decay_func not in constants.ADSTOCK_DECAY_FUNCTIONS:
95
+ raise ValueError(
96
+ 'Unrecognized adstock decay function value '
97
+ f'("{adstock_decay_func}"), expected one of '
98
+ f'{constants.ADSTOCK_DECAY_FUNCTIONS}.'
99
+ )
100
+
101
+
102
+ def compute_decay_weights(
103
+ alpha: backend.Tensor,
104
+ l_range: backend.Tensor,
105
+ window_size: int,
106
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
107
+ normalize: bool = True,
108
+ ) -> backend.Tensor:
109
+ """Computes decay weights using geometric and/or binomial decay.
110
+
111
+ This function always broadcasts the lag dimension (`l_range`) to the
112
+ trailing axis of the output tensor.
113
+
114
+ Args:
115
+ alpha: The parameter for the adstock decay function.
116
+ l_range: A 1D tensor representing the lag range, e.g., `[w-1, w-2, ...,
117
+ 0]`.
118
+ window_size: The number of time periods that go into the adstock weighted
119
+ average for each output time period.
120
+ decay_functions: String or sequence of strings indicating the decay
121
+ function(s) to use for the Adstock calculation. Allowed values
122
+ are 'geometric' and 'binomial'.
123
+ normalize: A boolean indicating whether to normalize the weights. Default:
124
+ `True`.
125
+
126
+ Returns:
127
+ A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
128
+
129
+ Raises:
130
+ ValueError: If the shape of `decay_functions` is not broadcastable to
131
+ the shape of `alpha`.
132
+
133
+ """
134
+
135
+ if isinstance(decay_functions, str):
136
+ # Same decay function for all channels
137
+ return _compute_single_decay_function_weights(
138
+ alpha, l_range, window_size, decay_functions, normalize,
139
+ )
140
+
141
+ binomial_weights = _compute_single_decay_function_weights(
142
+ alpha, l_range, window_size, constants.BINOMIAL_DECAY, normalize,
143
+ )
144
+ geometric_weights = _compute_single_decay_function_weights(
145
+ alpha, l_range, window_size, constants.GEOMETRIC_DECAY, normalize,
146
+ )
147
+
148
+ binomial_decay_mask = backend.reshape(
149
+ backend.to_tensor(decay_functions) == constants.BINOMIAL_DECAY,
150
+ (-1, 1),
151
+ )
152
+
153
+ try:
154
+ # pytype: disable=bad-return-type
155
+ return backend.where(
156
+ binomial_decay_mask, binomial_weights, geometric_weights
157
+ )
158
+ # pytype: enable=bad-return-type
159
+ except (backend.errors.InvalidArgumentError, ValueError) as e:
160
+ raise ValueError(
161
+ f'The shape of `alpha` ({alpha.shape}) is incompatible with the length'
162
+ f' of `decay_functions` ({len(decay_functions)})'
163
+ ) from e
164
+
165
+
166
+ def _compute_single_decay_function_weights(
167
+ alpha: backend.Tensor,
168
+ l_range: backend.Tensor,
169
+ window_size: int,
170
+ decay_function: str,
171
+ normalize: bool,
172
+ ) -> backend.Tensor:
173
+ """Computes decay weights using geometric decay.
174
+
175
+ This function always broadcasts the lag dimension (`l_range`) to the
176
+ trailing axis of the output tensor.
177
+
178
+ Args:
179
+ alpha: The parameter for the adstock decay function.
180
+ l_range: A 1D tensor representing the lag range, e.g., `[w-1, w-2, ...,
181
+ 0]`.
182
+ window_size: The number of time periods that go into the adstock weighted
183
+ average for each output time period.
184
+ decay_function: String indicating the decay function to use for the
185
+ Adstock calculation. Allowed values are 'geometric' and 'binomial'.
186
+ normalize: A boolean indicating whether to normalize the weights.
187
+
188
+ Returns:
189
+ A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
190
+ """
191
+ expanded_alpha = backend.expand_dims(alpha, -1)
192
+ match decay_function:
193
+ case constants.GEOMETRIC_DECAY:
194
+ weights = expanded_alpha**l_range
195
+ case constants.BINOMIAL_DECAY:
196
+ mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
197
+ weights = (1 - l_range / window_size) ** mapped_alpha_binomial
198
+ case _:
199
+ raise ValueError(f'Unsupported decay function: {decay_function}')
200
+
201
+ if normalize:
202
+ normalization_factors = backend.reduce_sum(weights, axis=-1, keepdims=True)
203
+ return backend.divide(weights, normalization_factors)
204
+ return weights
205
+
206
+
27
207
  def _validate_arguments(
28
- media: tf.Tensor, alpha: tf.Tensor, max_lag: int, n_times_output: int
208
+ media: backend.Tensor,
209
+ alpha: backend.Tensor,
210
+ max_lag: int,
211
+ n_times_output: int,
29
212
  ) -> None:
30
213
  batch_dims = alpha.shape[:-1]
31
214
  n_media_times = media.shape[-2]
@@ -35,7 +218,7 @@ def _validate_arguments(
35
218
  '`n_times_output` cannot exceed number of time periods in the media'
36
219
  ' data.'
37
220
  )
38
- if media.shape[:-3] not in [tf.TensorShape([]), tf.TensorShape(batch_dims)]:
221
+ if tuple(media.shape[:-3]) not in [(), tuple(batch_dims)]:
39
222
  raise ValueError(
40
223
  '`media` batch dims do not match `alpha` batch dims. If `media` '
41
224
  'has batch dims, then they must match `alpha`.'
@@ -51,11 +234,12 @@ def _validate_arguments(
51
234
 
52
235
 
53
236
  def _adstock(
54
- media: tf.Tensor,
55
- alpha: tf.Tensor,
237
+ media: backend.Tensor,
238
+ alpha: backend.Tensor,
56
239
  max_lag: int,
57
240
  n_times_output: int,
58
- ) -> tf.Tensor:
241
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
242
+ ) -> backend.Tensor:
59
243
  """Computes the Adstock function."""
60
244
  _validate_arguments(
61
245
  media=media, alpha=alpha, max_lag=max_lag, n_times_output=n_times_output
@@ -91,34 +275,43 @@ def _adstock(
91
275
  + (required_n_media_times - n_media_times,)
92
276
  + (media.shape[-1],)
93
277
  )
94
- media = tf.concat([tf.zeros(pad_shape), media], axis=-2)
278
+ media = backend.concatenate([backend.zeros(pad_shape), media], axis=-2)
95
279
 
96
280
  # Adstock calculation.
97
281
  window_list = [None] * window_size
98
282
  for i in range(window_size):
99
- window_list[i] = media[..., i:i+n_times_output, :]
100
- windowed = tf.stack(window_list)
101
- l_range = tf.range(window_size - 1, -1, -1, dtype=tf.float32)
102
- weights = tf.expand_dims(alpha, -1) ** l_range
103
- normalization_factors = tf.expand_dims(
104
- (1 - alpha ** (window_size)) / (1 - alpha), -1
283
+ window_list[i] = media[..., i : i + n_times_output, :]
284
+ windowed = backend.stack(window_list)
285
+ l_range = backend.arange(window_size - 1, -1, -1, dtype=backend.float32)
286
+ weights = compute_decay_weights(
287
+ alpha=alpha,
288
+ l_range=l_range,
289
+ window_size=window_size,
290
+ decay_functions=decay_functions,
291
+ normalize=True,
105
292
  )
106
- weights = tf.divide(weights, normalization_factors)
107
- return tf.einsum('...mw,w...gtm->...gtm', weights, windowed)
293
+ return backend.einsum('...mw,w...gtm->...gtm', weights, windowed)
294
+
295
+
296
+ def _map_alpha_for_binomial_decay(x: backend.Tensor):
297
+ # Map x -> 1/x - 1 to map [0, 1] to [0, +inf].
298
+ # 0 -> +inf is a valid mapping and reflects the "no adstock" case.
299
+
300
+ return 1 / x - 1
108
301
 
109
302
 
110
303
  def _hill(
111
- media: tf.Tensor,
112
- ec: tf.Tensor,
113
- slope: tf.Tensor,
114
- ) -> tf.Tensor:
304
+ media: backend.Tensor,
305
+ ec: backend.Tensor,
306
+ slope: backend.Tensor,
307
+ ) -> backend.Tensor:
115
308
  """Computes the Hill function."""
116
309
  batch_dims = slope.shape[:-1]
117
310
 
118
311
  # Argument checks.
119
312
  if slope.shape != ec.shape:
120
313
  raise ValueError('`slope` and `ec` dimensions do not match.')
121
- if media.shape[:-3] not in [tf.TensorShape([]), tf.TensorShape(batch_dims)]:
314
+ if tuple(media.shape[:-3]) not in [(), tuple(batch_dims)]:
122
315
  raise ValueError(
123
316
  '`media` batch dims do not match `slope` and `ec` batch dims. '
124
317
  'If `media` has batch dims, then they must match `slope` and '
@@ -129,8 +322,8 @@ def _hill(
129
322
  '`media` contains a different number of channels than `slope` and `ec`.'
130
323
  )
131
324
 
132
- t1 = media ** slope[..., tf.newaxis, tf.newaxis, :]
133
- t2 = (ec**slope)[..., tf.newaxis, tf.newaxis, :]
325
+ t1 = media ** slope[..., backend.newaxis, backend.newaxis, :]
326
+ t2 = (ec**slope)[..., backend.newaxis, backend.newaxis, :]
134
327
  return t1 / (t1 + t2)
135
328
 
136
329
 
@@ -138,24 +331,28 @@ class AdstockHillTransformer(metaclass=abc.ABCMeta):
138
331
  """Abstract class to compute the Adstock and Hill transformation of media."""
139
332
 
140
333
  @abc.abstractmethod
141
- def forward(self, media: tf.Tensor) -> tf.Tensor:
334
+ def forward(self, media: backend.Tensor) -> backend.Tensor:
142
335
  """Computes the Adstock and Hill transformation of a given media tensor."""
143
336
  pass
144
337
 
145
338
 
146
339
  class AdstockTransformer(AdstockHillTransformer):
147
- """Computes the Adstock transformation of media."""
148
-
149
- def __init__(self, alpha: tf.Tensor, max_lag: int, n_times_output: int):
340
+ """Class to compute the Adstock transformation of media."""
341
+
342
+ def __init__(
343
+ self,
344
+ alpha: backend.Tensor,
345
+ max_lag: int,
346
+ n_times_output: int,
347
+ decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
348
+ ):
150
349
  """Initializes this transformer based on Adstock function parameters.
151
350
 
152
351
  Args:
153
- alpha: Tensor of `alpha` parameters taking values `[0, 1)` with
352
+ alpha: Tensor of `alpha` parameters taking values in `[0, 1]` with
154
353
  dimensions `[..., n_media_channels]`. Batch dimensions `(...)` are
155
354
  optional. Note that `alpha = 0` is allowed, so it is possible to put a
156
- point mass prior at zero (effectively no Adstock). However, `alpha = 1`
157
- is not allowed since the geometric sum formula is not defined, and there
158
- is no practical reason to have point mass at `alpha = 1`.
355
+ point mass prior at zero (effectively no Adstock).
159
356
  max_lag: Integer indicating the maximum number of lag periods (≥ `0`) to
160
357
  include in the Adstock calculation.
161
358
  n_times_output: Integer indicating the number of time periods to include
@@ -164,12 +361,16 @@ class AdstockTransformer(AdstockHillTransformer):
164
361
  correspond to the most recent time periods of the media argument. For
165
362
  example, `media[..., -n_times_output:, :]` represents the media
166
363
  execution of the output weeks.
364
+ decay_functions: String or list of strings indicating the decay
365
+ function(s) to use for the Adstock calculation for each channel.
366
+ Default is geometric decay for all channels.
167
367
  """
168
368
  self._alpha = alpha
169
369
  self._max_lag = max_lag
170
370
  self._n_times_output = n_times_output
371
+ self._decay_functions = decay_functions
171
372
 
172
- def forward(self, media: tf.Tensor) -> tf.Tensor:
373
+ def forward(self, media: backend.Tensor) -> backend.Tensor:
173
374
  """Computes the Adstock transformation of a given `media` tensor.
174
375
 
175
376
  For geo `g`, time period `t`, and media channel `m`, Adstock is calculated
@@ -196,13 +397,14 @@ class AdstockTransformer(AdstockHillTransformer):
196
397
  alpha=self._alpha,
197
398
  max_lag=self._max_lag,
198
399
  n_times_output=self._n_times_output,
400
+ decay_functions=self._decay_functions,
199
401
  )
200
402
 
201
403
 
202
404
  class HillTransformer(AdstockHillTransformer):
203
405
  """Class to compute the Hill transformation of media."""
204
406
 
205
- def __init__(self, ec: tf.Tensor, slope: tf.Tensor):
407
+ def __init__(self, ec: backend.Tensor, slope: backend.Tensor):
206
408
  """Initializes the instance based on the Hill function parameters.
207
409
 
208
410
  Args:
@@ -216,7 +418,7 @@ class HillTransformer(AdstockHillTransformer):
216
418
  self._ec = ec
217
419
  self._slope = slope
218
420
 
219
- def forward(self, media: tf.Tensor) -> tf.Tensor:
421
+ def forward(self, media: backend.Tensor) -> backend.Tensor:
220
422
  """Computes the Hill transformation of a given `media` tensor.
221
423
 
222
424
  Calculates results for the Hill function, which accounts for the diminishing
@@ -234,3 +436,47 @@ class HillTransformer(AdstockHillTransformer):
234
436
  representing Hill-transformed media.
235
437
  """
236
438
  return _hill(media=media, ec=self._ec, slope=self._slope)
439
+
440
+
441
+ def transform_non_negative_reals_distribution(
442
+ distribution: backend.tfd.Distribution,
443
+ ) -> backend.tfd.TransformedDistribution:
444
+ """Transforms a distribution with support on `[0, infinity)` to `(0, 1]`.
445
+
446
+ This allows for defining a prior on `alpha_*`, the exponent of the binomial
447
+ Adstock decay function, directly, and then translating it to a distribution
448
+ defined on the unit interval as Meridian expects. This transformation
449
+ `(x -> 1 / (1 + x))` is the inverse of the interval mapping the Meridian
450
+ performs `(x -> 1 / x - 1)` on alpha to define the binomial Adstock
451
+ decay function's exponent.
452
+
453
+ For example, to define a `LogNormal(0.2, 0.9)` prior on `alpha_*`:
454
+
455
+ ```python
456
+ from meridian import backend
457
+ alpha_star_prior = backend.tfd.LogNormal(0.2, 0.9)
458
+ alpha_prior = transform_non_negative_reals_distribution(alpha_star_prior)
459
+ prior = prior_distribution.PriorDistribution(
460
+ alpha_m=alpha_prior,
461
+ ...
462
+ )
463
+ ```
464
+
465
+ Args:
466
+ distribution: A Tensorflow Probability distribution with support on `[0,
467
+ infinity)`.
468
+
469
+ Returns:
470
+ A Tensorflow Probability `TransformedDistribution` with support on `(0, 1]`,
471
+ such that the resultant prior on `alpha_*` is the input distribution.
472
+ """
473
+
474
+ bijector = backend.bijectors.Chain(
475
+ [backend.bijectors.Reciprocal(), backend.bijectors.Shift(1)]
476
+ )
477
+
478
+ return backend.tfd.TransformedDistribution(
479
+ distribution=distribution,
480
+ bijector=bijector,
481
+ name=f'{distribution.name}UnitIntervalMapped',
482
+ )
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The Meridian API module that performs EDA checks."""
16
+
17
+ from meridian.model.eda import eda_engine