ml4gw 0.7.5__py3-none-any.whl → 0.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

@@ -14,39 +14,39 @@ class SpectralDensity(torch.nn.Module):
14
14
  of a batch of multichannel timeseries, or the cross spectral
15
15
  density of two batches of multichannel timeseries.
16
16
 
17
- On `SpectralDensity.forward` call, if only one tensor is provided,
17
+ On ``SpectralDensity.forward`` call, if only one tensor is provided,
18
18
  this transform will compute its power spectral density. If a second
19
19
  tensor is provided, the cross spectral density between the two
20
20
  timeseries will be computed. For information about the allowed
21
21
  relationships between these two tensors, see the documentation to
22
- `ml4gw.spectral.fast_spectral_density`.
22
+ :meth:`~ml4gw.spectral.fast_spectral_density`.
23
23
 
24
24
  Note that the cross spectral density computation is currently
25
- only available for the `fast_spectral_density` option. If
26
- `fast=False` and a second tensor is passed to `SpectralDensity.forward`,
27
- a `NotImplementedError` will be raised.
25
+ only available for :meth:`~ml4gw.spectral.fast_spectral_density`. If
26
+ ``fast=False`` and a second tensor is passed to ``SpectralDensity.forward``, # noqa E501
27
+ a ``NotImplementedError`` will be raised.
28
28
 
29
29
  Args:
30
30
  sample_rate:
31
- Rate at which tensors passed to `forward` will be sampled
31
+ Rate at which tensors passed to ``forward`` will be sampled
32
32
  fftlength:
33
33
  Length of the window, in seconds, to use for FFT estimates
34
34
  overlap:
35
35
  Overlap between windows used for FFT calculation. If left
36
- as `None`, this will be set to `fftlength / 2`.
36
+ as ``None``, this will be set to ``fftlength / 2``.
37
37
  average:
38
38
  Aggregation method to use for combining windowed FFTs.
39
- Allowed values are `"mean"` and `"median"`.
39
+ Allowed values are ``"mean"`` and ``"median"``.
40
40
  window:
41
41
  Window array to multiply by each FFT window before
42
- FFT computation. Should have length `nperseg`.
42
+ FFT computation. Should have length ``nperseg``.
43
43
  Defaults to a hanning window.
44
44
  fast:
45
45
  Whether to use a faster spectral density computation that
46
46
  support cross spectral density, or a slower one which does
47
47
  not. The cost of the fast implementation is that it is not
48
48
  exact for the two lowest frequency bins.
49
- """
49
+ """ # noqa E501
50
50
 
51
51
  def __init__(
52
52
  self,
@@ -14,18 +14,18 @@ class MultiResolutionSpectrogram(torch.nn.Module):
14
14
  """
15
15
  Create a batch of multi-resolution spectrograms
16
16
  from a batch of timeseries. Input is expected to
17
- have the shape `(B, C, T)`, where `B` is the number
18
- of batches, `C` is the number of channels, and `T`
17
+ have the shape ``(B, C, T)``, where ``B`` is the number
18
+ of batches, ``C`` is the number of channels, and ``T``
19
19
  is the number of time samples.
20
20
 
21
21
  For each timeseries, calculate multiple normalized
22
- spectrograms based on the `Spectrogram` `kwargs` given.
22
+ spectrograms based on the ``Spectrogram`` ``kwargs`` given.
23
23
  Combine the spectrograms by taking the maximum value
24
24
  from the nearest time-frequncy bin.
25
25
 
26
26
  If the largest number of time bins among the spectrograms
27
- is `N` and the largest number of frequency bins is `M`,
28
- the output will have dimensions `(B, C, M, N)`
27
+ is ``N`` and the largest number of frequency bins is ``M``,
28
+ the output will have dimensions ``(B, C, M, N)``
29
29
 
30
30
  Args:
31
31
  kernel_length:
@@ -34,10 +34,11 @@ class MultiResolutionSpectrogram(torch.nn.Module):
34
34
  spectrogram
35
35
  sample_rate:
36
36
  The sample rate of the timeseries in Hz
37
- kwargs:
37
+ **kwargs:
38
38
  Arguments passed in kwargs will used to create
39
- `torchaudio.transforms.Spectrogram`s. Each
40
- argument should be a list of values. Any list
39
+ ``torchaudio.transforms.Spectrogram`` (see
40
+ `documentation <https://docs.pytorch.org/audio/main/generated/torchaudio.transforms.Spectrogram.html>`_).
41
+ Each argument should be a list of values. Any list
41
42
  of length greater than 1 should be the same
42
43
  length
43
44
  """
@@ -140,9 +141,9 @@ class MultiResolutionSpectrogram(torch.nn.Module):
140
141
  Batch of multichannel timeseries which will
141
142
  be used to calculate the multi-resolution
142
143
  spectrogram. Should have the shape
143
- `(B, C, T)`, where `B` is the number of
144
- batches, `C` is the number of channels,
145
- and `T` is the number of time samples.
144
+ ``(B, C, T)``, where ``B`` is the number of
145
+ batches, ``C`` is the number of channels,
146
+ and ``T`` is the number of time samples.
146
147
  """
147
148
  if X.shape[-1] != self.kernel_size:
148
149
  raise ValueError(
@@ -1,131 +1,29 @@
1
1
  """
2
- Adaptation of code from https://github.com/dottormale/Qtransform
2
+ Adaptation of code from https://github.com/dottormale/Qtransform_torch/
3
3
  """
4
4
 
5
5
  from typing import Optional, Tuple
6
6
 
7
7
  import torch
8
- import torch.nn.functional as F
9
8
  from torch import Tensor
10
9
 
11
10
 
12
- class SplineInterpolate(torch.nn.Module):
11
+ class SplineInterpolateBase(torch.nn.Module):
13
12
  """
14
- Perform 1D or 2D spline interpolation based on De Boor's method.
15
- Supports batched, multi-channel inputs, so acceptable data
16
- shapes are `(width)`, `(height, width)`, `(batch, width)`,
17
- `(batch, height, width)`, `(batch, channel, width)`, and
18
- `(batch, channel, height, width)`.
19
-
20
- During initialization of this Module, both the desired input
21
- and output coordinate Tensors can be specified to allow
22
- pre-computation of the B-spline basis matrices, though the only
23
- mandatory argument is the coordinates of the data along the
24
- `width` dimension. If no argument is given for coordinates along
25
- the `height` dimension, it is assumed that 1D interpolation is
26
- desired.
27
-
28
- Unlike scipy's implementation of spline interpolation, the data
29
- to be interpolated is not passed until actually calling the
30
- object. This is useful for cases where the input and output
31
- coordinates are known in advance, but the data is not, so that
32
- the interpolator can be set up ahead of time.
33
-
34
- WARNING: compared to scipy's spline interpolation, this method
35
- produces edge artifacts when the output coordinates are near
36
- the boundaries of the input coordinates. Therefore, it is
37
- recommended to interpolate only to coordinates that are well
38
- within the input coordinate range. Unfortunately, the specific
39
- definition of "well within" changes based on the size of the
40
- data, so some testing may be required to get good results.
41
-
42
- Args:
43
- x_in:
44
- Coordinates of the width dimension of the data
45
- y_in:
46
- Coordinates of the height dimension of the data. If not
47
- specified, it is assumed the 1D interpolation is desired,
48
- and so the default value is a Tensor of length 1
49
- kx:
50
- Degree of spline interpolation along the width dimension.
51
- Default is cubic.
52
- ky:
53
- Degree of spline interpolation along the height dimension.
54
- Default is cubic.
55
- sx:
56
- Regularization factor to avoid singularities during matrix
57
- inversion for interpolation along the width dimension. Not
58
- to be confused with the `s` parameter in scipy's spline
59
- methods, which controls the number of knots.
60
- sy:
61
- Regularization factor to avoid singularities during matrix
62
- inversion for interpolation along the height dimension.
63
- x_out:
64
- Coordinates for the data to be interpolated to along the
65
- width dimension. If not specified during initialization,
66
- this must be specified during the object call.
67
- y_out:
68
- Coordinates for the data to be interpolated to along the
69
- height dimension. If not specified during initialization,
70
- this must be specified during the object call.
71
-
13
+ Base class for spline interpolation.
72
14
  """
73
15
 
74
- def __init__(
75
- self,
76
- x_in: Tensor,
77
- y_in: Tensor = None,
78
- kx: int = 3,
79
- ky: int = 3,
80
- sx: float = 0.001,
81
- sy: float = 0.001,
82
- x_out: Optional[Tensor] = None,
83
- y_out: Optional[Tensor] = None,
84
- ):
85
- super().__init__()
86
- if y_in is None:
87
- y_in = Tensor([1])
88
- self.kx = kx
89
- self.ky = ky
90
- self.sx = sx
91
- self.sy = sy
92
- self.register_buffer("x_in", x_in)
93
- self.register_buffer("y_in", y_in)
94
- self.register_buffer("x_out", x_out)
95
- self.register_buffer("y_out", y_out)
96
-
97
- tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
98
- self.register_buffer("tx", tx)
99
- self.register_buffer("Bx", Bx)
100
- self.register_buffer("BxT_Bx", BxT_Bx)
101
-
102
- ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
103
- self.register_buffer("ty", ty)
104
- self.register_buffer("By", By)
105
- self.register_buffer("ByT_By", ByT_By)
106
-
107
- if self.x_out is not None:
108
- Bx_out = self.bspline_basis_natural(x_out, kx, self.tx)
109
- self.register_buffer("Bx_out", Bx_out)
110
- if self.y_out is not None:
111
- By_out = self.bspline_basis_natural(y_out, ky, self.ty)
112
- self.register_buffer("By_out", By_out)
113
-
114
16
  def _compute_knots_and_basis_matrices(self, x, k, s):
115
- knots = self.generate_natural_knots(x, k)
17
+ knots = self.generate_fitpack_knots(x, k)
116
18
  basis_matrix = self.bspline_basis_natural(x, k, knots)
117
19
  identity = torch.eye(basis_matrix.shape[-1])
118
20
  B_T_B = basis_matrix.T @ basis_matrix + s * identity
119
21
  return knots, basis_matrix, B_T_B
120
22
 
121
- def generate_natural_knots(self, x: Tensor, k: int) -> Tensor:
23
+ def generate_fitpack_knots(self, x: Tensor, k: int) -> Tensor:
122
24
  """
123
- Generates a natural knot sequence for B-spline interpolation.
124
- Natural knot sequence means that 2*k knots are added to the beginning
125
- and end of datapoints as replicas of first and last datapoint
126
- respectively in order to enforce natural boundary conditions,
127
- i.e. second derivative = 0.
128
- The other n nodes are placed in correspondece of the data points.
25
+ Generates a knot sequence for B-spline interpolation
26
+ in the same way as the FITPACK algorithm used by SciPy.
129
27
 
130
28
  Args:
131
29
  x: Tensor of data point positions.
@@ -134,7 +32,17 @@ class SplineInterpolate(torch.nn.Module):
134
32
  Returns:
135
33
  Tensor of knot positions.
136
34
  """
137
- return F.pad(x[None], (k, k), mode="replicate")[0]
35
+ num_knots = x.shape[-1] + k + 1
36
+ knots = torch.zeros(num_knots, dtype=x.dtype)
37
+ knots[: k + 1] = x[0]
38
+ knots[-(k + 1) :] = x[-1]
39
+
40
+ # Interior knots are the rolling average of the data points
41
+ # excluding the first and last points
42
+ windows = x[1:-1].unfold(dimension=-1, size=k, step=1)
43
+ knots[k + 1 : -k - 1] = windows.mean(dim=-1)
44
+
45
+ return knots
138
46
 
139
47
  def compute_L_R(
140
48
  self,
@@ -233,9 +141,6 @@ class SplineInterpolate(torch.nn.Module):
233
141
  Returns:
234
142
  Tensor containing the kth-order B-spline basis functions
235
143
  """
236
-
237
- if len(x) == 1:
238
- return torch.eye(1)
239
144
  n = x.shape[0]
240
145
  m = t.shape[0] - k - 1
241
146
 
@@ -255,12 +160,272 @@ class SplineInterpolate(torch.nn.Module):
255
160
 
256
161
  return b[:, :, -1]
257
162
 
258
- def bivariate_spline_fit_natural(self, Z):
259
- if len(Z.shape) == 3:
260
- Z_Bx = torch.matmul(Z, self.Bx)
261
- # ((BxT @ Bx)^-1 @ (Z @ Bx)T)T = Z @ BxT^-1
262
- return torch.linalg.solve(self.BxT_Bx, Z_Bx.mT).mT
263
163
 
164
+ class SplineInterpolate1D(SplineInterpolateBase):
165
+ """
166
+ Perform 1D spline interpolation based on De Boor's method.
167
+ It is allowed to have two spatial dimensions, but the second
168
+ dimension cannot be interpolated along. To interpolate along both
169
+ dimensions, use :class:`SplineInterpolate2D`.
170
+
171
+ Supports batched, multi-channel inputs, so acceptable data
172
+ shapes are ``(width)``, ``(height, width)``, ``(batch, width)``,
173
+ ``(batch, height, width)``, ``(batch, channel, width)``, and
174
+ ``(batch, channel, height, width)``.
175
+
176
+ During initialization of this Module, both the desired input
177
+ and output coordinate Tensors can be specified to allow
178
+ pre-computation of the B-spline basis matrices, though the only
179
+ mandatory argument is the coordinates of the data along the
180
+ ``width`` dimension.
181
+
182
+ Unlike scipy's implementation of spline interpolation, the data
183
+ to be interpolated is not passed until actually calling the
184
+ object. This is useful for cases where the input and output
185
+ coordinates are known in advance, but the data is not, so that
186
+ the interpolator can be set up ahead of time.
187
+
188
+ Args:
189
+ x_in:
190
+ Coordinates of the width dimension of the data
191
+ kx:
192
+ Degree of spline interpolation along the width dimension.
193
+ Default is cubic.
194
+ sx:
195
+ Regularization factor to avoid singularities during matrix
196
+ inversion for interpolation along the width dimension. Not
197
+ to be confused with the ``s`` parameter in scipy's spline
198
+ methods, which controls the number of knots.
199
+ x_out:
200
+ Coordinates for the data to be interpolated to along the
201
+ width dimension. If not specified during initialization,
202
+ this must be specified during the object call.
203
+
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ x_in: Tensor,
209
+ kx: int = 3,
210
+ sx: float = 0.0,
211
+ x_out: Optional[Tensor] = None,
212
+ ):
213
+ super().__init__()
214
+
215
+ if len(x_in) < kx + 2:
216
+ raise ValueError(
217
+ "Input x-coordinates must have at least kx + 2 points."
218
+ )
219
+
220
+ # Ensure that coordinates are floats
221
+ x_in = x_in.float()
222
+ x_out = x_out.float() if x_out is not None else None
223
+
224
+ self.kx = kx
225
+ self.sx = sx
226
+ self.register_buffer("x_in", x_in)
227
+ self.register_buffer("x_out", x_out)
228
+
229
+ tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
230
+ self.register_buffer("tx", tx)
231
+ self.register_buffer("Bx", Bx)
232
+ self.register_buffer("BxT_Bx", BxT_Bx)
233
+
234
+ if self.x_out is not None:
235
+ x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
236
+ Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
237
+ self.register_buffer("Bx_out", Bx_out)
238
+
239
+ def spline_fit_natural(self, Z):
240
+ # Adding batch/channel dimension handling
241
+ # Bx @ Z
242
+ BxT_Z = torch.einsum("ij,bchj->bchi", self.Bx.T, Z)
243
+ # (BxT @ Bx)^-1 @ (BxT @ Z) = Bx^-1 @ Z
244
+ C = torch.linalg.solve(self.BxT_Bx, BxT_Z.unsqueeze(-1))
245
+ return C.squeeze(-1)
246
+
247
+ def evaluate_spline(self, C: Tensor):
248
+ """
249
+ Evaluate a bivariate spline on a grid of x and y points.
250
+
251
+ Args:
252
+ C: Coefficient tensor of shape (batch_size, mx, my).
253
+
254
+ Returns:
255
+ Z_interp: Interpolated values at the grid points.
256
+ """
257
+ # Perform matrix multiplication using einsum to get Z_interp
258
+ return torch.einsum("ij,bchj->bchi", self.Bx_out, C)
259
+
260
+ def _validate_inputs(self, Z, x_out):
261
+ if x_out is None and self.x_out is None:
262
+ raise ValueError(
263
+ "Output x-coordinates were not specified in either object "
264
+ "creation or in forward call"
265
+ )
266
+
267
+ dims = len(Z.shape)
268
+ if dims > 4:
269
+ raise ValueError("Input data has more than 4 dimensions")
270
+
271
+ if Z.shape[-1] != len(self.x_in):
272
+ raise ValueError(
273
+ "The spatial dimensions of the data tensor do not match "
274
+ "the given input dimensions. "
275
+ f"Expected {len(self.x_in)}, but got {Z.shape[-1]}"
276
+ )
277
+
278
+ # Expand Z to have a batch, channel, and height dimension if needed
279
+ while len(Z.shape) < 4:
280
+ Z = Z.unsqueeze(0)
281
+
282
+ return Z
283
+
284
+ def forward(
285
+ self,
286
+ Z: Tensor,
287
+ x_out: Optional[Tensor] = None,
288
+ ) -> Tensor:
289
+ """
290
+ Compute the interpolated data
291
+
292
+ Args:
293
+ Z:
294
+ Tensor of data to be interpolated. Must be between 2 and 4
295
+ dimensions. The shape of the tensor must agree with the
296
+ input coordinates given on initialization.
297
+ x_out:
298
+ Coordinates to interpolate the data to along the width
299
+ dimension. Overrides any value that was set during
300
+ initialization.
301
+
302
+ Returns:
303
+ A 4D tensor with shape ``(batch, channel, height, width)``.
304
+ Depending on the input data shape, many of these dimensions
305
+ may have length 1.
306
+ """
307
+
308
+ Z = self._validate_inputs(Z, x_out)
309
+
310
+ if x_out is not None:
311
+ x_out = x_out.float()
312
+ x_clamped = torch.clamp(
313
+ x_out, self.tx[self.kx], self.tx[-self.kx - 1]
314
+ )
315
+ self.Bx_out = self.bspline_basis_natural(
316
+ x_clamped, self.kx, self.tx
317
+ )
318
+
319
+ coef = self.spline_fit_natural(Z)
320
+ Z_interp = self.evaluate_spline(coef)
321
+ return Z_interp
322
+
323
+
324
+ class SplineInterpolate2D(SplineInterpolateBase):
325
+ """
326
+ Perform 2D spline interpolation based on De Boor's method.
327
+ Supports batched, multi-channel inputs, so acceptable data
328
+ shapes are ``(height, width)``, ``(batch, height, width)``,
329
+ and ``(batch, channel, height, width)``.
330
+
331
+ During initialization of this Module, both the desired input
332
+ and output coordinate Tensors can be specified to allow
333
+ pre-computation of the B-spline basis matrices, though the only
334
+ mandatory arguments are the input coordinates.
335
+
336
+ Unlike scipy's implementation of spline interpolation, the data
337
+ to be interpolated is not passed until actually calling the
338
+ object. This is useful for cases where the input and output
339
+ coordinates are known in advance, but the data is not, so that
340
+ the interpolator can be set up ahead of time.
341
+
342
+ Args:
343
+ x_in:
344
+ Coordinates of the width dimension of the data
345
+ y_in:
346
+ Coordinates of the height dimension of the data.
347
+ kx:
348
+ Degree of spline interpolation along the width dimension.
349
+ Default is cubic.
350
+ ky:
351
+ Degree of spline interpolation along the height dimension.
352
+ Default is cubic.
353
+ sx:
354
+ Regularization factor to avoid singularities during matrix
355
+ inversion for interpolation along the width dimension. Not
356
+ to be confused with the ``s`` parameter in scipy's spline
357
+ methods, which controls the number of knots.
358
+ sy:
359
+ Regularization factor to avoid singularities during matrix
360
+ inversion for interpolation along the height dimension.
361
+ x_out:
362
+ Coordinates for the data to be interpolated to along the
363
+ width dimension. If not specified during initialization,
364
+ this must be specified during the object call.
365
+ y_out:
366
+ Coordinates for the data to be interpolated to along the
367
+ height dimension. If not specified during initialization,
368
+ this must be specified during the object call.
369
+
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ x_in: Tensor,
375
+ y_in: Tensor,
376
+ kx: int = 3,
377
+ ky: int = 3,
378
+ sx: float = 0.0,
379
+ sy: float = 0.0,
380
+ x_out: Optional[Tensor] = None,
381
+ y_out: Optional[Tensor] = None,
382
+ ):
383
+ super().__init__()
384
+
385
+ if len(x_in) < kx + 2:
386
+ raise ValueError(
387
+ "Input x-coordinates must have at least kx + 2 points."
388
+ )
389
+ if len(y_in) < ky + 2:
390
+ raise ValueError(
391
+ "Input y-coordinates must have at least ky + 2 points."
392
+ )
393
+
394
+ # Ensure that coordinates are floats
395
+ x_in = x_in.float()
396
+ y_in = y_in.float()
397
+ x_out = x_out.float() if x_out is not None else None
398
+ y_out = y_out.float() if y_out is not None else None
399
+
400
+ self.kx = kx
401
+ self.ky = ky
402
+ self.sx = sx
403
+ self.sy = sy
404
+ self.register_buffer("x_in", x_in)
405
+ self.register_buffer("y_in", y_in)
406
+ self.register_buffer("x_out", x_out)
407
+ self.register_buffer("y_out", y_out)
408
+
409
+ tx, Bx, BxT_Bx = self._compute_knots_and_basis_matrices(x_in, kx, sx)
410
+ self.register_buffer("tx", tx)
411
+ self.register_buffer("Bx", Bx)
412
+ self.register_buffer("BxT_Bx", BxT_Bx)
413
+
414
+ ty, By, ByT_By = self._compute_knots_and_basis_matrices(y_in, ky, sy)
415
+ self.register_buffer("ty", ty)
416
+ self.register_buffer("By", By)
417
+ self.register_buffer("ByT_By", ByT_By)
418
+
419
+ if self.x_out is not None:
420
+ x_clamped = torch.clamp(x_out, tx[kx], tx[-kx - 1])
421
+ Bx_out = self.bspline_basis_natural(x_clamped, kx, self.tx)
422
+ self.register_buffer("Bx_out", Bx_out)
423
+ if self.y_out is not None:
424
+ y_clamped = torch.clamp(y_out, ty[ky], ty[-ky - 1])
425
+ By_out = self.bspline_basis_natural(y_clamped, ky, self.ty)
426
+ self.register_buffer("By_out", By_out)
427
+
428
+ def bivariate_spline_fit_natural(self, Z):
264
429
  # Adding batch/channel dimension handling
265
430
  # ByT @ Z @ BxW
266
431
  ByT_Z_Bx = torch.einsum("ij,bcik,kl->bcjl", self.By, Z, self.Bx)
@@ -280,8 +445,6 @@ class SplineInterpolate(torch.nn.Module):
280
445
  Z_interp: Interpolated values at the grid points.
281
446
  """
282
447
  # Perform matrix multiplication using einsum to get Z_interp
283
- if len(C.shape) == 3:
284
- return torch.matmul(C, self.Bx_out.mT)
285
448
  return torch.einsum("ik,bckm,mj->bcij", self.By_out, C, self.Bx_out.mT)
286
449
 
287
450
  def _validate_inputs(self, Z, x_out, y_out):
@@ -292,29 +455,16 @@ class SplineInterpolate(torch.nn.Module):
292
455
  )
293
456
 
294
457
  if y_out is None and self.y_out is None:
295
- y_out = self.y_in
458
+ raise ValueError(
459
+ "Output y-coordinates were not specified in either object "
460
+ "creation or in forward call"
461
+ )
296
462
 
297
463
  dims = len(Z.shape)
298
464
  if dims > 4:
299
465
  raise ValueError("Input data has more than 4 dimensions")
300
-
301
- if len(self.y_in) > 1 and dims == 1:
302
- raise ValueError(
303
- "An input y-coordinate array with length greater than 1 "
304
- "was given, but the input data is 1-dimensional. Expected "
305
- "input data to be at least 2-dimensional"
306
- )
307
-
308
- # Expand Z to have 4 dimensions
309
- # There are 6 valid input shapes: (w), (b, w), (b, c, w),
310
- # (h, w), (b, h, w), and (b, c, h, w).
311
-
312
- # If the input y coordinate array has length 1,
313
- # assume the first dimension(s) are batch dimensions
314
- # and that no height dimension is included in Z
315
- idx = -2 if len(self.y_in) == 1 else -3
316
- while len(Z.shape) < 4:
317
- Z = Z.unsqueeze(idx)
466
+ if dims < 2:
467
+ raise ValueError("Input data has fewer than 2 dimensions")
318
468
 
319
469
  if Z.shape[-2:] != torch.Size([len(self.y_in), len(self.x_in)]):
320
470
  raise ValueError(
@@ -324,6 +474,10 @@ class SplineInterpolate(torch.nn.Module):
324
474
  f"[{Z.shape[-2]}, {Z.shape[-1]}]"
325
475
  )
326
476
 
477
+ # Expand Z to have a batch and channel dimension if needed
478
+ while len(Z.shape) < 4:
479
+ Z = Z.unsqueeze(0)
480
+
327
481
  return Z, y_out
328
482
 
329
483
  def forward(
@@ -337,11 +491,9 @@ class SplineInterpolate(torch.nn.Module):
337
491
 
338
492
  Args:
339
493
  Z:
340
- Tensor of data to be interpolated. Must be between 1 and 4
494
+ Tensor of data to be interpolated. Must be between 2 and 4
341
495
  dimensions. The shape of the tensor must agree with the
342
- input coordinates given on initialization. If `y_in` was
343
- not specified during initialization, it is assumed that
344
- Z does not have a height dimension.
496
+ input coordinates given on initialization.
345
497
  x_out:
346
498
  Coordinates to interpolate the data to along the width
347
499
  dimension. Overrides any value that was set during
@@ -352,7 +504,7 @@ class SplineInterpolate(torch.nn.Module):
352
504
  initialization.
353
505
 
354
506
  Returns:
355
- A 4D tensor with shape `(batch, channel, height, width)`.
507
+ A 4D tensor with shape ``(batch, channel, height, width)``.
356
508
  Depending on the input data shape, many of these dimensions
357
509
  may have length 1.
358
510
  """
@@ -360,9 +512,21 @@ class SplineInterpolate(torch.nn.Module):
360
512
  Z, y_out = self._validate_inputs(Z, x_out, y_out)
361
513
 
362
514
  if x_out is not None:
363
- self.Bx_out = self.bspline_basis_natural(x_out, self.kx, self.tx)
515
+ x_out = x_out.float()
516
+ x_clamped = torch.clamp(
517
+ x_out, self.tx[self.kx], self.tx[-self.kx - 1]
518
+ )
519
+ self.Bx_out = self.bspline_basis_natural(
520
+ x_clamped, self.kx, self.tx
521
+ )
364
522
  if y_out is not None:
365
- self.By_out = self.bspline_basis_natural(y_out, self.ky, self.ty)
523
+ y_out = y_out.float()
524
+ y_clamped = torch.clamp(
525
+ y_out, self.ty[self.ky], self.ty[-self.ky - 1]
526
+ )
527
+ self.By_out = self.bspline_basis_natural(
528
+ y_clamped, self.ky, self.ty
529
+ )
366
530
 
367
531
  coef = self.bivariate_spline_fit_natural(Z)
368
532
  Z_interp = self.evaluate_bivariate_spline(coef)
@@ -70,7 +70,7 @@ class FittableSpectralTransform(FittableTransform):
70
70
  )
71
71
 
72
72
  # add two dummy dimensions in case we need to interpolate
73
- # the frequency dimension, since `interpolate` expects
73
+ # the frequency dimension, since ``interpolate`` expects
74
74
  # a (batch, channel, spatial) formatted tensor as input
75
75
  x = x.view(1, 1, -1)
76
76
  if x.size(-1) != num_freqs: