py-neuromodulation 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. py_neuromodulation/__init__.py +16 -10
  2. py_neuromodulation/{nm_RMAP.py → analysis/RMAP.py} +2 -2
  3. py_neuromodulation/analysis/__init__.py +4 -0
  4. py_neuromodulation/{nm_decode.py → analysis/decode.py} +4 -4
  5. py_neuromodulation/{nm_analysis.py → analysis/feature_reader.py} +21 -20
  6. py_neuromodulation/{nm_plots.py → analysis/plots.py} +54 -12
  7. py_neuromodulation/{nm_stats.py → analysis/stats.py} +2 -8
  8. py_neuromodulation/{nm_settings.yaml → default_settings.yaml} +6 -9
  9. py_neuromodulation/features/__init__.py +31 -0
  10. py_neuromodulation/features/bandpower.py +165 -0
  11. py_neuromodulation/{nm_bispectra.py → features/bispectra.py} +8 -5
  12. py_neuromodulation/{nm_bursts.py → features/bursts.py} +14 -9
  13. py_neuromodulation/{nm_coherence.py → features/coherence.py} +17 -13
  14. py_neuromodulation/{nm_features.py → features/feature_processor.py} +30 -53
  15. py_neuromodulation/{nm_fooof.py → features/fooof.py} +11 -8
  16. py_neuromodulation/{nm_hjorth_raw.py → features/hjorth_raw.py} +10 -5
  17. py_neuromodulation/{nm_linelength.py → features/linelength.py} +1 -1
  18. py_neuromodulation/{nm_mne_connectivity.py → features/mne_connectivity.py} +5 -6
  19. py_neuromodulation/{nm_nolds.py → features/nolds.py} +5 -7
  20. py_neuromodulation/{nm_oscillatory.py → features/oscillatory.py} +7 -181
  21. py_neuromodulation/{nm_sharpwaves.py → features/sharpwaves.py} +13 -4
  22. py_neuromodulation/filter/__init__.py +3 -0
  23. py_neuromodulation/{nm_kalmanfilter.py → filter/kalman_filter.py} +67 -71
  24. py_neuromodulation/filter/kalman_filter_external.py +1890 -0
  25. py_neuromodulation/{nm_filter.py → filter/mne_filter.py} +128 -219
  26. py_neuromodulation/filter/notch_filter.py +93 -0
  27. py_neuromodulation/processing/__init__.py +10 -0
  28. py_neuromodulation/{nm_artifacts.py → processing/artifacts.py} +2 -3
  29. py_neuromodulation/{nm_preprocessing.py → processing/data_preprocessor.py} +19 -25
  30. py_neuromodulation/{nm_filter_preprocessing.py → processing/filter_preprocessing.py} +3 -4
  31. py_neuromodulation/{nm_normalization.py → processing/normalization.py} +9 -7
  32. py_neuromodulation/{nm_projection.py → processing/projection.py} +14 -14
  33. py_neuromodulation/{nm_rereference.py → processing/rereference.py} +13 -13
  34. py_neuromodulation/{nm_resample.py → processing/resample.py} +1 -4
  35. py_neuromodulation/stream/__init__.py +3 -0
  36. py_neuromodulation/{nm_run_analysis.py → stream/data_processor.py} +42 -42
  37. py_neuromodulation/stream/generator.py +53 -0
  38. py_neuromodulation/{nm_mnelsl_generator.py → stream/mnelsl_player.py} +10 -6
  39. py_neuromodulation/{nm_mnelsl_stream.py → stream/mnelsl_stream.py} +13 -9
  40. py_neuromodulation/{nm_settings.py → stream/settings.py} +27 -24
  41. py_neuromodulation/{nm_stream.py → stream/stream.py} +217 -188
  42. py_neuromodulation/utils/__init__.py +2 -0
  43. py_neuromodulation/{nm_define_nmchannels.py → utils/channels.py} +14 -9
  44. py_neuromodulation/{nm_database.py → utils/database.py} +2 -2
  45. py_neuromodulation/{nm_IO.py → utils/io.py} +42 -77
  46. py_neuromodulation/utils/keyboard.py +52 -0
  47. py_neuromodulation/{nm_logger.py → utils/logging.py} +3 -3
  48. py_neuromodulation/{nm_types.py → utils/types.py} +72 -14
  49. {py_neuromodulation-0.0.5.dist-info → py_neuromodulation-0.0.6.dist-info}/METADATA +3 -11
  50. py_neuromodulation-0.0.6.dist-info/RECORD +89 -0
  51. py_neuromodulation/FieldTrip.py +0 -589
  52. py_neuromodulation/_write_example_dataset_helper.py +0 -83
  53. py_neuromodulation/nm_generator.py +0 -45
  54. py_neuromodulation/nm_stream_abc.py +0 -166
  55. py_neuromodulation-0.0.5.dist-info/RECORD +0 -83
  56. {py_neuromodulation-0.0.5.dist-info → py_neuromodulation-0.0.6.dist-info}/WHEEL +0 -0
  57. {py_neuromodulation-0.0.5.dist-info → py_neuromodulation-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -8,13 +8,13 @@ from collections.abc import Sequence
8
8
  from itertools import product
9
9
 
10
10
  from pydantic import Field, field_validator
11
- from py_neuromodulation.nm_types import BoolSelector, NMBaseModel
12
- from py_neuromodulation.nm_features import NMFeature
11
+ from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
13
12
 
14
13
  from typing import TYPE_CHECKING, Callable
14
+ from py_neuromodulation.utils.types import create_validation_error
15
15
 
16
16
  if TYPE_CHECKING:
17
- from py_neuromodulation.nm_settings import NMSettings
17
+ from py_neuromodulation import NMSettings
18
18
 
19
19
 
20
20
  LARGE_NUM = 2**24
@@ -45,7 +45,7 @@ class BurstFeatures(BoolSelector):
45
45
  in_burst: bool = True
46
46
 
47
47
 
48
- class BurstSettings(NMBaseModel):
48
+ class BurstsSettings(NMBaseModel):
49
49
  threshold: float = Field(default=75, ge=0, le=100)
50
50
  time_duration_s: float = Field(default=30, ge=0)
51
51
  frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
@@ -56,17 +56,22 @@ class BurstSettings(NMBaseModel):
56
56
  return [f.replace(" ", "_") for f in frequency_bands]
57
57
 
58
58
 
59
- class Burst(NMFeature):
59
+ class Bursts(NMFeature):
60
60
  def __init__(
61
61
  self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
62
62
  ) -> None:
63
63
  # Test settings
64
+ settings.validate()
65
+
66
+ # Validate that all frequency bands are defined in the settings
64
67
  for fband_burst in settings.burst_settings.frequency_bands:
65
- assert (
66
- fband_burst in list(settings.frequency_ranges_hz.keys())
67
- ), f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']"
68
+ if fband_burst not in list(settings.frequency_ranges_hz.keys()):
69
+ raise create_validation_error(
70
+ f"bursting {fband_burst} needs to be defined in settings['frequency_ranges_hz']",
71
+ loc=["burst_settings", "frequency_bands"],
72
+ )
68
73
 
69
- from py_neuromodulation.nm_filter import MNEFilter
74
+ from py_neuromodulation.filter import MNEFilter
70
75
 
71
76
  self.settings = settings.burst_settings
72
77
  self.sfreq = sfreq
@@ -5,12 +5,16 @@ from collections.abc import Iterable
5
5
  from typing import TYPE_CHECKING, Annotated
6
6
  from pydantic import Field, field_validator
7
7
 
8
- from py_neuromodulation.nm_features import NMFeature
9
- from py_neuromodulation.nm_types import BoolSelector, FrequencyRange, NMBaseModel
8
+ from py_neuromodulation.utils.types import (
9
+ NMFeature,
10
+ BoolSelector,
11
+ FrequencyRange,
12
+ NMBaseModel,
13
+ )
10
14
  from py_neuromodulation import logger
11
15
 
12
16
  if TYPE_CHECKING:
13
- from py_neuromodulation.nm_settings import NMSettings
17
+ from py_neuromodulation import NMSettings
14
18
 
15
19
 
16
20
  class CoherenceMethods(BoolSelector):
@@ -140,11 +144,11 @@ class CoherenceObject:
140
144
  return feature_results
141
145
 
142
146
 
143
- class NMCoherence(NMFeature):
147
+ class Coherence(NMFeature):
144
148
  def __init__(
145
149
  self, settings: "NMSettings", ch_names: list[str], sfreq: float
146
150
  ) -> None:
147
- self.settings = settings.coherence
151
+ self.settings = settings.coherence_settings
148
152
  self.frequency_ranges_hz = settings.frequency_ranges_hz
149
153
  self.sfreq = sfreq
150
154
  self.ch_names = ch_names
@@ -193,7 +197,7 @@ class NMCoherence(NMFeature):
193
197
  sfreq: float,
194
198
  ):
195
199
  flat_channels = [
196
- ch for ch_pair in settings.coherence.channels for ch in ch_pair
200
+ ch for ch_pair in settings.coherence_settings.channels for ch in ch_pair
197
201
  ]
198
202
 
199
203
  valid_coh_channel = [
@@ -203,37 +207,37 @@ class NMCoherence(NMFeature):
203
207
  if valid_coh_channel[ch_idx] == 0:
204
208
  raise RuntimeError(
205
209
  f"Coherence selected channel {ch_coh} does not match any channel name: \n"
206
- f" - settings.coherence.channels: {settings.coherence.channels}\n"
210
+ f" - settings.coherence_settings.channels: {settings.coherence_settings.channels}\n"
207
211
  f" - ch_names: {ch_names} \n"
208
212
  )
209
213
 
210
214
  if valid_coh_channel[ch_idx] > 1:
211
215
  raise RuntimeError(
212
216
  f"Coherence selected channel {ch_coh} is ambigous and matches more than one channel name: \n"
213
- f" - settings.coherence.channels: {settings.coherence.channels}\n"
217
+ f" - settings.coherence_settings.channels: {settings.coherence_settings.channels}\n"
214
218
  f" - ch_names: {ch_names} \n"
215
219
  )
216
220
 
217
221
  assert all(
218
222
  f_band_coh in settings.frequency_ranges_hz
219
- for f_band_coh in settings.coherence.frequency_bands
223
+ for f_band_coh in settings.coherence_settings.frequency_bands
220
224
  ), (
221
225
  "coherence selected frequency bands don't match the ones"
222
226
  "specified in s['frequency_ranges_hz']"
223
- f"coherence frequency bands: {settings.coherence.frequency_bands}"
227
+ f"coherence frequency bands: {settings.coherence_settings.frequency_bands}"
224
228
  f"specified frequency_ranges_hz: {settings.frequency_ranges_hz}"
225
229
  )
226
230
 
227
231
  assert all(
228
232
  settings.frequency_ranges_hz[fb][0] < sfreq / 2
229
233
  and settings.frequency_ranges_hz[fb][1] < sfreq / 2
230
- for fb in settings.coherence.frequency_bands
234
+ for fb in settings.coherence_settings.frequency_bands
231
235
  ), (
232
236
  "the coherence frequency band ranges need to be smaller than the Nyquist frequency"
233
- f"got sfreq = {sfreq} and fband ranges {settings.coherence.frequency_bands}"
237
+ f"got sfreq = {sfreq} and fband ranges {settings.coherence_settings.frequency_bands}"
234
238
  )
235
239
 
236
- if not settings.coherence.method.get_enabled():
240
+ if not settings.coherence_settings.method.get_enabled():
237
241
  logger.warn(
238
242
  "feature coherence enabled, but no coherence['method'] selected"
239
243
  )
@@ -1,51 +1,27 @@
1
- from typing import Protocol, Type, runtime_checkable, TYPE_CHECKING
2
- from collections.abc import Sequence
3
- import numpy as np
1
+ from typing import Type, TYPE_CHECKING
4
2
 
5
- if TYPE_CHECKING:
6
- from nm_settings import NMSettings
7
-
8
- from py_neuromodulation.nm_types import ImportDetails, get_class, FeatureName
9
-
10
-
11
- @runtime_checkable
12
- class NMFeature(Protocol):
13
- def __init__(
14
- self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int | float
15
- ) -> None: ...
3
+ from py_neuromodulation.utils.types import NMFeature, FeatureName
16
4
 
17
- def calc_feature(self, data: np.ndarray) -> dict:
18
- """
19
- Feature calculation method. Each method needs to loop through all channels
20
-
21
- Parameters
22
- ----------
23
- data : 'np.ndarray'
24
- (channels, time)
25
- feature_results : dict
26
-
27
- Returns
28
- -------
29
- dict
30
- """
31
- ...
32
-
33
-
34
- FEATURE_DICT: dict[FeatureName | str, ImportDetails] = {
35
- "raw_hjorth": ImportDetails("nm_hjorth_raw", "Hjorth"),
36
- "return_raw": ImportDetails("nm_hjorth_raw", "Raw"),
37
- "bandpass_filter": ImportDetails("nm_oscillatory", "BandPower"),
38
- "stft": ImportDetails("nm_oscillatory", "STFT"),
39
- "fft": ImportDetails("nm_oscillatory", "FFT"),
40
- "welch": ImportDetails("nm_oscillatory", "Welch"),
41
- "sharpwave_analysis": ImportDetails("nm_sharpwaves", "SharpwaveAnalyzer"),
42
- "fooof": ImportDetails("nm_fooof", "FooofAnalyzer"),
43
- "nolds": ImportDetails("nm_nolds", "Nolds"),
44
- "coherence": ImportDetails("nm_coherence", "NMCoherence"),
45
- "bursts": ImportDetails("nm_bursts", "Burst"),
46
- "linelength": ImportDetails("nm_linelength", "LineLength"),
47
- "mne_connectivity": ImportDetails("nm_mne_connectivity", "MNEConnectivity"),
48
- "bispectrum": ImportDetails("nm_bispectra", "Bispectra"),
5
+ if TYPE_CHECKING:
6
+ import numpy as np
7
+ from py_neuromodulation import NMSettings
8
+
9
+
10
+ FEATURE_DICT: dict[FeatureName | str, str] = {
11
+ "raw_hjorth": "Hjorth",
12
+ "return_raw": "Raw",
13
+ "bandpass_filter": "BandPower",
14
+ "stft": "STFT",
15
+ "fft": "FFT",
16
+ "welch": "Welch",
17
+ "sharpwave_analysis": "SharpwaveAnalyzer",
18
+ "fooof": "FooofAnalyzer",
19
+ "nolds": "Nolds",
20
+ "coherence": "Coherence",
21
+ "bursts": "Bursts",
22
+ "linelength": "LineLength",
23
+ "mne_connectivity": "MNEConnectivity",
24
+ "bispectrum": "Bispectra",
49
25
  }
50
26
 
51
27
 
@@ -63,12 +39,13 @@ class FeatureProcessors:
63
39
  sfreq (float): sampling frequency in Hz
64
40
  """
65
41
  from py_neuromodulation import user_features
42
+ from importlib import import_module
66
43
 
67
44
  # Accept 'str' for custom features
68
45
  self.features: dict[FeatureName | str, NMFeature] = {
69
- feature_name: get_class(FEATURE_DICT[feature_name])(
70
- settings, ch_names, sfreq
71
- )
46
+ feature_name: getattr(
47
+ import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name]
48
+ )(settings, ch_names, sfreq)
72
49
  for feature_name in settings.features.get_enabled()
73
50
  }
74
51
 
@@ -80,12 +57,12 @@ class FeatureProcessors:
80
57
 
81
58
  Parameters
82
59
  ----------
83
- feature : nm_features_abc.Feature
60
+ feature : features_abc.Feature
84
61
  New feature to add to feature list
85
62
  """
86
63
  self.features[feature_name] = feature # type: ignore
87
64
 
88
- def estimate_features(self, data: np.ndarray) -> dict:
65
+ def estimate_features(self, data: "np.ndarray") -> dict:
89
66
  """Calculate features, as defined in settings.json
90
67
  Features are based on bandpower, raw Hjorth parameters and sharp wave
91
68
  characteristics.
@@ -125,7 +102,7 @@ def add_custom_feature(feature_name: str, new_feature: Type[NMFeature]):
125
102
  in this file).
126
103
  """
127
104
  from py_neuromodulation import user_features
128
- from py_neuromodulation.nm_settings import NMSettings
105
+ from py_neuromodulation import NMSettings
129
106
 
130
107
  user_features[feature_name] = new_feature
131
108
  NMSettings._add_feature(feature_name)
@@ -138,7 +115,7 @@ def remove_custom_feature(feature_name: str):
138
115
  feature_name (str): Name of the feature to remove
139
116
  """
140
117
  from py_neuromodulation import user_features
141
- from py_neuromodulation.nm_settings import NMSettings
118
+ from py_neuromodulation import NMSettings
142
119
 
143
120
  user_features.pop(feature_name)
144
121
  NMSettings._remove_feature(feature_name)
@@ -2,13 +2,16 @@ from collections.abc import Iterable
2
2
  import numpy as np
3
3
 
4
4
  from typing import TYPE_CHECKING
5
- from py_neuromodulation.nm_types import NMBaseModel
6
5
 
7
- from py_neuromodulation.nm_features import NMFeature
8
- from py_neuromodulation.nm_types import BoolSelector, FrequencyRange
6
+ from py_neuromodulation.utils.types import (
7
+ NMBaseModel,
8
+ NMFeature,
9
+ BoolSelector,
10
+ FrequencyRange,
11
+ )
9
12
 
10
13
  if TYPE_CHECKING:
11
- from py_neuromodulation.nm_settings import NMSettings
14
+ from py_neuromodulation import NMSettings
12
15
 
13
16
 
14
17
  class FooofAperiodicSettings(BoolSelector):
@@ -48,7 +51,7 @@ class FooofAnalyzer(NMFeature):
48
51
  def __init__(
49
52
  self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float
50
53
  ) -> None:
51
- self.settings = settings.fooof
54
+ self.settings = settings.fooof_settings
52
55
  self.sfreq = sfreq
53
56
  self.ch_names = ch_names
54
57
 
@@ -59,12 +62,12 @@ class FooofAnalyzer(NMFeature):
59
62
  self.f_vec = np.arange(0, int(self.num_samples / 2) + 1, 1)
60
63
 
61
64
  assert (
62
- settings.fooof.windowlength_ms <= settings.segment_length_features_ms
65
+ self.settings.windowlength_ms <= settings.segment_length_features_ms
63
66
  ), f"fooof windowlength_ms ({settings.fooof.windowlength_ms}) needs to be smaller equal than segment_length_features_ms ({settings.segment_length_features_ms})."
64
67
 
65
68
  assert (
66
- settings.fooof.freq_range_hz[0] < sfreq
67
- and settings.fooof.freq_range_hz[1] < sfreq
69
+ self.settings.freq_range_hz[0] < sfreq
70
+ and self.settings.freq_range_hz[1] < sfreq
68
71
  ), f"fooof frequency range needs to be below sfreq, got {settings.fooof.freq_range_hz}"
69
72
 
70
73
  from fooof import FOOOFGroup
@@ -5,16 +5,19 @@ Reference: B Hjorth
5
5
  DOI: 10.1016/0013-4694(70)90143-4
6
6
  """
7
7
 
8
+ from typing import TYPE_CHECKING
8
9
  import numpy as np
9
- from collections.abc import Iterable
10
+ from collections.abc import Sequence
10
11
 
11
- from py_neuromodulation.nm_features import NMFeature
12
- from py_neuromodulation.nm_settings import NMSettings
12
+ from py_neuromodulation.utils.types import NMFeature
13
+
14
+ if TYPE_CHECKING:
15
+ from py_neuromodulation import NMSettings
13
16
 
14
17
 
15
18
  class Hjorth(NMFeature):
16
19
  def __init__(
17
- self, settings: NMSettings, ch_names: Iterable[str], sfreq: float
20
+ self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
18
21
  ) -> None:
19
22
  self.ch_names = ch_names
20
23
 
@@ -40,7 +43,9 @@ class Hjorth(NMFeature):
40
43
 
41
44
 
42
45
  class Raw(NMFeature):
43
- def __init__(self, settings: dict, ch_names: Iterable[str], sfreq: float) -> None:
46
+ def __init__(
47
+ self, settings: "NMSettings", ch_names: Sequence[str], sfreq: float
48
+ ) -> None:
44
49
  self.ch_names = ch_names
45
50
 
46
51
  def calc_feature(self, data: np.ndarray) -> dict:
@@ -1,7 +1,7 @@
1
1
  import numpy as np
2
2
  from collections.abc import Sequence
3
3
 
4
- from py_neuromodulation.nm_features import NMFeature
4
+ from py_neuromodulation.features.feature_processor import NMFeature
5
5
 
6
6
 
7
7
  class LineLength(NMFeature):
@@ -1,13 +1,11 @@
1
1
  from collections.abc import Iterable
2
2
  import numpy as np
3
- import pandas as pd
4
3
  from typing import TYPE_CHECKING
5
4
 
6
- from py_neuromodulation.nm_features import NMFeature
7
- from py_neuromodulation.nm_types import NMBaseModel
5
+ from py_neuromodulation.utils.types import NMFeature, NMBaseModel
8
6
 
9
7
  if TYPE_CHECKING:
10
- from py_neuromodulation.nm_settings import NMSettings
8
+ from py_neuromodulation import NMSettings
11
9
  from mne.io import RawArray
12
10
  from mne import Epochs
13
11
 
@@ -32,8 +30,8 @@ class MNEConnectivity(NMFeature):
32
30
  self.sfreq = sfreq
33
31
 
34
32
  # Params used by spectral_connectivity_epochs
35
- self.mode = settings.mne_connectivity.mode
36
- self.method = settings.mne_connectivity.method
33
+ self.mode = settings.mne_connectivity_settings.mode
34
+ self.method = settings.mne_connectivity_settings.method
37
35
 
38
36
  self.fbands = settings.frequency_ranges_hz
39
37
  self.fband_ranges: list = []
@@ -48,6 +46,7 @@ class MNEConnectivity(NMFeature):
48
46
  from mne.io import RawArray
49
47
  from mne import Epochs
50
48
  from mne_connectivity import spectral_connectivity_epochs
49
+ import pandas as pd
51
50
 
52
51
  time_samples_s = data.shape[1] / self.sfreq
53
52
  epoch_length: float = 1 # TODO: Make this a parameter?
@@ -1,16 +1,14 @@
1
1
  import numpy as np
2
2
  from collections.abc import Iterable
3
3
 
4
- from py_neuromodulation.nm_types import NMBaseModel
5
4
  from typing import TYPE_CHECKING
6
5
 
7
- from py_neuromodulation.nm_features import NMFeature
8
- from py_neuromodulation.nm_types import BoolSelector
6
+ from py_neuromodulation.utils.types import NMFeature, BoolSelector, NMBaseModel
9
7
 
10
8
  from pydantic import field_validator
11
9
 
12
10
  if TYPE_CHECKING:
13
- from py_neuromodulation.nm_settings import NMSettings
11
+ from py_neuromodulation import NMSettings
14
12
 
15
13
 
16
14
  class NoldsFeatures(BoolSelector):
@@ -35,16 +33,16 @@ class Nolds(NMFeature):
35
33
  def __init__(
36
34
  self, settings: "NMSettings", ch_names: Iterable[str], sfreq: float
37
35
  ) -> None:
38
- self.settings = settings.nolds_features
36
+ self.settings = settings.nolds_settings
39
37
  self.ch_names = ch_names
40
38
 
41
39
  if len(self.settings.frequency_bands) > 0:
42
- from py_neuromodulation.nm_oscillatory import BandPower
40
+ from py_neuromodulation.features.bandpower import BandPower
43
41
 
44
42
  self.bp_filter = BandPower(settings, ch_names, sfreq, use_kf=False)
45
43
 
46
44
  # Check if the selected frequency bands are defined in the global settings
47
- for fb in settings.nolds_features.frequency_bands:
45
+ for fb in settings.nolds_settings.frequency_bands:
48
46
  assert (
49
47
  fb in settings.frequency_ranges_hz
50
48
  ), f"{fb} selected in nolds_features, but not defined in s['frequency_ranges_hz']"
@@ -1,17 +1,12 @@
1
- from collections.abc import Iterable
1
+ from collections.abc import Sequence
2
2
  import numpy as np
3
3
  from itertools import product
4
4
 
5
- from py_neuromodulation.nm_types import NMBaseModel
6
- from pydantic import field_validator
5
+ from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
7
6
  from typing import TYPE_CHECKING
8
7
 
9
- from py_neuromodulation.nm_features import NMFeature
10
- from py_neuromodulation.nm_types import BoolSelector
11
-
12
8
  if TYPE_CHECKING:
13
- from py_neuromodulation.nm_settings import NMSettings
14
- from py_neuromodulation.nm_kalmanfilter import KalmanSettings
9
+ from py_neuromodulation.stream.settings import NMSettings
15
10
 
16
11
 
17
12
  class OscillatoryFeatures(BoolSelector):
@@ -40,7 +35,7 @@ ESTIMATOR_DICT = {
40
35
 
41
36
  class OscillatoryFeature(NMFeature):
42
37
  def __init__(
43
- self, settings: "NMSettings", ch_names: Iterable[str], sfreq: int
38
+ self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int
44
39
  ) -> None:
45
40
  settings.validate()
46
41
  self.settings: OscillatorySettings # Assignment in subclass __init__
@@ -63,7 +58,7 @@ class FFT(OscillatoryFeature):
63
58
  def __init__(
64
59
  self,
65
60
  settings: "NMSettings",
66
- ch_names: Iterable[str],
61
+ ch_names: Sequence[str],
67
62
  sfreq: int,
68
63
  ) -> None:
69
64
  from scipy.fft import rfftfreq
@@ -127,7 +122,7 @@ class Welch(OscillatoryFeature):
127
122
  def __init__(
128
123
  self,
129
124
  settings: "NMSettings",
130
- ch_names: Iterable[str],
125
+ ch_names: Sequence[str],
131
126
  sfreq: int,
132
127
  ) -> None:
133
128
  from scipy.fft import rfftfreq
@@ -190,7 +185,7 @@ class STFT(OscillatoryFeature):
190
185
  def __init__(
191
186
  self,
192
187
  settings: "NMSettings",
193
- ch_names: Iterable[str],
188
+ ch_names: Sequence[str],
194
189
  sfreq: int,
195
190
  ) -> None:
196
191
  from scipy.fft import rfftfreq
@@ -252,172 +247,3 @@ class STFT(OscillatoryFeature):
252
247
  )[idx]
253
248
 
254
249
  return feature_results
255
-
256
-
257
- class BandpowerFeatures(BoolSelector):
258
- activity: bool = True
259
- mobility: bool = False
260
- complexity: bool = False
261
-
262
-
263
- ###################################
264
- ######## BANDPOWER FEATURE ########
265
- ###################################
266
-
267
-
268
- class BandpassSettings(NMBaseModel):
269
- segment_lengths_ms: dict[str, int] = {
270
- "theta": 1000,
271
- "alpha": 500,
272
- "low_beta": 333,
273
- "high_beta": 333,
274
- "low_gamma": 100,
275
- "high_gamma": 100,
276
- "HFA": 100,
277
- }
278
- bandpower_features: BandpowerFeatures = BandpowerFeatures()
279
- log_transform: bool = True
280
- kalman_filter: bool = False
281
-
282
- @field_validator("segment_lengths_ms")
283
- @classmethod
284
- # Replace spaces with underscores in frequency band names
285
- def fbands_spaces_to_underscores(cls, segment_lengths_ms: dict[str, int]):
286
- return {k.replace(" ", "_"): v for k, v in segment_lengths_ms.items()}
287
-
288
- @field_validator("bandpower_features")
289
- @classmethod
290
- def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
291
- assert (
292
- len(bandpower_features.get_enabled()) > 0
293
- ), "Set at least one bandpower_feature to True."
294
-
295
- return bandpower_features
296
-
297
- def validate_fbands(self, settings: "NMSettings") -> None:
298
- # Ensure that each freq-band is defined in the global settings
299
- for fband_name in settings.frequency_ranges_hz.keys():
300
- assert fband_name in self.segment_lengths_ms, (
301
- f"frequency range {fband_name} "
302
- "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
303
- )
304
-
305
- # Ensure that segment length for each freq-band is smaller than the feature segment length setting
306
- for fband_name, seg_length_fband in self.segment_lengths_ms.items():
307
- assert seg_length_fband <= settings.segment_length_features_ms, (
308
- f"segment length {seg_length_fband} needs to be smaller than "
309
- f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
310
- )
311
-
312
-
313
- class BandPower(NMFeature):
314
- def __init__(
315
- self,
316
- settings: "NMSettings",
317
- ch_names: Iterable[str],
318
- sfreq: float,
319
- use_kf: bool | None = None,
320
- ) -> None:
321
- settings.validate()
322
-
323
- self.bp_settings: BandpassSettings = settings.bandpass_filter_settings
324
- self.kalman_filter_settings: KalmanSettings = settings.kalman_filter_settings
325
- self.sfreq = sfreq
326
- self.ch_names = ch_names
327
- self.KF_dict: dict = {}
328
-
329
- from py_neuromodulation.nm_filter import MNEFilter
330
-
331
- self.bandpass_filter = MNEFilter(
332
- f_ranges=[
333
- tuple(frange) for frange in settings.frequency_ranges_hz.values()
334
- ],
335
- sfreq=self.sfreq,
336
- filter_length=self.sfreq - 1,
337
- verbose=False,
338
- )
339
-
340
- if use_kf or (use_kf is None and self.bp_settings.kalman_filter):
341
- self.init_KF("bandpass_activity")
342
-
343
- seglengths = self.bp_settings.segment_lengths_ms
344
-
345
- self.feature_params = []
346
- for ch_idx, ch_name in enumerate(self.ch_names):
347
- for f_band_idx, f_band in enumerate(settings.frequency_ranges_hz.keys()):
348
- seglength_ms = seglengths[f_band]
349
- seglen = int(np.floor(self.sfreq / 1000 * seglength_ms))
350
- for bp_feature in self.bp_settings.bandpower_features.get_enabled():
351
- feature_name = "_".join([ch_name, "bandpass", bp_feature, f_band])
352
- self.feature_params.append(
353
- (
354
- ch_idx,
355
- f_band_idx,
356
- seglen,
357
- bp_feature,
358
- feature_name,
359
- )
360
- )
361
-
362
- def init_KF(self, feature: str) -> None:
363
- from py_neuromodulation.nm_kalmanfilter import define_KF
364
-
365
- for f_band in self.kalman_filter_settings.frequency_bands:
366
- for channel in self.ch_names:
367
- self.KF_dict["_".join([channel, feature, f_band])] = define_KF(
368
- self.kalman_filter_settings.Tp,
369
- self.kalman_filter_settings.sigma_w,
370
- self.kalman_filter_settings.sigma_v,
371
- )
372
-
373
- def update_KF(self, feature_calc: np.floating, KF_name: str) -> np.floating:
374
- if KF_name in self.KF_dict:
375
- self.KF_dict[KF_name].predict()
376
- self.KF_dict[KF_name].update(feature_calc)
377
- feature_calc = self.KF_dict[KF_name].x[0]
378
- return feature_calc
379
-
380
- def calc_feature(self, data: np.ndarray) -> dict:
381
- data = self.bandpass_filter.filter_data(data)
382
-
383
- feature_results = {}
384
-
385
- for (
386
- ch_idx,
387
- f_band_idx,
388
- seglen,
389
- bp_feature,
390
- feature_name,
391
- ) in self.feature_params:
392
- feature_results[feature_name] = self.calc_bp_feature(
393
- bp_feature, feature_name, data[ch_idx, f_band_idx, -seglen:]
394
- )
395
-
396
- return feature_results
397
-
398
- def calc_bp_feature(self, bp_feature, feature_name, data):
399
- match bp_feature:
400
- case "activity":
401
- feature_calc = np.var(data)
402
- if self.bp_settings.log_transform:
403
- feature_calc = np.log10(feature_calc)
404
- if self.KF_dict:
405
- feature_calc = self.update_KF(feature_calc, feature_name)
406
- case "mobility":
407
- feature_calc = np.sqrt(np.var(np.diff(data)) / np.var(data))
408
- case "complexity":
409
- feature_calc = self.calc_complexity(data)
410
- case _:
411
- raise ValueError(f"Unknown bandpower feature: {bp_feature}")
412
-
413
- return np.nan_to_num(feature_calc)
414
-
415
- @staticmethod
416
- def calc_complexity(data: np.ndarray) -> float:
417
- dat_deriv = np.diff(data)
418
- deriv_variance = np.var(dat_deriv)
419
- mobility = np.sqrt(deriv_variance / np.var(data))
420
- dat_deriv_2_var = np.var(np.diff(dat_deriv))
421
- deriv_mobility = np.sqrt(dat_deriv_2_var / deriv_variance)
422
-
423
- return deriv_mobility / mobility