ml4gw 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

ml4gw/spectral.py CHANGED
@@ -27,8 +27,8 @@ def median(x: Float[Tensor, "... size"], axis: int) -> Float[Tensor, "..."]:
27
27
  """
28
28
  Implements a median calculation that matches numpy's
29
29
  behavior for an even number of elements and includes
30
- the same bias correction used by scipy's implementation.
31
- see https://github.com/scipy/scipy/blob/main/scipy/signal/_spectral_py.py#L2066
30
+ the same bias correction used by
31
+ `scipy's implementation <https://github.com/scipy/scipy/blob/main/scipy/signal/_spectral_py.py#L2066>`_.
32
32
  """ # noqa: E501
33
33
  n = x.shape[axis]
34
34
  ii_2 = 2 * torch.arange(1.0, (n - 1) // 2 + 1)
@@ -111,50 +111,50 @@ def fast_spectral_density(
111
111
  The timeseries tensor whose power spectral density
112
112
  to compute, or for cross spectral density the
113
113
  timeseries whose fft will be conjugated. Can have
114
- shape `(batch_size, num_channels, length * sample_rate)`,
115
- `(num_channels, length * sample_rate)`, or
116
- `(length * sample_rate)`.
114
+ shape ``(batch_size, num_channels, length * sample_rate)``,
115
+ ``(num_channels, length * sample_rate)``, or
116
+ ``(length * sample_rate)``.
117
117
  nperseg:
118
118
  Number of samples included in each FFT window
119
119
  nstride:
120
120
  Stride between FFT windows
121
121
  window:
122
122
  Window array to multiply by each FFT window before
123
- FFT computation. Should have length `nperseg // 2 + 1`.
123
+ FFT computation. Should have length ``nperseg // 2 + 1``.
124
124
  scale:
125
125
  Scale factor to multiply the FFT'd data by, related to
126
126
  desired units for output tensor (e.g. letting this equal
127
- `1 / (sample_rate * (window**2).sum())` will give output
128
- units of density, $\\text{Hz}^-1$$.
127
+ ``1 / (sample_rate * (window**2).sum())`` will give output
128
+ units of density, :math``\\text{Hz}^-1``.
129
129
  average:
130
130
  How to aggregate the contributions of each FFT window to
131
- the spectral density. Allowed options are `'mean'` and
132
- `'median'`.
131
+ the spectral density. Allowed options are ``'mean'`` and
132
+ ``'median'``.
133
133
  y:
134
134
  Timeseries tensor to compute cross spectral density
135
- with `x`. If left as `None`, `x`'s power spectral
136
- density will be returned. Otherwise, if `x` is 1D,
137
- `y` must also be 1D. If `x` is 2D, the assumption
135
+ with ``x``. If left as ``None``, ``x``'s power spectral
136
+ density will be returned. Otherwise, if ``x`` is 1D,
137
+ ``y`` must also be 1D. If ``x`` is 2D, the assumption
138
138
  is that this represents a single multi-channel timeseries,
139
- and `y` must be either 2D or 1D. In the former case,
139
+ and ``y`` must be either 2D or 1D. In the former case,
140
140
  the cross-spectral densities of each channel will be
141
- computed individually, so `y` must have the same shape as `x`.
142
- Otherwise, this will compute the CSD of each of `x`'s channels
143
- with `y`. If `x` is 3D, this will be assumed to be a batch
144
- of multi-channel timeseries. In this case, `y` can either
141
+ computed individually, so ``y`` must have the same shape as ``x``.
142
+ Otherwise, this will compute the CSD of each of ``x``'s channels
143
+ with ``y``. If ``x`` is 3D, this will be assumed to be a batch
144
+ of multi-channel timeseries. In this case, ``y`` can either
145
145
  be 3D, in which case each channel of each batch element will
146
146
  have its CSD calculated or 2D, which has two different options.
147
- If `y`'s 0th dimension matches `x`'s 0th dimension, it will
148
- be assumed that `y` represents a batch of 1D timeseries, and
147
+ If ``y``'s 0th dimension matches ``x``'s 0th dimension, it will
148
+ be assumed that ``y`` represents a batch of 1D timeseries, and
149
149
  for each batch element this timeseries will have its CSD with
150
- each channel of the corresponding batch element of `x`
151
- calculated. Otherwise, it sill be assumed that `y` represents
150
+ each channel of the corresponding batch element of ``x``
151
+ calculated. Otherwise, it sill be assumed that ``y`` represents
152
152
  a single multi-channel timeseries, in which case each channel
153
- of `y` will have its CSD calculated with the corresponding
154
- channel in `x` across _all_ of `x`'s batch elements.
153
+ of ``y`` will have its CSD calculated with the corresponding
154
+ channel in ``x`` across _all_ of ``x``'s batch elements.
155
155
  Returns:
156
- Tensor of power spectral densities of `x` or its cross spectral
157
- density with the timeseries in `y`.
156
+ Tensor of power spectral densities of ``x`` or its cross spectral
157
+ density with the timeseries in ``y``.
158
158
  """
159
159
 
160
160
  _validate_shapes(x, nperseg, y)
@@ -262,25 +262,25 @@ def spectral_density(
262
262
  The timeseries tensor whose power spectral density
263
263
  to compute, or for cross spectral density the
264
264
  timeseries whose fft will be conjugated. Can have
265
- shape `(batch_size, num_channels, length * sample_rate)`,
266
- `(num_channels, length * sample_rate)`, or
267
- `(length * sample_rate)`.
265
+ shape ``(batch_size, num_channels, length * sample_rate)``,
266
+ ``(num_channels, length * sample_rate)``, or
267
+ ``(length * sample_rate)``.
268
268
  nperseg:
269
269
  Number of samples included in each FFT window
270
270
  nstride:
271
271
  Stride between FFT windows
272
272
  window:
273
273
  Window array to multiply by each FFT window before
274
- FFT computation. Should have length `nperseg // 2 + 1`.
274
+ FFT computation. Should have length ``nperseg // 2 + 1``.
275
275
  scale:
276
276
  Scale factor to multiply the FFT'd data by, related to
277
277
  desired units for output tensor (e.g. letting this equal
278
- `1 / (sample_rate * (window**2).sum())` will give output
279
- units of density, $\\text{Hz}^-1$$.
278
+ ``1 / (sample_rate * (window**2).sum())`` will give output
279
+ units of density, :math:`\\text{Hz}^-1`.
280
280
  average:
281
281
  How to aggregate the contributions of each FFT window to
282
- the spectral density. Allowed options are `'mean'` and
283
- `'median'`.
282
+ the spectral density. Allowed options are ``'mean'`` and
283
+ ``'median'``.
284
284
  """
285
285
 
286
286
  _validate_shapes(x, nperseg)
@@ -348,18 +348,18 @@ def truncate_inverse_power_spectrum(
348
348
  """
349
349
  Truncate the length of the time domain response
350
350
  of a whitening filter built using the specified
351
- `psd` so that it has maximum length `fduration`
351
+ ``psd`` so that it has maximum length ``fduration``
352
352
  seconds. This is meant to mitigate the impact
353
353
  of sharp features in the background PSD causing
354
354
  time domain responses longer than the segments
355
355
  to which the whitening filter will be applied.
356
356
 
357
357
  Implementation details adapted from
358
- https://github.com/vivinousi/gw-detection-deep-learning/blob/203966cc2ee47c32c292be000fb009a16824b7d9/modules/whiten.py#L8
358
+ `here <https://github.com/vivinousi/gw-detection-deep-learning/blob/203966cc2ee47c32c292be000fb009a16824b7d9/modules/whiten.py#L8>`_.
359
359
 
360
360
  Args:
361
361
  psd:
362
- The one-sided power spectraul density used
362
+ The one-sided power spectral density used
363
363
  to construct a whitening filter.
364
364
  fduration:
365
365
  Desired length in seconds of the time domain
@@ -375,14 +375,14 @@ def truncate_inverse_power_spectrum(
375
375
  highpass:
376
376
  If specified, will zero out the frequency response
377
377
  of all frequencies below this value in Hz. If left
378
- as `None`, no highpass filtering will be applied.
378
+ as ``None``, no highpass filtering will be applied.
379
379
  lowpass:
380
380
  If specified, will zero out the frequency response
381
381
  of all frequencies above this value in Hz. If left
382
- as `None`, no lowpass filtering will be applied.
382
+ as ``None``, no lowpass filtering will be applied.
383
383
  Returns:
384
384
  The PSD with its time domain response truncated
385
- to `fduration` and any filtered frequencies
385
+ to ``fduration`` and any filtered frequencies
386
386
  tapered.
387
387
  """ # noqa: E501
388
388
 
@@ -469,7 +469,7 @@ def whiten(
469
469
  Whiten a batch of timeseries using the specified
470
470
  background one-sided power spectral densities (PSDs),
471
471
  modified to have the desired time domain response length
472
- `fduration` and possibly to highpass/lowpass filter.
472
+ ``fduration`` and possibly to highpass/lowpass filter.
473
473
 
474
474
  Args:
475
475
  X:
@@ -480,11 +480,11 @@ def whiten(
480
480
  the inverse of the square root of this PSD, ensuring
481
481
  that data from the same distribution will have
482
482
  approximately uniform power after whitening.
483
- If 2D, each batch element in `X` will be whitened
483
+ If 2D, each batch element in ``X`` will be whitened
484
484
  using the same PSDs. If 3D, each batch element will
485
485
  be whitened by the PSDs contained along the 0th
486
- dimenion of `psd`, and so the first two dimensions
487
- of `X` and `psd` should match.
486
+ dimenion of ``psd``, and so the first two dimensions
487
+ of ``X`` and ``psd`` should match.
488
488
  fduration:
489
489
  Desired length in seconds of the time domain
490
490
  response of a whitening filter built using
@@ -496,20 +496,20 @@ def whiten(
496
496
  the whitened timeseries to account for filter
497
497
  settle-in time.
498
498
  sample_rate:
499
- Rate at which the data in `X` has been sampled
499
+ Rate at which the data in ``X`` has been sampled
500
500
  highpass:
501
501
  The frequency in Hz at which to highpass filter
502
502
  the data, setting the frequency response in the
503
- whitening filter to 0. If left as `None`, no
503
+ whitening filter to 0. If left as ``None``, no
504
504
  highpass filtering will be applied.
505
505
  lowpass:
506
506
  The frequency in Hz at which to lowpass filter
507
507
  the data, setting the frequency response in the
508
- whitening filter to 0. If left as `None`, no
508
+ whitening filter to 0. If left as ``None``, no
509
509
  lowpass filtering will be applied.
510
510
  Returns:
511
511
  Batch of whitened multichannel timeseries with
512
- `fduration / 2` seconds trimmed from each side.
512
+ ``fduration / 2`` seconds trimmed from each side.
513
513
  """
514
514
 
515
515
  # figure out how much data we'll need to slice
@@ -9,7 +9,7 @@ class IIRFilter(torch.nn.Module):
9
9
  r"""
10
10
  IIR digital and analog filter design given order and critical points.
11
11
  Design an Nth-order digital or analog filter and apply it to a signal.
12
- Uses SciPy's `iirfilter` function to create the filter coefficients.
12
+ Uses SciPy's ``iirfilter`` function to create the filter coefficients.
13
13
  https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirfilter.html
14
14
 
15
15
  The forward call of this module accepts a batch tensor of shape
@@ -24,8 +24,8 @@ class IIRFilter(torch.nn.Module):
24
24
  default, fs is 2 half-cycles/sample, so these are normalized
25
25
  from 0 to 1, where 1 is the Nyquist frequency. (Wn is thus in
26
26
  half-cycles / sample). For analog filters, Wn is an angular
27
- frequency (e.g., rad/s). When Wn is a length-2 sequence,`Wn[0]`
28
- must be less than `Wn[1]`.
27
+ frequency (e.g., rad/s). When Wn is a length-2 sequence,``Wn[0]``
28
+ must be less than ``Wn[1]``.
29
29
  rp:
30
30
  For Chebyshev and elliptic filters, provides the maximum ripple in
31
31
  the passband. (dB)
@@ -8,20 +8,19 @@ from ..utils.slicing import unfold_windows
8
8
 
9
9
  class ShiftedPearsonCorrelation(torch.nn.Module):
10
10
  """
11
- Compute the [Pearson correlation]
12
- (https://en.wikipedia.org/wiki/Pearson_correlation_coefficient)
11
+ Compute the `Pearson correlation <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_
13
12
  for two equal-length timeseries over a pre-defined number of time
14
13
  shifts in each direction. Useful for when you want a
15
14
  correlation, but not over every possible shift (i.e.
16
15
  a convolution).
17
16
 
18
- The number of dimensions of the second timeseries `y`
17
+ The number of dimensions of the second timeseries ``y``
19
18
  passed at call time should always be less than or equal
20
- to the number of dimensions of the first timeseries `x`,
19
+ to the number of dimensions of the first timeseries ``x``,
21
20
  and each dimension should match the corresponding one of
22
- `x` in reverse order (i.e. if `x` has shape `(B, C, T)`
23
- then `y` should either have shape `(T,)`, `(C, T)`, or
24
- `(B, C, T)`).
21
+ ``x`` in reverse order (i.e. if ``x`` has shape ``(B, C, T)``
22
+ then ``y`` should either have shape ``(T,)``, ``(C, T)``, or
23
+ ``(B, C, T)``).
25
24
 
26
25
  Note that no windowing to either timeseries is applied
27
26
  at call time. Users should do any requisite windowing
@@ -36,7 +35,7 @@ class ShiftedPearsonCorrelation(torch.nn.Module):
36
35
  The maximum number of 1-step time shifts in
37
36
  each direction over which to compute the
38
37
  Pearson coefficient. Output shape will then
39
- be `(2 * max_shifts + 1, B, C)`.
38
+ be ``(2 * max_shifts + 1, B, C)``.
40
39
  """
41
40
 
42
41
  def __init__(self, max_shift: int) -> None:
@@ -22,7 +22,7 @@ class QTile(torch.nn.Module):
22
22
  """
23
23
  Compute the row of Q-tiles for a single Q value and a single
24
24
  frequency for a batch of multi-channel frequency series data.
25
- Should really be called `QRow`, but I want to match GWpy.
25
+ Should really be called ``QRow``, but I want to match GWpy.
26
26
  Input data should have three dimensions or fewer.
27
27
  If fewer, dimensions will be added until the input is
28
28
  three-dimensional.
@@ -112,17 +112,17 @@ class QTile(torch.nn.Module):
112
112
  fseries:
113
113
  Frequency series of data. Should correspond to data with
114
114
  the duration and sample rate used to initialize this object.
115
- Expected input shape is `(B, C, F)`, where F is the number
115
+ Expected input shape is ``(B, C, F)``, where F is the number
116
116
  of samples, C is the number of channels, and B is the number
117
117
  of batches. If less than three-dimensional, axes will be
118
118
  added.
119
119
  norm:
120
120
  The method of normalization. Options are "median", "mean", or
121
- `None`.
121
+ ``None``.
122
122
 
123
123
  Returns:
124
124
  The row of Q-tiles for the given Q and frequency. Output is
125
- three-dimensional: `(B, C, T)`
125
+ three-dimensional: ``(B, C, T)``
126
126
  """
127
127
  if len(fseries.shape) > 3:
128
128
  raise ValueError("Input data has more than 3 dimensions")
@@ -164,7 +164,7 @@ class SingleQTransform(torch.nn.Module):
164
164
  Sample rate of the data in Hz
165
165
  spectrogram_shape:
166
166
  The shape of the interpolated spectrogram, specified as
167
- `(num_f_bins, num_t_bins)`. Because the
167
+ ``(num_f_bins, num_t_bins)``. Because the
168
168
  frequency spacing of the Q-tiles is in log-space, the frequency
169
169
  interpolation is log-spaced as well.
170
170
  q:
@@ -176,14 +176,14 @@ class SingleQTransform(torch.nn.Module):
176
176
  mismatch:
177
177
  The maximum fractional mismatch between neighboring tiles
178
178
  interpolation_method:
179
- The method by which to interpolate each `QTile` to the specified
179
+ The method by which to interpolate each ``QTile`` to the specified
180
180
  number of time and frequency bins. The acceptable values are
181
181
  "bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
182
182
  options will use PyTorch's built-in interpolation modes, while
183
183
  "spline" will use the custom Torch-based implementation in
184
- `ml4gw`, as PyTorch does not have spline-based intertpolation.
184
+ ``ml4gw``, as PyTorch does not have spline-based intertpolation.
185
185
  The "spline" mode is most similar to the results of GWpy's
186
- Q-transform, which uses `scipy` to do spline interpolation.
186
+ Q-transform, which uses ``scipy`` to do spline interpolation.
187
187
  However, it is also the slowest and most memory intensive due to
188
188
  the matrix equation solving steps. Therefore, the default method
189
189
  is "bicubic" as it produces the most similar results while
@@ -291,7 +291,7 @@ class SingleQTransform(torch.nn.Module):
291
291
  def get_freqs(self) -> Float[Tensor, " nfreq"]:
292
292
  """
293
293
  Calculate the frequencies that will be used in this transform.
294
- For each frequency, a `QTile` is created.
294
+ For each frequency, a ``QTile`` is created.
295
295
  """
296
296
  minf, maxf = self.frange
297
297
  fcum_mismatch = (
@@ -320,7 +320,7 @@ class SingleQTransform(torch.nn.Module):
320
320
  be slow, so this isn't used yet.
321
321
 
322
322
  Optionally, a pair of frequency values can be specified for
323
- `fsearch_range` to restrict the frequencies in which the maximum
323
+ ``fsearch_range`` to restrict the frequencies in which the maximum
324
324
  energy value is sought.
325
325
  """
326
326
  allowed_dimensions = ["both", "neither", "channel", "batch"]
@@ -360,7 +360,7 @@ class SingleQTransform(torch.nn.Module):
360
360
  ) -> None:
361
361
  """
362
362
  Take the FFT of the input timeseries and calculate the transform
363
- for each `QTile`
363
+ for each ``QTile``
364
364
  """
365
365
  # Computing the FFT with the same normalization and scaling as GWpy
366
366
  X = torch.fft.rfft(X, norm="forward")
@@ -416,9 +416,9 @@ class SingleQTransform(torch.nn.Module):
416
416
  X:
417
417
  Time series of data. Should have the duration and sample rate
418
418
  used to initialize this object. Expected input shape is
419
- `(B, C, T)`, where T is the number of samples, C is the number
420
- of channels, and B is the number of batches. If less than
421
- three-dimensional, axes will be added during Q-tile
419
+ ``(B, C, T)``, where T is the number of samples, C is the
420
+ number of channels, and B is the number of batches. If less
421
+ than three-dimensional, axes will be added during Q-tile
422
422
  computation.
423
423
  norm:
424
424
  The method of normalization used by each QTile
@@ -445,13 +445,13 @@ class QScan(torch.nn.Module):
445
445
  Sample rate of the data in Hz
446
446
  spectrogram_shape:
447
447
  The shape of the interpolated spectrogram, specified as
448
- `(num_f_bins, num_t_bins)`. Because the
448
+ ``(num_f_bins, num_t_bins)``. Because the
449
449
  frequency spacing of the Q-tiles is in log-space, the frequency
450
450
  interpolation is log-spaced as well.
451
451
  qrange:
452
452
  The lower and upper values of Q to consider. The
453
453
  actual values of Q used for the transforms are
454
- determined by the `get_qs` method
454
+ determined by the ``get_qs`` method
455
455
  frange:
456
456
  The lower and upper frequency limit to consider for
457
457
  the transform. If unspecified, default values will
@@ -535,9 +535,9 @@ class QScan(torch.nn.Module):
535
535
  X:
536
536
  Time series of data. Should have the duration and sample rate
537
537
  used to initialize this object. Expected input shape is
538
- `(B, C, T)`, where T is the number of samples, C is the number
539
- of channels, and B is the number of batches. If less than
540
- three-dimensional, axes will be added during Q-tile
538
+ ``(B, C, T)``, where T is the number of samples, C is the
539
+ number of channels, and B is the number of batches. If less
540
+ than three-dimensional, axes will be added during Q-tile
541
541
  computation.
542
542
  fsearch_range:
543
543
  The lower and upper frequency values within which to search
@@ -13,14 +13,14 @@ class ChannelWiseScaler(FittableTransform):
13
13
  Scales timeseries channels by the mean and standard
14
14
  deviation of the channels of the timeseries used to
15
15
  fit the module. To reverse the scaling, provide the
16
- `reverse=True` keyword argument at call time.
16
+ ``reverse=True`` keyword argument at call time.
17
17
  By default, the scaling parameters are set to zero mean
18
18
  and unit variance, amounting to an identity transform.
19
19
 
20
20
  Args:
21
21
  num_channels:
22
22
  The number of channels of the target timeseries.
23
- If left as `None`, the timeseries will be assumed
23
+ If left as ``None``, the timeseries will be assumed
24
24
  to be 1D (single channel).
25
25
  """
26
26
 
@@ -42,8 +42,8 @@ class ChannelWiseScaler(FittableTransform):
42
42
  """Fit the scaling parameters to a timeseries
43
43
 
44
44
  Computes the channel-wise mean and standard deviation
45
- of the timeseries `X` and sets these values to the
46
- `mean` and `std` parameters of the scaler.
45
+ of the timeseries ``X`` and sets these values to the
46
+ ``mean`` and ``std`` parameters of the scaler.
47
47
  """
48
48
 
49
49
  if X.ndim == 1:
@@ -14,39 +14,39 @@ class SpectralDensity(torch.nn.Module):
14
14
  of a batch of multichannel timeseries, or the cross spectral
15
15
  density of two batches of multichannel timeseries.
16
16
 
17
- On `SpectralDensity.forward` call, if only one tensor is provided,
17
+ On ``SpectralDensity.forward`` call, if only one tensor is provided,
18
18
  this transform will compute its power spectral density. If a second
19
19
  tensor is provided, the cross spectral density between the two
20
20
  timeseries will be computed. For information about the allowed
21
21
  relationships between these two tensors, see the documentation to
22
- `ml4gw.spectral.fast_spectral_density`.
22
+ :meth:`~ml4gw.spectral.fast_spectral_density`.
23
23
 
24
24
  Note that the cross spectral density computation is currently
25
- only available for the `fast_spectral_density` option. If
26
- `fast=False` and a second tensor is passed to `SpectralDensity.forward`,
27
- a `NotImplementedError` will be raised.
25
+ only available for :meth:`~ml4gw.spectral.fast_spectral_density`. If
26
+ ``fast=False`` and a second tensor is passed to ``SpectralDensity.forward``, # noqa E501
27
+ a ``NotImplementedError`` will be raised.
28
28
 
29
29
  Args:
30
30
  sample_rate:
31
- Rate at which tensors passed to `forward` will be sampled
31
+ Rate at which tensors passed to ``forward`` will be sampled
32
32
  fftlength:
33
33
  Length of the window, in seconds, to use for FFT estimates
34
34
  overlap:
35
35
  Overlap between windows used for FFT calculation. If left
36
- as `None`, this will be set to `fftlength / 2`.
36
+ as ``None``, this will be set to ``fftlength / 2``.
37
37
  average:
38
38
  Aggregation method to use for combining windowed FFTs.
39
- Allowed values are `"mean"` and `"median"`.
39
+ Allowed values are ``"mean"`` and ``"median"``.
40
40
  window:
41
41
  Window array to multiply by each FFT window before
42
- FFT computation. Should have length `nperseg`.
42
+ FFT computation. Should have length ``nperseg``.
43
43
  Defaults to a hanning window.
44
44
  fast:
45
45
  Whether to use a faster spectral density computation that
46
46
  support cross spectral density, or a slower one which does
47
47
  not. The cost of the fast implementation is that it is not
48
48
  exact for the two lowest frequency bins.
49
- """
49
+ """ # noqa E501
50
50
 
51
51
  def __init__(
52
52
  self,
@@ -14,18 +14,18 @@ class MultiResolutionSpectrogram(torch.nn.Module):
14
14
  """
15
15
  Create a batch of multi-resolution spectrograms
16
16
  from a batch of timeseries. Input is expected to
17
- have the shape `(B, C, T)`, where `B` is the number
18
- of batches, `C` is the number of channels, and `T`
17
+ have the shape ``(B, C, T)``, where ``B`` is the number
18
+ of batches, ``C`` is the number of channels, and ``T``
19
19
  is the number of time samples.
20
20
 
21
21
  For each timeseries, calculate multiple normalized
22
- spectrograms based on the `Spectrogram` `kwargs` given.
22
+ spectrograms based on the ``Spectrogram`` ``kwargs`` given.
23
23
  Combine the spectrograms by taking the maximum value
24
24
  from the nearest time-frequncy bin.
25
25
 
26
26
  If the largest number of time bins among the spectrograms
27
- is `N` and the largest number of frequency bins is `M`,
28
- the output will have dimensions `(B, C, M, N)`
27
+ is ``N`` and the largest number of frequency bins is ``M``,
28
+ the output will have dimensions ``(B, C, M, N)``
29
29
 
30
30
  Args:
31
31
  kernel_length:
@@ -34,10 +34,11 @@ class MultiResolutionSpectrogram(torch.nn.Module):
34
34
  spectrogram
35
35
  sample_rate:
36
36
  The sample rate of the timeseries in Hz
37
- kwargs:
37
+ **kwargs:
38
38
  Arguments passed in kwargs will used to create
39
- `torchaudio.transforms.Spectrogram`s. Each
40
- argument should be a list of values. Any list
39
+ ``torchaudio.transforms.Spectrogram`` (see
40
+ `documentation <https://docs.pytorch.org/audio/main/generated/torchaudio.transforms.Spectrogram.html>`_).
41
+ Each argument should be a list of values. Any list
41
42
  of length greater than 1 should be the same
42
43
  length
43
44
  """
@@ -140,9 +141,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
140
141
  Batch of multichannel timeseries which will
141
142
  be used to calculate the multi-resolution
142
143
  spectrogram. Should have the shape
143
- `(B, C, T)`, where `B` is the number of
144
- batches, `C` is the number of channels,
145
- and `T` is the number of time samples.
144
+ ``(B, C, T)``, where ``B`` is the number of
145
+ batches, ``C`` is the number of channels,
146
+ and ``T`` is the number of time samples.
146
147
  """
147
148
  if X.shape[-1] != self.kernel_size:
148
149
  raise ValueError(
@@ -13,16 +13,16 @@ class SplineInterpolate(torch.nn.Module):
13
13
  """
14
14
  Perform 1D or 2D spline interpolation based on De Boor's method.
15
15
  Supports batched, multi-channel inputs, so acceptable data
16
- shapes are `(width)`, `(height, width)`, `(batch, width)`,
17
- `(batch, height, width)`, `(batch, channel, width)`, and
18
- `(batch, channel, height, width)`.
16
+ shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
17
+ ``(batch, height, width)``, ``(batch, channel, width)``, and
18
+ ``(batch, channel, height, width)``.
19
19
 
20
20
  During initialization of this Module, both the desired input
21
21
  and output coordinate Tensors can be specified to allow
22
22
  pre-computation of the B-spline basis matrices, though the only
23
23
  mandatory argument is the coordinates of the data along the
24
- `width` dimension. If no argument is given for coordinates along
25
- the `height` dimension, it is assumed that 1D interpolation is
24
+ ``width`` dimension. If no argument is given for coordinates along
25
+ the ``height`` dimension, it is assumed that 1D interpolation is
26
26
  desired.
27
27
 
28
28
  Unlike scipy's implementation of spline interpolation, the data
@@ -55,7 +55,7 @@ class SplineInterpolate(torch.nn.Module):
55
55
  sx:
56
56
  Regularization factor to avoid singularities during matrix
57
57
  inversion for interpolation along the width dimension. Not
58
- to be confused with the `s` parameter in scipy's spline
58
+ to be confused with the ``s`` parameter in scipy's spline
59
59
  methods, which controls the number of knots.
60
60
  sy:
61
61
  Regularization factor to avoid singularities during matrix
@@ -256,11 +256,6 @@ class SplineInterpolate(torch.nn.Module):
256
256
  return b[:, :, -1]
257
257
 
258
258
  def bivariate_spline_fit_natural(self, Z):
259
- if len(Z.shape) == 3:
260
- Z_Bx = torch.matmul(Z, self.Bx)
261
- # ((BxT @ Bx)^-1 @ (Z @ Bx)T)T = Z @ BxT^-1
262
- return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT
263
-
264
259
  # Adding batch/channel dimension handling
265
260
  # ByT @ Z @ BxW
266
261
  ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx)
@@ -280,8 +275,6 @@ class SplineInterpolate(torch.nn.Module):
280
275
  Z_interp: Interpolated values at the grid points.
281
276
  """
282
277
  # Perform matrix multiplication using einsum to get Z_interp
283
- if len(C.shape) == 3:
284
- return torch.matmul(C, self.Bx_out.mT)
285
278
  return torch.einsum("ik,bckm,mj->bcij", self.By_out, C, self.Bx_out.mT)
286
279
 
287
280
  def _validate_inputs(self, Z, x_out, y_out):
@@ -339,7 +332,7 @@ class SplineInterpolate(torch.nn.Module):
339
332
  Z:
340
333
  Tensor of data to be interpolated. Must be between 1 and 4
341
334
  dimensions. The shape of the tensor must agree with the
342
- input coordinates given on initialization. If `y_in` was
335
+ input coordinates given on initialization. If ``y_in`` was
343
336
  not specified during initialization, it is assumed that
344
337
  Z does not have a height dimension.
345
338
  x_out:
@@ -352,7 +345,7 @@ class SplineInterpolate(torch.nn.Module):
352
345
  initialization.
353
346
 
354
347
  Returns:
355
- A 4D tensor with shape `(batch, channel, height, width)`.
348
+ A 4D tensor with shape ``(batch, channel, height, width)``.
356
349
  Depending on the input data shape, many of these dimensions
357
350
  may have length 1.
358
351
  """
@@ -70,7 +70,7 @@ class FittableSpectralTransform(FittableTransform):
70
70
  )
71
71
 
72
72
  # add two dummy dimensions in case we need to interpolate
73
- # the frequency dimension, since `interpolate` expects
73
+ # the frequency dimension, since ``interpolate`` expects
74
74
  # a (batch, channel, spatial) formatted tensor as input
75
75
  x = x.view(1, 1, -1)
76
76
  if x.size(-1) != num_freqs: