pywavelet 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,28 @@
1
1
  from functools import partial
2
2
 
3
+ import jax
3
4
  import jax.numpy as jnp
4
5
  from jax import jit
5
6
  from jax.numpy.fft import ifft
6
7
 
8
+ X64_PRECISION = jax.config.jax_enable_x64
9
+
10
+ CMPLX_DTYPE = jnp.complex128 if X64_PRECISION else jnp.complex64
11
+
12
+
7
13
  import logging
8
14
 
9
- logger = logging.getLogger('pywavelet')
15
+ logger = logging.getLogger("pywavelet")
10
16
 
11
17
 
12
- @partial(jit, static_argnames=("Nf", "Nt"))
18
+ @partial(jit, static_argnames=("Nf", "Nt", "float_dtype", "complex_dtype"))
13
19
  def transform_wavelet_freq_helper(
14
- data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
20
+ data: jnp.ndarray,
21
+ Nf: int,
22
+ Nt: int,
23
+ phif: jnp.ndarray,
24
+ float_dtype=jnp.float64,
25
+ complex_dtype=jnp.complex128,
15
26
  ) -> jnp.ndarray:
16
27
  """
17
28
  Transforms input data from the frequency domain to the wavelet domain using a
@@ -26,78 +37,60 @@ def transform_wavelet_freq_helper(
26
37
  Returns:
27
38
  - wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
28
39
  """
