pywavelet 0.0.5__py3-none-any.whl → 0.1.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
pywavelet/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.0.5'
16
- __version_tuple__ = version_tuple = (0, 0, 5)
15
+ __version__ = version = '0.1.1'
16
+ __version_tuple__ = version_tuple = (0, 1, 1)
pywavelet/logger.py CHANGED
@@ -1,11 +1,15 @@
1
+ import logging
1
2
  import sys
2
3
  import warnings
4
+
3
5
  from rich.logging import RichHandler
4
- import logging
5
6
 
6
7
  FORMAT = "%(message)s"
7
8
  logging.basicConfig(
8
- level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
9
+ level="INFO",
10
+ format=FORMAT,
11
+ datefmt="[%X]",
12
+ handlers=[RichHandler(rich_tracebacks=True)],
9
13
  )
10
14
 
11
15
  logger = logging.getLogger("pywavelet")
@@ -1,4 +1,4 @@
1
- from .forward import from_freq_to_wavelet, from_time_to_wavelet, compute_bins
1
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
2
2
  from .inverse import from_wavelet_to_freq, from_wavelet_to_time
3
3
 
4
4
  __all__ = [
@@ -6,5 +6,4 @@ __all__ = [
6
6
  "from_wavelet_to_freq",
7
7
  "from_time_to_wavelet",
8
8
  "from_freq_to_wavelet",
9
- "compute_bins",
10
9
  ]
@@ -1,4 +1,3 @@
1
1
  from .main import from_freq_to_wavelet, from_time_to_wavelet
2
- from .wavelet_bins import compute_bins
3
2
 
4
- __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet", "compute_bins"]
3
+ __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -1,8 +1,10 @@
1
1
  """helper functions for transform_freq"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
5
6
 
7
+
6
8
  def transform_wavelet_freq_helper(
7
9
  data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
8
10
  ) -> np.ndarray:
@@ -13,8 +15,16 @@ def transform_wavelet_freq_helper(
13
15
  __core(Nf, Nt, DX, freq_strain, phif, wave)
14
16
  return wave
15
17
 
18
+
16
19
  # @njit()
17
- def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarray, wave:np.ndarray):
20
+ def __core(
21
+ Nf: int,
22
+ Nt: int,
23
+ DX: np.ndarray,
24
+ freq_strain: np.ndarray,
25
+ phif: np.ndarray,
26
+ wave: np.ndarray,
27
+ ):
18
28
  for f_bin in range(0, Nf + 1):
19
29
  __fill_wave_1(f_bin, Nt, Nf, DX, freq_strain, phif)
20
30
  # Numba doesn't support np.ifft so we cant jit this
@@ -22,8 +32,6 @@ def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarra
22
32
  __fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
23
33
 
24
34
 
25
-
26
-
27
35
  @njit()
28
36
  def __fill_wave_1(
29
37
  f_bin: int,
@@ -55,6 +63,7 @@ def __fill_wave_1(
55
63
  else:
56
64
  DX[i] = phif[j] * data[jj]
57
65
 
66
+
58
67
  @njit()
59
68
  def __fill_wave_2(
60
69
  f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
@@ -72,7 +81,7 @@ def __fill_wave_2(
72
81
  if (n + f_bin) % 2:
73
82
  wave[n, f_bin] = -DX_trans[n].imag
74
83
  else:
75
- wave[n, f_bin] = DX_trans[n].real
84
+ wave[n, f_bin] = DX_trans[n].real
76
85
  else:
77
86
  if (n + f_bin) % 2:
78
87
  wave[n, f_bin] = DX_trans[n].imag
@@ -1,4 +1,5 @@
1
1
  """helper functions for transform_time.py"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -20,7 +21,18 @@ def transform_wavelet_time_helper(
20
21
  __core(Nf, Nt, K, ND, wdata, data_pad, phi, wave, mult)
21
22
  return wave
22
23
 
23
- def __core(Nf: int, Nt: int, K: int, ND: int, wdata: np.ndarray, data_pad: np.ndarray, phi: np.ndarray, wave: np.ndarray, mult: int) -> None:
24
+
25
+ def __core(
26
+ Nf: int,
27
+ Nt: int,
28
+ K: int,
29
+ ND: int,
30
+ wdata: np.ndarray,
31
+ data_pad: np.ndarray,
32
+ phi: np.ndarray,
33
+ wave: np.ndarray,
34
+ mult: int,
35
+ ) -> None:
24
36
  for time_bin_i in range(0, Nt):
25
37
  __fill_wave_1(time_bin_i, K, ND, Nf, wdata, data_pad, phi)
26
38
  wdata_trans = np.fft.rfft(wdata, K)
@@ -3,15 +3,15 @@ from typing import Union
3
3
  import numpy as np
4
4
 
5
5
  from ...logger import logger
6
+ from ...types import FrequencySeries, TimeSeries, Wavelet
7
+ from ...types.wavelet_bins import _get_bins, _preprocess_bins
6
8
  from ..phi_computer import phi_vec, phitilde_vec_norm
7
- from ..types import FrequencySeries, TimeSeries, Wavelet
8
9
  from .from_freq import transform_wavelet_freq_helper
9
10
  from .from_time import transform_wavelet_time_helper
10
- from .wavelet_bins import _get_bins, _preprocess_bins
11
-
12
11
 
13
12
  __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
14
13
 
14
+
15
15
  def from_time_to_wavelet(
16
16
  timeseries: TimeSeries,
17
17
  Nf: Union[int, None] = None,
@@ -74,9 +74,7 @@ def from_time_to_wavelet(
74
74
  mult = min(mult, Nt // 2) # Ensure mult is not larger than ND/2
75
75
  phi = phi_vec(Nf, dt=dt, d=nx, q=mult)
76
76
  wave = transform_wavelet_time_helper(timeseries.data, Nf, Nt, phi, mult).T
77
- return Wavelet(
78
- wave * np.sqrt(2), time=t_bins, freq=f_bins
79
- )
77
+ return Wavelet(wave * np.sqrt(2), time=t_bins, freq=f_bins)
80
78
 
81
79
 
82
80
  def from_freq_to_wavelet(
@@ -117,12 +115,6 @@ def from_freq_to_wavelet(
117
115
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
118
116
  dt = freqseries.dt
119
117
  phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)
120
- wave = transform_wavelet_freq_helper(
121
- freqseries.data, Nf, Nt, phif
122
- )
123
-
124
- return Wavelet(
125
- (2 / Nf) * wave.T * np.sqrt(2),
126
- time=t_bins,
127
- freq=f_bins
128
- )
118
+ wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
119
+
120
+ return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
 
3
3
  from ...transforms.phi_computer import phi_vec, phitilde_vec_norm
4
- from ..types import FrequencySeries, TimeSeries, Wavelet
4
+ from ...types import FrequencySeries, TimeSeries, Wavelet
5
5
  from .to_freq import inverse_wavelet_freq_helper_fast
6
6
  from .to_time import inverse_wavelet_time_helper_fast
7
7
 
@@ -9,6 +9,7 @@ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
9
9
 
10
10
  INV_ROOT2 = 1.0 / np.sqrt(2)
11
11
 
12
+
12
13
  def from_wavelet_to_time(
13
14
  wave_in: Wavelet,
14
15
  dt: float,
@@ -55,9 +56,7 @@ def from_wavelet_to_time(
55
56
 
56
57
 
57
58
  def from_wavelet_to_freq(
58
- wave_in: Wavelet,
59
- dt: float,
60
- nx:float=4.0
59
+ wave_in: Wavelet, dt: float, nx: float = 4.0
61
60
  ) -> FrequencySeries:
62
61
  """
63
62
  Perform an inverse wavelet transform to the frequency domain.
@@ -1,4 +1,5 @@
1
1
  """functions for computing the inverse wavelet transforms"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -12,13 +13,20 @@ def inverse_wavelet_freq_helper_fast(
12
13
  ND = Nf * Nt
13
14
 
14
15
  prefactor2s = np.zeros(Nt, np.complex128)
15
- res = np.zeros(ND//2 +1, dtype=np.complex128)
16
+ res = np.zeros(ND // 2 + 1, dtype=np.complex128)
16
17
  __core(Nf, Nt, prefactor2s, wave_in, phif, res)
17
18
 
18
-
19
19
  return res
20
20
 
21
- def __core(Nf: int, Nt: int, prefactor2s: np.ndarray, wave_in: np.ndarray, phif: np.ndarray, res: np.ndarray) -> None:
21
+
22
+ def __core(
23
+ Nf: int,
24
+ Nt: int,
25
+ prefactor2s: np.ndarray,
26
+ wave_in: np.ndarray,
27
+ phif: np.ndarray,
28
+ res: np.ndarray,
29
+ ) -> None:
22
30
  for m in range(0, Nf + 1):
23
31
  __pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
24
32
  fft_prefactor2s = np.fft.fft(prefactor2s)
@@ -1,4 +1,5 @@
1
1
  """functions for computing the inverse wavelet transforms"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -21,7 +22,16 @@ def inverse_wavelet_time_helper_fast(
21
22
  return res[:ND]
22
23
 
23
24
 
24
- def __core(Nf: int, Nt: int, K: int, ND: int, wave_in: np.ndarray, phi: np.ndarray, res: np.ndarray, afins:np.ndarray) -> None:
25
+ def __core(
26
+ Nf: int,
27
+ Nt: int,
28
+ K: int,
29
+ ND: int,
30
+ wave_in: np.ndarray,
31
+ phi: np.ndarray,
32
+ res: np.ndarray,
33
+ afins: np.ndarray,
34
+ ) -> None:
25
35
  for n in range(0, Nt):
26
36
  if n % 2 == 0:
27
37
  pack_wave_time_helper_compact(n, Nf, Nt, wave_in, afins)
@@ -29,9 +39,9 @@ def __core(Nf: int, Nt: int, K: int, ND: int, wave_in: np.ndarray, phi: np.ndarr
29
39
  unpack_time_wave_helper_compact(n, Nf, Nt, K, phi, ffts_fin, res)
30
40
 
31
41
  # wrap boundary conditions
32
- res[: min(K + Nf, ND)] += res[ND: min(ND + K + Nf, 2 * ND)]
42
+ res[: min(K + Nf, ND)] += res[ND : min(ND + K + Nf, 2 * ND)]
33
43
  if K + Nf > ND:
34
- res[: K + Nf - ND] += res[2 * ND: ND + K * Nf]
44
+ res[: K + Nf - ND] += res[2 * ND : ND + K * Nf]
35
45
 
36
46
 
37
47
  def unpack_time_wave_helper(
@@ -1,3 +1,3 @@
1
1
  from .frequencyseries import FrequencySeries
2
2
  from .timeseries import TimeSeries
3
- from .wavelet import Wavelet
3
+ from .wavelet import Wavelet, WaveletMask
@@ -1,9 +1,9 @@
1
- from typing import Literal, Tuple
1
+ from typing import Tuple, Union
2
2
 
3
3
  import numpy as xp
4
- from numpy.fft import irfft, fft, rfft, rfftfreq
4
+ from numpy.fft import fft, irfft, rfft, rfftfreq # type: ignore
5
5
 
6
- from ...logger import logger
6
+ from ..logger import logger
7
7
 
8
8
 
9
9
  def _len_check(d):
@@ -19,7 +19,7 @@ def is_documented_by(original):
19
19
  return wrapper
20
20
 
21
21
 
22
- def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
22
+ def fmt_time(seconds: float, units=False) -> Union[str, Tuple[str, str]]:
23
23
  """Returns formatted time and units [ms, s, min, hr, day]"""
24
24
  t, u = "", ""
25
25
  if seconds < 1e-3:
@@ -42,12 +42,12 @@ def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
42
42
 
43
43
  def fmt_timerange(trange):
44
44
  t0 = fmt_time(trange[0])
45
- tend, units = fmt_time(trange[1], units = True)
45
+ tend, units = fmt_time(trange[1], units=True)
46
46
  return f"[{t0}, {tend}] {units}"
47
47
 
48
48
 
49
- def fmt_pow2(n:float)->str:
49
+ def fmt_pow2(n: float) -> str:
50
50
  pow2 = xp.log2(n)
51
51
  if pow2.is_integer():
52
52
  return f"2^{int(pow2)}"
53
- return f"{n:,}"
53
+ return f"{n:,}"
@@ -1,11 +1,13 @@
1
+ from typing import Optional, Tuple, Union
2
+
1
3
  import matplotlib.pyplot as plt
2
- from typing import Tuple, Union, Optional
3
4
 
4
- from .common import is_documented_by, xp, irfft, fmt_time, fmt_pow2
5
+ from .common import fmt_pow2, fmt_time, irfft, is_documented_by, xp
5
6
  from .plotting import plot_freqseries, plot_periodogram
6
7
 
7
8
  __all__ = ["FrequencySeries"]
8
9
 
10
+
9
11
  class FrequencySeries:
10
12
  """
11
13
  A class to represent a one-sided frequency series, with various methods for
@@ -39,9 +41,13 @@ class FrequencySeries:
39
41
  If any frequency is negative or if `data` and `freq` do not have the same length.
40
42
  """
41
43
  if xp.any(freq < 0):
42
- raise ValueError("FrequencySeries must be one-sided (only non-negative frequencies)")
44
+ raise ValueError(
45
+ "FrequencySeries must be one-sided (only non-negative frequencies)"
46
+ )
43
47
  if len(data) != len(freq):
44
- raise ValueError(f"data and freq must have the same length ({len(data)} != {len(freq)})")
48
+ raise ValueError(
49
+ f"data and freq must have the same length ({len(data)} != {len(freq)})"
50
+ )
45
51
  self.data = data
46
52
  self.freq = freq
47
53
  self.t0 = t0
@@ -53,10 +59,10 @@ class FrequencySeries:
53
59
  )
54
60
 
55
61
  @is_documented_by(plot_periodogram)
56
- def plot_periodogram(self, ax=None, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
57
- return plot_periodogram(
58
- self.data, self.freq, self.fs, ax=ax, **kwargs
59
- )
62
+ def plot_periodogram(
63
+ self, ax=None, **kwargs
64
+ ) -> Tuple[plt.Figure, plt.Axes]:
65
+ return plot_periodogram(self.data, self.freq, self.fs, ax=ax, **kwargs)
60
66
 
61
67
  def __len__(self):
62
68
  """Return the length of the frequency series."""
@@ -127,7 +133,9 @@ class FrequencySeries:
127
133
  n = fmt_pow2(len(self))
128
134
  return f"FrequencySeries(n={n}, frange=[{self.range[0]:.2f}, {self.range[1]:.2f}] Hz, T={dur}, fs={self.fs:.2f} Hz)"
129
135
 
130
- def noise_weighted_inner_product(self, other: "FrequencySeries", psd:"FrequencySeries") -> float:
136
+ def noise_weighted_inner_product(
137
+ self, other: "FrequencySeries", psd: "FrequencySeries"
138
+ ) -> float:
131
139
  """
132
140
  Compute the noise-weighted inner product of two FrequencySeries.
133
141
 
@@ -144,9 +152,11 @@ class FrequencySeries:
144
152
  The noise-weighted inner product of the two FrequencySeries.
145
153
  """
146
154
  integrand = xp.real(xp.conj(self.data) * other.data / psd.data)
147
- return (4 * self.dt/self.ND) * xp.nansum(integrand)
155
+ return (4 * self.dt / self.ND) * xp.nansum(integrand)
148
156
 
149
- def matched_filter_snr(self, other: "FrequencySeries", psd: "FrequencySeries") -> float:
157
+ def matched_filter_snr(
158
+ self, other: "FrequencySeries", psd: "FrequencySeries"
159
+ ) -> float:
150
160
  """
151
161
  Compute the signal-to-noise ratio (SNR) of a matched filter.
152
162
 
@@ -199,15 +209,15 @@ class FrequencySeries:
199
209
 
200
210
  # Create and return a TimeSeries object
201
211
  from .timeseries import TimeSeries
202
- return TimeSeries(time_data, time)
203
212
 
213
+ return TimeSeries(time_data, time)
204
214
 
205
215
  def to_wavelet(
206
- self,
207
- Nf: Union[int, None] = None,
208
- Nt: Union[int, None] = None,
209
- nx: Optional[float] = 4.0,
210
- )->"Wavelet":
216
+ self,
217
+ Nf: Union[int, None] = None,
218
+ Nt: Union[int, None] = None,
219
+ nx: Optional[float] = 4.0,
220
+ ) -> "Wavelet":
211
221
  """
212
222
  Convert the frequency series to a wavelet using inverse Fourier transform.
213
223
 
@@ -216,9 +226,9 @@ class FrequencySeries:
216
226
  Wavelet
217
227
  The corresponding wavelet.
218
228
  """
219
- from ..forward import from_freq_to_wavelet
220
- return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
229
+ from ..transforms.forward import from_freq_to_wavelet
221
230
 
231
+ return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
222
232
 
223
233
  def __eq__(self, other):
224
234
  """Check if two FrequencySeries objects are equal."""
@@ -228,10 +238,8 @@ class FrequencySeries:
228
238
 
229
239
  def __copy__(self):
230
240
  return FrequencySeries(
231
- xp.copy(self.data),
232
- xp.copy(self.freq),
233
- t0=self.t0
241
+ xp.copy(self.data), xp.copy(self.freq), t0=self.t0
234
242
  )
235
243
 
236
244
  def copy(self):
237
- return self.__copy__()
245
+ return self.__copy__()
@@ -1,12 +1,11 @@
1
1
  import warnings
2
- from typing import Tuple, Optional, Union
3
- from scipy.signal import savgol_filter
4
- from scipy.interpolate import interp1d
2
+ from typing import Optional, Tuple, Union
5
3
 
6
4
  import matplotlib.pyplot as plt
7
5
  import numpy as np
8
6
  from matplotlib.colors import LogNorm, TwoSlopeNorm
9
- from scipy.signal import spectrogram
7
+ from scipy.interpolate import interp1d
8
+ from scipy.signal import savgol_filter, spectrogram
10
9
 
11
10
  MIN_S = 60
12
11
  HOUR_S = 60 * MIN_S
@@ -53,8 +52,13 @@ def __get_smoothed_y(x, z, y_grid):
53
52
  # Interpolate to fill NaNs in y before smoothing
54
53
  nan_mask = ~np.isnan(y)
55
54
  if np.isnan(y).any():
56
- interpolator = interp1d(x[nan_mask], y[nan_mask], kind='cubic', bounds_error=False,
57
- fill_value="extrapolate")
55
+ interpolator = interp1d(
56
+ x[nan_mask],
57
+ y[nan_mask],
58
+ kind="cubic",
59
+ bounds_error=False,
60
+ fill_value="extrapolate",
61
+ )
58
62
  y = interpolator(x) # Fill NaNs with interpolated values
59
63
 
60
64
  # Smooth the curve
@@ -64,8 +68,6 @@ def __get_smoothed_y(x, z, y_grid):
64
68
  return y
65
69
 
66
70
 
67
-
68
-
69
71
  def plot_wavelet_grid(
70
72
  wavelet_data: np.ndarray,
71
73
  time_grid: np.ndarray,
@@ -80,8 +82,8 @@ def plot_wavelet_grid(
80
82
  norm: Optional[Union[LogNorm, TwoSlopeNorm]] = None,
81
83
  cbar_label: Optional[str] = None,
82
84
  nan_color: Optional[str] = "black",
83
- detailed_axes:bool = False,
84
- show_gridinfo:bool = True,
85
+ detailed_axes: bool = False,
86
+ show_gridinfo: bool = True,
85
87
  trend_color: Optional[str] = None,
86
88
  whiten_by: Optional[np.ndarray] = None,
87
89
  **kwargs,
@@ -153,7 +155,9 @@ def plot_wavelet_grid(
153
155
 
154
156
  # Validate the dimensions
155
157
  if (Nf, Nt) != (len(freq_grid), len(time_grid)):
156
- raise ValueError(f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}.")
158
+ raise ValueError(
159
+ f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}."
160
+ )
157
161
 
158
162
  # Prepare the data for plotting
159
163
  z = wavelet_data.copy()
@@ -162,14 +166,15 @@ def plot_wavelet_grid(
162
166
  if absolute:
163
167
  z = np.abs(z)
164
168
 
165
-
166
169
  # Determine normalization and colormap
167
170
  if norm is None:
168
171
  try:
169
172
  if np.all(np.isnan(z)):
170
173
  raise ValueError("All wavelet data is NaN.")
171
174
  if zscale == "log":
172
- norm = LogNorm(vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z<np.inf]))
175
+ norm = LogNorm(
176
+ vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z < np.inf])
177
+ )
173
178
  elif not absolute:
174
179
  vmin, vmax = np.nanmin(z), np.nanmax(z)
175
180
  vcenter = 0.0
@@ -177,7 +182,9 @@ def plot_wavelet_grid(
177
182
  else:
178
183
  norm = None # Default linear scaling
179
184
  except Exception as e:
180
- warnings.warn(f"Error in determining normalization: {e}. Using default linear scaling.")
185
+ warnings.warn(
186
+ f"Error in determining normalization: {e}. Using default linear scaling."
187
+ )
181
188
  norm = None
182
189
 
183
190
  if cmap is None:
@@ -195,7 +202,7 @@ def plot_wavelet_grid(
195
202
  im = ax.imshow(
196
203
  z,
197
204
  aspect="auto",
198
- extent=[time_grid[0],time_grid[-1], freq_grid[0], freq_grid[-1]],
205
+ extent=[time_grid[0], time_grid[-1], freq_grid[0], freq_grid[-1]],
199
206
  origin="lower",
200
207
  cmap=cmap,
201
208
  norm=norm,
@@ -203,13 +210,25 @@ def plot_wavelet_grid(
203
210
  **kwargs,
204
211
  )
205
212
  if trend_color is not None:
206
- plot_wavelet_trend(wavelet_data, time_grid, freq_grid, ax, color=trend_color, freq_range=freq_range, freq_scale=freq_scale)
213
+ plot_wavelet_trend(
214
+ wavelet_data,
215
+ time_grid,
216
+ freq_grid,
217
+ ax,
218
+ color=trend_color,
219
+ freq_range=freq_range,
220
+ freq_scale=freq_scale,
221
+ )
207
222
 
208
223
  # Add colorbar if requested
209
224
  if show_colorbar:
210
225
  cbar = fig.colorbar(im, ax=ax)
211
226
  if cbar_label is None:
212
- cbar_label = "Absolute Wavelet Amplitude" if absolute else "Wavelet Amplitude"
227
+ cbar_label = (
228
+ "Absolute Wavelet Amplitude"
229
+ if absolute
230
+ else "Wavelet Amplitude"
231
+ )
213
232
  cbar.set_label(cbar_label)
214
233
 
215
234
  # Configure axes scales
@@ -239,14 +258,12 @@ def plot_wavelet_grid(
239
258
  bbox=dict(boxstyle="round", facecolor=None, alpha=0.2),
240
259
  )
241
260
 
242
-
243
261
  # Adjust layout
244
262
  fig.tight_layout()
245
263
 
246
264
  return fig, ax
247
265
 
248
266
 
249
-
250
267
  def plot_freqseries(
251
268
  data: np.ndarray,
252
269
  freq: np.ndarray,
@@ -277,9 +294,10 @@ def plot_periodogram(
277
294
  flow = np.min(np.abs(freq))
278
295
  ax.set_xlabel("Frequency [Hz]")
279
296
  ax.set_ylabel("Periodigram")
280
- ax.set_xlim(left=flow, right=nyquist_frequency/2)
297
+ ax.set_xlim(left=flow, right=nyquist_frequency / 2)
281
298
  return ax.figure, ax
282
299
 
300
+
283
301
  def plot_timeseries(
284
302
  data: np.ndarray, time: np.ndarray, ax=None, **kwargs
285
303
  ) -> Tuple[plt.Figure, plt.Axes]:
@@ -314,7 +332,6 @@ def plot_spectrogram(
314
332
 
315
333
  _fmt_time_axis(t, ax)
316
334
 
317
-
318
335
  ax.set_ylabel("Frequency [Hz]")
319
336
  ax.set_ylim(top=fs / 2.0)
320
337
  cbar = plt.colorbar(cm, ax=ax)
@@ -322,20 +339,24 @@ def plot_spectrogram(
322
339
  return ax.figure, ax
323
340
 
324
341
 
325
-
326
342
  def _fmt_time_axis(t, axes, t0=None, tmax=None):
327
343
  if t[-1] > DAY_S: # If time goes beyond a day
328
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}"))
344
+ axes.xaxis.set_major_formatter(
345
+ plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}")
346
+ )
329
347
  axes.set_xlabel("Time [days]")
330
348
  elif t[-1] > HOUR_S: # If time goes beyond an hour
331
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}"))
349
+ axes.xaxis.set_major_formatter(
350
+ plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}")
351
+ )
332
352
  axes.set_xlabel("Time [hr]")
333
353
  elif t[-1] > MIN_S: # If time goes beyond a minute
334
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}"))
354
+ axes.xaxis.set_major_formatter(
355
+ plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}")
356
+ )
335
357
  axes.set_xlabel("Time [min]")
336
358
  else:
337
359
  axes.set_xlabel("Time [s]")
338
360
  t0 = t[0] if t0 is None else t0
339
361
  tmax = t[-1] if tmax is None else tmax
340
362
  axes.set_xlim(t0, tmax)
341
-
@@ -1,11 +1,20 @@
1
+ from typing import Optional, Tuple, Union
2
+
1
3
  import matplotlib.pyplot as plt
2
- from typing import Tuple, Optional, Union
3
- from scipy.signal.windows import tukey
4
4
  from scipy.signal import butter, sosfiltfilt
5
+ from scipy.signal.windows import tukey
5
6
 
6
- from ...logger import logger
7
- from .common import is_documented_by, xp, rfft, rfftfreq, fmt_timerange, fmt_time, fmt_pow2
8
- from .plotting import plot_timeseries, plot_spectrogram
7
+ from ..logger import logger
8
+ from .common import (
9
+ fmt_pow2,
10
+ fmt_time,
11
+ fmt_timerange,
12
+ is_documented_by,
13
+ rfft,
14
+ rfftfreq,
15
+ xp,
16
+ )
17
+ from .plotting import plot_spectrogram, plot_timeseries
9
18
 
10
19
  __all__ = ["TimeSeries"]
11
20
 
@@ -50,7 +59,7 @@ class TimeSeries:
50
59
 
51
60
  @is_documented_by(plot_spectrogram)
52
61
  def plot_spectrogram(
53
- self, ax=None, spec_kwargs={}, plot_kwargs={}
62
+ self, ax=None, spec_kwargs={}, plot_kwargs={}
54
63
  ) -> Tuple[plt.Figure, plt.Axes]:
55
64
  return plot_spectrogram(
56
65
  self.data,
@@ -122,9 +131,11 @@ class TimeSeries:
122
131
  trange = fmt_timerange((self.t0, self.tend))
123
132
  T = " ".join(fmt_time(self.duration, units=True))
124
133
  n = fmt_pow2(len(self))
125
- return f"TimeSeries(n={n}, trange={trange}, T={T}, fs={self.fs:.2f} Hz)"
134
+ return (
135
+ f"TimeSeries(n={n}, trange={trange}, T={T}, fs={self.fs:.2f} Hz)"
136
+ )
126
137
 
127
- def to_frequencyseries(self) -> 'FrequencySeries':
138
+ def to_frequencyseries(self) -> "FrequencySeries":
128
139
  """
129
140
  Convert the time series to a frequency series using the one-sided FFT.
130
141
 
@@ -137,14 +148,15 @@ class TimeSeries:
137
148
  data = rfft(self.data)
138
149
 
139
150
  from .frequencyseries import FrequencySeries # Avoid circular import
151
+
140
152
  return FrequencySeries(data, freq, t0=self.t0)
141
153
 
142
154
  def to_wavelet(
143
- self,
144
- Nf: Union[int, None] = None,
145
- Nt: Union[int, None] = None,
146
- nx: Optional[float] = 4.0,
147
- ) -> 'Wavelet':
155
+ self,
156
+ Nf: Union[int, None] = None,
157
+ Nt: Union[int, None] = None,
158
+ nx: Optional[float] = 4.0,
159
+ ) -> "Wavelet":
148
160
  """
149
161
  Convert the time series to a wavelet representation.
150
162
 
@@ -166,32 +178,37 @@ class TimeSeries:
166
178
  hf = self.to_frequencyseries()
167
179
  return hf.to_wavelet(Nf, Nt, nx=nx)
168
180
 
169
-
170
- def __add__(self, other: 'TimeSeries') -> 'TimeSeries':
181
+ def __add__(self, other: "TimeSeries") -> "TimeSeries":
171
182
  """Add two TimeSeries objects together."""
172
183
  if self.shape != other.shape:
173
- raise ValueError("TimeSeries objects must have the same shape to add them together")
184
+ raise ValueError(
185
+ "TimeSeries objects must have the same shape to add them together"
186
+ )
174
187
  return TimeSeries(self.data + other.data, self.time)
175
188
 
176
- def __sub__(self, other: 'TimeSeries') -> 'TimeSeries':
189
+ def __sub__(self, other: "TimeSeries") -> "TimeSeries":
177
190
  """Subtract one TimeSeries object from another."""
178
191
  if self.shape != other.shape:
179
- raise ValueError("TimeSeries objects must have the same shape to subtract them")
192
+ raise ValueError(
193
+ "TimeSeries objects must have the same shape to subtract them"
194
+ )
180
195
  return TimeSeries(self.data - other.data, self.time)
181
196
 
182
- def __eq__(self, other: 'TimeSeries') -> bool:
197
+ def __eq__(self, other: "TimeSeries") -> bool:
183
198
  """Check if two TimeSeries objects are equal."""
184
199
  shape_same = self.shape == other.shape
185
200
  range_same = self.t0 == other.t0 and self.tend == other.tend
186
- time_same = xp.allclose(self.time, other.time)
201
+ time_same = xp.allclose(self.time, other.time)
187
202
  data_same = xp.allclose(self.data, other.data)
188
203
  return shape_same and range_same and data_same and time_same
189
204
 
190
- def __mul__(self, other: float) -> 'TimeSeries':
205
+ def __mul__(self, other: float) -> "TimeSeries":
191
206
  """Multiply a TimeSeries object by a scalar."""
192
207
  return TimeSeries(self.data * other, self.time)
193
208
 
194
- def zero_pad_to_power_of_2(self, tukey_window_alpha:float=0.0)->'TimeSeries':
209
+ def zero_pad_to_power_of_2(
210
+ self, tukey_window_alpha: float = 0.0
211
+ ) -> "TimeSeries":
195
212
  """Zero pad the time series to make the length a power of two (useful to speed up FFTs, O(NlogN) versus O(N^2)).
196
213
 
197
214
  Parameters
@@ -207,7 +224,7 @@ class TimeSeries:
207
224
  """
208
225
  N, dt, t0 = self.ND, self.dt, self.t0
209
226
  pow_2 = xp.ceil(xp.log2(N))
210
- n_pad = int((2 ** pow_2) - N)
227
+ n_pad = int((2**pow_2) - N)
211
228
  new_N = N + n_pad
212
229
  if n_pad > 0:
213
230
  logger.warning(
@@ -220,7 +237,12 @@ class TimeSeries:
220
237
  time = xp.arange(0, len(data) * dt, dt) + t0
221
238
  return TimeSeries(data, time)
222
239
 
223
- def highpass_filter(self, fmin: float, tukey_window_alpha:float=0.0, bandpass_order: int = 4) -> 'TimeSeries':
240
+ def highpass_filter(
241
+ self,
242
+ fmin: float,
243
+ tukey_window_alpha: float = 0.0,
244
+ bandpass_order: int = 4,
245
+ ) -> "TimeSeries":
224
246
  """
225
247
  Filter the time series with a highpass bandpass filter.
226
248
 
@@ -244,24 +266,25 @@ class TimeSeries:
244
266
  """
245
267
 
246
268
  if fmin <= 0 or fmin > self.nyquist_frequency:
247
- raise ValueError(f"Invalid fmin value: {fmin}. Must be in the range [0, {self.nyquist_frequency}]")
269
+ raise ValueError(
270
+ f"Invalid fmin value: {fmin}. Must be in the range [0, {self.nyquist_frequency}]"
271
+ )
248
272
 
249
- sos = butter(bandpass_order, Wn=fmin, btype="highpass", output='sos', fs=self.fs)
273
+ sos = butter(
274
+ bandpass_order, Wn=fmin, btype="highpass", output="sos", fs=self.fs
275
+ )
250
276
  window = tukey(self.ND, alpha=tukey_window_alpha)
251
277
  data = self.data.copy()
252
278
  data = sosfiltfilt(sos, data * window)
253
279
  return TimeSeries(data, self.time)
254
280
 
255
281
  def __copy__(self):
256
- return TimeSeries(
257
- self.data.copy(),
258
- self.time.copy()
259
- )
282
+ return TimeSeries(self.data.copy(), self.time.copy())
260
283
 
261
284
  def copy(self):
262
285
  return self.__copy__()
263
286
 
264
- def __getitem__(self, key)->"TimeSeries":
287
+ def __getitem__(self, key) -> "TimeSeries":
265
288
  if isinstance(key, slice):
266
289
  # Handle slicing
267
290
  return self.__handle_slice(key)
@@ -269,12 +292,9 @@ class TimeSeries:
269
292
  # Handle regular indexing
270
293
  return TimeSeries(self.data[key], self.time[key])
271
294
 
272
- def __handle_slice(self, slice_obj)->"TimeSeries":
273
- return TimeSeries(
274
- self.data[slice_obj],
275
- self.time[slice_obj]
276
- )
295
+ def __handle_slice(self, slice_obj) -> "TimeSeries":
296
+ return TimeSeries(self.data[slice_obj], self.time[slice_obj])
277
297
 
278
298
  @classmethod
279
- def _EMPTY(cls, ND:int, dt:float)->"TimeSeries":
280
- return cls(xp.zeros(ND), xp.arange(0, ND*dt, dt))
299
+ def _EMPTY(cls, ND: int, dt: float) -> "TimeSeries":
300
+ return cls(xp.zeros(ND), xp.arange(0, ND * dt, dt))
@@ -1,10 +1,11 @@
1
- import matplotlib.pyplot as plt
2
- from typing import Optional, Tuple
1
+ from typing import List, Optional, Tuple
3
2
 
3
+ import matplotlib.pyplot as plt
4
4
  import numpy as np
5
5
 
6
- from .common import is_documented_by, xp, fmt_timerange
6
+ from .common import fmt_timerange, is_documented_by, xp
7
7
  from .plotting import plot_wavelet_grid, plot_wavelet_trend
8
+ from .wavelet_bins import compute_bins
8
9
 
9
10
 
10
11
  class Wavelet:
@@ -22,10 +23,10 @@ class Wavelet:
22
23
  """
23
24
 
24
25
  def __init__(
25
- self,
26
- data: xp.ndarray,
27
- time: xp.ndarray,
28
- freq: xp.ndarray,
26
+ self,
27
+ data: xp.ndarray,
28
+ time: xp.ndarray,
29
+ freq: xp.ndarray,
29
30
  ):
30
31
  """
31
32
  Initialize the Wavelet object with data, time, and frequency arrays.
@@ -53,17 +54,60 @@ class Wavelet:
53
54
  self.time = time
54
55
  self.freq = freq
55
56
 
57
+ @classmethod
58
+ def zeros_from_grid(cls, time: xp.ndarray, freq: xp.ndarray) -> "Wavelet":
59
+ """
60
+ Create a Wavelet object filled with zeros.
61
+
62
+ Parameters
63
+ ----------
64
+ time: xp.ndarray
65
+ freq: xp.ndarray
66
+
67
+ Returns
68
+ -------
69
+ Wavelet
70
+ A Wavelet object with zero-filled data array.
71
+ """
72
+ Nf, Nt = len(freq), len(time)
73
+ return cls(data=xp.zeros((Nf, Nt)), time=time, freq=freq)
74
+
75
+ @classmethod
76
+ def zeros(cls, Nf: int, Nt: int, T: float) -> "Wavelet":
77
+ """
78
+ Create a Wavelet object filled with zeros.
79
+
80
+ Parameters
81
+ ----------
82
+ Nf : int
83
+ Number of frequency bins.
84
+ Nt : int
85
+ Number of time bins.
86
+
87
+ Returns
88
+ -------
89
+ Wavelet
90
+ A Wavelet object with zero-filled data array.
91
+ """
92
+ return cls.zeros_from_grid(*compute_bins(Nf, Nt, T))
93
+
56
94
  @is_documented_by(plot_wavelet_grid)
57
95
  def plot(self, ax=None, *args, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
58
96
  kwargs["time_grid"] = kwargs.get("time_grid", self.time)
59
97
  kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
60
- return plot_wavelet_grid(wavelet_data=self.data, ax=ax, *args, **kwargs)
98
+ return plot_wavelet_grid(
99
+ wavelet_data=self.data, ax=ax, *args, **kwargs
100
+ )
61
101
 
62
102
  @is_documented_by(plot_wavelet_trend)
63
- def plot_trend(self, ax=None, *args, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
103
+ def plot_trend(
104
+ self, ax=None, *args, **kwargs
105
+ ) -> Tuple[plt.Figure, plt.Axes]:
64
106
  kwargs["time_grid"] = kwargs.get("time_grid", self.time)
65
107
  kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
66
- return plot_wavelet_trend(wavelet_data=self.data, ax=ax, *args, **kwargs)
108
+ return plot_wavelet_trend(
109
+ wavelet_data=self.data, ax=ax, *args, **kwargs
110
+ )
67
111
 
68
112
  @property
69
113
  def Nt(self) -> int:
@@ -242,7 +286,8 @@ class Wavelet:
242
286
  TimeSeries
243
287
  A `TimeSeries` object representing the time-domain signal.
244
288
  """
245
- from ..inverse import from_wavelet_to_time
289
+ from ..transforms.inverse import from_wavelet_to_time
290
+
246
291
  return from_wavelet_to_time(self, dt=self.delta_t, nx=nx, mult=mult)
247
292
 
248
293
  def to_frequencyseries(self, nx: float = 4.0) -> "FrequencySeries":
@@ -254,7 +299,8 @@ class Wavelet:
254
299
  FrequencySeries
255
300
  A `FrequencySeries` object representing the frequency-domain signal.
256
301
  """
257
- from ..inverse import from_wavelet_to_freq
302
+ from ..transforms.inverse import from_wavelet_to_freq
303
+
258
304
  return from_wavelet_to_freq(self, dt=self.delta_t, nx=nx)
259
305
 
260
306
  def __repr__(self) -> str:
@@ -277,40 +323,57 @@ class Wavelet:
277
323
  def __add__(self, other):
278
324
  """Element-wise addition of two Wavelet objects."""
279
325
  if isinstance(other, Wavelet):
280
- return Wavelet(data=self.data + other.data, time=self.time, freq=self.freq)
326
+ return Wavelet(
327
+ data=self.data + other.data, time=self.time, freq=self.freq
328
+ )
281
329
  elif isinstance(other, float):
282
- return Wavelet(data=self.data + other, time=self.time, freq=self.freq)
330
+ return Wavelet(
331
+ data=self.data + other, time=self.time, freq=self.freq
332
+ )
283
333
 
284
334
  def __sub__(self, other):
285
335
  """Element-wise subtraction of two Wavelet objects."""
286
336
  if isinstance(other, Wavelet):
287
- return Wavelet(data=self.data - other.data, time=self.time, freq=self.freq)
337
+ return Wavelet(
338
+ data=self.data - other.data, time=self.time, freq=self.freq
339
+ )
288
340
  elif isinstance(other, float):
289
- return Wavelet(data=self.data - other, time=self.time, freq=self.freq)
341
+ return Wavelet(
342
+ data=self.data - other, time=self.time, freq=self.freq
343
+ )
290
344
 
291
345
  def __mul__(self, other):
292
346
  """Element-wise multiplication of two Wavelet objects."""
293
347
  if isinstance(other, Wavelet):
294
- return Wavelet(data=self.data * other.data, time=self.time, freq=self.freq)
348
+ return Wavelet(
349
+ data=self.data * other.data, time=self.time, freq=self.freq
350
+ )
295
351
  elif isinstance(other, float):
296
- return Wavelet(data=self.data * other, time=self.time, freq=self.freq)
352
+ return Wavelet(
353
+ data=self.data * other, time=self.time, freq=self.freq
354
+ )
297
355
 
298
356
  def __truediv__(self, other):
299
357
  """Element-wise division of two Wavelet objects."""
300
358
  if isinstance(other, Wavelet):
301
- return Wavelet(data=self.data / other.data, time=self.time, freq=self.freq)
359
+ return Wavelet(
360
+ data=self.data / other.data, time=self.time, freq=self.freq
361
+ )
302
362
  elif isinstance(other, float):
303
- return Wavelet(data=self.data / other, time=self.time, freq=self.freq)
363
+ return Wavelet(
364
+ data=self.data / other, time=self.time, freq=self.freq
365
+ )
304
366
 
305
- def __eq__(self, other:"Wavelet") -> bool:
367
+ def __eq__(self, other: "Wavelet") -> bool:
306
368
  """Element-wise comparison of two Wavelet objects."""
307
369
  data_all_same = xp.isclose(xp.nansum(self.data - other.data), 0)
308
370
  time_same = (self.time == other.time).all()
309
371
  freq_same = (self.freq == other.freq).all()
310
372
  return data_all_same and time_same and freq_same
311
373
 
312
-
313
- def noise_weighted_inner_product(self, other:"Wavelet", psd:"Wavelet") -> float:
374
+ def noise_weighted_inner_product(
375
+ self, other: "Wavelet", psd: "Wavelet"
376
+ ) -> float:
314
377
  """
315
378
  Compute the noise-weighted inner product of two wavelet grids given a PSD.
316
379
 
@@ -326,11 +389,11 @@ class Wavelet:
326
389
  float
327
390
  The noise-weighted inner product.
328
391
  """
329
- from ...utils import noise_weighted_inner_product
330
- return noise_weighted_inner_product(self, other, psd)
392
+ from ..utils import noise_weighted_inner_product
331
393
 
394
+ return noise_weighted_inner_product(self, other, psd)
332
395
 
333
- def matched_filter_snr(self, template:"Wavelet", psd:"Wavelet") -> float:
396
+ def matched_filter_snr(self, template: "Wavelet", psd: "Wavelet") -> float:
334
397
  """
335
398
  Compute the matched filter SNR of the wavelet grid given a template.
336
399
 
@@ -347,7 +410,7 @@ class Wavelet:
347
410
  mf = self.noise_weighted_inner_product(template, psd)
348
411
  return mf / self.optimal_snr(psd)
349
412
 
350
- def optimal_snr(self, psd:"Wavelet") -> float:
413
+ def optimal_snr(self, psd: "Wavelet") -> float:
351
414
  """
352
415
  Compute the optimal SNR of the wavelet grid given a PSD.
353
416
 
@@ -365,10 +428,27 @@ class Wavelet:
365
428
 
366
429
  def __copy__(self):
367
430
  return Wavelet(
368
- data=self.data.copy(),
369
- time=self.time.copy(),
370
- freq=self.freq.copy()
431
+ data=self.data.copy(), time=self.time.copy(), freq=self.freq.copy()
371
432
  )
372
433
 
373
434
  def copy(self):
374
- return self.__copy__()
435
+ return self.__copy__()
436
+
437
+
438
+ class WaveletMask(Wavelet):
439
+ @property
440
+ def mask(self):
441
+ return self.data
442
+
443
+ def __repr__(self):
444
+ return f"WaveletMask({self.mask.shape}, {fmt_timerange(self.time)}, {self.freq})"
445
+
446
+ @classmethod
447
+ def from_frange(
448
+ cls, time_grid: xp.ndarray, freq_grid: xp.ndarray, frange: List[float]
449
+ ):
450
+ self = cls.zeros_from_grid(time_grid, freq_grid)
451
+ self.mask[
452
+ (freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
453
+ ] = True
454
+ return self
@@ -2,7 +2,8 @@ from typing import Tuple, Union
2
2
 
3
3
  import numpy as np
4
4
 
5
- from ..types import FrequencySeries, TimeSeries
5
+ from .frequencyseries import FrequencySeries
6
+ from .timeseries import TimeSeries
6
7
 
7
8
 
8
9
  def _preprocess_bins(
@@ -10,7 +11,6 @@ def _preprocess_bins(
10
11
  ) -> Tuple[int, int]:
11
12
  """preprocess the bins"""
12
13
 
13
-
14
14
  if isinstance(data, TimeSeries):
15
15
  N = len(data)
16
16
  elif isinstance(data, FrequencySeries):
@@ -34,7 +34,6 @@ def _get_bins(
34
34
  Nf: Union[int, None] = None,
35
35
  Nt: Union[int, None] = None,
36
36
  ) -> Tuple[np.ndarray, np.ndarray]:
37
-
38
37
  T = data.duration
39
38
  t_bins, f_bins = compute_bins(Nf, Nt, T)
40
39
 
@@ -42,12 +41,12 @@ def _get_bins(
42
41
  # fs = N / T
43
42
  # assert delta_f == fmax / Nf, f"delta_f={delta_f} != fmax/Nf={fmax/Nf}"
44
43
 
45
- t_bins+= data.t0
44
+ t_bins += data.t0
46
45
 
47
46
  return t_bins, f_bins
48
47
 
49
48
 
50
- def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
49
+ def compute_bins(Nf: int, Nt: int, T: float) -> Tuple[np.ndarray, np.ndarray]:
51
50
  """Get the bins for the wavelet transform
52
51
  Eq 4-6 in Wavelets paper
53
52
  """
@@ -55,4 +54,4 @@ def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
55
54
  delta_F = 1 / (2 * delta_T)
56
55
  t_bins = np.arange(0, Nt) * delta_T
57
56
  f_bins = np.arange(0, Nf) * delta_F
58
- return t_bins, f_bins
57
+ return t_bins, f_bins
pywavelet/utils.py CHANGED
@@ -3,7 +3,7 @@ from typing import Union
3
3
  import numpy as np
4
4
  from scipy.interpolate import interp1d
5
5
 
6
- from .transforms.types import FrequencySeries, TimeSeries, Wavelet
6
+ from .types import FrequencySeries, TimeSeries, Wavelet, WaveletMask
7
7
 
8
8
  DATA_TYPE = Union[TimeSeries, FrequencySeries, Wavelet]
9
9
 
@@ -34,10 +34,13 @@ def evolutionary_psd_from_stationary_psd(
34
34
  return Wavelet(psd_grid.T, time=t_grid, freq=f_grid)
35
35
 
36
36
 
37
- def noise_weighted_inner_product(d: Wavelet, h: Wavelet, PSD: Wavelet) -> float:
37
+ def noise_weighted_inner_product(
38
+ d: Wavelet, h: Wavelet, PSD: Wavelet
39
+ ) -> float:
38
40
  return np.nansum((d.data * h.data) / PSD.data)
39
41
 
40
- def compute_snr(d:Wavelet, h: Wavelet, PSD: Wavelet) -> float:
42
+
43
+ def compute_snr(d: Wavelet, h: Wavelet, PSD: Wavelet) -> float:
41
44
  """Compute the SNR of a model h[ti,fi] given freqseries d[ti,fi] and PSD[ti,fi].
42
45
 
43
46
  SNR(h) = Sum_{ti,fi} [ h_hat[ti,fi] d[ti,fi] / PSD[ti,fi]
@@ -60,5 +63,14 @@ def compute_snr(d:Wavelet, h: Wavelet, PSD: Wavelet) -> float:
60
63
  return np.sqrt(noise_weighted_inner_product(d, h, PSD))
61
64
 
62
65
 
63
- def compute_likelihood(data:Wavelet, template:Wavelet, psd:Wavelet) -> float:
64
- return -0.5 * np.nansum((data.data - template.data) ** 2 / psd.data)
66
+ def compute_likelihood(
67
+ data: Wavelet, template: Wavelet, psd: Wavelet, mask: WaveletMask = None
68
+ ) -> float:
69
+ d = data.data
70
+ h = template.data
71
+ p = psd.data
72
+ if mask is not None:
73
+ m = mask.mask
74
+ d, h, p = d * m, h * m, p * m
75
+
76
+ return -0.5 * np.nansum((d - h) ** 2 / p)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.0.5
3
+ Version: 0.1.1
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -0,0 +1,25 @@
1
+ pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
+ pywavelet/_version.py,sha256=PKIMyjdUACH4-ONvtunQCnYE2UhlMfp9su83e3HXl5E,411
3
+ pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
4
+ pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
5
+ pywavelet/transforms/__init__.py,sha256=1Ibsup9UwMajeZ9NCQ4BN15qZTeJ_EHkgGu8XNFdA18,255
6
+ pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
7
+ pywavelet/transforms/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
8
+ pywavelet/transforms/forward/from_freq.py,sha256=wCiyLpzJE3rGxYjQBdXlwkxPIRYhQWjKq0C_8zYlmDk,2697
9
+ pywavelet/transforms/forward/from_time.py,sha256=-Y6VEKwDCYBAHAjLdO46vT-6alpM5fXTgTZ_xkYxqA8,2381
10
+ pywavelet/transforms/forward/main.py,sha256=Gfy0sp-woy_3ihKMzuk2WuZ7dRk-Mm6sp5dVpYrSvj4,4005
11
+ pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
12
+ pywavelet/transforms/inverse/main.py,sha256=Q7wUaRjB1sgqdB7dniWQGbPTWYQNnIsIrYtjsaHJEdE,3012
13
+ pywavelet/transforms/inverse/to_freq.py,sha256=so_TDbwdS1N8sd1QcpeAEkI10XFDtoFJGohtD4YulZM,2809
14
+ pywavelet/transforms/inverse/to_time.py,sha256=w5vmImdsb_4YeInZtXh0llsThLTxS0tmYDlNGJ-IUew,5080
15
+ pywavelet/types/__init__.py,sha256=5YptzQvYBnRfC8N5lpOBf9I1lzpJ0pw0QMnvIcwP3YI,122
16
+ pywavelet/types/common.py,sha256=OSAW6GqLTgqJ-RYEv__XbzsfFd8AFo5w-ctXQ4XAFZo,1317
17
+ pywavelet/types/frequencyseries.py,sha256=UqcE6UQfw5HZm4na2q9k-X-mfqO-BCiTAvGjaYpSrwc,7518
18
+ pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
19
+ pywavelet/types/timeseries.py,sha256=6DPO0xLi4Dq2srhJLmavFMf4fYIC3wwdbyMU7lMdjTo,9446
20
+ pywavelet/types/wavelet.py,sha256=ptTEnq6nRZiW2x6g_NV_FuoeVNOsbNOu6caSxYDZNgk,12583
21
+ pywavelet/types/wavelet_bins.py,sha256=SC9nhyigWvOfs2TbH8-Ck_iS1Mrz6sG-8EaIFYLIHuk,1453
22
+ pywavelet-0.1.1.dist-info/METADATA,sha256=OgKTqqZQKfhCCcsYpcXnrut-2EAH42c3tCpCGc0PIDg,1307
23
+ pywavelet-0.1.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
24
+ pywavelet-0.1.1.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
25
+ pywavelet-0.1.1.dist-info/RECORD,,
@@ -1,25 +0,0 @@
1
- pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
- pywavelet/_version.py,sha256=EJB7__SNK9kQS_SWZB_U4DHJ3P8ftF6etZEihTYnuXE,411
3
- pywavelet/logger.py,sha256=u5Et6QLUU0IuTnPnxfev5-4GYmihTo7_yBCTf-rVUsA,377
4
- pywavelet/utils.py,sha256=rbmZ-PrnFstq05kh6xj8emuQ3-7oXqZinU03bpg3N7A,1746
5
- pywavelet/transforms/__init__.py,sha256=FolK8WiVEJmGDC9xMupYVI_essXaXS4LYWKbqEqGx6o,289
6
- pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
7
- pywavelet/transforms/forward/__init__.py,sha256=Yq4Tg3Ze98-17C9FIkOqMUdiLHe9x_YoyuRvxOxMOP0,176
8
- pywavelet/transforms/forward/from_freq.py,sha256=HWADciv746P5zWNhnyRKDEDACTQnwDTR4usf2k0VIWk,2663
9
- pywavelet/transforms/forward/from_time.py,sha256=zfpvmq7UHnynzAnsM_Pf4MXpO1GJ0TaCCERQKNsaBNU,2340
10
- pywavelet/transforms/forward/main.py,sha256=FJSLUUsk3I91lUeB2YFiXQdX8gB9maFW60e4tgrd45A,4053
11
- pywavelet/transforms/forward/wavelet_bins.py,sha256=43CHVSjNgC59Af1_h9t33VBBqGHJmIUaOUhJgekhMnc,1419
12
- pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
13
- pywavelet/transforms/inverse/main.py,sha256=l5yFvzmWObrO5Xt_8KYp62w829ab0pQLxcv0-QxkvG0,3015
14
- pywavelet/transforms/inverse/to_freq.py,sha256=SExZMax-8A-tJpIA86pYY61X2qvlZ2MrZY27uzCQSV0,2778
15
- pywavelet/transforms/inverse/to_time.py,sha256=BAYvrr41QHIbzwYMPyMnzv5mqSx40YigmBruWBwtZwc,5041
16
- pywavelet/transforms/types/__init__.py,sha256=4wUTVBk6A02xjZUY_w056eUZurYI9vVfa--I3Q6Udng,109
17
- pywavelet/transforms/types/common.py,sha256=sLBn2d9cuL6NYeIv6NIogDjY6rYPZAPzCtWGwMlAwkI,1290
18
- pywavelet/transforms/types/frequencyseries.py,sha256=Rtwt486UL0-TgMAdcMMVpyfi5PSLzxFcdE_RsYuyxQk,7463
19
- pywavelet/transforms/types/plotting.py,sha256=aEIFoSuQRYwYc2639yLkbujXx_aav5A5tXG29rOSMOQ,10275
20
- pywavelet/transforms/types/timeseries.py,sha256=Nl1tiKZ7kwu-EZ5JtubwgzgyjQZR86eGU3C3kElIPNg,9296
21
- pywavelet/transforms/types/wavelet.py,sha256=raODhyegBd1_esuv2YP6tK_Nj9fYaLQ6O4pxpiFAVZU,10790
22
- pywavelet-0.0.5.dist-info/METADATA,sha256=BaLK8jbUqss_8BeGWOlf_031PoCAlEV4rmuSaitQd0E,1307
23
- pywavelet-0.0.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
24
- pywavelet-0.0.5.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
25
- pywavelet-0.0.5.dist-info/RECORD,,