lisaanalysistools 1.1.20__cp39-cp39-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lisaanalysistools/git_version.py +7 -0
- lisaanalysistools-1.1.20.dist-info/METADATA +281 -0
- lisaanalysistools-1.1.20.dist-info/RECORD +48 -0
- lisaanalysistools-1.1.20.dist-info/WHEEL +5 -0
- lisaanalysistools-1.1.20.dist-info/licenses/LICENSE +201 -0
- lisatools/.dylibs/libgcc_s.1.1.dylib +0 -0
- lisatools/.dylibs/libstdc++.6.dylib +0 -0
- lisatools/__init__.py +90 -0
- lisatools/_version.py +34 -0
- lisatools/analysiscontainer.py +474 -0
- lisatools/cutils/Detector.cu +307 -0
- lisatools/cutils/Detector.hpp +84 -0
- lisatools/cutils/__init__.py +129 -0
- lisatools/cutils/global.hpp +28 -0
- lisatools/cutils/pycppdetector.pyx +256 -0
- lisatools/datacontainer.py +312 -0
- lisatools/detector.py +867 -0
- lisatools/diagnostic.py +990 -0
- lisatools/git_version.py.in +7 -0
- lisatools/orbit_files/equalarmlength-orbits-best-fit-to-esa.h5 +0 -0
- lisatools/orbit_files/equalarmlength-orbits.h5 +0 -0
- lisatools/orbit_files/esa-trailing-orbits.h5 +0 -0
- lisatools/sampling/__init__.py +0 -0
- lisatools/sampling/likelihood.py +882 -0
- lisatools/sampling/moves/__init__.py +0 -0
- lisatools/sampling/moves/skymodehop.py +110 -0
- lisatools/sampling/prior.py +646 -0
- lisatools/sampling/stopping.py +320 -0
- lisatools/sampling/utility.py +411 -0
- lisatools/sensitivity.py +1554 -0
- lisatools/sources/__init__.py +6 -0
- lisatools/sources/bbh/__init__.py +1 -0
- lisatools/sources/bbh/waveform.py +106 -0
- lisatools/sources/defaultresponse.py +37 -0
- lisatools/sources/emri/__init__.py +1 -0
- lisatools/sources/emri/waveform.py +79 -0
- lisatools/sources/gb/__init__.py +1 -0
- lisatools/sources/gb/waveform.py +69 -0
- lisatools/sources/utils.py +459 -0
- lisatools/sources/waveformbase.py +41 -0
- lisatools/stochastic.py +327 -0
- lisatools/utils/__init__.py +0 -0
- lisatools/utils/constants.py +54 -0
- lisatools/utils/exceptions.py +95 -0
- lisatools/utils/parallelbase.py +11 -0
- lisatools/utils/utility.py +122 -0
- lisatools_backend_cpu/git_version.py +7 -0
- lisatools_backend_cpu/pycppdetector.cpython-39-darwin.so +0 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
from multiprocessing.sharedctypes import Value
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import cupy as xp
|
|
8
|
+
|
|
9
|
+
except (ImportError, ModuleNotFoundError) as e:
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
from eryn.state import State, BranchSupplemental
|
|
13
|
+
from eryn.utils.utility import groups_from_inds
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DetermineGBGroups:
|
|
17
|
+
def __init__(self, gb_wave_generator, transform_fn=None, waveform_kwargs={}):
|
|
18
|
+
self.gb_wave_generator = gb_wave_generator
|
|
19
|
+
self.xp = self.gb_wave_generator.xp
|
|
20
|
+
self.transform_fn = transform_fn
|
|
21
|
+
self.waveform_kwargs = waveform_kwargs
|
|
22
|
+
|
|
23
|
+
def __call__(
|
|
24
|
+
self,
|
|
25
|
+
last_sample,
|
|
26
|
+
name_here,
|
|
27
|
+
check_temp=0,
|
|
28
|
+
input_groups=None,
|
|
29
|
+
input_groups_inds=None,
|
|
30
|
+
fix_group_count=False,
|
|
31
|
+
mismatch_lim=0.2,
|
|
32
|
+
double_check_lim=0.2,
|
|
33
|
+
start_term="random",
|
|
34
|
+
waveform_kwargs={},
|
|
35
|
+
index_within_group="random",
|
|
36
|
+
):
|
|
37
|
+
# TODO: mess with mismatch lim setting
|
|
38
|
+
# TODO: some time of mismatch annealing may be useful
|
|
39
|
+
if isinstance(last_sample, State):
|
|
40
|
+
state = last_sample
|
|
41
|
+
coords = state.branches_coords[name_here][check_temp]
|
|
42
|
+
inds = state.branches_inds[name_here][check_temp]
|
|
43
|
+
elif isinstance(last_sample, dict):
|
|
44
|
+
coords = last_sample[name_here][check_temp]["coords"]
|
|
45
|
+
inds = last_sample[name_here][check_temp]["inds"]
|
|
46
|
+
|
|
47
|
+
waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
|
|
48
|
+
|
|
49
|
+
# get coordinates and inds of the temperature you are considering.
|
|
50
|
+
|
|
51
|
+
nwalkers, nleaves_max, ndim = coords.shape
|
|
52
|
+
if input_groups is None:
|
|
53
|
+
|
|
54
|
+
# figure our which walker to start with
|
|
55
|
+
if start_term == "max":
|
|
56
|
+
start_walker_ind = inds[check_temp].sum(axis=-1).argmax()
|
|
57
|
+
elif start_term == "first":
|
|
58
|
+
start_walker_ind = 0
|
|
59
|
+
elif start_term == "random":
|
|
60
|
+
start_walker_ind = np.random.randint(0, nwalkers)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError("start_term must be 'max', 'first', or 'random'.")
|
|
63
|
+
|
|
64
|
+
# get all the good leaves in this walker
|
|
65
|
+
inds_good = np.where(inds[start_walker_ind])[0]
|
|
66
|
+
groups = []
|
|
67
|
+
groups_inds = []
|
|
68
|
+
|
|
69
|
+
# set up this information to load the information into the group lists
|
|
70
|
+
for leaf_i, leaf in enumerate(inds_good):
|
|
71
|
+
groups.append([])
|
|
72
|
+
groups_inds.append([])
|
|
73
|
+
groups[leaf_i].append(coords[start_walker_ind, leaf].copy())
|
|
74
|
+
groups_inds[leaf_i].append([start_walker_ind, leaf])
|
|
75
|
+
else:
|
|
76
|
+
# allows us to check groups based on groups we already have
|
|
77
|
+
groups = input_groups
|
|
78
|
+
groups_inds = input_groups_inds
|
|
79
|
+
|
|
80
|
+
if len(groups) == 0:
|
|
81
|
+
return [], [], []
|
|
82
|
+
for w in range(coords.shape[0]):
|
|
83
|
+
|
|
84
|
+
# we have already loaded this group
|
|
85
|
+
if input_groups is None and w == start_walker_ind:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
# walker has no binaries
|
|
89
|
+
if not np.any(inds[w]):
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# coords in this walker
|
|
93
|
+
coords_here = coords[w][inds[w]]
|
|
94
|
+
inds_for_group_stuff = np.arange(len(inds[w]))[inds[w]]
|
|
95
|
+
nleaves, ndim = coords_here.shape
|
|
96
|
+
|
|
97
|
+
params_for_test = []
|
|
98
|
+
for group in groups:
|
|
99
|
+
group_params = np.asarray(group)
|
|
100
|
+
|
|
101
|
+
if index_within_group == "first":
|
|
102
|
+
test_walker_ind = 0
|
|
103
|
+
elif index_within_group == "random":
|
|
104
|
+
test_walker_ind = np.random.randint(0, group_params.shape[0])
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError("start_term must be 'max', 'first', or 'random'.")
|
|
107
|
+
|
|
108
|
+
params_for_test.append(group_params[test_walker_ind])
|
|
109
|
+
params_for_test = np.asarray(params_for_test)
|
|
110
|
+
|
|
111
|
+
# transform coords
|
|
112
|
+
if self.transform_fn is not None:
|
|
113
|
+
params_for_test_in = self.transform_fn[name_here].both_transforms(
|
|
114
|
+
params_for_test, return_transpose=False
|
|
115
|
+
)
|
|
116
|
+
coords_here_in = self.transform_fn[name_here].both_transforms(
|
|
117
|
+
coords_here, return_transpose=False
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
else:
|
|
121
|
+
params_for_test_in = params_for_test.copy()
|
|
122
|
+
coords_here_in = coords_here.copy()
|
|
123
|
+
|
|
124
|
+
inds_tmp_test = np.arange(len(params_for_test_in))
|
|
125
|
+
inds_tmp_here = np.arange(len(coords_here_in))
|
|
126
|
+
inds_tmp_test, inds_tmp_here = [
|
|
127
|
+
tmp.ravel() for tmp in np.meshgrid(inds_tmp_test, inds_tmp_here)
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
params_for_test_in_full = params_for_test_in[inds_tmp_test]
|
|
131
|
+
coords_here_in_full = coords_here_in[inds_tmp_here]
|
|
132
|
+
# build the waveforms at the same time
|
|
133
|
+
|
|
134
|
+
df = 1.0 / waveform_kwargs["T"]
|
|
135
|
+
max_f = 1.0 / 2 * 1 / waveform_kwargs["dt"]
|
|
136
|
+
frqs = self.xp.arange(0.0, max_f, df)
|
|
137
|
+
data_minus_template = self.xp.asarray(
|
|
138
|
+
[
|
|
139
|
+
self.xp.ones_like(frqs, dtype=complex),
|
|
140
|
+
self.xp.ones_like(frqs, dtype=complex),
|
|
141
|
+
]
|
|
142
|
+
)[None, :, :]
|
|
143
|
+
psd = self.xp.asarray(
|
|
144
|
+
[
|
|
145
|
+
self.xp.ones_like(frqs, dtype=np.float64),
|
|
146
|
+
self.xp.ones_like(frqs, dtype=np.float64),
|
|
147
|
+
]
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
waveform_kwargs_fill = waveform_kwargs.copy()
|
|
151
|
+
waveform_kwargs_fill.pop("start_freq_ind")
|
|
152
|
+
|
|
153
|
+
# TODO: could use real data and get observed snr for each if needed
|
|
154
|
+
check = self.gb_wave_generator.swap_likelihood_difference(
|
|
155
|
+
params_for_test_in_full,
|
|
156
|
+
coords_here_in_full,
|
|
157
|
+
data_minus_template,
|
|
158
|
+
psd,
|
|
159
|
+
start_freq_ind=0,
|
|
160
|
+
data_index=None,
|
|
161
|
+
noise_index=None,
|
|
162
|
+
**waveform_kwargs_fill,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
numerator = self.gb_wave_generator.add_remove
|
|
166
|
+
norm_here = self.gb_wave_generator.add_add
|
|
167
|
+
norm_for_test = self.gb_wave_generator.remove_remove
|
|
168
|
+
|
|
169
|
+
normalized_autocorr = numerator / np.sqrt(norm_here * norm_for_test)
|
|
170
|
+
normalized_against_test = numerator / norm_for_test
|
|
171
|
+
|
|
172
|
+
normalized_autocorr = normalized_autocorr.reshape(
|
|
173
|
+
coords_here_in.shape[0], params_for_test_in.shape[0]
|
|
174
|
+
).real
|
|
175
|
+
normalized_against_test = normalized_against_test.reshape(
|
|
176
|
+
coords_here_in.shape[0], params_for_test_in.shape[0]
|
|
177
|
+
).real
|
|
178
|
+
|
|
179
|
+
# TODO: do based on Likelihood? make sure on same posterior
|
|
180
|
+
# TODO: add check based on amplitude
|
|
181
|
+
test1 = np.abs(
|
|
182
|
+
1.0 - normalized_autocorr.real
|
|
183
|
+
) # (numerator / norm_for_test[None, :]).real)
|
|
184
|
+
best = test1.argmin(axis=1)
|
|
185
|
+
try:
|
|
186
|
+
best = best.get()
|
|
187
|
+
except AttributeError:
|
|
188
|
+
pass
|
|
189
|
+
best_mismatch = test1[(np.arange(test1.shape[0]), best)]
|
|
190
|
+
check_normalized_against_test = np.abs(
|
|
191
|
+
1.0 - normalized_against_test[(np.arange(test1.shape[0]), best)]
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
f0_here = coords_here[:, 1]
|
|
195
|
+
f0_test = params_for_test[best, 1]
|
|
196
|
+
|
|
197
|
+
for leaf in range(nleaves):
|
|
198
|
+
if (
|
|
199
|
+
best_mismatch[leaf] < mismatch_lim
|
|
200
|
+
and check_normalized_against_test[leaf] < double_check_lim
|
|
201
|
+
):
|
|
202
|
+
groups[best[leaf]].append(coords_here[leaf].copy())
|
|
203
|
+
groups_inds[best[leaf]].append([w, inds_for_group_stuff[leaf]])
|
|
204
|
+
|
|
205
|
+
elif not fix_group_count:
|
|
206
|
+
# this only works for high snr limit
|
|
207
|
+
groups.append([coords_here[leaf]].copy())
|
|
208
|
+
groups_inds.append([[w, inds_for_group_stuff[leaf]]])
|
|
209
|
+
|
|
210
|
+
group_lens = [len(group) for group in groups]
|
|
211
|
+
|
|
212
|
+
return groups, groups_inds, group_lens
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class GetLastGBState:
|
|
216
|
+
def __init__(self, gb_wave_generator, transform_fn=None, waveform_kwargs={}):
|
|
217
|
+
self.gb_wave_generator = gb_wave_generator
|
|
218
|
+
self.xp = self.gb_wave_generator.xp
|
|
219
|
+
self.transform_fn = transform_fn
|
|
220
|
+
self.waveform_kwargs = waveform_kwargs
|
|
221
|
+
|
|
222
|
+
def __call__(
|
|
223
|
+
self,
|
|
224
|
+
mgh,
|
|
225
|
+
reader,
|
|
226
|
+
df,
|
|
227
|
+
supps_base_shape,
|
|
228
|
+
fix_temp_initial_ind: int = None,
|
|
229
|
+
fix_temp_inds: list = None,
|
|
230
|
+
nleaves_max_in=None,
|
|
231
|
+
waveform_kwargs={},
|
|
232
|
+
):
|
|
233
|
+
|
|
234
|
+
xp.cuda.runtime.setDevice(mgh.gpus[0])
|
|
235
|
+
|
|
236
|
+
if fix_temp_initial_ind is not None or fix_temp_inds is not None:
|
|
237
|
+
if fix_temp_initial_ind is None or fix_temp_inds is None:
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"If giving fix_temp_initial_ind or fix_temp_inds, must give both."
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
state = reader.get_last_sample()
|
|
243
|
+
|
|
244
|
+
waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
|
|
245
|
+
if "start_freq_ind" not in waveform_kwargs:
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"In get_last_gb_state, waveform_kwargs must include 'start_freq_ind'."
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# check = reader.get_last_sample()
|
|
251
|
+
ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
|
|
252
|
+
|
|
253
|
+
# out = get_groups_for_remixing(check, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, name_here="gb")
|
|
254
|
+
|
|
255
|
+
# lengths = []
|
|
256
|
+
# for group in out[0]:
|
|
257
|
+
# lengths.append(len(group))
|
|
258
|
+
# breakpoint()
|
|
259
|
+
try:
|
|
260
|
+
if fix_temp_initial_ind is not None:
|
|
261
|
+
for i in fix_temp_inds:
|
|
262
|
+
if i < fix_temp_initial_ind:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"If providing fix_temp_initial_ind and fix_temp_inds, all values in fix_temp_inds must be greater than fix_temp_initial_ind."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
state.log_like[i] = state.log_like[fix_temp_initial_ind]
|
|
268
|
+
state.log_prior[i] = state.log_prior[fix_temp_initial_ind]
|
|
269
|
+
state.branches_coords["gb"][i] = state.branches_coords["gb"][
|
|
270
|
+
fix_temp_initial_ind
|
|
271
|
+
]
|
|
272
|
+
state.branches_coords["gb"][i] = state.branches_coords["gb"][
|
|
273
|
+
fix_temp_initial_ind
|
|
274
|
+
]
|
|
275
|
+
state.branches_inds["gb"][i] = state.branches_inds["gb"][
|
|
276
|
+
fix_temp_initial_ind
|
|
277
|
+
]
|
|
278
|
+
state.branches_inds["gb"][i] = state.branches_inds["gb"][
|
|
279
|
+
fix_temp_initial_ind
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
|
|
283
|
+
if nleaves_max_in is None:
|
|
284
|
+
nleaves_max = nleaves_max_old
|
|
285
|
+
else:
|
|
286
|
+
nleaves_max = nleaves_max_in
|
|
287
|
+
if nleaves_max_old <= nleaves_max:
|
|
288
|
+
coords_tmp = np.zeros((ntemps, nwalkers, nleaves_max, ndim))
|
|
289
|
+
coords_tmp[:, :, :nleaves_max_old, :] = state.branches["gb"].coords
|
|
290
|
+
|
|
291
|
+
inds_tmp = np.zeros((ntemps, nwalkers, nleaves_max), dtype=bool)
|
|
292
|
+
inds_tmp[:, :, :nleaves_max_old] = state.branches["gb"].inds
|
|
293
|
+
state.branches["gb"].coords = coords_tmp
|
|
294
|
+
state.branches["gb"].inds = inds_tmp
|
|
295
|
+
state.branches["gb"].nleaves_max = nleaves_max
|
|
296
|
+
state.branches["gb"].shape = (ntemps, nwalkers, nleaves_max, ndim)
|
|
297
|
+
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError("new nleaves_max is less than nleaves_max_old.")
|
|
300
|
+
|
|
301
|
+
# add "gb" if there are any
|
|
302
|
+
data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
|
|
303
|
+
|
|
304
|
+
data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(
|
|
305
|
+
xp.int32
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
params_add_in = self.transform_fn["gb"].both_transforms(
|
|
309
|
+
state.branches_coords["gb"][state.branches_inds["gb"]]
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# batch_size is ignored if waveform_kwargs["use_c_implementation"] is True
|
|
313
|
+
# -1 is to do -(-d + h) = d - h
|
|
314
|
+
mgh.multiply_data(-1.0)
|
|
315
|
+
self.gb_wave_generator.generate_global_template(
|
|
316
|
+
params_add_in,
|
|
317
|
+
data_index,
|
|
318
|
+
mgh.data_list,
|
|
319
|
+
data_length=mgh.data_length,
|
|
320
|
+
data_splits=mgh.gpu_splits,
|
|
321
|
+
batch_size=1000,
|
|
322
|
+
**waveform_kwargs,
|
|
323
|
+
)
|
|
324
|
+
mgh.multiply_data(-1.0)
|
|
325
|
+
|
|
326
|
+
except KeyError:
|
|
327
|
+
# no "gb"
|
|
328
|
+
pass
|
|
329
|
+
|
|
330
|
+
data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
|
|
331
|
+
data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(xp.int32)
|
|
332
|
+
|
|
333
|
+
params_add_in = self.transform_fn["gb"].both_transforms(
|
|
334
|
+
state.branches_coords["gb"][state.branches_inds["gb"]]
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# -1 is to do -(-d + h) = d - h
|
|
338
|
+
mgh.multiply_data(-1.0)
|
|
339
|
+
self.gb_wave_generator.generate_global_template(
|
|
340
|
+
params_add_in,
|
|
341
|
+
data_index,
|
|
342
|
+
mgh.data_list,
|
|
343
|
+
data_length=mgh.data_length,
|
|
344
|
+
data_splits=mgh.gpu_splits,
|
|
345
|
+
batch_size=1000,
|
|
346
|
+
**waveform_kwargs,
|
|
347
|
+
)
|
|
348
|
+
mgh.multiply_data(-1.0)
|
|
349
|
+
|
|
350
|
+
self.gb_wave_generator.d_d = np.asarray(mgh.get_inner_product(use_cpu=True))
|
|
351
|
+
|
|
352
|
+
state.log_like = (
|
|
353
|
+
-1 / 2 * self.gb_wave_generator.d_d.real.reshape(ntemps, nwalkers)
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
temp_inds = mgh.temp_indices.copy()
|
|
357
|
+
walker_inds = mgh.walker_indices.copy()
|
|
358
|
+
overall_inds = mgh.overall_indices.copy()
|
|
359
|
+
|
|
360
|
+
supps = BranchSupplemental(
|
|
361
|
+
{
|
|
362
|
+
"temp_inds": temp_inds,
|
|
363
|
+
"walker_inds": walker_inds,
|
|
364
|
+
"overall_inds": overall_inds,
|
|
365
|
+
},
|
|
366
|
+
obj_contained_shape=supps_base_shape,
|
|
367
|
+
copy=True,
|
|
368
|
+
)
|
|
369
|
+
state.supplimental = supps
|
|
370
|
+
|
|
371
|
+
return state
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class HeterodynedUpdate:
|
|
375
|
+
def __init__(self, update_kwargs, set_d_d_zero=False):
|
|
376
|
+
self.update_kwargs = update_kwargs
|
|
377
|
+
self.set_d_d_zero = set_d_d_zero
|
|
378
|
+
|
|
379
|
+
def __call__(self, it, sample_state, sampler, **kwargs):
|
|
380
|
+
|
|
381
|
+
samples = sample_state.branches_coords["mbh"].reshape(-1, sampler.ndims[0])
|
|
382
|
+
lp_max = sample_state.log_like.argmax()
|
|
383
|
+
best = samples[lp_max]
|
|
384
|
+
|
|
385
|
+
lp = sample_state.log_like.flatten()
|
|
386
|
+
sorted = np.argsort(lp)
|
|
387
|
+
inds_best = sorted[-1000:]
|
|
388
|
+
inds_worst = sorted[:1000]
|
|
389
|
+
|
|
390
|
+
best_full = sampler.log_like_fn.f.parameter_transforms["mbh"].both_transforms(
|
|
391
|
+
best, copy=True
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
sampler.log_like_fn.f.template_model.init_heterodyne_info(
|
|
395
|
+
best_full, **self.update_kwargs
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
if self.set_d_d_zero:
|
|
399
|
+
sampler.log_like_fn.f.template_model.reference_d_d = 0.0
|
|
400
|
+
|
|
401
|
+
# TODO: make this a general update function in Eryn (?)
|
|
402
|
+
# samples[inds_worst] = samples[inds_best].copy()
|
|
403
|
+
samples = samples.reshape(sampler.ntemps, sampler.nwalkers, 1, sampler.ndims[0])
|
|
404
|
+
logp = sampler.compute_log_prior({"mbh": samples})
|
|
405
|
+
logL, blobs = sampler.compute_log_like({"mbh": samples}, logp=logp)
|
|
406
|
+
|
|
407
|
+
sample_state.branches["mbh"].coords = samples
|
|
408
|
+
sample_state.log_like = logL
|
|
409
|
+
sample_state.blobs = blobs
|
|
410
|
+
|
|
411
|
+
# sampler.backend.save_step(sample_state, np.full_like(lp, True))
|