ml4gw 0.7.7__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. ml4gw/augmentations.py +5 -0
  2. ml4gw/dataloading/__init__.py +5 -0
  3. ml4gw/dataloading/chunked_dataset.py +2 -4
  4. ml4gw/dataloading/hdf5_dataset.py +12 -10
  5. ml4gw/dataloading/in_memory_dataset.py +12 -12
  6. ml4gw/distributions.py +2 -2
  7. ml4gw/gw.py +18 -21
  8. ml4gw/nn/__init__.py +6 -0
  9. ml4gw/nn/autoencoder/base.py +5 -9
  10. ml4gw/nn/autoencoder/convolutional.py +7 -10
  11. ml4gw/nn/autoencoder/skip_connection.py +3 -5
  12. ml4gw/nn/norm.py +4 -4
  13. ml4gw/nn/resnet/resnet_1d.py +12 -13
  14. ml4gw/nn/resnet/resnet_2d.py +13 -14
  15. ml4gw/nn/streaming/online_average.py +3 -5
  16. ml4gw/nn/streaming/snapshotter.py +10 -14
  17. ml4gw/spectral.py +20 -23
  18. ml4gw/transforms/__init__.py +6 -0
  19. ml4gw/transforms/decimator.py +183 -0
  20. ml4gw/transforms/iirfilter.py +3 -5
  21. ml4gw/transforms/pearson.py +3 -4
  22. ml4gw/transforms/qtransform.py +10 -11
  23. ml4gw/transforms/scaler.py +3 -5
  24. ml4gw/transforms/snr_rescaler.py +7 -11
  25. ml4gw/transforms/spectral.py +6 -13
  26. ml4gw/transforms/spectrogram.py +6 -3
  27. ml4gw/transforms/spline_interpolation.py +7 -9
  28. ml4gw/transforms/transform.py +4 -6
  29. ml4gw/transforms/waveforms.py +8 -15
  30. ml4gw/transforms/whitening.py +11 -16
  31. ml4gw/types.py +8 -5
  32. ml4gw/utils/interferometer.py +20 -3
  33. ml4gw/utils/slicing.py +26 -30
  34. ml4gw/waveforms/__init__.py +6 -0
  35. ml4gw/waveforms/cbc/phenom_p.py +7 -9
  36. ml4gw/waveforms/conversion.py +2 -4
  37. ml4gw/waveforms/generator.py +3 -3
  38. {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/METADATA +28 -8
  39. ml4gw-0.7.8.dist-info/RECORD +57 -0
  40. ml4gw-0.7.7.dist-info/RECORD +0 -56
  41. {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/WHEEL +0 -0
  42. {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/licenses/LICENSE +0 -0
  43. {ml4gw-0.7.7.dist-info → ml4gw-0.7.8.dist-info}/top_level.txt +0 -0
ml4gw/augmentations.py CHANGED
@@ -1,3 +1,8 @@
1
+ """
2
+ This module contains transformations that may be useful
3
+ for augmenting timeseries data during training
4
+ """
5
+
1
6
  import torch
2
7
  from jaxtyping import Float
3
8
  from torch import Tensor
@@ -1,3 +1,8 @@
1
+ """
2
+ This module contains tools for efficient in-memory and
3
+ out-of-memory dataloading.
4
+ """
5
+
1
6
  from .chunked_dataset import ChunkedTimeSeriesDataset
2
7
  from .hdf5_dataset import Hdf5TimeSeriesDataset
3
8
  from .in_memory_dataset import InMemoryDataset
@@ -94,10 +94,8 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
94
94
  # flatten it to make it easier to slice
95
95
  if chunk_size < self.kernel_size:
96
96
  raise ValueError(
97
- (
98
- "Can't sample kernels of size {} from chunk "
99
- "with size {}"
100
- ).format(self.kernel_size, chunk_size)
97
+ f"Can't sample kernels of size {self.kernel_size} from "
98
+ f"chunk with size {chunk_size}"
101
99
  )
102
100
  chunk = chunk.reshape(-1)
103
101
 
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from typing import Optional, Sequence, Union
2
+ from collections.abc import Sequence
3
3
 
4
4
  import h5py
5
5
  import numpy as np
@@ -63,13 +63,13 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
63
63
  kernel_size: int,
64
64
  batch_size: int,
65
65
  batches_per_epoch: int,
66
- coincident: Union[bool, str],
67
- num_files_per_batch: Optional[int] = None,
66
+ coincident: bool | str,
67
+ num_files_per_batch: int | None = None,
68
68
  ) -> None:
69
69
  if not isinstance(coincident, bool) and coincident != "files":
70
70
  raise ValueError(
71
71
  "coincident must be either a boolean or 'files', "
72
- "got unrecognized value {}".format(coincident)
72
+ f"got unrecognized value {coincident}"
73
73
  )
74
74
 
75
75
  self.fnames = np.array(fnames)
@@ -94,13 +94,11 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
94
94
  dset = f[channels[0]]
95
95
  if dset.chunks is None:
96
96
  warnings.warn(
97
- "File {} contains datasets that were generated "
97
+ f"File {fname} contains datasets that were generated "
98
98
  "without using chunked storage. This can have "
99
99
  "severe performance impacts at data loading time. "
100
100
  "If you need faster loading, try re-generating "
101
- "your dataset with chunked storage turned on.".format(
102
- fname
103
- ),
101
+ "your dataset with chunked storage turned on.",
104
102
  category=ContiguousHdf5Warning,
105
103
  stacklevel=2,
106
104
  )
@@ -153,7 +151,9 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
153
151
  unique_fnames, inv, counts = np.unique(
154
152
  fnames, return_inverse=True, return_counts=True
155
153
  )
156
- for i, (fname, count) in enumerate(zip(unique_fnames, counts)):
154
+ for i, (fname, count) in enumerate(
155
+ zip(unique_fnames, counts, strict=True)
156
+ ):
157
157
  size = self.sizes[fname]
158
158
  max_idx = size - self.kernel_size
159
159
 
@@ -185,7 +185,9 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
185
185
  # open the file and sample a different set of
186
186
  # kernels for each batch element it occupies
187
187
  with h5py.File(fname, "r") as f:
188
- for b, c, i in zip(batch_indices, channel_indices, idx):
188
+ for b, c, i in zip(
189
+ batch_indices, channel_indices, idx, strict=True
190
+ ):
189
191
  x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
190
192
  return torch.Tensor(x)
191
193
 
@@ -1,5 +1,4 @@
1
1
  import itertools
2
- from typing import Optional, Tuple, Union
3
2
 
4
3
  import torch
5
4
  from jaxtyping import Float
@@ -79,10 +78,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
79
78
  self,
80
79
  X: Float[Tensor, "channels time"],
81
80
  kernel_size: int,
82
- y: Optional[Float[Tensor, " time"]] = None,
81
+ y: Float[Tensor, " time"] | None = None,
83
82
  batch_size: int = 32,
84
83
  stride: int = 1,
85
- batches_per_epoch: Optional[int] = None,
84
+ batches_per_epoch: int | None = None,
86
85
  coincident: bool = True,
87
86
  shuffle: bool = True,
88
87
  device: str = "cpu",
@@ -122,10 +121,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
122
121
  batch_size * batches_per_epoch
123
122
  ):
124
123
  raise ValueError(
125
- "Number of kernels {} in timeseries insufficient "
126
- "to generate {} batches of size {}".format(
127
- self.num_kernels, batch_size, batches_per_epoch
128
- )
124
+ f"Number of kernels {self.num_kernels} in timeseries "
125
+ f"insufficient to generate {batch_size} batches of size "
126
+ f"{batches_per_epoch}"
129
127
  )
130
128
 
131
129
  self.batch_size = batch_size
@@ -191,7 +189,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
191
189
  # indices we'll need rather than having to generate
192
190
  # everything.
193
191
  idx = [range(self.num_kernels) for _ in range(len(self.X))]
194
- idx = zip(range(num_kernels), itertools.product(*idx))
192
+ idx = zip(
193
+ range(num_kernels), itertools.product(*idx), strict=False
194
+ )
195
195
  idx = torch.stack([torch.Tensor(i[1]) for i in idx])
196
196
  idx = idx.type(torch.int64).to(device)
197
197
  elif self.shuffle:
@@ -208,10 +208,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
208
208
 
209
209
  def __iter__(
210
210
  self,
211
- ) -> Union[
212
- Float[Tensor, "batch channel time"],
213
- Tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]],
214
- ]:
211
+ ) -> (
212
+ Float[Tensor, "batch channel time"]
213
+ | tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]]
214
+ ):
215
215
  indices = self.init_indices()
216
216
  for i in range(len(self)):
217
217
  # slice the array of _indices_ we'll be using to
ml4gw/distributions.py CHANGED
@@ -6,7 +6,7 @@ from the corresponding distribution.
6
6
  """
7
7
 
8
8
  import math
9
- from typing import Callable, Optional
9
+ from collections.abc import Callable
10
10
 
11
11
  import torch
12
12
  import torch.distributions as dist
@@ -104,7 +104,7 @@ class LogNormal(dist.LogNormal):
104
104
  self,
105
105
  mean: float,
106
106
  std: float,
107
- low: Optional[float] = None,
107
+ low: float | None = None,
108
108
  validate_args=None,
109
109
  ):
110
110
  self.low = low
ml4gw/gw.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
- Tools for manipulating raw gravitational waveforms
3
- and projecting them onto interferometer responses.
2
+ Tools for manipulating raw gravitational waveforms,
3
+ projecting them onto interferometer responses, and
4
+ calculating SNRs.
4
5
  Much of the projection code is an extension of the
5
6
  implementation made available in
6
7
  `bilby <https://arxiv.org/abs/1811.02042>`_.
@@ -8,8 +9,6 @@ Specifically code from
8
9
  `this module <https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py>`_.
9
10
  """ # noqa E501
10
11
 
11
- from typing import List, Tuple, Union
12
-
13
12
  import torch
14
13
  from jaxtyping import Float
15
14
  from torch import Tensor
@@ -58,7 +57,7 @@ def compute_antenna_responses(
58
57
  psi: BatchTensor,
59
58
  phi: BatchTensor,
60
59
  detector_tensors: NetworkDetectorTensors,
61
- modes: List[str],
60
+ modes: list[str],
62
61
  ) -> Float[Tensor, "batch polarizations num_ifos"]:
63
62
  """
64
63
  Compute the antenna pattern factors of a batch of
@@ -257,7 +256,7 @@ def compute_observed_strain(
257
256
 
258
257
  def get_ifo_geometry(
259
258
  *ifos: str,
260
- ) -> Tuple[NetworkDetectorTensors, NetworkVertices]:
259
+ ) -> tuple[NetworkDetectorTensors, NetworkVertices]:
261
260
  """
262
261
  For a given list of interferometer names, retrieve and
263
262
  concatenate the associated detector tensors and vertices
@@ -286,8 +285,8 @@ def compute_ifo_snr(
286
285
  responses: WaveformTensor,
287
286
  psd: PSDTensor,
288
287
  sample_rate: float,
289
- highpass: Union[float, Float[Tensor, " frequency"], None] = None,
290
- lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
288
+ highpass: float | Float[Tensor, " frequency"] | None = None,
289
+ lowpass: float | Float[Tensor, " frequency"] | None = None,
291
290
  ) -> Float[Tensor, "batch num_ifos"]:
292
291
  """Compute the SNRs of a batch of interferometer responses
293
292
 
@@ -367,10 +366,9 @@ def compute_ifo_snr(
367
366
  highpass = freqs >= highpass
368
367
  elif len(highpass) != integrand.shape[-1]:
369
368
  raise ValueError(
370
- "Can't apply highpass filter mask with {} frequency bins"
371
- "to signal fft with {} frequency bins".format(
372
- len(highpass), integrand.shape[-1]
373
- )
369
+ f"Can't apply highpass filter mask with {len(highpass)} "
370
+ f"frequency bins to signal fft with {integrand.shape[-1]} "
371
+ "frequency bins"
374
372
  )
375
373
  integrand *= highpass.to(integrand.device)
376
374
  if lowpass is not None:
@@ -379,10 +377,9 @@ def compute_ifo_snr(
379
377
  lowpass = freqs < lowpass
380
378
  elif len(lowpass) != integrand.shape[-1]:
381
379
  raise ValueError(
382
- "Can't apply lowpass filter mask with {} frequency bins"
383
- "to signal fft with {} frequency bins".format(
384
- len(lowpass), integrand.shape[-1]
385
- )
380
+ f"Can't apply lowpass filter mask with {len(lowpass)} "
381
+ f"frequency bins to signal fft with {integrand.shape[-1]} "
382
+ "frequency bins"
386
383
  )
387
384
  integrand *= lowpass.to(integrand.device)
388
385
 
@@ -410,8 +407,8 @@ def compute_network_snr(
410
407
  responses: WaveformTensor,
411
408
  psd: PSDTensor,
412
409
  sample_rate: float,
413
- highpass: Union[float, Float[Tensor, " frequency"], None] = None,
414
- lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
410
+ highpass: float | Float[Tensor, " frequency"] | None = None,
411
+ lowpass: float | Float[Tensor, " frequency"] | None = None,
415
412
  ) -> BatchTensor:
416
413
  """
417
414
  Compute the total SNR from a gravitational waveform
@@ -467,11 +464,11 @@ def compute_network_snr(
467
464
 
468
465
  def reweight_snrs(
469
466
  responses: WaveformTensor,
470
- target_snrs: Union[float, BatchTensor],
467
+ target_snrs: float | BatchTensor,
471
468
  psd: PSDTensor,
472
469
  sample_rate: float,
473
- highpass: Union[float, Float[Tensor, " frequency"], None] = None,
474
- lowpass: Union[float, Float[Tensor, " frequency"], None] = None,
470
+ highpass: float | Float[Tensor, " frequency"] | None = None,
471
+ lowpass: float | Float[Tensor, " frequency"] | None = None,
475
472
  ) -> WaveformTensor:
476
473
  """Scale interferometer responses such that they have a desired SNR
477
474
 
ml4gw/nn/__init__.py CHANGED
@@ -0,0 +1,6 @@
1
+ """
2
+ This module contains neural network architectures and
3
+ architecture components. These can be a good place
4
+ to get started, rather than defining your own
5
+ architecture from the start.
6
+ """
@@ -1,5 +1,4 @@
1
1
  from collections.abc import Sequence
2
- from typing import Optional, Tuple, Union
3
2
 
4
3
  import torch
5
4
  from torch import Tensor
@@ -28,16 +27,14 @@ class Autoencoder(torch.nn.Module):
28
27
  and how they operate.
29
28
  """
30
29
 
31
- def __init__(
32
- self, skip_connection: Optional[SkipConnection] = None
33
- ) -> None:
30
+ def __init__(self, skip_connection: SkipConnection | None = None) -> None:
34
31
  super().__init__()
35
32
  self.skip_connection = skip_connection
36
33
  self.blocks = torch.nn.ModuleList()
37
34
 
38
35
  def encode(
39
36
  self, *X: Tensor, return_states: bool = False
40
- ) -> Union[Tensor, Tuple[Tensor, Sequence]]:
37
+ ) -> Tensor | tuple[Tensor, Sequence]:
41
38
  states = []
42
39
  for block in self.blocks:
43
40
  if isinstance(X, tuple):
@@ -53,7 +50,7 @@ class Autoencoder(torch.nn.Module):
53
50
  return X, states[:-1]
54
51
  return X
55
52
 
56
- def decode(self, *X, states: Optional[Sequence[Tensor]] = None) -> Tensor:
53
+ def decode(self, *X, states: Sequence[Tensor] | None = None) -> Tensor:
57
54
  if self.skip_connection is not None and states is None:
58
55
  raise ValueError(
59
56
  "Must pass intermediate states when autoencoder "
@@ -62,9 +59,8 @@ class Autoencoder(torch.nn.Module):
62
59
  elif states is not None:
63
60
  if len(states) != len(self.blocks) - 1:
64
61
  raise ValueError(
65
- "Passed {} intermediate states, expected {}".format(
66
- len(states), len(self.blocks) - 1
67
- )
62
+ f"Passed {len(states)} intermediate states, expected "
63
+ f"{len(self.blocks) - 1}"
68
64
  )
69
65
 
70
66
  # Don't skip connect the output layer
@@ -1,5 +1,4 @@
1
1
  from collections.abc import Callable, Sequence
2
- from typing import Optional
3
2
 
4
3
  import torch
5
4
  from torch import Tensor
@@ -21,9 +20,9 @@ class ConvBlock(Autoencoder):
21
20
  groups: int = 1,
22
21
  activation: torch.nn.Module = torch.nn.ReLU,
23
22
  norm: Module = torch.nn.BatchNorm1d,
24
- decode_channels: Optional[int] = None,
25
- output_activation: Optional[torch.nn.Module] = None,
26
- skip_connection: Optional[SkipConnection] = None,
23
+ decode_channels: int | None = None,
24
+ output_activation: torch.nn.Module | None = None,
25
+ skip_connection: SkipConnection | None = None,
27
26
  ) -> None:
28
27
  super().__init__(skip_connection=None)
29
28
 
@@ -98,10 +97,10 @@ class ConvolutionalAutoencoder(Autoencoder):
98
97
  stride: int = 1,
99
98
  groups: int = 1,
100
99
  activation: torch.nn.Module = torch.nn.ReLU,
101
- output_activation: Optional[torch.nn.Module] = None,
100
+ output_activation: torch.nn.Module | None = None,
102
101
  norm: Module = torch.nn.BatchNorm1d,
103
- decode_channels: Optional[int] = None,
104
- skip_connection: Optional[SkipConnection] = None,
102
+ decode_channels: int | None = None,
103
+ skip_connection: SkipConnection | None = None,
105
104
  ) -> None:
106
105
  # TODO: how to do this dynamically? Maybe the base
107
106
  # architecture looks for overlapping arguments between
@@ -145,9 +144,7 @@ class ConvolutionalAutoencoder(Autoencoder):
145
144
  self.blocks.append(block)
146
145
  in_channels = channels * groups
147
146
 
148
- def decode(
149
- self, *X, states=None, input_size: Optional[int] = None
150
- ) -> Tensor:
147
+ def decode(self, *X, states=None, input_size: int | None = None) -> Tensor:
151
148
  X = super().decode(*X, states=states)
152
149
  if input_size is not None:
153
150
  return match_size(X, input_size)
@@ -35,13 +35,11 @@ class ConcatSkipConnect(SkipConnection):
35
35
  rem = num_channels % self.groups
36
36
  if rem:
37
37
  raise ValueError(
38
- "Number of channels in input tensor {} cannot "
39
- "be divided evenly into {} groups".format(
40
- num_channels, self.groups
41
- )
38
+ f"Number of channels in input tensor {num_channels} cannot "
39
+ f"be divided evenly into {self.groups} groups"
42
40
  )
43
41
 
44
42
  X = torch.split(X, self.groups, dim=1)
45
43
  state = torch.split(state, self.groups, dim=1)
46
- frags = [i for j in zip(X, state) for i in j]
44
+ frags = [i for j in zip(X, state, strict=True) for i in j]
47
45
  return torch.cat(frags, dim=1)
ml4gw/nn/norm.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Callable, Optional
1
+ from collections.abc import Callable
2
2
 
3
3
  import torch
4
4
  from jaxtyping import Float
@@ -16,7 +16,7 @@ class GroupNorm1D(torch.nn.Module):
16
16
  def __init__(
17
17
  self,
18
18
  num_channels: int,
19
- num_groups: Optional[int] = None,
19
+ num_groups: int | None = None,
20
20
  eps: float = 1e-5,
21
21
  ):
22
22
  super().__init__()
