ml4gw 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/augmentations.py CHANGED
@@ -1,4 +1,6 @@
1
1
  import torch
2
+ from jaxtyping import Float
3
+ from torch import Tensor
2
4
 
3
5
 
4
6
  class SignalInverter(torch.nn.Module):
@@ -16,7 +18,9 @@ class SignalInverter(torch.nn.Module):
16
18
  super().__init__()
17
19
  self.prob = prob
18
20
 
19
- def forward(self, X):
21
+ def forward(
22
+ self, X: Float[Tensor, "*batch time"]
23
+ ) -> Float[Tensor, "*batch time"]:
20
24
  mask = torch.rand(size=X.shape[:-1]) < self.prob
21
25
  X[mask] *= -1
22
26
  return X
@@ -37,7 +41,9 @@ class SignalReverser(torch.nn.Module):
37
41
  super().__init__()
38
42
  self.prob = prob
39
43
 
40
- def forward(self, X):
44
+ def forward(
45
+ self, X: Float[Tensor, "*batch time"]
46
+ ) -> Float[Tensor, "*batch time"]:
41
47
  mask = torch.rand(size=X.shape[:-1]) < self.prob
42
48
  X[mask] = X[mask].flip(-1)
43
49
  return X
@@ -2,6 +2,8 @@ from collections.abc import Iterable
2
2
 
3
3
  import torch
4
4
 
5
+ from ml4gw.types import WaveformTensor
6
+
5
7
 
6
8
  class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
7
9
  """
@@ -55,10 +57,10 @@ class ChunkedTimeSeriesDataset(torch.utils.data.IterableDataset):
55
57
  self.coincident = coincident
56
58
  self.device = device
57
59
 
58
- def __len__(self):
60
+ def __len__(self) -> int:
59
61
  return len(self.chunk_it) * self.batches_per_chunk
60
62
 
61
- def __iter__(self):
63
+ def __iter__(self) -> WaveformTensor:
62
64
  it = iter(self.chunk_it)
63
65
  chunk = next(it)
64
66
  num_chunks, num_channels, chunk_size = chunk.shape
@@ -161,7 +161,7 @@ class Hdf5TimeSeriesDataset(torch.utils.data.IterableDataset):
161
161
  x[b, c] = f[self.channels[c]][i : i + self.kernel_size]
162
162
  return torch.Tensor(x)
163
163
 
164
- def __iter__(self) -> torch.Tensor:
164
+ def __iter__(self) -> WaveformTensor:
165
165
  worker_info = torch.utils.data.get_worker_info()
166
166
  if worker_info is None:
167
167
  num_batches = self.batches_per_epoch
@@ -2,8 +2,9 @@ import itertools
2
2
  from typing import Optional, Tuple, Union
3
3
 
4
4
  import torch
5
+ from jaxtyping import Float
6
+ from torch import Tensor
5
7
 
6
- from ml4gw import types
7
8
  from ml4gw.utils.slicing import slice_kernels
8
9
 
9
10
 
@@ -76,9 +77,9 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
76
77
 
77
78
  def __init__(
78
79
  self,
79
- X: types.TimeSeriesTensor,
80
+ X: Float[Tensor, "channels time"],
80
81
  kernel_size: int,
81
- y: Optional[types.ScalarTensor] = None,
82
+ y: Optional[Float[Tensor, " time"]] = None,
82
83
  batch_size: int = 32,
83
84
  stride: int = 1,
84
85
  batches_per_epoch: Optional[int] = None,
@@ -207,7 +208,10 @@ class InMemoryDataset(torch.utils.data.IterableDataset):
207
208
 
208
209
  def __iter__(
209
210
  self,
210
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
211
+ ) -> Union[
212
+ Float[Tensor, "batch channel time"],
213
+ Tuple[Float[Tensor, "batch channel time"], Float[Tensor, " batch"]],
214
+ ]:
211
215
 
212
216
  indices = self.init_indices()
213
217
  for i in range(len(self)):
ml4gw/distributions.py CHANGED
@@ -9,6 +9,8 @@ from typing import Optional
9
9
 
10
10
  import torch
11
11
  import torch.distributions as dist
12
+ from jaxtyping import Float
13
+ from torch import Tensor
12
14
 
13
15
 
14
16
  class Cosine(dist.Distribution):
@@ -31,11 +33,11 @@ class Cosine(dist.Distribution):
31
33
  self.high = torch.as_tensor(high)
32
34
  self.norm = 1 / (torch.sin(self.high) - torch.sin(self.low))
33
35
 
34
- def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
36
+ def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
35
37
  u = torch.rand(sample_shape, device=self.low.device)
36
38
  return torch.arcsin(u / self.norm + torch.sin(self.low))
37
39
 
38
- def log_prob(self, value):
40
+ def log_prob(self, value: float) -> Float[Tensor, ""]:
39
41
  value = torch.as_tensor(value)
40
42
  inside_range = (value >= self.low) & (value <= self.high)
41
43
  return value.cos().log() * inside_range
@@ -164,7 +166,7 @@ class DeltaFunction(dist.Distribution):
164
166
  super().__init__(batch_shape, validate_args=validate_args)
165
167
  self.peak = torch.as_tensor(peak)
166
168
 
167
- def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
169
+ def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
168
170
  return self.peak * torch.ones(
169
171
  sample_shape, device=self.peak.device, dtype=torch.float32
170
172
  )
ml4gw/gw.py CHANGED
@@ -13,27 +13,21 @@ https://github.com/lscsoft/bilby/blob/master/bilby/gw/detector/interferometer.py
13
13
  from typing import List, Tuple, Union
14
14
 
15
15
  import torch
16
- from torchtyping import TensorType
16
+ from jaxtyping import Float
17
+ from torch import Tensor
17
18
 
19
+ from ml4gw.constants import C
18
20
  from ml4gw.types import (
21
+ BatchTensor,
19
22
  NetworkDetectorTensors,
20
23
  NetworkVertices,
21
24
  PSDTensor,
22
- ScalarTensor,
23
25
  TensorGeometry,
24
26
  VectorGeometry,
25
27
  WaveformTensor,
26
28
  )
27
29
  from ml4gw.utils.interferometer import InterferometerGeometry
28
30
 
29
- SPEED_OF_LIGHT = 299792458.0 # m/s
30
-
31
-
32
- # define some tensor shapes we'll reuse a bit
33
- # up front. Need to assign these variables so
34
- # that static linters don't give us name errors
35
- batch = num_ifos = polarizations = time = frequency = space = None # noqa
36
-
37
31
 
38
32
  def outer(x: VectorGeometry, y: VectorGeometry) -> TensorGeometry:
39
33
  """
@@ -62,12 +56,12 @@ polarization_funcs = {
62
56
 
63
57
 
64
58
  def compute_antenna_responses(
65
- theta: ScalarTensor,
66
- psi: ScalarTensor,
67
- phi: ScalarTensor,
59
+ theta: BatchTensor,
60
+ psi: BatchTensor,
61
+ phi: BatchTensor,
68
62
  detector_tensors: NetworkDetectorTensors,
69
63
  modes: List[str],
70
- ) -> TensorType["batch", "polarizations", "num_ifos"]:
64
+ ) -> Float[Tensor, "batch polarizations num_ifos"]:
71
65
  """
72
66
  Compute the antenna pattern factors of a batch of
73
67
  waveforms as a function of the sky parameters of
@@ -147,8 +141,8 @@ def compute_antenna_responses(
147
141
 
148
142
  def shift_responses(
149
143
  responses: WaveformTensor,
150
- theta: ScalarTensor,
151
- phi: ScalarTensor,
144
+ theta: BatchTensor,
145
+ phi: BatchTensor,
152
146
  vertices: NetworkVertices,
153
147
  sample_rate: float,
154
148
  ) -> WaveformTensor:
@@ -166,7 +160,7 @@ def shift_responses(
166
160
  # Divide by c in the second line so that we only
167
161
  # need to multiply the array by a single float
168
162
  dt = -(omega * vertices).sum(axis=-1)
169
- dt *= sample_rate / SPEED_OF_LIGHT
163
+ dt *= sample_rate / C
170
164
  dt = torch.trunc(dt).type(torch.int64)
171
165
 
172
166
  # rolling by gathering implementation based on
@@ -191,13 +185,13 @@ def shift_responses(
191
185
 
192
186
 
193
187
  def compute_observed_strain(
194
- dec: ScalarTensor,
195
- psi: ScalarTensor,
196
- phi: ScalarTensor,
188
+ dec: BatchTensor,
189
+ psi: BatchTensor,
190
+ phi: BatchTensor,
197
191
  detector_tensors: NetworkDetectorTensors,
198
192
  detector_vertices: NetworkVertices,
199
193
  sample_rate: float,
200
- **polarizations: TensorType["batch", "time"],
194
+ **polarizations: Float[Tensor, "batch time"],
201
195
  ) -> WaveformTensor:
202
196
  """
