lisaanalysistools 1.1.20__cp39-cp39-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. lisaanalysistools/git_version.py +7 -0
  2. lisaanalysistools-1.1.20.dist-info/METADATA +281 -0
  3. lisaanalysistools-1.1.20.dist-info/RECORD +48 -0
  4. lisaanalysistools-1.1.20.dist-info/WHEEL +5 -0
  5. lisaanalysistools-1.1.20.dist-info/licenses/LICENSE +201 -0
  6. lisatools/.dylibs/libgcc_s.1.1.dylib +0 -0
  7. lisatools/.dylibs/libstdc++.6.dylib +0 -0
  8. lisatools/__init__.py +90 -0
  9. lisatools/_version.py +34 -0
  10. lisatools/analysiscontainer.py +474 -0
  11. lisatools/cutils/Detector.cu +307 -0
  12. lisatools/cutils/Detector.hpp +84 -0
  13. lisatools/cutils/__init__.py +129 -0
  14. lisatools/cutils/global.hpp +28 -0
  15. lisatools/cutils/pycppdetector.pyx +256 -0
  16. lisatools/datacontainer.py +312 -0
  17. lisatools/detector.py +867 -0
  18. lisatools/diagnostic.py +990 -0
  19. lisatools/git_version.py.in +7 -0
  20. lisatools/orbit_files/equalarmlength-orbits-best-fit-to-esa.h5 +0 -0
  21. lisatools/orbit_files/equalarmlength-orbits.h5 +0 -0
  22. lisatools/orbit_files/esa-trailing-orbits.h5 +0 -0
  23. lisatools/sampling/__init__.py +0 -0
  24. lisatools/sampling/likelihood.py +882 -0
  25. lisatools/sampling/moves/__init__.py +0 -0
  26. lisatools/sampling/moves/skymodehop.py +110 -0
  27. lisatools/sampling/prior.py +646 -0
  28. lisatools/sampling/stopping.py +320 -0
  29. lisatools/sampling/utility.py +411 -0
  30. lisatools/sensitivity.py +1554 -0
  31. lisatools/sources/__init__.py +6 -0
  32. lisatools/sources/bbh/__init__.py +1 -0
  33. lisatools/sources/bbh/waveform.py +106 -0
  34. lisatools/sources/defaultresponse.py +37 -0
  35. lisatools/sources/emri/__init__.py +1 -0
  36. lisatools/sources/emri/waveform.py +79 -0
  37. lisatools/sources/gb/__init__.py +1 -0
  38. lisatools/sources/gb/waveform.py +69 -0
  39. lisatools/sources/utils.py +459 -0
  40. lisatools/sources/waveformbase.py +41 -0
  41. lisatools/stochastic.py +327 -0
  42. lisatools/utils/__init__.py +0 -0
  43. lisatools/utils/constants.py +54 -0
  44. lisatools/utils/exceptions.py +95 -0
  45. lisatools/utils/parallelbase.py +11 -0
  46. lisatools/utils/utility.py +122 -0
  47. lisatools_backend_cpu/git_version.py +7 -0
  48. lisatools_backend_cpu/pycppdetector.cpython-39-darwin.so +0 -0
