pywavelet 0.2.6__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pywavelet-0.2.6 → pywavelet-0.2.7}/.gitignore +1 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/CHANGELOG.rst +22 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/PKG-INFO +1 -1
- pywavelet-0.2.7/docs/runtime.ipynb +322 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/_version.py +2 -2
- pywavelet-0.2.7/src/pywavelet/backend.py +101 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/__init__.py +6 -0
- pywavelet-0.2.7/src/pywavelet/transforms/jax/__init__.py +36 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/PKG-INFO +1 -1
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/SOURCES.txt +0 -1
- pywavelet-0.2.7/tests/test_backends.py +42 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_roundtrip_conversion.py +27 -22
- pywavelet-0.2.7/tests/test_snr.py +249 -0
- pywavelet-0.2.7/tests/utils/__init__.py +3 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/utils/conversions.py +9 -3
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/utils/plotting.py +6 -6
- pywavelet-0.2.6/docs/runtime.ipynb +0 -201
- pywavelet-0.2.6/src/pywavelet/backend.py +0 -53
- pywavelet-0.2.6/src/pywavelet/transforms/jax/__init__.py +0 -12
- pywavelet-0.2.6/tests/test_backends.py +0 -136
- pywavelet-0.2.6/tests/test_snr.py +0 -123
- pywavelet-0.2.6/tests/utils/__init__.py +0 -4
- pywavelet-0.2.6/tests/utils/cupy_check.py +0 -14
- {pywavelet-0.2.6 → pywavelet-0.2.7}/.github/workflows/ci.yml +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/.github/workflows/docs.yml +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/.github/workflows/pypi.yml +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/.pre-commit-config.yaml +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/CITATION.cff +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/README.rst +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/_config.yml +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/_static/demo.gif +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/_toc.yml +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/api.rst +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/example.ipynb +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/index.rst +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/logo.png +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/roundtrip_freq.png +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/roundtrip_time.png +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/pyproject.toml +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/setup.cfg +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/logger.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/forward/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/forward/from_freq.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/forward/from_time.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/forward/main.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/inverse/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/inverse/main.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/inverse/to_freq.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/forward/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/forward/from_freq.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/forward/from_time.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/forward/main.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/inverse/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/inverse/main.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/inverse/to_freq.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/forward/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/forward/from_freq.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/forward/from_time.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/forward/main.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/inverse/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/inverse/main.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/inverse/to_freq.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/inverse/to_time.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/phi_computer.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/__init__.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/common.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/frequencyseries.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/plotting.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/timeseries.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/wavelet.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/wavelet_bins.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/utils.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/dependency_links.txt +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/requires.txt +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/top_level.txt +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/conftest.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/ollie_example.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_chirp_freq.npz +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_chirp_time.npz +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_pure_f0_freq.npz +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_sine_freq.npz +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_sine_time.npz +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_docs.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_lnl.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_mask.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_phi.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_psd.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_timefreq_type.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_version.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_wavelet_plot.py +0 -0
- {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/utils/generate_data.py +0 -0
@@ -5,6 +5,19 @@ CHANGELOG
|
|
5
5
|
=========
|
6
6
|
|
7
7
|
|
8
|
+
.. _changelog-v0.2.7:
|
9
|
+
|
10
|
+
v0.2.7 (2025-03-25)
|
11
|
+
===================
|
12
|
+
|
13
|
+
Unknown
|
14
|
+
-------
|
15
|
+
|
16
|
+
* Merge branch 'main' of github.com:pywavelet/pywavelet (`ef5aa31`_)
|
17
|
+
|
18
|
+
.. _ef5aa31: https://github.com/pywavelet/pywavelet/commit/ef5aa31e72c8a838cfb008c03bd47b009df728dc
|
19
|
+
|
20
|
+
|
8
21
|
.. _changelog-v0.2.6:
|
9
22
|
|
10
23
|
v0.2.6 (2025-03-24)
|
@@ -13,8 +26,15 @@ v0.2.6 (2025-03-24)
|
|
13
26
|
Bug Fixes
|
14
27
|
---------
|
15
28
|
|
29
|
+
* fix: add more checkingn for cupy and jax backends (`75e54c0`_)
|
30
|
+
|
16
31
|
* fix: add note in readme about JAX and cupy (`d3cd8d9`_)
|
17
32
|
|
33
|
+
Chores
|
34
|
+
------
|
35
|
+
|
36
|
+
* chore(release): 0.2.6 (`d2e3cb2`_)
|
37
|
+
|
18
38
|
Unknown
|
19
39
|
-------
|
20
40
|
|
@@ -120,7 +140,9 @@ Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.gi
|
|
120
140
|
|
121
141
|
* add additional JAX tests (`c0d8396`_)
|
122
142
|
|
143
|
+
.. _75e54c0: https://github.com/pywavelet/pywavelet/commit/75e54c00dd9ddfff98a7f0612f54abf1ac17184e
|
123
144
|
.. _d3cd8d9: https://github.com/pywavelet/pywavelet/commit/d3cd8d92b5e6cf398ff3a4948b911533abe842f1
|
145
|
+
.. _d2e3cb2: https://github.com/pywavelet/pywavelet/commit/d2e3cb271b06d1d9bc3060e924fa5bcd57ceb269
|
124
146
|
.. _6cdfb28: https://github.com/pywavelet/pywavelet/commit/6cdfb28b152a7f7ad499f0b6ba6ef69da9284c57
|
125
147
|
.. _14b250b: https://github.com/pywavelet/pywavelet/commit/14b250b55dbea88c3b7c22b5faca531113d34477
|
126
148
|
.. _3d8b4cd: https://github.com/pywavelet/pywavelet/commit/3d8b4cdc8dea6af33b08985cb97f0897984838fc
|
@@ -0,0 +1,322 @@
|
|
1
|
+
{
|
2
|
+
"cells": [
|
3
|
+
{
|
4
|
+
"metadata": {},
|
5
|
+
"cell_type": "markdown",
|
6
|
+
"source": [
|
7
|
+
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/pywavelet/pywavelet/blob/main/docs/runtime.ipynb\">\n",
|
8
|
+
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
|
9
|
+
"</a>\n",
|
10
|
+
"\n",
|
11
|
+
"\n",
|
12
|
+
"# Runtime Comparisons"
|
13
|
+
],
|
14
|
+
"id": "9df79fd10b46e463"
|
15
|
+
},
|
16
|
+
{
|
17
|
+
"cell_type": "code",
|
18
|
+
"id": "initial_id",
|
19
|
+
"metadata": {
|
20
|
+
"collapsed": true
|
21
|
+
},
|
22
|
+
"source": [
|
23
|
+
"import importlib\n",
|
24
|
+
"import numpy as np\n",
|
25
|
+
"import jax.numpy as jnp\n",
|
26
|
+
"import pandas as pd \n",
|
27
|
+
"import jax\n",
|
28
|
+
"from pywavelet.backend import cuda_available\n",
|
29
|
+
"from tqdm.auto import tqdm\n",
|
30
|
+
"from pywavelet.types import FrequencySeries\n",
|
31
|
+
"from pywavelet.transforms.phi_computer import phitilde_vec_norm\n",
|
32
|
+
"from timeit import repeat as timing_repeat\n",
|
33
|
+
"import matplotlib.pyplot as plt\n",
|
34
|
+
"\n",
|
35
|
+
"jax.config.update(\"jax_enable_x64\", False)\n",
|
36
|
+
" \n",
|
37
|
+
"JAX_DEVICE = jax.default_backend()\n",
|
38
|
+
"JAX_PRECISION = \"x64\" if jax.config.jax_enable_x64 else \"x32\"\n",
|
39
|
+
"\n",
|
40
|
+
"\n",
|
41
|
+
"if cuda_available:\n",
|
42
|
+
" import cupy as cp\n",
|
43
|
+
"\n",
|
44
|
+
"\n",
|
45
|
+
"def generate_freq_domain_signal(\n",
|
46
|
+
" ND, f0=20.0, dt=0.0125, A=2\n",
|
47
|
+
") -> FrequencySeries:\n",
|
48
|
+
" \"\"\"\n",
|
49
|
+
" Generates a frequency domain signal.\n",
|
50
|
+
"\n",
|
51
|
+
" Parameters:\n",
|
52
|
+
" ND (int): Number of data points.\n",
|
53
|
+
" f0 (float): Frequency of the signal. Default is 20.0.\n",
|
54
|
+
" dt (float): Time step. Default is 0.0125.\n",
|
55
|
+
" A (float): Amplitude of the signal. Default is 2.\n",
|
56
|
+
"\n",
|
57
|
+
" Returns:\n",
|
58
|
+
" FrequencySeries: The generated frequency domain signal.\n",
|
59
|
+
" \"\"\"\n",
|
60
|
+
" ts = np.arange(0, ND) * dt\n",
|
61
|
+
" y = A * np.sin(2 * np.pi * f0 * ts)\n",
|
62
|
+
" yf = FrequencySeries(y, ts)\n",
|
63
|
+
" return yf\n",
|
64
|
+
"\n",
|
65
|
+
"\n",
|
66
|
+
"def generate_func_args(ND, backend=\"numpy\"):\n",
|
67
|
+
" Nf = Nt = int(np.sqrt(ND))\n",
|
68
|
+
" yf = generate_freq_domain_signal(ND).data\n",
|
69
|
+
" phif = phitilde_vec_norm(Nf, Nt, d=4.0)\n",
|
70
|
+
" if backend == \"jax\":\n",
|
71
|
+
" yf = jnp.array(yf)\n",
|
72
|
+
" phif = jnp.array(phif)\n",
|
73
|
+
" if backend == \"cupy\" and cuda_available:\n",
|
74
|
+
" yf = cp.array(yf)\n",
|
75
|
+
" phif = cp.array(phif)\n",
|
76
|
+
" return yf, Nf, Nt, phif\n",
|
77
|
+
"\n",
|
78
|
+
"\n",
|
79
|
+
"def collect_runtime(\n",
|
80
|
+
" func, func_args, n=5, nreps=5\n",
|
81
|
+
"):\n",
|
82
|
+
" func(*func_args) # Warm up run\n",
|
83
|
+
" times = timing_repeat(\n",
|
84
|
+
" lambda: func(*func_args),\n",
|
85
|
+
" number=n,\n",
|
86
|
+
" repeat=nreps\n",
|
87
|
+
" )\n",
|
88
|
+
"\n",
|
89
|
+
" return (np.median(times), (np.std(times)))\n",
|
90
|
+
"\n",
|
91
|
+
"\n",
|
92
|
+
"def collect_runtimes(func, backend, NF_values, number=5, repeat=5):\n",
|
93
|
+
" results = {}\n",
|
94
|
+
" bar = tqdm(NF_values, desc=\"Running\")\n",
|
95
|
+
" for Nf in bar:\n",
|
96
|
+
" ND = Nf * Nf\n",
|
97
|
+
" bar.set_postfix(ND=f\"2**{int(np.log2(ND))}\")\n",
|
98
|
+
" func_args = generate_func_args(ND, backend)\n",
|
99
|
+
" try:\n",
|
100
|
+
" results[ND] = collect_runtime(func, func_args, number, repeat)\n",
|
101
|
+
" except Exception as e:\n",
|
102
|
+
" print(f\"Error processing ND={ND}: {e}\")\n",
|
103
|
+
" results[ND] = (np.nan, np.nan)\n",
|
104
|
+
" return results\n",
|
105
|
+
"\n",
|
106
|
+
"\n",
|
107
|
+
"def run_transforms():\n",
|
108
|
+
" from pywavelet.transforms.jax.forward.from_freq import transform_wavelet_freq_helper as jax_transform\n",
|
109
|
+
" from pywavelet.transforms.numpy.forward.from_freq import transform_wavelet_freq_helper as np_transform\n",
|
110
|
+
"\n",
|
111
|
+
" min_pow2 = 2\n",
|
112
|
+
" max_pow2 = 12\n",
|
113
|
+
" NF = [2 ** i for i in range(min_pow2, max_pow2)]\n",
|
114
|
+
"\n",
|
115
|
+
" runtimes = {}\n",
|
116
|
+
" runtimes[\"numpy\"] = collect_runtimes(np_transform, \"numpy\", NF, number=5, repeat=5)\n",
|
117
|
+
" \n",
|
118
|
+
" max_pow2 = 15\n",
|
119
|
+
" NF = [2 ** i for i in range(min_pow2, max_pow2)]\n",
|
120
|
+
" runtimes[\"jax\"] = collect_runtimes(jax_transform, \"jax\", NF, number=5, repeat=5)\n",
|
121
|
+
"\n",
|
122
|
+
" if cuda_available:\n",
|
123
|
+
" from pywavelet.transforms.cupy.forward.from_freq import transform_wavelet_freq_helper as cp_transform\n",
|
124
|
+
" runtimes['cupy'] = collect_runtimes(cp_transform, \"cupy\", NF, number=10, repeat=10)\n",
|
125
|
+
" \n",
|
126
|
+
" \n",
|
127
|
+
" return runtimes\n",
|
128
|
+
"\n",
|
129
|
+
"\n",
|
130
|
+
"def plot(runtimes):\n",
|
131
|
+
" fig, ax = plt.subplots(figsize=(4, 3.5))\n",
|
132
|
+
" for i, backend in enumerate(runtimes.keys()):\n",
|
133
|
+
" _plot_backend_runtime(ax, runtimes[backend], backend, f\"C{i}\")\n",
|
134
|
+
" ax.set_yscale(\"log\")\n",
|
135
|
+
" ax.set_xscale(\"log\")\n",
|
136
|
+
" ax.set_xlabel(\"Number of Data Points\")\n",
|
137
|
+
" ax.set_ylabel(\"Runtime (s)\")\n",
|
138
|
+
" ax.legend(frameon=False)\n",
|
139
|
+
" return fig, ax\n",
|
140
|
+
"\n",
|
141
|
+
"\n",
|
142
|
+
"def _plot_backend_runtime(ax, runtimes, backend, color):\n",
|
143
|
+
" NDs = list(runtimes.keys())\n",
|
144
|
+
" times = [runtimes[ND][0] for ND in NDs]\n",
|
145
|
+
" stds = [runtimes[ND][1] for ND in NDs]\n",
|
146
|
+
" # plot a band around the median runtime\n",
|
147
|
+
" ax.fill_between(NDs, np.array(times) - np.array(stds), np.array(times) + np.array(stds), alpha=0.3, color=color)\n",
|
148
|
+
" ax.plot(NDs, times, label=f\"{backend}\", color=color)\n",
|
149
|
+
" ax.set_xlim(min(NDs), max(NDs))\n",
|
150
|
+
"\n",
|
151
|
+
"\n",
|
152
|
+
"runtimes = run_transforms()\n",
|
153
|
+
"# save runtime data as a txt file\n",
|
154
|
+
"runtime_data = pd.DataFrame(runtimes)\n",
|
155
|
+
"runtime_data.to_csv(f\"runtime_data_{JAX_DEVICE}.csv\")\n",
|
156
|
+
" \n",
|
157
|
+
"\n",
|
158
|
+
"fig, ax = plot(runtimes)\n",
|
159
|
+
"fig.savefig(\"runtime.png\", bbox_inches=\"tight\")\n",
|
160
|
+
"\n"
|
161
|
+
],
|
162
|
+
"outputs": [],
|
163
|
+
"execution_count": null
|
164
|
+
},
|
165
|
+
{
|
166
|
+
"metadata": {
|
167
|
+
"ExecuteTime": {
|
168
|
+
"end_time": "2025-03-24T07:53:51.232554Z",
|
169
|
+
"start_time": "2025-03-24T07:53:51.116610Z"
|
170
|
+
}
|
171
|
+
},
|
172
|
+
"cell_type": "code",
|
173
|
+
"source": [
|
174
|
+
"nan = np.nan\n",
|
175
|
+
"\n",
|
176
|
+
"\n",
|
177
|
+
"gpu_runtimes = {'numpy': {16: (np.float64(2.7647000024444424e-05),\n",
|
178
|
+
" np.float64(1.6699892612917253e-05)),\n",
|
179
|
+
" 64: (np.float64(4.7878000032142154e-05), np.float64(7.901239258106135e-06)),\n",
|
180
|
+
" 256: (np.float64(0.00014417199997751595),\n",
|
181
|
+
" np.float64(2.8245273864177316e-05)),\n",
|
182
|
+
" 1024: (np.float64(0.00025147099995592725),\n",
|
183
|
+
" np.float64(8.368707136149373e-06)),\n",
|
184
|
+
" 4096: (np.float64(0.0008757960000593812),\n",
|
185
|
+
" np.float64(0.00039357430227278433)),\n",
|
186
|
+
" 16384: (np.float64(0.0017043970000258923),\n",
|
187
|
+
" np.float64(0.0016274004399110828)),\n",
|
188
|
+
" 65536: (np.float64(0.006220592999966357),\n",
|
189
|
+
" np.float64(0.00021847848737320755)),\n",
|
190
|
+
" 262144: (np.float64(0.023843203999945217),\n",
|
191
|
+
" np.float64(0.0008785219521943771)),\n",
|
192
|
+
" 1048576: (np.float64(0.11971158300002571), np.float64(0.02209812945217875)),\n",
|
193
|
+
" 4194304: (np.float64(0.7133738380000523), np.float64(0.006134709481659322))},\n",
|
194
|
+
" 'jax': {16: (np.float64(0.0005724070000496795),\n",
|
195
|
+
" np.float64(0.0001069087688290125)),\n",
|
196
|
+
" 64: (np.float64(0.0007151319999820771), np.float64(0.0001322757281663352)),\n",
|
197
|
+
" 256: (np.float64(0.0009230550000438598), np.float64(0.00031332690291921287)),\n",
|
198
|
+
" 1024: (np.float64(0.001453614000070047), np.float64(0.00012679139424215402)),\n",
|
199
|
+
" 4096: (np.float64(0.00050203600005716), np.float64(8.404927733069893e-05)),\n",
|
200
|
+
" 16384: (np.float64(0.0007342449999896417), np.float64(6.60918000741e-05)),\n",
|
201
|
+
" 65536: (np.float64(0.0009854250000671527),\n",
|
202
|
+
" np.float64(9.820818468604813e-05)),\n",
|
203
|
+
" 262144: (np.float64(0.0005175009999902613),\n",
|
204
|
+
" np.float64(4.419451355415323e-05)),\n",
|
205
|
+
" 1048576: (np.float64(0.0003900709999697938),\n",
|
206
|
+
" np.float64(7.364248018965731e-05)),\n",
|
207
|
+
" 4194304: (np.float64(0.0004053449999901204),\n",
|
208
|
+
" np.float64(7.663108380403174e-05)),\n",
|
209
|
+
" 16777216: (np.float64(0.00039541899991490936),\n",
|
210
|
+
" np.float64(6.804382786659215e-05)),\n",
|
211
|
+
" 67108864: (np.float64(0.00040576800006419944),\n",
|
212
|
+
" np.float64(0.00010823411767385025)),\n",
|
213
|
+
" 268435456: (np.float64(0.00041394100003344647),\n",
|
214
|
+
" np.float64(0.00010930952627993826))},\n",
|
215
|
+
" 'cupy': {16: (np.float64(0.019849124999950618),\n",
|
216
|
+
" np.float64(0.0012326430575985117)),\n",
|
217
|
+
" 64: (np.float64(0.022546809499999654), np.float64(0.001692504727864258)),\n",
|
218
|
+
" 256: (np.float64(0.021691686499991647), np.float64(0.00294086495457803)),\n",
|
219
|
+
" 1024: (np.float64(0.022270052500005022), np.float64(0.003807071679929202)),\n",
|
220
|
+
" 4096: (np.float64(0.02085154750000129), np.float64(0.0010313077174266435)),\n",
|
221
|
+
" 16384: (np.float64(0.02093089249996183), np.float64(0.0010492676283390795)),\n",
|
222
|
+
" 65536: (np.float64(0.022014402500019514), np.float64(0.0006802123910234058)),\n",
|
223
|
+
" 262144: (np.float64(0.024796750999996675),\n",
|
224
|
+
" np.float64(0.0057645733683490815)),\n",
|
225
|
+
" 1048576: (np.float64(0.02190001700000721),\n",
|
226
|
+
" np.float64(0.0008732466576075845)),\n",
|
227
|
+
" 4194304: (np.float64(0.05879033200005779), np.float64(0.014904883472313363)),\n",
|
228
|
+
" 16777216: (np.float64(0.23404195400001981), np.float64(0.07533314482865239)),\n",
|
229
|
+
" 67108864: (nan, nan),\n",
|
230
|
+
" 268435456: (nan, nan)}}\n",
|
231
|
+
"\n",
|
232
|
+
"cpu_runtimes = {'numpy': {16: (np.float64(9.805700028664432e-05),\n",
|
233
|
+
" np.float64(0.0002719699903221544)),\n",
|
234
|
+
" 64: (np.float64(8.721999984118156e-05), np.float64(3.582454577267885e-05)),\n",
|
235
|
+
" 256: (np.float64(0.00017828199997893535),\n",
|
236
|
+
" np.float64(1.2305009703125784e-05)),\n",
|
237
|
+
" 1024: (np.float64(0.00036261100103729405),\n",
|
238
|
+
" np.float64(2.4494629976772865e-05)),\n",
|
239
|
+
" 4096: (np.float64(0.0006539139994856669), np.float64(0.0008354864099797487)),\n",
|
240
|
+
" 16384: (np.float64(0.002017660999626969), np.float64(0.0005277675716976345)),\n",
|
241
|
+
" 65536: (np.float64(0.007982569000887452), np.float64(0.0007685313919699793)),\n",
|
242
|
+
" 262144: (np.float64(0.03098107999903732),\n",
|
243
|
+
" np.float64(0.00045184700913424144)),\n",
|
244
|
+
" 1048576: (np.float64(0.12877668900000572), np.float64(0.002616350943173567)),\n",
|
245
|
+
" 4194304: (np.float64(0.719053980001263), np.float64(0.0469158455496374))},\n",
|
246
|
+
" 'jax': {16: (np.float64(0.00023953800155140925),\n",
|
247
|
+
" np.float64(0.00015563753017088007)),\n",
|
248
|
+
" 64: (np.float64(0.00035316300090926234), np.float64(0.0002318705268616366)),\n",
|
249
|
+
" 256: (np.float64(0.00029555400033132173), np.float64(8.832475279764218e-05)),\n",
|
250
|
+
" 1024: (np.float64(0.0004884099998889724),\n",
|
251
|
+
" np.float64(0.00024645926639514133)),\n",
|
252
|
+
" 4096: (np.float64(0.0004465590009203879),\n",
|
253
|
+
" np.float64(0.00019698850450255553)),\n",
|
254
|
+
" 16384: (np.float64(0.0005183960001886589),\n",
|
255
|
+
" np.float64(0.00025487691910374257)),\n",
|
256
|
+
" 65536: (np.float64(0.00037294600042514503),\n",
|
257
|
+
" np.float64(0.0008727913545748731)),\n",
|
258
|
+
" 262144: (np.float64(0.0033758919998945203),\n",
|
259
|
+
" np.float64(0.002688101382975047)),\n",
|
260
|
+
" 1048576: (np.float64(0.019741617999898153),\n",
|
261
|
+
" np.float64(0.010757652737059507)),\n",
|
262
|
+
" 4194304: (np.float64(0.10961731399947894), np.float64(0.05839073400005447))},\n",
|
263
|
+
" 'tpu_jax': {16: (np.float64(0.0001429589999588643),\n",
|
264
|
+
" np.float64(3.5459720815065046e-05)),\n",
|
265
|
+
" 64: (np.float64(0.00017167299995435314), np.float64(4.2662062404195546e-05)),\n",
|
266
|
+
" 256: (np.float64(0.00014238400001431728),\n",
|
267
|
+
" np.float64(2.8613417994725626e-05)),\n",
|
268
|
+
" 1024: (np.float64(0.00016977199993561953),\n",
|
269
|
+
" np.float64(0.0002816987455046532)),\n",
|
270
|
+
" 4096: (np.float64(0.00015566100000796723),\n",
|
271
|
+
" np.float64(4.2144860721421086e-05)),\n",
|
272
|
+
" 16384: (np.float64(0.00015526099991802766),\n",
|
273
|
+
" np.float64(3.326513666243062e-05)),\n",
|
274
|
+
" 65536: (np.float64(0.00011258099993938231),\n",
|
275
|
+
" np.float64(5.987320498690532e-05)),\n",
|
276
|
+
" 262144: (np.float64(0.00011269100002664345),\n",
|
277
|
+
" np.float64(3.9259844590263475e-05)),\n",
|
278
|
+
" 1048576: (np.float64(0.00012115300000914431),\n",
|
279
|
+
" np.float64(4.568541504776411e-05)),\n",
|
280
|
+
" 4194304: (np.float64(0.00013279199993121438),\n",
|
281
|
+
" np.float64(3.920358955286624e-05)),\n",
|
282
|
+
" 16777216: (np.float64(0.00012158700008058076),\n",
|
283
|
+
" np.float64(2.9125096906138658e-05)),\n",
|
284
|
+
" 67108864: (np.float64(0.0002409089998991476), np.float64(2.604900528570504)),\n",
|
285
|
+
" 268435456: (nan, nan)}}\n",
|
286
|
+
"\n"
|
287
|
+
],
|
288
|
+
"id": "7b1a6e9fd3a1955d",
|
289
|
+
"outputs": [],
|
290
|
+
"execution_count": 24
|
291
|
+
},
|
292
|
+
{
|
293
|
+
"metadata": {},
|
294
|
+
"cell_type": "code",
|
295
|
+
"outputs": [],
|
296
|
+
"execution_count": null,
|
297
|
+
"source": "",
|
298
|
+
"id": "59d0660f82642a3"
|
299
|
+
}
|
300
|
+
],
|
301
|
+
"metadata": {
|
302
|
+
"kernelspec": {
|
303
|
+
"display_name": "Python 3",
|
304
|
+
"language": "python",
|
305
|
+
"name": "python3"
|
306
|
+
},
|
307
|
+
"language_info": {
|
308
|
+
"codemirror_mode": {
|
309
|
+
"name": "ipython",
|
310
|
+
"version": 2
|
311
|
+
},
|
312
|
+
"file_extension": ".py",
|
313
|
+
"mimetype": "text/x-python",
|
314
|
+
"name": "python",
|
315
|
+
"nbconvert_exporter": "python",
|
316
|
+
"pygments_lexer": "ipython2",
|
317
|
+
"version": "2.7.6"
|
318
|
+
}
|
319
|
+
},
|
320
|
+
"nbformat": 4,
|
321
|
+
"nbformat_minor": 5
|
322
|
+
}
|
@@ -0,0 +1,101 @@
|
|
1
|
+
import importlib
|
2
|
+
import os
|
3
|
+
from rich.table import Table, Text
|
4
|
+
from rich.console import Console
|
5
|
+
|
6
|
+
|
7
|
+
|
8
|
+
from .logger import logger
|
9
|
+
|
10
|
+
JAX = "jax"
|
11
|
+
CUPY = "cupy"
|
12
|
+
NUMPY = "numpy"
|
13
|
+
|
14
|
+
|
15
|
+
def cuda_is_available() -> bool:
|
16
|
+
"""Check if CUDA is available."""
|
17
|
+
# Check if CuPy is available and CUDA is accessible
|
18
|
+
cupy_available = importlib.util.find_spec("cupy") is not None
|
19
|
+
if cupy_available:
|
20
|
+
import cupy
|
21
|
+
|
22
|
+
try:
|
23
|
+
cupy.cuda.runtime.getDeviceCount() # Check if any CUDA device is available
|
24
|
+
cuda_available = True
|
25
|
+
except cupy.cuda.runtime.CUDARuntimeError:
|
26
|
+
cuda_available = False
|
27
|
+
else:
|
28
|
+
cuda_available = False
|
29
|
+
return cuda_available
|
30
|
+
|
31
|
+
|
32
|
+
def jax_is_available() -> bool:
|
33
|
+
"""Check if JAX is available."""
|
34
|
+
return importlib.util.find_spec(JAX) is not None
|
35
|
+
|
36
|
+
|
37
|
+
def get_available_backends_table():
|
38
|
+
"""Print the available backends as a rich table."""
|
39
|
+
|
40
|
+
jax_avail = jax_is_available()
|
41
|
+
cupy_avail = cuda_is_available()
|
42
|
+
table = Table("Backend", "Available", title="Available backends")
|
43
|
+
true_check = "[green]✓[/green]"
|
44
|
+
false_check = "[red]✗[/red]"
|
45
|
+
table.add_row(JAX, true_check if jax_avail else false_check)
|
46
|
+
table.add_row(CUPY, true_check if cupy_avail else false_check)
|
47
|
+
table.add_row(NUMPY, true_check)
|
48
|
+
console = Console(width=150)
|
49
|
+
with console.capture() as capture:
|
50
|
+
console.print(table)
|
51
|
+
return Text.from_ansi(capture.get())
|
52
|
+
|
53
|
+
|
54
|
+
def get_backend_from_env():
|
55
|
+
"""Select and return the appropriate backend module."""
|
56
|
+
backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
|
57
|
+
|
58
|
+
if backend == JAX and jax_is_available():
|
59
|
+
|
60
|
+
import jax.numpy as xp
|
61
|
+
from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
|
62
|
+
from jax.scipy.special import betainc
|
63
|
+
|
64
|
+
logger.info("Using JAX backend")
|
65
|
+
|
66
|
+
elif backend == CUPY and cuda_is_available():
|
67
|
+
|
68
|
+
import cupy as xp
|
69
|
+
from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
|
70
|
+
from cupyx.scipy.special import betainc
|
71
|
+
|
72
|
+
logger.info("Using CuPy backend")
|
73
|
+
|
74
|
+
elif backend == NUMPY:
|
75
|
+
import numpy as xp
|
76
|
+
from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
|
77
|
+
from scipy.special import betainc
|
78
|
+
|
79
|
+
logger.info("Using NumPy backend")
|
80
|
+
|
81
|
+
|
82
|
+
else:
|
83
|
+
logger.error(
|
84
|
+
f"Backend {backend} is not available. "
|
85
|
+
)
|
86
|
+
print(get_available_backends_table())
|
87
|
+
logger.warning(
|
88
|
+
f"Setting backend to NumPy. "
|
89
|
+
)
|
90
|
+
os.environ["PYWAVELET_BACKEND"] = NUMPY
|
91
|
+
return get_backend_from_env()
|
92
|
+
|
93
|
+
return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
|
94
|
+
|
95
|
+
|
96
|
+
cuda_available = cuda_is_available()
|
97
|
+
|
98
|
+
# Get the chosen backend
|
99
|
+
xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
|
100
|
+
get_backend_from_env()
|
101
|
+
)
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from ...logger import logger
|
2
|
+
from .forward import from_freq_to_wavelet, from_time_to_wavelet
|
3
|
+
from .inverse import from_wavelet_to_freq, from_wavelet_to_time
|
4
|
+
|
5
|
+
logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
|
6
|
+
|
7
|
+
|
8
|
+
def _log_jax_info():
|
9
|
+
"""Log JAX backend and precision information.
|
10
|
+
|
11
|
+
backend : str
|
12
|
+
JAX backend. ["cpu", "gpu", "tpu"]
|
13
|
+
precision : str
|
14
|
+
JAX precision. ["32bit", "64bit"]
|
15
|
+
"""
|
16
|
+
import jax
|
17
|
+
|
18
|
+
_backend = jax.default_backend()
|
19
|
+
_precision = "64bit" if jax.config.jax_enable_x64 else "32bit"
|
20
|
+
|
21
|
+
logger.info(f"Jax running on {_backend} [{_precision} precision].")
|
22
|
+
if _precision == "32bit":
|
23
|
+
logger.warning(
|
24
|
+
"Jax is not running in 64bit precision. "
|
25
|
+
"To change, use jax.config.update('jax_enable_x64', True)."
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
_log_jax_info()
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
"from_wavelet_to_time",
|
33
|
+
"from_wavelet_to_freq",
|
34
|
+
"from_time_to_wavelet",
|
35
|
+
"from_freq_to_wavelet",
|
36
|
+
]
|
@@ -0,0 +1,42 @@
|
|
1
|
+
import importlib
|
2
|
+
import os
|
3
|
+
|
4
|
+
import matplotlib.pyplot as plt
|
5
|
+
import numpy as np
|
6
|
+
import pytest
|
7
|
+
from utils import cuda_available
|
8
|
+
|
9
|
+
from pywavelet import set_backend
|
10
|
+
from pywavelet.transforms import from_freq_to_wavelet
|
11
|
+
from pywavelet.types import FrequencySeries, TimeSeries
|
12
|
+
from pywavelet.utils import compute_snr, evolutionary_psd_from_stationary_psd
|
13
|
+
|
14
|
+
|
15
|
+
|
16
|
+
@pytest.mark.parametrize("backend", ["jax", "cupy", "numpy"])
|
17
|
+
def test_backend_loader(backend):
|
18
|
+
|
19
|
+
if backend == "cupy" and not cuda_available:
|
20
|
+
pytest.skip("CUDA is not available")
|
21
|
+
set_backend(backend)
|
22
|
+
from pywavelet.backend import current_backend, xp
|
23
|
+
|
24
|
+
assert current_backend == backend
|
25
|
+
|
26
|
+
set_backend("numpy")
|
27
|
+
|
28
|
+
def test_backend_fails_gracefully_if_no_cupy():
|
29
|
+
if cuda_available:
|
30
|
+
pytest.skip("CUDA is available")
|
31
|
+
|
32
|
+
set_backend('cupy')
|
33
|
+
from pywavelet.backend import current_backend
|
34
|
+
|
35
|
+
assert current_backend == 'numpy'
|
36
|
+
|
37
|
+
|
38
|
+
def test_backed_logger():
|
39
|
+
from pywavelet.backend import get_available_backends_table
|
40
|
+
print(get_available_backends_table())
|
41
|
+
|
42
|
+
|