py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/__init__.py +80 -13
  5. py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
  6. py_neuromodulation/analysis/__init__.py +4 -0
  7. py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
  8. py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
  9. py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
  10. py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
  11. py_neuromodulation/data/README +6 -6
  12. py_neuromodulation/data/dataset_description.json +8 -8
  13. py_neuromodulation/data/participants.json +32 -32
  14. py_neuromodulation/data/participants.tsv +2 -2
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  22. py_neuromodulation/default_settings.yaml +241 -0
  23. py_neuromodulation/features/__init__.py +31 -0
  24. py_neuromodulation/features/bandpower.py +165 -0
  25. py_neuromodulation/features/bispectra.py +157 -0
  26. py_neuromodulation/features/bursts.py +297 -0
  27. py_neuromodulation/features/coherence.py +255 -0
  28. py_neuromodulation/features/feature_processor.py +121 -0
  29. py_neuromodulation/features/fooof.py +142 -0
  30. py_neuromodulation/features/hjorth_raw.py +57 -0
  31. py_neuromodulation/features/linelength.py +21 -0
  32. py_neuromodulation/features/mne_connectivity.py +148 -0
  33. py_neuromodulation/features/nolds.py +94 -0
  34. py_neuromodulation/features/oscillatory.py +249 -0
  35. py_neuromodulation/features/sharpwaves.py +432 -0
  36. py_neuromodulation/filter/__init__.py +3 -0
  37. py_neuromodulation/filter/kalman_filter.py +67 -0
  38. py_neuromodulation/filter/kalman_filter_external.py +1890 -0
  39. py_neuromodulation/filter/mne_filter.py +128 -0
  40. py_neuromodulation/filter/notch_filter.py +93 -0
  41. py_neuromodulation/grid_cortex.tsv +40 -40
  42. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  43. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  44. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  45. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  46. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  47. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  48. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  49. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  50. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  51. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  52. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  53. py_neuromodulation/processing/__init__.py +10 -0
  54. py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
  55. py_neuromodulation/processing/data_preprocessor.py +77 -0
  56. py_neuromodulation/processing/filter_preprocessing.py +78 -0
  57. py_neuromodulation/processing/normalization.py +175 -0
  58. py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
  59. py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
  60. py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
  61. py_neuromodulation/stream/__init__.py +3 -0
  62. py_neuromodulation/stream/data_processor.py +325 -0
  63. py_neuromodulation/stream/generator.py +53 -0
  64. py_neuromodulation/stream/mnelsl_player.py +94 -0
  65. py_neuromodulation/stream/mnelsl_stream.py +120 -0
  66. py_neuromodulation/stream/settings.py +292 -0
  67. py_neuromodulation/stream/stream.py +427 -0
  68. py_neuromodulation/utils/__init__.py +2 -0
  69. py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
  70. py_neuromodulation/utils/database.py +149 -0
  71. py_neuromodulation/utils/io.py +378 -0
  72. py_neuromodulation/utils/keyboard.py +52 -0
  73. py_neuromodulation/utils/logging.py +66 -0
  74. py_neuromodulation/utils/types.py +251 -0
  75. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
  76. py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
  77. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
  78. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
  79. py_neuromodulation/FieldTrip.py +0 -589
  80. py_neuromodulation/_write_example_dataset_helper.py +0 -65
  81. py_neuromodulation/nm_EpochStream.py +0 -92
  82. py_neuromodulation/nm_IO.py +0 -417
  83. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  84. py_neuromodulation/nm_bispectra.py +0 -168
  85. py_neuromodulation/nm_bursts.py +0 -198
  86. py_neuromodulation/nm_coherence.py +0 -205
  87. py_neuromodulation/nm_cohortwrapper.py +0 -435
  88. py_neuromodulation/nm_eval_timing.py +0 -239
  89. py_neuromodulation/nm_features.py +0 -116
  90. py_neuromodulation/nm_features_abc.py +0 -39
  91. py_neuromodulation/nm_filter.py +0 -219
  92. py_neuromodulation/nm_filter_preprocessing.py +0 -91
  93. py_neuromodulation/nm_fooof.py +0 -159
  94. py_neuromodulation/nm_generator.py +0 -37
  95. py_neuromodulation/nm_hjorth_raw.py +0 -73
  96. py_neuromodulation/nm_kalmanfilter.py +0 -58
  97. py_neuromodulation/nm_linelength.py +0 -33
  98. py_neuromodulation/nm_mne_connectivity.py +0 -112
  99. py_neuromodulation/nm_nolds.py +0 -93
  100. py_neuromodulation/nm_normalization.py +0 -214
  101. py_neuromodulation/nm_oscillatory.py +0 -448
  102. py_neuromodulation/nm_run_analysis.py +0 -435
  103. py_neuromodulation/nm_settings.json +0 -338
  104. py_neuromodulation/nm_settings.py +0 -68
  105. py_neuromodulation/nm_sharpwaves.py +0 -401
  106. py_neuromodulation/nm_stream_abc.py +0 -218
  107. py_neuromodulation/nm_stream_offline.py +0 -359
  108. py_neuromodulation/utils/_logging.py +0 -24
  109. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -0,0 +1,165 @@
1
+ import numpy as np
2
+ from collections.abc import Sequence
3
+ from typing import TYPE_CHECKING
4
+ from pydantic import field_validator
5
+
6
+ from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
7
+
8
+ if TYPE_CHECKING:
9
+ from py_neuromodulation.stream.settings import NMSettings
10
+ from py_neuromodulation.filter import KalmanSettings
11
+
12
+
13
+ class BandpowerFeatures(BoolSelector):
14
+ activity: bool = True
15
+ mobility: bool = False
16
+ complexity: bool = False
17
+
18
+
19
+ class BandPowerSettings(NMBaseModel):
20
+ segment_lengths_ms: dict[str, int] = {
21
+ "theta": 1000,
22
+ "alpha": 500,
23
+ "low beta": 333,
24
+ "high beta": 333,
25
+ "low gamma": 100,
26
+ "high gamma": 100,
27
+ "HFA": 100,
28
+ }
29
+ bandpower_features: BandpowerFeatures = BandpowerFeatures()
30
+ log_transform: bool = True
31
+ kalman_filter: bool = False
32
+
33
+ @field_validator("bandpower_features")
34
+ @classmethod
35
+ def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
36
+ assert (
37
+ len(bandpower_features.get_enabled()) > 0
38
+ ), "Set at least one bandpower_feature to True."
39
+
40
+ return bandpower_features
41
+
42
+ def validate_fbands(self, settings: "NMSettings") -> None:
43
+ for fband_name, seg_length_fband in self.segment_lengths_ms.items():
44
+ assert seg_length_fband <= settings.segment_length_features_ms, (
45
+ f"segment length {seg_length_fband} needs to be smaller than "
46
+ f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
47
+ )
48
+
49
+ for fband_name in settings.frequency_ranges_hz.keys():
50
+ assert fband_name in self.segment_lengths_ms, (
51
+ f"frequency range {fband_name} "
52
+ "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
53
+ )
54
+
55
+
56
+ class BandPower(NMFeature):
57
+ def __init__(
58
+ self,
59
+ settings: "NMSettings",
60
+ ch_names: Sequence[str],
61
+ sfreq: float,
62
+ use_kf: bool | None = None,
63
+ ) -> None:
64
+ settings.validate()
65
+
66
+ self.bp_settings: BandPowerSettings = settings.bandpass_filter_settings
67
+ self.kalman_filter_settings: KalmanSettings = settings.kalman_filter_settings
68
+ self.sfreq = sfreq
69
+ self.ch_names = ch_names
70
+ self.KF_dict: dict = {}
71
+
72
+ from py_neuromodulation.filter import MNEFilter
73
+
74
+ self.bandpass_filter = MNEFilter(
75
+ f_ranges=[
76
+ tuple(frange) for frange in settings.frequency_ranges_hz.values()
77
+ ],
78
+ sfreq=self.sfreq,
79
+ filter_length=self.sfreq - 1,
80
+ verbose=False,
81
+ )
82
+
83
+ if use_kf or (use_kf is None and self.bp_settings.kalman_filter):
84
+ self.init_KF("bandpass_activity")
85
+
86
+ seglengths = self.bp_settings.segment_lengths_ms
87
+
88
+ self.feature_params = []
89
+ for ch_idx, ch_name in enumerate(self.ch_names):
90
+ for f_band_idx, f_band in enumerate(settings.frequency_ranges_hz.keys()):
91
+ seglength_ms = seglengths[f_band]
92
+ seglen = int(np.floor(self.sfreq / 1000 * seglength_ms))
93
+ for bp_feature in self.bp_settings.bandpower_features.get_enabled():
94
+ feature_name = "_".join([ch_name, "bandpass", bp_feature, f_band])
95
+ self.feature_params.append(
96
+ (
97
+ ch_idx,
98
+ f_band_idx,
99
+ seglen,
100
+ bp_feature,
101
+ feature_name,
102
+ )
103
+ )
104
+
105
+ def init_KF(self, feature: str) -> None:
106
+ from py_neuromodulation.filter import define_KF
107
+
108
+ for f_band in self.kalman_filter_settings.frequency_bands:
109
+ for channel in self.ch_names:
110
+ self.KF_dict["_".join([channel, feature, f_band])] = define_KF(
111
+ self.kalman_filter_settings.Tp,
112
+ self.kalman_filter_settings.sigma_w,
113
+ self.kalman_filter_settings.sigma_v,
114
+ )
115
+
116
+ def update_KF(self, feature_calc: np.floating, KF_name: str) -> np.floating:
117
+ if KF_name in self.KF_dict:
118
+ self.KF_dict[KF_name].predict()
119
+ self.KF_dict[KF_name].update(feature_calc)
120
+ feature_calc = self.KF_dict[KF_name].x[0]
121
+ return feature_calc
122
+
123
+ def calc_feature(self, data: np.ndarray) -> dict:
124
+ data = self.bandpass_filter.filter_data(data)
125
+
126
+ feature_results = {}
127
+ for (
128
+ ch_idx,
129
+ f_band_idx,
130
+ seglen,
131
+ bp_feature,
132
+ feature_name,
133
+ ) in self.feature_params:
134
+ feature_results[feature_name] = self.calc_bp_feature(
135
+ bp_feature, feature_name, data[ch_idx, f_band_idx, -seglen:]
136
+ )
137
+
138
+ return feature_results
139
+
140
+ def calc_bp_feature(self, bp_feature, feature_name, data):
141
+ match bp_feature:
142
+ case "activity":
143
+ feature_calc = np.var(data)
144
+ if self.bp_settings.log_transform:
145
+ feature_calc = np.log10(feature_calc)
146
+ if self.KF_dict:
147
+ feature_calc = self.update_KF(feature_calc, feature_name)
148
+ case "mobility":
149
+ feature_calc = np.sqrt(np.var(np.diff(data)) / np.var(data))
150
+ case "complexity":
151
+ feature_calc = self.calc_complexity(data)
152
+ case _:
153
+ raise ValueError(f"Unknown bandpower feature: {bp_feature}")
154
+
155
+ return np.nan_to_num(feature_calc)
156
+
157
+ @staticmethod
158
+ def calc_complexity(data: np.ndarray) -> float:
159
+ dat_deriv = np.diff(data)
160
+ deriv_variance = np.var(dat_deriv)
161
+ mobility = np.sqrt(deriv_variance / np.var(data))
162
+ dat_deriv_2_var = np.var(np.diff(dat_deriv))
163
+ deriv_mobility = np.sqrt(dat_deriv_2_var / deriv_variance)
164
+
165
+ return deriv_mobility / mobility
@@ -0,0 +1,157 @@
1
+ from collections.abc import Iterable
2
+ from pydantic import field_validator
3
+ from typing import TYPE_CHECKING, Callable
4
+
5
+ import numpy as np
6
+
7
+ from py_neuromodulation.utils.types import (
8
+ NMBaseModel,
9
+ NMFeature,
10
+ BoolSelector,
11
+ FrequencyRange,
12
+ )
13
+
14
+ if TYPE_CHECKING:
15
+ from py_neuromodulation import NMSettings
16
+
17
+
18
+ class BispectraComponents(BoolSelector):
19
+ absolute: bool = True
20
+ real: bool = True
21
+ imag: bool = True
22
+ phase: bool = True
23
+
24
+
25
+ class BispectraFeatures(BoolSelector):
26
+ mean: bool = True
27
+ sum: bool = True
28
+ var: bool = True
29
+
30
+
31
+ class BispectraSettings(NMBaseModel):
32
+ f1s: FrequencyRange = FrequencyRange(5, 35)
33
+ f2s: FrequencyRange = FrequencyRange(5, 35)
34
+ compute_features_for_whole_fband_range: bool = True
35
+ frequency_bands: list[str] = ["theta", "alpha", "low_beta", "high_beta"]
36
+
37
+ components: BispectraComponents = BispectraComponents()
38
+ bispectrum_features: BispectraFeatures = BispectraFeatures()
39
+
40
+ @field_validator("f1s", "f2s")
41
+ def test_range(cls, filter_range):
42
+ assert (
43
+ filter_range[1] > filter_range[0]
44
+ ), f"second frequency range value needs to be higher than first one, got {filter_range}"
45
+ return filter_range
46
+
47
+ @field_validator("frequency_bands")
48
+ def fbands_spaces_to_underscores(cls, frequency_bands):
49
+ return [f.replace(" ", "_") for f in frequency_bands]
50
+
51
+
52
+ FEATURE_DICT: dict[str, Callable] = {
53
+ "mean": np.nanmean,
54
+ "sum": np.nansum,
55
+ "var": np.nanvar,
56
+ }
57
+
58
+ COMPONENT_DICT: dict[str, Callable] = {
59
+ "real": lambda obj: getattr(obj, "real"),
60
+ "imag": lambda obj: getattr(obj, "imag"),
61
+ "absolute": np.abs,
62
+ "phase": np.angle,
63
+ }
64
+
65
+
66
+ class Bispectra(NMFeature):
67
+ def __init__(
68
+ self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float
69
+ ) -> None:
70
+ self.sfreq = sfreq
71
+ self.ch_names = ch_names
72
+ self.frequency_ranges_hz = settings.frequency_ranges_hz
73
+ self.settings: BispectraSettings = settings.bispectrum_settings
74
+
75
+ assert all(
76
+ f_band_bispectrum in settings.frequency_ranges_hz
77
+ for f_band_bispectrum in self.settings.frequency_bands
78
+ ), (
79
+ "bispectrum selected frequency bands don't match the ones"
80
+ "specified in s['frequency_ranges_hz']"
81
+ f"bispectrum frequency bands: {self.settings.frequency_bands}"
82
+ f"specified frequency_ranges_hz: {settings.frequency_ranges_hz}"
83
+ )
84
+
85
+ self.used_features = self.settings.bispectrum_features.get_enabled()
86
+
87
+ self.min_freq = min(
88
+ self.settings.f1s.frequency_low_hz, self.settings.f2s.frequency_low_hz
89
+ )
90
+ self.max_freq = max(
91
+ self.settings.f1s.frequency_high_hz, self.settings.f2s.frequency_high_hz
92
+ )
93
+
94
+ # self.freqs: np.ndarray = np.array([]) # In case we pre-computed this
95
+
96
+ def calc_feature(self, data: np.ndarray) -> dict:
97
+ from pybispectra import compute_fft, WaveShape
98
+
99
+ # PyBispectra's compute_fft uses PQDM to parallelize the calculation per channel
100
+ # Is this necessary? Maybe the overhead of parallelization is not worth it
101
+ # considering that we incur in it once per batch of data
102
+ fft_coeffs, freqs = compute_fft(
103
+ data=np.expand_dims(data, axis=(0)),
104
+ sampling_freq=self.sfreq,
105
+ n_points=data.shape[1],
106
+ verbose=False,
107
+ )
108
+
109
+ # freqs is batch independent, except for the last batch perhaps (if it has different shape)
110
+ # but it's computed by compute_fft regardless so no advantage in pre-computing it
111
+ # if not self.freqs = self.freqs = np.fft.rfftfreq(n=data.shape[1], d = 1 / sfreq)
112
+
113
+ # fft_coeffs shape: [epochs, channels, frequencies]
114
+
115
+ f_spectrum_range = freqs[
116
+ np.logical_and(freqs >= self.min_freq, freqs <= self.max_freq)
117
+ ]
118
+
119
+ waveshape = WaveShape(
120
+ data=fft_coeffs,
121
+ freqs=freqs,
122
+ sampling_freq=self.sfreq,
123
+ verbose=False,
124
+ )
125
+
126
+ waveshape.compute(
127
+ f1s=tuple(self.settings.f1s), # type: ignore
128
+ f2s=tuple(self.settings.f2s), # type: ignore
129
+ )
130
+
131
+ feature_results = {}
132
+ for ch_idx, ch_name in enumerate(self.ch_names):
133
+ bispectrum = waveshape._bicoherence[
134
+ ch_idx
135
+ ] # Same as waveshape.results._data, skips a copy
136
+
137
+ for component in self.settings.components.get_enabled():
138
+ spectrum_ch = COMPONENT_DICT[component](bispectrum)
139
+
140
+ for fb in self.settings.frequency_bands:
141
+ range_ = (f_spectrum_range >= self.frequency_ranges_hz[fb][0]) & (
142
+ f_spectrum_range <= self.frequency_ranges_hz[fb][1]
143
+ )
144
+ # waveshape.results.plot()
145
+ data_bs = spectrum_ch[range_, range_]
146
+
147
+ for bispectrum_feature in self.used_features:
148
+ feature_results[
149
+ f"{ch_name}_Bispectrum_{component}_{bispectrum_feature}_{fb}"
150
+ ] = FEATURE_DICT[bispectrum_feature](data_bs)
151
+
152
+ if self.settings.compute_features_for_whole_fband_range:
153
+ feature_results[
154
+ f"{ch_name}_Bispectrum_{component}_{bispectrum_feature}_whole_fband_range"
155
+ ] = FEATURE_DICT[bispectrum_feature](spectrum_ch)
156
+
157
+ return feature_results
@@ -0,0 +1,297 @@
1
+ import numpy as np
2
+
3
+ if np.__version__ >= "2.0.0":
4
+ from numpy.lib._function_base_impl import _quantile as np_quantile # type:ignore
5
+ else:
6
+ from numpy.lib.function_base import _quantile as np_quantile # type:ignore
7
+ from collections.abc import Sequence
8
+ from itertools import product
9
+
10
+ from pydantic import Field, field_validator
11
+ from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
12
+
13
+ from typing import TYPE_CHECKING, Callable
14
+ from py_neuromodulation.utils.types import create_validation_error
15
+
16
+ if TYPE_CHECKING:
17
+ from py_neuromodulation import NMSettings
18
+
19
+
20
+ LARGE_NUM = 2**24
21
+
22
+
23
+ def get_label_pos(burst_labels, valid_labels):
24
+ max_label = np.max(burst_labels, axis=2).flatten()
25
+ min_label = np.min(
26
+ burst_labels, axis=2, initial=LARGE_NUM, where=burst_labels != 0
27
+ ).flatten()
28
+ label_positions = np.zeros_like(valid_labels)
29
+ N = len(valid_labels)
30
+ pos = 0
31
+ i = 0
32
+ while i < N:
33
+ if valid_labels[i] >= min_label[pos] and valid_labels[i] <= max_label[pos]:
34
+ label_positions[i] = pos
35
+ i += 1
36
+ else:
37
+ pos += 1
38
+ return label_positions
39
+
40
+
41
+ class BurstFeatures(BoolSelector):
42
+ duration: bool = True
43
+ amplitude: bool = True
44
+ burst_rate_per_s: bool = True
45
+ in_burst: bool = True
46
+
47
+
48
+ class BurstsSettings(NMBaseModel):
49
+ threshold: float = Field(default=75, ge=0, le=100)
50
+ time_duration_s: float = Field(default=30, ge=0)
51
+ frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
52
+ burst_features: BurstFeatures = BurstFeatures()
53
+
54
+ @field_validator("frequency_bands")
55
+ def fbands_spaces_to_underscores(cls, frequency_bands):
56
+ return [f.replace(" ", "_") for f in frequency_bands]
57
+
58
+
59
+ class Bursts(NMFeature):
60
+ def __init__(
61
+ self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
62
+ ) -> None:
63
+ # Test settings
64
+ settings.validate()
65
+
66
+ # Validate that all frequency bands are defined in the settings
67
+ for fband_burst in settings.burst_settings.frequency_bands:
68
+ if fband_burst not in list(settings.frequency_ranges_hz.keys()):
69
+ raise create_validation_error(
70
+ f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']",
71
+ loc=["burst_settings", "frequency_bands"],
72
+ )
73
+
74
+ from py_neuromodulation.filter import MNEFilter
75
+
76
+ self.settings = settings.burst_settings
77
+ self.sfreq = sfreq
78
+ self.ch_names = ch_names
79
+ self.segment_length_features_s = settings.segment_length_features_ms / 1000
80
+ self.samples_overlap = int(
81
+ self.sfreq
82
+ * self.segment_length_features_s
83
+ / settings.sampling_rate_features_hz
84
+ )
85
+
86
+ self.fband_names = settings.burst_settings.frequency_bands
87
+
88
+ f_ranges: list[tuple[float, float]] = [
89
+ (
90
+ settings.frequency_ranges_hz[fband_name][0],
91
+ settings.frequency_ranges_hz[fband_name][1],
92
+ )
93
+ for fband_name in self.fband_names
94
+ ]
95
+
96
+ self.bandpass_filter = MNEFilter(
97
+ f_ranges=f_ranges,
98
+ sfreq=self.sfreq,
99
+ filter_length=self.sfreq - 1,
100
+ verbose=False,
101
+ )
102
+ self.filter_data = self.bandpass_filter.filter_data
103
+
104
+ self.num_max_samples_ring_buffer = int(
105
+ self.sfreq * self.settings.time_duration_s
106
+ )
107
+
108
+ self.n_channels = len(self.ch_names)
109
+ self.n_fbands = len(self.fband_names)
110
+
111
+ # Create circular buffer array for previous time_duration_s
112
+ self.data_buffer = np.empty(
113
+ (self.n_channels, self.n_fbands, 0), dtype=np.float64
114
+ )
115
+
116
+ self.used_features = self.settings.burst_features.get_enabled()
117
+
118
+ self.feature_combinations = list(
119
+ product(
120
+ enumerate(self.ch_names),
121
+ enumerate(self.fband_names),
122
+ self.settings.burst_features.get_enabled(),
123
+ )
124
+ )
125
+
126
+ # Variables to store results
127
+ self.burst_duration_mean: np.ndarray
128
+ self.burst_duration_max: np.ndarray
129
+ self.burst_amplitude_max: np.ndarray
130
+ self.burst_amplitude_mean: np.ndarray
131
+ self.burst_rate_per_s: np.ndarray
132
+ self.end_in_burst: np.ndarray
133
+
134
+ self.STORE_FEAT_DICT: dict[str, Callable] = {
135
+ "duration": self.store_duration,
136
+ "amplitude": self.store_amplitude,
137
+ "burst_rate_per_s": self.store_burst_rate,
138
+ "in_burst": self.store_in_burst,
139
+ }
140
+
141
+ self.batch = 0
142
+
143
+ # Structure matrix for np.ndimage.label
144
+ # pixels are connected only to adjacent neighbors along the last axis
145
+ self.label_structure_matrix = np.zeros((3, 3, 3))
146
+ self.label_structure_matrix[1, 1, :] = 1
147
+
148
+ def calc_feature(self, data: np.ndarray) -> dict:
149
+ from scipy.signal import hilbert
150
+ from scipy.ndimage import label, sum_labels as label_sum, mean as label_mean
151
+
152
+ filtered_data = np.abs(np.array(hilbert(self.filter_data(data))))
153
+
154
+ # Update buffer array
155
+ batch_size = (
156
+ filtered_data.shape[-1] if self.batch == 0 else self.samples_overlap
157
+ )
158
+
159
+ self.batch += 1
160
+ self.data_buffer = np.concatenate(
161
+ (
162
+ self.data_buffer,
163
+ filtered_data[:, :, -batch_size:],
164
+ ),
165
+ axis=2,
166
+ )[:, :, -self.num_max_samples_ring_buffer :]
167
+
168
+ # Burst threshold is calculated with the percentile defined in the settings
169
+ # Call low-level numpy function directly, extra checks not needed
170
+ burst_thr = np_quantile(self.data_buffer, self.settings.threshold / 100)[
171
+ :, :, None
172
+ ] # Add back the extra dimension
173
+
174
+ # Get burst locations as a boolean array, True where data is above threshold (i.e. a burst)
175
+ bursts = filtered_data >= burst_thr
176
+
177
+ # Use np.diff to find the places where bursts start and end
178
+ # Prepend False at the beginning ensures that data never starts on a burst
179
+ # Floor division to ignore last burst if series ends in a burst (true burst length unknown)
180
+ num_bursts = (
181
+ np.sum(np.diff(bursts, axis=2, prepend=False), axis=2) // 2
182
+ ).astype(np.float64) # np.astype added to avoid casting error in np.divide
183
+
184
+ # Label each burst with a unique id, limiting connectivity to last axis (see scipy.ndimage.label docs for details)
185
+ burst_labels = label(bursts, self.label_structure_matrix)[0] # type: ignore # wrong return type in scipy
186
+
187
+ # Remove labels of bursts that are at the end of the dataset, and 0
188
+ labels_at_end = np.concatenate((np.unique(burst_labels[:, :, -1]), (0,)))
189
+ valid_labels = np.unique(burst_labels)
190
+ valid_labels = valid_labels[
191
+ ~np.isin(valid_labels, labels_at_end, assume_unique=True)
192
+ ]
193
+
194
+ # Find (channel, band) coordinates for each valid label and get an array that maps each valid label to its channel/band
195
+ # Channel band coordinate is flattened to a 1D array of length (n_channels x n_fbands)
196
+ label_positions = get_label_pos(burst_labels, valid_labels)
197
+
198
+ # Now we're ready to calculate features
199
+
200
+ if "duration" in self.used_features or "burst_rate_per_s" in self.used_features:
201
+ # Handle division by zero using np.divide. Where num_bursts is 0, the result is 0
202
+ self.burst_duration_mean = (
203
+ np.divide(
204
+ np.sum(bursts, axis=2),
205
+ num_bursts,
206
+ out=np.zeros_like(num_bursts),
207
+ where=num_bursts != 0,
208
+ )
209
+ / self.sfreq
210
+ )
211
+
212
+ if "duration" in self.used_features:
213
+ # First get burst length for each valid burst
214
+ burst_lengths = (
215
+ label_sum(bursts, burst_labels, index=valid_labels) / self.sfreq
216
+ )
217
+
218
+ # Now the max needs to be calculated per channel/band
219
+ # For that, loop over channels/bands, get the corresponding burst lengths, and get the max
220
+ # Give parameter initial=0 so that when there are no bursts, the max is 0
221
+ # TODO: it might be interesting to write a C function for this
222
+ duration_max_flat = np.zeros(self.n_channels * self.n_fbands)
223
+ for idx in range(self.n_channels * self.n_fbands):
224
+ duration_max_flat[idx] = np.max(
225
+ burst_lengths[label_positions == idx], initial=0
226
+ )
227
+
228
+ self.burst_duration_max = duration_max_flat.reshape(
229
+ (self.n_channels, self.n_fbands)
230
+ )
231
+
232
+ if "amplitude" in self.used_features:
233
+ # Max amplitude is just the max of the filtered data where there is a burst
234
+ self.burst_amplitude_max = (filtered_data * bursts).max(axis=2)
235
+
236
+ # The mean is actually a mean of means, so we need the mean for each individual burst
237
+ label_means = label_mean(filtered_data, burst_labels, index=valid_labels)
238
+ # Now, loop over channels/bands, get the corresponding burst means, and calculate the mean of means
239
+ # TODO: it might be interesting to write a C function for this
240
+ amplitude_mean_flat = np.zeros(self.n_channels * self.n_fbands)
241
+ for idx in range(self.n_channels * self.n_fbands):
242
+ mask = label_positions == idx
243
+ amplitude_mean_flat[idx] = (
244
+ np.mean(label_means[mask]) if np.any(mask) else 0
245
+ )
246
+
247
+ self.burst_amplitude_mean = amplitude_mean_flat.reshape(
248
+ (self.n_channels, self.n_fbands)
249
+ )
250
+
251
+ if "burst_rate_per_s" in self.used_features:
252
+ self.burst_rate_per_s = (
253
+ self.burst_duration_mean / self.segment_length_features_s
254
+ )
255
+
256
+ if "in_burst" in self.used_features:
257
+ self.end_in_burst = bursts[:, :, -1] # End in burst
258
+
259
+ # Create dictionary of features which is the correct return format
260
+ feature_results = {}
261
+ for (ch_i, ch), (fb_i, fb), feat in self.feature_combinations:
262
+ self.STORE_FEAT_DICT[feat](feature_results, ch_i, ch, fb_i, fb)
263
+
264
+ return feature_results
265
+
266
+ def store_duration(
267
+ self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
268
+ ):
269
+ feature_results[f"{ch}_bursts_{fb}_duration_mean"] = self.burst_duration_mean[
270
+ ch_i, fb_i
271
+ ]
272
+
273
+ feature_results[f"{ch}_bursts_{fb}_duration_max"] = self.burst_duration_max[
274
+ ch_i, fb_i
275
+ ]
276
+
277
+ def store_amplitude(
278
+ self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
279
+ ):
280
+ feature_results[f"{ch}_bursts_{fb}_amplitude_mean"] = self.burst_amplitude_mean[
281
+ ch_i, fb_i
282
+ ]
283
+ feature_results[f"{ch}_bursts_{fb}_amplitude_max"] = self.burst_amplitude_max[
284
+ ch_i, fb_i
285
+ ]
286
+
287
+ def store_burst_rate(
288
+ self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
289
+ ):
290
+ feature_results[f"{ch}_bursts_{fb}_burst_rate_per_s"] = self.burst_rate_per_s[
291
+ ch_i, fb_i
292
+ ]
293
+
294
+ def store_in_burst(
295
+ self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
296
+ ):
297
+ feature_results[f"{ch}_bursts_{fb}_in_burst"] = self.end_in_burst[ch_i, fb_i]