pywavelet 0.0.1b0__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. pywavelet/__init__.py +1 -1
  2. pywavelet/_version.py +2 -2
  3. pywavelet/logger.py +6 -7
  4. pywavelet/transforms/__init__.py +10 -10
  5. pywavelet/transforms/forward/__init__.py +4 -0
  6. pywavelet/transforms/forward/from_freq.py +80 -0
  7. pywavelet/transforms/forward/from_time.py +66 -0
  8. pywavelet/transforms/forward/main.py +128 -0
  9. pywavelet/transforms/forward/wavelet_bins.py +58 -0
  10. pywavelet/transforms/inverse/__init__.py +3 -0
  11. pywavelet/transforms/inverse/main.py +96 -0
  12. pywavelet/transforms/{from_wavelets/inverse_wavelet_freq_funcs.py → inverse/to_freq.py} +43 -32
  13. pywavelet/transforms/{from_wavelets/inverse_wavelet_time_funcs.py → inverse/to_time.py} +49 -21
  14. pywavelet/transforms/phi_computer.py +152 -0
  15. pywavelet/transforms/types/__init__.py +4 -0
  16. pywavelet/transforms/types/common.py +53 -0
  17. pywavelet/transforms/types/frequencyseries.py +237 -0
  18. pywavelet/transforms/types/plotting.py +341 -0
  19. pywavelet/transforms/types/timeseries.py +280 -0
  20. pywavelet/transforms/types/wavelet.py +374 -0
  21. pywavelet/transforms/types/wavelet_mask.py +34 -0
  22. pywavelet/utils.py +76 -0
  23. pywavelet-0.1.0.dist-info/METADATA +35 -0
  24. pywavelet-0.1.0.dist-info/RECORD +26 -0
  25. {pywavelet-0.0.1b0.dist-info → pywavelet-0.1.0.dist-info}/WHEEL +1 -1
  26. pywavelet/fft_funcs.py +0 -16
  27. pywavelet/likelihood/__init__.py +0 -0
  28. pywavelet/likelihood/likelihood_base.py +0 -9
  29. pywavelet/likelihood/whittle.py +0 -24
  30. pywavelet/transforms/common.py +0 -77
  31. pywavelet/transforms/from_wavelets/__init__.py +0 -25
  32. pywavelet/transforms/to_wavelets/__init__.py +0 -52
  33. pywavelet/transforms/to_wavelets/transform_freq_funcs.py +0 -84
  34. pywavelet/transforms/to_wavelets/transform_time_funcs.py +0 -63
  35. pywavelet/utils/__init__.py +0 -0
  36. pywavelet/utils/fisher_matrix.py +0 -6
  37. pywavelet/utils/snr.py +0 -37
  38. pywavelet/waveform_generator/__init__.py +0 -0
  39. pywavelet/waveform_generator/build_lookup_table.py +0 -0
  40. pywavelet/waveform_generator/generators/__init__.py +0 -2
  41. pywavelet/waveform_generator/generators/functional_waveform_generator.py +0 -33
  42. pywavelet/waveform_generator/generators/lookuptable_waveform_generator.py +0 -15
  43. pywavelet/waveform_generator/generators/rom_waveform_generator.py +0 -0
  44. pywavelet/waveform_generator/waveform_generator.py +0 -14
  45. pywavelet-0.0.1b0.dist-info/METADATA +0 -35
  46. pywavelet-0.0.1b0.dist-info/RECORD +0 -29
  47. {pywavelet-0.0.1b0.dist-info → pywavelet-0.1.0.dist-info}/top_level.txt +0 -0
pywavelet/__init__.py CHANGED
@@ -2,4 +2,4 @@
2
2
  WDM Wavelet transform
3
3
  """
4
4
 
5
- __version__ = "0.0.1"
5
+ __version__ = "0.0.2"
pywavelet/_version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.0.1b0'
16
- __version_tuple__ = version_tuple = (0, 0, 1)
15
+ __version__ = version = '0.1.0'
16
+ __version_tuple__ = version_tuple = (0, 1, 0)
pywavelet/logger.py CHANGED
@@ -1,15 +1,14 @@
1
1
  import sys
2
2
  import warnings
3
+ from rich.logging import RichHandler
4
+ import logging
3
5
 
4
- from loguru import logger
5
-
6
- logger.add(
7
- sys.stderr,
8
- format="|<blue>pywavelet</>|{time:DD/MM HH:mm:ss}|{level}| {message} ",
9
- colorize=True,
10
- level="INFO",
6
+ FORMAT = "%(message)s"
7
+ logging.basicConfig(
8
+ level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
11
9
  )
12
10
 
11
+ logger = logging.getLogger("pywavelet")
13
12
 
14
13
  warnings.filterwarnings("ignore", category=RuntimeWarning)
15
14
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -1,10 +1,10 @@
1
- from .from_wavelets import (
2
- from_wavelet_to_freq,
3
- from_wavelet_to_freq_to_time,
4
- from_wavelet_to_time,
5
- )
6
- from .to_wavelets import (
7
- from_freq_to_wavelet,
8
- from_time_to_freq_to_wavelet,
9
- from_time_to_wavelet,
10
- )
1
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet, compute_bins
2
+ from .inverse import from_wavelet_to_freq, from_wavelet_to_time
3
+
4
+ __all__ = [
5
+ "from_wavelet_to_time",
6
+ "from_wavelet_to_freq",
7
+ "from_time_to_wavelet",
8
+ "from_freq_to_wavelet",
9
+ "compute_bins",
10
+ ]
@@ -0,0 +1,4 @@
1
+ from .main import from_freq_to_wavelet, from_time_to_wavelet
2
+ from .wavelet_bins import compute_bins
3
+
4
+ __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet", "compute_bins"]
@@ -0,0 +1,80 @@
1
+ """helper functions for transform_freq"""
2
+ import numpy as np
3
+ from numba import njit
4
+ from numpy import fft
5
+
6
+ def transform_wavelet_freq_helper(
7
+ data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
8
+ ) -> np.ndarray:
9
+ """helper to do the wavelet transform using the fast wavelet domain transform"""
10
+ wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal
11
+ DX = np.zeros(Nt, dtype=np.complex128)
12
+ freq_strain = data.copy() # Convert
13
+ __core(Nf, Nt, DX, freq_strain, phif, wave)
14
+ return wave
15
+
16
+ # @njit()
17
+ def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarray, wave:np.ndarray):
18
+ for f_bin in range(0, Nf + 1):
19
+ __fill_wave_1(f_bin, Nt, Nf, DX, freq_strain, phif)
20
+ # Numba doesn't support np.ifft so we cant jit this
21
+ DX_trans = np.fft.ifft(DX, Nt)
22
+ __fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
23
+
24
+
25
+
26
+
27
+ @njit()
28
+ def __fill_wave_1(
29
+ f_bin: int,
30
+ Nt: int,
31
+ Nf: int,
32
+ DX: np.ndarray,
33
+ data: np.ndarray,
34
+ phif: np.ndarray,
35
+ ) -> None:
36
+ """helper for assigning DX in the main loop"""
37
+ i_base = Nt // 2
38
+ jj_base = f_bin * Nt // 2
39
+
40
+ if f_bin == 0 or f_bin == Nf:
41
+ # NOTE this term appears to be needed to recover correct constant (at least for m=0), but was previously missing
42
+ DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2] / 2.0
43
+ else:
44
+ DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2]
45
+
46
+ for jj in range(jj_base + 1 - Nt // 2, jj_base + Nt // 2):
47
+ j = np.abs(jj - jj_base)
48
+ i = i_base - jj_base + jj
49
+ if f_bin == Nf and jj > jj_base:
50
+ DX[i] = 0.0
51
+ elif f_bin == 0 and jj < jj_base:
52
+ DX[i] = 0.0
53
+ elif j == 0:
54
+ continue
55
+ else:
56
+ DX[i] = phif[j] * data[jj]
57
+
58
+ @njit()
59
+ def __fill_wave_2(
60
+ f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
61
+ ) -> None:
62
+ if f_bin == 0:
63
+ # half of lowest and highest frequency bin pixels are redundant, so store them in even and odd components of f_bin=0 respectively
64
+ for n in range(0, Nt, 2):
65
+ wave[n, 0] = DX_trans[n].real * np.sqrt(2)
66
+ elif f_bin == Nf:
67
+ for n in range(0, Nt, 2):
68
+ wave[n + 1, 0] = DX_trans[n].real * np.sqrt(2)
69
+ else:
70
+ for n in range(0, Nt):
71
+ if f_bin % 2:
72
+ if (n + f_bin) % 2:
73
+ wave[n, f_bin] = -DX_trans[n].imag
74
+ else:
75
+ wave[n, f_bin] = DX_trans[n].real
76
+ else:
77
+ if (n + f_bin) % 2:
78
+ wave[n, f_bin] = DX_trans[n].imag
79
+ else:
80
+ wave[n, f_bin] = DX_trans[n].real
@@ -0,0 +1,66 @@
1
+ """helper functions for transform_time.py"""
2
+ import numpy as np
3
+ from numba import njit
4
+ from numpy import fft
5
+
6
+
7
+ def transform_wavelet_time_helper(
8
+ data: np.ndarray, Nf: int, Nt: int, phi: np.ndarray, mult: int
9
+ ) -> np.ndarray:
10
+ """helper function to do the wavelet transform in the time domain"""
11
+ # the time domain freqseries stream
12
+ ND = Nf * Nt
13
+ K = mult * 2 * Nf
14
+ assert len(data) == ND, f"len(data)={len(data)} != Nf*Nt={ND}"
15
+
16
+ # windowed freqseries packets
17
+ wdata = np.zeros(K)
18
+ wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal
19
+ data_pad = np.concatenate((data, data[:K]))
20
+ __core(Nf, Nt, K, ND, wdata, data_pad, phi, wave, mult)
21
+ return wave
22
+
23
+ def __core(Nf: int, Nt: int, K: int, ND: int, wdata: np.ndarray, data_pad: np.ndarray, phi: np.ndarray, wave: np.ndarray, mult: int) -> None:
24
+ for time_bin_i in range(0, Nt):
25
+ __fill_wave_1(time_bin_i, K, ND, Nf, wdata, data_pad, phi)
26
+ wdata_trans = np.fft.rfft(wdata, K)
27
+ __fill_wave_2(time_bin_i, wave, wdata_trans, Nf, mult)
28
+
29
+
30
+ @njit()
31
+ def __fill_wave_1(
32
+ t_bin: int,
33
+ K: int,
34
+ ND: int,
35
+ Nf: int,
36
+ wdata: np.ndarray,
37
+ data_pad: np.ndarray,
38
+ phi: np.ndarray,
39
+ ) -> None:
40
+ """Assign wdata to be FFT'd in a loop with K extra values on the right to loop."""
41
+ # wrapping the freqseries is needed to make the sum in Eq 13 in Cornish paper from [-K/2, K/2]
42
+ jj = (t_bin * Nf - K // 2) % ND # Periodically wrap the freqseries
43
+ for j in range(K):
44
+ # Eq 13 from Cornish paper
45
+ wdata[j] = data_pad[jj] * phi[j] # Apply the window
46
+ jj = (jj + 1) % ND # Periodically wrap the freqseries
47
+
48
+
49
+ @njit()
50
+ def __fill_wave_2(
51
+ t_bin: int, wave: np.ndarray, wdata_trans: np.ndarray, Nf: int, mult: int
52
+ ) -> None:
53
+ # wdata_trans = np.sum(wdata) * np.exp(1j * np.pi * np.arange(0, 1+K//2) / K)
54
+
55
+ # pack fft'd wdata into wave array
56
+ if t_bin % 2 == 0 and t_bin < wave.shape[0] - 1: # if EVEN t_bin
57
+ # m=0 value at even Nt and
58
+ wave[t_bin, 0] = wdata_trans[0].real / np.sqrt(2)
59
+ wave[t_bin + 1, 0] = wdata_trans[Nf * mult].real / np.sqrt(2)
60
+
61
+ # Cnm in eq 13
62
+ for j in range(1, Nf):
63
+ if (t_bin + j) % 2:
64
+ wave[t_bin, j] = -wdata_trans[j * mult].imag
65
+ else:
66
+ wave[t_bin, j] = wdata_trans[j * mult].real
@@ -0,0 +1,128 @@
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+
5
+ from ...logger import logger
6
+ from ..phi_computer import phi_vec, phitilde_vec_norm
7
+ from ..types import FrequencySeries, TimeSeries, Wavelet
8
+ from .from_freq import transform_wavelet_freq_helper
9
+ from .from_time import transform_wavelet_time_helper
10
+ from .wavelet_bins import _get_bins, _preprocess_bins
11
+
12
+
13
+ __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
14
+
15
+ def from_time_to_wavelet(
16
+ timeseries: TimeSeries,
17
+ Nf: Union[int, None] = None,
18
+ Nt: Union[int, None] = None,
19
+ nx: float = 4.0,
20
+ mult: int = 32,
21
+ ) -> Wavelet:
22
+ """
23
+ Transform time-domain data to wavelet-domain data.
24
+
25
+ This function performs a forward wavelet transform, converting a
26
+ time-domain signal into a wavelet-domain representation.
27
+
28
+ Parameters
29
+ ----------
30
+ timeseries : TimeSeries
31
+ Input time-domain data, represented as a `TimeSeries` object.
32
+ Nf : int, optional
33
+ Number of frequency bins for the wavelet transform. Default is None.
34
+ Nt : int, optional
35
+ Number of time bins for the wavelet transform. Default is None.
36
+ nx : float, optional
37
+ Number of standard deviations for the `phi_vec`, controlling the
38
+ width of the wavelets. Default is 4.0.
39
+ mult : int, optional
40
+ Number of time bins to use for the wavelet transform. Ensure `mult` is
41
+ not larger than half the number of time bins (`Nt`). Default is 32.
42
+
43
+ Returns
44
+ -------
45
+ Wavelet
46
+ A `Wavelet` object representing the transformed wavelet-domain data.
47
+
48
+ Warnings
49
+ --------
50
+ There can be significant leakage if `mult` is too small. The transform is
51
+ only approximately exact if `mult = Nt / 2`.
52
+
53
+ Notes
54
+ -----
55
+ The function warns when the `mult` value is too large, potentially leading
56
+ to inaccurate results.
57
+ """
58
+ Nf, Nt = _preprocess_bins(timeseries, Nf, Nt)
59
+ dt = timeseries.dt
60
+ t_bins, f_bins = _get_bins(timeseries, Nf, Nt)
61
+
62
+ ND = Nf * Nt
63
+
64
+ if len(timeseries) != ND:
65
+ logger.warning(
66
+ f"len(freqseries)={len(timeseries)} != Nf*Nt={ND}. Truncating to freqseries[:{ND}]"
67
+ )
68
+ timeseries = timeseries[:ND]
69
+ if mult > Nt / 2:
70
+ logger.warning(
71
+ f"mult={mult} is too large for Nt={Nt}. This may lead to bogus results."
72
+ )
73
+
74
+ mult = min(mult, Nt // 2) # Ensure mult is not larger than ND/2
75
+ phi = phi_vec(Nf, dt=dt, d=nx, q=mult)
76
+ wave = transform_wavelet_time_helper(timeseries.data, Nf, Nt, phi, mult).T
77
+ return Wavelet(
78
+ wave * np.sqrt(2), time=t_bins, freq=f_bins
79
+ )
80
+
81
+
82
+ def from_freq_to_wavelet(
83
+ freqseries: FrequencySeries,
84
+ Nf: Union[int, None] = None,
85
+ Nt: Union[int, None] = None,
86
+ nx: float = 4.0,
87
+ ) -> Wavelet:
88
+ """
89
+ Transform frequency-domain data to wavelet-domain data.
90
+
91
+ This function performs a forward wavelet transform, converting a
92
+ frequency-domain signal into a wavelet-domain representation.
93
+
94
+ Parameters
95
+ ----------
96
+ freqseries : FrequencySeries
97
+ Input frequency-domain data, represented as a `FrequencySeries` object.
98
+ Nf : int, optional
99
+ Number of frequency bins for the wavelet transform. Default is None.
100
+ Nt : int, optional
101
+ Number of time bins for the wavelet transform. Default is None.
102
+ nx : float, optional
103
+ Number of standard deviations for the `phi_vec`, controlling the
104
+ width of the wavelets. Default is 4.0.
105
+
106
+ Returns
107
+ -------
108
+ Wavelet
109
+ A `Wavelet` object representing the transformed wavelet-domain data.
110
+
111
+ Notes
112
+ -----
113
+ The function normalizes the wavelet-domain data to ensure consistency
114
+ during the transformation process.
115
+ """
116
+ Nf, Nt = _preprocess_bins(freqseries, Nf, Nt)
117
+ t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
118
+ dt = freqseries.dt
119
+ phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)
120
+ wave = transform_wavelet_freq_helper(
121
+ freqseries.data, Nf, Nt, phif
122
+ )
123
+
124
+ return Wavelet(
125
+ (2 / Nf) * wave.T * np.sqrt(2),
126
+ time=t_bins,
127
+ freq=f_bins
128
+ )
@@ -0,0 +1,58 @@
1
+ from typing import Tuple, Union
2
+
3
+ import numpy as np
4
+
5
+ from ..types import FrequencySeries, TimeSeries
6
+
7
+
8
+ def _preprocess_bins(
9
+ data: Union[TimeSeries, FrequencySeries], Nf=None, Nt=None
10
+ ) -> Tuple[int, int]:
11
+ """preprocess the bins"""
12
+
13
+
14
+ if isinstance(data, TimeSeries):
15
+ N = len(data)
16
+ elif isinstance(data, FrequencySeries):
17
+ # len(d) = N // 2 + 1
18
+ N = 2 * (len(data) - 1)
19
+
20
+ if Nt is not None and Nf is None:
21
+ assert 1 <= Nt <= N, f"Nt={Nt} must be between 1 and N={N}"
22
+ Nf = N // Nt
23
+
24
+ elif Nf is not None and Nt is None:
25
+ assert 1 <= Nf <= N, f"Nf={Nf} must be between 1 and N={N}"
26
+ Nt = N // Nf
27
+
28
+ _N = Nf * Nt
29
+ return Nf, Nt
30
+
31
+
32
+ def _get_bins(
33
+ data: Union[TimeSeries, FrequencySeries],
34
+ Nf: Union[int, None] = None,
35
+ Nt: Union[int, None] = None,
36
+ ) -> Tuple[np.ndarray, np.ndarray]:
37
+
38
+ T = data.duration
39
+ t_bins, f_bins = compute_bins(Nf, Nt, T)
40
+
41
+ # N = len(data)
42
+ # fs = N / T
43
+ # assert delta_f == fmax / Nf, f"delta_f={delta_f} != fmax/Nf={fmax/Nf}"
44
+
45
+ t_bins+= data.t0
46
+
47
+ return t_bins, f_bins
48
+
49
+
50
+ def compute_bins(Nf:int, Nt:int, T:float) -> Tuple[np.ndarray, np.ndarray]:
51
+ """Get the bins for the wavelet transform
52
+ Eq 4-6 in Wavelets paper
53
+ """
54
+ delta_T = T / Nt
55
+ delta_F = 1 / (2 * delta_T)
56
+ t_bins = np.arange(0, Nt) * delta_T
57
+ f_bins = np.arange(0, Nf) * delta_F
58
+ return t_bins, f_bins
@@ -0,0 +1,3 @@
1
+ from .main import from_wavelet_to_freq, from_wavelet_to_time
2
+
3
+ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
@@ -0,0 +1,96 @@
1
+ import numpy as np
2
+
3
+ from ...transforms.phi_computer import phi_vec, phitilde_vec_norm
4
+ from ..types import FrequencySeries, TimeSeries, Wavelet
5
+ from .to_freq import inverse_wavelet_freq_helper_fast
6
+ from .to_time import inverse_wavelet_time_helper_fast
7
+
8
+ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
9
+
10
+ INV_ROOT2 = 1.0 / np.sqrt(2)
11
+
12
+ def from_wavelet_to_time(
13
+ wave_in: Wavelet,
14
+ dt: float,
15
+ nx: float = 4.0,
16
+ mult: int = 32,
17
+ ) -> TimeSeries:
18
+ """
19
+ Perform an inverse wavelet transform to the time domain.
20
+
21
+ This function converts a wavelet-domain signal to a time-domain signal using
22
+ the inverse wavelet transform algorithm.
23
+
24
+ Parameters
25
+ ----------
26
+ wave_in : Wavelet
27
+ Input wavelet, represented by a `Wavelet` object.
28
+ dt : float
29
+ Time step of the wavelet data.
30
+ nx : float, optional
31
+ Scaling parameter for the phi vector used in the transformation. Default is 4.0.
32
+ mult : int, optional
33
+ Multiplier parameter for the phi vector. Ensures that the `mult` value
34
+ is not larger than half the number of time bins (`wave_in.Nt`). Default is 32.
35
+
36
+ Returns
37
+ -------
38
+ TimeSeries
39
+ A `TimeSeries` object containing the signal transformed into the time domain.
40
+
41
+ Notes
42
+ -----
43
+ The transformation involves normalizing the output by the square root of 2
44
+ to ensure the proper backwards transformation.
45
+ """
46
+
47
+ mult = min(mult, wave_in.Nt // 2) # Ensure mult is not larger than ND/2
48
+ phi = phi_vec(wave_in.Nf, d=nx, q=mult, dt=dt) / 2
49
+ h_t = inverse_wavelet_time_helper_fast(
50
+ wave_in.data.T, phi, wave_in.Nf, wave_in.Nt, mult
51
+ )
52
+ h_t *= INV_ROOT2 # Normalize to get proper backward transformation
53
+ ts = np.arange(0, wave_in.Nf * wave_in.Nt) * dt
54
+ return TimeSeries(data=h_t, time=ts)
55
+
56
+
57
+ def from_wavelet_to_freq(
58
+ wave_in: Wavelet,
59
+ dt: float,
60
+ nx:float=4.0
61
+ ) -> FrequencySeries:
62
+ """
63
+ Perform an inverse wavelet transform to the frequency domain.
64
+
65
+ This function converts a wavelet-domain signal into a frequency-domain
66
+ signal using the inverse wavelet transform algorithm.
67
+
68
+ Parameters
69
+ ----------
70
+ wave_in : Wavelet
71
+ Input wavelet, represented by a `Wavelet` object.
72
+ dt : float
73
+ Time step of the wavelet data.
74
+ nx : float, optional
75
+ Scaling parameter for the phi vector used in the transformation. Default is 4.0.
76
+
77
+ Returns
78
+ -------
79
+ FrequencySeries
80
+ A `FrequencySeries` object containing the signal transformed into the frequency domain.
81
+
82
+ Notes
83
+ -----
84
+ The transformation involves normalizing the output by the square root of 2
85
+ to ensure the proper backwards transformation.
86
+ """
87
+
88
+ phif = phitilde_vec_norm(wave_in.Nf, wave_in.Nt, dt=dt, d=nx)
89
+ freq_data = inverse_wavelet_freq_helper_fast(
90
+ wave_in.data, phif, wave_in.Nf, wave_in.Nt
91
+ )
92
+
93
+ freq_data *= INV_ROOT2
94
+
95
+ freqs = np.fft.rfftfreq(wave_in.ND, d=dt)
96
+ return FrequencySeries(data=freq_data, freq=freqs)
@@ -1,29 +1,61 @@
1
1
  """functions for computing the inverse wavelet transforms"""
2
2
  import numpy as np
3
3
  from numba import njit
4
+ from numpy import fft
4
5
 
5
- from ... import fft_funcs as fft
6
6
 
7
-
8
- # @njit()
9
- def inverse_wavelet_freq_helper_fast(wave_in, phif, Nf, Nt):
7
+ def inverse_wavelet_freq_helper_fast(
8
+ wave_in: np.ndarray, phif: np.ndarray, Nf: int, Nt: int
9
+ ) -> np.ndarray:
10
10
  """jit compatible loop for inverse_wavelet_freq"""
11
+ wave_in = wave_in.T
11
12
  ND = Nf * Nt
12
13
 
13
14
  prefactor2s = np.zeros(Nt, np.complex128)
14
- res = np.zeros(ND // 2 + 1, dtype=np.complex128)
15
+ res = np.zeros(ND//2 +1, dtype=np.complex128)
16
+ __core(Nf, Nt, prefactor2s, wave_in, phif, res)
15
17
 
16
- for m in range(0, Nf + 1):
17
- pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
18
- # with numba.objmode(fft_prefactor2s="complex128[:]"):
19
- fft_prefactor2s = fft.fft(prefactor2s)
20
- unpack_wave_inverse(m, Nt, Nf, phif, fft_prefactor2s, res)
21
18
 
22
19
  return res
23
20
 
21
+ def __core(Nf: int, Nt: int, prefactor2s: np.ndarray, wave_in: np.ndarray, phif: np.ndarray, res: np.ndarray) -> None:
22
+ for m in range(0, Nf + 1):
23
+ __pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
24
+ fft_prefactor2s = np.fft.fft(prefactor2s)
25
+ __unpack_wave_inverse(m, Nt, Nf, phif, fft_prefactor2s, res)
26
+
24
27
 
25
28
  @njit()
26
- def unpack_wave_inverse(m, Nt, Nf, phif, fft_prefactor2s, res):
29
+ def __pack_wave_inverse(
30
+ m: int, Nt: int, Nf: int, prefactor2s: np.ndarray, wave_in: np.ndarray
31
+ ) -> None:
32
+ """helper for fast frequency domain inverse transform to prepare for fourier transform"""
33
+ if m == 0:
34
+ for n in range(0, Nt):
35
+ prefactor2s[n] = 2 ** (-1 / 2) * wave_in[(2 * n) % Nt, 0]
36
+ elif m == Nf:
37
+ for n in range(0, Nt):
38
+ prefactor2s[n] = 2 ** (-1 / 2) * wave_in[(2 * n) % Nt + 1, 0]
39
+ else:
40
+ for n in range(0, Nt):
41
+ val = wave_in[n, m] # bug is here
42
+ if (n + m) % 2:
43
+ mult2 = -1j
44
+ else:
45
+ mult2 = 1
46
+
47
+ prefactor2s[n] = mult2 * val
48
+
49
+
50
+ @njit()
51
+ def __unpack_wave_inverse(
52
+ m: int,
53
+ Nt: int,
54
+ Nf: int,
55
+ phif: np.ndarray,
56
+ fft_prefactor2s: np.ndarray,
57
+ res: np.ndarray,
58
+ ) -> None:
27
59
  """helper for unpacking results of frequency domain inverse transform"""
28
60
 
29
61
  if m == 0 or m == Nf:
@@ -52,25 +84,4 @@ def unpack_wave_inverse(m, Nt, Nf, phif, fft_prefactor2s, res):
52
84
  ind31 = Nt - 1
53
85
  if ind32 == Nt:
54
86
  ind32 = 0
55
-
56
87
  res[Nt // 2 * m] = fft_prefactor2s[(Nt // 2 * m) % Nt] * phif[0]
57
-
58
-
59
- @njit()
60
- def pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in):
61
- """helper for fast frequency domain inverse transform to prepare for fourier transform"""
62
- if m == 0:
63
- for n in range(0, Nt):
64
- prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt, 0]
65
- elif m == Nf:
66
- for n in range(0, Nt):
67
- prefactor2s[n] = 1 / np.sqrt(2) * wave_in[(2 * n) % Nt + 1, 0]
68
- else:
69
- for n in range(0, Nt):
70
- val = wave_in[n, m]
71
- if (n + m) % 2:
72
- mult2 = -1j
73
- else:
74
- mult2 = 1
75
-
76
- prefactor2s[n] = mult2 * val