pywavelet 0.0.5__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.0.5'
16
- __version_tuple__ = version_tuple = (0, 0, 5)
15
+ __version__ = version = '0.1.1'
16
+ __version_tuple__ = version_tuple = (0, 1, 1)
pywavelet/logger.py CHANGED
@@ -1,11 +1,15 @@
1
+ import logging
1
2
  import sys
2
3
  import warnings
4
+
3
5
  from rich.logging import RichHandler
4
- import logging
5
6
 
6
7
  FORMAT = "%(message)s"
7
8
  logging.basicConfig(
8
- level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
9
+ level="INFO",
10
+ format=FORMAT,
11
+ datefmt="[%X]",
12
+ handlers=[RichHandler(rich_tracebacks=True)],
9
13
  )
10
14
 
11
15
  logger = logging.getLogger("pywavelet")
@@ -1,4 +1,4 @@
1
- from .forward import from_freq_to_wavelet, from_time_to_wavelet, compute_bins
1
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
2
2
  from .inverse import from_wavelet_to_freq, from_wavelet_to_time
3
3
 
4
4
  __all__ = [
@@ -6,5 +6,4 @@ __all__ = [
6
6
  "from_wavelet_to_freq",
7
7
  "from_time_to_wavelet",
8
8
  "from_freq_to_wavelet",
9
- "compute_bins",
10
9
  ]
@@ -1,4 +1,3 @@
1
1
  from .main import from_freq_to_wavelet, from_time_to_wavelet
2
- from .wavelet_bins import compute_bins
3
2
 
4
- __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet", "compute_bins"]
3
+ __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -1,8 +1,10 @@
1
1
  """helper functions for transform_freq"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
5
6
 
7
+
6
8
  def transform_wavelet_freq_helper(
7
9
  data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
8
10
  ) -> np.ndarray:
@@ -13,8 +15,16 @@ def transform_wavelet_freq_helper(
13
15
  __core(Nf, Nt, DX, freq_strain, phif, wave)
14
16
  return wave
15
17
 
18
+
16
19
  # @njit()
17
- def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarray, wave:np.ndarray):
20
+ def __core(
21
+ Nf: int,
22
+ Nt: int,
23
+ DX: np.ndarray,
24
+ freq_strain: np.ndarray,
25
+ phif: np.ndarray,
26
+ wave: np.ndarray,
27
+ ):
18
28
  for f_bin in range(0, Nf + 1):
19
29
  __fill_wave_1(f_bin, Nt, Nf, DX, freq_strain, phif)
20
30
  # Numba doesn't support np.ifft so we cant jit this
@@ -22,8 +32,6 @@ def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarra
22
32
  __fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
23
33
 
24
34
 
25
-
26
-
27
35
  @njit()
28
36
  def __fill_wave_1(
29
37
  f_bin: int,
@@ -55,6 +63,7 @@ def __fill_wave_1(
55
63
  else:
56
64
  DX[i] = phif[j] * data[jj]
57
65
 
66
+
58
67
  @njit()
59
68
  def __fill_wave_2(
60
69
  f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
@@ -72,7 +81,7 @@ def __fill_wave_2(
72
81
  if (n + f_bin) % 2:
73
82
  wave[n, f_bin] = -DX_trans[n].imag
74
83
  else:
75
- wave[n, f_bin] = DX_trans[n].real
84
+ wave[n, f_bin] = DX_trans[n].real
76
85
  else:
77
86
  if (n + f_bin) % 2:
78
87
  wave[n, f_bin] = DX_trans[n].imag
@@ -1,4 +1,5 @@
1
1
  """helper functions for transform_time.py"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -20,7 +21,18 @@ def transform_wavelet_time_helper(
20
21
  __core(Nf, Nt, K, ND, wdata, data_pad, phi, wave, mult)
21
22
  return wave
22
23
 
23
- def __core(Nf: int, Nt: int, K: int, ND: int, wdata: np.ndarray, data_pad: np.ndarray, phi: np.ndarray, wave: np.ndarray, mult: int) -> None:
24
+
25
+ def __core(
26
+ Nf: int,
27
+ Nt: int,
28
+ K: int,
29
+ ND: int,
30
+ wdata: np.ndarray,
31
+ data_pad: np.ndarray,
32
+ phi: np.ndarray,
33
+ wave: np.ndarray,
34
+ mult: int,
35
+ ) -> None:
24
36
  for time_bin_i in range(0, Nt):
25
37
  __fill_wave_1(time_bin_i, K, ND, Nf, wdata, data_pad, phi)
26
38
  wdata_trans = np.fft.rfft(wdata, K)
@@ -3,15 +3,15 @@ from typing import Union
3
3
  import numpy as np
4
4
 
5
5
  from ...logger import logger
6
+ from ...types import FrequencySeries, TimeSeries, Wavelet
7
+ from ...types.wavelet_bins import _get_bins, _preprocess_bins
6
8
  from ..phi_computer import phi_vec, phitilde_vec_norm
7
- from ..types import FrequencySeries, TimeSeries, Wavelet
8
9
  from .from_freq import transform_wavelet_freq_helper
9
10
  from .from_time import transform_wavelet_time_helper
10
- from .wavelet_bins import _get_bins, _preprocess_bins
11
-
12
11
 
13
12
  __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
14
13
 
14
+
15
15
  def from_time_to_wavelet(
16
16
  timeseries: TimeSeries,
17
17
  Nf: Union[int, None] = None,
@@ -74,9 +74,7 @@ def from_time_to_wavelet(
74
74
  mult = min(mult, Nt // 2) # Ensure mult is not larger than ND/2
75
75
  phi = phi_vec(Nf, dt=dt, d=nx, q=mult)
76
76
  wave = transform_wavelet_time_helper(timeseries.data, Nf, Nt, phi, mult).T
77
- return Wavelet(
78
- wave * np.sqrt(2), time=t_bins, freq=f_bins
79
- )
77
+ return Wavelet(wave * np.sqrt(2), time=t_bins, freq=f_bins)
80
78
 
81
79
 
82
80
  def from_freq_to_wavelet(
@@ -117,12 +115,6 @@ def from_freq_to_wavelet(
117
115
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
118
116
  dt = freqseries.dt
119
117
  phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)
120
- wave = transform_wavelet_freq_helper(
121
- freqseries.data, Nf, Nt, phif
122
- )
123
-
124
- return Wavelet(
125
- (2 / Nf) * wave.T * np.sqrt(2),
126
- time=t_bins,
127
- freq=f_bins
128
- )
118
+ wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
119
+
120
+ return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
 
3
3
  from ...transforms.phi_computer import phi_vec, phitilde_vec_norm
4
- from ..types import FrequencySeries, TimeSeries, Wavelet
4
+ from ...types import FrequencySeries, TimeSeries, Wavelet
5
5
  from .to_freq import inverse_wavelet_freq_helper_fast
6
6
  from .to_time import inverse_wavelet_time_helper_fast
7
7
 
@@ -9,6 +9,7 @@ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
9
9
 
10
10
  INV_ROOT2 = 1.0 / np.sqrt(2)
11
11
 
12
+
12
13
  def from_wavelet_to_time(
13
14
  wave_in: Wavelet,
14
15
  dt: float,
@@ -55,9 +56,7 @@ def from_wavelet_to_time(
55
56
 
56
57
 
57
58
  def from_wavelet_to_freq(
58
- wave_in: Wavelet,
59
- dt: float,
60
- nx:float=4.0
59
+ wave_in: Wavelet, dt: float, nx: float = 4.0
61
60
  ) -> FrequencySeries:
62
61
  """
63
62
  Perform an inverse wavelet transform to the frequency domain.
@@ -1,4 +1,5 @@
1
1
  """functions for computing the inverse wavelet transforms"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -12,13 +13,20 @@ def inverse_wavelet_freq_helper_fast(
12
13
  ND = Nf * Nt
13
14
 
14
15
  prefactor2s = np.zeros(Nt, np.complex128)
15
- res = np.zeros(ND//2 +1, dtype=np.complex128)
16
+ res = np.zeros(ND // 2 + 1, dtype=np.complex128)
16
17
  __core(Nf, Nt, prefactor2s, wave_in, phif, res)
17
18
 
18
-
19
19
  return res
20
20
 
21
- def __core(Nf: int, Nt: int, prefactor2s: np.ndarray, wave_in: np.ndarray, phif: np.ndarray, res: np.ndarray) -> None:
21
+
22
+ def __core(
23
+ Nf: int,
24
+ Nt: int,
25
+ prefactor2s: np.ndarray,
26
+ wave_in: np.ndarray,
27
+ phif: np.ndarray,
28
+ res: np.ndarray,
29
+ ) -> None:
22
30
  for m in range(0, Nf + 1):
23
31
  __pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
24
32
  fft_prefactor2s = np.fft.fft(prefactor2s)
@@ -1,4 +1,5 @@
1
1
  """functions for computing the inverse wavelet transforms"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -21,7 +22,16 @@ def inverse_wavelet_time_helper_fast(
21
22
  return res[:ND]
22
23
 
23
24
 
24
- def __core(Nf: int, Nt: int, K: int, ND: int, wave_in: np.ndarray, phi: np.ndarray, res: np.ndarray, afins:np.ndarray) -> None:
25
+ def __core(
26
+ Nf: int,
27
+ Nt: int,
28
+ K: int,
29
+ ND: int,
30
+ wave_in: np.ndarray,
31
+ phi: np.ndarray,
32
+ res: np.ndarray,
33
+ afins: np.ndarray,
34
+ ) -> None:
25
35
  for n in range(0, Nt):
26
36
  if n % 2 == 0:
27
37
  pack_wave_time_helper_compact(n, Nf, Nt, wave_in, afins)
@@ -29,9 +39,9 @@ def __core(Nf: int, Nt: int, K: int, ND: int, wave_in: np.ndarray, phi: np.ndarr
29
39
  unpack_time_wave_helper_compact(n, Nf, Nt, K, phi, ffts_fin, res)
30
40
 
31
41
  # wrap boundary conditions
32
- res[: min(K + Nf, ND)] += res[ND: min(ND + K + Nf, 2 * ND)]
42
+ res[: min(K + Nf, ND)] += res[ND : min(ND + K + Nf, 2 * ND)]
33
43
  if K + Nf > ND:
34
- res[: K + Nf - ND] += res[2 * ND: ND + K * Nf]
44
+ res[: K + Nf - ND] += res[2 * ND : ND + K * Nf]
35
45
 
36
46
 
37
47
  def unpack_time_wave_helper(
@@ -1,3 +1,3 @@
1
1
  from .frequencyseries import FrequencySeries
2
2
  from .timeseries import TimeSeries
3
- from .wavelet import Wavelet
3
+ from .wavelet import Wavelet, WaveletMask
@@ -1,9 +1,9 @@
1
- from typing import Literal, Tuple
1
+ from typing import Tuple, Union
2
2
 
3
3
  import numpy as xp
4
- from numpy.fft import irfft, fft, rfft, rfftfreq
4
+ from numpy.fft import fft, irfft, rfft, rfftfreq # type: ignore
5
5
 
6
- from ...logger import logger
6
+ from ..logger import logger
7
7
 
8
8
 
9
9
  def _len_check(d):
@@ -19,7 +19,7 @@ def is_documented_by(original):
19
19
  return wrapper
20
20
 
21
21
 
22
- def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
22
+ def fmt_time(seconds: float, units=False) -> Union[str, Tuple[str, str]]:
23
23
  """Returns formatted time and units [ms, s, min, hr, day]"""
24
24
  t, u = "", ""
25
25
  if seconds < 1e-3:
@@ -42,12 +42,12 @@ def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
42
42
 
43
43
  def fmt_timerange(trange):
44
44
  t0 = fmt_time(trange[0])
45
- tend, units = fmt_time(trange[1], units = True)
45
+ tend, units = fmt_time(trange[1], units=True)
46
46
  return f"[{t0}, {tend}] {units}"
47
47
 
48
48
 
49
- def fmt_pow2(n:float)->str:
49
+ def fmt_pow2(n: float) -> str:
50
50
  pow2 = xp.log2(n)
51
51
  if pow2.is_integer():
52
52
  return f"2^{int(pow2)}"
53
- return f"{n:,}"
53
+ return f"{n:,}"
@@ -1,11 +1,13 @@
1
+ from typing import Optional, Tuple, Union
2
+
1
3
  import matplotlib.pyplot as plt
2
- from typing import Tuple, Union, Optional
3
4
 
4
- from .common import is_documented_by, xp, irfft, fmt_time, fmt_pow2
5
+ from .common import fmt_pow2, fmt_time, irfft, is_documented_by, xp
5
6
  from .plotting import plot_freqseries, plot_periodogram
6
7
 
7
8
  __all__ = ["FrequencySeries"]
8
9
 
10
+
9
11
  class FrequencySeries:
10
12
  """
11
13
  A class to represent a one-sided frequency series, with various methods for
@@ -39,9 +41,13 @@ class FrequencySeries:
39
41
  If any frequency is negative or if `data` and `freq` do not have the same length.
40
42
  """
41
43
  if xp.any(freq < 0):
42
- raise ValueError("FrequencySeries must be one-sided (only non-negative frequencies)")
44
+ raise ValueError(
45
+ "FrequencySeries must be one-sided (only non-negative frequencies)"
46
+ )
43
47
  if len(data) != len(freq):
44
- raise ValueError(f"data and freq must have the same length ({len(data)} != {len(freq)})")
48
+ raise ValueError(
49
+ f"data and freq must have the same length ({len(data)} != {len(freq)})"
50
+ )
45
51
  self.data = data
46
52
  self.freq = freq
47
53
  self.t0 = t0
@@ -53,10 +59,10 @@ class FrequencySeries:
53
59
  )
54
60
 
55
61
  @is_documented_by(plot_periodogram)
56
- def plot_periodogram(self, ax=None, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
57
- return plot_periodogram(
58
- self.data, self.freq, self.fs, ax=ax, **kwargs
59
- )
62
+ def plot_periodogram(
63
+ self, ax=None, **kwargs
64
+ ) -> Tuple[plt.Figure, plt.Axes]:
65
+ return plot_periodogram(self.data, self.freq, self.fs, ax=ax, **kwargs)
60
66
 
61
67
  def __len__(self):
62
68
  """Return the length of the frequency series."""
@@ -127,7 +133,9 @@ class FrequencySeries:
127
133
  n = fmt_pow2(len(self))
128
134
  return f"FrequencySeries(n={n}, frange=[{self.range[0]:.2f}, {self.range[1]:.2f}] Hz, T={dur}, fs={self.fs:.2f} Hz)"
129
135
 
130
- def noise_weighted_inner_product(self, other: "FrequencySeries", psd:"FrequencySeries") -> float:
136
+ def noise_weighted_inner_product(
137
+ self, other: "FrequencySeries", psd: "FrequencySeries"
138
+ ) -> float:
131
139
  """
132
140
  Compute the noise-weighted inner product of two FrequencySeries.
133
141
 
@@ -144,9 +152,11 @@ class FrequencySeries:
144
152
  The noise-weighted inner product of the two FrequencySeries.
145
153
  """
146
154
  integrand = xp.real(xp.conj(self.data) * other.data / psd.data)
147
- return (4 * self.dt/self.ND) * xp.nansum(integrand)
155
+ return (4 * self.dt / self.ND) * xp.nansum(integrand)
148
156
 
149
- def matched_filter_snr(self, other: "FrequencySeries", psd: "FrequencySeries") -> float:
157
+ def matched_filter_snr(
158
+ self, other: "FrequencySeries", psd: "FrequencySeries"
159
+ ) -> float:
150
160
  """
151
161
  Compute the signal-to-noise ratio (SNR) of a matched filter.
152
162
 
@@ -199,15 +209,15 @@ class FrequencySeries:
199
209
 
200
210
  # Create and return a TimeSeries object
201
211
  from .timeseries import TimeSeries
202
- return TimeSeries(time_data, time)
203
212
 
213
+ return TimeSeries(time_data, time)
204
214
 
205
215
  def to_wavelet(
206
- self,
207
- Nf: Union[int, None] = None,
208
- Nt: Union[int, None] = None,
209
- nx: Optional[float] = 4.0,
210
- )->"Wavelet":
216
+ self,
217
+ Nf: Union[int, None] = None,
218
+ Nt: Union[int, None] = None,
219
+ nx: Optional[float] = 4.0,
220
+ ) -> "Wavelet":
211
221
  """
212
222
  Convert the frequency series to a wavelet using inverse Fourier transform.
213
223
 
@@ -216,9 +226,9 @@ class FrequencySeries:
216
226
  Wavelet
217
227
  The corresponding wavelet.
218
228
  """
219
- from ..forward import from_freq_to_wavelet
220
- return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
229
+ from ..transforms.forward import from_freq_to_wavelet
221
230
 
231
+ return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
222
232
 
223
233
  def __eq__(self, other):
224
234
  """Check if two FrequencySeries objects are equal."""
@@ -228,10 +238,8 @@ class FrequencySeries:
228
238
 
229
239
  def __copy__(self):
230
240
  return FrequencySeries(
231
- xp.copy(self.data),
232
- xp.copy(self.freq),
233
- t0=self.t0
241
+ xp.copy(self.data), xp.copy(self.freq), t0=self.t0
234
242
  )
235
243
 
236
244
  def copy(self):
237
- return self.__copy__()
245
+ return self.__copy__()
@@ -1,12 +1,11 @@
1
1
  import warnings
2
- from typing import Tuple, Optional, Union
3
- from scipy.signal import savgol_filter
4
- from scipy.interpolate import interp1d
2
+ from typing import Optional, Tuple, Union
5
3
 
6
4
  import matplotlib.pyplot as plt
7
5
  import numpy as np
8
6
  from matplotlib.colors import LogNorm, TwoSlopeNorm
9
- from scipy.signal import spectrogram
7
+ from scipy.interpolate import interp1d
8
+ from scipy.signal import savgol_filter, spectrogram
10
9
 
11
10
  MIN_S = 60
12
11
  HOUR_S = 60 * MIN_S
@@ -53,8 +52,13 @@ def __get_smoothed_y(x, z, y_grid):
53
52
  # Interpolate to fill NaNs in y before smoothing
54
53
  nan_mask = ~np.isnan(y)
55
54
  if np.isnan(y).any():
56
- interpolator = interp1d(x[nan_mask], y[nan_mask], kind='cubic', bounds_error=False,
57
- fill_value="extrapolate")
55
+ interpolator = interp1d(
56
+ x[nan_mask],
57
+ y[nan_mask],
58
+ kind="cubic",
59
+ bounds_error=False,
60
+ fill_value="extrapolate",
61
+ )
58
62
  y = interpolator(x) # Fill NaNs with interpolated values
59
63
 
60
64
  # Smooth the curve
@@ -64,8 +68,6 @@ def __get_smoothed_y(x, z, y_grid):
64
68
  return y
65
69
 
66
70
 
67
-
68
-
69
71
  def plot_wavelet_grid(
70
72
  wavelet_data: np.ndarray,
71
73
  time_grid: np.ndarray,
@@ -80,8 +82,8 @@ def plot_wavelet_grid(
80
82
  norm: Optional[Union[LogNorm, TwoSlopeNorm]] = None,
81
83
  cbar_label: Optional[str] = None,
82
84
  nan_color: Optional[str] = "black",
83
- detailed_axes:bool = False,
84
- show_gridinfo:bool = True,
85
+ detailed_axes: bool = False,
86
+ show_gridinfo: bool = True,
85
87
  trend_color: Optional[str] = None,
86
88
  whiten_by: Optional[np.ndarray] = None,
87
89
  **kwargs,
@@ -153,7 +155,9 @@ def plot_wavelet_grid(
153
155
 
154
156
  # Validate the dimensions
155
157
  if (Nf, Nt) != (len(freq_grid), len(time_grid)):
156
- raise ValueError(f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}.")
158
+ raise ValueError(
159
+ f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}."
160
+ )
157
161
 
158
162
  # Prepare the data for plotting
159
163
  z = wavelet_data.copy()
@@ -162,14 +166,15 @@ def plot_wavelet_grid(
162
166
  if absolute:
163
167
  z = np.abs(z)
164
168
 
165
-
166
169
  # Determine normalization and colormap
167
170
  if norm is None:
168
171
  try:
169
172
  if np.all(np.isnan(z)):
170
173
  raise ValueError("All wavelet data is NaN.")
171
174
  if zscale == "log":
172
- norm = LogNorm(vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z<np.inf]))
175
+ norm = LogNorm(
176
+ vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z < np.inf])
177
+ )
173
178
  elif not absolute:
174
179
  vmin, vmax = np.nanmin(z), np.nanmax(z)
175
180
  vcenter = 0.0
@@ -177,7 +182,9 @@ def plot_wavelet_grid(
177
182
  else:
178
183
  norm = None # Default linear scaling
179
184
  except Exception as e:
180
- warnings.warn(f"Error in determining normalization: {e}. Using default linear scaling.")
185
+ warnings.warn(
186
+ f"Error in determining normalization: {e}. Using default linear scaling."
187
+ )
181
188
  norm = None
182
189
 
183
190
  if cmap is None:
@@ -195,7 +202,7 @@ def plot_wavelet_grid(
195
202
  im = ax.imshow(
196
203
  z,
197
204
  aspect="auto",
198
- extent=[time_grid[0],time_grid[-1], freq_grid[0], freq_grid[-1]],
205
+ extent=[time_grid[0], time_grid[-1], freq_grid[0], freq_grid[-1]],
199
206
  origin="lower",
200
207
  cmap=cmap,
201
208
  norm=norm,
@@ -203,13 +210,25 @@ def plot_wavelet_grid(
203
210
  **kwargs,
204
211
  )
205
212
  if trend_color is not None:
206
- plot_wavelet_trend(wavelet_data, time_grid, freq_grid, ax, color=trend_color, freq_range=freq_range, freq_scale=freq_scale)
213
+ plot_wavelet_trend(
214
+ wavelet_data,
215
+ time_grid,
216
+ freq_grid,
217
+ ax,
218
+ color=trend_color,
219
+ freq_range=freq_range,
220
+ freq_scale=freq_scale,
221
+ )
207
222
 
208
223
  # Add colorbar if requested
209
224
  if show_colorbar:
210
225
  cbar = fig.colorbar(im, ax=ax)
211
226
  if cbar_label is None:
212
- cbar_label = "Absolute Wavelet Amplitude" if absolute else "Wavelet Amplitude"
227
+ cbar_label = (
228
+ "Absolute Wavelet Amplitude"
229
+ if absolute
230
+ else "Wavelet Amplitude"
231
+ )
213
232
  cbar.set_label(cbar_label)
214
233
 
215
234
  # Configure axes scales
@@ -239,14 +258,12 @@ def plot_wavelet_grid(
239
258
  bbox=dict(boxstyle="round", facecolor=None, alpha=0.2),
240
259
  )
241
260
 
242
-
243
261
  # Adjust layout
244
262
  fig.tight_layout()
245
263
 
246
264
  return fig, ax
247
265
 
248
266
 
249
-
250
267
  def plot_freqseries(
251
268
  data: np.ndarray,
252
269
  freq: np.ndarray,
@@ -277,9 +294,10 @@ def plot_periodogram(
277
294
  flow = np.min(np.abs(freq))
278
295
  ax.set_xlabel("Frequency [Hz]")
279
296
  ax.set_ylabel("Periodigram")
280
- ax.set_xlim(left=flow, right=nyquist_frequency/2)
297
+ ax.set_xlim(left=flow, right=nyquist_frequency / 2)
281
298
  return ax.figure, ax
282
299
 
300
+
283
301
  def plot_timeseries(
284
302
  data: np.ndarray, time: np.ndarray, ax=None, **kwargs
285
303
  ) -> Tuple[plt.Figure, plt.Axes]:
@@ -314,7 +332,6 @@ def plot_spectrogram(
314
332
 
315
333
  _fmt_time_axis(t, ax)
316
334
 
317
-
318
335
  ax.set_ylabel("Frequency [Hz]")
319
336
  ax.set_ylim(top=fs / 2.0)
320
337
  cbar = plt.colorbar(cm, ax=ax)
@@ -322,20 +339,24 @@ def plot_spectrogram(
322
339
  return ax.figure, ax
323
340
 
324
341
 
325
-
326
342
  def _fmt_time_axis(t, axes, t0=None, tmax=None):
327
343
  if t[-1] > DAY_S: # If time goes beyond a day
328
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}"))
344
+ axes.xaxis.set_major_formatter(
345
+ plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}")
346
+ )
329
347
  axes.set_xlabel("Time [days]")
330
348
  elif t[-1] > HOUR_S: # If time goes beyond an hour
331
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}"))
349
+ axes.xaxis.set_major_formatter(
350
+ plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}")
351
+ )
332
352
  axes.set_xlabel("Time [hr]")
333
353
  elif t[-1] > MIN_S: # If time goes beyond a minute
334
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}"))
354
+ axes.xaxis.set_major_formatter(
355
+ plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}")
356
+ )
335
357
  axes.set_xlabel("Time [min]")
336
358
  else:
337
359
  axes.set_xlabel("Time [s]")
338
360
  t0 = t[0] if t0 is None else t0
339
361
  tmax = t[-1] if tmax is None else tmax
340
362
  axes.set_xlim(t0, tmax)
341
-
@@ -1,11 +1,20 @@
1
+ from typing import Optional, Tuple, Union
2
+
1
3
  import matplotlib.pyplot as plt
2
- from typing import Tuple, Optional, Union
3
- from scipy.signal.windows import tukey
4
4
  from scipy.signal import butter, sosfiltfilt
5
+ from scipy.signal.windows import tukey
5
6
 
6
- from ...logger import logger
7
- from .common import is_documented_by, xp, rfft, rfftfreq, fmt_timerange, fmt_time, fmt_pow2
8
- from .plotting import plot_timeseries, plot_spectrogram
7
+ from ..logger import logger
8
+ from .common import (
9
+ fmt_pow2,
10
+ fmt_time,
11
+ fmt_timerange,
12
+ is_documented_by,
13
+ rfft,
14
+ rfftfreq,
15
+ xp,
16
+ )
17
+ from .plotting import plot_spectrogram, plot_timeseries
9
18
 
10
19
  __all__ = ["TimeSeries"]
11
20
 
@@ -50,7 +59,7 @@ class TimeSeries:
50
59
 
51
60
  @is_documented_by(plot_spectrogram)
52
61
  def plot_spectrogram(
53
- self, ax=None, spec_kwargs={}, plot_kwargs={}
62
+ self, ax=None, spec_kwargs={}, plot_kwargs={}
54
63
  ) -> Tuple[plt.Figure, plt.Axes]:
55
64
  return plot_spectrogram(
56
65
  self.data,
@@ -122,9 +131,11 @@ class TimeSeries:
122
131
  trange = fmt_timerange((self.t0, self.tend))
123
132
  T = " ".join(fmt_time(self.duration, units=True))
124
133
  n = fmt_pow2(len(self))
125
- return f"TimeSeries(n={n}, trange={trange}, T={T}, fs={self.fs:.2f} Hz)"
134
+ return (
135
+ f"TimeSeries(n={n}, trange={trange}, T={T}, fs={self.fs:.2f} Hz)"
136
+ )
126
137
 
127
- def to_frequencyseries(self) -> 'FrequencySeries':
138
+ def to_frequencyseries(self) -> "FrequencySeries":
128
139
  """
129
140
  Convert the time series to a frequency series using the one-sided FFT.
130
141
 
@@ -137,14 +148,15 @@ class TimeSeries:
137
148
  data = rfft(self.data)
138
149
 
139
150
  from .frequencyseries import FrequencySeries # Avoid circular import
151
+
140
152
  return FrequencySeries(data, freq, t0=self.t0)
141
153
 
142
154
  def to_wavelet(
143
- self,
144
- Nf: Union[int, None] = None,
145
- Nt: Union[int, None] = None,
146
- nx: Optional[float] = 4.0,
147
- ) -> 'Wavelet':
155
+ self,
156
+ Nf: Union[int, None] = None,
157
+ Nt: Union[int, None] = None,
158
+ nx: Optional[float] = 4.0,
159
+ ) -> "Wavelet":
148
160
  """
149
161
  Convert the time series to a wavelet representation.
150
162
 
@@ -166,32 +178,37 @@ class TimeSeries:
166
178
  hf = self.to_frequencyseries()
167
179
  return hf.to_wavelet(Nf, Nt, nx=nx)
168
180
 
169
-
170
- def __add__(self, other: 'TimeSeries') -> 'TimeSeries':
181
+ def __add__(self, other: "TimeSeries") -> "TimeSeries":
171
182
  """Add two TimeSeries objects together."""
172
183
  if self.shape != other.shape:
173
- raise ValueError("TimeSeries objects must have the same shape to add them together")
184
+ raise ValueError(
185
+ "TimeSeries objects must have the same shape to add them together"
186
+ )
174
187
  return TimeSeries(self.data + other.data, self.time)
175
188
 
176
- def __sub__(self, other: 'TimeSeries') -> 'TimeSeries':
189
+ def __sub__(self, other: "TimeSeries") -> "TimeSeries":
177
190
  """Subtract one TimeSeries object from another."""
178
191
  if self.shape != other.shape:
179
- raise ValueError("TimeSeries objects must have the same shape to subtract them")
192
+ raise ValueError(
193
+ "TimeSeries objects must have the same shape to subtract them"
194
+ )
180
195
  return TimeSeries(self.data - other.data, self.time)
181
196
 
182
- def __eq__(self, other: 'TimeSeries') -> bool:
197
+ def __eq__(self, other: "TimeSeries") -> bool:
183
198
  """Check if two TimeSeries objects are equal."""
184
199
  shape_same = self.shape == other.shape
185
200
  range_same = self.t0 == other.t0 and self.tend == other.tend
186
- time_same = xp.allclose(self.time, other.time)
201
+ time_same = xp.allclose(self.time, other.time)
187
202
  data_same = xp.allclose(self.data, other.data)
188
203
  return shape_same and range_same and data_same and time_same
189
204
 
190
- def __mul__(self, other: float) -> 'TimeSeries':
205
+ def __mul__(self, other: float) -> "TimeSeries":
191
206
  """Multiply a TimeSeries object by a scalar."""
192
207
  return TimeSeries(self.data * other, self.time)
193
208
 
194
- def zero_pad_to_power_of_2(self, tukey_window_alpha:float=0.0)->'TimeSeries':
209
+ def zero_pad_to_power_of_2(
210
+ self, tukey_window_alpha: float = 0.0
211
+ ) -> "TimeSeries":
195
212
  """Zero pad the time series to make the length a power of two (useful to speed up FFTs, O(NlogN) versus O(N^2)).
196
213
 
197
214
  Parameters
@@ -207,7 +224,7 @@ class TimeSeries:
207
224
  """
208
225
  N, dt, t0 = self.ND, self.dt, self.t0
209
226
  pow_2 = xp.ceil(xp.log2(N))
210
- n_pad = int((2 ** pow_2) - N)
227
+ n_pad = int((2**pow_2) - N)
211
228
  new_N = N + n_pad
212
229
  if n_pad > 0:
213
230
  logger.warning(
@@ -220,7 +237,12 @@ class TimeSeries:
220
237
  time = xp.arange(0, len(data) * dt, dt) + t0
221
238
  return TimeSeries(data, time)
222
239
 
223
- def highpass_filter(self, fmin: float, tukey_window_alpha:float=0.0, bandpass_order: int = 4) -> 'TimeSeries':
240
+ def highpass_filter(
241
+ self,
242
+ fmin: float,
243
+ tukey_window_alpha: float = 0.0,
244
+ bandpass_order: int = 4,
245
+ ) -> "TimeSeries":
224
246
  """
225
247
  Filter the time series with a highpass bandpass filter.
226
248
 
@@ -244,24 +266,25 @@ class TimeSeries:
244
266
  """
245
267
 
246
268
  if fmin <= 0 or fmin > self.nyquist_frequency:
247
- raise ValueError(f"Invalid fmin value: {fmin}. Must be in the range [0, {self.nyquist_frequency}]")
269
+ raise ValueError(
270
+ f"Invalid fmin value: {fmin}. Must be in the range [0, {self.nyquist_frequency}]"
271
+ )
248
272
 
249
- sos = butter(bandpass_order, Wn=fmin, btype="highpass", output='sos', fs=self.fs)
273
+ sos = butter(
274
+ bandpass_order, Wn=fmin, btype="highpass", output="sos", fs=self.fs
275
+ )
250
276
  window = tukey(self.ND, alpha=tukey_window_alpha)
251
277
  data = self.data.copy()
252
278
  data = sosfiltfilt(sos, data * window)
253
279
  return TimeSeries(data, self.time)
254
280
 
255
281
  def __copy__(self):
256
- return TimeSeries(
257
- self.data.copy(),
258
- self.time.copy()
259
- )
282
+ return TimeSeries(self.data.copy(), self.time.copy())
260
283
 
261
284
  def copy(self):
262
285
  return self.__copy__()
263
286
 
264
- def __getitem__(self, key)->"TimeSeries":
287
+ def __getitem__(self, key) -> "TimeSeries":
265
288
  if isinstance(key, slice):
266
289
  # Handle slicing
267
290
  return self.__handle_slice(key)
@@ -269,12 +292,9 @@ class TimeSeries:
269
292
  # Handle regular indexing
270
293
  return TimeSeries(self.data[key], self.time[key])
271
294
 
272
- def __handle_slice(self, slice_obj)->"TimeSeries":
273
- return TimeSeries(
274
- self.data[slice_obj],
275
- self.time[slice_obj]
276
- )
295
+ def __handle_slice(self, slice_obj) -> "TimeSeries":
296
+ return TimeSeries(self.data[slice_obj], self.time[slice_obj])
277
297
 
278
298
  @classmethod
279
- def _EMPTY(cls, ND:int, dt:float)->"TimeSeries":
280
- return cls(xp.zeros(ND), xp.arange(0, ND*dt, dt))
299
+ def _EMPTY(cls, ND: int, dt: float) -> "TimeSeries":
300
+ return cls(xp.zeros(ND), xp.arange(0, ND * dt, dt))
@@ -1,10 +1,11 @@
1
- import matplotlib.pyplot as plt
2
- from typing import Optional, Tuple
1
+ from typing import List, Optional, Tuple
3
2
 
3
+ import matplotlib.pyplot as plt
4
4
  import numpy as np
5
5
 
6
- from .common import is_documented_by, xp, fmt_timerange
6
+ from .common import fmt_timerange, is_documented_by, xp
7
7
  from .plotting import plot_wavelet_grid, plot_wavelet_trend
8
+ from .wavelet_bins import compute_bins
8
9
 
9
10
 
10
11
  class Wavelet:
@@ -22,10 +23,10 @@ class Wavelet:
22
23
  """
23
24
 
24
25
  def __init__(
25
- self,
26
- data: xp.ndarray,
27
- time: xp.ndarray,
28
- freq: xp.ndarray,
26
+ self,
27
+ data: xp.ndarray,
28
+ time: xp.ndarray,
29
+ freq: xp.ndarray,
29
30
  ):
30
31
  """
31
32
  Initialize the Wavelet object with data, time, and frequency arrays.
@@ -53,17 +54,60 @@ class Wavelet:
53
54
  self.time = time
54
55
  self.freq = freq
55
56
 
57
+ @classmethod
58
+ def zeros_from_grid(cls, time: xp.ndarray, freq: xp.ndarray) -> "Wavelet":
59
+ """
60
+ Create a Wavelet object filled with zeros.
61
+
62
+ Parameters
63
+ ----------
64
+ time: xp.ndarray
65
+ freq: xp.ndarray
66
+
67
+ Returns
68
+ -------
69
+ Wavelet
70
+ A Wavelet object with zero-filled data array.
71
+ """
72
+ Nf, Nt = len(freq), len(time)
73
+ return cls(data=xp.zeros((Nf, Nt)), time=time, freq=freq)
74
+
75
+ @classmethod
76
+ def zeros(cls, Nf: int, Nt: int, T: float) -> "Wavelet":
77
+ """
78
+ Create a Wavelet object filled with zeros.
79
+
80
+ Parameters
81
+ ----------
82
+ Nf : int
83
+ Number of frequency bins.
84
+ Nt : int
85
+ Number of time bins.
86
+
87
+ Returns
88
+ -------
89
+ Wavelet
90
+ A Wavelet object with zero-filled data array.
91
+ """
92
+ return cls.zeros_from_grid(*compute_bins(Nf, Nt, T))
93
+
56
94
  @is_documented_by(plot_wavelet_grid)
57
95
  def plot(self, ax=None, *args, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
58
96
  kwargs["time_grid"] = kwargs.get("time_grid", self.time)
59
97
  kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
60
- return plot_wavelet_grid(wavelet_data=self.data, ax=ax, *args, **kwargs)
98
+ return plot_wavelet_grid(
99
+ wavelet_data=self.data, ax=ax, *args, **kwargs
100
+ )
61
101
 
62
102
  @is_documented_by(plot_wavelet_trend)
63
- def plot_trend(self, ax=None, *args, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
103
+ def plot_trend(
104
+ self, ax=None, *args, **kwargs
105
+ ) -> Tuple[plt.Figure, plt.Axes]:
64
106
  kwargs["time_grid"] = kwargs.get("time_grid", self.time)
65
107
  kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
66
- return plot_wavelet_trend(wavelet_data=self.data, ax=ax, *args, **kwargs)
108
+ return plot_wavelet_trend(
109
+ wavelet_data=self.data, ax=ax, *args, **kwargs
110
+ )
67
111
 
68
112
  @property
69
113
  def Nt(self) -> int:
@@ -242,7 +286,8 @@ class Wavelet:
242
286
  TimeSeries
243
287
  A `TimeSeries` object representing the time-domain signal.
244
288
  """
245
- from ..inverse import from_wavelet_to_time
289
+ from ..transforms.inverse import from_wavelet_to_time
290
+
246
291
  return from_wavelet_to_time(self, dt=self.delta_t, nx=nx, mult=mult)
247
292
 
248
293
  def to_frequencyseries(self, nx: float = 4.0) -> "FrequencySeries":
@@ -254,7 +299,8 @@ class Wavelet:
254
299
  FrequencySeries
255
300
  A `FrequencySeries` object representing the frequency-domain signal.
256
301
  """
257
- from ..inverse import from_wavelet_to_freq
302
+ from ..transforms.inverse import from_wavelet_to_freq
303
+
258
304
  return from_wavelet_to_freq(self, dt=self.delta_t, nx=nx)
259
305
 
260
306
  def __repr__(self) -> str:
@@ -277,40 +323,57 @@ class Wavelet:
277
323
  def __add__(self, other):
278
324
  """Element-wise addition of two Wavelet objects."""
279
325
  if isinstance(other, Wavelet):
280
- return Wavelet(data=self.data + other.data, time=self.time, freq=self.freq)
326
+ return Wavelet(
327
+ data=self.data + other.data, time=self.time, freq=self.freq
328
+ )
281
329
  elif isinstance(other, float):
282
- return Wavelet(data=self.data + other, time=self.time, freq=self.freq)
330
+ return Wavelet(
331
+ data=self.data + other, time=self.time, freq=self.freq
332
+ )
283
333
 
284
334
  def __sub__(self, other):
285
335
  """Element-wise subtraction of two Wavelet objects."""
286
336
  if isinstance(other, Wavelet):
287
- return Wavelet(data=self.data - other.data, time=self.time, freq=self.freq)
337
+ return Wavelet(
338
+ data=self.data - other.data, time=self.time, freq=self.freq
339
+ )
288
340
  elif isinstance(other, float):
289
- return Wavelet(data=self.data - other, time=self.time, freq=self.freq)
341
+ return Wavelet(
342
+ data=self.data - other, time=self.time, freq=self.freq
343
+ )
290
344
 
291
345
  def __mul__(self, other):
292
346
  """Element-wise multiplication of two Wavelet objects."""
293
347
  if isinstance(other, Wavelet):
294
- return Wavelet(data=self.data * other.data, time=self.time, freq=self.freq)
348
+ return Wavelet(
349
+ data=self.data * other.data, time=self.time, freq=self.freq
350
+ )
295
351
  elif isinstance(other, float):
296
- return Wavelet(data=self.data * other, time=self.time, freq=self.freq)
352
+ return Wavelet(
353
+ data=self.data * other, time=self.time, freq=self.freq
354
+ )
297
355
 
298
356
  def __truediv__(self, other):
299
357
  """Element-wise division of two Wavelet objects."""
300
358
  if isinstance(other, Wavelet):
301
- return Wavelet(data=self.data / other.data, time=self.time, freq=self.freq)
359
+ return Wavelet(
360
+ data=self.data / other.data, time=self.time, freq=self.freq
361
+ )
302
362
  elif isinstance(other, float):
303
- return Wavelet(data=self.data / other, time=self.time, freq=self.freq)
363
+ return Wavelet(
364
+ data=self.data / other, time=self.time, freq=self.freq
365
+ )
304
366
 
305
- def __eq__(self, other:"Wavelet") -> bool:
367
+ def __eq__(self, other: "Wavelet") -> bool:
306
368
  """Element-wise comparison of two Wavelet objects."""
307
369
  data_all_same = xp.isclose(xp.nansum(self.data - other.data), 0)
308
370
  time_same = (self.time == other.time).all()
309
371
  freq_same = (self.freq == other.freq).all()
310
372
  return data_all_same and time_same and freq_same
311
373
 
312
-
313
- def noise_weighted_inner_product(self, other:"Wavelet", psd:"Wavelet") -> float:
374
+ def noise_weighted_inner_product(
375
+ self, other: "Wavelet", psd: "Wavelet"
376
+ ) -> float:
314
377
  """
315
378
  Compute the noise-weighted inner product of two wavelet grids given a PSD.
316
379
 
@@ -326,11 +389,11 @@ class Wavelet:
326
389
  float
327
390
  The noise-weighted inner product.
328
391
  """
329
- from ...utils import noise_weighted_inner_product
330
- return noise_weighted_inner_product(self, other, psd)
392
+ from ..utils import noise_weighted_inner_product
331
393
 
394
+ return noise_weighted_inner_product(self, other, psd)
332
395
 
333
- def matched_filter_snr(self, template:"Wavelet", psd:"Wavelet") -> float:
396
+ def matched_filter_snr(self, template: "Wavelet", psd: "Wavelet") -> float:
334
397
  """
335
398
  Compute the matched filter SNR of the wavelet grid given a template.
336
399
 
@@ -347,7 +410,7 @@ class Wavelet:
347
410
  mf = self.noise_weighted_inner_product(template, psd)
348
411
  return mf / self.optimal_snr(psd)
349
412
 
350
- def optimal_snr(self, psd:"Wavelet") -> float:
413
+ def optimal_snr(self, psd: "Wavelet") -> float:
351
414
  """
352
415
  Compute the optimal SNR of the wavelet grid given a PSD.
353
416
 
@@ -365,10 +428,27 @@ class Wavelet:
365
428
 
366
429
  def __copy__(self):
367
430
  return Wavelet(
368
- data=self.data.copy(),
369
- time=self.time.copy(),
370
- freq=self.freq.copy()
431
+ data=self.data.copy(), time=self.time.copy(), freq=self.freq.copy()
371
432
  )
372
433
 
373
434
  def copy(self):
374
- return self.__copy__()
435
+ return self.__copy__()
436
+
437
+
438
+ class WaveletMask(Wavelet):
439
+ @property
440
+ def mask(self):
441
+ return self.data
442
+
443
+ def __repr__(self):
444
+ return f"WaveletMask({self.mask.shape}, {fmt_timerange(self.time)}, {self.freq})"
445
+
446
+ @classmethod
447
+ def from_frange(
448
+ cls, time_grid: xp.ndarray, freq_grid: xp.ndarray, frange: List[float]
449
+ ):
450
+ self = cls.zeros_from_grid(time_grid, freq_grid)
451
+ self.mask[
452
+ (freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
453
+ ] = True
454
+ return self
@@ -2,7 +2,8 @@ from typing import Tuple, Union
2
2
 
3
3
  import numpy as np
4
4
 
5
- from ..types import FrequencySeries, TimeSeries
5
+ from .frequencyseries import FrequencySeries
6
+ from .timeseries import TimeSeries
6
7
 
7
8
 
8
9
  def _preprocess_bins(
@@ -10,7 +11,6 @@ def _preprocess_bins(
10
11
  ) -> Tuple[int, int]:
11
12
  """preprocess the bins"""
12
13
 
13
-
14
14
  if isinstance(data, TimeSeries):
15
15
  N = len(data)
16
16
  elif isinstance(data, FrequencySeries):
@@ -34,7 +34,6 @@ def _get_bins(
34
34
  Nf: Union[int, None] = None,
35
35
  Nt: Union[int, None] = None,
36
36
  ) -> Tuple[np.ndarray, np.ndarray]:
37
-
38
37
  T = data.duration
39
38
  t_bins, f_bins = compute_bins(Nf, Nt, T)
40
39
 
@@ -42,12 +41,12 @@ def _get_bins(
42
41
  # fs = N / T
43
42
  # assert delta_f == fmax / Nf, f"delta_f={delta_f} != fmax/Nf={fmax/Nf}"
44
43
 
45
- t_bins+= data.t0
44
+ t_bins += data.t0
46
45
 
47
46
  return t_bins, f_bins
48
47
 
49
48
 
50
- def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
49
+ def compute_bins(Nf: int, Nt: int, T: float) -> Tuple[np.ndarray, np.ndarray]:
51
50
  """Get the bins for the wavelet transform
52
51
  Eq 4-6 in Wavelets paper
53
52
  """
@@ -55,4 +54,4 @@ def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
55
54
  delta_F = 1 / (2 * delta_T)
56
55
  t_bins = np.arange(0, Nt) * delta_T
57
56
  f_bins = np.arange(0, Nf) * delta_F
58
- return t_bins, f_bins
57
+ return t_bins, f_bins
pywavelet/utils.py CHANGED
@@ -3,7 +3,7 @@ from typing import Union
3
3
  import numpy as np
4
4
  from scipy.interpolate import interp1d
5
5
 
6
- from .transforms.types import FrequencySeries, TimeSeries, Wavelet
6
+ from .types import FrequencySeries, TimeSeries, Wavelet, WaveletMask
7
7
 
8
8
  DATA_TYPE = Union[TimeSeries, FrequencySeries, Wavelet]
9
9
 
@@ -34,10 +34,13 @@ def evolutionary_psd_from_stationary_psd(
34
34
  return Wavelet(psd_grid.T, time=t_grid, freq=f_grid)
35
35
 
36
36
 
37
- def noise_weighted_inner_product(d: Wavelet, h: Wavelet, PSD: Wavelet) -> float:
37
+ def noise_weighted_inner_product(
38
+ d: Wavelet, h: Wavelet, PSD: Wavelet
39
+ ) -> float:
38
40
  return np.nansum((d.data * h.data) / PSD.data)
39
41
 
40
- def compute_snr(d:Wavelet, h: Wavelet, PSD: Wavelet) -> float:
42
+
43
+ def compute_snr(d: Wavelet, h: Wavelet, PSD: Wavelet) -> float:
41
44
  """Compute the SNR of a model h[ti,fi] given freqseries d[ti,fi] and PSD[ti,fi].
42
45
 
43
46
  SNR(h) = Sum_{ti,fi} [ h_hat[ti,fi] d[ti,fi] / PSD[ti,fi]
@@ -60,5 +63,14 @@ def compute_snr(d:Wavelet, h: Wavelet, PSD: Wavelet) -> float:
60
63
  return np.sqrt(noise_weighted_inner_product(d, h, PSD))
61
64
 
62
65
 
63
- def compute_likelihood(data:Wavelet, template:Wavelet, psd:Wavelet) -> float:
64
- return -0.5 * np.nansum((data.data - template.data) ** 2 / psd.data)
66
+ def compute_likelihood(
67
+ data: Wavelet, template: Wavelet, psd: Wavelet, mask: WaveletMask = None
68
+ ) -> float:
69
+ d = data.data
70
+ h = template.data
71
+ p = psd.data
72
+ if mask is not None:
73
+ m = mask.mask
74
+ d, h, p = d * m, h * m, p * m
75
+
76
+ return -0.5 * np.nansum((d - h) ** 2 / p)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.0.5
3
+ Version: 0.1.1
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -0,0 +1,25 @@
1
+ pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
+ pywavelet/_version.py,sha256=PKIMyjdUACH4-ONvtunQCnYE2UhlMfp9su83e3HXl5E,411
3
+ pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
4
+ pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
5
+ pywavelet/transforms/__init__.py,sha256=1Ibsup9UwMajeZ9NCQ4BN15qZTeJ_EHkgGu8XNFdA18,255
6
+ pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
7
+ pywavelet/transforms/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
8
+ pywavelet/transforms/forward/from_freq.py,sha256=wCiyLpzJE3rGxYjQBdXlwkxPIRYhQWjKq0C_8zYlmDk,2697
9
+ pywavelet/transforms/forward/from_time.py,sha256=-Y6VEKwDCYBAHAjLdO46vT-6alpM5fXTgTZ_xkYxqA8,2381
10
+ pywavelet/transforms/forward/main.py,sha256=Gfy0sp-woy_3ihKMzuk2WuZ7dRk-Mm6sp5dVpYrSvj4,4005
11
+ pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
12
+ pywavelet/transforms/inverse/main.py,sha256=Q7wUaRjB1sgqdB7dniWQGbPTWYQNnIsIrYtjsaHJEdE,3012
13
+ pywavelet/transforms/inverse/to_freq.py,sha256=so_TDbwdS1N8sd1QcpeAEkI10XFDtoFJGohtD4YulZM,2809
14
+ pywavelet/transforms/inverse/to_time.py,sha256=w5vmImdsb_4YeInZtXh0llsThLTxS0tmYDlNGJ-IUew,5080
15
+ pywavelet/types/__init__.py,sha256=5YptzQvYBnRfC8N5lpOBf9I1lzpJ0pw0QMnvIcwP3YI,122
16
+ pywavelet/types/common.py,sha256=OSAW6GqLTgqJ-RYEv__XbzsfFd8AFo5w-ctXQ4XAFZo,1317
17
+ pywavelet/types/frequencyseries.py,sha256=UqcE6UQfw5HZm4na2q9k-X-mfqO-BCiTAvGjaYpSrwc,7518
18
+ pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
19
+ pywavelet/types/timeseries.py,sha256=6DPO0xLi4Dq2srhJLmavFMf4fYIC3wwdbyMU7lMdjTo,9446
20
+ pywavelet/types/wavelet.py,sha256=ptTEnq6nRZiW2x6g_NV_FuoeVNOsbNOu6caSxYDZNgk,12583
21
+ pywavelet/types/wavelet_bins.py,sha256=SC9nhyigWvOfs2TbH8-Ck_iS1Mrz6sG-8EaIFYLIHuk,1453
22
+ pywavelet-0.1.1.dist-info/METADATA,sha256=OgKTqqZQKfhCCcsYpcXnrut-2EAH42c3tCpCGc0PIDg,1307
23
+ pywavelet-0.1.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
24
+ pywavelet-0.1.1.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
25
+ pywavelet-0.1.1.dist-info/RECORD,,
@@ -1,25 +0,0 @@
1
- pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
- pywavelet/_version.py,sha256=EJB7__SNK9kQS_SWZB_U4DHJ3P8ftF6etZEihTYnuXE,411
3
- pywavelet/logger.py,sha256=u5Et6QLUU0IuTnPnxfev5-4GYmihTo7_yBCTf-rVUsA,377
4
- pywavelet/utils.py,sha256=rbmZ-PrnFstq05kh6xj8emuQ3-7oXqZinU03bpg3N7A,1746
5
- pywavelet/transforms/__init__.py,sha256=FolK8WiVEJmGDC9xMupYVI_essXaXS4LYWKbqEqGx6o,289
6
- pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
7
- pywavelet/transforms/forward/__init__.py,sha256=Yq4Tg3Ze98-17C9FIkOqMUdiLHe9x_YoyuRvxOxMOP0,176
8
- pywavelet/transforms/forward/from_freq.py,sha256=HWADciv746P5zWNhnyRKDEDACTQnwDTR4usf2k0VIWk,2663
9
- pywavelet/transforms/forward/from_time.py,sha256=zfpvmq7UHnynzAnsM_Pf4MXpO1GJ0TaCCERQKNsaBNU,2340
10
- pywavelet/transforms/forward/main.py,sha256=FJSLUUsk3I91lUeB2YFiXQdX8gB9maFW60e4tgrd45A,4053
11
- pywavelet/transforms/forward/wavelet_bins.py,sha256=43CHVSjNgC59Af1_h9t33VBBqGHJmIUaOUhJgekhMnc,1419
12
- pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
13
- pywavelet/transforms/inverse/main.py,sha256=l5yFvzmWObrO5Xt_8KYp62w829ab0pQLxcv0-QxkvG0,3015
14
- pywavelet/transforms/inverse/to_freq.py,sha256=SExZMax-8A-tJpIA86pYY61X2qvlZ2MrZY27uzCQSV0,2778
15
- pywavelet/transforms/inverse/to_time.py,sha256=BAYvrr41QHIbzwYMPyMnzv5mqSx40YigmBruWBwtZwc,5041
16
- pywavelet/transforms/types/__init__.py,sha256=4wUTVBk6A02xjZUY_w056eUZurYI9vVfa--I3Q6Udng,109
17
- pywavelet/transforms/types/common.py,sha256=sLBn2d9cuL6NYeIv6NIogDjY6rYPZAPzCtWGwMlAwkI,1290
18
- pywavelet/transforms/types/frequencyseries.py,sha256=Rtwt486UL0-TgMAdcMMVpyfi5PSLzxFcdE_RsYuyxQk,7463
19
- pywavelet/transforms/types/plotting.py,sha256=aEIFoSuQRYwYc2639yLkbujXx_aav5A5tXG29rOSMOQ,10275
20
- pywavelet/transforms/types/timeseries.py,sha256=Nl1tiKZ7kwu-EZ5JtubwgzgyjQZR86eGU3C3kElIPNg,9296
21
- pywavelet/transforms/types/wavelet.py,sha256=raODhyegBd1_esuv2YP6tK_Nj9fYaLQ6O4pxpiFAVZU,10790
22
- pywavelet-0.0.5.dist-info/METADATA,sha256=BaLK8jbUqss_8BeGWOlf_031PoCAlEV4rmuSaitQd0E,1307
23
- pywavelet-0.0.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
24
- pywavelet-0.0.5.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
25
- pywavelet-0.0.5.dist-info/RECORD,,