lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lisaanalysistools might be problematic. Click here for more details.

Files changed (37) hide show
  1. lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
  2. lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
  3. lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
  4. lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
  5. lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
  6. lisatools/__init__.py +0 -0
  7. lisatools/_version.py +4 -0
  8. lisatools/analysiscontainer.py +438 -0
  9. lisatools/cutils/detector.cpython-312-darwin.so +0 -0
  10. lisatools/datacontainer.py +292 -0
  11. lisatools/detector.py +410 -0
  12. lisatools/diagnostic.py +976 -0
  13. lisatools/glitch.py +193 -0
  14. lisatools/sampling/__init__.py +0 -0
  15. lisatools/sampling/likelihood.py +882 -0
  16. lisatools/sampling/moves/__init__.py +0 -0
  17. lisatools/sampling/moves/gbgroupstretch.py +53 -0
  18. lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
  19. lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
  20. lisatools/sampling/moves/gbspecialstretch.py +1836 -0
  21. lisatools/sampling/moves/mbhspecialmove.py +286 -0
  22. lisatools/sampling/moves/placeholder.py +16 -0
  23. lisatools/sampling/moves/skymodehop.py +110 -0
  24. lisatools/sampling/moves/specialforegroundmove.py +564 -0
  25. lisatools/sampling/prior.py +508 -0
  26. lisatools/sampling/stopping.py +320 -0
  27. lisatools/sampling/utility.py +324 -0
  28. lisatools/sensitivity.py +888 -0
  29. lisatools/sources/__init__.py +0 -0
  30. lisatools/sources/emri/__init__.py +1 -0
  31. lisatools/sources/emri/tdiwaveform.py +72 -0
  32. lisatools/stochastic.py +291 -0
  33. lisatools/utils/__init__.py +0 -0
  34. lisatools/utils/constants.py +40 -0
  35. lisatools/utils/multigpudataholder.py +730 -0
  36. lisatools/utils/pointeradjust.py +106 -0
  37. lisatools/utils/utility.py +240 -0
@@ -0,0 +1,671 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from copy import deepcopy
4
+ from inspect import Attribute
5
+ import numpy as np
6
+ from scipy import stats
7
+ import warnings
8
+ import time
9
+
10
+ try:
11
+ import cupy as xp
12
+
13
+ gpu_available = True
14
+ except ModuleNotFoundError:
15
+ import numpy as xp
16
+
17
+ gpu_available = False
18
+
19
+ from lisatools.utils.utility import searchsorted2d_vec, get_groups_from_band_structure
20
+ from eryn.moves import GroupStretchMove
21
+ from eryn.prior import ProbDistContainer
22
+ from eryn.utils.utility import groups_from_inds
23
+ from .gbspecialstretch import GBSpecialStretchMove
24
+ from .gbgroupstretch import GBGroupStretchMove
25
+
26
+ from ...diagnostic import inner_product
27
+ from eryn.state import State
28
+ from lisatools.utils.utility import searchsorted2d_vec
29
+
30
+ __all__ = ["GBGroupStretchMove"]
31
+
32
+
33
+ # MHMove needs to be to the left here to overwrite GBBruteRejectionRJ RJ proposal method
34
+ class GBSpecialGroupStretchMove(GBGroupStretchMove, GBSpecialStretchMove):
35
+ """Generate Revesible-Jump proposals for GBs with try-force rejection
36
+
37
+ Will use gpu if template generator uses GPU.
38
+
39
+ Args:
40
+ priors (object): :class:`ProbDistContainer` object that has ``logpdf``
41
+ and ``rvs`` methods.
42
+
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ gb_args,
48
+ gb_kwargs,
49
+ start_ind_limit=10,
50
+ *args,
51
+ **kwargs
52
+ ):
53
+ self.fixed_like_diff = 0
54
+ self.time = 0
55
+ self.name = "gbgroupstretch"
56
+ self.start_ind_limit = start_ind_limit
57
+ GBSpecialStretchMove.__init__(self, *gb_args, **gb_kwargs)
58
+ GroupStretchMove.__init__(self, *args, **kwargs)
59
+
60
+ def setup_gbs(self, branch):
61
+ coords = branch.coords
62
+ inds = branch.inds
63
+ supps = branch.branch_supplimental
64
+ ntemps, nwalkers, nleaves_max, ndim = branch.shape
65
+ all_remaining_coords = coords[inds]
66
+ remaining_wave_info = supps[inds]
67
+
68
+ num_remaining = len(all_remaining_coords)
69
+ # TODO: make faster?
70
+ points_out = np.zeros((num_remaining, self.nfriends, ndim))
71
+
72
+ # info_mat = self.gb.information_matrix()
73
+
74
+ freqs = all_remaining_coords[:, 1]
75
+
76
+ # TODO: improve this?
77
+ inds_freqs_sorted = np.argsort(freqs)
78
+ freqs_sorted = freqs[np.argsort(freqs)]
79
+
80
+ inds_reverse = np.empty_like(inds_freqs_sorted)
81
+ inds_reverse[inds_freqs_sorted] = np.arange(inds_freqs_sorted.size)
82
+
83
+ left = np.full(len(freqs_sorted), self.nfriends)
84
+ right = np.full(len(freqs_sorted), self.nfriends)
85
+ indexes = np.arange(len(freqs_sorted))
86
+
87
+ left = left * (indexes >= self.nfriends) + indexes * (indexes < self.nfriends)
88
+ right = right * (indexes < len(freqs_sorted) - self.nfriends) + (len(freqs_sorted) - 1 - indexes) * (indexes >= len(freqs_sorted) - self.nfriends)
89
+
90
+ left = left * ((left + right == self.nfriends * 2) | (right == self.nfriends)) + (left + self.nfriends - right) * ((left + right < 2 * self.nfriends) & (right != self.nfriends))
91
+
92
+ right = right * ((left + right == self.nfriends * 2) | (left == self.nfriends)) + (right + self.nfriends - left) * ((left + right < self.nfriends * 2) & (left != self.nfriends))
93
+ inds_keep = (np.tile(np.arange(2 * self.nfriends), (len(freqs_sorted), 1)) + np.arange(len(freqs_sorted))[:, None] - left[:, None]).astype(int)
94
+
95
+ #distances = np.abs(freqs[None, :] - freqs[:, None])
96
+ #distances[distances == 0.0] = 1e300
97
+ """
98
+ distances = self.xp.full((num_remaining,num_remaining), 1e300)
99
+ for i, coords_here in enumerate(all_remaining_coords):
100
+ #A_here = remaining_wave_info["A"][i]
101
+ #E_here = remaining_wave_info["E"][i]
102
+ #sig_len = len(A_here)
103
+ #start_ind_here = remaining_wave_info["start_inds"][i].item()
104
+ #freqs_here = (self.xp.arange(sig_len) + start_ind_here) * self.df
105
+ #psd_here = self.psd[0][start_ind_here - self.start_freq_ind: start_ind_here - self.start_freq_ind + sig_len]
106
+
107
+ h_h = inner_product([A_here, E_here], [A_here, E_here], f_arr=freqs_here, PSD=psd_here, use_gpu=self.use_gpu)
108
+
109
+ for j in range(i, num_remaining):
110
+ if j == i:
111
+ continue
112
+ A_check = remaining_wave_info["A"][j]
113
+ E_check = remaining_wave_info["E"][j]
114
+ start_ind_check = remaining_wave_info["start_inds"][j].item()
115
+ if abs(start_ind_here - start_ind_check) > self.start_ind_limit:
116
+ continue
117
+ start_ind = self.xp.max(self.xp.asarray([start_ind_here, start_ind_check])).item()
118
+ end_ind = self.xp.min(self.xp.asarray([start_ind_here + sig_len, start_ind_check + sig_len])).item()
119
+ sig_len_new = end_ind - start_ind
120
+
121
+ start_ind_now_here = start_ind - start_ind_here
122
+ slice_here = slice(start_ind_now_here, start_ind_now_here + sig_len_new)
123
+
124
+ start_ind_now_check = start_ind - start_ind_check
125
+ slice_check = slice(start_ind_now_check, start_ind_now_check + sig_len_new)
126
+ d_h = inner_product([A_here[slice_here], E_here[slice_here]], [A_check[slice_check], E_check[slice_check]], f_arr=freqs_here[slice_here], PSD=psd_here[slice_here], use_gpu=self.use_gpu)
127
+
128
+ distances[i, j] = abs(1.0 - d_h.real / h_h.real)
129
+ distances[j, i] = distances[i, j]
130
+ print(i)
131
+
132
+ keep = self.xp.argsort(distances)[:self.nfriends]
133
+ try:
134
+ keep = keep.get()
135
+ except AttributeError:
136
+ pass
137
+
138
+
139
+ breakpoint()
140
+ """
141
+ """
142
+ """
143
+ try:
144
+ keep = inds_freqs_sorted[inds_keep][inds_reverse]
145
+ except IndexError:
146
+ breakpoint()
147
+ suggested_friends = all_remaining_coords[keep]
148
+ distances = np.abs(suggested_friends[:, :, 1] - all_remaining_coords[:, 1][:, None])
149
+ distances_pretend = np.zeros_like(distances)
150
+ distances_pretend[distances == 0.0] = 1e300
151
+ keep = np.argsort(distances_pretend, axis=1)[:, :self.nfriends]
152
+ friends = np.take_along_axis(suggested_friends, keep[:, :, None], axis=1)
153
+
154
+ supps[inds] = {"group_move_points": friends}
155
+
156
+ def find_friends(self, branches):
157
+ for i, (name, branch) in enumerate(branches.items()):
158
+ if name == "gb":
159
+ self.setup_gbs(branch)
160
+
161
+ def propose(self, model, state):
162
+ """Use the move to generate a proposal and compute the acceptance
163
+
164
+ Args:
165
+ model (:class:`eryn.model.Model`): Carrier of sampler information.
166
+ state (:class:`State`): Current state of the sampler.
167
+
168
+ Returns:
169
+ :class:`State`: State of sampler after proposal is complete.
170
+
171
+ """
172
+
173
+ # st = time.perf_counter()
174
+ # Check that the dimensions are compatible.
175
+ ndim_total = 0
176
+ ntemps, nwalkers, nleaves_, ndim_ = state.branches["gb"].shape
177
+
178
+ if state.branches["gb"].nleaves.sum() < 2 * self.nfriends:
179
+ print("Not enough friends yet.")
180
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
181
+ self.temperature_control.swaps_accepted = np.zeros(ntemps - 1, dtype=int)
182
+ return state, accepted
183
+
184
+ # TODO: deal with more intensive acceptance fractions
185
+ # Run any move-specific setup.
186
+
187
+ # data should not be whitened
188
+
189
+ # Split the ensemble in half and iterate over these two halves.
190
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
191
+
192
+ ntemps, nwalkers, nleaves_max, ndim = state.branches_coords["gb"].shape
193
+
194
+ f_test = state.branches_coords["gb"][:, :, :, 1] / 1e3
195
+
196
+ # TODO: add actual amplitude
197
+ N_vals = new_state.branches["gb"].branch_supplimental.holder["N_vals"]
198
+
199
+ N_vals = self.xp.asarray(N_vals)
200
+
201
+ N_vals_2_times = self.xp.concatenate([N_vals, N_vals], axis=-1)
202
+
203
+ gb_coords = np.zeros((ntemps, nwalkers, nleaves_max, ndim))
204
+ gb_coords[new_state.branches_inds["gb"]] = new_state.branches_coords["gb"][new_state.branches_inds["gb"]].copy()
205
+
206
+ log_like_tmp = self.xp.asarray(new_state.log_like)
207
+ log_prior_tmp = self.xp.asarray(new_state.log_prior)
208
+
209
+ self.mempool.free_all_blocks()
210
+
211
+ unique_N = np.unique(N_vals)
212
+
213
+ groups = get_groups_from_band_structure(f_test, self.band_edges, xp=np)
214
+ breakpoint()
215
+ groups[new_state.branches_inds["gb"]] = -1
216
+ unique_groups, group_len = np.unique(groups.flatten(), return_counts=True)
217
+
218
+ # remove information about the bad "-1" group
219
+ if -1 in unique_groups:
220
+ group_len = np.delete(group_len, unique_groups == -1)
221
+ unique_groups = np.delete(unique_groups, unique_groups == -1)
222
+
223
+ if len(unique_groups) == 0:
224
+ return state, accepted
225
+
226
+ # needs to be max because some values may be missing due to evens and odds
227
+ num_groups = unique_groups.max().item() + 1
228
+
229
+ waveform_kwargs_now = self.waveform_kwargs.copy()
230
+ if "N" in waveform_kwargs_now:
231
+ waveform_kwargs_now.pop("N")
232
+ waveform_kwargs_now["start_freq_ind"] = self.start_freq_ind
233
+
234
+ points_to_move = q["gb"][group]
235
+ group_branch_supp_info = new_state.branches_supplimental["gb"][group_cpu]
236
+ points_for_move = self.xp.asarray(group_branch_supp_info["group_move_points"])
237
+
238
+ q_temp, factors_temp = self.get_proposal(
239
+ {"gb": points_to_move.transpose(0, 2, 1, 3).reshape(ntemps * nleaves_max, int(nwalkers / 2), 1, ndim)},
240
+ {"gb": [points_for_move.transpose(0, 2, 1, 3).reshape(ntemps * nleaves_max, int(nwalkers / 2), 1, ndim)]},
241
+ model.random
242
+ )
243
+ breakpoint()
244
+
245
+ for group_iter in range(num_groups):
246
+ # st = time.perf_counter()
247
+ # sometimes you will have an extra odd or even group only
248
+ # the group_iter may not match the actual running group number in this case
249
+ if group_iter not in groups:
250
+ continue
251
+ group = [grp[groups == group_iter].flatten() for grp in group_temp_finder]
252
+
253
+ # st = time.perf_counter()
254
+ temp_inds, walkers_inds, leaf_inds = [self.xp.asarray(grp) for grp in group]
255
+
256
+
257
+
258
+
259
+ if True:
260
+ factors = self.xp.zeros((ntemps, nwalkers, nleaves_max))
261
+ factors[:, split_here] = factors_temp.reshape(ntemps, nleaves_max, int(nwalkers / 2)).transpose(0, 2, 1)
262
+
263
+ # use new_state here to get change after 1st round
264
+ q = {"gb": gb_coords.copy()}
265
+
266
+ q["gb"][:, split_here] = q_temp["gb"].reshape(ntemps, nleaves_max, int(nwalkers / 2), ndim).transpose(0, 2, 1, 3)
267
+
268
+ """et = time.perf_counter()
269
+ print("prop", (et - st))"""
270
+
271
+ for group_iter in range(num_groups):
272
+ # st = time.perf_counter()
273
+ # sometimes you will have an extra odd or even group only
274
+ # the group_iter may not match the actual running group number in this case
275
+ if group_iter not in groups:
276
+ continue
277
+ group = [grp[groups == group_iter].flatten() for grp in group_temp_finder]
278
+
279
+ # st = time.perf_counter()
280
+ temp_inds, walkers_inds, leaf_inds = [self.xp.asarray(grp) for grp in group]
281
+ q = q = {"gb": gb_coords.copy()}
282
+ # new_inds = deepcopy(new_state.branches_inds)
283
+
284
+ points_to_move = q["gb"][group]
285
+ group_branch_supp_info = new_state.branches_supplimental["gb"][group_cpu]
286
+ points_for_move = self.xp.asarray(group_branch_supp_info["group_move_points"])
287
+
288
+ """et = time.perf_counter()
289
+ print("before prop", (et - st))
290
+ st = time.perf_counter()"""
291
+ q_temp, factors_temp = self.get_proposal(
292
+ {"gb": points_to_move}, {"gb": points_for_move}, model.random
293
+ )
294
+
295
+ """et = time.perf_counter()
296
+ print("after prop", (et - st))
297
+ st = time.perf_counter()"""
298
+
299
+ #breakpoint()
300
+
301
+ q["gb"][group] = q_temp["gb"]
302
+
303
+ # data should not be whitened
304
+ # TODO: take this out of the groups loop
305
+ if "noise_params" not in state.branches:
306
+ use_stock_psd = True
307
+ psd = self.xp.tile(self.xp.asarray(self.psd), (ntemps * nwalkers, 1, 1))
308
+
309
+ else:
310
+ use_stock_psd = False
311
+ noise_params = state.branches["noise_params"].coords
312
+ if self.psd_func is None:
313
+ raise ValueError("When providing noise_params, psd_func kwargs in __init__ function must be given.")
314
+
315
+ if noise_params.ndim == 3:
316
+ noise_params = noise_params[0]
317
+ tmp = self.xp.asarray([self.psd_func(self.fd, *noise_params.reshape(-1, noise_params.shape[-1]).T, **self.noise_kwargs) for _ in range(2)])
318
+ psd = tmp.transpose((1,0,2))
319
+ self.noise_ll = -self.xp.sum(self.xp.log(psd), axis=(1, 2)).reshape(ntemps, nwalkers)
320
+
321
+ try:
322
+ self.noise_ll = self.noise_ll.get()
323
+ except AttributeError:
324
+ pass
325
+
326
+ noise_ll = self.noise_ll
327
+
328
+ if self.use_gpu:
329
+ new_points_prior = q["gb"][group].get()
330
+ old_points_prior = gb_coords[group].get()
331
+ else:
332
+ new_points_prior = q["gb"][group]
333
+ old_points_prior = gb_coords[group]
334
+
335
+ # TODO: GPUize prior
336
+ logp = self.xp.asarray(self.priors["gb"].logpdf(new_points_prior))
337
+
338
+ if self.xp.all(self.xp.isinf(logp)):
339
+ pass
340
+
341
+ keep_here = self.xp.where(~self.xp.isinf(logp))
342
+
343
+ points_remove = self.parameter_transforms.both_transforms(points_to_move[keep_here], xp=self.xp)
344
+ points_add = self.parameter_transforms.both_transforms(q_temp["gb"][keep_here], xp=self.xp)
345
+
346
+ data_index = self.xp.asarray((temp_inds[keep_here] * nwalkers + walkers_inds[keep_here]).astype(xp.int32))
347
+ noise_index = self.xp.asarray((temp_inds[keep_here] * nwalkers + walkers_inds[keep_here]).astype(xp.int32))
348
+ nChannels = 2
349
+
350
+ delta_ll = self.xp.full(points_to_move.shape[0], -1e300)
351
+
352
+ delta_ll[keep_here] = self.gb.swap_likelihood_difference(
353
+ points_remove,
354
+ points_add,
355
+ data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).copy(),
356
+ psd.copy(),
357
+ data_index=data_index,
358
+ noise_index=noise_index,
359
+ adjust_inplace=False,
360
+ **self.waveform_kwargs
361
+ )
362
+ """dhr = self.gb.d_h_remove.copy()
363
+ dha = self.gb.d_h_add.copy()
364
+ aa = self.gb.add_add.copy()
365
+ rr = self.gb.remove_remove.copy()
366
+ ar = self.gb.add_remove.copy()
367
+
368
+ kwargs_tmp = self.waveform_kwargs.copy()
369
+ kwargs_tmp["use_c_implementation"] = False
370
+ check = self.gb.swap_likelihood_difference(
371
+ points_remove,
372
+ points_add,
373
+ data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).copy(),
374
+ psd.reshape(ntemps * nwalkers, nChannels, -1).copy(),
375
+ data_index=data_index,
376
+ noise_index=noise_index,
377
+ adjust_inplace=False,
378
+ **kwargs_tmp
379
+ )
380
+ breakpoint()"""
381
+
382
+
383
+ optimized_snr = self.xp.sqrt(self.gb.add_add.real)
384
+ detected_snr = (self.gb.d_h_add + self.gb.add_remove).real / optimized_snr
385
+
386
+ if self.search:
387
+ inds_fix = ((optimized_snr < self.search_snr_lim) | (detected_snr < (0.8 * self.search_snr_lim)))
388
+
389
+ """try:
390
+ inds_fix = inds_fix.get()
391
+ except AttributeError:
392
+ pass"""
393
+ if self.xp.any(inds_fix):
394
+ delta_ll[keep_here[0][inds_fix]] = -1e300
395
+
396
+ prev_logl = log_like_tmp[(temp_inds, walkers_inds)]
397
+ logl = delta_ll + prev_logl
398
+
399
+ #if np.any(logl - np.load("noise_ll.npy").flatten() > 0.0):
400
+ # breakpoint()
401
+ #print("multi check: ", (logl - np.load("noise_ll.npy").flatten()))
402
+ prev_logp = self.xp.asarray(self.priors["gb"].logpdf(old_points_prior))
403
+
404
+ betas_in = self.xp.asarray(self.temperature_control.betas)[temp_inds]
405
+ logP = self.compute_log_posterior(logl, logp, betas=betas_in)
406
+
407
+ # TODO: check about prior = - inf
408
+ # takes care of tempering
409
+ prev_logP = self.compute_log_posterior(prev_logl, prev_logp, betas=betas_in)
410
+
411
+ # TODO: think about factors in tempering
412
+ lnpdiff = factors_temp + logP - prev_logP
413
+
414
+ # TODO: think about random states
415
+ keep = lnpdiff > self.xp.log(self.xp.random.rand(*logP.shape))
416
+
417
+ if self.xp.any(keep):
418
+ # if gibbs sampling, this will say it is accepted if
419
+ # any of the gibbs proposals were accepted
420
+ accepted_here = keep.copy()
421
+
422
+ # check freq overlap
423
+ f0_new = q_temp["gb"][keep, 1]
424
+ f0_old = gb_coords[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])][:, 1]
425
+ nleaves_max = state.branches["gb"].nleaves_max
426
+ check_f0 = self.xp.zeros((ntemps, nwalkers, nleaves_max))
427
+ #check_f0_old = self.xp.zeros((ntemps, nwalkers, nleaves_max))
428
+ check_f0[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])] = f0_new
429
+
430
+ check_f0_old = self.xp.zeros((ntemps, nwalkers, nleaves_max))
431
+ #check_f0_old = self.xp.zeros((ntemps, nwalkers, nleaves_max))
432
+
433
+ check_f0_old[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])] = f0_old
434
+
435
+ check_f0_old_sorted = self.xp.sort(check_f0_old, axis=-1)
436
+ inds_f0_old_sorted = self.xp.argsort(check_f0_old, axis=-1)
437
+
438
+ check_f0_tmp = check_f0.reshape(-1, check_f0.shape[-1])
439
+ check_f0_old_sorted_tmp = check_f0_old_sorted.reshape(-1, check_f0_old_sorted.shape[-1])
440
+
441
+ check_f0_in_old_inds = searchsorted2d_vec(check_f0_old_sorted_tmp, check_f0_tmp, xp=self.xp, side="right").reshape(check_f0.shape)
442
+
443
+ zero_check = check_f0_in_old_inds[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])]
444
+ f0_new_here = f0_new.copy()
445
+ keep_for_after = keep.copy()
446
+ inds_test_here = [-2, -1, 0, 1]
447
+ for ind_test_here in inds_test_here:
448
+ here_check = zero_check + ind_test_here
449
+ do_check = np.ones_like(here_check, dtype=bool)
450
+ do_check[here_check < 0] = False
451
+ do_check[here_check >= check_f0_old.shape[-1]] = False
452
+
453
+ here_vals = check_f0_old_sorted[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])]
454
+ here_inds = inds_f0_old_sorted[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])]
455
+ here_test = (self.xp.abs(f0_new_here - here_vals) / 1e3 / self.df).astype(int)
456
+ fix_bad2_tmp = self.xp.arange(len(keep))[keep]
457
+ fix_bad2 = fix_bad2_tmp[(here_test < dbin) & (leaf_inds[keep] != here_inds) & (do_check)]
458
+
459
+ # print("NEW BAD", len(fix_bad2), len(fix_bad2) / len(keep[keep]))
460
+ keep_for_after[fix_bad2] = False
461
+
462
+ check_f0_sorted = self.xp.sort(check_f0, axis=-1)
463
+ inds_f0_sorted = self.xp.argsort(check_f0, axis=-1)
464
+ check_f0_diff = self.xp.zeros_like(check_f0_sorted, dtype=int)
465
+ check_f0_diff[:, :, 1:] = (self.xp.diff(check_f0_sorted, axis=-1) / 1e3 / self.df).astype(int)
466
+
467
+ bad = (check_f0_diff < dbin) & (check_f0_sorted != 0.0)
468
+ if self.xp.any(bad):
469
+ try:
470
+ bad_inds = self.xp.where(bad)
471
+
472
+ # fix the last entry of bad inds
473
+ inds_bad = (bad_inds[0], bad_inds[1], inds_f0_sorted[bad])
474
+ bad_check_val = (inds_bad[0] * 1e10 + inds_bad[1] * 1e5 + inds_bad[2]).astype(int)
475
+ # we are going to make this proposal not accepted
476
+ # this so far is only ever an accepted-level problem at high temps
477
+ # where large frequency jumps can happen
478
+ check_val = (temp_inds[keep] * 1e10 + walkers_inds[keep] * 1e5 + leaf_inds[keep]).astype(int)
479
+ fix_keep = self.xp.arange(len(keep))[keep][self.xp.in1d(check_val, bad_check_val)]
480
+ keep[fix_keep] = False
481
+ except:
482
+ breakpoint()
483
+
484
+ gb_coords[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])] = q_temp["gb"][keep]
485
+
486
+ # parameters were run for all ~np.isinf(logp), need to adjust for those not accepted
487
+ keep_from_before = (keep * (~self.xp.isinf(logp)))[~self.xp.isinf(logp)]
488
+ group_index = data_index[keep_from_before]
489
+
490
+ waveform_kwargs_fill = self.waveform_kwargs.copy()
491
+ waveform_kwargs_fill["start_freq_ind"] = self.start_freq_ind
492
+ """ll_check_d_h_add = self.gb.get_ll(
493
+ points_add.T,
494
+ data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
495
+ psd.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
496
+ data_index=data_index,
497
+ noise_index=noise_index,
498
+ **self.waveform_kwargs
499
+ )
500
+
501
+ h_h_d_h_add = self.gb.h_h.copy()
502
+ d_h_d_h_add = self.gb.d_h.copy()
503
+
504
+ ll_check_d_h_remove = self.gb.get_ll(
505
+ points_remove.T,
506
+ data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
507
+ psd.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
508
+ data_index=data_index,
509
+ noise_index=noise_index,
510
+ **self.waveform_kwargs
511
+ )
512
+
513
+ h_h_d_h_remove = self.gb.h_h.copy()
514
+ d_h_d_h_remove = self.gb.d_h.copy()"""
515
+
516
+
517
+ # remove templates by multiplying by "adding them to" d - h
518
+ try:
519
+ self.gb.generate_global_template(points_remove[keep_from_before],
520
+ group_index, data_minus_template.reshape((-1,) + data_minus_template.shape[2:]), **waveform_kwargs_fill
521
+ )
522
+ except ValueError:
523
+ breakpoint()
524
+
525
+
526
+ """self.gb.d_d = self.xp.asarray(-2 * state.log_like.flatten())
527
+ ll_check_add = self.gb.get_ll(
528
+ points_add.T,
529
+ data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
530
+ psd.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
531
+ data_index=data_index,
532
+ noise_index=noise_index,
533
+ **self.waveform_kwargs
534
+ )
535
+
536
+ h_h_add = self.gb.h_h.copy()
537
+ d_h_add = self.gb.d_h.copy()
538
+
539
+ ll_check_remove = self.gb.get_ll(
540
+ points_remove.T,
541
+ data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
542
+ psd.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
543
+ data_index=data_index,
544
+ noise_index=noise_index,
545
+ **self.waveform_kwargs
546
+ )
547
+
548
+ h_h_remove = self.gb.h_h.copy()
549
+ d_h_remove = self.gb.d_h.copy()
550
+ breakpoint()"""
551
+
552
+ # add templates by adding to -(-(d - h) + add) = d - (h + add)
553
+ data_minus_template *= -1
554
+ try:
555
+ self.gb.generate_global_template(points_add[keep_from_before],
556
+ group_index, data_minus_template.reshape((-1,) + data_minus_template.shape[2:]), **waveform_kwargs_fill
557
+ )
558
+ except ValueError:
559
+ breakpoint()
560
+ data_minus_template *= -1
561
+
562
+ # set unaccepted differences to zero
563
+ accepted_delta_ll = delta_ll * (keep)
564
+ accepted_delta_lp = (logp - prev_logp)
565
+ accepted_delta_lp[self.xp.isinf(accepted_delta_lp)] = 0.0
566
+ logl_change_contribution = np.zeros_like(new_state.log_like)
567
+ logp_change_contribution = np.zeros_like(new_state.log_prior)
568
+
569
+ try:
570
+ in_tuple = (accepted_delta_ll[keep].get(), accepted_delta_lp[keep].get(), temp_inds[keep].get(), walkers_inds[keep].get())
571
+ except AttributeError:
572
+ in_tuple = (accepted_delta_ll[keep], accepted_delta_lp[keep], temp_inds[keep], walkers_inds[keep])
573
+
574
+ for i, (dll, dlp, ti, wi) in enumerate(zip(*in_tuple)):
575
+
576
+ logl_change_contribution[ti, wi] += dll
577
+ logp_change_contribution[ti, wi] += dlp
578
+
579
+ log_like_tmp += self.xp.asarray(logl_change_contribution)
580
+ log_prior_tmp += self.xp.asarray(logp_change_contribution)
581
+
582
+ if np.any(accepted_delta_ll > 1e8):
583
+ breakpoint()
584
+
585
+ try:
586
+ new_state.branches["gb"].coords[:] = gb_coords.get()
587
+ new_state.log_like[:] = log_like_tmp.get()
588
+ new_state.log_prior[:] = log_prior_tmp.get()
589
+ except AttributeError:
590
+ new_state.branches["gb"].coords[:] = gb_coords
591
+ new_state.log_like[:] = log_like_tmp
592
+ new_state.log_prior[:] = log_prior_tmp
593
+
594
+ if self.time % 1 == 0:
595
+ lp_after = model.compute_log_prior_fn(new_state.branches_coords, inds=new_state.branches_inds)
596
+ ll_after = model.compute_log_like_fn(new_state.branches_coords, inds=new_state.branches_inds, logp=lp_after, supps=new_state.supplimental, branch_supps=new_state.branches_supplimental)
597
+ #check = -1/2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / self.xp.asarray(self.psd), axis=(2, 3))
598
+ #check2 = -1/2 * 4 * self.df * self.xp.sum(tmp.conj() * tmp / self.xp.asarray(self.psd), axis=(2, 3))
599
+ #print(np.abs(new_state.log_like - ll_after[0]).max(), np.abs(new_state.log_prior - lp_after).max())
600
+ if np.abs(new_state.log_prior - lp_after).max() > 0.1 or np.abs(new_state.log_like - ll_after[0]).max() > 1e0:
601
+ breakpoint()
602
+
603
+ # if any are even remotely getting to be different, reset all (small change)
604
+ elif np.abs(new_state.log_like - ll_after[0]).max() > 1e-3:
605
+
606
+ fix_here = np.abs(new_state.log_like - ll_after[0]) > 1e-6
607
+ data_minus_template_old = data_minus_template.copy()
608
+ data_minus_template = self.xp.zeros_like(data_minus_template_old)
609
+ data_minus_template[:] = self.xp.asarray(self.data)[None, None]
610
+ templates = self.xp.zeros_like(data_minus_template).reshape(-1, 2, data_minus_template.shape[-1])
611
+ for name in new_state.branches.keys():
612
+ if name not in ["gb", "gb"]:
613
+ continue
614
+ new_state_branch = new_state.branches[name]
615
+ coords_here = new_state_branch.coords[new_state_branch.inds]
616
+ ntemps, nwalkers, nleaves_max_here, ndim = new_state_branch.shape
617
+
618
+ group_index = np.repeat(np.arange(ntemps * nwalkers).reshape(ntemps, nwalkers, 1), nleaves_max_here, axis=-1)[new_state_branch.inds]
619
+
620
+ coords_here_in = self.parameter_transforms.both_transforms(coords_here, xp=np)
621
+
622
+ self.gb.generate_global_template(coords_here_in, group_index, templates, batch_size=1000, **self.waveform_kwargs)
623
+
624
+ data_minus_template -= templates.reshape(ntemps, nwalkers, 2, templates.shape[-1])
625
+
626
+ psd_here = psd.reshape(ntemps, nwalkers, 2, templates.shape[-1])
627
+ new_like = -1 / 2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / psd_here, axis=(2, 3)).real.get()
628
+
629
+ new_like += self.noise_ll
630
+ new_state.log_like[:] = new_like.reshape(ntemps, nwalkers)
631
+
632
+ """elif np.abs(new_state.log_prior - lp_after).max() > 1e-6 or np.abs(new_state.log_like - ll_after[0]).max() > 0.1:
633
+ # TODO: need to investigate when this fails
634
+ self.fixed_like_diff += 1
635
+ print("Fixing like diff for now.", self.fixed_like_diff)
636
+ fix_here = np.abs(new_state.log_like - ll_after[0]) > 0.1
637
+ new_state.log_like[fix_here] = ll_after[0][fix_here]"""
638
+
639
+ """
640
+ check_logl = model.compute_log_like_fn(new_state.branches_coords, inds=new_state.branches_inds, branch_supps=new_state.branches_supplimental, supps=new_state.supplimental)
641
+ sigll = model.log_like_fn.f.signal_ll.copy()
642
+ check_logl2 = model.compute_log_like_fn(state.branches_coords, inds=state.branches_inds, branch_supps=state.branches_supplimental, supps=state.supplimental)
643
+ breakpoint()
644
+ """
645
+ """et = time.perf_counter()
646
+ print("group middle", (et - st))
647
+ st = time.perf_counter()"""
648
+ # get accepted fraction
649
+ accepted_check = np.all(np.abs(new_state.branches_coords["gb"] - state.branches_coords["gb"]) > 0.0, axis=-1).sum(axis=(1, 2)) / new_state.branches_inds["gb"].sum(axis=(1,2))
650
+
651
+ # manually tell temperatures how real overall acceptance fraction is
652
+ number_of_walkers_for_accepted = np.floor(nwalkers * accepted_check).astype(int)
653
+
654
+ accepted_inds = np.tile(np.arange(nwalkers), (ntemps, 1))
655
+
656
+ accepted = np.zeros((ntemps, nwalkers), dtype=bool)
657
+ accepted[accepted_inds < number_of_walkers_for_accepted[:, None]] = True
658
+
659
+ if self.temperature_control is not None:
660
+ new_state, accepted = self.temperature_control.temper_comps(new_state, accepted)
661
+ # self.temperature_control.swaps_accepted = np.zeros((ntemps - 1))
662
+
663
+ if np.any(new_state.log_like > 1e10):
664
+ breakpoint()
665
+ self.time += 1
666
+
667
+ #et = time.perf_counter()
668
+ #print("group end", (et - st))
669
+
670
+ return new_state, accepted
671
+