pywavelet 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.1.1'
16
- __version_tuple__ = version_tuple = (0, 1, 1)
15
+ __version__ = version = '0.2.0'
16
+ __version_tuple__ = version_tuple = (0, 2, 0)
pywavelet/backend.py ADDED
@@ -0,0 +1,26 @@
1
+ import os
2
+
3
+ try:
4
+ import jax
5
+
6
+ jax_available = True
7
+
8
+
9
+ except ImportError:
10
+ jax_available = False
11
+
12
+ use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
13
+
14
+ if use_jax:
15
+ import jax.numpy as xp # type: ignore
16
+ from jax.scipy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
17
+ from jax.scipy.special import betainc # type: ignore
18
+
19
+
20
+ else:
21
+ import numpy as xp # type: ignore
22
+ from numpy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
23
+ from scipy.special import betainc # type: ignore
24
+
25
+
26
+ PI = xp.pi
@@ -1,9 +1,17 @@
1
- from .forward import from_freq_to_wavelet, from_time_to_wavelet
2
- from .inverse import from_wavelet_to_freq, from_wavelet_to_time
1
+ from .numpy import (
2
+ from_wavelet_to_time,
3
+ from_wavelet_to_freq,
4
+ from_time_to_wavelet,
5
+ from_freq_to_wavelet,
6
+ )
7
+ from .phi_computer import phi_vec, phitilde_vec_norm, phitilde_vec
3
8
 
4
9
  __all__ = [
5
10
  "from_wavelet_to_time",
6
11
  "from_wavelet_to_freq",
7
12
  "from_time_to_wavelet",
8
13
  "from_freq_to_wavelet",
14
+ "phitilde_vec_norm",
15
+ "phi_vec",
16
+ "phitilde_vec",
9
17
  ]
File without changes
@@ -0,0 +1,6 @@
1
+ from .main import from_freq_to_wavelet, from_time_to_wavelet
2
+ from ....logger import logger
3
+
4
+ logger.warning("JAX SUBPACKAGE NOT YET TESTED")
5
+
6
+ __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -0,0 +1,56 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from functools import partial
4
+ from jax import jit
5
+ from jax.numpy.fft import ifft
6
+
7
+ @partial(jit, static_argnames=('Nf', 'Nt'))
8
+ def transform_wavelet_freq_helper(
9
+ data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
10
+ ) -> jnp.ndarray:
11
+ # Initially all wrk being done in time-rws, freq-cols
12
+ wave = jnp.zeros((Nt, Nf))
13
+ f_bins = jnp.arange(Nf)
14
+
15
+ i_base = Nt // 2
16
+ jj_base = f_bins * Nt // 2
17
+
18
+ initial_values = jnp.where(
19
+ (f_bins == 0) | (f_bins == Nf),
20
+ phif[0] * data[f_bins * Nt // 2] / 2.0,
21
+ phif[0] * data[f_bins * Nt // 2]
22
+ )
23
+
24
+ DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)
25
+ DX = DX.at[:, Nt // 2].set(initial_values)
26
+
27
+ j_range = jnp.arange(1 - Nt // 2, Nt // 2)
28
+ j = jnp.abs(j_range)
29
+ i = i_base + j_range
30
+
31
+ cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)
32
+ cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)
33
+ cond3 = j[None, :] == 0
34
+
35
+ jj = jj_base[:, None] + j_range[None, :]
36
+ val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])
37
+ DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))
38
+
39
+ # Vectorized ifft
40
+ DX_trans = ifft(DX, axis=1)
41
+
42
+ # Vectorized __fill_wave_2_jax
43
+ n_range = jnp.arange(Nt)
44
+ cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1
45
+ cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # shape: (Nf, 1)
46
+
47
+ real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
48
+ imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
49
+
50
+ wave = jnp.where(cond1, imag_part.T, real_part.T)
51
+
52
+ ## Special cases for f_bin 0 and Nf
53
+ wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))
54
+ wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))
55
+
56
+ return wave.T
@@ -0,0 +1,51 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jax import jit
4
+ from jax.numpy.fft import rfft
5
+ from functools import partial
6
+
7
+ @partial(jit, static_argnames=('Nf', 'Nt', 'mult'))
8
+ def transform_wavelet_time_helper(
9
+ data: jnp.ndarray, phi: jnp.ndarray, Nf: int, Nt: int, mult: int
10
+ ) -> jnp.ndarray:
11
+ """Helper function to do the wavelet transform in the time domain using JAX"""
12
+ # Define constants
13
+ ND = Nf * Nt
14
+ K = mult * 2 * Nf
15
+
16
+ # Pad the data with K extra values
17
+ data_pad = jnp.concatenate((data, data[:K]))
18
+
19
+ # Generate time bin indices
20
+ time_bins = jnp.arange(Nt)
21
+ jj_base = (time_bins[:, None] * Nf - K // 2) % ND
22
+ jj = (jj_base + jnp.arange(K)[None, :]) % ND
23
+
24
+ # Apply the window (phi) to the data
25
+ wdata = data_pad[jj] * phi[None, :]
26
+
27
+ # Perform FFT on the windowed data
28
+ wdata_trans = rfft(wdata, axis=1)
29
+
30
+ # Initialize the wavelet transform result
31
+ wave = jnp.zeros((Nt, Nf))
32
+
33
+ # Handle m=0 case for even time bins
34
+ even_mask = (time_bins % 2 == 0) & (time_bins < Nt - 1)
35
+ even_indices = jnp.nonzero(even_mask, size=even_mask.shape[0])[0]
36
+
37
+ # Update wave for m=0 using even time bins
38
+ wave = wave.at[even_indices, 0].set(jnp.real(wdata_trans[even_indices, 0]) / jnp.sqrt(2))
39
+ wave = wave.at[even_indices + 1, 0].set(jnp.real(wdata_trans[even_indices, Nf * mult]) / jnp.sqrt(2))
40
+
41
+ # Handle other cases (j > 0) using vectorized operations
42
+ j_range = jnp.arange(1, Nf)
43
+ odd_condition = ((time_bins[:, None] + j_range[None, :]) % 2 == 1)
44
+
45
+ wave = wave.at[:, 1:].set(
46
+ jnp.where(odd_condition,
47
+ -jnp.imag(wdata_trans[:, j_range * mult]),
48
+ jnp.real(wdata_trans[:, j_range * mult]))
49
+ )
50
+
51
+ return wave.T
@@ -0,0 +1,110 @@
1
+ from typing import Union
2
+
3
+ import jax.numpy as jnp
4
+
5
+ from ....logger import logger
6
+ from ....types import FrequencySeries, TimeSeries, Wavelet
7
+ from ....types.wavelet_bins import _get_bins, _preprocess_bins
8
+ from ...phi_computer import phi_vec, phitilde_vec_norm
9
+ from .from_freq import transform_wavelet_freq_helper
10
+ from .from_time import transform_wavelet_time_helper
11
+
12
+
13
+ def from_time_to_wavelet(
14
+ timeseries: TimeSeries,
15
+ Nf: Union[int, None] = None,
16
+ Nt: Union[int, None] = None,
17
+ nx: float = 4.0,
18
+ mult: int = 32,
19
+ **kwargs,
20
+ ) -> Wavelet:
21
+ """Transforms time-domain data to wavelet-domain data.
22
+
23
+ Warning: there can be significant leakage if mult is too small and the
24
+ transform is only approximately exact if mult=Nt/2
25
+
26
+ Parameters
27
+ ----------
28
+ timeseries : TimeSeries
29
+ Time domain data
30
+ Nf : int
31
+ Number of frequency bins
32
+ Nt : int
33
+ Number of time bins
34
+ nx : float, optional
35
+ Number of standard deviations for the phi_vec, by default 4.
36
+ mult : int, optional
37
+ Number of time bins to use for the wavelet transform, by default 32
38
+ **kwargs:
39
+ Additional keyword arguments passed to the Wavelet.from_data constructor.
40
+
41
+ Returns
42
+ -------
43
+ Wavelet
44
+ Wavelet domain data
45
+
46
+ """
47
+ Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
48
+ dt = timeseries.dt
49
+ t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
50
+
51
+ ND = Nf * Nt
52
+
53
+ if len(timeseries) != ND:
54
+ logger.warning(
55
+ f"len(freqseries)={len(timeseries)} != Nf*Nt={ND}. Truncating to freqseries[:{ND}]"
56
+ )
57
+ timeseries = timeseries[:ND]
58
+ if mult > Nt / 2:
59
+ logger.warning(
60
+ f"mult={mult} is too large for Nt={Nt}. This may lead to bogus results."
61
+ )
62
+
63
+ mult = min(mult, Nt // 2) # make sure K isn't bigger than ND
64
+ phi = jnp.array(phi_vec(Nf, d=nx, q=mult))
65
+ wave = transform_wavelet_time_helper(timeseries.data, Nf=Nf, Nt=Nt, phi=phi, mult=mult)
66
+ return Wavelet(
67
+ wave* jnp.sqrt(2), time=t_bins, freq=f_bins
68
+ )
69
+
70
+
71
+ def from_freq_to_wavelet(
72
+ freqseries: FrequencySeries,
73
+ Nf: Union[int, None] = None,
74
+ Nt: Union[int, None] = None,
75
+ nx: float = 4.0,
76
+ **kwargs,
77
+ ) -> Wavelet:
78
+ """Transforms frequency-domain data to wavelet-domain data.
79
+
80
+ Parameters
81
+ ----------
82
+ freqseries : FrequencySeries
83
+ Frequency domain data
84
+ Nf : int
85
+ Number of frequency bins
86
+ Nt : int
87
+ Number of time bins
88
+ nx : float, optional
89
+ Number of standard deviations for the phi_vec, by default 4.
90
+ **kwargs:
91
+ Additional keyword arguments passed to the Wavelet.from_data constructor.
92
+
93
+ Returns
94
+ -------
95
+ Wavelet
96
+ Wavelet domain data
97
+
98
+ """
99
+ Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
100
+ t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
101
+ phif = jnp.array(phitilde_vec_norm(Nf, Nt, d=nx))
102
+ wave = transform_wavelet_freq_helper(
103
+ freqseries.data, Nf=Nf, Nt=Nt, phif=phif
104
+ )
105
+
106
+ return Wavelet(
107
+ (2 / Nf) * wave * jnp.sqrt(2),
108
+ time=t_bins,
109
+ freq=f_bins
110
+ )
@@ -0,0 +1,10 @@
1
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
2
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
3
+
4
+
5
+ __all__ = [
6
+ "from_wavelet_to_time",
7
+ "from_wavelet_to_freq",
8
+ "from_time_to_wavelet",
9
+ "from_freq_to_wavelet",
10
+ ]
@@ -2,7 +2,6 @@
2
2
 
3
3
  import numpy as np
4
4
  from numba import njit
5
- from numpy import fft
6
5
 
7
6
 
8
7
  def transform_wavelet_freq_helper(
@@ -2,10 +2,10 @@ from typing import Union
2
2
 
3
3
  import numpy as np
4
4
 
5
- from ...logger import logger
6
- from ...types import FrequencySeries, TimeSeries, Wavelet
7
- from ...types.wavelet_bins import _get_bins, _preprocess_bins
8
- from ..phi_computer import phi_vec, phitilde_vec_norm
5
+ from ....logger import logger
6
+ from ....types import FrequencySeries, TimeSeries, Wavelet
7
+ from ....types.wavelet_bins import _get_bins, _preprocess_bins
8
+ from ...phi_computer import phi_vec, phitilde_vec_norm
9
9
  from .from_freq import transform_wavelet_freq_helper
10
10
  from .from_time import transform_wavelet_time_helper
11
11
 
@@ -72,7 +72,7 @@ def from_time_to_wavelet(
72
72
  )
73
73
 
74
74
  mult = min(mult, Nt // 2) # Ensure mult is not larger than ND/2
75
- phi = phi_vec(Nf, dt=dt, d=nx, q=mult)
75
+ phi = phi_vec(Nf, d=nx, q=mult)
76
76
  wave = transform_wavelet_time_helper(timeseries.data, Nf, Nt, phi, mult).T
77
77
  return Wavelet(wave * np.sqrt(2), time=t_bins, freq=f_bins)
78
78
 
@@ -113,8 +113,7 @@ def from_freq_to_wavelet(
113
113
  """
114
114
  Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
115
115
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
116
- dt = freqseries.dt
117
- phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)
116
+ phif = phitilde_vec_norm(Nf, Nt, d=nx)
118
117
  wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
119
118
 
120
119
  return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
 
3
- from ...transforms.phi_computer import phi_vec, phitilde_vec_norm
4
- from ...types import FrequencySeries, TimeSeries, Wavelet
3
+ from ...phi_computer import phi_vec, phitilde_vec_norm
4
+ from ....types import FrequencySeries, TimeSeries, Wavelet
5
5
  from .to_freq import inverse_wavelet_freq_helper_fast
6
6
  from .to_time import inverse_wavelet_time_helper_fast
7
7
 
@@ -46,7 +46,7 @@ def from_wavelet_to_time(
46
46
  """
47
47
 
48
48
  mult = min(mult, wave_in.Nt // 2) # Ensure mult is not larger than ND/2
49
- phi = phi_vec(wave_in.Nf, d=nx, q=mult, dt=dt) / 2
49
+ phi = phi_vec(wave_in.Nf, d=nx, q=mult) / 2
50
50
  h_t = inverse_wavelet_time_helper_fast(
51
51
  wave_in.data.T, phi, wave_in.Nf, wave_in.Nt, mult
52
52
  )
@@ -84,7 +84,7 @@ def from_wavelet_to_freq(
84
84
  to ensure the proper backwards transformation.
85
85
  """
86
86
 
87
- phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt, dt=dt, d=nx)
87
+ phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx)
88
88
  freq_data = inverse_wavelet_freq_helper_fast(
89
89
  wave_in.data, phif, wave_in.Nf, wave_in.Nt
90
90
  )
@@ -1,13 +1,10 @@
1
- import numpy as np
2
- from numpy import fft
3
- from scipy.special import betainc
1
+ from ..backend import xp, PI, betainc, ifft
4
2
 
5
- PI = np.pi
6
3
 
7
4
 
8
5
  def phitilde_vec(
9
- omega: np.ndarray, Nf: int, dt: float, d: float = 4.0
10
- ) -> np.ndarray:
6
+ omega: xp.ndarray, Nf: int, d: float = 4.0
7
+ ) -> xp.ndarray:
11
8
  """Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
12
9
 
13
10
  Eq 11 of https://arxiv.org/pdf/2009.00043.pdf (Cornish et al. 2020)
@@ -22,7 +19,7 @@ def phitilde_vec(
22
19
 
23
20
  Parameters
24
21
  ----------
25
- ω : np.ndarray
22
+ ω : xp.ndarray
26
23
  Array of angular frequencies
27
24
  Nf : int
28
25
  Number of frequency bins
@@ -31,35 +28,35 @@ def phitilde_vec(
31
28
 
32
29
  Returns
33
30
  -------
34
- np.ndarray
31
+ xp.ndarray
35
32
  Array of phi_tilde(omega_i) values
36
33
 
37
34
  """
38
35
  dF = 1.0 / (2 * Nf) # NOTE: missing 1/dt?
39
36
  dOmega = 2 * PI * dF # Near Eq 10 # 2 pi times DF
40
- inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)
37
+ inverse_sqrt_dOmega = 1.0 / xp.sqrt(dOmega)
41
38
 
42
39
  A = dOmega / 4
43
40
  B = dOmega - 2 * A # Cannot have B \leq 0.
44
41
  if B <= 0:
45
42
  raise ValueError("B must be greater than 0")
46
43
 
47
- phi = np.zeros(omega.size)
48
- mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B) # Minor changes
44
+ phi = xp.zeros(omega.size)
45
+ mask = (A <= xp.abs(omega)) & (xp.abs(omega) < A + B) # Minor changes
49
46
  vd = (PI / 2.0) * __nu_d(omega[mask], A, B, d=d) # different from paper
50
- phi[mask] = inverse_sqrt_dOmega * np.cos(vd)
51
- phi[np.abs(omega) < A] = inverse_sqrt_dOmega
47
+ phi[mask] = inverse_sqrt_dOmega * xp.cos(vd)
48
+ phi[xp.abs(omega) < A] = inverse_sqrt_dOmega
52
49
  return phi
53
50
 
54
51
 
55
52
  def __nu_d(
56
- omega: np.ndarray, A: float, B: float, d: float = 4.0
57
- ) -> np.ndarray:
53
+ omega: xp.ndarray, A: float, B: float, d: float = 4.0
54
+ ) -> xp.ndarray:
58
55
  """Compute the normalized incomplete beta function.
59
56
 
60
57
  Parameters
61
58
  ----------
62
- ω : np.ndarray
59
+ ω : xp.ndarray
63
60
  Array of angular frequencies
64
61
  A : float
65
62
  Lower bound for the beta function
@@ -70,29 +67,29 @@ def __nu_d(
70
67
 
71
68
  Returns
72
69
  -------
73
- np.ndarray
70
+ xp.ndarray
74
71
  Array of ν_d values
75
72
 
76
73
  scipy.special.betainc
77
74
  https://docs.scipy.org/doc/scipy-1.7.1/reference/reference/generated/scipy.special.betainc.html
78
75
 
79
76
  """
80
- x = (np.abs(omega) - A) / B
77
+ x = (xp.abs(omega) - A) / B
81
78
  return betainc(d, d, x) / betainc(d, d, 1)
82
79
 
83
80
 
84
- def phitilde_vec_norm(Nf: int, Nt: int, dt: float, d: float) -> np.ndarray:
81
+ def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> xp.ndarray:
85
82
  """Normalize phitilde for inverse frequency domain transform."""
86
83
 
87
84
  # Calculate the frequency values
88
85
  ND = Nf * Nt
89
- omegas = 2 * np.pi / ND * np.arange(0, Nt // 2 + 1)
86
+ omegas = 2 * xp.pi / ND * xp.arange(0, Nt // 2 + 1)
90
87
 
91
88
  # Calculate the unnormalized phitilde (u_phit)
92
- u_phit = phitilde_vec(omegas, Nf, dt, d)
89
+ u_phit = phitilde_vec(omegas, Nf, d)
93
90
 
94
91
  # Normalize the phitilde
95
- normalising_factor = np.pi ** (-1 / 2) # Ollie's normalising factor
92
+ normalising_factor = PI ** (-1 / 2) # Ollie's normalising factor
96
93
 
97
94
  # Notes: this is the overall normalising factor that is different from Cornish's paper
98
95
  # It is the only way I can force this code to be consistent with our work in the
@@ -117,15 +114,15 @@ def phitilde_vec_norm(Nf: int, Nt: int, dt: float, d: float) -> np.ndarray:
117
114
  return u_phit / (normalising_factor)
118
115
 
119
116
 
120
- def phi_vec(Nf: int, dt, d: float = 4.0, q: int = 16) -> np.ndarray:
117
+ def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> xp.ndarray:
121
118
  """get time domain phi as fourier transform of phitilde_vec"""
122
- insDOM = 1.0 / np.sqrt(PI / Nf)
119
+ insDOM = 1.0 / xp.sqrt(PI / Nf)
123
120
  K = q * 2 * Nf
124
- half_K = q * Nf # np.int64(K/2)
121
+ half_K = q * Nf # xp.int64(K/2)
125
122
 
126
123
  dom = 2 * PI / K # max frequency is K/2*dom = pi/dt = OM
127
124
 
128
- DX = np.zeros(K, dtype=np.complex128)
125
+ DX = xp.zeros(K, dtype=xp.complex128)
129
126
 
130
127
  # zero frequency
131
128
  DX[0] = insDOM
@@ -133,20 +130,20 @@ def phi_vec(Nf: int, dt, d: float = 4.0, q: int = 16) -> np.ndarray:
133
130
  DX = DX.copy()
134
131
  # postive frequencies
135
132
  DX[1 : half_K + 1] = phitilde_vec(
136
- dom * np.arange(1, half_K + 1), Nf, dt, d
133
+ dom * xp.arange(1, half_K + 1), Nf, d
137
134
  )
138
135
  # negative frequencies
139
136
  DX[half_K + 1 :] = phitilde_vec(
140
- -dom * np.arange(half_K - 1, 0, -1), Nf, dt, d
137
+ -dom * xp.arange(half_K - 1, 0, -1), Nf, d
141
138
  )
142
- DX = K * fft.ifft(DX, K)
139
+ DX = K * ifft(DX, K)
143
140
 
144
- phi = np.zeros(K)
145
- phi[0:half_K] = np.real(DX[half_K:K])
146
- phi[half_K:] = np.real(DX[0:half_K])
141
+ phi = xp.zeros(K)
142
+ phi[0:half_K] = xp.real(DX[half_K:K])
143
+ phi[half_K:] = xp.real(DX[0:half_K])
147
144
 
148
- nrm = np.sqrt(K / dom) # *np.linalg.norm(phi)
145
+ nrm = xp.sqrt(K / dom) # *xp.linalg.norm(phi)
149
146
 
150
- fac = np.sqrt(2.0) / nrm
147
+ fac = xp.sqrt(2.0) / nrm
151
148
  phi *= fac
152
149
  return phi
pywavelet/types/common.py CHANGED
@@ -1,9 +1,6 @@
1
- from typing import Tuple, Union
2
-
3
- import numpy as xp
4
- from numpy.fft import fft, irfft, rfft, rfftfreq # type: ignore
5
-
1
+ from typing import Tuple, Union, Callable
6
2
  from ..logger import logger
3
+ from ..backend import xp
7
4
 
8
5
 
9
6
  def _len_check(d):
@@ -11,7 +8,7 @@ def _len_check(d):
11
8
  logger.warning(f"Data length {len(d)} is suggested to be a power of 2")
12
9
 
13
10
 
14
- def is_documented_by(original):
11
+ def is_documented_by(original:Callable):
15
12
  def wrapper(target):
16
13
  target.__doc__ = original.__doc__
17
14
  return target
@@ -2,7 +2,8 @@ from typing import Optional, Tuple, Union
2
2
 
3
3
  import matplotlib.pyplot as plt
4
4
 
5
- from .common import fmt_pow2, fmt_time, irfft, is_documented_by, xp
5
+ from ..backend import xp, irfft
6
+ from .common import fmt_pow2, fmt_time, is_documented_by
6
7
  from .plotting import plot_freqseries, plot_periodogram
7
8
 
8
9
  __all__ = ["FrequencySeries"]
@@ -226,7 +227,7 @@ class FrequencySeries:
226
227
  Wavelet
227
228
  The corresponding wavelet.
228
229
  """
229
- from ..transforms.forward import from_freq_to_wavelet
230
+ from ..transforms import from_freq_to_wavelet
230
231
 
231
232
  return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
232
233
 
@@ -10,10 +10,9 @@ from .common import (
10
10
  fmt_time,
11
11
  fmt_timerange,
12
12
  is_documented_by,
13
- rfft,
14
- rfftfreq,
15
- xp,
16
13
  )
14
+ from ..backend import xp, rfftfreq, rfft
15
+
17
16
  from .plotting import plot_spectrogram, plot_timeseries
18
17
 
19
18
  __all__ = ["TimeSeries"]
@@ -1,7 +1,6 @@
1
- from typing import List, Optional, Tuple
1
+ from typing import List, Tuple
2
2
 
3
3
  import matplotlib.pyplot as plt
4
- import numpy as np
5
4
 
6
5
  from .common import fmt_timerange, is_documented_by, xp
7
6
  from .plotting import plot_wavelet_grid, plot_wavelet_trend
@@ -286,7 +285,7 @@ class Wavelet:
286
285
  TimeSeries
287
286
  A `TimeSeries` object representing the time-domain signal.
288
287
  """
289
- from ..transforms.inverse import from_wavelet_to_time
288
+ from ..transforms import from_wavelet_to_time
290
289
 
291
290
  return from_wavelet_to_time(self, dt=self.delta_t, nx=nx, mult=mult)
292
291
 
@@ -299,7 +298,7 @@ class Wavelet:
299
298
  FrequencySeries
300
299
  A `FrequencySeries` object representing the frequency-domain signal.
301
300
  """
302
- from ..transforms.inverse import from_wavelet_to_freq
301
+ from ..transforms import from_wavelet_to_freq
303
302
 
304
303
  return from_wavelet_to_freq(self, dt=self.delta_t, nx=nx)
305
304
 
@@ -315,8 +314,8 @@ class Wavelet:
315
314
 
316
315
  frange = ",".join([f"{f:.2e}" for f in (self.freq[0], self.freq[-1])])
317
316
  trange = fmt_timerange((self.t0, self.tend))
318
- Nfpow2 = int(np.log2(self.shape[0]))
319
- Ntpow2 = int(np.log2(self.shape[1]))
317
+ Nfpow2 = int(xp.log2(self.shape[0]))
318
+ Ntpow2 = int(xp.log2(self.shape[1]))
320
319
  shapef = f"NfxNf=[2^{Nfpow2}, 2^{Ntpow2}]"
321
320
  return f"Wavelet({shapef}, [{frange}]Hz, {trange})"
322
321
 
@@ -441,14 +440,16 @@ class WaveletMask(Wavelet):
441
440
  return self.data
442
441
 
443
442
  def __repr__(self):
444
- return f"WaveletMask({self.mask.shape}, {fmt_timerange(self.time)}, {self.freq})"
443
+ rpr = super().__repr__()
444
+ rpr = rpr.replace("Wavelet", "WaveletMask")
445
+ return rpr
445
446
 
446
447
  @classmethod
447
448
  def from_frange(
448
449
  cls, time_grid: xp.ndarray, freq_grid: xp.ndarray, frange: List[float]
449
450
  ):
450
451
  self = cls.zeros_from_grid(time_grid, freq_grid)
451
- self.mask[
452
+ self.data[
452
453
  (freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
453
454
  ] = True
454
455
  return self
@@ -1,6 +1,6 @@
1
1
  from typing import Tuple, Union
2
2
 
3
- import numpy as np
3
+ from ..backend import xp
4
4
 
5
5
  from .frequencyseries import FrequencySeries
6
6
  from .timeseries import TimeSeries
@@ -33,7 +33,7 @@ def _get_bins(
33
33
  data: Union[TimeSeries, FrequencySeries],
34
34
  Nf: Union[int, None] = None,
35
35
  Nt: Union[int, None] = None,
36
- ) -> Tuple[np.ndarray, np.ndarray]:
36
+ ) -> Tuple[xp.ndarray, xp.ndarray]:
37
37
  T = data.duration
38
38
  t_bins, f_bins = compute_bins(Nf, Nt, T)
39
39
 
@@ -46,12 +46,12 @@ def _get_bins(
46
46
  return t_bins, f_bins
47
47
 
48
48
 
49
- def compute_bins(Nf: int, Nt: int, T: float) -> Tuple[np.ndarray, np.ndarray]:
49
+ def compute_bins(Nf: int, Nt: int, T: float) -> Tuple[xp.ndarray, xp.ndarray]:
50
50
  """Get the bins for the wavelet transform
51
51
  Eq 4-6 in Wavelets paper
52
52
  """
53
53
  delta_T = T / Nt
54
54
  delta_F = 1 / (2 * delta_T)
55
- t_bins = np.arange(0, Nt) * delta_T
56
- f_bins = np.arange(0, Nf) * delta_F
55
+ t_bins = xp.arange(0, Nt) * delta_T
56
+ f_bins = xp.arange(0, Nf) * delta_F
57
57
  return t_bins, f_bins
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.1.1
3
+ Version: 0.2.0
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -32,4 +32,6 @@ Requires-Dist: isort; extra == "dev"
32
32
  Requires-Dist: mypy; extra == "dev"
33
33
  Requires-Dist: jupyter-book; extra == "dev"
34
34
  Requires-Dist: GitPython; extra == "dev"
35
+ Provides-Extra: jax
36
+ Requires-Dist: jax; extra == "jax"
35
37
 
@@ -0,0 +1,32 @@
1
+ pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
+ pywavelet/_version.py,sha256=H-qsvrxCpdhaQzyddR-yajEqI71hPxLa4KxzpP3uS1g,411
3
+ pywavelet/backend.py,sha256=k4pDi6f4cwNY6HsUIx1xfuga9f2wLnFr_FIb7Fs1Mds,553
4
+ pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
5
+ pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
6
+ pywavelet/transforms/__init__.py,sha256=uc1fKbBGQgEDafJHk6GEVCc0G_EXL5CtFTKCoFsewoM,381
7
+ pywavelet/transforms/phi_computer.py,sha256=ppFSGJwtNnO2flaiok9ms3WXlAxGQikvA7eNfLgriNQ,4461
8
+ pywavelet/transforms/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ pywavelet/transforms/jax/forward/__init__.py,sha256=Ki2RJCfkE9Zy59mqT3oEtGK9Ro9kS5kAz9duZFbxyZo,200
10
+ pywavelet/transforms/jax/forward/from_freq.py,sha256=PsUC7RfrN6pRWWkMSXYHk9z5lxCXW3DfF0m-Rd1GOBE,1785
11
+ pywavelet/transforms/jax/forward/from_time.py,sha256=xNeoZq54B6Gi3TdTTYLr_euaFeJcwpms-lSyCG53AdI,1726
12
+ pywavelet/transforms/jax/forward/main.py,sha256=mm0R4m0pXcnzZB0jCckAc4ynG8STH5mldCmHyyU_PGo,3091
13
+ pywavelet/transforms/numpy/__init__.py,sha256=qFLpGpW3VJSbDp2JpD0Gx7PdwDjH-wrW_aO84ASkIgA,255
14
+ pywavelet/transforms/numpy/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
15
+ pywavelet/transforms/numpy/forward/from_freq.py,sha256=JmJyjrNSb64WnpP50VZRt0BICP64iZJP5QAZTZoexkw,2675
16
+ pywavelet/transforms/numpy/forward/from_time.py,sha256=-Y6VEKwDCYBAHAjLdO46vT-6alpM5fXTgTZ_xkYxqA8,2381
17
+ pywavelet/transforms/numpy/forward/main.py,sha256=3y-YCnhpvN7M4N7xy3CVts7n3QQPwDcJ6mkklX1QbFM,3973
18
+ pywavelet/transforms/numpy/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
19
+ pywavelet/transforms/numpy/inverse/main.py,sha256=-11U5tnDizIssHk824rpYrzbJRl6WFpH6K2KKpVpDnU,2989
20
+ pywavelet/transforms/numpy/inverse/to_freq.py,sha256=so_TDbwdS1N8sd1QcpeAEkI10XFDtoFJGohtD4YulZM,2809
21
+ pywavelet/transforms/numpy/inverse/to_time.py,sha256=w5vmImdsb_4YeInZtXh0llsThLTxS0tmYDlNGJ-IUew,5080
22
+ pywavelet/types/__init__.py,sha256=5YptzQvYBnRfC8N5lpOBf9I1lzpJ0pw0QMnvIcwP3YI,122
23
+ pywavelet/types/common.py,sha256=aIcYq-0KOLHnPQjrVbVmw_TQ3Xm5a7xA30rSgwt3rk4,1275
24
+ pywavelet/types/frequencyseries.py,sha256=hrtLaIUaRrqXw8l00yFe2tPJwpksDa_4n1z6R8XSPPQ,7531
25
+ pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
26
+ pywavelet/types/timeseries.py,sha256=u35bIqFo3QdlQRBEu6maeWA7DePS11LQ6WMiLjZPcWo,9456
27
+ pywavelet/types/wavelet.py,sha256=el48oyAfwtSw2tCQLUb85F9lKr0qMSRJPUmAUU8TS50,12552
28
+ pywavelet/types/wavelet_bins.py,sha256=GoQGKeZlPc-KbYY7LoxAhB-HI4diHpPcTABBXRfUTLA,1459
29
+ pywavelet-0.2.0.dist-info/METADATA,sha256=rg9LNZxrykv39lKIGuX65v7WzuY93_D0oY268mMe5iw,1362
30
+ pywavelet-0.2.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
31
+ pywavelet-0.2.0.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
32
+ pywavelet-0.2.0.dist-info/RECORD,,
@@ -1,25 +0,0 @@
1
- pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
- pywavelet/_version.py,sha256=PKIMyjdUACH4-ONvtunQCnYE2UhlMfp9su83e3HXl5E,411
3
- pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
4
- pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
5
- pywavelet/transforms/__init__.py,sha256=1Ibsup9UwMajeZ9NCQ4BN15qZTeJ_EHkgGu8XNFdA18,255
6
- pywavelet/transforms/phi_computer.py,sha256=vo1PK9Z70kKV-1lfyRoxWdhSYqwIgJK5CRCCJVei3xI,4545
7
- pywavelet/transforms/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
8
- pywavelet/transforms/forward/from_freq.py,sha256=wCiyLpzJE3rGxYjQBdXlwkxPIRYhQWjKq0C_8zYlmDk,2697
9
- pywavelet/transforms/forward/from_time.py,sha256=-Y6VEKwDCYBAHAjLdO46vT-6alpM5fXTgTZ_xkYxqA8,2381
10
- pywavelet/transforms/forward/main.py,sha256=Gfy0sp-woy_3ihKMzuk2WuZ7dRk-Mm6sp5dVpYrSvj4,4005
11
- pywavelet/transforms/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
12
- pywavelet/transforms/inverse/main.py,sha256=Q7wUaRjB1sgqdB7dniWQGbPTWYQNnIsIrYtjsaHJEdE,3012
13
- pywavelet/transforms/inverse/to_freq.py,sha256=so_TDbwdS1N8sd1QcpeAEkI10XFDtoFJGohtD4YulZM,2809
14
- pywavelet/transforms/inverse/to_time.py,sha256=w5vmImdsb_4YeInZtXh0llsThLTxS0tmYDlNGJ-IUew,5080
15
- pywavelet/types/__init__.py,sha256=5YptzQvYBnRfC8N5lpOBf9I1lzpJ0pw0QMnvIcwP3YI,122
16
- pywavelet/types/common.py,sha256=OSAW6GqLTgqJ-RYEv__XbzsfFd8AFo5w-ctXQ4XAFZo,1317
17
- pywavelet/types/frequencyseries.py,sha256=UqcE6UQfw5HZm4na2q9k-X-mfqO-BCiTAvGjaYpSrwc,7518
18
- pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
19
- pywavelet/types/timeseries.py,sha256=6DPO0xLi4Dq2srhJLmavFMf4fYIC3wwdbyMU7lMdjTo,9446
20
- pywavelet/types/wavelet.py,sha256=ptTEnq6nRZiW2x6g_NV_FuoeVNOsbNOu6caSxYDZNgk,12583
21
- pywavelet/types/wavelet_bins.py,sha256=SC9nhyigWvOfs2TbH8-Ck_iS1Mrz6sG-8EaIFYLIHuk,1453
22
- pywavelet-0.1.1.dist-info/METADATA,sha256=OgKTqqZQKfhCCcsYpcXnrut-2EAH42c3tCpCGc0PIDg,1307
23
- pywavelet-0.1.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
24
- pywavelet-0.1.1.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
25
- pywavelet-0.1.1.dist-info/RECORD,,