lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lisaanalysistools might be problematic. Click here for more details.

Files changed (37) hide show
  1. lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
  2. lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
  3. lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
  4. lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
  5. lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
  6. lisatools/__init__.py +0 -0
  7. lisatools/_version.py +4 -0
  8. lisatools/analysiscontainer.py +438 -0
  9. lisatools/cutils/detector.cpython-312-darwin.so +0 -0
  10. lisatools/datacontainer.py +292 -0
  11. lisatools/detector.py +410 -0
  12. lisatools/diagnostic.py +976 -0
  13. lisatools/glitch.py +193 -0
  14. lisatools/sampling/__init__.py +0 -0
  15. lisatools/sampling/likelihood.py +882 -0
  16. lisatools/sampling/moves/__init__.py +0 -0
  17. lisatools/sampling/moves/gbgroupstretch.py +53 -0
  18. lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
  19. lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
  20. lisatools/sampling/moves/gbspecialstretch.py +1836 -0
  21. lisatools/sampling/moves/mbhspecialmove.py +286 -0
  22. lisatools/sampling/moves/placeholder.py +16 -0
  23. lisatools/sampling/moves/skymodehop.py +110 -0
  24. lisatools/sampling/moves/specialforegroundmove.py +564 -0
  25. lisatools/sampling/prior.py +508 -0
  26. lisatools/sampling/stopping.py +320 -0
  27. lisatools/sampling/utility.py +324 -0
  28. lisatools/sensitivity.py +888 -0
  29. lisatools/sources/__init__.py +0 -0
  30. lisatools/sources/emri/__init__.py +1 -0
  31. lisatools/sources/emri/tdiwaveform.py +72 -0
  32. lisatools/stochastic.py +291 -0
  33. lisatools/utils/__init__.py +0 -0
  34. lisatools/utils/constants.py +40 -0
  35. lisatools/utils/multigpudataholder.py +730 -0
  36. lisatools/utils/pointeradjust.py +106 -0
  37. lisatools/utils/utility.py +240 -0
