ml4gw 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

Files changed (44) hide show
  1. ml4gw/augmentations.py +8 -2
  2. ml4gw/constants.py +10 -19
  3. ml4gw/dataloading/chunked_dataset.py +4 -2
  4. ml4gw/dataloading/hdf5_dataset.py +1 -1
  5. ml4gw/dataloading/in_memory_dataset.py +8 -4
  6. ml4gw/distributions.py +5 -3
  7. ml4gw/gw.py +21 -27
  8. ml4gw/nn/autoencoder/base.py +11 -6
  9. ml4gw/nn/autoencoder/convolutional.py +7 -4
  10. ml4gw/nn/autoencoder/skip_connection.py +7 -6
  11. ml4gw/nn/autoencoder/utils.py +2 -1
  12. ml4gw/nn/norm.py +5 -1
  13. ml4gw/nn/streaming/online_average.py +7 -5
  14. ml4gw/nn/streaming/snapshotter.py +7 -5
  15. ml4gw/spectral.py +41 -37
  16. ml4gw/transforms/__init__.py +1 -0
  17. ml4gw/transforms/pearson.py +7 -3
  18. ml4gw/transforms/qtransform.py +151 -53
  19. ml4gw/transforms/scaler.py +9 -3
  20. ml4gw/transforms/snr_rescaler.py +6 -5
  21. ml4gw/transforms/spectral.py +9 -2
  22. ml4gw/transforms/spectrogram.py +7 -1
  23. ml4gw/transforms/spline_interpolation.py +370 -0
  24. ml4gw/transforms/transform.py +4 -3
  25. ml4gw/transforms/waveforms.py +10 -7
  26. ml4gw/transforms/whitening.py +12 -4
  27. ml4gw/types.py +25 -10
  28. ml4gw/utils/interferometer.py +1 -1
  29. ml4gw/utils/slicing.py +24 -16
  30. ml4gw/waveforms/__init__.py +2 -5
  31. ml4gw/waveforms/adhoc/__init__.py +2 -0
  32. ml4gw/waveforms/{ringdown.py → adhoc/ringdown.py} +8 -9
  33. ml4gw/waveforms/{sine_gaussian.py → adhoc/sine_gaussian.py} +6 -6
  34. ml4gw/waveforms/cbc/__init__.py +3 -0
  35. ml4gw/waveforms/{phenom_d.py → cbc/phenom_d.py} +20 -18
  36. ml4gw/waveforms/{phenom_p.py → cbc/phenom_p.py} +106 -95
  37. ml4gw/waveforms/{taylorf2.py → cbc/taylorf2.py} +33 -27
  38. ml4gw/waveforms/conversion.py +187 -0
  39. ml4gw/waveforms/generator.py +9 -5
  40. {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/METADATA +4 -3
  41. ml4gw-0.6.0.dist-info/RECORD +51 -0
  42. {ml4gw-0.5.0.dist-info → ml4gw-0.6.0.dist-info}/WHEEL +1 -1
  43. ml4gw-0.5.0.dist-info/RECORD +0 -47
  44. /ml4gw/waveforms/{phenom_d_data.py → cbc/phenom_d_data.py} +0 -0
@@ -1,5 +1,2 @@
1
- from .phenom_d import IMRPhenomD
2
- from .phenom_p import IMRPhenomPv2
3
- from .ringdown import Ringdown
4
- from .sine_gaussian import SineGaussian
5
- from .taylorf2 import TaylorF2
1
+ from .adhoc import *
2
+ from .cbc import *
@@ -0,0 +1,2 @@
1
+ from .ringdown import Ringdown
2
+ from .sine_gaussian import SineGaussian
@@ -1,9 +1,8 @@
1
1
  import numpy as np
2
2
  import torch
3
3
 
4
- from ml4gw.types import ScalarTensor
5
-
6
- from ..constants import PI, C, G, m_per_Mpc
4
+ from ml4gw.constants import PI, C, G, m_per_Mpc
5
+ from ml4gw.types import BatchTensor
7
6
 
8
7
 
9
8
  class Ringdown(torch.nn.Module):
@@ -27,12 +26,12 @@ class Ringdown(torch.nn.Module):
27
26
 
28
27
  def forward(
29
28
  self,
30
- frequency: ScalarTensor,
31
- quality: ScalarTensor,
32
- epsilon: ScalarTensor,
33
- phase: ScalarTensor,
34
- inclination: ScalarTensor,
35
- distance: ScalarTensor,
29
+ frequency: BatchTensor,
30
+ quality: BatchTensor,
31
+ epsilon: BatchTensor,
32
+ phase: BatchTensor,
33
+ inclination: BatchTensor,
34
+ distance: BatchTensor,
36
35
  ):
37
36
  """
38
37
  Generate ringdown waveform based on the damped sinusoid equation.
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
- from ml4gw.types import ScalarTensor
4
+ from ml4gw.types import BatchTensor
5
5
 
6
6
 
7
7
  def semi_major_minor_from_e(e: Tensor):
@@ -32,11 +32,11 @@ class SineGaussian(torch.nn.Module):
32
32
 
33
33
  def forward(
34
34
  self,
35
- quality: ScalarTensor,
36
- frequency: ScalarTensor,
37
- hrss: ScalarTensor,
38
- phase: ScalarTensor,
39
- eccentricity: ScalarTensor,
35
+ quality: BatchTensor,
36
+ frequency: BatchTensor,
37
+ hrss: BatchTensor,
38
+ phase: BatchTensor,
39
+ eccentricity: BatchTensor,
40
40
  ):
41
41
  """
42
42
  Generate lalinference implementation of a sine-Gaussian waveform.
@@ -0,0 +1,3 @@
1
+ from .phenom_d import IMRPhenomD
2
+ from .phenom_p import IMRPhenomPv2
3
+ from .taylorf2 import TaylorF2
@@ -1,7 +1,9 @@
1
1
  import torch
2
- from torchtyping import TensorType
2
+ from jaxtyping import Float
3
+
4
+ from ml4gw.constants import MTSUN_SI, PI
5
+ from ml4gw.types import BatchTensor, FrequencySeries1d
3
6
 
4
- from ..constants import MTSUN_SI, PI
5
7
  from .phenom_d_data import QNMData_a, QNMData_fdamp, QNMData_fring
6
8
  from .taylorf2 import TaylorF2
7
9
 
@@ -15,14 +17,14 @@ class IMRPhenomD(TaylorF2):
15
17
 
16
18
  def forward(
17
19
  self,
18
- f: TensorType,
19
- chirp_mass: TensorType,
20
- mass_ratio: TensorType,
21
- chi1: TensorType,
22
- chi2: TensorType,
23
- distance: TensorType,
24
- phic: TensorType,
25
- inclination: TensorType,
20
+ f: FrequencySeries1d,
21
+ chirp_mass: BatchTensor,
22
+ mass_ratio: BatchTensor,
23
+ chi1: BatchTensor,
24
+ chi2: BatchTensor,
25
+ distance: BatchTensor,
26
+ phic: BatchTensor,
27
+ inclination: BatchTensor,
26
28
  f_ref: float,
27
29
  ):
28
30
  """
@@ -76,15 +78,15 @@ class IMRPhenomD(TaylorF2):
76
78
 
77
79
  def phenom_d_htilde(
78
80
  self,
79
- f: TensorType,
80
- chirp_mass: TensorType,
81
- mass_ratio: TensorType,
82
- chi1: TensorType,
83
- chi2: TensorType,
84
- distance: TensorType,
85
- phic: TensorType,
81
+ f: FrequencySeries1d,
82
+ chirp_mass: BatchTensor,
83
+ mass_ratio: BatchTensor,
84
+ chi1: BatchTensor,
85
+ chi2: BatchTensor,
86
+ distance: BatchTensor,
87
+ phic: BatchTensor,
86
88
  f_ref: float,
87
- ):
89
+ ) -> Float[FrequencySeries1d, " batch"]:
88
90
  total_mass = chirp_mass * (1 + mass_ratio) ** 1.2 / mass_ratio**0.6
89
91
  mass_1 = total_mass / (1 + mass_ratio)
90
92
  mass_2 = mass_1 * mass_ratio
@@ -1,9 +1,13 @@
1
- from typing import Dict, Tuple
1
+ from typing import Dict, Optional, Tuple
2
2
 
3
3
  import torch
4
- from torchtyping import TensorType
4
+ from jaxtyping import Float
5
+ from torch import Tensor
6
+
7
+ from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
8
+ from ml4gw.types import BatchTensor, FrequencySeries1d
9
+ from ml4gw.waveforms.conversion import rotate_y, rotate_z
5
10
 
6
- from ..constants import MPC_SEC, MTSUN_SI, PI
7
11
  from .phenom_d import IMRPhenomD
8
12
 
9
13
 
@@ -13,20 +17,20 @@ class IMRPhenomPv2(IMRPhenomD):
13
17
 
14
18
  def forward(
15
19
  self,
16
- fs: TensorType,
17
- chirp_mass: TensorType,
18
- mass_ratio: TensorType,
19
- s1x: TensorType,
20
- s1y: TensorType,
21
- s1z: TensorType,
22
- s2x: TensorType,
23
- s2y: TensorType,
24
- s2z: TensorType,
25
- dist_mpc: TensorType,
26
- tc: TensorType,
27
- phiRef: TensorType,
28
- incl: TensorType,
20
+ fs: FrequencySeries1d,
21
+ chirp_mass: BatchTensor,
22
+ mass_ratio: BatchTensor,
23
+ s1x: BatchTensor,
24
+ s1y: BatchTensor,
25
+ s1z: BatchTensor,
26
+ s2x: BatchTensor,
27
+ s2y: BatchTensor,
28
+ s2z: BatchTensor,
29
+ distance: BatchTensor,
30
+ phic: BatchTensor,
31
+ inclination: BatchTensor,
29
32
  f_ref: float,
33
+ tc: Optional[BatchTensor] = None,
30
34
  ):
31
35
  """
32
36
  IMRPhenomPv2 waveform
@@ -50,13 +54,13 @@ class IMRPhenomPv2(IMRPhenomD):
50
54
  Spin component y of the second BH.
51
55
  s2z :
52
56
  Spin component z of the second BH.
53
- dist_mpc :
57
+ distance :
54
58
  Luminosity distance in Mpc.
55
59
  tc :
56
60
  Coalescence time.
57
- phiRef :
61
+ phic :
58
62
  Reference phase.
59
- incl :
63
+ inclination :
60
64
  Inclination angle.
61
65
  f_ref :
62
66
  Reference frequency in Hz.
@@ -68,6 +72,9 @@ class IMRPhenomPv2(IMRPhenomD):
68
72
  Note: m1 must be larger than m2.
69
73
  """
70
74
 
75
+ if tc is None:
76
+ tc = torch.zeros_like(chirp_mass)
77
+
71
78
  m2 = chirp_mass * (1.0 + mass_ratio) ** 0.2 / mass_ratio**0.6
72
79
  m1 = m2 * mass_ratio
73
80
 
@@ -86,7 +93,7 @@ class IMRPhenomPv2(IMRPhenomD):
86
93
  phi_aligned,
87
94
  zeta_polariz,
88
95
  ) = self.convert_spins(
89
- m1, m2, f_ref, phiRef, incl, s1x, s1y, s1z, s2x, s2y, s2z
96
+ m1, m2, f_ref, phic, inclination, s1x, s1y, s1z, s2x, s2y, s2z
90
97
  )
91
98
 
92
99
  phic = 2 * phi_aligned
@@ -149,7 +156,7 @@ class IMRPhenomPv2(IMRPhenomD):
149
156
  phic,
150
157
  M,
151
158
  xi,
152
- dist_mpc,
159
+ distance,
153
160
  )
154
161
 
155
162
  hp, hc = self.PhenomPCoreTwistUp(
@@ -184,18 +191,18 @@ class IMRPhenomPv2(IMRPhenomD):
184
191
 
185
192
  def PhenomPCoreTwistUp(
186
193
  self,
187
- fHz: TensorType,
188
- hPhenom: TensorType,
189
- eta: TensorType,
190
- chi1_l: TensorType,
191
- chi2_l: TensorType,
192
- chip: TensorType,
193
- M: TensorType,
194
- angcoeffs: Dict[str, TensorType],
195
- Y2m: TensorType,
196
- alphaoffset: TensorType,
197
- epsilonoffset: TensorType,
198
- ) -> Tuple[TensorType, TensorType]:
194
+ fHz: FrequencySeries1d,
195
+ hPhenom: BatchTensor,
196
+ eta: BatchTensor,
197
+ chi1_l: BatchTensor,
198
+ chi2_l: BatchTensor,
199
+ chip: BatchTensor,
200
+ M: BatchTensor,
201
+ angcoeffs: Dict[str, BatchTensor],
202
+ Y2m: BatchTensor,
203
+ alphaoffset: BatchTensor,
204
+ epsilonoffset: BatchTensor,
205
+ ) -> Tuple[BatchTensor, BatchTensor]:
199
206
  assert angcoeffs is not None
200
207
  assert Y2m is not None
201
208
  f = fHz * MTSUN_SI * M.unsqueeze(1) # Frequency in geometric units
@@ -306,7 +313,7 @@ class IMRPhenomPv2(IMRPhenomD):
306
313
  phic,
307
314
  M,
308
315
  xi,
309
- dist_mpc,
316
+ distance,
310
317
  ):
311
318
  """
312
319
  m1, m2: in solar masses
@@ -321,10 +328,10 @@ class IMRPhenomPv2(IMRPhenomD):
321
328
  phase, _ = self.phenom_d_phase(Mf, m1, m2, eta, eta2, chi1, chi2, xi)
322
329
  phase = (phase.mT - (phic + PI / 4.0)).mT
323
330
  Amp = self.phenom_d_amp(
324
- Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, dist_mpc
331
+ Mf, m1, m2, eta, eta2, Seta, chi1, chi2, chi12, chi22, xi, distance
325
332
  )[0]
326
333
  Amp0 = self.get_Amp0(Mf, eta)
327
- dist_s = dist_mpc * MPC_SEC
334
+ dist_s = distance * MPC_SEC
328
335
  Amp = ((Amp0 * Amp).mT * (M_s**2.0) / dist_s).mT
329
336
  # phase -= 2. * phic; # line 1316 ???
330
337
  hPhenom = Amp * (torch.exp(-1j * phase))
@@ -354,8 +361,11 @@ class IMRPhenomPv2(IMRPhenomD):
354
361
  # Utility functions
355
362
 
356
363
  def interpolate(
357
- self, x: TensorType, xp: TensorType, fp: TensorType
358
- ) -> TensorType:
364
+ self,
365
+ x: Float[Tensor, " new_series"],
366
+ xp: Float[Tensor, " series"],
367
+ fp: Float[Tensor, " series"],
368
+ ) -> Float[Tensor, " new_series"]:
359
369
  """One-dimensional linear interpolation for monotonically
360
370
  increasing sample points.
361
371
 
@@ -385,17 +395,11 @@ class IMRPhenomPv2(IMRPhenomD):
385
395
 
386
396
  return interpolated.reshape(original_shape)
387
397
 
388
- def ROTATEZ(self, angle: TensorType, x, y, z):
389
- tmp_x = x * torch.cos(angle) - y * torch.sin(angle)
390
- tmp_y = x * torch.sin(angle) + y * torch.cos(angle)
391
- return tmp_x, tmp_y, z
392
-
393
- def ROTATEY(self, angle, x, y, z):
394
- tmp_x = x * torch.cos(angle) + z * torch.sin(angle)
395
- tmp_z = -x * torch.sin(angle) + z * torch.cos(angle)
396
- return tmp_x, y, tmp_z
397
-
398
- def L2PNR(self, v: TensorType, eta: TensorType) -> TensorType:
398
+ def L2PNR(
399
+ self,
400
+ v: BatchTensor,
401
+ eta: BatchTensor,
402
+ ) -> BatchTensor:
399
403
  eta2 = eta**2
400
404
  x = v**2
401
405
  x2 = x**2
@@ -412,25 +416,25 @@ class IMRPhenomPv2(IMRPhenomD):
412
416
 
413
417
  def convert_spins(
414
418
  self,
415
- m1: TensorType,
416
- m2: TensorType,
419
+ m1: BatchTensor,
420
+ m2: BatchTensor,
417
421
  f_ref: float,
418
- phiRef: TensorType,
419
- incl: TensorType,
420
- s1x: TensorType,
421
- s1y: TensorType,
422
- s1z: TensorType,
423
- s2x: TensorType,
424
- s2y: TensorType,
425
- s2z: TensorType,
422
+ phic: BatchTensor,
423
+ inclination: BatchTensor,
424
+ s1x: BatchTensor,
425
+ s1y: BatchTensor,
426
+ s1z: BatchTensor,
427
+ s2x: BatchTensor,
428
+ s2y: BatchTensor,
429
+ s2z: BatchTensor,
426
430
  ) -> Tuple[
427
- TensorType,
428
- TensorType,
429
- TensorType,
430
- TensorType,
431
- TensorType,
432
- TensorType,
433
- TensorType,
431
+ BatchTensor,
432
+ BatchTensor,
433
+ BatchTensor,
434
+ BatchTensor,
435
+ BatchTensor,
436
+ BatchTensor,
437
+ BatchTensor,
434
438
  ]:
435
439
  M = m1 + m2
436
440
  m1_2 = m1 * m1
@@ -476,32 +480,32 @@ class IMRPhenomPv2(IMRPhenomD):
476
480
  # First we determine kappa
477
481
  # in the source frame, the components of N are given in
478
482
  # Eq (35c) of T1500606-v6
479
- Nx_sf = torch.sin(incl) * torch.cos(PI / 2.0 - phiRef)
480
- Ny_sf = torch.sin(incl) * torch.sin(PI / 2.0 - phiRef)
481
- Nz_sf = torch.cos(incl)
483
+ Nx_sf = torch.sin(inclination) * torch.cos(PI / 2.0 - phic)
484
+ Ny_sf = torch.sin(inclination) * torch.sin(PI / 2.0 - phic)
485
+ Nz_sf = torch.cos(inclination)
482
486
 
483
487
  tmp_x = Nx_sf
484
488
  tmp_y = Ny_sf
485
489
  tmp_z = Nz_sf
486
490
 
487
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
488
- tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
491
+ tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z)
492
+ tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
489
493
 
490
494
  kappa = -torch.arctan2(tmp_y, tmp_x)
491
495
 
492
496
  # Then we determine alpha0, by rotating LN
493
497
  tmp_x, tmp_y, tmp_z = 0, 0, 1
494
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
495
- tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
496
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
498
+ tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z)
499
+ tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
500
+ tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z)
497
501
 
498
502
  alpha0 = torch.arctan2(tmp_y, tmp_x)
499
503
 
500
504
  # Finally we determine thetaJ, by rotating N
501
505
  tmp_x, tmp_y, tmp_z = Nx_sf, Ny_sf, Nz_sf
502
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
503
- tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
504
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
506
+ tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z)
507
+ tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
508
+ tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z)
505
509
  Nx_Jf, Nz_Jf = tmp_x, tmp_z
506
510
  thetaJN = torch.arccos(Nz_Jf)
507
511
 
@@ -518,13 +522,13 @@ class IMRPhenomPv2(IMRPhenomD):
518
522
  # Both triads differ from each other by a rotation around N by an angle
519
523
  # \zeta and we need to rotate the polarizations accordingly by 2\zeta
520
524
 
521
- Xx_sf = -torch.cos(incl) * torch.sin(phiRef)
522
- Xy_sf = -torch.cos(incl) * torch.cos(phiRef)
523
- Xz_sf = torch.sin(incl)
525
+ Xx_sf = -torch.cos(inclination) * torch.sin(phic)
526
+ Xy_sf = -torch.cos(inclination) * torch.cos(phic)
527
+ Xz_sf = torch.sin(inclination)
524
528
  tmp_x, tmp_y, tmp_z = Xx_sf, Xy_sf, Xz_sf
525
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(-phiJ_sf, tmp_x, tmp_y, tmp_z)
526
- tmp_x, tmp_y, tmp_z = self.ROTATEY(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
527
- tmp_x, tmp_y, tmp_z = self.ROTATEZ(kappa, tmp_x, tmp_y, tmp_z)
529
+ tmp_x, tmp_y, tmp_z = rotate_z(-phiJ_sf, tmp_x, tmp_y, tmp_z)
530
+ tmp_x, tmp_y, tmp_z = rotate_y(-thetaJ_sf, tmp_x, tmp_y, tmp_z)
531
+ tmp_x, tmp_y, tmp_z = rotate_z(kappa, tmp_x, tmp_y, tmp_z)
528
532
 
529
533
  # Now the tmp_a are the components of X in the J frame
530
534
  # We need the polar angle of that vector in the P,Q basis of Arun et al
@@ -591,8 +595,12 @@ class IMRPhenomPv2(IMRPhenomD):
591
595
  )
592
596
 
593
597
  def WignerdCoefficients(
594
- self, v: TensorType, SL: TensorType, eta: TensorType, Sp: TensorType
595
- ) -> Tuple[TensorType, TensorType]:
598
+ self,
599
+ v: BatchTensor,
600
+ SL: BatchTensor,
601
+ eta: BatchTensor,
602
+ Sp: BatchTensor,
603
+ ) -> Tuple[BatchTensor, BatchTensor]:
596
604
  # We define the shorthand s := Sp / (L + SL)
597
605
  L = self.L2PNR(v, eta)
598
606
  s = (Sp / (L + SL)).mT
@@ -604,8 +612,11 @@ class IMRPhenomPv2(IMRPhenomD):
604
612
  return cos_beta_half, sin_beta_half
605
613
 
606
614
  def ComputeNNLOanglecoeffs(
607
- self, q: TensorType, chil: TensorType, chip: TensorType
608
- ) -> Dict[str, TensorType]:
615
+ self,
616
+ q: BatchTensor,
617
+ chil: BatchTensor,
618
+ chip: BatchTensor,
619
+ ) -> Dict[str, BatchTensor]:
609
620
  m2 = q / (1.0 + q)
610
621
  m1 = 1.0 / (1.0 + q)
611
622
  dm = m1 - m2
@@ -730,12 +741,12 @@ class IMRPhenomPv2(IMRPhenomD):
730
741
 
731
742
  def FinalSpin_inplane(
732
743
  self,
733
- m1: TensorType,
734
- m2: TensorType,
735
- chi1_l: TensorType,
736
- chi2_l: TensorType,
737
- chip: TensorType,
738
- ) -> TensorType:
744
+ m1: BatchTensor,
745
+ m2: BatchTensor,
746
+ chi1_l: BatchTensor,
747
+ chi2_l: BatchTensor,
748
+ chip: BatchTensor,
749
+ ) -> BatchTensor:
739
750
  M = m1 + m2
740
751
  eta = m1 * m2 / (M * M)
741
752
  eta2 = eta * eta
@@ -751,7 +762,7 @@ class IMRPhenomPv2(IMRPhenomD):
751
762
 
752
763
  def phP_get_fRD_fdamp(
753
764
  self, m1, m2, chi1_l, chi2_l, chip
754
- ) -> Tuple[TensorType, TensorType]:
765
+ ) -> Tuple[BatchTensor, BatchTensor]:
755
766
  # m1 > m2 should hold here
756
767
  finspin = self.FinalSpin_inplane(m1, m2, chi1_l, chi2_l, chip)
757
768
  m1_s = m1 * MTSUN_SI
@@ -770,7 +781,7 @@ class IMRPhenomPv2(IMRPhenomD):
770
781
  ) / (1.0 - Erad)
771
782
  return fRD / M_s, fdamp / M_s
772
783
 
773
- def get_Amp0(self, fM_s: TensorType, eta: TensorType) -> TensorType:
784
+ def get_Amp0(self, fM_s: BatchTensor, eta: BatchTensor) -> BatchTensor:
774
785
  Amp0 = (
775
786
  (2.0 / 3.0 * eta.unsqueeze(1)) ** (1.0 / 2.0)
776
787
  * (fM_s) ** (-7.0 / 6.0)
@@ -1,8 +1,9 @@
1
1
  import torch
2
- from torchtyping import TensorType
2
+ from jaxtyping import Float
3
3
 
4
- from ..constants import MPC_SEC, MTSUN_SI, PI
5
- from ..constants import EulerGamma as GAMMA
4
+ from ml4gw.constants import MPC_SEC, MTSUN_SI, PI
5
+ from ml4gw.constants import EulerGamma as GAMMA
6
+ from ml4gw.types import BatchTensor, FrequencySeries1d
6
7
 
7
8
 
8
9
  class TaylorF2(torch.nn.Module):
@@ -11,14 +12,14 @@ class TaylorF2(torch.nn.Module):
11
12
 
12
13
  def forward(
13
14
  self,
14
- f: TensorType,
15
- chirp_mass: TensorType,
16
- mass_ratio: TensorType,
17
- chi1: TensorType,
18
- chi2: TensorType,
19
- distance: TensorType,
20
- phic: TensorType,
21
- inclination: TensorType,
15
+ f: FrequencySeries1d,
16
+ chirp_mass: BatchTensor,
17
+ mass_ratio: BatchTensor,
18
+ chi1: BatchTensor,
19
+ chi2: BatchTensor,
20
+ distance: BatchTensor,
21
+ phic: BatchTensor,
22
+ inclination: BatchTensor,
22
23
  f_ref: float,
23
24
  ):
24
25
  """
@@ -75,15 +76,15 @@ class TaylorF2(torch.nn.Module):
75
76
 
76
77
  def taylorf2_htilde(
77
78
  self,
78
- f: TensorType,
79
- mass1: TensorType,
80
- mass2: TensorType,
81
- chi1: TensorType,
82
- chi2: TensorType,
83
- distance: TensorType,
84
- phic: TensorType,
79
+ f: FrequencySeries1d,
80
+ mass1: BatchTensor,
81
+ mass2: BatchTensor,
82
+ chi1: BatchTensor,
83
+ chi2: BatchTensor,
84
+ distance: BatchTensor,
85
+ phic: BatchTensor,
85
86
  f_ref: float,
86
- ):
87
+ ) -> Float[FrequencySeries1d, " batch"]:
87
88
  mass1_s = mass1 * MTSUN_SI
88
89
  mass2_s = mass2 * MTSUN_SI
89
90
  M_s = mass1_s + mass2_s
@@ -103,8 +104,13 @@ class TaylorF2(torch.nn.Module):
103
104
  return h0
104
105
 
105
106
  def taylorf2_amplitude(
106
- self, Mf: TensorType, mass1, mass2, eta, distance
107
- ) -> TensorType:
107
+ self,
108
+ Mf: BatchTensor,
109
+ mass1: BatchTensor,
110
+ mass2: BatchTensor,
111
+ eta: BatchTensor,
112
+ distance: BatchTensor,
113
+ ) -> Float[FrequencySeries1d, " batch"]:
108
114
  mass1_s = mass1 * MTSUN_SI
109
115
  mass2_s = mass2 * MTSUN_SI
110
116
  v = (PI * Mf) ** (1.0 / 3.0)
@@ -126,12 +132,12 @@ class TaylorF2(torch.nn.Module):
126
132
 
127
133
  def taylorf2_phase(
128
134
  self,
129
- Mf: TensorType,
130
- mass1: TensorType,
131
- mass2: TensorType,
132
- chi1: TensorType,
133
- chi2: TensorType,
134
- ) -> TensorType:
135
+ Mf: BatchTensor,
136
+ mass1: BatchTensor,
137
+ mass2: BatchTensor,
138
+ chi1: BatchTensor,
139
+ chi2: BatchTensor,
140
+ ) -> Float[FrequencySeries1d, " batch"]:
135
141
  """
136
142
  Calculate the inspiral phase for the TaylorF2.
137
143
  """