ml4gw 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -0,0 +1,110 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ from ml4gw.types import ScalarTensor
5
+
6
+ from ..constants import PI, C, G, m_per_Mpc
7
+
8
+
9
+ class Ringdown(torch.nn.Module):
10
+ """
11
+ Callable class for generating ringdown waveforms.
12
+
13
+ Args:
14
+ sample_rate: Sample rate of waveform
15
+ duration: Duration of waveform
16
+ """
17
+
18
+ def __init__(self, sample_rate: float, duration: float):
19
+ super().__init__()
20
+ # determine times based on requested duration and sample rate
21
+ # and shift so that the waveform is centered at t=0
22
+
23
+ num = int(duration * sample_rate)
24
+ times = torch.arange(num, dtype=torch.float64) / sample_rate
25
+
26
+ self.register_buffer("times", times)
27
+
28
+ def forward(
29
+ self,
30
+ frequency: ScalarTensor,
31
+ quality: ScalarTensor,
32
+ epsilon: ScalarTensor,
33
+ phase: ScalarTensor,
34
+ inclination: ScalarTensor,
35
+ distance: ScalarTensor,
36
+ ):
37
+ """
38
+ Generate ringdown waveform based on the damped sinusoid equation.
39
+
40
+ Args:
41
+ frequency:
42
+ Central frequency of the ringdown waveform in Hz
43
+ quality:
44
+ Quality factor of the ringdown waveform
45
+ epsilon:
46
+ Fraction of black hole's mass radiated as gravitational waves
47
+ phase:
48
+ Initial phase of the ringdown waveform in rad
49
+ inclination:
50
+ Inclination angle of the source in rad
51
+ distance:
52
+ Distance to the source in Mpc
53
+ Returns:
54
+ Tensors of cross and plus polarizations
55
+ """
56
+
57
+ # add dimension for calculating waveforms in batch
58
+ frequency = frequency.view(-1, 1)
59
+ quality = quality.view(-1, 1)
60
+ epsilon = epsilon.view(-1, 1)
61
+ phase = phase.view(-1, 1)
62
+ inclination = inclination.view(-1, 1)
63
+ distance = distance.view(-1, 1)
64
+
65
+ # convert Mpc to m
66
+ distance = distance * m_per_Mpc
67
+
68
+ # ensure all inputs are on the same device
69
+ pi = torch.tensor([PI], device=frequency.device)
70
+
71
+ # Calculate spin and mass
72
+ spin = 1 - (2 / quality) ** (20 / 9)
73
+ mass = (
74
+ (1 / (2 * pi))
75
+ * (C**3 / (G * frequency))
76
+ * (1 - 0.63 * (2 / quality) ** (2 / 3))
77
+ )
78
+
79
+ # Calculate amplitude
80
+ F_Q = 1 + ((7 / 24) / quality**2)
81
+ g_a = 1 - 0.63 * (1 - spin) ** (3 / 10)
82
+ amplitude = (
83
+ np.sqrt(5 / 2)
84
+ * epsilon
85
+ * (G * mass / (C) ** 2)
86
+ * quality ** (-0.5)
87
+ * F_Q ** (-0.5)
88
+ * g_a ** (-0.5)
89
+ )
90
+
91
+ # calculate cosines with inclination
92
+ cos_i = torch.cos(inclination)
93
+ cos_i2 = cos_i**2
94
+ sin_i = torch.sin(inclination)
95
+
96
+ # Precompute exponent and phase terms
97
+ exp_term = torch.exp(-pi * frequency * self.times / quality)
98
+ phase_term = 2 * pi * frequency * self.times + phase
99
+
100
+ a_plus = (amplitude / distance) * (1 + cos_i2) * exp_term
101
+ a_cross = (amplitude / distance) * (2 * sin_i) * exp_term
102
+
103
+ h_plus = a_plus * torch.cos(phase_term)
104
+ h_cross = a_cross * torch.sin(phase_term)
105
+
106
+ # ensure the dtype is double
107
+ h_plus = h_plus.double()
108
+ h_cross = h_cross.double()
109
+
110
+ return h_cross, h_plus
@@ -30,7 +30,7 @@ class SineGaussian(torch.nn.Module):
30
30
 
31
31
  self.register_buffer("times", times)
32
32
 
33
- def __call__(
33
+ def forward(
34
34
  self,
35
35
  quality: ScalarTensor,
36
36
  frequency: ScalarTensor,
@@ -60,7 +60,7 @@ class SineGaussian(torch.nn.Module):
60
60
  Returns:
61
61
  Tensors of cross and plus polarizations
62
62
  """
63
-
63
+ dtype = frequency.dtype
64
64
  # add dimension for calculating waveforms in batch
65
65
  frequency = frequency.view(-1, 1)
66
66
  quality = quality.view(-1, 1)
@@ -105,8 +105,7 @@ class SineGaussian(torch.nn.Module):
105
105
  cross = fac.imag * h0_cross
106
106
  plus = fac.real * h0_plus
107
107
 
108
- # TODO dtype as argument?
109
- cross = cross.double()
110
- plus = plus.double()
108
+ cross = cross.to(dtype)
109
+ plus = plus.to(dtype)
111
110
 
112
111
  return cross, plus