lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lisaanalysistools might be problematic. Click here for more details.
- lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
- lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
- lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
- lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
- lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
- lisatools/__init__.py +0 -0
- lisatools/_version.py +4 -0
- lisatools/analysiscontainer.py +438 -0
- lisatools/cutils/detector.cpython-312-darwin.so +0 -0
- lisatools/datacontainer.py +292 -0
- lisatools/detector.py +410 -0
- lisatools/diagnostic.py +976 -0
- lisatools/glitch.py +193 -0
- lisatools/sampling/__init__.py +0 -0
- lisatools/sampling/likelihood.py +882 -0
- lisatools/sampling/moves/__init__.py +0 -0
- lisatools/sampling/moves/gbgroupstretch.py +53 -0
- lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
- lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
- lisatools/sampling/moves/gbspecialstretch.py +1836 -0
- lisatools/sampling/moves/mbhspecialmove.py +286 -0
- lisatools/sampling/moves/placeholder.py +16 -0
- lisatools/sampling/moves/skymodehop.py +110 -0
- lisatools/sampling/moves/specialforegroundmove.py +564 -0
- lisatools/sampling/prior.py +508 -0
- lisatools/sampling/stopping.py +320 -0
- lisatools/sampling/utility.py +324 -0
- lisatools/sensitivity.py +888 -0
- lisatools/sources/__init__.py +0 -0
- lisatools/sources/emri/__init__.py +1 -0
- lisatools/sources/emri/tdiwaveform.py +72 -0
- lisatools/stochastic.py +291 -0
- lisatools/utils/__init__.py +0 -0
- lisatools/utils/constants.py +40 -0
- lisatools/utils/multigpudataholder.py +730 -0
- lisatools/utils/pointeradjust.py +106 -0
- lisatools/utils/utility.py +240 -0
|
@@ -0,0 +1,671 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from inspect import Attribute
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy import stats
|
|
7
|
+
import warnings
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import cupy as xp
|
|
12
|
+
|
|
13
|
+
gpu_available = True
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
import numpy as xp
|
|
16
|
+
|
|
17
|
+
gpu_available = False
|
|
18
|
+
|
|
19
|
+
from lisatools.utils.utility import searchsorted2d_vec, get_groups_from_band_structure
|
|
20
|
+
from eryn.moves import GroupStretchMove
|
|
21
|
+
from eryn.prior import ProbDistContainer
|
|
22
|
+
from eryn.utils.utility import groups_from_inds
|
|
23
|
+
from .gbspecialstretch import GBSpecialStretchMove
|
|
24
|
+
from .gbgroupstretch import GBGroupStretchMove
|
|
25
|
+
|
|
26
|
+
from ...diagnostic import inner_product
|
|
27
|
+
from eryn.state import State
|
|
28
|
+
from lisatools.utils.utility import searchsorted2d_vec
|
|
29
|
+
|
|
30
|
+
__all__ = ["GBGroupStretchMove"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# MHMove needs to be to the left here to overwrite GBBruteRejectionRJ RJ proposal method
|
|
34
|
+
class GBSpecialGroupStretchMove(GBGroupStretchMove, GBSpecialStretchMove):
|
|
35
|
+
"""Generate Revesible-Jump proposals for GBs with try-force rejection
|
|
36
|
+
|
|
37
|
+
Will use gpu if template generator uses GPU.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
priors (object): :class:`ProbDistContainer` object that has ``logpdf``
|
|
41
|
+
and ``rvs`` methods.
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
gb_args,
|
|
48
|
+
gb_kwargs,
|
|
49
|
+
start_ind_limit=10,
|
|
50
|
+
*args,
|
|
51
|
+
**kwargs
|
|
52
|
+
):
|
|
53
|
+
self.fixed_like_diff = 0
|
|
54
|
+
self.time = 0
|
|
55
|
+
self.name = "gbgroupstretch"
|
|
56
|
+
self.start_ind_limit = start_ind_limit
|
|
57
|
+
GBSpecialStretchMove.__init__(self, *gb_args, **gb_kwargs)
|
|
58
|
+
GroupStretchMove.__init__(self, *args, **kwargs)
|
|
59
|
+
|
|
60
|
+
def setup_gbs(self, branch):
|
|
61
|
+
coords = branch.coords
|
|
62
|
+
inds = branch.inds
|
|
63
|
+
supps = branch.branch_supplimental
|
|
64
|
+
ntemps, nwalkers, nleaves_max, ndim = branch.shape
|
|
65
|
+
all_remaining_coords = coords[inds]
|
|
66
|
+
remaining_wave_info = supps[inds]
|
|
67
|
+
|
|
68
|
+
num_remaining = len(all_remaining_coords)
|
|
69
|
+
# TODO: make faster?
|
|
70
|
+
points_out = np.zeros((num_remaining, self.nfriends, ndim))
|
|
71
|
+
|
|
72
|
+
# info_mat = self.gb.information_matrix()
|
|
73
|
+
|
|
74
|
+
freqs = all_remaining_coords[:, 1]
|
|
75
|
+
|
|
76
|
+
# TODO: improve this?
|
|
77
|
+
inds_freqs_sorted = np.argsort(freqs)
|
|
78
|
+
freqs_sorted = freqs[np.argsort(freqs)]
|
|
79
|
+
|
|
80
|
+
inds_reverse = np.empty_like(inds_freqs_sorted)
|
|
81
|
+
inds_reverse[inds_freqs_sorted] = np.arange(inds_freqs_sorted.size)
|
|
82
|
+
|
|
83
|
+
left = np.full(len(freqs_sorted), self.nfriends)
|
|
84
|
+
right = np.full(len(freqs_sorted), self.nfriends)
|
|
85
|
+
indexes = np.arange(len(freqs_sorted))
|
|
86
|
+
|
|
87
|
+
left = left * (indexes >= self.nfriends) + indexes * (indexes < self.nfriends)
|
|
88
|
+
right = right * (indexes < len(freqs_sorted) - self.nfriends) + (len(freqs_sorted) - 1 - indexes) * (indexes >= len(freqs_sorted) - self.nfriends)
|
|
89
|
+
|
|
90
|
+
left = left * ((left + right == self.nfriends * 2) | (right == self.nfriends)) + (left + self.nfriends - right) * ((left + right < 2 * self.nfriends) & (right != self.nfriends))
|
|
91
|
+
|
|
92
|
+
right = right * ((left + right == self.nfriends * 2) | (left == self.nfriends)) + (right + self.nfriends - left) * ((left + right < self.nfriends * 2) & (left != self.nfriends))
|
|
93
|
+
inds_keep = (np.tile(np.arange(2 * self.nfriends), (len(freqs_sorted), 1)) + np.arange(len(freqs_sorted))[:, None] - left[:, None]).astype(int)
|
|
94
|
+
|
|
95
|
+
#distances = np.abs(freqs[None, :] - freqs[:, None])
|
|
96
|
+
#distances[distances == 0.0] = 1e300
|
|
97
|
+
"""
|
|
98
|
+
distances = self.xp.full((num_remaining,num_remaining), 1e300)
|
|
99
|
+
for i, coords_here in enumerate(all_remaining_coords):
|
|
100
|
+
#A_here = remaining_wave_info["A"][i]
|
|
101
|
+
#E_here = remaining_wave_info["E"][i]
|
|
102
|
+
#sig_len = len(A_here)
|
|
103
|
+
#start_ind_here = remaining_wave_info["start_inds"][i].item()
|
|
104
|
+
#freqs_here = (self.xp.arange(sig_len) + start_ind_here) * self.df
|
|
105
|
+
#psd_here = self.psd[0][start_ind_here - self.start_freq_ind: start_ind_here - self.start_freq_ind + sig_len]
|
|
106
|
+
|
|
107
|
+
h_h = inner_product([A_here, E_here], [A_here, E_here], f_arr=freqs_here, PSD=psd_here, use_gpu=self.use_gpu)
|
|
108
|
+
|
|
109
|
+
for j in range(i, num_remaining):
|
|
110
|
+
if j == i:
|
|
111
|
+
continue
|
|
112
|
+
A_check = remaining_wave_info["A"][j]
|
|
113
|
+
E_check = remaining_wave_info["E"][j]
|
|
114
|
+
start_ind_check = remaining_wave_info["start_inds"][j].item()
|
|
115
|
+
if abs(start_ind_here - start_ind_check) > self.start_ind_limit:
|
|
116
|
+
continue
|
|
117
|
+
start_ind = self.xp.max(self.xp.asarray([start_ind_here, start_ind_check])).item()
|
|
118
|
+
end_ind = self.xp.min(self.xp.asarray([start_ind_here + sig_len, start_ind_check + sig_len])).item()
|
|
119
|
+
sig_len_new = end_ind - start_ind
|
|
120
|
+
|
|
121
|
+
start_ind_now_here = start_ind - start_ind_here
|
|
122
|
+
slice_here = slice(start_ind_now_here, start_ind_now_here + sig_len_new)
|
|
123
|
+
|
|
124
|
+
start_ind_now_check = start_ind - start_ind_check
|
|
125
|
+
slice_check = slice(start_ind_now_check, start_ind_now_check + sig_len_new)
|
|
126
|
+
d_h = inner_product([A_here[slice_here], E_here[slice_here]], [A_check[slice_check], E_check[slice_check]], f_arr=freqs_here[slice_here], PSD=psd_here[slice_here], use_gpu=self.use_gpu)
|
|
127
|
+
|
|
128
|
+
distances[i, j] = abs(1.0 - d_h.real / h_h.real)
|
|
129
|
+
distances[j, i] = distances[i, j]
|
|
130
|
+
print(i)
|
|
131
|
+
|
|
132
|
+
keep = self.xp.argsort(distances)[:self.nfriends]
|
|
133
|
+
try:
|
|
134
|
+
keep = keep.get()
|
|
135
|
+
except AttributeError:
|
|
136
|
+
pass
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
breakpoint()
|
|
140
|
+
"""
|
|
141
|
+
"""
|
|
142
|
+
"""
|
|
143
|
+
try:
|
|
144
|
+
keep = inds_freqs_sorted[inds_keep][inds_reverse]
|
|
145
|
+
except IndexError:
|
|
146
|
+
breakpoint()
|
|
147
|
+
suggested_friends = all_remaining_coords[keep]
|
|
148
|
+
distances = np.abs(suggested_friends[:, :, 1] - all_remaining_coords[:, 1][:, None])
|
|
149
|
+
distances_pretend = np.zeros_like(distances)
|
|
150
|
+
distances_pretend[distances == 0.0] = 1e300
|
|
151
|
+
keep = np.argsort(distances_pretend, axis=1)[:, :self.nfriends]
|
|
152
|
+
friends = np.take_along_axis(suggested_friends, keep[:, :, None], axis=1)
|
|
153
|
+
|
|
154
|
+
supps[inds] = {"group_move_points": friends}
|
|
155
|
+
|
|
156
|
+
def find_friends(self, branches):
|
|
157
|
+
for i, (name, branch) in enumerate(branches.items()):
|
|
158
|
+
if name == "gb":
|
|
159
|
+
self.setup_gbs(branch)
|
|
160
|
+
|
|
161
|
+
def propose(self, model, state):
|
|
162
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
166
|
+
state (:class:`State`): Current state of the sampler.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
:class:`State`: State of sampler after proposal is complete.
|
|
170
|
+
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
# st = time.perf_counter()
|
|
174
|
+
# Check that the dimensions are compatible.
|
|
175
|
+
ndim_total = 0
|
|
176
|
+
ntemps, nwalkers, nleaves_, ndim_ = state.branches["gb"].shape
|
|
177
|
+
|
|
178
|
+
if state.branches["gb"].nleaves.sum() < 2 * self.nfriends:
|
|
179
|
+
print("Not enough friends yet.")
|
|
180
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
181
|
+
self.temperature_control.swaps_accepted = np.zeros(ntemps - 1, dtype=int)
|
|
182
|
+
return state, accepted
|
|
183
|
+
|
|
184
|
+
# TODO: deal with more intensive acceptance fractions
|
|
185
|
+
# Run any move-specific setup.
|
|
186
|
+
|
|
187
|
+
# data should not be whitened
|
|
188
|
+
|
|
189
|
+
# Split the ensemble in half and iterate over these two halves.
|
|
190
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
191
|
+
|
|
192
|
+
ntemps, nwalkers, nleaves_max, ndim = state.branches_coords["gb"].shape
|
|
193
|
+
|
|
194
|
+
f_test = state.branches_coords["gb"][:, :, :, 1] / 1e3
|
|
195
|
+
|
|
196
|
+
# TODO: add actual amplitude
|
|
197
|
+
N_vals = new_state.branches["gb"].branch_supplimental.holder["N_vals"]
|
|
198
|
+
|
|
199
|
+
N_vals = self.xp.asarray(N_vals)
|
|
200
|
+
|
|
201
|
+
N_vals_2_times = self.xp.concatenate([N_vals, N_vals], axis=-1)
|
|
202
|
+
|
|
203
|
+
gb_coords = np.zeros((ntemps, nwalkers, nleaves_max, ndim))
|
|
204
|
+
gb_coords[new_state.branches_inds["gb"]] = new_state.branches_coords["gb"][new_state.branches_inds["gb"]].copy()
|
|
205
|
+
|
|
206
|
+
log_like_tmp = self.xp.asarray(new_state.log_like)
|
|
207
|
+
log_prior_tmp = self.xp.asarray(new_state.log_prior)
|
|
208
|
+
|
|
209
|
+
self.mempool.free_all_blocks()
|
|
210
|
+
|
|
211
|
+
unique_N = np.unique(N_vals)
|
|
212
|
+
|
|
213
|
+
groups = get_groups_from_band_structure(f_test, self.band_edges, xp=np)
|
|
214
|
+
breakpoint()
|
|
215
|
+
groups[new_state.branches_inds["gb"]] = -1
|
|
216
|
+
unique_groups, group_len = np.unique(groups.flatten(), return_counts=True)
|
|
217
|
+
|
|
218
|
+
# remove information about the bad "-1" group
|
|
219
|
+
if -1 in unique_groups:
|
|
220
|
+
group_len = np.delete(group_len, unique_groups == -1)
|
|
221
|
+
unique_groups = np.delete(unique_groups, unique_groups == -1)
|
|
222
|
+
|
|
223
|
+
if len(unique_groups) == 0:
|
|
224
|
+
return state, accepted
|
|
225
|
+
|
|
226
|
+
# needs to be max because some values may be missing due to evens and odds
|
|
227
|
+
num_groups = unique_groups.max().item() + 1
|
|
228
|
+
|
|
229
|
+
waveform_kwargs_now = self.waveform_kwargs.copy()
|
|
230
|
+
if "N" in waveform_kwargs_now:
|
|
231
|
+
waveform_kwargs_now.pop("N")
|
|
232
|
+
waveform_kwargs_now["start_freq_ind"] = self.start_freq_ind
|
|
233
|
+
|
|
234
|
+
points_to_move = q["gb"][group]
|
|
235
|
+
group_branch_supp_info = new_state.branches_supplimental["gb"][group_cpu]
|
|
236
|
+
points_for_move = self.xp.asarray(group_branch_supp_info["group_move_points"])
|
|
237
|
+
|
|
238
|
+
q_temp, factors_temp = self.get_proposal(
|
|
239
|
+
{"gb": points_to_move.transpose(0, 2, 1, 3).reshape(ntemps * nleaves_max, int(nwalkers / 2), 1, ndim)},
|
|
240
|
+
{"gb": [points_for_move.transpose(0, 2, 1, 3).reshape(ntemps * nleaves_max, int(nwalkers / 2), 1, ndim)]},
|
|
241
|
+
model.random
|
|
242
|
+
)
|
|
243
|
+
breakpoint()
|
|
244
|
+
|
|
245
|
+
for group_iter in range(num_groups):
|
|
246
|
+
# st = time.perf_counter()
|
|
247
|
+
# sometimes you will have an extra odd or even group only
|
|
248
|
+
# the group_iter may not match the actual running group number in this case
|
|
249
|
+
if group_iter not in groups:
|
|
250
|
+
continue
|
|
251
|
+
group = [grp[groups == group_iter].flatten() for grp in group_temp_finder]
|
|
252
|
+
|
|
253
|
+
# st = time.perf_counter()
|
|
254
|
+
temp_inds, walkers_inds, leaf_inds = [self.xp.asarray(grp) for grp in group]
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
if True:
|
|
260
|
+
factors = self.xp.zeros((ntemps, nwalkers, nleaves_max))
|
|
261
|
+
factors[:, split_here] = factors_temp.reshape(ntemps, nleaves_max, int(nwalkers / 2)).transpose(0, 2, 1)
|
|
262
|
+
|
|
263
|
+
# use new_state here to get change after 1st round
|
|
264
|
+
q = {"gb": gb_coords.copy()}
|
|
265
|
+
|
|
266
|
+
q["gb"][:, split_here] = q_temp["gb"].reshape(ntemps, nleaves_max, int(nwalkers / 2), ndim).transpose(0, 2, 1, 3)
|
|
267
|
+
|
|
268
|
+
"""et = time.perf_counter()
|
|
269
|
+
print("prop", (et - st))"""
|
|
270
|
+
|
|
271
|
+
for group_iter in range(num_groups):
|
|
272
|
+
# st = time.perf_counter()
|
|
273
|
+
# sometimes you will have an extra odd or even group only
|
|
274
|
+
# the group_iter may not match the actual running group number in this case
|
|
275
|
+
if group_iter not in groups:
|
|
276
|
+
continue
|
|
277
|
+
group = [grp[groups == group_iter].flatten() for grp in group_temp_finder]
|
|
278
|
+
|
|
279
|
+
# st = time.perf_counter()
|
|
280
|
+
temp_inds, walkers_inds, leaf_inds = [self.xp.asarray(grp) for grp in group]
|
|
281
|
+
q = q = {"gb": gb_coords.copy()}
|
|
282
|
+
# new_inds = deepcopy(new_state.branches_inds)
|
|
283
|
+
|
|
284
|
+
points_to_move = q["gb"][group]
|
|
285
|
+
group_branch_supp_info = new_state.branches_supplimental["gb"][group_cpu]
|
|
286
|
+
points_for_move = self.xp.asarray(group_branch_supp_info["group_move_points"])
|
|
287
|
+
|
|
288
|
+
"""et = time.perf_counter()
|
|
289
|
+
print("before prop", (et - st))
|
|
290
|
+
st = time.perf_counter()"""
|
|
291
|
+
q_temp, factors_temp = self.get_proposal(
|
|
292
|
+
{"gb": points_to_move}, {"gb": points_for_move}, model.random
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
"""et = time.perf_counter()
|
|
296
|
+
print("after prop", (et - st))
|
|
297
|
+
st = time.perf_counter()"""
|
|
298
|
+
|
|
299
|
+
#breakpoint()
|
|
300
|
+
|
|
301
|
+
q["gb"][group] = q_temp["gb"]
|
|
302
|
+
|
|
303
|
+
# data should not be whitened
|
|
304
|
+
# TODO: take this out of the groups loop
|
|
305
|
+
if "noise_params" not in state.branches:
|
|
306
|
+
use_stock_psd = True
|
|
307
|
+
psd = self.xp.tile(self.xp.asarray(self.psd), (ntemps * nwalkers, 1, 1))
|
|
308
|
+
|
|
309
|
+
else:
|
|
310
|
+
use_stock_psd = False
|
|
311
|
+
noise_params = state.branches["noise_params"].coords
|
|
312
|
+
if self.psd_func is None:
|
|
313
|
+
raise ValueError("When providing noise_params, psd_func kwargs in __init__ function must be given.")
|
|
314
|
+
|
|
315
|
+
if noise_params.ndim == 3:
|
|
316
|
+
noise_params = noise_params[0]
|
|
317
|
+
tmp = self.xp.asarray([self.psd_func(self.fd, *noise_params.reshape(-1, noise_params.shape[-1]).T, **self.noise_kwargs) for _ in range(2)])
|
|
318
|
+
psd = tmp.transpose((1,0,2))
|
|
319
|
+
self.noise_ll = -self.xp.sum(self.xp.log(psd), axis=(1, 2)).reshape(ntemps, nwalkers)
|
|
320
|
+
|
|
321
|
+
try:
|
|
322
|
+
self.noise_ll = self.noise_ll.get()
|
|
323
|
+
except AttributeError:
|
|
324
|
+
pass
|
|
325
|
+
|
|
326
|
+
noise_ll = self.noise_ll
|
|
327
|
+
|
|
328
|
+
if self.use_gpu:
|
|
329
|
+
new_points_prior = q["gb"][group].get()
|
|
330
|
+
old_points_prior = gb_coords[group].get()
|
|
331
|
+
else:
|
|
332
|
+
new_points_prior = q["gb"][group]
|
|
333
|
+
old_points_prior = gb_coords[group]
|
|
334
|
+
|
|
335
|
+
# TODO: GPUize prior
|
|
336
|
+
logp = self.xp.asarray(self.priors["gb"].logpdf(new_points_prior))
|
|
337
|
+
|
|
338
|
+
if self.xp.all(self.xp.isinf(logp)):
|
|
339
|
+
pass
|
|
340
|
+
|
|
341
|
+
keep_here = self.xp.where(~self.xp.isinf(logp))
|
|
342
|
+
|
|
343
|
+
points_remove = self.parameter_transforms.both_transforms(points_to_move[keep_here], xp=self.xp)
|
|
344
|
+
points_add = self.parameter_transforms.both_transforms(q_temp["gb"][keep_here], xp=self.xp)
|
|
345
|
+
|
|
346
|
+
data_index = self.xp.asarray((temp_inds[keep_here] * nwalkers + walkers_inds[keep_here]).astype(xp.int32))
|
|
347
|
+
noise_index = self.xp.asarray((temp_inds[keep_here] * nwalkers + walkers_inds[keep_here]).astype(xp.int32))
|
|
348
|
+
nChannels = 2
|
|
349
|
+
|
|
350
|
+
delta_ll = self.xp.full(points_to_move.shape[0], -1e300)
|
|
351
|
+
|
|
352
|
+
delta_ll[keep_here] = self.gb.swap_likelihood_difference(
|
|
353
|
+
points_remove,
|
|
354
|
+
points_add,
|
|
355
|
+
data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).copy(),
|
|
356
|
+
psd.copy(),
|
|
357
|
+
data_index=data_index,
|
|
358
|
+
noise_index=noise_index,
|
|
359
|
+
adjust_inplace=False,
|
|
360
|
+
**self.waveform_kwargs
|
|
361
|
+
)
|
|
362
|
+
"""dhr = self.gb.d_h_remove.copy()
|
|
363
|
+
dha = self.gb.d_h_add.copy()
|
|
364
|
+
aa = self.gb.add_add.copy()
|
|
365
|
+
rr = self.gb.remove_remove.copy()
|
|
366
|
+
ar = self.gb.add_remove.copy()
|
|
367
|
+
|
|
368
|
+
kwargs_tmp = self.waveform_kwargs.copy()
|
|
369
|
+
kwargs_tmp["use_c_implementation"] = False
|
|
370
|
+
check = self.gb.swap_likelihood_difference(
|
|
371
|
+
points_remove,
|
|
372
|
+
points_add,
|
|
373
|
+
data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).copy(),
|
|
374
|
+
psd.reshape(ntemps * nwalkers, nChannels, -1).copy(),
|
|
375
|
+
data_index=data_index,
|
|
376
|
+
noise_index=noise_index,
|
|
377
|
+
adjust_inplace=False,
|
|
378
|
+
**kwargs_tmp
|
|
379
|
+
)
|
|
380
|
+
breakpoint()"""
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
optimized_snr = self.xp.sqrt(self.gb.add_add.real)
|
|
384
|
+
detected_snr = (self.gb.d_h_add + self.gb.add_remove).real / optimized_snr
|
|
385
|
+
|
|
386
|
+
if self.search:
|
|
387
|
+
inds_fix = ((optimized_snr < self.search_snr_lim) | (detected_snr < (0.8 * self.search_snr_lim)))
|
|
388
|
+
|
|
389
|
+
"""try:
|
|
390
|
+
inds_fix = inds_fix.get()
|
|
391
|
+
except AttributeError:
|
|
392
|
+
pass"""
|
|
393
|
+
if self.xp.any(inds_fix):
|
|
394
|
+
delta_ll[keep_here[0][inds_fix]] = -1e300
|
|
395
|
+
|
|
396
|
+
prev_logl = log_like_tmp[(temp_inds, walkers_inds)]
|
|
397
|
+
logl = delta_ll + prev_logl
|
|
398
|
+
|
|
399
|
+
#if np.any(logl - np.load("noise_ll.npy").flatten() > 0.0):
|
|
400
|
+
# breakpoint()
|
|
401
|
+
#print("multi check: ", (logl - np.load("noise_ll.npy").flatten()))
|
|
402
|
+
prev_logp = self.xp.asarray(self.priors["gb"].logpdf(old_points_prior))
|
|
403
|
+
|
|
404
|
+
betas_in = self.xp.asarray(self.temperature_control.betas)[temp_inds]
|
|
405
|
+
logP = self.compute_log_posterior(logl, logp, betas=betas_in)
|
|
406
|
+
|
|
407
|
+
# TODO: check about prior = - inf
|
|
408
|
+
# takes care of tempering
|
|
409
|
+
prev_logP = self.compute_log_posterior(prev_logl, prev_logp, betas=betas_in)
|
|
410
|
+
|
|
411
|
+
# TODO: think about factors in tempering
|
|
412
|
+
lnpdiff = factors_temp + logP - prev_logP
|
|
413
|
+
|
|
414
|
+
# TODO: think about random states
|
|
415
|
+
keep = lnpdiff > self.xp.log(self.xp.random.rand(*logP.shape))
|
|
416
|
+
|
|
417
|
+
if self.xp.any(keep):
|
|
418
|
+
# if gibbs sampling, this will say it is accepted if
|
|
419
|
+
# any of the gibbs proposals were accepted
|
|
420
|
+
accepted_here = keep.copy()
|
|
421
|
+
|
|
422
|
+
# check freq overlap
|
|
423
|
+
f0_new = q_temp["gb"][keep, 1]
|
|
424
|
+
f0_old = gb_coords[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])][:, 1]
|
|
425
|
+
nleaves_max = state.branches["gb"].nleaves_max
|
|
426
|
+
check_f0 = self.xp.zeros((ntemps, nwalkers, nleaves_max))
|
|
427
|
+
#check_f0_old = self.xp.zeros((ntemps, nwalkers, nleaves_max))
|
|
428
|
+
check_f0[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])] = f0_new
|
|
429
|
+
|
|
430
|
+
check_f0_old = self.xp.zeros((ntemps, nwalkers, nleaves_max))
|
|
431
|
+
#check_f0_old = self.xp.zeros((ntemps, nwalkers, nleaves_max))
|
|
432
|
+
|
|
433
|
+
check_f0_old[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])] = f0_old
|
|
434
|
+
|
|
435
|
+
check_f0_old_sorted = self.xp.sort(check_f0_old, axis=-1)
|
|
436
|
+
inds_f0_old_sorted = self.xp.argsort(check_f0_old, axis=-1)
|
|
437
|
+
|
|
438
|
+
check_f0_tmp = check_f0.reshape(-1, check_f0.shape[-1])
|
|
439
|
+
check_f0_old_sorted_tmp = check_f0_old_sorted.reshape(-1, check_f0_old_sorted.shape[-1])
|
|
440
|
+
|
|
441
|
+
check_f0_in_old_inds = searchsorted2d_vec(check_f0_old_sorted_tmp, check_f0_tmp, xp=self.xp, side="right").reshape(check_f0.shape)
|
|
442
|
+
|
|
443
|
+
zero_check = check_f0_in_old_inds[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])]
|
|
444
|
+
f0_new_here = f0_new.copy()
|
|
445
|
+
keep_for_after = keep.copy()
|
|
446
|
+
inds_test_here = [-2, -1, 0, 1]
|
|
447
|
+
for ind_test_here in inds_test_here:
|
|
448
|
+
here_check = zero_check + ind_test_here
|
|
449
|
+
do_check = np.ones_like(here_check, dtype=bool)
|
|
450
|
+
do_check[here_check < 0] = False
|
|
451
|
+
do_check[here_check >= check_f0_old.shape[-1]] = False
|
|
452
|
+
|
|
453
|
+
here_vals = check_f0_old_sorted[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])]
|
|
454
|
+
here_inds = inds_f0_old_sorted[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])]
|
|
455
|
+
here_test = (self.xp.abs(f0_new_here - here_vals) / 1e3 / self.df).astype(int)
|
|
456
|
+
fix_bad2_tmp = self.xp.arange(len(keep))[keep]
|
|
457
|
+
fix_bad2 = fix_bad2_tmp[(here_test < dbin) & (leaf_inds[keep] != here_inds) & (do_check)]
|
|
458
|
+
|
|
459
|
+
# print("NEW BAD", len(fix_bad2), len(fix_bad2) / len(keep[keep]))
|
|
460
|
+
keep_for_after[fix_bad2] = False
|
|
461
|
+
|
|
462
|
+
check_f0_sorted = self.xp.sort(check_f0, axis=-1)
|
|
463
|
+
inds_f0_sorted = self.xp.argsort(check_f0, axis=-1)
|
|
464
|
+
check_f0_diff = self.xp.zeros_like(check_f0_sorted, dtype=int)
|
|
465
|
+
check_f0_diff[:, :, 1:] = (self.xp.diff(check_f0_sorted, axis=-1) / 1e3 / self.df).astype(int)
|
|
466
|
+
|
|
467
|
+
bad = (check_f0_diff < dbin) & (check_f0_sorted != 0.0)
|
|
468
|
+
if self.xp.any(bad):
|
|
469
|
+
try:
|
|
470
|
+
bad_inds = self.xp.where(bad)
|
|
471
|
+
|
|
472
|
+
# fix the last entry of bad inds
|
|
473
|
+
inds_bad = (bad_inds[0], bad_inds[1], inds_f0_sorted[bad])
|
|
474
|
+
bad_check_val = (inds_bad[0] * 1e10 + inds_bad[1] * 1e5 + inds_bad[2]).astype(int)
|
|
475
|
+
# we are going to make this proposal not accepted
|
|
476
|
+
# this so far is only ever an accepted-level problem at high temps
|
|
477
|
+
# where large frequency jumps can happen
|
|
478
|
+
check_val = (temp_inds[keep] * 1e10 + walkers_inds[keep] * 1e5 + leaf_inds[keep]).astype(int)
|
|
479
|
+
fix_keep = self.xp.arange(len(keep))[keep][self.xp.in1d(check_val, bad_check_val)]
|
|
480
|
+
keep[fix_keep] = False
|
|
481
|
+
except:
|
|
482
|
+
breakpoint()
|
|
483
|
+
|
|
484
|
+
gb_coords[(temp_inds[keep], walkers_inds[keep], leaf_inds[keep])] = q_temp["gb"][keep]
|
|
485
|
+
|
|
486
|
+
# parameters were run for all ~np.isinf(logp), need to adjust for those not accepted
|
|
487
|
+
keep_from_before = (keep * (~self.xp.isinf(logp)))[~self.xp.isinf(logp)]
|
|
488
|
+
group_index = data_index[keep_from_before]
|
|
489
|
+
|
|
490
|
+
waveform_kwargs_fill = self.waveform_kwargs.copy()
|
|
491
|
+
waveform_kwargs_fill["start_freq_ind"] = self.start_freq_ind
|
|
492
|
+
"""ll_check_d_h_add = self.gb.get_ll(
|
|
493
|
+
points_add.T,
|
|
494
|
+
data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
|
|
495
|
+
psd.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
|
|
496
|
+
data_index=data_index,
|
|
497
|
+
noise_index=noise_index,
|
|
498
|
+
**self.waveform_kwargs
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
h_h_d_h_add = self.gb.h_h.copy()
|
|
502
|
+
d_h_d_h_add = self.gb.d_h.copy()
|
|
503
|
+
|
|
504
|
+
ll_check_d_h_remove = self.gb.get_ll(
|
|
505
|
+
points_remove.T,
|
|
506
|
+
data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
|
|
507
|
+
psd.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
|
|
508
|
+
data_index=data_index,
|
|
509
|
+
noise_index=noise_index,
|
|
510
|
+
**self.waveform_kwargs
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
h_h_d_h_remove = self.gb.h_h.copy()
|
|
514
|
+
d_h_d_h_remove = self.gb.d_h.copy()"""
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
# remove templates by multiplying by "adding them to" d - h
|
|
518
|
+
try:
|
|
519
|
+
self.gb.generate_global_template(points_remove[keep_from_before],
|
|
520
|
+
group_index, data_minus_template.reshape((-1,) + data_minus_template.shape[2:]), **waveform_kwargs_fill
|
|
521
|
+
)
|
|
522
|
+
except ValueError:
|
|
523
|
+
breakpoint()
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
"""self.gb.d_d = self.xp.asarray(-2 * state.log_like.flatten())
|
|
527
|
+
ll_check_add = self.gb.get_ll(
|
|
528
|
+
points_add.T,
|
|
529
|
+
data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
|
|
530
|
+
psd.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
|
|
531
|
+
data_index=data_index,
|
|
532
|
+
noise_index=noise_index,
|
|
533
|
+
**self.waveform_kwargs
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
h_h_add = self.gb.h_h.copy()
|
|
537
|
+
d_h_add = self.gb.d_h.copy()
|
|
538
|
+
|
|
539
|
+
ll_check_remove = self.gb.get_ll(
|
|
540
|
+
points_remove.T,
|
|
541
|
+
data_minus_template.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
|
|
542
|
+
psd.reshape(ntemps * nwalkers, nChannels, -1).transpose(1, 0, 2),
|
|
543
|
+
data_index=data_index,
|
|
544
|
+
noise_index=noise_index,
|
|
545
|
+
**self.waveform_kwargs
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
h_h_remove = self.gb.h_h.copy()
|
|
549
|
+
d_h_remove = self.gb.d_h.copy()
|
|
550
|
+
breakpoint()"""
|
|
551
|
+
|
|
552
|
+
# add templates by adding to -(-(d - h) + add) = d - (h + add)
|
|
553
|
+
data_minus_template *= -1
|
|
554
|
+
try:
|
|
555
|
+
self.gb.generate_global_template(points_add[keep_from_before],
|
|
556
|
+
group_index, data_minus_template.reshape((-1,) + data_minus_template.shape[2:]), **waveform_kwargs_fill
|
|
557
|
+
)
|
|
558
|
+
except ValueError:
|
|
559
|
+
breakpoint()
|
|
560
|
+
data_minus_template *= -1
|
|
561
|
+
|
|
562
|
+
# set unaccepted differences to zero
|
|
563
|
+
accepted_delta_ll = delta_ll * (keep)
|
|
564
|
+
accepted_delta_lp = (logp - prev_logp)
|
|
565
|
+
accepted_delta_lp[self.xp.isinf(accepted_delta_lp)] = 0.0
|
|
566
|
+
logl_change_contribution = np.zeros_like(new_state.log_like)
|
|
567
|
+
logp_change_contribution = np.zeros_like(new_state.log_prior)
|
|
568
|
+
|
|
569
|
+
try:
|
|
570
|
+
in_tuple = (accepted_delta_ll[keep].get(), accepted_delta_lp[keep].get(), temp_inds[keep].get(), walkers_inds[keep].get())
|
|
571
|
+
except AttributeError:
|
|
572
|
+
in_tuple = (accepted_delta_ll[keep], accepted_delta_lp[keep], temp_inds[keep], walkers_inds[keep])
|
|
573
|
+
|
|
574
|
+
for i, (dll, dlp, ti, wi) in enumerate(zip(*in_tuple)):
|
|
575
|
+
|
|
576
|
+
logl_change_contribution[ti, wi] += dll
|
|
577
|
+
logp_change_contribution[ti, wi] += dlp
|
|
578
|
+
|
|
579
|
+
log_like_tmp += self.xp.asarray(logl_change_contribution)
|
|
580
|
+
log_prior_tmp += self.xp.asarray(logp_change_contribution)
|
|
581
|
+
|
|
582
|
+
if np.any(accepted_delta_ll > 1e8):
|
|
583
|
+
breakpoint()
|
|
584
|
+
|
|
585
|
+
try:
|
|
586
|
+
new_state.branches["gb"].coords[:] = gb_coords.get()
|
|
587
|
+
new_state.log_like[:] = log_like_tmp.get()
|
|
588
|
+
new_state.log_prior[:] = log_prior_tmp.get()
|
|
589
|
+
except AttributeError:
|
|
590
|
+
new_state.branches["gb"].coords[:] = gb_coords
|
|
591
|
+
new_state.log_like[:] = log_like_tmp
|
|
592
|
+
new_state.log_prior[:] = log_prior_tmp
|
|
593
|
+
|
|
594
|
+
if self.time % 1 == 0:
|
|
595
|
+
lp_after = model.compute_log_prior_fn(new_state.branches_coords, inds=new_state.branches_inds)
|
|
596
|
+
ll_after = model.compute_log_like_fn(new_state.branches_coords, inds=new_state.branches_inds, logp=lp_after, supps=new_state.supplimental, branch_supps=new_state.branches_supplimental)
|
|
597
|
+
#check = -1/2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / self.xp.asarray(self.psd), axis=(2, 3))
|
|
598
|
+
#check2 = -1/2 * 4 * self.df * self.xp.sum(tmp.conj() * tmp / self.xp.asarray(self.psd), axis=(2, 3))
|
|
599
|
+
#print(np.abs(new_state.log_like - ll_after[0]).max(), np.abs(new_state.log_prior - lp_after).max())
|
|
600
|
+
if np.abs(new_state.log_prior - lp_after).max() > 0.1 or np.abs(new_state.log_like - ll_after[0]).max() > 1e0:
|
|
601
|
+
breakpoint()
|
|
602
|
+
|
|
603
|
+
# if any are even remotely getting to be different, reset all (small change)
|
|
604
|
+
elif np.abs(new_state.log_like - ll_after[0]).max() > 1e-3:
|
|
605
|
+
|
|
606
|
+
fix_here = np.abs(new_state.log_like - ll_after[0]) > 1e-6
|
|
607
|
+
data_minus_template_old = data_minus_template.copy()
|
|
608
|
+
data_minus_template = self.xp.zeros_like(data_minus_template_old)
|
|
609
|
+
data_minus_template[:] = self.xp.asarray(self.data)[None, None]
|
|
610
|
+
templates = self.xp.zeros_like(data_minus_template).reshape(-1, 2, data_minus_template.shape[-1])
|
|
611
|
+
for name in new_state.branches.keys():
|
|
612
|
+
if name not in ["gb", "gb"]:
|
|
613
|
+
continue
|
|
614
|
+
new_state_branch = new_state.branches[name]
|
|
615
|
+
coords_here = new_state_branch.coords[new_state_branch.inds]
|
|
616
|
+
ntemps, nwalkers, nleaves_max_here, ndim = new_state_branch.shape
|
|
617
|
+
|
|
618
|
+
group_index = np.repeat(np.arange(ntemps * nwalkers).reshape(ntemps, nwalkers, 1), nleaves_max_here, axis=-1)[new_state_branch.inds]
|
|
619
|
+
|
|
620
|
+
coords_here_in = self.parameter_transforms.both_transforms(coords_here, xp=np)
|
|
621
|
+
|
|
622
|
+
self.gb.generate_global_template(coords_here_in, group_index, templates, batch_size=1000, **self.waveform_kwargs)
|
|
623
|
+
|
|
624
|
+
data_minus_template -= templates.reshape(ntemps, nwalkers, 2, templates.shape[-1])
|
|
625
|
+
|
|
626
|
+
psd_here = psd.reshape(ntemps, nwalkers, 2, templates.shape[-1])
|
|
627
|
+
new_like = -1 / 2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / psd_here, axis=(2, 3)).real.get()
|
|
628
|
+
|
|
629
|
+
new_like += self.noise_ll
|
|
630
|
+
new_state.log_like[:] = new_like.reshape(ntemps, nwalkers)
|
|
631
|
+
|
|
632
|
+
"""elif np.abs(new_state.log_prior - lp_after).max() > 1e-6 or np.abs(new_state.log_like - ll_after[0]).max() > 0.1:
|
|
633
|
+
# TODO: need to investigate when this fails
|
|
634
|
+
self.fixed_like_diff += 1
|
|
635
|
+
print("Fixing like diff for now.", self.fixed_like_diff)
|
|
636
|
+
fix_here = np.abs(new_state.log_like - ll_after[0]) > 0.1
|
|
637
|
+
new_state.log_like[fix_here] = ll_after[0][fix_here]"""
|
|
638
|
+
|
|
639
|
+
"""
|
|
640
|
+
check_logl = model.compute_log_like_fn(new_state.branches_coords, inds=new_state.branches_inds, branch_supps=new_state.branches_supplimental, supps=new_state.supplimental)
|
|
641
|
+
sigll = model.log_like_fn.f.signal_ll.copy()
|
|
642
|
+
check_logl2 = model.compute_log_like_fn(state.branches_coords, inds=state.branches_inds, branch_supps=state.branches_supplimental, supps=state.supplimental)
|
|
643
|
+
breakpoint()
|
|
644
|
+
"""
|
|
645
|
+
"""et = time.perf_counter()
|
|
646
|
+
print("group middle", (et - st))
|
|
647
|
+
st = time.perf_counter()"""
|
|
648
|
+
# get accepted fraction
|
|
649
|
+
accepted_check = np.all(np.abs(new_state.branches_coords["gb"] - state.branches_coords["gb"]) > 0.0, axis=-1).sum(axis=(1, 2)) / new_state.branches_inds["gb"].sum(axis=(1,2))
|
|
650
|
+
|
|
651
|
+
# manually tell temperatures how real overall acceptance fraction is
|
|
652
|
+
number_of_walkers_for_accepted = np.floor(nwalkers * accepted_check).astype(int)
|
|
653
|
+
|
|
654
|
+
accepted_inds = np.tile(np.arange(nwalkers), (ntemps, 1))
|
|
655
|
+
|
|
656
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
657
|
+
accepted[accepted_inds < number_of_walkers_for_accepted[:, None]] = True
|
|
658
|
+
|
|
659
|
+
if self.temperature_control is not None:
|
|
660
|
+
new_state, accepted = self.temperature_control.temper_comps(new_state, accepted)
|
|
661
|
+
# self.temperature_control.swaps_accepted = np.zeros((ntemps - 1))
|
|
662
|
+
|
|
663
|
+
if np.any(new_state.log_like > 1e10):
|
|
664
|
+
breakpoint()
|
|
665
|
+
self.time += 1
|
|
666
|
+
|
|
667
|
+
#et = time.perf_counter()
|
|
668
|
+
#print("group end", (et - st))
|
|
669
|
+
|
|
670
|
+
return new_state, accepted
|
|
671
|
+
|