pywavelet 0.2.2__tar.gz → 0.2.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {pywavelet-0.2.2 → pywavelet-0.2.4}/.github/workflows/ci.yml +1 -0
  2. {pywavelet-0.2.2 → pywavelet-0.2.4}/CHANGELOG.rst +44 -0
  3. {pywavelet-0.2.2 → pywavelet-0.2.4}/PKG-INFO +1 -1
  4. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/_version.py +2 -2
  5. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/backend.py +9 -4
  6. pywavelet-0.2.4/src/pywavelet/transforms/__init__.py +28 -0
  7. pywavelet-0.2.4/src/pywavelet/transforms/jax/__init__.py +12 -0
  8. pywavelet-0.2.4/src/pywavelet/transforms/jax/forward/from_freq.py +97 -0
  9. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/jax/inverse/main.py +3 -4
  10. {pywavelet-0.2.2/src/pywavelet/transforms/jax → pywavelet-0.2.4/src/pywavelet/transforms/numpy}/forward/__init__.py +0 -3
  11. pywavelet-0.2.4/src/pywavelet/transforms/numpy/inverse/__init__.py +3 -0
  12. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/types/wavelet.py +40 -6
  13. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet.egg-info/PKG-INFO +1 -1
  14. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet.egg-info/SOURCES.txt +1 -0
  15. pywavelet-0.2.4/tests/test_jax.py +105 -0
  16. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_mask.py +12 -11
  17. pywavelet-0.2.2/src/pywavelet/transforms/__init__.py +0 -17
  18. pywavelet-0.2.2/src/pywavelet/transforms/jax/__init__.py +0 -0
  19. pywavelet-0.2.2/src/pywavelet/transforms/jax/forward/from_freq.py +0 -56
  20. pywavelet-0.2.2/src/pywavelet/transforms/jax/inverse/__init__.py +0 -0
  21. {pywavelet-0.2.2 → pywavelet-0.2.4}/.github/workflows/docs.yml +0 -0
  22. {pywavelet-0.2.2 → pywavelet-0.2.4}/.github/workflows/pypi.yml +0 -0
  23. {pywavelet-0.2.2 → pywavelet-0.2.4}/.gitignore +0 -0
  24. {pywavelet-0.2.2 → pywavelet-0.2.4}/.pre-commit-config.yaml +0 -0
  25. {pywavelet-0.2.2 → pywavelet-0.2.4}/CITATION.cff +0 -0
  26. {pywavelet-0.2.2 → pywavelet-0.2.4}/README.rst +0 -0
  27. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/_config.yml +0 -0
  28. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/_static/demo.gif +0 -0
  29. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/_toc.yml +0 -0
  30. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/api.rst +0 -0
  31. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/example.ipynb +0 -0
  32. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/index.rst +0 -0
  33. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/logo.png +0 -0
  34. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/roundtrip_freq.png +0 -0
  35. {pywavelet-0.2.2 → pywavelet-0.2.4}/docs/roundtrip_time.png +0 -0
  36. {pywavelet-0.2.2 → pywavelet-0.2.4}/pyproject.toml +0 -0
  37. {pywavelet-0.2.2 → pywavelet-0.2.4}/setup.cfg +0 -0
  38. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/__init__.py +0 -0
  39. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/logger.py +0 -0
  40. {pywavelet-0.2.2/src/pywavelet/transforms/numpy → pywavelet-0.2.4/src/pywavelet/transforms/jax}/forward/__init__.py +0 -0
  41. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/jax/forward/from_time.py +0 -0
  42. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/jax/forward/main.py +0 -0
  43. {pywavelet-0.2.2/src/pywavelet/transforms/numpy → pywavelet-0.2.4/src/pywavelet/transforms/jax}/inverse/__init__.py +0 -0
  44. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/jax/inverse/to_freq.py +0 -0
  45. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/numpy/__init__.py +0 -0
  46. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/numpy/forward/from_freq.py +0 -0
  47. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/numpy/forward/from_time.py +0 -0
  48. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/numpy/forward/main.py +0 -0
  49. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/numpy/inverse/main.py +0 -0
  50. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/numpy/inverse/to_freq.py +0 -0
  51. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/numpy/inverse/to_time.py +0 -0
  52. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/transforms/phi_computer.py +0 -0
  53. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/types/__init__.py +0 -0
  54. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/types/common.py +0 -0
  55. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/types/frequencyseries.py +0 -0
  56. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/types/plotting.py +0 -0
  57. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/types/timeseries.py +0 -0
  58. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/types/wavelet_bins.py +0 -0
  59. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet/utils.py +0 -0
  60. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet.egg-info/dependency_links.txt +0 -0
  61. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet.egg-info/requires.txt +0 -0
  62. {pywavelet-0.2.2 → pywavelet-0.2.4}/src/pywavelet.egg-info/top_level.txt +0 -0
  63. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/conftest.py +0 -0
  64. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_data/roundtrip_chirp_freq.npz +0 -0
  65. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_data/roundtrip_chirp_time.npz +0 -0
  66. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_data/roundtrip_pure_f0_freq.npz +0 -0
  67. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_data/roundtrip_sine_freq.npz +0 -0
  68. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_data/roundtrip_sine_time.npz +0 -0
  69. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_lnl.py +0 -0
  70. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_phi.py +0 -0
  71. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_psd.py +0 -0
  72. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_roundtrip_conversion.py +0 -0
  73. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_snr.py +0 -0
  74. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_timefreq_type.py +0 -0
  75. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_version.py +0 -0
  76. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/test_wavelet_plot.py +0 -0
  77. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/utils/__init__.py +0 -0
  78. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/utils/generate_data.py +0 -0
  79. {pywavelet-0.2.2 → pywavelet-0.2.4}/tests/utils/plotting.py +0 -0
