pywavelet 0.0.5__py3-none-any.whl → 0.1.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- pywavelet/_version.py +2 -2
- pywavelet/logger.py +6 -2
- pywavelet/transforms/__init__.py +1 -2
- pywavelet/transforms/forward/__init__.py +1 -2
- pywavelet/transforms/forward/from_freq.py +13 -4
- pywavelet/transforms/forward/from_time.py +13 -1
- pywavelet/transforms/forward/main.py +7 -15
- pywavelet/transforms/inverse/main.py +3 -4
- pywavelet/transforms/inverse/to_freq.py +11 -3
- pywavelet/transforms/inverse/to_time.py +13 -3
- pywavelet/{transforms/types → types}/__init__.py +1 -1
- pywavelet/{transforms/types → types}/common.py +7 -7
- pywavelet/{transforms/types → types}/frequencyseries.py +31 -23
- pywavelet/{transforms/types → types}/plotting.py +47 -26
- pywavelet/{transforms/types → types}/timeseries.py +58 -38
- pywavelet/{transforms/types → types}/wavelet.py +111 -31
- pywavelet/{transforms/forward → types}/wavelet_bins.py +5 -6
- pywavelet/utils.py +17 -5
- {pywavelet-0.0.5.dist-info → pywavelet-0.1.1.dist-info}/METADATA +1 -1
- pywavelet-0.1.1.dist-info/RECORD +25 -0
- pywavelet-0.0.5.dist-info/RECORD +0 -25
- {pywavelet-0.0.5.dist-info → pywavelet-0.1.1.dist-info}/WHEEL +0 -0
- {pywavelet-0.0.5.dist-info → pywavelet-0.1.1.dist-info}/top_level.txt +0 -0
pywavelet/_version.py
CHANGED
pywavelet/logger.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1
|
+
import logging
|
1
2
|
import sys
|
2
3
|
import warnings
|
4
|
+
|
3
5
|
from rich.logging import RichHandler
|
4
|
-
import logging
|
5
6
|
|
6
7
|
FORMAT = "%(message)s"
|
7
8
|
logging.basicConfig(
|
8
|
-
level="INFO",
|
9
|
+
level="INFO",
|
10
|
+
format=FORMAT,
|
11
|
+
datefmt="[%X]",
|
12
|
+
handlers=[RichHandler(rich_tracebacks=True)],
|
9
13
|
)
|
10
14
|
|
11
15
|
logger = logging.getLogger("pywavelet")
|
pywavelet/transforms/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
1
|
+
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
2
2
|
from .inverse import from_wavelet_to_freq, from_wavelet_to_time
|
3
3
|
|
4
4
|
__all__ = [
|
@@ -6,5 +6,4 @@ __all__ = [
|
|
6
6
|
"from_wavelet_to_freq",
|
7
7
|
"from_time_to_wavelet",
|
8
8
|
"from_freq_to_wavelet",
|
9
|
-
"compute_bins",
|
10
9
|
]
|
@@ -1,8 +1,10 @@
|
|
1
1
|
"""helper functions for transform_freq"""
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
from numba import njit
|
4
5
|
from numpy import fft
|
5
6
|
|
7
|
+
|
6
8
|
def transform_wavelet_freq_helper(
|
7
9
|
data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
|
8
10
|
) -> np.ndarray:
|
@@ -13,8 +15,16 @@ def transform_wavelet_freq_helper(
|
|
13
15
|
__core(Nf, Nt, DX, freq_strain, phif, wave)
|
14
16
|
return wave
|
15
17
|
|
18
|
+
|
16
19
|
# @njit()
|
17
|
-
def __core(
|
20
|
+
def __core(
|
21
|
+
Nf: int,
|
22
|
+
Nt: int,
|
23
|
+
DX: np.ndarray,
|
24
|
+
freq_strain: np.ndarray,
|
25
|
+
phif: np.ndarray,
|
26
|
+
wave: np.ndarray,
|
27
|
+
):
|
18
28
|
for f_bin in range(0, Nf + 1):
|
19
29
|
__fill_wave_1(f_bin, Nt, Nf, DX, freq_strain, phif)
|
20
30
|
# Numba doesn't support np.ifft so we cant jit this
|
@@ -22,8 +32,6 @@ def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarra
|
|
22
32
|
__fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
|
23
33
|
|
24
34
|
|
25
|
-
|
26
|
-
|
27
35
|
@njit()
|
28
36
|
def __fill_wave_1(
|
29
37
|
f_bin: int,
|
@@ -55,6 +63,7 @@ def __fill_wave_1(
|
|
55
63
|
else:
|
56
64
|
DX[i] = phif[j] * data[jj]
|
57
65
|
|
66
|
+
|
58
67
|
@njit()
|
59
68
|
def __fill_wave_2(
|
60
69
|
f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
|
@@ -72,7 +81,7 @@ def __fill_wave_2(
|
|
72
81
|
if (n + f_bin) % 2:
|
73
82
|
wave[n, f_bin] = -DX_trans[n].imag
|
74
83
|
else:
|
75
|
-
wave[n, f_bin] =
|
84
|
+
wave[n, f_bin] = DX_trans[n].real
|
76
85
|
else:
|
77
86
|
if (n + f_bin) % 2:
|
78
87
|
wave[n, f_bin] = DX_trans[n].imag
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""helper functions for transform_time.py"""
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
from numba import njit
|
4
5
|
from numpy import fft
|
@@ -20,7 +21,18 @@ def transform_wavelet_time_helper(
|
|
20
21
|
__core(Nf, Nt, K, ND, wdata, data_pad, phi, wave, mult)
|
21
22
|
return wave
|
22
23
|
|
23
|
-
|
24
|
+
|
25
|
+
def __core(
|
26
|
+
Nf: int,
|
27
|
+
Nt: int,
|
28
|
+
K: int,
|
29
|
+
ND: int,
|
30
|
+
wdata: np.ndarray,
|
31
|
+
data_pad: np.ndarray,
|
32
|
+
phi: np.ndarray,
|
33
|
+
wave: np.ndarray,
|
34
|
+
mult: int,
|
35
|
+
) -> None:
|
24
36
|
for time_bin_i in range(0, Nt):
|
25
37
|
__fill_wave_1(time_bin_i, K, ND, Nf, wdata, data_pad, phi)
|
26
38
|
wdata_trans = np.fft.rfft(wdata, K)
|
@@ -3,15 +3,15 @@ from typing import Union
|
|
3
3
|
import numpy as np
|
4
4
|
|
5
5
|
from ...logger import logger
|
6
|
+
from ...types import FrequencySeries, TimeSeries, Wavelet
|
7
|
+
from ...types.wavelet_bins import _get_bins, _preprocess_bins
|
6
8
|
from ..phi_computer import phi_vec, phitilde_vec_norm
|
7
|
-
from ..types import FrequencySeries, TimeSeries, Wavelet
|
8
9
|
from .from_freq import transform_wavelet_freq_helper
|
9
10
|
from .from_time import transform_wavelet_time_helper
|
10
|
-
from .wavelet_bins import _get_bins, _preprocess_bins
|
11
|
-
|
12
11
|
|
13
12
|
__all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
|
14
13
|
|
14
|
+
|
15
15
|
def from_time_to_wavelet(
|
16
16
|
timeseries: TimeSeries,
|
17
17
|
Nf: Union[int, None] = None,
|
@@ -74,9 +74,7 @@ def from_time_to_wavelet(
|
|
74
74
|
mult = min(mult, Nt // 2) # Ensure mult is not larger than ND/2
|
75
75
|
phi = phi_vec(Nf, dt=dt, d=nx, q=mult)
|
76
76
|
wave = transform_wavelet_time_helper(timeseries.data, Nf, Nt, phi, mult).T
|
77
|
-
return Wavelet(
|
78
|
-
wave * np.sqrt(2), time=t_bins, freq=f_bins
|
79
|
-
)
|
77
|
+
return Wavelet(wave * np.sqrt(2), time=t_bins, freq=f_bins)
|
80
78
|
|
81
79
|
|
82
80
|
def from_freq_to_wavelet(
|
@@ -117,12 +115,6 @@ def from_freq_to_wavelet(
|
|
117
115
|
t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
|
118
116
|
dt = freqseries.dt
|
119
117
|
phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)
|
120
|
-
wave = transform_wavelet_freq_helper(
|
121
|
-
|
122
|
-
)
|
123
|
-
|
124
|
-
return Wavelet(
|
125
|
-
(2 / Nf) * wave.T * np.sqrt(2),
|
126
|
-
time=t_bins,
|
127
|
-
freq=f_bins
|
128
|
-
)
|
118
|
+
wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
|
119
|
+
|
120
|
+
return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
3
|
from ...transforms.phi_computer import phi_vec, phitilde_vec_norm
|
4
|
-
from
|
4
|
+
from ...types import FrequencySeries, TimeSeries, Wavelet
|
5
5
|
from .to_freq import inverse_wavelet_freq_helper_fast
|
6
6
|
from .to_time import inverse_wavelet_time_helper_fast
|
7
7
|
|
@@ -9,6 +9,7 @@ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
|
|
9
9
|
|
10
10
|
INV_ROOT2 = 1.0 / np.sqrt(2)
|
11
11
|
|
12
|
+
|
12
13
|
def from_wavelet_to_time(
|
13
14
|
wave_in: Wavelet,
|
14
15
|
dt: float,
|
@@ -55,9 +56,7 @@ def from_wavelet_to_time(
|
|
55
56
|
|
56
57
|
|
57
58
|
def from_wavelet_to_freq(
|
58
|
-
wave_in: Wavelet,
|
59
|
-
dt: float,
|
60
|
-
nx:float=4.0
|
59
|
+
wave_in: Wavelet, dt: float, nx: float = 4.0
|
61
60
|
) -> FrequencySeries:
|
62
61
|
"""
|
63
62
|
Perform an inverse wavelet transform to the frequency domain.
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""functions for computing the inverse wavelet transforms"""
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
from numba import njit
|
4
5
|
from numpy import fft
|
@@ -12,13 +13,20 @@ def inverse_wavelet_freq_helper_fast(
|
|
12
13
|
ND = Nf * Nt
|
13
14
|
|
14
15
|
prefactor2s = np.zeros(Nt, np.complex128)
|
15
|
-
res = np.zeros(ND//2 +1, dtype=np.complex128)
|
16
|
+
res = np.zeros(ND // 2 + 1, dtype=np.complex128)
|
16
17
|
__core(Nf, Nt, prefactor2s, wave_in, phif, res)
|
17
18
|
|
18
|
-
|
19
19
|
return res
|
20
20
|
|
21
|
-
|
21
|
+
|
22
|
+
def __core(
|
23
|
+
Nf: int,
|
24
|
+
Nt: int,
|
25
|
+
prefactor2s: np.ndarray,
|
26
|
+
wave_in: np.ndarray,
|
27
|
+
phif: np.ndarray,
|
28
|
+
res: np.ndarray,
|
29
|
+
) -> None:
|
22
30
|
for m in range(0, Nf + 1):
|
23
31
|
__pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
|
24
32
|
fft_prefactor2s = np.fft.fft(prefactor2s)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""functions for computing the inverse wavelet transforms"""
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
from numba import njit
|
4
5
|
from numpy import fft
|
@@ -21,7 +22,16 @@ def inverse_wavelet_time_helper_fast(
|
|
21
22
|
return res[:ND]
|
22
23
|
|
23
24
|
|
24
|
-
def __core(
|
25
|
+
def __core(
|
26
|
+
Nf: int,
|
27
|
+
Nt: int,
|
28
|
+
K: int,
|
29
|
+
ND: int,
|
30
|
+
wave_in: np.ndarray,
|
31
|
+
phi: np.ndarray,
|
32
|
+
res: np.ndarray,
|
33
|
+
afins: np.ndarray,
|
34
|
+
) -> None:
|
25
35
|
for n in range(0, Nt):
|
26
36
|
if n % 2 == 0:
|
27
37
|
pack_wave_time_helper_compact(n, Nf, Nt, wave_in, afins)
|
@@ -29,9 +39,9 @@ def __core(Nf: int, Nt: int, K: int, ND: int, wave_in: np.ndarray, phi: np.ndarr
|
|
29
39
|
unpack_time_wave_helper_compact(n, Nf, Nt, K, phi, ffts_fin, res)
|
30
40
|
|
31
41
|
# wrap boundary conditions
|
32
|
-
res[: min(K + Nf, ND)] += res[ND: min(ND + K + Nf, 2 * ND)]
|
42
|
+
res[: min(K + Nf, ND)] += res[ND : min(ND + K + Nf, 2 * ND)]
|
33
43
|
if K + Nf > ND:
|
34
|
-
res[: K + Nf - ND] += res[2 * ND: ND + K * Nf]
|
44
|
+
res[: K + Nf - ND] += res[2 * ND : ND + K * Nf]
|
35
45
|
|
36
46
|
|
37
47
|
def unpack_time_wave_helper(
|
@@ -1,9 +1,9 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Tuple, Union
|
2
2
|
|
3
3
|
import numpy as xp
|
4
|
-
from numpy.fft import
|
4
|
+
from numpy.fft import fft, irfft, rfft, rfftfreq # type: ignore
|
5
5
|
|
6
|
-
from
|
6
|
+
from ..logger import logger
|
7
7
|
|
8
8
|
|
9
9
|
def _len_check(d):
|
@@ -19,7 +19,7 @@ def is_documented_by(original):
|
|
19
19
|
return wrapper
|
20
20
|
|
21
21
|
|
22
|
-
def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
|
22
|
+
def fmt_time(seconds: float, units=False) -> Union[str, Tuple[str, str]]:
|
23
23
|
"""Returns formatted time and units [ms, s, min, hr, day]"""
|
24
24
|
t, u = "", ""
|
25
25
|
if seconds < 1e-3:
|
@@ -42,12 +42,12 @@ def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
|
|
42
42
|
|
43
43
|
def fmt_timerange(trange):
|
44
44
|
t0 = fmt_time(trange[0])
|
45
|
-
tend, units = fmt_time(trange[1], units
|
45
|
+
tend, units = fmt_time(trange[1], units=True)
|
46
46
|
return f"[{t0}, {tend}] {units}"
|
47
47
|
|
48
48
|
|
49
|
-
def fmt_pow2(n:float)->str:
|
49
|
+
def fmt_pow2(n: float) -> str:
|
50
50
|
pow2 = xp.log2(n)
|
51
51
|
if pow2.is_integer():
|
52
52
|
return f"2^{int(pow2)}"
|
53
|
-
return f"{n:,}"
|
53
|
+
return f"{n:,}"
|
@@ -1,11 +1,13 @@
|
|
1
|
+
from typing import Optional, Tuple, Union
|
2
|
+
|
1
3
|
import matplotlib.pyplot as plt
|
2
|
-
from typing import Tuple, Union, Optional
|
3
4
|
|
4
|
-
from .common import
|
5
|
+
from .common import fmt_pow2, fmt_time, irfft, is_documented_by, xp
|
5
6
|
from .plotting import plot_freqseries, plot_periodogram
|
6
7
|
|
7
8
|
__all__ = ["FrequencySeries"]
|
8
9
|
|
10
|
+
|
9
11
|
class FrequencySeries:
|
10
12
|
"""
|
11
13
|
A class to represent a one-sided frequency series, with various methods for
|
@@ -39,9 +41,13 @@ class FrequencySeries:
|
|
39
41
|
If any frequency is negative or if `data` and `freq` do not have the same length.
|
40
42
|
"""
|
41
43
|
if xp.any(freq < 0):
|
42
|
-
raise ValueError(
|
44
|
+
raise ValueError(
|
45
|
+
"FrequencySeries must be one-sided (only non-negative frequencies)"
|
46
|
+
)
|
43
47
|
if len(data) != len(freq):
|
44
|
-
raise ValueError(
|
48
|
+
raise ValueError(
|
49
|
+
f"data and freq must have the same length ({len(data)} != {len(freq)})"
|
50
|
+
)
|
45
51
|
self.data = data
|
46
52
|
self.freq = freq
|
47
53
|
self.t0 = t0
|
@@ -53,10 +59,10 @@ class FrequencySeries:
|
|
53
59
|
)
|
54
60
|
|
55
61
|
@is_documented_by(plot_periodogram)
|
56
|
-
def plot_periodogram(
|
57
|
-
|
58
|
-
|
59
|
-
)
|
62
|
+
def plot_periodogram(
|
63
|
+
self, ax=None, **kwargs
|
64
|
+
) -> Tuple[plt.Figure, plt.Axes]:
|
65
|
+
return plot_periodogram(self.data, self.freq, self.fs, ax=ax, **kwargs)
|
60
66
|
|
61
67
|
def __len__(self):
|
62
68
|
"""Return the length of the frequency series."""
|
@@ -127,7 +133,9 @@ class FrequencySeries:
|
|
127
133
|
n = fmt_pow2(len(self))
|
128
134
|
return f"FrequencySeries(n={n}, frange=[{self.range[0]:.2f}, {self.range[1]:.2f}] Hz, T={dur}, fs={self.fs:.2f} Hz)"
|
129
135
|
|
130
|
-
def noise_weighted_inner_product(
|
136
|
+
def noise_weighted_inner_product(
|
137
|
+
self, other: "FrequencySeries", psd: "FrequencySeries"
|
138
|
+
) -> float:
|
131
139
|
"""
|
132
140
|
Compute the noise-weighted inner product of two FrequencySeries.
|
133
141
|
|
@@ -144,9 +152,11 @@ class FrequencySeries:
|
|
144
152
|
The noise-weighted inner product of the two FrequencySeries.
|
145
153
|
"""
|
146
154
|
integrand = xp.real(xp.conj(self.data) * other.data / psd.data)
|
147
|
-
return (4 * self.dt/self.ND) * xp.nansum(integrand)
|
155
|
+
return (4 * self.dt / self.ND) * xp.nansum(integrand)
|
148
156
|
|
149
|
-
def matched_filter_snr(
|
157
|
+
def matched_filter_snr(
|
158
|
+
self, other: "FrequencySeries", psd: "FrequencySeries"
|
159
|
+
) -> float:
|
150
160
|
"""
|
151
161
|
Compute the signal-to-noise ratio (SNR) of a matched filter.
|
152
162
|
|
@@ -199,15 +209,15 @@ class FrequencySeries:
|
|
199
209
|
|
200
210
|
# Create and return a TimeSeries object
|
201
211
|
from .timeseries import TimeSeries
|
202
|
-
return TimeSeries(time_data, time)
|
203
212
|
|
213
|
+
return TimeSeries(time_data, time)
|
204
214
|
|
205
215
|
def to_wavelet(
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
216
|
+
self,
|
217
|
+
Nf: Union[int, None] = None,
|
218
|
+
Nt: Union[int, None] = None,
|
219
|
+
nx: Optional[float] = 4.0,
|
220
|
+
) -> "Wavelet":
|
211
221
|
"""
|
212
222
|
Convert the frequency series to a wavelet using inverse Fourier transform.
|
213
223
|
|
@@ -216,9 +226,9 @@ class FrequencySeries:
|
|
216
226
|
Wavelet
|
217
227
|
The corresponding wavelet.
|
218
228
|
"""
|
219
|
-
from ..forward import from_freq_to_wavelet
|
220
|
-
return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
|
229
|
+
from ..transforms.forward import from_freq_to_wavelet
|
221
230
|
|
231
|
+
return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
|
222
232
|
|
223
233
|
def __eq__(self, other):
|
224
234
|
"""Check if two FrequencySeries objects are equal."""
|
@@ -228,10 +238,8 @@ class FrequencySeries:
|
|
228
238
|
|
229
239
|
def __copy__(self):
|
230
240
|
return FrequencySeries(
|
231
|
-
xp.copy(self.data),
|
232
|
-
xp.copy(self.freq),
|
233
|
-
t0=self.t0
|
241
|
+
xp.copy(self.data), xp.copy(self.freq), t0=self.t0
|
234
242
|
)
|
235
243
|
|
236
244
|
def copy(self):
|
237
|
-
return self.__copy__()
|
245
|
+
return self.__copy__()
|
@@ -1,12 +1,11 @@
|
|
1
1
|
import warnings
|
2
|
-
from typing import
|
3
|
-
from scipy.signal import savgol_filter
|
4
|
-
from scipy.interpolate import interp1d
|
2
|
+
from typing import Optional, Tuple, Union
|
5
3
|
|
6
4
|
import matplotlib.pyplot as plt
|
7
5
|
import numpy as np
|
8
6
|
from matplotlib.colors import LogNorm, TwoSlopeNorm
|
9
|
-
from scipy.
|
7
|
+
from scipy.interpolate import interp1d
|
8
|
+
from scipy.signal import savgol_filter, spectrogram
|
10
9
|
|
11
10
|
MIN_S = 60
|
12
11
|
HOUR_S = 60 * MIN_S
|
@@ -53,8 +52,13 @@ def __get_smoothed_y(x, z, y_grid):
|
|
53
52
|
# Interpolate to fill NaNs in y before smoothing
|
54
53
|
nan_mask = ~np.isnan(y)
|
55
54
|
if np.isnan(y).any():
|
56
|
-
interpolator = interp1d(
|
57
|
-
|
55
|
+
interpolator = interp1d(
|
56
|
+
x[nan_mask],
|
57
|
+
y[nan_mask],
|
58
|
+
kind="cubic",
|
59
|
+
bounds_error=False,
|
60
|
+
fill_value="extrapolate",
|
61
|
+
)
|
58
62
|
y = interpolator(x) # Fill NaNs with interpolated values
|
59
63
|
|
60
64
|
# Smooth the curve
|
@@ -64,8 +68,6 @@ def __get_smoothed_y(x, z, y_grid):
|
|
64
68
|
return y
|
65
69
|
|
66
70
|
|
67
|
-
|
68
|
-
|
69
71
|
def plot_wavelet_grid(
|
70
72
|
wavelet_data: np.ndarray,
|
71
73
|
time_grid: np.ndarray,
|
@@ -80,8 +82,8 @@ def plot_wavelet_grid(
|
|
80
82
|
norm: Optional[Union[LogNorm, TwoSlopeNorm]] = None,
|
81
83
|
cbar_label: Optional[str] = None,
|
82
84
|
nan_color: Optional[str] = "black",
|
83
|
-
detailed_axes:bool = False,
|
84
|
-
show_gridinfo:bool = True,
|
85
|
+
detailed_axes: bool = False,
|
86
|
+
show_gridinfo: bool = True,
|
85
87
|
trend_color: Optional[str] = None,
|
86
88
|
whiten_by: Optional[np.ndarray] = None,
|
87
89
|
**kwargs,
|
@@ -153,7 +155,9 @@ def plot_wavelet_grid(
|
|
153
155
|
|
154
156
|
# Validate the dimensions
|
155
157
|
if (Nf, Nt) != (len(freq_grid), len(time_grid)):
|
156
|
-
raise ValueError(
|
158
|
+
raise ValueError(
|
159
|
+
f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}."
|
160
|
+
)
|
157
161
|
|
158
162
|
# Prepare the data for plotting
|
159
163
|
z = wavelet_data.copy()
|
@@ -162,14 +166,15 @@ def plot_wavelet_grid(
|
|
162
166
|
if absolute:
|
163
167
|
z = np.abs(z)
|
164
168
|
|
165
|
-
|
166
169
|
# Determine normalization and colormap
|
167
170
|
if norm is None:
|
168
171
|
try:
|
169
172
|
if np.all(np.isnan(z)):
|
170
173
|
raise ValueError("All wavelet data is NaN.")
|
171
174
|
if zscale == "log":
|
172
|
-
norm = LogNorm(
|
175
|
+
norm = LogNorm(
|
176
|
+
vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z < np.inf])
|
177
|
+
)
|
173
178
|
elif not absolute:
|
174
179
|
vmin, vmax = np.nanmin(z), np.nanmax(z)
|
175
180
|
vcenter = 0.0
|
@@ -177,7 +182,9 @@ def plot_wavelet_grid(
|
|
177
182
|
else:
|
178
183
|
norm = None # Default linear scaling
|
179
184
|
except Exception as e:
|
180
|
-
warnings.warn(
|
185
|
+
warnings.warn(
|
186
|
+
f"Error in determining normalization: {e}. Using default linear scaling."
|
187
|
+
)
|
181
188
|
norm = None
|
182
189
|
|
183
190
|
if cmap is None:
|
@@ -195,7 +202,7 @@ def plot_wavelet_grid(
|
|
195
202
|
im = ax.imshow(
|
196
203
|
z,
|
197
204
|
aspect="auto",
|
198
|
-
extent=[time_grid[0],time_grid[-1], freq_grid[0], freq_grid[-1]],
|
205
|
+
extent=[time_grid[0], time_grid[-1], freq_grid[0], freq_grid[-1]],
|
199
206
|
origin="lower",
|
200
207
|
cmap=cmap,
|
201
208
|
norm=norm,
|
@@ -203,13 +210,25 @@ def plot_wavelet_grid(
|
|
203
210
|
**kwargs,
|
204
211
|
)
|
205
212
|
if trend_color is not None:
|
206
|
-
plot_wavelet_trend(
|
213
|
+
plot_wavelet_trend(
|
214
|
+
wavelet_data,
|
215
|
+
time_grid,
|
216
|
+
freq_grid,
|
217
|
+
ax,
|
218
|
+
color=trend_color,
|
219
|
+
freq_range=freq_range,
|
220
|
+
freq_scale=freq_scale,
|
221
|
+
)
|
207
222
|
|
208
223
|
# Add colorbar if requested
|
209
224
|
if show_colorbar:
|
210
225
|
cbar = fig.colorbar(im, ax=ax)
|
211
226
|
if cbar_label is None:
|
212
|
-
cbar_label =
|
227
|
+
cbar_label = (
|
228
|
+
"Absolute Wavelet Amplitude"
|
229
|
+
if absolute
|
230
|
+
else "Wavelet Amplitude"
|
231
|
+
)
|
213
232
|
cbar.set_label(cbar_label)
|
214
233
|
|
215
234
|
# Configure axes scales
|
@@ -239,14 +258,12 @@ def plot_wavelet_grid(
|
|
239
258
|
bbox=dict(boxstyle="round", facecolor=None, alpha=0.2),
|
240
259
|
)
|
241
260
|
|
242
|
-
|
243
261
|
# Adjust layout
|
244
262
|
fig.tight_layout()
|
245
263
|
|
246
264
|
return fig, ax
|
247
265
|
|
248
266
|
|
249
|
-
|
250
267
|
def plot_freqseries(
|
251
268
|
data: np.ndarray,
|
252
269
|
freq: np.ndarray,
|
@@ -277,9 +294,10 @@ def plot_periodogram(
|
|
277
294
|
flow = np.min(np.abs(freq))
|
278
295
|
ax.set_xlabel("Frequency [Hz]")
|
279
296
|
ax.set_ylabel("Periodigram")
|
280
|
-
ax.set_xlim(left=flow, right=nyquist_frequency/2)
|
297
|
+
ax.set_xlim(left=flow, right=nyquist_frequency / 2)
|
281
298
|
return ax.figure, ax
|
282
299
|
|
300
|
+
|
283
301
|
def plot_timeseries(
|
284
302
|
data: np.ndarray, time: np.ndarray, ax=None, **kwargs
|
285
303
|
) -> Tuple[plt.Figure, plt.Axes]:
|
@@ -314,7 +332,6 @@ def plot_spectrogram(
|
|
314
332
|
|
315
333
|
_fmt_time_axis(t, ax)
|
316
334
|
|
317
|
-
|
318
335
|
ax.set_ylabel("Frequency [Hz]")
|
319
336
|
ax.set_ylim(top=fs / 2.0)
|
320
337
|
cbar = plt.colorbar(cm, ax=ax)
|
@@ -322,20 +339,24 @@ def plot_spectrogram(
|
|
322
339
|
return ax.figure, ax
|
323
340
|
|
324
341
|
|
325
|
-
|
326
342
|
def _fmt_time_axis(t, axes, t0=None, tmax=None):
|
327
343
|
if t[-1] > DAY_S: # If time goes beyond a day
|
328
|
-
axes.xaxis.set_major_formatter(
|
344
|
+
axes.xaxis.set_major_formatter(
|
345
|
+
plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}")
|
346
|
+
)
|
329
347
|
axes.set_xlabel("Time [days]")
|
330
348
|
elif t[-1] > HOUR_S: # If time goes beyond an hour
|
331
|
-
axes.xaxis.set_major_formatter(
|
349
|
+
axes.xaxis.set_major_formatter(
|
350
|
+
plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}")
|
351
|
+
)
|
332
352
|
axes.set_xlabel("Time [hr]")
|
333
353
|
elif t[-1] > MIN_S: # If time goes beyond a minute
|
334
|
-
axes.xaxis.set_major_formatter(
|
354
|
+
axes.xaxis.set_major_formatter(
|
355
|
+
plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}")
|
356
|
+
)
|
335
357
|
axes.set_xlabel("Time [min]")
|
336
358
|
else:
|
337
359
|
axes.set_xlabel("Time [s]")
|
338
360
|
t0 = t[0] if t0 is None else t0
|
339
361
|
tmax = t[-1] if tmax is None else tmax
|
340
362
|
axes.set_xlim(t0, tmax)
|
341
|
-
|
@@ -1,11 +1,20 @@
|
|
1
|
+
from typing import Optional, Tuple, Union
|
2
|
+
|
1
3
|
import matplotlib.pyplot as plt
|
2
|
-
from typing import Tuple, Optional, Union
|
3
|
-
from scipy.signal.windows import tukey
|
4
4
|
from scipy.signal import butter, sosfiltfilt
|
5
|
+
from scipy.signal.windows import tukey
|
5
6
|
|
6
|
-
from
|
7
|
-
from .common import
|
8
|
-
|
7
|
+
from ..logger import logger
|
8
|
+
from .common import (
|
9
|
+
fmt_pow2,
|
10
|
+
fmt_time,
|
11
|
+
fmt_timerange,
|
12
|
+
is_documented_by,
|
13
|
+
rfft,
|
14
|
+
rfftfreq,
|
15
|
+
xp,
|
16
|
+
)
|
17
|
+
from .plotting import plot_spectrogram, plot_timeseries
|
9
18
|
|
10
19
|
__all__ = ["TimeSeries"]
|
11
20
|
|
@@ -50,7 +59,7 @@ class TimeSeries:
|
|
50
59
|
|
51
60
|
@is_documented_by(plot_spectrogram)
|
52
61
|
def plot_spectrogram(
|
53
|
-
|
62
|
+
self, ax=None, spec_kwargs={}, plot_kwargs={}
|
54
63
|
) -> Tuple[plt.Figure, plt.Axes]:
|
55
64
|
return plot_spectrogram(
|
56
65
|
self.data,
|
@@ -122,9 +131,11 @@ class TimeSeries:
|
|
122
131
|
trange = fmt_timerange((self.t0, self.tend))
|
123
132
|
T = " ".join(fmt_time(self.duration, units=True))
|
124
133
|
n = fmt_pow2(len(self))
|
125
|
-
return
|
134
|
+
return (
|
135
|
+
f"TimeSeries(n={n}, trange={trange}, T={T}, fs={self.fs:.2f} Hz)"
|
136
|
+
)
|
126
137
|
|
127
|
-
def to_frequencyseries(self) ->
|
138
|
+
def to_frequencyseries(self) -> "FrequencySeries":
|
128
139
|
"""
|
129
140
|
Convert the time series to a frequency series using the one-sided FFT.
|
130
141
|
|
@@ -137,14 +148,15 @@ class TimeSeries:
|
|
137
148
|
data = rfft(self.data)
|
138
149
|
|
139
150
|
from .frequencyseries import FrequencySeries # Avoid circular import
|
151
|
+
|
140
152
|
return FrequencySeries(data, freq, t0=self.t0)
|
141
153
|
|
142
154
|
def to_wavelet(
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
) ->
|
155
|
+
self,
|
156
|
+
Nf: Union[int, None] = None,
|
157
|
+
Nt: Union[int, None] = None,
|
158
|
+
nx: Optional[float] = 4.0,
|
159
|
+
) -> "Wavelet":
|
148
160
|
"""
|
149
161
|
Convert the time series to a wavelet representation.
|
150
162
|
|
@@ -166,32 +178,37 @@ class TimeSeries:
|
|
166
178
|
hf = self.to_frequencyseries()
|
167
179
|
return hf.to_wavelet(Nf, Nt, nx=nx)
|
168
180
|
|
169
|
-
|
170
|
-
def __add__(self, other: 'TimeSeries') -> 'TimeSeries':
|
181
|
+
def __add__(self, other: "TimeSeries") -> "TimeSeries":
|
171
182
|
"""Add two TimeSeries objects together."""
|
172
183
|
if self.shape != other.shape:
|
173
|
-
raise ValueError(
|
184
|
+
raise ValueError(
|
185
|
+
"TimeSeries objects must have the same shape to add them together"
|
186
|
+
)
|
174
187
|
return TimeSeries(self.data + other.data, self.time)
|
175
188
|
|
176
|
-
def __sub__(self, other:
|
189
|
+
def __sub__(self, other: "TimeSeries") -> "TimeSeries":
|
177
190
|
"""Subtract one TimeSeries object from another."""
|
178
191
|
if self.shape != other.shape:
|
179
|
-
raise ValueError(
|
192
|
+
raise ValueError(
|
193
|
+
"TimeSeries objects must have the same shape to subtract them"
|
194
|
+
)
|
180
195
|
return TimeSeries(self.data - other.data, self.time)
|
181
196
|
|
182
|
-
def __eq__(self, other:
|
197
|
+
def __eq__(self, other: "TimeSeries") -> bool:
|
183
198
|
"""Check if two TimeSeries objects are equal."""
|
184
199
|
shape_same = self.shape == other.shape
|
185
200
|
range_same = self.t0 == other.t0 and self.tend == other.tend
|
186
|
-
time_same =
|
201
|
+
time_same = xp.allclose(self.time, other.time)
|
187
202
|
data_same = xp.allclose(self.data, other.data)
|
188
203
|
return shape_same and range_same and data_same and time_same
|
189
204
|
|
190
|
-
def __mul__(self, other: float) ->
|
205
|
+
def __mul__(self, other: float) -> "TimeSeries":
|
191
206
|
"""Multiply a TimeSeries object by a scalar."""
|
192
207
|
return TimeSeries(self.data * other, self.time)
|
193
208
|
|
194
|
-
def zero_pad_to_power_of_2(
|
209
|
+
def zero_pad_to_power_of_2(
|
210
|
+
self, tukey_window_alpha: float = 0.0
|
211
|
+
) -> "TimeSeries":
|
195
212
|
"""Zero pad the time series to make the length a power of two (useful to speed up FFTs, O(NlogN) versus O(N^2)).
|
196
213
|
|
197
214
|
Parameters
|
@@ -207,7 +224,7 @@ class TimeSeries:
|
|
207
224
|
"""
|
208
225
|
N, dt, t0 = self.ND, self.dt, self.t0
|
209
226
|
pow_2 = xp.ceil(xp.log2(N))
|
210
|
-
n_pad = int((2
|
227
|
+
n_pad = int((2**pow_2) - N)
|
211
228
|
new_N = N + n_pad
|
212
229
|
if n_pad > 0:
|
213
230
|
logger.warning(
|
@@ -220,7 +237,12 @@ class TimeSeries:
|
|
220
237
|
time = xp.arange(0, len(data) * dt, dt) + t0
|
221
238
|
return TimeSeries(data, time)
|
222
239
|
|
223
|
-
def highpass_filter(
|
240
|
+
def highpass_filter(
|
241
|
+
self,
|
242
|
+
fmin: float,
|
243
|
+
tukey_window_alpha: float = 0.0,
|
244
|
+
bandpass_order: int = 4,
|
245
|
+
) -> "TimeSeries":
|
224
246
|
"""
|
225
247
|
Filter the time series with a highpass bandpass filter.
|
226
248
|
|
@@ -244,24 +266,25 @@ class TimeSeries:
|
|
244
266
|
"""
|
245
267
|
|
246
268
|
if fmin <= 0 or fmin > self.nyquist_frequency:
|
247
|
-
raise ValueError(
|
269
|
+
raise ValueError(
|
270
|
+
f"Invalid fmin value: {fmin}. Must be in the range [0, {self.nyquist_frequency}]"
|
271
|
+
)
|
248
272
|
|
249
|
-
sos = butter(
|
273
|
+
sos = butter(
|
274
|
+
bandpass_order, Wn=fmin, btype="highpass", output="sos", fs=self.fs
|
275
|
+
)
|
250
276
|
window = tukey(self.ND, alpha=tukey_window_alpha)
|
251
277
|
data = self.data.copy()
|
252
278
|
data = sosfiltfilt(sos, data * window)
|
253
279
|
return TimeSeries(data, self.time)
|
254
280
|
|
255
281
|
def __copy__(self):
|
256
|
-
return TimeSeries(
|
257
|
-
self.data.copy(),
|
258
|
-
self.time.copy()
|
259
|
-
)
|
282
|
+
return TimeSeries(self.data.copy(), self.time.copy())
|
260
283
|
|
261
284
|
def copy(self):
|
262
285
|
return self.__copy__()
|
263
286
|
|
264
|
-
def __getitem__(self, key)->"TimeSeries":
|
287
|
+
def __getitem__(self, key) -> "TimeSeries":
|
265
288
|
if isinstance(key, slice):
|
266
289
|
# Handle slicing
|
267
290
|
return self.__handle_slice(key)
|
@@ -269,12 +292,9 @@ class TimeSeries:
|
|
269
292
|
# Handle regular indexing
|
270
293
|
return TimeSeries(self.data[key], self.time[key])
|
271
294
|
|
272
|
-
def __handle_slice(self, slice_obj)->"TimeSeries":
|
273
|
-
return TimeSeries(
|
274
|
-
self.data[slice_obj],
|
275
|
-
self.time[slice_obj]
|
276
|
-
)
|
295
|
+
def __handle_slice(self, slice_obj) -> "TimeSeries":
|
296
|
+
return TimeSeries(self.data[slice_obj], self.time[slice_obj])
|
277
297
|
|
278
298
|
@classmethod
|
279
|
-
def _EMPTY(cls, ND:int, dt:float)->"TimeSeries":
|
280
|
-
return cls(xp.zeros(ND), xp.arange(0, ND*dt, dt))
|
299
|
+
def _EMPTY(cls, ND: int, dt: float) -> "TimeSeries":
|
300
|
+
return cls(xp.zeros(ND), xp.arange(0, ND * dt, dt))
|
@@ -1,10 +1,11 @@
|
|
1
|
-
import
|
2
|
-
from typing import Optional, Tuple
|
1
|
+
from typing import List, Optional, Tuple
|
3
2
|
|
3
|
+
import matplotlib.pyplot as plt
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
-
from .common import is_documented_by, xp
|
6
|
+
from .common import fmt_timerange, is_documented_by, xp
|
7
7
|
from .plotting import plot_wavelet_grid, plot_wavelet_trend
|
8
|
+
from .wavelet_bins import compute_bins
|
8
9
|
|
9
10
|
|
10
11
|
class Wavelet:
|
@@ -22,10 +23,10 @@ class Wavelet:
|
|
22
23
|
"""
|
23
24
|
|
24
25
|
def __init__(
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
self,
|
27
|
+
data: xp.ndarray,
|
28
|
+
time: xp.ndarray,
|
29
|
+
freq: xp.ndarray,
|
29
30
|
):
|
30
31
|
"""
|
31
32
|
Initialize the Wavelet object with data, time, and frequency arrays.
|
@@ -53,17 +54,60 @@ class Wavelet:
|
|
53
54
|
self.time = time
|
54
55
|
self.freq = freq
|
55
56
|
|
57
|
+
@classmethod
|
58
|
+
def zeros_from_grid(cls, time: xp.ndarray, freq: xp.ndarray) -> "Wavelet":
|
59
|
+
"""
|
60
|
+
Create a Wavelet object filled with zeros.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
time: xp.ndarray
|
65
|
+
freq: xp.ndarray
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
Wavelet
|
70
|
+
A Wavelet object with zero-filled data array.
|
71
|
+
"""
|
72
|
+
Nf, Nt = len(freq), len(time)
|
73
|
+
return cls(data=xp.zeros((Nf, Nt)), time=time, freq=freq)
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def zeros(cls, Nf: int, Nt: int, T: float) -> "Wavelet":
|
77
|
+
"""
|
78
|
+
Create a Wavelet object filled with zeros.
|
79
|
+
|
80
|
+
Parameters
|
81
|
+
----------
|
82
|
+
Nf : int
|
83
|
+
Number of frequency bins.
|
84
|
+
Nt : int
|
85
|
+
Number of time bins.
|
86
|
+
|
87
|
+
Returns
|
88
|
+
-------
|
89
|
+
Wavelet
|
90
|
+
A Wavelet object with zero-filled data array.
|
91
|
+
"""
|
92
|
+
return cls.zeros_from_grid(*compute_bins(Nf, Nt, T))
|
93
|
+
|
56
94
|
@is_documented_by(plot_wavelet_grid)
|
57
95
|
def plot(self, ax=None, *args, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
|
58
96
|
kwargs["time_grid"] = kwargs.get("time_grid", self.time)
|
59
97
|
kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
|
60
|
-
return plot_wavelet_grid(
|
98
|
+
return plot_wavelet_grid(
|
99
|
+
wavelet_data=self.data, ax=ax, *args, **kwargs
|
100
|
+
)
|
61
101
|
|
62
102
|
@is_documented_by(plot_wavelet_trend)
|
63
|
-
def plot_trend(
|
103
|
+
def plot_trend(
|
104
|
+
self, ax=None, *args, **kwargs
|
105
|
+
) -> Tuple[plt.Figure, plt.Axes]:
|
64
106
|
kwargs["time_grid"] = kwargs.get("time_grid", self.time)
|
65
107
|
kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
|
66
|
-
return plot_wavelet_trend(
|
108
|
+
return plot_wavelet_trend(
|
109
|
+
wavelet_data=self.data, ax=ax, *args, **kwargs
|
110
|
+
)
|
67
111
|
|
68
112
|
@property
|
69
113
|
def Nt(self) -> int:
|
@@ -242,7 +286,8 @@ class Wavelet:
|
|
242
286
|
TimeSeries
|
243
287
|
A `TimeSeries` object representing the time-domain signal.
|
244
288
|
"""
|
245
|
-
from ..inverse import from_wavelet_to_time
|
289
|
+
from ..transforms.inverse import from_wavelet_to_time
|
290
|
+
|
246
291
|
return from_wavelet_to_time(self, dt=self.delta_t, nx=nx, mult=mult)
|
247
292
|
|
248
293
|
def to_frequencyseries(self, nx: float = 4.0) -> "FrequencySeries":
|
@@ -254,7 +299,8 @@ class Wavelet:
|
|
254
299
|
FrequencySeries
|
255
300
|
A `FrequencySeries` object representing the frequency-domain signal.
|
256
301
|
"""
|
257
|
-
from ..inverse import from_wavelet_to_freq
|
302
|
+
from ..transforms.inverse import from_wavelet_to_freq
|
303
|
+
|
258
304
|
return from_wavelet_to_freq(self, dt=self.delta_t, nx=nx)
|
259
305
|
|
260
306
|
def __repr__(self) -> str:
|
@@ -277,40 +323,57 @@ class Wavelet:
|
|
277
323
|
def __add__(self, other):
|
278
324
|
"""Element-wise addition of two Wavelet objects."""
|
279
325
|
if isinstance(other, Wavelet):
|
280
|
-
return Wavelet(
|
326
|
+
return Wavelet(
|
327
|
+
data=self.data + other.data, time=self.time, freq=self.freq
|
328
|
+
)
|
281
329
|
elif isinstance(other, float):
|
282
|
-
return Wavelet(
|
330
|
+
return Wavelet(
|
331
|
+
data=self.data + other, time=self.time, freq=self.freq
|
332
|
+
)
|
283
333
|
|
284
334
|
def __sub__(self, other):
|
285
335
|
"""Element-wise subtraction of two Wavelet objects."""
|
286
336
|
if isinstance(other, Wavelet):
|
287
|
-
return Wavelet(
|
337
|
+
return Wavelet(
|
338
|
+
data=self.data - other.data, time=self.time, freq=self.freq
|
339
|
+
)
|
288
340
|
elif isinstance(other, float):
|
289
|
-
return Wavelet(
|
341
|
+
return Wavelet(
|
342
|
+
data=self.data - other, time=self.time, freq=self.freq
|
343
|
+
)
|
290
344
|
|
291
345
|
def __mul__(self, other):
|
292
346
|
"""Element-wise multiplication of two Wavelet objects."""
|
293
347
|
if isinstance(other, Wavelet):
|
294
|
-
return Wavelet(
|
348
|
+
return Wavelet(
|
349
|
+
data=self.data * other.data, time=self.time, freq=self.freq
|
350
|
+
)
|
295
351
|
elif isinstance(other, float):
|
296
|
-
return Wavelet(
|
352
|
+
return Wavelet(
|
353
|
+
data=self.data * other, time=self.time, freq=self.freq
|
354
|
+
)
|
297
355
|
|
298
356
|
def __truediv__(self, other):
|
299
357
|
"""Element-wise division of two Wavelet objects."""
|
300
358
|
if isinstance(other, Wavelet):
|
301
|
-
return Wavelet(
|
359
|
+
return Wavelet(
|
360
|
+
data=self.data / other.data, time=self.time, freq=self.freq
|
361
|
+
)
|
302
362
|
elif isinstance(other, float):
|
303
|
-
return Wavelet(
|
363
|
+
return Wavelet(
|
364
|
+
data=self.data / other, time=self.time, freq=self.freq
|
365
|
+
)
|
304
366
|
|
305
|
-
def __eq__(self, other:"Wavelet") -> bool:
|
367
|
+
def __eq__(self, other: "Wavelet") -> bool:
|
306
368
|
"""Element-wise comparison of two Wavelet objects."""
|
307
369
|
data_all_same = xp.isclose(xp.nansum(self.data - other.data), 0)
|
308
370
|
time_same = (self.time == other.time).all()
|
309
371
|
freq_same = (self.freq == other.freq).all()
|
310
372
|
return data_all_same and time_same and freq_same
|
311
373
|
|
312
|
-
|
313
|
-
|
374
|
+
def noise_weighted_inner_product(
|
375
|
+
self, other: "Wavelet", psd: "Wavelet"
|
376
|
+
) -> float:
|
314
377
|
"""
|
315
378
|
Compute the noise-weighted inner product of two wavelet grids given a PSD.
|
316
379
|
|
@@ -326,11 +389,11 @@ class Wavelet:
|
|
326
389
|
float
|
327
390
|
The noise-weighted inner product.
|
328
391
|
"""
|
329
|
-
from
|
330
|
-
return noise_weighted_inner_product(self, other, psd)
|
392
|
+
from ..utils import noise_weighted_inner_product
|
331
393
|
|
394
|
+
return noise_weighted_inner_product(self, other, psd)
|
332
395
|
|
333
|
-
def matched_filter_snr(self, template:"Wavelet", psd:"Wavelet") -> float:
|
396
|
+
def matched_filter_snr(self, template: "Wavelet", psd: "Wavelet") -> float:
|
334
397
|
"""
|
335
398
|
Compute the matched filter SNR of the wavelet grid given a template.
|
336
399
|
|
@@ -347,7 +410,7 @@ class Wavelet:
|
|
347
410
|
mf = self.noise_weighted_inner_product(template, psd)
|
348
411
|
return mf / self.optimal_snr(psd)
|
349
412
|
|
350
|
-
def optimal_snr(self, psd:"Wavelet") -> float:
|
413
|
+
def optimal_snr(self, psd: "Wavelet") -> float:
|
351
414
|
"""
|
352
415
|
Compute the optimal SNR of the wavelet grid given a PSD.
|
353
416
|
|
@@ -365,10 +428,27 @@ class Wavelet:
|
|
365
428
|
|
366
429
|
def __copy__(self):
|
367
430
|
return Wavelet(
|
368
|
-
data=self.data.copy(),
|
369
|
-
time=self.time.copy(),
|
370
|
-
freq=self.freq.copy()
|
431
|
+
data=self.data.copy(), time=self.time.copy(), freq=self.freq.copy()
|
371
432
|
)
|
372
433
|
|
373
434
|
def copy(self):
|
374
|
-
return self.__copy__()
|
435
|
+
return self.__copy__()
|
436
|
+
|
437
|
+
|
438
|
+
class WaveletMask(Wavelet):
|
439
|
+
@property
|
440
|
+
def mask(self):
|
441
|
+
return self.data
|
442
|
+
|
443
|
+
def __repr__(self):
|
444
|
+
return f"WaveletMask({self.mask.shape}, {fmt_timerange(self.time)}, {self.freq})"
|
445
|
+
|
446
|
+
@classmethod
|
447
|
+
def from_frange(
|
448
|
+
cls, time_grid: xp.ndarray, freq_grid: xp.ndarray, frange: List[float]
|
449
|
+
):
|
450
|
+
self = cls.zeros_from_grid(time_grid, freq_grid)
|
451
|
+
self.mask[
|
452
|
+
(freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
|
453
|
+
] = True
|
454
|
+
return self
|
@@ -2,7 +2,8 @@ from typing import Tuple, Union
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
|
5
|
-
from
|
5
|
+
from .frequencyseries import FrequencySeries
|
6
|
+
from .timeseries import TimeSeries
|
6
7
|
|
7
8
|
|
8
9
|
def _preprocess_bins(
|
@@ -10,7 +11,6 @@ def _preprocess_bins(
|
|
10
11
|
) -> Tuple[int, int]:
|
11
12
|
"""preprocess the bins"""
|
12
13
|
|
13
|
-
|
14
14
|
if isinstance(data, TimeSeries):
|
15
15
|
N = len(data)
|
16
16
|
elif isinstance(data, FrequencySeries):
|
@@ -34,7 +34,6 @@ def _get_bins(
|
|
34
34
|
Nf: Union[int, None] = None,
|
35
35
|
Nt: Union[int, None] = None,
|
36
36
|
) -> Tuple[np.ndarray, np.ndarray]:
|
37
|
-
|
38
37
|
T = data.duration
|
39
38
|
t_bins, f_bins = compute_bins(Nf, Nt, T)
|
40
39
|
|
@@ -42,12 +41,12 @@ def _get_bins(
|
|
42
41
|
# fs = N / T
|
43
42
|
# assert delta_f == fmax / Nf, f"delta_f={delta_f} != fmax/Nf={fmax/Nf}"
|
44
43
|
|
45
|
-
t_bins+= data.t0
|
44
|
+
t_bins += data.t0
|
46
45
|
|
47
46
|
return t_bins, f_bins
|
48
47
|
|
49
48
|
|
50
|
-
def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
|
49
|
+
def compute_bins(Nf: int, Nt: int, T: float) -> Tuple[np.ndarray, np.ndarray]:
|
51
50
|
"""Get the bins for the wavelet transform
|
52
51
|
Eq 4-6 in Wavelets paper
|
53
52
|
"""
|
@@ -55,4 +54,4 @@ def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
|
|
55
54
|
delta_F = 1 / (2 * delta_T)
|
56
55
|
t_bins = np.arange(0, Nt) * delta_T
|
57
56
|
f_bins = np.arange(0, Nf) * delta_F
|
58
|
-
return
|
57
|
+
return t_bins, f_bins
|
pywavelet/utils.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Union
|
|
3
3
|
import numpy as np
|
4
4
|
from scipy.interpolate import interp1d
|
5
5
|
|
6
|
-
from .
|
6
|
+
from .types import FrequencySeries, TimeSeries, Wavelet, WaveletMask
|
7
7
|
|
8
8
|
DATA_TYPE = Union[TimeSeries, FrequencySeries, Wavelet]
|
9
9
|
|
@@ -34,10 +34,13 @@ def evolutionary_psd_from_stationary_psd(
|
|
34
34
|
return Wavelet(psd_grid.T, time=t_grid, freq=f_grid)
|
35
35
|
|
36
36
|
|
37
|
-
def noise_weighted_inner_product(
|
37
|
+
def noise_weighted_inner_product(
|
38
|
+
d: Wavelet, h: Wavelet, PSD: Wavelet
|
39
|
+
) -> float:
|
38
40
|
return np.nansum((d.data * h.data) / PSD.data)
|
39
41
|
|
40
|
-
|
42
|
+
|
43
|
+
def compute_snr(d: Wavelet, h: Wavelet, PSD: Wavelet) -> float:
|
41
44
|
"""Compute the SNR of a model h[ti,fi] given freqseries d[ti,fi] and PSD[ti,fi].
|
42
45
|
|
43
46
|
SNR(h) = Sum_{ti,fi} [ h_hat[ti,fi] d[ti,fi] / PSD[ti,fi]
|
@@ -60,5 +63,14 @@ def compute_snr(d:Wavelet, h: Wavelet, PSD: Wavelet) -> float:
|
|
60
63
|
return np.sqrt(noise_weighted_inner_product(d, h, PSD))
|
61
64
|
|
62
65
|
|
63
|
-
def compute_likelihood(
|
64
|
-
|
66
|
+
def compute_likelihood(
|
67
|
+
data: Wavelet, template: Wavelet, psd: Wavelet, mask: WaveletMask = None
|
68
|
+
) -> float:
|
69
|
+
d = data.data
|
70
|
+
h = template.data
|
71
|
+
p = psd.data
|
72
|
+
if mask is not None:
|
73
|
+
m = mask.mask
|
74
|
+
d, h, p = d * m, h * m, p * m
|
75
|
+
|
76
|
+
return -0.5 * np.nansum((d - h) ** 2 / p)
|
@@ -0,0 +1,25 @@
|
|
1
|
+
pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
|
2
|
+
pywavelet/_version.py,sha256=PKIMyjdUACH4-ONvtunQCnYE2UhlMfp9su83e3HXl5E,411
|
3
|
+
pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
|
4
|
+
pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
|
5
|
+
pywavelet/transforms/__init__.py,sha256=1Ibsup9UwMajeZ9NCQ4BN15qZTeJ_EHkgGu8XNFdA18,255
|
6
|
+
pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
|
7
|
+
pywavelet/transforms/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
|
8
|
+
pywavelet/transforms/forward/from_freq.py,sha256=wCiyLpzJE3rGxYjQBdXlwkxPIRYhQWjKq0C_8zYlmDk,2697
|
9
|
+
pywavelet/transforms/forward/from_time.py,sha256=-Y6VEKwDCYBAHAjLdO46vT-6alpM5fXTgTZ_xkYxqA8,2381
|
10
|
+
pywavelet/transforms/forward/main.py,sha256=Gfy0sp-woy_3ihKMzuk2WuZ7dRk-Mm6sp5dVpYrSvj4,4005
|
11
|
+
pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
|
12
|
+
pywavelet/transforms/inverse/main.py,sha256=Q7wUaRjB1sgqdB7dniWQGbPTWYQNnIsIrYtjsaHJEdE,3012
|
13
|
+
pywavelet/transforms/inverse/to_freq.py,sha256=so_TDbwdS1N8sd1QcpeAEkI10XFDtoFJGohtD4YulZM,2809
|
14
|
+
pywavelet/transforms/inverse/to_time.py,sha256=w5vmImdsb_4YeInZtXh0llsThLTxS0tmYDlNGJ-IUew,5080
|
15
|
+
pywavelet/types/__init__.py,sha256=5YptzQvYBnRfC8N5lpOBf9I1lzpJ0pw0QMnvIcwP3YI,122
|
16
|
+
pywavelet/types/common.py,sha256=OSAW6GqLTgqJ-RYEv__XbzsfFd8AFo5w-ctXQ4XAFZo,1317
|
17
|
+
pywavelet/types/frequencyseries.py,sha256=UqcE6UQfw5HZm4na2q9k-X-mfqO-BCiTAvGjaYpSrwc,7518
|
18
|
+
pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
|
19
|
+
pywavelet/types/timeseries.py,sha256=6DPO0xLi4Dq2srhJLmavFMf4fYIC3wwdbyMU7lMdjTo,9446
|
20
|
+
pywavelet/types/wavelet.py,sha256=ptTEnq6nRZiW2x6g_NV_FuoeVNOsbNOu6caSxYDZNgk,12583
|
21
|
+
pywavelet/types/wavelet_bins.py,sha256=SC9nhyigWvOfs2TbH8-Ck_iS1Mrz6sG-8EaIFYLIHuk,1453
|
22
|
+
pywavelet-0.1.1.dist-info/METADATA,sha256=OgKTqqZQKfhCCcsYpcXnrut-2EAH42c3tCpCGc0PIDg,1307
|
23
|
+
pywavelet-0.1.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
24
|
+
pywavelet-0.1.1.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
|
25
|
+
pywavelet-0.1.1.dist-info/RECORD,,
|
pywavelet-0.0.5.dist-info/RECORD
DELETED
@@ -1,25 +0,0 @@
|
|
1
|
-
pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
|
2
|
-
pywavelet/_version.py,sha256=EJB7__SNK9kQS_SWZB_U4DHJ3P8ftF6etZEihTYnuXE,411
|
3
|
-
pywavelet/logger.py,sha256=u5Et6QLUU0IuTnPnxfev5-4GYmihTo7_yBCTf-rVUsA,377
|
4
|
-
pywavelet/utils.py,sha256=rbmZ-PrnFstq05kh6xj8emuQ3-7oXqZinU03bpg3N7A,1746
|
5
|
-
pywavelet/transforms/__init__.py,sha256=FolK8WiVEJmGDC9xMupYVI_essXaXS4LYWKbqEqGx6o,289
|
6
|
-
pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
|
7
|
-
pywavelet/transforms/forward/__init__.py,sha256=Yq4Tg3Ze98-17C9FIkOqMUdiLHe9x_YoyuRvxOxMOP0,176
|
8
|
-
pywavelet/transforms/forward/from_freq.py,sha256=HWADciv746P5zWNhnyRKDEDACTQnwDTR4usf2k0VIWk,2663
|
9
|
-
pywavelet/transforms/forward/from_time.py,sha256=zfpvmq7UHnynzAnsM_Pf4MXpO1GJ0TaCCERQKNsaBNU,2340
|
10
|
-
pywavelet/transforms/forward/main.py,sha256=FJSLUUsk3I91lUeB2YFiXQdX8gB9maFW60e4tgrd45A,4053
|
11
|
-
pywavelet/transforms/forward/wavelet_bins.py,sha256=43CHVSjNgC59Af1_h9t33VBBqGHJmIUaOUhJgekhMnc,1419
|
12
|
-
pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
|
13
|
-
pywavelet/transforms/inverse/main.py,sha256=l5yFvzmWObrO5Xt_8KYp62w829ab0pQLxcv0-QxkvG0,3015
|
14
|
-
pywavelet/transforms/inverse/to_freq.py,sha256=SExZMax-8A-tJpIA86pYY61X2qvlZ2MrZY27uzCQSV0,2778
|
15
|
-
pywavelet/transforms/inverse/to_time.py,sha256=BAYvrr41QHIbzwYMPyMnzv5mqSx40YigmBruWBwtZwc,5041
|
16
|
-
pywavelet/transforms/types/__init__.py,sha256=4wUTVBk6A02xjZUY_w056eUZurYI9vVfa--I3Q6Udng,109
|
17
|
-
pywavelet/transforms/types/common.py,sha256=sLBn2d9cuL6NYeIv6NIogDjY6rYPZAPzCtWGwMlAwkI,1290
|
18
|
-
pywavelet/transforms/types/frequencyseries.py,sha256=Rtwt486UL0-TgMAdcMMVpyfi5PSLzxFcdE_RsYuyxQk,7463
|
19
|
-
pywavelet/transforms/types/plotting.py,sha256=aEIFoSuQRYwYc2639yLkbujXx_aav5A5tXG29rOSMOQ,10275
|
20
|
-
pywavelet/transforms/types/timeseries.py,sha256=Nl1tiKZ7kwu-EZ5JtubwgzgyjQZR86eGU3C3kElIPNg,9296
|
21
|
-
pywavelet/transforms/types/wavelet.py,sha256=raODhyegBd1_esuv2YP6tK_Nj9fYaLQ6O4pxpiFAVZU,10790
|
22
|
-
pywavelet-0.0.5.dist-info/METADATA,sha256=BaLK8jbUqss_8BeGWOlf_031PoCAlEV4rmuSaitQd0E,1307
|
23
|
-
pywavelet-0.0.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
24
|
-
pywavelet-0.0.5.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
|
25
|
-
pywavelet-0.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|