pywavelet 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/__init__.py +17 -5
- pywavelet/_version.py +2 -2
- pywavelet/backend.py +98 -20
- pywavelet/logger.py +6 -6
- pywavelet/transforms/__init__.py +1 -3
- pywavelet/transforms/cupy/forward/from_freq.py +64 -67
- pywavelet/transforms/cupy/forward/main.py +11 -7
- pywavelet/transforms/cupy/inverse/to_freq.py +54 -50
- pywavelet/transforms/jax/forward/from_freq.py +69 -76
- pywavelet/transforms/jax/forward/main.py +9 -6
- pywavelet/transforms/jax/inverse/to_freq.py +17 -28
- pywavelet/transforms/numpy/forward/from_freq.py +14 -6
- pywavelet/transforms/numpy/forward/main.py +13 -4
- pywavelet/transforms/phi_computer.py +35 -20
- pywavelet/types/common.py +1 -1
- pywavelet/types/plotting.py +1 -1
- pywavelet/types/timeseries.py +1 -0
- pywavelet/types/wavelet.py +4 -2
- pywavelet/types/wavelet_bins.py +3 -9
- pywavelet/utils/__init__.py +6 -0
- pywavelet/{utils.py → utils/analysis.py} +1 -1
- pywavelet/utils/timing_cli/__init__.py +0 -0
- pywavelet/utils/timing_cli/cli.py +95 -0
- pywavelet/utils/timing_cli/collect_runtimes.py +192 -0
- pywavelet/utils/timing_cli/plot.py +76 -0
- {pywavelet-0.2.7.dist-info → pywavelet-0.2.8.dist-info}/METADATA +3 -1
- pywavelet-0.2.8.dist-info/RECORD +49 -0
- {pywavelet-0.2.7.dist-info → pywavelet-0.2.8.dist-info}/WHEEL +1 -1
- pywavelet-0.2.8.dist-info/entry_points.txt +2 -0
- pywavelet-0.2.7.dist-info/RECORD +0 -43
- {pywavelet-0.2.7.dist-info → pywavelet-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,28 @@
|
|
1
1
|
from functools import partial
|
2
2
|
|
3
|
+
import jax
|
3
4
|
import jax.numpy as jnp
|
4
5
|
from jax import jit
|
5
6
|
from jax.numpy.fft import ifft
|
6
7
|
|
8
|
+
X64_PRECISION = jax.config.jax_enable_x64
|
9
|
+
|
10
|
+
CMPLX_DTYPE = jnp.complex128 if X64_PRECISION else jnp.complex64
|
11
|
+
|
12
|
+
|
7
13
|
import logging
|
8
14
|
|
9
|
-
logger = logging.getLogger(
|
15
|
+
logger = logging.getLogger("pywavelet")
|
10
16
|
|
11
17
|
|
12
|
-
@partial(jit, static_argnames=("Nf", "Nt"))
|
18
|
+
@partial(jit, static_argnames=("Nf", "Nt", "float_dtype", "complex_dtype"))
|
13
19
|
def transform_wavelet_freq_helper(
|
14
|
-
data: jnp.ndarray,
|
20
|
+
data: jnp.ndarray,
|
21
|
+
Nf: int,
|
22
|
+
Nt: int,
|
23
|
+
phif: jnp.ndarray,
|
24
|
+
float_dtype=jnp.float64,
|
25
|
+
complex_dtype=jnp.complex128,
|
15
26
|
) -> jnp.ndarray:
|
16
27
|
"""
|
17
28
|
Transforms input data from the frequency domain to the wavelet domain using a
|
@@ -26,78 +37,60 @@ def transform_wavelet_freq_helper(
|
|
26
37
|
Returns:
|
27
38
|
- wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
|
28
39
|
"""
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
#
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
40
|
+
logger.debug(
|
41
|
+
f"[JAX TRANSFORM] Input types [data:{type(data)},{data.dtype}, phif:{type(phif)},{phif.dtype}]"
|
42
|
+
)
|
43
|
+
half = Nt // 2
|
44
|
+
f_bins = jnp.arange(Nf + 1) # [0,1,...,Nf]
|
45
|
+
|
46
|
+
# --- 1) build the full (Nf+1, Nt) DX array ---
|
47
|
+
# center (j = 0):
|
48
|
+
center = phif[0] * data[f_bins * half]
|
49
|
+
center = jnp.where((f_bins == 0) | (f_bins == Nf), center / 2.0, center)
|
50
|
+
DX = jnp.zeros((Nf + 1, Nt), complex_dtype)
|
51
|
+
DX = DX.at[:, half].set(center)
|
52
|
+
|
53
|
+
# off-center (j = +/-1...+/-(half−1))
|
54
|
+
offs = jnp.arange(1 - half, half) # length Nt−1
|
55
|
+
jj = f_bins[:, None] * half + offs[None, :] # shape (Nf+1, Nt−1)
|
56
|
+
ii = half + offs # shape (Nt−1,)
|
57
|
+
mask = ((f_bins[:, None] == Nf) & (offs[None, :] > 0)) | (
|
58
|
+
(f_bins[:, None] == 0) & (offs[None, :] < 0)
|
59
|
+
)
|
60
|
+
vals = phif[jnp.abs(offs)] * data[jj]
|
61
|
+
vals = jnp.where(mask, 0.0, vals)
|
62
|
+
DX = DX.at[:, ii].set(vals)
|
63
|
+
|
64
|
+
# --- 2) ifft along time axis ---
|
65
|
+
DXt = jnp.fft.ifft(DX, n=Nt, axis=1)
|
66
|
+
|
67
|
+
# --- 3) unpack into wave (Nt, Nf) ---
|
68
|
+
wave = jnp.zeros((Nt, Nf), float_dtype)
|
69
|
+
sqrt2 = jnp.sqrt(2.0)
|
70
|
+
|
71
|
+
# 3a) DC into col 0, even rows
|
72
|
+
evens = jnp.arange(0, Nt, 2)
|
73
|
+
wave = wave.at[evens, 0].set(jnp.real(DXt[0, evens]) * sqrt2)
|
74
|
+
|
75
|
+
# 3b) Nyquist into col 0, odd rows
|
76
|
+
odds = jnp.arange(1, Nt, 2)
|
77
|
+
wave = wave.at[odds, 0].set(jnp.real(DXt[Nf, evens]) * sqrt2)
|
78
|
+
|
79
|
+
# 3c) intermediate bins 1...Nf−1
|
80
|
+
mids = jnp.arange(1, Nf) # [1...Nf-1]
|
81
|
+
Dt_mid = DXt[mids, :] # shape (Nf-1, Nt)
|
82
|
+
real_m = jnp.real(Dt_mid).T # (Nt, Nf-1)
|
83
|
+
imag_m = jnp.imag(Dt_mid).T # (Nt, Nf-1)
|
84
|
+
|
85
|
+
odd_f = (mids % 2) == 1 # shape (Nf-1,)
|
86
|
+
n_idx = jnp.arange(Nt)[:, None] # (Nt,1)
|
87
|
+
odd_n_f = ((n_idx + mids[None, :]) % 2) == 1 # (Nt, Nf-1)
|
88
|
+
|
89
|
+
mid_vals = jnp.where(
|
90
|
+
odd_n_f,
|
91
|
+
jnp.where(odd_f, -imag_m, imag_m),
|
92
|
+
jnp.where(odd_f, real_m, real_m),
|
46
93
|
)
|
94
|
+
wave = wave.at[:, 1:Nf].set(mid_vals)
|
47
95
|
|
48
|
-
|
49
|
-
DX = jnp.zeros(
|
50
|
-
(Nf, Nt), dtype=jnp.complex64
|
51
|
-
) # TODO: Check dtype -- is complex64 sufficient?
|
52
|
-
DX = DX.at[:, Nt // 2].set(
|
53
|
-
initial_values
|
54
|
-
) # Set initial values at the center of the transformation (2 sided FFT)
|
55
|
-
|
56
|
-
# Compute time indices for all offsets around the midpoint
|
57
|
-
j_range = jnp.arange(
|
58
|
-
1 - Nt // 2, Nt // 2
|
59
|
-
) # Time offsets (centered around zero)
|
60
|
-
j = jnp.abs(j_range) # Absolute offset indices
|
61
|
-
i = i_base + j_range # Time indices in output array
|
62
|
-
|
63
|
-
# Compute conditions for edge cases
|
64
|
-
cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
|
65
|
-
cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
|
66
|
-
cond3 = j[None, :] == 0 # Center of the transformation (no offset)
|
67
|
-
|
68
|
-
# Compute frequency indices for the input data
|
69
|
-
jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
|
70
|
-
val = jnp.where(
|
71
|
-
cond1 | cond2, 0.0, phif[j] * data[jj]
|
72
|
-
) # Wavelet filter application
|
73
|
-
DX = DX.at[:, i].set(
|
74
|
-
jnp.where(cond3, DX[:, i], val)
|
75
|
-
) # Update DX with computed values
|
76
|
-
# At this point, DX contains the data FFT'd with the wavelet filter
|
77
|
-
# (each row is a frequency bin, each column is a time bin)
|
78
|
-
|
79
|
-
# Perform the inverse FFT along the time dimension
|
80
|
-
DX_trans = ifft(DX, axis=1)
|
81
|
-
|
82
|
-
# Fill the wavelet output array based on the inverse FFT results
|
83
|
-
n_range = jnp.arange(Nt) # Time indices
|
84
|
-
cond1 = (
|
85
|
-
n_range[:, None] + f_bins[None, :]
|
86
|
-
) % 2 == 1 # Odd/even alternation
|
87
|
-
cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
|
88
|
-
|
89
|
-
# Assign real and imaginary parts based on conditions
|
90
|
-
real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
|
91
|
-
imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
|
92
|
-
wave = jnp.where(cond1, imag_part.T, real_part.T)
|
93
|
-
|
94
|
-
# Special cases for frequency bins 0 (DC) and Nf (Nyquist)
|
95
|
-
wave = wave.at[::2, 0].set(
|
96
|
-
jnp.real(DX_trans[0, ::2] * jnp.sqrt(2))
|
97
|
-
) # DC component
|
98
|
-
wave = wave.at[1::2, -1].set(
|
99
|
-
jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2))
|
100
|
-
) # Nyquist component
|
101
|
-
|
102
|
-
# Return the wavelet-transformed array (transposed for freq-major layout)
|
103
|
-
return wave.T # (Nt, Nf) -> (Nf, Nt)
|
96
|
+
return wave.T
|
@@ -2,6 +2,7 @@ from typing import Union
|
|
2
2
|
|
3
3
|
import jax.numpy as jnp
|
4
4
|
|
5
|
+
from .... import backend
|
5
6
|
from ....logger import logger
|
6
7
|
from ....types import FrequencySeries, TimeSeries, Wavelet
|
7
8
|
from ....types.wavelet_bins import _get_bins, _preprocess_bins
|
@@ -16,7 +17,6 @@ def from_time_to_wavelet(
|
|
16
17
|
Nt: Union[int, None] = None,
|
17
18
|
nx: float = 4.0,
|
18
19
|
mult: int = 32,
|
19
|
-
**kwargs,
|
20
20
|
) -> Wavelet:
|
21
21
|
"""Transforms time-domain data to wavelet-domain data.
|
22
22
|
|
@@ -45,7 +45,6 @@ def from_time_to_wavelet(
|
|
45
45
|
|
46
46
|
"""
|
47
47
|
Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
|
48
|
-
dt = timeseries.dt
|
49
48
|
t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
|
50
49
|
|
51
50
|
ND = Nf * Nt
|
@@ -73,7 +72,6 @@ def from_freq_to_wavelet(
|
|
73
72
|
Nf: Union[int, None] = None,
|
74
73
|
Nt: Union[int, None] = None,
|
75
74
|
nx: float = 4.0,
|
76
|
-
**kwargs,
|
77
75
|
) -> Wavelet:
|
78
76
|
"""Transforms frequency-domain data to wavelet-domain data.
|
79
77
|
|
@@ -100,7 +98,12 @@ def from_freq_to_wavelet(
|
|
100
98
|
t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
|
101
99
|
phif = jnp.array(phitilde_vec_norm(Nf, Nt, d=nx))
|
102
100
|
wave = transform_wavelet_freq_helper(
|
103
|
-
freqseries.data,
|
101
|
+
freqseries.data,
|
102
|
+
Nf=Nf,
|
103
|
+
Nt=Nt,
|
104
|
+
phif=phif,
|
105
|
+
float_dtype=backend.float_dtype,
|
106
|
+
complex_dtype=backend.complex_dtype,
|
104
107
|
)
|
105
|
-
|
106
|
-
return Wavelet(
|
108
|
+
factor = (2 / Nf) * jnp.sqrt(2)
|
109
|
+
return Wavelet(factor * wave, time=t_bins, freq=f_bins)
|
@@ -11,73 +11,56 @@ def inverse_wavelet_freq_helper(
|
|
11
11
|
wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
|
12
12
|
) -> jnp.ndarray:
|
13
13
|
"""JAX vectorized function for inverse_wavelet_freq with corrected shapes and ranges."""
|
14
|
-
# Transpose to match the NumPy version.
|
15
14
|
wave_in = wave_in.T
|
16
15
|
ND = Nf * Nt
|
17
16
|
|
18
|
-
# Allocate prefactor2s for each m (shape: (Nf+1, Nt)).
|
19
17
|
m_range = jnp.arange(Nf + 1)
|
20
18
|
prefactor2s = jnp.zeros((Nf + 1, Nt), dtype=jnp.complex128)
|
21
19
|
n_range = jnp.arange(Nt)
|
22
20
|
|
23
|
-
# m
|
21
|
+
# Handle m=0 and m=Nf cases
|
24
22
|
prefactor2s = prefactor2s.at[0].set(
|
25
23
|
2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
|
26
24
|
)
|
27
|
-
|
28
|
-
# m == Nf case
|
29
25
|
prefactor2s = prefactor2s.at[Nf].set(
|
30
26
|
2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
|
31
27
|
)
|
32
28
|
|
33
|
-
#
|
29
|
+
# Handle middle m cases
|
34
30
|
m_mid = m_range[1:Nf]
|
35
|
-
|
36
|
-
n_grid, m_grid = jnp.meshgrid(n_range, m_mid)
|
37
|
-
# Index the transposed wave_in using (n, m) as in the NumPy version.
|
31
|
+
n_grid, m_grid = jnp.meshgrid(n_range, m_mid, indexing="ij")
|
38
32
|
val = wave_in[n_grid, m_grid]
|
39
|
-
# Apply the alternating multiplier based on (n+m) parity.
|
40
33
|
mult2 = jnp.where((n_grid + m_grid) % 2, -1j, 1)
|
41
|
-
prefactor2s = prefactor2s.at[1:Nf].set(mult2 * val)
|
34
|
+
prefactor2s = prefactor2s.at[1:Nf].set((mult2 * val).T)
|
42
35
|
|
43
|
-
# Apply FFT along axis 1 for all m.
|
44
36
|
fft_prefactor2s = fft(prefactor2s, axis=1)
|
45
37
|
|
46
|
-
# Allocate the result array with corrected shape.
|
47
38
|
res = jnp.zeros(ND // 2 + 1, dtype=jnp.complex128)
|
48
39
|
|
49
|
-
#
|
40
|
+
# Unpack for m=0 and m=Nf
|
50
41
|
i_ind_range = jnp.arange(Nt // 2)
|
51
|
-
i_0 =
|
52
|
-
i_Nf = jnp.abs(Nf * (Nt // 2) - i_ind_range)
|
42
|
+
i_0 = i_ind_range
|
53
43
|
ind3_0 = (2 * i_0) % Nt
|
54
|
-
ind3_Nf = (2 * i_Nf) % Nt
|
55
|
-
|
56
44
|
res = res.at[i_0].add(fft_prefactor2s[0, ind3_0] * phif[i_ind_range])
|
45
|
+
|
46
|
+
i_Nf = jnp.abs(Nf * (Nt // 2) - i_ind_range)
|
47
|
+
ind3_Nf = (2 * i_Nf) % Nt
|
57
48
|
res = res.at[i_Nf].add(fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range])
|
58
|
-
|
49
|
+
|
59
50
|
special_index = jnp.abs(Nf * (Nt // 2) - (Nt // 2))
|
60
51
|
res = res.at[special_index].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
|
61
52
|
|
62
|
-
#
|
53
|
+
# Unpack for middle m values
|
63
54
|
m_mid = m_range[1:Nf]
|
64
|
-
# Use range [0, Nt//2) to match the loop in NumPy version.
|
65
55
|
i_ind_range_mid = jnp.arange(Nt // 2)
|
66
|
-
# Create meshgrid for vectorized computation.
|
67
56
|
m_grid_mid, i_ind_grid_mid = jnp.meshgrid(
|
68
57
|
m_mid, i_ind_range_mid, indexing="ij"
|
69
58
|
)
|
70
|
-
|
71
|
-
# Compute indices i1 and i2 following the NumPy logic.
|
72
59
|
i1 = (Nt // 2) * m_grid_mid - i_ind_grid_mid
|
73
60
|
i2 = (Nt // 2) * m_grid_mid + i_ind_grid_mid
|
74
|
-
# Compute the wrapped indices for FFT results.
|
75
61
|
ind31 = i1 % Nt
|
76
62
|
ind32 = i2 % Nt
|
77
63
|
|
78
|
-
# Update result array using vectorized adds.
|
79
|
-
# Note: You might need to adjust this further if your target res shape is non-trivial,
|
80
|
-
# because here we assume that i1 and i2 indices fall within the allocated result shape.
|
81
64
|
res = res.at[i1].add(
|
82
65
|
fft_prefactor2s[m_grid_mid, ind31] * phif[i_ind_grid_mid]
|
83
66
|
)
|
@@ -85,4 +68,10 @@ def inverse_wavelet_freq_helper(
|
|
85
68
|
fft_prefactor2s[m_grid_mid, ind32] * phif[i_ind_grid_mid]
|
86
69
|
)
|
87
70
|
|
71
|
+
# Correct the center points for middle m's
|
72
|
+
center_indices = (Nt // 2) * m_mid
|
73
|
+
fft_indices = center_indices % Nt
|
74
|
+
values = fft_prefactor2s[m_mid, fft_indices] * phif[0]
|
75
|
+
res = res.at[center_indices].set(values)
|
76
|
+
|
88
77
|
return res
|
@@ -12,7 +12,12 @@ logger = logging.getLogger("pywavelet")
|
|
12
12
|
|
13
13
|
|
14
14
|
def transform_wavelet_freq_helper(
|
15
|
-
data: np.ndarray,
|
15
|
+
data: np.ndarray,
|
16
|
+
Nf: int,
|
17
|
+
Nt: int,
|
18
|
+
phif: np.ndarray,
|
19
|
+
float_dtype: np.dtype = np.float64,
|
20
|
+
complex_dtype: np.dtype = np.complex128,
|
16
21
|
) -> np.ndarray:
|
17
22
|
"""
|
18
23
|
Forward wavelet transform helper using the fast wavelet domain transform,
|
@@ -36,14 +41,14 @@ def transform_wavelet_freq_helper(
|
|
36
41
|
f_bin==0 and f_bin==Nf are both stored in column 0.
|
37
42
|
"""
|
38
43
|
logger.debug(
|
39
|
-
f"[NUMPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]"
|
44
|
+
f"[NUMPY TRANSFORM] Input types [data:{type(data)},{data.dtype}, phif:{type(phif)},{phif.dtype}]"
|
40
45
|
)
|
41
|
-
wave = np.zeros((Nt, Nf), dtype=
|
42
|
-
DX = np.zeros(Nt, dtype=
|
46
|
+
wave = np.zeros((Nt, Nf), dtype=float_dtype)
|
47
|
+
DX = np.zeros(Nt, dtype=complex_dtype)
|
43
48
|
# Create a copy of the input data (if needed).
|
44
49
|
freq_strain = data.copy()
|
45
50
|
__core(Nf, Nt, DX, freq_strain, phif, wave)
|
46
|
-
return wave
|
51
|
+
return wave.T
|
47
52
|
|
48
53
|
|
49
54
|
@njit()
|
@@ -64,7 +69,10 @@ def __core(
|
|
64
69
|
for f_bin in range(0, Nf + 1):
|
65
70
|
__fill_wave_1(f_bin, Nt, Nf, DX, data, phif)
|
66
71
|
# Use rocket-fft's ifft (which is JIT-able) instead of np.fft.ifft.
|
67
|
-
DX_trans = np.fft.ifft(
|
72
|
+
DX_trans = np.fft.ifft(
|
73
|
+
DX,
|
74
|
+
Nt,
|
75
|
+
)
|
68
76
|
__fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
|
69
77
|
|
70
78
|
|
@@ -2,6 +2,7 @@ from typing import Union
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
|
5
|
+
from .... import backend
|
5
6
|
from ....logger import logger
|
6
7
|
from ....types import FrequencySeries, TimeSeries, Wavelet
|
7
8
|
from ....types.wavelet_bins import _get_bins, _preprocess_bins
|
@@ -112,7 +113,15 @@ def from_freq_to_wavelet(
|
|
112
113
|
"""
|
113
114
|
Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
|
114
115
|
t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
|
115
|
-
phif = phitilde_vec_norm(Nf, Nt, d=nx)
|
116
|
-
|
117
|
-
|
118
|
-
|
116
|
+
phif = np.array(phitilde_vec_norm(Nf, Nt, d=nx), dtype=backend.float_dtype)
|
117
|
+
data = np.array(freqseries.data, dtype=backend.complex_dtype)
|
118
|
+
wave = transform_wavelet_freq_helper(
|
119
|
+
data,
|
120
|
+
Nf,
|
121
|
+
Nt,
|
122
|
+
phif,
|
123
|
+
float_dtype=backend.float_dtype,
|
124
|
+
complex_dtype=backend.complex_dtype,
|
125
|
+
)
|
126
|
+
factor = (backend.float_dtype)((2 / Nf) * np.sqrt(2))
|
127
|
+
return Wavelet(factor * wave, time=t_bins, freq=f_bins)
|
@@ -1,25 +1,43 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
"""
|
2
|
+
This module contains functions to compute the Fourier transform of the
|
3
|
+
wavelet function and its normalization. The wavelet function is defined
|
4
|
+
in the frequency domain and is used to transform time-domain data into
|
5
|
+
the wavelet domain.
|
6
|
+
|
7
|
+
Everything in this module is retured as a npfloat64 array.
|
8
|
+
"""
|
3
9
|
|
4
|
-
from
|
10
|
+
from functools import lru_cache
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from jaxtyping import Float64
|
14
|
+
from numpy.fft import ifft
|
15
|
+
from scipy.special import betainc
|
5
16
|
|
6
17
|
__all__ = ["phitilde_vec_norm", "phi_vec", "omega"]
|
7
18
|
|
8
19
|
|
9
|
-
|
20
|
+
@lru_cache(maxsize=None)
|
21
|
+
def omega(Nf: int, Nt: int) -> Float64[np.ndarray, "{Nt}//2+1"]:
|
10
22
|
"""Get the angular frequencies of the time domain signal."""
|
11
23
|
df = 2 * np.pi / (Nf * Nt)
|
12
|
-
return df * np.arange(0, Nt // 2 + 1)
|
24
|
+
return df * np.arange(0, Nt // 2 + 1, dtype=np.float64)
|
13
25
|
|
14
26
|
|
15
|
-
|
27
|
+
@lru_cache(maxsize=None)
|
28
|
+
def phitilde_vec_norm(
|
29
|
+
Nf: int, Nt: int, d: float
|
30
|
+
) -> Float64[np.ndarray, "{Nt}//2+1"]:
|
16
31
|
"""Normalize phitilde for inverse frequency domain transform."""
|
17
32
|
omegas = omega(Nf, Nt)
|
18
33
|
_phi_t = _phitilde_vec(omegas, Nf, d) * np.sqrt(np.pi)
|
19
|
-
return
|
34
|
+
return np.array(_phi_t)
|
20
35
|
|
21
36
|
|
22
|
-
|
37
|
+
@lru_cache(maxsize=None)
|
38
|
+
def phi_vec(
|
39
|
+
Nf: int, d: float = 4.0, q: int = 16
|
40
|
+
) -> Float64[np.ndarray, "2*{q}*{Nf}"]:
|
23
41
|
"""get time domain phi as fourier transform of _phitilde_vec
|
24
42
|
q: number of Nf bins over which the window extends?
|
25
43
|
|
@@ -29,7 +47,6 @@ def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
|
|
29
47
|
half_K = q * Nf # xp.int64(K/2)
|
30
48
|
|
31
49
|
dom = 2 * np.pi / K # max frequency is K/2*dom = pi/dt = OM
|
32
|
-
|
33
50
|
DX = np.zeros(K, dtype=np.complex128)
|
34
51
|
|
35
52
|
# zero frequency
|
@@ -51,12 +68,12 @@ def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
|
|
51
68
|
nrm = np.sqrt(2.0) / np.sqrt(K / dom) # *xp.linalg.norm(phi)
|
52
69
|
|
53
70
|
phi *= nrm
|
54
|
-
return
|
71
|
+
return np.array(phi)
|
55
72
|
|
56
73
|
|
57
74
|
def _phitilde_vec(
|
58
|
-
omega:
|
59
|
-
) ->
|
75
|
+
omega: Float64[np.ndarray, "dim"], Nf: int, d: float = 4.0
|
76
|
+
) -> Float64[np.ndarray, "dim"]:
|
60
77
|
"""Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
|
61
78
|
|
62
79
|
Eq 11 of https://arxiv.org/pdf/2009.00043.pdf (Cornish et al. 2020)
|
@@ -67,8 +84,6 @@ def _phitilde_vec(
|
|
67
84
|
|
68
85
|
Where nu_d = normalized incomplete beta function
|
69
86
|
|
70
|
-
|
71
|
-
|
72
87
|
Parameters
|
73
88
|
----------
|
74
89
|
ω : xp.ndarray
|
@@ -93,22 +108,22 @@ def _phitilde_vec(
|
|
93
108
|
if B <= 0:
|
94
109
|
raise ValueError("B must be greater than 0")
|
95
110
|
|
96
|
-
phi = np.zeros(omega.size)
|
111
|
+
phi = np.zeros(omega.size, dtype=np.float64)
|
97
112
|
mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B) # Minor changes
|
98
113
|
vd = (np.pi / 2.0) * _nu_d(omega[mask], A, B, d=d) # different from paper
|
99
|
-
phi[mask] = inverse_sqrt_dOmega *
|
114
|
+
phi[mask] = inverse_sqrt_dOmega * np.cos(vd)
|
100
115
|
phi[np.abs(omega) < A] = inverse_sqrt_dOmega
|
101
116
|
return phi
|
102
117
|
|
103
118
|
|
104
119
|
def _nu_d(
|
105
|
-
omega:
|
106
|
-
) ->
|
120
|
+
omega: Float64[np.ndarray, "dim"], A: float, B: float, d: float = 4.0
|
121
|
+
) -> Float64[np.ndarray, "dim"]:
|
107
122
|
"""Compute the normalized incomplete beta function.
|
108
123
|
|
109
124
|
Parameters
|
110
125
|
----------
|
111
|
-
|
126
|
+
omega : np.ndarray
|
112
127
|
Array of angular frequencies
|
113
128
|
A : float
|
114
129
|
Lower bound for the beta function
|
@@ -119,7 +134,7 @@ def _nu_d(
|
|
119
134
|
|
120
135
|
Returns
|
121
136
|
-------
|
122
|
-
|
137
|
+
np.ndarray
|
123
138
|
Array of ν_d values
|
124
139
|
|
125
140
|
scipy.special.betainc
|
pywavelet/types/common.py
CHANGED
pywavelet/types/plotting.py
CHANGED
@@ -303,7 +303,7 @@ def plot_periodogram(
|
|
303
303
|
flow = np.min(np.abs(freq))
|
304
304
|
ax.set_xlabel("Frequency [Hz]")
|
305
305
|
ax.set_ylabel("Periodigram")
|
306
|
-
|
306
|
+
ax.set_xlim(left=flow, right=nyquist_frequency / 2)
|
307
307
|
return ax.figure, ax
|
308
308
|
|
309
309
|
|
pywavelet/types/timeseries.py
CHANGED
pywavelet/types/wavelet.py
CHANGED
@@ -3,7 +3,7 @@ from typing import List, Tuple
|
|
3
3
|
import matplotlib.pyplot as plt
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
-
from .common import fmt_timerange, is_documented_by, xp
|
6
|
+
from .common import float_dtype, fmt_timerange, is_documented_by, xp
|
7
7
|
from .plotting import plot_wavelet_grid, plot_wavelet_trend
|
8
8
|
from .wavelet_bins import compute_bins
|
9
9
|
|
@@ -70,7 +70,9 @@ class Wavelet:
|
|
70
70
|
A Wavelet object with zero-filled data array.
|
71
71
|
"""
|
72
72
|
Nf, Nt = len(freq), len(time)
|
73
|
-
return cls(
|
73
|
+
return cls(
|
74
|
+
data=xp.zeros((Nf, Nt), dtype=float_dtype), time=time, freq=freq
|
75
|
+
)
|
74
76
|
|
75
77
|
@classmethod
|
76
78
|
def zeros(cls, Nf: int, Nt: int, T: float) -> "Wavelet":
|
pywavelet/types/wavelet_bins.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Tuple, Union
|
2
2
|
|
3
|
-
from ..backend import xp
|
3
|
+
from ..backend import complex_dtype, float_dtype, xp
|
4
4
|
from .frequencyseries import FrequencySeries
|
5
5
|
from .timeseries import TimeSeries
|
6
6
|
|
@@ -35,13 +35,7 @@ def _get_bins(
|
|
35
35
|
) -> Tuple[xp.ndarray, xp.ndarray]:
|
36
36
|
T = data.duration
|
37
37
|
t_bins, f_bins = compute_bins(Nf, Nt, T)
|
38
|
-
|
39
|
-
# N = len(data)
|
40
|
-
# fs = N / T
|
41
|
-
# assert delta_f == fmax / Nf, f"delta_f={delta_f} != fmax/Nf={fmax/Nf}"
|
42
|
-
|
43
38
|
t_bins += data.t0
|
44
|
-
|
45
39
|
return t_bins, f_bins
|
46
40
|
|
47
41
|
|
@@ -51,6 +45,6 @@ def compute_bins(Nf: int, Nt: int, T: float) -> Tuple[xp.ndarray, xp.ndarray]:
|
|
51
45
|
"""
|
52
46
|
delta_T = T / Nt
|
53
47
|
delta_F = 1 / (2 * delta_T)
|
54
|
-
t_bins = xp.arange(0, Nt) * delta_T
|
55
|
-
f_bins = xp.arange(0, Nf) * delta_F
|
48
|
+
t_bins = xp.arange(0, Nt, dtype=float_dtype) * delta_T
|
49
|
+
f_bins = xp.arange(0, Nf, dtype=float_dtype) * delta_F
|
56
50
|
return t_bins, f_bins
|
@@ -3,7 +3,7 @@ from typing import Union
|
|
3
3
|
import numpy as np
|
4
4
|
from scipy.interpolate import interp1d
|
5
5
|
|
6
|
-
from
|
6
|
+
from ..types import FrequencySeries, TimeSeries, Wavelet, WaveletMask
|
7
7
|
|
8
8
|
DATA_TYPE = Union[TimeSeries, FrequencySeries, Wavelet]
|
9
9
|
|
File without changes
|