py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/FieldTrip.py +589 -589
  5. py_neuromodulation/__init__.py +74 -13
  6. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  7. py_neuromodulation/data/README +6 -6
  8. py_neuromodulation/data/dataset_description.json +8 -8
  9. py_neuromodulation/data/participants.json +32 -32
  10. py_neuromodulation/data/participants.tsv +2 -2
  11. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  12. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  13. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  14. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  18. py_neuromodulation/grid_cortex.tsv +40 -40
  19. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  20. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  21. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  22. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  23. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  24. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  25. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  26. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  27. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  28. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  29. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  30. py_neuromodulation/nm_IO.py +413 -417
  31. py_neuromodulation/nm_RMAP.py +496 -531
  32. py_neuromodulation/nm_analysis.py +993 -1074
  33. py_neuromodulation/nm_artifacts.py +30 -25
  34. py_neuromodulation/nm_bispectra.py +154 -168
  35. py_neuromodulation/nm_bursts.py +292 -198
  36. py_neuromodulation/nm_coherence.py +251 -205
  37. py_neuromodulation/nm_database.py +149 -0
  38. py_neuromodulation/nm_decode.py +918 -992
  39. py_neuromodulation/nm_define_nmchannels.py +300 -302
  40. py_neuromodulation/nm_features.py +144 -116
  41. py_neuromodulation/nm_filter.py +219 -219
  42. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  43. py_neuromodulation/nm_fooof.py +139 -159
  44. py_neuromodulation/nm_generator.py +45 -37
  45. py_neuromodulation/nm_hjorth_raw.py +52 -73
  46. py_neuromodulation/nm_kalmanfilter.py +71 -58
  47. py_neuromodulation/nm_linelength.py +21 -33
  48. py_neuromodulation/nm_logger.py +66 -0
  49. py_neuromodulation/nm_mne_connectivity.py +149 -112
  50. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  51. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  52. py_neuromodulation/nm_nolds.py +96 -93
  53. py_neuromodulation/nm_normalization.py +173 -214
  54. py_neuromodulation/nm_oscillatory.py +423 -448
  55. py_neuromodulation/nm_plots.py +585 -612
  56. py_neuromodulation/nm_preprocessing.py +83 -0
  57. py_neuromodulation/nm_projection.py +370 -394
  58. py_neuromodulation/nm_rereference.py +97 -95
  59. py_neuromodulation/nm_resample.py +59 -50
  60. py_neuromodulation/nm_run_analysis.py +325 -435
  61. py_neuromodulation/nm_settings.py +289 -68
  62. py_neuromodulation/nm_settings.yaml +244 -0
  63. py_neuromodulation/nm_sharpwaves.py +423 -401
  64. py_neuromodulation/nm_stats.py +464 -480
  65. py_neuromodulation/nm_stream.py +398 -0
  66. py_neuromodulation/nm_stream_abc.py +166 -218
  67. py_neuromodulation/nm_types.py +193 -0
  68. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
  69. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  70. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
  71. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
  72. py_neuromodulation/nm_EpochStream.py +0 -92
  73. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  74. py_neuromodulation/nm_cohortwrapper.py +0 -435
  75. py_neuromodulation/nm_eval_timing.py +0 -239
  76. py_neuromodulation/nm_features_abc.py +0 -39
  77. py_neuromodulation/nm_settings.json +0 -338
  78. py_neuromodulation/nm_stream_offline.py +0 -359
  79. py_neuromodulation/utils/_logging.py +0 -24
  80. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -1,401 +1,423 @@
1
- import numpy as np
2
- from scipy import signal
3
- from mne.filter import create_filter
4
- from typing import Iterable
5
-
6
- from py_neuromodulation import nm_features_abc
7
-
8
-
9
- class NoValidTroughException(Exception):
10
- pass
11
-
12
-
13
- class SharpwaveAnalyzer(nm_features_abc.Feature):
14
- def __init__(
15
- self, settings: dict, ch_names: Iterable[str], sfreq: float
16
- ) -> None:
17
-
18
- self.sw_settings = settings["sharpwave_analysis_settings"]
19
- self.sfreq = sfreq
20
- self.ch_names = ch_names
21
- self.list_filter = []
22
- for filter_range in settings["sharpwave_analysis_settings"][
23
- "filter_ranges_hz"
24
- ]:
25
-
26
- if filter_range[0] is None:
27
- self.list_filter.append(("no_filter", None))
28
- else:
29
- self.list_filter.append(
30
- (
31
- f"range_{filter_range[0]}_{filter_range[1]}",
32
- create_filter(
33
- None,
34
- sfreq,
35
- l_freq=filter_range[0],
36
- h_freq=filter_range[1],
37
- fir_design="firwin",
38
- # l_trans_bandwidth=None,
39
- # h_trans_bandwidth=None,
40
- # filter_length=str(sfreq) + "ms",
41
- verbose=False,
42
- ),
43
- )
44
- )
45
-
46
- # initialize used features
47
- self.used_features = list()
48
- for feature_name, val in self.sw_settings["sharpwave_features"].items():
49
- if val is True:
50
- self.used_features.append(feature_name)
51
-
52
- # initialize attributes
53
- self._initialize_sw_features()
54
-
55
- # initializing estimator functions, respecitive for all sharpwave features
56
- fun_names = []
57
- for used_feature in self.used_features:
58
- estimator_list_feature = (
59
- []
60
- ) # one feature can have multiple estimators
61
- for estimator, est_features in self.sw_settings[
62
- "estimator"
63
- ].items():
64
- if est_features is not None:
65
- for est_feature in est_features:
66
- if used_feature == est_feature:
67
- estimator_list_feature.append(estimator)
68
- fun_names.append(estimator_list_feature)
69
-
70
- self.estimator_names = fun_names
71
- self.estimator_functions = [
72
- [
73
- getattr(np, est_name)
74
- for est_name in self.estimator_names[feature_idx]
75
- ]
76
- for feature_idx in range(len(self.estimator_names))
77
- ]
78
-
79
- def _initialize_sw_features(self) -> None:
80
- """Resets used attributes to empty lists"""
81
- for feature_name in self.used_features:
82
- setattr(self, feature_name, list())
83
- if "trough" not in self.used_features:
84
- # trough attribute is still necessary, even if it is not specified in settings
85
- self.trough = list()
86
- self.troughs_idx = list()
87
-
88
- def _get_peaks_around(self, trough_ind, arr_ind_peaks, filtered_dat):
89
- """Find the closest peaks to the right and left side a given trough.
90
-
91
- Parameters
92
- ----------
93
- trough_ind (int): index of trough
94
- arr_ind_peaks (np.ndarray): array of peak indices
95
- filtered_dat (np.ndarray): raw data batch
96
-
97
- Raises:
98
- NoValidTroughException: Returned if no adjacent peak can be found
99
-
100
- Returns
101
- -------
102
- peak_left_idx (np.ndarray): index of left peak
103
- peak_right_idx (np.ndarray): index of right peak
104
- peak_left_val (np.ndarray): value of left peak
105
- peak_right_val (np.ndarray): value of righ peak
106
- """
107
-
108
- try: peak_right_idx = arr_ind_peaks[arr_ind_peaks > trough_ind][0]
109
- except IndexError: raise NoValidTroughException("No valid trough")
110
-
111
- try: peak_left_idx = arr_ind_peaks[arr_ind_peaks < trough_ind][-1]
112
- except IndexError: raise NoValidTroughException("No valid trough")
113
-
114
- return (
115
- peak_left_idx,
116
- peak_right_idx,
117
- filtered_dat[peak_left_idx],
118
- filtered_dat[peak_right_idx],
119
- )
120
-
121
- def calc_feature(
122
- self,
123
- data: np.array,
124
- features_compute: dict,
125
- ) -> dict:
126
- """Given a new data batch, the peaks, troughs and sharpwave features
127
- are estimated. Importantly only new data is being analyzed here. In
128
- steps of 1/settings["sampling_rate_features] analyzed and returned.
129
- Pre-initialized filters are applied to each channel.
130
-
131
- Parameters
132
- ----------
133
- data (np.ndarray): 2d data array with shape [num_channels, samples]
134
- features_compute (dict): Features.py estimated features
135
-
136
- Returns
137
- -------
138
- features_compute (dict): set features for Features.py object
139
- """
140
- for ch_idx, ch_name in enumerate(self.ch_names):
141
- for filter_name, filter in self.list_filter:
142
- self.data_process_sw = (data[ch_idx, :]
143
- if filter_name == "no_filter"
144
- else signal.fftconvolve(data[ch_idx, :], filter, mode="same")
145
- )
146
-
147
- # check settings if troughs and peaks are analyzed
148
-
149
- dict_ch_features = {}
150
-
151
- for detect_troughs in [False, True]:
152
-
153
- if detect_troughs is False:
154
- if (
155
- self.sw_settings["detect_peaks"]["estimate"]
156
- is False
157
- ):
158
- continue
159
- key_name_pt = "Peak"
160
- # the detect_troughs loop start with peaks, s.t. data does not
161
- # need to be flipped
162
-
163
- if detect_troughs is True:
164
- if (
165
- self.sw_settings["detect_troughs"]["estimate"]
166
- is False
167
- ):
168
- continue
169
- key_name_pt = "Trough"
170
-
171
- self.data_process_sw = -self.data_process_sw
172
-
173
- self._initialize_sw_features() # reset sharpwave feature attriubtes to empty lists
174
- self.analyze_waveform()
175
-
176
- # for each feature take the respective fun.
177
- for feature_idx, feature_name in enumerate(
178
- self.used_features
179
- ):
180
- for est_idx, estimator_name in enumerate(
181
- self.estimator_names[feature_idx]
182
- ):
183
- key_name = "_".join(
184
- [
185
- ch_name,
186
- "Sharpwave",
187
- self.estimator_names[feature_idx][
188
- est_idx
189
- ].title(),
190
- feature_name,
191
- filter_name,
192
- ]
193
- )
194
- # zero check because no peaks can be also detected
195
- val = (
196
- self.estimator_functions[feature_idx][est_idx](
197
- getattr(self, feature_name)
198
- )
199
- if len(getattr(self, feature_name)) != 0
200
- else 0
201
- )
202
- if key_name not in dict_ch_features:
203
- dict_ch_features[key_name] = {}
204
- dict_ch_features[key_name][key_name_pt] = val
205
-
206
- if self.sw_settings[
207
- "apply_estimator_between_peaks_and_troughs"
208
- ]:
209
- # apply between 'Trough' and 'Peak' the respective function again
210
- # save only the 'est_fun' (e.g. max) between them
211
-
212
- for idx, key_name in enumerate(dict_ch_features):
213
- # the key_name stays, since the estimator function stays between peaks and troughs
214
- # this array needs to be flattened
215
- features_compute[key_name] = np.concatenate(
216
- self.estimator_functions
217
- )[idx](
218
- [
219
- list(dict_ch_features[key_name].values())[0],
220
- list(dict_ch_features[key_name].values())[1],
221
- ]
222
- )
223
-
224
- else:
225
- # otherwise, save all
226
- # write all "flatted" key value pairs in features_
227
- for key, value in dict_ch_features.items():
228
- for key_sub, value_sub in dict_ch_features[key].items():
229
- features_compute[
230
- key + "_analyze_" + key_sub
231
- ] = value_sub
232
-
233
- return features_compute
234
-
235
- def analyze_waveform(self) -> None:
236
- """Given the scipy.signal.find_peaks trough/peak distance
237
- settings specified sharpwave features are estimated.
238
-
239
- Parameters
240
- ----------
241
-
242
- Raises:
243
- NoValidTroughException: Return if no adjacent peak can be found
244
- NoValidTroughException: Return if no adjacent peak can be found
245
-
246
- """
247
-
248
- peaks = signal.find_peaks(
249
- self.data_process_sw,
250
- distance=self.sw_settings["detect_troughs"]["distance_peaks_ms"],
251
- )[0]
252
- troughs = signal.find_peaks(
253
- -self.data_process_sw,
254
- distance=self.sw_settings["detect_troughs"]["distance_troughs_ms"],
255
- )[0]
256
-
257
- """ Find left and right peak indexes for each trough """
258
- peak_pointer = 0
259
- peak_idx_left = []
260
- peak_idx_right = []
261
- first_valid = last_valid = 0
262
-
263
- for i, trough_idx in enumerate(troughs):
264
-
265
- # Locate peak right of current trough
266
- while peak_pointer < peaks.size and peaks[peak_pointer] < trough_idx:
267
- peak_pointer += 1
268
-
269
- if peak_pointer - 1 < 0:
270
- # If trough has no peak to it's left, it's not valid
271
- first_valid = i + 1 # Try with next one
272
- continue
273
-
274
- if peak_pointer == peaks.size:
275
- # If we went past the end of the peaks list, trough had no peak to its right
276
- continue
277
-
278
- last_valid = i
279
- peak_idx_left.append(peaks[peak_pointer - 1])
280
- peak_idx_right.append(peaks[peak_pointer])
281
-
282
- troughs = troughs[first_valid:last_valid + 1] # Remove non valid troughs
283
-
284
- peak_idx_left = np.array(peak_idx_left, dtype=np.integer)
285
- peak_idx_right = np.array(peak_idx_right, dtype=np.integer)
286
-
287
- peak_left = self.data_process_sw[peak_idx_left]
288
- peak_right = self.data_process_sw[peak_idx_right]
289
- trough_values = self.data_process_sw[troughs]
290
-
291
- # No need to store trough data as it is not used anywhere else in the program
292
- # self.trough.append(trough)
293
- # self.troughs_idx.append(trough_idx)
294
-
295
- """ Calculate features (vectorized) """
296
-
297
- if self.sw_settings["sharpwave_features"]["interval"]:
298
- self.interval = np.concatenate(([0], np.diff(troughs))) * (1000 / self.sfreq)
299
-
300
- if self.sw_settings["sharpwave_features"]["peak_left"]:
301
- self.peak_left = peak_left
302
-
303
- if self.sw_settings["sharpwave_features"]["peak_right"]:
304
- self.peak_right = peak_right
305
-
306
- if self.sw_settings["sharpwave_features"]["sharpness"]:
307
- # sharpess is calculated on a +- 5 ms window
308
- # valid troughs need 5 ms of margin on both siddes
309
- troughs_valid = troughs[np.logical_and(
310
- troughs - int(5 * (1000 / self.sfreq)) > 0,
311
- troughs + int(5 * (1000 / self.sfreq)) < self.data_process_sw.shape[0])]
312
-
313
- self.sharpness = (
314
- (self.data_process_sw[troughs_valid] - self.data_process_sw[troughs_valid - int(5 * (1000 / self.sfreq))]) +
315
- (self.data_process_sw[troughs_valid] - self.data_process_sw[troughs_valid + int(5 * (1000 / self.sfreq))])
316
- ) / 2
317
-
318
- if (self.sw_settings["sharpwave_features"]["rise_steepness"] or
319
- self.sw_settings["sharpwave_features"]["decay_steepness"]):
320
-
321
- # steepness is calculated as the first derivative
322
- steepness = np.concatenate(([0],np.diff(self.data_process_sw)))
323
-
324
- if self.sw_settings["sharpwave_features"]["rise_steepness"]: # left peak -> trough
325
- # + 1 due to python syntax, s.t. the last element is included
326
- self.rise_steepness = np.array([
327
- np.max(np.abs(steepness[peak_idx_left[i] : troughs[i] + 1]))
328
- for i in range(trough_idx.size)
329
- ])
330
-
331
- if self.sw_settings["sharpwave_features"]["decay_steepness"]: # trough -> right peak
332
- self.decay_steepness = np.array([
333
- np.max(np.abs(steepness[troughs[i] : peak_idx_right[i] + 1]))
334
- for i in range(trough_idx.size)
335
- ])
336
-
337
- if (self.sw_settings["sharpwave_features"]["rise_steepness"] and
338
- self.sw_settings["sharpwave_features"]["decay_steepness"] and
339
- self.sw_settings["sharpwave_features"]["slope_ratio"]):
340
- self.slope_ratio = self.rise_steepness - self.decay_steepness
341
-
342
- if self.sw_settings["sharpwave_features"]["prominence"]:
343
- self.prominence = np.abs((peak_right + peak_left) / 2 - trough_values)
344
-
345
- if self.sw_settings["sharpwave_features"]["decay_time"]:
346
- self.decay_time = (peak_idx_left - troughs) * (1000 / self.sfreq) # ms
347
-
348
- if self.sw_settings["sharpwave_features"]["rise_time"]:
349
- self.rise_time = (peak_idx_right - troughs) * (1000 / self.sfreq) # ms
350
-
351
- if self.sw_settings["sharpwave_features"]["width"]:
352
- self.width = peak_idx_right - peak_idx_left # ms
353
-
354
- @staticmethod
355
- def test_settings(
356
- s: dict,
357
- ch_names: Iterable[str],
358
- sfreq: int | float,
359
- ):
360
- for filter_range in s["sharpwave_analysis_settings"][
361
- "filter_ranges_hz"
362
- ]:
363
- assert isinstance(
364
- filter_range[0],
365
- int,
366
- ), f"filter range needs to be of type int, got {filter_range[0]}"
367
- assert isinstance(
368
- filter_range[1],
369
- int,
370
- ), f"filter range needs to be of type int, got {filter_range[1]}"
371
- assert (
372
- filter_range[1] > filter_range[0]
373
- ), f"second filter value needs to be higher than first one, got {filter_range}"
374
- assert filter_range[0] < sfreq and filter_range[1] < sfreq, (
375
- "filter range has to be smaller than sfreq, "
376
- f"got sfreq {sfreq} and filter range {filter_range}"
377
- )
378
- # check if all features are also enbled via an estimator
379
- used_features = list()
380
- for feature_name, val in s["sharpwave_analysis_settings"][
381
- "sharpwave_features"
382
- ].items():
383
- assert isinstance(
384
- val, bool
385
- ), f"sharpwave_feature type {feature_name} has to be of type bool, got {val}"
386
- if val is True:
387
- used_features.append(feature_name)
388
- for used_feature in used_features:
389
- estimator_list_feature = (
390
- []
391
- ) # one feature can have multiple estimators
392
- for estimator, est_features in s["sharpwave_analysis_settings"][
393
- "estimator"
394
- ].items():
395
- if est_features is not None:
396
- for est_feature in est_features:
397
- if used_feature == est_feature:
398
- estimator_list_feature.append(estimator)
399
- assert (
400
- len(estimator_list_feature) > 0
401
- ), f"add estimator key for {used_feature}"
1
+ from collections.abc import Sequence
2
+ from collections import defaultdict
3
+ from itertools import product
4
+
5
+ from py_neuromodulation.nm_types import NMBaseModel
6
+ from pydantic import model_validator
7
+ from typing import TYPE_CHECKING, Any, Callable
8
+
9
+ import numpy as np
10
+
11
+ if np.__version__ >= "2.0.0":
12
+ from numpy._core._methods import _mean as np_mean # type: ignore
13
+ else:
14
+ from numpy.core._methods import _mean as np_mean
15
+
16
+ from py_neuromodulation.nm_features import NMFeature
17
+ from py_neuromodulation.nm_types import BoolSelector, FrequencyRange
18
+
19
+ if TYPE_CHECKING:
20
+ from py_neuromodulation.nm_settings import NMSettings
21
+
22
+ # Using low-level numpy mean function for performance, could do the same for the other estimators
23
+ ESTIMATOR_DICT = {
24
+ "mean": np_mean,
25
+ "median": np.median,
26
+ "max": np.max,
27
+ "min": np.min,
28
+ "var": np.var,
29
+ }
30
+
31
+
32
+ class PeakDetectionSettings(NMBaseModel):
33
+ estimate: bool = True
34
+ distance_troughs_ms: float = 10
35
+ distance_peaks_ms: float = 5
36
+
37
+
38
+ class SharpwaveFeatures(BoolSelector):
39
+ peak_left: bool = False
40
+ peak_right: bool = False
41
+ trough: bool = False
42
+ width: bool = False
43
+ prominence: bool = True
44
+ interval: bool = True
45
+ decay_time: bool = False
46
+ rise_time: bool = False
47
+ sharpness: bool = True
48
+ rise_steepness: bool = False
49
+ decay_steepness: bool = False
50
+ slope_ratio: bool = False
51
+
52
+
53
+ class SharpwaveEstimators(NMBaseModel):
54
+ mean: list[str] = ["interval"]
55
+ median: list[str] = []
56
+ max: list[str] = ["prominence", "sharpness"]
57
+ min: list[str] = []
58
+ var: list[str] = []
59
+
60
+ def keys(self):
61
+ return ["mean", "median", "max", "min", "var"]
62
+
63
+ def values(self):
64
+ return [self.mean, self.median, self.max, self.min, self.var]
65
+
66
+
67
+ class SharpwaveSettings(NMBaseModel):
68
+ sharpwave_features: SharpwaveFeatures = SharpwaveFeatures()
69
+ filter_ranges_hz: list[FrequencyRange] = [
70
+ FrequencyRange(5, 80),
71
+ FrequencyRange(5, 30),
72
+ ]
73
+ detect_troughs: PeakDetectionSettings = PeakDetectionSettings()
74
+ detect_peaks: PeakDetectionSettings = PeakDetectionSettings()
75
+ estimator: SharpwaveEstimators = SharpwaveEstimators()
76
+ apply_estimator_between_peaks_and_troughs: bool = True
77
+
78
+ def disable_all_features(self):
79
+ self.sharpwave_features.disable_all()
80
+ for est in self.estimator.keys():
81
+ self.estimator[est] = []
82
+
83
+ @model_validator(mode="after")
84
+ def test_settings(cls, settings):
85
+ # check if all features are also enabled via an estimator
86
+ estimator_list = [est for list_ in settings.estimator.values() for est in list_]
87
+
88
+ for used_feature in settings.sharpwave_features.get_enabled():
89
+ assert (
90
+ used_feature in estimator_list
91
+ ), f"Add estimator key for {used_feature}"
92
+
93
+ return settings
94
+
95
+
96
+ class SharpwaveAnalyzer(NMFeature):
97
+ def __init__(
98
+ self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
99
+ ) -> None:
100
+ self.sw_settings = settings.sharpwave_analysis_settings
101
+ self.sfreq = sfreq
102
+ self.ch_names = ch_names
103
+ self.list_filter: list[tuple[str, Any]] = []
104
+ self.trough: list = []
105
+ self.troughs_idx: list = []
106
+
107
+ settings.validate()
108
+
109
+ # FrequencyRange's are already ensured to have high > low
110
+ # Test that the higher frequency is smaller than the sampling frequency
111
+ for filter_range in settings.sharpwave_analysis_settings.filter_ranges_hz:
112
+ assert filter_range[1] < sfreq, (
113
+ "Filter range has to be smaller than sfreq, "
114
+ f"got sfreq {sfreq} and filter range {filter_range}"
115
+ )
116
+
117
+ for filter_range in settings.sharpwave_analysis_settings.filter_ranges_hz:
118
+ # Test settings
119
+ # TODO: handle None values
120
+ if filter_range[0] is None:
121
+ self.list_filter.append(("no_filter", None))
122
+ else:
123
+ from mne.filter import create_filter
124
+
125
+ self.list_filter.append(
126
+ (
127
+ f"range_{filter_range[0]:.0f}_{filter_range[1]:.0f}",
128
+ create_filter(
129
+ None,
130
+ sfreq,
131
+ l_freq=filter_range[0],
132
+ h_freq=filter_range[1],
133
+ fir_design="firwin",
134
+ # l_trans_bandwidth=None,
135
+ # h_trans_bandwidth=None,
136
+ # filter_length=str(sfreq) + "ms",
137
+ verbose=False,
138
+ ),
139
+ )
140
+ )
141
+
142
+ self.filter_names = [name for name, _ in self.list_filter]
143
+ self.filters = np.vstack([filter for _, filter in self.list_filter])
144
+ self.filters = np.tile(self.filters[None, :, :], (len(self.ch_names), 1, 1))
145
+
146
+ self.used_features = self.sw_settings.sharpwave_features.get_enabled()
147
+
148
+ # initializing estimator functions, respecitive for all sharpwave features
149
+ self.estimator_dict: dict[str, dict[str, Callable]] = {
150
+ feat: {
151
+ est: ESTIMATOR_DICT[est]
152
+ for est in self.sw_settings.estimator.keys()
153
+ if feat in self.sw_settings.estimator[est]
154
+ }
155
+ for feat_list in self.sw_settings.estimator.values()
156
+ for feat in feat_list
157
+ }
158
+
159
+ estimator_combinations = [
160
+ (feature_name, estimator_name, estimator)
161
+ for feature_name in self.used_features
162
+ for estimator_name, estimator in self.estimator_dict[feature_name].items()
163
+ ]
164
+
165
+ filter_combinations = list(
166
+ product(
167
+ enumerate(self.ch_names), enumerate(self.filter_names), [False, True]
168
+ )
169
+ )
170
+
171
+ self.estimator_key_map: dict[str, Callable] = {}
172
+ self.combinations = []
173
+ for (ch_idx, ch_name), (
174
+ filter_idx,
175
+ filter_name,
176
+ ), detect_troughs in filter_combinations:
177
+ for feature_name, estimator_name, estimator in estimator_combinations:
178
+ key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
179
+ self.estimator_key_map[key_name] = estimator
180
+ self.combinations.append(
181
+ (
182
+ (ch_idx, ch_name),
183
+ (filter_idx, filter_name),
184
+ detect_troughs,
185
+ estimator_combinations,
186
+ )
187
+ )
188
+
189
+ # Check required feature computations according to settings
190
+ self.need_peak_left = (
191
+ self.sw_settings.sharpwave_features.peak_left
192
+ or self.sw_settings.sharpwave_features.prominence
193
+ )
194
+ self.need_peak_right = (
195
+ self.sw_settings.sharpwave_features.peak_right
196
+ or self.sw_settings.sharpwave_features.prominence
197
+ )
198
+ self.need_trough = (
199
+ self.sw_settings.sharpwave_features.trough
200
+ or self.sw_settings.sharpwave_features.prominence
201
+ )
202
+
203
+ self.need_decay_steepness = (
204
+ self.sw_settings.sharpwave_features.decay_steepness
205
+ or self.sw_settings.sharpwave_features.slope_ratio
206
+ )
207
+
208
+ self.need_rise_steepness = (
209
+ self.sw_settings.sharpwave_features.rise_steepness
210
+ or self.sw_settings.sharpwave_features.slope_ratio
211
+ )
212
+
213
+ self.need_steepness = self.need_rise_steepness or self.need_decay_steepness
214
+
215
+ def calc_feature(self, data: np.ndarray) -> dict:
216
+ """Given a new data batch, the peaks, troughs and sharpwave features
217
+ are estimated. Importantly only new data is being analyzed here. In
218
+ steps of 1/settings["sampling_rate_features] analyzed and returned.
219
+ Pre-initialized filters are applied to each channel.
220
+
221
+ Parameters
222
+ ----------
223
+ data (np.ndarray): 2d data array with shape [num_channels, samples]
224
+ feature_results (dict): Features.py estimated features
225
+
226
+ Returns
227
+ -------
228
+ feature_results (dict): set features for Features.py object
229
+ """
230
+ dict_ch_features: dict[str, dict[str, float]] = defaultdict(lambda: {})
231
+
232
+ from scipy.signal import fftconvolve
233
+
234
+ data = np.tile(data[:, None, :], (1, len(self.list_filter), 1))
235
+ data = fftconvolve(data, self.filters, axes=2, mode="same")
236
+
237
+ self.filtered_data = (
238
+ data # TONI: Expose filtered data for example 3, need a better way
239
+ )
240
+
241
+ feature_results = {}
242
+
243
+ for (
244
+ (ch_idx, ch_name),
245
+ (filter_idx, filter_name),
246
+ detect_troughs,
247
+ estimator_combinations,
248
+ ) in self.combinations:
249
+ sub_data = data[ch_idx, filter_idx, :]
250
+
251
+ key_name_pt = "Trough" if detect_troughs else "Peak"
252
+
253
+ if (not detect_troughs and not self.sw_settings.detect_peaks.estimate) or (
254
+ detect_troughs and not self.sw_settings.detect_troughs.estimate
255
+ ):
256
+ continue
257
+
258
+ # the detect_troughs loop start with peaks, s.t. data does not need to be flipped
259
+ sub_data = -sub_data if detect_troughs else sub_data
260
+ # sub_data *= 1 - 2 * detect_troughs # branchless version
261
+
262
+ waveform_results = self.analyze_waveform(sub_data)
263
+
264
+ # for each feature take the respective fun.
265
+ for feature_name, estimator_name, estimator in estimator_combinations:
266
+ feature_data = waveform_results[feature_name]
267
+ key_name = f"{ch_name}_Sharpwave_{estimator_name.title()}_{feature_name}_{filter_name}"
268
+
269
+ # zero check because no peaks can be also detected
270
+ feature_data = estimator(feature_data) if len(feature_data) != 0 else 0
271
+ dict_ch_features[key_name][key_name_pt] = feature_data
272
+
273
+ if self.sw_settings.apply_estimator_between_peaks_and_troughs:
274
+ # apply between 'Trough' and 'Peak' the respective function again
275
+ # save only the 'est_fun' (e.g. max) between them
276
+
277
+ # the key_name stays, since the estimator function stays between peaks and troughs
278
+ for key_name, estimator in self.estimator_key_map.items():
279
+ feature_results[key_name] = estimator(
280
+ [
281
+ list(dict_ch_features[key_name].values())[0],
282
+ list(dict_ch_features[key_name].values())[1],
283
+ ]
284
+ )
285
+ else:
286
+ # otherwise, save all write all "flattened" key value pairs in feature_results
287
+ for key, subdict in dict_ch_features.items():
288
+ for key_sub, value_sub in subdict.items():
289
+ feature_results[key + "_analyze_" + key_sub] = value_sub
290
+
291
+ return feature_results
292
+
293
+ def analyze_waveform(self, data) -> dict:
294
+ """Given the scipy.signal.find_peaks trough/peak distance
295
+ settings specified sharpwave features are estimated.
296
+ """
297
+
298
+ from scipy.signal import find_peaks
299
+
300
+ # TODO: find peaks is actually not that big a performance hit, but the rest
301
+ # of this function is. Perhaps find_peaks can be put in a loop and the rest optimized somehow?
302
+ peak_idx: np.ndarray = find_peaks(
303
+ data, distance=self.sw_settings.detect_troughs.distance_peaks_ms
304
+ )[0]
305
+ trough_idx: np.ndarray = find_peaks(
306
+ -data, distance=self.sw_settings.detect_troughs.distance_troughs_ms
307
+ )[0]
308
+
309
+ """ Find left and right peak indexes for each trough """
310
+ peak_pointer = first_valid = last_valid = 0
311
+ peak_idx_left_list: list[int] = []
312
+ peak_idx_right_list: list[int] = []
313
+
314
+ for i in range(len(trough_idx)):
315
+ # Locate peak right of current trough
316
+ while (
317
+ peak_pointer < peak_idx.size and peak_idx[peak_pointer] < trough_idx[i]
318
+ ):
319
+ peak_pointer += 1
320
+
321
+ if peak_pointer - 1 < 0:
322
+ # If trough has no peak to it's left, it's not valid
323
+ first_valid = i + 1 # Try with next one
324
+ continue
325
+
326
+ if peak_pointer == peak_idx.size:
327
+ # If we went past the end of the peaks list, trough had no peak to its right
328
+ continue
329
+
330
+ last_valid = i
331
+ peak_idx_left_list.append(peak_idx[peak_pointer - 1])
332
+ peak_idx_right_list.append(peak_idx[peak_pointer])
333
+
334
+ # Remove non valid troughs and make array of left and right peaks for each trough
335
+ trough_idx = trough_idx[first_valid : last_valid + 1]
336
+ peak_idx_left = np.array(peak_idx_left_list, dtype=int)
337
+ peak_idx_right = np.array(peak_idx_right_list, dtype=int)
338
+
339
+ """ Calculate features (vectorized) """
340
+ results: dict = {}
341
+
342
+ if self.need_peak_left:
343
+ results["peak_left"] = data[peak_idx_left]
344
+
345
+ if self.need_peak_right:
346
+ results["peak_right"] = data[peak_idx_right]
347
+
348
+ if self.need_trough:
349
+ results["trough"] = data[trough_idx]
350
+
351
+ if self.sw_settings.sharpwave_features.interval:
352
+ results["interval"] = np.concatenate((np.zeros(1), np.diff(trough_idx))) * (
353
+ 1000 / self.sfreq
354
+ )
355
+
356
+ if self.sw_settings.sharpwave_features.sharpness:
357
+ # sharpess is calculated on a +- 5 ms window
358
+ # valid troughs need 5 ms of margin on both sides
359
+ troughs_valid = trough_idx[
360
+ np.logical_and(
361
+ trough_idx - int(5 * (1000 / self.sfreq)) > 0,
362
+ trough_idx + int(5 * (1000 / self.sfreq)) < data.shape[0],
363
+ )
364
+ ]
365
+ trough_height = data[troughs_valid]
366
+ left_height = data[troughs_valid - int(5 * (1000 / self.sfreq))]
367
+ right_height = data[troughs_valid + int(5 * (1000 / self.sfreq))]
368
+ # results["sharpness"] = ((trough_height - left_height) + (trough_height - right_height)) / 2
369
+ results["sharpness"] = trough_height - 0.5 * (left_height + right_height)
370
+
371
+ if self.need_steepness:
372
+ # steepness is calculated as the first derivative
373
+ steepness: np.ndarray = np.concatenate((np.zeros(1), np.diff(data)))
374
+
375
+ # Create an array with the rise and decay steepness for each trough
376
+ # 0th dimension for rise/decay, 1st for trough index, 2nd for timepoint
377
+ steepness_troughs = np.zeros((2, trough_idx.shape[0], steepness.shape[0]))
378
+ if self.need_rise_steepness or self.need_decay_steepness:
379
+ for i in range(len(trough_idx)):
380
+ steepness_troughs[
381
+ 0, i, 0 : trough_idx[i] - peak_idx_left[i] + 1
382
+ ] = steepness[peak_idx_left[i] : trough_idx[i] + 1]
383
+ steepness_troughs[
384
+ 1, i, 0 : peak_idx_right[i] - trough_idx[i] + 1
385
+ ] = steepness[trough_idx[i] : peak_idx_right[i] + 1]
386
+
387
+ if self.need_rise_steepness:
388
+ # left peak -> trough
389
+ # + 1 due to python syntax, s.t. the last element is included
390
+ results["rise_steepness"] = np.max(
391
+ np.abs(steepness_troughs[0, :, :]), axis=1
392
+ )
393
+
394
+ if self.need_decay_steepness:
395
+ # trough -> right peak
396
+ results["decay_steepness"] = np.max(
397
+ np.abs(steepness_troughs[1, :, :]), axis=1
398
+ )
399
+
400
+ if self.sw_settings.sharpwave_features.slope_ratio:
401
+ results["slope_ratio"] = (
402
+ results["rise_steepness"] - results["decay_steepness"]
403
+ )
404
+
405
+ if self.sw_settings.sharpwave_features.prominence:
406
+ results["prominence"] = np.abs(
407
+ (results["peak_right"] + results["peak_left"]) / 2 - results["trough"]
408
+ )
409
+
410
+ if self.sw_settings.sharpwave_features.decay_time:
411
+ results["decay_time"] = (peak_idx_left - trough_idx) * (
412
+ 1000 / self.sfreq
413
+ ) # ms
414
+
415
+ if self.sw_settings.sharpwave_features.rise_time:
416
+ results["rise_time"] = (peak_idx_right - trough_idx) * (
417
+ 1000 / self.sfreq
418
+ ) # ms
419
+
420
+ if self.sw_settings.sharpwave_features.width:
421
+ results["width"] = peak_idx_right - peak_idx_left # ms
422
+
423
+ return results