google-meridian 1.1.5__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/METADATA +8 -2
- google_meridian-1.2.0.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +526 -362
- meridian/analysis/optimizer.py +275 -267
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +37 -49
- meridian/backend/__init__.py +514 -0
- meridian/backend/config.py +59 -0
- meridian/backend/test_utils.py +95 -0
- meridian/constants.py +59 -3
- meridian/data/input_data.py +94 -0
- meridian/data/test_utils.py +144 -12
- meridian/model/adstock_hill.py +279 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +306 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +323 -157
- meridian/model/posterior_sampler.py +84 -77
- meridian/model/prior_distribution.py +538 -168
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +53 -47
- meridian/version.py +1 -1
- google_meridian-1.1.5.dist-info/RECORD +0 -47
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.5.dist-info → google_meridian-1.2.0.dist-info}/top_level.txt +0 -0
meridian/model/adstock_hill.py
CHANGED
|
@@ -15,17 +15,200 @@
|
|
|
15
15
|
"""Function definitions for Adstock and Hill calculations."""
|
|
16
16
|
|
|
17
17
|
import abc
|
|
18
|
-
|
|
18
|
+
from collections.abc import Sequence
|
|
19
|
+
import dataclasses
|
|
20
|
+
from meridian import backend
|
|
21
|
+
from meridian import constants
|
|
22
|
+
|
|
19
23
|
|
|
20
24
|
__all__ = [
|
|
25
|
+
'AdstockDecaySpec',
|
|
21
26
|
'AdstockHillTransformer',
|
|
22
27
|
'AdstockTransformer',
|
|
23
28
|
'HillTransformer',
|
|
29
|
+
'transform_non_negative_reals_distribution',
|
|
30
|
+
'compute_decay_weights',
|
|
24
31
|
]
|
|
25
32
|
|
|
26
33
|
|
|
34
|
+
@dataclasses.dataclass(frozen=True)
|
|
35
|
+
class AdstockDecaySpec:
|
|
36
|
+
"""Specification for each channel's adstock decay function.
|
|
37
|
+
|
|
38
|
+
This class contains the adstock decay function(s) to use for each channel
|
|
39
|
+
that the adstock transformation is applied to.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
media: A string or sequence of strings specifying the adstock function(s)
|
|
43
|
+
to use for media channels.
|
|
44
|
+
rf: A string or sequence of strings specifying the adstock function(s)
|
|
45
|
+
to use for reach and frequency channels.
|
|
46
|
+
organic_media: A string or sequence of strings specifying the adstock
|
|
47
|
+
function(s) to use for organic media channels.
|
|
48
|
+
organic_rf: A string or sequence of strings specifying the adstock
|
|
49
|
+
function(s) to use for organic reach and frequency channels.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
media: str | Sequence[str] = constants.GEOMETRIC_DECAY
|
|
53
|
+
rf: str | Sequence[str] = constants.GEOMETRIC_DECAY
|
|
54
|
+
organic_media: str | Sequence[str] = constants.GEOMETRIC_DECAY
|
|
55
|
+
organic_rf: str | Sequence[str] = constants.GEOMETRIC_DECAY
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_consistent_type(
|
|
59
|
+
cls,
|
|
60
|
+
consistent_decay_function: str = constants.GEOMETRIC_DECAY,
|
|
61
|
+
) -> 'AdstockDecaySpec':
|
|
62
|
+
"""Create an `AdstockDecaySpec` with the same decay function for all channels.
|
|
63
|
+
|
|
64
|
+
Arguments:
|
|
65
|
+
consistent_decay_function: A string denoting the adstock decay function
|
|
66
|
+
to use for all channels that the Adstock transformation is applied to.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
ValueError: If `consistent_decay_function` is not 'geometric' or
|
|
70
|
+
'binomial'.
|
|
71
|
+
"""
|
|
72
|
+
adstock_decay_functions = dict.fromkeys(
|
|
73
|
+
constants.ADSTOCK_CHANNELS, consistent_decay_function
|
|
74
|
+
)
|
|
75
|
+
return cls(**adstock_decay_functions)
|
|
76
|
+
|
|
77
|
+
def __post_init__(self):
|
|
78
|
+
adstock_decay_functions = [
|
|
79
|
+
self.media,
|
|
80
|
+
self.rf,
|
|
81
|
+
self.organic_media,
|
|
82
|
+
self.organic_rf,
|
|
83
|
+
]
|
|
84
|
+
|
|
85
|
+
for v in adstock_decay_functions:
|
|
86
|
+
if isinstance(v, str):
|
|
87
|
+
_validate_adstock_decay_function(v)
|
|
88
|
+
else:
|
|
89
|
+
for vi in v:
|
|
90
|
+
_validate_adstock_decay_function(vi)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _validate_adstock_decay_function(adstock_decay_func: str):
|
|
94
|
+
if adstock_decay_func not in constants.ADSTOCK_DECAY_FUNCTIONS:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
'Unrecognized adstock decay function value '
|
|
97
|
+
f'("{adstock_decay_func}"), expected one of '
|
|
98
|
+
f'{constants.ADSTOCK_DECAY_FUNCTIONS}.'
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def compute_decay_weights(
|
|
103
|
+
alpha: backend.Tensor,
|
|
104
|
+
l_range: backend.Tensor,
|
|
105
|
+
window_size: int,
|
|
106
|
+
decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
|
|
107
|
+
normalize: bool = True,
|
|
108
|
+
) -> backend.Tensor:
|
|
109
|
+
"""Computes decay weights using geometric and/or binomial decay.
|
|
110
|
+
|
|
111
|
+
This function always broadcasts the lag dimension (`l_range`) to the
|
|
112
|
+
trailing axis of the output tensor.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
alpha: The parameter for the adstock decay function.
|
|
116
|
+
l_range: A 1D tensor representing the lag range, e.g., `[w-1, w-2, ...,
|
|
117
|
+
0]`.
|
|
118
|
+
window_size: The number of time periods that go into the adstock weighted
|
|
119
|
+
average for each output time period.
|
|
120
|
+
decay_functions: String or sequence of strings indicating the decay
|
|
121
|
+
function(s) to use for the Adstock calculation. Allowed values
|
|
122
|
+
are 'geometric' and 'binomial'.
|
|
123
|
+
normalize: A boolean indicating whether to normalize the weights. Default:
|
|
124
|
+
`True`.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ValueError: If the shape of `decay_functions` is not broadcastable to
|
|
131
|
+
the shape of `alpha`.
|
|
132
|
+
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
if isinstance(decay_functions, str):
|
|
136
|
+
# Same decay function for all channels
|
|
137
|
+
return _compute_single_decay_function_weights(
|
|
138
|
+
alpha, l_range, window_size, decay_functions, normalize,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
binomial_weights = _compute_single_decay_function_weights(
|
|
142
|
+
alpha, l_range, window_size, constants.BINOMIAL_DECAY, normalize,
|
|
143
|
+
)
|
|
144
|
+
geometric_weights = _compute_single_decay_function_weights(
|
|
145
|
+
alpha, l_range, window_size, constants.GEOMETRIC_DECAY, normalize,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
binomial_decay_mask = backend.reshape(
|
|
149
|
+
backend.to_tensor(decay_functions) == constants.BINOMIAL_DECAY,
|
|
150
|
+
(-1, 1),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
# pytype: disable=bad-return-type
|
|
155
|
+
return backend.where(
|
|
156
|
+
binomial_decay_mask, binomial_weights, geometric_weights
|
|
157
|
+
)
|
|
158
|
+
# pytype: enable=bad-return-type
|
|
159
|
+
except (backend.errors.InvalidArgumentError, ValueError) as e:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
f'The shape of `alpha` ({alpha.shape}) is incompatible with the length'
|
|
162
|
+
f' of `decay_functions` ({len(decay_functions)})'
|
|
163
|
+
) from e
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _compute_single_decay_function_weights(
|
|
167
|
+
alpha: backend.Tensor,
|
|
168
|
+
l_range: backend.Tensor,
|
|
169
|
+
window_size: int,
|
|
170
|
+
decay_function: str,
|
|
171
|
+
normalize: bool,
|
|
172
|
+
) -> backend.Tensor:
|
|
173
|
+
"""Computes decay weights using geometric decay.
|
|
174
|
+
|
|
175
|
+
This function always broadcasts the lag dimension (`l_range`) to the
|
|
176
|
+
trailing axis of the output tensor.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
alpha: The parameter for the adstock decay function.
|
|
180
|
+
l_range: A 1D tensor representing the lag range, e.g., `[w-1, w-2, ...,
|
|
181
|
+
0]`.
|
|
182
|
+
window_size: The number of time periods that go into the adstock weighted
|
|
183
|
+
average for each output time period.
|
|
184
|
+
decay_function: String indicating the decay function to use for the
|
|
185
|
+
Adstock calculation. Allowed values are 'geometric' and 'binomial'.
|
|
186
|
+
normalize: A boolean indicating whether to normalize the weights.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
A tensor of weights with a shape of `(*alpha.shape, len(l_range))`.
|
|
190
|
+
"""
|
|
191
|
+
expanded_alpha = backend.expand_dims(alpha, -1)
|
|
192
|
+
match decay_function:
|
|
193
|
+
case constants.GEOMETRIC_DECAY:
|
|
194
|
+
weights = expanded_alpha**l_range
|
|
195
|
+
case constants.BINOMIAL_DECAY:
|
|
196
|
+
mapped_alpha_binomial = _map_alpha_for_binomial_decay(expanded_alpha)
|
|
197
|
+
weights = (1 - l_range / window_size) ** mapped_alpha_binomial
|
|
198
|
+
case _:
|
|
199
|
+
raise ValueError(f'Unsupported decay function: {decay_function}')
|
|
200
|
+
|
|
201
|
+
if normalize:
|
|
202
|
+
normalization_factors = backend.reduce_sum(weights, axis=-1, keepdims=True)
|
|
203
|
+
return backend.divide(weights, normalization_factors)
|
|
204
|
+
return weights
|
|
205
|
+
|
|
206
|
+
|
|
27
207
|
def _validate_arguments(
|
|
28
|
-
media:
|
|
208
|
+
media: backend.Tensor,
|
|
209
|
+
alpha: backend.Tensor,
|
|
210
|
+
max_lag: int,
|
|
211
|
+
n_times_output: int,
|
|
29
212
|
) -> None:
|
|
30
213
|
batch_dims = alpha.shape[:-1]
|
|
31
214
|
n_media_times = media.shape[-2]
|
|
@@ -35,7 +218,7 @@ def _validate_arguments(
|
|
|
35
218
|
'`n_times_output` cannot exceed number of time periods in the media'
|
|
36
219
|
' data.'
|
|
37
220
|
)
|
|
38
|
-
if media.shape[:-3] not in [
|
|
221
|
+
if tuple(media.shape[:-3]) not in [(), tuple(batch_dims)]:
|
|
39
222
|
raise ValueError(
|
|
40
223
|
'`media` batch dims do not match `alpha` batch dims. If `media` '
|
|
41
224
|
'has batch dims, then they must match `alpha`.'
|
|
@@ -51,11 +234,12 @@ def _validate_arguments(
|
|
|
51
234
|
|
|
52
235
|
|
|
53
236
|
def _adstock(
|
|
54
|
-
media:
|
|
55
|
-
alpha:
|
|
237
|
+
media: backend.Tensor,
|
|
238
|
+
alpha: backend.Tensor,
|
|
56
239
|
max_lag: int,
|
|
57
240
|
n_times_output: int,
|
|
58
|
-
|
|
241
|
+
decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
|
|
242
|
+
) -> backend.Tensor:
|
|
59
243
|
"""Computes the Adstock function."""
|
|
60
244
|
_validate_arguments(
|
|
61
245
|
media=media, alpha=alpha, max_lag=max_lag, n_times_output=n_times_output
|
|
@@ -91,34 +275,43 @@ def _adstock(
|
|
|
91
275
|
+ (required_n_media_times - n_media_times,)
|
|
92
276
|
+ (media.shape[-1],)
|
|
93
277
|
)
|
|
94
|
-
media =
|
|
278
|
+
media = backend.concatenate([backend.zeros(pad_shape), media], axis=-2)
|
|
95
279
|
|
|
96
280
|
# Adstock calculation.
|
|
97
281
|
window_list = [None] * window_size
|
|
98
282
|
for i in range(window_size):
|
|
99
|
-
window_list[i] = media[..., i:i+n_times_output, :]
|
|
100
|
-
windowed =
|
|
101
|
-
l_range =
|
|
102
|
-
weights =
|
|
103
|
-
|
|
104
|
-
|
|
283
|
+
window_list[i] = media[..., i : i + n_times_output, :]
|
|
284
|
+
windowed = backend.stack(window_list)
|
|
285
|
+
l_range = backend.arange(window_size - 1, -1, -1, dtype=backend.float32)
|
|
286
|
+
weights = compute_decay_weights(
|
|
287
|
+
alpha=alpha,
|
|
288
|
+
l_range=l_range,
|
|
289
|
+
window_size=window_size,
|
|
290
|
+
decay_functions=decay_functions,
|
|
291
|
+
normalize=True,
|
|
105
292
|
)
|
|
106
|
-
|
|
107
|
-
|
|
293
|
+
return backend.einsum('...mw,w...gtm->...gtm', weights, windowed)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _map_alpha_for_binomial_decay(x: backend.Tensor):
|
|
297
|
+
# Map x -> 1/x - 1 to map [0, 1] to [0, +inf].
|
|
298
|
+
# 0 -> +inf is a valid mapping and reflects the "no adstock" case.
|
|
299
|
+
|
|
300
|
+
return 1 / x - 1
|
|
108
301
|
|
|
109
302
|
|
|
110
303
|
def _hill(
|
|
111
|
-
media:
|
|
112
|
-
ec:
|
|
113
|
-
slope:
|
|
114
|
-
) ->
|
|
304
|
+
media: backend.Tensor,
|
|
305
|
+
ec: backend.Tensor,
|
|
306
|
+
slope: backend.Tensor,
|
|
307
|
+
) -> backend.Tensor:
|
|
115
308
|
"""Computes the Hill function."""
|
|
116
309
|
batch_dims = slope.shape[:-1]
|
|
117
310
|
|
|
118
311
|
# Argument checks.
|
|
119
312
|
if slope.shape != ec.shape:
|
|
120
313
|
raise ValueError('`slope` and `ec` dimensions do not match.')
|
|
121
|
-
if media.shape[:-3] not in [
|
|
314
|
+
if tuple(media.shape[:-3]) not in [(), tuple(batch_dims)]:
|
|
122
315
|
raise ValueError(
|
|
123
316
|
'`media` batch dims do not match `slope` and `ec` batch dims. '
|
|
124
317
|
'If `media` has batch dims, then they must match `slope` and '
|
|
@@ -129,8 +322,8 @@ def _hill(
|
|
|
129
322
|
'`media` contains a different number of channels than `slope` and `ec`.'
|
|
130
323
|
)
|
|
131
324
|
|
|
132
|
-
t1 = media ** slope[...,
|
|
133
|
-
t2 = (ec**slope)[...,
|
|
325
|
+
t1 = media ** slope[..., backend.newaxis, backend.newaxis, :]
|
|
326
|
+
t2 = (ec**slope)[..., backend.newaxis, backend.newaxis, :]
|
|
134
327
|
return t1 / (t1 + t2)
|
|
135
328
|
|
|
136
329
|
|
|
@@ -138,24 +331,28 @@ class AdstockHillTransformer(metaclass=abc.ABCMeta):
|
|
|
138
331
|
"""Abstract class to compute the Adstock and Hill transformation of media."""
|
|
139
332
|
|
|
140
333
|
@abc.abstractmethod
|
|
141
|
-
def forward(self, media:
|
|
334
|
+
def forward(self, media: backend.Tensor) -> backend.Tensor:
|
|
142
335
|
"""Computes the Adstock and Hill transformation of a given media tensor."""
|
|
143
336
|
pass
|
|
144
337
|
|
|
145
338
|
|
|
146
339
|
class AdstockTransformer(AdstockHillTransformer):
|
|
147
|
-
"""
|
|
148
|
-
|
|
149
|
-
def __init__(
|
|
340
|
+
"""Class to compute the Adstock transformation of media."""
|
|
341
|
+
|
|
342
|
+
def __init__(
|
|
343
|
+
self,
|
|
344
|
+
alpha: backend.Tensor,
|
|
345
|
+
max_lag: int,
|
|
346
|
+
n_times_output: int,
|
|
347
|
+
decay_functions: str | Sequence[str] = constants.GEOMETRIC_DECAY,
|
|
348
|
+
):
|
|
150
349
|
"""Initializes this transformer based on Adstock function parameters.
|
|
151
350
|
|
|
152
351
|
Args:
|
|
153
|
-
alpha: Tensor of `alpha` parameters taking values
|
|
352
|
+
alpha: Tensor of `alpha` parameters taking values in `[0, 1]` with
|
|
154
353
|
dimensions `[..., n_media_channels]`. Batch dimensions `(...)` are
|
|
155
354
|
optional. Note that `alpha = 0` is allowed, so it is possible to put a
|
|
156
|
-
point mass prior at zero (effectively no Adstock).
|
|
157
|
-
is not allowed since the geometric sum formula is not defined, and there
|
|
158
|
-
is no practical reason to have point mass at `alpha = 1`.
|
|
355
|
+
point mass prior at zero (effectively no Adstock).
|
|
159
356
|
max_lag: Integer indicating the maximum number of lag periods (≥ `0`) to
|
|
160
357
|
include in the Adstock calculation.
|
|
161
358
|
n_times_output: Integer indicating the number of time periods to include
|
|
@@ -164,12 +361,16 @@ class AdstockTransformer(AdstockHillTransformer):
|
|
|
164
361
|
correspond to the most recent time periods of the media argument. For
|
|
165
362
|
example, `media[..., -n_times_output:, :]` represents the media
|
|
166
363
|
execution of the output weeks.
|
|
364
|
+
decay_functions: String or list of strings indicating the decay
|
|
365
|
+
function(s) to use for the Adstock calculation for each channel.
|
|
366
|
+
Default is geometric decay for all channels.
|
|
167
367
|
"""
|
|
168
368
|
self._alpha = alpha
|
|
169
369
|
self._max_lag = max_lag
|
|
170
370
|
self._n_times_output = n_times_output
|
|
371
|
+
self._decay_functions = decay_functions
|
|
171
372
|
|
|
172
|
-
def forward(self, media:
|
|
373
|
+
def forward(self, media: backend.Tensor) -> backend.Tensor:
|
|
173
374
|
"""Computes the Adstock transformation of a given `media` tensor.
|
|
174
375
|
|
|
175
376
|
For geo `g`, time period `t`, and media channel `m`, Adstock is calculated
|
|
@@ -196,13 +397,14 @@ class AdstockTransformer(AdstockHillTransformer):
|
|
|
196
397
|
alpha=self._alpha,
|
|
197
398
|
max_lag=self._max_lag,
|
|
198
399
|
n_times_output=self._n_times_output,
|
|
400
|
+
decay_functions=self._decay_functions,
|
|
199
401
|
)
|
|
200
402
|
|
|
201
403
|
|
|
202
404
|
class HillTransformer(AdstockHillTransformer):
|
|
203
405
|
"""Class to compute the Hill transformation of media."""
|
|
204
406
|
|
|
205
|
-
def __init__(self, ec:
|
|
407
|
+
def __init__(self, ec: backend.Tensor, slope: backend.Tensor):
|
|
206
408
|
"""Initializes the instance based on the Hill function parameters.
|
|
207
409
|
|
|
208
410
|
Args:
|
|
@@ -216,7 +418,7 @@ class HillTransformer(AdstockHillTransformer):
|
|
|
216
418
|
self._ec = ec
|
|
217
419
|
self._slope = slope
|
|
218
420
|
|
|
219
|
-
def forward(self, media:
|
|
421
|
+
def forward(self, media: backend.Tensor) -> backend.Tensor:
|
|
220
422
|
"""Computes the Hill transformation of a given `media` tensor.
|
|
221
423
|
|
|
222
424
|
Calculates results for the Hill function, which accounts for the diminishing
|
|
@@ -234,3 +436,47 @@ class HillTransformer(AdstockHillTransformer):
|
|
|
234
436
|
representing Hill-transformed media.
|
|
235
437
|
"""
|
|
236
438
|
return _hill(media=media, ec=self._ec, slope=self._slope)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def transform_non_negative_reals_distribution(
|
|
442
|
+
distribution: backend.tfd.Distribution,
|
|
443
|
+
) -> backend.tfd.TransformedDistribution:
|
|
444
|
+
"""Transforms a distribution with support on `[0, infinity)` to `(0, 1]`.
|
|
445
|
+
|
|
446
|
+
This allows for defining a prior on `alpha_*`, the exponent of the binomial
|
|
447
|
+
Adstock decay function, directly, and then translating it to a distribution
|
|
448
|
+
defined on the unit interval as Meridian expects. This transformation
|
|
449
|
+
`(x -> 1 / (1 + x))` is the inverse of the interval mapping the Meridian
|
|
450
|
+
performs `(x -> 1 / x - 1)` on alpha to define the binomial Adstock
|
|
451
|
+
decay function's exponent.
|
|
452
|
+
|
|
453
|
+
For example, to define a `LogNormal(0.2, 0.9)` prior on `alpha_*`:
|
|
454
|
+
|
|
455
|
+
```python
|
|
456
|
+
from meridian import backend
|
|
457
|
+
alpha_star_prior = backend.tfd.LogNormal(0.2, 0.9)
|
|
458
|
+
alpha_prior = transform_non_negative_reals_distribution(alpha_star_prior)
|
|
459
|
+
prior = prior_distribution.PriorDistribution(
|
|
460
|
+
alpha_m=alpha_prior,
|
|
461
|
+
...
|
|
462
|
+
)
|
|
463
|
+
```
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
distribution: A Tensorflow Probability distribution with support on `[0,
|
|
467
|
+
infinity)`.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
A Tensorflow Probability `TransformedDistribution` with support on `(0, 1]`,
|
|
471
|
+
such that the resultant prior on `alpha_*` is the input distribution.
|
|
472
|
+
"""
|
|
473
|
+
|
|
474
|
+
bijector = backend.bijectors.Chain(
|
|
475
|
+
[backend.bijectors.Reciprocal(), backend.bijectors.Shift(1)]
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
return backend.tfd.TransformedDistribution(
|
|
479
|
+
distribution=distribution,
|
|
480
|
+
bijector=bijector,
|
|
481
|
+
name=f'{distribution.name}UnitIntervalMapped',
|
|
482
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2025 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""The Meridian API module that performs EDA checks."""
|
|
16
|
+
|
|
17
|
+
from meridian.model.eda import eda_engine
|