eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eryn/moves/red_blue.py ADDED
@@ -0,0 +1,333 @@
1
+ # -*- coding: utf-8 -*-
2
+ from abc import ABC
3
+ from copy import deepcopy
4
+ import numpy as np
5
+ import warnings
6
+
7
+ from ..state import BranchSupplemental, State
8
+ from .move import Move
9
+
10
+
11
+ __all__ = ["RedBlueMove"]
12
+
13
+
14
+ class RedBlueMove(Move, ABC):
15
+ """
16
+ An abstract red-blue ensemble move with parallelization as described in
17
+ `Foreman-Mackey et al. (2013) <https://arxiv.org/abs/1202.3665>`_.
18
+
19
+ This class is heavily based on the original from ``emcee``.
20
+
21
+ Args:
22
+ nsplits (int, optional): The number of sub-ensembles to use. Each
23
+ sub-ensemble is updated in parallel using the other sets as the
24
+ complementary ensemble. The default value is ``2`` and you
25
+ probably won't need to change that.
26
+ randomize_split (bool, optional): Randomly shuffle walkers between
27
+ sub-ensembles. The same number of walkers will be assigned to
28
+ each sub-ensemble on each iteration. (default: ``True``)
29
+ live_dangerously (bool, optional): By default, an update will fail with
30
+ a ``RuntimeError`` if the number of walkers is smaller than twice
31
+ the dimension of the problem because the walkers would then be
32
+ stuck on a low dimensional subspace. This can be avoided by
33
+ switching between the stretch move and, for example, a
34
+ Metropolis-Hastings step. If you want to do this and suppress the
35
+ error, set ``live_dangerously = True``. Thanks goes (once again)
36
+ to @dstndstn for this wonderful terminology. (default: ``False``)
37
+ **kwargs (dict, optional): Kwargs for parent classes.
38
+
39
+ """
40
+
41
+ def __init__(
42
+ self, nsplits=2, randomize_split=True, live_dangerously=False, **kwargs
43
+ ):
44
+ super(RedBlueMove, self).__init__(**kwargs)
45
+ self.nsplits = int(nsplits)
46
+ self.live_dangerously = live_dangerously
47
+ self.randomize_split = randomize_split
48
+
49
+ def setup(self, branches_coords):
50
+ """Any setup for the proposal.
51
+
52
+ Args:
53
+ branches_coords (dict): Keys are ``branch_names``. Values are
54
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
55
+ coordinates for all the walkers.
56
+
57
+ """
58
+
59
+ @classmethod
60
+ def get_proposal(self, sample, complement, random, gibbs_ndim=None):
61
+ """Make a proposal
62
+
63
+ Args:
64
+ sample (dict): Keys are ``branch_names``. Values are
65
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim].
66
+ complement (dict): Keys are ``branch_names``. Values are lists of other
67
+ other np.ndarray[ntemps, nwalkers - subset size, nleaves_max, ndim] from
68
+ all other subsets. This is the compliment whose ``coords`` are
69
+ used to form the proposal for the ``sample`` subset.
70
+ random (object): Current random state of the sampler.
71
+ gibbs_ndim (int or np.ndarray, optional): If Gibbs sampling, this indicates
72
+ the true dimension. If given as an array, must have shape ``(ntemps, nwalkers)``.
73
+ See the tutorial for more information.
74
+ (default: ``None``)
75
+
76
+ Returns:
77
+ tuple: Tuple contained proposal information.
78
+ First entry is the new coordinates as a dictionary with keys
79
+ as ``branch_names`` and values as np.ndarray[ntemps, nwalkers, nleaves_max, ndim]
80
+ of new coordinates. Second entry is the factors associated with the
81
+ proposal necessary for detailed balance. This is effectively
82
+ any term in the detailed balance fraction. +log of factors if
83
+ in the numerator. -log of factors if in the denominator.
84
+
85
+ """
86
+
87
+ raise NotImplementedError("The proposal must be implemented by " "subclasses")
88
+
89
+ def propose(self, model, state):
90
+ """Use the move to generate a proposal and compute the acceptance
91
+
92
+ Args:
93
+ model (:class:`eryn.model.Model`): Carrier of sampler information.
94
+ state (:class:`State`): Current state of the sampler.
95
+
96
+ Returns:
97
+ tuple: (state, accepted)
98
+ The first return is the state of the sampler after the move.
99
+ The second return value is the accepted count array.
100
+
101
+ """
102
+ # Check that the dimensions are compatible.
103
+ ndim_total = 0
104
+ for branch in state.branches.values():
105
+ ntemps, nwalkers, nleaves_, ndim_ = branch.shape
106
+ ndim_total += ndim_ * nleaves_
107
+
108
+ if nwalkers < 2 * ndim_total and not self.live_dangerously:
109
+ raise RuntimeError(
110
+ "It is unadvisable to use a red-blue move "
111
+ "with fewer walkers than twice the number of "
112
+ "dimensions. If you would like to do this, please set live_dangerously"
113
+ "to True."
114
+ )
115
+
116
+ # Run any move-specific setup.
117
+ self.setup(state.branches)
118
+
119
+ # Split the ensemble in half and iterate over these two halves.
120
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
121
+ all_inds = np.tile(np.arange(nwalkers), (ntemps, 1))
122
+ inds = all_inds % self.nsplits
123
+ if self.randomize_split:
124
+ [np.random.shuffle(x) for x in inds]
125
+
126
+ all_branch_names = list(state.branches.keys())
127
+
128
+ ntemps, nwalkers, _, _ = state.branches[all_branch_names[0]].shape
129
+
130
+ # get gibbs sampling information
131
+ for branch_names_run, inds_run in self.gibbs_sampling_setup_iterator(
132
+ all_branch_names
133
+ ):
134
+ # setup proposals based on Gibbs sampling
135
+ (
136
+ coords_going_for_proposal,
137
+ inds_going_for_proposal,
138
+ at_least_one_proposal,
139
+ ) = self.setup_proposals(
140
+ branch_names_run, inds_run, state.branches_coords, state.branches_inds
141
+ )
142
+
143
+ if not at_least_one_proposal:
144
+ continue
145
+
146
+ # prepare accepted fraction
147
+ accepted_here = np.zeros((ntemps, nwalkers), dtype=bool)
148
+ for split in range(self.nsplits):
149
+ # get split information
150
+ S1 = inds == split
151
+ num_total_here = np.sum(inds == split)
152
+ nwalkers_here = np.sum(S1[0])
153
+
154
+ all_inds_shaped = all_inds[S1].reshape(ntemps, nwalkers_here)
155
+
156
+ # inds including gibbs information
157
+ new_inds = {
158
+ name: np.take_along_axis(
159
+ state.branches[name].inds,
160
+ all_inds_shaped[:, :, None],
161
+ axis=1,
162
+ )
163
+ for name in state.branches
164
+ }
165
+
166
+ # the actual inds for the subset
167
+ real_inds_subset = {
168
+ name: new_inds[name] for name in inds_going_for_proposal
169
+ }
170
+
171
+ # actual coordinates of subset
172
+ temp_coords = {
173
+ name: np.take_along_axis(
174
+ state.branches_coords[name],
175
+ all_inds_shaped[:, :, None, None],
176
+ axis=1,
177
+ )
178
+ for name in state.branches_coords
179
+ }
180
+
181
+ # prepare the sets for each model
182
+ # goes into the proposal as (ntemps * (nwalkers / subset size), nleaves_max, ndim)
183
+ sets = {
184
+ key: [
185
+ np.take_along_axis(
186
+ state.branches[key].coords,
187
+ all_inds[inds == j].reshape(ntemps, -1)[:, :, None, None],
188
+ axis=1,
189
+ )
190
+ for j in range(self.nsplits)
191
+ ]
192
+ for key in branch_names_run
193
+ }
194
+
195
+ # setup s and c based on splits
196
+ s = {key: sets[key][split] for key in sets}
197
+ c = {key: sets[key][:split] + sets[key][split + 1 :] for key in sets}
198
+
199
+ # need to trick stretch proposal into using the dimenionality associated
200
+ # with Gibbs sampling if it is being used
201
+ gibbs_ndim = 0
202
+ for brn, ir in zip(branch_names_run, inds_run):
203
+ if ir is not None:
204
+ gibbs_ndim += ir.sum()
205
+ else:
206
+ # nleaves * ndim
207
+ gibbs_ndim += np.prod(state.branches[brn].shape[-2:])
208
+
209
+ # Get the move-specific proposal.
210
+ q, factors = self.get_proposal(
211
+ s, c, model.random, gibbs_ndim=gibbs_ndim
212
+ )
213
+
214
+ # account for gibbs sampling
215
+ self.cleanup_proposals_gibbs(branch_names_run, inds_run, q, temp_coords)
216
+
217
+ # setup supplemental information
218
+ if state.supplemental is not None:
219
+ # TODO: should there be a copy?
220
+ new_supps = BranchSupplemental(
221
+ state.supplemental.take_along_axis(all_inds_shaped, axis=1),
222
+ base_shape=(ntemps, nwalkers),
223
+ copy=False,
224
+ )
225
+
226
+ else:
227
+ new_supps = None
228
+
229
+ # default for removing inds info from supp
230
+ if not np.all(
231
+ np.asarray(list(state.branches_supplemental.values())) == None
232
+ ):
233
+ new_branch_supps_tmp = {
234
+ name: state.branches[name].branch_supplemental.take_along_axis(
235
+ all_inds_shaped[:, :, None], axis=1
236
+ )
237
+ for name in state.branches
238
+ if state.branches[name].branch_supplemental is not None
239
+ }
240
+
241
+ new_branch_supps = {
242
+ name: BranchSupplemental(
243
+ new_branch_supps_tmp[name],
244
+ base_shape=new_inds[name].shape,
245
+ copy=False,
246
+ )
247
+ for name in new_branch_supps_tmp
248
+ }
249
+
250
+ else:
251
+ new_branch_supps = None
252
+
253
+ # order everything properly
254
+ q, new_inds, new_branch_supps = self.ensure_ordering(
255
+ list(state.branches.keys()), q, new_inds, new_branch_supps
256
+ )
257
+
258
+ # Compute prior of the proposed position
259
+ # new_inds_prior is adjusted if product-space is used
260
+ logp = model.compute_log_prior_fn(
261
+ q,
262
+ inds=new_inds,
263
+ supps=new_supps,
264
+ branch_supps=new_branch_supps,
265
+ )
266
+
267
+ self.fix_logp_gibbs(branch_names_run, inds_run, logp, real_inds_subset)
268
+
269
+ # Compute the lnprobs of the proposed position.
270
+ logl, new_blobs = model.compute_log_like_fn(
271
+ q,
272
+ inds=new_inds,
273
+ logp=logp,
274
+ supps=new_supps,
275
+ branch_supps=new_branch_supps,
276
+ )
277
+
278
+ # catch and warn about nans
279
+ if np.any(np.isnan(logl)):
280
+ logl[np.isnan(logl)] = -1e300
281
+ warnings.warn("Getting Nan in likelihood computation.")
282
+
283
+ logP = self.compute_log_posterior(logl, logp)
284
+
285
+ prev_logl = np.take_along_axis(state.log_like, all_inds_shaped, axis=1)
286
+
287
+ prev_logp = np.take_along_axis(state.log_prior, all_inds_shaped, axis=1)
288
+
289
+ # takes care of tempering
290
+ prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
291
+
292
+ lnpdiff = factors + logP - prev_logP
293
+
294
+ keep = lnpdiff > np.log(model.random.rand(ntemps, nwalkers_here))
295
+
296
+ # if gibbs sampling, this will say it is accepted if
297
+ # any of the gibbs proposals were accepted
298
+ np.put_along_axis(
299
+ accepted_here,
300
+ all_inds_shaped,
301
+ keep,
302
+ axis=1,
303
+ )
304
+
305
+ # readout for total
306
+ accepted = (accepted.astype(int) + accepted_here.astype(int)).astype(
307
+ bool
308
+ )
309
+ # new state
310
+ new_state = State(
311
+ q,
312
+ log_like=logl,
313
+ log_prior=logp,
314
+ blobs=new_blobs,
315
+ inds=new_inds,
316
+ supplemental=new_supps,
317
+ branch_supplemental=new_branch_supps,
318
+ )
319
+
320
+ # update state
321
+ state = self.update(
322
+ state, new_state, accepted_here, subset=all_inds_shaped
323
+ )
324
+
325
+ # add to move-specific accepted information
326
+ self.accepted += accepted
327
+ self.num_proposals += 1
328
+
329
+ # temp swaps
330
+ if self.temperature_control is not None:
331
+ state = self.temperature_control.temper_comps(state)
332
+
333
+ return state, accepted