pywavelet 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/__init__.py CHANGED
@@ -6,11 +6,10 @@ import importlib
6
6
  import os
7
7
 
8
8
  from . import backend as _backend
9
+ from ._version import __version__
9
10
 
10
- __version__ = "0.0.2"
11
11
 
12
-
13
- def set_backend(backend: str):
12
+ def set_backend(backend: str, precision: str = "float32") -> None:
14
13
  """Set the backend for the wavelet transform.
15
14
 
16
15
  Parameters
@@ -18,10 +17,23 @@ def set_backend(backend: str):
18
17
  backend : str
19
18
  Backend to use. Options are "numpy", "jax", "cupy".
20
19
  """
21
- from . import types
22
- from . import transforms
20
+ from . import transforms, types
21
+
23
22
  os.environ["PYWAVELET_BACKEND"] = backend
23
+ os.environ["PYWAVELET_PRECISION"] = precision
24
24
 
25
25
  importlib.reload(_backend)
26
26
  importlib.reload(types)
27
27
  importlib.reload(transforms)
28
+ if backend == "cupy":
29
+ importlib.reload(transforms.cupy)
30
+ importlib.reload(transforms.cupy.forward)
31
+ importlib.reload(transforms.cupy.inverse)
32
+ elif backend == "jax":
33
+ importlib.reload(transforms.jax)
34
+ importlib.reload(transforms.jax.forward)
35
+ importlib.reload(transforms.jax.inverse)
36
+ else:
37
+ importlib.reload(transforms.numpy)
38
+ importlib.reload(transforms.numpy.forward)
39
+ importlib.reload(transforms.numpy.inverse)
pywavelet/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.7'
21
- __version_tuple__ = version_tuple = (0, 2, 7)
20
+ __version__ = version = '0.2.8'
21
+ __version_tuple__ = version_tuple = (0, 2, 8)
pywavelet/backend.py CHANGED
@@ -1,9 +1,10 @@
1
1
  import importlib
2
2
  import os
3
- from rich.table import Table, Text
4
- from rich.console import Console
5
-
3
+ from typing import Tuple
6
4
 
5
+ import numpy as np
6
+ from rich.console import Console
7
+ from rich.table import Table, Text
7
8
 
8
9
  from .logger import logger
9
10
 
@@ -11,22 +12,25 @@ JAX = "jax"
11
12
  CUPY = "cupy"
12
13
  NUMPY = "numpy"
13
14
 
15
+ VALID_PRECISIONS = ["float32", "float64"]
16
+
14
17
 
15
18
  def cuda_is_available() -> bool:
16
19
  """Check if CUDA is available."""
17
20
  # Check if CuPy is available and CUDA is accessible
18
21
  cupy_available = importlib.util.find_spec("cupy") is not None
22
+ _cuda_available = False
19
23
  if cupy_available:
20
24
  import cupy
21
25
 
22
26
  try:
23
27
  cupy.cuda.runtime.getDeviceCount() # Check if any CUDA device is available
24
- cuda_available = True
28
+ _cuda_available = True
25
29
  except cupy.cuda.runtime.CUDARuntimeError:
26
- cuda_available = False
30
+ _cuda_available = False
27
31
  else:
28
- cuda_available = False
29
- return cuda_available
32
+ _cuda_available = False
33
+ return _cuda_available
30
34
 
31
35
 
32
36
  def jax_is_available() -> bool:
@@ -51,9 +55,21 @@ def get_available_backends_table():
51
55
  return Text.from_ansi(capture.get())
52
56
 
53
57
 
58
+ def log_backend(level="info"):
59
+ """Print the current backend and precision."""
60
+ backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
61
+ precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
62
+ str = f"Current backend: {backend}[{precision}]"
63
+ if level == "info":
64
+ logger.info(str)
65
+ elif level == "debug":
66
+ logger.debug(str)
67
+
68
+
54
69
  def get_backend_from_env():