@@ -0,0 +1,564 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from copy import deepcopy
4
+ from inspect import Attribute
5
+ import numpy as np
6
+ from scipy import stats
7
+ import warnings
8
+ import time
9
+
10
+ try:
11
+ import cupy as xp
12
+
13
+ gpu_available = True
14
+ except ModuleNotFoundError:
15
+ import numpy as xp
16
+
17
+ gpu_available = False
18
+
19
+ from eryn.moves import StretchMove
20
+ from eryn.prior import ProbDistContainer
21
+ from eryn.state import State
22
+
23
+
24
+ __all__ = ["GBSpecialStretchMove"]
25
+
26
+ # MHMove needs to be to the left here to overwrite GBBruteRejectionRJ RJ proposal method
27
+ class GBForegroundSpecialMove(StretchMove):
28
+ """
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ gb,
34
+ priors,
35
+ start_freq_ind,
36
+ data_length,
37
+ mgh,
38
+ fd,
39
+ band_edges,
40
+ *args,
41
+ waveform_kwargs={},
42
+ noise_kwargs={},
43
+ parameter_transforms=None,
44
+ search=False,
45
+ search_samples=None,
46
+ search_snrs=None,
47
+ search_snr_lim=None,
48
+ search_snr_accept_factor=1.0,
49
+ take_max_ll=False,
50
+ global_template_builder=None,
51
+ psd_func=None,
52
+ provide_betas=False,
53
+ alternate_priors=None,
54
+ batch_size=5,
55
+ **kwargs
56
+ ):
57
+ StretchMove.__init__(self, **kwargs)
58
+
59
+ self.time = 0
60
+ self.greater_than_1e0 = 0
61
+ self.name = "GBForegroundSpecialMove".lower()
62
+
63
+ # TODO: make priors optional like special generate function?
64
+ for key in priors:
65
+ if not isinstance(priors[key], ProbDistContainer):
66
+ raise ValueError("Priors need to be eryn.priors.ProbDistContainer object.")
67
+ self.priors = priors
68
+ self.gb = gb
69
+ self.provide_betas = provide_betas
70
+ self.batch_size = batch_size
71
+ self.stop_here = True
72
+
73
+ # use gpu from template generator
74
+ self.use_gpu = gb.use_gpu
75
+ if self.use_gpu:
76
+ self.xp = xp
77
+ self.mempool = self.xp.get_default_memory_pool()
78
+
79
+ else:
80
+ self.xp = np
81
+
82
+
83
+ self.band_edges = band_edges
84
+ self.num_bands = len(band_edges) - 1
85
+ self.start_freq_ind = start_freq_ind
86
+ self.data_length = data_length
87
+ self.waveform_kwargs = waveform_kwargs
88
+ self.noise_kwargs = noise_kwargs
89
+ self.parameter_transforms = parameter_transforms
90
+ self.psd_func = psd_func
91
+ self.fd = fd
92
+ self.df = (fd[1] - fd[0]).item()
93
+ self.mgh = mgh
94
+ self.search = search
95
+ self.global_template_builder = global_template_builder
96
+
97
+ if search_snrs is not None:
98
+ if search_snr_lim is None:
99
+ search_snr_lim = 0.1
100
+
101
+ assert len(search_samples) == len(search_snrs)
102
+
103
+ self.search_samples = search_samples
104
+ self.search_snrs = search_snrs
105
+ self.search_snr_lim = search_snr_lim
106
+ self.search_snr_accept_factor = search_snr_accept_factor
107
+
108
+ self.take_max_ll = take_max_ll
109
+
110
+ def propose(self, model, state):
111
+ """Use the move to generate a proposal and compute the acceptance
112
+
113
+ Args:
114
+ model (:class:`eryn.model.Model`): Carrier of sampler information.
115
+ state (:class:`State`): Current state of the sampler.
116
+
117
+ Returns:
118
+ :class:`State`: State of sampler after proposal is complete.
119
+
120
+ """
121
+ # st = time.perf_counter()
122
+ self.xp.cuda.runtime.setDevice(self.mgh.gpus[0])
123
+
124
+ # np.random.seed(10)
125
+ #print("start stretch")
126
+ #st = time.perf_counter()
127
+ # Check that the dimensions are compatible.
128
+ ntemps, nwalkers, nleaves, ndim = state.branches["galfor"].shape
129
+
130
+ self.nwalkers = nwalkers
131
+ # TODO: deal with more intensive acceptance fractions
132
+ # Run any move-specific setup.
133
+ self.setup(state.branches)
134
+
135
+ new_state = State(state) # , copy=True)
136
+ self.mempool.free_all_blocks()
137
+
138
+ self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
139
+
140
+ # data should not be whitened
141
+
142
+ # Split the ensemble in half and iterate over these two halves.
143
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
144
+
145
+ ntemps, nwalkers, nleaves_max, ndim = state.branches_coords["gb"].shape
146
+
147
+ split_inds = np.zeros(nwalkers, dtype=int)
148
+ split_inds[1::2] = 1
149
+ np.random.shuffle(split_inds)
150
+
151
+ current_coords_galfor = new_state.branches["galfor"].coords.copy()
152
+ """et = time.perf_counter()
153
+ print("setup", (et - st))
154
+ st = time.perf_counter()"""
155
+ for split in range(2):
156
+ # st = time.perf_counter()
157
+ split_here = split_inds == split
158
+ walkers_keep = self.xp.arange(nwalkers)[split_here]
159
+
160
+ points_to_move = {key: new_state.branches[key].coords[:, split_here] for key in ["psd", "galfor"]}
161
+ points_for_move = {key: [new_state.branches[key].coords[:, ~split_here]] for key in ["psd", "galfor"]}
162
+
163
+ q, factors = self.get_proposal(
164
+ points_to_move,
165
+ points_for_move,
166
+ model.random
167
+ )
168
+
169
+ temp_part_general = np.repeat(np.arange(ntemps)[:, None], nwalkers, axis=-1)[:, split_here].flatten()
170
+ walker_part_general = np.tile(np.arange(nwalkers), (ntemps, 1))[:, split_here].flatten()
171
+
172
+ # get logp
173
+ logp_here = (self.priors["psd"].logpdf(q["psd"].reshape(-1, q["psd"].shape[-1])) + self.priors["galfor"].logpdf(q["galfor"].reshape(-1, q["galfor"].shape[-1]))).reshape((ntemps, int(nwalkers / 2)))
174
+
175
+ prev_logp_here = (self.priors["psd"].logpdf(new_state.branches_coords["psd"][:, split_here].reshape(-1, q["psd"].shape[-1])) + self.priors["galfor"].logpdf(new_state.branches_coords["galfor"][:, split_here].reshape(-1, q["galfor"].shape[-1]))).reshape((ntemps, int(nwalkers / 2)))
176
+
177
+ bad = np.isinf(logp_here.flatten())
178
+
179
+ data_index_tmp = np.asarray((temp_part_general * self.nwalkers + walker_part_general).astype(xp.int32))
180
+
181
+ data_index = self.mgh.get_mapped_indices(data_index_tmp)
182
+
183
+ data_index_in = data_index[~bad]
184
+
185
+ psd_params = q["psd"].reshape(-1, q["psd"].shape[-1])[~bad]
186
+ foreground_params = q["galfor"].reshape(-1, q["galfor"].shape[-1])[~bad]
187
+
188
+ self.mgh.set_psd_vals(psd_params, foreground_params=foreground_params, overall_inds=data_index_in)
189
+
190
+ logl_temp = self.mgh.get_ll(include_psd_info=True, overall_inds=data_index_in)
191
+
192
+ logl = np.full((ntemps, int(nwalkers / 2)), -1e300)
193
+ logl[~bad.reshape(ntemps, -1)] = logl_temp
194
+
195
+ prev_logl = new_state.log_like[:, split_here]
196
+ prev_logp = new_state.log_prior[:, split_here]
197
+
198
+ logp = prev_logp + logp_here - prev_logp_here
199
+
200
+ logP = self.compute_log_posterior(logl, logp)
201
+ prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
202
+
203
+ lnpdiff = factors + logP - prev_logP
204
+
205
+ keep = lnpdiff > np.log(model.random.rand(ntemps, int(nwalkers / 2)))
206
+
207
+ temp_inds_keep = temp_part_general[keep.flatten()]
208
+ walker_inds_keep = walker_part_general[keep.flatten()]
209
+
210
+ accepted[temp_inds_keep, walker_inds_keep] = True
211
+
212
+ new_state.log_like[temp_inds_keep, walker_inds_keep] = logl[keep]
213
+ new_state.log_prior[temp_inds_keep, walker_inds_keep] = logp[keep]
214
+
215
+ for key in ["psd", "galfor"]:
216
+ new_state.branches[key].coords[temp_inds_keep, walker_inds_keep, np.zeros_like(walker_inds_keep)] = q[key][keep][:, 0]
217
+
218
+ temp_inds_fix = temp_part_general[~keep.flatten()]
219
+ walker_inds_fix = walker_part_general[~keep.flatten()]
220
+
221
+ # return unaccepted psds
222
+ data_index_fix = data_index[~keep.flatten()]
223
+ psd_params_fix = new_state.branches_coords["psd"][(temp_inds_fix, walker_inds_fix)][:, 0]
224
+ foreground_params_fix = new_state.branches_coords["galfor"][(temp_inds_fix, walker_inds_fix)][:, 0]
225
+
226
+ self.mgh.set_psd_vals(psd_params_fix, foreground_params=foreground_params_fix, overall_inds=data_index_fix)
227
+
228
+ self.accepted += accepted.astype(int)
229
+ self.num_proposals += 1
230
+
231
+ self.mempool.free_all_blocks()
232
+
233
+ if self.time % 200 == 0:
234
+ ll_after = self.mgh.get_ll(include_psd_info=True).flatten()[new_state.supplimental[:]["overall_inds"]].reshape(ntemps, nwalkers)
235
+ check = np.abs(new_state.log_like - ll_after).max()
236
+ if check > 1e-3:
237
+ breakpoint()
238
+ self.mgh.restore_base_injections()
239
+
240
+ for name in new_state.branches.keys():
241
+ if name not in ["gb", "gb"]:
242
+ continue
243
+ new_state_branch = new_state.branches[name]
244
+ coords_here = new_state_branch.coords[new_state_branch.inds]
245
+ ntemps, nwalkers, nleaves_max_here, ndim = new_state_branch.shape
246
+ try:
247
+ group_index = self.xp.asarray(
248
+ self.mgh.get_mapped_indices(
249
+ np.repeat(np.arange(ntemps * nwalkers).reshape(ntemps, nwalkers, 1), nleaves_max, axis=-1)[new_state_branch.inds]
250
+ ).astype(self.xp.int32)
251
+ )
252
+ except IndexError:
253
+ breakpoint()
254
+ coords_here_in = self.parameter_transforms.both_transforms(coords_here, xp=np)
255
+
256
+ waveform_kwargs_fill = self.waveform_kwargs.copy()
257
+ waveform_kwargs_fill["start_freq_ind"] = self.start_freq_ind
258
+
259
+ if "N" in waveform_kwargs_fill:
260
+ waveform_kwargs_fill.pop("N")
261
+
262
+ self.mgh.multiply_data(-1.)
263
+ self.gb.generate_global_template(coords_here_in, group_index, self.mgh.data_list, data_length=self.data_length, data_splits=self.mgh.gpu_splits, batch_size=1000, **waveform_kwargs_fill)
264
+ self.mgh.multiply_data(-1.)
265
+
266
+ ll_after2 = self.mgh.get_ll(use_cpu=True).flatten()[new_state.supplimental[:]["overall_inds"]].reshape(ntemps, nwalkers)
267
+ new_state.log_like = ll_after2
268
+
269
+ """
270
+ data_minus_template = self.xp.concatenate(
271
+ [
272
+ tmp.reshape(ntemps, nwalkers, 1, self.data_length) for tmp in data_minus_template_in_swap
273
+ ],
274
+ axis=2
275
+ )
276
+ del data_minus_template_in_swap
277
+
278
+ psd = self.xp.concatenate(
279
+ [
280
+ tmp.reshape(ntemps * nwalkers, 1, self.data_length) for tmp in psd_in_swap
281
+ ],
282
+ axis=1
283
+ )
284
+ del psd_in_swap
285
+ self.mempool.free_all_blocks()
286
+
287
+ new_state.supplimental.holder["data_minus_template"] = data_minus_template
288
+
289
+ lp_after = model.compute_log_prior_fn(new_state.branches_coords, inds=new_state.branches_inds)
290
+
291
+ ll_after = (-1/2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / self.xp.asarray(self.psd), axis=(2, 3))).get() # model.compute_log_like_fn(new_state.branches_coords, inds=new_state.branches_inds, logp=lp_after, supps=new_state.supplimental, branch_supps=new_state.branches_supplimental)
292
+ #check = -1/2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / self.xp.asarray(self.psd), axis=(2, 3))
293
+ #check2 = -1/2 * 4 * self.df * self.xp.sum(tmp.conj() * tmp / self.xp.asarray(self.psd), axis=(2, 3))
294
+ #print(np.abs(new_state.log_like - ll_after[0]).max())
295
+
296
+ # if any are even remotely getting to be different, reset all (small change)
297
+ if np.abs(new_state.log_like - ll_after).max() > 1e-1:
298
+ if np.abs(new_state.log_like - ll_after).max() > 1e0:
299
+ self.greater_than_1e0 += 1
300
+ print("Greater:", self.greater_than_1e0)
301
+ breakpoint()
302
+ fix_here = np.abs(new_state.log_like - ll_after) > 1e-6
303
+ data_minus_template_old = data_minus_template.copy()
304
+ data_minus_template = self.xp.zeros_like(data_minus_template_old)
305
+ data_minus_template[:] = self.xp.asarray(self.data)[None, None]
306
+ templates = self.xp.zeros_like(data_minus_template).reshape(-1, 2, data_minus_template.shape[-1])
307
+ for name in new_state.branches.keys():
308
+ if name not in ["gb", "gb"]:
309
+ continue
310
+ new_state_branch = new_state.branches[name]
311
+ coords_here = new_state_branch.coords[new_state_branch.inds]
312
+ ntemps, nwalkers, nleaves_max_here, ndim = new_state_branch.shape
313
+ try:
314
+ group_index = np.repeat(np.arange(ntemps * nwalkers).reshape(ntemps, nwalkers, 1), nleaves_max, axis=-1)[new_state_branch.inds]
315
+ except IndexError:
316
+ breakpoint()
317
+ coords_here_in = self.parameter_transforms.both_transforms(coords_here, xp=np)
318
+
319
+ self.gb.generate_global_template(coords_here_in, group_index, templates, batch_size=1000, **self.waveform_kwargs)
320
+
321
+ data_minus_template -= templates.reshape(ntemps, nwalkers, 2, templates.shape[-1])
322
+
323
+ new_like = -1 / 2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / psd, axis=(2, 3)).real.get()
324
+
325
+ new_like += self.noise_ll
326
+ new_state.log_like[:] = new_like.reshape(ntemps, nwalkers)
327
+
328
+ self.mempool.free_all_blocks()
329
+ data_minus_template_in_swap = [data_minus_template[:,:, 0, :].flatten().copy(), data_minus_template[:,:, 1, :].flatten().copy()]
330
+ del data_minus_template
331
+
332
+ psd_in_swap = [psd[:, 0, :].flatten().copy(), psd[:, 1, :].flatten().copy()]
333
+ self.mempool.free_all_blocks()
334
+ del psd
335
+ self.mempool.free_all_blocks()
336
+ """
337
+
338
+ self.mempool.free_all_blocks()
339
+
340
+ if self.temperature_control is not None:
341
+ new_state = self.temperature_control.temper_comps(new_state, adapt=False)
342
+
343
+ """# new_state, accepted = self.temperature_control.temper_comps(new_state, accepted)
344
+ self.swaps_accepted = np.zeros(ntemps - 1)
345
+ self.attempted_swaps = np.zeros(ntemps - 1)
346
+ betas = self.temperature_control.betas
347
+ for i in range(ntemps - 1, 0, -1):
348
+ bi = betas[i]
349
+ bi1 = betas[i - 1]
350
+
351
+ dbeta = bi1 - bi
352
+
353
+ iperm = np.random.permutation(nwalkers)
354
+ i1perm = np.random.permutation(nwalkers)
355
+
356
+ # need to calculate switch likelihoods
357
+
358
+ coords_iperm = new_state.branches["gb"].coords[i, iperm]
359
+ coords_i1perm = new_state.branches["gb"].coords[i - 1, i1perm]
360
+
361
+ N_vals_iperm = new_state.branches["gb"].branch_supplimental.holder["N_vals"][i, iperm]
362
+
363
+ N_vals_i1perm = new_state.branches["gb"].branch_supplimental.holder["N_vals"][i - 1, i1perm]
364
+
365
+ f_test_i = coords_iperm[None, :, :, 1] / 1e3
366
+ f_test_2_i = coords_i1perm[None, :, :, 1] / 1e3
367
+
368
+ fix_f_test_i = (np.abs(f_test_i - f_test_2_i) > (self.df * N_vals_iperm * 1.5))
369
+
370
+ if hasattr(self, "keep_bands") and self.keep_bands is not None:
371
+ band_indices = np.searchsorted(self.band_edges, f_test_i.flatten()).reshape(f_test_i.shape) - 1
372
+ keep_bands = self.keep_bands
373
+ assert isinstance(keep_bands, np.ndarray)
374
+ fix_f_test_i[~np.in1d(band_indices, keep_bands).reshape(band_indices.shape)] = True
375
+
376
+
377
+ groups = get_groups_from_band_structure(f_test_i, self.band_edges, f0_2=f_test_2_i, xp=np, num_groups_base=3, fix_f_test=fix_f_test_i)
378
+
379
+ unique_groups, group_len = np.unique(groups.flatten(), return_counts=True)
380
+
381
+ # remove information about the bad "-1" group
382
+ for check_val in [-1, -2]:
383
+ group_len = np.delete(group_len, unique_groups == check_val)
384
+ unique_groups = np.delete(unique_groups, unique_groups == check_val)
385
+
386
+ # needs to be max because some values may be missing due to evens and odds
387
+ num_groups = unique_groups.max().item() + 1
388
+
389
+ for group_iter in range(num_groups):
390
+ # st = time.perf_counter()
391
+ # sometimes you will have an extra odd or even group only
392
+ # the group_iter may not match the actual running group number in this case
393
+ if group_iter not in groups:
394
+ continue
395
+
396
+ group = [grp[i:i+1][groups == group_iter].flatten() for grp in group_temp_finder]
397
+
398
+ # st = time.perf_counter()
399
+ temp_inds_back, walkers_inds_back, leaf_inds = [self.xp.asarray(grp) for grp in group]
400
+
401
+ temp_inds_i = temp_inds_back.copy()
402
+ walkers_inds_i = walkers_inds_back.copy()
403
+
404
+ temp_inds_i1 = temp_inds_back.copy()
405
+ walkers_inds_i1 = walkers_inds_back.copy()
406
+
407
+ temp_inds_i[:] = i
408
+ walkers_inds_i[:] = self.xp.asarray(iperm)[walkers_inds_back]
409
+
410
+ temp_inds_i1[:] = i - 1
411
+ walkers_inds_i1[:] = self.xp.asarray(i1perm)[walkers_inds_back]
412
+
413
+ group_here_i = (temp_inds_i, walkers_inds_i, leaf_inds)
414
+
415
+ group_here_i1 = (temp_inds_i1, walkers_inds_i1, leaf_inds)
416
+
417
+ # factors_here = factors[group_here]
418
+ old_points = self.xp.asarray(new_state.branches["gb"].coords)[group_here_i]
419
+ new_points = self.xp.asarray(new_state.branches["gb"].coords)[group_here_i1]
420
+
421
+ N_vals_here_i = N_vals[group_here_i]
422
+
423
+ log_like_tmp = self.xp.asarray(new_state.log_like.copy())
424
+ log_prior_tmp = self.xp.asarray(new_state.log_prior.copy())
425
+
426
+ delta_logl_i = self.run_swap_ll(None, old_points, new_points, group_here_i, N_vals_here_i, waveform_kwargs_now, None, log_like_tmp, log_prior_tmp, return_at_logl=True)
427
+
428
+ # factors_here = factors[group_here]
429
+ old_points[:] = self.xp.asarray(new_state.branches["gb"].coords)[group_here_i1]
430
+ new_points[:] = self.xp.asarray(new_state.branches["gb"].coords)[group_here_i]
431
+
432
+ N_vals_here_i1 = N_vals[group_here_i1]
433
+
434
+ log_like_tmp[:] = self.xp.asarray(new_state.log_like.copy())
435
+ log_prior_tmp[:] = self.xp.asarray(new_state.log_prior.copy())
436
+
437
+ delta_logl_i1 = self.run_swap_ll(None, old_points, new_points, group_here_i1, N_vals_here_i1, waveform_kwargs_now, None, log_like_tmp, log_prior_tmp, return_at_logl=True)
438
+
439
+ paccept = dbeta * 1. / 2. * (delta_logl_i - delta_logl_i1)
440
+ raccept = np.log(np.random.uniform(size=paccept.shape[0]))
441
+
442
+ # How many swaps were accepted?
443
+ sel = paccept > self.xp.asarray(raccept)
444
+
445
+ inds_i_swap = tuple([tmp[sel].get() for tmp in list(group_here_i)])
446
+ inds_i1_swap = tuple([tmp[sel].get() for tmp in list(group_here_i1)])
447
+
448
+ group_index_i = self.xp.asarray(
449
+ self.mgh.get_mapped_indices(
450
+ temp_inds_i[sel] + nwalkers * walkers_inds_i[sel]
451
+ )
452
+ ).astype(self.xp.int32)
453
+
454
+ group_index_i1 = self.xp.asarray(
455
+ self.mgh.get_mapped_indices(
456
+ temp_inds_i1[sel] + nwalkers * walkers_inds_i1[sel]
457
+ )
458
+ ).astype(self.xp.int32)
459
+
460
+ N_vals_i = N_vals[inds_i_swap]
461
+ params_i = self.xp.asarray(new_state.branches["gb"].coords)[inds_i_swap]
462
+ params_i1 = self.xp.asarray(new_state.branches["gb"].coords)[inds_i1_swap]
463
+
464
+ params_generate = self.xp.concatenate([
465
+ params_i,
466
+ params_i1,
467
+ params_i1, # reverse of above
468
+ params_i,
469
+ ], axis=0)
470
+
471
+ params_generate_in = self.parameter_transforms.both_transforms(params_generate, xp=self.xp)
472
+
473
+ group_index_gen = self.xp.concatenate(
474
+ [
475
+ group_index_i,
476
+ group_index_i,
477
+ group_index_i1,
478
+ group_index_i1
479
+ ], dtype=self.xp.int32
480
+ )
481
+
482
+ factors_multiply_generate = self.xp.concatenate([
483
+ +1 * self.xp.ones_like(group_index_i, dtype=float),
484
+ -1 * self.xp.ones_like(group_index_i, dtype=float),
485
+ +1 * self.xp.ones_like(group_index_i, dtype=float),
486
+ -1 * self.xp.ones_like(group_index_i, dtype=float),
487
+ ])
488
+
489
+ N_vals_in_gen = self.xp.concatenate([
490
+ N_vals_i,
491
+ N_vals_i,
492
+ N_vals_i,
493
+ N_vals_i
494
+ ])
495
+
496
+ waveform_kwargs_fill = waveform_kwargs_now.copy()
497
+ waveform_kwargs_fill["start_freq_ind"] = self.start_freq_ind
498
+
499
+ self.gb.generate_global_template(
500
+ params_generate_in,
501
+ group_index_gen,
502
+ self.mgh.data_list,
503
+ N=N_vals_in_gen,
504
+ data_length=self.data_length,
505
+ data_splits=self.mgh.gpu_splits,
506
+ factors=factors_multiply_generate,
507
+ **waveform_kwargs_fill
508
+ )
509
+
510
+ # update likelihoods
511
+
512
+ # set unaccepted differences to zero
513
+ accepted_delta_ll_i = delta_logl_i * (sel)
514
+ accepted_delta_ll_i1 = delta_logl_i1 * (sel)
515
+
516
+ logl_change_contribution = np.zeros_like(log_like_tmp.get())
517
+ try:
518
+ in_tuple = (accepted_delta_ll_i[sel].get(), accepted_delta_ll_i1[sel].get(), temp_inds_i[sel].get(), temp_inds_i1[sel].get(), walkers_inds_i[sel].get(), walkers_inds_i[sel].get())
519
+ except AttributeError:
520
+ in_tuple = (accepted_delta_ll_i[sel], accepted_delta_ll_i1[sel], temp_inds_i[sel], temp_inds_i1[sel], walkers_inds_i[sel], walkers_inds_i[sel])
521
+ for j, (dlli, dlli1, ti, ti1, wi, wi1) in enumerate(zip(*in_tuple)):
522
+ logl_change_contribution[ti, wi] += dlli
523
+ logl_change_contribution[ti1, wi1] += dlli1
524
+
525
+ log_like_tmp[:] += self.xp.asarray(logl_change_contribution)
526
+
527
+ tmp_swap = new_state.branches["gb"].coords[inds_i_swap]
528
+ new_state.branches["gb"].coords[inds_i_swap] = new_state.branches["gb"].coords[inds_i1_swap]
529
+
530
+ new_state.branches["gb"].coords[inds_i1_swap] = tmp_swap
531
+
532
+ tmp_swap = new_state.branches["gb"].branch_supplimental[inds_i_swap]
533
+
534
+ new_state.branches["gb"].branch_supplimental[inds_i_swap] = new_state.branches["gb"].branch_supplimental[inds_i1_swap]
535
+
536
+ new_state.branches["gb"].branch_supplimental[inds_i1_swap] = tmp_swap
537
+
538
+ # inds are all non-zero
539
+ self.swaps_accepted[i - 1] += np.sum(sel)
540
+ self.attempted_swaps[i - 1] += sel.shape[0]
541
+
542
+ ll_after = self.mgh.get_ll(use_cpu=True).flatten()[new_state.supplimental[:]["overall_inds"]].reshape(ntemps, nwalkers)
543
+ breakpoint()
544
+ """
545
+ else:
546
+ self.temperature_control.swaps_accepted = np.zeros((ntemps - 1))
547
+
548
+
549
+ if np.any(new_state.log_like > 1e10):
550
+ breakpoint()
551
+
552
+ self.time += 1
553
+ #self.xp.cuda.runtime.deviceSynchronize()
554
+ #et = time.perf_counter()
555
+ #print("end stretch", (et - st))
556
+
557
+ self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
558
+
559
+ """et = time.perf_counter()
560
+ print("end", (et - st), group_iter, group_len[group_iter])"""
561
+
562
+ # breakpoint()
563
+ return new_state, accepted
564
+