pywavelet 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. pywavelet/__init__.py +22 -0
  2. pywavelet/_version.py +9 -4
  3. pywavelet/backend.py +49 -27
  4. pywavelet/transforms/__init__.py +10 -4
  5. pywavelet/transforms/cupy/__init__.py +12 -0
  6. pywavelet/transforms/cupy/forward/__init__.py +3 -0
  7. pywavelet/transforms/cupy/forward/from_freq.py +92 -0
  8. pywavelet/transforms/cupy/forward/from_time.py +50 -0
  9. pywavelet/transforms/cupy/forward/main.py +106 -0
  10. pywavelet/transforms/cupy/inverse/__init__.py +3 -0
  11. pywavelet/transforms/cupy/inverse/main.py +67 -0
  12. pywavelet/transforms/cupy/inverse/to_freq.py +62 -0
  13. pywavelet/transforms/jax/forward/from_freq.py +6 -0
  14. pywavelet/transforms/jax/forward/from_time.py +18 -10
  15. pywavelet/transforms/jax/forward/main.py +6 -10
  16. pywavelet/transforms/jax/inverse/main.py +4 -6
  17. pywavelet/transforms/jax/inverse/to_freq.py +52 -34
  18. pywavelet/transforms/numpy/__init__.py +1 -2
  19. pywavelet/transforms/numpy/forward/from_freq.py +77 -19
  20. pywavelet/transforms/numpy/forward/main.py +1 -2
  21. pywavelet/transforms/numpy/inverse/main.py +4 -6
  22. pywavelet/transforms/numpy/inverse/to_freq.py +64 -1
  23. pywavelet/transforms/phi_computer.py +67 -86
  24. pywavelet/types/common.py +4 -3
  25. pywavelet/types/frequencyseries.py +1 -1
  26. pywavelet/types/plotting.py +14 -5
  27. pywavelet/types/timeseries.py +4 -10
  28. pywavelet/types/wavelet.py +6 -6
  29. pywavelet/types/wavelet_bins.py +0 -1
  30. pywavelet/utils.py +2 -0
  31. {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/METADATA +20 -9
  32. pywavelet-0.2.6.dist-info/RECORD +43 -0
  33. {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/WHEEL +1 -1
  34. pywavelet-0.2.4.dist-info/RECORD +0 -35
  35. {pywavelet-0.2.4.dist-info → pywavelet-0.2.6.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,7 @@ def from_wavelet_to_time(
10
10
  wave_in: Wavelet,
11
11
  dt: float,
12
12
  nx: float = 4.0,
13
- mult: int = 32,
13
+ mult: int = None,
14
14
  ) -> TimeSeries:
15
15
  """Inverse wavelet transform to time domain.
16
16
 
@@ -55,14 +55,12 @@ def from_wavelet_to_freq(
55
55
  Frequency domain signal
56
56
 
57
57
  """
58
- phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, dt=dt, d=nx))
58
+ phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx))
59
59
  freq_data = inverse_wavelet_freq_helper(
60
60
  wave_in.data, phif=phif, Nf=wave_in.Nf, Nt=wave_in.Nt
61
61
  )
62
62
 
63
- freq_data *= 2 ** (
64
- -1 / 2
65
- ) # Normalise to get the proper backwards transformation
63
+ freq_data *= 1.0 / jnp.sqrt(2)
66
64
 
67
- freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
65
+ freqs = rfftfreq(wave_in.ND, d=dt)
68
66
  return FrequencySeries(data=freq_data, freq=freqs)
@@ -1,70 +1,88 @@
1
+ from functools import partial
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  from jax import jit
4
6
  from jax.numpy.fft import fft
5
7
 
6
- from functools import partial
7
8
 
8
-
9
- @partial(jit, static_argnames=('Nf', 'Nt'))
9
+ @partial(jit, static_argnames=("Nf", "Nt"))
10
10
  def inverse_wavelet_freq_helper(
11
- wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
11
+ wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
12
12
  ) -> jnp.ndarray:
13
- """JAX vectorized function for inverse_wavelet_freq"""
13
+ """JAX vectorized function for inverse_wavelet_freq with corrected shapes and ranges."""
14
+ # Transpose to match the NumPy version.
14
15
  wave_in = wave_in.T
