google-meridian 1.1.6__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.0.dist-info}/METADATA +8 -2
- google_meridian-1.2.0.dist-info/RECORD +52 -0
- meridian/__init__.py +1 -0
- meridian/analysis/analyzer.py +526 -362
- meridian/analysis/optimizer.py +275 -267
- meridian/analysis/test_utils.py +96 -94
- meridian/analysis/visualizer.py +37 -49
- meridian/backend/__init__.py +514 -0
- meridian/backend/config.py +59 -0
- meridian/backend/test_utils.py +95 -0
- meridian/constants.py +59 -3
- meridian/data/input_data.py +94 -0
- meridian/data/test_utils.py +144 -12
- meridian/model/adstock_hill.py +279 -33
- meridian/model/eda/__init__.py +17 -0
- meridian/model/eda/eda_engine.py +306 -0
- meridian/model/knots.py +525 -2
- meridian/model/media.py +62 -54
- meridian/model/model.py +224 -97
- meridian/model/model_test_data.py +323 -157
- meridian/model/posterior_sampler.py +81 -76
- meridian/model/prior_distribution.py +538 -168
- meridian/model/prior_sampler.py +65 -65
- meridian/model/spec.py +23 -3
- meridian/model/transformers.py +53 -47
- meridian/version.py +1 -1
- google_meridian-1.1.6.dist-info/RECORD +0 -47
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.0.dist-info}/WHEEL +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {google_meridian-1.1.6.dist-info → google_meridian-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -18,10 +18,9 @@ from collections.abc import Mapping, Sequence
|
|
|
18
18
|
from typing import TYPE_CHECKING
|
|
19
19
|
|
|
20
20
|
import arviz as az
|
|
21
|
+
from meridian import backend
|
|
21
22
|
from meridian import constants
|
|
22
23
|
import numpy as np
|
|
23
|
-
import tensorflow as tf
|
|
24
|
-
import tensorflow_probability as tfp
|
|
25
24
|
|
|
26
25
|
if TYPE_CHECKING:
|
|
27
26
|
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
|
|
@@ -43,8 +42,8 @@ class MCMCOOMError(Exception):
|
|
|
43
42
|
|
|
44
43
|
|
|
45
44
|
def _get_tau_g(
|
|
46
|
-
tau_g_excl_baseline:
|
|
47
|
-
) ->
|
|
45
|
+
tau_g_excl_baseline: backend.Tensor, baseline_geo_idx: int
|
|
46
|
+
) -> backend.tfd.Distribution:
|
|
48
47
|
"""Computes `tau_g` from `tau_g_excl_baseline`.
|
|
49
48
|
|
|
50
49
|
This function computes `tau_g` by inserting a column of zeros at the
|
|
@@ -62,21 +61,21 @@ def _get_tau_g(
|
|
|
62
61
|
"""
|
|
63
62
|
rank = len(tau_g_excl_baseline.shape)
|
|
64
63
|
shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1
|
|
65
|
-
tau_g =
|
|
64
|
+
tau_g = backend.concatenate(
|
|
66
65
|
[
|
|
67
66
|
tau_g_excl_baseline[..., :baseline_geo_idx],
|
|
68
|
-
|
|
67
|
+
backend.zeros(shape, dtype=tau_g_excl_baseline.dtype),
|
|
69
68
|
tau_g_excl_baseline[..., baseline_geo_idx:],
|
|
70
69
|
],
|
|
71
70
|
axis=rank - 1,
|
|
72
71
|
)
|
|
73
|
-
return
|
|
72
|
+
return backend.tfd.Deterministic(tau_g, name="tau_g")
|
|
74
73
|
|
|
75
74
|
|
|
76
|
-
@
|
|
75
|
+
@backend.function(autograph=False, jit_compile=True)
|
|
77
76
|
def _xla_windowed_adaptive_nuts(**kwargs):
|
|
78
77
|
"""XLA wrapper for windowed_adaptive_nuts."""
|
|
79
|
-
return
|
|
78
|
+
return backend.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
|
|
80
79
|
|
|
81
80
|
|
|
82
81
|
class PosteriorMCMCSampler:
|
|
@@ -89,7 +88,7 @@ class PosteriorMCMCSampler:
|
|
|
89
88
|
def model(self) -> "model.Meridian":
|
|
90
89
|
return self._meridian
|
|
91
90
|
|
|
92
|
-
def _get_joint_dist_unpinned(self) ->
|
|
91
|
+
def _get_joint_dist_unpinned(self) -> backend.tfd.Distribution:
|
|
93
92
|
"""Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
|
|
94
93
|
mmm = self.model
|
|
95
94
|
mmm.populate_cached_properties()
|
|
@@ -120,13 +119,13 @@ class PosteriorMCMCSampler:
|
|
|
120
119
|
adstock_hill_rf_fn = mmm.adstock_hill_rf
|
|
121
120
|
total_outcome = mmm.total_outcome
|
|
122
121
|
|
|
123
|
-
@
|
|
122
|
+
@backend.tfd.JointDistributionCoroutineAutoBatched
|
|
124
123
|
def joint_dist_unpinned():
|
|
125
124
|
# Sample directly from prior.
|
|
126
125
|
knot_values = yield prior_broadcast.knot_values
|
|
127
126
|
sigma = yield prior_broadcast.sigma
|
|
128
127
|
|
|
129
|
-
tau_g_excl_baseline = yield
|
|
128
|
+
tau_g_excl_baseline = yield backend.tfd.Sample(
|
|
130
129
|
prior_broadcast.tau_g_excl_baseline,
|
|
131
130
|
name=constants.TAU_G_EXCL_BASELINE,
|
|
132
131
|
)
|
|
@@ -134,27 +133,27 @@ class PosteriorMCMCSampler:
|
|
|
134
133
|
tau_g_excl_baseline=tau_g_excl_baseline,
|
|
135
134
|
baseline_geo_idx=baseline_geo_idx,
|
|
136
135
|
)
|
|
137
|
-
mu_t = yield
|
|
138
|
-
|
|
136
|
+
mu_t = yield backend.tfd.Deterministic(
|
|
137
|
+
backend.einsum(
|
|
139
138
|
"k,kt->t",
|
|
140
139
|
knot_values,
|
|
141
|
-
|
|
140
|
+
backend.to_tensor(knot_info.weights),
|
|
142
141
|
),
|
|
143
142
|
name=constants.MU_T,
|
|
144
143
|
)
|
|
145
144
|
|
|
146
|
-
tau_gt = tau_g[:,
|
|
147
|
-
combined_media_transformed =
|
|
148
|
-
shape=(n_geos, n_times, 0), dtype=
|
|
145
|
+
tau_gt = tau_g[:, backend.newaxis] + mu_t
|
|
146
|
+
combined_media_transformed = backend.zeros(
|
|
147
|
+
shape=(n_geos, n_times, 0), dtype=backend.float32
|
|
149
148
|
)
|
|
150
|
-
combined_beta =
|
|
149
|
+
combined_beta = backend.zeros(shape=(n_geos, 0), dtype=backend.float32)
|
|
151
150
|
if media_tensors.media is not None:
|
|
152
151
|
alpha_m = yield prior_broadcast.alpha_m
|
|
153
152
|
ec_m = yield prior_broadcast.ec_m
|
|
154
153
|
eta_m = yield prior_broadcast.eta_m
|
|
155
154
|
slope_m = yield prior_broadcast.slope_m
|
|
156
|
-
beta_gm_dev = yield
|
|
157
|
-
|
|
155
|
+
beta_gm_dev = yield backend.tfd.Sample(
|
|
156
|
+
backend.tfd.Normal(0, 1),
|
|
158
157
|
[n_geos, n_media_channels],
|
|
159
158
|
name=constants.BETA_GM_DEV,
|
|
160
159
|
)
|
|
@@ -163,6 +162,7 @@ class PosteriorMCMCSampler:
|
|
|
163
162
|
alpha=alpha_m,
|
|
164
163
|
ec=ec_m,
|
|
165
164
|
slope=slope_m,
|
|
165
|
+
decay_functions=mmm.adstock_decay_spec.media,
|
|
166
166
|
)
|
|
167
167
|
prior_type = mmm.model_spec.effective_media_prior_type
|
|
168
168
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
@@ -194,7 +194,7 @@ class PosteriorMCMCSampler:
|
|
|
194
194
|
eta_x=eta_m,
|
|
195
195
|
beta_gx_dev=beta_gm_dev,
|
|
196
196
|
)
|
|
197
|
-
beta_m = yield
|
|
197
|
+
beta_m = yield backend.tfd.Deterministic(
|
|
198
198
|
beta_m_value, name=constants.BETA_M
|
|
199
199
|
)
|
|
200
200
|
|
|
@@ -202,23 +202,23 @@ class PosteriorMCMCSampler:
|
|
|
202
202
|
beta_gm_value = (
|
|
203
203
|
beta_eta_combined
|
|
204
204
|
if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
205
|
-
else
|
|
205
|
+
else backend.exp(beta_eta_combined)
|
|
206
206
|
)
|
|
207
|
-
beta_gm = yield
|
|
207
|
+
beta_gm = yield backend.tfd.Deterministic(
|
|
208
208
|
beta_gm_value, name=constants.BETA_GM
|
|
209
209
|
)
|
|
210
|
-
combined_media_transformed =
|
|
210
|
+
combined_media_transformed = backend.concatenate(
|
|
211
211
|
[combined_media_transformed, media_transformed], axis=-1
|
|
212
212
|
)
|
|
213
|
-
combined_beta =
|
|
213
|
+
combined_beta = backend.concatenate([combined_beta, beta_gm], axis=-1)
|
|
214
214
|
|
|
215
215
|
if rf_tensors.reach is not None:
|
|
216
216
|
alpha_rf = yield prior_broadcast.alpha_rf
|
|
217
217
|
ec_rf = yield prior_broadcast.ec_rf
|
|
218
218
|
eta_rf = yield prior_broadcast.eta_rf
|
|
219
219
|
slope_rf = yield prior_broadcast.slope_rf
|
|
220
|
-
beta_grf_dev = yield
|
|
221
|
-
|
|
220
|
+
beta_grf_dev = yield backend.tfd.Sample(
|
|
221
|
+
backend.tfd.Normal(0, 1),
|
|
222
222
|
[n_geos, n_rf_channels],
|
|
223
223
|
name=constants.BETA_GRF_DEV,
|
|
224
224
|
)
|
|
@@ -228,6 +228,7 @@ class PosteriorMCMCSampler:
|
|
|
228
228
|
alpha=alpha_rf,
|
|
229
229
|
ec=ec_rf,
|
|
230
230
|
slope=slope_rf,
|
|
231
|
+
decay_functions=mmm.adstock_decay_spec.rf,
|
|
231
232
|
)
|
|
232
233
|
|
|
233
234
|
prior_type = mmm.model_spec.effective_rf_prior_type
|
|
@@ -260,7 +261,7 @@ class PosteriorMCMCSampler:
|
|
|
260
261
|
eta_x=eta_rf,
|
|
261
262
|
beta_gx_dev=beta_grf_dev,
|
|
262
263
|
)
|
|
263
|
-
beta_rf = yield
|
|
264
|
+
beta_rf = yield backend.tfd.Deterministic(
|
|
264
265
|
beta_rf_value, name=constants.BETA_RF
|
|
265
266
|
)
|
|
266
267
|
|
|
@@ -268,23 +269,23 @@ class PosteriorMCMCSampler:
|
|
|
268
269
|
beta_grf_value = (
|
|
269
270
|
beta_eta_combined
|
|
270
271
|
if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
271
|
-
else
|
|
272
|
+
else backend.exp(beta_eta_combined)
|
|
272
273
|
)
|
|
273
|
-
beta_grf = yield
|
|
274
|
+
beta_grf = yield backend.tfd.Deterministic(
|
|
274
275
|
beta_grf_value, name=constants.BETA_GRF
|
|
275
276
|
)
|
|
276
|
-
combined_media_transformed =
|
|
277
|
+
combined_media_transformed = backend.concatenate(
|
|
277
278
|
[combined_media_transformed, rf_transformed], axis=-1
|
|
278
279
|
)
|
|
279
|
-
combined_beta =
|
|
280
|
+
combined_beta = backend.concatenate([combined_beta, beta_grf], axis=-1)
|
|
280
281
|
|
|
281
282
|
if organic_media_tensors.organic_media is not None:
|
|
282
283
|
alpha_om = yield prior_broadcast.alpha_om
|
|
283
284
|
ec_om = yield prior_broadcast.ec_om
|
|
284
285
|
eta_om = yield prior_broadcast.eta_om
|
|
285
286
|
slope_om = yield prior_broadcast.slope_om
|
|
286
|
-
beta_gom_dev = yield
|
|
287
|
-
|
|
287
|
+
beta_gom_dev = yield backend.tfd.Sample(
|
|
288
|
+
backend.tfd.Normal(0, 1),
|
|
288
289
|
[n_geos, n_organic_media_channels],
|
|
289
290
|
name=constants.BETA_GOM_DEV,
|
|
290
291
|
)
|
|
@@ -293,6 +294,7 @@ class PosteriorMCMCSampler:
|
|
|
293
294
|
alpha=alpha_om,
|
|
294
295
|
ec=ec_om,
|
|
295
296
|
slope=slope_om,
|
|
297
|
+
decay_functions=mmm.adstock_decay_spec.organic_media,
|
|
296
298
|
)
|
|
297
299
|
prior_type = mmm.model_spec.organic_media_prior_type
|
|
298
300
|
if prior_type == constants.TREATMENT_PRIOR_TYPE_COEFFICIENT:
|
|
@@ -307,7 +309,7 @@ class PosteriorMCMCSampler:
|
|
|
307
309
|
eta_x=eta_om,
|
|
308
310
|
beta_gx_dev=beta_gom_dev,
|
|
309
311
|
)
|
|
310
|
-
beta_om = yield
|
|
312
|
+
beta_om = yield backend.tfd.Deterministic(
|
|
311
313
|
beta_om_value, name=constants.BETA_OM
|
|
312
314
|
)
|
|
313
315
|
else:
|
|
@@ -317,23 +319,23 @@ class PosteriorMCMCSampler:
|
|
|
317
319
|
beta_gom_value = (
|
|
318
320
|
beta_eta_combined
|
|
319
321
|
if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
320
|
-
else
|
|
322
|
+
else backend.exp(beta_eta_combined)
|
|
321
323
|
)
|
|
322
|
-
beta_gom = yield
|
|
324
|
+
beta_gom = yield backend.tfd.Deterministic(
|
|
323
325
|
beta_gom_value, name=constants.BETA_GOM
|
|
324
326
|
)
|
|
325
|
-
combined_media_transformed =
|
|
327
|
+
combined_media_transformed = backend.concatenate(
|
|
326
328
|
[combined_media_transformed, organic_media_transformed], axis=-1
|
|
327
329
|
)
|
|
328
|
-
combined_beta =
|
|
330
|
+
combined_beta = backend.concatenate([combined_beta, beta_gom], axis=-1)
|
|
329
331
|
|
|
330
332
|
if organic_rf_tensors.organic_reach is not None:
|
|
331
333
|
alpha_orf = yield prior_broadcast.alpha_orf
|
|
332
334
|
ec_orf = yield prior_broadcast.ec_orf
|
|
333
335
|
eta_orf = yield prior_broadcast.eta_orf
|
|
334
336
|
slope_orf = yield prior_broadcast.slope_orf
|
|
335
|
-
beta_gorf_dev = yield
|
|
336
|
-
|
|
337
|
+
beta_gorf_dev = yield backend.tfd.Sample(
|
|
338
|
+
backend.tfd.Normal(0, 1),
|
|
337
339
|
[n_geos, n_organic_rf_channels],
|
|
338
340
|
name=constants.BETA_GORF_DEV,
|
|
339
341
|
)
|
|
@@ -343,6 +345,7 @@ class PosteriorMCMCSampler:
|
|
|
343
345
|
alpha=alpha_orf,
|
|
344
346
|
ec=ec_orf,
|
|
345
347
|
slope=slope_orf,
|
|
348
|
+
decay_functions=mmm.adstock_decay_spec.organic_rf,
|
|
346
349
|
)
|
|
347
350
|
|
|
348
351
|
prior_type = mmm.model_spec.organic_rf_prior_type
|
|
@@ -358,7 +361,7 @@ class PosteriorMCMCSampler:
|
|
|
358
361
|
eta_x=eta_orf,
|
|
359
362
|
beta_gx_dev=beta_gorf_dev,
|
|
360
363
|
)
|
|
361
|
-
beta_orf = yield
|
|
364
|
+
beta_orf = yield backend.tfd.Deterministic(
|
|
362
365
|
beta_orf_value, name=constants.BETA_ORF
|
|
363
366
|
)
|
|
364
367
|
else:
|
|
@@ -368,18 +371,20 @@ class PosteriorMCMCSampler:
|
|
|
368
371
|
beta_gorf_value = (
|
|
369
372
|
beta_eta_combined
|
|
370
373
|
if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
|
|
371
|
-
else
|
|
374
|
+
else backend.exp(beta_eta_combined)
|
|
372
375
|
)
|
|
373
|
-
beta_gorf = yield
|
|
376
|
+
beta_gorf = yield backend.tfd.Deterministic(
|
|
374
377
|
beta_gorf_value, name=constants.BETA_GORF
|
|
375
378
|
)
|
|
376
|
-
combined_media_transformed =
|
|
379
|
+
combined_media_transformed = backend.concatenate(
|
|
377
380
|
[combined_media_transformed, organic_rf_transformed], axis=-1
|
|
378
381
|
)
|
|
379
|
-
combined_beta =
|
|
382
|
+
combined_beta = backend.concatenate([combined_beta, beta_gorf], axis=-1)
|
|
380
383
|
|
|
381
|
-
sigma_gt =
|
|
382
|
-
|
|
384
|
+
sigma_gt = backend.transpose(
|
|
385
|
+
backend.broadcast_to(sigma, [n_times, n_geos])
|
|
386
|
+
)
|
|
387
|
+
y_pred_combined_media = tau_gt + backend.einsum(
|
|
383
388
|
"gtm,gm->gt", combined_media_transformed, combined_beta
|
|
384
389
|
)
|
|
385
390
|
# Omit gamma_c, xi_c, and gamma_gc from joint distribution output if
|
|
@@ -387,22 +392,22 @@ class PosteriorMCMCSampler:
|
|
|
387
392
|
if n_controls:
|
|
388
393
|
gamma_c = yield prior_broadcast.gamma_c
|
|
389
394
|
xi_c = yield prior_broadcast.xi_c
|
|
390
|
-
gamma_gc_dev = yield
|
|
391
|
-
|
|
395
|
+
gamma_gc_dev = yield backend.tfd.Sample(
|
|
396
|
+
backend.tfd.Normal(0, 1),
|
|
392
397
|
[n_geos, n_controls],
|
|
393
398
|
name=constants.GAMMA_GC_DEV,
|
|
394
399
|
)
|
|
395
|
-
gamma_gc = yield
|
|
400
|
+
gamma_gc = yield backend.tfd.Deterministic(
|
|
396
401
|
gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
|
|
397
402
|
)
|
|
398
|
-
y_pred_combined_media +=
|
|
403
|
+
y_pred_combined_media += backend.einsum(
|
|
399
404
|
"gtc,gc->gt", controls_scaled, gamma_gc
|
|
400
405
|
)
|
|
401
406
|
|
|
402
407
|
if mmm.non_media_treatments is not None:
|
|
403
408
|
xi_n = yield prior_broadcast.xi_n
|
|
404
|
-
gamma_gn_dev = yield
|
|
405
|
-
|
|
409
|
+
gamma_gn_dev = yield backend.tfd.Sample(
|
|
410
|
+
backend.tfd.Normal(0, 1),
|
|
406
411
|
[n_geos, n_non_media_channels],
|
|
407
412
|
name=constants.GAMMA_GN_DEV,
|
|
408
413
|
)
|
|
@@ -425,15 +430,15 @@ class PosteriorMCMCSampler:
|
|
|
425
430
|
eta_x=xi_n,
|
|
426
431
|
beta_gx_dev=gamma_gn_dev,
|
|
427
432
|
)
|
|
428
|
-
gamma_n = yield
|
|
433
|
+
gamma_n = yield backend.tfd.Deterministic(
|
|
429
434
|
gamma_n_value, name=constants.GAMMA_N
|
|
430
435
|
)
|
|
431
436
|
else:
|
|
432
437
|
raise ValueError(f"Unsupported prior type: {prior_type}")
|
|
433
|
-
gamma_gn = yield
|
|
438
|
+
gamma_gn = yield backend.tfd.Deterministic(
|
|
434
439
|
gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
|
|
435
440
|
)
|
|
436
|
-
y_pred = y_pred_combined_media +
|
|
441
|
+
y_pred = y_pred_combined_media + backend.einsum(
|
|
437
442
|
"gtn,gn->gt", non_media_treatments_normalized, gamma_gn
|
|
438
443
|
)
|
|
439
444
|
else:
|
|
@@ -445,21 +450,19 @@ class PosteriorMCMCSampler:
|
|
|
445
450
|
# deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the
|
|
446
451
|
# sampled posterior parameter values.
|
|
447
452
|
if holdout_id is not None:
|
|
448
|
-
y_pred_holdout =
|
|
449
|
-
test_sd =
|
|
450
|
-
sigma_gt_holdout =
|
|
451
|
-
yield
|
|
452
|
-
y_pred_holdout, sigma_gt_holdout, name="y"
|
|
453
|
-
)
|
|
453
|
+
y_pred_holdout = backend.where(holdout_id, 0.0, y_pred)
|
|
454
|
+
test_sd = backend.cast(1.0 / np.sqrt(2.0 * np.pi), backend.float32)
|
|
455
|
+
sigma_gt_holdout = backend.where(holdout_id, test_sd, sigma_gt)
|
|
456
|
+
yield backend.tfd.Normal(y_pred_holdout, sigma_gt_holdout, name="y")
|
|
454
457
|
else:
|
|
455
|
-
yield
|
|
458
|
+
yield backend.tfd.Normal(y_pred, sigma_gt, name="y")
|
|
456
459
|
|
|
457
460
|
return joint_dist_unpinned
|
|
458
461
|
|
|
459
|
-
def _get_joint_dist(self) ->
|
|
462
|
+
def _get_joint_dist(self) -> backend.tfd.Distribution:
|
|
460
463
|
mmm = self.model
|
|
461
464
|
y = (
|
|
462
|
-
|
|
465
|
+
backend.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
|
|
463
466
|
if mmm.holdout_id is not None
|
|
464
467
|
else mmm.kpi_scaled
|
|
465
468
|
)
|
|
@@ -471,7 +474,7 @@ class PosteriorMCMCSampler:
|
|
|
471
474
|
n_adapt: int,
|
|
472
475
|
n_burnin: int,
|
|
473
476
|
n_keep: int,
|
|
474
|
-
current_state: Mapping[str,
|
|
477
|
+
current_state: Mapping[str, backend.Tensor] | None = None,
|
|
475
478
|
init_step_size: int | None = None,
|
|
476
479
|
dual_averaging_kwargs: Mapping[str, int] | None = None,
|
|
477
480
|
max_tree_depth: int = 10,
|
|
@@ -549,7 +552,7 @@ class PosteriorMCMCSampler:
|
|
|
549
552
|
)
|
|
550
553
|
if seed is not None and isinstance(seed, int):
|
|
551
554
|
seed = (seed, seed)
|
|
552
|
-
seed =
|
|
555
|
+
seed = backend.random.sanitize_seed(seed) if seed is not None else None
|
|
553
556
|
n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
|
|
554
557
|
total_chains = np.sum(n_chains_list)
|
|
555
558
|
|
|
@@ -572,7 +575,7 @@ class PosteriorMCMCSampler:
|
|
|
572
575
|
seed=seed,
|
|
573
576
|
**pins,
|
|
574
577
|
)
|
|
575
|
-
except
|
|
578
|
+
except backend.errors.ResourceExhaustedError as error:
|
|
576
579
|
raise MCMCOOMError(
|
|
577
580
|
"ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
|
|
578
581
|
" integers as `n_chains` to sample chains serially (see"
|
|
@@ -584,9 +587,11 @@ class PosteriorMCMCSampler:
|
|
|
584
587
|
traces.append(mcmc.trace)
|
|
585
588
|
|
|
586
589
|
mcmc_states = {
|
|
587
|
-
k:
|
|
590
|
+
k: backend.einsum(
|
|
588
591
|
"ij...->ji...",
|
|
589
|
-
|
|
592
|
+
backend.concatenate([state[k] for state in states], axis=1)[
|
|
593
|
+
n_burnin:, ...
|
|
594
|
+
],
|
|
590
595
|
)
|
|
591
596
|
for k in states[0].keys()
|
|
592
597
|
if k not in constants.UNSAVED_PARAMETERS
|
|
@@ -604,10 +609,10 @@ class PosteriorMCMCSampler:
|
|
|
604
609
|
mcmc_trace = {}
|
|
605
610
|
for k in traces[0].keys():
|
|
606
611
|
if k not in constants.IGNORED_TRACE_METRICS:
|
|
607
|
-
mcmc_trace[k] =
|
|
612
|
+
mcmc_trace[k] = backend.concatenate(
|
|
608
613
|
[
|
|
609
|
-
|
|
610
|
-
|
|
614
|
+
backend.broadcast_to(
|
|
615
|
+
backend.transpose(trace[k][n_burnin:, ...]),
|
|
611
616
|
[n_chains_list[i], n_keep],
|
|
612
617
|
)
|
|
613
618
|
for i, trace in enumerate(traces)
|
|
@@ -647,7 +652,7 @@ class PosteriorMCMCSampler:
|
|
|
647
652
|
# Tensorflow does not include a "draw" dimension on step size metric if same
|
|
648
653
|
# step size is used for all chains. Step size must be broadcast to the
|
|
649
654
|
# correct shape.
|
|
650
|
-
sample_stats[constants.STEP_SIZE] =
|
|
655
|
+
sample_stats[constants.STEP_SIZE] = backend.broadcast_to(
|
|
651
656
|
sample_stats[constants.STEP_SIZE], [total_chains, n_keep]
|
|
652
657
|
)
|
|
653
658
|
sample_stats_dims[constants.STEP_SIZE] = [constants.CHAIN, constants.DRAW]
|