pywavelet 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/__init__.py +22 -0
- pywavelet/_version.py +9 -4
- pywavelet/backend.py +49 -27
- pywavelet/transforms/__init__.py +10 -4
- pywavelet/transforms/cupy/__init__.py +12 -0
- pywavelet/transforms/cupy/forward/__init__.py +3 -0
- pywavelet/transforms/cupy/forward/from_freq.py +92 -0
- pywavelet/transforms/cupy/forward/from_time.py +50 -0
- pywavelet/transforms/cupy/forward/main.py +106 -0
- pywavelet/transforms/cupy/inverse/__init__.py +3 -0
- pywavelet/transforms/cupy/inverse/main.py +67 -0
- pywavelet/transforms/cupy/inverse/to_freq.py +62 -0
- pywavelet/transforms/jax/forward/from_freq.py +6 -0
- pywavelet/transforms/jax/forward/from_time.py +18 -10
- pywavelet/transforms/jax/forward/main.py +6 -10
- pywavelet/transforms/jax/inverse/main.py +4 -6
- pywavelet/transforms/jax/inverse/to_freq.py +52 -34
- pywavelet/transforms/numpy/__init__.py +1 -2
- pywavelet/transforms/numpy/forward/from_freq.py +77 -19
- pywavelet/transforms/numpy/forward/main.py +1 -2
- pywavelet/transforms/numpy/inverse/main.py +4 -6
- pywavelet/transforms/numpy/inverse/to_freq.py +64 -1
- pywavelet/transforms/phi_computer.py +67 -86
- pywavelet/types/common.py +4 -3
- pywavelet/types/frequencyseries.py +1 -1
- pywavelet/types/plotting.py +14 -5
- pywavelet/types/timeseries.py +4 -10
- pywavelet/types/wavelet.py +6 -6
- pywavelet/types/wavelet_bins.py +0 -1
- pywavelet/utils.py +2 -0
- {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/METADATA +20 -9
- pywavelet-0.2.6.dist-info/RECORD +43 -0
- {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/WHEEL +1 -1
- pywavelet-0.2.4.dist-info/RECORD +0 -35
- {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ def from_wavelet_to_time(
|
|
10
10
|
wave_in: Wavelet,
|
11
11
|
dt: float,
|
12
12
|
nx: float = 4.0,
|
13
|
-
mult: int =
|
13
|
+
mult: int = None,
|
14
14
|
) -> TimeSeries:
|
15
15
|
"""Inverse wavelet transform to time domain.
|
16
16
|
|
@@ -55,14 +55,12 @@ def from_wavelet_to_freq(
|
|
55
55
|
Frequency domain signal
|
56
56
|
|
57
57
|
"""
|
58
|
-
phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt,
|
58
|
+
phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx))
|
59
59
|
freq_data = inverse_wavelet_freq_helper(
|
60
60
|
wave_in.data, phif=phif, Nf=wave_in.Nf, Nt=wave_in.Nt
|
61
61
|
)
|
62
62
|
|
63
|
-
freq_data *=
|
64
|
-
-1 / 2
|
65
|
-
) # Normalise to get the proper backwards transformation
|
63
|
+
freq_data *= 1.0 / jnp.sqrt(2)
|
66
64
|
|
67
|
-
freqs = rfftfreq(wave_in.ND
|
65
|
+
freqs = rfftfreq(wave_in.ND, d=dt)
|
68
66
|
return FrequencySeries(data=freq_data, freq=freqs)
|
@@ -1,70 +1,88 @@
|
|
1
|
+
from functools import partial
|
2
|
+
|
1
3
|
import jax
|
2
4
|
import jax.numpy as jnp
|
3
5
|
from jax import jit
|
4
6
|
from jax.numpy.fft import fft
|
5
7
|
|
6
|
-
from functools import partial
|
7
8
|
|
8
|
-
|
9
|
-
@partial(jit, static_argnames=('Nf', 'Nt'))
|
9
|
+
@partial(jit, static_argnames=("Nf", "Nt"))
|
10
10
|
def inverse_wavelet_freq_helper(
|
11
|
-
|
11
|
+
wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
|
12
12
|
) -> jnp.ndarray:
|
13
|
-
"""JAX vectorized function for inverse_wavelet_freq"""
|
13
|
+
"""JAX vectorized function for inverse_wavelet_freq with corrected shapes and ranges."""
|
14
|
+
# Transpose to match the NumPy version.
|
14
15
|
wave_in = wave_in.T
|
15
16
|
ND = Nf * Nt
|
16
17
|
|
18
|
+
# Allocate prefactor2s for each m (shape: (Nf+1, Nt)).
|
17
19
|
m_range = jnp.arange(Nf + 1)
|
18
20
|
prefactor2s = jnp.zeros((Nf + 1, Nt), dtype=jnp.complex128)
|
19
|
-
|
20
21
|
n_range = jnp.arange(Nt)
|
21
22
|
|
22
23
|
# m == 0 case
|
23
|
-
prefactor2s = prefactor2s.at[0].set(
|
24
|
+
prefactor2s = prefactor2s.at[0].set(
|
25
|
+
2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
|
26
|
+
)
|
24
27
|
|
25
28
|
# m == Nf case
|
26
|
-
prefactor2s = prefactor2s.at[Nf].set(
|
29
|
+
prefactor2s = prefactor2s.at[Nf].set(
|
30
|
+
2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
|
31
|
+
)
|
27
32
|
|
28
|
-
# Other m cases
|
33
|
+
# Other m cases: use meshgrid for vectorization.
|
29
34
|
m_mid = m_range[1:Nf]
|
35
|
+
# Create grids: n_grid (columns) and m_grid (rows)
|
30
36
|
n_grid, m_grid = jnp.meshgrid(n_range, m_mid)
|
37
|
+
# Index the transposed wave_in using (n, m) as in the NumPy version.
|
31
38
|
val = wave_in[n_grid, m_grid]
|
39
|
+
# Apply the alternating multiplier based on (n+m) parity.
|
32
40
|
mult2 = jnp.where((n_grid + m_grid) % 2, -1j, 1)
|
33
41
|
prefactor2s = prefactor2s.at[1:Nf].set(mult2 * val)
|
34
42
|
|
35
|
-
#
|
43
|
+
# Apply FFT along axis 1 for all m.
|
36
44
|
fft_prefactor2s = fft(prefactor2s, axis=1)
|
37
45
|
|
38
|
-
#
|
39
|
-
|
40
|
-
# ND or ND // 2 + 1?
|
41
|
-
# https://github.com/pywavelet/pywavelet/blob/63151a47cde9edc14f1e7e0bf17f554e78ad257c/src/pywavelet/transforms/from_wavelets/inverse_wavelet_freq_funcs.py
|
42
|
-
res = jnp.zeros(ND, dtype=jnp.complex128)
|
46
|
+
# Allocate the result array with corrected shape.
|
47
|
+
res = jnp.zeros(ND // 2 + 1, dtype=jnp.complex128)
|
43
48
|
|
44
|
-
# m == 0
|
49
|
+
# Unpacking for m == 0 and m == Nf cases:
|
45
50
|
i_ind_range = jnp.arange(Nt // 2)
|
46
|
-
i_0 = jnp.abs(i_ind_range)
|
47
|
-
i_Nf = jnp.abs(Nf * Nt // 2 - i_ind_range)
|
51
|
+
i_0 = jnp.abs(i_ind_range) # for m == 0: i = i_ind_range
|
52
|
+
i_Nf = jnp.abs(Nf * (Nt // 2) - i_ind_range)
|
48
53
|
ind3_0 = (2 * i_0) % Nt
|
49
54
|
ind3_Nf = (2 * i_Nf) % Nt
|
50
55
|
|
51
56
|
res = res.at[i_0].add(fft_prefactor2s[0, ind3_0] * phif[i_ind_range])
|
52
57
|
res = res.at[i_Nf].add(fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range])
|
58
|
+
# Special case for m == Nf (ensure the Nyquist frequency is updated correctly)
|
59
|
+
special_index = jnp.abs(Nf * (Nt // 2) - (Nt // 2))
|
60
|
+
res = res.at[special_index].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
|
53
61
|
|
54
|
-
#
|
55
|
-
res = res.at[Nf * Nt // 2].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
|
56
|
-
|
57
|
-
# Other m cases
|
62
|
+
# Unpacking for m in (1, ..., Nf-1)
|
58
63
|
m_mid = m_range[1:Nf]
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
64
|
+
# Use range [0, Nt//2) to match the loop in NumPy version.
|
65
|
+
i_ind_range_mid = jnp.arange(Nt // 2)
|
66
|
+
# Create meshgrid for vectorized computation.
|
67
|
+
m_grid_mid, i_ind_grid_mid = jnp.meshgrid(
|
68
|
+
m_mid, i_ind_range_mid, indexing="ij"
|
69
|
+
)
|
70
|
+
|
71
|
+
# Compute indices i1 and i2 following the NumPy logic.
|
72
|
+
i1 = (Nt // 2) * m_grid_mid - i_ind_grid_mid
|
73
|
+
i2 = (Nt // 2) * m_grid_mid + i_ind_grid_mid
|
74
|
+
# Compute the wrapped indices for FFT results.
|
75
|
+
ind31 = i1 % Nt
|
76
|
+
ind32 = i2 % Nt
|
77
|
+
|
78
|
+
# Update result array using vectorized adds.
|
79
|
+
# Note: You might need to adjust this further if your target res shape is non-trivial,
|
80
|
+
# because here we assume that i1 and i2 indices fall within the allocated result shape.
|
81
|
+
res = res.at[i1].add(
|
82
|
+
fft_prefactor2s[m_grid_mid, ind31] * phif[i_ind_grid_mid]
|
83
|
+
)
|
84
|
+
res = res.at[i2].add(
|
85
|
+
fft_prefactor2s[m_grid_mid, ind32] * phif[i_ind_grid_mid]
|
86
|
+
)
|
87
|
+
|
88
|
+
return res
|
@@ -1,10 +1,9 @@
|
|
1
1
|
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
2
2
|
from .inverse import from_wavelet_to_freq, from_wavelet_to_time
|
3
3
|
|
4
|
-
|
5
4
|
__all__ = [
|
6
5
|
"from_wavelet_to_time",
|
7
6
|
"from_wavelet_to_freq",
|
8
7
|
"from_time_to_wavelet",
|
9
8
|
"from_freq_to_wavelet",
|
10
|
-
]
|
9
|
+
]
|
@@ -1,32 +1,69 @@
|
|
1
1
|
"""helper functions for transform_freq"""
|
2
2
|
|
3
|
+
import logging
|
4
|
+
|
3
5
|
import numpy as np
|
4
6
|
from numba import njit
|
5
7
|
|
8
|
+
# import rocket_fft.special as rfft # JIT‐able FFT routines
|
9
|
+
|
10
|
+
|
11
|
+
logger = logging.getLogger("pywavelet")
|
12
|
+
|
6
13
|
|
7
14
|
def transform_wavelet_freq_helper(
|
8
15
|
data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
|
9
16
|
) -> np.ndarray:
|
10
|
-
"""
|
11
|
-
|
17
|
+
"""
|
18
|
+
Forward wavelet transform helper using the fast wavelet domain transform,
|
19
|
+
with a JIT-able FFT (rocket-fft) so that the whole transform is jittable.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
data : np.ndarray
|
24
|
+
Input frequency-domain data (1D array).
|
25
|
+
Nf : int
|
26
|
+
Number of frequency bins.
|
27
|
+
Nt : int
|
28
|
+
Number of time bins.
|
29
|
+
phif : np.ndarray
|
30
|
+
Fourier-domain phase factors (complex-valued array of length Nt//2 + 1).
|
31
|
+
|
32
|
+
Returns
|
33
|
+
-------
|
34
|
+
wave : np.ndarray
|
35
|
+
The wavelet transform output of shape (Nt, Nf). Note that contributions from
|
36
|
+
f_bin==0 and f_bin==Nf are both stored in column 0.
|
37
|
+
"""
|
38
|
+
logger.debug(
|
39
|
+
f"[NUMPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]"
|
40
|
+
)
|
41
|
+
wave = np.zeros((Nt, Nf), dtype=np.float64)
|
12
42
|
DX = np.zeros(Nt, dtype=np.complex128)
|
13
|
-
|
43
|
+
# Create a copy of the input data (if needed).
|
44
|
+
freq_strain = data.copy()
|
14
45
|
__core(Nf, Nt, DX, freq_strain, phif, wave)
|
15
46
|
return wave
|
16
47
|
|
17
48
|
|
18
|
-
|
49
|
+
@njit()
|
19
50
|
def __core(
|
20
51
|
Nf: int,
|
21
52
|
Nt: int,
|
22
53
|
DX: np.ndarray,
|
23
|
-
|
54
|
+
data: np.ndarray,
|
24
55
|
phif: np.ndarray,
|
25
56
|
wave: np.ndarray,
|
26
57
|
):
|
58
|
+
"""
|
59
|
+
Process each frequency bin (f_bin) to compute the temporary array DX,
|
60
|
+
perform the inverse FFT using rocket-fft, and then unpack the result into wave.
|
61
|
+
|
62
|
+
This function is fully jittable.
|
63
|
+
"""
|
27
64
|
for f_bin in range(0, Nf + 1):
|
28
|
-
__fill_wave_1(f_bin, Nt, Nf, DX,
|
29
|
-
#
|
65
|
+
__fill_wave_1(f_bin, Nt, Nf, DX, data, phif)
|
66
|
+
# Use rocket-fft's ifft (which is JIT-able) instead of np.fft.ifft.
|
30
67
|
DX_trans = np.fft.ifft(DX, Nt)
|
31
68
|
__fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
|
32
69
|
|
@@ -40,24 +77,32 @@ def __fill_wave_1(
|
|
40
77
|
data: np.ndarray,
|
41
78
|
phif: np.ndarray,
|
42
79
|
) -> None:
|
43
|
-
"""
|
80
|
+
"""
|
81
|
+
Fill the temporary complex array DX for the given frequency bin (f_bin)
|
82
|
+
based on the input data and the phase factors phif.
|
83
|
+
|
84
|
+
The computation is performed over a window of indices defined by the current f_bin.
|
85
|
+
"""
|
44
86
|
i_base = Nt // 2
|
45
|
-
jj_base = f_bin * Nt // 2
|
87
|
+
jj_base = f_bin * (Nt // 2)
|
46
88
|
|
89
|
+
# Special center assignment:
|
47
90
|
if f_bin == 0 or f_bin == Nf:
|
48
|
-
|
49
|
-
DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2] / 2.0
|
91
|
+
DX[i_base] = phif[0] * data[jj_base] / 2.0
|
50
92
|
else:
|
51
|
-
DX[
|
93
|
+
DX[i_base] = phif[0] * data[jj_base]
|
52
94
|
|
53
|
-
|
95
|
+
# Determine the window of indices.
|
96
|
+
start = jj_base + 1 - (Nt // 2)
|
97
|
+
end = jj_base + (Nt // 2)
|
98
|
+
for jj in range(start, end):
|
54
99
|
j = np.abs(jj - jj_base)
|
55
100
|
i = i_base - jj_base + jj
|
56
|
-
|
57
|
-
|
58
|
-
elif f_bin == 0 and jj < jj_base:
|
101
|
+
# For the highest frequency (f_bin==Nf) or the lowest (f_bin==0), zero out the out-of-range values.
|
102
|
+
if (f_bin == Nf and jj > jj_base) or (f_bin == 0 and jj < jj_base):
|
59
103
|
DX[i] = 0.0
|
60
104
|
elif j == 0:
|
105
|
+
# Center already assigned.
|
61
106
|
continue
|
62
107
|
else:
|
63
108
|
DX[i] = phif[j] * data[jj]
|
@@ -67,21 +112,34 @@ def __fill_wave_1(
|
|
67
112
|
def __fill_wave_2(
|
68
113
|
f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
|
69
114
|
) -> None:
|
115
|
+
"""
|
116
|
+
Unpack the inverse FFT output (DX_trans) into the output wave array.
|
117
|
+
|
118
|
+
For f_bin==0 and f_bin==Nf, the results are stored in column 0 of wave,
|
119
|
+
using even- or odd-indexed rows respectively. For intermediate f_bin values,
|
120
|
+
the values are stored in column f_bin with a sign and component (real or imag)
|
121
|
+
determined by parity.
|
122
|
+
"""
|
123
|
+
sqrt2 = np.sqrt(2.0)
|
70
124
|
if f_bin == 0:
|
71
|
-
#
|
125
|
+
# f_bin==0: assign even-indexed rows of column 0.
|
72
126
|
for n in range(0, Nt, 2):
|
73
|
-
wave[n, 0] = DX_trans[n].real *
|
127
|
+
wave[n, 0] = DX_trans[n].real * sqrt2
|
74
128
|
elif f_bin == Nf:
|
129
|
+
# f_bin==Nf: assign odd-indexed rows of column 0.
|
75
130
|
for n in range(0, Nt, 2):
|
76
|
-
wave[n + 1, 0] = DX_trans[n].real *
|
131
|
+
wave[n + 1, 0] = DX_trans[n].real * sqrt2
|
77
132
|
else:
|
133
|
+
# For intermediate f_bin, assign values to column f_bin.
|
78
134
|
for n in range(0, Nt):
|
79
135
|
if f_bin % 2:
|
136
|
+
# For odd f_bin: use -imag when (n+f_bin) is odd; otherwise use real.
|
80
137
|
if (n + f_bin) % 2:
|
81
138
|
wave[n, f_bin] = -DX_trans[n].imag
|
82
139
|
else:
|
83
140
|
wave[n, f_bin] = DX_trans[n].real
|
84
141
|
else:
|
142
|
+
# For even f_bin: use imag when (n+f_bin) is odd; otherwise use real.
|
85
143
|
if (n + f_bin) % 2:
|
86
144
|
wave[n, f_bin] = DX_trans[n].imag
|
87
145
|
else:
|
@@ -56,7 +56,6 @@ def from_time_to_wavelet(
|
|
56
56
|
to inaccurate results.
|
57
57
|
"""
|
58
58
|
Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
|
59
|
-
dt = timeseries.dt
|
60
59
|
t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
|
61
60
|
|
62
61
|
ND = Nf * Nt
|
@@ -113,7 +112,7 @@ def from_freq_to_wavelet(
|
|
113
112
|
"""
|
114
113
|
Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
|
115
114
|
t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
|
116
|
-
phif = phitilde_vec_norm(Nf, Nt,
|
115
|
+
phif = phitilde_vec_norm(Nf, Nt, d=nx)
|
117
116
|
wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
|
118
117
|
|
119
118
|
return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
|
@@ -1,14 +1,12 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from ...phi_computer import phi_vec, phitilde_vec_norm
|
4
3
|
from ....types import FrequencySeries, TimeSeries, Wavelet
|
4
|
+
from ...phi_computer import phi_vec, phitilde_vec_norm
|
5
5
|
from .to_freq import inverse_wavelet_freq_helper_fast
|
6
6
|
from .to_time import inverse_wavelet_time_helper_fast
|
7
7
|
|
8
8
|
__all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
|
9
9
|
|
10
|
-
INV_ROOT2 = 1.0 / np.sqrt(2)
|
11
|
-
|
12
10
|
|
13
11
|
def from_wavelet_to_time(
|
14
12
|
wave_in: Wavelet,
|
@@ -50,7 +48,7 @@ def from_wavelet_to_time(
|
|
50
48
|
h_t = inverse_wavelet_time_helper_fast(
|
51
49
|
wave_in.data.T, phi, wave_in.Nf, wave_in.Nt, mult
|
52
50
|
)
|
53
|
-
h_t *=
|
51
|
+
h_t *= 1.0 / np.sqrt(2) # Normalize to get proper backward transformation
|
54
52
|
ts = np.arange(0, wave_in.Nf * wave_in.Nt) * dt
|
55
53
|
return TimeSeries(data=h_t, time=ts)
|
56
54
|
|
@@ -84,12 +82,12 @@ def from_wavelet_to_freq(
|
|
84
82
|
to ensure the proper backwards transformation.
|
85
83
|
"""
|
86
84
|
|
87
|
-
phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt,
|
85
|
+
phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx)
|
88
86
|
freq_data = inverse_wavelet_freq_helper_fast(
|
89
87
|
wave_in.data, phif, wave_in.Nf, wave_in.Nt
|
90
88
|
)
|
91
89
|
|
92
|
-
freq_data *=
|
90
|
+
freq_data *= 1.0 / np.sqrt(2)
|
93
91
|
|
94
92
|
freqs = np.fft.rfftfreq(wave_in.ND, d=dt)
|
95
93
|
return FrequencySeries(data=freq_data, freq=freqs)
|
@@ -46,7 +46,7 @@ def __pack_wave_inverse(
|
|
46
46
|
prefactor2s[n] = 2 ** (-1 / 2) * wave_in[(2 * n) % Nt + 1, 0]
|
47
47
|
else:
|
48
48
|
for n in range(0, Nt):
|
49
|
-
val = wave_in[n, m]
|
49
|
+
val = wave_in[n, m]
|
50
50
|
if (n + m) % 2:
|
51
51
|
mult2 = -1j
|
52
52
|
else:
|
@@ -93,3 +93,66 @@ def __unpack_wave_inverse(
|
|
93
93
|
if ind32 == Nt:
|
94
94
|
ind32 = 0
|
95
95
|
res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0]
|
96
|
+
|
97
|
+
|
98
|
+
#
|
99
|
+
# # @njit
|
100
|
+
# def inverse_wavelet_freq_helper_fast_version2(
|
101
|
+
# wave_in: np.ndarray, phif: np.ndarray, Nf: int, Nt: int
|
102
|
+
# ) -> np.ndarray:
|
103
|
+
# wave_in = wave_in.T
|
104
|
+
# ND = Nf * Nt
|
105
|
+
# prefactor2s = np.zeros((Nf + 1, Nt), dtype=np.complex128)
|
106
|
+
# n_range = np.arange(Nt)
|
107
|
+
#
|
108
|
+
# # m == 0 case
|
109
|
+
# indices = (2 * n_range) % Nt
|
110
|
+
# prefactor2s[0] = (2 ** (-0.5)) * wave_in[indices, 0]
|
111
|
+
#
|
112
|
+
# # m == Nf case
|
113
|
+
# indices = ((2 * n_range) % Nt) + 1
|
114
|
+
# prefactor2s[Nf] = (2 ** (-0.5)) * wave_in[indices, 0]
|
115
|
+
#
|
116
|
+
# # For m = 1, ..., Nf-1
|
117
|
+
# m_mid = np.arange(1, Nf)
|
118
|
+
# m_grid, n_grid = np.meshgrid(m_mid, n_range, indexing='ij')
|
119
|
+
# val = wave_in[n_grid, m_grid]
|
120
|
+
# mult2 = np.where(((n_grid + m_grid) % 2) != 0, -1j, 1)
|
121
|
+
# prefactor2s[1:Nf] = mult2 * val
|
122
|
+
#
|
123
|
+
# fft_prefactor2s = np.fft.fft(prefactor2s, axis=1)
|
124
|
+
#
|
125
|
+
# res = np.zeros(ND // 2 + 1, dtype=np.complex128)
|
126
|
+
#
|
127
|
+
# # Unpacking for m == 0 and m == Nf
|
128
|
+
# for m in [0, Nf]:
|
129
|
+
# i_ind_range = np.arange(Nt // 2 + 1 if m == Nf else Nt // 2)
|
130
|
+
# i = np.abs(m * Nt // 2 - i_ind_range)
|
131
|
+
# ind3 = (2 * i) % Nt
|
132
|
+
# res[i] += fft_prefactor2s[m, ind3] * phif[i_ind_range]
|
133
|
+
#
|
134
|
+
# # Unpacking for m = 1,..., Nf-1
|
135
|
+
# for m in range(1, Nf):
|
136
|
+
# ind31 = (Nt // 2 * m) % Nt
|
137
|
+
# ind32 = ind31
|
138
|
+
# for i_ind in range(Nt // 2):
|
139
|
+
# i1 = Nt // 2 * m - i_ind
|
140
|
+
# i2 = Nt // 2 * m + i_ind
|
141
|
+
# res[i1] += fft_prefactor2s[m, ind31] * phif[i_ind]
|
142
|
+
# res[i2] += fft_prefactor2s[m, ind32] * phif[i_ind]
|
143
|
+
# ind31 = (ind31 - 1) % Nt
|
144
|
+
# ind32 = (ind32 + 1) % Nt
|
145
|
+
# res[Nt // 2 * m] += fft_prefactor2s[m, (Nt // 2 * m) % Nt] * phif[0]
|
146
|
+
#
|
147
|
+
# return res
|
148
|
+
#
|
149
|
+
# #
|
150
|
+
# #
|
151
|
+
# # if __name__ == '__main__':
|
152
|
+
# # phif = np.array(np.random.rand(64))
|
153
|
+
# # wave_in = np.array(np.random.rand(64, 64))
|
154
|
+
# # Nf = 64
|
155
|
+
# # Nt = 64
|
156
|
+
# # res = inverse_wavelet_freq_helper_fast(wave_in, phif, Nf, Nt)
|
157
|
+
# # res2 = inverse_wavelet_freq_helper_fast_version2(wave_in, phif, Nf, Nt)
|
158
|
+
# # assert np.allclose(res, res2), "Results do not match!"
|
@@ -1,10 +1,62 @@
|
|
1
|
-
|
1
|
+
import numpy as np
|
2
|
+
from jaxtyping import Array, Float
|
2
3
|
|
4
|
+
from ..backend import betainc, ifft, xp
|
3
5
|
|
6
|
+
__all__ = ["phitilde_vec_norm", "phi_vec", "omega"]
|
4
7
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
+
|
9
|
+
def omega(Nf: int, Nt: int) -> Float[Array, " Nt//2+1"]:
|
10
|
+
"""Get the angular frequencies of the time domain signal."""
|
11
|
+
df = 2 * np.pi / (Nf * Nt)
|
12
|
+
return df * np.arange(0, Nt // 2 + 1)
|
13
|
+
|
14
|
+
|
15
|
+
def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> Float[Array, " Nt//2+1"]:
|
16
|
+
"""Normalize phitilde for inverse frequency domain transform."""
|
17
|
+
omegas = omega(Nf, Nt)
|
18
|
+
_phi_t = _phitilde_vec(omegas, Nf, d) * np.sqrt(np.pi)
|
19
|
+
return xp.array(_phi_t)
|
20
|
+
|
21
|
+
|
22
|
+
def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
|
23
|
+
"""get time domain phi as fourier transform of _phitilde_vec
|
24
|
+
q: number of Nf bins over which the window extends?
|
25
|
+
|
26
|
+
"""
|
27
|
+
insDOM = 1.0 / np.sqrt(np.pi / Nf)
|
28
|
+
K = q * 2 * Nf
|
29
|
+
half_K = q * Nf # xp.int64(K/2)
|
30
|
+
|
31
|
+
dom = 2 * np.pi / K # max frequency is K/2*dom = pi/dt = OM
|
32
|
+
|
33
|
+
DX = np.zeros(K, dtype=np.complex128)
|
34
|
+
|
35
|
+
# zero frequency
|
36
|
+
DX[0] = insDOM
|
37
|
+
|
38
|
+
DX = DX.copy()
|
39
|
+
# postive frequencies
|
40
|
+
DX[1 : half_K + 1] = _phitilde_vec(dom * np.arange(1, half_K + 1), Nf, d)
|
41
|
+
# negative frequencies
|
42
|
+
DX[half_K + 1 :] = _phitilde_vec(
|
43
|
+
-dom * np.arange(half_K - 1, 0, -1), Nf, d
|
44
|
+
)
|
45
|
+
DX = K * ifft(DX, K)
|
46
|
+
|
47
|
+
phi = np.zeros(K)
|
48
|
+
phi[0:half_K] = np.real(DX[half_K:K])
|
49
|
+
phi[half_K:] = np.real(DX[0:half_K])
|
50
|
+
|
51
|
+
nrm = np.sqrt(2.0) / np.sqrt(K / dom) # *xp.linalg.norm(phi)
|
52
|
+
|
53
|
+
phi *= nrm
|
54
|
+
return xp.array(phi)
|
55
|
+
|
56
|
+
|
57
|
+
def _phitilde_vec(
|
58
|
+
omega: Float[Array, " Nt//2+1"], Nf: int, d: float = 4.0
|
59
|
+
) -> Float[Array, " Nt//2+1"]:
|
8
60
|
"""Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
|
9
61
|
|
10
62
|
Eq 11 of https://arxiv.org/pdf/2009.00043.pdf (Cornish et al. 2020)
|
@@ -33,25 +85,25 @@ def phitilde_vec(
|
|
33
85
|
|
34
86
|
"""
|
35
87
|
dF = 1.0 / (2 * Nf) # NOTE: missing 1/dt?
|
36
|
-
dOmega = 2 *
|
37
|
-
inverse_sqrt_dOmega = 1.0 /
|
88
|
+
dOmega = 2 * np.pi * dF # Near Eq 10 # 2 pi times DF
|
89
|
+
inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)
|
38
90
|
|
39
91
|
A = dOmega / 4
|
40
92
|
B = dOmega - 2 * A # Cannot have B \leq 0.
|
41
93
|
if B <= 0:
|
42
94
|
raise ValueError("B must be greater than 0")
|
43
95
|
|
44
|
-
phi =
|
45
|
-
mask = (A <=
|
46
|
-
vd = (
|
96
|
+
phi = np.zeros(omega.size)
|
97
|
+
mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B) # Minor changes
|
98
|
+
vd = (np.pi / 2.0) * _nu_d(omega[mask], A, B, d=d) # different from paper
|
47
99
|
phi[mask] = inverse_sqrt_dOmega * xp.cos(vd)
|
48
|
-
phi[
|
100
|
+
phi[np.abs(omega) < A] = inverse_sqrt_dOmega
|
49
101
|
return phi
|
50
102
|
|
51
103
|
|
52
|
-
def
|
53
|
-
omega:
|
54
|
-
) ->
|
104
|
+
def _nu_d(
|
105
|
+
omega: Float[Array, " Nt//2+1"], A: float, B: float, d: float = 4.0
|
106
|
+
) -> Float[Array, " Nt//2+1"]:
|
55
107
|
"""Compute the normalized incomplete beta function.
|
56
108
|
|
57
109
|
Parameters
|
@@ -74,76 +126,5 @@ def __nu_d(
|
|
74
126
|
https://docs.scipy.org/doc/scipy-1.7.1/reference/reference/generated/scipy.special.betainc.html
|
75
127
|
|
76
128
|
"""
|
77
|
-
x = (
|
78
|
-
return betainc(d, d, x)
|
79
|
-
|
80
|
-
|
81
|
-
def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> xp.ndarray:
|
82
|
-
"""Normalize phitilde for inverse frequency domain transform."""
|
83
|
-
|
84
|
-
# Calculate the frequency values
|
85
|
-
ND = Nf * Nt
|
86
|
-
omegas = 2 * xp.pi / ND * xp.arange(0, Nt // 2 + 1)
|
87
|
-
|
88
|
-
# Calculate the unnormalized phitilde (u_phit)
|
89
|
-
u_phit = phitilde_vec(omegas, Nf, d)
|
90
|
-
|
91
|
-
# Normalize the phitilde
|
92
|
-
normalising_factor = PI ** (-1 / 2) # Ollie's normalising factor
|
93
|
-
|
94
|
-
# Notes: this is the overall normalising factor that is different from Cornish's paper
|
95
|
-
# It is the only way I can force this code to be consistent with our work in the
|
96
|
-
# frequency domain. First note that
|
97
|
-
|
98
|
-
# old normalising factor -- This factor is absolutely ridiculous. Why!?
|
99
|
-
# Matt_normalising_factor = np.sqrt(
|
100
|
-
# (2 * np.sum(u_phit[1:] ** 2) + u_phit[0] ** 2) * 2 * PI / ND
|
101
|
-
# )
|
102
|
-
# Matt_normalising_factor /= PI**(3/2)/PI
|
103
|
-
|
104
|
-
# The expression above is equal to np.pi**(-1/2) after working through the maths.
|
105
|
-
# I have pulled (2/Nf) from __init__.py (from freq to wavelet) into the normalsiing
|
106
|
-
# factor here. I thnk it's cleaner to have ONE normalising constant. Avoids confusion
|
107
|
-
# and it is much easier to track.
|
108
|
-
|
109
|
-
# TODO: understand the following:
|
110
|
-
# (2 * np.sum(u_phit[1:] ** 2) + u_phit[0] ** 2) = 0.5 * Nt / dOmega
|
111
|
-
# Matt_normalising_factor is equal to 1/sqrt(pi)... why is this computed?
|
112
|
-
# in such a stupid way?
|
113
|
-
|
114
|
-
return u_phit / (normalising_factor)
|
115
|
-
|
116
|
-
|
117
|
-
def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> xp.ndarray:
|
118
|
-
"""get time domain phi as fourier transform of phitilde_vec"""
|
119
|
-
insDOM = 1.0 / xp.sqrt(PI / Nf)
|
120
|
-
K = q * 2 * Nf
|
121
|
-
half_K = q * Nf # xp.int64(K/2)
|
122
|
-
|
123
|
-
dom = 2 * PI / K # max frequency is K/2*dom = pi/dt = OM
|
124
|
-
|
125
|
-
DX = xp.zeros(K, dtype=xp.complex128)
|
126
|
-
|
127
|
-
# zero frequency
|
128
|
-
DX[0] = insDOM
|
129
|
-
|
130
|
-
DX = DX.copy()
|
131
|
-
# postive frequencies
|
132
|
-
DX[1 : half_K + 1] = phitilde_vec(
|
133
|
-
dom * xp.arange(1, half_K + 1), Nf, d
|
134
|
-
)
|
135
|
-
# negative frequencies
|
136
|
-
DX[half_K + 1 :] = phitilde_vec(
|
137
|
-
-dom * xp.arange(half_K - 1, 0, -1), Nf, d
|
138
|
-
)
|
139
|
-
DX = K * ifft(DX, K)
|
140
|
-
|
141
|
-
phi = xp.zeros(K)
|
142
|
-
phi[0:half_K] = xp.real(DX[half_K:K])
|
143
|
-
phi[half_K:] = xp.real(DX[0:half_K])
|
144
|
-
|
145
|
-
nrm = xp.sqrt(K / dom) # *xp.linalg.norm(phi)
|
146
|
-
|
147
|
-
fac = xp.sqrt(2.0) / nrm
|
148
|
-
phi *= fac
|
149
|
-
return phi
|
129
|
+
x = (np.abs(omega) - A) / B
|
130
|
+
return betainc(d, d, x)
|
pywavelet/types/common.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
|
-
from typing import Tuple, Union
|
2
|
-
|
1
|
+
from typing import Callable, Tuple, Union
|
2
|
+
|
3
3
|
from ..backend import xp
|
4
|
+
from ..logger import logger
|
4
5
|
|
5
6
|
|
6
7
|
def _len_check(d):
|
@@ -8,7 +9,7 @@ def _len_check(d):
|
|
8
9
|
logger.warning(f"Data length {len(d)} is suggested to be a power of 2")
|
9
10
|
|
10
11
|
|
11
|
-
def is_documented_by(original:Callable):
|
12
|
+
def is_documented_by(original: Callable):
|
12
13
|
def wrapper(target):
|
13
14
|
target.__doc__ = original.__doc__
|
14
15
|
return target
|