@@ -77,7 +77,7 @@ class GroupNorm1DGetter:
77
77
  for command-line parameterization with jsonargparse.
78
78
  """
79
79
 
80
- def __init__(self, groups: Optional[int] = None) -> None:
80
+ def __init__(self, groups: int | None = None) -> None:
81
81
  self.groups = groups
82
82
 
83
83
  def __call__(self, num_channels: int) -> torch.nn.Module:
@@ -96,7 +96,7 @@ class GroupNorm2DGetter:
96
96
  for command-line parameterization with jsonargparse.
97
97
  """
98
98
 
99
- def __init__(self, groups: Optional[int] = None) -> None:
99
+ def __init__(self, groups: int | None = None) -> None:
100
100
  self.groups = groups
101
101
 
102
102
  def __call__(self, num_channels: int) -> torch.nn.Module:
@@ -7,7 +7,8 @@ where training-time statistics are entirely arbitrary due to
7
7
  simulations.
8
8
  """
9
9
 
10
- from typing import Callable, List, Literal, Optional
10
+ from collections.abc import Callable
11
+ from typing import Literal
11
12
 
12
13
  import torch
13
14
  import torch.nn as nn
@@ -58,11 +59,11 @@ class BasicBlock(nn.Module):
58
59
  planes: int,
59
60
  kernel_size: int = 3,
60
61
  stride: int = 1,
61
- downsample: Optional[nn.Module] = None,
62
+ downsample: nn.Module | None = None,
62
63
  groups: int = 1,
63
64
  base_width: int = 64,
64
65
  dilation: int = 1,
65
- norm_layer: Optional[Callable[..., nn.Module]] = None,
66
+ norm_layer: Callable[..., nn.Module] | None = None,
66
67
  ) -> None:
67
68
  super().__init__()
68
69
  if norm_layer is None:
@@ -123,11 +124,11 @@ class Bottleneck(nn.Module):
123
124
  planes: int,
124
125
  kernel_size: int = 3,
125
126
  stride: int = 1,
126
- downsample: Optional[nn.Module] = None,
127
+ downsample: nn.Module | None = None,
127
128
  groups: int = 1,
128
129
  base_width: int = 64,
129
130
  dilation: int = 1,
130
- norm_layer: Optional[NormLayer] = None,
131
+ norm_layer: NormLayer | None = None,
131
132
  ) -> None:
132
133
  super().__init__()
133
134
  if norm_layer is None:
@@ -231,14 +232,14 @@ class ResNet1D(nn.Module):
231
232
  def __init__(
232
233
  self,
233
234
  in_channels: int,
234
- layers: List[int],
235
+ layers: list[int],
235
236
  classes: int,
236
237
  kernel_size: int = 3,
237
238
  zero_init_residual: bool = False,
238
239
  groups: int = 1,
239
240
  width_per_group: int = 64,
240
- stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
241
- norm_layer: Optional[NormLayer] = None,
241
+ stride_type: list[Literal["stride", "dilation"]] | None = None,
242
+ norm_layer: NormLayer | None = None,
242
243
  ) -> None:
243
244
  super().__init__()
244
245
 
@@ -257,10 +258,8 @@ class ResNet1D(nn.Module):
257
258
  stride_type = ["stride"] * (len(layers) - 1)
258
259
  if len(stride_type) != (len(layers) - 1):
259
260
  raise ValueError(
260
- (
261
- "'stride_type' should be None or a {}-element "
262
- "tuple, got {}"
263
- ).format(len(layers) - 1, stride_type)
261
+ f"'stride_type' should be None or a {len(layers) - 1}-element "
262
+ f"tuple, got {stride_type}"
264
263
  )
265
264
 
266
265
  self.groups = groups
@@ -289,7 +288,7 @@ class ResNet1D(nn.Module):
289
288
  # striding or dilating depending on the stride_type
290
289
  # argument)
291
290
  residual_layers = [self._make_layer(64, layers[0], kernel_size)]
292
- it = zip(layers[1:], stride_type)
291
+ it = zip(layers[1:], stride_type, strict=True)
293
292
  for i, (num_blocks, stride) in enumerate(it):
294
293
  block_size = 64 * 2 ** (i + 1)
295
294
  layer = self._make_layer(
@@ -4,7 +4,8 @@ https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
4
4
  but with arbitrary kernel sizes
5
5
  """
6
6
 
7
- from typing import Callable, List, Literal, Optional
7
+ from collections.abc import Callable
8
+ from typing import Literal
8
9
 
9
10
  import torch
10
11
  import torch.nn as nn
@@ -55,11 +56,11 @@ class BasicBlock(nn.Module):
55
56
  planes: int,
56
57
  kernel_size: int = 3,
57
58
  stride: int = 1,
58
- downsample: Optional[nn.Module] = None,
59
+ downsample: nn.Module | None = None,
59
60
  groups: int = 1,
60
61
  base_width: int = 64,
61
62
  dilation: int = 1,
62
- norm_layer: Optional[Callable[..., nn.Module]] = None,
63
+ norm_layer: Callable[..., nn.Module] | None = None,
63
64
  ) -> None:
64
65
  super().__init__()
65
66
  if norm_layer is None:
@@ -120,11 +121,11 @@ class Bottleneck(nn.Module):
120
121
  planes: int,
121
122
  kernel_size: int = 3,
122
123
  stride: int = 1,
123
- downsample: Optional[nn.Module] = None,
124
+ downsample: nn.Module | None = None,
124
125
  groups: int = 1,
125
126
  base_width: int = 64,
126
127
  dilation: int = 1,
127
- norm_layer: Optional[Callable[..., nn.Module]] = None,
128
+ norm_layer: Callable[..., nn.Module] | None = None,
128
129
  ) -> None:
129
130
  super().__init__()
130
131
  if norm_layer is None:
@@ -232,14 +233,14 @@ class ResNet2D(nn.Module):
232
233
  def __init__(
233
234
  self,
234
235
  in_channels: int,
235
- layers: List[int],
236
+ layers: list[int],
236
237
  classes: int,
237
238
  kernel_size: int = 3,
238
239
  zero_init_residual: bool = False,
239
240
  groups: int = 1,
240
241
  width_per_group: int = 64,
241
- stride_type: Optional[List[Literal["stride", "dilation"]]] = None,
242
- norm_layer: Optional[NormLayer] = None,
242
+ stride_type: list[Literal["stride", "dilation"]] | None = None,
243
+ norm_layer: NormLayer | None = None,
243
244
  ) -> None:
244
245
  super().__init__()
245
246
  # default to using InstanceNorm if no
@@ -257,10 +258,8 @@ class ResNet2D(nn.Module):
257
258
  stride_type = ["stride"] * (len(layers) - 1)
258
259
  if len(stride_type) != (len(layers) - 1):
259
260
  raise ValueError(
260
- (
261
- "'stride_type' should be None or a {}-element "
262
- "tuple, got {}"
263
- ).format(len(layers) - 1, stride_type)
261
+ f"'stride_type' should be None or a {len(layers) - 1}-element "
262
+ f"tuple, got {stride_type}"
264
263
  )
265
264
 
266
265
  self.groups = groups
@@ -289,7 +288,7 @@ class ResNet2D(nn.Module):
289
288
  # striding or dilating depending on the stride_type
290
289
  # argument)
291
290
  residual_layers = [self._make_layer(64, layers[0], kernel_size)]
292
- it = zip(layers[1:], stride_type)
291
+ it = zip(layers[1:], stride_type, strict=True)
293
292
  for i, (num_blocks, stride) in enumerate(it):
294
293
  block_size = 64 * 2 ** (i + 1)
