pywavelet 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {pywavelet-0.2.1 → pywavelet-0.2.3}/.github/workflows/ci.yml +1 -0
  2. {pywavelet-0.2.1 → pywavelet-0.2.3}/CHANGELOG.rst +47 -0
  3. {pywavelet-0.2.1 → pywavelet-0.2.3}/PKG-INFO +51 -2
  4. {pywavelet-0.2.1 → pywavelet-0.2.3}/pyproject.toml +1 -1
  5. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/_version.py +2 -2
  6. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/backend.py +9 -4
  7. pywavelet-0.2.3/src/pywavelet/transforms/__init__.py +28 -0
  8. pywavelet-0.2.3/src/pywavelet/transforms/jax/__init__.py +12 -0
  9. pywavelet-0.2.3/src/pywavelet/transforms/jax/forward/from_freq.py +97 -0
  10. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/jax/inverse/main.py +3 -4
  11. {pywavelet-0.2.1/src/pywavelet/transforms/jax → pywavelet-0.2.3/src/pywavelet/transforms/numpy}/forward/__init__.py +0 -3
  12. pywavelet-0.2.3/src/pywavelet/transforms/numpy/inverse/__init__.py +3 -0
  13. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet.egg-info/PKG-INFO +51 -2
  14. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet.egg-info/SOURCES.txt +1 -0
  15. pywavelet-0.2.3/tests/test_jax.py +105 -0
  16. pywavelet-0.2.1/src/pywavelet/transforms/__init__.py +0 -17
  17. pywavelet-0.2.1/src/pywavelet/transforms/jax/__init__.py +0 -0
  18. pywavelet-0.2.1/src/pywavelet/transforms/jax/forward/from_freq.py +0 -56
  19. pywavelet-0.2.1/src/pywavelet/transforms/jax/inverse/__init__.py +0 -0
  20. {pywavelet-0.2.1 → pywavelet-0.2.3}/.github/workflows/docs.yml +0 -0
  21. {pywavelet-0.2.1 → pywavelet-0.2.3}/.github/workflows/pypi.yml +0 -0
  22. {pywavelet-0.2.1 → pywavelet-0.2.3}/.gitignore +0 -0
  23. {pywavelet-0.2.1 → pywavelet-0.2.3}/.pre-commit-config.yaml +0 -0
  24. {pywavelet-0.2.1 → pywavelet-0.2.3}/CITATION.cff +0 -0
  25. {pywavelet-0.2.1 → pywavelet-0.2.3}/README.rst +0 -0
  26. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/_config.yml +0 -0
  27. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/_static/demo.gif +0 -0
  28. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/_toc.yml +0 -0
  29. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/api.rst +0 -0
  30. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/example.ipynb +0 -0
  31. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/index.rst +0 -0
  32. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/logo.png +0 -0
  33. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/roundtrip_freq.png +0 -0
  34. {pywavelet-0.2.1 → pywavelet-0.2.3}/docs/roundtrip_time.png +0 -0
  35. {pywavelet-0.2.1 → pywavelet-0.2.3}/setup.cfg +0 -0
  36. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/__init__.py +0 -0
  37. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/logger.py +0 -0
  38. {pywavelet-0.2.1/src/pywavelet/transforms/numpy → pywavelet-0.2.3/src/pywavelet/transforms/jax}/forward/__init__.py +0 -0
  39. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/jax/forward/from_time.py +0 -0
  40. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/jax/forward/main.py +0 -0
  41. {pywavelet-0.2.1/src/pywavelet/transforms/numpy → pywavelet-0.2.3/src/pywavelet/transforms/jax}/inverse/__init__.py +0 -0
  42. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/jax/inverse/to_freq.py +0 -0
  43. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/__init__.py +0 -0
  44. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/forward/from_freq.py +0 -0
  45. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/forward/from_time.py +0 -0
  46. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/forward/main.py +0 -0
  47. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/inverse/main.py +0 -0
  48. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/inverse/to_freq.py +0 -0
  49. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/numpy/inverse/to_time.py +0 -0
  50. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/transforms/phi_computer.py +0 -0
  51. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/types/__init__.py +0 -0
  52. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/types/common.py +0 -0
  53. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/types/frequencyseries.py +0 -0
  54. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/types/plotting.py +0 -0
  55. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/types/timeseries.py +0 -0
  56. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/types/wavelet.py +0 -0
  57. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/types/wavelet_bins.py +0 -0
  58. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet/utils.py +0 -0
  59. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet.egg-info/dependency_links.txt +0 -0
  60. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet.egg-info/requires.txt +0 -0
  61. {pywavelet-0.2.1 → pywavelet-0.2.3}/src/pywavelet.egg-info/top_level.txt +0 -0
  62. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/conftest.py +0 -0
  63. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_data/roundtrip_chirp_freq.npz +0 -0
  64. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_data/roundtrip_chirp_time.npz +0 -0
  65. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_data/roundtrip_pure_f0_freq.npz +0 -0
  66. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_data/roundtrip_sine_freq.npz +0 -0
  67. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_data/roundtrip_sine_time.npz +0 -0
  68. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_lnl.py +0 -0
  69. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_mask.py +0 -0
  70. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_phi.py +0 -0
  71. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_psd.py +0 -0
  72. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_roundtrip_conversion.py +0 -0
  73. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_snr.py +0 -0
  74. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_timefreq_type.py +0 -0
  75. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_version.py +0 -0
  76. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/test_wavelet_plot.py +0 -0
  77. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/utils/__init__.py +0 -0
  78. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/utils/generate_data.py +0 -0
  79. {pywavelet-0.2.1 → pywavelet-0.2.3}/tests/utils/plotting.py +0 -0
