pywavelet 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,13 +10,14 @@ from jax.numpy.fft import fft
10
10
  def inverse_wavelet_freq_helper(
11
11
  wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
12
12
  ) -> jnp.ndarray:
13
- """JAX vectorized function for inverse_wavelet_freq"""
13
+ """JAX vectorized function for inverse_wavelet_freq with corrected shapes and ranges."""
14
+ # Transpose to match the NumPy version.
14
15
  wave_in = wave_in.T
15
16
  ND = Nf * Nt
16
17
 
18
+ # Allocate prefactor2s for each m (shape: (Nf+1, Nt)).
17
19
  m_range = jnp.arange(Nf + 1)
18
20
  prefactor2s = jnp.zeros((Nf + 1, Nt), dtype=jnp.complex128)
19
-
20
21
  n_range = jnp.arange(Nt)
21
22
 
22
23
  # m == 0 case
@@ -29,46 +30,59 @@ def inverse_wavelet_freq_helper(
29
30
  2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
30
31
  )
31
32
 
32
- # Other m cases
33
+ # Other m cases: use meshgrid for vectorization.
33
34
  m_mid = m_range[1:Nf]
35
+ # Create grids: n_grid (columns) and m_grid (rows)
34
36
  n_grid, m_grid = jnp.meshgrid(n_range, m_mid)
37
+ # Index the transposed wave_in using (n, m) as in the NumPy version.
35
38
  val = wave_in[n_grid, m_grid]
39
+ # Apply the alternating multiplier based on (n+m) parity.
36
40
  mult2 = jnp.where((n_grid + m_grid) % 2, -1j, 1)
37
41
  prefactor2s = prefactor2s.at[1:Nf].set(mult2 * val)
38
42
 
39
- # Vectorized FFT
43
+ # Apply FFT along axis 1 for all m.
40
44
  fft_prefactor2s = fft(prefactor2s, axis=1)
41
45
 
