lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lisaanalysistools might be problematic. Click here for more details.
- lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
- lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
- lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
- lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
- lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
- lisatools/__init__.py +0 -0
- lisatools/_version.py +4 -0
- lisatools/analysiscontainer.py +438 -0
- lisatools/cutils/detector.cpython-312-darwin.so +0 -0
- lisatools/datacontainer.py +292 -0
- lisatools/detector.py +410 -0
- lisatools/diagnostic.py +976 -0
- lisatools/glitch.py +193 -0
- lisatools/sampling/__init__.py +0 -0
- lisatools/sampling/likelihood.py +882 -0
- lisatools/sampling/moves/__init__.py +0 -0
- lisatools/sampling/moves/gbgroupstretch.py +53 -0
- lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
- lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
- lisatools/sampling/moves/gbspecialstretch.py +1836 -0
- lisatools/sampling/moves/mbhspecialmove.py +286 -0
- lisatools/sampling/moves/placeholder.py +16 -0
- lisatools/sampling/moves/skymodehop.py +110 -0
- lisatools/sampling/moves/specialforegroundmove.py +564 -0
- lisatools/sampling/prior.py +508 -0
- lisatools/sampling/stopping.py +320 -0
- lisatools/sampling/utility.py +324 -0
- lisatools/sensitivity.py +888 -0
- lisatools/sources/__init__.py +0 -0
- lisatools/sources/emri/__init__.py +1 -0
- lisatools/sources/emri/tdiwaveform.py +72 -0
- lisatools/stochastic.py +291 -0
- lisatools/utils/__init__.py +0 -0
- lisatools/utils/constants.py +40 -0
- lisatools/utils/multigpudataholder.py +730 -0
- lisatools/utils/pointeradjust.py +106 -0
- lisatools/utils/utility.py +240 -0
|
@@ -0,0 +1,564 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from inspect import Attribute
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy import stats
|
|
7
|
+
import warnings
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import cupy as xp
|
|
12
|
+
|
|
13
|
+
gpu_available = True
|
|
14
|
+
except ModuleNotFoundError:
|
|
15
|
+
import numpy as xp
|
|
16
|
+
|
|
17
|
+
gpu_available = False
|
|
18
|
+
|
|
19
|
+
from eryn.moves import StretchMove
|
|
20
|
+
from eryn.prior import ProbDistContainer
|
|
21
|
+
from eryn.state import State
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = ["GBSpecialStretchMove"]
|
|
25
|
+
|
|
26
|
+
# MHMove needs to be to the left here to overwrite GBBruteRejectionRJ RJ proposal method
|
|
27
|
+
class GBForegroundSpecialMove(StretchMove):
|
|
28
|
+
"""
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
gb,
|
|
34
|
+
priors,
|
|
35
|
+
start_freq_ind,
|
|
36
|
+
data_length,
|
|
37
|
+
mgh,
|
|
38
|
+
fd,
|
|
39
|
+
band_edges,
|
|
40
|
+
*args,
|
|
41
|
+
waveform_kwargs={},
|
|
42
|
+
noise_kwargs={},
|
|
43
|
+
parameter_transforms=None,
|
|
44
|
+
search=False,
|
|
45
|
+
search_samples=None,
|
|
46
|
+
search_snrs=None,
|
|
47
|
+
search_snr_lim=None,
|
|
48
|
+
search_snr_accept_factor=1.0,
|
|
49
|
+
take_max_ll=False,
|
|
50
|
+
global_template_builder=None,
|
|
51
|
+
psd_func=None,
|
|
52
|
+
provide_betas=False,
|
|
53
|
+
alternate_priors=None,
|
|
54
|
+
batch_size=5,
|
|
55
|
+
**kwargs
|
|
56
|
+
):
|
|
57
|
+
StretchMove.__init__(self, **kwargs)
|
|
58
|
+
|
|
59
|
+
self.time = 0
|
|
60
|
+
self.greater_than_1e0 = 0
|
|
61
|
+
self.name = "GBForegroundSpecialMove".lower()
|
|
62
|
+
|
|
63
|
+
# TODO: make priors optional like special generate function?
|
|
64
|
+
for key in priors:
|
|
65
|
+
if not isinstance(priors[key], ProbDistContainer):
|
|
66
|
+
raise ValueError("Priors need to be eryn.priors.ProbDistContainer object.")
|
|
67
|
+
self.priors = priors
|
|
68
|
+
self.gb = gb
|
|
69
|
+
self.provide_betas = provide_betas
|
|
70
|
+
self.batch_size = batch_size
|
|
71
|
+
self.stop_here = True
|
|
72
|
+
|
|
73
|
+
# use gpu from template generator
|
|
74
|
+
self.use_gpu = gb.use_gpu
|
|
75
|
+
if self.use_gpu:
|
|
76
|
+
self.xp = xp
|
|
77
|
+
self.mempool = self.xp.get_default_memory_pool()
|
|
78
|
+
|
|
79
|
+
else:
|
|
80
|
+
self.xp = np
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
self.band_edges = band_edges
|
|
84
|
+
self.num_bands = len(band_edges) - 1
|
|
85
|
+
self.start_freq_ind = start_freq_ind
|
|
86
|
+
self.data_length = data_length
|
|
87
|
+
self.waveform_kwargs = waveform_kwargs
|
|
88
|
+
self.noise_kwargs = noise_kwargs
|
|
89
|
+
self.parameter_transforms = parameter_transforms
|
|
90
|
+
self.psd_func = psd_func
|
|
91
|
+
self.fd = fd
|
|
92
|
+
self.df = (fd[1] - fd[0]).item()
|
|
93
|
+
self.mgh = mgh
|
|
94
|
+
self.search = search
|
|
95
|
+
self.global_template_builder = global_template_builder
|
|
96
|
+
|
|
97
|
+
if search_snrs is not None:
|
|
98
|
+
if search_snr_lim is None:
|
|
99
|
+
search_snr_lim = 0.1
|
|
100
|
+
|
|
101
|
+
assert len(search_samples) == len(search_snrs)
|
|
102
|
+
|
|
103
|
+
self.search_samples = search_samples
|
|
104
|
+
self.search_snrs = search_snrs
|
|
105
|
+
self.search_snr_lim = search_snr_lim
|
|
106
|
+
self.search_snr_accept_factor = search_snr_accept_factor
|
|
107
|
+
|
|
108
|
+
self.take_max_ll = take_max_ll
|
|
109
|
+
|
|
110
|
+
def propose(self, model, state):
|
|
111
|
+
"""Use the move to generate a proposal and compute the acceptance
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
model (:class:`eryn.model.Model`): Carrier of sampler information.
|
|
115
|
+
state (:class:`State`): Current state of the sampler.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
:class:`State`: State of sampler after proposal is complete.
|
|
119
|
+
|
|
120
|
+
"""
|
|
121
|
+
# st = time.perf_counter()
|
|
122
|
+
self.xp.cuda.runtime.setDevice(self.mgh.gpus[0])
|
|
123
|
+
|
|
124
|
+
# np.random.seed(10)
|
|
125
|
+
#print("start stretch")
|
|
126
|
+
#st = time.perf_counter()
|
|
127
|
+
# Check that the dimensions are compatible.
|
|
128
|
+
ntemps, nwalkers, nleaves, ndim = state.branches["galfor"].shape
|
|
129
|
+
|
|
130
|
+
self.nwalkers = nwalkers
|
|
131
|
+
# TODO: deal with more intensive acceptance fractions
|
|
132
|
+
# Run any move-specific setup.
|
|
133
|
+
self.setup(state.branches)
|
|
134
|
+
|
|
135
|
+
new_state = State(state) # , copy=True)
|
|
136
|
+
self.mempool.free_all_blocks()
|
|
137
|
+
|
|
138
|
+
self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
|
|
139
|
+
|
|
140
|
+
# data should not be whitened
|
|
141
|
+
|
|
142
|
+
# Split the ensemble in half and iterate over these two halves.
|
|
143
|
+
accepted = np.zeros((ntemps, nwalkers), dtype=bool)
|
|
144
|
+
|
|
145
|
+
ntemps, nwalkers, nleaves_max, ndim = state.branches_coords["gb"].shape
|
|
146
|
+
|
|
147
|
+
split_inds = np.zeros(nwalkers, dtype=int)
|
|
148
|
+
split_inds[1::2] = 1
|
|
149
|
+
np.random.shuffle(split_inds)
|
|
150
|
+
|
|
151
|
+
current_coords_galfor = new_state.branches["galfor"].coords.copy()
|
|
152
|
+
"""et = time.perf_counter()
|
|
153
|
+
print("setup", (et - st))
|
|
154
|
+
st = time.perf_counter()"""
|
|
155
|
+
for split in range(2):
|
|
156
|
+
# st = time.perf_counter()
|
|
157
|
+
split_here = split_inds == split
|
|
158
|
+
walkers_keep = self.xp.arange(nwalkers)[split_here]
|
|
159
|
+
|
|
160
|
+
points_to_move = {key: new_state.branches[key].coords[:, split_here] for key in ["psd", "galfor"]}
|
|
161
|
+
points_for_move = {key: [new_state.branches[key].coords[:, ~split_here]] for key in ["psd", "galfor"]}
|
|
162
|
+
|
|
163
|
+
q, factors = self.get_proposal(
|
|
164
|
+
points_to_move,
|
|
165
|
+
points_for_move,
|
|
166
|
+
model.random
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
temp_part_general = np.repeat(np.arange(ntemps)[:, None], nwalkers, axis=-1)[:, split_here].flatten()
|
|
170
|
+
walker_part_general = np.tile(np.arange(nwalkers), (ntemps, 1))[:, split_here].flatten()
|
|
171
|
+
|
|
172
|
+
# get logp
|
|
173
|
+
logp_here = (self.priors["psd"].logpdf(q["psd"].reshape(-1, q["psd"].shape[-1])) + self.priors["galfor"].logpdf(q["galfor"].reshape(-1, q["galfor"].shape[-1]))).reshape((ntemps, int(nwalkers / 2)))
|
|
174
|
+
|
|
175
|
+
prev_logp_here = (self.priors["psd"].logpdf(new_state.branches_coords["psd"][:, split_here].reshape(-1, q["psd"].shape[-1])) + self.priors["galfor"].logpdf(new_state.branches_coords["galfor"][:, split_here].reshape(-1, q["galfor"].shape[-1]))).reshape((ntemps, int(nwalkers / 2)))
|
|
176
|
+
|
|
177
|
+
bad = np.isinf(logp_here.flatten())
|
|
178
|
+
|
|
179
|
+
data_index_tmp = np.asarray((temp_part_general * self.nwalkers + walker_part_general).astype(xp.int32))
|
|
180
|
+
|
|
181
|
+
data_index = self.mgh.get_mapped_indices(data_index_tmp)
|
|
182
|
+
|
|
183
|
+
data_index_in = data_index[~bad]
|
|
184
|
+
|
|
185
|
+
psd_params = q["psd"].reshape(-1, q["psd"].shape[-1])[~bad]
|
|
186
|
+
foreground_params = q["galfor"].reshape(-1, q["galfor"].shape[-1])[~bad]
|
|
187
|
+
|
|
188
|
+
self.mgh.set_psd_vals(psd_params, foreground_params=foreground_params, overall_inds=data_index_in)
|
|
189
|
+
|
|
190
|
+
logl_temp = self.mgh.get_ll(include_psd_info=True, overall_inds=data_index_in)
|
|
191
|
+
|
|
192
|
+
logl = np.full((ntemps, int(nwalkers / 2)), -1e300)
|
|
193
|
+
logl[~bad.reshape(ntemps, -1)] = logl_temp
|
|
194
|
+
|
|
195
|
+
prev_logl = new_state.log_like[:, split_here]
|
|
196
|
+
prev_logp = new_state.log_prior[:, split_here]
|
|
197
|
+
|
|
198
|
+
logp = prev_logp + logp_here - prev_logp_here
|
|
199
|
+
|
|
200
|
+
logP = self.compute_log_posterior(logl, logp)
|
|
201
|
+
prev_logP = self.compute_log_posterior(prev_logl, prev_logp)
|
|
202
|
+
|
|
203
|
+
lnpdiff = factors + logP - prev_logP
|
|
204
|
+
|
|
205
|
+
keep = lnpdiff > np.log(model.random.rand(ntemps, int(nwalkers / 2)))
|
|
206
|
+
|
|
207
|
+
temp_inds_keep = temp_part_general[keep.flatten()]
|
|
208
|
+
walker_inds_keep = walker_part_general[keep.flatten()]
|
|
209
|
+
|
|
210
|
+
accepted[temp_inds_keep, walker_inds_keep] = True
|
|
211
|
+
|
|
212
|
+
new_state.log_like[temp_inds_keep, walker_inds_keep] = logl[keep]
|
|
213
|
+
new_state.log_prior[temp_inds_keep, walker_inds_keep] = logp[keep]
|
|
214
|
+
|
|
215
|
+
for key in ["psd", "galfor"]:
|
|
216
|
+
new_state.branches[key].coords[temp_inds_keep, walker_inds_keep, np.zeros_like(walker_inds_keep)] = q[key][keep][:, 0]
|
|
217
|
+
|
|
218
|
+
temp_inds_fix = temp_part_general[~keep.flatten()]
|
|
219
|
+
walker_inds_fix = walker_part_general[~keep.flatten()]
|
|
220
|
+
|
|
221
|
+
# return unaccepted psds
|
|
222
|
+
data_index_fix = data_index[~keep.flatten()]
|
|
223
|
+
psd_params_fix = new_state.branches_coords["psd"][(temp_inds_fix, walker_inds_fix)][:, 0]
|
|
224
|
+
foreground_params_fix = new_state.branches_coords["galfor"][(temp_inds_fix, walker_inds_fix)][:, 0]
|
|
225
|
+
|
|
226
|
+
self.mgh.set_psd_vals(psd_params_fix, foreground_params=foreground_params_fix, overall_inds=data_index_fix)
|
|
227
|
+
|
|
228
|
+
self.accepted += accepted.astype(int)
|
|
229
|
+
self.num_proposals += 1
|
|
230
|
+
|
|
231
|
+
self.mempool.free_all_blocks()
|
|
232
|
+
|
|
233
|
+
if self.time % 200 == 0:
|
|
234
|
+
ll_after = self.mgh.get_ll(include_psd_info=True).flatten()[new_state.supplimental[:]["overall_inds"]].reshape(ntemps, nwalkers)
|
|
235
|
+
check = np.abs(new_state.log_like - ll_after).max()
|
|
236
|
+
if check > 1e-3:
|
|
237
|
+
breakpoint()
|
|
238
|
+
self.mgh.restore_base_injections()
|
|
239
|
+
|
|
240
|
+
for name in new_state.branches.keys():
|
|
241
|
+
if name not in ["gb", "gb"]:
|
|
242
|
+
continue
|
|
243
|
+
new_state_branch = new_state.branches[name]
|
|
244
|
+
coords_here = new_state_branch.coords[new_state_branch.inds]
|
|
245
|
+
ntemps, nwalkers, nleaves_max_here, ndim = new_state_branch.shape
|
|
246
|
+
try:
|
|
247
|
+
group_index = self.xp.asarray(
|
|
248
|
+
self.mgh.get_mapped_indices(
|
|
249
|
+
np.repeat(np.arange(ntemps * nwalkers).reshape(ntemps, nwalkers, 1), nleaves_max, axis=-1)[new_state_branch.inds]
|
|
250
|
+
).astype(self.xp.int32)
|
|
251
|
+
)
|
|
252
|
+
except IndexError:
|
|
253
|
+
breakpoint()
|
|
254
|
+
coords_here_in = self.parameter_transforms.both_transforms(coords_here, xp=np)
|
|
255
|
+
|
|
256
|
+
waveform_kwargs_fill = self.waveform_kwargs.copy()
|
|
257
|
+
waveform_kwargs_fill["start_freq_ind"] = self.start_freq_ind
|
|
258
|
+
|
|
259
|
+
if "N" in waveform_kwargs_fill:
|
|
260
|
+
waveform_kwargs_fill.pop("N")
|
|
261
|
+
|
|
262
|
+
self.mgh.multiply_data(-1.)
|
|
263
|
+
self.gb.generate_global_template(coords_here_in, group_index, self.mgh.data_list, data_length=self.data_length, data_splits=self.mgh.gpu_splits, batch_size=1000, **waveform_kwargs_fill)
|
|
264
|
+
self.mgh.multiply_data(-1.)
|
|
265
|
+
|
|
266
|
+
ll_after2 = self.mgh.get_ll(use_cpu=True).flatten()[new_state.supplimental[:]["overall_inds"]].reshape(ntemps, nwalkers)
|
|
267
|
+
new_state.log_like = ll_after2
|
|
268
|
+
|
|
269
|
+
"""
|
|
270
|
+
data_minus_template = self.xp.concatenate(
|
|
271
|
+
[
|
|
272
|
+
tmp.reshape(ntemps, nwalkers, 1, self.data_length) for tmp in data_minus_template_in_swap
|
|
273
|
+
],
|
|
274
|
+
axis=2
|
|
275
|
+
)
|
|
276
|
+
del data_minus_template_in_swap
|
|
277
|
+
|
|
278
|
+
psd = self.xp.concatenate(
|
|
279
|
+
[
|
|
280
|
+
tmp.reshape(ntemps * nwalkers, 1, self.data_length) for tmp in psd_in_swap
|
|
281
|
+
],
|
|
282
|
+
axis=1
|
|
283
|
+
)
|
|
284
|
+
del psd_in_swap
|
|
285
|
+
self.mempool.free_all_blocks()
|
|
286
|
+
|
|
287
|
+
new_state.supplimental.holder["data_minus_template"] = data_minus_template
|
|
288
|
+
|
|
289
|
+
lp_after = model.compute_log_prior_fn(new_state.branches_coords, inds=new_state.branches_inds)
|
|
290
|
+
|
|
291
|
+
ll_after = (-1/2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / self.xp.asarray(self.psd), axis=(2, 3))).get() # model.compute_log_like_fn(new_state.branches_coords, inds=new_state.branches_inds, logp=lp_after, supps=new_state.supplimental, branch_supps=new_state.branches_supplimental)
|
|
292
|
+
#check = -1/2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / self.xp.asarray(self.psd), axis=(2, 3))
|
|
293
|
+
#check2 = -1/2 * 4 * self.df * self.xp.sum(tmp.conj() * tmp / self.xp.asarray(self.psd), axis=(2, 3))
|
|
294
|
+
#print(np.abs(new_state.log_like - ll_after[0]).max())
|
|
295
|
+
|
|
296
|
+
# if any are even remotely getting to be different, reset all (small change)
|
|
297
|
+
if np.abs(new_state.log_like - ll_after).max() > 1e-1:
|
|
298
|
+
if np.abs(new_state.log_like - ll_after).max() > 1e0:
|
|
299
|
+
self.greater_than_1e0 += 1
|
|
300
|
+
print("Greater:", self.greater_than_1e0)
|
|
301
|
+
breakpoint()
|
|
302
|
+
fix_here = np.abs(new_state.log_like - ll_after) > 1e-6
|
|
303
|
+
data_minus_template_old = data_minus_template.copy()
|
|
304
|
+
data_minus_template = self.xp.zeros_like(data_minus_template_old)
|
|
305
|
+
data_minus_template[:] = self.xp.asarray(self.data)[None, None]
|
|
306
|
+
templates = self.xp.zeros_like(data_minus_template).reshape(-1, 2, data_minus_template.shape[-1])
|
|
307
|
+
for name in new_state.branches.keys():
|
|
308
|
+
if name not in ["gb", "gb"]:
|
|
309
|
+
continue
|
|
310
|
+
new_state_branch = new_state.branches[name]
|
|
311
|
+
coords_here = new_state_branch.coords[new_state_branch.inds]
|
|
312
|
+
ntemps, nwalkers, nleaves_max_here, ndim = new_state_branch.shape
|
|
313
|
+
try:
|
|
314
|
+
group_index = np.repeat(np.arange(ntemps * nwalkers).reshape(ntemps, nwalkers, 1), nleaves_max, axis=-1)[new_state_branch.inds]
|
|
315
|
+
except IndexError:
|
|
316
|
+
breakpoint()
|
|
317
|
+
coords_here_in = self.parameter_transforms.both_transforms(coords_here, xp=np)
|
|
318
|
+
|
|
319
|
+
self.gb.generate_global_template(coords_here_in, group_index, templates, batch_size=1000, **self.waveform_kwargs)
|
|
320
|
+
|
|
321
|
+
data_minus_template -= templates.reshape(ntemps, nwalkers, 2, templates.shape[-1])
|
|
322
|
+
|
|
323
|
+
new_like = -1 / 2 * 4 * self.df * self.xp.sum(data_minus_template.conj() * data_minus_template / psd, axis=(2, 3)).real.get()
|
|
324
|
+
|
|
325
|
+
new_like += self.noise_ll
|
|
326
|
+
new_state.log_like[:] = new_like.reshape(ntemps, nwalkers)
|
|
327
|
+
|
|
328
|
+
self.mempool.free_all_blocks()
|
|
329
|
+
data_minus_template_in_swap = [data_minus_template[:,:, 0, :].flatten().copy(), data_minus_template[:,:, 1, :].flatten().copy()]
|
|
330
|
+
del data_minus_template
|
|
331
|
+
|
|
332
|
+
psd_in_swap = [psd[:, 0, :].flatten().copy(), psd[:, 1, :].flatten().copy()]
|
|
333
|
+
self.mempool.free_all_blocks()
|
|
334
|
+
del psd
|
|
335
|
+
self.mempool.free_all_blocks()
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
self.mempool.free_all_blocks()
|
|
339
|
+
|
|
340
|
+
if self.temperature_control is not None:
|
|
341
|
+
new_state = self.temperature_control.temper_comps(new_state, adapt=False)
|
|
342
|
+
|
|
343
|
+
"""# new_state, accepted = self.temperature_control.temper_comps(new_state, accepted)
|
|
344
|
+
self.swaps_accepted = np.zeros(ntemps - 1)
|
|
345
|
+
self.attempted_swaps = np.zeros(ntemps - 1)
|
|
346
|
+
betas = self.temperature_control.betas
|
|
347
|
+
for i in range(ntemps - 1, 0, -1):
|
|
348
|
+
bi = betas[i]
|
|
349
|
+
bi1 = betas[i - 1]
|
|
350
|
+
|
|
351
|
+
dbeta = bi1 - bi
|
|
352
|
+
|
|
353
|
+
iperm = np.random.permutation(nwalkers)
|
|
354
|
+
i1perm = np.random.permutation(nwalkers)
|
|
355
|
+
|
|
356
|
+
# need to calculate switch likelihoods
|
|
357
|
+
|
|
358
|
+
coords_iperm = new_state.branches["gb"].coords[i, iperm]
|
|
359
|
+
coords_i1perm = new_state.branches["gb"].coords[i - 1, i1perm]
|
|
360
|
+
|
|
361
|
+
N_vals_iperm = new_state.branches["gb"].branch_supplimental.holder["N_vals"][i, iperm]
|
|
362
|
+
|
|
363
|
+
N_vals_i1perm = new_state.branches["gb"].branch_supplimental.holder["N_vals"][i - 1, i1perm]
|
|
364
|
+
|
|
365
|
+
f_test_i = coords_iperm[None, :, :, 1] / 1e3
|
|
366
|
+
f_test_2_i = coords_i1perm[None, :, :, 1] / 1e3
|
|
367
|
+
|
|
368
|
+
fix_f_test_i = (np.abs(f_test_i - f_test_2_i) > (self.df * N_vals_iperm * 1.5))
|
|
369
|
+
|
|
370
|
+
if hasattr(self, "keep_bands") and self.keep_bands is not None:
|
|
371
|
+
band_indices = np.searchsorted(self.band_edges, f_test_i.flatten()).reshape(f_test_i.shape) - 1
|
|
372
|
+
keep_bands = self.keep_bands
|
|
373
|
+
assert isinstance(keep_bands, np.ndarray)
|
|
374
|
+
fix_f_test_i[~np.in1d(band_indices, keep_bands).reshape(band_indices.shape)] = True
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
groups = get_groups_from_band_structure(f_test_i, self.band_edges, f0_2=f_test_2_i, xp=np, num_groups_base=3, fix_f_test=fix_f_test_i)
|
|
378
|
+
|
|
379
|
+
unique_groups, group_len = np.unique(groups.flatten(), return_counts=True)
|
|
380
|
+
|
|
381
|
+
# remove information about the bad "-1" group
|
|
382
|
+
for check_val in [-1, -2]:
|
|
383
|
+
group_len = np.delete(group_len, unique_groups == check_val)
|
|
384
|
+
unique_groups = np.delete(unique_groups, unique_groups == check_val)
|
|
385
|
+
|
|
386
|
+
# needs to be max because some values may be missing due to evens and odds
|
|
387
|
+
num_groups = unique_groups.max().item() + 1
|
|
388
|
+
|
|
389
|
+
for group_iter in range(num_groups):
|
|
390
|
+
# st = time.perf_counter()
|
|
391
|
+
# sometimes you will have an extra odd or even group only
|
|
392
|
+
# the group_iter may not match the actual running group number in this case
|
|
393
|
+
if group_iter not in groups:
|
|
394
|
+
continue
|
|
395
|
+
|
|
396
|
+
group = [grp[i:i+1][groups == group_iter].flatten() for grp in group_temp_finder]
|
|
397
|
+
|
|
398
|
+
# st = time.perf_counter()
|
|
399
|
+
temp_inds_back, walkers_inds_back, leaf_inds = [self.xp.asarray(grp) for grp in group]
|
|
400
|
+
|
|
401
|
+
temp_inds_i = temp_inds_back.copy()
|
|
402
|
+
walkers_inds_i = walkers_inds_back.copy()
|
|
403
|
+
|
|
404
|
+
temp_inds_i1 = temp_inds_back.copy()
|
|
405
|
+
walkers_inds_i1 = walkers_inds_back.copy()
|
|
406
|
+
|
|
407
|
+
temp_inds_i[:] = i
|
|
408
|
+
walkers_inds_i[:] = self.xp.asarray(iperm)[walkers_inds_back]
|
|
409
|
+
|
|
410
|
+
temp_inds_i1[:] = i - 1
|
|
411
|
+
walkers_inds_i1[:] = self.xp.asarray(i1perm)[walkers_inds_back]
|
|
412
|
+
|
|
413
|
+
group_here_i = (temp_inds_i, walkers_inds_i, leaf_inds)
|
|
414
|
+
|
|
415
|
+
group_here_i1 = (temp_inds_i1, walkers_inds_i1, leaf_inds)
|
|
416
|
+
|
|
417
|
+
# factors_here = factors[group_here]
|
|
418
|
+
old_points = self.xp.asarray(new_state.branches["gb"].coords)[group_here_i]
|
|
419
|
+
new_points = self.xp.asarray(new_state.branches["gb"].coords)[group_here_i1]
|
|
420
|
+
|
|
421
|
+
N_vals_here_i = N_vals[group_here_i]
|
|
422
|
+
|
|
423
|
+
log_like_tmp = self.xp.asarray(new_state.log_like.copy())
|
|
424
|
+
log_prior_tmp = self.xp.asarray(new_state.log_prior.copy())
|
|
425
|
+
|
|
426
|
+
delta_logl_i = self.run_swap_ll(None, old_points, new_points, group_here_i, N_vals_here_i, waveform_kwargs_now, None, log_like_tmp, log_prior_tmp, return_at_logl=True)
|
|
427
|
+
|
|
428
|
+
# factors_here = factors[group_here]
|
|
429
|
+
old_points[:] = self.xp.asarray(new_state.branches["gb"].coords)[group_here_i1]
|
|
430
|
+
new_points[:] = self.xp.asarray(new_state.branches["gb"].coords)[group_here_i]
|
|
431
|
+
|
|
432
|
+
N_vals_here_i1 = N_vals[group_here_i1]
|
|
433
|
+
|
|
434
|
+
log_like_tmp[:] = self.xp.asarray(new_state.log_like.copy())
|
|
435
|
+
log_prior_tmp[:] = self.xp.asarray(new_state.log_prior.copy())
|
|
436
|
+
|
|
437
|
+
delta_logl_i1 = self.run_swap_ll(None, old_points, new_points, group_here_i1, N_vals_here_i1, waveform_kwargs_now, None, log_like_tmp, log_prior_tmp, return_at_logl=True)
|
|
438
|
+
|
|
439
|
+
paccept = dbeta * 1. / 2. * (delta_logl_i - delta_logl_i1)
|
|
440
|
+
raccept = np.log(np.random.uniform(size=paccept.shape[0]))
|
|
441
|
+
|
|
442
|
+
# How many swaps were accepted?
|
|
443
|
+
sel = paccept > self.xp.asarray(raccept)
|
|
444
|
+
|
|
445
|
+
inds_i_swap = tuple([tmp[sel].get() for tmp in list(group_here_i)])
|
|
446
|
+
inds_i1_swap = tuple([tmp[sel].get() for tmp in list(group_here_i1)])
|
|
447
|
+
|
|
448
|
+
group_index_i = self.xp.asarray(
|
|
449
|
+
self.mgh.get_mapped_indices(
|
|
450
|
+
temp_inds_i[sel] + nwalkers * walkers_inds_i[sel]
|
|
451
|
+
)
|
|
452
|
+
).astype(self.xp.int32)
|
|
453
|
+
|
|
454
|
+
group_index_i1 = self.xp.asarray(
|
|
455
|
+
self.mgh.get_mapped_indices(
|
|
456
|
+
temp_inds_i1[sel] + nwalkers * walkers_inds_i1[sel]
|
|
457
|
+
)
|
|
458
|
+
).astype(self.xp.int32)
|
|
459
|
+
|
|
460
|
+
N_vals_i = N_vals[inds_i_swap]
|
|
461
|
+
params_i = self.xp.asarray(new_state.branches["gb"].coords)[inds_i_swap]
|
|
462
|
+
params_i1 = self.xp.asarray(new_state.branches["gb"].coords)[inds_i1_swap]
|
|
463
|
+
|
|
464
|
+
params_generate = self.xp.concatenate([
|
|
465
|
+
params_i,
|
|
466
|
+
params_i1,
|
|
467
|
+
params_i1, # reverse of above
|
|
468
|
+
params_i,
|
|
469
|
+
], axis=0)
|
|
470
|
+
|
|
471
|
+
params_generate_in = self.parameter_transforms.both_transforms(params_generate, xp=self.xp)
|
|
472
|
+
|
|
473
|
+
group_index_gen = self.xp.concatenate(
|
|
474
|
+
[
|
|
475
|
+
group_index_i,
|
|
476
|
+
group_index_i,
|
|
477
|
+
group_index_i1,
|
|
478
|
+
group_index_i1
|
|
479
|
+
], dtype=self.xp.int32
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
factors_multiply_generate = self.xp.concatenate([
|
|
483
|
+
+1 * self.xp.ones_like(group_index_i, dtype=float),
|
|
484
|
+
-1 * self.xp.ones_like(group_index_i, dtype=float),
|
|
485
|
+
+1 * self.xp.ones_like(group_index_i, dtype=float),
|
|
486
|
+
-1 * self.xp.ones_like(group_index_i, dtype=float),
|
|
487
|
+
])
|
|
488
|
+
|
|
489
|
+
N_vals_in_gen = self.xp.concatenate([
|
|
490
|
+
N_vals_i,
|
|
491
|
+
N_vals_i,
|
|
492
|
+
N_vals_i,
|
|
493
|
+
N_vals_i
|
|
494
|
+
])
|
|
495
|
+
|
|
496
|
+
waveform_kwargs_fill = waveform_kwargs_now.copy()
|
|
497
|
+
waveform_kwargs_fill["start_freq_ind"] = self.start_freq_ind
|
|
498
|
+
|
|
499
|
+
self.gb.generate_global_template(
|
|
500
|
+
params_generate_in,
|
|
501
|
+
group_index_gen,
|
|
502
|
+
self.mgh.data_list,
|
|
503
|
+
N=N_vals_in_gen,
|
|
504
|
+
data_length=self.data_length,
|
|
505
|
+
data_splits=self.mgh.gpu_splits,
|
|
506
|
+
factors=factors_multiply_generate,
|
|
507
|
+
**waveform_kwargs_fill
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# update likelihoods
|
|
511
|
+
|
|
512
|
+
# set unaccepted differences to zero
|
|
513
|
+
accepted_delta_ll_i = delta_logl_i * (sel)
|
|
514
|
+
accepted_delta_ll_i1 = delta_logl_i1 * (sel)
|
|
515
|
+
|
|
516
|
+
logl_change_contribution = np.zeros_like(log_like_tmp.get())
|
|
517
|
+
try:
|
|
518
|
+
in_tuple = (accepted_delta_ll_i[sel].get(), accepted_delta_ll_i1[sel].get(), temp_inds_i[sel].get(), temp_inds_i1[sel].get(), walkers_inds_i[sel].get(), walkers_inds_i[sel].get())
|
|
519
|
+
except AttributeError:
|
|
520
|
+
in_tuple = (accepted_delta_ll_i[sel], accepted_delta_ll_i1[sel], temp_inds_i[sel], temp_inds_i1[sel], walkers_inds_i[sel], walkers_inds_i[sel])
|
|
521
|
+
for j, (dlli, dlli1, ti, ti1, wi, wi1) in enumerate(zip(*in_tuple)):
|
|
522
|
+
logl_change_contribution[ti, wi] += dlli
|
|
523
|
+
logl_change_contribution[ti1, wi1] += dlli1
|
|
524
|
+
|
|
525
|
+
log_like_tmp[:] += self.xp.asarray(logl_change_contribution)
|
|
526
|
+
|
|
527
|
+
tmp_swap = new_state.branches["gb"].coords[inds_i_swap]
|
|
528
|
+
new_state.branches["gb"].coords[inds_i_swap] = new_state.branches["gb"].coords[inds_i1_swap]
|
|
529
|
+
|
|
530
|
+
new_state.branches["gb"].coords[inds_i1_swap] = tmp_swap
|
|
531
|
+
|
|
532
|
+
tmp_swap = new_state.branches["gb"].branch_supplimental[inds_i_swap]
|
|
533
|
+
|
|
534
|
+
new_state.branches["gb"].branch_supplimental[inds_i_swap] = new_state.branches["gb"].branch_supplimental[inds_i1_swap]
|
|
535
|
+
|
|
536
|
+
new_state.branches["gb"].branch_supplimental[inds_i1_swap] = tmp_swap
|
|
537
|
+
|
|
538
|
+
# inds are all non-zero
|
|
539
|
+
self.swaps_accepted[i - 1] += np.sum(sel)
|
|
540
|
+
self.attempted_swaps[i - 1] += sel.shape[0]
|
|
541
|
+
|
|
542
|
+
ll_after = self.mgh.get_ll(use_cpu=True).flatten()[new_state.supplimental[:]["overall_inds"]].reshape(ntemps, nwalkers)
|
|
543
|
+
breakpoint()
|
|
544
|
+
"""
|
|
545
|
+
else:
|
|
546
|
+
self.temperature_control.swaps_accepted = np.zeros((ntemps - 1))
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
if np.any(new_state.log_like > 1e10):
|
|
550
|
+
breakpoint()
|
|
551
|
+
|
|
552
|
+
self.time += 1
|
|
553
|
+
#self.xp.cuda.runtime.deviceSynchronize()
|
|
554
|
+
#et = time.perf_counter()
|
|
555
|
+
#print("end stretch", (et - st))
|
|
556
|
+
|
|
557
|
+
self.mgh.map = new_state.supplimental.holder["overall_inds"].flatten()
|
|
558
|
+
|
|
559
|
+
"""et = time.perf_counter()
|
|
560
|
+
print("end", (et - st), group_iter, group_len[group_iter])"""
|
|
561
|
+
|
|
562
|
+
# breakpoint()
|
|
563
|
+
return new_state, accepted
|
|
564
|
+
|