pywavelet 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.2'
16
- __version_tuple__ = version_tuple = (0, 2, 2)
15
+ __version__ = version = '0.2.4'
16
+ __version_tuple__ = version_tuple = (0, 2, 4)
pywavelet/backend.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import os
2
2
 
3
+ from .logger import logger
4
+
3
5
  try:
4
6
  import jax
5
7
 
@@ -13,14 +15,17 @@ use_jax = jax_available and os.getenv("PYWAVELET_JAX", "0") == "1"
13
15
 
14
16
  if use_jax:
15
17
  import jax.numpy as xp # type: ignore
16
- from jax.scipy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
18
+ from jax.numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
17
19
  from jax.scipy.special import betainc # type: ignore
18
20
 
21
+ logger.info("Using JAX backend")
19
22
 
20
23
  else:
21
24
  import numpy as xp # type: ignore
22
- from numpy.fft import fft, ifft, rfft, irfft, rfftfreq # type: ignore
23
- from scipy.special import betainc # type: ignore
25
+ from numpy.fft import fft, ifft, irfft, rfft, rfftfreq # type: ignore
26
+ from scipy.special import betainc # type: ignore
27
+
28
+ logger.info("Using NumPy+numba backend")
24
29
 
25
30
 
26
- PI = xp.pi
31
+ PI = xp.pi
@@ -1,10 +1,21 @@
1
- from .numpy import (
2
- from_wavelet_to_time,
3
- from_wavelet_to_freq,
4
- from_time_to_wavelet,
5
- from_freq_to_wavelet,
6
- )
7
- from .phi_computer import phi_vec, phitilde_vec_norm, phitilde_vec
1
+ from ..backend import use_jax
2
+
3
+ if use_jax:
4
+ from .jax import (
5
+ from_freq_to_wavelet,
6
+ from_time_to_wavelet,
7
+ from_wavelet_to_freq,
8
+ from_wavelet_to_time,
9
+ )
10
+ else:
11
+ from .numpy import (
12
+ from_wavelet_to_time,
13
+ from_wavelet_to_freq,
14
+ from_time_to_wavelet,
15
+ from_freq_to_wavelet,
16
+ )
17
+
18
+ from .phi_computer import phi_vec, phitilde_vec, phitilde_vec_norm
8
19
 
