ml4gw 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -0,0 +1,87 @@
1
+ import torch
2
+
3
+ from ml4gw.utils.slicing import unfold_windows
4
+
5
+
6
+ class ShiftedPearsonCorrelation(torch.nn.Module):
7
+ """
8
+ Compute the [Pearson correlation]
9
+ (https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
10
+ for two equal-length timeseries over a pre-defined number of time
11
+ shifts in each direction. Useful for when you want a
12
+ correlation, but not over every possible shift (i.e.
13
+ a convolution).
14
+
15
+ The number of dimensions of the second timeseries `y`
16
+ passed at call time should always be less than or equal
17
+ to the number of dimensions of the first timeseries `x`,
18
+ and each dimension should match the corresponding one of
19
+ `x` in reverse order (i.e. if `x` has shape `(B, C, T)`
20
+ then `y` should either have shape `(T,)`, `(C, T)`, or
21
+ `(B, C, T)`).
22
+
23
+ Note that no windowing to either timeseries is applied
24
+ at call time. Users should do any requisite windowing
25
+ beforehand.
26
+
27
+ TODOs:
28
+ - Should we perform windowing?
29
+ - Should we support stride > 1?
30
+
31
+ Args:
32
+ max_shift:
33
+ The maximum number of 1-step time shifts in
34
+ each direction over which to compute the
35
+ Pearson coefficient. Output shape will then
36
+ be `(2 * max_shifts + 1, B, C)`.
37
+ """
38
+
39
+ def __init__(self, max_shift: int) -> None:
40
+ super().__init__()
41
+ self.max_shift = max_shift
42
+
43
+ def _shape_checks(self, x: torch.Tensor, y: torch.Tensor):
44
+ if x.ndim > 3:
45
+ raise ValueError(
46
+ "Tensor x can only have up to 3 dimensions "
47
+ f"to compute ShiftedPearsonCorrelation. Found {x.ndim}."
48
+ )
49
+ elif y.ndim > x.ndim:
50
+ raise ValueError(
51
+ "y may not have more dimensions that x for "
52
+ "ShiftedPearsonCorrelation, but found shapes "
53
+ "{} and {}".format(y.shape, x.shape)
54
+ )
55
+ for dim in range(y.ndim):
56
+ if y.size(-dim - 1) != x.size(-dim - 1):
57
+ raise ValueError(
58
+ "x and y expected to have same size along "
59
+ "last dimensions, but found shapes {} and {}".format(
60
+ x.shape, y.shape
61
+ )
62
+ )
63
+
64
+ # TODO: torchtyping annotate
65
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
66
+ self._shape_checks(x, y)
67
+ dim = x.size(-1)
68
+
69
+ # pad x along time dimension so that it has shape
70
+ # batch x channels x (time + 2 * max_shift)
71
+ pad = (self.max_shift, self.max_shift)
72
+ x = torch.nn.functional.pad(x, pad)
73
+
74
+ # num_windows x batch x channels x time
75
+ x = unfold_windows(x, dim, 1)
76
+
77
+ # now compute the correlation between each window
78
+ # of x and the single window of y. Start by de-meaning
79
+ x = x - x.mean(-1, keepdims=True)
80
+ y = y - y.mean(-1, keepdims=True)
81
+
82
+ # apply formula and sum along time dimension to give final shape
83
+ # num_windows x batch x channels
84
+ corr = (x * y).sum(axis=-1)
85
+ norm = (x**2).sum(-1) * (y**2).sum(-1)
86
+
87
+ return corr / norm**0.5
@@ -0,0 +1,162 @@
1
+ import warnings
2
+ from typing import Dict, List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torchaudio.transforms import Spectrogram
7
+
8
+
9
+ class MultiResolutionSpectrogram(torch.nn.Module):
10
+ """
11
+ Create a batch of multi-resolution spectrograms
12
+ from a batch of timeseries. Input is expected to
13
+ have the shape `(B, C, T)`, where `B` is the number
14
+ of batches, `C` is the number of channels, and `T`
15
+ is the number of time samples.
16
+
17
+ For each timeseries, calculate multiple normalized
18
+ spectrograms based on the `Spectrogram` `kwargs` given.
19
+ Combine the spectrograms by taking the maximum value
20
+ from the nearest time-frequncy bin.
21
+
22
+ If the largest number of time bins among the spectrograms
23
+ is `N` and the largest number of frequency bins is `M`,
24
+ the output will have dimensions `(B, C, M, N)`
25
+
26
+ Args:
27
+ kernel_length:
28
+ The length in seconds of the time dimension
29
+ of the tensor that will be turned into a
30
+ spectrogram
31
+ sample_rate:
32
+ The sample rate of the timeseries in Hz
33
+ kwargs:
34
+ Arguments passed in kwargs will used to create
35
+ `torchaudio.transforms.Spectrogram`s. Each
36
+ argument should be a list of values. Any list
37
+ of length greater than 1 should be the same
38
+ length
39
+ """
40
+
41
+ def __init__(
42
+ self, kernel_length: float, sample_rate: float, **kwargs
43
+ ) -> None:
44
+ super().__init__()
45
+ self.kernel_size = kernel_length * sample_rate
46
+ # This method of combination makes sense only when
47
+ # the spectrograms are normalized, so enforce this
48
+ if "normalized" in kwargs.keys():
49
+ if not all(kwargs["normalized"]):
50
+ raise ValueError(
51
+ "Received a value of False for 'normalized'. "
52
+ "This method of combination is sensible only for "
53
+ "normalized spectrograms."
54
+ )
55
+ else:
56
+ kwargs["normalized"] = [True]
57
+ self.kwargs = self._check_and_format_kwargs(kwargs)
58
+
59
+ self.transforms = torch.nn.ModuleList(
60
+ [Spectrogram(**k) for k in self.kwargs]
61
+ )
62
+
63
+ dummy_input = torch.ones(int(kernel_length * sample_rate))
64
+ self.shapes = torch.tensor(
65
+ [t(dummy_input).shape for t in self.transforms]
66
+ )
67
+
68
+ self.num_freqs = max([shape[0] for shape in self.shapes])
69
+ self.num_times = max([shape[1] for shape in self.shapes])
70
+
71
+ left_pad = torch.zeros(len(self.transforms), dtype=torch.int)
72
+ top_pad = torch.zeros(len(self.transforms), dtype=torch.int)
73
+ bottom_pad = torch.tensor(
74
+ [int(self.num_freqs - shape[0]) for shape in self.shapes]
75
+ )
76
+ right_pad = torch.tensor(
77
+ [int(self.num_times - shape[1]) for shape in self.shapes]
78
+ )
79
+ self.register_buffer("left_pad", left_pad)
80
+ self.register_buffer("top_pad", top_pad)
81
+ self.register_buffer("bottom_pad", bottom_pad)
82
+ self.register_buffer("right_pad", right_pad)
83
+
84
+ freq_idxs = torch.tensor(
85
+ [
86
+ [int(i * shape[0] / self.num_freqs) for shape in self.shapes]
87
+ for i in range(self.num_freqs)
88
+ ]
89
+ )
90
+ freq_idxs = freq_idxs.repeat(self.num_times, 1, 1).transpose(0, 1)
91
+ time_idxs = torch.tensor(
92
+ [
93
+ [int(i * shape[1] / self.num_times) for shape in self.shapes]
94
+ for i in range(self.num_times)
95
+ ]
96
+ )
97
+ time_idxs = time_idxs.repeat(self.num_freqs, 1, 1)
98
+
99
+ self.register_buffer("freq_idxs", freq_idxs)
100
+ self.register_buffer("time_idxs", time_idxs)
101
+
102
+ def _check_and_format_kwargs(self, kwargs: Dict[str, List]) -> List:
103
+ lengths = sorted(set([len(v) for v in kwargs.values()]))
104
+
105
+ if lengths[-1] > 3:
106
+ warnings.warn(
107
+ "Combining too many spectrograms can impede computation time. "
108
+ "If performance is slower than desired, try reducing the "
109
+ "number of spectrograms",
110
+ RuntimeWarning,
111
+ )
112
+
113
+ if len(lengths) > 2 or (len(lengths) == 2 and lengths[0] != 1):
114
+ raise ValueError(
115
+ "Spectrogram keyword args should all have the same "
116
+ f"length or be of length one. Got lengths {lengths}"
117
+ )
118
+
119
+ if len(lengths) == 2:
120
+ size = lengths[1]
121
+ kwargs = {k: v * int(size / len(v)) for k, v in kwargs.items()}
122
+
123
+ return [dict(zip(kwargs, col)) for col in zip(*kwargs.values())]
124
+
125
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
126
+ """
127
+ Calculate spectrograms of the input tensor and
128
+ combine them into a single spectrogram
129
+
130
+ Args:
131
+ X:
132
+ Batch of multichannel timeseries which will
133
+ be used to calculate the multi-resolution
134
+ spectrogram. Should have the shape
135
+ `(B, C, T)`, where `B` is the number of
136
+ batches, `C` is the number of channels,
137
+ and `T` is the number of time samples.
138
+ """
139
+ if X.shape[-1] != self.kernel_size:
140
+ raise ValueError(
141
+ "Expected time dimension to be "
142
+ f"{self.kernel_size} samples long, got input with "
143
+ f"{X.shape[-1]} samples"
144
+ )
145
+
146
+ spectrograms = [t(X) for t in self.transforms]
147
+
148
+ padded_specs = []
149
+ for spec, left, right, top, bottom in zip(
150
+ spectrograms,
151
+ self.left_pad,
152
+ self.right_pad,
153
+ self.top_pad,
154
+ self.bottom_pad,
155
+ ):
156
+ padded_specs.append(F.pad(spec, (left, right, top, bottom)))
157
+
158
+ padded_specs = torch.stack(padded_specs)
159
+ remapped_specs = padded_specs[..., self.freq_idxs, self.time_idxs]
160
+ remapped_specs = torch.diagonal(remapped_specs, dim1=0, dim2=-1)
161
+
162
+ return torch.max(remapped_specs, axis=-1)[0]
@@ -123,7 +123,7 @@ class FixedWhiten(FittableSpectralTransform):
123
123
  num_channels: float,
124
124
  kernel_length: float,
125
125
  sample_rate: float,
126
- dtype: torch.dtype = torch.float32,
126
+ dtype: torch.dtype = torch.float64,
127
127
  ) -> None:
128
128
  super().__init__()
129
129
  self.num_channels = num_channels
@@ -1 +1,3 @@
1
+ from .phenom_d import IMRPhenomD
1
2
  from .sine_gaussian import SineGaussian
3
+ from .taylorf2 import TaylorF2