ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ml4gw might be problematic. Click here for more details.
- ml4gw/augmentations.py +8 -2
- ml4gw/constants.py +45 -0
- ml4gw/dataloading/chunked_dataset.py +4 -2
- ml4gw/dataloading/hdf5_dataset.py +1 -1
- ml4gw/dataloading/in_memory_dataset.py +8 -4
- ml4gw/distributions.py +18 -12
- ml4gw/gw.py +21 -27
- ml4gw/nn/autoencoder/base.py +11 -6
- ml4gw/nn/autoencoder/convolutional.py +7 -4
- ml4gw/nn/autoencoder/skip_connection.py +7 -6
- ml4gw/nn/autoencoder/utils.py +2 -1
- ml4gw/nn/norm.py +11 -1
- ml4gw/nn/streaming/online_average.py +7 -5
- ml4gw/nn/streaming/snapshotter.py +7 -5
- ml4gw/spectral.py +40 -36
- ml4gw/transforms/pearson.py +7 -3
- ml4gw/transforms/qtransform.py +20 -14
- ml4gw/transforms/scaler.py +6 -2
- ml4gw/transforms/snr_rescaler.py +6 -5
- ml4gw/transforms/spectral.py +25 -6
- ml4gw/transforms/spectrogram.py +7 -1
- ml4gw/transforms/transform.py +4 -3
- ml4gw/transforms/waveforms.py +10 -7
- ml4gw/transforms/whitening.py +12 -4
- ml4gw/types.py +25 -10
- ml4gw/utils/interferometer.py +7 -1
- ml4gw/utils/slicing.py +24 -16
- ml4gw/waveforms/__init__.py +2 -0
- ml4gw/waveforms/generator.py +9 -5
- ml4gw/waveforms/phenom_d.py +1338 -1256
- ml4gw/waveforms/phenom_p.py +796 -0
- ml4gw/waveforms/ringdown.py +109 -0
- ml4gw/waveforms/sine_gaussian.py +10 -11
- ml4gw/waveforms/taylorf2.py +304 -279
- {ml4gw-0.4.2.dist-info → ml4gw-0.5.1.dist-info}/METADATA +5 -3
- ml4gw-0.5.1.dist-info/RECORD +47 -0
- ml4gw-0.4.2.dist-info/RECORD +0 -44
- {ml4gw-0.4.2.dist-info → ml4gw-0.5.1.dist-info}/WHEEL +0 -0
ml4gw/augmentations.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from jaxtyping import Float
|
|
3
|
+
from torch import Tensor
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
class SignalInverter(torch.nn.Module):
|
|
@@ -16,7 +18,9 @@ class SignalInverter(torch.nn.Module):
|
|
|
16
18
|
super().__init__()
|
|
17
19
|
self.prob = prob
|
|
18
20
|
|
|
19
|
-
def forward(
|
|
21
|
+
def forward(
|
|
22
|
+
self, X: Float[Tensor, "*batch time"]
|
|
23
|
+
) -> Float[Tensor, "*batch time"]:
|
|
20
24
|
mask = torch.rand(size=X.shape[:-1]) < self.prob
|
|
21
25
|
X[mask] *= -1
|
|
22
26
|
return X
|
|
@@ -37,7 +41,9 @@ class SignalReverser(torch.nn.Module):
|
|
|
37
41
|
super().__init__()
|
|
38
42
|
self.prob = prob
|
|
39
43
|
|
|
40
|
-
def forward(
|
|
44
|
+
def forward(
|
|
45
|
+
self, X: Float[Tensor, "*batch time"]
|
|
46
|
+
) -> Float[Tensor, "*batch time"]:
|
|
41
47
|
mask = torch.rand(size=X.shape[:-1]) < self.prob
|
|
42
48
|
X[mask] = X[mask].flip(-1)
|
|
43
49
|
return X
|
ml4gw/constants.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Various constants, all in SI units.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
EulerGamma = 0.577215664901532860606512090082402431
|
|
6
|
+
|
|
7
|
+
MSUN = 1.988409902147041637325262574352366540e30 # kg
|
|
8
|
+
"""Solar mass"""
|
|
9
|
+
|
|
10
|
+
MRSUN = 1.476625038050124729627979840144936351e3
|
|
11
|
+
"""Geometrized nominal solar mass, m"""
|
|
12
|
+
|
|
13
|
+
G = 6.67430e-11 # m^3 / kg / s^2
|
|
14
|
+
"""Newton's gravitational constant"""
|
|
15
|
+
|
|
16
|
+
C = 299792458.0 # m / s
|
|
17
|
+
"""Speed of light"""
|
|
18
|
+
|
|
19
|
+
"""Pi"""
|
|
20
|
+
PI = 3.141592653589793238462643383279502884
|
|
21
|
+
|
|
22
|
+
TWO_PI = 6.283185307179586476925286766559005768
|
|
23
|
+
|
|
24
|
+
gt = G * MSUN / (C**3.0)
|
|
25
|
+
"""
|
|
26
|
+
G MSUN / C^3 in seconds
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
MTSUN_SI = 4.925490947641266978197229498498379006e-6
|
|
30
|
+
"""1 solar mass in seconds. Same value as lal.MTSUN_SI"""
|
|
31
|
+
|
|
32
|
+
m_per_Mpc = 3.085677581491367278913937957796471611e22
|
|
33
|
+
"""
|
|
34
|
+
Meters per Mpc.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
MPC_SEC = m_per_Mpc / C
|
|
38
|
+
"""
|
|
39
|
+
1 Mpc in seconds.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
clightGpc = C / 3.0856778570831e22
|
|
43
|
+
"""
|
|
44
|
+
Speed of light in vacuum (:math:`c`), in gigaparsecs per second
|
|
45
|
+
"""
|
|
@@ -2,6 +2,8 @@ from collections.abc import Iterable
|
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
5
|
+
from ml4gw.types import WaveformTensor
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
7
9
|
"""
|
|
@@ -55,10 +57,10 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
55
57
|
self.coincident = coincident
|
|
56
58
|
self.device = device
|
|
57
59
|
|
|
58
|
-
def __len__(self):
|
|
60
|
+
def __len__(self) -> int:
|
|
59
61
|
return len(self.chunk_it) * self.batches_per_chunk
|
|
60
62
|
|
|
61
|
-
def __iter__(self):
|
|
63
|
+
def __iter__(self) -> WaveformTensor:
|
|
62
64
|
it = iter(self.chunk_it)
|
|
63
65
|
chunk = next(it)
|
|
64
66
|
num_chunks, num_channels, chunk_size = chunk.shape
|
|
@@ -161,7 +161,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
161
161
|
x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
|
|
162
162
|
return torch.Tensor(x)
|
|
163
163
|
|
|
164
|
-
def __iter__(self) ->
|
|
164
|
+
def __iter__(self) -> WaveformTensor:
|
|
165
165
|
worker_info = torch.utils.data.get_worker_info()
|
|
166
166
|
if worker_info is None:
|
|
167
167
|
num_batches = self.batches_per_epoch
|
|
@@ -2,8 +2,9 @@ import itertools
|
|
|
2
2
|
from typing import Optional, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
from torch import Tensor
|
|
5
7
|
|
|
6
|
-
from ml4gw import types
|
|
7
8
|
from ml4gw.utils.slicing import slice_kernels
|
|
8
9
|
|
|
9
10
|
|
|
@@ -76,9 +77,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
76
77
|
|
|
77
78
|
def __init__(
|
|
78
79
|
self,
|
|
79
|
-
X:
|
|
80
|
+
X: Float[Tensor, "channels time"],
|
|
80
81
|
kernel_size: int,
|
|
81
|
-
y: Optional[
|
|
82
|
+
y: Optional[Float[Tensor, " time"]] = None,
|
|
82
83
|
batch_size: int = 32,
|
|
83
84
|
stride: int = 1,
|
|
84
85
|
batches_per_epoch: Optional[int] = None,
|
|
@@ -207,7 +208,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
207
208
|
|
|
208
209
|
def __iter__(
|
|
209
210
|
self,
|
|
210
|
-
) -> Union[
|
|
211
|
+
) -> Union[
|
|
212
|
+
Float[Tensor, "batch channel time"],
|
|
213
|
+
Tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]],
|
|
214
|
+
]:
|
|
211
215
|
|
|
212
216
|
indices = self.init_indices()
|
|
213
217
|
for i in range(len(self)):
|
ml4gw/distributions.py
CHANGED
|
@@ -4,11 +4,13 @@ from specified distributions. Each callable should map from
|
|
|
4
4
|
an integer `N` to a 1D torch `Tensor` containing `N` samples
|
|
5
5
|
from the corresponding distribution.
|
|
6
6
|
"""
|
|
7
|
-
|
|
7
|
+
import math
|
|
8
8
|
from typing import Optional
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
import torch.distributions as dist
|
|
12
|
+
from jaxtyping import Float
|
|
13
|
+
from torch import Tensor
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class Cosine(dist.Distribution):
|
|
@@ -21,20 +23,21 @@ class Cosine(dist.Distribution):
|
|
|
21
23
|
|
|
22
24
|
def __init__(
|
|
23
25
|
self,
|
|
24
|
-
low: float =
|
|
25
|
-
high: float =
|
|
26
|
+
low: float = -math.pi / 2,
|
|
27
|
+
high: float = math.pi / 2,
|
|
26
28
|
validate_args=None,
|
|
27
29
|
):
|
|
28
30
|
batch_shape = torch.Size()
|
|
29
31
|
super().__init__(batch_shape, validate_args=validate_args)
|
|
30
|
-
self.low = low
|
|
31
|
-
self.
|
|
32
|
+
self.low = torch.as_tensor(low)
|
|
33
|
+
self.high = torch.as_tensor(high)
|
|
34
|
+
self.norm = 1 / (torch.sin(self.high) - torch.sin(self.low))
|
|
32
35
|
|
|
33
|
-
def rsample(self, sample_shape: torch.Size = torch.Size()) ->
|
|
36
|
+
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
|
|
34
37
|
u = torch.rand(sample_shape, device=self.low.device)
|
|
35
38
|
return torch.arcsin(u / self.norm + torch.sin(self.low))
|
|
36
39
|
|
|
37
|
-
def log_prob(self, value):
|
|
40
|
+
def log_prob(self, value: float) -> Float[Tensor, ""]:
|
|
38
41
|
value = torch.as_tensor(value)
|
|
39
42
|
inside_range = (value >= self.low) & (value <= self.high)
|
|
40
43
|
return value.cos().log() * inside_range
|
|
@@ -48,13 +51,16 @@ class Sine(dist.TransformedDistribution):
|
|
|
48
51
|
|
|
49
52
|
def __init__(
|
|
50
53
|
self,
|
|
51
|
-
low: float =
|
|
52
|
-
high: float =
|
|
54
|
+
low: float = 0.0,
|
|
55
|
+
high: float = math.pi,
|
|
53
56
|
validate_args=None,
|
|
54
57
|
):
|
|
58
|
+
low = torch.as_tensor(low)
|
|
59
|
+
high = torch.as_tensor(high)
|
|
55
60
|
base_dist = Cosine(
|
|
56
61
|
low - torch.pi / 2, high - torch.pi / 2, validate_args
|
|
57
62
|
)
|
|
63
|
+
|
|
58
64
|
super().__init__(
|
|
59
65
|
base_dist,
|
|
60
66
|
[
|
|
@@ -153,14 +159,14 @@ class DeltaFunction(dist.Distribution):
|
|
|
153
159
|
|
|
154
160
|
def __init__(
|
|
155
161
|
self,
|
|
156
|
-
peak: float =
|
|
162
|
+
peak: float = 0.0,
|
|
157
163
|
validate_args=None,
|
|
158
164
|
):
|
|
159
165
|
batch_shape = torch.Size()
|
|
160
166
|
super().__init__(batch_shape, validate_args=validate_args)
|
|
161
|
-
self.peak = peak
|
|
167
|
+
self.peak = torch.as_tensor(peak)
|
|
162
168
|
|
|
163
|
-
def rsample(self, sample_shape: torch.Size = torch.Size()) ->
|
|
169
|
+
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
|
|
164
170
|
return self.peak * torch.ones(
|
|
165
171
|
sample_shape, device=self.peak.device, dtype=torch.float32
|
|
166
172
|
)
|
ml4gw/gw.py
CHANGED
|
@@ -13,27 +13,21 @@ https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py
|
|
|
13
13
|
from typing import List, Tuple, Union
|
|
14
14
|
|
|
15
15
|
import torch
|
|
16
|
-
from
|
|
16
|
+
from jaxtyping import Float
|
|
17
|
+
from torch import Tensor
|
|
17
18
|
|
|
19
|
+
from ml4gw.constants import C
|
|
18
20
|
from ml4gw.types import (
|
|
21
|
+
BatchTensor,
|
|
19
22
|
NetworkDetectorTensors,
|
|
20
23
|
NetworkVertices,
|
|
21
24
|
PSDTensor,
|
|
22
|
-
ScalarTensor,
|
|
23
25
|
TensorGeometry,
|
|
24
26
|
VectorGeometry,
|
|
25
27
|
WaveformTensor,
|
|
26
28
|
)
|
|
27
29
|
from ml4gw.utils.interferometer import InterferometerGeometry
|
|
28
30
|
|
|
29
|
-
SPEED_OF_LIGHT = 299792458.0 # m/s
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
# define some tensor shapes we'll reuse a bit
|
|
33
|
-
# up front. Need to assign these variables so
|
|
34
|
-
# that static linters don't give us name errors
|
|
35
|
-
batch = num_ifos = polarizations = time = frequency = space = None # noqa
|
|
36
|
-
|
|
37
31
|
|
|
38
32
|
def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
|
|
39
33
|
"""
|
|
@@ -62,12 +56,12 @@ polarization_funcs = {
|
|
|
62
56
|
|
|
63
57
|
|
|
64
58
|
def compute_antenna_responses(
|
|
65
|
-
theta:
|
|
66
|
-
psi:
|
|
67
|
-
phi:
|
|
59
|
+
theta: BatchTensor,
|
|
60
|
+
psi: BatchTensor,
|
|
61
|
+
phi: BatchTensor,
|
|
68
62
|
detector_tensors: NetworkDetectorTensors,
|
|
69
63
|
modes: List[str],
|
|
70
|
-
) ->
|
|
64
|
+
) -> Float[Tensor, "batch polarizations num_ifos"]:
|
|
71
65
|
"""
|
|
72
66
|
Compute the antenna pattern factors of a batch of
|
|
73
67
|
waveforms as a function of the sky parameters of
|
|
@@ -147,8 +141,8 @@ def compute_antenna_responses(
|
|
|
147
141
|
|
|
148
142
|
def shift_responses(
|
|
149
143
|
responses: WaveformTensor,
|
|
150
|
-
theta:
|
|
151
|
-
phi:
|
|
144
|
+
theta: BatchTensor,
|
|
145
|
+
phi: BatchTensor,
|
|
152
146
|
vertices: NetworkVertices,
|
|
153
147
|
sample_rate: float,
|
|
154
148
|
) -> WaveformTensor:
|
|
@@ -166,7 +160,7 @@ def shift_responses(
|
|
|
166
160
|
# Divide by c in the second line so that we only
|
|
167
161
|
# need to multiply the array by a single float
|
|
168
162
|
dt = -(omega * vertices).sum(axis=-1)
|
|
169
|
-
dt *= sample_rate /
|
|
163
|
+
dt *= sample_rate / C
|
|
170
164
|
dt = torch.trunc(dt).type(torch.int64)
|
|
171
165
|
|
|
172
166
|
# rolling by gathering implementation based on
|
|
@@ -191,13 +185,13 @@ def shift_responses(
|
|
|
191
185
|
|
|
192
186
|
|
|
193
187
|
def compute_observed_strain(
|
|
194
|
-
dec:
|
|
195
|
-
psi:
|
|
196
|
-
phi:
|
|
188
|
+
dec: BatchTensor,
|
|
189
|
+
psi: BatchTensor,
|
|
190
|
+
phi: BatchTensor,
|
|
197
191
|
detector_tensors: NetworkDetectorTensors,
|
|
198
192
|
detector_vertices: NetworkVertices,
|
|
199
193
|
sample_rate: float,
|
|
200
|
-
**polarizations:
|
|
194
|
+
**polarizations: Float[Tensor, "batch time"],
|
|
201
195
|
) -> WaveformTensor:
|
|
202
196
|
"""
|
|
203
197
|
Compute the strain timeseries $h(t)$ observed by a network
|
|
@@ -289,8 +283,8 @@ def compute_ifo_snr(
|
|
|
289
283
|
responses: WaveformTensor,
|
|
290
284
|
psd: PSDTensor,
|
|
291
285
|
sample_rate: float,
|
|
292
|
-
highpass: Union[float,
|
|
293
|
-
) ->
|
|
286
|
+
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
287
|
+
) -> Float[Tensor, "batch num_ifos"]:
|
|
294
288
|
r"""Compute the SNRs of a batch of interferometer responses
|
|
295
289
|
|
|
296
290
|
Compute the signal to noise ratio (SNR) of individual
|
|
@@ -390,8 +384,8 @@ def compute_network_snr(
|
|
|
390
384
|
responses: WaveformTensor,
|
|
391
385
|
psd: PSDTensor,
|
|
392
386
|
sample_rate: float,
|
|
393
|
-
highpass: Union[float,
|
|
394
|
-
) ->
|
|
387
|
+
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
388
|
+
) -> BatchTensor:
|
|
395
389
|
r"""
|
|
396
390
|
Compute the total SNR from a gravitational waveform
|
|
397
391
|
from a network of interferometers. The total SNR for
|
|
@@ -437,10 +431,10 @@ def compute_network_snr(
|
|
|
437
431
|
|
|
438
432
|
def reweight_snrs(
|
|
439
433
|
responses: WaveformTensor,
|
|
440
|
-
target_snrs: Union[float,
|
|
434
|
+
target_snrs: Union[float, BatchTensor],
|
|
441
435
|
psd: PSDTensor,
|
|
442
436
|
sample_rate: float,
|
|
443
|
-
highpass: Union[float,
|
|
437
|
+
highpass: Union[float, Float[Tensor, " frequency"], None] = None,
|
|
444
438
|
) -> WaveformTensor:
|
|
445
439
|
"""Scale interferometer responses such that they have a desired SNR
|
|
446
440
|
|
ml4gw/nn/autoencoder/base.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from torch import Tensor
|
|
5
6
|
|
|
6
7
|
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
|
|
7
8
|
|
|
@@ -27,12 +28,16 @@ class Autoencoder(torch.nn.Module):
|
|
|
27
28
|
and how they operate.
|
|
28
29
|
"""
|
|
29
30
|
|
|
30
|
-
def __init__(
|
|
31
|
+
def __init__(
|
|
32
|
+
self, skip_connection: Optional[SkipConnection] = None
|
|
33
|
+
) -> None:
|
|
31
34
|
super().__init__()
|
|
32
35
|
self.skip_connection = skip_connection
|
|
33
36
|
self.blocks = torch.nn.ModuleList()
|
|
34
37
|
|
|
35
|
-
def encode(
|
|
38
|
+
def encode(
|
|
39
|
+
self, *X: Tensor, return_states: bool = False
|
|
40
|
+
) -> Union[Tensor, Tuple[Tensor, Sequence]]:
|
|
36
41
|
states = []
|
|
37
42
|
for block in self.blocks:
|
|
38
43
|
if isinstance(X, tuple):
|
|
@@ -48,7 +53,7 @@ class Autoencoder(torch.nn.Module):
|
|
|
48
53
|
return X, states[:-1]
|
|
49
54
|
return X
|
|
50
55
|
|
|
51
|
-
def decode(self, *X, states: Optional[Sequence[
|
|
56
|
+
def decode(self, *X, states: Optional[Sequence[Tensor]] = None) -> Tensor:
|
|
52
57
|
if self.skip_connection is not None and states is None:
|
|
53
58
|
raise ValueError(
|
|
54
59
|
"Must pass intermediate states when autoencoder "
|
|
@@ -76,7 +81,7 @@ class Autoencoder(torch.nn.Module):
|
|
|
76
81
|
X = self.skip_connection(X, state)
|
|
77
82
|
return X
|
|
78
83
|
|
|
79
|
-
def forward(self, *X):
|
|
84
|
+
def forward(self, *X: Tensor) -> Tensor:
|
|
80
85
|
return_states = self.skip_connection is not None
|
|
81
86
|
X = self.encode(*X, return_states=return_states)
|
|
82
87
|
if return_states:
|
|
@@ -84,6 +89,6 @@ class Autoencoder(torch.nn.Module):
|
|
|
84
89
|
else:
|
|
85
90
|
states = None
|
|
86
91
|
|
|
87
|
-
if isinstance(X,
|
|
92
|
+
if isinstance(X, Tensor):
|
|
88
93
|
X = (X,)
|
|
89
94
|
return self.decode(*X, states=states)
|
|
@@ -2,6 +2,7 @@ from collections.abc import Callable, Sequence
|
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
+
from torch import Tensor
|
|
5
6
|
|
|
6
7
|
from ml4gw.nn.autoencoder.base import Autoencoder
|
|
7
8
|
from ml4gw.nn.autoencoder.skip_connection import SkipConnection
|
|
@@ -64,12 +65,12 @@ class ConvBlock(Autoencoder):
|
|
|
64
65
|
self.encode_norm = norm(out_channels)
|
|
65
66
|
self.decode_norm = norm(decode_channels)
|
|
66
67
|
|
|
67
|
-
def encode(self, X):
|
|
68
|
+
def encode(self, X: Tensor) -> Tensor:
|
|
68
69
|
X = self.encode_layer(X)
|
|
69
70
|
X = self.encode_norm(X)
|
|
70
71
|
return self.activation(X)
|
|
71
72
|
|
|
72
|
-
def decode(self, X):
|
|
73
|
+
def decode(self, X: Tensor) -> Tensor:
|
|
73
74
|
X = self.decode_layer(X)
|
|
74
75
|
X = self.decode_norm(X)
|
|
75
76
|
return self.output_activation(X)
|
|
@@ -144,13 +145,15 @@ class ConvolutionalAutoencoder(Autoencoder):
|
|
|
144
145
|
self.blocks.append(block)
|
|
145
146
|
in_channels = channels * groups
|
|
146
147
|
|
|
147
|
-
def decode(
|
|
148
|
+
def decode(
|
|
149
|
+
self, *X, states=None, input_size: Optional[int] = None
|
|
150
|
+
) -> Tensor:
|
|
148
151
|
X = super().decode(*X, states=states)
|
|
149
152
|
if input_size is not None:
|
|
150
153
|
return match_size(X, input_size)
|
|
151
154
|
return X
|
|
152
155
|
|
|
153
|
-
def forward(self, X):
|
|
156
|
+
def forward(self, X: Tensor) -> Tensor:
|
|
154
157
|
input_size = X.size(-1)
|
|
155
158
|
X = super().forward(X)
|
|
156
159
|
return match_size(X, input_size)
|
|
@@ -1,31 +1,32 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from torch import Tensor
|
|
2
3
|
|
|
3
4
|
from ml4gw.nn.autoencoder.utils import match_size
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class SkipConnection(torch.nn.Module):
|
|
7
|
-
def forward(self, X:
|
|
8
|
+
def forward(self, X: Tensor, state: Tensor) -> Tensor:
|
|
8
9
|
return match_size(X, state.size(-1))
|
|
9
10
|
|
|
10
|
-
def get_out_channels(self, in_channels):
|
|
11
|
+
def get_out_channels(self, in_channels: int) -> int:
|
|
11
12
|
return in_channels
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class AddSkipConnect(SkipConnection):
|
|
15
|
-
def forward(self, X, state):
|
|
16
|
+
def forward(self, X: Tensor, state: Tensor) -> Tensor:
|
|
16
17
|
X = super().forward(X, state)
|
|
17
18
|
return X + state
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class ConcatSkipConnect(SkipConnection):
|
|
21
|
-
def __init__(self, groups: int = 1):
|
|
22
|
+
def __init__(self, groups: int = 1) -> None:
|
|
22
23
|
super().__init__()
|
|
23
24
|
self.groups = groups
|
|
24
25
|
|
|
25
|
-
def get_out_channels(self, in_channels):
|
|
26
|
+
def get_out_channels(self, in_channels: int) -> int:
|
|
26
27
|
return 2 * in_channels
|
|
27
28
|
|
|
28
|
-
def forward(self, X, state):
|
|
29
|
+
def forward(self, X: Tensor, state: Tensor) -> Tensor:
|
|
29
30
|
X = super().forward(X, state)
|
|
30
31
|
if self.groups == 1:
|
|
31
32
|
return torch.cat([X, state], dim=1)
|
ml4gw/nn/autoencoder/utils.py
CHANGED
ml4gw/nn/norm.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Callable, Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
NormLayer = Callable[[int], torch.nn.Module]
|
|
6
8
|
|
|
@@ -31,7 +33,15 @@ class GroupNorm1D(torch.nn.Module):
|
|
|
31
33
|
self.weight = torch.nn.Parameter(torch.ones(shape))
|
|
32
34
|
self.bias = torch.nn.Parameter(torch.zeros(shape))
|
|
33
35
|
|
|
34
|
-
def forward(
|
|
36
|
+
def forward(
|
|
37
|
+
self, x: Float[Tensor, "batch channel length"]
|
|
38
|
+
) -> Float[Tensor, "batch channel length"]:
|
|
39
|
+
if len(x.shape) != 3:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"GroupNorm1D requires 3-dimensional input, "
|
|
42
|
+
f"received {len(x.shape)} dimensional input"
|
|
43
|
+
)
|
|
44
|
+
|
|
35
45
|
keepdims = self.num_groups == self.num_channels
|
|
36
46
|
|
|
37
47
|
# compute group variance via the E[x**2] - E**2[x] trick
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from typing import Optional, Tuple
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.utils.slicing import unfold_windows
|
|
6
8
|
|
|
7
|
-
Tensor = torch.Tensor
|
|
8
|
-
|
|
9
9
|
|
|
10
10
|
class OnlineAverager(torch.nn.Module):
|
|
11
11
|
"""
|
|
@@ -70,12 +70,14 @@ class OnlineAverager(torch.nn.Module):
|
|
|
70
70
|
weights = unfold_windows(weights, weight_size, update_size)
|
|
71
71
|
self.register_buffer("weights", weights)
|
|
72
72
|
|
|
73
|
-
def get_initial_state(self):
|
|
73
|
+
def get_initial_state(self) -> Float[Tensor, "channel time"]:
|
|
74
74
|
return torch.zeros((self.num_channels, self.state_size))
|
|
75
75
|
|
|
76
76
|
def forward(
|
|
77
|
-
self,
|
|
78
|
-
|
|
77
|
+
self,
|
|
78
|
+
update: Float[Tensor, "batch channel time1"],
|
|
79
|
+
state: Optional[Float[Tensor, "channel time2"]] = None,
|
|
80
|
+
) -> Tuple[Float[Tensor, "channel time3"], Float[Tensor, "channel time4"]]:
|
|
79
81
|
if state is None:
|
|
80
82
|
state = self.get_initial_state()
|
|
81
83
|
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Optional, Sequence, Tuple
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from torch import Tensor
|
|
4
6
|
|
|
5
7
|
from ml4gw.utils.slicing import unfold_windows
|
|
6
8
|
|
|
@@ -82,14 +84,14 @@ class Snapshotter(torch.nn.Module):
|
|
|
82
84
|
self.channels_per_snapshot = channels_per_snapshot
|
|
83
85
|
self.num_channels = num_channels
|
|
84
86
|
|
|
85
|
-
def get_initial_state(self):
|
|
87
|
+
def get_initial_state(self) -> Float[Tensor, "channel time"]:
|
|
86
88
|
return torch.zeros((self.num_channels, self.state_size))
|
|
87
89
|
|
|
88
|
-
# TODO: use torchtyping annotations to make
|
|
89
|
-
# clear what the expected shapes are
|
|
90
90
|
def forward(
|
|
91
|
-
self,
|
|
92
|
-
|
|
91
|
+
self,
|
|
92
|
+
update: Float[Tensor, "channel time1"],
|
|
93
|
+
snapshot: Optional[Float[Tensor, "channel time2"]] = None,
|
|
94
|
+
) -> Tuple[Tensor, ...]:
|
|
93
95
|
if snapshot is None:
|
|
94
96
|
snapshot = self.get_initial_state()
|
|
95
97
|
|