eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,819 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import division, print_function
4
+
5
+ __all__ = ["HDFBackend", "TempHDFBackend", "does_hdf5_support_longdouble"]
6
+
7
+ import os
8
+ import time
9
+ from tempfile import NamedTemporaryFile
10
+
11
+ import numpy as np
12
+
13
+ from .. import __version__
14
+ from .backend import Backend
15
+
16
+
17
+ try:
18
+ import h5py
19
+ except ImportError:
20
+ h5py = None
21
+
22
+
23
+ def does_hdf5_support_longdouble():
24
+ if h5py is None:
25
+ return False
26
+ with NamedTemporaryFile(
27
+ prefix="emcee-temporary-hdf5", suffix=".hdf5", delete=False
28
+ ) as f:
29
+ f.close()
30
+
31
+ with h5py.File(f.name, "w") as hf:
32
+ g = hf.create_group("group")
33
+ g.create_dataset("data", data=np.ones(1, dtype=np.longdouble))
34
+ if g["data"].dtype != np.longdouble:
35
+ return False
36
+ with h5py.File(f.name, "r") as hf:
37
+ if hf["group"]["data"].dtype != np.longdouble:
38
+ return False
39
+ return True
40
+
41
+
42
+ class HDFBackend(Backend):
43
+ """A backend that stores the chain in an HDF5 file using h5py
44
+
45
+ .. note:: You must install `h5py <http://www.h5py.org/>`_ to use this
46
+ backend.
47
+
48
+ Args:
49
+ filename (str): The name of the HDF5 file where the chain will be
50
+ saved.
51
+ name (str, optional): The name of the group where the chain will
52
+ be saved. (default: ``"mcmc"``)
53
+ read_only (bool, optional): If ``True``, the backend will throw a
54
+ ``RuntimeError`` if the file is opened with write access.
55
+ (default: ``False``)
56
+ dtype (dtype, optional): Dtype to use for data storage. If None,
57
+ program uses np.float64. (default: ``None``)
58
+ compression (str, optional): Compression type for h5 file. See more information
59
+ in the
60
+ `h5py documentation <https://docs.h5py.org/en/stable/high/dataset.html#filter-pipeline>`_.
61
+ (default: ``None``)
62
+ compression_opts (int, optional): Compression level for h5 file. See more information
63
+ in the
64
+ `h5py documentation <https://docs.h5py.org/en/stable/high/dataset.html#filter-pipeline>`_.
65
+ (default: ``None``)
66
+ store_missing_leaves (double, optional): Number to store for leaves that are not
67
+ used in a specific step. (default: ``np.nan``)
68
+
69
+
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ filename,
75
+ name="mcmc",
76
+ read_only=False,
77
+ dtype=None,
78
+ compression=None,
79
+ compression_opts=None,
80
+ store_missing_leaves=np.nan,
81
+ ):
82
+ if h5py is None:
83
+ raise ImportError("you must install 'h5py' to use the HDFBackend")
84
+
85
+ # store all necessary quantities
86
+ self.filename = filename
87
+ self.name = name
88
+ self.read_only = read_only
89
+ self.compression = compression
90
+ self.compression_opts = compression_opts
91
+ if dtype is None:
92
+ self.dtype_set = False
93
+ self.dtype = np.float64
94
+ else:
95
+ self.dtype_set = True
96
+ self.dtype = dtype
97
+
98
+ self.store_missing_leaves = store_missing_leaves
99
+
100
+ @property
101
+ def initialized(self):
102
+ """Check if backend file has been initialized properly."""
103
+ if not os.path.exists(self.filename):
104
+ return False
105
+ try:
106
+ with self.open() as f:
107
+ return self.name in f
108
+ except (OSError, IOError):
109
+ return False
110
+
111
+ def open(self, mode="r"):
112
+ """Opens the h5 file in the proper mode.
113
+
114
+ Args:
115
+ mode (str, optional): Mode to open h5 file.
116
+
117
+ Returns:
118
+ H5 file object: Opened file.
119
+
120
+ Raises:
121
+ RuntimeError: If backend is opened for writing when it is read-only.
122
+
123
+ """
124
+
125
+ if self.read_only and mode != "r":
126
+ raise RuntimeError(
127
+ "The backend has been loaded in read-only "
128
+ "mode. Set `read_only = False` to make "
129
+ "changes."
130
+ )
131
+
132
+ # open the file
133
+ file_opened = False
134
+
135
+ try_num = 0
136
+ max_tries = 100
137
+ while not file_opened:
138
+ try:
139
+ f = h5py.File(self.filename, mode)
140
+ file_opened = True
141
+
142
+ except BlockingIOError:
143
+ try_num += 1
144
+ if try_num >= max_tries:
145
+ raise BlockingIOError("Max tries exceeded trying to open h5 file.")
146
+ print("Failed to open h5 file. Trying again.")
147
+ time.sleep(10.0)
148
+
149
+ # get the data type and store it if it is not previously set
150
+ if not self.dtype_set and self.name in f:
151
+ # get the group from the file
152
+ g = f[self.name]
153
+ if "chain" in g:
154
+ # get the model names in chain
155
+ keys = list(g["chain"])
156
+
157
+ # they all have the same dtype so use the first one
158
+ try:
159
+ self.dtype = g["chain"][keys[0]].dtype
160
+
161
+ # we now have it
162
+ self.dtype_set = True
163
+ # catch error if the chain has not been initialized yet
164
+ except IndexError:
165
+ pass
166
+
167
+ return f
168
+
169
+ def reset(
170
+ self,
171
+ nwalkers,
172
+ ndims,
173
+ nleaves_max=1,
174
+ ntemps=1,
175
+ branch_names=None,
176
+ nbranches=1,
177
+ rj=False,
178
+ moves=None,
179
+ **info,
180
+ ):
181
+ """Clear the state of the chain and empty the backend
182
+
183
+ Args:
184
+ nwalkers (int): The size of the ensemble
185
+ ndims (int, list of ints, or dict): The number of dimensions for each branch. If
186
+ ``dict``, keys should be the branch names and values the associated dimensionality.
187
+ nleaves_max (int, list of ints, or dict, optional): Maximum allowable leaf count for each branch.
188
+ It should have the same length as the number of branches.
189
+ If ``dict``, keys should be the branch names and values the associated maximal leaf value.
190
+ (default: ``1``)
191
+ ntemps (int, optional): Number of rungs in the temperature ladder.
192
+ (default: ``1``)
193
+ branch_names (str or list of str, optional): Names of the branches used. If not given,
194
+ branches will be names ``model_0``, ..., ``model_n`` for ``n`` branches.
195
+ (default: ``None``)
196
+ nbranches (int, optional): Number of branches. This is only used if ``branch_names is None``.
197
+ (default: ``1``)
198
+ rj (bool, optional): If True, reversible-jump techniques are used.
199
+ (default: ``False``)
200
+ moves (list, optional): List of all of the move classes input into the sampler.
201
+ (default: ``None``)
202
+ **info (dict, optional): Any other key-value pairs to be added
203
+ as attributes to the backend. These are also added to the HDF5 file.
204
+
205
+ """
206
+
207
+ # open file in append mode
208
+ with self.open("a") as f:
209
+ # we are resetting so if self.name in the file we need to delete it
210
+ if self.name in f:
211
+ del f[self.name]
212
+
213
+ # turn things into lists/dicts if needed
214
+ if branch_names is not None:
215
+ if isinstance(branch_names, str):
216
+ branch_names = [branch_names]
217
+
218
+ elif not isinstance(branch_names, list):
219
+ raise ValueError("branch_names must be string or list of strings.")
220
+
221
+ else:
222
+ branch_names = ["model_{}".format(i) for i in range(nbranches)]
223
+
224
+ nbranches = len(branch_names)
225
+
226
+ if isinstance(ndims, int):
227
+ assert len(branch_names) == 1
228
+ ndims = {branch_names[0]: ndims}
229
+
230
+ elif isinstance(ndims, list) or isinstance(ndims, np.ndarray):
231
+ assert len(branch_names) == len(ndims)
232
+ ndims = {bn: nd for bn, nd in zip(branch_names, ndims)}
233
+
234
+ elif isinstance(ndims, dict):
235
+ assert len(list(ndims.keys())) == len(branch_names)
236
+ for key in ndims:
237
+ if key not in branch_names:
238
+ raise ValueError(
239
+ f"{key} is in ndims but does not appear in branch_names: {branch_names}."
240
+ )
241
+ else:
242
+ raise ValueError("ndims is to be a scalar int, list or dict.")
243
+
244
+ if isinstance(nleaves_max, int):
245
+ assert len(branch_names) == 1
246
+ nleaves_max = {branch_names[0]: nleaves_max}
247
+
248
+ elif isinstance(nleaves_max, list) or isinstance(nleaves_max, np.ndarray):
249
+ assert len(branch_names) == len(nleaves_max)
250
+ nleaves_max = {bn: nl for bn, nl in zip(branch_names, nleaves_max)}
251
+
252
+ elif isinstance(nleaves_max, dict):
253
+ assert len(list(nleaves_max.keys())) == len(branch_names)
254
+ for key in nleaves_max:
255
+ if key not in branch_names:
256
+ raise ValueError(
257
+ f"{key} is in nleaves_max but does not appear in branch_names: {branch_names}."
258
+ )
259
+ else:
260
+ raise ValueError("nleaves_max is to be a scalar int, list, or dict.")
261
+
262
+ # store all the info needed in memory and in the file
263
+
264
+ g = f.create_group(self.name)
265
+
266
+ g.attrs["version"] = __version__
267
+ g.attrs["nbranches"] = len(branch_names)
268
+ g.attrs["branch_names"] = branch_names
269
+ g.attrs["ntemps"] = ntemps
270
+ g.attrs["nwalkers"] = nwalkers
271
+ g.attrs["has_blobs"] = False
272
+ g.attrs["rj"] = rj
273
+ g.attrs["iteration"] = 0
274
+
275
+ # create info group
276
+ g.create_group("info")
277
+ # load info into class and into file
278
+ for key, value in info.items():
279
+ setattr(self, key, value)
280
+ g["info"].attrs[key] = value
281
+
282
+ # store nleaves max and ndims dicts
283
+ g.create_group("ndims")
284
+ for key, value in ndims.items():
285
+ g["ndims"].attrs[key] = value
286
+
287
+ g.create_group("nleaves_max")
288
+ for key, value in nleaves_max.items():
289
+ g["nleaves_max"].attrs[key] = value
290
+
291
+ # prepare all the data sets
292
+
293
+ g.create_dataset(
294
+ "accepted",
295
+ data=np.zeros((ntemps, nwalkers)),
296
+ compression=self.compression,
297
+ compression_opts=self.compression_opts,
298
+ )
299
+
300
+ g.create_dataset(
301
+ "swaps_accepted",
302
+ data=np.zeros((ntemps - 1,)),
303
+ compression=self.compression,
304
+ compression_opts=self.compression_opts,
305
+ )
306
+
307
+ if self.rj:
308
+ g.create_dataset(
309
+ "rj_accepted",
310
+ data=np.zeros((ntemps, nwalkers)),
311
+ compression=self.compression,
312
+ compression_opts=self.compression_opts,
313
+ )
314
+
315
+ g.create_dataset(
316
+ "log_like",
317
+ (0, ntemps, nwalkers),
318
+ maxshape=(None, ntemps, nwalkers),
319
+ dtype=self.dtype,
320
+ compression=self.compression,
321
+ compression_opts=self.compression_opts,
322
+ )
323
+
324
+ g.create_dataset(
325
+ "log_prior",
326
+ (0, ntemps, nwalkers),
327
+ maxshape=(None, ntemps, nwalkers),
328
+ dtype=self.dtype,
329
+ compression=self.compression,
330
+ compression_opts=self.compression_opts,
331
+ )
332
+
333
+ g.create_dataset(
334
+ "betas",
335
+ (0, ntemps),
336
+ maxshape=(None, ntemps),
337
+ dtype=self.dtype,
338
+ compression=self.compression,
339
+ compression_opts=self.compression_opts,
340
+ )
341
+
342
+ # setup data sets for branch-specific items
343
+
344
+ chain = g.create_group("chain")
345
+ inds = g.create_group("inds")
346
+
347
+ for name in branch_names:
348
+ nleaves = self.nleaves_max[name]
349
+ ndim = self.ndims[name]
350
+ chain.create_dataset(
351
+ name,
352
+ (0, ntemps, nwalkers, nleaves, ndim),
353
+ maxshape=(None, ntemps, nwalkers, nleaves, ndim),
354
+ dtype=self.dtype,
355
+ compression=self.compression,
356
+ compression_opts=self.compression_opts,
357
+ )
358
+
359
+ inds.create_dataset(
360
+ name,
361
+ (0, ntemps, nwalkers, nleaves),
362
+ maxshape=(None, ntemps, nwalkers, nleaves),
363
+ dtype=bool,
364
+ compression=self.compression,
365
+ compression_opts=self.compression_opts,
366
+ )
367
+
368
+ # store move specific information
369
+ if moves is not None:
370
+ move_group = g.create_group("moves")
371
+ # setup info and keys
372
+ for full_move_name in moves:
373
+
374
+ single_move = move_group.create_group(full_move_name)
375
+
376
+ # prepare information dictionary
377
+ single_move.create_dataset(
378
+ "acceptance_fraction",
379
+ (ntemps, nwalkers),
380
+ maxshape=(ntemps, nwalkers),
381
+ dtype=self.dtype,
382
+ compression=self.compression,
383
+ compression_opts=self.compression_opts,
384
+ )
385
+
386
+ else:
387
+ self.move_info = None
388
+
389
+ self.blobs = None
390
+
391
+ @property
392
+ def nwalkers(self):
393
+ """Get nwalkers from h5 file."""
394
+ with self.open() as f:
395
+ return f[self.name].attrs["nwalkers"]
396
+
397
+ @property
398
+ def ntemps(self):
399
+ """Get ntemps from h5 file."""
400
+ with self.open() as f:
401
+ return f[self.name].attrs["ntemps"]
402
+
403
+ @property
404
+ def rj(self):
405
+ """Get rj from h5 file."""
406
+ with self.open() as f:
407
+ return f[self.name].attrs["rj"]
408
+
409
+ @property
410
+ def nleaves_max(self):
411
+ """Get nleaves_max from h5 file."""
412
+ with self.open() as f:
413
+ return {
414
+ key: f[self.name]["nleaves_max"].attrs[key]
415
+ for key in f[self.name]["nleaves_max"].attrs
416
+ }
417
+
418
+ @property
419
+ def ndims(self):
420
+ """Get ndims from h5 file."""
421
+ with self.open() as f:
422
+ return {
423
+ key: f[self.name]["ndims"].attrs[key]
424
+ for key in f[self.name]["ndims"].attrs
425
+ }
426
+
427
+ @property
428
+ def move_keys(self):
429
+ """Get move_keys from h5 file."""
430
+ with self.open() as f:
431
+ return list(f[self.name]["moves"])
432
+
433
+ @property
434
+ def branch_names(self):
435
+ """Get branch names from h5 file."""
436
+ with self.open() as f:
437
+ return f[self.name].attrs["branch_names"]
438
+
439
+ @property
440
+ def nbranches(self):
441
+ """Get number of branches from h5 file."""
442
+ with self.open() as f:
443
+ return f[self.name].attrs["nbranches"]
444
+
445
+ @property
446
+ def reset_args(self):
447
+ """Get reset_args from h5 file."""
448
+ return [self.nwalkers, self.ndims]
449
+
450
+ @property
451
+ def reset_kwargs(self):
452
+ """Get reset_kwargs from h5 file."""
453
+ return dict(
454
+ nleaves_max=self.nleaves_max,
455
+ ntemps=self.ntemps,
456
+ branch_names=self.branch_names,
457
+ rj=self.rj,
458
+ moves=self.moves,
459
+ )
460
+
461
+ @property
462
+ def reset_kwargs(self):
463
+ """Get reset_kwargs from h5 file."""
464
+ with self.open() as f:
465
+ return f[self.name].attrs["reset_kwargs"]
466
+
467
+ def has_blobs(self):
468
+ """Returns ``True`` if the model includes blobs"""
469
+ with self.open() as f:
470
+ return f[self.name].attrs["has_blobs"]
471
+
472
+ def get_value(self, name, thin=1, discard=0, slice_vals=None, temp_index=None, branch_names=None):
473
+ """Returns a requested value to user.
474
+
475
+ This function helps to streamline the backend for both
476
+ basic and hdf backend.
477
+
478
+ Args:
479
+ name (str): Name of value requested.
480
+ thin (int, optional): Take only every ``thin`` steps from the
481
+ chain. (default: ``1``)
482
+ discard (int, optional): Discard the first ``discard`` steps in
483
+ the chain as burn-in. (default: ``0``)
484
+ slice_vals (indexing np.ndarray or slice, optional): If provided, slice the array directly
485
+ from the HDF5 file with slice = ``slice_vals``. ``thin`` and ``discard`` will be
486
+ ignored if slice_vals is not ``None``. This is particularly useful if files are
487
+ very large and the user only wants a small subset of the overall array.
488
+ (default: ``None``)
489
+ temp_index (int, optional): Integer for the desired temperature index.
490
+ If ``None``, will return all temperatures. (default: ``None``)
491
+ branch_names (str or list, optional): Specific branch names requested. (default: ``None``)
492
+
493
+ Returns:
494
+ dict or np.ndarray: Values requested.
495
+
496
+ """
497
+ # check if initialized
498
+ if not self.initialized:
499
+ raise AttributeError(
500
+ "You must run the sampler with "
501
+ "'store == True' before accessing the "
502
+ "results."
503
+ "When using the HDF backend, make sure you have the file"
504
+ "path correctly set. This is the error that"
505
+ "is given if the backend cannot find the file."
506
+ )
507
+
508
+ if slice_vals is None:
509
+ slice_vals = slice(discard + thin - 1, self.iteration, thin)
510
+
511
+ # make sure branch_names input is a list
512
+ if branch_names is not None:
513
+ if isinstance(branch_names, str):
514
+ branches_names = [branch_names]
515
+
516
+ branch_names_in = self.branch_names if branch_names is None else branch_names
517
+
518
+ # open the file wrapped in a "with" statement
519
+ with self.open() as f:
520
+ # get the group that everything is stored in
521
+ g = f[self.name]
522
+ iteration = g.attrs["iteration"]
523
+ if iteration <= 0:
524
+ raise AttributeError(
525
+ "You must run the sampler with "
526
+ "'store == True' before accessing the "
527
+ "results"
528
+ )
529
+
530
+ if name == "blobs" and not g.attrs["has_blobs"]:
531
+ return None
532
+
533
+ if temp_index is None:
534
+ temp_index = np.arange(self.ntemps)
535
+ else:
536
+ assert isinstance(temp_index, int)
537
+
538
+ if name == "chain":
539
+ v_all = {key: g["chain"][key][slice_vals, temp_index] for key in branch_names_in}
540
+ return v_all
541
+
542
+ if name == "inds":
543
+ v_all = {key: g["inds"][key][slice_vals, temp_index] for key in branch_names_in}
544
+
545
+ return v_all
546
+
547
+ v = g[name][slice_vals, temp_index]
548
+
549
+ return v
550
+
551
+ def get_move_info(self):
552
+ """Get move information.
553
+
554
+ Returns:
555
+ dict: Keys are move names and values are dictionaries with information on the moves.
556
+
557
+ """
558
+ # setup output dictionary
559
+ move_info_out = {}
560
+ with self.open() as f:
561
+ g = f[self.name]
562
+
563
+ # iterate through everything and produce a dictionary
564
+ for move_name in g["moves"]:
565
+ move_info_out[move_name] = {}
566
+ for info_name in g["moves"][move_name]:
567
+ move_info_out[move_name][info_name] = g["moves"][move_name][
568
+ info_name
569
+ ][:]
570
+
571
+ return move_info_out
572
+
573
+ @property
574
+ def shape(self):
575
+ """The dimensions of the ensemble
576
+
577
+ Returns:
578
+ dict: Shape of samples
579
+ Keys are ``branch_names`` and values are tuples with
580
+ shapes of individual branches: (ntemps, nwalkers, nleaves_max, ndim).
581
+
582
+ """
583
+ # open file wrapped in with
584
+ with self.open() as f:
585
+ g = f[self.name]
586
+ return {
587
+ key: (
588
+ g.attrs["ntemps"],
589
+ g.attrs["nwalkers"],
590
+ self.nleaves_max[key],
591
+ self.ndims[key],
592
+ )
593
+ for key in g.attrs["branch_names"]
594
+ }
595
+
596
+ @property
597
+ def iteration(self):
598
+ """Number of iterations stored in the hdf backend so far."""
599
+ with self.open() as f:
600
+ return f[self.name].attrs["iteration"]
601
+
602
+ @property
603
+ def accepted(self):
604
+ """Number of accepted moves per walker."""
605
+ with self.open() as f:
606
+ return f[self.name]["accepted"][...]
607
+
608
+ @property
609
+ def rj_accepted(self):
610
+ """Number of accepted rj moves per walker."""
611
+ with self.open() as f:
612
+ return f[self.name]["rj_accepted"][...]
613
+
614
+ @property
615
+ def swaps_accepted(self):
616
+ """Number of accepted swaps."""
617
+ with self.open() as f:
618
+ return f[self.name]["swaps_accepted"][...]
619
+
620
+ @property
621
+ def random_state(self):
622
+ """Get the random state"""
623
+ with self.open() as f:
624
+ elements = [
625
+ v
626
+ for k, v in sorted(f[self.name].attrs.items())
627
+ if k.startswith("random_state_")
628
+ ]
629
+ return elements if len(elements) else None
630
+
631
+ def grow(self, ngrow, blobs):
632
+ """Expand the storage space by some number of samples
633
+
634
+ Args:
635
+ ngrow (int): The number of steps to grow the chain.
636
+ blobs (None or np.ndarray): The current array of blobs. This is used to compute the
637
+ dtype for the blobs array.
638
+
639
+ """
640
+ self._check_blobs(blobs)
641
+
642
+ # open the file in append mode
643
+ with self.open("a") as f:
644
+ g = f[self.name]
645
+
646
+ # resize all the arrays accordingly
647
+
648
+ ntot = g.attrs["iteration"] + ngrow
649
+ for key in g["chain"]:
650
+ g["chain"][key].resize(ntot, axis=0)
651
+ g["inds"][key].resize(ntot, axis=0)
652
+
653
+ g["log_like"].resize(ntot, axis=0)
654
+ g["log_prior"].resize(ntot, axis=0)
655
+ g["betas"].resize(ntot, axis=0)
656
+
657
+ # deal with blobs
658
+ if blobs is not None:
659
+ has_blobs = g.attrs["has_blobs"]
660
+ # if blobs have not been added yet
661
+ if not has_blobs:
662
+ nwalkers = g.attrs["nwalkers"]
663
+ ntemps = g.attrs["ntemps"]
664
+ g.create_dataset(
665
+ "blobs",
666
+ (ntot, ntemps, nwalkers, blobs.shape[-1]),
667
+ maxshape=(None, ntemps, nwalkers, blobs.shape[-1]),
668
+ dtype=self.dtype,
669
+ compression=self.compression,
670
+ compression_opts=self.compression_opts,
671
+ )
672
+ else:
673
+ # resize the blobs if they have been there
674
+ g["blobs"].resize(ntot, axis=0)
675
+ if g["blobs"].shape[1:] != blobs.shape:
676
+ raise ValueError(
677
+ "Existing blobs have shape {} but new blobs "
678
+ "requested with shape {}".format(
679
+ g["blobs"].shape[1:], blobs.shape
680
+ )
681
+ )
682
+ g.attrs["has_blobs"] = True
683
+
684
+ def save_step(
685
+ self,
686
+ state,
687
+ accepted,
688
+ rj_accepted=None,
689
+ swaps_accepted=None,
690
+ moves_accepted_fraction=None,
691
+ ):
692
+ """Save a step to the backend
693
+
694
+ Args:
695
+ state (State): The :class:`State` of the ensemble.
696
+ accepted (ndarray): An array of boolean flags indicating whether
697
+ or not the proposal for each walker was accepted.
698
+ rj_accepted (ndarray, optional): An array of the number of accepted steps
699
+ for the reversible jump proposal for each walker.
700
+ If :code:`self.rj` is True, then rj_accepted must be an array with
701
+ :code:`rj_accepted.shape == accepted.shape`. If :code:`self.rj`
702
+ is False, then rj_accepted must be None, which is the default.
703
+ swaps_accepted (ndarray, optional): 1D array with number of swaps accepted
704
+ for the in-model step. (default: ``None``)
705
+ moves_accepted_fraction (dict, optional): Dict of acceptance fraction arrays for all of the
706
+ moves in the sampler. This dict must have the same keys as ``self.move_keys``.
707
+ (default: ``None``)
708
+
709
+ """
710
+ file_opened = False
711
+ max_tries = 100
712
+ try_num = 0
713
+ while not file_opened:
714
+ try:
715
+
716
+ # open for appending in with statement
717
+ with self.open("a") as f:
718
+ g = f[self.name]
719
+ # get the iteration left off on
720
+ iteration = g.attrs["iteration"]
721
+
722
+ # make sure the backend has all the information needed to store everything
723
+ for key in [
724
+ "rj",
725
+ "ntemps",
726
+ "nwalkers",
727
+ "nbranches",
728
+ "branch_names",
729
+ "ndims",
730
+ ]:
731
+ if not hasattr(self, key):
732
+ setattr(self, key, g.attrs[key])
733
+
734
+ # check the inputs are okay
735
+ self._check(
736
+ state,
737
+ accepted,
738
+ rj_accepted=rj_accepted,
739
+ swaps_accepted=swaps_accepted,
740
+ )
741
+
742
+ # branch-specific
743
+ for name, model in state.branches.items():
744
+ g["inds"][name][iteration] = model.inds
745
+ # use self.store_missing_leaves to set value for missing leaves
746
+ # state retains old coordinates
747
+ coords_in = model.coords * model.inds[:, :, :, None]
748
+ inds_all = np.repeat(
749
+ model.inds, coords_in.shape[-1], axis=-1
750
+ ).reshape(model.inds.shape + (coords_in.shape[-1],))
751
+ coords_in[~inds_all] = self.store_missing_leaves
752
+ g["chain"][name][self.iteration] = coords_in
753
+
754
+ # store everything else in the file
755
+ g["log_like"][iteration, :] = state.log_like
756
+ g["log_prior"][iteration, :] = state.log_prior
757
+ if state.blobs is not None:
758
+ g["blobs"][iteration, :] = state.blobs
759
+ if state.betas is not None:
760
+ g["betas"][self.iteration, :] = state.betas
761
+ g["accepted"][:] += accepted
762
+ if swaps_accepted is not None:
763
+ g["swaps_accepted"][:] += swaps_accepted
764
+ if self.rj:
765
+ g["rj_accepted"][:] += rj_accepted
766
+
767
+ for i, v in enumerate(state.random_state):
768
+ g.attrs["random_state_{0}".format(i)] = v
769
+
770
+ g.attrs["iteration"] = iteration + 1
771
+
772
+ # moves
773
+ if moves_accepted_fraction is not None:
774
+ if "moves" not in g:
775
+ raise ValueError(
776
+ """moves_accepted_fraction was passed, but moves_info was not initialized. Use the moves kwarg
777
+ in the reset function."""
778
+ )
779
+
780
+ # update acceptance fractions
781
+ for move_key in self.move_keys:
782
+ g["moves"][move_key]["acceptance_fraction"][:] = (
783
+ moves_accepted_fraction[move_key]
784
+ )
785
+ file_opened = True
786
+
787
+ except BlockingIOError:
788
+ try_num += 1
789
+ if try_num >= max_tries:
790
+ raise BlockingIOError("Max tries exceeded trying to open h5 file.")
791
+ print("Failed to open h5 file. Trying again.")
792
+ time.sleep(10.0)
793
+
794
+
795
+ class TempHDFBackend(object):
796
+ """Check if HDF5 is working and available."""
797
+
798
+ def __init__(self, dtype=None, compression=None, compression_opts=None):
799
+ self.dtype = dtype
800
+ self.filename = None
801
+ self.compression = compression
802
+ self.compression_opts = compression_opts
803
+
804
+ def __enter__(self):
805
+ f = NamedTemporaryFile(
806
+ prefix="emcee-temporary-hdf5", suffix=".hdf5", delete=False
807
+ )
808
+ f.close()
809
+ self.filename = f.name
810
+ return HDFBackend(
811
+ f.name,
812
+ "test",
813
+ dtype=self.dtype,
814
+ compression=self.compression,
815
+ compression_opts=self.compression_opts,
816
+ )
817
+
818
+ def __exit__(self, exception_type, exception_value, traceback):
819
+ os.remove(self.filename)