lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lisaanalysistools might be problematic. Click here for more details.

Files changed (37) hide show
  1. lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
  2. lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
  3. lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
  4. lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
  5. lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
  6. lisatools/__init__.py +0 -0
  7. lisatools/_version.py +4 -0
  8. lisatools/analysiscontainer.py +438 -0
  9. lisatools/cutils/detector.cpython-312-darwin.so +0 -0
  10. lisatools/datacontainer.py +292 -0
  11. lisatools/detector.py +410 -0
  12. lisatools/diagnostic.py +976 -0
  13. lisatools/glitch.py +193 -0
  14. lisatools/sampling/__init__.py +0 -0
  15. lisatools/sampling/likelihood.py +882 -0
  16. lisatools/sampling/moves/__init__.py +0 -0
  17. lisatools/sampling/moves/gbgroupstretch.py +53 -0
  18. lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
  19. lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
  20. lisatools/sampling/moves/gbspecialstretch.py +1836 -0
  21. lisatools/sampling/moves/mbhspecialmove.py +286 -0
  22. lisatools/sampling/moves/placeholder.py +16 -0
  23. lisatools/sampling/moves/skymodehop.py +110 -0
  24. lisatools/sampling/moves/specialforegroundmove.py +564 -0
  25. lisatools/sampling/prior.py +508 -0
  26. lisatools/sampling/stopping.py +320 -0
  27. lisatools/sampling/utility.py +324 -0
  28. lisatools/sensitivity.py +888 -0
  29. lisatools/sources/__init__.py +0 -0
  30. lisatools/sources/emri/__init__.py +1 -0
  31. lisatools/sources/emri/tdiwaveform.py +72 -0
  32. lisatools/stochastic.py +291 -0
  33. lisatools/utils/__init__.py +0 -0
  34. lisatools/utils/constants.py +40 -0
  35. lisatools/utils/multigpudataholder.py +730 -0
  36. lisatools/utils/pointeradjust.py +106 -0
  37. lisatools/utils/utility.py +240 -0
@@ -0,0 +1,1836 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from copy import deepcopy
4
+ from inspect import Attribute
5
+ import numpy as np
6
+ from scipy import stats
7
+ import warnings
8
+ import time
9
+ from gbgpu.utils.utility import get_N
10
+ from lisatools.sensitivity import Soms_d_all, Sa_a_all
11
+
12
+ try:
13
+ import cupy as xp
14
+
15
+ gpu_available = True
16
+ except ModuleNotFoundError:
17
+ import numpy as xp
18
+
19
+ gpu_available = False
20
+
21
+ from lisatools.utils.utility import searchsorted2d_vec, get_groups_from_band_structure
22
+ from eryn.moves import StretchMove
23
+ from eryn.prior import ProbDistContainer
24
+ from eryn.utils.utility import groups_from_inds
25
+ from eryn.utils import PeriodicContainer
26
+
27
+ from eryn.moves import GroupStretchMove, Move
28
+ from eryn.moves.multipletry import logsumexp, get_mt_computations
29
+
30
+ from ...diagnostic import inner_product
31
+ from lisatools.globalfit.state import State
32
+ from lisatools.sampling.prior import GBPriorWrap
33
+
34
+
35
+ __all__ = ["GBSpecialStretchMove"]
36
+
37
+
38
+ # MHMove needs to be to the left here to overwrite GBBruteRejectionRJ RJ proposal method
39
+ class GBSpecialStretchMove(GroupStretchMove, Move):
40
+ """Generate Revesible-Jump proposals for GBs with try-force rejection
41
+
42
+ Will use gpu if template generator uses GPU.
43
+
44
+ Args:
45
+ priors (object): :class:`ProbDistContainer` object that has ``logpdf``
46
+ and ``rvs`` methods.
47
+
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ gb,
53
+ priors,
54
+ start_freq_ind,
55
+ data_length,
56
+ mgh,
57
+ fd,
58
+ band_edges,
59
+ gpu_priors,
60
+ *args,
61
+ waveform_kwargs={},
62
+ parameter_transforms=None,
63
+ snr_lim=1e-10,
64
+ rj_proposal_distribution=None,
65
+ num_repeat_proposals=1,
66
+ name=None,
67
+ use_prior_removal=False,
68
+ phase_maximize=False,
69
+ **kwargs
70
+ ):
71
+ # return_gpu is a kwarg for the stretch move
72
+ GroupStretchMove.__init__(self, *args, return_gpu=True, **kwargs)
73
+
74
+ self.gpu_priors = gpu_priors
75
+ self.name = name
76
+ self.num_repeat_proposals = num_repeat_proposals
77
+ self.use_prior_removal = use_prior_removal
78
+
79
+ for key in priors:
80
+ if not isinstance(priors[key], ProbDistContainer) and not isinstance(priors[key], GBPriorWrap):
81
+ raise ValueError(
82
+ "Priors need to be eryn.priors.ProbDistContainer object."
83
+ )
84
+
85
+ self.priors = priors
86
+ self.gb = gb
87
+ self.stop_here = True
88
+
89
+ args = [priors["gb"].priors_in[(0, 1)].rho_star]
90
+ args += [priors["gb"].priors_in[(0, 1)].frequency_prior.min_val, priors["gb"].priors_in[(0, 1)].frequency_prior.max_val]
91
+ for i in range(2, 8):
92
+ args += [priors["gb"].priors_in[i].min_val, priors["gb"].priors_in[i].max_val]
93
+
94
+ self.gpu_cuda_priors = self.gb.pyPriorPackage(*tuple(args))
95
+ self.gpu_cuda_wrap = self.gb.pyPeriodicPackage(2 * np.pi, np.pi, 2 * np.pi)
96
+
97
+ # use gpu from template generator
98
+ # self.use_gpu = gb.use_gpu
99
+ if self.use_gpu:
100
+ self.mempool = self.xp.get_default_memory_pool()
101
+
102
+ self.band_edges = band_edges
103
+ self.num_bands = len(band_edges) - 1
104
+ self.start_freq_ind = start_freq_ind
105
+ self.data_length = data_length
106
+ self.waveform_kwargs = waveform_kwargs
107
+ self.parameter_transforms = parameter_transforms
108
+ self.fd = fd
109
+ self.df = (fd[1] - fd[0]).item()
110
+ self.mgh = mgh
111
+ self.phase_maximize = phase_maximize
112
+
113
+ self.snr_lim = snr_lim
114
+
115
+ self.band_edges = self.xp.asarray(self.band_edges)
116
+
117
+ self.rj_proposal_distribution = rj_proposal_distribution
118
+ self.is_rj_prop = self.rj_proposal_distribution is not None
119
+
120
+ # setup N vals for bands
121
+ band_mean_f = (self.band_edges[1:] + self.band_edges[:-1]).get() / 2
122
+ self.band_N_vals = xp.asarray(get_N(np.full_like(band_mean_f, 1e-30), band_mean_f, self.waveform_kwargs["T"], self.waveform_kwargs["oversample"]))
123
+
124
+ def setup_gbs(self, branch):
125
+ st = time.perf_counter()
126
+ coords = branch.coords
127
+ inds = branch.inds
128
+ supps = branch.branch_supplimental
129
+ ntemps, nwalkers, nleaves_max, ndim = branch.shape
130
+ all_remaining_freqs = coords[0][inds[0]][:, 1]
131
+
132
+ all_remaining_cords = coords[0][inds[0]]
133
+
134
+ num_remaining = len(all_remaining_freqs)
135
+
136
+ all_temp_fs = self.xp.asarray(coords[inds][:, 1])
137
+
138
+ # TODO: improve this?
139
+ self.inds_freqs_sorted = self.xp.asarray(np.argsort(all_remaining_freqs))
140
+ self.freqs_sorted = self.xp.asarray(np.sort(all_remaining_freqs))
141
+ self.all_coords_sorted = self.xp.asarray(all_remaining_cords)[
142
+ self.inds_freqs_sorted
143
+ ]
144
+
145
+ total_binaries = inds.sum().item()
146
+ still_going = xp.ones(total_binaries, dtype=bool)
147
+ inds_zero = xp.searchsorted(self.freqs_sorted, all_temp_fs, side="right") - 1
148
+ left_inds = inds_zero - int(self.nfriends / 2)
149
+ right_inds = inds_zero + int(self.nfriends / 2) - 1
150
+
151
+ # do right first here
152
+ right_inds[left_inds < 0] = self.nfriends - 1
153
+ left_inds[left_inds < 0] = 0
154
+
155
+ # do left first here
156
+ left_inds[right_inds > len(self.freqs_sorted) - 1] = len(self.freqs_sorted) - self.nfriends
157
+ right_inds[right_inds > len(self.freqs_sorted) - 1] = len(self.freqs_sorted) - 1
158
+
159
+ assert np.all(right_inds - left_inds == self.nfriends - 1)
160
+ assert not np.any(right_inds < 0) and not np.any(right_inds > len(self.freqs_sorted) - 1) and not np.any(left_inds < 0) and not np.any(left_inds > len(self.freqs_sorted) - 1)
161
+
162
+ jjj = 0
163
+ while np.any(still_going):
164
+ distance_left = np.abs(all_temp_fs[still_going] - self.freqs_sorted[left_inds[still_going]])
165
+ distance_right = np.abs(all_temp_fs[still_going] - self.freqs_sorted[right_inds[still_going]])
166
+
167
+ check_move_right = (distance_right <= distance_left)
168
+ check_left_inds = left_inds[still_going][check_move_right] + 1
169
+ check_right_inds = right_inds[still_going][check_move_right] + 1
170
+
171
+ new_distance_right = np.abs(all_temp_fs[still_going][check_move_right] - self.freqs_sorted[check_right_inds])
172
+
173
+ change_inds = xp.arange(len(all_temp_fs))[still_going][check_move_right][(new_distance_right < distance_left[check_move_right]) & (check_right_inds < len(self.freqs_sorted))]
174
+
175
+ left_inds[change_inds] += 1
176
+ right_inds[change_inds] += 1
177
+
178
+ stop_inds_right_1 = xp.arange(len(all_temp_fs))[still_going][check_move_right][(check_right_inds >= len(self.freqs_sorted))]
179
+
180
+ # last part is just for up here, below it will remove if it is still equal
181
+ stop_inds_right_2 = xp.arange(len(all_temp_fs))[still_going][check_move_right][(new_distance_right >= distance_left[check_move_right]) & (check_right_inds < len(self.freqs_sorted)) & (distance_right[check_move_right] != distance_left[check_move_right])]
182
+ stop_inds_right = xp.concatenate([stop_inds_right_1, stop_inds_right_2])
183
+ assert np.all(still_going[stop_inds_right])
184
+
185
+ # equal to should only be left over if it was equal above and moving right did not help
186
+ check_move_left = (distance_left <= distance_right)
187
+ check_left_inds = left_inds[still_going][check_move_left] - 1
188
+ check_right_inds = right_inds[still_going][check_move_left] - 1
189
+
190
+ new_distance_left = np.abs(all_temp_fs[still_going][check_move_left] - self.freqs_sorted[check_left_inds])
191
+
192
+ change_inds = xp.arange(len(all_temp_fs))[still_going][check_move_left][(new_distance_left < distance_right[check_move_left]) & (check_left_inds >= 0)]
193
+
194
+ left_inds[change_inds] -= 1
195
+ right_inds[change_inds] -= 1
196
+
197
+ stop_inds_left_1 = xp.arange(len(all_temp_fs))[still_going][check_move_left][(check_left_inds < 0)]
198
+ stop_inds_left_2 = xp.arange(len(all_temp_fs))[still_going][check_move_left][(new_distance_left >= distance_right[check_move_left]) & (check_left_inds >= 0)]
199
+ stop_inds_left = xp.concatenate([stop_inds_left_1, stop_inds_left_2])
200
+
201
+ stop_inds = xp.concatenate([stop_inds_right, stop_inds_left])
202
+ still_going[stop_inds] = False
203
+ # print(jjj, still_going.sum())
204
+ if jjj >= self.nfriends:
205
+ breakpoint()
206
+ jjj += 1
207
+
208
+ start_inds = left_inds.copy().get()
209
+
210
+ start_inds_all = np.zeros_like(inds, dtype=np.int32)
211
+ start_inds_all[inds] = start_inds.astype(np.int32)
212
+
213
+ if "friend_start_inds" not in supps:
214
+ supps.add_objects({"friend_start_inds": start_inds_all})
215
+ else:
216
+ supps[:] = {"friend_start_inds": start_inds_all}
217
+
218
+ self.stretch_friends_args_in = tuple([tmp.copy() for tmp in self.all_coords_sorted.T])
219
+ et = time.perf_counter()
220
+ self.mempool.free_all_blocks()
221
+ # print("SETUP:", et - st)
222
+ # start_inds_freq_out = np.zeros((ntemps, nwalkers, nleaves_max), dtype=int)
223
+ # freqs_sorted_here = self.freqs_sorted.get()
224
+ # freqs_remaining_here = all_remaining_freqs
225
+
226
+ # start_ind_best = np.zeros_like(freqs_remaining_here, dtype=int)
227
+
228
+ # best_index = (
229
+ # np.searchsorted(freqs_sorted_here, freqs_remaining_here, side="right") - 1
230
+ # )
231
+ # best_index[best_index < self.nfriends] = self.nfriends
232
+ # best_index[best_index >= len(freqs_sorted_here) - self.nfriends] = (
233
+ # len(freqs_sorted_here) - self.nfriends
234
+ # )
235
+ # check_inds = (
236
+ # best_index[:, None]
237
+ # + np.tile(np.arange(2 * self.nfriends), (best_index.shape[0], 1))
238
+ # - self.nfriends
239
+ # )
240
+
241
+ # check_freqs = freqs_sorted_here[check_inds]
242
+ # breakpoint()
243
+
244
+ # # batch_count = 1000
245
+ # # split_inds = np.arange(batch_count, freqs_remaining_here.shape[0], batch_count)
246
+
247
+ # # splits_remain = np.split(freqs_remaining_here, split_inds)
248
+ # # splits_check = np.split(check_freqs, split_inds)
249
+
250
+ # # out = []
251
+ # # for i, (split_r, split_c) in enumerate(zip(splits_remain, splits_check)):
252
+ # # out.append(np.abs(split_r[:, None] - split_c))
253
+ # # print(i)
254
+
255
+ # # freq_distance = np.asarray(out)
256
+
257
+ # freq_distance = np.abs(freqs_remaining_here[:, None] - check_freqs)
258
+ # breakpoint()
259
+
260
+ # keep_min_inds = np.argsort(freq_distance, axis=-1)[:, : self.nfriends].min(
261
+ # axis=-1
262
+ # )
263
+ # start_inds_freq = check_inds[(np.arange(len(check_inds)), keep_min_inds)]
264
+
265
+ # start_inds_freq_out[inds] = start_inds_freq
266
+
267
+ # start_inds_freq_out[~inds] = -1
268
+
269
+ # if "friend_start_inds" not in supps:
270
+ # supps.add_objects({"friend_start_inds": start_inds_freq_out})
271
+ # else:
272
+ # supps[:] = {"friend_start_inds": start_inds_freq_out}
273
+
274
+ # self.all_friends_start_inds_sorted = self.xp.asarray(
275
+ # start_inds_freq_out[inds][self.inds_freqs_sorted.get()]
276
+ # )
277
+
278
+ def find_friends(self, name, gb_points_to_move, s_inds=None):
279
+ if s_inds is None:
280
+ raise ValueError
281
+
282
+ inds_points_to_move = self.xp.asarray(s_inds.flatten())
283
+
284
+ half_friends = int(self.nfriends / 2)
285
+
286
+ gb_points_for_move = gb_points_to_move.reshape(-1, 8).copy()
287
+
288
+ if not hasattr(self, "ntemps"):
289
+ self.ntemps = 1
290
+
291
+ inds_start_freq_to_move = self.current_friends_start_inds[
292
+ inds_points_to_move.reshape(self.ntemps, self.nwalkers, -1)
293
+ ]
294
+
295
+ deviation = self.xp.random.randint(
296
+ 0, self.nfriends, size=len(inds_start_freq_to_move)
297
+ )
298
+
299
+ inds_keep_friends = inds_start_freq_to_move + deviation
300
+
301
+ inds_keep_friends[inds_keep_friends < 0] = 0
302
+ inds_keep_friends[inds_keep_friends >= len(self.all_coords_sorted)] = (
303
+ len(self.all_coords_sorted) - 1
304
+ )
305
+
306
+ gb_points_for_move[inds_points_to_move] = self.all_coords_sorted[
307
+ inds_keep_friends
308
+ ]
309
+
310
+ return gb_points_for_move.reshape(self.ntemps, -1, 1, 8)
311
+
312
+ def new_find_friends(self, name, inds_in):
313
+ inds_start_freq_to_move = self.current_friends_start_inds[tuple(inds_in)]
314
+
315
+ deviation = self.xp.random.randint(
316
+ 0, self.nfriends, size=len(inds_start_freq_to_move)
317
+ )
318
+
319
+ inds_keep_friends = inds_start_freq_to_move + deviation
320
+
321
+ inds_keep_friends[inds_keep_friends < 0] = 0
322
+ inds_keep_friends[inds_keep_friends >= len(self.all_coords_sorted)] = (
323
+ len(self.all_coords_sorted) - 1
324
+ )
325
+
326
+ gb_points_for_move = self.all_coords_sorted[
327
+ inds_keep_friends
328
+ ]
329
+
330
+ return gb_points_for_move
331
+
332
+ def setup(self, branches):
333
+ for i, (name, branch) in enumerate(branches.items()):
334
+ if name != "gb":
335
+ continue
336
+
337
+ if not self.is_rj_prop and self.time % self.n_iter_update == 0:
338
+ self.setup_gbs(branch)
339
+
340
+ elif self.is_rj_prop and self.time == 0:
341
+ ndim = branch.shape[-1]
342
+ self.stretch_friends_args_in = tuple([xp.array([]) for _ in range(ndim)])
343
+
344
+ # update any shifted start inds due to tempering (need to do this every non-rj move)
345
+ """if not self.is_rj_prop:
346
+ # fix the ones that have been added in RJ
347
+ fix = (
348
+ branch.branch_supplimental.holder["friend_start_inds"][:] == -1
349
+ ) & branch.inds
350
+
351
+ if np.any(fix):
352
+ new_freqs = xp.asarray(branch.coords[fix][:, 1])
353
+ # TODO: is there a better way of doing this?
354
+
355
+ # fill information into friend finder for new binaries
356
+ branch.branch_supplimental.holder["friend_start_inds"][fix] = (
357
+ (
358
+ xp.searchsorted(self.freqs_sorted, new_freqs, side="right")
359
+ - 1
360
+ )
361
+ * (
362
+ (new_freqs > self.freqs_sorted[0])
363
+ & (new_freqs < self.freqs_sorted[-1])
364
+ )
365
+ + 0 * (new_freqs < self.freqs_sorted[0])
366
+ + (len(self.freqs_sorted) - 1)
367
+ * (new_freqs > self.freqs_sorted[-1])
368
+ ).get()
369
+
370
+ # make sure current start inds reflect alive binaries
371
+ self.current_friends_start_inds = self.xp.asarray(
372
+ branch.branch_supplimental.holder["friend_start_inds"][:]
373
+ )
374
+ """
375
+
376
+ self.mempool.free_all_blocks()
377
+
378
+ def propose(self, model, state):
379
+ """Use the move to generate a proposal and compute the acceptance
380
+
381
+ Args:
382
+ model (:class:`eryn.model.Model`): Carrier of sampler information.
383
+ state (:class:`State`): Current state of the sampler.
384
+
385
+ Returns:
386
+ :class:`State`: State of sampler after proposal is complete.
387
+
388
+ """
389
+ st = time.perf_counter()
390
+
391
+ self.xp.cuda.runtime.setDevice(self.mgh.gpus[0])
392
+
393
+ self.current_state = state
394
+ np.random.seed(10)
395
+ # print("start stretch")
396
+
397
+ # Check that the dimensions are compatible.
398
+ ndim_total = 0
399
+ for branch in state.branches.values():
400
+ ntemps, nwalkers, nleaves_, ndim_ = branch.shape
401
+ ndim_total += ndim_ * nleaves_
402
+
403
+ self.nwalkers = nwalkers
404
+
405
+ # Run any move-specific setup.
406
+ self.setup(state.branches)
407
+
408
+ new_state = State(state, copy=True)
409
+ band_temps = xp.asarray(state.band_info["band_temps"].copy())
410
+
411
+ if self.is_rj_prop:
412
+ orig_store = new_state.log_like[0].copy()
413
+ gb_coords = xp.asarray(new_state.branches["gb"].coords)
414
+
415
+ self.mempool.free_all_blocks()
416
+
417
+ # self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
418
+
419
+ # data should not be whitened
420
+
421
+ ntemps, nwalkers, nleaves_max, ndim = state.branches_coords["gb"].shape
422
+
423
+ group_temp_finder = [
424
+ self.xp.repeat(self.xp.arange(ntemps), nwalkers * nleaves_max).reshape(
425
+ ntemps, nwalkers, nleaves_max
426
+ ),
427
+ self.xp.tile(self.xp.arange(nwalkers), (ntemps, nleaves_max, 1)).transpose(
428
+ (0, 2, 1)
429
+ ),
430
+ self.xp.tile(self.xp.arange(nleaves_max), ((ntemps, nwalkers, 1))),
431
+ ]
432
+
433
+ """(
434
+ gb_coords,
435
+ gb_inds_orig,
436
+ points_curr,
437
+ prior_all_curr,
438
+ gb_inds,
439
+ N_vals_in,
440
+ prop_to_curr_map,
441
+ factors,
442
+ inds_curr,
443
+ proposal_specific_information
444
+ ) = self.get_special_proposal_setup(model, new_state, state, group_temp_finder)"""
445
+
446
+ waveform_kwargs_now = self.waveform_kwargs.copy()
447
+ if "N" in waveform_kwargs_now:
448
+ waveform_kwargs_now.pop("N")
449
+ waveform_kwargs_now["start_freq_ind"] = self.start_freq_ind
450
+
451
+ # if self.is_rj_prop:
452
+ # print("START:", new_state.log_like[0])
453
+ log_like_tmp = self.xp.asarray(new_state.log_like)
454
+ log_prior_tmp = self.xp.asarray(new_state.log_prior)
455
+
456
+ self.mempool.free_all_blocks()
457
+
458
+ gb_inds = self.xp.asarray(new_state.branches["gb"].inds)
459
+ gb_inds_orig = gb_inds.copy()
460
+
461
+ data = self.mgh.data_list
462
+ base_data = self.mgh.base_data_list
463
+ psd = self.mgh.psd_list
464
+ lisasens = self.mgh.lisasens_list
465
+
466
+ # do unique for band size as separator between asynchronous kernel launches
467
+ # band_indices = self.xp.asarray(new_state.branches["gb"].branch_supplimental.holder["band_inds"])
468
+ band_indices = self.xp.searchsorted(self.band_edges, xp.asarray(new_state.branches["gb"].coords[:, :, :, 1]).flatten() / 1e3, side="right").reshape(new_state.branches["gb"].coords[:, :, :, 1].shape) - 1
469
+
470
+ # N_vals_in = self.xp.asarray(new_state.branches["gb"].branch_supplimental.holder["N_vals"])
471
+ points_curr = self.xp.asarray(new_state.branches["gb"].coords)
472
+ points_curr_orig = points_curr.copy()
473
+ # N_vals_in_orig = N_vals_in.copy()
474
+ band_indices_orig = band_indices.copy()
475
+
476
+ if self.is_rj_prop:
477
+ if isinstance(self.rj_proposal_distribution["gb"], list):
478
+ raise NotImplementedError
479
+ assert len(self.rj_proposal_distribution["gb"]) == ntemps
480
+ proposal_logpdf = xp.zeros((points_curr.shape[0], np.prod(points_curr.shape[1:-1])))
481
+ for t in range(ntemps):
482
+ new_sources = xp.full_like(points_curr[t, ~gb_inds[t]], np.nan)
483
+ fix = xp.full(new_sources.shape[0], True)
484
+ while xp.any(fix):
485
+ new_sources[fix] = self.rj_proposal_distribution["gb"][t].rvs(size=fix.sum().item())
486
+ fix = xp.any(xp.isnan(new_sources), axis=-1)
487
+ points_curr[t, ~gb_inds[t]] = new_sources
488
+ band_indices[t, ~gb_inds[t]] = self.xp.searchsorted(self.band_edges, new_sources[:, 1] / 1e3, side="right") - 1
489
+ # change gb_inds
490
+ gb_inds[t, :] = True
491
+ assert np.all(gb_inds[t])
492
+ proposal_logpdf[t] = self.rj_proposal_distribution["gb"][t].logpdf(
493
+ points_curr[t, gb_inds[t]]
494
+ )
495
+
496
+ proposal_logpdf = proposal_logpdf.flatten().copy()
497
+
498
+ else:
499
+ new_sources = xp.full_like(points_curr[~gb_inds], np.nan)
500
+ fix = xp.full(new_sources.shape[0], True)
501
+ while xp.any(fix):
502
+ if self.name == "rj_prior":
503
+ new_sources[fix] = self.rj_proposal_distribution["gb"].rvs(size=fix.sum().item(), psds=self.mgh.psd_shaped[0][0], walker_inds=group_temp_finder[1][~gb_inds][fix])
504
+ else:
505
+ new_sources[fix] = self.rj_proposal_distribution["gb"].rvs(size=fix.sum().item())
506
+
507
+ fix = xp.any(xp.isnan(new_sources), axis=-1)
508
+ points_curr[~gb_inds] = new_sources
509
+
510
+ band_indices[~gb_inds] = self.xp.searchsorted(self.band_edges, new_sources[:, 1] / 1e3, side="right") - 1
511
+ # N_vals_in[~gb_inds] = self.xp.asarray(
512
+ # get_N(
513
+ # xp.full_like(new_sources[:, 1], 1e-30),
514
+ # new_sources[:, 1] / 1e3,
515
+ # self.waveform_kwargs["T"],
516
+ # self.waveform_kwargs["oversample"],
517
+ # xp=self.xp
518
+ # )
519
+ # )
520
+
521
+ # change gb_inds
522
+ gb_inds[:] = True
523
+ assert np.all(gb_inds)
524
+ if self.name == "rj_prior":
525
+ proposal_logpdf = self.rj_proposal_distribution["gb"].logpdf(
526
+ points_curr[gb_inds], psds=self.mgh.psd_shaped[0][0], walker_inds=group_temp_finder[1][gb_inds]
527
+ )
528
+ else:
529
+ proposal_logpdf = self.rj_proposal_distribution["gb"].logpdf(
530
+ points_curr[gb_inds]
531
+ )
532
+
533
+ factors = (proposal_logpdf * -1) * (~gb_inds_orig).flatten() + (proposal_logpdf * +1) * (gb_inds_orig).flatten()
534
+
535
+ if self.name == "rj_prior" and self.use_prior_removal:
536
+ factors[~gb_inds_orig.flatten()] = -1e300
537
+ else:
538
+ factors = xp.zeros(gb_inds_orig.sum().item())
539
+
540
+ start_inds_all = xp.asarray(new_state.branches["gb"].branch_supplimental.holder["friend_start_inds"], dtype=xp.int32)[gb_inds]
541
+ points_curr = points_curr[gb_inds]
542
+ # N_vals_in = N_vals_in[gb_inds]
543
+ band_indices = band_indices[gb_inds]
544
+ gb_inds_in = gb_inds_orig[gb_inds]
545
+
546
+ temp_indices = group_temp_finder[0][gb_inds]
547
+ walker_indices = group_temp_finder[1][gb_inds]
548
+ leaf_indices = group_temp_finder[2][gb_inds]
549
+
550
+ unique_N = self.xp.unique(self.band_N_vals)
551
+ # remove 0
552
+ unique_N = unique_N[unique_N != 0]
553
+
554
+ do_synchronize = False
555
+ device = self.xp.cuda.runtime.getDevice()
556
+
557
+ units = 2 if not self.is_rj_prop else 2
558
+ # random start to rotation around
559
+ start_unit = model.random.randint(units)
560
+
561
+ ll_after = (
562
+ self.mgh.get_ll(include_psd_info=True)
563
+ )
564
+ # print(np.abs(new_state.log_like - ll_after).max())
565
+ store_max_diff = np.abs(new_state.log_like[0] - ll_after).max()
566
+ # print("CHECKING 0:", store_max_diff, self.is_rj_prop)
567
+ self.check_ll_inject(new_state)
568
+
569
+ per_walker_band_proposals = xp.zeros((ntemps, nwalkers, self.num_bands), dtype=int)
570
+ per_walker_band_accepted = xp.zeros((ntemps, nwalkers, self.num_bands), dtype=int)
571
+
572
+ total_keep = 0
573
+ for tmp in range(units):
574
+ remainder = (start_unit + tmp) % units
575
+ all_inputs = []
576
+ band_bookkeep_info = []
577
+ indiv_info = []
578
+ current_parameter_arrays = []
579
+ N_vals_list = []
580
+ output_info = []
581
+ inds_list = []
582
+ inds_orig_list = []
583
+ bands_list = []
584
+ for N_now in unique_N:
585
+ N_now = N_now.item()
586
+ if N_now == 0: # or N_now != 1024:
587
+ continue
588
+
589
+ # old_data_1 = self.mgh.data_shaped[0][0][11].copy()
590
+ # checkit = np.where(new_state.supplimental[:]["overall_inds"] == 0)
591
+
592
+ # TODO; check the maximum allowable band
593
+ keep = (
594
+ (band_indices % units == remainder)
595
+ # & (temp_indices == 0) # & (walker_indices == 2) # & (band_indices < 50)
596
+ & (self.band_N_vals[band_indices] == N_now)
597
+ & (band_indices < len(self.band_edges) - 2)
598
+ # & (band_indices == 1)
599
+ ) # & (band_indices == 501) # # & (N_vals_in <= 256) & (temp_inds == checkit[0].item()) & (walker_inds == checkit[1].item()) # & (band_indices < 540) # & (temp_inds == 0) & (walker_inds == 0)
600
+
601
+ # if self.time > 0:
602
+ # keep[np.where(keep)[0][1:]] = False
603
+ """if self.time > 0:
604
+ keep = (
605
+ (band_indices % units == remainder)
606
+ & (band_indices == 346) & (walker_indices == 9) & (temp_indices == 0)
607
+ & (self.band_N_vals[band_indices] == N_now)
608
+ & (band_indices < len(self.band_edges) - 2)
609
+ ) # & (band_indices == 501) # # & (N_vals_in <= 256) & (temp_inds == checkit[0].item()) & (walker_inds == checkit[1].item()) # & (band_indices < 540) # & (temp_inds == 0) & (walker_inds == 0)
610
+ if np.any(keep):
611
+ breakpoint()"""
612
+
613
+ if keep.sum().item() == 0:
614
+ continue
615
+
616
+ total_keep += keep.sum().item()
617
+ # for testing
618
+ # keep[:] = False
619
+ # keep[tmp_keep] = True
620
+ # keep[3000:3020:1] = True
621
+
622
+ # permutate for rj just in case
623
+ permute_inds = xp.random.permutation(xp.arange(keep.sum()))
624
+ params_curr = points_curr[keep][permute_inds]
625
+ start_inds_here = start_inds_all[keep][permute_inds]
626
+ inds_here = gb_inds_in[keep][permute_inds]
627
+ factors_here = factors[keep][permute_inds]
628
+
629
+ prior_all_curr_here = self.gpu_priors["gb"].logpdf(params_curr, psds=self.mgh.psd_shaped[0][0], walker_inds=group_temp_finder[1][gb_inds][keep][permute_inds])
630
+ if not self.is_rj_prop:
631
+ assert xp.all(~xp.isinf(prior_all_curr_here))
632
+
633
+ temp_inds_here = temp_indices[keep][permute_inds]
634
+ walker_inds_here = walker_indices[keep][permute_inds]
635
+ leaf_inds_here = leaf_indices[keep][permute_inds]
636
+ band_inds_here = band_indices[keep][permute_inds]
637
+
638
+ special_band_inds = (temp_inds_here * nwalkers + walker_inds_here) * int(1e6) + band_inds_here
639
+
640
+ sort_special = xp.argsort(special_band_inds)
641
+
642
+ params_curr = params_curr[sort_special]
643
+ inds_here = inds_here[sort_special]
644
+ factors_here = factors_here[sort_special]
645
+ start_inds_here = start_inds_here[sort_special]
646
+ prior_all_curr_here = prior_all_curr_here[sort_special]
647
+ temp_inds_here = temp_inds_here[sort_special]
648
+ walker_inds_here = walker_inds_here[sort_special]
649
+ leaf_inds_here = leaf_inds_here[sort_special]
650
+ band_inds_here = band_inds_here[sort_special]
651
+ special_band_inds = special_band_inds[sort_special]
652
+
653
+ # assert np.all(np.sort(special_band_inds_sorted_here) == special_band_inds_sorted_here)
654
+
655
+ (
656
+ uni_special_band_inds_here,
657
+ uni_index_special_band_inds_here,
658
+ uni_count_special_band_inds_here,
659
+ ) = self.xp.unique(
660
+ special_band_inds, return_index=True, return_counts=True
661
+ )
662
+
663
+ band_start_bin_ind_here = uni_index_special_band_inds_here.astype(
664
+ np.int32
665
+ )
666
+
667
+ band_num_bins_here = uni_count_special_band_inds_here.astype(np.int32)
668
+
669
+ band_inds = band_inds_here[uni_index_special_band_inds_here]
670
+ band_temps_inds = temp_inds_here[uni_index_special_band_inds_here]
671
+ band_walkers_inds = walker_inds_here[uni_index_special_band_inds_here]
672
+
673
+ band_inv_temp_vals_here = band_temps[band_inds, band_temps_inds]
674
+
675
+ indiv_info.append((temp_inds_here, walker_inds_here, leaf_inds_here))
676
+
677
+ band_bookkeep_info.append(
678
+ (band_temps_inds, band_walkers_inds, band_inds)
679
+ )
680
+
681
+ data_index_here = (((band_inds + 1) % 2) * nwalkers + band_walkers_inds).astype(np.int32)
682
+ noise_index_here = (band_walkers_inds).astype(np.int32)
683
+
684
+ # for updates
685
+ update_data_index_here = (((band_inds + 0) % 2) * nwalkers + band_walkers_inds).astype(np.int32)
686
+
687
+ L_contribution_here = xp.zeros_like(band_inds, dtype=complex)
688
+ p_contribution_here = xp.zeros_like(band_inds, dtype=complex)
689
+
690
+ buffer = 5 # bins
691
+
692
+ start_band_index = ((self.band_edges[band_inds] / self.df).astype(int) - N_now - 1).astype(np.int32)
693
+ end_band_index = (((self.band_edges[band_inds + 1]) / self.df).astype(int) + N_now + 1).astype(np.int32)
694
+
695
+ start_band_index[start_band_index < 0] = 0
696
+ end_band_index[end_band_index >= len(self.fd)] = len(self.fd) - 1
697
+
698
+ band_lengths = end_band_index - start_band_index
699
+
700
+ start_interest_band_index = ((self.band_edges[band_inds] / self.df).astype(int)).astype(np.int32) # - N_now).astype(np.int32)
701
+ end_interest_band_index = (((self.band_edges[band_inds + 1]) / self.df).astype(int)).astype(np.int32) # + N_now).astype(np.int32)
702
+
703
+ start_interest_band_index[start_interest_band_index < 0] = 0
704
+ end_interest_band_index[end_interest_band_index >= len(self.fd)] = len(self.fd) - 1
705
+
706
+ band_interest_lengths = end_interest_band_index - start_interest_band_index
707
+
708
+ if False: # self.is_rj_prop:
709
+ fmin_allow = ((self.band_edges[band_inds] / self.df).astype(int) + 1) * self.df
710
+ fmax_allow = ((self.band_edges[band_inds + 1] / self.df).astype(int) - 1) * self.df
711
+
712
+ else:
713
+ # proposal limits in a band
714
+ fmin_allow = ((self.band_edges[band_inds] / self.df).astype(int)- (N_now / 2)) * self.df
715
+ fmax_allow = ((self.band_edges[band_inds + 1] / self.df).astype(int) + (N_now / 2 )) * self.df
716
+
717
+ # if self.is_rj_prop:
718
+ # breakpoint()
719
+ fmin_allow[fmin_allow < self.band_edges[0]] = self.band_edges[0]
720
+ fmax_allow[fmax_allow > self.band_edges[-1]] = self.band_edges[-1]
721
+
722
+ max_data_store_size = band_lengths.max().item()
723
+
724
+ num_bands_here = len(band_inds)
725
+
726
+ # makes in-model effectively not tempered
727
+ # if not self.is_rj_prop:
728
+ # band_inv_temp_vals_here[:] = 1.0
729
+
730
+ params_curr_separated = tuple([params_curr[:, i].copy() for i in range(params_curr.shape[1])])
731
+ params_curr_separated_orig = tuple([params_curr[:, i].copy() for i in range(params_curr.shape[1])])
732
+ params_extra_params = (
733
+ self.waveform_kwargs["T"],
734
+ self.waveform_kwargs["dt"],
735
+ N_now,
736
+ params_curr.shape[0],
737
+ self.start_freq_ind,
738
+ Soms_d_all["sangria"] ** (1/2),
739
+ Sa_a_all["sangria"] ** (1/2),
740
+ 1e-100, 0.0, 0.0, 0.0, 0.0, # foreground params for snr -> amp transform
741
+ )
742
+
743
+ gb_params_curr_in = self.gb.pyGalacticBinaryParams(
744
+ *(params_curr_separated + params_curr_separated_orig + params_extra_params)
745
+ )
746
+
747
+ current_parameter_arrays.append(params_curr_separated)
748
+
749
+ data_package = self.gb.pyDataPackage(
750
+ data[0][0],
751
+ data[1][0],
752
+ base_data[0][0],
753
+ base_data[1][0],
754
+ psd[0][0],
755
+ psd[1][0],
756
+ lisasens[0][0],
757
+ lisasens[1][0],
758
+ self.df,
759
+ self.data_length,
760
+ self.nwalkers * self.ntemps,
761
+ self.nwalkers * self.ntemps
762
+ )
763
+
764
+ loc_band_index = xp.arange(num_bands_here, dtype=xp.int32)
765
+ band_package = self.gb.pyBandPackage(
766
+ loc_band_index,
767
+ data_index_here,
768
+ noise_index_here,
769
+ band_start_bin_ind_here, # uni_index
770
+ band_num_bins_here, # uni_count
771
+ start_band_index,
772
+ band_lengths,
773
+ start_interest_band_index,
774
+ band_interest_lengths,
775
+ num_bands_here,
776
+ max_data_store_size,
777
+ fmin_allow,
778
+ fmax_allow,
779
+ update_data_index_here,
780
+ self.ntemps,
781
+ band_inds.astype(np.int32),
782
+ band_walkers_inds.astype(np.int32),
783
+ band_temps_inds.astype(np.int32),
784
+ xp.zeros_like(band_temps_inds, dtype=np.int32), # swaps propsoed
785
+ xp.zeros_like(band_temps_inds, dtype=np.int32) # swaps accepted
786
+ )
787
+
788
+ accepted_out_here = xp.zeros_like(band_start_bin_ind_here, dtype=xp.int32)
789
+
790
+ mcmc_info = self.gb.pyMCMCInfo(
791
+ L_contribution_here,
792
+ p_contribution_here,
793
+ prior_all_curr_here,
794
+ accepted_out_here,
795
+ band_inv_temp_vals_here, # band_inv_temp_vals
796
+ self.is_rj_prop,
797
+ self.phase_maximize,
798
+ self.snr_lim,
799
+ )
800
+
801
+ if not self.is_rj_prop:
802
+ num_proposals_here = self.num_repeat_proposals
803
+
804
+ else:
805
+ num_proposals_here = 1
806
+
807
+ num_proposals_per_band = band_num_bins_here * num_proposals_here
808
+
809
+ assert start_inds_here.dtype == np.int32
810
+
811
+ proposal_info = self.gb.pyStretchProposalPackage(
812
+ *(self.stretch_friends_args_in + (self.nfriends, len(self.stretch_friends_args_in[0]), num_proposals_here, self.a, ndim, inds_here, factors_here, start_inds_here))
813
+ )
814
+
815
+ output_info.append([L_contribution_here, p_contribution_here, accepted_out_here, num_proposals_per_band])
816
+
817
+ inputs_now = (
818
+ data_package,
819
+ band_package,
820
+ gb_params_curr_in,
821
+ mcmc_info,
822
+ self.gpu_cuda_priors,
823
+ proposal_info,
824
+ self.gpu_cuda_wrap,
825
+ device,
826
+ do_synchronize,
827
+ )
828
+
829
+ # if self.is_rj_prop:
830
+ # prior_check = xp.zeros(params_curr.shape[0])
831
+ # self.gb.check_prior_vals(prior_check, self.gpu_cuda_priors, gb_params_curr_in, 100)
832
+ # prior_check2 = self.gpu_priors["gb"].logpdf(params_curr)
833
+ # breakpoint()
834
+
835
+
836
+ N_vals_list.append(N_now)
837
+ inds_list.append(inds_here)
838
+ inds_orig_list.append(inds_here.copy())
839
+ bands_list.append(band_inds_here)
840
+ all_inputs.append(inputs_now)
841
+ st = time.perf_counter()
842
+ # print(params_curr_separated[0].shape[0])
843
+ # if self.is_rj_prop:
844
+ # if self.time > 0:
845
+ # breakpoint()
846
+
847
+ # tmp_check = self.mgh.channel1_data[0][11 * self.data_length + 3911].real + self.mgh.channel1_data[0][29 * self.data_length + 3911].real - self.mgh.channel1_base_data[0][11 * self.data_length + 3911].real
848
+ # print(f"BEFORE {tmp_check}, {self.mgh.channel1_data[0][11 * self.data_length + 3911].real} , {self.mgh.channel1_data[0][29 * self.data_length + 3911].real} , {self.mgh.channel1_base_data[0][11 * self.data_length + 3911].real} ")
849
+ # print(f"BEFORE2 {params_curr_separated[1].min().item()} ")
850
+ self.gb.SharedMemoryMakeNewMove_wrap(*inputs_now)
851
+ self.xp.cuda.runtime.deviceSynchronize()
852
+ # tmp_check = self.mgh.channel1_data[0][11 * self.data_length + 3911].real + self.mgh.channel1_data[0][29 * self.data_length + 3911].real - self.mgh.channel1_base_data[0][11 * self.data_length + 3911].real
853
+ # print(f"After {tmp_check}, {self.mgh.channel1_data[0][11 * self.data_length + 3911].real} , {self.mgh.channel1_data[0][29 * self.data_length + 3911].real} , {self.mgh.channel1_base_data[0][11 * self.data_length + 3911].real} ")
854
+ # print(f"After2 {params_curr_separated[1].min().item()} ")
855
+
856
+ et = time.perf_counter()
857
+ # print(et - st, N_now)
858
+ # breakpoint()
859
+
860
+ self.xp.cuda.runtime.deviceSynchronize()
861
+ new_point_info = []
862
+ ll_diff_info = []
863
+ for (
864
+ band_info,
865
+ indiv_info_now,
866
+ N_now,
867
+ outputs,
868
+ current_parameters,
869
+ inds,
870
+ inds_prev,
871
+ bands
872
+ ) in zip(
873
+ band_bookkeep_info, indiv_info, N_vals_list, output_info, current_parameter_arrays, inds_list, inds_orig_list, bands_list
874
+ ):
875
+ ll_contrib_now = outputs[0]
876
+ lp_contrib_now = outputs[1]
877
+ accepted_now = outputs[2]
878
+ num_proposals_now = outputs[3]
879
+
880
+ per_walker_band_proposals[band_info] += num_proposals_now
881
+ per_walker_band_accepted[band_info] += accepted_now
882
+
883
+ # remove accepted
884
+ # print(accepted_now.sum(0) / accepted_now.shape[0])
885
+ temp_tmp, walker_tmp, leaf_tmp = (
886
+ indiv_info_now[0],
887
+ indiv_info_now[1],
888
+ indiv_info_now[2],
889
+ )
890
+
891
+ # updates related to newly added sources
892
+ if self.is_rj_prop:
893
+ gb_inds_orig_check = gb_inds_orig.copy()
894
+ gb_coords[(temp_tmp[inds], walker_tmp[inds], leaf_tmp[inds])] = xp.asarray(current_parameters).T[inds]
895
+
896
+ gb_inds_orig[(temp_tmp, walker_tmp, leaf_tmp)] = inds
897
+ new_state.branches_supplimental["gb"].holder["N_vals"][
898
+ (temp_tmp[inds].get(), walker_tmp[inds].get(), leaf_tmp[inds].get())
899
+ ] = N_now
900
+ new_state.branches_supplimental["gb"].holder["band_inds"][
901
+ (temp_tmp[inds].get(), walker_tmp[inds].get(), leaf_tmp[inds].get())
902
+ ] = bands[inds].get()
903
+
904
+ else:
905
+ gb_coords[(temp_tmp, walker_tmp, leaf_tmp)] = xp.asarray(current_parameters).T
906
+
907
+ new_band_inds = self.xp.searchsorted(self.band_edges, xp.asarray(current_parameters).T[:, 1] / 1e3, side="right") - 1
908
+
909
+ # if np.any(new_band_inds != bands):
910
+ # print(new_band_inds, bands)
911
+ ll_change = self.xp.zeros((ntemps, nwalkers, len(self.band_edges)))
912
+ lp_change = self.xp.zeros((ntemps, nwalkers, len(self.band_edges)))
913
+
914
+ self.xp.cuda.runtime.deviceSynchronize()
915
+
916
+ ll_change[band_info] = ll_contrib_now
917
+
918
+ ll_diff_info.append(ll_change.copy())
919
+
920
+ ll_adjustment = ll_change.sum(axis=-1)
921
+ log_like_tmp += ll_adjustment
922
+
923
+ self.xp.cuda.runtime.deviceSynchronize()
924
+
925
+ lp_change[band_info] = lp_contrib_now
926
+
927
+ lp_adjustment = lp_change.sum(axis=-1)
928
+ log_prior_tmp += lp_adjustment
929
+
930
+ self.xp.cuda.runtime.deviceSynchronize()
931
+ if True: # self.time > 0:
932
+ ll_after = (
933
+ self.mgh.get_ll(include_psd_info=True)
934
+ )
935
+ # print(np.abs(new_state.log_like - ll_after).max())
936
+ store_max_diff = np.abs(log_like_tmp[0].get() - ll_after).max()
937
+ # print("CHECKING in:", tmp, store_max_diff)
938
+ if store_max_diff > 3e-4:
939
+ print("LARGER ERROR:", store_max_diff)
940
+ self.check_ll_inject(new_state)
941
+ # self.mgh.get_ll(include_psd_info=True, stop=True)
942
+
943
+ new_state.branches["gb"].coords[:] = gb_coords.get()
944
+ if self.is_rj_prop:
945
+ new_state.branches["gb"].inds[:] = gb_inds_orig.get()
946
+ new_state.log_like[:] = log_like_tmp.get()
947
+ new_state.log_prior[:] = log_prior_tmp.get()
948
+
949
+ # get updated bands inds ONLY FOR COLD CHAIN and
950
+ # propogate changes to higher temperatures
951
+ new_freqs = gb_coords[gb_inds_orig, 1]
952
+ new_band_inds = (xp.searchsorted(self.band_edges, new_freqs / 1e3, side="right") - 1)
953
+ new_state.branches["gb"].branch_supplimental.holder["band_inds"][gb_inds_orig.get()] = new_band_inds.get()
954
+
955
+ ll_after = (
956
+ self.mgh.get_ll(include_psd_info=True)
957
+ )
958
+ # print(np.abs(new_state.log_like - ll_after).max())
959
+ store_max_diff = np.abs(new_state.log_like[0] - ll_after).max()
960
+ # print("CHECKING 1:", store_max_diff, self.is_rj_prop)
961
+ # if self.time > 0:
962
+ # breakpoint()
963
+ # self.check_ll_inject(new_state)
964
+
965
+
966
+ if not self.is_rj_prop:
967
+ # if self.time > 0:
968
+ # breakpoint()
969
+ # check2 = self.mgh.get_ll()
970
+ old_band_inds_cold_chain = state.branches["gb"].branch_supplimental.holder["band_inds"][0] * state.branches["gb"].inds[0]
971
+ new_band_inds_cold_chain = new_state.branches["gb"].branch_supplimental.holder["band_inds"][0] * state.branches["gb"].inds[0]
972
+ inds_band_change_cold_chain = np.where(new_band_inds_cold_chain != old_band_inds_cold_chain)
973
+ # when adjusting temperatures, be careful here
974
+ if len(inds_band_change_cold_chain[0]) > 0:
975
+ check2 = self.mgh.get_ll(include_psd_info=True)
976
+ # print("SWITCH", len(inds_band_change_cold_chain[0]))
977
+ walker_inds_change_cold_chain = np.tile(inds_band_change_cold_chain[0], (self.ntemps - 1, 1)).flatten()
978
+ old_leaf_inds_change_cold_chain = np.tile(inds_band_change_cold_chain[1], (self.ntemps - 1, 1)).flatten()
979
+ new_temp_inds_change_cold_chain = np.repeat(np.arange(1, self.ntemps), len(inds_band_change_cold_chain[0]))
980
+
981
+ special_check = new_temp_inds_change_cold_chain * self.nwalkers + walker_inds_change_cold_chain
982
+
983
+ uni_special_check, uni_special_check_count = np.unique(special_check, return_counts=True)
984
+
985
+ # get new leaf positions
986
+ temp_leaves = np.ones_like(group_temp_finder[2].reshape(self.ntemps * self.nwalkers, -1).get()[uni_special_check], dtype=int) * (~new_state.branches["gb"].inds.reshape(self.ntemps * self.nwalkers, -1)[uni_special_check])
987
+ temp_leaves_2 = np.cumsum(temp_leaves, axis=-1)
988
+ temp_leaves_2[new_state.branches["gb"].inds.reshape(self.ntemps * self.nwalkers, -1)[uni_special_check]] = -1
989
+
990
+ leaf_guide_here = np.tile(np.arange(nleaves_max), (len(uni_special_check), 1))
991
+ new_leaf_inds_change_cold_chain = leaf_guide_here[((temp_leaves_2 >= 0) & (temp_leaves_2 <= uni_special_check_count[:, None]))]
992
+ try:
993
+ assert np.all(~new_state.branches["gb"].inds[new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain])
994
+ except IndexError:
995
+ breakpoint()
996
+
997
+ new_state.branches["gb"].inds[new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain] = True
998
+
999
+ new_state.branches["gb"].coords[new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain] = new_state.branches["gb"].coords[np.zeros_like(walker_inds_change_cold_chain), walker_inds_change_cold_chain, old_leaf_inds_change_cold_chain]
1000
+ new_state.branches["gb"].branch_supplimental.holder["band_inds"][new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain] = new_state.branches["gb"].branch_supplimental.holder["band_inds"][np.zeros_like(walker_inds_change_cold_chain), walker_inds_change_cold_chain, old_leaf_inds_change_cold_chain]
1001
+ new_state.branches["gb"].branch_supplimental.holder["N_vals"][new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain] = new_state.branches["gb"].branch_supplimental.holder["N_vals"][np.zeros_like(walker_inds_change_cold_chain), walker_inds_change_cold_chain, old_leaf_inds_change_cold_chain]
1002
+
1003
+ # adjust data
1004
+ adjust_binaries = xp.asarray(new_state.branches["gb"].coords[0, inds_band_change_cold_chain[0], inds_band_change_cold_chain[1]])
1005
+ adjust_binaries_in = self.parameter_transforms.both_transforms(
1006
+ adjust_binaries, xp=xp
1007
+ )
1008
+ adjust_walker_inds = inds_band_change_cold_chain[0]
1009
+ adjust_band_new = new_state.branches["gb"].branch_supplimental.holder["band_inds"][0, inds_band_change_cold_chain[0], inds_band_change_cold_chain[1]]
1010
+ adjust_band_old = state.branches["gb"].branch_supplimental.holder["band_inds"][0, inds_band_change_cold_chain[0], inds_band_change_cold_chain[1]]
1011
+ # N_vals_in = new_state.branches["gb"].branch_supplimental.holder["N_vals"][0, inds_band_change_cold_chain[0], inds_band_change_cold_chain[1]]
1012
+
1013
+ adjust_binaries_in_in = xp.concatenate([
1014
+ adjust_binaries_in,
1015
+ adjust_binaries_in
1016
+ ], axis=0)
1017
+
1018
+ data_index = xp.concatenate([
1019
+ xp.asarray(((adjust_band_old + 0) % 2) * nwalkers + adjust_walker_inds),
1020
+ xp.asarray(((adjust_band_new + 0) % 2) * nwalkers + adjust_walker_inds)
1021
+ ]).astype(xp.int32)
1022
+
1023
+ N_vals_in_in = xp.concatenate([
1024
+ xp.asarray(self.band_N_vals[adjust_band_old]),
1025
+ xp.asarray(self.band_N_vals[adjust_band_new])
1026
+ ])
1027
+
1028
+ factors = xp.concatenate([
1029
+ +xp.ones_like(adjust_band_old, dtype=xp.float64), # remove
1030
+ -xp.ones_like(adjust_band_old, dtype=xp.float64) # add
1031
+ ])
1032
+
1033
+ self.gb.generate_global_template(
1034
+ adjust_binaries_in_in,
1035
+ data_index,
1036
+ self.mgh.data_list,
1037
+ N=N_vals_in_in,
1038
+ factors=factors,
1039
+ data_length=self.mgh.data_length,
1040
+ data_splits=self.mgh.gpu_splits,
1041
+ **waveform_kwargs_now
1042
+ )
1043
+ check3 = self.mgh.get_ll(include_psd_info=True)
1044
+
1045
+ # print(check3 - check2)
1046
+
1047
+ if np.any(
1048
+ self.band_N_vals[adjust_band_old] !=
1049
+ self.band_N_vals[adjust_band_new]
1050
+ ):
1051
+ walkers_focus = np.unique(inds_band_change_cold_chain[0][
1052
+ self.band_N_vals.get()[adjust_band_old] !=
1053
+ self.band_N_vals.get()[adjust_band_new]
1054
+ ])
1055
+ # print(f"specific change in N across boundary: {walkers_focus}")
1056
+ new_state.log_like[0, walkers_focus] = self.mgh.get_ll(include_psd_info=True)[walkers_focus]
1057
+
1058
+ # self.check_ll_inject(new_state)
1059
+ # breakpoint()
1060
+ ll_after = (
1061
+ self.mgh.get_ll(include_psd_info=True)
1062
+ )
1063
+ # print(np.abs(new_state.log_like - ll_after).max())
1064
+ store_max_diff = np.abs(new_state.log_like[0] - ll_after).max()
1065
+ # print("CHECKING 2:", store_max_diff, self.is_rj_prop)
1066
+
1067
+ self.mempool.free_all_blocks()
1068
+ # get accepted fraction
1069
+ if not self.is_rj_prop:
1070
+ accepted_check_tmp = np.zeros_like(
1071
+ state.branches_inds["gb"], dtype=bool
1072
+ )
1073
+ accepted_check_tmp[state.branches_inds["gb"]] = np.all(
1074
+ np.abs(
1075
+ new_state.branches_coords["gb"][
1076
+ state.branches_inds["gb"]
1077
+ ]
1078
+ - state.branches_coords["gb"][state.branches_inds["gb"]]
1079
+ )
1080
+ > 0.0,
1081
+ axis=-1,
1082
+ )
1083
+ proposed = gb_inds.get()
1084
+ accepted_check = accepted_check_tmp.sum(
1085
+ axis=(1, 2)
1086
+ ) / proposed.sum(axis=(1, 2))
1087
+ else:
1088
+ accepted_check_tmp = (
1089
+ new_state.branches_inds["gb"] == (~state.branches_inds["gb"])
1090
+ )
1091
+
1092
+ proposed = gb_inds.get()
1093
+ accepted_check = accepted_check_tmp.sum(axis=(1, 2)) / proposed.sum(axis=(1, 2))
1094
+
1095
+ # manually tell temperatures how real overall acceptance fraction is
1096
+ number_of_walkers_for_accepted = np.floor(nwalkers * accepted_check).astype(int)
1097
+
1098
+ accepted_inds = np.tile(np.arange(nwalkers), (ntemps, 1))
1099
+
1100
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
1101
+ accepted[accepted_inds < number_of_walkers_for_accepted[:, None]] = True
1102
+
1103
+ tmp1 = np.all(
1104
+ np.abs(
1105
+ new_state.branches_coords["gb"]
1106
+ - state.branches_coords["gb"]
1107
+ )
1108
+ > 0.0,
1109
+ axis=-1,
1110
+ ).sum(axis=(2,))
1111
+ tmp2 = new_state.branches_inds["gb"].sum(axis=(2,))
1112
+
1113
+ # add to move-specific accepted information
1114
+ self.accepted += tmp1
1115
+ if isinstance(self.num_proposals, int):
1116
+ self.num_proposals = tmp2
1117
+ else:
1118
+ self.num_proposals += tmp2
1119
+
1120
+ new_inds = xp.asarray(new_state.branches_inds["gb"])
1121
+
1122
+ # in-model inds will not change
1123
+ tmp_freqs_find_bands = xp.asarray(new_state.branches_coords["gb"][:, :, :, 1])
1124
+
1125
+ # calculate current band counts
1126
+ band_here = (xp.searchsorted(self.band_edges, tmp_freqs_find_bands.flatten() / 1e3, side="right") - 1).reshape(tmp_freqs_find_bands.shape)
1127
+
1128
+ # get binaries per band
1129
+ special_band_here_num_per_band = ((group_temp_finder[0] * nwalkers + group_temp_finder[1]) * int(1e6) + band_here)[new_inds]
1130
+ unique_special_band_here_num_per_band, unique_special_band_here_num_per_band_count = xp.unique(special_band_here_num_per_band, return_counts=True)
1131
+ temp_walker_index_num_per_band = (unique_special_band_here_num_per_band / 1e6).astype(int)
1132
+ temp_index_num_per_band = (temp_walker_index_num_per_band / nwalkers).astype(int)
1133
+ walker_index_num_per_band = temp_walker_index_num_per_band - temp_index_num_per_band * nwalkers
1134
+ band_index_num_per_band = unique_special_band_here_num_per_band - temp_walker_index_num_per_band * int(1e6)
1135
+
1136
+ per_walker_band_counts = xp.zeros((ntemps, nwalkers, self.num_bands), dtype=int)
1137
+ per_walker_band_counts[temp_index_num_per_band, walker_index_num_per_band, band_index_num_per_band] = unique_special_band_here_num_per_band_count
1138
+
1139
+ # TEMPERING
1140
+ self.temperature_control.swaps_accepted = np.zeros(ntemps - 1)
1141
+ self.temperature_control.swaps_proposed = np.zeros(ntemps - 1)
1142
+
1143
+ band_swaps_accepted = np.zeros((len(self.band_edges) - 1, self.ntemps - 1), dtype=int)
1144
+ band_swaps_proposed = np.zeros((len(self.band_edges) - 1, self.ntemps - 1), dtype=int)
1145
+ current_band_counts = np.zeros((len(self.band_edges) - 1, self.ntemps), dtype=int)
1146
+
1147
+ # if self.is_rj_prop:
1148
+ # print("1st count check:", new_state.branches["gb"].inds.sum(axis=-1).mean(axis=-1), "\nll:", new_state.log_like[0] - orig_store, new_state.log_like[0])
1149
+
1150
+ # if self.time > 0:
1151
+ # self.check_ll_inject(new_state)
1152
+ if (
1153
+ self.temperature_control is not None
1154
+ and self.time % 1 == 0
1155
+ and self.ntemps > 1
1156
+ and self.is_rj_prop
1157
+ # and False
1158
+ ):
1159
+
1160
+ new_coords_after_tempering = xp.asarray(np.zeros_like(new_state.branches["gb"].coords))
1161
+ new_inds_after_tempering = xp.asarray(np.zeros_like(new_state.branches["gb"].inds))
1162
+
1163
+ # TODO: check if N changes / need to deal with that
1164
+ betas = self.temperature_control.betas
1165
+
1166
+ # cannot find them yourself because of higher temps moving across band edge / need supplimental band inds
1167
+ band_inds_temp = new_state.branches["gb"].branch_supplimental.holder["band_inds"][new_state.branches["gb"].inds]
1168
+ temp_inds_temp = group_temp_finder[0].get()[new_state.branches["gb"].inds]
1169
+ walker_inds_temp = group_temp_finder[1].get()[new_state.branches["gb"].inds]
1170
+
1171
+ bands_guide = np.tile(np.arange(self.num_bands), (self.ntemps, self.nwalkers, 1)).transpose(2, 1, 0).flatten()
1172
+ temps_guide = np.repeat(np.arange(self.ntemps)[:, None], self.nwalkers * self.num_bands).reshape(self.ntemps, self.nwalkers, self.num_bands).transpose(2, 1, 0).flatten()
1173
+ walkers_guide = np.repeat(np.arange(self.nwalkers)[:, None], self.ntemps * self.num_bands).reshape(self.nwalkers, self.ntemps, self.num_bands).transpose(2, 0, 1).flatten()
1174
+
1175
+ walkers_permuted = np.asarray([np.random.permutation(np.arange(self.nwalkers)) for _ in range(self.ntemps * self.num_bands)]).reshape(self.num_bands, self.ntemps, self.nwalkers).transpose(0, 2, 1).flatten()
1176
+
1177
+ # special_inds_guide = bands_guide * int(1e6) + walkers_permuted * int(1e3) + temps_guide
1178
+
1179
+ coords_in = new_state.branches["gb"].coords[new_state.branches["gb"].inds]
1180
+
1181
+ # N_vals_in = new_state.branches["gb"].branch_supplimental.holder["N_vals"][new_state.branches["gb"].inds]
1182
+ unique_N = np.unique(self.band_N_vals).get()
1183
+ # remove 0
1184
+ unique_N = unique_N[unique_N != 0]
1185
+
1186
+ bands_in = band_inds_temp
1187
+ temps_in = temp_inds_temp
1188
+ walkers_in = walker_inds_temp
1189
+
1190
+ main_gpu = self.xp.cuda.runtime.getDevice()
1191
+
1192
+ walkers_info = walkers_permuted if self.temperature_control.permute else walkers_guide
1193
+ units = 2
1194
+ start_unit = np.random.randint(0, units)
1195
+ for unit in range(units):
1196
+ current_band_remainder = (start_unit + unit) % units
1197
+ for N_now in unique_N:
1198
+ keep = (bands_in % units == current_band_remainder) & (self.band_N_vals[xp.asarray(bands_in)].get() == N_now)
1199
+ keep_base = (bands_guide % units == current_band_remainder) & (self.band_N_vals[xp.asarray(bands_guide)].get() == N_now)
1200
+ bands_in_tmp = bands_in[keep]
1201
+ temps_in_tmp = temps_in[keep]
1202
+ walkers_in_tmp = walkers_in[keep]
1203
+ coords_in_tmp = coords_in[keep]
1204
+
1205
+ # adjust for N_val
1206
+ special_inds_bins = bands_in_tmp * int(1e8) + walkers_in_tmp * int(1e4) + temps_in_tmp
1207
+
1208
+ bands_guide_keep = bands_guide[keep_base]
1209
+ temps_guide_keep = temps_guide[keep_base]
1210
+ walkers_info_keep = walkers_info[keep_base]
1211
+
1212
+ special_inds_guide_keep = bands_guide_keep * int(1e8) + walkers_info_keep * int(1e4) + temps_guide_keep
1213
+ sort_tmp = np.argsort(special_inds_guide_keep)
1214
+ sorted_special_inds_guide_keep = special_inds_guide_keep[sort_tmp]
1215
+ sorting_info_need_to_revert = np.searchsorted(sorted_special_inds_guide_keep, special_inds_bins, side="left")
1216
+
1217
+ sorting_info = sort_tmp[sorting_info_need_to_revert]
1218
+
1219
+ assert np.all(special_inds_guide_keep[sorting_info] == special_inds_bins)
1220
+
1221
+ sort_bins = np.argsort(sorting_info)
1222
+
1223
+ # p1 = self.mgh.psd_shaped[0][0][8, 19748:20132]
1224
+ # p2 = self.mgh.psd_shaped[1][0][8, 19748:20132]
1225
+ # c2 = self.mgh.data_shaped[1][0][26, 19748:20132] + self.mgh.data_shaped[1][0][8, 19748:20132]- self.mgh.channel2_base_data[0][8 * self.data_length + 19748:8 * self.data_length + 20132]
1226
+ # c1 = self.mgh.data_shaped[0][0][26, 19748:20132] + self.mgh.data_shaped[0][0][8, 19748:20132]- self.mgh.channel1_base_data[0][8 * self.data_length + 19748:8 * self.data_length + 20132]
1227
+ # args_sorting = np.arange(len(sorting_info)) # np.argsort(sorting_info)
1228
+ bands_in_tmp = bands_in_tmp[sort_bins]
1229
+ temps_in_tmp = temps_in_tmp[sort_bins]
1230
+ walkers_in_tmp = walkers_in_tmp[sort_bins]
1231
+ coords_in_tmp = coords_in_tmp[sort_bins]
1232
+ sorting_info = sorting_info[sort_bins]
1233
+
1234
+ uni_sorted, uni_index, uni_inverse, uni_counts = xp.unique(xp.asarray(sorting_info), return_index=True, return_counts=True, return_inverse=True)
1235
+
1236
+ band_info_start_index_bin = xp.zeros_like(bands_guide_keep, dtype=np.int32)
1237
+ band_info_num_bin = xp.zeros_like(bands_guide_keep, dtype=np.int32)
1238
+
1239
+ band_info_start_index_bin[uni_sorted.astype(np.int32)] = uni_index.astype(np.int32)
1240
+ band_info_num_bin[uni_sorted.astype(np.int32)] = uni_counts.astype(np.int32)
1241
+
1242
+ params_curr_separated = tuple([xp.asarray(coords_in_tmp[:, i].copy()) for i in range(coords_in_tmp.shape[1])])
1243
+ params_curr_separated_orig = tuple([xp.asarray(coords_in_tmp[:, i].copy()) for i in range(coords_in_tmp.shape[1])])
1244
+ params_extra_params = (
1245
+ self.waveform_kwargs["T"],
1246
+ self.waveform_kwargs["dt"],
1247
+ N_now,
1248
+ coords_in.shape[0],
1249
+ self.start_freq_ind,
1250
+ Soms_d_all["sangria"] ** (1/2),
1251
+ Sa_a_all["sangria"] ** (1/2),
1252
+ 1e-100, 0.0, 0.0, 0.0, 0.0, # foreground params for snr -> amp transform
1253
+ )
1254
+
1255
+ gb_params_curr_in = self.gb.pyGalacticBinaryParams(
1256
+ *(params_curr_separated + params_curr_separated_orig + params_extra_params)
1257
+ )
1258
+
1259
+ # current_parameter_arrays.append(params_curr_separated)
1260
+
1261
+ data_package = self.gb.pyDataPackage(
1262
+ data[0][0],
1263
+ data[1][0],
1264
+ base_data[0][0],
1265
+ base_data[1][0],
1266
+ psd[0][0],
1267
+ psd[1][0],
1268
+ lisasens[0][0],
1269
+ lisasens[1][0],
1270
+ self.df,
1271
+ self.data_length,
1272
+ self.nwalkers * self.ntemps,
1273
+ self.nwalkers * self.ntemps
1274
+ )
1275
+
1276
+ data_index_here = xp.asarray((((bands_guide_keep + 1) % 2) * self.nwalkers + walkers_info_keep)).astype(np.int32)
1277
+ noise_index_here = xp.asarray(walkers_info_keep.copy()).astype(np.int32)
1278
+ update_data_index_here = xp.asarray((((bands_guide_keep + 0) % 2) * self.nwalkers + walkers_info_keep)).astype(np.int32)
1279
+
1280
+ start_band_index = ((self.band_edges[bands_guide_keep] / self.df).astype(int) - N_now - 1).astype(np.int32)
1281
+ end_band_index = (((self.band_edges[bands_guide_keep + 1]) / self.df).astype(int) + N_now + 1).astype(np.int32)
1282
+
1283
+ start_band_index[start_band_index < 0] = 0
1284
+ end_band_index[end_band_index >= len(self.fd)] = len(self.fd) - 1
1285
+
1286
+ band_lengths = end_band_index - start_band_index
1287
+
1288
+ start_interest_band_index = ((self.band_edges[bands_guide_keep] / self.df).astype(int)).astype(np.int32) # - N_now).astype(np.int32)
1289
+ end_interest_band_index = (((self.band_edges[bands_guide_keep + 1]) / self.df).astype(int)).astype(np.int32) # + N_now).astype(np.int32)
1290
+
1291
+ start_interest_band_index[start_interest_band_index < 0] = 0
1292
+ end_interest_band_index[end_interest_band_index >= len(self.fd)] = len(self.fd) - 1
1293
+
1294
+ band_interest_lengths = end_interest_band_index - start_interest_band_index
1295
+
1296
+
1297
+ # proposal limits in a band
1298
+ fmin_allow = xp.asarray(((self.band_edges[bands_guide_keep] / self.df).astype(int) - (N_now / 2))) * self.df
1299
+ fmax_allow = xp.asarray(((self.band_edges[bands_guide_keep + 1] / self.df).astype(int) + (N_now / 2))) * self.df
1300
+
1301
+ fmin_allow[fmin_allow < self.band_edges[0]] = self.band_edges[0]
1302
+ fmax_allow[fmax_allow > self.band_edges[-1]] = self.band_edges[-1]
1303
+
1304
+ max_data_store_size = band_lengths.max().item()
1305
+
1306
+ num_bands_here = len(band_info_start_index_bin)
1307
+
1308
+ num_swap_setups = np.unique(bands_guide_keep).shape[0] * self.nwalkers
1309
+
1310
+ swaps_proposed_here = xp.zeros(num_swap_setups * (self.ntemps - 1), dtype=np.int32)
1311
+ swaps_accepted_here = xp.zeros(num_swap_setups * (self.ntemps - 1), dtype=np.int32)
1312
+
1313
+ bands_guide_keep = xp.asarray(bands_guide_keep).astype(np.int32)
1314
+ walkers_info_keep = xp.asarray(walkers_info_keep).astype(np.int32)
1315
+ temps_guide_keep = xp.asarray(temps_guide_keep).astype(np.int32)
1316
+ loc_band_index = xp.arange(num_bands_here, dtype=xp.int32)
1317
+
1318
+ before_loc_band_index = loc_band_index.copy()
1319
+ before_band_info_start_index_bin = band_info_start_index_bin.copy()
1320
+ before_band_info_num_bin = band_info_num_bin.copy()
1321
+
1322
+ band_package = self.gb.pyBandPackage(
1323
+ loc_band_index,
1324
+ data_index_here,
1325
+ noise_index_here,
1326
+ band_info_start_index_bin, # uni_index
1327
+ band_info_num_bin, # uni_count
1328
+ start_band_index,
1329
+ band_lengths,
1330
+ start_interest_band_index,
1331
+ band_interest_lengths,
1332
+ num_bands_here,
1333
+ max_data_store_size,
1334
+ fmin_allow,
1335
+ fmax_allow,
1336
+ update_data_index_here,
1337
+ self.ntemps,
1338
+ bands_guide_keep,
1339
+ walkers_info_keep,
1340
+ temps_guide_keep,
1341
+ swaps_proposed_here,
1342
+ swaps_accepted_here
1343
+ )
1344
+
1345
+ band_inv_temp_vals_here = band_temps[bands_guide_keep, temps_guide_keep]
1346
+
1347
+ accepted_out_here = xp.zeros_like(bands_guide_keep, dtype=xp.int32)
1348
+ L_contribution_here = xp.zeros_like(bands_guide_keep, dtype=complex)
1349
+ p_contribution_here = xp.zeros_like(bands_guide_keep, dtype=complex)
1350
+ prior_all_curr_here = xp.zeros_like(bands_guide_keep, dtype=np.float64)
1351
+
1352
+ mcmc_info = self.gb.pyMCMCInfo(
1353
+ L_contribution_here,
1354
+ p_contribution_here,
1355
+ prior_all_curr_here,
1356
+ accepted_out_here,
1357
+ band_inv_temp_vals_here, # band_inv_temp_vals
1358
+ self.is_rj_prop,
1359
+ False, # phased maximize
1360
+ self.snr_lim,
1361
+ )
1362
+
1363
+ inds_here = xp.ones_like(temps_in_tmp, dtype=bool)
1364
+ factors_here = xp.zeros_like(temps_in_tmp, dtype=float)
1365
+ start_inds_here = xp.zeros_like(temps_in_tmp, dtype=np.int32)
1366
+
1367
+ num_proposals_here = 1
1368
+
1369
+ assert start_inds_here.dtype == np.int32
1370
+ proposal_info = self.gb.pyStretchProposalPackage(
1371
+ *(self.stretch_friends_args_in + (self.nfriends, len(self.stretch_friends_args_in[0]), num_proposals_here, self.a, ndim, inds_here, factors_here, start_inds_here))
1372
+ )
1373
+
1374
+ inputs_now = (
1375
+ data_package,
1376
+ band_package,
1377
+ gb_params_curr_in,
1378
+ mcmc_info,
1379
+ self.gpu_cuda_priors,
1380
+ proposal_info,
1381
+ self.gpu_cuda_wrap,
1382
+ num_swap_setups,
1383
+ device,
1384
+ do_synchronize,
1385
+ -1,
1386
+ 100000
1387
+ )
1388
+
1389
+ self.gb.SharedMemoryMakeTemperingMove_wrap(*inputs_now)
1390
+
1391
+ self.xp.cuda.runtime.deviceSynchronize()
1392
+
1393
+ walkers_info_keep_per_bin = xp.repeat(walkers_info_keep, list(band_info_num_bin.get()))
1394
+ temps_guide_keep_per_bin = xp.repeat(temps_guide_keep, list(band_info_num_bin.get()))
1395
+ start_bin_ind_mapping = xp.repeat(band_info_start_index_bin, list(band_info_num_bin.get()))
1396
+ uni_after, uni_after_start_index, uni_after_inverse = xp.unique(start_bin_ind_mapping, return_index=True, return_inverse=True)
1397
+ ind_map = xp.arange(len(start_bin_ind_mapping)) - xp.arange(len(start_bin_ind_mapping))[uni_after_start_index][uni_after_inverse] + start_bin_ind_mapping
1398
+
1399
+ # out = []
1400
+ # for i in range(len(band_info_num_bin_has)):
1401
+ # num_bins = band_info_num_bin_has[i].item()
1402
+ # start_index = band_info_start_index_bin_has[i].item()
1403
+ # for j in range(num_bins):
1404
+ # out.append(start_index + j)
1405
+
1406
+ # breakpoint()
1407
+ new_coords_mapped = coords_in_tmp[ind_map.get()]
1408
+
1409
+ special_after_tempering = temps_guide_keep_per_bin * self.nwalkers + walkers_info_keep_per_bin
1410
+
1411
+ sorted_after = xp.argsort(special_after_tempering)
1412
+ special_after_tempering = special_after_tempering[sorted_after]
1413
+ walkers_info_keep_per_bin = walkers_info_keep_per_bin[sorted_after]
1414
+ temps_guide_keep_per_bin = temps_guide_keep_per_bin[sorted_after]
1415
+ new_coords_mapped = new_coords_mapped[sorted_after.get()]
1416
+
1417
+ uni_after, uni_index_after, uni_inverse_after = xp.unique(special_after_tempering, return_index=True, return_inverse=True)
1418
+
1419
+ relative_leaf_info_after_tempering = xp.arange(len(special_after_tempering)) - uni_index_after[uni_inverse_after]
1420
+
1421
+ leaf_start_per_walker = new_inds_after_tempering.argmin(axis=-1)[temps_guide_keep_per_bin[uni_index_after], walkers_info_keep_per_bin[uni_index_after]]
1422
+
1423
+ absolute_leaf_info_after_tempering = leaf_start_per_walker[uni_inverse_after] + relative_leaf_info_after_tempering
1424
+
1425
+ assert (np.all(~new_inds_after_tempering[temps_guide_keep_per_bin, walkers_info_keep_per_bin, absolute_leaf_info_after_tempering]))
1426
+ new_coords_after_tempering[temps_guide_keep_per_bin, walkers_info_keep_per_bin, absolute_leaf_info_after_tempering] = new_coords_mapped
1427
+ new_inds_after_tempering[temps_guide_keep_per_bin, walkers_info_keep_per_bin, absolute_leaf_info_after_tempering] = True
1428
+
1429
+ bands_unique_here = np.unique(bands_guide_keep).get()
1430
+ band_swaps_accepted[bands_unique_here] += swaps_accepted_here.reshape(bands_unique_here.shape[0], self.nwalkers, self.ntemps - 1).sum(axis=1).get()
1431
+ band_swaps_proposed[bands_unique_here] += swaps_proposed_here.reshape(bands_unique_here.shape[0], self.nwalkers, self.ntemps - 1).sum(axis=1).get()
1432
+
1433
+ ll_change = self.xp.zeros((ntemps, nwalkers, len(self.band_edges)))
1434
+
1435
+ self.xp.cuda.runtime.deviceSynchronize()
1436
+
1437
+ ll_change[(temps_guide_keep,walkers_info_keep,bands_guide_keep)] = L_contribution_here
1438
+
1439
+ ll_adjustment = ll_change.sum(axis=-1)
1440
+ log_like_tmp += ll_adjustment
1441
+
1442
+ new_state.branches["gb"].coords[:] = new_coords_after_tempering.get()
1443
+ new_state.branches["gb"].inds[:] = new_inds_after_tempering.get()
1444
+ new_state.log_like[:] = log_like_tmp.get()
1445
+
1446
+ new_freqs = new_coords_after_tempering[new_inds_after_tempering, 1]
1447
+ new_band_inds = (xp.searchsorted(self.band_edges, new_freqs / 1e3, side="right") - 1)
1448
+ new_state.branches["gb"].branch_supplimental.holder["band_inds"][new_inds_after_tempering.get()] = new_band_inds.get()
1449
+
1450
+ # breakpoint()
1451
+ # self.check_ll_inject(new_state)
1452
+ # breakpoint()
1453
+
1454
+ # adjust priors accordingly
1455
+ log_prior_new_per_bin = xp.zeros_like(
1456
+ new_state.branches_inds["gb"], dtype=xp.float64
1457
+ )
1458
+ # self.gpu_priors
1459
+ log_prior_new_per_bin[
1460
+ new_state.branches_inds["gb"]
1461
+ ] = self.gpu_priors["gb"].logpdf(
1462
+ xp.asarray(
1463
+ new_state.branches_coords["gb"][
1464
+ new_state.branches_inds["gb"]
1465
+ ]
1466
+ ),
1467
+ psds=self.mgh.psd_shaped[0][0],
1468
+ walker_inds=group_temp_finder[1].get()[new_state.branches_inds["gb"]]
1469
+ )
1470
+
1471
+ new_state.log_prior = log_prior_new_per_bin.sum(axis=-1).get()
1472
+
1473
+ ratios = (band_swaps_accepted / band_swaps_proposed).T # self.swaps_accepted / self.swaps_proposed
1474
+ ratios[np.isnan(ratios)] = 0.0
1475
+
1476
+ # only change those with a binary in them
1477
+
1478
+ # adapt if desired
1479
+ if self.time > 50:
1480
+ betas0 = band_temps.copy().T.get()
1481
+ betas1 = betas0.copy()
1482
+
1483
+ # Modulate temperature adjustments with a hyperbolic decay.
1484
+ decay = self.temperature_control.adaptation_lag / (self.time + self.temperature_control.adaptation_lag)
1485
+ kappa = decay / self.temperature_control.adaptation_time
1486
+
1487
+ # Construct temperature adjustments.
1488
+ dSs = kappa * (ratios[:-1] - ratios[1:])
1489
+
1490
+ # Compute new ladder (hottest and coldest chains don't move).
1491
+ deltaTs = np.diff(1 / betas1[:-1], axis=0)
1492
+
1493
+ deltaTs *= np.exp(dSs)
1494
+ betas1[1:-1] = 1 / (np.cumsum(deltaTs, axis=0) + 1 / betas1[0])
1495
+
1496
+ # Don't mutate the ladder here; let the client code do that.
1497
+ dbetas = betas1 - betas0
1498
+
1499
+ band_temps += self.xp.asarray(dbetas.T)
1500
+
1501
+ # band_temps[:] = band_temps[553, :][None, :]
1502
+
1503
+ # only increase time if it is adaptive.
1504
+ new_state.betas = self.temperature_control.betas.copy()
1505
+
1506
+ self.mempool.free_all_blocks()
1507
+ # print(
1508
+ # self.is_rj_prop,
1509
+ # band_swaps_accepted[350] / band_swaps_proposed[350],
1510
+ # band_swaps_accepted[450] / band_swaps_proposed[450],
1511
+ # band_swaps_accepted[501] / band_swaps_proposed[501]
1512
+ # )
1513
+
1514
+ self.mempool.free_all_blocks()
1515
+ if self.time % 1 == 0:
1516
+ ll_after = (
1517
+ self.mgh.get_ll(include_psd_info=True)
1518
+ )
1519
+ # print(np.abs(new_state.log_like - ll_after).max())
1520
+ store_max_diff = np.abs(new_state.log_like[0] - ll_after).max()
1521
+ # print("CHECKING:", store_max_diff, self.is_rj_prop)
1522
+ if store_max_diff > 1e-5:
1523
+ ll_after = (
1524
+ self.mgh.get_ll(include_psd_info=True, stop=True)
1525
+ )
1526
+
1527
+ if store_max_diff > 1.0:
1528
+ breakpoint()
1529
+
1530
+ # reset data and fix likelihood
1531
+ new_state.log_like[0] = self.check_ll_inject(new_state)
1532
+
1533
+
1534
+ self.time += 1
1535
+ # self.xp.cuda.runtime.deviceSynchronize()
1536
+
1537
+ new_state.update_band_information(
1538
+ band_temps.get(), per_walker_band_proposals.sum(axis=1).get().T, per_walker_band_accepted.sum(axis=1).get().T, band_swaps_proposed, band_swaps_accepted,
1539
+ per_walker_band_counts.get(), self.is_rj_prop
1540
+ )
1541
+ # TODO: check rj numbers
1542
+
1543
+ # new_state.log_like[:] = self.check_ll_inject(new_state)
1544
+
1545
+ self.mempool.free_all_blocks()
1546
+
1547
+ if self.is_rj_prop:
1548
+ pass # print(self.name, "2nd count check:", new_state.branches["gb"].inds.sum(axis=-1).mean(axis=-1), "\nll:", new_state.log_like[0] - orig_store, new_state.log_like[0])
1549
+ return new_state, accepted
1550
+
1551
+ def check_ll_inject(self, new_state):
1552
+
1553
+ check_ll = self.mgh.get_ll(include_psd_info=True).copy()
1554
+
1555
+ nleaves_max = new_state.branches["gb"].shape[-2]
1556
+ for i in range(2):
1557
+ self.mgh.channel1_data[0][self.nwalkers * self.data_length * i: self.nwalkers * self.data_length * (i + 1)] = self.mgh.channel1_base_data[0][:]
1558
+ self.mgh.channel2_data[0][self.nwalkers * self.data_length * i: self.nwalkers * self.data_length * (i + 1)] = self.mgh.channel2_base_data[0][:]
1559
+
1560
+ coords_out_gb = new_state.branches["gb"].coords[0, new_state.branches["gb"].inds[0]]
1561
+ coords_in_in = self.parameter_transforms.both_transforms(coords_out_gb)
1562
+
1563
+ band_inds = np.searchsorted(self.band_edges.get(), coords_in_in[:, 1], side="right") - 1
1564
+ assert np.all(band_inds == new_state.branches["gb"].branch_supplimental.holder["band_inds"][0, new_state.branches["gb"].inds[0]])
1565
+
1566
+ walker_vals = np.tile(np.arange(self.nwalkers), (nleaves_max, 1)).transpose((1, 0))[new_state.branches["gb"].inds[0]]
1567
+
1568
+ data_index_1 = ((band_inds % 2) + 0) * self.nwalkers + walker_vals
1569
+
1570
+ data_index = xp.asarray(data_index_1).astype(xp.int32)
1571
+
1572
+ # goes in as -h
1573
+ factors = -xp.ones_like(data_index, dtype=xp.float64)
1574
+
1575
+ waveform_kwargs_tmp = self.waveform_kwargs.copy()
1576
+
1577
+ N_vals = self.band_N_vals[band_inds]
1578
+ self.gb.generate_global_template(
1579
+ coords_in_in,
1580
+ data_index,
1581
+ self.mgh.data_list,
1582
+ batch_size=1000,
1583
+ data_length=self.data_length,
1584
+ factors=factors,
1585
+ N=N_vals,
1586
+ data_splits=self.mgh.gpu_splits,
1587
+ **waveform_kwargs_tmp,
1588
+ )
1589
+
1590
+ check_ll_new = self.mgh.get_ll(include_psd_info=True)
1591
+ check_ll_diff1 = check_ll_new - check_ll
1592
+ # print(check_ll_diff1)
1593
+
1594
+ # breakpoint()
1595
+ return check_ll_new
1596
+
1597
+ # breakpoint()
1598
+
1599
+ # # print(self.accepted / self.num_proposals)
1600
+
1601
+ # # MULTIPLE TRY after RJ
1602
+
1603
+ # if False: # self.is_rj_prop:
1604
+
1605
+ # inds_added = np.where(new_state.branches["gb"].inds.astype(int) - state.branches["gb"].inds.astype(int) == +1)
1606
+
1607
+ # new_coords = xp.asarray(new_state.branches["gb"].coords[inds_added])
1608
+ # temp_inds_add = xp.repeat(xp.arange(ntemps)[:, None], nwalkers * nleaves_max, axis=-1).reshape(ntemps, nwalkers, nleaves_max)[inds_added]
1609
+ # walker_inds_add = xp.repeat(xp.arange(nwalkers)[:, None], ntemps * nleaves_max, axis=-1).reshape(nwalkers, ntemps, nleaves_max).transpose(1, 0, 2)[inds_added]
1610
+ # leaf_inds_add = xp.repeat(xp.arange(nleaves_max)[:, None], ntemps * nwalkers, axis=-1).reshape(nleaves_max, ntemps, nwalkers).transpose(1, 2, 0)[inds_added]
1611
+ # band_inds_add = xp.searchsorted(self.band_edges, new_coords[:, 1] / 1e3, side="right") - 1
1612
+ # N_vals_add = xp.asarray(new_state.branches["gb"].branch_supplimental.holder["N_vals"][inds_added])
1613
+
1614
+ # #### RIGHT NOW WE ARE EXPERIMENTING WITH NO OR SMALL CHANGE IN FREQUENCY
1615
+ # # because if it is added in RJ, it has locked on in frequency in some form
1616
+
1617
+ # # randomize order
1618
+ # inds_random = xp.random.permutation(xp.arange(len(new_coords)))
1619
+ # new_coords = new_coords[inds_random]
1620
+ # temp_inds_add = temp_inds_add[inds_random]
1621
+ # walker_inds_add = walker_inds_add[inds_random]
1622
+ # leaf_inds_add = leaf_inds_add[inds_random]
1623
+ # band_inds_add = band_inds_add[inds_random]
1624
+ # N_vals_add = N_vals_add[inds_random]
1625
+
1626
+ # special_band_map = leaf_inds_add * int(1e12) + walker_inds_add * int(1e6) + band_inds_add
1627
+ # inds_sorted_special = xp.argsort(special_band_map)
1628
+ # special_band_map = special_band_map[inds_sorted_special]
1629
+ # new_coords = new_coords[inds_sorted_special]
1630
+ # temp_inds_add = temp_inds_add[inds_sorted_special]
1631
+ # walker_inds_add = walker_inds_add[inds_sorted_special]
1632
+ # leaf_inds_add = leaf_inds_add[inds_sorted_special]
1633
+ # band_inds_add = band_inds_add[inds_sorted_special]
1634
+ # N_vals_add = N_vals_add[inds_sorted_special]
1635
+
1636
+ # unique_special_bands, unique_special_bands_index, unique_special_bands_inverse = xp.unique(special_band_map, return_index=True, return_inverse=True)
1637
+ # group_index = xp.arange(len(special_band_map)) - xp.arange(len(special_band_map))[unique_special_bands_index][unique_special_bands_inverse]
1638
+
1639
+ # band_splits = 3
1640
+ # num_try = 1000
1641
+ # for group in range(group_index.max().item()):
1642
+ # for split_i in range(band_splits):
1643
+ # group_split = (group_index == group) & (band_inds_add % band_splits == split_i)
1644
+
1645
+ # new_coords_group_split = new_coords[group_split]
1646
+ # temp_inds_add_group_split = temp_inds_add[group_split]
1647
+ # walker_inds_add_group_split = walker_inds_add[group_split]
1648
+ # leaf_inds_add_group_split = leaf_inds_add[group_split]
1649
+ # band_inds_add_group_split = band_inds_add[group_split]
1650
+ # N_vals_add_group_split = N_vals_add[group_split]
1651
+
1652
+ # coords_remove = xp.repeat(new_coords_group_split, num_try, axis=0)
1653
+ # coords_add = coords_remove.copy()
1654
+
1655
+ # inds_params = xp.array([0, 2, 3, 4, 5, 6, 7])
1656
+ # coords_add[:, inds_params] = self.gpu_priors["gb"].rvs(size=coords_remove.shape[0])[:, inds_params]
1657
+
1658
+ # log_proposal_pdf = self.gpu_priors["gb"].logpdf(coords_add)
1659
+ # # remove logpdf from fchange for now
1660
+ # log_proposal_pdf -= self.gpu_priors["gb"].priors_in[1].logpdf(coords_add[:, 1])
1661
+
1662
+ # priors_remove = self.gpu_priors["gb"].logpdf(coords_remove)
1663
+ # priors_add = self.gpu_priors["gb"].logpdf(coords_add)
1664
+
1665
+ # if xp.any(coords_add[:, 1] != coords_remove[:, 1]):
1666
+ # raise NotImplementedError("Assumes frequencies are the same.")
1667
+ # independent = False
1668
+ # else:
1669
+ # independent = True
1670
+
1671
+ # coords_remove_in = self.parameter_transforms.both_transforms(coords_remove, xp=xp)
1672
+ # coords_add_in = self.parameter_transforms.both_transforms(coords_add, xp=xp)
1673
+
1674
+ # waveform_kwargs_tmp = self.waveform_kwargs.copy()
1675
+ # waveform_kwargs_tmp.pop("N")
1676
+
1677
+ # data_index_in = xp.repeat(temp_inds_add_group_split * nwalkers + walker_inds_add_group_split, num_try).astype(xp.int32)
1678
+ # noise_index_in = data_index_in.copy()
1679
+ # N_vals_add_group_split_in = self.xp.repeat(N_vals_add_group_split, num_try)
1680
+
1681
+ # ll_diff = self.xp.asarray(self.gb.swap_likelihood_difference(coords_remove_in, coords_add_in, self.mgh.data_list, self.mgh.psd_list, data_index=data_index_in, noise_index=noise_index_in, N=N_vals_add_group_split_in, data_length=self.data_length, data_splits=self.mgh.gpu_splits, **waveform_kwargs_tmp))
1682
+
1683
+ # ll_diff[self.xp.isnan(ll_diff)] = -1e300
1684
+
1685
+ # band_inv_temps = self.xp.repeat(band_temps[(band_inds_add_group_split, temp_inds_add_group_split)], num_try)
1686
+
1687
+ # logP = band_inv_temps * ll_diff + priors_add
1688
+
1689
+ # from eryn.moves.multipletry import logsumexp, get_mt_computations
1690
+
1691
+ # band_inv_temps = band_inv_temps.reshape(-1, num_try)
1692
+ # ll_diff = ll_diff.reshape(-1, num_try)
1693
+ # logP = logP.reshape(-1, num_try)
1694
+ # priors_add = priors_add.reshape(-1, num_try)
1695
+ # log_proposal_pdf = log_proposal_pdf.reshape(-1, num_try)
1696
+ # coords_add = coords_add.reshape(-1, num_try, ndim)
1697
+
1698
+ # log_importance_weights, log_sum_weights, inds_group_split = get_mt_computations(logP, log_proposal_pdf, symmetric=False, xp=self.xp)
1699
+
1700
+ # inds_tuple = (self.xp.arange(len(inds_group_split)), inds_group_split)
1701
+
1702
+ # ll_diff_out = ll_diff[inds_tuple]
1703
+ # logP_out = logP[inds_tuple]
1704
+ # priors_add_out = priors_add[inds_tuple]
1705
+ # coords_add_out = coords_add[inds_tuple]
1706
+ # log_proposal_pdf_out = log_proposal_pdf[inds_tuple]
1707
+
1708
+ # if not independent:
1709
+ # raise NotImplementedError
1710
+ # else:
1711
+ # aux_coords_add = coords_add.copy()
1712
+ # aux_ll_diff = ll_diff.copy()
1713
+ # aux_priors_add = priors_add.copy()
1714
+ # aux_log_proposal_pdf = log_proposal_pdf.copy()
1715
+
1716
+ # aux_coords_add[:, 0] = new_coords_group_split
1717
+ # aux_ll_diff[:, 0] = 0.0 # diff is zero because the points are already in the data
1718
+ # aux_priors_add[:, 0] = priors_remove[::num_try]
1719
+
1720
+ # initial_log_proposal_pdf = self.gpu_priors["gb"].logpdf(new_coords_group_split)
1721
+ # # remove logpdf from fchange for now
1722
+ # initial_log_proposal_pdf -= self.gpu_priors["gb"].priors_in[1].logpdf(new_coords_group_split[:, 1])
1723
+
1724
+ # aux_log_proposal_pdf[:, 0] = initial_log_proposal_pdf
1725
+
1726
+ # aux_logP = band_inv_temps * aux_ll_diff + aux_priors_add
1727
+
1728
+ # aux_log_importane_weights, aux_log_sum_weights, _ = get_mt_computations(aux_logP, aux_log_proposal_pdf, symmetric=False, xp=self.xp)
1729
+
1730
+ # aux_logP_out = aux_logP[:, 0]
1731
+ # aux_log_proposal_pdf_out = aux_log_proposal_pdf[:, 0]
1732
+
1733
+ # factors = ((aux_logP_out - aux_log_sum_weights)- aux_log_proposal_pdf_out + aux_log_proposal_pdf_out) - ((logP_out - log_sum_weights) - log_proposal_pdf_out + log_proposal_pdf_out)
1734
+
1735
+ # lnpdiff = factors + logP_out - aux_logP_out
1736
+
1737
+ # keep = lnpdiff > self.xp.asarray(self.xp.log(self.xp.random.rand(*logP_out.shape)))
1738
+
1739
+ # coords_remove_keep = new_coords_group_split[keep]
1740
+ # coords_add_keep = coords_add_out[keep]
1741
+ # temp_inds_add_keep = temp_inds_add_group_split[keep]
1742
+ # walker_inds_add_keep = walker_inds_add_group_split[keep]
1743
+ # leaf_inds_add_keep = leaf_inds_add_group_split[keep]
1744
+ # band_inds_add_keep = band_inds_add_group_split[keep]
1745
+ # N_vals_add_keep = N_vals_add_group_split[keep]
1746
+ # ll_diff_keep = ll_diff_out[keep]
1747
+ # priors_add_keep = priors_add_out[keep]
1748
+ # priors_remove_keep = priors_remove[::num_try][keep]
1749
+
1750
+ # # adjust everything
1751
+ # ll_band_diff = xp.zeros((ntemps, nwalkers, len(self.band_edges) - 1))
1752
+ # ll_band_diff[temp_inds_add_keep, walker_inds_add_keep, band_inds_add_keep] = ll_diff_keep
1753
+ # lp_band_diff = xp.zeros((ntemps, nwalkers, len(self.band_edges) - 1))
1754
+ # lp_band_diff[temp_inds_add_keep, walker_inds_add_keep, band_inds_add_keep] = priors_add_keep - priors_remove_keep
1755
+
1756
+ # new_state.branches["gb"].coords[temp_inds_add_keep.get(), walker_inds_add_keep.get(), band_inds_add_keep.get()] = coords_add_keep.get()
1757
+ # new_state.log_like += ll_band_diff.sum(axis=-1).get()
1758
+ # new_state.log_prior += ll_band_diff.sum(axis=-1).get()
1759
+
1760
+ # waveform_kwargs_tmp = self.waveform_kwargs.copy()
1761
+ # waveform_kwargs_tmp.pop("N")
1762
+ # coords_remove_keep_in = self.parameter_transforms.both_transforms(coords_remove_keep, xp=self.xp)
1763
+ # coords_add_keep_in = self.parameter_transforms.both_transforms(coords_add_keep, xp=self.xp)
1764
+
1765
+ # coords_in = xp.concatenate([coords_remove_keep_in, coords_add_keep_in], axis=0)
1766
+ # factors = xp.concatenate([+xp.ones(coords_remove_keep_in.shape[0]), -xp.ones(coords_remove_keep_in.shape[0])])
1767
+ # data_index_tmp = (temp_inds_add_keep * nwalkers + walker_inds_add_keep).astype(xp.int32)
1768
+ # data_index_in = xp.concatenate([data_index_tmp, data_index_tmp], dtype=xp.int32)
1769
+ # N_vals_in = xp.concatenate([N_vals_add_keep, N_vals_add_keep])
1770
+ # self.gb.generate_global_template(
1771
+ # coords_in,
1772
+ # data_index_in,
1773
+ # self.mgh.data_list,
1774
+ # N=N_vals_in,
1775
+ # data_length=self.data_length,
1776
+ # data_splits=self.mgh.gpu_splits,
1777
+ # factors=factors,
1778
+ # **waveform_kwargs_tmp
1779
+ # )
1780
+ # self.xp.cuda.runtime.deviceSynchronize()
1781
+
1782
+ # ll_after = (
1783
+ # self.mgh.get_ll(include_psd_info=True)
1784
+ # .flatten()[new_state.supplimental[:]["overall_inds"]]
1785
+ # .reshape(ntemps, nwalkers)
1786
+ # )
1787
+ # # print(np.abs(new_state.log_like - ll_after).max())
1788
+ # store_max_diff = np.abs(new_state.log_like - ll_after).max()
1789
+ # breakpoint()
1790
+ # self.mempool.free_all_blocks()
1791
+
1792
+ # self.mgh.restore_base_injections()
1793
+
1794
+ # for name in new_state.branches.keys():
1795
+ # if name not in ["gb", "gb"]:
1796
+ # continue
1797
+ # new_state_branch = new_state.branches[name]
1798
+ # coords_here = new_state_branch.coords[new_state_branch.inds]
1799
+ # ntemps, nwalkers, nleaves_max_here, ndim = new_state_branch.shape
1800
+ # try:
1801
+ # group_index = self.xp.asarray(
1802
+ # self.mgh.get_mapped_indices(
1803
+ # np.repeat(
1804
+ # np.arange(ntemps * nwalkers).reshape(
1805
+ # ntemps, nwalkers, 1
1806
+ # ),
1807
+ # nleaves_max,
1808
+ # axis=-1,
1809
+ # )[new_state_branch.inds]
1810
+ # ).astype(self.xp.int32)
1811
+ # )
1812
+ # except IndexError:
1813
+ # breakpoint()
1814
+ # coords_here_in = self.parameter_transforms.both_transforms(
1815
+ # coords_here, xp=np
1816
+ # )
1817
+
1818
+ # waveform_kwargs_fill = self.waveform_kwargs.copy()
1819
+ # waveform_kwargs_fill["start_freq_ind"] = self.start_freq_ind
1820
+
1821
+ # if "N" in waveform_kwargs_fill:
1822
+ # waveform_kwargs_fill.pop("N")
1823
+
1824
+ # self.mgh.multiply_data(-1.0)
1825
+ # self.gb.generate_global_template(
1826
+ # coords_here_in,
1827
+ # group_index,
1828
+ # self.mgh.data_list,
1829
+ # data_length=self.data_length,
1830
+ # data_splits=self.mgh.gpu_splits,
1831
+ # **waveform_kwargs_fill
1832
+ # )
1833
+ # self.xp.cuda.runtime.deviceSynchronize()
1834
+ # self.mgh.multiply_data(-1.0)
1835
+
1836
+