203
197
  Compute the strain timeseries $h(t)$ observed by a network
@@ -289,8 +283,8 @@ def compute_ifo_snr(
289
283
  responses: WaveformTensor,
290
284
  psd: PSDTensor,
291
285
  sample_rate: float,
292
- highpass: Union[float, TensorType["frequency"], None] = None,
293
- ) -> TensorType["batch", "num_ifos"]:
286
+ highpass: Union[float, Float[Tensor, " frequency"], None] = None,
287
+ ) -> Float[Tensor, "batch num_ifos"]:
294
288
  r"""Compute the SNRs of a batch of interferometer responses
295
289
 
296
290
  Compute the signal to noise ratio (SNR) of individual
@@ -390,8 +384,8 @@ def compute_network_snr(
390
384
  responses: WaveformTensor,
391
385
  psd: PSDTensor,
392
386
  sample_rate: float,
393
- highpass: Union[float, TensorType["frequency"], None] = None,
394
- ) -> ScalarTensor:
387
+ highpass: Union[float, Float[Tensor, " frequency"], None] = None,
388
+ ) -> BatchTensor:
395
389
  r"""
396
390
  Compute the total SNR from a gravitational waveform
397
391
  from a network of interferometers. The total SNR for
@@ -437,10 +431,10 @@ def compute_network_snr(
437
431
 
438
432
  def reweight_snrs(
439
433
  responses: WaveformTensor,
440
- target_snrs: Union[float, ScalarTensor],
434
+ target_snrs: Union[float, BatchTensor],
441
435
  psd: PSDTensor,
442
436
  sample_rate: float,
443
- highpass: Union[float, TensorType["frequency"], None] = None,
437
+ highpass: Union[float, Float[Tensor, " frequency"], None] = None,
444
438
  ) -> WaveformTensor:
445
439
  """Scale interferometer responses such that they have a desired SNR
446
440
 
@@ -1,7 +1,8 @@
1
1
  from collections.abc import Sequence
2
- from typing import Optional
2
+ from typing import Optional, Tuple, Union
3
3
 
4
4
  import torch
5
+ from torch import Tensor
5
6
 
6
7
  from ml4gw.nn.autoencoder.skip_connection import SkipConnection
7
8
 
@@ -27,12 +28,16 @@ class Autoencoder(torch.nn.Module):
27
28
  and how they operate.
28
29
  """
29
30
 
30
- def __init__(self, skip_connection: Optional[SkipConnection] = None):
31
+ def __init__(
32
+ self, skip_connection: Optional[SkipConnection] = None
33
+ ) -> None:
31
34
  super().__init__()
32
35
  self.skip_connection = skip_connection
33
36
  self.blocks = torch.nn.ModuleList()
34
37
 
35
- def encode(self, *X: torch.Tensor, return_states: bool = False):
38
+ def encode(
39
+ self, *X: Tensor, return_states: bool = False
40
+ ) -> Union[Tensor, Tuple[Tensor, Sequence]]:
36
41
  states = []
37
42
  for block in self.blocks:
38
43
  if isinstance(X, tuple):
@@ -48,7 +53,7 @@ class Autoencoder(torch.nn.Module):
48
53
  return X, states[:-1]
49
54
  return X
50
55
 
51
- def decode(self, *X, states: Optional[Sequence[torch.Tensor]] = None):
56
+ def decode(self, *X, states: Optional[Sequence[Tensor]] = None) -> Tensor:
52
57
  if self.skip_connection is not None and states is None:
53
58
  raise ValueError(
54
59
  "Must pass intermediate states when autoencoder "
@@ -76,7 +81,7 @@ class Autoencoder(torch.nn.Module):
76
81
  X = self.skip_connection(X, state)
77
82
  return X
78
83
 
79
- def forward(self, *X):
84
+ def forward(self, *X: Tensor) -> Tensor:
80
85
  return_states = self.skip_connection is not None
81
86
  X = self.encode(*X, return_states=return_states)
82
87
  if return_states:
@@ -84,6 +89,6 @@ class Autoencoder(torch.nn.Module):
84
89
  else:
85
90
  states = None
86
91
 
87
- if isinstance(X, torch.Tensor):
92
+ if isinstance(X, Tensor):
88
93
  X = (X,)
89
94
  return self.decode(*X, states=states)
@@ -2,6 +2,7 @@ from collections.abc import Callable, Sequence
2
2
  from typing import Optional
3
3
 
4
4
  import torch
5
+ from torch import Tensor
5
6
 
6
7
  from ml4gw.nn.autoencoder.base import Autoencoder
7
8
  from ml4gw.nn.autoencoder.skip_connection import SkipConnection
@@ -64,12 +65,12 @@ class ConvBlock(Autoencoder):
64
65
  self.encode_norm = norm(out_channels)
65
66
  self.decode_norm = norm(decode_channels)
66
67
 
67
- def encode(self, X):
68
+ def encode(self, X: Tensor) -> Tensor:
68
69
  X = self.encode_layer(X)
69
70
  X = self.encode_norm(X)
70
71
  return self.activation(X)
71
72
 
72
- def decode(self, X):
73
+ def decode(self, X: Tensor) -> Tensor:
73
74
  X = self.decode_layer(X)
74
75
  X = self.decode_norm(X)
75
76
  return self.output_activation(X)
@@ -144,13 +145,15 @@ class ConvolutionalAutoencoder(Autoencoder):
144
145
  self.blocks.append(block)
145
146
  in_channels = channels * groups
146
147
 
147
- def decode(self, *X, states=None, input_size: Optional[int] = None):
148
+ def decode(
149
+ self, *X, states=None, input_size: Optional[int] = None
150
+ ) -> Tensor:
148
151
  X = super().decode(*X, states=states)
149
152
  if input_size is not None:
150
153
  return match_size(X, input_size)
151
154
  return X
152
155
 
153
- def forward(self, X):
156
+ def forward(self, X: Tensor) -> Tensor:
154
157
  input_size = X.size(-1)
155
158
  X = super().forward(X)
156
159
  return match_size(X, input_size)
@@ -1,31 +1,32 @@
1
1
  import torch
2
+ from torch import Tensor
2
3
 
3
4
  from ml4gw.nn.autoencoder.utils import match_size
4
5
 
5
6
 
6
7
  class SkipConnection(torch.nn.Module):
7
- def forward(self, X: torch.Tensor, state: torch.Tensor):
8
+ def forward(self, X: Tensor, state: Tensor) -> Tensor:
8
9
  return match_size(X, state.size(-1))
9
10
 
10
- def get_out_channels(self, in_channels):
11
+ def get_out_channels(self, in_channels: int) -> int:
11
12
  return in_channels
12
13
 
13
14
 
14
15
  class AddSkipConnect(SkipConnection):
15
- def forward(self, X, state):
16
+ def forward(self, X: Tensor, state: Tensor) -> Tensor:
16
17
  X = super().forward(X, state)
17
18
  return X + state
18
19
 
19
20
 
20
21
  class ConcatSkipConnect(SkipConnection):
21
- def __init__(self, groups: int = 1):
22
+ def __init__(self, groups: int = 1) -> None:
22
23
  super().__init__()
23
24
  self.groups = groups
24
25
 
25
- def get_out_channels(self, in_channels):
26
+ def get_out_channels(self, in_channels: int) -> int:
26
27
  return 2 * in_channels
27
28
 
28
- def forward(self, X, state):
29
+ def forward(self, X: Tensor, state: Tensor) -> Tensor:
29
30
  X = super().forward(X, state)
30
31
  if self.groups == 1:
31
32
  return torch.cat([X, state], dim=1)
@@ -1,7 +1,8 @@
1
1
  import torch
2
+ from torch import Tensor
2
3
 
3
4
 
4
- def match_size(X: torch.Tensor, target_size: int):
5
+ def match_size(X: Tensor, target_size: int) -> Tensor:
5
6
  diff = target_size - X.size(-1)
6
7
  left = int(diff // 2)
7
8
  right = diff - left
ml4gw/nn/norm.py CHANGED
@@ -1,6 +1,8 @@
1
1
  from typing import Callable, Optional
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  NormLayer = Callable[[int], torch.nn.Module]
6
8
 
@@ -31,7 +33,9 @@ class GroupNorm1D(torch.nn.Module):
31
33
  self.weight = torch.nn.Parameter(torch.ones(shape))
32
34
  self.bias = torch.nn.Parameter(torch.zeros(shape))
33
35
 
34
- def forward(self, x):
36
+ def forward(
37
+ self, x: Float[Tensor, "batch channel length"]
38
+ ) -> Float[Tensor, "batch channel length"]:
35
39
  if len(x.shape) != 3:
36
40
  raise ValueError(
37
41
  "GroupNorm1D requires 3-dimensional input, "
@@ -1,11 +1,11 @@
1
1
  from typing import Optional, Tuple
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.utils.slicing import unfold_windows
6
8
 
7
- Tensor = torch.Tensor
8
-
9
9
 
10
10
  class OnlineAverager(torch.nn.Module):
11
11
  """
@@ -70,12 +70,14 @@ class OnlineAverager(torch.nn.Module):
70
70
  weights = unfold_windows(weights, weight_size, update_size)
71
71
  self.register_buffer("weights", weights)
72
72
 
73
- def get_initial_state(self):
73
+ def get_initial_state(self) -> Float[Tensor, "channel time"]:
74
74
  return torch.zeros((self.num_channels, self.state_size))
75
75
 
76
76
  def forward(
77
- self, update: torch.Tensor, state: Optional[torch.Tensor] = None
78
- ) -> Tuple[torch.Tensor, torch.Tensor]:
77
+ self,
78
+ update: Float[Tensor, "batch channel time1"],
79
+ state: Optional[Float[Tensor, "channel time2"]] = None,
80
+ ) -> Tuple[Float[Tensor, "channel time3"], Float[Tensor, "channel time4"]]:
79
81
  if state is None:
80
82
  state = self.get_initial_state()
81
83
 
@@ -1,6 +1,8 @@
1
1
  from typing import Optional, Sequence, Tuple
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
  from ml4gw.utils.slicing import unfold_windows
6
8
 
@@ -82,14 +84,14 @@ class Snapshotter(torch.nn.Module):
82
84
  self.channels_per_snapshot = channels_per_snapshot
83
85
  self.num_channels = num_channels
84
86
 
85
- def get_initial_state(self):
87
+ def get_initial_state(self) -> Float[Tensor, "channel time"]:
86
88
  return torch.zeros((self.num_channels, self.state_size))
87
89
 
88
- # TODO: use torchtyping annotations to make
89
- # clear what the expected shapes are
90
90
  def forward(
91
- self, update: torch.Tensor, snapshot: Optional[torch.Tensor] = None
92
- ) -> Tuple[torch.Tensor, ...]:
91
+ self,
92
+ update: Float[Tensor, "channel time1"],
93
+ snapshot: Optional[Float[Tensor, "channel time2"]] = None,
94
+ ) -> Tuple[Tensor, ...]:
93
95
  if snapshot is None:
94
96
  snapshot = self.get_initial_state()
95
97
 
ml4gw/spectral.py CHANGED
@@ -12,14 +12,18 @@ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.csd.html
12
12
  from typing import Optional, Union
13
13
 
14
14
  import torch
15
- from torchtyping import TensorType
15
+ from jaxtyping import Float
16
+ from torch import Tensor
16
17
 
17
- from ml4gw import types
18
+ from ml4gw.types import (
19
+ FrequencySeries1to3d,
20
+ PSDTensor,
21
+ TimeSeries1to3d,
22
+ WaveformTensor,
23
+ )
18
24
 
19
- time = None
20
25
 
21
-
22
- def median(x, axis):
26
+ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
23
27
  """
24
28
  Implements a median calculation that matches numpy's
25
29
  behavior for an even number of elements and includes
@@ -33,7 +37,7 @@ def median(x, axis):
33
37
 
34
38
 
35
39
  def _validate_shapes(
36
- x: torch.Tensor, nperseg: int, y: Optional[torch.Tensor] = None
40
+ x: Tensor, nperseg: int, y: Optional[Tensor] = None
37
41
  ) -> None:
38
42
  if x.shape[-1] < nperseg:
39
43
  raise ValueError(
@@ -83,14 +87,14 @@ def _validate_shapes(
83
87
 
84
88
 
85
89
  def fast_spectral_density(
86
- x: torch.Tensor,
90
+ x: TimeSeries1to3d,
87
91
  nperseg: int,
88
92
  nstride: int,
89
- window: torch.Tensor,
90
- scale: torch.Tensor,
93
+ window: Float[Tensor, " {nperseg//2+1}"],
94
+ scale: float,
91
95
  average: str = "median",
92
- y: Optional[torch.Tensor] = None,
93
- ) -> torch.Tensor:
96
+ y: Optional[TimeSeries1to3d] = None,
97
+ ) -> FrequencySeries1to3d:
94
98
  """
95
99
  Compute the power spectral density of a multichannel
96
100
  timeseries or a batch of multichannel timeseries, or
@@ -107,9 +111,9 @@ def fast_spectral_density(
107
111
  The timeseries tensor whose power spectral density
108
112
  to compute, or for cross spectral density the
109
113
  timeseries whose fft will be conjugated. Can have
110
- shape either
111
- `(batch_size, num_channels, length * sample_rate)`
112
- or `(num_channels, length * sample_rate)`.
114
+ shape `(batch_size, num_channels, length * sample_rate)`,
115
+ `(num_channels, length * sample_rate)`, or
116
+ `(length * sample_rate)`.
113
117
  nperseg:
114
118
  Number of samples included in each FFT window
115
119
  nstride:
@@ -150,7 +154,7 @@ def fast_spectral_density(
150
154
  channel in `x` across _all_ of `x`'s batch elements.
151
155
  Returns:
152
156
  Tensor of power spectral densities of `x` or its cross spectral
153
- density with the timeseries in `y`.
157
+ density with the timeseries in `y`.
154
158
  """
155
159
 
156
160
  _validate_shapes(x, nperseg, y)
@@ -240,17 +244,16 @@ def fast_spectral_density(
240
244
 
241
245
 
242
246
  def spectral_density(
243
- x: torch.Tensor,
247
+ x: TimeSeries1to3d,
244
248
  nperseg: int,
245
249
  nstride: int,
246
- window: torch.Tensor,
247
- scale: torch.Tensor,
250
+ window: Float[Tensor, " {nperseg//2+1}"],
251
+ scale: float,
248
252
  average: str = "median",
249
- ) -> torch.Tensor:
253
+ ) -> FrequencySeries1to3d:
250
254
  """
251
255
  Compute the power spectral density of a multichannel
252
- timeseries or a batch of multichannel timeseries, or
253
- the cross power spectral density of two such timeseries.
256
+ timeseries or a batch of multichannel timeseries.
254
257
  This implementation is exact for all frequency bins, but
255
258
  slower than the fast implementation.
256
259
 
@@ -259,9 +262,9 @@ def spectral_density(
259
262
  The timeseries tensor whose power spectral density
260
263
  to compute, or for cross spectral density the
261
264
  timeseries whose fft will be conjugated. Can have
262
- shape either
263
- `(batch_size, num_channels, length * sample_rate)`
264
- or `(num_channels, length * sample_rate)`.
265
+ shape `(batch_size, num_channels, length * sample_rate)`,
266
+ `(num_channels, length * sample_rate)`, or
267
+ `(length * sample_rate)`.
265
268
  nperseg:
266
269
  Number of samples included in each FFT window
267
270
  nstride:
@@ -336,11 +339,11 @@ def spectral_density(
336
339
 
337
340
 
338
341
  def truncate_inverse_power_spectrum(
339
- psd: types.PSDTensor,
340
- fduration: Union[TensorType["time"], float],
342
+ psd: PSDTensor,
343
+ fduration: Union[Float[Tensor, " time"], float],
341
344
  sample_rate: float,
342
345
  highpass: Optional[float] = None,
343
- ) -> types.PSDTensor:
346
+ ) -> PSDTensor:
344
347
  """
345
348
  Truncate the length of the time domain response
346
349
  of a whitening filter built using the specified
@@ -399,7 +402,7 @@ def truncate_inverse_power_spectrum(
399
402
  q = torch.fft.irfft(inv_asd, n=N, norm="forward", dim=-1)
400
403
 
401
404
  # taper the edges of the TD filter
402
- if isinstance(fduration, torch.Tensor):
405
+ if isinstance(fduration, Tensor):
403
406
  pad = fduration.size(-1) // 2
404
407
  window = fduration
405
408
  else:
@@ -422,8 +425,8 @@ def truncate_inverse_power_spectrum(
422
425
 
423
426
 
424
427
  def normalize_by_psd(
425
- X: types.WaveformTensor,
426
- psd: types.PSDTensor,
428
+ X: WaveformTensor,
429
+ psd: PSDTensor,
427
430
  sample_rate: float,
428
431
  pad: int,
429
432
  ):
@@ -447,12 +450,12 @@ def normalize_by_psd(
447
450
 
448
451
 
449
452
  def whiten(
450
- X: types.WaveformTensor,
451
- psd: types.PSDTensor,
452
- fduration: Union[TensorType["time"], float],
453
+ X: WaveformTensor,
454
+ psd: PSDTensor,
455
+ fduration: Union[Float[Tensor, " time"], float],
453
456
  sample_rate: float,
454
457
  highpass: Optional[float] = None,
455
- ) -> types.WaveformTensor:
458
+ ) -> WaveformTensor:
456
459
  """
457
460
  Whiten a batch of timeseries using the specified
458
461
  background one-sided power spectral densities (PSDs),
@@ -460,7 +463,8 @@ def whiten(
460
463
  `fduration` and possibly to highpass filter.
461
464
 
462
465
  Args:
463
- X: batch of multichannel timeseries to whiten
466
+ X:
467
+ batch of multichannel timeseries to whiten
464
468
  psd:
465
469
  PSDs use to whiten the data. The frequency
466
470
  response of the whitening filter will be roughly
@@ -496,7 +500,7 @@ def whiten(
496
500
 
497
501
  # figure out how much data we'll need to slice
498
502
  # off after whitening
499
- if isinstance(fduration, torch.Tensor):
503
+ if isinstance(fduration, Tensor):
500
504
  pad = fduration.size(-1) // 2
501
505
  else:
502
506
  pad = int(fduration * sample_rate / 2)