lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lisaanalysistools might be problematic. Click here for more details.
- lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
- lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
- lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
- lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
- lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
- lisatools/__init__.py +0 -0
- lisatools/_version.py +4 -0
- lisatools/analysiscontainer.py +438 -0
- lisatools/cutils/detector.cpython-312-darwin.so +0 -0
- lisatools/datacontainer.py +292 -0
- lisatools/detector.py +410 -0
- lisatools/diagnostic.py +976 -0
- lisatools/glitch.py +193 -0
- lisatools/sampling/__init__.py +0 -0
- lisatools/sampling/likelihood.py +882 -0
- lisatools/sampling/moves/__init__.py +0 -0
- lisatools/sampling/moves/gbgroupstretch.py +53 -0
- lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
- lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
- lisatools/sampling/moves/gbspecialstretch.py +1836 -0
- lisatools/sampling/moves/mbhspecialmove.py +286 -0
- lisatools/sampling/moves/placeholder.py +16 -0
- lisatools/sampling/moves/skymodehop.py +110 -0
- lisatools/sampling/moves/specialforegroundmove.py +564 -0
- lisatools/sampling/prior.py +508 -0
- lisatools/sampling/stopping.py +320 -0
- lisatools/sampling/utility.py +324 -0
- lisatools/sensitivity.py +888 -0
- lisatools/sources/__init__.py +0 -0
- lisatools/sources/emri/__init__.py +1 -0
- lisatools/sources/emri/tdiwaveform.py +72 -0
- lisatools/stochastic.py +291 -0
- lisatools/utils/__init__.py +0 -0
- lisatools/utils/constants.py +40 -0
- lisatools/utils/multigpudataholder.py +730 -0
- lisatools/utils/pointeradjust.py +106 -0
- lisatools/utils/utility.py +240 -0
|
@@ -0,0 +1,1287 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from inspect import Attribute
|
|
5
|
+
from multiprocessing.sharedctypes import Value
|
|
6
|
+
import numpy as np
|
|
7
|
+
import warnings
|
|
8
|
+
import time
|
|
9
|
+
from scipy.special import logsumexp
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import cupy as xp
|
|
13
|
+
|
|
14
|
+
gpu_available = True
|
|
15
|
+
except ModuleNotFoundError:
|
|
16
|
+
import numpy as xp
|
|
17
|
+
|
|
18
|
+
gpu_available = False
|
|
19
|
+
|
|
20
|
+
from eryn.state import State, BranchSupplimental
|
|
21
|
+
from eryn.moves import ReversibleJumpMove, MultipleTryMove
|
|
22
|
+
from eryn.prior import ProbDistContainer
|
|
23
|
+
from eryn.utils.utility import groups_from_inds
|
|
24
|
+
|
|
25
|
+
from gbgpu.utils.utility import get_N, get_fdot
|
|
26
|
+
|
|
27
|
+
from lisatools.sampling.moves.gbspecialstretch import GBSpecialStretchMove
|
|
28
|
+
|
|
29
|
+
__all__ = ["GBMutlipleTryRJ"]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def shuffle_along_axis(a, axis):
|
|
33
|
+
idx = np.random.rand(*a.shape).argsort(axis=axis)
|
|
34
|
+
return np.take_along_axis(a, idx, axis=axis)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def searchsorted2d_vec(a, b, xp=None, **kwargs):
|
|
38
|
+
if xp is None:
|
|
39
|
+
xp = np
|
|
40
|
+
m, n = a.shape
|
|
41
|
+
max_num = xp.maximum(a.max() - a.min(), b.max() - b.min()) + 1
|
|
42
|
+
r = max_num * xp.arange(a.shape[0])[:, None]
|
|
43
|
+
p = xp.searchsorted((a + r).ravel(), (b + r).ravel(), **kwargs).reshape(m, -1)
|
|
44
|
+
return p - n * (xp.arange(m)[:, None])
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class GBMutlipleTryRJ(MultipleTryMove, ReversibleJumpMove, GBSpecialStretchMove):
|
|
48
|
+
"""Generate Revesible-Jump proposals for GBs with multiple try
|
|
49
|
+
|
|
50
|
+
Will use gpu if template generator uses GPU.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
priors (object): :class:`PriorContainer` object that has ``logpdf``
|
|
54
|
+
and ``rvs`` methods.
|
|
55
|
+
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
gb_args,
|
|
61
|
+
gb_kwargs,
|
|
62
|
+
m_chirp_lims,
|
|
63
|
+
*args,
|
|
64
|
+
start_ind_limit=10,
|
|
65
|
+
num_try=1,
|
|
66
|
+
point_generator_func=None,
|
|
67
|
+
fix_change=None,
|
|
68
|
+
**kwargs
|
|
69
|
+
):
|
|
70
|
+
self.point_generator_func = point_generator_func
|
|
71
|
+
self.fixed_like_diff = 0
|
|
72
|
+
self.time = 0
|
|
73
|
+
self.name = "gbgroupstretch"
|
|
74
|
+
self.start_ind_limit = start_ind_limit
|
|
75
|
+
self.num_try = num_try
|
|
76
|
+
|
|
77
|
+
GBSpecialStretchMove.__init__(self, *gb_args, **gb_kwargs)
|
|
78
|
+
ReversibleJump.__init__(self, *args, **kwargs)
|
|
79
|
+
MultipleTryMove.__init__(self, self.num_try, take_max_ll=False, xp=self.xp)
|
|
80
|
+
|
|
81
|
+
# setup band edges for priors
|
|
82
|
+
self.band_edges_fdot = self.xp.zeros_like(self.band_edges)
|
|
83
|
+
# lower limit
|
|
84
|
+
|
|
85
|
+
self.band_edges_fdot[:-1] = self.xp.asarray(
|
|
86
|
+
get_fdot(self.band_edges[:-1], Mc=m_chirp_lims[0])
|
|
87
|
+
)
|
|
88
|
+
self.band_edges_fdot[1:] = self.xp.asarray(
|
|
89
|
+
get_fdot(self.band_edges[1:], Mc=m_chirp_lims[1])
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
self.band_edges_gpu = self.xp.asarray(self.band_edges)
|
|
93
|
+
self.fix_change = fix_change
|
|
94
|
+
if self.fix_change not in [None, +1, -1]:
|
|
95
|
+
raise ValueError("fix_change must be None, +1, or -1.")
|
|
96
|
+
|
|
97
|
+
def special_generate_func(
|
|
98
|
+
self,
|
|
99
|
+
coords,
|
|
100
|
+
nwalkers,
|
|
101
|
+
current_priors=None,
|
|
102
|
+
random=None,
|
|
103
|
+
size: int = 1,
|
|
104
|
+
fill=None,
|
|
105
|
+
fill_inds=None,
|
|
106
|
+
band_inds=None,
|
|
107
|
+
):
|
|
108
|
+
"""if self.search_samples is not None:
|
|
109
|
+
# TODO: make replace=True ? in PE
|
|
110
|
+
inds_drawn = self.xp.array([random.choice(
|
|
111
|
+
self.search_inds, size=size, replace=False,
|
|
112
|
+
) for w in range(nwalkers)])
|
|
113
|
+
generated_points = self.search_samples[inds_drawn].copy() # .reshape(nwalkers, size, -1)
|
|
114
|
+
# since this is in search only, we will pretend these are coming from the prior
|
|
115
|
+
# so that they are accepted based on the Likelihood (and not the posterior)
|
|
116
|
+
generate_factors = current_priors.logpdf(generated_points.reshape(nwalkers * size, -1)).reshape(nwalkers, size)
|
|
117
|
+
"""
|
|
118
|
+
# st = time.perf_counter()
|
|
119
|
+
if band_inds is None:
|
|
120
|
+
raise ValueError("band_inds needs to be set")
|
|
121
|
+
|
|
122
|
+
# elif
|
|
123
|
+
if self.point_generator_func is not None:
|
|
124
|
+
# st1 = time.perf_counter()
|
|
125
|
+
generated_points = self.point_generator_func.rvs(size=size * nwalkers)
|
|
126
|
+
"""et1 = time.perf_counter()
|
|
127
|
+
print("generate rvs:", et1 - st1)"""
|
|
128
|
+
|
|
129
|
+
generated_points[self.xp.isnan(generated_points[:, 0]), 0] = 0.01
|
|
130
|
+
|
|
131
|
+
# st2 = time.perf_counter()
|
|
132
|
+
generate_factors = self.point_generator_func.logpdf(generated_points)
|
|
133
|
+
"""et2 = time.perf_counter()
|
|
134
|
+
print("generate logpdf:", et2 - st2)"""
|
|
135
|
+
|
|
136
|
+
starts = self.band_edges_gpu[band_inds]
|
|
137
|
+
ends = self.band_edges_gpu[band_inds + 1]
|
|
138
|
+
|
|
139
|
+
starts_fdot = self.band_edges_fdot[band_inds]
|
|
140
|
+
ends_fdot = self.band_edges_fdot[band_inds + 1]
|
|
141
|
+
|
|
142
|
+
# fill before getting logpdf
|
|
143
|
+
generated_points = generated_points.reshape(nwalkers, size, -1)
|
|
144
|
+
generate_factors = generate_factors.reshape(nwalkers, size)
|
|
145
|
+
|
|
146
|
+
# map back per band
|
|
147
|
+
generated_points[:, :, 1] = (
|
|
148
|
+
generated_points[:, :, 1] * (ends[:, None] - starts[:, None])
|
|
149
|
+
+ starts[:, None]
|
|
150
|
+
) * 1e3
|
|
151
|
+
generated_points[:, :, 2] = (
|
|
152
|
+
generated_points[:, :, 2] * (ends_fdot[:, None] - starts_fdot[:, None])
|
|
153
|
+
+ starts_fdot[:, None]
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
if fill is not None or fill_inds is not None:
|
|
157
|
+
if fill is None or fill_inds is None:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"If providing fill_inds or fill, must provide both."
|
|
160
|
+
)
|
|
161
|
+
generated_points[fill_inds] = fill.copy()
|
|
162
|
+
|
|
163
|
+
generated_points = generated_points.reshape(nwalkers, size, -1)
|
|
164
|
+
|
|
165
|
+
# logpdf contribution from original distribution is zero = log(1/1)
|
|
166
|
+
# THIS HAS BEEN REMOVED TO SIMULATE A PRIOR HERE THAT IS EQUIVALENT TO THE GLOBAL PRIOR VALUE
|
|
167
|
+
# THE FACT IS THAT THE EFFECTIVE PRIOR HERE WILL BE THE SAME AS THE GENERATING FUNCTION (UP TO THE SNR PRIOR IF THAT IS CHANGED IN THE GENERATING FUNCTION)
|
|
168
|
+
# generate_factors[:] += (self.xp.log(1 / (ends - starts)))[:, None]
|
|
169
|
+
# generate_factors[:] += (self.xp.log(1 / (ends_fdot - starts_fdot)))[:, None]
|
|
170
|
+
|
|
171
|
+
else:
|
|
172
|
+
if current_priors is None:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"If generating from the prior, must provide current_priors kwargs."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
generated_points = current_priors.rvs(size=nwalkers * size).reshape(
|
|
178
|
+
nwalkers, size, -1
|
|
179
|
+
)
|
|
180
|
+
if fill is not None or fill_inds is not None:
|
|
181
|
+
if fill is None or fill_inds is None:
|
|
182
|
+
raise ValueError(
|
|
183
|
+
"If providing fill_inds or fill, must provide both."
|
|
184
|
+
)
|
|
185
|
+
generated_points[fill_inds[0]] = fill.copy()
|
|
186
|
+
|
|
187
|
+
generate_factors = current_priors.logpdf(
|
|
188
|
+
generated_points.reshape(nwalkers * size, -1)
|
|
189
|
+
).reshape(nwalkers, size)
|
|
190
|
+
|
|
191
|
+
"""et = time.perf_counter()
|
|
192
|
+
print("GENEARTE:", et - st)"""
|
|
193
|
+
return generated_points, generate_factors
|
|
194
|
+
|
|
195
|
+
def special_like_func(
|
|
196
|
+
self,
|
|
197
|
+
generated_points,
|
|
198
|
+
base_shape,
|
|
199
|
+
inds_reverse=None,
|
|
200
|
+
old_d_h_d_h=None,
|
|
201
|
+
overall_inds=None,
|
|
202
|
+
):
|
|
203
|
+
# st = time.perf_counter()
|
|
204
|
+
self.xp.cuda.runtime.setDevice(self.xp.cuda.runtime.getDevice())
|
|
205
|
+
|
|
206
|
+
if overall_inds is None:
|
|
207
|
+
raise ValueError("overall_inds is None.")
|
|
208
|
+
|
|
209
|
+
# group everything
|
|
210
|
+
|
|
211
|
+
# GENERATED POINTS MUST BE PASSED IN by reference not copied
|
|
212
|
+
num_inds_change, nleaves_max, ndim = base_shape
|
|
213
|
+
num_inds_change_gen, num_try, ndim_gen = generated_points.shape
|
|
214
|
+
assert num_inds_change_gen == num_inds_change and ndim == ndim_gen
|
|
215
|
+
|
|
216
|
+
if old_d_h_d_h is None:
|
|
217
|
+
raise NotImplementedError
|
|
218
|
+
self.d_h_d_h = d_h_d_h = (
|
|
219
|
+
4 * self.df * self.xp.sum((in_vals.conj() * in_vals) / psd, axis=(1, 2))
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
self.d_h_d_h = d_h_d_h = self.xp.asarray(old_d_h_d_h)
|
|
223
|
+
|
|
224
|
+
ll_out = self.xp.zeros((num_inds_change, num_try)).flatten()
|
|
225
|
+
# TODO: take out of loop later?
|
|
226
|
+
|
|
227
|
+
phase_marginalize = self.search
|
|
228
|
+
generated_points_here = generated_points.reshape(-1, ndim)
|
|
229
|
+
|
|
230
|
+
back_d_d = self.gb.d_d.copy()
|
|
231
|
+
self.gb.d_d = self.xp.repeat(d_h_d_h, self.num_try)
|
|
232
|
+
|
|
233
|
+
# do not need mapping because it comes in as overall inds already mapped
|
|
234
|
+
data_index = self.xp.asarray(
|
|
235
|
+
self.xp.repeat(overall_inds, self.num_try).astype(self.xp.int32)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
noise_index = data_index.copy()
|
|
239
|
+
|
|
240
|
+
self.data_index_check = data_index.reshape(-1, self.num_try)[:, 0]
|
|
241
|
+
|
|
242
|
+
# TODO: Remove batch_size if GPU only ?
|
|
243
|
+
prior_generated_points = generated_points_here
|
|
244
|
+
|
|
245
|
+
if self.parameter_transforms is not None:
|
|
246
|
+
prior_generated_points_in = self.parameter_transforms.both_transforms(
|
|
247
|
+
prior_generated_points.copy(), xp=self.xp
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
N_temp = self.xp.asarray(
|
|
251
|
+
get_N(
|
|
252
|
+
prior_generated_points_in[:, 0],
|
|
253
|
+
prior_generated_points_in[:, 1],
|
|
254
|
+
self.waveform_kwargs["T"],
|
|
255
|
+
self.waveform_kwargs["oversample"],
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
waveform_kwargs_in = self.waveform_kwargs.copy()
|
|
260
|
+
waveform_kwargs_in.pop("N")
|
|
261
|
+
main_gpu = self.xp.cuda.runtime.getDevice()
|
|
262
|
+
# TODO: do search sorted and apply that to nearest found for new points found with group
|
|
263
|
+
|
|
264
|
+
# st3 = time.perf_counter()
|
|
265
|
+
ll = self.gb.get_ll(
|
|
266
|
+
prior_generated_points_in,
|
|
267
|
+
self.mgh.data_list,
|
|
268
|
+
self.mgh.psd_list,
|
|
269
|
+
data_index=data_index,
|
|
270
|
+
noise_index=noise_index,
|
|
271
|
+
phase_marginalize=phase_marginalize,
|
|
272
|
+
data_length=self.data_length,
|
|
273
|
+
data_splits=self.mgh.gpu_splits,
|
|
274
|
+
N=N_temp,
|
|
275
|
+
return_cupy=True,
|
|
276
|
+
**waveform_kwargs_in
|
|
277
|
+
)
|
|
278
|
+
self.xp.cuda.runtime.setDevice(main_gpu)
|
|
279
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
280
|
+
|
|
281
|
+
"""et3 = time.perf_counter()
|
|
282
|
+
print("actual like:", et3 - st3)"""
|
|
283
|
+
if self.xp.any(self.xp.isnan(ll)):
|
|
284
|
+
assert self.xp.isnan(ll).sum() < 10
|
|
285
|
+
ll[self.xp.isnan(ll)] = -1e300
|
|
286
|
+
|
|
287
|
+
opt_snr = self.xp.sqrt(self.gb.h_h)
|
|
288
|
+
|
|
289
|
+
if self.search:
|
|
290
|
+
phase_maximized_snr = (
|
|
291
|
+
self.xp.abs(self.gb.d_h) / self.xp.sqrt(self.gb.h_h)
|
|
292
|
+
).real.copy()
|
|
293
|
+
|
|
294
|
+
phase_change = self.xp.angle(
|
|
295
|
+
self.xp.asarray(self.gb.non_marg_d_h) / self.xp.sqrt(self.gb.h_h.real)
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
phase_maximized_snr = phase_maximized_snr.get()
|
|
300
|
+
phase_change = phase_change.get()
|
|
301
|
+
opt_snr = opt_snr.get()
|
|
302
|
+
|
|
303
|
+
except AttributeError:
|
|
304
|
+
pass
|
|
305
|
+
|
|
306
|
+
if self.xp.any(self.xp.isnan(prior_generated_points)) or self.xp.any(
|
|
307
|
+
self.xp.isnan(phase_change)
|
|
308
|
+
):
|
|
309
|
+
breakpoint()
|
|
310
|
+
|
|
311
|
+
# adjust for phase change from maximization
|
|
312
|
+
generated_points_here[:, 3] = (
|
|
313
|
+
generated_points_here[:, 3] - phase_change
|
|
314
|
+
) % (2 * np.pi)
|
|
315
|
+
|
|
316
|
+
snr_comp = phase_maximized_snr
|
|
317
|
+
|
|
318
|
+
else:
|
|
319
|
+
snr_comp = (self.gb.d_h.real / self.xp.sqrt(self.gb.h_h)).real.copy()
|
|
320
|
+
try:
|
|
321
|
+
snr_comp = snr_comp.get()
|
|
322
|
+
opt_snr = opt_snr.get()
|
|
323
|
+
|
|
324
|
+
except AttributeError:
|
|
325
|
+
pass
|
|
326
|
+
|
|
327
|
+
snr_comp2 = snr_comp.reshape(-1, self.num_try)
|
|
328
|
+
okay = snr_comp2 >= 1.0
|
|
329
|
+
okay[inds_reverse.get()] = True
|
|
330
|
+
ll[~okay.flatten()] = -1e300
|
|
331
|
+
|
|
332
|
+
##print(opt_snr[snr_comp.argmax()].real, snr_comp.max(), ll[snr_comp.argmax()].real - -1/2 * self.gb.d_d[snr_comp.argmax()].real)
|
|
333
|
+
if self.search and self.search_snr_lim is not None:
|
|
334
|
+
ll[
|
|
335
|
+
(snr_comp < self.search_snr_lim * 0.8)
|
|
336
|
+
| (opt_snr < self.search_snr_lim * self.search_snr_accept_factor)
|
|
337
|
+
] = -1e300
|
|
338
|
+
"""if self.xp.any(~((snr_comp
|
|
339
|
+
< self.search_snr_lim * 0.95)
|
|
340
|
+
| (opt_snr
|
|
341
|
+
< self.search_snr_lim * self.search_snr_accept_factor))):
|
|
342
|
+
breakpoint()"""
|
|
343
|
+
|
|
344
|
+
generated_points[:] = generated_points_here.reshape(generated_points.shape)
|
|
345
|
+
ll_out = ll.copy()
|
|
346
|
+
# if inds_reverse is not None and len(inds_reverse) != 0 and split == num_splits - 1:
|
|
347
|
+
# breakpoint()
|
|
348
|
+
|
|
349
|
+
ll_out = ll_out.reshape(num_inds_change, num_try)
|
|
350
|
+
self.old_ll_out_check = ll_out.copy()
|
|
351
|
+
if inds_reverse is not None:
|
|
352
|
+
# try:
|
|
353
|
+
# tmp_d_h_d_h = d_h_d_h.get()
|
|
354
|
+
# except AttributeError:
|
|
355
|
+
# tmp_d_h_d_h = d_h_d_h
|
|
356
|
+
|
|
357
|
+
# this is special to GBs
|
|
358
|
+
self.special_aux_ll = -ll_out[inds_reverse, 0]
|
|
359
|
+
|
|
360
|
+
self.check_h_h = self.gb.h_h.reshape(-1, num_try)[inds_reverse]
|
|
361
|
+
self.check_d_h = self.gb.d_h.reshape(-1, num_try)[inds_reverse]
|
|
362
|
+
ll_out[inds_reverse, :] += self.special_aux_ll[:, None]
|
|
363
|
+
|
|
364
|
+
# return gb.d_d
|
|
365
|
+
self.gb.d_d = back_d_d.copy()
|
|
366
|
+
# add noise term
|
|
367
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
368
|
+
"""et = time.perf_counter()
|
|
369
|
+
print("LIKE:", et - st)"""
|
|
370
|
+
return ll_out # + self.noise_ll[:, None]
|
|
371
|
+
|
|
372
|
+
def special_prior_func(
|
|
373
|
+
self, generated_points, base_shape, inds_reverse=None, **kwargs
|
|
374
|
+
):
|
|
375
|
+
# st = time.perf_counter()
|
|
376
|
+
nwalkers, nleaves_max, ndim = base_shape
|
|
377
|
+
# st2 = time.perf_counter()
|
|
378
|
+
lp_new = (
|
|
379
|
+
self.gpu_priors["gb"]
|
|
380
|
+
.logpdf(generated_points.reshape(-1, 8))
|
|
381
|
+
.reshape(nwalkers, self.num_try)
|
|
382
|
+
)
|
|
383
|
+
"""et2 = time.perf_counter()
|
|
384
|
+
print("prior logpdf:", et2 - st2)"""
|
|
385
|
+
lp_total = lp_new # _old[:, None] + lp_new
|
|
386
|
+
|
|
387
|
+
self.old_lp_total_check = lp_total.copy()
|
|
388
|
+
if inds_reverse is not None:
|
|
389
|
+
# this is special to GBs
|
|
390
|
+
self.special_aux_lp = -lp_total[inds_reverse, 0]
|
|
391
|
+
|
|
392
|
+
lp_total[inds_reverse, :] += self.special_aux_lp[:, None]
|
|
393
|
+
|
|
394
|
+
# add noise lp
|
|
395
|
+
"""et = time.perf_counter()
|
|
396
|
+
print("PRIOR:", et - st)"""
|
|
397
|
+
return lp_total
|
|
398
|
+
|
|
399
|
+
def readout_adjustment(self, out_vals, all_vals_prop, aux_all_vals, inds_reverse):
|
|
400
|
+
self.out_vals, self.all_vals_prop, self.aux_all_vals, self.inds_reverse = (
|
|
401
|
+
out_vals,
|
|
402
|
+
all_vals_prop,
|
|
403
|
+
aux_all_vals,
|
|
404
|
+
inds_reverse,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
(
|
|
408
|
+
self.logP_out,
|
|
409
|
+
self.ll_out,
|
|
410
|
+
self.lp_out,
|
|
411
|
+
self.log_proposal_pdf_out,
|
|
412
|
+
self.log_sum_weights,
|
|
413
|
+
) = out_vals
|
|
414
|
+
|
|
415
|
+
self.ll_out[inds_reverse] = self.special_aux_ll
|
|
416
|
+
self.lp_out[inds_reverse] = self.special_aux_lp
|
|
417
|
+
|
|
418
|
+
def get_proposal(
|
|
419
|
+
self,
|
|
420
|
+
gb_coords,
|
|
421
|
+
inds,
|
|
422
|
+
changes,
|
|
423
|
+
leaf_inds_for_changes,
|
|
424
|
+
band_inds,
|
|
425
|
+
random,
|
|
426
|
+
supps=None,
|
|
427
|
+
branch_supps=None,
|
|
428
|
+
):
|
|
429
|
+
"""Make a proposal
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
all_coords (dict): Keys are ``branch_names``. Values are
|
|
433
|
+
np.ndarray[ntemps, nwalkers, nleaves_max, ndim]. These are the curent
|
|
434
|
+
coordinates for all the walkers.
|
|
435
|
+
all_inds (dict): Keys are ``branch_names``. Values are
|
|
436
|
+
np.ndarray[ntemps, nwalkers, nleaves_max]. These are the boolean
|
|
437
|
+
arrays marking which leaves are currently used within each walker.
|
|
438
|
+
all_inds_for_change (dict): Keys are ``branch_names``. Values are
|
|
439
|
+
dictionaries. These dictionaries have keys ``"+1"`` and ``"-1"``,
|
|
440
|
+
indicating waklkers that are adding or removing a leafm respectively.
|
|
441
|
+
The values for these dicts are ``int`` np.ndarray[..., 3]. The "..." indicates
|
|
442
|
+
the number of walkers in all temperatures that fall under either adding
|
|
443
|
+
or removing a leaf. The second dimension, 3, is the indexes into
|
|
444
|
+
the three-dimensional arrays within ``all_inds`` of the specific leaf
|
|
445
|
+
that is being added or removed from those leaves currently considered.
|
|
446
|
+
random (object): Current random state of the sampler.
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
tuple: Tuple containing proposal information.
|
|
450
|
+
First entry is the new coordinates as a dictionary with keys
|
|
451
|
+
as ``branch_names`` and values as
|
|
452
|
+
``double `` np.ndarray[ntemps, nwalkers, nleaves_max, ndim] containing
|
|
453
|
+
proposed coordinates. Second entry is the new ``inds`` array with
|
|
454
|
+
boolean values flipped for added or removed sources. Third entry
|
|
455
|
+
is the factors associated with the
|
|
456
|
+
proposal necessary for detailed balance. This is effectively
|
|
457
|
+
any term in the detailed balance fraction. +log of factors if
|
|
458
|
+
in the numerator. -log of factors if in the denominator.
|
|
459
|
+
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
name = "gb"
|
|
463
|
+
|
|
464
|
+
ntemps, nwalkers, nleaves_max, ndim = gb_coords.shape
|
|
465
|
+
|
|
466
|
+
# adjust inds
|
|
467
|
+
changes_gpu = self.xp.asarray(changes)
|
|
468
|
+
# TODO: remove xp.asarrays here if they come in as GPU arrays
|
|
469
|
+
inds_reverse_band = self.xp.where(changes_gpu == -1)
|
|
470
|
+
|
|
471
|
+
leaf_inds_for_changes_gpu = self.xp.asarray(leaf_inds_for_changes)
|
|
472
|
+
|
|
473
|
+
inds_reverse = inds_reverse_band[:2] + (
|
|
474
|
+
leaf_inds_for_changes_gpu[inds_reverse_band],
|
|
475
|
+
)
|
|
476
|
+
inds_reverse_cpu = tuple([ir.get() for ir in list(inds_reverse)])
|
|
477
|
+
inds_reverse_in = self.xp.where(changes_gpu.flatten() == -1)[0]
|
|
478
|
+
inds_reverse_individual = leaf_inds_for_changes_gpu[changes_gpu < 0]
|
|
479
|
+
|
|
480
|
+
inds_forward_band = self.xp.where(changes_gpu == +1)
|
|
481
|
+
|
|
482
|
+
inds_forward = inds_forward_band[:2] + (
|
|
483
|
+
leaf_inds_for_changes_gpu[inds_forward_band],
|
|
484
|
+
)
|
|
485
|
+
# inds_forward_in = self.xp.where(changes.flatten() == +1)[0]
|
|
486
|
+
# inds_forward_individual = leaf_inds_for_changes_gpu[changes < 0]
|
|
487
|
+
|
|
488
|
+
# add coordinates for new leaves
|
|
489
|
+
current_priors = self.priors[name]
|
|
490
|
+
|
|
491
|
+
inds_here_band = self.xp.where((changes_gpu == -1) | (changes_gpu == +1))
|
|
492
|
+
|
|
493
|
+
inds_here = inds_here_band[:2] + (leaf_inds_for_changes_gpu[inds_here_band],)
|
|
494
|
+
inds_here_cpu = tuple([ir.get() for ir in list(inds_here)])
|
|
495
|
+
|
|
496
|
+
# allows for adjustment during search
|
|
497
|
+
reset_num_try = False
|
|
498
|
+
if len(inds_here[0]) != 0:
|
|
499
|
+
if self.search and self.search_samples is not None:
|
|
500
|
+
raise not NotImplementedError
|
|
501
|
+
if len(self.search_inds) >= self.num_try:
|
|
502
|
+
reset_num_try = True
|
|
503
|
+
old_num_try = self.num_try
|
|
504
|
+
self.num_try = len(self.search_inds)
|
|
505
|
+
|
|
506
|
+
num_inds_change = len(inds_here[0])
|
|
507
|
+
|
|
508
|
+
if self.provide_betas: # and not self.search:
|
|
509
|
+
betas = self.xp.asarray(self.temperature_control.betas)[inds_here[0]]
|
|
510
|
+
else:
|
|
511
|
+
betas = None
|
|
512
|
+
|
|
513
|
+
ll_here = self.xp.asarray(self.current_state.log_like.copy())[inds_here[:2]]
|
|
514
|
+
lp_here = self.xp.asarray(self.current_state.log_prior)[inds_here[:2]]
|
|
515
|
+
|
|
516
|
+
self.lp_old = lp_here
|
|
517
|
+
|
|
518
|
+
rj_info = dict(
|
|
519
|
+
ll=self.xp.zeros_like(ll_here), lp=self.xp.zeros_like(lp_here)
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
N_vals = branch_supps.holder["N_vals"]
|
|
523
|
+
N_vals = self.xp.asarray(N_vals)
|
|
524
|
+
|
|
525
|
+
self.inds_reverse = inds_reverse
|
|
526
|
+
self.inds_forward = inds_forward
|
|
527
|
+
|
|
528
|
+
if len(inds_reverse[0]) > 0:
|
|
529
|
+
parameters_remove = self.parameter_transforms.both_transforms(
|
|
530
|
+
gb_coords[inds_reverse_cpu], xp=self.xp
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
group_index_tmp = inds_reverse[0] * nwalkers + inds_reverse[1]
|
|
534
|
+
N_vals_in = self.xp.asarray(N_vals[inds_reverse])
|
|
535
|
+
group_index = self.xp.asarray(
|
|
536
|
+
self.mgh.get_mapped_indices(group_index_tmp).astype(self.xp.int32)
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
waveform_kwargs_add = self.waveform_kwargs.copy()
|
|
540
|
+
waveform_kwargs_add.pop("N")
|
|
541
|
+
# removing these so d - (h - r) = d - h + r
|
|
542
|
+
|
|
543
|
+
try:
|
|
544
|
+
self.xp.cuda.runtime.deviceSynchronize()
|
|
545
|
+
# parameters_remove[:, 0] *= 1e-30
|
|
546
|
+
self.gb.generate_global_template(
|
|
547
|
+
parameters_remove,
|
|
548
|
+
group_index,
|
|
549
|
+
self.mgh.data_list,
|
|
550
|
+
N=N_vals_in,
|
|
551
|
+
data_length=self.data_length,
|
|
552
|
+
data_splits=self.mgh.gpu_splits,
|
|
553
|
+
**waveform_kwargs_add
|
|
554
|
+
)
|
|
555
|
+
except ValueError as e:
|
|
556
|
+
print(e)
|
|
557
|
+
breakpoint()
|
|
558
|
+
|
|
559
|
+
# self.checkit4 = (-1/2 * self.df * 4 * self.xp.sum(data_minus_template.conj() * data_minus_template / self.psd, axis=(2,3))) + self.xp.asarray(self.noise_ll).reshape(ntemps, nwalkers)
|
|
560
|
+
|
|
561
|
+
overall_inds = supps.holder["overall_inds"][inds_here_cpu[:2]]
|
|
562
|
+
base_shape = (len(inds_here[0]), 1, ndim)
|
|
563
|
+
assert np.prod(band_inds.shape) == inds_here[0].shape[0]
|
|
564
|
+
old_d_h_d_h = self.xp.zeros_like(ll_here)
|
|
565
|
+
|
|
566
|
+
generate_points_out, logP_out, factors_out = self.get_mt_proposal(
|
|
567
|
+
self.xp.asarray(gb_coords[inds_here_cpu]),
|
|
568
|
+
len(inds_here[0]),
|
|
569
|
+
inds_reverse_in,
|
|
570
|
+
random,
|
|
571
|
+
args_prior=(base_shape, inds_reverse_in),
|
|
572
|
+
kwargs_generate={
|
|
573
|
+
"current_priors": current_priors,
|
|
574
|
+
"band_inds": band_inds.flatten(),
|
|
575
|
+
},
|
|
576
|
+
args_like=(base_shape,),
|
|
577
|
+
kwargs_like={
|
|
578
|
+
"inds_reverse": inds_reverse_in,
|
|
579
|
+
"old_d_h_d_h": old_d_h_d_h,
|
|
580
|
+
"overall_inds": overall_inds,
|
|
581
|
+
},
|
|
582
|
+
betas=betas,
|
|
583
|
+
rj_info=rj_info,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# gb_coords[inds_forward] = generate_points_out.copy()
|
|
587
|
+
|
|
588
|
+
# TODO: make sure detailed balance this will move to detailed balance in multiple try
|
|
589
|
+
|
|
590
|
+
self.logP_out = logP_out
|
|
591
|
+
|
|
592
|
+
return (
|
|
593
|
+
generate_points_out,
|
|
594
|
+
self.ll_out,
|
|
595
|
+
self.lp_out,
|
|
596
|
+
factors_out,
|
|
597
|
+
betas,
|
|
598
|
+
inds_here,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
else:
|
|
602
|
+
breakpoint()
|
|
603
|
+
|
|
604
|
+
def propose(self, model, state):
|
|
605
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
609
|
+
state (:class:`State`): Current state of the sampler.
|
|
610
|
+
|
|
611
|
+
Returns:
|
|
612
|
+
:class:`State`: State of sampler after proposal is complete.
|
|
613
|
+
|
|
614
|
+
"""
|
|
615
|
+
|
|
616
|
+
st = time.perf_counter()
|
|
617
|
+
# TODO: keep this?
|
|
618
|
+
# this exposes anywhere in the proposal class to this information
|
|
619
|
+
self.current_state = state
|
|
620
|
+
self.current_model = model
|
|
621
|
+
|
|
622
|
+
# Run any move-specific setup.
|
|
623
|
+
self.setup(state.branches)
|
|
624
|
+
|
|
625
|
+
# ll_before = model.compute_log_like_fn(state.branches_coords, inds=state.branches_inds, supps=state.supplimental, branch_supps=state.branches_supplimental)
|
|
626
|
+
|
|
627
|
+
# if not np.allclose(ll_before[0], state.log_like):
|
|
628
|
+
# breakpoint()
|
|
629
|
+
|
|
630
|
+
new_state = State(state, copy=True)
|
|
631
|
+
all_branch_names = list(new_state.branches_coords.keys())
|
|
632
|
+
|
|
633
|
+
if np.any(
|
|
634
|
+
new_state.branches_supplimental["gb"].holder["N_vals"][
|
|
635
|
+
new_state.branches_inds["gb"]
|
|
636
|
+
]
|
|
637
|
+
== 0
|
|
638
|
+
):
|
|
639
|
+
breakpoint()
|
|
640
|
+
|
|
641
|
+
self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
|
|
642
|
+
|
|
643
|
+
ntemps, nwalkers, _, _ = new_state.branches[
|
|
644
|
+
list(new_state.branches.keys())[0]
|
|
645
|
+
].shape
|
|
646
|
+
|
|
647
|
+
num_consecutive_rj_moves = 1
|
|
648
|
+
# print("starting")
|
|
649
|
+
|
|
650
|
+
for rj_move_i in range(num_consecutive_rj_moves):
|
|
651
|
+
# st = time.perf_counter()
|
|
652
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
653
|
+
|
|
654
|
+
coords_propose_in = self.xp.asarray(new_state.branches_coords["gb"])
|
|
655
|
+
inds_propose_in = self.xp.asarray(new_state.branches_inds["gb"])
|
|
656
|
+
branches_supp_propose_in = new_state.branches_supplimental["gb"]
|
|
657
|
+
remaining_coords = coords_propose_in[inds_propose_in]
|
|
658
|
+
f0 = remaining_coords[:, 1] / 1e3
|
|
659
|
+
inds_into_full_f0 = self.xp.arange(f0.shape[0])
|
|
660
|
+
|
|
661
|
+
band_inds = self.xp.searchsorted(self.band_edges, f0) - 1
|
|
662
|
+
temp_inds = self.xp.repeat(
|
|
663
|
+
self.xp.arange(ntemps)[:, None],
|
|
664
|
+
np.prod(coords_propose_in.shape[1:3]),
|
|
665
|
+
axis=1,
|
|
666
|
+
).reshape(ntemps, nwalkers, -1)[inds_propose_in]
|
|
667
|
+
walkers_inds = self.xp.repeat(
|
|
668
|
+
self.xp.repeat(self.xp.arange(nwalkers)[None, :], ntemps, axis=0)[
|
|
669
|
+
:, :, None
|
|
670
|
+
],
|
|
671
|
+
coords_propose_in.shape[2],
|
|
672
|
+
axis=2,
|
|
673
|
+
)[inds_propose_in]
|
|
674
|
+
leaf_inds = self.xp.repeat(
|
|
675
|
+
self.xp.arange(coords_propose_in.shape[2])[None, :],
|
|
676
|
+
ntemps * nwalkers,
|
|
677
|
+
axis=0,
|
|
678
|
+
).reshape(ntemps, nwalkers, -1)[inds_propose_in]
|
|
679
|
+
|
|
680
|
+
# TODO: make this weighted somehow to not focus on empty bands
|
|
681
|
+
# we are already leaving out the last one
|
|
682
|
+
bands_per_walker = self.xp.tile(
|
|
683
|
+
(self.xp.arange(self.num_bands - 1)), (ntemps, nwalkers, 1)
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
"""et = time.perf_counter()
|
|
687
|
+
print("initial setup", et - st)"""
|
|
688
|
+
# st = time.perf_counter()
|
|
689
|
+
|
|
690
|
+
# odds & evens (evens first)
|
|
691
|
+
for i in range(2):
|
|
692
|
+
# st = time.perf_counter()
|
|
693
|
+
bands_per_walker_here = bands_per_walker[
|
|
694
|
+
bands_per_walker % 2 == i
|
|
695
|
+
].reshape(ntemps, nwalkers, -1)
|
|
696
|
+
max_band_ind = bands_per_walker_here.max().item()
|
|
697
|
+
band_inds_here = (
|
|
698
|
+
int(1e12) * temp_inds[band_inds % 2 == i]
|
|
699
|
+
+ int(1e6) * walkers_inds[band_inds % 2 == i]
|
|
700
|
+
+ band_inds[band_inds % 2 == i]
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
inds_into_full_f0_here = inds_into_full_f0[band_inds.flatten() % 2 == i]
|
|
704
|
+
|
|
705
|
+
inds_for_shuffle = self.xp.arange(len(inds_into_full_f0_here))
|
|
706
|
+
# shuffle allows the first unique index to be the one that is dropped
|
|
707
|
+
self.xp.random.shuffle(inds_for_shuffle)
|
|
708
|
+
bands_inds_here_tmp = band_inds_here[inds_for_shuffle]
|
|
709
|
+
|
|
710
|
+
assert band_inds_here.shape == inds_into_full_f0_here.shape
|
|
711
|
+
|
|
712
|
+
# shuffle and then take first index from band
|
|
713
|
+
# is a good way to get random binary from that band
|
|
714
|
+
(
|
|
715
|
+
unique_bands_inds_here,
|
|
716
|
+
unique_bands_inds_here_index,
|
|
717
|
+
unique_bands_inds_here_inverse,
|
|
718
|
+
unique_bands_inds_here_counts,
|
|
719
|
+
) = self.xp.unique(
|
|
720
|
+
bands_inds_here_tmp,
|
|
721
|
+
return_index=True,
|
|
722
|
+
return_counts=True,
|
|
723
|
+
return_inverse=True,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# selected_inds_from_shuffle is indices reverted back to
|
|
727
|
+
# inds_into_full_f0_here after shuffle
|
|
728
|
+
selected_inds_from_shuffle = inds_for_shuffle[
|
|
729
|
+
unique_bands_inds_here_index
|
|
730
|
+
]
|
|
731
|
+
|
|
732
|
+
# frequencies of the first instance found in band according to self.xp.unique call
|
|
733
|
+
# and what bands they belong to
|
|
734
|
+
inds_into_full_f0_here_selected = inds_into_full_f0_here[
|
|
735
|
+
selected_inds_from_shuffle
|
|
736
|
+
]
|
|
737
|
+
bands_inds_here_selected = band_inds_here[selected_inds_from_shuffle]
|
|
738
|
+
|
|
739
|
+
temp_index_count = temp_inds[inds_into_full_f0_here_selected]
|
|
740
|
+
walkers_index_count = walkers_inds[inds_into_full_f0_here_selected]
|
|
741
|
+
leaf_index_count = leaf_inds[inds_into_full_f0_here_selected]
|
|
742
|
+
|
|
743
|
+
# get band index for each selected source by maths
|
|
744
|
+
band_index_count = self.xp.floor(
|
|
745
|
+
(
|
|
746
|
+
unique_bands_inds_here
|
|
747
|
+
- temp_index_count * int(1e12)
|
|
748
|
+
- walkers_index_count * int(1e6)
|
|
749
|
+
)
|
|
750
|
+
).astype(int)
|
|
751
|
+
band_index_tmp = self.xp.searchsorted(
|
|
752
|
+
bands_per_walker_here[0, 0], band_index_count, side="left"
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
nleaves_here = self.xp.zeros_like(bands_per_walker_here)
|
|
756
|
+
|
|
757
|
+
# how many leaves are in each band
|
|
758
|
+
try:
|
|
759
|
+
nleaves_here[
|
|
760
|
+
(temp_index_count, walkers_index_count, band_index_tmp)
|
|
761
|
+
] = unique_bands_inds_here_counts
|
|
762
|
+
except IndexError:
|
|
763
|
+
breakpoint()
|
|
764
|
+
|
|
765
|
+
# setup leaf count change arrays
|
|
766
|
+
if self.fix_change is None:
|
|
767
|
+
changes = self.xp.random.choice(
|
|
768
|
+
[-1, 1], size=bands_per_walker_here.shape
|
|
769
|
+
)
|
|
770
|
+
else:
|
|
771
|
+
changes = self.xp.full(bands_per_walker_here.shape, self.fix_change)
|
|
772
|
+
|
|
773
|
+
# make sure to add a binary if there are None
|
|
774
|
+
changes[nleaves_here == 0] = +1
|
|
775
|
+
|
|
776
|
+
if self.xp.any(
|
|
777
|
+
nleaves_here >= self.max_k[all_branch_names.index("gb")]
|
|
778
|
+
):
|
|
779
|
+
raise ValueError("nleaves_here higher than max_k.")
|
|
780
|
+
|
|
781
|
+
# number of sub bands
|
|
782
|
+
num_sub_bands_here = bands_per_walker_here.shape[-1]
|
|
783
|
+
|
|
784
|
+
# build arrays to help properly index and locate within these sub bands
|
|
785
|
+
temp_inds_for_change = self.xp.repeat(
|
|
786
|
+
self.xp.arange(ntemps)[:, None],
|
|
787
|
+
nwalkers * num_sub_bands_here,
|
|
788
|
+
axis=-1,
|
|
789
|
+
).reshape(ntemps, nwalkers, num_sub_bands_here)
|
|
790
|
+
walker_inds_for_change = self.xp.tile(
|
|
791
|
+
self.xp.arange(nwalkers), (ntemps, num_sub_bands_here, 1)
|
|
792
|
+
).transpose((0, 2, 1))
|
|
793
|
+
band_inds_for_change = self.xp.tile(
|
|
794
|
+
self.xp.arange(num_sub_bands_here), (ntemps, nwalkers, 1)
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
# leaves will be filled later
|
|
798
|
+
leaf_inds_for_change = -self.xp.ones(
|
|
799
|
+
(ntemps, nwalkers, num_sub_bands_here), dtype=int
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
# add leaves that would be removed if count is not currently zero
|
|
803
|
+
# will adjust for count soon
|
|
804
|
+
# len(leaf_index_count) < total number of proposals
|
|
805
|
+
leaf_inds_for_change[
|
|
806
|
+
(temp_index_count, walkers_index_count, band_index_tmp)
|
|
807
|
+
] = leaf_index_count
|
|
808
|
+
|
|
809
|
+
# band_inds_for_change[(temp_index_count, walkers_index_count, band_index_tmp)] = band_index_tmp
|
|
810
|
+
|
|
811
|
+
leaf_inds_tmp = self.xp.repeat(
|
|
812
|
+
self.xp.arange(coords_propose_in.shape[2])[None, :],
|
|
813
|
+
ntemps * nwalkers,
|
|
814
|
+
axis=0,
|
|
815
|
+
).reshape(ntemps, nwalkers, -1)
|
|
816
|
+
|
|
817
|
+
# any binary that exists give fake high value to remove from sort
|
|
818
|
+
leaf_inds_tmp[inds_propose_in] = int(1e7)
|
|
819
|
+
|
|
820
|
+
# this gives unused leaves, you take the number you need for each walker / temperature
|
|
821
|
+
leaf_inds_tmp_2 = self.xp.sort(leaf_inds_tmp, axis=-1)[
|
|
822
|
+
:, :, :num_sub_bands_here
|
|
823
|
+
]
|
|
824
|
+
leaf_inds_for_change[changes > 0] = leaf_inds_tmp_2[changes > 0]
|
|
825
|
+
|
|
826
|
+
assert not self.xp.any(leaf_inds_for_change == -1)
|
|
827
|
+
|
|
828
|
+
if self.xp.any(leaf_inds_tmp_2 == int(1e7)):
|
|
829
|
+
raise ValueError(
|
|
830
|
+
"Not enough spots to allocate for new binaries. Need to increase max leaves."
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
# TODO: check that number of available spots is high enough
|
|
834
|
+
|
|
835
|
+
"""et = time.perf_counter()
|
|
836
|
+
print("start", et - st)
|
|
837
|
+
st = time.perf_counter()"""
|
|
838
|
+
# propose new sources and coordinates
|
|
839
|
+
(
|
|
840
|
+
new_coords,
|
|
841
|
+
ll_out,
|
|
842
|
+
lp_out,
|
|
843
|
+
factors,
|
|
844
|
+
betas,
|
|
845
|
+
inds_here,
|
|
846
|
+
) = self.get_proposal(
|
|
847
|
+
coords_propose_in,
|
|
848
|
+
inds_propose_in,
|
|
849
|
+
changes,
|
|
850
|
+
leaf_inds_for_change,
|
|
851
|
+
bands_per_walker_here,
|
|
852
|
+
model.random,
|
|
853
|
+
branch_supps=branches_supp_propose_in,
|
|
854
|
+
supps=new_state.supplimental,
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
"""et = time.perf_counter()
|
|
858
|
+
print("proposal", et - st)
|
|
859
|
+
st = time.perf_counter()"""
|
|
860
|
+
|
|
861
|
+
# TODO: check this
|
|
862
|
+
edge_factors = self.xp.zeros_like(factors)
|
|
863
|
+
# get factors for edges
|
|
864
|
+
min_k = self.min_k[all_branch_names.index("gb")]
|
|
865
|
+
max_k = self.max_k[all_branch_names.index("gb")]
|
|
866
|
+
|
|
867
|
+
# fix proposal asymmetry at bottom of k range
|
|
868
|
+
inds_min = self.xp.where(nleaves_here.flatten() == min_k)
|
|
869
|
+
# numerator term so +ln
|
|
870
|
+
edge_factors[inds_min] += np.log(1 / 2.0)
|
|
871
|
+
|
|
872
|
+
# fix proposal asymmetry at top of k range
|
|
873
|
+
inds_max = self.xp.where(nleaves_here.flatten() == max_k)
|
|
874
|
+
# numerator term so -ln
|
|
875
|
+
edge_factors[inds_max] += np.log(1 / 2.0)
|
|
876
|
+
|
|
877
|
+
# fix proposal asymmetry at bottom of k range (kmin + 1)
|
|
878
|
+
inds_min = self.xp.where(nleaves_here.flatten() == min_k + 1)
|
|
879
|
+
# numerator term so +ln
|
|
880
|
+
edge_factors[inds_min] -= np.log(1 / 2.0)
|
|
881
|
+
|
|
882
|
+
# fix proposal asymmetry at top of k range (kmax - 1)
|
|
883
|
+
inds_max = self.xp.where(nleaves_here.flatten() == max_k - 1)
|
|
884
|
+
# numerator term so -ln
|
|
885
|
+
edge_factors[inds_max] -= np.log(1 / 2.0)
|
|
886
|
+
|
|
887
|
+
factors += edge_factors
|
|
888
|
+
|
|
889
|
+
prev_logl = self.xp.asarray(new_state.log_like)[inds_here[:2]]
|
|
890
|
+
|
|
891
|
+
prev_logp = self.xp.asarray(new_state.log_prior)[inds_here[:2]]
|
|
892
|
+
|
|
893
|
+
"""et = time.perf_counter()
|
|
894
|
+
print("prior", et - st)
|
|
895
|
+
st = time.perf_counter()"""
|
|
896
|
+
logl = prev_logl + ll_out
|
|
897
|
+
|
|
898
|
+
logp = prev_logp + lp_out
|
|
899
|
+
# loglcheck, new_blobs = model.compute_log_like_fn(q, inds=new_inds, logp=logp, supps=new_supps, branch_supps=new_branch_supps)
|
|
900
|
+
# if not self.xp.all(self.xp.abs(logl[logl != -1e300] - loglcheck[logl != -1e300]) < 1e-5):
|
|
901
|
+
# breakpoint()
|
|
902
|
+
|
|
903
|
+
logP = self.compute_log_posterior(logl, logp, betas=betas)
|
|
904
|
+
|
|
905
|
+
# TODO: check about prior = - inf
|
|
906
|
+
# takes care of tempering
|
|
907
|
+
prev_logP = self.compute_log_posterior(
|
|
908
|
+
prev_logl, prev_logp, betas=betas
|
|
909
|
+
)
|
|
910
|
+
|
|
911
|
+
# TODO: fix this
|
|
912
|
+
# this is where _metropolisk should come in
|
|
913
|
+
lnpdiff = factors + logP - prev_logP
|
|
914
|
+
|
|
915
|
+
accepted = lnpdiff > self.xp.log(
|
|
916
|
+
self.xp.asarray(model.random.rand(*lnpdiff.shape))
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
"""et = time.perf_counter()
|
|
920
|
+
print("through accepted", et - st)
|
|
921
|
+
st = time.perf_counter()"""
|
|
922
|
+
|
|
923
|
+
# bookkeeping
|
|
924
|
+
|
|
925
|
+
inds_reverse = self.xp.where(changes.flatten() == -1)[0]
|
|
926
|
+
|
|
927
|
+
# adjust births from False -> True
|
|
928
|
+
inds_forward = self.xp.where(changes.flatten() == +1)[0]
|
|
929
|
+
|
|
930
|
+
# accepted_keep = (temp_inds_for_change.flatten()[inds_forward] == 0) & (walker_inds_for_change.flatten()[inds_forward] == 1)
|
|
931
|
+
|
|
932
|
+
"""accepted_keep_tmp = self.xp.where(accepted_keep)[0]
|
|
933
|
+
accepted_keep[accepted_keep_tmp[:]] = False
|
|
934
|
+
accepted_keep[accepted_keep_tmp[164679]] = True"""
|
|
935
|
+
|
|
936
|
+
# accepted[:] = False
|
|
937
|
+
# accepted[inds_forward] = True
|
|
938
|
+
"""accepted[inds_forward[:]] = False
|
|
939
|
+
accepted[inds_reverse[0]] = False"""
|
|
940
|
+
|
|
941
|
+
# not accepted removals
|
|
942
|
+
accepted_reverse = accepted[inds_reverse]
|
|
943
|
+
accepted_inds_reverse = inds_reverse[accepted_reverse]
|
|
944
|
+
not_accepted_inds_reverse = inds_reverse[~accepted_reverse]
|
|
945
|
+
|
|
946
|
+
tuple_not_accepted_reverse = (
|
|
947
|
+
temp_inds_for_change.flatten()[not_accepted_inds_reverse],
|
|
948
|
+
walker_inds_for_change.flatten()[not_accepted_inds_reverse],
|
|
949
|
+
leaf_inds_for_change.flatten()[not_accepted_inds_reverse],
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
tuple_accepted_reverse = (
|
|
953
|
+
temp_inds_for_change.flatten()[accepted_inds_reverse],
|
|
954
|
+
walker_inds_for_change.flatten()[accepted_inds_reverse],
|
|
955
|
+
leaf_inds_for_change.flatten()[accepted_inds_reverse],
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
if len(not_accepted_inds_reverse) > 0:
|
|
959
|
+
points_not_accepted_removal = coords_propose_in[
|
|
960
|
+
tuple_not_accepted_reverse
|
|
961
|
+
]
|
|
962
|
+
|
|
963
|
+
else:
|
|
964
|
+
points_not_accepted_removal = self.xp.empty((0, 8))
|
|
965
|
+
|
|
966
|
+
tuple_accepted_reverse_cpu = tuple(
|
|
967
|
+
[tmp.get() for tmp in list(tuple_accepted_reverse)]
|
|
968
|
+
)
|
|
969
|
+
# accepted removals
|
|
970
|
+
new_state.branches["gb"].inds[tuple_accepted_reverse_cpu] = False
|
|
971
|
+
delta_logl_trans1 = self.xp.zeros_like(
|
|
972
|
+
leaf_inds_for_change, dtype=float
|
|
973
|
+
)
|
|
974
|
+
delta_logl_trans1[
|
|
975
|
+
(
|
|
976
|
+
temp_inds_for_change.flatten()[accepted_inds_reverse],
|
|
977
|
+
walker_inds_for_change.flatten()[accepted_inds_reverse],
|
|
978
|
+
band_inds_for_change.flatten()[accepted_inds_reverse],
|
|
979
|
+
)
|
|
980
|
+
] = (
|
|
981
|
+
logl[accepted_inds_reverse] - prev_logl[accepted_inds_reverse]
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
new_state.log_like += delta_logl_trans1.sum(axis=-1).get()
|
|
985
|
+
|
|
986
|
+
# accepted update logp
|
|
987
|
+
delta_logp_trans = self.xp.zeros_like(leaf_inds_for_change, dtype=float)
|
|
988
|
+
delta_logp_trans[
|
|
989
|
+
(
|
|
990
|
+
temp_inds_for_change.flatten()[accepted_inds_reverse],
|
|
991
|
+
walker_inds_for_change.flatten()[accepted_inds_reverse],
|
|
992
|
+
band_inds_for_change.flatten()[accepted_inds_reverse],
|
|
993
|
+
)
|
|
994
|
+
] = (
|
|
995
|
+
logp[accepted_inds_reverse] - prev_logp[accepted_inds_reverse]
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
new_state.log_prior += delta_logp_trans.sum(axis=-1).get()
|
|
999
|
+
|
|
1000
|
+
# accepted additions
|
|
1001
|
+
accepted_forward = accepted[inds_forward]
|
|
1002
|
+
accepted_inds_forward = inds_forward[accepted_forward]
|
|
1003
|
+
not_accepted_inds_forward = inds_forward[~accepted_forward]
|
|
1004
|
+
|
|
1005
|
+
tuple_accepted_forward = (
|
|
1006
|
+
temp_inds_for_change.flatten()[accepted_inds_forward],
|
|
1007
|
+
walker_inds_for_change.flatten()[accepted_inds_forward],
|
|
1008
|
+
leaf_inds_for_change.flatten()[accepted_inds_forward],
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
tuple_accepted_forward_cpu = tuple(
|
|
1012
|
+
[tmp.get() for tmp in list(tuple_accepted_forward)]
|
|
1013
|
+
)
|
|
1014
|
+
new_state.branches["gb"].inds[tuple_accepted_forward_cpu] = True
|
|
1015
|
+
new_state.branches["gb"].coords[
|
|
1016
|
+
tuple_accepted_forward_cpu
|
|
1017
|
+
] = new_coords[accepted_forward].get()
|
|
1018
|
+
|
|
1019
|
+
"""et = time.perf_counter()
|
|
1020
|
+
print("bookkeeping", et - st)
|
|
1021
|
+
st = time.perf_counter()"""
|
|
1022
|
+
|
|
1023
|
+
if len(accepted_inds_forward) > 0:
|
|
1024
|
+
points_accepted_addition = new_coords[accepted_forward]
|
|
1025
|
+
|
|
1026
|
+
# get group friend finder information
|
|
1027
|
+
f0_accepted_addition = -100 * self.xp.ones_like(
|
|
1028
|
+
leaf_inds_for_change, dtype=float
|
|
1029
|
+
)
|
|
1030
|
+
f0_accepted_addition[
|
|
1031
|
+
(
|
|
1032
|
+
temp_inds_for_change.flatten()[accepted_inds_forward],
|
|
1033
|
+
walker_inds_for_change.flatten()[accepted_inds_forward],
|
|
1034
|
+
band_inds_for_change.flatten()[accepted_inds_forward],
|
|
1035
|
+
)
|
|
1036
|
+
] = (
|
|
1037
|
+
points_accepted_addition[:, 1] / 1e3
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
for t in range(ntemps):
|
|
1041
|
+
# use old state to get supp information
|
|
1042
|
+
f0_old = (
|
|
1043
|
+
self.xp.asarray(
|
|
1044
|
+
state.branches["gb"].coords[
|
|
1045
|
+
t, state.branches["gb"].inds[t]
|
|
1046
|
+
][:, 1]
|
|
1047
|
+
)
|
|
1048
|
+
/ 1e3
|
|
1049
|
+
)
|
|
1050
|
+
friend_start_inds = self.xp.asarray(
|
|
1051
|
+
state.branches["gb"].branch_supplimental.holder[
|
|
1052
|
+
"friend_start_inds"
|
|
1053
|
+
][t, state.branches["gb"].inds[t]]
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
f0_old_sorted = self.xp.sort(f0_old, axis=-1)
|
|
1057
|
+
inds_f0_old_sorted = self.xp.argsort(f0_old, axis=-1)
|
|
1058
|
+
f0_accepted_addition_in1 = f0_accepted_addition[t]
|
|
1059
|
+
f0_accepted_addition_in2 = f0_accepted_addition_in1[
|
|
1060
|
+
f0_accepted_addition_in1 > -1.0
|
|
1061
|
+
]
|
|
1062
|
+
temp_inds_in = temp_inds_for_change[t][
|
|
1063
|
+
f0_accepted_addition_in1 > -1.0
|
|
1064
|
+
]
|
|
1065
|
+
walker_inds_in = walker_inds_for_change[t][
|
|
1066
|
+
f0_accepted_addition_in1 > -1.0
|
|
1067
|
+
]
|
|
1068
|
+
leaf_inds_in = leaf_inds_for_change[t][
|
|
1069
|
+
f0_accepted_addition_in1 > -1.0
|
|
1070
|
+
]
|
|
1071
|
+
|
|
1072
|
+
inds_f0_accepted_addition = (
|
|
1073
|
+
self.xp.searchsorted(
|
|
1074
|
+
f0_old_sorted, f0_accepted_addition_in2, side="right"
|
|
1075
|
+
)
|
|
1076
|
+
- 1
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
old_inds_f0_old = inds_f0_old_sorted[inds_f0_accepted_addition]
|
|
1080
|
+
comp_f0_old = f0_old[old_inds_f0_old]
|
|
1081
|
+
new_friends_start_inds = friend_start_inds[old_inds_f0_old]
|
|
1082
|
+
|
|
1083
|
+
# TODO: maybe check this
|
|
1084
|
+
new_state.branches["gb"].branch_supplimental.holder[
|
|
1085
|
+
"friend_start_inds"
|
|
1086
|
+
][
|
|
1087
|
+
(
|
|
1088
|
+
temp_inds_in.get(),
|
|
1089
|
+
walker_inds_in.get(),
|
|
1090
|
+
leaf_inds_in.get(),
|
|
1091
|
+
)
|
|
1092
|
+
] = new_friends_start_inds.get()
|
|
1093
|
+
|
|
1094
|
+
"""inds_into_old_array = self.xp.take_along_axis(inds_f0_old_sorted, inds_f0_accepted_addition, axis=-1).reshape(ntemps, nwalkers, -1)
|
|
1095
|
+
|
|
1096
|
+
new_start_inds = self.xp.take_along_axis(state.branches["gb"].branch_supplimental.holder["friend_start_inds"], inds_into_old_array, axis=-1)
|
|
1097
|
+
|
|
1098
|
+
keep_new_start_inds = new_start_inds[f0_accepted_addition.reshape(ntemps, nwalkers, -1) > -10.0]
|
|
1099
|
+
breakpoint()
|
|
1100
|
+
state.branches["gb"].branch_supplimental.holder["friend_start_inds"][(
|
|
1101
|
+
|
|
1102
|
+
)] = keep_new_start_inds
|
|
1103
|
+
|
|
1104
|
+
inds_into_old_array
|
|
1105
|
+
keep_inds_f0_accepted_addition = inds_f0_accepted_addition > 0
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
inds_into_old_array = inds_f0_old_sorted[inds_f0_accepted_addition]"""
|
|
1110
|
+
|
|
1111
|
+
else:
|
|
1112
|
+
points_accepted_addition = self.xp.empty((0, 8))
|
|
1113
|
+
|
|
1114
|
+
delta_logl_trans2 = self.xp.zeros_like(
|
|
1115
|
+
leaf_inds_for_change, dtype=float
|
|
1116
|
+
)
|
|
1117
|
+
delta_logl_trans2[
|
|
1118
|
+
(
|
|
1119
|
+
temp_inds_for_change.flatten()[accepted_inds_forward],
|
|
1120
|
+
walker_inds_for_change.flatten()[accepted_inds_forward],
|
|
1121
|
+
band_inds_for_change.flatten()[accepted_inds_forward],
|
|
1122
|
+
)
|
|
1123
|
+
] = (
|
|
1124
|
+
logl[accepted_inds_forward] - prev_logl[accepted_inds_forward]
|
|
1125
|
+
)
|
|
1126
|
+
|
|
1127
|
+
new_state.log_like += delta_logl_trans2.sum(axis=-1).get()
|
|
1128
|
+
|
|
1129
|
+
delta_logp_trans = self.xp.zeros_like(leaf_inds_for_change, dtype=float)
|
|
1130
|
+
delta_logp_trans[
|
|
1131
|
+
(
|
|
1132
|
+
temp_inds_for_change.flatten()[accepted_inds_forward],
|
|
1133
|
+
walker_inds_for_change.flatten()[accepted_inds_forward],
|
|
1134
|
+
band_inds_for_change.flatten()[accepted_inds_forward],
|
|
1135
|
+
)
|
|
1136
|
+
] = (
|
|
1137
|
+
logp[accepted_inds_forward] - prev_logp[accepted_inds_forward]
|
|
1138
|
+
)
|
|
1139
|
+
|
|
1140
|
+
new_state.log_prior += delta_logp_trans.sum(axis=-1).get()
|
|
1141
|
+
|
|
1142
|
+
accepted_counts = self.xp.zeros_like(leaf_inds_for_change, dtype=bool)
|
|
1143
|
+
num_proposals_here = leaf_inds_for_change.shape[-1]
|
|
1144
|
+
|
|
1145
|
+
accepted_counts[
|
|
1146
|
+
(
|
|
1147
|
+
temp_inds_for_change.flatten()[accepted],
|
|
1148
|
+
walker_inds_for_change.flatten()[accepted],
|
|
1149
|
+
band_inds_for_change.flatten()[accepted],
|
|
1150
|
+
)
|
|
1151
|
+
] = True
|
|
1152
|
+
|
|
1153
|
+
accepted_overall = accepted_counts.sum(axis=-1)
|
|
1154
|
+
|
|
1155
|
+
# TODO: if we do not run every band, we need to adjust this
|
|
1156
|
+
self.accepted += accepted_overall.get()
|
|
1157
|
+
self.num_proposals += num_proposals_here
|
|
1158
|
+
|
|
1159
|
+
points_to_add_to_template = self.xp.concatenate(
|
|
1160
|
+
[
|
|
1161
|
+
points_accepted_addition,
|
|
1162
|
+
points_not_accepted_removal, # were removed at the start, need to be put back
|
|
1163
|
+
],
|
|
1164
|
+
axis=0,
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
"""et = time.perf_counter()
|
|
1168
|
+
print("before add", et - st)
|
|
1169
|
+
st = time.perf_counter()"""
|
|
1170
|
+
|
|
1171
|
+
if points_to_add_to_template.shape[0] > 0:
|
|
1172
|
+
points_to_add_to_template_in = (
|
|
1173
|
+
self.parameter_transforms.both_transforms(
|
|
1174
|
+
points_to_add_to_template, xp=self.xp
|
|
1175
|
+
)
|
|
1176
|
+
)
|
|
1177
|
+
N_temp = get_N(
|
|
1178
|
+
points_to_add_to_template_in[:, 0],
|
|
1179
|
+
points_to_add_to_template_in[:, 1],
|
|
1180
|
+
self.waveform_kwargs["T"],
|
|
1181
|
+
self.waveform_kwargs["oversample"],
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
# has to be accepted forward first (sase as points_to_add_to_template)
|
|
1185
|
+
inds_here_for_add = tuple(
|
|
1186
|
+
[
|
|
1187
|
+
self.xp.concatenate(
|
|
1188
|
+
[
|
|
1189
|
+
tuple_accepted_forward[jjj],
|
|
1190
|
+
tuple_not_accepted_reverse[jjj],
|
|
1191
|
+
]
|
|
1192
|
+
)
|
|
1193
|
+
for jjj in range(len(tuple_not_accepted_reverse))
|
|
1194
|
+
]
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
inds_here_for_add_cpu = tuple(
|
|
1198
|
+
[tmp.get() for tmp in list(inds_here_for_add)]
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
# inds_here_for_add = tuple_accepted_forward
|
|
1202
|
+
|
|
1203
|
+
new_state.branches["gb"].branch_supplimental.holder["N_vals"][
|
|
1204
|
+
inds_here_for_add_cpu
|
|
1205
|
+
] = N_temp.get()
|
|
1206
|
+
|
|
1207
|
+
group_index_tmp_accepted_forward = (
|
|
1208
|
+
tuple_accepted_forward[0] * nwalkers + tuple_accepted_forward[1]
|
|
1209
|
+
)
|
|
1210
|
+
group_index_tmp_not_accepted_reverse = (
|
|
1211
|
+
tuple_not_accepted_reverse[0] * nwalkers
|
|
1212
|
+
+ tuple_not_accepted_reverse[1]
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
group_index_tmp = self.xp.concatenate(
|
|
1216
|
+
[
|
|
1217
|
+
group_index_tmp_accepted_forward,
|
|
1218
|
+
group_index_tmp_not_accepted_reverse,
|
|
1219
|
+
]
|
|
1220
|
+
)
|
|
1221
|
+
group_index = self.xp.asarray(
|
|
1222
|
+
self.mgh.get_mapped_indices(group_index_tmp).astype(
|
|
1223
|
+
self.xp.int32
|
|
1224
|
+
)
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
# adding these so d - (h + a) = d - h - a
|
|
1228
|
+
factors_multiply = -self.xp.ones_like(N_temp, dtype=self.xp.float64)
|
|
1229
|
+
waveform_kwargs_add = self.waveform_kwargs.copy()
|
|
1230
|
+
waveform_kwargs_add.pop("N")
|
|
1231
|
+
|
|
1232
|
+
# points_to_add_to_template_in[1:, 0] *= 1e-30
|
|
1233
|
+
|
|
1234
|
+
self.gb.generate_global_template(
|
|
1235
|
+
points_to_add_to_template_in,
|
|
1236
|
+
group_index,
|
|
1237
|
+
self.mgh.data_list,
|
|
1238
|
+
N=N_temp,
|
|
1239
|
+
data_length=self.data_length,
|
|
1240
|
+
data_splits=self.mgh.gpu_splits,
|
|
1241
|
+
factors=factors_multiply,
|
|
1242
|
+
**waveform_kwargs_add
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1245
|
+
if np.any(
|
|
1246
|
+
new_state.branches_supplimental["gb"].holder["N_vals"][
|
|
1247
|
+
new_state.branches_inds["gb"]
|
|
1248
|
+
]
|
|
1249
|
+
== 0
|
|
1250
|
+
):
|
|
1251
|
+
breakpoint()
|
|
1252
|
+
|
|
1253
|
+
"""et = time.perf_counter()
|
|
1254
|
+
print("after add", et - st)
|
|
1255
|
+
st = time.perf_counter()"""
|
|
1256
|
+
|
|
1257
|
+
# st = time.perf_counter()
|
|
1258
|
+
if self.time % 1 == 0:
|
|
1259
|
+
ll_after2 = (
|
|
1260
|
+
self.mgh.get_ll(include_psd_info=True)
|
|
1261
|
+
.flatten()[new_state.supplimental[:]["overall_inds"]]
|
|
1262
|
+
.reshape(ntemps, nwalkers)
|
|
1263
|
+
)
|
|
1264
|
+
if np.abs(ll_after2 - new_state.log_like).max() > 1e-4:
|
|
1265
|
+
if np.abs(ll_after2 - new_state.log_like).max() > 10.0:
|
|
1266
|
+
breakpoint()
|
|
1267
|
+
fix = np.abs(ll_after2 - new_state.log_like) > 1e-4
|
|
1268
|
+
new_state.log_like[fix] = ll_after2[fix]
|
|
1269
|
+
|
|
1270
|
+
# print("rj", (et - st) / num_consecutive_rj_moves)
|
|
1271
|
+
|
|
1272
|
+
if False: # self.temperature_control is not None and not self.prevent_swaps:
|
|
1273
|
+
# TODO: add swaps?
|
|
1274
|
+
new_state = self.temperature_control.temper_comps(new_state, adapt=False)
|
|
1275
|
+
|
|
1276
|
+
# et = time.perf_counter()
|
|
1277
|
+
# print("swapping", et - st)
|
|
1278
|
+
self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
|
|
1279
|
+
accepted = np.zeros_like(new_state.log_like)
|
|
1280
|
+
self.time += 1
|
|
1281
|
+
|
|
1282
|
+
self.mempool.free_all_blocks()
|
|
1283
|
+
|
|
1284
|
+
et = time.perf_counter()
|
|
1285
|
+
print("RJ end", et - st)
|
|
1286
|
+
print(new_state.branches["gb"].nleaves.mean(axis=-1))
|
|
1287
|
+
return new_state, accepted
|