15
16
  ND = Nf * Nt
16
17
 
18
+ # Allocate prefactor2s for each m (shape: (Nf+1, Nt)).
17
19
  m_range = jnp.arange(Nf + 1)
18
20
  prefactor2s = jnp.zeros((Nf + 1, Nt), dtype=jnp.complex128)
19
-
20
21
  n_range = jnp.arange(Nt)
21
22
 
22
23
  # m == 0 case
23
- prefactor2s = prefactor2s.at[0].set(2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0])
24
+ prefactor2s = prefactor2s.at[0].set(
25
+ 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
26
+ )
24
27
 
25
28
  # m == Nf case
26
- prefactor2s = prefactor2s.at[Nf].set(2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0])
29
+ prefactor2s = prefactor2s.at[Nf].set(
30
+ 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
31
+ )
27
32
 
28
- # Other m cases
33
+ # Other m cases: use meshgrid for vectorization.
29
34
  m_mid = m_range[1:Nf]
35
+ # Create grids: n_grid (columns) and m_grid (rows)
30
36
  n_grid, m_grid = jnp.meshgrid(n_range, m_mid)
37
+ # Index the transposed wave_in using (n, m) as in the NumPy version.
31
38
  val = wave_in[n_grid, m_grid]
39
+ # Apply the alternating multiplier based on (n+m) parity.
32
40
  mult2 = jnp.where((n_grid + m_grid) % 2, -1j, 1)
33
41
  prefactor2s = prefactor2s.at[1:Nf].set(mult2 * val)
34
42
 
35
- # Vectorized FFT
43
+ # Apply FFT along axis 1 for all m.
36
44
  fft_prefactor2s = fft(prefactor2s, axis=1)
37
45
 
