pywavelet 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/__init__.py CHANGED
@@ -6,11 +6,10 @@ import importlib
6
6
  import os
7
7
 
8
8
  from . import backend as _backend
9
+ from ._version import __version__
9
10
 
10
- __version__ = "0.0.2"
11
11
 
12
-
13
- def set_backend(backend: str):
12
+ def set_backend(backend: str, precision: str = "float32") -> None:
14
13
  """Set the backend for the wavelet transform.
15
14
 
16
15
  Parameters
@@ -18,10 +17,23 @@ def set_backend(backend: str):
18
17
  backend : str
19
18
  Backend to use. Options are "numpy", "jax", "cupy".
20
19
  """
21
- from . import types
22
- from . import transforms
20
+ from . import transforms, types
21
+
23
22
  os.environ["PYWAVELET_BACKEND"] = backend
23
+ os.environ["PYWAVELET_PRECISION"] = precision
24
24
 
25
25
  importlib.reload(_backend)
26
26
  importlib.reload(types)
27
27
  importlib.reload(transforms)
28
+ if backend == "cupy":
29
+ importlib.reload(transforms.cupy)
30
+ importlib.reload(transforms.cupy.forward)
31
+ importlib.reload(transforms.cupy.inverse)
32
+ elif backend == "jax":
33
+ importlib.reload(transforms.jax)
34
+ importlib.reload(transforms.jax.forward)
35
+ importlib.reload(transforms.jax.inverse)
36
+ else:
37
+ importlib.reload(transforms.numpy)
38
+ importlib.reload(transforms.numpy.forward)
39
+ importlib.reload(transforms.numpy.inverse)
pywavelet/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.6'
21
- __version_tuple__ = version_tuple = (0, 2, 6)
20
+ __version__ = version = '0.2.8'
21
+ __version_tuple__ = version_tuple = (0, 2, 8)
pywavelet/backend.py CHANGED
@@ -1,5 +1,10 @@
1
1
  import importlib
2
2
  import os
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ from rich.console import Console
7
+ from rich.table import Table, Text
3
8
 
4
9
  from .logger import logger
5
10
 
@@ -7,47 +12,168 @@ JAX = "jax"
7
12
  CUPY = "cupy"
8
13
  NUMPY = "numpy"
9
14
 
15
+ VALID_PRECISIONS = ["float32", "float64"]
16
+
17
+
18
+ def cuda_is_available() -> bool:
19
+ """Check if CUDA is available."""
20
+ # Check if CuPy is available and CUDA is accessible
21
+ cupy_available = importlib.util.find_spec("cupy") is not None
22
+ _cuda_available = False
23
+ if cupy_available:
24
+ import cupy
25
+
26
+ try:
27
+ cupy.cuda.runtime.getDeviceCount() # Check if any CUDA device is available
28
+ _cuda_available = True
29
+ except cupy.cuda.runtime.CUDARuntimeError:
30
+ _cuda_available = False
31
+ else:
32
+ _cuda_available = False
33
+ return _cuda_available
34
+
35
+
36
+ def jax_is_available() -> bool:
37
+ """Check if JAX is available."""
38
+ return importlib.util.find_spec(JAX) is not None
39
+
40
+
41
+ def get_available_backends_table():
42
+ """Print the available backends as a rich table."""
43
+
44
+ jax_avail = jax_is_available()
45
+ cupy_avail = cuda_is_available()
46
+ table = Table("Backend", "Available", title="Available backends")
47
+ true_check = "[green]✓[/green]"
48
+ false_check = "[red]✗[/red]"
49
+ table.add_row(JAX, true_check if jax_avail else false_check)
50
+ table.add_row(CUPY, true_check if cupy_avail else false_check)
51
+ table.add_row(NUMPY, true_check)
52
+ console = Console(width=150)
53
+ with console.capture() as capture:
54
+ console.print(table)
55
+ return Text.from_ansi(capture.get())
56
+
57
+
58
+ def log_backend(level="info"):
59
+ """Print the current backend and precision."""
60
+ backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
61
+ precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
62
+ str = f"Current backend: {backend}[{precision}]"
63
+ if level == "info":
64
+ logger.info(str)
65
+ elif level == "debug":
66
+ logger.debug(str)
67
+
10
68
 
