pywavelet 0.2.6__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. {pywavelet-0.2.6 → pywavelet-0.2.7}/.gitignore +1 -0
  2. {pywavelet-0.2.6 → pywavelet-0.2.7}/CHANGELOG.rst +22 -0
  3. {pywavelet-0.2.6 → pywavelet-0.2.7}/PKG-INFO +1 -1
  4. pywavelet-0.2.7/docs/runtime.ipynb +322 -0
  5. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/_version.py +2 -2
  6. pywavelet-0.2.7/src/pywavelet/backend.py +101 -0
  7. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/__init__.py +6 -0
  8. pywavelet-0.2.7/src/pywavelet/transforms/jax/__init__.py +36 -0
  9. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/PKG-INFO +1 -1
  10. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/SOURCES.txt +0 -1
  11. pywavelet-0.2.7/tests/test_backends.py +42 -0
  12. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_roundtrip_conversion.py +27 -22
  13. pywavelet-0.2.7/tests/test_snr.py +249 -0
  14. pywavelet-0.2.7/tests/utils/__init__.py +3 -0
  15. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/utils/conversions.py +9 -3
  16. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/utils/plotting.py +6 -6
  17. pywavelet-0.2.6/docs/runtime.ipynb +0 -201
  18. pywavelet-0.2.6/src/pywavelet/backend.py +0 -53
  19. pywavelet-0.2.6/src/pywavelet/transforms/jax/__init__.py +0 -12
  20. pywavelet-0.2.6/tests/test_backends.py +0 -136
  21. pywavelet-0.2.6/tests/test_snr.py +0 -123
  22. pywavelet-0.2.6/tests/utils/__init__.py +0 -4
  23. pywavelet-0.2.6/tests/utils/cupy_check.py +0 -14
  24. {pywavelet-0.2.6 → pywavelet-0.2.7}/.github/workflows/ci.yml +0 -0
  25. {pywavelet-0.2.6 → pywavelet-0.2.7}/.github/workflows/docs.yml +0 -0
  26. {pywavelet-0.2.6 → pywavelet-0.2.7}/.github/workflows/pypi.yml +0 -0
  27. {pywavelet-0.2.6 → pywavelet-0.2.7}/.pre-commit-config.yaml +0 -0
  28. {pywavelet-0.2.6 → pywavelet-0.2.7}/CITATION.cff +0 -0
  29. {pywavelet-0.2.6 → pywavelet-0.2.7}/README.rst +0 -0
  30. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/_config.yml +0 -0
  31. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/_static/demo.gif +0 -0
  32. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/_toc.yml +0 -0
  33. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/api.rst +0 -0
  34. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/example.ipynb +0 -0
  35. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/index.rst +0 -0
  36. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/logo.png +0 -0
  37. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/roundtrip_freq.png +0 -0
  38. {pywavelet-0.2.6 → pywavelet-0.2.7}/docs/roundtrip_time.png +0 -0
  39. {pywavelet-0.2.6 → pywavelet-0.2.7}/pyproject.toml +0 -0
  40. {pywavelet-0.2.6 → pywavelet-0.2.7}/setup.cfg +0 -0
  41. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/__init__.py +0 -0
  42. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/logger.py +0 -0
  43. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/__init__.py +0 -0
  44. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/forward/__init__.py +0 -0
  45. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/forward/from_freq.py +0 -0
  46. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/forward/from_time.py +0 -0
  47. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/forward/main.py +0 -0
  48. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/inverse/__init__.py +0 -0
  49. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/inverse/main.py +0 -0
  50. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/cupy/inverse/to_freq.py +0 -0
  51. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/forward/__init__.py +0 -0
  52. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/forward/from_freq.py +0 -0
  53. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/forward/from_time.py +0 -0
  54. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/forward/main.py +0 -0
  55. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/inverse/__init__.py +0 -0
  56. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/inverse/main.py +0 -0
  57. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/jax/inverse/to_freq.py +0 -0
  58. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/__init__.py +0 -0
  59. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/forward/__init__.py +0 -0
  60. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/forward/from_freq.py +0 -0
  61. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/forward/from_time.py +0 -0
  62. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/forward/main.py +0 -0
  63. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/inverse/__init__.py +0 -0
  64. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/inverse/main.py +0 -0
  65. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/inverse/to_freq.py +0 -0
  66. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/numpy/inverse/to_time.py +0 -0
  67. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/transforms/phi_computer.py +0 -0
  68. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/__init__.py +0 -0
  69. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/common.py +0 -0
  70. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/frequencyseries.py +0 -0
  71. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/plotting.py +0 -0
  72. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/timeseries.py +0 -0
  73. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/wavelet.py +0 -0
  74. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/types/wavelet_bins.py +0 -0
  75. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet/utils.py +0 -0
  76. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/dependency_links.txt +0 -0
  77. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/requires.txt +0 -0
  78. {pywavelet-0.2.6 → pywavelet-0.2.7}/src/pywavelet.egg-info/top_level.txt +0 -0
  79. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/conftest.py +0 -0
  80. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/ollie_example.py +0 -0
  81. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_chirp_freq.npz +0 -0
  82. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_chirp_time.npz +0 -0
  83. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_pure_f0_freq.npz +0 -0
  84. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_sine_freq.npz +0 -0
  85. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_data/roundtrip_sine_time.npz +0 -0
  86. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_docs.py +0 -0
  87. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_lnl.py +0 -0
  88. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_mask.py +0 -0
  89. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_phi.py +0 -0
  90. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_psd.py +0 -0
  91. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_timefreq_type.py +0 -0
  92. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_version.py +0 -0
  93. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/test_wavelet_plot.py +0 -0
  94. {pywavelet-0.2.6 → pywavelet-0.2.7}/tests/utils/generate_data.py +0 -0
@@ -119,3 +119,4 @@ venv.bak/
119
119
 
120
120
  # VSCode
121
121
  .vscode/
122
+ /docs/runtime_data_cpu.csv
@@ -5,6 +5,19 @@ CHANGELOG
5
5
  =========
6
6
 
7
7
 
8
+ .. _changelog-v0.2.7:
9
+
10
+ v0.2.7 (2025-03-25)
11
+ ===================
12
+
13
+ Unknown
14
+ -------
15
+
16
+ * Merge branch 'main' of github.com:pywavelet/pywavelet (`ef5aa31`_)
17
+
18
+ .. _ef5aa31: https://github.com/pywavelet/pywavelet/commit/ef5aa31e72c8a838cfb008c03bd47b009df728dc
19
+
20
+
8
21
  .. _changelog-v0.2.6:
9
22
 
10
23
  v0.2.6 (2025-03-24)
@@ -13,8 +26,15 @@ v0.2.6 (2025-03-24)
13
26
  Bug Fixes
14
27
  ---------
15
28
 
29
+ * fix: add more checkingn for cupy and jax backends (`75e54c0`_)
30
+
16
31
  * fix: add note in readme about JAX and cupy (`d3cd8d9`_)
17
32
 
33
+ Chores
34
+ ------
35
+
36
+ * chore(release): 0.2.6 (`d2e3cb2`_)
37
+
18
38
  Unknown
19
39
  -------
20
40
 
@@ -120,7 +140,9 @@ Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.gi
120
140
 
121
141
  * add additional JAX tests (`c0d8396`_)
122
142
 
143
+ .. _75e54c0: https://github.com/pywavelet/pywavelet/commit/75e54c00dd9ddfff98a7f0612f54abf1ac17184e
123
144
  .. _d3cd8d9: https://github.com/pywavelet/pywavelet/commit/d3cd8d92b5e6cf398ff3a4948b911533abe842f1
145
+ .. _d2e3cb2: https://github.com/pywavelet/pywavelet/commit/d2e3cb271b06d1d9bc3060e924fa5bcd57ceb269
124
146
  .. _6cdfb28: https://github.com/pywavelet/pywavelet/commit/6cdfb28b152a7f7ad499f0b6ba6ef69da9284c57
125
147
  .. _14b250b: https://github.com/pywavelet/pywavelet/commit/14b250b55dbea88c3b7c22b5faca531113d34477
126
148
  .. _3d8b4cd: https://github.com/pywavelet/pywavelet/commit/3d8b4cdc8dea6af33b08985cb97f0897984838fc
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pywavelet
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -0,0 +1,322 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {},
5
+ "cell_type": "markdown",
6
+ "source": [
7
+ "<a target=\"_blank\" href=\"https://colab.research.google.com/github/pywavelet/pywavelet/blob/main/docs/runtime.ipynb\">\n",
8
+ " <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
9
+ "</a>\n",
10
+ "\n",
11
+ "\n",
12
+ "# Runtime Comparisons"
13
+ ],
14
+ "id": "9df79fd10b46e463"
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "id": "initial_id",
19
+ "metadata": {
20
+ "collapsed": true
21
+ },
22
+ "source": [
23
+ "import importlib\n",
24
+ "import numpy as np\n",
25
+ "import jax.numpy as jnp\n",
26
+ "import pandas as pd \n",
27
+ "import jax\n",
28
+ "from pywavelet.backend import cuda_available\n",
29
+ "from tqdm.auto import tqdm\n",
30
+ "from pywavelet.types import FrequencySeries\n",
31
+ "from pywavelet.transforms.phi_computer import phitilde_vec_norm\n",
32
+ "from timeit import repeat as timing_repeat\n",
33
+ "import matplotlib.pyplot as plt\n",
34
+ "\n",
35
+ "jax.config.update(\"jax_enable_x64\", False)\n",
36
+ " \n",
37
+ "JAX_DEVICE = jax.default_backend()\n",
38
+ "JAX_PRECISION = \"x64\" if jax.config.jax_enable_x64 else \"x32\"\n",
39
+ "\n",
40
+ "\n",
41
+ "if cuda_available:\n",
42
+ " import cupy as cp\n",
43
+ "\n",
44
+ "\n",
45
+ "def generate_freq_domain_signal(\n",
46
+ " ND, f0=20.0, dt=0.0125, A=2\n",
47
+ ") -> FrequencySeries:\n",
48
+ " \"\"\"\n",
49
+ " Generates a frequency domain signal.\n",
50
+ "\n",
51
+ " Parameters:\n",
52
+ " ND (int): Number of data points.\n",
53
+ " f0 (float): Frequency of the signal. Default is 20.0.\n",
54
+ " dt (float): Time step. Default is 0.0125.\n",
55
+ " A (float): Amplitude of the signal. Default is 2.\n",
56
+ "\n",
57
+ " Returns:\n",
58
+ " FrequencySeries: The generated frequency domain signal.\n",
59
+ " \"\"\"\n",
60
+ " ts = np.arange(0, ND) * dt\n",
61
+ " y = A * np.sin(2 * np.pi * f0 * ts)\n",
62
+ " yf = FrequencySeries(y, ts)\n",
63
+ " return yf\n",
64
+ "\n",
65
+ "\n",
66
+ "def generate_func_args(ND, backend=\"numpy\"):\n",
67
+ " Nf = Nt = int(np.sqrt(ND))\n",
68
+ " yf = generate_freq_domain_signal(ND).data\n",
69
+ " phif = phitilde_vec_norm(Nf, Nt, d=4.0)\n",
70
+ " if backend == \"jax\":\n",
71
+ " yf = jnp.array(yf)\n",
72
+ " phif = jnp.array(phif)\n",
73
+ " if backend == \"cupy\" and cuda_available:\n",
74
+ " yf = cp.array(yf)\n",
75
+ " phif = cp.array(phif)\n",
76
+ " return yf, Nf, Nt, phif\n",
77
+ "\n",
78
+ "\n",
79
+ "def collect_runtime(\n",
80
+ " func, func_args, n=5, nreps=5\n",
81
+ "):\n",
82
+ " func(*func_args) # Warm up run\n",
83
+ " times = timing_repeat(\n",
84
+ " lambda: func(*func_args),\n",
85
+ " number=n,\n",
86
+ " repeat=nreps\n",
87
+ " )\n",
88
+ "\n",
89
+ " return (np.median(times), (np.std(times)))\n",
90
+ "\n",
91
+ "\n",
92
+ "def collect_runtimes(func, backend, NF_values, number=5, repeat=5):\n",
93
+ " results = {}\n",
94
+ " bar = tqdm(NF_values, desc=\"Running\")\n",
95
+ " for Nf in bar:\n",
96
+ " ND = Nf * Nf\n",
97
+ " bar.set_postfix(ND=f\"2**{int(np.log2(ND))}\")\n",
98
+ " func_args = generate_func_args(ND, backend)\n",
99
+ " try:\n",
100
+ " results[ND] = collect_runtime(func, func_args, number, repeat)\n",
101
+ " except Exception as e:\n",
102
+ " print(f\"Error processing ND={ND}: {e}\")\n",
103
+ " results[ND] = (np.nan, np.nan)\n",
104
+ " return results\n",
105
+ "\n",
106
+ "\n",
107
+ "def run_transforms():\n",
108
+ " from pywavelet.transforms.jax.forward.from_freq import transform_wavelet_freq_helper as jax_transform\n",
109
+ " from pywavelet.transforms.numpy.forward.from_freq import transform_wavelet_freq_helper as np_transform\n",
110
+ "\n",
111
+ " min_pow2 = 2\n",
112
+ " max_pow2 = 12\n",
113
+ " NF = [2 ** i for i in range(min_pow2, max_pow2)]\n",
114
+ "\n",
115
+ " runtimes = {}\n",
116
+ " runtimes[\"numpy\"] = collect_runtimes(np_transform, \"numpy\", NF, number=5, repeat=5)\n",
117
+ " \n",
118
+ " max_pow2 = 15\n",
119
+ " NF = [2 ** i for i in range(min_pow2, max_pow2)]\n",
120
+ " runtimes[\"jax\"] = collect_runtimes(jax_transform, \"jax\", NF, number=5, repeat=5)\n",
121
+ "\n",
122
+ " if cuda_available:\n",
123
+ " from pywavelet.transforms.cupy.forward.from_freq import transform_wavelet_freq_helper as cp_transform\n",
124
+ " runtimes['cupy'] = collect_runtimes(cp_transform, \"cupy\", NF, number=10, repeat=10)\n",
125
+ " \n",
126
+ " \n",
127
+ " return runtimes\n",
128
+ "\n",
129
+ "\n",
130
+ "def plot(runtimes):\n",
131
+ " fig, ax = plt.subplots(figsize=(4, 3.5))\n",
132
+ " for i, backend in enumerate(runtimes.keys()):\n",
133
+ " _plot_backend_runtime(ax, runtimes[backend], backend, f\"C{i}\")\n",
134
+ " ax.set_yscale(\"log\")\n",
135
+ " ax.set_xscale(\"log\")\n",
136
+ " ax.set_xlabel(\"Number of Data Points\")\n",
137
+ " ax.set_ylabel(\"Runtime (s)\")\n",
138
+ " ax.legend(frameon=False)\n",
139
+ " return fig, ax\n",
140
+ "\n",
141
+ "\n",
142
+ "def _plot_backend_runtime(ax, runtimes, backend, color):\n",
143
+ " NDs = list(runtimes.keys())\n",
144
+ " times = [runtimes[ND][0] for ND in NDs]\n",
145
+ " stds = [runtimes[ND][1] for ND in NDs]\n",
146
+ " # plot a band around the median runtime\n",
147
+ " ax.fill_between(NDs, np.array(times) - np.array(stds), np.array(times) + np.array(stds), alpha=0.3, color=color)\n",
148
+ " ax.plot(NDs, times, label=f\"{backend}\", color=color)\n",
149
+ " ax.set_xlim(min(NDs), max(NDs))\n",
150
+ "\n",
151
+ "\n",
152
+ "runtimes = run_transforms()\n",
153
+ "# save runtime data as a txt file\n",
154
+ "runtime_data = pd.DataFrame(runtimes)\n",
155
+ "runtime_data.to_csv(f\"runtime_data_{JAX_DEVICE}.csv\")\n",
156
+ " \n",
157
+ "\n",
158
+ "fig, ax = plot(runtimes)\n",
159
+ "fig.savefig(\"runtime.png\", bbox_inches=\"tight\")\n",
160
+ "\n"
161
+ ],
162
+ "outputs": [],
163
+ "execution_count": null
164
+ },
165
+ {
166
+ "metadata": {
167
+ "ExecuteTime": {
168
+ "end_time": "2025-03-24T07:53:51.232554Z",
169
+ "start_time": "2025-03-24T07:53:51.116610Z"
170
+ }
171
+ },
172
+ "cell_type": "code",
173
+ "source": [
174
+ "nan = np.nan\n",
175
+ "\n",
176
+ "\n",
177
+ "gpu_runtimes = {'numpy': {16: (np.float64(2.7647000024444424e-05),\n",
178
+ " np.float64(1.6699892612917253e-05)),\n",
179
+ " 64: (np.float64(4.7878000032142154e-05), np.float64(7.901239258106135e-06)),\n",
180
+ " 256: (np.float64(0.00014417199997751595),\n",
181
+ " np.float64(2.8245273864177316e-05)),\n",
182
+ " 1024: (np.float64(0.00025147099995592725),\n",
183
+ " np.float64(8.368707136149373e-06)),\n",
184
+ " 4096: (np.float64(0.0008757960000593812),\n",
185
+ " np.float64(0.00039357430227278433)),\n",
186
+ " 16384: (np.float64(0.0017043970000258923),\n",
187
+ " np.float64(0.0016274004399110828)),\n",
188
+ " 65536: (np.float64(0.006220592999966357),\n",
189
+ " np.float64(0.00021847848737320755)),\n",
190
+ " 262144: (np.float64(0.023843203999945217),\n",
191
+ " np.float64(0.0008785219521943771)),\n",
192
+ " 1048576: (np.float64(0.11971158300002571), np.float64(0.02209812945217875)),\n",
193
+ " 4194304: (np.float64(0.7133738380000523), np.float64(0.006134709481659322))},\n",
194
+ " 'jax': {16: (np.float64(0.0005724070000496795),\n",
195
+ " np.float64(0.0001069087688290125)),\n",
196
+ " 64: (np.float64(0.0007151319999820771), np.float64(0.0001322757281663352)),\n",
197
+ " 256: (np.float64(0.0009230550000438598), np.float64(0.00031332690291921287)),\n",
198
+ " 1024: (np.float64(0.001453614000070047), np.float64(0.00012679139424215402)),\n",
199
+ " 4096: (np.float64(0.00050203600005716), np.float64(8.404927733069893e-05)),\n",
200
+ " 16384: (np.float64(0.0007342449999896417), np.float64(6.60918000741e-05)),\n",
201
+ " 65536: (np.float64(0.0009854250000671527),\n",
202
+ " np.float64(9.820818468604813e-05)),\n",
203
+ " 262144: (np.float64(0.0005175009999902613),\n",
204
+ " np.float64(4.419451355415323e-05)),\n",
205
+ " 1048576: (np.float64(0.0003900709999697938),\n",
206
+ " np.float64(7.364248018965731e-05)),\n",
207
+ " 4194304: (np.float64(0.0004053449999901204),\n",
208
+ " np.float64(7.663108380403174e-05)),\n",
209
+ " 16777216: (np.float64(0.00039541899991490936),\n",
210
+ " np.float64(6.804382786659215e-05)),\n",
211
+ " 67108864: (np.float64(0.00040576800006419944),\n",
212
+ " np.float64(0.00010823411767385025)),\n",
213
+ " 268435456: (np.float64(0.00041394100003344647),\n",
214
+ " np.float64(0.00010930952627993826))},\n",
215
+ " 'cupy': {16: (np.float64(0.019849124999950618),\n",
216
+ " np.float64(0.0012326430575985117)),\n",
217
+ " 64: (np.float64(0.022546809499999654), np.float64(0.001692504727864258)),\n",
218
+ " 256: (np.float64(0.021691686499991647), np.float64(0.00294086495457803)),\n",
219
+ " 1024: (np.float64(0.022270052500005022), np.float64(0.003807071679929202)),\n",
220
+ " 4096: (np.float64(0.02085154750000129), np.float64(0.0010313077174266435)),\n",
221
+ " 16384: (np.float64(0.02093089249996183), np.float64(0.0010492676283390795)),\n",
222
+ " 65536: (np.float64(0.022014402500019514), np.float64(0.0006802123910234058)),\n",
223
+ " 262144: (np.float64(0.024796750999996675),\n",
224
+ " np.float64(0.0057645733683490815)),\n",
225
+ " 1048576: (np.float64(0.02190001700000721),\n",
226
+ " np.float64(0.0008732466576075845)),\n",
227
+ " 4194304: (np.float64(0.05879033200005779), np.float64(0.014904883472313363)),\n",
228
+ " 16777216: (np.float64(0.23404195400001981), np.float64(0.07533314482865239)),\n",
229
+ " 67108864: (nan, nan),\n",
230
+ " 268435456: (nan, nan)}}\n",
231
+ "\n",
232
+ "cpu_runtimes = {'numpy': {16: (np.float64(9.805700028664432e-05),\n",
233
+ " np.float64(0.0002719699903221544)),\n",
234
+ " 64: (np.float64(8.721999984118156e-05), np.float64(3.582454577267885e-05)),\n",
235
+ " 256: (np.float64(0.00017828199997893535),\n",
236
+ " np.float64(1.2305009703125784e-05)),\n",
237
+ " 1024: (np.float64(0.00036261100103729405),\n",
238
+ " np.float64(2.4494629976772865e-05)),\n",
239
+ " 4096: (np.float64(0.0006539139994856669), np.float64(0.0008354864099797487)),\n",
240
+ " 16384: (np.float64(0.002017660999626969), np.float64(0.0005277675716976345)),\n",
241
+ " 65536: (np.float64(0.007982569000887452), np.float64(0.0007685313919699793)),\n",
242
+ " 262144: (np.float64(0.03098107999903732),\n",
243
+ " np.float64(0.00045184700913424144)),\n",
244
+ " 1048576: (np.float64(0.12877668900000572), np.float64(0.002616350943173567)),\n",
245
+ " 4194304: (np.float64(0.719053980001263), np.float64(0.0469158455496374))},\n",
246
+ " 'jax': {16: (np.float64(0.00023953800155140925),\n",
247
+ " np.float64(0.00015563753017088007)),\n",
248
+ " 64: (np.float64(0.00035316300090926234), np.float64(0.0002318705268616366)),\n",
249
+ " 256: (np.float64(0.00029555400033132173), np.float64(8.832475279764218e-05)),\n",
250
+ " 1024: (np.float64(0.0004884099998889724),\n",
251
+ " np.float64(0.00024645926639514133)),\n",
252
+ " 4096: (np.float64(0.0004465590009203879),\n",
253
+ " np.float64(0.00019698850450255553)),\n",
254
+ " 16384: (np.float64(0.0005183960001886589),\n",
255
+ " np.float64(0.00025487691910374257)),\n",
256
+ " 65536: (np.float64(0.00037294600042514503),\n",
257
+ " np.float64(0.0008727913545748731)),\n",
258
+ " 262144: (np.float64(0.0033758919998945203),\n",
259
+ " np.float64(0.002688101382975047)),\n",
260
+ " 1048576: (np.float64(0.019741617999898153),\n",
261
+ " np.float64(0.010757652737059507)),\n",
262
+ " 4194304: (np.float64(0.10961731399947894), np.float64(0.05839073400005447))},\n",
263
+ " 'tpu_jax': {16: (np.float64(0.0001429589999588643),\n",
264
+ " np.float64(3.5459720815065046e-05)),\n",
265
+ " 64: (np.float64(0.00017167299995435314), np.float64(4.2662062404195546e-05)),\n",
266
+ " 256: (np.float64(0.00014238400001431728),\n",
267
+ " np.float64(2.8613417994725626e-05)),\n",
268
+ " 1024: (np.float64(0.00016977199993561953),\n",
269
+ " np.float64(0.0002816987455046532)),\n",
270
+ " 4096: (np.float64(0.00015566100000796723),\n",
271
+ " np.float64(4.2144860721421086e-05)),\n",
272
+ " 16384: (np.float64(0.00015526099991802766),\n",
273
+ " np.float64(3.326513666243062e-05)),\n",
274
+ " 65536: (np.float64(0.00011258099993938231),\n",
275
+ " np.float64(5.987320498690532e-05)),\n",
276
+ " 262144: (np.float64(0.00011269100002664345),\n",
277
+ " np.float64(3.9259844590263475e-05)),\n",
278
+ " 1048576: (np.float64(0.00012115300000914431),\n",
279
+ " np.float64(4.568541504776411e-05)),\n",
280
+ " 4194304: (np.float64(0.00013279199993121438),\n",
281
+ " np.float64(3.920358955286624e-05)),\n",
282
+ " 16777216: (np.float64(0.00012158700008058076),\n",
283
+ " np.float64(2.9125096906138658e-05)),\n",
284
+ " 67108864: (np.float64(0.0002409089998991476), np.float64(2.604900528570504)),\n",
285
+ " 268435456: (nan, nan)}}\n",
286
+ "\n"
287
+ ],
288
+ "id": "7b1a6e9fd3a1955d",
289
+ "outputs": [],
290
+ "execution_count": 24
291
+ },
292
+ {
293
+ "metadata": {},
294
+ "cell_type": "code",
295
+ "outputs": [],
296
+ "execution_count": null,
297
+ "source": "",
298
+ "id": "59d0660f82642a3"
299
+ }
300
+ ],
301
+ "metadata": {
302
+ "kernelspec": {
303
+ "display_name": "Python 3",
304
+ "language": "python",
305
+ "name": "python3"
306
+ },
307
+ "language_info": {
308
+ "codemirror_mode": {
309
+ "name": "ipython",
310
+ "version": 2
311
+ },
312
+ "file_extension": ".py",
313
+ "mimetype": "text/x-python",
314
+ "name": "python",
315
+ "nbconvert_exporter": "python",
316
+ "pygments_lexer": "ipython2",
317
+ "version": "2.7.6"
318
+ }
319
+ },
320
+ "nbformat": 4,
321
+ "nbformat_minor": 5
322
+ }
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.6'
21
- __version_tuple__ = version_tuple = (0, 2, 6)
20
+ __version__ = version = '0.2.7'
21
+ __version_tuple__ = version_tuple = (0, 2, 7)
@@ -0,0 +1,101 @@
1
+ import importlib
2
+ import os
3
+ from rich.table import Table, Text
4
+ from rich.console import Console
5
+
6
+
7
+
8
+ from .logger import logger
9
+
10
+ JAX = "jax"
11
+ CUPY = "cupy"
12
+ NUMPY = "numpy"
13
+
14
+
15
+ def cuda_is_available() -> bool:
16
+ """Check if CUDA is available."""
17
+ # Check if CuPy is available and CUDA is accessible
18
+ cupy_available = importlib.util.find_spec("cupy") is not None
19
+ if cupy_available:
20
+ import cupy
21
+
22
+ try:
23
+ cupy.cuda.runtime.getDeviceCount() # Check if any CUDA device is available
24
+ cuda_available = True
25
+ except cupy.cuda.runtime.CUDARuntimeError:
26
+ cuda_available = False
27
+ else:
28
+ cuda_available = False
29
+ return cuda_available
30
+
31
+
32
+ def jax_is_available() -> bool:
33
+ """Check if JAX is available."""
34
+ return importlib.util.find_spec(JAX) is not None
35
+
36
+
37
+ def get_available_backends_table():
38
+ """Print the available backends as a rich table."""
39
+
40
+ jax_avail = jax_is_available()
41
+ cupy_avail = cuda_is_available()
42
+ table = Table("Backend", "Available", title="Available backends")
43
+ true_check = "[green]✓[/green]"
44
+ false_check = "[red]✗[/red]"
45
+ table.add_row(JAX, true_check if jax_avail else false_check)
46
+ table.add_row(CUPY, true_check if cupy_avail else false_check)
47
+ table.add_row(NUMPY, true_check)
48
+ console = Console(width=150)
49
+ with console.capture() as capture:
50
+ console.print(table)
51
+ return Text.from_ansi(capture.get())
52
+
53
+
54
+ def get_backend_from_env():
55
+ """Select and return the appropriate backend module."""
56
+ backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
57
+
58
+ if backend == JAX and jax_is_available():
59
+
60
+ import jax.numpy as xp
61
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
62
+ from jax.scipy.special import betainc
63
+
64
+ logger.info("Using JAX backend")
65
+
66
+ elif backend == CUPY and cuda_is_available():
67
+
68
+ import cupy as xp
69
+ from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
70
+ from cupyx.scipy.special import betainc
71
+
72
+ logger.info("Using CuPy backend")
73
+
74
+ elif backend == NUMPY:
75
+ import numpy as xp
76
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
77
+ from scipy.special import betainc
78
+
79
+ logger.info("Using NumPy backend")
80
+
81
+
82
+ else:
83
+ logger.error(
84
+ f"Backend {backend} is not available. "
85
+ )
86
+ print(get_available_backends_table())
87
+ logger.warning(
88
+ f"Setting backend to NumPy. "
89
+ )
90
+ os.environ["PYWAVELET_BACKEND"] = NUMPY
91
+ return get_backend_from_env()
92
+
93
+ return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
94
+
95
+
96
+ cuda_available = cuda_is_available()
97
+
98
+ # Get the chosen backend
99
+ xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
100
+ get_backend_from_env()
101
+ )
@@ -1,5 +1,11 @@
1
+ from ..logger import logger
2
+
1
3
  from ..backend import current_backend
2
4
 
5
+
6
+ logger.debug(f"Using {current_backend} backend")
7
+
8
+
3
9
  if current_backend == "jax":
4
10
  from .jax import (
5
11
  from_freq_to_wavelet,
@@ -0,0 +1,36 @@
1
+ from ...logger import logger
2
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
3
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
+
5
+ logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
6
+
7
+
8
+ def _log_jax_info():
9
+ """Log JAX backend and precision information.
10
+
11
+ backend : str
12
+ JAX backend. ["cpu", "gpu", "tpu"]
13
+ precision : str
14
+ JAX precision. ["32bit", "64bit"]
15
+ """
16
+ import jax
17
+
18
+ _backend = jax.default_backend()
19
+ _precision = "64bit" if jax.config.jax_enable_x64 else "32bit"
20
+
21
+ logger.info(f"Jax running on {_backend} [{_precision} precision].")
22
+ if _precision == "32bit":
23
+ logger.warning(
24
+ "Jax is not running in 64bit precision. "
25
+ "To change, use jax.config.update('jax_enable_x64', True)."
26
+ )
27
+
28
+
29
+ _log_jax_info()
30
+
31
+ __all__ = [
32
+ "from_wavelet_to_time",
33
+ "from_wavelet_to_freq",
34
+ "from_time_to_wavelet",
35
+ "from_freq_to_wavelet",
36
+ ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pywavelet
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -81,6 +81,5 @@ tests/test_data/roundtrip_sine_freq.npz
81
81
  tests/test_data/roundtrip_sine_time.npz
82
82
  tests/utils/__init__.py
83
83
  tests/utils/conversions.py
84
- tests/utils/cupy_check.py
85
84
  tests/utils/generate_data.py
86
85
  tests/utils/plotting.py
@@ -0,0 +1,42 @@
1
+ import importlib
2
+ import os
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import pytest
7
+ from utils import cuda_available
8
+
9
+ from pywavelet import set_backend
10
+ from pywavelet.transforms import from_freq_to_wavelet
11
+ from pywavelet.types import FrequencySeries, TimeSeries
12
+ from pywavelet.utils import compute_snr, evolutionary_psd_from_stationary_psd
13
+
14
+
15
+
16
+ @pytest.mark.parametrize("backend", ["jax", "cupy", "numpy"])
17
+ def test_backend_loader(backend):
18
+
19
+ if backend == "cupy" and not cuda_available:
20
+ pytest.skip("CUDA is not available")
21
+ set_backend(backend)
22
+ from pywavelet.backend import current_backend, xp
23
+
24
+ assert current_backend == backend
25
+
26
+ set_backend("numpy")
27
+
28
+ def test_backend_fails_gracefully_if_no_cupy():
29
+ if cuda_available:
30
+ pytest.skip("CUDA is available")
31
+
32
+ set_backend('cupy')
33
+ from pywavelet.backend import current_backend
34
+
35
+ assert current_backend == 'numpy'
36
+
37
+
38
+ def test_backed_logger():
39
+ from pywavelet.backend import get_available_backends_table
40
+ print(get_available_backends_table())
41
+
42
+