google-meridian 1.1.6__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/METADATA +8 -2
- google_meridian-1.2.1.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +621 -393
- meridian/analysis/optimizer.py +403 -351
- meridian/analysis/summarizer.py +31 -16
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +53 -54
- meridian/backend/__init__.py +975 -0
- meridian/backend/config.py +118 -0
- meridian/backend/test_utils.py +181 -0
- meridian/constants.py +71 -10
- meridian/data/input_data.py +99 -0
- meridian/data/test_utils.py +146 -12
- meridian/mlflow/autolog.py +2 -2
- meridian/model/adstock_hill.py +280 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +735 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +331 -159
- meridian/model/posterior_sampler.py +388 -383
- meridian/model/prior_distribution.py +612 -177
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +55 -49
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.1.dist-info}/top_level.txt +0 -0
meridian/model/adstock_hill.py
CHANGED
|
@@ -15,17 +15,201 @@
|
|
|
15
15
|
"""Function definitions for Adstock and Hill calculations."""
|
|
16
16
|
|
|
17
17
|
import abc
|
|
18
|
-
|
|
18
|
+
from collections.abc import Sequence
|
|
19
|
+
import dataclasses
|
|
20
|
+
from meridian import backend
|
|
21
|
+
from meridian import constants
|
|
22
|
+
|
|
19
23
|
|
|
20
24
|
__all__ = [
|
|
25
|
+
'AdstockDecaySpec',
|
|
21
26
|
'AdstockHillTransformer',
|
|
22
27
|
'AdstockTransformer',
|
|
23
28
|
'HillTransformer',
|
|
29
|
+
'transform_non_negative_reals_distribution',
|
|
30
|
+
'compute_decay_weights',
|
|
24
31
|
]
|
|
25
32
|
|
|
26
33
|
|
|
34
|
+
@dataclasses.dataclass(frozen=True)
|
|
35
|
+
class AdstockDecaySpec:
|
|
36
|
+
"""Specification for each channel's adstock decay function.
|
|
37
|
+
|
|
38
|
+
This class contains the adstock decay function(s) to use for each channel
|
|
39
|
+
that the adstock transformation is applied to.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
media: A string or sequence of strings specifying the adstock function(s)
|
|
43
|
+
to use for media channels.
|
|
44
|
+
rf: A string or sequence of strings specifying the adstock function(s)
|
|
45
|
+
to use for reach and frequency channels.
|
|
46
|
+
organic_media: A string or sequence of strings specifying the adstock
|
|
47
|
+
function(s) to use for organic media channels.
|
|
48
|
+
organic_rf: A string or sequence of strings specifying the adstock
|
|
49
|
+
function(s) to use for organic reach and frequency channels.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
media: str | Sequence[str] = constants.GEOMETRIC_DECAY
|
|
53
|
+
rf: str | Sequence[str] = constants.GEOMETRIC_DECAY
|
|
54
|
+
organic_media: str | Sequence[str] = constants.GEOMETRIC_DECAY
|
|
55
|
+
organic_rf: str | Sequence[str] = constants.GEOMETRIC_DECAY
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_consistent_type(
|
|
59
|
+
cls,
|
|
60
|
+
consistent_decay_function: str = constants.GEOMETRIC_DECAY,
|
|
61
|
+
) -> 'AdstockDecaySpec':
|
|
62
|
+
"""Create an `AdstockDecaySpec` with the same decay function for all channels.
|
|
63
|
+
|
|
64
|
+
Arguments:
|
|
65
|
+
consistent_decay_function: A string denoting the adstock decay function
|
|
66
|
+
to use for all channels that the Adstock transformation is applied to.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If `consistent_decay_function` is not 'geometric' or
|
|
70
|
+
'binomial'.
|
|
71
|
+
"""
|
|
72
|
+
adstock_decay_functions = dict.fromkeys(
|
|
73
|
+
constants.ADSTOCK_CHANNELS, consistent_decay_function
|
|
74
|
+
)
|
|
75
|
+
return cls(**adstock_decay_functions)
|
|
76
|
+
|
|
77
|
+
def __post_init__(self):
|
|
78
|
+
adstock_decay_functions = [
|
|
79
|
+
self.media,
|
|
80
|
+
self.rf,
|
|
81
|
+
self.organic_media,
|
|
82
|
+
self.organic_rf,
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
for v in adstock_decay_functions:
|
|
86
|
+
if isinstance(v, str):
|
|
87
|
+
_validate_adstock_decay_function(v)
|
|
88
|
+
else:
|
|
89
|
+
for vi in v:
|
|
90
|
+
_validate_adstock_decay_function(vi)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _validate_adstock_decay_function(adstock_decay_func: str):
|
|
94
|
+
if adstock_decay_func not in constants.ADSTOCK_DECAY_FUNCTIONS:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
'Unrecognized adstock decay function value '
|
|
97
|
+
f'("{adstock_decay_func}"), expected one of '
|
|
98
|
+
f'{constants.ADSTOCK_DECAY_FUNCTIONS}.'
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def compute_decay_weights(
|
|
103
|
+
alpha: backend.Tensor,
|
|
104
|
+
l_range: backend.Tensor,
|
|
105
|
+
window_size: int,
|
|
106
|
+
decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
|
|
107
|
+
normalize: bool = True,
|
|
108
|
+
) -> backend.Tensor:
|
|
109
|
+
"""Computes decay weights using geometric and/or binomial decay.
|
|
110
|
+
|
|
111
|
+
This function always broadcasts the lag dimension (`l_range`) to the
|
|
112
|
+
trailing axis of the output tensor.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
alpha: The parameter for the adstock decay function.
|
|
116
|
+
l_range: A 1D tensor representing the lag range, e.g., `[w-1, w-2, ...,
|
|
117
|
+
0]`.
|
|
118
|
+
window_size: The number of time periods that go into the adstock weighted
|
|
119
|
+
average for each output time period.
|
|
120
|
+
decay_functions: String or sequence of strings indicating the decay
|
|
121
|
+
function(s) to use for the Adstock calculation. Allowed values
|
|
122
|
+
are 'geometric' and 'binomial'.
|
|
123
|
+
normalize: A boolean indicating whether to normalize the weights. Default:
|
|
124
|
+
`True`.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ValueError: If the shape of `decay_functions` is not broadcastable to
|
|
131
|
+
the shape of `alpha`.
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
if isinstance(decay_functions, str):
|
|
136
|
+
# Same decay function for all channels
|
|
137
|
+
return _compute_single_decay_function_weights(
|
|
138
|
+
alpha, l_range, window_size, decay_functions, normalize,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
binomial_weights = _compute_single_decay_function_weights(
|
|
142
|
+
alpha, l_range, window_size, constants.BINOMIAL_DECAY, normalize,
|
|
143
|
+
)
|
|
144
|
+
geometric_weights = _compute_single_decay_function_weights(
|
|
145
|
+
alpha, l_range, window_size, constants.GEOMETRIC_DECAY, normalize,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
is_binomial = [s == constants.BINOMIAL_DECAY for s in decay_functions]
|
|
149
|
+
binomial_decay_mask = backend.reshape(
|
|
150
|
+
backend.to_tensor(is_binomial, dtype=backend.bool_),
|
|
151
|
+
(-1, 1),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
# pytype: disable=bad-return-type
|
|
156
|
+
return backend.where(
|
|
157
|
+
binomial_decay_mask, binomial_weights, geometric_weights
|
|
158
|
+
)
|
|
159
|
+
# pytype: enable=bad-return-type
|
|
160
|
+
except (backend.errors.InvalidArgumentError, ValueError) as e:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f'The shape of `alpha` ({alpha.shape}) is incompatible with the length'
|
|
163
|
+
f' of `decay_functions` ({len(decay_functions)})'
|
|
164
|
+
) from e
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _compute_single_decay_function_weights(
|
|
168
|
+
alpha: backend.Tensor,
|
|
169
|
+
l_range: backend.Tensor,
|
|
170
|
+
window_size: int,
|
|
171
|
+
decay_function: str,
|
|
172
|
+
normalize: bool,
|
|
173
|
+
) -> backend.Tensor:
|
|
174
|
+
"""Computes decay weights using geometric decay.
|
|
175
|
+
|
|
176
|
+
This function always broadcasts the lag dimension (`l_range`) to the
|
|
177
|
+
trailing axis of the output tensor.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
alpha: The parameter for the adstock decay function.
|
|
181
|
+
l_range: A 1D tensor representing the lag range, e.g., `[w-1, w-2, ...,
|
|
182
|
+
0]`.
|
|
183
|
+
window_size: The number of time periods that go into the adstock weighted
|
|
184
|
+
average for each output time period.
|
|
185
|
+
decay_function: String indicating the decay function to use for the
|
|
186
|
+
Adstock calculation. Allowed values are 'geometric' and 'binomial'.
|
|
187
|
+
normalize: A boolean indicating whether to normalize the weights.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
|
|
191
|
+
"""
|
|
192
|
+
expanded_alpha = backend.expand_dims(alpha, -1)
|
|
193
|
+
|
|
194
|
+
if decay_function == constants.GEOMETRIC_DECAY:
|
|
195
|
+
weights = expanded_alpha**l_range
|
|
196
|
+
elif decay_function == constants.BINOMIAL_DECAY:
|
|
197
|
+
mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
|
|
198
|
+
weights = (1 - l_range / window_size) ** mapped_alpha_binomial
|
|
199
|
+
else:
|
|
200
|
+
raise ValueError(f'Unsupported decay function: {decay_function}')
|
|
201
|
+
|
|
202
|
+
if normalize:
|
|
203
|
+
normalization_factors = backend.reduce_sum(weights, axis=-1, keepdims=True)
|
|
204
|
+
return backend.divide(weights, normalization_factors)
|
|
205
|
+
return weights
|
|
206
|
+
|
|
207
|
+
|
|
27
208
|
def _validate_arguments(
|
|
28
|
-
media:
|
|
209
|
+
media: backend.Tensor,
|
|
210
|
+
alpha: backend.Tensor,
|
|
211
|
+
max_lag: int,
|
|
212
|
+
n_times_output: int,
|
|
29
213
|
) -> None:
|
|
30
214
|
batch_dims = alpha.shape[:-1]
|
|
31
215
|
n_media_times = media.shape[-2]
|
|
@@ -35,7 +219,7 @@ def _validate_arguments(
|
|
|
35
219
|
'`n_times_output` cannot exceed number of time periods in the media'
|
|
36
220
|
' data.'
|
|
37
221
|
)
|
|
38
|
-
if media.shape[:-3] not in [
|
|
222
|
+
if tuple(media.shape[:-3]) not in [(), tuple(batch_dims)]:
|
|
39
223
|
raise ValueError(
|
|
40
224
|
'`media` batch dims do not match `alpha` batch dims. If `media` '
|
|
41
225
|
'has batch dims, then they must match `alpha`.'
|
|
@@ -51,11 +235,12 @@ def _validate_arguments(
|
|
|
51
235
|
|
|
52
236
|
|
|
53
237
|
def _adstock(
|
|
54
|
-
media:
|
|
55
|
-
alpha:
|
|
238
|
+
media: backend.Tensor,
|
|
239
|
+
alpha: backend.Tensor,
|
|
56
240
|
max_lag: int,
|
|
57
241
|
n_times_output: int,
|
|
58
|
-
|
|
242
|
+
decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
|
|
243
|
+
) -> backend.Tensor:
|
|
59
244
|
"""Computes the Adstock function."""
|
|
60
245
|
_validate_arguments(
|
|
61
246
|
media=media, alpha=alpha, max_lag=max_lag, n_times_output=n_times_output
|
|
@@ -91,34 +276,43 @@ def _adstock(
|
|
|
91
276
|
+ (required_n_media_times - n_media_times,)
|
|
92
277
|
+ (media.shape[-1],)
|
|
93
278
|
)
|
|
94
|
-
media =
|
|
279
|
+
media = backend.concatenate([backend.zeros(pad_shape), media], axis=-2)
|
|
95
280
|
|
|
96
281
|
# Adstock calculation.
|
|
97
282
|
window_list = [None] * window_size
|
|
98
283
|
for i in range(window_size):
|
|
99
|
-
window_list[i] = media[..., i:i+n_times_output, :]
|
|
100
|
-
windowed =
|
|
101
|
-
l_range =
|
|
102
|
-
weights =
|
|
103
|
-
|
|
104
|
-
|
|
284
|
+
window_list[i] = media[..., i : i + n_times_output, :]
|
|
285
|
+
windowed = backend.stack(window_list)
|
|
286
|
+
l_range = backend.arange(window_size - 1, -1, -1, dtype=backend.float32)
|
|
287
|
+
weights = compute_decay_weights(
|
|
288
|
+
alpha=alpha,
|
|
289
|
+
l_range=l_range,
|
|
290
|
+
window_size=window_size,
|
|
291
|
+
decay_functions=decay_functions,
|
|
292
|
+
normalize=True,
|
|
105
293
|
)
|
|
106
|
-
|
|
107
|
-
|
|
294
|
+
return backend.einsum('...mw,w...gtm->...gtm', weights, windowed)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _map_alpha_for_binomial_decay(x: backend.Tensor):
|
|
298
|
+
# Map x -> 1/x - 1 to map [0, 1] to [0, +inf].
|
|
299
|
+
# 0 -> +inf is a valid mapping and reflects the "no adstock" case.
|
|
300
|
+
|
|
301
|
+
return 1 / x - 1
|
|
108
302
|
|
|
109
303
|
|
|
110
304
|
def _hill(
|
|
111
|
-
media:
|
|
112
|
-
ec:
|
|
113
|
-
slope:
|
|
114
|
-
) ->
|
|
305
|
+
media: backend.Tensor,
|
|
306
|
+
ec: backend.Tensor,
|
|
307
|
+
slope: backend.Tensor,
|
|
308
|
+
) -> backend.Tensor:
|
|
115
309
|
"""Computes the Hill function."""
|
|
116
310
|
batch_dims = slope.shape[:-1]
|
|
117
311
|
|
|
118
312
|
# Argument checks.
|
|
119
313
|
if slope.shape != ec.shape:
|
|
120
314
|
raise ValueError('`slope` and `ec` dimensions do not match.')
|
|
121
|
-
if media.shape[:-3] not in [
|
|
315
|
+
if tuple(media.shape[:-3]) not in [(), tuple(batch_dims)]:
|
|
122
316
|
raise ValueError(
|
|
123
317
|
'`media` batch dims do not match `slope` and `ec` batch dims. '
|
|
124
318
|
'If `media` has batch dims, then they must match `slope` and '
|
|
@@ -129,8 +323,8 @@ def _hill(
|
|
|
129
323
|
'`media` contains a different number of channels than `slope` and `ec`.'
|
|
130
324
|
)
|
|
131
325
|
|
|
132
|
-
t1 = media ** slope[...,
|
|
133
|
-
t2 = (ec**slope)[...,
|
|
326
|
+
t1 = media ** slope[..., backend.newaxis, backend.newaxis, :]
|
|
327
|
+
t2 = (ec**slope)[..., backend.newaxis, backend.newaxis, :]
|
|
134
328
|
return t1 / (t1 + t2)
|
|
135
329
|
|
|
136
330
|
|
|
@@ -138,24 +332,28 @@ class AdstockHillTransformer(metaclass=abc.ABCMeta):
|
|
|
138
332
|
"""Abstract class to compute the Adstock and Hill transformation of media."""
|
|
139
333
|
|
|
140
334
|
@abc.abstractmethod
|
|
141
|
-
def forward(self, media:
|
|
335
|
+
def forward(self, media: backend.Tensor) -> backend.Tensor:
|
|
142
336
|
"""Computes the Adstock and Hill transformation of a given media tensor."""
|
|
143
337
|
pass
|
|
144
338
|
|
|
145
339
|
|
|
146
340
|
class AdstockTransformer(AdstockHillTransformer):
|
|
147
|
-
"""
|
|
148
|
-
|
|
149
|
-
def __init__(
|
|
341
|
+
"""Class to compute the Adstock transformation of media."""
|
|
342
|
+
|
|
343
|
+
def __init__(
|
|
344
|
+
self,
|
|
345
|
+
alpha: backend.Tensor,
|
|
346
|
+
max_lag: int,
|
|
347
|
+
n_times_output: int,
|
|
348
|
+
decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
|
|
349
|
+
):
|
|
150
350
|
"""Initializes this transformer based on Adstock function parameters.
|
|
151
351
|
|
|
152
352
|
Args:
|
|
153
|
-
alpha: Tensor of `alpha` parameters taking values
|
|
353
|
+
alpha: Tensor of `alpha` parameters taking values in `[0, 1]` with
|
|
154
354
|
dimensions `[..., n_media_channels]`. Batch dimensions `(...)` are
|
|
155
355
|
optional. Note that `alpha = 0` is allowed, so it is possible to put a
|
|
156
|
-
point mass prior at zero (effectively no Adstock).
|
|
157
|
-
is not allowed since the geometric sum formula is not defined, and there
|
|
158
|
-
is no practical reason to have point mass at `alpha = 1`.
|
|
356
|
+
point mass prior at zero (effectively no Adstock).
|
|
159
357
|
max_lag: Integer indicating the maximum number of lag periods (≥ `0`) to
|
|
160
358
|
include in the Adstock calculation.
|
|
161
359
|
n_times_output: Integer indicating the number of time periods to include
|
|
@@ -164,12 +362,16 @@ class AdstockTransformer(AdstockHillTransformer):
|
|
|
164
362
|
correspond to the most recent time periods of the media argument. For
|
|
165
363
|
example, `media[..., -n_times_output:, :]` represents the media
|
|
166
364
|
execution of the output weeks.
|
|
365
|
+
decay_functions: String or list of strings indicating the decay
|
|
366
|
+
function(s) to use for the Adstock calculation for each channel.
|
|
367
|
+
Default is geometric decay for all channels.
|
|
167
368
|
"""
|
|
168
369
|
self._alpha = alpha
|
|
169
370
|
self._max_lag = max_lag
|
|
170
371
|
self._n_times_output = n_times_output
|
|
372
|
+
self._decay_functions = decay_functions
|
|
171
373
|
|
|
172
|
-
def forward(self, media:
|
|
374
|
+
def forward(self, media: backend.Tensor) -> backend.Tensor:
|
|
173
375
|
"""Computes the Adstock transformation of a given `media` tensor.
|
|
174
376
|
|
|
175
377
|
For geo `g`, time period `t`, and media channel `m`, Adstock is calculated
|
|
@@ -196,13 +398,14 @@ class AdstockTransformer(AdstockHillTransformer):
|
|
|
196
398
|
alpha=self._alpha,
|
|
197
399
|
max_lag=self._max_lag,
|
|
198
400
|
n_times_output=self._n_times_output,
|
|
401
|
+
decay_functions=self._decay_functions,
|
|
199
402
|
)
|
|
200
403
|
|
|
201
404
|
|
|
202
405
|
class HillTransformer(AdstockHillTransformer):
|
|
203
406
|
"""Class to compute the Hill transformation of media."""
|
|
204
407
|
|
|
205
|
-
def __init__(self, ec:
|
|
408
|
+
def __init__(self, ec: backend.Tensor, slope: backend.Tensor):
|
|
206
409
|
"""Initializes the instance based on the Hill function parameters.
|
|
207
410
|
|
|
208
411
|
Args:
|
|
@@ -216,7 +419,7 @@ class HillTransformer(AdstockHillTransformer):
|
|
|
216
419
|
self._ec = ec
|
|
217
420
|
self._slope = slope
|
|
218
421
|
|
|
219
|
-
def forward(self, media:
|
|
422
|
+
def forward(self, media: backend.Tensor) -> backend.Tensor:
|
|
220
423
|
"""Computes the Hill transformation of a given `media` tensor.
|
|
221
424
|
|
|
222
425
|
Calculates results for the Hill function, which accounts for the diminishing
|
|
@@ -234,3 +437,47 @@ class HillTransformer(AdstockHillTransformer):
|
|
|
234
437
|
representing Hill-transformed media.
|
|
235
438
|
"""
|
|
236
439
|
return _hill(media=media, ec=self._ec, slope=self._slope)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def transform_non_negative_reals_distribution(
|
|
443
|
+
distribution: backend.tfd.Distribution,
|
|
444
|
+
) -> backend.tfd.TransformedDistribution:
|
|
445
|
+
"""Transforms a distribution with support on `[0, infinity)` to `(0, 1]`.
|
|
446
|
+
|
|
447
|
+
This allows for defining a prior on `alpha_*`, the exponent of the binomial
|
|
448
|
+
Adstock decay function, directly, and then translating it to a distribution
|
|
449
|
+
defined on the unit interval as Meridian expects. This transformation
|
|
450
|
+
`(x -> 1 / (1 + x))` is the inverse of the interval mapping the Meridian
|
|
451
|
+
performs `(x -> 1 / x - 1)` on alpha to define the binomial Adstock
|
|
452
|
+
decay function's exponent.
|
|
453
|
+
|
|
454
|
+
For example, to define a `LogNormal(0.2, 0.9)` prior on `alpha_*`:
|
|
455
|
+
|
|
456
|
+
```python
|
|
457
|
+
from meridian import backend
|
|
458
|
+
alpha_star_prior = backend.tfd.LogNormal(0.2, 0.9)
|
|
459
|
+
alpha_prior = transform_non_negative_reals_distribution(alpha_star_prior)
|
|
460
|
+
prior = prior_distribution.PriorDistribution(
|
|
461
|
+
alpha_m=alpha_prior,
|
|
462
|
+
...
|
|
463
|
+
)
|
|
464
|
+
```
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
distribution: A Tensorflow Probability distribution with support on `[0,
|
|
468
|
+
infinity)`.
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
A Tensorflow Probability `TransformedDistribution` with support on `(0, 1]`,
|
|
472
|
+
such that the resultant prior on `alpha_*` is the input distribution.
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
bijector = backend.bijectors.Chain(
|
|
476
|
+
[backend.bijectors.Reciprocal(), backend.bijectors.Shift(1)]
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
return backend.tfd.TransformedDistribution(
|
|
480
|
+
distribution=distribution,
|
|
481
|
+
bijector=bijector,
|
|
482
|
+
name=f'{distribution.name}UnitIntervalMapped',
|
|
483
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""The Meridian API module that performs EDA checks."""
|
|
16
|
+
|
|
17
|
+
from meridian.model.eda import eda_engine
|