lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lisaanalysistools might be problematic. Click here for more details.

Files changed (37) hide show
  1. lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
  2. lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
  3. lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
  4. lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
  5. lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
  6. lisatools/__init__.py +0 -0
  7. lisatools/_version.py +4 -0
  8. lisatools/analysiscontainer.py +438 -0
  9. lisatools/cutils/detector.cpython-312-darwin.so +0 -0
  10. lisatools/datacontainer.py +292 -0
  11. lisatools/detector.py +410 -0
  12. lisatools/diagnostic.py +976 -0
  13. lisatools/glitch.py +193 -0
  14. lisatools/sampling/__init__.py +0 -0
  15. lisatools/sampling/likelihood.py +882 -0
  16. lisatools/sampling/moves/__init__.py +0 -0
  17. lisatools/sampling/moves/gbgroupstretch.py +53 -0
  18. lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
  19. lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
  20. lisatools/sampling/moves/gbspecialstretch.py +1836 -0
  21. lisatools/sampling/moves/mbhspecialmove.py +286 -0
  22. lisatools/sampling/moves/placeholder.py +16 -0
  23. lisatools/sampling/moves/skymodehop.py +110 -0
  24. lisatools/sampling/moves/specialforegroundmove.py +564 -0
  25. lisatools/sampling/prior.py +508 -0
  26. lisatools/sampling/stopping.py +320 -0
  27. lisatools/sampling/utility.py +324 -0
  28. lisatools/sensitivity.py +888 -0
  29. lisatools/sources/__init__.py +0 -0
  30. lisatools/sources/emri/__init__.py +1 -0
  31. lisatools/sources/emri/tdiwaveform.py +72 -0
  32. lisatools/stochastic.py +291 -0
  33. lisatools/utils/__init__.py +0 -0
  34. lisatools/utils/constants.py +40 -0
  35. lisatools/utils/multigpudataholder.py +730 -0
  36. lisatools/utils/pointeradjust.py +106 -0
  37. lisatools/utils/utility.py +240 -0
@@ -0,0 +1,320 @@
1
+ import time
2
+
3
+ import numpy as np
4
+
5
+ from eryn.utils.stopping import Stopping
6
+ from eryn.utils.utility import thermodynamic_integration_log_evidence
7
+
8
+
9
+ class SNRStopping(Stopping):
10
+ def __init__(self, snr_limit=100.0, verbose=False):
11
+ self.snr_limit = snr_limit
12
+ self.verbose = verbose
13
+
14
+ def __call__(self, iter, sample, sampler):
15
+
16
+ ind = sampler.get_log_like().argmax()
17
+
18
+ log_best = sampler.get_log_like().max()
19
+ snr_best = sampler.get_blobs()[:, :, :, 0].flatten()[ind]
20
+ # d_h_best = sampler.get_blobs()[:, :, :, 1].flatten()[ind]
21
+ # h_h_best = sampler.get_blobs()[:, :, :, 2].flatten()[ind]
22
+
23
+ if self.verbose:
24
+ print(
25
+ "snr_best",
26
+ snr_best,
27
+ "limit:",
28
+ self.snr_limit,
29
+ "loglike:",
30
+ log_best,
31
+ # d_h_best,
32
+ # h_h_best,
33
+ )
34
+
35
+ if snr_best > self.snr_limit:
36
+ return True
37
+
38
+ else:
39
+ return False
40
+
41
+
42
+ class NLeavesSearchStopping:
43
+ def __init__(self, convergence_iter=5, verbose=False):
44
+ self.convergence_iter = convergence_iter
45
+ self.verbose = verbose
46
+
47
+ def __call__(self, current_info):
48
+
49
+ if not hasattr(self, "st"):
50
+ self.st = time.perf_counter()
51
+
52
+ current_iter = current_info.gb_info["reader"].iteration
53
+
54
+ if current_iter > self.convergence_iter:
55
+
56
+ nleaves_cc = curr.gb_info["reader"].get_nleaves()["gb"][:, 0]
57
+
58
+ # do not include most recent
59
+ nleaves_cc_max_old = nleaves_cc[:-self.convergence_iter].max()
60
+ nleaves_cc_max_new = nleaves_cc[-self.convergence_iter:].max()
61
+
62
+ if nleaves_cc_max_old > nleaves_cc_max_new:
63
+ stop = True
64
+
65
+ else:
66
+ stop = False
67
+
68
+ if self.verbose:
69
+ dur = (time.perf_counter() - self.st) / 3600.0 # hours
70
+ print(
71
+ "\nnleaves max old:\n",
72
+ nleaves_cc_max_old,
73
+ "\nnleaves max new:\n",
74
+ nleaves_cc_max_newf,
75
+ f"\nTIME TO NOW: {dur} hours"
76
+ )
77
+
78
+ return stop
79
+
80
+
81
+ class SearchConvergeStopping(Stopping):
82
+ def __init__(self, n_iters=30, diff=1.0, verbose=False, start_iteration=0):
83
+ self.n_iters = n_iters
84
+ self.iters_consecutive = 0
85
+ self.past_like_best = -np.inf
86
+ self.diff = diff
87
+ self.verbose = verbose
88
+ self.start_iteration = start_iteration
89
+
90
+ def __call__(self, iter, sample, sampler):
91
+
92
+ like_best = sampler.get_log_like(discard=self.start_iteration).max()
93
+
94
+ if np.abs(like_best - self.past_like_best) < self.diff:
95
+ self.iters_consecutive += 1
96
+
97
+ else:
98
+ self.iters_consecutive = 0
99
+ self.past_like_best = like_best
100
+
101
+ if self.verbose:
102
+ print(
103
+ "\nITERS CONSECUTIVE:\n",
104
+ self.iters_consecutive,
105
+ self.past_like_best,
106
+ like_best,
107
+ )
108
+
109
+ if self.iters_consecutive >= self.n_iters:
110
+ self.iters_consecutive = 0
111
+ return True
112
+
113
+ else:
114
+ return False
115
+
116
+
117
+
118
+ class GBBandLogLConvergeStopping(Stopping):
119
+
120
+ def __init__(self, fd, band_edges, n_iters=30, diff=1.0, verbose=False, start_iteration=0):
121
+ self.band_edge_inds = np.searchsorted(fd, band_edges, side="right") - 1
122
+ self.num_bands = self.band_edge_inds.shape[0] - 1
123
+ self.converged = np.zeros(self.num_bands, dtype=bool)
124
+ self.iters_consecutive = np.zeros(self.num_bands, dtype=int)
125
+ self.past_like_best = np.full(self.num_bands, -np.inf)
126
+ self.n_iters = n_iters
127
+ self.diff = diff
128
+ self.verbose = verbose
129
+ self.start_iteration = start_iteration
130
+
131
+ def add_mgh(self, mgh):
132
+ self.mgh = mgh
133
+
134
+ def __call__(self, i, sample, sampler):
135
+
136
+ ll_per_band = self.mgh.get_ll(band_edge_inds=self.band_edge_inds).max(axis=0)
137
+
138
+ ll_movement = (ll_per_band - self.past_like_best) > self.diff
139
+
140
+ self.iters_consecutive[~ll_movement] += 1
141
+ self.iters_consecutive[ll_movement] = 0
142
+
143
+ self.converged = self.iters_consecutive >= self.n_iters
144
+
145
+ self.past_like_best[ll_movement] = ll_per_band[ll_movement]
146
+
147
+ # for move in sampler.all_moves:
148
+ # move.converged_sub_bands = self.converged.copy()
149
+
150
+ if self.verbose:
151
+ print("Num still going:", (~self.converged).sum(), "\nChanged here:", (ll_movement).sum())
152
+
153
+ if np.all(self.converged):
154
+ return True
155
+ else:
156
+ return False
157
+
158
+
159
+
160
+
161
+
162
+
163
+ class SearchConvergeStopping2(Stopping):
164
+ def __init__(self, n_iters=30, diff=0.1, verbose=False, start_iteration=0, iter_back_check=-1):
165
+ self.n_iters = n_iters
166
+ self.iters_consecutive = 0
167
+ self.past_like_best = -np.inf
168
+ self.diff = diff
169
+ self.verbose = verbose
170
+ self.start_iteration = start_iteration
171
+ self.iter_back_check = iter_back_check
172
+ self.time = 0
173
+ self.back_check = [None for _ in range(self.iter_back_check)]
174
+ self.last_sampler_iteration = self.start_iteration
175
+ self.back_check_ind = 0
176
+ self.stop_here = True
177
+
178
+ def __call__(self, iter, sample, sampler):
179
+
180
+ self.time += 1
181
+
182
+ if sampler.iteration <= self.start_iteration:
183
+ return False
184
+
185
+ lps = sampler.get_log_like(discard=self.start_iteration)[self.last_sampler_iteration - self.start_iteration:]
186
+ try:
187
+ like_best = lps.max()
188
+ except:
189
+ breakpoint()
190
+ self.last_sampler_iteration = sampler.iteration
191
+
192
+ if np.any(np.asarray(self.back_check) == None):
193
+ for i in range(len(self.back_check)):
194
+ if self.back_check[i] is None:
195
+ self.back_check[i] = like_best
196
+ return False
197
+
198
+ first_check = like_best - self.past_like_best > self.diff
199
+ second_check = np.all(like_best >= np.asarray(self.back_check))
200
+
201
+ # spread in stored values is below difference
202
+ third_check = np.asarray(self.back_check).max() - np.asarray(self.back_check).min() < self.diff
203
+
204
+ update = (
205
+ (first_check and second_check and self.past_like_best == -np.inf)
206
+ or (self.past_like_best == -np.inf and third_check)
207
+ or (self.past_like_best > -np.inf and first_check)
208
+ )
209
+
210
+ self.back_check[self.back_check_ind] = like_best
211
+ self.back_check_ind = (self.back_check_ind + 1) % len(self.back_check)
212
+
213
+ if update:
214
+ self.past_like_best = like_best
215
+ self.iters_consecutive = 0
216
+
217
+ elif self.past_like_best > -np.inf:
218
+ self.iters_consecutive += 1
219
+
220
+ if self.verbose:
221
+ print(
222
+ "\nITERS CONSECUTIVE:\n",
223
+ self.iters_consecutive,
224
+ f"previous best: {self.past_like_best}, overall best: {like_best},",
225
+ "first check:", first_check,
226
+ "second check:", second_check
227
+ )
228
+
229
+ if self.iters_consecutive >= self.n_iters:
230
+ self.iters_consecutive = 0
231
+ return True
232
+
233
+ else:
234
+ return False
235
+
236
+
237
+
238
+ class EvidenceStopping(Stopping):
239
+ def __init__(self, diff=0.5, verbose=False):
240
+ self.diff = diff
241
+ self.verbose = verbose
242
+
243
+ def __call__(self, iter, sample, sampler):
244
+
245
+ betas = sampler.get_betas()[-1]
246
+ logls = sampler.get_log_like().mean(axis=(0, 2))
247
+
248
+ logZ, dlogZ = thermodynamic_integration_log_evidence(betas, logls)
249
+ print(logZ, dlogZ)
250
+ return False
251
+
252
+
253
+ if self.verbose:
254
+ print(
255
+ "snr_best",
256
+ snr_best,
257
+ "limit:",
258
+ self.snr_limit,
259
+ "loglike:",
260
+ log_best,
261
+ # d_h_best,
262
+ # h_h_best,
263
+ )
264
+
265
+ if snr_best > self.snr_limit:
266
+ return True
267
+
268
+ else:
269
+ return False
270
+
271
+
272
+ class MPICommunicateStopping(Stopping):
273
+
274
+ def __init__(self, stopper_rank, other_ranks, stop_fn=None):
275
+
276
+ self.stopper_rank = stopper_rank
277
+ self.other_ranks = other_ranks
278
+ self.stop_fn = stop_fn
279
+
280
+ def add_comm(self, comm):
281
+ self.comm = comm
282
+
283
+ def __call__(self, *args, **kwargs):
284
+
285
+ if not hasattr(self, "comm"):
286
+ raise ValueError("Must add comm via add_comm method before __call__ is used.")
287
+
288
+ if not hasattr(self, "rank"):
289
+ self.rank = self.comm.Get_rank()
290
+ if not self.rank == self.stopper_rank and not self.rank in self.other_ranks:
291
+ raise ValueError("Rank is not available in other ranks list. Must be either stopper rank or in other ranks list.")
292
+
293
+ if self.stopper_rank == self.rank and self.stop_fn is None:
294
+ raise ValueError("Rank is equivalent to stopper rank but stop_fn is not provided. It must be provided.")
295
+
296
+ if self.rank == self.stopper_rank:
297
+ stop = self.stop_fn(*args, **kwargs)
298
+
299
+ if stop:
300
+ for rank in self.other_ranks:
301
+ tag = int(str(rank) + "1000")
302
+ self.comm.isend(True, dest=rank, tag=tag)
303
+
304
+ else:
305
+ tag = int(str(self.rank) + "1000")
306
+ check_stop = self.comm.irecv(source=self.stopper_rank, tag=tag)
307
+
308
+ if check_stop.get_status():
309
+ stop = check_stop.wait()
310
+
311
+ else:
312
+ check_stop.cancel()
313
+ stop = False
314
+
315
+ return stop
316
+
317
+
318
+
319
+
320
+
@@ -0,0 +1,324 @@
1
+ from multiprocessing.sharedctypes import Value
2
+ import os
3
+
4
+ import numpy as np
5
+
6
+ try:
7
+ import cupy as xp
8
+
9
+ except (ImportError, ModuleNotFoundError) as e:
10
+ pass
11
+
12
+ from eryn.state import State, BranchSupplimental
13
+ from eryn.utils.utility import groups_from_inds
14
+
15
+
16
+ class DetermineGBGroups:
17
+ def __init__(self, gb_wave_generator, transform_fn=None, waveform_kwargs={}):
18
+ self.gb_wave_generator = gb_wave_generator
19
+ self.xp = self.gb_wave_generator.xp
20
+ self.transform_fn = transform_fn
21
+ self.waveform_kwargs = waveform_kwargs
22
+
23
+ def __call__(self, last_sample, name_here, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, mismatch_lim=0.2, double_check_lim=0.2, start_term="random", waveform_kwargs={}, index_within_group="random"):
24
+ # TODO: mess with mismatch lim setting
25
+ # TODO: some time of mismatch annealing may be useful
26
+ if isinstance(last_sample, State):
27
+ state = last_sample
28
+ coords = state.branches_coords[name_here][check_temp]
29
+ inds = state.branches_inds[name_here][check_temp]
30
+ elif isinstance(last_sample, dict):
31
+ coords = last_sample[name_here][check_temp]["coords"]
32
+ inds = last_sample[name_here][check_temp]["inds"]
33
+
34
+ waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
35
+
36
+ # get coordinates and inds of the temperature you are considering.
37
+
38
+ nwalkers, nleaves_max, ndim = coords.shape
39
+ if input_groups is None:
40
+
41
+ # figure our which walker to start with
42
+ if start_term == "max":
43
+ start_walker_ind = inds[check_temp].sum(axis=-1).argmax()
44
+ elif start_term == "first":
45
+ start_walker_ind = 0
46
+ elif start_term == "random":
47
+ start_walker_ind = np.random.randint(0, nwalkers)
48
+ else:
49
+ raise ValueError("start_term must be 'max', 'first', or 'random'.")
50
+
51
+ # get all the good leaves in this walker
52
+ inds_good = np.where(inds[start_walker_ind])[0]
53
+ groups = []
54
+ groups_inds = []
55
+
56
+ # set up this information to load the information into the group lists
57
+ for leaf_i, leaf in enumerate(inds_good):
58
+ groups.append([])
59
+ groups_inds.append([])
60
+ groups[leaf_i].append(coords[start_walker_ind, leaf].copy())
61
+ groups_inds[leaf_i].append([start_walker_ind, leaf])
62
+ else:
63
+ # allows us to check groups based on groups we already have
64
+ groups = input_groups
65
+ groups_inds = input_groups_inds
66
+
67
+ if len(groups) == 0:
68
+ return [], [], []
69
+ for w in range(coords.shape[0]):
70
+
71
+ # we have already loaded this group
72
+ if input_groups is None and w == start_walker_ind:
73
+ continue
74
+
75
+ # walker has no binaries
76
+ if not np.any(inds[w]):
77
+ continue
78
+
79
+ # coords in this walker
80
+ coords_here = coords[w][inds[w]]
81
+ inds_for_group_stuff = np.arange(len(inds[w]))[inds[w]]
82
+ nleaves, ndim = coords_here.shape
83
+
84
+
85
+ params_for_test = []
86
+ for group in groups:
87
+ group_params = np.asarray(group)
88
+
89
+ if index_within_group == "first":
90
+ test_walker_ind = 0
91
+ elif index_within_group == "random":
92
+ test_walker_ind = np.random.randint(0, group_params.shape[0])
93
+ else:
94
+ raise ValueError("start_term must be 'max', 'first', or 'random'.")
95
+
96
+ params_for_test.append(group_params[test_walker_ind])
97
+ params_for_test = np.asarray(params_for_test)
98
+
99
+ # transform coords
100
+ if self.transform_fn is not None:
101
+ params_for_test_in = self.transform_fn[name_here].both_transforms(params_for_test, return_transpose=False)
102
+ coords_here_in = self.transform_fn[name_here].both_transforms(coords_here, return_transpose=False)
103
+
104
+ else:
105
+ params_for_test_in = params_for_test.copy()
106
+ coords_here_in = coords_here.copy()
107
+
108
+ inds_tmp_test = np.arange(len(params_for_test_in))
109
+ inds_tmp_here = np.arange(len(coords_here_in))
110
+ inds_tmp_test, inds_tmp_here = [tmp.ravel() for tmp in np.meshgrid(inds_tmp_test, inds_tmp_here)]
111
+
112
+ params_for_test_in_full = params_for_test_in[inds_tmp_test]
113
+ coords_here_in_full = coords_here_in[inds_tmp_here]
114
+ # build the waveforms at the same time
115
+
116
+ df = 1. / waveform_kwargs["T"]
117
+ max_f = 1. / 2 * 1/waveform_kwargs["dt"]
118
+ frqs = self.xp.arange(0.0, max_f, df)
119
+ data_minus_template = self.xp.asarray([
120
+ self.xp.ones_like(frqs, dtype=complex),
121
+ self.xp.ones_like(frqs, dtype=complex)
122
+ ])[None, :, :]
123
+ psd = self.xp.asarray([
124
+ self.xp.ones_like(frqs, dtype=np.float64),
125
+ self.xp.ones_like(frqs, dtype=np.float64)
126
+ ])
127
+
128
+ waveform_kwargs_fill = waveform_kwargs.copy()
129
+ waveform_kwargs_fill.pop("start_freq_ind")
130
+
131
+ # TODO: could use real data and get observed snr for each if needed
132
+ check = self.gb_wave_generator.swap_likelihood_difference(
133
+ params_for_test_in_full,
134
+ coords_here_in_full,
135
+ data_minus_template,
136
+ psd,
137
+ start_freq_ind=0,
138
+ data_index=None,
139
+ noise_index=None,
140
+ **waveform_kwargs_fill,
141
+ )
142
+
143
+ numerator = self.gb_wave_generator.add_remove
144
+ norm_here = self.gb_wave_generator.add_add
145
+ norm_for_test = self.gb_wave_generator.remove_remove
146
+
147
+ normalized_autocorr = numerator / np.sqrt(norm_here * norm_for_test)
148
+ normalized_against_test = numerator / norm_for_test
149
+
150
+ normalized_autocorr = normalized_autocorr.reshape(coords_here_in.shape[0], params_for_test_in.shape[0]).real
151
+ normalized_against_test = normalized_against_test.reshape(coords_here_in.shape[0], params_for_test_in.shape[0]).real
152
+
153
+ # TODO: do based on Likelihood? make sure on same posterior
154
+ # TODO: add check based on amplitude
155
+ test1 = np.abs(1.0 - normalized_autocorr.real) # (numerator / norm_for_test[None, :]).real)
156
+ best = test1.argmin(axis=1)
157
+ try:
158
+ best = best.get()
159
+ except AttributeError:
160
+ pass
161
+ best_mismatch = test1[(np.arange(test1.shape[0]), best)]
162
+ check_normalized_against_test = np.abs(1.0 - normalized_against_test[(np.arange(test1.shape[0]), best)])
163
+
164
+
165
+ f0_here = coords_here[:, 1]
166
+ f0_test = params_for_test[best, 1]
167
+
168
+ for leaf in range(nleaves):
169
+ if best_mismatch[leaf] < mismatch_lim and check_normalized_against_test[leaf] < double_check_lim:
170
+ groups[best[leaf]].append(coords_here[leaf].copy())
171
+ groups_inds[best[leaf]].append([w, inds_for_group_stuff[leaf]])
172
+
173
+ elif not fix_group_count:
174
+ # this only works for high snr limit
175
+ groups.append([coords_here[leaf]].copy())
176
+ groups_inds.append([[w, inds_for_group_stuff[leaf]]])
177
+
178
+ group_lens = [len(group) for group in groups]
179
+
180
+ return groups, groups_inds, group_lens
181
+
182
+
183
+ class GetLastGBState:
184
+ def __init__(self, gb_wave_generator, transform_fn=None, waveform_kwargs={}):
185
+ self.gb_wave_generator = gb_wave_generator
186
+ self.xp = self.gb_wave_generator.xp
187
+ self.transform_fn = transform_fn
188
+ self.waveform_kwargs = waveform_kwargs
189
+
190
+ def __call__(self, mgh, reader, df, supps_base_shape, fix_temp_initial_ind:int=None, fix_temp_inds:list=None, nleaves_max_in=None, waveform_kwargs={}):
191
+
192
+ xp.cuda.runtime.setDevice(mgh.gpus[0])
193
+
194
+ if fix_temp_initial_ind is not None or fix_temp_inds is not None:
195
+ if fix_temp_initial_ind is None or fix_temp_inds is None:
196
+ raise ValueError("If giving fix_temp_initial_ind or fix_temp_inds, must give both.")
197
+
198
+ state = reader.get_last_sample()
199
+
200
+ waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
201
+ if "start_freq_ind" not in waveform_kwargs:
202
+ raise ValueError("In get_last_gb_state, waveform_kwargs must include 'start_freq_ind'.")
203
+
204
+ #check = reader.get_last_sample()
205
+ ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
206
+
207
+ #out = get_groups_for_remixing(check, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, name_here="gb")
208
+
209
+ #lengths = []
210
+ #for group in out[0]:
211
+ # lengths.append(len(group))
212
+ #breakpoint()
213
+ try:
214
+ if fix_temp_initial_ind is not None:
215
+ for i in fix_temp_inds:
216
+ if i < fix_temp_initial_ind:
217
+ raise ValueError("If providing fix_temp_initial_ind and fix_temp_inds, all values in fix_temp_inds must be greater than fix_temp_initial_ind.")
218
+
219
+ state.log_like[i] = state.log_like[fix_temp_initial_ind]
220
+ state.log_prior[i] = state.log_prior[fix_temp_initial_ind]
221
+ state.branches_coords["gb"][i] = state.branches_coords["gb"][fix_temp_initial_ind]
222
+ state.branches_coords["gb"][i] = state.branches_coords["gb"][fix_temp_initial_ind]
223
+ state.branches_inds["gb"][i] = state.branches_inds["gb"][fix_temp_initial_ind]
224
+ state.branches_inds["gb"][i] = state.branches_inds["gb"][fix_temp_initial_ind]
225
+
226
+ ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
227
+ if nleaves_max_in is None:
228
+ nleaves_max = nleaves_max_old
229
+ else:
230
+ nleaves_max = nleaves_max_in
231
+ if nleaves_max_old <= nleaves_max:
232
+ coords_tmp = np.zeros((ntemps, nwalkers, nleaves_max, ndim))
233
+ coords_tmp[:, :, :nleaves_max_old, :] = state.branches["gb"].coords
234
+
235
+ inds_tmp = np.zeros((ntemps, nwalkers, nleaves_max), dtype=bool)
236
+ inds_tmp[:, :, :nleaves_max_old] = state.branches["gb"].inds
237
+ state.branches["gb"].coords = coords_tmp
238
+ state.branches["gb"].inds = inds_tmp
239
+ state.branches["gb"].nleaves_max = nleaves_max
240
+ state.branches["gb"].shape = (ntemps, nwalkers, nleaves_max, ndim)
241
+
242
+ else:
243
+ raise ValueError("new nleaves_max is less than nleaves_max_old.")
244
+
245
+ # add "gb" if there are any
246
+ data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
247
+
248
+ data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(xp.int32)
249
+
250
+ params_add_in = self.transform_fn["gb"].both_transforms(state.branches_coords["gb"][state.branches_inds["gb"]])
251
+
252
+ # batch_size is ignored if waveform_kwargs["use_c_implementation"] is True
253
+ # -1 is to do -(-d + h) = d - h
254
+ mgh.multiply_data(-1.)
255
+ self.gb_wave_generator.generate_global_template(params_add_in, data_index, mgh.data_list, data_length=mgh.data_length, data_splits=mgh.gpu_splits, batch_size=1000, **waveform_kwargs)
256
+ mgh.multiply_data(-1.)
257
+
258
+
259
+ except KeyError:
260
+ # no "gb"
261
+ pass
262
+
263
+ data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
264
+ data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(xp.int32)
265
+
266
+ params_add_in = self.transform_fn["gb"].both_transforms(state.branches_coords["gb"][state.branches_inds["gb"]])
267
+
268
+ # -1 is to do -(-d + h) = d - h
269
+ mgh.multiply_data(-1.)
270
+ self.gb_wave_generator.generate_global_template(params_add_in, data_index, mgh.data_list, data_length=mgh.data_length, data_splits=mgh.gpu_splits, batch_size=1000, **waveform_kwargs)
271
+ mgh.multiply_data(-1.)
272
+
273
+ self.gb_wave_generator.d_d = np.asarray(mgh.get_inner_product(use_cpu=True))
274
+
275
+ state.log_like = -1/2 * self.gb_wave_generator.d_d.real.reshape(ntemps, nwalkers)
276
+
277
+ temp_inds = mgh.temp_indices.copy()
278
+ walker_inds = mgh.walker_indices.copy()
279
+ overall_inds = mgh.overall_indices.copy()
280
+
281
+ supps = BranchSupplimental({ "temp_inds": temp_inds, "walker_inds": walker_inds, "overall_inds": overall_inds,}, obj_contained_shape=supps_base_shape, copy=True)
282
+ state.supplimental = supps
283
+
284
+ return state
285
+
286
+
287
+ class HeterodynedUpdate:
288
+ def __init__(self, update_kwargs, set_d_d_zero=False):
289
+ self.update_kwargs = update_kwargs
290
+ self.set_d_d_zero = set_d_d_zero
291
+
292
+ def __call__(self, it, sample_state, sampler, **kwargs):
293
+
294
+ samples = sample_state.branches_coords["mbh"].reshape(-1, sampler.ndims[0])
295
+ lp_max = sample_state.log_like.argmax()
296
+ best = samples[lp_max]
297
+
298
+ lp = sample_state.log_like.flatten()
299
+ sorted = np.argsort(lp)
300
+ inds_best = sorted[-1000:]
301
+ inds_worst = sorted[:1000]
302
+
303
+ best_full = sampler.log_like_fn.f.parameter_transforms["mbh"].both_transforms(
304
+ best, copy=True
305
+ )
306
+
307
+ sampler.log_like_fn.f.template_model.init_heterodyne_info(
308
+ best_full, **self.update_kwargs
309
+ )
310
+
311
+ if self.set_d_d_zero:
312
+ sampler.log_like_fn.f.template_model.reference_d_d = 0.0
313
+
314
+ # TODO: make this a general update function in Eryn (?)
315
+ # samples[inds_worst] = samples[inds_best].copy()
316
+ samples = samples.reshape(sampler.ntemps, sampler.nwalkers, 1, sampler.ndims[0])
317
+ logp = sampler.compute_log_prior({"mbh": samples})
318
+ logL, blobs = sampler.compute_log_like({"mbh": samples}, logp=logp)
319
+
320
+ sample_state.branches["mbh"].coords = samples
321
+ sample_state.log_like = logL
322
+ sample_state.blobs = blobs
323
+
324
+ # sampler.backend.save_step(sample_state, np.full_like(lp, True))