eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
eryn/moves/rj.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from multiprocessing.sharedctypes import Value
|
|
4
|
+
import numpy as np
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from ..state import State
|
|
7
|
+
from .move import Move
|
|
8
|
+
from .delayedrejection import DelayedRejection
|
|
9
|
+
from .distgen import DistributionGenerate
|
|
10
|
+
|
|
11
|
+
__all__ = ["ReversibleJumpMove"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ReversibleJumpMove(Move):
|
|
15
|
+
"""
|
|
16
|
+
An abstract reversible jump move.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
nleaves_max (dict): Maximum number(s) of leaves for each model.
|
|
20
|
+
Keys are ``branch_names`` and values are ``nleaves_max`` for each branch.
|
|
21
|
+
This is a keyword argument, nut it is required.
|
|
22
|
+
nleaves_min (dict): Minimum number(s) of leaves for each model.
|
|
23
|
+
Keys are ``branch_names`` and values are ``nleaves_min`` for each branch.
|
|
24
|
+
This is a keyword argument, nut it is required.
|
|
25
|
+
tune (bool, optional): If True, tune proposal. (Default: ``False``)
|
|
26
|
+
fix_change (int or None, optional): Fix the change in the number of leaves. Make them all
|
|
27
|
+
add a leaf or remove a leaf. This can be useful for some search functions. Options
|
|
28
|
+
are ``+1`` or ``-1``. (default: ``None``)
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
nleaves_max=None,
|
|
35
|
+
nleaves_min=None,
|
|
36
|
+
dr=None,
|
|
37
|
+
dr_max_iter=5,
|
|
38
|
+
tune=False,
|
|
39
|
+
fix_change=None,
|
|
40
|
+
**kwargs
|
|
41
|
+
):
|
|
42
|
+
# super(ReversibleJumpMove, self).__init__(**kwargs)
|
|
43
|
+
Move.__init__(self, is_rj=True, **kwargs)
|
|
44
|
+
|
|
45
|
+
if nleaves_max is None or nleaves_min is None:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"Must provide nleaves_min and nleaves_max keyword arguments for RJ."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
if not isinstance(nleaves_max, dict) or not isinstance(nleaves_min, dict):
|
|
51
|
+
raise ValueError(
|
|
52
|
+
"nleaves_min and nleaves_max must be provided as dictionaries with keys as branch names and values as the max or min leaf count."
|
|
53
|
+
)
|
|
54
|
+
# store info
|
|
55
|
+
self.nleaves_max = nleaves_max
|
|
56
|
+
self.nleaves_min = nleaves_min
|
|
57
|
+
self.tune = tune
|
|
58
|
+
self.dr = dr
|
|
59
|
+
self.fix_change = fix_change
|
|
60
|
+
if self.fix_change not in [None, +1, -1]:
|
|
61
|
+
raise ValueError("fix_change must be None, +1, or -1.")
|
|
62
|
+
|
|
63
|
+
# Decide if DR is desirable. TODO: Now it uses the prior generator, we need to
|
|
64
|
+
# think carefully if we want to use the in-model sampling proposal
|
|
65
|
+
if self.dr is not None and self.dr is not False:
|
|
66
|
+
if self.dr is True: # Check if it's a boolean, then we just generate
|
|
67
|
+
# from prior (kills the purpose, but yields "healther" chains)
|
|
68
|
+
dr_proposal = DistributionGenerate(
|
|
69
|
+
self.generate_dist, temperature_control=self.temperature_control
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
# Otherwise pass given input
|
|
73
|
+
dr_proposal = self.dr
|
|
74
|
+
|
|
75
|
+
self.dr = DelayedRejection(dr_proposal, max_iter=dr_max_iter)
|
|
76
|
+
|
|
77
|
+
def setup(self, branches_coords):
|
|
78
|
+
"""Any setup for the proposal.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
branches_coords (dict): Keys are ``branch_names``. Values are
|
|
82
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
|
|
83
|
+
coordinates for all the walkers.
|
|
84
|
+
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def get_proposal(
|
|
88
|
+
self, all_coords, all_inds, nleaves_min_all, nleaves_max_all, random, **kwargs
|
|
89
|
+
):
|
|
90
|
+
"""Make a proposal
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
all_coords (dict): Keys are ``branch_names``. Values are
|
|
94
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
|
|
95
|
+
coordinates for all the walkers.
|
|
96
|
+
all_inds (dict): Keys are ``branch_names``. Values are
|
|
97
|
+
np.ndarray[ntemps, nwalkers, nleaves_max]. These are the boolean
|
|
98
|
+
arrays marking which leaves are currently used within each walker.
|
|
99
|
+
nleaves_min_all (dict): Minimum values of leaf ount for each model. Must have same order as ``all_cords``.
|
|
100
|
+
nleaves_max_all (dict): Maximum values of leaf ount for each model. Must have same order as ``all_cords``.
|
|
101
|
+
random (object): Current random state of the sampler.
|
|
102
|
+
**kwargs (ignored): For modularity.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
tuple: Tuple containing proposal information.
|
|
106
|
+
First entry is the new coordinates as a dictionary with keys
|
|
107
|
+
as ``branch_names`` and values as
|
|
108
|
+
``double `` np.ndarray[ntemps, nwalkers, nleaves_max, ndim] containing
|
|
109
|
+
proposed coordinates. Second entry is the new ``inds`` array with
|
|
110
|
+
boolean values flipped for added or removed sources. Third entry
|
|
111
|
+
is the factors associated with the
|
|
112
|
+
proposal necessary for detailed balance. This is effectively
|
|
113
|
+
any term in the detailed balance fraction. +log of factors if
|
|
114
|
+
in the numerator. -log of factors if in the denominator.
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
NotImplementedError: If this proposal is not implemented by a subclass.
|
|
118
|
+
|
|
119
|
+
"""
|
|
120
|
+
raise NotImplementedError("The proposal must be implemented by " "subclasses")
|
|
121
|
+
|
|
122
|
+
def get_model_change_proposal(self, inds, random, nleaves_min, nleaves_max):
|
|
123
|
+
"""Helper function for changing the model count by 1.
|
|
124
|
+
|
|
125
|
+
This helper function works with nested models where you want to add or remove
|
|
126
|
+
one leaf at a time.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
inds (np.ndarray): ``inds`` values for this specific branch with shape
|
|
130
|
+
``(ntemps, nwalkers, nleaves_max)``.
|
|
131
|
+
random (object): Current random state of the sampler.
|
|
132
|
+
nleaves_min (int): Minimum allowable leaf count for this branch.
|
|
133
|
+
nleaves_max (int): Maximum allowable leaf count for this branch.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
dict: Keys are ``"+1"`` and ``"-1"``. Values are indexing information.
|
|
137
|
+
``"+1"`` and ``"-1"`` indicate if a source is being added or removed, respectively.
|
|
138
|
+
The indexing information is a 2D array with shape ``(number changing, 3)``.
|
|
139
|
+
The length 3 is the index into each of the ``(ntemps, nwalkers, nleaves_max)``.
|
|
140
|
+
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
raise NotImplementedError
|
|
144
|
+
|
|
145
|
+
def propose(self, model, state):
|
|
146
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
150
|
+
state (:class:`State`): Current state of the sampler.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
:class:`State`: State of sampler after proposal is complete.
|
|
154
|
+
|
|
155
|
+
"""
|
|
156
|
+
# this exposes anywhere in the proposal class to this information
|
|
157
|
+
|
|
158
|
+
# Run any move-specific setup.
|
|
159
|
+
self.setup(state.branches)
|
|
160
|
+
|
|
161
|
+
ntemps, nwalkers, _, _ = state.branches[list(state.branches.keys())[0]].shape
|
|
162
|
+
|
|
163
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
164
|
+
|
|
165
|
+
all_branch_names = list(state.branches.keys())
|
|
166
|
+
|
|
167
|
+
ntemps, nwalkers, _, _ = state.branches[all_branch_names[0]].shape
|
|
168
|
+
|
|
169
|
+
for branch_names_run, inds_run in self.gibbs_sampling_setup_iterator(
|
|
170
|
+
all_branch_names
|
|
171
|
+
):
|
|
172
|
+
# gibbs sampling is only over branches so pick out that info
|
|
173
|
+
coords_propose_in = {
|
|
174
|
+
key: state.branches_coords[key] for key in branch_names_run
|
|
175
|
+
}
|
|
176
|
+
inds_propose_in = {
|
|
177
|
+
key: state.branches_inds[key] for key in branch_names_run
|
|
178
|
+
}
|
|
179
|
+
branches_supp_propose_in = {
|
|
180
|
+
key: state.branches_supplemental[key] for key in branch_names_run
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
if len(list(coords_propose_in.keys())) == 0:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
"Right now, no models are getting a reversible jump proposal. Check nleaves_min and nleaves_max or do not use rj proposal."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# get min and max leaf information
|
|
189
|
+
nleaves_max_all = {brn: self.nleaves_max[brn] for brn in branch_names_run}
|
|
190
|
+
nleaves_min_all = {brn: self.nleaves_min[brn] for brn in branch_names_run}
|
|
191
|
+
|
|
192
|
+
self.current_model = model
|
|
193
|
+
self.current_state = state
|
|
194
|
+
# propose new sources and coordinates
|
|
195
|
+
q, new_inds, factors = self.get_proposal(
|
|
196
|
+
coords_propose_in,
|
|
197
|
+
inds_propose_in,
|
|
198
|
+
nleaves_min_all,
|
|
199
|
+
nleaves_max_all,
|
|
200
|
+
model.random,
|
|
201
|
+
branch_supps=branches_supp_propose_in,
|
|
202
|
+
supps=state.supplemental,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
branches_supps_new = {
|
|
206
|
+
key: item for key, item in branches_supp_propose_in.items()
|
|
207
|
+
}
|
|
208
|
+
# account for gibbs sampling
|
|
209
|
+
self.cleanup_proposals_gibbs(
|
|
210
|
+
branch_names_run, inds_run, q, state.branches_coords
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# put back any branches that were left out from Gibbs split
|
|
214
|
+
for name, branch in state.branches.items():
|
|
215
|
+
if name not in q:
|
|
216
|
+
q[name] = state.branches[name].coords[:].copy()
|
|
217
|
+
if name not in new_inds:
|
|
218
|
+
new_inds[name] = state.branches[name].inds[:].copy()
|
|
219
|
+
|
|
220
|
+
if name not in branches_supps_new:
|
|
221
|
+
branches_supps_new[name] = state.branches_supplemental[name]
|
|
222
|
+
|
|
223
|
+
# fix any ordering issues
|
|
224
|
+
q, new_inds, branches_supps_new = self.ensure_ordering(
|
|
225
|
+
list(state.branches.keys()), q, new_inds, branches_supps_new
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
edge_factors = np.zeros((ntemps, nwalkers))
|
|
229
|
+
# get factors for edges
|
|
230
|
+
for name, branch in state.branches.items():
|
|
231
|
+
nleaves_max = self.nleaves_max[name]
|
|
232
|
+
nleaves_min = self.nleaves_min[name]
|
|
233
|
+
|
|
234
|
+
if name not in branch_names_run:
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
# get old and new values
|
|
238
|
+
old_nleaves = branch.nleaves
|
|
239
|
+
new_nleaves = new_inds[name].sum(axis=-1)
|
|
240
|
+
|
|
241
|
+
# do not work on sources with fixed source count
|
|
242
|
+
if nleaves_min == nleaves_max or nleaves_min + 1 == nleaves_max:
|
|
243
|
+
# nleaves_min == nleaves_max --> no rj proposal
|
|
244
|
+
# nleaves_min + 1 == nleaves_max --> no edge factors because it is guaranteed to be nleaves_min or nleaves_max
|
|
245
|
+
continue
|
|
246
|
+
|
|
247
|
+
elif nleaves_min > nleaves_max:
|
|
248
|
+
raise ValueError("nleaves_min cannot be greater than nleaves_max.")
|
|
249
|
+
|
|
250
|
+
else:
|
|
251
|
+
# fix proposal asymmetry at bottom of k range (kmin -> kmin + 1)
|
|
252
|
+
inds_min = np.where(old_nleaves == nleaves_min)
|
|
253
|
+
# numerator term so +ln
|
|
254
|
+
edge_factors[inds_min] += np.log(1 / 2.0)
|
|
255
|
+
|
|
256
|
+
# fix proposal asymmetry at top of k range (kmax -> kmax - 1)
|
|
257
|
+
inds_max = np.where(old_nleaves == nleaves_max)
|
|
258
|
+
# numerator term so -ln
|
|
259
|
+
edge_factors[inds_max] += np.log(1 / 2.0)
|
|
260
|
+
|
|
261
|
+
# fix proposal asymmetry at bottom of k range (kmin + 1 -> kmin)
|
|
262
|
+
inds_min = np.where(new_nleaves == nleaves_min)
|
|
263
|
+
# numerator term so +ln
|
|
264
|
+
edge_factors[inds_min] -= np.log(1 / 2.0)
|
|
265
|
+
|
|
266
|
+
# fix proposal asymmetry at top of k range (kmax - 1 -> kmax)
|
|
267
|
+
inds_max = np.where(new_nleaves == nleaves_max)
|
|
268
|
+
# numerator term so -ln
|
|
269
|
+
edge_factors[inds_max] -= np.log(1 / 2.0)
|
|
270
|
+
|
|
271
|
+
factors += edge_factors
|
|
272
|
+
|
|
273
|
+
# setup supplemental information
|
|
274
|
+
|
|
275
|
+
if state.supplemental is not None:
|
|
276
|
+
# TODO: should there be a copy?
|
|
277
|
+
new_supps = deepcopy(state.supplemental)
|
|
278
|
+
|
|
279
|
+
else:
|
|
280
|
+
new_supps = None
|
|
281
|
+
|
|
282
|
+
# for_transfer information can be taken directly from custom proposal
|
|
283
|
+
|
|
284
|
+
# supp info
|
|
285
|
+
|
|
286
|
+
if hasattr(self, "mt_supps"):
|
|
287
|
+
# logp = self.lp_for_transfer.reshape(ntemps, nwalkers)
|
|
288
|
+
new_supps = self.mt_supps
|
|
289
|
+
|
|
290
|
+
if hasattr(self, "mt_branch_supps"):
|
|
291
|
+
# logp = self.lp_for_transfer.reshape(ntemps, nwalkers)
|
|
292
|
+
new_branch_supps = self.mt_branch_supps
|
|
293
|
+
|
|
294
|
+
# logp and logl
|
|
295
|
+
|
|
296
|
+
# Compute prior of the proposed position
|
|
297
|
+
if hasattr(self, "mt_lp"):
|
|
298
|
+
logp = self.mt_lp.reshape(ntemps, nwalkers)
|
|
299
|
+
|
|
300
|
+
else:
|
|
301
|
+
logp = model.compute_log_prior_fn(q, inds=new_inds)
|
|
302
|
+
|
|
303
|
+
self.fix_logp_gibbs(branch_names_run, inds_run, logp, new_inds)
|
|
304
|
+
|
|
305
|
+
if hasattr(self, "mt_ll"):
|
|
306
|
+
logl = self.mt_ll.reshape(ntemps, nwalkers)
|
|
307
|
+
|
|
308
|
+
else:
|
|
309
|
+
# Compute the ln like of the proposed position.
|
|
310
|
+
logl, new_blobs = model.compute_log_like_fn(
|
|
311
|
+
q,
|
|
312
|
+
inds=new_inds,
|
|
313
|
+
logp=logp,
|
|
314
|
+
supps=new_supps,
|
|
315
|
+
branch_supps=branches_supps_new,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# posterior and previous info
|
|
319
|
+
|
|
320
|
+
logP = self.compute_log_posterior(logl, logp)
|
|
321
|
+
|
|
322
|
+
prev_logl = state.log_like
|
|
323
|
+
|
|
324
|
+
prev_logp = state.log_prior
|
|
325
|
+
|
|
326
|
+
# takes care of tempering
|
|
327
|
+
prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
|
|
328
|
+
|
|
329
|
+
# acceptance fraction
|
|
330
|
+
lnpdiff = factors + logP - prev_logP
|
|
331
|
+
|
|
332
|
+
accepted = lnpdiff > np.log(model.random.rand(ntemps, nwalkers))
|
|
333
|
+
|
|
334
|
+
# update with new state
|
|
335
|
+
new_state = State(
|
|
336
|
+
q,
|
|
337
|
+
log_like=logl,
|
|
338
|
+
log_prior=logp,
|
|
339
|
+
blobs=None,
|
|
340
|
+
inds=new_inds,
|
|
341
|
+
supplemental=new_supps,
|
|
342
|
+
branch_supplemental=branches_supps_new,
|
|
343
|
+
)
|
|
344
|
+
state = self.update(state, new_state, accepted)
|
|
345
|
+
|
|
346
|
+
# apply delayed rejection to walkers that are +1
|
|
347
|
+
# TODO: need to reexamine this a bit. I have a feeling that only applying
|
|
348
|
+
# this to +1 may not be preserving detailed balance. You may need to
|
|
349
|
+
# "simulate it" for -1 similar to what we do in multiple try
|
|
350
|
+
if self.dr:
|
|
351
|
+
raise NotImplementedError(
|
|
352
|
+
"Delayed Rejection will be implemented soon. Check for updated versions."
|
|
353
|
+
)
|
|
354
|
+
# for name, branch in state.branches.items():
|
|
355
|
+
# # We have to work with the binaries added only.
|
|
356
|
+
# # We need the a) rejected points, b) the model,
|
|
357
|
+
# # c) the current state, d) the indices where we had +1 (True),
|
|
358
|
+
# # and the e) factors.
|
|
359
|
+
inds_for_change = {}
|
|
360
|
+
for name in branch_names_run:
|
|
361
|
+
inds_for_change[name] = {
|
|
362
|
+
"+1": np.argwhere(new_inds[name] & (~state.branches[name].inds))
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
state, accepted = self.dr.propose(
|
|
366
|
+
lnpdiff,
|
|
367
|
+
accepted,
|
|
368
|
+
model,
|
|
369
|
+
state,
|
|
370
|
+
new_state,
|
|
371
|
+
new_inds,
|
|
372
|
+
inds_for_change,
|
|
373
|
+
factors,
|
|
374
|
+
) # model, state
|
|
375
|
+
|
|
376
|
+
# If RJ is true we control only on the in-model step, so no need to do it here as well
|
|
377
|
+
# In most cases, RJ proposal is has small acceptance rate, so in the end we end up
|
|
378
|
+
# switching back what was swapped in the previous in-model step.
|
|
379
|
+
# TODO: MLK: I think we should allow for swapping but no adaptation.
|
|
380
|
+
|
|
381
|
+
if self.temperature_control is not None and not self.prevent_swaps:
|
|
382
|
+
state = self.temperature_control.temper_comps(state, adapt=False)
|
|
383
|
+
|
|
384
|
+
# add to move-specific accepted information
|
|
385
|
+
self.accepted += accepted
|
|
386
|
+
self.num_proposals += 1
|
|
387
|
+
|
|
388
|
+
return state, accepted
|
eryn/moves/stretch.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
try:
|
|
3
|
+
import cupy as cp
|
|
4
|
+
except (ModuleNotFoundError, ImportError):
|
|
5
|
+
pass
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from .red_blue import RedBlueMove
|
|
10
|
+
|
|
11
|
+
__all__ = ["StretchMove"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class StretchMove(RedBlueMove):
|
|
15
|
+
"""Affine-Invariant Proposal
|
|
16
|
+
|
|
17
|
+
A `Goodman & Weare (2010)
|
|
18
|
+
<https://msp.org/camcos/2010/5-1/p04.xhtml>`_ "stretch move" with
|
|
19
|
+
parallelization as described in `Foreman-Mackey et al. (2013)
|
|
20
|
+
<https://arxiv.org/abs/1202.3665>`_.
|
|
21
|
+
|
|
22
|
+
This class was originally implemented in ``emcee``.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
a (double, optional): The stretch scale parameter. (default: ``2.0``)
|
|
26
|
+
return_gpu (bool, optional): If ``use_gpu == True and return_gpu == True``,
|
|
27
|
+
the returned arrays will be returned as ``CuPy`` arrays. (default: ``False``)
|
|
28
|
+
kwargs (dict, optional): Additional keyword arguments passed down through :class:`RedRedBlueMove`_.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
a (double): The stretch scale parameter.
|
|
32
|
+
return_gpu (bool): Whether the array being returned is in ``Cupy`` (``True``)
|
|
33
|
+
or ``NumPy`` (``False``).
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, a=2.0, return_gpu=False, random_seed=None, **kwargs):
|
|
38
|
+
# store scale factor
|
|
39
|
+
self.a = a
|
|
40
|
+
|
|
41
|
+
# pass kwargs up
|
|
42
|
+
RedBlueMove.__init__(self, **kwargs)
|
|
43
|
+
|
|
44
|
+
# change array library based on GPU usage
|
|
45
|
+
|
|
46
|
+
# set the random seet of the library if desired
|
|
47
|
+
if random_seed is not None:
|
|
48
|
+
self.xp.random.seed(random_seed)
|
|
49
|
+
|
|
50
|
+
self.return_gpu = return_gpu
|
|
51
|
+
|
|
52
|
+
# how it was formerly
|
|
53
|
+
# super(StretchMove, self).__init__(**kwargs)
|
|
54
|
+
|
|
55
|
+
def adjust_factors(self, factors, ndims_old, ndims_new):
|
|
56
|
+
"""Adjust the ``factors`` based on changing dimensions.
|
|
57
|
+
|
|
58
|
+
``factors`` is adjusted in place.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
factors (xp.ndarray): Array of ``factors`` values. It is adjusted in place.
|
|
62
|
+
ndims_old (int or xp.ndarray): Old dimension. If given as an ``xp.ndarray``,
|
|
63
|
+
must be broadcastable with ``factors``.
|
|
64
|
+
ndims_new (int or xp.ndarray): New dimension. If given as an ``xp.ndarray``,
|
|
65
|
+
must be broadcastable with ``factors``.
|
|
66
|
+
|
|
67
|
+
"""
|
|
68
|
+
# adjusts in place
|
|
69
|
+
if ndims_old == ndims_new:
|
|
70
|
+
return
|
|
71
|
+
logzz = factors / (ndims_old - 1.0)
|
|
72
|
+
factors[:] = logzz * (ndims_new - 1.0)
|
|
73
|
+
|
|
74
|
+
def choose_c_vals(self, c, Nc, Ns, ntemps, random_number_generator, **kwargs):
|
|
75
|
+
"""Get the compliment array
|
|
76
|
+
|
|
77
|
+
The compliment represents the points that are used to move the actual points whose position is
|
|
78
|
+
changing.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
c (np.ndarray): Possible compliment values with shape ``(ntemps, Nc, nleaves_max, ndim)``.
|
|
82
|
+
Nc (int): Length of the ``...``: the subset of walkers proposed to move now (usually nwalkers/2).
|
|
83
|
+
Ns (int): Number of generation points.
|
|
84
|
+
ntemps (int): Number of temperatures.
|
|
85
|
+
random_number_generator (object): Random state object.
|
|
86
|
+
**kwargs (ignored): Ignored here. For modularity.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
np.ndarray: Compliment values to use with shape ``(ntemps, Ns, nleaves_max, ndim)``.
|
|
90
|
+
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
rint = random_number_generator.randint(
|
|
94
|
+
Nc,
|
|
95
|
+
size=(
|
|
96
|
+
ntemps,
|
|
97
|
+
Ns,
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
c_temp = self.xp.take_along_axis(c, rint[:, :, None, None], axis=1)
|
|
101
|
+
return c_temp
|
|
102
|
+
|
|
103
|
+
def get_new_points(
|
|
104
|
+
self, name, s, c_temp, Ns, branch_shape, branch_i, random_number_generator
|
|
105
|
+
):
|
|
106
|
+
"""Get mew points in stretch move.
|
|
107
|
+
|
|
108
|
+
Takes compliment and uses it to get new points for those being proposed.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
name (str): Branch name.
|
|
112
|
+
s (np.ndarray): Points to be moved with shape ``(ntemps, Ns, nleaves_max, ndim)``.
|
|
113
|
+
c (np.ndarray): Compliment to move points with shape ``(ntemps, Ns, nleaves_max, ndim)``.
|
|
114
|
+
Ns (int): Number to generate.
|
|
115
|
+
branch_shape (tuple): Full branch shape.
|
|
116
|
+
branch_i (int): Which branch in the order is being run now. This ensures that the
|
|
117
|
+
randomly generated quantity per walker remains the same over branches.
|
|
118
|
+
random_number_generator (object): Random state object.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
np.ndarray: New proposed points with shape ``(ntemps, Ns, nleaves_max, ndim)``.
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
ntemps, nwalkers, nleaves_max, ndim_here = branch_shape
|
|
127
|
+
|
|
128
|
+
# only for the first branch do we draw for zz
|
|
129
|
+
if branch_i == 0:
|
|
130
|
+
self.zz = (
|
|
131
|
+
(self.a - 1.0) * random_number_generator.rand(ntemps, Ns) + 1
|
|
132
|
+
) ** 2.0 / self.a
|
|
133
|
+
|
|
134
|
+
# get proper distance
|
|
135
|
+
|
|
136
|
+
if self.periodic is not None:
|
|
137
|
+
diff = self.periodic.distance(
|
|
138
|
+
{name: s.reshape(ntemps * nwalkers, nleaves_max, ndim_here)},
|
|
139
|
+
{name: c_temp.reshape(ntemps * nwalkers, nleaves_max, ndim_here)},
|
|
140
|
+
xp=self.xp,
|
|
141
|
+
)[name].reshape(ntemps, nwalkers, nleaves_max, ndim_here)
|
|
142
|
+
else:
|
|
143
|
+
diff = c_temp - s
|
|
144
|
+
|
|
145
|
+
temp = c_temp - (diff) * self.zz[:, :, None, None]
|
|
146
|
+
|
|
147
|
+
# wrap periodic values
|
|
148
|
+
|
|
149
|
+
if self.periodic is not None:
|
|
150
|
+
temp = self.periodic.wrap(
|
|
151
|
+
{name: temp.reshape(ntemps * nwalkers, nleaves_max, ndim_here)},
|
|
152
|
+
xp=self.xp,
|
|
153
|
+
)[name].reshape(ntemps, nwalkers, nleaves_max, ndim_here)
|
|
154
|
+
|
|
155
|
+
# get from gpu or not
|
|
156
|
+
if self.use_gpu and not self.return_gpu:
|
|
157
|
+
temp = temp.get()
|
|
158
|
+
return temp
|
|
159
|
+
|
|
160
|
+
def get_proposal(self, s_all, c_all, random, gibbs_ndim=None, **kwargs):
|
|
161
|
+
"""Generate stretch proposal
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
s_all (dict): Keys are ``branch_names`` and values are coordinates
|
|
165
|
+
for which a proposal is to be generated.
|
|
166
|
+
c_all (dict): Keys are ``branch_names`` and values are lists. These
|
|
167
|
+
lists contain all the complement array values.
|
|
168
|
+
random (object): Random state object.
|
|
169
|
+
gibbs_ndim (int or np.ndarray, optional): If Gibbs sampling, this indicates
|
|
170
|
+
the true dimension. If given as an array, must have shape ``(ntemps, nwalkers)``.
|
|
171
|
+
See the tutorial for more information.
|
|
172
|
+
(default: ``None``)
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
tuple: First entry is new positions. Second entry is detailed balance factors.
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
ValueError: Issues with dimensionality.
|
|
179
|
+
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
# needs to be set before we reach the end
|
|
183
|
+
self.zz = None
|
|
184
|
+
random_number_generator = random if not self.use_gpu else self.xp.random
|
|
185
|
+
newpos = {}
|
|
186
|
+
|
|
187
|
+
# iterate over branches
|
|
188
|
+
for i, name in enumerate(s_all):
|
|
189
|
+
# get points to move
|
|
190
|
+
s = self.xp.asarray(s_all[name])
|
|
191
|
+
|
|
192
|
+
if not isinstance(c_all[name], list):
|
|
193
|
+
raise ValueError("c_all for each branch needs to be a list.")
|
|
194
|
+
|
|
195
|
+
# get compliment possibilities
|
|
196
|
+
c = [self.xp.asarray(c_tmp) for c_tmp in c_all[name]]
|
|
197
|
+
|
|
198
|
+
ntemps, nwalkers, nleaves_max, ndim_here = s.shape
|
|
199
|
+
c = self.xp.concatenate(c, axis=1)
|
|
200
|
+
|
|
201
|
+
Ns, Nc = s.shape[1], c.shape[1]
|
|
202
|
+
# gets rid of any values of exactly zero
|
|
203
|
+
ndim_temp = nleaves_max * ndim_here
|
|
204
|
+
|
|
205
|
+
# need to properly handle ndim
|
|
206
|
+
if i == 0:
|
|
207
|
+
ndim = ndim_temp
|
|
208
|
+
Ns_check = Ns
|
|
209
|
+
|
|
210
|
+
else:
|
|
211
|
+
ndim += ndim_temp
|
|
212
|
+
if Ns_check != Ns:
|
|
213
|
+
raise ValueError("Different number of walkers across models.")
|
|
214
|
+
|
|
215
|
+
# get actual compliment values
|
|
216
|
+
c_temp = self.choose_c_vals(c, Nc, Ns, ntemps, random_number_generator)
|
|
217
|
+
|
|
218
|
+
# use stretch to get new proposals
|
|
219
|
+
newpos[name] = self.get_new_points(
|
|
220
|
+
name, s, c_temp, Ns, s.shape, i, random_number_generator
|
|
221
|
+
)
|
|
222
|
+
# proper factors
|
|
223
|
+
factors = (ndim - 1.0) * self.xp.log(self.zz)
|
|
224
|
+
if self.use_gpu and not self.return_gpu:
|
|
225
|
+
factors = factors.get()
|
|
226
|
+
|
|
227
|
+
if gibbs_ndim is not None:
|
|
228
|
+
# adjust factors in place
|
|
229
|
+
self.adjust_factors(factors, ndim, gibbs_ndim)
|
|
230
|
+
|
|
231
|
+
return newpos, factors
|