py-neuromodulation 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. py_neuromodulation/ConnectivityDecoding/_get_grid_hull.m +34 -34
  2. py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +95 -106
  3. py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +107 -119
  4. py_neuromodulation/__init__.py +80 -13
  5. py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +496 -531
  6. py_neuromodulation/analysis/__init__.py +4 -0
  7. py_neuromodulation/{nm_decode.py → analysis/decode.py} +918 -992
  8. py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +994 -1074
  9. py_neuromodulation/{nm_plots.py → analysis/plots.py} +627 -612
  10. py_neuromodulation/{nm_stats.py → analysis/stats.py} +458 -480
  11. py_neuromodulation/data/README +6 -6
  12. py_neuromodulation/data/dataset_description.json +8 -8
  13. py_neuromodulation/data/participants.json +32 -32
  14. py_neuromodulation/data/participants.tsv +2 -2
  15. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_coordsystem.json +5 -5
  16. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_space-mni_electrodes.tsv +11 -11
  17. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_channels.tsv +11 -11
  18. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.json +18 -18
  19. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr +35 -35
  20. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/ieeg/sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vmrk +13 -13
  21. py_neuromodulation/data/sub-testsub/ses-EphysMedOff/sub-testsub_ses-EphysMedOff_scans.tsv +2 -2
  22. py_neuromodulation/default_settings.yaml +241 -0
  23. py_neuromodulation/features/__init__.py +31 -0
  24. py_neuromodulation/features/bandpower.py +165 -0
  25. py_neuromodulation/features/bispectra.py +157 -0
  26. py_neuromodulation/features/bursts.py +297 -0
  27. py_neuromodulation/features/coherence.py +255 -0
  28. py_neuromodulation/features/feature_processor.py +121 -0
  29. py_neuromodulation/features/fooof.py +142 -0
  30. py_neuromodulation/features/hjorth_raw.py +57 -0
  31. py_neuromodulation/features/linelength.py +21 -0
  32. py_neuromodulation/features/mne_connectivity.py +148 -0
  33. py_neuromodulation/features/nolds.py +94 -0
  34. py_neuromodulation/features/oscillatory.py +249 -0
  35. py_neuromodulation/features/sharpwaves.py +432 -0
  36. py_neuromodulation/filter/__init__.py +3 -0
  37. py_neuromodulation/filter/kalman_filter.py +67 -0
  38. py_neuromodulation/filter/kalman_filter_external.py +1890 -0
  39. py_neuromodulation/filter/mne_filter.py +128 -0
  40. py_neuromodulation/filter/notch_filter.py +93 -0
  41. py_neuromodulation/grid_cortex.tsv +40 -40
  42. py_neuromodulation/liblsl/libpugixml.so.1.12 +0 -0
  43. py_neuromodulation/liblsl/linux/bionic_amd64/liblsl.1.16.2.so +0 -0
  44. py_neuromodulation/liblsl/linux/bookworm_amd64/liblsl.1.16.2.so +0 -0
  45. py_neuromodulation/liblsl/linux/focal_amd46/liblsl.1.16.2.so +0 -0
  46. py_neuromodulation/liblsl/linux/jammy_amd64/liblsl.1.16.2.so +0 -0
  47. py_neuromodulation/liblsl/linux/jammy_x86/liblsl.1.16.2.so +0 -0
  48. py_neuromodulation/liblsl/linux/noble_amd64/liblsl.1.16.2.so +0 -0
  49. py_neuromodulation/liblsl/macos/amd64/liblsl.1.16.2.dylib +0 -0
  50. py_neuromodulation/liblsl/macos/arm64/liblsl.1.16.0.dylib +0 -0
  51. py_neuromodulation/liblsl/windows/amd64/liblsl.1.16.2.dll +0 -0
  52. py_neuromodulation/liblsl/windows/x86/liblsl.1.16.2.dll +0 -0
  53. py_neuromodulation/processing/__init__.py +10 -0
  54. py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +29 -25
  55. py_neuromodulation/processing/data_preprocessor.py +77 -0
  56. py_neuromodulation/processing/filter_preprocessing.py +78 -0
  57. py_neuromodulation/processing/normalization.py +175 -0
  58. py_neuromodulation/{nm_projection.py → processing/projection.py} +370 -394
  59. py_neuromodulation/{nm_rereference.py → processing/rereference.py} +97 -95
  60. py_neuromodulation/{nm_resample.py → processing/resample.py} +56 -50
  61. py_neuromodulation/stream/__init__.py +3 -0
  62. py_neuromodulation/stream/data_processor.py +325 -0
  63. py_neuromodulation/stream/generator.py +53 -0
  64. py_neuromodulation/stream/mnelsl_player.py +94 -0
  65. py_neuromodulation/stream/mnelsl_stream.py +120 -0
  66. py_neuromodulation/stream/settings.py +292 -0
  67. py_neuromodulation/stream/stream.py +427 -0
  68. py_neuromodulation/utils/__init__.py +2 -0
  69. py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +305 -302
  70. py_neuromodulation/utils/database.py +149 -0
  71. py_neuromodulation/utils/io.py +378 -0
  72. py_neuromodulation/utils/keyboard.py +52 -0
  73. py_neuromodulation/utils/logging.py +66 -0
  74. py_neuromodulation/utils/types.py +251 -0
  75. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +28 -33
  76. py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
  77. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +1 -1
  78. {py_neuromodulation-0.0.4.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +21 -21
  79. py_neuromodulation/FieldTrip.py +0 -589
  80. py_neuromodulation/_write_example_dataset_helper.py +0 -65
  81. py_neuromodulation/nm_EpochStream.py +0 -92
  82. py_neuromodulation/nm_IO.py +0 -417
  83. py_neuromodulation/nm_across_patient_decoding.py +0 -927
  84. py_neuromodulation/nm_bispectra.py +0 -168
  85. py_neuromodulation/nm_bursts.py +0 -198
  86. py_neuromodulation/nm_coherence.py +0 -205
  87. py_neuromodulation/nm_cohortwrapper.py +0 -435
  88. py_neuromodulation/nm_eval_timing.py +0 -239
  89. py_neuromodulation/nm_features.py +0 -116
  90. py_neuromodulation/nm_features_abc.py +0 -39
  91. py_neuromodulation/nm_filter.py +0 -219
  92. py_neuromodulation/nm_filter_preprocessing.py +0 -91
  93. py_neuromodulation/nm_fooof.py +0 -159
  94. py_neuromodulation/nm_generator.py +0 -37
  95. py_neuromodulation/nm_hjorth_raw.py +0 -73
  96. py_neuromodulation/nm_kalmanfilter.py +0 -58
  97. py_neuromodulation/nm_linelength.py +0 -33
  98. py_neuromodulation/nm_mne_connectivity.py +0 -112
  99. py_neuromodulation/nm_nolds.py +0 -93
  100. py_neuromodulation/nm_normalization.py +0 -214
  101. py_neuromodulation/nm_oscillatory.py +0 -448
  102. py_neuromodulation/nm_run_analysis.py +0 -435
  103. py_neuromodulation/nm_settings.json +0 -338
  104. py_neuromodulation/nm_settings.py +0 -68
  105. py_neuromodulation/nm_sharpwaves.py +0 -401
  106. py_neuromodulation/nm_stream_abc.py +0 -218
  107. py_neuromodulation/nm_stream_offline.py +0 -359
  108. py_neuromodulation/utils/_logging.py +0 -24
  109. py_neuromodulation-0.0.4.dist-info/RECORD +0 -72
@@ -0,0 +1,148 @@
1
+ from collections.abc import Iterable
2
+ import numpy as np
3
+ from typing import TYPE_CHECKING
4
+
5
+ from py_neuromodulation.utils.types import NMFeature, NMBaseModel
6
+
7
+ if TYPE_CHECKING:
8
+ from py_neuromodulation import NMSettings
9
+ from mne.io import RawArray
10
+ from mne import Epochs
11
+
12
+
13
+ class MNEConnectivitySettings(NMBaseModel):
14
+ method: str = "plv"
15
+ mode: str = "multitaper"
16
+
17
+
18
+ class MNEConnectivity(NMFeature):
19
+ def __init__(
20
+ self,
21
+ settings: "NMSettings",
22
+ ch_names: Iterable[str],
23
+ sfreq: float,
24
+ ) -> None:
25
+ from mne import create_info
26
+
27
+ self.settings = settings
28
+
29
+ self.ch_names = ch_names
30
+ self.sfreq = sfreq
31
+
32
+ # Params used by spectral_connectivity_epochs
33
+ self.mode = settings.mne_connectivity_settings.mode
34
+ self.method = settings.mne_connectivity_settings.method
35
+
36
+ self.fbands = settings.frequency_ranges_hz
37
+ self.fband_ranges: list = []
38
+ self.result_keys = []
39
+
40
+ self.raw_info = create_info(ch_names=self.ch_names, sfreq=self.sfreq)
41
+ self.raw_array: "RawArray"
42
+ self.epochs: "Epochs"
43
+ self.prev_batch_shape: tuple = (-1, -1) # sentinel value
44
+
45
+ def calc_feature(self, data: np.ndarray) -> dict:
46
+ from mne.io import RawArray
47
+ from mne import Epochs
48
+ from mne_connectivity import spectral_connectivity_epochs
49
+ import pandas as pd
50
+
51
+ time_samples_s = data.shape[1] / self.sfreq
52
+ epoch_length: float = 1 # TODO: Make this a parameter?
53
+
54
+ if epoch_length > time_samples_s:
55
+ raise ValueError(
56
+ f"the intended epoch length for mne connectivity: {epoch_length}s"
57
+ f" are longer than the passed data array {np.round(time_samples_s, 2)}s"
58
+ )
59
+
60
+ # Only reinitialize the raw_array and epochs object if the data shape has changed
61
+ # That could mean that the channels have been re-selected, or we're in the last batch
62
+ # TODO: If sfreq or channels change, do we re-initialize the whole Stream object?
63
+ if data.shape != self.prev_batch_shape:
64
+ self.raw_array = RawArray(
65
+ data=data,
66
+ info=self.raw_info,
67
+ copy=None, # type: ignore
68
+ verbose=False,
69
+ )
70
+
71
+ # self.events = make_fixed_length_events(self.raw_array, duration=epoch_length)
72
+ # Equivalent code for those parameters:
73
+ event_times = np.arange(
74
+ 0, data.shape[-1], self.sfreq * epoch_length, dtype=int
75
+ )
76
+ events = np.column_stack(
77
+ (
78
+ event_times,
79
+ np.zeros_like(event_times, dtype=int),
80
+ np.ones_like(event_times, dtype=int),
81
+ )
82
+ )
83
+
84
+ # there need to be minimum 2 of two epochs, otherwise mne_connectivity
85
+ # is not correctly initialized
86
+ if events.shape[0] < 2:
87
+ raise RuntimeError(
88
+ f"A minimum of 2 epochs is required for mne_connectivity,"
89
+ f" got only {events.shape[0]}. Increase settings['segment_length_features_ms']"
90
+ )
91
+
92
+ self.epochs = Epochs(
93
+ self.raw_array,
94
+ events=events,
95
+ event_id={"rest": 1},
96
+ tmin=0,
97
+ tmax=epoch_length,
98
+ baseline=None,
99
+ reject_by_annotation=True,
100
+ verbose=False,
101
+ )
102
+
103
+ # Trick the function "spectral_connectivity_epochs" into not calling "add_annotations_to_metadata"
104
+ # TODO: This is a hack, and maybe needs a fix in the mne_connectivity library
105
+ self.epochs._metadata = pd.DataFrame(index=np.arange(events.shape[0]))
106
+
107
+ else:
108
+ # As long as the initialization parameters, channels, sfreq and batch size are the same
109
+ # We can re-use the existing epochs object by updating the raw data
110
+ self.raw_array._data = data
111
+ self.epochs._raw = self.raw_array
112
+
113
+ # n_jobs is here kept to 1, since setup of the multiprocessing Pool
114
+ # takes longer than most batch computing sizes
115
+ spec_out = spectral_connectivity_epochs(
116
+ data=self.epochs,
117
+ sfreq=self.sfreq,
118
+ method=self.method,
119
+ mode=self.mode,
120
+ indices=(np.array([0, 0, 1, 1]), np.array([2, 3, 2, 3])),
121
+ verbose=False,
122
+ )
123
+ dat_conn: np.ndarray = spec_out.get_data()
124
+
125
+ # Get frequency band ranges only for the first batch, it's already the same
126
+ if len(self.fband_ranges) == 0:
127
+ for fband_range in self.fbands.values():
128
+ self.fband_ranges.append(
129
+ np.where(
130
+ (np.array(spec_out.freqs) > fband_range[0])
131
+ & (np.array(spec_out.freqs) < fband_range[1])
132
+ )[0]
133
+ )
134
+
135
+ # TODO: If I compute the mean for the entire fband, results are almost the same before
136
+ # normalization (0.9999999... vs 1.0), but some change wildly after normalization (-3 vs 0)
137
+ # Investigate why, is this a bug in normalization?
138
+ feature_results = {}
139
+ for conn in np.arange(dat_conn.shape[0]):
140
+ for fband_idx, fband in enumerate(self.fbands):
141
+ feature_results["_".join(["ch1", self.method, str(conn), fband])] = (
142
+ np.mean(dat_conn[conn, self.fband_ranges[fband_idx]])
143
+ )
144
+
145
+ # Store current experiment parameters to check if re-initialization is needed
146
+ self.prev_batch_shape = data.shape
147
+
148
+ return feature_results
@@ -0,0 +1,94 @@
1
+ import numpy as np
2
+ from collections.abc import Iterable
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ from py_neuromodulation.utils.types import NMFeature, BoolSelector, NMBaseModel
7
+
8
+ from pydantic import field_validator
9
+
10
+ if TYPE_CHECKING:
11
+ from py_neuromodulation import NMSettings
12
+
13
+
14
+ class NoldsFeatures(BoolSelector):
15
+ sample_entropy: bool = False
16
+ correlation_dimension: bool = False
17
+ lyapunov_exponent: bool = True
18
+ hurst_exponent: bool = False
19
+ detrended_fluctuation_analysis: bool = False
20
+
21
+
22
+ class NoldsSettings(NMBaseModel):
23
+ raw: bool = True
24
+ frequency_bands: list[str] = ["low_beta"]
25
+ features: NoldsFeatures = NoldsFeatures()
26
+
27
+ @field_validator("frequency_bands")
28
+ def fbands_spaces_to_underscores(cls, frequency_bands):
29
+ return [f.replace(" ", "_") for f in frequency_bands]
30
+
31
+
32
+ class Nolds(NMFeature):
33
+ def __init__(
34
+ self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float
35
+ ) -> None:
36
+ self.settings = settings.nolds_settings
37
+ self.ch_names = ch_names
38
+
39
+ if len(self.settings.frequency_bands) > 0:
40
+ from py_neuromodulation.features.bandpower import BandPower
41
+
42
+ self.bp_filter = BandPower(settings, ch_names, sfreq, use_kf=False)
43
+
44
+ # Check if the selected frequency bands are defined in the global settings
45
+ for fb in settings.nolds_settings.frequency_bands:
46
+ assert (
47
+ fb in settings.frequency_ranges_hz
48
+ ), f"{fb} selected in nolds_features, but not defined in s['frequency_ranges_hz']"
49
+
50
+ def calc_feature(self, data: np.ndarray) -> dict:
51
+ feature_results = {}
52
+ data = np.nan_to_num(data)
53
+ if self.settings.raw:
54
+ feature_results = self.calc_nolds(data, feature_results)
55
+ if len(self.settings.frequency_bands) > 0:
56
+ data_filt = self.bp_filter.bandpass_filter.filter_data(data)
57
+
58
+ for f_band_idx, f_band in enumerate(self.settings.frequency_bands):
59
+ # filter data now for a specific fband and pass to calc_nolds
60
+ feature_results = self.calc_nolds(
61
+ data_filt[:, f_band_idx, :], feature_results, f_band
62
+ ) # ch, bands, samples
63
+ return feature_results
64
+
65
+ def calc_nolds(
66
+ self, data: np.ndarray, feature_results: dict, data_str: str = "raw"
67
+ ) -> dict:
68
+ for ch_idx, ch_name in enumerate(self.ch_names):
69
+ for f_name in self.settings.features.get_enabled():
70
+ feature_results[f"{ch_name}_nolds_{f_name}_{data_str}"] = (
71
+ self.calc_nolds_feature(f_name, data[ch_idx, :])
72
+ if data[ch_idx, :].sum()
73
+ else 0
74
+ )
75
+
76
+ return feature_results
77
+
78
+ @staticmethod
79
+ def calc_nolds_feature(f_name: str, dat: np.ndarray):
80
+ import nolds
81
+
82
+ match f_name:
83
+ case "sample_entropy":
84
+ return nolds.sampen(dat)
85
+ case "correlation_dimension":
86
+ return nolds.corr_dim(dat, emb_dim=2)
87
+ case "lyapunov_exponent":
88
+ return nolds.lyap_r(dat)
89
+ case "hurst_exponent":
90
+ return nolds.hurst_rs(dat)
91
+ case "detrended_fluctuation_analysis":
92
+ return nolds.dfa(dat)
93
+ case _:
94
+ raise ValueError(f"Invalid nolds feature name: {f_name}")
@@ -0,0 +1,249 @@
1
+ from collections.abc import Sequence
2
+ import numpy as np
3
+ from itertools import product
4
+
5
+ from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
6
+ from typing import TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ from py_neuromodulation.stream.settings import NMSettings
10
+
11
+
12
+ class OscillatoryFeatures(BoolSelector):
13
+ mean: bool = True
14
+ median: bool = False
15
+ std: bool = False
16
+ max: bool = False
17
+
18
+
19
+ class OscillatorySettings(NMBaseModel):
20
+ windowlength_ms: int = 1000
21
+ log_transform: bool = True
22
+ features: OscillatoryFeatures = OscillatoryFeatures(
23
+ mean=True, median=False, std=False, max=False
24
+ )
25
+ return_spectrum: bool = False
26
+
27
+
28
+ ESTIMATOR_DICT = {
29
+ "mean": np.nanmean,
30
+ "median": np.nanmedian,
31
+ "std": np.nanstd,
32
+ "max": np.nanmax,
33
+ }
34
+
35
+
36
+ class OscillatoryFeature(NMFeature):
37
+ def __init__(
38
+ self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int
39
+ ) -> None:
40
+ settings.validate()
41
+ self.settings: OscillatorySettings # Assignment in subclass __init__
42
+ self.osc_feature_name: str # Required for output
43
+
44
+ self.sfreq = int(sfreq)
45
+ self.ch_names = ch_names
46
+
47
+ self.frequency_ranges = settings.frequency_ranges_hz
48
+
49
+ # Test settings
50
+ assert self.settings.windowlength_ms <= settings.segment_length_features_ms, (
51
+ f"oscillatory feature windowlength_ms = ({self.settings.windowlength_ms})"
52
+ f"needs to be smaller than"
53
+ f"settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
54
+ )
55
+
56
+
57
+ class FFT(OscillatoryFeature):
58
+ def __init__(
59
+ self,
60
+ settings: "NMSettings",
61
+ ch_names: Sequence[str],
62
+ sfreq: int,
63
+ ) -> None:
64
+ from scipy.fft import rfftfreq
65
+
66
+ self.osc_feature_name = "fft"
67
+ self.settings = settings.fft_settings
68
+ # super.__init__ needs osc_feature_name and settings
69
+ super().__init__(settings, ch_names, sfreq)
70
+
71
+ window_ms = self.settings.windowlength_ms
72
+
73
+ self.window_samples = int(-np.floor(window_ms / 1000 * sfreq))
74
+ self.freqs = rfftfreq(-self.window_samples, 1 / np.floor(self.sfreq))
75
+
76
+ # Pre-calculate frequency ranges
77
+ self.idx_range = [
78
+ (
79
+ f_band,
80
+ np.where((self.freqs >= f_range[0]) & (self.freqs < f_range[1]))[0],
81
+ )
82
+ for f_band, f_range in self.frequency_ranges.items()
83
+ ]
84
+
85
+ self.estimators = [
86
+ (est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
87
+ ]
88
+
89
+ def calc_feature(self, data: np.ndarray) -> dict:
90
+ data = data[:, self.window_samples :]
91
+
92
+ from scipy.fft import rfft
93
+
94
+ Z = np.abs(rfft(data)) # type: ignore
95
+
96
+ if self.settings.log_transform:
97
+ Z = np.log10(Z)
98
+
99
+ feature_results = {}
100
+
101
+ for f_band_name, idx_range in self.idx_range:
102
+ # TODO Can we get rid of this for-loop? Hard to vectorize windows of different lengths...
103
+ Z_band = Z[:, idx_range] # Data for all channels
104
+
105
+ for est_name, est_fun in self.estimators:
106
+ result = est_fun(Z_band, axis=1)
107
+
108
+ for ch_idx, ch_name in enumerate(self.ch_names):
109
+ feature_results[
110
+ f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
111
+ ] = result[ch_idx]
112
+
113
+ if self.settings.return_spectrum:
114
+ combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
115
+ for (ch_idx, ch_name), (idx, f) in combinations:
116
+ feature_results[f"{ch_name}_fft_psd_{int(f)}"] = Z[ch_idx][idx]
117
+
118
+ return feature_results
119
+
120
+
121
+ class Welch(OscillatoryFeature):
122
+ def __init__(
123
+ self,
124
+ settings: "NMSettings",
125
+ ch_names: Sequence[str],
126
+ sfreq: int,
127
+ ) -> None:
128
+ from scipy.fft import rfftfreq
129
+
130
+ self.osc_feature_name = "welch"
131
+ self.settings = settings.welch_settings
132
+ # super.__init__ needs osc_feature_name and settings
133
+ super().__init__(settings, ch_names, sfreq)
134
+
135
+ self.freqs = rfftfreq(self.sfreq, 1 / self.sfreq)
136
+
137
+ self.idx_range = [
138
+ (
139
+ f_band,
140
+ np.where((self.freqs >= f_range[0]) & (self.freqs < f_range[1]))[0],
141
+ )
142
+ for f_band, f_range in self.frequency_ranges.items()
143
+ ]
144
+
145
+ self.estimators = [
146
+ (est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
147
+ ]
148
+
149
+ def calc_feature(self, data: np.ndarray) -> dict:
150
+ from scipy.signal import welch
151
+
152
+ _, Z = welch(
153
+ data,
154
+ fs=self.sfreq,
155
+ window="hann",
156
+ nperseg=self.sfreq,
157
+ noverlap=None,
158
+ )
159
+
160
+ if self.settings.log_transform:
161
+ Z = np.log10(Z)
162
+
163
+ feature_results = {}
164
+
165
+ for f_band_name, idx_range in self.idx_range:
166
+ Z_band = Z[:, idx_range]
167
+
168
+ for est_name, est_fun in self.estimators:
169
+ result = est_fun(Z_band, axis=1)
170
+
171
+ for ch_idx, ch_name in enumerate(self.ch_names):
172
+ feature_results[
173
+ f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
174
+ ] = result[ch_idx]
175
+
176
+ if self.settings.return_spectrum:
177
+ combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
178
+ for (ch_idx, ch_name), (idx, f) in combinations:
179
+ feature_results[f"{ch_name}_welch_psd_{str(f)}"] = Z[ch_idx][idx]
180
+
181
+ return feature_results
182
+
183
+
184
+ class STFT(OscillatoryFeature):
185
+ def __init__(
186
+ self,
187
+ settings: "NMSettings",
188
+ ch_names: Sequence[str],
189
+ sfreq: int,
190
+ ) -> None:
191
+ from scipy.fft import rfftfreq
192
+
193
+ self.osc_feature_name = "stft"
194
+ self.settings = settings.stft_settings
195
+ # super.__init__ needs osc_feature_name and settings
196
+ super().__init__(settings, ch_names, sfreq)
197
+
198
+ self.nperseg = self.settings.windowlength_ms
199
+
200
+ self.freqs = rfftfreq(self.nperseg, 1 / self.sfreq)
201
+
202
+ self.idx_range = [
203
+ (
204
+ f_band,
205
+ np.where((self.freqs >= f_range[0]) & (self.freqs <= f_range[1]))[0],
206
+ )
207
+ for f_band, f_range in self.frequency_ranges.items()
208
+ ]
209
+
210
+ self.estimators = [
211
+ (est, ESTIMATOR_DICT[est]) for est in self.settings.features.get_enabled()
212
+ ]
213
+
214
+ def calc_feature(self, data: np.ndarray) -> dict:
215
+ from scipy.signal import stft
216
+
217
+ _, _, Zxx = stft(
218
+ data,
219
+ fs=self.sfreq,
220
+ window="hamming",
221
+ nperseg=self.nperseg,
222
+ boundary="even",
223
+ )
224
+
225
+ Z = np.abs(Zxx)
226
+ if self.settings.log_transform:
227
+ Z = np.log10(Z)
228
+
229
+ feature_results = {}
230
+
231
+ for f_band_name, idx_range in self.idx_range:
232
+ Z_band = Z[:, idx_range, :]
233
+
234
+ for est_name, est_fun in self.estimators:
235
+ result = est_fun(Z_band, axis=(1, 2))
236
+
237
+ for ch_idx, ch_name in enumerate(self.ch_names):
238
+ feature_results[
239
+ f"{ch_name}_{self.osc_feature_name}_{f_band_name}_{est_name}"
240
+ ] = result[ch_idx]
241
+
242
+ if self.settings.return_spectrum:
243
+ combinations = product(enumerate(self.ch_names), enumerate(self.freqs))
244
+ for (ch_idx, ch_name), (idx, f) in combinations:
245
+ feature_results[f"{ch_name}_stft_psd_{str(f)}"] = Z[ch_idx].mean(
246
+ axis=1
247
+ )[idx]
248
+
249
+ return feature_results