eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
eryn/moves/red_blue.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
import numpy as np
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
from ..state import BranchSupplemental, State
|
|
8
|
+
from .move import Move
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = ["RedBlueMove"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RedBlueMove(Move, ABC):
|
|
15
|
+
"""
|
|
16
|
+
An abstract red-blue ensemble move with parallelization as described in
|
|
17
|
+
`Foreman-Mackey et al. (2013) <https://arxiv.org/abs/1202.3665>`_.
|
|
18
|
+
|
|
19
|
+
This class is heavily based on the original from ``emcee``.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
nsplits (int, optional): The number of sub-ensembles to use. Each
|
|
23
|
+
sub-ensemble is updated in parallel using the other sets as the
|
|
24
|
+
complementary ensemble. The default value is ``2`` and you
|
|
25
|
+
probably won't need to change that.
|
|
26
|
+
randomize_split (bool, optional): Randomly shuffle walkers between
|
|
27
|
+
sub-ensembles. The same number of walkers will be assigned to
|
|
28
|
+
each sub-ensemble on each iteration. (default: ``True``)
|
|
29
|
+
live_dangerously (bool, optional): By default, an update will fail with
|
|
30
|
+
a ``RuntimeError`` if the number of walkers is smaller than twice
|
|
31
|
+
the dimension of the problem because the walkers would then be
|
|
32
|
+
stuck on a low dimensional subspace. This can be avoided by
|
|
33
|
+
switching between the stretch move and, for example, a
|
|
34
|
+
Metropolis-Hastings step. If you want to do this and suppress the
|
|
35
|
+
error, set ``live_dangerously = True``. Thanks goes (once again)
|
|
36
|
+
to @dstndstn for this wonderful terminology. (default: ``False``)
|
|
37
|
+
**kwargs (dict, optional): Kwargs for parent classes.
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self, nsplits=2, randomize_split=True, live_dangerously=False, **kwargs
|
|
43
|
+
):
|
|
44
|
+
super(RedBlueMove, self).__init__(**kwargs)
|
|
45
|
+
self.nsplits = int(nsplits)
|
|
46
|
+
self.live_dangerously = live_dangerously
|
|
47
|
+
self.randomize_split = randomize_split
|
|
48
|
+
|
|
49
|
+
def setup(self, branches_coords):
|
|
50
|
+
"""Any setup for the proposal.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
branches_coords (dict): Keys are ``branch_names``. Values are
|
|
54
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
|
|
55
|
+
coordinates for all the walkers.
|
|
56
|
+
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def get_proposal(self, sample, complement, random, gibbs_ndim=None):
|
|
61
|
+
"""Make a proposal
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
sample (dict): Keys are ``branch_names``. Values are
|
|
65
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim].
|
|
66
|
+
complement (dict): Keys are ``branch_names``. Values are lists of other
|
|
67
|
+
other np.ndarray[ntemps, nwalkers - subset size, nleaves_max, ndim] from
|
|
68
|
+
all other subsets. This is the compliment whose ``coords`` are
|
|
69
|
+
used to form the proposal for the ``sample`` subset.
|
|
70
|
+
random (object): Current random state of the sampler.
|
|
71
|
+
gibbs_ndim (int or np.ndarray, optional): If Gibbs sampling, this indicates
|
|
72
|
+
the true dimension. If given as an array, must have shape ``(ntemps, nwalkers)``.
|
|
73
|
+
See the tutorial for more information.
|
|
74
|
+
(default: ``None``)
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
tuple: Tuple contained proposal information.
|
|
78
|
+
First entry is the new coordinates as a dictionary with keys
|
|
79
|
+
as ``branch_names`` and values as np.ndarray[ntemps, nwalkers, nleaves_max, ndim]
|
|
80
|
+
of new coordinates. Second entry is the factors associated with the
|
|
81
|
+
proposal necessary for detailed balance. This is effectively
|
|
82
|
+
any term in the detailed balance fraction. +log of factors if
|
|
83
|
+
in the numerator. -log of factors if in the denominator.
|
|
84
|
+
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
raise NotImplementedError("The proposal must be implemented by " "subclasses")
|
|
88
|
+
|
|
89
|
+
def propose(self, model, state):
|
|
90
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
94
|
+
state (:class:`State`): Current state of the sampler.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
tuple: (state, accepted)
|
|
98
|
+
The first return is the state of the sampler after the move.
|
|
99
|
+
The second return value is the accepted count array.
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
# Check that the dimensions are compatible.
|
|
103
|
+
ndim_total = 0
|
|
104
|
+
for branch in state.branches.values():
|
|
105
|
+
ntemps, nwalkers, nleaves_, ndim_ = branch.shape
|
|
106
|
+
ndim_total += ndim_ * nleaves_
|
|
107
|
+
|
|
108
|
+
if nwalkers < 2 * ndim_total and not self.live_dangerously:
|
|
109
|
+
raise RuntimeError(
|
|
110
|
+
"It is unadvisable to use a red-blue move "
|
|
111
|
+
"with fewer walkers than twice the number of "
|
|
112
|
+
"dimensions. If you would like to do this, please set live_dangerously"
|
|
113
|
+
"to True."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Run any move-specific setup.
|
|
117
|
+
self.setup(state.branches)
|
|
118
|
+
|
|
119
|
+
# Split the ensemble in half and iterate over these two halves.
|
|
120
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
121
|
+
all_inds = np.tile(np.arange(nwalkers), (ntemps, 1))
|
|
122
|
+
inds = all_inds % self.nsplits
|
|
123
|
+
if self.randomize_split:
|
|
124
|
+
[np.random.shuffle(x) for x in inds]
|
|
125
|
+
|
|
126
|
+
all_branch_names = list(state.branches.keys())
|
|
127
|
+
|
|
128
|
+
ntemps, nwalkers, _, _ = state.branches[all_branch_names[0]].shape
|
|
129
|
+
|
|
130
|
+
# get gibbs sampling information
|
|
131
|
+
for branch_names_run, inds_run in self.gibbs_sampling_setup_iterator(
|
|
132
|
+
all_branch_names
|
|
133
|
+
):
|
|
134
|
+
# setup proposals based on Gibbs sampling
|
|
135
|
+
(
|
|
136
|
+
coords_going_for_proposal,
|
|
137
|
+
inds_going_for_proposal,
|
|
138
|
+
at_least_one_proposal,
|
|
139
|
+
) = self.setup_proposals(
|
|
140
|
+
branch_names_run, inds_run, state.branches_coords, state.branches_inds
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if not at_least_one_proposal:
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
# prepare accepted fraction
|
|
147
|
+
accepted_here = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
148
|
+
for split in range(self.nsplits):
|
|
149
|
+
# get split information
|
|
150
|
+
S1 = inds == split
|
|
151
|
+
num_total_here = np.sum(inds == split)
|
|
152
|
+
nwalkers_here = np.sum(S1[0])
|
|
153
|
+
|
|
154
|
+
all_inds_shaped = all_inds[S1].reshape(ntemps, nwalkers_here)
|
|
155
|
+
|
|
156
|
+
# inds including gibbs information
|
|
157
|
+
new_inds = {
|
|
158
|
+
name: np.take_along_axis(
|
|
159
|
+
state.branches[name].inds,
|
|
160
|
+
all_inds_shaped[:, :, None],
|
|
161
|
+
axis=1,
|
|
162
|
+
)
|
|
163
|
+
for name in state.branches
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
# the actual inds for the subset
|
|
167
|
+
real_inds_subset = {
|
|
168
|
+
name: new_inds[name] for name in inds_going_for_proposal
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
# actual coordinates of subset
|
|
172
|
+
temp_coords = {
|
|
173
|
+
name: np.take_along_axis(
|
|
174
|
+
state.branches_coords[name],
|
|
175
|
+
all_inds_shaped[:, :, None, None],
|
|
176
|
+
axis=1,
|
|
177
|
+
)
|
|
178
|
+
for name in state.branches_coords
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
# prepare the sets for each model
|
|
182
|
+
# goes into the proposal as (ntemps * (nwalkers / subset size), nleaves_max, ndim)
|
|
183
|
+
sets = {
|
|
184
|
+
key: [
|
|
185
|
+
np.take_along_axis(
|
|
186
|
+
state.branches[key].coords,
|
|
187
|
+
all_inds[inds == j].reshape(ntemps, -1)[:, :, None, None],
|
|
188
|
+
axis=1,
|
|
189
|
+
)
|
|
190
|
+
for j in range(self.nsplits)
|
|
191
|
+
]
|
|
192
|
+
for key in branch_names_run
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
# setup s and c based on splits
|
|
196
|
+
s = {key: sets[key][split] for key in sets}
|
|
197
|
+
c = {key: sets[key][:split] + sets[key][split + 1 :] for key in sets}
|
|
198
|
+
|
|
199
|
+
# need to trick stretch proposal into using the dimenionality associated
|
|
200
|
+
# with Gibbs sampling if it is being used
|
|
201
|
+
gibbs_ndim = 0
|
|
202
|
+
for brn, ir in zip(branch_names_run, inds_run):
|
|
203
|
+
if ir is not None:
|
|
204
|
+
gibbs_ndim += ir.sum()
|
|
205
|
+
else:
|
|
206
|
+
# nleaves * ndim
|
|
207
|
+
gibbs_ndim += np.prod(state.branches[brn].shape[-2:])
|
|
208
|
+
|
|
209
|
+
# Get the move-specific proposal.
|
|
210
|
+
q, factors = self.get_proposal(
|
|
211
|
+
s, c, model.random, gibbs_ndim=gibbs_ndim
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# account for gibbs sampling
|
|
215
|
+
self.cleanup_proposals_gibbs(branch_names_run, inds_run, q, temp_coords)
|
|
216
|
+
|
|
217
|
+
# setup supplemental information
|
|
218
|
+
if state.supplemental is not None:
|
|
219
|
+
# TODO: should there be a copy?
|
|
220
|
+
new_supps = BranchSupplemental(
|
|
221
|
+
state.supplemental.take_along_axis(all_inds_shaped, axis=1),
|
|
222
|
+
base_shape=(ntemps, nwalkers),
|
|
223
|
+
copy=False,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
else:
|
|
227
|
+
new_supps = None
|
|
228
|
+
|
|
229
|
+
# default for removing inds info from supp
|
|
230
|
+
if not np.all(
|
|
231
|
+
np.asarray(list(state.branches_supplemental.values())) == None
|
|
232
|
+
):
|
|
233
|
+
new_branch_supps_tmp = {
|
|
234
|
+
name: state.branches[name].branch_supplemental.take_along_axis(
|
|
235
|
+
all_inds_shaped[:, :, None], axis=1
|
|
236
|
+
)
|
|
237
|
+
for name in state.branches
|
|
238
|
+
if state.branches[name].branch_supplemental is not None
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
new_branch_supps = {
|
|
242
|
+
name: BranchSupplemental(
|
|
243
|
+
new_branch_supps_tmp[name],
|
|
244
|
+
base_shape=new_inds[name].shape,
|
|
245
|
+
copy=False,
|
|
246
|
+
)
|
|
247
|
+
for name in new_branch_supps_tmp
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
else:
|
|
251
|
+
new_branch_supps = None
|
|
252
|
+
|
|
253
|
+
# order everything properly
|
|
254
|
+
q, new_inds, new_branch_supps = self.ensure_ordering(
|
|
255
|
+
list(state.branches.keys()), q, new_inds, new_branch_supps
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Compute prior of the proposed position
|
|
259
|
+
# new_inds_prior is adjusted if product-space is used
|
|
260
|
+
logp = model.compute_log_prior_fn(
|
|
261
|
+
q,
|
|
262
|
+
inds=new_inds,
|
|
263
|
+
supps=new_supps,
|
|
264
|
+
branch_supps=new_branch_supps,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
self.fix_logp_gibbs(branch_names_run, inds_run, logp, real_inds_subset)
|
|
268
|
+
|
|
269
|
+
# Compute the lnprobs of the proposed position.
|
|
270
|
+
logl, new_blobs = model.compute_log_like_fn(
|
|
271
|
+
q,
|
|
272
|
+
inds=new_inds,
|
|
273
|
+
logp=logp,
|
|
274
|
+
supps=new_supps,
|
|
275
|
+
branch_supps=new_branch_supps,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# catch and warn about nans
|
|
279
|
+
if np.any(np.isnan(logl)):
|
|
280
|
+
logl[np.isnan(logl)] = -1e300
|
|
281
|
+
warnings.warn("Getting Nan in likelihood computation.")
|
|
282
|
+
|
|
283
|
+
logP = self.compute_log_posterior(logl, logp)
|
|
284
|
+
|
|
285
|
+
prev_logl = np.take_along_axis(state.log_like, all_inds_shaped, axis=1)
|
|
286
|
+
|
|
287
|
+
prev_logp = np.take_along_axis(state.log_prior, all_inds_shaped, axis=1)
|
|
288
|
+
|
|
289
|
+
# takes care of tempering
|
|
290
|
+
prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
|
|
291
|
+
|
|
292
|
+
lnpdiff = factors + logP - prev_logP
|
|
293
|
+
|
|
294
|
+
keep = lnpdiff > np.log(model.random.rand(ntemps, nwalkers_here))
|
|
295
|
+
|
|
296
|
+
# if gibbs sampling, this will say it is accepted if
|
|
297
|
+
# any of the gibbs proposals were accepted
|
|
298
|
+
np.put_along_axis(
|
|
299
|
+
accepted_here,
|
|
300
|
+
all_inds_shaped,
|
|
301
|
+
keep,
|
|
302
|
+
axis=1,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# readout for total
|
|
306
|
+
accepted = (accepted.astype(int) + accepted_here.astype(int)).astype(
|
|
307
|
+
bool
|
|
308
|
+
)
|
|
309
|
+
# new state
|
|
310
|
+
new_state = State(
|
|
311
|
+
q,
|
|
312
|
+
log_like=logl,
|
|
313
|
+
log_prior=logp,
|
|
314
|
+
blobs=new_blobs,
|
|
315
|
+
inds=new_inds,
|
|
316
|
+
supplemental=new_supps,
|
|
317
|
+
branch_supplemental=new_branch_supps,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# update state
|
|
321
|
+
state = self.update(
|
|
322
|
+
state, new_state, accepted_here, subset=all_inds_shaped
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# add to move-specific accepted information
|
|
326
|
+
self.accepted += accepted
|
|
327
|
+
self.num_proposals += 1
|
|
328
|
+
|
|
329
|
+
# temp swaps
|
|
330
|
+
if self.temperature_control is not None:
|
|
331
|
+
state = self.temperature_control.temper_comps(state)
|
|
332
|
+
|
|
333
|
+
return state, accepted
|