bmtool 0.7.0.6.4__py3-none-any.whl → 0.7.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bmtool/SLURM.py +162 -109
- bmtool/__init__.py +1 -1
- bmtool/__main__.py +8 -7
- bmtool/analysis/entrainment.py +290 -147
- bmtool/analysis/lfp.py +279 -134
- bmtool/analysis/netcon_reports.py +41 -44
- bmtool/analysis/spikes.py +114 -73
- bmtool/bmplot/connections.py +658 -325
- bmtool/bmplot/entrainment.py +17 -18
- bmtool/bmplot/lfp.py +24 -17
- bmtool/bmplot/netcon_reports.py +0 -4
- bmtool/bmplot/spikes.py +97 -48
- bmtool/connectors.py +394 -251
- bmtool/debug/commands.py +13 -7
- bmtool/debug/debug.py +2 -2
- bmtool/graphs.py +26 -19
- bmtool/manage.py +6 -11
- bmtool/plot_commands.py +350 -151
- bmtool/singlecell.py +357 -195
- bmtool/synapses.py +564 -470
- bmtool/util/commands.py +1079 -627
- bmtool/util/neuron/celltuner.py +989 -609
- bmtool/util/util.py +992 -588
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/METADATA +40 -2
- bmtool-0.7.1.1.dist-info/RECORD +34 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/WHEEL +1 -1
- bmtool-0.7.0.6.4.dist-info/RECORD +0 -34
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/entry_points.txt +0 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/licenses/LICENSE +0 -0
- {bmtool-0.7.0.6.4.dist-info → bmtool-0.7.1.1.dist-info}/top_level.txt +0 -0
bmtool/analysis/lfp.py
CHANGED
@@ -3,16 +3,18 @@ Module for processing BMTK LFP output.
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
import h5py
|
6
|
+
import matplotlib.pyplot as plt
|
6
7
|
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import pywt
|
7
10
|
import xarray as xr
|
8
11
|
from fooof import FOOOF
|
9
|
-
from fooof.sim.gen import gen_model
|
10
|
-
|
11
|
-
|
12
|
-
import pywt
|
13
|
-
import pandas as pd
|
12
|
+
from fooof.sim.gen import gen_model
|
13
|
+
from scipy import signal
|
14
|
+
|
14
15
|
from ..bmplot.connections import is_notebook
|
15
16
|
|
17
|
+
|
16
18
|
def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
|
17
19
|
"""
|
18
20
|
Load ECP data from an HDF5 file (BMTK sim) into an xarray DataArray.
|
@@ -30,26 +32,27 @@ def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
|
|
30
32
|
An xarray DataArray containing the ECP data, with time as one dimension
|
31
33
|
and channel_id as another.
|
32
34
|
"""
|
33
|
-
with h5py.File(ecp_file,
|
35
|
+
with h5py.File(ecp_file, "r") as f:
|
34
36
|
ecp = xr.DataArray(
|
35
|
-
f[
|
37
|
+
f["ecp"]["data"][()].T,
|
36
38
|
coords=dict(
|
37
|
-
channel_id=f[
|
38
|
-
time=np.arange(*f[
|
39
|
+
channel_id=f["ecp"]["channel_id"][()],
|
40
|
+
time=np.arange(*f["ecp"]["time"]), # ms
|
39
41
|
),
|
40
42
|
attrs=dict(
|
41
|
-
fs=1000 / f[
|
42
|
-
)
|
43
|
+
fs=1000 / f["ecp"]["time"][2] # Hz
|
44
|
+
),
|
43
45
|
)
|
44
46
|
if demean:
|
45
|
-
ecp -= ecp.mean(dim=
|
47
|
+
ecp -= ecp.mean(dim="time")
|
46
48
|
return ecp
|
47
49
|
|
48
50
|
|
49
|
-
def ecp_to_lfp(
|
50
|
-
|
51
|
+
def ecp_to_lfp(
|
52
|
+
ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000, downsample_freq: float = 1000
|
53
|
+
) -> xr.DataArray:
|
51
54
|
"""
|
52
|
-
Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
|
55
|
+
Apply a low-pass Butterworth filter to an xarray DataArray and optionally downsample.
|
53
56
|
This filters out the high end frequencies turning the ECP into a LFP
|
54
57
|
|
55
58
|
Parameters:
|
@@ -71,21 +74,25 @@ def ecp_to_lfp(ecp_data: xr.DataArray, cutoff: float = 250, fs: float = 10000,
|
|
71
74
|
# Bandpass filter design
|
72
75
|
nyq = 0.5 * fs
|
73
76
|
cut = cutoff / nyq
|
74
|
-
b, a = signal.butter(8, cut, btype=
|
77
|
+
b, a = signal.butter(8, cut, btype="low", analog=False)
|
75
78
|
|
76
79
|
# Initialize an array to hold filtered data
|
77
|
-
filtered_data = xr.DataArray(
|
80
|
+
filtered_data = xr.DataArray(
|
81
|
+
np.zeros_like(ecp_data), coords=ecp_data.coords, dims=ecp_data.dims
|
82
|
+
)
|
78
83
|
|
79
84
|
# Apply the filter to each channel
|
80
85
|
for channel in ecp_data.channel_id:
|
81
|
-
filtered_data.loc[channel, :] = signal.filtfilt(
|
86
|
+
filtered_data.loc[channel, :] = signal.filtfilt(
|
87
|
+
b, a, ecp_data.sel(channel_id=channel).values
|
88
|
+
)
|
82
89
|
|
83
90
|
# Downsample the filtered data if a downsample frequency is provided
|
84
91
|
if downsample_freq is not None:
|
85
92
|
downsample_factor = int(fs / downsample_freq)
|
86
93
|
filtered_data = filtered_data.isel(time=slice(None, None, downsample_factor))
|
87
94
|
# Update the sampling frequency attribute
|
88
|
-
filtered_data.attrs[
|
95
|
+
filtered_data.attrs["fs"] = downsample_freq
|
89
96
|
|
90
97
|
return filtered_data
|
91
98
|
|
@@ -100,7 +107,7 @@ def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
|
|
100
107
|
data : xr.DataArray
|
101
108
|
The input xarray DataArray containing time-series data.
|
102
109
|
time_ranges : tuple or list of tuples
|
103
|
-
One or more tuples representing the (start, stop) time points for slicing.
|
110
|
+
One or more tuples representing the (start, stop) time points for slicing.
|
104
111
|
For example: (start, stop) or [(start1, stop1), (start2, stop2)]
|
105
112
|
|
106
113
|
Returns:
|
@@ -122,17 +129,26 @@ def slice_time_series(data: xr.DataArray, time_ranges: tuple) -> xr.DataArray:
|
|
122
129
|
|
123
130
|
# Concatenate all slices along the time dimension if more than one slice
|
124
131
|
if len(slices) > 1:
|
125
|
-
return xr.concat(slices, dim=
|
132
|
+
return xr.concat(slices, dim="time")
|
126
133
|
else:
|
127
134
|
return slices[0]
|
128
135
|
|
129
136
|
|
130
|
-
def fit_fooof(
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
137
|
+
def fit_fooof(
|
138
|
+
f: np.ndarray,
|
139
|
+
pxx: np.ndarray,
|
140
|
+
aperiodic_mode: str = "fixed",
|
141
|
+
dB_threshold: float = 3.0,
|
142
|
+
max_n_peaks: int = 10,
|
143
|
+
freq_range: tuple = None,
|
144
|
+
peak_width_limits: tuple = None,
|
145
|
+
report: bool = False,
|
146
|
+
plot: bool = False,
|
147
|
+
plt_log: bool = False,
|
148
|
+
plt_range: tuple = None,
|
149
|
+
figsize: tuple = None,
|
150
|
+
title: str = None,
|
151
|
+
) -> tuple:
|
136
152
|
"""
|
137
153
|
Fit a FOOOF model to power spectral density data.
|
138
154
|
|
@@ -170,43 +186,50 @@ def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
|
|
170
186
|
tuple
|
171
187
|
A tuple containing the fitting results and the FOOOF model object.
|
172
188
|
"""
|
173
|
-
if aperiodic_mode !=
|
174
|
-
aperiodic_mode =
|
175
|
-
|
189
|
+
if aperiodic_mode != "knee":
|
190
|
+
aperiodic_mode = "fixed"
|
191
|
+
|
176
192
|
def set_range(x, upper=f[-1]):
|
177
193
|
x = np.array(upper) if x is None else np.array(x)
|
178
194
|
return [f[2], x.item()] if x.size == 1 else x.tolist()
|
179
|
-
|
195
|
+
|
180
196
|
freq_range = set_range(freq_range)
|
181
197
|
peak_width_limits = set_range(peak_width_limits, np.inf)
|
182
198
|
|
183
199
|
# Initialize a FOOOF object
|
184
|
-
fm = FOOOF(
|
185
|
-
|
186
|
-
|
200
|
+
fm = FOOOF(
|
201
|
+
peak_width_limits=peak_width_limits,
|
202
|
+
min_peak_height=dB_threshold / 10,
|
203
|
+
peak_threshold=0.0,
|
204
|
+
max_n_peaks=max_n_peaks,
|
205
|
+
aperiodic_mode=aperiodic_mode,
|
206
|
+
)
|
207
|
+
|
187
208
|
# Fit the model
|
188
209
|
try:
|
189
210
|
fm.fit(f, pxx, freq_range)
|
190
211
|
except Exception as e:
|
191
212
|
fl = np.linspace(f[0], f[-1], int((f[-1] - f[0]) / np.min(np.diff(f))) + 1)
|
192
213
|
fm.fit(fl, np.interp(fl, f, pxx), freq_range)
|
193
|
-
|
214
|
+
|
194
215
|
results = fm.get_results()
|
195
216
|
|
196
217
|
if report:
|
197
218
|
fm.print_results()
|
198
|
-
if aperiodic_mode ==
|
219
|
+
if aperiodic_mode == "knee":
|
199
220
|
ap_params = results.aperiodic_params
|
200
221
|
if ap_params[1] <= 0:
|
201
|
-
print(
|
222
|
+
print(
|
223
|
+
"Negative value of knee parameter occurred. Suggestion: Fit without knee parameter."
|
224
|
+
)
|
202
225
|
knee_freq = np.abs(ap_params[1]) ** (1 / ap_params[2])
|
203
|
-
print(f
|
204
|
-
|
226
|
+
print(f"Knee location: {knee_freq:.2f} Hz")
|
227
|
+
|
205
228
|
if plot:
|
206
229
|
plt_range = set_range(plt_range)
|
207
230
|
fm.plot(plt_log=plt_log)
|
208
231
|
plt.xlim(np.log10(plt_range) if plt_log else plt_range)
|
209
|
-
#plt.ylim(-8, -5.5)
|
232
|
+
# plt.ylim(-8, -5.5)
|
210
233
|
if figsize:
|
211
234
|
plt.gcf().set_size_inches(figsize)
|
212
235
|
if title:
|
@@ -215,7 +238,7 @@ def fit_fooof(f: np.ndarray, pxx: np.ndarray, aperiodic_mode: str = 'fixed',
|
|
215
238
|
pass
|
216
239
|
else:
|
217
240
|
plt.show()
|
218
|
-
|
241
|
+
|
219
242
|
return results, fm
|
220
243
|
|
221
244
|
|
@@ -234,13 +257,19 @@ def generate_resd_from_fooof(fooof_model: FOOOF) -> tuple:
|
|
234
257
|
A tuple containing the residual power spectral density and the aperiodic fit.
|
235
258
|
"""
|
236
259
|
results = fooof_model.get_results()
|
237
|
-
full_fit, _, ap_fit = gen_model(
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
260
|
+
full_fit, _, ap_fit = gen_model(
|
261
|
+
fooof_model.freqs[1:],
|
262
|
+
results.aperiodic_params,
|
263
|
+
results.gaussian_params,
|
264
|
+
return_components=True,
|
265
|
+
)
|
266
|
+
|
267
|
+
full_fit, ap_fit = 10**full_fit, 10**ap_fit # Convert back from log
|
268
|
+
res_psd = np.insert(
|
269
|
+
(10 ** fooof_model.power_spectrum[1:]) - ap_fit, 0, 0.0
|
270
|
+
) # Convert back from log
|
271
|
+
res_fit = np.insert(full_fit - ap_fit, 0, 0.0)
|
272
|
+
ap_fit = np.insert(ap_fit, 0, 0.0)
|
244
273
|
|
245
274
|
return res_psd, ap_fit
|
246
275
|
|
@@ -276,7 +305,7 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
|
|
276
305
|
def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
|
277
306
|
"""
|
278
307
|
Calculate the passband of a complex Morlet wavelet filter.
|
279
|
-
|
308
|
+
|
280
309
|
Parameters
|
281
310
|
----------
|
282
311
|
center_freq : float
|
@@ -285,7 +314,7 @@ def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
|
|
285
314
|
Bandwidth parameter of the wavelet filter
|
286
315
|
threshold : float, optional
|
287
316
|
Power threshold to define the passband edges (default: 0.5 = -3dB point)
|
288
|
-
|
317
|
+
|
289
318
|
Returns
|
290
319
|
-------
|
291
320
|
tuple
|
@@ -297,32 +326,39 @@ def calculate_wavelet_passband(center_freq, bandwidth, threshold=0.3):
|
|
297
326
|
freq_min = max(0.1, center_freq - 3 * expected_width)
|
298
327
|
freq_max = center_freq + 3 * expected_width
|
299
328
|
freq_axis = np.linspace(freq_min, freq_max, 1000)
|
300
|
-
|
329
|
+
|
301
330
|
# Calculate the theoretical frequency response of the Morlet wavelet
|
302
331
|
# For a complex Morlet wavelet, the frequency response approximates a Gaussian
|
303
332
|
# centered at the center frequency with width related to the bandwidth parameter
|
304
333
|
sigma_f = bandwidth * center_freq / 8 # Approximate relationship for cmor wavelet
|
305
|
-
response = np.exp(-((freq_axis - center_freq)**2) / (2 * sigma_f**2))
|
306
|
-
|
334
|
+
response = np.exp(-((freq_axis - center_freq) ** 2) / (2 * sigma_f**2))
|
335
|
+
|
307
336
|
# Find the passband edges (where response crosses the threshold)
|
308
337
|
above_threshold = response >= threshold
|
309
338
|
if not np.any(above_threshold):
|
310
339
|
return (center_freq, center_freq, 0) # No passband found
|
311
|
-
|
340
|
+
|
312
341
|
# Find the first and last indices where response is above threshold
|
313
342
|
indices = np.where(above_threshold)[0]
|
314
343
|
lower_idx = indices[0]
|
315
344
|
upper_idx = indices[-1]
|
316
|
-
|
345
|
+
|
317
346
|
# Get the corresponding frequencies
|
318
347
|
lower_bound = freq_axis[lower_idx]
|
319
348
|
upper_bound = freq_axis[upper_idx]
|
320
349
|
passband_width = upper_bound - lower_bound
|
321
|
-
|
350
|
+
|
322
351
|
return (lower_bound, upper_bound, passband_width)
|
323
352
|
|
324
353
|
|
325
|
-
def wavelet_filter(
|
354
|
+
def wavelet_filter(
|
355
|
+
x: np.ndarray,
|
356
|
+
freq: float,
|
357
|
+
fs: float,
|
358
|
+
bandwidth: float = 1.0,
|
359
|
+
axis: int = -1,
|
360
|
+
show_passband: bool = False,
|
361
|
+
) -> np.ndarray:
|
326
362
|
"""
|
327
363
|
Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
|
328
364
|
|
@@ -340,41 +376,55 @@ def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0
|
|
340
376
|
Axis along which to compute the CWT (default is -1)
|
341
377
|
show_passband : bool, optional
|
342
378
|
If True, print the passband of the wavelet filter (default is False)
|
343
|
-
|
379
|
+
|
344
380
|
Returns
|
345
381
|
-------
|
346
382
|
np.ndarray
|
347
383
|
Continuous Wavelet Transform of the input signal
|
348
384
|
"""
|
349
385
|
if show_passband:
|
350
|
-
lower_bound, upper_bound, passband_width = calculate_wavelet_passband(
|
386
|
+
lower_bound, upper_bound, passband_width = calculate_wavelet_passband(
|
387
|
+
freq, bandwidth, threshold=0.3
|
388
|
+
) # kinda made up threshold gives the rough idea
|
351
389
|
print(f"Wavelet filter at {freq:.1f} Hz Bandwidth: {bandwidth:.1f} Hz:")
|
352
|
-
print(
|
353
|
-
|
390
|
+
print(
|
391
|
+
f" Passband: {lower_bound:.1f} - {upper_bound:.1f} Hz (width: {passband_width:.1f} Hz)"
|
392
|
+
)
|
393
|
+
wavelet = "cmor" + str(2 * bandwidth**2) + "-1.0"
|
354
394
|
scale = pywt.scale2frequency(wavelet, 1) * fs / freq
|
355
395
|
x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
|
356
396
|
return x_a
|
357
397
|
|
358
398
|
|
359
|
-
def butter_bandpass_filter(
|
399
|
+
def butter_bandpass_filter(
|
400
|
+
data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1
|
401
|
+
) -> np.ndarray:
|
360
402
|
"""
|
361
403
|
Apply a Butterworth bandpass filter to the input data.
|
362
404
|
"""
|
363
|
-
sos = signal.butter(order, [lowcut, highcut], fs=fs, btype=
|
405
|
+
sos = signal.butter(order, [lowcut, highcut], fs=fs, btype="band", output="sos")
|
364
406
|
x_a = signal.sosfiltfilt(sos, data, axis=axis)
|
365
407
|
return x_a
|
366
408
|
|
367
409
|
|
368
|
-
def get_lfp_power(
|
369
|
-
|
410
|
+
def get_lfp_power(
|
411
|
+
lfp_data,
|
412
|
+
freq_of_interest: float,
|
413
|
+
fs: float,
|
414
|
+
filter_method: str = "wavelet",
|
415
|
+
lowcut: float = None,
|
416
|
+
highcut: float = None,
|
417
|
+
bandwidth: float = 1.0,
|
418
|
+
):
|
370
419
|
"""
|
371
|
-
Compute the power of the raw LFP signal in a specified frequency band
|
372
|
-
|
420
|
+
Compute the power of the raw LFP signal in a specified frequency band,
|
421
|
+
preserving xarray structure if input is xarray.
|
422
|
+
|
373
423
|
Parameters
|
374
424
|
----------
|
375
|
-
lfp_data : np.ndarray
|
425
|
+
lfp_data : np.ndarray or xr.DataArray
|
376
426
|
Raw local field potential (LFP) time series data
|
377
|
-
|
427
|
+
freq_of_interest : float
|
378
428
|
Center frequency (Hz) for wavelet filtering method
|
379
429
|
fs : float
|
380
430
|
Sampling frequency (Hz) of the input data
|
@@ -386,45 +436,88 @@ def get_lfp_power(lfp_data: np.ndarray, freq: float, fs: float, filter_method: s
|
|
386
436
|
Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
|
387
437
|
bandwidth : float, optional
|
388
438
|
Bandwidth parameter for wavelet filter when method='wavelet' (default: 1.0)
|
389
|
-
|
439
|
+
|
390
440
|
Returns
|
391
441
|
-------
|
392
|
-
np.ndarray
|
393
|
-
Power of the filtered signal (magnitude squared)
|
394
|
-
|
442
|
+
np.ndarray or xr.DataArray
|
443
|
+
Power of the filtered signal (magnitude squared) with same structure as input
|
444
|
+
|
395
445
|
Notes
|
396
446
|
-----
|
397
447
|
- The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
|
398
448
|
- The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
|
399
449
|
- When using the 'butter' method, both lowcut and highcut must be provided
|
450
|
+
- If input is an xarray DataArray, the output will preserve the same structure with coordinates
|
400
451
|
"""
|
401
|
-
|
402
|
-
|
403
|
-
|
452
|
+
import xarray as xr
|
453
|
+
|
454
|
+
# Check if input is xarray
|
455
|
+
is_xarray = isinstance(lfp_data, xr.DataArray)
|
456
|
+
|
457
|
+
if is_xarray:
|
458
|
+
# Get the raw data from xarray
|
459
|
+
raw_data = lfp_data.values
|
460
|
+
# Check if 'fs' attribute exists in the xarray and override if necessary
|
461
|
+
if "fs" in lfp_data.attrs and fs is None:
|
462
|
+
fs = lfp_data.attrs["fs"]
|
463
|
+
else:
|
464
|
+
raw_data = lfp_data
|
465
|
+
|
466
|
+
if filter_method == "wavelet":
|
467
|
+
filtered_signal = wavelet_filter(raw_data, freq_of_interest, fs, bandwidth)
|
468
|
+
elif filter_method == "butter":
|
404
469
|
if lowcut is None or highcut is None:
|
405
|
-
raise ValueError(
|
406
|
-
|
470
|
+
raise ValueError(
|
471
|
+
"Both lowcut and highcut must be specified when using 'butter' method."
|
472
|
+
)
|
473
|
+
filtered_signal = butter_bandpass_filter(raw_data, lowcut, highcut, fs)
|
407
474
|
else:
|
408
475
|
raise ValueError("Invalid method. Choose 'wavelet' or 'butter'.")
|
409
|
-
|
476
|
+
|
410
477
|
# Calculate power (magnitude squared of filtered signal)
|
411
|
-
power = np.abs(filtered_signal)**2
|
478
|
+
power = np.abs(filtered_signal) ** 2
|
479
|
+
|
480
|
+
# If the input was an xarray, return an xarray with the same coordinates
|
481
|
+
if is_xarray:
|
482
|
+
power_xarray = xr.DataArray(
|
483
|
+
power,
|
484
|
+
coords=lfp_data.coords,
|
485
|
+
dims=lfp_data.dims,
|
486
|
+
attrs={
|
487
|
+
**lfp_data.attrs,
|
488
|
+
"filter_method": filter_method,
|
489
|
+
"frequency_of_interest": freq_of_interest,
|
490
|
+
"bandwidth": bandwidth,
|
491
|
+
"lowcut": lowcut,
|
492
|
+
"highcut": highcut,
|
493
|
+
"power_type": "magnitude_squared",
|
494
|
+
},
|
495
|
+
)
|
496
|
+
return power_xarray
|
497
|
+
|
412
498
|
return power
|
413
499
|
|
414
500
|
|
415
|
-
def get_lfp_phase(
|
416
|
-
|
501
|
+
def get_lfp_phase(
|
502
|
+
lfp_data,
|
503
|
+
freq_of_interest: float,
|
504
|
+
fs: float,
|
505
|
+
filter_method: str = "wavelet",
|
506
|
+
lowcut: float = None,
|
507
|
+
highcut: float = None,
|
508
|
+
bandwidth: float = 1.0,
|
509
|
+
) -> np.ndarray:
|
417
510
|
"""
|
418
|
-
Calculate the phase of the filtered signal.
|
419
|
-
|
511
|
+
Calculate the phase of the filtered signal, preserving xarray structure if input is xarray.
|
512
|
+
|
420
513
|
Parameters
|
421
514
|
----------
|
422
|
-
lfp_data : np.ndarray
|
515
|
+
lfp_data : np.ndarray or xr.DataArray
|
423
516
|
Input LFP data
|
517
|
+
freq_of_interest : float
|
518
|
+
Frequency of interest (Hz)
|
424
519
|
fs : float
|
425
520
|
Sampling frequency (Hz)
|
426
|
-
freq : float
|
427
|
-
Frequency of interest (Hz)
|
428
521
|
filter_method : str, optional
|
429
522
|
Method for filtering the signal ('wavelet' or 'butter')
|
430
523
|
bandwidth : float, optional
|
@@ -433,43 +526,77 @@ def get_lfp_phase(lfp_data: np.ndarray, freq_of_interest: float, fs: float, filt
|
|
433
526
|
Low cutoff frequency for Butterworth filter when method='butter'
|
434
527
|
highcut : float, optional
|
435
528
|
High cutoff frequency for Butterworth filter when method='butter'
|
436
|
-
|
529
|
+
|
437
530
|
Returns
|
438
531
|
-------
|
439
|
-
np.ndarray
|
440
|
-
Phase of the filtered signal
|
441
|
-
|
532
|
+
np.ndarray or xr.DataArray
|
533
|
+
Phase of the filtered signal with same structure as input
|
534
|
+
|
442
535
|
Notes
|
443
536
|
-----
|
444
537
|
- The 'wavelet' method uses a complex Morlet wavelet centered at the specified frequency
|
445
538
|
- The 'butter' method uses a Butterworth bandpass filter with the specified cutoff frequencies
|
446
539
|
followed by Hilbert transform to extract the phase
|
447
540
|
- When using the 'butter' method, both lowcut and highcut must be provided
|
541
|
+
- If input is an xarray DataArray, the output will preserve the same structure with coordinates
|
448
542
|
"""
|
449
|
-
|
543
|
+
import xarray as xr
|
544
|
+
|
545
|
+
# Check if input is xarray
|
546
|
+
is_xarray = isinstance(lfp_data, xr.DataArray)
|
547
|
+
|
548
|
+
if is_xarray:
|
549
|
+
# Get the raw data from xarray
|
550
|
+
raw_data = lfp_data.values
|
551
|
+
# Check if 'fs' attribute exists in the xarray and override if necessary
|
552
|
+
if "fs" in lfp_data.attrs and fs is None:
|
553
|
+
fs = lfp_data.attrs["fs"]
|
554
|
+
else:
|
555
|
+
raw_data = lfp_data
|
556
|
+
|
557
|
+
if filter_method == "wavelet":
|
450
558
|
if freq_of_interest is None:
|
451
559
|
raise ValueError("freq_of_interest must be provided for the wavelet method.")
|
452
560
|
# Wavelet filter returns complex values directly
|
453
|
-
filtered_signal = wavelet_filter(
|
561
|
+
filtered_signal = wavelet_filter(raw_data, freq_of_interest, fs, bandwidth)
|
454
562
|
# Phase is the angle of the complex signal
|
455
563
|
phase = np.angle(filtered_signal)
|
456
|
-
elif filter_method ==
|
564
|
+
elif filter_method == "butter":
|
457
565
|
if lowcut is None or highcut is None:
|
458
|
-
raise ValueError(
|
566
|
+
raise ValueError(
|
567
|
+
"Both lowcut and highcut must be specified when using 'butter' method."
|
568
|
+
)
|
459
569
|
# Butterworth filter returns real values
|
460
|
-
filtered_signal = butter_bandpass_filter(
|
570
|
+
filtered_signal = butter_bandpass_filter(raw_data, lowcut, highcut, fs)
|
461
571
|
# Apply Hilbert transform to get analytic signal (complex)
|
462
572
|
analytic_signal = signal.hilbert(filtered_signal)
|
463
573
|
# Phase is the angle of the analytic signal
|
464
574
|
phase = np.angle(analytic_signal)
|
465
575
|
else:
|
466
576
|
raise ValueError(f"Invalid method {filter_method}. Choose 'wavelet' or 'butter'.")
|
467
|
-
|
577
|
+
|
578
|
+
# If the input was an xarray, return an xarray with the same coordinates
|
579
|
+
if is_xarray:
|
580
|
+
phase_xarray = xr.DataArray(
|
581
|
+
phase,
|
582
|
+
coords=lfp_data.coords,
|
583
|
+
dims=lfp_data.dims,
|
584
|
+
attrs={
|
585
|
+
**lfp_data.attrs,
|
586
|
+
"filter_method": filter_method,
|
587
|
+
"freq_of_interest": freq_of_interest,
|
588
|
+
"bandwidth": bandwidth,
|
589
|
+
"lowcut": lowcut,
|
590
|
+
"highcut": highcut,
|
591
|
+
},
|
592
|
+
)
|
593
|
+
return phase_xarray
|
594
|
+
|
468
595
|
return phase
|
469
596
|
|
470
|
-
|
471
|
-
|
472
|
-
|
597
|
+
|
598
|
+
# windowing functions
|
599
|
+
def windowed_xarray(da, windows, dim="time", new_coord_name="cycle", new_coord=None):
|
473
600
|
"""Divide xarray into windows of equal size along an axis
|
474
601
|
da: input DataArray
|
475
602
|
windows: 2d-array of windows
|
@@ -488,13 +615,13 @@ def windowed_xarray(da, windows, dim='time',
|
|
488
615
|
return win_da
|
489
616
|
|
490
617
|
|
491
|
-
def group_windows(win_da, win_grp_idx={}, win_dim=
|
618
|
+
def group_windows(win_da, win_grp_idx={}, win_dim="cycle"):
|
492
619
|
"""Group windows into a dictionary of DataArrays
|
493
620
|
win_da: input windowed DataArrays
|
494
621
|
win_grp_idx: dictionary of {window group id: window indices}
|
495
622
|
win_dim: dimension for different windows
|
496
623
|
Return: dictionaries of {window group id: DataArray of grouped windows}
|
497
|
-
win_on / win_off for windows selected / not selected by `win_grp_idx`
|
624
|
+
win_on / win_off for windows selected / not selected by `win_grp_idx`
|
498
625
|
"""
|
499
626
|
win_on, win_off = {}, {}
|
500
627
|
for g, w in win_grp_idx.items():
|
@@ -503,22 +630,27 @@ def group_windows(win_da, win_grp_idx={}, win_dim='cycle'):
|
|
503
630
|
return win_on, win_off
|
504
631
|
|
505
632
|
|
506
|
-
def average_group_windows(win_da, win_dim=
|
633
|
+
def average_group_windows(win_da, win_dim="cycle", grp_dim="unique_cycle"):
|
507
634
|
"""Average over windows in each group and stack groups in a DataArray
|
508
635
|
win_da: input dictionary of {window group id: DataArray of grouped windows}
|
509
636
|
win_dim: dimension for different windows
|
510
|
-
grp_dim: dimension along which to stack average of window groups
|
637
|
+
grp_dim: dimension along which to stack average of window groups
|
511
638
|
"""
|
512
|
-
win_avg = {
|
513
|
-
|
514
|
-
|
639
|
+
win_avg = {
|
640
|
+
g: xr.concat(
|
641
|
+
[x.mean(dim=win_dim), x.std(dim=win_dim)], pd.Index(("mean_", "std_"), name="stats")
|
642
|
+
)
|
643
|
+
for g, x in win_da.items()
|
644
|
+
}
|
515
645
|
win_avg = xr.concat(win_avg.values(), dim=pd.Index(win_avg.keys(), name=grp_dim))
|
516
|
-
win_avg = win_avg.to_dataset(dim=
|
646
|
+
win_avg = win_avg.to_dataset(dim="stats")
|
517
647
|
return win_avg
|
518
648
|
|
649
|
+
|
519
650
|
# used for avg spectrogram across different trials
|
520
|
-
def get_windowed_data(
|
521
|
-
|
651
|
+
def get_windowed_data(
|
652
|
+
x, windows, win_grp_idx, dim="time", win_dim="cycle", win_coord=None, grp_dim="unique_cycle"
|
653
|
+
):
|
522
654
|
"""Apply functions of windowing to data
|
523
655
|
x: DataArray
|
524
656
|
windows: `windows` for `windowed_xarray`
|
@@ -531,47 +663,58 @@ def get_windowed_data(x, windows, win_grp_idx, dim='time',
|
|
531
663
|
Return: data returned by three functions,
|
532
664
|
`windowed_xarray`, `group_windows`, `average_group_windows`
|
533
665
|
"""
|
534
|
-
x_win = windowed_xarray(x, windows, dim=dim,
|
535
|
-
new_coord_name=win_dim, new_coord=win_coord)
|
666
|
+
x_win = windowed_xarray(x, windows, dim=dim, new_coord_name=win_dim, new_coord=win_coord)
|
536
667
|
x_win_onff = group_windows(x_win, win_grp_idx, win_dim=win_dim)
|
537
668
|
if grp_dim:
|
538
|
-
x_win_avg = [average_group_windows(x, win_dim=win_dim, grp_dim=grp_dim)
|
539
|
-
for x in x_win_onff]
|
669
|
+
x_win_avg = [average_group_windows(x, win_dim=win_dim, grp_dim=grp_dim) for x in x_win_onff]
|
540
670
|
else:
|
541
671
|
x_win_avg = None
|
542
672
|
return x_win, x_win_onff, x_win_avg
|
543
|
-
|
544
|
-
|
673
|
+
|
674
|
+
|
675
|
+
# cone of influence in frequency for cmorxx-1.0 wavelet. need to add logic to calculate in function
|
545
676
|
f0 = 2 * np.pi
|
546
|
-
CMOR_COI = 2
|
547
|
-
CMOR_FLAMBDA = 4 * np.pi / (f0 + (2 + f0
|
677
|
+
CMOR_COI = 2**-0.5
|
678
|
+
CMOR_FLAMBDA = 4 * np.pi / (f0 + (2 + f0**2) ** 0.5)
|
548
679
|
COI_FREQ = 1 / (CMOR_COI * CMOR_FLAMBDA)
|
549
680
|
|
550
|
-
|
551
|
-
|
681
|
+
|
682
|
+
def cwt_spectrogram(
|
683
|
+
x,
|
684
|
+
fs,
|
685
|
+
nNotes=6,
|
686
|
+
nOctaves=np.inf,
|
687
|
+
freq_range=(0, np.inf),
|
688
|
+
bandwidth=1.0,
|
689
|
+
axis=-1,
|
690
|
+
detrend=False,
|
691
|
+
normalize=False,
|
692
|
+
):
|
552
693
|
"""Calculate spectrogram using continuous wavelet transform"""
|
553
694
|
x = np.asarray(x)
|
554
695
|
N = x.shape[axis]
|
555
696
|
times = np.arange(N) / fs
|
556
697
|
# detrend and normalize
|
557
698
|
if detrend:
|
558
|
-
x = signal.detrend(x, axis=axis, type=
|
699
|
+
x = signal.detrend(x, axis=axis, type="linear")
|
559
700
|
if normalize:
|
560
701
|
x = x / x.std()
|
561
|
-
# Define some parameters of our wavelet analysis.
|
702
|
+
# Define some parameters of our wavelet analysis.
|
562
703
|
# range of scales (in time) that makes sense
|
563
704
|
# min = 2 (Nyquist frequency)
|
564
705
|
# max = np.floor(N/2)
|
565
706
|
nOctaves = min(nOctaves, np.log2(2 * np.floor(N / 2)))
|
566
707
|
scales = 2 ** np.arange(1, nOctaves, 1 / nNotes)
|
567
|
-
# cwt and the frequencies used.
|
708
|
+
# cwt and the frequencies used.
|
568
709
|
# Use the complex morelet with bw=2*bandwidth^2 and center frequency of 1.0
|
569
710
|
# bandwidth is sigma of the gaussian envelope
|
570
|
-
wavelet =
|
711
|
+
wavelet = "cmor" + str(2 * bandwidth**2) + "-1.0"
|
571
712
|
frequencies = pywt.scale2frequency(wavelet, scales) * fs
|
572
713
|
scales = scales[(frequencies >= freq_range[0]) & (frequencies <= freq_range[1])]
|
573
|
-
coef, frequencies = pywt.cwt(
|
574
|
-
|
714
|
+
coef, frequencies = pywt.cwt(
|
715
|
+
x, scales[::-1], wavelet=wavelet, sampling_period=1 / fs, axis=axis
|
716
|
+
)
|
717
|
+
power = np.real(coef * np.conj(coef)) # equivalent to power = np.abs(coef)**2
|
575
718
|
# cone of influence in terms of wavelength
|
576
719
|
coi = N / 2 - np.abs(np.arange(N) - (N - 1) / 2)
|
577
720
|
# cone of influence in terms of frequency
|
@@ -579,8 +722,9 @@ def cwt_spectrogram(x, fs, nNotes=6, nOctaves=np.inf, freq_range=(0, np.inf),
|
|
579
722
|
return power, times, frequencies, coif
|
580
723
|
|
581
724
|
|
582
|
-
def cwt_spectrogram_xarray(
|
583
|
-
|
725
|
+
def cwt_spectrogram_xarray(
|
726
|
+
x, fs, time=None, axis=-1, downsample_fs=None, channel_coords=None, **cwt_kwargs
|
727
|
+
):
|
584
728
|
"""Calculate spectrogram using continuous wavelet transform and return an xarray.Dataset
|
585
729
|
x: input array
|
586
730
|
fs: sampling frequency (Hz)
|
@@ -590,7 +734,7 @@ def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
|
|
590
734
|
cwt_kwargs: keyword arguments for cwt_spectrogram()
|
591
735
|
"""
|
592
736
|
x = np.asarray(x)
|
593
|
-
T = x.shape[axis]
|
737
|
+
T = x.shape[axis] # number of time points
|
594
738
|
t = np.arange(T) / fs if time is None else np.asarray(time)
|
595
739
|
if downsample_fs is None or downsample_fs >= fs:
|
596
740
|
downsample_fs = fs
|
@@ -601,10 +745,11 @@ def cwt_spectrogram_xarray(x, fs, time=None, axis=-1, downsample_fs=None,
|
|
601
745
|
downsampled, t = signal.resample(x, num=num, t=t, axis=axis)
|
602
746
|
downsampled = np.moveaxis(downsampled, axis, -1)
|
603
747
|
sxx, _, f, coif = cwt_spectrogram(downsampled, downsample_fs, **cwt_kwargs)
|
604
|
-
sxx = np.moveaxis(sxx, 0, -2)
|
748
|
+
sxx = np.moveaxis(sxx, 0, -2) # shape (... , freq, time)
|
605
749
|
if channel_coords is None:
|
606
|
-
channel_coords = {f
|
607
|
-
sxx = xr.DataArray(sxx, coords={**channel_coords,
|
608
|
-
|
750
|
+
channel_coords = {f"dim_{i:d}": range(d) for i, d in enumerate(sxx.shape[:-2])}
|
751
|
+
sxx = xr.DataArray(sxx, coords={**channel_coords, "frequency": f, "time": t}).to_dataset(
|
752
|
+
name="PSD"
|
753
|
+
)
|
754
|
+
sxx.update(dict(cone_of_influence_frequency=xr.DataArray(coif, coords={"time": t})))
|
609
755
|
return sxx
|
610
|
-
|