29
-
30
- logger.debug(f"[JAX TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]")
31
-
32
- # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
33
- wave = jnp.zeros((Nt, Nf))
34
- f_bins = jnp.arange(Nf) # Frequency bin indices
35
-
36
- # Compute base indices for time (i_base) and frequency (jj_base)
37
- i_base = Nt // 2
38
- jj_base = f_bins * Nt // 2
39
-
40
- # Set initial values for the center of the transformation
41
- initial_values = jnp.where(
42
- (f_bins == 0)
43
- | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
44
- phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
45
- phif[0] * data[f_bins * Nt // 2],
40
+ logger.debug(
41
+ f"[JAX TRANSFORM] Input types [data:{type(data)},{data.dtype}, phif:{type(phif)},{phif.dtype}]"
42
+ )
43
+ half = Nt // 2
44
+ f_bins = jnp.arange(Nf + 1) # [0,1,...,Nf]
45
+
46
+ # --- 1) build the full (Nf+1, Nt) DX array ---
47
+ # center (j = 0):
48
+ center = phif[0] * data[f_bins * half]
49
+ center = jnp.where((f_bins == 0) | (f_bins == Nf), center / 2.0, center)
50
+ DX = jnp.zeros((Nf + 1, Nt), complex_dtype)
51
+ DX = DX.at[:, half].set(center)
52
+
53
+ # off-center (j = +/-1...+/-(half−1))
54
+ offs = jnp.arange(1 - half, half) # length Nt−1
55
+ jj = f_bins[:, None] * half + offs[None, :] # shape (Nf+1, Nt−1)
56
+ ii = half + offs # shape (Nt−1,)
57
+ mask = ((f_bins[:, None] == Nf) & (offs[None, :] > 0)) | (
58
+ (f_bins[:, None] == 0) & (offs[None, :] < 0)
59
+ )
60
+ vals = phif[jnp.abs(offs)] * data[jj]
61
+ vals = jnp.where(mask, 0.0, vals)
62
+ DX = DX.at[:, ii].set(vals)
63
+
64
+ # --- 2) ifft along time axis ---
65
+ DXt = jnp.fft.ifft(DX, n=Nt, axis=1)
66
+
67
+ # --- 3) unpack into wave (Nt, Nf) ---
68
+ wave = jnp.zeros((Nt, Nf), float_dtype)
69
+ sqrt2 = jnp.sqrt(2.0)
70
+
71
+ # 3a) DC into col 0, even rows
72
+ evens = jnp.arange(0, Nt, 2)
73
+ wave = wave.at[evens, 0].set(jnp.real(DXt[0, evens]) * sqrt2)
74
+
75
+ # 3b) Nyquist into col 0, odd rows
76
+ odds = jnp.arange(1, Nt, 2)
77
+ wave = wave.at[odds, 0].set(jnp.real(DXt[Nf, evens]) * sqrt2)
78
+
79
+ # 3c) intermediate bins 1...Nf−1
80
+ mids = jnp.arange(1, Nf) # [1...Nf-1]
81
+ Dt_mid = DXt[mids, :] # shape (Nf-1, Nt)
82
+ real_m = jnp.real(Dt_mid).T # (Nt, Nf-1)
83
+ imag_m = jnp.imag(Dt_mid).T # (Nt, Nf-1)
84
+
85
+ odd_f = (mids % 2) == 1 # shape (Nf-1,)
86
+ n_idx = jnp.arange(Nt)[:, None] # (Nt,1)
87
+ odd_n_f = ((n_idx + mids[None, :]) % 2) == 1 # (Nt, Nf-1)
88
+
89
+ mid_vals = jnp.where(
90
+ odd_n_f,
91
+ jnp.where(odd_f, -imag_m, imag_m),
92
+ jnp.where(odd_f, real_m, real_m),
46
93
  )
94
+ wave = wave.at[:, 1:Nf].set(mid_vals)
47
95
 
48
- # Initialize a 2D array to store intermediate FFT input values
49
- DX = jnp.zeros(
50
- (Nf, Nt), dtype=jnp.complex64
51
- ) # TODO: Check dtype -- is complex64 sufficient?
52
- DX = DX.at[:, Nt // 2].set(
53
- initial_values
54
- ) # Set initial values at the center of the transformation (2 sided FFT)
55
-
56
- # Compute time indices for all offsets around the midpoint
57
- j_range = jnp.arange(
58
- 1 - Nt // 2, Nt // 2
59
- ) # Time offsets (centered around zero)
60
- j = jnp.abs(j_range) # Absolute offset indices
61
- i = i_base + j_range # Time indices in output array
62
-
63
- # Compute conditions for edge cases
64
- cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
65
- cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
66
- cond3 = j[None, :] == 0 # Center of the transformation (no offset)
67
-
68
- # Compute frequency indices for the input data
69
- jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
70
- val = jnp.where(
71
- cond1 | cond2, 0.0, phif[j] * data[jj]
72
- ) # Wavelet filter application
73
- DX = DX.at[:, i].set(
74
- jnp.where(cond3, DX[:, i], val)
75
- ) # Update DX with computed values
76
- # At this point, DX contains the data FFT'd with the wavelet filter
77
- # (each row is a frequency bin, each column is a time bin)
78
-
79
- # Perform the inverse FFT along the time dimension
80
- DX_trans = ifft(DX, axis=1)
81
-
82
- # Fill the wavelet output array based on the inverse FFT results
83
- n_range = jnp.arange(Nt) # Time indices
84
- cond1 = (
85
- n_range[:, None] + f_bins[None, :]
86
- ) % 2 == 1 # Odd/even alternation
87
- cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
88
-
89
- # Assign real and imaginary parts based on conditions
90
- real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
91
- imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
92
- wave = jnp.where(cond1, imag_part.T, real_part.T)
93
-
94
- # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
95
- wave = wave.at[::2, 0].set(
96
- jnp.real(DX_trans[0, ::2] * jnp.sqrt(2))
97
- ) # DC component
98
- wave = wave.at[1::2, -1].set(
99
- jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2))
100
- ) # Nyquist component
101
-
102
- # Return the wavelet-transformed array (transposed for freq-major layout)
103
- return wave.T # (Nt, Nf) -> (Nf, Nt)
96
+ return wave.T
@@ -2,6 +2,7 @@ from typing import Union
2
2
 
3
3
  import jax.numpy as jnp
4
4
 
5
+ from .... import backend
5
6
  from ....logger import logger
6
7
  from ....types import FrequencySeries, TimeSeries, Wavelet
7
8
  from ....types.wavelet_bins import _get_bins, _preprocess_bins
@@ -16,7 +17,6 @@ def from_time_to_wavelet(
16
17
  Nt: Union[int, None] = None,
17
18
  nx: float = 4.0,
18
19
  mult: int = 32,
19
- **kwargs,
20
20
  ) -> Wavelet:
21
21
  """Transforms time-domain data to wavelet-domain data.
22
22
 
@@ -45,7 +45,6 @@ def from_time_to_wavelet(
45
45
 
46
46
  """
47
47
  Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
48
- dt = timeseries.dt
49
48
  t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
50
49
 
51
50
  ND = Nf * Nt
@@ -73,7 +72,6 @@ def from_freq_to_wavelet(
73
72
  Nf: Union[int, None] = None,
74
73
  Nt: Union[int, None] = None,
75
74
  nx: float = 4.0,
76
- **kwargs,
77
75
  ) -> Wavelet:
78
76
  """Transforms frequency-domain data to wavelet-domain data.
79
77
 
@@ -100,7 +98,12 @@ def from_freq_to_wavelet(
100
98
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
101
99
  phif = jnp.array(phitilde_vec_norm(Nf, Nt, d=nx))
102
100
  wave = transform_wavelet_freq_helper(
103
- freqseries.data, Nf=Nf, Nt=Nt, phif=phif
101
+ freqseries.data,
102
+ Nf=Nf,
103
+ Nt=Nt,
104
+ phif=phif,
105
+ float_dtype=backend.float_dtype,
106
+ complex_dtype=backend.complex_dtype,
104
107
  )
105
-
106
- return Wavelet((2 / Nf) * wave * jnp.sqrt(2), time=t_bins, freq=f_bins)
108
+ factor = (2 / Nf) * jnp.sqrt(2)
109
+ return Wavelet(factor * wave, time=t_bins, freq=f_bins)
@@ -11,73 +11,56 @@ def inverse_wavelet_freq_helper(
11
11
  wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
12
12
  ) -> jnp.ndarray:
13
13
  """JAX vectorized function for inverse_wavelet_freq with corrected shapes and ranges."""
14
- # Transpose to match the NumPy version.
15
14
  wave_in = wave_in.T
16
15
  ND = Nf * Nt
17
16
 
18
- # Allocate prefactor2s for each m (shape: (Nf+1, Nt)).
19
17
  m_range = jnp.arange(Nf + 1)
20
18
  prefactor2s = jnp.zeros((Nf + 1, Nt), dtype=jnp.complex128)
21
19
  n_range = jnp.arange(Nt)
22
20
 
23
- # m == 0 case
21
+ # Handle m=0 and m=Nf cases
24
22
  prefactor2s = prefactor2s.at[0].set(
25
23
  2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
26
24
  )
27
-
28
- # m == Nf case
29
25
  prefactor2s = prefactor2s.at[Nf].set(
30
26
  2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
31
27
  )
32
28
 
33
- # Other m cases: use meshgrid for vectorization.
29
+ # Handle middle m cases
34
30
  m_mid = m_range[1:Nf]
35
- # Create grids: n_grid (columns) and m_grid (rows)
36
- n_grid, m_grid = jnp.meshgrid(n_range, m_mid)
37
- # Index the transposed wave_in using (n, m) as in the NumPy version.
31
+ n_grid, m_grid = jnp.meshgrid(n_range, m_mid, indexing="ij")
38
32
  val = wave_in[n_grid, m_grid]
39
- # Apply the alternating multiplier based on (n+m) parity.
40
33
  mult2 = jnp.where((n_grid + m_grid) % 2, -1j, 1)
41
- prefactor2s = prefactor2s.at[1:Nf].set(mult2 * val)
34
+ prefactor2s = prefactor2s.at[1:Nf].set((mult2 * val).T)
42
35
 
43
- # Apply FFT along axis 1 for all m.
44
36
  fft_prefactor2s = fft(prefactor2s, axis=1)
45
37
 
46
- # Allocate the result array with corrected shape.
47
38
  res = jnp.zeros(ND // 2 + 1, dtype=jnp.complex128)
48
39
 
49
- # Unpacking for m == 0 and m == Nf cases:
40
+ # Unpack for m=0 and m=Nf
50
41
  i_ind_range = jnp.arange(Nt // 2)
51
- i_0 = jnp.abs(i_ind_range) # for m == 0: i = i_ind_range
52
- i_Nf = jnp.abs(Nf * (Nt // 2) - i_ind_range)
42
+ i_0 = i_ind_range
53
43
  ind3_0 = (2 * i_0) % Nt
54
- ind3_Nf = (2 * i_Nf) % Nt
55
-
56
44
  res = res.at[i_0].add(fft_prefactor2s[0, ind3_0] * phif[i_ind_range])
45
+
46
+ i_Nf = jnp.abs(Nf * (Nt // 2) - i_ind_range)
47
+ ind3_Nf = (2 * i_Nf) % Nt
57
48
  res = res.at[i_Nf].add(fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range])
58
- # Special case for m == Nf (ensure the Nyquist frequency is updated correctly)
49
+
59
50
  special_index = jnp.abs(Nf * (Nt // 2) - (Nt // 2))
60
51
  res = res.at[special_index].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
61
52
 
62
- # Unpacking for m in (1, ..., Nf-1)
53
+ # Unpack for middle m values
63
54
  m_mid = m_range[1:Nf]
64
- # Use range [0, Nt//2) to match the loop in NumPy version.
65
55
  i_ind_range_mid = jnp.arange(Nt // 2)
66
- # Create meshgrid for vectorized computation.
67
56
  m_grid_mid, i_ind_grid_mid = jnp.meshgrid(
68
57
  m_mid, i_ind_range_mid, indexing="ij"
69
58
  )
70
-
71
- # Compute indices i1 and i2 following the NumPy logic.
72
59
  i1 = (Nt // 2) * m_grid_mid - i_ind_grid_mid
73
60
  i2 = (Nt // 2) * m_grid_mid + i_ind_grid_mid
74
- # Compute the wrapped indices for FFT results.
75
61
  ind31 = i1 % Nt
76
62
  ind32 = i2 % Nt
77
63
 
78
- # Update result array using vectorized adds.
79
- # Note: You might need to adjust this further if your target res shape is non-trivial,
80
- # because here we assume that i1 and i2 indices fall within the allocated result shape.
81
64
  res = res.at[i1].add(
82
65
  fft_prefactor2s[m_grid_mid, ind31] * phif[i_ind_grid_mid]
83
66
  )
@@ -85,4 +68,10 @@ def inverse_wavelet_freq_helper(
85
68
  fft_prefactor2s[m_grid_mid, ind32] * phif[i_ind_grid_mid]
86
69
  )
87
70
 
71
+ # Correct the center points for middle m's
72
+ center_indices = (Nt // 2) * m_mid
73
+ fft_indices = center_indices % Nt
74
+ values = fft_prefactor2s[m_mid, fft_indices] * phif[0]
75
+ res = res.at[center_indices].set(values)
76
+
88
77
  return res
@@ -12,7 +12,12 @@ logger = logging.getLogger("pywavelet")
12
12
 
13
13
 
14
14
  def transform_wavelet_freq_helper(
15
- data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
15
+ data: np.ndarray,
16
+ Nf: int,
17
+ Nt: int,
18
+ phif: np.ndarray,
19
+ float_dtype: np.dtype = np.float64,
20
+ complex_dtype: np.dtype = np.complex128,
16
21
  ) -> np.ndarray:
17
22
  """
18
23
  Forward wavelet transform helper using the fast wavelet domain transform,
@@ -36,14 +41,14 @@ def transform_wavelet_freq_helper(
36
41
  f_bin==0 and f_bin==Nf are both stored in column 0.
37
42
  """
38
43
  logger.debug(
39
- f"[NUMPY TRANSFORM] Input types [data:{type(data)}, phif:{type(phif)}]"
44
+ f"[NUMPY TRANSFORM] Input types [data:{type(data)},{data.dtype}, phif:{type(phif)},{phif.dtype}]"
40
45
  )
41
- wave = np.zeros((Nt, Nf), dtype=np.float64)
42
- DX = np.zeros(Nt, dtype=np.complex128)
46
+ wave = np.zeros((Nt, Nf), dtype=float_dtype)
47
+ DX = np.zeros(Nt, dtype=complex_dtype)
43
48
  # Create a copy of the input data (if needed).
44
49
  freq_strain = data.copy()
45
50
  __core(Nf, Nt, DX, freq_strain, phif, wave)
46
- return wave
51
+ return wave.T
47
52
 
48
53
 
49
54
  @njit()
@@ -64,7 +69,10 @@ def __core(
64
69
  for f_bin in range(0, Nf + 1):
65
70
  __fill_wave_1(f_bin, Nt, Nf, DX, data, phif)
66
71
  # Use rocket-fft's ifft (which is JIT-able) instead of np.fft.ifft.
67
- DX_trans = np.fft.ifft(DX, Nt)
72
+ DX_trans = np.fft.ifft(
73
+ DX,
74
+ Nt,
75
+ )
68
76
  __fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
69
77
 
70
78
 
@@ -2,6 +2,7 @@ from typing import Union
2
2
 
3
3
  import numpy as np
4
4
 
5
+ from .... import backend
5
6
  from ....logger import logger
6
7
  from ....types import FrequencySeries, TimeSeries, Wavelet
7
8
  from ....types.wavelet_bins import _get_bins, _preprocess_bins
@@ -112,7 +113,15 @@ def from_freq_to_wavelet(
112
113
  """
113
114
  Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
114
115
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
115
- phif = phitilde_vec_norm(Nf, Nt, d=nx)
116
- wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
117
-
118
- return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
116
+ phif = np.array(phitilde_vec_norm(Nf, Nt, d=nx), dtype=backend.float_dtype)
117
+ data = np.array(freqseries.data, dtype=backend.complex_dtype)
118
+ wave = transform_wavelet_freq_helper(
119
+ data,
120
+ Nf,
121
+ Nt,
122
+ phif,
123
+ float_dtype=backend.float_dtype,
124
+ complex_dtype=backend.complex_dtype,
125
+ )
126
+ factor = (backend.float_dtype)((2 / Nf) * np.sqrt(2))
127
+ return Wavelet(factor * wave, time=t_bins, freq=f_bins)
@@ -1,25 +1,43 @@
1
- import numpy as np
2
- from jaxtyping import Array, Float
1
+ """
2
+ This module contains functions to compute the Fourier transform of the
3
+ wavelet function and its normalization. The wavelet function is defined
4
+ in the frequency domain and is used to transform time-domain data into
5
+ the wavelet domain.
6
+
7
+ Everything in this module is retured as a npfloat64 array.
8
+ """
3
9
 
4
- from ..backend import betainc, ifft, xp
10
+ from functools import lru_cache
11
+
12
+ import numpy as np
13
+ from jaxtyping import Float64
14
+ from numpy.fft import ifft
15
+ from scipy.special import betainc
5
16
 
6
17
  __all__ = ["phitilde_vec_norm", "phi_vec", "omega"]
7
18
 
8
19
 
9
- def omega(Nf: int, Nt: int) -> Float[Array, " Nt//2+1"]:
20
+ @lru_cache(maxsize=None)
21
+ def omega(Nf: int, Nt: int) -> Float64[np.ndarray, "{Nt}//2+1"]:
10
22
  """Get the angular frequencies of the time domain signal."""
11
23
  df = 2 * np.pi / (Nf * Nt)
12
- return df * np.arange(0, Nt // 2 + 1)
24
+ return df * np.arange(0, Nt // 2 + 1, dtype=np.float64)
13
25
 
14
26
 
15
- def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> Float[Array, " Nt//2+1"]:
27
+ @lru_cache(maxsize=None)
28
+ def phitilde_vec_norm(
29
+ Nf: int, Nt: int, d: float
30
+ ) -> Float64[np.ndarray, "{Nt}//2+1"]:
16
31
  """Normalize phitilde for inverse frequency domain transform."""
17
32
  omegas = omega(Nf, Nt)
18
33
  _phi_t = _phitilde_vec(omegas, Nf, d) * np.sqrt(np.pi)
19
- return xp.array(_phi_t)
34
+ return np.array(_phi_t)
20
35
 
21
36
 
22
- def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
37
+ @lru_cache(maxsize=None)
38
+ def phi_vec(
39
+ Nf: int, d: float = 4.0, q: int = 16
40
+ ) -> Float64[np.ndarray, "2*{q}*{Nf}"]:
23
41
  """get time domain phi as fourier transform of _phitilde_vec
24
42
  q: number of Nf bins over which the window extends?
25
43
 
@@ -29,7 +47,6 @@ def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
29
47
  half_K = q * Nf # xp.int64(K/2)
30
48
 
31
49
  dom = 2 * np.pi / K # max frequency is K/2*dom = pi/dt = OM
32
-
33
50
  DX = np.zeros(K, dtype=np.complex128)
34
51
 
35
52
  # zero frequency
@@ -51,12 +68,12 @@ def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> Float[Array, " 2*q*Nf"]:
51
68
  nrm = np.sqrt(2.0) / np.sqrt(K / dom) # *xp.linalg.norm(phi)
52
69
 
53
70
  phi *= nrm
54
- return xp.array(phi)
71
+ return np.array(phi)
55
72
 
56
73
 
57
74
  def _phitilde_vec(
58
- omega: Float[Array, " Nt//2+1"], Nf: int, d: float = 4.0
59
- ) -> Float[Array, " Nt//2+1"]:
75
+ omega: Float64[np.ndarray, "dim"], Nf: int, d: float = 4.0
76
+ ) -> Float64[np.ndarray, "dim"]:
60
77
  """Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
61
78
 
62
79
  Eq 11 of https://arxiv.org/pdf/2009.00043.pdf (Cornish et al. 2020)
@@ -67,8 +84,6 @@ def _phitilde_vec(
67
84
 
68
85
  Where nu_d = normalized incomplete beta function
69
86
 
70
-
71
-
72
87
  Parameters
73
88
  ----------
74
89
  ω : xp.ndarray
@@ -93,22 +108,22 @@ def _phitilde_vec(
93
108
  if B <= 0:
94
109
  raise ValueError("B must be greater than 0")
95
110
 
96
- phi = np.zeros(omega.size)
111
+ phi = np.zeros(omega.size, dtype=np.float64)
97
112
  mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B) # Minor changes
98
113
  vd = (np.pi / 2.0) * _nu_d(omega[mask], A, B, d=d) # different from paper
99
- phi[mask] = inverse_sqrt_dOmega * xp.cos(vd)
114
+ phi[mask] = inverse_sqrt_dOmega * np.cos(vd)
100
115
  phi[np.abs(omega) < A] = inverse_sqrt_dOmega
101
116
  return phi
102
117
 
103
118
 
104
119
  def _nu_d(
105
- omega: Float[Array, " Nt//2+1"], A: float, B: float, d: float = 4.0
106
- ) -> Float[Array, " Nt//2+1"]:
120
+ omega: Float64[np.ndarray, "dim"], A: float, B: float, d: float = 4.0
121
+ ) -> Float64[np.ndarray, "dim"]:
107
122
  """Compute the normalized incomplete beta function.
108
123
 
109
124
  Parameters
110
125
  ----------
111
- ω : xp.ndarray
126
+ omega : np.ndarray
112
127
  Array of angular frequencies
113
128
  A : float
114
129
  Lower bound for the beta function
@@ -119,7 +134,7 @@ def _nu_d(
119
134
 
120
135
  Returns
121
136
  -------
122
- xp.ndarray
137
+ np.ndarray
123
138
  Array of ν_d values
124
139
 
125
140
  scipy.special.betainc
pywavelet/types/common.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from typing import Callable, Tuple, Union
2
2
 
3
- from ..backend import xp
3
+ from ..backend import complex_dtype, float_dtype, xp
4
4
  from ..logger import logger
5
5
 
6
6
 
@@ -303,7 +303,7 @@ def plot_periodogram(
303
303
  flow = np.min(np.abs(freq))
304
304
  ax.set_xlabel("Frequency [Hz]")
305
305
  ax.set_ylabel("Periodigram")
306
- # ax.set_xlim(left=flow, right=nyquist_frequency / 2)
306
+ ax.set_xlim(left=flow, right=nyquist_frequency / 2)
307
307
  return ax.figure, ax
308
308
 
309
309
 
@@ -1,6 +1,7 @@
1
1
  from typing import Optional, Tuple, Union
2
2
 
3
3
  import matplotlib.pyplot as plt
4
+ from astropy.utils.metadata.utils import dtype
4
5
  from scipy.signal import butter, sosfiltfilt
5
6
  from scipy.signal.windows import tukey
6
7
 
@@ -3,7 +3,7 @@ from typing import List, Tuple
3
3
  import matplotlib.pyplot as plt
4
4
  import numpy as np
5
5
 
6
- from .common import fmt_timerange, is_documented_by, xp
6
+ from .common import float_dtype, fmt_timerange, is_documented_by, xp
7
7
  from .plotting import plot_wavelet_grid, plot_wavelet_trend
8
8
  from .wavelet_bins import compute_bins
9
9
 
@@ -70,7 +70,9 @@ class Wavelet:
70
70
  A Wavelet object with zero-filled data array.
71
71
  """
72
72
  Nf, Nt = len(freq), len(time)
73
- return cls(data=xp.zeros((Nf, Nt)), time=time, freq=freq)
73
+ return cls(
74
+ data=xp.zeros((Nf, Nt), dtype=float_dtype), time=time, freq=freq
75
+ )
74
76
 
75
77
  @classmethod
76
78
  def zeros(cls, Nf: int, Nt: int, T: float) -> "Wavelet":
@@ -1,6 +1,6 @@
1
1
  from typing import Tuple, Union
2
2
 
3
- from ..backend import xp
3
+ from ..backend import complex_dtype, float_dtype, xp
4
4
  from .frequencyseries import FrequencySeries
5
5
  from .timeseries import TimeSeries
6
6
 
@@ -35,13 +35,7 @@ def _get_bins(
35
35
  ) -> Tuple[xp.ndarray, xp.ndarray]:
36
36
  T = data.duration
37
37
  t_bins, f_bins = compute_bins(Nf, Nt, T)
38
-
39
- # N = len(data)
40
- # fs = N / T
41
- # assert delta_f == fmax / Nf, f"delta_f={delta_f} != fmax/Nf={fmax/Nf}"
42
-
43
38
  t_bins += data.t0
44
-
45
39
  return t_bins, f_bins
46
40
 
47
41
 
@@ -51,6 +45,6 @@ def compute_bins(Nf: int, Nt: int, T: float) -> Tuple[xp.ndarray, xp.ndarray]:
51
45
  """
52
46
  delta_T = T / Nt
53
47
  delta_F = 1 / (2 * delta_T)
54
- t_bins = xp.arange(0, Nt) * delta_T
55
- f_bins = xp.arange(0, Nf) * delta_F
48
+ t_bins = xp.arange(0, Nt, dtype=float_dtype) * delta_T
49
+ f_bins = xp.arange(0, Nf, dtype=float_dtype) * delta_F
56
50
  return t_bins, f_bins
@@ -0,0 +1,6 @@
1
+ from .analysis import (
2
+ compute_likelihood,
3
+ compute_snr,
4
+ evolutionary_psd_from_stationary_psd,
5
+ noise_weighted_inner_product,
6
+ )
@@ -3,7 +3,7 @@ from typing import Union
3
3
  import numpy as np
4
4
  from scipy.interpolate import interp1d
5
5
 
6
- from .types import FrequencySeries, TimeSeries, Wavelet, WaveletMask
6
+ from ..types import FrequencySeries, TimeSeries, Wavelet, WaveletMask
7
7
 
8
8
  DATA_TYPE = Union[TimeSeries, FrequencySeries, Wavelet]
9
9
 
File without changes