eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,776 @@
1
+ from multiprocessing.sharedctypes import Value
2
+ import numpy as np
3
+ import warnings
4
+ from copy import deepcopy
5
+ from abc import ABC
6
+
7
+ # from scipy.special import logsumexp
8
+
9
+ try:
10
+ import cupy as cp
11
+
12
+ gpu_available = True
13
+ except (ModuleNotFoundError, ImportError):
14
+ import numpy as cp
15
+
16
+ gpu_available = False
17
+
18
+ from .rj import ReversibleJumpMove
19
+ from ..prior import ProbDistContainer
20
+ from ..utils.utility import groups_from_inds
21
+
22
+ ___ = ["MultipleTryMove"]
23
+
24
+
25
+ def logsumexp(a, axis=None, xp=None):
26
+ if xp is None:
27
+ xp = np
28
+
29
+ max = xp.max(a, axis=axis)
30
+ ds = a - max[:, None]
31
+
32
+ sum_of_exp = xp.exp(ds).sum(axis=axis)
33
+ return max + xp.log(sum_of_exp)
34
+
35
+
36
+ def get_mt_computations(logP, log_proposal_pdf, symmetric=False, xp=None):
37
+
38
+ if xp is None:
39
+ xp = np
40
+
41
+ # set weights based on if symmetric
42
+ if symmetric:
43
+ log_importance_weights = logP
44
+ else:
45
+ log_importance_weights = logP - log_proposal_pdf
46
+
47
+ # get the sum of weights
48
+ log_sum_weights = logsumexp(log_importance_weights, axis=-1, xp=xp)
49
+
50
+ # probs = wi / sum(wi)
51
+ log_of_probs = log_importance_weights - log_sum_weights[:, None]
52
+
53
+ # probabilities to choose try
54
+ probs = xp.exp(log_of_probs)
55
+
56
+ # draw based on likelihood
57
+ inds_keep = (probs.cumsum(1) > xp.random.rand(probs.shape[0])[:, None]).argmax(1)
58
+
59
+ return log_importance_weights, log_sum_weights, inds_keep
60
+
61
+
62
+ class MultipleTryMove(ABC):
63
+ """Generate multiple proposal tries.
64
+
65
+ This class should be inherited by another proposal class
66
+ with the ``@classmethods`` overwritten. See :class:`eryn.moves.MTDistGenMove`
67
+ and :class:`MTDistGenRJ` for examples.
68
+
69
+ Args:
70
+ num_try (int, optional): Number of tries. (default: 1)
71
+ independent (bool, optional): Set to ``True`` if the proposal is independent of the current points.
72
+ (default: ``False``).
73
+ symmetric (bool, optional): Set to ``True`` if the proposal is symmetric.
74
+ (default: ``False``).
75
+ rj (bool, optional): Set to ``True`` if this is a nested reversible jump proposal.
76
+ (default: ``False``).
77
+ **kwargs (dict, optional): for compatibility with other proposals.
78
+
79
+ Raises:
80
+ ValueError: Input issues.
81
+
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ num_try=1,
87
+ independent=False,
88
+ symmetric=False,
89
+ rj=False,
90
+ use_gpu=None,
91
+ **kwargs
92
+ ):
93
+ self.num_try = num_try
94
+
95
+ self.independent = independent
96
+ self.symmetric = symmetric
97
+ self.rj = rj
98
+
99
+ if self.rj:
100
+ if self.symmetric or self.independent:
101
+ raise ValueError(
102
+ "If rj==True, symmetric and independt must both be False."
103
+ )
104
+
105
+ self.use_gpu = use_gpu
106
+
107
+ @property
108
+ def xp(self):
109
+ xp = cp if self.use_gpu else np
110
+ return xp
111
+
112
+ @classmethod
113
+ def special_like_func(self, generated_coords, *args, inds_leaves_rj=None, **kwargs):
114
+ """Calculate the Likelihood for sampled points.
115
+
116
+ Args:
117
+ generated_coords (np.ndarray): Generated coordinates with shape ``(number of independent walkers, num_try)``.
118
+ *args (tuple, optional): additional arguments passed by overwriting the
119
+ ``get_proposal`` function and passing ``args_like`` keyword argument.
120
+ inds_leaves_rj (np.ndarray): Index into each individual walker giving the
121
+ leaf index associated with this proposal. Should only be used if ``self.rj is True``. (default: ``None``)
122
+ **kwargs (tuple, optional): additional keyword arguments passed by overwriting the
123
+ ``get_proposal`` function and passing ``kwargs_like`` keyword argument.
124
+
125
+ Returns:
126
+ np.ndarray: Likelihood values with shape ``(generated_coords.shape[0], num_try).``
127
+
128
+ Raises:
129
+ NotImplementedError: Function not included.
130
+
131
+ """
132
+ raise NotImplementedError
133
+
134
+ @classmethod
135
+ def special_prior_func(self, generated_coords, *args, **kwargs):
136
+ """Calculate the Prior for sampled points.
137
+
138
+ Args:
139
+ generated_coords (np.ndarray): Generated coordinates with shape ``(number of independent walkers, num_try)``.
140
+ *args (tuple, optional): additional arguments passed by overwriting the
141
+ ``get_proposal`` function and passing ``args_prior`` keyword argument.
142
+ inds_leaves_rj (np.ndarray): Index into each individual walker giving the
143
+ leaf index associated with this proposal. Should only be used if ``self.rj is True``. (default: ``None``)
144
+ **kwargs (tuple, optional): additional keyword arguments passed by overwriting the
145
+ ``get_proposal`` function and passing ``kwargs_prior`` keyword argument.
146
+
147
+ Returns:
148
+ np.ndarray: Prior values with shape ``(generated_coords.shape[0], num_try).``
149
+
150
+ Raises:
151
+ NotImplementedError: Function not included.
152
+
153
+ """
154
+ raise NotImplementedError
155
+
156
+ @classmethod
157
+ def special_generate_func(
158
+ coords, random, size=1, *args, fill_tuple=None, fill_values=None, **kwargs
159
+ ):
160
+ """Generate samples and calculate the logpdf of their proposal function.
161
+
162
+ Args:
163
+ coords (np.ndarray): Current coordinates of walkers.
164
+ random (obj): Random generator.
165
+ *args (tuple, optional): additional arguments passed by overwriting the
166
+ ``get_proposal`` function and passing ``args_generate`` keyword argument.
167
+ size (int, optional): Number of tries to generate.
168
+ fill_tuple (tuple, optional): Length 2 tuple with the indexing of which values to fill
169
+ when generating. Can be used for auxillary proposals or reverse RJ proposals. First index is the index into walkers and the second index is
170
+ the index into the number of tries. (default: ``None``)
171
+ fill_values (np.ndarray): values to fill associated with ``fill_tuple``. Should
172
+ have size ``(len(fill_tuple[0]), ndim)``. (default: ``None``).
173
+ **kwargs (tuple, optional): additional keyword arguments passed by overwriting the
174
+ ``get_proposal`` function and passing ``kwargs_generate`` keyword argument.
175
+
176
+ Returns:
177
+ tuple: (generated points, logpdf of generated points).
178
+
179
+ Raises:
180
+ NotImplementedError: Function not included.
181
+
182
+ """
183
+ raise NotImplementedError
184
+
185
+ @classmethod
186
+ def special_generate_logpdf(self, coords):
187
+ """Get logpdf of generated coordinates.
188
+
189
+ Args:
190
+ coords (np.ndarray): Current coordinates of walkers.
191
+
192
+ Returns:
193
+ np.ndarray: logpdf of generated points.
194
+
195
+ Raises:
196
+ NotImplementedError: Function not included.
197
+ """
198
+ raise NotImplementedError
199
+
200
+ def get_mt_log_posterior(self, ll, lp, betas=None):
201
+ """Calculate the log of the posterior for all tries.
202
+
203
+ Args:
204
+ ll (np.ndarray): Log Likelihood values with shape ``(nwalkers, num_tries)``.
205
+ lp (np.ndarray): Log Prior values with shape ``(nwalkers, num_tries)``.
206
+ betas (np.ndarray, optional): Inverse temperatures to include in log Posterior computation.
207
+ (default: ``None``)
208
+
209
+ Returns:
210
+ np.ndarray: Log of the Posterior with shape ``(nwalkers, num_tries)``.
211
+
212
+ """
213
+ if betas is None:
214
+ ll_temp = ll.copy()
215
+ else:
216
+ assert isinstance(betas, self.xp.ndarray)
217
+ if ll.ndim > 1:
218
+ betas_tmp = self.xp.expand_dims(betas, ll.ndim - 1)
219
+ else:
220
+ betas_tmp = betas
221
+ ll_temp = betas_tmp * ll
222
+
223
+ return ll_temp + lp
224
+
225
+ def readout_adjustment(self, out_vals, all_vals_prop, aux_all_vals):
226
+ """Read out values from the proposal.
227
+
228
+ Allows the user to read out any values from the proposal that may be needed elsewhere. This function must be overwritten.
229
+
230
+ Args:
231
+ out_vals (list): ``[logP_out, ll_out, lp_out, log_proposal_pdf_out, log_sum_weights]``.
232
+ all_vals_prop (list): ``[logP, ll, lp, log_proposal_pdf, log_sum_weights]``.
233
+ aux_all_vals (list): ``[aux_logP, aux_ll, aux_lp, aux_log_proposal_pdf, aux_log_sum_weights]``.
234
+
235
+ """
236
+ pass
237
+
238
+ def get_mt_proposal(
239
+ self,
240
+ coords,
241
+ random,
242
+ args_generate=(),
243
+ kwargs_generate={},
244
+ args_like=(),
245
+ kwargs_like={},
246
+ args_prior=(),
247
+ kwargs_prior={},
248
+ betas=None,
249
+ ll_in=None,
250
+ lp_in=None,
251
+ inds_leaves_rj=None,
252
+ inds_reverse_rj=None,
253
+ ):
254
+ """Make a multiple-try proposal
255
+
256
+ Here, ``nwalkers`` refers to all independent walkers which generally
257
+ will mean ``nwalkers * ntemps`` in terms of the rest of the sampler.
258
+
259
+ Args:
260
+ coords (np.ndarray): Current coordinates of walkers.
261
+ random (obj): Random generator.
262
+ args_generate (tuple, optional): Additional ``*args`` to pass to generate function.
263
+ Must overwrite ``get_proposal`` function to use these.
264
+ (default: ``()``)
265
+ kwargs_generate (dict, optional): Additional ``**kwargs`` to pass to generate function.
266
+ (default: ``{}``)
267
+ Must overwrite ``get_proposal`` function to use these.
268
+ args_like (tuple, optional): Additional ``*args`` to pass to Likelihood function.
269
+ Must overwrite ``get_proposal`` function to use these.
270
+ (default: ``()``)
271
+ kwargs_like (dict, optional): Additional ``**kwargs`` to pass to Likelihood function.
272
+ Must overwrite ``get_proposal`` function to use these.
273
+ (default: ``{}``)
274
+ args_prior (tuple, optional): Additional ``*args`` to pass to Prior function.
275
+ Must overwrite ``get_proposal`` function to use these.
276
+ (default: ``()``)
277
+ kwargs_prior (dict, optional): Additional ``**kwargs`` to pass to Prior function.
278
+ Must overwrite ``get_proposal`` function to use these.
279
+ (default: ``{}``)
280
+ betas (np.ndarray, optional): Inverse temperatures passes to the proposal with shape ``(nwalkers,)``.
281
+ ll_in (np.ndarray, optional): Log Likelihood values coming in for current coordinates. Must be provided
282
+ if ``self.rj is True``. If ``self.rj is True``, must be nested.
283
+ Also, for all proposed removals, this value must be the Likelihood with the binary
284
+ removed so all proposals are pretending to add a binary.
285
+ Useful if ``self.independent is True``. (default: ``None``)
286
+ lp_in (np.ndarray, optional): Log Prior values coming in for current coordinates. Must be provided
287
+ if ``self.rj is True``. If ``self.rj is True``, must be nested.
288
+ Also, for all proposed removals, this value must be the Likelihood with the binary
289
+ removed so all proposals are pretending to add a binary.
290
+ Useful if ``self.independent is True``. (default: ``None``)
291
+ inds_leaves_rj (np.ndarray, optional): Array giving the leaf index of each incoming walker.
292
+ Must be provided if ``self.rj is True``. (default: ``None``)
293
+ inds_reverse_rj (np.ndarray, optional): Array giving the walker index for which proposals are
294
+ reverse proposal removing a leaf.
295
+ Must be provided if ``self.rj is True``. (default: ``None``)
296
+
297
+ Returns:
298
+ tuple: (generated points, factors).
299
+
300
+ Raises:
301
+ ValueError: Inputs are incorrect.
302
+
303
+ """
304
+
305
+ # check if rj and make sure we have all the information in that case
306
+ if self.rj:
307
+ try:
308
+ assert ll_in is not None and lp_in is not None
309
+ assert inds_leaves_rj is not None and inds_reverse_rj is not None
310
+ except AssertionError:
311
+ raise ValueError(
312
+ "If using rj, must provide ll_in, lp_in, inds_leaves_rj, and inds_reverse_rj."
313
+ )
314
+
315
+ # if using reversible jump, fill first spot with values that are proposed to remove
316
+ fill_tuple = (inds_reverse_rj, np.zeros_like(inds_reverse_rj))
317
+ fill_values = coords[inds_reverse_rj]
318
+ else:
319
+ fill_tuple = None
320
+ fill_values = None
321
+
322
+ # generate new points and get log of the proposal probability
323
+ generated_points, log_proposal_pdf = self.special_generate_func(
324
+ coords,
325
+ random,
326
+ *args_generate,
327
+ size=self.num_try,
328
+ fill_values=fill_values,
329
+ fill_tuple=fill_tuple,
330
+ **kwargs_generate
331
+ )
332
+
333
+ # compute the Likelihood functions
334
+ ll = self.special_like_func(
335
+ generated_points, *args_like, inds_leaves_rj=inds_leaves_rj, **kwargs_like
336
+ )
337
+
338
+ # check for nans
339
+ if self.xp.any(self.xp.isnan(ll)):
340
+ warnings.warn("Getting nans for ll in multiple try.")
341
+ ll[self.xp.isnan(ll)] = -1e300
342
+
343
+ # compute the Prior functions
344
+ lp = self.special_prior_func(
345
+ generated_points, *args_prior, inds_leaves_rj=inds_leaves_rj, **kwargs_prior
346
+ )
347
+
348
+ # if rj, make proposal distribution for all other leaves the prior value
349
+ # this will properly cancel the prior with the proposal for leaves that already exists
350
+ if self.rj:
351
+ log_proposal_pdf += lp_in[:, None]
352
+
353
+ # get posterior distribution including tempering
354
+ logP = self.get_mt_log_posterior(ll, lp, betas=betas)
355
+
356
+ log_importance_weights, log_sum_weights, inds_keep = get_mt_computations(
357
+ logP, log_proposal_pdf, symmetric=self.symmetric, xp=self.xp
358
+ )
359
+
360
+ # tuple of index arrays of which try chosen per walker
361
+ inds_tuple = (self.xp.arange(len(inds_keep)), inds_keep)
362
+
363
+ if self.rj:
364
+ # this just ensures the cancellation of logP and aux_logP outside of proposal
365
+ inds_tuple[1][inds_reverse_rj] = 0
366
+
367
+ # get chosen prior, Likelihood, posterior information
368
+ lp_out = lp[inds_tuple]
369
+ ll_out = ll[inds_tuple]
370
+ logP_out = logP[inds_tuple]
371
+
372
+ # store this information for access outside of multiple try part
373
+ self.mt_lp = lp_out
374
+ self.mt_ll = ll_out
375
+
376
+ # choose points and get the log of the proposal for storage
377
+ generated_points_out = generated_points[inds_tuple].copy() # theta^j
378
+ log_proposal_pdf_out = log_proposal_pdf[inds_tuple]
379
+
380
+ # prepare auxillary information based on if it is nested rj, independent, or not
381
+ if self.independent:
382
+ # if independent, all the tries can be repeated for the auxillary draws
383
+ aux_ll = ll.copy()
384
+ aux_lp = lp.copy()
385
+
386
+ # sub in the generation pdf for the current coordinates
387
+ aux_log_proposal_pdf_sub = self.special_generate_logpdf(coords)
388
+
389
+ # set sub ll based on if it is provided
390
+ if ll_in is None:
391
+ aux_ll_sub = self.special_generate_like(coords)
392
+
393
+ else:
394
+ assert ll_in.shape[0] == coords.shape[0]
395
+ aux_ll_sub = ll_in
396
+
397
+ # set sub lp based on if it is provided
398
+ if lp_in is None:
399
+ aux_lp_sub = self.special_generate_prior(coords)
400
+
401
+ else:
402
+ assert lp_in.shape[0] == coords.shape[0]
403
+ aux_lp_sub = lp_in
404
+
405
+ # sub in this information from the current coordinates
406
+ aux_ll[inds_tuple] = aux_ll_sub
407
+ aux_lp[inds_tuple] = aux_lp_sub
408
+
409
+ # get auxillary posterior
410
+ aux_logP = self.get_mt_log_posterior(aux_ll, aux_lp, betas=betas)
411
+
412
+ # get aux_log_proposal_pdf information
413
+ aux_log_proposal_pdf = log_proposal_pdf.copy()
414
+ aux_log_proposal_pdf[inds_tuple] = aux_log_proposal_pdf_sub
415
+
416
+ # set auxillary weights
417
+ aux_log_importance_weights = aux_logP - aux_log_proposal_pdf
418
+
419
+ elif self.rj:
420
+ # in rj, set aux_ll and aux_lp to be repeats of the model with one less leaf
421
+ aux_ll = np.repeat(ll_in[:, None], self.num_try, axis=-1)
422
+ aux_lp = np.repeat(lp_in[:, None], self.num_try, axis=-1)
423
+
424
+ # probability is the prior for the existing points
425
+ aux_log_proposal_pdf = aux_lp.copy()
426
+
427
+ # get log posterior
428
+ aux_logP = self.get_mt_log_posterior(aux_ll, aux_lp, betas=betas)
429
+
430
+ # get importance weights
431
+ aux_log_importance_weights = aux_logP - aux_log_proposal_pdf
432
+
433
+ else:
434
+ # generate auxillary points based on chosen new points
435
+ aux_generated_points, aux_log_proposal_pdf = self.special_generate_func(
436
+ generated_points_out,
437
+ random,
438
+ *args_generate,
439
+ size=self.num_try,
440
+ fill_tuple=inds_tuple,
441
+ fill_values=generated_points_out,
442
+ **kwargs_generate
443
+ )
444
+
445
+ # get ll, lp, and lP
446
+ aux_ll = self.special_like_func(
447
+ aux_generated_points, *args_like, **kwargs_like
448
+ )
449
+
450
+ aux_lp = self.special_prior_func(aux_generated_points)
451
+
452
+ aux_logP = self.get_mt_log_posterior(aux_ll, aux_lp, betas=betas)
453
+
454
+ # set auxillary weights
455
+ if not self.symmetric:
456
+ aux_log_importance_weights = aux_logP - aux_log_proposal_pdf_sub
457
+ else:
458
+ aux_log_importance_weights = aux_logP
459
+
460
+ # chosen output old Posteriors
461
+ aux_logP_out = aux_logP[inds_tuple]
462
+ # get sum of log weights
463
+ aux_log_sum_weights = logsumexp(aux_log_importance_weights, axis=-1, xp=self.xp)
464
+
465
+ aux_log_proposal_pdf_out = aux_log_proposal_pdf[inds_tuple]
466
+ # this is setup to make clear with the math.
467
+ # setting up factors properly means the
468
+ # final lnpdiff will be effectively be the ratio of the sums
469
+ # of the weights
470
+
471
+ # IMPORTANT: logP_out must be subtracted against log_sum_weights before anything else due to -1e300s.
472
+ factors = (
473
+ (aux_logP_out - aux_log_sum_weights)
474
+ - aux_log_proposal_pdf_out
475
+ + aux_log_proposal_pdf_out
476
+ ) - ((logP_out - log_sum_weights) - log_proposal_pdf_out + log_proposal_pdf_out)
477
+
478
+ if self.rj:
479
+ # adjust all information for reverese rj proposals
480
+ factors[inds_reverse_rj] *= -1
481
+ self.mt_ll[inds_reverse_rj] = ll_in[inds_reverse_rj]
482
+ self.mt_lp[inds_reverse_rj] = lp_in[inds_reverse_rj]
483
+
484
+ # store output information
485
+ self.aux_logP_out = aux_logP_out
486
+ self.logP_out = logP_out
487
+ self.aux_ll = aux_ll
488
+ self.aux_lp = aux_lp
489
+
490
+ self.log_sum_weights = log_sum_weights
491
+ self.aux_log_sum_weights = aux_log_sum_weights
492
+
493
+ if self.rj:
494
+ self.inds_reverse_rj = inds_reverse_rj
495
+ self.inds_forward_rj = np.delete(
496
+ np.arange(coords.shape[0]), inds_reverse_rj
497
+ )
498
+
499
+ # prepare to readout any information the user would like in readout_adjustment
500
+ out_vals = [logP_out, ll_out, lp_out, log_proposal_pdf_out, log_sum_weights]
501
+ all_vals_prop = [logP, ll, lp, log_proposal_pdf, log_sum_weights]
502
+ aux_all_vals = [
503
+ aux_logP,
504
+ aux_ll,
505
+ aux_lp,
506
+ aux_log_proposal_pdf,
507
+ aux_log_sum_weights,
508
+ ]
509
+ self.readout_adjustment(out_vals, all_vals_prop, aux_all_vals)
510
+
511
+ return (
512
+ generated_points_out,
513
+ factors,
514
+ )
515
+
516
+ def get_proposal(self, branches_coords, random, branches_inds=None, **kwargs):
517
+ """Get proposal
518
+
519
+ Args:
520
+ branches_coords (dict): Keys are ``branch_names`` and values are
521
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim] representing
522
+ coordinates for walkers.
523
+ random (object): Current random state object.
524
+ branches_inds (dict, optional): Keys are ``branch_names`` and values are
525
+ np.ndarray[ntemps, nwalkers, nleaves_max] representing which
526
+ leaves are currently being used. (default: ``None``)
527
+ **kwargs (ignored): This is added for compatibility. It is ignored in this function.
528
+
529
+ Returns:
530
+ tuple: (Proposed coordinates, factors) -> (dict, np.ndarray)
531
+
532
+ Raises:
533
+ ValueError: Input issues.
534
+
535
+ """
536
+
537
+ # mutliple try is only made for one branch here
538
+ if len(list(branches_coords.keys())) > 1:
539
+ raise ValueError("Can only propose change to one model at a time with MT.")
540
+
541
+ # get main key
542
+ key_in = list(branches_coords.keys())[0]
543
+ self.key_in = key_in
544
+
545
+ # get inds information
546
+ if branches_inds is None:
547
+ branches_inds = {}
548
+ branches_inds[key_in] = np.ones(
549
+ branches_coords[key_in].shape[:-1], dtype=bool
550
+ )
551
+
552
+ # Make sure for base proposals that there is only one leaf
553
+ if np.any(branches_inds[key_in].sum(axis=-1) > 1):
554
+ raise ValueError
555
+
556
+ ntemps, nwalkers, _, _ = branches_coords[key_in].shape
557
+
558
+ # get temperature information
559
+ betas_here = np.repeat(
560
+ self.temperature_control.betas[:, None],
561
+ np.prod(branches_coords[key_in].shape[1:-1]),
562
+ ).reshape(branches_inds[key_in].shape)[branches_inds[key_in]]
563
+
564
+ # previous Likelihoods in case proposal is independent
565
+ ll_here = np.repeat(
566
+ self.current_state.log_like[:, :, None],
567
+ branches_coords[key_in].shape[2],
568
+ axis=-1,
569
+ ).reshape(branches_inds[key_in].shape)[branches_inds[key_in]]
570
+
571
+ # previous Priors in case proposal is independent
572
+ lp_here = np.repeat(
573
+ self.current_state.log_prior[:, :, None],
574
+ branches_coords[key_in].shape[2],
575
+ axis=-1,
576
+ ).reshape(branches_inds[key_in].shape)[branches_inds[key_in]]
577
+
578
+ # get mt proposal
579
+ generated_points, factors = self.get_mt_proposal(
580
+ branches_coords[key_in][branches_inds[key_in]],
581
+ random,
582
+ betas=betas_here,
583
+ ll_in=ll_here,
584
+ lp_in=lp_here,
585
+ )
586
+
587
+ # store this information for access outside
588
+ self.mt_ll = self.mt_ll.reshape(ntemps, nwalkers)
589
+ self.mt_lp = self.mt_lp.reshape(ntemps, nwalkers)
590
+
591
+ return (
592
+ {key_in: generated_points.reshape(ntemps, nwalkers, 1, -1)},
593
+ factors.reshape(ntemps, nwalkers),
594
+ )
595
+
596
+
597
+ class MultipleTryMoveRJ(MultipleTryMove):
598
+ def get_proposal(
599
+ self,
600
+ branches_coords,
601
+ branches_inds,
602
+ nleaves_min_all,
603
+ nleaves_max_all,
604
+ random,
605
+ **kwargs
606
+ ):
607
+ """Make a proposal
608
+
609
+ Args:
610
+ all_coords (dict): Keys are ``branch_names``. Values are
611
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
612
+ coordinates for all the walkers.
613
+ all_inds (dict): Keys are ``branch_names``. Values are
614
+ np.ndarray[ntemps, nwalkers, nleaves_max]. These are the boolean
615
+ arrays marking which leaves are currently used within each walker.
616
+ nleaves_min_all (list): Minimum values of leaf ount for each model. Must have same order as ``all_cords``.
617
+ nleaves_max_all (list): Maximum values of leaf ount for each model. Must have same order as ``all_cords``.
618
+ random (object): Current random state of the sampler.
619
+ **kwargs (ignored): For modularity.
620
+
621
+ Returns:
622
+ tuple: Tuple containing proposal information.
623
+ First entry is the new coordinates as a dictionary with keys
624
+ as ``branch_names`` and values as
625
+ ``double `` np.ndarray[ntemps, nwalkers, nleaves_max, ndim] containing
626
+ proposed coordinates. Second entry is the new ``inds`` array with
627
+ boolean values flipped for added or removed sources. Third entry
628
+ is the factors associated with the
629
+ proposal necessary for detailed balance. This is effectively
630
+ any term in the detailed balance fraction. +log of factors if
631
+ in the numerator. -log of factors if in the denominator.
632
+
633
+ """
634
+
635
+ if len(list(branches_coords.keys())) > 1:
636
+ raise ValueError("Can only propose change to one model at a time with MT.")
637
+
638
+ # get main key
639
+ key_in = list(branches_coords.keys())[0]
640
+ self.key_in = key_in
641
+
642
+ if branches_inds is None:
643
+ raise ValueError("In MT RJ proposal, branches_inds cannot be None.")
644
+
645
+ ntemps, nwalkers, nleaves_max, ndim = branches_coords[key_in].shape
646
+
647
+ # get temperature information
648
+ betas_here = np.repeat(
649
+ self.temperature_control.betas[:, None], nwalkers, axis=-1
650
+ ).flatten()
651
+
652
+ # current Likelihood and prior information
653
+ ll_here = self.current_state.log_like.flatten()
654
+ lp_here = self.current_state.log_prior.flatten()
655
+
656
+ # do rj setup
657
+ assert len(nleaves_min_all) == 1 and len(nleaves_max_all) == 1
658
+ nleaves_min = nleaves_min_all[key_in]
659
+ nleaves_max = nleaves_max_all[key_in]
660
+
661
+ if nleaves_min == nleaves_max:
662
+ raise ValueError("MT RJ proposal requires that nleaves_min != nleaves_max.")
663
+ elif nleaves_min > nleaves_max:
664
+ raise ValueError("nleaves_min is greater than nleaves_max. Not allowed.")
665
+
666
+ # get the inds adjustment information
667
+ all_inds_for_change = self.get_model_change_proposal(
668
+ branches_inds[key_in], random, nleaves_min, nleaves_max
669
+ )
670
+
671
+ # preparing leaf information for going into the proposal
672
+ inds_leaves_rj = np.zeros(ntemps * nwalkers, dtype=int)
673
+ coords_in = np.zeros((ntemps * nwalkers, ndim))
674
+ inds_reverse_rj = None
675
+
676
+ # prepare proposal dictionaries
677
+ new_inds = deepcopy(branches_inds)
678
+ q = deepcopy(branches_coords)
679
+ for change in all_inds_for_change.keys():
680
+ if change not in ["+1", "-1"]:
681
+ raise ValueError("MT RJ is only implemented for +1/-1 moves.")
682
+
683
+ # get indicies of changing leaves
684
+ temp_inds = all_inds_for_change[change][:, 0]
685
+ walker_inds = all_inds_for_change[change][:, 1]
686
+ leaf_inds = all_inds_for_change[change][:, 2]
687
+
688
+ # leaf index to change
689
+ inds_leaves_rj[temp_inds * nwalkers + walker_inds] = leaf_inds
690
+ coords_in[temp_inds * nwalkers + walker_inds] = branches_coords[key_in][
691
+ (temp_inds, walker_inds, leaf_inds)
692
+ ]
693
+
694
+ # adjustment of indices
695
+ new_val = {"+1": True, "-1": False}[change]
696
+
697
+ # adjust indices
698
+ new_inds[key_in][(temp_inds, walker_inds, leaf_inds)] = new_val
699
+
700
+ if change == "-1":
701
+ # which walkers are removing
702
+ inds_reverse_rj = temp_inds * nwalkers + walker_inds
703
+
704
+ if inds_reverse_rj is not None:
705
+ # setup reversal coords and inds
706
+ # need to determine Likelihood and prior of removed binaries.
707
+ # this goes into the multiple try proposal as previous ll and lp
708
+ temp_reverse_coords = {}
709
+ temp_reverse_inds = {}
710
+
711
+ for key in self.current_state.branches:
712
+ (
713
+ ntemps_tmp,
714
+ nwalkers_tmp,
715
+ nleaves_max_tmp,
716
+ ndim_tmp,
717
+ ) = self.current_state.branches[key].shape
718
+
719
+ # coords from reversal
720
+ temp_reverse_coords[key] = self.current_state.branches[
721
+ key
722
+ ].coords.reshape(ntemps_tmp * nwalkers_tmp, nleaves_max_tmp, ndim_tmp)[
723
+ inds_reverse_rj
724
+ ][
725
+ None, :
726
+ ]
727
+
728
+ # which inds array to use
729
+ inds_tmp_here = (
730
+ new_inds[key]
731
+ if key == key_in
732
+ else self.current_state.branches[key].inds
733
+ )
734
+ temp_reverse_inds[key] = inds_tmp_here.reshape(
735
+ ntemps * nwalkers, nleaves_max_tmp
736
+ )[inds_reverse_rj][None, :]
737
+
738
+ # calculate information for the reverse
739
+ lp_reverse_here = self.current_model.compute_log_prior_fn(
740
+ temp_reverse_coords, inds=temp_reverse_inds
741
+ )[0]
742
+ ll_reverse_here = self.current_model.compute_log_like_fn(
743
+ temp_reverse_coords, inds=temp_reverse_inds, logp=lp_here
744
+ )[0]
745
+
746
+ # fill the here values
747
+ ll_here[inds_reverse_rj] = ll_reverse_here
748
+ lp_here[inds_reverse_rj] = lp_reverse_here
749
+
750
+ # get mt proposal
751
+ generated_points, factors = self.get_mt_proposal(
752
+ coords_in,
753
+ random,
754
+ betas=betas_here,
755
+ ll_in=ll_here,
756
+ lp_in=lp_here,
757
+ inds_leaves_rj=inds_leaves_rj,
758
+ inds_reverse_rj=inds_reverse_rj,
759
+ )
760
+
761
+ # for reading outside
762
+ self.mt_ll = self.mt_ll.reshape(ntemps, nwalkers)
763
+ self.mt_lp = self.mt_lp.reshape(ntemps, nwalkers)
764
+
765
+ # which walkers have information added
766
+ inds_forward_rj = np.delete(np.arange(coords_in.shape[0]), inds_reverse_rj)
767
+
768
+ # updated the coordinates
769
+ temp_inds = all_inds_for_change["+1"][:, 0]
770
+ walker_inds = all_inds_for_change["+1"][:, 1]
771
+ leaf_inds = all_inds_for_change["+1"][:, 2]
772
+ q[key_in][(temp_inds, walker_inds, leaf_inds)] = generated_points[
773
+ inds_forward_rj
774
+ ]
775
+
776
+ return q, new_inds, factors.reshape(ntemps, nwalkers)