pywavelet 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pywavelet/_version.py +2 -2
- pywavelet/backend.py +80 -32
- pywavelet/transforms/__init__.py +6 -0
- pywavelet/transforms/jax/__init__.py +24 -0
- {pywavelet-0.2.6.dist-info → pywavelet-0.2.7.dist-info}/METADATA +1 -1
- {pywavelet-0.2.6.dist-info → pywavelet-0.2.7.dist-info}/RECORD +8 -8
- {pywavelet-0.2.6.dist-info → pywavelet-0.2.7.dist-info}/WHEEL +1 -1
- {pywavelet-0.2.6.dist-info → pywavelet-0.2.7.dist-info}/top_level.txt +0 -0
pywavelet/_version.py
CHANGED
pywavelet/backend.py
CHANGED
@@ -1,5 +1,9 @@
|
|
1
1
|
import importlib
|
2
2
|
import os
|
3
|
+
from rich.table import Table, Text
|
4
|
+
from rich.console import Console
|
5
|
+
|
6
|
+
|
3
7
|
|
4
8
|
from .logger import logger
|
5
9
|
|
@@ -8,45 +12,89 @@ CUPY = "cupy"
|
|
8
12
|
NUMPY = "numpy"
|
9
13
|
|
10
14
|
|
15
|
+
def cuda_is_available() -> bool:
|
16
|
+
"""Check if CUDA is available."""
|
17
|
+
# Check if CuPy is available and CUDA is accessible
|
18
|
+
cupy_available = importlib.util.find_spec("cupy") is not None
|
19
|
+
if cupy_available:
|
20
|
+
import cupy
|
21
|
+
|
22
|
+
try:
|
23
|
+
cupy.cuda.runtime.getDeviceCount() # Check if any CUDA device is available
|
24
|
+
cuda_available = True
|
25
|
+
except cupy.cuda.runtime.CUDARuntimeError:
|
26
|
+
cuda_available = False
|
27
|
+
else:
|
28
|
+
cuda_available = False
|
29
|
+
return cuda_available
|
30
|
+
|
31
|
+
|
32
|
+
def jax_is_available() -> bool:
|
33
|
+
"""Check if JAX is available."""
|
34
|
+
return importlib.util.find_spec(JAX) is not None
|
35
|
+
|
36
|
+
|
37
|
+
def get_available_backends_table():
|
38
|
+
"""Print the available backends as a rich table."""
|
39
|
+
|
40
|
+
jax_avail = jax_is_available()
|
41
|
+
cupy_avail = cuda_is_available()
|
42
|
+
table = Table("Backend", "Available", title="Available backends")
|
43
|
+
true_check = "[green]✓[/green]"
|
44
|
+
false_check = "[red]✗[/red]"
|
45
|
+
table.add_row(JAX, true_check if jax_avail else false_check)
|
46
|
+
table.add_row(CUPY, true_check if cupy_avail else false_check)
|
47
|
+
table.add_row(NUMPY, true_check)
|
48
|
+
console = Console(width=150)
|
49
|
+
with console.capture() as capture:
|
50
|
+
console.print(table)
|
51
|
+
return Text.from_ansi(capture.get())
|
52
|
+
|
53
|
+
|
11
54
|
def get_backend_from_env():
|
12
55
|
"""Select and return the appropriate backend module."""
|
13
56
|
backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
|
14
57
|
|
15
|
-
if backend == JAX:
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
58
|
+
if backend == JAX and jax_is_available():
|
59
|
+
|
60
|
+
import jax.numpy as xp
|
61
|
+
from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
|
62
|
+
from jax.scipy.special import betainc
|
63
|
+
|
64
|
+
logger.info("Using JAX backend")
|
65
|
+
|
66
|
+
elif backend == CUPY and cuda_is_available():
|
67
|
+
|
68
|
+
import cupy as xp
|
69
|
+
from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
|
70
|
+
from cupyx.scipy.special import betainc
|
71
|
+
|
72
|
+
logger.info("Using CuPy backend")
|
73
|
+
|
74
|
+
elif backend == NUMPY:
|
75
|
+
import numpy as xp
|
76
|
+
from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
|
77
|
+
from scipy.special import betainc
|
78
|
+
|
79
|
+
logger.info("Using NumPy backend")
|
80
|
+
|
81
|
+
|
82
|
+
else:
|
83
|
+
logger.error(
|
84
|
+
f"Backend {backend} is not available. "
|
85
|
+
)
|
86
|
+
print(get_available_backends_table())
|
87
|
+
logger.warning(
|
88
|
+
f"Setting backend to NumPy. "
|
89
|
+
)
|
90
|
+
os.environ["PYWAVELET_BACKEND"] = NUMPY
|
91
|
+
return get_backend_from_env()
|
92
|
+
|
47
93
|
return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
|
48
94
|
|
49
95
|
|
96
|
+
cuda_available = cuda_is_available()
|
97
|
+
|
50
98
|
# Get the chosen backend
|
51
99
|
xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
|
52
100
|
get_backend_from_env()
|
pywavelet/transforms/__init__.py
CHANGED
@@ -4,6 +4,30 @@ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
|
|
4
4
|
|
5
5
|
logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
|
6
6
|
|
7
|
+
|
8
|
+
def _log_jax_info():
|
9
|
+
"""Log JAX backend and precision information.
|
10
|
+
|
11
|
+
backend : str
|
12
|
+
JAX backend. ["cpu", "gpu", "tpu"]
|
13
|
+
precision : str
|
14
|
+
JAX precision. ["32bit", "64bit"]
|
15
|
+
"""
|
16
|
+
import jax
|
17
|
+
|
18
|
+
_backend = jax.default_backend()
|
19
|
+
_precision = "64bit" if jax.config.jax_enable_x64 else "32bit"
|
20
|
+
|
21
|
+
logger.info(f"Jax running on {_backend} [{_precision} precision].")
|
22
|
+
if _precision == "32bit":
|
23
|
+
logger.warning(
|
24
|
+
"Jax is not running in 64bit precision. "
|
25
|
+
"To change, use jax.config.update('jax_enable_x64', True)."
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
_log_jax_info()
|
30
|
+
|
7
31
|
__all__ = [
|
8
32
|
"from_wavelet_to_time",
|
9
33
|
"from_wavelet_to_freq",
|
@@ -1,9 +1,9 @@
|
|
1
1
|
pywavelet/__init__.py,sha256=K7pQ8W2w9d5qwI4KzPdTpRn5-YaUfMpjnJmg7oQnYSM,508
|
2
|
-
pywavelet/_version.py,sha256=
|
3
|
-
pywavelet/backend.py,sha256=
|
2
|
+
pywavelet/_version.py,sha256=Xk20v7uvkFqkpy9aLJzVngs1eKQn0FYUP2oyA1MEQUU,511
|
3
|
+
pywavelet/backend.py,sha256=Ixl9nVP7Lobbyff5WF-pqXo0zuM7RVbyC1j35dSO3jI,2791
|
4
4
|
pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
|
5
5
|
pywavelet/utils.py,sha256=FqQ6V41WGHMbLC4wv_1xnwHjOPDVSWnG78sAeqbYtYU,1994
|
6
|
-
pywavelet/transforms/__init__.py,sha256=
|
6
|
+
pywavelet/transforms/__init__.py,sha256=c5dnTdzm-Se3idp2LP-dHvk6fmv_ynDCixIDfLCjEPw,865
|
7
7
|
pywavelet/transforms/phi_computer.py,sha256=jVxeWtfx5P1H-_HdMsK7xHuINZAjH9bj7cA8CJ98isw,3667
|
8
8
|
pywavelet/transforms/cupy/__init__.py,sha256=8BBE6msB071WdstA860a7g64C0aHT2PZsqfEgP6nmkA,336
|
9
9
|
pywavelet/transforms/cupy/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
|
@@ -13,7 +13,7 @@ pywavelet/transforms/cupy/forward/main.py,sha256=g2Pl-j4LBg7GLlzzCSoCGuEd6NNCckJ
|
|
13
13
|
pywavelet/transforms/cupy/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
|
14
14
|
pywavelet/transforms/cupy/inverse/main.py,sha256=5pTtGNNdwlSGDQV4sqGyzUPnmqFUgFOFUFfpqjZx07Q,1608
|
15
15
|
pywavelet/transforms/cupy/inverse/to_freq.py,sha256=gpqu5Y65ZvuET5jANp6UAuAamg2PRkpAlaAjWPh7uBk,1835
|
16
|
-
pywavelet/transforms/jax/__init__.py,sha256=
|
16
|
+
pywavelet/transforms/jax/__init__.py,sha256=_JG9EaHKD95U-Nzm5zQ1RHQ7XuBjZarMa7VpZ-y7rgY,941
|
17
17
|
pywavelet/transforms/jax/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
|
18
18
|
pywavelet/transforms/jax/forward/from_freq.py,sha256=XYtRziPD7MCbeKf4HAucQrMzko4T0zmNV7jg5bziVwA,3910
|
19
19
|
pywavelet/transforms/jax/forward/from_time.py,sha256=4RZ8-ah0qOMP20i3-xThVWddxa1QTCvZKnGpNAJbb0g,1765
|
@@ -37,7 +37,7 @@ pywavelet/types/plotting.py,sha256=qjv5IeuSEc9WWkfJYvz1eQRgTKTspWxj4lwB5N69SbU,1
|
|
37
37
|
pywavelet/types/timeseries.py,sha256=sataMW4BPFqi23h_NBZ_U9-Svuo9pLXVRmUJI6KTXG0,9430
|
38
38
|
pywavelet/types/wavelet.py,sha256=lDhpy9bEb_I-YDQbI3elaWuU8l9E2P6wDcuAQONv8lA,13591
|
39
39
|
pywavelet/types/wavelet_bins.py,sha256=gBjhWwfjcbbSnbGZVMNUeFFVUo2DVxJS4abDUVCL7ts,1458
|
40
|
-
pywavelet-0.2.
|
41
|
-
pywavelet-0.2.
|
42
|
-
pywavelet-0.2.
|
43
|
-
pywavelet-0.2.
|
40
|
+
pywavelet-0.2.7.dist-info/METADATA,sha256=5oaGFUO6ASLeTlweptCKP5D9LUGkLQbmJNfWMVYgwxc,2571
|
41
|
+
pywavelet-0.2.7.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
|
42
|
+
pywavelet-0.2.7.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
|
43
|
+
pywavelet-0.2.7.dist-info/RECORD,,
|
File without changes
|