pywavelet 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/__init__.py +17 -5
- pywavelet/_version.py +2 -2
- pywavelet/backend.py +98 -20
- pywavelet/logger.py +6 -6
- pywavelet/transforms/__init__.py +1 -3
- pywavelet/transforms/cupy/forward/from_freq.py +64 -67
- pywavelet/transforms/cupy/forward/main.py +11 -7
- pywavelet/transforms/cupy/inverse/to_freq.py +54 -50
- pywavelet/transforms/jax/forward/from_freq.py +69 -76
- pywavelet/transforms/jax/forward/main.py +9 -6
- pywavelet/transforms/jax/inverse/to_freq.py +17 -28
- pywavelet/transforms/numpy/forward/from_freq.py +14 -6
- pywavelet/transforms/numpy/forward/main.py +13 -4
- pywavelet/transforms/phi_computer.py +35 -20
- pywavelet/types/common.py +1 -1
- pywavelet/types/plotting.py +1 -1
- pywavelet/types/timeseries.py +1 -0
- pywavelet/types/wavelet.py +4 -2
- pywavelet/types/wavelet_bins.py +3 -9
- pywavelet/utils/__init__.py +6 -0
- pywavelet/{utils.py → utils/analysis.py} +1 -1
- pywavelet/utils/timing_cli/__init__.py +0 -0
- pywavelet/utils/timing_cli/cli.py +95 -0
- pywavelet/utils/timing_cli/collect_runtimes.py +192 -0
- pywavelet/utils/timing_cli/plot.py +76 -0
- {pywavelet-0.2.7.dist-info → pywavelet-0.2.8.dist-info}/METADATA +3 -1
- pywavelet-0.2.8.dist-info/RECORD +49 -0
- {pywavelet-0.2.7.dist-info → pywavelet-0.2.8.dist-info}/WHEEL +1 -1
- pywavelet-0.2.8.dist-info/entry_points.txt +2 -0
- pywavelet-0.2.7.dist-info/RECORD +0 -43
- {pywavelet-0.2.7.dist-info → pywavelet-0.2.8.dist-info}/top_level.txt +0 -0
pywavelet/__init__.py
CHANGED
@@ -6,11 +6,10 @@ import importlib
|
|
6
6
|
import os
|
7
7
|
|
8
8
|
from . import backend as _backend
|
9
|
+
from ._version import __version__
|
9
10
|
|
10
|
-
__version__ = "0.0.2"
|
11
11
|
|
12
|
-
|
13
|
-
def set_backend(backend: str):
|
12
|
+
def set_backend(backend: str, precision: str = "float32") -> None:
|
14
13
|
"""Set the backend for the wavelet transform.
|
15
14
|
|
16
15
|
Parameters
|
@@ -18,10 +17,23 @@ def set_backend(backend: str):
|
|
18
17
|
backend : str
|
19
18
|
Backend to use. Options are "numpy", "jax", "cupy".
|
20
19
|
"""
|
21
|
-
from . import types
|
22
|
-
|
20
|
+
from . import transforms, types
|
21
|
+
|
23
22
|
os.environ["PYWAVELET_BACKEND"] = backend
|
23
|
+
os.environ["PYWAVELET_PRECISION"] = precision
|
24
24
|
|
25
25
|
importlib.reload(_backend)
|
26
26
|
importlib.reload(types)
|
27
27
|
importlib.reload(transforms)
|
28
|
+
if backend == "cupy":
|
29
|
+
importlib.reload(transforms.cupy)
|
30
|
+
importlib.reload(transforms.cupy.forward)
|
31
|
+
importlib.reload(transforms.cupy.inverse)
|
32
|
+
elif backend == "jax":
|
33
|
+
importlib.reload(transforms.jax)
|
34
|
+
importlib.reload(transforms.jax.forward)
|
35
|
+
importlib.reload(transforms.jax.inverse)
|
36
|
+
else:
|
37
|
+
importlib.reload(transforms.numpy)
|
38
|
+
importlib.reload(transforms.numpy.forward)
|
39
|
+
importlib.reload(transforms.numpy.inverse)
|
pywavelet/_version.py
CHANGED
pywavelet/backend.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
import importlib
|
2
2
|
import os
|
3
|
-
from
|
4
|
-
from rich.console import Console
|
5
|
-
|
3
|
+
from typing import Tuple
|
6
4
|
|
5
|
+
import numpy as np
|
6
|
+
from rich.console import Console
|
7
|
+
from rich.table import Table, Text
|
7
8
|
|
8
9
|
from .logger import logger
|
9
10
|
|
@@ -11,22 +12,25 @@ JAX = "jax"
|
|
11
12
|
CUPY = "cupy"
|
12
13
|
NUMPY = "numpy"
|
13
14
|
|
15
|
+
VALID_PRECISIONS = ["float32", "float64"]
|
16
|
+
|
14
17
|
|
15
18
|
def cuda_is_available() -> bool:
|
16
19
|
"""Check if CUDA is available."""
|
17
20
|
# Check if CuPy is available and CUDA is accessible
|
18
21
|
cupy_available = importlib.util.find_spec("cupy") is not None
|
22
|
+
_cuda_available = False
|
19
23
|
if cupy_available:
|
20
24
|
import cupy
|
21
25
|
|
22
26
|
try:
|
23
27
|
cupy.cuda.runtime.getDeviceCount() # Check if any CUDA device is available
|
24
|
-
|
28
|
+
_cuda_available = True
|
25
29
|
except cupy.cuda.runtime.CUDARuntimeError:
|
26
|
-
|
30
|
+
_cuda_available = False
|
27
31
|
else:
|
28
|
-
|
29
|
-
return
|
32
|
+
_cuda_available = False
|
33
|
+
return _cuda_available
|
30
34
|
|
31
35
|
|
32
36
|
def jax_is_available() -> bool:
|
@@ -51,9 +55,21 @@ def get_available_backends_table():
|
|
51
55
|
return Text.from_ansi(capture.get())
|
52
56
|
|
53
57
|
|
58
|
+
def log_backend(level="info"):
|
59
|
+
"""Print the current backend and precision."""
|
60
|
+
backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
|
61
|
+
precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
|
62
|
+
str = f"Current backend: {backend}[{precision}]"
|
63
|
+
if level == "info":
|
64
|
+
logger.info(str)
|
65
|
+
elif level == "debug":
|
66
|
+
logger.debug(str)
|
67
|
+
|
68
|
+
|
54
69
|
def get_backend_from_env():
|
55
70
|
"""Select and return the appropriate backend module."""
|
56
71
|
backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
|
72
|
+
precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
|
57
73
|
|
58
74
|
if backend == JAX and jax_is_available():
|
59
75
|
|
@@ -61,41 +77,103 @@ def get_backend_from_env():
|
|
61
77
|
from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
|
62
78
|
from jax.scipy.special import betainc
|
63
79
|
|
64
|
-
logger.info("Using JAX backend")
|
65
|
-
|
66
80
|
elif backend == CUPY and cuda_is_available():
|
67
81
|
|
68
82
|
import cupy as xp
|
69
83
|
from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
|
70
84
|
from cupyx.scipy.special import betainc
|
71
85
|
|
72
|
-
logger.info("Using CuPy backend")
|
73
|
-
|
74
86
|
elif backend == NUMPY:
|
75
87
|
import numpy as xp
|
76
88
|
from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
|
77
89
|
from scipy.special import betainc
|
78
90
|
|
79
|
-
logger.info("Using NumPy backend")
|
80
|
-
|
81
|
-
|
82
91
|
else:
|
83
|
-
logger.error(
|
84
|
-
f"Backend {backend} is not available. "
|
85
|
-
)
|
92
|
+
logger.error(f"Backend {backend}[{precision}] is not available. ")
|
86
93
|
print(get_available_backends_table())
|
87
|
-
logger.warning(
|
88
|
-
f"Setting backend to NumPy. "
|
89
|
-
)
|
94
|
+
logger.warning(f"Setting backend to NumPy. ")
|
90
95
|
os.environ["PYWAVELET_BACKEND"] = NUMPY
|
91
96
|
return get_backend_from_env()
|
92
97
|
|
98
|
+
log_backend("debug")
|
93
99
|
return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
|
94
100
|
|
95
101
|
|
102
|
+
def get_precision_from_env() -> str:
|
103
|
+
"""Get the precision from the environment variable."""
|
104
|
+
precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
|
105
|
+
if precision not in VALID_PRECISIONS:
|
106
|
+
logger.error(
|
107
|
+
f"Precision {precision} is not supported, defaulting to float32."
|
108
|
+
)
|
109
|
+
precision = "float32"
|
110
|
+
return precision
|
111
|
+
|
112
|
+
|
113
|
+
def set_precision(precision: str) -> None:
|
114
|
+
"""Set the precision for the backend."""
|
115
|
+
precision = precision.lower()
|
116
|
+
if precision not in VALID_PRECISIONS:
|
117
|
+
logger.error(f"Precision {precision} is not supported.")
|
118
|
+
return
|
119
|
+
else:
|
120
|
+
os.environ["PYWAVELET_PRECISION"] = precision
|
121
|
+
logger.info(f"Setting precision to {precision}.")
|
122
|
+
return
|
123
|
+
|
124
|
+
|
125
|
+
def get_dtype_from_env() -> Tuple[np.dtype, np.dtype]:
|
126
|
+
"""Get the data type from the environment variable."""
|
127
|
+
precision = get_precision_from_env()
|
128
|
+
backend = get_backend_from_env()[-1]
|
129
|
+
if backend == JAX:
|
130
|
+
|
131
|
+
if precision == "float32":
|
132
|
+
import jax
|
133
|
+
|
134
|
+
jax.config.update("jax_enable_x64", False)
|
135
|
+
|
136
|
+
import jax.numpy as jnp
|
137
|
+
|
138
|
+
float_dtype = jnp.float32
|
139
|
+
complex_dtype = jnp.complex64
|
140
|
+
elif precision == "float64":
|
141
|
+
import jax
|
142
|
+
|
143
|
+
jax.config.update("jax_enable_x64", True)
|
144
|
+
|
145
|
+
import jax.numpy as jnp
|
146
|
+
|
147
|
+
float_dtype = jnp.float64
|
148
|
+
complex_dtype = jnp.complex128
|
149
|
+
|
150
|
+
elif backend == CUPY:
|
151
|
+
import cupy as cp
|
152
|
+
|
153
|
+
if precision == "float32":
|
154
|
+
float_dtype = cp.float32
|
155
|
+
complex_dtype = cp.complex64
|
156
|
+
elif precision == "float64":
|
157
|
+
float_dtype = cp.float64
|
158
|
+
complex_dtype = cp.complex128
|
159
|
+
|
160
|
+
else:
|
161
|
+
if precision == "float32":
|
162
|
+
float_dtype = np.float32
|
163
|
+
complex_dtype = np.complex64
|
164
|
+
elif precision == "float64":
|
165
|
+
float_dtype = np.float64
|
166
|
+
complex_dtype = np.complex128
|
167
|
+
|
168
|
+
return float_dtype, complex_dtype
|
169
|
+
|
170
|
+
|
96
171
|
cuda_available = cuda_is_available()
|
97
172
|
|
98
173
|
# Get the chosen backend
|
99
174
|
xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
|
100
175
|
get_backend_from_env()
|
101
176
|
)
|
177
|
+
|
178
|
+
# Get the chosen precision
|
179
|
+
float_dtype, complex_dtype = get_dtype_from_env()
|
pywavelet/logger.py
CHANGED
@@ -5,14 +5,14 @@ import warnings
|
|
5
5
|
from rich.logging import RichHandler
|
6
6
|
|
7
7
|
FORMAT = "%(message)s"
|
8
|
-
logging.basicConfig(
|
9
|
-
level="INFO",
|
10
|
-
format=FORMAT,
|
11
|
-
datefmt="[%X]",
|
12
|
-
handlers=[RichHandler(rich_tracebacks=True)],
|
13
|
-
)
|
14
8
|
|
15
9
|
logger = logging.getLogger("pywavelet")
|
10
|
+
if not logger.hasHandlers():
|
11
|
+
handler = RichHandler(rich_tracebacks=True)
|
12
|
+
formatter = logging.Formatter(fmt=FORMAT, datefmt="[%X]")
|
13
|
+
handler.setFormatter(formatter)
|
14
|
+
logger.addHandler(handler)
|
15
|
+
logger.setLevel(logging.INFO)
|
16
16
|
|
17
17
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
18
18
|
warnings.filterwarnings("ignore", category=UserWarning)
|
pywavelet/transforms/__init__.py
CHANGED
@@ -1,14 +1,18 @@
|
|
1
|
+
import logging
|
2
|
+
|
1
3
|
import cupy as cp
|
2
4
|
from cupyx.scipy.fft import ifft
|
3
5
|
|
4
|
-
|
5
|
-
import logging
|
6
|
-
|
7
|
-
logger = logging.getLogger('pywavelet')
|
6
|
+
logger = logging.getLogger("pywavelet")
|
8
7
|
|
9
8
|
|
10
9
|
def transform_wavelet_freq_helper(
|
11
|
-
data: cp.ndarray,
|
10
|
+
data: cp.ndarray,
|
11
|
+
Nf: int,
|
12
|
+
Nt: int,
|
13
|
+
phif: cp.ndarray,
|
14
|
+
float_dtype=cp.float64,
|
15
|
+
complex_dtype=cp.complex128,
|
12
16
|
) -> cp.ndarray:
|
13
17
|
"""
|
14
18
|
Transforms input data from the frequency domain to the wavelet domain using a
|
@@ -24,69 +28,62 @@ def transform_wavelet_freq_helper(
|
|
24
28
|
- wave (cp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
|
25
29
|
"""
|
26
30
|
|
27
|
-
logger.debug(
|
28
|
-
|
29
|
-
# Initialize the wavelet output array with zeros (time-rows, frequency-columns)
|
30
|
-
wave = cp.zeros((Nt, Nf))
|
31
|
-
f_bins = cp.arange(Nf) # Frequency bin indices
|
32
|
-
|
33
|
-
# Compute base indices for time (i_base) and frequency (jj_base)
|
34
|
-
i_base = Nt // 2
|
35
|
-
jj_base = f_bins * Nt // 2
|
36
|
-
|
37
|
-
# Set initial values for the center of the transformation
|
38
|
-
initial_values = cp.where(
|
39
|
-
(f_bins == 0)
|
40
|
-
| (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
|
41
|
-
phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
|
42
|
-
phif[0] * data[f_bins * Nt // 2],
|
31
|
+
logger.debug(
|
32
|
+
f"[CUPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}, precision:{data.dtype}]"
|
43
33
|
)
|
44
34
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
35
|
+
half = Nt // 2
|
36
|
+
f_bins = cp.arange(Nf + 1) # [0, 1, ..., Nf]
|
37
|
+
|
38
|
+
# 1) Build (Nf+1, Nt) DX
|
39
|
+
# — center:
|
40
|
+
center = phif[0] * data[f_bins * half]
|
41
|
+
center = cp.where((f_bins == 0) | (f_bins == Nf), center * 0.5, center)
|
42
|
+
DX = cp.zeros((Nf + 1, Nt), complex_dtype)
|
43
|
+
DX[:, half] = center
|
44
|
+
|
45
|
+
# — off-center (j = ±1...±(half-1))
|
46
|
+
offs = cp.arange(1 - half, half) # length Nt-1
|
47
|
+
jj = f_bins[:, None] * half + offs[None, :] # (Nf+1, Nt-1)
|
48
|
+
ii = half + offs # (Nt-1,)
|
49
|
+
mask = ((f_bins[:, None] == Nf) & (offs[None, :] > 0)) | (
|
50
|
+
(f_bins[:, None] == 0) & (offs[None, :] < 0)
|
51
|
+
)
|
52
|
+
vals = phif[cp.abs(offs)] * data[jj]
|
53
|
+
vals = cp.where(mask, 0.0, vals)
|
54
|
+
DX[:, ii] = vals
|
55
|
+
|
56
|
+
# 2) IFFT along time
|
57
|
+
DXt = ifft(DX, axis=1)
|
58
|
+
|
59
|
+
# 3) Unpack into wave (Nt, Nf)
|
60
|
+
wave = cp.zeros((Nt, Nf), float_dtype)
|
61
|
+
sqrt2 = cp.sqrt(2.0)
|
62
|
+
|
63
|
+
# 3a) DC into col 0, even rows
|
64
|
+
evens = cp.arange(0, Nt, 2)
|
65
|
+
wave[evens, 0] = cp.real(DXt[0, evens]) * sqrt2
|
66
|
+
|
67
|
+
# 3b) Nyquist into col 0, odd rows
|
68
|
+
odds = cp.arange(1, Nt, 2)
|
69
|
+
wave[odds, 0] = cp.real(DXt[Nf, evens]) * sqrt2
|
70
|
+
|
71
|
+
# 3c) intermediate bins 1...Nf-1
|
72
|
+
mids = cp.arange(1, Nf) # [1...Nf-1]
|
73
|
+
Dt_mid = DXt[mids, :] # (Nf-1, Nt)
|
74
|
+
real_m = cp.real(Dt_mid).T # (Nt, Nf-1)
|
75
|
+
imag_m = cp.imag(Dt_mid).T # (Nt, Nf-1)
|
76
|
+
|
77
|
+
odd_f = (mids % 2) == 1 # shape (Nf-1,)
|
78
|
+
n_idx = cp.arange(Nt)[:, None] # (Nt,1)
|
79
|
+
odd_nf = ((n_idx + mids[None, :]) % 2) == 1
|
80
|
+
|
81
|
+
# Select real vs imag and sign exactly as in __fill_wave_2
|
82
|
+
mid_vals = cp.where(
|
83
|
+
odd_nf,
|
84
|
+
cp.where(odd_f, -imag_m, imag_m),
|
85
|
+
cp.where(odd_f, real_m, real_m),
|
49
86
|
)
|
87
|
+
wave[:, 1:Nf] = mid_vals
|
50
88
|
|
51
|
-
|
52
|
-
j_range = cp.arange(
|
53
|
-
1 - Nt // 2, Nt // 2
|
54
|
-
) # Time offsets (centered around zero)
|
55
|
-
j = cp.abs(j_range) # Absolute offset indices
|
56
|
-
i = i_base + j_range # Time indices in output array
|
57
|
-
|
58
|
-
# Compute conditions for edge cases
|
59
|
-
cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
|
60
|
-
cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
|
61
|
-
cond3 = j[None, :] == 0 # Center of the transformation (no offset)
|
62
|
-
|
63
|
-
# Compute frequency indices for the input data
|
64
|
-
jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
|
65
|
-
val = cp.where(
|
66
|
-
cond1 | cond2, 0.0, phif[j] * data[jj]
|
67
|
-
) # Wavelet filter application
|
68
|
-
DX[:, i] = cp.where(cond3, DX[:, i], val) # Update DX with computed values
|
69
|
-
|
70
|
-
# Perform the inverse FFT along the time dimension
|
71
|
-
DX_trans = ifft(DX, axis=1)
|
72
|
-
|
73
|
-
# Fill the wavelet output array based on the inverse FFT results
|
74
|
-
n_range = cp.arange(Nt) # Time indices
|
75
|
-
cond1 = (
|
76
|
-
n_range[:, None] + f_bins[None, :]
|
77
|
-
) % 2 == 1 # Odd/even alternation
|
78
|
-
cond2 = cp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
|
79
|
-
|
80
|
-
# Assign real and imaginary parts based on conditions
|
81
|
-
real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))
|
82
|
-
imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))
|
83
|
-
wave = cp.where(cond1, imag_part.T, real_part.T)
|
84
|
-
|
85
|
-
# Special cases for frequency bins 0 (DC) and Nf (Nyquist)
|
86
|
-
wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2)) # DC component
|
87
|
-
wave[1::2, -1] = cp.real(
|
88
|
-
DX_trans[-1, ::2] * cp.sqrt(2)
|
89
|
-
) # Nyquist component
|
90
|
-
|
91
|
-
# Return the wavelet-transformed array (transposed for freq-major layout)
|
92
|
-
return wave.T # (Nt, Nf) -> (Nf, Nt)
|
89
|
+
return wave.T
|
@@ -2,6 +2,7 @@ from typing import Union
|
|
2
2
|
|
3
3
|
import cupy as cp
|
4
4
|
|
5
|
+
from .... import backend
|
5
6
|
from ....logger import logger
|
6
7
|
from ....types import FrequencySeries, TimeSeries, Wavelet
|
7
8
|
from ....types.wavelet_bins import _get_bins, _preprocess_bins
|
@@ -16,7 +17,6 @@ def from_time_to_wavelet(
|
|
16
17
|
Nt: Union[int, None] = None,
|
17
18
|
nx: float = 4.0,
|
18
19
|
mult: int = 32,
|
19
|
-
**kwargs,
|
20
20
|
) -> Wavelet:
|
21
21
|
"""Transforms time-domain data to wavelet-domain data.
|
22
22
|
|
@@ -45,7 +45,6 @@ def from_time_to_wavelet(
|
|
45
45
|
|
46
46
|
"""
|
47
47
|
Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
|
48
|
-
dt = timeseries.dt
|
49
48
|
t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
|
50
49
|
|
51
50
|
ND = Nf * Nt
|
@@ -73,7 +72,6 @@ def from_freq_to_wavelet(
|
|
73
72
|
Nf: Union[int, None] = None,
|
74
73
|
Nt: Union[int, None] = None,
|
75
74
|
nx: float = 4.0,
|
76
|
-
**kwargs,
|
77
75
|
) -> Wavelet:
|
78
76
|
"""Transforms frequency-domain data to wavelet-domain data.
|
79
77
|
|
@@ -98,9 +96,15 @@ def from_freq_to_wavelet(
|
|
98
96
|
"""
|
99
97
|
Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
|
100
98
|
t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
|
101
|
-
phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx))
|
99
|
+
phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx), dtype=backend.float_dtype)
|
100
|
+
data = cp.array(freqseries.data, dtype=backend.complex_dtype)
|
102
101
|
wave = transform_wavelet_freq_helper(
|
103
|
-
|
102
|
+
data,
|
103
|
+
Nf=Nf,
|
104
|
+
Nt=Nt,
|
105
|
+
phif=phif,
|
106
|
+
float_dtype=backend.float_dtype,
|
107
|
+
complex_dtype=backend.complex_dtype,
|
104
108
|
)
|
105
|
-
|
106
|
-
return Wavelet(
|
109
|
+
factor = (2 / Nf) * cp.sqrt(2)
|
110
|
+
return Wavelet(factor * wave, time=t_bins, freq=f_bins)
|
@@ -6,57 +6,61 @@ def inverse_wavelet_freq_helper(
|
|
6
6
|
wave_in: cp.ndarray, phif: cp.ndarray, Nf: int, Nt: int
|
7
7
|
) -> cp.ndarray:
|
8
8
|
"""CuPy vectorized function for inverse_wavelet_freq"""
|
9
|
-
|
10
|
-
|
9
|
+
wave = wave_in.T
|
10
|
+
ND2 = (Nf * Nt) // 2
|
11
|
+
half = Nt // 2
|
11
12
|
|
13
|
+
# === STEP 1: build prefactor2s[m, n] ===
|
12
14
|
m_range = cp.arange(Nf + 1)
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
# m
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
res[
|
45
|
-
|
46
|
-
#
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
15
|
+
pref2 = cp.zeros((Nf + 1, Nt), dtype=cp.complex128)
|
16
|
+
n = cp.arange(Nt)
|
17
|
+
|
18
|
+
# m=0 and m=Nf, with 1/√2
|
19
|
+
pref2[0, :] = wave[(2 * n) % Nt, 0] * (2 ** (-0.5))
|
20
|
+
pref2[Nf, :] = wave[((2 * n) % Nt) + 1, 0] * (2 ** (-0.5))
|
21
|
+
|
22
|
+
# middle m=1...Nf-1
|
23
|
+
m_mid = cp.arange(1, Nf)
|
24
|
+
# build meshgrids (m_mid rows, n cols)
|
25
|
+
mm, nn = cp.meshgrid(m_mid, n, indexing="ij")
|
26
|
+
vals = wave[nn, mm]
|
27
|
+
signs = cp.where(((nn + mm) % 2) == 1, -1j, 1)
|
28
|
+
pref2[1:Nf, :] = signs * vals
|
29
|
+
|
30
|
+
# === STEP 2: FFT along time axis ===
|
31
|
+
F = fft(pref2, axis=1) # shape (Nf+1, Nt)
|
32
|
+
|
33
|
+
# === STEP 3: unpack back into half-spectrum res[0...ND2] ===
|
34
|
+
res = cp.zeros(ND2 + 1, dtype=cp.complex128)
|
35
|
+
idx = cp.arange(half)
|
36
|
+
|
37
|
+
# 3a) contribution from m=0
|
38
|
+
res[idx] += F[0, (2 * idx) % Nt] * phif[idx]
|
39
|
+
|
40
|
+
# 3b) contribution from m=Nf
|
41
|
+
iNf = cp.abs(Nf * half - idx)
|
42
|
+
res[iNf] += F[Nf, (2 * idx) % Nt] * phif[idx]
|
43
|
+
|
44
|
+
# special Nyquist‐folding term
|
45
|
+
special = cp.abs(Nf * half - half)
|
46
|
+
res[special] += F[Nf, 0] * phif[half]
|
47
|
+
|
48
|
+
# 3c) middle m cases
|
49
|
+
m_mid = cp.arange(1, Nf)
|
50
|
+
i_mid = cp.arange(half) # 0...half-1
|
51
|
+
mm, ii = cp.meshgrid(m_mid, i_mid, indexing="ij")
|
52
|
+
i1 = (half * mm - ii) % (ND2 + 1)
|
53
|
+
i2 = (half * mm + ii) % (ND2 + 1)
|
54
|
+
ind1 = (half * mm - ii) % Nt
|
55
|
+
ind2 = (half * mm + ii) % Nt
|
56
|
+
|
57
|
+
# accumulate
|
58
|
+
res[i1] += F[mm, ind1] * phif[ii]
|
59
|
+
res[i2] += F[mm, ind2] * phif[ii]
|
60
|
+
|
61
|
+
# override the "center" points (j=0) exactly
|
62
|
+
centers = half * m_mid
|
63
|
+
fft_idx = centers % Nt
|
64
|
+
res[centers] = F[m_mid, fft_idx] * phif[0]
|
61
65
|
|
62
66
|
return res
|