@@ -28,6 +28,7 @@ jobs:
28
28
  run: |
29
29
  python -m pip install --upgrade pip
30
30
  python -m pip install -e .[dev]
31
+ python -m pip install -e .[jax]
31
32
  pre-commit install
32
33
 
33
34
  - name: pre-commit
@@ -5,16 +5,63 @@ CHANGELOG
5
5
  =========
6
6
 
7
7
 
8
+ .. _changelog-v0.2.3:
9
+
10
+ v0.2.3 (2025-01-24)
11
+ ===================
12
+
13
+ Bug Fixes
14
+ ---------
15
+
16
+ * fix: add backend check for os.environ (`98c0818`_)
17
+
18
+ * fix: add test for jax (`1940394`_)
19
+
20
+ .. _98c0818: https://github.com/pywavelet/pywavelet/commit/98c0818078190d829a23734f932f1f93c9932167
21
+ .. _1940394: https://github.com/pywavelet/pywavelet/commit/194039437a3a9b3ada303d101b4e2573ab7d0afd
22
+
23
+
24
+ .. _changelog-v0.2.2:
25
+
26
+ v0.2.2 (2025-01-23)
27
+ ===================
28
+
29
+ Chores
30
+ ------
31
+
32
+ * chore(release): 0.2.2 (`eed5d68`_)
33
+
34
+ Unknown
35
+ -------
36
+
37
+ * Merge branch 'main' of github.com:pywavelet/pywavelet (`e8e2115`_)
38
+
39
+ .. _eed5d68: https://github.com/pywavelet/pywavelet/commit/eed5d6864276fc5f90c4866749903e3e358df5ca
40
+ .. _e8e2115: https://github.com/pywavelet/pywavelet/commit/e8e2115e797a5001f236ff027a14ef226151dcc1
41
+
42
+
8
43
  .. _changelog-v0.2.1:
9
44
 
10
45
  v0.2.1 (2025-01-23)
11
46
  ===================
12
47
 
48
+ Bug Fixes
49
+ ---------
50
+
51
+ * fix: fix readme path (`34e927d`_)
52
+
53
+ Chores
54
+ ------
55
+
56
+ * chore(release): 0.2.1 (`a55bce5`_)
57
+
13
58
  Unknown
14
59
  -------
15
60
 
16
61
  * Merge branch 'main' of github.com:pywavelet/pywavelet (`b8c6d15`_)
17
62
 
63
+ .. _34e927d: https://github.com/pywavelet/pywavelet/commit/34e927d411ec8fde89f552bd5ec89b38820e07e0
64
+ .. _a55bce5: https://github.com/pywavelet/pywavelet/commit/a55bce518c3484543efada283399a41df3ecf001
18
65
  .. _b8c6d15: https://github.com/pywavelet/pywavelet/commit/b8c6d1579d48ec5fa22130430267794ae8e54f6c
19
66
 
20
67
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -12,7 +12,7 @@ Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Operating System :: OS Independent
13
13
  Classifier: Programming Language :: Python :: 3.8
14
14
  Requires-Python: >=3.8
15
- Description-Content-Type: text/markdown
15
+ Description-Content-Type: text/x-rst
16
16
  Requires-Dist: numpy
17
17
  Requires-Dist: numba
18
18
  Requires-Dist: scipy
@@ -34,3 +34,52 @@ Requires-Dist: isort; extra == "dev"
34
34
  Requires-Dist: mypy; extra == "dev"
35
35
  Requires-Dist: jupyter-book; extra == "dev"
36
36
  Requires-Dist: GitPython; extra == "dev"
37
+
38
+ pywavelet
39
+ #########
40
+
41
+ .. image:: https://badge.fury.io/py/pywavelet.svg
42
+ :target: https://badge.fury.io/py/pywavelet
43
+ .. image:: https://coveralls.io/repos/github/avivajpeyi/pywavelet/badge.svg?branch=main&kill_cache=1
44
+ :target: https://coveralls.io/github/avivajpeyi/pywavelet?branch=main
45
+
46
+
47
+
48
+
49
+
50
+ WDM Wavelet transform
51
+
52
+
53
+ Quickstart
54
+ ==========
55
+
56
+ pywavelet is available on PyPI and can be installed with `pip <https://pip.pypa.io>`_.
57
+
58
+ .. code-block:: console
59
+
60
+ $ pip install pywavelet
61
+
62
+ For developers
63
+ --------------
64
+
65
+ First set up a conda environment with the latest version of python.
66
+
67
+ .. code-block::
68
+
69
+ $ conda create -n pywavelet -c conda-forge python=3.12
70
+
71
+ .. code-block::
72
+
73
+ $ pip install -e ".[dev]"
74
+ $ pre-commit install
75
+
76
+ Test code
77
+ ---------
78
+
79
+ Locate directory /tests from root directory. run
80
+
81
+ .. code-block::
82
+
83
+ $ pytest .
84
+
85
+ Hopefully everything should run fine.
@@ -11,7 +11,7 @@ name = "pywavelet"
11
11
  dynamic = ["version"] # scm versioning (using tags)
12
12
  requires-python = ">=3.8"
13
13
  description = "WDM wavelet transform your time/freq series!"
14
- readme = "README.md"
14
+ readme = "README.rst"
15
15
  authors = [
16
16
  { name = "Pywavelet Team", email = "avi.vajpeyi@gmail.com" },
17
17
  ]
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.1'
16
- __version_tuple__ = version_tuple = (0, 2, 1)
15
+ __version__ = version = '0.2.3'
16
+ __version_tuple__ = version_tuple = (0, 2, 3)
@@ -1,5 +1,7 @@
1
1
  import os
2
2
 
3
+ from .logger import logger
4
+
3
5
  try:
4
6
  import jax
5
7
 
@@ -13,14 +15,17 @@ use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
13
15
 
14
16
  if use_jax:
15
17
  import jax.numpy as xp # type: ignore
16
- from jax.scipy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
18
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
17
19
  from jax.scipy.special import betainc # type: ignore
18
20
 
21
+ logger.info("Using JAX backend")
19
22
 
20
23
  else:
21
24
  import numpy as xp # type: ignore
22
- from numpy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
23
- from scipy.special import betainc # type: ignore
25
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
26
+ from scipy.special import betainc # type: ignore
27
+
28
+ logger.info("Using NumPy+numba backend")
24
29
 
25
30
 
26
- PI = xp.pi
31
+ PI = xp.pi
@@ -0,0 +1,28 @@
1
+ from ..backend import use_jax
2
+
3
+ if use_jax:
4
+ from .jax import (
5
+ from_freq_to_wavelet,
6
+ from_time_to_wavelet,
7
+ from_wavelet_to_freq,
8
+ from_wavelet_to_time,
9
+ )
10
+ else:
11
+ from .numpy import (
12
+ from_wavelet_to_time,
13
+ from_wavelet_to_freq,
14
+ from_time_to_wavelet,
15
+ from_freq_to_wavelet,
16
+ )
17
+
18
+ from .phi_computer import phi_vec, phitilde_vec, phitilde_vec_norm
19
+
20
+ __all__ = [
21
+ "from_wavelet_to_time",
22
+ "from_wavelet_to_freq",
23
+ "from_time_to_wavelet",
24
+ "from_freq_to_wavelet",
25
+ "phitilde_vec_norm",
26
+ "phi_vec",
27
+ "phitilde_vec",
28
+ ]
@@ -0,0 +1,12 @@
1
+ from ...logger import logger
2
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
3
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
+
5
+ logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
6
+
7
+ __all__ = [
8
+ "from_wavelet_to_time",
9
+ "from_wavelet_to_freq",
10
+ "from_time_to_wavelet",
11
+ "from_freq_to_wavelet",
12
+ ]
@@ -0,0 +1,97 @@
1
+ from functools import partial
2
+
3
+ import jax.numpy as jnp
4
+ from jax import jit
5
+ from jax.numpy.fft import ifft
6
+
7
+
8
+ @partial(jit, static_argnames=("Nf", "Nt"))
9
+ def transform_wavelet_freq_helper(
10
+ data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
11
+ ) -> jnp.ndarray:
12
+ """
13
+ Transforms input data from the frequency domain to the wavelet domain using a
14
+ pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
15
+
16
+ Parameters:
17
+ - data (jnp.ndarray): 1D array representing the input data in the frequency domain.
18
+ - Nf (int): Number of frequency bins.
19
+ - Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
20
+ - phif (jnp.ndarray): Pre-computed wavelet filter for frequency components.
21
+
22
+ Returns:
23
+ - wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
24
+ """
25
+
26
+ # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
27
+ wave = jnp.zeros((Nt, Nf))
28
+ f_bins = jnp.arange(Nf) # Frequency bin indices
29
+
30
+ # Compute base indices for time (i_base) and frequency (jj_base)
31
+ i_base = Nt // 2
32
+ jj_base = f_bins * Nt // 2
33
+
34
+ # Set initial values for the center of the transformation
35
+ initial_values = jnp.where(
36
+ (f_bins == 0)
37
+ | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
38
+ phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
39
+ phif[0] * data[f_bins * Nt // 2],
40
+ )
41
+
42
+ # Initialize a 2D array to store intermediate FFT input values
43
+ DX = jnp.zeros(
44
+ (Nf, Nt), dtype=jnp.complex64
45
+ ) # TODO: Check dtype -- is complex64 sufficient?
46
+ DX = DX.at[:, Nt // 2].set(
47
+ initial_values
48
+ ) # Set initial values at the center of the transformation (2 sided FFT)
49
+
50
+ # Compute time indices for all offsets around the midpoint
51
+ j_range = jnp.arange(
52
+ 1 - Nt // 2, Nt // 2
53
+ ) # Time offsets (centered around zero)
54
+ j = jnp.abs(j_range) # Absolute offset indices
55
+ i = i_base + j_range # Time indices in output array
56
+
57
+ # Compute conditions for edge cases
58
+ cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
59
+ cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
60
+ cond3 = j[None, :] == 0 # Center of the transformation (no offset)
61
+
62
+ # Compute frequency indices for the input data
63
+ jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
64
+ val = jnp.where(
65
+ cond1 | cond2, 0.0, phif[j] * data[jj]
66
+ ) # Wavelet filter application
67
+ DX = DX.at[:, i].set(
68
+ jnp.where(cond3, DX[:, i], val)
69
+ ) # Update DX with computed values
70
+ # At this point, DX contains the data FFT'd with the wavelet filter
71
+ # (each row is a frequency bin, each column is a time bin)
72
+
73
+ # Perform the inverse FFT along the time dimension
74
+ DX_trans = ifft(DX, axis=1)
75
+
76
+ # Fill the wavelet output array based on the inverse FFT results
77
+ n_range = jnp.arange(Nt) # Time indices
78
+ cond1 = (
79
+ n_range[:, None] + f_bins[None, :]
80
+ ) % 2 == 1 # Odd/even alternation
81
+ cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
82
+
83
+ # Assign real and imaginary parts based on conditions
84
+ real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
85
+ imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
86
+ wave = jnp.where(cond1, imag_part.T, real_part.T)
87
+
88
+ # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
89
+ wave = wave.at[::2, 0].set(
90
+ jnp.real(DX_trans[0, ::2] * jnp.sqrt(2))
91
+ ) # DC component
92
+ wave = wave.at[1::2, -1].set(
93
+ jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2))
94
+ ) # Nyquist component
95
+
96
+ # Return the wavelet-transformed array (transposed for freq-major layout)
97
+ return wave.T # (Nt, Nf) -> (Nf, Nt)
@@ -1,10 +1,9 @@
1
1
  import jax.numpy as jnp
2
2
  from jax.numpy.fft import rfftfreq
3
3
 
4
- from ...phi_computer import phi_vec, phitilde_vec_norm
5
4
  from ....types import FrequencySeries, TimeSeries, Wavelet
5
+ from ...phi_computer import phi_vec, phitilde_vec_norm
6
6
  from .to_freq import inverse_wavelet_freq_helper
7
- # from .inverse_wavelet_time_funcs import inverse_wavelet_time_helper
8
7
 
9
8
 
10
9
  def from_wavelet_to_time(
@@ -65,5 +64,5 @@ def from_wavelet_to_freq(
65
64
  -1 / 2
66
65
  ) # Normalise to get the proper backwards transformation
67
66
 
68
- freqs = rfftfreq(wave_in.ND*2, d=dt)[1:]
69
- return FrequencySeries(data=freq_data, freq=freqs)
67
+ freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
68
+ return FrequencySeries(data=freq_data, freq=freqs)
@@ -1,6 +1,3 @@
1
1
  from .main import from_freq_to_wavelet, from_time_to_wavelet
2
- from ....logger import logger
3
-
4
- logger.warning("JAX SUBPACKAGE NOT YET TESTED")
5
2
 
6
3
  __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -0,0 +1,3 @@
1
+ from .main import from_wavelet_to_freq, from_wavelet_to_time
2
+
3
+ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -12,7 +12,7 @@ Classifier: License :: OSI Approved :: MIT License
12
12
  Classifier: Operating System :: OS Independent
13
13
  Classifier: Programming Language :: Python :: 3.8
14
14
  Requires-Python: >=3.8
15
- Description-Content-Type: text/markdown
15
+ Description-Content-Type: text/x-rst
16
16
  Requires-Dist: numpy
17
17
  Requires-Dist: numba
18
18
  Requires-Dist: scipy
@@ -34,3 +34,52 @@ Requires-Dist: isort; extra == "dev"
34
34
  Requires-Dist: mypy; extra == "dev"
35
35
  Requires-Dist: jupyter-book; extra == "dev"
36
36
  Requires-Dist: GitPython; extra == "dev"
37
+
38
+ pywavelet
39
+ #########
40
+
41
+ .. image:: https://badge.fury.io/py/pywavelet.svg
42
+ :target: https://badge.fury.io/py/pywavelet
43
+ .. image:: https://coveralls.io/repos/github/avivajpeyi/pywavelet/badge.svg?branch=main&kill_cache=1
44
+ :target: https://coveralls.io/github/avivajpeyi/pywavelet?branch=main
45
+
46
+
47
+
48
+
49
+
50
+ WDM Wavelet transform
51
+
52
+
53
+ Quickstart
54
+ ==========
55
+
56
+ pywavelet is available on PyPI and can be installed with `pip <https://pip.pypa.io>`_.
57
+
58
+ .. code-block:: console
59
+
60
+ $ pip install pywavelet
61
+
62
+ For developers
63
+ --------------
64
+
65
+ First set up a conda environment with the latest version of python.
66
+
67
+ .. code-block::
68
+
69
+ $ conda create -n pywavelet -c conda-forge python=3.12
70
+
71
+ .. code-block::
72
+
73
+ $ pip install -e ".[dev]"
74
+ $ pre-commit install
75
+
76
+ Test code
77
+ ---------
78
+
79
+ Locate directory /tests from root directory. run
80
+
81
+ .. code-block::
82
+
83
+ $ pytest .
84
+
85
+ Hopefully everything should run fine.
@@ -53,6 +53,7 @@ src/pywavelet/types/timeseries.py
53
53
  src/pywavelet/types/wavelet.py
54
54
  src/pywavelet/types/wavelet_bins.py
55
55
  tests/conftest.py
56
+ tests/test_jax.py
56
57
  tests/test_lnl.py
57
58
  tests/test_mask.py
58
59
  tests/test_phi.py
@@ -0,0 +1,105 @@
1
+ import importlib
2
+ import os
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from conftest import monochromatic_wnm
7
+
8
+ from pywavelet.transforms import from_freq_to_wavelet, from_time_to_wavelet
9
+ from pywavelet.types import FrequencySeries, TimeSeries
10
+ from pywavelet.types.wavelet_bins import compute_bins
11
+ from pywavelet.utils import compute_snr, evolutionary_psd_from_stationary_psd
12
+
13
+
14
+ def test_toy_model_snr(plot_dir):
15
+ f0 = 20
16
+ dt = 0.0125
17
+ A = 2
18
+ Nt = 128
19
+ Nf = 256
20
+ ND = Nt * Nf
21
+ t = np.arange(0, ND) * dt
22
+ PSD_AMP = 1
23
+
24
+ ########################################
25
+ # Part1: Analytical SNR calculation
26
+ #######################################
27
+
28
+ # Eq 21
29
+ y = A * np.sin(2 * np.pi * f0 * t) # Signal waveform we wish to test
30
+ signal_timeseries = TimeSeries(y, t)
31
+ signal_freq = signal_timeseries.to_frequencyseries()
32
+ psd_freq = FrequencySeries(
33
+ PSD_AMP * np.ones(len(signal_freq)), signal_freq.freq
34
+ )
35
+ snr = signal_freq.optimal_snr(psd_freq)
36
+
37
+ ########################################
38
+ # Part2: Wavelet domain (numpy)
39
+ ########################################
40
+
41
+ signal_wavelet = from_freq_to_wavelet(signal_freq, Nf=Nf, Nt=Nt)
42
+ psd_wavelet = evolutionary_psd_from_stationary_psd(
43
+ psd=psd_freq.data,
44
+ psd_f=psd_freq.freq,
45
+ f_grid=signal_wavelet.freq,
46
+ t_grid=signal_wavelet.time,
47
+ dt=dt,
48
+ )
49
+ wdm_snr = compute_snr(signal_wavelet, signal_wavelet, psd_wavelet)
50
+ assert np.isclose(snr, wdm_snr, atol=0.5), f"{snr}!={wdm_snr}"
51
+
52
+ ########################################
53
+ # Part3: Wavelet domain (jax)
54
+ ########################################
55
+
56
+ from pywavelet.transforms.jax import (
57
+ from_freq_to_wavelet as jax_from_freq_to_wavelet,
58
+ )
59
+
60
+ signal_wavelet_jax = jax_from_freq_to_wavelet(signal_freq, Nf=Nf, Nt=Nt)
61
+ psd_wavelet_jax = evolutionary_psd_from_stationary_psd(
62
+ psd=psd_freq.data,
63
+ psd_f=psd_freq.freq,
64
+ f_grid=signal_wavelet_jax.freq,
65
+ t_grid=signal_wavelet_jax.time,
66
+ dt=dt,
67
+ )
68
+ wdm_snr_jax = compute_snr(
69
+ signal_wavelet_jax, signal_wavelet_jax, psd_wavelet_jax
70
+ )
71
+ assert np.isclose(snr, wdm_snr_jax, atol=0.5), f"{snr}!={wdm_snr_jax}"
72
+
73
+ wdm_diff = signal_wavelet - signal_wavelet_jax
74
+
75
+ ########################################
76
+ # Part4: Plot
77
+ ########################################
78
+
79
+ fig, ax = plt.subplots(1, 3, figsize=(15, 6))
80
+ signal_wavelet.plot(ax=ax[0])
81
+ signal_wavelet_jax.plot(ax=ax[1])
82
+ wdm_diff.plot(ax=ax[2])
83
+ ax[0].set_title(f"Numpy SNR={wdm_snr:.2f}")
84
+ ax[1].set_title(f"Jax SNR={wdm_snr_jax:.2f}")
85
+ ax[2].set_title("Difference")
86
+ plt.tight_layout()
87
+ plt.savefig(f"{plot_dir}/jax_vs_np.png")
88
+
89
+
90
+ def test_backend_loader():
91
+ # temporarily set os.environ["PYWAVELET_JAX"] = "1"
92
+
93
+ import pywavelet.backend
94
+
95
+ os.environ["PYWAVELET_JAX"] = "1"
96
+ importlib.reload(pywavelet.backend)
97
+ from pywavelet.backend import use_jax
98
+
99
+ assert use_jax
100
+ os.environ["PYWAVELET_JAX"] = "0"
101
+
102
+ importlib.reload(pywavelet.backend)
103
+ from pywavelet.backend import use_jax
104
+
105
+ assert not use_jax
@@ -1,17 +0,0 @@
1
- from .numpy import (
2
- from_wavelet_to_time,
3
- from_wavelet_to_freq,
4
- from_time_to_wavelet,
5
- from_freq_to_wavelet,
6
- )
7
- from .phi_computer import phi_vec, phitilde_vec_norm, phitilde_vec
8
-
9
- __all__ = [
10
- "from_wavelet_to_time",
11
- "from_wavelet_to_freq",
12
- "from_time_to_wavelet",
13
- "from_freq_to_wavelet",
14
- "phitilde_vec_norm",
15
- "phi_vec",
16
- "phitilde_vec",
17
- ]
File without changes
@@ -1,56 +0,0 @@
1
- import jax
2
- import jax.numpy as jnp
3
- from functools import partial
4
- from jax import jit
5
- from jax.numpy.fft import ifft
6
-
7
- @partial(jit, static_argnames=('Nf', 'Nt'))
8
- def transform_wavelet_freq_helper(
9
- data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
10
- ) -> jnp.ndarray:
11
- # Initially all wrk being done in time-rws, freq-cols
12
- wave = jnp.zeros((Nt, Nf))
13
- f_bins = jnp.arange(Nf)
14
-
15
- i_base = Nt // 2
16
- jj_base = f_bins * Nt // 2
17
-
18
- initial_values = jnp.where(
19
- (f_bins == 0) | (f_bins == Nf),
20
- phif[0] * data[f_bins * Nt // 2] / 2.0,
21
- phif[0] * data[f_bins * Nt // 2]
22
- )
23
-
24
- DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)
25
- DX = DX.at[:, Nt // 2].set(initial_values)
26
-
27
- j_range = jnp.arange(1 - Nt // 2, Nt // 2)
28
- j = jnp.abs(j_range)
29
- i = i_base + j_range
30
-
31
- cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)
32
- cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)
33
- cond3 = j[None, :] == 0
34
-
35
- jj = jj_base[:, None] + j_range[None, :]
36
- val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])
37
- DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))
38
-
39
- # Vectorized ifft
40
- DX_trans = ifft(DX, axis=1)
41
-
42
- # Vectorized __fill_wave_2_jax
43
- n_range = jnp.arange(Nt)
44
- cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1
45
- cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # shape: (Nf, 1)
46
-
47
- real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
48
- imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
49
-
50
- wave = jnp.where(cond1, imag_part.T, real_part.T)
51
-
52
- ## Special cases for f_bin 0 and Nf
53
- wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))
54
- wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))
55
-
56
- return wave.T
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes