google-meridian 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/METADATA +26 -21
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/RECORD +20 -16
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/WHEEL +1 -1
- meridian/__init__.py +1 -1
- meridian/analysis/analyzer.py +347 -512
- meridian/analysis/formatter.py +18 -0
- meridian/analysis/optimizer.py +259 -145
- meridian/analysis/summarizer.py +2 -2
- meridian/analysis/visualizer.py +21 -2
- meridian/data/__init__.py +1 -0
- meridian/data/arg_builder.py +107 -0
- meridian/data/input_data.py +23 -0
- meridian/data/test_utils.py +6 -4
- meridian/model/__init__.py +2 -0
- meridian/model/model.py +42 -984
- meridian/model/model_test_data.py +351 -0
- meridian/model/posterior_sampler.py +566 -0
- meridian/model/prior_sampler.py +633 -0
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/LICENSE +0 -0
- {google_meridian-1.0.3.dist-info → google_meridian-1.0.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
# Copyright 2024 The Meridian Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Module for MCMC sampling of posterior distributions in a Meridian model."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Mapping, Sequence
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
19
|
+
|
|
20
|
+
import arviz as az
|
|
21
|
+
from meridian import constants
|
|
22
|
+
import numpy as np
|
|
23
|
+
import tensorflow as tf
|
|
24
|
+
import tensorflow_probability as tfp
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"MCMCSamplingError",
|
|
32
|
+
"MCMCOOMError",
|
|
33
|
+
"PosteriorMCMCSampler",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MCMCSamplingError(Exception):
|
|
38
|
+
"""The Markov Chain Monte Carlo (MCMC) sampling failed."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class MCMCOOMError(Exception):
|
|
42
|
+
"""The Markov Chain Monte Carlo (MCMC) sampling exceeds memory limits."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _get_tau_g(
|
|
46
|
+
tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int
|
|
47
|
+
) -> tfp.distributions.Distribution:
|
|
48
|
+
"""Computes `tau_g` from `tau_g_excl_baseline`.
|
|
49
|
+
|
|
50
|
+
This function computes `tau_g` by inserting a column of zeros at the
|
|
51
|
+
`baseline_geo` position in `tau_g_excl_baseline`.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
tau_g_excl_baseline: A tensor of shape `[..., n_geos - 1]` for the
|
|
55
|
+
user-defined dimensions of the `tau_g` parameter distribution.
|
|
56
|
+
baseline_geo_idx: The index of the baseline geo to be set to zero.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
A tensor of shape `[..., n_geos]` with the final distribution of the `tau_g`
|
|
60
|
+
parameter with zero at position `baseline_geo_idx` and matching
|
|
61
|
+
`tau_g_excl_baseline` elsewhere.
|
|
62
|
+
"""
|
|
63
|
+
rank = len(tau_g_excl_baseline.shape)
|
|
64
|
+
shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1
|
|
65
|
+
tau_g = tf.concat(
|
|
66
|
+
[
|
|
67
|
+
tau_g_excl_baseline[..., :baseline_geo_idx],
|
|
68
|
+
tf.zeros(shape, dtype=tau_g_excl_baseline.dtype),
|
|
69
|
+
tau_g_excl_baseline[..., baseline_geo_idx:],
|
|
70
|
+
],
|
|
71
|
+
axis=rank - 1,
|
|
72
|
+
)
|
|
73
|
+
return tfp.distributions.Deterministic(tau_g, name="tau_g")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@tf.function(autograph=False, jit_compile=True)
|
|
77
|
+
def _xla_windowed_adaptive_nuts(**kwargs):
|
|
78
|
+
"""XLA wrapper for windowed_adaptive_nuts."""
|
|
79
|
+
return tfp.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class PosteriorMCMCSampler:
|
|
83
|
+
"""A callable that samples from posterior distributions using MCMC."""
|
|
84
|
+
|
|
85
|
+
def __init__(self, meridian: "model.Meridian"):
|
|
86
|
+
self._meridian = meridian
|
|
87
|
+
|
|
88
|
+
def _get_joint_dist_unpinned(self) -> tfp.distributions.Distribution:
|
|
89
|
+
"""Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
|
|
90
|
+
mmm = self._meridian
|
|
91
|
+
mmm.populate_cached_properties()
|
|
92
|
+
|
|
93
|
+
# This lists all the derived properties and states of this Meridian object
|
|
94
|
+
# that are referenced by the joint distribution coroutine.
|
|
95
|
+
# That is, these are the list of captured parameters.
|
|
96
|
+
prior_broadcast = mmm.prior_broadcast
|
|
97
|
+
baseline_geo_idx = mmm.baseline_geo_idx
|
|
98
|
+
knot_info = mmm.knot_info
|
|
99
|
+
n_geos = mmm.n_geos
|
|
100
|
+
n_times = mmm.n_times
|
|
101
|
+
n_media_channels = mmm.n_media_channels
|
|
102
|
+
n_rf_channels = mmm.n_rf_channels
|
|
103
|
+
n_organic_media_channels = mmm.n_organic_media_channels
|
|
104
|
+
n_organic_rf_channels = mmm.n_organic_rf_channels
|
|
105
|
+
n_controls = mmm.n_controls
|
|
106
|
+
n_non_media_channels = mmm.n_non_media_channels
|
|
107
|
+
holdout_id = mmm.holdout_id
|
|
108
|
+
media_tensors = mmm.media_tensors
|
|
109
|
+
rf_tensors = mmm.rf_tensors
|
|
110
|
+
organic_media_tensors = mmm.organic_media_tensors
|
|
111
|
+
organic_rf_tensors = mmm.organic_rf_tensors
|
|
112
|
+
controls_scaled = mmm.controls_scaled
|
|
113
|
+
non_media_treatments_scaled = mmm.non_media_treatments_scaled
|
|
114
|
+
media_effects_dist = mmm.media_effects_dist
|
|
115
|
+
adstock_hill_media_fn = mmm.adstock_hill_media
|
|
116
|
+
adstock_hill_rf_fn = mmm.adstock_hill_rf
|
|
117
|
+
get_roi_prior_beta_m_value_fn = (
|
|
118
|
+
mmm.prior_sampler_callable.get_roi_prior_beta_m_value
|
|
119
|
+
)
|
|
120
|
+
get_roi_prior_beta_rf_value_fn = (
|
|
121
|
+
mmm.prior_sampler_callable.get_roi_prior_beta_rf_value
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
@tfp.distributions.JointDistributionCoroutineAutoBatched
|
|
125
|
+
def joint_dist_unpinned():
|
|
126
|
+
# Sample directly from prior.
|
|
127
|
+
knot_values = yield prior_broadcast.knot_values
|
|
128
|
+
gamma_c = yield prior_broadcast.gamma_c
|
|
129
|
+
xi_c = yield prior_broadcast.xi_c
|
|
130
|
+
sigma = yield prior_broadcast.sigma
|
|
131
|
+
|
|
132
|
+
tau_g_excl_baseline = yield tfp.distributions.Sample(
|
|
133
|
+
prior_broadcast.tau_g_excl_baseline,
|
|
134
|
+
name=constants.TAU_G_EXCL_BASELINE,
|
|
135
|
+
)
|
|
136
|
+
tau_g = yield _get_tau_g(
|
|
137
|
+
tau_g_excl_baseline=tau_g_excl_baseline,
|
|
138
|
+
baseline_geo_idx=baseline_geo_idx,
|
|
139
|
+
)
|
|
140
|
+
mu_t = yield tfp.distributions.Deterministic(
|
|
141
|
+
tf.einsum(
|
|
142
|
+
"k,kt->t",
|
|
143
|
+
knot_values,
|
|
144
|
+
tf.convert_to_tensor(knot_info.weights),
|
|
145
|
+
),
|
|
146
|
+
name=constants.MU_T,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
tau_gt = tau_g[:, tf.newaxis] + mu_t
|
|
150
|
+
combined_media_transformed = tf.zeros(
|
|
151
|
+
shape=(n_geos, n_times, 0), dtype=tf.float32
|
|
152
|
+
)
|
|
153
|
+
combined_beta = tf.zeros(shape=(n_geos, 0), dtype=tf.float32)
|
|
154
|
+
if media_tensors.media is not None:
|
|
155
|
+
alpha_m = yield prior_broadcast.alpha_m
|
|
156
|
+
ec_m = yield prior_broadcast.ec_m
|
|
157
|
+
eta_m = yield prior_broadcast.eta_m
|
|
158
|
+
slope_m = yield prior_broadcast.slope_m
|
|
159
|
+
beta_gm_dev = yield tfp.distributions.Sample(
|
|
160
|
+
tfp.distributions.Normal(0, 1),
|
|
161
|
+
[n_geos, n_media_channels],
|
|
162
|
+
name=constants.BETA_GM_DEV,
|
|
163
|
+
)
|
|
164
|
+
media_transformed = adstock_hill_media_fn(
|
|
165
|
+
media=media_tensors.media_scaled,
|
|
166
|
+
alpha=alpha_m,
|
|
167
|
+
ec=ec_m,
|
|
168
|
+
slope=slope_m,
|
|
169
|
+
)
|
|
170
|
+
prior_type = mmm.model_spec.paid_media_prior_type
|
|
171
|
+
if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
|
|
172
|
+
if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
|
|
173
|
+
roi_or_mroi_m = yield prior_broadcast.roi_m
|
|
174
|
+
else:
|
|
175
|
+
roi_or_mroi_m = yield prior_broadcast.mroi_m
|
|
176
|
+
beta_m_value = get_roi_prior_beta_m_value_fn(
|
|
177
|
+
alpha_m,
|
|
178
|
+
beta_gm_dev,
|
|
179
|
+
ec_m,
|
|
180
|
+
eta_m,
|
|
181
|
+
roi_or_mroi_m,
|
|
182
|
+
slope_m,
|
|
183
|
+
media_transformed,
|
|
184
|
+
)
|
|
185
|
+
beta_m = yield tfp.distributions.Deterministic(
|
|
186
|
+
beta_m_value, name=constants.BETA_M
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
beta_m = yield prior_broadcast.beta_m
|
|
190
|
+
|
|
191
|
+
beta_eta_combined = beta_m + eta_m * beta_gm_dev
|
|
192
|
+
beta_gm_value = (
|
|
193
|
+
beta_eta_combined
|
|
194
|
+
if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
195
|
+
else tf.math.exp(beta_eta_combined)
|
|
196
|
+
)
|
|
197
|
+
beta_gm = yield tfp.distributions.Deterministic(
|
|
198
|
+
beta_gm_value, name=constants.BETA_GM
|
|
199
|
+
)
|
|
200
|
+
combined_media_transformed = tf.concat(
|
|
201
|
+
[combined_media_transformed, media_transformed], axis=-1
|
|
202
|
+
)
|
|
203
|
+
combined_beta = tf.concat([combined_beta, beta_gm], axis=-1)
|
|
204
|
+
|
|
205
|
+
if rf_tensors.reach is not None:
|
|
206
|
+
alpha_rf = yield prior_broadcast.alpha_rf
|
|
207
|
+
ec_rf = yield prior_broadcast.ec_rf
|
|
208
|
+
eta_rf = yield prior_broadcast.eta_rf
|
|
209
|
+
slope_rf = yield prior_broadcast.slope_rf
|
|
210
|
+
beta_grf_dev = yield tfp.distributions.Sample(
|
|
211
|
+
tfp.distributions.Normal(0, 1),
|
|
212
|
+
[n_geos, n_rf_channels],
|
|
213
|
+
name=constants.BETA_GRF_DEV,
|
|
214
|
+
)
|
|
215
|
+
rf_transformed = adstock_hill_rf_fn(
|
|
216
|
+
reach=rf_tensors.reach_scaled,
|
|
217
|
+
frequency=rf_tensors.frequency,
|
|
218
|
+
alpha=alpha_rf,
|
|
219
|
+
ec=ec_rf,
|
|
220
|
+
slope=slope_rf,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
prior_type = mmm.model_spec.paid_media_prior_type
|
|
224
|
+
if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
|
|
225
|
+
if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
|
|
226
|
+
roi_or_mroi_rf = yield prior_broadcast.roi_rf
|
|
227
|
+
else:
|
|
228
|
+
roi_or_mroi_rf = yield prior_broadcast.mroi_rf
|
|
229
|
+
beta_rf_value = get_roi_prior_beta_rf_value_fn(
|
|
230
|
+
alpha_rf,
|
|
231
|
+
beta_grf_dev,
|
|
232
|
+
ec_rf,
|
|
233
|
+
eta_rf,
|
|
234
|
+
roi_or_mroi_rf,
|
|
235
|
+
slope_rf,
|
|
236
|
+
rf_transformed,
|
|
237
|
+
)
|
|
238
|
+
beta_rf = yield tfp.distributions.Deterministic(
|
|
239
|
+
beta_rf_value,
|
|
240
|
+
name=constants.BETA_RF,
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
beta_rf = yield prior_broadcast.beta_rf
|
|
244
|
+
|
|
245
|
+
beta_eta_combined = beta_rf + eta_rf * beta_grf_dev
|
|
246
|
+
beta_grf_value = (
|
|
247
|
+
beta_eta_combined
|
|
248
|
+
if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
249
|
+
else tf.math.exp(beta_eta_combined)
|
|
250
|
+
)
|
|
251
|
+
beta_grf = yield tfp.distributions.Deterministic(
|
|
252
|
+
beta_grf_value, name=constants.BETA_GRF
|
|
253
|
+
)
|
|
254
|
+
combined_media_transformed = tf.concat(
|
|
255
|
+
[combined_media_transformed, rf_transformed], axis=-1
|
|
256
|
+
)
|
|
257
|
+
combined_beta = tf.concat([combined_beta, beta_grf], axis=-1)
|
|
258
|
+
|
|
259
|
+
if organic_media_tensors.organic_media is not None:
|
|
260
|
+
alpha_om = yield prior_broadcast.alpha_om
|
|
261
|
+
ec_om = yield prior_broadcast.ec_om
|
|
262
|
+
eta_om = yield prior_broadcast.eta_om
|
|
263
|
+
slope_om = yield prior_broadcast.slope_om
|
|
264
|
+
beta_gom_dev = yield tfp.distributions.Sample(
|
|
265
|
+
tfp.distributions.Normal(0, 1),
|
|
266
|
+
[n_geos, n_organic_media_channels],
|
|
267
|
+
name=constants.BETA_GOM_DEV,
|
|
268
|
+
)
|
|
269
|
+
organic_media_transformed = adstock_hill_media_fn(
|
|
270
|
+
media=organic_media_tensors.organic_media_scaled,
|
|
271
|
+
alpha=alpha_om,
|
|
272
|
+
ec=ec_om,
|
|
273
|
+
slope=slope_om,
|
|
274
|
+
)
|
|
275
|
+
beta_om = yield prior_broadcast.beta_om
|
|
276
|
+
|
|
277
|
+
beta_eta_combined = beta_om + eta_om * beta_gom_dev
|
|
278
|
+
beta_gom_value = (
|
|
279
|
+
beta_eta_combined
|
|
280
|
+
if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
281
|
+
else tf.math.exp(beta_eta_combined)
|
|
282
|
+
)
|
|
283
|
+
beta_gom = yield tfp.distributions.Deterministic(
|
|
284
|
+
beta_gom_value, name=constants.BETA_GOM
|
|
285
|
+
)
|
|
286
|
+
combined_media_transformed = tf.concat(
|
|
287
|
+
[combined_media_transformed, organic_media_transformed], axis=-1
|
|
288
|
+
)
|
|
289
|
+
combined_beta = tf.concat([combined_beta, beta_gom], axis=-1)
|
|
290
|
+
|
|
291
|
+
if organic_rf_tensors.organic_reach is not None:
|
|
292
|
+
alpha_orf = yield prior_broadcast.alpha_orf
|
|
293
|
+
ec_orf = yield prior_broadcast.ec_orf
|
|
294
|
+
eta_orf = yield prior_broadcast.eta_orf
|
|
295
|
+
slope_orf = yield prior_broadcast.slope_orf
|
|
296
|
+
beta_gorf_dev = yield tfp.distributions.Sample(
|
|
297
|
+
tfp.distributions.Normal(0, 1),
|
|
298
|
+
[n_geos, n_organic_rf_channels],
|
|
299
|
+
name=constants.BETA_GORF_DEV,
|
|
300
|
+
)
|
|
301
|
+
organic_rf_transformed = adstock_hill_rf_fn(
|
|
302
|
+
reach=organic_rf_tensors.organic_reach_scaled,
|
|
303
|
+
frequency=organic_rf_tensors.organic_frequency,
|
|
304
|
+
alpha=alpha_orf,
|
|
305
|
+
ec=ec_orf,
|
|
306
|
+
slope=slope_orf,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
beta_orf = yield prior_broadcast.beta_orf
|
|
310
|
+
|
|
311
|
+
beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev
|
|
312
|
+
beta_gorf_value = (
|
|
313
|
+
beta_eta_combined
|
|
314
|
+
if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
315
|
+
else tf.math.exp(beta_eta_combined)
|
|
316
|
+
)
|
|
317
|
+
beta_gorf = yield tfp.distributions.Deterministic(
|
|
318
|
+
beta_gorf_value, name=constants.BETA_GORF
|
|
319
|
+
)
|
|
320
|
+
combined_media_transformed = tf.concat(
|
|
321
|
+
[combined_media_transformed, organic_rf_transformed], axis=-1
|
|
322
|
+
)
|
|
323
|
+
combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1)
|
|
324
|
+
|
|
325
|
+
sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos]))
|
|
326
|
+
gamma_gc_dev = yield tfp.distributions.Sample(
|
|
327
|
+
tfp.distributions.Normal(0, 1),
|
|
328
|
+
[n_geos, n_controls],
|
|
329
|
+
name=constants.GAMMA_GC_DEV,
|
|
330
|
+
)
|
|
331
|
+
gamma_gc = yield tfp.distributions.Deterministic(
|
|
332
|
+
gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
|
|
333
|
+
)
|
|
334
|
+
y_pred_combined_media = (
|
|
335
|
+
tau_gt
|
|
336
|
+
+ tf.einsum("gtm,gm->gt", combined_media_transformed, combined_beta)
|
|
337
|
+
+ tf.einsum("gtc,gc->gt", controls_scaled, gamma_gc)
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
if mmm.non_media_treatments is not None:
|
|
341
|
+
gamma_n = yield prior_broadcast.gamma_n
|
|
342
|
+
xi_n = yield prior_broadcast.xi_n
|
|
343
|
+
gamma_gn_dev = yield tfp.distributions.Sample(
|
|
344
|
+
tfp.distributions.Normal(0, 1),
|
|
345
|
+
[n_geos, n_non_media_channels],
|
|
346
|
+
name=constants.GAMMA_GN_DEV,
|
|
347
|
+
)
|
|
348
|
+
gamma_gn = yield tfp.distributions.Deterministic(
|
|
349
|
+
gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
|
|
350
|
+
)
|
|
351
|
+
y_pred = y_pred_combined_media + tf.einsum(
|
|
352
|
+
"gtn,gn->gt", non_media_treatments_scaled, gamma_gn
|
|
353
|
+
)
|
|
354
|
+
else:
|
|
355
|
+
y_pred = y_pred_combined_media
|
|
356
|
+
|
|
357
|
+
# If there are any holdout observations, the holdout KPI values will
|
|
358
|
+
# be replaced with zeros using `experimental_pin`. For these
|
|
359
|
+
# observations, we set the posterior mean equal to zero and standard
|
|
360
|
+
# deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the
|
|
361
|
+
# sampled posterior parameter values.
|
|
362
|
+
if holdout_id is not None:
|
|
363
|
+
y_pred_holdout = tf.where(holdout_id, 0.0, y_pred)
|
|
364
|
+
test_sd = tf.cast(1.0 / np.sqrt(2.0 * np.pi), tf.float32)
|
|
365
|
+
sigma_gt_holdout = tf.where(holdout_id, test_sd, sigma_gt)
|
|
366
|
+
yield tfp.distributions.Normal(
|
|
367
|
+
y_pred_holdout, sigma_gt_holdout, name="y"
|
|
368
|
+
)
|
|
369
|
+
else:
|
|
370
|
+
yield tfp.distributions.Normal(y_pred, sigma_gt, name="y")
|
|
371
|
+
|
|
372
|
+
return joint_dist_unpinned
|
|
373
|
+
|
|
374
|
+
def _get_joint_dist(self) -> tfp.distributions.Distribution:
|
|
375
|
+
mmm = self._meridian
|
|
376
|
+
y = (
|
|
377
|
+
tf.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
|
|
378
|
+
if mmm.holdout_id is not None
|
|
379
|
+
else mmm.kpi_scaled
|
|
380
|
+
)
|
|
381
|
+
return self._get_joint_dist_unpinned().experimental_pin(y=y)
|
|
382
|
+
|
|
383
|
+
def __call__(
|
|
384
|
+
self,
|
|
385
|
+
n_chains: Sequence[int] | int,
|
|
386
|
+
n_adapt: int,
|
|
387
|
+
n_burnin: int,
|
|
388
|
+
n_keep: int,
|
|
389
|
+
current_state: Mapping[str, tf.Tensor] | None = None,
|
|
390
|
+
init_step_size: int | None = None,
|
|
391
|
+
dual_averaging_kwargs: Mapping[str, int] | None = None,
|
|
392
|
+
max_tree_depth: int = 10,
|
|
393
|
+
max_energy_diff: float = 500.0,
|
|
394
|
+
unrolled_leapfrog_steps: int = 1,
|
|
395
|
+
parallel_iterations: int = 10,
|
|
396
|
+
seed: Sequence[int] | None = None,
|
|
397
|
+
**pins,
|
|
398
|
+
) -> az.InferenceData:
|
|
399
|
+
"""Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions.
|
|
400
|
+
|
|
401
|
+
For more information about the arguments, see [`windowed_adaptive_nuts`]
|
|
402
|
+
(https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/mcmc/windowed_adaptive_nuts).
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
n_chains: Number of MCMC chains. Given a sequence of integers,
|
|
406
|
+
`windowed_adaptive_nuts` will be called once for each element. The
|
|
407
|
+
`n_chains` argument of each `windowed_adaptive_nuts` call will be equal
|
|
408
|
+
to the respective integer element. Using a list of integers, one can
|
|
409
|
+
split the chains of a `windowed_adaptive_nuts` call into multiple calls
|
|
410
|
+
with fewer chains per call. This can reduce memory usage. This might
|
|
411
|
+
require an increased number of adaptation steps for convergence, as the
|
|
412
|
+
optimization is occurring across fewer chains per sampling call.
|
|
413
|
+
n_adapt: Number of adaptation draws per chain.
|
|
414
|
+
n_burnin: Number of burn-in draws per chain. Burn-in draws occur after
|
|
415
|
+
adaptation draws and before the kept draws.
|
|
416
|
+
n_keep: Integer number of draws per chain to keep for inference.
|
|
417
|
+
current_state: Optional structure of tensors at which to initialize
|
|
418
|
+
sampling. Use the same shape and structure as
|
|
419
|
+
`model.experimental_pin(**pins).sample(n_chains)`.
|
|
420
|
+
init_step_size: Optional integer determining where to initialize the step
|
|
421
|
+
size for the leapfrog integrator. The structure must broadcast with
|
|
422
|
+
`current_state`. For example, if the initial state is: ``` { 'a':
|
|
423
|
+
tf.zeros(n_chains), 'b': tf.zeros([n_chains, n_features]), } ``` then
|
|
424
|
+
any of `1.`, `{'a': 1., 'b': 1.}`, or `{'a': tf.ones(n_chains), 'b':
|
|
425
|
+
tf.ones([n_chains, n_features])}` will work. Defaults to the dimension
|
|
426
|
+
of the log density to the ¼ power.
|
|
427
|
+
dual_averaging_kwargs: Optional dict keyword arguments to pass to
|
|
428
|
+
`tfp.mcmc.DualAveragingStepSizeAdaptation`. By default, a
|
|
429
|
+
`target_accept_prob` of `0.85` is set, acceptance probabilities across
|
|
430
|
+
chains are reduced using a harmonic mean, and the class defaults are
|
|
431
|
+
used otherwise.
|
|
432
|
+
max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
|
|
433
|
+
maximum number of leapfrog steps is bounded by `2**max_tree_depth`, for
|
|
434
|
+
example, the number of nodes in a binary tree `max_tree_depth` nodes
|
|
435
|
+
deep. The default setting of `10` takes up to 1024 leapfrog steps.
|
|
436
|
+
max_energy_diff: Scalar threshold of energy differences at each leapfrog,
|
|
437
|
+
divergence samples are defined as leapfrog steps that exceed this
|
|
438
|
+
threshold. Default is `1000`.
|
|
439
|
+
unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree
|
|
440
|
+
expansion step. Applies a direct linear multiplier to the maximum
|
|
441
|
+
trajectory length implied by `max_tree_depth`. Defaults is `1`.
|
|
442
|
+
parallel_iterations: Number of iterations allowed to run in parallel. Must
|
|
443
|
+
be a positive integer. For more information, see `tf.while_loop`.
|
|
444
|
+
seed: Used to set the seed for reproducible results. For more information,
|
|
445
|
+
see [PRNGS and seeds]
|
|
446
|
+
(https://github.com/tensorflow/probability/blob/main/PRNGS.md).
|
|
447
|
+
**pins: These are used to condition the provided joint distribution, and
|
|
448
|
+
are passed directly to `joint_dist.experimental_pin(**pins)`.
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
An Arviz `InferenceData` object containing posterior samples only.
|
|
452
|
+
|
|
453
|
+
Throws:
|
|
454
|
+
MCMCOOMError: If the model is out of memory. Try reducing `n_keep` or pass
|
|
455
|
+
a list of integers as `n_chains` to sample chains serially. For more
|
|
456
|
+
information, see
|
|
457
|
+
[ResourceExhaustedError when running Meridian.sample_posterior]
|
|
458
|
+
(https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
|
|
459
|
+
"""
|
|
460
|
+
seed = tfp.random.sanitize_seed(seed) if seed else None
|
|
461
|
+
n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
|
|
462
|
+
total_chains = np.sum(n_chains_list)
|
|
463
|
+
|
|
464
|
+
states = []
|
|
465
|
+
traces = []
|
|
466
|
+
for n_chains_batch in n_chains_list:
|
|
467
|
+
try:
|
|
468
|
+
mcmc = _xla_windowed_adaptive_nuts(
|
|
469
|
+
n_draws=n_burnin + n_keep,
|
|
470
|
+
joint_dist=self._get_joint_dist(),
|
|
471
|
+
n_chains=n_chains_batch,
|
|
472
|
+
num_adaptation_steps=n_adapt,
|
|
473
|
+
current_state=current_state,
|
|
474
|
+
init_step_size=init_step_size,
|
|
475
|
+
dual_averaging_kwargs=dual_averaging_kwargs,
|
|
476
|
+
max_tree_depth=max_tree_depth,
|
|
477
|
+
max_energy_diff=max_energy_diff,
|
|
478
|
+
unrolled_leapfrog_steps=unrolled_leapfrog_steps,
|
|
479
|
+
parallel_iterations=parallel_iterations,
|
|
480
|
+
seed=seed,
|
|
481
|
+
**pins,
|
|
482
|
+
)
|
|
483
|
+
except tf.errors.ResourceExhaustedError as error:
|
|
484
|
+
raise MCMCOOMError(
|
|
485
|
+
"ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
|
|
486
|
+
" integers as `n_chains` to sample chains serially (see"
|
|
487
|
+
" https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error)"
|
|
488
|
+
) from error
|
|
489
|
+
states.append(mcmc.all_states._asdict())
|
|
490
|
+
traces.append(mcmc.trace)
|
|
491
|
+
|
|
492
|
+
mcmc_states = {
|
|
493
|
+
k: tf.einsum(
|
|
494
|
+
"ij...->ji...",
|
|
495
|
+
tf.concat([state[k] for state in states], axis=1)[n_burnin:, ...],
|
|
496
|
+
)
|
|
497
|
+
for k in states[0].keys()
|
|
498
|
+
if k not in constants.UNSAVED_PARAMETERS
|
|
499
|
+
}
|
|
500
|
+
# Create Arviz InferenceData for posterior draws.
|
|
501
|
+
posterior_coords = self._meridian.create_inference_data_coords(
|
|
502
|
+
total_chains, n_keep
|
|
503
|
+
)
|
|
504
|
+
posterior_dims = self._meridian.create_inference_data_dims()
|
|
505
|
+
infdata_posterior = az.convert_to_inference_data(
|
|
506
|
+
mcmc_states, coords=posterior_coords, dims=posterior_dims
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# Save trace metrics in InferenceData.
|
|
510
|
+
mcmc_trace = {}
|
|
511
|
+
for k in traces[0].keys():
|
|
512
|
+
if k not in constants.IGNORED_TRACE_METRICS:
|
|
513
|
+
mcmc_trace[k] = tf.concat(
|
|
514
|
+
[
|
|
515
|
+
tf.broadcast_to(
|
|
516
|
+
tf.transpose(trace[k][n_burnin:, ...]),
|
|
517
|
+
[n_chains_list[i], n_keep],
|
|
518
|
+
)
|
|
519
|
+
for i, trace in enumerate(traces)
|
|
520
|
+
],
|
|
521
|
+
axis=0,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
trace_coords = {
|
|
525
|
+
constants.CHAIN: np.arange(total_chains),
|
|
526
|
+
constants.DRAW: np.arange(n_keep),
|
|
527
|
+
}
|
|
528
|
+
trace_dims = {
|
|
529
|
+
k: [constants.CHAIN, constants.DRAW] for k in mcmc_trace.keys()
|
|
530
|
+
}
|
|
531
|
+
infdata_trace = az.convert_to_inference_data(
|
|
532
|
+
mcmc_trace, coords=trace_coords, dims=trace_dims, group="trace"
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Create Arviz InferenceData for divergent transitions and other sampling
|
|
536
|
+
# statistics. Note that InferenceData has a different naming convention
|
|
537
|
+
# than Tensorflow, and only certain variables are recongnized.
|
|
538
|
+
# https://arviz-devs.github.io/arviz/schema/schema.html#sample-stats
|
|
539
|
+
# The list of values returned by windowed_adaptive_nuts() is the following:
|
|
540
|
+
# 'step_size', 'tune', 'target_log_prob', 'diverging', 'accept_ratio',
|
|
541
|
+
# 'variance_scaling', 'n_steps', 'is_accepted'.
|
|
542
|
+
|
|
543
|
+
sample_stats = {
|
|
544
|
+
constants.SAMPLE_STATS_METRICS[k]: v
|
|
545
|
+
for k, v in mcmc_trace.items()
|
|
546
|
+
if k in constants.SAMPLE_STATS_METRICS
|
|
547
|
+
}
|
|
548
|
+
sample_stats_dims = {
|
|
549
|
+
constants.SAMPLE_STATS_METRICS[k]: v
|
|
550
|
+
for k, v in trace_dims.items()
|
|
551
|
+
if k in constants.SAMPLE_STATS_METRICS
|
|
552
|
+
}
|
|
553
|
+
# Tensorflow does not include a "draw" dimension on step size metric if same
|
|
554
|
+
# step size is used for all chains. Step size must be broadcast to the
|
|
555
|
+
# correct shape.
|
|
556
|
+
sample_stats[constants.STEP_SIZE] = tf.broadcast_to(
|
|
557
|
+
sample_stats[constants.STEP_SIZE], [total_chains, n_keep]
|
|
558
|
+
)
|
|
559
|
+
sample_stats_dims[constants.STEP_SIZE] = [constants.CHAIN, constants.DRAW]
|
|
560
|
+
infdata_sample_stats = az.convert_to_inference_data(
|
|
561
|
+
sample_stats,
|
|
562
|
+
coords=trace_coords,
|
|
563
|
+
dims=sample_stats_dims,
|
|
564
|
+
group="sample_stats",
|
|
565
|
+
)
|
|
566
|
+
return az.concat(infdata_posterior, infdata_trace, infdata_sample_stats)
|