lisaanalysistools 1.1.20__cp39-cp39-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. lisaanalysistools/git_version.py +7 -0
  2. lisaanalysistools-1.1.20.dist-info/METADATA +281 -0
  3. lisaanalysistools-1.1.20.dist-info/RECORD +48 -0
  4. lisaanalysistools-1.1.20.dist-info/WHEEL +5 -0
  5. lisaanalysistools-1.1.20.dist-info/licenses/LICENSE +201 -0
  6. lisatools/.dylibs/libgcc_s.1.1.dylib +0 -0
  7. lisatools/.dylibs/libstdc++.6.dylib +0 -0
  8. lisatools/__init__.py +90 -0
  9. lisatools/_version.py +34 -0
  10. lisatools/analysiscontainer.py +474 -0
  11. lisatools/cutils/Detector.cu +307 -0
  12. lisatools/cutils/Detector.hpp +84 -0
  13. lisatools/cutils/__init__.py +129 -0
  14. lisatools/cutils/global.hpp +28 -0
  15. lisatools/cutils/pycppdetector.pyx +256 -0
  16. lisatools/datacontainer.py +312 -0
  17. lisatools/detector.py +867 -0
  18. lisatools/diagnostic.py +990 -0
  19. lisatools/git_version.py.in +7 -0
  20. lisatools/orbit_files/equalarmlength-orbits-best-fit-to-esa.h5 +0 -0
  21. lisatools/orbit_files/equalarmlength-orbits.h5 +0 -0
  22. lisatools/orbit_files/esa-trailing-orbits.h5 +0 -0
  23. lisatools/sampling/__init__.py +0 -0
  24. lisatools/sampling/likelihood.py +882 -0
  25. lisatools/sampling/moves/__init__.py +0 -0
  26. lisatools/sampling/moves/skymodehop.py +110 -0
  27. lisatools/sampling/prior.py +646 -0
  28. lisatools/sampling/stopping.py +320 -0
  29. lisatools/sampling/utility.py +411 -0
  30. lisatools/sensitivity.py +1554 -0
  31. lisatools/sources/__init__.py +6 -0
  32. lisatools/sources/bbh/__init__.py +1 -0
  33. lisatools/sources/bbh/waveform.py +106 -0
  34. lisatools/sources/defaultresponse.py +37 -0
  35. lisatools/sources/emri/__init__.py +1 -0
  36. lisatools/sources/emri/waveform.py +79 -0
  37. lisatools/sources/gb/__init__.py +1 -0
  38. lisatools/sources/gb/waveform.py +69 -0
  39. lisatools/sources/utils.py +459 -0
  40. lisatools/sources/waveformbase.py +41 -0
  41. lisatools/stochastic.py +327 -0
  42. lisatools/utils/__init__.py +0 -0
  43. lisatools/utils/constants.py +54 -0
  44. lisatools/utils/exceptions.py +95 -0
  45. lisatools/utils/parallelbase.py +11 -0
  46. lisatools/utils/utility.py +122 -0
  47. lisatools_backend_cpu/git_version.py +7 -0
  48. lisatools_backend_cpu/pycppdetector.cpython-39-darwin.so +0 -0
