pywavelet 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.6'
21
- __version_tuple__ = version_tuple = (0, 2, 6)
20
+ __version__ = version = '0.2.7'
21
+ __version_tuple__ = version_tuple = (0, 2, 7)
pywavelet/backend.py CHANGED
@@ -1,5 +1,9 @@
1
1
  import importlib
2
2
  import os
3
+ from rich.table import Table, Text
4
+ from rich.console import Console
5
+
6
+
3
7
 
4
8
  from .logger import logger
5
9
 
@@ -8,45 +12,89 @@ CUPY = "cupy"
8
12
  NUMPY = "numpy"
9
13
 
10
14
 
15
+ def cuda_is_available() -> bool:
16
+ """Check if CUDA is available."""
17
+ # Check if CuPy is available and CUDA is accessible
18
+ cupy_available = importlib.util.find_spec("cupy") is not None
19
+ if cupy_available:
20
+ import cupy
21
+
22
+ try:
23
+ cupy.cuda.runtime.getDeviceCount() # Check if any CUDA device is available
24
+ cuda_available = True
25
+ except cupy.cuda.runtime.CUDARuntimeError:
26
+ cuda_available = False
27
+ else:
28
+ cuda_available = False
29
+ return cuda_available
30
+
31
+
32
+ def jax_is_available() -> bool:
33
+ """Check if JAX is available."""
34
+ return importlib.util.find_spec(JAX) is not None
35
+
36
+
37
+ def get_available_backends_table():
38
+ """Print the available backends as a rich table."""
39
+
40
+ jax_avail = jax_is_available()
41
+ cupy_avail = cuda_is_available()
42
+ table = Table("Backend", "Available", title="Available backends")
43
+ true_check = "[green]✓[/green]"
44
+ false_check = "[red]✗[/red]"
45
+ table.add_row(JAX, true_check if jax_avail else false_check)
46
+ table.add_row(CUPY, true_check if cupy_avail else false_check)
47
+ table.add_row(NUMPY, true_check)
48
+ console = Console(width=150)
49
+ with console.capture() as capture:
50
+ console.print(table)
51
+ return Text.from_ansi(capture.get())
52
+
53
+
11
54
  def get_backend_from_env():
12
55
  """Select and return the appropriate backend module."""
13
56
  backend = os.getenv("PYWAVELET_BACKEND", NUMPY).lower()
14
57
 
15
- if backend == JAX:
16
- if importlib.util.find_spec(JAX):
17
- import jax.numpy as xp
18
- from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
19
- from jax.scipy.special import betainc
20
-
21
- logger.info("Using JAX backend")
22
- return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
23
- else:
24
- logger.warning(
25
- "JAX backend requested but not installed. Falling back to NumPy."
26
- )
27
-
28
- elif backend == CUPY:
29
- if importlib.util.find_spec(CUPY):
30
- import cupy as xp
31
- from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
32
- from cupyx.scipy.special import betainc
33
-
34
- logger.info("Using CuPy backend")
35
- return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
36
- else:
37
- logger.warning(
38
- "CuPy backend requested but not installed. Falling back to NumPy."
39
- )
40
-
41
- # Default to NumPy
42
- import numpy as xp
43
- from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
44
- from scipy.special import betainc
45
-
46
- logger.info("Using NumPy+Numba backend")
58
+ if backend == JAX and jax_is_available():
59
+
60
+ import jax.numpy as xp
61
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq
62
+ from jax.scipy.special import betainc
63
+
64
+ logger.info("Using JAX backend")
65
+
66
+ elif backend == CUPY and cuda_is_available():
67
+
68
+ import cupy as xp
69
+ from cupy.fft import fft, ifft, irfft, rfft, rfftfreq
70
+ from cupyx.scipy.special import betainc
71
+
72
+ logger.info("Using CuPy backend")
73
+
74
+ elif backend == NUMPY:
75
+ import numpy as xp
76
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq
77
+ from scipy.special import betainc
78
+
79
+ logger.info("Using NumPy backend")
80
+
81
+
82
+ else:
83
+ logger.error(
84
+ f"Backend {backend} is not available. "
85
+ )
86
+ print(get_available_backends_table())
87
+ logger.warning(
88
+ f"Setting backend to NumPy. "
89
+ )
90
+ os.environ["PYWAVELET_BACKEND"] = NUMPY
91
+ return get_backend_from_env()
92
+
47
93
  return xp, fft, ifft, irfft, rfft, rfftfreq, betainc, backend
48
94
 
49
95
 
96
+ cuda_available = cuda_is_available()
97
+
50
98
  # Get the chosen backend
51
99
  xp, fft, ifft, irfft, rfft, rfftfreq, betainc, current_backend = (
52
100
  get_backend_from_env()
@@ -1,5 +1,11 @@
1
+ from ..logger import logger
2
+
1
3
  from ..backend import current_backend
2
4
 
5
+
6
+ logger.debug(f"Using {current_backend} backend")
7
+
8
+
3
9
  if current_backend == "jax":
4
10
  from .jax import (
5
11
  from_freq_to_wavelet,
@@ -4,6 +4,30 @@ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
4
 
5
5
  logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
6
6
 
7
+
8
+ def _log_jax_info():
9
+ """Log JAX backend and precision information.
10
+
11
+ backend : str
12
+ JAX backend. ["cpu", "gpu", "tpu"]
13
+ precision : str
14
+ JAX precision. ["32bit", "64bit"]
15
+ """
16
+ import jax
17
+
18
+ _backend = jax.default_backend()
19
+ _precision = "64bit" if jax.config.jax_enable_x64 else "32bit"
20
+
21
+ logger.info(f"Jax running on {_backend} [{_precision} precision].")
22
+ if _precision == "32bit":
23
+ logger.warning(
24
+ "Jax is not running in 64bit precision. "
25
+ "To change, use jax.config.update('jax_enable_x64', True)."
26
+ )
27
+
28
+
29
+ _log_jax_info()
30
+
7
31
  __all__ = [
8
32
  "from_wavelet_to_time",
9
33
  "from_wavelet_to_freq",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pywavelet
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -1,9 +1,9 @@
1
1
  pywavelet/__init__.py,sha256=K7pQ8W2w9d5qwI4KzPdTpRn5-YaUfMpjnJmg7oQnYSM,508
2
- pywavelet/_version.py,sha256=nObnONsicQ3YX6SG5MVBxmIp5dmRacXDauSqZijWQbY,511
3
- pywavelet/backend.py,sha256=1AjwqoIlan6vNFZcon_LIVsiPH8HrWQwU3RON7dnjUE,1585
2
+ pywavelet/_version.py,sha256=Xk20v7uvkFqkpy9aLJzVngs1eKQn0FYUP2oyA1MEQUU,511
3
+ pywavelet/backend.py,sha256=Ixl9nVP7Lobbyff5WF-pqXo0zuM7RVbyC1j35dSO3jI,2791
4
4
  pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
5
5
  pywavelet/utils.py,sha256=FqQ6V41WGHMbLC4wv_1xnwHjOPDVSWnG78sAeqbYtYU,1994
6
- pywavelet/transforms/__init__.py,sha256=t4cHI8Rd5UnLwqCunr4sCQRmsKhHOnZ5VqkDphhi-VM,784
6
+ pywavelet/transforms/__init__.py,sha256=c5dnTdzm-Se3idp2LP-dHvk6fmv_ynDCixIDfLCjEPw,865
7
7
  pywavelet/transforms/phi_computer.py,sha256=jVxeWtfx5P1H-_HdMsK7xHuINZAjH9bj7cA8CJ98isw,3667
8
8
  pywavelet/transforms/cupy/__init__.py,sha256=8BBE6msB071WdstA860a7g64C0aHT2PZsqfEgP6nmkA,336
9
9
  pywavelet/transforms/cupy/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
@@ -13,7 +13,7 @@ pywavelet/transforms/cupy/forward/main.py,sha256=g2Pl-j4LBg7GLlzzCSoCGuEd6NNCckJ
13
13
  pywavelet/transforms/cupy/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
14
14
  pywavelet/transforms/cupy/inverse/main.py,sha256=5pTtGNNdwlSGDQV4sqGyzUPnmqFUgFOFUFfpqjZx07Q,1608
15
15
  pywavelet/transforms/cupy/inverse/to_freq.py,sha256=gpqu5Y65ZvuET5jANp6UAuAamg2PRkpAlaAjWPh7uBk,1835
16
- pywavelet/transforms/jax/__init__.py,sha256=D_f-JgFAzOIJ-EuQZhTMziD4MT6lVWS3XV9s51Cu7Kg,335
16
+ pywavelet/transforms/jax/__init__.py,sha256=_JG9EaHKD95U-Nzm5zQ1RHQ7XuBjZarMa7VpZ-y7rgY,941
17
17
  pywavelet/transforms/jax/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
18
18
  pywavelet/transforms/jax/forward/from_freq.py,sha256=XYtRziPD7MCbeKf4HAucQrMzko4T0zmNV7jg5bziVwA,3910
19
19
  pywavelet/transforms/jax/forward/from_time.py,sha256=4RZ8-ah0qOMP20i3-xThVWddxa1QTCvZKnGpNAJbb0g,1765
@@ -37,7 +37,7 @@ pywavelet/types/plotting.py,sha256=qjv5IeuSEc9WWkfJYvz1eQRgTKTspWxj4lwB5N69SbU,1
37
37
  pywavelet/types/timeseries.py,sha256=sataMW4BPFqi23h_NBZ_U9-Svuo9pLXVRmUJI6KTXG0,9430
38
38
  pywavelet/types/wavelet.py,sha256=lDhpy9bEb_I-YDQbI3elaWuU8l9E2P6wDcuAQONv8lA,13591
39
39
  pywavelet/types/wavelet_bins.py,sha256=gBjhWwfjcbbSnbGZVMNUeFFVUo2DVxJS4abDUVCL7ts,1458
40
- pywavelet-0.2.6.dist-info/METADATA,sha256=WScrhO_gC_5wKwY39T0aI81YBj71xn9O6-MI0GBRucQ,2571
41
- pywavelet-0.2.6.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
42
- pywavelet-0.2.6.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
43
- pywavelet-0.2.6.dist-info/RECORD,,
40
+ pywavelet-0.2.7.dist-info/METADATA,sha256=5oaGFUO6ASLeTlweptCKP5D9LUGkLQbmJNfWMVYgwxc,2571
41
+ pywavelet-0.2.7.dist-info/WHEEL,sha256=DK49LOLCYiurdXXOXwGJm6U4DkHkg4lcxjhqwRa0CP4,91
42
+ pywavelet-0.2.7.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
43
+ pywavelet-0.2.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.3)
2
+ Generator: setuptools (78.0.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5