lisaanalysistools 1.0.10__cp312-cp312-macosx_10_9_x86_64.whl → 1.0.11__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lisaanalysistools might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lisaanalysistools
3
- Version: 1.0.10
3
+ Version: 1.0.11
4
4
  Home-page: https://github.com/mikekatz04/LISAanalysistools
5
5
  Author: Michael Katz
6
6
  Author-email: mikekatz04@gmail.com
@@ -66,7 +66,7 @@ Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduc
66
66
 
67
67
  We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/mikekatz04/LISAanalysistools/tags).
68
68
 
69
- Current Version: 1.0.10
69
+ Current Version: 1.0.11
70
70
 
71
71
  ## Authors/Developers
72
72
 
@@ -1,13 +1,13 @@
1
1
  lisatools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- lisatools/_version.py,sha256=rrQg9RuwmpsEKf63sR6w0kWC6dP3jibZZD8wvMlcM8E,124
3
- lisatools/analysiscontainer.py,sha256=kvTP0KSzfAlLBkeF3RNn0BcZs_K1Z0ntQA4xglfxqgI,15345
2
+ lisatools/_version.py,sha256=kORDw-aHXqe534RIAm9NvG4SSRug8egjhH2hTnHhut0,124
3
+ lisatools/analysiscontainer.py,sha256=UQft6SvyDueDtuL1H1auiEzG5IEsFwRiiXf_0DHtu5U,16289
4
4
  lisatools/datacontainer.py,sha256=QVz0twD46Fl_J-ueGQUCXOJkkXUEum7yrwp9LaVeohU,9853
5
- lisatools/detector.py,sha256=RhWIj5SaMpOGOc5_i7GkYf8vkoEqzxUU_or1eSEwb8Q,21007
5
+ lisatools/detector.py,sha256=1ATZToFoy14U3ACTyver2x9DKavJ3l6ToU6nUSGSTmo,21161
6
6
  lisatools/diagnostic.py,sha256=o9vtfXbY3yMDk4cGNeOsTrlP7gYG_BRROw7gyuThDaM,34192
7
- lisatools/sensitivity.py,sha256=cUgYvfsfd-MrKCAUk2rwb-oQm6IfHI-IMndbRLOKqVs,27380
7
+ lisatools/sensitivity.py,sha256=3BEVnRjFsQ7HZX6PRjd6B7K2FyIJRL3lpQpGaHjNgMA,27393
8
8
  lisatools/stochastic.py,sha256=CysAQ5BIVD349fLjAMpZAXaT0X0dFjrd4QFi0vfqops,9458
9
9
  lisatools/cutils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- lisatools/cutils/detector_cpu.cpython-312-darwin.so,sha256=sYWmMV8OJKuqkbrnr9jfPomur9WR5pVVU2eWG_XumJk,154720
10
+ lisatools/cutils/detector_cpu.cpython-312-darwin.so,sha256=xhTNwTCXaSHLmHFOWOVAl3sP4tDy2o7TMExj_G5MORE,154720
11
11
  lisatools/cutils/include/Detector.hpp,sha256=Ic37OgP-gvEg8qouhv9aFYf7vQP98tJ_i6PGu8RI1YQ,2050
12
12
  lisatools/cutils/include/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
13
  lisatools/cutils/include/global.hpp,sha256=3VPiqglTMRrIBXlEvDUJO2-CjKy_SLUXZt-9A1fH6lQ,572
@@ -16,10 +16,10 @@ lisatools/cutils/src/Detector.cu,sha256=JvAK8UCoYxgTACBhx3hi0EK9hT-HRYOxX2XozHuH
16
16
  lisatools/cutils/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  lisatools/cutils/src/pycppdetector.pyx,sha256=e0tz79VNS7Sxxwyx9RQwsFVdpuClBsUf7OArSkyRBa0,7613
18
18
  lisatools/sampling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- lisatools/sampling/likelihood.py,sha256=G2kAQ43qlhAzIFWvYsrSmHXd7WKJAzcCN2o07vRE8vc,29585
19
+ lisatools/sampling/likelihood.py,sha256=A9a1McgjqJre8c3Bzoyxakl3onjxaE6c7zVwY2-12iI,29585
20
20
  lisatools/sampling/prior.py,sha256=1K1PMStpwO9WT0qG0aotKSyoNjuehXNbzTDtlk8Q15M,21407
