pywavelet 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/__init__.py CHANGED
@@ -2,4 +2,26 @@
2
2
  WDM Wavelet transform
3
3
  """
4
4
 
5
+ import importlib
6
+ import os
7
+
8
+ from . import backend as _backend
9
+
5
10
  __version__ = "0.0.2"
11
+
12
+
13
+ def set_backend(backend: str):
14
+ """Set the backend for the wavelet transform.
15
+
16
+ Parameters
17
+ ----------
18
+ backend : str
19
+ Backend to use. Options are "numpy", "jax", "cupy".
20
+ """
21
+ from . import types
22
+ from . import transforms
23
+ os.environ["PYWAVELET_BACKEND"] = backend
24
+
25
+ importlib.reload(_backend)
26
+ importlib.reload(types)
27
+ importlib.reload(transforms)
pywavelet/_version.py CHANGED
@@ -1,8 +1,13 @@
1
- # file generated by setuptools_scm
1
+ # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
3
6
  TYPE_CHECKING = False
4
7
  if TYPE_CHECKING:
5
- from typing import Tuple, Union
8
+ from typing import Tuple
9
+ from typing import Union
10
+
6
11
  VERSION_TUPLE = Tuple[Union[int, str], ...]
7
12
  else:
8
13
  VERSION_TUPLE = object
@@ -12,5 +17,5 @@ __version__: str
12
17
  __version_tuple__: VERSION_TUPLE
13
18
  version_tuple: VERSION_TUPLE
14
19
 
15
- __version__ = version = '0.2.5'
16
- __version_tuple__ = version_tuple = (0, 2, 5)
20
+ __version__ = version = '0.2.6'
21
+ __version_tuple__ = version_tuple = (0, 2, 6)
pywavelet/backend.py CHANGED
@@ -1,31 +1,53 @@
1
+ import importlib
1
2
  import os
2
3
 
3
4
  from .logger import logger
4
5
 
5
- try:
6
- import jax
7
-
8
- jax_available = True
9
-
10
-
11
- except ImportError:
12
- jax_available = False
13
-
14
- use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
15
-
16
- if use_jax:
17
- import jax.numpy as xp # type: ignore
18
- from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
19
- from jax.scipy.special import betainc # type: ignore
20
-
21
- logger.info("Using JAX backend")
22
-
23
- else:
24
- import numpy as xp # type: ignore
25
- from numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
26
- from scipy.special import betainc # type: ignore
27
-
28
- logger.info("Using NumPy+numba backend")
29
-
30
-
31
- PI = xp.pi
6
+ JAX = "jax"
7
+ CUPY = "cupy"
8
+ NUMPY = "numpy"
9
+
10
+
11
+ def get_backend_from_env():
12
+ """Select and return the appropriate backend module."""
13
+ backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
14
+
15
+ if backend == JAX:
16
+ if importlib.util.find_spec(JAX):
17
+ import jax.numpy as xp
18
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
19
+ from jax.scipy.special import betainc
20
+
21
+ logger.info("Using JAX backend")
22
+ return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
23
+ else:
24
+ logger.warning(
25
+ "JAX backend requested but not installed. Falling back to NumPy."
26
+ )
27
+
28
+ elif backend == CUPY:
29
+ if importlib.util.find_spec(CUPY):
30
+ import cupy as xp
31
+ from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
32
+ from cupyx.scipy.special import betainc
33
+
34
+ logger.info("Using CuPy backend")
35
+ return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
36
+ else:
37
+ logger.warning(
38
+ "CuPy backend requested but not installed. Falling back to NumPy."
39
+ )
40
+
41
+ # Default to NumPy
42
+ import numpy as xp
43
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
44
+ from scipy.special import betainc
45
+
46
+ logger.info("Using NumPy+Numba backend")
47
+ return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
48
+
49
+
50
+ # Get the chosen backend
51
+ xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
52
+ get_backend_from_env()
53
+ )
@@ -1,12 +1,19 @@
1
- from ..backend import use_jax
1
+ from ..backend import current_backend
2
2
 
3
- if use_jax:
3
+ if current_backend == "jax":
4
4
  from .jax import (
5
5
  from_freq_to_wavelet,
6
6
  from_time_to_wavelet,
7
7
  from_wavelet_to_freq,
8
8
  from_wavelet_to_time,
9
9
  )
10
+ elif current_backend == "cupy":
11
+ from .cupy import (
12
+ from_freq_to_wavelet,
13
+ from_time_to_wavelet,
14
+ from_wavelet_to_freq,
15
+ from_wavelet_to_time,
16
+ )
10
17
  else:
11
18
  from .numpy import (
12
19
  from_wavelet_to_time,
@@ -15,7 +22,7 @@ else:
15
22
  from_freq_to_wavelet,
16
23
  )
17
24
 
18
- from .phi_computer import phi_vec, phitilde_vec, phitilde_vec_norm
25
+ from .phi_computer import omega, phi_vec, phitilde_vec_norm
19
26
 
20
27
  __all__ = [
21
28
  "from_wavelet_to_time",
@@ -24,5 +31,4 @@ __all__ = [
24
31
  "from_freq_to_wavelet",
25
32
  "phitilde_vec_norm",
26
33
  "phi_vec",
27
- "phitilde_vec",
28
34
  ]
@@ -0,0 +1,12 @@
1
+ from ...logger import logger
2
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
3
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
+
5
+ logger.warning("CUPY SUBPACKAGE NOT FULLY TESTED")
6
+
7
+ __all__ = [
8
+ "from_wavelet_to_time",
9
+ "from_wavelet_to_freq",
10
+ "from_time_to_wavelet",
11
+ "from_freq_to_wavelet",
12
+ ]
@@ -0,0 +1,3 @@
1
+ from .main import from_freq_to_wavelet, from_time_to_wavelet
2
+
3
+ __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -0,0 +1,92 @@
1
+ import cupy as cp
2
+ from cupyx.scipy.fft import ifft
3
+
4
+
5
+ import logging
6
+
7
+ logger = logging.getLogger('pywavelet')
8
+
9
+
10
+ def transform_wavelet_freq_helper(
11
+ data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray
12
+ ) -> cp.ndarray:
13
+ """
14
+ Transforms input data from the frequency domain to the wavelet domain using a
15
+ pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
16
+
17
+ Parameters:
18
+ - data (cp.ndarray): 1D array representing the input data in the frequency domain.
19
+ - Nf (int): Number of frequency bins.
20
+ - Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
21
+ - phif (cp.ndarray): Pre-computed wavelet filter for frequency components.
22
+
23
+ Returns:
24
+ - wave (cp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
25
+ """
26
+
27
+ logger.debug(f"[CUPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
28
+
29
+ # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
30
+ wave = cp.zeros((Nt, Nf))
31
+ f_bins = cp.arange(Nf) # Frequency bin indices
32
+
33
+ # Compute base indices for time (i_base) and frequency (jj_base)
34
+ i_base = Nt // 2
35
+ jj_base = f_bins * Nt // 2
36
+
37
+ # Set initial values for the center of the transformation
38
+ initial_values = cp.where(
39
+ (f_bins == 0)
40
+ | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
41
+ phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
42
+ phif[0] * data[f_bins * Nt // 2],
43
+ )
44
+
45
+ # Initialize a 2D array to store intermediate FFT input values
46
+ DX = cp.zeros((Nf, Nt), dtype=cp.complex64)
47
+ DX[:, Nt // 2] = (
48
+ initial_values # Set initial values at the center of the transformation (2 sided FFT)
49
+ )
50
+
51
+ # Compute time indices for all offsets around the midpoint
52
+ j_range = cp.arange(
53
+ 1 - Nt // 2, Nt // 2
54
+ ) # Time offsets (centered around zero)
55
+ j = cp.abs(j_range) # Absolute offset indices
56
+ i = i_base + j_range # Time indices in output array
57
+
58
+ # Compute conditions for edge cases
59
+ cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
60
+ cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
61
+ cond3 = j[None, :] == 0 # Center of the transformation (no offset)
62
+
63
+ # Compute frequency indices for the input data
64
+ jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
65
+ val = cp.where(
66
+ cond1 | cond2, 0.0, phif[j] * data[jj]
67
+ ) # Wavelet filter application
68
+ DX[:, i] = cp.where(cond3, DX[:, i], val) # Update DX with computed values
69
+
70
+ # Perform the inverse FFT along the time dimension
71
+ DX_trans = ifft(DX, axis=1)
72
+
73
+ # Fill the wavelet output array based on the inverse FFT results
74
+ n_range = cp.arange(Nt) # Time indices
75
+ cond1 = (
76
+ n_range[:, None] + f_bins[None, :]
77
+ ) % 2 == 1 # Odd/even alternation
78
+ cond2 = cp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
79
+
80
+ # Assign real and imaginary parts based on conditions
81
+ real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))
82
+ imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))
83
+ wave = cp.where(cond1, imag_part.T, real_part.T)
84
+
85
+ # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
86
+ wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2)) # DC component
87
+ wave[1::2, -1] = cp.real(
88
+ DX_trans[-1, ::2] * cp.sqrt(2)
89
+ ) # Nyquist component
90
+
91
+ # Return the wavelet-transformed array (transposed for freq-major layout)
92
+ return wave.T # (Nt, Nf) -> (Nf, Nt)
@@ -0,0 +1,50 @@
1
+ import cupy as cp
2
+ from cupyx.scipy.fft import rfft
3
+
4
+
5
+ def transform_wavelet_time_helper(
6
+ data: cp.ndarray, phi: cp.ndarray, Nf: int, Nt: int, mult: int
7
+ ) -> cp.ndarray:
8
+ """Helper function to do the wavelet transform in the time domain using CuPy"""
9
+ # Define constants
10
+ ND = Nf * Nt
11
+ K = mult * 2 * Nf
12
+
13
+ # Pad the data with K extra values
14
+ data_pad = cp.concatenate((data, data[:K]))
15
+
16
+ # Generate time bin indices
17
+ time_bins = cp.arange(Nt)
18
+ jj_base = (time_bins[:, None] * Nf - K // 2) % ND
19
+ jj = (jj_base + cp.arange(K)[None, :]) % ND
20
+
21
+ # Apply the window (phi) to the data
22
+ wdata = data_pad[jj] * phi[None, :]
23
+
24
+ # Perform FFT on the windowed data
25
+ wdata_trans = rfft(wdata, axis=1)
26
+
27
+ # Initialize the wavelet transform result
28
+ wave = cp.zeros((Nt, Nf))
29
+
30
+ # Handle m=0 case for even time bins
31
+ even_mask = (time_bins % 2 == 0) & (time_bins < Nt - 1)
32
+ even_indices = cp.nonzero(even_mask)[0]
33
+
34
+ # Update wave for m=0 using even time bins
35
+ wave[even_indices, 0] = cp.real(wdata_trans[even_indices, 0]) / cp.sqrt(2)
36
+ wave[even_indices + 1, 0] = cp.real(
37
+ wdata_trans[even_indices, Nf * mult]
38
+ ) / cp.sqrt(2)
39
+
40
+ # Handle other cases (j > 0) using vectorized operations
41
+ j_range = cp.arange(1, Nf)
42
+ odd_condition = (time_bins[:, None] + j_range[None, :]) % 2 == 1
43
+
44
+ wave[:, 1:] = cp.where(
45
+ odd_condition,
46
+ -cp.imag(wdata_trans[:, j_range * mult]),
47
+ cp.real(wdata_trans[:, j_range * mult]),
48
+ )
49
+
50
+ return wave.T
@@ -0,0 +1,106 @@
1
+ from typing import Union
2
+
3
+ import cupy as cp
4
+
5
+ from ....logger import logger
6
+ from ....types import FrequencySeries, TimeSeries, Wavelet
7
+ from ....types.wavelet_bins import _get_bins, _preprocess_bins
8
+ from ...phi_computer import phi_vec, phitilde_vec_norm
9
+ from .from_freq import transform_wavelet_freq_helper
10
+ from .from_time import transform_wavelet_time_helper
11
+
12
+
13
+ def from_time_to_wavelet(
14
+ timeseries: TimeSeries,
15
+ Nf: Union[int, None] = None,
16
+ Nt: Union[int, None] = None,
17
+ nx: float = 4.0,
18
+ mult: int = 32,
19
+ **kwargs,
20
+ ) -> Wavelet:
21
+ """Transforms time-domain data to wavelet-domain data.
22
+
23
+ Warning: there can be significant leakage if mult is too small and the
24
+ transform is only approximately exact if mult=Nt/2
25
+
26
+ Parameters
27
+ ----------
28
+ timeseries : TimeSeries
29
+ Time domain data
30
+ Nf : int
31
+ Number of frequency bins
32
+ Nt : int
33
+ Number of time bins
34
+ nx : float, optional
35
+ Number of standard deviations for the phi_vec, by default 4.
36
+ mult : int, optional
37
+ Number of time bins to use for the wavelet transform, by default 32
38
+ **kwargs:
39
+ Additional keyword arguments passed to the Wavelet.from_data constructor.
40
+
41
+ Returns
42
+ -------
43
+ Wavelet
44
+ Wavelet domain data
45
+
46
+ """
47
+ Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
48
+ dt = timeseries.dt
49
+ t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
50
+
51
+ ND = Nf * Nt
52
+
53
+ if len(timeseries) != ND:
54
+ logger.warning(
55
+ f"len(freqseries)={len(timeseries)} != Nf*Nt={ND}. Truncating to freqseries[:{ND}]"
56
+ )
57
+ timeseries = timeseries[:ND]
58
+ if mult > Nt / 2:
59
+ logger.warning(
60
+ f"mult={mult} is too large for Nt={Nt}. This may lead to bogus results."
61
+ )
62
+
63
+ mult = min(mult, Nt // 2) # make sure K isn't bigger than ND
64
+ phi = cp.array(phi_vec(Nf, d=nx, q=mult))
65
+ wave = transform_wavelet_time_helper(
66
+ cp.array(timeseries.data), Nf=Nf, Nt=Nt, phi=phi, mult=mult
67
+ )
68
+ return Wavelet(wave * cp.sqrt(2), time=t_bins, freq=f_bins)
69
+
70
+
71
+ def from_freq_to_wavelet(
72
+ freqseries: FrequencySeries,
73
+ Nf: Union[int, None] = None,
74
+ Nt: Union[int, None] = None,
75
+ nx: float = 4.0,
76
+ **kwargs,
77
+ ) -> Wavelet:
78
+ """Transforms frequency-domain data to wavelet-domain data.
79
+
80
+ Parameters
81
+ ----------
82
+ freqseries : FrequencySeries
83
+ Frequency domain data
84
+ Nf : int
85
+ Number of frequency bins
86
+ Nt : int
87
+ Number of time bins
88
+ nx : float, optional
89
+ Number of standard deviations for the phi_vec, by default 4.
90
+ **kwargs:
91
+ Additional keyword arguments passed to the Wavelet.from_data constructor.
92
+
93
+ Returns
94
+ -------
95
+ Wavelet
96
+ Wavelet domain data
97
+
98
+ """
99
+ Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
100
+ t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
101
+ phif = cp.array(phitilde_vec_norm(Nf, Nt, d=nx))
102
+ wave = transform_wavelet_freq_helper(
103
+ cp.array(freqseries.data), Nf=Nf, Nt=Nt, phif=phif
104
+ )
105
+
106
+ return Wavelet((2 / Nf) * wave * cp.sqrt(2), time=t_bins, freq=f_bins)
@@ -0,0 +1,3 @@
1
+ from .main import from_wavelet_to_freq, from_wavelet_to_time
2
+
3
+ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
@@ -0,0 +1,67 @@
1
+ import cupy as cp
2
+ from cupyx.scipy.fft import rfftfreq
3
+
4
+ from ....types import FrequencySeries, TimeSeries, Wavelet
5
+ from ...phi_computer import phi_vec, phitilde_vec_norm
6
+ from .to_freq import inverse_wavelet_freq_helper
7
+
8
+
9
+ def from_wavelet_to_time(
10
+ wave_in: Wavelet,
11
+ dt: float,
12
+ nx: float = 4.0,
13
+ mult: int = 32,
14
+ ) -> TimeSeries:
15
+ """Inverse wavelet transform to time domain.
16
+
17
+ Parameters
18
+ ----------
19
+ wave_in : Wavelet
20
+ input wavelet
21
+ dt : float
22
+ time step
23
+ nx : float, optional
24
+ parameter for phi_vec, by default 4.0
25
+ mult : int, optional
26
+ parameter for phi_vec, by default 32
27
+
28
+ Returns
29
+ -------
30
+ TimeSeries
31
+ Time domain signal
32
+ """
33
+ freq = from_wavelet_to_freq(wave_in, dt=dt, nx=nx)
34
+ return freq.to_timeseries()
35
+
36
+
37
+ def from_wavelet_to_freq(
38
+ wave_in: Wavelet, dt: float, nx=4.0
39
+ ) -> FrequencySeries:
40
+ """Inverse wavelet transform to frequency domain.
41
+
42
+ Parameters
43
+ ----------
44
+ wave_in : Wavelet
45
+ input wavelet
46
+ dt : float
47
+ time step
48
+ nx : float, optional
49
+ parameter for phitilde_vec_norm, by default 4.0
50
+
51
+ Returns
52
+ -------
53
+ FrequencySeries
54
+ Frequency domain signal
55
+
56
+ """
57
+ phif = cp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx))
58
+ freq_data = inverse_wavelet_freq_helper(
59
+ wave_in.data, phif=phif, Nf=wave_in.Nf, Nt=wave_in.Nt
60
+ )
61
+
62
+ freq_data *= 2 ** (
63
+ -1 / 2
64
+ ) # Normalise to get the proper backwards transformation
65
+
66
+ freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
67
+ return FrequencySeries(data=freq_data, freq=freqs)
@@ -0,0 +1,62 @@
1
+ import cupy as cp
2
+ from cupyx.scipy.fft import fft
3
+
4
+
5
+ def inverse_wavelet_freq_helper(
6
+ wave_in: cp.ndarray, phif: cp.ndarray, Nf: int, Nt: int
7
+ ) -> cp.ndarray:
8
+ """CuPy vectorized function for inverse_wavelet_freq"""
9
+ wave_in = wave_in.T
10
+ ND = Nf * Nt
11
+
12
+ m_range = cp.arange(Nf + 1)
13
+ prefactor2s = cp.zeros((Nf + 1, Nt), dtype=cp.complex128)
14
+
15
+ n_range = cp.arange(Nt)
16
+
17
+ # m == 0 case
18
+ prefactor2s[0] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
19
+
20
+ # m == Nf case
21
+ prefactor2s[Nf] = 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
22
+
23
+ # Other m cases
24
+ m_mid = m_range[1:Nf]
25
+ n_grid, m_grid = cp.meshgrid(n_range, m_mid)
26
+ val = wave_in[n_grid, m_grid]
27
+ mult2 = cp.where((n_grid + m_grid) % 2, -1j, 1)
28
+ prefactor2s[1:Nf] = mult2 * val
29
+
30
+ # Vectorized FFT
31
+ fft_prefactor2s = fft(prefactor2s, axis=1)
32
+
33
+ # Vectorized __unpack_wave_inverse
34
+ res = cp.zeros(ND, dtype=cp.complex128)
35
+
36
+ # m == 0 or m == Nf cases
37
+ i_ind_range = cp.arange(Nt // 2)
38
+ i_0 = cp.abs(i_ind_range)
39
+ i_Nf = cp.abs(Nf * Nt // 2 - i_ind_range)
40
+ ind3_0 = (2 * i_0) % Nt
41
+ ind3_Nf = (2 * i_Nf) % Nt
42
+
43
+ res[i_0] += fft_prefactor2s[0, ind3_0] * phif[i_ind_range]
44
+ res[i_Nf] += fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range]
45
+
46
+ # Special case for m == Nf
47
+ res[Nf * Nt // 2] += fft_prefactor2s[Nf, 0] * phif[Nt // 2]
48
+
49
+ # Other m cases
50
+ m_mid = m_range[1:Nf]
51
+ i_ind_range = cp.arange(Nt // 2 + 1)
52
+ m_grid, i_ind_grid = cp.meshgrid(m_mid, i_ind_range)
53
+
54
+ i1 = Nt // 2 * m_grid - i_ind_grid
55
+ i2 = Nt // 2 * m_grid + i_ind_grid
56
+ ind31 = (Nt // 2 * m_grid - i_ind_grid) % Nt
57
+ ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
58
+
59
+ res[i1] += fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid]
60
+ res[i2] += fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid]
61
+
62
+ return res
@@ -4,6 +4,10 @@ import jax.numpy as jnp
4
4
  from jax import jit
5
5
  from jax.numpy.fft import ifft
6
6
 
7
+ import logging
8
+
9
+ logger = logging.getLogger('pywavelet')
10
+
7
11
 
8
12
  @partial(jit, static_argnames=("Nf", "Nt"))
9
13
  def transform_wavelet_freq_helper(
@@ -23,6 +27,8 @@ def transform_wavelet_freq_helper(
23
27
  - wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
24
28
  """
25
29
 
30
+ logger.debug(f"[JAX TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
31
+
26
32
  # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
27
33
  wave = jnp.zeros((Nt, Nf))
28
34
  f_bins = jnp.arange(Nf) # Frequency bin indices
@@ -10,7 +10,7 @@ def from_wavelet_to_time(
10
10
  wave_in: Wavelet,
11
11
  dt: float,
12
12
  nx: float = 4.0,
13
- mult: int = 32,
13
+ mult: int = None,
14
14
  ) -> TimeSeries:
15
15
  """Inverse wavelet transform to time domain.
16
16
 
@@ -55,14 +55,12 @@ def from_wavelet_to_freq(
55
55
  Frequency domain signal
56
56
 
57
57
  """
58
- phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, dt=dt, d=nx))
58
+ phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx))
59
59
  freq_data = inverse_wavelet_freq_helper(
60
60
  wave_in.data, phif=phif, Nf=wave_in.Nf, Nt=wave_in.Nt
61
61
  )
62
62
 
63
- freq_data *= 2 ** (
64
- -1 / 2
65
- ) # Normalise to get the proper backwards transformation
63
+ freq_data *= 1.0 / jnp.sqrt(2)
66
64
 
67
- freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
65
+ freqs = rfftfreq(wave_in.ND, d=dt)
68
66
  return FrequencySeries(data=freq_data, freq=freqs)