eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,819 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from __future__ import division, print_function
|
|
4
|
+
|
|
5
|
+
__all__ = ["HDFBackend", "TempHDFBackend", "does_hdf5_support_longdouble"]
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
from tempfile import NamedTemporaryFile
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from .. import __version__
|
|
14
|
+
from .backend import Backend
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import h5py
|
|
19
|
+
except ImportError:
|
|
20
|
+
h5py = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def does_hdf5_support_longdouble():
|
|
24
|
+
if h5py is None:
|
|
25
|
+
return False
|
|
26
|
+
with NamedTemporaryFile(
|
|
27
|
+
prefix="emcee-temporary-hdf5", suffix=".hdf5", delete=False
|
|
28
|
+
) as f:
|
|
29
|
+
f.close()
|
|
30
|
+
|
|
31
|
+
with h5py.File(f.name, "w") as hf:
|
|
32
|
+
g = hf.create_group("group")
|
|
33
|
+
g.create_dataset("data", data=np.ones(1, dtype=np.longdouble))
|
|
34
|
+
if g["data"].dtype != np.longdouble:
|
|
35
|
+
return False
|
|
36
|
+
with h5py.File(f.name, "r") as hf:
|
|
37
|
+
if hf["group"]["data"].dtype != np.longdouble:
|
|
38
|
+
return False
|
|
39
|
+
return True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class HDFBackend(Backend):
|
|
43
|
+
"""A backend that stores the chain in an HDF5 file using h5py
|
|
44
|
+
|
|
45
|
+
.. note:: You must install `h5py <http://www.h5py.org/>`_ to use this
|
|
46
|
+
backend.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
filename (str): The name of the HDF5 file where the chain will be
|
|
50
|
+
saved.
|
|
51
|
+
name (str, optional): The name of the group where the chain will
|
|
52
|
+
be saved. (default: ``"mcmc"``)
|
|
53
|
+
read_only (bool, optional): If ``True``, the backend will throw a
|
|
54
|
+
``RuntimeError`` if the file is opened with write access.
|
|
55
|
+
(default: ``False``)
|
|
56
|
+
dtype (dtype, optional): Dtype to use for data storage. If None,
|
|
57
|
+
program uses np.float64. (default: ``None``)
|
|
58
|
+
compression (str, optional): Compression type for h5 file. See more information
|
|
59
|
+
in the
|
|
60
|
+
`h5py documentation <https://docs.h5py.org/en/stable/high/dataset.html#filter-pipeline>`_.
|
|
61
|
+
(default: ``None``)
|
|
62
|
+
compression_opts (int, optional): Compression level for h5 file. See more information
|
|
63
|
+
in the
|
|
64
|
+
`h5py documentation <https://docs.h5py.org/en/stable/high/dataset.html#filter-pipeline>`_.
|
|
65
|
+
(default: ``None``)
|
|
66
|
+
store_missing_leaves (double, optional): Number to store for leaves that are not
|
|
67
|
+
used in a specific step. (default: ``np.nan``)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
filename,
|
|
75
|
+
name="mcmc",
|
|
76
|
+
read_only=False,
|
|
77
|
+
dtype=None,
|
|
78
|
+
compression=None,
|
|
79
|
+
compression_opts=None,
|
|
80
|
+
store_missing_leaves=np.nan,
|
|
81
|
+
):
|
|
82
|
+
if h5py is None:
|
|
83
|
+
raise ImportError("you must install 'h5py' to use the HDFBackend")
|
|
84
|
+
|
|
85
|
+
# store all necessary quantities
|
|
86
|
+
self.filename = filename
|
|
87
|
+
self.name = name
|
|
88
|
+
self.read_only = read_only
|
|
89
|
+
self.compression = compression
|
|
90
|
+
self.compression_opts = compression_opts
|
|
91
|
+
if dtype is None:
|
|
92
|
+
self.dtype_set = False
|
|
93
|
+
self.dtype = np.float64
|
|
94
|
+
else:
|
|
95
|
+
self.dtype_set = True
|
|
96
|
+
self.dtype = dtype
|
|
97
|
+
|
|
98
|
+
self.store_missing_leaves = store_missing_leaves
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def initialized(self):
|
|
102
|
+
"""Check if backend file has been initialized properly."""
|
|
103
|
+
if not os.path.exists(self.filename):
|
|
104
|
+
return False
|
|
105
|
+
try:
|
|
106
|
+
with self.open() as f:
|
|
107
|
+
return self.name in f
|
|
108
|
+
except (OSError, IOError):
|
|
109
|
+
return False
|
|
110
|
+
|
|
111
|
+
def open(self, mode="r"):
|
|
112
|
+
"""Opens the h5 file in the proper mode.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
mode (str, optional): Mode to open h5 file.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
H5 file object: Opened file.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
RuntimeError: If backend is opened for writing when it is read-only.
|
|
122
|
+
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
if self.read_only and mode != "r":
|
|
126
|
+
raise RuntimeError(
|
|
127
|
+
"The backend has been loaded in read-only "
|
|
128
|
+
"mode. Set `read_only = False` to make "
|
|
129
|
+
"changes."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# open the file
|
|
133
|
+
file_opened = False
|
|
134
|
+
|
|
135
|
+
try_num = 0
|
|
136
|
+
max_tries = 100
|
|
137
|
+
while not file_opened:
|
|
138
|
+
try:
|
|
139
|
+
f = h5py.File(self.filename, mode)
|
|
140
|
+
file_opened = True
|
|
141
|
+
|
|
142
|
+
except BlockingIOError:
|
|
143
|
+
try_num += 1
|
|
144
|
+
if try_num >= max_tries:
|
|
145
|
+
raise BlockingIOError("Max tries exceeded trying to open h5 file.")
|
|
146
|
+
print("Failed to open h5 file. Trying again.")
|
|
147
|
+
time.sleep(10.0)
|
|
148
|
+
|
|
149
|
+
# get the data type and store it if it is not previously set
|
|
150
|
+
if not self.dtype_set and self.name in f:
|
|
151
|
+
# get the group from the file
|
|
152
|
+
g = f[self.name]
|
|
153
|
+
if "chain" in g:
|
|
154
|
+
# get the model names in chain
|
|
155
|
+
keys = list(g["chain"])
|
|
156
|
+
|
|
157
|
+
# they all have the same dtype so use the first one
|
|
158
|
+
try:
|
|
159
|
+
self.dtype = g["chain"][keys[0]].dtype
|
|
160
|
+
|
|
161
|
+
# we now have it
|
|
162
|
+
self.dtype_set = True
|
|
163
|
+
# catch error if the chain has not been initialized yet
|
|
164
|
+
except IndexError:
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
return f
|
|
168
|
+
|
|
169
|
+
def reset(
|
|
170
|
+
self,
|
|
171
|
+
nwalkers,
|
|
172
|
+
ndims,
|
|
173
|
+
nleaves_max=1,
|
|
174
|
+
ntemps=1,
|
|
175
|
+
branch_names=None,
|
|
176
|
+
nbranches=1,
|
|
177
|
+
rj=False,
|
|
178
|
+
moves=None,
|
|
179
|
+
**info,
|
|
180
|
+
):
|
|
181
|
+
"""Clear the state of the chain and empty the backend
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
nwalkers (int): The size of the ensemble
|
|
185
|
+
ndims (int, list of ints, or dict): The number of dimensions for each branch. If
|
|
186
|
+
``dict``, keys should be the branch names and values the associated dimensionality.
|
|
187
|
+
nleaves_max (int, list of ints, or dict, optional): Maximum allowable leaf count for each branch.
|
|
188
|
+
It should have the same length as the number of branches.
|
|
189
|
+
If ``dict``, keys should be the branch names and values the associated maximal leaf value.
|
|
190
|
+
(default: ``1``)
|
|
191
|
+
ntemps (int, optional): Number of rungs in the temperature ladder.
|
|
192
|
+
(default: ``1``)
|
|
193
|
+
branch_names (str or list of str, optional): Names of the branches used. If not given,
|
|
194
|
+
branches will be names ``model_0``, ..., ``model_n`` for ``n`` branches.
|
|
195
|
+
(default: ``None``)
|
|
196
|
+
nbranches (int, optional): Number of branches. This is only used if ``branch_names is None``.
|
|
197
|
+
(default: ``1``)
|
|
198
|
+
rj (bool, optional): If True, reversible-jump techniques are used.
|
|
199
|
+
(default: ``False``)
|
|
200
|
+
moves (list, optional): List of all of the move classes input into the sampler.
|
|
201
|
+
(default: ``None``)
|
|
202
|
+
**info (dict, optional): Any other key-value pairs to be added
|
|
203
|
+
as attributes to the backend. These are also added to the HDF5 file.
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
# open file in append mode
|
|
208
|
+
with self.open("a") as f:
|
|
209
|
+
# we are resetting so if self.name in the file we need to delete it
|
|
210
|
+
if self.name in f:
|
|
211
|
+
del f[self.name]
|
|
212
|
+
|
|
213
|
+
# turn things into lists/dicts if needed
|
|
214
|
+
if branch_names is not None:
|
|
215
|
+
if isinstance(branch_names, str):
|
|
216
|
+
branch_names = [branch_names]
|
|
217
|
+
|
|
218
|
+
elif not isinstance(branch_names, list):
|
|
219
|
+
raise ValueError("branch_names must be string or list of strings.")
|
|
220
|
+
|
|
221
|
+
else:
|
|
222
|
+
branch_names = ["model_{}".format(i) for i in range(nbranches)]
|
|
223
|
+
|
|
224
|
+
nbranches = len(branch_names)
|
|
225
|
+
|
|
226
|
+
if isinstance(ndims, int):
|
|
227
|
+
assert len(branch_names) == 1
|
|
228
|
+
ndims = {branch_names[0]: ndims}
|
|
229
|
+
|
|
230
|
+
elif isinstance(ndims, list) or isinstance(ndims, np.ndarray):
|
|
231
|
+
assert len(branch_names) == len(ndims)
|
|
232
|
+
ndims = {bn: nd for bn, nd in zip(branch_names, ndims)}
|
|
233
|
+
|
|
234
|
+
elif isinstance(ndims, dict):
|
|
235
|
+
assert len(list(ndims.keys())) == len(branch_names)
|
|
236
|
+
for key in ndims:
|
|
237
|
+
if key not in branch_names:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"{key} is in ndims but does not appear in branch_names: {branch_names}."
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError("ndims is to be a scalar int, list or dict.")
|
|
243
|
+
|
|
244
|
+
if isinstance(nleaves_max, int):
|
|
245
|
+
assert len(branch_names) == 1
|
|
246
|
+
nleaves_max = {branch_names[0]: nleaves_max}
|
|
247
|
+
|
|
248
|
+
elif isinstance(nleaves_max, list) or isinstance(nleaves_max, np.ndarray):
|
|
249
|
+
assert len(branch_names) == len(nleaves_max)
|
|
250
|
+
nleaves_max = {bn: nl for bn, nl in zip(branch_names, nleaves_max)}
|
|
251
|
+
|
|
252
|
+
elif isinstance(nleaves_max, dict):
|
|
253
|
+
assert len(list(nleaves_max.keys())) == len(branch_names)
|
|
254
|
+
for key in nleaves_max:
|
|
255
|
+
if key not in branch_names:
|
|
256
|
+
raise ValueError(
|
|
257
|
+
f"{key} is in nleaves_max but does not appear in branch_names: {branch_names}."
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
raise ValueError("nleaves_max is to be a scalar int, list, or dict.")
|
|
261
|
+
|
|
262
|
+
# store all the info needed in memory and in the file
|
|
263
|
+
|
|
264
|
+
g = f.create_group(self.name)
|
|
265
|
+
|
|
266
|
+
g.attrs["version"] = __version__
|
|
267
|
+
g.attrs["nbranches"] = len(branch_names)
|
|
268
|
+
g.attrs["branch_names"] = branch_names
|
|
269
|
+
g.attrs["ntemps"] = ntemps
|
|
270
|
+
g.attrs["nwalkers"] = nwalkers
|
|
271
|
+
g.attrs["has_blobs"] = False
|
|
272
|
+
g.attrs["rj"] = rj
|
|
273
|
+
g.attrs["iteration"] = 0
|
|
274
|
+
|
|
275
|
+
# create info group
|
|
276
|
+
g.create_group("info")
|
|
277
|
+
# load info into class and into file
|
|
278
|
+
for key, value in info.items():
|
|
279
|
+
setattr(self, key, value)
|
|
280
|
+
g["info"].attrs[key] = value
|
|
281
|
+
|
|
282
|
+
# store nleaves max and ndims dicts
|
|
283
|
+
g.create_group("ndims")
|
|
284
|
+
for key, value in ndims.items():
|
|
285
|
+
g["ndims"].attrs[key] = value
|
|
286
|
+
|
|
287
|
+
g.create_group("nleaves_max")
|
|
288
|
+
for key, value in nleaves_max.items():
|
|
289
|
+
g["nleaves_max"].attrs[key] = value
|
|
290
|
+
|
|
291
|
+
# prepare all the data sets
|
|
292
|
+
|
|
293
|
+
g.create_dataset(
|
|
294
|
+
"accepted",
|
|
295
|
+
data=np.zeros((ntemps, nwalkers)),
|
|
296
|
+
compression=self.compression,
|
|
297
|
+
compression_opts=self.compression_opts,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
g.create_dataset(
|
|
301
|
+
"swaps_accepted",
|
|
302
|
+
data=np.zeros((ntemps - 1,)),
|
|
303
|
+
compression=self.compression,
|
|
304
|
+
compression_opts=self.compression_opts,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
if self.rj:
|
|
308
|
+
g.create_dataset(
|
|
309
|
+
"rj_accepted",
|
|
310
|
+
data=np.zeros((ntemps, nwalkers)),
|
|
311
|
+
compression=self.compression,
|
|
312
|
+
compression_opts=self.compression_opts,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
g.create_dataset(
|
|
316
|
+
"log_like",
|
|
317
|
+
(0, ntemps, nwalkers),
|
|
318
|
+
maxshape=(None, ntemps, nwalkers),
|
|
319
|
+
dtype=self.dtype,
|
|
320
|
+
compression=self.compression,
|
|
321
|
+
compression_opts=self.compression_opts,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
g.create_dataset(
|
|
325
|
+
"log_prior",
|
|
326
|
+
(0, ntemps, nwalkers),
|
|
327
|
+
maxshape=(None, ntemps, nwalkers),
|
|
328
|
+
dtype=self.dtype,
|
|
329
|
+
compression=self.compression,
|
|
330
|
+
compression_opts=self.compression_opts,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
g.create_dataset(
|
|
334
|
+
"betas",
|
|
335
|
+
(0, ntemps),
|
|
336
|
+
maxshape=(None, ntemps),
|
|
337
|
+
dtype=self.dtype,
|
|
338
|
+
compression=self.compression,
|
|
339
|
+
compression_opts=self.compression_opts,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# setup data sets for branch-specific items
|
|
343
|
+
|
|
344
|
+
chain = g.create_group("chain")
|
|
345
|
+
inds = g.create_group("inds")
|
|
346
|
+
|
|
347
|
+
for name in branch_names:
|
|
348
|
+
nleaves = self.nleaves_max[name]
|
|
349
|
+
ndim = self.ndims[name]
|
|
350
|
+
chain.create_dataset(
|
|
351
|
+
name,
|
|
352
|
+
(0, ntemps, nwalkers, nleaves, ndim),
|
|
353
|
+
maxshape=(None, ntemps, nwalkers, nleaves, ndim),
|
|
354
|
+
dtype=self.dtype,
|
|
355
|
+
compression=self.compression,
|
|
356
|
+
compression_opts=self.compression_opts,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
inds.create_dataset(
|
|
360
|
+
name,
|
|
361
|
+
(0, ntemps, nwalkers, nleaves),
|
|
362
|
+
maxshape=(None, ntemps, nwalkers, nleaves),
|
|
363
|
+
dtype=bool,
|
|
364
|
+
compression=self.compression,
|
|
365
|
+
compression_opts=self.compression_opts,
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
# store move specific information
|
|
369
|
+
if moves is not None:
|
|
370
|
+
move_group = g.create_group("moves")
|
|
371
|
+
# setup info and keys
|
|
372
|
+
for full_move_name in moves:
|
|
373
|
+
|
|
374
|
+
single_move = move_group.create_group(full_move_name)
|
|
375
|
+
|
|
376
|
+
# prepare information dictionary
|
|
377
|
+
single_move.create_dataset(
|
|
378
|
+
"acceptance_fraction",
|
|
379
|
+
(ntemps, nwalkers),
|
|
380
|
+
maxshape=(ntemps, nwalkers),
|
|
381
|
+
dtype=self.dtype,
|
|
382
|
+
compression=self.compression,
|
|
383
|
+
compression_opts=self.compression_opts,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
else:
|
|
387
|
+
self.move_info = None
|
|
388
|
+
|
|
389
|
+
self.blobs = None
|
|
390
|
+
|
|
391
|
+
@property
|
|
392
|
+
def nwalkers(self):
|
|
393
|
+
"""Get nwalkers from h5 file."""
|
|
394
|
+
with self.open() as f:
|
|
395
|
+
return f[self.name].attrs["nwalkers"]
|
|
396
|
+
|
|
397
|
+
@property
|
|
398
|
+
def ntemps(self):
|
|
399
|
+
"""Get ntemps from h5 file."""
|
|
400
|
+
with self.open() as f:
|
|
401
|
+
return f[self.name].attrs["ntemps"]
|
|
402
|
+
|
|
403
|
+
@property
|
|
404
|
+
def rj(self):
|
|
405
|
+
"""Get rj from h5 file."""
|
|
406
|
+
with self.open() as f:
|
|
407
|
+
return f[self.name].attrs["rj"]
|
|
408
|
+
|
|
409
|
+
@property
|
|
410
|
+
def nleaves_max(self):
|
|
411
|
+
"""Get nleaves_max from h5 file."""
|
|
412
|
+
with self.open() as f:
|
|
413
|
+
return {
|
|
414
|
+
key: f[self.name]["nleaves_max"].attrs[key]
|
|
415
|
+
for key in f[self.name]["nleaves_max"].attrs
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def ndims(self):
|
|
420
|
+
"""Get ndims from h5 file."""
|
|
421
|
+
with self.open() as f:
|
|
422
|
+
return {
|
|
423
|
+
key: f[self.name]["ndims"].attrs[key]
|
|
424
|
+
for key in f[self.name]["ndims"].attrs
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
@property
|
|
428
|
+
def move_keys(self):
|
|
429
|
+
"""Get move_keys from h5 file."""
|
|
430
|
+
with self.open() as f:
|
|
431
|
+
return list(f[self.name]["moves"])
|
|
432
|
+
|
|
433
|
+
@property
|
|
434
|
+
def branch_names(self):
|
|
435
|
+
"""Get branch names from h5 file."""
|
|
436
|
+
with self.open() as f:
|
|
437
|
+
return f[self.name].attrs["branch_names"]
|
|
438
|
+
|
|
439
|
+
@property
|
|
440
|
+
def nbranches(self):
|
|
441
|
+
"""Get number of branches from h5 file."""
|
|
442
|
+
with self.open() as f:
|
|
443
|
+
return f[self.name].attrs["nbranches"]
|
|
444
|
+
|
|
445
|
+
@property
|
|
446
|
+
def reset_args(self):
|
|
447
|
+
"""Get reset_args from h5 file."""
|
|
448
|
+
return [self.nwalkers, self.ndims]
|
|
449
|
+
|
|
450
|
+
@property
|
|
451
|
+
def reset_kwargs(self):
|
|
452
|
+
"""Get reset_kwargs from h5 file."""
|
|
453
|
+
return dict(
|
|
454
|
+
nleaves_max=self.nleaves_max,
|
|
455
|
+
ntemps=self.ntemps,
|
|
456
|
+
branch_names=self.branch_names,
|
|
457
|
+
rj=self.rj,
|
|
458
|
+
moves=self.moves,
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
@property
|
|
462
|
+
def reset_kwargs(self):
|
|
463
|
+
"""Get reset_kwargs from h5 file."""
|
|
464
|
+
with self.open() as f:
|
|
465
|
+
return f[self.name].attrs["reset_kwargs"]
|
|
466
|
+
|
|
467
|
+
def has_blobs(self):
|
|
468
|
+
"""Returns ``True`` if the model includes blobs"""
|
|
469
|
+
with self.open() as f:
|
|
470
|
+
return f[self.name].attrs["has_blobs"]
|
|
471
|
+
|
|
472
|
+
def get_value(self, name, thin=1, discard=0, slice_vals=None, temp_index=None, branch_names=None):
|
|
473
|
+
"""Returns a requested value to user.
|
|
474
|
+
|
|
475
|
+
This function helps to streamline the backend for both
|
|
476
|
+
basic and hdf backend.
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
name (str): Name of value requested.
|
|
480
|
+
thin (int, optional): Take only every ``thin`` steps from the
|
|
481
|
+
chain. (default: ``1``)
|
|
482
|
+
discard (int, optional): Discard the first ``discard`` steps in
|
|
483
|
+
the chain as burn-in. (default: ``0``)
|
|
484
|
+
slice_vals (indexing np.ndarray or slice, optional): If provided, slice the array directly
|
|
485
|
+
from the HDF5 file with slice = ``slice_vals``. ``thin`` and ``discard`` will be
|
|
486
|
+
ignored if slice_vals is not ``None``. This is particularly useful if files are
|
|
487
|
+
very large and the user only wants a small subset of the overall array.
|
|
488
|
+
(default: ``None``)
|
|
489
|
+
temp_index (int, optional): Integer for the desired temperature index.
|
|
490
|
+
If ``None``, will return all temperatures. (default: ``None``)
|
|
491
|
+
branch_names (str or list, optional): Specific branch names requested. (default: ``None``)
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
dict or np.ndarray: Values requested.
|
|
495
|
+
|
|
496
|
+
"""
|
|
497
|
+
# check if initialized
|
|
498
|
+
if not self.initialized:
|
|
499
|
+
raise AttributeError(
|
|
500
|
+
"You must run the sampler with "
|
|
501
|
+
"'store == True' before accessing the "
|
|
502
|
+
"results."
|
|
503
|
+
"When using the HDF backend, make sure you have the file"
|
|
504
|
+
"path correctly set. This is the error that"
|
|
505
|
+
"is given if the backend cannot find the file."
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
if slice_vals is None:
|
|
509
|
+
slice_vals = slice(discard + thin - 1, self.iteration, thin)
|
|
510
|
+
|
|
511
|
+
# make sure branch_names input is a list
|
|
512
|
+
if branch_names is not None:
|
|
513
|
+
if isinstance(branch_names, str):
|
|
514
|
+
branches_names = [branch_names]
|
|
515
|
+
|
|
516
|
+
branch_names_in = self.branch_names if branch_names is None else branch_names
|
|
517
|
+
|
|
518
|
+
# open the file wrapped in a "with" statement
|
|
519
|
+
with self.open() as f:
|
|
520
|
+
# get the group that everything is stored in
|
|
521
|
+
g = f[self.name]
|
|
522
|
+
iteration = g.attrs["iteration"]
|
|
523
|
+
if iteration <= 0:
|
|
524
|
+
raise AttributeError(
|
|
525
|
+
"You must run the sampler with "
|
|
526
|
+
"'store == True' before accessing the "
|
|
527
|
+
"results"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if name == "blobs" and not g.attrs["has_blobs"]:
|
|
531
|
+
return None
|
|
532
|
+
|
|
533
|
+
if temp_index is None:
|
|
534
|
+
temp_index = np.arange(self.ntemps)
|
|
535
|
+
else:
|
|
536
|
+
assert isinstance(temp_index, int)
|
|
537
|
+
|
|
538
|
+
if name == "chain":
|
|
539
|
+
v_all = {key: g["chain"][key][slice_vals, temp_index] for key in branch_names_in}
|
|
540
|
+
return v_all
|
|
541
|
+
|
|
542
|
+
if name == "inds":
|
|
543
|
+
v_all = {key: g["inds"][key][slice_vals, temp_index] for key in branch_names_in}
|
|
544
|
+
|
|
545
|
+
return v_all
|
|
546
|
+
|
|
547
|
+
v = g[name][slice_vals, temp_index]
|
|
548
|
+
|
|
549
|
+
return v
|
|
550
|
+
|
|
551
|
+
def get_move_info(self):
|
|
552
|
+
"""Get move information.
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
dict: Keys are move names and values are dictionaries with information on the moves.
|
|
556
|
+
|
|
557
|
+
"""
|
|
558
|
+
# setup output dictionary
|
|
559
|
+
move_info_out = {}
|
|
560
|
+
with self.open() as f:
|
|
561
|
+
g = f[self.name]
|
|
562
|
+
|
|
563
|
+
# iterate through everything and produce a dictionary
|
|
564
|
+
for move_name in g["moves"]:
|
|
565
|
+
move_info_out[move_name] = {}
|
|
566
|
+
for info_name in g["moves"][move_name]:
|
|
567
|
+
move_info_out[move_name][info_name] = g["moves"][move_name][
|
|
568
|
+
info_name
|
|
569
|
+
][:]
|
|
570
|
+
|
|
571
|
+
return move_info_out
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def shape(self):
|
|
575
|
+
"""The dimensions of the ensemble
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
dict: Shape of samples
|
|
579
|
+
Keys are ``branch_names`` and values are tuples with
|
|
580
|
+
shapes of individual branches: (ntemps, nwalkers, nleaves_max, ndim).
|
|
581
|
+
|
|
582
|
+
"""
|
|
583
|
+
# open file wrapped in with
|
|
584
|
+
with self.open() as f:
|
|
585
|
+
g = f[self.name]
|
|
586
|
+
return {
|
|
587
|
+
key: (
|
|
588
|
+
g.attrs["ntemps"],
|
|
589
|
+
g.attrs["nwalkers"],
|
|
590
|
+
self.nleaves_max[key],
|
|
591
|
+
self.ndims[key],
|
|
592
|
+
)
|
|
593
|
+
for key in g.attrs["branch_names"]
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
@property
|
|
597
|
+
def iteration(self):
|
|
598
|
+
"""Number of iterations stored in the hdf backend so far."""
|
|
599
|
+
with self.open() as f:
|
|
600
|
+
return f[self.name].attrs["iteration"]
|
|
601
|
+
|
|
602
|
+
@property
|
|
603
|
+
def accepted(self):
|
|
604
|
+
"""Number of accepted moves per walker."""
|
|
605
|
+
with self.open() as f:
|
|
606
|
+
return f[self.name]["accepted"][...]
|
|
607
|
+
|
|
608
|
+
@property
|
|
609
|
+
def rj_accepted(self):
|
|
610
|
+
"""Number of accepted rj moves per walker."""
|
|
611
|
+
with self.open() as f:
|
|
612
|
+
return f[self.name]["rj_accepted"][...]
|
|
613
|
+
|
|
614
|
+
@property
|
|
615
|
+
def swaps_accepted(self):
|
|
616
|
+
"""Number of accepted swaps."""
|
|
617
|
+
with self.open() as f:
|
|
618
|
+
return f[self.name]["swaps_accepted"][...]
|
|
619
|
+
|
|
620
|
+
@property
|
|
621
|
+
def random_state(self):
|
|
622
|
+
"""Get the random state"""
|
|
623
|
+
with self.open() as f:
|
|
624
|
+
elements = [
|
|
625
|
+
v
|
|
626
|
+
for k, v in sorted(f[self.name].attrs.items())
|
|
627
|
+
if k.startswith("random_state_")
|
|
628
|
+
]
|
|
629
|
+
return elements if len(elements) else None
|
|
630
|
+
|
|
631
|
+
def grow(self, ngrow, blobs):
|
|
632
|
+
"""Expand the storage space by some number of samples
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
ngrow (int): The number of steps to grow the chain.
|
|
636
|
+
blobs (None or np.ndarray): The current array of blobs. This is used to compute the
|
|
637
|
+
dtype for the blobs array.
|
|
638
|
+
|
|
639
|
+
"""
|
|
640
|
+
self._check_blobs(blobs)
|
|
641
|
+
|
|
642
|
+
# open the file in append mode
|
|
643
|
+
with self.open("a") as f:
|
|
644
|
+
g = f[self.name]
|
|
645
|
+
|
|
646
|
+
# resize all the arrays accordingly
|
|
647
|
+
|
|
648
|
+
ntot = g.attrs["iteration"] + ngrow
|
|
649
|
+
for key in g["chain"]:
|
|
650
|
+
g["chain"][key].resize(ntot, axis=0)
|
|
651
|
+
g["inds"][key].resize(ntot, axis=0)
|
|
652
|
+
|
|
653
|
+
g["log_like"].resize(ntot, axis=0)
|
|
654
|
+
g["log_prior"].resize(ntot, axis=0)
|
|
655
|
+
g["betas"].resize(ntot, axis=0)
|
|
656
|
+
|
|
657
|
+
# deal with blobs
|
|
658
|
+
if blobs is not None:
|
|
659
|
+
has_blobs = g.attrs["has_blobs"]
|
|
660
|
+
# if blobs have not been added yet
|
|
661
|
+
if not has_blobs:
|
|
662
|
+
nwalkers = g.attrs["nwalkers"]
|
|
663
|
+
ntemps = g.attrs["ntemps"]
|
|
664
|
+
g.create_dataset(
|
|
665
|
+
"blobs",
|
|
666
|
+
(ntot, ntemps, nwalkers, blobs.shape[-1]),
|
|
667
|
+
maxshape=(None, ntemps, nwalkers, blobs.shape[-1]),
|
|
668
|
+
dtype=self.dtype,
|
|
669
|
+
compression=self.compression,
|
|
670
|
+
compression_opts=self.compression_opts,
|
|
671
|
+
)
|
|
672
|
+
else:
|
|
673
|
+
# resize the blobs if they have been there
|
|
674
|
+
g["blobs"].resize(ntot, axis=0)
|
|
675
|
+
if g["blobs"].shape[1:] != blobs.shape:
|
|
676
|
+
raise ValueError(
|
|
677
|
+
"Existing blobs have shape {} but new blobs "
|
|
678
|
+
"requested with shape {}".format(
|
|
679
|
+
g["blobs"].shape[1:], blobs.shape
|
|
680
|
+
)
|
|
681
|
+
)
|
|
682
|
+
g.attrs["has_blobs"] = True
|
|
683
|
+
|
|
684
|
+
def save_step(
|
|
685
|
+
self,
|
|
686
|
+
state,
|
|
687
|
+
accepted,
|
|
688
|
+
rj_accepted=None,
|
|
689
|
+
swaps_accepted=None,
|
|
690
|
+
moves_accepted_fraction=None,
|
|
691
|
+
):
|
|
692
|
+
"""Save a step to the backend
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
state (State): The :class:`State` of the ensemble.
|
|
696
|
+
accepted (ndarray): An array of boolean flags indicating whether
|
|
697
|
+
or not the proposal for each walker was accepted.
|
|
698
|
+
rj_accepted (ndarray, optional): An array of the number of accepted steps
|
|
699
|
+
for the reversible jump proposal for each walker.
|
|
700
|
+
If :code:`self.rj` is True, then rj_accepted must be an array with
|
|
701
|
+
:code:`rj_accepted.shape == accepted.shape`. If :code:`self.rj`
|
|
702
|
+
is False, then rj_accepted must be None, which is the default.
|
|
703
|
+
swaps_accepted (ndarray, optional): 1D array with number of swaps accepted
|
|
704
|
+
for the in-model step. (default: ``None``)
|
|
705
|
+
moves_accepted_fraction (dict, optional): Dict of acceptance fraction arrays for all of the
|
|
706
|
+
moves in the sampler. This dict must have the same keys as ``self.move_keys``.
|
|
707
|
+
(default: ``None``)
|
|
708
|
+
|
|
709
|
+
"""
|
|
710
|
+
file_opened = False
|
|
711
|
+
max_tries = 100
|
|
712
|
+
try_num = 0
|
|
713
|
+
while not file_opened:
|
|
714
|
+
try:
|
|
715
|
+
|
|
716
|
+
# open for appending in with statement
|
|
717
|
+
with self.open("a") as f:
|
|
718
|
+
g = f[self.name]
|
|
719
|
+
# get the iteration left off on
|
|
720
|
+
iteration = g.attrs["iteration"]
|
|
721
|
+
|
|
722
|
+
# make sure the backend has all the information needed to store everything
|
|
723
|
+
for key in [
|
|
724
|
+
"rj",
|
|
725
|
+
"ntemps",
|
|
726
|
+
"nwalkers",
|
|
727
|
+
"nbranches",
|
|
728
|
+
"branch_names",
|
|
729
|
+
"ndims",
|
|
730
|
+
]:
|
|
731
|
+
if not hasattr(self, key):
|
|
732
|
+
setattr(self, key, g.attrs[key])
|
|
733
|
+
|
|
734
|
+
# check the inputs are okay
|
|
735
|
+
self._check(
|
|
736
|
+
state,
|
|
737
|
+
accepted,
|
|
738
|
+
rj_accepted=rj_accepted,
|
|
739
|
+
swaps_accepted=swaps_accepted,
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
# branch-specific
|
|
743
|
+
for name, model in state.branches.items():
|
|
744
|
+
g["inds"][name][iteration] = model.inds
|
|
745
|
+
# use self.store_missing_leaves to set value for missing leaves
|
|
746
|
+
# state retains old coordinates
|
|
747
|
+
coords_in = model.coords * model.inds[:, :, :, None]
|
|
748
|
+
inds_all = np.repeat(
|
|
749
|
+
model.inds, coords_in.shape[-1], axis=-1
|
|
750
|
+
).reshape(model.inds.shape + (coords_in.shape[-1],))
|
|
751
|
+
coords_in[~inds_all] = self.store_missing_leaves
|
|
752
|
+
g["chain"][name][self.iteration] = coords_in
|
|
753
|
+
|
|
754
|
+
# store everything else in the file
|
|
755
|
+
g["log_like"][iteration, :] = state.log_like
|
|
756
|
+
g["log_prior"][iteration, :] = state.log_prior
|
|
757
|
+
if state.blobs is not None:
|
|
758
|
+
g["blobs"][iteration, :] = state.blobs
|
|
759
|
+
if state.betas is not None:
|
|
760
|
+
g["betas"][self.iteration, :] = state.betas
|
|
761
|
+
g["accepted"][:] += accepted
|
|
762
|
+
if swaps_accepted is not None:
|
|
763
|
+
g["swaps_accepted"][:] += swaps_accepted
|
|
764
|
+
if self.rj:
|
|
765
|
+
g["rj_accepted"][:] += rj_accepted
|
|
766
|
+
|
|
767
|
+
for i, v in enumerate(state.random_state):
|
|
768
|
+
g.attrs["random_state_{0}".format(i)] = v
|
|
769
|
+
|
|
770
|
+
g.attrs["iteration"] = iteration + 1
|
|
771
|
+
|
|
772
|
+
# moves
|
|
773
|
+
if moves_accepted_fraction is not None:
|
|
774
|
+
if "moves" not in g:
|
|
775
|
+
raise ValueError(
|
|
776
|
+
"""moves_accepted_fraction was passed, but moves_info was not initialized. Use the moves kwarg
|
|
777
|
+
in the reset function."""
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
# update acceptance fractions
|
|
781
|
+
for move_key in self.move_keys:
|
|
782
|
+
g["moves"][move_key]["acceptance_fraction"][:] = (
|
|
783
|
+
moves_accepted_fraction[move_key]
|
|
784
|
+
)
|
|
785
|
+
file_opened = True
|
|
786
|
+
|
|
787
|
+
except BlockingIOError:
|
|
788
|
+
try_num += 1
|
|
789
|
+
if try_num >= max_tries:
|
|
790
|
+
raise BlockingIOError("Max tries exceeded trying to open h5 file.")
|
|
791
|
+
print("Failed to open h5 file. Trying again.")
|
|
792
|
+
time.sleep(10.0)
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
class TempHDFBackend(object):
|
|
796
|
+
"""Check if HDF5 is working and available."""
|
|
797
|
+
|
|
798
|
+
def __init__(self, dtype=None, compression=None, compression_opts=None):
|
|
799
|
+
self.dtype = dtype
|
|
800
|
+
self.filename = None
|
|
801
|
+
self.compression = compression
|
|
802
|
+
self.compression_opts = compression_opts
|
|
803
|
+
|
|
804
|
+
def __enter__(self):
|
|
805
|
+
f = NamedTemporaryFile(
|
|
806
|
+
prefix="emcee-temporary-hdf5", suffix=".hdf5", delete=False
|
|
807
|
+
)
|
|
808
|
+
f.close()
|
|
809
|
+
self.filename = f.name
|
|
810
|
+
return HDFBackend(
|
|
811
|
+
f.name,
|
|
812
|
+
"test",
|
|
813
|
+
dtype=self.dtype,
|
|
814
|
+
compression=self.compression,
|
|
815
|
+
compression_opts=self.compression_opts,
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
def __exit__(self, exception_type, exception_value, traceback):
|
|
819
|
+
os.remove(self.filename)
|