ml4gw 0.4.2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/utils/slicing.py CHANGED
@@ -1,25 +1,30 @@
1
1
  from typing import Optional, Union
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float, Int64
5
+ from torch import Tensor
4
6
  from torch.nn.functional import unfold
5
- from torchtyping import TensorType
6
7
 
7
- # need to define these for flake8 compatibility
8
- batch = time = channel = None # noqa
8
+ from ml4gw.types import (
9
+ TimeSeries1d,
10
+ TimeSeries1to3d,
11
+ TimeSeries2d,
12
+ TimeSeries3d,
13
+ )
9
14
 
10
- TimeSeriesTensor = Union[TensorType["time"], TensorType["channel", "time"]]
11
-
12
- BatchTimeSeriesTensor = Union[
13
- TensorType["batch", "time"], TensorType["batch", "channel", "time"]
14
- ]
15
+ BatchTimeSeriesTensor = Union[Float[Tensor, "batch time"], TimeSeries3d]
15
16
 
16
17
 
17
18
  def unfold_windows(
18
- x: torch.Tensor,
19
+ x: TimeSeries1to3d,
19
20
  window_size: int,
20
21
  stride: int,
21
22
  drop_last: bool = True,
22
- ):
23
+ ) -> Union[
24
+ Float[TimeSeries1d, " window"],
25
+ Float[TimeSeries2d, " window"],
26
+ Float[TimeSeries3d, " window"],
27
+ ]:
23
28
  """Unfold a timeseries into windows
24
29
 
25
30
  Args:
@@ -83,8 +88,8 @@ def unfold_windows(
83
88
 
84
89
 
85
90
  def slice_kernels(
86
- x: Union[TimeSeriesTensor, TensorType["batch", "channel", "time"]],
87
- idx: TensorType[..., torch.int64],
91
+ x: TimeSeries1to3d,
92
+ idx: Int64[Tensor, "..."],
88
93
  kernel_size: int,
89
94
  ) -> BatchTimeSeriesTensor:
90
95
  """Slice kernels from single or multichannel timeseries
@@ -96,7 +101,8 @@ def slice_kernels(
96
101
  one more dimension than `x`.
97
102
 
98
103
  Args:
99
- x: The timeseries tensor to slice kernels from
104
+ x:
105
+ The timeseries tensor to slice kernels from
100
106
  idx:
101
107
  The indices in `x` of the first sample of each
102
108
  kernel. If `x` is 1D, `idx` must be 1D as well.
@@ -114,6 +120,7 @@ def slice_kernels(
114
120
  coincidentally among the channels.
115
121
  kernel_size:
116
122
  The length of the kernels to slice from the timeseries
123
+
117
124
  Returns:
118
125
  A tensor of shape `(batch_size, kernel_size)` if `x` is
119
126
  1D and `(batch_size, num_channels, kernel_size)` if `x`
@@ -225,7 +232,7 @@ def slice_kernels(
225
232
 
226
233
 
227
234
  def sample_kernels(
228
- X: TimeSeriesTensor,
235
+ X: TimeSeries1to3d,
229
236
  kernel_size: int,
230
237
  N: Optional[int] = None,
231
238
  max_center_offset: Optional[int] = None,
@@ -245,8 +252,9 @@ def sample_kernels(
245
252
  either be `None` or be equal to `len(X)`.
246
253
 
247
254
  Args:
248
- X: The timeseries tensor from which to sample kernels
249
- kernel_size: The size of the kernels to sample
255
+ X:
256
+ The timeseries tensor from which to sample kernels
257
+ kernel_size: The size of the kernels to sample
250
258
  N:
251
259
  The number of kernels to sample. Can be left as
252
260
  `None` if `X` is 3D, otherwise must be specified
@@ -1,3 +1,5 @@
1
1
  from .phenom_d import IMRPhenomD
2
+ from .phenom_p import IMRPhenomPv2
3
+ from .ringdown import Ringdown
2
4
  from .sine_gaussian import SineGaussian
3
5
  from .taylorf2 import TaylorF2
@@ -1,24 +1,26 @@
1
- from typing import Callable
1
+ from typing import Callable, Dict, Tuple
2
2
 
3
3
  import torch
4
+ from jaxtyping import Float
5
+ from torch import Tensor
4
6
 
5
7
 
6
8
  class ParameterSampler(torch.nn.Module):
7
- def __init__(self, **parameters: Callable):
9
+ def __init__(self, **parameters: Callable) -> None:
8
10
  super().__init__()
9
11
  self.parameters = parameters
10
12
 
11
13
  def forward(
12
14
  self,
13
15
  N: int,
14
- ):
16
+ ) -> Dict[str, Float[Tensor, " {N}"]]:
15
17
  return {k: v.sample((N,)) for k, v in self.parameters.items()}
16
18
 
17
19
 
18
20
  class WaveformGenerator(torch.nn.Module):
19
21
  def __init__(
20
22
  self, waveform: Callable, parameter_sampler: ParameterSampler
21
- ):
23
+ ) -> None:
22
24
  """
23
25
  A torch module that generates waveforms from a given waveform function
24
26
  and a parameter sampler.
@@ -34,6 +36,8 @@ class WaveformGenerator(torch.nn.Module):
34
36
  self.waveform = waveform
35
37
  self.parameter_sampler = parameter_sampler
36
38
 
37
- def forward(self, N: int):
39
+ def forward(
40
+ self, N: int
41
+ ) -> Tuple[Float[Tensor, "{N} samples"], Dict[str, Float[Tensor, " {N}"]]]:
38
42
  parameters = self.parameter_sampler(N)
39
43
  return self.waveform(**parameters), parameters