ml4gw 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -105,10 +105,10 @@ class BasicBlock(nn.Module):
105
105
  class Bottleneck(nn.Module):
106
106
  """
107
107
  Bottleneck blocks implement one extra convolution
108
- compared to basic blocks. In this layers, the `planes`
108
+ compared to basic blocks. In this layers, the ``planes``
109
109
  parameter is generally meant to _downsize_ the number
110
110
  of feature maps first, which then get expanded out to
111
- `planes * Bottleneck.expansion` feature maps at the
111
+ ``planes * Bottleneck.expansion`` feature maps at the
112
112
  output of the layer.
113
113
  """
114
114
 
@@ -188,9 +188,9 @@ class ResNet2D(nn.Module):
188
188
  A list representing the number of residual
189
189
  blocks to include in each "layer" of the
190
190
  network. Total layers (e.g. 50 in ResNet50)
191
- is `2 + sum(layers) * factor`, where factor
192
- is `2` for vanilla `ResNet` and `3` for
193
- `BottleneckResNet`.
191
+ is ``2 + sum(layers) * factor``, where factor
192
+ is ``2`` for vanilla ``ResNet`` and ``3`` for
193
+ ``BottleneckResNet``.
194
194
  kernel_size:
195
195
  The size of the convolutional kernel to
196
196
  use in all residual layers. _NOT_ the size
@@ -207,22 +207,22 @@ class ResNet2D(nn.Module):
207
207
  connections between feature maps at subsequent
208
208
  layers rather than global. Generally won't
209
209
  need this to be >1, and wil raise an error if
210
- >1 when using vanilla `ResNet`.
210
+ >1 when using vanilla ``ResNet``.
211
211
  width_per_group:
212
212
  Base width of each of the feature map groups,
213
213
  which is scaled up by the typical expansion
214
214
  factor at each layer of the network. Meaningless
215
- for vanilla `ResNet`.
215
+ for vanilla ``ResNet``.
216
216
  stride_type:
217
217
  Whether to achieve downsampling on the time axis
218
218
  by strided or dilated convolutions for each layer.
219
- If left as `None`, strided convolutions will be
220
- used at each layer. Otherwise, `stride_type` should
221
- be one element shorter than `layers` and indicate either
222
- `stride` or `dilation` for each layer after the first.
219
+ If left as ``None``, strided convolutions will be
220
+ used at each layer. Otherwise, ``stride_type`` should
221
+ be one element shorter than ``layers`` and indicate either
222
+ ``stride`` or ``dilation`` for each layer after the first.
223
223
  norm_groups:
224
224
  The number of groups to use in GroupNorm layers
225
- throughout the model. If left as `-1`, the number
225
+ throughout the model. If left as ``-1``, the number
226
226
  of groups will be equal to the number of channels,
227
227
  making this equilavent to LayerNorm
228
228
  """
@@ -11,7 +11,7 @@ class OnlineAverager(torch.nn.Module):
11
11
  """
12
12
  Module for performing stateful online averaging of
13
13
  batches of overlapping timeseries. At present, the
14
- first `num_updates` predictions produced by this
14
+ first ``num_updates`` predictions produced by this
15
15
  model will underestimate the true average.
16
16
 
17
17
  Args:
@@ -12,24 +12,24 @@ class Snapshotter(torch.nn.Module):
12
12
  Model for converting streaming state updates into
13
13
  a batch of overlapping snaphots of a multichannel
14
14
  timeseries. Can support multiple timeseries in a
15
- single state update via the `channels_per_snapshot`
15
+ single state update via the ``channels_per_snapshot``
16
16
  kwarg.
17
17
 
18
18
  Specifically, maps tensors of shape
19
- `(num_channels, batch_size * stride_size)` to a tensor
20
- of shape `(batch_size, num_channels, snapshot_size)`.
21
- If `channels_per_snapshot` is specified, it will return
22
- `len(channels_per_snapshot)` tensors of this shape,
19
+ ``(num_channels, batch_size * stride_size)`` to a tensor
20
+ of shape ``(batch_size, num_channels, snapshot_size)``.
21
+ If ``channels_per_snapshot`` is specified, it will return
22
+ ``len(channels_per_snapshot)`` tensors of this shape,
23
23
  with the channel dimension replaced by the corresponding
24
- value of `channels_per_snapshot`. The last tensor returned
24
+ value of ``channels_per_snapshot``. The last tensor returned
25
25
  at call time will be the current state that can be passed
26
- to the next `forward` call.
26
+ to the next ``forward`` call.
27
27
 
28
28
  Args:
29
29
  num_channels:
30
30
  Number of channels in the timeseries. If
31
- `channels_per_snapshot` is not `None`,
32
- this should be equal to `sum(channels_per_snapshot)`.
31
+ ``channels_per_snapshot`` is not ``None``,
32
+ this should be equal to ``sum(channels_per_snapshot)``.
33
33
  snapshot_size:
34
34
  The size of the output snapshot windows in
35
35
  number of samples
@@ -39,17 +39,17 @@ class Snapshotter(torch.nn.Module):
39
39
  batch_size:
40
40
  The number of snapshots to produce at each
41
41
  update. The last dimension of the input
42
- tensor should have size `batch_size * stride_size`.
42
+ tensor should have size ``batch_size * stride_size``.
43
43
  channels_per_snapshot:
44
44
  How to split up the channels in the timeseries
45
- for different tensors. If left as `None`, all
45
+ for different tensors. If left as ``None``, all
46
46
  the channels will be returned in a single tensor.
47
47
  Otherwise, the channels will be split up into
48
- `len(channels_per_snapshot)` tensors, with each
48
+ ``len(channels_per_snapshot)`` tensors, with each
49
49
  tensor's channel dimension being equal to the
50
- corresponding value in `channels_per_snapshot`.
50
+ corresponding value in ``channels_per_snapshot``.
51
51
  Therefore, if specified, these values should
52
- add up to `num_channels`.
52
+ add up to ``num_channels``.
53
53
  """
54
54
 
55
55
  def __init__(
ml4gw/spectral.py CHANGED
@@ -27,8 +27,8 @@ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
27
27
  """
28
28
  Implements a median calculation that matches numpy's
29
29
  behavior for an even number of elements and includes
30
- the same bias correction used by scipy's implementation.
31
- see https://github.com/scipy/scipy/blob/main/scipy/signal/_spectral_py.py#L2066
30
+ the same bias correction used by
31
+ `scipy's implementation <https://github.com/scipy/scipy/blob/main/scipy/signal/_spectral_py.py#L2066>`_.
32
32
  """ # noqa: E501
33
33
  n = x.shape[axis]
34
34
  ii_2 = 2 * torch.arange(1.0, (n - 1) // 2 + 1)
@@ -111,50 +111,50 @@ def fast_spectral_density(
111
111
  The timeseries tensor whose power spectral density
112
112
  to compute, or for cross spectral density the
113
113
  timeseries whose fft will be conjugated. Can have
114
- shape `(batch_size, num_channels, length * sample_rate)`,
115
- `(num_channels, length * sample_rate)`, or
116
- `(length * sample_rate)`.
114
+ shape ``(batch_size, num_channels, length * sample_rate)``,
115
+ ``(num_channels, length * sample_rate)``, or
116
+ ``(length * sample_rate)``.
117
117
  nperseg:
118
118
  Number of samples included in each FFT window
119
119
  nstride:
120
120
  Stride between FFT windows
121
121
  window:
122
122
  Window array to multiply by each FFT window before
123
- FFT computation. Should have length `nperseg // 2 + 1`.
123
+ FFT computation. Should have length ``nperseg // 2 + 1``.
124
124
  scale:
125
125
  Scale factor to multiply the FFT'd data by, related to
126
126
  desired units for output tensor (e.g. letting this equal
127
- `1 / (sample_rate * (window**2).sum())` will give output
128
- units of density, $\\text{Hz}^-1$$.
127
+ ``1 / (sample_rate * (window**2).sum())`` will give output
128
+ units of density, :math``\\text{Hz}^-1``.
129
129
  average:
130
130
  How to aggregate the contributions of each FFT window to
131
- the spectral density. Allowed options are `'mean'` and
132
- `'median'`.
131
+ the spectral density. Allowed options are ``'mean'`` and
132
+ ``'median'``.
133
133
  y:
134
134
  Timeseries tensor to compute cross spectral density
135
- with `x`. If left as `None`, `x`'s power spectral
136
- density will be returned. Otherwise, if `x` is 1D,
137
- `y` must also be 1D. If `x` is 2D, the assumption
135
+ with ``x``. If left as ``None``, ``x``'s power spectral
136
+ density will be returned. Otherwise, if ``x`` is 1D,
137
+ ``y`` must also be 1D. If ``x`` is 2D, the assumption
138
138
  is that this represents a single multi-channel timeseries,
139
- and `y` must be either 2D or 1D. In the former case,
139
+ and ``y`` must be either 2D or 1D. In the former case,
140
140
  the cross-spectral densities of each channel will be
141
- computed individually, so `y` must have the same shape as `x`.
142
- Otherwise, this will compute the CSD of each of `x`'s channels
143
- with `y`. If `x` is 3D, this will be assumed to be a batch
144
- of multi-channel timeseries. In this case, `y` can either
141
+ computed individually, so ``y`` must have the same shape as ``x``.
142
+ Otherwise, this will compute the CSD of each of ``x``'s channels
143
+ with ``y``. If ``x`` is 3D, this will be assumed to be a batch
144
+ of multi-channel timeseries. In this case, ``y`` can either
145
145
  be 3D, in which case each channel of each batch element will
146
146
  have its CSD calculated or 2D, which has two different options.
147
- If `y`'s 0th dimension matches `x`'s 0th dimension, it will
148
- be assumed that `y` represents a batch of 1D timeseries, and
147
+ If ``y``'s 0th dimension matches ``x``'s 0th dimension, it will
148
+ be assumed that ``y`` represents a batch of 1D timeseries, and
149
149
  for each batch element this timeseries will have its CSD with
150
- each channel of the corresponding batch element of `x`
151
- calculated. Otherwise, it sill be assumed that `y` represents
150
+ each channel of the corresponding batch element of ``x``
151
+ calculated. Otherwise, it sill be assumed that ``y`` represents
152
152
  a single multi-channel timeseries, in which case each channel
153
- of `y` will have its CSD calculated with the corresponding
154
- channel in `x` across _all_ of `x`'s batch elements.
153
+ of ``y`` will have its CSD calculated with the corresponding
154
+ channel in ``x`` across _all_ of ``x``'s batch elements.
155
155
  Returns:
156
- Tensor of power spectral densities of `x` or its cross spectral
157
- density with the timeseries in `y`.
156
+ Tensor of power spectral densities of ``x`` or its cross spectral
157
+ density with the timeseries in ``y``.
158
158
  """
159
159
 
160
160
  _validate_shapes(x, nperseg, y)
@@ -262,25 +262,25 @@ def spectral_density(
262
262
  The timeseries tensor whose power spectral density
263
263
  to compute, or for cross spectral density the
264
264
  timeseries whose fft will be conjugated. Can have
265
- shape `(batch_size, num_channels, length * sample_rate)`,
266
- `(num_channels, length * sample_rate)`, or
267
- `(length * sample_rate)`.
265
+ shape ``(batch_size, num_channels, length * sample_rate)``,
266
+ ``(num_channels, length * sample_rate)``, or
267
+ ``(length * sample_rate)``.
268
268
  nperseg:
269
269
  Number of samples included in each FFT window
270
270
  nstride:
271
271
  Stride between FFT windows
272
272
  window:
273
273
  Window array to multiply by each FFT window before
274
- FFT computation. Should have length `nperseg // 2 + 1`.
274
+ FFT computation. Should have length ``nperseg // 2 + 1``.
275
275
  scale:
276
276
  Scale factor to multiply the FFT'd data by, related to
277
277
  desired units for output tensor (e.g. letting this equal
278
- `1 / (sample_rate * (window**2).sum())` will give output
279
- units of density, $\\text{Hz}^-1$$.
278
+ ``1 / (sample_rate * (window**2).sum())`` will give output
279
+ units of density, :math:`\\text{Hz}^-1`.
280
280
  average:
281
281
  How to aggregate the contributions of each FFT window to
282
- the spectral density. Allowed options are `'mean'` and
283
- `'median'`.
282
+ the spectral density. Allowed options are ``'mean'`` and
283
+ ``'median'``.
284
284
  """
285
285
 
286
286
  _validate_shapes(x, nperseg)
@@ -348,18 +348,18 @@ def truncate_inverse_power_spectrum(
348
348
  """
349
349
  Truncate the length of the time domain response
350
350
  of a whitening filter built using the specified
351
- `psd` so that it has maximum length `fduration`
351
+ ``psd`` so that it has maximum length ``fduration``
352
352
  seconds. This is meant to mitigate the impact
353
353
  of sharp features in the background PSD causing
354
354
  time domain responses longer than the segments
355
355
  to which the whitening filter will be applied.
356
356
 
357
357
  Implementation details adapted from
358
- https://github.com/vivinousi/gw-detection-deep-learning/blob/203966cc2ee47c32c292be000fb009a16824b7d9/modules/whiten.py#L8
358
+ `here <https://github.com/vivinousi/gw-detection-deep-learning/blob/203966cc2ee47c32c292be000fb009a16824b7d9/modules/whiten.py#L8>`_.
359
359
 
360
360
  Args:
361
361
  psd:
362
- The one-sided power spectraul density used
362
+ The one-sided power spectral density used
363
363
  to construct a whitening filter.
364
364
  fduration:
365
365
  Desired length in seconds of the time domain
@@ -375,14 +375,14 @@ def truncate_inverse_power_spectrum(
375
375
  highpass:
376
376
  If specified, will zero out the frequency response
377
377
  of all frequencies below this value in Hz. If left
378
- as `None`, no highpass filtering will be applied.
378
+ as ``None``, no highpass filtering will be applied.
379
379
  lowpass:
380
380
  If specified, will zero out the frequency response
381
381
  of all frequencies above this value in Hz. If left
382
- as `None`, no lowpass filtering will be applied.
382
+ as ``None``, no lowpass filtering will be applied.
383
383
  Returns:
384
384
  The PSD with its time domain response truncated
385
- to `fduration` and any filtered frequencies
385
+ to ``fduration`` and any filtered frequencies
386
386
  tapered.
387
387
  """ # noqa: E501
388
388
 
@@ -469,7 +469,7 @@ def whiten(
469
469
  Whiten a batch of timeseries using the specified
470
470
  background one-sided power spectral densities (PSDs),
471
471
  modified to have the desired time domain response length
472
- `fduration` and possibly to highpass/lowpass filter.
472
+ ``fduration`` and possibly to highpass/lowpass filter.
473
473
 
474
474
  Args:
475
475
  X:
@@ -480,11 +480,11 @@ def whiten(
480
480
  the inverse of the square root of this PSD, ensuring
481
481
  that data from the same distribution will have
482
482
  approximately uniform power after whitening.
483
- If 2D, each batch element in `X` will be whitened
483
+ If 2D, each batch element in ``X`` will be whitened
484
484
  using the same PSDs. If 3D, each batch element will
485
485
  be whitened by the PSDs contained along the 0th
486
- dimenion of `psd`, and so the first two dimensions
487
- of `X` and `psd` should match.
486
+ dimenion of ``psd``, and so the first two dimensions
487
+ of ``X`` and ``psd`` should match.
488
488
  fduration:
489
489
  Desired length in seconds of the time domain
490
490
  response of a whitening filter built using
@@ -496,20 +496,20 @@ def whiten(
496
496
  the whitened timeseries to account for filter
497
497
  settle-in time.
498
498
  sample_rate:
499
- Rate at which the data in `X` has been sampled
499
+ Rate at which the data in ``X`` has been sampled
500
500
  highpass:
501
501
  The frequency in Hz at which to highpass filter
502
502
  the data, setting the frequency response in the
503
- whitening filter to 0. If left as `None`, no
503
+ whitening filter to 0. If left as ``None``, no
504
504
  highpass filtering will be applied.
505
505
  lowpass:
506
506
  The frequency in Hz at which to lowpass filter
507
507
  the data, setting the frequency response in the
508
- whitening filter to 0. If left as `None`, no
508
+ whitening filter to 0. If left as ``None``, no
509
509
  lowpass filtering will be applied.
510
510
  Returns:
511
511
  Batch of whitened multichannel timeseries with
512
- `fduration / 2` seconds trimmed from each side.
512
+ ``fduration / 2`` seconds trimmed from each side.
513
513
  """
514
514
 
515
515
  # figure out how much data we'll need to slice
@@ -5,6 +5,6 @@ from .scaler import ChannelWiseScaler
5
5
  from .snr_rescaler import SnrRescaler
6
6
  from .spectral import SpectralDensity
7
7
  from .spectrogram import MultiResolutionSpectrogram
8
- from .spline_interpolation import SplineInterpolate
8
+ from .spline_interpolation import SplineInterpolate1D, SplineInterpolate2D
9
9
  from .waveforms import WaveformProjector, WaveformSampler
10
10
  from .whitening import FixedWhiten, Whiten
@@ -9,7 +9,7 @@ class IIRFilter(torch.nn.Module):
9
9
  r"""
10
10
  IIR digital and analog filter design given order and critical points.
11
11
  Design an Nth-order digital or analog filter and apply it to a signal.
12
- Uses SciPy's `iirfilter` function to create the filter coefficients.
12
+ Uses SciPy's ``iirfilter`` function to create the filter coefficients.
13
13
  https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html
14
14
 
15
15
  The forward call of this module accepts a batch tensor of shape
@@ -24,8 +24,8 @@ class IIRFilter(torch.nn.Module):
24
24
  default, fs is 2 half-cycles/sample, so these are normalized
25
25
  from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in
26
26
  half-cycles / sample). For analog filters, Wn is an angular
27
- frequency (e.g., rad/s). When Wn is a length-2 sequence,`Wn[0]`
28
- must be less than `Wn[1]`.
27
+ frequency (e.g., rad/s). When Wn is a length-2 sequence,``Wn[0]``
28
+ must be less than ``Wn[1]``.
29
29
  rp:
30
30
  For Chebyshev and elliptic filters, provides the maximum ripple in
31
31
  the passband. (dB)
@@ -8,20 +8,19 @@ from ..utils.slicing import unfold_windows
8
8
 
9
9
  class ShiftedPearsonCorrelation(torch.nn.Module):
10
10
  """
11
- Compute the [Pearson correlation]
12
- (https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
11
+ Compute the `Pearson correlation <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_
13
12
  for two equal-length timeseries over a pre-defined number of time
14
13
  shifts in each direction. Useful for when you want a
15
14
  correlation, but not over every possible shift (i.e.
16
15
  a convolution).
17
16
 
18
- The number of dimensions of the second timeseries `y`
17
+ The number of dimensions of the second timeseries ``y``
19
18
  passed at call time should always be less than or equal
20
- to the number of dimensions of the first timeseries `x`,
19
+ to the number of dimensions of the first timeseries ``x``,
21
20
  and each dimension should match the corresponding one of
22
- `x` in reverse order (i.e. if `x` has shape `(B, C, T)`
23
- then `y` should either have shape `(T,)`, `(C, T)`, or
24
- `(B, C, T)`).
21
+ ``x`` in reverse order (i.e. if ``x`` has shape ``(B, C, T)``
22
+ then ``y`` should either have shape ``(T,)``, ``(C, T)``, or
23
+ ``(B, C, T)``).
25
24
 
26
25
  Note that no windowing to either timeseries is applied
27
26
  at call time. Users should do any requisite windowing
@@ -36,7 +35,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
36
35
  The maximum number of 1-step time shifts in
37
36
  each direction over which to compute the
38
37
  Pearson coefficient. Output shape will then
39
- be `(2 * max_shifts + 1, B, C)`.
38
+ be ``(2 * max_shifts + 1, B, C)``.
40
39
  """
41
40
 
42
41
  def __init__(self, max_shift: int) -> None:
@@ -8,7 +8,7 @@ from jaxtyping import Float, Int
8
8
  from torch import Tensor
9
9
 
10
10
  from ..types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
11
- from .spline_interpolation import SplineInterpolate
11
+ from .spline_interpolation import SplineInterpolate1D
12
12
 
13
13
  """
14
14
  All based on https://github.com/gwpy/gwpy/blob/v3.0.8/gwpy/signal/qtransform.py
@@ -22,7 +22,7 @@ class QTile(torch.nn.Module):
22
22
  """
23
23
  Compute the row of Q-tiles for a single Q value and a single
24
24
  frequency for a batch of multi-channel frequency series data.
25
- Should really be called `QRow`, but I want to match GWpy.
25
+ Should really be called ``QRow``, but I want to match GWpy.
26
26
  Input data should have three dimensions or fewer.
27
27
  If fewer, dimensions will be added until the input is
28
28
  three-dimensional.
@@ -112,17 +112,17 @@ class QTile(torch.nn.Module):
112
112
  fseries:
113
113
  Frequency series of data. Should correspond to data with
114
114
  the duration and sample rate used to initialize this object.
115
- Expected input shape is `(B, C, F)`, where F is the number
115
+ Expected input shape is ``(B, C, F)``, where F is the number
116
116
  of samples, C is the number of channels, and B is the number
117
117
  of batches. If less than three-dimensional, axes will be
118
118
  added.
119
119
  norm:
120
120
  The method of normalization. Options are "median", "mean", or
121
- `None`.
121
+ ``None``.
122
122
 
123
123
  Returns:
124
124
  The row of Q-tiles for the given Q and frequency. Output is
125
- three-dimensional: `(B, C, T)`
125
+ three-dimensional: ``(B, C, T)``
126
126
  """
127
127
  if len(fseries.shape) > 3:
128
128
  raise ValueError("Input data has more than 3 dimensions")
@@ -164,7 +164,7 @@ class SingleQTransform(torch.nn.Module):
164
164
  Sample rate of the data in Hz
165
165
  spectrogram_shape:
166
166
  The shape of the interpolated spectrogram, specified as
167
- `(num_f_bins, num_t_bins)`. Because the
167
+ ``(num_f_bins, num_t_bins)``. Because the
168
168
  frequency spacing of the Q-tiles is in log-space, the frequency
169
169
  interpolation is log-spaced as well.
170
170
  q:
@@ -176,14 +176,14 @@ class SingleQTransform(torch.nn.Module):
176
176
  mismatch:
177
177
  The maximum fractional mismatch between neighboring tiles
178
178
  interpolation_method:
179
- The method by which to interpolate each `QTile` to the specified
179
+ The method by which to interpolate each ``QTile`` to the specified
180
180
  number of time and frequency bins. The acceptable values are
181
181
  "bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
182
182
  options will use PyTorch's built-in interpolation modes, while
183
183
  "spline" will use the custom Torch-based implementation in
184
- `ml4gw`, as PyTorch does not have spline-based intertpolation.
184
+ ``ml4gw``, as PyTorch does not have spline-based intertpolation.
185
185
  The "spline" mode is most similar to the results of GWpy's
186
- Q-transform, which uses `scipy` to do spline interpolation.
186
+ Q-transform, which uses ``scipy`` to do spline interpolation.
187
187
  However, it is also the slowest and most memory intensive due to
188
188
  the matrix equation solving steps. Therefore, the default method
189
189
  is "bicubic" as it produces the most similar results while
@@ -260,18 +260,15 @@ class SingleQTransform(torch.nn.Module):
260
260
  )
261
261
  self.qtile_interpolators = torch.nn.ModuleList(
262
262
  [
263
- SplineInterpolate(
263
+ SplineInterpolate1D(
264
264
  kx=3,
265
265
  x_in=torch.arange(0, self.duration, self.duration / tiles),
266
- y_in=torch.arange(len(idx)),
267
266
  x_out=t_out,
268
- y_out=torch.arange(len(idx)),
269
267
  )
270
- for tiles, idx in zip(unique_ntiles, self.stack_idx)
268
+ for tiles in unique_ntiles
271
269
  ]
272
270
  )
273
271
 
274
- t_in = t_out
275
272
  f_in = self.freqs
276
273
  f_out = torch.logspace(
277
274
  math.log10(self.frange[0]),
@@ -279,19 +276,16 @@ class SingleQTransform(torch.nn.Module):
279
276
  self.spectrogram_shape[0],
280
277
  )
281
278
 
282
- self.interpolator = SplineInterpolate(
279
+ self.interpolator = SplineInterpolate1D(
283
280
  kx=3,
284
- ky=3,
285
- x_in=t_in,
286
- y_in=f_in,
287
- x_out=t_out,
288
- y_out=f_out,
281
+ x_in=f_in,
282
+ x_out=f_out,
289
283
  )
290
284
 
291
285
  def get_freqs(self) -> Float[Tensor, " nfreq"]:
292
286
  """
293
287
  Calculate the frequencies that will be used in this transform.
294
- For each frequency, a `QTile` is created.
288
+ For each frequency, a ``QTile`` is created.
295
289
  """
296
290
  minf, maxf = self.frange
297
291
  fcum_mismatch = (
@@ -320,7 +314,7 @@ class SingleQTransform(torch.nn.Module):
320
314
  be slow, so this isn't used yet.
321
315
 
322
316
  Optionally, a pair of frequency values can be specified for
323
- `fsearch_range` to restrict the frequencies in which the maximum
317
+ ``fsearch_range`` to restrict the frequencies in which the maximum
324
318
  energy value is sought.
325
319
  """
326
320
  allowed_dimensions = ["both", "neither", "channel", "batch"]
@@ -360,7 +354,7 @@ class SingleQTransform(torch.nn.Module):
360
354
  ) -> None:
361
355
  """
362
356
  Take the FFT of the input timeseries and calculate the transform
363
- for each `QTile`
357
+ for each ``QTile``
364
358
  """
365
359
  # Computing the FFT with the same normalization and scaling as GWpy
366
360
  X = torch.fft.rfft(X, norm="forward")
@@ -379,14 +373,15 @@ class SingleQTransform(torch.nn.Module):
379
373
  ]
380
374
  time_interped = torch.cat(
381
375
  [
382
- interpolator(qtile)
383
- for qtile, interpolator in zip(
376
+ qtile_interpolator(qtile)
377
+ for qtile, qtile_interpolator in zip(
384
378
  qtiles, self.qtile_interpolators
385
379
  )
386
380
  ],
387
381
  dim=-2,
388
382
  )
389
- return self.interpolator(time_interped)
383
+ # Transpose because the final dimension gets interpolated
384
+ return self.interpolator(time_interped.mT).mT
390
385
  num_f_bins, num_t_bins = self.spectrogram_shape
391
386
  resampled = [
392
387
  F.interpolate(
@@ -416,9 +411,9 @@ class SingleQTransform(torch.nn.Module):
416
411
  X:
417
412
  Time series of data. Should have the duration and sample rate
418
413
  used to initialize this object. Expected input shape is
419
- `(B, C, T)`, where T is the number of samples, C is the number
420
- of channels, and B is the number of batches. If less than
421
- three-dimensional, axes will be added during Q-tile
414
+ ``(B, C, T)``, where T is the number of samples, C is the
415
+ number of channels, and B is the number of batches. If less
416
+ than three-dimensional, axes will be added during Q-tile
422
417
  computation.
423
418
  norm:
424
419
  The method of normalization used by each QTile
@@ -445,13 +440,13 @@ class QScan(torch.nn.Module):
445
440
  Sample rate of the data in Hz
446
441
  spectrogram_shape:
447
442
  The shape of the interpolated spectrogram, specified as
448
- `(num_f_bins, num_t_bins)`. Because the
443
+ ``(num_f_bins, num_t_bins)``. Because the
449
444
  frequency spacing of the Q-tiles is in log-space, the frequency
450
445
  interpolation is log-spaced as well.
451
446
  qrange:
452
447
  The lower and upper values of Q to consider. The
453
448
  actual values of Q used for the transforms are
454
- determined by the `get_qs` method
449
+ determined by the ``get_qs`` method
455
450
  frange:
456
451
  The lower and upper frequency limit to consider for
457
452
  the transform. If unspecified, default values will
@@ -535,9 +530,9 @@ class QScan(torch.nn.Module):
535
530
  X:
536
531
  Time series of data. Should have the duration and sample rate
537
532
  used to initialize this object. Expected input shape is
538
- `(B, C, T)`, where T is the number of samples, C is the number
539
- of channels, and B is the number of batches. If less than
540
- three-dimensional, axes will be added during Q-tile
533
+ ``(B, C, T)``, where T is the number of samples, C is the
534
+ number of channels, and B is the number of batches. If less
535
+ than three-dimensional, axes will be added during Q-tile
541
536
  computation.
542
537
  fsearch_range:
543
538
  The lower and upper frequency values within which to search
@@ -13,14 +13,14 @@ class ChannelWiseScaler(FittableTransform):
13
13
  Scales timeseries channels by the mean and standard
14
14
  deviation of the channels of the timeseries used to
15
15
  fit the module. To reverse the scaling, provide the
16
- `reverse=True` keyword argument at call time.
16
+ ``reverse=True`` keyword argument at call time.
17
17
  By default, the scaling parameters are set to zero mean
18
18
  and unit variance, amounting to an identity transform.
19
19
 
20
20
  Args:
21
21
  num_channels:
22
22
  The number of channels of the target timeseries.
23
- If left as `None`, the timeseries will be assumed
23
+ If left as ``None``, the timeseries will be assumed
24
24
  to be 1D (single channel).
25
25
  """
26
26
 
@@ -42,8 +42,8 @@ class ChannelWiseScaler(FittableTransform):
42
42
  """Fit the scaling parameters to a timeseries
43
43
 
44
44
  Computes the channel-wise mean and standard deviation
45
- of the timeseries `X` and sets these values to the
46
- `mean` and `std` parameters of the scaler.
45
+ of the timeseries ``X`` and sets these values to the
46
+ ``mean`` and ``std`` parameters of the scaler.
47
47
  """
48
48
 
49
49
  if X.ndim == 1: