ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -0,0 +1,109 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from ml4gw.constants import PI, C, G, m_per_Mpc
5
+ from ml4gw.types import BatchTensor
6
+
7
+
8
+ class Ringdown(torch.nn.Module):
9
+ """
10
+ Callable class for generating ringdown waveforms.
11
+
12
+ Args:
13
+ sample_rate: Sample rate of waveform
14
+ duration: Duration of waveform
15
+ """
16
+
17
+ def __init__(self, sample_rate: float, duration: float):
18
+ super().__init__()
19
+ # determine times based on requested duration and sample rate
20
+ # and shift so that the waveform is centered at t=0
21
+
22
+ num = int(duration * sample_rate)
23
+ times = torch.arange(num, dtype=torch.float64) / sample_rate
24
+
25
+ self.register_buffer("times", times)
26
+
27
+ def forward(
28
+ self,
29
+ frequency: BatchTensor,
30
+ quality: BatchTensor,
31
+ epsilon: BatchTensor,
32
+ phase: BatchTensor,
33
+ inclination: BatchTensor,
34
+ distance: BatchTensor,
35
+ ):
36
+ """
37
+ Generate ringdown waveform based on the damped sinusoid equation.
38
+
39
+ Args:
40
+ frequency:
41
+ Central frequency of the ringdown waveform in Hz
42
+ quality:
43
+ Quality factor of the ringdown waveform
44
+ epsilon:
45
+ Fraction of black hole's mass radiated as gravitational waves
46
+ phase:
47
+ Initial phase of the ringdown waveform in rad
48
+ inclination:
49
+ Inclination angle of the source in rad
50
+ distance:
51
+ Distance to the source in Mpc
52
+ Returns:
53
+ Tensors of cross and plus polarizations
54
+ """
55
+
56
+ # add dimension for calculating waveforms in batch
57
+ frequency = frequency.view(-1, 1)
58
+ quality = quality.view(-1, 1)
59
+ epsilon = epsilon.view(-1, 1)
60
+ phase = phase.view(-1, 1)
61
+ inclination = inclination.view(-1, 1)
62
+ distance = distance.view(-1, 1)
63
+
64
+ # convert Mpc to m
65
+ distance = distance * m_per_Mpc
66
+
67
+ # ensure all inputs are on the same device
68
+ pi = torch.tensor([PI], device=frequency.device)
69
+
70
+ # Calculate spin and mass
71
+ spin = 1 - (2 / quality) ** (20 / 9)
72
+ mass = (
73
+ (1 / (2 * pi))
74
+ * (C**3 / (G * frequency))
75
+ * (1 - 0.63 * (2 / quality) ** (2 / 3))
76
+ )
77
+
78
+ # Calculate amplitude
79
+ F_Q = 1 + ((7 / 24) / quality**2)
80
+ g_a = 1 - 0.63 * (1 - spin) ** (3 / 10)
81
+ amplitude = (
82
+ np.sqrt(5 / 2)
83
+ * epsilon
84
+ * (G * mass / (C) ** 2)
85
+ * quality ** (-0.5)
86
+ * F_Q ** (-0.5)
87
+ * g_a ** (-0.5)
88
+ )
89
+
90
+ # calculate cosines with inclination
91
+ cos_i = torch.cos(inclination)
92
+ cos_i2 = cos_i**2
93
+ sin_i = torch.sin(inclination)
94
+
95
+ # Precompute exponent and phase terms
96
+ exp_term = torch.exp(-pi * frequency * self.times / quality)
97
+ phase_term = 2 * pi * frequency * self.times + phase
98
+
99
+ a_plus = (amplitude / distance) * (1 + cos_i2) * exp_term
100
+ a_cross = (amplitude / distance) * (2 * sin_i) * exp_term
101
+
102
+ h_plus = a_plus * torch.cos(phase_term)
103
+ h_cross = a_cross * torch.sin(phase_term)
104
+
105
+ # ensure the dtype is double
106
+ h_plus = h_plus.double()
107
+ h_cross = h_cross.double()
108
+
109
+ return h_cross, h_plus
@@ -1,7 +1,7 @@
1
1
  import torch
2
2
  from torch import Tensor
3
3
 
4
- from ml4gw.types import ScalarTensor
4
+ from ml4gw.types import BatchTensor
5
5
 
6
6
 
7
7
  def semi_major_minor_from_e(e: Tensor):
@@ -30,13 +30,13 @@ class SineGaussian(torch.nn.Module):
30
30
 
31
31
  self.register_buffer("times", times)
32
32
 
33
- def __call__(
33
+ def forward(
34
34
  self,
35
- quality: ScalarTensor,
36
- frequency: ScalarTensor,
37
- hrss: ScalarTensor,
38
- phase: ScalarTensor,
39
- eccentricity: ScalarTensor,
35
+ quality: BatchTensor,
36
+ frequency: BatchTensor,
37
+ hrss: BatchTensor,
38
+ phase: BatchTensor,
39
+ eccentricity: BatchTensor,
40
40
  ):
41
41
  """
42
42
  Generate lalinference implementation of a sine-Gaussian waveform.
@@ -60,7 +60,7 @@ class SineGaussian(torch.nn.Module):
60
60
  Returns:
61
61
  Tensors of cross and plus polarizations
62
62
  """
63
-
63
+ dtype = frequency.dtype
64
64
  # add dimension for calculating waveforms in batch
65
65
  frequency = frequency.view(-1, 1)
66
66
  quality = quality.view(-1, 1)
@@ -105,8 +105,7 @@ class SineGaussian(torch.nn.Module):
105
105
  cross = fac.imag * h0_cross
106
106
  plus = fac.real * h0_plus
107
107
 
108
- # TODO dtype as argument?
109
- cross = cross.double()
110
- plus = plus.double()
108
+ cross = cross.to(dtype)
109
+ plus = plus.to(dtype)
111
110
 
112
111
  return cross, plus