21
21
  lisatools/sampling/stopping.py,sha256=Q8q7nM0wnJervhRduf2tBXZZHUVza5kJiAUAMUQXP5o,9682
22
- lisatools/sampling/utility.py,sha256=rOGotS0Aj8-DAWqsTVy2xWNsxsoz74BVrHEnG2mOkwU,14340
22
+ lisatools/sampling/utility.py,sha256=R5LKNK0A36L1PzLjCENQjDWcDjsR1oUBQGrR9PqasSA,15482
23
23
  lisatools/sampling/moves/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  lisatools/sampling/moves/skymodehop.py,sha256=0nf72eFhFMGwi0dLJci6XZz-bIMGqco2B2_J72hQvf8,3348
25
25
  lisatools/sources/__init__.py,sha256=Fm085xHQ3VpRjqaSlws0bdVefFofAxNzVZyGQQYrQic,140
@@ -36,8 +36,8 @@ lisatools/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  lisatools/utils/constants.py,sha256=r1kVwkpbZS13JTOxj2iRxT5sMgTYX30y-S0JdVmD5Oo,1354
37
37
  lisatools/utils/pointeradjust.py,sha256=2sT-7qeYWr1pd_sHk9leVHUTSJ7jJgYIRoWQOtYqguE,2995
38
38
  lisatools/utils/utility.py,sha256=MKqRsG8qRI1xCsj51mt6QRa80-deUUgedxibyMV3KD4,6776
39
- lisaanalysistools-1.0.10.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
40
- lisaanalysistools-1.0.10.dist-info/METADATA,sha256=hOqVWPDpsoy9kyzDpSwLA9dgJxF9rY1SnSSWJxiUyMY,4202
41
- lisaanalysistools-1.0.10.dist-info/WHEEL,sha256=Vn5rrdwmXMBVNlKjgSWDNoqXXbGpTeTFNnBfj813LFw,110
42
- lisaanalysistools-1.0.10.dist-info/top_level.txt,sha256=qXUn8Xi8yvK0vr3QH0vvT5cmoccjSj-UPeKGLAxdYGo,10
43
- lisaanalysistools-1.0.10.dist-info/RECORD,,
39
+ lisaanalysistools-1.0.11.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
40
+ lisaanalysistools-1.0.11.dist-info/METADATA,sha256=xjT_alb_5mDWKzhKyK9KldcQjlndzlQ8-eyrZ9Fp8fk,4202
41
+ lisaanalysistools-1.0.11.dist-info/WHEEL,sha256=zL8OKpwMaCM_hEdCwF-bHbg5JyuZtrW4vMYNMNxkj9k,110
42
+ lisaanalysistools-1.0.11.dist-info/top_level.txt,sha256=qXUn8Xi8yvK0vr3QH0vvT5cmoccjSj-UPeKGLAxdYGo,10
43
+ lisaanalysistools-1.0.11.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp312-cp312-macosx_10_9_x86_64
5
5
 
lisatools/_version.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '1.0.10'
1
+ __version__ = '1.0.11'
2
2
  __copyright__ = "Michael L. Katz 2024"
3
3
  __name__ = "lisaanalysistools"
4
4
  __author__ = "Michael L. Katz"
@@ -11,6 +11,8 @@ import numpy as np
11
11
  from scipy import interpolate
12
12
  import matplotlib.pyplot as plt
13
13
 
14
+ from eryn.utils import TransformContainer
15
+
14
16
 
15
17
  try:
16
18
  import cupy as cp
@@ -286,6 +288,7 @@ class AnalysisContainer:
286
288
  source_only: bool = False,
287
289
  waveform_kwargs: Optional[dict] = {},
288
290
  data_res_arr_kwargs: Optional[dict] = {},
291
+ transform_fn: Optional[TransformContainer] = None,
289
292
  **kwargs: dict,
290
293
  ) -> float | complex:
291
294
  """Return the likelihood of a generated signal with the data.
@@ -308,8 +311,14 @@ class AnalysisContainer:
308
311
  if data_res_arr_kwargs == {}:
309
312
  data_res_arr_kwargs = self.data_res_arr.init_kwargs
310
313
 
314
+ if transform_fn is not None:
315
+ args_tmp = np.asarray(args)
316
+ args_in = tuple(transform_fn.both_transforms(args_tmp))
317
+ else:
318
+ args_in = args
319
+
311
320
  template = DataResidualArray(
312
- self.signal_gen(*args, **waveform_kwargs), **data_res_arr_kwargs
321
+ self.signal_gen(*args_in, **waveform_kwargs), **data_res_arr_kwargs
313
322
  )
314
323
 
315
324
  args_2 = (template,)
@@ -426,7 +435,30 @@ class AnalysisContainer:
426
435
  **kwargs,
427
436
  )
428
437
 
429
- def eryn_likelihood_function(self, x, *args, **kwargs):
438
+ def eryn_likelihood_function(
439
+ self, x: np.ndarray | list | tuple, *args: Any, **kwargs: Any
440
+ ) -> np.ndarray | float:
441
+ """Likelihood function for Eryn sampler.
442
+
443
+ This function is not vectorized.
444
+
445
+ ``signal_gen`` must be set to use this function.
446
+
447
+ Args:
448
+ x: Parameters. Can be 1D list, tuple, array or 2D array.
449
+ If a 2D array is input, the computation is done serially.
450
+ *args: Likelihood args.
451
+ **kwargs: Likelihood kwargs.
452
+
453
+ Returns:
454
+ Likelihood value(s).
455
+
456
+ """
457
+ assert self.signal_gen is not None
458
+
459
+ if isinstance(x, list) or isinstance(x, tuple):
460
+ x = np.asarray(x)
461
+
430
462
  if x.ndim == 1:
431
463
  input_vals = tuple(x) + tuple(args)
432
464
  return self.calculate_signal_likelihood(*input_vals, **kwargs)
lisatools/detector.py CHANGED
@@ -9,6 +9,7 @@ import h5py
9
9
  from scipy import interpolate
10
10
 
11
11
  from .utils.constants import *
12
+ from .utils.utility import get_array_module
12
13
 
13
14
  import numpy as np
14
15
 
@@ -453,9 +454,10 @@ class Orbits(ABC):
453
454
  squeeze = False
454
455
  t = self.xp.asarray(t)
455
456
  sc = self.xp.asarray(sc).astype(np.int32)
457
+
456
458
  else:
457
459
  raise ValueError(
458
- "(t, sc) can be (float, int), (np.ndarray, int), (np.ndarray, np.ndarray)."
460
+ "(t, sc) can be (float, int), (np.ndarray, int), (np.ndarray, np.ndarray). If the inputs follow this, make sure the orbits class GPU setting matches the arrays coming in (GPU or CPU)."
459
461
  )
460
462
 
461
463
  # buffer arrays for input into c code
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from eryn.state import Branch, BranchSupplimental
2
+ from eryn.state import Branch, BranchSupplemental
3
3
 
4
4
  import numpy as np
5
5
 
@@ -9,7 +9,7 @@ try:
9
9
  except (ImportError, ModuleNotFoundError) as e:
10
10
  pass
11
11
 
12
- from eryn.state import State, BranchSupplimental
12
+ from eryn.state import State, BranchSupplemental
13
13
  from eryn.utils.utility import groups_from_inds
14
14
 
15
15
 
@@ -19,8 +19,21 @@ class DetermineGBGroups:
19
19
  self.xp = self.gb_wave_generator.xp
20
20
  self.transform_fn = transform_fn
21
21
  self.waveform_kwargs = waveform_kwargs
22
-
23
- def __call__(self, last_sample, name_here, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, mismatch_lim=0.2, double_check_lim=0.2, start_term="random", waveform_kwargs={}, index_within_group="random"):
22
+
23
+ def __call__(
24
+ self,
25
+ last_sample,
26
+ name_here,
27
+ check_temp=0,
28
+ input_groups=None,
29
+ input_groups_inds=None,
30
+ fix_group_count=False,
31
+ mismatch_lim=0.2,
32
+ double_check_lim=0.2,
33
+ start_term="random",
34
+ waveform_kwargs={},
35
+ index_within_group="random",
36
+ ):
24
37
  # TODO: mess with mismatch lim setting
25
38
  # TODO: some time of mismatch annealing may be useful
26
39
  if isinstance(last_sample, State):
@@ -32,13 +45,13 @@ class DetermineGBGroups:
32
45
  inds = last_sample[name_here][check_temp]["inds"]
33
46
 
34
47
  waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
35
-
48
+
36
49
  # get coordinates and inds of the temperature you are considering.
37
-
50
+
38
51
  nwalkers, nleaves_max, ndim = coords.shape
39
52
  if input_groups is None:
40
53
 
41
- # figure our which walker to start with
54
+ # figure our which walker to start with
42
55
  if start_term == "max":
43
56
  start_walker_ind = inds[check_temp].sum(axis=-1).argmax()
44
57
  elif start_term == "first":
@@ -52,7 +65,7 @@ class DetermineGBGroups:
52
65
  inds_good = np.where(inds[start_walker_ind])[0]
53
66
  groups = []
54
67
  groups_inds = []
55
-
68
+
56
69
  # set up this information to load the information into the group lists
57
70
  for leaf_i, leaf in enumerate(inds_good):
58
71
  groups.append([])
@@ -72,7 +85,7 @@ class DetermineGBGroups:
72
85
  if input_groups is None and w == start_walker_ind:
73
86
  continue
74
87
 
75
- # walker has no binaries
88
+ # walker has no binaries
76
89
  if not np.any(inds[w]):
77
90
  continue
78
91
 
@@ -81,7 +94,6 @@ class DetermineGBGroups:
81
94
  inds_for_group_stuff = np.arange(len(inds[w]))[inds[w]]
82
95
  nleaves, ndim = coords_here.shape
83
96
 
84
-
85
97
  params_for_test = []
86
98
  for group in groups:
87
99
  group_params = np.asarray(group)
@@ -95,39 +107,49 @@ class DetermineGBGroups:
95
107
 
96
108
  params_for_test.append(group_params[test_walker_ind])
97
109
  params_for_test = np.asarray(params_for_test)
98
-
110
+
99
111
  # transform coords
100
112
  if self.transform_fn is not None:
101
- params_for_test_in = self.transform_fn[name_here].both_transforms(params_for_test, return_transpose=False)
102
- coords_here_in = self.transform_fn[name_here].both_transforms(coords_here, return_transpose=False)
103
-
113
+ params_for_test_in = self.transform_fn[name_here].both_transforms(
114
+ params_for_test, return_transpose=False
115
+ )
116
+ coords_here_in = self.transform_fn[name_here].both_transforms(
117
+ coords_here, return_transpose=False
118
+ )
119
+
104
120
  else:
105
121
  params_for_test_in = params_for_test.copy()
106
122
  coords_here_in = coords_here.copy()
107
123
 
108
124
  inds_tmp_test = np.arange(len(params_for_test_in))
109
125
  inds_tmp_here = np.arange(len(coords_here_in))
110
- inds_tmp_test, inds_tmp_here = [tmp.ravel() for tmp in np.meshgrid(inds_tmp_test, inds_tmp_here)]
126
+ inds_tmp_test, inds_tmp_here = [
127
+ tmp.ravel() for tmp in np.meshgrid(inds_tmp_test, inds_tmp_here)
128
+ ]
111
129
 
112
130
  params_for_test_in_full = params_for_test_in[inds_tmp_test]
113
131
  coords_here_in_full = coords_here_in[inds_tmp_here]
114
132
  # build the waveforms at the same time
115
133
 
116
- df = 1. / waveform_kwargs["T"]
117
- max_f = 1. / 2 * 1/waveform_kwargs["dt"]
134
+ df = 1.0 / waveform_kwargs["T"]
135
+ max_f = 1.0 / 2 * 1 / waveform_kwargs["dt"]
118
136
  frqs = self.xp.arange(0.0, max_f, df)
119
- data_minus_template = self.xp.asarray([
120
- self.xp.ones_like(frqs, dtype=complex),
121
- self.xp.ones_like(frqs, dtype=complex)
122
- ])[None, :, :]
123
- psd = self.xp.asarray([
124
- self.xp.ones_like(frqs, dtype=np.float64),
125
- self.xp.ones_like(frqs, dtype=np.float64)
126
- ])
137
+ data_minus_template = self.xp.asarray(
138
+ [
139
+ self.xp.ones_like(frqs, dtype=complex),
140
+ self.xp.ones_like(frqs, dtype=complex),
141
+ ]
142
+ )[None, :, :]
143
+ psd = self.xp.asarray(
144
+ [
145
+ self.xp.ones_like(frqs, dtype=np.float64),
146
+ self.xp.ones_like(frqs, dtype=np.float64),
147
+ ]
148
+ )
127
149
 
128
150
  waveform_kwargs_fill = waveform_kwargs.copy()
129
151
  waveform_kwargs_fill.pop("start_freq_ind")
130
-
152
+
131
153
  # TODO: could use real data and get observed snr for each if needed
132
154
  check = self.gb_wave_generator.swap_likelihood_difference(
133
155
  params_for_test_in_full,
@@ -147,26 +169,36 @@ class DetermineGBGroups:
147
169
  normalized_autocorr = numerator / np.sqrt(norm_here * norm_for_test)
148
170
  normalized_against_test = numerator / norm_for_test
149
171
 
150
- normalized_autocorr = normalized_autocorr.reshape(coords_here_in.shape[0], params_for_test_in.shape[0]).real
151
- normalized_against_test = normalized_against_test.reshape(coords_here_in.shape[0], params_for_test_in.shape[0]).real
152
-
172
+ normalized_autocorr = normalized_autocorr.reshape(
173
+ coords_here_in.shape[0], params_for_test_in.shape[0]
174
+ ).real
175
+ normalized_against_test = normalized_against_test.reshape(
176
+ coords_here_in.shape[0], params_for_test_in.shape[0]
177
+ ).real
178
+
153
179
  # TODO: do based on Likelihood? make sure on same posterior
154
180
  # TODO: add check based on amplitude
155
- test1 = np.abs(1.0 - normalized_autocorr.real) # (numerator / norm_for_test[None, :]).real)
181
+ test1 = np.abs(
182
+ 1.0 - normalized_autocorr.real
183
+ ) # (numerator / norm_for_test[None, :]).real)
156
184
  best = test1.argmin(axis=1)
157
185
  try:
158
186
  best = best.get()
159
187
  except AttributeError:
160
188
  pass
161
189
  best_mismatch = test1[(np.arange(test1.shape[0]), best)]
162
- check_normalized_against_test = np.abs(1.0 - normalized_against_test[(np.arange(test1.shape[0]), best)])
163
-
190
+ check_normalized_against_test = np.abs(
191
+ 1.0 - normalized_against_test[(np.arange(test1.shape[0]), best)]
192
+ )
164
193
 
165
194
  f0_here = coords_here[:, 1]
166
195
  f0_test = params_for_test[best, 1]
167
196
 
168
197
  for leaf in range(nleaves):
169
- if best_mismatch[leaf] < mismatch_lim and check_normalized_against_test[leaf] < double_check_lim:
198
+ if (
199
+ best_mismatch[leaf] < mismatch_lim
200
+ and check_normalized_against_test[leaf] < double_check_lim
201
+ ):
170
202
  groups[best[leaf]].append(coords_here[leaf].copy())
171
203
  groups_inds[best[leaf]].append([w, inds_for_group_stuff[leaf]])
172
204
 
@@ -187,41 +219,65 @@ class GetLastGBState:
187
219
  self.transform_fn = transform_fn
188
220
  self.waveform_kwargs = waveform_kwargs
189
221
 
190
- def __call__(self, mgh, reader, df, supps_base_shape, fix_temp_initial_ind:int=None, fix_temp_inds:list=None, nleaves_max_in=None, waveform_kwargs={}):
222
+ def __call__(
223
+ self,
224
+ mgh,
225
+ reader,
226
+ df,
227
+ supps_base_shape,
228
+ fix_temp_initial_ind: int = None,
229
+ fix_temp_inds: list = None,
230
+ nleaves_max_in=None,
231
+ waveform_kwargs={},
232
+ ):
191
233
 
192
234
  xp.cuda.runtime.setDevice(mgh.gpus[0])
193
235
 
194
236
  if fix_temp_initial_ind is not None or fix_temp_inds is not None:
195
237
  if fix_temp_initial_ind is None or fix_temp_inds is None:
196
- raise ValueError("If giving fix_temp_initial_ind or fix_temp_inds, must give both.")
238
+ raise ValueError(
239
+ "If giving fix_temp_initial_ind or fix_temp_inds, must give both."
240
+ )
197
241
 
198
242
  state = reader.get_last_sample()
199
243
 
200
244
  waveform_kwargs = {**self.waveform_kwargs, **waveform_kwargs}
201
245
  if "start_freq_ind" not in waveform_kwargs:
202
- raise ValueError("In get_last_gb_state, waveform_kwargs must include 'start_freq_ind'.")
246
+ raise ValueError(
247
+ "In get_last_gb_state, waveform_kwargs must include 'start_freq_ind'."
248
+ )
203
249
 
204
- #check = reader.get_last_sample()
250
+ # check = reader.get_last_sample()
205
251
  ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
206
-
207
- #out = get_groups_for_remixing(check, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, name_here="gb")
208
252
 
209
- #lengths = []
210
- #for group in out[0]:
253
+ # out = get_groups_for_remixing(check, check_temp=0, input_groups=None, input_groups_inds=None, fix_group_count=False, name_here="gb")
254
+
255
+ # lengths = []
256
+ # for group in out[0]:
211
257
  # lengths.append(len(group))
212
- #breakpoint()
258
+ # breakpoint()
213
259
  try:
214
- if fix_temp_initial_ind is not None:
260
+ if fix_temp_initial_ind is not None:
215
261
  for i in fix_temp_inds:
216
262
  if i < fix_temp_initial_ind:
217
- raise ValueError("If providing fix_temp_initial_ind and fix_temp_inds, all values in fix_temp_inds must be greater than fix_temp_initial_ind.")
263
+ raise ValueError(
264
+ "If providing fix_temp_initial_ind and fix_temp_inds, all values in fix_temp_inds must be greater than fix_temp_initial_ind."
265
+ )
218
266
 
219
267
  state.log_like[i] = state.log_like[fix_temp_initial_ind]
220
268
  state.log_prior[i] = state.log_prior[fix_temp_initial_ind]
221
- state.branches_coords["gb"][i] = state.branches_coords["gb"][fix_temp_initial_ind]
222
- state.branches_coords["gb"][i] = state.branches_coords["gb"][fix_temp_initial_ind]
223
- state.branches_inds["gb"][i] = state.branches_inds["gb"][fix_temp_initial_ind]
224
- state.branches_inds["gb"][i] = state.branches_inds["gb"][fix_temp_initial_ind]
269
+ state.branches_coords["gb"][i] = state.branches_coords["gb"][
270
+ fix_temp_initial_ind
271
+ ]
272
+ state.branches_coords["gb"][i] = state.branches_coords["gb"][
273
+ fix_temp_initial_ind
274
+ ]
275
+ state.branches_inds["gb"][i] = state.branches_inds["gb"][
276
+ fix_temp_initial_ind
277
+ ]
278
+ state.branches_inds["gb"][i] = state.branches_inds["gb"][
279
+ fix_temp_initial_ind
280
+ ]
225
281
 
226
282
  ntemps, nwalkers, nleaves_max_old, ndim = state.branches["gb"].shape
227
283
  if nleaves_max_in is None:
@@ -238,23 +294,34 @@ class GetLastGBState:
238
294
  state.branches["gb"].inds = inds_tmp
239
295
  state.branches["gb"].nleaves_max = nleaves_max
240
296
  state.branches["gb"].shape = (ntemps, nwalkers, nleaves_max, ndim)
241
-
297
+
242
298
  else:
243
299
  raise ValueError("new nleaves_max is less than nleaves_max_old.")
244
300
 
245
301
  # add "gb" if there are any
246
302
  data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
247
303
 
248
- data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(xp.int32)
304
+ data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(
305
+ xp.int32
306
+ )
307
+
308
+ params_add_in = self.transform_fn["gb"].both_transforms(
309
+ state.branches_coords["gb"][state.branches_inds["gb"]]
310
+ )
249
311
 
250
- params_add_in = self.transform_fn["gb"].both_transforms(state.branches_coords["gb"][state.branches_inds["gb"]])
251
-
252
312
  # batch_size is ignored if waveform_kwargs["use_c_implementation"] is True
253
- # -1 is to do -(-d + h) = d - h
254
- mgh.multiply_data(-1.)
255
- self.gb_wave_generator.generate_global_template(params_add_in, data_index, mgh.data_list, data_length=mgh.data_length, data_splits=mgh.gpu_splits, batch_size=1000, **waveform_kwargs)
256
- mgh.multiply_data(-1.)
257
-
313
+ # -1 is to do -(-d + h) = d - h
314
+ mgh.multiply_data(-1.0)
315
+ self.gb_wave_generator.generate_global_template(
316
+ params_add_in,
317
+ data_index,
318
+ mgh.data_list,
319
+ data_length=mgh.data_length,
320
+ data_splits=mgh.gpu_splits,
321
+ batch_size=1000,
322
+ **waveform_kwargs,
323
+ )
324
+ mgh.multiply_data(-1.0)
258
325
 
259
326
  except KeyError:
260
327
  # no "gb"
@@ -263,22 +330,42 @@ class GetLastGBState:
263
330
  data_index_in = groups_from_inds({"gb": state.branches_inds["gb"]})["gb"]
264
331
  data_index = xp.asarray(mgh.get_mapped_indices(data_index_in)).astype(xp.int32)
265
332
 
266
- params_add_in = self.transform_fn["gb"].both_transforms(state.branches_coords["gb"][state.branches_inds["gb"]])
267
-
268
- # -1 is to do -(-d + h) = d - h
269
- mgh.multiply_data(-1.)
270
- self.gb_wave_generator.generate_global_template(params_add_in, data_index, mgh.data_list, data_length=mgh.data_length, data_splits=mgh.gpu_splits, batch_size=1000, **waveform_kwargs)
271
- mgh.multiply_data(-1.)
333
+ params_add_in = self.transform_fn["gb"].both_transforms(
334
+ state.branches_coords["gb"][state.branches_inds["gb"]]
335
+ )
336
+
337
+ # -1 is to do -(-d + h) = d - h
338
+ mgh.multiply_data(-1.0)
339
+ self.gb_wave_generator.generate_global_template(
340
+ params_add_in,
341
+ data_index,
342
+ mgh.data_list,
343
+ data_length=mgh.data_length,
344
+ data_splits=mgh.gpu_splits,
345
+ batch_size=1000,
346
+ **waveform_kwargs,
347
+ )
348
+ mgh.multiply_data(-1.0)
272
349
 
273
350
  self.gb_wave_generator.d_d = np.asarray(mgh.get_inner_product(use_cpu=True))
274
-
275
- state.log_like = -1/2 * self.gb_wave_generator.d_d.real.reshape(ntemps, nwalkers)
351
+
352
+ state.log_like = (
353
+ -1 / 2 * self.gb_wave_generator.d_d.real.reshape(ntemps, nwalkers)
354
+ )
276
355
 
277
356
  temp_inds = mgh.temp_indices.copy()
278
357
  walker_inds = mgh.walker_indices.copy()
279
358
  overall_inds = mgh.overall_indices.copy()
280
-
281
- supps = BranchSupplimental({ "temp_inds": temp_inds, "walker_inds": walker_inds, "overall_inds": overall_inds,}, obj_contained_shape=supps_base_shape, copy=True)
359
+
360
+ supps = BranchSupplemental(
361
+ {
362
+ "temp_inds": temp_inds,
363
+ "walker_inds": walker_inds,
364
+ "overall_inds": overall_inds,
365
+ },
366
+ obj_contained_shape=supps_base_shape,
367
+ copy=True,
368
+ )
282
369
  state.supplimental = supps
283
370
 
284
371
  return state
lisatools/sensitivity.py CHANGED
@@ -162,11 +162,11 @@ class Sensitivity(ABC):
162
162
  ):
163
163
  if stochastic_function is None:
164
164
  stochastic_function = FittedHyperbolicTangentGalacticForeground
165
+ assert len(stochastic_params) == 1
165
166
 
166
- check = stochastic_function.get_Sh(
167
- f, *stochastic_params, **stochastic_kwargs
168
- )
169
- sgal[:] = check
167
+ sgal[:] = stochastic_function.get_Sh(
168
+ f, *stochastic_params, **stochastic_kwargs
169
+ )
170
170
 
171
171
  if squeeze:
172
172
  sgal = sgal.squeeze()