42
- # Vectorized __unpack_wave_inverse
43
- ## TODO: Check with Giorgio
44
- # ND or ND // 2 + 1?
45
- # https://github.com/pywavelet/pywavelet/blob/63151a47cde9edc14f1e7e0bf17f554e78ad257c/src/pywavelet/transforms/from_wavelets/inverse_wavelet_freq_funcs.py
46
- res = jnp.zeros(ND, dtype=jnp.complex128)
46
+ # Allocate the result array with corrected shape.
47
+ res = jnp.zeros(ND // 2 + 1, dtype=jnp.complex128)
47
48
 
48
- # m == 0 or m == Nf cases
49
+ # Unpacking for m == 0 and m == Nf cases:
49
50
  i_ind_range = jnp.arange(Nt // 2)
50
- i_0 = jnp.abs(i_ind_range)
51
- i_Nf = jnp.abs(Nf * Nt // 2 - i_ind_range)
51
+ i_0 = jnp.abs(i_ind_range) # for m == 0: i = i_ind_range
52
+ i_Nf = jnp.abs(Nf * (Nt // 2) - i_ind_range)
52
53
  ind3_0 = (2 * i_0) % Nt
53
54
  ind3_Nf = (2 * i_Nf) % Nt
54
55
 
55
56
  res = res.at[i_0].add(fft_prefactor2s[0, ind3_0] * phif[i_ind_range])
56
57
  res = res.at[i_Nf].add(fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range])
58
+ # Special case for m == Nf (ensure the Nyquist frequency is updated correctly)
59
+ special_index = jnp.abs(Nf * (Nt // 2) - (Nt // 2))
60
+ res = res.at[special_index].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
57
61
 
58
- # Special case for m == Nf
59
- res = res.at[Nf * Nt // 2].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
60
-
61
- # Other m cases
62
+ # Unpacking for m in (1, ..., Nf-1)
62
63
  m_mid = m_range[1:Nf]
63
- i_ind_range = jnp.arange(Nt // 2 + 1)
64
- m_grid, i_ind_grid = jnp.meshgrid(m_mid, i_ind_range)
65
-
66
- i1 = Nt // 2 * m_grid - i_ind_grid
67
- i2 = Nt // 2 * m_grid + i_ind_grid
68
- ind31 = (Nt // 2 * m_grid - i_ind_grid) % Nt
69
- ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
64
+ # Use range [0, Nt//2) to match the loop in NumPy version.
65
+ i_ind_range_mid = jnp.arange(Nt // 2)
66
+ # Create meshgrid for vectorized computation.
67
+ m_grid_mid, i_ind_grid_mid = jnp.meshgrid(
68
+ m_mid, i_ind_range_mid, indexing="ij"
69
+ )
70
70
 
71
- res = res.at[i1].add(fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid])
72
- res = res.at[i2].add(fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid])
71
+ # Compute indices i1 and i2 following the NumPy logic.
72
+ i1 = (Nt // 2) * m_grid_mid - i_ind_grid_mid
73
+ i2 = (Nt // 2) * m_grid_mid + i_ind_grid_mid
74
+ # Compute the wrapped indices for FFT results.
75
+ ind31 = i1 % Nt
76
+ ind32 = i2 % Nt
77
+
78
+ # Update result array using vectorized adds.
79
+ # Note: You might need to adjust this further if your target res shape is non-trivial,
80
+ # because here we assume that i1 and i2 indices fall within the allocated result shape.
81
+ res = res.at[i1].add(
82
+ fft_prefactor2s[m_grid_mid, ind31] * phif[i_ind_grid_mid]
83
+ )
84
+ res = res.at[i2].add(
85
+ fft_prefactor2s[m_grid_mid, ind32] * phif[i_ind_grid_mid]
86
+ )
73
87
 
74
88
  return res
@@ -1,32 +1,69 @@
1
1
  """helper functions for transform_freq"""
2
2
 
3
+ import logging
4
+
3
5
  import numpy as np
4
6
  from numba import njit
5
7
 
8
+ # import rocket_fft.special as rfft # JIT‐able FFT routines
9
+
10
+
11
+ logger = logging.getLogger("pywavelet")
12
+
6
13
 
7
14
  def transform_wavelet_freq_helper(
8
15
  data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
9
16
  ) -> np.ndarray:
10
- """helper to do the wavelet transform using the fast wavelet domain transform"""
11
- wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal
17
+ """
18
+ Forward wavelet transform helper using the fast wavelet domain transform,
19
+ with a JIT-able FFT (rocket-fft) so that the whole transform is jittable.
20
+
21
+ Parameters
22
+ ----------
23
+ data : np.ndarray
24
+ Input frequency-domain data (1D array).
25
+ Nf : int
26
+ Number of frequency bins.
27
+ Nt : int
28
+ Number of time bins.
29
+ phif : np.ndarray
30
+ Fourier-domain phase factors (complex-valued array of length Nt//2 + 1).
31
+
32
+ Returns
33
+ -------
34
+ wave : np.ndarray
35
+ The wavelet transform output of shape (Nt, Nf). Note that contributions from
36
+ f_bin==0 and f_bin==Nf are both stored in column 0.
37
+ """
38
+ logger.debug(
39
+ f"[NUMPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]"
40
+ )
41
+ wave = np.zeros((Nt, Nf), dtype=np.float64)
12
42
  DX = np.zeros(Nt, dtype=np.complex128)
13
- freq_strain = data.copy() # Convert
43
+ # Create a copy of the input data (if needed).
44
+ freq_strain = data.copy()
14
45
  __core(Nf, Nt, DX, freq_strain, phif, wave)
15
46
  return wave
16
47
 
17
48
 
18
- # @njit()
49
+ @njit()
19
50
  def __core(
20
51
  Nf: int,
21
52
  Nt: int,
22
53
  DX: np.ndarray,
23
- freq_strain: np.ndarray,
54
+ data: np.ndarray,
24
55
  phif: np.ndarray,
25
56
  wave: np.ndarray,
26
57
  ):
58
+ """
59
+ Process each frequency bin (f_bin) to compute the temporary array DX,
60
+ perform the inverse FFT using rocket-fft, and then unpack the result into wave.
61
+
62
+ This function is fully jittable.
63
+ """
27
64
  for f_bin in range(0, Nf + 1):
28
- __fill_wave_1(f_bin, Nt, Nf, DX, freq_strain, phif)
29
- # Numba doesn't support np.ifft so we cant jit this
65
+ __fill_wave_1(f_bin, Nt, Nf, DX, data, phif)
66
+ # Use rocket-fft's ifft (which is JIT-able) instead of np.fft.ifft.
30
67
  DX_trans = np.fft.ifft(DX, Nt)
31
68
  __fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
32
69
 
@@ -40,24 +77,32 @@ def __fill_wave_1(
40
77
  data: np.ndarray,
41
78
  phif: np.ndarray,
42
79
  ) -> None:
43
- """helper for assigning DX in the main loop"""
80
+ """
81
+ Fill the temporary complex array DX for the given frequency bin (f_bin)
82
+ based on the input data and the phase factors phif.
83
+
84
+ The computation is performed over a window of indices defined by the current f_bin.
85
+ """
44
86
  i_base = Nt // 2
45
- jj_base = f_bin * Nt // 2
87
+ jj_base = f_bin * (Nt // 2)
46
88
 
89
+ # Special center assignment:
47
90
  if f_bin == 0 or f_bin == Nf:
48
- # NOTE this term appears to be needed to recover correct constant (at least for m=0), but was previously missing
49
- DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2] / 2.0
91
+ DX[i_base] = phif[0] * data[jj_base] / 2.0
50
92
  else:
51
- DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2]
93
+ DX[i_base] = phif[0] * data[jj_base]
52
94
 
53
- for jj in range(jj_base + 1 - Nt // 2, jj_base + Nt // 2):
95
+ # Determine the window of indices.
96
+ start = jj_base + 1 - (Nt // 2)
97
+ end = jj_base + (Nt // 2)
98
+ for jj in range(start, end):
54
99
  j = np.abs(jj - jj_base)
55
100
  i = i_base - jj_base + jj
56
- if f_bin == Nf and jj > jj_base:
57
- DX[i] = 0.0
58
- elif f_bin == 0 and jj < jj_base:
101
+ # For the highest frequency (f_bin==Nf) or the lowest (f_bin==0), zero out the out-of-range values.
102
+ if (f_bin == Nf and jj > jj_base) or (f_bin == 0 and jj < jj_base):
59
103
  DX[i] = 0.0
60
104
  elif j == 0:
105
+ # Center already assigned.
61
106
  continue
62
107
  else:
63
108
  DX[i] = phif[j] * data[jj]
@@ -67,21 +112,34 @@ def __fill_wave_1(
67
112
  def __fill_wave_2(
68
113
  f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
69
114
  ) -> None:
115
+ """
116
+ Unpack the inverse FFT output (DX_trans) into the output wave array.
117
+
118
+ For f_bin==0 and f_bin==Nf, the results are stored in column 0 of wave,
119
+ using even- or odd-indexed rows respectively. For intermediate f_bin values,
120
+ the values are stored in column f_bin with a sign and component (real or imag)
121
+ determined by parity.
122
+ """
123
+ sqrt2 = np.sqrt(2.0)
70
124
  if f_bin == 0:
71
- # half of lowest and highest frequency bin pixels are redundant, so store them in even and odd components of f_bin=0 respectively
125
+ # f_bin==0: assign even-indexed rows of column 0.
72
126
  for n in range(0, Nt, 2):
73
- wave[n, 0] = DX_trans[n].real * np.sqrt(2)
127
+ wave[n, 0] = DX_trans[n].real * sqrt2
74
128
  elif f_bin == Nf:
129
+ # f_bin==Nf: assign odd-indexed rows of column 0.
75
130
  for n in range(0, Nt, 2):
76
- wave[n + 1, 0] = DX_trans[n].real * np.sqrt(2)
131
+ wave[n + 1, 0] = DX_trans[n].real * sqrt2
77
132
  else:
133
+ # For intermediate f_bin, assign values to column f_bin.
78
134
  for n in range(0, Nt):
79
135
  if f_bin % 2:
136
+ # For odd f_bin: use -imag when (n+f_bin) is odd; otherwise use real.
80
137
  if (n + f_bin) % 2:
81
138
  wave[n, f_bin] = -DX_trans[n].imag
82
139
  else:
83
140
  wave[n, f_bin] = DX_trans[n].real
84
141
  else:
142
+ # For even f_bin: use imag when (n+f_bin) is odd; otherwise use real.
85
143
  if (n + f_bin) % 2:
86
144
  wave[n, f_bin] = DX_trans[n].imag
87
145
  else:
@@ -56,7 +56,6 @@ def from_time_to_wavelet(
56
56
  to inaccurate results.
57
57
  """
58
58
  Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
59
- dt = timeseries.dt
60
59
  t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
61
60
 
62
61
  ND = Nf * Nt
@@ -7,8 +7,6 @@ from .to_time import inverse_wavelet_time_helper_fast
7
7
 
8
8
  __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
9
9
 
10
- INV_ROOT2 = 1.0 / np.sqrt(2)
11
-
12
10
 
13
11
  def from_wavelet_to_time(
14
12
  wave_in: Wavelet,
@@ -50,7 +48,7 @@ def from_wavelet_to_time(
50
48
  h_t = inverse_wavelet_time_helper_fast(
51
49
  wave_in.data.T, phi, wave_in.Nf, wave_in.Nt, mult
52
50
  )
53
- h_t *= INV_ROOT2 # Normalize to get proper backward transformation
51
+ h_t *= 1.0 / np.sqrt(2) # Normalize to get proper backward transformation
54
52
  ts = np.arange(0, wave_in.Nf * wave_in.Nt) * dt
55
53
  return TimeSeries(data=h_t, time=ts)
56
54
 
@@ -89,7 +87,7 @@ def from_wavelet_to_freq(
89
87
  wave_in.data, phif, wave_in.Nf, wave_in.Nt
90
88
  )
91
89
 
92
- freq_data *= INV_ROOT2
90
+ freq_data *= 1.0 / np.sqrt(2)
93
91
 
94
92
  freqs = np.fft.rfftfreq(wave_in.ND, d=dt)
95
93
  return FrequencySeries(data=freq_data, freq=freqs)
@@ -46,7 +46,7 @@ def __pack_wave_inverse(
46
46
  prefactor2s[n] = 2 ** (-1 / 2) * wave_in[(2 * n) % Nt + 1, 0]
47
47
  else:
48
48
  for n in range(0, Nt):
49
- val = wave_in[n, m] # bug is here
49
+ val = wave_in[n, m]
50
50
  if (n + m) % 2:
51
51
  mult2 = -1j
52
52
  else:
@@ -93,3 +93,66 @@ def __unpack_wave_inverse(
93
93
  if ind32 == Nt:
94
94
  ind32 = 0
95
95
  res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0]
96
+
97
+
98
+ #
99
+ # # @njit
100
+ # def inverse_wavelet_freq_helper_fast_version2(
101
+ # wave_in: np.ndarray, phif: np.ndarray, Nf: int, Nt: int
102
+ # ) -> np.ndarray:
103
+ # wave_in = wave_in.T
104
+ # ND = Nf * Nt
105
+ # prefactor2s = np.zeros((Nf + 1, Nt), dtype=np.complex128)
106
+ # n_range = np.arange(Nt)
107
+ #
108
+ # # m == 0 case
109
+ # indices = (2 * n_range) % Nt
110
+ # prefactor2s[0] = (2 ** (-0.5)) * wave_in[indices, 0]
111
+ #
112
+ # # m == Nf case
113
+ # indices = ((2 * n_range) % Nt) + 1
114
+ # prefactor2s[Nf] = (2 ** (-0.5)) * wave_in[indices, 0]
115
+ #
116
+ # # For m = 1, ..., Nf-1
117
+ # m_mid = np.arange(1, Nf)
118
+ # m_grid, n_grid = np.meshgrid(m_mid, n_range, indexing='ij')
119
+ # val = wave_in[n_grid, m_grid]
120
+ # mult2 = np.where(((n_grid + m_grid) % 2) != 0, -1j, 1)
121
+ # prefactor2s[1:Nf] = mult2 * val
122
+ #
123
+ # fft_prefactor2s = np.fft.fft(prefactor2s, axis=1)
124
+ #
125
+ # res = np.zeros(ND // 2 + 1, dtype=np.complex128)
126
+ #
127
+ # # Unpacking for m == 0 and m == Nf
128
+ # for m in [0, Nf]:
129
+ # i_ind_range = np.arange(Nt // 2 + 1 if m == Nf else Nt // 2)
130
+ # i = np.abs(m * Nt // 2 - i_ind_range)
131
+ # ind3 = (2 * i) % Nt
132
+ # res[i] += fft_prefactor2s[m, ind3] * phif[i_ind_range]
133
+ #
134
+ # # Unpacking for m = 1,..., Nf-1
135
+ # for m in range(1, Nf):
136
+ # ind31 = (Nt // 2 * m) % Nt
137
+ # ind32 = ind31
138
+ # for i_ind in range(Nt // 2):
139
+ # i1 = Nt // 2 * m - i_ind
140
+ # i2 = Nt // 2 * m + i_ind
141
+ # res[i1] += fft_prefactor2s[m, ind31] * phif[i_ind]
142
+ # res[i2] += fft_prefactor2s[m, ind32] * phif[i_ind]
143
+ # ind31 = (ind31 - 1) % Nt
144
+ # ind32 = (ind32 + 1) % Nt
145
+ # res[Nt // 2 * m] += fft_prefactor2s[m, (Nt // 2 * m) % Nt] * phif[0]
146
+ #
147
+ # return res
148
+ #
149
+ # #
150
+ # #
151
+ # # if __name__ == '__main__':
152
+ # # phif = np.array(np.random.rand(64))
153
+ # # wave_in = np.array(np.random.rand(64, 64))
154
+ # # Nf = 64
155
+ # # Nt = 64
156
+ # # res = inverse_wavelet_freq_helper_fast(wave_in, phif, Nf, Nt)
157
+ # # res2 = inverse_wavelet_freq_helper_fast_version2(wave_in, phif, Nf, Nt)
158
+ # # assert np.allclose(res, res2), "Results do not match!"
@@ -1,7 +1,62 @@
1
- from ..backend import PI, betainc, ifft, xp
1
+ import numpy as np
2
+ from jaxtyping import Array, Float
2
3
 
4
+ from ..backend import betainc, ifft, xp
3
5
 
4
- def phitilde_vec(omega: xp.ndarray, Nf: int, d: float = 4.0) -> xp.ndarray:
6
+ __all__ = ["phitilde_vec_norm", "phi_vec", "omega"]
7
+
8
+
9
+ def omega(Nf: int, Nt: int) -> Float[Array, " Nt//2+1"]:
10
+ """Get the angular frequencies of the time domain signal."""
11
+ df = 2 * np.pi / (Nf * Nt)
12
+ return df * np.arange(0, Nt // 2 + 1)
13
+
14
+
15
+ def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> Float[Array, " Nt//2+1"]:
16
+ """Normalize phitilde for inverse frequency domain transform."""
17
+ omegas = omega(Nf, Nt)
18
+ _phi_t = _phitilde_vec(omegas, Nf, d) * np.sqrt(np.pi)
19
+ return xp.array(_phi_t)
20
+
21
+
22
+ def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
23
+ """get time domain phi as fourier transform of _phitilde_vec
24
+ q: number of Nf bins over which the window extends?
25
+
26
+ """
27
+ insDOM = 1.0 / np.sqrt(np.pi / Nf)
28
+ K = q * 2 * Nf
29
+ half_K = q * Nf # xp.int64(K/2)
30
+
31
+ dom = 2 * np.pi / K # max frequency is K/2*dom = pi/dt = OM
32
+
33
+ DX = np.zeros(K, dtype=np.complex128)
34
+
35
+ # zero frequency
36
+ DX[0] = insDOM
37
+
38
+ DX = DX.copy()
39
+ # postive frequencies
40
+ DX[1 : half_K + 1] = _phitilde_vec(dom * np.arange(1, half_K + 1), Nf, d)
41
+ # negative frequencies
42
+ DX[half_K + 1 :] = _phitilde_vec(
43
+ -dom * np.arange(half_K - 1, 0, -1), Nf, d
44
+ )
45
+ DX = K * ifft(DX, K)
46
+
47
+ phi = np.zeros(K)
48
+ phi[0:half_K] = np.real(DX[half_K:K])
49
+ phi[half_K:] = np.real(DX[0:half_K])
50
+
51
+ nrm = np.sqrt(2.0) / np.sqrt(K / dom) # *xp.linalg.norm(phi)
52
+
53
+ phi *= nrm
54
+ return xp.array(phi)
55
+
56
+
57
+ def _phitilde_vec(
58
+ omega: Float[Array, " Nt//2+1"], Nf: int, d: float = 4.0
59
+ ) -> Float[Array, " Nt//2+1"]:
5
60
  """Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
6
61
 
7
62
  Eq 11 of https://arxiv.org/pdf/2009.00043.pdf (Cornish et al. 2020)
@@ -30,25 +85,25 @@ def phitilde_vec(omega: xp.ndarray, Nf: int, d: float = 4.0) -> xp.ndarray:
30
85
 
31
86
  """
32
87
  dF = 1.0 / (2 * Nf) # NOTE: missing 1/dt?
33
- dOmega = 2 * PI * dF # Near Eq 10 # 2 pi times DF
34
- inverse_sqrt_dOmega = 1.0 / xp.sqrt(dOmega)
88
+ dOmega = 2 * np.pi * dF # Near Eq 10 # 2 pi times DF
89
+ inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)
35
90
 
36
91
  A = dOmega / 4
37
92
  B = dOmega - 2 * A # Cannot have B \leq 0.
38
93
  if B <= 0:
39
94
  raise ValueError("B must be greater than 0")
40
95
 
41
- phi = xp.zeros(omega.size)
42
- mask = (A <= xp.abs(omega)) & (xp.abs(omega) < A + B) # Minor changes
43
- vd = (PI / 2.0) * __nu_d(omega[mask], A, B, d=d) # different from paper
96
+ phi = np.zeros(omega.size)
97
+ mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B) # Minor changes
98
+ vd = (np.pi / 2.0) * _nu_d(omega[mask], A, B, d=d) # different from paper
44
99
  phi[mask] = inverse_sqrt_dOmega * xp.cos(vd)
45
- phi[xp.abs(omega) < A] = inverse_sqrt_dOmega
100
+ phi[np.abs(omega) < A] = inverse_sqrt_dOmega
46
101
  return phi
47
102
 
48
103
 
49
- def __nu_d(
50
- omega: xp.ndarray, A: float, B: float, d: float = 4.0
51
- ) -> xp.ndarray:
104
+ def _nu_d(
105
+ omega: Float[Array, " Nt//2+1"], A: float, B: float, d: float = 4.0
106
+ ) -> Float[Array, " Nt//2+1"]:
52
107
  """Compute the normalized incomplete beta function.
53
108
 
54
109
  Parameters
@@ -71,72 +126,5 @@ def __nu_d(
71
126
  https://docs.scipy.org/doc/scipy-1.7.1/reference/reference/generated/scipy.special.betainc.html
72
127
 
73
128
  """
74
- x = (xp.abs(omega) - A) / B
75
- return betainc(d, d, x) / betainc(d, d, 1)
76
-
77
-
78
- def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> xp.ndarray:
79
- """Normalize phitilde for inverse frequency domain transform."""
80
-
81
- # Calculate the frequency values
82
- ND = Nf * Nt
83
- omegas = 2 * xp.pi / ND * xp.arange(0, Nt // 2 + 1)
84
-
85
- # Calculate the unnormalized phitilde (u_phit)
86
- u_phit = phitilde_vec(omegas, Nf, d)
87
-
88
- # Normalize the phitilde
89
- normalising_factor = PI ** (-1 / 2) # Ollie's normalising factor
90
-
91
- # Notes: this is the overall normalising factor that is different from Cornish's paper
92
- # It is the only way I can force this code to be consistent with our work in the
93
- # frequency domain. First note that
94
-
95
- # old normalising factor -- This factor is absolutely ridiculous. Why!?
96
- # Matt_normalising_factor = np.sqrt(
97
- # (2 * np.sum(u_phit[1:] ** 2) + u_phit[0] ** 2) * 2 * PI / ND
98
- # )
99
- # Matt_normalising_factor /= PI**(3/2)/PI
100
-
101
- # The expression above is equal to np.pi**(-1/2) after working through the maths.
102
- # I have pulled (2/Nf) from __init__.py (from freq to wavelet) into the normalsiing
103
- # factor here. I thnk it's cleaner to have ONE normalising constant. Avoids confusion
104
- # and it is much easier to track.
105
-
106
- # TODO: understand the following:
107
- # (2 * np.sum(u_phit[1:] ** 2) + u_phit[0] ** 2) = 0.5 * Nt / dOmega
108
- # Matt_normalising_factor is equal to 1/sqrt(pi)... why is this computed?
109
- # in such a stupid way?
110
-
111
- return u_phit / (normalising_factor)
112
-
113
-
114
- def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> xp.ndarray:
115
- """get time domain phi as fourier transform of phitilde_vec"""
116
- insDOM = 1.0 / xp.sqrt(PI / Nf)
117
- K = q * 2 * Nf
118
- half_K = q * Nf # xp.int64(K/2)
119
-
120
- dom = 2 * PI / K # max frequency is K/2*dom = pi/dt = OM
121
-
122
- DX = xp.zeros(K, dtype=xp.complex128)
123
-
124
- # zero frequency
125
- DX[0] = insDOM
126
-
127
- DX = DX.copy()
128
- # postive frequencies
129
- DX[1 : half_K + 1] = phitilde_vec(dom * xp.arange(1, half_K + 1), Nf, d)
130
- # negative frequencies
131
- DX[half_K + 1 :] = phitilde_vec(-dom * xp.arange(half_K - 1, 0, -1), Nf, d)
132
- DX = K * ifft(DX, K)
133
-
134
- phi = xp.zeros(K)
135
- phi[0:half_K] = xp.real(DX[half_K:K])
136
- phi[half_K:] = xp.real(DX[0:half_K])
137
-
138
- nrm = xp.sqrt(K / dom) # *xp.linalg.norm(phi)
139
-
140
- fac = xp.sqrt(2.0) / nrm
141
- phi *= fac
142
- return phi
129
+ x = (np.abs(omega) - A) / B
130
+ return betainc(d, d, x)
@@ -84,6 +84,7 @@ def plot_wavelet_grid(
84
84
  nan_color: Optional[str] = "black",
85
85
  detailed_axes: bool = False,
86
86
  show_gridinfo: bool = True,
87
+ txtbox_kwargs: dict = {},
87
88
  trend_color: Optional[str] = None,
88
89
  whiten_by: Optional[np.ndarray] = None,
89
90
  **kwargs,
@@ -172,12 +173,17 @@ def plot_wavelet_grid(
172
173
  if np.all(np.isnan(z)):
173
174
  raise ValueError("All wavelet data is NaN.")
174
175
  if zscale == "log":
175
- norm = LogNorm(
176
- vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z < np.inf])
177
- )
176
+ vmin = np.nanmin(z[z > 0])
177
+ vmax = np.nanmax(z[z < np.inf])
178
+ if vmin > vmax:
179
+ raise ValueError("vmin > vmax... something wrong")
180
+ norm = LogNorm(vmin=vmin, vmax=vmax)
178
181
  elif not absolute:
179
182
  vmin, vmax = np.nanmin(z), np.nanmax(z)
180
183
  vcenter = 0.0
184
+ if vmin > vmax:
185
+ raise ValueError("vmin > vmax... something wrong")
186
+
181
187
  norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
182
188
  else:
183
189
  norm = None # Default linear scaling
@@ -248,6 +254,9 @@ def plot_wavelet_grid(
248
254
  NfNt_label = f"{Nf}x{Nt}" if show_gridinfo else ""
249
255
  txt = f"{label}\n{NfNt_label}" if label else NfNt_label
250
256
  if txt:
257
+ txtbox_kwargs.setdefault("boxstyle", "round")
258
+ txtbox_kwargs.setdefault("facecolor", "white")
259
+ txtbox_kwargs.setdefault("alpha", 0.2)
251
260
  ax.text(
252
261
  0.05,
253
262
  0.95,
@@ -255,7 +264,7 @@ def plot_wavelet_grid(
255
264
  transform=ax.transAxes,
256
265
  fontsize=14,
257
266
  verticalalignment="top",
258
- bbox=dict(boxstyle="round", facecolor=None, alpha=0.2),
267
+ bbox=txtbox_kwargs,
259
268
  )
260
269
 
261
270
  # Adjust layout
@@ -294,7 +303,7 @@ def plot_periodogram(
294
303
  flow = np.min(np.abs(freq))
295
304
  ax.set_xlabel("Frequency [Hz]")
296
305
  ax.set_ylabel("Periodigram")
297
- ax.set_xlim(left=flow, right=nyquist_frequency / 2)
306
+ # ax.set_xlim(left=flow, right=nyquist_frequency / 2)
298
307
  return ax.figure, ax
299
308
 
300
309
 
@@ -477,13 +477,13 @@ class WaveletMask(Wavelet):
477
477
  A WaveletMask object with the specified restrictions.
478
478
  """
479
479
  self = cls.zeros_from_grid(time_grid, freq_grid)
480
- self.data[
481
- (freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
482
- ] = True
480
+ self.data[(freq_grid >= frange[0]) & (freq_grid <= frange[1]), :] = (
481
+ True
482
+ )
483
483
 
484
484
  for tgap in tgaps:
485
- self.data[
486
- :, (time_grid >= tgap[0]) & (time_grid <= tgap[1])
487
- ] = False
485
+ self.data[:, (time_grid >= tgap[0]) & (time_grid <= tgap[1])] = (
486
+ False
487
+ )
488
488
  self.data = self.data.astype(bool)
489
489
  return self