pywavelet 0.0.1b0__py3-none-any.whl → 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. pywavelet/__init__.py +1 -1
  2. pywavelet/_version.py +2 -2
  3. pywavelet/logger.py +6 -7
  4. pywavelet/transforms/__init__.py +10 -10
  5. pywavelet/transforms/forward/__init__.py +4 -0
  6. pywavelet/transforms/forward/from_freq.py +80 -0
  7. pywavelet/transforms/forward/from_time.py +66 -0
  8. pywavelet/transforms/forward/main.py +128 -0
  9. pywavelet/transforms/forward/wavelet_bins.py +58 -0
  10. pywavelet/transforms/inverse/__init__.py +3 -0
  11. pywavelet/transforms/inverse/main.py +96 -0
  12. pywavelet/transforms/{from_wavelets/inverse_wavelet_freq_funcs.py → inverse/to_freq.py} +43 -32
  13. pywavelet/transforms/{from_wavelets/inverse_wavelet_time_funcs.py → inverse/to_time.py} +49 -21
  14. pywavelet/transforms/phi_computer.py +152 -0
  15. pywavelet/transforms/types/__init__.py +4 -0
  16. pywavelet/transforms/types/common.py +53 -0
  17. pywavelet/transforms/types/frequencyseries.py +237 -0
  18. pywavelet/transforms/types/plotting.py +341 -0
  19. pywavelet/transforms/types/timeseries.py +280 -0
  20. pywavelet/transforms/types/wavelet.py +374 -0
  21. pywavelet/transforms/types/wavelet_mask.py +34 -0
  22. pywavelet/utils.py +76 -0
  23. pywavelet-0.1.0.dist-info/METADATA +35 -0
  24. pywavelet-0.1.0.dist-info/RECORD +26 -0
  25. {pywavelet-0.0.1b0.dist-info → pywavelet-0.1.0.dist-info}/WHEEL +1 -1
  26. pywavelet/fft_funcs.py +0 -16
  27. pywavelet/likelihood/__init__.py +0 -0
  28. pywavelet/likelihood/likelihood_base.py +0 -9
  29. pywavelet/likelihood/whittle.py +0 -24
  30. pywavelet/transforms/common.py +0 -77
  31. pywavelet/transforms/from_wavelets/__init__.py +0 -25
  32. pywavelet/transforms/to_wavelets/__init__.py +0 -52
  33. pywavelet/transforms/to_wavelets/transform_freq_funcs.py +0 -84
  34. pywavelet/transforms/to_wavelets/transform_time_funcs.py +0 -63
  35. pywavelet/utils/__init__.py +0 -0
  36. pywavelet/utils/fisher_matrix.py +0 -6
  37. pywavelet/utils/snr.py +0 -37
  38. pywavelet/waveform_generator/__init__.py +0 -0
  39. pywavelet/waveform_generator/build_lookup_table.py +0 -0
  40. pywavelet/waveform_generator/generators/__init__.py +0 -2
  41. pywavelet/waveform_generator/generators/functional_waveform_generator.py +0 -33
  42. pywavelet/waveform_generator/generators/lookuptable_waveform_generator.py +0 -15
  43. pywavelet/waveform_generator/generators/rom_waveform_generator.py +0 -0
  44. pywavelet/waveform_generator/waveform_generator.py +0 -14
  45. pywavelet-0.0.1b0.dist-info/METADATA +0 -35
  46. pywavelet-0.0.1b0.dist-info/RECORD +0 -29
  47. {pywavelet-0.0.1b0.dist-info → pywavelet-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,341 @@
1
+ import warnings
2
+ from typing import Tuple, Optional, Union
3
+ from scipy.signal import savgol_filter
4
+ from scipy.interpolate import interp1d
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from matplotlib.colors import LogNorm, TwoSlopeNorm
9
+ from scipy.signal import spectrogram
10
+
11
+ MIN_S = 60
12
+ HOUR_S = 60 * MIN_S
13
+ DAY_S = 24 * HOUR_S
14
+
15
+
16
+ def plot_wavelet_trend(
17
+ wavelet_data: np.ndarray,
18
+ time_grid: np.ndarray,
19
+ freq_grid: np.ndarray,
20
+ ax: Optional[plt.Axes] = None,
21
+ freq_scale: str = "linear",
22
+ freq_range: Optional[Tuple[float, float]] = None,
23
+ color: str = "black",
24
+ ):
25
+ x = time_grid
26
+ y = __get_smoothed_y(x, np.abs(wavelet_data), freq_grid)
27
+ if ax == None:
28
+ fig, ax = plt.subplots()
29
+ ax.plot(x, y, color=color)
30
+
31
+ # Configure axes scales
32
+ ax.set_yscale(freq_scale)
33
+ _fmt_time_axis(time_grid, ax)
34
+ ax.set_ylabel("Frequency [Hz]")
35
+
36
+ # Set frequency range if specified
37
+ freq_range = freq_range or (freq_grid[0], freq_grid[-1])
38
+ ax.set_ylim(freq_range)
39
+
40
+
41
+ def __get_smoothed_y(x, z, y_grid):
42
+ Nf, Nt = z.shape
43
+ y = np.zeros(Nt)
44
+ dy = np.diff(y_grid)[0]
45
+ for i in range(Nt):
46
+ # if all values are nan, set to nan
47
+ if np.all(np.isnan(z[:, i])):
48
+ y[i] = np.nan
49
+ else:
50
+ y[i] = y_grid[np.nanargmax(z[:, i])]
51
+
52
+ if not np.isnan(y).all():
53
+ # Interpolate to fill NaNs in y before smoothing
54
+ nan_mask = ~np.isnan(y)
55
+ if np.isnan(y).any():
56
+ interpolator = interp1d(x[nan_mask], y[nan_mask], kind='cubic', bounds_error=False,
57
+ fill_value="extrapolate")
58
+ y = interpolator(x) # Fill NaNs with interpolated values
59
+
60
+ # Smooth the curve
61
+ window_length = min(51, len(y) - 1 if len(y) % 2 == 0 else len(y))
62
+ y = savgol_filter(y, window_length, 3)
63
+ y[~nan_mask] = np.nan
64
+ return y
65
+
66
+
67
+
68
+
69
+ def plot_wavelet_grid(
70
+ wavelet_data: np.ndarray,
71
+ time_grid: np.ndarray,
72
+ freq_grid: np.ndarray,
73
+ ax: Optional[plt.Axes] = None,
74
+ zscale: str = "linear",
75
+ freq_scale: str = "linear",
76
+ absolute: bool = False,
77
+ freq_range: Optional[Tuple[float, float]] = None,
78
+ show_colorbar: bool = True,
79
+ cmap: Optional[str] = None,
80
+ norm: Optional[Union[LogNorm, TwoSlopeNorm]] = None,
81
+ cbar_label: Optional[str] = None,
82
+ nan_color: Optional[str] = "black",
83
+ detailed_axes:bool = False,
84
+ show_gridinfo:bool = True,
85
+ trend_color: Optional[str] = None,
86
+ whiten_by: Optional[np.ndarray] = None,
87
+ **kwargs,
88
+ ) -> Tuple[plt.Figure, plt.Axes]:
89
+ """
90
+ Plot a 2D grid of wavelet coefficients.
91
+
92
+ Parameters
93
+ ----------
94
+ wavelet_data : np.ndarray
95
+ A 2D array containing the wavelet coefficients with shape (Nf, Nt),
96
+ where Nf is the number of frequency bins and Nt is the number of time bins.
97
+
98
+ time_grid : np.ndarray, optional
99
+ 1D array of time values corresponding to the time bins. If None, uses np.arange(Nt).
100
+
101
+ freq_grid : np.ndarray, optional
102
+ 1D array of frequency values corresponding to the frequency bins. If None, uses np.arange(Nf).
103
+
104
+ ax : plt.Axes, optional
105
+ Matplotlib Axes object to plot on. If None, creates a new figure and axes.
106
+
107
+ zscale : str, optional
108
+ Scale for the color mapping. Options are 'linear' or 'log'. Default is 'linear'.
109
+
110
+ freq_scale : str, optional
111
+ Scale for the frequency axis. Options are 'linear' or 'log'. Default is 'linear'.
112
+
113
+ absolute : bool, optional
114
+ If True, plots the absolute value of the wavelet coefficients. Default is False.
115
+
116
+ freq_range : tuple of float, optional
117
+ Tuple specifying the (min, max) frequency range to display. If None, displays the full range.
118
+
119
+ show_colorbar : bool, optional
120
+ If True, displays a colorbar next to the plot. Default is True.
121
+
122
+ cmap : str, optional
123
+ Colormap to use for the plot. If None, uses 'viridis' for absolute values or 'bwr' for signed values.
124
+
125
+ norm : matplotlib.colors.Normalize, optional
126
+ Normalization instance to scale data values. If None, a suitable normalization is chosen based on `zscale`.
127
+
128
+ cbar_label : str, optional
129
+ Label for the colorbar. If None, a default label is used based on the `absolute` parameter.
130
+
131
+ nan_color : str, optional
132
+ Color to use for NaN values. Default is 'black'.
133
+
134
+ trend_color : bool, optional
135
+ Color to use for the trend line. Not shown if None.
136
+
137
+ **kwargs
138
+ Additional keyword arguments passed to `ax.imshow()`.
139
+
140
+ Returns
141
+ -------
142
+ Tuple[plt.Figure, plt.Axes]
143
+ The figure and axes objects of the plot.
144
+
145
+ Raises
146
+ ------
147
+ ValueError
148
+ If the dimensions of `wavelet_data` do not match the lengths of `freq_grid` and `time_grid`.
149
+ """
150
+
151
+ # Determine the dimensions of the data
152
+ Nf, Nt = wavelet_data.shape
153
+
154
+ # Validate the dimensions
155
+ if (Nf, Nt) != (len(freq_grid), len(time_grid)):
156
+ raise ValueError(f"Wavelet shape {Nf, Nt} does not match provided grids {(len(freq_grid), len(time_grid))}.")
157
+
158
+ # Prepare the data for plotting
159
+ z = wavelet_data.copy()
160
+ if whiten_by is not None:
161
+ z = z / whiten_by
162
+ if absolute:
163
+ z = np.abs(z)
164
+
165
+
166
+ # Determine normalization and colormap
167
+ if norm is None:
168
+ try:
169
+ if np.all(np.isnan(z)):
170
+ raise ValueError("All wavelet data is NaN.")
171
+ if zscale == "log":
172
+ norm = LogNorm(vmin=np.nanmin(z[z > 0]), vmax=np.nanmax(z[z<np.inf]))
173
+ elif not absolute:
174
+ vmin, vmax = np.nanmin(z), np.nanmax(z)
175
+ vcenter = 0.0
176
+ norm = TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax)
177
+ else:
178
+ norm = None # Default linear scaling
179
+ except Exception as e:
180
+ warnings.warn(f"Error in determining normalization: {e}. Using default linear scaling.")
181
+ norm = None
182
+
183
+ if cmap is None:
184
+ cmap = "viridis" if absolute else "bwr"
185
+ cmap = plt.get_cmap(cmap)
186
+ cmap.set_bad(color=nan_color)
187
+
188
+ # Set up the plot
189
+ if ax is None:
190
+ fig, ax = plt.subplots()
191
+ else:
192
+ fig = ax.get_figure()
193
+
194
+ # Plot the data
195
+ im = ax.imshow(
196
+ z,
197
+ aspect="auto",
198
+ extent=[time_grid[0],time_grid[-1], freq_grid[0], freq_grid[-1]],
199
+ origin="lower",
200
+ cmap=cmap,
201
+ norm=norm,
202
+ interpolation="nearest",
203
+ **kwargs,
204
+ )
205
+ if trend_color is not None:
206
+ plot_wavelet_trend(wavelet_data, time_grid, freq_grid, ax, color=trend_color, freq_range=freq_range, freq_scale=freq_scale)
207
+
208
+ # Add colorbar if requested
209
+ if show_colorbar:
210
+ cbar = fig.colorbar(im, ax=ax)
211
+ if cbar_label is None:
212
+ cbar_label = "Absolute Wavelet Amplitude" if absolute else "Wavelet Amplitude"
213
+ cbar.set_label(cbar_label)
214
+
215
+ # Configure axes scales
216
+ ax.set_yscale(freq_scale)
217
+ _fmt_time_axis(time_grid, ax)
218
+ ax.set_ylabel("Frequency [Hz]")
219
+
220
+ # Set frequency range if specified
221
+ freq_range = freq_range or (freq_grid[0], freq_grid[-1])
222
+ ax.set_ylim(freq_range)
223
+
224
+ if detailed_axes:
225
+ ax.set_xlabel(r"Time Bins [$\Delta T$=" + f"{1 / Nt:.4f}s, Nt={Nt}]")
226
+ ax.set_ylabel(r"Freq Bins [$\Delta F$=" + f"{1 / Nf:.4f}Hz, Nf={Nf}]")
227
+
228
+ label = kwargs.get("label", "")
229
+ NfNt_label = f"{Nf}x{Nt}" if show_gridinfo else ""
230
+ txt = f"{label}\n{NfNt_label}" if label else NfNt_label
231
+ if txt:
232
+ ax.text(
233
+ 0.05,
234
+ 0.95,
235
+ txt,
236
+ transform=ax.transAxes,
237
+ fontsize=14,
238
+ verticalalignment="top",
239
+ bbox=dict(boxstyle="round", facecolor=None, alpha=0.2),
240
+ )
241
+
242
+
243
+ # Adjust layout
244
+ fig.tight_layout()
245
+
246
+ return fig, ax
247
+
248
+
249
+
250
+ def plot_freqseries(
251
+ data: np.ndarray,
252
+ freq: np.ndarray,
253
+ nyquist_frequency: float,
254
+ ax=None,
255
+ **kwargs,
256
+ ):
257
+ if ax == None:
258
+ fig, ax = plt.subplots()
259
+ ax.plot(freq, data, **kwargs)
260
+ ax.set_xlabel("Frequency Bin [Hz]")
261
+ ax.set_ylabel("Amplitude")
262
+ ax.set_xlim(0, nyquist_frequency)
263
+ return ax.figure, ax
264
+
265
+
266
+ def plot_periodogram(
267
+ data: np.ndarray,
268
+ freq: np.ndarray,
269
+ nyquist_frequency: float,
270
+ ax=None,
271
+ **kwargs,
272
+ ):
273
+ if ax == None:
274
+ fig, ax = plt.subplots()
275
+
276
+ ax.loglog(freq, np.abs(data) ** 2, **kwargs)
277
+ flow = np.min(np.abs(freq))
278
+ ax.set_xlabel("Frequency [Hz]")
279
+ ax.set_ylabel("Periodigram")
280
+ ax.set_xlim(left=flow, right=nyquist_frequency/2)
281
+ return ax.figure, ax
282
+
283
+ def plot_timeseries(
284
+ data: np.ndarray, time: np.ndarray, ax=None, **kwargs
285
+ ) -> Tuple[plt.Figure, plt.Axes]:
286
+ """Custom method."""
287
+ if ax == None:
288
+ fig, ax = plt.subplots()
289
+ ax.plot(time, data, **kwargs)
290
+
291
+ ax.set_ylabel("Amplitude")
292
+ ax.set_xlim(left=time[0], right=time[-1])
293
+
294
+ _fmt_time_axis(time, ax)
295
+
296
+ return ax.figure, ax
297
+
298
+
299
+ def plot_spectrogram(
300
+ timeseries_data: np.ndarray,
301
+ fs: float,
302
+ ax=None,
303
+ spec_kwargs={},
304
+ plot_kwargs={},
305
+ ) -> Tuple[plt.Figure, plt.Axes]:
306
+ f, t, Sxx = spectrogram(timeseries_data, fs=fs, **spec_kwargs)
307
+ if ax == None:
308
+ fig, ax = plt.subplots()
309
+
310
+ if "cmap" not in plot_kwargs:
311
+ plot_kwargs["cmap"] = "Reds"
312
+
313
+ cm = ax.pcolormesh(t, f, Sxx, shading="nearest", **plot_kwargs)
314
+
315
+ _fmt_time_axis(t, ax)
316
+
317
+
318
+ ax.set_ylabel("Frequency [Hz]")
319
+ ax.set_ylim(top=fs / 2.0)
320
+ cbar = plt.colorbar(cm, ax=ax)
321
+ cbar.set_label("Spectrogram Amplitude")
322
+ return ax.figure, ax
323
+
324
+
325
+
326
+ def _fmt_time_axis(t, axes, t0=None, tmax=None):
327
+ if t[-1] > DAY_S: # If time goes beyond a day
328
+ axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / DAY_S:.1f}"))
329
+ axes.set_xlabel("Time [days]")
330
+ elif t[-1] > HOUR_S: # If time goes beyond an hour
331
+ axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / HOUR_S:.1f}"))
332
+ axes.set_xlabel("Time [hr]")
333
+ elif t[-1] > MIN_S: # If time goes beyond a minute
334
+ axes.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x / MIN_S:.1f}"))
335
+ axes.set_xlabel("Time [min]")
336
+ else:
337
+ axes.set_xlabel("Time [s]")
338
+ t0 = t[0] if t0 is None else t0
339
+ tmax = t[-1] if tmax is None else tmax
340
+ axes.set_xlim(t0, tmax)
341
+
@@ -0,0 +1,280 @@
1
+ import matplotlib.pyplot as plt
2
+ from typing import Tuple, Optional, Union
3
+ from scipy.signal.windows import tukey
4
+ from scipy.signal import butter, sosfiltfilt
5
+
6
+ from ...logger import logger
7
+ from .common import is_documented_by, xp, rfft, rfftfreq, fmt_timerange, fmt_time, fmt_pow2
8
+ from .plotting import plot_timeseries, plot_spectrogram
9
+
10
+ __all__ = ["TimeSeries"]
11
+
12
+
13
+ class TimeSeries:
14
+ """
15
+ A class to represent a time series, with methods for plotting and converting
16
+ the series to a frequency-domain representation.
17
+
18
+ Attributes
19
+ ----------
20
+ data : xp.ndarray
21
+ Time domain data.
22
+ time : xp.ndarray
23
+ Array of corresponding time points.
24
+ """
25
+
26
+ def __init__(self, data: xp.ndarray, time: xp.ndarray):
27
+ """
28
+ Initialize the TimeSeries with data and time arrays.
29
+
30
+ Parameters
31
+ ----------
32
+ data : xp.ndarray
33
+ Time domain data.
34
+ time : xp.ndarray
35
+ Array of corresponding time points. Must be the same length as `data`.
36
+
37
+ Raises
38
+ ------
39
+ ValueError
40
+ If `data` and `time` do not have the same length.
41
+ """
42
+ if len(data) != len(time):
43
+ raise ValueError("data and time must have the same length")
44
+ self.data = data
45
+ self.time = time
46
+
47
+ @is_documented_by(plot_timeseries)
48
+ def plot(self, ax=None, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
49
+ return plot_timeseries(self.data, self.time, ax=ax, **kwargs)
50
+
51
+ @is_documented_by(plot_spectrogram)
52
+ def plot_spectrogram(
53
+ self, ax=None, spec_kwargs={}, plot_kwargs={}
54
+ ) -> Tuple[plt.Figure, plt.Axes]:
55
+ return plot_spectrogram(
56
+ self.data,
57
+ self.fs,
58
+ ax=ax,
59
+ spec_kwargs=spec_kwargs,
60
+ plot_kwargs=plot_kwargs,
61
+ )
62
+
63
+ def __len__(self):
64
+ """Return the number of data points in the time series."""
65
+ return len(self.data)
66
+
67
+ def __getitem__(self, item):
68
+ """Return the data point at the specified index."""
69
+ return self.data[item]
70
+
71
+ @property
72
+ def sample_rate(self) -> float:
73
+ """
74
+ Return the sample rate (fs).
75
+
76
+ The sample rate is the inverse of the time resolution (Δt).
77
+ """
78
+ return float(xp.round(1.0 / self.dt, decimals=14))
79
+
80
+ @property
81
+ def fs(self) -> float:
82
+ """Return the sample rate (fs)."""
83
+ return self.sample_rate
84
+
85
+ @property
86
+ def duration(self) -> float:
87
+ """Return the duration of the time series in seconds."""
88
+ return len(self) / self.fs
89
+
90
+ @property
91
+ def dt(self) -> float:
92
+ """Return the time resolution (Δt)."""
93
+ return float(self.time[1] - self.time[0])
94
+
95
+ @property
96
+ def nyquist_frequency(self) -> float:
97
+ """Return the Nyquist frequency (fs/2)."""
98
+ return self.fs / 2
99
+
100
+ @property
101
+ def t0(self) -> float:
102
+ """Return the initial time point in the series."""
103
+ return float(self.time[0])
104
+
105
+ @property
106
+ def tend(self) -> float:
107
+ """Return the final time point in the series."""
108
+ return float(self.time[-1]) + self.dt
109
+
110
+ @property
111
+ def shape(self) -> Tuple[int, ...]:
112
+ """Return the shape of the data array."""
113
+ return self.data.shape
114
+
115
+ @property
116
+ def ND(self) -> int:
117
+ """Return the number of data points in the time series."""
118
+ return len(self)
119
+
120
+ def __repr__(self) -> str:
121
+ """Return a string representation of the TimeSeries."""
122
+ trange = fmt_timerange((self.t0, self.tend))
123
+ T = " ".join(fmt_time(self.duration, units=True))
124
+ n = fmt_pow2(len(self))
125
+ return f"TimeSeries(n={n}, trange={trange}, T={T}, fs={self.fs:.2f} Hz)"
126
+
127
+ def to_frequencyseries(self) -> 'FrequencySeries':
128
+ """
129
+ Convert the time series to a frequency series using the one-sided FFT.
130
+
131
+ Returns
132
+ -------
133
+ FrequencySeries
134
+ The frequency-domain representation of the time series.
135
+ """
136
+ freq = rfftfreq(len(self), d=self.dt)
137
+ data = rfft(self.data)
138
+
139
+ from .frequencyseries import FrequencySeries # Avoid circular import
140
+ return FrequencySeries(data, freq, t0=self.t0)
141
+
142
+ def to_wavelet(
143
+ self,
144
+ Nf: Union[int, None] = None,
145
+ Nt: Union[int, None] = None,
146
+ nx: Optional[float] = 4.0,
147
+ ) -> 'Wavelet':
148
+ """
149
+ Convert the time series to a wavelet representation.
150
+
151
+ Parameters
152
+ ----------
153
+ Nf : int
154
+ Number of frequency bins for the wavelet transform.
155
+ Nt : int
156
+ Number of time bins for the wavelet transform.
157
+ nx : float, optional
158
+ Number of standard deviations for the `phi_vec`, controlling the
159
+ width of the wavelets. Default is 4.0.
160
+
161
+ Returns
162
+ -------
163
+ Wavelet
164
+ The wavelet-domain representation of the time series.
165
+ """
166
+ hf = self.to_frequencyseries()
167
+ return hf.to_wavelet(Nf, Nt, nx=nx)
168
+
169
+
170
+ def __add__(self, other: 'TimeSeries') -> 'TimeSeries':
171
+ """Add two TimeSeries objects together."""
172
+ if self.shape != other.shape:
173
+ raise ValueError("TimeSeries objects must have the same shape to add them together")
174
+ return TimeSeries(self.data + other.data, self.time)
175
+
176
+ def __sub__(self, other: 'TimeSeries') -> 'TimeSeries':
177
+ """Subtract one TimeSeries object from another."""
178
+ if self.shape != other.shape:
179
+ raise ValueError("TimeSeries objects must have the same shape to subtract them")
180
+ return TimeSeries(self.data - other.data, self.time)
181
+
182
+ def __eq__(self, other: 'TimeSeries') -> bool:
183
+ """Check if two TimeSeries objects are equal."""
184
+ shape_same = self.shape == other.shape
185
+ range_same = self.t0 == other.t0 and self.tend == other.tend
186
+ time_same = xp.allclose(self.time, other.time)
187
+ data_same = xp.allclose(self.data, other.data)
188
+ return shape_same and range_same and data_same and time_same
189
+
190
+ def __mul__(self, other: float) -> 'TimeSeries':
191
+ """Multiply a TimeSeries object by a scalar."""
192
+ return TimeSeries(self.data * other, self.time)
193
+
194
+ def zero_pad_to_power_of_2(self, tukey_window_alpha:float=0.0)->'TimeSeries':
195
+ """Zero pad the time series to make the length a power of two (useful to speed up FFTs, O(NlogN) versus O(N^2)).
196
+
197
+ Parameters
198
+ ----------
199
+ tukey_window_alpha : float, optional
200
+ Alpha parameter for the Tukey window. Default is 0.0.
201
+ (prevents spectral leakage when padding the data)
202
+
203
+ Returns
204
+ -------
205
+ TimeSeries
206
+ A new TimeSeries object with the data zero-padded to a power of two.
207
+ """
208
+ N, dt, t0 = self.ND, self.dt, self.t0
209
+ pow_2 = xp.ceil(xp.log2(N))
210
+ n_pad = int((2 ** pow_2) - N)
211
+ new_N = N + n_pad
212
+ if n_pad > 0:
213
+ logger.warning(
214
+ f"Padding the data to a power of two. "
215
+ f"{N:,} (2**{xp.log2(N):.2f}) -> {new_N:,} (2**{pow_2}). "
216
+ )
217
+ window = tukey(N, alpha=tukey_window_alpha)
218
+ data = self.data * window
219
+ data = xp.pad(data, (0, n_pad), "constant")
220
+ time = xp.arange(0, len(data) * dt, dt) + t0
221
+ return TimeSeries(data, time)
222
+
223
+ def highpass_filter(self, fmin: float, tukey_window_alpha:float=0.0, bandpass_order: int = 4) -> 'TimeSeries':
224
+ """
225
+ Filter the time series with a highpass bandpass filter.
226
+
227
+ (we use sosfiltfilt instead of filtfilt for numerical stability)
228
+
229
+ Note: filtfilt should be used if phase accuracy (zero-phase filtering) is critical for your analysis
230
+ and if the filter order is low to moderate.
231
+
232
+
233
+ Parameters
234
+ ----------
235
+ fmin : float
236
+ Minimum frequency to pass through the filter.
237
+ bandpass_order : int, optional
238
+ Order of the bandpass filter. Default is 4.
239
+
240
+ Returns
241
+ -------
242
+ TimeSeries
243
+ A new TimeSeries object with the highpass filter applied.
244
+ """
245
+
246
+ if fmin <= 0 or fmin > self.nyquist_frequency:
247
+ raise ValueError(f"Invalid fmin value: {fmin}. Must be in the range [0, {self.nyquist_frequency}]")
248
+
249
+ sos = butter(bandpass_order, Wn=fmin, btype="highpass", output='sos', fs=self.fs)
250
+ window = tukey(self.ND, alpha=tukey_window_alpha)
251
+ data = self.data.copy()
252
+ data = sosfiltfilt(sos, data * window)
253
+ return TimeSeries(data, self.time)
254
+
255
+ def __copy__(self):
256
+ return TimeSeries(
257
+ self.data.copy(),
258
+ self.time.copy()
259
+ )
260
+
261
+ def copy(self):
262
+ return self.__copy__()
263
+
264
+ def __getitem__(self, key)->"TimeSeries":
265
+ if isinstance(key, slice):
266
+ # Handle slicing
267
+ return self.__handle_slice(key)
268
+ else:
269
+ # Handle regular indexing
270
+ return TimeSeries(self.data[key], self.time[key])
271
+
272
+ def __handle_slice(self, slice_obj)->"TimeSeries":
273
+ return TimeSeries(
274
+ self.data[slice_obj],
275
+ self.time[slice_obj]
276
+ )
277
+
278
+ @classmethod
279
+ def _EMPTY(cls, ND:int, dt:float)->"TimeSeries":
280
+ return cls(xp.zeros(ND), xp.arange(0, ND*dt, dt))