lisaanalysistools 1.0.0__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lisaanalysistools might be problematic. Click here for more details.

Files changed (37) hide show
  1. lisaanalysistools-1.0.0.dist-info/LICENSE +201 -0
  2. lisaanalysistools-1.0.0.dist-info/METADATA +80 -0
  3. lisaanalysistools-1.0.0.dist-info/RECORD +37 -0
  4. lisaanalysistools-1.0.0.dist-info/WHEEL +5 -0
  5. lisaanalysistools-1.0.0.dist-info/top_level.txt +2 -0
  6. lisatools/__init__.py +0 -0
  7. lisatools/_version.py +4 -0
  8. lisatools/analysiscontainer.py +438 -0
  9. lisatools/cutils/detector.cpython-312-darwin.so +0 -0
  10. lisatools/datacontainer.py +292 -0
  11. lisatools/detector.py +410 -0
  12. lisatools/diagnostic.py +976 -0
  13. lisatools/glitch.py +193 -0
  14. lisatools/sampling/__init__.py +0 -0
  15. lisatools/sampling/likelihood.py +882 -0
  16. lisatools/sampling/moves/__init__.py +0 -0
  17. lisatools/sampling/moves/gbgroupstretch.py +53 -0
  18. lisatools/sampling/moves/gbmultipletryrj.py +1287 -0
  19. lisatools/sampling/moves/gbspecialgroupstretch.py +671 -0
  20. lisatools/sampling/moves/gbspecialstretch.py +1836 -0
  21. lisatools/sampling/moves/mbhspecialmove.py +286 -0
  22. lisatools/sampling/moves/placeholder.py +16 -0
  23. lisatools/sampling/moves/skymodehop.py +110 -0
  24. lisatools/sampling/moves/specialforegroundmove.py +564 -0
  25. lisatools/sampling/prior.py +508 -0
  26. lisatools/sampling/stopping.py +320 -0
  27. lisatools/sampling/utility.py +324 -0
  28. lisatools/sensitivity.py +888 -0
  29. lisatools/sources/__init__.py +0 -0
  30. lisatools/sources/emri/__init__.py +1 -0
  31. lisatools/sources/emri/tdiwaveform.py +72 -0
  32. lisatools/stochastic.py +291 -0
  33. lisatools/utils/__init__.py +0 -0
  34. lisatools/utils/constants.py +40 -0
  35. lisatools/utils/multigpudataholder.py +730 -0
  36. lisatools/utils/pointeradjust.py +106 -0
  37. lisatools/utils/utility.py +240 -0
@@ -0,0 +1,508 @@
1
+ import numpy as np
2
+ from scipy import stats
3
+
4
+ from ..utils.constants import *
5
+ from ..sensitivity import get_sensitivity
6
+ from eryn.moves.multipletry import logsumexp
7
+
8
+ from typing import Union, Optional, Tuple, List
9
+
10
+ import sys
11
+ sys.path.append("/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/gf_search/")
12
+ # from galaxy_ffdot import GalaxyFFdot
13
+ # from galaxy import Galaxy
14
+
15
+ try:
16
+ import cupy as cp
17
+
18
+ except (ModuleNotFoundError, ImportError) as e:
19
+ pass
20
+
21
+ class AmplitudeFrequencySNRPrior:
22
+ def __init__(self, rho_star, frequency_prior, L, Tobs, use_cupy=False, **noise_kwargs):
23
+ self.rho_star = rho_star
24
+ self.frequency_prior = frequency_prior
25
+
26
+ self.transform = AmplitudeFromSNR(L, Tobs, use_cupy=use_cupy, **noise_kwargs)
27
+ self.snr_prior = SNRPrior(rho_star, use_cupy=use_cupy)
28
+
29
+ # must be after transform and snr_prior due to setter
30
+ self.use_cupy = use_cupy
31
+
32
+ @property
33
+ def use_cupy(self):
34
+ return self._use_cupy
35
+
36
+ @use_cupy.setter
37
+ def use_cupy(self, use_cupy):
38
+ self._use_cupy = use_cupy
39
+ self.transform.use_cupy = use_cupy
40
+ self.snr_prior.use_cupy = use_cupy
41
+ self.frequency_prior.use_cupy = use_cupy
42
+
43
+ def pdf(self, *args, **noise_kwargs):
44
+ return np.exp(self.logpdf(*args, **noise_kwargs))
45
+
46
+ def logpdf(self, amp, f0_ms, **noise_kwargs):
47
+
48
+ xp = np if not self.use_cupy else cp
49
+
50
+ f0 = f0_ms / 1e3
51
+ rho, f0 = self.transform.forward(amp, f0, **noise_kwargs)
52
+
53
+ rho_pdf = self.snr_prior.pdf(rho)
54
+
55
+ Jac = xp.abs(rho / amp)
56
+
57
+ logpdf_amp = np.log(np.abs(Jac * rho_pdf))
58
+ logpdf_f = self.frequency_prior.logpdf(f0_ms)
59
+
60
+ return logpdf_amp + logpdf_f
61
+
62
+ def rvs(self, size=1, f0_input=None, **noise_kwargs):
63
+ if isinstance(size, int):
64
+ size = (size,)
65
+
66
+ xp = np if not self.use_cupy else cp
67
+
68
+ if f0_input is None:
69
+ f0_ms = self.frequency_prior.rvs(size=size)
70
+ else:
71
+ f0_ms = f0_input
72
+ assert f0_input.shape[:-1] == size
73
+
74
+ f0 = f0_ms / 1e3
75
+
76
+ rho = self.snr_prior.rvs(size=size)
77
+
78
+ amp, _ = self.transform(rho, f0, **noise_kwargs)
79
+
80
+ return (amp, f0_ms)
81
+
82
+
83
+
84
+
85
+
86
+
87
+ class SNRPrior:
88
+ def __init__(self, rho_star, use_cupy=False):
89
+ self.rho_star = rho_star
90
+ self.use_cupy = use_cupy
91
+
92
+ @property
93
+ def use_cupy(self):
94
+ return self._use_cupy
95
+
96
+ @use_cupy.setter
97
+ def use_cupy(self, use_cupy):
98
+ self._use_cupy = use_cupy
99
+
100
+ def pdf(self, rho):
101
+
102
+ xp = np if not self.use_cupy else cp
103
+
104
+ p = xp.zeros_like(rho)
105
+ good = rho > 0.0
106
+ p[good] = 3 * rho[good] / (4 * self.rho_star ** 2 * (1 + rho[good] / (4 * self.rho_star)) ** 5)
107
+ return p
108
+
109
+ def logpdf(self, rho):
110
+ xp = np if not self.use_cupy else cp
111
+ return xp.log(self.pdf(rho))
112
+
113
+ def cdf(self, rho):
114
+ xp = np if not self.use_cupy else cp
115
+ c = xp.zeros_like(rho)
116
+ good = rho > 0.0
117
+ c[good] = 768 * self.rho_star ** 3 * (1 / (768. * self.rho_star ** 3) - (rho[good] + self.rho_star)/(3. * (rho[good] + 4 * self.rho_star) ** 4))
118
+ return c
119
+
120
+ def rvs(self, size=1):
121
+ if isinstance(size, int):
122
+ size = (size,)
123
+
124
+ xp = np if not self.use_cupy else cp
125
+
126
+ u = xp.random.rand(*size)
127
+
128
+ rho = (-4*self.rho_star + xp.sqrt(-32*self.rho_star**2 - (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) +
129
+ (3072*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*
130
+ (self.rho_star**4 - u*self.rho_star**4))/
131
+ ((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
132
+ xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
133
+ 3131031158784*u**3*self.rho_star**12))) +
134
+ xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
135
+ xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
136
+ 3131031158784*u**3*self.rho_star**12))/
137
+ (3.*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)))/2.
138
+ + xp.sqrt(32*self.rho_star**2 + (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) -
139
+ (3072*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*
140
+ (self.rho_star**4 - u*self.rho_star**4))/
141
+ ((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
142
+ xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
143
+ 3131031158784*u**3*self.rho_star**12))) -
144
+ xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
145
+ xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
146
+ 3131031158784*u**3*self.rho_star**12))/
147
+ (3.*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)) +
148
+ (2048*self.rho_star**3 - (2048*u*self.rho_star**3)/(-1 + u))/
149
+ (4.*xp.sqrt(-32*self.rho_star**2 - (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) +
150
+ (3072*2**0.3333333333333333*
151
+ xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*(self.rho_star**4 - u*self.rho_star**4)
152
+ )/
153
+ ((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
154
+ xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
155
+ 3131031158784*u**3*self.rho_star**12))) +
156
+ xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
157
+ xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
158
+ 3131031158784*u**3*self.rho_star**12))/
159
+ (3.*2**0.3333333333333333*
160
+ xp.cbrt(-1 + 3*u - 3*u**2 + u**3)))))/2.)
161
+
162
+ return rho
163
+
164
+
165
+ class AmplitudeFromSNR:
166
+ def __init__(self, L, Tobs, fd=None, use_cupy=False, **noise_kwargs):
167
+ self.f_star = 1 / (2. * np.pi * L) * C_SI
168
+ self.Tobs = Tobs
169
+ self.noise_kwargs = noise_kwargs
170
+
171
+ xp = np if not use_cupy else cp
172
+ if fd is not None:
173
+ self.fd = xp.asarray(fd)
174
+ else:
175
+ self.fd = fd
176
+
177
+ # got to be after fd
178
+ self.use_cupy = use_cupy
179
+
180
+ @property
181
+ def use_cupy(self):
182
+ return self._use_cupy
183
+
184
+ @use_cupy.setter
185
+ def use_cupy(self, use_cupy):
186
+ self._use_cupy = use_cupy
187
+ if use_cupy and not isinstance(self.fd, cp.ndarray):
188
+ self.fd = cp.asarray(self.fd)
189
+ elif not use_cupy and isinstance(self.fd, cp.ndarray):
190
+ self.fd = self.fd.get()
191
+
192
+ def interp_psd(self, f0, psds, walker_inds=None):
193
+ assert self.fd is not None
194
+ xp = np if not self.use_cupy else cp
195
+ psds = xp.atleast_2d(psds)
196
+
197
+ if xp == cp and not isinstance(self.fd, cp.ndarray):
198
+ self.fd = xp.asarray(self.fd)
199
+ try:
200
+ inds_fd = xp.searchsorted(self.fd, f0, side="right") - 1
201
+ except:
202
+ breakpoint()
203
+ if walker_inds is None:
204
+ walker_inds = xp.zeros_like(f0, dtype=int)
205
+
206
+ new_psds = (psds[(walker_inds, inds_fd + 1)] - psds[(walker_inds, inds_fd)]) / (self.fd[inds_fd + 1] - self.fd[inds_fd]) * (f0 - self.fd[inds_fd]) + psds[(walker_inds, inds_fd)]
207
+ return new_psds
208
+
209
+ def __call__(self, rho, f0, **noise_kwargs):
210
+
211
+ xp = np if not self.use_cupy else cp
212
+
213
+ if noise_kwargs == {}:
214
+ noise_kwargs = self.noise_kwargs
215
+
216
+ Sn_f = self.get_Sn_f(f0, **noise_kwargs)
217
+
218
+ factor = 1./2. * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
219
+ amp = rho / factor
220
+ return (amp, f0)
221
+
222
+ def get_Sn_f(self, f0, psds=None, walker_inds=None, Sn_f=None, **noise_kwargs):
223
+ if Sn_f is not None:
224
+ assert len(f0) == len(Sn_f)
225
+ assert isinstance(f0, type(Sn_f))
226
+
227
+ elif psds is not None:
228
+ Sn_f = self.interp_psd(f0, psds, walker_inds=walker_inds)
229
+ else:
230
+ Sn_f = get_sensitivity(f0, **noise_kwargs)
231
+
232
+ return Sn_f
233
+
234
+ def forward(self, amp, f0, **noise_kwargs):
235
+
236
+ if noise_kwargs == {}:
237
+ noise_kwargs = self.noise_kwargs
238
+
239
+ Sn_f = self.get_Sn_f(f0, **noise_kwargs)
240
+
241
+ factor = 1./2. * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
242
+ rho = amp * factor
243
+ return (rho, f0)
244
+
245
+
246
+ class GBPriorWrap:
247
+ def __init__(self, ndim, full_prior_container, gen_frequency_alone=False):
248
+ self.base_prior = full_prior_container
249
+ self.use_cupy = full_prior_container.use_cupy
250
+ self.ndim = ndim
251
+ self.gen_frequency_alone = gen_frequency_alone
252
+
253
+ if gen_frequency_alone:
254
+ self.keys_sep = [1, 2, 3, 4, 5, 6, 7]
255
+ else:
256
+ self.keys_sep = [2, 3, 4, 5, 6, 7]
257
+
258
+ @property
259
+ def priors_in(self):
260
+ return self.base_prior.priors_in
261
+
262
+ def logpdf(self, x, **noise_kwargs):
263
+ xp = np if not self.use_cupy else cp
264
+ assert x.shape[1] == self.ndim and x.ndim == 2
265
+
266
+ logpdf_everything_else = self.base_prior.logpdf(x, keys=self.keys_sep)
267
+
268
+ f0 = xp.asarray(x[:, 1])
269
+ amp = xp.asarray(x[:, 0])
270
+ logpdf_A_f = self.base_prior.priors_in[(0, 1)].logpdf(amp, f0, **noise_kwargs)
271
+
272
+ return logpdf_A_f + logpdf_everything_else
273
+
274
+ def rvs(self, size=1, ignore_amp=False, **kwargs):
275
+ xp = np if not self.use_cupy else cp
276
+ if isinstance(size, int):
277
+ size = (size,)
278
+
279
+ arr = xp.zeros(size + (self.ndim,)).reshape(-1, self.ndim)
280
+
281
+ diff = self.ndim - len(self.keys_sep)
282
+ assert diff >= 0
283
+
284
+ arr[:, :] = self.base_prior.rvs(size, keys=self.keys_sep).reshape(-1, self.ndim)
285
+
286
+ if not ignore_amp:
287
+ f0_input = arr[:, 1] if self.gen_frequency_alone else None
288
+ arr[:, :diff] = xp.asarray(self.base_prior.priors_in[(0, 1)].rvs(size, f0_input=f0_input, **kwargs)).reshape(diff, -1).T
289
+
290
+ arr = arr.reshape(size + (self.ndim,))
291
+ return arr
292
+
293
+
294
+ class FullGaussianMixtureModel:
295
+ def __init__(self, gb, weights, means, covs, invcovs, dets, mins, maxs, limit=10.0, use_cupy=False):
296
+
297
+ self.use_cupy = use_cupy
298
+ if use_cupy:
299
+ xp = cp
300
+ else:
301
+ xp = np
302
+
303
+ self.gb = gb
304
+
305
+ indexing = []
306
+ for i, weight in enumerate(weights):
307
+ index_base = np.full_like(weight, i, dtype=int)
308
+ indexing.append(index_base)
309
+
310
+ self.indexing = xp.asarray(np.concatenate(indexing))
311
+ # invidivual weights / total number of components to uniformly choose from them
312
+ self.weights = xp.asarray(np.concatenate(weights, axis=0) * 1 / len(weights))
313
+
314
+ assert xp.allclose(self.weights.sum(), 1.0)
315
+ self.means = xp.asarray(np.concatenate(means, axis=0))
316
+ self.covs = xp.asarray(np.concatenate(covs, axis=0))
317
+ self.invcovs = xp.asarray(np.concatenate(invcovs, axis=0))
318
+ self.dets = xp.asarray(np.concatenate(dets, axis=0))
319
+ self.ndim = self.means.shape[1]
320
+
321
+ self.mins = xp.asarray(np.vstack(mins))
322
+ self.maxs = xp.asarray(np.vstack(maxs))
323
+
324
+ self.mins_in_pdf = self.mins[self.indexing].T.flatten().copy()
325
+ self.maxs_in_pdf = self.maxs[self.indexing].T.flatten().copy()
326
+ self.means_in_pdf = self.means.T.flatten().copy()
327
+ self.invcovs_in_pdf = self.invcovs.transpose(1, 2, 0).flatten().copy()
328
+
329
+ self.cumulative_weights = xp.concatenate([xp.array([0.0]), xp.cumsum(self.weights)])
330
+
331
+ self.min_limit_f = self.map_back_frequency(-1. * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1])
332
+ self.max_limit_f = self.map_back_frequency(+1. * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1])
333
+
334
+ # compute the jacobian
335
+ self.log_det_J = (self.ndim * np.log(2) - xp.sum(xp.log(self.maxs - self.mins), axis=-1))[self.indexing].copy()
336
+
337
+ """self.inds_sort_min_limit_f = xp.argsort(self.min_limit_f)
338
+ self.inds_sort_max_limit_f = xp.argsort(self.max_limit_f)
339
+ self.sorted_min_limit_f = self.min_limit_f[self.inds_sort_min_limit_f]
340
+ self.sorted_max_limit_f = self.max_limit_f[self.inds_sort_max_limit_f]
341
+ """
342
+ def logpdf(self, x):
343
+
344
+ if self.use_cupy:
345
+ xp = cp
346
+ else:
347
+ xp = np
348
+
349
+ assert len(x.shape) == 2
350
+ assert x.shape[1] == self.ndim
351
+
352
+ k = self.ndim
353
+
354
+ inds_sort = xp.argsort(x[:, 1])
355
+ f_sort = x[:, 1][inds_sort]
356
+ points_sorted = x[inds_sort]
357
+
358
+ ind_min_limit = xp.searchsorted(f_sort, self.min_limit_f, side="left")
359
+ ind_max_limit = xp.searchsorted(f_sort, self.max_limit_f, side="right")
360
+
361
+ diff = (ind_max_limit - ind_min_limit)
362
+ cs = xp.concatenate([xp.array([0]), xp.cumsum(diff)])
363
+ tmp = xp.arange(cs[-1])
364
+ keep_component_map = xp.searchsorted(cs, tmp, side="right") - 1
365
+ keep_point_map = tmp - cs[keep_component_map] + ind_min_limit[keep_component_map]
366
+ max_components = diff.max().item()
367
+
368
+ int_check = int(1e6)
369
+ assert int_check > self.min_limit_f.shape[0]
370
+ special_point_component_map = int_check * keep_point_map + keep_component_map
371
+
372
+ sorted_special = xp.sort(special_point_component_map)
373
+
374
+ points_keep_in = (sorted_special / float(int_check)).astype(int)
375
+ components_keep_in = sorted_special - points_keep_in * int_check
376
+
377
+ unique_points, unique_starts = xp.unique(points_keep_in, return_index=True)
378
+ start_index_in_pdf = xp.concatenate([unique_starts, xp.array([len(points_keep_in)])]).astype(xp.int32)
379
+ assert xp.all(xp.diff(unique_starts) > 0)
380
+
381
+ points_sorted_in = points_sorted[unique_points]
382
+
383
+ logpdf_out_tmp = xp.zeros(points_sorted_in.shape[0])
384
+
385
+ self.gb.compute_logpdf(logpdf_out_tmp, components_keep_in.astype(xp.int32), points_sorted_in,
386
+ self.weights, self.mins_in_pdf, self.maxs_in_pdf, self.means_in_pdf, self.invcovs_in_pdf, self.dets, self.log_det_J,
387
+ points_sorted_in.shape[0], start_index_in_pdf, self.weights.shape[0], x.shape[1])
388
+
389
+ # need to reverse the sort
390
+ logpdf_out = xp.full(x.shape[0], -xp.inf)
391
+ logpdf_out[xp.sort(inds_sort[unique_points])] = logpdf_out_tmp[xp.argsort(inds_sort[unique_points])]
392
+ return logpdf_out
393
+ """# breakpoint()
394
+
395
+ # map to reduced domain
396
+ x_mapped = (self.map_input(points_sorted[:, None, :], self.mins[None, :, :], self.maxs[None, :, :]))[:, self.indexing]
397
+
398
+ diff = x_mapped - self.means[None, :, :]
399
+ log_main_part = -1./2. * xp.einsum("...k,...k", diff, xp.einsum("...jk,...k->...j", self.invcovs, diff))
400
+ log_norm_factor = (k / 2) * xp.log(2 * np.pi) + (1 / 2) * xp.log(self.dets)
401
+ log_weighted_pdf = (xp.log(self.weights) + log_norm_factor)[None, :] + log_main_part
402
+
403
+ logpdf_full_dist_tmp = logsumexp(log_weighted_pdf, axis=-1, xp=xp)
404
+ logpdf_full_dist = logpdf_full_dist_tmp[xp.argsort(inds_sort)]
405
+
406
+ breakpoint()
407
+ assert xp.allclose(logpdf_full_dist, logpdf_out)
408
+
409
+ return logpdf_full_dist"""
410
+
411
+ def map_input(self, x, mins, maxs):
412
+ return ((x - mins) / (maxs - mins)) * 2. - 1.
413
+
414
+ def map_back_frequency(self, x, mins, maxs):
415
+ return (x + 1.) * 1. / 2. * (maxs - mins) + mins
416
+
417
+ def rvs(self, size=(1,)):
418
+
419
+ if isinstance(size, int):
420
+ size = (size,)
421
+
422
+ if self.use_cupy:
423
+ xp = cp
424
+ else:
425
+ xp = np
426
+
427
+ # choose which component
428
+ draw = xp.random.rand(*size)
429
+ component = (xp.searchsorted(self.cumulative_weights, draw.flatten(), side="right") - 1).reshape(draw.shape)
430
+
431
+ mean_here = self.means[component]
432
+ cov_here = self.covs[component]
433
+
434
+ new_points = mean_here + xp.einsum("...kj,...j->...k", cov_here, np.random.randn(*(component.shape + (self.ndim,))))
435
+
436
+ index_here = self.indexing[component]
437
+ mins_here = self.mins[index_here]
438
+ maxs_here = self.maxs[index_here]
439
+ new_points_mapped = self.map_back_frequency(new_points, mins_here, maxs_here)
440
+
441
+ return new_points_mapped
442
+
443
+
444
+ # class FlowDist:
445
+ # def __init__(self, config: dict, model: Union[Galaxy, GalaxyFFdot], fit: str, ndim: int):
446
+
447
+ # self.dist = model(config)
448
+ # self.dist.load_fit()
449
+
450
+ # param_min, param_max = np.loadtxt(fit)
451
+ # self.dist.set_min(param_min)
452
+ # self.dist.set_max(param_max)
453
+
454
+ # self.config = config
455
+ # self.fit = fit
456
+ # self.ndim = ndim
457
+
458
+ # def rvs(self, size: Optional[Union[int, tuple]]=(1,)) -> cp.ndarray:
459
+ # if isinstance(size, int):
460
+ # size = (size,)
461
+
462
+ # total_samp = int(np.prod(size))
463
+ # samples = self.dist.sample(total_samp).reshape(size + (self.ndim,))
464
+ # return samples
465
+
466
+ # def logpdf(self, x: cp.ndarray) -> cp.ndarray:
467
+ # assert x.shape[-1] == self.ndim
468
+ # log_prob = self.dist.log_prob(x.reshape(-1, self.ndim)).reshape(x.shape[:-1])
469
+ # return log_prob
470
+
471
+ # class GalaxyFlowDist(FlowDist):
472
+ # def __init__(self):
473
+ # config = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/configs/gbs/density_galaxy.yaml'
474
+ # model = Galaxy
475
+ # fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_galaxy_sangria.txt'
476
+ # ndim = 3
477
+ # super().__init__(config, model, fit, ndim)
478
+
479
+ # def logpdf(self, x: cp.ndarray) -> cp.ndarray:
480
+ # # adjust amplitudes to exp
481
+ # x[:, 0] = np.log(x[:, 0])
482
+ # return super().logpdf(x)
483
+
484
+ # def rvs(self, size: Optional[Union[int, tuple]]=(1,)) -> cp.ndarray:
485
+ # if isinstance(size, int):
486
+ # size = (size,)
487
+
488
+ # samples = super().rvs(size=size)
489
+ # samples = samples.reshape(-1, samples.shape[-1])
490
+ # samples[:, 0] = np.exp(samples[:, 0])
491
+ # samples = samples.reshape(size + (samples.shape[-1],))
492
+ # return samples
493
+
494
+ # class FFdotFlowDist(FlowDist):
495
+ # def __init__(self):
496
+ # config = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/configs/gbs/density_f.yaml'
497
+ # model = GalaxyFFdot
498
+ # fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_ffdot_sangria.txt'
499
+ # ndim = 2
500
+ # super().__init__(config, model, fit, ndim)
501
+
502
+
503
+
504
+
505
+
506
+
507
+
508
+