bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bmtool/analysis/lfp.py CHANGED
@@ -3,16 +3,18 @@ Module for processing BMTK LFP output.
3
3
  """
4
4
 
5
5
  import h5py
6
+ import matplotlib.pyplot as plt
6
7
  import numpy as np
8
+ import pandas as pd
9
+ import pywt
7
10
  import xarray as xr
8
11
  from fooof import FOOOF
9
- from fooof.sim.gen import gen_model, gen_aperiodic
10
- import matplotlib.pyplot as plt
11
- from scipy import signal
12
- import pywt
13
- import pandas as pd
12
+ from fooof.sim.gen import gen_model
13
+ from scipy import signal
14
+
14
15
  from ..bmplot.connections import is_notebook
15
16
 
17
+
16
18
  def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
17
19
  """
18
20
  Load ECP data from an HDF5 file (BMTK sim) into an xarray DataArray.
@@ -30,26 +32,27 @@ def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
30
32
  An xarray DataArray containing the ECP data, with time as one dimension
31
33
  and channel_id as another.
32
34
  """
33
- with h5py.File(ecp_file, 'r') as f:
35
+ with h5py.File(ecp_file, "r") as f:
34
36
  ecp = xr.DataArray(
35
- f['ecp']['data'][()].T,
37
+ f["ecp"]["data"][()].T,
36
38
  coords=dict(
37
- channel_id=f['ecp']['channel_id'][()],
38
- time=np.arange(*f['ecp']['time']) # ms
39
+ channel_id=f["ecp"]["channel_id"][()],
40
+ time=np.arange(*f["ecp"]["time"]), # ms
39
41
  ),
40
42
  attrs=dict(
41
- fs=1000 / f['ecp']['time'][2] # Hz
42
- )
43
+ fs=1000 / f["ecp"]["time"][2] # Hz
44
+ ),
43
45
  )
44
46
  if demean:
45
- ecp -= ecp.mean(dim='time')
47
+ ecp -= ecp.mean(dim="time")
46
48
  return ecp
47
49
 
48
50
 
49
- def ecp_to_lfp(ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000,
50
- downsample_freq: float = 1000) -> xr.DataArray:
51
+ def ecp_to_lfp(
52
+ ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000, downsample_freq: float = 1000
53
+ ) -> xr.DataArray:
51
54
  """
52
- Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
55
+ Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
53
56
  This filters out the high end frequencies turning the ECP into a LFP
54
57
 
55
58
  Parameters:
@@ -71,21 +74,25 @@ def ecp_to_lfp(ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000,
71
74
  # Bandpass filter design
72
75
  nyq = 0.5 * fs
73
76
  cut = cutoff / nyq
74
- b, a = signal.butter(8, cut, btype='low', analog=False)
77
+ b, a = signal.butter(8, cut, btype="low", analog=False)
75
78
 
76
79
  # Initialize an array to hold filtered data
77
- filtered_data = xr.DataArray(np.zeros_like(ecp_data), coords=ecp_data.coords, dims=ecp_data.dims)
80
+ filtered_data = xr.DataArray(
81
+ np.zeros_like(ecp_data), coords=ecp_data.coords, dims=ecp_data.dims
82
+ )
78
83
 
79
84
  # Apply the filter to each channel
80
85
  for channel in ecp_data.channel_id:
81
- filtered_data.loc[channel, :] = signal.filtfilt(b, a, ecp_data.sel(channel_id=channel).values)
86
+ filtered_data.loc[channel, :] = signal.filtfilt(
87
+ b, a, ecp_data.sel(channel_id=channel).values
88
+ )
82
89
 
83
90
  # Downsample the filtered data if a downsample frequency is provided
84
91
  if downsample_freq is not None:
85
92
  downsample_factor = int(fs / downsample_freq)
86
93
  filtered_data = filtered_data.isel(time=slice(None, None, downsample_factor))
87
94
  # Update the sampling frequency attribute
88
- filtered_data.attrs['fs'] = downsample_freq
95
+ filtered_data.attrs["fs"] = downsample_freq
89
96
 
90
97
  return filtered_data
91
98
 
@@ -100,7 +107,7 @@ def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
100
107
  data : xr.DataArray
101
108
  The input xarray DataArray containing time-series data.
102
109
  time_ranges : tuple or list of tuples
103
- One or more tuples representing the (start, stop) time points for slicing.
110
+ One or more tuples representing the (start, stop) time points for slicing.
104
111
  For example: (start, stop) or [(start1, stop1), (start2, stop2)]
105
112
 
106
113
  Returns:
@@ -122,17 +129,26 @@ def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
122
129
 
123
130
  # Concatenate all slices along the time dimension if more than one slice
124
131
  if len(slices) > 1:
125
- return xr.concat(slices, dim='time')
132
+ return xr.concat(slices, dim="time")
126
133
  else:
127
134
  return slices[0]
128
135
 
129
136
 
130
- def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
131
- dB_threshold: float = 3.0, max_n_peaks: int = 10,
132
- freq_range: tuple = None, peak_width_limits: tuple = None,
133
- report: bool = False, plot: bool = False,
134
- plt_log: bool = False, plt_range: tuple = None,
135
- figsize: tuple = None, title: str = None) -> tuple:
137
+ def fit_fooof(
138
+ f: np.ndarray,
139
+ pxx: np.ndarray,
140
+ aperiodic_mode: str = "fixed",
141
+ dB_threshold: float = 3.0,
142
+ max_n_peaks: int = 10,
143
+ freq_range: tuple = None,
144
+ peak_width_limits: tuple = None,
145
+ report: bool = False,
146
+ plot: bool = False,
147
+ plt_log: bool = False,
148
+ plt_range: tuple = None,
149
+ figsize: tuple = None,
150
+ title: str = None,
151
+ ) -> tuple:
136
152
  """
137
153
  Fit a FOOOF model to power spectral density data.
138
154
 
@@ -170,43 +186,50 @@ def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
170
186
  tuple
171
187
  A tuple containing the fitting results and the FOOOF model object.
172
188
  """