11
69
  def get_backend_from_env():
12
70
  """Select and return the appropriate backend module."""
13
71
  backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
72
+ precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
73
+
74
+ if backend == JAX and jax_is_available():
75
+
76
+ import jax.numpy as xp
77
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
78
+ from jax.scipy.special import betainc
14
79
 
80
+ elif backend == CUPY and cuda_is_available():
81
+
82
+ import cupy as xp
83
+ from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
84
+ from cupyx.scipy.special import betainc
85
+
86
+ elif backend == NUMPY:
87
+ import numpy as xp
88
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
89
+ from scipy.special import betainc
90
+
91
+ else:
92
+ logger.error(f"Backend {backend}[{precision}] is not available. ")
93
+ print(get_available_backends_table())
94
+ logger.warning(f"Setting backend to NumPy. ")
95
+ os.environ["PYWAVELET_BACKEND"] = NUMPY
96
+ return get_backend_from_env()
97
+
98
+ log_backend("debug")
99
+ return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
100
+
101
+
102
+ def get_precision_from_env() -> str:
103
+ """Get the precision from the environment variable."""
104
+ precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
105
+ if precision not in VALID_PRECISIONS:
106
+ logger.error(
107
+ f"Precision {precision} is not supported, defaulting to float32."
108
+ )
109
+ precision = "float32"
110
+ return precision
111
+
112
+
113
+ def set_precision(precision: str) -> None:
114
+ """Set the precision for the backend."""
115
+ precision = precision.lower()
116
+ if precision not in VALID_PRECISIONS:
117
+ logger.error(f"Precision {precision} is not supported.")
118
+ return
119
+ else:
120
+ os.environ["PYWAVELET_PRECISION"] = precision
121
+ logger.info(f"Setting precision to {precision}.")
122
+ return
123
+
124
+
125
+ def get_dtype_from_env() -> Tuple[np.dtype, np.dtype]:
126
+ """Get the data type from the environment variable."""
127
+ precision = get_precision_from_env()
128
+ backend = get_backend_from_env()[-1]
15
129
  if backend == JAX:
16
- if importlib.util.find_spec(JAX):
17
- import jax.numpy as xp
18
- from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
19
- from jax.scipy.special import betainc
20
-
21
- logger.info("Using JAX backend")
22
- return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
23
- else:
24
- logger.warning(
25
- "JAX backend requested but not installed. Falling back to NumPy."
26
- )
130
+
131
+ if precision == "float32":
132
+ import jax
133
+
134
+ jax.config.update("jax_enable_x64", False)
135
+
136
+ import jax.numpy as jnp
137
+
138
+ float_dtype = jnp.float32
139
+ complex_dtype = jnp.complex64
140
+ elif precision == "float64":
141
+ import jax
142
+
143
+ jax.config.update("jax_enable_x64", True)
144
+
145
+ import jax.numpy as jnp
146
+
147
+ float_dtype = jnp.float64
148
+ complex_dtype = jnp.complex128
27
149
 
28
150
  elif backend == CUPY:
29
- if importlib.util.find_spec(CUPY):
30
- import cupy as xp
31
- from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
32
- from cupyx.scipy.special import betainc
33
-
34
- logger.info("Using CuPy backend")
35
- return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
36
- else:
37
- logger.warning(
38
- "CuPy backend requested but not installed. Falling back to NumPy."
39
- )
40
-
41
- # Default to NumPy
42
- import numpy as xp
43
- from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
44
- from scipy.special import betainc
45
-
46
- logger.info("Using NumPy+Numba backend")
47
- return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
151
+ import cupy as cp
152
+
153
+ if precision == "float32":
154
+ float_dtype = cp.float32
155
+ complex_dtype = cp.complex64
156
+ elif precision == "float64":
157
+ float_dtype = cp.float64
158
+ complex_dtype = cp.complex128
48
159
 
160
+ else:
161
+ if precision == "float32":
162
+ float_dtype = np.float32
163
+ complex_dtype = np.complex64
164
+ elif precision == "float64":
165
+ float_dtype = np.float64
166
+ complex_dtype = np.complex128
167
+
168
+ return float_dtype, complex_dtype
169
+
170
+
171
+ cuda_available = cuda_is_available()
49
172
 
