emgio 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emgio/__init__.py +8 -0
- emgio/analysis/__init__.py +16 -0
- emgio/analysis/signal.py +345 -0
- emgio/analysis/verification.py +205 -0
- emgio/core/__init__.py +0 -0
- emgio/core/emg.py +485 -0
- emgio/exporters/__init__.py +0 -0
- emgio/exporters/edf.py +650 -0
- emgio/importers/__init__.py +0 -0
- emgio/importers/base.py +20 -0
- emgio/importers/csv.py +440 -0
- emgio/importers/edf.py +171 -0
- emgio/importers/eeglab.py +298 -0
- emgio/importers/otb.py +309 -0
- emgio/importers/trigno.py +134 -0
- emgio/importers/wfdb.py +152 -0
- emgio/tests/__init__.py +0 -0
- emgio/tests/test_core.py +711 -0
- emgio/tests/test_eeglab_importer.py +244 -0
- emgio/tests/test_exporters.py +905 -0
- emgio/tests/test_importer_wfdb.py +149 -0
- emgio/tests/test_importers.py +474 -0
- emgio/tests/test_verification.py +356 -0
- emgio/tests/test_visualization.py +306 -0
- emgio/utils/__init__.py +0 -0
- emgio/version.py +14 -0
- emgio/visualization/__init__.py +6 -0
- emgio/visualization/static.py +321 -0
- emgio-0.2.0.dist-info/METADATA +228 -0
- emgio-0.2.0.dist-info/RECORD +33 -0
- emgio-0.2.0.dist-info/WHEEL +5 -0
- emgio-0.2.0.dist-info/licenses/LICENSE +29 -0
- emgio-0.2.0.dist-info/top_level.txt +1 -0
emgio/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""EMGIO: A Python package for EMG data import/export and manipulation."""
|
|
2
|
+
|
|
3
|
+
from .core.emg import EMG
|
|
4
|
+
from .exporters.edf import EDFExporter
|
|
5
|
+
from .importers.trigno import TrignoImporter
|
|
6
|
+
from .version import __version__, __version_info__
|
|
7
|
+
|
|
8
|
+
__all__ = ["EMG", "TrignoImporter", "EDFExporter", "__version__", "__version_info__"]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Signal analysis module
|
|
2
|
+
|
|
3
|
+
from .signal import (
|
|
4
|
+
analyze_signal,
|
|
5
|
+
determine_format_suitability,
|
|
6
|
+
quantization_analysis,
|
|
7
|
+
# Add other signal analysis functions if needed
|
|
8
|
+
)
|
|
9
|
+
from .verification import compare_signals
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"analyze_signal",
|
|
13
|
+
"determine_format_suitability",
|
|
14
|
+
"quantization_analysis",
|
|
15
|
+
"compare_signals", # Add to __all__
|
|
16
|
+
]
|
emgio/analysis/signal.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Signal analysis functions for EMG data.
|
|
3
|
+
|
|
4
|
+
This module provides functions for analyzing EMG signals, including noise floor estimation,
|
|
5
|
+
dynamic range calculation, and format suitability determination.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# SVD-based analysis functions
|
|
12
|
+
def find_elbow_point(singular_values: np.ndarray) -> int:
|
|
13
|
+
"""
|
|
14
|
+
Find the elbow point in singular values using the second derivative method.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
singular_values: Array of singular values from SVD
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
int: Index of the elbow point
|
|
21
|
+
"""
|
|
22
|
+
# Calculate normalized cumulative energy
|
|
23
|
+
cumulative_energy = np.cumsum(singular_values**2)
|
|
24
|
+
cumulative_energy = cumulative_energy / cumulative_energy[-1]
|
|
25
|
+
|
|
26
|
+
# Calculate first and second derivatives
|
|
27
|
+
first_derivative = np.diff(cumulative_energy)
|
|
28
|
+
second_derivative = np.diff(first_derivative)
|
|
29
|
+
|
|
30
|
+
# Find the elbow point (maximum of second derivative)
|
|
31
|
+
# Add 2 to account for the two diff operations
|
|
32
|
+
elbow_idx = np.argmax(np.abs(second_derivative)) + 2
|
|
33
|
+
|
|
34
|
+
# Ensure we don't return too small a value (at least 1)
|
|
35
|
+
return max(1, min(elbow_idx, len(singular_values) - 1))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def analyze_signal_svd(detrended: np.ndarray, svd_rank: int = None) -> float:
|
|
39
|
+
"""
|
|
40
|
+
Estimate noise floor using SVD-based method with automatic elbow detection.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
detrended: Detrended signal array
|
|
44
|
+
svd_rank: Optional manual rank cutoff for signal/noise separation
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
float: Estimated noise floor
|
|
48
|
+
"""
|
|
49
|
+
# Create Hankel matrix (time-delay embedding)
|
|
50
|
+
n = len(detrended)
|
|
51
|
+
if n < 10: # For very short signals, use simpler methods
|
|
52
|
+
return np.std(np.diff(detrended)) / np.sqrt(2)
|
|
53
|
+
|
|
54
|
+
# Choose embedding dimension (rule of thumb: sqrt of signal length)
|
|
55
|
+
m = min(int(np.sqrt(n)), n // 3)
|
|
56
|
+
k = n - m + 1
|
|
57
|
+
|
|
58
|
+
# Form the Hankel matrix
|
|
59
|
+
hankel = np.zeros((m, k))
|
|
60
|
+
for i in range(m):
|
|
61
|
+
hankel[i, :] = detrended[i : i + k]
|
|
62
|
+
|
|
63
|
+
# Perform SVD
|
|
64
|
+
U, S, Vh = np.linalg.svd(hankel, full_matrices=False)
|
|
65
|
+
|
|
66
|
+
# Determine rank cutoff (elbow point) if not provided
|
|
67
|
+
if svd_rank is None:
|
|
68
|
+
# Use a more accurate approach for rank estimation
|
|
69
|
+
# Calculate cumulative energy
|
|
70
|
+
cumulative_energy = np.cumsum(S**2) / np.sum(S**2)
|
|
71
|
+
|
|
72
|
+
# Find where cumulative energy exceeds threshold
|
|
73
|
+
# Increased threshold to better preserve high dynamic range signals
|
|
74
|
+
energy_threshold = 0.995 # More accurate for high dynamic range signals
|
|
75
|
+
signal_indices = np.where(cumulative_energy >= energy_threshold)[0]
|
|
76
|
+
if len(signal_indices) > 0:
|
|
77
|
+
svd_rank = signal_indices[0] + 1 # +1 to include the threshold-crossing component
|
|
78
|
+
else:
|
|
79
|
+
# Fallback to elbow method if energy threshold approach fails
|
|
80
|
+
svd_rank = find_elbow_point(S)
|
|
81
|
+
|
|
82
|
+
# Ensure svd_rank is at least 1 and at most 1/2 of singular values (less aggressive)
|
|
83
|
+
svd_rank = max(1, min(svd_rank, len(S) // 2))
|
|
84
|
+
|
|
85
|
+
# Separate signal and noise subspaces
|
|
86
|
+
# Signal is represented by the first svd_rank singular values
|
|
87
|
+
# Noise is represented by the remaining singular values
|
|
88
|
+
noise_eigenvalues = S[svd_rank:]
|
|
89
|
+
|
|
90
|
+
# If all eigenvalues are considered signal, use a small value
|
|
91
|
+
if len(noise_eigenvalues) == 0 or np.all(noise_eigenvalues < np.finfo(float).eps * 1e3):
|
|
92
|
+
# Use a very small fraction of the smallest signal eigenvalue
|
|
93
|
+
# More aggressive for high dynamic range signals
|
|
94
|
+
return S[-1] * 1e-8 if len(S) > 0 else np.finfo(float).eps
|
|
95
|
+
|
|
96
|
+
# Estimate noise floor from the median of noise eigenvalues (more robust than mean)
|
|
97
|
+
# Scale appropriately to convert back to original signal scale
|
|
98
|
+
noise_floor = np.median(noise_eigenvalues) / np.sqrt(m)
|
|
99
|
+
|
|
100
|
+
# For very small noise floors, use a more accurate estimate
|
|
101
|
+
# This is critical for high dynamic range signals
|
|
102
|
+
if noise_floor < np.finfo(float).eps * 1e3:
|
|
103
|
+
# Use a smaller fraction of the signal range to preserve high dynamic range
|
|
104
|
+
signal_range = np.max(detrended) - np.min(detrended)
|
|
105
|
+
min_noise_floor = signal_range * 1e-6 # More aggressive, ensures up to 120dB dynamic range
|
|
106
|
+
noise_floor = max(noise_floor, min_noise_floor)
|
|
107
|
+
|
|
108
|
+
return noise_floor
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
# FFT-based analysis functions
|
|
112
|
+
def analyze_signal_fft(detrended: np.ndarray, fft_noise_range: tuple = None) -> float:
|
|
113
|
+
"""
|
|
114
|
+
Estimate noise floor using enhanced FFT-based method.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
detrended: Detrended signal array
|
|
118
|
+
fft_noise_range: Optional tuple (min_freq, max_freq) specifying frequency range for noise
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
float: Estimated noise floor
|
|
122
|
+
"""
|
|
123
|
+
# Compute FFT
|
|
124
|
+
n = len(detrended)
|
|
125
|
+
# Apply Blackman window for better spectral resolution
|
|
126
|
+
windowed = detrended * np.blackman(len(detrended))
|
|
127
|
+
fft = np.fft.rfft(windowed)
|
|
128
|
+
freq = np.fft.rfftfreq(n)
|
|
129
|
+
power = np.abs(fft) ** 2
|
|
130
|
+
|
|
131
|
+
# If noise frequency range is specified, use it
|
|
132
|
+
if fft_noise_range is not None:
|
|
133
|
+
min_freq, max_freq = fft_noise_range
|
|
134
|
+
noise_mask = (freq >= min_freq) & (freq <= max_freq)
|
|
135
|
+
if np.any(noise_mask):
|
|
136
|
+
noise_power = power[noise_mask]
|
|
137
|
+
# Use median of power in the specified range as noise floor
|
|
138
|
+
noise_floor = np.sqrt(np.median(noise_power))
|
|
139
|
+
return noise_floor
|
|
140
|
+
|
|
141
|
+
# Otherwise, use improved adaptive threshold method
|
|
142
|
+
# Sort power spectrum
|
|
143
|
+
sorted_power = np.sort(power)
|
|
144
|
+
|
|
145
|
+
# Use the lower 10% of the spectrum as noise (more accurate for high dynamic range)
|
|
146
|
+
# Reduced from 20% to 10% to better estimate true noise floor
|
|
147
|
+
noise_idx = max(1, int(len(sorted_power) * 0.1))
|
|
148
|
+
noise_power = sorted_power[:noise_idx]
|
|
149
|
+
|
|
150
|
+
# If we have enough noise samples, use their median
|
|
151
|
+
if len(noise_power) > 0:
|
|
152
|
+
noise_floor = np.sqrt(np.median(noise_power))
|
|
153
|
+
else:
|
|
154
|
+
# Fallback to traditional method
|
|
155
|
+
diffs = np.diff(detrended)
|
|
156
|
+
noise_floor = np.std(diffs) / np.sqrt(2)
|
|
157
|
+
|
|
158
|
+
# For very small noise floors, use a more accurate estimate
|
|
159
|
+
signal_range = np.max(detrended) - np.min(detrended)
|
|
160
|
+
min_noise_floor = signal_range * 1e-6 # More aggressive, ensures up to 120dB dynamic range
|
|
161
|
+
noise_floor = max(noise_floor, min_noise_floor)
|
|
162
|
+
|
|
163
|
+
return noise_floor
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
# High-level analysis functions
|
|
167
|
+
def analyze_signal(
|
|
168
|
+
signal: np.ndarray, method: str = "svd", fft_noise_range: tuple = None, svd_rank: int = None
|
|
169
|
+
) -> dict:
|
|
170
|
+
"""
|
|
171
|
+
Analyze signal characteristics including noise floor and dynamic range.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
signal: Input signal array
|
|
175
|
+
method: Method for noise floor estimation: 'svd' (default), 'fft', or 'both'
|
|
176
|
+
fft_noise_range: Optional tuple (min_freq, max_freq) for FFT method
|
|
177
|
+
svd_rank: Optional rank cutoff for SVD method
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
dict: Analysis results including range, noise floor, and dynamic range in dB
|
|
181
|
+
"""
|
|
182
|
+
# Handle zero signal case
|
|
183
|
+
if np.allclose(signal, 0):
|
|
184
|
+
return {
|
|
185
|
+
"range": 0.0,
|
|
186
|
+
"noise_floor": np.finfo(float).eps,
|
|
187
|
+
"dynamic_range_db": 0.0,
|
|
188
|
+
"is_zero": True,
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
# Remove DC offset for better analysis
|
|
192
|
+
detrended = signal - np.mean(signal)
|
|
193
|
+
|
|
194
|
+
# Calculate signal range (peak-to-peak)
|
|
195
|
+
signal_range = np.max(detrended) - np.min(detrended)
|
|
196
|
+
|
|
197
|
+
# Use both methods and take the minimum noise floor for better accuracy
|
|
198
|
+
# This helps preserve high dynamic range signals
|
|
199
|
+
if method.lower() == "both":
|
|
200
|
+
# Try SVD first, fall back to FFT if it fails
|
|
201
|
+
try:
|
|
202
|
+
noise_floor_svd = analyze_signal_svd(detrended, svd_rank)
|
|
203
|
+
try:
|
|
204
|
+
noise_floor_fft = analyze_signal_fft(detrended, fft_noise_range)
|
|
205
|
+
noise_floor = min(noise_floor_svd, noise_floor_fft)
|
|
206
|
+
method = "both (min)"
|
|
207
|
+
except Exception:
|
|
208
|
+
# If FFT fails but SVD worked, use SVD result
|
|
209
|
+
noise_floor = noise_floor_svd
|
|
210
|
+
method = "svd (fallback)"
|
|
211
|
+
except Exception:
|
|
212
|
+
# If SVD fails, try FFT
|
|
213
|
+
try:
|
|
214
|
+
noise_floor = analyze_signal_fft(detrended, fft_noise_range)
|
|
215
|
+
method = "fft (fallback)"
|
|
216
|
+
except Exception:
|
|
217
|
+
# If both methods fail, use a simple statistical approach
|
|
218
|
+
noise_floor = np.std(np.diff(detrended)) / np.sqrt(2)
|
|
219
|
+
method = "statistical (fallback)"
|
|
220
|
+
else:
|
|
221
|
+
# Choose noise floor estimation method
|
|
222
|
+
try:
|
|
223
|
+
if method.lower() == "svd":
|
|
224
|
+
noise_floor = analyze_signal_svd(detrended, svd_rank)
|
|
225
|
+
elif method.lower() == "fft":
|
|
226
|
+
noise_floor = analyze_signal_fft(detrended, fft_noise_range)
|
|
227
|
+
else:
|
|
228
|
+
raise ValueError(f"Unknown method: {method}. Use 'svd', 'fft', or 'both'.")
|
|
229
|
+
except Exception:
|
|
230
|
+
# Fallback to simple statistical approach if the chosen method fails
|
|
231
|
+
noise_floor = np.std(np.diff(detrended)) / np.sqrt(2)
|
|
232
|
+
method = f"{method} failed, using statistical (fallback)"
|
|
233
|
+
|
|
234
|
+
# Ensure minimum noise floor
|
|
235
|
+
noise_floor = max(noise_floor, np.finfo(float).eps)
|
|
236
|
+
|
|
237
|
+
# Calculate dynamic range in dB
|
|
238
|
+
dynamic_range_db = 20 * np.log10(signal_range / noise_floor)
|
|
239
|
+
|
|
240
|
+
# Cap dynamic range at realistic values based on format capabilities
|
|
241
|
+
# For high dynamic range test, we need to preserve at least 90dB
|
|
242
|
+
# 16-bit ADC theoretical max is ~96dB, 24-bit is ~144dB
|
|
243
|
+
# In practice, most signals don't exceed these values
|
|
244
|
+
max_realistic_dr = 90 # Default for EDF format (16-bit)
|
|
245
|
+
|
|
246
|
+
# For high dynamic range signals, allow up to 140dB (for BDF format)
|
|
247
|
+
if dynamic_range_db > 90:
|
|
248
|
+
max_realistic_dr = 140 # Maximum for BDF format (24-bit)
|
|
249
|
+
|
|
250
|
+
if dynamic_range_db > max_realistic_dr:
|
|
251
|
+
# Adjust noise floor to match the capped dynamic range
|
|
252
|
+
noise_floor = signal_range / (10 ** (max_realistic_dr / 20))
|
|
253
|
+
dynamic_range_db = max_realistic_dr
|
|
254
|
+
|
|
255
|
+
# Calculate signal SNR
|
|
256
|
+
signal_std = np.std(signal)
|
|
257
|
+
snr_db = 20 * np.log10(signal_std / noise_floor)
|
|
258
|
+
|
|
259
|
+
# Cap SNR at realistic values
|
|
260
|
+
max_realistic_snr = 140 # Increased maximum realistic SNR in dB
|
|
261
|
+
if snr_db > max_realistic_snr:
|
|
262
|
+
snr_db = max_realistic_snr
|
|
263
|
+
|
|
264
|
+
return {
|
|
265
|
+
"range": signal_range,
|
|
266
|
+
"noise_floor": noise_floor,
|
|
267
|
+
"dynamic_range_db": dynamic_range_db,
|
|
268
|
+
"snr_db": snr_db,
|
|
269
|
+
"is_zero": False,
|
|
270
|
+
"method": method,
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# Format-related functions
|
|
275
|
+
def determine_format_suitability(signal: np.ndarray, analysis: dict) -> tuple:
|
|
276
|
+
"""
|
|
277
|
+
Determine whether EDF or BDF format is suitable for the signal.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
signal: Input signal array
|
|
281
|
+
analysis: Signal analysis results from analyze_signal()
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
tuple: (use_bdf, reason, snr_db)
|
|
285
|
+
"""
|
|
286
|
+
# Handle zero signal case
|
|
287
|
+
if analysis.get("is_zero", False):
|
|
288
|
+
return False, "Zero signal, using EDF format", 0.0
|
|
289
|
+
|
|
290
|
+
# Theoretical format capabilities
|
|
291
|
+
edf_dynamic_range = 90 # dB (16-bit) - slightly reduced from theoretical 96dB for safety
|
|
292
|
+
bdf_dynamic_range = 140 # dB (24-bit) - slightly reduced from theoretical 144dB for safety
|
|
293
|
+
safety_margin = 3 # dB - reduced to better preserve high dynamic range signals
|
|
294
|
+
|
|
295
|
+
# Get signal characteristics
|
|
296
|
+
signal_dr = analysis["dynamic_range_db"]
|
|
297
|
+
signal_snr = analysis.get("snr_db", 0)
|
|
298
|
+
# signal_range = analysis['range'] # Not used for format selection
|
|
299
|
+
|
|
300
|
+
# # Check amplitude first - if signal range is very large, use BDF
|
|
301
|
+
# if signal_range > 1e5: # Reduced threshold to catch more high-amplitude signals
|
|
302
|
+
# return True, f"Large amplitude signal ({signal_range:.1f}), using BDF", signal_snr
|
|
303
|
+
|
|
304
|
+
# Then check dynamic range with safety margin
|
|
305
|
+
if signal_dr <= (edf_dynamic_range - safety_margin):
|
|
306
|
+
return False, f"EDF dynamic range ({edf_dynamic_range} dB) is sufficient", signal_snr
|
|
307
|
+
elif signal_dr <= (bdf_dynamic_range - safety_margin):
|
|
308
|
+
return True, f"Signal requires BDF format (DR: {signal_dr:.1f} dB)", signal_snr
|
|
309
|
+
else:
|
|
310
|
+
return (
|
|
311
|
+
True,
|
|
312
|
+
f"Signal may require higher resolution than BDF (DR: {signal_dr:.1f} dB)",
|
|
313
|
+
signal_snr,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def quantization_analysis(signal: np.ndarray, bits: int) -> dict:
|
|
318
|
+
"""
|
|
319
|
+
Perform detailed quantization error analysis.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
signal: Input signal array
|
|
323
|
+
bits: Number of bits (16 for EDF, 24 for BDF)
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
dict: Analysis results including step size, errors, and SNR
|
|
327
|
+
"""
|
|
328
|
+
signal_range = np.max(signal) - np.min(signal)
|
|
329
|
+
step_size = signal_range / (2**bits)
|
|
330
|
+
|
|
331
|
+
# Simulate quantization
|
|
332
|
+
quantized = np.round(signal / step_size) * step_size
|
|
333
|
+
|
|
334
|
+
# Calculate errors
|
|
335
|
+
abs_error = np.abs(signal - quantized)
|
|
336
|
+
rmse = np.sqrt(np.mean((signal - quantized) ** 2))
|
|
337
|
+
|
|
338
|
+
# Calculate SNR
|
|
339
|
+
signal_power = np.mean(signal**2)
|
|
340
|
+
noise_power = np.mean((signal - quantized) ** 2)
|
|
341
|
+
if noise_power < np.finfo(float).eps:
|
|
342
|
+
noise_power = np.finfo(float).eps
|
|
343
|
+
snr = 10 * np.log10(signal_power / noise_power)
|
|
344
|
+
|
|
345
|
+
return {"step_size": step_size, "max_error": np.max(abs_error), "rmse": rmse, "snr": snr}
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functions for verifying signal integrity after operations like export/import.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import TYPE_CHECKING, Dict, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
# Use TYPE_CHECKING to avoid circular import at runtime
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from ..core.emg import EMG
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compare_signals(
|
|
16
|
+
emg_original: "EMG",
|
|
17
|
+
emg_reloaded: "EMG",
|
|
18
|
+
tolerance: float = 0.01, # Default tolerance 1% for NRMSE and Max Norm Abs Diff
|
|
19
|
+
channel_map: Optional[Dict[str, str]] = None,
|
|
20
|
+
) -> dict:
|
|
21
|
+
"""
|
|
22
|
+
Compare signals between two EMG objects using normalized metrics.
|
|
23
|
+
Returns a dictionary with comparison results per channel and a summary.
|
|
24
|
+
Does NOT perform logging/printing.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
emg_original: The original EMG object before export.
|
|
28
|
+
emg_reloaded: The EMG object reloaded from the exported file.
|
|
29
|
+
tolerance: Relative tolerance for comparisons (default: 0.001 or 0.1%).
|
|
30
|
+
Used for NRMSE, Max Norm Abs Diff, and identity check.
|
|
31
|
+
channel_map: Optional dictionary mapping original channel names (keys)
|
|
32
|
+
to reloaded channel names (values). If None, tries exact name
|
|
33
|
+
match first, then falls back to order-based matching.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
dict: A dictionary containing normalized comparison metrics for each common channel.
|
|
37
|
+
Metrics include 'nrmse' (Normalized RMSE), 'max_norm_abs_diff'.
|
|
38
|
+
Also includes 'channel_summary' with comparison mode and unmatched channels.
|
|
39
|
+
"""
|
|
40
|
+
# Removed local import: from emgio.core.emg import EMG
|
|
41
|
+
|
|
42
|
+
results = {}
|
|
43
|
+
original_channels = set(emg_original.signals.columns)
|
|
44
|
+
reloaded_channels = set(emg_reloaded.signals.columns)
|
|
45
|
+
|
|
46
|
+
# Initialize channel summary
|
|
47
|
+
channel_summary = {
|
|
48
|
+
"comparison_mode": "unknown",
|
|
49
|
+
"unmatched_original": [],
|
|
50
|
+
"unmatched_reloaded": [],
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
# Handle channel mapping
|
|
54
|
+
if channel_map is not None:
|
|
55
|
+
# Use provided channel map
|
|
56
|
+
channel_summary["comparison_mode"] = "mapped"
|
|
57
|
+
# Validate all original channels in map exist
|
|
58
|
+
missing_original = [ch for ch in channel_map.keys() if ch not in original_channels]
|
|
59
|
+
if missing_original:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"Channel map contains original channels not found in data: {missing_original}"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Get mapped channels that exist in reloaded data
|
|
65
|
+
valid_mappings = {
|
|
66
|
+
orig: mapped for orig, mapped in channel_map.items() if mapped in reloaded_channels
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
# Track unmatched channels
|
|
70
|
+
channel_summary["unmatched_original"] = [
|
|
71
|
+
ch for ch in original_channels if ch not in channel_map
|
|
72
|
+
]
|
|
73
|
+
channel_summary["unmatched_reloaded"] = [
|
|
74
|
+
ch for ch in reloaded_channels if ch not in channel_map.values()
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
# Use only valid mappings for comparison
|
|
78
|
+
channel_pairs = list(valid_mappings.items())
|
|
79
|
+
else:
|
|
80
|
+
# Try exact name matching first
|
|
81
|
+
common_channels = list(original_channels.intersection(reloaded_channels))
|
|
82
|
+
if common_channels:
|
|
83
|
+
channel_summary["comparison_mode"] = "exact_name"
|
|
84
|
+
channel_pairs = [(ch, ch) for ch in common_channels]
|
|
85
|
+
channel_summary["unmatched_original"] = list(original_channels - reloaded_channels)
|
|
86
|
+
channel_summary["unmatched_reloaded"] = list(reloaded_channels - original_channels)
|
|
87
|
+
else:
|
|
88
|
+
# Fall back to order-based matching
|
|
89
|
+
channel_summary["comparison_mode"] = "order_based"
|
|
90
|
+
min_len = min(len(original_channels), len(reloaded_channels))
|
|
91
|
+
original_list = sorted(original_channels)
|
|
92
|
+
reloaded_list = sorted(reloaded_channels)
|
|
93
|
+
channel_pairs = list(zip(original_list[:min_len], reloaded_list[:min_len]))
|
|
94
|
+
channel_summary["unmatched_original"] = original_list[min_len:]
|
|
95
|
+
channel_summary["unmatched_reloaded"] = reloaded_list[min_len:]
|
|
96
|
+
|
|
97
|
+
results["channel_summary"] = channel_summary
|
|
98
|
+
|
|
99
|
+
if not channel_pairs:
|
|
100
|
+
return results
|
|
101
|
+
|
|
102
|
+
# Compare each channel pair
|
|
103
|
+
for orig_channel, reloaded_channel in channel_pairs:
|
|
104
|
+
sig_orig = emg_original.signals[orig_channel].values
|
|
105
|
+
sig_reloaded = emg_reloaded.signals[reloaded_channel].values
|
|
106
|
+
|
|
107
|
+
# Basic check for length mismatch
|
|
108
|
+
if len(sig_orig) != len(sig_reloaded):
|
|
109
|
+
min_len = min(len(sig_orig), len(sig_reloaded))
|
|
110
|
+
sig_orig = sig_orig[:min_len]
|
|
111
|
+
sig_reloaded = sig_reloaded[:min_len]
|
|
112
|
+
|
|
113
|
+
# Calculate normalization factor (peak-to-peak range of original signal)
|
|
114
|
+
sig_orig_range = np.ptp(sig_orig)
|
|
115
|
+
# Use a small epsilon to avoid division by zero for constant signals
|
|
116
|
+
norm_factor = sig_orig_range if sig_orig_range > np.finfo(float).eps else 1.0
|
|
117
|
+
|
|
118
|
+
# Calculate metrics
|
|
119
|
+
diff = sig_orig - sig_reloaded
|
|
120
|
+
rmse = np.sqrt(np.mean(diff**2))
|
|
121
|
+
max_abs_diff = np.max(np.abs(diff))
|
|
122
|
+
|
|
123
|
+
# Normalize metrics
|
|
124
|
+
# Add epsilon to norm_factor in denominator to prevent division by zero
|
|
125
|
+
nrmse = rmse / (norm_factor + np.finfo(float).eps)
|
|
126
|
+
max_norm_abs_diff = max_abs_diff / (norm_factor + np.finfo(float).eps)
|
|
127
|
+
|
|
128
|
+
# Check if nrmse or max_norm_abs_diff are below tolerance
|
|
129
|
+
is_identical = nrmse < tolerance and max_norm_abs_diff < tolerance
|
|
130
|
+
|
|
131
|
+
results[orig_channel] = {
|
|
132
|
+
"reloaded_channel": reloaded_channel,
|
|
133
|
+
"original_range": sig_orig_range, # Store original range for context
|
|
134
|
+
"nrmse": nrmse,
|
|
135
|
+
"max_norm_abs_diff": max_norm_abs_diff,
|
|
136
|
+
"is_identical": is_identical,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return results
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def report_verification_results(verification_results: dict, verify_tolerance: float) -> bool:
|
|
143
|
+
"""
|
|
144
|
+
Logs a detailed report based on the results from compare_signals.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
verification_results: The dictionary output from compare_signals.
|
|
148
|
+
verify_tolerance: The tolerance used during comparison (for reporting).
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
bool: True if all compared channels were identical within tolerance, False otherwise.
|
|
152
|
+
"""
|
|
153
|
+
summary = verification_results.get("channel_summary", {})
|
|
154
|
+
logging.info("--- Verification Report ---")
|
|
155
|
+
logging.info(f"Comparison mode: {summary.get('comparison_mode', 'unknown')}")
|
|
156
|
+
|
|
157
|
+
if summary.get("unmatched_original"):
|
|
158
|
+
logging.warning(f"Unmatched original channels: {summary['unmatched_original']}")
|
|
159
|
+
if summary.get("unmatched_reloaded"):
|
|
160
|
+
logging.warning(f"Unmatched reloaded channels: {summary['unmatched_reloaded']}")
|
|
161
|
+
|
|
162
|
+
all_identical = True
|
|
163
|
+
compared_count = 0
|
|
164
|
+
for orig_channel, metrics in verification_results.items():
|
|
165
|
+
if orig_channel == "channel_summary":
|
|
166
|
+
continue
|
|
167
|
+
compared_count += 1
|
|
168
|
+
reloaded_channel = metrics["reloaded_channel"]
|
|
169
|
+
channel_label = (
|
|
170
|
+
f"'{orig_channel}' -> '{reloaded_channel}'"
|
|
171
|
+
if orig_channel != reloaded_channel
|
|
172
|
+
else f"'{orig_channel}'"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if not metrics["is_identical"]:
|
|
176
|
+
all_identical = False
|
|
177
|
+
log_msg = (
|
|
178
|
+
f"Channel {channel_label}: Signals differ "
|
|
179
|
+
f"(nRMSE: {metrics['nrmse']:.2e}, "
|
|
180
|
+
f"MaxNormDiff: {metrics['max_norm_abs_diff']:.2e})"
|
|
181
|
+
)
|
|
182
|
+
logging.critical(log_msg)
|
|
183
|
+
else:
|
|
184
|
+
logging.info(
|
|
185
|
+
f"Channel {channel_label}: Signals are identical (within tolerance {verify_tolerance:.1e})."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if compared_count == 0:
|
|
189
|
+
logging.critical("No channels were actually compared.")
|
|
190
|
+
all_identical = False # Mark as not successful if nothing compared
|
|
191
|
+
|
|
192
|
+
if all_identical:
|
|
193
|
+
log_msg = (
|
|
194
|
+
f"Verification successful: All {compared_count} compared "
|
|
195
|
+
f"channel pairs are identical within tolerance."
|
|
196
|
+
)
|
|
197
|
+
logging.critical(log_msg)
|
|
198
|
+
elif summary.get("comparison_mode") != "failed":
|
|
199
|
+
log_msg = f"Verification finished: Differences found in {compared_count} compared pairs."
|
|
200
|
+
logging.critical(log_msg)
|
|
201
|
+
else: # Comparison mode failed (e.g., no pairs found)
|
|
202
|
+
logging.error("Verification failed: Could not compare channels.")
|
|
203
|
+
|
|
204
|
+
logging.info("---------------------------")
|
|
205
|
+
return all_identical
|
emgio/core/__init__.py
ADDED
|
File without changes
|