295
294
  layer = self._make_layer(
@@ -316,7 +315,7 @@ class ResNet2D(nn.Module):
316
315
  nn.init.kaiming_normal_(
317
316
  m.weight, mode="fan_out", nonlinearity="relu"
318
317
  )
319
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
318
+ elif isinstance(m, nn.BatchNorm2d | nn.GroupNorm):
320
319
  nn.init.constant_(m.weight, 1)
321
320
  nn.init.constant_(m.bias, 0)
322
321
 
@@ -1,5 +1,3 @@
1
- from typing import Optional, Tuple
2
-
3
1
  import torch
4
2
  from jaxtyping import Float
5
3
  from torch import Tensor
@@ -38,7 +36,7 @@ class OnlineAverager(torch.nn.Module):
38
36
  batch_size: int,
39
37
  num_updates: int,
40
38
  num_channels: int,
41
- offset: Optional[int] = None,
39
+ offset: int | None = None,
42
40
  ) -> None:
43
41
  super().__init__()
44
42
  self.update_size = update_size
@@ -76,8 +74,8 @@ class OnlineAverager(torch.nn.Module):
76
74
  def forward(
77
75
  self,
78
76
  update: Float[Tensor, "batch channel time1"],
79
- state: Optional[Float[Tensor, "channel time2"]] = None,
80
- ) -> Tuple[Float[Tensor, "channel time3"], Float[Tensor, "channel time4"]]:
77
+ state: Float[Tensor, "channel time2"] | None = None,
78
+ ) -> tuple[Float[Tensor, "channel time3"], Float[Tensor, "channel time4"]]:
81
79
  if state is None:
82
80
  state = self.get_initial_state()
83
81
 
@@ -1,4 +1,4 @@
1
- from typing import Optional, Sequence, Tuple
1
+ from collections.abc import Sequence
2
2
 
3
3
  import torch
4
4
  from jaxtyping import Float
@@ -58,15 +58,13 @@ class Snapshotter(torch.nn.Module):
58
58
  snapshot_size: int,
59
59
  stride_size: int,
60
60
  batch_size: int,
61
- channels_per_snapshot: Optional[Sequence[int]] = None,
61
+ channels_per_snapshot: Sequence[int] | None = None,
62
62
  ) -> None:
63
63
  super().__init__()
64
64
  if stride_size >= snapshot_size:
65
65
  raise ValueError(
66
- "Snapshotter can't accommodate stride {} "
67
- "which is greater than snapshot size {}".format(
68
- stride_size, snapshot_size
69
- )
66
+ f"Snapshotter can't accommodate stride {stride_size} "
67
+ f"which is greater than snapshot size {snapshot_size}"
70
68
  )
71
69
 
72
70
  self.snapshot_size = snapshot_size
@@ -77,9 +75,8 @@ class Snapshotter(torch.nn.Module):
77
75
  if channels_per_snapshot is not None:
78
76
  if sum(channels_per_snapshot) != num_channels:
79
77
  raise ValueError(
80
- "Can't break {} channels into {}".format(
81
- num_channels, channels_per_snapshot
82
- )
78
+ f"Can't break {num_channels} channels into "
79
+ f"{channels_per_snapshot}"
83
80
  )
84
81
  self.channels_per_snapshot = channels_per_snapshot
85
82
  self.num_channels = num_channels
@@ -90,8 +87,8 @@ class Snapshotter(torch.nn.Module):
90
87
  def forward(
91
88
  self,
92
89
  update: Float[Tensor, "channel time1"],
93
- snapshot: Optional[Float[Tensor, "channel time2"]] = None,
94
- ) -> Tuple[Tensor, ...]:
90
+ snapshot: Float[Tensor, "channel time2"] | None = None,
91
+ ) -> tuple[Tensor, ...]:
95
92
  if snapshot is None:
96
93
  snapshot = self.get_initial_state()
97
94
 
@@ -108,9 +105,8 @@ class Snapshotter(torch.nn.Module):
108
105
  if self.channels_per_snapshot is not None:
109
106
  if snapshots.size(1) != self.num_channels:
110
107
  raise ValueError(
111
- "Expected {} channels, found {}".format(
112
- self.num_channels, snapshots.size(1)
113
- )
108
+ f"Expected {self.num_channels} channels, found "
109
+ f"{snapshots.size(1)}"
114
110
  )
115
111
  snapshots = torch.split(
116
112
  snapshots, self.channels_per_snapshot, dim=1