55
70
  """Select and return the appropriate backend module."""
56
71
  backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
72
+ precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
57
73
 
58
74
  if backend == JAX and jax_is_available():
59
75
 
@@ -61,41 +77,103 @@ def get_backend_from_env():
61
77
  from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
62
78
  from jax.scipy.special import betainc
63
79
 
64
- logger.info("Using JAX backend")
65
-
66
80
  elif backend == CUPY and cuda_is_available():
67
81
 
68
82
  import cupy as xp
69
83
  from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
70
84
  from cupyx.scipy.special import betainc
71
85
 
72
- logger.info("Using CuPy backend")
73
-
74
86
  elif backend == NUMPY:
75
87
  import numpy as xp
76
88
  from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
77
89
  from scipy.special import betainc
78
90
 
79
- logger.info("Using NumPy backend")
80
-
81
-
82
91
  else:
83
- logger.error(
84
- f"Backend {backend} is not available. "
85
- )
92
+ logger.error(f"Backend {backend}[{precision}] is not available. ")
86
93
  print(get_available_backends_table())
87
- logger.warning(
88
- f"Setting backend to NumPy. "
89
- )
94
+ logger.warning(f"Setting backend to NumPy. ")
90
95
  os.environ["PYWAVELET_BACKEND"] = NUMPY
91
96
  return get_backend_from_env()
92
97
 
98
+ log_backend("debug")
93
99
  return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
94
100
 
95
101
 
102
+ def get_precision_from_env() -> str:
103
+ """Get the precision from the environment variable."""
104
+ precision = os.getenv("PYWAVELET_PRECISION", "float32").lower()
105
+ if precision not in VALID_PRECISIONS:
106
+ logger.error(
107
+ f"Precision {precision} is not supported, defaulting to float32."
108
+ )
109
+ precision = "float32"
110
+ return precision
111
+
112
+
113
+ def set_precision(precision: str) -> None:
114
+ """Set the precision for the backend."""
115
+ precision = precision.lower()
116
+ if precision not in VALID_PRECISIONS:
117
+ logger.error(f"Precision {precision} is not supported.")
118
+ return
119
+ else:
120
+ os.environ["PYWAVELET_PRECISION"] = precision
121
+ logger.info(f"Setting precision to {precision}.")
122
+ return
123
+
124
+
125
+ def get_dtype_from_env() -> Tuple[np.dtype, np.dtype]:
126
+ """Get the data type from the environment variable."""
127
+ precision = get_precision_from_env()
128
+ backend = get_backend_from_env()[-1]
129
+ if backend == JAX:
130
+
131
+ if precision == "float32":
132
+ import jax
133
+
134
+ jax.config.update("jax_enable_x64", False)
135
+
136
+ import jax.numpy as jnp
137
+
138
+ float_dtype = jnp.float32
139
+ complex_dtype = jnp.complex64
140
+ elif precision == "float64":
141
+ import jax
142
+
143
+ jax.config.update("jax_enable_x64", True)
144
+
145
+ import jax.numpy as jnp
146
+
147
+ float_dtype = jnp.float64
148
+ complex_dtype = jnp.complex128
149
+
150
+ elif backend == CUPY:
151
+ import cupy as cp
152
+
153
+ if precision == "float32":
154
+ float_dtype = cp.float32
155
+ complex_dtype = cp.complex64
156
+ elif precision == "float64":
157
+ float_dtype = cp.float64
158
+ complex_dtype = cp.complex128
159
+
160
+ else:
161
+ if precision == "float32":
162
+ float_dtype = np.float32
163
+ complex_dtype = np.complex64
164
+ elif precision == "float64":
165
+ float_dtype = np.float64
166
+ complex_dtype = np.complex128
167
+
168
+ return float_dtype, complex_dtype
169
+
170
+
96
171
  cuda_available = cuda_is_available()
97
172
 
98
173
  # Get the chosen backend
99
174
  xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
100
175
  get_backend_from_env()
101
176
  )
177
+
178
+ # Get the chosen precision
179
+ float_dtype, complex_dtype = get_dtype_from_env()
pywavelet/logger.py CHANGED
@@ -5,14 +5,14 @@ import warnings
5
5
  from rich.logging import RichHandler
6
6
 
7
7
  FORMAT = "%(message)s"
8
- logging.basicConfig(
9
- level="INFO",
10
- format=FORMAT,
11
- datefmt="[%X]",
12
- handlers=[RichHandler(rich_tracebacks=True)],
13
- )
14
8
 
15
9
  logger = logging.getLogger("pywavelet")
10
+ if not logger.hasHandlers():
11
+ handler = RichHandler(rich_tracebacks=True)
12
+ formatter = logging.Formatter(fmt=FORMAT, datefmt="[%X]")
13
+ handler.setFormatter(formatter)
14
+ logger.addHandler(handler)
15
+ logger.setLevel(logging.INFO)
16
16
 
17
17
  warnings.filterwarnings("ignore", category=RuntimeWarning)
18
18
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -1,7 +1,5 @@
1
- from ..logger import logger
2
-
3
1
  from ..backend import current_backend
4
-
2
+ from ..logger import logger
5
3
 
6
4
  logger.debug(f"Using {current_backend} backend")
7
5
 
@@ -1,14 +1,18 @@
1
+ import logging
2
+
1
3
  import cupy as cp
2
4
  from cupyx.scipy.fft import ifft
3
5
 
4
-
5
- import logging
6
-
7
- logger = logging.getLogger('pywavelet')
6
+ logger = logging.getLogger("pywavelet")
8
7
 
9
8
 
10
9
  def transform_wavelet_freq_helper(
11
- data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray
10
+ data: cp.ndarray,
11
+ Nf: int,
12
+ Nt: int,
13
+ phif: cp.ndarray,
14
+ float_dtype=cp.float64,
15
+ complex_dtype=cp.complex128,
12
16
  ) -> cp.ndarray:
13
17
  """
14
18
  Transforms input data from the frequency domain to the wavelet domain using a
@@ -24,69 +28,62 @@ def transform_wavelet_freq_helper(
24
28
  - wave (cp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
25
29
  """
26
30
 
27
- logger.debug(f"[CUPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
28
-
29
- # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
30
- wave = cp.zeros((Nt, Nf))
31
- f_bins = cp.arange(Nf) # Frequency bin indices
32
-
33
- # Compute base indices for time (i_base) and frequency (jj_base)
34
- i_base = Nt // 2
35
- jj_base = f_bins * Nt // 2
36
-
37
- # Set initial values for the center of the transformation
38
- initial_values = cp.where(
39
- (f_bins == 0)
40
- | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
41
- phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
42
- phif[0] * data[f_bins * Nt // 2],
31
+ logger.debug(
32
+ f"[CUPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}, precision:{data.dtype}]"
43
33
  )
44
34
 
45
- # Initialize a 2D array to store intermediate FFT input values
46
- DX = cp.zeros((Nf, Nt), dtype=cp.complex64)
47
- DX[:, Nt // 2] = (
48
- initial_values # Set initial values at the center of the transformation (2 sided FFT)
35
+ half = Nt // 2
36
+ f_bins = cp.arange(Nf + 1) # [0, 1, ..., Nf]
37
+
38
+ # 1) Build (Nf+1, Nt) DX
39
+ # — center:
40
+ center = phif[0] * data[f_bins * half]
41
+ center = cp.where((f_bins == 0) | (f_bins == Nf), center * 0.5, center)
42
+ DX = cp.zeros((Nf + 1, Nt), complex_dtype)
43
+ DX[:, half] = center
44
+
45
+ # — off-center (j = ±1...±(half-1))
46
+ offs = cp.arange(1 - half, half) # length Nt-1
47
+ jj = f_bins[:, None] * half + offs[None, :] # (Nf+1, Nt-1)
48
+ ii = half + offs # (Nt-1,)
49
+ mask = ((f_bins[:, None] == Nf) & (offs[None, :] > 0)) | (
50
+ (f_bins[:, None] == 0) & (offs[None, :] < 0)
51
+ )
52
+ vals = phif[cp.abs(offs)] * data[jj]
53
+ vals = cp.where(mask, 0.0, vals)
54
+ DX[:, ii] = vals
55
+
56
+ # 2) IFFT along time
57
+ DXt = ifft(DX, axis=1)
58
+
59
+ # 3) Unpack into wave (Nt, Nf)
60
+ wave = cp.zeros((Nt, Nf), float_dtype)
61
+ sqrt2 = cp.sqrt(2.0)
62
+
63
+ # 3a) DC into col 0, even rows
64
+ evens = cp.arange(0, Nt, 2)
65
+ wave[evens, 0] = cp.real(DXt[0, evens]) * sqrt2
66
+
67
+ # 3b) Nyquist into col 0, odd rows
68
+ odds = cp.arange(1, Nt, 2)
69
+ wave[odds, 0] = cp.real(DXt[Nf, evens]) * sqrt2
70
+
71
+ # 3c) intermediate bins 1...Nf-1
72
+ mids = cp.arange(1, Nf) # [1...Nf-1]
73
+ Dt_mid = DXt[mids, :] # (Nf-1, Nt)
74
+ real_m = cp.real(Dt_mid).T # (Nt, Nf-1)
75
+ imag_m = cp.imag(Dt_mid).T # (Nt, Nf-1)
76
+
77
+ odd_f = (mids % 2) == 1 # shape (Nf-1,)
78
+ n_idx = cp.arange(Nt)[:, None] # (Nt,1)
79
+ odd_nf = ((n_idx + mids[None, :]) % 2) == 1
80
+
81
+ # Select real vs imag and sign exactly as in __fill_wave_2
82
+ mid_vals = cp.where(
83
+ odd_nf,
84
+ cp.where(odd_f, -imag_m, imag_m),
85
+ cp.where(odd_f, real_m, real_m),
49
86
  )
87
+ wave[:, 1:Nf] = mid_vals
50
88
 
51
- # Compute time indices for all offsets around the midpoint
52
- j_range = cp.arange(
53
- 1 - Nt // 2, Nt // 2
54
- ) # Time offsets (centered around zero)
55
- j = cp.abs(j_range) # Absolute offset indices
56
- i = i_base + j_range # Time indices in output array
57
-
58
- # Compute conditions for edge cases
59
- cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
60
- cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
61
- cond3 = j[None, :] == 0 # Center of the transformation (no offset)
62
-
63
- # Compute frequency indices for the input data
64
- jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
65
- val = cp.where(
66
- cond1 | cond2, 0.0, phif[j] * data[jj]
67
- ) # Wavelet filter application
68
- DX[:, i] = cp.where(cond3, DX[:, i], val) # Update DX with computed values
69
-
70
- # Perform the inverse FFT along the time dimension
71
- DX_trans = ifft(DX, axis=1)
72
-
73
- # Fill the wavelet output array based on the inverse FFT results
74
- n_range = cp.arange(Nt) # Time indices
75
- cond1 = (
76
- n_range[:, None] + f_bins[None, :]
77
- ) % 2 == 1 # Odd/even alternation
78
- cond2 = cp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
79
-
80
- # Assign real and imaginary parts based on conditions
81
- real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))
82
- imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))
83
- wave = cp.where(cond1, imag_part.T, real_part.T)
84
-
85
- # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
86
- wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2)) # DC component
87
- wave[1::2, -1] = cp.real(
88
- DX_trans[-1, ::2] * cp.sqrt(2)
89
- ) # Nyquist component
90
-
91
- # Return the wavelet-transformed array (transposed for freq-major layout)
92
- return wave.T # (Nt, Nf) -> (Nf, Nt)
89
+ return wave.T
@@ -2,6 +2,7 @@ from typing import Union
2
2
 
3
3
  import cupy as cp
4
4
 
5
+ from .... import backend
5
6
  from ....logger import logger
6
7
  from ....types import FrequencySeries, TimeSeries, Wavelet
7
8
  from ....types.wavelet_bins import _get_bins, _preprocess_bins
@@ -16,7 +17,6 @@ def from_time_to_wavelet(
16
17
  Nt: Union[int, None] = None,
17
18
  nx: float = 4.0,
18
19
  mult: int = 32,
19
- **kwargs,
20
20
  ) -> Wavelet:
21
21
  """Transforms time-domain data to wavelet-domain data.
22
22
 
@@ -45,7 +45,6 @@ def from_time_to_wavelet(
45
45
 
46
46
  """
47
47
  Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
48
- dt = timeseries.dt
49
48
  t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
50
49
 
51
50
  ND = Nf * Nt
@@ -73,7 +72,6 @@ def from_freq_to_wavelet(
73
72
  Nf: Union[int, None] = None,
74
73
  Nt: Union[int, None] = None,
75
74
  nx: float = 4.0,
76
- **kwargs,
77
75
  ) -> Wavelet:
78
76
  """Transforms frequency-domain data to wavelet-domain data.
79
77
 
@@ -98,9 +96,15 @@ def from_freq_to_wavelet(
98
96
  """
99
97
  Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
100
98
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
101
- phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx))
99
+ phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx), dtype=backend.float_dtype)
100
+ data = cp.array(freqseries.data, dtype=backend.complex_dtype)
102
101
  wave = transform_wavelet_freq_helper(
103
- cp.array(freqseries.data), Nf=Nf, Nt=Nt, phif=phif
102
+ data,
103
+ Nf=Nf,
104
+ Nt=Nt,
105
+ phif=phif,
106
+ float_dtype=backend.float_dtype,
107
+ complex_dtype=backend.complex_dtype,
104
108
  )
105
-
106
- return Wavelet((2 / Nf) * wave * cp.sqrt(2), time=t_bins, freq=f_bins)
109
+ factor = (2 / Nf) * cp.sqrt(2)
110
+ return Wavelet(factor * wave, time=t_bins, freq=f_bins)
@@ -6,57 +6,61 @@ def inverse_wavelet_freq_helper(
6
6
  wave_in: cp.ndarray, phif: cp.ndarray, Nf: int, Nt: int
7
7
  ) -> cp.ndarray:
8
8
  """CuPy vectorized function for inverse_wavelet_freq"""
9
- wave_in = wave_in.T
10
- ND = Nf * Nt
9
+ wave = wave_in.T
10
+ ND2 = (Nf * Nt) // 2
11
+ half = Nt // 2
11
12
 
13
+ # === STEP 1: build prefactor2s[m, n] ===
12
14
  m_range = cp.arange(Nf + 1)
