eryn 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
eryn/moves/move.py ADDED
@@ -0,0 +1,703 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from ..state import BranchSupplemental
4
+ import numpy as np
5
+
6
+ from copy import deepcopy
7
+
8
+ try:
9
+ import cupy as cp
10
+ except (ModuleNotFoundError, ImportError):
11
+ import numpy as cp
12
+
13
+ __all__ = ["Move"]
14
+
15
+
16
+ class Move(object):
17
+ """Parent class for proposals or "moves"
18
+
19
+ Args:
20
+ temperature_control (:class:`tempering.TemperatureControl`, optional):
21
+ This object controls the tempering. It is passed to the parent class
22
+ to moves so that all proposals can share and use temperature settings.
23
+ (default: ``None``)
24
+ periodic (:class:`eryn.utils.PeriodicContainer, optional):
25
+ This object holds periodic information and methods for periodic parameters. It is passed to the parent class
26
+ to moves so that all proposals can share and use periodic information.
27
+ (default: ``None``)
28
+ gibbs_sampling_setup (str, tuple, dict, or list, optional): This sets the Gibbs Sampling setup if
29
+ desired. The Gibbs sampling setup is completely customizable down to the leaf and parameters.
30
+ All of the separate Gibbs sampling splits will be run within 1 call to this proposal.
31
+ If ``None``, run all branches and all parameters. If ``str``, run all parameters within the
32
+ branch given as the string. To enter a branch with a specific set of parameters, you can
33
+ provide a 2-tuple with the first entry as the branch name and the second entry as a 2D
34
+ boolean array of shape ``(nleaves_max, ndim)`` that indicates which leaves and/or parameters
35
+ you want to run. ``None`` can also be entered in the second entry if all parameters are to be run.
36
+ A dictionary is also possible with keys as branch names and values as the same 2D boolean array
37
+ of shape ``(nleaves_max, ndim)`` that indicates which leaves and/or parameters
38
+ you want to run. ``None`` can also be entered in the value of the dictionary
39
+ if all parameters are to be run. If multiple keys are provided in the dictionary, those
40
+ branches will be run simultaneously in the proposal as one iteration of the proposing loop.
41
+ The final option is a list. This is how you make sure to run all the Gibbs splits. Each entry
42
+ of the list can be a string, 2-tuple, or dictionary as described above. The list controls
43
+ the order in which all of these splits are run. (default: ``None``)
44
+ prevent_swaps (bool, optional): If ``True``, do not perform temperature swaps in this move.
45
+ skip_supp_names_update (list, optional): List of names (`str`), that can be in any
46
+ :class:`eryn.state.BranchSupplemental`,
47
+ to skip when updating states (:func:`Move.update`). This is useful if a
48
+ large amount of memory is stored in the branch supplementals.
49
+ is_rj (bool, optional): If using RJ, this should be ``True``. (default: ``False``)
50
+ use_gpu (bool, optional): If ``True``, use ``CuPy`` for computations.
51
+ Use ``NumPy`` if ``use_gpu == False``. (default: ``False``)
52
+ random_seed (int, optional): Set the random seed in ``CuPy/NumPy`` if not ``None``.
53
+ (default: ``None``)
54
+
55
+ Raises:
56
+ ValueError: Incorrect inputs.
57
+
58
+ Attributes:
59
+ Note: All kwargs are stored as attributes.
60
+ num_proposals (int): the number of times this move has been run. This is needed to
61
+ compute the acceptance fraction.
62
+ gibbs_sampling_setup (list): All of the Gibbs sampling splits as described above.
63
+ xp (obj): ``NumPy`` or ``CuPy``.
64
+ use_gpu (bool): Whether ``Cupy`` (``True``) is used or not (``False``).
65
+
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ temperature_control=None,
71
+ periodic=None,
72
+ gibbs_sampling_setup=None,
73
+ prevent_swaps=False,
74
+ skip_supp_names_update=[],
75
+ is_rj=False,
76
+ use_gpu=False,
77
+ random_seed=None,
78
+ **kwargs
79
+ ):
80
+ # store all information
81
+ self.temperature_control = temperature_control
82
+ self.periodic = periodic
83
+ self.skip_supp_names_update = skip_supp_names_update
84
+ self.prevent_swaps = prevent_swaps
85
+
86
+ self._initialize_branch_setup(gibbs_sampling_setup, is_rj=is_rj)
87
+
88
+ # keep track of the number of proposals
89
+ self.num_proposals = 0
90
+ self.time = 0
91
+
92
+ self.use_gpu = use_gpu
93
+
94
+ # set the random seet of the library if desired
95
+ if random_seed is not None:
96
+ self.xp.random.seed(random_seed)
97
+
98
+ @property
99
+ def use_gpu(self):
100
+ return self._use_gpu
101
+
102
+ @use_gpu.setter
103
+ def use_gpu(self, use_gpu):
104
+ self._use_gpu = use_gpu
105
+
106
+ @property
107
+ def xp(self):
108
+ if self._use_gpu is None:
109
+ raise ValueError("use_gpu has not been set.")
110
+ xp = cp if self.use_gpu else np
111
+ return xp
112
+
113
+ def _initialize_branch_setup(self, gibbs_sampling_setup, is_rj=False):
114
+ """Initialize the gibbs setup properly."""
115
+ self.gibbs_sampling_setup = gibbs_sampling_setup
116
+
117
+ message_rj = """inputting gibbs indexing at the leaf/parameter level is not allowed
118
+ with an RJ proposal. Only branch names."""
119
+
120
+ message_non_rj = """When inputing gibbs indexing and using a 2-tuple, second item must be None or 2D np.ndarray of shape (nleaves_max, ndim)."""
121
+
122
+ # setup proposal branches properly
123
+ if self.gibbs_sampling_setup is not None:
124
+ # string indicates one branch (all of it)
125
+ if type(self.gibbs_sampling_setup) not in [str, tuple, list, dict]:
126
+ raise ValueError(
127
+ "gibbs_sampling_setup must be string, dict, tuple, or list."
128
+ )
129
+
130
+ if not isinstance(self.gibbs_sampling_setup, list):
131
+ self.gibbs_sampling_setup = [self.gibbs_sampling_setup]
132
+
133
+ gibbs_sampling_setup_tmp = []
134
+ for item in self.gibbs_sampling_setup:
135
+ # all the arguments are treated
136
+
137
+ # strings indicate single branch all parameters
138
+ if isinstance(item, str):
139
+ gibbs_sampling_setup_tmp.append(item)
140
+
141
+ # tuple is one branch with a split in the parameters
142
+ elif isinstance(item, tuple):
143
+ # check inputs
144
+ assert len(item) == 2
145
+ if item is not None and is_rj:
146
+ raise ValueError(message_rj)
147
+
148
+ elif (
149
+ not isinstance(item[1], np.ndarray) and item[1] is not None
150
+ ) or (isinstance(item[1], np.ndarray) and item[1].ndim != 2):
151
+ breakpoint()
152
+ raise ValueError(message_non_rj)
153
+
154
+ gibbs_sampling_setup_tmp.append(item)
155
+
156
+ # dict can include multiple models and parameter splits
157
+ # these will all be in one iteration
158
+ elif isinstance(item, dict):
159
+ tmp = []
160
+ for key, value in item.items():
161
+ # check inputs
162
+ if value is not None and is_rj:
163
+ raise ValueError(message_rj)
164
+
165
+ elif (
166
+ not isinstance(value, np.ndarray) and value is not None
167
+ ) or (isinstance(value, np.ndarray) and value.ndim != 2):
168
+ raise ValueError(message_non_rj)
169
+
170
+ tmp.append((key, value))
171
+
172
+ gibbs_sampling_setup_tmp.append(tmp)
173
+
174
+ else:
175
+ raise ValueError(
176
+ "If providing a list for gibbs_sampling_setup, each item needs to be a string, tuple, or dict."
177
+ )
178
+
179
+ # copy the original for information if needed
180
+ self.gibbs_sampling_setup_input = deepcopy(self.gibbs_sampling_setup)
181
+
182
+ # store as the setup that all proposals will follow
183
+ self.gibbs_sampling_setup = gibbs_sampling_setup_tmp
184
+
185
+ # now that we have everything out of the input
186
+ # sort into branch names and indices to be run
187
+ branch_names_run_all = []
188
+ inds_run_all = []
189
+
190
+ # for each split in the gibbs splits
191
+ for prop_i, proposal_iteration in enumerate(self.gibbs_sampling_setup):
192
+ # break out
193
+ if isinstance(proposal_iteration, tuple):
194
+ # tuple is 1 entry loop
195
+ branch_names_run_all.append([proposal_iteration[0]])
196
+ inds_run_all.append([proposal_iteration[1]])
197
+ elif isinstance(proposal_iteration, str):
198
+ # string is 1 entry loop
199
+ branch_names_run_all.append([proposal_iteration])
200
+ inds_run_all.append([None])
201
+
202
+ elif isinstance(proposal_iteration, list):
203
+ # list allows more branches at the same time
204
+ branch_names_run_all.append([])
205
+ inds_run_all.append([])
206
+ for item in proposal_iteration:
207
+ if isinstance(item, str):
208
+ branch_names_run_all[prop_i].append(item)
209
+ inds_run_all[prop_i].append(None)
210
+ elif isinstance(item, tuple):
211
+ branch_names_run_all[prop_i].append(item[0])
212
+ inds_run_all[prop_i].append(item[1])
213
+
214
+ # store information
215
+ self.branch_names_run_all = branch_names_run_all
216
+ self.inds_run_all = inds_run_all
217
+
218
+ else:
219
+ # no Gibbs sampling
220
+ self.branch_names_run_all = [None]
221
+ self.inds_run_all = [None]
222
+
223
+ def gibbs_sampling_setup_iterator(self, all_branch_names):
224
+ """Iterate through the gibbs splits as a generator
225
+
226
+ Args:
227
+ all_branch_names (list): List of all branch names.
228
+
229
+ Yields:
230
+ 2-tuple: Gibbs sampling split.
231
+ First entry is the branch names to run and the second entry is the index
232
+ into the leaves/parameters for this Gibbs split.
233
+
234
+ Raises:
235
+ ValueError: Incorrect inputs.
236
+
237
+ """
238
+ for branch_names_run, inds_run in zip(
239
+ self.branch_names_run_all, self.inds_run_all
240
+ ):
241
+ # adjust if branch_names_run is None
242
+ if branch_names_run is None:
243
+ branch_names_run = all_branch_names
244
+ inds_run = [None for _ in branch_names_run]
245
+ # yield to the iterator
246
+ yield (branch_names_run, inds_run)
247
+
248
+ def setup_proposals(
249
+ self, branch_names_run, inds_run, branches_coords, branches_inds
250
+ ):
251
+ """Setup proposals when gibbs sampling.
252
+
253
+ Get inputs into the proposal including Gibbs split information.
254
+
255
+ Args:
256
+ branch_names_run (list): List of branch names to run concurrently.
257
+ inds_run (list): List of ``inds`` arrays including Gibbs sampling information.
258
+ branches_coords (dict): Dictionary of coordinate arrays for all branches.
259
+ branches_inds (dict): Dictionary of ``inds`` arrays for all branches.
260
+
261
+ Returns:
262
+ tuple: (coords, inds, at_least_one_proposal)
263
+ * Coords including Gibbs sampling info.
264
+ * ``inds`` including Gibbs sampling info.
265
+ * ``at_least_one_proposal`` is boolean. It is passed out to
266
+ indicate there is at least one leaf available for the requested branch names.
267
+
268
+ """
269
+ inds_going_for_proposal = {}
270
+ coords_going_for_proposal = {}
271
+
272
+ at_least_one_proposal = False
273
+ for bnr, ir in zip(branch_names_run, inds_run):
274
+ if ir is not None:
275
+ tmp = np.zeros_like(branches_inds[bnr], dtype=bool)
276
+
277
+ # flatten coordinates to the leaves dimension
278
+ ir_keep = ir.astype(int).sum(axis=-1).astype(bool)
279
+ tmp[:, :, ir_keep] = True
280
+ # make sure leavdes that are actually not there are not counted
281
+ tmp[~branches_inds[bnr]] = False
282
+ inds_going_for_proposal[bnr] = tmp
283
+ else:
284
+ inds_going_for_proposal[bnr] = branches_inds[bnr]
285
+
286
+ if np.any(inds_going_for_proposal[bnr]):
287
+ at_least_one_proposal = True
288
+
289
+ coords_going_for_proposal[bnr] = branches_coords[bnr]
290
+
291
+ return (
292
+ coords_going_for_proposal,
293
+ inds_going_for_proposal,
294
+ at_least_one_proposal,
295
+ )
296
+
297
+ def cleanup_proposals_gibbs(
298
+ self,
299
+ branch_names_run,
300
+ inds_run,
301
+ q,
302
+ branches_coords,
303
+ new_inds=None,
304
+ branches_inds=None,
305
+ new_branch_supps=None,
306
+ branches_supplemental=None,
307
+ ):
308
+ """Set all not Gibbs-sampled parameters back
309
+
310
+ Args:
311
+ branch_names_run (list): List of branch names to run concurrently.
312
+ inds_run (list): List of ``inds`` arrays including Gibbs sampling information.
313
+ q (dict): Dictionary of new coordinate arrays for all proposal branches.
314
+ branches_coords (dict): Dictionary of old coordinate arrays for all branches.
315
+ new_inds (dict, optional): Dictionary of new inds arrays for all proposal branches.
316
+ branches_inds (dict, optional): Dictionary of old inds arrays for all branches.
317
+ new_branch_supps (dict, optional): Dictionary of new branches supplemental for all proposal branches.
318
+ branches_supplemental (dict, optional): Dictionary of old branches supplemental for all branches.
319
+
320
+ """
321
+ # add back any parameters that are fixed for this round
322
+ for bnr, ir in zip(branch_names_run, inds_run):
323
+ if ir is not None:
324
+ q[bnr][:, :, ~ir] = branches_coords[bnr][:, :, ~ir]
325
+
326
+ # add other models that were not included
327
+ for key, value in branches_coords.items():
328
+ if key not in q:
329
+ q[key] = value.copy()
330
+ if new_inds is not None and key not in new_inds:
331
+ assert branches_inds is not None
332
+ new_inds[key] = branches_inds[key].copy()
333
+
334
+ if new_branch_supps is not None and key not in new_branch_supps:
335
+ assert branches_supplemental is not None
336
+ new_branch_supps[key] = branches_supplemental[key].copy()
337
+
338
+ def ensure_ordering(self, correct_key_order, q, new_inds, new_branch_supps):
339
+ """Ensure proper order of key in dictionaries.
340
+
341
+ Args:
342
+ correct_key_order (list): Keys in correct order.
343
+ q (dict): Dictionary of new coordinate arrays for all branches.
344
+ new_inds (dict): Dictionary of new inds arrays for all branches.
345
+ new_branch_supps (dict or None): Dictionary of new branches supplemental for all proposal branches.
346
+
347
+ Returns:
348
+ Tuple: (q, new_inds, new_branch_supps) in correct key order.
349
+
350
+ """
351
+ if list(q.keys()) != correct_key_order:
352
+ q = {key: q[key] for key in correct_key_order}
353
+
354
+ if list(new_inds.keys()) != correct_key_order:
355
+ new_inds = {key: new_inds[key] for key in correct_key_order}
356
+
357
+ if (
358
+ new_branch_supps is not None
359
+ and list(new_branch_supps.keys()) != correct_key_order
360
+ ):
361
+ temp = {key: None for key in correct_key_order}
362
+ for key in new_branch_supps:
363
+ temp[key] = new_branch_supps[key]
364
+ new_branch_supps = deepcopy(temp)
365
+
366
+ return q, new_inds, new_branch_supps
367
+
368
+ def fix_logp_gibbs(self, branch_names_run, inds_run, logp, inds):
369
+ """Set any walker with no leaves to have logp = -np.inf
370
+
371
+ Args:
372
+ branch_names_run (list): List of branch names to run concurrently.
373
+ inds_run (list): List of ``inds`` arrays including Gibbs sampling information.
374
+ logp (np.ndarray): Log of the prior going into final posterior computation.
375
+ inds (dict): Dictionary of ``inds`` arrays for all branches.
376
+
377
+ """
378
+ total_leaves = np.zeros_like(logp, dtype=int)
379
+ total_leaves_here = np.zeros_like(logp, dtype=int)
380
+ for bnr, ir in zip(branch_names_run, inds_run):
381
+ if ir is not None:
382
+ tmp = np.zeros_like(inds[bnr], dtype=bool)
383
+
384
+ # flatten coordinates to the leaves dimension
385
+ ir_keep = ir.astype(int).sum(axis=-1).astype(bool)
386
+ tmp[:, :, ir_keep] = True
387
+ # make sure leaves that are actually not there are not counted
388
+ tmp[~inds[bnr]] = False
389
+
390
+ else:
391
+ tmp = inds[bnr]
392
+
393
+ total_leaves += tmp.sum(axis=-1)
394
+ total_leaves_here += tmp.sum(axis=-1)
395
+
396
+ for name, inds_val in inds.items():
397
+ if name not in branch_names_run:
398
+ total_leaves += inds_val.sum(axis=-1)
399
+
400
+ # adjust
401
+ logp[(total_leaves != 0) & (total_leaves_here == 0)] = -np.inf # no use in running because no change
402
+ logp[(total_leaves == 0) & (total_leaves_here == 0)] = 0.0 # there is nothing in the model currently
403
+
404
+ @property
405
+ def accepted(self):
406
+ """Accepted counts for this move."""
407
+ if self._accepted is None:
408
+ raise ValueError(
409
+ "accepted must be inititalized with the init_accepted function if you want to use it."
410
+ )
411
+ return self._accepted
412
+
413
+ @accepted.setter
414
+ def accepted(self, accepted):
415
+ assert isinstance(accepted, np.ndarray)
416
+ self._accepted = accepted
417
+
418
+ @property
419
+ def acceptance_fraction(self):
420
+ """Acceptance fraction for this move."""
421
+ return self.accepted / self.num_proposals
422
+
423
+ @property
424
+ def temperature_control(self):
425
+ """Temperature controller"""
426
+ return self._temperature_control
427
+
428
+ @temperature_control.setter
429
+ def temperature_control(self, temperature_control):
430
+ self._temperature_control = temperature_control
431
+
432
+ # use the setting of the temperature control to determine which log posterior function to use
433
+ # tempered or basic
434
+ if temperature_control is None:
435
+ self.compute_log_posterior = self.compute_log_posterior_basic
436
+ else:
437
+ self.compute_log_posterior = (
438
+ self.temperature_control.compute_log_posterior_tempered
439
+ )
440
+
441
+ self.ntemps = self.temperature_control.ntemps
442
+
443
+ def compute_log_posterior_basic(self, logl, logp):
444
+ """Compute the log of posterior
445
+
446
+ :math:`\log{P} = \log{L} + \log{p}`
447
+
448
+ This method is to mesh with the tempered log posterior computation.
449
+
450
+ Args:
451
+ logl (np.ndarray[ntemps, nwalkers]): Log-likelihood values.
452
+ logp (np.ndarray[ntemps, nwalkers]): Log-prior values.
453
+
454
+ Returns:
455
+ np.ndarray[ntemps, nwalkers]: Log-Posterior values.
456
+ """
457
+ return logl + logp
458
+
459
+ def tune(self, state, accepted):
460
+ """Tune a proposal
461
+
462
+ This is a place holder for tuning.
463
+
464
+ Args:
465
+ state (:class:`eryn.state.State`): Current state of sampler.
466
+ accepted (np.ndarray[ntemps, nwalkers]): Accepted values for last pass
467
+ through proposal.
468
+
469
+ """
470
+ pass
471
+
472
+ def update(self, old_state, new_state, accepted, subset=None):
473
+ """Update a given subset of the ensemble with an accepted proposal
474
+
475
+ This class was updated from ``emcee`` to handle the added structure
476
+ of Eryn.
477
+
478
+ Args:
479
+ old_state (:class:`eryn.state.State`): State with current information.
480
+ New information is added to this state.
481
+ new_state (:class:`eryn.state.State`): State with information from proposed
482
+ points.
483
+ accepted (np.ndarray[ntemps, nwalkers]): A vector of booleans indicating
484
+ which walkers were accepted.
485
+ subset (np.ndarray[ntemps, nwalkers], optional): A boolean mask
486
+ indicating which walkers were included in the subset.
487
+ This can be used, for example, when updating only the
488
+ primary ensemble in a :class:`RedBlueMove`.
489
+ (default: ``None``)
490
+
491
+ Returns:
492
+ :class:`eryn.state.State`: ``old_state`` with accepted points added from ``new_state``.
493
+
494
+ """
495
+
496
+ # TODO: update this to be use (tuples of inds) ??
497
+ if subset is None:
498
+ # subset of everything
499
+ subset = np.tile(
500
+ np.arange(old_state.log_like.shape[1]), (old_state.log_like.shape[0], 1)
501
+ )
502
+
503
+ # each computation is similar
504
+ # 1. Take subset of values from old information (take_along_axis)
505
+ # 2. Set new information
506
+ # 3. Combine into a new temporary quantity based on accepted or not
507
+ # 4. Put new combined subset back into full arrays (put_along_axis)
508
+
509
+ # take_along_axis is necessary to do this all in higher dimensions
510
+ accepted_temp = np.take_along_axis(accepted, subset, axis=1)
511
+
512
+ # new log likelihood
513
+ old_log_likes = np.take_along_axis(old_state.log_like, subset, axis=1)
514
+ new_log_likes = new_state.log_like
515
+ temp_change_log_like = new_log_likes * (accepted_temp) + old_log_likes * (
516
+ ~accepted_temp
517
+ )
518
+
519
+ np.put_along_axis(old_state.log_like, subset, temp_change_log_like, axis=1)
520
+
521
+ # new log prior
522
+ old_log_priors = np.take_along_axis(old_state.log_prior, subset, axis=1)
523
+ new_log_priors = new_state.log_prior.copy()
524
+
525
+ # deal with -infs
526
+ new_log_priors[np.isinf(new_log_priors)] = 0.0
527
+
528
+ temp_change_log_prior = new_log_priors * (accepted_temp) + old_log_priors * (
529
+ ~accepted_temp
530
+ )
531
+
532
+ np.put_along_axis(old_state.log_prior, subset, temp_change_log_prior, axis=1)
533
+
534
+ # inds
535
+ old_inds = {
536
+ name: np.take_along_axis(branch.inds, subset[:, :, None], axis=1)
537
+ for name, branch in old_state.branches.items()
538
+ }
539
+
540
+ new_inds = {name: branch.inds for name, branch in new_state.branches.items()}
541
+
542
+ temp_change_inds = {
543
+ name: new_inds[name] * (accepted_temp[:, :, None])
544
+ + old_inds[name] * (~accepted_temp[:, :, None])
545
+ for name in old_inds
546
+ }
547
+
548
+ [
549
+ np.put_along_axis(
550
+ old_state.branches[name].inds,
551
+ subset[:, :, None],
552
+ temp_change_inds[name],
553
+ axis=1,
554
+ )
555
+ for name in new_inds
556
+ ]
557
+
558
+ # check for branches_supplemental
559
+ run_branches_supplemental = False
560
+ for name, value in old_state.branches_supplemental.items():
561
+ if value is not None:
562
+ run_branches_supplemental = True
563
+
564
+ if run_branches_supplemental:
565
+ # branch_supplemental
566
+ temp_change_branch_supplemental = {}
567
+ for name in old_state.branches:
568
+ if old_state.branches[name].branch_supplemental is not None:
569
+ old_branch_supplemental = old_state.branches[
570
+ name
571
+ ].branch_supplemental.take_along_axis(
572
+ subset[:, :, None],
573
+ axis=1,
574
+ skip_names=self.skip_supp_names_update,
575
+ )
576
+ new_branch_supplemental = new_state.branches[
577
+ name
578
+ ].branch_supplemental[:]
579
+
580
+ tmp = {}
581
+ for key in old_branch_supplemental:
582
+ # need to check to see if we should skip anything
583
+ if key in self.skip_supp_names_update:
584
+ continue
585
+ accepted_temp_here = accepted_temp.copy()
586
+
587
+ # have adjust if it is an object array or a regular array
588
+ if new_branch_supplemental[key].dtype.name != "object":
589
+ for _ in range(
590
+ new_branch_supplemental[key].ndim
591
+ - accepted_temp_here.ndim
592
+ ):
593
+ accepted_temp_here = np.expand_dims(
594
+ accepted_temp_here, (-1,)
595
+ )
596
+
597
+ # adjust for GPUs
598
+ try:
599
+ tmp[key] = new_branch_supplemental[key] * (
600
+ accepted_temp_here
601
+ ) + old_branch_supplemental[key] * (~accepted_temp_here)
602
+ except TypeError:
603
+ # for gpus
604
+ tmp[key] = new_branch_supplemental[key] * (
605
+ xp.asarray(accepted_temp_here)
606
+ ) + old_branch_supplemental[key] * (
607
+ xp.asarray(~accepted_temp_here)
608
+ )
609
+
610
+ temp_change_branch_supplemental[name] = BranchSupplemental(
611
+ tmp,
612
+ base_shape=new_state.branches_supplemental[name].base_shape,
613
+ copy=True,
614
+ )
615
+
616
+ else:
617
+ temp_change_branch_supplemental[name] = None
618
+
619
+ [
620
+ old_state.branches[name].branch_supplemental.put_along_axis(
621
+ subset[:, :, None],
622
+ temp_change_branch_supplemental[name][:],
623
+ axis=1,
624
+ )
625
+ for name in new_inds
626
+ if temp_change_branch_supplemental[name] is not None
627
+ ]
628
+
629
+ # sampler level supplemental
630
+ if old_state.supplemental is not None:
631
+ old_suppliment = old_state.supplemental.take_along_axis(subset, axis=1)
632
+ new_suppliment = new_state.supplemental[:]
633
+
634
+ accepted_temp_here = accepted_temp.copy()
635
+
636
+ temp_change_suppliment = {}
637
+ for name in old_suppliment:
638
+ # make sure to get rid of specific supps if requested
639
+ if name in self.skip_supp_names_update:
640
+ continue
641
+
642
+ # adjust if it is not an object array
643
+ if old_suppliment[name].dtype.name != "object":
644
+ for _ in range(old_suppliment[name].ndim - accepted_temp_here.ndim):
645
+ accepted_temp_here = np.expand_dims(accepted_temp_here, (-1,))
646
+ try:
647
+ temp_change_suppliment[name] = new_suppliment[name] * (
648
+ accepted_temp_here
649
+ ) + old_suppliment[name] * (~accepted_temp_here)
650
+ except TypeError:
651
+ temp_change_suppliment[name] = new_suppliment[name] * (
652
+ xp.asarray(accepted_temp_here)
653
+ ) + old_suppliment[name] * (xp.asarray(~accepted_temp_here))
654
+ old_state.supplemental.put_along_axis(
655
+ subset, temp_change_suppliment, axis=1
656
+ )
657
+
658
+ # coords
659
+ old_coords = {
660
+ name: np.take_along_axis(branch.coords, subset[:, :, None, None], axis=1)
661
+ for name, branch in old_state.branches.items()
662
+ }
663
+
664
+ new_coords = {
665
+ name: branch.coords for name, branch in new_state.branches.items()
666
+ }
667
+
668
+ # change to copy then fill due to issue of adding Nans
669
+ temp_change_coords = {name: old_coords[name].copy() for name in old_coords}
670
+
671
+ for name in temp_change_coords:
672
+ temp_change_coords[name][accepted_temp] = new_coords[name][accepted_temp]
673
+
674
+ [
675
+ np.put_along_axis(
676
+ old_state.branches[name].coords,
677
+ subset[:, :, None, None],
678
+ temp_change_coords[name],
679
+ axis=1,
680
+ )
681
+ for name in new_coords
682
+ ]
683
+
684
+ # take care of blobs
685
+ if new_state.blobs is not None:
686
+ if old_state.blobs is None:
687
+ raise ValueError(
688
+ "If you start sampling with a given log_like, "
689
+ "you also need to provide the current list of "
690
+ "blobs at that position."
691
+ )
692
+
693
+ old_blobs = np.take_along_axis(old_state.blobs, subset[:, :, None], axis=1)
694
+ new_blobs = new_state.blobs
695
+ temp_change_blobs = new_blobs * (accepted_temp[:, :, None]) + old_blobs * (
696
+ ~accepted_temp[:, :, None]
697
+ )
698
+
699
+ np.put_along_axis(
700
+ old_state.blobs, subset[:, :, None], temp_change_blobs, axis=1
701
+ )
702
+
703
+ return old_state