@@ -0,0 +1,411 @@
1
+ from multiprocessing.sharedctypes import Value
2
+ import os
3
+
4
+ import numpy as np
5
+
6
+ try:
7
+ import cupy as xp
8
+
9
+ except (ImportError, ModuleNotFoundError) as e:
10
+ pass
11
+
12
+ from eryn.state import State, BranchSupplemental
13
+ from eryn.utils.utility import groups_from_inds
14
+
15
+
16
+ class DetermineGBGroups:
17
+ def __init__(self, gb_wave_generator, transform_fn=None, waveform_kwargs={}):
18
+ self.gb_wave_generator = gb_wave_generator
19
+ self.xp = self.gb_wave_generator.xp
20
+ self.transform_fn = transform_fn
21
+ self.waveform_kwargs = waveform_kwargs
22
+
23
+ def __call__(
24
+ self,
25
+ last_sample,
26
+ name_here,
27
+ check_temp=0,
28
+ input_groups=None,
29
+ input_groups_inds=None,
30
+ fix_group_count=False,
31
+ mismatch_lim=0.2,
32
+ double_check_lim=0.2,
33
+ start_term="random",
34
+ waveform_kwargs={},
35
+ index_within_group="random",
36
+ ):
37
+ # TODO: mess with mismatch lim setting
38
+ # TODO: some time of mismatch annealing may be useful
39
+ if isinstance(last_sample, State):
40
+ state = last_sample
41
+ coords = state.branches_coords[name_here][check_temp]
42
+ inds = state.branches_inds[name_here][check_temp]
43
+ elif isinstance(last_sample, dict):
44
+ coords = last_sample[name_here][check_temp]["coords"]
45
+ inds = last_sample[name_here][check_temp]["inds"]
46
+
47
+ waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
48
+
49
+ # get coordinates and inds of the temperature you are considering.
50
+
51
+ nwalkers, nleaves_max, ndim = coords.shape
52
+ if input_groups is None:
53
+
54
+ # figure our which walker to start with
55
+ if start_term == "max":
56
+ start_walker_ind = inds[check_temp].sum(axis=-1).argmax()
57
+ elif start_term == "first":
58
+ start_walker_ind = 0
59
+ elif start_term == "random":
60
+ start_walker_ind = np.random.randint(0, nwalkers)
61
+ else:
62
+ raise ValueError("start_term must be 'max', 'first', or 'random'.")
63
+
64
+ # get all the good leaves in this walker
65
+ inds_good = np.where(inds[start_walker_ind])[0]
66
+ groups = []
67
+ groups_inds = []
68
+
69
+ # set up this information to load the information into the group lists
70
+ for leaf_i, leaf in enumerate(inds_good):
71
+ groups.append([])
72
+ groups_inds.append([])
73
+ groups[leaf_i].append(coords[start_walker_ind, leaf].copy())
74
+ groups_inds[leaf_i].append([start_walker_ind, leaf])
75
+ else:
76
+ # allows us to check groups based on groups we already have
77
+ groups = input_groups
78
+ groups_inds = input_groups_inds
79
+
80
+ if len(groups) == 0:
81
+ return [], [], []
82
+ for w in range(coords.shape[0]):
83
+
84
+ # we have already loaded this group
85
+ if input_groups is None and w == start_walker_ind:
86
+ continue
87
+
88
+ # walker has no binaries
89
+ if not np.any(inds[w]):
90
+ continue
91
+
92
+ # coords in this walker
93
+ coords_here = coords[w][inds[w]]
94
+ inds_for_group_stuff = np.arange(len(inds[w]))[inds[w]]
95
+ nleaves, ndim = coords_here.shape
96
+
97
+ params_for_test = []
98
+ for group in groups:
99
+ group_params = np.asarray(group)
100
+
101
+ if index_within_group == "first":
102
+ test_walker_ind = 0
103
+ elif index_within_group == "random":
104
+ test_walker_ind = np.random.randint(0, group_params.shape[0])
105
+ else:
106
+ raise ValueError("start_term must be 'max', 'first', or 'random'.")
107
+
108
+ params_for_test.append(group_params[test_walker_ind])
109
+ params_for_test = np.asarray(params_for_test)
110
+
111
+ # transform coords
112
+ if self.transform_fn is not None:
113
+ params_for_test_in = self.transform_fn[name_here].both_transforms(
114
+ params_for_test, return_transpose=False
115
+ )
116
+ coords_here_in = self.transform_fn[name_here].both_transforms(
117
+ coords_here, return_transpose=False
118
+ )
119
+
120
+ else:
121
+ params_for_test_in = params_for_test.copy()
122
+ coords_here_in = coords_here.copy()
123
+
124
+ inds_tmp_test = np.arange(len(params_for_test_in))
125
+ inds_tmp_here = np.arange(len(coords_here_in))
126
+ inds_tmp_test, inds_tmp_here = [
127
+ tmp.ravel() for tmp in np.meshgrid(inds_tmp_test, inds_tmp_here)
128
+ ]
129
+
130
+ params_for_test_in_full = params_for_test_in[inds_tmp_test]
131
+ coords_here_in_full = coords_here_in[inds_tmp_here]
132
+ # build the waveforms at the same time
133
+
134
+ df = 1.0 / waveform_kwargs["T"]
135
+ max_f = 1.0 / 2 * 1 / waveform_kwargs["dt"]
136
+ frqs = self.xp.arange(0.0, max_f, df)
137
+ data_minus_template = self.xp.asarray(
138
+ [
139
+ self.xp.ones_like(frqs, dtype=complex),
140
+ self.xp.ones_like(frqs, dtype=complex),
141
+ ]
142
+ )[None, :, :]
143
+ psd = self.xp.asarray(
144
+ [
145
+ self.xp.ones_like(frqs, dtype=np.float64),
146
+ self.xp.ones_like(frqs, dtype=np.float64),
147
+ ]
148
+ )
149
+
150
+ waveform_kwargs_fill = waveform_kwargs.copy()
151
+ waveform_kwargs_fill.pop("start_freq_ind")
152
+
153
+ # TODO: could use real data and get observed snr for each if needed
154
+ check = self.gb_wave_generator.swap_likelihood_difference(
155
+ params_for_test_in_full,
156
+ coords_here_in_full,
157
+ data_minus_template,
158
+ psd,
159
+ start_freq_ind=0,
160
+ data_index=None,
161
+ noise_index=None,
162
+ **waveform_kwargs_fill,
163
+ )
164
+
165
+ numerator = self.gb_wave_generator.add_remove
166
+ norm_here = self.gb_wave_generator.add_add
167
+ norm_for_test = self.gb_wave_generator.remove_remove
168
+
169
+ normalized_autocorr = numerator / np.sqrt(norm_here * norm_for_test)
170
+ normalized_against_test = numerator / norm_for_test
171
+
172
+ normalized_autocorr = normalized_autocorr.reshape(
173
+ coords_here_in.shape[0], params_for_test_in.shape[0]
174
+ ).real
175
+ normalized_against_test = normalized_against_test.reshape(
176
+ coords_here_in.shape[0], params_for_test_in.shape[0]
177
+ ).real
178
+
179
+ # TODO: do based on Likelihood? make sure on same posterior
180
+ # TODO: add check based on amplitude
181
+ test1 = np.abs(
182
+ 1.0 - normalized_autocorr.real
183
+ ) # (numerator / norm_for_test[None, :]).real)
184
+ best = test1.argmin(axis=1)
185
+ try:
186
+ best = best.get()
187
+ except AttributeError:
188
+ pass
189
+ best_mismatch = test1[(np.arange(test1.shape[0]), best)]
190
+ check_normalized_against_test = np.abs(
191
+ 1.0 - normalized_against_test[(np.arange(test1.shape[0]), best)]
192
+ )
193
+
194
+ f0_here = coords_here[:, 1]
195
+ f0_test = params_for_test[best, 1]
196
+
197
+ for leaf in range(nleaves):
198
+ if (
199
+ best_mismatch[leaf] < mismatch_lim
200
+ and check_normalized_against_test[leaf] < double_check_lim
201
+ ):
202
+ groups[best[leaf]].append(coords_here[leaf].copy())
203
+ groups_inds[best[leaf]].append([w, inds_for_group_stuff[leaf]])
204
+
205
+ elif not fix_group_count:
206
+ # this only works for high snr limit
207
+ groups.append([coords_here[leaf]].copy())
208
+ groups_inds.append([[w, inds_for_group_stuff[leaf]]])
209
+
210
+ group_lens = [len(group) for group in groups]
211
+
212
+ return groups, groups_inds, group_lens
213
+
214
+
215
+ class GetLastGBState:
216
+ def __init__(self, gb_wave_generator, transform_fn=None, waveform_kwargs={}):
217
+ self.gb_wave_generator = gb_wave_generator
218
+ self.xp = self.gb_wave_generator.xp
219
+ self.transform_fn = transform_fn
220
+ self.waveform_kwargs = waveform_kwargs
221
+
222
+ def __call__(
223
+ self,
224
+ mgh,
225
+ reader,
226
+ df,
227
+ supps_base_shape,
228
+ fix_temp_initial_ind: int = None,
229
+ fix_temp_inds: list = None,
230
+ nleaves_max_in=None,
231
+ waveform_kwargs={},
232
+ ):
233
+
234
+ xp.cuda.runtime.setDevice(mgh.gpus[0])
235
+
236
+ if fix_temp_initial_ind is not None or fix_temp_inds is not None:
237
+ if fix_temp_initial_ind is None or fix_temp_inds is None:
238
+ raise ValueError(
239
+ "If giving fix_temp_initial_ind or fix_temp_inds, must give both."
240
+ )
241
+
242
+ state = reader.get_last_sample()
243
+
244
+ waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
245
+ if "start_freq_ind" not in waveform_kwargs:
246
+ raise ValueError(
247
+ "In get_last_gb_state, waveform_kwargs must include 'start_freq_ind'."
248
+ )
249
+
250
+ # check = reader.get_last_sample()
251
+ ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
252
+
253
+ # out = get_groups_for_remixing(check, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, name_here="gb")
254
+
255
+ # lengths = []
256
+ # for group in out[0]:
257
+ # lengths.append(len(group))
258
+ # breakpoint()
259
+ try:
260
+ if fix_temp_initial_ind is not None:
261
+ for i in fix_temp_inds:
262
+ if i < fix_temp_initial_ind:
263
+ raise ValueError(
264
+ "If providing fix_temp_initial_ind and fix_temp_inds, all values in fix_temp_inds must be greater than fix_temp_initial_ind."
265
+ )
266
+
267
+ state.log_like[i] = state.log_like[fix_temp_initial_ind]
268
+ state.log_prior[i] = state.log_prior[fix_temp_initial_ind]
269
+ state.branches_coords["gb"][i] = state.branches_coords["gb"][
270
+ fix_temp_initial_ind
271
+ ]
272
+ state.branches_coords["gb"][i] = state.branches_coords["gb"][
273
+ fix_temp_initial_ind
274
+ ]
275
+ state.branches_inds["gb"][i] = state.branches_inds["gb"][
276
+ fix_temp_initial_ind
277
+ ]
278
+ state.branches_inds["gb"][i] = state.branches_inds["gb"][
279
+ fix_temp_initial_ind
280
+ ]
281
+
282
+ ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
283
+ if nleaves_max_in is None:
284
+ nleaves_max = nleaves_max_old
285
+ else:
286
+ nleaves_max = nleaves_max_in
287
+ if nleaves_max_old <= nleaves_max:
288
+ coords_tmp = np.zeros((ntemps, nwalkers, nleaves_max, ndim))
289
+ coords_tmp[:, :, :nleaves_max_old, :] = state.branches["gb"].coords
290
+
291
+ inds_tmp = np.zeros((ntemps, nwalkers, nleaves_max), dtype=bool)
292
+ inds_tmp[:, :, :nleaves_max_old] = state.branches["gb"].inds
293
+ state.branches["gb"].coords = coords_tmp
294
+ state.branches["gb"].inds = inds_tmp
295
+ state.branches["gb"].nleaves_max = nleaves_max
296
+ state.branches["gb"].shape = (ntemps, nwalkers, nleaves_max, ndim)
297
+
298
+ else:
299
+ raise ValueError("new nleaves_max is less than nleaves_max_old.")
300
+
301
+ # add "gb" if there are any
302
+ data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
303
+
304
+ data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(
305
+ xp.int32
306
+ )
307
+
308
+ params_add_in = self.transform_fn["gb"].both_transforms(
309
+ state.branches_coords["gb"][state.branches_inds["gb"]]
310
+ )
311
+
312
+ # batch_size is ignored if waveform_kwargs["use_c_implementation"] is True
313
+ # -1 is to do -(-d + h) = d - h
314
+ mgh.multiply_data(-1.0)
315
+ self.gb_wave_generator.generate_global_template(
316
+ params_add_in,
317
+ data_index,
318
+ mgh.data_list,
319
+ data_length=mgh.data_length,
320
+ data_splits=mgh.gpu_splits,
321
+ batch_size=1000,
322
+ **waveform_kwargs,
323
+ )
324
+ mgh.multiply_data(-1.0)
325
+
326
+ except KeyError:
327
+ # no "gb"
328
+ pass
329
+
330
+ data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
331
+ data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(xp.int32)
332
+
333
+ params_add_in = self.transform_fn["gb"].both_transforms(
334
+ state.branches_coords["gb"][state.branches_inds["gb"]]
335
+ )
336
+
337
+ # -1 is to do -(-d + h) = d - h
338
+ mgh.multiply_data(-1.0)
339
+ self.gb_wave_generator.generate_global_template(
340
+ params_add_in,
341
+ data_index,
342
+ mgh.data_list,
343
+ data_length=mgh.data_length,
344
+ data_splits=mgh.gpu_splits,
345
+ batch_size=1000,
346
+ **waveform_kwargs,
347
+ )
348
+ mgh.multiply_data(-1.0)
349
+
350
+ self.gb_wave_generator.d_d = np.asarray(mgh.get_inner_product(use_cpu=True))
351
+
352
+ state.log_like = (
353
+ -1 / 2 * self.gb_wave_generator.d_d.real.reshape(ntemps, nwalkers)
354
+ )
355
+
356
+ temp_inds = mgh.temp_indices.copy()
357
+ walker_inds = mgh.walker_indices.copy()
358
+ overall_inds = mgh.overall_indices.copy()
359
+
360
+ supps = BranchSupplemental(
361
+ {
362
+ "temp_inds": temp_inds,
363
+ "walker_inds": walker_inds,
364
+ "overall_inds": overall_inds,
365
+ },
366
+ obj_contained_shape=supps_base_shape,
367
+ copy=True,
368
+ )
369
+ state.supplimental = supps
370
+
371
+ return state
372
+
373
+
374
+ class HeterodynedUpdate:
375
+ def __init__(self, update_kwargs, set_d_d_zero=False):
376
+ self.update_kwargs = update_kwargs
377
+ self.set_d_d_zero = set_d_d_zero
378
+
379
+ def __call__(self, it, sample_state, sampler, **kwargs):
380
+
381
+ samples = sample_state.branches_coords["mbh"].reshape(-1, sampler.ndims[0])
382
+ lp_max = sample_state.log_like.argmax()
383
+ best = samples[lp_max]
384
+
385
+ lp = sample_state.log_like.flatten()
386
+ sorted = np.argsort(lp)
387
+ inds_best = sorted[-1000:]
388
+ inds_worst = sorted[:1000]
389
+
390
+ best_full = sampler.log_like_fn.f.parameter_transforms["mbh"].both_transforms(
391
+ best, copy=True
392
+ )
393
+
394
+ sampler.log_like_fn.f.template_model.init_heterodyne_info(
395
+ best_full, **self.update_kwargs
396
+ )
397
+
398
+ if self.set_d_d_zero:
399
+ sampler.log_like_fn.f.template_model.reference_d_d = 0.0
400
+
401
+ # TODO: make this a general update function in Eryn (?)
402
+ # samples[inds_worst] = samples[inds_best].copy()
403
+ samples = samples.reshape(sampler.ntemps, sampler.nwalkers, 1, sampler.ndims[0])
404
+ logp = sampler.compute_log_prior({"mbh": samples})
405
+ logL, blobs = sampler.compute_log_like({"mbh": samples}, logp=logp)
406
+
407
+ sample_state.branches["mbh"].coords = samples
408
+ sample_state.log_like = logL
409
+ sample_state.blobs = blobs
410
+
411
+ # sampler.backend.save_step(sample_state, np.full_like(lp, True))