173
- if aperiodic_mode != 'knee':
174
- aperiodic_mode = 'fixed'
175
-
189
+ if aperiodic_mode != "knee":
190
+ aperiodic_mode = "fixed"
191
+
176
192
  def set_range(x, upper=f[-1]):
177
193
  x = np.array(upper) if x is None else np.array(x)
178
194
  return [f[2], x.item()] if x.size == 1 else x.tolist()
179
-
195
+
180
196
  freq_range = set_range(freq_range)
181
197
  peak_width_limits = set_range(peak_width_limits, np.inf)
182
198
 
183
199
  # Initialize a FOOOF object
184
- fm = FOOOF(peak_width_limits=peak_width_limits, min_peak_height=dB_threshold / 10,
185
- peak_threshold=0., max_n_peaks=max_n_peaks, aperiodic_mode=aperiodic_mode)
186
-
200
+ fm = FOOOF(
201
+ peak_width_limits=peak_width_limits,
202
+ min_peak_height=dB_threshold / 10,
203
+ peak_threshold=0.0,
204
+ max_n_peaks=max_n_peaks,
205
+ aperiodic_mode=aperiodic_mode,
206
+ )
207
+
187
208
  # Fit the model
188
209
  try:
189
210
  fm.fit(f, pxx, freq_range)
190
211
  except Exception as e:
191
212
  fl = np.linspace(f[0], f[-1], int((f[-1] - f[0]) / np.min(np.diff(f))) + 1)
192
213
  fm.fit(fl, np.interp(fl, f, pxx), freq_range)
193
-
214
+
194
215
  results = fm.get_results()
195
216
 
196
217
  if report:
197
218
  fm.print_results()
198
- if aperiodic_mode == 'knee':
219
+ if aperiodic_mode == "knee":
199
220
  ap_params = results.aperiodic_params
200
221
  if ap_params[1] <= 0:
201
- print('Negative value of knee parameter occurred. Suggestion: Fit without knee parameter.')
222
+ print(
223
+ "Negative value of knee parameter occurred. Suggestion: Fit without knee parameter."
224
+ )
202
225
  knee_freq = np.abs(ap_params[1]) ** (1 / ap_params[2])
203
- print(f'Knee location: {knee_freq:.2f} Hz')
204
-
226
+ print(f"Knee location: {knee_freq:.2f} Hz")
227
+
205
228
  if plot:
206
229
  plt_range = set_range(plt_range)
207
230
  fm.plot(plt_log=plt_log)
208
231
  plt.xlim(np.log10(plt_range) if plt_log else plt_range)
209
- #plt.ylim(-8, -5.5)
232
+ # plt.ylim(-8, -5.5)
210
233
  if figsize:
211
234
  plt.gcf().set_size_inches(figsize)
212
235
  if title:
@@ -215,7 +238,7 @@ def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
215
238
  pass
216
239
  else:
217
240
  plt.show()
218
-
241
+
219
242
  return results, fm
220
243
 
221
244
 
@@ -234,13 +257,19 @@ def generate_resd_from_fooof(fooof_model: FOOOF) -> tuple:
234
257
  A tuple containing the residual power spectral density and the aperiodic fit.
235
258
  """
236
259
  results = fooof_model.get_results()
237
- full_fit, _, ap_fit = gen_model(fooof_model.freqs[1:], results.aperiodic_params,
238
- results.gaussian_params, return_components=True)
239
-
240
- full_fit, ap_fit = 10 ** full_fit, 10 ** ap_fit # Convert back from log
241
- res_psd = np.insert((10 ** fooof_model.power_spectrum[1:]) - ap_fit, 0, 0.) # Convert back from log
242
- res_fit = np.insert(full_fit - ap_fit, 0, 0.)
243
- ap_fit = np.insert(ap_fit, 0, 0.)
260
+ full_fit, _, ap_fit = gen_model(
261
+ fooof_model.freqs[1:],
262
+ results.aperiodic_params,
263
+ results.gaussian_params,
264
+ return_components=True,
265
+ )
266
+
267
+ full_fit, ap_fit = 10**full_fit, 10**ap_fit # Convert back from log
268
+ res_psd = np.insert(
269
+ (10 ** fooof_model.power_spectrum[1:]) - ap_fit, 0, 0.0
270
+ ) # Convert back from log
271
+ res_fit = np.insert(full_fit - ap_fit, 0, 0.0)
272
+ ap_fit = np.insert(ap_fit, 0, 0.0)
244
273
 
245
274
  return res_psd, ap_fit
246
275
 
@@ -276,7 +305,7 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
276
305
  def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
277
306
  """
278
307
  Calculate the passband of a complex Morlet wavelet filter.
279
-
308
+
280
309
  Parameters
281
310
  ----------
282
311
  center_freq : float
@@ -285,7 +314,7 @@ def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
285
314
  Bandwidth parameter of the wavelet filter
286
315
  threshold : float, optional
287
316
  Power threshold to define the passband edges (default: 0.5 = -3dB point)
288
-
317
+
289
318
  Returns
290
319
  -------
291
320
  tuple
@@ -297,32 +326,39 @@ def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
297
326
  freq_min = max(0.1, center_freq - 3 * expected_width)
298
327
  freq_max = center_freq + 3 * expected_width
299
328
  freq_axis = np.linspace(freq_min, freq_max, 1000)
300
-
329
+
301
330
  # Calculate the theoretical frequency response of the Morlet wavelet
302
331
  # For a complex Morlet wavelet, the frequency response approximates a Gaussian
303
332
  # centered at the center frequency with width related to the bandwidth parameter
304
333
  sigma_f = bandwidth * center_freq / 8 # Approximate relationship for cmor wavelet
305
- response = np.exp(-((freq_axis - center_freq)**2) / (2 * sigma_f**2))
306
-
334
+ response = np.exp(-((freq_axis - center_freq) ** 2) / (2 * sigma_f**2))
335
+
307
336
  # Find the passband edges (where response crosses the threshold)
308
337
  above_threshold = response >= threshold
309
338
  if not np.any(above_threshold):
310
339
  return (center_freq, center_freq, 0) # No passband found
311
-
340
+
312
341
  # Find the first and last indices where response is above threshold
313
342
  indices = np.where(above_threshold)[0]
314
343
  lower_idx = indices[0]
315
344
  upper_idx = indices[-1]
316
-
345
+
317
346
  # Get the corresponding frequencies
318
347
  lower_bound = freq_axis[lower_idx]
319
348
  upper_bound = freq_axis[upper_idx]
320
349
  passband_width = upper_bound - lower_bound
321
-
350
+
322
351
  return (lower_bound, upper_bound, passband_width)
323
352
 
324
353
 
325
- def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1,show_passband: bool = False) -> np.ndarray:
354
+ def wavelet_filter(
355
+ x: np.ndarray,
356
+ freq: float,
357
+ fs: float,
358
+ bandwidth: float = 1.0,
359
+ axis: int = -1,
360
+ show_passband: bool = False,
361
+ ) -> np.ndarray:
326
362
  """
327
363
  Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
328
364
 
@@ -340,41 +376,55 @@ def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0
340
376
  Axis along which to compute the CWT (default is -1)
341
377
  show_passband : bool, optional
342
378
  If True, print the passband of the wavelet filter (default is False)
343
-
379
+
344
380
  Returns
345
381
  -------
346
382
  np.ndarray
347
383
  Continuous Wavelet Transform of the input signal
348
384
  """
349
385
  if show_passband:
350
- lower_bound, upper_bound, passband_width = calculate_wavelet_passband(freq, bandwidth, threshold=0.3) # kinda made up threshold gives the rough idea
386
+ lower_bound, upper_bound, passband_width = calculate_wavelet_passband(
387
+ freq, bandwidth, threshold=0.3
388
+ ) # kinda made up threshold gives the rough idea
351
389
  print(f"Wavelet filter at {freq:.1f} Hz Bandwidth: {bandwidth:.1f} Hz:")
352
- print(f" Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)")
353
- wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
390
+ print(
391
+ f" Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)"
392
+ )
393
+ wavelet = "cmor" + str(2 * bandwidth**2) + "-1.0"
354
394
  scale = pywt.scale2frequency(wavelet, 1) * fs / freq
355
395
  x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
356
396
  return x_a
357
397
 
358
398
 
359
- def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1) -> np.ndarray:
399
+ def butter_bandpass_filter(
400
+ data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1
401
+ ) -> np.ndarray:
360
402
  """
361
403
  Apply a Butterworth bandpass filter to the input data.
362
404
  """
363
- sos = signal.butter(order, [lowcut, highcut], fs=fs, btype='band', output='sos')
405
+ sos = signal.butter(order, [lowcut, highcut], fs=fs, btype="band", output="sos")
364
406
  x_a = signal.sosfiltfilt(sos, data, axis=axis)
365
407
  return x_a
366
408
 
367
409
 
368
- def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: str = 'wavelet',
369
- lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
410
+ def get_lfp_power(
411
+ lfp_data,
412
+ freq_of_interest: float,
413
+ fs: float,
414
+ filter_method: str = "wavelet",
415
+ lowcut: float = None,
416
+ highcut: float = None,
417
+ bandwidth: float = 1.0,
418
+ ):
370
419
  """
371
- Compute the power of the raw LFP signal in a specified frequency band.
372
-
420
+ Compute the power of the raw LFP signal in a specified frequency band,
421
+ preserving xarray structure if input is xarray.
422
+
373
423
  Parameters
374
424
  ----------
375
- lfp_data : np.ndarray
425
+ lfp_data : np.ndarray or xr.DataArray
376
426
  Raw local field potential (LFP) time series data
377
- freq : float
427
+ freq_of_interest : float
378
428
  Center frequency (Hz) for wavelet filtering method
379
429
  fs : float
380
430
  Sampling frequency (Hz) of the input data
@@ -386,45 +436,88 @@ def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: s
386
436
  Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
387
437
  bandwidth : float, optional
388
438
  Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
389
-
439
+
390
440
  Returns
391
441
  -------
392
- np.ndarray
393
- Power of the filtered signal (magnitude squared)
394
-
442
+ np.ndarray or xr.DataArray
443
+ Power of the filtered signal (magnitude squared) with same structure as input
444
+
395
445
  Notes
396
446
  -----
397
447
  - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
398
448
  - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
399
449
  - When using the 'butter' method, both lowcut and highcut must be provided
450
+ - If input is an xarray DataArray, the output will preserve the same structure with coordinates
400
451
  """
401
- if filter_method == 'wavelet':
402
- filtered_signal = wavelet_filter(lfp_data, freq, fs, bandwidth)
403
- elif filter_method == 'butter':
452
+ import xarray as xr
453
+
454
+ # Check if input is xarray
455
+ is_xarray = isinstance(lfp_data, xr.DataArray)
456
+
457
+ if is_xarray:
458
+ # Get the raw data from xarray
459
+ raw_data = lfp_data.values
460
+ # Check if 'fs' attribute exists in the xarray and override if necessary
461
+ if "fs" in lfp_data.attrs and fs is None:
462
+ fs = lfp_data.attrs["fs"]
463
+ else:
464
+ raw_data = lfp_data
465
+
466
+ if filter_method == "wavelet":
467
+ filtered_signal = wavelet_filter(raw_data, freq_of_interest, fs, bandwidth)
468
+ elif filter_method == "butter":
404
469
  if lowcut is None or highcut is None:
405
- raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
406
- filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
470
+ raise ValueError(
471
+ "Both lowcut and highcut must be specified when using 'butter' method."
472
+ )
473
+ filtered_signal = butter_bandpass_filter(raw_data, lowcut, highcut, fs)
407
474
  else:
408
475
  raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
409
-
476
+
410
477
  # Calculate power (magnitude squared of filtered signal)
411
- power = np.abs(filtered_signal)**2
478
+ power = np.abs(filtered_signal) ** 2
479
+
480
+ # If the input was an xarray, return an xarray with the same coordinates
481
+ if is_xarray:
482
+ power_xarray = xr.DataArray(
483
+ power,
484
+ coords=lfp_data.coords,
485
+ dims=lfp_data.dims,
486
+ attrs={
487
+ **lfp_data.attrs,
488
+ "filter_method": filter_method,
489
+ "frequency_of_interest": freq_of_interest,
490
+ "bandwidth": bandwidth,
491
+ "lowcut": lowcut,
492
+ "highcut": highcut,
493
+ "power_type": "magnitude_squared",
494
+ },
495
+ )
496
+ return power_xarray
497
+
412
498
  return power
413
499
 
414
500
 
415
- def get_lfp_phase(lfp_data: np.ndarray, freq_of_interest: float, fs: float, filter_method: str = 'wavelet',
416
- lowcut: float = None, highcut: float = None, bandwidth: float = 1.0) -> np.ndarray:
501
+ def get_lfp_phase(
502
+ lfp_data,
503
+ freq_of_interest: float,
504
+ fs: float,
505
+ filter_method: str = "wavelet",
506
+ lowcut: float = None,
507
+ highcut: float = None,
508
+ bandwidth: float = 1.0,
509
+ ) -> np.ndarray:
417
510
  """
418
- Calculate the phase of the filtered signal.
419
-
511
+ Calculate the phase of the filtered signal, preserving xarray structure if input is xarray.
512
+
420
513
  Parameters
421
514
  ----------
422
- lfp_data : np.ndarray
515
+ lfp_data : np.ndarray or xr.DataArray
423
516
  Input LFP data
517
+ freq_of_interest : float
518
+ Frequency of interest (Hz)
424
519
  fs : float
425
520
  Sampling frequency (Hz)
426
- freq : float
427
- Frequency of interest (Hz)
428
521
  filter_method : str, optional
429
522
  Method for filtering the signal ('wavelet' or 'butter')
430
523
  bandwidth : float, optional
@@ -433,43 +526,77 @@ def get_lfp_phase(lfp_data: np.ndarray, freq_of_interest: float, fs: float, filt
433
526
  Low cutoff frequency for Butterworth filter when method='butter'
434
527
  highcut : float, optional
435
528
  High cutoff frequency for Butterworth filter when method='butter'
436
-
529
+
437
530
  Returns
438
531
  -------
439
- np.ndarray
440
- Phase of the filtered signal
441
-
532
+ np.ndarray or xr.DataArray
533
+ Phase of the filtered signal with same structure as input
534
+
442
535
  Notes
443
536
  -----
444
537
  - The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
445
538
  - The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
446
539
  followed by Hilbert transform to extract the phase
447
540
  - When using the 'butter' method, both lowcut and highcut must be provided
541
+ - If input is an xarray DataArray, the output will preserve the same structure with coordinates
448
542
  """
449
- if filter_method == 'wavelet':
543
+ import xarray as xr
544
+
545
+ # Check if input is xarray
546
+ is_xarray = isinstance(lfp_data, xr.DataArray)
547
+
548
+ if is_xarray:
549
+ # Get the raw data from xarray
550
+ raw_data = lfp_data.values
551
+ # Check if 'fs' attribute exists in the xarray and override if necessary
552
+ if "fs" in lfp_data.attrs and fs is None:
553
+ fs = lfp_data.attrs["fs"]
554
+ else:
555
+ raw_data = lfp_data
556
+
557
+ if filter_method == "wavelet":
450
558
  if freq_of_interest is None:
451
559
  raise ValueError("freq_of_interest must be provided for the wavelet method.")
452
560
  # Wavelet filter returns complex values directly
453
- filtered_signal = wavelet_filter(lfp_data, freq_of_interest, fs, bandwidth)
561
+ filtered_signal = wavelet_filter(raw_data, freq_of_interest, fs, bandwidth)
454
562
  # Phase is the angle of the complex signal
455
563
  phase = np.angle(filtered_signal)
456
- elif filter_method == 'butter':
564
+ elif filter_method == "butter":
457
565
  if lowcut is None or highcut is None:
458
- raise ValueError("Both lowcut and highcut must be specified when using 'butter' method.")
566
+ raise ValueError(
567
+ "Both lowcut and highcut must be specified when using 'butter' method."
568
+ )
459
569
  # Butterworth filter returns real values
460
- filtered_signal = butter_bandpass_filter(lfp_data, lowcut, highcut, fs)
570
+ filtered_signal = butter_bandpass_filter(raw_data, lowcut, highcut, fs)
461
571
  # Apply Hilbert transform to get analytic signal (complex)
462
572
  analytic_signal = signal.hilbert(filtered_signal)
463
573
  # Phase is the angle of the analytic signal
464
574
  phase = np.angle(analytic_signal)
465
575
  else:
466
576
  raise ValueError(f"Invalid method {filter_method}. Choose 'wavelet' or 'butter'.")
467
-
577
+
578
+ # If the input was an xarray, return an xarray with the same coordinates
579
+ if is_xarray:
580
+ phase_xarray = xr.DataArray(
581
+ phase,
582
+ coords=lfp_data.coords,
583
+ dims=lfp_data.dims,
584
+ attrs={
585
+ **lfp_data.attrs,
586
+ "filter_method": filter_method,
587
+ "freq_of_interest": freq_of_interest,
588
+ "bandwidth": bandwidth,
589
+ "lowcut": lowcut,
590
+ "highcut": highcut,
591
+ },
592
+ )
593
+ return phase_xarray
594
+
468
595
  return phase
469
596
 
470
- # windowing functions
471
- def windowed_xarray(da, windows, dim='time',
472
- new_coord_name='cycle', new_coord=None):
597
+
598
+ # windowing functions
599
+ def windowed_xarray(da, windows, dim="time", new_coord_name="cycle", new_coord=None):
473
600
  """Divide xarray into windows of equal size along an axis
474
601
  da: input DataArray
475
602
  windows: 2d-array of windows
@@ -488,13 +615,13 @@ def windowed_xarray(da, windows, dim='time',
488
615
  return win_da
489
616
 
490
617
 
491
- def group_windows(win_da, win_grp_idx={}, win_dim='cycle'):
618
+ def group_windows(win_da, win_grp_idx={}, win_dim="cycle"):
492
619
  """Group windows into a dictionary of DataArrays
493
620
  win_da: input windowed DataArrays
494
621
  win_grp_idx: dictionary of {window group id: window indices}
495
622
  win_dim: dimension for different windows
496
623
  Return: dictionaries of {window group id: DataArray of grouped windows}
497
- win_on / win_off for windows selected / not selected by `win_grp_idx`
624
+ win_on / win_off for windows selected / not selected by `win_grp_idx`
498
625
  """
499
626
  win_on, win_off = {}, {}
500
627
  for g, w in win_grp_idx.items():
@@ -503,22 +630,27 @@ def group_windows(win_da, win_grp_idx={}, win_dim='cycle'):
503
630
  return win_on, win_off
504
631
 
505
632
 
506
- def average_group_windows(win_da, win_dim='cycle', grp_dim='unique_cycle'):
633
+ def average_group_windows(win_da, win_dim="cycle", grp_dim="unique_cycle"):
507
634
  """Average over windows in each group and stack groups in a DataArray
508
635
  win_da: input dictionary of {window group id: DataArray of grouped windows}
509
636
  win_dim: dimension for different windows
510
- grp_dim: dimension along which to stack average of window groups
637
+ grp_dim: dimension along which to stack average of window groups
511
638
  """
512
- win_avg = {g: xr.concat([x.mean(dim=win_dim), x.std(dim=win_dim)],
513
- pd.Index(('mean_', 'std_'), name='stats'))
514
- for g, x in win_da.items()}
639
+ win_avg = {
640
+ g: xr.concat(
641
+ [x.mean(dim=win_dim), x.std(dim=win_dim)], pd.Index(("mean_", "std_"), name="stats")
642
+ )
643
+ for g, x in win_da.items()
644
+ }
515
645
  win_avg = xr.concat(win_avg.values(), dim=pd.Index(win_avg.keys(), name=grp_dim))
516
- win_avg = win_avg.to_dataset(dim='stats')
646
+ win_avg = win_avg.to_dataset(dim="stats")
517
647
  return win_avg
518
648
 
649
+
519
650
  # used for avg spectrogram across different trials
520
- def get_windowed_data(x, windows, win_grp_idx, dim='time',
521
- win_dim='cycle', win_coord=None, grp_dim='unique_cycle'):
651
+ def get_windowed_data(
652
+ x, windows, win_grp_idx, dim="time", win_dim="cycle", win_coord=None, grp_dim="unique_cycle"
653
+ ):
522
654
  """Apply functions of windowing to data
523
655
  x: DataArray
524
656
  windows: `windows` for `windowed_xarray`
@@ -531,47 +663,58 @@ def get_windowed_data(x, windows, win_grp_idx, dim='time',
531
663
  Return: data returned by three functions,
532
664
  `windowed_xarray`, `group_windows`, `average_group_windows`
533
665
  """
534
- x_win = windowed_xarray(x, windows, dim=dim,
535
- new_coord_name=win_dim, new_coord=win_coord)
666
+ x_win = windowed_xarray(x, windows, dim=dim, new_coord_name=win_dim, new_coord=win_coord)
536
667
  x_win_onff = group_windows(x_win, win_grp_idx, win_dim=win_dim)
537
668
  if grp_dim:
538
- x_win_avg = [average_group_windows(x, win_dim=win_dim, grp_dim=grp_dim)
539
- for x in x_win_onff]
669
+ x_win_avg = [average_group_windows(x, win_dim=win_dim, grp_dim=grp_dim) for x in x_win_onff]
540
670
  else:
541
671
  x_win_avg = None
542
672
  return x_win, x_win_onff, x_win_avg
543
-
544
- # cone of influence in frequency for cmorxx-1.0 wavelet. need to add logic to calculate in function
673
+
674
+
675
+ # cone of influence in frequency for cmorxx-1.0 wavelet. need to add logic to calculate in function
545
676
  f0 = 2 * np.pi
546
- CMOR_COI = 2 ** -0.5
547
- CMOR_FLAMBDA = 4 * np.pi / (f0 + (2 + f0 ** 2) ** 0.5)
677
+ CMOR_COI = 2**-0.5
678
+ CMOR_FLAMBDA = 4 * np.pi / (f0 + (2 + f0**2) ** 0.5)
548
679
  COI_FREQ = 1 / (CMOR_COI * CMOR_FLAMBDA)
549
680
 
550
- def cwt_spectrogram(x, fs, nNotes=6, nOctaves=np.inf, freq_range=(0, np.inf),
551
- bandwidth=1.0, axis=-1, detrend=False, normalize=False):
681
+
682
+ def cwt_spectrogram(
683
+ x,
684
+ fs,
685
+ nNotes=6,
686
+ nOctaves=np.inf,
687
+ freq_range=(0, np.inf),
688
+ bandwidth=1.0,
689
+ axis=-1,
690
+ detrend=False,
691
+ normalize=False,
692
+ ):
552
693
  """Calculate spectrogram using continuous wavelet transform"""
553
694
  x = np.asarray(x)
554
695
  N = x.shape[axis]
555
696
  times = np.arange(N) / fs
556
697
  # detrend and normalize
557
698
  if detrend:
558
- x = signal.detrend(x, axis=axis, type='linear')
699
+ x = signal.detrend(x, axis=axis, type="linear")
559
700
  if normalize:
560
701
  x = x / x.std()
561
- # Define some parameters of our wavelet analysis.
702
+ # Define some parameters of our wavelet analysis.
562
703
  # range of scales (in time) that makes sense
563
704
  # min = 2 (Nyquist frequency)
564
705
  # max = np.floor(N/2)
565
706
  nOctaves = min(nOctaves, np.log2(2 * np.floor(N / 2)))
566
707
  scales = 2 ** np.arange(1, nOctaves, 1 / nNotes)
567
- # cwt and the frequencies used.
708
+ # cwt and the frequencies used.
568
709
  # Use the complex morelet with bw=2*bandwidth^2 and center frequency of 1.0
569
710
  # bandwidth is sigma of the gaussian envelope
570
- wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
711
+ wavelet = "cmor" + str(2 * bandwidth**2) + "-1.0"
571
712
  frequencies = pywt.scale2frequency(wavelet, scales) * fs
572
713
  scales = scales[(frequencies >= freq_range[0]) & (frequencies <= freq_range[1])]
573
- coef, frequencies = pywt.cwt(x, scales[::-1], wavelet=wavelet, sampling_period=1 / fs, axis=axis)
574
- power = np.real(coef * np.conj(coef)) # equivalent to power = np.abs(coef)**2
714
+ coef, frequencies = pywt.cwt(
715
+ x, scales[::-1], wavelet=wavelet, sampling_period=1 / fs, axis=axis
716
+ )
717
+ power = np.real(coef * np.conj(coef)) # equivalent to power = np.abs(coef)**2
575
718
  # cone of influence in terms of wavelength
576
719
  coi = N / 2 - np.abs(np.arange(N) - (N - 1) / 2)
577
720
  # cone of influence in terms of frequency
@@ -579,8 +722,9 @@ def cwt_spectrogram(x, fs, nNotes=6, nOctaves=np.inf, freq_range=(0, np.inf),
579
722
  return power, times, frequencies, coif
580
723
 
581
724
 
582
- def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
583
- channel_coords=None, **cwt_kwargs):
725
+ def cwt_spectrogram_xarray(
726
+ x, fs, time=None, axis=-1, downsample_fs=None, channel_coords=None, **cwt_kwargs
727
+ ):
584
728
  """Calculate spectrogram using continuous wavelet transform and return an xarray.Dataset
585
729
  x: input array
586
730
  fs: sampling frequency (Hz)
@@ -590,7 +734,7 @@ def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
590
734
  cwt_kwargs: keyword arguments for cwt_spectrogram()
591
735
  """
592
736
  x = np.asarray(x)
593
- T = x.shape[axis] # number of time points
737
+ T = x.shape[axis] # number of time points
594
738
  t = np.arange(T) / fs if time is None else np.asarray(time)
595
739
  if downsample_fs is None or downsample_fs >= fs:
596
740
  downsample_fs = fs
@@ -601,10 +745,11 @@ def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
601
745
  downsampled, t = signal.resample(x, num=num, t=t, axis=axis)
602
746
  downsampled = np.moveaxis(downsampled, axis, -1)
603
747
  sxx, _, f, coif = cwt_spectrogram(downsampled, downsample_fs, **cwt_kwargs)
604
- sxx = np.moveaxis(sxx, 0, -2) # shape (... , freq, time)
748
+ sxx = np.moveaxis(sxx, 0, -2) # shape (... , freq, time)
605
749
  if channel_coords is None:
606
- channel_coords = {f'dim_{i:d}': range(d) for i, d in enumerate(sxx.shape[:-2])}
607
- sxx = xr.DataArray(sxx, coords={**channel_coords, 'frequency': f, 'time': t}).to_dataset(name='PSD')
608
- sxx.update(dict(cone_of_influence_frequency=xr.DataArray(coif, coords={'time': t})))
750
+ channel_coords = {f"dim_{i:d}": range(d) for i, d in enumerate(sxx.shape[:-2])}
751
+ sxx = xr.DataArray(sxx, coords={**channel_coords, "frequency": f, "time": t}).to_dataset(
752
+ name="PSD"
753
+ )
754
+ sxx.update(dict(cone_of_influence_frequency=xr.DataArray(coif, coords={"time": t})))
609
755
  return sxx
610
-