pywavelet 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {pywavelet-0.2.2 → pywavelet-0.2.3}/.github/workflows/ci.yml +1 -0
  2. {pywavelet-0.2.2 → pywavelet-0.2.3}/CHANGELOG.rst +22 -0
  3. {pywavelet-0.2.2 → pywavelet-0.2.3}/PKG-INFO +1 -1
  4. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/_version.py +2 -2
  5. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/backend.py +9 -4
  6. pywavelet-0.2.3/src/pywavelet/transforms/__init__.py +28 -0
  7. pywavelet-0.2.3/src/pywavelet/transforms/jax/__init__.py +12 -0
  8. pywavelet-0.2.3/src/pywavelet/transforms/jax/forward/from_freq.py +97 -0
  9. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/jax/inverse/main.py +3 -4
  10. {pywavelet-0.2.2/src/pywavelet/transforms/jax → pywavelet-0.2.3/src/pywavelet/transforms/numpy}/forward/__init__.py +0 -3
  11. pywavelet-0.2.3/src/pywavelet/transforms/numpy/inverse/__init__.py +3 -0
  12. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet.egg-info/PKG-INFO +1 -1
  13. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet.egg-info/SOURCES.txt +1 -0
  14. pywavelet-0.2.3/tests/test_jax.py +105 -0
  15. pywavelet-0.2.2/src/pywavelet/transforms/__init__.py +0 -17
  16. pywavelet-0.2.2/src/pywavelet/transforms/jax/__init__.py +0 -0
  17. pywavelet-0.2.2/src/pywavelet/transforms/jax/forward/from_freq.py +0 -56
  18. pywavelet-0.2.2/src/pywavelet/transforms/jax/inverse/__init__.py +0 -0
  19. {pywavelet-0.2.2 → pywavelet-0.2.3}/.github/workflows/docs.yml +0 -0
  20. {pywavelet-0.2.2 → pywavelet-0.2.3}/.github/workflows/pypi.yml +0 -0
  21. {pywavelet-0.2.2 → pywavelet-0.2.3}/.gitignore +0 -0
  22. {pywavelet-0.2.2 → pywavelet-0.2.3}/.pre-commit-config.yaml +0 -0
  23. {pywavelet-0.2.2 → pywavelet-0.2.3}/CITATION.cff +0 -0
  24. {pywavelet-0.2.2 → pywavelet-0.2.3}/README.rst +0 -0
  25. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/_config.yml +0 -0
  26. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/_static/demo.gif +0 -0
  27. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/_toc.yml +0 -0
  28. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/api.rst +0 -0
  29. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/example.ipynb +0 -0
  30. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/index.rst +0 -0
  31. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/logo.png +0 -0
  32. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/roundtrip_freq.png +0 -0
  33. {pywavelet-0.2.2 → pywavelet-0.2.3}/docs/roundtrip_time.png +0 -0
  34. {pywavelet-0.2.2 → pywavelet-0.2.3}/pyproject.toml +0 -0
  35. {pywavelet-0.2.2 → pywavelet-0.2.3}/setup.cfg +0 -0
  36. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/__init__.py +0 -0
  37. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/logger.py +0 -0
  38. {pywavelet-0.2.2/src/pywavelet/transforms/numpy → pywavelet-0.2.3/src/pywavelet/transforms/jax}/forward/__init__.py +0 -0
  39. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/jax/forward/from_time.py +0 -0
  40. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/jax/forward/main.py +0 -0
  41. {pywavelet-0.2.2/src/pywavelet/transforms/numpy → pywavelet-0.2.3/src/pywavelet/transforms/jax}/inverse/__init__.py +0 -0
  42. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/jax/inverse/to_freq.py +0 -0
  43. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/__init__.py +0 -0
  44. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/forward/from_freq.py +0 -0
  45. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/forward/from_time.py +0 -0
  46. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/forward/main.py +0 -0
  47. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/inverse/main.py +0 -0
  48. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/inverse/to_freq.py +0 -0
  49. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/inverse/to_time.py +0 -0
  50. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/transforms/phi_computer.py +0 -0
  51. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/types/__init__.py +0 -0
  52. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/types/common.py +0 -0
  53. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/types/frequencyseries.py +0 -0
  54. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/types/plotting.py +0 -0
  55. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/types/timeseries.py +0 -0
  56. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/types/wavelet.py +0 -0
  57. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/types/wavelet_bins.py +0 -0
  58. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet/utils.py +0 -0
  59. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet.egg-info/dependency_links.txt +0 -0
  60. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet.egg-info/requires.txt +0 -0
  61. {pywavelet-0.2.2 → pywavelet-0.2.3}/src/pywavelet.egg-info/top_level.txt +0 -0
  62. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/conftest.py +0 -0
  63. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_data/roundtrip_chirp_freq.npz +0 -0
  64. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_data/roundtrip_chirp_time.npz +0 -0
  65. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_data/roundtrip_pure_f0_freq.npz +0 -0
  66. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_data/roundtrip_sine_freq.npz +0 -0
  67. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_data/roundtrip_sine_time.npz +0 -0
  68. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_lnl.py +0 -0
  69. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_mask.py +0 -0
  70. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_phi.py +0 -0
  71. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_psd.py +0 -0
  72. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_roundtrip_conversion.py +0 -0
  73. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_snr.py +0 -0
  74. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_timefreq_type.py +0 -0
  75. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_version.py +0 -0
  76. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/test_wavelet_plot.py +0 -0
  77. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/utils/__init__.py +0 -0
  78. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/utils/generate_data.py +0 -0
  79. {pywavelet-0.2.2 → pywavelet-0.2.3}/tests/utils/plotting.py +0 -0
