py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/FieldTrip.py +589 -589
  5. py_neuromodulation/__init__.py +74 -13
  6. py_neuromodulation/_write_example_dataset_helper.py +83 -65
  7. py_neuromodulation/data/README +6 -6
  8. py_neuromodulation/data/dataset_description.json +8 -8
  9. py_neuromodulation/data/participants.json +32 -32
  10. py_neuromodulation/data/participants.tsv +2 -2
  11. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  12. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  13. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  14. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  18. py_neuromodulation/grid_cortex.tsv +40 -40
  19. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  20. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  21. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  22. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  23. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  24. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  25. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  26. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  27. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  28. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  29. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  30. py_neuromodulation/nm_IO.py +413 -417
  31. py_neuromodulation/nm_RMAP.py +496 -531
  32. py_neuromodulation/nm_analysis.py +993 -1074
  33. py_neuromodulation/nm_artifacts.py +30 -25
  34. py_neuromodulation/nm_bispectra.py +154 -168
  35. py_neuromodulation/nm_bursts.py +292 -198
  36. py_neuromodulation/nm_coherence.py +251 -205
  37. py_neuromodulation/nm_database.py +149 -0
  38. py_neuromodulation/nm_decode.py +918 -992
  39. py_neuromodulation/nm_define_nmchannels.py +300 -302
  40. py_neuromodulation/nm_features.py +144 -116
  41. py_neuromodulation/nm_filter.py +219 -219
  42. py_neuromodulation/nm_filter_preprocessing.py +79 -91
  43. py_neuromodulation/nm_fooof.py +139 -159
  44. py_neuromodulation/nm_generator.py +45 -37
  45. py_neuromodulation/nm_hjorth_raw.py +52 -73
  46. py_neuromodulation/nm_kalmanfilter.py +71 -58
  47. py_neuromodulation/nm_linelength.py +21 -33
  48. py_neuromodulation/nm_logger.py +66 -0
  49. py_neuromodulation/nm_mne_connectivity.py +149 -112
  50. py_neuromodulation/nm_mnelsl_generator.py +90 -0
  51. py_neuromodulation/nm_mnelsl_stream.py +116 -0
  52. py_neuromodulation/nm_nolds.py +96 -93
  53. py_neuromodulation/nm_normalization.py +173 -214
  54. py_neuromodulation/nm_oscillatory.py +423 -448
  55. py_neuromodulation/nm_plots.py +585 -612
  56. py_neuromodulation/nm_preprocessing.py +83 -0
  57. py_neuromodulation/nm_projection.py +370 -394
  58. py_neuromodulation/nm_rereference.py +97 -95
  59. py_neuromodulation/nm_resample.py +59 -50
  60. py_neuromodulation/nm_run_analysis.py +325 -435
  61. py_neuromodulation/nm_settings.py +289 -68
  62. py_neuromodulation/nm_settings.yaml +244 -0
  63. py_neuromodulation/nm_sharpwaves.py +423 -401
  64. py_neuromodulation/nm_stats.py +464 -480
  65. py_neuromodulation/nm_stream.py +398 -0
  66. py_neuromodulation/nm_stream_abc.py +166 -218
  67. py_neuromodulation/nm_types.py +193 -0
  68. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/METADATA +29 -26
  69. py_neuromodulation-0.0.5.dist-info/RECORD +83 -0
  70. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/WHEEL +1 -1
  71. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.5.dist-info}/licenses/LICENSE +21 -21
  72. py_neuromodulation/nm_EpochStream.py +0 -92
  73. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  74. py_neuromodulation/nm_cohortwrapper.py +0 -435
  75. py_neuromodulation/nm_eval_timing.py +0 -239
  76. py_neuromodulation/nm_features_abc.py +0 -39
  77. py_neuromodulation/nm_settings.json +0 -338
  78. py_neuromodulation/nm_stream_offline.py +0 -359
  79. py_neuromodulation/utils/_logging.py +0 -24
  80. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -1,198 +1,292 @@
1
- import enum
2
- import numpy as np
3
- from typing import Iterable
4
- from scipy import signal
5
-
6
- from py_neuromodulation import nm_features_abc, nm_filter
7
-
8
-
9
- class Burst(nm_features_abc.Feature):
10
- def __init__(
11
- self, settings: dict, ch_names: Iterable[str], sfreq: float
12
- ) -> None:
13
- self.s = settings
14
- self.sfreq = sfreq
15
- self.ch_names = ch_names
16
- self.threshold = self.s["burst_settings"]["threshold"]
17
- self.time_duration_s = self.s["burst_settings"]["time_duration_s"]
18
- self.samples_overlap = int(
19
- self.sfreq
20
- * (self.s["segment_length_features_ms"] / 1000)
21
- / self.s["sampling_rate_features_hz"]
22
- )
23
-
24
- self.fband_names = self.s["burst_settings"]["frequency_bands"]
25
- self.f_ranges = [
26
- self.s["frequency_ranges_hz"][fband_name]
27
- for fband_name in self.fband_names
28
- ]
29
- self.seglengths = np.floor(
30
- self.sfreq
31
- / 1000
32
- * np.array(
33
- [
34
- self.s["bandpass_filter_settings"]["segment_lengths_ms"][
35
- fband
36
- ]
37
- for fband in self.fband_names
38
- ]
39
- )
40
- ).astype(int)
41
-
42
- self.num_max_samples_ring_buffer = int(
43
- self.sfreq * self.time_duration_s
44
- )
45
-
46
- self.bandpass_filter = nm_filter.MNEFilter(
47
- f_ranges=self.f_ranges,
48
- sfreq=self.sfreq,
49
- filter_length=self.sfreq - 1,
50
- verbose=False,
51
- )
52
-
53
- # create dict with fband, channel specific data store
54
- # for previous time_duration_s
55
- def init_ch_fband_dict() -> dict:
56
- d = {}
57
- for ch in self.ch_names:
58
- if ch not in d:
59
- d[ch] = {}
60
- for fb in self.fband_names:
61
- if fb not in d[ch]:
62
- d[ch][fb] = None
63
- return d
64
-
65
- self.data_buffer = init_ch_fband_dict()
66
-
67
- def test_settings(
68
- settings: dict,
69
- ch_names: Iterable[str],
70
- sfreq: int | float,
71
- ):
72
- assert isinstance(
73
- settings["burst_settings"]["threshold"], (float, int)
74
- ), f"burst settings threshold needs to be type int or float, got: {settings['burst_settings']['threshold']}"
75
- assert (
76
- 0 < settings["burst_settings"]["threshold"] < 100
77
- ), f"burst setting threshold needs to be between 0 and 100, got: {settings['burst_settings']['threshold']}"
78
- assert isinstance(
79
- settings["burst_settings"]["time_duration_s"], (float, int)
80
- ), f"burst settings time_duration_s needs to be type int or float, got: {settings['burst_settings']['time_duration_s']}"
81
- assert (
82
- settings["burst_settings"]["time_duration_s"] > 0
83
- ), f"burst setting time_duration_s needs to be greater than 0, got: {settings['burst_settings']['time_duration_s']}"
84
-
85
- for fband_burst in settings["burst_settings"]["frequency_bands"]:
86
- assert fband_burst in list(
87
- settings["frequency_ranges_hz"].keys()
88
- ), f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']"
89
-
90
- for burst_feature in settings["burst_settings"][
91
- "burst_features"
92
- ].keys():
93
- assert isinstance(
94
- settings["burst_settings"]["burst_features"][burst_feature],
95
- bool,
96
- ), (
97
- f"bursting feature {burst_feature} needs to be type bool, "
98
- f"got: {settings['burst_settings']['burst_features'][burst_feature]}"
99
- )
100
-
101
- def calc_feature(self, data: np.array, features_compute: dict) -> dict:
102
- # filter_data returns (n_channels, n_fbands, n_samples)
103
- filtered_data = np.abs(
104
- signal.hilbert(self.bandpass_filter.filter_data(data), axis=2)
105
- )
106
- for ch_idx, ch_name in enumerate(self.ch_names):
107
- for fband_idx, fband_name in enumerate(self.fband_names):
108
- new_dat = filtered_data[ch_idx, fband_idx, :]
109
- if self.data_buffer[ch_name][fband_name] is None:
110
- self.data_buffer[ch_name][fband_name] = new_dat
111
- else:
112
- self.data_buffer[ch_name][fband_name] = np.concatenate(
113
- (
114
- self.data_buffer[ch_name][fband_name],
115
- new_dat[-self.samples_overlap :],
116
- ),
117
- axis=0,
118
- )[-self.num_max_samples_ring_buffer :]
119
-
120
- # calc features
121
- burst_thr = np.percentile(
122
- self.data_buffer[ch_name][fband_name], q=self.threshold
123
- )
124
-
125
- burst_amplitude, burst_length = self.get_burst_amplitude_length(
126
- new_dat, burst_thr, self.sfreq
127
- )
128
-
129
- features_compute[
130
- f"{ch_name}_bursts_{fband_name}_duration_mean"
131
- ] = (np.mean(burst_length) if len(burst_length) != 0 else 0)
132
- features_compute[
133
- f"{ch_name}_bursts_{fband_name}_amplitude_mean"
134
- ] = (
135
- np.mean([np.mean(a) for a in burst_amplitude])
136
- if len(burst_length) != 0
137
- else 0
138
- )
139
-
140
- features_compute[
141
- f"{ch_name}_bursts_{fband_name}_duration_max"
142
- ] = (np.max(burst_length) if len(burst_length) != 0 else 0)
143
- features_compute[
144
- f"{ch_name}_bursts_{fband_name}_amplitude_max"
145
- ] = (
146
- np.max([np.max(a) for a in burst_amplitude])
147
- if len(burst_amplitude) != 0
148
- else 0
149
- )
150
-
151
- features_compute[
152
- f"{ch_name}_bursts_{fband_name}_burst_rate_per_s"
153
- ] = (
154
- np.mean(burst_length)
155
- / (self.s["segment_length_features_ms"] / 1000)
156
- if len(burst_length) != 0
157
- else 0
158
- )
159
-
160
- in_burst = False
161
- if self.data_buffer[ch_name][fband_name][-1] > burst_thr:
162
- in_burst = True
163
-
164
- features_compute[f"{ch_name}_bursts_{fband_name}_in_burst"] = (
165
- in_burst
166
- )
167
- return features_compute
168
-
169
- @staticmethod
170
- def get_burst_amplitude_length(
171
- beta_averp_norm, burst_thr: float, sfreq: float
172
- ):
173
- """
174
- Analysing the duration of beta burst
175
- """
176
- bursts = np.zeros((beta_averp_norm.shape[0] + 1), dtype=bool)
177
- bursts[1:] = beta_averp_norm >= burst_thr
178
- deriv = np.diff(bursts)
179
- burst_length = []
180
- burst_amplitude = []
181
-
182
- burst_time_points = np.where(deriv == True)[0]
183
-
184
- for i in range(burst_time_points.size // 2):
185
- burst_length.append(
186
- burst_time_points[2 * i + 1] - burst_time_points[2 * i]
187
- )
188
- burst_amplitude.append(
189
- beta_averp_norm[
190
- burst_time_points[2 * i] : burst_time_points[2 * i + 1]
191
- ]
192
- )
193
-
194
- # the last burst length (in case isburst == True) is omitted,
195
- # since the true burst length cannot be estimated
196
- burst_length = np.array(burst_length) / sfreq
197
-
198
- return burst_amplitude, burst_length
1
+ import numpy as np
2
+
3
+ if np.__version__ >= "2.0.0":
4
+ from numpy.lib._function_base_impl import _quantile as np_quantile # type:ignore
5
+ else:
6
+ from numpy.lib.function_base import _quantile as np_quantile # type:ignore
7
+ from collections.abc import Sequence
8
+ from itertools import product
9
+
10
+ from pydantic import Field, field_validator
11
+ from py_neuromodulation.nm_types import BoolSelector, NMBaseModel
12
+ from py_neuromodulation.nm_features import NMFeature
13
+
14
+ from typing import TYPE_CHECKING, Callable
15
+
16
+ if TYPE_CHECKING:
17
+ from py_neuromodulation.nm_settings import NMSettings
18
+
19
+
20
+ LARGE_NUM = 2**24
21
+
22
+
23
+ def get_label_pos(burst_labels, valid_labels):
24
+ max_label = np.max(burst_labels, axis=2).flatten()
25
+ min_label = np.min(
26
+ burst_labels, axis=2, initial=LARGE_NUM, where=burst_labels != 0
27
+ ).flatten()
28
+ label_positions = np.zeros_like(valid_labels)
29
+ N = len(valid_labels)
30
+ pos = 0
31
+ i = 0
32
+ while i < N:
33
+ if valid_labels[i] >= min_label[pos] and valid_labels[i] <= max_label[pos]:
34
+ label_positions[i] = pos
35
+ i += 1
36
+ else:
37
+ pos += 1
38
+ return label_positions
39
+
40
+
41
+ class BurstFeatures(BoolSelector):
42
+ duration: bool = True
43
+ amplitude: bool = True
44
+ burst_rate_per_s: bool = True
45
+ in_burst: bool = True
46
+
47
+
48
+ class BurstSettings(NMBaseModel):
49
+ threshold: float = Field(default=75, ge=0, le=100)
50
+ time_duration_s: float = Field(default=30, ge=0)
51
+ frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
52
+ burst_features: BurstFeatures = BurstFeatures()
53
+
54
+ @field_validator("frequency_bands")
55
+ def fbands_spaces_to_underscores(cls, frequency_bands):
56
+ return [f.replace(" ", "_") for f in frequency_bands]
57
+
58
+
59
+ class Burst(NMFeature):
60
+ def __init__(
61
+ self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
62
+ ) -> None:
63
+ # Test settings
64
+ for fband_burst in settings.burst_settings.frequency_bands:
65
+ assert (
66
+ fband_burst in list(settings.frequency_ranges_hz.keys())
67
+ ), f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']"
68
+
69
+ from py_neuromodulation.nm_filter import MNEFilter
70
+
71
+ self.settings = settings.burst_settings
72
+ self.sfreq = sfreq
73
+ self.ch_names = ch_names
74
+ self.segment_length_features_s = settings.segment_length_features_ms / 1000
75
+ self.samples_overlap = int(
76
+ self.sfreq
77
+ * self.segment_length_features_s
78
+ / settings.sampling_rate_features_hz
79
+ )
80
+
81
+ self.fband_names = settings.burst_settings.frequency_bands
82
+
83
+ f_ranges: list[tuple[float, float]] = [
84
+ (
85
+ settings.frequency_ranges_hz[fband_name][0],
86
+ settings.frequency_ranges_hz[fband_name][1],
87
+ )
88
+ for fband_name in self.fband_names
89
+ ]
90
+
91
+ self.bandpass_filter = MNEFilter(
92
+ f_ranges=f_ranges,
93
+ sfreq=self.sfreq,
94
+ filter_length=self.sfreq - 1,
95
+ verbose=False,
96
+ )
97
+ self.filter_data = self.bandpass_filter.filter_data
98
+
99
+ self.num_max_samples_ring_buffer = int(
100
+ self.sfreq * self.settings.time_duration_s
101
+ )
102
+
103
+ self.n_channels = len(self.ch_names)
104
+ self.n_fbands = len(self.fband_names)
105
+
106
+ # Create circular buffer array for previous time_duration_s
107
+ self.data_buffer = np.empty(
108
+ (self.n_channels, self.n_fbands, 0), dtype=np.float64
109
+ )
110
+
111
+ self.used_features = self.settings.burst_features.get_enabled()
112
+
113
+ self.feature_combinations = list(
114
+ product(
115
+ enumerate(self.ch_names),
116
+ enumerate(self.fband_names),
117
+ self.settings.burst_features.get_enabled(),
118
+ )
119
+ )
120
+
121
+ # Variables to store results
122
+ self.burst_duration_mean: np.ndarray
123
+ self.burst_duration_max: np.ndarray
124
+ self.burst_amplitude_max: np.ndarray
125
+ self.burst_amplitude_mean: np.ndarray
126
+ self.burst_rate_per_s: np.ndarray
127
+ self.end_in_burst: np.ndarray
128
+
129
+ self.STORE_FEAT_DICT: dict[str, Callable] = {
130
+ "duration": self.store_duration,
131
+ "amplitude": self.store_amplitude,
132
+ "burst_rate_per_s": self.store_burst_rate,
133
+ "in_burst": self.store_in_burst,
134
+ }
135
+
136
+ self.batch = 0
137
+
138
+ # Structure matrix for np.ndimage.label
139
+ # pixels are connected only to adjacent neighbors along the last axis
140
+ self.label_structure_matrix = np.zeros((3, 3, 3))
141
+ self.label_structure_matrix[1, 1, :] = 1
142
+
143
+ def calc_feature(self, data: np.ndarray) -> dict:
144
+ from scipy.signal import hilbert
145
+ from scipy.ndimage import label, sum_labels as label_sum, mean as label_mean
146
+
147
+ filtered_data = np.abs(np.array(hilbert(self.filter_data(data))))
148
+
149
+ # Update buffer array
150
+ batch_size = (
151
+ filtered_data.shape[-1] if self.batch == 0 else self.samples_overlap
152
+ )
153
+
154
+ self.batch += 1
155
+ self.data_buffer = np.concatenate(
156
+ (
157
+ self.data_buffer,
158
+ filtered_data[:, :, -batch_size:],
159
+ ),
160
+ axis=2,
161
+ )[:, :, -self.num_max_samples_ring_buffer :]
162
+
163
+ # Burst threshold is calculated with the percentile defined in the settings
164
+ # Call low-level numpy function directly, extra checks not needed
165
+ burst_thr = np_quantile(self.data_buffer, self.settings.threshold / 100)[
166
+ :, :, None
167
+ ] # Add back the extra dimension
168
+
169
+ # Get burst locations as a boolean array, True where data is above threshold (i.e. a burst)
170
+ bursts = filtered_data >= burst_thr
171
+
172
+ # Use np.diff to find the places where bursts start and end
173
+ # Prepend False at the beginning ensures that data never starts on a burst
174
+ # Floor division to ignore last burst if series ends in a burst (true burst length unknown)
175
+ num_bursts = (
176
+ np.sum(np.diff(bursts, axis=2, prepend=False), axis=2) // 2
177
+ ).astype(np.float64) # np.astype added to avoid casting error in np.divide
178
+
179
+ # Label each burst with a unique id, limiting connectivity to last axis (see scipy.ndimage.label docs for details)
180
+ burst_labels = label(bursts, self.label_structure_matrix)[0] # type: ignore # wrong return type in scipy
181
+
182
+ # Remove labels of bursts that are at the end of the dataset, and 0
183
+ labels_at_end = np.concatenate((np.unique(burst_labels[:, :, -1]), (0,)))
184
+ valid_labels = np.unique(burst_labels)
185
+ valid_labels = valid_labels[
186
+ ~np.isin(valid_labels, labels_at_end, assume_unique=True)
187
+ ]
188
+
189
+ # Find (channel, band) coordinates for each valid label and get an array that maps each valid label to its channel/band
190
+ # Channel band coordinate is flattened to a 1D array of length (n_channels x n_fbands)
191
+ label_positions = get_label_pos(burst_labels, valid_labels)
192
+
193
+ # Now we're ready to calculate features
194
+
195
+ if "duration" in self.used_features or "burst_rate_per_s" in self.used_features:
196
+ # Handle division by zero using np.divide. Where num_bursts is 0, the result is 0
197
+ self.burst_duration_mean = (
198
+ np.divide(
199
+ np.sum(bursts, axis=2),
200
+ num_bursts,
201
+ out=np.zeros_like(num_bursts),
202
+ where=num_bursts != 0,
203
+ )
204
+ / self.sfreq
205
+ )
206
+
207
+ if "duration" in self.used_features:
208
+ # First get burst length for each valid burst
209
+ burst_lengths = (
210
+ label_sum(bursts, burst_labels, index=valid_labels) / self.sfreq
211
+ )
212
+
213
+ # Now the max needs to be calculated per channel/band
214
+ # For that, loop over channels/bands, get the corresponding burst lengths, and get the max
215
+ # Give parameter initial=0 so that when there are no bursts, the max is 0
216
+ # TODO: it might be interesting to write a C function for this
217
+ duration_max_flat = np.zeros(self.n_channels * self.n_fbands)
218
+ for idx in range(self.n_channels * self.n_fbands):
219
+ duration_max_flat[idx] = np.max(
220
+ burst_lengths[label_positions == idx], initial=0
221
+ )
222
+
223
+ self.burst_duration_max = duration_max_flat.reshape(
224
+ (self.n_channels, self.n_fbands)
225
+ )
226
+
227
+ if "amplitude" in self.used_features:
228
+ # Max amplitude is just the max of the filtered data where there is a burst
229
+ self.burst_amplitude_max = (filtered_data * bursts).max(axis=2)
230
+
231
+ # The mean is actually a mean of means, so we need the mean for each individual burst
232
+ label_means = label_mean(filtered_data, burst_labels, index=valid_labels)
233
+ # Now, loop over channels/bands, get the corresponding burst means, and calculate the mean of means
234
+ # TODO: it might be interesting to write a C function for this
235
+ amplitude_mean_flat = np.zeros(self.n_channels * self.n_fbands)
236
+ for idx in range(self.n_channels * self.n_fbands):
237
+ mask = label_positions == idx
238
+ amplitude_mean_flat[idx] = (
239
+ np.mean(label_means[mask]) if np.any(mask) else 0
240
+ )
241
+
242
+ self.burst_amplitude_mean = amplitude_mean_flat.reshape(
243
+ (self.n_channels, self.n_fbands)
244
+ )
245
+
246
+ if "burst_rate_per_s" in self.used_features:
247
+ self.burst_rate_per_s = (
248
+ self.burst_duration_mean / self.segment_length_features_s
249
+ )
250
+
251
+ if "in_burst" in self.used_features:
252
+ self.end_in_burst = bursts[:, :, -1] # End in burst
253
+
254
+ # Create dictionary of features which is the correct return format
255
+ feature_results = {}
256
+ for (ch_i, ch), (fb_i, fb), feat in self.feature_combinations:
257
+ self.STORE_FEAT_DICT[feat](feature_results, ch_i, ch, fb_i, fb)
258
+
259
+ return feature_results
260
+
261
+ def store_duration(
262
+ self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
263
+ ):
264
+ feature_results[f"{ch}_bursts_{fb}_duration_mean"] = self.burst_duration_mean[
265
+ ch_i, fb_i
266
+ ]
267
+
268
+ feature_results[f"{ch}_bursts_{fb}_duration_max"] = self.burst_duration_max[
269
+ ch_i, fb_i
270
+ ]
271
+
272
+ def store_amplitude(
273
+ self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
274
+ ):
275
+ feature_results[f"{ch}_bursts_{fb}_amplitude_mean"] = self.burst_amplitude_mean[
276
+ ch_i, fb_i
277
+ ]
278
+ feature_results[f"{ch}_bursts_{fb}_amplitude_max"] = self.burst_amplitude_max[
279
+ ch_i, fb_i
280
+ ]
281
+
282
+ def store_burst_rate(
283
+ self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
284
+ ):
285
+ feature_results[f"{ch}_bursts_{fb}_burst_rate_per_s"] = self.burst_rate_per_s[
286
+ ch_i, fb_i
287
+ ]
288
+
289
+ def store_in_burst(
290
+ self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
291
+ ):
292
+ feature_results[f"{ch}_bursts_{fb}_in_burst"] = self.end_in_burst[ch_i, fb_i]