lisaanalysistools 1.0.9__cp312-cp312-macosx_10_9_x86_64.whl → 1.0.11__cp312-cp312-macosx_10_9_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lisaanalysistools might be problematic. Click here for more details.
- {lisaanalysistools-1.0.9.dist-info → lisaanalysistools-1.0.11.dist-info}/METADATA +2 -2
- {lisaanalysistools-1.0.9.dist-info → lisaanalysistools-1.0.11.dist-info}/RECORD +13 -13
- {lisaanalysistools-1.0.9.dist-info → lisaanalysistools-1.0.11.dist-info}/WHEEL +1 -1
- lisatools/_version.py +1 -1
- lisatools/analysiscontainer.py +34 -2
- lisatools/cutils/detector_cpu.cpython-312-darwin.so +0 -0
- lisatools/detector.py +3 -2
- lisatools/sampling/likelihood.py +1 -1
- lisatools/sampling/utility.py +153 -66
- lisatools/sensitivity.py +4 -4
- lisatools/sources/utils.py +18 -2
- {lisaanalysistools-1.0.9.dist-info → lisaanalysistools-1.0.11.dist-info}/LICENSE +0 -0
- {lisaanalysistools-1.0.9.dist-info → lisaanalysistools-1.0.11.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: lisaanalysistools
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.11
|
|
4
4
|
Home-page: https://github.com/mikekatz04/LISAanalysistools
|
|
5
5
|
Author: Michael Katz
|
|
6
6
|
Author-email: mikekatz04@gmail.com
|
|
@@ -66,7 +66,7 @@ Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduc
|
|
|
66
66
|
|
|
67
67
|
We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/mikekatz04/LISAanalysistools/tags).
|
|
68
68
|
|
|
69
|
-
Current Version: 1.0.
|
|
69
|
+
Current Version: 1.0.11
|
|
70
70
|
|
|
71
71
|
## Authors/Developers
|
|
72
72
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
lisatools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
lisatools/_version.py,sha256=
|
|
3
|
-
lisatools/analysiscontainer.py,sha256=
|
|
2
|
+
lisatools/_version.py,sha256=kORDw-aHXqe534RIAm9NvG4SSRug8egjhH2hTnHhut0,124
|
|
3
|
+
lisatools/analysiscontainer.py,sha256=UQft6SvyDueDtuL1H1auiEzG5IEsFwRiiXf_0DHtu5U,16289
|
|
4
4
|
lisatools/datacontainer.py,sha256=QVz0twD46Fl_J-ueGQUCXOJkkXUEum7yrwp9LaVeohU,9853
|
|
5
|
-
lisatools/detector.py,sha256=
|
|
5
|
+
lisatools/detector.py,sha256=1ATZToFoy14U3ACTyver2x9DKavJ3l6ToU6nUSGSTmo,21161
|
|
6
6
|
lisatools/diagnostic.py,sha256=o9vtfXbY3yMDk4cGNeOsTrlP7gYG_BRROw7gyuThDaM,34192
|
|
7
|
-
lisatools/sensitivity.py,sha256=
|
|
7
|
+
lisatools/sensitivity.py,sha256=3BEVnRjFsQ7HZX6PRjd6B7K2FyIJRL3lpQpGaHjNgMA,27393
|
|
8
8
|
lisatools/stochastic.py,sha256=CysAQ5BIVD349fLjAMpZAXaT0X0dFjrd4QFi0vfqops,9458
|
|
9
9
|
lisatools/cutils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
lisatools/cutils/detector_cpu.cpython-312-darwin.so,sha256=
|
|
10
|
+
lisatools/cutils/detector_cpu.cpython-312-darwin.so,sha256=xhTNwTCXaSHLmHFOWOVAl3sP4tDy2o7TMExj_G5MORE,154720
|
|
11
11
|
lisatools/cutils/include/Detector.hpp,sha256=Ic37OgP-gvEg8qouhv9aFYf7vQP98tJ_i6PGu8RI1YQ,2050
|
|
12
12
|
lisatools/cutils/include/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
lisatools/cutils/include/global.hpp,sha256=3VPiqglTMRrIBXlEvDUJO2-CjKy_SLUXZt-9A1fH6lQ,572
|
|
@@ -16,15 +16,15 @@ lisatools/cutils/src/Detector.cu,sha256=JvAK8UCoYxgTACBhx3hi0EK9hT-HRYOxX2XozHuH
|
|
|
16
16
|
lisatools/cutils/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
lisatools/cutils/src/pycppdetector.pyx,sha256=e0tz79VNS7Sxxwyx9RQwsFVdpuClBsUf7OArSkyRBa0,7613
|
|
18
18
|
lisatools/sampling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
|
-
lisatools/sampling/likelihood.py,sha256=
|
|
19
|
+
lisatools/sampling/likelihood.py,sha256=A9a1McgjqJre8c3Bzoyxakl3onjxaE6c7zVwY2-12iI,29585
|
|
20
20
|
lisatools/sampling/prior.py,sha256=1K1PMStpwO9WT0qG0aotKSyoNjuehXNbzTDtlk8Q15M,21407
|
|
21
21
|
lisatools/sampling/stopping.py,sha256=Q8q7nM0wnJervhRduf2tBXZZHUVza5kJiAUAMUQXP5o,9682
|
|
22
|
-
lisatools/sampling/utility.py,sha256=
|
|
22
|
+
lisatools/sampling/utility.py,sha256=R5LKNK0A36L1PzLjCENQjDWcDjsR1oUBQGrR9PqasSA,15482
|
|
23
23
|
lisatools/sampling/moves/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
lisatools/sampling/moves/skymodehop.py,sha256=0nf72eFhFMGwi0dLJci6XZz-bIMGqco2B2_J72hQvf8,3348
|
|
25
25
|
lisatools/sources/__init__.py,sha256=Fm085xHQ3VpRjqaSlws0bdVefFofAxNzVZyGQQYrQic,140
|
|
26
26
|
lisatools/sources/defaultresponse.py,sha256=2SMbRf-UwgLNaA89tIANjw4BCKh1XEuRg5mgkvAg3-k,839
|
|
27
|
-
lisatools/sources/utils.py,sha256=
|
|
27
|
+
lisatools/sources/utils.py,sha256=5blfG4ozaN1QgOZR49o1vDH6E9_x0g1dNGfsRiTrVeo,14342
|
|
28
28
|
lisatools/sources/waveformbase.py,sha256=JPLqLZd1e-6E3ySyXodO83nZGH8bVq2K_s8sF2Oy84w,845
|
|
29
29
|
lisatools/sources/bbh/__init__.py,sha256=M3yP4eaScZZMhOucn88iiJ2WGL6zZTQI0xUEnmw4Nu8,37
|
|
30
30
|
lisatools/sources/bbh/waveform.py,sha256=-xg2uYB8AWf233W34hxwLsnlSYPF9d7TqzNfXj95zaA,2599
|
|
@@ -36,8 +36,8 @@ lisatools/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
36
36
|
lisatools/utils/constants.py,sha256=r1kVwkpbZS13JTOxj2iRxT5sMgTYX30y-S0JdVmD5Oo,1354
|
|
37
37
|
lisatools/utils/pointeradjust.py,sha256=2sT-7qeYWr1pd_sHk9leVHUTSJ7jJgYIRoWQOtYqguE,2995
|
|
38
38
|
lisatools/utils/utility.py,sha256=MKqRsG8qRI1xCsj51mt6QRa80-deUUgedxibyMV3KD4,6776
|
|
39
|
-
lisaanalysistools-1.0.
|
|
40
|
-
lisaanalysistools-1.0.
|
|
41
|
-
lisaanalysistools-1.0.
|
|
42
|
-
lisaanalysistools-1.0.
|
|
43
|
-
lisaanalysistools-1.0.
|
|
39
|
+
lisaanalysistools-1.0.11.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
40
|
+
lisaanalysistools-1.0.11.dist-info/METADATA,sha256=xjT_alb_5mDWKzhKyK9KldcQjlndzlQ8-eyrZ9Fp8fk,4202
|
|
41
|
+
lisaanalysistools-1.0.11.dist-info/WHEEL,sha256=zL8OKpwMaCM_hEdCwF-bHbg5JyuZtrW4vMYNMNxkj9k,110
|
|
42
|
+
lisaanalysistools-1.0.11.dist-info/top_level.txt,sha256=qXUn8Xi8yvK0vr3QH0vvT5cmoccjSj-UPeKGLAxdYGo,10
|
|
43
|
+
lisaanalysistools-1.0.11.dist-info/RECORD,,
|
lisatools/_version.py
CHANGED
lisatools/analysiscontainer.py
CHANGED
|
@@ -11,6 +11,8 @@ import numpy as np
|
|
|
11
11
|
from scipy import interpolate
|
|
12
12
|
import matplotlib.pyplot as plt
|
|
13
13
|
|
|
14
|
+
from eryn.utils import TransformContainer
|
|
15
|
+
|
|
14
16
|
|
|
15
17
|
try:
|
|
16
18
|
import cupy as cp
|
|
@@ -286,6 +288,7 @@ class AnalysisContainer:
|
|
|
286
288
|
source_only: bool = False,
|
|
287
289
|
waveform_kwargs: Optional[dict] = {},
|
|
288
290
|
data_res_arr_kwargs: Optional[dict] = {},
|
|
291
|
+
transform_fn: Optional[TransformContainer] = None,
|
|
289
292
|
**kwargs: dict,
|
|
290
293
|
) -> float | complex:
|
|
291
294
|
"""Return the likelihood of a generated signal with the data.
|
|
@@ -308,8 +311,14 @@ class AnalysisContainer:
|
|
|
308
311
|
if data_res_arr_kwargs == {}:
|
|
309
312
|
data_res_arr_kwargs = self.data_res_arr.init_kwargs
|
|
310
313
|
|
|
314
|
+
if transform_fn is not None:
|
|
315
|
+
args_tmp = np.asarray(args)
|
|
316
|
+
args_in = tuple(transform_fn.both_transforms(args_tmp))
|
|
317
|
+
else:
|
|
318
|
+
args_in = args
|
|
319
|
+
|
|
311
320
|
template = DataResidualArray(
|
|
312
|
-
self.signal_gen(*
|
|
321
|
+
self.signal_gen(*args_in, **waveform_kwargs), **data_res_arr_kwargs
|
|
313
322
|
)
|
|
314
323
|
|
|
315
324
|
args_2 = (template,)
|
|
@@ -426,7 +435,30 @@ class AnalysisContainer:
|
|
|
426
435
|
**kwargs,
|
|
427
436
|
)
|
|
428
437
|
|
|
429
|
-
def eryn_likelihood_function(
|
|
438
|
+
def eryn_likelihood_function(
|
|
439
|
+
self, x: np.ndarray | list | tuple, *args: Any, **kwargs: Any
|
|
440
|
+
) -> np.ndarray | float:
|
|
441
|
+
"""Likelihood function for Eryn sampler.
|
|
442
|
+
|
|
443
|
+
This function is not vectorized.
|
|
444
|
+
|
|
445
|
+
``signal_gen`` must be set to use this function.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
x: Parameters. Can be 1D list, tuple, array or 2D array.
|
|
449
|
+
If a 2D array is input, the computation is done serially.
|
|
450
|
+
*args: Likelihood args.
|
|
451
|
+
**kwargs: Likelihood kwargs.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
Likelihood value(s).
|
|
455
|
+
|
|
456
|
+
"""
|
|
457
|
+
assert self.signal_gen is not None
|
|
458
|
+
|
|
459
|
+
if isinstance(x, list) or isinstance(x, tuple):
|
|
460
|
+
x = np.asarray(x)
|
|
461
|
+
|
|
430
462
|
if x.ndim == 1:
|
|
431
463
|
input_vals = tuple(x) + tuple(args)
|
|
432
464
|
return self.calculate_signal_likelihood(*input_vals, **kwargs)
|
|
Binary file
|
lisatools/detector.py
CHANGED
|
@@ -9,6 +9,7 @@ import h5py
|
|
|
9
9
|
from scipy import interpolate
|
|
10
10
|
|
|
11
11
|
from .utils.constants import *
|
|
12
|
+
from .utils.utility import get_array_module
|
|
12
13
|
|
|
13
14
|
import numpy as np
|
|
14
15
|
|
|
@@ -453,9 +454,10 @@ class Orbits(ABC):
|
|
|
453
454
|
squeeze = False
|
|
454
455
|
t = self.xp.asarray(t)
|
|
455
456
|
sc = self.xp.asarray(sc).astype(np.int32)
|
|
457
|
+
|
|
456
458
|
else:
|
|
457
459
|
raise ValueError(
|
|
458
|
-
"(t, sc) can be (float, int), (np.ndarray, int), (np.ndarray, np.ndarray)."
|
|
460
|
+
"(t, sc) can be (float, int), (np.ndarray, int), (np.ndarray, np.ndarray). If the inputs follow this, make sure the orbits class GPU setting matches the arrays coming in (GPU or CPU)."
|
|
459
461
|
)
|
|
460
462
|
|
|
461
463
|
# buffer arrays for input into c code
|
|
@@ -563,7 +565,6 @@ class ESAOrbits(Orbits):
|
|
|
563
565
|
"""
|
|
564
566
|
|
|
565
567
|
def __init__(self, *args, **kwargs):
|
|
566
|
-
# TODO: fix this up
|
|
567
568
|
super().__init__("esa-trailing-orbits.h5", *args, **kwargs)
|
|
568
569
|
|
|
569
570
|
|
lisatools/sampling/likelihood.py
CHANGED
lisatools/sampling/utility.py
CHANGED
|
@@ -9,7 +9,7 @@ try:
|
|
|
9
9
|
except (ImportError, ModuleNotFoundError) as e:
|
|
10
10
|
pass
|
|
11
11
|
|
|
12
|
-
from eryn.state import State,
|
|
12
|
+
from eryn.state import State, BranchSupplemental
|
|
13
13
|
from eryn.utils.utility import groups_from_inds
|
|
14
14
|
|
|
15
15
|
|
|
@@ -19,8 +19,21 @@ class DetermineGBGroups:
|
|
|
19
19
|
self.xp = self.gb_wave_generator.xp
|
|
20
20
|
self.transform_fn = transform_fn
|
|
21
21
|
self.waveform_kwargs = waveform_kwargs
|
|
22
|
-
|
|
23
|
-
def __call__(
|
|
22
|
+
|
|
23
|
+
def __call__(
|
|
24
|
+
self,
|
|
25
|
+
last_sample,
|
|
26
|
+
name_here,
|
|
27
|
+
check_temp=0,
|
|
28
|
+
input_groups=None,
|
|
29
|
+
input_groups_inds=None,
|
|
30
|
+
fix_group_count=False,
|
|
31
|
+
mismatch_lim=0.2,
|
|
32
|
+
double_check_lim=0.2,
|
|
33
|
+
start_term="random",
|
|
34
|
+
waveform_kwargs={},
|
|
35
|
+
index_within_group="random",
|
|
36
|
+
):
|
|
24
37
|
# TODO: mess with mismatch lim setting
|
|
25
38
|
# TODO: some time of mismatch annealing may be useful
|
|
26
39
|
if isinstance(last_sample, State):
|
|
@@ -32,13 +45,13 @@ class DetermineGBGroups:
|
|
|
32
45
|
inds = last_sample[name_here][check_temp]["inds"]
|
|
33
46
|
|
|
34
47
|
waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
|
|
35
|
-
|
|
48
|
+
|
|
36
49
|
# get coordinates and inds of the temperature you are considering.
|
|
37
|
-
|
|
50
|
+
|
|
38
51
|
nwalkers, nleaves_max, ndim = coords.shape
|
|
39
52
|
if input_groups is None:
|
|
40
53
|
|
|
41
|
-
# figure our which walker to start with
|
|
54
|
+
# figure our which walker to start with
|
|
42
55
|
if start_term == "max":
|
|
43
56
|
start_walker_ind = inds[check_temp].sum(axis=-1).argmax()
|
|
44
57
|
elif start_term == "first":
|
|
@@ -52,7 +65,7 @@ class DetermineGBGroups:
|
|
|
52
65
|
inds_good = np.where(inds[start_walker_ind])[0]
|
|
53
66
|
groups = []
|
|
54
67
|
groups_inds = []
|
|
55
|
-
|
|
68
|
+
|
|
56
69
|
# set up this information to load the information into the group lists
|
|
57
70
|
for leaf_i, leaf in enumerate(inds_good):
|
|
58
71
|
groups.append([])
|
|
@@ -72,7 +85,7 @@ class DetermineGBGroups:
|
|
|
72
85
|
if input_groups is None and w == start_walker_ind:
|
|
73
86
|
continue
|
|
74
87
|
|
|
75
|
-
# walker has no binaries
|
|
88
|
+
# walker has no binaries
|
|
76
89
|
if not np.any(inds[w]):
|
|
77
90
|
continue
|
|
78
91
|
|
|
@@ -81,7 +94,6 @@ class DetermineGBGroups:
|
|
|
81
94
|
inds_for_group_stuff = np.arange(len(inds[w]))[inds[w]]
|
|
82
95
|
nleaves, ndim = coords_here.shape
|
|
83
96
|
|
|
84
|
-
|
|
85
97
|
params_for_test = []
|
|
86
98
|
for group in groups:
|
|
87
99
|
group_params = np.asarray(group)
|
|
@@ -95,39 +107,49 @@ class DetermineGBGroups:
|
|
|
95
107
|
|
|
96
108
|
params_for_test.append(group_params[test_walker_ind])
|
|
97
109
|
params_for_test = np.asarray(params_for_test)
|
|
98
|
-
|
|
110
|
+
|
|
99
111
|
# transform coords
|
|
100
112
|
if self.transform_fn is not None:
|
|
101
|
-
params_for_test_in = self.transform_fn[name_here].both_transforms(
|
|
102
|
-
|
|
103
|
-
|
|
113
|
+
params_for_test_in = self.transform_fn[name_here].both_transforms(
|
|
114
|
+
params_for_test, return_transpose=False
|
|
115
|
+
)
|
|
116
|
+
coords_here_in = self.transform_fn[name_here].both_transforms(
|
|
117
|
+
coords_here, return_transpose=False
|
|
118
|
+
)
|
|
119
|
+
|
|
104
120
|
else:
|
|
105
121
|
params_for_test_in = params_for_test.copy()
|
|
106
122
|
coords_here_in = coords_here.copy()
|
|
107
123
|
|
|
108
124
|
inds_tmp_test = np.arange(len(params_for_test_in))
|
|
109
125
|
inds_tmp_here = np.arange(len(coords_here_in))
|
|
110
|
-
inds_tmp_test, inds_tmp_here = [
|
|
126
|
+
inds_tmp_test, inds_tmp_here = [
|
|
127
|
+
tmp.ravel() for tmp in np.meshgrid(inds_tmp_test, inds_tmp_here)
|
|
128
|
+
]
|
|
111
129
|
|
|
112
130
|
params_for_test_in_full = params_for_test_in[inds_tmp_test]
|
|
113
131
|
coords_here_in_full = coords_here_in[inds_tmp_here]
|
|
114
132
|
# build the waveforms at the same time
|
|
115
133
|
|
|
116
|
-
df = 1. / waveform_kwargs["T"]
|
|
117
|
-
max_f = 1. / 2 * 1/waveform_kwargs["dt"]
|
|
134
|
+
df = 1.0 / waveform_kwargs["T"]
|
|
135
|
+
max_f = 1.0 / 2 * 1 / waveform_kwargs["dt"]
|
|
118
136
|
frqs = self.xp.arange(0.0, max_f, df)
|
|
119
|
-
data_minus_template = self.xp.asarray(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
137
|
+
data_minus_template = self.xp.asarray(
|
|
138
|
+
[
|
|
139
|
+
self.xp.ones_like(frqs, dtype=complex),
|
|
140
|
+
self.xp.ones_like(frqs, dtype=complex),
|
|
141
|
+
]
|
|
142
|
+
)[None, :, :]
|
|
143
|
+
psd = self.xp.asarray(
|
|
144
|
+
[
|
|
145
|
+
self.xp.ones_like(frqs, dtype=np.float64),
|
|
146
|
+
self.xp.ones_like(frqs, dtype=np.float64),
|
|
147
|
+
]
|
|
148
|
+
)
|
|
127
149
|
|
|
128
150
|
waveform_kwargs_fill = waveform_kwargs.copy()
|
|
129
151
|
waveform_kwargs_fill.pop("start_freq_ind")
|
|
130
|
-
|
|
152
|
+
|
|
131
153
|
# TODO: could use real data and get observed snr for each if needed
|
|
132
154
|
check = self.gb_wave_generator.swap_likelihood_difference(
|
|
133
155
|
params_for_test_in_full,
|
|
@@ -147,26 +169,36 @@ class DetermineGBGroups:
|
|
|
147
169
|
normalized_autocorr = numerator / np.sqrt(norm_here * norm_for_test)
|
|
148
170
|
normalized_against_test = numerator / norm_for_test
|
|
149
171
|
|
|
150
|
-
normalized_autocorr = normalized_autocorr.reshape(
|
|
151
|
-
|
|
152
|
-
|
|
172
|
+
normalized_autocorr = normalized_autocorr.reshape(
|
|
173
|
+
coords_here_in.shape[0], params_for_test_in.shape[0]
|
|
174
|
+
).real
|
|
175
|
+
normalized_against_test = normalized_against_test.reshape(
|
|
176
|
+
coords_here_in.shape[0], params_for_test_in.shape[0]
|
|
177
|
+
).real
|
|
178
|
+
|
|
153
179
|
# TODO: do based on Likelihood? make sure on same posterior
|
|
154
180
|
# TODO: add check based on amplitude
|
|
155
|
-
test1 = np.abs(
|
|
181
|
+
test1 = np.abs(
|
|
182
|
+
1.0 - normalized_autocorr.real
|
|
183
|
+
) # (numerator / norm_for_test[None, :]).real)
|
|
156
184
|
best = test1.argmin(axis=1)
|
|
157
185
|
try:
|
|
158
186
|
best = best.get()
|
|
159
187
|
except AttributeError:
|
|
160
188
|
pass
|
|
161
189
|
best_mismatch = test1[(np.arange(test1.shape[0]), best)]
|
|
162
|
-
check_normalized_against_test = np.abs(
|
|
163
|
-
|
|
190
|
+
check_normalized_against_test = np.abs(
|
|
191
|
+
1.0 - normalized_against_test[(np.arange(test1.shape[0]), best)]
|
|
192
|
+
)
|
|
164
193
|
|
|
165
194
|
f0_here = coords_here[:, 1]
|
|
166
195
|
f0_test = params_for_test[best, 1]
|
|
167
196
|
|
|
168
197
|
for leaf in range(nleaves):
|
|
169
|
-
if
|
|
198
|
+
if (
|
|
199
|
+
best_mismatch[leaf] < mismatch_lim
|
|
200
|
+
and check_normalized_against_test[leaf] < double_check_lim
|
|
201
|
+
):
|
|
170
202
|
groups[best[leaf]].append(coords_here[leaf].copy())
|
|
171
203
|
groups_inds[best[leaf]].append([w, inds_for_group_stuff[leaf]])
|
|
172
204
|
|
|
@@ -187,41 +219,65 @@ class GetLastGBState:
|
|
|
187
219
|
self.transform_fn = transform_fn
|
|
188
220
|
self.waveform_kwargs = waveform_kwargs
|
|
189
221
|
|
|
190
|
-
def __call__(
|
|
222
|
+
def __call__(
|
|
223
|
+
self,
|
|
224
|
+
mgh,
|
|
225
|
+
reader,
|
|
226
|
+
df,
|
|
227
|
+
supps_base_shape,
|
|
228
|
+
fix_temp_initial_ind: int = None,
|
|
229
|
+
fix_temp_inds: list = None,
|
|
230
|
+
nleaves_max_in=None,
|
|
231
|
+
waveform_kwargs={},
|
|
232
|
+
):
|
|
191
233
|
|
|
192
234
|
xp.cuda.runtime.setDevice(mgh.gpus[0])
|
|
193
235
|
|
|
194
236
|
if fix_temp_initial_ind is not None or fix_temp_inds is not None:
|
|
195
237
|
if fix_temp_initial_ind is None or fix_temp_inds is None:
|
|
196
|
-
raise ValueError(
|
|
238
|
+
raise ValueError(
|
|
239
|
+
"If giving fix_temp_initial_ind or fix_temp_inds, must give both."
|
|
240
|
+
)
|
|
197
241
|
|
|
198
242
|
state = reader.get_last_sample()
|
|
199
243
|
|
|
200
244
|
waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
|
|
201
245
|
if "start_freq_ind" not in waveform_kwargs:
|
|
202
|
-
raise ValueError(
|
|
246
|
+
raise ValueError(
|
|
247
|
+
"In get_last_gb_state, waveform_kwargs must include 'start_freq_ind'."
|
|
248
|
+
)
|
|
203
249
|
|
|
204
|
-
#check = reader.get_last_sample()
|
|
250
|
+
# check = reader.get_last_sample()
|
|
205
251
|
ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
|
|
206
|
-
|
|
207
|
-
#out = get_groups_for_remixing(check, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, name_here="gb")
|
|
208
252
|
|
|
209
|
-
#
|
|
210
|
-
|
|
253
|
+
# out = get_groups_for_remixing(check, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, name_here="gb")
|
|
254
|
+
|
|
255
|
+
# lengths = []
|
|
256
|
+
# for group in out[0]:
|
|
211
257
|
# lengths.append(len(group))
|
|
212
|
-
#breakpoint()
|
|
258
|
+
# breakpoint()
|
|
213
259
|
try:
|
|
214
|
-
if fix_temp_initial_ind is not None:
|
|
260
|
+
if fix_temp_initial_ind is not None:
|
|
215
261
|
for i in fix_temp_inds:
|
|
216
262
|
if i < fix_temp_initial_ind:
|
|
217
|
-
raise ValueError(
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"If providing fix_temp_initial_ind and fix_temp_inds, all values in fix_temp_inds must be greater than fix_temp_initial_ind."
|
|
265
|
+
)
|
|
218
266
|
|
|
219
267
|
state.log_like[i] = state.log_like[fix_temp_initial_ind]
|
|
220
268
|
state.log_prior[i] = state.log_prior[fix_temp_initial_ind]
|
|
221
|
-
state.branches_coords["gb"][i] = state.branches_coords["gb"][
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
state.
|
|
269
|
+
state.branches_coords["gb"][i] = state.branches_coords["gb"][
|
|
270
|
+
fix_temp_initial_ind
|
|
271
|
+
]
|
|
272
|
+
state.branches_coords["gb"][i] = state.branches_coords["gb"][
|
|
273
|
+
fix_temp_initial_ind
|
|
274
|
+
]
|
|
275
|
+
state.branches_inds["gb"][i] = state.branches_inds["gb"][
|
|
276
|
+
fix_temp_initial_ind
|
|
277
|
+
]
|
|
278
|
+
state.branches_inds["gb"][i] = state.branches_inds["gb"][
|
|
279
|
+
fix_temp_initial_ind
|
|
280
|
+
]
|
|
225
281
|
|
|
226
282
|
ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
|
|
227
283
|
if nleaves_max_in is None:
|
|
@@ -238,23 +294,34 @@ class GetLastGBState:
|
|
|
238
294
|
state.branches["gb"].inds = inds_tmp
|
|
239
295
|
state.branches["gb"].nleaves_max = nleaves_max
|
|
240
296
|
state.branches["gb"].shape = (ntemps, nwalkers, nleaves_max, ndim)
|
|
241
|
-
|
|
297
|
+
|
|
242
298
|
else:
|
|
243
299
|
raise ValueError("new nleaves_max is less than nleaves_max_old.")
|
|
244
300
|
|
|
245
301
|
# add "gb" if there are any
|
|
246
302
|
data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
|
|
247
303
|
|
|
248
|
-
data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(
|
|
304
|
+
data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(
|
|
305
|
+
xp.int32
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
params_add_in = self.transform_fn["gb"].both_transforms(
|
|
309
|
+
state.branches_coords["gb"][state.branches_inds["gb"]]
|
|
310
|
+
)
|
|
249
311
|
|
|
250
|
-
params_add_in = self.transform_fn["gb"].both_transforms(state.branches_coords["gb"][state.branches_inds["gb"]])
|
|
251
|
-
|
|
252
312
|
# batch_size is ignored if waveform_kwargs["use_c_implementation"] is True
|
|
253
|
-
# -1 is to do -(-d + h) = d - h
|
|
254
|
-
mgh.multiply_data(-1.)
|
|
255
|
-
self.gb_wave_generator.generate_global_template(
|
|
256
|
-
|
|
257
|
-
|
|
313
|
+
# -1 is to do -(-d + h) = d - h
|
|
314
|
+
mgh.multiply_data(-1.0)
|
|
315
|
+
self.gb_wave_generator.generate_global_template(
|
|
316
|
+
params_add_in,
|
|
317
|
+
data_index,
|
|
318
|
+
mgh.data_list,
|
|
319
|
+
data_length=mgh.data_length,
|
|
320
|
+
data_splits=mgh.gpu_splits,
|
|
321
|
+
batch_size=1000,
|
|
322
|
+
**waveform_kwargs,
|
|
323
|
+
)
|
|
324
|
+
mgh.multiply_data(-1.0)
|
|
258
325
|
|
|
259
326
|
except KeyError:
|
|
260
327
|
# no "gb"
|
|
@@ -263,22 +330,42 @@ class GetLastGBState:
|
|
|
263
330
|
data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
|
|
264
331
|
data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(xp.int32)
|
|
265
332
|
|
|
266
|
-
params_add_in = self.transform_fn["gb"].both_transforms(
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
mgh.multiply_data(-1.)
|
|
333
|
+
params_add_in = self.transform_fn["gb"].both_transforms(
|
|
334
|
+
state.branches_coords["gb"][state.branches_inds["gb"]]
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# -1 is to do -(-d + h) = d - h
|
|
338
|
+
mgh.multiply_data(-1.0)
|
|
339
|
+
self.gb_wave_generator.generate_global_template(
|
|
340
|
+
params_add_in,
|
|
341
|
+
data_index,
|
|
342
|
+
mgh.data_list,
|
|
343
|
+
data_length=mgh.data_length,
|
|
344
|
+
data_splits=mgh.gpu_splits,
|
|
345
|
+
batch_size=1000,
|
|
346
|
+
**waveform_kwargs,
|
|
347
|
+
)
|
|
348
|
+
mgh.multiply_data(-1.0)
|
|
272
349
|
|
|
273
350
|
self.gb_wave_generator.d_d = np.asarray(mgh.get_inner_product(use_cpu=True))
|
|
274
|
-
|
|
275
|
-
state.log_like =
|
|
351
|
+
|
|
352
|
+
state.log_like = (
|
|
353
|
+
-1 / 2 * self.gb_wave_generator.d_d.real.reshape(ntemps, nwalkers)
|
|
354
|
+
)
|
|
276
355
|
|
|
277
356
|
temp_inds = mgh.temp_indices.copy()
|
|
278
357
|
walker_inds = mgh.walker_indices.copy()
|
|
279
358
|
overall_inds = mgh.overall_indices.copy()
|
|
280
|
-
|
|
281
|
-
supps =
|
|
359
|
+
|
|
360
|
+
supps = BranchSupplemental(
|
|
361
|
+
{
|
|
362
|
+
"temp_inds": temp_inds,
|
|
363
|
+
"walker_inds": walker_inds,
|
|
364
|
+
"overall_inds": overall_inds,
|
|
365
|
+
},
|
|
366
|
+
obj_contained_shape=supps_base_shape,
|
|
367
|
+
copy=True,
|
|
368
|
+
)
|
|
282
369
|
state.supplimental = supps
|
|
283
370
|
|
|
284
371
|
return state
|
lisatools/sensitivity.py
CHANGED
|
@@ -162,11 +162,11 @@ class Sensitivity(ABC):
|
|
|
162
162
|
):
|
|
163
163
|
if stochastic_function is None:
|
|
164
164
|
stochastic_function = FittedHyperbolicTangentGalacticForeground
|
|
165
|
+
assert len(stochastic_params) == 1
|
|
165
166
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
sgal[:] = check
|
|
167
|
+
sgal[:] = stochastic_function.get_Sh(
|
|
168
|
+
f, *stochastic_params, **stochastic_kwargs
|
|
169
|
+
)
|
|
170
170
|
|
|
171
171
|
if squeeze:
|
|
172
172
|
sgal = sgal.squeeze()
|
lisatools/sources/utils.py
CHANGED
|
@@ -90,6 +90,21 @@ class CalculationController:
|
|
|
90
90
|
return opt_snr
|
|
91
91
|
|
|
92
92
|
|
|
93
|
+
def mT_q_to_m1_m2(mT: float, q: float) -> Tuple[float, float]:
|
|
94
|
+
"""
|
|
95
|
+
q <= 1.0
|
|
96
|
+
"""
|
|
97
|
+
return (mT / (1 + q), (q * mT) / (1 + q))
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def dist_convert(x: float) -> float:
|
|
101
|
+
return x * 1e9 * PC_SI
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def time_convert(x: float) -> float:
|
|
105
|
+
return x * YRSID_SI
|
|
106
|
+
|
|
107
|
+
|
|
93
108
|
class BBHCalculationController(CalculationController):
|
|
94
109
|
"""Calculation controller for BBHs.
|
|
95
110
|
|
|
@@ -104,10 +119,11 @@ class BBHCalculationController(CalculationController):
|
|
|
104
119
|
# transforms from information matrix basis
|
|
105
120
|
parameter_transforms = {
|
|
106
121
|
0: np.exp,
|
|
107
|
-
4:
|
|
122
|
+
4: dist_convert,
|
|
108
123
|
7: np.arccos,
|
|
109
124
|
9: np.arcsin,
|
|
110
|
-
11:
|
|
125
|
+
11: time_convert,
|
|
126
|
+
(0, 1): mT_q_to_m1_m2,
|
|
111
127
|
}
|
|
112
128
|
self.transform_fn = TransformContainer(
|
|
113
129
|
parameter_transforms=parameter_transforms, fill_dict=None # fill_dict
|
|
File without changes
|
|
File without changes
|