pywavelet 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/_version.py +2 -2
- pywavelet/backend.py +9 -4
- pywavelet/transforms/__init__.py +18 -7
- pywavelet/transforms/jax/__init__.py +12 -0
- pywavelet/transforms/jax/forward/__init__.py +0 -3
- pywavelet/transforms/jax/forward/from_freq.py +71 -30
- pywavelet/transforms/jax/inverse/__init__.py +3 -0
- pywavelet/transforms/jax/inverse/main.py +3 -4
- {pywavelet-0.2.2.dist-info → pywavelet-0.2.3.dist-info}/METADATA +1 -1
- {pywavelet-0.2.2.dist-info → pywavelet-0.2.3.dist-info}/RECORD +12 -12
- {pywavelet-0.2.2.dist-info → pywavelet-0.2.3.dist-info}/WHEEL +0 -0
- {pywavelet-0.2.2.dist-info → pywavelet-0.2.3.dist-info}/top_level.txt +0 -0
pywavelet/_version.py
CHANGED
pywavelet/backend.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
|
3
|
+
from .logger import logger
|
4
|
+
|
3
5
|
try:
|
4
6
|
import jax
|
5
7
|
|
@@ -13,14 +15,17 @@ use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
|
|
13
15
|
|
14
16
|
if use_jax:
|
15
17
|
import jax.numpy as xp # type: ignore
|
16
|
-
from jax.
|
18
|
+
from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
|
17
19
|
from jax.scipy.special import betainc # type: ignore
|
18
20
|
|
21
|
+
logger.info("Using JAX backend")
|
19
22
|
|
20
23
|
else:
|
21
24
|
import numpy as xp # type: ignore
|
22
|
-
from numpy.fft import fft, ifft,
|
23
|
-
from scipy.special import betainc
|
25
|
+
from numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
|
26
|
+
from scipy.special import betainc # type: ignore
|
27
|
+
|
28
|
+
logger.info("Using NumPy+numba backend")
|
24
29
|
|
25
30
|
|
26
|
-
PI = xp.pi
|
31
|
+
PI = xp.pi
|
pywavelet/transforms/__init__.py
CHANGED
@@ -1,10 +1,21 @@
|
|
1
|
-
from
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
1
|
+
from ..backend import use_jax
|
2
|
+
|
3
|
+
if use_jax:
|
4
|
+
from .jax import (
|
5
|
+
from_freq_to_wavelet,
|
6
|
+
from_time_to_wavelet,
|
7
|
+
from_wavelet_to_freq,
|
8
|
+
from_wavelet_to_time,
|
9
|
+
)
|
10
|
+
else:
|
11
|
+
from .numpy import (
|
12
|
+
from_wavelet_to_time,
|
13
|
+
from_wavelet_to_freq,
|
14
|
+
from_time_to_wavelet,
|
15
|
+
from_freq_to_wavelet,
|
16
|
+
)
|
17
|
+
|
18
|
+
from .phi_computer import phi_vec, phitilde_vec, phitilde_vec_norm
|
8
19
|
|
9
20
|
__all__ = [
|
10
21
|
"from_wavelet_to_time",
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from ...logger import logger
|
2
|
+
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
3
|
+
from .inverse import from_wavelet_to_freq, from_wavelet_to_time
|
4
|
+
|
5
|
+
logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"from_wavelet_to_time",
|
9
|
+
"from_wavelet_to_freq",
|
10
|
+
"from_time_to_wavelet",
|
11
|
+
"from_freq_to_wavelet",
|
12
|
+
]
|
@@ -1,56 +1,97 @@
|
|
1
|
-
import jax
|
2
|
-
import jax.numpy as jnp
|
3
1
|
from functools import partial
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
4
|
from jax import jit
|
5
5
|
from jax.numpy.fft import ifft
|
6
6
|
|
7
|
-
|
7
|
+
|
8
|
+
@partial(jit, static_argnames=("Nf", "Nt"))
|
8
9
|
def transform_wavelet_freq_helper(
|
9
|
-
|
10
|
+
data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
|
10
11
|
) -> jnp.ndarray:
|
11
|
-
|
12
|
+
"""
|
13
|
+
Transforms input data from the frequency domain to the wavelet domain using a
|
14
|
+
pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
|
15
|
+
|
16
|
+
Parameters:
|
17
|
+
- data (jnp.ndarray): 1D array representing the input data in the frequency domain.
|
18
|
+
- Nf (int): Number of frequency bins.
|
19
|
+
- Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
|
20
|
+
- phif (jnp.ndarray): Pre-computed wavelet filter for frequency components.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
- wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
|
24
|
+
"""
|
25
|
+
|
26
|
+
# Initialize the wavelet output array with zeros (time-rows, frequency-columns)
|
12
27
|
wave = jnp.zeros((Nt, Nf))
|
13
|
-
f_bins = jnp.arange(Nf)
|
28
|
+
f_bins = jnp.arange(Nf) # Frequency bin indices
|
14
29
|
|
30
|
+
# Compute base indices for time (i_base) and frequency (jj_base)
|
15
31
|
i_base = Nt // 2
|
16
32
|
jj_base = f_bins * Nt // 2
|
17
33
|
|
34
|
+
# Set initial values for the center of the transformation
|
18
35
|
initial_values = jnp.where(
|
19
|
-
(f_bins == 0)
|
20
|
-
|
21
|
-
phif[0] * data[f_bins * Nt // 2]
|
36
|
+
(f_bins == 0)
|
37
|
+
| (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
|
38
|
+
phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
|
39
|
+
phif[0] * data[f_bins * Nt // 2],
|
22
40
|
)
|
23
41
|
|
24
|
-
|
25
|
-
DX =
|
42
|
+
# Initialize a 2D array to store intermediate FFT input values
|
43
|
+
DX = jnp.zeros(
|
44
|
+
(Nf, Nt), dtype=jnp.complex64
|
45
|
+
) # TODO: Check dtype -- is complex64 sufficient?
|
46
|
+
DX = DX.at[:, Nt // 2].set(
|
47
|
+
initial_values
|
48
|
+
) # Set initial values at the center of the transformation (2 sided FFT)
|
26
49
|
|
27
|
-
|
28
|
-
|
29
|
-
|
50
|
+
# Compute time indices for all offsets around the midpoint
|
51
|
+
j_range = jnp.arange(
|
52
|
+
1 - Nt // 2, Nt // 2
|
53
|
+
) # Time offsets (centered around zero)
|
54
|
+
j = jnp.abs(j_range) # Absolute offset indices
|
55
|
+
i = i_base + j_range # Time indices in output array
|
30
56
|
|
31
|
-
|
32
|
-
|
33
|
-
|
57
|
+
# Compute conditions for edge cases
|
58
|
+
cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
|
59
|
+
cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
|
60
|
+
cond3 = j[None, :] == 0 # Center of the transformation (no offset)
|
34
61
|
|
35
|
-
|
36
|
-
|
37
|
-
|
62
|
+
# Compute frequency indices for the input data
|
63
|
+
jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
|
64
|
+
val = jnp.where(
|
65
|
+
cond1 | cond2, 0.0, phif[j] * data[jj]
|
66
|
+
) # Wavelet filter application
|
67
|
+
DX = DX.at[:, i].set(
|
68
|
+
jnp.where(cond3, DX[:, i], val)
|
69
|
+
) # Update DX with computed values
|
70
|
+
# At this point, DX contains the data FFT'd with the wavelet filter
|
71
|
+
# (each row is a frequency bin, each column is a time bin)
|
38
72
|
|
39
|
-
#
|
73
|
+
# Perform the inverse FFT along the time dimension
|
40
74
|
DX_trans = ifft(DX, axis=1)
|
41
75
|
|
42
|
-
#
|
43
|
-
n_range = jnp.arange(Nt)
|
44
|
-
cond1 = (
|
45
|
-
|
76
|
+
# Fill the wavelet output array based on the inverse FFT results
|
77
|
+
n_range = jnp.arange(Nt) # Time indices
|
78
|
+
cond1 = (
|
79
|
+
n_range[:, None] + f_bins[None, :]
|
80
|
+
) % 2 == 1 # Odd/even alternation
|
81
|
+
cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
|
46
82
|
|
83
|
+
# Assign real and imaginary parts based on conditions
|
47
84
|
real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
|
48
85
|
imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
|
49
|
-
|
50
86
|
wave = jnp.where(cond1, imag_part.T, real_part.T)
|
51
87
|
|
52
|
-
|
53
|
-
wave = wave.at[::2, 0].set(
|
54
|
-
|
88
|
+
# Special cases for frequency bins 0 (DC) and Nf (Nyquist)
|
89
|
+
wave = wave.at[::2, 0].set(
|
90
|
+
jnp.real(DX_trans[0, ::2] * jnp.sqrt(2))
|
91
|
+
) # DC component
|
92
|
+
wave = wave.at[1::2, -1].set(
|
93
|
+
jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2))
|
94
|
+
) # Nyquist component
|
55
95
|
|
56
|
-
|
96
|
+
# Return the wavelet-transformed array (transposed for freq-major layout)
|
97
|
+
return wave.T # (Nt, Nf) -> (Nf, Nt)
|
@@ -1,10 +1,9 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
2
|
from jax.numpy.fft import rfftfreq
|
3
3
|
|
4
|
-
from ...phi_computer import phi_vec, phitilde_vec_norm
|
5
4
|
from ....types import FrequencySeries, TimeSeries, Wavelet
|
5
|
+
from ...phi_computer import phi_vec, phitilde_vec_norm
|
6
6
|
from .to_freq import inverse_wavelet_freq_helper
|
7
|
-
# from .inverse_wavelet_time_funcs import inverse_wavelet_time_helper
|
8
7
|
|
9
8
|
|
10
9
|
def from_wavelet_to_time(
|
@@ -65,5 +64,5 @@ def from_wavelet_to_freq(
|
|
65
64
|
-1 / 2
|
66
65
|
) # Normalise to get the proper backwards transformation
|
67
66
|
|
68
|
-
freqs = rfftfreq(wave_in.ND*2, d=dt)[1:]
|
69
|
-
return FrequencySeries(data=freq_data, freq=freqs)
|
67
|
+
freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
|
68
|
+
return FrequencySeries(data=freq_data, freq=freqs)
|
@@ -1,17 +1,17 @@
|
|
1
1
|
pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
|
2
|
-
pywavelet/_version.py,sha256=
|
3
|
-
pywavelet/backend.py,sha256=
|
2
|
+
pywavelet/_version.py,sha256=AaQEeqeDwmZAHoPuwg2C0ulADePbIYLSFanZzt0cytQ,411
|
3
|
+
pywavelet/backend.py,sha256=SmpgIBHvTO1rtIAQQN_zpVB8i6R-x23FNKJG6_JlrNs,666
|
4
4
|
pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
|
5
5
|
pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
|
6
|
-
pywavelet/transforms/__init__.py,sha256=
|
6
|
+
pywavelet/transforms/__init__.py,sha256=EYX8glRWojYbrjtbgrjS4vigYTRi7FOtIV3D1UwI5fY,604
|
7
7
|
pywavelet/transforms/phi_computer.py,sha256=ppFSGJwtNnO2flaiok9ms3WXlAxGQikvA7eNfLgriNQ,4461
|
8
|
-
pywavelet/transforms/jax/__init__.py,sha256=
|
9
|
-
pywavelet/transforms/jax/forward/__init__.py,sha256=
|
10
|
-
pywavelet/transforms/jax/forward/from_freq.py,sha256=
|
8
|
+
pywavelet/transforms/jax/__init__.py,sha256=D_f-JgFAzOIJ-EuQZhTMziD4MT6lVWS3XV9s51Cu7Kg,335
|
9
|
+
pywavelet/transforms/jax/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
|
10
|
+
pywavelet/transforms/jax/forward/from_freq.py,sha256=tKEdqPyEvX8ZKVQf16wGxN3d6gkcjm_RtAHQuWHUzy4,3764
|
11
11
|
pywavelet/transforms/jax/forward/from_time.py,sha256=xNeoZq54B6Gi3TdTTYLr_euaFeJcwpms-lSyCG53AdI,1726
|
12
12
|
pywavelet/transforms/jax/forward/main.py,sha256=mm0R4m0pXcnzZB0jCckAc4ynG8STH5mldCmHyyU_PGo,3091
|
13
|
-
pywavelet/transforms/jax/inverse/__init__.py,sha256=
|
14
|
-
pywavelet/transforms/jax/inverse/main.py,sha256
|
13
|
+
pywavelet/transforms/jax/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
|
14
|
+
pywavelet/transforms/jax/inverse/main.py,sha256=-HVOOBsYo3GJvGNCsQLbNPnt9s14JvbB2bGAd9LOr3A,1647
|
15
15
|
pywavelet/transforms/jax/inverse/to_freq.py,sha256=ASNARcDBJQr4EizAP_77e5ai36iPwP6hzfvwGbZQ6BM,2295
|
16
16
|
pywavelet/transforms/numpy/__init__.py,sha256=qFLpGpW3VJSbDp2JpD0Gx7PdwDjH-wrW_aO84ASkIgA,255
|
17
17
|
pywavelet/transforms/numpy/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
|
@@ -29,7 +29,7 @@ pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,1
|
|
29
29
|
pywavelet/types/timeseries.py,sha256=u35bIqFo3QdlQRBEu6maeWA7DePS11LQ6WMiLjZPcWo,9456
|
30
30
|
pywavelet/types/wavelet.py,sha256=el48oyAfwtSw2tCQLUb85F9lKr0qMSRJPUmAUU8TS50,12552
|
31
31
|
pywavelet/types/wavelet_bins.py,sha256=GoQGKeZlPc-KbYY7LoxAhB-HI4diHpPcTABBXRfUTLA,1459
|
32
|
-
pywavelet-0.2.
|
33
|
-
pywavelet-0.2.
|
34
|
-
pywavelet-0.2.
|
35
|
-
pywavelet-0.2.
|
32
|
+
pywavelet-0.2.3.dist-info/METADATA,sha256=IGnMbmU9Cer13p5ZpYNmdcWKnWG1J0p1p_BN7_I1smE,2241
|
33
|
+
pywavelet-0.2.3.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
34
|
+
pywavelet-0.2.3.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
|
35
|
+
pywavelet-0.2.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|