pywavelet 0.0.5__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/_version.py +2 -2
- pywavelet/logger.py +6 -2
- pywavelet/transforms/__init__.py +1 -2
- pywavelet/transforms/forward/__init__.py +1 -2
- pywavelet/transforms/forward/from_freq.py +13 -4
- pywavelet/transforms/forward/from_time.py +13 -1
- pywavelet/transforms/forward/main.py +7 -15
- pywavelet/transforms/inverse/main.py +3 -4
- pywavelet/transforms/inverse/to_freq.py +11 -3
- pywavelet/transforms/inverse/to_time.py +13 -3
- pywavelet/{transforms/types → types}/__init__.py +1 -1
- pywavelet/{transforms/types → types}/common.py +7 -7
- pywavelet/{transforms/types → types}/frequencyseries.py +31 -23
- pywavelet/{transforms/types → types}/plotting.py +47 -26
- pywavelet/{transforms/types → types}/timeseries.py +58 -38
- pywavelet/{transforms/types → types}/wavelet.py +111 -31
- pywavelet/{transforms/forward → types}/wavelet_bins.py +5 -6
- pywavelet/utils.py +17 -5
- {pywavelet-0.0.5.dist-info → pywavelet-0.1.1.dist-info}/METADATA +1 -1
- pywavelet-0.1.1.dist-info/RECORD +25 -0
- pywavelet-0.0.5.dist-info/RECORD +0 -25
- {pywavelet-0.0.5.dist-info → pywavelet-0.1.1.dist-info}/WHEEL +0 -0
- {pywavelet-0.0.5.dist-info → pywavelet-0.1.1.dist-info}/top_level.txt +0 -0
pywavelet/_version.py
CHANGED
pywavelet/logger.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1
|
+
import logging
|
1
2
|
import sys
|
2
3
|
import warnings
|
4
|
+
|
3
5
|
from rich.logging import RichHandler
|
4
|
-
import logging
|
5
6
|
|
6
7
|
FORMAT = "%(message)s"
|
7
8
|
logging.basicConfig(
|
8
|
-
level="INFO",
|
9
|
+
level="INFO",
|
10
|
+
format=FORMAT,
|
11
|
+
datefmt="[%X]",
|
12
|
+
handlers=[RichHandler(rich_tracebacks=True)],
|
9
13
|
)
|
10
14
|
|
11
15
|
logger = logging.getLogger("pywavelet")
|
pywavelet/transforms/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
1
|
+
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
2
2
|
from .inverse import from_wavelet_to_freq, from_wavelet_to_time
|
3
3
|
|
4
4
|
__all__ = [
|
@@ -6,5 +6,4 @@ __all__ = [
|
|
6
6
|
"from_wavelet_to_freq",
|
7
7
|
"from_time_to_wavelet",
|
8
8
|
"from_freq_to_wavelet",
|
9
|
-
"compute_bins",
|
10
9
|
]
|
@@ -1,8 +1,10 @@
|
|
1
1
|
"""helper functions for transform_freq"""
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
from numba import njit
|
4
5
|
from numpy import fft
|
5
6
|
|
7
|
+
|
6
8
|
def transform_wavelet_freq_helper(
|
7
9
|
data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
|
8
10
|
) -> np.ndarray:
|
@@ -13,8 +15,16 @@ def transform_wavelet_freq_helper(
|
|
13
15
|
__core(Nf, Nt, DX, freq_strain, phif, wave)
|
14
16
|
return wave
|
15
17
|
|
18
|
+
|
16
19
|
# @njit()
|
17
|
-
def __core(
|
20
|
+
def __core(
|
21
|
+
Nf: int,
|
22
|
+
Nt: int,
|
23
|
+
DX: np.ndarray,
|
24
|
+
freq_strain: np.ndarray,
|
25
|
+
phif: np.ndarray,
|
26
|
+
wave: np.ndarray,
|
27
|
+
):
|
18
28
|
for f_bin in range(0, Nf + 1):
|
19
29
|
__fill_wave_1(f_bin, Nt, Nf, DX, freq_strain, phif)
|
20
30
|
# Numba doesn't support np.ifft so we cant jit this
|
@@ -22,8 +32,6 @@ def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarra
|
|
22
32
|
__fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
|
23
33
|
|
24
34
|
|
25
|
-
|
26
|
-
|
27
35
|
@njit()
|
28
36
|
def __fill_wave_1(
|
29
37
|
f_bin: int,
|
@@ -55,6 +63,7 @@ def __fill_wave_1(
|
|
55
63
|
else:
|
56
64
|
DX[i] = phif[j] * data[jj]
|
57
65
|
|
66
|
+
|
58
67
|
@njit()
|
59
68
|
def __fill_wave_2(
|
60
69
|
f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
|
@@ -72,7 +81,7 @@ def __fill_wave_2(
|
|
72
81
|
if (n + f_bin) % 2:
|
73
82
|
wave[n, f_bin] = -DX_trans[n].imag
|
74
83
|
else:
|
75
|
-
wave[n, f_bin] =
|
84
|
+
wave[n, f_bin] = DX_trans[n].real
|
76
85
|
else:
|
77
86
|
if (n + f_bin) % 2:
|
78
87
|
wave[n, f_bin] = DX_trans[n].imag
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""helper functions for transform_time.py"""
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
from numba import njit
|
4
5
|
from numpy import fft
|
@@ -20,7 +21,18 @@ def transform_wavelet_time_helper(
|
|
20
21
|
__core(Nf, Nt, K, ND, wdata, data_pad, phi, wave, mult)
|
21
22
|
return wave
|
22
23
|
|
23
|
-
|
24
|
+
|
25
|
+
def __core(
|
26
|
+
Nf: int,
|
27
|
+
Nt: int,
|
28
|
+
K: int,
|
29
|
+
ND: int,
|
30
|
+
wdata: np.ndarray,
|
31
|
+
data_pad: np.ndarray,
|
32
|
+
phi: np.ndarray,
|
33
|
+
wave: np.ndarray,
|
34
|
+
mult: int,
|
35
|
+
) -> None:
|
24
36
|
for time_bin_i in range(0, Nt):
|
25
37
|
__fill_wave_1(time_bin_i, K, ND, Nf, wdata, data_pad, phi)
|
26
38
|
wdata_trans = np.fft.rfft(wdata, K)
|
@@ -3,15 +3,15 @@ from typing import Union
|
|
3
3
|
import numpy as np
|
4
4
|
|
5
5
|
from ...logger import logger
|
6
|
+
from ...types import FrequencySeries, TimeSeries, Wavelet
|
7
|
+
from ...types.wavelet_bins import _get_bins, _preprocess_bins
|
6
8
|
from ..phi_computer import phi_vec, phitilde_vec_norm
|
7
|
-
from ..types import FrequencySeries, TimeSeries, Wavelet
|
8
9
|
from .from_freq import transform_wavelet_freq_helper
|
9
10
|
from .from_time import transform_wavelet_time_helper
|
10
|
-
from .wavelet_bins import _get_bins, _preprocess_bins
|
11
|
-
|
12
11
|
|
13
12
|
__all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
|
14
13
|
|
14
|
+
|
15
15
|
def from_time_to_wavelet(
|
16
16
|
timeseries: TimeSeries,
|
17
17
|
Nf: Union[int, None] = None,
|
@@ -74,9 +74,7 @@ def from_time_to_wavelet(
|
|
74
74
|
mult = min(mult, Nt // 2) # Ensure mult is not larger than ND/2
|
75
75
|
phi = phi_vec(Nf, dt=dt, d=nx, q=mult)
|
76
76
|
wave = transform_wavelet_time_helper(timeseries.data, Nf, Nt, phi, mult).T
|
77
|
-
return Wavelet(
|
78
|
-
wave * np.sqrt(2), time=t_bins, freq=f_bins
|
79
|
-
)
|
77
|
+
return Wavelet(wave * np.sqrt(2), time=t_bins, freq=f_bins)
|
80
78
|
|
81
79
|
|
82
80
|
def from_freq_to_wavelet(
|
@@ -117,12 +115,6 @@ def from_freq_to_wavelet(
|
|
117
115
|
t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
|
118
116
|
dt = freqseries.dt
|
119
117
|
phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)
|
120
|
-
wave = transform_wavelet_freq_helper(
|
121
|
-
|
122
|
-
)
|
123
|
-
|
124
|
-
return Wavelet(
|
125
|
-
(2 / Nf) * wave.T * np.sqrt(2),
|
126
|
-
time=t_bins,
|
127
|
-
freq=f_bins
|
128
|
-
)
|
118
|
+
wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
|
119
|
+
|
120
|
+
return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
3
|
from ...transforms.phi_computer import phi_vec, phitilde_vec_norm
|
4
|
-
from
|
4
|
+
from ...types import FrequencySeries, TimeSeries, Wavelet
|
5
5
|
from .to_freq import inverse_wavelet_freq_helper_fast
|
6
6
|
from .to_time import inverse_wavelet_time_helper_fast
|
7
7
|
|
@@ -9,6 +9,7 @@ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
|
|
9
9
|
|
10
10
|
INV_ROOT2 = 1.0 / np.sqrt(2)
|
11
11
|
|
12
|
+
|
12
13
|
def from_wavelet_to_time(
|
13
14
|
wave_in: Wavelet,
|
14
15
|
dt: float,
|
@@ -55,9 +56,7 @@ def from_wavelet_to_time(
|
|
55
56
|
|
56
57
|
|
57
58
|
def from_wavelet_to_freq(
|
58
|
-
wave_in: Wavelet,
|
59
|
-
dt: float,
|
60
|
-
nx:float=4.0
|
59
|
+
wave_in: Wavelet, dt: float, nx: float = 4.0
|
61
60
|
) -> FrequencySeries:
|
62
61
|
"""
|
63
62
|
Perform an inverse wavelet transform to the frequency domain.
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""functions for computing the inverse wavelet transforms"""
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
from numba import njit
|
4
5
|
from numpy import fft
|
@@ -12,13 +13,20 @@ def inverse_wavelet_freq_helper_fast(
|
|
12
13
|
ND = Nf * Nt
|
13
14
|
|
14
15
|
prefactor2s = np.zeros(Nt, np.complex128)
|
15
|
-
res = np.zeros(ND//2 +1, dtype=np.complex128)
|
16
|
+
res = np.zeros(ND // 2 + 1, dtype=np.complex128)
|
16
17
|
__core(Nf, Nt, prefactor2s, wave_in, phif, res)
|
17
18
|
|
18
|
-
|
19
19
|
return res
|
20
20
|
|
21
|
-
|
21
|
+
|
22
|
+
def __core(
|
23
|
+
Nf: int,
|
24
|
+
Nt: int,
|
25
|
+
prefactor2s: np.ndarray,
|
26
|
+
wave_in: np.ndarray,
|
27
|
+
phif: np.ndarray,
|
28
|
+
res: np.ndarray,
|
29
|
+
) -> None:
|
22
30
|
for m in range(0, Nf + 1):
|
23
31
|
__pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
|
24
32
|
fft_prefactor2s = np.fft.fft(prefactor2s)
|
@@ -1,4 +1,5 @@
|
|
1
1
|
"""functions for computing the inverse wavelet transforms"""
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
from numba import njit
|
4
5
|
from numpy import fft
|
@@ -21,7 +22,16 @@ def inverse_wavelet_time_helper_fast(
|
|
21
22
|
return res[:ND]
|
22
23
|
|
23
24
|
|
24
|
-
def __core(
|
25
|
+
def __core(
|
26
|
+
Nf: int,
|
27
|
+
Nt: int,
|
28
|
+
K: int,
|
29
|
+
ND: int,
|
30
|
+
wave_in: np.ndarray,
|
31
|
+
phi: np.ndarray,
|
32
|
+
res: np.ndarray,
|
33
|
+
afins: np.ndarray,
|
34
|
+
) -> None:
|
25
35
|
for n in range(0, Nt):
|
26
36
|
if n % 2 == 0:
|
27
37
|
pack_wave_time_helper_compact(n, Nf, Nt, wave_in, afins)
|
@@ -29,9 +39,9 @@ def __core(Nf: int, Nt: int, K: int, ND: int, wave_in: np.ndarray, phi: np.ndarr
|
|
29
39
|
unpack_time_wave_helper_compact(n, Nf, Nt, K, phi, ffts_fin, res)
|
30
40
|
|
31
41
|
# wrap boundary conditions
|
32
|
-
res[: min(K + Nf, ND)] += res[ND: min(ND + K + Nf, 2 * ND)]
|
42
|
+
res[: min(K + Nf, ND)] += res[ND : min(ND + K + Nf, 2 * ND)]
|
33
43
|
if K + Nf > ND:
|
34
|
-
res[: K + Nf - ND] += res[2 * ND: ND + K * Nf]
|
44
|
+
res[: K + Nf - ND] += res[2 * ND : ND + K * Nf]
|
35
45
|
|
36
46
|
|
37
47
|
def unpack_time_wave_helper(
|
@@ -1,9 +1,9 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Tuple, Union
|
2
2
|
|
3
3
|
import numpy as xp
|
4
|
-
from numpy.fft import
|
4
|
+
from numpy.fft import fft, irfft, rfft, rfftfreq # type: ignore
|
5
5
|
|
6
|
-
from
|
6
|
+
from ..logger import logger
|
7
7
|
|
8
8
|
|
9
9
|
def _len_check(d):
|
@@ -19,7 +19,7 @@ def is_documented_by(original):
|
|
19
19
|
return wrapper
|
20
20
|
|
21
21
|
|
22
|
-
def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
|
22
|
+
def fmt_time(seconds: float, units=False) -> Union[str, Tuple[str, str]]:
|
23
23
|
"""Returns formatted time and units [ms, s, min, hr, day]"""
|
24
24
|
t, u = "", ""
|
25
25
|
if seconds < 1e-3:
|
@@ -42,12 +42,12 @@ def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
|
|
42
42
|
|
43
43
|
def fmt_timerange(trange):
|
44
44
|
t0 = fmt_time(trange[0])
|
45
|
-
tend, units = fmt_time(trange[1], units
|
45
|
+
tend, units = fmt_time(trange[1], units=True)
|
46
46
|
return f"[{t0}, {tend}] {units}"
|
47
47
|
|
48
48
|
|
49
|
-
def fmt_pow2(n:float)->str:
|
49
|
+
def fmt_pow2(n: float) -> str:
|
50
50
|
pow2 = xp.log2(n)
|
51
51
|
if pow2.is_integer():
|
52
52
|
return f"2^{int(pow2)}"
|
53
|
-
return f"{n:,}"
|
53
|
+
return f"{n:,}"
|
@@ -1,11 +1,13 @@
|
|
1
|
+
from typing import Optional, Tuple, Union
|
2
|
+
|
1
3
|
import matplotlib.pyplot as plt
|
2
|
-
from typing import Tuple, Union, Optional
|
3
4
|
|
4
|
-
from .common import
|
5
|
+
from .common import fmt_pow2, fmt_time, irfft, is_documented_by, xp
|
5
6
|
from .plotting import plot_freqseries, plot_periodogram
|
6
7
|
|
7
8
|
__all__ = ["FrequencySeries"]
|
8
9
|
|
10
|
+
|
9
11
|
class FrequencySeries:
|
10
12
|
"""
|
11
13
|
A class to represent a one-sided frequency series, with various methods for
|
@@ -39,9 +41,13 @@ class FrequencySeries:
|
|
39
41
|
If any frequency is negative or if `data` and `freq` do not have the same length.
|
40
42
|
"""
|
41
43
|
if xp.any(freq < 0):
|
42
|
-
raise ValueError(
|
44
|
+
raise ValueError(
|
45
|
+
"FrequencySeries must be one-sided (only non-negative frequencies)"
|
46
|
+
)
|
43
47
|
if len(data) != len(freq):
|
44
|
-
raise ValueError(
|
48
|
+
raise ValueError(
|
49
|
+
f"data and freq must have the same length ({len(data)} != {len(freq)})"
|
50
|
+
)
|
45
51
|
self.data = data
|
46
52
|
self.freq = freq
|
47
53
|
self.t0 = t0
|
@@ -53,10 +59,10 @@ class FrequencySeries:
|
|
53
59
|
)
|
54
60
|
|
55
61
|
@is_documented_by(plot_periodogram)
|
56
|
-
def plot_periodogram(
|
57
|
-
|
58
|
-
|
59
|
-
)
|
62
|
+
def plot_periodogram(
|
63
|
+
self, ax=None, **kwargs
|
64
|
+
) -> Tuple[plt.Figure, plt.Axes]:
|
65
|
+
return plot_periodogram(self.data, self.freq, self.fs, ax=ax, **kwargs)
|
60
66
|
|
61
67
|
def __len__(self):
|
62
68
|
"""Return the length of the frequency series."""
|
@@ -127,7 +133,9 @@ class FrequencySeries:
|
|
127
133
|
n = fmt_pow2(len(self))
|
128
134
|
return f"FrequencySeries(n={n}, frange=[{self.range[0]:.2f}, {self.range[1]:.2f}] Hz, T={dur}, fs={self.fs:.2f} Hz)"
|
129
135
|
|
130
|
-
def noise_weighted_inner_product(
|
136
|
+
def noise_weighted_inner_product(
|
137
|
+
self, other: "FrequencySeries", psd: "FrequencySeries"
|
138
|
+
) -> float:
|
131
139
|
"""
|
132
140
|
Compute the noise-weighted inner product of two FrequencySeries.
|
133
141
|
|
@@ -144,9 +152,11 @@ class FrequencySeries:
|
|
144
152
|
The noise-weighted inner product of the two FrequencySeries.
|
145
153
|
"""
|
146
154
|
integrand = xp.real(xp.conj(self.data) * other.data / psd.data)
|
147
|
-
return (4 * self.dt/self.ND) * xp.nansum(integrand)
|
155
|
+
return (4 * self.dt / self.ND) * xp.nansum(integrand)
|
148
156
|
|
149
|
-
def matched_filter_snr(
|
157
|
+
def matched_filter_snr(
|
158
|
+
self, other: "FrequencySeries", psd: "FrequencySeries"
|
159
|
+
) -> float:
|
150
160
|
"""
|
151
161
|
Compute the signal-to-noise ratio (SNR) of a matched filter.
|
152
162
|
|
@@ -199,15 +209,15 @@ class FrequencySeries:
|
|
199
209
|
|
200
210
|
# Create and return a TimeSeries object
|
201
211
|
from .timeseries import TimeSeries
|
202
|
-
return TimeSeries(time_data, time)
|
203
212
|
|
213
|
+
return TimeSeries(time_data, time)
|
204
214
|
|
205
215
|
def to_wavelet(
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
216
|
+
self,
|
217
|
+
Nf: Union[int, None] = None,
|
218
|
+
Nt: Union[int, None] = None,
|
219
|
+
nx: Optional[float] = 4.0,
|
220
|
+
) -> "Wavelet":
|
211
221
|
"""
|
212
222
|
Convert the frequency series to a wavelet using inverse Fourier transform.
|
213
223
|
|
@@ -216,9 +226,9 @@ class FrequencySeries:
|
|
216
226
|
Wavelet
|
217
227
|
The corresponding wavelet.
|
218
228
|
"""
|
219
|
-
from ..forward import from_freq_to_wavelet
|
220
|
-
return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
|
229
|
+
from ..transforms.forward import from_freq_to_wavelet
|
221
230
|
|
231
|
+
return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
|
222
232
|
|
223
233
|
def __eq__(self, other):
|
224
234
|
"""Check if two FrequencySeries objects are equal."""
|
@@ -228,10 +238,8 @@ class FrequencySeries:
|
|
228
238
|
|
229
239
|
def __copy__(self):
|
230
240
|
return FrequencySeries(
|
231
|
-
xp.copy(self.data),
|
232
|
-
xp.copy(self.freq),
|
233
|
-
t0=self.t0
|
241
|
+
xp.copy(self.data), xp.copy(self.freq), t0=self.t0
|
234
242
|
)
|
235
243
|
|
236
244
|
def copy(self):
|
237
|
-
return self.__copy__()
|
245
|
+
return self.__copy__()
|
@@ -1,12 +1,11 @@
|
|
1
1
|
import warnings
|
2
|
-
from typing import
|
3
|
-
from scipy.signal import savgol_filter
|
4
|
-
from scipy.interpolate import interp1d
|
2
|
+
from typing import Optional, Tuple, Union
|
5
3
|
|
6
4
|
import matplotlib.pyplot as plt
|
7
5
|
import numpy as np
|
8
6
|
from matplotlib.colors import LogNorm, TwoSlopeNorm
|
9
|
-
from scipy.
|
7
|
+
from scipy.interpolate import interp1d
|
8
|
+
from scipy.signal import savgol_filter, spectrogram
|
10
9
|
|
11
10
|
MIN_S = 60
|
12
11
|
HOUR_S = 60 * MIN_S
|
@@ -53,8 +52,13 @@ def __get_smoothed_y(x, z, y_grid):
|
|
53
52
|
# Interpolate to fill NaNs in y before smoothing
|
54
53
|
nan_mask = ~np.isnan(y)
|
55
54
|
if np.isnan(y).any():
|
56
|
-
interpolator = interp1d(
|
57
|
-
|
55
|
+
interpolator = interp1d(
|
56
|
+
x[nan_mask],
|
57
|
+
y[nan_mask],
|
58
|
+
kind="cubic",
|
59
|
+
bounds_error=False,
|
60
|
+
fill_value="extrapolate",
|
61
|
+
)
|
58
62
|
y = interpolator(x) # Fill NaNs with interpolated values
|
59
63
|
|
60
64
|
# Smooth the curve
|
@@ -64,8 +68,6 @@ def __get_smoothed_y(x, z, y_grid):
|
|
64
68
|
return y
|
65
69
|
|
66
70
|
|
67
|
-
|
68
|
-
|
69
71
|
def plot_wavelet_grid(
|
70
72
|
wavelet_data: np.ndarray,
|
71
73
|
time_grid: np.ndarray,
|
@@ -80,8 +82,8 @@ def plot_wavelet_grid(
|
|
80
82
|
norm: Optional[Union[LogNorm, TwoSlopeNorm]] = None,
|
81
83
|
cbar_label: Optional[str] = None,
|
82
84
|
nan_color: Optional[str] = "black",
|
83
|
-
detailed_axes:bool = False,
|
84
|
-
show_gridinfo:bool = True,
|
85
|
+
detailed_axes: bool = False,
|
86
|
+
show_gridinfo: bool = True,
|
85
87
|
trend_color: Optional[str] = None,
|
86
88
|
whiten_by: Optional[np.ndarray] = None,
|
87
89
|
**kwargs,
|
@@ -153,7 +155,9 @@ def plot_wavelet_grid(
|
|
153
155
|
|
154
156
|
# Validate the dimensions
|
155
157
|
if (Nf, Nt) != (len(freq_grid), len(time_grid)):
|
156
|
-
raise ValueError(
|
158
|
+
raise ValueError(
|
159
|
+
f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}."
|
160
|
+
)
|
157
161
|
|
158
162
|
# Prepare the data for plotting
|
159
163
|
z = wavelet_data.copy()
|
@@ -162,14 +166,15 @@ def plot_wavelet_grid(
|
|
162
166
|
if absolute:
|
163
167
|
z = np.abs(z)
|
164
168
|
|
165
|
-
|
166
169
|
# Determine normalization and colormap
|
167
170
|
if norm is None:
|
168
171
|
try:
|
169
172
|
if np.all(np.isnan(z)):
|
170
173
|
raise ValueError("All wavelet data is NaN.")
|
171
174
|
if zscale == "log":
|
172
|
-
norm = LogNorm(
|
175
|
+
norm = LogNorm(
|
176
|
+
vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z < np.inf])
|
177
|
+
)
|
173
178
|
elif not absolute:
|
174
179
|
vmin, vmax = np.nanmin(z), np.nanmax(z)
|
175
180
|
vcenter = 0.0
|
@@ -177,7 +182,9 @@ def plot_wavelet_grid(
|
|
177
182
|
else:
|
178
183
|
norm = None # Default linear scaling
|
179
184
|
except Exception as e:
|
180
|
-
warnings.warn(
|
185
|
+
warnings.warn(
|
186
|
+
f"Error in determining normalization: {e}. Using default linear scaling."
|
187
|
+
)
|
181
188
|
norm = None
|
182
189
|
|
183
190
|
if cmap is None:
|
@@ -195,7 +202,7 @@ def plot_wavelet_grid(
|
|
195
202
|
im = ax.imshow(
|
196
203
|
z,
|
197
204
|
aspect="auto",
|
198
|
-
extent=[time_grid[0],time_grid[-1], freq_grid[0], freq_grid[-1]],
|
205
|
+
extent=[time_grid[0], time_grid[-1], freq_grid[0], freq_grid[-1]],
|
199
206
|
origin="lower",
|
200
207
|
cmap=cmap,
|
201
208
|
norm=norm,
|
@@ -203,13 +210,25 @@ def plot_wavelet_grid(
|
|
203
210
|
**kwargs,
|
204
211
|
)
|
205
212
|
if trend_color is not None:
|
206
|
-
plot_wavelet_trend(
|
213
|
+
plot_wavelet_trend(
|
214
|
+
wavelet_data,
|
215
|
+
time_grid,
|
216
|
+
freq_grid,
|
217
|
+
ax,
|
218
|
+
color=trend_color,
|
219
|
+
freq_range=freq_range,
|
220
|
+
freq_scale=freq_scale,
|
221
|
+
)
|
207
222
|
|
208
223
|
# Add colorbar if requested
|
209
224
|
if show_colorbar:
|
210
225
|
cbar = fig.colorbar(im, ax=ax)
|
211
226
|
if cbar_label is None:
|
212
|
-
cbar_label =
|
227
|
+
cbar_label = (
|
228
|
+
"Absolute Wavelet Amplitude"
|
229
|
+
if absolute
|
230
|
+
else "Wavelet Amplitude"
|
231
|
+
)
|
213
232
|
cbar.set_label(cbar_label)
|
214
233
|
|
215
234
|
# Configure axes scales
|
@@ -239,14 +258,12 @@ def plot_wavelet_grid(
|
|
239
258
|
bbox=dict(boxstyle="round", facecolor=None, alpha=0.2),
|
240
259
|
)
|
241
260
|
|
242
|
-
|
243
261
|
# Adjust layout
|
244
262
|
fig.tight_layout()
|
245
263
|
|
246
264
|
return fig, ax
|
247
265
|
|
248
266
|
|
249
|
-
|
250
267
|
def plot_freqseries(
|
251
268
|
data: np.ndarray,
|
252
269
|
freq: np.ndarray,
|
@@ -277,9 +294,10 @@ def plot_periodogram(
|
|
277
294
|
flow = np.min(np.abs(freq))
|
278
295
|
ax.set_xlabel("Frequency [Hz]")
|
279
296
|
ax.set_ylabel("Periodigram")
|
280
|
-
ax.set_xlim(left=flow, right=nyquist_frequency/2)
|
297
|
+
ax.set_xlim(left=flow, right=nyquist_frequency / 2)
|
281
298
|
return ax.figure, ax
|
282
299
|
|
300
|
+
|
283
301
|
def plot_timeseries(
|
284
302
|
data: np.ndarray, time: np.ndarray, ax=None, **kwargs
|
285
303
|
) -> Tuple[plt.Figure, plt.Axes]:
|
@@ -314,7 +332,6 @@ def plot_spectrogram(
|
|
314
332
|
|
315
333
|
_fmt_time_axis(t, ax)
|
316
334
|
|
317
|
-
|
318
335
|
ax.set_ylabel("Frequency [Hz]")
|
319
336
|
ax.set_ylim(top=fs / 2.0)
|
320
337
|
cbar = plt.colorbar(cm, ax=ax)
|
@@ -322,20 +339,24 @@ def plot_spectrogram(
|
|
322
339
|
return ax.figure, ax
|
323
340
|
|
324
341
|
|
325
|
-
|
326
342
|
def _fmt_time_axis(t, axes, t0=None, tmax=None):
|
327
343
|
if t[-1] > DAY_S: # If time goes beyond a day
|
328
|
-
axes.xaxis.set_major_formatter(
|
344
|
+
axes.xaxis.set_major_formatter(
|
345
|
+
plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}")
|
346
|
+
)
|
329
347
|
axes.set_xlabel("Time [days]")
|
330
348
|
elif t[-1] > HOUR_S: # If time goes beyond an hour
|
331
|
-
axes.xaxis.set_major_formatter(
|
349
|
+
axes.xaxis.set_major_formatter(
|
350
|
+
plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}")
|
351
|
+
)
|
332
352
|
axes.set_xlabel("Time [hr]")
|
333
353
|
elif t[-1] > MIN_S: # If time goes beyond a minute
|
334
|
-
axes.xaxis.set_major_formatter(
|
354
|
+
axes.xaxis.set_major_formatter(
|
355
|
+
plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}")
|
356
|
+
)
|
335
357
|
axes.set_xlabel("Time [min]")
|
336
358
|
else:
|
337
359
|
axes.set_xlabel("Time [s]")
|
338
360
|
t0 = t[0] if t0 is None else t0
|
339
361
|
tmax = t[-1] if tmax is None else tmax
|
340
362
|
axes.set_xlim(t0, tmax)
|
341
|
-
|
@@ -1,11 +1,20 @@
|
|
1
|
+
from typing import Optional, Tuple, Union
|
2
|
+
|
1
3
|
import matplotlib.pyplot as plt
|
2
|
-
from typing import Tuple, Optional, Union
|
3
|
-
from scipy.signal.windows import tukey
|
4
4
|
from scipy.signal import butter, sosfiltfilt
|
5
|
+
from scipy.signal.windows import tukey
|
5
6
|
|
6
|
-
from
|
7
|
-
from .common import
|
8
|
-
|
7
|
+
from ..logger import logger
|
8
|
+
from .common import (
|
9
|
+
fmt_pow2,
|
10
|
+
fmt_time,
|
11
|
+
fmt_timerange,
|
12
|
+
is_documented_by,
|
13
|
+
rfft,
|
14
|
+
rfftfreq,
|
15
|
+
xp,
|
16
|
+
)
|
17
|
+
from .plotting import plot_spectrogram, plot_timeseries
|
9
18
|
|
10
19
|
__all__ = ["TimeSeries"]
|
11
20
|
|
@@ -50,7 +59,7 @@ class TimeSeries:
|
|
50
59
|
|
51
60
|
@is_documented_by(plot_spectrogram)
|
52
61
|
def plot_spectrogram(
|
53
|
-
|
62
|
+
self, ax=None, spec_kwargs={}, plot_kwargs={}
|
54
63
|
) -> Tuple[plt.Figure, plt.Axes]:
|
55
64
|
return plot_spectrogram(
|
56
65
|
self.data,
|
@@ -122,9 +131,11 @@ class TimeSeries:
|
|
122
131
|
trange = fmt_timerange((self.t0, self.tend))
|
123
132
|
T = " ".join(fmt_time(self.duration, units=True))
|
124
133
|
n = fmt_pow2(len(self))
|
125
|
-
return
|
134
|
+
return (
|
135
|
+
f"TimeSeries(n={n}, trange={trange}, T={T}, fs={self.fs:.2f} Hz)"
|
136
|
+
)
|
126
137
|
|
127
|
-
def to_frequencyseries(self) ->
|
138
|
+
def to_frequencyseries(self) -> "FrequencySeries":
|
128
139
|
"""
|
129
140
|
Convert the time series to a frequency series using the one-sided FFT.
|
130
141
|
|
@@ -137,14 +148,15 @@ class TimeSeries:
|
|
137
148
|
data = rfft(self.data)
|
138
149
|
|
139
150
|
from .frequencyseries import FrequencySeries # Avoid circular import
|
151
|
+
|
140
152
|
return FrequencySeries(data, freq, t0=self.t0)
|
141
153
|
|
142
154
|
def to_wavelet(
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
) ->
|
155
|
+
self,
|
156
|
+
Nf: Union[int, None] = None,
|
157
|
+
Nt: Union[int, None] = None,
|
158
|
+
nx: Optional[float] = 4.0,
|
159
|
+
) -> "Wavelet":
|
148
160
|
"""
|
149
161
|
Convert the time series to a wavelet representation.
|
150
162
|
|
@@ -166,32 +178,37 @@ class TimeSeries:
|
|
166
178
|
hf = self.to_frequencyseries()
|
167
179
|
return hf.to_wavelet(Nf, Nt, nx=nx)
|
168
180
|
|
169
|
-
|
170
|
-
def __add__(self, other: 'TimeSeries') -> 'TimeSeries':
|
181
|
+
def __add__(self, other: "TimeSeries") -> "TimeSeries":
|
171
182
|
"""Add two TimeSeries objects together."""
|
172
183
|
if self.shape != other.shape:
|
173
|
-
raise ValueError(
|
184
|
+
raise ValueError(
|
185
|
+
"TimeSeries objects must have the same shape to add them together"
|
186
|
+
)
|
174
187
|
return TimeSeries(self.data + other.data, self.time)
|
175
188
|
|
176
|
-
def __sub__(self, other:
|
189
|
+
def __sub__(self, other: "TimeSeries") -> "TimeSeries":
|
177
190
|
"""Subtract one TimeSeries object from another."""
|
178
191
|
if self.shape != other.shape:
|
179
|
-
raise ValueError(
|
192
|
+
raise ValueError(
|
193
|
+
"TimeSeries objects must have the same shape to subtract them"
|
194
|
+
)
|
180
195
|
return TimeSeries(self.data - other.data, self.time)
|
181
196
|
|
182
|
-
def __eq__(self, other:
|
197
|
+
def __eq__(self, other: "TimeSeries") -> bool:
|
183
198
|
"""Check if two TimeSeries objects are equal."""
|
184
199
|
shape_same = self.shape == other.shape
|
185
200
|
range_same = self.t0 == other.t0 and self.tend == other.tend
|
186
|
-
time_same =
|
201
|
+
time_same = xp.allclose(self.time, other.time)
|
187
202
|
data_same = xp.allclose(self.data, other.data)
|
188
203
|
return shape_same and range_same and data_same and time_same
|
189
204
|
|
190
|
-
def __mul__(self, other: float) ->
|
205
|
+
def __mul__(self, other: float) -> "TimeSeries":
|
191
206
|
"""Multiply a TimeSeries object by a scalar."""
|
192
207
|
return TimeSeries(self.data * other, self.time)
|
193
208
|
|
194
|
-
def zero_pad_to_power_of_2(
|
209
|
+
def zero_pad_to_power_of_2(
|
210
|
+
self, tukey_window_alpha: float = 0.0
|
211
|
+
) -> "TimeSeries":
|
195
212
|
"""Zero pad the time series to make the length a power of two (useful to speed up FFTs, O(NlogN) versus O(N^2)).
|
196
213
|
|
197
214
|
Parameters
|
@@ -207,7 +224,7 @@ class TimeSeries:
|
|
207
224
|
"""
|
208
225
|
N, dt, t0 = self.ND, self.dt, self.t0
|
209
226
|
pow_2 = xp.ceil(xp.log2(N))
|
210
|
-
n_pad = int((2
|
227
|
+
n_pad = int((2**pow_2) - N)
|
211
228
|
new_N = N + n_pad
|
212
229
|
if n_pad > 0:
|
213
230
|
logger.warning(
|
@@ -220,7 +237,12 @@ class TimeSeries:
|
|
220
237
|
time = xp.arange(0, len(data) * dt, dt) + t0
|
221
238
|
return TimeSeries(data, time)
|
222
239
|
|
223
|
-
def highpass_filter(
|
240
|
+
def highpass_filter(
|
241
|
+
self,
|
242
|
+
fmin: float,
|
243
|
+
tukey_window_alpha: float = 0.0,
|
244
|
+
bandpass_order: int = 4,
|
245
|
+
) -> "TimeSeries":
|
224
246
|
"""
|
225
247
|
Filter the time series with a highpass bandpass filter.
|
226
248
|
|
@@ -244,24 +266,25 @@ class TimeSeries:
|
|
244
266
|
"""
|
245
267
|
|
246
268
|
if fmin <= 0 or fmin > self.nyquist_frequency:
|
247
|
-
raise ValueError(
|
269
|
+
raise ValueError(
|
270
|
+
f"Invalid fmin value: {fmin}. Must be in the range [0, {self.nyquist_frequency}]"
|
271
|
+
)
|
248
272
|
|
249
|
-
sos = butter(
|
273
|
+
sos = butter(
|
274
|
+
bandpass_order, Wn=fmin, btype="highpass", output="sos", fs=self.fs
|
275
|
+
)
|
250
276
|
window = tukey(self.ND, alpha=tukey_window_alpha)
|
251
277
|
data = self.data.copy()
|
252
278
|
data = sosfiltfilt(sos, data * window)
|
253
279
|
return TimeSeries(data, self.time)
|
254
280
|
|
255
281
|
def __copy__(self):
|
256
|
-
return TimeSeries(
|
257
|
-
self.data.copy(),
|
258
|
-
self.time.copy()
|
259
|
-
)
|
282
|
+
return TimeSeries(self.data.copy(), self.time.copy())
|
260
283
|
|
261
284
|
def copy(self):
|
262
285
|
return self.__copy__()
|
263
286
|
|
264
|
-
def __getitem__(self, key)->"TimeSeries":
|
287
|
+
def __getitem__(self, key) -> "TimeSeries":
|
265
288
|
if isinstance(key, slice):
|
266
289
|
# Handle slicing
|
267
290
|
return self.__handle_slice(key)
|
@@ -269,12 +292,9 @@ class TimeSeries:
|
|
269
292
|
# Handle regular indexing
|
270
293
|
return TimeSeries(self.data[key], self.time[key])
|
271
294
|
|
272
|
-
def __handle_slice(self, slice_obj)->"TimeSeries":
|
273
|
-
return TimeSeries(
|
274
|
-
self.data[slice_obj],
|
275
|
-
self.time[slice_obj]
|
276
|
-
)
|
295
|
+
def __handle_slice(self, slice_obj) -> "TimeSeries":
|
296
|
+
return TimeSeries(self.data[slice_obj], self.time[slice_obj])
|
277
297
|
|
278
298
|
@classmethod
|
279
|
-
def _EMPTY(cls, ND:int, dt:float)->"TimeSeries":
|
280
|
-
return cls(xp.zeros(ND), xp.arange(0, ND*dt, dt))
|
299
|
+
def _EMPTY(cls, ND: int, dt: float) -> "TimeSeries":
|
300
|
+
return cls(xp.zeros(ND), xp.arange(0, ND * dt, dt))
|
@@ -1,10 +1,11 @@
|
|
1
|
-
import
|
2
|
-
from typing import Optional, Tuple
|
1
|
+
from typing import List, Optional, Tuple
|
3
2
|
|
3
|
+
import matplotlib.pyplot as plt
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
-
from .common import is_documented_by, xp
|
6
|
+
from .common import fmt_timerange, is_documented_by, xp
|
7
7
|
from .plotting import plot_wavelet_grid, plot_wavelet_trend
|
8
|
+
from .wavelet_bins import compute_bins
|
8
9
|
|
9
10
|
|
10
11
|
class Wavelet:
|
@@ -22,10 +23,10 @@ class Wavelet:
|
|
22
23
|
"""
|
23
24
|
|
24
25
|
def __init__(
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
self,
|
27
|
+
data: xp.ndarray,
|
28
|
+
time: xp.ndarray,
|
29
|
+
freq: xp.ndarray,
|
29
30
|
):
|
30
31
|
"""
|
31
32
|
Initialize the Wavelet object with data, time, and frequency arrays.
|
@@ -53,17 +54,60 @@ class Wavelet:
|
|
53
54
|
self.time = time
|
54
55
|
self.freq = freq
|
55
56
|
|
57
|
+
@classmethod
|
58
|
+
def zeros_from_grid(cls, time: xp.ndarray, freq: xp.ndarray) -> "Wavelet":
|
59
|
+
"""
|
60
|
+
Create a Wavelet object filled with zeros.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
time: xp.ndarray
|
65
|
+
freq: xp.ndarray
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
Wavelet
|
70
|
+
A Wavelet object with zero-filled data array.
|
71
|
+
"""
|
72
|
+
Nf, Nt = len(freq), len(time)
|
73
|
+
return cls(data=xp.zeros((Nf, Nt)), time=time, freq=freq)
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def zeros(cls, Nf: int, Nt: int, T: float) -> "Wavelet":
|
77
|
+
"""
|
78
|
+
Create a Wavelet object filled with zeros.
|
79
|
+
|
80
|
+
Parameters
|
81
|
+
----------
|
82
|
+
Nf : int
|
83
|
+
Number of frequency bins.
|
84
|
+
Nt : int
|
85
|
+
Number of time bins.
|
86
|
+
|
87
|
+
Returns
|
88
|
+
-------
|
89
|
+
Wavelet
|
90
|
+
A Wavelet object with zero-filled data array.
|
91
|
+
"""
|
92
|
+
return cls.zeros_from_grid(*compute_bins(Nf, Nt, T))
|
93
|
+
|
56
94
|
@is_documented_by(plot_wavelet_grid)
|
57
95
|
def plot(self, ax=None, *args, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
|
58
96
|
kwargs["time_grid"] = kwargs.get("time_grid", self.time)
|
59
97
|
kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
|
60
|
-
return plot_wavelet_grid(
|
98
|
+
return plot_wavelet_grid(
|
99
|
+
wavelet_data=self.data, ax=ax, *args, **kwargs
|
100
|
+
)
|
61
101
|
|
62
102
|
@is_documented_by(plot_wavelet_trend)
|
63
|
-
def plot_trend(
|
103
|
+
def plot_trend(
|
104
|
+
self, ax=None, *args, **kwargs
|
105
|
+
) -> Tuple[plt.Figure, plt.Axes]:
|
64
106
|
kwargs["time_grid"] = kwargs.get("time_grid", self.time)
|
65
107
|
kwargs["freq_grid"] = kwargs.get("freq_grid", self.freq)
|
66
|
-
return plot_wavelet_trend(
|
108
|
+
return plot_wavelet_trend(
|
109
|
+
wavelet_data=self.data, ax=ax, *args, **kwargs
|
110
|
+
)
|
67
111
|
|
68
112
|
@property
|
69
113
|
def Nt(self) -> int:
|
@@ -242,7 +286,8 @@ class Wavelet:
|
|
242
286
|
TimeSeries
|
243
287
|
A `TimeSeries` object representing the time-domain signal.
|
244
288
|
"""
|
245
|
-
from ..inverse import from_wavelet_to_time
|
289
|
+
from ..transforms.inverse import from_wavelet_to_time
|
290
|
+
|
246
291
|
return from_wavelet_to_time(self, dt=self.delta_t, nx=nx, mult=mult)
|
247
292
|
|
248
293
|
def to_frequencyseries(self, nx: float = 4.0) -> "FrequencySeries":
|
@@ -254,7 +299,8 @@ class Wavelet:
|
|
254
299
|
FrequencySeries
|
255
300
|
A `FrequencySeries` object representing the frequency-domain signal.
|
256
301
|
"""
|
257
|
-
from ..inverse import from_wavelet_to_freq
|
302
|
+
from ..transforms.inverse import from_wavelet_to_freq
|
303
|
+
|
258
304
|
return from_wavelet_to_freq(self, dt=self.delta_t, nx=nx)
|
259
305
|
|
260
306
|
def __repr__(self) -> str:
|
@@ -277,40 +323,57 @@ class Wavelet:
|
|
277
323
|
def __add__(self, other):
|
278
324
|
"""Element-wise addition of two Wavelet objects."""
|
279
325
|
if isinstance(other, Wavelet):
|
280
|
-
return Wavelet(
|
326
|
+
return Wavelet(
|
327
|
+
data=self.data + other.data, time=self.time, freq=self.freq
|
328
|
+
)
|
281
329
|
elif isinstance(other, float):
|
282
|
-
return Wavelet(
|
330
|
+
return Wavelet(
|
331
|
+
data=self.data + other, time=self.time, freq=self.freq
|
332
|
+
)
|
283
333
|
|
284
334
|
def __sub__(self, other):
|
285
335
|
"""Element-wise subtraction of two Wavelet objects."""
|
286
336
|
if isinstance(other, Wavelet):
|
287
|
-
return Wavelet(
|
337
|
+
return Wavelet(
|
338
|
+
data=self.data - other.data, time=self.time, freq=self.freq
|
339
|
+
)
|
288
340
|
elif isinstance(other, float):
|
289
|
-
return Wavelet(
|
341
|
+
return Wavelet(
|
342
|
+
data=self.data - other, time=self.time, freq=self.freq
|
343
|
+
)
|
290
344
|
|
291
345
|
def __mul__(self, other):
|
292
346
|
"""Element-wise multiplication of two Wavelet objects."""
|
293
347
|
if isinstance(other, Wavelet):
|
294
|
-
return Wavelet(
|
348
|
+
return Wavelet(
|
349
|
+
data=self.data * other.data, time=self.time, freq=self.freq
|
350
|
+
)
|
295
351
|
elif isinstance(other, float):
|
296
|
-
return Wavelet(
|
352
|
+
return Wavelet(
|
353
|
+
data=self.data * other, time=self.time, freq=self.freq
|
354
|
+
)
|
297
355
|
|
298
356
|
def __truediv__(self, other):
|
299
357
|
"""Element-wise division of two Wavelet objects."""
|
300
358
|
if isinstance(other, Wavelet):
|
301
|
-
return Wavelet(
|
359
|
+
return Wavelet(
|
360
|
+
data=self.data / other.data, time=self.time, freq=self.freq
|
361
|
+
)
|
302
362
|
elif isinstance(other, float):
|
303
|
-
return Wavelet(
|
363
|
+
return Wavelet(
|
364
|
+
data=self.data / other, time=self.time, freq=self.freq
|
365
|
+
)
|
304
366
|
|
305
|
-
def __eq__(self, other:"Wavelet") -> bool:
|
367
|
+
def __eq__(self, other: "Wavelet") -> bool:
|
306
368
|
"""Element-wise comparison of two Wavelet objects."""
|
307
369
|
data_all_same = xp.isclose(xp.nansum(self.data - other.data), 0)
|
308
370
|
time_same = (self.time == other.time).all()
|
309
371
|
freq_same = (self.freq == other.freq).all()
|
310
372
|
return data_all_same and time_same and freq_same
|
311
373
|
|
312
|
-
|
313
|
-
|
374
|
+
def noise_weighted_inner_product(
|
375
|
+
self, other: "Wavelet", psd: "Wavelet"
|
376
|
+
) -> float:
|
314
377
|
"""
|
315
378
|
Compute the noise-weighted inner product of two wavelet grids given a PSD.
|
316
379
|
|
@@ -326,11 +389,11 @@ class Wavelet:
|
|
326
389
|
float
|
327
390
|
The noise-weighted inner product.
|
328
391
|
"""
|
329
|
-
from
|
330
|
-
return noise_weighted_inner_product(self, other, psd)
|
392
|
+
from ..utils import noise_weighted_inner_product
|
331
393
|
|
394
|
+
return noise_weighted_inner_product(self, other, psd)
|
332
395
|
|
333
|
-
def matched_filter_snr(self, template:"Wavelet", psd:"Wavelet") -> float:
|
396
|
+
def matched_filter_snr(self, template: "Wavelet", psd: "Wavelet") -> float:
|
334
397
|
"""
|
335
398
|
Compute the matched filter SNR of the wavelet grid given a template.
|
336
399
|
|
@@ -347,7 +410,7 @@ class Wavelet:
|
|
347
410
|
mf = self.noise_weighted_inner_product(template, psd)
|
348
411
|
return mf / self.optimal_snr(psd)
|
349
412
|
|
350
|
-
def optimal_snr(self, psd:"Wavelet") -> float:
|
413
|
+
def optimal_snr(self, psd: "Wavelet") -> float:
|
351
414
|
"""
|
352
415
|
Compute the optimal SNR of the wavelet grid given a PSD.
|
353
416
|
|
@@ -365,10 +428,27 @@ class Wavelet:
|
|
365
428
|
|
366
429
|
def __copy__(self):
|
367
430
|
return Wavelet(
|
368
|
-
data=self.data.copy(),
|
369
|
-
time=self.time.copy(),
|
370
|
-
freq=self.freq.copy()
|
431
|
+
data=self.data.copy(), time=self.time.copy(), freq=self.freq.copy()
|
371
432
|
)
|
372
433
|
|
373
434
|
def copy(self):
|
374
|
-
return self.__copy__()
|
435
|
+
return self.__copy__()
|
436
|
+
|
437
|
+
|
438
|
+
class WaveletMask(Wavelet):
|
439
|
+
@property
|
440
|
+
def mask(self):
|
441
|
+
return self.data
|
442
|
+
|
443
|
+
def __repr__(self):
|
444
|
+
return f"WaveletMask({self.mask.shape}, {fmt_timerange(self.time)}, {self.freq})"
|
445
|
+
|
446
|
+
@classmethod
|
447
|
+
def from_frange(
|
448
|
+
cls, time_grid: xp.ndarray, freq_grid: xp.ndarray, frange: List[float]
|
449
|
+
):
|
450
|
+
self = cls.zeros_from_grid(time_grid, freq_grid)
|
451
|
+
self.mask[
|
452
|
+
(freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
|
453
|
+
] = True
|
454
|
+
return self
|
@@ -2,7 +2,8 @@ from typing import Tuple, Union
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
|
5
|
-
from
|
5
|
+
from .frequencyseries import FrequencySeries
|
6
|
+
from .timeseries import TimeSeries
|
6
7
|
|
7
8
|
|
8
9
|
def _preprocess_bins(
|
@@ -10,7 +11,6 @@ def _preprocess_bins(
|
|
10
11
|
) -> Tuple[int, int]:
|
11
12
|
"""preprocess the bins"""
|
12
13
|
|
13
|
-
|
14
14
|
if isinstance(data, TimeSeries):
|
15
15
|
N = len(data)
|
16
16
|
elif isinstance(data, FrequencySeries):
|
@@ -34,7 +34,6 @@ def _get_bins(
|
|
34
34
|
Nf: Union[int, None] = None,
|
35
35
|
Nt: Union[int, None] = None,
|
36
36
|
) -> Tuple[np.ndarray, np.ndarray]:
|
37
|
-
|
38
37
|
T = data.duration
|
39
38
|
t_bins, f_bins = compute_bins(Nf, Nt, T)
|
40
39
|
|
@@ -42,12 +41,12 @@ def _get_bins(
|
|
42
41
|
# fs = N / T
|
43
42
|
# assert delta_f == fmax / Nf, f"delta_f={delta_f} != fmax/Nf={fmax/Nf}"
|
44
43
|
|
45
|
-
t_bins+= data.t0
|
44
|
+
t_bins += data.t0
|
46
45
|
|
47
46
|
return t_bins, f_bins
|
48
47
|
|
49
48
|
|
50
|
-
def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
|
49
|
+
def compute_bins(Nf: int, Nt: int, T: float) -> Tuple[np.ndarray, np.ndarray]:
|
51
50
|
"""Get the bins for the wavelet transform
|
52
51
|
Eq 4-6 in Wavelets paper
|
53
52
|
"""
|
@@ -55,4 +54,4 @@ def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
|
|
55
54
|
delta_F = 1 / (2 * delta_T)
|
56
55
|
t_bins = np.arange(0, Nt) * delta_T
|
57
56
|
f_bins = np.arange(0, Nf) * delta_F
|
58
|
-
return
|
57
|
+
return t_bins, f_bins
|
pywavelet/utils.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Union
|
|
3
3
|
import numpy as np
|
4
4
|
from scipy.interpolate import interp1d
|
5
5
|
|
6
|
-
from .
|
6
|
+
from .types import FrequencySeries, TimeSeries, Wavelet, WaveletMask
|
7
7
|
|
8
8
|
DATA_TYPE = Union[TimeSeries, FrequencySeries, Wavelet]
|
9
9
|
|
@@ -34,10 +34,13 @@ def evolutionary_psd_from_stationary_psd(
|
|
34
34
|
return Wavelet(psd_grid.T, time=t_grid, freq=f_grid)
|
35
35
|
|
36
36
|
|
37
|
-
def noise_weighted_inner_product(
|
37
|
+
def noise_weighted_inner_product(
|
38
|
+
d: Wavelet, h: Wavelet, PSD: Wavelet
|
39
|
+
) -> float:
|
38
40
|
return np.nansum((d.data * h.data) / PSD.data)
|
39
41
|
|
40
|
-
|
42
|
+
|
43
|
+
def compute_snr(d: Wavelet, h: Wavelet, PSD: Wavelet) -> float:
|
41
44
|
"""Compute the SNR of a model h[ti,fi] given freqseries d[ti,fi] and PSD[ti,fi].
|
42
45
|
|
43
46
|
SNR(h) = Sum_{ti,fi} [ h_hat[ti,fi] d[ti,fi] / PSD[ti,fi]
|
@@ -60,5 +63,14 @@ def compute_snr(d:Wavelet, h: Wavelet, PSD: Wavelet) -> float:
|
|
60
63
|
return np.sqrt(noise_weighted_inner_product(d, h, PSD))
|
61
64
|
|
62
65
|
|
63
|
-
def compute_likelihood(
|
64
|
-
|
66
|
+
def compute_likelihood(
|
67
|
+
data: Wavelet, template: Wavelet, psd: Wavelet, mask: WaveletMask = None
|
68
|
+
) -> float:
|
69
|
+
d = data.data
|
70
|
+
h = template.data
|
71
|
+
p = psd.data
|
72
|
+
if mask is not None:
|
73
|
+
m = mask.mask
|
74
|
+
d, h, p = d * m, h * m, p * m
|
75
|
+
|
76
|
+
return -0.5 * np.nansum((d - h) ** 2 / p)
|
@@ -0,0 +1,25 @@
|
|
1
|
+
pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
|
2
|
+
pywavelet/_version.py,sha256=PKIMyjdUACH4-ONvtunQCnYE2UhlMfp9su83e3HXl5E,411
|
3
|
+
pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
|
4
|
+
pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
|
5
|
+
pywavelet/transforms/__init__.py,sha256=1Ibsup9UwMajeZ9NCQ4BN15qZTeJ_EHkgGu8XNFdA18,255
|
6
|
+
pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
|
7
|
+
pywavelet/transforms/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
|
8
|
+
pywavelet/transforms/forward/from_freq.py,sha256=wCiyLpzJE3rGxYjQBdXlwkxPIRYhQWjKq0C_8zYlmDk,2697
|
9
|
+
pywavelet/transforms/forward/from_time.py,sha256=-Y6VEKwDCYBAHAjLdO46vT-6alpM5fXTgTZ_xkYxqA8,2381
|
10
|
+
pywavelet/transforms/forward/main.py,sha256=Gfy0sp-woy_3ihKMzuk2WuZ7dRk-Mm6sp5dVpYrSvj4,4005
|
11
|
+
pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
|
12
|
+
pywavelet/transforms/inverse/main.py,sha256=Q7wUaRjB1sgqdB7dniWQGbPTWYQNnIsIrYtjsaHJEdE,3012
|
13
|
+
pywavelet/transforms/inverse/to_freq.py,sha256=so_TDbwdS1N8sd1QcpeAEkI10XFDtoFJGohtD4YulZM,2809
|
14
|
+
pywavelet/transforms/inverse/to_time.py,sha256=w5vmImdsb_4YeInZtXh0llsThLTxS0tmYDlNGJ-IUew,5080
|
15
|
+
pywavelet/types/__init__.py,sha256=5YptzQvYBnRfC8N5lpOBf9I1lzpJ0pw0QMnvIcwP3YI,122
|
16
|
+
pywavelet/types/common.py,sha256=OSAW6GqLTgqJ-RYEv__XbzsfFd8AFo5w-ctXQ4XAFZo,1317
|
17
|
+
pywavelet/types/frequencyseries.py,sha256=UqcE6UQfw5HZm4na2q9k-X-mfqO-BCiTAvGjaYpSrwc,7518
|
18
|
+
pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
|
19
|
+
pywavelet/types/timeseries.py,sha256=6DPO0xLi4Dq2srhJLmavFMf4fYIC3wwdbyMU7lMdjTo,9446
|
20
|
+
pywavelet/types/wavelet.py,sha256=ptTEnq6nRZiW2x6g_NV_FuoeVNOsbNOu6caSxYDZNgk,12583
|
21
|
+
pywavelet/types/wavelet_bins.py,sha256=SC9nhyigWvOfs2TbH8-Ck_iS1Mrz6sG-8EaIFYLIHuk,1453
|
22
|
+
pywavelet-0.1.1.dist-info/METADATA,sha256=OgKTqqZQKfhCCcsYpcXnrut-2EAH42c3tCpCGc0PIDg,1307
|
23
|
+
pywavelet-0.1.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
24
|
+
pywavelet-0.1.1.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
|
25
|
+
pywavelet-0.1.1.dist-info/RECORD,,
|
pywavelet-0.0.5.dist-info/RECORD
DELETED
@@ -1,25 +0,0 @@
|
|
1
|
-
pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
|
2
|
-
pywavelet/_version.py,sha256=EJB7__SNK9kQS_SWZB_U4DHJ3P8ftF6etZEihTYnuXE,411
|
3
|
-
pywavelet/logger.py,sha256=u5Et6QLUU0IuTnPnxfev5-4GYmihTo7_yBCTf-rVUsA,377
|
4
|
-
pywavelet/utils.py,sha256=rbmZ-PrnFstq05kh6xj8emuQ3-7oXqZinU03bpg3N7A,1746
|
5
|
-
pywavelet/transforms/__init__.py,sha256=FolK8WiVEJmGDC9xMupYVI_essXaXS4LYWKbqEqGx6o,289
|
6
|
-
pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
|
7
|
-
pywavelet/transforms/forward/__init__.py,sha256=Yq4Tg3Ze98-17C9FIkOqMUdiLHe9x_YoyuRvxOxMOP0,176
|
8
|
-
pywavelet/transforms/forward/from_freq.py,sha256=HWADciv746P5zWNhnyRKDEDACTQnwDTR4usf2k0VIWk,2663
|
9
|
-
pywavelet/transforms/forward/from_time.py,sha256=zfpvmq7UHnynzAnsM_Pf4MXpO1GJ0TaCCERQKNsaBNU,2340
|
10
|
-
pywavelet/transforms/forward/main.py,sha256=FJSLUUsk3I91lUeB2YFiXQdX8gB9maFW60e4tgrd45A,4053
|
11
|
-
pywavelet/transforms/forward/wavelet_bins.py,sha256=43CHVSjNgC59Af1_h9t33VBBqGHJmIUaOUhJgekhMnc,1419
|
12
|
-
pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
|
13
|
-
pywavelet/transforms/inverse/main.py,sha256=l5yFvzmWObrO5Xt_8KYp62w829ab0pQLxcv0-QxkvG0,3015
|
14
|
-
pywavelet/transforms/inverse/to_freq.py,sha256=SExZMax-8A-tJpIA86pYY61X2qvlZ2MrZY27uzCQSV0,2778
|
15
|
-
pywavelet/transforms/inverse/to_time.py,sha256=BAYvrr41QHIbzwYMPyMnzv5mqSx40YigmBruWBwtZwc,5041
|
16
|
-
pywavelet/transforms/types/__init__.py,sha256=4wUTVBk6A02xjZUY_w056eUZurYI9vVfa--I3Q6Udng,109
|
17
|
-
pywavelet/transforms/types/common.py,sha256=sLBn2d9cuL6NYeIv6NIogDjY6rYPZAPzCtWGwMlAwkI,1290
|
18
|
-
pywavelet/transforms/types/frequencyseries.py,sha256=Rtwt486UL0-TgMAdcMMVpyfi5PSLzxFcdE_RsYuyxQk,7463
|
19
|
-
pywavelet/transforms/types/plotting.py,sha256=aEIFoSuQRYwYc2639yLkbujXx_aav5A5tXG29rOSMOQ,10275
|
20
|
-
pywavelet/transforms/types/timeseries.py,sha256=Nl1tiKZ7kwu-EZ5JtubwgzgyjQZR86eGU3C3kElIPNg,9296
|
21
|
-
pywavelet/transforms/types/wavelet.py,sha256=raODhyegBd1_esuv2YP6tK_Nj9fYaLQ6O4pxpiFAVZU,10790
|
22
|
-
pywavelet-0.0.5.dist-info/METADATA,sha256=BaLK8jbUqss_8BeGWOlf_031PoCAlEV4rmuSaitQd0E,1307
|
23
|
-
pywavelet-0.0.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
24
|
-
pywavelet-0.0.5.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
|
25
|
-
pywavelet-0.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|