pywavelet 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/_version.py +2 -2
- pywavelet/backend.py +9 -4
- pywavelet/transforms/__init__.py +18 -7
- pywavelet/transforms/jax/__init__.py +12 -0
- pywavelet/transforms/jax/forward/__init__.py +0 -3
- pywavelet/transforms/jax/forward/from_freq.py +71 -30
- pywavelet/transforms/jax/inverse/__init__.py +3 -0
- pywavelet/transforms/jax/inverse/main.py +3 -4
- pywavelet/types/wavelet.py +40 -6
- {pywavelet-0.2.2.dist-info → pywavelet-0.2.4.dist-info}/METADATA +1 -1
- {pywavelet-0.2.2.dist-info → pywavelet-0.2.4.dist-info}/RECORD +13 -13
- {pywavelet-0.2.2.dist-info → pywavelet-0.2.4.dist-info}/WHEEL +0 -0
- {pywavelet-0.2.2.dist-info → pywavelet-0.2.4.dist-info}/top_level.txt +0 -0
pywavelet/_version.py
CHANGED
pywavelet/backend.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
|
3
|
+
from .logger import logger
|
4
|
+
|
3
5
|
try:
|
4
6
|
import jax
|
5
7
|
|
@@ -13,14 +15,17 @@ use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
|
|
13
15
|
|
14
16
|
if use_jax:
|
15
17
|
import jax.numpy as xp # type: ignore
|
16
|
-
from jax.
|
18
|
+
from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
|
17
19
|
from jax.scipy.special import betainc # type: ignore
|
18
20
|
|
21
|
+
logger.info("Using JAX backend")
|
19
22
|
|
20
23
|
else:
|
21
24
|
import numpy as xp # type: ignore
|
22
|
-
from numpy.fft import fft, ifft,
|
23
|
-
from scipy.special import betainc
|
25
|
+
from numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
|
26
|
+
from scipy.special import betainc # type: ignore
|
27
|
+
|
28
|
+
logger.info("Using NumPy+numba backend")
|
24
29
|
|
25
30
|
|
26
|
-
PI = xp.pi
|
31
|
+
PI = xp.pi
|
pywavelet/transforms/__init__.py
CHANGED
@@ -1,10 +1,21 @@
|
|
1
|
-
from
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
1
|
+
from ..backend import use_jax
|
2
|
+
|
3
|
+
if use_jax:
|
4
|
+
from .jax import (
|
5
|
+
from_freq_to_wavelet,
|
6
|
+
from_time_to_wavelet,
|
7
|
+
from_wavelet_to_freq,
|
8
|
+
from_wavelet_to_time,
|
9
|
+
)
|
10
|
+
else:
|
11
|
+
from .numpy import (
|
12
|
+
from_wavelet_to_time,
|
13
|
+
from_wavelet_to_freq,
|
14
|
+
from_time_to_wavelet,
|
15
|
+
from_freq_to_wavelet,
|
16
|
+
)
|
17
|
+
|
18
|
+
from .phi_computer import phi_vec, phitilde_vec, phitilde_vec_norm
|
8
19
|
|
9
20
|
__all__ = [
|
10
21
|
"from_wavelet_to_time",
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from ...logger import logger
|
2
|
+
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
3
|
+
from .inverse import from_wavelet_to_freq, from_wavelet_to_time
|
4
|
+
|
5
|
+
logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"from_wavelet_to_time",
|
9
|
+
"from_wavelet_to_freq",
|
10
|
+
"from_time_to_wavelet",
|
11
|
+
"from_freq_to_wavelet",
|
12
|
+
]
|
@@ -1,56 +1,97 @@
|
|
1
|
-
import jax
|
2
|
-
import jax.numpy as jnp
|
3
1
|
from functools import partial
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
4
|
from jax import jit
|
5
5
|
from jax.numpy.fft import ifft
|
6
6
|
|
7
|
-
|
7
|
+
|
8
|
+
@partial(jit, static_argnames=("Nf", "Nt"))
|
8
9
|
def transform_wavelet_freq_helper(
|
9
|
-
|
10
|
+
data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
|
10
11
|
) -> jnp.ndarray:
|
11
|
-
|
12
|
+
"""
|
13
|
+
Transforms input data from the frequency domain to the wavelet domain using a
|
14
|
+
pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
|
15
|
+
|
16
|
+
Parameters:
|
17
|
+
- data (jnp.ndarray): 1D array representing the input data in the frequency domain.
|
18
|
+
- Nf (int): Number of frequency bins.
|
19
|
+
- Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
|
20
|
+
- phif (jnp.ndarray): Pre-computed wavelet filter for frequency components.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
- wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
|
24
|
+
"""
|
25
|
+
|
26
|
+
# Initialize the wavelet output array with zeros (time-rows, frequency-columns)
|
12
27
|
wave = jnp.zeros((Nt, Nf))
|
13
|
-
f_bins = jnp.arange(Nf)
|
28
|
+
f_bins = jnp.arange(Nf) # Frequency bin indices
|
14
29
|
|
30
|
+
# Compute base indices for time (i_base) and frequency (jj_base)
|
15
31
|
i_base = Nt // 2
|
16
32
|
jj_base = f_bins * Nt // 2
|
17
33
|
|
34
|
+
# Set initial values for the center of the transformation
|
18
35
|
initial_values = jnp.where(
|
19
|
-
(f_bins == 0)
|
20
|
-
|
21
|
-
phif[0] * data[f_bins * Nt // 2]
|
36
|
+
(f_bins == 0)
|
37
|
+
| (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
|
38
|
+
phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
|
39
|
+
phif[0] * data[f_bins * Nt // 2],
|
22
40
|
)
|
23
41
|
|
24
|
-
|
25
|
-
DX =
|
42
|
+
# Initialize a 2D array to store intermediate FFT input values
|
43
|
+
DX = jnp.zeros(
|
44
|
+
(Nf, Nt), dtype=jnp.complex64
|
45
|
+
) # TODO: Check dtype -- is complex64 sufficient?
|
46
|
+
DX = DX.at[:, Nt // 2].set(
|
47
|
+
initial_values
|
48
|
+
) # Set initial values at the center of the transformation (2 sided FFT)
|
26
49
|
|
27
|
-
|
28
|
-
|
29
|
-
|
50
|
+
# Compute time indices for all offsets around the midpoint
|
51
|
+
j_range = jnp.arange(
|
52
|
+
1 - Nt // 2, Nt // 2
|
53
|
+
) # Time offsets (centered around zero)
|
54
|
+
j = jnp.abs(j_range) # Absolute offset indices
|
55
|
+
i = i_base + j_range # Time indices in output array
|
30
56
|
|
31
|
-
|
32
|
-
|
33
|
-
|
57
|
+
# Compute conditions for edge cases
|
58
|
+
cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
|
59
|
+
cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
|
60
|
+
cond3 = j[None, :] == 0 # Center of the transformation (no offset)
|
34
61
|
|
35
|
-
|
36
|
-
|
37
|
-
|
62
|
+
# Compute frequency indices for the input data
|
63
|
+
jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
|
64
|
+
val = jnp.where(
|
65
|
+
cond1 | cond2, 0.0, phif[j] * data[jj]
|
66
|
+
) # Wavelet filter application
|
67
|
+
DX = DX.at[:, i].set(
|
68
|
+
jnp.where(cond3, DX[:, i], val)
|
69
|
+
) # Update DX with computed values
|
70
|
+
# At this point, DX contains the data FFT'd with the wavelet filter
|
71
|
+
# (each row is a frequency bin, each column is a time bin)
|
38
72
|
|
39
|
-
#
|
73
|
+
# Perform the inverse FFT along the time dimension
|
40
74
|
DX_trans = ifft(DX, axis=1)
|
41
75
|
|
42
|
-
#
|
43
|
-
n_range = jnp.arange(Nt)
|
44
|
-
cond1 = (
|
45
|
-
|
76
|
+
# Fill the wavelet output array based on the inverse FFT results
|
77
|
+
n_range = jnp.arange(Nt) # Time indices
|
78
|
+
cond1 = (
|
79
|
+
n_range[:, None] + f_bins[None, :]
|
80
|
+
) % 2 == 1 # Odd/even alternation
|
81
|
+
cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
|
46
82
|
|
83
|
+
# Assign real and imaginary parts based on conditions
|
47
84
|
real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
|
48
85
|
imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
|
49
|
-
|
50
86
|
wave = jnp.where(cond1, imag_part.T, real_part.T)
|
51
87
|
|
52
|
-
|
53
|
-
wave = wave.at[::2, 0].set(
|
54
|
-
|
88
|
+
# Special cases for frequency bins 0 (DC) and Nf (Nyquist)
|
89
|
+
wave = wave.at[::2, 0].set(
|
90
|
+
jnp.real(DX_trans[0, ::2] * jnp.sqrt(2))
|
91
|
+
) # DC component
|
92
|
+
wave = wave.at[1::2, -1].set(
|
93
|
+
jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2))
|
94
|
+
) # Nyquist component
|
55
95
|
|
56
|
-
|
96
|
+
# Return the wavelet-transformed array (transposed for freq-major layout)
|
97
|
+
return wave.T # (Nt, Nf) -> (Nf, Nt)
|
@@ -1,10 +1,9 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
2
|
from jax.numpy.fft import rfftfreq
|
3
3
|
|
4
|
-
from ...phi_computer import phi_vec, phitilde_vec_norm
|
5
4
|
from ....types import FrequencySeries, TimeSeries, Wavelet
|
5
|
+
from ...phi_computer import phi_vec, phitilde_vec_norm
|
6
6
|
from .to_freq import inverse_wavelet_freq_helper
|
7
|
-
# from .inverse_wavelet_time_funcs import inverse_wavelet_time_helper
|
8
7
|
|
9
8
|
|
10
9
|
def from_wavelet_to_time(
|
@@ -65,5 +64,5 @@ def from_wavelet_to_freq(
|
|
65
64
|
-1 / 2
|
66
65
|
) # Normalise to get the proper backwards transformation
|
67
66
|
|
68
|
-
freqs = rfftfreq(wave_in.ND*2, d=dt)[1:]
|
69
|
-
return FrequencySeries(data=freq_data, freq=freqs)
|
67
|
+
freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
|
68
|
+
return FrequencySeries(data=freq_data, freq=freqs)
|
pywavelet/types/wavelet.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from typing import List, Tuple
|
2
2
|
|
3
3
|
import matplotlib.pyplot as plt
|
4
|
+
import numpy as np
|
4
5
|
|
5
6
|
from .common import fmt_timerange, is_documented_by, xp
|
6
7
|
from .plotting import plot_wavelet_grid, plot_wavelet_trend
|
@@ -343,14 +344,18 @@ class Wavelet:
|
|
343
344
|
|
344
345
|
def __mul__(self, other):
|
345
346
|
"""Element-wise multiplication of two Wavelet objects."""
|
346
|
-
if isinstance(other,
|
347
|
-
|
348
|
-
|
349
|
-
)
|
347
|
+
if isinstance(other, WaveletMask):
|
348
|
+
data = self.data.copy()
|
349
|
+
data[~other.mask] = np.nan
|
350
|
+
return Wavelet(data=data, time=self.time, freq=self.freq)
|
350
351
|
elif isinstance(other, float):
|
351
352
|
return Wavelet(
|
352
353
|
data=self.data * other, time=self.time, freq=self.freq
|
353
354
|
)
|
355
|
+
elif isinstance(other, WaveletMask):
|
356
|
+
return Wavelet(
|
357
|
+
data=self.data * other.data, time=self.time, freq=self.freq
|
358
|
+
)
|
354
359
|
|
355
360
|
def __truediv__(self, other):
|
356
361
|
"""Element-wise division of two Wavelet objects."""
|
@@ -445,11 +450,40 @@ class WaveletMask(Wavelet):
|
|
445
450
|
return rpr
|
446
451
|
|
447
452
|
@classmethod
|
448
|
-
def
|
449
|
-
cls,
|
453
|
+
def from_restrictions(
|
454
|
+
cls,
|
455
|
+
time_grid: xp.ndarray,
|
456
|
+
freq_grid: xp.ndarray,
|
457
|
+
frange: List[float],
|
458
|
+
tgaps: List[Tuple[float, float]] = [],
|
450
459
|
):
|
460
|
+
"""
|
461
|
+
Create a WaveletMask object from restrictions on time and frequency.
|
462
|
+
|
463
|
+
Parameters
|
464
|
+
----------
|
465
|
+
time_grid : xp.ndarray
|
466
|
+
Array of time points.
|
467
|
+
freq_grid : xp.ndarray
|
468
|
+
Array of corresponding frequency points.
|
469
|
+
frange : List[float]
|
470
|
+
Frequency range to include.
|
471
|
+
tgaps : List[Tuple[float, float]]
|
472
|
+
List of time gaps to exclude.
|
473
|
+
|
474
|
+
Returns
|
475
|
+
-------
|
476
|
+
WaveletMask
|
477
|
+
A WaveletMask object with the specified restrictions.
|
478
|
+
"""
|
451
479
|
self = cls.zeros_from_grid(time_grid, freq_grid)
|
452
480
|
self.data[
|
453
481
|
(freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
|
454
482
|
] = True
|
483
|
+
|
484
|
+
for tgap in tgaps:
|
485
|
+
self.data[
|
486
|
+
:, (time_grid >= tgap[0]) & (time_grid <= tgap[1])
|
487
|
+
] = False
|
488
|
+
self.data = self.data.astype(bool)
|
455
489
|
return self
|
@@ -1,17 +1,17 @@
|
|
1
1
|
pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
|
2
|
-
pywavelet/_version.py,sha256=
|
3
|
-
pywavelet/backend.py,sha256=
|
2
|
+
pywavelet/_version.py,sha256=4gL0W4-u58XR5lRLpeoIPrGhcewTk0-527de6uTNmkg,411
|
3
|
+
pywavelet/backend.py,sha256=SmpgIBHvTO1rtIAQQN_zpVB8i6R-x23FNKJG6_JlrNs,666
|
4
4
|
pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
|
5
5
|
pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
|
6
|
-
pywavelet/transforms/__init__.py,sha256=
|
6
|
+
pywavelet/transforms/__init__.py,sha256=EYX8glRWojYbrjtbgrjS4vigYTRi7FOtIV3D1UwI5fY,604
|
7
7
|
pywavelet/transforms/phi_computer.py,sha256=ppFSGJwtNnO2flaiok9ms3WXlAxGQikvA7eNfLgriNQ,4461
|
8
|
-
pywavelet/transforms/jax/__init__.py,sha256=
|
9
|
-
pywavelet/transforms/jax/forward/__init__.py,sha256=
|
10
|
-
pywavelet/transforms/jax/forward/from_freq.py,sha256=
|
8
|
+
pywavelet/transforms/jax/__init__.py,sha256=D_f-JgFAzOIJ-EuQZhTMziD4MT6lVWS3XV9s51Cu7Kg,335
|
9
|
+
pywavelet/transforms/jax/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
|
10
|
+
pywavelet/transforms/jax/forward/from_freq.py,sha256=tKEdqPyEvX8ZKVQf16wGxN3d6gkcjm_RtAHQuWHUzy4,3764
|
11
11
|
pywavelet/transforms/jax/forward/from_time.py,sha256=xNeoZq54B6Gi3TdTTYLr_euaFeJcwpms-lSyCG53AdI,1726
|
12
12
|
pywavelet/transforms/jax/forward/main.py,sha256=mm0R4m0pXcnzZB0jCckAc4ynG8STH5mldCmHyyU_PGo,3091
|
13
|
-
pywavelet/transforms/jax/inverse/__init__.py,sha256=
|
14
|
-
pywavelet/transforms/jax/inverse/main.py,sha256
|
13
|
+
pywavelet/transforms/jax/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
|
14
|
+
pywavelet/transforms/jax/inverse/main.py,sha256=-HVOOBsYo3GJvGNCsQLbNPnt9s14JvbB2bGAd9LOr3A,1647
|
15
15
|
pywavelet/transforms/jax/inverse/to_freq.py,sha256=ASNARcDBJQr4EizAP_77e5ai36iPwP6hzfvwGbZQ6BM,2295
|
16
16
|
pywavelet/transforms/numpy/__init__.py,sha256=qFLpGpW3VJSbDp2JpD0Gx7PdwDjH-wrW_aO84ASkIgA,255
|
17
17
|
pywavelet/transforms/numpy/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
|
@@ -27,9 +27,9 @@ pywavelet/types/common.py,sha256=aIcYq-0KOLHnPQjrVbVmw_TQ3Xm5a7xA30rSgwt3rk4,127
|
|
27
27
|
pywavelet/types/frequencyseries.py,sha256=hrtLaIUaRrqXw8l00yFe2tPJwpksDa_4n1z6R8XSPPQ,7531
|
28
28
|
pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
|
29
29
|
pywavelet/types/timeseries.py,sha256=u35bIqFo3QdlQRBEu6maeWA7DePS11LQ6WMiLjZPcWo,9456
|
30
|
-
pywavelet/types/wavelet.py,sha256=
|
30
|
+
pywavelet/types/wavelet.py,sha256=uHJzTS2ZXTRr7I7NHWv3qNjknSBhQUpcED3jM6ti7UM,13587
|
31
31
|
pywavelet/types/wavelet_bins.py,sha256=GoQGKeZlPc-KbYY7LoxAhB-HI4diHpPcTABBXRfUTLA,1459
|
32
|
-
pywavelet-0.2.
|
33
|
-
pywavelet-0.2.
|
34
|
-
pywavelet-0.2.
|
35
|
-
pywavelet-0.2.
|
32
|
+
pywavelet-0.2.4.dist-info/METADATA,sha256=Thhhz8I2XTKr0mVuf09UpcvjeEGKUnVUX0jxENu6gEQ,2241
|
33
|
+
pywavelet-0.2.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
34
|
+
pywavelet-0.2.4.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
|
35
|
+
pywavelet-0.2.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|