py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/__init__.py +80 -13
  5. py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
  6. py_neuromodulation/analysis/__init__.py +4 -0
  7. py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
  8. py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
  9. py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
  10. py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
  11. py_neuromodulation/data/README +6 -6
  12. py_neuromodulation/data/dataset_description.json +8 -8
  13. py_neuromodulation/data/participants.json +32 -32
  14. py_neuromodulation/data/participants.tsv +2 -2
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  22. py_neuromodulation/default_settings.yaml +241 -0
  23. py_neuromodulation/features/__init__.py +31 -0
  24. py_neuromodulation/features/bandpower.py +165 -0
  25. py_neuromodulation/features/bispectra.py +157 -0
  26. py_neuromodulation/features/bursts.py +297 -0
  27. py_neuromodulation/features/coherence.py +255 -0
  28. py_neuromodulation/features/feature_processor.py +121 -0
  29. py_neuromodulation/features/fooof.py +142 -0
  30. py_neuromodulation/features/hjorth_raw.py +57 -0
  31. py_neuromodulation/features/linelength.py +21 -0
  32. py_neuromodulation/features/mne_connectivity.py +148 -0
  33. py_neuromodulation/features/nolds.py +94 -0
  34. py_neuromodulation/features/oscillatory.py +249 -0
  35. py_neuromodulation/features/sharpwaves.py +432 -0
  36. py_neuromodulation/filter/__init__.py +3 -0
  37. py_neuromodulation/filter/kalman_filter.py +67 -0
  38. py_neuromodulation/filter/kalman_filter_external.py +1890 -0
  39. py_neuromodulation/filter/mne_filter.py +128 -0
  40. py_neuromodulation/filter/notch_filter.py +93 -0
  41. py_neuromodulation/grid_cortex.tsv +40 -40
  42. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  43. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  44. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  45. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  46. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  47. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  48. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  49. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  50. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  51. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  52. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  53. py_neuromodulation/processing/__init__.py +10 -0
  54. py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
  55. py_neuromodulation/processing/data_preprocessor.py +77 -0
  56. py_neuromodulation/processing/filter_preprocessing.py +78 -0
  57. py_neuromodulation/processing/normalization.py +175 -0
  58. py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
  59. py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
  60. py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
  61. py_neuromodulation/stream/__init__.py +3 -0
  62. py_neuromodulation/stream/data_processor.py +325 -0
  63. py_neuromodulation/stream/generator.py +53 -0
  64. py_neuromodulation/stream/mnelsl_player.py +94 -0
  65. py_neuromodulation/stream/mnelsl_stream.py +120 -0
  66. py_neuromodulation/stream/settings.py +292 -0
  67. py_neuromodulation/stream/stream.py +427 -0
  68. py_neuromodulation/utils/__init__.py +2 -0
  69. py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
  70. py_neuromodulation/utils/database.py +149 -0
  71. py_neuromodulation/utils/io.py +378 -0
  72. py_neuromodulation/utils/keyboard.py +52 -0
  73. py_neuromodulation/utils/logging.py +66 -0
  74. py_neuromodulation/utils/types.py +251 -0
  75. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
  76. py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
  77. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
  78. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
  79. py_neuromodulation/FieldTrip.py +0 -589
  80. py_neuromodulation/_write_example_dataset_helper.py +0 -65
  81. py_neuromodulation/nm_EpochStream.py +0 -92
  82. py_neuromodulation/nm_IO.py +0 -417
  83. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  84. py_neuromodulation/nm_bispectra.py +0 -168
  85. py_neuromodulation/nm_bursts.py +0 -198
  86. py_neuromodulation/nm_coherence.py +0 -205
  87. py_neuromodulation/nm_cohortwrapper.py +0 -435
  88. py_neuromodulation/nm_eval_timing.py +0 -239
  89. py_neuromodulation/nm_features.py +0 -116
  90. py_neuromodulation/nm_features_abc.py +0 -39
  91. py_neuromodulation/nm_filter.py +0 -219
  92. py_neuromodulation/nm_filter_preprocessing.py +0 -91
  93. py_neuromodulation/nm_fooof.py +0 -159
  94. py_neuromodulation/nm_generator.py +0 -37
  95. py_neuromodulation/nm_hjorth_raw.py +0 -73
  96. py_neuromodulation/nm_kalmanfilter.py +0 -58
  97. py_neuromodulation/nm_linelength.py +0 -33
  98. py_neuromodulation/nm_mne_connectivity.py +0 -112
  99. py_neuromodulation/nm_nolds.py +0 -93
  100. py_neuromodulation/nm_normalization.py +0 -214
  101. py_neuromodulation/nm_oscillatory.py +0 -448
  102. py_neuromodulation/nm_run_analysis.py +0 -435
  103. py_neuromodulation/nm_settings.json +0 -338
  104. py_neuromodulation/nm_settings.py +0 -68
  105. py_neuromodulation/nm_sharpwaves.py +0 -401
  106. py_neuromodulation/nm_stream_abc.py +0 -218
  107. py_neuromodulation/nm_stream_offline.py +0 -359
  108. py_neuromodulation/utils/_logging.py +0 -24
  109. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -0,0 +1,432 @@
1
+ from collections.abc import Sequence
2
+ from collections import defaultdict
3
+ from itertools import product
4
+
5
+ from pydantic import model_validator
6
+ from typing import TYPE_CHECKING, Any, Callable
7
+
8
+ import numpy as np
9
+
10
+ if np.__version__ >= "2.0.0":
11
+ from numpy._core._methods import _mean as np_mean # type: ignore
12
+ else:
13
+ from numpy.core._methods import _mean as np_mean
14
+
15
+ from py_neuromodulation.utils.types import (
16
+ NMFeature,
17
+ NMBaseModel,
18
+ BoolSelector,
19
+ FrequencyRange,
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from py_neuromodulation import NMSettings
24
+
25
+ # Using low-level numpy mean function for performance, could do the same for the other estimators
26
+ ESTIMATOR_DICT = {
27
+ "mean": np_mean,
28
+ "median": np.median,
29
+ "max": np.max,
30
+ "min": np.min,
31
+ "var": np.var,
32
+ }
33
+
34
+
35
+ class PeakDetectionSettings(NMBaseModel):
36
+ estimate: bool = True
37
+ distance_troughs_ms: float = 10
38
+ distance_peaks_ms: float = 5
39
+
40
+
41
+ class SharpwaveFeatures(BoolSelector):
42
+ peak_left: bool = False
43
+ peak_right: bool = False
44
+ num_peaks: bool = False
45
+ trough: bool = False
46
+ width: bool = False
47
+ prominence: bool = True
48
+ interval: bool = True
49
+ decay_time: bool = False
50
+ rise_time: bool = False
51
+ sharpness: bool = True
52
+ rise_steepness: bool = False
53
+ decay_steepness: bool = False
54
+ slope_ratio: bool = False
55
+
56
+
57
+ class SharpwaveEstimators(NMBaseModel):
58
+ mean: list[str] = ["interval"]
59
+ median: list[str] = []
60
+ max: list[str] = ["prominence", "sharpness"]
61
+ min: list[str] = []
62
+ var: list[str] = []
63
+
64
+ def keys(self):
65
+ return ["mean", "median", "max", "min", "var"]
66
+
67
+ def values(self):
68
+ return [self.mean, self.median, self.max, self.min, self.var]
69
+
70
+
71
+ class SharpwaveSettings(NMBaseModel):
72
+ sharpwave_features: SharpwaveFeatures = SharpwaveFeatures()
73
+ filter_ranges_hz: list[FrequencyRange] = [
74
+ FrequencyRange(5, 80),
75
+ FrequencyRange(5, 30),
76
+ ]
77
+ detect_troughs: PeakDetectionSettings = PeakDetectionSettings()
78
+ detect_peaks: PeakDetectionSettings = PeakDetectionSettings()
79
+ estimator: SharpwaveEstimators = SharpwaveEstimators()
80
+ apply_estimator_between_peaks_and_troughs: bool = True
81
+
82
+ def disable_all_features(self):
83
+ self.sharpwave_features.disable_all()
84
+ for est in self.estimator.keys():
85
+ self.estimator[est] = []
86
+
87
+ @model_validator(mode="after")
88
+ def test_settings(cls, settings):
89
+ # check if all features are also enabled via an estimator
90
+ estimator_list = [est for list_ in settings.estimator.values() for est in list_]
91
+
92
+ for used_feature in settings.sharpwave_features.get_enabled():
93
+ assert (
94
+ used_feature in estimator_list
95
+ ), f"Add estimator key for {used_feature}"
96
+
97
+ return settings
98
+
99
+
100
+ class SharpwaveAnalyzer(NMFeature):
101
+ def __init__(
102
+ self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
103
+ ) -> None:
104
+ self.sw_settings = settings.sharpwave_analysis_settings
105
+ self.sfreq = sfreq
106
+ self.ch_names = ch_names
107
+ self.list_filter: list[tuple[str, Any]] = []
108
+ self.trough: list = []
109
+ self.troughs_idx: list = []
110
+
111
+ settings.validate()
112
+
113
+ # FrequencyRange's are already ensured to have high > low
114
+ # Test that the higher frequency is smaller than the sampling frequency
115
+ for filter_range in settings.sharpwave_analysis_settings.filter_ranges_hz:
116
+ assert filter_range[1] < sfreq, (
117
+ "Filter range has to be smaller than sfreq, "
118
+ f"got sfreq {sfreq} and filter range {filter_range}"
119
+ )
120
+
121
+ for filter_range in settings.sharpwave_analysis_settings.filter_ranges_hz:
122
+ # Test settings
123
+ # TODO: handle None values
124
+ if filter_range[0] is None:
125
+ self.list_filter.append(("no_filter", None))
126
+ else:
127
+ from mne.filter import create_filter
128
+
129
+ self.list_filter.append(
130
+ (
131
+ f"range_{filter_range[0]:.0f}_{filter_range[1]:.0f}",
132
+ create_filter(
133
+ None,
134
+ sfreq,
135
+ l_freq=filter_range[0],
136
+ h_freq=filter_range[1],
137
+ fir_design="firwin",
138
+ # l_trans_bandwidth=None,
139
+ # h_trans_bandwidth=None,
140
+ # filter_length=str(sfreq) + "ms",
141
+ verbose=False,
142
+ ),
143
+ )
144
+ )
145
+
146
+ self.filter_names = [name for name, _ in self.list_filter]
147
+ self.filters = np.vstack([filter for _, filter in self.list_filter])
148
+ self.filters = np.tile(self.filters[None, :, :], (len(self.ch_names), 1, 1))
149
+
150
+ self.used_features = self.sw_settings.sharpwave_features.get_enabled()
151
+
152
+ # initializing estimator functions, respecitive for all sharpwave features
153
+ self.estimator_dict: dict[str, dict[str, Callable]] = {
154
+ feat: {
155
+ est: ESTIMATOR_DICT[est]
156
+ for est in self.sw_settings.estimator.keys()
157
+ if feat in self.sw_settings.estimator[est]
158
+ }
159
+ for feat_list in self.sw_settings.estimator.values()
160
+ for feat in feat_list
161
+ }
162
+
163
+ estimator_combinations = [
164
+ (feature_name, estimator_name, estimator)
165
+ for feature_name in self.used_features
166
+ for estimator_name, estimator in self.estimator_dict[feature_name].items()
167
+ ]
168
+
169
+ filter_combinations = list(
170
+ product(
171
+ enumerate(self.ch_names), enumerate(self.filter_names), [False, True]
172
+ )
173
+ )
174
+
175
+ self.estimator_key_map: dict[str, Callable] = {}
176
+ self.combinations = []
177
+ for (ch_idx, ch_name), (
178
+ filter_idx,
179
+ filter_name,
180
+ ), detect_troughs in filter_combinations:
181
+ for feature_name, estimator_name, estimator in estimator_combinations:
182
+ key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
183
+ self.estimator_key_map[key_name] = estimator
184
+ self.combinations.append(
185
+ (
186
+ (ch_idx, ch_name),
187
+ (filter_idx, filter_name),
188
+ detect_troughs,
189
+ estimator_combinations,
190
+ )
191
+ )
192
+
193
+ # Check required feature computations according to settings
194
+ self.need_peak_left = (
195
+ self.sw_settings.sharpwave_features.peak_left
196
+ or self.sw_settings.sharpwave_features.prominence
197
+ )
198
+ self.need_peak_right = (
199
+ self.sw_settings.sharpwave_features.peak_right
200
+ or self.sw_settings.sharpwave_features.prominence
201
+ )
202
+ self.need_trough = (
203
+ self.sw_settings.sharpwave_features.trough
204
+ or self.sw_settings.sharpwave_features.prominence
205
+ )
206
+
207
+ self.need_decay_steepness = (
208
+ self.sw_settings.sharpwave_features.decay_steepness
209
+ or self.sw_settings.sharpwave_features.slope_ratio
210
+ )
211
+
212
+ self.need_rise_steepness = (
213
+ self.sw_settings.sharpwave_features.rise_steepness
214
+ or self.sw_settings.sharpwave_features.slope_ratio
215
+ )
216
+
217
+ self.need_steepness = self.need_rise_steepness or self.need_decay_steepness
218
+
219
+ def calc_feature(self, data: np.ndarray) -> dict:
220
+ """Given a new data batch, the peaks, troughs and sharpwave features
221
+ are estimated. Importantly only new data is being analyzed here. In
222
+ steps of 1/settings["sampling_rate_features] analyzed and returned.
223
+ Pre-initialized filters are applied to each channel.
224
+
225
+ Parameters
226
+ ----------
227
+ data (np.ndarray): 2d data array with shape [num_channels, samples]
228
+ feature_results (dict): Features.py estimated features
229
+
230
+ Returns
231
+ -------
232
+ feature_results (dict): set features for Features.py object
233
+ """
234
+ dict_ch_features: dict[str, dict[str, float]] = defaultdict(lambda: {})
235
+
236
+ from scipy.signal import fftconvolve
237
+
238
+ data = np.tile(data[:, None, :], (1, len(self.list_filter), 1))
239
+ data = fftconvolve(data, self.filters, axes=2, mode="same")
240
+
241
+ self.filtered_data = (
242
+ data # TONI: Expose filtered data for example 3, need a better way
243
+ )
244
+
245
+ feature_results = {}
246
+
247
+ for (
248
+ (ch_idx, ch_name),
249
+ (filter_idx, filter_name),
250
+ detect_troughs,
251
+ estimator_combinations,
252
+ ) in self.combinations:
253
+ sub_data = data[ch_idx, filter_idx, :]
254
+
255
+ key_name_pt = "Trough" if detect_troughs else "Peak"
256
+
257
+ if (not detect_troughs and not self.sw_settings.detect_peaks.estimate) or (
258
+ detect_troughs and not self.sw_settings.detect_troughs.estimate
259
+ ):
260
+ continue
261
+
262
+ # the detect_troughs loop start with peaks, s.t. data does not need to be flipped
263
+ sub_data = -sub_data if detect_troughs else sub_data
264
+ # sub_data *= 1 - 2 * detect_troughs # branchless version
265
+
266
+ waveform_results = self.analyze_waveform(sub_data)
267
+
268
+ # for each feature take the respective fun.
269
+ for feature_name, estimator_name, estimator in estimator_combinations:
270
+ feature_data = waveform_results[feature_name]
271
+ key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
272
+
273
+ # zero check because no peaks can be also detected
274
+ feature_data = estimator(feature_data) if len(feature_data) != 0 else 0
275
+ dict_ch_features[key_name][key_name_pt] = feature_data
276
+
277
+ if self.sw_settings.apply_estimator_between_peaks_and_troughs:
278
+ # apply between 'Trough' and 'Peak' the respective function again
279
+ # save only the 'est_fun' (e.g. max) between them
280
+
281
+ # the key_name stays, since the estimator function stays between peaks and troughs
282
+ for key_name, estimator in self.estimator_key_map.items():
283
+ feature_results[key_name] = estimator(
284
+ [
285
+ list(dict_ch_features[key_name].values())[0],
286
+ list(dict_ch_features[key_name].values())[1],
287
+ ]
288
+ )
289
+ else:
290
+ # otherwise, save all write all "flattened" key value pairs in feature_results
291
+ for key, subdict in dict_ch_features.items():
292
+ for key_sub, value_sub in subdict.items():
293
+ feature_results[key + "_analyze_" + key_sub] = value_sub
294
+
295
+ return feature_results
296
+
297
+ def analyze_waveform(self, data) -> dict:
298
+ """Given the scipy.signal.find_peaks trough/peak distance
299
+ settings specified sharpwave features are estimated.
300
+ """
301
+
302
+ from scipy.signal import find_peaks
303
+
304
+ # TODO: find peaks is actually not that big a performance hit, but the rest
305
+ # of this function is. Perhaps find_peaks can be put in a loop and the rest optimized somehow?
306
+ peak_idx: np.ndarray = find_peaks(
307
+ data, distance=self.sw_settings.detect_troughs.distance_peaks_ms
308
+ )[0]
309
+ trough_idx: np.ndarray = find_peaks(
310
+ -data, distance=self.sw_settings.detect_troughs.distance_troughs_ms
311
+ )[0]
312
+
313
+ """ Find left and right peak indexes for each trough """
314
+ peak_pointer = first_valid = last_valid = 0
315
+ peak_idx_left_list: list[int] = []
316
+ peak_idx_right_list: list[int] = []
317
+
318
+ for i in range(len(trough_idx)):
319
+ # Locate peak right of current trough
320
+ while (
321
+ peak_pointer < peak_idx.size and peak_idx[peak_pointer] < trough_idx[i]
322
+ ):
323
+ peak_pointer += 1
324
+
325
+ if peak_pointer - 1 < 0:
326
+ # If trough has no peak to it's left, it's not valid
327
+ first_valid = i + 1 # Try with next one
328
+ continue
329
+
330
+ if peak_pointer == peak_idx.size:
331
+ # If we went past the end of the peaks list, trough had no peak to its right
332
+ continue
333
+
334
+ last_valid = i
335
+ peak_idx_left_list.append(peak_idx[peak_pointer - 1])
336
+ peak_idx_right_list.append(peak_idx[peak_pointer])
337
+
338
+ # Remove non valid troughs and make array of left and right peaks for each trough
339
+ trough_idx = trough_idx[first_valid : last_valid + 1]
340
+ peak_idx_left = np.array(peak_idx_left_list, dtype=int)
341
+ peak_idx_right = np.array(peak_idx_right_list, dtype=int)
342
+
343
+ """ Calculate features (vectorized) """
344
+ results: dict = {}
345
+
346
+ if self.need_peak_left:
347
+ results["peak_left"] = data[peak_idx_left]
348
+
349
+ if self.need_peak_right:
350
+ results["peak_right"] = data[peak_idx_right]
351
+
352
+ if self.need_trough:
353
+ results["trough"] = data[trough_idx]
354
+
355
+ if self.sw_settings.sharpwave_features.interval:
356
+ results["interval"] = np.concatenate((np.zeros(1), np.diff(trough_idx))) * (
357
+ 1000 / self.sfreq
358
+ )
359
+
360
+ if self.sw_settings.sharpwave_features.sharpness:
361
+ # sharpess is calculated on a +- 5 ms window
362
+ # valid troughs need 5 ms of margin on both sides
363
+ troughs_valid = trough_idx[
364
+ np.logical_and(
365
+ trough_idx - int(5 * (1000 / self.sfreq)) > 0,
366
+ trough_idx + int(5 * (1000 / self.sfreq)) < data.shape[0],
367
+ )
368
+ ]
369
+ trough_height = data[troughs_valid]
370
+ left_height = data[troughs_valid - int(5 * (1000 / self.sfreq))]
371
+ right_height = data[troughs_valid + int(5 * (1000 / self.sfreq))]
372
+ # results["sharpness"] = ((trough_height - left_height) + (trough_height - right_height)) / 2
373
+ results["sharpness"] = trough_height - 0.5 * (left_height + right_height)
374
+
375
+ if self.sw_settings.sharpwave_features.num_peaks:
376
+ results["num_peaks"] = [
377
+ trough_idx.shape[0]
378
+ ] # keep list to the estimator can be applied
379
+
380
+ if self.need_steepness:
381
+ # steepness is calculated as the first derivative
382
+ steepness: np.ndarray = np.concatenate((np.zeros(1), np.diff(data)))
383
+
384
+ # Create an array with the rise and decay steepness for each trough
385
+ # 0th dimension for rise/decay, 1st for trough index, 2nd for timepoint
386
+ steepness_troughs = np.zeros((2, trough_idx.shape[0], steepness.shape[0]))
387
+ if self.need_rise_steepness or self.need_decay_steepness:
388
+ for i in range(len(trough_idx)):
389
+ steepness_troughs[
390
+ 0, i, 0 : trough_idx[i] - peak_idx_left[i] + 1
391
+ ] = steepness[peak_idx_left[i] : trough_idx[i] + 1]
392
+ steepness_troughs[
393
+ 1, i, 0 : peak_idx_right[i] - trough_idx[i] + 1
394
+ ] = steepness[trough_idx[i] : peak_idx_right[i] + 1]
395
+
396
+ if self.need_rise_steepness:
397
+ # left peak -> trough
398
+ # + 1 due to python syntax, s.t. the last element is included
399
+ results["rise_steepness"] = np.max(
400
+ np.abs(steepness_troughs[0, :, :]), axis=1
401
+ )
402
+
403
+ if self.need_decay_steepness:
404
+ # trough -> right peak
405
+ results["decay_steepness"] = np.max(
406
+ np.abs(steepness_troughs[1, :, :]), axis=1
407
+ )
408
+
409
+ if self.sw_settings.sharpwave_features.slope_ratio:
410
+ results["slope_ratio"] = (
411
+ results["rise_steepness"] - results["decay_steepness"]
412
+ )
413
+
414
+ if self.sw_settings.sharpwave_features.prominence:
415
+ results["prominence"] = np.abs(
416
+ (results["peak_right"] + results["peak_left"]) / 2 - results["trough"]
417
+ )
418
+
419
+ if self.sw_settings.sharpwave_features.decay_time:
420
+ results["decay_time"] = (peak_idx_left - trough_idx) * (
421
+ 1000 / self.sfreq
422
+ ) # ms
423
+
424
+ if self.sw_settings.sharpwave_features.rise_time:
425
+ results["rise_time"] = (peak_idx_right - trough_idx) * (
426
+ 1000 / self.sfreq
427
+ ) # ms
428
+
429
+ if self.sw_settings.sharpwave_features.width:
430
+ results["width"] = peak_idx_right - peak_idx_left # ms
431
+
432
+ return results
@@ -0,0 +1,3 @@
1
+ from .kalman_filter import define_KF, KalmanSettings
2
+ from .notch_filter import NotchFilter
3
+ from .mne_filter import MNEFilter
@@ -0,0 +1,67 @@
1
+ import numpy as np
2
+ from typing import TYPE_CHECKING
3
+
4
+ from py_neuromodulation.utils.types import NMBaseModel
5
+
6
+
7
+ if TYPE_CHECKING:
8
+ from py_neuromodulation.stream.settings import NMSettings
9
+
10
+
11
+ class KalmanSettings(NMBaseModel):
12
+ Tp: float = 0.1
13
+ sigma_w: float = 0.7
14
+ sigma_v: float = 1.0
15
+ frequency_bands: list[str] = [
16
+ "theta",
17
+ "alpha",
18
+ "low_beta",
19
+ "high_beta",
20
+ "low_gamma",
21
+ "high_gamma",
22
+ "HFA",
23
+ ]
24
+
25
+ def validate_fbands(self, settings: "NMSettings") -> None:
26
+ assert all(
27
+ (item in settings.frequency_ranges_hz for item in self.frequency_bands)
28
+ ), (
29
+ "Frequency bands for Kalman filter must also be specified in "
30
+ "bandpass_filter_settings."
31
+ )
32
+
33
+
34
+ def define_KF(Tp, sigma_w, sigma_v):
35
+ """Define Kalman filter according to white noise acceleration model.
36
+ See DOI: 10.1109/TBME.2009.2038990 for explanation
37
+ See https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html#r64ca38088676-2 for implementation details
38
+
39
+ Parameters
40
+ ----------
41
+ Tp : float
42
+ prediction interval
43
+ sigma_w : float
44
+ process noise
45
+ sigma_v : float
46
+ measurement noise
47
+
48
+ Returns
49
+ -------
50
+ filterpy.KalmanFilter
51
+ initialized KalmanFilter object
52
+ """
53
+ from .kalman_filter_external import KalmanFilter
54
+
55
+ f = KalmanFilter(dim_x=2, dim_z=1)
56
+ f.x = np.array([0, 1]) # x here sensor signal and it's first derivative
57
+ f.F = np.array([[1, Tp], [0, 1]])
58
+ f.H = np.array([[1, 0]])
59
+ f.R = sigma_v
60
+ f.Q = np.array(
61
+ [
62
+ [(sigma_w**2) * (Tp**3) / 3, (sigma_w**2) * (Tp**2) / 2],
63
+ [(sigma_w**2) * (Tp**2) / 2, (sigma_w**2) * Tp],
64
+ ]
65
+ )
66
+ f.P = np.cov([[1, 0], [0, 1]])
67
+ return f