50
173
  # Get the chosen backend
51
174
  xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
52
175
  get_backend_from_env()
53
176
  )
177
+
178
+ # Get the chosen precision
179
+ float_dtype, complex_dtype = get_dtype_from_env()
pywavelet/logger.py CHANGED
@@ -5,14 +5,14 @@ import warnings
5
5
  from rich.logging import RichHandler
6
6
 
7
7
  FORMAT = "%(message)s"
8
- logging.basicConfig(
9
- level="INFO",
10
- format=FORMAT,
11
- datefmt="[%X]",
12
- handlers=[RichHandler(rich_tracebacks=True)],
13
- )
14
8
 
15
9
  logger = logging.getLogger("pywavelet")
10
+ if not logger.hasHandlers():
11
+ handler = RichHandler(rich_tracebacks=True)
12
+ formatter = logging.Formatter(fmt=FORMAT, datefmt="[%X]")
13
+ handler.setFormatter(formatter)
14
+ logger.addHandler(handler)
15
+ logger.setLevel(logging.INFO)
16
16
 
17
17
  warnings.filterwarnings("ignore", category=RuntimeWarning)
18
18
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -1,4 +1,8 @@
1
1
  from ..backend import current_backend
2
+ from ..logger import logger
3
+
4
+ logger.debug(f"Using {current_backend} backend")
5
+
2
6
 
3
7
  if current_backend == "jax":
4
8
  from .jax import (
@@ -1,14 +1,18 @@
1
+ import logging
2
+
1
3
  import cupy as cp
2
4
  from cupyx.scipy.fft import ifft
3
5
 
4
-
5
- import logging
6
-
7
- logger = logging.getLogger('pywavelet')
6
+ logger = logging.getLogger("pywavelet")
8
7
 
9
8
 
10
9
  def transform_wavelet_freq_helper(
11
- data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray
10
+ data: cp.ndarray,
11
+ Nf: int,
12
+ Nt: int,
13
+ phif: cp.ndarray,
14
+ float_dtype=cp.float64,
15
+ complex_dtype=cp.complex128,
12
16
  ) -> cp.ndarray:
13
17
  """
14
18
  Transforms input data from the frequency domain to the wavelet domain using a
@@ -24,69 +28,62 @@ def transform_wavelet_freq_helper(
24
28
  - wave (cp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
25
29
  """
26
30
 
27
- logger.debug(f"[CUPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
28
-
29
- # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
30
- wave = cp.zeros((Nt, Nf))
31
- f_bins = cp.arange(Nf) # Frequency bin indices
32
-
33
- # Compute base indices for time (i_base) and frequency (jj_base)
34
- i_base = Nt // 2
35
- jj_base = f_bins * Nt // 2
36
-
37
- # Set initial values for the center of the transformation
38
- initial_values = cp.where(
39
- (f_bins == 0)
40
- | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
41
- phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
42
- phif[0] * data[f_bins * Nt // 2],
31
+ logger.debug(
32
+ f"[CUPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}, precision:{data.dtype}]"
43
33
  )
44
34
 
45
- # Initialize a 2D array to store intermediate FFT input values
46
- DX = cp.zeros((Nf, Nt), dtype=cp.complex64)
47
- DX[:, Nt // 2] = (
48
- initial_values # Set initial values at the center of the transformation (2 sided FFT)
35
+ half = Nt // 2
36
+ f_bins = cp.arange(Nf + 1) # [0, 1, ..., Nf]
37
+
38
+ # 1) Build (Nf+1, Nt) DX
39
+ # — center:
40
+ center = phif[0] * data[f_bins * half]
41
+ center = cp.where((f_bins == 0) | (f_bins == Nf), center * 0.5, center)
42
+ DX = cp.zeros((Nf + 1, Nt), complex_dtype)
43
+ DX[:, half] = center
44
+
45
+ # — off-center (j = ±1...±(half-1))
46
+ offs = cp.arange(1 - half, half) # length Nt-1
47
+ jj = f_bins[:, None] * half + offs[None, :] # (Nf+1, Nt-1)
48
+ ii = half + offs # (Nt-1,)
49
+ mask = ((f_bins[:, None] == Nf) & (offs[None, :] > 0)) | (
50
+ (f_bins[:, None] == 0) & (offs[None, :] < 0)
51
+ )
52
+ vals = phif[cp.abs(offs)] * data[jj]
53
+ vals = cp.where(mask, 0.0, vals)
54
+ DX[:, ii] = vals
55
+
56
+ # 2) IFFT along time
57
+ DXt = ifft(DX, axis=1)
58
+
59
+ # 3) Unpack into wave (Nt, Nf)
60
+ wave = cp.zeros((Nt, Nf), float_dtype)
61
+ sqrt2 = cp.sqrt(2.0)
62
+
63
+ # 3a) DC into col 0, even rows
64
+ evens = cp.arange(0, Nt, 2)
65
+ wave[evens, 0] = cp.real(DXt[0, evens]) * sqrt2
66
+
67
+ # 3b) Nyquist into col 0, odd rows
68
+ odds = cp.arange(1, Nt, 2)
69
+ wave[odds, 0] = cp.real(DXt[Nf, evens]) * sqrt2
70
+
71
+ # 3c) intermediate bins 1...Nf-1
72
+ mids = cp.arange(1, Nf) # [1...Nf-1]
73
+ Dt_mid = DXt[mids, :] # (Nf-1, Nt)
74
+ real_m = cp.real(Dt_mid).T # (Nt, Nf-1)
75
+ imag_m = cp.imag(Dt_mid).T # (Nt, Nf-1)
76
+
77
+ odd_f = (mids % 2) == 1 # shape (Nf-1,)
78
+ n_idx = cp.arange(Nt)[:, None] # (Nt,1)
79
+ odd_nf = ((n_idx + mids[None, :]) % 2) == 1
80
+
81
+ # Select real vs imag and sign exactly as in __fill_wave_2
82
+ mid_vals = cp.where(
83
+ odd_nf,
84
+ cp.where(odd_f, -imag_m, imag_m),
85
+ cp.where(odd_f, real_m, real_m),
49
86
  )
87
+ wave[:, 1:Nf] = mid_vals
50
88
 
51
- # Compute time indices for all offsets around the midpoint
52
- j_range = cp.arange(
53
- 1 - Nt // 2, Nt // 2
54
- ) # Time offsets (centered around zero)
55
- j = cp.abs(j_range) # Absolute offset indices
56
- i = i_base + j_range # Time indices in output array
57
-
58
- # Compute conditions for edge cases
59
- cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
60
- cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
61
- cond3 = j[None, :] == 0 # Center of the transformation (no offset)
62
-
63
- # Compute frequency indices for the input data
64
- jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
65
- val = cp.where(
66
- cond1 | cond2, 0.0, phif[j] * data[jj]
67
- ) # Wavelet filter application
68
- DX[:, i] = cp.where(cond3, DX[:, i], val) # Update DX with computed values
69
-
70
- # Perform the inverse FFT along the time dimension
71
- DX_trans = ifft(DX, axis=1)
72
-
73
- # Fill the wavelet output array based on the inverse FFT results
74
- n_range = cp.arange(Nt) # Time indices
75
- cond1 = (
76
- n_range[:, None] + f_bins[None, :]
77
- ) % 2 == 1 # Odd/even alternation
78
- cond2 = cp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
79
-
80
- # Assign real and imaginary parts based on conditions
81
- real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))
82
- imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))
83
- wave = cp.where(cond1, imag_part.T, real_part.T)
84
-
85
- # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
86
- wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2)) # DC component
87
- wave[1::2, -1] = cp.real(
88
- DX_trans[-1, ::2] * cp.sqrt(2)
89
- ) # Nyquist component
90
-
91
- # Return the wavelet-transformed array (transposed for freq-major layout)
92
- return wave.T # (Nt, Nf) -> (Nf, Nt)
89
+ return wave.T
@@ -2,6 +2,7 @@ from typing import Union
2
2
 
3
3
  import cupy as cp
4
4
 
5
+ from .... import backend
5
6
  from ....logger import logger
6
7
  from ....types import FrequencySeries, TimeSeries, Wavelet
7
8
  from ....types.wavelet_bins import _get_bins, _preprocess_bins