9
20
  __all__ = [
10
21
  "from_wavelet_to_time",
@@ -0,0 +1,12 @@
1
+ from ...logger import logger
2
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
3
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
4
+
5
+ logger.warning("JAX SUBPACKAGE NOT FULLY TESTED")
6
+
7
+ __all__ = [
8
+ "from_wavelet_to_time",
9
+ "from_wavelet_to_freq",
10
+ "from_time_to_wavelet",
11
+ "from_freq_to_wavelet",
12
+ ]
@@ -1,6 +1,3 @@
1
1
  from .main import from_freq_to_wavelet, from_time_to_wavelet
2
- from ....logger import logger
3
-
4
- logger.warning("JAX SUBPACKAGE NOT YET TESTED")
5
2
 
6
3
  __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -1,56 +1,97 @@
1
- import jax
2
- import jax.numpy as jnp
3
1
  from functools import partial
2
+
3
+ import jax.numpy as jnp
4
4
  from jax import jit
5
5
  from jax.numpy.fft import ifft
6
6
 
7
- @partial(jit, static_argnames=('Nf', 'Nt'))
7
+
8
+ @partial(jit, static_argnames=("Nf", "Nt"))
8
9
  def transform_wavelet_freq_helper(
9
- data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
10
+ data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray
10
11
  ) -> jnp.ndarray:
11
- # Initially all wrk being done in time-rws, freq-cols
12
+ """
13
+ Transforms input data from the frequency domain to the wavelet domain using a
14
+ pre-computed wavelet filter (`phif`) and performs an efficient inverse FFT.
15
+
16
+ Parameters:
17
+ - data (jnp.ndarray): 1D array representing the input data in the frequency domain.
18
+ - Nf (int): Number of frequency bins.
19
+ - Nt (int): Number of time bins. (Note: Nt * Nf == len(data))
20
+ - phif (jnp.ndarray): Pre-computed wavelet filter for frequency components.
21
+
22
+ Returns:
23
+ - wave (jnp.ndarray): 2D array of wavelet-transformed data with shape (Nf, Nt).
24
+ """
25
+
26
+ # Initialize the wavelet output array with zeros (time-rows, frequency-columns)
12
27
  wave = jnp.zeros((Nt, Nf))
13
- f_bins = jnp.arange(Nf)
28
+ f_bins = jnp.arange(Nf) # Frequency bin indices
14
29
 
30
+ # Compute base indices for time (i_base) and frequency (jj_base)
15
31
  i_base = Nt // 2
16
32
  jj_base = f_bins * Nt // 2
17
33
 
34
+ # Set initial values for the center of the transformation
18
35
  initial_values = jnp.where(
19
- (f_bins == 0) | (f_bins == Nf),
20
- phif[0] * data[f_bins * Nt // 2] / 2.0,
21
- phif[0] * data[f_bins * Nt // 2]
36
+ (f_bins == 0)
37
+ | (f_bins == Nf), # Edge cases: DC (f=0) and Nyquist (f=Nf)
38
+ phif[0] * data[f_bins * Nt // 2] / 2.0, # Adjust for symmetry
39
+ phif[0] * data[f_bins * Nt // 2],
22
40
  )
23
41
 
24
- DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)
25
- DX = DX.at[:, Nt // 2].set(initial_values)
42
+ # Initialize a 2D array to store intermediate FFT input values
43
+ DX = jnp.zeros(
44
+ (Nf, Nt), dtype=jnp.complex64
45
+ ) # TODO: Check dtype -- is complex64 sufficient?
46
+ DX = DX.at[:, Nt // 2].set(
47
+ initial_values
48
+ ) # Set initial values at the center of the transformation (2 sided FFT)
26
49
 
27
- j_range = jnp.arange(1 - Nt // 2, Nt // 2)
28
- j = jnp.abs(j_range)
29
- i = i_base + j_range
50
+ # Compute time indices for all offsets around the midpoint
51
+ j_range = jnp.arange(
52
+ 1 - Nt // 2, Nt // 2
53
+ ) # Time offsets (centered around zero)
54
+ j = jnp.abs(j_range) # Absolute offset indices
55
+ i = i_base + j_range # Time indices in output array
30
56
 
31
- cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)
32
- cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)
33
- cond3 = j[None, :] == 0
57
+ # Compute conditions for edge cases
58
+ cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0) # Nyquist
59
+ cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0) # DC
60
+ cond3 = j[None, :] == 0 # Center of the transformation (no offset)
34
61
 
35
- jj = jj_base[:, None] + j_range[None, :]
36
- val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])
37
- DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))
62
+ # Compute frequency indices for the input data
63
+ jj = jj_base[:, None] + j_range[None, :] # Frequency offsets
64
+ val = jnp.where(
65
+ cond1 | cond2, 0.0, phif[j] * data[jj]
66
+ ) # Wavelet filter application
67
+ DX = DX.at[:, i].set(
68
+ jnp.where(cond3, DX[:, i], val)
69
+ ) # Update DX with computed values
70
+ # At this point, DX contains the data FFT'd with the wavelet filter
71
+ # (each row is a frequency bin, each column is a time bin)
38
72
 
39
- # Vectorized ifft
73
+ # Perform the inverse FFT along the time dimension
40
74
  DX_trans = ifft(DX, axis=1)
41
75
 
42
- # Vectorized __fill_wave_2_jax
43
- n_range = jnp.arange(Nt)
44
- cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1
45
- cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # shape: (Nf, 1)
76
+ # Fill the wavelet output array based on the inverse FFT results
77
+ n_range = jnp.arange(Nt) # Time indices
78
+ cond1 = (
79
+ n_range[:, None] + f_bins[None, :]
80
+ ) % 2 == 1 # Odd/even alternation
81
+ cond2 = jnp.expand_dims(f_bins % 2 == 1, axis=-1) # Odd frequency bins
46
82
 
83
+ # Assign real and imaginary parts based on conditions
47
84
  real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))
48
85
  imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))
49
-
50
86
  wave = jnp.where(cond1, imag_part.T, real_part.T)
51
87
 
52
- ## Special cases for f_bin 0 and Nf
53
- wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))
54
- wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))
88
+ # Special cases for frequency bins 0 (DC) and Nf (Nyquist)
89
+ wave = wave.at[::2, 0].set(
90
+ jnp.real(DX_trans[0, ::2] * jnp.sqrt(2))
91
+ ) # DC component
92
+ wave = wave.at[1::2, -1].set(
93
+ jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2))
94
+ ) # Nyquist component
55
95
 
56
- return wave.T
96
+ # Return the wavelet-transformed array (transposed for freq-major layout)
97
+ return wave.T # (Nt, Nf) -> (Nf, Nt)
@@ -0,0 +1,3 @@
1
+ from .main import from_wavelet_to_freq, from_wavelet_to_time
2
+
3
+ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
@@ -1,10 +1,9 @@
1
1
  import jax.numpy as jnp
2
2
  from jax.numpy.fft import rfftfreq
3
3
 
4
- from ...phi_computer import phi_vec, phitilde_vec_norm
5
4
  from ....types import FrequencySeries, TimeSeries, Wavelet
5
+ from ...phi_computer import phi_vec, phitilde_vec_norm
6
6
  from .to_freq import inverse_wavelet_freq_helper
7
- # from .inverse_wavelet_time_funcs import inverse_wavelet_time_helper
8
7
 
9
8
 
10
9
  def from_wavelet_to_time(
@@ -65,5 +64,5 @@ def from_wavelet_to_freq(
65
64
  -1 / 2
66
65
  ) # Normalise to get the proper backwards transformation
67
66
 
68
- freqs = rfftfreq(wave_in.ND*2, d=dt)[1:]
69
- return FrequencySeries(data=freq_data, freq=freqs)
67
+ freqs = rfftfreq(wave_in.ND * 2, d=dt)[1:]
68
+ return FrequencySeries(data=freq_data, freq=freqs)
@@ -1,6 +1,7 @@
1
1
  from typing import List, Tuple
2
2
 
3
3
  import matplotlib.pyplot as plt
4
+ import numpy as np
4
5
 
5
6
  from .common import fmt_timerange, is_documented_by, xp
6
7
  from .plotting import plot_wavelet_grid, plot_wavelet_trend
@@ -343,14 +344,18 @@ class Wavelet:
343
344
 
344
345
  def __mul__(self, other):
345
346
  """Element-wise multiplication of two Wavelet objects."""
346
- if isinstance(other, Wavelet):
347
- return Wavelet(
348
- data=self.data * other.data, time=self.time, freq=self.freq
349
- )
347
+ if isinstance(other, WaveletMask):
348
+ data = self.data.copy()
349
+ data[~other.mask] = np.nan
350
+ return Wavelet(data=data, time=self.time, freq=self.freq)
350
351
  elif isinstance(other, float):
351
352
  return Wavelet(
352
353
  data=self.data * other, time=self.time, freq=self.freq
353
354
  )
355
+ elif isinstance(other, WaveletMask):
356
+ return Wavelet(
357
+ data=self.data * other.data, time=self.time, freq=self.freq
358
+ )
354
359
 
355
360
  def __truediv__(self, other):
356
361
  """Element-wise division of two Wavelet objects."""
@@ -445,11 +450,40 @@ class WaveletMask(Wavelet):
445
450
  return rpr
446
451
 
447
452
  @classmethod
448
- def from_frange(
449
- cls, time_grid: xp.ndarray, freq_grid: xp.ndarray, frange: List[float]
453
+ def from_restrictions(
454
+ cls,
455
+ time_grid: xp.ndarray,
456
+ freq_grid: xp.ndarray,
457
+ frange: List[float],
458
+ tgaps: List[Tuple[float, float]] = [],
450
459
  ):
460
+ """
461
+ Create a WaveletMask object from restrictions on time and frequency.
462
+
463
+ Parameters
464
+ ----------
465
+ time_grid : xp.ndarray
466
+ Array of time points.
467
+ freq_grid : xp.ndarray
468
+ Array of corresponding frequency points.
469
+ frange : List[float]
470
+ Frequency range to include.
471
+ tgaps : List[Tuple[float, float]]
472
+ List of time gaps to exclude.
473
+
474
+ Returns
475
+ -------
476
+ WaveletMask
477
+ A WaveletMask object with the specified restrictions.
478
+ """
451
479
  self = cls.zeros_from_grid(time_grid, freq_grid)
452
480
  self.data[
453
481
  (freq_grid >= frange[0]) & (freq_grid <= frange[1]), :
454
482
  ] = True
483
+
484
+ for tgap in tgaps:
485
+ self.data[
486
+ :, (time_grid >= tgap[0]) & (time_grid <= tgap[1])
487
+ ] = False
488
+ self.data = self.data.astype(bool)
455
489
  return self
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -1,17 +1,17 @@
1
1
  pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
- pywavelet/_version.py,sha256=RrHB9KG1O3GPm--rbTedqmZbdDrbgeRLXBmT4OBUqqI,411
3
- pywavelet/backend.py,sha256=k4pDi6f4cwNY6HsUIx1xfuga9f2wLnFr_FIb7Fs1Mds,553
2
+ pywavelet/_version.py,sha256=4gL0W4-u58XR5lRLpeoIPrGhcewTk0-527de6uTNmkg,411
3
+ pywavelet/backend.py,sha256=SmpgIBHvTO1rtIAQQN_zpVB8i6R-x23FNKJG6_JlrNs,666
4
4
  pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
5
5
  pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
6
- pywavelet/transforms/__init__.py,sha256=uc1fKbBGQgEDafJHk6GEVCc0G_EXL5CtFTKCoFsewoM,381
6
+ pywavelet/transforms/__init__.py,sha256=EYX8glRWojYbrjtbgrjS4vigYTRi7FOtIV3D1UwI5fY,604
7
7
  pywavelet/transforms/phi_computer.py,sha256=ppFSGJwtNnO2flaiok9ms3WXlAxGQikvA7eNfLgriNQ,4461
8
- pywavelet/transforms/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- pywavelet/transforms/jax/forward/__init__.py,sha256=Ki2RJCfkE9Zy59mqT3oEtGK9Ro9kS5kAz9duZFbxyZo,200
10
- pywavelet/transforms/jax/forward/from_freq.py,sha256=PsUC7RfrN6pRWWkMSXYHk9z5lxCXW3DfF0m-Rd1GOBE,1785
8
+ pywavelet/transforms/jax/__init__.py,sha256=D_f-JgFAzOIJ-EuQZhTMziD4MT6lVWS3XV9s51Cu7Kg,335
9
+ pywavelet/transforms/jax/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
10
+ pywavelet/transforms/jax/forward/from_freq.py,sha256=tKEdqPyEvX8ZKVQf16wGxN3d6gkcjm_RtAHQuWHUzy4,3764
11
11
  pywavelet/transforms/jax/forward/from_time.py,sha256=xNeoZq54B6Gi3TdTTYLr_euaFeJcwpms-lSyCG53AdI,1726
12
12
  pywavelet/transforms/jax/forward/main.py,sha256=mm0R4m0pXcnzZB0jCckAc4ynG8STH5mldCmHyyU_PGo,3091
13
- pywavelet/transforms/jax/inverse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- pywavelet/transforms/jax/inverse/main.py,sha256=ZK8NyfMI6oFYMKcALatETWnCystH0LWjEYInvnMMmh0,1714
13
+ pywavelet/transforms/jax/inverse/__init__.py,sha256=J4KIzPzbHNg_8fV_c1MpPq3slSqHQV0j3VFrjfd1Nog,121
14
+ pywavelet/transforms/jax/inverse/main.py,sha256=-HVOOBsYo3GJvGNCsQLbNPnt9s14JvbB2bGAd9LOr3A,1647
15
15
  pywavelet/transforms/jax/inverse/to_freq.py,sha256=ASNARcDBJQr4EizAP_77e5ai36iPwP6hzfvwGbZQ6BM,2295
16
16
  pywavelet/transforms/numpy/__init__.py,sha256=qFLpGpW3VJSbDp2JpD0Gx7PdwDjH-wrW_aO84ASkIgA,255
17
17
  pywavelet/transforms/numpy/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
@@ -27,9 +27,9 @@ pywavelet/types/common.py,sha256=aIcYq-0KOLHnPQjrVbVmw_TQ3Xm5a7xA30rSgwt3rk4,127
27
27
  pywavelet/types/frequencyseries.py,sha256=hrtLaIUaRrqXw8l00yFe2tPJwpksDa_4n1z6R8XSPPQ,7531
28
28
  pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,10625
29
29
  pywavelet/types/timeseries.py,sha256=u35bIqFo3QdlQRBEu6maeWA7DePS11LQ6WMiLjZPcWo,9456
30
- pywavelet/types/wavelet.py,sha256=el48oyAfwtSw2tCQLUb85F9lKr0qMSRJPUmAUU8TS50,12552
30
+ pywavelet/types/wavelet.py,sha256=uHJzTS2ZXTRr7I7NHWv3qNjknSBhQUpcED3jM6ti7UM,13587
31
31
  pywavelet/types/wavelet_bins.py,sha256=GoQGKeZlPc-KbYY7LoxAhB-HI4diHpPcTABBXRfUTLA,1459
32
- pywavelet-0.2.2.dist-info/METADATA,sha256=JqDqYCarCAWF5DTmbKzY1u1zUlvlB7OexbrrcLuCRRM,2241
33
- pywavelet-0.2.2.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
34
- pywavelet-0.2.2.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
35
- pywavelet-0.2.2.dist-info/RECORD,,
32
+ pywavelet-0.2.4.dist-info/METADATA,sha256=Thhhz8I2XTKr0mVuf09UpcvjeEGKUnVUX0jxENu6gEQ,2241
33
+ pywavelet-0.2.4.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
34
+ pywavelet-0.2.4.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
35
+ pywavelet-0.2.4.dist-info/RECORD,,