@@ -28,6 +28,7 @@ jobs:
28
28
  run: |
29
29
  python -m pip install --upgrade pip
30
30
  python -m pip install -e .[dev]
31
+ python -m pip install -e .[jax]
31
32
  pre-commit install
32
33
 
33
34
  - name: pre-commit
@@ -5,16 +5,38 @@ CHANGELOG
5
5
  =========
6
6
 
7
7
 
8
+ .. _changelog-v0.2.3:
9
+
10
+ v0.2.3 (2025-01-24)
11
+ ===================
12
+
13
+ Bug Fixes
14
+ ---------
15
+
16
+ * fix: add backend check for os.environ (`98c0818`_)
17
+
18
+ * fix: add test for jax (`1940394`_)
19
+
20
+ .. _98c0818: https://github.com/pywavelet/pywavelet/commit/98c0818078190d829a23734f932f1f93c9932167
21
+ .. _1940394: https://github.com/pywavelet/pywavelet/commit/194039437a3a9b3ada303d101b4e2573ab7d0afd
22
+
23
+
8
24
  .. _changelog-v0.2.2:
9
25
 
10
26
  v0.2.2 (2025-01-23)
11
27
  ===================
12
28
 
29
+ Chores
30
+ ------
31
+
32
+ * chore(release): 0.2.2 (`eed5d68`_)
33
+
13
34
  Unknown
14
35
  -------
15
36
 
16
37
  * Merge branch 'main' of github.com:pywavelet/pywavelet (`e8e2115`_)
17
38
 
39
+ .. _eed5d68: https://github.com/pywavelet/pywavelet/commit/eed5d6864276fc5f90c4866749903e3e358df5ca
18
40
  .. _e8e2115: https://github.com/pywavelet/pywavelet/commit/e8e2115e797a5001f236ff027a14ef226151dcc1
19
41
 
20
42
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.2'
16
- __version_tuple__ = version_tuple = (0, 2, 2)
15
+ __version__ = version = '0.2.3'
16
+ __version_tuple__ = version_tuple = (0, 2, 3)
@@ -1,5 +1,7 @@
1
1
  import os
2
2
 
3
+ from .logger import logger
4
+
3
5
  try:
4
6
  import jax
5
7
 
@@ -13,14 +15,17 @@ use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
13
15
 
14
16
  if use_jax:
15
17
  import jax.numpy as xp # type: ignore
16
- from jax.scipy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
18
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
17
19
  from jax.scipy.special import betainc # type: ignore
18
20
 
21
+ logger.info("Using JAX backend")
19
22
 
20
23
  else:
21
24
  import numpy as xp # type: ignore
22
- from numpy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
23
- from scipy.special import betainc # type: ignore
25
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
26
+ from scipy.special import betainc # type: ignore
27
+
28
+ logger.info("Using NumPy+numba backend")
24
29
 
25
30
 
