ml4gw 0.5.1__tar.gz → 0.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ml4gw might be problematic. Click here for more details.

Files changed (51) hide show
  1. {ml4gw-0.5.1 → ml4gw-0.6.1}/PKG-INFO +1 -1
  2. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/constants.py +10 -19
  3. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/spectral.py +1 -1
  4. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/__init__.py +1 -0
  5. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/qtransform.py +134 -42
  6. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/scaler.py +4 -2
  7. ml4gw-0.6.1/ml4gw/transforms/spline_interpolation.py +370 -0
  8. ml4gw-0.6.1/ml4gw/waveforms/__init__.py +2 -0
  9. ml4gw-0.6.1/ml4gw/waveforms/adhoc/__init__.py +2 -0
  10. {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.1/ml4gw/waveforms/cbc}/__init__.py +0 -2
  11. {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.1/ml4gw/waveforms/cbc}/phenom_d.py +13 -12
  12. {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.1/ml4gw/waveforms/cbc}/phenom_p.py +36 -42
  13. {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.1/ml4gw/waveforms/cbc}/taylorf2.py +6 -2
  14. ml4gw-0.6.1/ml4gw/waveforms/conversion.py +204 -0
  15. {ml4gw-0.5.1 → ml4gw-0.6.1}/pyproject.toml +1 -1
  16. {ml4gw-0.5.1 → ml4gw-0.6.1}/README.md +0 -0
  17. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/__init__.py +0 -0
  18. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/augmentations.py +0 -0
  19. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/dataloading/__init__.py +0 -0
  20. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/dataloading/chunked_dataset.py +0 -0
  21. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/dataloading/hdf5_dataset.py +0 -0
  22. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/dataloading/in_memory_dataset.py +0 -0
  23. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/distributions.py +0 -0
  24. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/gw.py +0 -0
  25. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/__init__.py +0 -0
  26. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/autoencoder/__init__.py +0 -0
  27. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/autoencoder/base.py +0 -0
  28. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/autoencoder/convolutional.py +0 -0
  29. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/autoencoder/skip_connection.py +0 -0
  30. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/autoencoder/utils.py +0 -0
  31. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/norm.py +0 -0
  32. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/resnet/__init__.py +0 -0
  33. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/resnet/resnet_1d.py +0 -0
  34. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/resnet/resnet_2d.py +0 -0
  35. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/streaming/__init__.py +0 -0
  36. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/streaming/online_average.py +0 -0
  37. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/nn/streaming/snapshotter.py +0 -0
  38. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/pearson.py +0 -0
  39. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/snr_rescaler.py +0 -0
  40. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/spectral.py +0 -0
  41. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/spectrogram.py +0 -0
  42. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/transform.py +0 -0
  43. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/waveforms.py +0 -0
  44. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/transforms/whitening.py +0 -0
  45. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/types.py +0 -0
  46. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/utils/interferometer.py +0 -0
  47. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/utils/slicing.py +0 -0
  48. {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.1/ml4gw/waveforms/adhoc}/ringdown.py +0 -0
  49. {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.1/ml4gw/waveforms/adhoc}/sine_gaussian.py +0 -0
  50. {ml4gw-0.5.1/ml4gw/waveforms → ml4gw-0.6.1/ml4gw/waveforms/cbc}/phenom_d_data.py +0 -0
  51. {ml4gw-0.5.1 → ml4gw-0.6.1}/ml4gw/waveforms/generator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ml4gw
3
- Version: 0.5.1
3
+ Version: 0.6.1
4
4
  Summary: Tools for training torch models on gravitational wave data
5
5
  Author: Alec Gunny
6
6
  Author-email: alec.gunny@ligo.org
@@ -4,42 +4,33 @@ Various constants, all in SI units.
4
4
 
5
5
  EulerGamma = 0.577215664901532860606512090082402431
6
6
 
7
+ # solar mass
7
8
  MSUN = 1.988409902147041637325262574352366540e30 # kg
8
- """Solar mass"""
9
9
 
10
+ # Geometrized nominal solar mass, m
10
11
  MRSUN = 1.476625038050124729627979840144936351e3
11
- """Geometrized nominal solar mass, m"""
12
12
 
13
+ # Newton's gravitational constant
13
14
  G = 6.67430e-11 # m^3 / kg / s^2
14
- """Newton's gravitational constant"""
15
15
 
16
+ # Speed of light
16
17
  C = 299792458.0 # m / s
17
- """Speed of light"""
18
18
 
19
- """Pi"""
19
+ # pi and 2pi
20
20
  PI = 3.141592653589793238462643383279502884
21
-
22
21
  TWO_PI = 6.283185307179586476925286766559005768
23
22
 
23
+ # G MSUN / C^3 in seconds
24
24
  gt = G * MSUN / (C**3.0)
25
- """
26
- G MSUN / C^3 in seconds
27
- """
28
25
 
26
+ # 1 solar mass in seconds. Same value as lal.MTSUN_SI
29
27
  MTSUN_SI = 4.925490947641266978197229498498379006e-6
30
- """1 solar mass in seconds. Same value as lal.MTSUN_SI"""
31
28
 
29
+ # Meters per Mpc.
32
30
  m_per_Mpc = 3.085677581491367278913937957796471611e22
33
- """
34
- Meters per Mpc.
35
- """
36
31
 
32
+ # 1 Mpc in seconds.
37
33
  MPC_SEC = m_per_Mpc / C
38
- """
39
- 1 Mpc in seconds.
40
- """
41
34
 
35
+ # Speed of light in vacuum (:math:`c`), in gigaparsecs per second
42
36
  clightGpc = C / 3.0856778570831e22
43
- """
44
- Speed of light in vacuum (:math:`c`), in gigaparsecs per second
45
- """
@@ -441,7 +441,7 @@ def normalize_by_psd(
441
441
 
442
442
  # convert back to the time domain and normalize
443
443
  # TODO: what's this normalization factor?
444
- X = torch.fft.irfft(X_tilde, norm="forward", dim=-1)
444
+ X = torch.fft.irfft(X_tilde, n=X.shape[-1], norm="forward", dim=-1)
445
445
  X = X.float() / sample_rate**0.5
446
446
 
447
447
  # slice off corrupted data at edges of kernel
@@ -4,5 +4,6 @@ from .scaler import ChannelWiseScaler
4
4
  from .snr_rescaler import SnrRescaler
5
5
  from .spectral import SpectralDensity
6
6
  from .spectrogram import MultiResolutionSpectrogram
7
+ from .spline_interpolation import SplineInterpolate
7
8
  from .waveforms import WaveformProjector, WaveformSampler
8
9
  from .whitening import FixedWhiten, Whiten
@@ -1,11 +1,13 @@
1
1
  import math
2
- from typing import List, Optional, Tuple
2
+ import warnings
3
+ from typing import List, Tuple
3
4
 
4
5
  import torch
5
6
  import torch.nn.functional as F
6
7
  from jaxtyping import Float, Int
7
8
  from torch import Tensor
8
9
 
10
+ from ml4gw.transforms.spline_interpolation import SplineInterpolate
9
11
  from ml4gw.types import FrequencySeries1to3d, TimeSeries1to3d, TimeSeries3d
10
12
 
11
13
  """
@@ -38,7 +40,6 @@ class QTile(torch.nn.Module):
38
40
  mismatch:
39
41
  The maximum fractional mismatch between neighboring tiles
40
42
 
41
-
42
43
  """
43
44
 
44
45
  def __init__(
@@ -100,7 +101,9 @@ class QTile(torch.nn.Module):
100
101
  ).type(torch.long)
101
102
 
102
103
  def forward(
103
- self, fseries: FrequencySeries1to3d, norm: str = "median"
104
+ self,
105
+ fseries: FrequencySeries1to3d,
106
+ norm: str = "median",
104
107
  ) -> TimeSeries1to3d:
105
108
  """
106
109
  Compute the transform for this row
@@ -144,7 +147,7 @@ class QTile(torch.nn.Module):
144
147
  energy /= means
145
148
  else:
146
149
  raise ValueError("Invalid normalisation %r" % norm)
147
- return energy.type(torch.float32)
150
+ energy = energy.type(torch.float32)
148
151
  return energy
149
152
 
150
153
 
@@ -172,6 +175,19 @@ class SingleQTransform(torch.nn.Module):
172
175
  be chosen based on q, sample_rate, and duration
173
176
  mismatch:
174
177
  The maximum fractional mismatch between neighboring tiles
178
+ interpolation_method:
179
+ The method by which to interpolate each `QTile` to the specified
180
+ number of time and frequency bins. The acceptable values are
181
+ "bilinear", "bicubic", and "spline". The "bilinear" and "bicubic"
182
+ options will use PyTorch's built-in interpolation modes, while
183
+ "spline" will use the custom Torch-based implementation in
184
+ `ml4gw`, as PyTorch does not have spline-based intertpolation.
185
+ The "spline" mode is most similar to the results of GWpy's
186
+ Q-transform, which uses `scipy` to do spline interpolation.
187
+ However, it is also the slowest and most memory intensive due to
188
+ the matrix equation solving steps. Therefore, the default method
189
+ is "bicubic" as it produces the most similar results while
190
+ optimizing for computing performance.
175
191
  """
176
192
 
177
193
  def __init__(
@@ -182,6 +198,7 @@ class SingleQTransform(torch.nn.Module):
182
198
  q: float = 12,
183
199
  frange: List[float] = [0, torch.inf],
184
200
  mismatch: float = 0.2,
201
+ interpolation_method: str = "bicubic",
185
202
  ) -> None:
186
203
  super().__init__()
187
204
  self.q = q
@@ -190,20 +207,87 @@ class SingleQTransform(torch.nn.Module):
190
207
  self.duration = duration
191
208
  self.mismatch = mismatch
192
209
 
210
+ # If q is too large, the minimum of the frange computed
211
+ # below will be larger than the maximum
212
+ max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
213
+ if q >= max_q:
214
+ raise ValueError(
215
+ "The given q value is too large for the given duration and "
216
+ f"sample rate. The maximum allowable value is {max_q}"
217
+ )
218
+
219
+ if interpolation_method not in ["bilinear", "bicubic", "spline"]:
220
+ raise ValueError(
221
+ "Interpolation method must be either 'bilinear', 'bicubic', "
222
+ f"or 'spline'; got {interpolation_method}"
223
+ )
224
+ self.interpolation_method = interpolation_method
225
+
193
226
  qprime = self.q / 11 ** (1 / 2.0)
194
227
  if self.frange[0] <= 0: # set non-zero lower frequency
195
228
  self.frange[0] = 50 * self.q / (2 * torch.pi * duration)
196
229
  if math.isinf(self.frange[1]): # set non-infinite upper frequency
197
230
  self.frange[1] = sample_rate / 2 / (1 + 1 / qprime)
231
+
198
232
  self.freqs = self.get_freqs()
199
233
  self.qtile_transforms = torch.nn.ModuleList(
200
234
  [
201
- QTile(self.q, freq, self.duration, sample_rate, self.mismatch)
235
+ QTile(
236
+ q=self.q,
237
+ frequency=freq,
238
+ duration=self.duration,
239
+ sample_rate=sample_rate,
240
+ mismatch=self.mismatch,
241
+ )
202
242
  for freq in self.freqs
203
243
  ]
204
244
  )
205
245
  self.qtiles = None
206
246
 
247
+ if self.interpolation_method == "spline":
248
+ self._set_up_spline_interp()
249
+
250
+ def _set_up_spline_interp(self):
251
+ ntiles = [qtile.ntiles() for qtile in self.qtile_transforms]
252
+ # For efficiency, we'll stack all qtiles of the same length before
253
+ # interpolating, so we need to figure out which those are
254
+ unique_ntiles = sorted(list(set(ntiles)))
255
+ idx = torch.arange(len(ntiles))
256
+ self.stack_idx = [idx[Tensor(ntiles) == n] for n in unique_ntiles]
257
+
258
+ t_out = torch.arange(
259
+ 0, self.duration, self.duration / self.spectrogram_shape[1]
260
+ )
261
+ self.qtile_interpolators = torch.nn.ModuleList(
262
+ [
263
+ SplineInterpolate(
264
+ kx=3,
265
+ x_in=torch.arange(0, self.duration, self.duration / tiles),
266
+ y_in=torch.arange(len(idx)),
267
+ x_out=t_out,
268
+ y_out=torch.arange(len(idx)),
269
+ )
270
+ for tiles, idx in zip(unique_ntiles, self.stack_idx)
271
+ ]
272
+ )
273
+
274
+ t_in = t_out
275
+ f_in = self.freqs
276
+ f_out = torch.logspace(
277
+ math.log10(self.frange[0]),
278
+ math.log10(self.frange[-1]),
279
+ self.spectrogram_shape[0],
280
+ )
281
+
282
+ self.interpolator = SplineInterpolate(
283
+ kx=3,
284
+ ky=3,
285
+ x_in=t_in,
286
+ y_in=f_in,
287
+ x_out=t_out,
288
+ y_out=f_out,
289
+ )
290
+
207
291
  def get_freqs(self) -> Float[Tensor, " nfreq"]:
208
292
  """
209
293
  Calculate the frequencies that will be used in this transform.
@@ -220,7 +304,8 @@ class SingleQTransform(torch.nn.Module):
220
304
 
221
305
  freq_base = math.exp(2 / ((2 + self.q**2) ** (1 / 2.0)) * fstep)
222
306
  freqs = torch.Tensor([freq_base ** (i + 0.5) for i in range(nfreq)])
223
- freqs = (minf * freqs // fstepmin) * fstepmin
307
+ # Cast freqs to float64 to avoid off-by-ones from rounding
308
+ freqs = (minf * freqs.double() // fstepmin) * fstepmin
224
309
  return torch.unique(freqs)
225
310
 
226
311
  def get_max_energy(
@@ -268,7 +353,11 @@ class SingleQTransform(torch.nn.Module):
268
353
  if dimension == "batch":
269
354
  return torch.max(max_across_ft, dim=-1).values
270
355
 
271
- def compute_qtiles(self, X: TimeSeries1to3d, norm: str = "median") -> None:
356
+ def compute_qtiles(
357
+ self,
358
+ X: TimeSeries1to3d,
359
+ norm: str = "median",
360
+ ) -> None:
272
361
  """
273
362
  Take the FFT of the input timeseries and calculate the transform
274
363
  for each `QTile`
@@ -278,28 +367,40 @@ class SingleQTransform(torch.nn.Module):
278
367
  X[..., 1:] *= 2
279
368
  self.qtiles = [qtile(X, norm) for qtile in self.qtile_transforms]
280
369
 
281
- def interpolate(self, num_f_bins: int, num_t_bins: int) -> TimeSeries3d:
282
- """
283
- Interpolate each `QTile` to the specified number of time and
284
- frequency bins. Note that PyTorch does not have the same
285
- interpolation methods that GWpy uses, and so the interpolated
286
- spectrograms will be different even though the uninterpolated
287
- values match. The `bicubic` interpolation method is used as
288
- it seems to match GWpy most closely.
289
- """
370
+ def interpolate(self) -> TimeSeries3d:
290
371
  if self.qtiles is None:
291
372
  raise RuntimeError(
292
373
  "Q-tiles must first be computed with .compute_qtiles()"
293
374
  )
375
+ if self.interpolation_method == "spline":
376
+ qtiles = [
377
+ torch.stack([self.qtiles[i] for i in idx], dim=-2)
378
+ for idx in self.stack_idx
379
+ ]
380
+ time_interped = torch.cat(
381
+ [
382
+ interpolator(qtile)
383
+ for qtile, interpolator in zip(
384
+ qtiles, self.qtile_interpolators
385
+ )
386
+ ],
387
+ dim=-2,
388
+ )
389
+ return self.interpolator(time_interped)
390
+ num_f_bins, num_t_bins = self.spectrogram_shape
294
391
  resampled = [
295
392
  F.interpolate(
296
- qtile[None], (qtile.shape[-2], num_t_bins), mode="bicubic"
393
+ qtile[None],
394
+ (qtile.shape[-2], num_t_bins),
395
+ mode=self.interpolation_method,
297
396
  )
298
397
  for qtile in self.qtiles
299
398
  ]
300
399
  resampled = torch.stack(resampled, dim=-2)
301
400
  resampled = F.interpolate(
302
- resampled[0], (num_f_bins, num_t_bins), mode="bicubic"
401
+ resampled[0],
402
+ (num_f_bins, num_t_bins),
403
+ mode=self.interpolation_method,
303
404
  )
304
405
  return torch.squeeze(resampled)
305
406
 
@@ -307,7 +408,6 @@ class SingleQTransform(torch.nn.Module):
307
408
  self,
308
409
  X: TimeSeries1to3d,
309
410
  norm: str = "median",
310
- spectrogram_shape: Optional[Tuple[int, int]] = None,
311
411
  ):
312
412
  """
313
413
  Compute the Q-tiles and interpolate
@@ -321,24 +421,15 @@ class SingleQTransform(torch.nn.Module):
321
421
  three-dimensional, axes will be added during Q-tile
322
422
  computation.
323
423
  norm:
324
- The method of interpolation used by each QTile
325
- spectrogram_shape:
326
- The shape of the interpolated spectrogram, specified as
327
- `(num_f_bins, num_t_bins)`. Because the
328
- frequency spacing of the Q-tiles is in log-space, the frequency
329
- interpolation is log-spaced as well. If not given, the shape
330
- used to initialize the transform will be used.
424
+ The method of normalization used by each QTile
331
425
 
332
426
  Returns:
333
427
  The interpolated Q-transform for the batch of data. Output will
334
428
  have one more dimension than the input
335
429
  """
336
430
 
337
- if spectrogram_shape is None:
338
- spectrogram_shape = self.spectrogram_shape
339
- num_f_bins, num_t_bins = spectrogram_shape
340
431
  self.compute_qtiles(X, norm)
341
- return self.interpolate(num_f_bins, num_t_bins)
432
+ return self.interpolate()
342
433
 
343
434
 
344
435
  class QScan(torch.nn.Module):
@@ -376,14 +467,22 @@ class QScan(torch.nn.Module):
376
467
  spectrogram_shape: Tuple[int, int],
377
468
  qrange: List[float] = [4, 64],
378
469
  frange: List[float] = [0, torch.inf],
470
+ interpolation_method="bicubic",
379
471
  mismatch: float = 0.2,
380
472
  ) -> None:
381
473
  super().__init__()
382
474
  self.qrange = qrange
383
475
  self.mismatch = mismatch
384
- self.qs = self.get_qs()
385
476
  self.frange = frange
386
477
  self.spectrogram_shape = spectrogram_shape
478
+ max_q = torch.pi * duration * sample_rate / 50 - 11 ** (0.5)
479
+ self.qs = self.get_qs()
480
+ if self.qs[-1] >= max_q:
481
+ warnings.warn(
482
+ "Some Q values exceed the maximum allowable Q value of "
483
+ f"{max_q}. The list of Q values to be tested in this "
484
+ "scan will be truncated to avoid those values."
485
+ )
387
486
 
388
487
  # Deliberately doing something different from GWpy here.
389
488
  # Their final frange is the intersection of the frange
@@ -397,9 +496,11 @@ class QScan(torch.nn.Module):
397
496
  spectrogram_shape=spectrogram_shape,
398
497
  q=q,
399
498
  frange=self.frange.copy(),
499
+ interpolation_method=interpolation_method,
400
500
  mismatch=self.mismatch,
401
501
  )
402
502
  for q in self.qs
503
+ if q < max_q
403
504
  ]
404
505
  )
405
506
 
@@ -415,6 +516,7 @@ class QScan(torch.nn.Module):
415
516
  self.qrange[0] * math.exp(2 ** (1 / 2.0) * dq * (i + 0.5))
416
517
  for i in range(nplanes)
417
518
  ]
519
+
418
520
  return qs
419
521
 
420
522
  def forward(
@@ -422,7 +524,6 @@ class QScan(torch.nn.Module):
422
524
  X: TimeSeries1to3d,
423
525
  fsearch_range: List[float] = None,
424
526
  norm: str = "median",
425
- spectrogram_shape: Optional[Tuple[int, int]] = None,
426
527
  ):
427
528
  """
428
529
  Compute the set of QTiles for each Q transform and determine which
@@ -442,12 +543,6 @@ class QScan(torch.nn.Module):
442
543
  for the maximum energy
443
544
  norm:
444
545
  The method of interpolation used by each QTile
445
- spectrogram_shape:
446
- The shape of the interpolated spectrogram, specified as
447
- `(num_f_bins, num_t_bins)`. Because the
448
- frequency spacing of the Q-tiles is in log-space, the frequency
449
- interpolation is log-spaced as well. If not given, the shape
450
- used to initialize the transform will be used.
451
546
 
452
547
  Returns:
453
548
  An interpolated Q-transform for the batch of data. Output will
@@ -463,7 +558,4 @@ class QScan(torch.nn.Module):
463
558
  ]
464
559
  )
465
560
  )
466
- if spectrogram_shape is None:
467
- spectrogram_shape = self.spectrogram_shape
468
- num_f_bins, num_t_bins = spectrogram_shape
469
- return self.q_transforms[idx].interpolate(num_f_bins, num_t_bins)
561
+ return self.q_transforms[idx].interpolate()
@@ -36,7 +36,9 @@ class ChannelWiseScaler(FittableTransform):
36
36
  self.register_buffer("mean", mean)
37
37
  self.register_buffer("std", std)
38
38
 
39
- def fit(self, X: Float[Tensor, "... time"]) -> None:
39
+ def fit(
40
+ self, X: Float[Tensor, "... time"], std_reg: Optional[float] = 0.0
41
+ ) -> None:
40
42
  """Fit the scaling parameters to a timeseries
41
43
 
42
44
  Computes the channel-wise mean and standard deviation
@@ -59,7 +61,7 @@ class ChannelWiseScaler(FittableTransform):
59
61
  "Can't fit channel wise mean and standard deviation "
60
62
  "from tensor of shape {}".format(X.shape)
61
63
  )
62
-
64
+ std += std_reg * torch.ones_like(std)
63
65
  super().build(mean=mean, std=std)
64
66
 
65
67
  def forward(