py2ls 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- py2ls/.git/COMMIT_EDITMSG +1 -0
- py2ls/.git/FETCH_HEAD +1 -0
- py2ls/.git/HEAD +1 -0
- py2ls/.git/config +15 -0
- py2ls/.git/description +1 -0
- py2ls/.git/hooks/applypatch-msg.sample +15 -0
- py2ls/.git/hooks/commit-msg.sample +24 -0
- py2ls/.git/hooks/fsmonitor-watchman.sample +174 -0
- py2ls/.git/hooks/post-update.sample +8 -0
- py2ls/.git/hooks/pre-applypatch.sample +14 -0
- py2ls/.git/hooks/pre-commit.sample +49 -0
- py2ls/.git/hooks/pre-merge-commit.sample +13 -0
- py2ls/.git/hooks/pre-push.sample +53 -0
- py2ls/.git/hooks/pre-rebase.sample +169 -0
- py2ls/.git/hooks/pre-receive.sample +24 -0
- py2ls/.git/hooks/prepare-commit-msg.sample +42 -0
- py2ls/.git/hooks/push-to-checkout.sample +78 -0
- py2ls/.git/hooks/update.sample +128 -0
- py2ls/.git/index +0 -0
- py2ls/.git/info/exclude +6 -0
- py2ls/.git/logs/HEAD +1 -0
- py2ls/.git/logs/refs/heads/main +1 -0
- py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
- py2ls/.git/logs/refs/remotes/origin/main +1 -0
- py2ls/.git/objects/25/b796accd261b9135fd32a2c00785f68edf6c46 +0 -0
- py2ls/.git/objects/36/b4a1b7403abc6c360f8fe2cb656ab945254971 +0 -0
- py2ls/.git/objects/3f/d6561300938afbb3d11976cf9c8f29549280d9 +0 -0
- py2ls/.git/objects/58/20a729045d4dc7e37ccaf8aa8eec126850afe2 +0 -0
- py2ls/.git/objects/60/f273eb1c412d916fa3f11318a7da7a9911b52a +0 -0
- py2ls/.git/objects/61/570cec8c061abe74121f27f5face6c69b98f99 +0 -0
- py2ls/.git/objects/69/13c452ca319f7cbf6a0836dc10a5bb033c84e4 +0 -0
- py2ls/.git/objects/78/3d4167bc95c9d2175e0df03ef1c1c880ba75ab +0 -0
- py2ls/.git/objects/79/7ae089b2212a937840e215276005ce76881307 +0 -0
- py2ls/.git/objects/7e/5956c806b5edc344d46dab599dec337891ba1f +1 -0
- py2ls/.git/objects/8e/55a7d2b96184030211f20c9b9af201eefcac82 +0 -0
- py2ls/.git/objects/91/c69ad88fe0ba94aa7859fb5f7edac5e6f1a3f7 +0 -0
- py2ls/.git/objects/b0/56be4be89ba6b76949dd641df45bb7036050c8 +0 -0
- py2ls/.git/objects/b0/9cd7856d58590578ee1a4f3ad45d1310a97f87 +0 -0
- py2ls/.git/objects/d9/005f2cc7fc4e65f14ed5518276007c08cf2fd0 +0 -0
- py2ls/.git/objects/df/e0770424b2a19faf507a501ebfc23be8f54e7b +0 -0
- py2ls/.git/objects/e9/391ffe371f1cc43b42ef09b705d9c767c2e14f +0 -0
- py2ls/.git/objects/fc/292e793ecfd42240ac43be407023bd731fa9e7 +0 -0
- py2ls/.git/refs/heads/main +1 -0
- py2ls/.git/refs/remotes/origin/HEAD +1 -0
- py2ls/.git/refs/remotes/origin/main +1 -0
- py2ls/.gitattributes +2 -0
- py2ls/.gitignore +152 -0
- py2ls/LICENSE +201 -0
- py2ls/README.md +409 -0
- py2ls/__init__.py +17 -0
- py2ls/brain_atlas.py +145 -0
- py2ls/correlators.py +475 -0
- py2ls/dbhandler.py +97 -0
- py2ls/freqanalysis.py +800 -0
- py2ls/internet_finder.py +405 -0
- py2ls/ips.py +2844 -0
- py2ls/netfinder.py +780 -0
- py2ls/sleep_events_detectors.py +1350 -0
- py2ls/translator.py +686 -0
- py2ls/version.py +1 -0
- py2ls/wb_detector.py +169 -0
- py2ls-0.1.0.dist-info/METADATA +12 -0
- py2ls-0.1.0.dist-info/RECORD +64 -0
- py2ls-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1350 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import pandas as pd
|
3
|
+
from scipy.signal import butter, hilbert, filtfilt, find_peaks,resample, resample_poly
|
4
|
+
from scipy.interpolate import interp1d
|
5
|
+
from scipy.ndimage import convolve1d
|
6
|
+
import os
|
7
|
+
import mne
|
8
|
+
import scipy.io
|
9
|
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
import seaborn as sns
|
12
|
+
# from multitaper_spectrogram_python import *
|
13
|
+
|
14
|
+
def load_mat(dir_mat):
|
15
|
+
return scipy.io.loadmat(dir_mat)
|
16
|
+
|
17
|
+
# def loadmat(dir_mat):
|
18
|
+
# return scipy.io.loadmat(dir_mat)
|
19
|
+
|
20
|
+
# Define data path
|
21
|
+
# data_file = "./UCLA_data/CSC4.Ncs"def
|
22
|
+
def load_ncs(dir_file, header_size=16 * 1024):
|
23
|
+
# Header has 16 kilobytes length
|
24
|
+
# header_size = 16 * 1024
|
25
|
+
|
26
|
+
# Open file
|
27
|
+
fid = open(dir_file, "rb")
|
28
|
+
|
29
|
+
# Skip header by shifting position by header size
|
30
|
+
fid.seek(header_size)
|
31
|
+
|
32
|
+
# Read data according to Neuralynx information
|
33
|
+
data_format = np.dtype(
|
34
|
+
[
|
35
|
+
("TimeStamp", np.uint64),
|
36
|
+
("ChannelNumber", np.uint32),
|
37
|
+
("SampleFreq", np.uint32),
|
38
|
+
("NumValidSamples", np.uint32),
|
39
|
+
("Samples", np.int16, 512),
|
40
|
+
]
|
41
|
+
)
|
42
|
+
|
43
|
+
raw = np.fromfile(fid, dtype=data_format)
|
44
|
+
# Close file
|
45
|
+
fid.close()
|
46
|
+
|
47
|
+
# filling output
|
48
|
+
res = {}
|
49
|
+
res["data"] = raw["Samples"].ravel() # Create data vector
|
50
|
+
res["fs"] = raw["SampleFreq"][0] # Get sampling frequency
|
51
|
+
res["dur_sec"] = (
|
52
|
+
res["data"].shape[0] / raw["SampleFreq"][0]
|
53
|
+
) # Determine duration of recording in seconds
|
54
|
+
res["time"] = np.linspace(
|
55
|
+
0, res["dur_sec"], res["data"].shape[0]
|
56
|
+
) # Create time vector
|
57
|
+
return pd.DataFrame(res)
|
58
|
+
|
59
|
+
def ncs2_single_raw(fpath, ch_names=None, ch_types=None):
|
60
|
+
ncs_data = load_ncs(fpath)
|
61
|
+
data = ncs_data["data"]
|
62
|
+
sfreq = ncs_data["fs"][0]
|
63
|
+
if ch_names is None:
|
64
|
+
ch_names = [os.path.splitext(os.path.basename(fpath))[0]]
|
65
|
+
if ch_types is None:
|
66
|
+
ch_types = "eeg" if "eeg" in ch_names[0].lower() else "eog" # should be 'lfp'
|
67
|
+
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
|
68
|
+
data_nchan_nsamples = np.array(data)[np.newaxis, :]
|
69
|
+
raw = mne.io.RawArray(data_nchan_nsamples, info)
|
70
|
+
return raw
|
71
|
+
|
72
|
+
def trial_extract(trial):
|
73
|
+
# the mat should be stored older version, otherwise, cannot be read
|
74
|
+
# recommend: mat_data = scipy.io.loadmat(dir_data) #import scipy.io
|
75
|
+
while (type(trial[0]) == "numpy.float64") or (len(trial[0]) == 1):
|
76
|
+
trial = trial[0].copy()
|
77
|
+
print(trial[0].shape)
|
78
|
+
return trial[0]
|
79
|
+
|
80
|
+
# # dir_data = "/Users/macjianfeng/Desktop/test_v7_dat.mat"
|
81
|
+
# dir_data = "/Users/macjianfeng/Desktop/mat_r06rec1.mat"
|
82
|
+
# mat_data = scipy.io.loadmat(dir_data)
|
83
|
+
|
84
|
+
# trials = trial_extract(mat_data["trial"])
|
85
|
+
# fs = mat_data["fsample"][0][0]
|
86
|
+
# label = repr(mat_data["label"][0][0][0]) # convert to 'str'
|
87
|
+
|
88
|
+
# print("first 12 trials: ", trials[:12])
|
89
|
+
# print("fs=", fs)
|
90
|
+
# print("label=", label, "type(label):", type(label))
|
91
|
+
def cal_slope(data, segment=1, correct=True):
|
92
|
+
length = len(data)
|
93
|
+
slope = []
|
94
|
+
for i in range(0, length - segment, segment):
|
95
|
+
change_in_y = data[i + segment] - data[i]
|
96
|
+
change_in_x = segment
|
97
|
+
slope.append(change_in_y / change_in_x)
|
98
|
+
if correct:
|
99
|
+
# Interpolate the slopes to fill in the gaps
|
100
|
+
interpolated_slopes = np.repeat(slope, segment)
|
101
|
+
# Adjust the length of interpolated_slopes to match the length of continuous_line
|
102
|
+
missing_values = len(data) - len(interpolated_slopes)
|
103
|
+
if missing_values > 0:
|
104
|
+
interpolated_slopes = np.append(
|
105
|
+
interpolated_slopes, [slope[-1]] * missing_values
|
106
|
+
)
|
107
|
+
return interpolated_slopes
|
108
|
+
else:
|
109
|
+
return slope
|
110
|
+
# Apply bandpass filter to EEG signal
|
111
|
+
def butter_band_filter(data, lowcut, highcut, fs, ord=3):
|
112
|
+
from scipy.signal import butter
|
113
|
+
|
114
|
+
nyq = 0.5 * fs
|
115
|
+
low = lowcut / nyq
|
116
|
+
high = highcut / nyq
|
117
|
+
b, a = butter(ord, [low, high], btype="band")
|
118
|
+
dat_spin_filt = filtfilt(b, a, data)
|
119
|
+
return dat_spin_filt
|
120
|
+
|
121
|
+
|
122
|
+
def butter_bandpass_filter(data=None, ord=4, freq_range=[11, 16], fs=1000):
|
123
|
+
from scipy.signal import butter
|
124
|
+
print("usage:\n butter_bandpass_filter(data=None, ord=4, freq_range=[11, 16], fs=1000)")
|
125
|
+
# alternative:
|
126
|
+
b, a = butter(ord, freq_range, btype="bandpass", fs=fs)
|
127
|
+
data_filt = filtfilt(b,a,data)
|
128
|
+
return data_filt
|
129
|
+
|
130
|
+
def filter_bandpass(data=None, ord=4, freq_range=[11, 16], fs=1000):
|
131
|
+
from scipy.signal import butter
|
132
|
+
# print("usage:\n butter_bandpass_filter(data=None, ord=4, freq_range=[11, 16], fs=1000)")
|
133
|
+
# alternative:
|
134
|
+
b, a = butter(ord, freq_range, btype="bandpass", fs=fs)
|
135
|
+
data_filt = filtfilt(b,a,data)
|
136
|
+
return data_filt
|
137
|
+
|
138
|
+
# Apply smoothing (moving average)
|
139
|
+
def moving_average(data, window_size):
|
140
|
+
return convolve1d(data, np.ones(window_size) / window_size)
|
141
|
+
|
142
|
+
def detect_cross(data, thr=0):
|
143
|
+
if isinstance(data, list):
|
144
|
+
data = np.array(data)
|
145
|
+
if data.ndim == 1:
|
146
|
+
pass
|
147
|
+
elif data.ndim == 2 and data.shape[0] > data.shape[1]:
|
148
|
+
data = data.T
|
149
|
+
else:
|
150
|
+
raise ValueError("Input data must have two dimensions.")
|
151
|
+
|
152
|
+
thr_cross = np.sign(data[:, np.newaxis] - thr)
|
153
|
+
falling_before = np.where((thr_cross[:-1] == 1) & (thr_cross[1:] == -1))[0] #+ 1
|
154
|
+
rising_before = np.where((thr_cross[:-1] == -1) & (thr_cross[1:] == 1))[0]
|
155
|
+
falling_before = falling_before.tolist()
|
156
|
+
rising_before = rising_before.tolist()
|
157
|
+
if rising_before and falling_before:
|
158
|
+
if rising_before[0] < falling_before[0]:
|
159
|
+
if len(rising_before) > len(falling_before):
|
160
|
+
rising_before.pop(0)
|
161
|
+
else:
|
162
|
+
falling_before.pop(0)
|
163
|
+
if len(rising_before) > len(falling_before):
|
164
|
+
rising_before.pop(0)
|
165
|
+
## debug
|
166
|
+
# a = np.sin(np.arange(0, 10 * np.pi, np.pi / 100))
|
167
|
+
|
168
|
+
# thres = 0.75
|
169
|
+
# rise, fall = detect_cross(a, thres)
|
170
|
+
# RisingFalling = np.column_stack((rise, fall))
|
171
|
+
# plt.figure(figsize=[5, 2])
|
172
|
+
# t = np.arange(len(a))
|
173
|
+
# plt.plot(t, a)
|
174
|
+
# for i in range(4):
|
175
|
+
# plt.plot(
|
176
|
+
# t[RisingFalling[i][0] : RisingFalling[i][1]],
|
177
|
+
# a[RisingFalling[i][0] : RisingFalling[i][1]],
|
178
|
+
# lw=10 - i,
|
179
|
+
# )
|
180
|
+
# plt.plot(
|
181
|
+
# t[RisingFalling[i][0] : RisingFalling[i + 1][0]],
|
182
|
+
# a[RisingFalling[i][0] : RisingFalling[i + 1][0]],
|
183
|
+
# lw=7 - i,
|
184
|
+
# )
|
185
|
+
# plt.plot(
|
186
|
+
# t[RisingFalling[i][0] : RisingFalling[i + 1][1]],
|
187
|
+
# a[RisingFalling[i][0] : RisingFalling[i + 1][1]],
|
188
|
+
# lw=5 - i,
|
189
|
+
# )
|
190
|
+
# plt.gca().axhline(thres)
|
191
|
+
return rising_before, falling_before
|
192
|
+
|
193
|
+
def find_repeats(data, N, nGap=None):
|
194
|
+
"""
|
195
|
+
Find the beginning and end points of repeated occurrences in a dataset.
|
196
|
+
|
197
|
+
Parameters:
|
198
|
+
data (list or numpy.ndarray): The dataset in which repeated occurrences are to be found.
|
199
|
+
N (int or list of int): The element(s) to search for.
|
200
|
+
nGap (int, optional): The number of elements that can appear between repeated occurrences.
|
201
|
+
Defaults to 1 if not provided.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
numpy.ndarray: An array containing the beginning and end points of repeated occurrences
|
205
|
+
of the specified element(s).
|
206
|
+
|
207
|
+
Description:
|
208
|
+
This function identifies the beginning and end points of repeated occurrences
|
209
|
+
of specified elements in a dataset. It searches for the element(s) specified
|
210
|
+
by `N` in the input `data` and returns the indices where consecutive occurrences
|
211
|
+
are separated by at most `nGap` elements.
|
212
|
+
|
213
|
+
Example:
|
214
|
+
data = [1, 2, 3, 4, 1, 2, 5, 1, 2, 2, 3]
|
215
|
+
idx = find_repeats(data, [1, 2])
|
216
|
+
print(idx) # Output: [[0, 2], [3, 5], [6, 8], [7, 9]]
|
217
|
+
|
218
|
+
idx = find_repeats(data, 2, 2)
|
219
|
+
print(idx) # Output: [[4, 8]]
|
220
|
+
|
221
|
+
"""
|
222
|
+
# Convert data to numpy array if it's a list
|
223
|
+
if isinstance(data, list):
|
224
|
+
data = np.array(data)
|
225
|
+
|
226
|
+
if nGap is None:
|
227
|
+
nGap = 1
|
228
|
+
|
229
|
+
if isinstance(N, int):
|
230
|
+
N = [N]
|
231
|
+
|
232
|
+
idx = []
|
233
|
+
for num in N:
|
234
|
+
if num in data:
|
235
|
+
idx_beg = [
|
236
|
+
i
|
237
|
+
for i, x in enumerate(data)
|
238
|
+
if x == num and (i == 0 or data[i - 1] != num)
|
239
|
+
]
|
240
|
+
idx_end = [
|
241
|
+
i - 1
|
242
|
+
for i, x in enumerate(data)
|
243
|
+
if x == num and (i == len(data) - 1 or data[i + 1] != num)
|
244
|
+
]
|
245
|
+
|
246
|
+
# Adjust indices for Python's zero-based indexing
|
247
|
+
# idx_beg = [i + 1 for i in idx_beg]
|
248
|
+
idx_end = [i + 1 for i in idx_end]
|
249
|
+
|
250
|
+
idx_array = list(zip(idx_beg, idx_end))
|
251
|
+
# Correct the first column of idx_array
|
252
|
+
idx_single = [
|
253
|
+
i
|
254
|
+
for i in range(len(idx_array))
|
255
|
+
if idx_array[i][1] - idx_array[i][0] == 0
|
256
|
+
]
|
257
|
+
for single in idx_single:
|
258
|
+
idx_array[single] = (idx_array[single][0], idx_array[single][1] + 1)
|
259
|
+
|
260
|
+
if nGap == 1:
|
261
|
+
idx.append([(beg, end) for beg, end in idx_array])
|
262
|
+
elif nGap > 1:
|
263
|
+
idx.append([(beg * nGap + 1, end * nGap) for beg, end in idx_array])
|
264
|
+
else:
|
265
|
+
idx.append([])
|
266
|
+
|
267
|
+
return np.concatenate(idx)
|
268
|
+
|
269
|
+
def find_continue(data, step=1):
|
270
|
+
"""
|
271
|
+
Find indices for the beginning and end of continuous segments in data.
|
272
|
+
|
273
|
+
Parameters:
|
274
|
+
data (numpy.ndarray): Input array.
|
275
|
+
step (int): Comparison difference. Default is 1.
|
276
|
+
|
277
|
+
Returns:
|
278
|
+
tuple: Tuple containing arrays of indices for the beginning and end of continuous segments.
|
279
|
+
"""
|
280
|
+
if isinstance(data, (list, np.ndarray)):
|
281
|
+
data = np.array(data) # Ensure data is a numpy array
|
282
|
+
if not isinstance(step, int):
|
283
|
+
raise TypeError("step must be an integer")
|
284
|
+
|
285
|
+
idx_beg = np.where(np.diff([-99999] + data.tolist()) != step)[0]
|
286
|
+
idx_end = np.where(np.diff(data.tolist() + [-99999]) != step)[0]
|
287
|
+
|
288
|
+
return idx_beg, idx_end
|
289
|
+
|
290
|
+
def resample_data(data, input_fs, output_fs, t=None, axis=0, window=None, domain='time',method='fourier', **kwargs):
|
291
|
+
"""
|
292
|
+
Downsample a signal to a target sampling frequency.
|
293
|
+
|
294
|
+
Parameters:
|
295
|
+
data (array-like):
|
296
|
+
The signal to be downsampled.
|
297
|
+
input_fs (float):
|
298
|
+
The original sampling frequency of the signal.
|
299
|
+
output_fs (float):
|
300
|
+
The target sampling frequency for downsampling.
|
301
|
+
# num int:
|
302
|
+
# The number of samples in the resampled signal.
|
303
|
+
t : array_like, optional
|
304
|
+
If t is given, it is assumed to be the equally spaced sample positions associated with the
|
305
|
+
signal data in x.
|
306
|
+
axis : int, optional
|
307
|
+
The axis of x that is resampled. Default is 0.
|
308
|
+
window: array_like, callable, string, float, or tuple, optional
|
309
|
+
Specifies the window applied to the signal in the Fourier domain. See below for details.
|
310
|
+
domain: string, optional
|
311
|
+
A string indicating the domain of the input x: time Consider the input x as time-domain (Default),
|
312
|
+
freq Consider the input x as frequency-domain.
|
313
|
+
Returns:
|
314
|
+
array-like: The downsampled signal.
|
315
|
+
"""
|
316
|
+
if input_fs < output_fs:
|
317
|
+
# raise ValueError(f"Target freq = {output_fs} must be <= {input_fs} Hz(input_fs) .")
|
318
|
+
# The resample function in scipy uses Fourier method by default for resampling. This method involves
|
319
|
+
# upsampling the signal in the frequency domain, then low-pass filtering to remove aliasing, and
|
320
|
+
# finally downsampling to the desired rate.
|
321
|
+
if method == 'fourier':
|
322
|
+
# Upsampling using Fourier method
|
323
|
+
factor = output_fs / input_fs
|
324
|
+
num_samples_new = int(len(data) * factor)
|
325
|
+
data_resampled = resample(data, num_samples_new)
|
326
|
+
print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using Fourier method")
|
327
|
+
|
328
|
+
elif method == 'linear':
|
329
|
+
# Upsampling using linear interpolation
|
330
|
+
t_original = np.arange(len(data)) / input_fs
|
331
|
+
t_resampled = np.arange(len(data) * output_fs / input_fs) / output_fs
|
332
|
+
f = interp1d(t_original, data, kind='linear', fill_value='extrapolate')
|
333
|
+
data_resampled = f(t_resampled)
|
334
|
+
print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using linear interpolation")
|
335
|
+
|
336
|
+
elif method == 'spline':
|
337
|
+
# Upsampling using spline interpolation
|
338
|
+
t_original = np.arange(len(data)) / input_fs
|
339
|
+
t_resampled = np.arange(len(data) * output_fs / input_fs) / output_fs
|
340
|
+
f = interp1d(t_original, data, kind='cubic', fill_value='extrapolate')
|
341
|
+
data_resampled = f(t_resampled)
|
342
|
+
print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using spline interpolation")
|
343
|
+
|
344
|
+
elif method == 'sinc':
|
345
|
+
# Upsampling using windowed sinc interpolation
|
346
|
+
upsample_factor = output_fs / input_fs
|
347
|
+
data_resampled = resample_poly(data, int(len(data) * upsample_factor), 1)
|
348
|
+
print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using sinc interpolation")
|
349
|
+
|
350
|
+
elif method == 'zero':
|
351
|
+
# Upsampling by zero-padding
|
352
|
+
upsample_factor = output_fs / input_fs
|
353
|
+
data_resampled = np.repeat(data, int(upsample_factor))
|
354
|
+
print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using zero-padding")
|
355
|
+
|
356
|
+
elif input_fs == output_fs:
|
357
|
+
print(f"Input = output = {output_fs} Hz. \nNo resampling is performed.")
|
358
|
+
return data
|
359
|
+
else:
|
360
|
+
if method == 'fourier':
|
361
|
+
# Calculate the resampling factor
|
362
|
+
resampling_factor = input_fs / output_fs
|
363
|
+
|
364
|
+
# Calculate the new number of samples
|
365
|
+
num_samples_new = int(len(data) / resampling_factor)
|
366
|
+
|
367
|
+
# Perform resampling
|
368
|
+
data_resampled = resample(data, num_samples_new)
|
369
|
+
|
370
|
+
print(f"Original data {input_fs} Hz\nResampled data {output_fs} Hz using Fourier method")
|
371
|
+
elif method == 'decimate':
|
372
|
+
# Downsampling using decimate function (which internally uses FIR filter)
|
373
|
+
decimation_factor = int(input_fs / output_fs)
|
374
|
+
data_resampled = decimate(data, decimation_factor, zero_phase=True)
|
375
|
+
print(f"Original data {input_fs} Hz\nDownsampled data {output_fs} Hz using decimation")
|
376
|
+
|
377
|
+
return data_resampled
|
378
|
+
|
379
|
+
|
380
|
+
def extract_score(dir_score, kron_go=True,scalar=10*1000):
|
381
|
+
df_score = pd.read_csv(dir_score, sep="\t")
|
382
|
+
score_val = df_score[df_score.columns[1]].values
|
383
|
+
if kron_go:
|
384
|
+
score_val = np.kron(
|
385
|
+
score_val, np.ones((1,round(scalar)))
|
386
|
+
).reshape(1, -1)
|
387
|
+
return score_val[0]
|
388
|
+
|
389
|
+
def plot_sleeparc(
|
390
|
+
ax,
|
391
|
+
score,
|
392
|
+
code_org=[1, 2, 3],
|
393
|
+
code_new=[1, 0, -1],
|
394
|
+
kind="patch",
|
395
|
+
c=["#474747", "#0C5DA5", "#0C5DA5"],
|
396
|
+
):
|
397
|
+
score[np.where(score == code_org[0])] = code_new[0]
|
398
|
+
score[np.where(score == code_org[1])] = code_new[1]
|
399
|
+
score[np.where(score == code_org[2])] = code_new[2]
|
400
|
+
dist = code_new[0] - code_new[1]
|
401
|
+
colorlist = c
|
402
|
+
# patch method
|
403
|
+
if "pa" in kind.lower():
|
404
|
+
for i in range(score.shape[0]):
|
405
|
+
x = [i - 1, i, i, i - 1]
|
406
|
+
if score[i] == code_new[0]:
|
407
|
+
y = [code_new[0], code_new[0], code_new[0] + dist, code_new[0] + dist]
|
408
|
+
c = colorlist[0]
|
409
|
+
elif score[i] == code_new[1]:
|
410
|
+
y = [code_new[1], code_new[1], code_new[0], code_new[0]]
|
411
|
+
c = colorlist[1]
|
412
|
+
elif score[i] == code_new[2]:
|
413
|
+
y = [code_new[2], code_new[2], code_new[1], code_new[1]]
|
414
|
+
c = colorlist[2]
|
415
|
+
ax.fill(x, y, c=c, edgecolor="none")
|
416
|
+
# line method
|
417
|
+
if "l" in kind.lower():
|
418
|
+
ax.plot(score, c=colorlist[0])
|
419
|
+
return ax
|
420
|
+
|
421
|
+
def filter_linenoise(
|
422
|
+
data, fs, method="notch", bandwidth=2, n_components=None, random_state=None
|
423
|
+
):
|
424
|
+
from scipy import signal
|
425
|
+
|
426
|
+
nyquist = fs / 2
|
427
|
+
freq = np.arange(1, 100) * 50
|
428
|
+
freq = list(freq[np.where(freq <= nyquist)])
|
429
|
+
print(f"nyquist={nyquist}hz, clearned freq:{freq}")
|
430
|
+
if method == "notch":
|
431
|
+
for f0 in freq:
|
432
|
+
|
433
|
+
w0 = f0 / nyquist # Normalized Frequency
|
434
|
+
# Design notch filter
|
435
|
+
b_, a_ = signal.iirnotch(w0, Q=bandwidth)
|
436
|
+
if f0 == freq[0]:
|
437
|
+
clean_data = signal.lfilter(b_, a_, data)
|
438
|
+
else:
|
439
|
+
clean_data = signal.lfilter(b_, a_, clean_data)
|
440
|
+
elif method == "bandstop":
|
441
|
+
b_, a_ = signal.iirnotch(w0=freq / nyquist, Q=bandwidth, fs=fs)
|
442
|
+
clean_data = signal.filtfilt(b_, a_, data, axis=-1)
|
443
|
+
elif method == "ica": # Independent Component Analysis (ICA)
|
444
|
+
from sklearn.decomposition import FastICA
|
445
|
+
|
446
|
+
ica = FastICA(n_components=n_components, random_state=random_state)
|
447
|
+
ica_components = ica.fit_transform(data.T).T
|
448
|
+
clean_data = ica.inverse_transform(ica_components)
|
449
|
+
|
450
|
+
return clean_data
|
451
|
+
|
452
|
+
def MAD(signal, mSTD=1.4826):
|
453
|
+
"""
|
454
|
+
to calculate the MedianAbsoluteDeviation (MAD) of a 1D array (signal)
|
455
|
+
Parameters:
|
456
|
+
signal: 1D array of data.
|
457
|
+
mSTD: Multiplier used to scale the MAD. Default value is 1.4826, which is a scaling
|
458
|
+
factor to make MAD asymptotically consistent for the estimation of standard deviation
|
459
|
+
under the assumption of a normal distribution.
|
460
|
+
Output:
|
461
|
+
MAD: The Median Absolute Deviation of the input signal.
|
462
|
+
Explanation:
|
463
|
+
The function calculates the median of the input signal using np.median(signal).
|
464
|
+
It then calculates the absolute deviations of each element in the signal from the median
|
465
|
+
The median of these absolute deviations is then calculated using np.median().
|
466
|
+
Finally, the MAD is obtained by multiplying this median absolute deviation by the scaling factor mSTD.
|
467
|
+
"""
|
468
|
+
# signal is a 1-d array
|
469
|
+
medians = np.median(signal)
|
470
|
+
MAD = mSTD * np.median(np.abs(signal - medians))
|
471
|
+
return MAD
|
472
|
+
|
473
|
+
|
474
|
+
def detect_spikes(data, Fs, mSTD=1.4826, direction="both"):
|
475
|
+
"""
|
476
|
+
Purpose: This function detects spikes in the input data.
|
477
|
+
Parameters:
|
478
|
+
data: The input signal data.
|
479
|
+
Fs: The sampling rate of the signal in Hz.
|
480
|
+
mSTD: The multiplier used to compute the threshold. Default is set to 1.4826,
|
481
|
+
which is equivalent to the standard deviation of a Gaussian distribution. This
|
482
|
+
value is typically used with the Median Absolute Deviation (MAD) to estimate
|
483
|
+
the threshold.
|
484
|
+
direction: Specifies the direction of spikes to detect. Default is set to "both",
|
485
|
+
meaning it detects both positive and negative spikes.
|
486
|
+
Return:
|
487
|
+
spike_idx: Indices of the detected spikes in the input data.
|
488
|
+
time: Time stamps of the detected spikes in milliseconds.
|
489
|
+
Explanation:
|
490
|
+
The function computes the threshold (thr_) based on the Median Absolute Deviation
|
491
|
+
(MAD) of the input data.
|
492
|
+
It then detects spikes based on the specified direction using the computed threshold.
|
493
|
+
The detected spike indices and their corresponding time stamps are returned.
|
494
|
+
"""
|
495
|
+
# Detect spikes
|
496
|
+
# s, t = detectSpikes(x,Fs) detects spikes in x, where Fs the sampling
|
497
|
+
# rate (in Hz). The outputs s and t are column vectors of spike times in
|
498
|
+
# samples and ms, respectively. By convention the time of the zeroth
|
499
|
+
# sample is 0 ms.
|
500
|
+
|
501
|
+
# if is_dataframe(x):
|
502
|
+
# pass
|
503
|
+
# mad_ = (x - x.mean()).abs().mean()
|
504
|
+
spike_idx = {}
|
505
|
+
t = {}
|
506
|
+
thr_ = MAD(data, mSTD=mSTD)
|
507
|
+
# thr_ = np.percentile(signal, 66.6)
|
508
|
+
|
509
|
+
# # alternative: thr_ is set to 8.0, and the threshold is calculated as 8 times
|
510
|
+
# # the square root of the variance of the first 10 seconds of data
|
511
|
+
# thr_ = mSTD * np.sqrt(x[col][: 10 * Fs].astype("float64"))
|
512
|
+
|
513
|
+
directions = ["both", "pos", "neg"] # 'both': detect both posi and neg.
|
514
|
+
if direction in "both":
|
515
|
+
time = [x / Fs for x in np.where(np.abs(data) >= thr_)]
|
516
|
+
spike_idx, _ = find_peaks(np.abs(data), height=thr_)
|
517
|
+
|
518
|
+
elif direction in "positive":
|
519
|
+
time = [x / Fs for x in np.where(data >= thr_)]
|
520
|
+
spike_idx, _ = find_peaks(data, height=thr_)
|
521
|
+
elif direction in "negative":
|
522
|
+
time = [x / Fs for x in np.where((-1) * data <= thr_)]
|
523
|
+
spike_idx, _ = find_peaks((-1) * data, height=thr_)
|
524
|
+
|
525
|
+
return spike_idx, time
|
526
|
+
|
527
|
+
def extract_spikes_waveforms(data, Fs, mSTD=1.4826, win=[-10, 20], direction="both"):
|
528
|
+
"""
|
529
|
+
Purpose: This function extracts waveforms of detected spikes from the input data.
|
530
|
+
Parameters:
|
531
|
+
data: The input signal data.
|
532
|
+
Fs: The sampling rate of the signal in Hz.
|
533
|
+
mSTD: The multiplier used to compute the threshold. Default is set to 1.4826, which is
|
534
|
+
equivalent to the standard deviation of a Gaussian distribution. This value is typically
|
535
|
+
used with the Median Absolute Deviation (MAD) to estimate the threshold.
|
536
|
+
win: The window size around each detected spike to extract the waveform. It is specified as a
|
537
|
+
list [start_offset, end_offset], where start_offset is the number of samples before the
|
538
|
+
spike index and end_offset is the number of samples after the spike index.
|
539
|
+
direction: Specifies the direction of spikes to detect. Default is set to "both", meaning it
|
540
|
+
extracts waveforms for both positive and negative spikes.
|
541
|
+
Output:
|
542
|
+
waveforms: Extracted waveforms of the detected spikes.
|
543
|
+
It first detects spikes using the detect_spikes() function.
|
544
|
+
For each detected spike, it extracts the waveform within the specified window
|
545
|
+
around the spike index.
|
546
|
+
The extracted waveforms are stored in a NumPy array and returned.
|
547
|
+
"""
|
548
|
+
spike_idx, _ = detect_spikes(data, Fs, mSTD=mSTD, direction=direction)
|
549
|
+
num_spikes = len(spike_idx)
|
550
|
+
win_start = win[0] + 1
|
551
|
+
win_end = win[1] + 1
|
552
|
+
waveform_length = win_end - win_start
|
553
|
+
|
554
|
+
waveforms = np.empty((num_spikes, waveform_length)) * np.nan
|
555
|
+
|
556
|
+
for i, idx in enumerate(spike_idx):
|
557
|
+
start_idx = int(idx + win_start)
|
558
|
+
end_idx = int(idx + win_end)
|
559
|
+
|
560
|
+
# Ensure the start and end indices are within bounds
|
561
|
+
if start_idx >= 0 and end_idx <= data.shape[0]:
|
562
|
+
waveforms[i, :] = data[start_idx:end_idx]
|
563
|
+
|
564
|
+
# Remove rows with NaN values (corresponding to spikes outside bounds)
|
565
|
+
waveforms = waveforms[~np.isnan(waveforms).all(axis=1)]
|
566
|
+
|
567
|
+
print(f"Extracted waveforms number: {waveforms.shape}")
|
568
|
+
|
569
|
+
return waveforms
|
570
|
+
|
571
|
+
def extract_peaks_waveforms(data, pks, win_size):
|
572
|
+
"""
|
573
|
+
Extracts waveforms from data centered around peak indices.
|
574
|
+
|
575
|
+
Parameters:
|
576
|
+
- data (1d-array): The data array from which waveforms are to be extracted.
|
577
|
+
- pks (list): A list of peak indices.
|
578
|
+
- win_size: A tuple specifying the window size around each peak.
|
579
|
+
It should be of the form (start_offset, end_offset).
|
580
|
+
|
581
|
+
Returns:
|
582
|
+
- waveform_array: A 2D NumPy array containing the extracted waveforms.
|
583
|
+
Each row corresponds to a waveform.
|
584
|
+
"""
|
585
|
+
waveforms = []
|
586
|
+
for i in pks:
|
587
|
+
start_index = int(i + win_size[0])
|
588
|
+
end_index = int(i + win_size[1])
|
589
|
+
waveforms.append(data[start_index:end_index])
|
590
|
+
waveform_array = np.array(waveforms)
|
591
|
+
return waveform_array
|
592
|
+
# usage: win_size = [-500, 500]
|
593
|
+
# waveform = extract_peaks_waveforms(
|
594
|
+
# data=data_ds, pks=res_sos.sos.pks_neg_idx, win_size=win_size
|
595
|
+
# )
|
596
|
+
# Function to find the closest timestamp
|
597
|
+
def find_closest_timestamp(time_rel, time_fly):
|
598
|
+
closest_idx = np.argmin(np.abs(time_fly - time_rel))
|
599
|
+
return time_fly[closest_idx]
|
600
|
+
|
601
|
+
|
602
|
+
def coupling_finder(rel_pks, fly_pks, win, verbose=False):
|
603
|
+
pks_cp_rel = []
|
604
|
+
pks_cp_fly = []
|
605
|
+
for rel_pk in rel_pks:
|
606
|
+
closest_pt = find_closest_timestamp(rel_pk, fly_pks)
|
607
|
+
delta_t = closest_pt - rel_pk
|
608
|
+
if abs(delta_t) <= win:
|
609
|
+
pks_cp_rel.append(rel_pk)
|
610
|
+
pks_cp_fly.append(closest_pt)
|
611
|
+
|
612
|
+
# Calculate coupling rate
|
613
|
+
if not pks_cp_rel:
|
614
|
+
cp_rate = 0
|
615
|
+
else:
|
616
|
+
cp_rate = (len(pks_cp_rel) / len(rel_pks)) * 100
|
617
|
+
|
618
|
+
if verbose:
|
619
|
+
print(f"Coupling Rate: {cp_rate}%")
|
620
|
+
|
621
|
+
return pks_cp_rel, pks_cp_fly, cp_rate
|
622
|
+
|
623
|
+
def perm_circ(data_1d):
|
624
|
+
# to perform circular permutation
|
625
|
+
permuted_1d = np.roll(data_1d, np.random.randint(len(data_1d)))
|
626
|
+
return permuted_1d
|
627
|
+
|
628
|
+
|
629
|
+
def coupling_permutate(rel_pks, fly_pks, win, n_perm=1000):
|
630
|
+
# Function to simulate SWR-shuffled condition
|
631
|
+
pks_cp_rel_shuf, pks_cp_fly_shuf, cp_rate_shuf = [], [], []
|
632
|
+
for _ in range(n_perm):
|
633
|
+
pks_rel_shuf = perm_circ(rel_pks)
|
634
|
+
pks_cp_rel_tmp, pks_cp_fly_tmp, cp_rate_tmp = coupling_finder(
|
635
|
+
rel_pks=pks_rel_shuf, fly_pks=fly_pks, win=win, verbose=False
|
636
|
+
)
|
637
|
+
pks_cp_rel_shuf.append(pks_cp_rel_tmp)
|
638
|
+
pks_cp_fly_shuf.append(pks_cp_fly_tmp)
|
639
|
+
cp_rate_shuf.append(cp_rate_tmp)
|
640
|
+
return pks_cp_rel_shuf, pks_cp_fly_shuf, cp_rate_shuf
|
641
|
+
|
642
|
+
def detect_spindles(data, opt):
|
643
|
+
# usage: prepare the opt cfg
|
644
|
+
# opt = pd.DataFrame(
|
645
|
+
# {
|
646
|
+
# "spin": {
|
647
|
+
# "thr": [1.5, 2, 2.5],
|
648
|
+
# "dur_sec": [0.5, 2.5],
|
649
|
+
# "freq_range": [11, 16],
|
650
|
+
# "stage": 2,
|
651
|
+
# "smth": False,
|
652
|
+
# },
|
653
|
+
# "info": {
|
654
|
+
# "fs": 1000,
|
655
|
+
# "epoch_dur": 10,
|
656
|
+
# "dir_score": "/Users/macjianfeng/DataCenter/Meflo-SSD/Data_Scored_Txt/R6Rec1_scoring.txt",
|
657
|
+
# },
|
658
|
+
# }
|
659
|
+
# )
|
660
|
+
|
661
|
+
fs = opt["info"]["fs"]
|
662
|
+
epoch_dur = opt["info"]["epoch_dur"]
|
663
|
+
|
664
|
+
# Filter
|
665
|
+
# amp_filt = butter_band_filter(
|
666
|
+
# data=data,
|
667
|
+
# lowcut=opt["spin"]["freq_range"][0],
|
668
|
+
# highcut=opt["spin"]["freq_range"][1],
|
669
|
+
# fs=fs,
|
670
|
+
# )
|
671
|
+
amp_filt = filter_bandpass(
|
672
|
+
ord=opt['spin']['filt_ord'],
|
673
|
+
data=data,
|
674
|
+
freq_range=opt["spin"]["freq_range"],
|
675
|
+
fs=fs,
|
676
|
+
)
|
677
|
+
|
678
|
+
# Calculate amp_filt_env using Hilbert transform
|
679
|
+
amp_filt_env = np.abs(hilbert(amp_filt))
|
680
|
+
|
681
|
+
# Apply additional smoothing (moving average with 200-ms window size)
|
682
|
+
if opt["spin"]["smth"]:
|
683
|
+
# Calculate mean and standard deviation of amp_filt_env
|
684
|
+
amp_filt_env_mean = np.mean(moving_average(amp_filt_env, int(0.2 * fs)))
|
685
|
+
amp_filt_env_std = np.std(moving_average(amp_filt_env, int(0.2 * fs)))
|
686
|
+
else:
|
687
|
+
# Calculate mean and standard deviation of amp_filt_env
|
688
|
+
amp_filt_env_mean = np.mean(amp_filt_env)
|
689
|
+
amp_filt_env_std = np.std(amp_filt_env)
|
690
|
+
|
691
|
+
# 2.3 filling in one matrix
|
692
|
+
Thr = []
|
693
|
+
for m_std in opt["spin"]["thr"]:
|
694
|
+
Thr.append(amp_filt_env_std * m_std)
|
695
|
+
# 2.4 use the defined Thresholds
|
696
|
+
if len(Thr) >= 1:
|
697
|
+
a, b = detect_cross(amp_filt_env, Thr[0])
|
698
|
+
else:
|
699
|
+
raise ValueError("Didn not find the 1st spi.Thr")
|
700
|
+
|
701
|
+
RisingFalling = np.column_stack((a, b))
|
702
|
+
dur_sec = opt["spin"]["dur_sec"]
|
703
|
+
Dura_tmp = np.diff(RisingFalling, axis=1)
|
704
|
+
Thr1Spin1 = RisingFalling[
|
705
|
+
np.where((dur_sec[0] * fs < Dura_tmp) & (Dura_tmp < dur_sec[1] * fs)),
|
706
|
+
:,
|
707
|
+
][0]
|
708
|
+
Thr1Spin1 = Thr1Spin1.reshape(-1, 2)
|
709
|
+
|
710
|
+
# 2.4.1.2 calcultion the EventsSpin1 in NREM (specific sleep stages) or not
|
711
|
+
score_code = extract_score(opt["info"]["dir_score"],scalar=fs * epoch_dur)
|
712
|
+
stage_spin_idx = find_repeats(
|
713
|
+
score_code, opt["spin"]["stage"], nGap=1
|
714
|
+
)
|
715
|
+
stage_spin_idx = stage_spin_idx[
|
716
|
+
np.where(
|
717
|
+
(stage_spin_idx[:, 0] >= fs * epoch_dur)
|
718
|
+
& (stage_spin_idx[:, 1] <= len(amp_filt) - fs * epoch_dur)
|
719
|
+
)
|
720
|
+
]
|
721
|
+
EventsSpin1 = Thr1Spin1[
|
722
|
+
np.where(
|
723
|
+
(stage_spin_idx[:, 0] < Thr1Spin1[:, 0].reshape(-1, 1))
|
724
|
+
& (Thr1Spin1[:, 1].reshape(-1, 1) < stage_spin_idx[:, 1])
|
725
|
+
)[0],
|
726
|
+
:,
|
727
|
+
]
|
728
|
+
# print("step1", EventsSpin1.shape)
|
729
|
+
# 2.4.2 Thr2 crossing
|
730
|
+
if len(Thr) >= 2:
|
731
|
+
a, b = detect_cross(amp_filt_env, Thr[1])
|
732
|
+
RisingFalling = np.column_stack((a, b))
|
733
|
+
SpinDura_min2 = dur_sec[0] / 2 # half of the minium duration
|
734
|
+
SpinDura_max2 = dur_sec[1]
|
735
|
+
Dura_tmp = np.diff(RisingFalling, axis=1)
|
736
|
+
EventsSpin2 = RisingFalling[
|
737
|
+
np.where(
|
738
|
+
(SpinDura_min2 * fs <= Dura_tmp) & (Dura_tmp <= SpinDura_max2 * fs)
|
739
|
+
),
|
740
|
+
:,
|
741
|
+
]
|
742
|
+
else:
|
743
|
+
EventsSpin2 = np.copy(EventsSpin1)
|
744
|
+
EventsSpin2 = EventsSpin2.reshape(-1, 2)
|
745
|
+
# print("step2", EventsSpin2.shape)
|
746
|
+
# 2.4.2.3 check EventsSpin2 in EventsSpin1
|
747
|
+
EventsSpin3 = []
|
748
|
+
if (
|
749
|
+
("EventsSpin1" in locals())
|
750
|
+
and ("EventsSpin2" in locals())
|
751
|
+
and (EventsSpin1.shape[0] != 0)
|
752
|
+
and (EventsSpin2.shape[0] != 0)
|
753
|
+
):
|
754
|
+
EventsSpin3 = EventsSpin1[
|
755
|
+
np.where(
|
756
|
+
(EventsSpin1[:, 0] < EventsSpin2[:, 0].reshape(-1, 1))
|
757
|
+
& (EventsSpin2[:, 1].reshape(-1, 1) < EventsSpin1[:, 1])
|
758
|
+
)[1],
|
759
|
+
:,
|
760
|
+
]
|
761
|
+
# print("step3", EventsSpin3.shape)
|
762
|
+
# 2.4.2.4 unique EventsSpin3
|
763
|
+
if EventsSpin3.shape[0] != 0:
|
764
|
+
EventsSpin3_orgs = np.copy(EventsSpin3)
|
765
|
+
EventsSpin3 = np.unique(EventsSpin3_orgs[:, 0:2], axis=0)
|
766
|
+
|
767
|
+
if len(Thr) >= 3:
|
768
|
+
# 2.4.3 Crossing positions - Thr 3
|
769
|
+
EventsSpin4 = []
|
770
|
+
iSpi4 = 0
|
771
|
+
if EventsSpin3.shape[0] != 0:
|
772
|
+
for iSpi3 in range(EventsSpin3.shape[0]):
|
773
|
+
if (
|
774
|
+
np.max(amp_filt_env[EventsSpin3[iSpi3, 0] : EventsSpin3[iSpi3, 1]])
|
775
|
+
>= Thr[2]
|
776
|
+
):
|
777
|
+
EventsSpin4.append([EventsSpin3[iSpi3, 0], EventsSpin3[iSpi3, 1]])
|
778
|
+
iSpi4 += 1
|
779
|
+
if isinstance(EventsSpin4, list):
|
780
|
+
EventsSpin4 = np.array(EventsSpin4)
|
781
|
+
EventsSpin4 = EventsSpin4.reshape(-1, 2)
|
782
|
+
else:
|
783
|
+
EventsSpin4 = EventsSpin3.copy()
|
784
|
+
else:
|
785
|
+
EventsSpin4 = EventsSpin3.copy()
|
786
|
+
print("\ncannot find the 3rd Thr_spin, only 2 Thr were used for spin dtk \n")
|
787
|
+
# print("step4", EventsSpin4.shape)
|
788
|
+
# 2.5 checking if two spindles are too close? gap should be more than 50 ms;
|
789
|
+
if "EventsSpin4" in locals() and EventsSpin4.shape[0] != 0:
|
790
|
+
iSpin4 = 0
|
791
|
+
EventsSpin5 = []
|
792
|
+
for iSpin in range(1, EventsSpin4.shape[0]):
|
793
|
+
tmp_gap = (
|
794
|
+
EventsSpin4[iSpin, 0] - EventsSpin4[iSpin - 1, 1]
|
795
|
+
) / fs # in second
|
796
|
+
if (
|
797
|
+
tmp_gap <= 0.05
|
798
|
+
): # gap less than SpinDura_min and the total duration should not more than SpinDura_max1
|
799
|
+
EventsSpin5.append([EventsSpin4[iSpin - 1, 0], EventsSpin4[iSpin, 1]])
|
800
|
+
else:
|
801
|
+
EventsSpin5.append(list(EventsSpin4[iSpin]))
|
802
|
+
iSpin4 += 1
|
803
|
+
else:
|
804
|
+
EventsSpin5 = EventsSpin4.copy()
|
805
|
+
if isinstance(EventsSpin5, list):
|
806
|
+
EventsSpin5 = np.array(EventsSpin5)
|
807
|
+
EventsSpin5 = EventsSpin5.reshape(-1, 2)
|
808
|
+
# print("step5", EventsSpin5.shape)
|
809
|
+
if "EventsSpin5" in locals():
|
810
|
+
# 2.5.2 merge into one spindles
|
811
|
+
if EventsSpin5.shape[0] != 0:
|
812
|
+
EventsSpin5_diff_merge = np.where(
|
813
|
+
np.diff(np.hstack((0, EventsSpin5[:, 0]))) == 0
|
814
|
+
)[0]
|
815
|
+
EventsSpin5_diff_rm = EventsSpin5_diff_merge - 1
|
816
|
+
EventsSpin5 = np.delete(
|
817
|
+
EventsSpin5, EventsSpin5_diff_rm, axis=0
|
818
|
+
) # remove the merged parts
|
819
|
+
del EventsSpin5_diff_rm, EventsSpin5_diff_merge
|
820
|
+
|
821
|
+
# 2.5.3 remove the last 5s recording;
|
822
|
+
RecTail = len(amp_filt) # in sample resolution
|
823
|
+
Last5s = RecTail - epoch_dur / 2 * fs # half epoch
|
824
|
+
if EventsSpin5.shape[0] != 0:
|
825
|
+
for iSpin in range(EventsSpin5.shape[0], 0, -1):
|
826
|
+
if EventsSpin5[iSpin - 1, 1] <= Last5s:
|
827
|
+
EventsSpin6 = EventsSpin5[0:iSpin, :]
|
828
|
+
break
|
829
|
+
else:
|
830
|
+
EventsSpin6 = EventsSpin5.copy()
|
831
|
+
else:
|
832
|
+
EventsSpin6 = EventsSpin5.copy()
|
833
|
+
else:
|
834
|
+
EventsSpin6 = EventsSpin5.copy()
|
835
|
+
EventsSpin6 = EventsSpin6.reshape(-1, 2)
|
836
|
+
# print("step6", EventsSpin6.shape)
|
837
|
+
# 2.5.3 spin2spin duration should not beyond the 'SpinDura_max2'
|
838
|
+
if len(Thr) >= 2:
|
839
|
+
EventsSpin = EventsSpin6[
|
840
|
+
np.where(np.diff(EventsSpin6, axis=1) <= SpinDura_max2 * fs)[0], :
|
841
|
+
]
|
842
|
+
else:
|
843
|
+
EventsSpin = EventsSpin6.copy()
|
844
|
+
EventsSpin = EventsSpin.reshape(-1, 2)
|
845
|
+
# print("final detected spindle number: ", EventsSpin.shape)
|
846
|
+
# 2.6 Spindle density (counts during NREM) Spin density (events/min)
|
847
|
+
# calculated as the number of spindle detected in each recording site
|
848
|
+
# divided by the time in SWS.
|
849
|
+
SpinDensity = EventsSpin.shape[0] / (
|
850
|
+
np.sum(np.diff(stage_spin_idx, axis=1)) / fs / 60
|
851
|
+
) # in minute
|
852
|
+
# print("spindle density: ", SpinDensity)
|
853
|
+
|
854
|
+
# Freq of each Spindles
|
855
|
+
num_spin_pks = []
|
856
|
+
for i in range(EventsSpin.shape[0]):
|
857
|
+
peaks, _ = find_peaks(
|
858
|
+
amp_filt[EventsSpin[i, 0] : EventsSpin[i, 1]], height=Thr[0]
|
859
|
+
) # Assuming Thr is a scalar
|
860
|
+
num_spin_pks.append(len(peaks))
|
861
|
+
|
862
|
+
dur_spin_sec = np.diff(EventsSpin, axis=1) / fs
|
863
|
+
spin_freq = [
|
864
|
+
(x / y).tolist()[0] for (x, y) in zip(np.array(num_spin_pks), dur_spin_sec)
|
865
|
+
]
|
866
|
+
spin_avg_freq = np.nanmean(spin_freq, axis=0)
|
867
|
+
# print(f"Average spindle frequency: {spin_avg_freq:.4f} Hz")
|
868
|
+
|
869
|
+
# Spindle Power
|
870
|
+
spin_pow_single = []
|
871
|
+
for iPow in range(EventsSpin.shape[0]):
|
872
|
+
spin_pow_single.append(np.trapz(amp_filt_env[EventsSpin[i, 0] : EventsSpin[i, 1]]))
|
873
|
+
|
874
|
+
# find the max pks loc
|
875
|
+
if EventsSpin.shape[0] > 1:
|
876
|
+
spin_pk2pk = []
|
877
|
+
spin_pks_loc = []
|
878
|
+
loc_max_spin=[]
|
879
|
+
ipk = 0
|
880
|
+
for ispin in range(EventsSpin.shape[0]):
|
881
|
+
tmp = amp_filt[
|
882
|
+
EventsSpin[ispin, 0] : EventsSpin[ispin, 1] + 1
|
883
|
+
] # +1 to include the end index
|
884
|
+
# (1) find pks_max
|
885
|
+
locs_max, _ = find_peaks(list(tmp))
|
886
|
+
pks_max = tmp[locs_max]
|
887
|
+
pks_max_spin = np.max(pks_max)
|
888
|
+
loc_max_spin_ = (
|
889
|
+
locs_max[np.where(pks_max == pks_max_spin)[0][0]] + EventsSpin[ispin, 0]
|
890
|
+
)
|
891
|
+
loc_max_spin.append(loc_max_spin_)
|
892
|
+
# (2) find pks_min
|
893
|
+
pks_min = tmp[locs_max]
|
894
|
+
pks_min = pks_min * (-1) # don't forget to multiply by -1
|
895
|
+
pks_min_spin = np.min(pks_min)
|
896
|
+
loc_min_spin_ = (
|
897
|
+
locs_max[np.where(pks_min == pks_min_spin)[0][0]] + EventsSpin[ispin, 0]
|
898
|
+
)
|
899
|
+
# (3) spin_pk2pk
|
900
|
+
spin_pk2pk.append(pks_max_spin - pks_min_spin)
|
901
|
+
spin_pks_loc.append(
|
902
|
+
[loc_min_spin_, pks_min_spin, loc_max_spin_, pks_max_spin]
|
903
|
+
)
|
904
|
+
ipk += 1
|
905
|
+
spin_pk2pk = np.array(spin_pk2pk)
|
906
|
+
spin_pks_loc = np.array(spin_pks_loc)
|
907
|
+
else:
|
908
|
+
spin_pks_loc = np.array([])
|
909
|
+
spin_pk2pk = np.array([])
|
910
|
+
if opt.spin['extract']:
|
911
|
+
waveform = extract_peaks_waveforms(data=data, pks=loc_max_spin, win_size=opt.spin["win_size"])
|
912
|
+
print(f"detected {EventsSpin.shape[0]} spindles, density={SpinDensity}")
|
913
|
+
# fillint output
|
914
|
+
res = pd.DataFrame(
|
915
|
+
{
|
916
|
+
"spin": {
|
917
|
+
"idx_start_stop": EventsSpin,
|
918
|
+
"num":EventsSpin.shape[0],
|
919
|
+
"density": SpinDensity,
|
920
|
+
"thr": Thr,
|
921
|
+
"freq": spin_freq,
|
922
|
+
"pow": spin_pow_single,
|
923
|
+
"pk2pk": spin_pk2pk,
|
924
|
+
"pk2pk_loc": spin_pks_loc,
|
925
|
+
"max_pks_loc":loc_max_spin,
|
926
|
+
"avg_freq":spin_avg_freq,
|
927
|
+
"win_size":opt.spin["win_size"],
|
928
|
+
"waveform":waveform
|
929
|
+
}
|
930
|
+
}
|
931
|
+
)
|
932
|
+
del amp_filt_env
|
933
|
+
del amp_filt
|
934
|
+
del data
|
935
|
+
return res
|
936
|
+
|
937
|
+
def detect_sos(data, opt):
|
938
|
+
fs = opt["info"]["fs"]
|
939
|
+
epoch_dur = opt["info"]["epoch_dur"]
|
940
|
+
|
941
|
+
amp_filt = filter_bandpass(
|
942
|
+
ord=opt["sos"]["filt_ord"],
|
943
|
+
data=data,
|
944
|
+
freq_range=opt["sos"]["freq_range"],
|
945
|
+
fs=fs,
|
946
|
+
)
|
947
|
+
# zero_cross
|
948
|
+
rise_b4, fall_b4 = detect_cross(amp_filt, 0)
|
949
|
+
loc_cross = np.zeros((len(fall_b4) - 1, 2))
|
950
|
+
# [falling1, falling2; falling2, faling3;....]
|
951
|
+
loc_cross[:, 0] = [x for x in fall_b4[:-1]]
|
952
|
+
loc_cross[:, 1] = [x for x in fall_b4[1:]]
|
953
|
+
# #+++++++++ check the loc_cross+++++++++
|
954
|
+
# t = np.arange(len(amp_filt))
|
955
|
+
# plt.figure(figsize=[6, 2])
|
956
|
+
# plt.plot(
|
957
|
+
# t[int(loc_cross[0][0] - 50) : int(loc_cross[0][1] + 50)],
|
958
|
+
# amp_filt[int(loc_cross[0][0] - 50) : int(loc_cross[0][1] + 50)],
|
959
|
+
# lw=0.75,
|
960
|
+
# )
|
961
|
+
# plt.plot(
|
962
|
+
# t[int(loc_cross[0][0]) : int(loc_cross[0][1])],
|
963
|
+
# amp_filt[int(loc_cross[0][0]) : int(loc_cross[0][1])],
|
964
|
+
# lw=1.5,
|
965
|
+
# c="r",
|
966
|
+
# )
|
967
|
+
# plt.axhline(0, lw=0.75)
|
968
|
+
# plt.show()
|
969
|
+
# # sos candidates within NREM_idx time-frame
|
970
|
+
score_code = extract_score(
|
971
|
+
opt["info"]["dir_score"], kron_go=True, scalar=fs * epoch_dur
|
972
|
+
)
|
973
|
+
so_stage_idx = find_repeats(score_code, opt["sos"]["stage"], nGap=1)
|
974
|
+
so_stage_idx = so_stage_idx[np.where(so_stage_idx[:, 0] >= fs * epoch_dur)]
|
975
|
+
so_stage_idx = so_stage_idx[
|
976
|
+
np.where(so_stage_idx[:, 1] <= (len(data) - fs * epoch_dur))
|
977
|
+
]
|
978
|
+
sos_kndt_Loc = loc_cross[
|
979
|
+
np.where(
|
980
|
+
(so_stage_idx[:, 0] < loc_cross[:, 0].reshape(-1, 1))
|
981
|
+
& (loc_cross[:, 1].reshape(-1, 1) < so_stage_idx[:, 1])
|
982
|
+
)[0],
|
983
|
+
:,
|
984
|
+
]
|
985
|
+
|
986
|
+
dur_sec = opt["sos"]["dur_sec"]
|
987
|
+
dura_tmp = np.diff(sos_kndt_Loc, axis=1)
|
988
|
+
event_sos_loc = sos_kndt_Loc[
|
989
|
+
np.where((dur_sec[0] * fs < dura_tmp) & (dura_tmp < dur_sec[1] * fs))[0].tolist(), :
|
990
|
+
]
|
991
|
+
event_sos_loc = np.array(
|
992
|
+
[(int(x), int(y)) for (x, y) in np.array(event_sos_loc).reshape(-1, 2)]
|
993
|
+
).reshape(
|
994
|
+
-1, 2
|
995
|
+
) # int
|
996
|
+
sos_pks_pos_idx = []
|
997
|
+
sos_pks_neg_idx = []
|
998
|
+
sos_pks_neg_value = []
|
999
|
+
sos_pk2pk = []
|
1000
|
+
|
1001
|
+
for iso in range(event_sos_loc.shape[0]):
|
1002
|
+
# max
|
1003
|
+
sos_max_tmp = np.max(amp_filt[event_sos_loc[iso, 0] : event_sos_loc[iso, 1]])
|
1004
|
+
sos_pks_idx_max_tmp = list(
|
1005
|
+
amp_filt[event_sos_loc[iso, 0] : event_sos_loc[iso, 1]]
|
1006
|
+
).index(sos_max_tmp)
|
1007
|
+
sos_pks_pos_idx.append(int(event_sos_loc[iso, 0] + sos_pks_idx_max_tmp))
|
1008
|
+
# min
|
1009
|
+
sos_min_tmp = np.min(amp_filt[event_sos_loc[iso, 0] : event_sos_loc[iso, 1]])
|
1010
|
+
sos_pks_idx_min_tmp = list(
|
1011
|
+
amp_filt[event_sos_loc[iso, 0] : event_sos_loc[iso, 1]]
|
1012
|
+
).index(sos_min_tmp)
|
1013
|
+
sos_pks_neg_idx.append(int(event_sos_loc[iso, 0] + sos_pks_idx_min_tmp))
|
1014
|
+
sos_pks_neg_value.append(sos_min_tmp)
|
1015
|
+
# pk2pk
|
1016
|
+
sos_pk2pk.append(sos_max_tmp + np.abs(sos_min_tmp))
|
1017
|
+
if isinstance(sos_pks_neg_value, list):
|
1018
|
+
sos_pks_neg_value = np.array(sos_pks_neg_value)
|
1019
|
+
sos_pk2pk = np.array(sos_pk2pk)
|
1020
|
+
if opt["sos"]["thr"] == []:
|
1021
|
+
n_prctile_amplitude = opt["sos"]["n_prctile_amplitude"]
|
1022
|
+
thr_negpks_amp = np.percentile(
|
1023
|
+
np.abs(sos_pks_neg_value), n_prctile_amplitude, axis=0
|
1024
|
+
)
|
1025
|
+
thr_pks2pks = np.percentile(sos_pk2pk, n_prctile_amplitude, axis=0)
|
1026
|
+
sos_thr = np.array([-thr_negpks_amp, thr_pks2pks])
|
1027
|
+
else:
|
1028
|
+
if len(opt["sos"]["thr"]) == 1:
|
1029
|
+
thr_negpks_amp = abs(opt["sos"]["thr"][0])
|
1030
|
+
sos_thr = np.array([-thr_negpks_amp])
|
1031
|
+
elif len(opt["sos"]["thr"]) == 2:
|
1032
|
+
thr_negpks_amp = abs(opt["sos"]["thr"][0])
|
1033
|
+
thr_pks2pks = opt["sos"]["thr"][1]
|
1034
|
+
sos_thr = np.array([-thr_negpks_amp, thr_pks2pks])
|
1035
|
+
ithr = 1
|
1036
|
+
sos_loc = []
|
1037
|
+
(
|
1038
|
+
abs(sos_pks_neg_value[iso]) > sos_thr[0] and abs(sos_pk2pk[iso]) > sos_thr[1]
|
1039
|
+
if len(sos_thr) == 2
|
1040
|
+
else abs(sos_pks_neg_value[iso]) > sos_thr[0]
|
1041
|
+
)
|
1042
|
+
if "event_sos_loc" in locals() and event_sos_loc.shape[0] != 0:
|
1043
|
+
for iso in range(sos_pk2pk.shape[0]):
|
1044
|
+
thr_criterion = (
|
1045
|
+
abs(sos_pks_neg_value[iso]) > sos_thr[0]
|
1046
|
+
and abs(sos_pk2pk[iso]) > sos_thr[1]
|
1047
|
+
if len(sos_thr) == 2
|
1048
|
+
else abs(sos_pks_neg_value[iso]) > sos_thr[0]
|
1049
|
+
)
|
1050
|
+
|
1051
|
+
if thr_criterion:
|
1052
|
+
sos_loc.append(
|
1053
|
+
[
|
1054
|
+
event_sos_loc[iso, 0],
|
1055
|
+
event_sos_loc[iso, 1],
|
1056
|
+
sos_pks_neg_idx[iso],
|
1057
|
+
amp_filt[sos_pks_neg_idx[iso]],
|
1058
|
+
sos_pks_pos_idx[iso],
|
1059
|
+
amp_filt[sos_pks_pos_idx[iso]],
|
1060
|
+
]
|
1061
|
+
)
|
1062
|
+
ithr += 1
|
1063
|
+
|
1064
|
+
if len(sos_loc) != 0:
|
1065
|
+
sos_loc = np.array(sos_loc)
|
1066
|
+
sos_idx = sos_loc[:, :2].astype(int)
|
1067
|
+
sos_pks_neg_idx = sos_loc[:, 2].astype(int)
|
1068
|
+
sos_pks_pos_idx = sos_loc[:, 4].astype(int)
|
1069
|
+
sos_pks_loc = sos_loc[:, 2:6]
|
1070
|
+
sos_pk2pk = sos_loc[:, 5] - sos_loc[:, 3]
|
1071
|
+
# 3.7 sos density
|
1072
|
+
if "sos_idx" in locals():
|
1073
|
+
sos_density = len(sos_idx) / (
|
1074
|
+
np.sum(np.diff(so_stage_idx, axis=1)) / fs / 60
|
1075
|
+
) # in minutes
|
1076
|
+
# 3.8 Freq of each sos
|
1077
|
+
if "sos_idx" in locals():
|
1078
|
+
# sos Power and sos Freq
|
1079
|
+
# amp_filt = np.abs(hilbert(amp_filt))
|
1080
|
+
sos_power = np.zeros((len(sos_idx), 1))
|
1081
|
+
sos_dura_sec = np.diff(sos_idx, axis=1) / fs
|
1082
|
+
for i in range(len(sos_idx)):
|
1083
|
+
sos_power[i] = np.trapz(amp_filt[sos_idx[i, 0] : sos_idx[i, 1]])
|
1084
|
+
sos_freq = 1 / sos_dura_sec.flatten() # Frequency
|
1085
|
+
# 3.10 slope_sos
|
1086
|
+
# calculating the slope of slow wave events by taking the difference between
|
1087
|
+
# the positive and negative peaks and dividing it by the respective time
|
1088
|
+
# interval.
|
1089
|
+
#
|
1090
|
+
# Calculate slope for each slow wave event
|
1091
|
+
sos_slope = np.zeros(len(sos_pks_pos_idx))
|
1092
|
+
for i in range(len(sos_pks_pos_idx)):
|
1093
|
+
# Calculate the slope as (positive peak - negative peak) / time interval
|
1094
|
+
sos_slope[i] = (amp_filt[sos_pks_pos_idx[i]] - amp_filt[sos_pks_neg_idx[i]]) / (
|
1095
|
+
sos_pks_pos_idx[i] - sos_pks_neg_idx[i]
|
1096
|
+
)
|
1097
|
+
if opt.sos["extract"]:
|
1098
|
+
waveform = extract_peaks_waveforms(
|
1099
|
+
data=data, pks=sos_pks_neg_idx, win_size=opt.sos["win_size"]
|
1100
|
+
)
|
1101
|
+
else:
|
1102
|
+
waveform = []
|
1103
|
+
print(f"detected {sos_idx.shape[0]} SOs, density={sos_density}")
|
1104
|
+
# fillint output
|
1105
|
+
res = pd.DataFrame(
|
1106
|
+
{
|
1107
|
+
"sos": {
|
1108
|
+
"idx_start_stop": sos_idx,
|
1109
|
+
"num":sos_idx.shape[0],
|
1110
|
+
"density": sos_density,
|
1111
|
+
"thr": sos_thr,
|
1112
|
+
"slope": sos_slope,
|
1113
|
+
"freq": sos_freq,
|
1114
|
+
"pks_neg_idx": sos_pks_neg_idx,
|
1115
|
+
"pks_pos_idx": sos_pks_pos_idx,
|
1116
|
+
"pk2pk": sos_pk2pk,
|
1117
|
+
"pks_loc": sos_pks_loc,
|
1118
|
+
"pow": sos_power,
|
1119
|
+
"win_size":opt.spin["win_size"],
|
1120
|
+
"waveform":waveform
|
1121
|
+
},
|
1122
|
+
}
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
# # debug++++++++plot the grandaverage sos+++++++++++
|
1126
|
+
# win_size = 1.5
|
1127
|
+
# waveform = []
|
1128
|
+
# for pks_ne_idx in res.sos.pks_neg_idx:
|
1129
|
+
# waveform.append(
|
1130
|
+
# amp_filt[int(pks_ne_idx - win_size * fs) : int(pks_ne_idx + win_size * fs)]
|
1131
|
+
# )
|
1132
|
+
# waveform = np.array(waveform).reshape(-1, int(win_size * 2 * fs))
|
1133
|
+
# # plot
|
1134
|
+
# fig, axs = plt.subplots(1, 1, figsize=[8, 3])
|
1135
|
+
|
1136
|
+
# stdshade(
|
1137
|
+
# axs,
|
1138
|
+
# range(waveform.shape[1]),
|
1139
|
+
# waveform,
|
1140
|
+
# 0.39,
|
1141
|
+
# [x / 255 for x in [48, 109, 99]],
|
1142
|
+
# 30,
|
1143
|
+
# )
|
1144
|
+
# plt.axvline(waveform.shape[1] / 2, c="k", label="t0='sos negtive peak'")
|
1145
|
+
# plt.axhline(0, c=".6")
|
1146
|
+
# plt.legend()
|
1147
|
+
# plt.show()
|
1148
|
+
del amp_filt
|
1149
|
+
del data
|
1150
|
+
return res
|
1151
|
+
|
1152
|
+
def detect_ripples(data, opt):
|
1153
|
+
fs = opt["info"]["fs"]
|
1154
|
+
epoch_dur = opt["info"]["epoch_dur"]
|
1155
|
+
|
1156
|
+
amp_filt = filter_bandpass(
|
1157
|
+
ord=opt["rip"]["filt_ord"],
|
1158
|
+
data=data,
|
1159
|
+
freq_range=opt["rip"]["freq_range"],
|
1160
|
+
fs=fs,
|
1161
|
+
)
|
1162
|
+
|
1163
|
+
# Calculate amp_filt_env using Hilbert transform
|
1164
|
+
amp_filt_env = np.abs(hilbert(amp_filt))
|
1165
|
+
# Apply additional smoothing (moving average with 200-ms window size)
|
1166
|
+
if opt["rip"]["smth"]:
|
1167
|
+
# Calculate mean and standard deviation of amp_filt_env
|
1168
|
+
amp_filt_env_mean = np.mean(moving_average(amp_filt_env, int(0.2 * fs)))
|
1169
|
+
amp_filt_env_std = np.std(moving_average(amp_filt_env, int(0.2 * fs)))
|
1170
|
+
else:
|
1171
|
+
# Calculate mean and standard deviation of amp_filt_env
|
1172
|
+
amp_filt_env_mean = np.mean(amp_filt_env)
|
1173
|
+
amp_filt_env_std = np.std(amp_filt_env)
|
1174
|
+
# 2.3 filling in one matrix
|
1175
|
+
Thr = np.array(opt["rip"]["thr"]) * amp_filt_env_std
|
1176
|
+
# 2.4 use the defined Thresholds
|
1177
|
+
if len(Thr) >= 1:
|
1178
|
+
a, b = detect_cross(amp_filt_env, Thr[0])
|
1179
|
+
else:
|
1180
|
+
raise ValueError("Didn not find the 1st spi.Thr")
|
1181
|
+
RisingFalling = np.column_stack((a, b))
|
1182
|
+
rip_dura_sec = opt["rip"]["dur_sec"]
|
1183
|
+
Dura_tmp = np.diff(RisingFalling, axis=1)
|
1184
|
+
Thr1rip1 = RisingFalling[
|
1185
|
+
np.where((rip_dura_sec[0] * fs < Dura_tmp) & (Dura_tmp < rip_dura_sec[1] * fs)),
|
1186
|
+
:,
|
1187
|
+
][0]
|
1188
|
+
# 2.4.1.2 calcultion the EventsRip1 in NREM (specific sleep stages) or not
|
1189
|
+
score_code = extract_score(opt["info"]["dir_score"], scalar=fs * epoch_dur)
|
1190
|
+
stage_rip_idx = find_repeats(score_code, opt["rip"]["stage"], nGap=1)
|
1191
|
+
stage_rip_idx = stage_rip_idx[
|
1192
|
+
np.where(
|
1193
|
+
(stage_rip_idx[:, 0] >= fs * epoch_dur)
|
1194
|
+
& (stage_rip_idx[:, 1] <= len(amp_filt) - fs * epoch_dur)
|
1195
|
+
)
|
1196
|
+
]
|
1197
|
+
EventsRip1 = Thr1rip1[
|
1198
|
+
np.where(
|
1199
|
+
(stage_rip_idx[:, 0] < Thr1rip1[:, 0].reshape(-1, 1))
|
1200
|
+
& (Thr1rip1[:, 1].reshape(-1, 1) < stage_rip_idx[:, 1])
|
1201
|
+
)[0],
|
1202
|
+
:,
|
1203
|
+
]
|
1204
|
+
# print("step1", EventsRip1.shape)
|
1205
|
+
# 2.4.2 Thr2 crossing
|
1206
|
+
if len(Thr) >= 2:
|
1207
|
+
a, b = detect_cross(amp_filt_env, Thr[1])
|
1208
|
+
RisingFalling = np.column_stack((a, b))
|
1209
|
+
ripDura_min2 = rip_dura_sec[0] / 2 # half of the minium duration
|
1210
|
+
ripDura_max2 = rip_dura_sec[1]
|
1211
|
+
Dura_tmp = np.diff(RisingFalling, axis=1)
|
1212
|
+
EventsRip2 = RisingFalling[
|
1213
|
+
np.where((ripDura_min2 * fs <= Dura_tmp) & (Dura_tmp <= ripDura_max2 * fs)),
|
1214
|
+
:,
|
1215
|
+
][0]
|
1216
|
+
else:
|
1217
|
+
EventsRip2 = np.copy(EventsRip1)
|
1218
|
+
# print("step2", EventsRip2.shape)
|
1219
|
+
# 2.4.2.3 check EventsRip2 in EventsRip1
|
1220
|
+
if (
|
1221
|
+
("EventsRip1" in locals())
|
1222
|
+
and ("EventsRip2" in locals())
|
1223
|
+
and (EventsRip1.shape[0] != 0)
|
1224
|
+
and (EventsRip2.shape[0] != 0)
|
1225
|
+
):
|
1226
|
+
EventsRip3 = EventsRip1[
|
1227
|
+
np.where(
|
1228
|
+
(EventsRip1[:, 0] < EventsRip2[:, 0].reshape(-1, 1))
|
1229
|
+
& (EventsRip2[:, 1].reshape(-1, 1) < EventsRip1[:, 1])
|
1230
|
+
)[1],
|
1231
|
+
:,
|
1232
|
+
]
|
1233
|
+
# print("step3", EventsRip3.shape)
|
1234
|
+
# 2.4.2.4 unique EventsRip3
|
1235
|
+
if EventsRip3.shape[0] != 0:
|
1236
|
+
EventsRip3_orgs = np.copy(EventsRip3)
|
1237
|
+
EventsRip3 = np.unique(EventsRip3_orgs[:, 0:2], axis=0)
|
1238
|
+
if len(Thr) >= 3:
|
1239
|
+
# 2.4.3 Crossing positions - Thr 3
|
1240
|
+
EventsRip = []
|
1241
|
+
irip4 = 0
|
1242
|
+
if EventsRip3.shape[0] != 0:
|
1243
|
+
for irip3 in range(EventsRip3.shape[0]):
|
1244
|
+
if (
|
1245
|
+
np.max(amp_filt_env[EventsRip3[irip3, 0] : EventsRip3[irip3, 1]])
|
1246
|
+
>= Thr[2]
|
1247
|
+
):
|
1248
|
+
EventsRip.append([EventsRip3[irip3, 0], EventsRip3[irip3, 1]])
|
1249
|
+
irip4 += 1
|
1250
|
+
if isinstance(EventsRip, list):
|
1251
|
+
EventsRip = np.array(EventsRip)
|
1252
|
+
EventsRip = EventsRip.reshape(-1, 2)
|
1253
|
+
else:
|
1254
|
+
EventsRip = EventsRip3.copy()
|
1255
|
+
else:
|
1256
|
+
EventsRip = EventsRip3.copy()
|
1257
|
+
print("\ncannot find the 3rd Thr_rip, only 2 Thr were used for rip dtk \n")
|
1258
|
+
EventsRip = EventsRip.reshape(-1, 2)
|
1259
|
+
# print("final detected ripple number: ", EventsRip.shape)
|
1260
|
+
# 2.6 ripple density (counts during NREM) rip density (events/min)
|
1261
|
+
# calculated as the number of ripple detected in each recording site
|
1262
|
+
# divided by the time in SWS.
|
1263
|
+
rip_density = EventsRip.shape[0] / (
|
1264
|
+
np.sum(np.diff(stage_rip_idx, axis=1)) / fs / 60
|
1265
|
+
) # in minute
|
1266
|
+
print(f"detected {EventsRip.shape[0]} ripples, density={rip_density}")
|
1267
|
+
# Freq of each ripples
|
1268
|
+
num_rip_pks = []
|
1269
|
+
for i in range(EventsRip.shape[0]):
|
1270
|
+
peaks, _ = find_peaks(
|
1271
|
+
amp_filt[EventsRip[i, 0] : EventsRip[i, 1]], height=Thr[0]
|
1272
|
+
) # Assuming Thr is a scalar
|
1273
|
+
num_rip_pks.append(len(peaks))
|
1274
|
+
|
1275
|
+
dur_rip_sec = np.diff(EventsRip, axis=1) / fs
|
1276
|
+
rip_freq = [(x / y).tolist()[0] for (x, y) in zip(np.array(num_rip_pks), dur_rip_sec)]
|
1277
|
+
rig_avg_freq = np.nanmean(rip_freq, axis=0)
|
1278
|
+
print(f"Average ripple frequency: {rig_avg_freq:.4f} Hz")
|
1279
|
+
# ripple Power # Use numpy's np.trapz() to compute the area under the curve
|
1280
|
+
rip_pow_single = [np.trapz(amp_filt_env[start:end]) for start, end in EventsRip]
|
1281
|
+
|
1282
|
+
if EventsRip.shape[0] > 1:
|
1283
|
+
loc_max_rip = [np.argmax(amp_filt[start:end]) + start for start, end in EventsRip]
|
1284
|
+
rip_pk2pk = [
|
1285
|
+
(np.max(amp_filt[start:end]) - np.min(amp_filt[start:end]))
|
1286
|
+
for start, end in EventsRip
|
1287
|
+
]
|
1288
|
+
# rip_pk2pk = []
|
1289
|
+
# rip_pks_loc = []
|
1290
|
+
# loc_max_rip = []
|
1291
|
+
# ipk = 0
|
1292
|
+
# for irip in range(EventsRip.shape[0]):
|
1293
|
+
# tmp = amp_filt[
|
1294
|
+
# EventsRip[irip, 0] : EventsRip[irip, 1] + 1
|
1295
|
+
# ] # +1 to include the end index
|
1296
|
+
# # (1) find pks_max
|
1297
|
+
# locs_max, _ = find_peaks(list(tmp))
|
1298
|
+
# pks_max = tmp[locs_max]
|
1299
|
+
# pks_max_rip = np.max(pks_max)
|
1300
|
+
# loc_max_rip_ = (
|
1301
|
+
# locs_max[np.where(pks_max == pks_max_rip)[0][0]] + EventsRip[irip, 0]
|
1302
|
+
# )
|
1303
|
+
# loc_max_rip.append(loc_max_rip_)
|
1304
|
+
|
1305
|
+
# # (2) find pks_min
|
1306
|
+
# pks_min = tmp[locs_max]
|
1307
|
+
# pks_min = pks_min * (-1) # don't forget to multiply by -1
|
1308
|
+
# pks_min_rip = np.min(pks_min)
|
1309
|
+
# loc_min_rip_ = (
|
1310
|
+
# locs_max[np.where(pks_min == pks_min_rip)[0][0]] + EventsRip[irip, 0]
|
1311
|
+
# )
|
1312
|
+
# # (3) rip_pk2pk
|
1313
|
+
# rip_pk2pk.append(pks_max_rip - pks_min_rip)
|
1314
|
+
# rip_pks_loc.append([loc_min_rip_, pks_min_rip, loc_max_rip_, pks_max_rip])
|
1315
|
+
# ipk += 1
|
1316
|
+
# rip_pk2pk = np.array(rip_pk2pk)
|
1317
|
+
# rip_pks_loc = np.array(rip_pks_loc)
|
1318
|
+
else:
|
1319
|
+
# rip_pks_loc = np.array([])
|
1320
|
+
rip_pk2pk = np.array([])
|
1321
|
+
if opt.rip["extract"] and EventsRip.shape[0] > 1:
|
1322
|
+
waveform = extract_peaks_waveforms(
|
1323
|
+
data=data, pks=loc_max_rip, win_size=opt.rip["win_size"]
|
1324
|
+
)
|
1325
|
+
else:
|
1326
|
+
waveform = []
|
1327
|
+
# fillint output
|
1328
|
+
res = pd.DataFrame(
|
1329
|
+
{
|
1330
|
+
"rip": {
|
1331
|
+
"idx_start_stop": EventsRip,
|
1332
|
+
"num": EventsRip.shape[0],
|
1333
|
+
"density": rip_density,
|
1334
|
+
"thr": Thr,
|
1335
|
+
"freq": rip_freq,
|
1336
|
+
"pow": rip_pow_single,
|
1337
|
+
"pk2pk": rip_pk2pk,
|
1338
|
+
# "pk2pk_loc": rip_pks_loc,
|
1339
|
+
"avg_freq": rig_avg_freq,
|
1340
|
+
"max_pks_loc": loc_max_rip,
|
1341
|
+
"win_size": opt.spin["win_size"],
|
1342
|
+
"waveform": waveform,
|
1343
|
+
}
|
1344
|
+
}
|
1345
|
+
)
|
1346
|
+
del amp_filt
|
1347
|
+
del data
|
1348
|
+
del amp_filt_env
|
1349
|
+
|
1350
|
+
return res
|