ml4gw 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ml4gw/augmentations.py +5 -0
- ml4gw/dataloading/__init__.py +5 -0
- ml4gw/dataloading/chunked_dataset.py +2 -4
- ml4gw/dataloading/hdf5_dataset.py +12 -10
- ml4gw/dataloading/in_memory_dataset.py +12 -12
- ml4gw/distributions.py +3 -3
- ml4gw/gw.py +18 -21
- ml4gw/nn/__init__.py +6 -0
- ml4gw/nn/autoencoder/base.py +5 -9
- ml4gw/nn/autoencoder/convolutional.py +7 -10
- ml4gw/nn/autoencoder/skip_connection.py +3 -5
- ml4gw/nn/norm.py +4 -4
- ml4gw/nn/resnet/resnet_1d.py +12 -13
- ml4gw/nn/resnet/resnet_2d.py +13 -14
- ml4gw/nn/streaming/online_average.py +3 -5
- ml4gw/nn/streaming/snapshotter.py +10 -14
- ml4gw/spectral.py +20 -23
- ml4gw/transforms/__init__.py +7 -1
- ml4gw/transforms/decimator.py +183 -0
- ml4gw/transforms/iirfilter.py +3 -5
- ml4gw/transforms/pearson.py +3 -4
- ml4gw/transforms/qtransform.py +20 -26
- ml4gw/transforms/scaler.py +3 -5
- ml4gw/transforms/snr_rescaler.py +7 -11
- ml4gw/transforms/spectral.py +6 -13
- ml4gw/transforms/spectrogram.py +6 -3
- ml4gw/transforms/spline_interpolation.py +312 -143
- ml4gw/transforms/transform.py +4 -6
- ml4gw/transforms/waveforms.py +8 -15
- ml4gw/transforms/whitening.py +11 -16
- ml4gw/types.py +8 -5
- ml4gw/utils/interferometer.py +20 -3
- ml4gw/utils/slicing.py +26 -30
- ml4gw/waveforms/__init__.py +6 -0
- ml4gw/waveforms/cbc/phenom_p.py +7 -9
- ml4gw/waveforms/conversion.py +2 -4
- ml4gw/waveforms/generator.py +3 -3
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/METADATA +33 -12
- ml4gw-0.7.8.dist-info/RECORD +57 -0
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +2 -1
- ml4gw-0.7.8.dist-info/top_level.txt +1 -0
- ml4gw-0.7.6.dist-info/RECORD +0 -55
- {ml4gw-0.7.6.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
ml4gw/augmentations.py
CHANGED
ml4gw/dataloading/__init__.py
CHANGED
|
@@ -94,10 +94,8 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
94
94
|
# flatten it to make it easier to slice
|
|
95
95
|
if chunk_size < self.kernel_size:
|
|
96
96
|
raise ValueError(
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
"with size {}"
|
|
100
|
-
).format(self.kernel_size, chunk_size)
|
|
97
|
+
f"Can't sample kernels of size {self.kernel_size} from "
|
|
98
|
+
f"chunk with size {chunk_size}"
|
|
101
99
|
)
|
|
102
100
|
chunk = chunk.reshape(-1)
|
|
103
101
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
3
|
|
|
4
4
|
import h5py
|
|
5
5
|
import numpy as np
|
|
@@ -63,13 +63,13 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
63
63
|
kernel_size: int,
|
|
64
64
|
batch_size: int,
|
|
65
65
|
batches_per_epoch: int,
|
|
66
|
-
coincident:
|
|
67
|
-
num_files_per_batch:
|
|
66
|
+
coincident: bool | str,
|
|
67
|
+
num_files_per_batch: int | None = None,
|
|
68
68
|
) -> None:
|
|
69
69
|
if not isinstance(coincident, bool) and coincident != "files":
|
|
70
70
|
raise ValueError(
|
|
71
71
|
"coincident must be either a boolean or 'files', "
|
|
72
|
-
"got unrecognized value {}"
|
|
72
|
+
f"got unrecognized value {coincident}"
|
|
73
73
|
)
|
|
74
74
|
|
|
75
75
|
self.fnames = np.array(fnames)
|
|
@@ -94,13 +94,11 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
94
94
|
dset = f[channels[0]]
|
|
95
95
|
if dset.chunks is None:
|
|
96
96
|
warnings.warn(
|
|
97
|
-
"File {} contains datasets that were generated "
|
|
97
|
+
f"File {fname} contains datasets that were generated "
|
|
98
98
|
"without using chunked storage. This can have "
|
|
99
99
|
"severe performance impacts at data loading time. "
|
|
100
100
|
"If you need faster loading, try re-generating "
|
|
101
|
-
"your dataset with chunked storage turned on."
|
|
102
|
-
fname
|
|
103
|
-
),
|
|
101
|
+
"your dataset with chunked storage turned on.",
|
|
104
102
|
category=ContiguousHdf5Warning,
|
|
105
103
|
stacklevel=2,
|
|
106
104
|
)
|
|
@@ -153,7 +151,9 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
153
151
|
unique_fnames, inv, counts = np.unique(
|
|
154
152
|
fnames, return_inverse=True, return_counts=True
|
|
155
153
|
)
|
|
156
|
-
for i, (fname, count) in enumerate(
|
|
154
|
+
for i, (fname, count) in enumerate(
|
|
155
|
+
zip(unique_fnames, counts, strict=True)
|
|
156
|
+
):
|
|
157
157
|
size = self.sizes[fname]
|
|
158
158
|
max_idx = size - self.kernel_size
|
|
159
159
|
|
|
@@ -185,7 +185,9 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
|
|
|
185
185
|
# open the file and sample a different set of
|
|
186
186
|
# kernels for each batch element it occupies
|
|
187
187
|
with h5py.File(fname, "r") as f:
|
|
188
|
-
for b, c, i in zip(
|
|
188
|
+
for b, c, i in zip(
|
|
189
|
+
batch_indices, channel_indices, idx, strict=True
|
|
190
|
+
):
|
|
189
191
|
x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
|
|
190
192
|
return torch.Tensor(x)
|
|
191
193
|
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import itertools
|
|
2
|
-
from typing import Optional, Tuple, Union
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
from jaxtyping import Float
|
|
@@ -79,10 +78,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
79
78
|
self,
|
|
80
79
|
X: Float[Tensor, "channels time"],
|
|
81
80
|
kernel_size: int,
|
|
82
|
-
y:
|
|
81
|
+
y: Float[Tensor, " time"] | None = None,
|
|
83
82
|
batch_size: int = 32,
|
|
84
83
|
stride: int = 1,
|
|
85
|
-
batches_per_epoch:
|
|
84
|
+
batches_per_epoch: int | None = None,
|
|
86
85
|
coincident: bool = True,
|
|
87
86
|
shuffle: bool = True,
|
|
88
87
|
device: str = "cpu",
|
|
@@ -122,10 +121,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
122
121
|
batch_size * batches_per_epoch
|
|
123
122
|
):
|
|
124
123
|
raise ValueError(
|
|
125
|
-
"Number of kernels {} in timeseries
|
|
126
|
-
"to generate {} batches of size
|
|
127
|
-
|
|
128
|
-
)
|
|
124
|
+
f"Number of kernels {self.num_kernels} in timeseries "
|
|
125
|
+
f"insufficient to generate {batch_size} batches of size "
|
|
126
|
+
f"{batches_per_epoch}"
|
|
129
127
|
)
|
|
130
128
|
|
|
131
129
|
self.batch_size = batch_size
|
|
@@ -191,7 +189,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
191
189
|
# indices we'll need rather than having to generate
|
|
192
190
|
# everything.
|
|
193
191
|
idx = [range(self.num_kernels) for _ in range(len(self.X))]
|
|
194
|
-
idx = zip(
|
|
192
|
+
idx = zip(
|
|
193
|
+
range(num_kernels), itertools.product(*idx), strict=False
|
|
194
|
+
)
|
|
195
195
|
idx = torch.stack([torch.Tensor(i[1]) for i in idx])
|
|
196
196
|
idx = idx.type(torch.int64).to(device)
|
|
197
197
|
elif self.shuffle:
|
|
@@ -208,10 +208,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
|
|
|
208
208
|
|
|
209
209
|
def __iter__(
|
|
210
210
|
self,
|
|
211
|
-
) ->
|
|
212
|
-
Float[Tensor, "batch channel time"]
|
|
213
|
-
|
|
214
|
-
|
|
211
|
+
) -> (
|
|
212
|
+
Float[Tensor, "batch channel time"]
|
|
213
|
+
| tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]]
|
|
214
|
+
):
|
|
215
215
|
indices = self.init_indices()
|
|
216
216
|
for i in range(len(self)):
|
|
217
217
|
# slice the array of _indices_ we'll be using to
|
ml4gw/distributions.py
CHANGED
|
@@ -6,7 +6,7 @@ from the corresponding distribution.
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import math
|
|
9
|
-
from
|
|
9
|
+
from collections.abc import Callable
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
import torch.distributions as dist
|
|
@@ -104,7 +104,7 @@ class LogNormal(dist.LogNormal):
|
|
|
104
104
|
self,
|
|
105
105
|
mean: float,
|
|
106
106
|
std: float,
|
|
107
|
-
low:
|
|
107
|
+
low: float | None = None,
|
|
108
108
|
validate_args=None,
|
|
109
109
|
):
|
|
110
110
|
self.low = low
|
|
@@ -137,7 +137,7 @@ class PowerLaw(dist.TransformedDistribution):
|
|
|
137
137
|
support = dist.constraints.nonnegative
|
|
138
138
|
|
|
139
139
|
def __init__(
|
|
140
|
-
self, minimum: float, maximum: float, index:
|
|
140
|
+
self, minimum: float, maximum: float, index: float, validate_args=None
|
|
141
141
|
):
|
|
142
142
|
if index == 0:
|
|
143
143
|
raise ValueError("Index of 0 is the same as Uniform")
|
ml4gw/gw.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Tools for manipulating raw gravitational waveforms
|
|
3
|
-
|
|
2
|
+
Tools for manipulating raw gravitational waveforms,
|
|
3
|
+
projecting them onto interferometer responses, and
|
|
4
|
+
calculating SNRs.
|
|
4
5
|
Much of the projection code is an extension of the
|
|
5
6
|
implementation made available in
|
|
6
7
|
`bilby <https://arxiv.org/abs/1811.02042>`_.
|
|
@@ -8,8 +9,6 @@ Specifically code from
|
|
|
8
9
|
`this module <https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py>`_.
|
|
9
10
|
""" # noqa E501
|
|
10
11
|
|
|
11
|
-
from typing import List, Tuple, Union
|
|
12
|
-
|
|
13
12
|
import torch
|
|
14
13
|
from jaxtyping import Float
|
|
15
14
|
from torch import Tensor
|
|
@@ -58,7 +57,7 @@ def compute_antenna_responses(
|
|
|
58
57
|
psi: BatchTensor,
|
|
59
58
|
phi: BatchTensor,
|
|
60
59
|
detector_tensors: NetworkDetectorTensors,
|
|
61
|
-
modes:
|
|
60
|
+
modes: list[str],
|
|
62
61
|
) -> Float[Tensor, "batch polarizations num_ifos"]:
|
|
63
62
|
"""
|
|
64
63
|
Compute the antenna pattern factors of a batch of
|
|
@@ -257,7 +256,7 @@ def compute_observed_strain(
|
|
|
257
256
|
|
|
258
257
|
def get_ifo_geometry(
|
|
259
258
|
*ifos: str,
|
|
260
|
-
) ->
|
|
259
|
+
) -> tuple[NetworkDetectorTensors, NetworkVertices]:
|
|
261
260
|
"""
|
|
262
261
|
For a given list of interferometer names, retrieve and
|
|
263
262
|
concatenate the associated detector tensors and vertices
|
|
@@ -286,8 +285,8 @@ def compute_ifo_snr(
|
|
|
286
285
|
responses: WaveformTensor,
|
|
287
286
|
psd: PSDTensor,
|
|
288
287
|
sample_rate: float,
|
|
289
|
-
highpass:
|
|
290
|
-
lowpass:
|
|
288
|
+
highpass: float | Float[Tensor, " frequency"] | None = None,
|
|
289
|
+
lowpass: float | Float[Tensor, " frequency"] | None = None,
|
|
291
290
|
) -> Float[Tensor, "batch num_ifos"]:
|
|
292
291
|
"""Compute the SNRs of a batch of interferometer responses
|
|
293
292
|
|
|
@@ -367,10 +366,9 @@ def compute_ifo_snr(
|
|
|
367
366
|
highpass = freqs >= highpass
|
|
368
367
|
elif len(highpass) != integrand.shape[-1]:
|
|
369
368
|
raise ValueError(
|
|
370
|
-
"Can't apply highpass filter mask with {}
|
|
371
|
-
"to signal fft with {}
|
|
372
|
-
|
|
373
|
-
)
|
|
369
|
+
f"Can't apply highpass filter mask with {len(highpass)} "
|
|
370
|
+
f"frequency bins to signal fft with {integrand.shape[-1]} "
|
|
371
|
+
"frequency bins"
|
|
374
372
|
)
|
|
375
373
|
integrand *= highpass.to(integrand.device)
|
|
376
374
|
if lowpass is not None:
|
|
@@ -379,10 +377,9 @@ def compute_ifo_snr(
|
|
|
379
377
|
lowpass = freqs < lowpass
|
|
380
378
|
elif len(lowpass) != integrand.shape[-1]:
|
|
381
379
|
raise ValueError(
|
|
382
|
-
"Can't apply lowpass filter mask with {}
|
|
383
|
-
"to signal fft with {}
|
|
384
|
-
|
|
385
|
-
)
|
|
380
|
+
f"Can't apply lowpass filter mask with {len(lowpass)} "
|
|
381
|
+
f"frequency bins to signal fft with {integrand.shape[-1]} "
|
|
382
|
+
"frequency bins"
|
|
386
383
|
)
|
|
387
384
|
integrand *= lowpass.to(integrand.device)
|
|
388
385
|
|
|
@@ -410,8 +407,8 @@ def compute_network_snr(
|
|
|
410
407
|
responses: WaveformTensor,
|
|
411
408
|
psd: PSDTensor,
|
|
412
409
|
sample_rate: float,
|
|
413
|
-
highpass:
|
|
414
|
-
lowpass:
|
|
410
|
+
highpass: float | Float[Tensor, " frequency"] | None = None,
|
|
411
|
+
lowpass: float | Float[Tensor, " frequency"] | None = None,
|
|
415
412
|
) -> BatchTensor:
|
|
416
413
|
"""
|
|
417
414
|
Compute the total SNR from a gravitational waveform
|
|
@@ -467,11 +464,11 @@ def compute_network_snr(
|
|
|
467
464
|
|
|
468
465
|
def reweight_snrs(
|
|
469
466
|
responses: WaveformTensor,
|
|
470
|
-
target_snrs:
|
|
467
|
+
target_snrs: float | BatchTensor,
|
|
471
468
|
psd: PSDTensor,
|
|
472
469
|
sample_rate: float,
|
|
473
|
-
highpass:
|
|
474
|
-
lowpass:
|
|
470
|
+
highpass: float | Float[Tensor, " frequency"] | None = None,
|
|
471
|
+
lowpass: float | Float[Tensor, " frequency"] | None = None,
|
|
475
472
|
) -> WaveformTensor:
|
|
476
473
|
"""Scale interferometer responses such that they have a desired SNR
|
|
477
474
|
|
ml4gw/nn/__init__.py
CHANGED
ml4gw/nn/autoencoder/base.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from collections.abc import Sequence
|
|
2
|
-
from typing import Optional, Tuple, Union
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
from torch import Tensor
|
|
@@ -28,16 +27,14 @@ class Autoencoder(torch.nn.Module):
|
|
|
28
27
|
and how they operate.
|
|
29
28
|
"""
|
|
30
29
|
|
|
31
|
-
def __init__(
|
|
32
|
-
self, skip_connection: Optional[SkipConnection] = None
|
|
33
|
-
) -> None:
|
|
30
|
+
def __init__(self, skip_connection: SkipConnection | None = None) -> None:
|
|
34
31
|
super().__init__()
|
|
35
32
|
self.skip_connection = skip_connection
|
|
36
33
|
self.blocks = torch.nn.ModuleList()
|
|
37
34
|
|
|
38
35
|
def encode(
|
|
39
36
|
self, *X: Tensor, return_states: bool = False
|
|
40
|
-
) ->
|
|
37
|
+
) -> Tensor | tuple[Tensor, Sequence]:
|
|
41
38
|
states = []
|
|
42
39
|
for block in self.blocks:
|
|
43
40
|
if isinstance(X, tuple):
|
|
@@ -53,7 +50,7 @@ class Autoencoder(torch.nn.Module):
|
|
|
53
50
|
return X, states[:-1]
|
|
54
51
|
return X
|
|
55
52
|
|
|
56
|
-
def decode(self, *X, states:
|
|
53
|
+
def decode(self, *X, states: Sequence[Tensor] | None = None) -> Tensor:
|
|
57
54
|
if self.skip_connection is not None and states is None:
|
|
58
55
|
raise ValueError(
|
|
59
56
|
"Must pass intermediate states when autoencoder "
|
|
@@ -62,9 +59,8 @@ class Autoencoder(torch.nn.Module):
|
|
|
62
59
|
elif states is not None:
|
|
63
60
|
if len(states) != len(self.blocks) - 1:
|
|
64
61
|
raise ValueError(
|
|
65
|
-
"Passed {} intermediate states, expected
|
|
66
|
-
|
|
67
|
-
)
|
|
62
|
+
f"Passed {len(states)} intermediate states, expected "
|
|
63
|
+
f"{len(self.blocks) - 1}"
|
|
68
64
|
)
|
|
69
65
|
|
|
70
66
|
# Don't skip connect the output layer
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from collections.abc import Callable, Sequence
|
|
2
|
-
from typing import Optional
|
|
3
2
|
|
|
4
3
|
import torch
|
|
5
4
|
from torch import Tensor
|
|
@@ -21,9 +20,9 @@ class ConvBlock(Autoencoder):
|
|
|
21
20
|
groups: int = 1,
|
|
22
21
|
activation: torch.nn.Module = torch.nn.ReLU,
|
|
23
22
|
norm: Module = torch.nn.BatchNorm1d,
|
|
24
|
-
decode_channels:
|
|
25
|
-
output_activation:
|
|
26
|
-
skip_connection:
|
|
23
|
+
decode_channels: int | None = None,
|
|
24
|
+
output_activation: torch.nn.Module | None = None,
|
|
25
|
+
skip_connection: SkipConnection | None = None,
|
|
27
26
|
) -> None:
|
|
28
27
|
super().__init__(skip_connection=None)
|
|
29
28
|
|
|
@@ -98,10 +97,10 @@ class ConvolutionalAutoencoder(Autoencoder):
|
|
|
98
97
|
stride: int = 1,
|
|
99
98
|
groups: int = 1,
|
|
100
99
|
activation: torch.nn.Module = torch.nn.ReLU,
|
|
101
|
-
output_activation:
|
|
100
|
+
output_activation: torch.nn.Module | None = None,
|
|
102
101
|
norm: Module = torch.nn.BatchNorm1d,
|
|
103
|
-
decode_channels:
|
|
104
|
-
skip_connection:
|
|
102
|
+
decode_channels: int | None = None,
|
|
103
|
+
skip_connection: SkipConnection | None = None,
|
|
105
104
|
) -> None:
|
|
106
105
|
# TODO: how to do this dynamically? Maybe the base
|
|
107
106
|
# architecture looks for overlapping arguments between
|
|
@@ -145,9 +144,7 @@ class ConvolutionalAutoencoder(Autoencoder):
|
|
|
145
144
|
self.blocks.append(block)
|
|
146
145
|
in_channels = channels * groups
|
|
147
146
|
|
|
148
|
-
def decode(
|
|
149
|
-
self, *X, states=None, input_size: Optional[int] = None
|
|
150
|
-
) -> Tensor:
|
|
147
|
+
def decode(self, *X, states=None, input_size: int | None = None) -> Tensor:
|
|
151
148
|
X = super().decode(*X, states=states)
|
|
152
149
|
if input_size is not None:
|
|
153
150
|
return match_size(X, input_size)
|
|
@@ -35,13 +35,11 @@ class ConcatSkipConnect(SkipConnection):
|
|
|
35
35
|
rem = num_channels % self.groups
|
|
36
36
|
if rem:
|
|
37
37
|
raise ValueError(
|
|
38
|
-
"Number of channels in input tensor {} cannot "
|
|
39
|
-
"be divided evenly into {} groups"
|
|
40
|
-
num_channels, self.groups
|
|
41
|
-
)
|
|
38
|
+
f"Number of channels in input tensor {num_channels} cannot "
|
|
39
|
+
f"be divided evenly into {self.groups} groups"
|
|
42
40
|
)
|
|
43
41
|
|
|
44
42
|
X = torch.split(X, self.groups, dim=1)
|
|
45
43
|
state = torch.split(state, self.groups, dim=1)
|
|
46
|
-
frags = [i for j in zip(X, state) for i in j]
|
|
44
|
+
frags = [i for j in zip(X, state, strict=True) for i in j]
|
|
47
45
|
return torch.cat(frags, dim=1)
|
ml4gw/nn/norm.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from jaxtyping import Float
|
|
@@ -16,7 +16,7 @@ class GroupNorm1D(torch.nn.Module):
|
|
|
16
16
|
def __init__(
|
|
17
17
|
self,
|
|
18
18
|
num_channels: int,
|
|
19
|
-
num_groups:
|
|
19
|
+
num_groups: int | None = None,
|
|
20
20
|
eps: float = 1e-5,
|
|
21
21
|
):
|
|
22
22
|
super().__init__()
|
|
@@ -77,7 +77,7 @@ class GroupNorm1DGetter:
|
|
|
77
77
|
for command-line parameterization with jsonargparse.
|
|
78
78
|
"""
|
|
79
79
|
|
|
80
|
-
def __init__(self, groups:
|
|
80
|
+
def __init__(self, groups: int | None = None) -> None:
|
|
81
81
|
self.groups = groups
|
|
82
82
|
|
|
83
83
|
def __call__(self, num_channels: int) -> torch.nn.Module:
|
|
@@ -96,7 +96,7 @@ class GroupNorm2DGetter:
|
|
|
96
96
|
for command-line parameterization with jsonargparse.
|
|
97
97
|
"""
|
|
98
98
|
|
|
99
|
-
def __init__(self, groups:
|
|
99
|
+
def __init__(self, groups: int | None = None) -> None:
|
|
100
100
|
self.groups = groups
|
|
101
101
|
|
|
102
102
|
def __call__(self, num_channels: int) -> torch.nn.Module:
|
ml4gw/nn/resnet/resnet_1d.py
CHANGED
|
@@ -7,7 +7,8 @@ where training-time statistics are entirely arbitrary due to
|
|
|
7
7
|
simulations.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
-
from
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from typing import Literal
|
|
11
12
|
|
|
12
13
|
import torch
|
|
13
14
|
import torch.nn as nn
|
|
@@ -58,11 +59,11 @@ class BasicBlock(nn.Module):
|
|
|
58
59
|
planes: int,
|
|
59
60
|
kernel_size: int = 3,
|
|
60
61
|
stride: int = 1,
|
|
61
|
-
downsample:
|
|
62
|
+
downsample: nn.Module | None = None,
|
|
62
63
|
groups: int = 1,
|
|
63
64
|
base_width: int = 64,
|
|
64
65
|
dilation: int = 1,
|
|
65
|
-
norm_layer:
|
|
66
|
+
norm_layer: Callable[..., nn.Module] | None = None,
|
|
66
67
|
) -> None:
|
|
67
68
|
super().__init__()
|
|
68
69
|
if norm_layer is None:
|
|
@@ -123,11 +124,11 @@ class Bottleneck(nn.Module):
|
|
|
123
124
|
planes: int,
|
|
124
125
|
kernel_size: int = 3,
|
|
125
126
|
stride: int = 1,
|
|
126
|
-
downsample:
|
|
127
|
+
downsample: nn.Module | None = None,
|
|
127
128
|
groups: int = 1,
|
|
128
129
|
base_width: int = 64,
|
|
129
130
|
dilation: int = 1,
|
|
130
|
-
norm_layer:
|
|
131
|
+
norm_layer: NormLayer | None = None,
|
|
131
132
|
) -> None:
|
|
132
133
|
super().__init__()
|
|
133
134
|
if norm_layer is None:
|
|
@@ -231,14 +232,14 @@ class ResNet1D(nn.Module):
|
|
|
231
232
|
def __init__(
|
|
232
233
|
self,
|
|
233
234
|
in_channels: int,
|
|
234
|
-
layers:
|
|
235
|
+
layers: list[int],
|
|
235
236
|
classes: int,
|
|
236
237
|
kernel_size: int = 3,
|
|
237
238
|
zero_init_residual: bool = False,
|
|
238
239
|
groups: int = 1,
|
|
239
240
|
width_per_group: int = 64,
|
|
240
|
-
stride_type:
|
|
241
|
-
norm_layer:
|
|
241
|
+
stride_type: list[Literal["stride", "dilation"]] | None = None,
|
|
242
|
+
norm_layer: NormLayer | None = None,
|
|
242
243
|
) -> None:
|
|
243
244
|
super().__init__()
|
|
244
245
|
|
|
@@ -257,10 +258,8 @@ class ResNet1D(nn.Module):
|
|
|
257
258
|
stride_type = ["stride"] * (len(layers) - 1)
|
|
258
259
|
if len(stride_type) != (len(layers) - 1):
|
|
259
260
|
raise ValueError(
|
|
260
|
-
(
|
|
261
|
-
|
|
262
|
-
"tuple, got {}"
|
|
263
|
-
).format(len(layers) - 1, stride_type)
|
|
261
|
+
f"'stride_type' should be None or a {len(layers) - 1}-element "
|
|
262
|
+
f"tuple, got {stride_type}"
|
|
264
263
|
)
|
|
265
264
|
|
|
266
265
|
self.groups = groups
|
|
@@ -289,7 +288,7 @@ class ResNet1D(nn.Module):
|
|
|
289
288
|
# striding or dilating depending on the stride_type
|
|
290
289
|
# argument)
|
|
291
290
|
residual_layers = [self._make_layer(64, layers[0], kernel_size)]
|
|
292
|
-
it = zip(layers[1:], stride_type)
|
|
291
|
+
it = zip(layers[1:], stride_type, strict=True)
|
|
293
292
|
for i, (num_blocks, stride) in enumerate(it):
|
|
294
293
|
block_size = 64 * 2 ** (i + 1)
|
|
295
294
|
layer = self._make_layer(
|
ml4gw/nn/resnet/resnet_2d.py
CHANGED
|
@@ -4,7 +4,8 @@ https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
|
|
|
4
4
|
but with arbitrary kernel sizes
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Callable
|
|
8
|
+
from typing import Literal
|
|
8
9
|
|
|
9
10
|
import torch
|
|
10
11
|
import torch.nn as nn
|
|
@@ -55,11 +56,11 @@ class BasicBlock(nn.Module):
|
|
|
55
56
|
planes: int,
|
|
56
57
|
kernel_size: int = 3,
|
|
57
58
|
stride: int = 1,
|
|
58
|
-
downsample:
|
|
59
|
+
downsample: nn.Module | None = None,
|
|
59
60
|
groups: int = 1,
|
|
60
61
|
base_width: int = 64,
|
|
61
62
|
dilation: int = 1,
|
|
62
|
-
norm_layer:
|
|
63
|
+
norm_layer: Callable[..., nn.Module] | None = None,
|
|
63
64
|
) -> None:
|
|
64
65
|
super().__init__()
|
|
65
66
|
if norm_layer is None:
|
|
@@ -120,11 +121,11 @@ class Bottleneck(nn.Module):
|
|
|
120
121
|
planes: int,
|
|
121
122
|
kernel_size: int = 3,
|
|
122
123
|
stride: int = 1,
|
|
123
|
-
downsample:
|
|
124
|
+
downsample: nn.Module | None = None,
|
|
124
125
|
groups: int = 1,
|
|
125
126
|
base_width: int = 64,
|
|
126
127
|
dilation: int = 1,
|
|
127
|
-
norm_layer:
|
|
128
|
+
norm_layer: Callable[..., nn.Module] | None = None,
|
|
128
129
|
) -> None:
|
|
129
130
|
super().__init__()
|
|
130
131
|
if norm_layer is None:
|
|
@@ -232,14 +233,14 @@ class ResNet2D(nn.Module):
|
|
|
232
233
|
def __init__(
|
|
233
234
|
self,
|
|
234
235
|
in_channels: int,
|
|
235
|
-
layers:
|
|
236
|
+
layers: list[int],
|
|
236
237
|
classes: int,
|
|
237
238
|
kernel_size: int = 3,
|
|
238
239
|
zero_init_residual: bool = False,
|
|
239
240
|
groups: int = 1,
|
|
240
241
|
width_per_group: int = 64,
|
|
241
|
-
stride_type:
|
|
242
|
-
norm_layer:
|
|
242
|
+
stride_type: list[Literal["stride", "dilation"]] | None = None,
|
|
243
|
+
norm_layer: NormLayer | None = None,
|
|
243
244
|
) -> None:
|
|
244
245
|
super().__init__()
|
|
245
246
|
# default to using InstanceNorm if no
|
|
@@ -257,10 +258,8 @@ class ResNet2D(nn.Module):
|
|
|
257
258
|
stride_type = ["stride"] * (len(layers) - 1)
|
|
258
259
|
if len(stride_type) != (len(layers) - 1):
|
|
259
260
|
raise ValueError(
|
|
260
|
-
(
|
|
261
|
-
|
|
262
|
-
"tuple, got {}"
|
|
263
|
-
).format(len(layers) - 1, stride_type)
|
|
261
|
+
f"'stride_type' should be None or a {len(layers) - 1}-element "
|
|
262
|
+
f"tuple, got {stride_type}"
|
|
264
263
|
)
|
|
265
264
|
|
|
266
265
|
self.groups = groups
|
|
@@ -289,7 +288,7 @@ class ResNet2D(nn.Module):
|
|
|
289
288
|
# striding or dilating depending on the stride_type
|
|
290
289
|
# argument)
|
|
291
290
|
residual_layers = [self._make_layer(64, layers[0], kernel_size)]
|
|
292
|
-
it = zip(layers[1:], stride_type)
|
|
291
|
+
it = zip(layers[1:], stride_type, strict=True)
|
|
293
292
|
for i, (num_blocks, stride) in enumerate(it):
|
|
294
293
|
block_size = 64 * 2 ** (i + 1)
|
|
295
294
|
layer = self._make_layer(
|
|
@@ -316,7 +315,7 @@ class ResNet2D(nn.Module):
|
|
|
316
315
|
nn.init.kaiming_normal_(
|
|
317
316
|
m.weight, mode="fan_out", nonlinearity="relu"
|
|
318
317
|
)
|
|
319
|
-
elif isinstance(m,
|
|
318
|
+
elif isinstance(m, nn.BatchNorm2d | nn.GroupNorm):
|
|
320
319
|
nn.init.constant_(m.weight, 1)
|
|
321
320
|
nn.init.constant_(m.bias, 0)
|
|
322
321
|
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional, Tuple
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
from jaxtyping import Float
|
|
5
3
|
from torch import Tensor
|
|
@@ -38,7 +36,7 @@ class OnlineAverager(torch.nn.Module):
|
|
|
38
36
|
batch_size: int,
|
|
39
37
|
num_updates: int,
|
|
40
38
|
num_channels: int,
|
|
41
|
-
offset:
|
|
39
|
+
offset: int | None = None,
|
|
42
40
|
) -> None:
|
|
43
41
|
super().__init__()
|
|
44
42
|
self.update_size = update_size
|
|
@@ -76,8 +74,8 @@ class OnlineAverager(torch.nn.Module):
|
|
|
76
74
|
def forward(
|
|
77
75
|
self,
|
|
78
76
|
update: Float[Tensor, "batch channel time1"],
|
|
79
|
-
state:
|
|
80
|
-
) ->
|
|
77
|
+
state: Float[Tensor, "channel time2"] | None = None,
|
|
78
|
+
) -> tuple[Float[Tensor, "channel time3"], Float[Tensor, "channel time4"]]:
|
|
81
79
|
if state is None:
|
|
82
80
|
state = self.get_initial_state()
|
|
83
81
|
|