pywavelet 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pywavelet/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.2.0'
16
- __version_tuple__ = version_tuple = (0, 2, 0)
15
+ __version__ = version = '0.2.1'
16
+ __version_tuple__ = version_tuple = (0, 2, 1)
File without changes
@@ -0,0 +1,69 @@
1
+ import jax.numpy as jnp
2
+ from jax.numpy.fft import rfftfreq
3
+
4
+ from ...phi_computer import phi_vec, phitilde_vec_norm
5
+ from ....types import FrequencySeries, TimeSeries, Wavelet
6
+ from .to_freq import inverse_wavelet_freq_helper
7
+ # from .inverse_wavelet_time_funcs import inverse_wavelet_time_helper
8
+
9
+
10
+ def from_wavelet_to_time(
11
+ wave_in: Wavelet,
12
+ dt: float,
13
+ nx: float = 4.0,
14
+ mult: int = 32,
15
+ ) -> TimeSeries:
16
+ """Inverse wavelet transform to time domain.
17
+
18
+ Parameters
19
+ ----------
20
+ wave_in : Wavelet
21
+ input wavelet
22
+ dt : float
23
+ time step
24
+ nx : float, optional
25
+ parameter for phi_vec, by default 4.0
26
+ mult : int, optional
27
+ parameter for phi_vec, by default 32
28
+
29
+ Returns
30
+ -------
31
+ TimeSeries
32
+ Time domain signal
33
+ """
34
+ # Can we just do this?
35
+ freq = from_wavelet_to_freq(wave_in, dt=dt, nx=nx)
36
+ return freq.to_timeseries()
37
+
38
+
39
+ def from_wavelet_to_freq(
40
+ wave_in: Wavelet, dt: float, nx=4.0
41
+ ) -> FrequencySeries:
42
+ """Inverse wavelet transform to frequency domain.
43
+
44
+ Parameters
45
+ ----------
46
+ wave_in : Wavelet
47
+ input wavelet
48
+ dt : float
49
+ time step
50
+ nx : float, optional
51
+ parameter for phitilde_vec_norm, by default 4.0
52
+
53
+ Returns
54
+ -------
55
+ FrequencySeries
56
+ Frequency domain signal
57
+
58
+ """
59
+ phif = jnp.array(phitilde_vec_norm(wave_in.Nf, wave_in.Nt, dt=dt, d=nx))
60
+ freq_data = inverse_wavelet_freq_helper(
61
+ wave_in.data, phif=phif, Nf=wave_in.Nf, Nt=wave_in.Nt
62
+ )
63
+
64
+ freq_data *= 2 ** (
65
+ -1 / 2
66
+ ) # Normalise to get the proper backwards transformation
67
+
68
+ freqs = rfftfreq(wave_in.ND*2, d=dt)[1:]
69
+ return FrequencySeries(data=freq_data, freq=freqs)
@@ -0,0 +1,70 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jax import jit
4
+ from jax.numpy.fft import fft
5
+
6
+ from functools import partial
7
+
8
+
9
+ @partial(jit, static_argnames=('Nf', 'Nt'))
10
+ def inverse_wavelet_freq_helper(
11
+ wave_in: jnp.ndarray, phif: jnp.ndarray, Nf: int, Nt: int
12
+ ) -> jnp.ndarray:
13
+ """JAX vectorized function for inverse_wavelet_freq"""
14
+ wave_in = wave_in.T
15
+ ND = Nf * Nt
16
+
17
+ m_range = jnp.arange(Nf + 1)
18
+ prefactor2s = jnp.zeros((Nf + 1, Nt), dtype=jnp.complex128)
19
+
20
+ n_range = jnp.arange(Nt)
21
+
22
+ # m == 0 case
23
+ prefactor2s = prefactor2s.at[0].set(2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt, 0])
24
+
25
+ # m == Nf case
26
+ prefactor2s = prefactor2s.at[Nf].set(2 ** (-1 / 2) * wave_in[(2 * n_range) % Nt + 1, 0])
27
+
28
+ # Other m cases
29
+ m_mid = m_range[1:Nf]
30
+ n_grid, m_grid = jnp.meshgrid(n_range, m_mid)
31
+ val = wave_in[n_grid, m_grid]
32
+ mult2 = jnp.where((n_grid + m_grid) % 2, -1j, 1)
33
+ prefactor2s = prefactor2s.at[1:Nf].set(mult2 * val)
34
+
35
+ # Vectorized FFT
36
+ fft_prefactor2s = fft(prefactor2s, axis=1)
37
+
38
+ # Vectorized __unpack_wave_inverse
39
+ ## TODO: Check with Giorgio
40
+ # ND or ND // 2 + 1?
41
+ # https://github.com/pywavelet/pywavelet/blob/63151a47cde9edc14f1e7e0bf17f554e78ad257c/src/pywavelet/transforms/from_wavelets/inverse_wavelet_freq_funcs.py
42
+ res = jnp.zeros(ND, dtype=jnp.complex128)
43
+
44
+ # m == 0 or m == Nf cases
45
+ i_ind_range = jnp.arange(Nt // 2)
46
+ i_0 = jnp.abs(i_ind_range)
47
+ i_Nf = jnp.abs(Nf * Nt // 2 - i_ind_range)
48
+ ind3_0 = (2 * i_0) % Nt
49
+ ind3_Nf = (2 * i_Nf) % Nt
50
+
51
+ res = res.at[i_0].add(fft_prefactor2s[0, ind3_0] * phif[i_ind_range])
52
+ res = res.at[i_Nf].add(fft_prefactor2s[Nf, ind3_Nf] * phif[i_ind_range])
53
+
54
+ # Special case for m == Nf
55
+ res = res.at[Nf * Nt // 2].add(fft_prefactor2s[Nf, 0] * phif[Nt // 2])
56
+
57
+ # Other m cases
58
+ m_mid = m_range[1:Nf]
59
+ i_ind_range = jnp.arange(Nt // 2 + 1)
60
+ m_grid, i_ind_grid = jnp.meshgrid(m_mid, i_ind_range)
61
+
62
+ i1 = Nt // 2 * m_grid - i_ind_grid
63
+ i2 = Nt // 2 * m_grid + i_ind_grid
64
+ ind31 = (Nt // 2 * m_grid - i_ind_grid) % Nt
65
+ ind32 = (Nt // 2 * m_grid + i_ind_grid) % Nt
66
+
67
+ res = res.at[i1].add(fft_prefactor2s[m_grid, ind31] * phif[i_ind_grid])
68
+ res = res.at[i2].add(fft_prefactor2s[m_grid, ind32] * phif[i_ind_grid])
69
+
70
+ return res
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -1,5 +1,5 @@
1
1
  pywavelet/__init__.py,sha256=zcK3Qj4wTrGZF1rU3aT6yA9LvliAOD4DVOY7gNfHhCI,53
2
- pywavelet/_version.py,sha256=H-qsvrxCpdhaQzyddR-yajEqI71hPxLa4KxzpP3uS1g,411
2
+ pywavelet/_version.py,sha256=MxUhzLJIZQfEpDTTcKSxciTGrMLd5v2VmMlHa2HGeo0,411
3
3
  pywavelet/backend.py,sha256=k4pDi6f4cwNY6HsUIx1xfuga9f2wLnFr_FIb7Fs1Mds,553
4
4
  pywavelet/logger.py,sha256=DyKC-pJ_N9GlVeXL00E1D8hUd8GceBg-pnn7g1YPKcM,391
5
5
  pywavelet/utils.py,sha256=l47C643nGlV9q4a0G7wtKzuas0Ou4En2e1FTATCgwlw,1907
@@ -10,6 +10,9 @@ pywavelet/transforms/jax/forward/__init__.py,sha256=Ki2RJCfkE9Zy59mqT3oEtGK9Ro9k
10
10
  pywavelet/transforms/jax/forward/from_freq.py,sha256=PsUC7RfrN6pRWWkMSXYHk9z5lxCXW3DfF0m-Rd1GOBE,1785
11
11
  pywavelet/transforms/jax/forward/from_time.py,sha256=xNeoZq54B6Gi3TdTTYLr_euaFeJcwpms-lSyCG53AdI,1726
12
12
  pywavelet/transforms/jax/forward/main.py,sha256=mm0R4m0pXcnzZB0jCckAc4ynG8STH5mldCmHyyU_PGo,3091
13
+ pywavelet/transforms/jax/inverse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ pywavelet/transforms/jax/inverse/main.py,sha256=ZK8NyfMI6oFYMKcALatETWnCystH0LWjEYInvnMMmh0,1714
15
+ pywavelet/transforms/jax/inverse/to_freq.py,sha256=ASNARcDBJQr4EizAP_77e5ai36iPwP6hzfvwGbZQ6BM,2295
13
16
  pywavelet/transforms/numpy/__init__.py,sha256=qFLpGpW3VJSbDp2JpD0Gx7PdwDjH-wrW_aO84ASkIgA,255
14
17
  pywavelet/transforms/numpy/forward/__init__.py,sha256=E_A8plyfTSKDRXlAAvdiRMTe9f3Y6MbK3pXMHFg8mr0,121
15
18
  pywavelet/transforms/numpy/forward/from_freq.py,sha256=JmJyjrNSb64WnpP50VZRt0BICP64iZJP5QAZTZoexkw,2675
@@ -26,7 +29,7 @@ pywavelet/types/plotting.py,sha256=JNDxeP-fB8U09E90J-rVT-h5yCGA_tGRHtctbgINiRo,1
26
29
  pywavelet/types/timeseries.py,sha256=u35bIqFo3QdlQRBEu6maeWA7DePS11LQ6WMiLjZPcWo,9456
27
30
  pywavelet/types/wavelet.py,sha256=el48oyAfwtSw2tCQLUb85F9lKr0qMSRJPUmAUU8TS50,12552
28
31
  pywavelet/types/wavelet_bins.py,sha256=GoQGKeZlPc-KbYY7LoxAhB-HI4diHpPcTABBXRfUTLA,1459
29
- pywavelet-0.2.0.dist-info/METADATA,sha256=rg9LNZxrykv39lKIGuX65v7WzuY93_D0oY268mMe5iw,1362
30
- pywavelet-0.2.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
31
- pywavelet-0.2.0.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
32
- pywavelet-0.2.0.dist-info/RECORD,,
32
+ pywavelet-0.2.1.dist-info/METADATA,sha256=pVblWF5CqZ_hFU9J8_bAENKh5HqJnIdSy8y1RR3TkVU,1362
33
+ pywavelet-0.2.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
34
+ pywavelet-0.2.1.dist-info/top_level.txt,sha256=g0Ezt0Rg0X-nrd-a0pAXKVRkuWNsF2M9Ynsjb9b2UYQ,10
35
+ pywavelet-0.2.1.dist-info/RECORD,,