@@ -16,7 +17,6 @@ def from_time_to_wavelet(
16
17
  Nt: Union[int, None] = None,
17
18
  nx: float = 4.0,
18
19
  mult: int = 32,
19
- **kwargs,
20
20
  ) -> Wavelet:
21
21
  """Transforms time-domain data to wavelet-domain data.
22
22
 
@@ -45,7 +45,6 @@ def from_time_to_wavelet(
45
45
 
46
46
  """
47
47
  Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
48
- dt = timeseries.dt
49
48
  t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
50
49
 
51
50
  ND = Nf * Nt
@@ -73,7 +72,6 @@ def from_freq_to_wavelet(
73
72
  Nf: Union[int, None] = None,
74
73
  Nt: Union[int, None] = None,
75
74
  nx: float = 4.0,
76
- **kwargs,
77
75
  ) -> Wavelet:
78
76
  """Transforms frequency-domain data to wavelet-domain data.
79
77
 
@@ -98,9 +96,15 @@ def from_freq_to_wavelet(
98
96
  """
99
97
  Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
100
98
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
101
- phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx))
99
+ phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx), dtype=backend.float_dtype)
100
+ data = cp.array(freqseries.data, dtype=backend.complex_dtype)
102
101
  wave = transform_wavelet_freq_helper(
103
- cp.array(freqseries.data), Nf=Nf, Nt=Nt, phif=phif
102
+ data,
103
+ Nf=Nf,
104
+ Nt=Nt,
105
+ phif=phif,
106
+ float_dtype=backend.float_dtype,
107
+ complex_dtype=backend.complex_dtype,
104
108
  )
105
-
106
- return Wavelet((2 / Nf) * wave * cp.sqrt(2), time=t_bins, freq=f_bins)
109
+ factor = (2 / Nf) * cp.sqrt(2)
110
+ return Wavelet(factor * wave, time=t_bins, freq=f_bins)
@@ -6,57 +6,61 @@ def inverse_wavelet_freq_helper(
6
6
  wave_in: cp.ndarray, phif: cp.ndarray, Nf: int, Nt: int
7
7
  ) -> cp.ndarray:
8
8
  """CuPy vectorized function for inverse_wavelet_freq"""
9
- wave_in = wave_in.T
10
- ND = Nf * Nt
9
+ wave = wave_in.T
10
+ ND2 = (Nf * Nt) // 2
11
+ half = Nt // 2
11
12
 
13
+ # === STEP 1: build prefactor2s[m, n] ===
12
14
  m_range = cp.arange(Nf + 1)
