pywavelet 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/__init__.py +22 -0
- pywavelet/_version.py +9 -4
- pywavelet/backend.py +87 -17
- pywavelet/transforms/__init__.py +16 -4
- pywavelet/transforms/cupy/__init__.py +12 -0
- pywavelet/transforms/cupy/forward/__init__.py +3 -0
- pywavelet/transforms/cupy/forward/from_freq.py +92 -0
- pywavelet/transforms/cupy/forward/from_time.py +50 -0
- pywavelet/transforms/cupy/forward/main.py +106 -0
- pywavelet/transforms/cupy/inverse/__init__.py +3 -0
- pywavelet/transforms/cupy/inverse/main.py +67 -0
- pywavelet/transforms/cupy/inverse/to_freq.py +62 -0
- pywavelet/transforms/jax/__init__.py +24 -0
- pywavelet/transforms/jax/forward/from_freq.py +6 -0
- pywavelet/transforms/jax/inverse/main.py +4 -6
- pywavelet/transforms/jax/inverse/to_freq.py +39 -25
- pywavelet/transforms/numpy/forward/from_freq.py +77 -19
- pywavelet/transforms/numpy/forward/main.py +0 -1
- pywavelet/transforms/numpy/inverse/main.py +2 -4
- pywavelet/transforms/numpy/inverse/to_freq.py +64 -1
- pywavelet/transforms/phi_computer.py +68 -80
- pywavelet/types/plotting.py +14 -5
- pywavelet/types/wavelet.py +6 -6
- {pywavelet-0.2.5.dist-info → pywavelet-0.2.7.dist-info}/METADATA +18 -7
- pywavelet-0.2.7.dist-info/RECORD +43 -0
- {pywavelet-0.2.5.dist-info → pywavelet-0.2.7.dist-info}/WHEEL +1 -1
- pywavelet-0.2.5.dist-info/RECORD +0 -35
- {pywavelet-0.2.5.dist-info → pywavelet-0.2.7.dist-info}/top_level.txt +0 -0
@@ -10,13 +10,14 @@ from jax.numpy.fft import fft
|
|
10
10
|
def inverse_wavelet_freq_helper(
|
11
11
|
wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
|
12
12
|
) -> jnp.ndarray:
|
13
|
-
"""JAX vectorized function for inverse_wavelet_freq"""
|
13
|
+
"""JAX vectorized function for inverse_wavelet_freq with corrected shapes and ranges."""
|
14
|
+
# Transpose to match the NumPy version.
|
14
15
|
wave_in = wave_in.T
|
15
16
|
ND = Nf * Nt
|
16
17
|
|
18
|
+
# Allocate prefactor2s for each m (shape: (Nf+1, Nt)).
|
17
19
|
m_range = jnp.arange(Nf + 1)
|
18
20
|
prefactor2s = jnp.zeros((Nf + 1, Nt), dtype=jnp.complex128)
|
19
|
-
|
20
21
|
n_range = jnp.arange(Nt)
|
21
22
|
|
22
23
|
# m == 0 case
|
@@ -29,46 +30,59 @@ def inverse_wavelet_freq_helper(
|
|
29
30
|
2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
|
30
31
|
)
|
31
32
|
|
32
|
-
# Other m cases
|
33
|
+
# Other m cases: use meshgrid for vectorization.
|
33
34
|
m_mid = m_range[1:Nf]
|
35
|
+
# Create grids: n_grid (columns) and m_grid (rows)
|
34
36
|
n_grid, m_grid = jnp.meshgrid(n_range, m_mid)
|
37
|
+
# Index the transposed wave_in using (n, m) as in the NumPy version.
|
35
38
|
val = wave_in[n_grid, m_grid]
|
39
|
+
# Apply the alternating multiplier based on (n+m) parity.
|
36
40
|
mult2 = jnp.where((n_grid + m_grid) % 2, -1j, 1)
|
37
41
|
prefactor2s = prefactor2s.at[1:Nf].set(mult2 * val)
|
38
42
|
|
39
|
-
#
|
43
|
+
# Apply FFT along axis 1 for all m.
|
40
44
|
fft_prefactor2s = fft(prefactor2s, axis=1)
|
41
45
|
|
42
|
-
#
|
43
|
-
|
44
|
-
# ND or ND // 2 + 1?
|
45
|
-
# https://github.com/pywavelet/pywavelet/blob/63151a47cde9edc14f1e7e0bf17f554e78ad257c/src/pywavelet/transforms/from_wavelets/inverse_wavelet_freq_funcs.py
|
46
|
-
res = jnp.zeros(ND, dtype=jnp.complex128)
|
46
|
+
# Allocate the result array with corrected shape.
|
47
|
+
res = jnp.zeros(ND // 2 + 1, dtype=jnp.complex128)
|
47
48
|
|
48
|
-
# m == 0
|
49
|
+
# Unpacking for m == 0 and m == Nf cases:
|
49
50
|
i_ind_range = jnp.arange(Nt // 2)
|
50
|
-
i_0 = jnp.abs(i_ind_range)
|
51
|
-
i_Nf = jnp.abs(Nf * Nt // 2 - i_ind_range)
|
51
|
+
i_0 = jnp.abs(i_ind_range) # for m == 0: i = i_ind_range
|
52
|
+
i_Nf = jnp.abs(Nf * (Nt // 2) - i_ind_range)
|
52
53
|
ind3_0 = (2 * i_0) % Nt
|
53
54
|
ind3_Nf = (2 * i_Nf) % Nt
|
54
55
|
|
55
56
|
res = res.at[i_0].add(fft_prefactor2s[0, ind3_0] * phif[i_ind_range])
|
56
57
|
res = res.at[i_Nf].add(fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range])
|
58
|
+
# Special case for m == Nf (ensure the Nyquist frequency is updated correctly)
|
59
|
+
special_index = jnp.abs(Nf * (Nt // 2) - (Nt // 2))
|
60
|
+
res = res.at[special_index].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
|
57
61
|
|
58
|
-
#
|
59
|
-
res = res.at[Nf * Nt // 2].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
|
60
|
-
|
61
|
-
# Other m cases
|
62
|
+
# Unpacking for m in (1, ..., Nf-1)
|
62
63
|
m_mid = m_range[1:Nf]
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
|
64
|
+
# Use range [0, Nt//2) to match the loop in NumPy version.
|
65
|
+
i_ind_range_mid = jnp.arange(Nt // 2)
|
66
|
+
# Create meshgrid for vectorized computation.
|
67
|
+
m_grid_mid, i_ind_grid_mid = jnp.meshgrid(
|
68
|
+
m_mid, i_ind_range_mid, indexing="ij"
|
69
|
+
)
|
70
70
|
|
71
|
-
|
72
|
-
|
71
|
+
# Compute indices i1 and i2 following the NumPy logic.
|
72
|
+
i1 = (Nt // 2) * m_grid_mid - i_ind_grid_mid
|
73
|
+
i2 = (Nt // 2) * m_grid_mid + i_ind_grid_mid
|
74
|
+
# Compute the wrapped indices for FFT results.
|
75
|
+
ind31 = i1 % Nt
|
76
|
+
ind32 = i2 % Nt
|
77
|
+
|
78
|
+
# Update result array using vectorized adds.
|
79
|
+
# Note: You might need to adjust this further if your target res shape is non-trivial,
|
80
|
+
# because here we assume that i1 and i2 indices fall within the allocated result shape.
|
81
|
+
res = res.at[i1].add(
|
82
|
+
fft_prefactor2s[m_grid_mid, ind31] * phif[i_ind_grid_mid]
|
83
|
+
)
|
84
|
+
res = res.at[i2].add(
|
85
|
+
fft_prefactor2s[m_grid_mid, ind32] * phif[i_ind_grid_mid]
|
86
|
+
)
|
73
87
|
|
74
88
|
return res
|
@@ -1,32 +1,69 @@
|
|
1
1
|
"""helper functions for transform_freq"""
|
2
2
|
|
3
|
+
import logging
|
4
|
+
|
3
5
|
import numpy as np
|
4
6
|
from numba import njit
|
5
7
|
|
8
|
+
# import rocket_fft.special as rfft # JIT‐able FFT routines
|
9
|
+
|
10
|
+
|
11
|
+
logger = logging.getLogger("pywavelet")
|
12
|
+
|
6
13
|
|
7
14
|
def transform_wavelet_freq_helper(
|
8
15
|
data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
|
9
16
|
) -> np.ndarray:
|
10
|
-
"""
|
11
|
-
|
17
|
+
"""
|
18
|
+
Forward wavelet transform helper using the fast wavelet domain transform,
|
19
|
+
with a JIT-able FFT (rocket-fft) so that the whole transform is jittable.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
data : np.ndarray
|
24
|
+
Input frequency-domain data (1D array).
|
25
|
+
Nf : int
|
26
|
+
Number of frequency bins.
|
27
|
+
Nt : int
|
28
|
+
Number of time bins.
|
29
|
+
phif : np.ndarray
|
30
|
+
Fourier-domain phase factors (complex-valued array of length Nt//2 + 1).
|
31
|
+
|
32
|
+
Returns
|
33
|
+
-------
|
34
|
+
wave : np.ndarray
|
35
|
+
The wavelet transform output of shape (Nt, Nf). Note that contributions from
|
36
|
+
f_bin==0 and f_bin==Nf are both stored in column 0.
|
37
|
+
"""
|
38
|
+
logger.debug(
|
39
|
+
f"[NUMPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]"
|
40
|
+
)
|
41
|
+
wave = np.zeros((Nt, Nf), dtype=np.float64)
|
12
42
|
DX = np.zeros(Nt, dtype=np.complex128)
|
13
|
-
|
43
|
+
# Create a copy of the input data (if needed).
|
44
|
+
freq_strain = data.copy()
|
14
45
|
__core(Nf, Nt, DX, freq_strain, phif, wave)
|
15
46
|
return wave
|
16
47
|
|
17
48
|
|
18
|
-
|
49
|
+
@njit()
|
19
50
|
def __core(
|
20
51
|
Nf: int,
|
21
52
|
Nt: int,
|
22
53
|
DX: np.ndarray,
|
23
|
-
|
54
|
+
data: np.ndarray,
|
24
55
|
phif: np.ndarray,
|
25
56
|
wave: np.ndarray,
|
26
57
|
):
|
58
|
+
"""
|
59
|
+
Process each frequency bin (f_bin) to compute the temporary array DX,
|
60
|
+
perform the inverse FFT using rocket-fft, and then unpack the result into wave.
|
61
|
+
|
62
|
+
This function is fully jittable.
|
63
|
+
"""
|
27
64
|
for f_bin in range(0, Nf + 1):
|
28
|
-
__fill_wave_1(f_bin, Nt, Nf, DX,
|
29
|
-
#
|
65
|
+
__fill_wave_1(f_bin, Nt, Nf, DX, data, phif)
|
66
|
+
# Use rocket-fft's ifft (which is JIT-able) instead of np.fft.ifft.
|
30
67
|
DX_trans = np.fft.ifft(DX, Nt)
|
31
68
|
__fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
|
32
69
|
|
@@ -40,24 +77,32 @@ def __fill_wave_1(
|
|
40
77
|
data: np.ndarray,
|
41
78
|
phif: np.ndarray,
|
42
79
|
) -> None:
|
43
|
-
"""
|
80
|
+
"""
|
81
|
+
Fill the temporary complex array DX for the given frequency bin (f_bin)
|
82
|
+
based on the input data and the phase factors phif.
|
83
|
+
|
84
|
+
The computation is performed over a window of indices defined by the current f_bin.
|
85
|
+
"""
|
44
86
|
i_base = Nt // 2
|
45
|
-
jj_base = f_bin * Nt // 2
|
87
|
+
jj_base = f_bin * (Nt // 2)
|
46
88
|
|
89
|
+
# Special center assignment:
|
47
90
|
if f_bin == 0 or f_bin == Nf:
|
48
|
-
|
49
|
-
DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2] / 2.0
|
91
|
+
DX[i_base] = phif[0] * data[jj_base] / 2.0
|
50
92
|
else:
|
51
|
-
DX[
|
93
|
+
DX[i_base] = phif[0] * data[jj_base]
|
52
94
|
|
53
|
-
|
95
|
+
# Determine the window of indices.
|
96
|
+
start = jj_base + 1 - (Nt // 2)
|
97
|
+
end = jj_base + (Nt // 2)
|
98
|
+
for jj in range(start, end):
|
54
99
|
j = np.abs(jj - jj_base)
|
55
100
|
i = i_base - jj_base + jj
|
56
|
-
|
57
|
-
|
58
|
-
elif f_bin == 0 and jj < jj_base:
|
101
|
+
# For the highest frequency (f_bin==Nf) or the lowest (f_bin==0), zero out the out-of-range values.
|
102
|
+
if (f_bin == Nf and jj > jj_base) or (f_bin == 0 and jj < jj_base):
|
59
103
|
DX[i] = 0.0
|
60
104
|
elif j == 0:
|
105
|
+
# Center already assigned.
|
61
106
|
continue
|
62
107
|
else:
|
63
108
|
DX[i] = phif[j] * data[jj]
|
@@ -67,21 +112,34 @@ def __fill_wave_1(
|
|
67
112
|
def __fill_wave_2(
|
68
113
|
f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
|
69
114
|
) -> None:
|
115
|
+
"""
|
116
|
+
Unpack the inverse FFT output (DX_trans) into the output wave array.
|
117
|
+
|
118
|
+
For f_bin==0 and f_bin==Nf, the results are stored in column 0 of wave,
|
119
|
+
using even- or odd-indexed rows respectively. For intermediate f_bin values,
|
120
|
+
the values are stored in column f_bin with a sign and component (real or imag)
|
121
|
+
determined by parity.
|
122
|
+
"""
|
123
|
+
sqrt2 = np.sqrt(2.0)
|
70
124
|
if f_bin == 0:
|
71
|
-
#
|
125
|
+
# f_bin==0: assign even-indexed rows of column 0.
|
72
126
|
for n in range(0, Nt, 2):
|
73
|
-
wave[n, 0] = DX_trans[n].real *
|
127
|
+
wave[n, 0] = DX_trans[n].real * sqrt2
|
74
128
|
elif f_bin == Nf:
|
129
|
+
# f_bin==Nf: assign odd-indexed rows of column 0.
|
75
130
|
for n in range(0, Nt, 2):
|
76
|
-
wave[n + 1, 0] = DX_trans[n].real *
|
131
|
+
wave[n + 1, 0] = DX_trans[n].real * sqrt2
|
77
132
|
else:
|
133
|
+
# For intermediate f_bin, assign values to column f_bin.
|
78
134
|
for n in range(0, Nt):
|
79
135
|
if f_bin % 2:
|
136
|
+
# For odd f_bin: use -imag when (n+f_bin) is odd; otherwise use real.
|
80
137
|
if (n + f_bin) % 2:
|
81
138
|
wave[n, f_bin] = -DX_trans[n].imag
|
82
139
|
else:
|
83
140
|
wave[n, f_bin] = DX_trans[n].real
|
84
141
|
else:
|
142
|
+
# For even f_bin: use imag when (n+f_bin) is odd; otherwise use real.
|
85
143
|
if (n + f_bin) % 2:
|
86
144
|
wave[n, f_bin] = DX_trans[n].imag
|
87
145
|
else:
|
@@ -7,8 +7,6 @@ from .to_time import inverse_wavelet_time_helper_fast
|
|
7
7
|
|
8
8
|
__all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
|
9
9
|
|
10
|
-
INV_ROOT2 = 1.0 / np.sqrt(2)
|
11
|
-
|
12
10
|
|
13
11
|
def from_wavelet_to_time(
|
14
12
|
wave_in: Wavelet,
|
@@ -50,7 +48,7 @@ def from_wavelet_to_time(
|
|
50
48
|
h_t = inverse_wavelet_time_helper_fast(
|
51
49
|
wave_in.data.T, phi, wave_in.Nf, wave_in.Nt, mult
|
52
50
|
)
|
53
|
-
h_t *=
|
51
|
+
h_t *= 1.0 / np.sqrt(2) # Normalize to get proper backward transformation
|
54
52
|
ts = np.arange(0, wave_in.Nf * wave_in.Nt) * dt
|
55
53
|
return TimeSeries(data=h_t, time=ts)
|
56
54
|
|
@@ -89,7 +87,7 @@ def from_wavelet_to_freq(
|
|
89
87
|
wave_in.data, phif, wave_in.Nf, wave_in.Nt
|
90
88
|
)
|
91
89
|
|
92
|
-
freq_data *=
|
90
|
+
freq_data *= 1.0 / np.sqrt(2)
|
93
91
|
|
94
92
|
freqs = np.fft.rfftfreq(wave_in.ND, d=dt)
|
95
93
|
return FrequencySeries(data=freq_data, freq=freqs)
|
@@ -46,7 +46,7 @@ def __pack_wave_inverse(
|
|
46
46
|
prefactor2s[n] = 2 ** (-1 / 2) * wave_in[(2 * n) % Nt + 1, 0]
|
47
47
|
else:
|
48
48
|
for n in range(0, Nt):
|
49
|
-
val = wave_in[n, m]
|
49
|
+
val = wave_in[n, m]
|
50
50
|
if (n + m) % 2:
|
51
51
|
mult2 = -1j
|
52
52
|
else:
|
@@ -93,3 +93,66 @@ def __unpack_wave_inverse(
|
|
93
93
|
if ind32 == Nt:
|
94
94
|
ind32 = 0
|
95
95
|
res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0]
|
96
|
+
|
97
|
+
|
98
|
+
#
|
99
|
+
# # @njit
|
100
|
+
# def inverse_wavelet_freq_helper_fast_version2(
|
101
|
+
# wave_in: np.ndarray, phif: np.ndarray, Nf: int, Nt: int
|
102
|
+
# ) -> np.ndarray:
|
103
|
+
# wave_in = wave_in.T
|
104
|
+
# ND = Nf * Nt
|
105
|
+
# prefactor2s = np.zeros((Nf + 1, Nt), dtype=np.complex128)
|
106
|
+
# n_range = np.arange(Nt)
|
107
|
+
#
|
108
|
+
# # m == 0 case
|
109
|
+
# indices = (2 * n_range) % Nt
|
110
|
+
# prefactor2s[0] = (2 ** (-0.5)) * wave_in[indices, 0]
|
111
|
+
#
|
112
|
+
# # m == Nf case
|
113
|
+
# indices = ((2 * n_range) % Nt) + 1
|
114
|
+
# prefactor2s[Nf] = (2 ** (-0.5)) * wave_in[indices, 0]
|
115
|
+
#
|
116
|
+
# # For m = 1, ..., Nf-1
|
117
|
+
# m_mid = np.arange(1, Nf)
|
118
|
+
# m_grid, n_grid = np.meshgrid(m_mid, n_range, indexing='ij')
|
119
|
+
# val = wave_in[n_grid, m_grid]
|
120
|
+
# mult2 = np.where(((n_grid + m_grid) % 2) != 0, -1j, 1)
|
121
|
+
# prefactor2s[1:Nf] = mult2 * val
|
122
|
+
#
|
123
|
+
# fft_prefactor2s = np.fft.fft(prefactor2s, axis=1)
|
124
|
+
#
|
125
|
+
# res = np.zeros(ND // 2 + 1, dtype=np.complex128)
|
126
|
+
#
|
127
|
+
# # Unpacking for m == 0 and m == Nf
|
128
|
+
# for m in [0, Nf]:
|
129
|
+
# i_ind_range = np.arange(Nt // 2 + 1 if m == Nf else Nt // 2)
|
130
|
+
# i = np.abs(m * Nt // 2 - i_ind_range)
|
131
|
+
# ind3 = (2 * i) % Nt
|
132
|
+
# res[i] += fft_prefactor2s[m, ind3] * phif[i_ind_range]
|
133
|
+
#
|
134
|
+
# # Unpacking for m = 1,..., Nf-1
|
135
|
+
# for m in range(1, Nf):
|
136
|
+
# ind31 = (Nt // 2 * m) % Nt
|
137
|
+
# ind32 = ind31
|
138
|
+
# for i_ind in range(Nt // 2):
|
139
|
+
# i1 = Nt // 2 * m - i_ind
|
140
|
+
# i2 = Nt // 2 * m + i_ind
|
141
|
+
# res[i1] += fft_prefactor2s[m, ind31] * phif[i_ind]
|
142
|
+
# res[i2] += fft_prefactor2s[m, ind32] * phif[i_ind]
|
143
|
+
# ind31 = (ind31 - 1) % Nt
|
144
|
+
# ind32 = (ind32 + 1) % Nt
|
145
|
+
# res[Nt // 2 * m] += fft_prefactor2s[m, (Nt // 2 * m) % Nt] * phif[0]
|
146
|
+
#
|
147
|
+
# return res
|
148
|
+
#
|
149
|
+
# #
|
150
|
+
# #
|
151
|
+
# # if __name__ == '__main__':
|
152
|
+
# # phif = np.array(np.random.rand(64))
|
153
|
+
# # wave_in = np.array(np.random.rand(64, 64))
|
154
|
+
# # Nf = 64
|
155
|
+
# # Nt = 64
|
156
|
+
# # res = inverse_wavelet_freq_helper_fast(wave_in, phif, Nf, Nt)
|
157
|
+
# # res2 = inverse_wavelet_freq_helper_fast_version2(wave_in, phif, Nf, Nt)
|
158
|
+
# # assert np.allclose(res, res2), "Results do not match!"
|
@@ -1,7 +1,62 @@
|
|
1
|
-
|
1
|
+
import numpy as np
|
2
|
+
from jaxtyping import Array, Float
|
2
3
|
|
4
|
+
from ..backend import betainc, ifft, xp
|
3
5
|
|
4
|
-
|
6
|
+
__all__ = ["phitilde_vec_norm", "phi_vec", "omega"]
|
7
|
+
|
8
|
+
|
9
|
+
def omega(Nf: int, Nt: int) -> Float[Array, " Nt//2+1"]:
|
10
|
+
"""Get the angular frequencies of the time domain signal."""
|
11
|
+
df = 2 * np.pi / (Nf * Nt)
|
12
|
+
return df * np.arange(0, Nt // 2 + 1)
|
13
|
+
|
14
|
+
|
15
|
+
def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> Float[Array, " Nt//2+1"]:
|
16
|
+
"""Normalize phitilde for inverse frequency domain transform."""
|
17
|
+
omegas = omega(Nf, Nt)
|
18
|
+
_phi_t = _phitilde_vec(omegas, Nf, d) * np.sqrt(np.pi)
|
19
|
+
return xp.array(_phi_t)
|
20
|
+
|
21
|
+
|
22
|
+
def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
|
23
|
+
"""get time domain phi as fourier transform of _phitilde_vec
|
24
|
+
q: number of Nf bins over which the window extends?
|
25
|
+
|
26
|
+
"""
|
27
|
+
insDOM = 1.0 / np.sqrt(np.pi / Nf)
|
28
|
+
K = q * 2 * Nf
|
29
|
+
half_K = q * Nf # xp.int64(K/2)
|
30
|
+
|
31
|
+
dom = 2 * np.pi / K # max frequency is K/2*dom = pi/dt = OM
|
32
|
+
|
33
|
+
DX = np.zeros(K, dtype=np.complex128)
|
34
|
+
|
35
|
+
# zero frequency
|
36
|
+
DX[0] = insDOM
|
37
|
+
|
38
|
+
DX = DX.copy()
|
39
|
+
# postive frequencies
|
40
|
+
DX[1 : half_K + 1] = _phitilde_vec(dom * np.arange(1, half_K + 1), Nf, d)
|
41
|
+
# negative frequencies
|
42
|
+
DX[half_K + 1 :] = _phitilde_vec(
|
43
|
+
-dom * np.arange(half_K - 1, 0, -1), Nf, d
|
44
|
+
)
|
45
|
+
DX = K * ifft(DX, K)
|
46
|
+
|
47
|
+
phi = np.zeros(K)
|
48
|
+
phi[0:half_K] = np.real(DX[half_K:K])
|
49
|
+
phi[half_K:] = np.real(DX[0:half_K])
|
50
|
+
|
51
|
+
nrm = np.sqrt(2.0) / np.sqrt(K / dom) # *xp.linalg.norm(phi)
|
52
|
+
|
53
|
+
phi *= nrm
|
54
|
+
return xp.array(phi)
|
55
|
+
|
56
|
+
|
57
|
+
def _phitilde_vec(
|
58
|
+
omega: Float[Array, " Nt//2+1"], Nf: int, d: float = 4.0
|
59
|
+
) -> Float[Array, " Nt//2+1"]:
|
5
60
|
"""Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
|
6
61
|
|
7
62
|
Eq 11 of https://arxiv.org/pdf/2009.00043.pdf (Cornish et al. 2020)
|
@@ -30,25 +85,25 @@ def phitilde_vec(omega: xp.ndarray, Nf: int, d: float = 4.0) -> xp.ndarray:
|
|
30
85
|
|
31
86
|
"""
|
32
87
|
dF = 1.0 / (2 * Nf) # NOTE: missing 1/dt?
|
33
|
-
dOmega = 2 *
|
34
|
-
inverse_sqrt_dOmega = 1.0 /
|
88
|
+
dOmega = 2 * np.pi * dF # Near Eq 10 # 2 pi times DF
|
89
|
+
inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)
|
35
90
|
|
36
91
|
A = dOmega / 4
|
37
92
|
B = dOmega - 2 * A # Cannot have B \leq 0.
|
38
93
|
if B <= 0:
|
39
94
|
raise ValueError("B must be greater than 0")
|
40
95
|
|
41
|
-
phi =
|
42
|
-
mask = (A <=
|
43
|
-
vd = (
|
96
|
+
phi = np.zeros(omega.size)
|
97
|
+
mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B) # Minor changes
|
98
|
+
vd = (np.pi / 2.0) * _nu_d(omega[mask], A, B, d=d) # different from paper
|
44
99
|
phi[mask] = inverse_sqrt_dOmega * xp.cos(vd)
|
45
|
-
phi[
|
100
|
+
phi[np.abs(omega) < A] = inverse_sqrt_dOmega
|
46
101
|
return phi
|
47
102
|
|
48
103
|
|
49
|
-
def
|
50
|
-
omega:
|
51
|
-
) ->
|
104
|
+
def _nu_d(
|
105
|
+
omega: Float[Array, " Nt//2+1"], A: float, B: float, d: float = 4.0
|
106
|
+
) -> Float[Array, " Nt//2+1"]:
|
52
107
|
"""Compute the normalized incomplete beta function.
|
53
108
|
|
54
109
|
Parameters
|
@@ -71,72 +126,5 @@ def __nu_d(
|
|
71
126
|
https://docs.scipy.org/doc/scipy-1.7.1/reference/reference/generated/scipy.special.betainc.html
|
72
127
|
|
73
128
|
"""
|
74
|
-
x = (
|
75
|
-
return betainc(d, d, x)
|
76
|
-
|
77
|
-
|
78
|
-
def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> xp.ndarray:
|
79
|
-
"""Normalize phitilde for inverse frequency domain transform."""
|
80
|
-
|
81
|
-
# Calculate the frequency values
|
82
|
-
ND = Nf * Nt
|
83
|
-
omegas = 2 * xp.pi / ND * xp.arange(0, Nt // 2 + 1)
|
84
|
-
|
85
|
-
# Calculate the unnormalized phitilde (u_phit)
|
86
|
-
u_phit = phitilde_vec(omegas, Nf, d)
|
87
|
-
|
88
|
-
# Normalize the phitilde
|
89
|
-
normalising_factor = PI ** (-1 / 2) # Ollie's normalising factor
|
90
|
-
|
91
|
-
# Notes: this is the overall normalising factor that is different from Cornish's paper
|
92
|
-
# It is the only way I can force this code to be consistent with our work in the
|
93
|
-
# frequency domain. First note that
|
94
|
-
|
95
|
-
# old normalising factor -- This factor is absolutely ridiculous. Why!?
|
96
|
-
# Matt_normalising_factor = np.sqrt(
|
97
|
-
# (2 * np.sum(u_phit[1:] ** 2) + u_phit[0] ** 2) * 2 * PI / ND
|
98
|
-
# )
|
99
|
-
# Matt_normalising_factor /= PI**(3/2)/PI
|
100
|
-
|
101
|
-
# The expression above is equal to np.pi**(-1/2) after working through the maths.
|
102
|
-
# I have pulled (2/Nf) from __init__.py (from freq to wavelet) into the normalsiing
|
103
|
-
# factor here. I thnk it's cleaner to have ONE normalising constant. Avoids confusion
|
104
|
-
# and it is much easier to track.
|
105
|
-
|
106
|
-
# TODO: understand the following:
|
107
|
-
# (2 * np.sum(u_phit[1:] ** 2) + u_phit[0] ** 2) = 0.5 * Nt / dOmega
|
108
|
-
# Matt_normalising_factor is equal to 1/sqrt(pi)... why is this computed?
|
109
|
-
# in such a stupid way?
|
110
|
-
|
111
|
-
return u_phit / (normalising_factor)
|
112
|
-
|
113
|
-
|
114
|
-
def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> xp.ndarray:
|
115
|
-
"""get time domain phi as fourier transform of phitilde_vec"""
|
116
|
-
insDOM = 1.0 / xp.sqrt(PI / Nf)
|
117
|
-
K = q * 2 * Nf
|
118
|
-
half_K = q * Nf # xp.int64(K/2)
|
119
|
-
|
120
|
-
dom = 2 * PI / K # max frequency is K/2*dom = pi/dt = OM
|
121
|
-
|
122
|
-
DX = xp.zeros(K, dtype=xp.complex128)
|
123
|
-
|
124
|
-
# zero frequency
|
125
|
-
DX[0] = insDOM
|
126
|
-
|
127
|
-
DX = DX.copy()
|
128
|
-
# postive frequencies
|
129
|
-
DX[1 : half_K + 1] = phitilde_vec(dom * xp.arange(1, half_K + 1), Nf, d)
|
130
|
-
# negative frequencies
|
131
|
-
DX[half_K + 1 :] = phitilde_vec(-dom * xp.arange(half_K - 1, 0, -1), Nf, d)
|
132
|
-
DX = K * ifft(DX, K)
|
133
|
-
|
134
|
-
phi = xp.zeros(K)
|
135
|
-
phi[0:half_K] = xp.real(DX[half_K:K])
|
136
|
-
phi[half_K:] = xp.real(DX[0:half_K])
|
137
|
-
|
138
|
-
nrm = xp.sqrt(K / dom) # *xp.linalg.norm(phi)
|
139
|
-
|
140
|
-
fac = xp.sqrt(2.0) / nrm
|
141
|
-
phi *= fac
|
142
|
-
return phi
|
129
|
+
x = (np.abs(omega) - A) / B
|
130
|
+
return betainc(d, d, x)
|
pywavelet/types/plotting.py
CHANGED
@@ -84,6 +84,7 @@ def plot_wavelet_grid(
|
|
84
84
|
nan_color: Optional[str] = "black",
|
85
85
|
detailed_axes: bool = False,
|
86
86
|
show_gridinfo: bool = True,
|
87
|
+
txtbox_kwargs: dict = {},
|
87
88
|
trend_color: Optional[str] = None,
|
88
89
|
whiten_by: Optional[np.ndarray] = None,
|
89
90
|
**kwargs,
|
@@ -172,12 +173,17 @@ def plot_wavelet_grid(
|
|
172
173
|
if np.all(np.isnan(z)):
|
173
174
|
raise ValueError("All wavelet data is NaN.")
|
174
175
|
if zscale == "log":
|
175
|
-
|
176
|
-
|
177
|
-
|
176
|
+
vmin = np.nanmin(z[z > 0])
|
177
|
+
vmax = np.nanmax(z[z < np.inf])
|
178
|
+
if vmin > vmax:
|
179
|
+
raise ValueError("vmin > vmax... something wrong")
|
180
|
+
norm = LogNorm(vmin=vmin, vmax=vmax)
|
178
181
|
elif not absolute:
|
179
182
|
vmin, vmax = np.nanmin(z), np.nanmax(z)
|
180
183
|
vcenter = 0.0
|
184
|
+
if vmin > vmax:
|
185
|
+
raise ValueError("vmin > vmax... something wrong")
|
186
|
+
|
181
187
|
norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
|
182
188
|
else:
|
183
189
|
norm = None # Default linear scaling
|
@@ -248,6 +254,9 @@ def plot_wavelet_grid(
|
|
248
254
|
NfNt_label = f"{Nf}x{Nt}" if show_gridinfo else ""
|
249
255
|
txt = f"{label}\n{NfNt_label}" if label else NfNt_label
|
250
256
|
if txt:
|
257
|
+
txtbox_kwargs.setdefault("boxstyle", "round")
|
258
|
+
txtbox_kwargs.setdefault("facecolor", "white")
|
259
|
+
txtbox_kwargs.setdefault("alpha", 0.2)
|
251
260
|
ax.text(
|
252
261
|
0.05,
|
253
262
|
0.95,
|
@@ -255,7 +264,7 @@ def plot_wavelet_grid(
|
|
255
264
|
transform=ax.transAxes,
|
256
265
|
fontsize=14,
|
257
266
|
verticalalignment="top",
|
258
|
-
bbox=
|
267
|
+
bbox=txtbox_kwargs,
|
259
268
|
)
|
260
269
|
|
261
270
|
# Adjust layout
|
@@ -294,7 +303,7 @@ def plot_periodogram(
|
|
294
303
|
flow = np.min(np.abs(freq))
|
295
304
|
ax.set_xlabel("Frequency [Hz]")
|
296
305
|
ax.set_ylabel("Periodigram")
|
297
|
-
ax.set_xlim(left=flow, right=nyquist_frequency / 2)
|
306
|
+
# ax.set_xlim(left=flow, right=nyquist_frequency / 2)
|
298
307
|
return ax.figure, ax
|
299
308
|
|
300
309
|
|
pywavelet/types/wavelet.py
CHANGED
@@ -477,13 +477,13 @@ class WaveletMask(Wavelet):
|
|
477
477
|
A WaveletMask object with the specified restrictions.
|
478
478
|
"""
|
479
479
|
self = cls.zeros_from_grid(time_grid, freq_grid)
|
480
|
-
self.data[
|
481
|
-
|
482
|
-
|
480
|
+
self.data[(freq_grid >= frange[0]) & (freq_grid <= frange[1]), :] = (
|
481
|
+
True
|
482
|
+
)
|
483
483
|
|
484
484
|
for tgap in tgaps:
|
485
|
-
self.data[
|
486
|
-
|
487
|
-
|
485
|
+
self.data[:, (time_grid >= tgap[0]) & (time_grid <= tgap[1])] = (
|
486
|
+
False
|
487
|
+
)
|
488
488
|
self.data = self.data.astype(bool)
|
489
489
|
return self
|