13
- prefactor2s = cp.zeros((Nf + 1, Nt), dtype=cp.complex128)
14
-
15
- n_range = cp.arange(Nt)
16
-
17
- # m == 0 case
18
- prefactor2s[0] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
19
-
20
- # m == Nf case
21
- prefactor2s[Nf] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
22
-
23
- # Other m cases
24
- m_mid = m_range[1:Nf]
25
- n_grid, m_grid = cp.meshgrid(n_range, m_mid)
26
- val = wave_in[n_grid, m_grid]
27
- mult2 = cp.where((n_grid + m_grid) % 2, -1j, 1)
28
- prefactor2s[1:Nf] = mult2 * val
29
-
30
- # Vectorized FFT
31
- fft_prefactor2s = fft(prefactor2s, axis=1)
32
-
33
- # Vectorized __unpack_wave_inverse
34
- res = cp.zeros(ND, dtype=cp.complex128)
35
-
36
- # m == 0 or m == Nf cases
37
- i_ind_range = cp.arange(Nt // 2)
38
- i_0 = cp.abs(i_ind_range)
39
- i_Nf = cp.abs(Nf * Nt // 2 - i_ind_range)
40
- ind3_0 = (2 * i_0) % Nt
41
- ind3_Nf = (2 * i_Nf) % Nt
42
-
43
- res[i_0] += fft_prefactor2s[0, ind3_0] * phif[i_ind_range]
44
- res[i_Nf] += fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range]
45
-
46
- # Special case for m == Nf
47
- res[Nf * Nt // 2] += fft_prefactor2s[Nf, 0] * phif[Nt // 2]
48
-
49
- # Other m cases
50
- m_mid = m_range[1:Nf]
51
- i_ind_range = cp.arange(Nt // 2 + 1)
52
- m_grid, i_ind_grid = cp.meshgrid(m_mid, i_ind_range)
53
-
54
- i1 = Nt // 2 * m_grid - i_ind_grid
55
- i2 = Nt // 2 * m_grid + i_ind_grid
56
- ind31 = (Nt // 2 * m_grid - i_ind_grid) % Nt
57
- ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
58
-
59
- res[i1] += fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid]
60
- res[i2] += fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid]
15
+ pref2 = cp.zeros((Nf + 1, Nt), dtype=cp.complex128)
16
+ n = cp.arange(Nt)
17
+
18
+ # m=0 and m=Nf, with 1/√2
19
+ pref2[0, :] = wave[(2 * n) % Nt, 0] * (2 ** (-0.5))
20
+ pref2[Nf, :] = wave[((2 * n) % Nt) + 1, 0] * (2 ** (-0.5))
21
+
22
+ # middle m=1...Nf-1
23
+ m_mid = cp.arange(1, Nf)
24
+ # build meshgrids (m_mid rows, n cols)
25
+ mm, nn = cp.meshgrid(m_mid, n, indexing="ij")
26
+ vals = wave[nn, mm]
27
+ signs = cp.where(((nn + mm) % 2) == 1, -1j, 1)
28
+ pref2[1:Nf, :] = signs * vals
29
+
30
+ # === STEP 2: FFT along time axis ===
31
+ F = fft(pref2, axis=1) # shape (Nf+1, Nt)
32
+
33
+ # === STEP 3: unpack back into half-spectrum res[0...ND2] ===
34
+ res = cp.zeros(ND2 + 1, dtype=cp.complex128)
35
+ idx = cp.arange(half)
36
+
37
+ # 3a) contribution from m=0
38
+ res[idx] += F[0, (2 * idx) % Nt] * phif[idx]
39
+
40
+ # 3b) contribution from m=Nf
41
+ iNf = cp.abs(Nf * half - idx)
42
+ res[iNf] += F[Nf, (2 * idx) % Nt] * phif[idx]
43
+
44
+ # special Nyquist‐folding term
45
+ special = cp.abs(Nf * half - half)
46
+ res[special] += F[Nf, 0] * phif[half]
47
+
48
+ # 3c) middle m cases
49
+ m_mid = cp.arange(1, Nf)
50
+ i_mid = cp.arange(half) # 0...half-1
51
+ mm, ii = cp.meshgrid(m_mid, i_mid, indexing="ij")
52
+ i1 = (half * mm - ii) % (ND2 + 1)
53
+ i2 = (half * mm + ii) % (ND2 + 1)
54
+ ind1 = (half * mm - ii) % Nt
55
+ ind2 = (half * mm + ii) % Nt
56
+
57
+ # accumulate
58
+ res[i1] += F[mm, ind1] * phif[ii]
59
+ res[i2] += F[mm, ind2] * phif[ii]
60
+
61
+ # override the "center" points (j=0) exactly
62
+ centers = half * m_mid
63
+ fft_idx = centers % Nt
64
+ res[centers] = F[m_mid, fft_idx] * phif[0]
61
65
 
62
66
  return res