lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lisaanalysistools might be problematic. Click here for more details.

Files changed (37) hide show
  1. lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
  2. lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
  3. lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
  4. lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
  5. lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
  6. lisatools/__init__.py +0 -0
  7. lisatools/_version.py +4 -0
  8. lisatools/analysiscontainer.py +438 -0
  9. lisatools/cutils/detector.cpython-312-darwin.so +0 -0
  10. lisatools/datacontainer.py +292 -0
  11. lisatools/detector.py +410 -0
  12. lisatools/diagnostic.py +976 -0
  13. lisatools/glitch.py +193 -0
  14. lisatools/sampling/__init__.py +0 -0
  15. lisatools/sampling/likelihood.py +882 -0
  16. lisatools/sampling/moves/__init__.py +0 -0
  17. lisatools/sampling/moves/gbgroupstretch.py +53 -0
  18. lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
  19. lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
  20. lisatools/sampling/moves/gbspecialstretch.py +1836 -0
  21. lisatools/sampling/moves/mbhspecialmove.py +286 -0
  22. lisatools/sampling/moves/placeholder.py +16 -0
  23. lisatools/sampling/moves/skymodehop.py +110 -0
  24. lisatools/sampling/moves/specialforegroundmove.py +564 -0
  25. lisatools/sampling/prior.py +508 -0
  26. lisatools/sampling/stopping.py +320 -0
  27. lisatools/sampling/utility.py +324 -0
  28. lisatools/sensitivity.py +888 -0
  29. lisatools/sources/__init__.py +0 -0
  30. lisatools/sources/emri/__init__.py +1 -0
  31. lisatools/sources/emri/tdiwaveform.py +72 -0
  32. lisatools/stochastic.py +291 -0
  33. lisatools/utils/__init__.py +0 -0
  34. lisatools/utils/constants.py +40 -0
  35. lisatools/utils/multigpudataholder.py +730 -0
  36. lisatools/utils/pointeradjust.py +106 -0
  37. lisatools/utils/utility.py +240 -0
@@ -0,0 +1,1287 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from copy import deepcopy
4
+ from inspect import Attribute
5
+ from multiprocessing.sharedctypes import Value
6
+ import numpy as np
7
+ import warnings
8
+ import time
9
+ from scipy.special import logsumexp
10
+
11
+ try:
12
+ import cupy as xp
13
+
14
+ gpu_available = True
15
+ except ModuleNotFoundError:
16
+ import numpy as xp
17
+
18
+ gpu_available = False
19
+
20
+ from eryn.state import State, BranchSupplimental
21
+ from eryn.moves import ReversibleJumpMove, MultipleTryMove
22
+ from eryn.prior import ProbDistContainer
23
+ from eryn.utils.utility import groups_from_inds
24
+
25
+ from gbgpu.utils.utility import get_N, get_fdot
26
+
27
+ from lisatools.sampling.moves.gbspecialstretch import GBSpecialStretchMove
28
+
29
+ __all__ = ["GBMutlipleTryRJ"]
30
+
31
+
32
+ def shuffle_along_axis(a, axis):
33
+ idx = np.random.rand(*a.shape).argsort(axis=axis)
34
+ return np.take_along_axis(a, idx, axis=axis)
35
+
36
+
37
+ def searchsorted2d_vec(a, b, xp=None, **kwargs):
38
+ if xp is None:
39
+ xp = np
40
+ m, n = a.shape
41
+ max_num = xp.maximum(a.max() - a.min(), b.max() - b.min()) + 1
42
+ r = max_num * xp.arange(a.shape[0])[:, None]
43
+ p = xp.searchsorted((a + r).ravel(), (b + r).ravel(), **kwargs).reshape(m, -1)
44
+ return p - n * (xp.arange(m)[:, None])
45
+
46
+
47
+ class GBMutlipleTryRJ(MultipleTryMove, ReversibleJumpMove, GBSpecialStretchMove):
48
+ """Generate Revesible-Jump proposals for GBs with multiple try
49
+
50
+ Will use gpu if template generator uses GPU.
51
+
52
+ Args:
53
+ priors (object): :class:`PriorContainer` object that has ``logpdf``
54
+ and ``rvs`` methods.
55
+
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ gb_args,
61
+ gb_kwargs,
62
+ m_chirp_lims,
63
+ *args,
64
+ start_ind_limit=10,
65
+ num_try=1,
66
+ point_generator_func=None,
67
+ fix_change=None,
68
+ **kwargs
69
+ ):
70
+ self.point_generator_func = point_generator_func
71
+ self.fixed_like_diff = 0
72
+ self.time = 0
73
+ self.name = "gbgroupstretch"
74
+ self.start_ind_limit = start_ind_limit
75
+ self.num_try = num_try
76
+
77
+ GBSpecialStretchMove.__init__(self, *gb_args, **gb_kwargs)
78
+ ReversibleJump.__init__(self, *args, **kwargs)
79
+ MultipleTryMove.__init__(self, self.num_try, take_max_ll=False, xp=self.xp)
80
+
81
+ # setup band edges for priors
82
+ self.band_edges_fdot = self.xp.zeros_like(self.band_edges)
83
+ # lower limit
84
+
85
+ self.band_edges_fdot[:-1] = self.xp.asarray(
86
+ get_fdot(self.band_edges[:-1], Mc=m_chirp_lims[0])
87
+ )
88
+ self.band_edges_fdot[1:] = self.xp.asarray(
89
+ get_fdot(self.band_edges[1:], Mc=m_chirp_lims[1])
90
+ )
91
+
92
+ self.band_edges_gpu = self.xp.asarray(self.band_edges)
93
+ self.fix_change = fix_change
94
+ if self.fix_change not in [None, +1, -1]:
95
+ raise ValueError("fix_change must be None, +1, or -1.")
96
+
97
+ def special_generate_func(
98
+ self,
99
+ coords,
100
+ nwalkers,
101
+ current_priors=None,
102
+ random=None,
103
+ size: int = 1,
104
+ fill=None,
105
+ fill_inds=None,
106
+ band_inds=None,
107
+ ):
108
+ """if self.search_samples is not None:
109
+ # TODO: make replace=True ? in PE
110
+ inds_drawn = self.xp.array([random.choice(
111
+ self.search_inds, size=size, replace=False,
112
+ ) for w in range(nwalkers)])
113
+ generated_points = self.search_samples[inds_drawn].copy() # .reshape(nwalkers, size, -1)
114
+ # since this is in search only, we will pretend these are coming from the prior
115
+ # so that they are accepted based on the Likelihood (and not the posterior)
116
+ generate_factors = current_priors.logpdf(generated_points.reshape(nwalkers * size, -1)).reshape(nwalkers, size)
117
+ """
118
+ # st = time.perf_counter()
119
+ if band_inds is None:
120
+ raise ValueError("band_inds needs to be set")
121
+
122
+ # elif
123
+ if self.point_generator_func is not None:
124
+ # st1 = time.perf_counter()
125
+ generated_points = self.point_generator_func.rvs(size=size * nwalkers)
126
+ """et1 = time.perf_counter()
127
+ print("generate rvs:", et1 - st1)"""
128
+
129
+ generated_points[self.xp.isnan(generated_points[:, 0]), 0] = 0.01
130
+
131
+ # st2 = time.perf_counter()
132
+ generate_factors = self.point_generator_func.logpdf(generated_points)
133
+ """et2 = time.perf_counter()
134
+ print("generate logpdf:", et2 - st2)"""
135
+
136
+ starts = self.band_edges_gpu[band_inds]
137
+ ends = self.band_edges_gpu[band_inds + 1]
138
+
139
+ starts_fdot = self.band_edges_fdot[band_inds]
140
+ ends_fdot = self.band_edges_fdot[band_inds + 1]
141
+
142
+ # fill before getting logpdf
143
+ generated_points = generated_points.reshape(nwalkers, size, -1)
144
+ generate_factors = generate_factors.reshape(nwalkers, size)
145
+
146
+ # map back per band
147
+ generated_points[:, :, 1] = (
148
+ generated_points[:, :, 1] * (ends[:, None] - starts[:, None])
149
+ + starts[:, None]
150
+ ) * 1e3
151
+ generated_points[:, :, 2] = (
152
+ generated_points[:, :, 2] * (ends_fdot[:, None] - starts_fdot[:, None])
153
+ + starts_fdot[:, None]
154
+ )
155
+
156
+ if fill is not None or fill_inds is not None:
157
+ if fill is None or fill_inds is None:
158
+ raise ValueError(
159
+ "If providing fill_inds or fill, must provide both."
160
+ )
161
+ generated_points[fill_inds] = fill.copy()
162
+
163
+ generated_points = generated_points.reshape(nwalkers, size, -1)
164
+
165
+ # logpdf contribution from original distribution is zero = log(1/1)
166
+ # THIS HAS BEEN REMOVED TO SIMULATE A PRIOR HERE THAT IS EQUIVALENT TO THE GLOBAL PRIOR VALUE
167
+ # THE FACT IS THAT THE EFFECTIVE PRIOR HERE WILL BE THE SAME AS THE GENERATING FUNCTION (UP TO THE SNR PRIOR IF THAT IS CHANGED IN THE GENERATING FUNCTION)
168
+ # generate_factors[:] += (self.xp.log(1 / (ends - starts)))[:, None]
169
+ # generate_factors[:] += (self.xp.log(1 / (ends_fdot - starts_fdot)))[:, None]
170
+
171
+ else:
172
+ if current_priors is None:
173
+ raise ValueError(
174
+ "If generating from the prior, must provide current_priors kwargs."
175
+ )
176
+
177
+ generated_points = current_priors.rvs(size=nwalkers * size).reshape(
178
+ nwalkers, size, -1
179
+ )
180
+ if fill is not None or fill_inds is not None:
181
+ if fill is None or fill_inds is None:
182
+ raise ValueError(
183
+ "If providing fill_inds or fill, must provide both."
184
+ )
185
+ generated_points[fill_inds[0]] = fill.copy()
186
+
187
+ generate_factors = current_priors.logpdf(
188
+ generated_points.reshape(nwalkers * size, -1)
189
+ ).reshape(nwalkers, size)
190
+
191
+ """et = time.perf_counter()
192
+ print("GENEARTE:", et - st)"""
193
+ return generated_points, generate_factors
194
+
195
+ def special_like_func(
196
+ self,
197
+ generated_points,
198
+ base_shape,
199
+ inds_reverse=None,
200
+ old_d_h_d_h=None,
201
+ overall_inds=None,
202
+ ):
203
+ # st = time.perf_counter()
204
+ self.xp.cuda.runtime.setDevice(self.xp.cuda.runtime.getDevice())
205
+
206
+ if overall_inds is None:
207
+ raise ValueError("overall_inds is None.")
208
+
209
+ # group everything
210
+
211
+ # GENERATED POINTS MUST BE PASSED IN by reference not copied
212
+ num_inds_change, nleaves_max, ndim = base_shape
213
+ num_inds_change_gen, num_try, ndim_gen = generated_points.shape
214
+ assert num_inds_change_gen == num_inds_change and ndim == ndim_gen
215
+
216
+ if old_d_h_d_h is None:
217
+ raise NotImplementedError
218
+ self.d_h_d_h = d_h_d_h = (
219
+ 4 * self.df * self.xp.sum((in_vals.conj() * in_vals) / psd, axis=(1, 2))
220
+ )
221
+ else:
222
+ self.d_h_d_h = d_h_d_h = self.xp.asarray(old_d_h_d_h)
223
+
224
+ ll_out = self.xp.zeros((num_inds_change, num_try)).flatten()
225
+ # TODO: take out of loop later?
226
+
227
+ phase_marginalize = self.search
228
+ generated_points_here = generated_points.reshape(-1, ndim)
229
+
230
+ back_d_d = self.gb.d_d.copy()
231
+ self.gb.d_d = self.xp.repeat(d_h_d_h, self.num_try)
232
+
233
+ # do not need mapping because it comes in as overall inds already mapped
234
+ data_index = self.xp.asarray(
235
+ self.xp.repeat(overall_inds, self.num_try).astype(self.xp.int32)
236
+ )
237
+
238
+ noise_index = data_index.copy()
239
+
240
+ self.data_index_check = data_index.reshape(-1, self.num_try)[:, 0]
241
+
242
+ # TODO: Remove batch_size if GPU only ?
243
+ prior_generated_points = generated_points_here
244
+
245
+ if self.parameter_transforms is not None:
246
+ prior_generated_points_in = self.parameter_transforms.both_transforms(
247
+ prior_generated_points.copy(), xp=self.xp
248
+ )
249
+
250
+ N_temp = self.xp.asarray(
251
+ get_N(
252
+ prior_generated_points_in[:, 0],
253
+ prior_generated_points_in[:, 1],
254
+ self.waveform_kwargs["T"],
255
+ self.waveform_kwargs["oversample"],
256
+ )
257
+ )
258
+
259
+ waveform_kwargs_in = self.waveform_kwargs.copy()
260
+ waveform_kwargs_in.pop("N")
261
+ main_gpu = self.xp.cuda.runtime.getDevice()
262
+ # TODO: do search sorted and apply that to nearest found for new points found with group
263
+
264
+ # st3 = time.perf_counter()
265
+ ll = self.gb.get_ll(
266
+ prior_generated_points_in,
267
+ self.mgh.data_list,
268
+ self.mgh.psd_list,
269
+ data_index=data_index,
270
+ noise_index=noise_index,
271
+ phase_marginalize=phase_marginalize,
272
+ data_length=self.data_length,
273
+ data_splits=self.mgh.gpu_splits,
274
+ N=N_temp,
275
+ return_cupy=True,
276
+ **waveform_kwargs_in
277
+ )
278
+ self.xp.cuda.runtime.setDevice(main_gpu)
279
+ self.xp.cuda.runtime.deviceSynchronize()
280
+
281
+ """et3 = time.perf_counter()
282
+ print("actual like:", et3 - st3)"""
283
+ if self.xp.any(self.xp.isnan(ll)):
284
+ assert self.xp.isnan(ll).sum() < 10
285
+ ll[self.xp.isnan(ll)] = -1e300
286
+
287
+ opt_snr = self.xp.sqrt(self.gb.h_h)
288
+
289
+ if self.search:
290
+ phase_maximized_snr = (
291
+ self.xp.abs(self.gb.d_h) / self.xp.sqrt(self.gb.h_h)
292
+ ).real.copy()
293
+
294
+ phase_change = self.xp.angle(
295
+ self.xp.asarray(self.gb.non_marg_d_h) / self.xp.sqrt(self.gb.h_h.real)
296
+ )
297
+
298
+ try:
299
+ phase_maximized_snr = phase_maximized_snr.get()
300
+ phase_change = phase_change.get()
301
+ opt_snr = opt_snr.get()
302
+
303
+ except AttributeError:
304
+ pass
305
+
306
+ if self.xp.any(self.xp.isnan(prior_generated_points)) or self.xp.any(
307
+ self.xp.isnan(phase_change)
308
+ ):
309
+ breakpoint()
310
+
311
+ # adjust for phase change from maximization
312
+ generated_points_here[:, 3] = (
313
+ generated_points_here[:, 3] - phase_change
314
+ ) % (2 * np.pi)
315
+
316
+ snr_comp = phase_maximized_snr
317
+
318
+ else:
319
+ snr_comp = (self.gb.d_h.real / self.xp.sqrt(self.gb.h_h)).real.copy()
320
+ try:
321
+ snr_comp = snr_comp.get()
322
+ opt_snr = opt_snr.get()
323
+
324
+ except AttributeError:
325
+ pass
326
+
327
+ snr_comp2 = snr_comp.reshape(-1, self.num_try)
328
+ okay = snr_comp2 >= 1.0
329
+ okay[inds_reverse.get()] = True
330
+ ll[~okay.flatten()] = -1e300
331
+
332
+ ##print(opt_snr[snr_comp.argmax()].real, snr_comp.max(), ll[snr_comp.argmax()].real - -1/2 * self.gb.d_d[snr_comp.argmax()].real)
333
+ if self.search and self.search_snr_lim is not None:
334
+ ll[
335
+ (snr_comp < self.search_snr_lim * 0.8)
336
+ | (opt_snr < self.search_snr_lim * self.search_snr_accept_factor)
337
+ ] = -1e300
338
+ """if self.xp.any(~((snr_comp
339
+ < self.search_snr_lim * 0.95)
340
+ | (opt_snr
341
+ < self.search_snr_lim * self.search_snr_accept_factor))):
342
+ breakpoint()"""
343
+
344
+ generated_points[:] = generated_points_here.reshape(generated_points.shape)
345
+ ll_out = ll.copy()
346
+ # if inds_reverse is not None and len(inds_reverse) != 0 and split == num_splits - 1:
347
+ # breakpoint()
348
+
349
+ ll_out = ll_out.reshape(num_inds_change, num_try)
350
+ self.old_ll_out_check = ll_out.copy()
351
+ if inds_reverse is not None:
352
+ # try:
353
+ # tmp_d_h_d_h = d_h_d_h.get()
354
+ # except AttributeError:
355
+ # tmp_d_h_d_h = d_h_d_h
356
+
357
+ # this is special to GBs
358
+ self.special_aux_ll = -ll_out[inds_reverse, 0]
359
+
360
+ self.check_h_h = self.gb.h_h.reshape(-1, num_try)[inds_reverse]
361
+ self.check_d_h = self.gb.d_h.reshape(-1, num_try)[inds_reverse]
362
+ ll_out[inds_reverse, :] += self.special_aux_ll[:, None]
363
+
364
+ # return gb.d_d
365
+ self.gb.d_d = back_d_d.copy()
366
+ # add noise term
367
+ self.xp.cuda.runtime.deviceSynchronize()
368
+ """et = time.perf_counter()
369
+ print("LIKE:", et - st)"""
370
+ return ll_out # + self.noise_ll[:, None]
371
+
372
+ def special_prior_func(
373
+ self, generated_points, base_shape, inds_reverse=None, **kwargs
374
+ ):
375
+ # st = time.perf_counter()
376
+ nwalkers, nleaves_max, ndim = base_shape
377
+ # st2 = time.perf_counter()
378
+ lp_new = (
379
+ self.gpu_priors["gb"]
380
+ .logpdf(generated_points.reshape(-1, 8))
381
+ .reshape(nwalkers, self.num_try)
382
+ )
383
+ """et2 = time.perf_counter()
384
+ print("prior logpdf:", et2 - st2)"""
385
+ lp_total = lp_new # _old[:, None] + lp_new
386
+
387
+ self.old_lp_total_check = lp_total.copy()
388
+ if inds_reverse is not None:
389
+ # this is special to GBs
390
+ self.special_aux_lp = -lp_total[inds_reverse, 0]
391
+
392
+ lp_total[inds_reverse, :] += self.special_aux_lp[:, None]
393
+
394
+ # add noise lp
395
+ """et = time.perf_counter()
396
+ print("PRIOR:", et - st)"""
397
+ return lp_total
398
+
399
+ def readout_adjustment(self, out_vals, all_vals_prop, aux_all_vals, inds_reverse):
400
+ self.out_vals, self.all_vals_prop, self.aux_all_vals, self.inds_reverse = (
401
+ out_vals,
402
+ all_vals_prop,
403
+ aux_all_vals,
404
+ inds_reverse,
405
+ )
406
+
407
+ (
408
+ self.logP_out,
409
+ self.ll_out,
410
+ self.lp_out,
411
+ self.log_proposal_pdf_out,
412
+ self.log_sum_weights,
413
+ ) = out_vals
414
+
415
+ self.ll_out[inds_reverse] = self.special_aux_ll
416
+ self.lp_out[inds_reverse] = self.special_aux_lp
417
+
418
+ def get_proposal(
419
+ self,
420
+ gb_coords,
421
+ inds,
422
+ changes,
423
+ leaf_inds_for_changes,
424
+ band_inds,
425
+ random,
426
+ supps=None,
427
+ branch_supps=None,
428
+ ):
429
+ """Make a proposal
430
+
431
+ Args:
432
+ all_coords (dict): Keys are ``branch_names``. Values are
433
+ np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
434
+ coordinates for all the walkers.
435
+ all_inds (dict): Keys are ``branch_names``. Values are
436
+ np.ndarray[ntemps, nwalkers, nleaves_max]. These are the boolean
437
+ arrays marking which leaves are currently used within each walker.
438
+ all_inds_for_change (dict): Keys are ``branch_names``. Values are
439
+ dictionaries. These dictionaries have keys ``"+1"`` and ``"-1"``,
440
+ indicating waklkers that are adding or removing a leafm respectively.
441
+ The values for these dicts are ``int`` np.ndarray[..., 3]. The "..." indicates
442
+ the number of walkers in all temperatures that fall under either adding
443
+ or removing a leaf. The second dimension, 3, is the indexes into
444
+ the three-dimensional arrays within ``all_inds`` of the specific leaf
445
+ that is being added or removed from those leaves currently considered.
446
+ random (object): Current random state of the sampler.
447
+
448
+ Returns:
449
+ tuple: Tuple containing proposal information.
450
+ First entry is the new coordinates as a dictionary with keys
451
+ as ``branch_names`` and values as
452
+ ``double `` np.ndarray[ntemps, nwalkers, nleaves_max, ndim] containing
453
+ proposed coordinates. Second entry is the new ``inds`` array with
454
+ boolean values flipped for added or removed sources. Third entry
455
+ is the factors associated with the
456
+ proposal necessary for detailed balance. This is effectively
457
+ any term in the detailed balance fraction. +log of factors if
458
+ in the numerator. -log of factors if in the denominator.
459
+
460
+ """
461
+
462
+ name = "gb"
463
+
464
+ ntemps, nwalkers, nleaves_max, ndim = gb_coords.shape
465
+
466
+ # adjust inds
467
+ changes_gpu = self.xp.asarray(changes)
468
+ # TODO: remove xp.asarrays here if they come in as GPU arrays
469
+ inds_reverse_band = self.xp.where(changes_gpu == -1)
470
+
471
+ leaf_inds_for_changes_gpu = self.xp.asarray(leaf_inds_for_changes)
472
+
473
+ inds_reverse = inds_reverse_band[:2] + (
474
+ leaf_inds_for_changes_gpu[inds_reverse_band],
475
+ )
476
+ inds_reverse_cpu = tuple([ir.get() for ir in list(inds_reverse)])
477
+ inds_reverse_in = self.xp.where(changes_gpu.flatten() == -1)[0]
478
+ inds_reverse_individual = leaf_inds_for_changes_gpu[changes_gpu < 0]
479
+
480
+ inds_forward_band = self.xp.where(changes_gpu == +1)
481
+
482
+ inds_forward = inds_forward_band[:2] + (
483
+ leaf_inds_for_changes_gpu[inds_forward_band],
484
+ )
485
+ # inds_forward_in = self.xp.where(changes.flatten() == +1)[0]
486
+ # inds_forward_individual = leaf_inds_for_changes_gpu[changes < 0]
487
+
488
+ # add coordinates for new leaves
489
+ current_priors = self.priors[name]
490
+
491
+ inds_here_band = self.xp.where((changes_gpu == -1) | (changes_gpu == +1))
492
+
493
+ inds_here = inds_here_band[:2] + (leaf_inds_for_changes_gpu[inds_here_band],)
494
+ inds_here_cpu = tuple([ir.get() for ir in list(inds_here)])
495
+
496
+ # allows for adjustment during search
497
+ reset_num_try = False
498
+ if len(inds_here[0]) != 0:
499
+ if self.search and self.search_samples is not None:
500
+ raise not NotImplementedError
501
+ if len(self.search_inds) >= self.num_try:
502
+ reset_num_try = True
503
+ old_num_try = self.num_try
504
+ self.num_try = len(self.search_inds)
505
+
506
+ num_inds_change = len(inds_here[0])
507
+
508
+ if self.provide_betas: # and not self.search:
509
+ betas = self.xp.asarray(self.temperature_control.betas)[inds_here[0]]
510
+ else:
511
+ betas = None
512
+
513
+ ll_here = self.xp.asarray(self.current_state.log_like.copy())[inds_here[:2]]
514
+ lp_here = self.xp.asarray(self.current_state.log_prior)[inds_here[:2]]
515
+
516
+ self.lp_old = lp_here
517
+
518
+ rj_info = dict(
519
+ ll=self.xp.zeros_like(ll_here), lp=self.xp.zeros_like(lp_here)
520
+ )
521
+
522
+ N_vals = branch_supps.holder["N_vals"]
523
+ N_vals = self.xp.asarray(N_vals)
524
+
525
+ self.inds_reverse = inds_reverse
526
+ self.inds_forward = inds_forward
527
+
528
+ if len(inds_reverse[0]) > 0:
529
+ parameters_remove = self.parameter_transforms.both_transforms(
530
+ gb_coords[inds_reverse_cpu], xp=self.xp
531
+ )
532
+
533
+ group_index_tmp = inds_reverse[0] * nwalkers + inds_reverse[1]
534
+ N_vals_in = self.xp.asarray(N_vals[inds_reverse])
535
+ group_index = self.xp.asarray(
536
+ self.mgh.get_mapped_indices(group_index_tmp).astype(self.xp.int32)
537
+ )
538
+
539
+ waveform_kwargs_add = self.waveform_kwargs.copy()
540
+ waveform_kwargs_add.pop("N")
541
+ # removing these so d - (h - r) = d - h + r
542
+
543
+ try:
544
+ self.xp.cuda.runtime.deviceSynchronize()
545
+ # parameters_remove[:, 0] *= 1e-30
546
+ self.gb.generate_global_template(
547
+ parameters_remove,
548
+ group_index,
549
+ self.mgh.data_list,
550
+ N=N_vals_in,
551
+ data_length=self.data_length,
552
+ data_splits=self.mgh.gpu_splits,
553
+ **waveform_kwargs_add
554
+ )
555
+ except ValueError as e:
556
+ print(e)
557
+ breakpoint()
558
+
559
+ # self.checkit4 = (-1/2 * self.df * 4 * self.xp.sum(data_minus_template.conj() * data_minus_template / self.psd, axis=(2,3))) + self.xp.asarray(self.noise_ll).reshape(ntemps, nwalkers)
560
+
561
+ overall_inds = supps.holder["overall_inds"][inds_here_cpu[:2]]
562
+ base_shape = (len(inds_here[0]), 1, ndim)
563
+ assert np.prod(band_inds.shape) == inds_here[0].shape[0]
564
+ old_d_h_d_h = self.xp.zeros_like(ll_here)
565
+
566
+ generate_points_out, logP_out, factors_out = self.get_mt_proposal(
567
+ self.xp.asarray(gb_coords[inds_here_cpu]),
568
+ len(inds_here[0]),
569
+ inds_reverse_in,
570
+ random,
571
+ args_prior=(base_shape, inds_reverse_in),
572
+ kwargs_generate={
573
+ "current_priors": current_priors,
574
+ "band_inds": band_inds.flatten(),
575
+ },
576
+ args_like=(base_shape,),
577
+ kwargs_like={
578
+ "inds_reverse": inds_reverse_in,
579
+ "old_d_h_d_h": old_d_h_d_h,
580
+ "overall_inds": overall_inds,
581
+ },
582
+ betas=betas,
583
+ rj_info=rj_info,
584
+ )
585
+
586
+ # gb_coords[inds_forward] = generate_points_out.copy()
587
+
588
+ # TODO: make sure detailed balance this will move to detailed balance in multiple try
589
+
590
+ self.logP_out = logP_out
591
+
592
+ return (
593
+ generate_points_out,
594
+ self.ll_out,
595
+ self.lp_out,
596
+ factors_out,
597
+ betas,
598
+ inds_here,
599
+ )
600
+
601
+ else:
602
+ breakpoint()
603
+
604
+ def propose(self, model, state):
605
+ """Use the move to generate a proposal and compute the acceptance
606
+
607
+ Args:
608
+ model (:class:`eryn.model.Model`): Carrier of sampler information.
609
+ state (:class:`State`): Current state of the sampler.
610
+
611
+ Returns:
612
+ :class:`State`: State of sampler after proposal is complete.
613
+
614
+ """
615
+
616
+ st = time.perf_counter()
617
+ # TODO: keep this?
618
+ # this exposes anywhere in the proposal class to this information
619
+ self.current_state = state
620
+ self.current_model = model
621
+
622
+ # Run any move-specific setup.
623
+ self.setup(state.branches)
624
+
625
+ # ll_before = model.compute_log_like_fn(state.branches_coords, inds=state.branches_inds, supps=state.supplimental, branch_supps=state.branches_supplimental)
626
+
627
+ # if not np.allclose(ll_before[0], state.log_like):
628
+ # breakpoint()
629
+
630
+ new_state = State(state, copy=True)
631
+ all_branch_names = list(new_state.branches_coords.keys())
632
+
633
+ if np.any(
634
+ new_state.branches_supplimental["gb"].holder["N_vals"][
635
+ new_state.branches_inds["gb"]
636
+ ]
637
+ == 0
638
+ ):
639
+ breakpoint()
640
+
641
+ self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
642
+
643
+ ntemps, nwalkers, _, _ = new_state.branches[
644
+ list(new_state.branches.keys())[0]
645
+ ].shape
646
+
647
+ num_consecutive_rj_moves = 1
648
+ # print("starting")
649
+
650
+ for rj_move_i in range(num_consecutive_rj_moves):
651
+ # st = time.perf_counter()
652
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
653
+
654
+ coords_propose_in = self.xp.asarray(new_state.branches_coords["gb"])
655
+ inds_propose_in = self.xp.asarray(new_state.branches_inds["gb"])
656
+ branches_supp_propose_in = new_state.branches_supplimental["gb"]
657
+ remaining_coords = coords_propose_in[inds_propose_in]
658
+ f0 = remaining_coords[:, 1] / 1e3
659
+ inds_into_full_f0 = self.xp.arange(f0.shape[0])
660
+
661
+ band_inds = self.xp.searchsorted(self.band_edges, f0) - 1
662
+ temp_inds = self.xp.repeat(
663
+ self.xp.arange(ntemps)[:, None],
664
+ np.prod(coords_propose_in.shape[1:3]),
665
+ axis=1,
666
+ ).reshape(ntemps, nwalkers, -1)[inds_propose_in]
667
+ walkers_inds = self.xp.repeat(
668
+ self.xp.repeat(self.xp.arange(nwalkers)[None, :], ntemps, axis=0)[
669
+ :, :, None
670
+ ],
671
+ coords_propose_in.shape[2],
672
+ axis=2,
673
+ )[inds_propose_in]
674
+ leaf_inds = self.xp.repeat(
675
+ self.xp.arange(coords_propose_in.shape[2])[None, :],
676
+ ntemps * nwalkers,
677
+ axis=0,
678
+ ).reshape(ntemps, nwalkers, -1)[inds_propose_in]
679
+
680
+ # TODO: make this weighted somehow to not focus on empty bands
681
+ # we are already leaving out the last one
682
+ bands_per_walker = self.xp.tile(
683
+ (self.xp.arange(self.num_bands - 1)), (ntemps, nwalkers, 1)
684
+ )
685
+
686
+ """et = time.perf_counter()
687
+ print("initial setup", et - st)"""
688
+ # st = time.perf_counter()
689
+
690
+ # odds & evens (evens first)
691
+ for i in range(2):
692
+ # st = time.perf_counter()
693
+ bands_per_walker_here = bands_per_walker[
694
+ bands_per_walker % 2 == i
695
+ ].reshape(ntemps, nwalkers, -1)
696
+ max_band_ind = bands_per_walker_here.max().item()
697
+ band_inds_here = (
698
+ int(1e12) * temp_inds[band_inds % 2 == i]
699
+ + int(1e6) * walkers_inds[band_inds % 2 == i]
700
+ + band_inds[band_inds % 2 == i]
701
+ )
702
+
703
+ inds_into_full_f0_here = inds_into_full_f0[band_inds.flatten() % 2 == i]
704
+
705
+ inds_for_shuffle = self.xp.arange(len(inds_into_full_f0_here))
706
+ # shuffle allows the first unique index to be the one that is dropped
707
+ self.xp.random.shuffle(inds_for_shuffle)
708
+ bands_inds_here_tmp = band_inds_here[inds_for_shuffle]
709
+
710
+ assert band_inds_here.shape == inds_into_full_f0_here.shape
711
+
712
+ # shuffle and then take first index from band
713
+ # is a good way to get random binary from that band
714
+ (
715
+ unique_bands_inds_here,
716
+ unique_bands_inds_here_index,
717
+ unique_bands_inds_here_inverse,
718
+ unique_bands_inds_here_counts,
719
+ ) = self.xp.unique(
720
+ bands_inds_here_tmp,
721
+ return_index=True,
722
+ return_counts=True,
723
+ return_inverse=True,
724
+ )
725
+
726
+ # selected_inds_from_shuffle is indices reverted back to
727
+ # inds_into_full_f0_here after shuffle
728
+ selected_inds_from_shuffle = inds_for_shuffle[
729
+ unique_bands_inds_here_index
730
+ ]
731
+
732
+ # frequencies of the first instance found in band according to self.xp.unique call
733
+ # and what bands they belong to
734
+ inds_into_full_f0_here_selected = inds_into_full_f0_here[
735
+ selected_inds_from_shuffle
736
+ ]
737
+ bands_inds_here_selected = band_inds_here[selected_inds_from_shuffle]
738
+
739
+ temp_index_count = temp_inds[inds_into_full_f0_here_selected]
740
+ walkers_index_count = walkers_inds[inds_into_full_f0_here_selected]
741
+ leaf_index_count = leaf_inds[inds_into_full_f0_here_selected]
742
+
743
+ # get band index for each selected source by maths
744
+ band_index_count = self.xp.floor(
745
+ (
746
+ unique_bands_inds_here
747
+ - temp_index_count * int(1e12)
748
+ - walkers_index_count * int(1e6)
749
+ )
750
+ ).astype(int)
751
+ band_index_tmp = self.xp.searchsorted(
752
+ bands_per_walker_here[0, 0], band_index_count, side="left"
753
+ )
754
+
755
+ nleaves_here = self.xp.zeros_like(bands_per_walker_here)
756
+
757
+ # how many leaves are in each band
758
+ try:
759
+ nleaves_here[
760
+ (temp_index_count, walkers_index_count, band_index_tmp)
761
+ ] = unique_bands_inds_here_counts
762
+ except IndexError:
763
+ breakpoint()
764
+
765
+ # setup leaf count change arrays
766
+ if self.fix_change is None:
767
+ changes = self.xp.random.choice(
768
+ [-1, 1], size=bands_per_walker_here.shape
769
+ )
770
+ else:
771
+ changes = self.xp.full(bands_per_walker_here.shape, self.fix_change)
772
+
773
+ # make sure to add a binary if there are None
774
+ changes[nleaves_here == 0] = +1
775
+
776
+ if self.xp.any(
777
+ nleaves_here >= self.max_k[all_branch_names.index("gb")]
778
+ ):
779
+ raise ValueError("nleaves_here higher than max_k.")
780
+
781
+ # number of sub bands
782
+ num_sub_bands_here = bands_per_walker_here.shape[-1]
783
+
784
+ # build arrays to help properly index and locate within these sub bands
785
+ temp_inds_for_change = self.xp.repeat(
786
+ self.xp.arange(ntemps)[:, None],
787
+ nwalkers * num_sub_bands_here,
788
+ axis=-1,
789
+ ).reshape(ntemps, nwalkers, num_sub_bands_here)
790
+ walker_inds_for_change = self.xp.tile(
791
+ self.xp.arange(nwalkers), (ntemps, num_sub_bands_here, 1)
792
+ ).transpose((0, 2, 1))
793
+ band_inds_for_change = self.xp.tile(
794
+ self.xp.arange(num_sub_bands_here), (ntemps, nwalkers, 1)
795
+ )
796
+
797
+ # leaves will be filled later
798
+ leaf_inds_for_change = -self.xp.ones(
799
+ (ntemps, nwalkers, num_sub_bands_here), dtype=int
800
+ )
801
+
802
+ # add leaves that would be removed if count is not currently zero
803
+ # will adjust for count soon
804
+ # len(leaf_index_count) < total number of proposals
805
+ leaf_inds_for_change[
806
+ (temp_index_count, walkers_index_count, band_index_tmp)
807
+ ] = leaf_index_count
808
+
809
+ # band_inds_for_change[(temp_index_count, walkers_index_count, band_index_tmp)] = band_index_tmp
810
+
811
+ leaf_inds_tmp = self.xp.repeat(
812
+ self.xp.arange(coords_propose_in.shape[2])[None, :],
813
+ ntemps * nwalkers,
814
+ axis=0,
815
+ ).reshape(ntemps, nwalkers, -1)
816
+
817
+ # any binary that exists give fake high value to remove from sort
818
+ leaf_inds_tmp[inds_propose_in] = int(1e7)
819
+
820
+ # this gives unused leaves, you take the number you need for each walker / temperature
821
+ leaf_inds_tmp_2 = self.xp.sort(leaf_inds_tmp, axis=-1)[
822
+ :, :, :num_sub_bands_here
823
+ ]
824
+ leaf_inds_for_change[changes > 0] = leaf_inds_tmp_2[changes > 0]
825
+
826
+ assert not self.xp.any(leaf_inds_for_change == -1)
827
+
828
+ if self.xp.any(leaf_inds_tmp_2 == int(1e7)):
829
+ raise ValueError(
830
+ "Not enough spots to allocate for new binaries. Need to increase max leaves."
831
+ )
832
+
833
+ # TODO: check that number of available spots is high enough
834
+
835
+ """et = time.perf_counter()
836
+ print("start", et - st)
837
+ st = time.perf_counter()"""
838
+ # propose new sources and coordinates
839
+ (
840
+ new_coords,
841
+ ll_out,
842
+ lp_out,
843
+ factors,
844
+ betas,
845
+ inds_here,
846
+ ) = self.get_proposal(
847
+ coords_propose_in,
848
+ inds_propose_in,
849
+ changes,
850
+ leaf_inds_for_change,
851
+ bands_per_walker_here,
852
+ model.random,
853
+ branch_supps=branches_supp_propose_in,
854
+ supps=new_state.supplimental,
855
+ )
856
+
857
+ """et = time.perf_counter()
858
+ print("proposal", et - st)
859
+ st = time.perf_counter()"""
860
+
861
+ # TODO: check this
862
+ edge_factors = self.xp.zeros_like(factors)
863
+ # get factors for edges
864
+ min_k = self.min_k[all_branch_names.index("gb")]
865
+ max_k = self.max_k[all_branch_names.index("gb")]
866
+
867
+ # fix proposal asymmetry at bottom of k range
868
+ inds_min = self.xp.where(nleaves_here.flatten() == min_k)
869
+ # numerator term so +ln
870
+ edge_factors[inds_min] += np.log(1 / 2.0)
871
+
872
+ # fix proposal asymmetry at top of k range
873
+ inds_max = self.xp.where(nleaves_here.flatten() == max_k)
874
+ # numerator term so -ln
875
+ edge_factors[inds_max] += np.log(1 / 2.0)
876
+
877
+ # fix proposal asymmetry at bottom of k range (kmin + 1)
878
+ inds_min = self.xp.where(nleaves_here.flatten() == min_k + 1)
879
+ # numerator term so +ln
880
+ edge_factors[inds_min] -= np.log(1 / 2.0)
881
+
882
+ # fix proposal asymmetry at top of k range (kmax - 1)
883
+ inds_max = self.xp.where(nleaves_here.flatten() == max_k - 1)
884
+ # numerator term so -ln
885
+ edge_factors[inds_max] -= np.log(1 / 2.0)
886
+
887
+ factors += edge_factors
888
+
889
+ prev_logl = self.xp.asarray(new_state.log_like)[inds_here[:2]]
890
+
891
+ prev_logp = self.xp.asarray(new_state.log_prior)[inds_here[:2]]
892
+
893
+ """et = time.perf_counter()
894
+ print("prior", et - st)
895
+ st = time.perf_counter()"""
896
+ logl = prev_logl + ll_out
897
+
898
+ logp = prev_logp + lp_out
899
+ # loglcheck, new_blobs = model.compute_log_like_fn(q, inds=new_inds, logp=logp, supps=new_supps, branch_supps=new_branch_supps)
900
+ # if not self.xp.all(self.xp.abs(logl[logl != -1e300] - loglcheck[logl != -1e300]) < 1e-5):
901
+ # breakpoint()
902
+
903
+ logP = self.compute_log_posterior(logl, logp, betas=betas)
904
+
905
+ # TODO: check about prior = - inf
906
+ # takes care of tempering
907
+ prev_logP = self.compute_log_posterior(
908
+ prev_logl, prev_logp, betas=betas
909
+ )
910
+
911
+ # TODO: fix this
912
+ # this is where _metropolisk should come in
913
+ lnpdiff = factors + logP - prev_logP
914
+
915
+ accepted = lnpdiff > self.xp.log(
916
+ self.xp.asarray(model.random.rand(*lnpdiff.shape))
917
+ )
918
+
919
+ """et = time.perf_counter()
920
+ print("through accepted", et - st)
921
+ st = time.perf_counter()"""
922
+
923
+ # bookkeeping
924
+
925
+ inds_reverse = self.xp.where(changes.flatten() == -1)[0]
926
+
927
+ # adjust births from False -> True
928
+ inds_forward = self.xp.where(changes.flatten() == +1)[0]
929
+
930
+ # accepted_keep = (temp_inds_for_change.flatten()[inds_forward] == 0) & (walker_inds_for_change.flatten()[inds_forward] == 1)
931
+
932
+ """accepted_keep_tmp = self.xp.where(accepted_keep)[0]
933
+ accepted_keep[accepted_keep_tmp[:]] = False
934
+ accepted_keep[accepted_keep_tmp[164679]] = True"""
935
+
936
+ # accepted[:] = False
937
+ # accepted[inds_forward] = True
938
+ """accepted[inds_forward[:]] = False
939
+ accepted[inds_reverse[0]] = False"""
940
+
941
+ # not accepted removals
942
+ accepted_reverse = accepted[inds_reverse]
943
+ accepted_inds_reverse = inds_reverse[accepted_reverse]
944
+ not_accepted_inds_reverse = inds_reverse[~accepted_reverse]
945
+
946
+ tuple_not_accepted_reverse = (
947
+ temp_inds_for_change.flatten()[not_accepted_inds_reverse],
948
+ walker_inds_for_change.flatten()[not_accepted_inds_reverse],
949
+ leaf_inds_for_change.flatten()[not_accepted_inds_reverse],
950
+ )
951
+
952
+ tuple_accepted_reverse = (
953
+ temp_inds_for_change.flatten()[accepted_inds_reverse],
954
+ walker_inds_for_change.flatten()[accepted_inds_reverse],
955
+ leaf_inds_for_change.flatten()[accepted_inds_reverse],
956
+ )
957
+
958
+ if len(not_accepted_inds_reverse) > 0:
959
+ points_not_accepted_removal = coords_propose_in[
960
+ tuple_not_accepted_reverse
961
+ ]
962
+
963
+ else:
964
+ points_not_accepted_removal = self.xp.empty((0, 8))
965
+
966
+ tuple_accepted_reverse_cpu = tuple(
967
+ [tmp.get() for tmp in list(tuple_accepted_reverse)]
968
+ )
969
+ # accepted removals
970
+ new_state.branches["gb"].inds[tuple_accepted_reverse_cpu] = False
971
+ delta_logl_trans1 = self.xp.zeros_like(
972
+ leaf_inds_for_change, dtype=float
973
+ )
974
+ delta_logl_trans1[
975
+ (
976
+ temp_inds_for_change.flatten()[accepted_inds_reverse],
977
+ walker_inds_for_change.flatten()[accepted_inds_reverse],
978
+ band_inds_for_change.flatten()[accepted_inds_reverse],
979
+ )
980
+ ] = (
981
+ logl[accepted_inds_reverse] - prev_logl[accepted_inds_reverse]
982
+ )
983
+
984
+ new_state.log_like += delta_logl_trans1.sum(axis=-1).get()
985
+
986
+ # accepted update logp
987
+ delta_logp_trans = self.xp.zeros_like(leaf_inds_for_change, dtype=float)
988
+ delta_logp_trans[
989
+ (
990
+ temp_inds_for_change.flatten()[accepted_inds_reverse],
991
+ walker_inds_for_change.flatten()[accepted_inds_reverse],
992
+ band_inds_for_change.flatten()[accepted_inds_reverse],
993
+ )
994
+ ] = (
995
+ logp[accepted_inds_reverse] - prev_logp[accepted_inds_reverse]
996
+ )
997
+
998
+ new_state.log_prior += delta_logp_trans.sum(axis=-1).get()
999
+
1000
+ # accepted additions
1001
+ accepted_forward = accepted[inds_forward]
1002
+ accepted_inds_forward = inds_forward[accepted_forward]
1003
+ not_accepted_inds_forward = inds_forward[~accepted_forward]
1004
+
1005
+ tuple_accepted_forward = (
1006
+ temp_inds_for_change.flatten()[accepted_inds_forward],
1007
+ walker_inds_for_change.flatten()[accepted_inds_forward],
1008
+ leaf_inds_for_change.flatten()[accepted_inds_forward],
1009
+ )
1010
+
1011
+ tuple_accepted_forward_cpu = tuple(
1012
+ [tmp.get() for tmp in list(tuple_accepted_forward)]
1013
+ )
1014
+ new_state.branches["gb"].inds[tuple_accepted_forward_cpu] = True
1015
+ new_state.branches["gb"].coords[
1016
+ tuple_accepted_forward_cpu
1017
+ ] = new_coords[accepted_forward].get()
1018
+
1019
+ """et = time.perf_counter()
1020
+ print("bookkeeping", et - st)
1021
+ st = time.perf_counter()"""
1022
+
1023
+ if len(accepted_inds_forward) > 0:
1024
+ points_accepted_addition = new_coords[accepted_forward]
1025
+
1026
+ # get group friend finder information
1027
+ f0_accepted_addition = -100 * self.xp.ones_like(
1028
+ leaf_inds_for_change, dtype=float
1029
+ )
1030
+ f0_accepted_addition[
1031
+ (
1032
+ temp_inds_for_change.flatten()[accepted_inds_forward],
1033
+ walker_inds_for_change.flatten()[accepted_inds_forward],
1034
+ band_inds_for_change.flatten()[accepted_inds_forward],
1035
+ )
1036
+ ] = (
1037
+ points_accepted_addition[:, 1] / 1e3
1038
+ )
1039
+
1040
+ for t in range(ntemps):
1041
+ # use old state to get supp information
1042
+ f0_old = (
1043
+ self.xp.asarray(
1044
+ state.branches["gb"].coords[
1045
+ t, state.branches["gb"].inds[t]
1046
+ ][:, 1]
1047
+ )
1048
+ / 1e3
1049
+ )
1050
+ friend_start_inds = self.xp.asarray(
1051
+ state.branches["gb"].branch_supplimental.holder[
1052
+ "friend_start_inds"
1053
+ ][t, state.branches["gb"].inds[t]]
1054
+ )
1055
+
1056
+ f0_old_sorted = self.xp.sort(f0_old, axis=-1)
1057
+ inds_f0_old_sorted = self.xp.argsort(f0_old, axis=-1)
1058
+ f0_accepted_addition_in1 = f0_accepted_addition[t]
1059
+ f0_accepted_addition_in2 = f0_accepted_addition_in1[
1060
+ f0_accepted_addition_in1 > -1.0
1061
+ ]
1062
+ temp_inds_in = temp_inds_for_change[t][
1063
+ f0_accepted_addition_in1 > -1.0
1064
+ ]
1065
+ walker_inds_in = walker_inds_for_change[t][
1066
+ f0_accepted_addition_in1 > -1.0
1067
+ ]
1068
+ leaf_inds_in = leaf_inds_for_change[t][
1069
+ f0_accepted_addition_in1 > -1.0
1070
+ ]
1071
+
1072
+ inds_f0_accepted_addition = (
1073
+ self.xp.searchsorted(
1074
+ f0_old_sorted, f0_accepted_addition_in2, side="right"
1075
+ )
1076
+ - 1
1077
+ )
1078
+
1079
+ old_inds_f0_old = inds_f0_old_sorted[inds_f0_accepted_addition]
1080
+ comp_f0_old = f0_old[old_inds_f0_old]
1081
+ new_friends_start_inds = friend_start_inds[old_inds_f0_old]
1082
+
1083
+ # TODO: maybe check this
1084
+ new_state.branches["gb"].branch_supplimental.holder[
1085
+ "friend_start_inds"
1086
+ ][
1087
+ (
1088
+ temp_inds_in.get(),
1089
+ walker_inds_in.get(),
1090
+ leaf_inds_in.get(),
1091
+ )
1092
+ ] = new_friends_start_inds.get()
1093
+
1094
+ """inds_into_old_array = self.xp.take_along_axis(inds_f0_old_sorted, inds_f0_accepted_addition, axis=-1).reshape(ntemps, nwalkers, -1)
1095
+
1096
+ new_start_inds = self.xp.take_along_axis(state.branches["gb"].branch_supplimental.holder["friend_start_inds"], inds_into_old_array, axis=-1)
1097
+
1098
+ keep_new_start_inds = new_start_inds[f0_accepted_addition.reshape(ntemps, nwalkers, -1) > -10.0]
1099
+ breakpoint()
1100
+ state.branches["gb"].branch_supplimental.holder["friend_start_inds"][(
1101
+
1102
+ )] = keep_new_start_inds
1103
+
1104
+ inds_into_old_array
1105
+ keep_inds_f0_accepted_addition = inds_f0_accepted_addition > 0
1106
+
1107
+
1108
+
1109
+ inds_into_old_array = inds_f0_old_sorted[inds_f0_accepted_addition]"""
1110
+
1111
+ else:
1112
+ points_accepted_addition = self.xp.empty((0, 8))
1113
+
1114
+ delta_logl_trans2 = self.xp.zeros_like(
1115
+ leaf_inds_for_change, dtype=float
1116
+ )
1117
+ delta_logl_trans2[
1118
+ (
1119
+ temp_inds_for_change.flatten()[accepted_inds_forward],
1120
+ walker_inds_for_change.flatten()[accepted_inds_forward],
1121
+ band_inds_for_change.flatten()[accepted_inds_forward],
1122
+ )
1123
+ ] = (
1124
+ logl[accepted_inds_forward] - prev_logl[accepted_inds_forward]
1125
+ )
1126
+
1127
+ new_state.log_like += delta_logl_trans2.sum(axis=-1).get()
1128
+
1129
+ delta_logp_trans = self.xp.zeros_like(leaf_inds_for_change, dtype=float)
1130
+ delta_logp_trans[
1131
+ (
1132
+ temp_inds_for_change.flatten()[accepted_inds_forward],
1133
+ walker_inds_for_change.flatten()[accepted_inds_forward],
1134
+ band_inds_for_change.flatten()[accepted_inds_forward],
1135
+ )
1136
+ ] = (
1137
+ logp[accepted_inds_forward] - prev_logp[accepted_inds_forward]
1138
+ )
1139
+
1140
+ new_state.log_prior += delta_logp_trans.sum(axis=-1).get()
1141
+
1142
+ accepted_counts = self.xp.zeros_like(leaf_inds_for_change, dtype=bool)
1143
+ num_proposals_here = leaf_inds_for_change.shape[-1]
1144
+
1145
+ accepted_counts[
1146
+ (
1147
+ temp_inds_for_change.flatten()[accepted],
1148
+ walker_inds_for_change.flatten()[accepted],
1149
+ band_inds_for_change.flatten()[accepted],
1150
+ )
1151
+ ] = True
1152
+
1153
+ accepted_overall = accepted_counts.sum(axis=-1)
1154
+
1155
+ # TODO: if we do not run every band, we need to adjust this
1156
+ self.accepted += accepted_overall.get()
1157
+ self.num_proposals += num_proposals_here
1158
+
1159
+ points_to_add_to_template = self.xp.concatenate(
1160
+ [
1161
+ points_accepted_addition,
1162
+ points_not_accepted_removal, # were removed at the start, need to be put back
1163
+ ],
1164
+ axis=0,
1165
+ )
1166
+
1167
+ """et = time.perf_counter()
1168
+ print("before add", et - st)
1169
+ st = time.perf_counter()"""
1170
+
1171
+ if points_to_add_to_template.shape[0] > 0:
1172
+ points_to_add_to_template_in = (
1173
+ self.parameter_transforms.both_transforms(
1174
+ points_to_add_to_template, xp=self.xp
1175
+ )
1176
+ )
1177
+ N_temp = get_N(
1178
+ points_to_add_to_template_in[:, 0],
1179
+ points_to_add_to_template_in[:, 1],
1180
+ self.waveform_kwargs["T"],
1181
+ self.waveform_kwargs["oversample"],
1182
+ )
1183
+
1184
+ # has to be accepted forward first (sase as points_to_add_to_template)
1185
+ inds_here_for_add = tuple(
1186
+ [
1187
+ self.xp.concatenate(
1188
+ [
1189
+ tuple_accepted_forward[jjj],
1190
+ tuple_not_accepted_reverse[jjj],
1191
+ ]
1192
+ )
1193
+ for jjj in range(len(tuple_not_accepted_reverse))
1194
+ ]
1195
+ )
1196
+
1197
+ inds_here_for_add_cpu = tuple(
1198
+ [tmp.get() for tmp in list(inds_here_for_add)]
1199
+ )
1200
+
1201
+ # inds_here_for_add = tuple_accepted_forward
1202
+
1203
+ new_state.branches["gb"].branch_supplimental.holder["N_vals"][
1204
+ inds_here_for_add_cpu
1205
+ ] = N_temp.get()
1206
+
1207
+ group_index_tmp_accepted_forward = (
1208
+ tuple_accepted_forward[0] * nwalkers + tuple_accepted_forward[1]
1209
+ )
1210
+ group_index_tmp_not_accepted_reverse = (
1211
+ tuple_not_accepted_reverse[0] * nwalkers
1212
+ + tuple_not_accepted_reverse[1]
1213
+ )
1214
+
1215
+ group_index_tmp = self.xp.concatenate(
1216
+ [
1217
+ group_index_tmp_accepted_forward,
1218
+ group_index_tmp_not_accepted_reverse,
1219
+ ]
1220
+ )
1221
+ group_index = self.xp.asarray(
1222
+ self.mgh.get_mapped_indices(group_index_tmp).astype(
1223
+ self.xp.int32
1224
+ )
1225
+ )
1226
+
1227
+ # adding these so d - (h + a) = d - h - a
1228
+ factors_multiply = -self.xp.ones_like(N_temp, dtype=self.xp.float64)
1229
+ waveform_kwargs_add = self.waveform_kwargs.copy()
1230
+ waveform_kwargs_add.pop("N")
1231
+
1232
+ # points_to_add_to_template_in[1:, 0] *= 1e-30
1233
+
1234
+ self.gb.generate_global_template(
1235
+ points_to_add_to_template_in,
1236
+ group_index,
1237
+ self.mgh.data_list,
1238
+ N=N_temp,
1239
+ data_length=self.data_length,
1240
+ data_splits=self.mgh.gpu_splits,
1241
+ factors=factors_multiply,
1242
+ **waveform_kwargs_add
1243
+ )
1244
+
1245
+ if np.any(
1246
+ new_state.branches_supplimental["gb"].holder["N_vals"][
1247
+ new_state.branches_inds["gb"]
1248
+ ]
1249
+ == 0
1250
+ ):
1251
+ breakpoint()
1252
+
1253
+ """et = time.perf_counter()
1254
+ print("after add", et - st)
1255
+ st = time.perf_counter()"""
1256
+
1257
+ # st = time.perf_counter()
1258
+ if self.time % 1 == 0:
1259
+ ll_after2 = (
1260
+ self.mgh.get_ll(include_psd_info=True)
1261
+ .flatten()[new_state.supplimental[:]["overall_inds"]]
1262
+ .reshape(ntemps, nwalkers)
1263
+ )
1264
+ if np.abs(ll_after2 - new_state.log_like).max() > 1e-4:
1265
+ if np.abs(ll_after2 - new_state.log_like).max() > 10.0:
1266
+ breakpoint()
1267
+ fix = np.abs(ll_after2 - new_state.log_like) > 1e-4
1268
+ new_state.log_like[fix] = ll_after2[fix]
1269
+
1270
+ # print("rj", (et - st) / num_consecutive_rj_moves)
1271
+
1272
+ if False: # self.temperature_control is not None and not self.prevent_swaps:
1273
+ # TODO: add swaps?
1274
+ new_state = self.temperature_control.temper_comps(new_state, adapt=False)
1275
+
1276
+ # et = time.perf_counter()
1277
+ # print("swapping", et - st)
1278
+ self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
1279
+ accepted = np.zeros_like(new_state.log_like)
1280
+ self.time += 1
1281
+
1282
+ self.mempool.free_all_blocks()
1283
+
1284
+ et = time.perf_counter()
1285
+ print("RJ end", et - st)
1286
+ print(new_state.branches["gb"].nleaves.mean(axis=-1))
1287
+ return new_state, accepted