google-meridian 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,566 @@
1
+ # Copyright 2024 The Meridian Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for MCMC sampling of posterior distributions in a Meridian model."""
16
+
17
+ from collections.abc import Mapping, Sequence
18
+ from typing import TYPE_CHECKING
19
+
20
+ import arviz as az
21
+ from meridian import constants
22
+ import numpy as np
23
+ import tensorflow as tf
24
+ import tensorflow_probability as tfp
25
+
26
+ if TYPE_CHECKING:
27
+ from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
28
+
29
+
30
+ __all__ = [
31
+ "MCMCSamplingError",
32
+ "MCMCOOMError",
33
+ "PosteriorMCMCSampler",
34
+ ]
35
+
36
+
37
+ class MCMCSamplingError(Exception):
38
+ """The Markov Chain Monte Carlo (MCMC) sampling failed."""
39
+
40
+
41
+ class MCMCOOMError(Exception):
42
+ """The Markov Chain Monte Carlo (MCMC) sampling exceeds memory limits."""
43
+
44
+
45
+ def _get_tau_g(
46
+ tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int
47
+ ) -> tfp.distributions.Distribution:
48
+ """Computes `tau_g` from `tau_g_excl_baseline`.
49
+
50
+ This function computes `tau_g` by inserting a column of zeros at the
51
+ `baseline_geo` position in `tau_g_excl_baseline`.
52
+
53
+ Args:
54
+ tau_g_excl_baseline: A tensor of shape `[..., n_geos - 1]` for the
55
+ user-defined dimensions of the `tau_g` parameter distribution.
56
+ baseline_geo_idx: The index of the baseline geo to be set to zero.
57
+
58
+ Returns:
59
+ A tensor of shape `[..., n_geos]` with the final distribution of the `tau_g`
60
+ parameter with zero at position `baseline_geo_idx` and matching
61
+ `tau_g_excl_baseline` elsewhere.
62
+ """
63
+ rank = len(tau_g_excl_baseline.shape)
64
+ shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1
65
+ tau_g = tf.concat(
66
+ [
67
+ tau_g_excl_baseline[..., :baseline_geo_idx],
68
+ tf.zeros(shape, dtype=tau_g_excl_baseline.dtype),
69
+ tau_g_excl_baseline[..., baseline_geo_idx:],
70
+ ],
71
+ axis=rank - 1,
72
+ )
73
+ return tfp.distributions.Deterministic(tau_g, name="tau_g")
74
+
75
+
76
+ @tf.function(autograph=False, jit_compile=True)
77
+ def _xla_windowed_adaptive_nuts(**kwargs):
78
+ """XLA wrapper for windowed_adaptive_nuts."""
79
+ return tfp.experimental.mcmc.windowed_adaptive_nuts(**kwargs)
80
+
81
+
82
+ class PosteriorMCMCSampler:
83
+ """A callable that samples from posterior distributions using MCMC."""
84
+
85
+ def __init__(self, meridian: "model.Meridian"):
86
+ self._meridian = meridian
87
+
88
+ def _get_joint_dist_unpinned(self) -> tfp.distributions.Distribution:
89
+ """Returns a `JointDistributionCoroutineAutoBatched` function for MCMC."""
90
+ mmm = self._meridian
91
+ mmm.populate_cached_properties()
92
+
93
+ # This lists all the derived properties and states of this Meridian object
94
+ # that are referenced by the joint distribution coroutine.
95
+ # That is, these are the list of captured parameters.
96
+ prior_broadcast = mmm.prior_broadcast
97
+ baseline_geo_idx = mmm.baseline_geo_idx
98
+ knot_info = mmm.knot_info
99
+ n_geos = mmm.n_geos
100
+ n_times = mmm.n_times
101
+ n_media_channels = mmm.n_media_channels
102
+ n_rf_channels = mmm.n_rf_channels
103
+ n_organic_media_channels = mmm.n_organic_media_channels
104
+ n_organic_rf_channels = mmm.n_organic_rf_channels
105
+ n_controls = mmm.n_controls
106
+ n_non_media_channels = mmm.n_non_media_channels
107
+ holdout_id = mmm.holdout_id
108
+ media_tensors = mmm.media_tensors
109
+ rf_tensors = mmm.rf_tensors
110
+ organic_media_tensors = mmm.organic_media_tensors
111
+ organic_rf_tensors = mmm.organic_rf_tensors
112
+ controls_scaled = mmm.controls_scaled
113
+ non_media_treatments_scaled = mmm.non_media_treatments_scaled
114
+ media_effects_dist = mmm.media_effects_dist
115
+ adstock_hill_media_fn = mmm.adstock_hill_media
116
+ adstock_hill_rf_fn = mmm.adstock_hill_rf
117
+ get_roi_prior_beta_m_value_fn = (
118
+ mmm.prior_sampler_callable.get_roi_prior_beta_m_value
119
+ )
120
+ get_roi_prior_beta_rf_value_fn = (
121
+ mmm.prior_sampler_callable.get_roi_prior_beta_rf_value
122
+ )
123
+
124
+ @tfp.distributions.JointDistributionCoroutineAutoBatched
125
+ def joint_dist_unpinned():
126
+ # Sample directly from prior.
127
+ knot_values = yield prior_broadcast.knot_values
128
+ gamma_c = yield prior_broadcast.gamma_c
129
+ xi_c = yield prior_broadcast.xi_c
130
+ sigma = yield prior_broadcast.sigma
131
+
132
+ tau_g_excl_baseline = yield tfp.distributions.Sample(
133
+ prior_broadcast.tau_g_excl_baseline,
134
+ name=constants.TAU_G_EXCL_BASELINE,
135
+ )
136
+ tau_g = yield _get_tau_g(
137
+ tau_g_excl_baseline=tau_g_excl_baseline,
138
+ baseline_geo_idx=baseline_geo_idx,
139
+ )
140
+ mu_t = yield tfp.distributions.Deterministic(
141
+ tf.einsum(
142
+ "k,kt->t",
143
+ knot_values,
144
+ tf.convert_to_tensor(knot_info.weights),
145
+ ),
146
+ name=constants.MU_T,
147
+ )
148
+
149
+ tau_gt = tau_g[:, tf.newaxis] + mu_t
150
+ combined_media_transformed = tf.zeros(
151
+ shape=(n_geos, n_times, 0), dtype=tf.float32
152
+ )
153
+ combined_beta = tf.zeros(shape=(n_geos, 0), dtype=tf.float32)
154
+ if media_tensors.media is not None:
155
+ alpha_m = yield prior_broadcast.alpha_m
156
+ ec_m = yield prior_broadcast.ec_m
157
+ eta_m = yield prior_broadcast.eta_m
158
+ slope_m = yield prior_broadcast.slope_m
159
+ beta_gm_dev = yield tfp.distributions.Sample(
160
+ tfp.distributions.Normal(0, 1),
161
+ [n_geos, n_media_channels],
162
+ name=constants.BETA_GM_DEV,
163
+ )
164
+ media_transformed = adstock_hill_media_fn(
165
+ media=media_tensors.media_scaled,
166
+ alpha=alpha_m,
167
+ ec=ec_m,
168
+ slope=slope_m,
169
+ )
170
+ prior_type = mmm.model_spec.paid_media_prior_type
171
+ if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
172
+ if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
173
+ roi_or_mroi_m = yield prior_broadcast.roi_m
174
+ else:
175
+ roi_or_mroi_m = yield prior_broadcast.mroi_m
176
+ beta_m_value = get_roi_prior_beta_m_value_fn(
177
+ alpha_m,
178
+ beta_gm_dev,
179
+ ec_m,
180
+ eta_m,
181
+ roi_or_mroi_m,
182
+ slope_m,
183
+ media_transformed,
184
+ )
185
+ beta_m = yield tfp.distributions.Deterministic(
186
+ beta_m_value, name=constants.BETA_M
187
+ )
188
+ else:
189
+ beta_m = yield prior_broadcast.beta_m
190
+
191
+ beta_eta_combined = beta_m + eta_m * beta_gm_dev
192
+ beta_gm_value = (
193
+ beta_eta_combined
194
+ if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
195
+ else tf.math.exp(beta_eta_combined)
196
+ )
197
+ beta_gm = yield tfp.distributions.Deterministic(
198
+ beta_gm_value, name=constants.BETA_GM
199
+ )
200
+ combined_media_transformed = tf.concat(
201
+ [combined_media_transformed, media_transformed], axis=-1
202
+ )
203
+ combined_beta = tf.concat([combined_beta, beta_gm], axis=-1)
204
+
205
+ if rf_tensors.reach is not None:
206
+ alpha_rf = yield prior_broadcast.alpha_rf
207
+ ec_rf = yield prior_broadcast.ec_rf
208
+ eta_rf = yield prior_broadcast.eta_rf
209
+ slope_rf = yield prior_broadcast.slope_rf
210
+ beta_grf_dev = yield tfp.distributions.Sample(
211
+ tfp.distributions.Normal(0, 1),
212
+ [n_geos, n_rf_channels],
213
+ name=constants.BETA_GRF_DEV,
214
+ )
215
+ rf_transformed = adstock_hill_rf_fn(
216
+ reach=rf_tensors.reach_scaled,
217
+ frequency=rf_tensors.frequency,
218
+ alpha=alpha_rf,
219
+ ec=ec_rf,
220
+ slope=slope_rf,
221
+ )
222
+
223
+ prior_type = mmm.model_spec.paid_media_prior_type
224
+ if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES:
225
+ if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI:
226
+ roi_or_mroi_rf = yield prior_broadcast.roi_rf
227
+ else:
228
+ roi_or_mroi_rf = yield prior_broadcast.mroi_rf
229
+ beta_rf_value = get_roi_prior_beta_rf_value_fn(
230
+ alpha_rf,
231
+ beta_grf_dev,
232
+ ec_rf,
233
+ eta_rf,
234
+ roi_or_mroi_rf,
235
+ slope_rf,
236
+ rf_transformed,
237
+ )
238
+ beta_rf = yield tfp.distributions.Deterministic(
239
+ beta_rf_value,
240
+ name=constants.BETA_RF,
241
+ )
242
+ else:
243
+ beta_rf = yield prior_broadcast.beta_rf
244
+
245
+ beta_eta_combined = beta_rf + eta_rf * beta_grf_dev
246
+ beta_grf_value = (
247
+ beta_eta_combined
248
+ if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
249
+ else tf.math.exp(beta_eta_combined)
250
+ )
251
+ beta_grf = yield tfp.distributions.Deterministic(
252
+ beta_grf_value, name=constants.BETA_GRF
253
+ )
254
+ combined_media_transformed = tf.concat(
255
+ [combined_media_transformed, rf_transformed], axis=-1
256
+ )
257
+ combined_beta = tf.concat([combined_beta, beta_grf], axis=-1)
258
+
259
+ if organic_media_tensors.organic_media is not None:
260
+ alpha_om = yield prior_broadcast.alpha_om
261
+ ec_om = yield prior_broadcast.ec_om
262
+ eta_om = yield prior_broadcast.eta_om
263
+ slope_om = yield prior_broadcast.slope_om
264
+ beta_gom_dev = yield tfp.distributions.Sample(
265
+ tfp.distributions.Normal(0, 1),
266
+ [n_geos, n_organic_media_channels],
267
+ name=constants.BETA_GOM_DEV,
268
+ )
269
+ organic_media_transformed = adstock_hill_media_fn(
270
+ media=organic_media_tensors.organic_media_scaled,
271
+ alpha=alpha_om,
272
+ ec=ec_om,
273
+ slope=slope_om,
274
+ )
275
+ beta_om = yield prior_broadcast.beta_om
276
+
277
+ beta_eta_combined = beta_om + eta_om * beta_gom_dev
278
+ beta_gom_value = (
279
+ beta_eta_combined
280
+ if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
281
+ else tf.math.exp(beta_eta_combined)
282
+ )
283
+ beta_gom = yield tfp.distributions.Deterministic(
284
+ beta_gom_value, name=constants.BETA_GOM
285
+ )
286
+ combined_media_transformed = tf.concat(
287
+ [combined_media_transformed, organic_media_transformed], axis=-1
288
+ )
289
+ combined_beta = tf.concat([combined_beta, beta_gom], axis=-1)
290
+
291
+ if organic_rf_tensors.organic_reach is not None:
292
+ alpha_orf = yield prior_broadcast.alpha_orf
293
+ ec_orf = yield prior_broadcast.ec_orf
294
+ eta_orf = yield prior_broadcast.eta_orf
295
+ slope_orf = yield prior_broadcast.slope_orf
296
+ beta_gorf_dev = yield tfp.distributions.Sample(
297
+ tfp.distributions.Normal(0, 1),
298
+ [n_geos, n_organic_rf_channels],
299
+ name=constants.BETA_GORF_DEV,
300
+ )
301
+ organic_rf_transformed = adstock_hill_rf_fn(
302
+ reach=organic_rf_tensors.organic_reach_scaled,
303
+ frequency=organic_rf_tensors.organic_frequency,
304
+ alpha=alpha_orf,
305
+ ec=ec_orf,
306
+ slope=slope_orf,
307
+ )
308
+
309
+ beta_orf = yield prior_broadcast.beta_orf
310
+
311
+ beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev
312
+ beta_gorf_value = (
313
+ beta_eta_combined
314
+ if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL
315
+ else tf.math.exp(beta_eta_combined)
316
+ )
317
+ beta_gorf = yield tfp.distributions.Deterministic(
318
+ beta_gorf_value, name=constants.BETA_GORF
319
+ )
320
+ combined_media_transformed = tf.concat(
321
+ [combined_media_transformed, organic_rf_transformed], axis=-1
322
+ )
323
+ combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1)
324
+
325
+ sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos]))
326
+ gamma_gc_dev = yield tfp.distributions.Sample(
327
+ tfp.distributions.Normal(0, 1),
328
+ [n_geos, n_controls],
329
+ name=constants.GAMMA_GC_DEV,
330
+ )
331
+ gamma_gc = yield tfp.distributions.Deterministic(
332
+ gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC
333
+ )
334
+ y_pred_combined_media = (
335
+ tau_gt
336
+ + tf.einsum("gtm,gm->gt", combined_media_transformed, combined_beta)
337
+ + tf.einsum("gtc,gc->gt", controls_scaled, gamma_gc)
338
+ )
339
+
340
+ if mmm.non_media_treatments is not None:
341
+ gamma_n = yield prior_broadcast.gamma_n
342
+ xi_n = yield prior_broadcast.xi_n
343
+ gamma_gn_dev = yield tfp.distributions.Sample(
344
+ tfp.distributions.Normal(0, 1),
345
+ [n_geos, n_non_media_channels],
346
+ name=constants.GAMMA_GN_DEV,
347
+ )
348
+ gamma_gn = yield tfp.distributions.Deterministic(
349
+ gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN
350
+ )
351
+ y_pred = y_pred_combined_media + tf.einsum(
352
+ "gtn,gn->gt", non_media_treatments_scaled, gamma_gn
353
+ )
354
+ else:
355
+ y_pred = y_pred_combined_media
356
+
357
+ # If there are any holdout observations, the holdout KPI values will
358
+ # be replaced with zeros using `experimental_pin`. For these
359
+ # observations, we set the posterior mean equal to zero and standard
360
+ # deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the
361
+ # sampled posterior parameter values.
362
+ if holdout_id is not None:
363
+ y_pred_holdout = tf.where(holdout_id, 0.0, y_pred)
364
+ test_sd = tf.cast(1.0 / np.sqrt(2.0 * np.pi), tf.float32)
365
+ sigma_gt_holdout = tf.where(holdout_id, test_sd, sigma_gt)
366
+ yield tfp.distributions.Normal(
367
+ y_pred_holdout, sigma_gt_holdout, name="y"
368
+ )
369
+ else:
370
+ yield tfp.distributions.Normal(y_pred, sigma_gt, name="y")
371
+
372
+ return joint_dist_unpinned
373
+
374
+ def _get_joint_dist(self) -> tfp.distributions.Distribution:
375
+ mmm = self._meridian
376
+ y = (
377
+ tf.where(mmm.holdout_id, 0.0, mmm.kpi_scaled)
378
+ if mmm.holdout_id is not None
379
+ else mmm.kpi_scaled
380
+ )
381
+ return self._get_joint_dist_unpinned().experimental_pin(y=y)
382
+
383
+ def __call__(
384
+ self,
385
+ n_chains: Sequence[int] | int,
386
+ n_adapt: int,
387
+ n_burnin: int,
388
+ n_keep: int,
389
+ current_state: Mapping[str, tf.Tensor] | None = None,
390
+ init_step_size: int | None = None,
391
+ dual_averaging_kwargs: Mapping[str, int] | None = None,
392
+ max_tree_depth: int = 10,
393
+ max_energy_diff: float = 500.0,
394
+ unrolled_leapfrog_steps: int = 1,
395
+ parallel_iterations: int = 10,
396
+ seed: Sequence[int] | None = None,
397
+ **pins,
398
+ ) -> az.InferenceData:
399
+ """Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions.
400
+
401
+ For more information about the arguments, see [`windowed_adaptive_nuts`]
402
+ (https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/mcmc/windowed_adaptive_nuts).
403
+
404
+ Args:
405
+ n_chains: Number of MCMC chains. Given a sequence of integers,
406
+ `windowed_adaptive_nuts` will be called once for each element. The
407
+ `n_chains` argument of each `windowed_adaptive_nuts` call will be equal
408
+ to the respective integer element. Using a list of integers, one can
409
+ split the chains of a `windowed_adaptive_nuts` call into multiple calls
410
+ with fewer chains per call. This can reduce memory usage. This might
411
+ require an increased number of adaptation steps for convergence, as the
412
+ optimization is occurring across fewer chains per sampling call.
413
+ n_adapt: Number of adaptation draws per chain.
414
+ n_burnin: Number of burn-in draws per chain. Burn-in draws occur after
415
+ adaptation draws and before the kept draws.
416
+ n_keep: Integer number of draws per chain to keep for inference.
417
+ current_state: Optional structure of tensors at which to initialize
418
+ sampling. Use the same shape and structure as
419
+ `model.experimental_pin(**pins).sample(n_chains)`.
420
+ init_step_size: Optional integer determining where to initialize the step
421
+ size for the leapfrog integrator. The structure must broadcast with
422
+ `current_state`. For example, if the initial state is: ``` { 'a':
423
+ tf.zeros(n_chains), 'b': tf.zeros([n_chains, n_features]), } ``` then
424
+ any of `1.`, `{'a': 1., 'b': 1.}`, or `{'a': tf.ones(n_chains), 'b':
425
+ tf.ones([n_chains, n_features])}` will work. Defaults to the dimension
426
+ of the log density to the ¼ power.
427
+ dual_averaging_kwargs: Optional dict keyword arguments to pass to
428
+ `tfp.mcmc.DualAveragingStepSizeAdaptation`. By default, a
429
+ `target_accept_prob` of `0.85` is set, acceptance probabilities across
430
+ chains are reduced using a harmonic mean, and the class defaults are
431
+ used otherwise.
432
+ max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
433
+ maximum number of leapfrog steps is bounded by `2**max_tree_depth`, for
434
+ example, the number of nodes in a binary tree `max_tree_depth` nodes
435
+ deep. The default setting of `10` takes up to 1024 leapfrog steps.
436
+ max_energy_diff: Scalar threshold of energy differences at each leapfrog,
437
+ divergence samples are defined as leapfrog steps that exceed this
438
+ threshold. Default is `1000`.
439
+ unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree
440
+ expansion step. Applies a direct linear multiplier to the maximum
441
+ trajectory length implied by `max_tree_depth`. Defaults is `1`.
442
+ parallel_iterations: Number of iterations allowed to run in parallel. Must
443
+ be a positive integer. For more information, see `tf.while_loop`.
444
+ seed: Used to set the seed for reproducible results. For more information,
445
+ see [PRNGS and seeds]
446
+ (https://github.com/tensorflow/probability/blob/main/PRNGS.md).
447
+ **pins: These are used to condition the provided joint distribution, and
448
+ are passed directly to `joint_dist.experimental_pin(**pins)`.
449
+
450
+ Returns:
451
+ An Arviz `InferenceData` object containing posterior samples only.
452
+
453
+ Throws:
454
+ MCMCOOMError: If the model is out of memory. Try reducing `n_keep` or pass
455
+ a list of integers as `n_chains` to sample chains serially. For more
456
+ information, see
457
+ [ResourceExhaustedError when running Meridian.sample_posterior]
458
+ (https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error).
459
+ """
460
+ seed = tfp.random.sanitize_seed(seed) if seed else None
461
+ n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains
462
+ total_chains = np.sum(n_chains_list)
463
+
464
+ states = []
465
+ traces = []
466
+ for n_chains_batch in n_chains_list:
467
+ try:
468
+ mcmc = _xla_windowed_adaptive_nuts(
469
+ n_draws=n_burnin + n_keep,
470
+ joint_dist=self._get_joint_dist(),
471
+ n_chains=n_chains_batch,
472
+ num_adaptation_steps=n_adapt,
473
+ current_state=current_state,
474
+ init_step_size=init_step_size,
475
+ dual_averaging_kwargs=dual_averaging_kwargs,
476
+ max_tree_depth=max_tree_depth,
477
+ max_energy_diff=max_energy_diff,
478
+ unrolled_leapfrog_steps=unrolled_leapfrog_steps,
479
+ parallel_iterations=parallel_iterations,
480
+ seed=seed,
481
+ **pins,
482
+ )
483
+ except tf.errors.ResourceExhaustedError as error:
484
+ raise MCMCOOMError(
485
+ "ERROR: Out of memory. Try reducing `n_keep` or pass a list of"
486
+ " integers as `n_chains` to sample chains serially (see"
487
+ " https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error)"
488
+ ) from error
489
+ states.append(mcmc.all_states._asdict())
490
+ traces.append(mcmc.trace)
491
+
492
+ mcmc_states = {
493
+ k: tf.einsum(
494
+ "ij...->ji...",
495
+ tf.concat([state[k] for state in states], axis=1)[n_burnin:, ...],
496
+ )
497
+ for k in states[0].keys()
498
+ if k not in constants.UNSAVED_PARAMETERS
499
+ }
500
+ # Create Arviz InferenceData for posterior draws.
501
+ posterior_coords = self._meridian.create_inference_data_coords(
502
+ total_chains, n_keep
503
+ )
504
+ posterior_dims = self._meridian.create_inference_data_dims()
505
+ infdata_posterior = az.convert_to_inference_data(
506
+ mcmc_states, coords=posterior_coords, dims=posterior_dims
507
+ )
508
+
509
+ # Save trace metrics in InferenceData.
510
+ mcmc_trace = {}
511
+ for k in traces[0].keys():
512
+ if k not in constants.IGNORED_TRACE_METRICS:
513
+ mcmc_trace[k] = tf.concat(
514
+ [
515
+ tf.broadcast_to(
516
+ tf.transpose(trace[k][n_burnin:, ...]),
517
+ [n_chains_list[i], n_keep],
518
+ )
519
+ for i, trace in enumerate(traces)
520
+ ],
521
+ axis=0,
522
+ )
523
+
524
+ trace_coords = {
525
+ constants.CHAIN: np.arange(total_chains),
526
+ constants.DRAW: np.arange(n_keep),
527
+ }
528
+ trace_dims = {
529
+ k: [constants.CHAIN, constants.DRAW] for k in mcmc_trace.keys()
530
+ }
531
+ infdata_trace = az.convert_to_inference_data(
532
+ mcmc_trace, coords=trace_coords, dims=trace_dims, group="trace"
533
+ )
534
+
535
+ # Create Arviz InferenceData for divergent transitions and other sampling
536
+ # statistics. Note that InferenceData has a different naming convention
537
+ # than Tensorflow, and only certain variables are recongnized.
538
+ # https://arviz-devs.github.io/arviz/schema/schema.html#sample-stats
539
+ # The list of values returned by windowed_adaptive_nuts() is the following:
540
+ # 'step_size', 'tune', 'target_log_prob', 'diverging', 'accept_ratio',
541
+ # 'variance_scaling', 'n_steps', 'is_accepted'.
542
+
543
+ sample_stats = {
544
+ constants.SAMPLE_STATS_METRICS[k]: v
545
+ for k, v in mcmc_trace.items()
546
+ if k in constants.SAMPLE_STATS_METRICS
547
+ }
548
+ sample_stats_dims = {
549
+ constants.SAMPLE_STATS_METRICS[k]: v
550
+ for k, v in trace_dims.items()
551
+ if k in constants.SAMPLE_STATS_METRICS
552
+ }
553
+ # Tensorflow does not include a "draw" dimension on step size metric if same
554
+ # step size is used for all chains. Step size must be broadcast to the
555
+ # correct shape.
556
+ sample_stats[constants.STEP_SIZE] = tf.broadcast_to(
557
+ sample_stats[constants.STEP_SIZE], [total_chains, n_keep]
558
+ )
559
+ sample_stats_dims[constants.STEP_SIZE] = [constants.CHAIN, constants.DRAW]
560
+ infdata_sample_stats = az.convert_to_inference_data(
561
+ sample_stats,
562
+ coords=trace_coords,
563
+ dims=sample_stats_dims,
564
+ group="sample_stats",
565
+ )
566
+ return az.concat(infdata_posterior, infdata_trace, infdata_sample_stats)