@@ -28,6 +28,7 @@ jobs:
28
28
  run: |
29
29
  python -m pip install --upgrade pip
30
30
  python -m pip install -e .[dev]
31
+ python -m pip install -e .[jax]
31
32
  pre-commit install
32
33
 
33
34
  - name: pre-commit
@@ -5,16 +5,60 @@ CHANGELOG
5
5
  =========
6
6
 
7
7
 
8
+ .. _changelog-v0.2.4:
9
+
10
+ v0.2.4 (2025-01-24)
11
+ ===================
12
+
13
+ Unknown
14
+ -------
15
+
16
+ * Merge branch 'main' of github.com:avivajpeyi/pywavelet into main (`d2c84d9`_)
17
+
18
+ .. _d2c84d9: https://github.com/pywavelet/pywavelet/commit/d2c84d980b1701baf99e40ba6191cbd9336cfa59
19
+
20
+
21
+ .. _changelog-v0.2.3:
22
+
23
+ v0.2.3 (2025-01-24)
24
+ ===================
25
+
26
+ Bug Fixes
27
+ ---------
28
+
29
+ * fix: add masks to filter out gaps (`26fe40a`_)
30
+
31
+ * fix: add backend check for os.environ (`98c0818`_)
32
+
33
+ * fix: add test for jax (`1940394`_)
34
+
35
+ Chores
36
+ ------
37
+
38
+ * chore(release): 0.2.3 (`d067461`_)
39
+
40
+ .. _26fe40a: https://github.com/pywavelet/pywavelet/commit/26fe40ace80d5f9d598e1efeba2f8ca4a6f1043b
41
+ .. _98c0818: https://github.com/pywavelet/pywavelet/commit/98c0818078190d829a23734f932f1f93c9932167
42
+ .. _1940394: https://github.com/pywavelet/pywavelet/commit/194039437a3a9b3ada303d101b4e2573ab7d0afd
43
+ .. _d067461: https://github.com/pywavelet/pywavelet/commit/d0674615df328774a0d80eb224b5c503fbd8f332
44
+
45
+
8
46
  .. _changelog-v0.2.2:
9
47
 
10
48
  v0.2.2 (2025-01-23)
11
49
  ===================
12
50
 
51
+ Chores
52
+ ------
53
+
54
+ * chore(release): 0.2.2 (`eed5d68`_)
55
+
13
56
  Unknown
14
57
  -------
15
58
 
16
59
  * Merge branch 'main' of github.com:pywavelet/pywavelet (`e8e2115`_)
17
60
 
61
+ .. _eed5d68: https://github.com/pywavelet/pywavelet/commit/eed5d6864276fc5f90c4866749903e3e358df5ca
18
62
  .. _e8e2115: https://github.com/pywavelet/pywavelet/commit/e8e2115e797a5001f236ff027a14ef226151dcc1
19
63
 
20
64
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.2'
16
- __version_tuple__ = version_tuple = (0, 2, 2)
15
+ __version__ = version = '0.2.4'
16
+ __version_tuple__ = version_tuple = (0, 2, 4)
@@ -1,5 +1,7 @@
1
1
  import os
2
2
 
3
+ from .logger import logger
4
+
3
5
  try:
4
6
  import jax
5
7
 
@@ -13,14 +15,17 @@ use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
13
15
 
14
16
  if use_jax:
15
17
  import jax.numpy as xp # type: ignore
16
- from jax.scipy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
18
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
17
19
  from jax.scipy.special import betainc # type: ignore
18
20
 
21
+ logger.info("Using JAX backend")
19
22
 
20
23
  else:
21
24
  import numpy as xp # type: ignore
22
- from numpy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
23
- from scipy.special import betainc # type: ignore
25
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
26
+ from scipy.special import betainc # type: ignore
27
+
28
+ logger.info("Using NumPy+numba backend")
24
29
 
25
30
 
26
- PI = xp.pi
31
+ PI = xp.pi
@@ -0,0 +1,28 @@
1
+ from ..backend import use_jax
2
+
3
+ if use_jax:
4
+ from .jax import (
5
+ from_freq_to_wavelet,
6
+ from_time_to_wavelet,
7
+ from_wavelet_to_freq,
8
+ from_wavelet_to_time,
9
+ )
10
+ else:
11
+ from .numpy import (
12
+ from_wavelet_to_time,
13
+ from_wavelet_to_freq,
14
+ from_time_to_wavelet,
15
+ from_freq_to_wavelet,
16
+ )
17
+
18
+ from .phi_computer import phi_vec, phitilde_vec, phitilde_vec_norm
19
+
20
+ __all__ = [
21
+ "from_wavelet_to_time",
22
+ "from_wavelet_to_freq",
23
+ "from_time_to_wavelet",
24
+ "from_freq_to_wavelet",
25
+ "phitilde_vec_norm",
26
+ "phi_vec",
27
+ "phitilde_vec",
28
+ ]
@@ -0,0 +1,12 @@
1
+ from ...logger import logger
2
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
3
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
+
5
+ logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
6
+
7
+ __all__ = [
8
+ "from_wavelet_to_time",
9
+ "from_wavelet_to_freq",
10
+ "from_time_to_wavelet",
11
+ "from_freq_to_wavelet",
12
+ ]
@@ -0,0 +1,97 @@
1
+ from functools import partial
2
+
3
+ import jax.numpy as jnp
4
+ from jax import jit
5
+ from jax.numpy.fft import ifft
6
+
7
+
8
+ @partial(jit, static_argnames=("Nf", "Nt"))
9
+ def transform_wavelet_freq_helper(
10
+ data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
11
+ ) -> jnp.ndarray:
12
+ """
13
+ Transforms input data from the frequency domain to the wavelet domain using a
14
+ pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
15
+
16
+ Parameters:
17
+ - data (jnp.ndarray): 1D array representing the input data in the frequency domain.
18
+ - Nf (int): Number of frequency bins.
19
+ - Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
20
+ - phif (jnp.ndarray): Pre-computed wavelet filter for frequency components.
21
+
22
+ Returns:
23
+ - wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
24
+ """
25
+
26
+ # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
27
+ wave = jnp.zeros((Nt, Nf))
28
+ f_bins = jnp.arange(Nf) # Frequency bin indices
29
+
30
+ # Compute base indices for time (i_base) and frequency (jj_base)
31
+ i_base = Nt // 2
32
+ jj_base = f_bins * Nt // 2
33
+
34
+ # Set initial values for the center of the transformation
35
+ initial_values = jnp.where(
36
+ (f_bins == 0)
37
+ | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
38
+ phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
39
+ phif[0] * data[f_bins * Nt // 2],
40
+ )
41
+
42
+ # Initialize a 2D array to store intermediate FFT input values
43
+ DX = jnp.zeros(
44
+ (Nf, Nt), dtype=jnp.complex64
45
+ ) # TODO: Check dtype -- is complex64 sufficient?
46
+ DX = DX.at[:, Nt // 2].set(
47
+ initial_values
48
+ ) # Set initial values at the center of the transformation (2 sided FFT)
49
+
50
+ # Compute time indices for all offsets around the midpoint
51
+ j_range = jnp.arange(
52
+ 1 - Nt // 2, Nt // 2
53
+ ) # Time offsets (centered around zero)
54
+ j = jnp.abs(j_range) # Absolute offset indices
55
+ i = i_base + j_range # Time indices in output array
56
+
57
+ # Compute conditions for edge cases
58
+ cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
59
+ cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
60
+ cond3 = j[None, :] == 0 # Center of the transformation (no offset)
61
+
62
+ # Compute frequency indices for the input data
63
+ jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
64
+ val = jnp.where(
65
+ cond1 | cond2, 0.0, phif[j] * data[jj]
66
+ ) # Wavelet filter application
67
+ DX = DX.at[:, i].set(
68
+ jnp.where(cond3, DX[:, i], val)
69
+ ) # Update DX with computed values
70
+ # At this point, DX contains the data FFT'd with the wavelet filter
71
+ # (each row is a frequency bin, each column is a time bin)
72
+
73
+ # Perform the inverse FFT along the time dimension
74
+ DX_trans = ifft(DX, axis=1)
75
+
76
+ # Fill the wavelet output array based on the inverse FFT results
77
+ n_range = jnp.arange(Nt) # Time indices
78
+ cond1 = (
79
+ n_range[:, None] + f_bins[None, :]
80
+ ) % 2 == 1 # Odd/even alternation
81
+ cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
82
+
83
+ # Assign real and imaginary parts based on conditions
84
+ real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
85
+ imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
86
+ wave = jnp.where(cond1, imag_part.T, real_part.T)
87
+
88
+ # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
89
+ wave = wave.at[::2, 0].set(
90
+ jnp.real(DX_trans[0, ::2] * jnp.sqrt(2))
91
+ ) # DC component
92
+ wave = wave.at[1::2, -1].set(
93
+ jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2))
94
+ ) # Nyquist component
95
+
96
+ # Return the wavelet-transformed array (transposed for freq-major layout)
97
+ return wave.T # (Nt, Nf) -> (Nf, Nt)
@@ -1,10 +1,9 @@
1
1
  import jax.numpy as jnp
2
2
  from jax.numpy.fft import rfftfreq
3
3
 
4
- from ...phi_computer import phi_vec, phitilde_vec_norm
5
4
  from ....types import FrequencySeries, TimeSeries, Wavelet
5
+ from ...phi_computer import phi_vec, phitilde_vec_norm
6
6
  from .to_freq import inverse_wavelet_freq_helper
7
- # from .inverse_wavelet_time_funcs import inverse_wavelet_time_helper
8
7
 
9
8
 
10
9
  def from_wavelet_to_time(
@@ -65,5 +64,5 @@ def from_wavelet_to_freq(
65
64
  -1 / 2
66
65
  ) # Normalise to get the proper backwards transformation
67
66
 
68
- freqs = rfftfreq(wave_in.ND*2, d=dt)[1:]
69
- return FrequencySeries(data=freq_data, freq=freqs)
67
+ freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
68
+ return FrequencySeries(data=freq_data, freq=freqs)
@@ -1,6 +1,3 @@
1
1
  from .main import from_freq_to_wavelet, from_time_to_wavelet
2
- from ....logger import logger
3
-
4
- logger.warning("JAX SUBPACKAGE NOT YET TESTED")
5
2
 
6
3
  __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -0,0 +1,3 @@
1
+ from .main import from_wavelet_to_freq, from_wavelet_to_time
2
+
3
+ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
@@ -1,6 +1,7 @@
1
1
  from typing import List, Tuple
2
2
 
3
3
  import matplotlib.pyplot as plt
4
+ import numpy as np
4
5
 
5
6
  from .common import fmt_timerange, is_documented_by, xp
6
7
  from .plotting import plot_wavelet_grid, plot_wavelet_trend
@@ -343,14 +344,18 @@ class Wavelet:
343
344
 
344
345
  def __mul__(self, other):
345
346
  """Element-wise multiplication of two Wavelet objects."""
346
- if isinstance(other, Wavelet):
347
- return Wavelet(
348
- data=self.data * other.data, time=self.time, freq=self.freq
349
- )
347
+ if isinstance(other, WaveletMask):
348
+ data = self.data.copy()
349
+ data[~other.mask] = np.nan
350
+ return Wavelet(data=data, time=self.time, freq=self.freq)
350
351
  elif isinstance(other, float):
351
352
  return Wavelet(
352
353
  data=self.data * other, time=self.time, freq=self.freq
353
354
  )
355
+ elif isinstance(other, WaveletMask):
356
+ return Wavelet(
357
+ data=self.data * other.data, time=self.time, freq=self.freq
358
+ )
354
359
 
355
360
  def __truediv__(self, other):
356
361
  """Element-wise division of two Wavelet objects."""
@@ -445,11 +450,40 @@ class WaveletMask(Wavelet):
445
450
  return rpr
446
451
 
447
452
  @classmethod
448
- def from_frange(
449
- cls, time_grid: xp.ndarray, freq_grid: xp.ndarray, frange: List[float]
453
+ def from_restrictions(
454
+ cls,
455
+ time_grid: xp.ndarray,
456
+ freq_grid: xp.ndarray,
457
+ frange: List[float],
458
+ tgaps: List[Tuple[float, float]] = [],
450
459
  ):
460
+ """
461
+ Create a WaveletMask object from restrictions on time and frequency.
462
+
463
+ Parameters
464
+ ----------
465
+ time_grid : xp.ndarray
466
+ Array of time points.
467
+ freq_grid : xp.ndarray
468
+ Array of corresponding frequency points.
469
+ frange : List[float]
470
+ Frequency range to include.
471
+ tgaps : List[Tuple[float, float]]
472
+ List of time gaps to exclude.
473
+
474
+ Returns
475
+ -------
476
+ WaveletMask
477
+ A WaveletMask object with the specified restrictions.
478
+ """
451
479
  self = cls.zeros_from_grid(time_grid, freq_grid)
452
480
  self.data[
453
481
  (freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
454
482
  ] = True
483
+
484
+ for tgap in tgaps:
485
+ self.data[
486
+ :, (time_grid >= tgap[0]) & (time_grid <= tgap[1])
487
+ ] = False
488
+ self.data = self.data.astype(bool)
455
489
  return self
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -53,6 +53,7 @@ src/pywavelet/types/timeseries.py
53
53
  src/pywavelet/types/wavelet.py
54
54
  src/pywavelet/types/wavelet_bins.py
55
55
  tests/conftest.py
56
+ tests/test_jax.py
56
57
  tests/test_lnl.py
57
58
  tests/test_mask.py
58
59
  tests/test_phi.py
@@ -0,0 +1,105 @@
1
+ import importlib
2
+ import os
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from conftest import monochromatic_wnm
7
+
8
+ from pywavelet.transforms import from_freq_to_wavelet, from_time_to_wavelet
9
+ from pywavelet.types import FrequencySeries, TimeSeries
10
+ from pywavelet.types.wavelet_bins import compute_bins
11
+ from pywavelet.utils import compute_snr, evolutionary_psd_from_stationary_psd
12
+
13
+
14
+ def test_toy_model_snr(plot_dir):
15
+ f0 = 20
16
+ dt = 0.0125
17
+ A = 2
18
+ Nt = 128
19
+ Nf = 256
20
+ ND = Nt * Nf
21
+ t = np.arange(0, ND) * dt
22
+ PSD_AMP = 1
23
+
24
+ ########################################
25
+ # Part1: Analytical SNR calculation
26
+ #######################################
27
+
28
+ # Eq 21
29
+ y = A * np.sin(2 * np.pi * f0 * t) # Signal waveform we wish to test
30
+ signal_timeseries = TimeSeries(y, t)
31
+ signal_freq = signal_timeseries.to_frequencyseries()
32
+ psd_freq = FrequencySeries(
33
+ PSD_AMP * np.ones(len(signal_freq)), signal_freq.freq
34
+ )
35
+ snr = signal_freq.optimal_snr(psd_freq)
36
+
37
+ ########################################
38
+ # Part2: Wavelet domain (numpy)
39
+ ########################################
40
+
41
+ signal_wavelet = from_freq_to_wavelet(signal_freq, Nf=Nf, Nt=Nt)
42
+ psd_wavelet = evolutionary_psd_from_stationary_psd(
43
+ psd=psd_freq.data,
44
+ psd_f=psd_freq.freq,
45
+ f_grid=signal_wavelet.freq,
46
+ t_grid=signal_wavelet.time,
47
+ dt=dt,
48
+ )
49
+ wdm_snr = compute_snr(signal_wavelet, signal_wavelet, psd_wavelet)
50
+ assert np.isclose(snr, wdm_snr, atol=0.5), f"{snr}!={wdm_snr}"
51
+
52
+ ########################################
53
+ # Part3: Wavelet domain (jax)
54
+ ########################################
55
+
56
+ from pywavelet.transforms.jax import (
57
+ from_freq_to_wavelet as jax_from_freq_to_wavelet,
58
+ )
59
+
60
+ signal_wavelet_jax = jax_from_freq_to_wavelet(signal_freq, Nf=Nf, Nt=Nt)
61
+ psd_wavelet_jax = evolutionary_psd_from_stationary_psd(
62
+ psd=psd_freq.data,
63
+ psd_f=psd_freq.freq,
64
+ f_grid=signal_wavelet_jax.freq,
65
+ t_grid=signal_wavelet_jax.time,
66
+ dt=dt,
67
+ )
68
+ wdm_snr_jax = compute_snr(
69
+ signal_wavelet_jax, signal_wavelet_jax, psd_wavelet_jax
70
+ )
71
+ assert np.isclose(snr, wdm_snr_jax, atol=0.5), f"{snr}!={wdm_snr_jax}"
72
+
73
+ wdm_diff = signal_wavelet - signal_wavelet_jax
74
+
75
+ ########################################
76
+ # Part4: Plot
77
+ ########################################
78
+
79
+ fig, ax = plt.subplots(1, 3, figsize=(15, 6))
80
+ signal_wavelet.plot(ax=ax[0])
81
+ signal_wavelet_jax.plot(ax=ax[1])
82
+ wdm_diff.plot(ax=ax[2])
83
+ ax[0].set_title(f"Numpy SNR={wdm_snr:.2f}")
84
+ ax[1].set_title(f"Jax SNR={wdm_snr_jax:.2f}")
85
+ ax[2].set_title("Difference")
86
+ plt.tight_layout()
87
+ plt.savefig(f"{plot_dir}/jax_vs_np.png")
88
+
89
+
90
+ def test_backend_loader():
91
+ # temporarily set os.environ["PYWAVELET_JAX"] = "1"
92
+
93
+ import pywavelet.backend
94
+
95
+ os.environ["PYWAVELET_JAX"] = "1"
96
+ importlib.reload(pywavelet.backend)
97
+ from pywavelet.backend import use_jax
98
+
99
+ assert use_jax
100
+ os.environ["PYWAVELET_JAX"] = "0"
101
+
102
+ importlib.reload(pywavelet.backend)
103
+ from pywavelet.backend import use_jax
104
+
105
+ assert not use_jax
@@ -13,25 +13,26 @@ def test_mask(plot_dir):
13
13
  psd = Wavelet(np.ones((h.Nf, h.Nt)), h.time, h.freq)
14
14
  assert compute_likelihood(d, h, psd) == 0
15
15
 
16
- mask = WaveletMask.from_frange(h.time, h.freq, [f0 - 0.5, f0 + 0.5])
16
+ mask = WaveletMask.from_restrictions(h.time, h.freq, [f0 - 0.5, f0 + 0.5])
17
+ dmasked = d * mask
17
18
  assert np.isclose(compute_likelihood(d, h, psd, mask), 0)
19
+ assert np.isclose(compute_likelihood(dmasked, h, psd), 0)
20
+ # number of nans in dmasked.data
21
+ assert np.isnan(dmasked.data).sum() != 0
18
22
 
19
- mask1 = WaveletMask.from_frange(h.time, h.freq, [f0 + 0.5, f0 + 1.5])
23
+ mask1 = WaveletMask.from_restrictions(
24
+ h.time, h.freq, [f0 + 0.5, f0 + 1.5], tgaps=[[1.7 * 60, 3.4 * 60]]
25
+ )
26
+ dmasked1 = d * mask1
20
27
  # assert np.isclose(compute_likelihood(d, h, psd_analysis, mask1), 0) == False
21
28
 
22
29
  # plt the 3 differnet datasets
23
30
  fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharex=True, sharey=True)
24
- kwgs = dict(
25
- cmap="viridis",
26
- aspect="auto",
27
- origin="lower",
28
- extent=[d.time[0], d.time[-1], d.freq[0], d.freq[-1]],
29
- )
30
- axes[0].imshow(d.data, **kwgs)
31
+ d.plot(ax=axes[0])
31
32
  axes[0].set_title("data")
32
- axes[1].imshow(d.data * mask.mask, **kwgs)
33
+ (d * mask).plot(ax=axes[1])
33
34
  axes[1].set_title("data*mask[f0-0.5, f0+0.5]")
34
- axes[2].imshow(d.data * mask1.mask, **kwgs)
35
+ (d * mask1).plot(ax=axes[2])
35
36
  axes[2].set_title("data*mask[f0+0.5, f0+1.5]")
36
37
  plt.tight_layout()
37
38
  fig.savefig(f"{plot_dir}/test_mask.png")
@@ -1,17 +0,0 @@
1
- from .numpy import (
2
- from_wavelet_to_time,
3
- from_wavelet_to_freq,
4
- from_time_to_wavelet,
5
- from_freq_to_wavelet,
6
- )
7
- from .phi_computer import phi_vec, phitilde_vec_norm, phitilde_vec
8
-
9
- __all__ = [
10
- "from_wavelet_to_time",
11
- "from_wavelet_to_freq",
12
- "from_time_to_wavelet",
13
- "from_freq_to_wavelet",
14
- "phitilde_vec_norm",
15
- "phi_vec",
16
- "phitilde_vec",
17
- ]
File without changes
@@ -1,56 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from functools import partial
4
- from jax import jit
5
- from jax.numpy.fft import ifft
6
-
7
- @partial(jit, static_argnames=('Nf', 'Nt'))
8
- def transform_wavelet_freq_helper(
9
- data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
10
- ) -> jnp.ndarray:
11
- # Initially all wrk being done in time-rws, freq-cols
12
- wave = jnp.zeros((Nt, Nf))
13
- f_bins = jnp.arange(Nf)
14
-
15
- i_base = Nt // 2
16
- jj_base = f_bins * Nt // 2
17
-
18
- initial_values = jnp.where(
19
- (f_bins == 0) | (f_bins == Nf),
20
- phif[0] * data[f_bins * Nt // 2] / 2.0,
21
- phif[0] * data[f_bins * Nt // 2]
22
- )
23
-
24
- DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)
25
- DX = DX.at[:, Nt // 2].set(initial_values)
26
-
27
- j_range = jnp.arange(1 - Nt // 2, Nt // 2)
28
- j = jnp.abs(j_range)
29
- i = i_base + j_range
30
-
31
- cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)
32
- cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)
33
- cond3 = j[None, :] == 0
34
-
35
- jj = jj_base[:, None] + j_range[None, :]
36
- val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])
37
- DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))
38
-
39
- # Vectorized ifft
40
- DX_trans = ifft(DX, axis=1)
41
-
42
- # Vectorized __fill_wave_2_jax
43
- n_range = jnp.arange(Nt)
44
- cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1
45
- cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # shape: (Nf, 1)
46
-
47
- real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
48
- imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
49
-
50
- wave = jnp.where(cond1, imag_part.T, real_part.T)
51
-
52
- ## Special cases for f_bin 0 and Nf
53
- wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))
54
- wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))
55
-
56
- return wave.T
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes