pywavelet 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/__init__.py CHANGED
@@ -2,4 +2,26 @@
2
2
  WDM Wavelet transform
3
3
  """
4
4
 
5
+ import importlib
6
+ import os
7
+
8
+ from . import backend as _backend
9
+
5
10
  __version__ = "0.0.2"
11
+
12
+
13
+ def set_backend(backend: str):
14
+ """Set the backend for the wavelet transform.
15
+
16
+ Parameters
17
+ ----------
18
+ backend : str
19
+ Backend to use. Options are "numpy", "jax", "cupy".
20
+ """
21
+ from . import types
22
+ from . import transforms
23
+ os.environ["PYWAVELET_BACKEND"] = backend
24
+
25
+ importlib.reload(_backend)
26
+ importlib.reload(types)
27
+ importlib.reload(transforms)
pywavelet/_version.py CHANGED
@@ -1,8 +1,13 @@
1
- # file generated by setuptools_scm
1
+ # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
3
6
  TYPE_CHECKING = False
4
7
  if TYPE_CHECKING:
5
- from typing import Tuple, Union
8
+ from typing import Tuple
9
+ from typing import Union
10
+
6
11
  VERSION_TUPLE = Tuple[Union[int, str], ...]
7
12
  else:
8
13
  VERSION_TUPLE = object
@@ -12,5 +17,5 @@ __version__: str
12
17
  __version_tuple__: VERSION_TUPLE
13
18
  version_tuple: VERSION_TUPLE
14
19
 
15
- __version__ = version = '0.2.5'
16
- __version_tuple__ = version_tuple = (0, 2, 5)
20
+ __version__ = version = '0.2.7'
21
+ __version_tuple__ = version_tuple = (0, 2, 7)
pywavelet/backend.py CHANGED
@@ -1,31 +1,101 @@
1
+ import importlib
1
2
  import os
3
+ from rich.table import Table, Text
4
+ from rich.console import Console
5
+
6
+
2
7
 
3
8
  from .logger import logger
4
9
 
5
- try:
6
- import jax
10
+ JAX = "jax"
11
+ CUPY = "cupy"
12
+ NUMPY = "numpy"
13
+
14
+
15
+ def cuda_is_available() -> bool:
16
+ """Check if CUDA is available."""
17
+ # Check if CuPy is available and CUDA is accessible
18
+ cupy_available = importlib.util.find_spec("cupy") is not None
19
+ if cupy_available:
20
+ import cupy
21
+
22
+ try:
23
+ cupy.cuda.runtime.getDeviceCount() # Check if any CUDA device is available
24
+ cuda_available = True
25
+ except cupy.cuda.runtime.CUDARuntimeError:
26
+ cuda_available = False
27
+ else:
28
+ cuda_available = False
29
+ return cuda_available
30
+
31
+
32
+ def jax_is_available() -> bool:
33
+ """Check if JAX is available."""
34
+ return importlib.util.find_spec(JAX) is not None
35
+
36
+
37
+ def get_available_backends_table():
38
+ """Print the available backends as a rich table."""
39
+
40
+ jax_avail = jax_is_available()
41
+ cupy_avail = cuda_is_available()
42
+ table = Table("Backend", "Available", title="Available backends")
43
+ true_check = "[green]✓[/green]"
44
+ false_check = "[red]✗[/red]"
45
+ table.add_row(JAX, true_check if jax_avail else false_check)
46
+ table.add_row(CUPY, true_check if cupy_avail else false_check)
47
+ table.add_row(NUMPY, true_check)
48
+ console = Console(width=150)
49
+ with console.capture() as capture:
50
+ console.print(table)
51
+ return Text.from_ansi(capture.get())
52
+
53
+
54
+ def get_backend_from_env():
55
+ """Select and return the appropriate backend module."""
56
+ backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
57
+
58
+ if backend == JAX and jax_is_available():
59
+
60
+ import jax.numpy as xp
61
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
62
+ from jax.scipy.special import betainc
63
+
64
+ logger.info("Using JAX backend")
65
+
66
+ elif backend == CUPY and cuda_is_available():
7
67
 
8
- jax_available = True
68
+ import cupy as xp
69
+ from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
70
+ from cupyx.scipy.special import betainc
9
71
 
72
+ logger.info("Using CuPy backend")
10
73
 
11
- except ImportError:
12
- jax_available = False
74
+ elif backend == NUMPY:
75
+ import numpy as xp
76
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
77
+ from scipy.special import betainc
13
78
 
14
- use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
79
+ logger.info("Using NumPy backend")
15
80
 
16
- if use_jax:
17
- import jax.numpy as xp # type: ignore
18
- from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
19
- from jax.scipy.special import betainc # type: ignore
20
81
 
21
- logger.info("Using JAX backend")
82
+ else:
83
+ logger.error(
84
+ f"Backend {backend} is not available. "
85
+ )
86
+ print(get_available_backends_table())
87
+ logger.warning(
88
+ f"Setting backend to NumPy. "
89
+ )
90
+ os.environ["PYWAVELET_BACKEND"] = NUMPY
91
+ return get_backend_from_env()
22
92
 
23
- else:
24
- import numpy as xp # type: ignore
25
- from numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
26
- from scipy.special import betainc # type: ignore
93
+ return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
27
94
 
28
- logger.info("Using NumPy+numba backend")
29
95
 
96
+ cuda_available = cuda_is_available()
30
97
 
31
- PI = xp.pi
98
+ # Get the chosen backend
99
+ xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
100
+ get_backend_from_env()
101
+ )
@@ -1,12 +1,25 @@
1
- from ..backend import use_jax
1
+ from ..logger import logger
2
2
 
3
- if use_jax:
3
+ from ..backend import current_backend
4
+
5
+
6
+ logger.debug(f"Using {current_backend} backend")
7
+
8
+
9
+ if current_backend == "jax":
4
10
  from .jax import (
5
11
  from_freq_to_wavelet,
6
12
  from_time_to_wavelet,
7
13
  from_wavelet_to_freq,
8
14
  from_wavelet_to_time,
9
15
  )
16
+ elif current_backend == "cupy":
17
+ from .cupy import (
18
+ from_freq_to_wavelet,
19
+ from_time_to_wavelet,
20
+ from_wavelet_to_freq,
21
+ from_wavelet_to_time,
22
+ )
10
23
  else:
11
24
  from .numpy import (
12
25
  from_wavelet_to_time,
@@ -15,7 +28,7 @@ else:
15
28
  from_freq_to_wavelet,
16
29
  )
17
30
 
18
- from .phi_computer import phi_vec, phitilde_vec, phitilde_vec_norm
31
+ from .phi_computer import omega, phi_vec, phitilde_vec_norm
19
32
 
20
33
  __all__ = [
21
34
  "from_wavelet_to_time",
@@ -24,5 +37,4 @@ __all__ = [
24
37
  "from_freq_to_wavelet",
25
38
  "phitilde_vec_norm",
26
39
  "phi_vec",
27
- "phitilde_vec",
28
40
  ]
@@ -0,0 +1,12 @@
1
+ from ...logger import logger
2
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
3
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
+
5
+ logger.warning("CUPY SUBPACKAGE NOT FULLY TESTED")
6
+
7
+ __all__ = [
8
+ "from_wavelet_to_time",
9
+ "from_wavelet_to_freq",
10
+ "from_time_to_wavelet",
11
+ "from_freq_to_wavelet",
12
+ ]
@@ -0,0 +1,3 @@
1
+ from .main import from_freq_to_wavelet, from_time_to_wavelet
2
+
3
+ __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -0,0 +1,92 @@
1
+ import cupy as cp
2
+ from cupyx.scipy.fft import ifft
3
+
4
+
5
+ import logging
6
+
7
+ logger = logging.getLogger('pywavelet')
8
+
9
+
10
+ def transform_wavelet_freq_helper(
11
+ data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray
12
+ ) -> cp.ndarray:
13
+ """
14
+ Transforms input data from the frequency domain to the wavelet domain using a
15
+ pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
16
+
17
+ Parameters:
18
+ - data (cp.ndarray): 1D array representing the input data in the frequency domain.
19
+ - Nf (int): Number of frequency bins.
20
+ - Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
21
+ - phif (cp.ndarray): Pre-computed wavelet filter for frequency components.
22
+
23
+ Returns:
24
+ - wave (cp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
25
+ """
26
+
27
+ logger.debug(f"[CUPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
28
+
29
+ # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
30
+ wave = cp.zeros((Nt, Nf))
31
+ f_bins = cp.arange(Nf) # Frequency bin indices
32
+
33
+ # Compute base indices for time (i_base) and frequency (jj_base)
34
+ i_base = Nt // 2
35
+ jj_base = f_bins * Nt // 2
36
+
37
+ # Set initial values for the center of the transformation
38
+ initial_values = cp.where(
39
+ (f_bins == 0)
40
+ | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
41
+ phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
42
+ phif[0] * data[f_bins * Nt // 2],
43
+ )
44
+
45
+ # Initialize a 2D array to store intermediate FFT input values
46
+ DX = cp.zeros((Nf, Nt), dtype=cp.complex64)
47
+ DX[:, Nt // 2] = (
48
+ initial_values # Set initial values at the center of the transformation (2 sided FFT)
49
+ )
50
+
51
+ # Compute time indices for all offsets around the midpoint
52
+ j_range = cp.arange(
53
+ 1 - Nt // 2, Nt // 2
54
+ ) # Time offsets (centered around zero)
55
+ j = cp.abs(j_range) # Absolute offset indices
56
+ i = i_base + j_range # Time indices in output array
57
+
58
+ # Compute conditions for edge cases
59
+ cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
60
+ cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
61
+ cond3 = j[None, :] == 0 # Center of the transformation (no offset)
62
+
63
+ # Compute frequency indices for the input data
64
+ jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
65
+ val = cp.where(
66
+ cond1 | cond2, 0.0, phif[j] * data[jj]
67
+ ) # Wavelet filter application
68
+ DX[:, i] = cp.where(cond3, DX[:, i], val) # Update DX with computed values
69
+
70
+ # Perform the inverse FFT along the time dimension
71
+ DX_trans = ifft(DX, axis=1)
72
+
73
+ # Fill the wavelet output array based on the inverse FFT results
74
+ n_range = cp.arange(Nt) # Time indices
75
+ cond1 = (
76
+ n_range[:, None] + f_bins[None, :]
77
+ ) % 2 == 1 # Odd/even alternation
78
+ cond2 = cp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
79
+
80
+ # Assign real and imaginary parts based on conditions
81
+ real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))
82
+ imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))
83
+ wave = cp.where(cond1, imag_part.T, real_part.T)
84
+
85
+ # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
86
+ wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2)) # DC component
87
+ wave[1::2, -1] = cp.real(
88
+ DX_trans[-1, ::2] * cp.sqrt(2)
89
+ ) # Nyquist component
90
+
91
+ # Return the wavelet-transformed array (transposed for freq-major layout)
92
+ return wave.T # (Nt, Nf) -> (Nf, Nt)
@@ -0,0 +1,50 @@
1
+ import cupy as cp
2
+ from cupyx.scipy.fft import rfft
3
+
4
+
5
+ def transform_wavelet_time_helper(
6
+ data: cp.ndarray, phi: cp.ndarray, Nf: int, Nt: int, mult: int
7
+ ) -> cp.ndarray:
8
+ """Helper function to do the wavelet transform in the time domain using CuPy"""
9
+ # Define constants
10
+ ND = Nf * Nt
11
+ K = mult * 2 * Nf
12
+
13
+ # Pad the data with K extra values
14
+ data_pad = cp.concatenate((data, data[:K]))
15
+
16
+ # Generate time bin indices
17
+ time_bins = cp.arange(Nt)
18
+ jj_base = (time_bins[:, None] * Nf - K // 2) % ND
19
+ jj = (jj_base + cp.arange(K)[None, :]) % ND
20
+
21
+ # Apply the window (phi) to the data
22
+ wdata = data_pad[jj] * phi[None, :]
23
+
24
+ # Perform FFT on the windowed data
25
+ wdata_trans = rfft(wdata, axis=1)
26
+
27
+ # Initialize the wavelet transform result
28
+ wave = cp.zeros((Nt, Nf))
29
+
30
+ # Handle m=0 case for even time bins
31
+ even_mask = (time_bins % 2 == 0) & (time_bins < Nt - 1)
32
+ even_indices = cp.nonzero(even_mask)[0]
33
+
34
+ # Update wave for m=0 using even time bins
35
+ wave[even_indices, 0] = cp.real(wdata_trans[even_indices, 0]) / cp.sqrt(2)
36
+ wave[even_indices + 1, 0] = cp.real(
37
+ wdata_trans[even_indices, Nf * mult]
38
+ ) / cp.sqrt(2)
39
+
40
+ # Handle other cases (j > 0) using vectorized operations
41
+ j_range = cp.arange(1, Nf)
42
+ odd_condition = (time_bins[:, None] + j_range[None, :]) % 2 == 1
43
+
44
+ wave[:, 1:] = cp.where(
45
+ odd_condition,
46
+ -cp.imag(wdata_trans[:, j_range * mult]),
47
+ cp.real(wdata_trans[:, j_range * mult]),
48
+ )
49
+
50
+ return wave.T
@@ -0,0 +1,106 @@
1
+ from typing import Union
2
+
3
+ import cupy as cp
4
+
5
+ from ....logger import logger
6
+ from ....types import FrequencySeries, TimeSeries, Wavelet
7
+ from ....types.wavelet_bins import _get_bins, _preprocess_bins
8
+ from ...phi_computer import phi_vec, phitilde_vec_norm
9
+ from .from_freq import transform_wavelet_freq_helper
10
+ from .from_time import transform_wavelet_time_helper
11
+
12
+
13
+ def from_time_to_wavelet(
14
+ timeseries: TimeSeries,
15
+ Nf: Union[int, None] = None,
16
+ Nt: Union[int, None] = None,
17
+ nx: float = 4.0,
18
+ mult: int = 32,
19
+ **kwargs,
20
+ ) -> Wavelet:
21
+ """Transforms time-domain data to wavelet-domain data.
22
+
23
+ Warning: there can be significant leakage if mult is too small and the
24
+ transform is only approximately exact if mult=Nt/2
25
+
26
+ Parameters
27
+ ----------
28
+ timeseries : TimeSeries
29
+ Time domain data
30
+ Nf : int
31
+ Number of frequency bins
32
+ Nt : int
33
+ Number of time bins
34
+ nx : float, optional
35
+ Number of standard deviations for the phi_vec, by default 4.
36
+ mult : int, optional
37
+ Number of time bins to use for the wavelet transform, by default 32
38
+ **kwargs:
39
+ Additional keyword arguments passed to the Wavelet.from_data constructor.
40
+
41
+ Returns
42
+ -------
43
+ Wavelet
44
+ Wavelet domain data
45
+
46
+ """
47
+ Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
48
+ dt = timeseries.dt
49
+ t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
50
+
51
+ ND = Nf * Nt
52
+
53
+ if len(timeseries) != ND:
54
+ logger.warning(
55
+ f"len(freqseries)={len(timeseries)} != Nf*Nt={ND}. Truncating to freqseries[:{ND}]"
56
+ )
57
+ timeseries = timeseries[:ND]
58
+ if mult > Nt / 2:
59
+ logger.warning(
60
+ f"mult={mult} is too large for Nt={Nt}. This may lead to bogus results."
61
+ )
62
+
63
+ mult = min(mult, Nt // 2) # make sure K isn't bigger than ND
64
+ phi = cp.array(phi_vec(Nf, d=nx, q=mult))
65
+ wave = transform_wavelet_time_helper(
66
+ cp.array(timeseries.data), Nf=Nf, Nt=Nt, phi=phi, mult=mult
67
+ )
68
+ return Wavelet(wave * cp.sqrt(2), time=t_bins, freq=f_bins)
69
+
70
+
71
+ def from_freq_to_wavelet(
72
+ freqseries: FrequencySeries,
73
+ Nf: Union[int, None] = None,
74
+ Nt: Union[int, None] = None,
75
+ nx: float = 4.0,
76
+ **kwargs,
77
+ ) -> Wavelet:
78
+ """Transforms frequency-domain data to wavelet-domain data.
79
+
80
+ Parameters
81
+ ----------
82
+ freqseries : FrequencySeries
83
+ Frequency domain data
84
+ Nf : int
85
+ Number of frequency bins
86
+ Nt : int
87
+ Number of time bins
88
+ nx : float, optional
89
+ Number of standard deviations for the phi_vec, by default 4.
90
+ **kwargs:
91
+ Additional keyword arguments passed to the Wavelet.from_data constructor.
92
+
93
+ Returns
94
+ -------
95
+ Wavelet
96
+ Wavelet domain data
97
+
98
+ """
99
+ Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
100
+ t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
101
+ phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx))
102
+ wave = transform_wavelet_freq_helper(
103
+ cp.array(freqseries.data), Nf=Nf, Nt=Nt, phif=phif
104
+ )
105
+
106
+ return Wavelet((2 / Nf) * wave * cp.sqrt(2), time=t_bins, freq=f_bins)
@@ -0,0 +1,3 @@
1
+ from .main import from_wavelet_to_freq, from_wavelet_to_time
2
+
3
+ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
@@ -0,0 +1,67 @@
1
+ import cupy as cp
2
+ from cupyx.scipy.fft import rfftfreq
3
+
4
+ from ....types import FrequencySeries, TimeSeries, Wavelet
5
+ from ...phi_computer import phi_vec, phitilde_vec_norm
6
+ from .to_freq import inverse_wavelet_freq_helper
7
+
8
+
9
+ def from_wavelet_to_time(
10
+ wave_in: Wavelet,
11
+ dt: float,
12
+ nx: float = 4.0,
13
+ mult: int = 32,
14
+ ) -> TimeSeries:
15
+ """Inverse wavelet transform to time domain.
16
+
17
+ Parameters
18
+ ----------
19
+ wave_in : Wavelet
20
+ input wavelet
21
+ dt : float
22
+ time step
23
+ nx : float, optional
24
+ parameter for phi_vec, by default 4.0
25
+ mult : int, optional
26
+ parameter for phi_vec, by default 32
27
+
28
+ Returns
29
+ -------
30
+ TimeSeries
31
+ Time domain signal
32
+ """
33
+ freq = from_wavelet_to_freq(wave_in, dt=dt, nx=nx)
34
+ return freq.to_timeseries()
35
+
36
+
37
+ def from_wavelet_to_freq(
38
+ wave_in: Wavelet, dt: float, nx=4.0
39
+ ) -> FrequencySeries:
40
+ """Inverse wavelet transform to frequency domain.
41
+
42
+ Parameters
43
+ ----------
44
+ wave_in : Wavelet
45
+ input wavelet
46
+ dt : float
47
+ time step
48
+ nx : float, optional
49
+ parameter for phitilde_vec_norm, by default 4.0
50
+
51
+ Returns
52
+ -------
53
+ FrequencySeries
54
+ Frequency domain signal
55
+
56
+ """
57
+ phif = cp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx))
58
+ freq_data = inverse_wavelet_freq_helper(
59
+ wave_in.data, phif=phif, Nf=wave_in.Nf, Nt=wave_in.Nt
60
+ )
61
+
62
+ freq_data *= 2 ** (
63
+ -1 / 2
64
+ ) # Normalise to get the proper backwards transformation
65
+
66
+ freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
67
+ return FrequencySeries(data=freq_data, freq=freqs)
@@ -0,0 +1,62 @@
1
+ import cupy as cp
2
+ from cupyx.scipy.fft import fft
3
+
4
+
5
+ def inverse_wavelet_freq_helper(
6
+ wave_in: cp.ndarray, phif: cp.ndarray, Nf: int, Nt: int
7
+ ) -> cp.ndarray:
8
+ """CuPy vectorized function for inverse_wavelet_freq"""
9
+ wave_in = wave_in.T
10
+ ND = Nf * Nt
11
+
12
+ m_range = cp.arange(Nf + 1)
13
+ prefactor2s = cp.zeros((Nf + 1, Nt), dtype=cp.complex128)
14
+
15
+ n_range = cp.arange(Nt)
16
+
17
+ # m == 0 case
18
+ prefactor2s[0] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
19
+
20
+ # m == Nf case
21
+ prefactor2s[Nf] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
22
+
23
+ # Other m cases
24
+ m_mid = m_range[1:Nf]
25
+ n_grid, m_grid = cp.meshgrid(n_range, m_mid)
26
+ val = wave_in[n_grid, m_grid]
27
+ mult2 = cp.where((n_grid + m_grid) % 2, -1j, 1)
28
+ prefactor2s[1:Nf] = mult2 * val
29
+
30
+ # Vectorized FFT
31
+ fft_prefactor2s = fft(prefactor2s, axis=1)
32
+
33
+ # Vectorized __unpack_wave_inverse
34
+ res = cp.zeros(ND, dtype=cp.complex128)
35
+
36
+ # m == 0 or m == Nf cases
37
+ i_ind_range = cp.arange(Nt // 2)
38
+ i_0 = cp.abs(i_ind_range)
39
+ i_Nf = cp.abs(Nf * Nt // 2 - i_ind_range)
40
+ ind3_0 = (2 * i_0) % Nt
41
+ ind3_Nf = (2 * i_Nf) % Nt
42
+
43
+ res[i_0] += fft_prefactor2s[0, ind3_0] * phif[i_ind_range]
44
+ res[i_Nf] += fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range]
45
+
46
+ # Special case for m == Nf
47
+ res[Nf * Nt // 2] += fft_prefactor2s[Nf, 0] * phif[Nt // 2]
48
+
49
+ # Other m cases
50
+ m_mid = m_range[1:Nf]
51
+ i_ind_range = cp.arange(Nt // 2 + 1)
52
+ m_grid, i_ind_grid = cp.meshgrid(m_mid, i_ind_range)
53
+
54
+ i1 = Nt // 2 * m_grid - i_ind_grid
55
+ i2 = Nt // 2 * m_grid + i_ind_grid
56
+ ind31 = (Nt // 2 * m_grid - i_ind_grid) % Nt
57
+ ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
58
+
59
+ res[i1] += fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid]
60
+ res[i2] += fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid]
61
+
62
+ return res
@@ -4,6 +4,30 @@ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
4
 
5
5
  logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
6
6
 
7
+
8
+ def _log_jax_info():
9
+ """Log JAX backend and precision information.
10
+
11
+ backend : str
12
+ JAX backend. ["cpu", "gpu", "tpu"]
13
+ precision : str
14
+ JAX precision. ["32bit", "64bit"]
15
+ """
16
+ import jax
17
+
18
+ _backend = jax.default_backend()
19
+ _precision = "64bit" if jax.config.jax_enable_x64 else "32bit"
20
+
21
+ logger.info(f"Jax running on {_backend} [{_precision} precision].")
22
+ if _precision == "32bit":
23
+ logger.warning(
24
+ "Jax is not running in 64bit precision. "
25
+ "To change, use jax.config.update('jax_enable_x64', True)."
26
+ )
27
+
28
+
29
+ _log_jax_info()
30
+
7
31
  __all__ = [
8
32
  "from_wavelet_to_time",
9
33
  "from_wavelet_to_freq",
@@ -4,6 +4,10 @@ import jax.numpy as jnp
4
4
  from jax import jit
5
5
  from jax.numpy.fft import ifft
6
6
 
7
+ import logging
8
+
9
+ logger = logging.getLogger('pywavelet')
10
+
7
11
 
8
12
  @partial(jit, static_argnames=("Nf", "Nt"))
9
13
  def transform_wavelet_freq_helper(
@@ -23,6 +27,8 @@ def transform_wavelet_freq_helper(
23
27
  - wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
24
28
  """
25
29
 
30
+ logger.debug(f"[JAX TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
31
+
26
32
  # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
27
33
  wave = jnp.zeros((Nt, Nf))
28
34
  f_bins = jnp.arange(Nf) # Frequency bin indices
@@ -10,7 +10,7 @@ def from_wavelet_to_time(
10
10
  wave_in: Wavelet,
11
11
  dt: float,
12
12
  nx: float = 4.0,
13
- mult: int = 32,
13
+ mult: int = None,
14
14
  ) -> TimeSeries:
15
15
  """Inverse wavelet transform to time domain.
16
16
 
@@ -55,14 +55,12 @@ def from_wavelet_to_freq(
55
55
  Frequency domain signal
56
56
 
57
57
  """
58
- phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, dt=dt, d=nx))
58
+ phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx))
59
59
  freq_data = inverse_wavelet_freq_helper(
60
60
  wave_in.data, phif=phif, Nf=wave_in.Nf, Nt=wave_in.Nt
61
61
  )
62
62
 
63
- freq_data *= 2 ** (
64
- -1 / 2
65
- ) # Normalise to get the proper backwards transformation
63
+ freq_data *= 1.0 / jnp.sqrt(2)
66
64
 
67
- freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
65
+ freqs = rfftfreq(wave_in.ND, d=dt)
68
66
  return FrequencySeries(data=freq_data, freq=freqs)