pywavelet 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.4'
16
- __version_tuple__ = version_tuple = (0, 2, 4)
15
+ __version__ = version = '0.2.5'
16
+ __version_tuple__ = version_tuple = (0, 2, 5)
@@ -1,12 +1,14 @@
1
+ from functools import partial
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  from jax import jit
4
6
  from jax.numpy.fft import rfft
5
- from functools import partial
6
7
 
7
- @partial(jit, static_argnames=('Nf', 'Nt', 'mult'))
8
+
9
+ @partial(jit, static_argnames=("Nf", "Nt", "mult"))
8
10
  def transform_wavelet_time_helper(
9
- data: jnp.ndarray, phi: jnp.ndarray, Nf: int, Nt: int, mult: int
11
+ data: jnp.ndarray, phi: jnp.ndarray, Nf: int, Nt: int, mult: int
10
12
  ) -> jnp.ndarray:
11
13
  """Helper function to do the wavelet transform in the time domain using JAX"""
12
14
  # Define constants
@@ -35,17 +37,23 @@ def transform_wavelet_time_helper(
35
37
  even_indices = jnp.nonzero(even_mask, size=even_mask.shape[0])[0]
36
38
 
37
39
  # Update wave for m=0 using even time bins
38
- wave = wave.at[even_indices, 0].set(jnp.real(wdata_trans[even_indices, 0]) / jnp.sqrt(2))
39
- wave = wave.at[even_indices + 1, 0].set(jnp.real(wdata_trans[even_indices, Nf * mult]) / jnp.sqrt(2))
40
+ wave = wave.at[even_indices, 0].set(
41
+ jnp.real(wdata_trans[even_indices, 0]) / jnp.sqrt(2)
42
+ )
43
+ wave = wave.at[even_indices + 1, 0].set(
44
+ jnp.real(wdata_trans[even_indices, Nf * mult]) / jnp.sqrt(2)
45
+ )
40
46
 
41
47
  # Handle other cases (j > 0) using vectorized operations
42
48
  j_range = jnp.arange(1, Nf)
43
- odd_condition = ((time_bins[:, None] + j_range[None, :]) % 2 == 1)
49
+ odd_condition = (time_bins[:, None] + j_range[None, :]) % 2 == 1
44
50
 
45
51
  wave = wave.at[:, 1:].set(
46
- jnp.where(odd_condition,
47
- -jnp.imag(wdata_trans[:, j_range * mult]),
48
- jnp.real(wdata_trans[:, j_range * mult]))
52
+ jnp.where(
53
+ odd_condition,
54
+ -jnp.imag(wdata_trans[:, j_range * mult]),
55
+ jnp.real(wdata_trans[:, j_range * mult]),
56
+ )
49
57
  )
50
58
 
51
- return wave.T
59
+ return wave.T
@@ -62,10 +62,10 @@ def from_time_to_wavelet(
62
62
 
63
63
  mult = min(mult, Nt // 2) # make sure K isn't bigger than ND
64
64
  phi = jnp.array(phi_vec(Nf, d=nx, q=mult))
65
- wave = transform_wavelet_time_helper(timeseries.data, Nf=Nf, Nt=Nt, phi=phi, mult=mult)
66
- return Wavelet(
67
- wave* jnp.sqrt(2), time=t_bins, freq=f_bins
65
+ wave = transform_wavelet_time_helper(
66
+ timeseries.data, Nf=Nf, Nt=Nt, phi=phi, mult=mult
68
67
  )
68
+ return Wavelet(wave * jnp.sqrt(2), time=t_bins, freq=f_bins)
69
69
 
70
70
 
71
71
  def from_freq_to_wavelet(
@@ -98,13 +98,9 @@ def from_freq_to_wavelet(
98
98
  """
99
99
  Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
100
100
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
101
- phif = jnp.array(phitilde_vec_norm(Nf, Nt, d=nx))
102
- wave = transform_wavelet_freq_helper(
101
+ phif = jnp.array(phitilde_vec_norm(Nf, Nt, d=nx))
102
+ wave = transform_wavelet_freq_helper(
103
103
  freqseries.data, Nf=Nf, Nt=Nt, phif=phif
104
104
  )
105
105
 
106
- return Wavelet(
107
- (2 / Nf) * wave * jnp.sqrt(2),
108
- time=t_bins,
109
- freq=f_bins
110
- )
106
+ return Wavelet((2 / Nf) * wave * jnp.sqrt(2), time=t_bins, freq=f_bins)
@@ -1,14 +1,14 @@
1
+ from functools import partial
2
+
1
3
  import jax
2
4
  import jax.numpy as jnp
3
5
  from jax import jit
4
6
  from jax.numpy.fft import fft
5
7
 
6
- from functools import partial
7
-
8
8
 
9
- @partial(jit, static_argnames=('Nf', 'Nt'))
9
+ @partial(jit, static_argnames=("Nf", "Nt"))
10
10
  def inverse_wavelet_freq_helper(
11
- wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
11
+ wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
12
12
  ) -> jnp.ndarray:
13
13
  """JAX vectorized function for inverse_wavelet_freq"""
14
14
  wave_in = wave_in.T
@@ -20,10 +20,14 @@ def inverse_wavelet_freq_helper(
20
20
  n_range = jnp.arange(Nt)
21
21
 
22
22
  # m == 0 case
23
- prefactor2s = prefactor2s.at[0].set(2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0])
23
+ prefactor2s = prefactor2s.at[0].set(
24
+ 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0]
25
+ )
24
26
 
25
27
  # m == Nf case
26
- prefactor2s = prefactor2s.at[Nf].set(2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0])
28
+ prefactor2s = prefactor2s.at[Nf].set(
29
+ 2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0]
30
+ )
27
31
 
28
32
  # Other m cases
29
33
  m_mid = m_range[1:Nf]
@@ -67,4 +71,4 @@ def inverse_wavelet_freq_helper(
67
71
  res = res.at[i1].add(fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid])
68
72
  res = res.at[i2].add(fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid])
69
73
 
70
- return res
74
+ return res
@@ -1,10 +1,9 @@
1
1
  from .forward import from_freq_to_wavelet, from_time_to_wavelet
2
2
  from .inverse import from_wavelet_to_freq, from_wavelet_to_time
3
3
 
4
-
5
4
  __all__ = [
6
5
  "from_wavelet_to_time",
7
6
  "from_wavelet_to_freq",
8
7
  "from_time_to_wavelet",
9
8
  "from_freq_to_wavelet",
10
- ]
9
+ ]
@@ -113,7 +113,7 @@ def from_freq_to_wavelet(
113
113
  """
114
114
  Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
115
115
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
116
- phif = phitilde_vec_norm(Nf, Nt, d=nx)
116
+ phif = phitilde_vec_norm(Nf, Nt, d=nx)
117
117
  wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
118
118
 
119
119
  return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
 
3
- from ...phi_computer import phi_vec, phitilde_vec_norm
4
3
  from ....types import FrequencySeries, TimeSeries, Wavelet
4
+ from ...phi_computer import phi_vec, phitilde_vec_norm
5
5
  from .to_freq import inverse_wavelet_freq_helper_fast
6
6
  from .to_time import inverse_wavelet_time_helper_fast
7
7
 
@@ -84,7 +84,7 @@ def from_wavelet_to_freq(
84
84
  to ensure the proper backwards transformation.
85
85
  """
86
86
 
87
- phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx)
87
+ phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt, d=nx)
88
88
  freq_data = inverse_wavelet_freq_helper_fast(
89
89
  wave_in.data, phif, wave_in.Nf, wave_in.Nt
90
90
  )
@@ -1,10 +1,7 @@
1
- from ..backend import xp, PI, betainc, ifft
1
+ from ..backend import PI, betainc, ifft, xp
2
2
 
3
3
 
4
-
5
- def phitilde_vec(
6
- omega: xp.ndarray, Nf: int, d: float = 4.0
7
- ) -> xp.ndarray:
4
+ def phitilde_vec(omega: xp.ndarray, Nf: int, d: float = 4.0) -> xp.ndarray:
8
5
  """Compute phi_tilde(omega_i) array, nx is filter steepness, defaults to 4.
9
6
 
10
7
  Eq 11 of https://arxiv.org/pdf/2009.00043.pdf (Cornish et al. 2020)
@@ -78,7 +75,7 @@ def __nu_d(
78
75
  return betainc(d, d, x) / betainc(d, d, 1)
79
76
 
80
77
 
81
- def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> xp.ndarray:
78
+ def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> xp.ndarray:
82
79
  """Normalize phitilde for inverse frequency domain transform."""
83
80
 
84
81
  # Calculate the frequency values
@@ -86,7 +83,7 @@ def phitilde_vec_norm(Nf: int, Nt: int, d: float) -> xp.ndarray:
86
83
  omegas = 2 * xp.pi / ND * xp.arange(0, Nt // 2 + 1)
87
84
 
88
85
  # Calculate the unnormalized phitilde (u_phit)
89
- u_phit = phitilde_vec(omegas, Nf, d)
86
+ u_phit = phitilde_vec(omegas, Nf, d)
90
87
 
91
88
  # Normalize the phitilde
92
89
  normalising_factor = PI ** (-1 / 2) # Ollie's normalising factor
@@ -129,13 +126,9 @@ def phi_vec(Nf: int, d: float = 4.0, q: int = 16) -> xp.ndarray:
129
126
 
130
127
  DX = DX.copy()
131
128
  # postive frequencies
132
- DX[1 : half_K + 1] = phitilde_vec(
133
- dom * xp.arange(1, half_K + 1), Nf, d
134
- )
129
+ DX[1 : half_K + 1] = phitilde_vec(dom * xp.arange(1, half_K + 1), Nf, d)
135
130
  # negative frequencies
136
- DX[half_K + 1 :] = phitilde_vec(
137
- -dom * xp.arange(half_K - 1, 0, -1), Nf, d
138
- )
131
+ DX[half_K + 1 :] = phitilde_vec(-dom * xp.arange(half_K - 1, 0, -1), Nf, d)
139
132
  DX = K * ifft(DX, K)
140
133
 
141
134
  phi = xp.zeros(K)
pywavelet/types/common.py CHANGED
@@ -1,6 +1,7 @@
1
- from typing import Tuple, Union, Callable
2
- from ..logger import logger
1
+ from typing import Callable, Tuple, Union
2
+
3
3
  from ..backend import xp
4
+ from ..logger import logger
4
5
 
5
6
 
6
7
  def _len_check(d):
@@ -8,7 +9,7 @@ def _len_check(d):
8
9
  logger.warning(f"Data length {len(d)} is suggested to be a power of 2")
9
10
 
10
11
 
11
- def is_documented_by(original:Callable):
12
+ def is_documented_by(original: Callable):
12
13
  def wrapper(target):
13
14
  target.__doc__ = original.__doc__
14
15
  return target
@@ -2,7 +2,7 @@ from typing import Optional, Tuple, Union
2
2
 
3
3
  import matplotlib.pyplot as plt
4
4
 
5
- from ..backend import xp, irfft
5
+ from ..backend import irfft, xp
6
6
  from .common import fmt_pow2, fmt_time, is_documented_by
7
7
  from .plotting import plot_freqseries, plot_periodogram
8
8
 
@@ -4,15 +4,9 @@ import matplotlib.pyplot as plt
4
4
  from scipy.signal import butter, sosfiltfilt
5
5
  from scipy.signal.windows import tukey
6
6
 
7
+ from ..backend import rfft, rfftfreq, xp
7
8
  from ..logger import logger
8
- from .common import (
9
- fmt_pow2,
10
- fmt_time,
11
- fmt_timerange,
12
- is_documented_by,
13
- )
14
- from ..backend import xp, rfftfreq, rfft
15
-
9
+ from .common import fmt_pow2, fmt_time, fmt_timerange, is_documented_by
16
10
  from .plotting import plot_spectrogram, plot_timeseries
17
11
 
18
12
  __all__ = ["TimeSeries"]
@@ -272,9 +266,9 @@ class TimeSeries:
272
266
  sos = butter(
273
267
  bandpass_order, Wn=fmin, btype="highpass", output="sos", fs=self.fs
274
268
  )
275
- window = tukey(self.ND, alpha=tukey_window_alpha)
276
269
  data = self.data.copy()
277
- data = sosfiltfilt(sos, data * window)
270
+ data = sosfiltfilt(sos, data)
271
+ data = data * tukey(self.ND, alpha=tukey_window_alpha)
278
272
  return TimeSeries(data, self.time)
279
273
 
280
274
  def __copy__(self):
@@ -1,7 +1,6 @@
1
1
  from typing import Tuple, Union
2
2
 
3
3
  from ..backend import xp
4
-
5
4
  from .frequencyseries import FrequencySeries
6
5
  from .timeseries import TimeSeries
7
6
 
pywavelet/utils.py CHANGED
@@ -71,6 +71,8 @@ def compute_likelihood(
71
71
  p = psd.data
72
72
  if mask is not None:
73
73
  m = mask.mask
74
+ # convert mask to numbers -- 0 for False, 1 for True
75
+ m = m.astype(int)
74
76
  d, h, p = d * m, h * m, p * m
75
77
 
76
78
  return -0.5 * np.nansum((d - h) ** 2 / p)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.4
3
+ Version: 0.2.5
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -76,10 +76,10 @@ First set up a conda environment with the latest version of python.
76
76
  Test code
77
77
  ---------
78
78
 
79
- Locate directory /tests from root directory. run
79
+ Locate directory /tests from root directory. run
80
80
 
81
81
  .. code-block::
82
82
 
83
83
  $ pytest .
84
84
 
85
- Hopefully everything should run fine.
85
+ Hopefully everything should run fine.
@@ -1,35 +1,35 @@
1
1
  pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
- pywavelet/_version.py,sha256=4gL0W4-u58XR5lRLpeoIPrGhcewTk0-527de6uTNmkg,411
2
+ pywavelet/_version.py,sha256=8n5F0z2KUdXF-kMA1VxiH6w0elyZUUWXosPRzNz_3I0,411
3
3
  pywavelet/backend.py,sha256=SmpgIBHvTO1rtIAQQN_zpVB8i6R-x23FNKJG6_JlrNs,666
4
4
  pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
5
- pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
5
+ pywavelet/utils.py,sha256=FqQ6V41WGHMbLC4wv_1xnwHjOPDVSWnG78sAeqbYtYU,1994
6
6
  pywavelet/transforms/__init__.py,sha256=EYX8glRWojYbrjtbgrjS4vigYTRi7FOtIV3D1UwI5fY,604
7
- pywavelet/transforms/phi_computer.py,sha256=ppFSGJwtNnO2flaiok9ms3WXlAxGQikvA7eNfLgriNQ,4461
7
+ pywavelet/transforms/phi_computer.py,sha256=d39RfU-7Zbdo4GnmvrV21fXnFLEtf1ccPX_mBHg96Lw,4423
8
8
  pywavelet/transforms/jax/__init__.py,sha256=D_f-JgFAzOIJ-EuQZhTMziD4MT6lVWS3XV9s51Cu7Kg,335
9
9
  pywavelet/transforms/jax/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
10
10
  pywavelet/transforms/jax/forward/from_freq.py,sha256=tKEdqPyEvX8ZKVQf16wGxN3d6gkcjm_RtAHQuWHUzy4,3764
11
- pywavelet/transforms/jax/forward/from_time.py,sha256=xNeoZq54B6Gi3TdTTYLr_euaFeJcwpms-lSyCG53AdI,1726
12
- pywavelet/transforms/jax/forward/main.py,sha256=mm0R4m0pXcnzZB0jCckAc4ynG8STH5mldCmHyyU_PGo,3091
11
+ pywavelet/transforms/jax/forward/from_time.py,sha256=4RZ8-ah0qOMP20i3-xThVWddxa1QTCvZKnGpNAJbb0g,1765
12
+ pywavelet/transforms/jax/forward/main.py,sha256=7gpHUycEclDwlb6KpLqUZoIkhJjPH0sBITBGVqepYAI,3061
13
13
  pywavelet/transforms/jax/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
14
14
  pywavelet/transforms/jax/inverse/main.py,sha256=-HVOOBsYo3GJvGNCsQLbNPnt9s14JvbB2bGAd9LOr3A,1647
15
- pywavelet/transforms/jax/inverse/to_freq.py,sha256=ASNARcDBJQr4EizAP_77e5ai36iPwP6hzfvwGbZQ6BM,2295
16
- pywavelet/transforms/numpy/__init__.py,sha256=qFLpGpW3VJSbDp2JpD0Gx7PdwDjH-wrW_aO84ASkIgA,255
15
+ pywavelet/transforms/jax/inverse/to_freq.py,sha256=y27Lx797Hcmg3gZLS1IpoyRalmJk54my91farO6G64M,2320
16
+ pywavelet/transforms/numpy/__init__.py,sha256=1Ibsup9UwMajeZ9NCQ4BN15qZTeJ_EHkgGu8XNFdA18,255
17
17
  pywavelet/transforms/numpy/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
18
18
  pywavelet/transforms/numpy/forward/from_freq.py,sha256=JmJyjrNSb64WnpP50VZRt0BICP64iZJP5QAZTZoexkw,2675
19
19
  pywavelet/transforms/numpy/forward/from_time.py,sha256=-Y6VEKwDCYBAHAjLdO46vT-6alpM5fXTgTZ_xkYxqA8,2381
20
- pywavelet/transforms/numpy/forward/main.py,sha256=3y-YCnhpvN7M4N7xy3CVts7n3QQPwDcJ6mkklX1QbFM,3973
20
+ pywavelet/transforms/numpy/forward/main.py,sha256=1YCwBuhWAPFyah-XHT5Q98rMjVtoSnaEvH3sl6NWWmA,3972
21
21
  pywavelet/transforms/numpy/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
22
- pywavelet/transforms/numpy/inverse/main.py,sha256=-11U5tnDizIssHk824rpYrzbJRl6WFpH6K2KKpVpDnU,2989
22
+ pywavelet/transforms/numpy/inverse/main.py,sha256=rzPsukcvOgN4vqtTpVxNSb-60KT-L-1AYr3_OSeDulk,2988
23
23
  pywavelet/transforms/numpy/inverse/to_freq.py,sha256=so_TDbwdS1N8sd1QcpeAEkI10XFDtoFJGohtD4YulZM,2809
24
24
  pywavelet/transforms/numpy/inverse/to_time.py,sha256=w5vmImdsb_4YeInZtXh0llsThLTxS0tmYDlNGJ-IUew,5080
25
25
  pywavelet/types/__init__.py,sha256=5YptzQvYBnRfC8N5lpOBf9I1lzpJ0pw0QMnvIcwP3YI,122
26
- pywavelet/types/common.py,sha256=aIcYq-0KOLHnPQjrVbVmw_TQ3Xm5a7xA30rSgwt3rk4,1275
27
- pywavelet/types/frequencyseries.py,sha256=hrtLaIUaRrqXw8l00yFe2tPJwpksDa_4n1z6R8XSPPQ,7531
26
+ pywavelet/types/common.py,sha256=_SMmXLrRO0Nw_A7Oa6C10kZAbj8jq9agXx7tMDjnYJg,1277
27
+ pywavelet/types/frequencyseries.py,sha256=tAbZr0vEBCe0MwH7ZjaK00UVupjRNxvjoW9LCMsiiMo,7531
28
28
  pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
29
- pywavelet/types/timeseries.py,sha256=u35bIqFo3QdlQRBEu6maeWA7DePS11LQ6WMiLjZPcWo,9456
29
+ pywavelet/types/timeseries.py,sha256=sataMW4BPFqi23h_NBZ_U9-Svuo9pLXVRmUJI6KTXG0,9430
30
30
  pywavelet/types/wavelet.py,sha256=uHJzTS2ZXTRr7I7NHWv3qNjknSBhQUpcED3jM6ti7UM,13587
31
- pywavelet/types/wavelet_bins.py,sha256=GoQGKeZlPc-KbYY7LoxAhB-HI4diHpPcTABBXRfUTLA,1459
32
- pywavelet-0.2.4.dist-info/METADATA,sha256=Thhhz8I2XTKr0mVuf09UpcvjeEGKUnVUX0jxENu6gEQ,2241
33
- pywavelet-0.2.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
34
- pywavelet-0.2.4.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
35
- pywavelet-0.2.4.dist-info/RECORD,,
31
+ pywavelet/types/wavelet_bins.py,sha256=gBjhWwfjcbbSnbGZVMNUeFFVUo2DVxJS4abDUVCL7ts,1458
32
+ pywavelet-0.2.5.dist-info/METADATA,sha256=jfqO4IRwiYgyhfGvE9Wrbpw4t5P6L8kuLgAA7W82Jls,2239
33
+ pywavelet-0.2.5.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
34
+ pywavelet-0.2.5.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
35
+ pywavelet-0.2.5.dist-info/RECORD,,