pywavelet 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/__init__.py +22 -0
- pywavelet/_version.py +9 -4
- pywavelet/backend.py +49 -27
- pywavelet/transforms/__init__.py +10 -4
- pywavelet/transforms/cupy/__init__.py +12 -0
- pywavelet/transforms/cupy/forward/__init__.py +3 -0
- pywavelet/transforms/cupy/forward/from_freq.py +92 -0
- pywavelet/transforms/cupy/forward/from_time.py +50 -0
- pywavelet/transforms/cupy/forward/main.py +106 -0
- pywavelet/transforms/cupy/inverse/__init__.py +3 -0
- pywavelet/transforms/cupy/inverse/main.py +67 -0
- pywavelet/transforms/cupy/inverse/to_freq.py +62 -0
- pywavelet/transforms/jax/forward/from_freq.py +6 -0
- pywavelet/transforms/jax/forward/from_time.py +18 -10
- pywavelet/transforms/jax/forward/main.py +6 -10
- pywavelet/transforms/jax/inverse/main.py +4 -6
- pywavelet/transforms/jax/inverse/to_freq.py +52 -34
- pywavelet/transforms/numpy/__init__.py +1 -2
- pywavelet/transforms/numpy/forward/from_freq.py +77 -19
- pywavelet/transforms/numpy/forward/main.py +1 -2
- pywavelet/transforms/numpy/inverse/main.py +4 -6
- pywavelet/transforms/numpy/inverse/to_freq.py +64 -1
- pywavelet/transforms/phi_computer.py +67 -86
- pywavelet/types/common.py +4 -3
- pywavelet/types/frequencyseries.py +1 -1
- pywavelet/types/plotting.py +14 -5
- pywavelet/types/timeseries.py +4 -10
- pywavelet/types/wavelet.py +6 -6
- pywavelet/types/wavelet_bins.py +0 -1
- pywavelet/utils.py +2 -0
- {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/METADATA +20 -9
- pywavelet-0.2.6.dist-info/RECORD +43 -0
- {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/WHEEL +1 -1
- pywavelet-0.2.4.dist-info/RECORD +0 -35
- {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/top_level.txt +0 -0
pywavelet/__init__.py
CHANGED
@@ -2,4 +2,26 @@
|
|
2
2
|
WDM Wavelet transform
|
3
3
|
"""
|
4
4
|
|
5
|
+
import importlib
|
6
|
+
import os
|
7
|
+
|
8
|
+
from . import backend as _backend
|
9
|
+
|
5
10
|
__version__ = "0.0.2"
|
11
|
+
|
12
|
+
|
13
|
+
def set_backend(backend: str):
|
14
|
+
"""Set the backend for the wavelet transform.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
backend : str
|
19
|
+
Backend to use. Options are "numpy", "jax", "cupy".
|
20
|
+
"""
|
21
|
+
from . import types
|
22
|
+
from . import transforms
|
23
|
+
os.environ["PYWAVELET_BACKEND"] = backend
|
24
|
+
|
25
|
+
importlib.reload(_backend)
|
26
|
+
importlib.reload(types)
|
27
|
+
importlib.reload(transforms)
|
pywavelet/_version.py
CHANGED
@@ -1,8 +1,13 @@
|
|
1
|
-
# file generated by
|
1
|
+
# file generated by setuptools-scm
|
2
2
|
# don't change, don't track in version control
|
3
|
+
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
5
|
+
|
3
6
|
TYPE_CHECKING = False
|
4
7
|
if TYPE_CHECKING:
|
5
|
-
from typing import Tuple
|
8
|
+
from typing import Tuple
|
9
|
+
from typing import Union
|
10
|
+
|
6
11
|
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
7
12
|
else:
|
8
13
|
VERSION_TUPLE = object
|
@@ -12,5 +17,5 @@ __version__: str
|
|
12
17
|
__version_tuple__: VERSION_TUPLE
|
13
18
|
version_tuple: VERSION_TUPLE
|
14
19
|
|
15
|
-
__version__ = version = '0.2.
|
16
|
-
__version_tuple__ = version_tuple = (0, 2,
|
20
|
+
__version__ = version = '0.2.6'
|
21
|
+
__version_tuple__ = version_tuple = (0, 2, 6)
|
pywavelet/backend.py
CHANGED
@@ -1,31 +1,53 @@
|
|
1
|
+
import importlib
|
1
2
|
import os
|
2
3
|
|
3
4
|
from .logger import logger
|
4
5
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
6
|
+
JAX = "jax"
|
7
|
+
CUPY = "cupy"
|
8
|
+
NUMPY = "numpy"
|
9
|
+
|
10
|
+
|
11
|
+
def get_backend_from_env():
|
12
|
+
"""Select and return the appropriate backend module."""
|
13
|
+
backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
|
14
|
+
|
15
|
+
if backend == JAX:
|
16
|
+
if importlib.util.find_spec(JAX):
|
17
|
+
import jax.numpy as xp
|
18
|
+
from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
|
19
|
+
from jax.scipy.special import betainc
|
20
|
+
|
21
|
+
logger.info("Using JAX backend")
|
22
|
+
return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
|
23
|
+
else:
|
24
|
+
logger.warning(
|
25
|
+
"JAX backend requested but not installed. Falling back to NumPy."
|
26
|
+
)
|
27
|
+
|
28
|
+
elif backend == CUPY:
|
29
|
+
if importlib.util.find_spec(CUPY):
|
30
|
+
import cupy as xp
|
31
|
+
from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
|
32
|
+
from cupyx.scipy.special import betainc
|
33
|
+
|
34
|
+
logger.info("Using CuPy backend")
|
35
|
+
return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
|
36
|
+
else:
|
37
|
+
logger.warning(
|
38
|
+
"CuPy backend requested but not installed. Falling back to NumPy."
|
39
|
+
)
|
40
|
+
|
41
|
+
# Default to NumPy
|
42
|
+
import numpy as xp
|
43
|
+
from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
|
44
|
+
from scipy.special import betainc
|
45
|
+
|
46
|
+
logger.info("Using NumPy+Numba backend")
|
47
|
+
return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
|
48
|
+
|
49
|
+
|
50
|
+
# Get the chosen backend
|
51
|
+
xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
|
52
|
+
get_backend_from_env()
|
53
|
+
)
|
pywavelet/transforms/__init__.py
CHANGED
@@ -1,12 +1,19 @@
|
|
1
|
-
from ..backend import
|
1
|
+
from ..backend import current_backend
|
2
2
|
|
3
|
-
if
|
3
|
+
if current_backend == "jax":
|
4
4
|
from .jax import (
|
5
5
|
from_freq_to_wavelet,
|
6
6
|
from_time_to_wavelet,
|
7
7
|
from_wavelet_to_freq,
|
8
8
|
from_wavelet_to_time,
|
9
9
|
)
|
10
|
+
elif current_backend == "cupy":
|
11
|
+
from .cupy import (
|
12
|
+
from_freq_to_wavelet,
|
13
|
+
from_time_to_wavelet,
|
14
|
+
from_wavelet_to_freq,
|
15
|
+
from_wavelet_to_time,
|
16
|
+
)
|
10
17
|
else:
|
11
18
|
from .numpy import (
|
12
19
|
from_wavelet_to_time,
|
@@ -15,7 +22,7 @@ else:
|
|
15
22
|
from_freq_to_wavelet,
|
16
23
|
)
|
17
24
|
|
18
|
-
from .phi_computer import
|
25
|
+
from .phi_computer import omega, phi_vec, phitilde_vec_norm
|
19
26
|
|
20
27
|
__all__ = [
|
21
28
|
"from_wavelet_to_time",
|
@@ -24,5 +31,4 @@ __all__ = [
|
|
24
31
|
"from_freq_to_wavelet",
|
25
32
|
"phitilde_vec_norm",
|
26
33
|
"phi_vec",
|
27
|
-
"phitilde_vec",
|
28
34
|
]
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from ...logger import logger
|
2
|
+
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
3
|
+
from .inverse import from_wavelet_to_freq, from_wavelet_to_time
|
4
|
+
|
5
|
+
logger.warning("CUPY SUBPACKAGE NOT FULLY TESTED")
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"from_wavelet_to_time",
|
9
|
+
"from_wavelet_to_freq",
|
10
|
+
"from_time_to_wavelet",
|
11
|
+
"from_freq_to_wavelet",
|
12
|
+
]
|
@@ -0,0 +1,92 @@
|
|
1
|
+
import cupy as cp
|
2
|
+
from cupyx.scipy.fft import ifft
|
3
|
+
|
4
|
+
|
5
|
+
import logging
|
6
|
+
|
7
|
+
logger = logging.getLogger('pywavelet')
|
8
|
+
|
9
|
+
|
10
|
+
def transform_wavelet_freq_helper(
|
11
|
+
data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray
|
12
|
+
) -> cp.ndarray:
|
13
|
+
"""
|
14
|
+
Transforms input data from the frequency domain to the wavelet domain using a
|
15
|
+
pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
|
16
|
+
|
17
|
+
Parameters:
|
18
|
+
- data (cp.ndarray): 1D array representing the input data in the frequency domain.
|
19
|
+
- Nf (int): Number of frequency bins.
|
20
|
+
- Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
|
21
|
+
- phif (cp.ndarray): Pre-computed wavelet filter for frequency components.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
- wave (cp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
|
25
|
+
"""
|
26
|
+
|
27
|
+
logger.debug(f"[CUPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
|
28
|
+
|
29
|
+
# Initialize the wavelet output array with zeros (time-rows, frequency-columns)
|
30
|
+
wave = cp.zeros((Nt, Nf))
|
31
|
+
f_bins = cp.arange(Nf) # Frequency bin indices
|
32
|
+
|
33
|
+
# Compute base indices for time (i_base) and frequency (jj_base)
|
34
|
+
i_base = Nt // 2
|
35
|
+
jj_base = f_bins * Nt // 2
|
36
|
+
|
37
|
+
# Set initial values for the center of the transformation
|
38
|
+
initial_values = cp.where(
|
39
|
+
(f_bins == 0)
|
40
|
+
| (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
|
41
|
+
phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
|
42
|
+
phif[0] * data[f_bins * Nt // 2],
|
43
|
+
)
|
44
|
+
|
45
|
+
# Initialize a 2D array to store intermediate FFT input values
|
46
|
+
DX = cp.zeros((Nf, Nt), dtype=cp.complex64)
|
47
|
+
DX[:, Nt // 2] = (
|
48
|
+
initial_values # Set initial values at the center of the transformation (2 sided FFT)
|
49
|
+
)
|
50
|
+
|
51
|
+
# Compute time indices for all offsets around the midpoint
|
52
|
+
j_range = cp.arange(
|
53
|
+
1 - Nt // 2, Nt // 2
|
54
|
+
) # Time offsets (centered around zero)
|
55
|
+
j = cp.abs(j_range) # Absolute offset indices
|
56
|
+
i = i_base + j_range # Time indices in output array
|
57
|
+
|
58
|
+
# Compute conditions for edge cases
|
59
|
+
cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
|
60
|
+
cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
|
61
|
+
cond3 = j[None, :] == 0 # Center of the transformation (no offset)
|
62
|
+
|
63
|
+
# Compute frequency indices for the input data
|
64
|
+
jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
|
65
|
+
val = cp.where(
|
66
|
+
cond1 | cond2, 0.0, phif[j] * data[jj]
|
67
|
+
) # Wavelet filter application
|
68
|
+
DX[:, i] = cp.where(cond3, DX[:, i], val) # Update DX with computed values
|
69
|
+
|
70
|
+
# Perform the inverse FFT along the time dimension
|
71
|
+
DX_trans = ifft(DX, axis=1)
|
72
|
+
|
73
|
+
# Fill the wavelet output array based on the inverse FFT results
|
74
|
+
n_range = cp.arange(Nt) # Time indices
|
75
|
+
cond1 = (
|
76
|
+
n_range[:, None] + f_bins[None, :]
|
77
|
+
) % 2 == 1 # Odd/even alternation
|
78
|
+
cond2 = cp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
|
79
|
+
|
80
|
+
# Assign real and imaginary parts based on conditions
|
81
|
+
real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))
|
82
|
+
imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))
|
83
|
+
wave = cp.where(cond1, imag_part.T, real_part.T)
|
84
|
+
|
85
|
+
# Special cases for frequency bins 0 (DC) and Nf (Nyquist)
|
86
|
+
wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2)) # DC component
|
87
|
+
wave[1::2, -1] = cp.real(
|
88
|
+
DX_trans[-1, ::2] * cp.sqrt(2)
|
89
|
+
) # Nyquist component
|
90
|
+
|
91
|
+
# Return the wavelet-transformed array (transposed for freq-major layout)
|
92
|
+
return wave.T # (Nt, Nf) -> (Nf, Nt)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import cupy as cp
|
2
|
+
from cupyx.scipy.fft import rfft
|
3
|
+
|
4
|
+
|
5
|
+
def transform_wavelet_time_helper(
|
6
|
+
data: cp.ndarray, phi: cp.ndarray, Nf: int, Nt: int, mult: int
|
7
|
+
) -> cp.ndarray:
|
8
|
+
"""Helper function to do the wavelet transform in the time domain using CuPy"""
|
9
|
+
# Define constants
|
10
|
+
ND = Nf * Nt
|
11
|
+
K = mult * 2 * Nf
|
12
|
+
|
13
|
+
# Pad the data with K extra values
|
14
|
+
data_pad = cp.concatenate((data, data[:K]))
|
15
|
+
|
16
|
+
# Generate time bin indices
|
17
|
+
time_bins = cp.arange(Nt)
|
18
|
+
jj_base = (time_bins[:, None] * Nf - K // 2) % ND
|
19
|
+
jj = (jj_base + cp.arange(K)[None, :]) % ND
|
20
|
+
|
21
|
+
# Apply the window (phi) to the data
|
22
|
+
wdata = data_pad[jj] * phi[None, :]
|
23
|
+
|
24
|
+
# Perform FFT on the windowed data
|
25
|
+
wdata_trans = rfft(wdata, axis=1)
|
26
|
+
|
27
|
+
# Initialize the wavelet transform result
|
28
|
+
wave = cp.zeros((Nt, Nf))
|
29
|
+
|
30
|
+
# Handle m=0 case for even time bins
|
31
|
+
even_mask = (time_bins % 2 == 0) & (time_bins < Nt - 1)
|
32
|
+
even_indices = cp.nonzero(even_mask)[0]
|
33
|
+
|
34
|
+
# Update wave for m=0 using even time bins
|
35
|
+
wave[even_indices, 0] = cp.real(wdata_trans[even_indices, 0]) / cp.sqrt(2)
|
36
|
+
wave[even_indices + 1, 0] = cp.real(
|
37
|
+
wdata_trans[even_indices, Nf * mult]
|
38
|
+
) / cp.sqrt(2)
|
39
|
+
|
40
|
+
# Handle other cases (j > 0) using vectorized operations
|
41
|
+
j_range = cp.arange(1, Nf)
|
42
|
+
odd_condition = (time_bins[:, None] + j_range[None, :]) % 2 == 1
|
43
|
+
|
44
|
+
wave[:, 1:] = cp.where(
|
45
|
+
odd_condition,
|
46
|
+
-cp.imag(wdata_trans[:, j_range * mult]),
|
47
|
+
cp.real(wdata_trans[:, j_range * mult]),
|
48
|
+
)
|
49
|
+
|
50
|
+
return wave.T
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from typing import Union
|
2
|
+
|
3
|
+
import cupy as cp
|
4
|
+
|
5
|
+
from ....logger import logger
|
6
|
+
from ....types import FrequencySeries, TimeSeries, Wavelet
|
7
|
+
from ....types.wavelet_bins import _get_bins, _preprocess_bins
|
8
|
+
from ...phi_computer import phi_vec, phitilde_vec_norm
|
9
|
+
from .from_freq import transform_wavelet_freq_helper
|
10
|
+
from .from_time import transform_wavelet_time_helper
|
11
|
+
|
12
|
+
|
13
|
+
def from_time_to_wavelet(
|
14
|
+
timeseries: TimeSeries,
|
15
|
+
Nf: Union[int, None] = None,
|
16
|
+
Nt: Union[int, None] = None,
|
17
|
+
nx: float = 4.0,
|
18
|
+
mult: int = 32,
|
19
|
+
**kwargs,
|
20
|
+
) -> Wavelet:
|
21
|
+
"""Transforms time-domain data to wavelet-domain data.
|
22
|
+
|
23
|
+
Warning: there can be significant leakage if mult is too small and the
|
24
|
+
transform is only approximately exact if mult=Nt/2
|
25
|
+
|
26
|
+
Parameters
|
27
|
+
----------
|
28
|
+
timeseries : TimeSeries
|
29
|
+
Time domain data
|
30
|
+
Nf : int
|
31
|
+
Number of frequency bins
|
32
|
+
Nt : int
|
33
|
+
Number of time bins
|
34
|
+
nx : float, optional
|
35
|
+
Number of standard deviations for the phi_vec, by default 4.
|
36
|
+
mult : int, optional
|
37
|
+
Number of time bins to use for the wavelet transform, by default 32
|
38
|
+
**kwargs:
|
39
|
+
Additional keyword arguments passed to the Wavelet.from_data constructor.
|
40
|
+
|
41
|
+
Returns
|
42
|
+
-------
|
43
|
+
Wavelet
|
44
|
+
Wavelet domain data
|
45
|
+
|
46
|
+
"""
|
47
|
+
Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
|
48
|
+
dt = timeseries.dt
|
49
|
+
t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
|
50
|
+
|
51
|
+
ND = Nf * Nt
|
52
|
+
|
53
|
+
if len(timeseries) != ND:
|
54
|
+
logger.warning(
|
55
|
+
f"len(freqseries)={len(timeseries)} != Nf*Nt={ND}. Truncating to freqseries[:{ND}]"
|
56
|
+
)
|
57
|
+
timeseries = timeseries[:ND]
|
58
|
+
if mult > Nt / 2:
|
59
|
+
logger.warning(
|
60
|
+
f"mult={mult} is too large for Nt={Nt}. This may lead to bogus results."
|
61
|
+
)
|
62
|
+
|
63
|
+
mult = min(mult, Nt // 2) # make sure K isn't bigger than ND
|
64
|
+
phi = cp.array(phi_vec(Nf, d=nx, q=mult))
|
65
|
+
wave = transform_wavelet_time_helper(
|
66
|
+
cp.array(timeseries.data), Nf=Nf, Nt=Nt, phi=phi, mult=mult
|
67
|
+
)
|
68
|
+
return Wavelet(wave * cp.sqrt(2), time=t_bins, freq=f_bins)
|
69
|
+
|
70
|
+
|
71
|
+
def from_freq_to_wavelet(
|
72
|
+
freqseries: FrequencySeries,
|
73
|
+
Nf: Union[int, None] = None,
|
74
|
+
Nt: Union[int, None] = None,
|
75
|
+
nx: float = 4.0,
|
76
|
+
**kwargs,
|
77
|
+
) -> Wavelet:
|
78
|
+
"""Transforms frequency-domain data to wavelet-domain data.
|
79
|
+
|
80
|
+
Parameters
|
81
|
+
----------
|
82
|
+
freqseries : FrequencySeries
|
83
|
+
Frequency domain data
|
84
|
+
Nf : int
|
85
|
+
Number of frequency bins
|
86
|
+
Nt : int
|
87
|
+
Number of time bins
|
88
|
+
nx : float, optional
|
89
|
+
Number of standard deviations for the phi_vec, by default 4.
|
90
|
+
**kwargs:
|
91
|
+
Additional keyword arguments passed to the Wavelet.from_data constructor.
|
92
|
+
|
93
|
+
Returns
|
94
|
+
-------
|
95
|
+
Wavelet
|
96
|
+
Wavelet domain data
|
97
|
+
|
98
|
+
"""
|
99
|
+
Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
|
100
|
+
t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
|
101
|
+
phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx))
|
102
|
+
wave = transform_wavelet_freq_helper(
|
103
|
+
cp.array(freqseries.data), Nf=Nf, Nt=Nt, phif=phif
|
104
|
+
)
|
105
|
+
|
106
|
+
return Wavelet((2 / Nf) * wave * cp.sqrt(2), time=t_bins, freq=f_bins)
|
@@ -0,0 +1,67 @@
|
|
1
|
+
import cupy as cp
|
2
|
+
from cupyx.scipy.fft import rfftfreq
|
3
|
+
|
4
|
+
from ....types import FrequencySeries, TimeSeries, Wavelet
|
5
|
+
from ...phi_computer import phi_vec, phitilde_vec_norm
|
6
|
+
from .to_freq import inverse_wavelet_freq_helper
|
7
|
+
|
8
|
+
|
9
|
+
def from_wavelet_to_time(
|
10
|
+
wave_in: Wavelet,
|
11
|
+
dt: float,
|
12
|
+
nx: float = 4.0,
|
13
|
+
mult: int = 32,
|
14
|
+
) -> TimeSeries:
|
15
|
+
"""Inverse wavelet transform to time domain.
|
16
|
+
|
17
|
+
Parameters
|
18
|
+
----------
|
19
|
+
wave_in : Wavelet
|
20
|
+
input wavelet
|
21
|
+
dt : float
|
22
|
+
time step
|
23
|
+
nx : float, optional
|
24
|
+
parameter for phi_vec, by default 4.0
|
25
|
+
mult : int, optional
|
26
|
+
parameter for phi_vec, by default 32
|
27
|
+
|
28
|
+
Returns
|
29
|
+
-------
|
30
|
+
TimeSeries
|
31
|
+
Time domain signal
|
32
|
+
"""
|
33
|
+
freq = from_wavelet_to_freq(wave_in, dt=dt, nx=nx)
|
34
|
+
return freq.to_timeseries()
|
35
|
+
|
36
|
+
|
37
|
+
def from_wavelet_to_freq(
|
38
|
+
wave_in: Wavelet, dt: float, nx=4.0
|
39
|
+
) -> FrequencySeries:
|
40
|
+
"""Inverse wavelet transform to frequency domain.
|
41
|
+
|
42
|
+
Parameters
|
43
|
+
----------
|
44
|
+
wave_in : Wavelet
|
45
|
+
input wavelet
|
46
|
+
dt : float
|
47
|
+
time step
|
48
|
+
nx : float, optional
|
49
|
+
parameter for phitilde_vec_norm, by default 4.0
|
50
|
+
|
51
|
+
Returns
|
52
|
+
-------
|
53
|
+
FrequencySeries
|
54
|
+
Frequency domain signal
|
55
|
+
|
56
|
+
"""
|
57
|
+
phif = cp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx))
|
58
|
+
freq_data = inverse_wavelet_freq_helper(
|
59
|
+
wave_in.data, phif=phif, Nf=wave_in.Nf, Nt=wave_in.Nt
|
60
|
+
)
|
61
|
+
|
62
|
+
freq_data *= 2 ** (
|
63
|
+
-1 / 2
|
64
|
+
) # Normalise to get the proper backwards transformation
|
65
|
+
|
66
|
+
freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
|
67
|
+
return FrequencySeries(data=freq_data, freq=freqs)
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import cupy as cp
|
2
|
+
from cupyx.scipy.fft import fft
|
3
|
+
|
4
|
+
|
5
|
+
def inverse_wavelet_freq_helper(
|
6
|
+
wave_in: cp.ndarray, phif: cp.ndarray, Nf: int, Nt: int
|
7
|
+
) -> cp.ndarray:
|
8
|
+
"""CuPy vectorized function for inverse_wavelet_freq"""
|
9
|
+
wave_in = wave_in.T
|
10
|
+
ND = Nf * Nt
|
11
|
+
|
12
|
+
m_range = cp.arange(Nf + 1)
|
13
|
+
prefactor2s = cp.zeros((Nf + 1, Nt), dtype=cp.complex128)
|
14
|
+
|
15
|
+
n_range = cp.arange(Nt)
|
16
|
+
|
17
|
+
# m == 0 case
|
18
|
+
prefactor2s[0] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
|
19
|
+
|
20
|
+
# m == Nf case
|
21
|
+
prefactor2s[Nf] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
|
22
|
+
|
23
|
+
# Other m cases
|
24
|
+
m_mid = m_range[1:Nf]
|
25
|
+
n_grid, m_grid = cp.meshgrid(n_range, m_mid)
|
26
|
+
val = wave_in[n_grid, m_grid]
|
27
|
+
mult2 = cp.where((n_grid + m_grid) % 2, -1j, 1)
|
28
|
+
prefactor2s[1:Nf] = mult2 * val
|
29
|
+
|
30
|
+
# Vectorized FFT
|
31
|
+
fft_prefactor2s = fft(prefactor2s, axis=1)
|
32
|
+
|
33
|
+
# Vectorized __unpack_wave_inverse
|
34
|
+
res = cp.zeros(ND, dtype=cp.complex128)
|
35
|
+
|
36
|
+
# m == 0 or m == Nf cases
|
37
|
+
i_ind_range = cp.arange(Nt // 2)
|
38
|
+
i_0 = cp.abs(i_ind_range)
|
39
|
+
i_Nf = cp.abs(Nf * Nt // 2 - i_ind_range)
|
40
|
+
ind3_0 = (2 * i_0) % Nt
|
41
|
+
ind3_Nf = (2 * i_Nf) % Nt
|
42
|
+
|
43
|
+
res[i_0] += fft_prefactor2s[0, ind3_0] * phif[i_ind_range]
|
44
|
+
res[i_Nf] += fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range]
|
45
|
+
|
46
|
+
# Special case for m == Nf
|
47
|
+
res[Nf * Nt // 2] += fft_prefactor2s[Nf, 0] * phif[Nt // 2]
|
48
|
+
|
49
|
+
# Other m cases
|
50
|
+
m_mid = m_range[1:Nf]
|
51
|
+
i_ind_range = cp.arange(Nt // 2 + 1)
|
52
|
+
m_grid, i_ind_grid = cp.meshgrid(m_mid, i_ind_range)
|
53
|
+
|
54
|
+
i1 = Nt // 2 * m_grid - i_ind_grid
|
55
|
+
i2 = Nt // 2 * m_grid + i_ind_grid
|
56
|
+
ind31 = (Nt // 2 * m_grid - i_ind_grid) % Nt
|
57
|
+
ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
|
58
|
+
|
59
|
+
res[i1] += fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid]
|
60
|
+
res[i2] += fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid]
|
61
|
+
|
62
|
+
return res
|
@@ -4,6 +4,10 @@ import jax.numpy as jnp
|
|
4
4
|
from jax import jit
|
5
5
|
from jax.numpy.fft import ifft
|
6
6
|
|
7
|
+
import logging
|
8
|
+
|
9
|
+
logger = logging.getLogger('pywavelet')
|
10
|
+
|
7
11
|
|
8
12
|
@partial(jit, static_argnames=("Nf", "Nt"))
|
9
13
|
def transform_wavelet_freq_helper(
|
@@ -23,6 +27,8 @@ def transform_wavelet_freq_helper(
|
|
23
27
|
- wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
|
24
28
|
"""
|
25
29
|
|
30
|
+
logger.debug(f"[JAX TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
|
31
|
+
|
26
32
|
# Initialize the wavelet output array with zeros (time-rows, frequency-columns)
|
27
33
|
wave = jnp.zeros((Nt, Nf))
|
28
34
|
f_bins = jnp.arange(Nf) # Frequency bin indices
|
@@ -1,12 +1,14 @@
|
|
1
|
+
from functools import partial
|
2
|
+
|
1
3
|
import jax
|
2
4
|
import jax.numpy as jnp
|
3
5
|
from jax import jit
|
4
6
|
from jax.numpy.fft import rfft
|
5
|
-
from functools import partial
|
6
7
|
|
7
|
-
|
8
|
+
|
9
|
+
@partial(jit, static_argnames=("Nf", "Nt", "mult"))
|
8
10
|
def transform_wavelet_time_helper(
|
9
|
-
data: jnp.ndarray, phi: jnp.ndarray, Nf: int, Nt: int,
|
11
|
+
data: jnp.ndarray, phi: jnp.ndarray, Nf: int, Nt: int, mult: int
|
10
12
|
) -> jnp.ndarray:
|
11
13
|
"""Helper function to do the wavelet transform in the time domain using JAX"""
|
12
14
|
# Define constants
|
@@ -35,17 +37,23 @@ def transform_wavelet_time_helper(
|
|
35
37
|
even_indices = jnp.nonzero(even_mask, size=even_mask.shape[0])[0]
|
36
38
|
|
37
39
|
# Update wave for m=0 using even time bins
|
38
|
-
wave = wave.at[even_indices, 0].set(
|
39
|
-
|
40
|
+
wave = wave.at[even_indices, 0].set(
|
41
|
+
jnp.real(wdata_trans[even_indices, 0]) / jnp.sqrt(2)
|
42
|
+
)
|
43
|
+
wave = wave.at[even_indices + 1, 0].set(
|
44
|
+
jnp.real(wdata_trans[even_indices, Nf * mult]) / jnp.sqrt(2)
|
45
|
+
)
|
40
46
|
|
41
47
|
# Handle other cases (j > 0) using vectorized operations
|
42
48
|
j_range = jnp.arange(1, Nf)
|
43
|
-
odd_condition = (
|
49
|
+
odd_condition = (time_bins[:, None] + j_range[None, :]) % 2 == 1
|
44
50
|
|
45
51
|
wave = wave.at[:, 1:].set(
|
46
|
-
jnp.where(
|
47
|
-
|
48
|
-
|
52
|
+
jnp.where(
|
53
|
+
odd_condition,
|
54
|
+
-jnp.imag(wdata_trans[:, j_range * mult]),
|
55
|
+
jnp.real(wdata_trans[:, j_range * mult]),
|
56
|
+
)
|
49
57
|
)
|
50
58
|
|
51
|
-
return wave.T
|
59
|
+
return wave.T
|
@@ -62,10 +62,10 @@ def from_time_to_wavelet(
|
|
62
62
|
|
63
63
|
mult = min(mult, Nt // 2) # make sure K isn't bigger than ND
|
64
64
|
phi = jnp.array(phi_vec(Nf, d=nx, q=mult))
|
65
|
-
wave = transform_wavelet_time_helper(
|
66
|
-
|
67
|
-
wave* jnp.sqrt(2), time=t_bins, freq=f_bins
|
65
|
+
wave = transform_wavelet_time_helper(
|
66
|
+
timeseries.data, Nf=Nf, Nt=Nt, phi=phi, mult=mult
|
68
67
|
)
|
68
|
+
return Wavelet(wave * jnp.sqrt(2), time=t_bins, freq=f_bins)
|
69
69
|
|
70
70
|
|
71
71
|
def from_freq_to_wavelet(
|
@@ -98,13 +98,9 @@ def from_freq_to_wavelet(
|
|
98
98
|
"""
|
99
99
|
Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
|
100
100
|
t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
|
101
|
-
phif = jnp.array(phitilde_vec_norm(Nf, Nt,
|
102
|
-
wave =
|
101
|
+
phif = jnp.array(phitilde_vec_norm(Nf, Nt, d=nx))
|
102
|
+
wave = transform_wavelet_freq_helper(
|
103
103
|
freqseries.data, Nf=Nf, Nt=Nt, phif=phif
|
104
104
|
)
|
105
105
|
|
106
|
-
return Wavelet(
|
107
|
-
(2 / Nf) * wave * jnp.sqrt(2),
|
108
|
-
time=t_bins,
|
109
|
-
freq=f_bins
|
110
|
-
)
|
106
|
+
return Wavelet((2 / Nf) * wave * jnp.sqrt(2), time=t_bins, freq=f_bins)
|