eryn 1.2.4__py3-none-any.whl → 1.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/backends/backend.py +6 -0
- eryn/backends/hdfbackend.py +15 -0
- eryn/ensemble.py +16 -13
- eryn/prior.py +47 -2
- eryn/utils/__init__.py +1 -1
- eryn/utils/periodic.py +19 -3
- eryn/utils/plot.py +1393 -0
- eryn/utils/transform.py +52 -39
- eryn/utils/updates.py +106 -0
- eryn/utils/utility.py +1 -1
- {eryn-1.2.4.dist-info → eryn-1.2.6.dist-info}/METADATA +7 -15
- {eryn-1.2.4.dist-info → eryn-1.2.6.dist-info}/RECORD +13 -12
- {eryn-1.2.4.dist-info → eryn-1.2.6.dist-info}/WHEEL +0 -0
eryn/utils/plot.py
ADDED
|
@@ -0,0 +1,1393 @@
|
|
|
1
|
+
# *-- coding: utf-8 --*
|
|
2
|
+
import os
|
|
3
|
+
import numpy as np
|
|
4
|
+
import matplotlib as mpl
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
from matplotlib.patches import Ellipse, Rectangle
|
|
7
|
+
|
|
8
|
+
from matplotlib.colors import to_rgba
|
|
9
|
+
|
|
10
|
+
import corner
|
|
11
|
+
import typing
|
|
12
|
+
|
|
13
|
+
from eryn.utils.updates import UpdateStep
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import seaborn as sns
|
|
17
|
+
|
|
18
|
+
from eryn.utils.utility import stepping_stone_log_evidence, get_integrated_act
|
|
19
|
+
DEFAULT_PALETTE = "icefire"
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import scienceplots
|
|
23
|
+
plt.style.use(['science'])
|
|
24
|
+
except (ImportError, ModuleNotFoundError):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
# increase default font size
|
|
28
|
+
mpl.rcParams.update({'font.size': 16})
|
|
29
|
+
|
|
30
|
+
class Backend:
|
|
31
|
+
"""A placeholder Backend class for type hinting."""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def save_or_show(fig, filename=None):
|
|
35
|
+
"""
|
|
36
|
+
Save the figure to a file or show it.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
fig (matplotlib.figure.Figure): Figure to save or show.
|
|
40
|
+
filename (str, optional): If provided, saves the figure to this filename.
|
|
41
|
+
"""
|
|
42
|
+
if filename:
|
|
43
|
+
fig.savefig(filename, dpi=200, bbox_inches='tight')
|
|
44
|
+
plt.close(fig)
|
|
45
|
+
else:
|
|
46
|
+
plt.show()
|
|
47
|
+
|
|
48
|
+
def cov_ellipse(mean, cov, ax, n_std=1.0, **kwargs):
|
|
49
|
+
"""
|
|
50
|
+
Plot a covariance ellipse using eigendecomposition.
|
|
51
|
+
|
|
52
|
+
The ellipse axes are aligned with the eigenvectors of the covariance matrix,
|
|
53
|
+
and scaled by sqrt(eigenvalue) * n_std.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
mean (array-like): Center of the ellipse (mean_x, mean_y).
|
|
57
|
+
cov (np.ndarray): 2x2 covariance matrix.
|
|
58
|
+
ax (matplotlib.axes.Axes): Axes object on which to plot the ellipse.
|
|
59
|
+
n_std (float, optional): Number of standard deviations for ellipse radius. Default is 1.0.
|
|
60
|
+
**kwargs: Additional keyword arguments passed to matplotlib.patches.Ellipse.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
matplotlib.patches.Ellipse: The covariance ellipse added to the axes.
|
|
64
|
+
"""
|
|
65
|
+
# Eigendecomposition: eigenvalues are variances along principal axes
|
|
66
|
+
eigenvalues, eigenvectors = np.linalg.eigh(cov)
|
|
67
|
+
|
|
68
|
+
# Sort by eigenvalue (largest first) for consistent orientation
|
|
69
|
+
order = eigenvalues.argsort()[::-1]
|
|
70
|
+
eigenvalues = eigenvalues[order]
|
|
71
|
+
eigenvectors = eigenvectors[:, order]
|
|
72
|
+
|
|
73
|
+
# Ellipse dimensions: 2 * n_std * sqrt(eigenvalue) for width/height
|
|
74
|
+
width, height = 2 * n_std * np.sqrt(eigenvalues)
|
|
75
|
+
|
|
76
|
+
# Rotation angle from the first eigenvector (major axis direction)
|
|
77
|
+
angle = np.degrees(np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0]))
|
|
78
|
+
|
|
79
|
+
ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle, **kwargs)
|
|
80
|
+
return ax.add_patch(ellipse)
|
|
81
|
+
|
|
82
|
+
def overlay_fim_covariance(
|
|
83
|
+
fig,
|
|
84
|
+
covariance,
|
|
85
|
+
means=None,
|
|
86
|
+
nsigmas=[1, 2, 3],
|
|
87
|
+
plot_1d=False,
|
|
88
|
+
colors=None,
|
|
89
|
+
linestyles=None,
|
|
90
|
+
linewidths=None,
|
|
91
|
+
alpha=0.7,
|
|
92
|
+
labels=None,
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
Overlay Fisher Information Matrix confidence contours on corner plot axes.
|
|
96
|
+
|
|
97
|
+
For 2D subplots, draws elliptical contours at specified confidence levels.
|
|
98
|
+
For 1D diagonal plots, draws vertical lines at ±nσ from the mean.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
fig (matplotlib.figure.Figure): Figure object containing corner plot axes.
|
|
102
|
+
covariance (np.ndarray): Covariance matrix from Fisher analysis, shape (n_params, n_params).
|
|
103
|
+
means (np.ndarray, optional): Mean values for each parameter. If None, uses origin (0, 0, ...).
|
|
104
|
+
nsigmas (list, optional): List of sigma levels to plot (e.g., [1, 2, 3]). Default [1, 2, 3].
|
|
105
|
+
plot_1d (bool, optional): Whether to plot 1D contours on diagonal plots. Default is False.
|
|
106
|
+
colors (list, optional): Colors for each sigma level. If None, uses default color cycle.
|
|
107
|
+
linestyles (list, optional): Line styles for each sigma level. If None, uses solid lines.
|
|
108
|
+
linewidths (list, optional): Line widths for each sigma level. If None, uses default (1.5).
|
|
109
|
+
alpha (float, optional): Transparency of contours. Default 0.7.
|
|
110
|
+
labels (list, optional): Labels for legend entries. If None, uses "nσ FIM".
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
matplotlib.figure.Figure: The figure with overlaid contours.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Convert axs to numpy array if it's a list
|
|
117
|
+
axs = np.array(fig.get_axes())
|
|
118
|
+
|
|
119
|
+
# Infer number of parameters from covariance matrix
|
|
120
|
+
n_params = covariance.shape[0]
|
|
121
|
+
|
|
122
|
+
if covariance.shape != (n_params, n_params):
|
|
123
|
+
raise ValueError(f"Covariance matrix must be square, got shape {covariance.shape}")
|
|
124
|
+
|
|
125
|
+
# Set default means to origin
|
|
126
|
+
if means is None:
|
|
127
|
+
means = np.zeros(n_params)
|
|
128
|
+
elif len(means) != n_params:
|
|
129
|
+
raise ValueError(f"means must have length {n_params}, got {len(means)}")
|
|
130
|
+
|
|
131
|
+
# Set default colors
|
|
132
|
+
if colors is None:
|
|
133
|
+
colors = [f'C{i}' for i in range(len(nsigmas))]
|
|
134
|
+
elif isinstance(colors, str):
|
|
135
|
+
colors = [colors] * len(nsigmas)
|
|
136
|
+
elif len(colors) != len(nsigmas):
|
|
137
|
+
raise ValueError(f"colors must have length {len(nsigmas)}, got {len(colors)}")
|
|
138
|
+
|
|
139
|
+
# Set default linestyles
|
|
140
|
+
if linestyles is None:
|
|
141
|
+
linestyles = ['-'] * len(nsigmas)
|
|
142
|
+
elif len(linestyles) != len(nsigmas):
|
|
143
|
+
raise ValueError(f"linestyles must have length {len(nsigmas)}, got {len(linestyles)}")
|
|
144
|
+
|
|
145
|
+
# Set default linewidths
|
|
146
|
+
if linewidths is None:
|
|
147
|
+
linewidths = [1.5] * len(nsigmas)
|
|
148
|
+
elif len(linewidths) != len(nsigmas):
|
|
149
|
+
raise ValueError(f"linewidths must have length {len(nsigmas)}, got {len(linewidths)}")
|
|
150
|
+
|
|
151
|
+
# Set default labels
|
|
152
|
+
if labels is None:
|
|
153
|
+
labels = [f'{n}$\\sigma$ FIM' for n in nsigmas]
|
|
154
|
+
elif len(labels) != len(nsigmas):
|
|
155
|
+
raise ValueError(f"labels must have length {len(nsigmas)}, got {len(labels)}")
|
|
156
|
+
|
|
157
|
+
# Extract standard deviations for 1D plots
|
|
158
|
+
sigmas = np.sqrt(np.diag(covariance))
|
|
159
|
+
|
|
160
|
+
# Reshape axes into 2D grid if needed
|
|
161
|
+
if axs.ndim == 1:
|
|
162
|
+
# Corner plot axes are typically returned as 1D array
|
|
163
|
+
# Reshape to (n_params, n_params) grid
|
|
164
|
+
n_axs = int(np.sqrt(len(axs)))
|
|
165
|
+
axs_grid = np.empty((n_axs, n_axs), dtype=object)
|
|
166
|
+
idx = 0
|
|
167
|
+
for i in range(n_axs):
|
|
168
|
+
for j in range(n_axs):
|
|
169
|
+
axs_grid[j, i] = axs[idx]
|
|
170
|
+
idx += 1
|
|
171
|
+
# if j <= i:
|
|
172
|
+
# axs_grid[j, i] = axs[idx]
|
|
173
|
+
# idx += 1
|
|
174
|
+
# else:
|
|
175
|
+
# axs_grid[j, i] = None
|
|
176
|
+
else:
|
|
177
|
+
axs_grid = axs
|
|
178
|
+
|
|
179
|
+
# Loop over each sigma level
|
|
180
|
+
for sigma_idx, (n_sigma, color, ls, lw, label) in enumerate(
|
|
181
|
+
zip(nsigmas, colors, linestyles, linewidths, labels)
|
|
182
|
+
):
|
|
183
|
+
# Loop over all subplots
|
|
184
|
+
for i in range(n_params):
|
|
185
|
+
for j in range(i, n_params):
|
|
186
|
+
ax = axs_grid[i, j]
|
|
187
|
+
|
|
188
|
+
if ax is None:
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
if i == j:
|
|
192
|
+
if plot_1d:
|
|
193
|
+
# 1D diagonal plot - draw vertical lines at mean ± n*sigma
|
|
194
|
+
mean_val = means[i]
|
|
195
|
+
sigma_val = sigmas[i]
|
|
196
|
+
|
|
197
|
+
# Get y-limits for vertical lines
|
|
198
|
+
ylim = ax.get_ylim()
|
|
199
|
+
|
|
200
|
+
# Draw vertical lines
|
|
201
|
+
for sign in [-1, 1]:
|
|
202
|
+
line = ax.axvline(
|
|
203
|
+
mean_val + sign * n_sigma * sigma_val,
|
|
204
|
+
color=color,
|
|
205
|
+
linestyle=ls,
|
|
206
|
+
linewidth=lw,
|
|
207
|
+
alpha=alpha,
|
|
208
|
+
zorder=10,
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
continue
|
|
212
|
+
|
|
213
|
+
else:
|
|
214
|
+
# 2D off-diagonal plot - draw ellipse
|
|
215
|
+
# Extract 2x2 subcovariance for parameters j and i
|
|
216
|
+
cov = np.array(
|
|
217
|
+
(
|
|
218
|
+
(covariance[i][i], covariance[i][j]),
|
|
219
|
+
(covariance[j][i], covariance[j][j]),
|
|
220
|
+
)
|
|
221
|
+
)
|
|
222
|
+
# print(cov)
|
|
223
|
+
|
|
224
|
+
mean = np.array((means[i], means[j]))
|
|
225
|
+
cov_ellipse(mean, cov, ax, n_std=n_sigma,
|
|
226
|
+
edgecolor=color, facecolor='none', linestyle=ls, linewidth=lw,
|
|
227
|
+
zorder=10, alpha=alpha
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return fig
|
|
231
|
+
|
|
232
|
+
def cornerplot(data, *args, means=None, overlay_covariance=None, legend_label='Samples', overlay_label='Information Matrix Covariance', filename=None, **kwargs):
|
|
233
|
+
"""
|
|
234
|
+
Create a corner plot with optional Information Matrix covariance overlay. This is centered around the means if provided.
|
|
235
|
+
Wrapper around `corner.corner()` that adds Fisher Information Matrix covariance contours.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
data (array-like): Input data for corner plot (e.g., MCMC samples).
|
|
239
|
+
*args: Positional arguments passed to `corner.corner()`.
|
|
240
|
+
means (array-like, optional): Mean values for each parameter to center the overlay. If None, uses 'truths' from kwargs or mean of data.
|
|
241
|
+
overlay_covariance (np.ndarray, optional): Covariance matrix to overlay. If None, no overlay is added.
|
|
242
|
+
legend_label (str, optional): Label for the sample distribution in the legend. Default is 'Samples'.
|
|
243
|
+
overlay_label (str, optional): Label for the overlay covariance in the legend. Default is 'Information Matrix Covariance', assuming FIM.
|
|
244
|
+
filename (str, optional): If provided, saves the figure to this filename.
|
|
245
|
+
**kwargs: Keyword arguments passed to `corner.corner()`.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
matplotlib.figure.Figure: The corner plot figure with optional overlays. If `filename` is provided, the figure is saved instead.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
corner_kwargs = {
|
|
252
|
+
#'quantiles': [0.16, 0.5, 0.84],
|
|
253
|
+
'levels': (1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-4.5)),
|
|
254
|
+
'show_titles': True,
|
|
255
|
+
'title_fmt': '.1e',
|
|
256
|
+
'title_kwargs': {'fontsize': 12},
|
|
257
|
+
'label_kwargs': {'fontsize': 14},
|
|
258
|
+
'hist_kwargs': {'density': True, 'linewidth': 2},
|
|
259
|
+
'plot_datapoints': False,
|
|
260
|
+
'fill_contours': True,
|
|
261
|
+
'color': 'steelblue',
|
|
262
|
+
'truth_color': 'red',
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
# Update base kwargs with any user-provided kwargs
|
|
266
|
+
corner_kwargs.update(kwargs)
|
|
267
|
+
|
|
268
|
+
# how do we deal with 2d truths for reversible jump?
|
|
269
|
+
truths = corner_kwargs.get('truths', None)
|
|
270
|
+
if truths is not None and len(truths.shape) > 1:
|
|
271
|
+
corner_kwargs['truths'] = None # add truths later
|
|
272
|
+
|
|
273
|
+
fig = corner.corner(data=data, *args, **corner_kwargs)
|
|
274
|
+
|
|
275
|
+
# add the other truths as vertical/horizontal lines
|
|
276
|
+
if truths is not None and len(truths.shape) > 1:
|
|
277
|
+
corner_kwargs['truths'] = truths # add back in for legend
|
|
278
|
+
n_params = data.shape[1]
|
|
279
|
+
axs = np.array(fig.get_axes()).reshape((n_params, n_params))
|
|
280
|
+
for truth in truths:
|
|
281
|
+
for i in range(n_params):
|
|
282
|
+
ax = axs[i, i]
|
|
283
|
+
ax.axvline(truth[i], color=corner_kwargs['truth_color'], linestyle='-', linewidth=1)
|
|
284
|
+
|
|
285
|
+
for i in range(n_params):
|
|
286
|
+
for j in range(i):
|
|
287
|
+
ax = axs[i, j]
|
|
288
|
+
ax.axhline(truth[i], color=corner_kwargs['truth_color'], linestyle='-', linewidth=1)
|
|
289
|
+
ax.axvline(truth[j], color=corner_kwargs['truth_color'], linestyle='-', linewidth=1)
|
|
290
|
+
|
|
291
|
+
# prepare handles for legend
|
|
292
|
+
handles = []
|
|
293
|
+
handles_labels = []
|
|
294
|
+
|
|
295
|
+
handles.append(mpl.lines.Line2D([], [], color=corner_kwargs['color'],
|
|
296
|
+
linestyle='-', linewidth=2))
|
|
297
|
+
handles_labels.append(legend_label)
|
|
298
|
+
if 'truths' in corner_kwargs and corner_kwargs['truths'] is not None:
|
|
299
|
+
handles.append(mpl.lines.Line2D([], [], color=corner_kwargs['truth_color'],
|
|
300
|
+
linestyle='-', linewidth=1))
|
|
301
|
+
handles_labels.append('Truths')
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
# Overlay covariance contours if provided
|
|
305
|
+
if overlay_covariance is not None:
|
|
306
|
+
if means is None:
|
|
307
|
+
means = kwargs.get('truths', None)
|
|
308
|
+
|
|
309
|
+
if means is None:
|
|
310
|
+
# If no means or truths provided, default to the mean of the samples
|
|
311
|
+
means = np.mean(data, axis=0)
|
|
312
|
+
|
|
313
|
+
overlay_fim_covariance(
|
|
314
|
+
fig,
|
|
315
|
+
overlay_covariance,
|
|
316
|
+
means=means,
|
|
317
|
+
plot_1d=False,
|
|
318
|
+
alpha=0.8,
|
|
319
|
+
nsigmas=[1, 2 ,3],
|
|
320
|
+
colors='k'
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
handles.append(mpl.lines.Line2D([], [], color='k', linestyle='-', linewidth=1.5))
|
|
324
|
+
handles_labels.append(overlay_label)
|
|
325
|
+
|
|
326
|
+
# Add legend to the second leftmost top subplot
|
|
327
|
+
ax_legend = fig.get_axes()[1]
|
|
328
|
+
ax_legend.legend(handles, handles_labels, loc='upper left', fontsize=18)
|
|
329
|
+
|
|
330
|
+
save_or_show(fig, filename)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def traceplot(chain, labels=None, truths=None, filename=None):
|
|
334
|
+
"""
|
|
335
|
+
Create trace plots for MCMC chains.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
chain (np.ndarray): MCMC chain of shape (nsteps, nwalkers, nleaves, ndim).
|
|
339
|
+
labels (list, optional): List of parameter names for x-axis labels.
|
|
340
|
+
truths (array-like, optional): True parameter values to overlay as horizontal lines.
|
|
341
|
+
filename (str, optional): If provided, saves the figure to this filename.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
matplotlib.figure.Figure: The trace plot figure. If `filename` is provided, the figure is saved instead.
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
nsteps, nwalkers, nleaves, ndim = chain.shape
|
|
348
|
+
fig, axs = plt.subplots(ndim, 1, figsize=(10, 2.5 * ndim), sharex=True)
|
|
349
|
+
|
|
350
|
+
for i in range(ndim):
|
|
351
|
+
for w in range(nwalkers):
|
|
352
|
+
axs[i].plot(chain[:, w, :, i], alpha=0.5, rasterized=True)
|
|
353
|
+
if truths is not None:
|
|
354
|
+
truths = np.atleast_2d(truths)
|
|
355
|
+
for j in range(truths.shape[0]):
|
|
356
|
+
axs[i].axhline(truths[j, i], color='k', linestyle='--')
|
|
357
|
+
if labels is not None:
|
|
358
|
+
axs[i].set_ylabel(labels[i])
|
|
359
|
+
|
|
360
|
+
axs[-1].set_xlabel('Step')
|
|
361
|
+
plt.tight_layout()
|
|
362
|
+
|
|
363
|
+
save_or_show(fig, filename)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def plot_loglikelihood(logl, filename=None):
|
|
367
|
+
nsteps, nwalkers = logl.shape
|
|
368
|
+
fig = plt.figure(figsize=(10, 6))
|
|
369
|
+
for j in range(nwalkers):
|
|
370
|
+
plt.plot(logl[:, j], color=f"C{j % 10}", alpha=0.8, rasterized=True)
|
|
371
|
+
plt.xlabel("Sampler Iteration")
|
|
372
|
+
plt.ylabel("Log-Likelihood")
|
|
373
|
+
|
|
374
|
+
save_or_show(fig, filename)
|
|
375
|
+
|
|
376
|
+
# plot a facet grid of loglikelihood evolution for each walker
|
|
377
|
+
# Reshape: logl is (nsteps, nwalkers), need to flatten properly
|
|
378
|
+
max_logl = np.max(logl, axis=(0,1))
|
|
379
|
+
facet_logl = logl - max_logl
|
|
380
|
+
|
|
381
|
+
step = np.tile(range(nsteps), nwalkers)
|
|
382
|
+
walker = np.repeat(np.arange(nwalkers, dtype=int), nsteps)
|
|
383
|
+
|
|
384
|
+
df = pd.DataFrame(np.c_[facet_logl.flat, step, walker],
|
|
385
|
+
columns=[r"$\Delta \log\mathcal{L}$", "step", "walker"])
|
|
386
|
+
|
|
387
|
+
# Initialize a grid of plots with an Axes for each walker
|
|
388
|
+
grid = sns.FacetGrid(df, col="walker", hue="walker", #palette="tab20c",
|
|
389
|
+
col_wrap=int(np.floor(np.sqrt(nwalkers))), height=1.5)
|
|
390
|
+
|
|
391
|
+
# Draw a line plot to show the trajectory of each random walk
|
|
392
|
+
grid.map(plt.plot, "step", r"$\Delta \log\mathcal{L}$", marker=".", rasterized=True)
|
|
393
|
+
|
|
394
|
+
grid.refline(y=0, linestyle=":") # Add a horizontal reference line at y=0 ~ average loglikelihood at each step
|
|
395
|
+
|
|
396
|
+
# Adjust the arrangement of the plots
|
|
397
|
+
grid.set_titles(col_template="Walker {col_name:.0f}")
|
|
398
|
+
grid.set_axis_labels("Step", r"$\Delta \log\mathcal{L}$")
|
|
399
|
+
# Disable tight_layout to avoid warning
|
|
400
|
+
grid.tight_layout = lambda *args, **kwargs: None
|
|
401
|
+
grid.tight_layout()
|
|
402
|
+
|
|
403
|
+
# add overall title
|
|
404
|
+
plt.subplots_adjust(top=0.9)
|
|
405
|
+
grid.figure.suptitle(r"$\Delta \log\mathcal{L}_w = \log\mathcal{L}_w - \max(\log\mathcal{L})$", fontsize=16)
|
|
406
|
+
|
|
407
|
+
save_or_show(grid.figure, filename.replace('.png', '_facet.png') if filename else None)
|
|
408
|
+
|
|
409
|
+
def tempering_ridgeplot(chain, labels=None, palette=None,
|
|
410
|
+
bw_adjust=0.5, aspect=5, height=0.5, hspace=-0.25,
|
|
411
|
+
max_samples=10000, filename=None):
|
|
412
|
+
"""
|
|
413
|
+
Create ridge plots of tempered distributions using overlapping KDE plots for all parameters.
|
|
414
|
+
|
|
415
|
+
This creates a visually appealing ridge plot (also known as joy plot) showing
|
|
416
|
+
how the posterior distribution broadens at higher temperatures. Each temperature
|
|
417
|
+
level is shown as a separate row with overlapping density estimates.
|
|
418
|
+
All parameters are shown as columns in a single FacetGrid figure.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
chain (np.ndarray): MCMC chain of shape (nsteps, ntemps, nwalkers, nleaves, ndim).
|
|
422
|
+
labels (list, optional): List of parameter names. If provided, uses labels for column titles.
|
|
423
|
+
palette (str or list, optional): Seaborn color palette name or list of colors.
|
|
424
|
+
Default uses cubehelix_palette.
|
|
425
|
+
bw_adjust (float, optional): Bandwidth adjustment factor for KDE. Default is 0.5.
|
|
426
|
+
aspect (float, optional): Aspect ratio of each facet. Default is 5.
|
|
427
|
+
height (float, optional): Height of each temperature row in inches. Default is 0.5.
|
|
428
|
+
hspace (float, optional): Vertical spacing between temperature rows (negative for overlap). Default is -0.25.
|
|
429
|
+
max_samples (int, optional): Maximum number of samples per temperature for KDE. Default is 5000.
|
|
430
|
+
Subsampling speeds up KDE computation for large chains.
|
|
431
|
+
filename (str, optional): If provided, saves figure to this filename.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
matplotlib.figure.Figure: The figure containing all ridge plots.
|
|
435
|
+
If `filename` is provided, the figure is saved instead.
|
|
436
|
+
"""
|
|
437
|
+
# Use seaborn context manager to temporarily set theme
|
|
438
|
+
with sns.axes_style("white", {"axes.facecolor": (0, 0, 0, 0)}):
|
|
439
|
+
|
|
440
|
+
nsteps, ntemps, nwalkers, nleaves, ndim = chain.shape
|
|
441
|
+
|
|
442
|
+
# Create color palette (blue=cold/β=1 at top, red=hot/β→0 at bottom)
|
|
443
|
+
if palette is None:
|
|
444
|
+
# Use coolwarm reversed: blue for cold (β=1), red for hot (β→0)
|
|
445
|
+
pal = sns.color_palette(DEFAULT_PALETTE, ntemps)
|
|
446
|
+
elif isinstance(palette, str):
|
|
447
|
+
pal = sns.color_palette(palette, ntemps)
|
|
448
|
+
else:
|
|
449
|
+
pal = palette
|
|
450
|
+
|
|
451
|
+
# Subsampling RNG
|
|
452
|
+
rng = np.random.default_rng(42)
|
|
453
|
+
|
|
454
|
+
# Build dataframe with samples from all temperatures and parameters
|
|
455
|
+
data_list = []
|
|
456
|
+
for param_idx in range(ndim):
|
|
457
|
+
param_label = labels[param_idx] if labels is not None else fr'$x_{param_idx}$'
|
|
458
|
+
for t in range(ntemps):
|
|
459
|
+
# Flatten samples across steps, walkers, leaves for the selected parameter
|
|
460
|
+
samples = chain[:, t, :, :, param_idx].reshape(-1)
|
|
461
|
+
|
|
462
|
+
# Remove NaNs
|
|
463
|
+
samples = samples[~np.isnan(samples)]
|
|
464
|
+
|
|
465
|
+
# Subsample if needed for faster KDE
|
|
466
|
+
if len(samples) > max_samples:
|
|
467
|
+
samples = rng.choice(samples, size=max_samples, replace=False)
|
|
468
|
+
|
|
469
|
+
temp_label = rf"$\beta_{{{t}}}$" #fr"$\beta$={betas[t]:.1e}"
|
|
470
|
+
data_list.append(pd.DataFrame({
|
|
471
|
+
'x': samples,
|
|
472
|
+
'temp': temp_label,
|
|
473
|
+
'temp_idx': t,
|
|
474
|
+
'param': param_label,
|
|
475
|
+
'param_idx': param_idx
|
|
476
|
+
}))
|
|
477
|
+
|
|
478
|
+
df = pd.concat(data_list, ignore_index=True)
|
|
479
|
+
|
|
480
|
+
# Get unique temps in order (beta=1 first at top)
|
|
481
|
+
temp_order = df.drop_duplicates('temp_idx').sort_values('temp_idx', ascending=True)['temp'].tolist()
|
|
482
|
+
|
|
483
|
+
# Get unique params in order
|
|
484
|
+
param_order = df.drop_duplicates('param_idx').sort_values('param_idx', ascending=True)['param'].tolist()
|
|
485
|
+
|
|
486
|
+
# Compute x-axis limits from the cold posterior (β=1, temp_idx=0) for each parameter
|
|
487
|
+
# This ensures the cold posterior is always visible even when hot distributions are much wider
|
|
488
|
+
xlims = {}
|
|
489
|
+
cold_df = df[df['temp_idx'] == 0]
|
|
490
|
+
for param in param_order:
|
|
491
|
+
param_data = cold_df[cold_df['param'] == param]['x']
|
|
492
|
+
q_low, q_high = param_data.quantile([0.001, 0.999])
|
|
493
|
+
margin = (q_high - q_low) * 0.3 # Add 30% margin to show some broadening
|
|
494
|
+
xlims[param] = (q_low - margin, q_high + margin)
|
|
495
|
+
|
|
496
|
+
# Initialize the FacetGrid with row=temp, col=param
|
|
497
|
+
# Suppress tight_layout warnings during initialization
|
|
498
|
+
import warnings
|
|
499
|
+
with warnings.catch_warnings():
|
|
500
|
+
warnings.filterwarnings("ignore", message=".*Tight layout.*")
|
|
501
|
+
g = sns.FacetGrid(df, row="temp", col="param", hue="temp",
|
|
502
|
+
aspect=aspect, height=height,
|
|
503
|
+
palette=pal, row_order=temp_order, col_order=param_order,
|
|
504
|
+
sharex=False, sharey=False)
|
|
505
|
+
|
|
506
|
+
# Disable tight_layout to avoid warning with negative hspace
|
|
507
|
+
g.tight_layout = lambda *args, **kwargs: None
|
|
508
|
+
|
|
509
|
+
# Custom plotting function that clips KDE to the parameter's xlim
|
|
510
|
+
def plot_kde_clipped(x, color, label, **kwargs):
|
|
511
|
+
ax = plt.gca()
|
|
512
|
+
# Get the parameter for this column
|
|
513
|
+
col_idx = ax.get_subplotspec().colspan.start
|
|
514
|
+
param = param_order[col_idx]
|
|
515
|
+
clip = xlims[param]
|
|
516
|
+
|
|
517
|
+
sns.kdeplot(x, ax=ax, bw_adjust=bw_adjust, clip_on=False,
|
|
518
|
+
fill=True, alpha=1, linewidth=1.5, color=color, clip=clip)
|
|
519
|
+
sns.kdeplot(x, ax=ax, clip_on=False, color="w", lw=2, bw_adjust=bw_adjust, clip=clip)
|
|
520
|
+
ax.set_xlim(clip)
|
|
521
|
+
|
|
522
|
+
g.map(plot_kde_clipped, "x")
|
|
523
|
+
|
|
524
|
+
# Add horizontal reference line at y=0
|
|
525
|
+
g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)
|
|
526
|
+
|
|
527
|
+
# Label function for temperature labels (only on first column)
|
|
528
|
+
def label_temp(x, color, label):
|
|
529
|
+
ax = plt.gca()
|
|
530
|
+
# Only add label if this is the first column
|
|
531
|
+
if ax.get_subplotspec().colspan.start == 0:
|
|
532
|
+
ax.text(-0.2, 0.2, label, fontweight="bold", color=color,
|
|
533
|
+
ha="left", va="center", transform=ax.transAxes, fontsize=14)
|
|
534
|
+
|
|
535
|
+
g.map(label_temp, "x")
|
|
536
|
+
|
|
537
|
+
# Set subplots to overlap vertically
|
|
538
|
+
g.figure.subplots_adjust(hspace=hspace, wspace=0.1)
|
|
539
|
+
|
|
540
|
+
# use parameter labels in the x axes (on the last row)
|
|
541
|
+
for ax, param in zip(g.axes[-1], param_order):
|
|
542
|
+
ax.set_xlabel(param, fontsize=14)
|
|
543
|
+
|
|
544
|
+
# Remove xticks from all rows except the bottom
|
|
545
|
+
for row_idx in range(len(temp_order) - 1):
|
|
546
|
+
for col_idx in range(len(param_order)):
|
|
547
|
+
g.axes[row_idx, col_idx].set_xticks([])
|
|
548
|
+
|
|
549
|
+
# set xlims again to ensure consistency
|
|
550
|
+
param = param_order[col_idx]
|
|
551
|
+
g.axes[row_idx, col_idx].set_xlim(xlims[param])
|
|
552
|
+
|
|
553
|
+
# Remove axes details that don't work well with overlap
|
|
554
|
+
g.set_titles("")
|
|
555
|
+
g.set(yticks=[], ylabel="")
|
|
556
|
+
g.despine(bottom=True, left=True)
|
|
557
|
+
|
|
558
|
+
# Add overall title
|
|
559
|
+
g.figure.suptitle('Tempered Distributions', y=1.02, fontsize=14)
|
|
560
|
+
|
|
561
|
+
save_or_show(g.figure, filename)
|
|
562
|
+
|
|
563
|
+
# # Properly restore the original seaborn style
|
|
564
|
+
# sns.set_theme(style=original_style['style'], rc=original_style)
|
|
565
|
+
|
|
566
|
+
# return g.figure
|
|
567
|
+
|
|
568
|
+
def plot_swap_acceptance(swap_acceptance_fraction, palette=None, filename=None):
|
|
569
|
+
"""
|
|
570
|
+
Plot the temperature swap acceptance fraction between adjacent temperature levels.
|
|
571
|
+
|
|
572
|
+
Args:
|
|
573
|
+
swap_acceptance_fraction (np.ndarray): Swap acceptance fraction between adjacent
|
|
574
|
+
temperatures, shape (ntemps-1,). Element i corresponds to swaps between
|
|
575
|
+
temperature i and i+1.
|
|
576
|
+
palette (str or list, optional): Seaborn color palette name or list of colors.
|
|
577
|
+
filename (str, optional): If provided, saves the figure to this filename.
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
matplotlib.figure.Figure: The swap acceptance plot figure.
|
|
581
|
+
"""
|
|
582
|
+
ntemps = swap_acceptance_fraction.shape[0] + 1
|
|
583
|
+
|
|
584
|
+
fig, ax = plt.subplots(figsize=(8, 5))
|
|
585
|
+
|
|
586
|
+
# X-axis: temperature pair indices (swap between temp i and i+1)
|
|
587
|
+
x = np.arange(ntemps - 1)
|
|
588
|
+
|
|
589
|
+
# Create labels for each swap pair
|
|
590
|
+
labels = [fr'{i}$\leftrightarrow${i+1}' for i in range(ntemps - 1)]
|
|
591
|
+
|
|
592
|
+
# Color by temperature (use coolwarm to match ridgeplot)
|
|
593
|
+
palette = palette if palette is not None else DEFAULT_PALETTE
|
|
594
|
+
colors = sns.color_palette(palette, ntemps - 1)
|
|
595
|
+
|
|
596
|
+
# Bar plot
|
|
597
|
+
bars = ax.bar(x, swap_acceptance_fraction, color=colors, edgecolor='white', linewidth=0.5)
|
|
598
|
+
|
|
599
|
+
# Add horizontal reference line at common target acceptance (0.2-0.4 is often good)
|
|
600
|
+
ax.axhline(y=0.25, color='gray', linestyle='--', linewidth=1, alpha=0.7, label='0.25')
|
|
601
|
+
|
|
602
|
+
# Add beta values as secondary labels
|
|
603
|
+
# ax2 = ax.twiny()
|
|
604
|
+
# ax2.set_xlim(ax.get_xlim())
|
|
605
|
+
# ax2.set_xticks(x)
|
|
606
|
+
# beta_labels = [fr'$\beta$={betas[i]:.1e}$\leftrightarrow${betas[i+1]:.1e}' for i in range(ntemps - 1)]
|
|
607
|
+
# ax2.set_xticklabels(beta_labels, fontsize=8, rotation=45, ha='left')
|
|
608
|
+
|
|
609
|
+
# Labels and formatting
|
|
610
|
+
ax.set_xticks(x)
|
|
611
|
+
ax.set_xticklabels(labels, fontsize=10, rotation=45, ha='right')
|
|
612
|
+
ax.set_xlabel('Temperature Pair (index)', fontsize=12)
|
|
613
|
+
ax.set_ylabel('Swap Acceptance Fraction', fontsize=12)
|
|
614
|
+
ax.set_title('Temperature Swap Acceptance', fontsize=14)
|
|
615
|
+
ax.set_ylim(0, 1.1)
|
|
616
|
+
ax.legend(loc='upper left', fontsize=10)
|
|
617
|
+
|
|
618
|
+
# Add value annotations on bars
|
|
619
|
+
for bar, val in zip(bars, swap_acceptance_fraction):
|
|
620
|
+
height = bar.get_height()
|
|
621
|
+
ax.annotate(f'{val:.2f}',
|
|
622
|
+
xy=(bar.get_x() + bar.get_width() / 2, height),
|
|
623
|
+
xytext=(0, 3), # 3 points vertical offset
|
|
624
|
+
textcoords="offset points",
|
|
625
|
+
ha='center', va='bottom', fontsize=9)
|
|
626
|
+
|
|
627
|
+
plt.tight_layout()
|
|
628
|
+
|
|
629
|
+
save_or_show(fig, filename)
|
|
630
|
+
|
|
631
|
+
return fig
|
|
632
|
+
|
|
633
|
+
def plot_logl_betas(betas: np.ndarray,
|
|
634
|
+
logl: np.ndarray,
|
|
635
|
+
palette: str = None,
|
|
636
|
+
filename: str = None
|
|
637
|
+
):
|
|
638
|
+
"""
|
|
639
|
+
Plots the evolution of log-likelihood values for each temperature.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
betas (numpy.ndarray): Array of inverse temperatures.
|
|
643
|
+
logl (numpy.ndarray): Array of log-likelihood values.
|
|
644
|
+
palette (str, optional): Seaborn color palette name or list of colors.
|
|
645
|
+
filename (str, optional): If provided, saves the figure to this filename.
|
|
646
|
+
Returns:
|
|
647
|
+
None
|
|
648
|
+
"""
|
|
649
|
+
fig = plt.figure(figsize=(10, 6))
|
|
650
|
+
ntemp = betas.shape[1]
|
|
651
|
+
tempcolors = sns.color_palette(palette if palette is not None else DEFAULT_PALETTE, ntemp)
|
|
652
|
+
for temp in range(ntemp):
|
|
653
|
+
plt.semilogx(betas[-1, temp], np.mean(logl[:, temp]), '.', c=tempcolors[temp], label=f'$T_{temp}$')
|
|
654
|
+
|
|
655
|
+
logZ, dlogZ = stepping_stone_log_evidence(betas[-1], logl)
|
|
656
|
+
|
|
657
|
+
plt.ylabel(r'$<\log{\mathcal{L}}>_{\beta}$')
|
|
658
|
+
plt.xlabel(r'$\beta$')
|
|
659
|
+
plt.title(r'$\log{\mathcal{Z}} = %.2f \pm %.2f$' % (logZ, dlogZ))
|
|
660
|
+
|
|
661
|
+
save_or_show(fig, filename)
|
|
662
|
+
|
|
663
|
+
def plot_betas_evolution(betas: np.ndarray, palette: str = None, filename: str = None):
|
|
664
|
+
"""
|
|
665
|
+
Plots the evolution of inverse temperatures (betas) over sampling steps.
|
|
666
|
+
|
|
667
|
+
Args:
|
|
668
|
+
betas (numpy.ndarray): Array of inverse temperatures of shape (nsteps, ntemps).
|
|
669
|
+
palette (str, optional): Seaborn color palette name or list of colors.
|
|
670
|
+
filename (str, optional): If provided, saves the figure to this filename.
|
|
671
|
+
Returns:
|
|
672
|
+
None
|
|
673
|
+
"""
|
|
674
|
+
nsteps, ntemps = betas.shape
|
|
675
|
+
tempcolors = sns.color_palette(palette if palette is not None else DEFAULT_PALETTE, ntemps)
|
|
676
|
+
fig = plt.figure(figsize=(10, 6))
|
|
677
|
+
for temp in range(ntemps):
|
|
678
|
+
plt.plot(range(nsteps), betas[:, temp], color=tempcolors[temp], linewidth=1.5, alpha=0.8, rasterized=True)
|
|
679
|
+
plt.xlabel('Sampler Iteration')
|
|
680
|
+
plt.ylabel(r'Inverse Temperature ($\beta$)')
|
|
681
|
+
plt.title('Evolution of Inverse Temperatures')
|
|
682
|
+
|
|
683
|
+
# Create temperature color gradient legend
|
|
684
|
+
|
|
685
|
+
ax = plt.gca()
|
|
686
|
+
legend_width = 0.15
|
|
687
|
+
legend_height = 0.03
|
|
688
|
+
legend_x = 0.75
|
|
689
|
+
legend_y = 0.9
|
|
690
|
+
|
|
691
|
+
# Create gradient patches
|
|
692
|
+
for i, color in enumerate(tempcolors[::-1]): # Reverse to show cold->hot left to right
|
|
693
|
+
rect = Rectangle(
|
|
694
|
+
(legend_x + i * legend_width / ntemps, legend_y),
|
|
695
|
+
legend_width / ntemps, legend_height,
|
|
696
|
+
transform=ax.transAxes,
|
|
697
|
+
facecolor=to_rgba(color, 0.7),
|
|
698
|
+
edgecolor='none'
|
|
699
|
+
)
|
|
700
|
+
ax.add_patch(rect)
|
|
701
|
+
|
|
702
|
+
# Add border and labels
|
|
703
|
+
border = Rectangle(
|
|
704
|
+
(legend_x, legend_y), legend_width, legend_height,
|
|
705
|
+
transform=ax.transAxes,
|
|
706
|
+
facecolor='none',
|
|
707
|
+
edgecolor='black',
|
|
708
|
+
linewidth=0.5
|
|
709
|
+
)
|
|
710
|
+
ax.add_patch(border)
|
|
711
|
+
|
|
712
|
+
# Add text labels
|
|
713
|
+
ax.text(legend_x - 0.01, legend_y + legend_height / 2, r'$T_{\rm max}$',
|
|
714
|
+
transform=ax.transAxes, ha='right', va='center', fontsize=11, fontweight='normal', antialiased=True)
|
|
715
|
+
ax.text(legend_x + legend_width + 0.01, legend_y + legend_height / 2, r'$T_0$',
|
|
716
|
+
transform=ax.transAxes, ha='left', va='center', fontsize=11, fontweight='normal', antialiased=True)
|
|
717
|
+
|
|
718
|
+
save_or_show(fig, filename)
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
# RJ plots
|
|
722
|
+
def plot_leaves(nleaves: np.ndarray,
|
|
723
|
+
nleaves_min: int,
|
|
724
|
+
nleaves_max: int,
|
|
725
|
+
palette: str = None,
|
|
726
|
+
iteration: int = 0,
|
|
727
|
+
filename: str = None):
|
|
728
|
+
"""
|
|
729
|
+
Plot the histogram of the number of leaves for each temperature.
|
|
730
|
+
|
|
731
|
+
This method plots a histogram of the number of leaves for each temperature in the `rj_branches` dictionary.
|
|
732
|
+
It uses the `self.backend` object to get the number of leaves for each temperature.
|
|
733
|
+
The histogram is plotted using the `plt.hist` function from the `matplotlib.pyplot` module.
|
|
734
|
+
The plot includes temperature-specific colors and a legend for the colors.
|
|
735
|
+
|
|
736
|
+
Returns:
|
|
737
|
+
None
|
|
738
|
+
"""
|
|
739
|
+
bns = (np.arange(nleaves_min, nleaves_max + 2) - 0.5)
|
|
740
|
+
ntemps = nleaves.shape[1]
|
|
741
|
+
tempcolors = sns.color_palette(palette if palette is not None else DEFAULT_PALETTE, ntemps)
|
|
742
|
+
|
|
743
|
+
fig = plt.figure(figsize=(8, 5))
|
|
744
|
+
|
|
745
|
+
for temp, tempcolor in enumerate(tempcolors):
|
|
746
|
+
plt.hist(nleaves[:, temp].flatten(), bins=bns, histtype="stepfilled", edgecolor=tempcolor,
|
|
747
|
+
facecolor=to_rgba(tempcolor, 0.2), density=True, ls='-', zorder=100 - temp, rasterized=True)
|
|
748
|
+
|
|
749
|
+
plt.xlabel('Number of leaves')
|
|
750
|
+
plt.ylabel('Density')
|
|
751
|
+
|
|
752
|
+
# Create temperature color gradient legend
|
|
753
|
+
# Use a horizontal gradient showing cold (blue) to hot (red)
|
|
754
|
+
|
|
755
|
+
# Add a color bar as legend showing temperature progression
|
|
756
|
+
ax = plt.gca()
|
|
757
|
+
legend_width = 0.15
|
|
758
|
+
legend_height = 0.03
|
|
759
|
+
legend_x = 0.75
|
|
760
|
+
legend_y = 0.9
|
|
761
|
+
|
|
762
|
+
# Create gradient patches
|
|
763
|
+
n_gradient = len(tempcolors)
|
|
764
|
+
for i, color in enumerate(tempcolors[::-1]): # Reverse to show cold->hot left to right
|
|
765
|
+
rect = Rectangle(
|
|
766
|
+
(legend_x + i * legend_width / n_gradient, legend_y),
|
|
767
|
+
legend_width / n_gradient, legend_height,
|
|
768
|
+
transform=ax.transAxes,
|
|
769
|
+
facecolor=to_rgba(color, 0.7),
|
|
770
|
+
edgecolor='none'
|
|
771
|
+
)
|
|
772
|
+
ax.add_patch(rect)
|
|
773
|
+
|
|
774
|
+
# Add border and labels
|
|
775
|
+
border = Rectangle(
|
|
776
|
+
(legend_x, legend_y), legend_width, legend_height,
|
|
777
|
+
transform=ax.transAxes,
|
|
778
|
+
facecolor='none',
|
|
779
|
+
edgecolor='black',
|
|
780
|
+
linewidth=0.5
|
|
781
|
+
)
|
|
782
|
+
ax.add_patch(border)
|
|
783
|
+
|
|
784
|
+
# Add text labels
|
|
785
|
+
ax.text(legend_x - 0.01, legend_y + legend_height / 2, r'$T_{\rm max}$',
|
|
786
|
+
transform=ax.transAxes, ha='right', va='center', fontsize=11, fontweight='normal', antialiased=True)
|
|
787
|
+
ax.text(legend_x + legend_width + 0.01, legend_y + legend_height / 2, r'$T_0$',
|
|
788
|
+
transform=ax.transAxes, ha='left', va='center', fontsize=11, fontweight='normal', antialiased=True)
|
|
789
|
+
|
|
790
|
+
fig.text(0.07, 0.08, f"Step: {iteration}", ha='left', va='top', fontfamily='serif', c='k')
|
|
791
|
+
#plt.title(key)
|
|
792
|
+
save_or_show(fig, filename)
|
|
793
|
+
|
|
794
|
+
def plot_leaves_evolution(nleaves: np.ndarray,
|
|
795
|
+
filename: str = None):
|
|
796
|
+
"""
|
|
797
|
+
Plot the evolution of the number of leaves per walker in the cold chain over sampling steps.
|
|
798
|
+
Args:
|
|
799
|
+
nleaves (np.ndarray): Array of number of leaves in the cold chain, shape (nsteps, nwalkers).
|
|
800
|
+
filename (str, optional): If provided, saves the figure to this filename.
|
|
801
|
+
Returns:
|
|
802
|
+
None
|
|
803
|
+
"""
|
|
804
|
+
nsteps, nwalkers = nleaves.shape
|
|
805
|
+
fig = plt.figure(figsize=(10, 6))
|
|
806
|
+
for w in range(nwalkers):
|
|
807
|
+
plt.plot(range(nsteps), nleaves[:, w], color=f"C{w % 10}", linewidth=1.5, alpha=0.8, rasterized=True)
|
|
808
|
+
plt.xlabel('Sampler Iteration')
|
|
809
|
+
plt.ylabel('Number of Leaves')
|
|
810
|
+
plt.title('Evolution of Number of Leaves in Cold Chain')
|
|
811
|
+
save_or_show(fig, filename)
|
|
812
|
+
|
|
813
|
+
def plot_acceptance_fraction(steps: typing.Union[np.ndarray, list],
|
|
814
|
+
total_acceptance_fraction: np.ndarray,
|
|
815
|
+
moves_acceptance_fraction: dict,
|
|
816
|
+
filename: str = None):
|
|
817
|
+
"""
|
|
818
|
+
Plot the acceptance fraction for different moves over sampling steps.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
|
|
822
|
+
"""
|
|
823
|
+
|
|
824
|
+
fig = plt.figure(figsize=(10, 6))
|
|
825
|
+
# cold chain total acceptance fraction
|
|
826
|
+
plt.plot(steps, total_acceptance_fraction[:, 0].mean(axis=1), label='Total', color='black', linewidth=2)
|
|
827
|
+
|
|
828
|
+
# skip if moves_acceptance_fraction is empty
|
|
829
|
+
if len(moves_acceptance_fraction) != 0:
|
|
830
|
+
for move, acc_fraction in moves_acceptance_fraction.items():
|
|
831
|
+
plt.plot(steps, acc_fraction[:, 0].mean(axis=1), marker='o', label=move)
|
|
832
|
+
|
|
833
|
+
plt.axhline(y=0.234, color='gray', linestyle='--', linewidth=1, alpha=0.7, label='0.234')
|
|
834
|
+
plt.legend()
|
|
835
|
+
plt.xlabel('Sampler Iteration')
|
|
836
|
+
plt.ylabel('Acceptance Fraction')
|
|
837
|
+
plt.title('Acceptance Fraction Over Time')
|
|
838
|
+
|
|
839
|
+
save_or_show(fig, filename)
|
|
840
|
+
|
|
841
|
+
def plot_tempered_acceptance_fraction(steps: typing.Union[np.ndarray, list],
|
|
842
|
+
total_acceptance_fraction: np.ndarray,
|
|
843
|
+
palette: str = None,
|
|
844
|
+
filename: str = None):
|
|
845
|
+
"""
|
|
846
|
+
Plot the acceptance fraction for different moves over sampling steps.
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
steps (np.ndarray or list): Array of sampling steps.
|
|
850
|
+
total_acceptance_fraction (np.ndarray): Array of total acceptance fractions, shape (nsteps, ntemps, nwalkers).
|
|
851
|
+
palette (str or list, optional): Seaborn color palette name or list of colors.
|
|
852
|
+
filename (str, optional): If provided, saves the figure to this filename.
|
|
853
|
+
"""
|
|
854
|
+
ntemps = total_acceptance_fraction.shape[1]
|
|
855
|
+
tempcolors = sns.color_palette(palette if palette is not None else DEFAULT_PALETTE, ntemps)
|
|
856
|
+
|
|
857
|
+
fig = plt.figure(figsize=(10, 6))
|
|
858
|
+
|
|
859
|
+
for temp in range(ntemps):
|
|
860
|
+
plt.plot(steps, total_acceptance_fraction[:, temp].mean(axis=1), color=tempcolors[temp], linewidth=1.5, marker='o', alpha=0.8, rasterized=True)
|
|
861
|
+
|
|
862
|
+
ax = plt.gca()
|
|
863
|
+
legend_width = 0.15
|
|
864
|
+
legend_height = 0.03
|
|
865
|
+
legend_x = 0.75
|
|
866
|
+
legend_y = 0.9
|
|
867
|
+
|
|
868
|
+
# Create gradient patches
|
|
869
|
+
for i, color in enumerate(tempcolors[::-1]): # Reverse to show cold->hot left to right
|
|
870
|
+
rect = Rectangle(
|
|
871
|
+
(legend_x + i * legend_width / ntemps, legend_y),
|
|
872
|
+
legend_width / ntemps, legend_height,
|
|
873
|
+
transform=ax.transAxes,
|
|
874
|
+
facecolor=to_rgba(color, 0.7),
|
|
875
|
+
edgecolor='none'
|
|
876
|
+
)
|
|
877
|
+
ax.add_patch(rect)
|
|
878
|
+
|
|
879
|
+
# Add border and labels
|
|
880
|
+
border = Rectangle(
|
|
881
|
+
(legend_x, legend_y), legend_width, legend_height,
|
|
882
|
+
transform=ax.transAxes,
|
|
883
|
+
facecolor='none',
|
|
884
|
+
edgecolor='black',
|
|
885
|
+
linewidth=0.5
|
|
886
|
+
)
|
|
887
|
+
ax.add_patch(border)
|
|
888
|
+
|
|
889
|
+
# Add text labels
|
|
890
|
+
ax.text(legend_x - 0.01, legend_y + legend_height / 2, r'$T_{\rm max}$',
|
|
891
|
+
transform=ax.transAxes, ha='right', va='center', fontsize=11, fontweight='normal', antialiased=True)
|
|
892
|
+
ax.text(legend_x + legend_width + 0.01, legend_y + legend_height / 2, r'$T_0$',
|
|
893
|
+
transform=ax.transAxes, ha='left', va='center', fontsize=11, fontweight='normal', antialiased=True)
|
|
894
|
+
|
|
895
|
+
plt.legend()
|
|
896
|
+
plt.xlabel('Sampler Iteration')
|
|
897
|
+
plt.ylabel('Acceptance Fraction')
|
|
898
|
+
|
|
899
|
+
ymin, ymax = plt.ylim()
|
|
900
|
+
plt.ylim(ymin, 1.2 * ymax)
|
|
901
|
+
|
|
902
|
+
plt.title('Acceptance Fraction Over Time')
|
|
903
|
+
|
|
904
|
+
save_or_show(fig, filename)
|
|
905
|
+
|
|
906
|
+
def plot_act_evolution(chain: dict,
|
|
907
|
+
iteration: int = 0,
|
|
908
|
+
parent_folder: str = '.'):
|
|
909
|
+
|
|
910
|
+
"""
|
|
911
|
+
Plot the evolution of the autocorrelation time for each branch in the chain. Also plots the ACT values per parameter in each branch.
|
|
912
|
+
|
|
913
|
+
Args:
|
|
914
|
+
chain (Dict): Dictionary of MCMC chains for different branches.
|
|
915
|
+
iteration (int, optional): Current iteration number for labeling. Default is 0.
|
|
916
|
+
parent_folder (str, optional): Folder to save the plots. Default is current directory.
|
|
917
|
+
"""
|
|
918
|
+
|
|
919
|
+
NPOINTS = 10
|
|
920
|
+
points = np.exp(np.linspace(np.log(min(100, iteration)), np.log(iteration), NPOINTS)).astype(int)
|
|
921
|
+
|
|
922
|
+
taus = {}
|
|
923
|
+
for branch, samples in chain.items():
|
|
924
|
+
# create branch folder
|
|
925
|
+
branch_folder = os.path.join(parent_folder, branch)
|
|
926
|
+
os.makedirs(branch_folder, exist_ok=True)
|
|
927
|
+
|
|
928
|
+
nsteps, ntemps, nwalkers, nleaves, ndim = samples.shape
|
|
929
|
+
cold_chain = samples[:, 0, :, :, :].reshape(nsteps, 1, nwalkers, nleaves * ndim)
|
|
930
|
+
#cold_chain = cold_chain[~np.isnan(cold_chain).any(axis=-1)] # remove NaNs
|
|
931
|
+
if np.isnan(cold_chain).any(axis=-1).any():
|
|
932
|
+
print(f"Skipping ACT plot for branch {branch} due to NaNs in the cold chain.")
|
|
933
|
+
continue
|
|
934
|
+
tmp = []
|
|
935
|
+
for i, point in enumerate(points):
|
|
936
|
+
if point > cold_chain.shape[0]:
|
|
937
|
+
continue
|
|
938
|
+
|
|
939
|
+
tau = get_integrated_act(cold_chain[:point], average=True)
|
|
940
|
+
tmp.append(tau.squeeze())
|
|
941
|
+
|
|
942
|
+
taus[branch] = np.array(tmp)
|
|
943
|
+
|
|
944
|
+
fig = plt.figure(figsize=(10, 6))
|
|
945
|
+
for d in range(ndim):
|
|
946
|
+
plt.loglog(points[:len(taus[branch])], taus[branch][:, d], marker='o', label=fr'$x_{d}$')
|
|
947
|
+
|
|
948
|
+
# add tau = iteration / 50 line for reference
|
|
949
|
+
xaxis = np.logspace(0, np.log10(iteration), 100)
|
|
950
|
+
#xaxis[0] = 1 # avoid plotting tau=0 at step 0
|
|
951
|
+
|
|
952
|
+
plt.loglog(xaxis, xaxis / 50, linestyle='--', color='k', lw=2, label=r'$\tau$ = Nsteps / 50')
|
|
953
|
+
plt.xlabel('Number of Steps')
|
|
954
|
+
plt.ylabel('Integrated Autocorrelation Time')
|
|
955
|
+
plt.title(f'Autocorrelation Time Evolution - Branch: {branch}')
|
|
956
|
+
|
|
957
|
+
ymax = np.max(taus[branch]) * 1.2
|
|
958
|
+
ymin = np.min(taus[branch]) * 0.5
|
|
959
|
+
if ymin < ymax:
|
|
960
|
+
plt.ylim(ymin, ymax)
|
|
961
|
+
else:
|
|
962
|
+
plt.ylim(0, ymax)
|
|
963
|
+
plt.xlim(min(points) * 0.9, iteration * 1.1)
|
|
964
|
+
plt.legend()
|
|
965
|
+
save_or_show(fig, os.path.join(branch_folder, f'act_evolution.png'))
|
|
966
|
+
|
|
967
|
+
# plot the maximum ACT across all parameters for each branch
|
|
968
|
+
if len(taus) > 0:
|
|
969
|
+
fig = plt.figure(figsize=(10, 6))
|
|
970
|
+
for branch in taus.keys():
|
|
971
|
+
max_tau = np.max(taus[branch], axis=1)
|
|
972
|
+
plt.loglog(points[:len(max_tau)], max_tau, marker='o', label=f'Branch: {branch}')
|
|
973
|
+
|
|
974
|
+
plt.loglog(xaxis, xaxis / 50, linestyle='--', color='k', lw=2, label=r'$\tau$ = Nsteps / 50')
|
|
975
|
+
plt.xlabel('Number of Steps')
|
|
976
|
+
plt.ylabel('Maximum Integrated Autocorrelation Time')
|
|
977
|
+
plt.title('Maximum Autocorrelation Time Evolution Across Branches')
|
|
978
|
+
ymax = max([np.max(np.max(taus[branch], axis=1)) for branch in taus.keys()]) * 1.2
|
|
979
|
+
ymin = min([np.min(np.max(taus[branch], axis=1)) for branch in taus.keys()]) * 0.5
|
|
980
|
+
if ymin < ymax:
|
|
981
|
+
plt.ylim(ymin, ymax)
|
|
982
|
+
else:
|
|
983
|
+
plt.ylim(0, ymax)
|
|
984
|
+
plt.xlim(min(points) * 0.9, iteration * 1.1)
|
|
985
|
+
plt.legend()
|
|
986
|
+
save_or_show(fig, os.path.join(parent_folder, f'max_act_evolution.png'))
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
def produce_base_plots(chain: dict,
|
|
990
|
+
logl: np.ndarray,
|
|
991
|
+
truths: dict = None,
|
|
992
|
+
overlay_covariance: dict = None,
|
|
993
|
+
labels: dict = None,
|
|
994
|
+
iteration: int = 0,
|
|
995
|
+
parent_folder: str = '.',
|
|
996
|
+
):
|
|
997
|
+
|
|
998
|
+
"""
|
|
999
|
+
Produce a set of standard diagnostic plot. These include:
|
|
1000
|
+
|
|
1001
|
+
* corner plots for the cold chain per branch,
|
|
1002
|
+
* trace plots for the cold chain per branch,
|
|
1003
|
+
* log-likelihood evolution plots.
|
|
1004
|
+
|
|
1005
|
+
Args:
|
|
1006
|
+
chain (Dict): Dictionary of MCMC chains for different branches.
|
|
1007
|
+
logl (np.ndarray): Log-likelihood array of shape (nsteps, ntemperatures, nwalkers).
|
|
1008
|
+
truths (Dict, optional): Dictionary of true parameter values for different branches.
|
|
1009
|
+
overlay_covariance (Dict, optional): Dictionary of covariance matrices to overlay on corner plots.
|
|
1010
|
+
labels (Dict, optional): Dictionary of parameter labels for different branches.
|
|
1011
|
+
parent_folder (str, optional): Folder to save the plots. Default is current directory.
|
|
1012
|
+
"""
|
|
1013
|
+
|
|
1014
|
+
# create dictionary with a color shade for each branch
|
|
1015
|
+
branches = list(chain.keys())
|
|
1016
|
+
legend_label = 'Samples at step %d' % iteration
|
|
1017
|
+
|
|
1018
|
+
palette = 'Blues'
|
|
1019
|
+
colors = sns.color_palette(palette, n_colors=len(branches))
|
|
1020
|
+
#colors = sns.color_palette("mako", n_colors=len(branches)+2)[2:] # avoid very light colors
|
|
1021
|
+
branch_colors = {branch: colors[i] for i, branch in enumerate(branches)}
|
|
1022
|
+
|
|
1023
|
+
for branch, samples in chain.items():
|
|
1024
|
+
branch_labels = labels.get(branch, None) if labels else None
|
|
1025
|
+
branch_truths = truths.get(branch, None) if truths else None
|
|
1026
|
+
branch_cov = overlay_covariance.get(branch, None) if overlay_covariance else None
|
|
1027
|
+
|
|
1028
|
+
# create branch folder
|
|
1029
|
+
branch_folder = os.path.join(parent_folder, branch)
|
|
1030
|
+
os.makedirs(branch_folder, exist_ok=True)
|
|
1031
|
+
|
|
1032
|
+
nsteps, ntemps, nwalkers, nleaves, ndim = samples.shape
|
|
1033
|
+
cold_chain = samples[:, 0, :, :, :].reshape(-1, ndim)
|
|
1034
|
+
cold_chain = cold_chain[~np.isnan(cold_chain).any(axis=1)] # remove NaNs
|
|
1035
|
+
|
|
1036
|
+
cornerplot(
|
|
1037
|
+
cold_chain,
|
|
1038
|
+
means=branch_truths,
|
|
1039
|
+
overlay_covariance=branch_cov,
|
|
1040
|
+
legend_label=legend_label,
|
|
1041
|
+
truths=branch_truths,
|
|
1042
|
+
overlay_label='Information Matrix Covariance' if branch_cov is not None else None,
|
|
1043
|
+
labels=branch_labels,
|
|
1044
|
+
color=branch_colors[branch],
|
|
1045
|
+
filename=os.path.join(branch_folder, f'cornerplot.png')
|
|
1046
|
+
)
|
|
1047
|
+
|
|
1048
|
+
traceplot(
|
|
1049
|
+
samples[:, 0, :, :, :],
|
|
1050
|
+
labels=branch_labels,
|
|
1051
|
+
truths=branch_truths,
|
|
1052
|
+
filename=os.path.join(branch_folder, f'traceplot.png')
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
plot_loglikelihood(
|
|
1056
|
+
logl[:, 0, :],
|
|
1057
|
+
filename=os.path.join(parent_folder, f'loglikelihood.png')
|
|
1058
|
+
)
|
|
1059
|
+
|
|
1060
|
+
def produce_tempering_plots(chain: dict,
|
|
1061
|
+
betas: np.ndarray,
|
|
1062
|
+
logl: np.ndarray,
|
|
1063
|
+
swap_acceptance_fraction: np.ndarray,
|
|
1064
|
+
labels: dict = None,
|
|
1065
|
+
parent_folder: str = '.',
|
|
1066
|
+
palette: str = None
|
|
1067
|
+
):
|
|
1068
|
+
"""
|
|
1069
|
+
Produce tempering ridge plots for each branch in the chain. These include:
|
|
1070
|
+
|
|
1071
|
+
* ridge plots of the tempered distributions per parameter per branch,
|
|
1072
|
+
* the swap acceptance fraction between adjacent temperatures,
|
|
1073
|
+
* the averaged log-likelihood vs. betas plot,
|
|
1074
|
+
* the evolution of betas over sampling steps.
|
|
1075
|
+
|
|
1076
|
+
Args:
|
|
1077
|
+
chain (Dict): Dictionary of MCMC chains for different branches.
|
|
1078
|
+
betas (np.ndarray): Inverse temperatures of shape (nsteps, ntemps,).
|
|
1079
|
+
swap_acceptance_fraction (np.ndarray): Swap acceptance fraction between adjacent temperatures.
|
|
1080
|
+
labels (Dict, optional): Dictionary of parameter labels for different branches.
|
|
1081
|
+
logl (np.ndarray): Log-likelihood array of shape (nsteps, ntemperatures, nwalkers).
|
|
1082
|
+
parent_folder (str, optional): Folder to save the plots. Default is current directory.
|
|
1083
|
+
palette (str or list, optional): Seaborn color palette name or list of colors.
|
|
1084
|
+
"""
|
|
1085
|
+
|
|
1086
|
+
for branch, samples in chain.items():
|
|
1087
|
+
branch_labels = labels.get(branch, None) if labels else None
|
|
1088
|
+
|
|
1089
|
+
# create branch folder
|
|
1090
|
+
branch_folder = os.path.join(parent_folder, branch)
|
|
1091
|
+
os.makedirs(branch_folder, exist_ok=True)
|
|
1092
|
+
|
|
1093
|
+
tempering_ridgeplot(
|
|
1094
|
+
samples,
|
|
1095
|
+
labels=branch_labels,
|
|
1096
|
+
palette=palette,
|
|
1097
|
+
filename=os.path.join(branch_folder, f'tempering_ridgeplot.png')
|
|
1098
|
+
)
|
|
1099
|
+
|
|
1100
|
+
plot_swap_acceptance(
|
|
1101
|
+
swap_acceptance_fraction,
|
|
1102
|
+
palette=palette,
|
|
1103
|
+
filename=os.path.join(parent_folder, f'swap_acceptance.png')
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
plot_logl_betas(
|
|
1107
|
+
betas,
|
|
1108
|
+
logl,
|
|
1109
|
+
palette=palette,
|
|
1110
|
+
filename=os.path.join(parent_folder, f'logl_betas.png')
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
plot_betas_evolution(
|
|
1114
|
+
betas,
|
|
1115
|
+
palette=palette,
|
|
1116
|
+
filename=os.path.join(parent_folder, f'betas_evolution.png')
|
|
1117
|
+
)
|
|
1118
|
+
|
|
1119
|
+
def produce_advanced_plots(steps: typing.Union[np.ndarray, list],
|
|
1120
|
+
total_acceptance_fraction: np.ndarray,
|
|
1121
|
+
moves_acceptance_fraction: dict,
|
|
1122
|
+
palette: str = None,
|
|
1123
|
+
iteration: int = 0,
|
|
1124
|
+
chain: dict = None,
|
|
1125
|
+
parent_folder: str = '.'):
|
|
1126
|
+
"""
|
|
1127
|
+
Produce advanced diagnostic plots. These include:
|
|
1128
|
+
|
|
1129
|
+
* autocorrelation time evolution per parameter per branch in the cold chain,
|
|
1130
|
+
* the comparison of the maximum autocorrelation time in each branch against the number of steps,
|
|
1131
|
+
* the acceptance fraction evolution over steps in the cold chain (both overall and per move),
|
|
1132
|
+
* the overall acceptance fraction evolution over steps per temperature.
|
|
1133
|
+
|
|
1134
|
+
Args:
|
|
1135
|
+
steps (Union[np.ndarray, list]): Array or list of sampling steps.
|
|
1136
|
+
total_acceptance_fraction (np.ndarray): Total acceptance fraction array of shape (nsteps, ntemps, nwalkers).
|
|
1137
|
+
moves_acceptance_fraction (Dict): Dictionary of acceptance fractions for different moves.
|
|
1138
|
+
parent_folder (str, optional): Folder to save the plots. Default is current directory.
|
|
1139
|
+
"""
|
|
1140
|
+
|
|
1141
|
+
plot_acceptance_fraction(
|
|
1142
|
+
steps,
|
|
1143
|
+
total_acceptance_fraction,
|
|
1144
|
+
moves_acceptance_fraction,
|
|
1145
|
+
filename=os.path.join(parent_folder, f'acceptance_fraction.png')
|
|
1146
|
+
)
|
|
1147
|
+
|
|
1148
|
+
plot_tempered_acceptance_fraction(
|
|
1149
|
+
steps,
|
|
1150
|
+
total_acceptance_fraction,
|
|
1151
|
+
palette=palette,
|
|
1152
|
+
filename=os.path.join(parent_folder, f'tempered_acceptance_fraction.png')
|
|
1153
|
+
)
|
|
1154
|
+
|
|
1155
|
+
plot_act_evolution(
|
|
1156
|
+
chain,
|
|
1157
|
+
iteration=iteration,
|
|
1158
|
+
parent_folder=parent_folder
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
def produce_rj_plots(nleaves: dict,
|
|
1162
|
+
nleaves_min: dict,
|
|
1163
|
+
nleaves_max: dict,
|
|
1164
|
+
palette: str = None,
|
|
1165
|
+
parent_folder: str = '.',
|
|
1166
|
+
iteration: int = 0):
|
|
1167
|
+
|
|
1168
|
+
"""
|
|
1169
|
+
Produce RJ diagnostic plots for each branch in the chain. At present, only plots the histogram of the number of leaves across temperatures.
|
|
1170
|
+
|
|
1171
|
+
Args:
|
|
1172
|
+
nleaves (Dict): Dictionary of number of leaves arrays for different branches.
|
|
1173
|
+
nleaves_min (Dict): Dictionary of minimum number of leaves for different branches.
|
|
1174
|
+
nleaves_max (Dict): Dictionary of maximum number of leaves for different branches.
|
|
1175
|
+
palette (str or list, optional): Seaborn color palette name or list of colors.
|
|
1176
|
+
parent_folder (str, optional): Folder to save the plots. Default is current directory.
|
|
1177
|
+
iteration (int, optional): Current iteration number for labeling plots.
|
|
1178
|
+
"""
|
|
1179
|
+
|
|
1180
|
+
for branch, leaves in nleaves.items():
|
|
1181
|
+
branch_nleaves_min = nleaves_min.get(branch, 1)
|
|
1182
|
+
branch_nleaves_max = nleaves_max.get(branch, 1)
|
|
1183
|
+
|
|
1184
|
+
# skip if there is no reversible-jump sampling for this branch
|
|
1185
|
+
if branch_nleaves_max <= branch_nleaves_min:
|
|
1186
|
+
continue
|
|
1187
|
+
# create branch folder
|
|
1188
|
+
branch_folder = os.path.join(parent_folder, branch)
|
|
1189
|
+
os.makedirs(branch_folder, exist_ok=True)
|
|
1190
|
+
|
|
1191
|
+
plot_leaves(
|
|
1192
|
+
leaves,
|
|
1193
|
+
branch_nleaves_min,
|
|
1194
|
+
branch_nleaves_max,
|
|
1195
|
+
palette=palette,
|
|
1196
|
+
iteration=iteration,
|
|
1197
|
+
filename=os.path.join(branch_folder, f'rj_nleaves.png')
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
# plot_leaves_evolution(
|
|
1201
|
+
# leaves[:, 0],
|
|
1202
|
+
# filename=os.path.join(branch_folder, f'rj_nleaves_evolution.png')
|
|
1203
|
+
# )
|
|
1204
|
+
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
class PlotContainer:
|
|
1209
|
+
"""
|
|
1210
|
+
An Update that generates diagnostic plots at specified intervals
|
|
1211
|
+
|
|
1212
|
+
Args:
|
|
1213
|
+
plots (list or str): List of plot types to generate. Options are 'base', 'tempering', 'advanced', 'rj', or 'all'. If multiple plot types are desired, provide a list of strings.
|
|
1214
|
+
branches (list, optional): List of branch names to generate plots for. If None, all branches are used.
|
|
1215
|
+
truths (dict, optional): Dictionary of true parameter values for different branches.
|
|
1216
|
+
overlay_covariance (dict, optional): Dictionary of covariance matrices to overlay on corner plots.
|
|
1217
|
+
tempering_palette (str or list, optional): Seaborn color palette name or list of colors for tempering plots. If None, it defaults to `icefire`.
|
|
1218
|
+
parent_folder (str, optional): Folder to save the plots. Default is current directory.
|
|
1219
|
+
discard (float, optional): Number of initial samples to discard from the chain before plotting. If between 0 and 1, it is treated as fraction of total samples. Default is 0.
|
|
1220
|
+
stop (int, optional): Maximum number of steps to generate plots for. Default is 10000.
|
|
1221
|
+
"""
|
|
1222
|
+
|
|
1223
|
+
def __init__(self,
|
|
1224
|
+
backend: Backend = None,
|
|
1225
|
+
plots : typing.Union[list, str] = 'base',
|
|
1226
|
+
branches : list = None,
|
|
1227
|
+
truths: dict = None,
|
|
1228
|
+
overlay_covariance: dict = None,
|
|
1229
|
+
tempering_palette: str = None,
|
|
1230
|
+
parent_folder: str = '.',
|
|
1231
|
+
discard: float = 0,
|
|
1232
|
+
stop: int = int(1e4),
|
|
1233
|
+
):
|
|
1234
|
+
"""
|
|
1235
|
+
Initialize the PlotContainer.
|
|
1236
|
+
"""
|
|
1237
|
+
|
|
1238
|
+
self.backend = backend
|
|
1239
|
+
|
|
1240
|
+
self.parent_folder = parent_folder
|
|
1241
|
+
os.makedirs(self.parent_folder, exist_ok=True)
|
|
1242
|
+
|
|
1243
|
+
allowable_plots = ['base', 'tempering', 'advanced', 'rj']
|
|
1244
|
+
self.branches = branches
|
|
1245
|
+
|
|
1246
|
+
if isinstance(plots, str):
|
|
1247
|
+
if plots == 'all':
|
|
1248
|
+
plots = allowable_plots
|
|
1249
|
+
else:
|
|
1250
|
+
plots = [plots]
|
|
1251
|
+
|
|
1252
|
+
for plot in plots:
|
|
1253
|
+
if plot not in allowable_plots:
|
|
1254
|
+
raise ValueError(f"Plot type '{plot}' not recognized. Allowable types: {allowable_plots}")
|
|
1255
|
+
self.plots = plots
|
|
1256
|
+
|
|
1257
|
+
#self.labels = labels
|
|
1258
|
+
self.truths = truths
|
|
1259
|
+
self.overlay_covariance = overlay_covariance
|
|
1260
|
+
self.discard = discard
|
|
1261
|
+
self.tempering_palette = tempering_palette
|
|
1262
|
+
|
|
1263
|
+
self.steps = []
|
|
1264
|
+
self.total_acceptance_fraction = None
|
|
1265
|
+
self.move_acceptance_fractions = {}
|
|
1266
|
+
|
|
1267
|
+
self.stop = stop
|
|
1268
|
+
|
|
1269
|
+
@property
|
|
1270
|
+
def backend(self):
|
|
1271
|
+
return self._backend
|
|
1272
|
+
@backend.setter
|
|
1273
|
+
def backend(self, value):
|
|
1274
|
+
self._backend = value
|
|
1275
|
+
|
|
1276
|
+
@property
|
|
1277
|
+
def truths(self):
|
|
1278
|
+
return self._truths
|
|
1279
|
+
@truths.setter
|
|
1280
|
+
def truths(self, value):
|
|
1281
|
+
self._truths = value
|
|
1282
|
+
|
|
1283
|
+
@property
|
|
1284
|
+
def overlay_covariance(self):
|
|
1285
|
+
return self._overlay_covariance
|
|
1286
|
+
@overlay_covariance.setter
|
|
1287
|
+
def overlay_covariance(self, value):
|
|
1288
|
+
self._overlay_covariance = value
|
|
1289
|
+
|
|
1290
|
+
def produce_plots(self, sampler=None) -> None:
|
|
1291
|
+
"""
|
|
1292
|
+
Generate diagnostic plots at specified intervals.
|
|
1293
|
+
|
|
1294
|
+
Args:
|
|
1295
|
+
sampler: The sampler object. If not provided, uses self.backend. In this cases not all the plots could be available.
|
|
1296
|
+
Returns:
|
|
1297
|
+
None
|
|
1298
|
+
"""
|
|
1299
|
+
|
|
1300
|
+
if self.backend.iteration > self.stop:
|
|
1301
|
+
return
|
|
1302
|
+
|
|
1303
|
+
labels = self.backend.key_order
|
|
1304
|
+
|
|
1305
|
+
discard = int(self.discard) if self.discard >= 1 else int(self.discard * self.backend.iteration)
|
|
1306
|
+
chain = self.backend.get_chain(discard=discard)
|
|
1307
|
+
logl = self.backend.get_log_like(discard=discard)
|
|
1308
|
+
betas = self.backend.get_betas(discard=discard)
|
|
1309
|
+
|
|
1310
|
+
if self.branches is not None:
|
|
1311
|
+
chain = {branch: chain[branch] for branch in self.branches if branch in chain}
|
|
1312
|
+
logl = {branch: logl[branch] for branch in self.branches if branch in logl}
|
|
1313
|
+
betas = {branch: betas[branch] for branch in self.branches if branch in betas}
|
|
1314
|
+
|
|
1315
|
+
for plot in self.plots:
|
|
1316
|
+
base_folder = os.path.join(self.parent_folder, plot)
|
|
1317
|
+
os.makedirs(base_folder, exist_ok=True)
|
|
1318
|
+
|
|
1319
|
+
if plot == 'base':
|
|
1320
|
+
produce_base_plots(
|
|
1321
|
+
chain=chain,
|
|
1322
|
+
logl=logl,
|
|
1323
|
+
truths=self.truths,
|
|
1324
|
+
overlay_covariance=self.overlay_covariance,
|
|
1325
|
+
iteration=self.backend.iteration,
|
|
1326
|
+
labels=labels,
|
|
1327
|
+
parent_folder=base_folder
|
|
1328
|
+
)
|
|
1329
|
+
elif plot == 'tempering':
|
|
1330
|
+
swap_acceptance_fraction = self.backend.swaps_accepted / float(self.backend.iteration * self.backend.nwalkers)
|
|
1331
|
+
produce_tempering_plots(
|
|
1332
|
+
chain=chain,
|
|
1333
|
+
betas=betas,
|
|
1334
|
+
logl=logl,
|
|
1335
|
+
swap_acceptance_fraction=swap_acceptance_fraction,
|
|
1336
|
+
labels=labels,
|
|
1337
|
+
parent_folder=base_folder,
|
|
1338
|
+
palette=self.tempering_palette
|
|
1339
|
+
)
|
|
1340
|
+
|
|
1341
|
+
elif plot == 'advanced':
|
|
1342
|
+
self.steps.append(self.backend.iteration)
|
|
1343
|
+
if self.total_acceptance_fraction is None:
|
|
1344
|
+
self.total_acceptance_fraction = (self.backend.accepted / float(self.backend.iteration))[np.newaxis, ...]
|
|
1345
|
+
else:
|
|
1346
|
+
self.total_acceptance_fraction = np.vstack((self.total_acceptance_fraction, (self.backend.accepted / float(self.backend.iteration))[np.newaxis, ...])) # shape (niterations, ntemps, nwalkers)
|
|
1347
|
+
|
|
1348
|
+
if sampler is not None:
|
|
1349
|
+
moves = sampler.moves
|
|
1350
|
+
elif hasattr(self.backend, moves):
|
|
1351
|
+
moves = self.backend.moves
|
|
1352
|
+
else:
|
|
1353
|
+
moves = None
|
|
1354
|
+
|
|
1355
|
+
if moves is not None:
|
|
1356
|
+
for move in moves:
|
|
1357
|
+
name = move.__class__.__name__
|
|
1358
|
+
if name not in self.move_acceptance_fractions:
|
|
1359
|
+
self.move_acceptance_fractions[name] = move.acceptance_fraction[np.newaxis, ...]
|
|
1360
|
+
else:
|
|
1361
|
+
self.move_acceptance_fractions[name] = np.vstack((self.move_acceptance_fractions[name], move.acceptance_fraction[np.newaxis, ...])) # shape (niterations, ntemps, nwalkers)
|
|
1362
|
+
|
|
1363
|
+
full_chain = self.backend.get_chain(discard=0) if discard > 0 else chain
|
|
1364
|
+
|
|
1365
|
+
produce_advanced_plots(steps=self.steps,
|
|
1366
|
+
total_acceptance_fraction=self.total_acceptance_fraction,
|
|
1367
|
+
moves_acceptance_fraction=self.move_acceptance_fractions,
|
|
1368
|
+
palette=self.tempering_palette,
|
|
1369
|
+
iteration=self.backend.iteration,
|
|
1370
|
+
chain=full_chain,
|
|
1371
|
+
parent_folder=base_folder
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
elif plot == 'rj':
|
|
1375
|
+
if self.backend.rj is False:
|
|
1376
|
+
continue
|
|
1377
|
+
|
|
1378
|
+
nleaves = self.backend.get_nleaves(discard=discard)
|
|
1379
|
+
|
|
1380
|
+
nleaves_min = sampler.nleaves_min if sampler is not None else dict(zip(self.backend.rj_branches, [0]*len(self.backend.rj_branches)))
|
|
1381
|
+
nleaves_max = self.backend.nleaves_max
|
|
1382
|
+
|
|
1383
|
+
produce_rj_plots(
|
|
1384
|
+
nleaves=nleaves,
|
|
1385
|
+
nleaves_min=nleaves_min,
|
|
1386
|
+
nleaves_max=nleaves_max,
|
|
1387
|
+
palette=self.tempering_palette,
|
|
1388
|
+
parent_folder=base_folder,
|
|
1389
|
+
iteration=self.backend.iteration
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
|
|
1393
|
+
|