lisaanalysistools 1.0.2__cp312-cp312-macosx_10_9_x86_64.whl → 1.0.4__cp312-cp312-macosx_10_9_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lisaanalysistools might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lisaanalysistools
3
- Version: 1.0.2
3
+ Version: 1.0.4
4
4
  Home-page: https://github.com/mikekatz04/lisa-on-gpu
5
5
  Author: Michael Katz
6
6
  Author-email: mikekatz04@gmail.com
@@ -18,8 +18,9 @@ License-File: LICENSE
18
18
  # LISA Analysis Tools
19
19
 
20
20
  [![Doc badge](https://img.shields.io/badge/Docs-master-brightgreen)](https://mikekatz04.github.io/LISAanalysistools)
21
+ [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10930980.svg)](https://doi.org/10.5281/zenodo.10930980)
21
22
 
22
- LISA Analysis Tools is a package for performing LISA Data Analysis tasks, including building the LISA Global Fit.
23
+ LISA Analysis Tools is a package for performing LISA Data Analysis tasks, including building the LISA Global Fit.
23
24
 
24
25
  ## 1 - Getting Started
25
26
 
@@ -59,16 +60,15 @@ Please read [CONTRIBUTING.md](CONTRIBUTING.md) for details on our code of conduc
59
60
 
60
61
  We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/BlackHolePerturbationToolkit/FastEMRIWaveforms/tags).
61
62
 
62
- Current Version: 1.0.2
63
+ Current Version: 1.0.4
63
64
 
64
65
  ## Authors/Developers
65
66
 
66
67
  * **Michael Katz**
67
-
68
- ### Contibutors
69
-
70
68
  * Lorenzo Speri
71
69
  * Christian Chapman-Bird
70
+ * Natalia Korsakova
71
+ * Nikos Karnesis
72
72
 
73
73
  ## License
74
74
 
@@ -76,5 +76,20 @@ This project is licensed under the Apache License - see the [LICENSE.md](LICENSE
76
76
 
77
77
  ## Citation
78
78
 
79
- TODO.
79
+ ```
80
+ @software{michael_katz_2024_10930980,
81
+ author = {Michael Katz and
82
+ CChapmanbird and
83
+ Lorenzo Speri and
84
+ Nikolaos Karnesis and
85
+ Korsakova, Natalia},
86
+ title = {mikekatz04/LISAanalysistools: First main release.},
87
+ month = apr,
88
+ year = 2024,
89
+ publisher = {Zenodo},
90
+ version = {v1.0.3},
91
+ doi = {10.5281/zenodo.10930980},
92
+ url = {https://doi.org/10.5281/zenodo.10930980}
93
+ }
94
+ ```
80
95
 
@@ -1,16 +1,16 @@
1
1
  lisatools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- lisatools/_version.py,sha256=deQ7mveodnSkIp1gTJM0VfFR2eczJFUm4STvvmnmvZ4,123
3
- lisatools/analysiscontainer.py,sha256=KG1Iww3iYHyeoKYgfASqQ2Lu2VilXuW9zDW1Jkw5ARc,15329
4
- lisatools/datacontainer.py,sha256=MErM0cWZdxVtbaA8HTU79bXmxv2XcYDcqPiHyPbPZE0,9224
5
- lisatools/detector.py,sha256=VqOhvzG3RiG8o7AX9Q9123WYuRNl7Leq_6oBzenRZOU,12183
6
- lisatools/diagnostic.py,sha256=oPMovelkyTYUXQWzLvtZB-nia-oxYhlNIci4P2r0Bhg,34177
2
+ lisatools/_version.py,sha256=ZyRiY_EiNuQjMpqN7rHllKpqA3B9m-CTwI9pWBKukog,123
3
+ lisatools/analysiscontainer.py,sha256=ePwTBUTEBJn2TK93_afARate9SAqUKK8c8T6DcGUx1Y,15321
4
+ lisatools/datacontainer.py,sha256=W89ErPJynfeioZwYqcpehHVorhKsb8FLKrj69zIsKKU,9187
5
+ lisatools/detector.py,sha256=Ht5v-Iq_DJlCWMt9iSbZwmEwTGwd3pdeOaw4lBtmiBQ,13852
6
+ lisatools/diagnostic.py,sha256=CfPpfvDys1XyZRWmmTqTSWb0SY2eH0G_8TRnt1OxBFo,34174
7
7
  lisatools/glitch.py,sha256=qMNSqdmGqdm6kVtZP9qncP_40DyPj9ushbXh88g9wlU,5154
8
- lisatools/sensitivity.py,sha256=4ELRfrB5CY7Fkg6LYuuseTdbqnx59My0i7cGvK2VxKg,27236
9
- lisatools/stochastic.py,sha256=XYZpsWiVB8Yz9V5-_hAKjzOfqe8e5ln0amtWypjqCCg,9440
10
- lisatools/cutils/detector.cpython-312-darwin.so,sha256=THx-1ysBWhHf54vU10ZyBpQN-38ON7F7IE74FTwvatQ,121664
8
+ lisatools/sensitivity.py,sha256=eYSqM3Kr5UsAhqR3J5DWQBGaavL5dGO1ouWpbhMwhao,27263
9
+ lisatools/stochastic.py,sha256=wdiToEj4nUpCDIb0j7vQ7netTPDDtPGPbUg8-RiFA9U,9421
10
+ lisatools/cutils/detector.cpython-312-darwin.so,sha256=FguczB3QPZfL3CCSMEuIx3VP0jj240_Dh30anER84m8,121664
11
11
  lisatools/sampling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  lisatools/sampling/likelihood.py,sha256=G2kAQ43qlhAzIFWvYsrSmHXd7WKJAzcCN2o07vRE8vc,29585
13
- lisatools/sampling/prior.py,sha256=V9sWorxDamgsSFPZpK3t1651yPp6znRVXe0zto0wFpc,18461
13
+ lisatools/sampling/prior.py,sha256=1K1PMStpwO9WT0qG0aotKSyoNjuehXNbzTDtlk8Q15M,21407
14
14
  lisatools/sampling/stopping.py,sha256=Q8q7nM0wnJervhRduf2tBXZZHUVza5kJiAUAMUQXP5o,9682
15
15
  lisatools/sampling/utility.py,sha256=rOGotS0Aj8-DAWqsTVy2xWNsxsoz74BVrHEnG2mOkwU,14340
16
16
  lisatools/sampling/moves/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -29,9 +29,9 @@ lisatools/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  lisatools/utils/constants.py,sha256=r1kVwkpbZS13JTOxj2iRxT5sMgTYX30y-S0JdVmD5Oo,1354
30
30
  lisatools/utils/multigpudataholder.py,sha256=6HwkOceqga1Q7eK4TjStGXy4oKgx37hTkdcAwiiZ8_Y,33765
31
31
  lisatools/utils/pointeradjust.py,sha256=2sT-7qeYWr1pd_sHk9leVHUTSJ7jJgYIRoWQOtYqguE,2995
32
- lisatools/utils/utility.py,sha256=TgZ4vLVGih4ZU2caMRlK06m8nMoEVvwrS3Q7dH83u1g,6742
33
- lisaanalysistools-1.0.2.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
- lisaanalysistools-1.0.2.dist-info/METADATA,sha256=lZPzQp4QDwR64AUkz4sjKk6pRMmXMYDo4Z1w0DQCCoY,2753
35
- lisaanalysistools-1.0.2.dist-info/WHEEL,sha256=KYtn_mzb_QwZSHwPlosUO3fDl70znfUFngLlrLVHeBY,111
36
- lisaanalysistools-1.0.2.dist-info/top_level.txt,sha256=oCQGY7qy66i_b9MCsK2fTRdbV1pcC9GsGgIDjN47Tyc,14
37
- lisaanalysistools-1.0.2.dist-info/RECORD,,
32
+ lisatools/utils/utility.py,sha256=3mJoJKNGrm3KuNXIa2RUKi9WKd593V4q9XjjQZCQD0M,6831
33
+ lisaanalysistools-1.0.4.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
+ lisaanalysistools-1.0.4.dist-info/METADATA,sha256=reVoFCjiQuN-XxTsgR7PW_ROH4yQEkWLeQPxF3adDcw,3380
35
+ lisaanalysistools-1.0.4.dist-info/WHEEL,sha256=KYtn_mzb_QwZSHwPlosUO3fDl70znfUFngLlrLVHeBY,111
36
+ lisaanalysistools-1.0.4.dist-info/top_level.txt,sha256=oCQGY7qy66i_b9MCsK2fTRdbV1pcC9GsGgIDjN47Tyc,14
37
+ lisaanalysistools-1.0.4.dist-info/RECORD,,
lisatools/_version.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '1.0.2'
1
+ __version__ = '1.0.4'
2
2
  __copyright__ = "Michael L. Katz 2024"
3
3
  __name__ = "lisaanalysistools"
4
4
  __author__ = "Michael L. Katz"
@@ -1,13 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+
1
4
  import warnings
2
5
  from abc import ABC
3
6
  from typing import Any, Tuple, Optional, List
4
7
 
5
8
  import math
6
9
  import numpy as np
7
- from numpy.typing import ArrayLike
10
+
8
11
  from scipy import interpolate
9
12
  import matplotlib.pyplot as plt
10
13
 
14
+
11
15
  try:
12
16
  import cupy as cp
13
17
 
@@ -209,7 +213,7 @@ class AnalysisContainer:
209
213
  template: DataResidualArray,
210
214
  include_psd_info: bool = False,
211
215
  phase_maximize: bool = False,
212
- **kwargs: dict
216
+ **kwargs: dict,
213
217
  ) -> float:
214
218
  """Calculate the Likelihood of a template against the data.
215
219
 
@@ -268,7 +272,7 @@ class AnalysisContainer:
268
272
  return noise_likelihood_term(self.sens_mat)
269
273
  elif source_only:
270
274
  return residual_source_likelihood_term(
271
- self.data_res_arr, self.sens_mat, **kwargs
275
+ self.data_res_arr, psd=self.sens_mat, **kwargs
272
276
  )
273
277
  else:
274
278
  return residual_full_source_and_noise_likelihood(
@@ -278,11 +282,11 @@ class AnalysisContainer:
278
282
  def _calculate_signal_operation(
279
283
  self,
280
284
  calc: str,
281
- *args: ArrayLike,
285
+ *args: Any,
282
286
  source_only: bool = False,
283
287
  waveform_kwargs: Optional[dict] = {},
284
288
  data_res_arr_kwargs: Optional[dict] = {},
285
- **kwargs: dict
289
+ **kwargs: dict,
286
290
  ) -> float | complex:
287
291
  """Return the likelihood of a generated signal with the data.
288
292
 
@@ -328,11 +332,11 @@ class AnalysisContainer:
328
332
 
329
333
  def calculate_signal_likelihood(
330
334
  self,
331
- *args: ArrayLike,
335
+ *args: Any,
332
336
  source_only: bool = False,
333
337
  waveform_kwargs: Optional[dict] = {},
334
338
  data_res_arr_kwargs: Optional[dict] = {},
335
- **kwargs: dict
339
+ **kwargs: dict,
336
340
  ) -> float | complex:
337
341
  """Return the likelihood of a generated signal with the data.
338
342
 
@@ -355,16 +359,16 @@ class AnalysisContainer:
355
359
  source_only=source_only,
356
360
  waveform_kwargs=waveform_kwargs,
357
361
  data_res_arr_kwargs=data_res_arr_kwargs,
358
- **kwargs
362
+ **kwargs,
359
363
  )
360
364
 
361
365
  def calculate_signal_inner_product(
362
366
  self,
363
- *args: ArrayLike,
367
+ *args: Any,
364
368
  source_only: bool = False,
365
369
  waveform_kwargs: Optional[dict] = {},
366
370
  data_res_arr_kwargs: Optional[dict] = {},
367
- **kwargs: dict
371
+ **kwargs: dict,
368
372
  ) -> float | complex:
369
373
  """Return the inner product of a generated signal with the data.
370
374
 
@@ -387,16 +391,16 @@ class AnalysisContainer:
387
391
  source_only=source_only,
388
392
  waveform_kwargs=waveform_kwargs,
389
393
  data_res_arr_kwargs=data_res_arr_kwargs,
390
- **kwargs
394
+ **kwargs,
391
395
  )