38
- # Vectorized __unpack_wave_inverse
39
- ## TODO: Check with Giorgio
40
- # ND or ND // 2 + 1?
41
- # https://github.com/pywavelet/pywavelet/blob/63151a47cde9edc14f1e7e0bf17f554e78ad257c/src/pywavelet/transforms/from_wavelets/inverse_wavelet_freq_funcs.py
42
- res = jnp.zeros(ND, dtype=jnp.complex128)
46
+ # Allocate the result array with corrected shape.
47
+ res = jnp.zeros(ND // 2 + 1, dtype=jnp.complex128)
43
48
 
44
- # m == 0 or m == Nf cases
49
+ # Unpacking for m == 0 and m == Nf cases:
45
50
  i_ind_range = jnp.arange(Nt // 2)
46
- i_0 = jnp.abs(i_ind_range)
47
- i_Nf = jnp.abs(Nf * Nt // 2 - i_ind_range)
51
+ i_0 = jnp.abs(i_ind_range) # for m == 0: i = i_ind_range
52
+ i_Nf = jnp.abs(Nf * (Nt // 2) - i_ind_range)
48
53
  ind3_0 = (2 * i_0) % Nt
49
54
  ind3_Nf = (2 * i_Nf) % Nt
50
55
 
51
56
  res = res.at[i_0].add(fft_prefactor2s[0, ind3_0] * phif[i_ind_range])
52
57
  res = res.at[i_Nf].add(fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range])
58
+ # Special case for m == Nf (ensure the Nyquist frequency is updated correctly)
59
+ special_index = jnp.abs(Nf * (Nt // 2) - (Nt // 2))
60
+ res = res.at[special_index].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
53
61
 
54
- # Special case for m == Nf
55
- res = res.at[Nf * Nt // 2].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
56
-
57
- # Other m cases
62
+ # Unpacking for m in (1, ..., Nf-1)
58
63
  m_mid = m_range[1:Nf]
59
- i_ind_range = jnp.arange(Nt // 2 + 1)
60
- m_grid, i_ind_grid = jnp.meshgrid(m_mid, i_ind_range)
61
-
62
- i1 = Nt // 2 * m_grid - i_ind_grid
63
- i2 = Nt // 2 * m_grid + i_ind_grid
64
- ind31 = (Nt // 2 * m_grid - i_ind_grid) % Nt
65
- ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
66
-
67
- res = res.at[i1].add(fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid])
68
- res = res.at[i2].add(fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid])
69
-
70
- return res
64
+ # Use range [0, Nt//2) to match the loop in NumPy version.
65
+ i_ind_range_mid = jnp.arange(Nt // 2)
66
+ # Create meshgrid for vectorized computation.
67
+ m_grid_mid, i_ind_grid_mid = jnp.meshgrid(
68
+ m_mid, i_ind_range_mid, indexing="ij"
69
+ )
70
+
71
+ # Compute indices i1 and i2 following the NumPy logic.
72
+ i1 = (Nt // 2) * m_grid_mid - i_ind_grid_mid
73
+ i2 = (Nt // 2) * m_grid_mid + i_ind_grid_mid
74
+ # Compute the wrapped indices for FFT results.
75
+ ind31 = i1 % Nt
76
+ ind32 = i2 % Nt
77
+
78
+ # Update result array using vectorized adds.
79
+ # Note: You might need to adjust this further if your target res shape is non-trivial,
80
+ # because here we assume that i1 and i2 indices fall within the allocated result shape.
81
+ res = res.at[i1].add(
82
+ fft_prefactor2s[m_grid_mid, ind31] * phif[i_ind_grid_mid]
83
+ )
84
+ res = res.at[i2].add(
85
+ fft_prefactor2s[m_grid_mid, ind32] * phif[i_ind_grid_mid]
86
+ )
87
+
88
+ return res
@@ -1,10 +1,9 @@
1
1
  from .forward import from_freq_to_wavelet, from_time_to_wavelet
2
2
  from .inverse import from_wavelet_to_freq, from_wavelet_to_time
3
3
 
4
-
5
4
  __all__ = [
6
5
  "from_wavelet_to_time",
7
6
  "from_wavelet_to_freq",
8
7
  "from_time_to_wavelet",
9
8
  "from_freq_to_wavelet",
10
- ]
9
+ ]
@@ -1,32 +1,69 @@
1
1
  """helper functions for transform_freq"""
2
2
 
3
+ import logging
4
+
3
5
  import numpy as np
4
6
  from numba import njit
5
7
 
8
+ # import rocket_fft.special as rfft # JIT‐able FFT routines
9
+
10
+
11
+ logger = logging.getLogger("pywavelet")
12
+
6
13
 
7
14
  def transform_wavelet_freq_helper(
8
15
  data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
9
16
  ) -> np.ndarray:
10
- """helper to do the wavelet transform using the fast wavelet domain transform"""
11
- wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal
17
+ """
18
+ Forward wavelet transform helper using the fast wavelet domain transform,
19
+ with a JIT-able FFT (rocket-fft) so that the whole transform is jittable.
20
+
21
+ Parameters
22
+ ----------
23
+ data : np.ndarray
24
+ Input frequency-domain data (1D array).
25
+ Nf : int
26
+ Number of frequency bins.
27
+ Nt : int
28
+ Number of time bins.
29
+ phif : np.ndarray
30
+ Fourier-domain phase factors (complex-valued array of length Nt//2 + 1).
31
+
32
+ Returns
33
+ -------
34
+ wave : np.ndarray
35
+ The wavelet transform output of shape (Nt, Nf). Note that contributions from
36
+ f_bin==0 and f_bin==Nf are both stored in column 0.
37
+ """
38
+ logger.debug(
39
+ f"[NUMPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]"
40
+ )
41
+ wave = np.zeros((Nt, Nf), dtype=np.float64)
12
42
  DX = np.zeros(Nt, dtype=np.complex128)
13
- freq_strain = data.copy() # Convert
43
+ # Create a copy of the input data (if needed).
44
+ freq_strain = data.copy()
14
45
  __core(Nf, Nt, DX, freq_strain, phif, wave)
15
46
  return wave
16
47
 
17
48
 
18
- # @njit()
49
+ @njit()
19
50
  def __core(
20
51
  Nf: int,
21
52
  Nt: int,
22
53
  DX: np.ndarray,
23
- freq_strain: np.ndarray,
54
+ data: np.ndarray,
24
55
  phif: np.ndarray,
25
56
  wave: np.ndarray,
26
57
  ):
58
+ """
59
+ Process each frequency bin (f_bin) to compute the temporary array DX,
60
+ perform the inverse FFT using rocket-fft, and then unpack the result into wave.
61
+
62
+ This function is fully jittable.
63
+ """
27
64
  for f_bin in range(0, Nf + 1):
28
- __fill_wave_1(f_bin, Nt, Nf, DX, freq_strain, phif)
29
- # Numba doesn't support np.ifft so we cant jit this
65
+ __fill_wave_1(f_bin, Nt, Nf, DX, data, phif)
66
+ # Use rocket-fft's ifft (which is JIT-able) instead of np.fft.ifft.
30
67
  DX_trans = np.fft.ifft(DX, Nt)
31
68
  __fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
32
69
 
@@ -40,24 +77,32 @@ def __fill_wave_1(
40
77
  data: np.ndarray,
41
78
  phif: np.ndarray,
42
79
  ) -> None:
43
- """helper for assigning DX in the main loop"""
80
+ """
81
+ Fill the temporary complex array DX for the given frequency bin (f_bin)
82
+ based on the input data and the phase factors phif.
83
+
84
+ The computation is performed over a window of indices defined by the current f_bin.
85
+ """
44
86
  i_base = Nt // 2
45
- jj_base = f_bin * Nt // 2
87
+ jj_base = f_bin * (Nt // 2)
46
88
 
89
+ # Special center assignment:
47
90
  if f_bin == 0 or f_bin == Nf:
48
- # NOTE this term appears to be needed to recover correct constant (at least for m=0), but was previously missing
49
- DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2] / 2.0
91
+ DX[i_base] = phif[0] * data[jj_base] / 2.0
50
92
  else:
51
- DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2]
93
+ DX[i_base] = phif[0] * data[jj_base]
52
94
 
53
- for jj in range(jj_base + 1 - Nt // 2, jj_base + Nt // 2):
95
+ # Determine the window of indices.
96
+ start = jj_base + 1 - (Nt // 2)
97
+ end = jj_base + (Nt // 2)
98
+ for jj in range(start, end):
54
99
  j = np.abs(jj - jj_base)
55
100
  i = i_base - jj_base + jj
56
- if f_bin == Nf and jj > jj_base:
57
- DX[i] = 0.0
58
- elif f_bin == 0 and jj < jj_base:
101
+ # For the highest frequency (f_bin==Nf) or the lowest (f_bin==0), zero out the out-of-range values.
102
+ if (f_bin == Nf and jj > jj_base) or (f_bin == 0 and jj < jj_base):
59
103
  DX[i] = 0.0
60
104
  elif j == 0:
105
+ # Center already assigned.
61
106
  continue
62
107
  else:
63
108
  DX[i] = phif[j] * data[jj]
@@ -67,21 +112,34 @@ def __fill_wave_1(
67
112
  def __fill_wave_2(
68
113
  f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
69
114
  ) -> None:
115
+ """
116
+ Unpack the inverse FFT output (DX_trans) into the output wave array.
117
+
118
+ For f_bin==0 and f_bin==Nf, the results are stored in column 0 of wave,
119
+ using even- or odd-indexed rows respectively. For intermediate f_bin values,
120
+ the values are stored in column f_bin with a sign and component (real or imag)
121
+ determined by parity.
122
+ """
123
+ sqrt2 = np.sqrt(2.0)
70
124
  if f_bin == 0:
71
- # half of lowest and highest frequency bin pixels are redundant, so store them in even and odd components of f_bin=0 respectively
125
+ # f_bin==0: assign even-indexed rows of column 0.
72
126
  for n in range(0, Nt, 2):
73
- wave[n, 0] = DX_trans[n].real * np.sqrt(2)
127
+ wave[n, 0] = DX_trans[n].real * sqrt2
74
128
  elif f_bin == Nf:
129
+ # f_bin==Nf: assign odd-indexed rows of column 0.
75
130
  for n in range(0, Nt, 2):
76
- wave[n + 1, 0] = DX_trans[n].real * np.sqrt(2)
131
+ wave[n + 1, 0] = DX_trans[n].real * sqrt2
77
132
  else:
133
+ # For intermediate f_bin, assign values to column f_bin.
78
134
  for n in range(0, Nt):
79
135
  if f_bin % 2:
136
+ # For odd f_bin: use -imag when (n+f_bin) is odd; otherwise use real.
80
137
  if (n + f_bin) % 2:
81
138
  wave[n, f_bin] = -DX_trans[n].imag
82
139
  else:
83
140
  wave[n, f_bin] = DX_trans[n].real
84
141
  else:
142
+ # For even f_bin: use imag when (n+f_bin) is odd; otherwise use real.
85
143
  if (n + f_bin) % 2:
86
144
  wave[n, f_bin] = DX_trans[n].imag
87
145
  else:
@@ -56,7 +56,6 @@ def from_time_to_wavelet(
56
56
  to inaccurate results.
57
57
  """
58
58
  Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
59
- dt = timeseries.dt
60
59
  t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
61
60
 
62
61
  ND = Nf * Nt
@@ -113,7 +112,7 @@ def from_freq_to_wavelet(
113
112
  """
114
113
  Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
115
114
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
116
- phif = phitilde_vec_norm(Nf, Nt, d=nx)
115
+ phif = phitilde_vec_norm(Nf, Nt, d=nx)
117
116
  wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
118
117
 
119
118
  return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
@@ -1,14 +1,12 @@
1
1
  import numpy as np
2
2
 
3
- from ...phi_computer import phi_vec, phitilde_vec_norm
4
3
  from ....types import FrequencySeries, TimeSeries, Wavelet
4
+ from ...phi_computer import phi_vec, phitilde_vec_norm
5
5
  from .to_freq import inverse_wavelet_freq_helper_fast
6
6
  from .to_time import inverse_wavelet_time_helper_fast
7
7
 
8
8
  __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
9
9
 
10
- INV_ROOT2 = 1.0 / np.sqrt(2)
11
-
12
10
 
13
11
  def from_wavelet_to_time(
14
12
  wave_in: Wavelet,
@@ -50,7 +48,7 @@ def from_wavelet_to_time(
50
48
  h_t = inverse_wavelet_time_helper_fast(
51
49
  wave_in.data.T, phi, wave_in.Nf, wave_in.Nt, mult
52
50
  )
53
- h_t *= INV_ROOT2 # Normalize to get proper backward transformation
51
+ h_t *= 1.0 / np.sqrt(2) # Normalize to get proper backward transformation
54
52
  ts = np.arange(0, wave_in.Nf * wave_in.Nt) * dt
55
53
  return TimeSeries(data=h_t, time=ts)
56
54
 
@@ -84,12 +82,12 @@ def from_wavelet_to_freq(
84
82
  to ensure the proper backwards transformation.
85
83
  """
86
84
 
87
- phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx)
85
+ phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx)
88
86
  freq_data = inverse_wavelet_freq_helper_fast(
89
87
  wave_in.data, phif, wave_in.Nf, wave_in.Nt
90
88
  )
91
89
 
92
- freq_data *= INV_ROOT2
90
+ freq_data *= 1.0 / np.sqrt(2)
93
91
 
94
92
  freqs = np.fft.rfftfreq(wave_in.ND, d=dt)
95
93
  return FrequencySeries(data=freq_data, freq=freqs)
@@ -46,7 +46,7 @@ def __pack_wave_inverse(
46
46
  prefactor2s[n] = 2 ** (-1 / 2) * wave_in[(2 * n) % Nt + 1, 0]
47
47
  else:
48
48
  for n in range(0, Nt):
49
- val = wave_in[n, m] # bug is here
49
+ val = wave_in[n, m]
50
50
  if (n + m) % 2:
51
51
  mult2 = -1j
52
52
  else:
@@ -93,3 +93,66 @@ def __unpack_wave_inverse(
93
93
  if ind32 == Nt:
94
94
  ind32 = 0
95
95
  res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0]
96
+
97
+
98
+ #
99
+ # # @njit
100
+ # def inverse_wavelet_freq_helper_fast_version2(
101
+ # wave_in: np.ndarray, phif: np.ndarray, Nf: int, Nt: int
102
+ # ) -> np.ndarray:
103
+ # wave_in = wave_in.T
104
+ # ND = Nf * Nt
105
+ # prefactor2s = np.zeros((Nf + 1, Nt), dtype=np.complex128)
106
+ # n_range = np.arange(Nt)
107
+ #
108
+ # # m == 0 case
109
+ # indices = (2 * n_range) % Nt
110
+ # prefactor2s[0] = (2 ** (-0.5)) * wave_in[indices, 0]
111
+ #
112
+ # # m == Nf case
113
+ # indices = ((2 * n_range) % Nt) + 1
114
+ # prefactor2s[Nf] = (2 ** (-0.5)) * wave_in[indices, 0]
115
+ #
116
+ # # For m = 1, ..., Nf-1
117
+ # m_mid = np.arange(1, Nf)
118
+ # m_grid, n_grid = np.meshgrid(m_mid, n_range, indexing='ij')
119
+ # val = wave_in[n_grid, m_grid]
120
+ # mult2 = np.where(((n_grid + m_grid) % 2) != 0, -1j, 1)
121
+ # prefactor2s[1:Nf] = mult2 * val
122
+ #
123
+ # fft_prefactor2s = np.fft.fft(prefactor2s, axis=1)
124
+ #
125
+ # res = np.zeros(ND // 2 + 1, dtype=np.complex128)
126
+ #
127
+ # # Unpacking for m == 0 and m == Nf
128
+ # for m in [0, Nf]:
129
+ # i_ind_range = np.arange(Nt // 2 + 1 if m == Nf else Nt // 2)
130
+ # i = np.abs(m * Nt // 2 - i_ind_range)
131
+ # ind3 = (2 * i) % Nt
132
+ # res[i] += fft_prefactor2s[m, ind3] * phif[i_ind_range]
133
+ #
134
+ # # Unpacking for m = 1,..., Nf-1
135
+ # for m in range(1, Nf):
136
+ # ind31 = (Nt // 2 * m) % Nt
137
+ # ind32 = ind31
138
+ # for i_ind in range(Nt // 2):
139
+ # i1 = Nt // 2 * m - i_ind
140
+ # i2 = Nt // 2 * m + i_ind
141
+ # res[i1] += fft_prefactor2s[m, ind31] * phif[i_ind]
142
+ # res[i2] += fft_prefactor2s[m, ind32] * phif[i_ind]
143
+ # ind31 = (ind31 - 1) % Nt
144
+ # ind32 = (ind32 + 1) % Nt
145
+ # res[Nt // 2 * m] += fft_prefactor2s[m, (Nt // 2 * m) % Nt] * phif[0]
146
+ #
147
+ # return res
148
+ #
149
+ # #
150
+ # #
151
+ # # if __name__ == '__main__':
152
+ # # phif = np.array(np.random.rand(64))
153
+ # # wave_in = np.array(np.random.rand(64, 64))
154
+ # # Nf = 64
155
+ # # Nt = 64
156
+ # # res = inverse_wavelet_freq_helper_fast(wave_in, phif, Nf, Nt)
157
+ # # res2 = inverse_wavelet_freq_helper_fast_version2(wave_in, phif, Nf, Nt)
158
+ # # assert np.allclose(res, res2), "Results do not match!"
@@ -1,10 +1,62 @@
1
- from ..backend import xp, PI, betainc, ifft
1
+ import numpy as np
2
+ from jaxtyping import Array, Float
2
3
 
4
+ from ..backend import betainc, ifft, xp
3
5
 
6
+ __all__ = ["phitilde_vec_norm", "phi_vec", "omega"]
4
7
 
5
- def phitilde_vec(
6
- omega: xp.ndarray, Nf: int, d: float = 4.0
7
- ) -> xp.ndarray:
8
+
9
+ def omega(Nf: int, Nt: int) -> Float[Array, " Nt//2+1"]:
10
+ """Get the angular frequencies of the time domain signal."""
11
+ df = 2 * np.pi / (Nf * Nt)
12
+ return df * np.arange(0, Nt // 2 + 1)
13
+
14
+
15
+ def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> Float[Array, " Nt//2+1"]:
16
+ """Normalize phitilde for inverse frequency domain transform."""
17
+ omegas = omega(Nf, Nt)
18
+ _phi_t = _phitilde_vec(omegas, Nf, d) * np.sqrt(np.pi)
19
+ return xp.array(_phi_t)
20
+
21
+
22
+ def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
23
+ """get time domain phi as fourier transform of _phitilde_vec
24
+ q: number of Nf bins over which the window extends?
25
+
26
+ """
27
+ insDOM = 1.0 / np.sqrt(np.pi / Nf)
28
+ K = q * 2 * Nf
29
+ half_K = q * Nf # xp.int64(K/2)
30
+
31
+ dom = 2 * np.pi / K # max frequency is K/2*dom = pi/dt = OM
32
+
33
+ DX = np.zeros(K, dtype=np.complex128)
34
+
35
+ # zero frequency
36
+ DX[0] = insDOM
37
+
38
+ DX = DX.copy()
39
+ # postive frequencies
40
+ DX[1 : half_K + 1] = _phitilde_vec(dom * np.arange(1, half_K + 1), Nf, d)
41
+ # negative frequencies
42
+ DX[half_K + 1 :] = _phitilde_vec(
43
+ -dom * np.arange(half_K - 1, 0, -1), Nf, d
44
+ )
45
+ DX = K * ifft(DX, K)
46
+
47
+ phi = np.zeros(K)
48
+ phi[0:half_K] = np.real(DX[half_K:K])
49
+ phi[half_K:] = np.real(DX[0:half_K])
50
+
51
+ nrm = np.sqrt(2.0) / np.sqrt(K / dom) # *xp.linalg.norm(phi)
52
+
53
+ phi *= nrm
54
+ return xp.array(phi)
55
+
56
+
57
+ def _phitilde_vec(
58
+ omega: Float[Array, " Nt//2+1"], Nf: int, d: float = 4.0
59
+ ) -> Float[Array, " Nt//2+1"]:
8
60
  """Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
9
61
 
10
62
  Eq 11 of https://arxiv.org/pdf/2009.00043.pdf (Cornish et al. 2020)
@@ -33,25 +85,25 @@ def phitilde_vec(
33
85
 
34
86
  """
35
87
  dF = 1.0 / (2 * Nf) # NOTE: missing 1/dt?
36
- dOmega = 2 * PI * dF # Near Eq 10 # 2 pi times DF
37
- inverse_sqrt_dOmega = 1.0 / xp.sqrt(dOmega)
88
+ dOmega = 2 * np.pi * dF # Near Eq 10 # 2 pi times DF
89
+ inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)
38
90
 
39
91
  A = dOmega / 4
40
92
  B = dOmega - 2 * A # Cannot have B \leq 0.
41
93
  if B <= 0:
42
94
  raise ValueError("B must be greater than 0")
43
95
 
44
- phi = xp.zeros(omega.size)
45
- mask = (A <= xp.abs(omega)) & (xp.abs(omega) < A + B) # Minor changes
46
- vd = (PI / 2.0) * __nu_d(omega[mask], A, B, d=d) # different from paper
96
+ phi = np.zeros(omega.size)
97
+ mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B) # Minor changes
98
+ vd = (np.pi / 2.0) * _nu_d(omega[mask], A, B, d=d) # different from paper
47
99
  phi[mask] = inverse_sqrt_dOmega * xp.cos(vd)
48
- phi[xp.abs(omega) < A] = inverse_sqrt_dOmega
100
+ phi[np.abs(omega) < A] = inverse_sqrt_dOmega
49
101
  return phi
50
102
 
51
103
 
52
- def __nu_d(
53
- omega: xp.ndarray, A: float, B: float, d: float = 4.0
54
- ) -> xp.ndarray:
104
+ def _nu_d(
105
+ omega: Float[Array, " Nt//2+1"], A: float, B: float, d: float = 4.0
106
+ ) -> Float[Array, " Nt//2+1"]:
55
107
  """Compute the normalized incomplete beta function.
56
108
 
57
109
  Parameters
@@ -74,76 +126,5 @@ def __nu_d(
74
126
  https://docs.scipy.org/doc/scipy-1.7.1/reference/reference/generated/scipy.special.betainc.html
75
127
 
76
128
  """
77
- x = (xp.abs(omega) - A) / B
78
- return betainc(d, d, x) / betainc(d, d, 1)
79
-
80
-
81
- def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> xp.ndarray:
82
- """Normalize phitilde for inverse frequency domain transform."""
83
-
84
- # Calculate the frequency values
85
- ND = Nf * Nt
86
- omegas = 2 * xp.pi / ND * xp.arange(0, Nt // 2 + 1)
87
-
88
- # Calculate the unnormalized phitilde (u_phit)
89
- u_phit = phitilde_vec(omegas, Nf, d)
90
-
91
- # Normalize the phitilde
92
- normalising_factor = PI ** (-1 / 2) # Ollie's normalising factor
93
-
94
- # Notes: this is the overall normalising factor that is different from Cornish's paper
95
- # It is the only way I can force this code to be consistent with our work in the
96
- # frequency domain. First note that
97
-
98
- # old normalising factor -- This factor is absolutely ridiculous. Why!?
99
- # Matt_normalising_factor = np.sqrt(
100
- # (2 * np.sum(u_phit[1:] ** 2) + u_phit[0] ** 2) * 2 * PI / ND
101
- # )
102
- # Matt_normalising_factor /= PI**(3/2)/PI
103
-
104
- # The expression above is equal to np.pi**(-1/2) after working through the maths.
105
- # I have pulled (2/Nf) from __init__.py (from freq to wavelet) into the normalsiing
106
- # factor here. I thnk it's cleaner to have ONE normalising constant. Avoids confusion
107
- # and it is much easier to track.
108
-
109
- # TODO: understand the following:
110
- # (2 * np.sum(u_phit[1:] ** 2) + u_phit[0] ** 2) = 0.5 * Nt / dOmega
111
- # Matt_normalising_factor is equal to 1/sqrt(pi)... why is this computed?
112
- # in such a stupid way?
113
-
114
- return u_phit / (normalising_factor)
115
-
116
-
117
- def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> xp.ndarray:
118
- """get time domain phi as fourier transform of phitilde_vec"""
119
- insDOM = 1.0 / xp.sqrt(PI / Nf)
120
- K = q * 2 * Nf
121
- half_K = q * Nf # xp.int64(K/2)
122
-
123
- dom = 2 * PI / K # max frequency is K/2*dom = pi/dt = OM
124
-
125
- DX = xp.zeros(K, dtype=xp.complex128)
126
-
127
- # zero frequency
128
- DX[0] = insDOM
129
-
130
- DX = DX.copy()
131
- # postive frequencies
132
- DX[1 : half_K + 1] = phitilde_vec(
133
- dom * xp.arange(1, half_K + 1), Nf, d
134
- )
135
- # negative frequencies
136
- DX[half_K + 1 :] = phitilde_vec(
137
- -dom * xp.arange(half_K - 1, 0, -1), Nf, d
138
- )
139
- DX = K * ifft(DX, K)
140
-
141
- phi = xp.zeros(K)
142
- phi[0:half_K] = xp.real(DX[half_K:K])
143
- phi[half_K:] = xp.real(DX[0:half_K])
144
-
145
- nrm = xp.sqrt(K / dom) # *xp.linalg.norm(phi)
146
-
147
- fac = xp.sqrt(2.0) / nrm
148
- phi *= fac
149
- return phi
129
+ x = (np.abs(omega) - A) / B
130
+ return betainc(d, d, x)
pywavelet/types/common.py CHANGED
@@ -1,6 +1,7 @@
1
- from typing import Tuple, Union, Callable
2
- from ..logger import logger
1
+ from typing import Callable, Tuple, Union
2
+
3
3
  from ..backend import xp
4
+ from ..logger import logger
4
5
 
5
6
 
6
7
  def _len_check(d):
@@ -8,7 +9,7 @@ def _len_check(d):
8
9
  logger.warning(f"Data length {len(d)} is suggested to be a power of 2")
9
10
 
10
11
 
11
- def is_documented_by(original:Callable):
12
+ def is_documented_by(original: Callable):
12
13
  def wrapper(target):
13
14
  target.__doc__ = original.__doc__
14
15
  return target
@@ -2,7 +2,7 @@ from typing import Optional, Tuple, Union
2
2
 
3
3
  import matplotlib.pyplot as plt
4
4
 
5
- from ..backend import xp, irfft
5
+ from ..backend import irfft, xp
6
6
  from .common import fmt_pow2, fmt_time, is_documented_by
7
7
  from .plotting import plot_freqseries, plot_periodogram
8
8