paradigma 1.0.3__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
paradigma/__init__.py CHANGED
@@ -1,6 +1,15 @@
1
+ """
2
+ ParaDigMa: Parkinson Digital Biomarker Analysis Toolbox
3
+ """
4
+
1
5
  # read version from installed package
2
6
  from importlib.metadata import version
3
7
 
4
8
  __version__ = version("paradigma")
5
9
 
6
- __all__ = []
10
+ # Import main pipeline functions for easy access
11
+ from .orchestrator import run_paradigma
12
+
13
+ __all__ = [
14
+ "run_paradigma",
15
+ ]
@@ -1,14 +1,19 @@
1
- import numpy as np
2
1
  import pickle
3
-
4
2
  from pathlib import Path
3
+ from typing import Any
4
+
5
+ import numpy as np
5
6
  from sklearn.base import BaseEstimator
6
- from typing import Any, Optional
7
+ from sklearn.preprocessing import StandardScaler
8
+
7
9
 
8
10
  class ClassifierPackage:
9
- def __init__(self, classifier: Optional[BaseEstimator] = None,
10
- threshold: Optional[float] = None,
11
- scaler: Optional[Any] = None):
11
+ def __init__(
12
+ self,
13
+ classifier: BaseEstimator | None = None,
14
+ threshold: float | None = None,
15
+ scaler: Any | None = None,
16
+ ):
12
17
  """
13
18
  Initialize the ClassifierPackage with a classifier, threshold, and scaler.
14
19
 
@@ -25,13 +30,13 @@ class ClassifierPackage:
25
30
  self.threshold = threshold
26
31
  self.scaler = scaler
27
32
 
28
- def transform_features(self, X) -> np.ndarray:
33
+ def transform_features(self, x) -> np.ndarray:
29
34
  """
30
35
  Transform the input features using the scaler.
31
36
 
32
37
  Parameters
33
38
  ----------
34
- X : np.ndarray
39
+ x : np.ndarray
35
40
  The input features.
36
41
 
37
42
  Return
@@ -40,16 +45,28 @@ class ClassifierPackage:
40
45
  The transformed features.
41
46
  """
42
47
  if not self.scaler:
43
- return X
44
- return self.scaler.transform(X)
48
+ return x
49
+ return self.scaler.transform(x)
45
50
 
46
- def predict_proba(self, X) -> float:
51
+ def update_scaler(self, x_train: np.ndarray) -> None:
52
+ """
53
+ Update the scaler used for feature transformation.
54
+
55
+ Parameters
56
+ ----------
57
+ x_train : np.ndarray
58
+ The training data to fit the scaler.
59
+ """
60
+ scaler = StandardScaler()
61
+ self.scaler = scaler.fit(x_train)
62
+
63
+ def predict_proba(self, x) -> float:
47
64
  """
48
65
  Make predictions using the classifier and apply the threshold.
49
66
 
50
67
  Parameters
51
68
  ----------
52
- X : np.ndarray
69
+ x : np.ndarray
53
70
  The input features.
54
71
 
55
72
  Return
@@ -60,15 +77,15 @@ class ClassifierPackage:
60
77
  """
61
78
  if not self.classifier:
62
79
  raise ValueError("Classifier is not loaded.")
63
- return self.classifier.predict_proba(X)[:, 1]
64
-
65
- def predict(self, X) -> int:
80
+ return self.classifier.predict_proba(x)[:, 1]
81
+
82
+ def predict(self, x) -> int:
66
83
  """
67
84
  Make predictions using the classifier and apply the threshold.
68
85
 
69
86
  Parameters
70
87
  ----------
71
- X : np.ndarray
88
+ x : np.ndarray
72
89
  The input features.
73
90
 
74
91
  Return
@@ -79,8 +96,8 @@ class ClassifierPackage:
79
96
  """
80
97
  if not self.classifier:
81
98
  raise ValueError("Classifier is not loaded.")
82
- return int(self.predict_proba(X) >= self.threshold)
83
-
99
+ return int(self.predict_proba(x) >= self.threshold)
100
+
84
101
  def save(self, filepath: str | Path) -> None:
85
102
  """
86
103
  Save the ClassifierPackage to a file.
@@ -90,7 +107,7 @@ class ClassifierPackage:
90
107
  filepath : str
91
108
  The path to the file.
92
109
  """
93
- with open(filepath, 'wb') as f:
110
+ with open(filepath, "wb") as f:
94
111
  pickle.dump(self, f)
95
112
 
96
113
  @classmethod
@@ -109,7 +126,7 @@ class ClassifierPackage:
109
126
  The loaded classifier package.
110
127
  """
111
128
  try:
112
- with open(filepath, 'rb') as f:
129
+ with open(filepath, "rb") as f:
113
130
  return pickle.load(f)
114
131
  except Exception as e:
115
- raise ValueError(f"Failed to load classifier package: {e}") from e
132
+ raise ValueError(f"Failed to load classifier package: {e}") from e
paradigma/config.py CHANGED
@@ -1,12 +1,16 @@
1
- from typing import Dict, List
2
- from paradigma.constants import DataColumns, DataUnits
1
+ import warnings
2
+ from dataclasses import asdict
3
+
3
4
  import numpy as np
4
5
 
6
+ from paradigma.constants import DataColumns, DataUnits
7
+
8
+
5
9
  class BaseConfig:
6
10
  def __init__(self) -> None:
7
- self.meta_filename = ''
8
- self.values_filename = ''
9
- self.time_filename = ''
11
+ self.meta_filename = ""
12
+ self.values_filename = ""
13
+ self.time_filename = ""
10
14
 
11
15
  def set_sensor(self, sensor: str) -> None:
12
16
  """Sets the sensor and derived filenames"""
@@ -14,8 +18,10 @@ class BaseConfig:
14
18
  self.set_filenames(sensor)
15
19
 
16
20
  def set_filenames(self, prefix: str) -> None:
17
- """Sets the filenames based on the prefix. This method is duplicated from `gaits_analysis_config.py`.
18
-
21
+ """Sets the filenames based on the prefix.
22
+
23
+ This method is duplicated from `gaits_analysis_config.py`.
24
+
19
25
  Parameters
20
26
  ----------
21
27
  prefix : str
@@ -25,79 +31,108 @@ class BaseConfig:
25
31
  self.time_filename = f"{prefix}_time.bin"
26
32
  self.values_filename = f"{prefix}_values.bin"
27
33
 
34
+
28
35
  class IMUConfig(BaseConfig):
36
+ """
37
+ IMU configuration that uses DataColumns() to dynamically map available channels.
38
+ Works even if only accelerometer or only gyroscope data is present.
39
+ """
29
40
 
30
- def __init__(self) -> None:
41
+ def __init__(self, column_mapping: dict[str, str] | None = None) -> None:
31
42
  super().__init__()
32
-
33
- self.set_filenames('IMU')
43
+ self.set_filenames("IMU")
34
44
 
35
45
  self.acceleration_units = DataUnits.ACCELERATION
36
46
  self.rotation_units = DataUnits.ROTATION
37
-
38
47
  self.axes = ["x", "y", "z"]
39
48
 
40
- self.accelerometer_cols: List[str] = [
41
- DataColumns.ACCELEROMETER_X,
42
- DataColumns.ACCELEROMETER_Y,
43
- DataColumns.ACCELEROMETER_Z,
44
- ]
45
- self.gyroscope_cols: List[str] = [
46
- DataColumns.GYROSCOPE_X,
47
- DataColumns.GYROSCOPE_Y,
48
- DataColumns.GYROSCOPE_Z,
49
- ]
50
- self.gravity_cols: List[str] = [
51
- DataColumns.GRAV_ACCELEROMETER_X,
52
- DataColumns.GRAV_ACCELEROMETER_Y,
53
- DataColumns.GRAV_ACCELEROMETER_Z,
49
+ # Generate a default mapping or override with user-provided mapping
50
+ default_mapping = asdict(DataColumns())
51
+ self.column_mapping = {**default_mapping, **(column_mapping or {})}
52
+
53
+ self.time_colname = self.column_mapping["TIME"]
54
+
55
+ self.accelerometer_colnames: list[str] = []
56
+ self.gyroscope_colnames: list[str] = []
57
+ self.gravity_colnames: list[str] = []
58
+
59
+ self.d_channels_accelerometer: dict[str, str] = {}
60
+ self.d_channels_gyroscope: dict[str, str] = {}
61
+
62
+ accel_keys = ["ACCELEROMETER_X", "ACCELEROMETER_Y", "ACCELEROMETER_Z"]
63
+ grav_keys = [
64
+ "GRAV_ACCELEROMETER_X",
65
+ "GRAV_ACCELEROMETER_Y",
66
+ "GRAV_ACCELEROMETER_Z",
54
67
  ]
68
+ gyro_keys = ["GYROSCOPE_X", "GYROSCOPE_Y", "GYROSCOPE_Z"]
55
69
 
56
- self.d_channels_accelerometer = {
57
- DataColumns.ACCELEROMETER_X: self.acceleration_units,
58
- DataColumns.ACCELEROMETER_Y: self.acceleration_units,
59
- DataColumns.ACCELEROMETER_Z: self.acceleration_units,
60
- }
61
- self.d_channels_gyroscope = {
62
- DataColumns.GYROSCOPE_X: self.rotation_units,
63
- DataColumns.GYROSCOPE_Y: self.rotation_units,
64
- DataColumns.GYROSCOPE_Z: self.rotation_units,
70
+ if all(k in self.column_mapping for k in accel_keys):
71
+ self.accelerometer_colnames = [self.column_mapping[k] for k in accel_keys]
72
+
73
+ if all(k in self.column_mapping for k in grav_keys):
74
+ self.gravity_colnames = [self.column_mapping[k] for k in grav_keys]
75
+
76
+ self.d_channels_accelerometer = {
77
+ c: self.acceleration_units for c in self.accelerometer_colnames
78
+ }
79
+
80
+ if all(k in self.column_mapping for k in gyro_keys):
81
+ self.gyroscope_colnames = [self.column_mapping[k] for k in gyro_keys]
82
+
83
+ self.d_channels_gyroscope = {
84
+ c: self.rotation_units for c in self.gyroscope_colnames
85
+ }
86
+
87
+ self.d_channels_imu: dict[str, str] = {
88
+ **self.d_channels_accelerometer,
89
+ **self.d_channels_gyroscope,
65
90
  }
66
- self.d_channels_imu = {**self.d_channels_accelerometer, **self.d_channels_gyroscope}
67
91
 
68
92
  self.sampling_frequency = 100
69
93
  self.resampling_frequency = 100
94
+ self.tolerance = 3 * 1 / self.sampling_frequency
70
95
  self.lower_cutoff_frequency = 0.2
71
96
  self.upper_cutoff_frequency = 3.5
72
97
  self.filter_order = 4
73
98
 
99
+ # Segmentation parameters for handling non-contiguous data
100
+ self.max_segment_gap_s = 1.5
101
+ self.min_segment_length_s = 1.5
102
+
74
103
 
75
104
  class PPGConfig(BaseConfig):
76
105
 
77
- def __init__(self) -> None:
106
+ def __init__(self, column_mapping: dict[str, str] | None = None) -> None:
78
107
  super().__init__()
79
108
 
80
- self.set_filenames('PPG')
109
+ self.set_filenames("PPG")
110
+
111
+ # Generate a default mapping or override with user-provided mapping
112
+ default_mapping = asdict(DataColumns())
113
+ self.column_mapping = {**default_mapping, **(column_mapping or {})}
81
114
 
82
- self.ppg_colname = DataColumns.PPG
115
+ self.time_colname = self.column_mapping["TIME"]
116
+ self.ppg_colname = self.column_mapping["PPG"]
83
117
 
84
118
  self.sampling_frequency = 30
119
+ self.resampling_frequency = 30
120
+ self.tolerance = 3 * 1 / self.sampling_frequency
85
121
  self.lower_cutoff_frequency = 0.4
86
122
  self.upper_cutoff_frequency = 3.5
87
123
  self.filter_order = 4
88
124
 
89
- self.d_channels_ppg = {
90
- DataColumns.PPG: DataUnits.NONE
91
- }
125
+ self.d_channels_ppg = {self.ppg_colname: DataUnits.NONE}
92
126
 
93
127
 
94
128
  # Domain base configs
95
129
  class GaitConfig(IMUConfig):
96
130
 
97
- def __init__(self, step) -> None:
98
- super().__init__()
131
+ def __init__(self, step, column_mapping: dict[str, str] | None = None) -> None:
132
+ # Pass column_mapping through to IMUConfig
133
+ super().__init__(column_mapping=column_mapping)
99
134
 
100
- self.set_sensor('accelerometer')
135
+ self.set_sensor("accelerometer")
101
136
 
102
137
  # ----------
103
138
  # Segmenting
@@ -105,7 +140,7 @@ class GaitConfig(IMUConfig):
105
140
  self.max_segment_gap_s = 1.5
106
141
  self.min_segment_length_s = 1.5
107
142
 
108
- if step == 'gait':
143
+ if step == "gait":
109
144
  self.window_length_s: float = 6
110
145
  self.window_step_length_s: float = 1
111
146
  else:
@@ -120,7 +155,7 @@ class GaitConfig(IMUConfig):
120
155
  self.spectrum_high_frequency: int = int(self.sampling_frequency / 2)
121
156
 
122
157
  # Power in specified frequency bands
123
- self.d_frequency_bandwidths: Dict[str, List[float]] = {
158
+ self.d_frequency_bandwidths: dict[str, list[float]] = {
124
159
  "power_below_gait": [0.2, 0.7],
125
160
  "power_gait": [0.7, 3.5],
126
161
  "power_tremor": [3.5, 8],
@@ -140,7 +175,7 @@ class GaitConfig(IMUConfig):
140
175
  # -----------------
141
176
  # TSDF data storage
142
177
  # -----------------
143
- self.d_channels_values: Dict[str, str] = {
178
+ self.d_channels_values: dict[str, str] = {
144
179
  "accelerometer_std_norm": DataUnits.GRAVITY,
145
180
  "accelerometer_x_grav_mean": DataUnits.GRAVITY,
146
181
  "accelerometer_y_grav_mean": DataUnits.GRAVITY,
@@ -148,29 +183,33 @@ class GaitConfig(IMUConfig):
148
183
  "accelerometer_x_grav_std": DataUnits.GRAVITY,
149
184
  "accelerometer_y_grav_std": DataUnits.GRAVITY,
150
185
  "accelerometer_z_grav_std": DataUnits.GRAVITY,
151
- "accelerometer_x_power_below_gait": DataUnits.POWER_SPECTRAL_DENSITY,
152
- "accelerometer_y_power_below_gait": DataUnits.POWER_SPECTRAL_DENSITY,
153
- "accelerometer_z_power_below_gait": DataUnits.POWER_SPECTRAL_DENSITY,
154
- "accelerometer_x_power_gait": DataUnits.POWER_SPECTRAL_DENSITY,
155
- "accelerometer_y_power_gait": DataUnits.POWER_SPECTRAL_DENSITY,
156
- "accelerometer_z_power_gait": DataUnits.POWER_SPECTRAL_DENSITY,
157
- "accelerometer_x_power_tremor": DataUnits.POWER_SPECTRAL_DENSITY,
158
- "accelerometer_y_power_tremor": DataUnits.POWER_SPECTRAL_DENSITY,
159
- "accelerometer_z_power_tremor": DataUnits.POWER_SPECTRAL_DENSITY,
160
- "accelerometer_x_power_above_tremor": DataUnits.POWER_SPECTRAL_DENSITY,
161
- "accelerometer_y_power_above_tremor": DataUnits.POWER_SPECTRAL_DENSITY,
162
- "accelerometer_z_power_above_tremor": DataUnits.POWER_SPECTRAL_DENSITY,
186
+ "accelerometer_x_power_below_gait": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
187
+ "accelerometer_y_power_below_gait": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
188
+ "accelerometer_z_power_below_gait": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
189
+ "accelerometer_x_power_gait": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
190
+ "accelerometer_y_power_gait": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
191
+ "accelerometer_z_power_gait": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
192
+ "accelerometer_x_power_tremor": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
193
+ "accelerometer_y_power_tremor": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
194
+ "accelerometer_z_power_tremor": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
195
+ "accelerometer_x_power_above_tremor": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
196
+ "accelerometer_y_power_above_tremor": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
197
+ "accelerometer_z_power_above_tremor": DataUnits.POWER_SPECTRAL_DENSITY_ACC,
163
198
  "accelerometer_x_dominant_frequency": DataUnits.FREQUENCY,
164
199
  "accelerometer_y_dominant_frequency": DataUnits.FREQUENCY,
165
200
  "accelerometer_z_dominant_frequency": DataUnits.FREQUENCY,
166
201
  }
167
202
 
168
203
  for mfcc_coef in range(1, self.mfcc_n_coefficients + 1):
169
- self.d_channels_values[f"accelerometer_mfcc_{mfcc_coef}"] = DataUnits.GRAVITY
204
+ self.d_channels_values[f"accelerometer_mfcc_{mfcc_coef}"] = (
205
+ DataUnits.GRAVITY
206
+ )
170
207
 
171
- if step == 'arm_activity':
208
+ if step == "arm_activity":
172
209
  for mfcc_coef in range(1, self.mfcc_n_coefficients + 1):
173
- self.d_channels_values[f"gyroscope_mfcc_{mfcc_coef}"] = DataUnits.GRAVITY
210
+ self.d_channels_values[f"gyroscope_mfcc_{mfcc_coef}"] = (
211
+ DataUnits.GRAVITY
212
+ )
174
213
 
175
214
 
176
215
  class TremorConfig(IMUConfig):
@@ -184,7 +223,7 @@ class TremorConfig(IMUConfig):
184
223
  """
185
224
  super().__init__()
186
225
 
187
- self.set_sensor('gyroscope')
226
+ self.set_sensor("gyroscope")
188
227
 
189
228
  # ----------
190
229
  # Segmenting
@@ -195,12 +234,12 @@ class TremorConfig(IMUConfig):
195
234
  # -----------------
196
235
  # Feature extraction
197
236
  # -----------------
198
- self.window_type = 'hann'
237
+ self.window_type = "hann"
199
238
  self.overlap_fraction: float = 0.8
200
239
  self.segment_length_psd_s: float = 3
201
240
  self.segment_length_spectrogram_s: float = 2
202
241
  self.spectral_resolution: float = 0.25
203
-
242
+
204
243
  # PSD based features
205
244
  self.fmin_peak_search: float = 1
206
245
  self.fmax_peak_search: float = 25
@@ -223,95 +262,120 @@ class TremorConfig(IMUConfig):
223
262
  # -----------
224
263
  # Aggregation
225
264
  # -----------
226
- self.aggregates_tremor_power: List[str] = ['mode_binned', 'median', '90p']
227
- self.evaluation_points_tremor_power: np.ndarray = np.linspace(0, 6, 301)
265
+ self.aggregates_tremor_power: list[str] = ["mode_binned", "median", "90p"]
266
+ self.evaluation_points_tremor_power: np.ndarray = np.linspace(0, 6, 301)
228
267
 
229
268
  # -----------------
230
269
  # TSDF data storage
231
270
  # -----------------
232
- if step == 'features':
233
- self.d_channels_values: Dict[str, str] = {}
271
+ if step == "features":
272
+ self.d_channels_values: dict[str, str] = {}
234
273
  for mfcc_coef in range(1, self.n_coefficients_mfcc + 1):
235
- self.d_channels_values[f"mfcc_{mfcc_coef}"] = "unitless"
274
+ self.d_channels_values[f"mfcc_{mfcc_coef}"] = DataUnits.NONE
275
+
276
+ self.d_channels_values[DataColumns.FREQ_PEAK] = DataUnits.FREQUENCY
277
+ self.d_channels_values[DataColumns.BELOW_TREMOR_POWER] = (
278
+ DataUnits.POWER_ROTATION
279
+ )
280
+ self.d_channels_values[DataColumns.TREMOR_POWER] = DataUnits.POWER_ROTATION
236
281
 
237
- self.d_channels_values["freq_peak"] = "Hz"
238
- self.d_channels_values["below_tremor_power"] = "(deg/s)^2"
239
- self.d_channels_values["tremor_power"] = "(deg/s)^2"
240
- elif step == 'classification':
282
+ elif step == "classification":
241
283
  self.d_channels_values = {
242
284
  DataColumns.PRED_TREMOR_PROBA: "probability",
243
285
  DataColumns.PRED_TREMOR_LOGREG: "boolean",
244
286
  DataColumns.PRED_TREMOR_CHECKED: "boolean",
245
- DataColumns.PRED_ARM_AT_REST: "boolean"
287
+ DataColumns.PRED_ARM_AT_REST: "boolean",
246
288
  }
247
289
 
248
-
290
+
249
291
  class PulseRateConfig(PPGConfig):
250
- def __init__(self, sensor: str = 'ppg', min_window_length_s: int = 30) -> None:
292
+ def __init__(
293
+ self,
294
+ sensor: str = "ppg",
295
+ ppg_sampling_frequency: int = 30,
296
+ imu_sampling_frequency: int | None = None,
297
+ min_window_length_s: int = 30,
298
+ accelerometer_colnames: list[str] | None = None,
299
+ ) -> None:
251
300
  super().__init__()
252
301
 
253
- # ----------
254
- # Segmenting
255
- # ----------
302
+ self.ppg_sampling_frequency = ppg_sampling_frequency
303
+
304
+ if sensor == "imu":
305
+ if imu_sampling_frequency is not None:
306
+ self.imu_sampling_frequency = imu_sampling_frequency
307
+ else:
308
+ self.imu_sampling_frequency = IMUConfig().sampling_frequency
309
+ warnings.warn(
310
+ f"imu_sampling_frequency not provided, using default "
311
+ f"of {self.imu_sampling_frequency} Hz"
312
+ )
313
+
314
+ # Windowing parameters
256
315
  self.window_length_s: int = 6
257
316
  self.window_step_length_s: int = 1
258
317
  self.window_overlap_s = self.window_length_s - self.window_step_length_s
259
318
 
260
- self.accelerometer_cols = IMUConfig().accelerometer_cols
319
+ self.accelerometer_colnames = accelerometer_colnames
261
320
 
262
- # -----------------------
263
- # Signal quality analysis
264
- # -----------------------
265
- self.freq_band_physio = [0.75, 3] # Hz
266
- self.bandwidth = 0.2 # Hz
267
- self.freq_bin_resolution = 0.05 # Hz
321
+ # Signal quality analysis parameters
322
+ self.freq_band_physio = [0.75, 3] # Hz
323
+ self.bandwidth = 0.2 # Hz
324
+ self.freq_bin_resolution = 0.05 # Hz
268
325
 
269
- # ---------------------
270
- # Pulse rate estimation
271
- # ---------------------
272
- self.set_tfd_length(min_window_length_s) # Set tfd length to default of 30 seconds
326
+ # Pulse rate estimation parameters
273
327
  self.threshold_sqa = 0.5
274
328
  self.threshold_sqa_accelerometer = 0.10
275
329
 
330
+ # Set initial sensor and update sampling-dependent params
331
+ self.set_sensor(sensor, min_window_length_s)
332
+
333
+ def set_sensor(self, sensor: str, min_window_length_s: int | None = None) -> None:
334
+ """Sets the active sensor and recomputes sampling-dependent parameters."""
335
+ if sensor not in ["ppg", "imu"]:
336
+ raise ValueError(f"Invalid sensor type: {sensor}")
337
+ self.sensor = sensor
338
+
339
+ # Decide which frequency to use
340
+ self.sampling_frequency = (
341
+ self.imu_sampling_frequency
342
+ if sensor == "imu"
343
+ else self.ppg_sampling_frequency
344
+ )
345
+
346
+ # Update all frequency-dependent parameters
347
+ if min_window_length_s is not None:
348
+ self._update_sampling_dependent_params(min_window_length_s)
349
+ else:
350
+ # Reuse previous tfd_length if it exists, else fallback to 30
351
+ self._update_sampling_dependent_params(getattr(self, "tfd_length", 30))
352
+
353
+ def _update_sampling_dependent_params(self, tfd_length: int):
354
+ """Compute attributes that depend on sampling frequency."""
355
+
356
+ # --- PPG-dependent parameters ---
357
+ self.tfd_length = tfd_length
358
+ self.min_pr_samples = int(round(self.tfd_length * self.ppg_sampling_frequency))
359
+
276
360
  pr_est_length = 2 # pulse rate estimation length in seconds
277
- self.pr_est_samples = pr_est_length * self.sampling_frequency
361
+ self.pr_est_samples = pr_est_length * self.ppg_sampling_frequency
278
362
 
279
363
  # Time-frequency distribution parameters
280
- self.kern_type = 'sep'
281
- win_type_doppler = 'hamm'
282
- win_type_lag = 'hamm'
364
+ win_type_doppler = "hamm"
365
+ win_type_lag = "hamm"
283
366
  win_length_doppler = 8
284
367
  win_length_lag = 1
285
- doppler_samples = self.sampling_frequency * win_length_doppler
286
- lag_samples = win_length_lag * self.sampling_frequency
368
+ doppler_samples = self.ppg_sampling_frequency * win_length_doppler
369
+ lag_samples = win_length_lag * self.ppg_sampling_frequency
370
+ self.kern_type = "sep"
287
371
  self.kern_params = {
288
- 'doppler': {
289
- 'win_length': doppler_samples,
290
- 'win_type': win_type_doppler,
291
- },
292
- 'lag': {
293
- 'win_length': lag_samples,
294
- 'win_type': win_type_lag,
295
- }
372
+ "doppler": {"win_length": doppler_samples, "win_type": win_type_doppler},
373
+ "lag": {"win_length": lag_samples, "win_type": win_type_lag},
296
374
  }
297
375
 
298
- self.set_sensor(sensor)
299
-
300
- def set_tfd_length(self, tfd_length: int):
301
- self.tfd_length = tfd_length
302
- self.min_pr_samples = int(round(self.tfd_length * self.sampling_frequency))
303
-
304
- def set_sensor(self, sensor):
305
- self.sensor = sensor
306
-
307
- if sensor not in ['ppg', 'imu']:
308
- raise ValueError(f"Invalid sensor type: {sensor}")
309
-
310
- if sensor == 'imu':
311
- self.sampling_frequency = IMUConfig().sampling_frequency
312
- else:
313
- self.sampling_frequency = PPGConfig().sampling_frequency
314
-
376
+ # --- Welch / FFT parameters based on current sensor frequency ---
315
377
  self.window_length_welch = 3 * self.sampling_frequency
316
378
  self.overlap_welch_window = self.window_length_welch // 2
317
- self.nfft = len(np.arange(0, self.sampling_frequency/2, self.freq_bin_resolution))*2
379
+ self.nfft = (
380
+ len(np.arange(0, self.sampling_frequency / 2, self.freq_bin_resolution)) * 2
381
+ )