@@ -0,0 +1,646 @@
1
+ import numpy as np
2
+ from scipy import stats
3
+
4
+ from ..utils.constants import *
5
+ from ..sensitivity import get_sensitivity
6
+ from eryn.moves.multipletry import logsumexp
7
+
8
+ from typing import Union, Optional, Tuple, List
9
+
10
+ import sys
11
+
12
+ sys.path.append(
13
+ "/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/gf_search/"
14
+ )
15
+ # from galaxy_ffdot import GalaxyFFdot
16
+ # from galaxy import Galaxy
17
+
18
+ try:
19
+ import cupy as cp
20
+
21
+ except (ModuleNotFoundError, ImportError) as e:
22
+ pass
23
+
24
+
25
+ class AmplitudeFrequencySNRPrior:
26
+ def __init__(
27
+ self, rho_star, frequency_prior, L, Tobs, use_cupy=False, **noise_kwargs
28
+ ):
29
+ self.rho_star = rho_star
30
+ self.frequency_prior = frequency_prior
31
+
32
+ self.transform = AmplitudeFromSNR(L, Tobs, use_cupy=use_cupy, **noise_kwargs)
33
+ self.snr_prior = SNRPrior(rho_star, use_cupy=use_cupy)
34
+
35
+ # must be after transform and snr_prior due to setter
36
+ self.use_cupy = use_cupy
37
+
38
+ @property
39
+ def use_cupy(self):
40
+ return self._use_cupy
41
+
42
+ @use_cupy.setter
43
+ def use_cupy(self, use_cupy):
44
+ self._use_cupy = use_cupy
45
+ self.transform.use_cupy = use_cupy
46
+ self.snr_prior.use_cupy = use_cupy
47
+ self.frequency_prior.use_cupy = use_cupy
48
+
49
+ def pdf(self, *args, **noise_kwargs):
50
+ return np.exp(self.logpdf(*args, **noise_kwargs))
51
+
52
+ def logpdf(self, amp, f0_ms, **noise_kwargs):
53
+
54
+ xp = np if not self.use_cupy else cp
55
+
56
+ f0 = f0_ms / 1e3
57
+ rho, f0 = self.transform.forward(amp, f0, **noise_kwargs)
58
+
59
+ rho_pdf = self.snr_prior.pdf(rho)
60
+
61
+ Jac = xp.abs(rho / amp)
62
+
63
+ logpdf_amp = np.log(np.abs(Jac * rho_pdf))
64
+ logpdf_f = self.frequency_prior.logpdf(f0_ms)
65
+
66
+ return logpdf_amp + logpdf_f
67
+
68
+ def rvs(self, size=1, f0_input=None, **noise_kwargs):
69
+ if isinstance(size, int):
70
+ size = (size,)
71
+
72
+ xp = np if not self.use_cupy else cp
73
+
74
+ if f0_input is None:
75
+ f0_ms = self.frequency_prior.rvs(size=size)
76
+ else:
77
+ f0_ms = f0_input
78
+ assert f0_input.shape[:-1] == size
79
+
80
+ f0 = f0_ms / 1e3
81
+
82
+ rho = self.snr_prior.rvs(size=size)
83
+
84
+ amp, _ = self.transform(rho, f0, **noise_kwargs)
85
+
86
+ return (amp, f0_ms)
87
+
88
+
89
+ class SNRPrior:
90
+ def __init__(self, rho_star, use_cupy=False):
91
+ self.rho_star = rho_star
92
+ self.use_cupy = use_cupy
93
+
94
+ @property
95
+ def use_cupy(self):
96
+ return self._use_cupy
97
+
98
+ @use_cupy.setter
99
+ def use_cupy(self, use_cupy):
100
+ self._use_cupy = use_cupy
101
+
102
+ def pdf(self, rho):
103
+
104
+ xp = np if not self.use_cupy else cp
105
+
106
+ p = xp.zeros_like(rho)
107
+ good = rho > 0.0
108
+ p[good] = (
109
+ 3
110
+ * rho[good]
111
+ / (4 * self.rho_star**2 * (1 + rho[good] / (4 * self.rho_star)) ** 5)
112
+ )
113
+ return p
114
+
115
+ def logpdf(self, rho):
116
+ xp = np if not self.use_cupy else cp
117
+ return xp.log(self.pdf(rho))
118
+
119
+ def cdf(self, rho):
120
+ xp = np if not self.use_cupy else cp
121
+ c = xp.zeros_like(rho)
122
+ good = rho > 0.0
123
+ c[good] = (
124
+ 768
125
+ * self.rho_star**3
126
+ * (
127
+ 1 / (768.0 * self.rho_star**3)
128
+ - (rho[good] + self.rho_star)
129
+ / (3.0 * (rho[good] + 4 * self.rho_star) ** 4)
130
+ )
131
+ )
132
+ return c
133
+
134
+ def rvs(self, size=1):
135
+ if isinstance(size, int):
136
+ size = (size,)
137
+
138
+ xp = np if not self.use_cupy else cp
139
+
140
+ u = xp.random.rand(*size)
141
+
142
+ rho = (
143
+ -4 * self.rho_star
144
+ + xp.sqrt(
145
+ -32 * self.rho_star**2
146
+ - (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
147
+ + (
148
+ 3072
149
+ * 2**0.3333333333333333
150
+ * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
151
+ * (self.rho_star**4 - u * self.rho_star**4)
152
+ )
153
+ / (
154
+ (-1 + u) ** 2
155
+ * xp.cbrt(
156
+ -1769472 * self.rho_star**6
157
+ + 1769472 * u * self.rho_star**6
158
+ - xp.sqrt(
159
+ 3131031158784 * u * self.rho_star**12
160
+ - 6262062317568 * u**2 * self.rho_star**12
161
+ + 3131031158784 * u**3 * self.rho_star**12
162
+ )
163
+ )
164
+ )
165
+ + xp.cbrt(
166
+ -1769472 * self.rho_star**6
167
+ + 1769472 * u * self.rho_star**6
168
+ - xp.sqrt(
169
+ 3131031158784 * u * self.rho_star**12
170
+ - 6262062317568 * u**2 * self.rho_star**12
171
+ + 3131031158784 * u**3 * self.rho_star**12
172
+ )
173
+ )
174
+ / (3.0 * 2**0.3333333333333333 * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3))
175
+ )
176
+ / 2.0
177
+ + xp.sqrt(
178
+ 32 * self.rho_star**2
179
+ + (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
180
+ - (
181
+ 3072
182
+ * 2**0.3333333333333333
183
+ * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
184
+ * (self.rho_star**4 - u * self.rho_star**4)
185
+ )
186
+ / (
187
+ (-1 + u) ** 2
188
+ * xp.cbrt(
189
+ -1769472 * self.rho_star**6
190
+ + 1769472 * u * self.rho_star**6
191
+ - xp.sqrt(
192
+ 3131031158784 * u * self.rho_star**12
193
+ - 6262062317568 * u**2 * self.rho_star**12
194
+ + 3131031158784 * u**3 * self.rho_star**12
195
+ )
196
+ )
197
+ )
198
+ - xp.cbrt(
199
+ -1769472 * self.rho_star**6
200
+ + 1769472 * u * self.rho_star**6
201
+ - xp.sqrt(
202
+ 3131031158784 * u * self.rho_star**12
203
+ - 6262062317568 * u**2 * self.rho_star**12
204
+ + 3131031158784 * u**3 * self.rho_star**12
205
+ )
206
+ )
207
+ / (3.0 * 2**0.3333333333333333 * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3))
208
+ + (2048 * self.rho_star**3 - (2048 * u * self.rho_star**3) / (-1 + u))
209
+ / (
210
+ 4.0
211
+ * xp.sqrt(
212
+ -32 * self.rho_star**2
213
+ - (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
214
+ + (
215
+ 3072
216
+ * 2**0.3333333333333333
217
+ * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
218
+ * (self.rho_star**4 - u * self.rho_star**4)
219
+ )
220
+ / (
221
+ (-1 + u) ** 2
222
+ * xp.cbrt(
223
+ -1769472 * self.rho_star**6
224
+ + 1769472 * u * self.rho_star**6
225
+ - xp.sqrt(
226
+ 3131031158784 * u * self.rho_star**12
227
+ - 6262062317568 * u**2 * self.rho_star**12
228
+ + 3131031158784 * u**3 * self.rho_star**12
229
+ )
230
+ )
231
+ )
232
+ + xp.cbrt(
233
+ -1769472 * self.rho_star**6
234
+ + 1769472 * u * self.rho_star**6
235
+ - xp.sqrt(
236
+ 3131031158784 * u * self.rho_star**12
237
+ - 6262062317568 * u**2 * self.rho_star**12
238
+ + 3131031158784 * u**3 * self.rho_star**12
239
+ )
240
+ )
241
+ / (
242
+ 3.0
243
+ * 2**0.3333333333333333
244
+ * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
245
+ )
246
+ )
247
+ )
248
+ )
249
+ / 2.0
250
+ )
251
+
252
+ return rho
253
+
254
+
255
+ class AmplitudeFromSNR:
256
+ def __init__(self, L, Tobs, fd=None, use_cupy=False, **noise_kwargs):
257
+ self.f_star = 1 / (2.0 * np.pi * L) * C_SI
258
+ self.Tobs = Tobs
259
+ self.noise_kwargs = noise_kwargs
260
+
261
+ xp = np if not use_cupy else cp
262
+ if fd is not None:
263
+ self.fd = xp.asarray(fd)
264
+ else:
265
+ self.fd = fd
266
+
267
+ # got to be after fd
268
+ self.use_cupy = use_cupy
269
+
270
+ @property
271
+ def use_cupy(self):
272
+ return self._use_cupy
273
+
274
+ @use_cupy.setter
275
+ def use_cupy(self, use_cupy):
276
+ self._use_cupy = use_cupy
277
+ if use_cupy and not isinstance(self.fd, cp.ndarray):
278
+ self.fd = cp.asarray(self.fd)
279
+ elif not use_cupy and isinstance(self.fd, cp.ndarray):
280
+ self.fd = self.fd.get()
281
+
282
+ def interp_psd(self, f0, psds, walker_inds=None):
283
+ assert self.fd is not None
284
+ xp = np if not self.use_cupy else cp
285
+ psds = xp.atleast_2d(psds)
286
+
287
+ if xp == cp and not isinstance(self.fd, cp.ndarray):
288
+ self.fd = xp.asarray(self.fd)
289
+ try:
290
+ inds_fd = xp.searchsorted(self.fd, f0, side="right") - 1
291
+ except:
292
+ breakpoint()
293
+ if walker_inds is None:
294
+ walker_inds = xp.zeros_like(f0, dtype=int)
295
+
296
+ new_psds = (psds[(walker_inds, inds_fd + 1)] - psds[(walker_inds, inds_fd)]) / (
297
+ self.fd[inds_fd + 1] - self.fd[inds_fd]
298
+ ) * (f0 - self.fd[inds_fd]) + psds[(walker_inds, inds_fd)]
299
+ return new_psds
300
+
301
+ def __call__(self, rho, f0, **noise_kwargs):
302
+
303
+ xp = np if not self.use_cupy else cp
304
+
305
+ if noise_kwargs == {}:
306
+ noise_kwargs = self.noise_kwargs
307
+
308
+ Sn_f = self.get_Sn_f(f0, **noise_kwargs)
309
+
310
+ factor = 1.0 / 2.0 * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
311
+ amp = rho / factor
312
+ return (amp, f0)
313
+
314
+ def get_Sn_f(self, f0, psds=None, walker_inds=None, Sn_f=None, **noise_kwargs):
315
+ if Sn_f is not None:
316
+ assert len(f0) == len(Sn_f)
317
+ assert isinstance(f0, type(Sn_f))
318
+
319
+ elif psds is not None:
320
+ Sn_f = self.interp_psd(f0, psds, walker_inds=walker_inds)
321
+ else:
322
+ Sn_f = get_sensitivity(f0, **noise_kwargs)
323
+
324
+ return Sn_f
325
+
326
+ def forward(self, amp, f0, **noise_kwargs):
327
+
328
+ if noise_kwargs == {}:
329
+ noise_kwargs = self.noise_kwargs
330
+
331
+ Sn_f = self.get_Sn_f(f0, **noise_kwargs)
332
+
333
+ factor = 1.0 / 2.0 * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
334
+ rho = amp * factor
335
+ return (rho, f0)
336
+
337
+
338
+ class GBPriorWrap:
339
+ def __init__(self, ndim, full_prior_container, gen_frequency_alone=False):
340
+ self.base_prior = full_prior_container
341
+ self.use_cupy = full_prior_container.use_cupy
342
+ self.ndim = ndim
343
+ self.gen_frequency_alone = gen_frequency_alone
344
+
345
+ if gen_frequency_alone:
346
+ self.keys_sep = [1, 2, 3, 4, 5, 6, 7]
347
+ else:
348
+ self.keys_sep = [2, 3, 4, 5, 6, 7]
349
+
350
+ @property
351
+ def priors_in(self):
352
+ return self.base_prior.priors_in
353
+
354
+ def logpdf(self, x, **noise_kwargs):
355
+ xp = np if not self.use_cupy else cp
356
+ assert x.shape[1] == self.ndim and x.ndim == 2
357
+
358
+ logpdf_everything_else = self.base_prior.logpdf(x, keys=self.keys_sep)
359
+
360
+ f0 = xp.asarray(x[:, 1])
361
+ amp = xp.asarray(x[:, 0])
362
+ logpdf_A_f = self.base_prior.priors_in[(0, 1)].logpdf(amp, f0, **noise_kwargs)
363
+
364
+ return logpdf_A_f + logpdf_everything_else
365
+
366
+ def rvs(self, size=1, ignore_amp=False, **kwargs):
367
+ xp = np if not self.use_cupy else cp
368
+ if isinstance(size, int):
369
+ size = (size,)
370
+
371
+ arr = xp.zeros(size + (self.ndim,)).reshape(-1, self.ndim)
372
+
373
+ diff = self.ndim - len(self.keys_sep)
374
+ assert diff >= 0
375
+
376
+ arr[:, :] = self.base_prior.rvs(size, keys=self.keys_sep).reshape(-1, self.ndim)
377
+
378
+ if not ignore_amp:
379
+ f0_input = arr[:, 1] if self.gen_frequency_alone else None
380
+ arr[:, :diff] = (
381
+ xp.asarray(
382
+ self.base_prior.priors_in[(0, 1)].rvs(
383
+ size, f0_input=f0_input, **kwargs
384
+ )
385
+ )
386
+ .reshape(diff, -1)
387
+ .T
388
+ )
389
+
390
+ arr = arr.reshape(size + (self.ndim,))
391
+ return arr
392
+
393
+
394
+ class FullGaussianMixtureModel:
395
+ def __init__(
396
+ self,
397
+ gb,
398
+ weights,
399
+ means,
400
+ covs,
401
+ invcovs,
402
+ dets,
403
+ mins,
404
+ maxs,
405
+ limit=10.0,
406
+ use_cupy=False,
407
+ ):
408
+
409
+ self.use_cupy = use_cupy
410
+ if use_cupy:
411
+ xp = cp
412
+ else:
413
+ xp = np
414
+
415
+ self.gb = gb
416
+
417
+ indexing = []
418
+ for i, weight in enumerate(weights):
419
+ index_base = np.full_like(weight, i, dtype=int)
420
+ indexing.append(index_base)
421
+
422
+ self.indexing = xp.asarray(np.concatenate(indexing))
423
+ # invidivual weights / total number of components to uniformly choose from them
424
+ self.weights = xp.asarray(np.concatenate(weights, axis=0) * 1 / len(weights))
425
+
426
+ assert xp.allclose(self.weights.sum(), 1.0)
427
+ self.means = xp.asarray(np.concatenate(means, axis=0))
428
+ self.covs = xp.asarray(np.concatenate(covs, axis=0))
429
+ self.invcovs = xp.asarray(np.concatenate(invcovs, axis=0))
430
+ self.dets = xp.asarray(np.concatenate(dets, axis=0))
431
+ self.ndim = self.means.shape[1]
432
+
433
+ self.mins = xp.asarray(np.vstack(mins))
434
+ self.maxs = xp.asarray(np.vstack(maxs))
435
+
436
+ self.mins_in_pdf = self.mins[self.indexing].T.flatten().copy()
437
+ self.maxs_in_pdf = self.maxs[self.indexing].T.flatten().copy()
438
+ self.means_in_pdf = self.means.T.flatten().copy()
439
+ self.invcovs_in_pdf = self.invcovs.transpose(1, 2, 0).flatten().copy()
440
+
441
+ self.cumulative_weights = xp.concatenate(
442
+ [xp.array([0.0]), xp.cumsum(self.weights)]
443
+ )
444
+
445
+ self.min_limit_f = self.map_back_frequency(
446
+ -1.0 * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1]
447
+ )
448
+ self.max_limit_f = self.map_back_frequency(
449
+ +1.0 * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1]
450
+ )
451
+
452
+ # compute the jacobian
453
+ self.log_det_J = (
454
+ self.ndim * np.log(2) - xp.sum(xp.log(self.maxs - self.mins), axis=-1)
455
+ )[self.indexing].copy()
456
+
457
+ """self.inds_sort_min_limit_f = xp.argsort(self.min_limit_f)
458
+ self.inds_sort_max_limit_f = xp.argsort(self.max_limit_f)
459
+ self.sorted_min_limit_f = self.min_limit_f[self.inds_sort_min_limit_f]
460
+ self.sorted_max_limit_f = self.max_limit_f[self.inds_sort_max_limit_f]
461
+ """
462
+
463
+ def logpdf(self, x):
464
+
465
+ if self.use_cupy:
466
+ xp = cp
467
+ else:
468
+ xp = np
469
+
470
+ assert len(x.shape) == 2
471
+ assert x.shape[1] == self.ndim
472
+
473
+ k = self.ndim
474
+
475
+ inds_sort = xp.argsort(x[:, 1])
476
+ f_sort = x[:, 1][inds_sort]
477
+ points_sorted = x[inds_sort]
478
+
479
+ ind_min_limit = xp.searchsorted(f_sort, self.min_limit_f, side="left")
480
+ ind_max_limit = xp.searchsorted(f_sort, self.max_limit_f, side="right")
481
+
482
+ diff = ind_max_limit - ind_min_limit
483
+ cs = xp.concatenate([xp.array([0]), xp.cumsum(diff)])
484
+ tmp = xp.arange(cs[-1])
485
+ keep_component_map = xp.searchsorted(cs, tmp, side="right") - 1
486
+ keep_point_map = (
487
+ tmp - cs[keep_component_map] + ind_min_limit[keep_component_map]
488
+ )
489
+ max_components = diff.max().item()
490
+
491
+ int_check = int(1e6)
492
+ assert int_check > self.min_limit_f.shape[0]
493
+ special_point_component_map = int_check * keep_point_map + keep_component_map
494
+
495
+ sorted_special = xp.sort(special_point_component_map)
496
+
497
+ points_keep_in = (sorted_special / float(int_check)).astype(int)
498
+ components_keep_in = sorted_special - points_keep_in * int_check
499
+
500
+ unique_points, unique_starts = xp.unique(points_keep_in, return_index=True)
501
+ start_index_in_pdf = xp.concatenate(
502
+ [unique_starts, xp.array([len(points_keep_in)])]
503
+ ).astype(xp.int32)
504
+ assert xp.all(xp.diff(unique_starts) > 0)
505
+
506
+ points_sorted_in = points_sorted[unique_points]
507
+
508
+ logpdf_out_tmp = xp.zeros(points_sorted_in.shape[0])
509
+
510
+ self.gb.compute_logpdf(
511
+ logpdf_out_tmp,
512
+ components_keep_in.astype(xp.int32),
513
+ points_sorted_in,
514
+ self.weights,
515
+ self.mins_in_pdf,
516
+ self.maxs_in_pdf,
517
+ self.means_in_pdf,
518
+ self.invcovs_in_pdf,
519
+ self.dets,
520
+ self.log_det_J,
521
+ points_sorted_in.shape[0],
522
+ start_index_in_pdf,
523
+ self.weights.shape[0],
524
+ x.shape[1],
525
+ )
526
+
527
+ # need to reverse the sort
528
+ logpdf_out = xp.full(x.shape[0], -xp.inf)
529
+ logpdf_out[xp.sort(inds_sort[unique_points])] = logpdf_out_tmp[
530
+ xp.argsort(inds_sort[unique_points])
531
+ ]
532
+ return logpdf_out
533
+ """# breakpoint()
534
+
535
+ # map to reduced domain
536
+ x_mapped = (self.map_input(points_sorted[:, None, :], self.mins[None, :, :], self.maxs[None, :, :]))[:, self.indexing]
537
+
538
+ diff = x_mapped - self.means[None, :, :]
539
+ log_main_part = -1./2. * xp.einsum("...k,...k", diff, xp.einsum("...jk,...k->...j", self.invcovs, diff))
540
+ log_norm_factor = (k / 2) * xp.log(2 * np.pi) + (1 / 2) * xp.log(self.dets)
541
+ log_weighted_pdf = (xp.log(self.weights) + log_norm_factor)[None, :] + log_main_part
542
+
543
+ logpdf_full_dist_tmp = logsumexp(log_weighted_pdf, axis=-1, xp=xp)
544
+ logpdf_full_dist = logpdf_full_dist_tmp[xp.argsort(inds_sort)]
545
+
546
+ breakpoint()
547
+ assert xp.allclose(logpdf_full_dist, logpdf_out)
548
+
549
+ return logpdf_full_dist"""
550
+
551
+ def map_input(self, x, mins, maxs):
552
+ return ((x - mins) / (maxs - mins)) * 2.0 - 1.0
553
+
554
+ def map_back_frequency(self, x, mins, maxs):
555
+ return (x + 1.0) * 1.0 / 2.0 * (maxs - mins) + mins
556
+
557
+ def rvs(self, size=(1,)):
558
+
559
+ if isinstance(size, int):
560
+ size = (size,)
561
+
562
+ if self.use_cupy:
563
+ xp = cp
564
+ else:
565
+ xp = np
566
+
567
+ # choose which component
568
+ draw = xp.random.rand(*size)
569
+ component = (
570
+ xp.searchsorted(self.cumulative_weights, draw.flatten(), side="right") - 1
571
+ ).reshape(draw.shape)
572
+
573
+ mean_here = self.means[component]
574
+ cov_here = self.covs[component]
575
+
576
+ new_points = mean_here + xp.einsum(
577
+ "...kj,...j->...k",
578
+ cov_here,
579
+ np.random.randn(*(component.shape + (self.ndim,))),
580
+ )
581
+
582
+ index_here = self.indexing[component]
583
+ mins_here = self.mins[index_here]
584
+ maxs_here = self.maxs[index_here]
585
+ new_points_mapped = self.map_back_frequency(new_points, mins_here, maxs_here)
586
+
587
+ return new_points_mapped
588
+
589
+
590
+ # class FlowDist:
591
+ # def __init__(self, config: dict, model: Union[Galaxy, GalaxyFFdot], fit: str, ndim: int):
592
+
593
+ # self.dist = model(config)
594
+ # self.dist.load_fit()
595
+
596
+ # param_min, param_max = np.loadtxt(fit)
597
+ # self.dist.set_min(param_min)
598
+ # self.dist.set_max(param_max)
599
+
600
+ # self.config = config
601
+ # self.fit = fit
602
+ # self.ndim = ndim
603
+
604
+ # def rvs(self, size: Optional[Union[int, tuple]]=(1,)) -> cp.ndarray:
605
+ # if isinstance(size, int):
606
+ # size = (size,)
607
+
608
+ # total_samp = int(np.prod(size))
609
+ # samples = self.dist.sample(total_samp).reshape(size + (self.ndim,))
610
+ # return samples
611
+
612
+ # def logpdf(self, x: cp.ndarray) -> cp.ndarray:
613
+ # assert x.shape[-1] == self.ndim
614
+ # log_prob = self.dist.log_prob(x.reshape(-1, self.ndim)).reshape(x.shape[:-1])
615
+ # return log_prob
616
+
617
+ # class GalaxyFlowDist(FlowDist):
618
+ # def __init__(self):
619
+ # config = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/configs/gbs/density_galaxy.yaml'
620
+ # model = Galaxy
621
+ # fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_galaxy_sangria.txt'
622
+ # ndim = 3
623
+ # super().__init__(config, model, fit, ndim)
624
+
625
+ # def logpdf(self, x: cp.ndarray) -> cp.ndarray:
626
+ # # adjust amplitudes to exp
627
+ # x[:, 0] = np.log(x[:, 0])
628
+ # return super().logpdf(x)
629
+
630
+ # def rvs(self, size: Optional[Union[int, tuple]]=(1,)) -> cp.ndarray:
631
+ # if isinstance(size, int):
632
+ # size = (size,)
633
+
634
+ # samples = super().rvs(size=size)
635
+ # samples = samples.reshape(-1, samples.shape[-1])
636
+ # samples[:, 0] = np.exp(samples[:, 0])
637
+ # samples = samples.reshape(size + (samples.shape[-1],))
638
+ # return samples
639
+
640
+ # class FFdotFlowDist(FlowDist):
641
+ # def __init__(self):
642
+ # config = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/configs/gbs/density_f.yaml'
643
+ # model = GalaxyFFdot
644
+ # fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_ffdot_sangria.txt'
645
+ # ndim = 2
646
+ # super().__init__(config, model, fit, ndim)