lisaanalysistools 1.1.20__cp39-cp39-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lisaanalysistools/git_version.py +7 -0
- lisaanalysistools-1.1.20.dist-info/METADATA +281 -0
- lisaanalysistools-1.1.20.dist-info/RECORD +48 -0
- lisaanalysistools-1.1.20.dist-info/WHEEL +5 -0
- lisaanalysistools-1.1.20.dist-info/licenses/LICENSE +201 -0
- lisatools/.dylibs/libgcc_s.1.1.dylib +0 -0
- lisatools/.dylibs/libstdc++.6.dylib +0 -0
- lisatools/__init__.py +90 -0
- lisatools/_version.py +34 -0
- lisatools/analysiscontainer.py +474 -0
- lisatools/cutils/Detector.cu +307 -0
- lisatools/cutils/Detector.hpp +84 -0
- lisatools/cutils/__init__.py +129 -0
- lisatools/cutils/global.hpp +28 -0
- lisatools/cutils/pycppdetector.pyx +256 -0
- lisatools/datacontainer.py +312 -0
- lisatools/detector.py +867 -0
- lisatools/diagnostic.py +990 -0
- lisatools/git_version.py.in +7 -0
- lisatools/orbit_files/equalarmlength-orbits-best-fit-to-esa.h5 +0 -0
- lisatools/orbit_files/equalarmlength-orbits.h5 +0 -0
- lisatools/orbit_files/esa-trailing-orbits.h5 +0 -0
- lisatools/sampling/__init__.py +0 -0
- lisatools/sampling/likelihood.py +882 -0
- lisatools/sampling/moves/__init__.py +0 -0
- lisatools/sampling/moves/skymodehop.py +110 -0
- lisatools/sampling/prior.py +646 -0
- lisatools/sampling/stopping.py +320 -0
- lisatools/sampling/utility.py +411 -0
- lisatools/sensitivity.py +1554 -0
- lisatools/sources/__init__.py +6 -0
- lisatools/sources/bbh/__init__.py +1 -0
- lisatools/sources/bbh/waveform.py +106 -0
- lisatools/sources/defaultresponse.py +37 -0
- lisatools/sources/emri/__init__.py +1 -0
- lisatools/sources/emri/waveform.py +79 -0
- lisatools/sources/gb/__init__.py +1 -0
- lisatools/sources/gb/waveform.py +69 -0
- lisatools/sources/utils.py +459 -0
- lisatools/sources/waveformbase.py +41 -0
- lisatools/stochastic.py +327 -0
- lisatools/utils/__init__.py +0 -0
- lisatools/utils/constants.py +54 -0
- lisatools/utils/exceptions.py +95 -0
- lisatools/utils/parallelbase.py +11 -0
- lisatools/utils/utility.py +122 -0
- lisatools_backend_cpu/git_version.py +7 -0
- lisatools_backend_cpu/pycppdetector.cpython-39-darwin.so +0 -0
|
@@ -0,0 +1,646 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy import stats
|
|
3
|
+
|
|
4
|
+
from ..utils.constants import *
|
|
5
|
+
from ..sensitivity import get_sensitivity
|
|
6
|
+
from eryn.moves.multipletry import logsumexp
|
|
7
|
+
|
|
8
|
+
from typing import Union, Optional, Tuple, List
|
|
9
|
+
|
|
10
|
+
import sys
|
|
11
|
+
|
|
12
|
+
sys.path.append(
|
|
13
|
+
"/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/gf_search/"
|
|
14
|
+
)
|
|
15
|
+
# from galaxy_ffdot import GalaxyFFdot
|
|
16
|
+
# from galaxy import Galaxy
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import cupy as cp
|
|
20
|
+
|
|
21
|
+
except (ModuleNotFoundError, ImportError) as e:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AmplitudeFrequencySNRPrior:
|
|
26
|
+
def __init__(
|
|
27
|
+
self, rho_star, frequency_prior, L, Tobs, use_cupy=False, **noise_kwargs
|
|
28
|
+
):
|
|
29
|
+
self.rho_star = rho_star
|
|
30
|
+
self.frequency_prior = frequency_prior
|
|
31
|
+
|
|
32
|
+
self.transform = AmplitudeFromSNR(L, Tobs, use_cupy=use_cupy, **noise_kwargs)
|
|
33
|
+
self.snr_prior = SNRPrior(rho_star, use_cupy=use_cupy)
|
|
34
|
+
|
|
35
|
+
# must be after transform and snr_prior due to setter
|
|
36
|
+
self.use_cupy = use_cupy
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def use_cupy(self):
|
|
40
|
+
return self._use_cupy
|
|
41
|
+
|
|
42
|
+
@use_cupy.setter
|
|
43
|
+
def use_cupy(self, use_cupy):
|
|
44
|
+
self._use_cupy = use_cupy
|
|
45
|
+
self.transform.use_cupy = use_cupy
|
|
46
|
+
self.snr_prior.use_cupy = use_cupy
|
|
47
|
+
self.frequency_prior.use_cupy = use_cupy
|
|
48
|
+
|
|
49
|
+
def pdf(self, *args, **noise_kwargs):
|
|
50
|
+
return np.exp(self.logpdf(*args, **noise_kwargs))
|
|
51
|
+
|
|
52
|
+
def logpdf(self, amp, f0_ms, **noise_kwargs):
|
|
53
|
+
|
|
54
|
+
xp = np if not self.use_cupy else cp
|
|
55
|
+
|
|
56
|
+
f0 = f0_ms / 1e3
|
|
57
|
+
rho, f0 = self.transform.forward(amp, f0, **noise_kwargs)
|
|
58
|
+
|
|
59
|
+
rho_pdf = self.snr_prior.pdf(rho)
|
|
60
|
+
|
|
61
|
+
Jac = xp.abs(rho / amp)
|
|
62
|
+
|
|
63
|
+
logpdf_amp = np.log(np.abs(Jac * rho_pdf))
|
|
64
|
+
logpdf_f = self.frequency_prior.logpdf(f0_ms)
|
|
65
|
+
|
|
66
|
+
return logpdf_amp + logpdf_f
|
|
67
|
+
|
|
68
|
+
def rvs(self, size=1, f0_input=None, **noise_kwargs):
|
|
69
|
+
if isinstance(size, int):
|
|
70
|
+
size = (size,)
|
|
71
|
+
|
|
72
|
+
xp = np if not self.use_cupy else cp
|
|
73
|
+
|
|
74
|
+
if f0_input is None:
|
|
75
|
+
f0_ms = self.frequency_prior.rvs(size=size)
|
|
76
|
+
else:
|
|
77
|
+
f0_ms = f0_input
|
|
78
|
+
assert f0_input.shape[:-1] == size
|
|
79
|
+
|
|
80
|
+
f0 = f0_ms / 1e3
|
|
81
|
+
|
|
82
|
+
rho = self.snr_prior.rvs(size=size)
|
|
83
|
+
|
|
84
|
+
amp, _ = self.transform(rho, f0, **noise_kwargs)
|
|
85
|
+
|
|
86
|
+
return (amp, f0_ms)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class SNRPrior:
|
|
90
|
+
def __init__(self, rho_star, use_cupy=False):
|
|
91
|
+
self.rho_star = rho_star
|
|
92
|
+
self.use_cupy = use_cupy
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def use_cupy(self):
|
|
96
|
+
return self._use_cupy
|
|
97
|
+
|
|
98
|
+
@use_cupy.setter
|
|
99
|
+
def use_cupy(self, use_cupy):
|
|
100
|
+
self._use_cupy = use_cupy
|
|
101
|
+
|
|
102
|
+
def pdf(self, rho):
|
|
103
|
+
|
|
104
|
+
xp = np if not self.use_cupy else cp
|
|
105
|
+
|
|
106
|
+
p = xp.zeros_like(rho)
|
|
107
|
+
good = rho > 0.0
|
|
108
|
+
p[good] = (
|
|
109
|
+
3
|
|
110
|
+
* rho[good]
|
|
111
|
+
/ (4 * self.rho_star**2 * (1 + rho[good] / (4 * self.rho_star)) ** 5)
|
|
112
|
+
)
|
|
113
|
+
return p
|
|
114
|
+
|
|
115
|
+
def logpdf(self, rho):
|
|
116
|
+
xp = np if not self.use_cupy else cp
|
|
117
|
+
return xp.log(self.pdf(rho))
|
|
118
|
+
|
|
119
|
+
def cdf(self, rho):
|
|
120
|
+
xp = np if not self.use_cupy else cp
|
|
121
|
+
c = xp.zeros_like(rho)
|
|
122
|
+
good = rho > 0.0
|
|
123
|
+
c[good] = (
|
|
124
|
+
768
|
|
125
|
+
* self.rho_star**3
|
|
126
|
+
* (
|
|
127
|
+
1 / (768.0 * self.rho_star**3)
|
|
128
|
+
- (rho[good] + self.rho_star)
|
|
129
|
+
/ (3.0 * (rho[good] + 4 * self.rho_star) ** 4)
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
return c
|
|
133
|
+
|
|
134
|
+
def rvs(self, size=1):
|
|
135
|
+
if isinstance(size, int):
|
|
136
|
+
size = (size,)
|
|
137
|
+
|
|
138
|
+
xp = np if not self.use_cupy else cp
|
|
139
|
+
|
|
140
|
+
u = xp.random.rand(*size)
|
|
141
|
+
|
|
142
|
+
rho = (
|
|
143
|
+
-4 * self.rho_star
|
|
144
|
+
+ xp.sqrt(
|
|
145
|
+
-32 * self.rho_star**2
|
|
146
|
+
- (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
|
|
147
|
+
+ (
|
|
148
|
+
3072
|
|
149
|
+
* 2**0.3333333333333333
|
|
150
|
+
* xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
|
|
151
|
+
* (self.rho_star**4 - u * self.rho_star**4)
|
|
152
|
+
)
|
|
153
|
+
/ (
|
|
154
|
+
(-1 + u) ** 2
|
|
155
|
+
* xp.cbrt(
|
|
156
|
+
-1769472 * self.rho_star**6
|
|
157
|
+
+ 1769472 * u * self.rho_star**6
|
|
158
|
+
- xp.sqrt(
|
|
159
|
+
3131031158784 * u * self.rho_star**12
|
|
160
|
+
- 6262062317568 * u**2 * self.rho_star**12
|
|
161
|
+
+ 3131031158784 * u**3 * self.rho_star**12
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
+ xp.cbrt(
|
|
166
|
+
-1769472 * self.rho_star**6
|
|
167
|
+
+ 1769472 * u * self.rho_star**6
|
|
168
|
+
- xp.sqrt(
|
|
169
|
+
3131031158784 * u * self.rho_star**12
|
|
170
|
+
- 6262062317568 * u**2 * self.rho_star**12
|
|
171
|
+
+ 3131031158784 * u**3 * self.rho_star**12
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
/ (3.0 * 2**0.3333333333333333 * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3))
|
|
175
|
+
)
|
|
176
|
+
/ 2.0
|
|
177
|
+
+ xp.sqrt(
|
|
178
|
+
32 * self.rho_star**2
|
|
179
|
+
+ (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
|
|
180
|
+
- (
|
|
181
|
+
3072
|
|
182
|
+
* 2**0.3333333333333333
|
|
183
|
+
* xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
|
|
184
|
+
* (self.rho_star**4 - u * self.rho_star**4)
|
|
185
|
+
)
|
|
186
|
+
/ (
|
|
187
|
+
(-1 + u) ** 2
|
|
188
|
+
* xp.cbrt(
|
|
189
|
+
-1769472 * self.rho_star**6
|
|
190
|
+
+ 1769472 * u * self.rho_star**6
|
|
191
|
+
- xp.sqrt(
|
|
192
|
+
3131031158784 * u * self.rho_star**12
|
|
193
|
+
- 6262062317568 * u**2 * self.rho_star**12
|
|
194
|
+
+ 3131031158784 * u**3 * self.rho_star**12
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
- xp.cbrt(
|
|
199
|
+
-1769472 * self.rho_star**6
|
|
200
|
+
+ 1769472 * u * self.rho_star**6
|
|
201
|
+
- xp.sqrt(
|
|
202
|
+
3131031158784 * u * self.rho_star**12
|
|
203
|
+
- 6262062317568 * u**2 * self.rho_star**12
|
|
204
|
+
+ 3131031158784 * u**3 * self.rho_star**12
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
/ (3.0 * 2**0.3333333333333333 * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3))
|
|
208
|
+
+ (2048 * self.rho_star**3 - (2048 * u * self.rho_star**3) / (-1 + u))
|
|
209
|
+
/ (
|
|
210
|
+
4.0
|
|
211
|
+
* xp.sqrt(
|
|
212
|
+
-32 * self.rho_star**2
|
|
213
|
+
- (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
|
|
214
|
+
+ (
|
|
215
|
+
3072
|
|
216
|
+
* 2**0.3333333333333333
|
|
217
|
+
* xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
|
|
218
|
+
* (self.rho_star**4 - u * self.rho_star**4)
|
|
219
|
+
)
|
|
220
|
+
/ (
|
|
221
|
+
(-1 + u) ** 2
|
|
222
|
+
* xp.cbrt(
|
|
223
|
+
-1769472 * self.rho_star**6
|
|
224
|
+
+ 1769472 * u * self.rho_star**6
|
|
225
|
+
- xp.sqrt(
|
|
226
|
+
3131031158784 * u * self.rho_star**12
|
|
227
|
+
- 6262062317568 * u**2 * self.rho_star**12
|
|
228
|
+
+ 3131031158784 * u**3 * self.rho_star**12
|
|
229
|
+
)
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
+ xp.cbrt(
|
|
233
|
+
-1769472 * self.rho_star**6
|
|
234
|
+
+ 1769472 * u * self.rho_star**6
|
|
235
|
+
- xp.sqrt(
|
|
236
|
+
3131031158784 * u * self.rho_star**12
|
|
237
|
+
- 6262062317568 * u**2 * self.rho_star**12
|
|
238
|
+
+ 3131031158784 * u**3 * self.rho_star**12
|
|
239
|
+
)
|
|
240
|
+
)
|
|
241
|
+
/ (
|
|
242
|
+
3.0
|
|
243
|
+
* 2**0.3333333333333333
|
|
244
|
+
* xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
|
|
245
|
+
)
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
)
|
|
249
|
+
/ 2.0
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
return rho
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class AmplitudeFromSNR:
|
|
256
|
+
def __init__(self, L, Tobs, fd=None, use_cupy=False, **noise_kwargs):
|
|
257
|
+
self.f_star = 1 / (2.0 * np.pi * L) * C_SI
|
|
258
|
+
self.Tobs = Tobs
|
|
259
|
+
self.noise_kwargs = noise_kwargs
|
|
260
|
+
|
|
261
|
+
xp = np if not use_cupy else cp
|
|
262
|
+
if fd is not None:
|
|
263
|
+
self.fd = xp.asarray(fd)
|
|
264
|
+
else:
|
|
265
|
+
self.fd = fd
|
|
266
|
+
|
|
267
|
+
# got to be after fd
|
|
268
|
+
self.use_cupy = use_cupy
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def use_cupy(self):
|
|
272
|
+
return self._use_cupy
|
|
273
|
+
|
|
274
|
+
@use_cupy.setter
|
|
275
|
+
def use_cupy(self, use_cupy):
|
|
276
|
+
self._use_cupy = use_cupy
|
|
277
|
+
if use_cupy and not isinstance(self.fd, cp.ndarray):
|
|
278
|
+
self.fd = cp.asarray(self.fd)
|
|
279
|
+
elif not use_cupy and isinstance(self.fd, cp.ndarray):
|
|
280
|
+
self.fd = self.fd.get()
|
|
281
|
+
|
|
282
|
+
def interp_psd(self, f0, psds, walker_inds=None):
|
|
283
|
+
assert self.fd is not None
|
|
284
|
+
xp = np if not self.use_cupy else cp
|
|
285
|
+
psds = xp.atleast_2d(psds)
|
|
286
|
+
|
|
287
|
+
if xp == cp and not isinstance(self.fd, cp.ndarray):
|
|
288
|
+
self.fd = xp.asarray(self.fd)
|
|
289
|
+
try:
|
|
290
|
+
inds_fd = xp.searchsorted(self.fd, f0, side="right") - 1
|
|
291
|
+
except:
|
|
292
|
+
breakpoint()
|
|
293
|
+
if walker_inds is None:
|
|
294
|
+
walker_inds = xp.zeros_like(f0, dtype=int)
|
|
295
|
+
|
|
296
|
+
new_psds = (psds[(walker_inds, inds_fd + 1)] - psds[(walker_inds, inds_fd)]) / (
|
|
297
|
+
self.fd[inds_fd + 1] - self.fd[inds_fd]
|
|
298
|
+
) * (f0 - self.fd[inds_fd]) + psds[(walker_inds, inds_fd)]
|
|
299
|
+
return new_psds
|
|
300
|
+
|
|
301
|
+
def __call__(self, rho, f0, **noise_kwargs):
|
|
302
|
+
|
|
303
|
+
xp = np if not self.use_cupy else cp
|
|
304
|
+
|
|
305
|
+
if noise_kwargs == {}:
|
|
306
|
+
noise_kwargs = self.noise_kwargs
|
|
307
|
+
|
|
308
|
+
Sn_f = self.get_Sn_f(f0, **noise_kwargs)
|
|
309
|
+
|
|
310
|
+
factor = 1.0 / 2.0 * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
|
|
311
|
+
amp = rho / factor
|
|
312
|
+
return (amp, f0)
|
|
313
|
+
|
|
314
|
+
def get_Sn_f(self, f0, psds=None, walker_inds=None, Sn_f=None, **noise_kwargs):
|
|
315
|
+
if Sn_f is not None:
|
|
316
|
+
assert len(f0) == len(Sn_f)
|
|
317
|
+
assert isinstance(f0, type(Sn_f))
|
|
318
|
+
|
|
319
|
+
elif psds is not None:
|
|
320
|
+
Sn_f = self.interp_psd(f0, psds, walker_inds=walker_inds)
|
|
321
|
+
else:
|
|
322
|
+
Sn_f = get_sensitivity(f0, **noise_kwargs)
|
|
323
|
+
|
|
324
|
+
return Sn_f
|
|
325
|
+
|
|
326
|
+
def forward(self, amp, f0, **noise_kwargs):
|
|
327
|
+
|
|
328
|
+
if noise_kwargs == {}:
|
|
329
|
+
noise_kwargs = self.noise_kwargs
|
|
330
|
+
|
|
331
|
+
Sn_f = self.get_Sn_f(f0, **noise_kwargs)
|
|
332
|
+
|
|
333
|
+
factor = 1.0 / 2.0 * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
|
|
334
|
+
rho = amp * factor
|
|
335
|
+
return (rho, f0)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class GBPriorWrap:
|
|
339
|
+
def __init__(self, ndim, full_prior_container, gen_frequency_alone=False):
|
|
340
|
+
self.base_prior = full_prior_container
|
|
341
|
+
self.use_cupy = full_prior_container.use_cupy
|
|
342
|
+
self.ndim = ndim
|
|
343
|
+
self.gen_frequency_alone = gen_frequency_alone
|
|
344
|
+
|
|
345
|
+
if gen_frequency_alone:
|
|
346
|
+
self.keys_sep = [1, 2, 3, 4, 5, 6, 7]
|
|
347
|
+
else:
|
|
348
|
+
self.keys_sep = [2, 3, 4, 5, 6, 7]
|
|
349
|
+
|
|
350
|
+
@property
|
|
351
|
+
def priors_in(self):
|
|
352
|
+
return self.base_prior.priors_in
|
|
353
|
+
|
|
354
|
+
def logpdf(self, x, **noise_kwargs):
|
|
355
|
+
xp = np if not self.use_cupy else cp
|
|
356
|
+
assert x.shape[1] == self.ndim and x.ndim == 2
|
|
357
|
+
|
|
358
|
+
logpdf_everything_else = self.base_prior.logpdf(x, keys=self.keys_sep)
|
|
359
|
+
|
|
360
|
+
f0 = xp.asarray(x[:, 1])
|
|
361
|
+
amp = xp.asarray(x[:, 0])
|
|
362
|
+
logpdf_A_f = self.base_prior.priors_in[(0, 1)].logpdf(amp, f0, **noise_kwargs)
|
|
363
|
+
|
|
364
|
+
return logpdf_A_f + logpdf_everything_else
|
|
365
|
+
|
|
366
|
+
def rvs(self, size=1, ignore_amp=False, **kwargs):
|
|
367
|
+
xp = np if not self.use_cupy else cp
|
|
368
|
+
if isinstance(size, int):
|
|
369
|
+
size = (size,)
|
|
370
|
+
|
|
371
|
+
arr = xp.zeros(size + (self.ndim,)).reshape(-1, self.ndim)
|
|
372
|
+
|
|
373
|
+
diff = self.ndim - len(self.keys_sep)
|
|
374
|
+
assert diff >= 0
|
|
375
|
+
|
|
376
|
+
arr[:, :] = self.base_prior.rvs(size, keys=self.keys_sep).reshape(-1, self.ndim)
|
|
377
|
+
|
|
378
|
+
if not ignore_amp:
|
|
379
|
+
f0_input = arr[:, 1] if self.gen_frequency_alone else None
|
|
380
|
+
arr[:, :diff] = (
|
|
381
|
+
xp.asarray(
|
|
382
|
+
self.base_prior.priors_in[(0, 1)].rvs(
|
|
383
|
+
size, f0_input=f0_input, **kwargs
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
.reshape(diff, -1)
|
|
387
|
+
.T
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
arr = arr.reshape(size + (self.ndim,))
|
|
391
|
+
return arr
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class FullGaussianMixtureModel:
|
|
395
|
+
def __init__(
|
|
396
|
+
self,
|
|
397
|
+
gb,
|
|
398
|
+
weights,
|
|
399
|
+
means,
|
|
400
|
+
covs,
|
|
401
|
+
invcovs,
|
|
402
|
+
dets,
|
|
403
|
+
mins,
|
|
404
|
+
maxs,
|
|
405
|
+
limit=10.0,
|
|
406
|
+
use_cupy=False,
|
|
407
|
+
):
|
|
408
|
+
|
|
409
|
+
self.use_cupy = use_cupy
|
|
410
|
+
if use_cupy:
|
|
411
|
+
xp = cp
|
|
412
|
+
else:
|
|
413
|
+
xp = np
|
|
414
|
+
|
|
415
|
+
self.gb = gb
|
|
416
|
+
|
|
417
|
+
indexing = []
|
|
418
|
+
for i, weight in enumerate(weights):
|
|
419
|
+
index_base = np.full_like(weight, i, dtype=int)
|
|
420
|
+
indexing.append(index_base)
|
|
421
|
+
|
|
422
|
+
self.indexing = xp.asarray(np.concatenate(indexing))
|
|
423
|
+
# invidivual weights / total number of components to uniformly choose from them
|
|
424
|
+
self.weights = xp.asarray(np.concatenate(weights, axis=0) * 1 / len(weights))
|
|
425
|
+
|
|
426
|
+
assert xp.allclose(self.weights.sum(), 1.0)
|
|
427
|
+
self.means = xp.asarray(np.concatenate(means, axis=0))
|
|
428
|
+
self.covs = xp.asarray(np.concatenate(covs, axis=0))
|
|
429
|
+
self.invcovs = xp.asarray(np.concatenate(invcovs, axis=0))
|
|
430
|
+
self.dets = xp.asarray(np.concatenate(dets, axis=0))
|
|
431
|
+
self.ndim = self.means.shape[1]
|
|
432
|
+
|
|
433
|
+
self.mins = xp.asarray(np.vstack(mins))
|
|
434
|
+
self.maxs = xp.asarray(np.vstack(maxs))
|
|
435
|
+
|
|
436
|
+
self.mins_in_pdf = self.mins[self.indexing].T.flatten().copy()
|
|
437
|
+
self.maxs_in_pdf = self.maxs[self.indexing].T.flatten().copy()
|
|
438
|
+
self.means_in_pdf = self.means.T.flatten().copy()
|
|
439
|
+
self.invcovs_in_pdf = self.invcovs.transpose(1, 2, 0).flatten().copy()
|
|
440
|
+
|
|
441
|
+
self.cumulative_weights = xp.concatenate(
|
|
442
|
+
[xp.array([0.0]), xp.cumsum(self.weights)]
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
self.min_limit_f = self.map_back_frequency(
|
|
446
|
+
-1.0 * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1]
|
|
447
|
+
)
|
|
448
|
+
self.max_limit_f = self.map_back_frequency(
|
|
449
|
+
+1.0 * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1]
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# compute the jacobian
|
|
453
|
+
self.log_det_J = (
|
|
454
|
+
self.ndim * np.log(2) - xp.sum(xp.log(self.maxs - self.mins), axis=-1)
|
|
455
|
+
)[self.indexing].copy()
|
|
456
|
+
|
|
457
|
+
"""self.inds_sort_min_limit_f = xp.argsort(self.min_limit_f)
|
|
458
|
+
self.inds_sort_max_limit_f = xp.argsort(self.max_limit_f)
|
|
459
|
+
self.sorted_min_limit_f = self.min_limit_f[self.inds_sort_min_limit_f]
|
|
460
|
+
self.sorted_max_limit_f = self.max_limit_f[self.inds_sort_max_limit_f]
|
|
461
|
+
"""
|
|
462
|
+
|
|
463
|
+
def logpdf(self, x):
|
|
464
|
+
|
|
465
|
+
if self.use_cupy:
|
|
466
|
+
xp = cp
|
|
467
|
+
else:
|
|
468
|
+
xp = np
|
|
469
|
+
|
|
470
|
+
assert len(x.shape) == 2
|
|
471
|
+
assert x.shape[1] == self.ndim
|
|
472
|
+
|
|
473
|
+
k = self.ndim
|
|
474
|
+
|
|
475
|
+
inds_sort = xp.argsort(x[:, 1])
|
|
476
|
+
f_sort = x[:, 1][inds_sort]
|
|
477
|
+
points_sorted = x[inds_sort]
|
|
478
|
+
|
|
479
|
+
ind_min_limit = xp.searchsorted(f_sort, self.min_limit_f, side="left")
|
|
480
|
+
ind_max_limit = xp.searchsorted(f_sort, self.max_limit_f, side="right")
|
|
481
|
+
|
|
482
|
+
diff = ind_max_limit - ind_min_limit
|
|
483
|
+
cs = xp.concatenate([xp.array([0]), xp.cumsum(diff)])
|
|
484
|
+
tmp = xp.arange(cs[-1])
|
|
485
|
+
keep_component_map = xp.searchsorted(cs, tmp, side="right") - 1
|
|
486
|
+
keep_point_map = (
|
|
487
|
+
tmp - cs[keep_component_map] + ind_min_limit[keep_component_map]
|
|
488
|
+
)
|
|
489
|
+
max_components = diff.max().item()
|
|
490
|
+
|
|
491
|
+
int_check = int(1e6)
|
|
492
|
+
assert int_check > self.min_limit_f.shape[0]
|
|
493
|
+
special_point_component_map = int_check * keep_point_map + keep_component_map
|
|
494
|
+
|
|
495
|
+
sorted_special = xp.sort(special_point_component_map)
|
|
496
|
+
|
|
497
|
+
points_keep_in = (sorted_special / float(int_check)).astype(int)
|
|
498
|
+
components_keep_in = sorted_special - points_keep_in * int_check
|
|
499
|
+
|
|
500
|
+
unique_points, unique_starts = xp.unique(points_keep_in, return_index=True)
|
|
501
|
+
start_index_in_pdf = xp.concatenate(
|
|
502
|
+
[unique_starts, xp.array([len(points_keep_in)])]
|
|
503
|
+
).astype(xp.int32)
|
|
504
|
+
assert xp.all(xp.diff(unique_starts) > 0)
|
|
505
|
+
|
|
506
|
+
points_sorted_in = points_sorted[unique_points]
|
|
507
|
+
|
|
508
|
+
logpdf_out_tmp = xp.zeros(points_sorted_in.shape[0])
|
|
509
|
+
|
|
510
|
+
self.gb.compute_logpdf(
|
|
511
|
+
logpdf_out_tmp,
|
|
512
|
+
components_keep_in.astype(xp.int32),
|
|
513
|
+
points_sorted_in,
|
|
514
|
+
self.weights,
|
|
515
|
+
self.mins_in_pdf,
|
|
516
|
+
self.maxs_in_pdf,
|
|
517
|
+
self.means_in_pdf,
|
|
518
|
+
self.invcovs_in_pdf,
|
|
519
|
+
self.dets,
|
|
520
|
+
self.log_det_J,
|
|
521
|
+
points_sorted_in.shape[0],
|
|
522
|
+
start_index_in_pdf,
|
|
523
|
+
self.weights.shape[0],
|
|
524
|
+
x.shape[1],
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
# need to reverse the sort
|
|
528
|
+
logpdf_out = xp.full(x.shape[0], -xp.inf)
|
|
529
|
+
logpdf_out[xp.sort(inds_sort[unique_points])] = logpdf_out_tmp[
|
|
530
|
+
xp.argsort(inds_sort[unique_points])
|
|
531
|
+
]
|
|
532
|
+
return logpdf_out
|
|
533
|
+
"""# breakpoint()
|
|
534
|
+
|
|
535
|
+
# map to reduced domain
|
|
536
|
+
x_mapped = (self.map_input(points_sorted[:, None, :], self.mins[None, :, :], self.maxs[None, :, :]))[:, self.indexing]
|
|
537
|
+
|
|
538
|
+
diff = x_mapped - self.means[None, :, :]
|
|
539
|
+
log_main_part = -1./2. * xp.einsum("...k,...k", diff, xp.einsum("...jk,...k->...j", self.invcovs, diff))
|
|
540
|
+
log_norm_factor = (k / 2) * xp.log(2 * np.pi) + (1 / 2) * xp.log(self.dets)
|
|
541
|
+
log_weighted_pdf = (xp.log(self.weights) + log_norm_factor)[None, :] + log_main_part
|
|
542
|
+
|
|
543
|
+
logpdf_full_dist_tmp = logsumexp(log_weighted_pdf, axis=-1, xp=xp)
|
|
544
|
+
logpdf_full_dist = logpdf_full_dist_tmp[xp.argsort(inds_sort)]
|
|
545
|
+
|
|
546
|
+
breakpoint()
|
|
547
|
+
assert xp.allclose(logpdf_full_dist, logpdf_out)
|
|
548
|
+
|
|
549
|
+
return logpdf_full_dist"""
|
|
550
|
+
|
|
551
|
+
def map_input(self, x, mins, maxs):
|
|
552
|
+
return ((x - mins) / (maxs - mins)) * 2.0 - 1.0
|
|
553
|
+
|
|
554
|
+
def map_back_frequency(self, x, mins, maxs):
|
|
555
|
+
return (x + 1.0) * 1.0 / 2.0 * (maxs - mins) + mins
|
|
556
|
+
|
|
557
|
+
def rvs(self, size=(1,)):
|
|
558
|
+
|
|
559
|
+
if isinstance(size, int):
|
|
560
|
+
size = (size,)
|
|
561
|
+
|
|
562
|
+
if self.use_cupy:
|
|
563
|
+
xp = cp
|
|
564
|
+
else:
|
|
565
|
+
xp = np
|
|
566
|
+
|
|
567
|
+
# choose which component
|
|
568
|
+
draw = xp.random.rand(*size)
|
|
569
|
+
component = (
|
|
570
|
+
xp.searchsorted(self.cumulative_weights, draw.flatten(), side="right") - 1
|
|
571
|
+
).reshape(draw.shape)
|
|
572
|
+
|
|
573
|
+
mean_here = self.means[component]
|
|
574
|
+
cov_here = self.covs[component]
|
|
575
|
+
|
|
576
|
+
new_points = mean_here + xp.einsum(
|
|
577
|
+
"...kj,...j->...k",
|
|
578
|
+
cov_here,
|
|
579
|
+
np.random.randn(*(component.shape + (self.ndim,))),
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
index_here = self.indexing[component]
|
|
583
|
+
mins_here = self.mins[index_here]
|
|
584
|
+
maxs_here = self.maxs[index_here]
|
|
585
|
+
new_points_mapped = self.map_back_frequency(new_points, mins_here, maxs_here)
|
|
586
|
+
|
|
587
|
+
return new_points_mapped
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
# class FlowDist:
|
|
591
|
+
# def __init__(self, config: dict, model: Union[Galaxy, GalaxyFFdot], fit: str, ndim: int):
|
|
592
|
+
|
|
593
|
+
# self.dist = model(config)
|
|
594
|
+
# self.dist.load_fit()
|
|
595
|
+
|
|
596
|
+
# param_min, param_max = np.loadtxt(fit)
|
|
597
|
+
# self.dist.set_min(param_min)
|
|
598
|
+
# self.dist.set_max(param_max)
|
|
599
|
+
|
|
600
|
+
# self.config = config
|
|
601
|
+
# self.fit = fit
|
|
602
|
+
# self.ndim = ndim
|
|
603
|
+
|
|
604
|
+
# def rvs(self, size: Optional[Union[int, tuple]]=(1,)) -> cp.ndarray:
|
|
605
|
+
# if isinstance(size, int):
|
|
606
|
+
# size = (size,)
|
|
607
|
+
|
|
608
|
+
# total_samp = int(np.prod(size))
|
|
609
|
+
# samples = self.dist.sample(total_samp).reshape(size + (self.ndim,))
|
|
610
|
+
# return samples
|
|
611
|
+
|
|
612
|
+
# def logpdf(self, x: cp.ndarray) -> cp.ndarray:
|
|
613
|
+
# assert x.shape[-1] == self.ndim
|
|
614
|
+
# log_prob = self.dist.log_prob(x.reshape(-1, self.ndim)).reshape(x.shape[:-1])
|
|
615
|
+
# return log_prob
|
|
616
|
+
|
|
617
|
+
# class GalaxyFlowDist(FlowDist):
|
|
618
|
+
# def __init__(self):
|
|
619
|
+
# config = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/configs/gbs/density_galaxy.yaml'
|
|
620
|
+
# model = Galaxy
|
|
621
|
+
# fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_galaxy_sangria.txt'
|
|
622
|
+
# ndim = 3
|
|
623
|
+
# super().__init__(config, model, fit, ndim)
|
|
624
|
+
|
|
625
|
+
# def logpdf(self, x: cp.ndarray) -> cp.ndarray:
|
|
626
|
+
# # adjust amplitudes to exp
|
|
627
|
+
# x[:, 0] = np.log(x[:, 0])
|
|
628
|
+
# return super().logpdf(x)
|
|
629
|
+
|
|
630
|
+
# def rvs(self, size: Optional[Union[int, tuple]]=(1,)) -> cp.ndarray:
|
|
631
|
+
# if isinstance(size, int):
|
|
632
|
+
# size = (size,)
|
|
633
|
+
|
|
634
|
+
# samples = super().rvs(size=size)
|
|
635
|
+
# samples = samples.reshape(-1, samples.shape[-1])
|
|
636
|
+
# samples[:, 0] = np.exp(samples[:, 0])
|
|
637
|
+
# samples = samples.reshape(size + (samples.shape[-1],))
|
|
638
|
+
# return samples
|
|
639
|
+
|
|
640
|
+
# class FFdotFlowDist(FlowDist):
|
|
641
|
+
# def __init__(self):
|
|
642
|
+
# config = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/configs/gbs/density_f.yaml'
|
|
643
|
+
# model = GalaxyFFdot
|
|
644
|
+
# fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_ffdot_sangria.txt'
|
|
645
|
+
# ndim = 2
|
|
646
|
+
# super().__init__(config, model, fit, ndim)
|