eryn 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eryn/CMakeLists.txt +51 -0
- eryn/__init__.py +35 -0
- eryn/backends/__init__.py +20 -0
- eryn/backends/backend.py +1150 -0
- eryn/backends/hdfbackend.py +819 -0
- eryn/ensemble.py +1690 -0
- eryn/git_version.py.in +7 -0
- eryn/model.py +18 -0
- eryn/moves/__init__.py +42 -0
- eryn/moves/combine.py +135 -0
- eryn/moves/delayedrejection.py +229 -0
- eryn/moves/distgen.py +104 -0
- eryn/moves/distgenrj.py +222 -0
- eryn/moves/gaussian.py +190 -0
- eryn/moves/group.py +281 -0
- eryn/moves/groupstretch.py +120 -0
- eryn/moves/mh.py +193 -0
- eryn/moves/move.py +703 -0
- eryn/moves/mtdistgen.py +137 -0
- eryn/moves/mtdistgenrj.py +190 -0
- eryn/moves/multipletry.py +776 -0
- eryn/moves/red_blue.py +333 -0
- eryn/moves/rj.py +388 -0
- eryn/moves/stretch.py +231 -0
- eryn/moves/tempering.py +649 -0
- eryn/pbar.py +56 -0
- eryn/prior.py +452 -0
- eryn/state.py +775 -0
- eryn/tests/__init__.py +0 -0
- eryn/tests/test_eryn.py +1246 -0
- eryn/utils/__init__.py +10 -0
- eryn/utils/periodic.py +134 -0
- eryn/utils/stopping.py +164 -0
- eryn/utils/transform.py +226 -0
- eryn/utils/updates.py +69 -0
- eryn/utils/utility.py +329 -0
- eryn-1.2.0.dist-info/METADATA +167 -0
- eryn-1.2.0.dist-info/RECORD +39 -0
- eryn-1.2.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,776 @@
|
|
|
1
|
+
from multiprocessing.sharedctypes import Value
|
|
2
|
+
import numpy as np
|
|
3
|
+
import warnings
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from abc import ABC
|
|
6
|
+
|
|
7
|
+
# from scipy.special import logsumexp
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import cupy as cp
|
|
11
|
+
|
|
12
|
+
gpu_available = True
|
|
13
|
+
except (ModuleNotFoundError, ImportError):
|
|
14
|
+
import numpy as cp
|
|
15
|
+
|
|
16
|
+
gpu_available = False
|
|
17
|
+
|
|
18
|
+
from .rj import ReversibleJumpMove
|
|
19
|
+
from ..prior import ProbDistContainer
|
|
20
|
+
from ..utils.utility import groups_from_inds
|
|
21
|
+
|
|
22
|
+
___ = ["MultipleTryMove"]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def logsumexp(a, axis=None, xp=None):
|
|
26
|
+
if xp is None:
|
|
27
|
+
xp = np
|
|
28
|
+
|
|
29
|
+
max = xp.max(a, axis=axis)
|
|
30
|
+
ds = a - max[:, None]
|
|
31
|
+
|
|
32
|
+
sum_of_exp = xp.exp(ds).sum(axis=axis)
|
|
33
|
+
return max + xp.log(sum_of_exp)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_mt_computations(logP, log_proposal_pdf, symmetric=False, xp=None):
|
|
37
|
+
|
|
38
|
+
if xp is None:
|
|
39
|
+
xp = np
|
|
40
|
+
|
|
41
|
+
# set weights based on if symmetric
|
|
42
|
+
if symmetric:
|
|
43
|
+
log_importance_weights = logP
|
|
44
|
+
else:
|
|
45
|
+
log_importance_weights = logP - log_proposal_pdf
|
|
46
|
+
|
|
47
|
+
# get the sum of weights
|
|
48
|
+
log_sum_weights = logsumexp(log_importance_weights, axis=-1, xp=xp)
|
|
49
|
+
|
|
50
|
+
# probs = wi / sum(wi)
|
|
51
|
+
log_of_probs = log_importance_weights - log_sum_weights[:, None]
|
|
52
|
+
|
|
53
|
+
# probabilities to choose try
|
|
54
|
+
probs = xp.exp(log_of_probs)
|
|
55
|
+
|
|
56
|
+
# draw based on likelihood
|
|
57
|
+
inds_keep = (probs.cumsum(1) > xp.random.rand(probs.shape[0])[:, None]).argmax(1)
|
|
58
|
+
|
|
59
|
+
return log_importance_weights, log_sum_weights, inds_keep
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class MultipleTryMove(ABC):
|
|
63
|
+
"""Generate multiple proposal tries.
|
|
64
|
+
|
|
65
|
+
This class should be inherited by another proposal class
|
|
66
|
+
with the ``@classmethods`` overwritten. See :class:`eryn.moves.MTDistGenMove`
|
|
67
|
+
and :class:`MTDistGenRJ` for examples.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
num_try (int, optional): Number of tries. (default: 1)
|
|
71
|
+
independent (bool, optional): Set to ``True`` if the proposal is independent of the current points.
|
|
72
|
+
(default: ``False``).
|
|
73
|
+
symmetric (bool, optional): Set to ``True`` if the proposal is symmetric.
|
|
74
|
+
(default: ``False``).
|
|
75
|
+
rj (bool, optional): Set to ``True`` if this is a nested reversible jump proposal.
|
|
76
|
+
(default: ``False``).
|
|
77
|
+
**kwargs (dict, optional): for compatibility with other proposals.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: Input issues.
|
|
81
|
+
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
num_try=1,
|
|
87
|
+
independent=False,
|
|
88
|
+
symmetric=False,
|
|
89
|
+
rj=False,
|
|
90
|
+
use_gpu=None,
|
|
91
|
+
**kwargs
|
|
92
|
+
):
|
|
93
|
+
self.num_try = num_try
|
|
94
|
+
|
|
95
|
+
self.independent = independent
|
|
96
|
+
self.symmetric = symmetric
|
|
97
|
+
self.rj = rj
|
|
98
|
+
|
|
99
|
+
if self.rj:
|
|
100
|
+
if self.symmetric or self.independent:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
"If rj==True, symmetric and independt must both be False."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self.use_gpu = use_gpu
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def xp(self):
|
|
109
|
+
xp = cp if self.use_gpu else np
|
|
110
|
+
return xp
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def special_like_func(self, generated_coords, *args, inds_leaves_rj=None, **kwargs):
|
|
114
|
+
"""Calculate the Likelihood for sampled points.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
generated_coords (np.ndarray): Generated coordinates with shape ``(number of independent walkers, num_try)``.
|
|
118
|
+
*args (tuple, optional): additional arguments passed by overwriting the
|
|
119
|
+
``get_proposal`` function and passing ``args_like`` keyword argument.
|
|
120
|
+
inds_leaves_rj (np.ndarray): Index into each individual walker giving the
|
|
121
|
+
leaf index associated with this proposal. Should only be used if ``self.rj is True``. (default: ``None``)
|
|
122
|
+
**kwargs (tuple, optional): additional keyword arguments passed by overwriting the
|
|
123
|
+
``get_proposal`` function and passing ``kwargs_like`` keyword argument.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
np.ndarray: Likelihood values with shape ``(generated_coords.shape[0], num_try).``
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
NotImplementedError: Function not included.
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
raise NotImplementedError
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
def special_prior_func(self, generated_coords, *args, **kwargs):
|
|
136
|
+
"""Calculate the Prior for sampled points.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
generated_coords (np.ndarray): Generated coordinates with shape ``(number of independent walkers, num_try)``.
|
|
140
|
+
*args (tuple, optional): additional arguments passed by overwriting the
|
|
141
|
+
``get_proposal`` function and passing ``args_prior`` keyword argument.
|
|
142
|
+
inds_leaves_rj (np.ndarray): Index into each individual walker giving the
|
|
143
|
+
leaf index associated with this proposal. Should only be used if ``self.rj is True``. (default: ``None``)
|
|
144
|
+
**kwargs (tuple, optional): additional keyword arguments passed by overwriting the
|
|
145
|
+
``get_proposal`` function and passing ``kwargs_prior`` keyword argument.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
np.ndarray: Prior values with shape ``(generated_coords.shape[0], num_try).``
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
NotImplementedError: Function not included.
|
|
152
|
+
|
|
153
|
+
"""
|
|
154
|
+
raise NotImplementedError
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def special_generate_func(
|
|
158
|
+
coords, random, size=1, *args, fill_tuple=None, fill_values=None, **kwargs
|
|
159
|
+
):
|
|
160
|
+
"""Generate samples and calculate the logpdf of their proposal function.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
coords (np.ndarray): Current coordinates of walkers.
|
|
164
|
+
random (obj): Random generator.
|
|
165
|
+
*args (tuple, optional): additional arguments passed by overwriting the
|
|
166
|
+
``get_proposal`` function and passing ``args_generate`` keyword argument.
|
|
167
|
+
size (int, optional): Number of tries to generate.
|
|
168
|
+
fill_tuple (tuple, optional): Length 2 tuple with the indexing of which values to fill
|
|
169
|
+
when generating. Can be used for auxillary proposals or reverse RJ proposals. First index is the index into walkers and the second index is
|
|
170
|
+
the index into the number of tries. (default: ``None``)
|
|
171
|
+
fill_values (np.ndarray): values to fill associated with ``fill_tuple``. Should
|
|
172
|
+
have size ``(len(fill_tuple[0]), ndim)``. (default: ``None``).
|
|
173
|
+
**kwargs (tuple, optional): additional keyword arguments passed by overwriting the
|
|
174
|
+
``get_proposal`` function and passing ``kwargs_generate`` keyword argument.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
tuple: (generated points, logpdf of generated points).
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
NotImplementedError: Function not included.
|
|
181
|
+
|
|
182
|
+
"""
|
|
183
|
+
raise NotImplementedError
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def special_generate_logpdf(self, coords):
|
|
187
|
+
"""Get logpdf of generated coordinates.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
coords (np.ndarray): Current coordinates of walkers.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
np.ndarray: logpdf of generated points.
|
|
194
|
+
|
|
195
|
+
Raises:
|
|
196
|
+
NotImplementedError: Function not included.
|
|
197
|
+
"""
|
|
198
|
+
raise NotImplementedError
|
|
199
|
+
|
|
200
|
+
def get_mt_log_posterior(self, ll, lp, betas=None):
|
|
201
|
+
"""Calculate the log of the posterior for all tries.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
ll (np.ndarray): Log Likelihood values with shape ``(nwalkers, num_tries)``.
|
|
205
|
+
lp (np.ndarray): Log Prior values with shape ``(nwalkers, num_tries)``.
|
|
206
|
+
betas (np.ndarray, optional): Inverse temperatures to include in log Posterior computation.
|
|
207
|
+
(default: ``None``)
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
np.ndarray: Log of the Posterior with shape ``(nwalkers, num_tries)``.
|
|
211
|
+
|
|
212
|
+
"""
|
|
213
|
+
if betas is None:
|
|
214
|
+
ll_temp = ll.copy()
|
|
215
|
+
else:
|
|
216
|
+
assert isinstance(betas, self.xp.ndarray)
|
|
217
|
+
if ll.ndim > 1:
|
|
218
|
+
betas_tmp = self.xp.expand_dims(betas, ll.ndim - 1)
|
|
219
|
+
else:
|
|
220
|
+
betas_tmp = betas
|
|
221
|
+
ll_temp = betas_tmp * ll
|
|
222
|
+
|
|
223
|
+
return ll_temp + lp
|
|
224
|
+
|
|
225
|
+
def readout_adjustment(self, out_vals, all_vals_prop, aux_all_vals):
|
|
226
|
+
"""Read out values from the proposal.
|
|
227
|
+
|
|
228
|
+
Allows the user to read out any values from the proposal that may be needed elsewhere. This function must be overwritten.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
out_vals (list): ``[logP_out, ll_out, lp_out, log_proposal_pdf_out, log_sum_weights]``.
|
|
232
|
+
all_vals_prop (list): ``[logP, ll, lp, log_proposal_pdf, log_sum_weights]``.
|
|
233
|
+
aux_all_vals (list): ``[aux_logP, aux_ll, aux_lp, aux_log_proposal_pdf, aux_log_sum_weights]``.
|
|
234
|
+
|
|
235
|
+
"""
|
|
236
|
+
pass
|
|
237
|
+
|
|
238
|
+
def get_mt_proposal(
|
|
239
|
+
self,
|
|
240
|
+
coords,
|
|
241
|
+
random,
|
|
242
|
+
args_generate=(),
|
|
243
|
+
kwargs_generate={},
|
|
244
|
+
args_like=(),
|
|
245
|
+
kwargs_like={},
|
|
246
|
+
args_prior=(),
|
|
247
|
+
kwargs_prior={},
|
|
248
|
+
betas=None,
|
|
249
|
+
ll_in=None,
|
|
250
|
+
lp_in=None,
|
|
251
|
+
inds_leaves_rj=None,
|
|
252
|
+
inds_reverse_rj=None,
|
|
253
|
+
):
|
|
254
|
+
"""Make a multiple-try proposal
|
|
255
|
+
|
|
256
|
+
Here, ``nwalkers`` refers to all independent walkers which generally
|
|
257
|
+
will mean ``nwalkers * ntemps`` in terms of the rest of the sampler.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
coords (np.ndarray): Current coordinates of walkers.
|
|
261
|
+
random (obj): Random generator.
|
|
262
|
+
args_generate (tuple, optional): Additional ``*args`` to pass to generate function.
|
|
263
|
+
Must overwrite ``get_proposal`` function to use these.
|
|
264
|
+
(default: ``()``)
|
|
265
|
+
kwargs_generate (dict, optional): Additional ``**kwargs`` to pass to generate function.
|
|
266
|
+
(default: ``{}``)
|
|
267
|
+
Must overwrite ``get_proposal`` function to use these.
|
|
268
|
+
args_like (tuple, optional): Additional ``*args`` to pass to Likelihood function.
|
|
269
|
+
Must overwrite ``get_proposal`` function to use these.
|
|
270
|
+
(default: ``()``)
|
|
271
|
+
kwargs_like (dict, optional): Additional ``**kwargs`` to pass to Likelihood function.
|
|
272
|
+
Must overwrite ``get_proposal`` function to use these.
|
|
273
|
+
(default: ``{}``)
|
|
274
|
+
args_prior (tuple, optional): Additional ``*args`` to pass to Prior function.
|
|
275
|
+
Must overwrite ``get_proposal`` function to use these.
|
|
276
|
+
(default: ``()``)
|
|
277
|
+
kwargs_prior (dict, optional): Additional ``**kwargs`` to pass to Prior function.
|
|
278
|
+
Must overwrite ``get_proposal`` function to use these.
|
|
279
|
+
(default: ``{}``)
|
|
280
|
+
betas (np.ndarray, optional): Inverse temperatures passes to the proposal with shape ``(nwalkers,)``.
|
|
281
|
+
ll_in (np.ndarray, optional): Log Likelihood values coming in for current coordinates. Must be provided
|
|
282
|
+
if ``self.rj is True``. If ``self.rj is True``, must be nested.
|
|
283
|
+
Also, for all proposed removals, this value must be the Likelihood with the binary
|
|
284
|
+
removed so all proposals are pretending to add a binary.
|
|
285
|
+
Useful if ``self.independent is True``. (default: ``None``)
|
|
286
|
+
lp_in (np.ndarray, optional): Log Prior values coming in for current coordinates. Must be provided
|
|
287
|
+
if ``self.rj is True``. If ``self.rj is True``, must be nested.
|
|
288
|
+
Also, for all proposed removals, this value must be the Likelihood with the binary
|
|
289
|
+
removed so all proposals are pretending to add a binary.
|
|
290
|
+
Useful if ``self.independent is True``. (default: ``None``)
|
|
291
|
+
inds_leaves_rj (np.ndarray, optional): Array giving the leaf index of each incoming walker.
|
|
292
|
+
Must be provided if ``self.rj is True``. (default: ``None``)
|
|
293
|
+
inds_reverse_rj (np.ndarray, optional): Array giving the walker index for which proposals are
|
|
294
|
+
reverse proposal removing a leaf.
|
|
295
|
+
Must be provided if ``self.rj is True``. (default: ``None``)
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
tuple: (generated points, factors).
|
|
299
|
+
|
|
300
|
+
Raises:
|
|
301
|
+
ValueError: Inputs are incorrect.
|
|
302
|
+
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
# check if rj and make sure we have all the information in that case
|
|
306
|
+
if self.rj:
|
|
307
|
+
try:
|
|
308
|
+
assert ll_in is not None and lp_in is not None
|
|
309
|
+
assert inds_leaves_rj is not None and inds_reverse_rj is not None
|
|
310
|
+
except AssertionError:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
"If using rj, must provide ll_in, lp_in, inds_leaves_rj, and inds_reverse_rj."
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# if using reversible jump, fill first spot with values that are proposed to remove
|
|
316
|
+
fill_tuple = (inds_reverse_rj, np.zeros_like(inds_reverse_rj))
|
|
317
|
+
fill_values = coords[inds_reverse_rj]
|
|
318
|
+
else:
|
|
319
|
+
fill_tuple = None
|
|
320
|
+
fill_values = None
|
|
321
|
+
|
|
322
|
+
# generate new points and get log of the proposal probability
|
|
323
|
+
generated_points, log_proposal_pdf = self.special_generate_func(
|
|
324
|
+
coords,
|
|
325
|
+
random,
|
|
326
|
+
*args_generate,
|
|
327
|
+
size=self.num_try,
|
|
328
|
+
fill_values=fill_values,
|
|
329
|
+
fill_tuple=fill_tuple,
|
|
330
|
+
**kwargs_generate
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# compute the Likelihood functions
|
|
334
|
+
ll = self.special_like_func(
|
|
335
|
+
generated_points, *args_like, inds_leaves_rj=inds_leaves_rj, **kwargs_like
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# check for nans
|
|
339
|
+
if self.xp.any(self.xp.isnan(ll)):
|
|
340
|
+
warnings.warn("Getting nans for ll in multiple try.")
|
|
341
|
+
ll[self.xp.isnan(ll)] = -1e300
|
|
342
|
+
|
|
343
|
+
# compute the Prior functions
|
|
344
|
+
lp = self.special_prior_func(
|
|
345
|
+
generated_points, *args_prior, inds_leaves_rj=inds_leaves_rj, **kwargs_prior
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# if rj, make proposal distribution for all other leaves the prior value
|
|
349
|
+
# this will properly cancel the prior with the proposal for leaves that already exists
|
|
350
|
+
if self.rj:
|
|
351
|
+
log_proposal_pdf += lp_in[:, None]
|
|
352
|
+
|
|
353
|
+
# get posterior distribution including tempering
|
|
354
|
+
logP = self.get_mt_log_posterior(ll, lp, betas=betas)
|
|
355
|
+
|
|
356
|
+
log_importance_weights, log_sum_weights, inds_keep = get_mt_computations(
|
|
357
|
+
logP, log_proposal_pdf, symmetric=self.symmetric, xp=self.xp
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# tuple of index arrays of which try chosen per walker
|
|
361
|
+
inds_tuple = (self.xp.arange(len(inds_keep)), inds_keep)
|
|
362
|
+
|
|
363
|
+
if self.rj:
|
|
364
|
+
# this just ensures the cancellation of logP and aux_logP outside of proposal
|
|
365
|
+
inds_tuple[1][inds_reverse_rj] = 0
|
|
366
|
+
|
|
367
|
+
# get chosen prior, Likelihood, posterior information
|
|
368
|
+
lp_out = lp[inds_tuple]
|
|
369
|
+
ll_out = ll[inds_tuple]
|
|
370
|
+
logP_out = logP[inds_tuple]
|
|
371
|
+
|
|
372
|
+
# store this information for access outside of multiple try part
|
|
373
|
+
self.mt_lp = lp_out
|
|
374
|
+
self.mt_ll = ll_out
|
|
375
|
+
|
|
376
|
+
# choose points and get the log of the proposal for storage
|
|
377
|
+
generated_points_out = generated_points[inds_tuple].copy() # theta^j
|
|
378
|
+
log_proposal_pdf_out = log_proposal_pdf[inds_tuple]
|
|
379
|
+
|
|
380
|
+
# prepare auxillary information based on if it is nested rj, independent, or not
|
|
381
|
+
if self.independent:
|
|
382
|
+
# if independent, all the tries can be repeated for the auxillary draws
|
|
383
|
+
aux_ll = ll.copy()
|
|
384
|
+
aux_lp = lp.copy()
|
|
385
|
+
|
|
386
|
+
# sub in the generation pdf for the current coordinates
|
|
387
|
+
aux_log_proposal_pdf_sub = self.special_generate_logpdf(coords)
|
|
388
|
+
|
|
389
|
+
# set sub ll based on if it is provided
|
|
390
|
+
if ll_in is None:
|
|
391
|
+
aux_ll_sub = self.special_generate_like(coords)
|
|
392
|
+
|
|
393
|
+
else:
|
|
394
|
+
assert ll_in.shape[0] == coords.shape[0]
|
|
395
|
+
aux_ll_sub = ll_in
|
|
396
|
+
|
|
397
|
+
# set sub lp based on if it is provided
|
|
398
|
+
if lp_in is None:
|
|
399
|
+
aux_lp_sub = self.special_generate_prior(coords)
|
|
400
|
+
|
|
401
|
+
else:
|
|
402
|
+
assert lp_in.shape[0] == coords.shape[0]
|
|
403
|
+
aux_lp_sub = lp_in
|
|
404
|
+
|
|
405
|
+
# sub in this information from the current coordinates
|
|
406
|
+
aux_ll[inds_tuple] = aux_ll_sub
|
|
407
|
+
aux_lp[inds_tuple] = aux_lp_sub
|
|
408
|
+
|
|
409
|
+
# get auxillary posterior
|
|
410
|
+
aux_logP = self.get_mt_log_posterior(aux_ll, aux_lp, betas=betas)
|
|
411
|
+
|
|
412
|
+
# get aux_log_proposal_pdf information
|
|
413
|
+
aux_log_proposal_pdf = log_proposal_pdf.copy()
|
|
414
|
+
aux_log_proposal_pdf[inds_tuple] = aux_log_proposal_pdf_sub
|
|
415
|
+
|
|
416
|
+
# set auxillary weights
|
|
417
|
+
aux_log_importance_weights = aux_logP - aux_log_proposal_pdf
|
|
418
|
+
|
|
419
|
+
elif self.rj:
|
|
420
|
+
# in rj, set aux_ll and aux_lp to be repeats of the model with one less leaf
|
|
421
|
+
aux_ll = np.repeat(ll_in[:, None], self.num_try, axis=-1)
|
|
422
|
+
aux_lp = np.repeat(lp_in[:, None], self.num_try, axis=-1)
|
|
423
|
+
|
|
424
|
+
# probability is the prior for the existing points
|
|
425
|
+
aux_log_proposal_pdf = aux_lp.copy()
|
|
426
|
+
|
|
427
|
+
# get log posterior
|
|
428
|
+
aux_logP = self.get_mt_log_posterior(aux_ll, aux_lp, betas=betas)
|
|
429
|
+
|
|
430
|
+
# get importance weights
|
|
431
|
+
aux_log_importance_weights = aux_logP - aux_log_proposal_pdf
|
|
432
|
+
|
|
433
|
+
else:
|
|
434
|
+
# generate auxillary points based on chosen new points
|
|
435
|
+
aux_generated_points, aux_log_proposal_pdf = self.special_generate_func(
|
|
436
|
+
generated_points_out,
|
|
437
|
+
random,
|
|
438
|
+
*args_generate,
|
|
439
|
+
size=self.num_try,
|
|
440
|
+
fill_tuple=inds_tuple,
|
|
441
|
+
fill_values=generated_points_out,
|
|
442
|
+
**kwargs_generate
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
# get ll, lp, and lP
|
|
446
|
+
aux_ll = self.special_like_func(
|
|
447
|
+
aux_generated_points, *args_like, **kwargs_like
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
aux_lp = self.special_prior_func(aux_generated_points)
|
|
451
|
+
|
|
452
|
+
aux_logP = self.get_mt_log_posterior(aux_ll, aux_lp, betas=betas)
|
|
453
|
+
|
|
454
|
+
# set auxillary weights
|
|
455
|
+
if not self.symmetric:
|
|
456
|
+
aux_log_importance_weights = aux_logP - aux_log_proposal_pdf_sub
|
|
457
|
+
else:
|
|
458
|
+
aux_log_importance_weights = aux_logP
|
|
459
|
+
|
|
460
|
+
# chosen output old Posteriors
|
|
461
|
+
aux_logP_out = aux_logP[inds_tuple]
|
|
462
|
+
# get sum of log weights
|
|
463
|
+
aux_log_sum_weights = logsumexp(aux_log_importance_weights, axis=-1, xp=self.xp)
|
|
464
|
+
|
|
465
|
+
aux_log_proposal_pdf_out = aux_log_proposal_pdf[inds_tuple]
|
|
466
|
+
# this is setup to make clear with the math.
|
|
467
|
+
# setting up factors properly means the
|
|
468
|
+
# final lnpdiff will be effectively be the ratio of the sums
|
|
469
|
+
# of the weights
|
|
470
|
+
|
|
471
|
+
# IMPORTANT: logP_out must be subtracted against log_sum_weights before anything else due to -1e300s.
|
|
472
|
+
factors = (
|
|
473
|
+
(aux_logP_out - aux_log_sum_weights)
|
|
474
|
+
- aux_log_proposal_pdf_out
|
|
475
|
+
+ aux_log_proposal_pdf_out
|
|
476
|
+
) - ((logP_out - log_sum_weights) - log_proposal_pdf_out + log_proposal_pdf_out)
|
|
477
|
+
|
|
478
|
+
if self.rj:
|
|
479
|
+
# adjust all information for reverese rj proposals
|
|
480
|
+
factors[inds_reverse_rj] *= -1
|
|
481
|
+
self.mt_ll[inds_reverse_rj] = ll_in[inds_reverse_rj]
|
|
482
|
+
self.mt_lp[inds_reverse_rj] = lp_in[inds_reverse_rj]
|
|
483
|
+
|
|
484
|
+
# store output information
|
|
485
|
+
self.aux_logP_out = aux_logP_out
|
|
486
|
+
self.logP_out = logP_out
|
|
487
|
+
self.aux_ll = aux_ll
|
|
488
|
+
self.aux_lp = aux_lp
|
|
489
|
+
|
|
490
|
+
self.log_sum_weights = log_sum_weights
|
|
491
|
+
self.aux_log_sum_weights = aux_log_sum_weights
|
|
492
|
+
|
|
493
|
+
if self.rj:
|
|
494
|
+
self.inds_reverse_rj = inds_reverse_rj
|
|
495
|
+
self.inds_forward_rj = np.delete(
|
|
496
|
+
np.arange(coords.shape[0]), inds_reverse_rj
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# prepare to readout any information the user would like in readout_adjustment
|
|
500
|
+
out_vals = [logP_out, ll_out, lp_out, log_proposal_pdf_out, log_sum_weights]
|
|
501
|
+
all_vals_prop = [logP, ll, lp, log_proposal_pdf, log_sum_weights]
|
|
502
|
+
aux_all_vals = [
|
|
503
|
+
aux_logP,
|
|
504
|
+
aux_ll,
|
|
505
|
+
aux_lp,
|
|
506
|
+
aux_log_proposal_pdf,
|
|
507
|
+
aux_log_sum_weights,
|
|
508
|
+
]
|
|
509
|
+
self.readout_adjustment(out_vals, all_vals_prop, aux_all_vals)
|
|
510
|
+
|
|
511
|
+
return (
|
|
512
|
+
generated_points_out,
|
|
513
|
+
factors,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
def get_proposal(self, branches_coords, random, branches_inds=None, **kwargs):
|
|
517
|
+
"""Get proposal
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
branches_coords (dict): Keys are ``branch_names`` and values are
|
|
521
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim] representing
|
|
522
|
+
coordinates for walkers.
|
|
523
|
+
random (object): Current random state object.
|
|
524
|
+
branches_inds (dict, optional): Keys are ``branch_names`` and values are
|
|
525
|
+
np.ndarray[ntemps, nwalkers, nleaves_max] representing which
|
|
526
|
+
leaves are currently being used. (default: ``None``)
|
|
527
|
+
**kwargs (ignored): This is added for compatibility. It is ignored in this function.
|
|
528
|
+
|
|
529
|
+
Returns:
|
|
530
|
+
tuple: (Proposed coordinates, factors) -> (dict, np.ndarray)
|
|
531
|
+
|
|
532
|
+
Raises:
|
|
533
|
+
ValueError: Input issues.
|
|
534
|
+
|
|
535
|
+
"""
|
|
536
|
+
|
|
537
|
+
# mutliple try is only made for one branch here
|
|
538
|
+
if len(list(branches_coords.keys())) > 1:
|
|
539
|
+
raise ValueError("Can only propose change to one model at a time with MT.")
|
|
540
|
+
|
|
541
|
+
# get main key
|
|
542
|
+
key_in = list(branches_coords.keys())[0]
|
|
543
|
+
self.key_in = key_in
|
|
544
|
+
|
|
545
|
+
# get inds information
|
|
546
|
+
if branches_inds is None:
|
|
547
|
+
branches_inds = {}
|
|
548
|
+
branches_inds[key_in] = np.ones(
|
|
549
|
+
branches_coords[key_in].shape[:-1], dtype=bool
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
# Make sure for base proposals that there is only one leaf
|
|
553
|
+
if np.any(branches_inds[key_in].sum(axis=-1) > 1):
|
|
554
|
+
raise ValueError
|
|
555
|
+
|
|
556
|
+
ntemps, nwalkers, _, _ = branches_coords[key_in].shape
|
|
557
|
+
|
|
558
|
+
# get temperature information
|
|
559
|
+
betas_here = np.repeat(
|
|
560
|
+
self.temperature_control.betas[:, None],
|
|
561
|
+
np.prod(branches_coords[key_in].shape[1:-1]),
|
|
562
|
+
).reshape(branches_inds[key_in].shape)[branches_inds[key_in]]
|
|
563
|
+
|
|
564
|
+
# previous Likelihoods in case proposal is independent
|
|
565
|
+
ll_here = np.repeat(
|
|
566
|
+
self.current_state.log_like[:, :, None],
|
|
567
|
+
branches_coords[key_in].shape[2],
|
|
568
|
+
axis=-1,
|
|
569
|
+
).reshape(branches_inds[key_in].shape)[branches_inds[key_in]]
|
|
570
|
+
|
|
571
|
+
# previous Priors in case proposal is independent
|
|
572
|
+
lp_here = np.repeat(
|
|
573
|
+
self.current_state.log_prior[:, :, None],
|
|
574
|
+
branches_coords[key_in].shape[2],
|
|
575
|
+
axis=-1,
|
|
576
|
+
).reshape(branches_inds[key_in].shape)[branches_inds[key_in]]
|
|
577
|
+
|
|
578
|
+
# get mt proposal
|
|
579
|
+
generated_points, factors = self.get_mt_proposal(
|
|
580
|
+
branches_coords[key_in][branches_inds[key_in]],
|
|
581
|
+
random,
|
|
582
|
+
betas=betas_here,
|
|
583
|
+
ll_in=ll_here,
|
|
584
|
+
lp_in=lp_here,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
# store this information for access outside
|
|
588
|
+
self.mt_ll = self.mt_ll.reshape(ntemps, nwalkers)
|
|
589
|
+
self.mt_lp = self.mt_lp.reshape(ntemps, nwalkers)
|
|
590
|
+
|
|
591
|
+
return (
|
|
592
|
+
{key_in: generated_points.reshape(ntemps, nwalkers, 1, -1)},
|
|
593
|
+
factors.reshape(ntemps, nwalkers),
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
class MultipleTryMoveRJ(MultipleTryMove):
|
|
598
|
+
def get_proposal(
|
|
599
|
+
self,
|
|
600
|
+
branches_coords,
|
|
601
|
+
branches_inds,
|
|
602
|
+
nleaves_min_all,
|
|
603
|
+
nleaves_max_all,
|
|
604
|
+
random,
|
|
605
|
+
**kwargs
|
|
606
|
+
):
|
|
607
|
+
"""Make a proposal
|
|
608
|
+
|
|
609
|
+
Args:
|
|
610
|
+
all_coords (dict): Keys are ``branch_names``. Values are
|
|
611
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
|
|
612
|
+
coordinates for all the walkers.
|
|
613
|
+
all_inds (dict): Keys are ``branch_names``. Values are
|
|
614
|
+
np.ndarray[ntemps, nwalkers, nleaves_max]. These are the boolean
|
|
615
|
+
arrays marking which leaves are currently used within each walker.
|
|
616
|
+
nleaves_min_all (list): Minimum values of leaf ount for each model. Must have same order as ``all_cords``.
|
|
617
|
+
nleaves_max_all (list): Maximum values of leaf ount for each model. Must have same order as ``all_cords``.
|
|
618
|
+
random (object): Current random state of the sampler.
|
|
619
|
+
**kwargs (ignored): For modularity.
|
|
620
|
+
|
|
621
|
+
Returns:
|
|
622
|
+
tuple: Tuple containing proposal information.
|
|
623
|
+
First entry is the new coordinates as a dictionary with keys
|
|
624
|
+
as ``branch_names`` and values as
|
|
625
|
+
``double `` np.ndarray[ntemps, nwalkers, nleaves_max, ndim] containing
|
|
626
|
+
proposed coordinates. Second entry is the new ``inds`` array with
|
|
627
|
+
boolean values flipped for added or removed sources. Third entry
|
|
628
|
+
is the factors associated with the
|
|
629
|
+
proposal necessary for detailed balance. This is effectively
|
|
630
|
+
any term in the detailed balance fraction. +log of factors if
|
|
631
|
+
in the numerator. -log of factors if in the denominator.
|
|
632
|
+
|
|
633
|
+
"""
|
|
634
|
+
|
|
635
|
+
if len(list(branches_coords.keys())) > 1:
|
|
636
|
+
raise ValueError("Can only propose change to one model at a time with MT.")
|
|
637
|
+
|
|
638
|
+
# get main key
|
|
639
|
+
key_in = list(branches_coords.keys())[0]
|
|
640
|
+
self.key_in = key_in
|
|
641
|
+
|
|
642
|
+
if branches_inds is None:
|
|
643
|
+
raise ValueError("In MT RJ proposal, branches_inds cannot be None.")
|
|
644
|
+
|
|
645
|
+
ntemps, nwalkers, nleaves_max, ndim = branches_coords[key_in].shape
|
|
646
|
+
|
|
647
|
+
# get temperature information
|
|
648
|
+
betas_here = np.repeat(
|
|
649
|
+
self.temperature_control.betas[:, None], nwalkers, axis=-1
|
|
650
|
+
).flatten()
|
|
651
|
+
|
|
652
|
+
# current Likelihood and prior information
|
|
653
|
+
ll_here = self.current_state.log_like.flatten()
|
|
654
|
+
lp_here = self.current_state.log_prior.flatten()
|
|
655
|
+
|
|
656
|
+
# do rj setup
|
|
657
|
+
assert len(nleaves_min_all) == 1 and len(nleaves_max_all) == 1
|
|
658
|
+
nleaves_min = nleaves_min_all[key_in]
|
|
659
|
+
nleaves_max = nleaves_max_all[key_in]
|
|
660
|
+
|
|
661
|
+
if nleaves_min == nleaves_max:
|
|
662
|
+
raise ValueError("MT RJ proposal requires that nleaves_min != nleaves_max.")
|
|
663
|
+
elif nleaves_min > nleaves_max:
|
|
664
|
+
raise ValueError("nleaves_min is greater than nleaves_max. Not allowed.")
|
|
665
|
+
|
|
666
|
+
# get the inds adjustment information
|
|
667
|
+
all_inds_for_change = self.get_model_change_proposal(
|
|
668
|
+
branches_inds[key_in], random, nleaves_min, nleaves_max
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# preparing leaf information for going into the proposal
|
|
672
|
+
inds_leaves_rj = np.zeros(ntemps * nwalkers, dtype=int)
|
|
673
|
+
coords_in = np.zeros((ntemps * nwalkers, ndim))
|
|
674
|
+
inds_reverse_rj = None
|
|
675
|
+
|
|
676
|
+
# prepare proposal dictionaries
|
|
677
|
+
new_inds = deepcopy(branches_inds)
|
|
678
|
+
q = deepcopy(branches_coords)
|
|
679
|
+
for change in all_inds_for_change.keys():
|
|
680
|
+
if change not in ["+1", "-1"]:
|
|
681
|
+
raise ValueError("MT RJ is only implemented for +1/-1 moves.")
|
|
682
|
+
|
|
683
|
+
# get indicies of changing leaves
|
|
684
|
+
temp_inds = all_inds_for_change[change][:, 0]
|
|
685
|
+
walker_inds = all_inds_for_change[change][:, 1]
|
|
686
|
+
leaf_inds = all_inds_for_change[change][:, 2]
|
|
687
|
+
|
|
688
|
+
# leaf index to change
|
|
689
|
+
inds_leaves_rj[temp_inds * nwalkers + walker_inds] = leaf_inds
|
|
690
|
+
coords_in[temp_inds * nwalkers + walker_inds] = branches_coords[key_in][
|
|
691
|
+
(temp_inds, walker_inds, leaf_inds)
|
|
692
|
+
]
|
|
693
|
+
|
|
694
|
+
# adjustment of indices
|
|
695
|
+
new_val = {"+1": True, "-1": False}[change]
|
|
696
|
+
|
|
697
|
+
# adjust indices
|
|
698
|
+
new_inds[key_in][(temp_inds, walker_inds, leaf_inds)] = new_val
|
|
699
|
+
|
|
700
|
+
if change == "-1":
|
|
701
|
+
# which walkers are removing
|
|
702
|
+
inds_reverse_rj = temp_inds * nwalkers + walker_inds
|
|
703
|
+
|
|
704
|
+
if inds_reverse_rj is not None:
|
|
705
|
+
# setup reversal coords and inds
|
|
706
|
+
# need to determine Likelihood and prior of removed binaries.
|
|
707
|
+
# this goes into the multiple try proposal as previous ll and lp
|
|
708
|
+
temp_reverse_coords = {}
|
|
709
|
+
temp_reverse_inds = {}
|
|
710
|
+
|
|
711
|
+
for key in self.current_state.branches:
|
|
712
|
+
(
|
|
713
|
+
ntemps_tmp,
|
|
714
|
+
nwalkers_tmp,
|
|
715
|
+
nleaves_max_tmp,
|
|
716
|
+
ndim_tmp,
|
|
717
|
+
) = self.current_state.branches[key].shape
|
|
718
|
+
|
|
719
|
+
# coords from reversal
|
|
720
|
+
temp_reverse_coords[key] = self.current_state.branches[
|
|
721
|
+
key
|
|
722
|
+
].coords.reshape(ntemps_tmp * nwalkers_tmp, nleaves_max_tmp, ndim_tmp)[
|
|
723
|
+
inds_reverse_rj
|
|
724
|
+
][
|
|
725
|
+
None, :
|
|
726
|
+
]
|
|
727
|
+
|
|
728
|
+
# which inds array to use
|
|
729
|
+
inds_tmp_here = (
|
|
730
|
+
new_inds[key]
|
|
731
|
+
if key == key_in
|
|
732
|
+
else self.current_state.branches[key].inds
|
|
733
|
+
)
|
|
734
|
+
temp_reverse_inds[key] = inds_tmp_here.reshape(
|
|
735
|
+
ntemps * nwalkers, nleaves_max_tmp
|
|
736
|
+
)[inds_reverse_rj][None, :]
|
|
737
|
+
|
|
738
|
+
# calculate information for the reverse
|
|
739
|
+
lp_reverse_here = self.current_model.compute_log_prior_fn(
|
|
740
|
+
temp_reverse_coords, inds=temp_reverse_inds
|
|
741
|
+
)[0]
|
|
742
|
+
ll_reverse_here = self.current_model.compute_log_like_fn(
|
|
743
|
+
temp_reverse_coords, inds=temp_reverse_inds, logp=lp_here
|
|
744
|
+
)[0]
|
|
745
|
+
|
|
746
|
+
# fill the here values
|
|
747
|
+
ll_here[inds_reverse_rj] = ll_reverse_here
|
|
748
|
+
lp_here[inds_reverse_rj] = lp_reverse_here
|
|
749
|
+
|
|
750
|
+
# get mt proposal
|
|
751
|
+
generated_points, factors = self.get_mt_proposal(
|
|
752
|
+
coords_in,
|
|
753
|
+
random,
|
|
754
|
+
betas=betas_here,
|
|
755
|
+
ll_in=ll_here,
|
|
756
|
+
lp_in=lp_here,
|
|
757
|
+
inds_leaves_rj=inds_leaves_rj,
|
|
758
|
+
inds_reverse_rj=inds_reverse_rj,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
# for reading outside
|
|
762
|
+
self.mt_ll = self.mt_ll.reshape(ntemps, nwalkers)
|
|
763
|
+
self.mt_lp = self.mt_lp.reshape(ntemps, nwalkers)
|
|
764
|
+
|
|
765
|
+
# which walkers have information added
|
|
766
|
+
inds_forward_rj = np.delete(np.arange(coords_in.shape[0]), inds_reverse_rj)
|
|
767
|
+
|
|
768
|
+
# updated the coordinates
|
|
769
|
+
temp_inds = all_inds_for_change["+1"][:, 0]
|
|
770
|
+
walker_inds = all_inds_for_change["+1"][:, 1]
|
|
771
|
+
leaf_inds = all_inds_for_change["+1"][:, 2]
|
|
772
|
+
q[key_in][(temp_inds, walker_inds, leaf_inds)] = generated_points[
|
|
773
|
+
inds_forward_rj
|
|
774
|
+
]
|
|
775
|
+
|
|
776
|
+
return q, new_inds, factors.reshape(ntemps, nwalkers)
|