py2ls 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. py2ls/.git/COMMIT_EDITMSG +1 -0
  2. py2ls/.git/FETCH_HEAD +1 -0
  3. py2ls/.git/HEAD +1 -0
  4. py2ls/.git/config +15 -0
  5. py2ls/.git/description +1 -0
  6. py2ls/.git/hooks/applypatch-msg.sample +15 -0
  7. py2ls/.git/hooks/commit-msg.sample +24 -0
  8. py2ls/.git/hooks/fsmonitor-watchman.sample +174 -0
  9. py2ls/.git/hooks/post-update.sample +8 -0
  10. py2ls/.git/hooks/pre-applypatch.sample +14 -0
  11. py2ls/.git/hooks/pre-commit.sample +49 -0
  12. py2ls/.git/hooks/pre-merge-commit.sample +13 -0
  13. py2ls/.git/hooks/pre-push.sample +53 -0
  14. py2ls/.git/hooks/pre-rebase.sample +169 -0
  15. py2ls/.git/hooks/pre-receive.sample +24 -0
  16. py2ls/.git/hooks/prepare-commit-msg.sample +42 -0
  17. py2ls/.git/hooks/push-to-checkout.sample +78 -0
  18. py2ls/.git/hooks/update.sample +128 -0
  19. py2ls/.git/index +0 -0
  20. py2ls/.git/info/exclude +6 -0
  21. py2ls/.git/logs/HEAD +1 -0
  22. py2ls/.git/logs/refs/heads/main +1 -0
  23. py2ls/.git/logs/refs/remotes/origin/HEAD +1 -0
  24. py2ls/.git/logs/refs/remotes/origin/main +1 -0
  25. py2ls/.git/objects/25/b796accd261b9135fd32a2c00785f68edf6c46 +0 -0
  26. py2ls/.git/objects/36/b4a1b7403abc6c360f8fe2cb656ab945254971 +0 -0
  27. py2ls/.git/objects/3f/d6561300938afbb3d11976cf9c8f29549280d9 +0 -0
  28. py2ls/.git/objects/58/20a729045d4dc7e37ccaf8aa8eec126850afe2 +0 -0
  29. py2ls/.git/objects/60/f273eb1c412d916fa3f11318a7da7a9911b52a +0 -0
  30. py2ls/.git/objects/61/570cec8c061abe74121f27f5face6c69b98f99 +0 -0
  31. py2ls/.git/objects/69/13c452ca319f7cbf6a0836dc10a5bb033c84e4 +0 -0
  32. py2ls/.git/objects/78/3d4167bc95c9d2175e0df03ef1c1c880ba75ab +0 -0
  33. py2ls/.git/objects/79/7ae089b2212a937840e215276005ce76881307 +0 -0
  34. py2ls/.git/objects/7e/5956c806b5edc344d46dab599dec337891ba1f +1 -0
  35. py2ls/.git/objects/8e/55a7d2b96184030211f20c9b9af201eefcac82 +0 -0
  36. py2ls/.git/objects/91/c69ad88fe0ba94aa7859fb5f7edac5e6f1a3f7 +0 -0
  37. py2ls/.git/objects/b0/56be4be89ba6b76949dd641df45bb7036050c8 +0 -0
  38. py2ls/.git/objects/b0/9cd7856d58590578ee1a4f3ad45d1310a97f87 +0 -0
  39. py2ls/.git/objects/d9/005f2cc7fc4e65f14ed5518276007c08cf2fd0 +0 -0
  40. py2ls/.git/objects/df/e0770424b2a19faf507a501ebfc23be8f54e7b +0 -0
  41. py2ls/.git/objects/e9/391ffe371f1cc43b42ef09b705d9c767c2e14f +0 -0
  42. py2ls/.git/objects/fc/292e793ecfd42240ac43be407023bd731fa9e7 +0 -0
  43. py2ls/.git/refs/heads/main +1 -0
  44. py2ls/.git/refs/remotes/origin/HEAD +1 -0
  45. py2ls/.git/refs/remotes/origin/main +1 -0
  46. py2ls/.gitattributes +2 -0
  47. py2ls/.gitignore +152 -0
  48. py2ls/LICENSE +201 -0
  49. py2ls/README.md +409 -0
  50. py2ls/__init__.py +17 -0
  51. py2ls/brain_atlas.py +145 -0
  52. py2ls/correlators.py +475 -0
  53. py2ls/dbhandler.py +97 -0
  54. py2ls/freqanalysis.py +800 -0
  55. py2ls/internet_finder.py +405 -0
  56. py2ls/ips.py +2844 -0
  57. py2ls/netfinder.py +780 -0
  58. py2ls/sleep_events_detectors.py +1350 -0
  59. py2ls/translator.py +686 -0
  60. py2ls/version.py +1 -0
  61. py2ls/wb_detector.py +169 -0
  62. py2ls-0.1.0.dist-info/METADATA +12 -0
  63. py2ls-0.1.0.dist-info/RECORD +64 -0
  64. py2ls-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1350 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from scipy.signal import butter, hilbert, filtfilt, find_peaks,resample, resample_poly
4
+ from scipy.interpolate import interp1d
5
+ from scipy.ndimage import convolve1d
6
+ import os
7
+ import mne
8
+ import scipy.io
9
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ # from multitaper_spectrogram_python import *
13
+
14
+ def load_mat(dir_mat):
15
+ return scipy.io.loadmat(dir_mat)
16
+
17
+ # def loadmat(dir_mat):
18
+ # return scipy.io.loadmat(dir_mat)
19
+
20
+ # Define data path
21
+ # data_file = "./UCLA_data/CSC4.Ncs"def
22
+ def load_ncs(dir_file, header_size=16 * 1024):
23
+ # Header has 16 kilobytes length
24
+ # header_size = 16 * 1024
25
+
26
+ # Open file
27
+ fid = open(dir_file, "rb")
28
+
29
+ # Skip header by shifting position by header size
30
+ fid.seek(header_size)
31
+
32
+ # Read data according to Neuralynx information
33
+ data_format = np.dtype(
34
+ [
35
+ ("TimeStamp", np.uint64),
36
+ ("ChannelNumber", np.uint32),
37
+ ("SampleFreq", np.uint32),
38
+ ("NumValidSamples", np.uint32),
39
+ ("Samples", np.int16, 512),
40
+ ]
41
+ )
42
+
43
+ raw = np.fromfile(fid, dtype=data_format)
44
+ # Close file
45
+ fid.close()
46
+
47
+ # filling output
48
+ res = {}
49
+ res["data"] = raw["Samples"].ravel() # Create data vector
50
+ res["fs"] = raw["SampleFreq"][0] # Get sampling frequency
51
+ res["dur_sec"] = (
52
+ res["data"].shape[0] / raw["SampleFreq"][0]
53
+ ) # Determine duration of recording in seconds
54
+ res["time"] = np.linspace(
55
+ 0, res["dur_sec"], res["data"].shape[0]
56
+ ) # Create time vector
57
+ return pd.DataFrame(res)
58
+
59
+ def ncs2_single_raw(fpath, ch_names=None, ch_types=None):
60
+ ncs_data = load_ncs(fpath)
61
+ data = ncs_data["data"]
62
+ sfreq = ncs_data["fs"][0]
63
+ if ch_names is None:
64
+ ch_names = [os.path.splitext(os.path.basename(fpath))[0]]
65
+ if ch_types is None:
66
+ ch_types = "eeg" if "eeg" in ch_names[0].lower() else "eog" # should be 'lfp'
67
+ info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
68
+ data_nchan_nsamples = np.array(data)[np.newaxis, :]
69
+ raw = mne.io.RawArray(data_nchan_nsamples, info)
70
+ return raw
71
+
72
+ def trial_extract(trial):
73
+ # the mat should be stored older version, otherwise, cannot be read
74
+ # recommend: mat_data = scipy.io.loadmat(dir_data) #import scipy.io
75
+ while (type(trial[0]) == "numpy.float64") or (len(trial[0]) == 1):
76
+ trial = trial[0].copy()
77
+ print(trial[0].shape)
78
+ return trial[0]
79
+
80
+ # # dir_data = "/Users/macjianfeng/Desktop/test_v7_dat.mat"
81
+ # dir_data = "/Users/macjianfeng/Desktop/mat_r06rec1.mat"
82
+ # mat_data = scipy.io.loadmat(dir_data)
83
+
84
+ # trials = trial_extract(mat_data["trial"])
85
+ # fs = mat_data["fsample"][0][0]
86
+ # label = repr(mat_data["label"][0][0][0]) # convert to 'str'
87
+
88
+ # print("first 12 trials: ", trials[:12])
89
+ # print("fs=", fs)
90
+ # print("label=", label, "type(label):", type(label))
91
+ def cal_slope(data, segment=1, correct=True):
92
+ length = len(data)
93
+ slope = []
94
+ for i in range(0, length - segment, segment):
95
+ change_in_y = data[i + segment] - data[i]
96
+ change_in_x = segment
97
+ slope.append(change_in_y / change_in_x)
98
+ if correct:
99
+ # Interpolate the slopes to fill in the gaps
100
+ interpolated_slopes = np.repeat(slope, segment)
101
+ # Adjust the length of interpolated_slopes to match the length of continuous_line
102
+ missing_values = len(data) - len(interpolated_slopes)
103
+ if missing_values > 0:
104
+ interpolated_slopes = np.append(
105
+ interpolated_slopes, [slope[-1]] * missing_values
106
+ )
107
+ return interpolated_slopes
108
+ else:
109
+ return slope
110
+ # Apply bandpass filter to EEG signal
111
+ def butter_band_filter(data, lowcut, highcut, fs, ord=3):
112
+ from scipy.signal import butter
113
+
114
+ nyq = 0.5 * fs
115
+ low = lowcut / nyq
116
+ high = highcut / nyq
117
+ b, a = butter(ord, [low, high], btype="band")
118
+ dat_spin_filt = filtfilt(b, a, data)
119
+ return dat_spin_filt
120
+
121
+
122
+ def butter_bandpass_filter(data=None, ord=4, freq_range=[11, 16], fs=1000):
123
+ from scipy.signal import butter
124
+ print("usage:\n butter_bandpass_filter(data=None, ord=4, freq_range=[11, 16], fs=1000)")
125
+ # alternative:
126
+ b, a = butter(ord, freq_range, btype="bandpass", fs=fs)
127
+ data_filt = filtfilt(b,a,data)
128
+ return data_filt
129
+
130
+ def filter_bandpass(data=None, ord=4, freq_range=[11, 16], fs=1000):
131
+ from scipy.signal import butter
132
+ # print("usage:\n butter_bandpass_filter(data=None, ord=4, freq_range=[11, 16], fs=1000)")
133
+ # alternative:
134
+ b, a = butter(ord, freq_range, btype="bandpass", fs=fs)
135
+ data_filt = filtfilt(b,a,data)
136
+ return data_filt
137
+
138
+ # Apply smoothing (moving average)
139
+ def moving_average(data, window_size):
140
+ return convolve1d(data, np.ones(window_size) / window_size)
141
+
142
+ def detect_cross(data, thr=0):
143
+ if isinstance(data, list):
144
+ data = np.array(data)
145
+ if data.ndim == 1:
146
+ pass
147
+ elif data.ndim == 2 and data.shape[0] > data.shape[1]:
148
+ data = data.T
149
+ else:
150
+ raise ValueError("Input data must have two dimensions.")
151
+
152
+ thr_cross = np.sign(data[:, np.newaxis] - thr)
153
+ falling_before = np.where((thr_cross[:-1] == 1) & (thr_cross[1:] == -1))[0] #+ 1
154
+ rising_before = np.where((thr_cross[:-1] == -1) & (thr_cross[1:] == 1))[0]
155
+ falling_before = falling_before.tolist()
156
+ rising_before = rising_before.tolist()
157
+ if rising_before and falling_before:
158
+ if rising_before[0] < falling_before[0]:
159
+ if len(rising_before) > len(falling_before):
160
+ rising_before.pop(0)
161
+ else:
162
+ falling_before.pop(0)
163
+ if len(rising_before) > len(falling_before):
164
+ rising_before.pop(0)
165
+ ## debug
166
+ # a = np.sin(np.arange(0, 10 * np.pi, np.pi / 100))
167
+
168
+ # thres = 0.75
169
+ # rise, fall = detect_cross(a, thres)
170
+ # RisingFalling = np.column_stack((rise, fall))
171
+ # plt.figure(figsize=[5, 2])
172
+ # t = np.arange(len(a))
173
+ # plt.plot(t, a)
174
+ # for i in range(4):
175
+ # plt.plot(
176
+ # t[RisingFalling[i][0] : RisingFalling[i][1]],
177
+ # a[RisingFalling[i][0] : RisingFalling[i][1]],
178
+ # lw=10 - i,
179
+ # )
180
+ # plt.plot(
181
+ # t[RisingFalling[i][0] : RisingFalling[i + 1][0]],
182
+ # a[RisingFalling[i][0] : RisingFalling[i + 1][0]],
183
+ # lw=7 - i,
184
+ # )
185
+ # plt.plot(
186
+ # t[RisingFalling[i][0] : RisingFalling[i + 1][1]],
187
+ # a[RisingFalling[i][0] : RisingFalling[i + 1][1]],
188
+ # lw=5 - i,
189
+ # )
190
+ # plt.gca().axhline(thres)
191
+ return rising_before, falling_before
192
+
193
+ def find_repeats(data, N, nGap=None):
194
+ """
195
+ Find the beginning and end points of repeated occurrences in a dataset.
196
+
197
+ Parameters:
198
+ data (list or numpy.ndarray): The dataset in which repeated occurrences are to be found.
199
+ N (int or list of int): The element(s) to search for.
200
+ nGap (int, optional): The number of elements that can appear between repeated occurrences.
201
+ Defaults to 1 if not provided.
202
+
203
+ Returns:
204
+ numpy.ndarray: An array containing the beginning and end points of repeated occurrences
205
+ of the specified element(s).
206
+
207
+ Description:
208
+ This function identifies the beginning and end points of repeated occurrences
209
+ of specified elements in a dataset. It searches for the element(s) specified
210
+ by `N` in the input `data` and returns the indices where consecutive occurrences
211
+ are separated by at most `nGap` elements.
212
+
213
+ Example:
214
+ data = [1, 2, 3, 4, 1, 2, 5, 1, 2, 2, 3]
215
+ idx = find_repeats(data, [1, 2])
216
+ print(idx) # Output: [[0, 2], [3, 5], [6, 8], [7, 9]]
217
+
218
+ idx = find_repeats(data, 2, 2)
219
+ print(idx) # Output: [[4, 8]]
220
+
221
+ """
222
+ # Convert data to numpy array if it's a list
223
+ if isinstance(data, list):
224
+ data = np.array(data)
225
+
226
+ if nGap is None:
227
+ nGap = 1
228
+
229
+ if isinstance(N, int):
230
+ N = [N]
231
+
232
+ idx = []
233
+ for num in N:
234
+ if num in data:
235
+ idx_beg = [
236
+ i
237
+ for i, x in enumerate(data)
238
+ if x == num and (i == 0 or data[i - 1] != num)
239
+ ]
240
+ idx_end = [
241
+ i - 1
242
+ for i, x in enumerate(data)
243
+ if x == num and (i == len(data) - 1 or data[i + 1] != num)
244
+ ]
245
+
246
+ # Adjust indices for Python's zero-based indexing
247
+ # idx_beg = [i + 1 for i in idx_beg]
248
+ idx_end = [i + 1 for i in idx_end]
249
+
250
+ idx_array = list(zip(idx_beg, idx_end))
251
+ # Correct the first column of idx_array
252
+ idx_single = [
253
+ i
254
+ for i in range(len(idx_array))
255
+ if idx_array[i][1] - idx_array[i][0] == 0
256
+ ]
257
+ for single in idx_single:
258
+ idx_array[single] = (idx_array[single][0], idx_array[single][1] + 1)
259
+
260
+ if nGap == 1:
261
+ idx.append([(beg, end) for beg, end in idx_array])
262
+ elif nGap > 1:
263
+ idx.append([(beg * nGap + 1, end * nGap) for beg, end in idx_array])
264
+ else:
265
+ idx.append([])
266
+
267
+ return np.concatenate(idx)
268
+
269
+ def find_continue(data, step=1):
270
+ """
271
+ Find indices for the beginning and end of continuous segments in data.
272
+
273
+ Parameters:
274
+ data (numpy.ndarray): Input array.
275
+ step (int): Comparison difference. Default is 1.
276
+
277
+ Returns:
278
+ tuple: Tuple containing arrays of indices for the beginning and end of continuous segments.
279
+ """
280
+ if isinstance(data, (list, np.ndarray)):
281
+ data = np.array(data) # Ensure data is a numpy array
282
+ if not isinstance(step, int):
283
+ raise TypeError("step must be an integer")
284
+
285
+ idx_beg = np.where(np.diff([-99999] + data.tolist()) != step)[0]
286
+ idx_end = np.where(np.diff(data.tolist() + [-99999]) != step)[0]
287
+
288
+ return idx_beg, idx_end
289
+
290
+ def resample_data(data, input_fs, output_fs, t=None, axis=0, window=None, domain='time',method='fourier', **kwargs):
291
+ """
292
+ Downsample a signal to a target sampling frequency.
293
+
294
+ Parameters:
295
+ data (array-like):
296
+ The signal to be downsampled.
297
+ input_fs (float):
298
+ The original sampling frequency of the signal.
299
+ output_fs (float):
300
+ The target sampling frequency for downsampling.
301
+ # num int:
302
+ # The number of samples in the resampled signal.
303
+ t : array_like, optional
304
+ If t is given, it is assumed to be the equally spaced sample positions associated with the
305
+ signal data in x.
306
+ axis : int, optional
307
+ The axis of x that is resampled. Default is 0.
308
+ window: array_like, callable, string, float, or tuple, optional
309
+ Specifies the window applied to the signal in the Fourier domain. See below for details.
310
+ domain: string, optional
311
+ A string indicating the domain of the input x: time Consider the input x as time-domain (Default),
312
+ freq Consider the input x as frequency-domain.
313
+ Returns:
314
+ array-like: The downsampled signal.
315
+ """
316
+ if input_fs < output_fs:
317
+ # raise ValueError(f"Target freq = {output_fs} must be <= {input_fs} Hz(input_fs) .")
318
+ # The resample function in scipy uses Fourier method by default for resampling. This method involves
319
+ # upsampling the signal in the frequency domain, then low-pass filtering to remove aliasing, and
320
+ # finally downsampling to the desired rate.
321
+ if method == 'fourier':
322
+ # Upsampling using Fourier method
323
+ factor = output_fs / input_fs
324
+ num_samples_new = int(len(data) * factor)
325
+ data_resampled = resample(data, num_samples_new)
326
+ print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using Fourier method")
327
+
328
+ elif method == 'linear':
329
+ # Upsampling using linear interpolation
330
+ t_original = np.arange(len(data)) / input_fs
331
+ t_resampled = np.arange(len(data) * output_fs / input_fs) / output_fs
332
+ f = interp1d(t_original, data, kind='linear', fill_value='extrapolate')
333
+ data_resampled = f(t_resampled)
334
+ print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using linear interpolation")
335
+
336
+ elif method == 'spline':
337
+ # Upsampling using spline interpolation
338
+ t_original = np.arange(len(data)) / input_fs
339
+ t_resampled = np.arange(len(data) * output_fs / input_fs) / output_fs
340
+ f = interp1d(t_original, data, kind='cubic', fill_value='extrapolate')
341
+ data_resampled = f(t_resampled)
342
+ print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using spline interpolation")
343
+
344
+ elif method == 'sinc':
345
+ # Upsampling using windowed sinc interpolation
346
+ upsample_factor = output_fs / input_fs
347
+ data_resampled = resample_poly(data, int(len(data) * upsample_factor), 1)
348
+ print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using sinc interpolation")
349
+
350
+ elif method == 'zero':
351
+ # Upsampling by zero-padding
352
+ upsample_factor = output_fs / input_fs
353
+ data_resampled = np.repeat(data, int(upsample_factor))
354
+ print(f"Original data {input_fs} Hz\nUpsampled data {output_fs} Hz using zero-padding")
355
+
356
+ elif input_fs == output_fs:
357
+ print(f"Input = output = {output_fs} Hz. \nNo resampling is performed.")
358
+ return data
359
+ else:
360
+ if method == 'fourier':
361
+ # Calculate the resampling factor
362
+ resampling_factor = input_fs / output_fs
363
+
364
+ # Calculate the new number of samples
365
+ num_samples_new = int(len(data) / resampling_factor)
366
+
367
+ # Perform resampling
368
+ data_resampled = resample(data, num_samples_new)
369
+
370
+ print(f"Original data {input_fs} Hz\nResampled data {output_fs} Hz using Fourier method")
371
+ elif method == 'decimate':
372
+ # Downsampling using decimate function (which internally uses FIR filter)
373
+ decimation_factor = int(input_fs / output_fs)
374
+ data_resampled = decimate(data, decimation_factor, zero_phase=True)
375
+ print(f"Original data {input_fs} Hz\nDownsampled data {output_fs} Hz using decimation")
376
+
377
+ return data_resampled
378
+
379
+
380
+ def extract_score(dir_score, kron_go=True,scalar=10*1000):
381
+ df_score = pd.read_csv(dir_score, sep="\t")
382
+ score_val = df_score[df_score.columns[1]].values
383
+ if kron_go:
384
+ score_val = np.kron(
385
+ score_val, np.ones((1,round(scalar)))
386
+ ).reshape(1, -1)
387
+ return score_val[0]
388
+
389
+ def plot_sleeparc(
390
+ ax,
391
+ score,
392
+ code_org=[1, 2, 3],
393
+ code_new=[1, 0, -1],
394
+ kind="patch",
395
+ c=["#474747", "#0C5DA5", "#0C5DA5"],
396
+ ):
397
+ score[np.where(score == code_org[0])] = code_new[0]
398
+ score[np.where(score == code_org[1])] = code_new[1]
399
+ score[np.where(score == code_org[2])] = code_new[2]
400
+ dist = code_new[0] - code_new[1]
401
+ colorlist = c
402
+ # patch method
403
+ if "pa" in kind.lower():
404
+ for i in range(score.shape[0]):
405
+ x = [i - 1, i, i, i - 1]
406
+ if score[i] == code_new[0]:
407
+ y = [code_new[0], code_new[0], code_new[0] + dist, code_new[0] + dist]
408
+ c = colorlist[0]
409
+ elif score[i] == code_new[1]:
410
+ y = [code_new[1], code_new[1], code_new[0], code_new[0]]
411
+ c = colorlist[1]
412
+ elif score[i] == code_new[2]:
413
+ y = [code_new[2], code_new[2], code_new[1], code_new[1]]
414
+ c = colorlist[2]
415
+ ax.fill(x, y, c=c, edgecolor="none")
416
+ # line method
417
+ if "l" in kind.lower():
418
+ ax.plot(score, c=colorlist[0])
419
+ return ax
420
+
421
+ def filter_linenoise(
422
+ data, fs, method="notch", bandwidth=2, n_components=None, random_state=None
423
+ ):
424
+ from scipy import signal
425
+
426
+ nyquist = fs / 2
427
+ freq = np.arange(1, 100) * 50
428
+ freq = list(freq[np.where(freq <= nyquist)])
429
+ print(f"nyquist={nyquist}hz, clearned freq:{freq}")
430
+ if method == "notch":
431
+ for f0 in freq:
432
+
433
+ w0 = f0 / nyquist # Normalized Frequency
434
+ # Design notch filter
435
+ b_, a_ = signal.iirnotch(w0, Q=bandwidth)
436
+ if f0 == freq[0]:
437
+ clean_data = signal.lfilter(b_, a_, data)
438
+ else:
439
+ clean_data = signal.lfilter(b_, a_, clean_data)
440
+ elif method == "bandstop":
441
+ b_, a_ = signal.iirnotch(w0=freq / nyquist, Q=bandwidth, fs=fs)
442
+ clean_data = signal.filtfilt(b_, a_, data, axis=-1)
443
+ elif method == "ica": # Independent Component Analysis (ICA)
444
+ from sklearn.decomposition import FastICA
445
+
446
+ ica = FastICA(n_components=n_components, random_state=random_state)
447
+ ica_components = ica.fit_transform(data.T).T
448
+ clean_data = ica.inverse_transform(ica_components)
449
+
450
+ return clean_data
451
+
452
+ def MAD(signal, mSTD=1.4826):
453
+ """
454
+ to calculate the MedianAbsoluteDeviation (MAD) of a 1D array (signal)
455
+ Parameters:
456
+ signal: 1D array of data.
457
+ mSTD: Multiplier used to scale the MAD. Default value is 1.4826, which is a scaling
458
+ factor to make MAD asymptotically consistent for the estimation of standard deviation
459
+ under the assumption of a normal distribution.
460
+ Output:
461
+ MAD: The Median Absolute Deviation of the input signal.
462
+ Explanation:
463
+ The function calculates the median of the input signal using np.median(signal).
464
+ It then calculates the absolute deviations of each element in the signal from the median
465
+ The median of these absolute deviations is then calculated using np.median().
466
+ Finally, the MAD is obtained by multiplying this median absolute deviation by the scaling factor mSTD.
467
+ """
468
+ # signal is a 1-d array
469
+ medians = np.median(signal)
470
+ MAD = mSTD * np.median(np.abs(signal - medians))
471
+ return MAD
472
+
473
+
474
+ def detect_spikes(data, Fs, mSTD=1.4826, direction="both"):
475
+ """
476
+ Purpose: This function detects spikes in the input data.
477
+ Parameters:
478
+ data: The input signal data.
479
+ Fs: The sampling rate of the signal in Hz.
480
+ mSTD: The multiplier used to compute the threshold. Default is set to 1.4826,
481
+ which is equivalent to the standard deviation of a Gaussian distribution. This
482
+ value is typically used with the Median Absolute Deviation (MAD) to estimate
483
+ the threshold.
484
+ direction: Specifies the direction of spikes to detect. Default is set to "both",
485
+ meaning it detects both positive and negative spikes.
486
+ Return:
487
+ spike_idx: Indices of the detected spikes in the input data.
488
+ time: Time stamps of the detected spikes in milliseconds.
489
+ Explanation:
490
+ The function computes the threshold (thr_) based on the Median Absolute Deviation
491
+ (MAD) of the input data.
492
+ It then detects spikes based on the specified direction using the computed threshold.
493
+ The detected spike indices and their corresponding time stamps are returned.
494
+ """
495
+ # Detect spikes
496
+ # s, t = detectSpikes(x,Fs) detects spikes in x, where Fs the sampling
497
+ # rate (in Hz). The outputs s and t are column vectors of spike times in
498
+ # samples and ms, respectively. By convention the time of the zeroth
499
+ # sample is 0 ms.
500
+
501
+ # if is_dataframe(x):
502
+ # pass
503
+ # mad_ = (x - x.mean()).abs().mean()
504
+ spike_idx = {}
505
+ t = {}
506
+ thr_ = MAD(data, mSTD=mSTD)
507
+ # thr_ = np.percentile(signal, 66.6)
508
+
509
+ # # alternative: thr_ is set to 8.0, and the threshold is calculated as 8 times
510
+ # # the square root of the variance of the first 10 seconds of data
511
+ # thr_ = mSTD * np.sqrt(x[col][: 10 * Fs].astype("float64"))
512
+
513
+ directions = ["both", "pos", "neg"] # 'both': detect both posi and neg.
514
+ if direction in "both":
515
+ time = [x / Fs for x in np.where(np.abs(data) >= thr_)]
516
+ spike_idx, _ = find_peaks(np.abs(data), height=thr_)
517
+
518
+ elif direction in "positive":
519
+ time = [x / Fs for x in np.where(data >= thr_)]
520
+ spike_idx, _ = find_peaks(data, height=thr_)
521
+ elif direction in "negative":
522
+ time = [x / Fs for x in np.where((-1) * data <= thr_)]
523
+ spike_idx, _ = find_peaks((-1) * data, height=thr_)
524
+
525
+ return spike_idx, time
526
+
527
+ def extract_spikes_waveforms(data, Fs, mSTD=1.4826, win=[-10, 20], direction="both"):
528
+ """
529
+ Purpose: This function extracts waveforms of detected spikes from the input data.
530
+ Parameters:
531
+ data: The input signal data.
532
+ Fs: The sampling rate of the signal in Hz.
533
+ mSTD: The multiplier used to compute the threshold. Default is set to 1.4826, which is
534
+ equivalent to the standard deviation of a Gaussian distribution. This value is typically
535
+ used with the Median Absolute Deviation (MAD) to estimate the threshold.
536
+ win: The window size around each detected spike to extract the waveform. It is specified as a
537
+ list [start_offset, end_offset], where start_offset is the number of samples before the
538
+ spike index and end_offset is the number of samples after the spike index.
539
+ direction: Specifies the direction of spikes to detect. Default is set to "both", meaning it
540
+ extracts waveforms for both positive and negative spikes.
541
+ Output:
542
+ waveforms: Extracted waveforms of the detected spikes.
543
+ It first detects spikes using the detect_spikes() function.
544
+ For each detected spike, it extracts the waveform within the specified window
545
+ around the spike index.
546
+ The extracted waveforms are stored in a NumPy array and returned.
547
+ """
548
+ spike_idx, _ = detect_spikes(data, Fs, mSTD=mSTD, direction=direction)
549
+ num_spikes = len(spike_idx)
550
+ win_start = win[0] + 1
551
+ win_end = win[1] + 1
552
+ waveform_length = win_end - win_start
553
+
554
+ waveforms = np.empty((num_spikes, waveform_length)) * np.nan
555
+
556
+ for i, idx in enumerate(spike_idx):
557
+ start_idx = int(idx + win_start)
558
+ end_idx = int(idx + win_end)
559
+
560
+ # Ensure the start and end indices are within bounds
561
+ if start_idx >= 0 and end_idx <= data.shape[0]:
562
+ waveforms[i, :] = data[start_idx:end_idx]
563
+
564
+ # Remove rows with NaN values (corresponding to spikes outside bounds)
565
+ waveforms = waveforms[~np.isnan(waveforms).all(axis=1)]
566
+
567
+ print(f"Extracted waveforms number: {waveforms.shape}")
568
+
569
+ return waveforms
570
+
571
+ def extract_peaks_waveforms(data, pks, win_size):
572
+ """
573
+ Extracts waveforms from data centered around peak indices.
574
+
575
+ Parameters:
576
+ - data (1d-array): The data array from which waveforms are to be extracted.
577
+ - pks (list): A list of peak indices.
578
+ - win_size: A tuple specifying the window size around each peak.
579
+ It should be of the form (start_offset, end_offset).
580
+
581
+ Returns:
582
+ - waveform_array: A 2D NumPy array containing the extracted waveforms.
583
+ Each row corresponds to a waveform.
584
+ """
585
+ waveforms = []
586
+ for i in pks:
587
+ start_index = int(i + win_size[0])
588
+ end_index = int(i + win_size[1])
589
+ waveforms.append(data[start_index:end_index])
590
+ waveform_array = np.array(waveforms)
591
+ return waveform_array
592
+ # usage: win_size = [-500, 500]
593
+ # waveform = extract_peaks_waveforms(
594
+ # data=data_ds, pks=res_sos.sos.pks_neg_idx, win_size=win_size
595
+ # )
596
+ # Function to find the closest timestamp
597
+ def find_closest_timestamp(time_rel, time_fly):
598
+ closest_idx = np.argmin(np.abs(time_fly - time_rel))
599
+ return time_fly[closest_idx]
600
+
601
+
602
+ def coupling_finder(rel_pks, fly_pks, win, verbose=False):
603
+ pks_cp_rel = []
604
+ pks_cp_fly = []
605
+ for rel_pk in rel_pks:
606
+ closest_pt = find_closest_timestamp(rel_pk, fly_pks)
607
+ delta_t = closest_pt - rel_pk
608
+ if abs(delta_t) <= win:
609
+ pks_cp_rel.append(rel_pk)
610
+ pks_cp_fly.append(closest_pt)
611
+
612
+ # Calculate coupling rate
613
+ if not pks_cp_rel:
614
+ cp_rate = 0
615
+ else:
616
+ cp_rate = (len(pks_cp_rel) / len(rel_pks)) * 100
617
+
618
+ if verbose:
619
+ print(f"Coupling Rate: {cp_rate}%")
620
+
621
+ return pks_cp_rel, pks_cp_fly, cp_rate
622
+
623
+ def perm_circ(data_1d):
624
+ # to perform circular permutation
625
+ permuted_1d = np.roll(data_1d, np.random.randint(len(data_1d)))
626
+ return permuted_1d
627
+
628
+
629
+ def coupling_permutate(rel_pks, fly_pks, win, n_perm=1000):
630
+ # Function to simulate SWR-shuffled condition
631
+ pks_cp_rel_shuf, pks_cp_fly_shuf, cp_rate_shuf = [], [], []
632
+ for _ in range(n_perm):
633
+ pks_rel_shuf = perm_circ(rel_pks)
634
+ pks_cp_rel_tmp, pks_cp_fly_tmp, cp_rate_tmp = coupling_finder(
635
+ rel_pks=pks_rel_shuf, fly_pks=fly_pks, win=win, verbose=False
636
+ )
637
+ pks_cp_rel_shuf.append(pks_cp_rel_tmp)
638
+ pks_cp_fly_shuf.append(pks_cp_fly_tmp)
639
+ cp_rate_shuf.append(cp_rate_tmp)
640
+ return pks_cp_rel_shuf, pks_cp_fly_shuf, cp_rate_shuf
641
+
642
+ def detect_spindles(data, opt):
643
+ # usage: prepare the opt cfg
644
+ # opt = pd.DataFrame(
645
+ # {
646
+ # "spin": {
647
+ # "thr": [1.5, 2, 2.5],
648
+ # "dur_sec": [0.5, 2.5],
649
+ # "freq_range": [11, 16],
650
+ # "stage": 2,
651
+ # "smth": False,
652
+ # },
653
+ # "info": {
654
+ # "fs": 1000,
655
+ # "epoch_dur": 10,
656
+ # "dir_score": "/Users/macjianfeng/DataCenter/Meflo-SSD/Data_Scored_Txt/R6Rec1_scoring.txt",
657
+ # },
658
+ # }
659
+ # )
660
+
661
+ fs = opt["info"]["fs"]
662
+ epoch_dur = opt["info"]["epoch_dur"]
663
+
664
+ # Filter
665
+ # amp_filt = butter_band_filter(
666
+ # data=data,
667
+ # lowcut=opt["spin"]["freq_range"][0],
668
+ # highcut=opt["spin"]["freq_range"][1],
669
+ # fs=fs,
670
+ # )
671
+ amp_filt = filter_bandpass(
672
+ ord=opt['spin']['filt_ord'],
673
+ data=data,
674
+ freq_range=opt["spin"]["freq_range"],
675
+ fs=fs,
676
+ )
677
+
678
+ # Calculate amp_filt_env using Hilbert transform
679
+ amp_filt_env = np.abs(hilbert(amp_filt))
680
+
681
+ # Apply additional smoothing (moving average with 200-ms window size)
682
+ if opt["spin"]["smth"]:
683
+ # Calculate mean and standard deviation of amp_filt_env
684
+ amp_filt_env_mean = np.mean(moving_average(amp_filt_env, int(0.2 * fs)))
685
+ amp_filt_env_std = np.std(moving_average(amp_filt_env, int(0.2 * fs)))
686
+ else:
687
+ # Calculate mean and standard deviation of amp_filt_env
688
+ amp_filt_env_mean = np.mean(amp_filt_env)
689
+ amp_filt_env_std = np.std(amp_filt_env)
690
+
691
+ # 2.3 filling in one matrix
692
+ Thr = []
693
+ for m_std in opt["spin"]["thr"]:
694
+ Thr.append(amp_filt_env_std * m_std)
695
+ # 2.4 use the defined Thresholds
696
+ if len(Thr) >= 1:
697
+ a, b = detect_cross(amp_filt_env, Thr[0])
698
+ else:
699
+ raise ValueError("Didn not find the 1st spi.Thr")
700
+
701
+ RisingFalling = np.column_stack((a, b))
702
+ dur_sec = opt["spin"]["dur_sec"]
703
+ Dura_tmp = np.diff(RisingFalling, axis=1)
704
+ Thr1Spin1 = RisingFalling[
705
+ np.where((dur_sec[0] * fs < Dura_tmp) & (Dura_tmp < dur_sec[1] * fs)),
706
+ :,
707
+ ][0]
708
+ Thr1Spin1 = Thr1Spin1.reshape(-1, 2)
709
+
710
+ # 2.4.1.2 calcultion the EventsSpin1 in NREM (specific sleep stages) or not
711
+ score_code = extract_score(opt["info"]["dir_score"],scalar=fs * epoch_dur)
712
+ stage_spin_idx = find_repeats(
713
+ score_code, opt["spin"]["stage"], nGap=1
714
+ )
715
+ stage_spin_idx = stage_spin_idx[
716
+ np.where(
717
+ (stage_spin_idx[:, 0] >= fs * epoch_dur)
718
+ & (stage_spin_idx[:, 1] <= len(amp_filt) - fs * epoch_dur)
719
+ )
720
+ ]
721
+ EventsSpin1 = Thr1Spin1[
722
+ np.where(
723
+ (stage_spin_idx[:, 0] < Thr1Spin1[:, 0].reshape(-1, 1))
724
+ & (Thr1Spin1[:, 1].reshape(-1, 1) < stage_spin_idx[:, 1])
725
+ )[0],
726
+ :,
727
+ ]
728
+ # print("step1", EventsSpin1.shape)
729
+ # 2.4.2 Thr2 crossing
730
+ if len(Thr) >= 2:
731
+ a, b = detect_cross(amp_filt_env, Thr[1])
732
+ RisingFalling = np.column_stack((a, b))
733
+ SpinDura_min2 = dur_sec[0] / 2 # half of the minium duration
734
+ SpinDura_max2 = dur_sec[1]
735
+ Dura_tmp = np.diff(RisingFalling, axis=1)
736
+ EventsSpin2 = RisingFalling[
737
+ np.where(
738
+ (SpinDura_min2 * fs <= Dura_tmp) & (Dura_tmp <= SpinDura_max2 * fs)
739
+ ),
740
+ :,
741
+ ]
742
+ else:
743
+ EventsSpin2 = np.copy(EventsSpin1)
744
+ EventsSpin2 = EventsSpin2.reshape(-1, 2)
745
+ # print("step2", EventsSpin2.shape)
746
+ # 2.4.2.3 check EventsSpin2 in EventsSpin1
747
+ EventsSpin3 = []
748
+ if (
749
+ ("EventsSpin1" in locals())
750
+ and ("EventsSpin2" in locals())
751
+ and (EventsSpin1.shape[0] != 0)
752
+ and (EventsSpin2.shape[0] != 0)
753
+ ):
754
+ EventsSpin3 = EventsSpin1[
755
+ np.where(
756
+ (EventsSpin1[:, 0] < EventsSpin2[:, 0].reshape(-1, 1))
757
+ & (EventsSpin2[:, 1].reshape(-1, 1) < EventsSpin1[:, 1])
758
+ )[1],
759
+ :,
760
+ ]
761
+ # print("step3", EventsSpin3.shape)
762
+ # 2.4.2.4 unique EventsSpin3
763
+ if EventsSpin3.shape[0] != 0:
764
+ EventsSpin3_orgs = np.copy(EventsSpin3)
765
+ EventsSpin3 = np.unique(EventsSpin3_orgs[:, 0:2], axis=0)
766
+
767
+ if len(Thr) >= 3:
768
+ # 2.4.3 Crossing positions - Thr 3
769
+ EventsSpin4 = []
770
+ iSpi4 = 0
771
+ if EventsSpin3.shape[0] != 0:
772
+ for iSpi3 in range(EventsSpin3.shape[0]):
773
+ if (
774
+ np.max(amp_filt_env[EventsSpin3[iSpi3, 0] : EventsSpin3[iSpi3, 1]])
775
+ >= Thr[2]
776
+ ):
777
+ EventsSpin4.append([EventsSpin3[iSpi3, 0], EventsSpin3[iSpi3, 1]])
778
+ iSpi4 += 1
779
+ if isinstance(EventsSpin4, list):
780
+ EventsSpin4 = np.array(EventsSpin4)
781
+ EventsSpin4 = EventsSpin4.reshape(-1, 2)
782
+ else:
783
+ EventsSpin4 = EventsSpin3.copy()
784
+ else:
785
+ EventsSpin4 = EventsSpin3.copy()
786
+ print("\ncannot find the 3rd Thr_spin, only 2 Thr were used for spin dtk \n")
787
+ # print("step4", EventsSpin4.shape)
788
+ # 2.5 checking if two spindles are too close? gap should be more than 50 ms;
789
+ if "EventsSpin4" in locals() and EventsSpin4.shape[0] != 0:
790
+ iSpin4 = 0
791
+ EventsSpin5 = []
792
+ for iSpin in range(1, EventsSpin4.shape[0]):
793
+ tmp_gap = (
794
+ EventsSpin4[iSpin, 0] - EventsSpin4[iSpin - 1, 1]
795
+ ) / fs # in second
796
+ if (
797
+ tmp_gap <= 0.05
798
+ ): # gap less than SpinDura_min and the total duration should not more than SpinDura_max1
799
+ EventsSpin5.append([EventsSpin4[iSpin - 1, 0], EventsSpin4[iSpin, 1]])
800
+ else:
801
+ EventsSpin5.append(list(EventsSpin4[iSpin]))
802
+ iSpin4 += 1
803
+ else:
804
+ EventsSpin5 = EventsSpin4.copy()
805
+ if isinstance(EventsSpin5, list):
806
+ EventsSpin5 = np.array(EventsSpin5)
807
+ EventsSpin5 = EventsSpin5.reshape(-1, 2)
808
+ # print("step5", EventsSpin5.shape)
809
+ if "EventsSpin5" in locals():
810
+ # 2.5.2 merge into one spindles
811
+ if EventsSpin5.shape[0] != 0:
812
+ EventsSpin5_diff_merge = np.where(
813
+ np.diff(np.hstack((0, EventsSpin5[:, 0]))) == 0
814
+ )[0]
815
+ EventsSpin5_diff_rm = EventsSpin5_diff_merge - 1
816
+ EventsSpin5 = np.delete(
817
+ EventsSpin5, EventsSpin5_diff_rm, axis=0
818
+ ) # remove the merged parts
819
+ del EventsSpin5_diff_rm, EventsSpin5_diff_merge
820
+
821
+ # 2.5.3 remove the last 5s recording;
822
+ RecTail = len(amp_filt) # in sample resolution
823
+ Last5s = RecTail - epoch_dur / 2 * fs # half epoch
824
+ if EventsSpin5.shape[0] != 0:
825
+ for iSpin in range(EventsSpin5.shape[0], 0, -1):
826
+ if EventsSpin5[iSpin - 1, 1] <= Last5s:
827
+ EventsSpin6 = EventsSpin5[0:iSpin, :]
828
+ break
829
+ else:
830
+ EventsSpin6 = EventsSpin5.copy()
831
+ else:
832
+ EventsSpin6 = EventsSpin5.copy()
833
+ else:
834
+ EventsSpin6 = EventsSpin5.copy()
835
+ EventsSpin6 = EventsSpin6.reshape(-1, 2)
836
+ # print("step6", EventsSpin6.shape)
837
+ # 2.5.3 spin2spin duration should not beyond the 'SpinDura_max2'
838
+ if len(Thr) >= 2:
839
+ EventsSpin = EventsSpin6[
840
+ np.where(np.diff(EventsSpin6, axis=1) <= SpinDura_max2 * fs)[0], :
841
+ ]
842
+ else:
843
+ EventsSpin = EventsSpin6.copy()
844
+ EventsSpin = EventsSpin.reshape(-1, 2)
845
+ # print("final detected spindle number: ", EventsSpin.shape)
846
+ # 2.6 Spindle density (counts during NREM) Spin density (events/min)
847
+ # calculated as the number of spindle detected in each recording site
848
+ # divided by the time in SWS.
849
+ SpinDensity = EventsSpin.shape[0] / (
850
+ np.sum(np.diff(stage_spin_idx, axis=1)) / fs / 60
851
+ ) # in minute
852
+ # print("spindle density: ", SpinDensity)
853
+
854
+ # Freq of each Spindles
855
+ num_spin_pks = []
856
+ for i in range(EventsSpin.shape[0]):
857
+ peaks, _ = find_peaks(
858
+ amp_filt[EventsSpin[i, 0] : EventsSpin[i, 1]], height=Thr[0]
859
+ ) # Assuming Thr is a scalar
860
+ num_spin_pks.append(len(peaks))
861
+
862
+ dur_spin_sec = np.diff(EventsSpin, axis=1) / fs
863
+ spin_freq = [
864
+ (x / y).tolist()[0] for (x, y) in zip(np.array(num_spin_pks), dur_spin_sec)
865
+ ]
866
+ spin_avg_freq = np.nanmean(spin_freq, axis=0)
867
+ # print(f"Average spindle frequency: {spin_avg_freq:.4f} Hz")
868
+
869
+ # Spindle Power
870
+ spin_pow_single = []
871
+ for iPow in range(EventsSpin.shape[0]):
872
+ spin_pow_single.append(np.trapz(amp_filt_env[EventsSpin[i, 0] : EventsSpin[i, 1]]))
873
+
874
+ # find the max pks loc
875
+ if EventsSpin.shape[0] > 1:
876
+ spin_pk2pk = []
877
+ spin_pks_loc = []
878
+ loc_max_spin=[]
879
+ ipk = 0
880
+ for ispin in range(EventsSpin.shape[0]):
881
+ tmp = amp_filt[
882
+ EventsSpin[ispin, 0] : EventsSpin[ispin, 1] + 1
883
+ ] # +1 to include the end index
884
+ # (1) find pks_max
885
+ locs_max, _ = find_peaks(list(tmp))
886
+ pks_max = tmp[locs_max]
887
+ pks_max_spin = np.max(pks_max)
888
+ loc_max_spin_ = (
889
+ locs_max[np.where(pks_max == pks_max_spin)[0][0]] + EventsSpin[ispin, 0]
890
+ )
891
+ loc_max_spin.append(loc_max_spin_)
892
+ # (2) find pks_min
893
+ pks_min = tmp[locs_max]
894
+ pks_min = pks_min * (-1) # don't forget to multiply by -1
895
+ pks_min_spin = np.min(pks_min)
896
+ loc_min_spin_ = (
897
+ locs_max[np.where(pks_min == pks_min_spin)[0][0]] + EventsSpin[ispin, 0]
898
+ )
899
+ # (3) spin_pk2pk
900
+ spin_pk2pk.append(pks_max_spin - pks_min_spin)
901
+ spin_pks_loc.append(
902
+ [loc_min_spin_, pks_min_spin, loc_max_spin_, pks_max_spin]
903
+ )
904
+ ipk += 1
905
+ spin_pk2pk = np.array(spin_pk2pk)
906
+ spin_pks_loc = np.array(spin_pks_loc)
907
+ else:
908
+ spin_pks_loc = np.array([])
909
+ spin_pk2pk = np.array([])
910
+ if opt.spin['extract']:
911
+ waveform = extract_peaks_waveforms(data=data, pks=loc_max_spin, win_size=opt.spin["win_size"])
912
+ print(f"detected {EventsSpin.shape[0]} spindles, density={SpinDensity}")
913
+ # fillint output
914
+ res = pd.DataFrame(
915
+ {
916
+ "spin": {
917
+ "idx_start_stop": EventsSpin,
918
+ "num":EventsSpin.shape[0],
919
+ "density": SpinDensity,
920
+ "thr": Thr,
921
+ "freq": spin_freq,
922
+ "pow": spin_pow_single,
923
+ "pk2pk": spin_pk2pk,
924
+ "pk2pk_loc": spin_pks_loc,
925
+ "max_pks_loc":loc_max_spin,
926
+ "avg_freq":spin_avg_freq,
927
+ "win_size":opt.spin["win_size"],
928
+ "waveform":waveform
929
+ }
930
+ }
931
+ )
932
+ del amp_filt_env
933
+ del amp_filt
934
+ del data
935
+ return res
936
+
937
+ def detect_sos(data, opt):
938
+ fs = opt["info"]["fs"]
939
+ epoch_dur = opt["info"]["epoch_dur"]
940
+
941
+ amp_filt = filter_bandpass(
942
+ ord=opt["sos"]["filt_ord"],
943
+ data=data,
944
+ freq_range=opt["sos"]["freq_range"],
945
+ fs=fs,
946
+ )
947
+ # zero_cross
948
+ rise_b4, fall_b4 = detect_cross(amp_filt, 0)
949
+ loc_cross = np.zeros((len(fall_b4) - 1, 2))
950
+ # [falling1, falling2; falling2, faling3;....]
951
+ loc_cross[:, 0] = [x for x in fall_b4[:-1]]
952
+ loc_cross[:, 1] = [x for x in fall_b4[1:]]
953
+ # #+++++++++ check the loc_cross+++++++++
954
+ # t = np.arange(len(amp_filt))
955
+ # plt.figure(figsize=[6, 2])
956
+ # plt.plot(
957
+ # t[int(loc_cross[0][0] - 50) : int(loc_cross[0][1] + 50)],
958
+ # amp_filt[int(loc_cross[0][0] - 50) : int(loc_cross[0][1] + 50)],
959
+ # lw=0.75,
960
+ # )
961
+ # plt.plot(
962
+ # t[int(loc_cross[0][0]) : int(loc_cross[0][1])],
963
+ # amp_filt[int(loc_cross[0][0]) : int(loc_cross[0][1])],
964
+ # lw=1.5,
965
+ # c="r",
966
+ # )
967
+ # plt.axhline(0, lw=0.75)
968
+ # plt.show()
969
+ # # sos candidates within NREM_idx time-frame
970
+ score_code = extract_score(
971
+ opt["info"]["dir_score"], kron_go=True, scalar=fs * epoch_dur
972
+ )
973
+ so_stage_idx = find_repeats(score_code, opt["sos"]["stage"], nGap=1)
974
+ so_stage_idx = so_stage_idx[np.where(so_stage_idx[:, 0] >= fs * epoch_dur)]
975
+ so_stage_idx = so_stage_idx[
976
+ np.where(so_stage_idx[:, 1] <= (len(data) - fs * epoch_dur))
977
+ ]
978
+ sos_kndt_Loc = loc_cross[
979
+ np.where(
980
+ (so_stage_idx[:, 0] < loc_cross[:, 0].reshape(-1, 1))
981
+ & (loc_cross[:, 1].reshape(-1, 1) < so_stage_idx[:, 1])
982
+ )[0],
983
+ :,
984
+ ]
985
+
986
+ dur_sec = opt["sos"]["dur_sec"]
987
+ dura_tmp = np.diff(sos_kndt_Loc, axis=1)
988
+ event_sos_loc = sos_kndt_Loc[
989
+ np.where((dur_sec[0] * fs < dura_tmp) & (dura_tmp < dur_sec[1] * fs))[0].tolist(), :
990
+ ]
991
+ event_sos_loc = np.array(
992
+ [(int(x), int(y)) for (x, y) in np.array(event_sos_loc).reshape(-1, 2)]
993
+ ).reshape(
994
+ -1, 2
995
+ ) # int
996
+ sos_pks_pos_idx = []
997
+ sos_pks_neg_idx = []
998
+ sos_pks_neg_value = []
999
+ sos_pk2pk = []
1000
+
1001
+ for iso in range(event_sos_loc.shape[0]):
1002
+ # max
1003
+ sos_max_tmp = np.max(amp_filt[event_sos_loc[iso, 0] : event_sos_loc[iso, 1]])
1004
+ sos_pks_idx_max_tmp = list(
1005
+ amp_filt[event_sos_loc[iso, 0] : event_sos_loc[iso, 1]]
1006
+ ).index(sos_max_tmp)
1007
+ sos_pks_pos_idx.append(int(event_sos_loc[iso, 0] + sos_pks_idx_max_tmp))
1008
+ # min
1009
+ sos_min_tmp = np.min(amp_filt[event_sos_loc[iso, 0] : event_sos_loc[iso, 1]])
1010
+ sos_pks_idx_min_tmp = list(
1011
+ amp_filt[event_sos_loc[iso, 0] : event_sos_loc[iso, 1]]
1012
+ ).index(sos_min_tmp)
1013
+ sos_pks_neg_idx.append(int(event_sos_loc[iso, 0] + sos_pks_idx_min_tmp))
1014
+ sos_pks_neg_value.append(sos_min_tmp)
1015
+ # pk2pk
1016
+ sos_pk2pk.append(sos_max_tmp + np.abs(sos_min_tmp))
1017
+ if isinstance(sos_pks_neg_value, list):
1018
+ sos_pks_neg_value = np.array(sos_pks_neg_value)
1019
+ sos_pk2pk = np.array(sos_pk2pk)
1020
+ if opt["sos"]["thr"] == []:
1021
+ n_prctile_amplitude = opt["sos"]["n_prctile_amplitude"]
1022
+ thr_negpks_amp = np.percentile(
1023
+ np.abs(sos_pks_neg_value), n_prctile_amplitude, axis=0
1024
+ )
1025
+ thr_pks2pks = np.percentile(sos_pk2pk, n_prctile_amplitude, axis=0)
1026
+ sos_thr = np.array([-thr_negpks_amp, thr_pks2pks])
1027
+ else:
1028
+ if len(opt["sos"]["thr"]) == 1:
1029
+ thr_negpks_amp = abs(opt["sos"]["thr"][0])
1030
+ sos_thr = np.array([-thr_negpks_amp])
1031
+ elif len(opt["sos"]["thr"]) == 2:
1032
+ thr_negpks_amp = abs(opt["sos"]["thr"][0])
1033
+ thr_pks2pks = opt["sos"]["thr"][1]
1034
+ sos_thr = np.array([-thr_negpks_amp, thr_pks2pks])
1035
+ ithr = 1
1036
+ sos_loc = []
1037
+ (
1038
+ abs(sos_pks_neg_value[iso]) > sos_thr[0] and abs(sos_pk2pk[iso]) > sos_thr[1]
1039
+ if len(sos_thr) == 2
1040
+ else abs(sos_pks_neg_value[iso]) > sos_thr[0]
1041
+ )
1042
+ if "event_sos_loc" in locals() and event_sos_loc.shape[0] != 0:
1043
+ for iso in range(sos_pk2pk.shape[0]):
1044
+ thr_criterion = (
1045
+ abs(sos_pks_neg_value[iso]) > sos_thr[0]
1046
+ and abs(sos_pk2pk[iso]) > sos_thr[1]
1047
+ if len(sos_thr) == 2
1048
+ else abs(sos_pks_neg_value[iso]) > sos_thr[0]
1049
+ )
1050
+
1051
+ if thr_criterion:
1052
+ sos_loc.append(
1053
+ [
1054
+ event_sos_loc[iso, 0],
1055
+ event_sos_loc[iso, 1],
1056
+ sos_pks_neg_idx[iso],
1057
+ amp_filt[sos_pks_neg_idx[iso]],
1058
+ sos_pks_pos_idx[iso],
1059
+ amp_filt[sos_pks_pos_idx[iso]],
1060
+ ]
1061
+ )
1062
+ ithr += 1
1063
+
1064
+ if len(sos_loc) != 0:
1065
+ sos_loc = np.array(sos_loc)
1066
+ sos_idx = sos_loc[:, :2].astype(int)
1067
+ sos_pks_neg_idx = sos_loc[:, 2].astype(int)
1068
+ sos_pks_pos_idx = sos_loc[:, 4].astype(int)
1069
+ sos_pks_loc = sos_loc[:, 2:6]
1070
+ sos_pk2pk = sos_loc[:, 5] - sos_loc[:, 3]
1071
+ # 3.7 sos density
1072
+ if "sos_idx" in locals():
1073
+ sos_density = len(sos_idx) / (
1074
+ np.sum(np.diff(so_stage_idx, axis=1)) / fs / 60
1075
+ ) # in minutes
1076
+ # 3.8 Freq of each sos
1077
+ if "sos_idx" in locals():
1078
+ # sos Power and sos Freq
1079
+ # amp_filt = np.abs(hilbert(amp_filt))
1080
+ sos_power = np.zeros((len(sos_idx), 1))
1081
+ sos_dura_sec = np.diff(sos_idx, axis=1) / fs
1082
+ for i in range(len(sos_idx)):
1083
+ sos_power[i] = np.trapz(amp_filt[sos_idx[i, 0] : sos_idx[i, 1]])
1084
+ sos_freq = 1 / sos_dura_sec.flatten() # Frequency
1085
+ # 3.10 slope_sos
1086
+ # calculating the slope of slow wave events by taking the difference between
1087
+ # the positive and negative peaks and dividing it by the respective time
1088
+ # interval.
1089
+ #
1090
+ # Calculate slope for each slow wave event
1091
+ sos_slope = np.zeros(len(sos_pks_pos_idx))
1092
+ for i in range(len(sos_pks_pos_idx)):
1093
+ # Calculate the slope as (positive peak - negative peak) / time interval
1094
+ sos_slope[i] = (amp_filt[sos_pks_pos_idx[i]] - amp_filt[sos_pks_neg_idx[i]]) / (
1095
+ sos_pks_pos_idx[i] - sos_pks_neg_idx[i]
1096
+ )
1097
+ if opt.sos["extract"]:
1098
+ waveform = extract_peaks_waveforms(
1099
+ data=data, pks=sos_pks_neg_idx, win_size=opt.sos["win_size"]
1100
+ )
1101
+ else:
1102
+ waveform = []
1103
+ print(f"detected {sos_idx.shape[0]} SOs, density={sos_density}")
1104
+ # fillint output
1105
+ res = pd.DataFrame(
1106
+ {
1107
+ "sos": {
1108
+ "idx_start_stop": sos_idx,
1109
+ "num":sos_idx.shape[0],
1110
+ "density": sos_density,
1111
+ "thr": sos_thr,
1112
+ "slope": sos_slope,
1113
+ "freq": sos_freq,
1114
+ "pks_neg_idx": sos_pks_neg_idx,
1115
+ "pks_pos_idx": sos_pks_pos_idx,
1116
+ "pk2pk": sos_pk2pk,
1117
+ "pks_loc": sos_pks_loc,
1118
+ "pow": sos_power,
1119
+ "win_size":opt.spin["win_size"],
1120
+ "waveform":waveform
1121
+ },
1122
+ }
1123
+ )
1124
+
1125
+ # # debug++++++++plot the grandaverage sos+++++++++++
1126
+ # win_size = 1.5
1127
+ # waveform = []
1128
+ # for pks_ne_idx in res.sos.pks_neg_idx:
1129
+ # waveform.append(
1130
+ # amp_filt[int(pks_ne_idx - win_size * fs) : int(pks_ne_idx + win_size * fs)]
1131
+ # )
1132
+ # waveform = np.array(waveform).reshape(-1, int(win_size * 2 * fs))
1133
+ # # plot
1134
+ # fig, axs = plt.subplots(1, 1, figsize=[8, 3])
1135
+
1136
+ # stdshade(
1137
+ # axs,
1138
+ # range(waveform.shape[1]),
1139
+ # waveform,
1140
+ # 0.39,
1141
+ # [x / 255 for x in [48, 109, 99]],
1142
+ # 30,
1143
+ # )
1144
+ # plt.axvline(waveform.shape[1] / 2, c="k", label="t0='sos negtive peak'")
1145
+ # plt.axhline(0, c=".6")
1146
+ # plt.legend()
1147
+ # plt.show()
1148
+ del amp_filt
1149
+ del data
1150
+ return res
1151
+
1152
+ def detect_ripples(data, opt):
1153
+ fs = opt["info"]["fs"]
1154
+ epoch_dur = opt["info"]["epoch_dur"]
1155
+
1156
+ amp_filt = filter_bandpass(
1157
+ ord=opt["rip"]["filt_ord"],
1158
+ data=data,
1159
+ freq_range=opt["rip"]["freq_range"],
1160
+ fs=fs,
1161
+ )
1162
+
1163
+ # Calculate amp_filt_env using Hilbert transform
1164
+ amp_filt_env = np.abs(hilbert(amp_filt))
1165
+ # Apply additional smoothing (moving average with 200-ms window size)
1166
+ if opt["rip"]["smth"]:
1167
+ # Calculate mean and standard deviation of amp_filt_env
1168
+ amp_filt_env_mean = np.mean(moving_average(amp_filt_env, int(0.2 * fs)))
1169
+ amp_filt_env_std = np.std(moving_average(amp_filt_env, int(0.2 * fs)))
1170
+ else:
1171
+ # Calculate mean and standard deviation of amp_filt_env
1172
+ amp_filt_env_mean = np.mean(amp_filt_env)
1173
+ amp_filt_env_std = np.std(amp_filt_env)
1174
+ # 2.3 filling in one matrix
1175
+ Thr = np.array(opt["rip"]["thr"]) * amp_filt_env_std
1176
+ # 2.4 use the defined Thresholds
1177
+ if len(Thr) >= 1:
1178
+ a, b = detect_cross(amp_filt_env, Thr[0])
1179
+ else:
1180
+ raise ValueError("Didn not find the 1st spi.Thr")
1181
+ RisingFalling = np.column_stack((a, b))
1182
+ rip_dura_sec = opt["rip"]["dur_sec"]
1183
+ Dura_tmp = np.diff(RisingFalling, axis=1)
1184
+ Thr1rip1 = RisingFalling[
1185
+ np.where((rip_dura_sec[0] * fs < Dura_tmp) & (Dura_tmp < rip_dura_sec[1] * fs)),
1186
+ :,
1187
+ ][0]
1188
+ # 2.4.1.2 calcultion the EventsRip1 in NREM (specific sleep stages) or not
1189
+ score_code = extract_score(opt["info"]["dir_score"], scalar=fs * epoch_dur)
1190
+ stage_rip_idx = find_repeats(score_code, opt["rip"]["stage"], nGap=1)
1191
+ stage_rip_idx = stage_rip_idx[
1192
+ np.where(
1193
+ (stage_rip_idx[:, 0] >= fs * epoch_dur)
1194
+ & (stage_rip_idx[:, 1] <= len(amp_filt) - fs * epoch_dur)
1195
+ )
1196
+ ]
1197
+ EventsRip1 = Thr1rip1[
1198
+ np.where(
1199
+ (stage_rip_idx[:, 0] < Thr1rip1[:, 0].reshape(-1, 1))
1200
+ & (Thr1rip1[:, 1].reshape(-1, 1) < stage_rip_idx[:, 1])
1201
+ )[0],
1202
+ :,
1203
+ ]
1204
+ # print("step1", EventsRip1.shape)
1205
+ # 2.4.2 Thr2 crossing
1206
+ if len(Thr) >= 2:
1207
+ a, b = detect_cross(amp_filt_env, Thr[1])
1208
+ RisingFalling = np.column_stack((a, b))
1209
+ ripDura_min2 = rip_dura_sec[0] / 2 # half of the minium duration
1210
+ ripDura_max2 = rip_dura_sec[1]
1211
+ Dura_tmp = np.diff(RisingFalling, axis=1)
1212
+ EventsRip2 = RisingFalling[
1213
+ np.where((ripDura_min2 * fs <= Dura_tmp) & (Dura_tmp <= ripDura_max2 * fs)),
1214
+ :,
1215
+ ][0]
1216
+ else:
1217
+ EventsRip2 = np.copy(EventsRip1)
1218
+ # print("step2", EventsRip2.shape)
1219
+ # 2.4.2.3 check EventsRip2 in EventsRip1
1220
+ if (
1221
+ ("EventsRip1" in locals())
1222
+ and ("EventsRip2" in locals())
1223
+ and (EventsRip1.shape[0] != 0)
1224
+ and (EventsRip2.shape[0] != 0)
1225
+ ):
1226
+ EventsRip3 = EventsRip1[
1227
+ np.where(
1228
+ (EventsRip1[:, 0] < EventsRip2[:, 0].reshape(-1, 1))
1229
+ & (EventsRip2[:, 1].reshape(-1, 1) < EventsRip1[:, 1])
1230
+ )[1],
1231
+ :,
1232
+ ]
1233
+ # print("step3", EventsRip3.shape)
1234
+ # 2.4.2.4 unique EventsRip3
1235
+ if EventsRip3.shape[0] != 0:
1236
+ EventsRip3_orgs = np.copy(EventsRip3)
1237
+ EventsRip3 = np.unique(EventsRip3_orgs[:, 0:2], axis=0)
1238
+ if len(Thr) >= 3:
1239
+ # 2.4.3 Crossing positions - Thr 3
1240
+ EventsRip = []
1241
+ irip4 = 0
1242
+ if EventsRip3.shape[0] != 0:
1243
+ for irip3 in range(EventsRip3.shape[0]):
1244
+ if (
1245
+ np.max(amp_filt_env[EventsRip3[irip3, 0] : EventsRip3[irip3, 1]])
1246
+ >= Thr[2]
1247
+ ):
1248
+ EventsRip.append([EventsRip3[irip3, 0], EventsRip3[irip3, 1]])
1249
+ irip4 += 1
1250
+ if isinstance(EventsRip, list):
1251
+ EventsRip = np.array(EventsRip)
1252
+ EventsRip = EventsRip.reshape(-1, 2)
1253
+ else:
1254
+ EventsRip = EventsRip3.copy()
1255
+ else:
1256
+ EventsRip = EventsRip3.copy()
1257
+ print("\ncannot find the 3rd Thr_rip, only 2 Thr were used for rip dtk \n")
1258
+ EventsRip = EventsRip.reshape(-1, 2)
1259
+ # print("final detected ripple number: ", EventsRip.shape)
1260
+ # 2.6 ripple density (counts during NREM) rip density (events/min)
1261
+ # calculated as the number of ripple detected in each recording site
1262
+ # divided by the time in SWS.
1263
+ rip_density = EventsRip.shape[0] / (
1264
+ np.sum(np.diff(stage_rip_idx, axis=1)) / fs / 60
1265
+ ) # in minute
1266
+ print(f"detected {EventsRip.shape[0]} ripples, density={rip_density}")
1267
+ # Freq of each ripples
1268
+ num_rip_pks = []
1269
+ for i in range(EventsRip.shape[0]):
1270
+ peaks, _ = find_peaks(
1271
+ amp_filt[EventsRip[i, 0] : EventsRip[i, 1]], height=Thr[0]
1272
+ ) # Assuming Thr is a scalar
1273
+ num_rip_pks.append(len(peaks))
1274
+
1275
+ dur_rip_sec = np.diff(EventsRip, axis=1) / fs
1276
+ rip_freq = [(x / y).tolist()[0] for (x, y) in zip(np.array(num_rip_pks), dur_rip_sec)]
1277
+ rig_avg_freq = np.nanmean(rip_freq, axis=0)
1278
+ print(f"Average ripple frequency: {rig_avg_freq:.4f} Hz")
1279
+ # ripple Power # Use numpy's np.trapz() to compute the area under the curve
1280
+ rip_pow_single = [np.trapz(amp_filt_env[start:end]) for start, end in EventsRip]
1281
+
1282
+ if EventsRip.shape[0] > 1:
1283
+ loc_max_rip = [np.argmax(amp_filt[start:end]) + start for start, end in EventsRip]
1284
+ rip_pk2pk = [
1285
+ (np.max(amp_filt[start:end]) - np.min(amp_filt[start:end]))
1286
+ for start, end in EventsRip
1287
+ ]
1288
+ # rip_pk2pk = []
1289
+ # rip_pks_loc = []
1290
+ # loc_max_rip = []
1291
+ # ipk = 0
1292
+ # for irip in range(EventsRip.shape[0]):
1293
+ # tmp = amp_filt[
1294
+ # EventsRip[irip, 0] : EventsRip[irip, 1] + 1
1295
+ # ] # +1 to include the end index
1296
+ # # (1) find pks_max
1297
+ # locs_max, _ = find_peaks(list(tmp))
1298
+ # pks_max = tmp[locs_max]
1299
+ # pks_max_rip = np.max(pks_max)
1300
+ # loc_max_rip_ = (
1301
+ # locs_max[np.where(pks_max == pks_max_rip)[0][0]] + EventsRip[irip, 0]
1302
+ # )
1303
+ # loc_max_rip.append(loc_max_rip_)
1304
+
1305
+ # # (2) find pks_min
1306
+ # pks_min = tmp[locs_max]
1307
+ # pks_min = pks_min * (-1) # don't forget to multiply by -1
1308
+ # pks_min_rip = np.min(pks_min)
1309
+ # loc_min_rip_ = (
1310
+ # locs_max[np.where(pks_min == pks_min_rip)[0][0]] + EventsRip[irip, 0]
1311
+ # )
1312
+ # # (3) rip_pk2pk
1313
+ # rip_pk2pk.append(pks_max_rip - pks_min_rip)
1314
+ # rip_pks_loc.append([loc_min_rip_, pks_min_rip, loc_max_rip_, pks_max_rip])
1315
+ # ipk += 1
1316
+ # rip_pk2pk = np.array(rip_pk2pk)
1317
+ # rip_pks_loc = np.array(rip_pks_loc)
1318
+ else:
1319
+ # rip_pks_loc = np.array([])
1320
+ rip_pk2pk = np.array([])
1321
+ if opt.rip["extract"] and EventsRip.shape[0] > 1:
1322
+ waveform = extract_peaks_waveforms(
1323
+ data=data, pks=loc_max_rip, win_size=opt.rip["win_size"]
1324
+ )
1325
+ else:
1326
+ waveform = []
1327
+ # fillint output
1328
+ res = pd.DataFrame(
1329
+ {
1330
+ "rip": {
1331
+ "idx_start_stop": EventsRip,
1332
+ "num": EventsRip.shape[0],
1333
+ "density": rip_density,
1334
+ "thr": Thr,
1335
+ "freq": rip_freq,
1336
+ "pow": rip_pow_single,
1337
+ "pk2pk": rip_pk2pk,
1338
+ # "pk2pk_loc": rip_pks_loc,
1339
+ "avg_freq": rig_avg_freq,
1340
+ "max_pks_loc": loc_max_rip,
1341
+ "win_size": opt.spin["win_size"],
1342
+ "waveform": waveform,
1343
+ }
1344
+ }
1345
+ )
1346
+ del amp_filt
1347
+ del data
1348
+ del amp_filt_env
1349
+
1350
+ return res