eryn 1.2.4__py3-none-any.whl → 1.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eryn/utils/plot.py ADDED
@@ -0,0 +1,1393 @@
1
+ # *-- coding: utf-8 --*
2
+ import os
3
+ import numpy as np
4
+ import matplotlib as mpl
5
+ import matplotlib.pyplot as plt
6
+ from matplotlib.patches import Ellipse, Rectangle
7
+
8
+ from matplotlib.colors import to_rgba
9
+
10
+ import corner
11
+ import typing
12
+
13
+ from eryn.utils.updates import UpdateStep
14
+
15
+ import pandas as pd
16
+ import seaborn as sns
17
+
18
+ from eryn.utils.utility import stepping_stone_log_evidence, get_integrated_act
19
+ DEFAULT_PALETTE = "icefire"
20
+
21
+ try:
22
+ import scienceplots
23
+ plt.style.use(['science'])
24
+ except (ImportError, ModuleNotFoundError):
25
+ pass
26
+
27
+ # increase default font size
28
+ mpl.rcParams.update({'font.size': 16})
29
+
30
+ class Backend:
31
+ """A placeholder Backend class for type hinting."""
32
+ pass
33
+
34
+ def save_or_show(fig, filename=None):
35
+ """
36
+ Save the figure to a file or show it.
37
+
38
+ Args:
39
+ fig (matplotlib.figure.Figure): Figure to save or show.
40
+ filename (str, optional): If provided, saves the figure to this filename.
41
+ """
42
+ if filename:
43
+ fig.savefig(filename, dpi=200, bbox_inches='tight')
44
+ plt.close(fig)
45
+ else:
46
+ plt.show()
47
+
48
+ def cov_ellipse(mean, cov, ax, n_std=1.0, **kwargs):
49
+ """
50
+ Plot a covariance ellipse using eigendecomposition.
51
+
52
+ The ellipse axes are aligned with the eigenvectors of the covariance matrix,
53
+ and scaled by sqrt(eigenvalue) * n_std.
54
+
55
+ Args:
56
+ mean (array-like): Center of the ellipse (mean_x, mean_y).
57
+ cov (np.ndarray): 2x2 covariance matrix.
58
+ ax (matplotlib.axes.Axes): Axes object on which to plot the ellipse.
59
+ n_std (float, optional): Number of standard deviations for ellipse radius. Default is 1.0.
60
+ **kwargs: Additional keyword arguments passed to matplotlib.patches.Ellipse.
61
+
62
+ Returns:
63
+ matplotlib.patches.Ellipse: The covariance ellipse added to the axes.
64
+ """
65
+ # Eigendecomposition: eigenvalues are variances along principal axes
66
+ eigenvalues, eigenvectors = np.linalg.eigh(cov)
67
+
68
+ # Sort by eigenvalue (largest first) for consistent orientation
69
+ order = eigenvalues.argsort()[::-1]
70
+ eigenvalues = eigenvalues[order]
71
+ eigenvectors = eigenvectors[:, order]
72
+
73
+ # Ellipse dimensions: 2 * n_std * sqrt(eigenvalue) for width/height
74
+ width, height = 2 * n_std * np.sqrt(eigenvalues)
75
+
76
+ # Rotation angle from the first eigenvector (major axis direction)
77
+ angle = np.degrees(np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0]))
78
+
79
+ ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle, **kwargs)
80
+ return ax.add_patch(ellipse)
81
+
82
+ def overlay_fim_covariance(
83
+ fig,
84
+ covariance,
85
+ means=None,
86
+ nsigmas=[1, 2, 3],
87
+ plot_1d=False,
88
+ colors=None,
89
+ linestyles=None,
90
+ linewidths=None,
91
+ alpha=0.7,
92
+ labels=None,
93
+ ):
94
+ """
95
+ Overlay Fisher Information Matrix confidence contours on corner plot axes.
96
+
97
+ For 2D subplots, draws elliptical contours at specified confidence levels.
98
+ For 1D diagonal plots, draws vertical lines at ±nσ from the mean.
99
+
100
+ Args:
101
+ fig (matplotlib.figure.Figure): Figure object containing corner plot axes.
102
+ covariance (np.ndarray): Covariance matrix from Fisher analysis, shape (n_params, n_params).
103
+ means (np.ndarray, optional): Mean values for each parameter. If None, uses origin (0, 0, ...).
104
+ nsigmas (list, optional): List of sigma levels to plot (e.g., [1, 2, 3]). Default [1, 2, 3].
105
+ plot_1d (bool, optional): Whether to plot 1D contours on diagonal plots. Default is False.
106
+ colors (list, optional): Colors for each sigma level. If None, uses default color cycle.
107
+ linestyles (list, optional): Line styles for each sigma level. If None, uses solid lines.
108
+ linewidths (list, optional): Line widths for each sigma level. If None, uses default (1.5).
109
+ alpha (float, optional): Transparency of contours. Default 0.7.
110
+ labels (list, optional): Labels for legend entries. If None, uses "nσ FIM".
111
+
112
+ Returns:
113
+ matplotlib.figure.Figure: The figure with overlaid contours.
114
+ """
115
+
116
+ # Convert axs to numpy array if it's a list
117
+ axs = np.array(fig.get_axes())
118
+
119
+ # Infer number of parameters from covariance matrix
120
+ n_params = covariance.shape[0]
121
+
122
+ if covariance.shape != (n_params, n_params):
123
+ raise ValueError(f"Covariance matrix must be square, got shape {covariance.shape}")
124
+
125
+ # Set default means to origin
126
+ if means is None:
127
+ means = np.zeros(n_params)
128
+ elif len(means) != n_params:
129
+ raise ValueError(f"means must have length {n_params}, got {len(means)}")
130
+
131
+ # Set default colors
132
+ if colors is None:
133
+ colors = [f'C{i}' for i in range(len(nsigmas))]
134
+ elif isinstance(colors, str):
135
+ colors = [colors] * len(nsigmas)
136
+ elif len(colors) != len(nsigmas):
137
+ raise ValueError(f"colors must have length {len(nsigmas)}, got {len(colors)}")
138
+
139
+ # Set default linestyles
140
+ if linestyles is None:
141
+ linestyles = ['-'] * len(nsigmas)
142
+ elif len(linestyles) != len(nsigmas):
143
+ raise ValueError(f"linestyles must have length {len(nsigmas)}, got {len(linestyles)}")
144
+
145
+ # Set default linewidths
146
+ if linewidths is None:
147
+ linewidths = [1.5] * len(nsigmas)
148
+ elif len(linewidths) != len(nsigmas):
149
+ raise ValueError(f"linewidths must have length {len(nsigmas)}, got {len(linewidths)}")
150
+
151
+ # Set default labels
152
+ if labels is None:
153
+ labels = [f'{n}$\\sigma$ FIM' for n in nsigmas]
154
+ elif len(labels) != len(nsigmas):
155
+ raise ValueError(f"labels must have length {len(nsigmas)}, got {len(labels)}")
156
+
157
+ # Extract standard deviations for 1D plots
158
+ sigmas = np.sqrt(np.diag(covariance))
159
+
160
+ # Reshape axes into 2D grid if needed
161
+ if axs.ndim == 1:
162
+ # Corner plot axes are typically returned as 1D array
163
+ # Reshape to (n_params, n_params) grid
164
+ n_axs = int(np.sqrt(len(axs)))
165
+ axs_grid = np.empty((n_axs, n_axs), dtype=object)
166
+ idx = 0
167
+ for i in range(n_axs):
168
+ for j in range(n_axs):
169
+ axs_grid[j, i] = axs[idx]
170
+ idx += 1
171
+ # if j <= i:
172
+ # axs_grid[j, i] = axs[idx]
173
+ # idx += 1
174
+ # else:
175
+ # axs_grid[j, i] = None
176
+ else:
177
+ axs_grid = axs
178
+
179
+ # Loop over each sigma level
180
+ for sigma_idx, (n_sigma, color, ls, lw, label) in enumerate(
181
+ zip(nsigmas, colors, linestyles, linewidths, labels)
182
+ ):
183
+ # Loop over all subplots
184
+ for i in range(n_params):
185
+ for j in range(i, n_params):
186
+ ax = axs_grid[i, j]
187
+
188
+ if ax is None:
189
+ continue
190
+
191
+ if i == j:
192
+ if plot_1d:
193
+ # 1D diagonal plot - draw vertical lines at mean ± n*sigma
194
+ mean_val = means[i]
195
+ sigma_val = sigmas[i]
196
+
197
+ # Get y-limits for vertical lines
198
+ ylim = ax.get_ylim()
199
+
200
+ # Draw vertical lines
201
+ for sign in [-1, 1]:
202
+ line = ax.axvline(
203
+ mean_val + sign * n_sigma * sigma_val,
204
+ color=color,
205
+ linestyle=ls,
206
+ linewidth=lw,
207
+ alpha=alpha,
208
+ zorder=10,
209
+ )
210
+ else:
211
+ continue
212
+
213
+ else:
214
+ # 2D off-diagonal plot - draw ellipse
215
+ # Extract 2x2 subcovariance for parameters j and i
216
+ cov = np.array(
217
+ (
218
+ (covariance[i][i], covariance[i][j]),
219
+ (covariance[j][i], covariance[j][j]),
220
+ )
221
+ )
222
+ # print(cov)
223
+
224
+ mean = np.array((means[i], means[j]))
225
+ cov_ellipse(mean, cov, ax, n_std=n_sigma,
226
+ edgecolor=color, facecolor='none', linestyle=ls, linewidth=lw,
227
+ zorder=10, alpha=alpha
228
+ )
229
+
230
+ return fig
231
+
232
+ def cornerplot(data, *args, means=None, overlay_covariance=None, legend_label='Samples', overlay_label='Information Matrix Covariance', filename=None, **kwargs):
233
+ """
234
+ Create a corner plot with optional Information Matrix covariance overlay. This is centered around the means if provided.
235
+ Wrapper around `corner.corner()` that adds Fisher Information Matrix covariance contours.
236
+
237
+ Args:
238
+ data (array-like): Input data for corner plot (e.g., MCMC samples).
239
+ *args: Positional arguments passed to `corner.corner()`.
240
+ means (array-like, optional): Mean values for each parameter to center the overlay. If None, uses 'truths' from kwargs or mean of data.
241
+ overlay_covariance (np.ndarray, optional): Covariance matrix to overlay. If None, no overlay is added.
242
+ legend_label (str, optional): Label for the sample distribution in the legend. Default is 'Samples'.
243
+ overlay_label (str, optional): Label for the overlay covariance in the legend. Default is 'Information Matrix Covariance', assuming FIM.
244
+ filename (str, optional): If provided, saves the figure to this filename.
245
+ **kwargs: Keyword arguments passed to `corner.corner()`.
246
+
247
+ Returns:
248
+ matplotlib.figure.Figure: The corner plot figure with optional overlays. If `filename` is provided, the figure is saved instead.
249
+ """
250
+
251
+ corner_kwargs = {
252
+ #'quantiles': [0.16, 0.5, 0.84],
253
+ 'levels': (1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-4.5)),
254
+ 'show_titles': True,
255
+ 'title_fmt': '.1e',
256
+ 'title_kwargs': {'fontsize': 12},
257
+ 'label_kwargs': {'fontsize': 14},
258
+ 'hist_kwargs': {'density': True, 'linewidth': 2},
259
+ 'plot_datapoints': False,
260
+ 'fill_contours': True,
261
+ 'color': 'steelblue',
262
+ 'truth_color': 'red',
263
+ }
264
+
265
+ # Update base kwargs with any user-provided kwargs
266
+ corner_kwargs.update(kwargs)
267
+
268
+ # how do we deal with 2d truths for reversible jump?
269
+ truths = corner_kwargs.get('truths', None)
270
+ if truths is not None and len(truths.shape) > 1:
271
+ corner_kwargs['truths'] = None # add truths later
272
+
273
+ fig = corner.corner(data=data, *args, **corner_kwargs)
274
+
275
+ # add the other truths as vertical/horizontal lines
276
+ if truths is not None and len(truths.shape) > 1:
277
+ corner_kwargs['truths'] = truths # add back in for legend
278
+ n_params = data.shape[1]
279
+ axs = np.array(fig.get_axes()).reshape((n_params, n_params))
280
+ for truth in truths:
281
+ for i in range(n_params):
282
+ ax = axs[i, i]
283
+ ax.axvline(truth[i], color=corner_kwargs['truth_color'], linestyle='-', linewidth=1)
284
+
285
+ for i in range(n_params):
286
+ for j in range(i):
287
+ ax = axs[i, j]
288
+ ax.axhline(truth[i], color=corner_kwargs['truth_color'], linestyle='-', linewidth=1)
289
+ ax.axvline(truth[j], color=corner_kwargs['truth_color'], linestyle='-', linewidth=1)
290
+
291
+ # prepare handles for legend
292
+ handles = []
293
+ handles_labels = []
294
+
295
+ handles.append(mpl.lines.Line2D([], [], color=corner_kwargs['color'],
296
+ linestyle='-', linewidth=2))
297
+ handles_labels.append(legend_label)
298
+ if 'truths' in corner_kwargs and corner_kwargs['truths'] is not None:
299
+ handles.append(mpl.lines.Line2D([], [], color=corner_kwargs['truth_color'],
300
+ linestyle='-', linewidth=1))
301
+ handles_labels.append('Truths')
302
+
303
+
304
+ # Overlay covariance contours if provided
305
+ if overlay_covariance is not None:
306
+ if means is None:
307
+ means = kwargs.get('truths', None)
308
+
309
+ if means is None:
310
+ # If no means or truths provided, default to the mean of the samples
311
+ means = np.mean(data, axis=0)
312
+
313
+ overlay_fim_covariance(
314
+ fig,
315
+ overlay_covariance,
316
+ means=means,
317
+ plot_1d=False,
318
+ alpha=0.8,
319
+ nsigmas=[1, 2 ,3],
320
+ colors='k'
321
+ )
322
+
323
+ handles.append(mpl.lines.Line2D([], [], color='k', linestyle='-', linewidth=1.5))
324
+ handles_labels.append(overlay_label)
325
+
326
+ # Add legend to the second leftmost top subplot
327
+ ax_legend = fig.get_axes()[1]
328
+ ax_legend.legend(handles, handles_labels, loc='upper left', fontsize=18)
329
+
330
+ save_or_show(fig, filename)
331
+
332
+
333
+ def traceplot(chain, labels=None, truths=None, filename=None):
334
+ """
335
+ Create trace plots for MCMC chains.
336
+
337
+ Args:
338
+ chain (np.ndarray): MCMC chain of shape (nsteps, nwalkers, nleaves, ndim).
339
+ labels (list, optional): List of parameter names for x-axis labels.
340
+ truths (array-like, optional): True parameter values to overlay as horizontal lines.
341
+ filename (str, optional): If provided, saves the figure to this filename.
342
+
343
+ Returns:
344
+ matplotlib.figure.Figure: The trace plot figure. If `filename` is provided, the figure is saved instead.
345
+ """
346
+
347
+ nsteps, nwalkers, nleaves, ndim = chain.shape
348
+ fig, axs = plt.subplots(ndim, 1, figsize=(10, 2.5 * ndim), sharex=True)
349
+
350
+ for i in range(ndim):
351
+ for w in range(nwalkers):
352
+ axs[i].plot(chain[:, w, :, i], alpha=0.5, rasterized=True)
353
+ if truths is not None:
354
+ truths = np.atleast_2d(truths)
355
+ for j in range(truths.shape[0]):
356
+ axs[i].axhline(truths[j, i], color='k', linestyle='--')
357
+ if labels is not None:
358
+ axs[i].set_ylabel(labels[i])
359
+
360
+ axs[-1].set_xlabel('Step')
361
+ plt.tight_layout()
362
+
363
+ save_or_show(fig, filename)
364
+
365
+
366
+ def plot_loglikelihood(logl, filename=None):
367
+ nsteps, nwalkers = logl.shape
368
+ fig = plt.figure(figsize=(10, 6))
369
+ for j in range(nwalkers):
370
+ plt.plot(logl[:, j], color=f"C{j % 10}", alpha=0.8, rasterized=True)
371
+ plt.xlabel("Sampler Iteration")
372
+ plt.ylabel("Log-Likelihood")
373
+
374
+ save_or_show(fig, filename)
375
+
376
+ # plot a facet grid of loglikelihood evolution for each walker
377
+ # Reshape: logl is (nsteps, nwalkers), need to flatten properly
378
+ max_logl = np.max(logl, axis=(0,1))
379
+ facet_logl = logl - max_logl
380
+
381
+ step = np.tile(range(nsteps), nwalkers)
382
+ walker = np.repeat(np.arange(nwalkers, dtype=int), nsteps)
383
+
384
+ df = pd.DataFrame(np.c_[facet_logl.flat, step, walker],
385
+ columns=[r"$\Delta \log\mathcal{L}$", "step", "walker"])
386
+
387
+ # Initialize a grid of plots with an Axes for each walker
388
+ grid = sns.FacetGrid(df, col="walker", hue="walker", #palette="tab20c",
389
+ col_wrap=int(np.floor(np.sqrt(nwalkers))), height=1.5)
390
+
391
+ # Draw a line plot to show the trajectory of each random walk
392
+ grid.map(plt.plot, "step", r"$\Delta \log\mathcal{L}$", marker=".", rasterized=True)
393
+
394
+ grid.refline(y=0, linestyle=":") # Add a horizontal reference line at y=0 ~ average loglikelihood at each step
395
+
396
+ # Adjust the arrangement of the plots
397
+ grid.set_titles(col_template="Walker {col_name:.0f}")
398
+ grid.set_axis_labels("Step", r"$\Delta \log\mathcal{L}$")
399
+ # Disable tight_layout to avoid warning
400
+ grid.tight_layout = lambda *args, **kwargs: None
401
+ grid.tight_layout()
402
+
403
+ # add overall title
404
+ plt.subplots_adjust(top=0.9)
405
+ grid.figure.suptitle(r"$\Delta \log\mathcal{L}_w = \log\mathcal{L}_w - \max(\log\mathcal{L})$", fontsize=16)
406
+
407
+ save_or_show(grid.figure, filename.replace('.png', '_facet.png') if filename else None)
408
+
409
+ def tempering_ridgeplot(chain, labels=None, palette=None,
410
+ bw_adjust=0.5, aspect=5, height=0.5, hspace=-0.25,
411
+ max_samples=10000, filename=None):
412
+ """
413
+ Create ridge plots of tempered distributions using overlapping KDE plots for all parameters.
414
+
415
+ This creates a visually appealing ridge plot (also known as joy plot) showing
416
+ how the posterior distribution broadens at higher temperatures. Each temperature
417
+ level is shown as a separate row with overlapping density estimates.
418
+ All parameters are shown as columns in a single FacetGrid figure.
419
+
420
+ Args:
421
+ chain (np.ndarray): MCMC chain of shape (nsteps, ntemps, nwalkers, nleaves, ndim).
422
+ labels (list, optional): List of parameter names. If provided, uses labels for column titles.
423
+ palette (str or list, optional): Seaborn color palette name or list of colors.
424
+ Default uses cubehelix_palette.
425
+ bw_adjust (float, optional): Bandwidth adjustment factor for KDE. Default is 0.5.
426
+ aspect (float, optional): Aspect ratio of each facet. Default is 5.
427
+ height (float, optional): Height of each temperature row in inches. Default is 0.5.
428
+ hspace (float, optional): Vertical spacing between temperature rows (negative for overlap). Default is -0.25.
429
+ max_samples (int, optional): Maximum number of samples per temperature for KDE. Default is 5000.
430
+ Subsampling speeds up KDE computation for large chains.
431
+ filename (str, optional): If provided, saves figure to this filename.
432
+
433
+ Returns:
434
+ matplotlib.figure.Figure: The figure containing all ridge plots.
435
+ If `filename` is provided, the figure is saved instead.
436
+ """
437
+ # Use seaborn context manager to temporarily set theme
438
+ with sns.axes_style("white", {"axes.facecolor": (0, 0, 0, 0)}):
439
+
440
+ nsteps, ntemps, nwalkers, nleaves, ndim = chain.shape
441
+
442
+ # Create color palette (blue=cold/β=1 at top, red=hot/β→0 at bottom)
443
+ if palette is None:
444
+ # Use coolwarm reversed: blue for cold (β=1), red for hot (β→0)
445
+ pal = sns.color_palette(DEFAULT_PALETTE, ntemps)
446
+ elif isinstance(palette, str):
447
+ pal = sns.color_palette(palette, ntemps)
448
+ else:
449
+ pal = palette
450
+
451
+ # Subsampling RNG
452
+ rng = np.random.default_rng(42)
453
+
454
+ # Build dataframe with samples from all temperatures and parameters
455
+ data_list = []
456
+ for param_idx in range(ndim):
457
+ param_label = labels[param_idx] if labels is not None else fr'$x_{param_idx}$'
458
+ for t in range(ntemps):
459
+ # Flatten samples across steps, walkers, leaves for the selected parameter
460
+ samples = chain[:, t, :, :, param_idx].reshape(-1)
461
+
462
+ # Remove NaNs
463
+ samples = samples[~np.isnan(samples)]
464
+
465
+ # Subsample if needed for faster KDE
466
+ if len(samples) > max_samples:
467
+ samples = rng.choice(samples, size=max_samples, replace=False)
468
+
469
+ temp_label = rf"$\beta_{{{t}}}$" #fr"$\beta$={betas[t]:.1e}"
470
+ data_list.append(pd.DataFrame({
471
+ 'x': samples,
472
+ 'temp': temp_label,
473
+ 'temp_idx': t,
474
+ 'param': param_label,
475
+ 'param_idx': param_idx
476
+ }))
477
+
478
+ df = pd.concat(data_list, ignore_index=True)
479
+
480
+ # Get unique temps in order (beta=1 first at top)
481
+ temp_order = df.drop_duplicates('temp_idx').sort_values('temp_idx', ascending=True)['temp'].tolist()
482
+
483
+ # Get unique params in order
484
+ param_order = df.drop_duplicates('param_idx').sort_values('param_idx', ascending=True)['param'].tolist()
485
+
486
+ # Compute x-axis limits from the cold posterior (β=1, temp_idx=0) for each parameter
487
+ # This ensures the cold posterior is always visible even when hot distributions are much wider
488
+ xlims = {}
489
+ cold_df = df[df['temp_idx'] == 0]
490
+ for param in param_order:
491
+ param_data = cold_df[cold_df['param'] == param]['x']
492
+ q_low, q_high = param_data.quantile([0.001, 0.999])
493
+ margin = (q_high - q_low) * 0.3 # Add 30% margin to show some broadening
494
+ xlims[param] = (q_low - margin, q_high + margin)
495
+
496
+ # Initialize the FacetGrid with row=temp, col=param
497
+ # Suppress tight_layout warnings during initialization
498
+ import warnings
499
+ with warnings.catch_warnings():
500
+ warnings.filterwarnings("ignore", message=".*Tight layout.*")
501
+ g = sns.FacetGrid(df, row="temp", col="param", hue="temp",
502
+ aspect=aspect, height=height,
503
+ palette=pal, row_order=temp_order, col_order=param_order,
504
+ sharex=False, sharey=False)
505
+
506
+ # Disable tight_layout to avoid warning with negative hspace
507
+ g.tight_layout = lambda *args, **kwargs: None
508
+
509
+ # Custom plotting function that clips KDE to the parameter's xlim
510
+ def plot_kde_clipped(x, color, label, **kwargs):
511
+ ax = plt.gca()
512
+ # Get the parameter for this column
513
+ col_idx = ax.get_subplotspec().colspan.start
514
+ param = param_order[col_idx]
515
+ clip = xlims[param]
516
+
517
+ sns.kdeplot(x, ax=ax, bw_adjust=bw_adjust, clip_on=False,
518
+ fill=True, alpha=1, linewidth=1.5, color=color, clip=clip)
519
+ sns.kdeplot(x, ax=ax, clip_on=False, color="w", lw=2, bw_adjust=bw_adjust, clip=clip)
520
+ ax.set_xlim(clip)
521
+
522
+ g.map(plot_kde_clipped, "x")
523
+
524
+ # Add horizontal reference line at y=0
525
+ g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)
526
+
527
+ # Label function for temperature labels (only on first column)
528
+ def label_temp(x, color, label):
529
+ ax = plt.gca()
530
+ # Only add label if this is the first column
531
+ if ax.get_subplotspec().colspan.start == 0:
532
+ ax.text(-0.2, 0.2, label, fontweight="bold", color=color,
533
+ ha="left", va="center", transform=ax.transAxes, fontsize=14)
534
+
535
+ g.map(label_temp, "x")
536
+
537
+ # Set subplots to overlap vertically
538
+ g.figure.subplots_adjust(hspace=hspace, wspace=0.1)
539
+
540
+ # use parameter labels in the x axes (on the last row)
541
+ for ax, param in zip(g.axes[-1], param_order):
542
+ ax.set_xlabel(param, fontsize=14)
543
+
544
+ # Remove xticks from all rows except the bottom
545
+ for row_idx in range(len(temp_order) - 1):
546
+ for col_idx in range(len(param_order)):
547
+ g.axes[row_idx, col_idx].set_xticks([])
548
+
549
+ # set xlims again to ensure consistency
550
+ param = param_order[col_idx]
551
+ g.axes[row_idx, col_idx].set_xlim(xlims[param])
552
+
553
+ # Remove axes details that don't work well with overlap
554
+ g.set_titles("")
555
+ g.set(yticks=[], ylabel="")
556
+ g.despine(bottom=True, left=True)
557
+
558
+ # Add overall title
559
+ g.figure.suptitle('Tempered Distributions', y=1.02, fontsize=14)
560
+
561
+ save_or_show(g.figure, filename)
562
+
563
+ # # Properly restore the original seaborn style
564
+ # sns.set_theme(style=original_style['style'], rc=original_style)
565
+
566
+ # return g.figure
567
+
568
+ def plot_swap_acceptance(swap_acceptance_fraction, palette=None, filename=None):
569
+ """
570
+ Plot the temperature swap acceptance fraction between adjacent temperature levels.
571
+
572
+ Args:
573
+ swap_acceptance_fraction (np.ndarray): Swap acceptance fraction between adjacent
574
+ temperatures, shape (ntemps-1,). Element i corresponds to swaps between
575
+ temperature i and i+1.
576
+ palette (str or list, optional): Seaborn color palette name or list of colors.
577
+ filename (str, optional): If provided, saves the figure to this filename.
578
+
579
+ Returns:
580
+ matplotlib.figure.Figure: The swap acceptance plot figure.
581
+ """
582
+ ntemps = swap_acceptance_fraction.shape[0] + 1
583
+
584
+ fig, ax = plt.subplots(figsize=(8, 5))
585
+
586
+ # X-axis: temperature pair indices (swap between temp i and i+1)
587
+ x = np.arange(ntemps - 1)
588
+
589
+ # Create labels for each swap pair
590
+ labels = [fr'{i}$\leftrightarrow${i+1}' for i in range(ntemps - 1)]
591
+
592
+ # Color by temperature (use coolwarm to match ridgeplot)
593
+ palette = palette if palette is not None else DEFAULT_PALETTE
594
+ colors = sns.color_palette(palette, ntemps - 1)
595
+
596
+ # Bar plot
597
+ bars = ax.bar(x, swap_acceptance_fraction, color=colors, edgecolor='white', linewidth=0.5)
598
+
599
+ # Add horizontal reference line at common target acceptance (0.2-0.4 is often good)
600
+ ax.axhline(y=0.25, color='gray', linestyle='--', linewidth=1, alpha=0.7, label='0.25')
601
+
602
+ # Add beta values as secondary labels
603
+ # ax2 = ax.twiny()
604
+ # ax2.set_xlim(ax.get_xlim())
605
+ # ax2.set_xticks(x)
606
+ # beta_labels = [fr'$\beta$={betas[i]:.1e}$\leftrightarrow${betas[i+1]:.1e}' for i in range(ntemps - 1)]
607
+ # ax2.set_xticklabels(beta_labels, fontsize=8, rotation=45, ha='left')
608
+
609
+ # Labels and formatting
610
+ ax.set_xticks(x)
611
+ ax.set_xticklabels(labels, fontsize=10, rotation=45, ha='right')
612
+ ax.set_xlabel('Temperature Pair (index)', fontsize=12)
613
+ ax.set_ylabel('Swap Acceptance Fraction', fontsize=12)
614
+ ax.set_title('Temperature Swap Acceptance', fontsize=14)
615
+ ax.set_ylim(0, 1.1)
616
+ ax.legend(loc='upper left', fontsize=10)
617
+
618
+ # Add value annotations on bars
619
+ for bar, val in zip(bars, swap_acceptance_fraction):
620
+ height = bar.get_height()
621
+ ax.annotate(f'{val:.2f}',
622
+ xy=(bar.get_x() + bar.get_width() / 2, height),
623
+ xytext=(0, 3), # 3 points vertical offset
624
+ textcoords="offset points",
625
+ ha='center', va='bottom', fontsize=9)
626
+
627
+ plt.tight_layout()
628
+
629
+ save_or_show(fig, filename)
630
+
631
+ return fig
632
+
633
+ def plot_logl_betas(betas: np.ndarray,
634
+ logl: np.ndarray,
635
+ palette: str = None,
636
+ filename: str = None
637
+ ):
638
+ """
639
+ Plots the evolution of log-likelihood values for each temperature.
640
+
641
+ Args:
642
+ betas (numpy.ndarray): Array of inverse temperatures.
643
+ logl (numpy.ndarray): Array of log-likelihood values.
644
+ palette (str, optional): Seaborn color palette name or list of colors.
645
+ filename (str, optional): If provided, saves the figure to this filename.
646
+ Returns:
647
+ None
648
+ """
649
+ fig = plt.figure(figsize=(10, 6))
650
+ ntemp = betas.shape[1]
651
+ tempcolors = sns.color_palette(palette if palette is not None else DEFAULT_PALETTE, ntemp)
652
+ for temp in range(ntemp):
653
+ plt.semilogx(betas[-1, temp], np.mean(logl[:, temp]), '.', c=tempcolors[temp], label=f'$T_{temp}$')
654
+
655
+ logZ, dlogZ = stepping_stone_log_evidence(betas[-1], logl)
656
+
657
+ plt.ylabel(r'$<\log{\mathcal{L}}>_{\beta}$')
658
+ plt.xlabel(r'$\beta$')
659
+ plt.title(r'$\log{\mathcal{Z}} = %.2f \pm %.2f$' % (logZ, dlogZ))
660
+
661
+ save_or_show(fig, filename)
662
+
663
+ def plot_betas_evolution(betas: np.ndarray, palette: str = None, filename: str = None):
664
+ """
665
+ Plots the evolution of inverse temperatures (betas) over sampling steps.
666
+
667
+ Args:
668
+ betas (numpy.ndarray): Array of inverse temperatures of shape (nsteps, ntemps).
669
+ palette (str, optional): Seaborn color palette name or list of colors.
670
+ filename (str, optional): If provided, saves the figure to this filename.
671
+ Returns:
672
+ None
673
+ """
674
+ nsteps, ntemps = betas.shape
675
+ tempcolors = sns.color_palette(palette if palette is not None else DEFAULT_PALETTE, ntemps)
676
+ fig = plt.figure(figsize=(10, 6))
677
+ for temp in range(ntemps):
678
+ plt.plot(range(nsteps), betas[:, temp], color=tempcolors[temp], linewidth=1.5, alpha=0.8, rasterized=True)
679
+ plt.xlabel('Sampler Iteration')
680
+ plt.ylabel(r'Inverse Temperature ($\beta$)')
681
+ plt.title('Evolution of Inverse Temperatures')
682
+
683
+ # Create temperature color gradient legend
684
+
685
+ ax = plt.gca()
686
+ legend_width = 0.15
687
+ legend_height = 0.03
688
+ legend_x = 0.75
689
+ legend_y = 0.9
690
+
691
+ # Create gradient patches
692
+ for i, color in enumerate(tempcolors[::-1]): # Reverse to show cold->hot left to right
693
+ rect = Rectangle(
694
+ (legend_x + i * legend_width / ntemps, legend_y),
695
+ legend_width / ntemps, legend_height,
696
+ transform=ax.transAxes,
697
+ facecolor=to_rgba(color, 0.7),
698
+ edgecolor='none'
699
+ )
700
+ ax.add_patch(rect)
701
+
702
+ # Add border and labels
703
+ border = Rectangle(
704
+ (legend_x, legend_y), legend_width, legend_height,
705
+ transform=ax.transAxes,
706
+ facecolor='none',
707
+ edgecolor='black',
708
+ linewidth=0.5
709
+ )
710
+ ax.add_patch(border)
711
+
712
+ # Add text labels
713
+ ax.text(legend_x - 0.01, legend_y + legend_height / 2, r'$T_{\rm max}$',
714
+ transform=ax.transAxes, ha='right', va='center', fontsize=11, fontweight='normal', antialiased=True)
715
+ ax.text(legend_x + legend_width + 0.01, legend_y + legend_height / 2, r'$T_0$',
716
+ transform=ax.transAxes, ha='left', va='center', fontsize=11, fontweight='normal', antialiased=True)
717
+
718
+ save_or_show(fig, filename)
719
+
720
+
721
+ # RJ plots
722
+ def plot_leaves(nleaves: np.ndarray,
723
+ nleaves_min: int,
724
+ nleaves_max: int,
725
+ palette: str = None,
726
+ iteration: int = 0,
727
+ filename: str = None):
728
+ """
729
+ Plot the histogram of the number of leaves for each temperature.
730
+
731
+ This method plots a histogram of the number of leaves for each temperature in the `rj_branches` dictionary.
732
+ It uses the `self.backend` object to get the number of leaves for each temperature.
733
+ The histogram is plotted using the `plt.hist` function from the `matplotlib.pyplot` module.
734
+ The plot includes temperature-specific colors and a legend for the colors.
735
+
736
+ Returns:
737
+ None
738
+ """
739
+ bns = (np.arange(nleaves_min, nleaves_max + 2) - 0.5)
740
+ ntemps = nleaves.shape[1]
741
+ tempcolors = sns.color_palette(palette if palette is not None else DEFAULT_PALETTE, ntemps)
742
+
743
+ fig = plt.figure(figsize=(8, 5))
744
+
745
+ for temp, tempcolor in enumerate(tempcolors):
746
+ plt.hist(nleaves[:, temp].flatten(), bins=bns, histtype="stepfilled", edgecolor=tempcolor,
747
+ facecolor=to_rgba(tempcolor, 0.2), density=True, ls='-', zorder=100 - temp, rasterized=True)
748
+
749
+ plt.xlabel('Number of leaves')
750
+ plt.ylabel('Density')
751
+
752
+ # Create temperature color gradient legend
753
+ # Use a horizontal gradient showing cold (blue) to hot (red)
754
+
755
+ # Add a color bar as legend showing temperature progression
756
+ ax = plt.gca()
757
+ legend_width = 0.15
758
+ legend_height = 0.03
759
+ legend_x = 0.75
760
+ legend_y = 0.9
761
+
762
+ # Create gradient patches
763
+ n_gradient = len(tempcolors)
764
+ for i, color in enumerate(tempcolors[::-1]): # Reverse to show cold->hot left to right
765
+ rect = Rectangle(
766
+ (legend_x + i * legend_width / n_gradient, legend_y),
767
+ legend_width / n_gradient, legend_height,
768
+ transform=ax.transAxes,
769
+ facecolor=to_rgba(color, 0.7),
770
+ edgecolor='none'
771
+ )
772
+ ax.add_patch(rect)
773
+
774
+ # Add border and labels
775
+ border = Rectangle(
776
+ (legend_x, legend_y), legend_width, legend_height,
777
+ transform=ax.transAxes,
778
+ facecolor='none',
779
+ edgecolor='black',
780
+ linewidth=0.5
781
+ )
782
+ ax.add_patch(border)
783
+
784
+ # Add text labels
785
+ ax.text(legend_x - 0.01, legend_y + legend_height / 2, r'$T_{\rm max}$',
786
+ transform=ax.transAxes, ha='right', va='center', fontsize=11, fontweight='normal', antialiased=True)
787
+ ax.text(legend_x + legend_width + 0.01, legend_y + legend_height / 2, r'$T_0$',
788
+ transform=ax.transAxes, ha='left', va='center', fontsize=11, fontweight='normal', antialiased=True)
789
+
790
+ fig.text(0.07, 0.08, f"Step: {iteration}", ha='left', va='top', fontfamily='serif', c='k')
791
+ #plt.title(key)
792
+ save_or_show(fig, filename)
793
+
794
+ def plot_leaves_evolution(nleaves: np.ndarray,
795
+ filename: str = None):
796
+ """
797
+ Plot the evolution of the number of leaves per walker in the cold chain over sampling steps.
798
+ Args:
799
+ nleaves (np.ndarray): Array of number of leaves in the cold chain, shape (nsteps, nwalkers).
800
+ filename (str, optional): If provided, saves the figure to this filename.
801
+ Returns:
802
+ None
803
+ """
804
+ nsteps, nwalkers = nleaves.shape
805
+ fig = plt.figure(figsize=(10, 6))
806
+ for w in range(nwalkers):
807
+ plt.plot(range(nsteps), nleaves[:, w], color=f"C{w % 10}", linewidth=1.5, alpha=0.8, rasterized=True)
808
+ plt.xlabel('Sampler Iteration')
809
+ plt.ylabel('Number of Leaves')
810
+ plt.title('Evolution of Number of Leaves in Cold Chain')
811
+ save_or_show(fig, filename)
812
+
813
+ def plot_acceptance_fraction(steps: typing.Union[np.ndarray, list],
814
+ total_acceptance_fraction: np.ndarray,
815
+ moves_acceptance_fraction: dict,
816
+ filename: str = None):
817
+ """
818
+ Plot the acceptance fraction for different moves over sampling steps.
819
+
820
+ Args:
821
+
822
+ """
823
+
824
+ fig = plt.figure(figsize=(10, 6))
825
+ # cold chain total acceptance fraction
826
+ plt.plot(steps, total_acceptance_fraction[:, 0].mean(axis=1), label='Total', color='black', linewidth=2)
827
+
828
+ # skip if moves_acceptance_fraction is empty
829
+ if len(moves_acceptance_fraction) != 0:
830
+ for move, acc_fraction in moves_acceptance_fraction.items():
831
+ plt.plot(steps, acc_fraction[:, 0].mean(axis=1), marker='o', label=move)
832
+
833
+ plt.axhline(y=0.234, color='gray', linestyle='--', linewidth=1, alpha=0.7, label='0.234')
834
+ plt.legend()
835
+ plt.xlabel('Sampler Iteration')
836
+ plt.ylabel('Acceptance Fraction')
837
+ plt.title('Acceptance Fraction Over Time')
838
+
839
+ save_or_show(fig, filename)
840
+
841
+ def plot_tempered_acceptance_fraction(steps: typing.Union[np.ndarray, list],
842
+ total_acceptance_fraction: np.ndarray,
843
+ palette: str = None,
844
+ filename: str = None):
845
+ """
846
+ Plot the acceptance fraction for different moves over sampling steps.
847
+
848
+ Args:
849
+ steps (np.ndarray or list): Array of sampling steps.
850
+ total_acceptance_fraction (np.ndarray): Array of total acceptance fractions, shape (nsteps, ntemps, nwalkers).
851
+ palette (str or list, optional): Seaborn color palette name or list of colors.
852
+ filename (str, optional): If provided, saves the figure to this filename.
853
+ """
854
+ ntemps = total_acceptance_fraction.shape[1]
855
+ tempcolors = sns.color_palette(palette if palette is not None else DEFAULT_PALETTE, ntemps)
856
+
857
+ fig = plt.figure(figsize=(10, 6))
858
+
859
+ for temp in range(ntemps):
860
+ plt.plot(steps, total_acceptance_fraction[:, temp].mean(axis=1), color=tempcolors[temp], linewidth=1.5, marker='o', alpha=0.8, rasterized=True)
861
+
862
+ ax = plt.gca()
863
+ legend_width = 0.15
864
+ legend_height = 0.03
865
+ legend_x = 0.75
866
+ legend_y = 0.9
867
+
868
+ # Create gradient patches
869
+ for i, color in enumerate(tempcolors[::-1]): # Reverse to show cold->hot left to right
870
+ rect = Rectangle(
871
+ (legend_x + i * legend_width / ntemps, legend_y),
872
+ legend_width / ntemps, legend_height,
873
+ transform=ax.transAxes,
874
+ facecolor=to_rgba(color, 0.7),
875
+ edgecolor='none'
876
+ )
877
+ ax.add_patch(rect)
878
+
879
+ # Add border and labels
880
+ border = Rectangle(
881
+ (legend_x, legend_y), legend_width, legend_height,
882
+ transform=ax.transAxes,
883
+ facecolor='none',
884
+ edgecolor='black',
885
+ linewidth=0.5
886
+ )
887
+ ax.add_patch(border)
888
+
889
+ # Add text labels
890
+ ax.text(legend_x - 0.01, legend_y + legend_height / 2, r'$T_{\rm max}$',
891
+ transform=ax.transAxes, ha='right', va='center', fontsize=11, fontweight='normal', antialiased=True)
892
+ ax.text(legend_x + legend_width + 0.01, legend_y + legend_height / 2, r'$T_0$',
893
+ transform=ax.transAxes, ha='left', va='center', fontsize=11, fontweight='normal', antialiased=True)
894
+
895
+ plt.legend()
896
+ plt.xlabel('Sampler Iteration')
897
+ plt.ylabel('Acceptance Fraction')
898
+
899
+ ymin, ymax = plt.ylim()
900
+ plt.ylim(ymin, 1.2 * ymax)
901
+
902
+ plt.title('Acceptance Fraction Over Time')
903
+
904
+ save_or_show(fig, filename)
905
+
906
+ def plot_act_evolution(chain: dict,
907
+ iteration: int = 0,
908
+ parent_folder: str = '.'):
909
+
910
+ """
911
+ Plot the evolution of the autocorrelation time for each branch in the chain. Also plots the ACT values per parameter in each branch.
912
+
913
+ Args:
914
+ chain (Dict): Dictionary of MCMC chains for different branches.
915
+ iteration (int, optional): Current iteration number for labeling. Default is 0.
916
+ parent_folder (str, optional): Folder to save the plots. Default is current directory.
917
+ """
918
+
919
+ NPOINTS = 10
920
+ points = np.exp(np.linspace(np.log(min(100, iteration)), np.log(iteration), NPOINTS)).astype(int)
921
+
922
+ taus = {}
923
+ for branch, samples in chain.items():
924
+ # create branch folder
925
+ branch_folder = os.path.join(parent_folder, branch)
926
+ os.makedirs(branch_folder, exist_ok=True)
927
+
928
+ nsteps, ntemps, nwalkers, nleaves, ndim = samples.shape
929
+ cold_chain = samples[:, 0, :, :, :].reshape(nsteps, 1, nwalkers, nleaves * ndim)
930
+ #cold_chain = cold_chain[~np.isnan(cold_chain).any(axis=-1)] # remove NaNs
931
+ if np.isnan(cold_chain).any(axis=-1).any():
932
+ print(f"Skipping ACT plot for branch {branch} due to NaNs in the cold chain.")
933
+ continue
934
+ tmp = []
935
+ for i, point in enumerate(points):
936
+ if point > cold_chain.shape[0]:
937
+ continue
938
+
939
+ tau = get_integrated_act(cold_chain[:point], average=True)
940
+ tmp.append(tau.squeeze())
941
+
942
+ taus[branch] = np.array(tmp)
943
+
944
+ fig = plt.figure(figsize=(10, 6))
945
+ for d in range(ndim):
946
+ plt.loglog(points[:len(taus[branch])], taus[branch][:, d], marker='o', label=fr'$x_{d}$')
947
+
948
+ # add tau = iteration / 50 line for reference
949
+ xaxis = np.logspace(0, np.log10(iteration), 100)
950
+ #xaxis[0] = 1 # avoid plotting tau=0 at step 0
951
+
952
+ plt.loglog(xaxis, xaxis / 50, linestyle='--', color='k', lw=2, label=r'$\tau$ = Nsteps / 50')
953
+ plt.xlabel('Number of Steps')
954
+ plt.ylabel('Integrated Autocorrelation Time')
955
+ plt.title(f'Autocorrelation Time Evolution - Branch: {branch}')
956
+
957
+ ymax = np.max(taus[branch]) * 1.2
958
+ ymin = np.min(taus[branch]) * 0.5
959
+ if ymin < ymax:
960
+ plt.ylim(ymin, ymax)
961
+ else:
962
+ plt.ylim(0, ymax)
963
+ plt.xlim(min(points) * 0.9, iteration * 1.1)
964
+ plt.legend()
965
+ save_or_show(fig, os.path.join(branch_folder, f'act_evolution.png'))
966
+
967
+ # plot the maximum ACT across all parameters for each branch
968
+ if len(taus) > 0:
969
+ fig = plt.figure(figsize=(10, 6))
970
+ for branch in taus.keys():
971
+ max_tau = np.max(taus[branch], axis=1)
972
+ plt.loglog(points[:len(max_tau)], max_tau, marker='o', label=f'Branch: {branch}')
973
+
974
+ plt.loglog(xaxis, xaxis / 50, linestyle='--', color='k', lw=2, label=r'$\tau$ = Nsteps / 50')
975
+ plt.xlabel('Number of Steps')
976
+ plt.ylabel('Maximum Integrated Autocorrelation Time')
977
+ plt.title('Maximum Autocorrelation Time Evolution Across Branches')
978
+ ymax = max([np.max(np.max(taus[branch], axis=1)) for branch in taus.keys()]) * 1.2
979
+ ymin = min([np.min(np.max(taus[branch], axis=1)) for branch in taus.keys()]) * 0.5
980
+ if ymin < ymax:
981
+ plt.ylim(ymin, ymax)
982
+ else:
983
+ plt.ylim(0, ymax)
984
+ plt.xlim(min(points) * 0.9, iteration * 1.1)
985
+ plt.legend()
986
+ save_or_show(fig, os.path.join(parent_folder, f'max_act_evolution.png'))
987
+
988
+
989
+ def produce_base_plots(chain: dict,
990
+ logl: np.ndarray,
991
+ truths: dict = None,
992
+ overlay_covariance: dict = None,
993
+ labels: dict = None,
994
+ iteration: int = 0,
995
+ parent_folder: str = '.',
996
+ ):
997
+
998
+ """
999
+ Produce a set of standard diagnostic plot. These include:
1000
+
1001
+ * corner plots for the cold chain per branch,
1002
+ * trace plots for the cold chain per branch,
1003
+ * log-likelihood evolution plots.
1004
+
1005
+ Args:
1006
+ chain (Dict): Dictionary of MCMC chains for different branches.
1007
+ logl (np.ndarray): Log-likelihood array of shape (nsteps, ntemperatures, nwalkers).
1008
+ truths (Dict, optional): Dictionary of true parameter values for different branches.
1009
+ overlay_covariance (Dict, optional): Dictionary of covariance matrices to overlay on corner plots.
1010
+ labels (Dict, optional): Dictionary of parameter labels for different branches.
1011
+ parent_folder (str, optional): Folder to save the plots. Default is current directory.
1012
+ """
1013
+
1014
+ # create dictionary with a color shade for each branch
1015
+ branches = list(chain.keys())
1016
+ legend_label = 'Samples at step %d' % iteration
1017
+
1018
+ palette = 'Blues'
1019
+ colors = sns.color_palette(palette, n_colors=len(branches))
1020
+ #colors = sns.color_palette("mako", n_colors=len(branches)+2)[2:] # avoid very light colors
1021
+ branch_colors = {branch: colors[i] for i, branch in enumerate(branches)}
1022
+
1023
+ for branch, samples in chain.items():
1024
+ branch_labels = labels.get(branch, None) if labels else None
1025
+ branch_truths = truths.get(branch, None) if truths else None
1026
+ branch_cov = overlay_covariance.get(branch, None) if overlay_covariance else None
1027
+
1028
+ # create branch folder
1029
+ branch_folder = os.path.join(parent_folder, branch)
1030
+ os.makedirs(branch_folder, exist_ok=True)
1031
+
1032
+ nsteps, ntemps, nwalkers, nleaves, ndim = samples.shape
1033
+ cold_chain = samples[:, 0, :, :, :].reshape(-1, ndim)
1034
+ cold_chain = cold_chain[~np.isnan(cold_chain).any(axis=1)] # remove NaNs
1035
+
1036
+ cornerplot(
1037
+ cold_chain,
1038
+ means=branch_truths,
1039
+ overlay_covariance=branch_cov,
1040
+ legend_label=legend_label,
1041
+ truths=branch_truths,
1042
+ overlay_label='Information Matrix Covariance' if branch_cov is not None else None,
1043
+ labels=branch_labels,
1044
+ color=branch_colors[branch],
1045
+ filename=os.path.join(branch_folder, f'cornerplot.png')
1046
+ )
1047
+
1048
+ traceplot(
1049
+ samples[:, 0, :, :, :],
1050
+ labels=branch_labels,
1051
+ truths=branch_truths,
1052
+ filename=os.path.join(branch_folder, f'traceplot.png')
1053
+ )
1054
+
1055
+ plot_loglikelihood(
1056
+ logl[:, 0, :],
1057
+ filename=os.path.join(parent_folder, f'loglikelihood.png')
1058
+ )
1059
+
1060
+ def produce_tempering_plots(chain: dict,
1061
+ betas: np.ndarray,
1062
+ logl: np.ndarray,
1063
+ swap_acceptance_fraction: np.ndarray,
1064
+ labels: dict = None,
1065
+ parent_folder: str = '.',
1066
+ palette: str = None
1067
+ ):
1068
+ """
1069
+ Produce tempering ridge plots for each branch in the chain. These include:
1070
+
1071
+ * ridge plots of the tempered distributions per parameter per branch,
1072
+ * the swap acceptance fraction between adjacent temperatures,
1073
+ * the averaged log-likelihood vs. betas plot,
1074
+ * the evolution of betas over sampling steps.
1075
+
1076
+ Args:
1077
+ chain (Dict): Dictionary of MCMC chains for different branches.
1078
+ betas (np.ndarray): Inverse temperatures of shape (nsteps, ntemps,).
1079
+ swap_acceptance_fraction (np.ndarray): Swap acceptance fraction between adjacent temperatures.
1080
+ labels (Dict, optional): Dictionary of parameter labels for different branches.
1081
+ logl (np.ndarray): Log-likelihood array of shape (nsteps, ntemperatures, nwalkers).
1082
+ parent_folder (str, optional): Folder to save the plots. Default is current directory.
1083
+ palette (str or list, optional): Seaborn color palette name or list of colors.
1084
+ """
1085
+
1086
+ for branch, samples in chain.items():
1087
+ branch_labels = labels.get(branch, None) if labels else None
1088
+
1089
+ # create branch folder
1090
+ branch_folder = os.path.join(parent_folder, branch)
1091
+ os.makedirs(branch_folder, exist_ok=True)
1092
+
1093
+ tempering_ridgeplot(
1094
+ samples,
1095
+ labels=branch_labels,
1096
+ palette=palette,
1097
+ filename=os.path.join(branch_folder, f'tempering_ridgeplot.png')
1098
+ )
1099
+
1100
+ plot_swap_acceptance(
1101
+ swap_acceptance_fraction,
1102
+ palette=palette,
1103
+ filename=os.path.join(parent_folder, f'swap_acceptance.png')
1104
+ )
1105
+
1106
+ plot_logl_betas(
1107
+ betas,
1108
+ logl,
1109
+ palette=palette,
1110
+ filename=os.path.join(parent_folder, f'logl_betas.png')
1111
+ )
1112
+
1113
+ plot_betas_evolution(
1114
+ betas,
1115
+ palette=palette,
1116
+ filename=os.path.join(parent_folder, f'betas_evolution.png')
1117
+ )
1118
+
1119
+ def produce_advanced_plots(steps: typing.Union[np.ndarray, list],
1120
+ total_acceptance_fraction: np.ndarray,
1121
+ moves_acceptance_fraction: dict,
1122
+ palette: str = None,
1123
+ iteration: int = 0,
1124
+ chain: dict = None,
1125
+ parent_folder: str = '.'):
1126
+ """
1127
+ Produce advanced diagnostic plots. These include:
1128
+
1129
+ * autocorrelation time evolution per parameter per branch in the cold chain,
1130
+ * the comparison of the maximum autocorrelation time in each branch against the number of steps,
1131
+ * the acceptance fraction evolution over steps in the cold chain (both overall and per move),
1132
+ * the overall acceptance fraction evolution over steps per temperature.
1133
+
1134
+ Args:
1135
+ steps (Union[np.ndarray, list]): Array or list of sampling steps.
1136
+ total_acceptance_fraction (np.ndarray): Total acceptance fraction array of shape (nsteps, ntemps, nwalkers).
1137
+ moves_acceptance_fraction (Dict): Dictionary of acceptance fractions for different moves.
1138
+ parent_folder (str, optional): Folder to save the plots. Default is current directory.
1139
+ """
1140
+
1141
+ plot_acceptance_fraction(
1142
+ steps,
1143
+ total_acceptance_fraction,
1144
+ moves_acceptance_fraction,
1145
+ filename=os.path.join(parent_folder, f'acceptance_fraction.png')
1146
+ )
1147
+
1148
+ plot_tempered_acceptance_fraction(
1149
+ steps,
1150
+ total_acceptance_fraction,
1151
+ palette=palette,
1152
+ filename=os.path.join(parent_folder, f'tempered_acceptance_fraction.png')
1153
+ )
1154
+
1155
+ plot_act_evolution(
1156
+ chain,
1157
+ iteration=iteration,
1158
+ parent_folder=parent_folder
1159
+ )
1160
+
1161
+ def produce_rj_plots(nleaves: dict,
1162
+ nleaves_min: dict,
1163
+ nleaves_max: dict,
1164
+ palette: str = None,
1165
+ parent_folder: str = '.',
1166
+ iteration: int = 0):
1167
+
1168
+ """
1169
+ Produce RJ diagnostic plots for each branch in the chain. At present, only plots the histogram of the number of leaves across temperatures.
1170
+
1171
+ Args:
1172
+ nleaves (Dict): Dictionary of number of leaves arrays for different branches.
1173
+ nleaves_min (Dict): Dictionary of minimum number of leaves for different branches.
1174
+ nleaves_max (Dict): Dictionary of maximum number of leaves for different branches.
1175
+ palette (str or list, optional): Seaborn color palette name or list of colors.
1176
+ parent_folder (str, optional): Folder to save the plots. Default is current directory.
1177
+ iteration (int, optional): Current iteration number for labeling plots.
1178
+ """
1179
+
1180
+ for branch, leaves in nleaves.items():
1181
+ branch_nleaves_min = nleaves_min.get(branch, 1)
1182
+ branch_nleaves_max = nleaves_max.get(branch, 1)
1183
+
1184
+ # skip if there is no reversible-jump sampling for this branch
1185
+ if branch_nleaves_max <= branch_nleaves_min:
1186
+ continue
1187
+ # create branch folder
1188
+ branch_folder = os.path.join(parent_folder, branch)
1189
+ os.makedirs(branch_folder, exist_ok=True)
1190
+
1191
+ plot_leaves(
1192
+ leaves,
1193
+ branch_nleaves_min,
1194
+ branch_nleaves_max,
1195
+ palette=palette,
1196
+ iteration=iteration,
1197
+ filename=os.path.join(branch_folder, f'rj_nleaves.png')
1198
+ )
1199
+
1200
+ # plot_leaves_evolution(
1201
+ # leaves[:, 0],
1202
+ # filename=os.path.join(branch_folder, f'rj_nleaves_evolution.png')
1203
+ # )
1204
+
1205
+
1206
+
1207
+
1208
+ class PlotContainer:
1209
+ """
1210
+ An Update that generates diagnostic plots at specified intervals
1211
+
1212
+ Args:
1213
+ plots (list or str): List of plot types to generate. Options are 'base', 'tempering', 'advanced', 'rj', or 'all'. If multiple plot types are desired, provide a list of strings.
1214
+ branches (list, optional): List of branch names to generate plots for. If None, all branches are used.
1215
+ truths (dict, optional): Dictionary of true parameter values for different branches.
1216
+ overlay_covariance (dict, optional): Dictionary of covariance matrices to overlay on corner plots.
1217
+ tempering_palette (str or list, optional): Seaborn color palette name or list of colors for tempering plots. If None, it defaults to `icefire`.
1218
+ parent_folder (str, optional): Folder to save the plots. Default is current directory.
1219
+ discard (float, optional): Number of initial samples to discard from the chain before plotting. If between 0 and 1, it is treated as fraction of total samples. Default is 0.
1220
+ stop (int, optional): Maximum number of steps to generate plots for. Default is 10000.
1221
+ """
1222
+
1223
+ def __init__(self,
1224
+ backend: Backend = None,
1225
+ plots : typing.Union[list, str] = 'base',
1226
+ branches : list = None,
1227
+ truths: dict = None,
1228
+ overlay_covariance: dict = None,
1229
+ tempering_palette: str = None,
1230
+ parent_folder: str = '.',
1231
+ discard: float = 0,
1232
+ stop: int = int(1e4),
1233
+ ):
1234
+ """
1235
+ Initialize the PlotContainer.
1236
+ """
1237
+
1238
+ self.backend = backend
1239
+
1240
+ self.parent_folder = parent_folder
1241
+ os.makedirs(self.parent_folder, exist_ok=True)
1242
+
1243
+ allowable_plots = ['base', 'tempering', 'advanced', 'rj']
1244
+ self.branches = branches
1245
+
1246
+ if isinstance(plots, str):
1247
+ if plots == 'all':
1248
+ plots = allowable_plots
1249
+ else:
1250
+ plots = [plots]
1251
+
1252
+ for plot in plots:
1253
+ if plot not in allowable_plots:
1254
+ raise ValueError(f"Plot type '{plot}' not recognized. Allowable types: {allowable_plots}")
1255
+ self.plots = plots
1256
+
1257
+ #self.labels = labels
1258
+ self.truths = truths
1259
+ self.overlay_covariance = overlay_covariance
1260
+ self.discard = discard
1261
+ self.tempering_palette = tempering_palette
1262
+
1263
+ self.steps = []
1264
+ self.total_acceptance_fraction = None
1265
+ self.move_acceptance_fractions = {}
1266
+
1267
+ self.stop = stop
1268
+
1269
+ @property
1270
+ def backend(self):
1271
+ return self._backend
1272
+ @backend.setter
1273
+ def backend(self, value):
1274
+ self._backend = value
1275
+
1276
+ @property
1277
+ def truths(self):
1278
+ return self._truths
1279
+ @truths.setter
1280
+ def truths(self, value):
1281
+ self._truths = value
1282
+
1283
+ @property
1284
+ def overlay_covariance(self):
1285
+ return self._overlay_covariance
1286
+ @overlay_covariance.setter
1287
+ def overlay_covariance(self, value):
1288
+ self._overlay_covariance = value
1289
+
1290
+ def produce_plots(self, sampler=None) -> None:
1291
+ """
1292
+ Generate diagnostic plots at specified intervals.
1293
+
1294
+ Args:
1295
+ sampler: The sampler object. If not provided, uses self.backend. In this cases not all the plots could be available.
1296
+ Returns:
1297
+ None
1298
+ """
1299
+
1300
+ if self.backend.iteration > self.stop:
1301
+ return
1302
+
1303
+ labels = self.backend.key_order
1304
+
1305
+ discard = int(self.discard) if self.discard >= 1 else int(self.discard * self.backend.iteration)
1306
+ chain = self.backend.get_chain(discard=discard)
1307
+ logl = self.backend.get_log_like(discard=discard)
1308
+ betas = self.backend.get_betas(discard=discard)
1309
+
1310
+ if self.branches is not None:
1311
+ chain = {branch: chain[branch] for branch in self.branches if branch in chain}
1312
+ logl = {branch: logl[branch] for branch in self.branches if branch in logl}
1313
+ betas = {branch: betas[branch] for branch in self.branches if branch in betas}
1314
+
1315
+ for plot in self.plots:
1316
+ base_folder = os.path.join(self.parent_folder, plot)
1317
+ os.makedirs(base_folder, exist_ok=True)
1318
+
1319
+ if plot == 'base':
1320
+ produce_base_plots(
1321
+ chain=chain,
1322
+ logl=logl,
1323
+ truths=self.truths,
1324
+ overlay_covariance=self.overlay_covariance,
1325
+ iteration=self.backend.iteration,
1326
+ labels=labels,
1327
+ parent_folder=base_folder
1328
+ )
1329
+ elif plot == 'tempering':
1330
+ swap_acceptance_fraction = self.backend.swaps_accepted / float(self.backend.iteration * self.backend.nwalkers)
1331
+ produce_tempering_plots(
1332
+ chain=chain,
1333
+ betas=betas,
1334
+ logl=logl,
1335
+ swap_acceptance_fraction=swap_acceptance_fraction,
1336
+ labels=labels,
1337
+ parent_folder=base_folder,
1338
+ palette=self.tempering_palette
1339
+ )
1340
+
1341
+ elif plot == 'advanced':
1342
+ self.steps.append(self.backend.iteration)
1343
+ if self.total_acceptance_fraction is None:
1344
+ self.total_acceptance_fraction = (self.backend.accepted / float(self.backend.iteration))[np.newaxis, ...]
1345
+ else:
1346
+ self.total_acceptance_fraction = np.vstack((self.total_acceptance_fraction, (self.backend.accepted / float(self.backend.iteration))[np.newaxis, ...])) # shape (niterations, ntemps, nwalkers)
1347
+
1348
+ if sampler is not None:
1349
+ moves = sampler.moves
1350
+ elif hasattr(self.backend, moves):
1351
+ moves = self.backend.moves
1352
+ else:
1353
+ moves = None
1354
+
1355
+ if moves is not None:
1356
+ for move in moves:
1357
+ name = move.__class__.__name__
1358
+ if name not in self.move_acceptance_fractions:
1359
+ self.move_acceptance_fractions[name] = move.acceptance_fraction[np.newaxis, ...]
1360
+ else:
1361
+ self.move_acceptance_fractions[name] = np.vstack((self.move_acceptance_fractions[name], move.acceptance_fraction[np.newaxis, ...])) # shape (niterations, ntemps, nwalkers)
1362
+
1363
+ full_chain = self.backend.get_chain(discard=0) if discard > 0 else chain
1364
+
1365
+ produce_advanced_plots(steps=self.steps,
1366
+ total_acceptance_fraction=self.total_acceptance_fraction,
1367
+ moves_acceptance_fraction=self.move_acceptance_fractions,
1368
+ palette=self.tempering_palette,
1369
+ iteration=self.backend.iteration,
1370
+ chain=full_chain,
1371
+ parent_folder=base_folder
1372
+ )
1373
+
1374
+ elif plot == 'rj':
1375
+ if self.backend.rj is False:
1376
+ continue
1377
+
1378
+ nleaves = self.backend.get_nleaves(discard=discard)
1379
+
1380
+ nleaves_min = sampler.nleaves_min if sampler is not None else dict(zip(self.backend.rj_branches, [0]*len(self.backend.rj_branches)))
1381
+ nleaves_max = self.backend.nleaves_max
1382
+
1383
+ produce_rj_plots(
1384
+ nleaves=nleaves,
1385
+ nleaves_min=nleaves_min,
1386
+ nleaves_max=nleaves_max,
1387
+ palette=self.tempering_palette,
1388
+ parent_folder=base_folder,
1389
+ iteration=self.backend.iteration
1390
+ )
1391
+
1392
+
1393
+