392
396
 
393
397
  def calculate_signal_snr(
394
398
  self,
395
- *args: ArrayLike,
399
+ *args: Any,
396
400
  source_only: bool = False,
397
401
  waveform_kwargs: Optional[dict] = {},
398
402
  data_res_arr_kwargs: Optional[dict] = {},
399
- **kwargs: dict
403
+ **kwargs: dict,
400
404
  ) -> Tuple[float, float]:
401
405
  """Return the SNR of a generated signal with the data.
402
406
 
@@ -419,7 +423,7 @@ class AnalysisContainer:
419
423
  source_only=source_only,
420
424
  waveform_kwargs=waveform_kwargs,
421
425
  data_res_arr_kwargs=data_res_arr_kwargs,
422
- **kwargs
426
+ **kwargs,
423
427
  )
424
428
 
425
429
  def eryn_likelihood_function(self, x, *args, **kwargs):
@@ -4,7 +4,6 @@ from typing import Any, Tuple, Optional, List
4
4
 
5
5
  import math
6
6
  import numpy as np
7
- from numpy.typing import ArrayLike
8
7
  from scipy import interpolate
9
8
  import matplotlib.pyplot as plt
10
9
 
@@ -36,7 +35,7 @@ class DataResidualArray:
36
35
  sens_mat: Input sensitivity list. The shape of the nested lists should represent the shape of the
37
36
  desired matrix. Each entry in the list must be an array, :class:`Sensitivity`-derived object,
38
37
  or a string corresponding to the :class:`Sensitivity` object.
39
- **sens_kwargs: Keyword arguments to pass to :method:`Sensitivity.get_Sn`.
38
+ **sens_kwargs: Keyword arguments to pass to :func:`Sensitivity.get_Sn`.
40
39
 
41
40
  """
42
41
 
lisatools/detector.py CHANGED
@@ -22,7 +22,6 @@ class Orbits(ABC):
22
22
  Args:
23
23
  filename: File name. File should be in the style of LISAOrbits
24
24
 
25
-
26
25
  """
27
26
 
28
27
  def __init__(self, filename: str) -> None:
@@ -51,6 +50,7 @@ class Orbits(ABC):
51
50
  return [int(str(link_i)[1]) for link_i in self.LINKS]
52
51
 
53
52
  def _setup(self) -> None:
53
+ """Read in orbital data from file and store."""
54
54
  with self.open() as f:
55
55
  for key in f.attrs.keys():
56
56
  setattr(self, key + "_base", f.attrs[key])
@@ -63,22 +63,33 @@ class Orbits(ABC):
63
63
  @filename.setter
64
64
  def filename(self, filename: str) -> None:
65
65
  """Set file name."""
66
+
66
67
  assert isinstance(filename, str)
68
+
69
+ # get path
67
70
  path_to_this_file = __file__.split("detector.py")[0]
71
+
72
+ # make sure orbit_files directory exists in the right place
68
73
  if not os.path.exists(path_to_this_file + "orbit_files/"):
69
74
  os.mkdir(path_to_this_file + "orbit_files/")
70
75
  path_to_this_file = path_to_this_file + "orbit_files/"
76
+
71
77
  if not os.path.exists(path_to_this_file + filename):
78
+ # download files from github if they are not there
72
79
  github_file = f"https://github.com/mikekatz04/LISAanalysistools/raw/main/lisatools/orbit_files/{filename}"
73
80
  r = requests.get(github_file)
81
+
82
+ # if not success
74
83
  if r.status_code != 200:
75
84
  raise ValueError(
76
85
  f"Cannot find {filename} within default files located at github.com/mikekatz04/LISAanalysistools/lisatools/orbit_files/."
77
86
  )
78
87
 
88
+ # write the contents to a local file
79
89
  with open(path_to_this_file + filename, "wb") as f:
80
90
  f.write(r.content)
81
91
 
92
+ # store
82
93
  self._filename = path_to_this_file + filename
83
94
 
84
95
  def open(self) -> h5py.File:
@@ -96,7 +107,7 @@ class Orbits(ABC):
96
107
 
97
108
  @property
98
109
  def t_base(self) -> np.ndarray:
99
- """Light travel times along links from file."""
110
+ """Time array from file."""
100
111
  with self.open() as f:
101
112
  t_base = np.arange(self.size_base) * self.dt_base
102
113
  return t_base
@@ -117,14 +128,14 @@ class Orbits(ABC):
117
128
 
118
129
  @property
119
130
  def x_base(self) -> np.ndarray:
120
- """Light travel times along links from file."""
131
+ """Spacecraft position from file."""
121
132
  with self.open() as f:
122
133
  x = f["tcb"]["x"][:]
123
134
  return x
124
135
 
125
136
  @property
126
137
  def v_base(self) -> np.ndarray:
127
- """Light travel times along links from file."""
138
+ """Spacecraft velocities from file."""
128
139
  with self.open() as f:
129
140
  v = f["tcb"]["v"][:]
130
141
  return v
@@ -154,35 +165,35 @@ class Orbits(ABC):
154
165
 
155
166
  @property
156
167
  def n(self) -> np.ndarray:
157
- """Light travel time."""
168
+ """Normal vectors along links."""
158
169
  self._check_configured()
159
170
  return self._n
160
171
 
161
172
  @n.setter
162
173
  def n(self, n: np.ndarray) -> np.ndarray:
163
- """Set light travel time."""
174
+ """Set Normal vectors along links."""
164
175
  return self._n
165
176
 
166
177
  @property
167
178
  def x(self) -> np.ndarray:
168
- """Light travel time."""
179
+ """Spacecraft positions."""
169
180
  self._check_configured()
170
181
  return self._x
171
182
 
172
183
  @x.setter
173
184
  def x(self, x: np.ndarray) -> np.ndarray:
174
- """Set light travel time."""
185
+ """Set Spacecraft positions."""
175
186
  return self._x
176
187
 
177
188
  @property
178
189
  def v(self) -> np.ndarray:
179
- """Light travel time."""
190
+ """Spacecraft velocities."""
180
191
  self._check_configured()
181
192
  return self._v
182
193
 
183
194
  @v.setter
184
195
  def v(self, v: np.ndarray) -> np.ndarray:
185
- """Set light travel time."""
196
+ """Set Spacecraft velocities."""
186
197
  return self._v
187
198
 
188
199
  def configure(
@@ -290,6 +301,7 @@ class Orbits(ABC):
290
301
 
291
302
  @property
292
303
  def pycppdetector_args(self) -> tuple:
304
+ """args for the c++ class."""
293
305
  return self._pycppdetector_args
294
306
 
295
307
  @pycppdetector_args.setter
@@ -311,30 +323,81 @@ class Orbits(ABC):
311
323
  def get_light_travel_times(
312
324
  self, t: float | np.ndarray, link: int
313
325
  ) -> float | np.ndarray:
326
+ """Compute light travel time as a function of time.
327
+
328
+ Computes with the c++ backend.
329
+
330
+ Args:
331
+ t: Time array in seconds.
332
+ link: which link. Must be ``in self.LINKS``.
333
+
334
+ Returns:
335
+ Light travel times.
336
+
337
+ """
314
338
  return self.pycppdetector.get_light_travel_time(t, link)
315
339
 
316
340
  def get_normal_unit_vec(self, t: float | np.ndarray, link: int) -> np.ndarray:
341
+ """Compute link normal vector as a function of time.
342
+
343
+ Computes with the c++ backend.
344
+
345
+ Args:
346
+ t: Time array in seconds.
347
+ link: which link. Must be ``in self.LINKS``.
348
+
349
+ Returns:
350
+ Link normal vectors.
351
+
352
+ """
317
353
  return self.pycppdetector.get_normal_unit_vec(t, link)
318
354
 
319
355
  def get_pos(self, t: float | np.ndarray, sc: int) -> np.ndarray:
356
+ """Compute spacecraft position as a function of time.
357
+
358
+ Computes with the c++ backend.
359
+
360
+ Args:
361
+ t: Time array in seconds.
362
+ sc: which spacecraft. Must be ``in self.SC``.
363
+
364
+ Returns:
365
+ Spacecraft positions.
366
+
367
+ """
320
368
  return self.pycppdetector.get_pos(t, sc)
321
369
 
322
370
  @property
323
371
  def ptr(self) -> int:
324
- """pointer to c-class"""
372
+ """pointer to c++ class"""
325
373
  return self.pycppdetector.ptr
326
374
 
327
375
 
328
376
  class EqualArmlengthOrbits(Orbits):
329
- """Equal Armlength Orbits"""
377
+ """Equal Armlength Orbits
378
+
379
+ Orbit file: equalarmlength-orbits.h5
380
+
381
+ """
330
382
 
331
383
  def __init__(self):
332
- # TODO: fix this up
333
384
  super().__init__("equalarmlength-orbits.h5")
334
385
 
335
386
 
387
+ class ESAOrbits(Orbits):
388
+ """ESA Orbits
389
+
390
+ Orbit file: esa-trailing-orbits.h5
391
+
392
+ """
393
+
394
+ def __init__(self):
395
+ # TODO: fix this up
396
+ super().__init__("esa-trailing-orbits.h5")
397
+
398
+
336
399
  class DefaultOrbits(EqualArmlengthOrbits):
337
- """Set default orbit class to Equal Arm Length orbits for now."""
400
+ """Set default orbit class to Equal Armlength orbits for now."""
338
401
 
339
402
  pass
340
403
 
@@ -343,11 +406,10 @@ class DefaultOrbits(EqualArmlengthOrbits):
343
406
  class LISAModelSettings:
344
407
  """Required LISA model settings:
345
408
 
346
- TODO: rename these
347
-
348
409
  Args:
349
410
  Soms_d: OMS displacement noise.
350
411
  Sa_a: Acceleration noise.
412
+ orbits: Orbital information.
351
413
  name: Name of model.
352
414
 
353
415
  """
@@ -359,7 +421,16 @@ class LISAModelSettings:
359
421
 
360
422
 
361
423
  class LISAModel(LISAModelSettings, ABC):
362
- """Model for the LISA Constellation"""
424
+ """Model for the LISA Constellation
425
+
426
+ This includes sensitivity information computed in
427
+ :module:`lisatools.sensitivity` and orbital information
428
+ contained in an :class:`Orbits` class object.
429
+
430
+ This class is used to house high-level methods useful
431
+ to various needed computations.
432
+
433
+ """
363
434
 
364
435
  def __str__(self) -> str:
365
436
  out = "LISA Constellation Configurations Settings:\n"
@@ -368,6 +439,7 @@ class LISAModel(LISAModelSettings, ABC):
368
439
  return out
369
440
 
370
441
 
442
+ # defaults
371
443
  scirdv1 = LISAModel((15.0e-12) ** 2, (3.0e-15) ** 2, DefaultOrbits(), "scirdv1")
372
444
  proposal = LISAModel((10.0e-12) ** 2, (3.0e-15) ** 2, DefaultOrbits(), "proposal")
373
445
  mrdv1 = LISAModel((10.0e-12) ** 2, (2.4e-15) ** 2, DefaultOrbits(), "mrdv1")
lisatools/diagnostic.py CHANGED
@@ -2,7 +2,6 @@ import warnings
2
2
  from types import ModuleType, NoneType
3
3
  from typing import Optional, Any, Tuple, List
4
4
 
5
- from numpy.typing import ArrayLike
6
5
  import matplotlib.pyplot as plt
7
6
 
8
7
  from eryn.utils import TransformContainer
@@ -387,7 +386,7 @@ def snr(
387
386
  def h_var_p_eps(
388
387
  step: float,
389
388
  waveform_model: callable,
390
- params: ArrayLike,
389
+ params: np.ndarray | list,
391
390
  index: int,
392
391
  parameter_transforms: Optional[TransformContainer] = None,
393
392
  waveform_args: Optional[tuple] = (),
@@ -499,8 +498,8 @@ def dh_dlambda(
499
498
  def info_matrix(
500
499
  eps: float | np.ndarray,
501
500
  waveform_model: callable,
502
- params: ArrayLike,
503
- deriv_inds: Optional[ArrayLike] = None,
501
+ params: np.ndarray | list,
502
+ deriv_inds: Optional[np.ndarray | list] = None,
504
503
  inner_product_kwargs: Optional[dict] = {},
505
504
  return_derivs: Optional[bool] = False,
506
505
  **kwargs: dict,
@@ -793,7 +792,7 @@ def cutler_vallisneri_bias(
793
792
  eps: float | np.ndarray,
794
793
  input_diagnostics: Optional[dict] = None,
795
794
  info_mat: Optional[np.ndarray] = None,
796
- deriv_inds: Optional[ArrayLike] = None,
795
+ deriv_inds: Optional[np.ndarray | list] = None,
797
796
  return_derivs: Optional[bool] = False,
798
797
  return_cov: Optional[bool] = False,
799
798
  parameter_transforms: Optional[TransformContainer] = None,
@@ -8,7 +8,10 @@ from eryn.moves.multipletry import logsumexp
8
8
  from typing import Union, Optional, Tuple, List
9
9
 
10
10
  import sys
11
- sys.path.append("/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/gf_search/")
11
+
12
+ sys.path.append(
13
+ "/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/gf_search/"
14
+ )
12
15
  # from galaxy_ffdot import GalaxyFFdot
13
16
  # from galaxy import Galaxy
14
17
 
@@ -18,17 +21,20 @@ try:
18
21
  except (ModuleNotFoundError, ImportError) as e:
19
22
  pass
20
23
 
24
+
21
25
  class AmplitudeFrequencySNRPrior:
22
- def __init__(self, rho_star, frequency_prior, L, Tobs, use_cupy=False, **noise_kwargs):
26
+ def __init__(
27
+ self, rho_star, frequency_prior, L, Tobs, use_cupy=False, **noise_kwargs
28
+ ):
23
29
  self.rho_star = rho_star
24
30
  self.frequency_prior = frequency_prior
25
31
 
26
32
  self.transform = AmplitudeFromSNR(L, Tobs, use_cupy=use_cupy, **noise_kwargs)
27
33
  self.snr_prior = SNRPrior(rho_star, use_cupy=use_cupy)
28
-
34
+
29
35
  # must be after transform and snr_prior due to setter
30
36
  self.use_cupy = use_cupy
31
-
37
+
32
38
  @property
33
39
  def use_cupy(self):
34
40
  return self._use_cupy
@@ -70,7 +76,7 @@ class AmplitudeFrequencySNRPrior:
70
76
  else:
71
77
  f0_ms = f0_input
72
78
  assert f0_input.shape[:-1] == size
73
-
79
+
74
80
  f0 = f0_ms / 1e3
75
81
 
76
82
  rho = self.snr_prior.rvs(size=size)
@@ -80,10 +86,6 @@ class AmplitudeFrequencySNRPrior:
80
86
  return (amp, f0_ms)
81
87
 
82
88
 
83
-
84
-
85
-
86
-
87
89
  class SNRPrior:
88
90
  def __init__(self, rho_star, use_cupy=False):
89
91
  self.rho_star = rho_star
@@ -98,12 +100,16 @@ class SNRPrior:
98
100
  self._use_cupy = use_cupy
99
101
 
100
102
  def pdf(self, rho):
101
-
103
+
102
104
  xp = np if not self.use_cupy else cp
103
105
 
104
106
  p = xp.zeros_like(rho)
105
107
  good = rho > 0.0
106
- p[good] = 3 * rho[good] / (4 * self.rho_star ** 2 * (1 + rho[good] / (4 * self.rho_star)) ** 5)
108
+ p[good] = (
109
+ 3
110
+ * rho[good]
111
+ / (4 * self.rho_star**2 * (1 + rho[good] / (4 * self.rho_star)) ** 5)
112
+ )
107
113
  return p
108
114
 
109
115
  def logpdf(self, rho):
@@ -114,7 +120,15 @@ class SNRPrior:
114
120
  xp = np if not self.use_cupy else cp
115
121
  c = xp.zeros_like(rho)
116
122
  good = rho > 0.0
117
- c[good] = 768 * self.rho_star ** 3 * (1 / (768. * self.rho_star ** 3) - (rho[good] + self.rho_star)/(3. * (rho[good] + 4 * self.rho_star) ** 4))
123
+ c[good] = (
124
+ 768
125
+ * self.rho_star**3
126
+ * (
127
+ 1 / (768.0 * self.rho_star**3)
128
+ - (rho[good] + self.rho_star)
129
+ / (3.0 * (rho[good] + 4 * self.rho_star) ** 4)
130
+ )
131
+ )
118
132
  return c
119
133
 
120
134
  def rvs(self, size=1):
@@ -125,49 +139,125 @@ class SNRPrior:
125
139
 
126
140
  u = xp.random.rand(*size)
127
141
 
128
- rho = (-4*self.rho_star + xp.sqrt(-32*self.rho_star**2 - (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) +
129
- (3072*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*
130
- (self.rho_star**4 - u*self.rho_star**4))/
131
- ((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
132
- xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
133
- 3131031158784*u**3*self.rho_star**12))) +
134
- xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
135
- xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
136
- 3131031158784*u**3*self.rho_star**12))/
137
- (3.*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)))/2.
138
- + xp.sqrt(32*self.rho_star**2 + (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) -
139
- (3072*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*
140
- (self.rho_star**4 - u*self.rho_star**4))/
141
- ((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
142
- xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
143
- 3131031158784*u**3*self.rho_star**12))) -
144
- xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
145
- xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
146
- 3131031158784*u**3*self.rho_star**12))/
147
- (3.*2**0.3333333333333333*xp.cbrt(-1 + 3*u - 3*u**2 + u**3)) +
148
- (2048*self.rho_star**3 - (2048*u*self.rho_star**3)/(-1 + u))/
149
- (4.*xp.sqrt(-32*self.rho_star**2 - (32*(-self.rho_star**2 + u*self.rho_star**2))/(1 - u) +
150
- (3072*2**0.3333333333333333*
151
- xp.cbrt(-1 + 3*u - 3*u**2 + u**3)*(self.rho_star**4 - u*self.rho_star**4)
152
- )/
153
- ((-1 + u)**2*xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
154
- xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
155
- 3131031158784*u**3*self.rho_star**12))) +
156
- xp.cbrt(-1769472*self.rho_star**6 + 1769472*u*self.rho_star**6 -
157
- xp.sqrt(3131031158784*u*self.rho_star**12 - 6262062317568*u**2*self.rho_star**12 +
158
- 3131031158784*u**3*self.rho_star**12))/
159
- (3.*2**0.3333333333333333*
160
- xp.cbrt(-1 + 3*u - 3*u**2 + u**3)))))/2.)
142
+ rho = (
143
+ -4 * self.rho_star
144
+ + xp.sqrt(
145
+ -32 * self.rho_star**2
146
+ - (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
147
+ + (
148
+ 3072
149
+ * 2**0.3333333333333333
150
+ * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
151
+ * (self.rho_star**4 - u * self.rho_star**4)
152
+ )
153
+ / (
154
+ (-1 + u) ** 2
155
+ * xp.cbrt(
156
+ -1769472 * self.rho_star**6
157
+ + 1769472 * u * self.rho_star**6
158
+ - xp.sqrt(
159
+ 3131031158784 * u * self.rho_star**12
160
+ - 6262062317568 * u**2 * self.rho_star**12
161
+ + 3131031158784 * u**3 * self.rho_star**12
162
+ )
163
+ )
164
+ )
165
+ + xp.cbrt(
166
+ -1769472 * self.rho_star**6
167
+ + 1769472 * u * self.rho_star**6
168
+ - xp.sqrt(
169
+ 3131031158784 * u * self.rho_star**12
170
+ - 6262062317568 * u**2 * self.rho_star**12
171
+ + 3131031158784 * u**3 * self.rho_star**12
172
+ )
173
+ )
174
+ / (3.0 * 2**0.3333333333333333 * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3))
175
+ )
176
+ / 2.0
177
+ + xp.sqrt(
178
+ 32 * self.rho_star**2
179
+ + (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
180
+ - (
181
+ 3072
182
+ * 2**0.3333333333333333
183
+ * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
184
+ * (self.rho_star**4 - u * self.rho_star**4)
185
+ )
186
+ / (
187
+ (-1 + u) ** 2
188
+ * xp.cbrt(
189
+ -1769472 * self.rho_star**6
190
+ + 1769472 * u * self.rho_star**6
191
+ - xp.sqrt(
192
+ 3131031158784 * u * self.rho_star**12
193
+ - 6262062317568 * u**2 * self.rho_star**12
194
+ + 3131031158784 * u**3 * self.rho_star**12
195
+ )
196
+ )
197
+ )
198
+ - xp.cbrt(
199
+ -1769472 * self.rho_star**6
200
+ + 1769472 * u * self.rho_star**6
201
+ - xp.sqrt(
202
+ 3131031158784 * u * self.rho_star**12
203
+ - 6262062317568 * u**2 * self.rho_star**12
204
+ + 3131031158784 * u**3 * self.rho_star**12
205
+ )
206
+ )
207
+ / (3.0 * 2**0.3333333333333333 * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3))
208
+ + (2048 * self.rho_star**3 - (2048 * u * self.rho_star**3) / (-1 + u))
209
+ / (
210
+ 4.0
211
+ * xp.sqrt(
212
+ -32 * self.rho_star**2
213
+ - (32 * (-self.rho_star**2 + u * self.rho_star**2)) / (1 - u)
214
+ + (
215
+ 3072
216
+ * 2**0.3333333333333333
217
+ * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
218
+ * (self.rho_star**4 - u * self.rho_star**4)
219
+ )
220
+ / (
221
+ (-1 + u) ** 2
222
+ * xp.cbrt(
223
+ -1769472 * self.rho_star**6
224
+ + 1769472 * u * self.rho_star**6
225
+ - xp.sqrt(
226
+ 3131031158784 * u * self.rho_star**12
227
+ - 6262062317568 * u**2 * self.rho_star**12
228
+ + 3131031158784 * u**3 * self.rho_star**12
229
+ )
230
+ )
231
+ )
232
+ + xp.cbrt(
233
+ -1769472 * self.rho_star**6
234
+ + 1769472 * u * self.rho_star**6
235
+ - xp.sqrt(
236
+ 3131031158784 * u * self.rho_star**12
237
+ - 6262062317568 * u**2 * self.rho_star**12
238
+ + 3131031158784 * u**3 * self.rho_star**12
239
+ )
240
+ )
241
+ / (
242
+ 3.0
243
+ * 2**0.3333333333333333
244
+ * xp.cbrt(-1 + 3 * u - 3 * u**2 + u**3)
245
+ )
246
+ )
247
+ )
248
+ )
249
+ / 2.0
250
+ )
161
251
 
162
252
  return rho
163
253
 
164
254
 
165
255
  class AmplitudeFromSNR:
166
256
  def __init__(self, L, Tobs, fd=None, use_cupy=False, **noise_kwargs):
167
- self.f_star = 1 / (2. * np.pi * L) * C_SI
257
+ self.f_star = 1 / (2.0 * np.pi * L) * C_SI
168
258
  self.Tobs = Tobs
169
259
  self.noise_kwargs = noise_kwargs
170
-
260
+
171
261
  xp = np if not use_cupy else cp
172
262
  if fd is not None:
173
263
  self.fd = xp.asarray(fd)
@@ -193,7 +283,7 @@ class AmplitudeFromSNR:
193
283
  assert self.fd is not None
194
284
  xp = np if not self.use_cupy else cp
195
285
  psds = xp.atleast_2d(psds)
196
-
286
+
197
287
  if xp == cp and not isinstance(self.fd, cp.ndarray):
198
288
  self.fd = xp.asarray(self.fd)
199
289
  try:
@@ -203,7 +293,9 @@ class AmplitudeFromSNR:
203
293
  if walker_inds is None:
204
294
  walker_inds = xp.zeros_like(f0, dtype=int)
205
295
 
206
- new_psds = (psds[(walker_inds, inds_fd + 1)] - psds[(walker_inds, inds_fd)]) / (self.fd[inds_fd + 1] - self.fd[inds_fd]) * (f0 - self.fd[inds_fd]) + psds[(walker_inds, inds_fd)]
296
+ new_psds = (psds[(walker_inds, inds_fd + 1)] - psds[(walker_inds, inds_fd)]) / (
297
+ self.fd[inds_fd + 1] - self.fd[inds_fd]
298
+ ) * (f0 - self.fd[inds_fd]) + psds[(walker_inds, inds_fd)]
207
299
  return new_psds
208
300
 
209
301
  def __call__(self, rho, f0, **noise_kwargs):
@@ -215,7 +307,7 @@ class AmplitudeFromSNR:
215
307
 
216
308
  Sn_f = self.get_Sn_f(f0, **noise_kwargs)
217
309
 
218
- factor = 1./2. * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
310
+ factor = 1.0 / 2.0 * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
219
311
  amp = rho / factor
220
312
  return (amp, f0)
221
313
 
@@ -238,7 +330,7 @@ class AmplitudeFromSNR:
238
330
 
239
331
  Sn_f = self.get_Sn_f(f0, **noise_kwargs)
240
332
 
241
- factor = 1./2. * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
333
+ factor = 1.0 / 2.0 * np.sqrt((self.Tobs * np.sin(f0 / self.f_star) ** 2) / Sn_f)
242
334
  rho = amp * factor
243
335
  return (rho, f0)
244
336
 
@@ -275,7 +367,7 @@ class GBPriorWrap:
275
367
  xp = np if not self.use_cupy else cp
276
368
  if isinstance(size, int):
277
369
  size = (size,)
278
-
370
+
279
371
  arr = xp.zeros(size + (self.ndim,)).reshape(-1, self.ndim)
280
372
 
281
373
  diff = self.ndim - len(self.keys_sep)
@@ -285,15 +377,35 @@ class GBPriorWrap:
285
377
 
286
378
  if not ignore_amp:
287
379
  f0_input = arr[:, 1] if self.gen_frequency_alone else None
288
- arr[:, :diff] = xp.asarray(self.base_prior.priors_in[(0, 1)].rvs(size, f0_input=f0_input, **kwargs)).reshape(diff, -1).T
380
+ arr[:, :diff] = (
381
+ xp.asarray(
382
+ self.base_prior.priors_in[(0, 1)].rvs(
383
+ size, f0_input=f0_input, **kwargs
384
+ )
385
+ )
386
+ .reshape(diff, -1)
387
+ .T
388
+ )
289
389
 
290
390
  arr = arr.reshape(size + (self.ndim,))
291
391
  return arr
292
392
 
293
393
 
294
394
  class FullGaussianMixtureModel:
295
- def __init__(self, gb, weights, means, covs, invcovs, dets, mins, maxs, limit=10.0, use_cupy=False):
296
-
395
+ def __init__(
396
+ self,
397
+ gb,
398
+ weights,
399
+ means,
400
+ covs,
401
+ invcovs,
402
+ dets,
403
+ mins,
404
+ maxs,
405
+ limit=10.0,
406
+ use_cupy=False,
407
+ ):
408
+
297
409
  self.use_cupy = use_cupy
298
410
  if use_cupy:
299
411
  xp = cp
@@ -306,7 +418,7 @@ class FullGaussianMixtureModel:
306
418
  for i, weight in enumerate(weights):
307
419
  index_base = np.full_like(weight, i, dtype=int)
308
420
  indexing.append(index_base)
309
-
421
+
310
422
  self.indexing = xp.asarray(np.concatenate(indexing))
311
423
  # invidivual weights / total number of components to uniformly choose from them
312
424
  self.weights = xp.asarray(np.concatenate(weights, axis=0) * 1 / len(weights))
@@ -326,19 +438,28 @@ class FullGaussianMixtureModel:
326
438
  self.means_in_pdf = self.means.T.flatten().copy()
327
439
  self.invcovs_in_pdf = self.invcovs.transpose(1, 2, 0).flatten().copy()
328
440
 
329
- self.cumulative_weights = xp.concatenate([xp.array([0.0]), xp.cumsum(self.weights)])
441
+ self.cumulative_weights = xp.concatenate(
442
+ [xp.array([0.0]), xp.cumsum(self.weights)]
443
+ )
330
444
 
331
- self.min_limit_f = self.map_back_frequency(-1. * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1])
332
- self.max_limit_f = self.map_back_frequency(+1. * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1])
445
+ self.min_limit_f = self.map_back_frequency(
446
+ -1.0 * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1]
447
+ )
448
+ self.max_limit_f = self.map_back_frequency(
449
+ +1.0 * limit, self.mins[self.indexing, 1], self.maxs[self.indexing, 1]
450
+ )
333
451
 
334
452
  # compute the jacobian
335
- self.log_det_J = (self.ndim * np.log(2) - xp.sum(xp.log(self.maxs - self.mins), axis=-1))[self.indexing].copy()
453
+ self.log_det_J = (
454
+ self.ndim * np.log(2) - xp.sum(xp.log(self.maxs - self.mins), axis=-1)
455
+ )[self.indexing].copy()
336
456
 
337
457
  """self.inds_sort_min_limit_f = xp.argsort(self.min_limit_f)
338
458
  self.inds_sort_max_limit_f = xp.argsort(self.max_limit_f)
339
459
  self.sorted_min_limit_f = self.min_limit_f[self.inds_sort_min_limit_f]
340
460
  self.sorted_max_limit_f = self.max_limit_f[self.inds_sort_max_limit_f]
341
461
  """
462
+
342
463
  def logpdf(self, x):
343
464
 
344
465
  if self.use_cupy:
@@ -358,13 +479,15 @@ class FullGaussianMixtureModel:
358
479
  ind_min_limit = xp.searchsorted(f_sort, self.min_limit_f, side="left")
359
480
  ind_max_limit = xp.searchsorted(f_sort, self.max_limit_f, side="right")
360
481
 
361
- diff = (ind_max_limit - ind_min_limit)
482
+ diff = ind_max_limit - ind_min_limit
362
483
  cs = xp.concatenate([xp.array([0]), xp.cumsum(diff)])
363
484
  tmp = xp.arange(cs[-1])
364
485
  keep_component_map = xp.searchsorted(cs, tmp, side="right") - 1
365
- keep_point_map = tmp - cs[keep_component_map] + ind_min_limit[keep_component_map]
486
+ keep_point_map = (
487
+ tmp - cs[keep_component_map] + ind_min_limit[keep_component_map]
488
+ )
366
489
  max_components = diff.max().item()
367
-
490
+
368
491
  int_check = int(1e6)
369
492
  assert int_check > self.min_limit_f.shape[0]
370
493
  special_point_component_map = int_check * keep_point_map + keep_component_map
@@ -375,20 +498,37 @@ class FullGaussianMixtureModel:
375
498
  components_keep_in = sorted_special - points_keep_in * int_check
376
499
 
377
500
  unique_points, unique_starts = xp.unique(points_keep_in, return_index=True)
378
- start_index_in_pdf = xp.concatenate([unique_starts, xp.array([len(points_keep_in)])]).astype(xp.int32)
501
+ start_index_in_pdf = xp.concatenate(
502
+ [unique_starts, xp.array([len(points_keep_in)])]
503
+ ).astype(xp.int32)
379
504
  assert xp.all(xp.diff(unique_starts) > 0)
380
-
505
+
381
506
  points_sorted_in = points_sorted[unique_points]
382
507
 
383
508
  logpdf_out_tmp = xp.zeros(points_sorted_in.shape[0])
384
509
 
385
- self.gb.compute_logpdf(logpdf_out_tmp, components_keep_in.astype(xp.int32), points_sorted_in,
386
- self.weights, self.mins_in_pdf, self.maxs_in_pdf, self.means_in_pdf, self.invcovs_in_pdf, self.dets, self.log_det_J,
387
- points_sorted_in.shape[0], start_index_in_pdf, self.weights.shape[0], x.shape[1])
510
+ self.gb.compute_logpdf(
511
+ logpdf_out_tmp,
512
+ components_keep_in.astype(xp.int32),
513
+ points_sorted_in,
514
+ self.weights,
515
+ self.mins_in_pdf,
516
+ self.maxs_in_pdf,
517
+ self.means_in_pdf,
518
+ self.invcovs_in_pdf,
519
+ self.dets,
520
+ self.log_det_J,
521
+ points_sorted_in.shape[0],
522
+ start_index_in_pdf,
523
+ self.weights.shape[0],
524
+ x.shape[1],
525
+ )
388
526
 
389
527
  # need to reverse the sort
390
528
  logpdf_out = xp.full(x.shape[0], -xp.inf)
391
- logpdf_out[xp.sort(inds_sort[unique_points])] = logpdf_out_tmp[xp.argsort(inds_sort[unique_points])]
529
+ logpdf_out[xp.sort(inds_sort[unique_points])] = logpdf_out_tmp[
530
+ xp.argsort(inds_sort[unique_points])
531
+ ]
392
532
  return logpdf_out
393
533
  """# breakpoint()
394
534
 
@@ -409,11 +549,11 @@ class FullGaussianMixtureModel:
409
549
  return logpdf_full_dist"""
410
550
 
411
551
  def map_input(self, x, mins, maxs):
412
- return ((x - mins) / (maxs - mins)) * 2. - 1.
552
+ return ((x - mins) / (maxs - mins)) * 2.0 - 1.0
413
553
 
414
554
  def map_back_frequency(self, x, mins, maxs):
415
- return (x + 1.) * 1. / 2. * (maxs - mins) + mins
416
-
555
+ return (x + 1.0) * 1.0 / 2.0 * (maxs - mins) + mins
556
+
417
557
  def rvs(self, size=(1,)):
418
558
 
419
559
  if isinstance(size, int):
@@ -426,20 +566,26 @@ class FullGaussianMixtureModel:
426
566
 
427
567
  # choose which component
428
568
  draw = xp.random.rand(*size)
429
- component = (xp.searchsorted(self.cumulative_weights, draw.flatten(), side="right") - 1).reshape(draw.shape)
569
+ component = (
570
+ xp.searchsorted(self.cumulative_weights, draw.flatten(), side="right") - 1
571
+ ).reshape(draw.shape)
430
572
 
431
573
  mean_here = self.means[component]
432
574
  cov_here = self.covs[component]
433
575
 
434
- new_points = mean_here + xp.einsum("...kj,...j->...k", cov_here, np.random.randn(*(component.shape + (self.ndim,))))
576
+ new_points = mean_here + xp.einsum(
577
+ "...kj,...j->...k",
578
+ cov_here,
579
+ np.random.randn(*(component.shape + (self.ndim,))),
580
+ )
435
581
 
436
582
  index_here = self.indexing[component]
437
583
  mins_here = self.mins[index_here]
438
584
  maxs_here = self.maxs[index_here]
439
585
  new_points_mapped = self.map_back_frequency(new_points, mins_here, maxs_here)
440
-
586
+
441
587
  return new_points_mapped
442
-
588
+
443
589
 
444
590
  # class FlowDist:
445
591
  # def __init__(self, config: dict, model: Union[Galaxy, GalaxyFFdot], fit: str, ndim: int):
@@ -450,7 +596,7 @@ class FullGaussianMixtureModel:
450
596
  # param_min, param_max = np.loadtxt(fit)
451
597
  # self.dist.set_min(param_min)
452
598
  # self.dist.set_max(param_max)
453
-
599
+
454
600
  # self.config = config
455
601
  # self.fit = fit
456
602
  # self.ndim = ndim
@@ -461,7 +607,7 @@ class FullGaussianMixtureModel:
461
607
 
462
608
  # total_samp = int(np.prod(size))
463
609
  # samples = self.dist.sample(total_samp).reshape(size + (self.ndim,))
464
- # return samples
610
+ # return samples
465
611
 
466
612
  # def logpdf(self, x: cp.ndarray) -> cp.ndarray:
467
613
  # assert x.shape[-1] == self.ndim
@@ -498,11 +644,3 @@ class FullGaussianMixtureModel:
498
644
  # fit = '/data/mkatz/LISAanalysistools/lisaflow/flow/experiments/rvs/minmax_ffdot_sangria.txt'
499
645
  # ndim = 2
500
646
  # super().__init__(config, model, fit, ndim)
501
-
502
-
503
-
504
-
505
-
506
-
507
-
508
-
lisatools/sensitivity.py CHANGED
@@ -5,7 +5,6 @@ from copy import deepcopy
5
5
 
6
6
  import math
7
7
  import numpy as np
8
- from numpy.typing import ArrayLike
9
8
  from scipy import interpolate
10
9
  import matplotlib.pyplot as plt
11
10
 
@@ -527,17 +526,19 @@ class SensitivityMatrix:
527
526
  sens_mat: Input sensitivity list. The shape of the nested lists should represent the shape of the
528
527
  desired matrix. Each entry in the list must be an array, :class:`Sensitivity`-derived object,
529
528
  or a string corresponding to the :class:`Sensitivity` object.
530
- **sens_kwargs: Keyword arguments to pass to :method:`Sensitivity.get_Sn`.
529
+ **sens_kwargs: Keyword arguments to pass to :func:`Sensitivity.get_Sn`.
531
530
 
532
531
  """
533
532
 
534
533
  def __init__(
535
534
  self,
536
535
  f: np.ndarray,
537
- sens_mat: List[List[np.ndarray | Sensitivity]]
538
- | List[np.ndarray | Sensitivity]
539
- | np.ndarray
540
- | Sensitivity,
536
+ sens_mat: (
537
+ List[List[np.ndarray | Sensitivity]]
538
+ | List[np.ndarray | Sensitivity]
539
+ | np.ndarray
540
+ | Sensitivity
541
+ ),
541
542
  *sens_args: tuple,
542
543
  **sens_kwargs: dict,
543
544
  ) -> None:
@@ -598,10 +599,12 @@ class SensitivityMatrix:
598
599
  @sens_mat.setter
599
600
  def sens_mat(
600
601
  self,
601
- sens_mat: List[List[np.ndarray | Sensitivity]]
602
- | List[np.ndarray | Sensitivity]
603
- | np.ndarray
604
- | Sensitivity,
602
+ sens_mat: (
603
+ List[List[np.ndarray | Sensitivity]]
604
+ | List[np.ndarray | Sensitivity]
605
+ | np.ndarray
606
+ | Sensitivity
607
+ ),
605
608
  ) -> None:
606
609
  """Set sensitivity matrix."""
607
610
  self.sens_mat_input = deepcopy(sens_mat)
@@ -729,7 +732,7 @@ class XYZ1SensitivityMatrix(SensitivityMatrix):
729
732
 
730
733
  Args:
731
734
  f: Frequency array.
732
- **sens_kwargs: Keyword arguments to pass to :method:`Sensitivity.get_Sn`.
735
+ **sens_kwargs: Keyword arguments to pass to :func:`Sensitivity.get_Sn`.
733
736
 
734
737
  """
735
738
 
@@ -749,7 +752,7 @@ class AET1SensitivityMatrix(SensitivityMatrix):
749
752
 
750
753
  Args:
751
754
  f: Frequency array.
752
- **sens_kwargs: Keyword arguments to pass to :method:`Sensitivity.get_Sn`.
755
+ **sens_kwargs: Keyword arguments to pass to :func:`Sensitivity.get_Sn`.
753
756
 
754
757
  """
755
758
 
@@ -763,7 +766,7 @@ class AE1SensitivityMatrix(SensitivityMatrix):
763
766
 
764
767
  Args:
765
768
  f: Frequency array.
766
- **sens_kwargs: Keyword arguments to pass to :method:`Sensitivity.get_Sn`.
769
+ **sens_kwargs: Keyword arguments to pass to :func:`Sensitivity.get_Sn`.
767
770
 
768
771
  """
769
772
 
@@ -778,7 +781,7 @@ class LISASensSensitivityMatrix(SensitivityMatrix):
778
781
  Args:
779
782
  f: Frequency array.
780
783
  nchannels: Number of channels.
781
- **sens_kwargs: Keyword arguments to pass to :method:`Sensitivity.get_Sn`.
784
+ **sens_kwargs: Keyword arguments to pass to :func:`Sensitivity.get_Sn`.
782
785
 
783
786
  """
784
787
 
lisatools/stochastic.py CHANGED
@@ -4,7 +4,6 @@ from typing import Any, Tuple, Optional, List, Dict
4
4
 
5
5
  import math
6
6
  import numpy as np
7
- from numpy.typing import ArrayLike
8
7
  from scipy import interpolate
9
8
 
10
9
  try:
@@ -25,7 +24,7 @@ class StochasticContribution(ABC):
25
24
  added_stochastic_list = []
26
25
 
27
26
  @classmethod
28
- def _check_ndim(cls, params: ArrayLike) -> None:
27
+ def _check_ndim(cls, params: np.ndarray | list) -> None:
29
28
  """Check the dimensionality of the parameters matches the model.
30
29
 
31
30
  Args:
@@ -42,7 +41,7 @@ class StochasticContribution(ABC):
42
41
 
43
42
  @classmethod
44
43
  def get_Sh(
45
- cls, f: float | np.ndarray, *params: ArrayLike, **kwargs: Any
44
+ cls, f: float | np.ndarray, *params: np.ndarray | list, **kwargs: Any
46
45
  ) -> float | np.ndarray:
47
46
  """Calculate the power spectral density of the stochastic contribution.
48
47
 
@@ -226,9 +226,9 @@ def get_groups_from_band_structure(
226
226
  groups_even_odd = xp.sum(groups_even_odd_tmp, axis=0)
227
227
 
228
228
  groups_out = -2 * xp.ones_like(f0, dtype=int)
229
- groups_out[
230
- (temp_inds, walker_inds, inds_band_indices.flatten()[keep])
231
- ] = groups_even_odd
229
+ groups_out[(temp_inds, walker_inds, inds_band_indices.flatten()[keep])] = (
230
+ groups_even_odd
231
+ )
232
232
 
233
233
  groups_out[bad] = -1
234
234
 
@@ -238,3 +238,9 @@ def get_groups_from_band_structure(
238
238
  fix_1 = band_indices[fix]"""
239
239
 
240
240
  return groups_out
241
+
242
+
243
+ autodoc_type_aliases = {
244
+ "Iterable": "Iterable",
245
+ "ArrayLike": "ArrayLike",
246
+ }