lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lisaanalysistools might be problematic. Click here for more details.
- lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
- lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
- lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
- lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
- lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
- lisatools/__init__.py +0 -0
- lisatools/_version.py +4 -0
- lisatools/analysiscontainer.py +438 -0
- lisatools/cutils/detector.cpython-312-darwin.so +0 -0
- lisatools/datacontainer.py +292 -0
- lisatools/detector.py +410 -0
- lisatools/diagnostic.py +976 -0
- lisatools/glitch.py +193 -0
- lisatools/sampling/__init__.py +0 -0
- lisatools/sampling/likelihood.py +882 -0
- lisatools/sampling/moves/__init__.py +0 -0
- lisatools/sampling/moves/gbgroupstretch.py +53 -0
- lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
- lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
- lisatools/sampling/moves/gbspecialstretch.py +1836 -0
- lisatools/sampling/moves/mbhspecialmove.py +286 -0
- lisatools/sampling/moves/placeholder.py +16 -0
- lisatools/sampling/moves/skymodehop.py +110 -0
- lisatools/sampling/moves/specialforegroundmove.py +564 -0
- lisatools/sampling/prior.py +508 -0
- lisatools/sampling/stopping.py +320 -0
- lisatools/sampling/utility.py +324 -0
- lisatools/sensitivity.py +888 -0
- lisatools/sources/__init__.py +0 -0
- lisatools/sources/emri/__init__.py +1 -0
- lisatools/sources/emri/tdiwaveform.py +72 -0
- lisatools/stochastic.py +291 -0
- lisatools/utils/__init__.py +0 -0
- lisatools/utils/constants.py +40 -0
- lisatools/utils/multigpudataholder.py +730 -0
- lisatools/utils/pointeradjust.py +106 -0
- lisatools/utils/utility.py +240 -0
|
@@ -0,0 +1,1836 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from inspect import Attribute
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy import stats
|
|
7
|
+
import warnings
|
|
8
|
+
import time
|
|
9
|
+
from gbgpu.utils.utility import get_N
|
|
10
|
+
from lisatools.sensitivity import Soms_d_all, Sa_a_all
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import cupy as xp
|
|
14
|
+
|
|
15
|
+
gpu_available = True
|
|
16
|
+
except ModuleNotFoundError:
|
|
17
|
+
import numpy as xp
|
|
18
|
+
|
|
19
|
+
gpu_available = False
|
|
20
|
+
|
|
21
|
+
from lisatools.utils.utility import searchsorted2d_vec, get_groups_from_band_structure
|
|
22
|
+
from eryn.moves import StretchMove
|
|
23
|
+
from eryn.prior import ProbDistContainer
|
|
24
|
+
from eryn.utils.utility import groups_from_inds
|
|
25
|
+
from eryn.utils import PeriodicContainer
|
|
26
|
+
|
|
27
|
+
from eryn.moves import GroupStretchMove, Move
|
|
28
|
+
from eryn.moves.multipletry import logsumexp, get_mt_computations
|
|
29
|
+
|
|
30
|
+
from ...diagnostic import inner_product
|
|
31
|
+
from lisatools.globalfit.state import State
|
|
32
|
+
from lisatools.sampling.prior import GBPriorWrap
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
__all__ = ["GBSpecialStretchMove"]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# MHMove needs to be to the left here to overwrite GBBruteRejectionRJ RJ proposal method
|
|
39
|
+
class GBSpecialStretchMove(GroupStretchMove, Move):
|
|
40
|
+
"""Generate Revesible-Jump proposals for GBs with try-force rejection
|
|
41
|
+
|
|
42
|
+
Will use gpu if template generator uses GPU.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
priors (object): :class:`ProbDistContainer` object that has ``logpdf``
|
|
46
|
+
and ``rvs`` methods.
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
gb,
|
|
53
|
+
priors,
|
|
54
|
+
start_freq_ind,
|
|
55
|
+
data_length,
|
|
56
|
+
mgh,
|
|
57
|
+
fd,
|
|
58
|
+
band_edges,
|
|
59
|
+
gpu_priors,
|
|
60
|
+
*args,
|
|
61
|
+
waveform_kwargs={},
|
|
62
|
+
parameter_transforms=None,
|
|
63
|
+
snr_lim=1e-10,
|
|
64
|
+
rj_proposal_distribution=None,
|
|
65
|
+
num_repeat_proposals=1,
|
|
66
|
+
name=None,
|
|
67
|
+
use_prior_removal=False,
|
|
68
|
+
phase_maximize=False,
|
|
69
|
+
**kwargs
|
|
70
|
+
):
|
|
71
|
+
# return_gpu is a kwarg for the stretch move
|
|
72
|
+
GroupStretchMove.__init__(self, *args, return_gpu=True, **kwargs)
|
|
73
|
+
|
|
74
|
+
self.gpu_priors = gpu_priors
|
|
75
|
+
self.name = name
|
|
76
|
+
self.num_repeat_proposals = num_repeat_proposals
|
|
77
|
+
self.use_prior_removal = use_prior_removal
|
|
78
|
+
|
|
79
|
+
for key in priors:
|
|
80
|
+
if not isinstance(priors[key], ProbDistContainer) and not isinstance(priors[key], GBPriorWrap):
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"Priors need to be eryn.priors.ProbDistContainer object."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
self.priors = priors
|
|
86
|
+
self.gb = gb
|
|
87
|
+
self.stop_here = True
|
|
88
|
+
|
|
89
|
+
args = [priors["gb"].priors_in[(0, 1)].rho_star]
|
|
90
|
+
args += [priors["gb"].priors_in[(0, 1)].frequency_prior.min_val, priors["gb"].priors_in[(0, 1)].frequency_prior.max_val]
|
|
91
|
+
for i in range(2, 8):
|
|
92
|
+
args += [priors["gb"].priors_in[i].min_val, priors["gb"].priors_in[i].max_val]
|
|
93
|
+
|
|
94
|
+
self.gpu_cuda_priors = self.gb.pyPriorPackage(*tuple(args))
|
|
95
|
+
self.gpu_cuda_wrap = self.gb.pyPeriodicPackage(2 * np.pi, np.pi, 2 * np.pi)
|
|
96
|
+
|
|
97
|
+
# use gpu from template generator
|
|
98
|
+
# self.use_gpu = gb.use_gpu
|
|
99
|
+
if self.use_gpu:
|
|
100
|
+
self.mempool = self.xp.get_default_memory_pool()
|
|
101
|
+
|
|
102
|
+
self.band_edges = band_edges
|
|
103
|
+
self.num_bands = len(band_edges) - 1
|
|
104
|
+
self.start_freq_ind = start_freq_ind
|
|
105
|
+
self.data_length = data_length
|
|
106
|
+
self.waveform_kwargs = waveform_kwargs
|
|
107
|
+
self.parameter_transforms = parameter_transforms
|
|
108
|
+
self.fd = fd
|
|
109
|
+
self.df = (fd[1] - fd[0]).item()
|
|
110
|
+
self.mgh = mgh
|
|
111
|
+
self.phase_maximize = phase_maximize
|
|
112
|
+
|
|
113
|
+
self.snr_lim = snr_lim
|
|
114
|
+
|
|
115
|
+
self.band_edges = self.xp.asarray(self.band_edges)
|
|
116
|
+
|
|
117
|
+
self.rj_proposal_distribution = rj_proposal_distribution
|
|
118
|
+
self.is_rj_prop = self.rj_proposal_distribution is not None
|
|
119
|
+
|
|
120
|
+
# setup N vals for bands
|
|
121
|
+
band_mean_f = (self.band_edges[1:] + self.band_edges[:-1]).get() / 2
|
|
122
|
+
self.band_N_vals = xp.asarray(get_N(np.full_like(band_mean_f, 1e-30), band_mean_f, self.waveform_kwargs["T"], self.waveform_kwargs["oversample"]))
|
|
123
|
+
|
|
124
|
+
def setup_gbs(self, branch):
|
|
125
|
+
st = time.perf_counter()
|
|
126
|
+
coords = branch.coords
|
|
127
|
+
inds = branch.inds
|
|
128
|
+
supps = branch.branch_supplimental
|
|
129
|
+
ntemps, nwalkers, nleaves_max, ndim = branch.shape
|
|
130
|
+
all_remaining_freqs = coords[0][inds[0]][:, 1]
|
|
131
|
+
|
|
132
|
+
all_remaining_cords = coords[0][inds[0]]
|
|
133
|
+
|
|
134
|
+
num_remaining = len(all_remaining_freqs)
|
|
135
|
+
|
|
136
|
+
all_temp_fs = self.xp.asarray(coords[inds][:, 1])
|
|
137
|
+
|
|
138
|
+
# TODO: improve this?
|
|
139
|
+
self.inds_freqs_sorted = self.xp.asarray(np.argsort(all_remaining_freqs))
|
|
140
|
+
self.freqs_sorted = self.xp.asarray(np.sort(all_remaining_freqs))
|
|
141
|
+
self.all_coords_sorted = self.xp.asarray(all_remaining_cords)[
|
|
142
|
+
self.inds_freqs_sorted
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
total_binaries = inds.sum().item()
|
|
146
|
+
still_going = xp.ones(total_binaries, dtype=bool)
|
|
147
|
+
inds_zero = xp.searchsorted(self.freqs_sorted, all_temp_fs, side="right") - 1
|
|
148
|
+
left_inds = inds_zero - int(self.nfriends / 2)
|
|
149
|
+
right_inds = inds_zero + int(self.nfriends / 2) - 1
|
|
150
|
+
|
|
151
|
+
# do right first here
|
|
152
|
+
right_inds[left_inds < 0] = self.nfriends - 1
|
|
153
|
+
left_inds[left_inds < 0] = 0
|
|
154
|
+
|
|
155
|
+
# do left first here
|
|
156
|
+
left_inds[right_inds > len(self.freqs_sorted) - 1] = len(self.freqs_sorted) - self.nfriends
|
|
157
|
+
right_inds[right_inds > len(self.freqs_sorted) - 1] = len(self.freqs_sorted) - 1
|
|
158
|
+
|
|
159
|
+
assert np.all(right_inds - left_inds == self.nfriends - 1)
|
|
160
|
+
assert not np.any(right_inds < 0) and not np.any(right_inds > len(self.freqs_sorted) - 1) and not np.any(left_inds < 0) and not np.any(left_inds > len(self.freqs_sorted) - 1)
|
|
161
|
+
|
|
162
|
+
jjj = 0
|
|
163
|
+
while np.any(still_going):
|
|
164
|
+
distance_left = np.abs(all_temp_fs[still_going] - self.freqs_sorted[left_inds[still_going]])
|
|
165
|
+
distance_right = np.abs(all_temp_fs[still_going] - self.freqs_sorted[right_inds[still_going]])
|
|
166
|
+
|
|
167
|
+
check_move_right = (distance_right <= distance_left)
|
|
168
|
+
check_left_inds = left_inds[still_going][check_move_right] + 1
|
|
169
|
+
check_right_inds = right_inds[still_going][check_move_right] + 1
|
|
170
|
+
|
|
171
|
+
new_distance_right = np.abs(all_temp_fs[still_going][check_move_right] - self.freqs_sorted[check_right_inds])
|
|
172
|
+
|
|
173
|
+
change_inds = xp.arange(len(all_temp_fs))[still_going][check_move_right][(new_distance_right < distance_left[check_move_right]) & (check_right_inds < len(self.freqs_sorted))]
|
|
174
|
+
|
|
175
|
+
left_inds[change_inds] += 1
|
|
176
|
+
right_inds[change_inds] += 1
|
|
177
|
+
|
|
178
|
+
stop_inds_right_1 = xp.arange(len(all_temp_fs))[still_going][check_move_right][(check_right_inds >= len(self.freqs_sorted))]
|
|
179
|
+
|
|
180
|
+
# last part is just for up here, below it will remove if it is still equal
|
|
181
|
+
stop_inds_right_2 = xp.arange(len(all_temp_fs))[still_going][check_move_right][(new_distance_right >= distance_left[check_move_right]) & (check_right_inds < len(self.freqs_sorted)) & (distance_right[check_move_right] != distance_left[check_move_right])]
|
|
182
|
+
stop_inds_right = xp.concatenate([stop_inds_right_1, stop_inds_right_2])
|
|
183
|
+
assert np.all(still_going[stop_inds_right])
|
|
184
|
+
|
|
185
|
+
# equal to should only be left over if it was equal above and moving right did not help
|
|
186
|
+
check_move_left = (distance_left <= distance_right)
|
|
187
|
+
check_left_inds = left_inds[still_going][check_move_left] - 1
|
|
188
|
+
check_right_inds = right_inds[still_going][check_move_left] - 1
|
|
189
|
+
|
|
190
|
+
new_distance_left = np.abs(all_temp_fs[still_going][check_move_left] - self.freqs_sorted[check_left_inds])
|
|
191
|
+
|
|
192
|
+
change_inds = xp.arange(len(all_temp_fs))[still_going][check_move_left][(new_distance_left < distance_right[check_move_left]) & (check_left_inds >= 0)]
|
|
193
|
+
|
|
194
|
+
left_inds[change_inds] -= 1
|
|
195
|
+
right_inds[change_inds] -= 1
|
|
196
|
+
|
|
197
|
+
stop_inds_left_1 = xp.arange(len(all_temp_fs))[still_going][check_move_left][(check_left_inds < 0)]
|
|
198
|
+
stop_inds_left_2 = xp.arange(len(all_temp_fs))[still_going][check_move_left][(new_distance_left >= distance_right[check_move_left]) & (check_left_inds >= 0)]
|
|
199
|
+
stop_inds_left = xp.concatenate([stop_inds_left_1, stop_inds_left_2])
|
|
200
|
+
|
|
201
|
+
stop_inds = xp.concatenate([stop_inds_right, stop_inds_left])
|
|
202
|
+
still_going[stop_inds] = False
|
|
203
|
+
# print(jjj, still_going.sum())
|
|
204
|
+
if jjj >= self.nfriends:
|
|
205
|
+
breakpoint()
|
|
206
|
+
jjj += 1
|
|
207
|
+
|
|
208
|
+
start_inds = left_inds.copy().get()
|
|
209
|
+
|
|
210
|
+
start_inds_all = np.zeros_like(inds, dtype=np.int32)
|
|
211
|
+
start_inds_all[inds] = start_inds.astype(np.int32)
|
|
212
|
+
|
|
213
|
+
if "friend_start_inds" not in supps:
|
|
214
|
+
supps.add_objects({"friend_start_inds": start_inds_all})
|
|
215
|
+
else:
|
|
216
|
+
supps[:] = {"friend_start_inds": start_inds_all}
|
|
217
|
+
|
|
218
|
+
self.stretch_friends_args_in = tuple([tmp.copy() for tmp in self.all_coords_sorted.T])
|
|
219
|
+
et = time.perf_counter()
|
|
220
|
+
self.mempool.free_all_blocks()
|
|
221
|
+
# print("SETUP:", et - st)
|
|
222
|
+
# start_inds_freq_out = np.zeros((ntemps, nwalkers, nleaves_max), dtype=int)
|
|
223
|
+
# freqs_sorted_here = self.freqs_sorted.get()
|
|
224
|
+
# freqs_remaining_here = all_remaining_freqs
|
|
225
|
+
|
|
226
|
+
# start_ind_best = np.zeros_like(freqs_remaining_here, dtype=int)
|
|
227
|
+
|
|
228
|
+
# best_index = (
|
|
229
|
+
# np.searchsorted(freqs_sorted_here, freqs_remaining_here, side="right") - 1
|
|
230
|
+
# )
|
|
231
|
+
# best_index[best_index < self.nfriends] = self.nfriends
|
|
232
|
+
# best_index[best_index >= len(freqs_sorted_here) - self.nfriends] = (
|
|
233
|
+
# len(freqs_sorted_here) - self.nfriends
|
|
234
|
+
# )
|
|
235
|
+
# check_inds = (
|
|
236
|
+
# best_index[:, None]
|
|
237
|
+
# + np.tile(np.arange(2 * self.nfriends), (best_index.shape[0], 1))
|
|
238
|
+
# - self.nfriends
|
|
239
|
+
# )
|
|
240
|
+
|
|
241
|
+
# check_freqs = freqs_sorted_here[check_inds]
|
|
242
|
+
# breakpoint()
|
|
243
|
+
|
|
244
|
+
# # batch_count = 1000
|
|
245
|
+
# # split_inds = np.arange(batch_count, freqs_remaining_here.shape[0], batch_count)
|
|
246
|
+
|
|
247
|
+
# # splits_remain = np.split(freqs_remaining_here, split_inds)
|
|
248
|
+
# # splits_check = np.split(check_freqs, split_inds)
|
|
249
|
+
|
|
250
|
+
# # out = []
|
|
251
|
+
# # for i, (split_r, split_c) in enumerate(zip(splits_remain, splits_check)):
|
|
252
|
+
# # out.append(np.abs(split_r[:, None] - split_c))
|
|
253
|
+
# # print(i)
|
|
254
|
+
|
|
255
|
+
# # freq_distance = np.asarray(out)
|
|
256
|
+
|
|
257
|
+
# freq_distance = np.abs(freqs_remaining_here[:, None] - check_freqs)
|
|
258
|
+
# breakpoint()
|
|
259
|
+
|
|
260
|
+
# keep_min_inds = np.argsort(freq_distance, axis=-1)[:, : self.nfriends].min(
|
|
261
|
+
# axis=-1
|
|
262
|
+
# )
|
|
263
|
+
# start_inds_freq = check_inds[(np.arange(len(check_inds)), keep_min_inds)]
|
|
264
|
+
|
|
265
|
+
# start_inds_freq_out[inds] = start_inds_freq
|
|
266
|
+
|
|
267
|
+
# start_inds_freq_out[~inds] = -1
|
|
268
|
+
|
|
269
|
+
# if "friend_start_inds" not in supps:
|
|
270
|
+
# supps.add_objects({"friend_start_inds": start_inds_freq_out})
|
|
271
|
+
# else:
|
|
272
|
+
# supps[:] = {"friend_start_inds": start_inds_freq_out}
|
|
273
|
+
|
|
274
|
+
# self.all_friends_start_inds_sorted = self.xp.asarray(
|
|
275
|
+
# start_inds_freq_out[inds][self.inds_freqs_sorted.get()]
|
|
276
|
+
# )
|
|
277
|
+
|
|
278
|
+
def find_friends(self, name, gb_points_to_move, s_inds=None):
|
|
279
|
+
if s_inds is None:
|
|
280
|
+
raise ValueError
|
|
281
|
+
|
|
282
|
+
inds_points_to_move = self.xp.asarray(s_inds.flatten())
|
|
283
|
+
|
|
284
|
+
half_friends = int(self.nfriends / 2)
|
|
285
|
+
|
|
286
|
+
gb_points_for_move = gb_points_to_move.reshape(-1, 8).copy()
|
|
287
|
+
|
|
288
|
+
if not hasattr(self, "ntemps"):
|
|
289
|
+
self.ntemps = 1
|
|
290
|
+
|
|
291
|
+
inds_start_freq_to_move = self.current_friends_start_inds[
|
|
292
|
+
inds_points_to_move.reshape(self.ntemps, self.nwalkers, -1)
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
deviation = self.xp.random.randint(
|
|
296
|
+
0, self.nfriends, size=len(inds_start_freq_to_move)
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
inds_keep_friends = inds_start_freq_to_move + deviation
|
|
300
|
+
|
|
301
|
+
inds_keep_friends[inds_keep_friends < 0] = 0
|
|
302
|
+
inds_keep_friends[inds_keep_friends >= len(self.all_coords_sorted)] = (
|
|
303
|
+
len(self.all_coords_sorted) - 1
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
gb_points_for_move[inds_points_to_move] = self.all_coords_sorted[
|
|
307
|
+
inds_keep_friends
|
|
308
|
+
]
|
|
309
|
+
|
|
310
|
+
return gb_points_for_move.reshape(self.ntemps, -1, 1, 8)
|
|
311
|
+
|
|
312
|
+
def new_find_friends(self, name, inds_in):
|
|
313
|
+
inds_start_freq_to_move = self.current_friends_start_inds[tuple(inds_in)]
|
|
314
|
+
|
|
315
|
+
deviation = self.xp.random.randint(
|
|
316
|
+
0, self.nfriends, size=len(inds_start_freq_to_move)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
inds_keep_friends = inds_start_freq_to_move + deviation
|
|
320
|
+
|
|
321
|
+
inds_keep_friends[inds_keep_friends < 0] = 0
|
|
322
|
+
inds_keep_friends[inds_keep_friends >= len(self.all_coords_sorted)] = (
|
|
323
|
+
len(self.all_coords_sorted) - 1
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
gb_points_for_move = self.all_coords_sorted[
|
|
327
|
+
inds_keep_friends
|
|
328
|
+
]
|
|
329
|
+
|
|
330
|
+
return gb_points_for_move
|
|
331
|
+
|
|
332
|
+
def setup(self, branches):
|
|
333
|
+
for i, (name, branch) in enumerate(branches.items()):
|
|
334
|
+
if name != "gb":
|
|
335
|
+
continue
|
|
336
|
+
|
|
337
|
+
if not self.is_rj_prop and self.time % self.n_iter_update == 0:
|
|
338
|
+
self.setup_gbs(branch)
|
|
339
|
+
|
|
340
|
+
elif self.is_rj_prop and self.time == 0:
|
|
341
|
+
ndim = branch.shape[-1]
|
|
342
|
+
self.stretch_friends_args_in = tuple([xp.array([]) for _ in range(ndim)])
|
|
343
|
+
|
|
344
|
+
# update any shifted start inds due to tempering (need to do this every non-rj move)
|
|
345
|
+
"""if not self.is_rj_prop:
|
|
346
|
+
# fix the ones that have been added in RJ
|
|
347
|
+
fix = (
|
|
348
|
+
branch.branch_supplimental.holder["friend_start_inds"][:] == -1
|
|
349
|
+
) & branch.inds
|
|
350
|
+
|
|
351
|
+
if np.any(fix):
|
|
352
|
+
new_freqs = xp.asarray(branch.coords[fix][:, 1])
|
|
353
|
+
# TODO: is there a better way of doing this?
|
|
354
|
+
|
|
355
|
+
# fill information into friend finder for new binaries
|
|
356
|
+
branch.branch_supplimental.holder["friend_start_inds"][fix] = (
|
|
357
|
+
(
|
|
358
|
+
xp.searchsorted(self.freqs_sorted, new_freqs, side="right")
|
|
359
|
+
- 1
|
|
360
|
+
)
|
|
361
|
+
* (
|
|
362
|
+
(new_freqs > self.freqs_sorted[0])
|
|
363
|
+
& (new_freqs < self.freqs_sorted[-1])
|
|
364
|
+
)
|
|
365
|
+
+ 0 * (new_freqs < self.freqs_sorted[0])
|
|
366
|
+
+ (len(self.freqs_sorted) - 1)
|
|
367
|
+
* (new_freqs > self.freqs_sorted[-1])
|
|
368
|
+
).get()
|
|
369
|
+
|
|
370
|
+
# make sure current start inds reflect alive binaries
|
|
371
|
+
self.current_friends_start_inds = self.xp.asarray(
|
|
372
|
+
branch.branch_supplimental.holder["friend_start_inds"][:]
|
|
373
|
+
)
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
self.mempool.free_all_blocks()
|
|
377
|
+
|
|
378
|
+
def propose(self, model, state):
|
|
379
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
383
|
+
state (:class:`State`): Current state of the sampler.
|
|
384
|
+
|
|
385
|
+
Returns:
|
|
386
|
+
:class:`State`: State of sampler after proposal is complete.
|
|
387
|
+
|
|
388
|
+
"""
|
|
389
|
+
st = time.perf_counter()
|
|
390
|
+
|
|
391
|
+
self.xp.cuda.runtime.setDevice(self.mgh.gpus[0])
|
|
392
|
+
|
|
393
|
+
self.current_state = state
|
|
394
|
+
np.random.seed(10)
|
|
395
|
+
# print("start stretch")
|
|
396
|
+
|
|
397
|
+
# Check that the dimensions are compatible.
|
|
398
|
+
ndim_total = 0
|
|
399
|
+
for branch in state.branches.values():
|
|
400
|
+
ntemps, nwalkers, nleaves_, ndim_ = branch.shape
|
|
401
|
+
ndim_total += ndim_ * nleaves_
|
|
402
|
+
|
|
403
|
+
self.nwalkers = nwalkers
|
|
404
|
+
|
|
405
|
+
# Run any move-specific setup.
|
|
406
|
+
self.setup(state.branches)
|
|
407
|
+
|
|
408
|
+
new_state = State(state, copy=True)
|
|
409
|
+
band_temps = xp.asarray(state.band_info["band_temps"].copy())
|
|
410
|
+
|
|
411
|
+
if self.is_rj_prop:
|
|
412
|
+
orig_store = new_state.log_like[0].copy()
|
|
413
|
+
gb_coords = xp.asarray(new_state.branches["gb"].coords)
|
|
414
|
+
|
|
415
|
+
self.mempool.free_all_blocks()
|
|
416
|
+
|
|
417
|
+
# self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
|
|
418
|
+
|
|
419
|
+
# data should not be whitened
|
|
420
|
+
|
|
421
|
+
ntemps, nwalkers, nleaves_max, ndim = state.branches_coords["gb"].shape
|
|
422
|
+
|
|
423
|
+
group_temp_finder = [
|
|
424
|
+
self.xp.repeat(self.xp.arange(ntemps), nwalkers * nleaves_max).reshape(
|
|
425
|
+
ntemps, nwalkers, nleaves_max
|
|
426
|
+
),
|
|
427
|
+
self.xp.tile(self.xp.arange(nwalkers), (ntemps, nleaves_max, 1)).transpose(
|
|
428
|
+
(0, 2, 1)
|
|
429
|
+
),
|
|
430
|
+
self.xp.tile(self.xp.arange(nleaves_max), ((ntemps, nwalkers, 1))),
|
|
431
|
+
]
|
|
432
|
+
|
|
433
|
+
"""(
|
|
434
|
+
gb_coords,
|
|
435
|
+
gb_inds_orig,
|
|
436
|
+
points_curr,
|
|
437
|
+
prior_all_curr,
|
|
438
|
+
gb_inds,
|
|
439
|
+
N_vals_in,
|
|
440
|
+
prop_to_curr_map,
|
|
441
|
+
factors,
|
|
442
|
+
inds_curr,
|
|
443
|
+
proposal_specific_information
|
|
444
|
+
) = self.get_special_proposal_setup(model, new_state, state, group_temp_finder)"""
|
|
445
|
+
|
|
446
|
+
waveform_kwargs_now = self.waveform_kwargs.copy()
|
|
447
|
+
if "N" in waveform_kwargs_now:
|
|
448
|
+
waveform_kwargs_now.pop("N")
|
|
449
|
+
waveform_kwargs_now["start_freq_ind"] = self.start_freq_ind
|
|
450
|
+
|
|
451
|
+
# if self.is_rj_prop:
|
|
452
|
+
# print("START:", new_state.log_like[0])
|
|
453
|
+
log_like_tmp = self.xp.asarray(new_state.log_like)
|
|
454
|
+
log_prior_tmp = self.xp.asarray(new_state.log_prior)
|
|
455
|
+
|
|
456
|
+
self.mempool.free_all_blocks()
|
|
457
|
+
|
|
458
|
+
gb_inds = self.xp.asarray(new_state.branches["gb"].inds)
|
|
459
|
+
gb_inds_orig = gb_inds.copy()
|
|
460
|
+
|
|
461
|
+
data = self.mgh.data_list
|
|
462
|
+
base_data = self.mgh.base_data_list
|
|
463
|
+
psd = self.mgh.psd_list
|
|
464
|
+
lisasens = self.mgh.lisasens_list
|
|
465
|
+
|
|
466
|
+
# do unique for band size as separator between asynchronous kernel launches
|
|
467
|
+
# band_indices = self.xp.asarray(new_state.branches["gb"].branch_supplimental.holder["band_inds"])
|
|
468
|
+
band_indices = self.xp.searchsorted(self.band_edges, xp.asarray(new_state.branches["gb"].coords[:, :, :, 1]).flatten() / 1e3, side="right").reshape(new_state.branches["gb"].coords[:, :, :, 1].shape) - 1
|
|
469
|
+
|
|
470
|
+
# N_vals_in = self.xp.asarray(new_state.branches["gb"].branch_supplimental.holder["N_vals"])
|
|
471
|
+
points_curr = self.xp.asarray(new_state.branches["gb"].coords)
|
|
472
|
+
points_curr_orig = points_curr.copy()
|
|
473
|
+
# N_vals_in_orig = N_vals_in.copy()
|
|
474
|
+
band_indices_orig = band_indices.copy()
|
|
475
|
+
|
|
476
|
+
if self.is_rj_prop:
|
|
477
|
+
if isinstance(self.rj_proposal_distribution["gb"], list):
|
|
478
|
+
raise NotImplementedError
|
|
479
|
+
assert len(self.rj_proposal_distribution["gb"]) == ntemps
|
|
480
|
+
proposal_logpdf = xp.zeros((points_curr.shape[0], np.prod(points_curr.shape[1:-1])))
|
|
481
|
+
for t in range(ntemps):
|
|
482
|
+
new_sources = xp.full_like(points_curr[t, ~gb_inds[t]], np.nan)
|
|
483
|
+
fix = xp.full(new_sources.shape[0], True)
|
|
484
|
+
while xp.any(fix):
|
|
485
|
+
new_sources[fix] = self.rj_proposal_distribution["gb"][t].rvs(size=fix.sum().item())
|
|
486
|
+
fix = xp.any(xp.isnan(new_sources), axis=-1)
|
|
487
|
+
points_curr[t, ~gb_inds[t]] = new_sources
|
|
488
|
+
band_indices[t, ~gb_inds[t]] = self.xp.searchsorted(self.band_edges, new_sources[:, 1] / 1e3, side="right") - 1
|
|
489
|
+
# change gb_inds
|
|
490
|
+
gb_inds[t, :] = True
|
|
491
|
+
assert np.all(gb_inds[t])
|
|
492
|
+
proposal_logpdf[t] = self.rj_proposal_distribution["gb"][t].logpdf(
|
|
493
|
+
points_curr[t, gb_inds[t]]
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
proposal_logpdf = proposal_logpdf.flatten().copy()
|
|
497
|
+
|
|
498
|
+
else:
|
|
499
|
+
new_sources = xp.full_like(points_curr[~gb_inds], np.nan)
|
|
500
|
+
fix = xp.full(new_sources.shape[0], True)
|
|
501
|
+
while xp.any(fix):
|
|
502
|
+
if self.name == "rj_prior":
|
|
503
|
+
new_sources[fix] = self.rj_proposal_distribution["gb"].rvs(size=fix.sum().item(), psds=self.mgh.psd_shaped[0][0], walker_inds=group_temp_finder[1][~gb_inds][fix])
|
|
504
|
+
else:
|
|
505
|
+
new_sources[fix] = self.rj_proposal_distribution["gb"].rvs(size=fix.sum().item())
|
|
506
|
+
|
|
507
|
+
fix = xp.any(xp.isnan(new_sources), axis=-1)
|
|
508
|
+
points_curr[~gb_inds] = new_sources
|
|
509
|
+
|
|
510
|
+
band_indices[~gb_inds] = self.xp.searchsorted(self.band_edges, new_sources[:, 1] / 1e3, side="right") - 1
|
|
511
|
+
# N_vals_in[~gb_inds] = self.xp.asarray(
|
|
512
|
+
# get_N(
|
|
513
|
+
# xp.full_like(new_sources[:, 1], 1e-30),
|
|
514
|
+
# new_sources[:, 1] / 1e3,
|
|
515
|
+
# self.waveform_kwargs["T"],
|
|
516
|
+
# self.waveform_kwargs["oversample"],
|
|
517
|
+
# xp=self.xp
|
|
518
|
+
# )
|
|
519
|
+
# )
|
|
520
|
+
|
|
521
|
+
# change gb_inds
|
|
522
|
+
gb_inds[:] = True
|
|
523
|
+
assert np.all(gb_inds)
|
|
524
|
+
if self.name == "rj_prior":
|
|
525
|
+
proposal_logpdf = self.rj_proposal_distribution["gb"].logpdf(
|
|
526
|
+
points_curr[gb_inds], psds=self.mgh.psd_shaped[0][0], walker_inds=group_temp_finder[1][gb_inds]
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
proposal_logpdf = self.rj_proposal_distribution["gb"].logpdf(
|
|
530
|
+
points_curr[gb_inds]
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
factors = (proposal_logpdf * -1) * (~gb_inds_orig).flatten() + (proposal_logpdf * +1) * (gb_inds_orig).flatten()
|
|
534
|
+
|
|
535
|
+
if self.name == "rj_prior" and self.use_prior_removal:
|
|
536
|
+
factors[~gb_inds_orig.flatten()] = -1e300
|
|
537
|
+
else:
|
|
538
|
+
factors = xp.zeros(gb_inds_orig.sum().item())
|
|
539
|
+
|
|
540
|
+
start_inds_all = xp.asarray(new_state.branches["gb"].branch_supplimental.holder["friend_start_inds"], dtype=xp.int32)[gb_inds]
|
|
541
|
+
points_curr = points_curr[gb_inds]
|
|
542
|
+
# N_vals_in = N_vals_in[gb_inds]
|
|
543
|
+
band_indices = band_indices[gb_inds]
|
|
544
|
+
gb_inds_in = gb_inds_orig[gb_inds]
|
|
545
|
+
|
|
546
|
+
temp_indices = group_temp_finder[0][gb_inds]
|
|
547
|
+
walker_indices = group_temp_finder[1][gb_inds]
|
|
548
|
+
leaf_indices = group_temp_finder[2][gb_inds]
|
|
549
|
+
|
|
550
|
+
unique_N = self.xp.unique(self.band_N_vals)
|
|
551
|
+
# remove 0
|
|
552
|
+
unique_N = unique_N[unique_N != 0]
|
|
553
|
+
|
|
554
|
+
do_synchronize = False
|
|
555
|
+
device = self.xp.cuda.runtime.getDevice()
|
|
556
|
+
|
|
557
|
+
units = 2 if not self.is_rj_prop else 2
|
|
558
|
+
# random start to rotation around
|
|
559
|
+
start_unit = model.random.randint(units)
|
|
560
|
+
|
|
561
|
+
ll_after = (
|
|
562
|
+
self.mgh.get_ll(include_psd_info=True)
|
|
563
|
+
)
|
|
564
|
+
# print(np.abs(new_state.log_like - ll_after).max())
|
|
565
|
+
store_max_diff = np.abs(new_state.log_like[0] - ll_after).max()
|
|
566
|
+
# print("CHECKING 0:", store_max_diff, self.is_rj_prop)
|
|
567
|
+
self.check_ll_inject(new_state)
|
|
568
|
+
|
|
569
|
+
per_walker_band_proposals = xp.zeros((ntemps, nwalkers, self.num_bands), dtype=int)
|
|
570
|
+
per_walker_band_accepted = xp.zeros((ntemps, nwalkers, self.num_bands), dtype=int)
|
|
571
|
+
|
|
572
|
+
total_keep = 0
|
|
573
|
+
for tmp in range(units):
|
|
574
|
+
remainder = (start_unit + tmp) % units
|
|
575
|
+
all_inputs = []
|
|
576
|
+
band_bookkeep_info = []
|
|
577
|
+
indiv_info = []
|
|
578
|
+
current_parameter_arrays = []
|
|
579
|
+
N_vals_list = []
|
|
580
|
+
output_info = []
|
|
581
|
+
inds_list = []
|
|
582
|
+
inds_orig_list = []
|
|
583
|
+
bands_list = []
|
|
584
|
+
for N_now in unique_N:
|
|
585
|
+
N_now = N_now.item()
|
|
586
|
+
if N_now == 0: # or N_now != 1024:
|
|
587
|
+
continue
|
|
588
|
+
|
|
589
|
+
# old_data_1 = self.mgh.data_shaped[0][0][11].copy()
|
|
590
|
+
# checkit = np.where(new_state.supplimental[:]["overall_inds"] == 0)
|
|
591
|
+
|
|
592
|
+
# TODO; check the maximum allowable band
|
|
593
|
+
keep = (
|
|
594
|
+
(band_indices % units == remainder)
|
|
595
|
+
# & (temp_indices == 0) # & (walker_indices == 2) # & (band_indices < 50)
|
|
596
|
+
& (self.band_N_vals[band_indices] == N_now)
|
|
597
|
+
& (band_indices < len(self.band_edges) - 2)
|
|
598
|
+
# & (band_indices == 1)
|
|
599
|
+
) # & (band_indices == 501) # # & (N_vals_in <= 256) & (temp_inds == checkit[0].item()) & (walker_inds == checkit[1].item()) # & (band_indices < 540) # & (temp_inds == 0) & (walker_inds == 0)
|
|
600
|
+
|
|
601
|
+
# if self.time > 0:
|
|
602
|
+
# keep[np.where(keep)[0][1:]] = False
|
|
603
|
+
"""if self.time > 0:
|
|
604
|
+
keep = (
|
|
605
|
+
(band_indices % units == remainder)
|
|
606
|
+
& (band_indices == 346) & (walker_indices == 9) & (temp_indices == 0)
|
|
607
|
+
& (self.band_N_vals[band_indices] == N_now)
|
|
608
|
+
& (band_indices < len(self.band_edges) - 2)
|
|
609
|
+
) # & (band_indices == 501) # # & (N_vals_in <= 256) & (temp_inds == checkit[0].item()) & (walker_inds == checkit[1].item()) # & (band_indices < 540) # & (temp_inds == 0) & (walker_inds == 0)
|
|
610
|
+
if np.any(keep):
|
|
611
|
+
breakpoint()"""
|
|
612
|
+
|
|
613
|
+
if keep.sum().item() == 0:
|
|
614
|
+
continue
|
|
615
|
+
|
|
616
|
+
total_keep += keep.sum().item()
|
|
617
|
+
# for testing
|
|
618
|
+
# keep[:] = False
|
|
619
|
+
# keep[tmp_keep] = True
|
|
620
|
+
# keep[3000:3020:1] = True
|
|
621
|
+
|
|
622
|
+
# permutate for rj just in case
|
|
623
|
+
permute_inds = xp.random.permutation(xp.arange(keep.sum()))
|
|
624
|
+
params_curr = points_curr[keep][permute_inds]
|
|
625
|
+
start_inds_here = start_inds_all[keep][permute_inds]
|
|
626
|
+
inds_here = gb_inds_in[keep][permute_inds]
|
|
627
|
+
factors_here = factors[keep][permute_inds]
|
|
628
|
+
|
|
629
|
+
prior_all_curr_here = self.gpu_priors["gb"].logpdf(params_curr, psds=self.mgh.psd_shaped[0][0], walker_inds=group_temp_finder[1][gb_inds][keep][permute_inds])
|
|
630
|
+
if not self.is_rj_prop:
|
|
631
|
+
assert xp.all(~xp.isinf(prior_all_curr_here))
|
|
632
|
+
|
|
633
|
+
temp_inds_here = temp_indices[keep][permute_inds]
|
|
634
|
+
walker_inds_here = walker_indices[keep][permute_inds]
|
|
635
|
+
leaf_inds_here = leaf_indices[keep][permute_inds]
|
|
636
|
+
band_inds_here = band_indices[keep][permute_inds]
|
|
637
|
+
|
|
638
|
+
special_band_inds = (temp_inds_here * nwalkers + walker_inds_here) * int(1e6) + band_inds_here
|
|
639
|
+
|
|
640
|
+
sort_special = xp.argsort(special_band_inds)
|
|
641
|
+
|
|
642
|
+
params_curr = params_curr[sort_special]
|
|
643
|
+
inds_here = inds_here[sort_special]
|
|
644
|
+
factors_here = factors_here[sort_special]
|
|
645
|
+
start_inds_here = start_inds_here[sort_special]
|
|
646
|
+
prior_all_curr_here = prior_all_curr_here[sort_special]
|
|
647
|
+
temp_inds_here = temp_inds_here[sort_special]
|
|
648
|
+
walker_inds_here = walker_inds_here[sort_special]
|
|
649
|
+
leaf_inds_here = leaf_inds_here[sort_special]
|
|
650
|
+
band_inds_here = band_inds_here[sort_special]
|
|
651
|
+
special_band_inds = special_band_inds[sort_special]
|
|
652
|
+
|
|
653
|
+
# assert np.all(np.sort(special_band_inds_sorted_here) == special_band_inds_sorted_here)
|
|
654
|
+
|
|
655
|
+
(
|
|
656
|
+
uni_special_band_inds_here,
|
|
657
|
+
uni_index_special_band_inds_here,
|
|
658
|
+
uni_count_special_band_inds_here,
|
|
659
|
+
) = self.xp.unique(
|
|
660
|
+
special_band_inds, return_index=True, return_counts=True
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
band_start_bin_ind_here = uni_index_special_band_inds_here.astype(
|
|
664
|
+
np.int32
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
band_num_bins_here = uni_count_special_band_inds_here.astype(np.int32)
|
|
668
|
+
|
|
669
|
+
band_inds = band_inds_here[uni_index_special_band_inds_here]
|
|
670
|
+
band_temps_inds = temp_inds_here[uni_index_special_band_inds_here]
|
|
671
|
+
band_walkers_inds = walker_inds_here[uni_index_special_band_inds_here]
|
|
672
|
+
|
|
673
|
+
band_inv_temp_vals_here = band_temps[band_inds, band_temps_inds]
|
|
674
|
+
|
|
675
|
+
indiv_info.append((temp_inds_here, walker_inds_here, leaf_inds_here))
|
|
676
|
+
|
|
677
|
+
band_bookkeep_info.append(
|
|
678
|
+
(band_temps_inds, band_walkers_inds, band_inds)
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
data_index_here = (((band_inds + 1) % 2) * nwalkers + band_walkers_inds).astype(np.int32)
|
|
682
|
+
noise_index_here = (band_walkers_inds).astype(np.int32)
|
|
683
|
+
|
|
684
|
+
# for updates
|
|
685
|
+
update_data_index_here = (((band_inds + 0) % 2) * nwalkers + band_walkers_inds).astype(np.int32)
|
|
686
|
+
|
|
687
|
+
L_contribution_here = xp.zeros_like(band_inds, dtype=complex)
|
|
688
|
+
p_contribution_here = xp.zeros_like(band_inds, dtype=complex)
|
|
689
|
+
|
|
690
|
+
buffer = 5 # bins
|
|
691
|
+
|
|
692
|
+
start_band_index = ((self.band_edges[band_inds] / self.df).astype(int) - N_now - 1).astype(np.int32)
|
|
693
|
+
end_band_index = (((self.band_edges[band_inds + 1]) / self.df).astype(int) + N_now + 1).astype(np.int32)
|
|
694
|
+
|
|
695
|
+
start_band_index[start_band_index < 0] = 0
|
|
696
|
+
end_band_index[end_band_index >= len(self.fd)] = len(self.fd) - 1
|
|
697
|
+
|
|
698
|
+
band_lengths = end_band_index - start_band_index
|
|
699
|
+
|
|
700
|
+
start_interest_band_index = ((self.band_edges[band_inds] / self.df).astype(int)).astype(np.int32) # - N_now).astype(np.int32)
|
|
701
|
+
end_interest_band_index = (((self.band_edges[band_inds + 1]) / self.df).astype(int)).astype(np.int32) # + N_now).astype(np.int32)
|
|
702
|
+
|
|
703
|
+
start_interest_band_index[start_interest_band_index < 0] = 0
|
|
704
|
+
end_interest_band_index[end_interest_band_index >= len(self.fd)] = len(self.fd) - 1
|
|
705
|
+
|
|
706
|
+
band_interest_lengths = end_interest_band_index - start_interest_band_index
|
|
707
|
+
|
|
708
|
+
if False: # self.is_rj_prop:
|
|
709
|
+
fmin_allow = ((self.band_edges[band_inds] / self.df).astype(int) + 1) * self.df
|
|
710
|
+
fmax_allow = ((self.band_edges[band_inds + 1] / self.df).astype(int) - 1) * self.df
|
|
711
|
+
|
|
712
|
+
else:
|
|
713
|
+
# proposal limits in a band
|
|
714
|
+
fmin_allow = ((self.band_edges[band_inds] / self.df).astype(int)- (N_now / 2)) * self.df
|
|
715
|
+
fmax_allow = ((self.band_edges[band_inds + 1] / self.df).astype(int) + (N_now / 2 )) * self.df
|
|
716
|
+
|
|
717
|
+
# if self.is_rj_prop:
|
|
718
|
+
# breakpoint()
|
|
719
|
+
fmin_allow[fmin_allow < self.band_edges[0]] = self.band_edges[0]
|
|
720
|
+
fmax_allow[fmax_allow > self.band_edges[-1]] = self.band_edges[-1]
|
|
721
|
+
|
|
722
|
+
max_data_store_size = band_lengths.max().item()
|
|
723
|
+
|
|
724
|
+
num_bands_here = len(band_inds)
|
|
725
|
+
|
|
726
|
+
# makes in-model effectively not tempered
|
|
727
|
+
# if not self.is_rj_prop:
|
|
728
|
+
# band_inv_temp_vals_here[:] = 1.0
|
|
729
|
+
|
|
730
|
+
params_curr_separated = tuple([params_curr[:, i].copy() for i in range(params_curr.shape[1])])
|
|
731
|
+
params_curr_separated_orig = tuple([params_curr[:, i].copy() for i in range(params_curr.shape[1])])
|
|
732
|
+
params_extra_params = (
|
|
733
|
+
self.waveform_kwargs["T"],
|
|
734
|
+
self.waveform_kwargs["dt"],
|
|
735
|
+
N_now,
|
|
736
|
+
params_curr.shape[0],
|
|
737
|
+
self.start_freq_ind,
|
|
738
|
+
Soms_d_all["sangria"] ** (1/2),
|
|
739
|
+
Sa_a_all["sangria"] ** (1/2),
|
|
740
|
+
1e-100, 0.0, 0.0, 0.0, 0.0, # foreground params for snr -> amp transform
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
gb_params_curr_in = self.gb.pyGalacticBinaryParams(
|
|
744
|
+
*(params_curr_separated + params_curr_separated_orig + params_extra_params)
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
current_parameter_arrays.append(params_curr_separated)
|
|
748
|
+
|
|
749
|
+
data_package = self.gb.pyDataPackage(
|
|
750
|
+
data[0][0],
|
|
751
|
+
data[1][0],
|
|
752
|
+
base_data[0][0],
|
|
753
|
+
base_data[1][0],
|
|
754
|
+
psd[0][0],
|
|
755
|
+
psd[1][0],
|
|
756
|
+
lisasens[0][0],
|
|
757
|
+
lisasens[1][0],
|
|
758
|
+
self.df,
|
|
759
|
+
self.data_length,
|
|
760
|
+
self.nwalkers * self.ntemps,
|
|
761
|
+
self.nwalkers * self.ntemps
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
loc_band_index = xp.arange(num_bands_here, dtype=xp.int32)
|
|
765
|
+
band_package = self.gb.pyBandPackage(
|
|
766
|
+
loc_band_index,
|
|
767
|
+
data_index_here,
|
|
768
|
+
noise_index_here,
|
|
769
|
+
band_start_bin_ind_here, # uni_index
|
|
770
|
+
band_num_bins_here, # uni_count
|
|
771
|
+
start_band_index,
|
|
772
|
+
band_lengths,
|
|
773
|
+
start_interest_band_index,
|
|
774
|
+
band_interest_lengths,
|
|
775
|
+
num_bands_here,
|
|
776
|
+
max_data_store_size,
|
|
777
|
+
fmin_allow,
|
|
778
|
+
fmax_allow,
|
|
779
|
+
update_data_index_here,
|
|
780
|
+
self.ntemps,
|
|
781
|
+
band_inds.astype(np.int32),
|
|
782
|
+
band_walkers_inds.astype(np.int32),
|
|
783
|
+
band_temps_inds.astype(np.int32),
|
|
784
|
+
xp.zeros_like(band_temps_inds, dtype=np.int32), # swaps propsoed
|
|
785
|
+
xp.zeros_like(band_temps_inds, dtype=np.int32) # swaps accepted
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
accepted_out_here = xp.zeros_like(band_start_bin_ind_here, dtype=xp.int32)
|
|
789
|
+
|
|
790
|
+
mcmc_info = self.gb.pyMCMCInfo(
|
|
791
|
+
L_contribution_here,
|
|
792
|
+
p_contribution_here,
|
|
793
|
+
prior_all_curr_here,
|
|
794
|
+
accepted_out_here,
|
|
795
|
+
band_inv_temp_vals_here, # band_inv_temp_vals
|
|
796
|
+
self.is_rj_prop,
|
|
797
|
+
self.phase_maximize,
|
|
798
|
+
self.snr_lim,
|
|
799
|
+
)
|
|
800
|
+
|
|
801
|
+
if not self.is_rj_prop:
|
|
802
|
+
num_proposals_here = self.num_repeat_proposals
|
|
803
|
+
|
|
804
|
+
else:
|
|
805
|
+
num_proposals_here = 1
|
|
806
|
+
|
|
807
|
+
num_proposals_per_band = band_num_bins_here * num_proposals_here
|
|
808
|
+
|
|
809
|
+
assert start_inds_here.dtype == np.int32
|
|
810
|
+
|
|
811
|
+
proposal_info = self.gb.pyStretchProposalPackage(
|
|
812
|
+
*(self.stretch_friends_args_in + (self.nfriends, len(self.stretch_friends_args_in[0]), num_proposals_here, self.a, ndim, inds_here, factors_here, start_inds_here))
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
output_info.append([L_contribution_here, p_contribution_here, accepted_out_here, num_proposals_per_band])
|
|
816
|
+
|
|
817
|
+
inputs_now = (
|
|
818
|
+
data_package,
|
|
819
|
+
band_package,
|
|
820
|
+
gb_params_curr_in,
|
|
821
|
+
mcmc_info,
|
|
822
|
+
self.gpu_cuda_priors,
|
|
823
|
+
proposal_info,
|
|
824
|
+
self.gpu_cuda_wrap,
|
|
825
|
+
device,
|
|
826
|
+
do_synchronize,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
# if self.is_rj_prop:
|
|
830
|
+
# prior_check = xp.zeros(params_curr.shape[0])
|
|
831
|
+
# self.gb.check_prior_vals(prior_check, self.gpu_cuda_priors, gb_params_curr_in, 100)
|
|
832
|
+
# prior_check2 = self.gpu_priors["gb"].logpdf(params_curr)
|
|
833
|
+
# breakpoint()
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
N_vals_list.append(N_now)
|
|
837
|
+
inds_list.append(inds_here)
|
|
838
|
+
inds_orig_list.append(inds_here.copy())
|
|
839
|
+
bands_list.append(band_inds_here)
|
|
840
|
+
all_inputs.append(inputs_now)
|
|
841
|
+
st = time.perf_counter()
|
|
842
|
+
# print(params_curr_separated[0].shape[0])
|
|
843
|
+
# if self.is_rj_prop:
|
|
844
|
+
# if self.time > 0:
|
|
845
|
+
# breakpoint()
|
|
846
|
+
|
|
847
|
+
# tmp_check = self.mgh.channel1_data[0][11 * self.data_length + 3911].real + self.mgh.channel1_data[0][29 * self.data_length + 3911].real - self.mgh.channel1_base_data[0][11 * self.data_length + 3911].real
|
|
848
|
+
# print(f"BEFORE {tmp_check}, {self.mgh.channel1_data[0][11 * self.data_length + 3911].real} , {self.mgh.channel1_data[0][29 * self.data_length + 3911].real} , {self.mgh.channel1_base_data[0][11 * self.data_length + 3911].real} ")
|
|
849
|
+
# print(f"BEFORE2 {params_curr_separated[1].min().item()} ")
|
|
850
|
+
self.gb.SharedMemoryMakeNewMove_wrap(*inputs_now)
|
|
851
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
852
|
+
# tmp_check = self.mgh.channel1_data[0][11 * self.data_length + 3911].real + self.mgh.channel1_data[0][29 * self.data_length + 3911].real - self.mgh.channel1_base_data[0][11 * self.data_length + 3911].real
|
|
853
|
+
# print(f"After {tmp_check}, {self.mgh.channel1_data[0][11 * self.data_length + 3911].real} , {self.mgh.channel1_data[0][29 * self.data_length + 3911].real} , {self.mgh.channel1_base_data[0][11 * self.data_length + 3911].real} ")
|
|
854
|
+
# print(f"After2 {params_curr_separated[1].min().item()} ")
|
|
855
|
+
|
|
856
|
+
et = time.perf_counter()
|
|
857
|
+
# print(et - st, N_now)
|
|
858
|
+
# breakpoint()
|
|
859
|
+
|
|
860
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
861
|
+
new_point_info = []
|
|
862
|
+
ll_diff_info = []
|
|
863
|
+
for (
|
|
864
|
+
band_info,
|
|
865
|
+
indiv_info_now,
|
|
866
|
+
N_now,
|
|
867
|
+
outputs,
|
|
868
|
+
current_parameters,
|
|
869
|
+
inds,
|
|
870
|
+
inds_prev,
|
|
871
|
+
bands
|
|
872
|
+
) in zip(
|
|
873
|
+
band_bookkeep_info, indiv_info, N_vals_list, output_info, current_parameter_arrays, inds_list, inds_orig_list, bands_list
|
|
874
|
+
):
|
|
875
|
+
ll_contrib_now = outputs[0]
|
|
876
|
+
lp_contrib_now = outputs[1]
|
|
877
|
+
accepted_now = outputs[2]
|
|
878
|
+
num_proposals_now = outputs[3]
|
|
879
|
+
|
|
880
|
+
per_walker_band_proposals[band_info] += num_proposals_now
|
|
881
|
+
per_walker_band_accepted[band_info] += accepted_now
|
|
882
|
+
|
|
883
|
+
# remove accepted
|
|
884
|
+
# print(accepted_now.sum(0) / accepted_now.shape[0])
|
|
885
|
+
temp_tmp, walker_tmp, leaf_tmp = (
|
|
886
|
+
indiv_info_now[0],
|
|
887
|
+
indiv_info_now[1],
|
|
888
|
+
indiv_info_now[2],
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
# updates related to newly added sources
|
|
892
|
+
if self.is_rj_prop:
|
|
893
|
+
gb_inds_orig_check = gb_inds_orig.copy()
|
|
894
|
+
gb_coords[(temp_tmp[inds], walker_tmp[inds], leaf_tmp[inds])] = xp.asarray(current_parameters).T[inds]
|
|
895
|
+
|
|
896
|
+
gb_inds_orig[(temp_tmp, walker_tmp, leaf_tmp)] = inds
|
|
897
|
+
new_state.branches_supplimental["gb"].holder["N_vals"][
|
|
898
|
+
(temp_tmp[inds].get(), walker_tmp[inds].get(), leaf_tmp[inds].get())
|
|
899
|
+
] = N_now
|
|
900
|
+
new_state.branches_supplimental["gb"].holder["band_inds"][
|
|
901
|
+
(temp_tmp[inds].get(), walker_tmp[inds].get(), leaf_tmp[inds].get())
|
|
902
|
+
] = bands[inds].get()
|
|
903
|
+
|
|
904
|
+
else:
|
|
905
|
+
gb_coords[(temp_tmp, walker_tmp, leaf_tmp)] = xp.asarray(current_parameters).T
|
|
906
|
+
|
|
907
|
+
new_band_inds = self.xp.searchsorted(self.band_edges, xp.asarray(current_parameters).T[:, 1] / 1e3, side="right") - 1
|
|
908
|
+
|
|
909
|
+
# if np.any(new_band_inds != bands):
|
|
910
|
+
# print(new_band_inds, bands)
|
|
911
|
+
ll_change = self.xp.zeros((ntemps, nwalkers, len(self.band_edges)))
|
|
912
|
+
lp_change = self.xp.zeros((ntemps, nwalkers, len(self.band_edges)))
|
|
913
|
+
|
|
914
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
915
|
+
|
|
916
|
+
ll_change[band_info] = ll_contrib_now
|
|
917
|
+
|
|
918
|
+
ll_diff_info.append(ll_change.copy())
|
|
919
|
+
|
|
920
|
+
ll_adjustment = ll_change.sum(axis=-1)
|
|
921
|
+
log_like_tmp += ll_adjustment
|
|
922
|
+
|
|
923
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
924
|
+
|
|
925
|
+
lp_change[band_info] = lp_contrib_now
|
|
926
|
+
|
|
927
|
+
lp_adjustment = lp_change.sum(axis=-1)
|
|
928
|
+
log_prior_tmp += lp_adjustment
|
|
929
|
+
|
|
930
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
931
|
+
if True: # self.time > 0:
|
|
932
|
+
ll_after = (
|
|
933
|
+
self.mgh.get_ll(include_psd_info=True)
|
|
934
|
+
)
|
|
935
|
+
# print(np.abs(new_state.log_like - ll_after).max())
|
|
936
|
+
store_max_diff = np.abs(log_like_tmp[0].get() - ll_after).max()
|
|
937
|
+
# print("CHECKING in:", tmp, store_max_diff)
|
|
938
|
+
if store_max_diff > 3e-4:
|
|
939
|
+
print("LARGER ERROR:", store_max_diff)
|
|
940
|
+
self.check_ll_inject(new_state)
|
|
941
|
+
# self.mgh.get_ll(include_psd_info=True, stop=True)
|
|
942
|
+
|
|
943
|
+
new_state.branches["gb"].coords[:] = gb_coords.get()
|
|
944
|
+
if self.is_rj_prop:
|
|
945
|
+
new_state.branches["gb"].inds[:] = gb_inds_orig.get()
|
|
946
|
+
new_state.log_like[:] = log_like_tmp.get()
|
|
947
|
+
new_state.log_prior[:] = log_prior_tmp.get()
|
|
948
|
+
|
|
949
|
+
# get updated bands inds ONLY FOR COLD CHAIN and
|
|
950
|
+
# propogate changes to higher temperatures
|
|
951
|
+
new_freqs = gb_coords[gb_inds_orig, 1]
|
|
952
|
+
new_band_inds = (xp.searchsorted(self.band_edges, new_freqs / 1e3, side="right") - 1)
|
|
953
|
+
new_state.branches["gb"].branch_supplimental.holder["band_inds"][gb_inds_orig.get()] = new_band_inds.get()
|
|
954
|
+
|
|
955
|
+
ll_after = (
|
|
956
|
+
self.mgh.get_ll(include_psd_info=True)
|
|
957
|
+
)
|
|
958
|
+
# print(np.abs(new_state.log_like - ll_after).max())
|
|
959
|
+
store_max_diff = np.abs(new_state.log_like[0] - ll_after).max()
|
|
960
|
+
# print("CHECKING 1:", store_max_diff, self.is_rj_prop)
|
|
961
|
+
# if self.time > 0:
|
|
962
|
+
# breakpoint()
|
|
963
|
+
# self.check_ll_inject(new_state)
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
if not self.is_rj_prop:
|
|
967
|
+
# if self.time > 0:
|
|
968
|
+
# breakpoint()
|
|
969
|
+
# check2 = self.mgh.get_ll()
|
|
970
|
+
old_band_inds_cold_chain = state.branches["gb"].branch_supplimental.holder["band_inds"][0] * state.branches["gb"].inds[0]
|
|
971
|
+
new_band_inds_cold_chain = new_state.branches["gb"].branch_supplimental.holder["band_inds"][0] * state.branches["gb"].inds[0]
|
|
972
|
+
inds_band_change_cold_chain = np.where(new_band_inds_cold_chain != old_band_inds_cold_chain)
|
|
973
|
+
# when adjusting temperatures, be careful here
|
|
974
|
+
if len(inds_band_change_cold_chain[0]) > 0:
|
|
975
|
+
check2 = self.mgh.get_ll(include_psd_info=True)
|
|
976
|
+
# print("SWITCH", len(inds_band_change_cold_chain[0]))
|
|
977
|
+
walker_inds_change_cold_chain = np.tile(inds_band_change_cold_chain[0], (self.ntemps - 1, 1)).flatten()
|
|
978
|
+
old_leaf_inds_change_cold_chain = np.tile(inds_band_change_cold_chain[1], (self.ntemps - 1, 1)).flatten()
|
|
979
|
+
new_temp_inds_change_cold_chain = np.repeat(np.arange(1, self.ntemps), len(inds_band_change_cold_chain[0]))
|
|
980
|
+
|
|
981
|
+
special_check = new_temp_inds_change_cold_chain * self.nwalkers + walker_inds_change_cold_chain
|
|
982
|
+
|
|
983
|
+
uni_special_check, uni_special_check_count = np.unique(special_check, return_counts=True)
|
|
984
|
+
|
|
985
|
+
# get new leaf positions
|
|
986
|
+
temp_leaves = np.ones_like(group_temp_finder[2].reshape(self.ntemps * self.nwalkers, -1).get()[uni_special_check], dtype=int) * (~new_state.branches["gb"].inds.reshape(self.ntemps * self.nwalkers, -1)[uni_special_check])
|
|
987
|
+
temp_leaves_2 = np.cumsum(temp_leaves, axis=-1)
|
|
988
|
+
temp_leaves_2[new_state.branches["gb"].inds.reshape(self.ntemps * self.nwalkers, -1)[uni_special_check]] = -1
|
|
989
|
+
|
|
990
|
+
leaf_guide_here = np.tile(np.arange(nleaves_max), (len(uni_special_check), 1))
|
|
991
|
+
new_leaf_inds_change_cold_chain = leaf_guide_here[((temp_leaves_2 >= 0) & (temp_leaves_2 <= uni_special_check_count[:, None]))]
|
|
992
|
+
try:
|
|
993
|
+
assert np.all(~new_state.branches["gb"].inds[new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain])
|
|
994
|
+
except IndexError:
|
|
995
|
+
breakpoint()
|
|
996
|
+
|
|
997
|
+
new_state.branches["gb"].inds[new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain] = True
|
|
998
|
+
|
|
999
|
+
new_state.branches["gb"].coords[new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain] = new_state.branches["gb"].coords[np.zeros_like(walker_inds_change_cold_chain), walker_inds_change_cold_chain, old_leaf_inds_change_cold_chain]
|
|
1000
|
+
new_state.branches["gb"].branch_supplimental.holder["band_inds"][new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain] = new_state.branches["gb"].branch_supplimental.holder["band_inds"][np.zeros_like(walker_inds_change_cold_chain), walker_inds_change_cold_chain, old_leaf_inds_change_cold_chain]
|
|
1001
|
+
new_state.branches["gb"].branch_supplimental.holder["N_vals"][new_temp_inds_change_cold_chain, walker_inds_change_cold_chain, new_leaf_inds_change_cold_chain] = new_state.branches["gb"].branch_supplimental.holder["N_vals"][np.zeros_like(walker_inds_change_cold_chain), walker_inds_change_cold_chain, old_leaf_inds_change_cold_chain]
|
|
1002
|
+
|
|
1003
|
+
# adjust data
|
|
1004
|
+
adjust_binaries = xp.asarray(new_state.branches["gb"].coords[0, inds_band_change_cold_chain[0], inds_band_change_cold_chain[1]])
|
|
1005
|
+
adjust_binaries_in = self.parameter_transforms.both_transforms(
|
|
1006
|
+
adjust_binaries, xp=xp
|
|
1007
|
+
)
|
|
1008
|
+
adjust_walker_inds = inds_band_change_cold_chain[0]
|
|
1009
|
+
adjust_band_new = new_state.branches["gb"].branch_supplimental.holder["band_inds"][0, inds_band_change_cold_chain[0], inds_band_change_cold_chain[1]]
|
|
1010
|
+
adjust_band_old = state.branches["gb"].branch_supplimental.holder["band_inds"][0, inds_band_change_cold_chain[0], inds_band_change_cold_chain[1]]
|
|
1011
|
+
# N_vals_in = new_state.branches["gb"].branch_supplimental.holder["N_vals"][0, inds_band_change_cold_chain[0], inds_band_change_cold_chain[1]]
|
|
1012
|
+
|
|
1013
|
+
adjust_binaries_in_in = xp.concatenate([
|
|
1014
|
+
adjust_binaries_in,
|
|
1015
|
+
adjust_binaries_in
|
|
1016
|
+
], axis=0)
|
|
1017
|
+
|
|
1018
|
+
data_index = xp.concatenate([
|
|
1019
|
+
xp.asarray(((adjust_band_old + 0) % 2) * nwalkers + adjust_walker_inds),
|
|
1020
|
+
xp.asarray(((adjust_band_new + 0) % 2) * nwalkers + adjust_walker_inds)
|
|
1021
|
+
]).astype(xp.int32)
|
|
1022
|
+
|
|
1023
|
+
N_vals_in_in = xp.concatenate([
|
|
1024
|
+
xp.asarray(self.band_N_vals[adjust_band_old]),
|
|
1025
|
+
xp.asarray(self.band_N_vals[adjust_band_new])
|
|
1026
|
+
])
|
|
1027
|
+
|
|
1028
|
+
factors = xp.concatenate([
|
|
1029
|
+
+xp.ones_like(adjust_band_old, dtype=xp.float64), # remove
|
|
1030
|
+
-xp.ones_like(adjust_band_old, dtype=xp.float64) # add
|
|
1031
|
+
])
|
|
1032
|
+
|
|
1033
|
+
self.gb.generate_global_template(
|
|
1034
|
+
adjust_binaries_in_in,
|
|
1035
|
+
data_index,
|
|
1036
|
+
self.mgh.data_list,
|
|
1037
|
+
N=N_vals_in_in,
|
|
1038
|
+
factors=factors,
|
|
1039
|
+
data_length=self.mgh.data_length,
|
|
1040
|
+
data_splits=self.mgh.gpu_splits,
|
|
1041
|
+
**waveform_kwargs_now
|
|
1042
|
+
)
|
|
1043
|
+
check3 = self.mgh.get_ll(include_psd_info=True)
|
|
1044
|
+
|
|
1045
|
+
# print(check3 - check2)
|
|
1046
|
+
|
|
1047
|
+
if np.any(
|
|
1048
|
+
self.band_N_vals[adjust_band_old] !=
|
|
1049
|
+
self.band_N_vals[adjust_band_new]
|
|
1050
|
+
):
|
|
1051
|
+
walkers_focus = np.unique(inds_band_change_cold_chain[0][
|
|
1052
|
+
self.band_N_vals.get()[adjust_band_old] !=
|
|
1053
|
+
self.band_N_vals.get()[adjust_band_new]
|
|
1054
|
+
])
|
|
1055
|
+
# print(f"specific change in N across boundary: {walkers_focus}")
|
|
1056
|
+
new_state.log_like[0, walkers_focus] = self.mgh.get_ll(include_psd_info=True)[walkers_focus]
|
|
1057
|
+
|
|
1058
|
+
# self.check_ll_inject(new_state)
|
|
1059
|
+
# breakpoint()
|
|
1060
|
+
ll_after = (
|
|
1061
|
+
self.mgh.get_ll(include_psd_info=True)
|
|
1062
|
+
)
|
|
1063
|
+
# print(np.abs(new_state.log_like - ll_after).max())
|
|
1064
|
+
store_max_diff = np.abs(new_state.log_like[0] - ll_after).max()
|
|
1065
|
+
# print("CHECKING 2:", store_max_diff, self.is_rj_prop)
|
|
1066
|
+
|
|
1067
|
+
self.mempool.free_all_blocks()
|
|
1068
|
+
# get accepted fraction
|
|
1069
|
+
if not self.is_rj_prop:
|
|
1070
|
+
accepted_check_tmp = np.zeros_like(
|
|
1071
|
+
state.branches_inds["gb"], dtype=bool
|
|
1072
|
+
)
|
|
1073
|
+
accepted_check_tmp[state.branches_inds["gb"]] = np.all(
|
|
1074
|
+
np.abs(
|
|
1075
|
+
new_state.branches_coords["gb"][
|
|
1076
|
+
state.branches_inds["gb"]
|
|
1077
|
+
]
|
|
1078
|
+
- state.branches_coords["gb"][state.branches_inds["gb"]]
|
|
1079
|
+
)
|
|
1080
|
+
> 0.0,
|
|
1081
|
+
axis=-1,
|
|
1082
|
+
)
|
|
1083
|
+
proposed = gb_inds.get()
|
|
1084
|
+
accepted_check = accepted_check_tmp.sum(
|
|
1085
|
+
axis=(1, 2)
|
|
1086
|
+
) / proposed.sum(axis=(1, 2))
|
|
1087
|
+
else:
|
|
1088
|
+
accepted_check_tmp = (
|
|
1089
|
+
new_state.branches_inds["gb"] == (~state.branches_inds["gb"])
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
proposed = gb_inds.get()
|
|
1093
|
+
accepted_check = accepted_check_tmp.sum(axis=(1, 2)) / proposed.sum(axis=(1, 2))
|
|
1094
|
+
|
|
1095
|
+
# manually tell temperatures how real overall acceptance fraction is
|
|
1096
|
+
number_of_walkers_for_accepted = np.floor(nwalkers * accepted_check).astype(int)
|
|
1097
|
+
|
|
1098
|
+
accepted_inds = np.tile(np.arange(nwalkers), (ntemps, 1))
|
|
1099
|
+
|
|
1100
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
1101
|
+
accepted[accepted_inds < number_of_walkers_for_accepted[:, None]] = True
|
|
1102
|
+
|
|
1103
|
+
tmp1 = np.all(
|
|
1104
|
+
np.abs(
|
|
1105
|
+
new_state.branches_coords["gb"]
|
|
1106
|
+
- state.branches_coords["gb"]
|
|
1107
|
+
)
|
|
1108
|
+
> 0.0,
|
|
1109
|
+
axis=-1,
|
|
1110
|
+
).sum(axis=(2,))
|
|
1111
|
+
tmp2 = new_state.branches_inds["gb"].sum(axis=(2,))
|
|
1112
|
+
|
|
1113
|
+
# add to move-specific accepted information
|
|
1114
|
+
self.accepted += tmp1
|
|
1115
|
+
if isinstance(self.num_proposals, int):
|
|
1116
|
+
self.num_proposals = tmp2
|
|
1117
|
+
else:
|
|
1118
|
+
self.num_proposals += tmp2
|
|
1119
|
+
|
|
1120
|
+
new_inds = xp.asarray(new_state.branches_inds["gb"])
|
|
1121
|
+
|
|
1122
|
+
# in-model inds will not change
|
|
1123
|
+
tmp_freqs_find_bands = xp.asarray(new_state.branches_coords["gb"][:, :, :, 1])
|
|
1124
|
+
|
|
1125
|
+
# calculate current band counts
|
|
1126
|
+
band_here = (xp.searchsorted(self.band_edges, tmp_freqs_find_bands.flatten() / 1e3, side="right") - 1).reshape(tmp_freqs_find_bands.shape)
|
|
1127
|
+
|
|
1128
|
+
# get binaries per band
|
|
1129
|
+
special_band_here_num_per_band = ((group_temp_finder[0] * nwalkers + group_temp_finder[1]) * int(1e6) + band_here)[new_inds]
|
|
1130
|
+
unique_special_band_here_num_per_band, unique_special_band_here_num_per_band_count = xp.unique(special_band_here_num_per_band, return_counts=True)
|
|
1131
|
+
temp_walker_index_num_per_band = (unique_special_band_here_num_per_band / 1e6).astype(int)
|
|
1132
|
+
temp_index_num_per_band = (temp_walker_index_num_per_band / nwalkers).astype(int)
|
|
1133
|
+
walker_index_num_per_band = temp_walker_index_num_per_band - temp_index_num_per_band * nwalkers
|
|
1134
|
+
band_index_num_per_band = unique_special_band_here_num_per_band - temp_walker_index_num_per_band * int(1e6)
|
|
1135
|
+
|
|
1136
|
+
per_walker_band_counts = xp.zeros((ntemps, nwalkers, self.num_bands), dtype=int)
|
|
1137
|
+
per_walker_band_counts[temp_index_num_per_band, walker_index_num_per_band, band_index_num_per_band] = unique_special_band_here_num_per_band_count
|
|
1138
|
+
|
|
1139
|
+
# TEMPERING
|
|
1140
|
+
self.temperature_control.swaps_accepted = np.zeros(ntemps - 1)
|
|
1141
|
+
self.temperature_control.swaps_proposed = np.zeros(ntemps - 1)
|
|
1142
|
+
|
|
1143
|
+
band_swaps_accepted = np.zeros((len(self.band_edges) - 1, self.ntemps - 1), dtype=int)
|
|
1144
|
+
band_swaps_proposed = np.zeros((len(self.band_edges) - 1, self.ntemps - 1), dtype=int)
|
|
1145
|
+
current_band_counts = np.zeros((len(self.band_edges) - 1, self.ntemps), dtype=int)
|
|
1146
|
+
|
|
1147
|
+
# if self.is_rj_prop:
|
|
1148
|
+
# print("1st count check:", new_state.branches["gb"].inds.sum(axis=-1).mean(axis=-1), "\nll:", new_state.log_like[0] - orig_store, new_state.log_like[0])
|
|
1149
|
+
|
|
1150
|
+
# if self.time > 0:
|
|
1151
|
+
# self.check_ll_inject(new_state)
|
|
1152
|
+
if (
|
|
1153
|
+
self.temperature_control is not None
|
|
1154
|
+
and self.time % 1 == 0
|
|
1155
|
+
and self.ntemps > 1
|
|
1156
|
+
and self.is_rj_prop
|
|
1157
|
+
# and False
|
|
1158
|
+
):
|
|
1159
|
+
|
|
1160
|
+
new_coords_after_tempering = xp.asarray(np.zeros_like(new_state.branches["gb"].coords))
|
|
1161
|
+
new_inds_after_tempering = xp.asarray(np.zeros_like(new_state.branches["gb"].inds))
|
|
1162
|
+
|
|
1163
|
+
# TODO: check if N changes / need to deal with that
|
|
1164
|
+
betas = self.temperature_control.betas
|
|
1165
|
+
|
|
1166
|
+
# cannot find them yourself because of higher temps moving across band edge / need supplimental band inds
|
|
1167
|
+
band_inds_temp = new_state.branches["gb"].branch_supplimental.holder["band_inds"][new_state.branches["gb"].inds]
|
|
1168
|
+
temp_inds_temp = group_temp_finder[0].get()[new_state.branches["gb"].inds]
|
|
1169
|
+
walker_inds_temp = group_temp_finder[1].get()[new_state.branches["gb"].inds]
|
|
1170
|
+
|
|
1171
|
+
bands_guide = np.tile(np.arange(self.num_bands), (self.ntemps, self.nwalkers, 1)).transpose(2, 1, 0).flatten()
|
|
1172
|
+
temps_guide = np.repeat(np.arange(self.ntemps)[:, None], self.nwalkers * self.num_bands).reshape(self.ntemps, self.nwalkers, self.num_bands).transpose(2, 1, 0).flatten()
|
|
1173
|
+
walkers_guide = np.repeat(np.arange(self.nwalkers)[:, None], self.ntemps * self.num_bands).reshape(self.nwalkers, self.ntemps, self.num_bands).transpose(2, 0, 1).flatten()
|
|
1174
|
+
|
|
1175
|
+
walkers_permuted = np.asarray([np.random.permutation(np.arange(self.nwalkers)) for _ in range(self.ntemps * self.num_bands)]).reshape(self.num_bands, self.ntemps, self.nwalkers).transpose(0, 2, 1).flatten()
|
|
1176
|
+
|
|
1177
|
+
# special_inds_guide = bands_guide * int(1e6) + walkers_permuted * int(1e3) + temps_guide
|
|
1178
|
+
|
|
1179
|
+
coords_in = new_state.branches["gb"].coords[new_state.branches["gb"].inds]
|
|
1180
|
+
|
|
1181
|
+
# N_vals_in = new_state.branches["gb"].branch_supplimental.holder["N_vals"][new_state.branches["gb"].inds]
|
|
1182
|
+
unique_N = np.unique(self.band_N_vals).get()
|
|
1183
|
+
# remove 0
|
|
1184
|
+
unique_N = unique_N[unique_N != 0]
|
|
1185
|
+
|
|
1186
|
+
bands_in = band_inds_temp
|
|
1187
|
+
temps_in = temp_inds_temp
|
|
1188
|
+
walkers_in = walker_inds_temp
|
|
1189
|
+
|
|
1190
|
+
main_gpu = self.xp.cuda.runtime.getDevice()
|
|
1191
|
+
|
|
1192
|
+
walkers_info = walkers_permuted if self.temperature_control.permute else walkers_guide
|
|
1193
|
+
units = 2
|
|
1194
|
+
start_unit = np.random.randint(0, units)
|
|
1195
|
+
for unit in range(units):
|
|
1196
|
+
current_band_remainder = (start_unit + unit) % units
|
|
1197
|
+
for N_now in unique_N:
|
|
1198
|
+
keep = (bands_in % units == current_band_remainder) & (self.band_N_vals[xp.asarray(bands_in)].get() == N_now)
|
|
1199
|
+
keep_base = (bands_guide % units == current_band_remainder) & (self.band_N_vals[xp.asarray(bands_guide)].get() == N_now)
|
|
1200
|
+
bands_in_tmp = bands_in[keep]
|
|
1201
|
+
temps_in_tmp = temps_in[keep]
|
|
1202
|
+
walkers_in_tmp = walkers_in[keep]
|
|
1203
|
+
coords_in_tmp = coords_in[keep]
|
|
1204
|
+
|
|
1205
|
+
# adjust for N_val
|
|
1206
|
+
special_inds_bins = bands_in_tmp * int(1e8) + walkers_in_tmp * int(1e4) + temps_in_tmp
|
|
1207
|
+
|
|
1208
|
+
bands_guide_keep = bands_guide[keep_base]
|
|
1209
|
+
temps_guide_keep = temps_guide[keep_base]
|
|
1210
|
+
walkers_info_keep = walkers_info[keep_base]
|
|
1211
|
+
|
|
1212
|
+
special_inds_guide_keep = bands_guide_keep * int(1e8) + walkers_info_keep * int(1e4) + temps_guide_keep
|
|
1213
|
+
sort_tmp = np.argsort(special_inds_guide_keep)
|
|
1214
|
+
sorted_special_inds_guide_keep = special_inds_guide_keep[sort_tmp]
|
|
1215
|
+
sorting_info_need_to_revert = np.searchsorted(sorted_special_inds_guide_keep, special_inds_bins, side="left")
|
|
1216
|
+
|
|
1217
|
+
sorting_info = sort_tmp[sorting_info_need_to_revert]
|
|
1218
|
+
|
|
1219
|
+
assert np.all(special_inds_guide_keep[sorting_info] == special_inds_bins)
|
|
1220
|
+
|
|
1221
|
+
sort_bins = np.argsort(sorting_info)
|
|
1222
|
+
|
|
1223
|
+
# p1 = self.mgh.psd_shaped[0][0][8, 19748:20132]
|
|
1224
|
+
# p2 = self.mgh.psd_shaped[1][0][8, 19748:20132]
|
|
1225
|
+
# c2 = self.mgh.data_shaped[1][0][26, 19748:20132] + self.mgh.data_shaped[1][0][8, 19748:20132]- self.mgh.channel2_base_data[0][8 * self.data_length + 19748:8 * self.data_length + 20132]
|
|
1226
|
+
# c1 = self.mgh.data_shaped[0][0][26, 19748:20132] + self.mgh.data_shaped[0][0][8, 19748:20132]- self.mgh.channel1_base_data[0][8 * self.data_length + 19748:8 * self.data_length + 20132]
|
|
1227
|
+
# args_sorting = np.arange(len(sorting_info)) # np.argsort(sorting_info)
|
|
1228
|
+
bands_in_tmp = bands_in_tmp[sort_bins]
|
|
1229
|
+
temps_in_tmp = temps_in_tmp[sort_bins]
|
|
1230
|
+
walkers_in_tmp = walkers_in_tmp[sort_bins]
|
|
1231
|
+
coords_in_tmp = coords_in_tmp[sort_bins]
|
|
1232
|
+
sorting_info = sorting_info[sort_bins]
|
|
1233
|
+
|
|
1234
|
+
uni_sorted, uni_index, uni_inverse, uni_counts = xp.unique(xp.asarray(sorting_info), return_index=True, return_counts=True, return_inverse=True)
|
|
1235
|
+
|
|
1236
|
+
band_info_start_index_bin = xp.zeros_like(bands_guide_keep, dtype=np.int32)
|
|
1237
|
+
band_info_num_bin = xp.zeros_like(bands_guide_keep, dtype=np.int32)
|
|
1238
|
+
|
|
1239
|
+
band_info_start_index_bin[uni_sorted.astype(np.int32)] = uni_index.astype(np.int32)
|
|
1240
|
+
band_info_num_bin[uni_sorted.astype(np.int32)] = uni_counts.astype(np.int32)
|
|
1241
|
+
|
|
1242
|
+
params_curr_separated = tuple([xp.asarray(coords_in_tmp[:, i].copy()) for i in range(coords_in_tmp.shape[1])])
|
|
1243
|
+
params_curr_separated_orig = tuple([xp.asarray(coords_in_tmp[:, i].copy()) for i in range(coords_in_tmp.shape[1])])
|
|
1244
|
+
params_extra_params = (
|
|
1245
|
+
self.waveform_kwargs["T"],
|
|
1246
|
+
self.waveform_kwargs["dt"],
|
|
1247
|
+
N_now,
|
|
1248
|
+
coords_in.shape[0],
|
|
1249
|
+
self.start_freq_ind,
|
|
1250
|
+
Soms_d_all["sangria"] ** (1/2),
|
|
1251
|
+
Sa_a_all["sangria"] ** (1/2),
|
|
1252
|
+
1e-100, 0.0, 0.0, 0.0, 0.0, # foreground params for snr -> amp transform
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
gb_params_curr_in = self.gb.pyGalacticBinaryParams(
|
|
1256
|
+
*(params_curr_separated + params_curr_separated_orig + params_extra_params)
|
|
1257
|
+
)
|
|
1258
|
+
|
|
1259
|
+
# current_parameter_arrays.append(params_curr_separated)
|
|
1260
|
+
|
|
1261
|
+
data_package = self.gb.pyDataPackage(
|
|
1262
|
+
data[0][0],
|
|
1263
|
+
data[1][0],
|
|
1264
|
+
base_data[0][0],
|
|
1265
|
+
base_data[1][0],
|
|
1266
|
+
psd[0][0],
|
|
1267
|
+
psd[1][0],
|
|
1268
|
+
lisasens[0][0],
|
|
1269
|
+
lisasens[1][0],
|
|
1270
|
+
self.df,
|
|
1271
|
+
self.data_length,
|
|
1272
|
+
self.nwalkers * self.ntemps,
|
|
1273
|
+
self.nwalkers * self.ntemps
|
|
1274
|
+
)
|
|
1275
|
+
|
|
1276
|
+
data_index_here = xp.asarray((((bands_guide_keep + 1) % 2) * self.nwalkers + walkers_info_keep)).astype(np.int32)
|
|
1277
|
+
noise_index_here = xp.asarray(walkers_info_keep.copy()).astype(np.int32)
|
|
1278
|
+
update_data_index_here = xp.asarray((((bands_guide_keep + 0) % 2) * self.nwalkers + walkers_info_keep)).astype(np.int32)
|
|
1279
|
+
|
|
1280
|
+
start_band_index = ((self.band_edges[bands_guide_keep] / self.df).astype(int) - N_now - 1).astype(np.int32)
|
|
1281
|
+
end_band_index = (((self.band_edges[bands_guide_keep + 1]) / self.df).astype(int) + N_now + 1).astype(np.int32)
|
|
1282
|
+
|
|
1283
|
+
start_band_index[start_band_index < 0] = 0
|
|
1284
|
+
end_band_index[end_band_index >= len(self.fd)] = len(self.fd) - 1
|
|
1285
|
+
|
|
1286
|
+
band_lengths = end_band_index - start_band_index
|
|
1287
|
+
|
|
1288
|
+
start_interest_band_index = ((self.band_edges[bands_guide_keep] / self.df).astype(int)).astype(np.int32) # - N_now).astype(np.int32)
|
|
1289
|
+
end_interest_band_index = (((self.band_edges[bands_guide_keep + 1]) / self.df).astype(int)).astype(np.int32) # + N_now).astype(np.int32)
|
|
1290
|
+
|
|
1291
|
+
start_interest_band_index[start_interest_band_index < 0] = 0
|
|
1292
|
+
end_interest_band_index[end_interest_band_index >= len(self.fd)] = len(self.fd) - 1
|
|
1293
|
+
|
|
1294
|
+
band_interest_lengths = end_interest_band_index - start_interest_band_index
|
|
1295
|
+
|
|
1296
|
+
|
|
1297
|
+
# proposal limits in a band
|
|
1298
|
+
fmin_allow = xp.asarray(((self.band_edges[bands_guide_keep] / self.df).astype(int) - (N_now / 2))) * self.df
|
|
1299
|
+
fmax_allow = xp.asarray(((self.band_edges[bands_guide_keep + 1] / self.df).astype(int) + (N_now / 2))) * self.df
|
|
1300
|
+
|
|
1301
|
+
fmin_allow[fmin_allow < self.band_edges[0]] = self.band_edges[0]
|
|
1302
|
+
fmax_allow[fmax_allow > self.band_edges[-1]] = self.band_edges[-1]
|
|
1303
|
+
|
|
1304
|
+
max_data_store_size = band_lengths.max().item()
|
|
1305
|
+
|
|
1306
|
+
num_bands_here = len(band_info_start_index_bin)
|
|
1307
|
+
|
|
1308
|
+
num_swap_setups = np.unique(bands_guide_keep).shape[0] * self.nwalkers
|
|
1309
|
+
|
|
1310
|
+
swaps_proposed_here = xp.zeros(num_swap_setups * (self.ntemps - 1), dtype=np.int32)
|
|
1311
|
+
swaps_accepted_here = xp.zeros(num_swap_setups * (self.ntemps - 1), dtype=np.int32)
|
|
1312
|
+
|
|
1313
|
+
bands_guide_keep = xp.asarray(bands_guide_keep).astype(np.int32)
|
|
1314
|
+
walkers_info_keep = xp.asarray(walkers_info_keep).astype(np.int32)
|
|
1315
|
+
temps_guide_keep = xp.asarray(temps_guide_keep).astype(np.int32)
|
|
1316
|
+
loc_band_index = xp.arange(num_bands_here, dtype=xp.int32)
|
|
1317
|
+
|
|
1318
|
+
before_loc_band_index = loc_band_index.copy()
|
|
1319
|
+
before_band_info_start_index_bin = band_info_start_index_bin.copy()
|
|
1320
|
+
before_band_info_num_bin = band_info_num_bin.copy()
|
|
1321
|
+
|
|
1322
|
+
band_package = self.gb.pyBandPackage(
|
|
1323
|
+
loc_band_index,
|
|
1324
|
+
data_index_here,
|
|
1325
|
+
noise_index_here,
|
|
1326
|
+
band_info_start_index_bin, # uni_index
|
|
1327
|
+
band_info_num_bin, # uni_count
|
|
1328
|
+
start_band_index,
|
|
1329
|
+
band_lengths,
|
|
1330
|
+
start_interest_band_index,
|
|
1331
|
+
band_interest_lengths,
|
|
1332
|
+
num_bands_here,
|
|
1333
|
+
max_data_store_size,
|
|
1334
|
+
fmin_allow,
|
|
1335
|
+
fmax_allow,
|
|
1336
|
+
update_data_index_here,
|
|
1337
|
+
self.ntemps,
|
|
1338
|
+
bands_guide_keep,
|
|
1339
|
+
walkers_info_keep,
|
|
1340
|
+
temps_guide_keep,
|
|
1341
|
+
swaps_proposed_here,
|
|
1342
|
+
swaps_accepted_here
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
band_inv_temp_vals_here = band_temps[bands_guide_keep, temps_guide_keep]
|
|
1346
|
+
|
|
1347
|
+
accepted_out_here = xp.zeros_like(bands_guide_keep, dtype=xp.int32)
|
|
1348
|
+
L_contribution_here = xp.zeros_like(bands_guide_keep, dtype=complex)
|
|
1349
|
+
p_contribution_here = xp.zeros_like(bands_guide_keep, dtype=complex)
|
|
1350
|
+
prior_all_curr_here = xp.zeros_like(bands_guide_keep, dtype=np.float64)
|
|
1351
|
+
|
|
1352
|
+
mcmc_info = self.gb.pyMCMCInfo(
|
|
1353
|
+
L_contribution_here,
|
|
1354
|
+
p_contribution_here,
|
|
1355
|
+
prior_all_curr_here,
|
|
1356
|
+
accepted_out_here,
|
|
1357
|
+
band_inv_temp_vals_here, # band_inv_temp_vals
|
|
1358
|
+
self.is_rj_prop,
|
|
1359
|
+
False, # phased maximize
|
|
1360
|
+
self.snr_lim,
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
inds_here = xp.ones_like(temps_in_tmp, dtype=bool)
|
|
1364
|
+
factors_here = xp.zeros_like(temps_in_tmp, dtype=float)
|
|
1365
|
+
start_inds_here = xp.zeros_like(temps_in_tmp, dtype=np.int32)
|
|
1366
|
+
|
|
1367
|
+
num_proposals_here = 1
|
|
1368
|
+
|
|
1369
|
+
assert start_inds_here.dtype == np.int32
|
|
1370
|
+
proposal_info = self.gb.pyStretchProposalPackage(
|
|
1371
|
+
*(self.stretch_friends_args_in + (self.nfriends, len(self.stretch_friends_args_in[0]), num_proposals_here, self.a, ndim, inds_here, factors_here, start_inds_here))
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
inputs_now = (
|
|
1375
|
+
data_package,
|
|
1376
|
+
band_package,
|
|
1377
|
+
gb_params_curr_in,
|
|
1378
|
+
mcmc_info,
|
|
1379
|
+
self.gpu_cuda_priors,
|
|
1380
|
+
proposal_info,
|
|
1381
|
+
self.gpu_cuda_wrap,
|
|
1382
|
+
num_swap_setups,
|
|
1383
|
+
device,
|
|
1384
|
+
do_synchronize,
|
|
1385
|
+
-1,
|
|
1386
|
+
100000
|
|
1387
|
+
)
|
|
1388
|
+
|
|
1389
|
+
self.gb.SharedMemoryMakeTemperingMove_wrap(*inputs_now)
|
|
1390
|
+
|
|
1391
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
1392
|
+
|
|
1393
|
+
walkers_info_keep_per_bin = xp.repeat(walkers_info_keep, list(band_info_num_bin.get()))
|
|
1394
|
+
temps_guide_keep_per_bin = xp.repeat(temps_guide_keep, list(band_info_num_bin.get()))
|
|
1395
|
+
start_bin_ind_mapping = xp.repeat(band_info_start_index_bin, list(band_info_num_bin.get()))
|
|
1396
|
+
uni_after, uni_after_start_index, uni_after_inverse = xp.unique(start_bin_ind_mapping, return_index=True, return_inverse=True)
|
|
1397
|
+
ind_map = xp.arange(len(start_bin_ind_mapping)) - xp.arange(len(start_bin_ind_mapping))[uni_after_start_index][uni_after_inverse] + start_bin_ind_mapping
|
|
1398
|
+
|
|
1399
|
+
# out = []
|
|
1400
|
+
# for i in range(len(band_info_num_bin_has)):
|
|
1401
|
+
# num_bins = band_info_num_bin_has[i].item()
|
|
1402
|
+
# start_index = band_info_start_index_bin_has[i].item()
|
|
1403
|
+
# for j in range(num_bins):
|
|
1404
|
+
# out.append(start_index + j)
|
|
1405
|
+
|
|
1406
|
+
# breakpoint()
|
|
1407
|
+
new_coords_mapped = coords_in_tmp[ind_map.get()]
|
|
1408
|
+
|
|
1409
|
+
special_after_tempering = temps_guide_keep_per_bin * self.nwalkers + walkers_info_keep_per_bin
|
|
1410
|
+
|
|
1411
|
+
sorted_after = xp.argsort(special_after_tempering)
|
|
1412
|
+
special_after_tempering = special_after_tempering[sorted_after]
|
|
1413
|
+
walkers_info_keep_per_bin = walkers_info_keep_per_bin[sorted_after]
|
|
1414
|
+
temps_guide_keep_per_bin = temps_guide_keep_per_bin[sorted_after]
|
|
1415
|
+
new_coords_mapped = new_coords_mapped[sorted_after.get()]
|
|
1416
|
+
|
|
1417
|
+
uni_after, uni_index_after, uni_inverse_after = xp.unique(special_after_tempering, return_index=True, return_inverse=True)
|
|
1418
|
+
|
|
1419
|
+
relative_leaf_info_after_tempering = xp.arange(len(special_after_tempering)) - uni_index_after[uni_inverse_after]
|
|
1420
|
+
|
|
1421
|
+
leaf_start_per_walker = new_inds_after_tempering.argmin(axis=-1)[temps_guide_keep_per_bin[uni_index_after], walkers_info_keep_per_bin[uni_index_after]]
|
|
1422
|
+
|
|
1423
|
+
absolute_leaf_info_after_tempering = leaf_start_per_walker[uni_inverse_after] + relative_leaf_info_after_tempering
|
|
1424
|
+
|
|
1425
|
+
assert (np.all(~new_inds_after_tempering[temps_guide_keep_per_bin, walkers_info_keep_per_bin, absolute_leaf_info_after_tempering]))
|
|
1426
|
+
new_coords_after_tempering[temps_guide_keep_per_bin, walkers_info_keep_per_bin, absolute_leaf_info_after_tempering] = new_coords_mapped
|
|
1427
|
+
new_inds_after_tempering[temps_guide_keep_per_bin, walkers_info_keep_per_bin, absolute_leaf_info_after_tempering] = True
|
|
1428
|
+
|
|
1429
|
+
bands_unique_here = np.unique(bands_guide_keep).get()
|
|
1430
|
+
band_swaps_accepted[bands_unique_here] += swaps_accepted_here.reshape(bands_unique_here.shape[0], self.nwalkers, self.ntemps - 1).sum(axis=1).get()
|
|
1431
|
+
band_swaps_proposed[bands_unique_here] += swaps_proposed_here.reshape(bands_unique_here.shape[0], self.nwalkers, self.ntemps - 1).sum(axis=1).get()
|
|
1432
|
+
|
|
1433
|
+
ll_change = self.xp.zeros((ntemps, nwalkers, len(self.band_edges)))
|
|
1434
|
+
|
|
1435
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
1436
|
+
|
|
1437
|
+
ll_change[(temps_guide_keep,walkers_info_keep,bands_guide_keep)] = L_contribution_here
|
|
1438
|
+
|
|
1439
|
+
ll_adjustment = ll_change.sum(axis=-1)
|
|
1440
|
+
log_like_tmp += ll_adjustment
|
|
1441
|
+
|
|
1442
|
+
new_state.branches["gb"].coords[:] = new_coords_after_tempering.get()
|
|
1443
|
+
new_state.branches["gb"].inds[:] = new_inds_after_tempering.get()
|
|
1444
|
+
new_state.log_like[:] = log_like_tmp.get()
|
|
1445
|
+
|
|
1446
|
+
new_freqs = new_coords_after_tempering[new_inds_after_tempering, 1]
|
|
1447
|
+
new_band_inds = (xp.searchsorted(self.band_edges, new_freqs / 1e3, side="right") - 1)
|
|
1448
|
+
new_state.branches["gb"].branch_supplimental.holder["band_inds"][new_inds_after_tempering.get()] = new_band_inds.get()
|
|
1449
|
+
|
|
1450
|
+
# breakpoint()
|
|
1451
|
+
# self.check_ll_inject(new_state)
|
|
1452
|
+
# breakpoint()
|
|
1453
|
+
|
|
1454
|
+
# adjust priors accordingly
|
|
1455
|
+
log_prior_new_per_bin = xp.zeros_like(
|
|
1456
|
+
new_state.branches_inds["gb"], dtype=xp.float64
|
|
1457
|
+
)
|
|
1458
|
+
# self.gpu_priors
|
|
1459
|
+
log_prior_new_per_bin[
|
|
1460
|
+
new_state.branches_inds["gb"]
|
|
1461
|
+
] = self.gpu_priors["gb"].logpdf(
|
|
1462
|
+
xp.asarray(
|
|
1463
|
+
new_state.branches_coords["gb"][
|
|
1464
|
+
new_state.branches_inds["gb"]
|
|
1465
|
+
]
|
|
1466
|
+
),
|
|
1467
|
+
psds=self.mgh.psd_shaped[0][0],
|
|
1468
|
+
walker_inds=group_temp_finder[1].get()[new_state.branches_inds["gb"]]
|
|
1469
|
+
)
|
|
1470
|
+
|
|
1471
|
+
new_state.log_prior = log_prior_new_per_bin.sum(axis=-1).get()
|
|
1472
|
+
|
|
1473
|
+
ratios = (band_swaps_accepted / band_swaps_proposed).T # self.swaps_accepted / self.swaps_proposed
|
|
1474
|
+
ratios[np.isnan(ratios)] = 0.0
|
|
1475
|
+
|
|
1476
|
+
# only change those with a binary in them
|
|
1477
|
+
|
|
1478
|
+
# adapt if desired
|
|
1479
|
+
if self.time > 50:
|
|
1480
|
+
betas0 = band_temps.copy().T.get()
|
|
1481
|
+
betas1 = betas0.copy()
|
|
1482
|
+
|
|
1483
|
+
# Modulate temperature adjustments with a hyperbolic decay.
|
|
1484
|
+
decay = self.temperature_control.adaptation_lag / (self.time + self.temperature_control.adaptation_lag)
|
|
1485
|
+
kappa = decay / self.temperature_control.adaptation_time
|
|
1486
|
+
|
|
1487
|
+
# Construct temperature adjustments.
|
|
1488
|
+
dSs = kappa * (ratios[:-1] - ratios[1:])
|
|
1489
|
+
|
|
1490
|
+
# Compute new ladder (hottest and coldest chains don't move).
|
|
1491
|
+
deltaTs = np.diff(1 / betas1[:-1], axis=0)
|
|
1492
|
+
|
|
1493
|
+
deltaTs *= np.exp(dSs)
|
|
1494
|
+
betas1[1:-1] = 1 / (np.cumsum(deltaTs, axis=0) + 1 / betas1[0])
|
|
1495
|
+
|
|
1496
|
+
# Don't mutate the ladder here; let the client code do that.
|
|
1497
|
+
dbetas = betas1 - betas0
|
|
1498
|
+
|
|
1499
|
+
band_temps += self.xp.asarray(dbetas.T)
|
|
1500
|
+
|
|
1501
|
+
# band_temps[:] = band_temps[553, :][None, :]
|
|
1502
|
+
|
|
1503
|
+
# only increase time if it is adaptive.
|
|
1504
|
+
new_state.betas = self.temperature_control.betas.copy()
|
|
1505
|
+
|
|
1506
|
+
self.mempool.free_all_blocks()
|
|
1507
|
+
# print(
|
|
1508
|
+
# self.is_rj_prop,
|
|
1509
|
+
# band_swaps_accepted[350] / band_swaps_proposed[350],
|
|
1510
|
+
# band_swaps_accepted[450] / band_swaps_proposed[450],
|
|
1511
|
+
# band_swaps_accepted[501] / band_swaps_proposed[501]
|
|
1512
|
+
# )
|
|
1513
|
+
|
|
1514
|
+
self.mempool.free_all_blocks()
|
|
1515
|
+
if self.time % 1 == 0:
|
|
1516
|
+
ll_after = (
|
|
1517
|
+
self.mgh.get_ll(include_psd_info=True)
|
|
1518
|
+
)
|
|
1519
|
+
# print(np.abs(new_state.log_like - ll_after).max())
|
|
1520
|
+
store_max_diff = np.abs(new_state.log_like[0] - ll_after).max()
|
|
1521
|
+
# print("CHECKING:", store_max_diff, self.is_rj_prop)
|
|
1522
|
+
if store_max_diff > 1e-5:
|
|
1523
|
+
ll_after = (
|
|
1524
|
+
self.mgh.get_ll(include_psd_info=True, stop=True)
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
if store_max_diff > 1.0:
|
|
1528
|
+
breakpoint()
|
|
1529
|
+
|
|
1530
|
+
# reset data and fix likelihood
|
|
1531
|
+
new_state.log_like[0] = self.check_ll_inject(new_state)
|
|
1532
|
+
|
|
1533
|
+
|
|
1534
|
+
self.time += 1
|
|
1535
|
+
# self.xp.cuda.runtime.deviceSynchronize()
|
|
1536
|
+
|
|
1537
|
+
new_state.update_band_information(
|
|
1538
|
+
band_temps.get(), per_walker_band_proposals.sum(axis=1).get().T, per_walker_band_accepted.sum(axis=1).get().T, band_swaps_proposed, band_swaps_accepted,
|
|
1539
|
+
per_walker_band_counts.get(), self.is_rj_prop
|
|
1540
|
+
)
|
|
1541
|
+
# TODO: check rj numbers
|
|
1542
|
+
|
|
1543
|
+
# new_state.log_like[:] = self.check_ll_inject(new_state)
|
|
1544
|
+
|
|
1545
|
+
self.mempool.free_all_blocks()
|
|
1546
|
+
|
|
1547
|
+
if self.is_rj_prop:
|
|
1548
|
+
pass # print(self.name, "2nd count check:", new_state.branches["gb"].inds.sum(axis=-1).mean(axis=-1), "\nll:", new_state.log_like[0] - orig_store, new_state.log_like[0])
|
|
1549
|
+
return new_state, accepted
|
|
1550
|
+
|
|
1551
|
+
def check_ll_inject(self, new_state):
|
|
1552
|
+
|
|
1553
|
+
check_ll = self.mgh.get_ll(include_psd_info=True).copy()
|
|
1554
|
+
|
|
1555
|
+
nleaves_max = new_state.branches["gb"].shape[-2]
|
|
1556
|
+
for i in range(2):
|
|
1557
|
+
self.mgh.channel1_data[0][self.nwalkers * self.data_length * i: self.nwalkers * self.data_length * (i + 1)] = self.mgh.channel1_base_data[0][:]
|
|
1558
|
+
self.mgh.channel2_data[0][self.nwalkers * self.data_length * i: self.nwalkers * self.data_length * (i + 1)] = self.mgh.channel2_base_data[0][:]
|
|
1559
|
+
|
|
1560
|
+
coords_out_gb = new_state.branches["gb"].coords[0, new_state.branches["gb"].inds[0]]
|
|
1561
|
+
coords_in_in = self.parameter_transforms.both_transforms(coords_out_gb)
|
|
1562
|
+
|
|
1563
|
+
band_inds = np.searchsorted(self.band_edges.get(), coords_in_in[:, 1], side="right") - 1
|
|
1564
|
+
assert np.all(band_inds == new_state.branches["gb"].branch_supplimental.holder["band_inds"][0, new_state.branches["gb"].inds[0]])
|
|
1565
|
+
|
|
1566
|
+
walker_vals = np.tile(np.arange(self.nwalkers), (nleaves_max, 1)).transpose((1, 0))[new_state.branches["gb"].inds[0]]
|
|
1567
|
+
|
|
1568
|
+
data_index_1 = ((band_inds % 2) + 0) * self.nwalkers + walker_vals
|
|
1569
|
+
|
|
1570
|
+
data_index = xp.asarray(data_index_1).astype(xp.int32)
|
|
1571
|
+
|
|
1572
|
+
# goes in as -h
|
|
1573
|
+
factors = -xp.ones_like(data_index, dtype=xp.float64)
|
|
1574
|
+
|
|
1575
|
+
waveform_kwargs_tmp = self.waveform_kwargs.copy()
|
|
1576
|
+
|
|
1577
|
+
N_vals = self.band_N_vals[band_inds]
|
|
1578
|
+
self.gb.generate_global_template(
|
|
1579
|
+
coords_in_in,
|
|
1580
|
+
data_index,
|
|
1581
|
+
self.mgh.data_list,
|
|
1582
|
+
batch_size=1000,
|
|
1583
|
+
data_length=self.data_length,
|
|
1584
|
+
factors=factors,
|
|
1585
|
+
N=N_vals,
|
|
1586
|
+
data_splits=self.mgh.gpu_splits,
|
|
1587
|
+
**waveform_kwargs_tmp,
|
|
1588
|
+
)
|
|
1589
|
+
|
|
1590
|
+
check_ll_new = self.mgh.get_ll(include_psd_info=True)
|
|
1591
|
+
check_ll_diff1 = check_ll_new - check_ll
|
|
1592
|
+
# print(check_ll_diff1)
|
|
1593
|
+
|
|
1594
|
+
# breakpoint()
|
|
1595
|
+
return check_ll_new
|
|
1596
|
+
|
|
1597
|
+
# breakpoint()
|
|
1598
|
+
|
|
1599
|
+
# # print(self.accepted / self.num_proposals)
|
|
1600
|
+
|
|
1601
|
+
# # MULTIPLE TRY after RJ
|
|
1602
|
+
|
|
1603
|
+
# if False: # self.is_rj_prop:
|
|
1604
|
+
|
|
1605
|
+
# inds_added = np.where(new_state.branches["gb"].inds.astype(int) - state.branches["gb"].inds.astype(int) == +1)
|
|
1606
|
+
|
|
1607
|
+
# new_coords = xp.asarray(new_state.branches["gb"].coords[inds_added])
|
|
1608
|
+
# temp_inds_add = xp.repeat(xp.arange(ntemps)[:, None], nwalkers * nleaves_max, axis=-1).reshape(ntemps, nwalkers, nleaves_max)[inds_added]
|
|
1609
|
+
# walker_inds_add = xp.repeat(xp.arange(nwalkers)[:, None], ntemps * nleaves_max, axis=-1).reshape(nwalkers, ntemps, nleaves_max).transpose(1, 0, 2)[inds_added]
|
|
1610
|
+
# leaf_inds_add = xp.repeat(xp.arange(nleaves_max)[:, None], ntemps * nwalkers, axis=-1).reshape(nleaves_max, ntemps, nwalkers).transpose(1, 2, 0)[inds_added]
|
|
1611
|
+
# band_inds_add = xp.searchsorted(self.band_edges, new_coords[:, 1] / 1e3, side="right") - 1
|
|
1612
|
+
# N_vals_add = xp.asarray(new_state.branches["gb"].branch_supplimental.holder["N_vals"][inds_added])
|
|
1613
|
+
|
|
1614
|
+
# #### RIGHT NOW WE ARE EXPERIMENTING WITH NO OR SMALL CHANGE IN FREQUENCY
|
|
1615
|
+
# # because if it is added in RJ, it has locked on in frequency in some form
|
|
1616
|
+
|
|
1617
|
+
# # randomize order
|
|
1618
|
+
# inds_random = xp.random.permutation(xp.arange(len(new_coords)))
|
|
1619
|
+
# new_coords = new_coords[inds_random]
|
|
1620
|
+
# temp_inds_add = temp_inds_add[inds_random]
|
|
1621
|
+
# walker_inds_add = walker_inds_add[inds_random]
|
|
1622
|
+
# leaf_inds_add = leaf_inds_add[inds_random]
|
|
1623
|
+
# band_inds_add = band_inds_add[inds_random]
|
|
1624
|
+
# N_vals_add = N_vals_add[inds_random]
|
|
1625
|
+
|
|
1626
|
+
# special_band_map = leaf_inds_add * int(1e12) + walker_inds_add * int(1e6) + band_inds_add
|
|
1627
|
+
# inds_sorted_special = xp.argsort(special_band_map)
|
|
1628
|
+
# special_band_map = special_band_map[inds_sorted_special]
|
|
1629
|
+
# new_coords = new_coords[inds_sorted_special]
|
|
1630
|
+
# temp_inds_add = temp_inds_add[inds_sorted_special]
|
|
1631
|
+
# walker_inds_add = walker_inds_add[inds_sorted_special]
|
|
1632
|
+
# leaf_inds_add = leaf_inds_add[inds_sorted_special]
|
|
1633
|
+
# band_inds_add = band_inds_add[inds_sorted_special]
|
|
1634
|
+
# N_vals_add = N_vals_add[inds_sorted_special]
|
|
1635
|
+
|
|
1636
|
+
# unique_special_bands, unique_special_bands_index, unique_special_bands_inverse = xp.unique(special_band_map, return_index=True, return_inverse=True)
|
|
1637
|
+
# group_index = xp.arange(len(special_band_map)) - xp.arange(len(special_band_map))[unique_special_bands_index][unique_special_bands_inverse]
|
|
1638
|
+
|
|
1639
|
+
# band_splits = 3
|
|
1640
|
+
# num_try = 1000
|
|
1641
|
+
# for group in range(group_index.max().item()):
|
|
1642
|
+
# for split_i in range(band_splits):
|
|
1643
|
+
# group_split = (group_index == group) & (band_inds_add % band_splits == split_i)
|
|
1644
|
+
|
|
1645
|
+
# new_coords_group_split = new_coords[group_split]
|
|
1646
|
+
# temp_inds_add_group_split = temp_inds_add[group_split]
|
|
1647
|
+
# walker_inds_add_group_split = walker_inds_add[group_split]
|
|
1648
|
+
# leaf_inds_add_group_split = leaf_inds_add[group_split]
|
|
1649
|
+
# band_inds_add_group_split = band_inds_add[group_split]
|
|
1650
|
+
# N_vals_add_group_split = N_vals_add[group_split]
|
|
1651
|
+
|
|
1652
|
+
# coords_remove = xp.repeat(new_coords_group_split, num_try, axis=0)
|
|
1653
|
+
# coords_add = coords_remove.copy()
|
|
1654
|
+
|
|
1655
|
+
# inds_params = xp.array([0, 2, 3, 4, 5, 6, 7])
|
|
1656
|
+
# coords_add[:, inds_params] = self.gpu_priors["gb"].rvs(size=coords_remove.shape[0])[:, inds_params]
|
|
1657
|
+
|
|
1658
|
+
# log_proposal_pdf = self.gpu_priors["gb"].logpdf(coords_add)
|
|
1659
|
+
# # remove logpdf from fchange for now
|
|
1660
|
+
# log_proposal_pdf -= self.gpu_priors["gb"].priors_in[1].logpdf(coords_add[:, 1])
|
|
1661
|
+
|
|
1662
|
+
# priors_remove = self.gpu_priors["gb"].logpdf(coords_remove)
|
|
1663
|
+
# priors_add = self.gpu_priors["gb"].logpdf(coords_add)
|
|
1664
|
+
|
|
1665
|
+
# if xp.any(coords_add[:, 1] != coords_remove[:, 1]):
|
|
1666
|
+
# raise NotImplementedError("Assumes frequencies are the same.")
|
|
1667
|
+
# independent = False
|
|
1668
|
+
# else:
|
|
1669
|
+
# independent = True
|
|
1670
|
+
|
|
1671
|
+
# coords_remove_in = self.parameter_transforms.both_transforms(coords_remove, xp=xp)
|
|
1672
|
+
# coords_add_in = self.parameter_transforms.both_transforms(coords_add, xp=xp)
|
|
1673
|
+
|
|
1674
|
+
# waveform_kwargs_tmp = self.waveform_kwargs.copy()
|
|
1675
|
+
# waveform_kwargs_tmp.pop("N")
|
|
1676
|
+
|
|
1677
|
+
# data_index_in = xp.repeat(temp_inds_add_group_split * nwalkers + walker_inds_add_group_split, num_try).astype(xp.int32)
|
|
1678
|
+
# noise_index_in = data_index_in.copy()
|
|
1679
|
+
# N_vals_add_group_split_in = self.xp.repeat(N_vals_add_group_split, num_try)
|
|
1680
|
+
|
|
1681
|
+
# ll_diff = self.xp.asarray(self.gb.swap_likelihood_difference(coords_remove_in, coords_add_in, self.mgh.data_list, self.mgh.psd_list, data_index=data_index_in, noise_index=noise_index_in, N=N_vals_add_group_split_in, data_length=self.data_length, data_splits=self.mgh.gpu_splits, **waveform_kwargs_tmp))
|
|
1682
|
+
|
|
1683
|
+
# ll_diff[self.xp.isnan(ll_diff)] = -1e300
|
|
1684
|
+
|
|
1685
|
+
# band_inv_temps = self.xp.repeat(band_temps[(band_inds_add_group_split, temp_inds_add_group_split)], num_try)
|
|
1686
|
+
|
|
1687
|
+
# logP = band_inv_temps * ll_diff + priors_add
|
|
1688
|
+
|
|
1689
|
+
# from eryn.moves.multipletry import logsumexp, get_mt_computations
|
|
1690
|
+
|
|
1691
|
+
# band_inv_temps = band_inv_temps.reshape(-1, num_try)
|
|
1692
|
+
# ll_diff = ll_diff.reshape(-1, num_try)
|
|
1693
|
+
# logP = logP.reshape(-1, num_try)
|
|
1694
|
+
# priors_add = priors_add.reshape(-1, num_try)
|
|
1695
|
+
# log_proposal_pdf = log_proposal_pdf.reshape(-1, num_try)
|
|
1696
|
+
# coords_add = coords_add.reshape(-1, num_try, ndim)
|
|
1697
|
+
|
|
1698
|
+
# log_importance_weights, log_sum_weights, inds_group_split = get_mt_computations(logP, log_proposal_pdf, symmetric=False, xp=self.xp)
|
|
1699
|
+
|
|
1700
|
+
# inds_tuple = (self.xp.arange(len(inds_group_split)), inds_group_split)
|
|
1701
|
+
|
|
1702
|
+
# ll_diff_out = ll_diff[inds_tuple]
|
|
1703
|
+
# logP_out = logP[inds_tuple]
|
|
1704
|
+
# priors_add_out = priors_add[inds_tuple]
|
|
1705
|
+
# coords_add_out = coords_add[inds_tuple]
|
|
1706
|
+
# log_proposal_pdf_out = log_proposal_pdf[inds_tuple]
|
|
1707
|
+
|
|
1708
|
+
# if not independent:
|
|
1709
|
+
# raise NotImplementedError
|
|
1710
|
+
# else:
|
|
1711
|
+
# aux_coords_add = coords_add.copy()
|
|
1712
|
+
# aux_ll_diff = ll_diff.copy()
|
|
1713
|
+
# aux_priors_add = priors_add.copy()
|
|
1714
|
+
# aux_log_proposal_pdf = log_proposal_pdf.copy()
|
|
1715
|
+
|
|
1716
|
+
# aux_coords_add[:, 0] = new_coords_group_split
|
|
1717
|
+
# aux_ll_diff[:, 0] = 0.0 # diff is zero because the points are already in the data
|
|
1718
|
+
# aux_priors_add[:, 0] = priors_remove[::num_try]
|
|
1719
|
+
|
|
1720
|
+
# initial_log_proposal_pdf = self.gpu_priors["gb"].logpdf(new_coords_group_split)
|
|
1721
|
+
# # remove logpdf from fchange for now
|
|
1722
|
+
# initial_log_proposal_pdf -= self.gpu_priors["gb"].priors_in[1].logpdf(new_coords_group_split[:, 1])
|
|
1723
|
+
|
|
1724
|
+
# aux_log_proposal_pdf[:, 0] = initial_log_proposal_pdf
|
|
1725
|
+
|
|
1726
|
+
# aux_logP = band_inv_temps * aux_ll_diff + aux_priors_add
|
|
1727
|
+
|
|
1728
|
+
# aux_log_importane_weights, aux_log_sum_weights, _ = get_mt_computations(aux_logP, aux_log_proposal_pdf, symmetric=False, xp=self.xp)
|
|
1729
|
+
|
|
1730
|
+
# aux_logP_out = aux_logP[:, 0]
|
|
1731
|
+
# aux_log_proposal_pdf_out = aux_log_proposal_pdf[:, 0]
|
|
1732
|
+
|
|
1733
|
+
# factors = ((aux_logP_out - aux_log_sum_weights)- aux_log_proposal_pdf_out + aux_log_proposal_pdf_out) - ((logP_out - log_sum_weights) - log_proposal_pdf_out + log_proposal_pdf_out)
|
|
1734
|
+
|
|
1735
|
+
# lnpdiff = factors + logP_out - aux_logP_out
|
|
1736
|
+
|
|
1737
|
+
# keep = lnpdiff > self.xp.asarray(self.xp.log(self.xp.random.rand(*logP_out.shape)))
|
|
1738
|
+
|
|
1739
|
+
# coords_remove_keep = new_coords_group_split[keep]
|
|
1740
|
+
# coords_add_keep = coords_add_out[keep]
|
|
1741
|
+
# temp_inds_add_keep = temp_inds_add_group_split[keep]
|
|
1742
|
+
# walker_inds_add_keep = walker_inds_add_group_split[keep]
|
|
1743
|
+
# leaf_inds_add_keep = leaf_inds_add_group_split[keep]
|
|
1744
|
+
# band_inds_add_keep = band_inds_add_group_split[keep]
|
|
1745
|
+
# N_vals_add_keep = N_vals_add_group_split[keep]
|
|
1746
|
+
# ll_diff_keep = ll_diff_out[keep]
|
|
1747
|
+
# priors_add_keep = priors_add_out[keep]
|
|
1748
|
+
# priors_remove_keep = priors_remove[::num_try][keep]
|
|
1749
|
+
|
|
1750
|
+
# # adjust everything
|
|
1751
|
+
# ll_band_diff = xp.zeros((ntemps, nwalkers, len(self.band_edges) - 1))
|
|
1752
|
+
# ll_band_diff[temp_inds_add_keep, walker_inds_add_keep, band_inds_add_keep] = ll_diff_keep
|
|
1753
|
+
# lp_band_diff = xp.zeros((ntemps, nwalkers, len(self.band_edges) - 1))
|
|
1754
|
+
# lp_band_diff[temp_inds_add_keep, walker_inds_add_keep, band_inds_add_keep] = priors_add_keep - priors_remove_keep
|
|
1755
|
+
|
|
1756
|
+
# new_state.branches["gb"].coords[temp_inds_add_keep.get(), walker_inds_add_keep.get(), band_inds_add_keep.get()] = coords_add_keep.get()
|
|
1757
|
+
# new_state.log_like += ll_band_diff.sum(axis=-1).get()
|
|
1758
|
+
# new_state.log_prior += ll_band_diff.sum(axis=-1).get()
|
|
1759
|
+
|
|
1760
|
+
# waveform_kwargs_tmp = self.waveform_kwargs.copy()
|
|
1761
|
+
# waveform_kwargs_tmp.pop("N")
|
|
1762
|
+
# coords_remove_keep_in = self.parameter_transforms.both_transforms(coords_remove_keep, xp=self.xp)
|
|
1763
|
+
# coords_add_keep_in = self.parameter_transforms.both_transforms(coords_add_keep, xp=self.xp)
|
|
1764
|
+
|
|
1765
|
+
# coords_in = xp.concatenate([coords_remove_keep_in, coords_add_keep_in], axis=0)
|
|
1766
|
+
# factors = xp.concatenate([+xp.ones(coords_remove_keep_in.shape[0]), -xp.ones(coords_remove_keep_in.shape[0])])
|
|
1767
|
+
# data_index_tmp = (temp_inds_add_keep * nwalkers + walker_inds_add_keep).astype(xp.int32)
|
|
1768
|
+
# data_index_in = xp.concatenate([data_index_tmp, data_index_tmp], dtype=xp.int32)
|
|
1769
|
+
# N_vals_in = xp.concatenate([N_vals_add_keep, N_vals_add_keep])
|
|
1770
|
+
# self.gb.generate_global_template(
|
|
1771
|
+
# coords_in,
|
|
1772
|
+
# data_index_in,
|
|
1773
|
+
# self.mgh.data_list,
|
|
1774
|
+
# N=N_vals_in,
|
|
1775
|
+
# data_length=self.data_length,
|
|
1776
|
+
# data_splits=self.mgh.gpu_splits,
|
|
1777
|
+
# factors=factors,
|
|
1778
|
+
# **waveform_kwargs_tmp
|
|
1779
|
+
# )
|
|
1780
|
+
# self.xp.cuda.runtime.deviceSynchronize()
|
|
1781
|
+
|
|
1782
|
+
# ll_after = (
|
|
1783
|
+
# self.mgh.get_ll(include_psd_info=True)
|
|
1784
|
+
# .flatten()[new_state.supplimental[:]["overall_inds"]]
|
|
1785
|
+
# .reshape(ntemps, nwalkers)
|
|
1786
|
+
# )
|
|
1787
|
+
# # print(np.abs(new_state.log_like - ll_after).max())
|
|
1788
|
+
# store_max_diff = np.abs(new_state.log_like - ll_after).max()
|
|
1789
|
+
# breakpoint()
|
|
1790
|
+
# self.mempool.free_all_blocks()
|
|
1791
|
+
|
|
1792
|
+
# self.mgh.restore_base_injections()
|
|
1793
|
+
|
|
1794
|
+
# for name in new_state.branches.keys():
|
|
1795
|
+
# if name not in ["gb", "gb"]:
|
|
1796
|
+
# continue
|
|
1797
|
+
# new_state_branch = new_state.branches[name]
|
|
1798
|
+
# coords_here = new_state_branch.coords[new_state_branch.inds]
|
|
1799
|
+
# ntemps, nwalkers, nleaves_max_here, ndim = new_state_branch.shape
|
|
1800
|
+
# try:
|
|
1801
|
+
# group_index = self.xp.asarray(
|
|
1802
|
+
# self.mgh.get_mapped_indices(
|
|
1803
|
+
# np.repeat(
|
|
1804
|
+
# np.arange(ntemps * nwalkers).reshape(
|
|
1805
|
+
# ntemps, nwalkers, 1
|
|
1806
|
+
# ),
|
|
1807
|
+
# nleaves_max,
|
|
1808
|
+
# axis=-1,
|
|
1809
|
+
# )[new_state_branch.inds]
|
|
1810
|
+
# ).astype(self.xp.int32)
|
|
1811
|
+
# )
|
|
1812
|
+
# except IndexError:
|
|
1813
|
+
# breakpoint()
|
|
1814
|
+
# coords_here_in = self.parameter_transforms.both_transforms(
|
|
1815
|
+
# coords_here, xp=np
|
|
1816
|
+
# )
|
|
1817
|
+
|
|
1818
|
+
# waveform_kwargs_fill = self.waveform_kwargs.copy()
|
|
1819
|
+
# waveform_kwargs_fill["start_freq_ind"] = self.start_freq_ind
|
|
1820
|
+
|
|
1821
|
+
# if "N" in waveform_kwargs_fill:
|
|
1822
|
+
# waveform_kwargs_fill.pop("N")
|
|
1823
|
+
|
|
1824
|
+
# self.mgh.multiply_data(-1.0)
|
|
1825
|
+
# self.gb.generate_global_template(
|
|
1826
|
+
# coords_here_in,
|
|
1827
|
+
# group_index,
|
|
1828
|
+
# self.mgh.data_list,
|
|
1829
|
+
# data_length=self.data_length,
|
|
1830
|
+
# data_splits=self.mgh.gpu_splits,
|
|
1831
|
+
# **waveform_kwargs_fill
|
|
1832
|
+
# )
|
|
1833
|
+
# self.xp.cuda.runtime.deviceSynchronize()
|
|
1834
|
+
# self.mgh.multiply_data(-1.0)
|
|
1835
|
+
|
|
1836
|
+
|