pywavelet 0.0.5__tar.gz → 0.1.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (65) hide show
  1. {pywavelet-0.0.5 → pywavelet-0.1.1}/CHANGELOG.rst +44 -0
  2. {pywavelet-0.0.5 → pywavelet-0.1.1}/PKG-INFO +1 -1
  3. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/_version.py +2 -2
  4. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/logger.py +6 -2
  5. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/__init__.py +1 -2
  6. pywavelet-0.1.1/src/pywavelet/transforms/forward/__init__.py +3 -0
  7. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/forward/from_freq.py +13 -4
  8. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/forward/from_time.py +13 -1
  9. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/forward/main.py +7 -15
  10. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/inverse/main.py +3 -4
  11. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/inverse/to_freq.py +11 -3
  12. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/inverse/to_time.py +13 -3
  13. {pywavelet-0.0.5/src/pywavelet/transforms → pywavelet-0.1.1/src/pywavelet}/types/__init__.py +1 -1
  14. {pywavelet-0.0.5/src/pywavelet/transforms → pywavelet-0.1.1/src/pywavelet}/types/common.py +7 -7
  15. {pywavelet-0.0.5/src/pywavelet/transforms → pywavelet-0.1.1/src/pywavelet}/types/frequencyseries.py +31 -23
  16. {pywavelet-0.0.5/src/pywavelet/transforms → pywavelet-0.1.1/src/pywavelet}/types/plotting.py +47 -26
  17. {pywavelet-0.0.5/src/pywavelet/transforms → pywavelet-0.1.1/src/pywavelet}/types/timeseries.py +58 -38
  18. {pywavelet-0.0.5/src/pywavelet/transforms → pywavelet-0.1.1/src/pywavelet}/types/wavelet.py +111 -31
  19. {pywavelet-0.0.5/src/pywavelet/transforms/forward → pywavelet-0.1.1/src/pywavelet/types}/wavelet_bins.py +5 -6
  20. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/utils.py +17 -5
  21. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet.egg-info/PKG-INFO +1 -1
  22. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet.egg-info/SOURCES.txt +9 -7
  23. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/conftest.py +28 -5
  24. pywavelet-0.1.1/tests/test_lnl.py +18 -0
  25. pywavelet-0.1.1/tests/test_mask.py +37 -0
  26. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_roundtrip_conversion.py +47 -22
  27. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_snr.py +20 -25
  28. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_timefreq_type.py +15 -12
  29. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_wavelet_plot.py +9 -2
  30. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/utils/__init__.py +1 -1
  31. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/utils/generate_data.py +9 -12
  32. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/utils/plotting.py +70 -23
  33. pywavelet-0.0.5/src/pywavelet/transforms/forward/__init__.py +0 -4
  34. {pywavelet-0.0.5 → pywavelet-0.1.1}/.github/workflows/ci.yml +0 -0
  35. {pywavelet-0.0.5 → pywavelet-0.1.1}/.github/workflows/docs.yml +0 -0
  36. {pywavelet-0.0.5 → pywavelet-0.1.1}/.github/workflows/pypi.yml +0 -0
  37. {pywavelet-0.0.5 → pywavelet-0.1.1}/.gitignore +0 -0
  38. {pywavelet-0.0.5 → pywavelet-0.1.1}/.pre-commit-config.yaml +0 -0
  39. {pywavelet-0.0.5 → pywavelet-0.1.1}/CITATION.cff +0 -0
  40. {pywavelet-0.0.5 → pywavelet-0.1.1}/README.rst +0 -0
  41. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/_config.yml +0 -0
  42. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/_static/demo.gif +0 -0
  43. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/_toc.yml +0 -0
  44. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/api.rst +0 -0
  45. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/example.ipynb +0 -0
  46. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/index.rst +0 -0
  47. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/logo.png +0 -0
  48. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/roundtrip_freq.png +0 -0
  49. {pywavelet-0.0.5 → pywavelet-0.1.1}/docs/roundtrip_time.png +0 -0
  50. {pywavelet-0.0.5 → pywavelet-0.1.1}/pyproject.toml +0 -0
  51. {pywavelet-0.0.5 → pywavelet-0.1.1}/setup.cfg +0 -0
  52. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/__init__.py +0 -0
  53. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/inverse/__init__.py +0 -0
  54. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet/transforms/phi_computer.py +0 -0
  55. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet.egg-info/dependency_links.txt +0 -0
  56. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet.egg-info/requires.txt +0 -0
  57. {pywavelet-0.0.5 → pywavelet-0.1.1}/src/pywavelet.egg-info/top_level.txt +0 -0
  58. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_data/roundtrip_chirp_freq.npz +0 -0
  59. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_data/roundtrip_chirp_time.npz +0 -0
  60. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_data/roundtrip_pure_f0_freq.npz +0 -0
  61. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_data/roundtrip_sine_freq.npz +0 -0
  62. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_data/roundtrip_sine_time.npz +0 -0
  63. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_phi.py +0 -0
  64. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_psd.py +0 -0
  65. {pywavelet-0.0.5 → pywavelet-0.1.1}/tests/test_version.py +0 -0
@@ -5,6 +5,44 @@ CHANGELOG
5
5
  =========
6
6
 
7
7
 
8
+ .. _changelog-v0.1.1:
9
+
10
+ v0.1.1 (2025-01-16)
11
+ ===================
12
+
13
+ Unknown
14
+ -------
15
+
16
+ * Merge branch 'main' of github.com:avivajpeyi/pywavelet into main (`69eefa2`_)
17
+
18
+ .. _69eefa2: https://github.com/pywavelet/pywavelet/commit/69eefa29b7873c30fcb74ad1e051eb20101a277a
19
+
20
+
21
+ .. _changelog-v0.1.0:
22
+
23
+ v0.1.0 (2025-01-15)
24
+ ===================
25
+
26
+ Bug Fixes
27
+ ---------
28
+
29
+ * fix: refactor type outside transforms (`efb8878`_)
30
+
31
+ Chores
32
+ ------
33
+
34
+ * chore(release): 0.1.0 (`c5a3fde`_)
35
+
36
+ Features
37
+ --------
38
+
39
+ * feat: add wavelet mask and more tests (`e009903`_)
40
+
41
+ .. _efb8878: https://github.com/pywavelet/pywavelet/commit/efb88789f8468ff18f99abaf6168bb8fc0f5947b
42
+ .. _c5a3fde: https://github.com/pywavelet/pywavelet/commit/c5a3fdea455c16478f04049f14bc35dfcf4efb15
43
+ .. _e009903: https://github.com/pywavelet/pywavelet/commit/e00990300d9c013438580c2bc47ea93570fd95be
44
+
45
+
8
46
  .. _changelog-v0.0.5:
9
47
 
10
48
  v0.0.5 (2024-12-12)
@@ -15,7 +53,13 @@ Bug Fixes
15
53
 
16
54
  * fix: update changelog generator (`884c87b`_)
17
55
 
56
+ Chores
57
+ ------
58
+
59
+ * chore(release): 0.0.5 (`4ed6b03`_)
60
+
18
61
  .. _884c87b: https://github.com/pywavelet/pywavelet/commit/884c87bcd36b5d21eb1a8e10ee9e0edf6f65d744
62
+ .. _4ed6b03: https://github.com/pywavelet/pywavelet/commit/4ed6b03618347cc179195feec57b05e04a004100
19
63
 
20
64
 
21
65
  .. _changelog-v0.0.4:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pywavelet
3
- Version: 0.0.5
3
+ Version: 0.1.1
4
4
  Summary: WDM wavelet transform your time/freq series!
5
5
  Author-email: Pywavelet Team <avi.vajpeyi@gmail.com>
6
6
  Project-URL: Homepage, https://pywavelet.github.io/pywavelet/
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.0.5'
16
- __version_tuple__ = version_tuple = (0, 0, 5)
15
+ __version__ = version = '0.1.1'
16
+ __version_tuple__ = version_tuple = (0, 1, 1)
@@ -1,11 +1,15 @@
1
+ import logging
1
2
  import sys
2
3
  import warnings
4
+
3
5
  from rich.logging import RichHandler
4
- import logging
5
6
 
6
7
  FORMAT = "%(message)s"
7
8
  logging.basicConfig(
8
- level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)]
9
+ level="INFO",
10
+ format=FORMAT,
11
+ datefmt="[%X]",
12
+ handlers=[RichHandler(rich_tracebacks=True)],
9
13
  )
10
14
 
11
15
  logger = logging.getLogger("pywavelet")
@@ -1,4 +1,4 @@
1
- from .forward import from_freq_to_wavelet, from_time_to_wavelet, compute_bins
1
+ from .forward import from_freq_to_wavelet, from_time_to_wavelet
2
2
  from .inverse import from_wavelet_to_freq, from_wavelet_to_time
3
3
 
4
4
  __all__ = [
@@ -6,5 +6,4 @@ __all__ = [
6
6
  "from_wavelet_to_freq",
7
7
  "from_time_to_wavelet",
8
8
  "from_freq_to_wavelet",
9
- "compute_bins",
10
9
  ]
@@ -0,0 +1,3 @@
1
+ from .main import from_freq_to_wavelet, from_time_to_wavelet
2
+
3
+ __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
@@ -1,8 +1,10 @@
1
1
  """helper functions for transform_freq"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
5
6
 
7
+
6
8
  def transform_wavelet_freq_helper(
7
9
  data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray
8
10
  ) -> np.ndarray:
@@ -13,8 +15,16 @@ def transform_wavelet_freq_helper(
13
15
  __core(Nf, Nt, DX, freq_strain, phif, wave)
14
16
  return wave
15
17
 
18
+
16
19
  # @njit()
17
- def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarray, wave:np.ndarray):
20
+ def __core(
21
+ Nf: int,
22
+ Nt: int,
23
+ DX: np.ndarray,
24
+ freq_strain: np.ndarray,
25
+ phif: np.ndarray,
26
+ wave: np.ndarray,
27
+ ):
18
28
  for f_bin in range(0, Nf + 1):
19
29
  __fill_wave_1(f_bin, Nt, Nf, DX, freq_strain, phif)
20
30
  # Numba doesn't support np.ifft so we cant jit this
@@ -22,8 +32,6 @@ def __core(Nf:int, Nt:int, DX:np.ndarray, freq_strain:np.ndarray, phif:np.ndarra
22
32
  __fill_wave_2(f_bin, DX_trans, wave, Nt, Nf)
23
33
 
24
34
 
25
-
26
-
27
35
  @njit()
28
36
  def __fill_wave_1(
29
37
  f_bin: int,
@@ -55,6 +63,7 @@ def __fill_wave_1(
55
63
  else:
56
64
  DX[i] = phif[j] * data[jj]
57
65
 
66
+
58
67
  @njit()
59
68
  def __fill_wave_2(
60
69
  f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int
@@ -72,7 +81,7 @@ def __fill_wave_2(
72
81
  if (n + f_bin) % 2:
73
82
  wave[n, f_bin] = -DX_trans[n].imag
74
83
  else:
75
- wave[n, f_bin] = DX_trans[n].real
84
+ wave[n, f_bin] = DX_trans[n].real
76
85
  else:
77
86
  if (n + f_bin) % 2:
78
87
  wave[n, f_bin] = DX_trans[n].imag
@@ -1,4 +1,5 @@
1
1
  """helper functions for transform_time.py"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -20,7 +21,18 @@ def transform_wavelet_time_helper(
20
21
  __core(Nf, Nt, K, ND, wdata, data_pad, phi, wave, mult)
21
22
  return wave
22
23
 
23
- def __core(Nf: int, Nt: int, K: int, ND: int, wdata: np.ndarray, data_pad: np.ndarray, phi: np.ndarray, wave: np.ndarray, mult: int) -> None:
24
+
25
+ def __core(
26
+ Nf: int,
27
+ Nt: int,
28
+ K: int,
29
+ ND: int,
30
+ wdata: np.ndarray,
31
+ data_pad: np.ndarray,
32
+ phi: np.ndarray,
33
+ wave: np.ndarray,
34
+ mult: int,
35
+ ) -> None:
24
36
  for time_bin_i in range(0, Nt):
25
37
  __fill_wave_1(time_bin_i, K, ND, Nf, wdata, data_pad, phi)
26
38
  wdata_trans = np.fft.rfft(wdata, K)
@@ -3,15 +3,15 @@ from typing import Union
3
3
  import numpy as np
4
4
 
5
5
  from ...logger import logger
6
+ from ...types import FrequencySeries, TimeSeries, Wavelet
7
+ from ...types.wavelet_bins import _get_bins, _preprocess_bins
6
8
  from ..phi_computer import phi_vec, phitilde_vec_norm
7
- from ..types import FrequencySeries, TimeSeries, Wavelet
8
9
  from .from_freq import transform_wavelet_freq_helper
9
10
  from .from_time import transform_wavelet_time_helper
10
- from .wavelet_bins import _get_bins, _preprocess_bins
11
-
12
11
 
13
12
  __all__ = ["from_time_to_wavelet", "from_freq_to_wavelet"]
14
13
 
14
+
15
15
  def from_time_to_wavelet(
16
16
  timeseries: TimeSeries,
17
17
  Nf: Union[int, None] = None,
@@ -74,9 +74,7 @@ def from_time_to_wavelet(
74
74
  mult = min(mult, Nt // 2) # Ensure mult is not larger than ND/2
75
75
  phi = phi_vec(Nf, dt=dt, d=nx, q=mult)
76
76
  wave = transform_wavelet_time_helper(timeseries.data, Nf, Nt, phi, mult).T
77
- return Wavelet(
78
- wave * np.sqrt(2), time=t_bins, freq=f_bins
79
- )
77
+ return Wavelet(wave * np.sqrt(2), time=t_bins, freq=f_bins)
80
78
 
81
79
 
82
80
  def from_freq_to_wavelet(
@@ -117,12 +115,6 @@ def from_freq_to_wavelet(
117
115
  t_bins, f_bins = _get_bins(freqseries, Nf, Nt)
118
116
  dt = freqseries.dt
119
117
  phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)
120
- wave = transform_wavelet_freq_helper(
121
- freqseries.data, Nf, Nt, phif
122
- )
123
-
124
- return Wavelet(
125
- (2 / Nf) * wave.T * np.sqrt(2),
126
- time=t_bins,
127
- freq=f_bins
128
- )
118
+ wave = transform_wavelet_freq_helper(freqseries.data, Nf, Nt, phif)
119
+
120
+ return Wavelet((2 / Nf) * wave.T * np.sqrt(2), time=t_bins, freq=f_bins)
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
 
3
3
  from ...transforms.phi_computer import phi_vec, phitilde_vec_norm
4
- from ..types import FrequencySeries, TimeSeries, Wavelet
4
+ from ...types import FrequencySeries, TimeSeries, Wavelet
5
5
  from .to_freq import inverse_wavelet_freq_helper_fast
6
6
  from .to_time import inverse_wavelet_time_helper_fast
7
7
 
@@ -9,6 +9,7 @@ __all__ = ["from_wavelet_to_time", "from_wavelet_to_freq"]
9
9
 
10
10
  INV_ROOT2 = 1.0 / np.sqrt(2)
11
11
 
12
+
12
13
  def from_wavelet_to_time(
13
14
  wave_in: Wavelet,
14
15
  dt: float,
@@ -55,9 +56,7 @@ def from_wavelet_to_time(
55
56
 
56
57
 
57
58
  def from_wavelet_to_freq(
58
- wave_in: Wavelet,
59
- dt: float,
60
- nx:float=4.0
59
+ wave_in: Wavelet, dt: float, nx: float = 4.0
61
60
  ) -> FrequencySeries:
62
61
  """
63
62
  Perform an inverse wavelet transform to the frequency domain.
@@ -1,4 +1,5 @@
1
1
  """functions for computing the inverse wavelet transforms"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -12,13 +13,20 @@ def inverse_wavelet_freq_helper_fast(
12
13
  ND = Nf * Nt
13
14
 
14
15
  prefactor2s = np.zeros(Nt, np.complex128)
15
- res = np.zeros(ND//2 +1, dtype=np.complex128)
16
+ res = np.zeros(ND // 2 + 1, dtype=np.complex128)
16
17
  __core(Nf, Nt, prefactor2s, wave_in, phif, res)
17
18
 
18
-
19
19
  return res
20
20
 
21
- def __core(Nf: int, Nt: int, prefactor2s: np.ndarray, wave_in: np.ndarray, phif: np.ndarray, res: np.ndarray) -> None:
21
+
22
+ def __core(
23
+ Nf: int,
24
+ Nt: int,
25
+ prefactor2s: np.ndarray,
26
+ wave_in: np.ndarray,
27
+ phif: np.ndarray,
28
+ res: np.ndarray,
29
+ ) -> None:
22
30
  for m in range(0, Nf + 1):
23
31
  __pack_wave_inverse(m, Nt, Nf, prefactor2s, wave_in)
24
32
  fft_prefactor2s = np.fft.fft(prefactor2s)
@@ -1,4 +1,5 @@
1
1
  """functions for computing the inverse wavelet transforms"""
2
+
2
3
  import numpy as np
3
4
  from numba import njit
4
5
  from numpy import fft
@@ -21,7 +22,16 @@ def inverse_wavelet_time_helper_fast(
21
22
  return res[:ND]
22
23
 
23
24
 
24
- def __core(Nf: int, Nt: int, K: int, ND: int, wave_in: np.ndarray, phi: np.ndarray, res: np.ndarray, afins:np.ndarray) -> None:
25
+ def __core(
26
+ Nf: int,
27
+ Nt: int,
28
+ K: int,
29
+ ND: int,
30
+ wave_in: np.ndarray,
31
+ phi: np.ndarray,
32
+ res: np.ndarray,
33
+ afins: np.ndarray,
34
+ ) -> None:
25
35
  for n in range(0, Nt):
26
36
  if n % 2 == 0:
27
37
  pack_wave_time_helper_compact(n, Nf, Nt, wave_in, afins)
@@ -29,9 +39,9 @@ def __core(Nf: int, Nt: int, K: int, ND: int, wave_in: np.ndarray, phi: np.ndarr
29
39
  unpack_time_wave_helper_compact(n, Nf, Nt, K, phi, ffts_fin, res)
30
40
 
31
41
  # wrap boundary conditions
32
- res[: min(K + Nf, ND)] += res[ND: min(ND + K + Nf, 2 * ND)]
42
+ res[: min(K + Nf, ND)] += res[ND : min(ND + K + Nf, 2 * ND)]
33
43
  if K + Nf > ND:
34
- res[: K + Nf - ND] += res[2 * ND: ND + K * Nf]
44
+ res[: K + Nf - ND] += res[2 * ND : ND + K * Nf]
35
45
 
36
46
 
37
47
  def unpack_time_wave_helper(
@@ -1,3 +1,3 @@
1
1
  from .frequencyseries import FrequencySeries
2
2
  from .timeseries import TimeSeries
3
- from .wavelet import Wavelet
3
+ from .wavelet import Wavelet, WaveletMask
@@ -1,9 +1,9 @@
1
- from typing import Literal, Tuple
1
+ from typing import Tuple, Union
2
2
 
3
3
  import numpy as xp
4
- from numpy.fft import irfft, fft, rfft, rfftfreq
4
+ from numpy.fft import fft, irfft, rfft, rfftfreq # type: ignore
5
5
 
6
- from ...logger import logger
6
+ from ..logger import logger
7
7
 
8
8
 
9
9
  def _len_check(d):
@@ -19,7 +19,7 @@ def is_documented_by(original):
19
19
  return wrapper
20
20
 
21
21
 
22
- def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
22
+ def fmt_time(seconds: float, units=False) -> Union[str, Tuple[str, str]]:
23
23
  """Returns formatted time and units [ms, s, min, hr, day]"""
24
24
  t, u = "", ""
25
25
  if seconds < 1e-3:
@@ -42,12 +42,12 @@ def fmt_time(seconds: float, units=False) -> Tuple[str, str]:
42
42
 
43
43
  def fmt_timerange(trange):
44
44
  t0 = fmt_time(trange[0])
45
- tend, units = fmt_time(trange[1], units = True)
45
+ tend, units = fmt_time(trange[1], units=True)
46
46
  return f"[{t0}, {tend}] {units}"
47
47
 
48
48
 
49
- def fmt_pow2(n:float)->str:
49
+ def fmt_pow2(n: float) -> str:
50
50
  pow2 = xp.log2(n)
51
51
  if pow2.is_integer():
52
52
  return f"2^{int(pow2)}"
53
- return f"{n:,}"
53
+ return f"{n:,}"
@@ -1,11 +1,13 @@
1
+ from typing import Optional, Tuple, Union
2
+
1
3
  import matplotlib.pyplot as plt
2
- from typing import Tuple, Union, Optional
3
4
 
4
- from .common import is_documented_by, xp, irfft, fmt_time, fmt_pow2
5
+ from .common import fmt_pow2, fmt_time, irfft, is_documented_by, xp
5
6
  from .plotting import plot_freqseries, plot_periodogram
6
7
 
7
8
  __all__ = ["FrequencySeries"]
8
9
 
10
+
9
11
  class FrequencySeries:
10
12
  """
11
13
  A class to represent a one-sided frequency series, with various methods for
@@ -39,9 +41,13 @@ class FrequencySeries:
39
41
  If any frequency is negative or if `data` and `freq` do not have the same length.
40
42
  """
41
43
  if xp.any(freq < 0):
42
- raise ValueError("FrequencySeries must be one-sided (only non-negative frequencies)")
44
+ raise ValueError(
45
+ "FrequencySeries must be one-sided (only non-negative frequencies)"
46
+ )
43
47
  if len(data) != len(freq):
44
- raise ValueError(f"data and freq must have the same length ({len(data)} != {len(freq)})")
48
+ raise ValueError(
49
+ f"data and freq must have the same length ({len(data)} != {len(freq)})"
50
+ )
45
51
  self.data = data
46
52
  self.freq = freq
47
53
  self.t0 = t0
@@ -53,10 +59,10 @@ class FrequencySeries:
53
59
  )
54
60
 
55
61
  @is_documented_by(plot_periodogram)
56
- def plot_periodogram(self, ax=None, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
57
- return plot_periodogram(
58
- self.data, self.freq, self.fs, ax=ax, **kwargs
59
- )
62
+ def plot_periodogram(
63
+ self, ax=None, **kwargs
64
+ ) -> Tuple[plt.Figure, plt.Axes]:
65
+ return plot_periodogram(self.data, self.freq, self.fs, ax=ax, **kwargs)
60
66
 
61
67
  def __len__(self):
62
68
  """Return the length of the frequency series."""
@@ -127,7 +133,9 @@ class FrequencySeries:
127
133
  n = fmt_pow2(len(self))
128
134
  return f"FrequencySeries(n={n}, frange=[{self.range[0]:.2f}, {self.range[1]:.2f}] Hz, T={dur}, fs={self.fs:.2f} Hz)"
129
135
 
130
- def noise_weighted_inner_product(self, other: "FrequencySeries", psd:"FrequencySeries") -> float:
136
+ def noise_weighted_inner_product(
137
+ self, other: "FrequencySeries", psd: "FrequencySeries"
138
+ ) -> float:
131
139
  """
132
140
  Compute the noise-weighted inner product of two FrequencySeries.
133
141
 
@@ -144,9 +152,11 @@ class FrequencySeries:
144
152
  The noise-weighted inner product of the two FrequencySeries.
145
153
  """
146
154
  integrand = xp.real(xp.conj(self.data) * other.data / psd.data)
147
- return (4 * self.dt/self.ND) * xp.nansum(integrand)
155
+ return (4 * self.dt / self.ND) * xp.nansum(integrand)
148
156
 
149
- def matched_filter_snr(self, other: "FrequencySeries", psd: "FrequencySeries") -> float:
157
+ def matched_filter_snr(
158
+ self, other: "FrequencySeries", psd: "FrequencySeries"
159
+ ) -> float:
150
160
  """
151
161
  Compute the signal-to-noise ratio (SNR) of a matched filter.
152
162
 
@@ -199,15 +209,15 @@ class FrequencySeries:
199
209
 
200
210
  # Create and return a TimeSeries object
201
211
  from .timeseries import TimeSeries
202
- return TimeSeries(time_data, time)
203
212
 
213
+ return TimeSeries(time_data, time)
204
214
 
205
215
  def to_wavelet(
206
- self,
207
- Nf: Union[int, None] = None,
208
- Nt: Union[int, None] = None,
209
- nx: Optional[float] = 4.0,
210
- )->"Wavelet":
216
+ self,
217
+ Nf: Union[int, None] = None,
218
+ Nt: Union[int, None] = None,
219
+ nx: Optional[float] = 4.0,
220
+ ) -> "Wavelet":
211
221
  """
212
222
  Convert the frequency series to a wavelet using inverse Fourier transform.
213
223
 
@@ -216,9 +226,9 @@ class FrequencySeries:
216
226
  Wavelet
217
227
  The corresponding wavelet.
218
228
  """
219
- from ..forward import from_freq_to_wavelet
220
- return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
229
+ from ..transforms.forward import from_freq_to_wavelet
221
230
 
231
+ return from_freq_to_wavelet(self, Nf=Nf, Nt=Nt, nx=nx)
222
232
 
223
233
  def __eq__(self, other):
224
234
  """Check if two FrequencySeries objects are equal."""
@@ -228,10 +238,8 @@ class FrequencySeries:
228
238
 
229
239
  def __copy__(self):
230
240
  return FrequencySeries(
231
- xp.copy(self.data),
232
- xp.copy(self.freq),
233
- t0=self.t0
241
+ xp.copy(self.data), xp.copy(self.freq), t0=self.t0
234
242
  )
235
243
 
236
244
  def copy(self):
237
- return self.__copy__()
245
+ return self.__copy__()
@@ -1,12 +1,11 @@
1
1
  import warnings
2
- from typing import Tuple, Optional, Union
3
- from scipy.signal import savgol_filter
4
- from scipy.interpolate import interp1d
2
+ from typing import Optional, Tuple, Union
5
3
 
6
4
  import matplotlib.pyplot as plt
7
5
  import numpy as np
8
6
  from matplotlib.colors import LogNorm, TwoSlopeNorm
9
- from scipy.signal import spectrogram
7
+ from scipy.interpolate import interp1d
8
+ from scipy.signal import savgol_filter, spectrogram
10
9
 
11
10
  MIN_S = 60
12
11
  HOUR_S = 60 * MIN_S
@@ -53,8 +52,13 @@ def __get_smoothed_y(x, z, y_grid):
53
52
  # Interpolate to fill NaNs in y before smoothing
54
53
  nan_mask = ~np.isnan(y)
55
54
  if np.isnan(y).any():
56
- interpolator = interp1d(x[nan_mask], y[nan_mask], kind='cubic', bounds_error=False,
57
- fill_value="extrapolate")
55
+ interpolator = interp1d(
56
+ x[nan_mask],
57
+ y[nan_mask],
58
+ kind="cubic",
59
+ bounds_error=False,
60
+ fill_value="extrapolate",
61
+ )
58
62
  y = interpolator(x) # Fill NaNs with interpolated values
59
63
 
60
64
  # Smooth the curve
@@ -64,8 +68,6 @@ def __get_smoothed_y(x, z, y_grid):
64
68
  return y
65
69
 
66
70
 
67
-
68
-
69
71
  def plot_wavelet_grid(
70
72
  wavelet_data: np.ndarray,
71
73
  time_grid: np.ndarray,
@@ -80,8 +82,8 @@ def plot_wavelet_grid(
80
82
  norm: Optional[Union[LogNorm, TwoSlopeNorm]] = None,
81
83
  cbar_label: Optional[str] = None,
82
84
  nan_color: Optional[str] = "black",
83
- detailed_axes:bool = False,
84
- show_gridinfo:bool = True,
85
+ detailed_axes: bool = False,
86
+ show_gridinfo: bool = True,
85
87
  trend_color: Optional[str] = None,
86
88
  whiten_by: Optional[np.ndarray] = None,
87
89
  **kwargs,
@@ -153,7 +155,9 @@ def plot_wavelet_grid(
153
155
 
154
156
  # Validate the dimensions
155
157
  if (Nf, Nt) != (len(freq_grid), len(time_grid)):
156
- raise ValueError(f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}.")
158
+ raise ValueError(
159
+ f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}."
160
+ )
157
161
 
158
162
  # Prepare the data for plotting
159
163
  z = wavelet_data.copy()
@@ -162,14 +166,15 @@ def plot_wavelet_grid(
162
166
  if absolute:
163
167
  z = np.abs(z)
164
168
 
165
-
166
169
  # Determine normalization and colormap
167
170
  if norm is None:
168
171
  try:
169
172
  if np.all(np.isnan(z)):
170
173
  raise ValueError("All wavelet data is NaN.")
171
174
  if zscale == "log":
172
- norm = LogNorm(vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z<np.inf]))
175
+ norm = LogNorm(
176
+ vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z < np.inf])
177
+ )
173
178
  elif not absolute:
174
179
  vmin, vmax = np.nanmin(z), np.nanmax(z)
175
180
  vcenter = 0.0
@@ -177,7 +182,9 @@ def plot_wavelet_grid(
177
182
  else:
178
183
  norm = None # Default linear scaling
179
184
  except Exception as e:
180
- warnings.warn(f"Error in determining normalization: {e}. Using default linear scaling.")
185
+ warnings.warn(
186
+ f"Error in determining normalization: {e}. Using default linear scaling."
187
+ )
181
188
  norm = None
182
189
 
183
190
  if cmap is None:
@@ -195,7 +202,7 @@ def plot_wavelet_grid(
195
202
  im = ax.imshow(
196
203
  z,
197
204
  aspect="auto",
198
- extent=[time_grid[0],time_grid[-1], freq_grid[0], freq_grid[-1]],
205
+ extent=[time_grid[0], time_grid[-1], freq_grid[0], freq_grid[-1]],
199
206
  origin="lower",
200
207
  cmap=cmap,
201
208
  norm=norm,
@@ -203,13 +210,25 @@ def plot_wavelet_grid(
203
210
  **kwargs,
204
211
  )
205
212
  if trend_color is not None:
206
- plot_wavelet_trend(wavelet_data, time_grid, freq_grid, ax, color=trend_color, freq_range=freq_range, freq_scale=freq_scale)
213
+ plot_wavelet_trend(
214
+ wavelet_data,
215
+ time_grid,
216
+ freq_grid,
217
+ ax,
218
+ color=trend_color,
219
+ freq_range=freq_range,
220
+ freq_scale=freq_scale,
221
+ )
207
222
 
208
223
  # Add colorbar if requested
209
224
  if show_colorbar:
210
225
  cbar = fig.colorbar(im, ax=ax)
211
226
  if cbar_label is None:
212
- cbar_label = "Absolute Wavelet Amplitude" if absolute else "Wavelet Amplitude"
227
+ cbar_label = (
228
+ "Absolute Wavelet Amplitude"
229
+ if absolute
230
+ else "Wavelet Amplitude"
231
+ )
213
232
  cbar.set_label(cbar_label)
214
233
 
215
234
  # Configure axes scales
@@ -239,14 +258,12 @@ def plot_wavelet_grid(
239
258
  bbox=dict(boxstyle="round", facecolor=None, alpha=0.2),
240
259
  )
241
260
 
242
-
243
261
  # Adjust layout
244
262
  fig.tight_layout()
245
263
 
246
264
  return fig, ax
247
265
 
248
266
 
249
-
250
267
  def plot_freqseries(
251
268
  data: np.ndarray,
252
269
  freq: np.ndarray,
@@ -277,9 +294,10 @@ def plot_periodogram(
277
294
  flow = np.min(np.abs(freq))
278
295
  ax.set_xlabel("Frequency [Hz]")
279
296
  ax.set_ylabel("Periodigram")
280
- ax.set_xlim(left=flow, right=nyquist_frequency/2)
297
+ ax.set_xlim(left=flow, right=nyquist_frequency / 2)
281
298
  return ax.figure, ax
282
299
 
300
+
283
301
  def plot_timeseries(
284
302
  data: np.ndarray, time: np.ndarray, ax=None, **kwargs
285
303
  ) -> Tuple[plt.Figure, plt.Axes]:
@@ -314,7 +332,6 @@ def plot_spectrogram(
314
332
 
315
333
  _fmt_time_axis(t, ax)
316
334
 
317
-
318
335
  ax.set_ylabel("Frequency [Hz]")
319
336
  ax.set_ylim(top=fs / 2.0)
320
337
  cbar = plt.colorbar(cm, ax=ax)
@@ -322,20 +339,24 @@ def plot_spectrogram(
322
339
  return ax.figure, ax
323
340
 
324
341
 
325
-
326
342
  def _fmt_time_axis(t, axes, t0=None, tmax=None):
327
343
  if t[-1] > DAY_S: # If time goes beyond a day
328
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}"))
344
+ axes.xaxis.set_major_formatter(
345
+ plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}")
346
+ )
329
347
  axes.set_xlabel("Time [days]")
330
348
  elif t[-1] > HOUR_S: # If time goes beyond an hour
331
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}"))
349
+ axes.xaxis.set_major_formatter(
350
+ plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}")
351
+ )
332
352
  axes.set_xlabel("Time [hr]")
333
353
  elif t[-1] > MIN_S: # If time goes beyond a minute
334
- axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}"))
354
+ axes.xaxis.set_major_formatter(
355
+ plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}")
356
+ )
335
357
  axes.set_xlabel("Time [min]")
336
358
  else:
337
359
  axes.set_xlabel("Time [s]")
338
360
  t0 = t[0] if t0 is None else t0
339
361
  tmax = t[-1] if tmax is None else tmax
340
362
  axes.set_xlim(t0, tmax)
341
-