lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lisaanalysistools might be problematic. Click here for more details.
- lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
- lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
- lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
- lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
- lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
- lisatools/__init__.py +0 -0
- lisatools/_version.py +4 -0
- lisatools/analysiscontainer.py +438 -0
- lisatools/cutils/detector.cpython-312-darwin.so +0 -0
- lisatools/datacontainer.py +292 -0
- lisatools/detector.py +410 -0
- lisatools/diagnostic.py +976 -0
- lisatools/glitch.py +193 -0
- lisatools/sampling/__init__.py +0 -0
- lisatools/sampling/likelihood.py +882 -0
- lisatools/sampling/moves/__init__.py +0 -0
- lisatools/sampling/moves/gbgroupstretch.py +53 -0
- lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
- lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
- lisatools/sampling/moves/gbspecialstretch.py +1836 -0
- lisatools/sampling/moves/mbhspecialmove.py +286 -0
- lisatools/sampling/moves/placeholder.py +16 -0
- lisatools/sampling/moves/skymodehop.py +110 -0
- lisatools/sampling/moves/specialforegroundmove.py +564 -0
- lisatools/sampling/prior.py +508 -0
- lisatools/sampling/stopping.py +320 -0
- lisatools/sampling/utility.py +324 -0
- lisatools/sensitivity.py +888 -0
- lisatools/sources/__init__.py +0 -0
- lisatools/sources/emri/__init__.py +1 -0
- lisatools/sources/emri/tdiwaveform.py +72 -0
- lisatools/stochastic.py +291 -0
- lisatools/utils/__init__.py +0 -0
- lisatools/utils/constants.py +40 -0
- lisatools/utils/multigpudataholder.py +730 -0
- lisatools/utils/pointeradjust.py +106 -0
- lisatools/utils/utility.py +240 -0
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy import stats
|
|
3
|
+
|
|
4
|
+
from ..utils.constants import *
|
|
5
|
+
from ..sensitivity import get_sensitivity
|
|
6
|
+
from eryn.moves.multipletry import logsumexp
|
|
7
|
+
|
|
8
|
+
from typing import Union, Optional, Tuple, List
|
|
9
|
+
|
|
10
|
+
import sys
|
|
11
|
+
sys.path.append("/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/gf_search/")
|
|
12
|
+
# from galaxy_ffdot import GalaxyFFdot
|
|
13
|
+
# from galaxy import Galaxy
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import cupy as cp
|
|
17
|
+
|
|
18
|
+
except (ModuleNotFoundError, ImportError) as e:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
class AmplitudeFrequencySNRPrior:
|
|
22
|
+
def __init__(self, rho_star, frequency_prior, L, Tobs, use_cupy=False, **noise_kwargs):
|
|
23
|
+
self.rho_star = rho_star
|
|
24
|
+
self.frequency_prior = frequency_prior
|
|
25
|
+
|
|
26
|
+
self.transform = AmplitudeFromSNR(L, Tobs, use_cupy=use_cupy, **noise_kwargs)
|
|
27
|
+
self.snr_prior = SNRPrior(rho_star, use_cupy=use_cupy)
|
|
28
|
+
|
|
29
|
+
# must be after transform and snr_prior due to setter
|
|
30
|
+
self.use_cupy = use_cupy
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def use_cupy(self):
|
|
34
|
+
return self._use_cupy
|
|
35
|
+
|
|
36
|
+
@use_cupy.setter
|
|
37
|
+
def use_cupy(self, use_cupy):
|
|
38
|
+
self._use_cupy = use_cupy
|
|
39
|
+
self.transform.use_cupy = use_cupy
|
|
40
|
+
self.snr_prior.use_cupy = use_cupy
|
|
41
|
+
self.frequency_prior.use_cupy = use_cupy
|
|
42
|
+
|
|
43
|
+
def pdf(self, *args, **noise_kwargs):
|
|
44
|
+
return np.exp(self.logpdf(*args, **noise_kwargs))
|
|
45
|
+
|
|
46
|
+
def logpdf(self, amp, f0_ms, **noise_kwargs):
|
|
47
|
+
|
|
48
|
+
xp = np if not self.use_cupy else cp
|
|
49
|
+
|
|
50
|
+
f0 = f0_ms / 1e3
|
|
51
|
+
rho, f0 = self.transform.forward(amp, f0, **noise_kwargs)
|
|
52
|
+
|
|
53
|
+
rho_pdf = self.snr_prior.pdf(rho)
|
|
54
|
+
|
|
55
|
+
Jac = xp.abs(rho / amp)
|
|
56
|
+
|
|
57
|
+
logpdf_amp = np.log(np.abs(Jac * rho_pdf))
|
|
58
|
+
logpdf_f = self.frequency_prior.logpdf(f0_ms)
|
|
59
|
+
|
|
60
|
+
return logpdf_amp + logpdf_f
|
|
61
|
+
|
|
62
|
+
def rvs(self, size=1, f0_input=None, **noise_kwargs):
|
|
63
|
+
if isinstance(size, int):
|
|
64
|
+
size = (size,)
|
|
65
|
+
|
|
66
|
+
xp = np if not self.use_cupy else cp
|
|
67
|
+
|
|
68
|
+
if f0_input is None:
|
|
69
|
+
f0_ms = self.frequency_prior.rvs(size=size)
|
|
70
|
+
else:
|
|
71
|
+
f0_ms = f0_input
|
|
72
|
+
assert f0_input.shape[:-1] == size
|
|
73
|
+
|
|
74
|
+
f0 = f0_ms / 1e3
|
|
75
|
+
|
|
76
|
+
rho = self.snr_prior.rvs(size=size)
|
|
77
|
+
|
|
78
|
+
amp, _ = self.transform(rho, f0, **noise_kwargs)
|
|
79
|
+
|
|
80
|
+
return (amp, f0_ms)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class SNRPrior:
|
|
88
|
+
def __init__(self, rho_star, use_cupy=False):
|
|
89
|
+
self.rho_star = rho_star
|
|
90
|
+
self.use_cupy = use_cupy
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def use_cupy(self):
|
|
94
|
+
return self._use_cupy
|
|
95
|
+
|
|
96
|
+
@use_cupy.setter
|
|
97
|
+
def use_cupy(self, use_cupy):
|
|
98
|
+
self._use_cupy = use_cupy
|
|
99
|
+
|
|
100
|
+
def pdf(self, rho):
|
|
101
|
+
|
|
102
|
+
xp = np if not self.use_cupy else cp
|
|
103
|
+
|
|
104
|
+
p = xp.zeros_like(rho)
|
|
105
|
+
good = rho > 0.0
|
|
106
|
+
p[good] = 3 * rho[good] / (4 * self.rho_star ** 2 * (1 + rho[good] / (4 * self.rho_star)) ** 5)
|
|
107
|
+
return p
|
|
108
|
+
|
|
109
|
+
def logpdf(self, rho):
|
|
110
|
+
xp = np if not self.use_cupy else cp
|
|
111
|
+
return xp.log(self.pdf(rho))
|
|
112
|
+
|
|
113
|
+
def cdf(self, rho):
|
|
114
|
+
xp = np if not self.use_cupy else cp
|
|
115
|
+
c = xp.zeros_like(rho)
|
|
116
|
+
good = rho > 0.0
|
|
117
|
+
c[good] = 768 * self.rho_star ** 3 * (1 / (768. * self.rho_star ** 3) - (rho[good] + self.rho_star)/(3. * (rho[good] + 4 * self.rho_star) ** 4))
|
|
118
|
+
return c
|
|
119
|
+
|
|
120
|
+
def rvs(self, size=1):
|
|
121
|
+
if isinstance(size, int):
|
|
122
|
+
size = (size,)
|
|
123
|
+
|
|
124
|
+
xp = np if not self.use_cupy else cp
|
|
125
|
+
|
|
126
|
+
u = xp.random.rand(*size)
|
|
127
|
+
|
|
128
|
+
rho = (-4*self.rho_star + xp.sqrt(-32*self.rho_star**2 - (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) +
|
|
129
|
+
(3072*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*
|
|
130
|
+
(self.rho_star**4 - u*self.rho_star**4))/
|
|
131
|
+
((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
|
|
132
|
+
xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
|
|
133
|
+
3131031158784*u**3*self.rho_star**12))) +
|
|
134
|
+
xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
|
|
135
|
+
xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
|
|
136
|
+
3131031158784*u**3*self.rho_star**12))/
|
|
137
|
+
(3.*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)))/2.
|
|
138
|
+
+ xp.sqrt(32*self.rho_star**2 + (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) -
|
|
139
|
+
(3072*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*
|
|
140
|
+
(self.rho_star**4 - u*self.rho_star**4))/
|
|
141
|
+
((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
|
|
142
|
+
xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
|
|
143
|
+
3131031158784*u**3*self.rho_star**12))) -
|
|
144
|
+
xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
|
|
145
|
+
xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
|
|
146
|
+
3131031158784*u**3*self.rho_star**12))/
|
|
147
|
+
(3.*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)) +
|
|
148
|
+
(2048*self.rho_star**3 - (2048*u*self.rho_star**3)/(-1 + u))/
|
|
149
|
+
(4.*xp.sqrt(-32*self.rho_star**2 - (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) +
|
|
150
|
+
(3072*2**0.3333333333333333*
|
|
151
|
+
xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*(self.rho_star**4 - u*self.rho_star**4)
|
|
152
|
+
)/
|
|
153
|
+
((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
|
|
154
|
+
xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
|
|
155
|
+
3131031158784*u**3*self.rho_star**12))) +
|
|
156
|
+
xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
|
|
157
|
+
xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
|
|
158
|
+
3131031158784*u**3*self.rho_star**12))/
|
|
159
|
+
(3.*2**0.3333333333333333*
|
|
160
|
+
xp.cbrt(-1 + 3*u - 3*u**2 + u**3)))))/2.)
|
|
161
|
+
|
|
162
|
+
return rho
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class AmplitudeFromSNR:
|
|
166
|
+
def __init__(self, L, Tobs, fd=None, use_cupy=False, **noise_kwargs):
|
|
167
|
+
self.f_star = 1 / (2. * np.pi * L) * C_SI
|
|
168
|
+
self.Tobs = Tobs
|
|
169
|
+
self.noise_kwargs = noise_kwargs
|
|
170
|
+
|
|
171
|
+
xp = np if not use_cupy else cp
|
|
172
|
+
if fd is not None:
|
|
173
|
+
self.fd = xp.asarray(fd)
|
|
174
|
+
else:
|
|
175
|
+
self.fd = fd
|
|
176
|
+
|
|
177
|
+
# got to be after fd
|
|
178
|
+
self.use_cupy = use_cupy
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def use_cupy(self):
|
|
182
|
+
return self._use_cupy
|
|
183
|
+
|
|
184
|
+
@use_cupy.setter
|
|
185
|
+
def use_cupy(self, use_cupy):
|
|
186
|
+
self._use_cupy = use_cupy
|
|
187
|
+
if use_cupy and not isinstance(self.fd, cp.ndarray):
|
|
188
|
+
self.fd = cp.asarray(self.fd)
|
|
189
|
+
elif not use_cupy and isinstance(self.fd, cp.ndarray):
|
|
190
|
+
self.fd = self.fd.get()
|
|
191
|
+
|
|
192
|
+
def interp_psd(self, f0, psds, walker_inds=None):
|
|
193
|
+
assert self.fd is not None
|
|
194
|
+
xp = np if not self.use_cupy else cp
|
|
195
|
+
psds = xp.atleast_2d(psds)
|
|
196
|
+
|
|
197
|
+
if xp == cp and not isinstance(self.fd, cp.ndarray):
|
|
198
|
+
self.fd = xp.asarray(self.fd)
|
|
199
|
+
try:
|
|
200
|
+
inds_fd = xp.searchsorted(self.fd, f0, side="right") - 1
|
|
201
|
+
except:
|
|
202
|
+
breakpoint()
|
|
203
|
+
if walker_inds is None:
|
|
204
|
+
walker_inds = xp.zeros_like(f0, dtype=int)
|
|
205
|
+
|
|
206
|
+
new_psds = (psds[(walker_inds, inds_fd + 1)] - psds[(walker_inds, inds_fd)]) / (self.fd[inds_fd + 1] - self.fd[inds_fd]) * (f0 - self.fd[inds_fd]) + psds[(walker_inds, inds_fd)]
|
|
207
|
+
return new_psds
|
|
208
|
+
|
|
209
|
+
def __call__(self, rho, f0, **noise_kwargs):
|
|
210
|
+
|
|
211
|
+
xp = np if not self.use_cupy else cp
|
|
212
|
+
|
|
213
|
+
if noise_kwargs == {}:
|
|
214
|
+
noise_kwargs = self.noise_kwargs
|
|
215
|
+
|
|
216
|
+
Sn_f = self.get_Sn_f(f0, **noise_kwargs)
|
|
217
|
+
|
|
218
|
+
factor = 1./2. * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
|
|
219
|
+
amp = rho / factor
|
|
220
|
+
return (amp, f0)
|
|
221
|
+
|
|
222
|
+
def get_Sn_f(self, f0, psds=None, walker_inds=None, Sn_f=None, **noise_kwargs):
|
|
223
|
+
if Sn_f is not None:
|
|
224
|
+
assert len(f0) == len(Sn_f)
|
|
225
|
+
assert isinstance(f0, type(Sn_f))
|
|
226
|
+
|
|
227
|
+
elif psds is not None:
|
|
228
|
+
Sn_f = self.interp_psd(f0, psds, walker_inds=walker_inds)
|
|
229
|
+
else:
|
|
230
|
+
Sn_f = get_sensitivity(f0, **noise_kwargs)
|
|
231
|
+
|
|
232
|
+
return Sn_f
|
|
233
|
+
|
|
234
|
+
def forward(self, amp, f0, **noise_kwargs):
|
|
235
|
+
|
|
236
|
+
if noise_kwargs == {}:
|
|
237
|
+
noise_kwargs = self.noise_kwargs
|
|
238
|
+
|
|
239
|
+
Sn_f = self.get_Sn_f(f0, **noise_kwargs)
|
|
240
|
+
|
|
241
|
+
factor = 1./2. * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
|
|
242
|
+
rho = amp * factor
|
|
243
|
+
return (rho, f0)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class GBPriorWrap:
|
|
247
|
+
def __init__(self, ndim, full_prior_container, gen_frequency_alone=False):
|
|
248
|
+
self.base_prior = full_prior_container
|
|
249
|
+
self.use_cupy = full_prior_container.use_cupy
|
|
250
|
+
self.ndim = ndim
|
|
251
|
+
self.gen_frequency_alone = gen_frequency_alone
|
|
252
|
+
|
|
253
|
+
if gen_frequency_alone:
|
|
254
|
+
self.keys_sep = [1, 2, 3, 4, 5, 6, 7]
|
|
255
|
+
else:
|
|
256
|
+
self.keys_sep = [2, 3, 4, 5, 6, 7]
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def priors_in(self):
|
|
260
|
+
return self.base_prior.priors_in
|
|
261
|
+
|
|
262
|
+
def logpdf(self, x, **noise_kwargs):
|
|
263
|
+
xp = np if not self.use_cupy else cp
|
|
264
|
+
assert x.shape[1] == self.ndim and x.ndim == 2
|
|
265
|
+
|
|
266
|
+
logpdf_everything_else = self.base_prior.logpdf(x, keys=self.keys_sep)
|
|
267
|
+
|
|
268
|
+
f0 = xp.asarray(x[:, 1])
|
|
269
|
+
amp = xp.asarray(x[:, 0])
|
|
270
|
+
logpdf_A_f = self.base_prior.priors_in[(0, 1)].logpdf(amp, f0, **noise_kwargs)
|
|
271
|
+
|
|
272
|
+
return logpdf_A_f + logpdf_everything_else
|
|
273
|
+
|
|
274
|
+
def rvs(self, size=1, ignore_amp=False, **kwargs):
|
|
275
|
+
xp = np if not self.use_cupy else cp
|
|
276
|
+
if isinstance(size, int):
|
|
277
|
+
size = (size,)
|
|
278
|
+
|
|
279
|
+
arr = xp.zeros(size + (self.ndim,)).reshape(-1, self.ndim)
|
|
280
|
+
|
|
281
|
+
diff = self.ndim - len(self.keys_sep)
|
|
282
|
+
assert diff >= 0
|
|
283
|
+
|
|
284
|
+
arr[:, :] = self.base_prior.rvs(size, keys=self.keys_sep).reshape(-1, self.ndim)
|
|
285
|
+
|
|
286
|
+
if not ignore_amp:
|
|
287
|
+
f0_input = arr[:, 1] if self.gen_frequency_alone else None
|
|
288
|
+
arr[:, :diff] = xp.asarray(self.base_prior.priors_in[(0, 1)].rvs(size, f0_input=f0_input, **kwargs)).reshape(diff, -1).T
|
|
289
|
+
|
|
290
|
+
arr = arr.reshape(size + (self.ndim,))
|
|
291
|
+
return arr
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class FullGaussianMixtureModel:
|
|
295
|
+
def __init__(self, gb, weights, means, covs, invcovs, dets, mins, maxs, limit=10.0, use_cupy=False):
|
|
296
|
+
|
|
297
|
+
self.use_cupy = use_cupy
|
|
298
|
+
if use_cupy:
|
|
299
|
+
xp = cp
|
|
300
|
+
else:
|
|
301
|
+
xp = np
|
|
302
|
+
|
|
303
|
+
self.gb = gb
|
|
304
|
+
|
|
305
|
+
indexing = []
|
|
306
|
+
for i, weight in enumerate(weights):
|
|
307
|
+
index_base = np.full_like(weight, i, dtype=int)
|
|
308
|
+
indexing.append(index_base)
|
|
309
|
+
|
|
310
|
+
self.indexing = xp.asarray(np.concatenate(indexing))
|
|
311
|
+
# invidivual weights / total number of components to uniformly choose from them
|
|
312
|
+
self.weights = xp.asarray(np.concatenate(weights, axis=0) * 1 / len(weights))
|
|
313
|
+
|
|
314
|
+
assert xp.allclose(self.weights.sum(), 1.0)
|
|
315
|
+
self.means = xp.asarray(np.concatenate(means, axis=0))
|
|
316
|
+
self.covs = xp.asarray(np.concatenate(covs, axis=0))
|
|
317
|
+
self.invcovs = xp.asarray(np.concatenate(invcovs, axis=0))
|
|
318
|
+
self.dets = xp.asarray(np.concatenate(dets, axis=0))
|
|
319
|
+
self.ndim = self.means.shape[1]
|
|
320
|
+
|
|
321
|
+
self.mins = xp.asarray(np.vstack(mins))
|
|
322
|
+
self.maxs = xp.asarray(np.vstack(maxs))
|
|
323
|
+
|
|
324
|
+
self.mins_in_pdf = self.mins[self.indexing].T.flatten().copy()
|
|
325
|
+
self.maxs_in_pdf = self.maxs[self.indexing].T.flatten().copy()
|
|
326
|
+
self.means_in_pdf = self.means.T.flatten().copy()
|
|
327
|
+
self.invcovs_in_pdf = self.invcovs.transpose(1, 2, 0).flatten().copy()
|
|
328
|
+
|
|
329
|
+
self.cumulative_weights = xp.concatenate([xp.array([0.0]), xp.cumsum(self.weights)])
|
|
330
|
+
|
|
331
|
+
self.min_limit_f = self.map_back_frequency(-1. * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1])
|
|
332
|
+
self.max_limit_f = self.map_back_frequency(+1. * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1])
|
|
333
|
+
|
|
334
|
+
# compute the jacobian
|
|
335
|
+
self.log_det_J = (self.ndim * np.log(2) - xp.sum(xp.log(self.maxs - self.mins), axis=-1))[self.indexing].copy()
|
|
336
|
+
|
|
337
|
+
"""self.inds_sort_min_limit_f = xp.argsort(self.min_limit_f)
|
|
338
|
+
self.inds_sort_max_limit_f = xp.argsort(self.max_limit_f)
|
|
339
|
+
self.sorted_min_limit_f = self.min_limit_f[self.inds_sort_min_limit_f]
|
|
340
|
+
self.sorted_max_limit_f = self.max_limit_f[self.inds_sort_max_limit_f]
|
|
341
|
+
"""
|
|
342
|
+
def logpdf(self, x):
|
|
343
|
+
|
|
344
|
+
if self.use_cupy:
|
|
345
|
+
xp = cp
|
|
346
|
+
else:
|
|
347
|
+
xp = np
|
|
348
|
+
|
|
349
|
+
assert len(x.shape) == 2
|
|
350
|
+
assert x.shape[1] == self.ndim
|
|
351
|
+
|
|
352
|
+
k = self.ndim
|
|
353
|
+
|
|
354
|
+
inds_sort = xp.argsort(x[:, 1])
|
|
355
|
+
f_sort = x[:, 1][inds_sort]
|
|
356
|
+
points_sorted = x[inds_sort]
|
|
357
|
+
|
|
358
|
+
ind_min_limit = xp.searchsorted(f_sort, self.min_limit_f, side="left")
|
|
359
|
+
ind_max_limit = xp.searchsorted(f_sort, self.max_limit_f, side="right")
|
|
360
|
+
|
|
361
|
+
diff = (ind_max_limit - ind_min_limit)
|
|
362
|
+
cs = xp.concatenate([xp.array([0]), xp.cumsum(diff)])
|
|
363
|
+
tmp = xp.arange(cs[-1])
|
|
364
|
+
keep_component_map = xp.searchsorted(cs, tmp, side="right") - 1
|
|
365
|
+
keep_point_map = tmp - cs[keep_component_map] + ind_min_limit[keep_component_map]
|
|
366
|
+
max_components = diff.max().item()
|
|
367
|
+
|
|
368
|
+
int_check = int(1e6)
|
|
369
|
+
assert int_check > self.min_limit_f.shape[0]
|
|
370
|
+
special_point_component_map = int_check * keep_point_map + keep_component_map
|
|
371
|
+
|
|
372
|
+
sorted_special = xp.sort(special_point_component_map)
|
|
373
|
+
|
|
374
|
+
points_keep_in = (sorted_special / float(int_check)).astype(int)
|
|
375
|
+
components_keep_in = sorted_special - points_keep_in * int_check
|
|
376
|
+
|
|
377
|
+
unique_points, unique_starts = xp.unique(points_keep_in, return_index=True)
|
|
378
|
+
start_index_in_pdf = xp.concatenate([unique_starts, xp.array([len(points_keep_in)])]).astype(xp.int32)
|
|
379
|
+
assert xp.all(xp.diff(unique_starts) > 0)
|
|
380
|
+
|
|
381
|
+
points_sorted_in = points_sorted[unique_points]
|
|
382
|
+
|
|
383
|
+
logpdf_out_tmp = xp.zeros(points_sorted_in.shape[0])
|
|
384
|
+
|
|
385
|
+
self.gb.compute_logpdf(logpdf_out_tmp, components_keep_in.astype(xp.int32), points_sorted_in,
|
|
386
|
+
self.weights, self.mins_in_pdf, self.maxs_in_pdf, self.means_in_pdf, self.invcovs_in_pdf, self.dets, self.log_det_J,
|
|
387
|
+
points_sorted_in.shape[0], start_index_in_pdf, self.weights.shape[0], x.shape[1])
|
|
388
|
+
|
|
389
|
+
# need to reverse the sort
|
|
390
|
+
logpdf_out = xp.full(x.shape[0], -xp.inf)
|
|
391
|
+
logpdf_out[xp.sort(inds_sort[unique_points])] = logpdf_out_tmp[xp.argsort(inds_sort[unique_points])]
|
|
392
|
+
return logpdf_out
|
|
393
|
+
"""# breakpoint()
|
|
394
|
+
|
|
395
|
+
# map to reduced domain
|
|
396
|
+
x_mapped = (self.map_input(points_sorted[:, None, :], self.mins[None, :, :], self.maxs[None, :, :]))[:, self.indexing]
|
|
397
|
+
|
|
398
|
+
diff = x_mapped - self.means[None, :, :]
|
|
399
|
+
log_main_part = -1./2. * xp.einsum("...k,...k", diff, xp.einsum("...jk,...k->...j", self.invcovs, diff))
|
|
400
|
+
log_norm_factor = (k / 2) * xp.log(2 * np.pi) + (1 / 2) * xp.log(self.dets)
|
|
401
|
+
log_weighted_pdf = (xp.log(self.weights) + log_norm_factor)[None, :] + log_main_part
|
|
402
|
+
|
|
403
|
+
logpdf_full_dist_tmp = logsumexp(log_weighted_pdf, axis=-1, xp=xp)
|
|
404
|
+
logpdf_full_dist = logpdf_full_dist_tmp[xp.argsort(inds_sort)]
|
|
405
|
+
|
|
406
|
+
breakpoint()
|
|
407
|
+
assert xp.allclose(logpdf_full_dist, logpdf_out)
|
|
408
|
+
|
|
409
|
+
return logpdf_full_dist"""
|
|
410
|
+
|
|
411
|
+
def map_input(self, x, mins, maxs):
|
|
412
|
+
return ((x - mins) / (maxs - mins)) * 2. - 1.
|
|
413
|
+
|
|
414
|
+
def map_back_frequency(self, x, mins, maxs):
|
|
415
|
+
return (x + 1.) * 1. / 2. * (maxs - mins) + mins
|
|
416
|
+
|
|
417
|
+
def rvs(self, size=(1,)):
|
|
418
|
+
|
|
419
|
+
if isinstance(size, int):
|
|
420
|
+
size = (size,)
|
|
421
|
+
|
|
422
|
+
if self.use_cupy:
|
|
423
|
+
xp = cp
|
|
424
|
+
else:
|
|
425
|
+
xp = np
|
|
426
|
+
|
|
427
|
+
# choose which component
|
|
428
|
+
draw = xp.random.rand(*size)
|
|
429
|
+
component = (xp.searchsorted(self.cumulative_weights, draw.flatten(), side="right") - 1).reshape(draw.shape)
|
|
430
|
+
|
|
431
|
+
mean_here = self.means[component]
|
|
432
|
+
cov_here = self.covs[component]
|
|
433
|
+
|
|
434
|
+
new_points = mean_here + xp.einsum("...kj,...j->...k", cov_here, np.random.randn(*(component.shape + (self.ndim,))))
|
|
435
|
+
|
|
436
|
+
index_here = self.indexing[component]
|
|
437
|
+
mins_here = self.mins[index_here]
|
|
438
|
+
maxs_here = self.maxs[index_here]
|
|
439
|
+
new_points_mapped = self.map_back_frequency(new_points, mins_here, maxs_here)
|
|
440
|
+
|
|
441
|
+
return new_points_mapped
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
# class FlowDist:
|
|
445
|
+
# def __init__(self, config: dict, model: Union[Galaxy, GalaxyFFdot], fit: str, ndim: int):
|
|
446
|
+
|
|
447
|
+
# self.dist = model(config)
|
|
448
|
+
# self.dist.load_fit()
|
|
449
|
+
|
|
450
|
+
# param_min, param_max = np.loadtxt(fit)
|
|
451
|
+
# self.dist.set_min(param_min)
|
|
452
|
+
# self.dist.set_max(param_max)
|
|
453
|
+
|
|
454
|
+
# self.config = config
|
|
455
|
+
# self.fit = fit
|
|
456
|
+
# self.ndim = ndim
|
|
457
|
+
|
|
458
|
+
# def rvs(self, size: Optional[Union[int, tuple]]=(1,)) -> cp.ndarray:
|
|
459
|
+
# if isinstance(size, int):
|
|
460
|
+
# size = (size,)
|
|
461
|
+
|
|
462
|
+
# total_samp = int(np.prod(size))
|
|
463
|
+
# samples = self.dist.sample(total_samp).reshape(size + (self.ndim,))
|
|
464
|
+
# return samples
|
|
465
|
+
|
|
466
|
+
# def logpdf(self, x: cp.ndarray) -> cp.ndarray:
|
|
467
|
+
# assert x.shape[-1] == self.ndim
|
|
468
|
+
# log_prob = self.dist.log_prob(x.reshape(-1, self.ndim)).reshape(x.shape[:-1])
|
|
469
|
+
# return log_prob
|
|
470
|
+
|
|
471
|
+
# class GalaxyFlowDist(FlowDist):
|
|
472
|
+
# def __init__(self):
|
|
473
|
+
# config = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/configs/gbs/density_galaxy.yaml'
|
|
474
|
+
# model = Galaxy
|
|
475
|
+
# fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_galaxy_sangria.txt'
|
|
476
|
+
# ndim = 3
|
|
477
|
+
# super().__init__(config, model, fit, ndim)
|
|
478
|
+
|
|
479
|
+
# def logpdf(self, x: cp.ndarray) -> cp.ndarray:
|
|
480
|
+
# # adjust amplitudes to exp
|
|
481
|
+
# x[:, 0] = np.log(x[:, 0])
|
|
482
|
+
# return super().logpdf(x)
|
|
483
|
+
|
|
484
|
+
# def rvs(self, size: Optional[Union[int, tuple]]=(1,)) -> cp.ndarray:
|
|
485
|
+
# if isinstance(size, int):
|
|
486
|
+
# size = (size,)
|
|
487
|
+
|
|
488
|
+
# samples = super().rvs(size=size)
|
|
489
|
+
# samples = samples.reshape(-1, samples.shape[-1])
|
|
490
|
+
# samples[:, 0] = np.exp(samples[:, 0])
|
|
491
|
+
# samples = samples.reshape(size + (samples.shape[-1],))
|
|
492
|
+
# return samples
|
|
493
|
+
|
|
494
|
+
# class FFdotFlowDist(FlowDist):
|
|
495
|
+
# def __init__(self):
|
|
496
|
+
# config = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/configs/gbs/density_f.yaml'
|
|
497
|
+
# model = GalaxyFFdot
|
|
498
|
+
# fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_ffdot_sangria.txt'
|
|
499
|
+
# ndim = 2
|
|
500
|
+
# super().__init__(config, model, fit, ndim)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
|