13
- prefactor2s = cp.zeros((Nf + 1, Nt), dtype=cp.complex128)
14
-
15
- n_range = cp.arange(Nt)
16
-
17
- # m == 0 case
18
- prefactor2s[0] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
19
-
20
- # m == Nf case
21
- prefactor2s[Nf] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
22
-
23
- # Other m cases
24
- m_mid = m_range[1:Nf]
25
- n_grid, m_grid = cp.meshgrid(n_range, m_mid)
26
- val = wave_in[n_grid, m_grid]
27
- mult2 = cp.where((n_grid + m_grid) % 2, -1j, 1)
28
- prefactor2s[1:Nf] = mult2 * val
29
-
30
- # Vectorized FFT
31
- fft_prefactor2s = fft(prefactor2s, axis=1)
32
-
33
- # Vectorized __unpack_wave_inverse
34
- res = cp.zeros(ND, dtype=cp.complex128)
35
-
36
- # m == 0 or m == Nf cases
37
- i_ind_range = cp.arange(Nt // 2)
38
- i_0 = cp.abs(i_ind_range)
39
- i_Nf = cp.abs(Nf * Nt // 2 - i_ind_range)
40
- ind3_0 = (2 * i_0) % Nt
41
- ind3_Nf = (2 * i_Nf) % Nt
42
-
43
- res[i_0] += fft_prefactor2s[0, ind3_0] * phif[i_ind_range]
44
- res[i_Nf] += fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range]
45
-
46
- # Special case for m == Nf
47
- res[Nf * Nt // 2] += fft_prefactor2s[Nf, 0] * phif[Nt // 2]
48
-
49
- # Other m cases
50
- m_mid = m_range[1:Nf]
51
- i_ind_range = cp.arange(Nt // 2 + 1)
52
- m_grid, i_ind_grid = cp.meshgrid(m_mid, i_ind_range)
53
-
54
- i1 = Nt // 2 * m_grid - i_ind_grid
55
- i2 = Nt // 2 * m_grid + i_ind_grid
56
- ind31 = (Nt // 2 * m_grid - i_ind_grid) % Nt
57
- ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
58
-
59
- res[i1] += fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid]
60
- res[i2] += fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid]
15
+ pref2 = cp.zeros((Nf + 1, Nt), dtype=cp.complex128)
16
+ n = cp.arange(Nt)
17
+
18
+ # m=0 and m=Nf, with 1/√2
19
+ pref2[0, :] = wave[(2 * n) % Nt, 0] * (2 ** (-0.5))
20
+ pref2[Nf, :] = wave[((2 * n) % Nt) + 1, 0] * (2 ** (-0.5))
21
+
22
+ # middle m=1...Nf-1
23
+ m_mid = cp.arange(1, Nf)
24
+ # build meshgrids (m_mid rows, n cols)
25
+ mm, nn = cp.meshgrid(m_mid, n, indexing="ij")
26
+ vals = wave[nn, mm]
27
+ signs = cp.where(((nn + mm) % 2) == 1, -1j, 1)
28
+ pref2[1:Nf, :] = signs * vals
29
+
30
+ # === STEP 2: FFT along time axis ===
31
+ F = fft(pref2, axis=1) # shape (Nf+1, Nt)
32
+
33
+ # === STEP 3: unpack back into half-spectrum res[0...ND2] ===
34
+ res = cp.zeros(ND2 + 1, dtype=cp.complex128)
35
+ idx = cp.arange(half)
36
+
37
+ # 3a) contribution from m=0
38
+ res[idx] += F[0, (2 * idx) % Nt] * phif[idx]
39
+
40
+ # 3b) contribution from m=Nf
41
+ iNf = cp.abs(Nf * half - idx)
42
+ res[iNf] += F[Nf, (2 * idx) % Nt] * phif[idx]
43
+
44
+ # special Nyquist‐folding term
45
+ special = cp.abs(Nf * half - half)
46
+ res[special] += F[Nf, 0] * phif[half]
47
+
48
+ # 3c) middle m cases
49
+ m_mid = cp.arange(1, Nf)
50
+ i_mid = cp.arange(half) # 0...half-1
51
+ mm, ii = cp.meshgrid(m_mid, i_mid, indexing="ij")
52
+ i1 = (half * mm - ii) % (ND2 + 1)
53
+ i2 = (half * mm + ii) % (ND2 + 1)
54
+ ind1 = (half * mm - ii) % Nt
55
+ ind2 = (half * mm + ii) % Nt
56
+
57
+ # accumulate
58
+ res[i1] += F[mm, ind1] * phif[ii]
59
+ res[i2] += F[mm, ind2] * phif[ii]
60
+
61
+ # override the "center" points (j=0) exactly
62
+ centers = half * m_mid
63
+ fft_idx = centers % Nt
64
+ res[centers] = F[m_mid, fft_idx] * phif[0]
61
65
 
62
66
  return res
@@ -4,6 +4,30 @@ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
4
 
5
5
  logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
6
6
 
7
+
8
+ def _log_jax_info():
9
+ """Log JAX backend and precision information.
10
+
11
+ backend : str
12
+ JAX backend. ["cpu", "gpu", "tpu"]
13
+ precision : str
14
+ JAX precision. ["32bit", "64bit"]
15
+ """
16
+ import jax
17
+
18
+ _backend = jax.default_backend()
19
+ _precision = "64bit" if jax.config.jax_enable_x64 else "32bit"
20
+
21
+ logger.info(f"Jax running on {_backend} [{_precision} precision].")
22
+ if _precision == "32bit":
23
+ logger.warning(
24
+ "Jax is not running in 64bit precision. "
25
+ "To change, use jax.config.update('jax_enable_x64', True)."
26
+ )
27
+
28
+
29
+ _log_jax_info()
30
+
7
31
  __all__ = [
8
32
  "from_wavelet_to_time",
9
33
  "from_wavelet_to_freq",