26
- PI = xp.pi
31
+ PI = xp.pi
@@ -0,0 +1,28 @@
1
+ from ..backend import use_jax
2
+
3
+ if use_jax:
4
+ from .jax import (
5
+ from_freq_to_wavelet,
6
+ from_time_to_wavelet,
7
+ from_wavelet_to_freq,
8
+ from_wavelet_to_time,
9
+ )
10
+ else:
11
+ from .numpy import (
12
+ from_wavelet_to_time,
13
+ from_wavelet_to_freq,
14
+ from_time_to_wavelet,
15
+ from_freq_to_wavelet,
16
+ )
17
+
18
+ from .phi_computer import phi_vec, phitilde_vec, phitilde_vec_norm
19
+
20
+ __all__ = [
21
+ "from_wavelet_to_time",
22
+ "from_wavelet_to_freq",
23
+ "from_time_to_wavelet",
24
+ "from_freq_to_wavelet",
25
+ "phitilde_vec_norm",
26
+ "phi_vec",
27
+ "phitilde_vec",
28
+ ]
@@ -0,0 +1,12 @@
1
+ from ...logger import logger
2
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
3
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
+
5
+ logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
6
+
7
+ __all__ = [
8
+ "from_wavelet_to_time",
9
+ "from_wavelet_to_freq",
10
+ "from_time_to_wavelet",
11
+ "from_freq_to_wavelet",
12
+ ]
@@ -0,0 +1,97 @@
1
+ from functools import partial
2
+
3
+ import jax.numpy as jnp
4
+ from jax import jit
5
+ from jax.numpy.fft import ifft
6
+
7
+
8
+ @partial(jit, static_argnames=("Nf", "Nt"))
9
+ def transform_wavelet_freq_helper(
10
+ data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
11
+ ) -> jnp.ndarray:
12
+ """
13
+ Transforms input data from the frequency domain to the wavelet domain using a
14
+ pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
15
+
16
+ Parameters:
17
+ - data (jnp.ndarray): 1D array representing the input data in the frequency domain.
18
+ - Nf (int): Number of frequency bins.
19
+ - Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
20
+ - phif (jnp.ndarray): Pre-computed wavelet filter for frequency components.
21
+
22
+ Returns:
23
+ - wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
24
+ """
25
+
26
+ # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
27
+ wave = jnp.zeros((Nt, Nf))
28
+ f_bins = jnp.arange(Nf) # Frequency bin indices
29
+
30
+ # Compute base indices for time (i_base) and frequency (jj_base)
31
+ i_base = Nt // 2
32
+ jj_base = f_bins * Nt // 2
33
+
34
+ # Set initial values for the center of the transformation
35
+ initial_values = jnp.where(
36
+ (f_bins == 0)
37
+ | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
38
+ phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
39
+ phif[0] * data[f_bins * Nt // 2],
40
+ )
41
+
42
+ # Initialize a 2D array to store intermediate FFT input values
43
+ DX = jnp.zeros(
44
+ (Nf, Nt), dtype=jnp.complex64
45
+ ) # TODO: Check dtype -- is complex64 sufficient?
46
+ DX = DX.at[:, Nt // 2].set(
47
+ initial_values
48
+ ) # Set initial values at the center of the transformation (2 sided FFT)
49
+
50
+ # Compute time indices for all offsets around the midpoint
51
+ j_range = jnp.arange(
52
+ 1 - Nt // 2, Nt // 2
53
+ ) # Time offsets (centered around zero)
54
+ j = jnp.abs(j_range) # Absolute offset indices
55
+ i = i_base + j_range # Time indices in output array
56
+
57
+ # Compute conditions for edge cases
58
+ cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
59
+ cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
60
+ cond3 = j[None, :] == 0 # Center of the transformation (no offset)
61
+
62
+ # Compute frequency indices for the input data
63
+ jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
64
+ val = jnp.where(
65
+ cond1 | cond2, 0.0, phif[j] * data[jj]
66
+ ) # Wavelet filter application
67
+ DX = DX.at[:, i].set(
68
+ jnp.where(cond3, DX[:, i], val)
69
+ ) # Update DX with computed values
70
+ # At this point, DX contains the data FFT'd with the wavelet filter
71
+ # (each row is a frequency bin, each column is a time bin)
72
+
73
+ # Perform the inverse FFT along the time dimension
74
+ DX_trans = ifft(DX, axis=1)
75
+
76
+ # Fill the wavelet output array based on the inverse FFT results
77
+ n_range = jnp.arange(Nt) # Time indices
78
+ cond1 = (
79
+ n_range[:, None] + f_bins[None, :]
80
+ ) % 2 == 1 # Odd/even alternation
81
+ cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
82
+
83
+ # Assign real and imaginary parts based on conditions
84
+ real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
85
+ imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
86
+ wave = jnp.where(cond1, imag_part.T, real_part.T)
87
+
88
+ # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
89
+ wave = wave.at[::2, 0].set(
90
+ jnp.real(DX_trans[0, ::2] * jnp.sqrt(2))
91
+ ) # DC component
92
+ wave = wave.at[1::2, -1].set(
93
+ jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2))
94
+ ) # Nyquist component
95
+
96
+ # Return the wavelet-transformed array (transposed for freq-major layout)
97
+ return wave.T # (Nt, Nf) -> (Nf, Nt)
@@ -1,10 +1,9 @@
1
1
  import jax.numpy as jnp
2
2
  from jax.numpy.fft import rfftfreq
3
3
 
4
- from ...phi_computer import phi_vec, phitilde_vec_norm
5
4
  from ....types import FrequencySeries, TimeSeries, Wavelet
5
+ from ...phi_computer import phi_vec, phitilde_vec_norm
6
6
  from .to_freq import inverse_wavelet_freq_helper
7
- # from .inverse_wavelet_time_funcs import inverse_wavelet_time_helper
8
7
 
9
8
 
10
9
  def from_wavelet_to_time(
@@ -65,5 +64,5 @@ def from_wavelet_to_freq(
65
64
  -1 / 2
66
65
  ) # Normalise to get the proper backwards transformation
67
66
 
68
- freqs = rfftfreq(wave_in.ND*2, d=dt)[1:]
69
- return FrequencySeries(data=freq_data, freq=freqs)
67
+ freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
68
+ return FrequencySeries(data=freq_data, freq=freqs)
@@ -1,6 +1,3 @@
1
1
  from .main import from_freq_to_wavelet, from_time_to_wavelet
2
- from ....logger import logger
3
-
4
- logger.warning("JAX SUBPACKAGE NOT YET TESTED")
5
2
 
6
3
  __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -0,0 +1,3 @@
1
+ from .main import from_wavelet_to_freq, from_wavelet_to_time
2
+
3
+ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -53,6 +53,7 @@ src/pywavelet/types/timeseries.py
53
53
  src/pywavelet/types/wavelet.py
54
54
  src/pywavelet/types/wavelet_bins.py
55
55
  tests/conftest.py
56
+ tests/test_jax.py
56
57
  tests/test_lnl.py
57
58
  tests/test_mask.py
58
59
  tests/test_phi.py
@@ -0,0 +1,105 @@
1
+ import importlib
2
+ import os
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from conftest import monochromatic_wnm
7
+
8
+ from pywavelet.transforms import from_freq_to_wavelet, from_time_to_wavelet
9
+ from pywavelet.types import FrequencySeries, TimeSeries
10
+ from pywavelet.types.wavelet_bins import compute_bins
11
+ from pywavelet.utils import compute_snr, evolutionary_psd_from_stationary_psd
12
+
13
+
14
+ def test_toy_model_snr(plot_dir):
15
+ f0 = 20
16
+ dt = 0.0125
17
+ A = 2
18
+ Nt = 128
19
+ Nf = 256
20
+ ND = Nt * Nf
21
+ t = np.arange(0, ND) * dt
22
+ PSD_AMP = 1
23
+
24
+ ########################################
25
+ # Part1: Analytical SNR calculation
26
+ #######################################
27
+
28
+ # Eq 21
29
+ y = A * np.sin(2 * np.pi * f0 * t) # Signal waveform we wish to test
30
+ signal_timeseries = TimeSeries(y, t)
31
+ signal_freq = signal_timeseries.to_frequencyseries()
32
+ psd_freq = FrequencySeries(
33
+ PSD_AMP * np.ones(len(signal_freq)), signal_freq.freq
34
+ )
35
+ snr = signal_freq.optimal_snr(psd_freq)
36
+
37
+ ########################################
38
+ # Part2: Wavelet domain (numpy)
39
+ ########################################
40
+
41
+ signal_wavelet = from_freq_to_wavelet(signal_freq, Nf=Nf, Nt=Nt)
42
+ psd_wavelet = evolutionary_psd_from_stationary_psd(
43
+ psd=psd_freq.data,
44
+ psd_f=psd_freq.freq,
45
+ f_grid=signal_wavelet.freq,
46
+ t_grid=signal_wavelet.time,
47
+ dt=dt,
48
+ )
49
+ wdm_snr = compute_snr(signal_wavelet, signal_wavelet, psd_wavelet)
50
+ assert np.isclose(snr, wdm_snr, atol=0.5), f"{snr}!={wdm_snr}"
51
+
52
+ ########################################
53
+ # Part3: Wavelet domain (jax)
54
+ ########################################
55
+
56
+ from pywavelet.transforms.jax import (
57
+ from_freq_to_wavelet as jax_from_freq_to_wavelet,
58
+ )
59
+
60
+ signal_wavelet_jax = jax_from_freq_to_wavelet(signal_freq, Nf=Nf, Nt=Nt)
61
+ psd_wavelet_jax = evolutionary_psd_from_stationary_psd(
62
+ psd=psd_freq.data,
63
+ psd_f=psd_freq.freq,
64
+ f_grid=signal_wavelet_jax.freq,
65
+ t_grid=signal_wavelet_jax.time,
66
+ dt=dt,
67
+ )
68
+ wdm_snr_jax = compute_snr(
69
+ signal_wavelet_jax, signal_wavelet_jax, psd_wavelet_jax
70
+ )
71
+ assert np.isclose(snr, wdm_snr_jax, atol=0.5), f"{snr}!={wdm_snr_jax}"
72
+
73
+ wdm_diff = signal_wavelet - signal_wavelet_jax
74
+
75
+ ########################################
76
+ # Part4: Plot
77
+ ########################################
78
+
79
+ fig, ax = plt.subplots(1, 3, figsize=(15, 6))
80
+ signal_wavelet.plot(ax=ax[0])
81
+ signal_wavelet_jax.plot(ax=ax[1])
82
+ wdm_diff.plot(ax=ax[2])
83
+ ax[0].set_title(f"Numpy SNR={wdm_snr:.2f}")
84
+ ax[1].set_title(f"Jax SNR={wdm_snr_jax:.2f}")
85
+ ax[2].set_title("Difference")
86
+ plt.tight_layout()
87
+ plt.savefig(f"{plot_dir}/jax_vs_np.png")
88
+
89
+
90
+ def test_backend_loader():
91
+ # temporarily set os.environ["PYWAVELET_JAX"] = "1"
92
+
93
+ import pywavelet.backend
94
+
95
+ os.environ["PYWAVELET_JAX"] = "1"
96
+ importlib.reload(pywavelet.backend)
97
+ from pywavelet.backend import use_jax
98
+
99
+ assert use_jax
100
+ os.environ["PYWAVELET_JAX"] = "0"
101
+
102
+ importlib.reload(pywavelet.backend)
103
+ from pywavelet.backend import use_jax
104
+
105
+ assert not use_jax
@@ -1,17 +0,0 @@
1
- from .numpy import (
2
- from_wavelet_to_time,
3
- from_wavelet_to_freq,
4
- from_time_to_wavelet,
5
- from_freq_to_wavelet,
6
- )
7
- from .phi_computer import phi_vec, phitilde_vec_norm, phitilde_vec
8
-
9
- __all__ = [
10
- "from_wavelet_to_time",
11
- "from_wavelet_to_freq",
12
- "from_time_to_wavelet",
13
- "from_freq_to_wavelet",
14
- "phitilde_vec_norm",
15
- "phi_vec",
16
- "phitilde_vec",
17
- ]
File without changes
@@ -1,56 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from functools import partial
4
- from jax import jit
5
- from jax.numpy.fft import ifft
6
-
7
- @partial(jit, static_argnames=('Nf', 'Nt'))
8
- def transform_wavelet_freq_helper(
9
- data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
10
- ) -> jnp.ndarray:
11
- # Initially all wrk being done in time-rws, freq-cols
12
- wave = jnp.zeros((Nt, Nf))
13
- f_bins = jnp.arange(Nf)
14
-
15
- i_base = Nt // 2
16
- jj_base = f_bins * Nt // 2
17
-
18
- initial_values = jnp.where(
19
- (f_bins == 0) | (f_bins == Nf),
20
- phif[0] * data[f_bins * Nt // 2] / 2.0,
21
- phif[0] * data[f_bins * Nt // 2]
22
- )
23
-
24
- DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)
25
- DX = DX.at[:, Nt // 2].set(initial_values)
26
-
27
- j_range = jnp.arange(1 - Nt // 2, Nt // 2)
28
- j = jnp.abs(j_range)
29
- i = i_base + j_range
30
-
31
- cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)
32
- cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)
33
- cond3 = j[None, :] == 0
34
-
35
- jj = jj_base[:, None] + j_range[None, :]
36
- val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])
37
- DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))
38
-
39
- # Vectorized ifft
40
- DX_trans = ifft(DX, axis=1)
41
-
42
- # Vectorized __fill_wave_2_jax
43
- n_range = jnp.arange(Nt)
44
- cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1
45
- cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # shape: (Nf, 1)
46
-
47
- real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
48
- imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
49
-
50
- wave = jnp.where(cond1, imag_part.T, real_part.T)
51
-
52
- ## Special cases for f_bin 0 and Nf
53
- wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))
54
- wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))
55
-
56
- return wave.T
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes