vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. vbi/__init__.py +37 -0
  2. vbi/_version.py +17 -0
  3. vbi/dataset/__init__.py +0 -0
  4. vbi/dataset/connectivity_84/centers.txt +84 -0
  5. vbi/dataset/connectivity_84/centres.txt +84 -0
  6. vbi/dataset/connectivity_84/cortical.txt +84 -0
  7. vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
  8. vbi/dataset/connectivity_84/weights.txt +84 -0
  9. vbi/dataset/connectivity_88/Aud_88.txt +88 -0
  10. vbi/dataset/connectivity_88/Bold.npz +0 -0
  11. vbi/dataset/connectivity_88/Labels.txt +17 -0
  12. vbi/dataset/connectivity_88/Region_labels.txt +88 -0
  13. vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
  14. vbi/dataset/connectivity_88/weights.txt +88 -0
  15. vbi/feature_extraction/__init__.py +1 -0
  16. vbi/feature_extraction/calc_features.py +293 -0
  17. vbi/feature_extraction/features.json +535 -0
  18. vbi/feature_extraction/features.py +2124 -0
  19. vbi/feature_extraction/features_settings.py +374 -0
  20. vbi/feature_extraction/features_utils.py +1357 -0
  21. vbi/feature_extraction/infodynamics.jar +0 -0
  22. vbi/feature_extraction/utility.py +507 -0
  23. vbi/inference.py +98 -0
  24. vbi/models/__init__.py +0 -0
  25. vbi/models/cpp/__init__.py +0 -0
  26. vbi/models/cpp/_src/__init__.py +0 -0
  27. vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
  28. vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
  29. vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
  30. vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  31. vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  32. vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  33. vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
  34. vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
  35. vbi/models/cpp/_src/bold.hpp +303 -0
  36. vbi/models/cpp/_src/do.hpp +167 -0
  37. vbi/models/cpp/_src/do.i +17 -0
  38. vbi/models/cpp/_src/do.py +467 -0
  39. vbi/models/cpp/_src/do_wrap.cxx +12811 -0
  40. vbi/models/cpp/_src/jr_sdde.hpp +352 -0
  41. vbi/models/cpp/_src/jr_sdde.i +19 -0
  42. vbi/models/cpp/_src/jr_sdde.py +688 -0
  43. vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
  44. vbi/models/cpp/_src/jr_sde.hpp +264 -0
  45. vbi/models/cpp/_src/jr_sde.i +17 -0
  46. vbi/models/cpp/_src/jr_sde.py +470 -0
  47. vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
  48. vbi/models/cpp/_src/km_sde.hpp +158 -0
  49. vbi/models/cpp/_src/km_sde.i +19 -0
  50. vbi/models/cpp/_src/km_sde.py +671 -0
  51. vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
  52. vbi/models/cpp/_src/makefile +52 -0
  53. vbi/models/cpp/_src/mpr_sde.hpp +327 -0
  54. vbi/models/cpp/_src/mpr_sde.i +19 -0
  55. vbi/models/cpp/_src/mpr_sde.py +711 -0
  56. vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
  57. vbi/models/cpp/_src/utility.hpp +307 -0
  58. vbi/models/cpp/_src/vep.hpp +171 -0
  59. vbi/models/cpp/_src/vep.i +16 -0
  60. vbi/models/cpp/_src/vep.py +464 -0
  61. vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
  62. vbi/models/cpp/_src/wc_ode.hpp +294 -0
  63. vbi/models/cpp/_src/wc_ode.i +19 -0
  64. vbi/models/cpp/_src/wc_ode.py +686 -0
  65. vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
  66. vbi/models/cpp/damp_oscillator.py +143 -0
  67. vbi/models/cpp/jansen_rit.py +543 -0
  68. vbi/models/cpp/km.py +187 -0
  69. vbi/models/cpp/mpr.py +289 -0
  70. vbi/models/cpp/vep.py +150 -0
  71. vbi/models/cpp/wc.py +216 -0
  72. vbi/models/cupy/__init__.py +0 -0
  73. vbi/models/cupy/bold.py +111 -0
  74. vbi/models/cupy/ghb.py +284 -0
  75. vbi/models/cupy/jansen_rit.py +473 -0
  76. vbi/models/cupy/km.py +224 -0
  77. vbi/models/cupy/mpr.py +475 -0
  78. vbi/models/cupy/mpr_modified_bold.py +12 -0
  79. vbi/models/cupy/utils.py +184 -0
  80. vbi/models/numba/__init__.py +0 -0
  81. vbi/models/numba/_ww_EI.py +444 -0
  82. vbi/models/numba/damp_oscillator.py +162 -0
  83. vbi/models/numba/ghb.py +208 -0
  84. vbi/models/numba/mpr.py +383 -0
  85. vbi/models/pytorch/__init__.py +0 -0
  86. vbi/models/pytorch/data/default_parameters.npz +0 -0
  87. vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
  88. vbi/models/pytorch/data/input/fc_test.csv +68 -0
  89. vbi/models/pytorch/data/input/fc_train.csv +68 -0
  90. vbi/models/pytorch/data/input/fc_vali.csv +68 -0
  91. vbi/models/pytorch/data/input/fcd_test.mat +0 -0
  92. vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
  93. vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
  94. vbi/models/pytorch/data/input/fcd_train.mat +0 -0
  95. vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
  96. vbi/models/pytorch/data/input/myelin.csv +68 -0
  97. vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
  98. vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
  99. vbi/models/pytorch/data/input/sc_test.csv +68 -0
  100. vbi/models/pytorch/data/input/sc_train.csv +68 -0
  101. vbi/models/pytorch/data/input/sc_vali.csv +68 -0
  102. vbi/models/pytorch/data/obs_kong0.npz +0 -0
  103. vbi/models/pytorch/ww_sde_kong.py +570 -0
  104. vbi/models/tvbk/__init__.py +9 -0
  105. vbi/models/tvbk/tvbk_wrapper.py +166 -0
  106. vbi/models/tvbk/utils.py +72 -0
  107. vbi/papers/__init__.py +0 -0
  108. vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
  109. vbi/tests/__init__.py +0 -0
  110. vbi/tests/_test_mpr_nb.py +36 -0
  111. vbi/tests/test_features.py +355 -0
  112. vbi/tests/test_ghb_cupy.py +90 -0
  113. vbi/tests/test_mpr_cupy.py +49 -0
  114. vbi/tests/test_mpr_numba.py +84 -0
  115. vbi/tests/test_suite.py +19 -0
  116. vbi/utils.py +402 -0
  117. vbi-0.1.3.dist-info/METADATA +166 -0
  118. vbi-0.1.3.dist-info/RECORD +121 -0
  119. vbi-0.1.3.dist-info/WHEEL +5 -0
  120. vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
  121. vbi-0.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1357 @@
1
+ import vbi
2
+ import scipy
3
+ import torch
4
+ import numpy as np
5
+ from os.path import join
6
+ from typing import Union
7
+ from copy import deepcopy
8
+ import scipy.stats as stats
9
+ from numpy import linalg as LA
10
+ from sklearn.decomposition import PCA
11
+ from scipy.signal import butter, detrend, filtfilt, hilbert
12
+ from vbi.feature_extraction.features_settings import load_json
13
+ from vbi.feature_extraction.utility import *
14
+
15
+ try:
16
+ import jpype as jp
17
+ import ssm
18
+ except:
19
+ # logging.warning("jpype not imported.")
20
+ pass
21
+
22
+
23
+ def slice_features(x: Union[np.ndarray, torch.Tensor], feature_names: list, info: dict):
24
+ """
25
+ Slice features using given feature list
26
+
27
+ Parameters
28
+ ----------
29
+ x: array-like
30
+ features: list of strings
31
+ list of features
32
+ info: dict
33
+ features's colum indices in x
34
+
35
+ Returns
36
+ -------
37
+ x_sliced: array-like
38
+ sliced features
39
+ """
40
+ if isinstance(x, (list, tuple)):
41
+ x = np.array(x)
42
+
43
+ if x.ndim == 1:
44
+ x = x.reshape(1, -1)
45
+
46
+ is_tensor = isinstance(x, torch.Tensor)
47
+ if is_tensor:
48
+ x_sliced = torch.Tensor([])
49
+ else:
50
+ x_sliced = np.array([])
51
+
52
+ if len(feature_names) == 0:
53
+ return x_sliced
54
+
55
+ for f_name in feature_names:
56
+ if f_name in info:
57
+ coli, colf = info[f_name]["index"][0], info[f_name]["index"][1]
58
+ if is_tensor:
59
+ x_sliced = torch.cat((x_sliced, x[:, coli:colf]), dim=1)
60
+ else:
61
+ if x_sliced.size == 0:
62
+ x_sliced = x[:, coli:colf]
63
+ else:
64
+ x_sliced = np.concatenate((x_sliced, x[:, coli:colf]), axis=1)
65
+ else:
66
+ raise ValueError(f"{f_name} not in info")
67
+
68
+ return x_sliced
69
+
70
+
71
+ def preprocess(ts, fs=None, preprocess_dict={}, **kwargs):
72
+ """
73
+ Preprocess time series data
74
+
75
+ Parameters
76
+ ----------
77
+ ts : nd-array [n_regions, n_timepoints]
78
+ Input from which the features are extracted
79
+ fs : int
80
+ Sampling frequency, set to 1 if not used
81
+ preprocess_dict : dictionary
82
+ Dictionary of preprocessing options
83
+ **kwargs : dict
84
+ Additional arguments
85
+
86
+
87
+ """
88
+
89
+ if not preprocess_dict:
90
+ preprocess_dict = load_json(
91
+ vbi.__path__[0] + "/feature_extraction/preprocess.json"
92
+ )
93
+
94
+ if preprocess_dict["zscores"]["use"] == "yes":
95
+ ts = stats.zscore(ts, axis=1)
96
+ if preprocess_dict["offset"]["use"] == "yes":
97
+ value = preprocess_dict["offset"]["parameters"]["value"]
98
+ ts = ts[:, value:]
99
+
100
+ if preprocess_dict["demean"]["use"] == "yes":
101
+ ts = ts - np.mean(ts, axis=1)[:, None]
102
+
103
+ if preprocess_dict["detrend"]["use"] == "yes":
104
+ ts = detrend(ts, axis=1)
105
+
106
+ if preprocess_dict["filter"]["use"] == "yes":
107
+ low_cut = preprocess_dict["filter"]["parameters"]["low"]
108
+ high_cut = preprocess_dict["filter"]["parameters"]["high"]
109
+ order = preprocess_dict["filter"]["parameters"]["order"]
110
+ TR = 1.0 / fs
111
+ ts = band_pass_filter(ts, k=order, TR=TR, low_cut=low_cut, high_cut=high_cut)
112
+
113
+ if preprocess_dict["remove_strong_artefacts"]["use"] == "yes":
114
+ ts = remove_strong_artefacts(ts)
115
+
116
+ return ts
117
+
118
+
119
+ def band_pass_filter(ts, low_cut=0.02, high_cut=0.1, TR=2.0, order=2):
120
+ """
121
+ apply band pass filter to given time series
122
+
123
+ Parameters
124
+ ----------
125
+ ts : numpy.ndarray [n_regions, n_timepoints]
126
+ Input signal
127
+ low_cut : float, optional
128
+ Low cut frequency. The default is 0.02.
129
+ high_cut : float, optional
130
+ High cut frequency. The default is 0.1.
131
+ TR : float, optional
132
+ Sampling interval. The default is 2.0 second.
133
+
134
+ returns
135
+ -------
136
+ ts_filt : numpy.ndarray
137
+ filtered signal
138
+
139
+
140
+ """
141
+
142
+ assert np.isnan(ts).any() == False
143
+
144
+ fnq = 1.0 / (2.0 * TR) # Nyquist frequency
145
+ Wn = [low_cut / fnq, high_cut / fnq]
146
+ bfilt, afilt = butter(order, Wn, btype="band")
147
+ return filtfilt(bfilt, afilt, ts, axis=1)
148
+
149
+
150
+ def remove_strong_artefacts(ts, threshold=3.0):
151
+
152
+ if isinstance(ts, (list, tuple)):
153
+ ts = np.array(ts)
154
+
155
+ if ts.ndim == 1:
156
+ ts = ts.reshape(1, -1)
157
+
158
+ nn = ts.shape[0]
159
+
160
+ for i in range(nn):
161
+ x_ = ts[i, :]
162
+ std_dev = threshold * np.std(x_)
163
+ x_[x_ > std_dev] = std_dev
164
+ x_[x_ < -std_dev] = -std_dev
165
+ ts[i, :] = x_
166
+ return ts
167
+
168
+
169
+ def get_fc(ts, masks=None, positive=False, fc_fucntion="corrcoef"):
170
+ """
171
+ calculate the functional connectivity matrix
172
+
173
+ Parameters
174
+ ----------
175
+ ts : numpy.ndarray [n_regions, n_timepoints]
176
+ Input signal
177
+
178
+ Returns
179
+ -------
180
+ FC : numpy.ndarray
181
+ functional connectivity matrix
182
+ """
183
+
184
+ from numpy import corrcoef, cov
185
+
186
+ n_noes = ts.shape[0]
187
+ if masks is None:
188
+ masks = {"full": np.ones((n_noes, n_noes))}
189
+
190
+ FCs = {}
191
+ FC = eval(fc_fucntion)(ts)
192
+ for _, key in enumerate(masks.keys()):
193
+ mask = masks[key]
194
+ fc = deepcopy(FC)
195
+ if positive:
196
+ fc = fc * (fc > 0)
197
+ fc = fc * mask
198
+ fc = fc - np.diag(np.diagonal(fc))
199
+ FCs[key] = fc
200
+
201
+ return FCs
202
+
203
+
204
+ def get_fcd(
205
+ ts,
206
+ TR=1,
207
+ win_len=30,
208
+ positive=False,
209
+ masks=None,
210
+ #!TODO: add overlap
211
+ ):
212
+ """
213
+ Compute dynamic functional connectivity.
214
+
215
+ Parameters
216
+ ----------
217
+
218
+ ts: numpy.ndarray [n_regions, n_timepoints]
219
+ Input signal
220
+ win_len: int
221
+ sliding window length in samples, default is 30
222
+ TR: int
223
+ repetition time. It refers to the amount of time that
224
+ passes between consecutive acquired brain volumes during
225
+ functional magnetic resonance imaging (fMRI) scans.
226
+ positive: bool
227
+ if True, only positive values of FC are considered.
228
+ default is False
229
+ masks: dict
230
+ dictionary of masks to compute FCD on.
231
+ default is None, which means that FCD is computed on the full matrix.
232
+ see also `hbt.utility.make_mask` and `hbt.utility.get_masks`.
233
+
234
+ Returns
235
+ -------
236
+ FCD: ndarray
237
+ matrix of functional connectivity dynamics
238
+ """
239
+ if not isinstance(ts, np.ndarray):
240
+ ts = np.array(ts)
241
+
242
+ ts = ts.T
243
+ n_samples, n_nodes = ts.shape
244
+ # check if lenght of the time series is enough
245
+ if n_samples < 2 * win_len:
246
+ raise ValueError(
247
+ f"get_fcd: Length of the time series should be at least 2 times of win_len. n_samples: {n_samples}, win_len: {win_len}"
248
+ )
249
+
250
+ mask_full = np.ones((n_nodes, n_nodes))
251
+ if masks is None:
252
+ masks = {"full": mask_full}
253
+
254
+ windowed_data = np.lib.stride_tricks.sliding_window_view(
255
+ ts, (int(win_len / TR), n_nodes), axis=(0, 1)
256
+ ).squeeze()
257
+ n_windows = windowed_data.shape[0]
258
+ fc_stream = np.asarray(
259
+ [np.corrcoef(windowed_data[i, :, :], rowvar=False) for i in range(n_windows)]
260
+ )
261
+
262
+ if positive:
263
+ fc_stream *= fc_stream > 0
264
+
265
+ FCDs = {}
266
+ for _, key in enumerate(masks.keys()):
267
+ mask = masks[key].astype(np.float64)
268
+ mask *= np.triu(mask_full, k=1)
269
+ nonzero_idx = np.nonzero(mask)
270
+ fc_stream_masked = fc_stream[:, nonzero_idx[0], nonzero_idx[1]]
271
+ fcd = np.corrcoef(fc_stream_masked, rowvar=True)
272
+ FCDs[key] = fcd
273
+
274
+ return FCDs
275
+
276
+
277
+ def get_fcd2(ts, wwidth=30, maxNwindows=200, olap=0.94, indices=[], verbose=False):
278
+ """
279
+ Functional Connectivity Dynamics from the given of time series
280
+
281
+ Parameters
282
+ ----------
283
+ data: np.ndarray (2d)
284
+ time series in rows [n_nodes, n_samples]
285
+ opt: dict
286
+ parameters
287
+
288
+ Returns
289
+ -------
290
+ FCD: np.ndarray (2d)
291
+ functional connectivity dynamics matrix
292
+
293
+ """
294
+
295
+ assert olap <= 1 and olap >= 0, "olap must be between 0 and 1"
296
+
297
+ all_corr_matrix = []
298
+ nt = len(ts[0]) # number of time points/ samples
299
+
300
+ try:
301
+ Nwindows = min(
302
+ ((nt - wwidth * olap) // (wwidth * (1 - olap)), maxNwindows)
303
+ )
304
+ shift = int((nt - wwidth) // (Nwindows - 1))
305
+ if Nwindows == maxNwindows:
306
+ wwidth = int(shift // (1 - olap))
307
+
308
+ indx_start = range(0, (nt - wwidth + 1), shift)
309
+ indx_stop = range(wwidth, (1 + nt), shift)
310
+
311
+ nnodes = ts.shape[0]
312
+
313
+ for j1, j2 in zip(indx_start, indx_stop):
314
+ aux_s = ts[:, j1:j2]
315
+ corr_mat = np.corrcoef(aux_s)
316
+ all_corr_matrix.append(corr_mat)
317
+
318
+ corr_vectors = np.array(
319
+ [allPm[np.tril_indices(nnodes, k=-1)] for allPm in all_corr_matrix]
320
+ )
321
+ CV_centered = corr_vectors - np.mean(corr_vectors, -1)[:, None]
322
+
323
+ return np.corrcoef(CV_centered)
324
+
325
+ except Exception as e:
326
+ if verbose:
327
+ print(e)
328
+ return np.array([np.nan])
329
+
330
+
331
+ def set_attribute(key, value):
332
+ def decorate_func(func):
333
+ setattr(func, key, value)
334
+ return func
335
+
336
+ return decorate_func
337
+
338
+
339
+ def compute_time(signal, fs):
340
+ """Creates the signal correspondent time array.
341
+
342
+ Parameters
343
+ ----------
344
+ signal: nd-array
345
+ Input from which the time is computed.
346
+ fs: int
347
+ Sampling Frequency
348
+
349
+ Returns
350
+ -------
351
+ time : float list
352
+ Signal time
353
+
354
+ """
355
+
356
+ return np.arange(0, len(signal)) / fs
357
+
358
+
359
+ def calculate_plv(data):
360
+ n_channels, n_samples = data.shape
361
+
362
+ analytic_signal = hilbert(data)
363
+ phase_angles = np.angle(analytic_signal)
364
+ plv_matrix = np.zeros((n_channels, n_channels))
365
+
366
+ for i in range(n_channels):
367
+ for j in range(i + 1, n_channels):
368
+ plv = np.abs(np.mean(np.exp(1j * (phase_angles[i] - phase_angles[j]))))
369
+ plv_matrix[i, j] = plv
370
+ plv_matrix[j, i] = plv
371
+
372
+ return plv_matrix
373
+
374
+
375
+ def calc_fft(signal, fs):
376
+ """This functions computes the fft of a signal.
377
+
378
+ Parameters
379
+ ----------
380
+ signal : nd-array
381
+ The input signal from which fft is computed
382
+ fs : int
383
+ Sampling frequency
384
+
385
+ Returns
386
+ -------
387
+ f: nd-array
388
+ Frequency values (xx axis)
389
+ fmag: nd-array
390
+ Amplitude of the frequency values (yy axis)
391
+
392
+ """
393
+
394
+ fmag = np.abs(np.fft.fft(signal))
395
+ f = np.linspace(0, fs // 2, len(signal) // 2)
396
+
397
+ return f[: len(signal) // 2].copy(), fmag[: len(signal) // 2].copy()
398
+
399
+
400
+ def filterbank(signal, fs, pre_emphasis=0.97, nfft=512, nfilt=40):
401
+ """Computes the MEL-spaced filterbank.
402
+
403
+ It provides the information about the power in each frequency band.
404
+
405
+ Implementation details and description on:
406
+ https://www.kaggle.com/ilyamich/mfcc-implementation-and-tutorial
407
+ https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html#fnref:1
408
+
409
+ Parameters
410
+ ----------
411
+ signal : nd-array
412
+ Input from which filterbank is computed
413
+ fs : int
414
+ Sampling frequency
415
+ pre_emphasis : float
416
+ Pre-emphasis coefficient for pre-emphasis filter application
417
+ nfft : int
418
+ Number of points of fft
419
+ nfilt : int
420
+ Number of filters
421
+
422
+ Returns
423
+ -------
424
+ nd-array
425
+ MEL-spaced filterbank
426
+
427
+ """
428
+
429
+ # Signal is already a window from the original signal, so no frame is needed.
430
+ # According to the references it is needed the application of a window function such as
431
+ # hann window. However if the signal windows don't have overlap, we will lose information,
432
+ # as the application of a hann window will overshadow the windows signal edges.
433
+
434
+ # pre-emphasis filter to amplify the high frequencies
435
+
436
+ emphasized_signal = np.append(
437
+ np.array(signal)[0], np.array(signal[1:]) - pre_emphasis * np.array(signal[:-1])
438
+ )
439
+
440
+ # Fourier transform and Power spectrum
441
+ mag_frames = np.absolute(
442
+ np.fft.rfft(emphasized_signal, nfft)
443
+ ) # Magnitude of the FFT
444
+
445
+ pow_frames = (1.0 / nfft) * (mag_frames**2) # Power Spectrum
446
+
447
+ low_freq_mel = 0
448
+ high_freq_mel = 2595 * np.log10(1 + (fs / 2) / 700) # Convert Hz to Mel
449
+ # Equally spaced in Mel scale
450
+ mel_points = np.linspace(low_freq_mel, high_freq_mel, nfilt + 2)
451
+ hz_points = 700 * (10 ** (mel_points / 2595) - 1) # Convert Mel to Hz
452
+ filter_bin = np.floor((nfft + 1) * hz_points / fs)
453
+
454
+ fbank = np.zeros((nfilt, int(np.floor(nfft / 2 + 1))))
455
+ for m in range(1, nfilt + 1):
456
+
457
+ f_m_minus = int(filter_bin[m - 1]) # left
458
+ f_m = int(filter_bin[m]) # center
459
+ f_m_plus = int(filter_bin[m + 1]) # right
460
+
461
+ for k in range(f_m_minus, f_m):
462
+ fbank[m - 1, k] = (k - filter_bin[m - 1]) / (
463
+ filter_bin[m] - filter_bin[m - 1]
464
+ )
465
+ for k in range(f_m, f_m_plus):
466
+ fbank[m - 1, k] = (filter_bin[m + 1] - k) / (
467
+ filter_bin[m + 1] - filter_bin[m]
468
+ )
469
+
470
+ # Area Normalization
471
+ # If we don't normalize the noise will increase with frequency because of the filter width.
472
+ enorm = 2.0 / (hz_points[2 : nfilt + 2] - hz_points[:nfilt])
473
+ fbank *= enorm[:, np.newaxis]
474
+
475
+ filter_banks = np.dot(pow_frames, fbank.T)
476
+ filter_banks = np.where(
477
+ filter_banks == 0, np.finfo(float).eps, filter_banks
478
+ ) # Numerical Stability
479
+ filter_banks = 20 * np.log10(filter_banks) # dB
480
+
481
+ return filter_banks
482
+
483
+
484
+ def autocorr_norm(signal):
485
+ """Computes the autocorrelation.
486
+
487
+ Implementation details and description in:
488
+ https://ccrma.stanford.edu/~orchi/Documents/speaker_recognition_report.pdf
489
+
490
+ Parameters
491
+ ----------
492
+ signal : nd-array
493
+ Input from linear prediction coefficients are computed
494
+
495
+ Returns
496
+ -------
497
+ nd-array
498
+ Autocorrelation result
499
+
500
+ """
501
+
502
+ variance = np.var(signal)
503
+ signal = np.copy(signal - signal.mean())
504
+ r = scipy.signal.correlate(signal, signal)[-len(signal) :]
505
+
506
+ if (signal == 0).all():
507
+ return np.zeros(len(signal))
508
+
509
+ acf = r / variance / len(signal)
510
+
511
+ return acf
512
+
513
+
514
+ def create_symmetric_matrix(acf, order=11):
515
+ """Computes a symmetric matrix.
516
+
517
+ Implementation details and description in:
518
+ https://ccrma.stanford.edu/~orchi/Documents/speaker_recognition_report.pdf
519
+
520
+ Parameters
521
+ ----------
522
+ acf : nd-array
523
+ Input from which a symmetric matrix is computed
524
+ order : int
525
+ Order
526
+
527
+ Returns
528
+ -------
529
+ nd-array
530
+ Symmetric Matrix
531
+
532
+ """
533
+
534
+ smatrix = np.empty((order, order))
535
+ xx = np.arange(order)
536
+ j = np.tile(xx, order)
537
+ i = np.repeat(xx, order)
538
+ smatrix[i, j] = acf[np.abs(i - j)]
539
+
540
+ return smatrix
541
+
542
+
543
+ def lpc(signal, n_coeff=12):
544
+ """Computes the linear prediction coefficients.
545
+
546
+ Implementation details and description in:
547
+ https://ccrma.stanford.edu/~orchi/Documents/speaker_recognition_report.pdf
548
+
549
+ Parameters
550
+ ----------
551
+ signal : nd-array
552
+ Input from linear prediction coefficients are computed
553
+ n_coeff : int
554
+ Number of coefficients
555
+
556
+ Returns
557
+ -------
558
+ nd-array
559
+ Linear prediction coefficients
560
+
561
+ """
562
+
563
+ if signal.ndim > 1:
564
+ raise ValueError("Only 1 dimensional arrays are valid")
565
+ if n_coeff > signal.size:
566
+ raise ValueError("Input signal must have a length >= n_coeff")
567
+
568
+ # Calculate the order based on the number of coefficients
569
+ order = n_coeff - 1
570
+
571
+ # Calculate LPC with Yule-Walker
572
+ acf = np.correlate(signal, signal, "full")
573
+
574
+ r = np.zeros(order + 1, "float32")
575
+ # Assuring that works for all type of input lengths
576
+ nx = np.min([order + 1, len(signal)])
577
+ r[:nx] = acf[len(signal) - 1 : len(signal) + order]
578
+
579
+ smatrix = create_symmetric_matrix(r[:-1], order)
580
+
581
+ if np.sum(smatrix) == 0:
582
+ return tuple(np.zeros(order + 1))
583
+
584
+ lpc_coeffs = np.dot(np.linalg.inv(smatrix), -r[1:])
585
+
586
+ return tuple(np.concatenate(([1.0], lpc_coeffs)))
587
+
588
+
589
+ def create_xx(features):
590
+ """Computes the range of features amplitude for the probability density function calculus.
591
+
592
+ Parameters
593
+ ----------
594
+ features : nd-array
595
+ Input features
596
+
597
+ Returns
598
+ -------
599
+ nd-array
600
+ range of features amplitude
601
+
602
+ """
603
+
604
+ features_ = np.copy(features)
605
+
606
+ if max(features_) < 0:
607
+ max_f = -max(features_)
608
+ min_f = min(features_)
609
+ else:
610
+ min_f = min(features_)
611
+ max_f = max(features_)
612
+
613
+ if min(features_) == max(features_):
614
+ xx = np.linspace(min_f, min_f + 10, len(features_))
615
+ else:
616
+ xx = np.linspace(min_f, max_f, len(features_))
617
+
618
+ return xx
619
+
620
+
621
+ def kde(features):
622
+ """Computes the probability density function of the input signal
623
+ using a Gaussian KDE (Kernel Density Estimate)
624
+
625
+ Parameters
626
+ ----------
627
+ features : nd-array
628
+ Input from which probability density function is computed
629
+
630
+ Returns
631
+ -------
632
+ nd-array
633
+ probability density values
634
+
635
+ """
636
+ features_ = np.copy(features)
637
+ xx = create_xx(features_)
638
+
639
+ if min(features_) == max(features_):
640
+ noise = np.random.randn(len(features_)) * 0.0001
641
+ features_ = np.copy(features_ + noise)
642
+
643
+ kernel = scipy.stats.gaussian_kde(features_, bw_method="silverman")
644
+
645
+ return np.array(kernel(xx) / np.sum(kernel(xx)))
646
+
647
+
648
+ def gaussian(features):
649
+ """Computes the probability density function of the input signal using a Gaussian function
650
+
651
+ Parameters
652
+ ----------
653
+ features : nd-array
654
+ Input from which probability density function is computed
655
+ Returns
656
+ -------
657
+ nd-array
658
+ probability density values
659
+
660
+ """
661
+
662
+ features_ = np.copy(features)
663
+
664
+ xx = create_xx(features_)
665
+ std_value = np.std(features_)
666
+ mean_value = np.mean(features_)
667
+
668
+ if std_value == 0:
669
+ return 0.0
670
+ pdf_gauss = scipy.stats.norm.pdf(xx, mean_value, std_value)
671
+
672
+ return np.array(pdf_gauss / np.sum(pdf_gauss))
673
+
674
+
675
+ def calc_ecdf(signal):
676
+ """Computes the ECDF of the signal.
677
+ ECDF is the empirical cumulative distribution function.
678
+
679
+ Parameters
680
+ ----------
681
+ signal : nd-array
682
+ Input from which ECDF is computed
683
+ Returns
684
+ -------
685
+ nd-array
686
+ Sorted signal and computed ECDF.
687
+
688
+ """
689
+ return np.sort(signal), np.arange(1, len(signal) + 1) / len(signal)
690
+
691
+
692
+ def matrix_stat(
693
+ A: np.ndarray,
694
+ k: int = 1,
695
+ eigenvalues: bool = True,
696
+ pca_num_components: int = 3,
697
+ quantiles: List[float] = [0.05, 0.25, 0.5, 0.75, 0.95],
698
+ features: List[str] = ["sum", "max", "min", "mean", "std", "skew", "kurtosis"],
699
+ ):
700
+ """
701
+ calculate statistics of the given matrix
702
+
703
+ Parameters
704
+ ----------
705
+ A: np.ndarray (2d)
706
+ input matrix
707
+ k: int
708
+ upper triangular matrix offset
709
+ pca_num_components: int
710
+ number of components to keep for PCA, set to 0 if not used
711
+ features: list
712
+ list of features to compute
713
+ quantiles: list
714
+ list of quantiles to compute, set to [] or None if not used
715
+
716
+ Returns
717
+ -------
718
+ values: np.ndarray (1d)
719
+ feature values
720
+ labels: list
721
+ feature labels
722
+
723
+ """
724
+ from numpy import sum, max, min, mean, std
725
+ from scipy.stats import skew, kurtosis
726
+
727
+ off_diag_sum_A = np.sum(np.abs(A)) - np.trace(np.abs(A))
728
+
729
+ ut_idx = np.triu_indices_from(A, k=k)
730
+ A_ut = A[ut_idx[0], ut_idx[1]]
731
+
732
+ values = []
733
+ labels = []
734
+ if quantiles:
735
+ q = np.quantile(A, quantiles)
736
+ values.extend(q.tolist())
737
+ labels.extend([f"quantile_{i}" for i in quantiles])
738
+
739
+ if pca_num_components:
740
+ try:
741
+ pca = PCA(n_components=pca_num_components)
742
+ pca_a = pca.fit_transform(A)
743
+ except:
744
+ return [np.nan], ["pca_error"]
745
+
746
+ for f in features:
747
+ v = eval(f)(pca_a.reshape(-1))
748
+ values.append(v)
749
+ labels.append(f"pca_{f}")
750
+
751
+ if eigenvalues:
752
+ eigen_vals_A, _ = LA.eig(A)
753
+ for f in features:
754
+ v = eval(f)(np.real(eigen_vals_A[:-1]))
755
+ values.append(v)
756
+ labels.append(f"eig_{f}")
757
+
758
+ for f in features:
759
+ v = eval(f)(A_ut)
760
+ values.append(v)
761
+ labels.append(f"ut_{f}")
762
+
763
+ values.append(off_diag_sum_A)
764
+ labels.append("sum")
765
+
766
+ return values, labels
767
+
768
+
769
+ def report_cfg(cfg: dict):
770
+ """
771
+ report the features in provided config file
772
+ """
773
+
774
+ print("Selected features:")
775
+ print("------------------")
776
+
777
+ for d in cfg:
778
+ if d == "features_path":
779
+ continue
780
+ else:
781
+ if cfg[d]:
782
+ print("■ Domain:", d)
783
+ for f in cfg[d]:
784
+ print(" ▢ Function: ", f)
785
+ print(" ▫ description: ", cfg[d][f]["description"])
786
+ print(" ▫ function : ", cfg[d][f]["function"])
787
+ print(" ▫ parameters : ", cfg[d][f]["parameters"])
788
+ print(" ▫ tag : ", cfg[d][f]["tag"])
789
+ print(" ▫ use : ", cfg[d][f]["use"])
790
+
791
+
792
+ def get_jar_location():
793
+
794
+ jar_file_name = "infodynamics.jar"
795
+ jar_location = join(vbi.__file__, "feature_extraction")
796
+ jar_location = jar_location.replace("__init__.py", "")
797
+ jar_location = join(jar_location, jar_file_name)
798
+
799
+ return jar_location
800
+
801
+
802
+ def init_jvm():
803
+
804
+ jar_location = get_jar_location()
805
+
806
+ if jp.isJVMStarted():
807
+ return
808
+ else:
809
+ jp.startJVM(jp.getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jar_location)
810
+
811
+
812
+ def nat2bit(x):
813
+ """
814
+ convert nats to bits
815
+ """
816
+ return x * 1.4426950408889634
817
+
818
+
819
+ def compute_time(ts, fs):
820
+ """Creates the signal correspondent time array.
821
+
822
+ Parameters
823
+ ----------
824
+ signal: nd-array
825
+ Input from which the time is computed.
826
+ fs: int
827
+ Sampling Frequency
828
+
829
+ Returns
830
+ -------
831
+ time : float list
832
+ Signal time
833
+
834
+ """
835
+
836
+ return np.arange(0, len(ts)) / fs
837
+
838
+
839
+ def calc_fft(ts, fs):
840
+ """This functions computes the fft of a signal.
841
+
842
+ Parameters
843
+ ----------
844
+ signal : nd-array [n_regions, n_timepoints]
845
+ The input signal from which fft is computed
846
+ fs : float
847
+ Sampling frequency
848
+
849
+ Returns
850
+ -------
851
+ f: nd-array
852
+ Frequency values (xx axis)
853
+ fmag: nd-array [n_regions, n_freqs]
854
+ Amplitude of the frequency values (yy axis)
855
+
856
+ """
857
+
858
+ fmag = np.abs(np.fft.rfft(ts, axis=1))
859
+ f = np.fft.rfftfreq(len(ts[0]), d=1 / fs)
860
+
861
+ return f, fmag
862
+
863
+
864
+ def fundamental_frequency(f, fmag):
865
+ """Computes fundamental frequency of the signal.
866
+
867
+ The fundamental frequency integer multiple best explain
868
+ the content of the signal spectrum.
869
+
870
+ Feature computational cost: 1
871
+
872
+ Parameters
873
+ ----------
874
+ ts : nd-array [n_regions x n_samples]
875
+ Input from which fundamental frequency is computed
876
+ fs : float
877
+ Sampling frequency
878
+
879
+ Returns
880
+ -------
881
+ f0: array of floats
882
+ Predominant frequency of the signals
883
+
884
+ """
885
+
886
+ def one_dim(f, fmag):
887
+ bp = scipy.signal.find_peaks(fmag, height=max(fmag) * 0.3)[0]
888
+
889
+ # Condition for offset removal, since the offset generates a peak at frequency zero
890
+ bp = bp[bp != 0]
891
+ if not list(bp):
892
+ f0 = 0
893
+ else:
894
+ # f0 is the minimum big peak frequency
895
+ f0 = f[min(bp)]
896
+
897
+ return f0
898
+
899
+ r, c = fmag.shape
900
+ f0 = np.zeros(r)
901
+ for i in range(r):
902
+ f0[i] = one_dim(f, fmag[i])
903
+ labels = [f"fundamental_frequency_{i}" for i in range(len(f0))]
904
+ return f0, labels
905
+
906
+
907
+ def spectral_distance(freq, fmag):
908
+ """Computes the signal spectral distance.
909
+
910
+ Distance of the signal's cumulative sum of the FFT elements to
911
+ the respective linear regression.
912
+
913
+ Parameters
914
+ ----------
915
+ fmag: nd-array [n_regions x n_freqs]
916
+ power spectrum of the signal
917
+
918
+ Returns
919
+ -------
920
+ values: array-like
921
+ spectral distances
922
+ labels: array-like
923
+ labels of the features
924
+
925
+ """
926
+
927
+ r, c = fmag.shape
928
+ values = np.zeros(r)
929
+ cum_fmag = np.cumsum(fmag, axis=1)
930
+
931
+ for i in range(r):
932
+ points_y = np.linspace(0, cum_fmag[i], c)
933
+ values[i] = np.sum(points_y - cum_fmag[i]) / c
934
+ labels = [f"spectral_distance_{i}" for i in range(r)]
935
+ return values, labels
936
+
937
+
938
+ def max_frequency(f, psd):
939
+ """
940
+ Computes the maximum frequency of the signals.
941
+
942
+ parameters
943
+ ----------
944
+ f: nd-array
945
+ frequency values
946
+ psd: nd-array [n_regions x n_freqs]
947
+ power spectral density of the signal
948
+
949
+ Returns
950
+ -------
951
+ values: array-like
952
+ maximum frequencies
953
+
954
+ """
955
+ if not isinstance(f, np.ndarray):
956
+ f = np.array(f)
957
+ if not isinstance(psd, np.ndarray):
958
+ psd = np.array(psd)
959
+ if psd.ndim == 1:
960
+ psd = psd.reshape(1, -1)
961
+
962
+ nn, nt = psd.shape
963
+ fmax = np.zeros(nn)
964
+ ind_max = np.argmax(psd, axis=1)
965
+ fmax = f[ind_max]
966
+
967
+
968
+ labels = [f"max_frequency_{i}" for i in range(len(fmax))]
969
+ return fmax, labels
970
+
971
+ def max_psd(f, psd):
972
+ """
973
+ Computes the maximum power spectral density of the signals.
974
+
975
+ Parameters
976
+ ----------
977
+ f: nd-array
978
+ frequency values
979
+ psd: nd-array [n_regions x n_freqs]
980
+ power spectral density of the signal
981
+
982
+ Returns
983
+ -------
984
+ values: array-like
985
+ maximum power spectral densities
986
+ """
987
+ nn, nt = psd.shape
988
+ if not isinstance(psd, np.ndarray):
989
+ psd = np.array(psd)
990
+ if psd.ndim == 1:
991
+ psd = psd.reshape(1, -1)
992
+
993
+ pmax = np.max(psd, axis=1)
994
+ labels = [f"max_psd_{i}" for i in range(len(pmax))]
995
+ return pmax, labels
996
+
997
+
998
+ def median_frequency(f, fmag):
999
+ """
1000
+ Computes the median frequency of the signals.
1001
+
1002
+ """
1003
+
1004
+ def one_d(cum_fmag):
1005
+
1006
+ try:
1007
+ ind_mag = np.where(cum_fmag > cum_fmag[-1] * 0.5)[0][0]
1008
+ except:
1009
+ ind_mag = np.argmax(cum_fmag)
1010
+ return f[ind_mag]
1011
+
1012
+ cum_fmag = np.cumsum(fmag, axis=1)
1013
+ # use map to apply one_d to each row of cum_fmag
1014
+ fmed = np.array(list(map(one_d, cum_fmag)))
1015
+ labels = [f"median_frequency_{i}" for i in range(len(fmed))]
1016
+ return fmed, labels
1017
+
1018
+
1019
+ def spectral_centroid(f, fmag):
1020
+ """
1021
+ Calculate the spectral centroid of the signals.
1022
+ The Timbre Toolbox: Extracting audio descriptors from musicalsignals
1023
+ Authors Peeters G., Giordano B., Misdariis P., McAdams S.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ f: nd-array
1028
+ frequency values
1029
+ fmag: nd-array [n_regions x n_freqs]
1030
+ power spectrum of the signal
1031
+
1032
+ Returns
1033
+ -------
1034
+ values: array-like
1035
+ spectral centroids
1036
+ labels: array-like
1037
+ labels of the features
1038
+
1039
+ """
1040
+
1041
+ def one_d(f, fmag):
1042
+ if not np.sum(fmag):
1043
+ return 0
1044
+ else:
1045
+ return np.sum(f * fmag) / np.sum(fmag)
1046
+
1047
+ # use map to apply one_d to each row of fmag
1048
+ values = np.array(list(map(one_d, f, fmag)))
1049
+ labels = [f"spectral_centroid_{i}" for i in range(len(values))]
1050
+ return values, labels
1051
+
1052
+
1053
+ def spectral_kurtosis(f, fmag):
1054
+ """
1055
+ Measure the flatness of the power spectrum of the signals.
1056
+ The Timbre Toolbox: Extracting audio descriptors from musicalsignals
1057
+ Authors Peeters G., Giordano B., Misdariis P., McAdams S.
1058
+
1059
+ Parameters
1060
+ ----------
1061
+ f: nd-array
1062
+ frequency values
1063
+ fmag: nd-array [n_regions x n_freqs]
1064
+ power spectrum of the signal
1065
+
1066
+ Returns
1067
+ -------
1068
+ values: array-like
1069
+ spectral kurtosis
1070
+ labels: array-like
1071
+ labels of the features
1072
+
1073
+ """
1074
+
1075
+ spread = spectral_spread(f, fmag)[0]
1076
+ centroid = spectral_centroid(f, fmag)[0]
1077
+ values = np.zeros(len(spread))
1078
+ for i in range(len(spread)):
1079
+ if spread[i] == 0:
1080
+ values[i] = 0
1081
+ else:
1082
+ spect_kurt = ((f - centroid[i]) ** 4) * (fmag / np.sum(fmag))
1083
+ values[i] = np.sum(spect_kurt) / (spread[i] ** 4)
1084
+ labels = [f"spectral_kurtosis_{i}" for i in range(len(values))]
1085
+
1086
+ return values, labels
1087
+
1088
+
1089
+ def spectral_spread(f, fmag):
1090
+ """Measures the spread of the spectrum around its mean value.
1091
+
1092
+ Description and formula in Article:
1093
+ The Timbre Toolbox: Extracting audio descriptors from musicalsignals
1094
+ Authors Peeters G., Giordano B., Misdariis P., McAdams S.
1095
+
1096
+ Feature computational cost: 2
1097
+
1098
+ Parameters
1099
+ ----------
1100
+ signal : nd-array
1101
+ Signal from which spectral spread is computed.
1102
+ fs : float
1103
+ Sampling frequency
1104
+
1105
+ Returns
1106
+ -------
1107
+ float
1108
+ Spectral Spread
1109
+
1110
+ """
1111
+ n = fmag.shape[0]
1112
+ centroid = spectral_centroid(f, fmag)[0]
1113
+ values = np.zeros(n)
1114
+ for i in range(n):
1115
+ if not np.sum(fmag[i]):
1116
+ values[i] = 0
1117
+ else:
1118
+ values[i] = (
1119
+ np.dot(((f - centroid[i]) ** 2), (fmag[i] / np.sum(fmag[i]))) ** 0.5
1120
+ )
1121
+
1122
+ return values, [f"spectral_spread_{i}" for i in range(len(values))]
1123
+
1124
+
1125
+ def spectral_variation(freq, fmag):
1126
+ """
1127
+ Computes the amount of variation of the spectrum along time.
1128
+ Spectral variation is computed from the normalized cross-correlation between two consecutive amplitude spectra.
1129
+
1130
+ Description and formula in Article:
1131
+ The Timbre Toolbox: Extracting audio descriptors from musicalsignals
1132
+ Authors Peeters G., Giordano B., Misdariis P., McAdams S.
1133
+ """
1134
+
1135
+ def one_d(sum1, sum2, sum3):
1136
+
1137
+ if not sum2 or not sum3:
1138
+ return 1
1139
+ else:
1140
+ return 1 - (sum1 / ((sum2**0.5) * (sum3**0.5)))
1141
+
1142
+ sum1 = np.sum(fmag[:, :-1] * fmag[:, 1:], axis=1)
1143
+ sum2 = np.sum(fmag[:, 1:] ** 2, axis=1)
1144
+ sum3 = np.sum(fmag[:, :-1] ** 2, axis=1)
1145
+ sums = np.array([sum1, sum2, sum3]).T
1146
+
1147
+ n = fmag.shape[0]
1148
+ values = np.array(list(map(lambda x: one_d(*x), sums)))
1149
+ labels = [f"spectral_variation_{i}" for i in range(len(values))]
1150
+ return values, labels
1151
+
1152
+
1153
+ def wavelet(signal, function=None, widths=np.arange(1, 10)):
1154
+ """Computes CWT (continuous wavelet transform) of the signal.
1155
+
1156
+ Parameters
1157
+ ----------
1158
+ signal : nd-array
1159
+ Input from which CWT is computed
1160
+ function : wavelet function
1161
+ Default: scipy.signal.ricker
1162
+ widths : nd-array
1163
+ Widths to use for transformation
1164
+ Default: np.arange(1,10)
1165
+
1166
+ Returns
1167
+ -------
1168
+ nd-array
1169
+ The result of the CWT along the time axis
1170
+ matrix with size (len(widths),len(signal))
1171
+
1172
+ """
1173
+
1174
+ if function is None:
1175
+ function = scipy.signal.ricker
1176
+
1177
+ if isinstance(function, str):
1178
+ function = eval(function)
1179
+
1180
+ if isinstance(widths, str):
1181
+ widths = eval(widths)
1182
+
1183
+ cwt = scipy.signal.cwt(signal, function, widths)
1184
+
1185
+ return cwt
1186
+
1187
+
1188
+ def km_order(ts, indices=None, avg=True):
1189
+ """
1190
+ Calculate the (local) Kuramoto order parameter (KOP) of the given time series
1191
+
1192
+ Parameters
1193
+ ----------
1194
+ ts: np.ndarray (2d) [n_regions, n_timepoints]
1195
+ input array
1196
+ indices: list
1197
+ list of indices of the regions of interest
1198
+ avg: bool
1199
+ if True, average the KOP across time
1200
+
1201
+ Returns
1202
+ -------
1203
+ values: np.ndarray (1d) or float
1204
+ feature values
1205
+
1206
+ """
1207
+
1208
+ if not isinstance(ts, np.ndarray):
1209
+ ts = np.array(ts)
1210
+
1211
+ if ts.ndim == 1:
1212
+ raise ValueError("Input array must be 2d")
1213
+
1214
+ if indices is None:
1215
+ indices = np.arange(ts.shape[0], dtype=int)
1216
+
1217
+ if max(indices) >= ts.shape[0]:
1218
+ raise ValueError("Invalid indices")
1219
+
1220
+ if not all(isinstance(i, (int, np.int64)) for i in indices):
1221
+ raise ValueError("Indices must be integers")
1222
+
1223
+ if len(indices) < 2:
1224
+ raise ValueError("At least two indices are required")
1225
+
1226
+ ts = ts[indices, :]
1227
+
1228
+ nn, nt = ts.shape
1229
+ r = np.abs(np.sum(np.exp(1j * ts), axis=0) / nn)
1230
+ if avg:
1231
+ return np.mean(r)
1232
+ else:
1233
+ return r
1234
+
1235
+
1236
+ def normalize_signal(ts, method="zscore"):
1237
+ """
1238
+ Normalize the input time series
1239
+
1240
+ Parameters
1241
+ ----------
1242
+ ts: np.ndarray (2d) [n_regions, n_timepoints]
1243
+ input array
1244
+ method: str
1245
+ normalization method
1246
+ index: int
1247
+ index of the times point to normalize with respect to
1248
+ x = x / x[:, index]
1249
+
1250
+ Returns
1251
+ -------
1252
+ ts: np.ndarray (2d) [n_regions, n_timepoints]
1253
+ normalized array
1254
+
1255
+ """
1256
+
1257
+ if not isinstance(ts, np.ndarray):
1258
+ ts = np.array(ts)
1259
+ if ts.ndim == 1:
1260
+ ts = ts.reshape(1, -1)
1261
+
1262
+ if method == "zscore":
1263
+ ts = stats.zscore(ts, axis=1)
1264
+
1265
+ elif method == "minmax":
1266
+ ts = (ts - np.min(ts, axis=1)[:, None]) / (
1267
+ np.max(ts, axis=1) - np.min(ts, axis=1)
1268
+ )[:, None]
1269
+
1270
+ elif method == "mean":
1271
+ ts = (ts - np.mean(ts, axis=1)[:, None]) / np.std(ts, axis=1)[:, None]
1272
+
1273
+ elif method == "max":
1274
+ ts = ts / np.max(ts, axis=1)[:, None]
1275
+
1276
+ elif method == "none":
1277
+ pass
1278
+
1279
+ else:
1280
+ raise ValueError("Invalid method")
1281
+
1282
+ return ts
1283
+
1284
+
1285
+ def state_duration(
1286
+ hmm_z: np.ndarray, n_states: int, avg: bool = True, tcut: int = 5, bins: int = 10
1287
+ ):
1288
+ """
1289
+ Measure the duration of each state
1290
+
1291
+ Parameters
1292
+ ----------
1293
+ hmm_z : nd-array [n_samples]
1294
+ The most likely states for each time point
1295
+ n_states : int
1296
+ The number of states
1297
+ avg : bool
1298
+ If True, the average duration of each state is returned.
1299
+ Otherwise, the duration of each state is returned.
1300
+ t_cut : int
1301
+ maximum duration of a state, default is 5
1302
+ bins : int
1303
+ number of bins for the histogram, default is 10
1304
+
1305
+ Returns
1306
+ -------
1307
+ stat_vec : array-like
1308
+ The duration of each state
1309
+
1310
+ """
1311
+
1312
+ infered_state = hmm_z.astype(int)
1313
+ inferred_state_list, inffered_dur = ssm.util.rle(infered_state)
1314
+
1315
+ inferred_dur_stack = []
1316
+ for s in range(n_states):
1317
+ inferred_dur_stack.append(inffered_dur[inferred_state_list == s])
1318
+
1319
+ V = []
1320
+ for i in range(n_states):
1321
+ v, _ = np.histogram(inferred_dur_stack[i], bins=bins, range=(0, tcut))
1322
+ V.append(v)
1323
+ V = np.array(V)
1324
+
1325
+ if avg:
1326
+ return V.mean(axis=0)
1327
+ else:
1328
+ return V.flatten()
1329
+
1330
+
1331
+ # not used in the code
1332
+ def set_attribute(key, value):
1333
+ def decorate_func(func):
1334
+ setattr(func, key, value)
1335
+ return func
1336
+
1337
+ return decorate_func
1338
+
1339
+
1340
+
1341
+ def seizure_onset_indicator(ts: np.ndarray, thr:float=0.02):
1342
+ '''
1343
+ return the index of the onset of seizures
1344
+ '''
1345
+
1346
+ if not isinstance(ts, np.ndarray):
1347
+ ts = np.array(ts)
1348
+
1349
+ if ts.ndim == 1:
1350
+ ts = ts.reshape(1, -1)
1351
+
1352
+ df = np.diff(ts, axis=1)
1353
+ onset_idx = np.argmax(df, axis=1)
1354
+ onset_amp = np.max(df, axis=1)
1355
+ onset_idx = np.where(onset_amp < thr, 0, onset_idx)
1356
+ # onset_amp = np.where(onset_amp < thr, 0, onset_amp)
1357
+ return onset_idx