google-meridian 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/METADATA +26 -21
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/RECORD +20 -16
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/WHEEL +1 -1
- meridian/__init__.py +1 -1
- meridian/analysis/analyzer.py +347 -512
- meridian/analysis/formatter.py +18 -0
- meridian/analysis/optimizer.py +259 -145
- meridian/analysis/summarizer.py +2 -2
- meridian/analysis/visualizer.py +21 -2
- meridian/data/__init__.py +1 -0
- meridian/data/arg_builder.py +107 -0
- meridian/data/input_data.py +23 -0
- meridian/data/test_utils.py +6 -4
- meridian/model/__init__.py +2 -0
- meridian/model/model.py +42 -984
- meridian/model/model_test_data.py +351 -0
- meridian/model/posterior_sampler.py +566 -0
- meridian/model/prior_sampler.py +633 -0
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/LICENSE +0 -0
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,633 @@
|
|
|
1
|
+
# Copyright 2024 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module for sampling prior distributions in a Meridian model."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Mapping
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
import arviz as az
|
|
21
|
+
from meridian import constants
|
|
22
|
+
import tensorflow as tf
|
|
23
|
+
import tensorflow_probability as tfp
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"PriorDistributionSampler",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_tau_g(
|
|
35
|
+
tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int
|
|
36
|
+
) -> tfp.distributions.Distribution:
|
|
37
|
+
"""Computes `tau_g` from `tau_g_excl_baseline`.
|
|
38
|
+
|
|
39
|
+
This function computes `tau_g` by inserting a column of zeros at the
|
|
40
|
+
`baseline_geo` position in `tau_g_excl_baseline`.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
tau_g_excl_baseline: A tensor of shape `[..., n_geos - 1]` for the
|
|
44
|
+
user-defined dimensions of the `tau_g` parameter distribution.
|
|
45
|
+
baseline_geo_idx: The index of the baseline geo to be set to zero.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A tensor of shape `[..., n_geos]` with the final distribution of the `tau_g`
|
|
49
|
+
parameter with zero at position `baseline_geo_idx` and matching
|
|
50
|
+
`tau_g_excl_baseline` elsewhere.
|
|
51
|
+
"""
|
|
52
|
+
rank = len(tau_g_excl_baseline.shape)
|
|
53
|
+
shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1
|
|
54
|
+
tau_g = tf.concat(
|
|
55
|
+
[
|
|
56
|
+
tau_g_excl_baseline[..., :baseline_geo_idx],
|
|
57
|
+
tf.zeros(shape, dtype=tau_g_excl_baseline.dtype),
|
|
58
|
+
tau_g_excl_baseline[..., baseline_geo_idx:],
|
|
59
|
+
],
|
|
60
|
+
axis=rank - 1,
|
|
61
|
+
)
|
|
62
|
+
return tfp.distributions.Deterministic(tau_g, name="tau_g")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class PriorDistributionSampler:
|
|
66
|
+
"""A callable that samples from a model spec's prior distributions."""
|
|
67
|
+
|
|
68
|
+
def __init__(self, meridian: "model.Meridian"):
|
|
69
|
+
self._meridian = meridian
|
|
70
|
+
|
|
71
|
+
def get_roi_prior_beta_m_value(
|
|
72
|
+
self,
|
|
73
|
+
alpha_m: tf.Tensor,
|
|
74
|
+
beta_gm_dev: tf.Tensor,
|
|
75
|
+
ec_m: tf.Tensor,
|
|
76
|
+
eta_m: tf.Tensor,
|
|
77
|
+
roi_or_mroi_m: tf.Tensor,
|
|
78
|
+
slope_m: tf.Tensor,
|
|
79
|
+
media_transformed: tf.Tensor,
|
|
80
|
+
) -> tf.Tensor:
|
|
81
|
+
"""Returns a tensor to be used in `beta_m`."""
|
|
82
|
+
mmm = self._meridian
|
|
83
|
+
|
|
84
|
+
# The `roi_or_mroi_m` parameter represents either ROI or mROI. For reach &
|
|
85
|
+
# frequency channels, marginal ROI priors are defined as "mROI by reach",
|
|
86
|
+
# which is equivalent to ROI.
|
|
87
|
+
media_spend = mmm.media_tensors.media_spend
|
|
88
|
+
media_spend_counterfactual = mmm.media_tensors.media_spend_counterfactual
|
|
89
|
+
media_counterfactual_scaled = mmm.media_tensors.media_counterfactual_scaled
|
|
90
|
+
# If we got here, then we should already have media tensors derived from
|
|
91
|
+
# non-None InputData.media data.
|
|
92
|
+
assert media_spend is not None
|
|
93
|
+
assert media_spend_counterfactual is not None
|
|
94
|
+
assert media_counterfactual_scaled is not None
|
|
95
|
+
|
|
96
|
+
# Use absolute value here because this difference will be negative for
|
|
97
|
+
# marginal ROI priors.
|
|
98
|
+
inc_revenue_m = roi_or_mroi_m * tf.reduce_sum(
|
|
99
|
+
tf.abs(media_spend - media_spend_counterfactual),
|
|
100
|
+
range(media_spend.ndim - 1),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if (
|
|
104
|
+
mmm.model_spec.roi_calibration_period is None
|
|
105
|
+
and mmm.model_spec.paid_media_prior_type
|
|
106
|
+
== constants.PAID_MEDIA_PRIOR_TYPE_ROI
|
|
107
|
+
):
|
|
108
|
+
# We can skip the adstock/hill computation step in this case.
|
|
109
|
+
media_counterfactual_transformed = tf.zeros_like(media_transformed)
|
|
110
|
+
else:
|
|
111
|
+
media_counterfactual_transformed = mmm.adstock_hill_media(
|
|
112
|
+
media=media_counterfactual_scaled,
|
|
113
|
+
alpha=alpha_m,
|
|
114
|
+
ec=ec_m,
|
|
115
|
+
slope=slope_m,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
revenue_per_kpi = mmm.revenue_per_kpi
|
|
119
|
+
if mmm.input_data.revenue_per_kpi is None:
|
|
120
|
+
revenue_per_kpi = tf.ones([mmm.n_geos, mmm.n_times], dtype=tf.float32)
|
|
121
|
+
# Note: use absolute value here because this difference will be negative for
|
|
122
|
+
# marginal ROI priors.
|
|
123
|
+
media_contrib_gm = tf.einsum(
|
|
124
|
+
"...gtm,g,,gt->...gm",
|
|
125
|
+
tf.abs(media_transformed - media_counterfactual_transformed),
|
|
126
|
+
mmm.population,
|
|
127
|
+
mmm.kpi_transformer.population_scaled_stdev,
|
|
128
|
+
revenue_per_kpi,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL:
|
|
132
|
+
media_contrib_m = tf.einsum("...gm->...m", media_contrib_gm)
|
|
133
|
+
random_effect_m = tf.einsum(
|
|
134
|
+
"...m,...gm,...gm->...m", eta_m, beta_gm_dev, media_contrib_gm
|
|
135
|
+
)
|
|
136
|
+
return (inc_revenue_m - random_effect_m) / media_contrib_m
|
|
137
|
+
else:
|
|
138
|
+
# For log_normal, beta_m and eta_m are not mean & std.
|
|
139
|
+
# The parameterization is beta_gm ~ exp(beta_m + eta_m * N(0, 1)).
|
|
140
|
+
random_effect_m = tf.einsum(
|
|
141
|
+
"...gm,...gm->...m",
|
|
142
|
+
tf.math.exp(beta_gm_dev * eta_m[..., tf.newaxis, :]),
|
|
143
|
+
media_contrib_gm,
|
|
144
|
+
)
|
|
145
|
+
return tf.math.log(inc_revenue_m) - tf.math.log(random_effect_m)
|
|
146
|
+
|
|
147
|
+
def get_roi_prior_beta_rf_value(
|
|
148
|
+
self,
|
|
149
|
+
alpha_rf: tf.Tensor,
|
|
150
|
+
beta_grf_dev: tf.Tensor,
|
|
151
|
+
ec_rf: tf.Tensor,
|
|
152
|
+
eta_rf: tf.Tensor,
|
|
153
|
+
roi_or_mroi_rf: tf.Tensor,
|
|
154
|
+
slope_rf: tf.Tensor,
|
|
155
|
+
rf_transformed: tf.Tensor,
|
|
156
|
+
) -> tf.Tensor:
|
|
157
|
+
"""Returns a tensor to be used in `beta_rf`."""
|
|
158
|
+
mmm = self._meridian
|
|
159
|
+
|
|
160
|
+
rf_spend = mmm.rf_tensors.rf_spend
|
|
161
|
+
rf_spend_counterfactual = mmm.rf_tensors.rf_spend_counterfactual
|
|
162
|
+
reach_counterfactual_scaled = mmm.rf_tensors.reach_counterfactual_scaled
|
|
163
|
+
frequency = mmm.rf_tensors.frequency
|
|
164
|
+
# If we got here, then we should already have RF media tensors derived from
|
|
165
|
+
# non-None InputData.reach data.
|
|
166
|
+
assert rf_spend is not None
|
|
167
|
+
assert rf_spend_counterfactual is not None
|
|
168
|
+
assert reach_counterfactual_scaled is not None
|
|
169
|
+
assert frequency is not None
|
|
170
|
+
|
|
171
|
+
inc_revenue_rf = roi_or_mroi_rf * tf.reduce_sum(
|
|
172
|
+
rf_spend - rf_spend_counterfactual,
|
|
173
|
+
range(rf_spend.ndim - 1),
|
|
174
|
+
)
|
|
175
|
+
if mmm.model_spec.rf_roi_calibration_period is not None:
|
|
176
|
+
rf_counterfactual_transformed = mmm.adstock_hill_rf(
|
|
177
|
+
reach=reach_counterfactual_scaled,
|
|
178
|
+
frequency=frequency,
|
|
179
|
+
alpha=alpha_rf,
|
|
180
|
+
ec=ec_rf,
|
|
181
|
+
slope=slope_rf,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
rf_counterfactual_transformed = tf.zeros_like(rf_transformed)
|
|
185
|
+
revenue_per_kpi = mmm.revenue_per_kpi
|
|
186
|
+
if mmm.input_data.revenue_per_kpi is None:
|
|
187
|
+
revenue_per_kpi = tf.ones([mmm.n_geos, mmm.n_times], dtype=tf.float32)
|
|
188
|
+
|
|
189
|
+
media_contrib_grf = tf.einsum(
|
|
190
|
+
"...gtm,g,,gt->...gm",
|
|
191
|
+
rf_transformed - rf_counterfactual_transformed,
|
|
192
|
+
mmm.population,
|
|
193
|
+
mmm.kpi_transformer.population_scaled_stdev,
|
|
194
|
+
revenue_per_kpi,
|
|
195
|
+
)
|
|
196
|
+
if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL:
|
|
197
|
+
media_contrib_rf = tf.einsum("...gm->...m", media_contrib_grf)
|
|
198
|
+
random_effect_rf = tf.einsum(
|
|
199
|
+
"...m,...gm,...gm->...m", eta_rf, beta_grf_dev, media_contrib_grf
|
|
200
|
+
)
|
|
201
|
+
return (inc_revenue_rf - random_effect_rf) / media_contrib_rf
|
|
202
|
+
else:
|
|
203
|
+
# For log_normal, beta_rf and eta_rf are not mean & std.
|
|
204
|
+
# The parameterization is beta_grf ~ exp(beta_rf + eta_rf * N(0, 1)).
|
|
205
|
+
random_effect_rf = tf.einsum(
|
|
206
|
+
"...gm,...gm->...m",
|
|
207
|
+
tf.math.exp(beta_grf_dev * eta_rf[..., tf.newaxis, :]),
|
|
208
|
+
media_contrib_grf,
|
|
209
|
+
)
|
|
210
|
+
return tf.math.log(inc_revenue_rf) - tf.math.log(random_effect_rf)
|
|
211
|
+
|
|
212
|
+
def _sample_media_priors(
|
|
213
|
+
self,
|
|
214
|
+
n_draws: int,
|
|
215
|
+
seed: int | None = None,
|
|
216
|
+
) -> Mapping[str, tf.Tensor]:
|
|
217
|
+
"""Draws samples from the prior distributions of the media variables.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
n_draws: Number of samples drawn from the prior distribution.
|
|
221
|
+
seed: Used to set the seed for reproducible results. For more information,
|
|
222
|
+
see [PRNGS and seeds]
|
|
223
|
+
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
A mapping of media parameter names to a tensor of shape `[n_draws, n_geos,
|
|
227
|
+
n_media_channels]` or `[n_draws, n_media_channels]` containing the
|
|
228
|
+
samples.
|
|
229
|
+
"""
|
|
230
|
+
mmm = self._meridian
|
|
231
|
+
|
|
232
|
+
prior = mmm.prior_broadcast
|
|
233
|
+
sample_shape = [1, n_draws]
|
|
234
|
+
sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
|
|
235
|
+
media_vars = {
|
|
236
|
+
constants.ALPHA_M: prior.alpha_m.sample(**sample_kwargs),
|
|
237
|
+
constants.EC_M: prior.ec_m.sample(**sample_kwargs),
|
|
238
|
+
constants.ETA_M: prior.eta_m.sample(**sample_kwargs),
|
|
239
|
+
constants.SLOPE_M: prior.slope_m.sample(**sample_kwargs),
|
|
240
|
+
}
|
|
241
|
+
beta_gm_dev = tfp.distributions.Sample(
|
|
242
|
+
tfp.distributions.Normal(0, 1),
|
|
243
|
+
[mmm.n_geos, mmm.n_media_channels],
|
|
244
|
+
name=constants.BETA_GM_DEV,
|
|
245
|
+
).sample(**sample_kwargs)
|
|
246
|
+
media_transformed = mmm.adstock_hill_media(
|
|
247
|
+
media=mmm.media_tensors.media_scaled,
|
|
248
|
+
alpha=media_vars[constants.ALPHA_M],
|
|
249
|
+
ec=media_vars[constants.EC_M],
|
|
250
|
+
slope=media_vars[constants.SLOPE_M],
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
prior_type = mmm.model_spec.paid_media_prior_type
|
|
254
|
+
if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
|
|
255
|
+
roi_m = prior.roi_m.sample(**sample_kwargs)
|
|
256
|
+
beta_m_value = self.get_roi_prior_beta_m_value(
|
|
257
|
+
beta_gm_dev=beta_gm_dev,
|
|
258
|
+
media_transformed=media_transformed,
|
|
259
|
+
roi_or_mroi_m=roi_m,
|
|
260
|
+
**media_vars,
|
|
261
|
+
)
|
|
262
|
+
media_vars[constants.ROI_M] = roi_m
|
|
263
|
+
media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
|
|
264
|
+
beta_m_value, name=constants.BETA_M
|
|
265
|
+
).sample()
|
|
266
|
+
elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI:
|
|
267
|
+
mroi_m = prior.mroi_m.sample(**sample_kwargs)
|
|
268
|
+
beta_m_value = self.get_roi_prior_beta_m_value(
|
|
269
|
+
beta_gm_dev=beta_gm_dev,
|
|
270
|
+
media_transformed=media_transformed,
|
|
271
|
+
roi_or_mroi_m=mroi_m,
|
|
272
|
+
**media_vars,
|
|
273
|
+
)
|
|
274
|
+
media_vars[constants.MROI_M] = mroi_m
|
|
275
|
+
media_vars[constants.BETA_M] = tfp.distributions.Deterministic(
|
|
276
|
+
beta_m_value, name=constants.BETA_M
|
|
277
|
+
).sample()
|
|
278
|
+
else:
|
|
279
|
+
media_vars[constants.BETA_M] = prior.beta_m.sample(**sample_kwargs)
|
|
280
|
+
|
|
281
|
+
beta_eta_combined = (
|
|
282
|
+
media_vars[constants.BETA_M][..., tf.newaxis, :]
|
|
283
|
+
+ media_vars[constants.ETA_M][..., tf.newaxis, :] * beta_gm_dev
|
|
284
|
+
)
|
|
285
|
+
beta_gm_value = (
|
|
286
|
+
beta_eta_combined
|
|
287
|
+
if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
288
|
+
else tf.math.exp(beta_eta_combined)
|
|
289
|
+
)
|
|
290
|
+
media_vars[constants.BETA_GM] = tfp.distributions.Deterministic(
|
|
291
|
+
beta_gm_value, name=constants.BETA_GM
|
|
292
|
+
).sample()
|
|
293
|
+
|
|
294
|
+
return media_vars
|
|
295
|
+
|
|
296
|
+
def _sample_rf_priors(
|
|
297
|
+
self,
|
|
298
|
+
n_draws: int,
|
|
299
|
+
seed: int | None = None,
|
|
300
|
+
) -> Mapping[str, tf.Tensor]:
|
|
301
|
+
"""Draws samples from the prior distributions of the RF variables.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
n_draws: Number of samples drawn from the prior distribution.
|
|
305
|
+
seed: Used to set the seed for reproducible results. For more information,
|
|
306
|
+
see [PRNGS and seeds]
|
|
307
|
+
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
A mapping of RF parameter names to a tensor of shape `[n_draws, n_geos,
|
|
311
|
+
n_rf_channels]` or `[n_draws, n_rf_channels]` containing the samples.
|
|
312
|
+
"""
|
|
313
|
+
mmm = self._meridian
|
|
314
|
+
|
|
315
|
+
prior = mmm.prior_broadcast
|
|
316
|
+
sample_shape = [1, n_draws]
|
|
317
|
+
sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
|
|
318
|
+
rf_vars = {
|
|
319
|
+
constants.ALPHA_RF: prior.alpha_rf.sample(**sample_kwargs),
|
|
320
|
+
constants.EC_RF: prior.ec_rf.sample(**sample_kwargs),
|
|
321
|
+
constants.ETA_RF: prior.eta_rf.sample(**sample_kwargs),
|
|
322
|
+
constants.SLOPE_RF: prior.slope_rf.sample(**sample_kwargs),
|
|
323
|
+
}
|
|
324
|
+
beta_grf_dev = tfp.distributions.Sample(
|
|
325
|
+
tfp.distributions.Normal(0, 1),
|
|
326
|
+
[mmm.n_geos, mmm.n_rf_channels],
|
|
327
|
+
name=constants.BETA_GRF_DEV,
|
|
328
|
+
).sample(**sample_kwargs)
|
|
329
|
+
rf_transformed = mmm.adstock_hill_rf(
|
|
330
|
+
reach=mmm.rf_tensors.reach_scaled,
|
|
331
|
+
frequency=mmm.rf_tensors.frequency,
|
|
332
|
+
alpha=rf_vars[constants.ALPHA_RF],
|
|
333
|
+
ec=rf_vars[constants.EC_RF],
|
|
334
|
+
slope=rf_vars[constants.SLOPE_RF],
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
prior_type = mmm.model_spec.paid_media_prior_type
|
|
338
|
+
if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
|
|
339
|
+
roi_rf = prior.roi_rf.sample(**sample_kwargs)
|
|
340
|
+
beta_rf_value = self.get_roi_prior_beta_rf_value(
|
|
341
|
+
beta_grf_dev=beta_grf_dev,
|
|
342
|
+
rf_transformed=rf_transformed,
|
|
343
|
+
roi_or_mroi_rf=roi_rf,
|
|
344
|
+
**rf_vars,
|
|
345
|
+
)
|
|
346
|
+
rf_vars[constants.ROI_RF] = roi_rf
|
|
347
|
+
rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
|
|
348
|
+
beta_rf_value,
|
|
349
|
+
name=constants.BETA_RF,
|
|
350
|
+
).sample()
|
|
351
|
+
elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI:
|
|
352
|
+
mroi_rf = prior.mroi_rf.sample(**sample_kwargs)
|
|
353
|
+
beta_rf_value = self.get_roi_prior_beta_rf_value(
|
|
354
|
+
beta_grf_dev=beta_grf_dev,
|
|
355
|
+
rf_transformed=rf_transformed,
|
|
356
|
+
roi_or_mroi_rf=mroi_rf,
|
|
357
|
+
**rf_vars,
|
|
358
|
+
)
|
|
359
|
+
rf_vars[constants.MROI_RF] = mroi_rf
|
|
360
|
+
rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic(
|
|
361
|
+
beta_rf_value,
|
|
362
|
+
name=constants.BETA_RF,
|
|
363
|
+
).sample()
|
|
364
|
+
else:
|
|
365
|
+
rf_vars[constants.BETA_RF] = prior.beta_rf.sample(**sample_kwargs)
|
|
366
|
+
|
|
367
|
+
beta_eta_combined = (
|
|
368
|
+
rf_vars[constants.BETA_RF][..., tf.newaxis, :]
|
|
369
|
+
+ rf_vars[constants.ETA_RF][..., tf.newaxis, :] * beta_grf_dev
|
|
370
|
+
)
|
|
371
|
+
beta_grf_value = (
|
|
372
|
+
beta_eta_combined
|
|
373
|
+
if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
374
|
+
else tf.math.exp(beta_eta_combined)
|
|
375
|
+
)
|
|
376
|
+
rf_vars[constants.BETA_GRF] = tfp.distributions.Deterministic(
|
|
377
|
+
beta_grf_value, name=constants.BETA_GRF
|
|
378
|
+
).sample()
|
|
379
|
+
|
|
380
|
+
return rf_vars
|
|
381
|
+
|
|
382
|
+
def _sample_organic_media_priors(
|
|
383
|
+
self,
|
|
384
|
+
n_draws: int,
|
|
385
|
+
seed: int | None = None,
|
|
386
|
+
) -> Mapping[str, tf.Tensor]:
|
|
387
|
+
"""Draws samples from the prior distributions of organic media variables.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
n_draws: Number of samples drawn from the prior distribution.
|
|
391
|
+
seed: Used to set the seed for reproducible results. For more information,
|
|
392
|
+
see [PRNGS and seeds]
|
|
393
|
+
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
A mapping of organic media parameter names to a tensor of shape [n_draws,
|
|
397
|
+
n_geos, n_organic_media_channels] or [n_draws, n_organic_media_channels]
|
|
398
|
+
containing the samples.
|
|
399
|
+
"""
|
|
400
|
+
mmm = self._meridian
|
|
401
|
+
|
|
402
|
+
prior = mmm.prior_broadcast
|
|
403
|
+
sample_shape = [1, n_draws]
|
|
404
|
+
sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
|
|
405
|
+
organic_media_vars = {
|
|
406
|
+
constants.ALPHA_OM: prior.alpha_om.sample(**sample_kwargs),
|
|
407
|
+
constants.EC_OM: prior.ec_om.sample(**sample_kwargs),
|
|
408
|
+
constants.ETA_OM: prior.eta_om.sample(**sample_kwargs),
|
|
409
|
+
constants.SLOPE_OM: prior.slope_om.sample(**sample_kwargs),
|
|
410
|
+
}
|
|
411
|
+
beta_gom_dev = tfp.distributions.Sample(
|
|
412
|
+
tfp.distributions.Normal(0, 1),
|
|
413
|
+
[mmm.n_geos, mmm.n_organic_media_channels],
|
|
414
|
+
name=constants.BETA_GOM_DEV,
|
|
415
|
+
).sample(**sample_kwargs)
|
|
416
|
+
|
|
417
|
+
organic_media_vars[constants.BETA_OM] = prior.beta_om.sample(
|
|
418
|
+
**sample_kwargs
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
beta_eta_combined = (
|
|
422
|
+
organic_media_vars[constants.BETA_OM][..., tf.newaxis, :]
|
|
423
|
+
+ organic_media_vars[constants.ETA_OM][..., tf.newaxis, :]
|
|
424
|
+
* beta_gom_dev
|
|
425
|
+
)
|
|
426
|
+
beta_gom_value = (
|
|
427
|
+
beta_eta_combined
|
|
428
|
+
if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
429
|
+
else tf.math.exp(beta_eta_combined)
|
|
430
|
+
)
|
|
431
|
+
organic_media_vars[constants.BETA_GOM] = tfp.distributions.Deterministic(
|
|
432
|
+
beta_gom_value, name=constants.BETA_GOM
|
|
433
|
+
).sample()
|
|
434
|
+
|
|
435
|
+
return organic_media_vars
|
|
436
|
+
|
|
437
|
+
def _sample_organic_rf_priors(
|
|
438
|
+
self,
|
|
439
|
+
n_draws: int,
|
|
440
|
+
seed: int | None = None,
|
|
441
|
+
) -> Mapping[str, tf.Tensor]:
|
|
442
|
+
"""Draws samples from the prior distributions of the organic RF variables.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
n_draws: Number of samples drawn from the prior distribution.
|
|
446
|
+
seed: Used to set the seed for reproducible results. For more information,
|
|
447
|
+
see [PRNGS and seeds]
|
|
448
|
+
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
A mapping of organic RF parameter names to a tensor of shape [n_draws,
|
|
452
|
+
n_geos, n_organic_rf_channels] or [n_draws, n_organic_rf_channels]
|
|
453
|
+
containing the samples.
|
|
454
|
+
"""
|
|
455
|
+
mmm = self._meridian
|
|
456
|
+
|
|
457
|
+
prior = mmm.prior_broadcast
|
|
458
|
+
sample_shape = [1, n_draws]
|
|
459
|
+
sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
|
|
460
|
+
organic_rf_vars = {
|
|
461
|
+
constants.ALPHA_ORF: prior.alpha_orf.sample(**sample_kwargs),
|
|
462
|
+
constants.EC_ORF: prior.ec_orf.sample(**sample_kwargs),
|
|
463
|
+
constants.ETA_ORF: prior.eta_orf.sample(**sample_kwargs),
|
|
464
|
+
constants.SLOPE_ORF: prior.slope_orf.sample(**sample_kwargs),
|
|
465
|
+
}
|
|
466
|
+
beta_gorf_dev = tfp.distributions.Sample(
|
|
467
|
+
tfp.distributions.Normal(0, 1),
|
|
468
|
+
[mmm.n_geos, mmm.n_organic_rf_channels],
|
|
469
|
+
name=constants.BETA_GORF_DEV,
|
|
470
|
+
).sample(**sample_kwargs)
|
|
471
|
+
|
|
472
|
+
organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(**sample_kwargs)
|
|
473
|
+
|
|
474
|
+
beta_eta_combined = (
|
|
475
|
+
organic_rf_vars[constants.BETA_ORF][..., tf.newaxis, :]
|
|
476
|
+
+ organic_rf_vars[constants.ETA_ORF][..., tf.newaxis, :] * beta_gorf_dev
|
|
477
|
+
)
|
|
478
|
+
beta_gorf_value = (
|
|
479
|
+
beta_eta_combined
|
|
480
|
+
if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
481
|
+
else tf.math.exp(beta_eta_combined)
|
|
482
|
+
)
|
|
483
|
+
organic_rf_vars[constants.BETA_GORF] = tfp.distributions.Deterministic(
|
|
484
|
+
beta_gorf_value, name=constants.BETA_GORF
|
|
485
|
+
).sample()
|
|
486
|
+
|
|
487
|
+
return organic_rf_vars
|
|
488
|
+
|
|
489
|
+
def _sample_non_media_treatments_priors(
|
|
490
|
+
self,
|
|
491
|
+
n_draws: int,
|
|
492
|
+
seed: int | None = None,
|
|
493
|
+
) -> Mapping[str, tf.Tensor]:
|
|
494
|
+
"""Draws from the prior distributions of the non-media treatment variables.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
n_draws: Number of samples drawn from the prior distribution.
|
|
498
|
+
seed: Used to set the seed for reproducible results. For more information,
|
|
499
|
+
see [PRNGS and seeds]
|
|
500
|
+
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
A mapping of non-media treatment parameter names to a tensor of shape
|
|
504
|
+
[n_draws,
|
|
505
|
+
n_geos, n_non_media_channels] or [n_draws, n_non_media_channels]
|
|
506
|
+
containing the samples.
|
|
507
|
+
"""
|
|
508
|
+
mmm = self._meridian
|
|
509
|
+
|
|
510
|
+
prior = mmm.prior_broadcast
|
|
511
|
+
sample_shape = [1, n_draws]
|
|
512
|
+
sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
|
|
513
|
+
non_media_treatments_vars = {
|
|
514
|
+
constants.GAMMA_N: prior.gamma_n.sample(**sample_kwargs),
|
|
515
|
+
constants.XI_N: prior.xi_n.sample(**sample_kwargs),
|
|
516
|
+
}
|
|
517
|
+
gamma_gn_dev = tfp.distributions.Sample(
|
|
518
|
+
tfp.distributions.Normal(0, 1),
|
|
519
|
+
[mmm.n_geos, mmm.n_non_media_channels],
|
|
520
|
+
name=constants.GAMMA_GN_DEV,
|
|
521
|
+
).sample(**sample_kwargs)
|
|
522
|
+
non_media_treatments_vars[constants.GAMMA_GN] = (
|
|
523
|
+
tfp.distributions.Deterministic(
|
|
524
|
+
non_media_treatments_vars[constants.GAMMA_N][..., tf.newaxis, :]
|
|
525
|
+
+ non_media_treatments_vars[constants.XI_N][..., tf.newaxis, :]
|
|
526
|
+
* gamma_gn_dev,
|
|
527
|
+
name=constants.GAMMA_GN,
|
|
528
|
+
).sample()
|
|
529
|
+
)
|
|
530
|
+
return non_media_treatments_vars
|
|
531
|
+
|
|
532
|
+
def _sample_prior(
|
|
533
|
+
self,
|
|
534
|
+
n_draws: int,
|
|
535
|
+
seed: int | None = None,
|
|
536
|
+
) -> Mapping[str, tf.Tensor]:
|
|
537
|
+
"""Returns a mapping of prior parameters to tensors of the samples."""
|
|
538
|
+
mmm = self._meridian
|
|
539
|
+
|
|
540
|
+
# For stateful sampling, the random seed must be set to ensure that any
|
|
541
|
+
# random numbers that are generated are deterministic.
|
|
542
|
+
if seed is not None:
|
|
543
|
+
tf.keras.utils.set_random_seed(1)
|
|
544
|
+
|
|
545
|
+
prior = mmm.prior_broadcast
|
|
546
|
+
sample_shape = [1, n_draws]
|
|
547
|
+
sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed}
|
|
548
|
+
|
|
549
|
+
tau_g_excl_baseline = prior.tau_g_excl_baseline.sample(**sample_kwargs)
|
|
550
|
+
base_vars = {
|
|
551
|
+
constants.KNOT_VALUES: prior.knot_values.sample(**sample_kwargs),
|
|
552
|
+
constants.GAMMA_C: prior.gamma_c.sample(**sample_kwargs),
|
|
553
|
+
constants.XI_C: prior.xi_c.sample(**sample_kwargs),
|
|
554
|
+
constants.SIGMA: prior.sigma.sample(**sample_kwargs),
|
|
555
|
+
constants.TAU_G: _get_tau_g(
|
|
556
|
+
tau_g_excl_baseline=tau_g_excl_baseline,
|
|
557
|
+
baseline_geo_idx=mmm.baseline_geo_idx,
|
|
558
|
+
).sample(),
|
|
559
|
+
}
|
|
560
|
+
base_vars[constants.MU_T] = tfp.distributions.Deterministic(
|
|
561
|
+
tf.einsum(
|
|
562
|
+
"...k,kt->...t",
|
|
563
|
+
base_vars[constants.KNOT_VALUES],
|
|
564
|
+
tf.convert_to_tensor(mmm.knot_info.weights),
|
|
565
|
+
),
|
|
566
|
+
name=constants.MU_T,
|
|
567
|
+
).sample()
|
|
568
|
+
|
|
569
|
+
gamma_gc_dev = tfp.distributions.Sample(
|
|
570
|
+
tfp.distributions.Normal(0, 1),
|
|
571
|
+
[mmm.n_geos, mmm.n_controls],
|
|
572
|
+
name=constants.GAMMA_GC_DEV,
|
|
573
|
+
).sample(**sample_kwargs)
|
|
574
|
+
base_vars[constants.GAMMA_GC] = tfp.distributions.Deterministic(
|
|
575
|
+
base_vars[constants.GAMMA_C][..., tf.newaxis, :]
|
|
576
|
+
+ base_vars[constants.XI_C][..., tf.newaxis, :] * gamma_gc_dev,
|
|
577
|
+
name=constants.GAMMA_GC,
|
|
578
|
+
).sample()
|
|
579
|
+
|
|
580
|
+
media_vars = (
|
|
581
|
+
self._sample_media_priors(n_draws, seed)
|
|
582
|
+
if mmm.media_tensors.media is not None
|
|
583
|
+
else {}
|
|
584
|
+
)
|
|
585
|
+
rf_vars = (
|
|
586
|
+
self._sample_rf_priors(n_draws, seed)
|
|
587
|
+
if mmm.rf_tensors.reach is not None
|
|
588
|
+
else {}
|
|
589
|
+
)
|
|
590
|
+
organic_media_vars = (
|
|
591
|
+
self._sample_organic_media_priors(n_draws, seed)
|
|
592
|
+
if mmm.organic_media_tensors.organic_media is not None
|
|
593
|
+
else {}
|
|
594
|
+
)
|
|
595
|
+
organic_rf_vars = (
|
|
596
|
+
self._sample_organic_rf_priors(n_draws, seed)
|
|
597
|
+
if mmm.organic_rf_tensors.organic_reach is not None
|
|
598
|
+
else {}
|
|
599
|
+
)
|
|
600
|
+
non_media_treatments_vars = (
|
|
601
|
+
self._sample_non_media_treatments_priors(n_draws, seed)
|
|
602
|
+
if mmm.non_media_treatments_scaled is not None
|
|
603
|
+
else {}
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
return (
|
|
607
|
+
base_vars
|
|
608
|
+
| media_vars
|
|
609
|
+
| rf_vars
|
|
610
|
+
| organic_media_vars
|
|
611
|
+
| organic_rf_vars
|
|
612
|
+
| non_media_treatments_vars
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
def __call__(self, n_draws: int, seed: int | None = None) -> az.InferenceData:
|
|
616
|
+
"""Draws samples from prior distributions.
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
An Arviz `InferenceData` object containing prior samples only.
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
n_draws: Number of samples drawn from the prior distribution.
|
|
623
|
+
seed: Used to set the seed for reproducible results. For more information,
|
|
624
|
+
see [PRNGS and seeds]
|
|
625
|
+
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
626
|
+
"""
|
|
627
|
+
prior_draws = self._sample_prior(n_draws, seed=seed)
|
|
628
|
+
# Create Arviz InferenceData for prior draws.
|
|
629
|
+
prior_coords = self._meridian.create_inference_data_coords(1, n_draws)
|
|
630
|
+
prior_dims = self._meridian.create_inference_data_dims()
|
|
631
|
+
return az.convert_to_inference_data(
|
|
632
|
+
prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR
|
|
633
|
+
)
|
|
File without changes
|
|
File without changes
|