openecg 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openecg/__init__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
openecg/codec.py ADDED
@@ -0,0 +1,144 @@
1
+ """OpenECG token format codec — uint16 pack/unpack, frame expansion, ASCII render.
2
+
3
+ Spec: docs/superpowers/specs/2026-05-03-ecgcode-stage1-design.md §5
4
+ """
5
+
6
+ import numpy as np
7
+
8
+ from openecg import vocab
9
+
10
+ MS_PER_UNIT = 4
11
+ MAX_LENGTH_MS = 255 * MS_PER_UNIT # 1020 ms
12
+
13
+
14
+ def _split_long(sym: int, length_ms: int) -> list[tuple[int, int]]:
15
+ chunks = []
16
+ while length_ms > MAX_LENGTH_MS:
17
+ chunks.append((sym, MAX_LENGTH_MS))
18
+ length_ms -= MAX_LENGTH_MS
19
+ if length_ms > 0:
20
+ chunks.append((sym, length_ms))
21
+ return chunks
22
+
23
+
24
+ def encode(events: list[tuple[int, int]]) -> np.ndarray:
25
+ """Pack (symbol_id, length_ms) events to uint16 array.
26
+
27
+ Length is snapped to nearest 4ms (min 4ms). Long events split at 1020ms.
28
+ Raises ValueError if any event has length 0 or invalid symbol_id.
29
+ """
30
+ out = []
31
+ for sym, ms in events:
32
+ if ms <= 0:
33
+ raise ValueError(f"Event length must be positive, got {ms}")
34
+ if not (0 <= sym <= 255):
35
+ raise ValueError(f"Invalid symbol_id {sym}")
36
+ units = max(1, round(ms / MS_PER_UNIT))
37
+ snapped_ms = units * MS_PER_UNIT
38
+ for chunk_sym, chunk_ms in _split_long(sym, snapped_ms):
39
+ chunk_units = chunk_ms // MS_PER_UNIT
40
+ packed = (chunk_sym << 8) | chunk_units
41
+ out.append(packed)
42
+ return np.array(out, dtype=np.uint16)
43
+
44
+
45
+ def decode(packed: np.ndarray) -> list[tuple[int, int]]:
46
+ """Unpack uint16 array to (symbol_id, length_ms) events.
47
+
48
+ Consecutive runs of the same symbol are merged so that long events split
49
+ by `encode` (>1020ms) round-trip losslessly.
50
+ """
51
+ types = (packed >> 8).astype(np.uint8)
52
+ units = (packed & 0xFF).astype(np.uint8)
53
+ out: list[tuple[int, int]] = []
54
+ for t, u in zip(types, units):
55
+ sym = int(t)
56
+ ms = int(u) * MS_PER_UNIT
57
+ if out and out[-1][0] == sym:
58
+ out[-1] = (sym, out[-1][1] + ms)
59
+ else:
60
+ out.append((sym, ms))
61
+ return out
62
+
63
+
64
+ def to_frames(
65
+ events: list[tuple[int, int]],
66
+ frame_ms: int = 20,
67
+ total_ms: int | None = None,
68
+ ) -> np.ndarray:
69
+ """Expand RLE events to per-frame symbol array.
70
+
71
+ Rule: each frame gets the symbol with maximum overlap, with `*` (pacer)
72
+ as priority override (any frame containing a spike -> pacer label).
73
+ """
74
+ if total_ms is None:
75
+ total_ms = sum(ms for _, ms in events)
76
+ n_frames = round(total_ms / frame_ms)
77
+ out = np.zeros(n_frames, dtype=np.uint8)
78
+
79
+ # Build (start_ms, end_ms, sym) intervals
80
+ intervals = []
81
+ cum = 0
82
+ for sym, ms in events:
83
+ intervals.append((cum, cum + ms, sym))
84
+ cum += ms
85
+
86
+ for f in range(n_frames):
87
+ f_start = f * frame_ms
88
+ f_end = f_start + frame_ms
89
+ overlap = {}
90
+ spike_present = False
91
+ for s_start, s_end, sym in intervals:
92
+ if s_end <= f_start:
93
+ continue
94
+ if s_start >= f_end:
95
+ break
96
+ ov = min(s_end, f_end) - max(s_start, f_start)
97
+ if ov > 0:
98
+ if sym == vocab.ID_PACER:
99
+ spike_present = True
100
+ else:
101
+ overlap[sym] = overlap.get(sym, 0) + ov
102
+ if spike_present:
103
+ out[f] = vocab.ID_PACER
104
+ elif overlap:
105
+ out[f] = max(overlap, key=overlap.get)
106
+ else:
107
+ out[f] = vocab.ID_ISO
108
+ return out
109
+
110
+
111
+ def render_compact(events: list[tuple[int, int]]) -> str:
112
+ """One char per event."""
113
+ return "".join(vocab.ID_TO_CHAR[sym] for sym, _ in events)
114
+
115
+
116
+ def render_timed(events: list[tuple[int, int]], ms_per_char: int = 20) -> str:
117
+ """Char count proportional to duration. Minimum 1 char per event."""
118
+ chars = []
119
+ for sym, ms in events:
120
+ n = max(1, round(ms / ms_per_char))
121
+ chars.append(vocab.ID_TO_CHAR[sym] * n)
122
+ return "".join(chars)
123
+
124
+
125
+ def render_json(events: list[tuple[int, int]]) -> list[dict]:
126
+ """Verbose JSON view: [{'sym': name, 'ms': length}, ...]."""
127
+ return [{"sym": vocab.ID_TO_NAME[sym], "ms": ms} for sym, ms in events]
128
+
129
+
130
+ def from_frames(frames: np.ndarray, frame_ms: int = 20) -> list[tuple[int, int]]:
131
+ """Run-length encode per-frame array to list of (symbol_id, length_ms) events.
132
+
133
+ Inverse of to_frames at frame granularity. Output durations are multiples of frame_ms.
134
+ """
135
+ if len(frames) == 0:
136
+ return []
137
+ change_idx = np.flatnonzero(np.diff(frames)) + 1
138
+ boundaries = np.concatenate(([0], change_idx, [len(frames)]))
139
+ events = []
140
+ for start, end in zip(boundaries[:-1], boundaries[1:]):
141
+ sym = int(frames[start])
142
+ n_frames = end - start
143
+ events.append((sym, n_frames * frame_ms))
144
+ return events
openecg/delineate.py ADDED
@@ -0,0 +1,71 @@
1
+ # openecg/delineate.py
2
+ """NeuroKit2 ecg_delineate wrapper.
3
+
4
+ NK provides per-beat onset/peak/offset for P, QRS, T plus separate Q/S peaks.
5
+ Missing waves are marked with NaN inside NK output; we keep that and let
6
+ labeler handle (via np.isnan checks).
7
+
8
+ Spec: docs/superpowers/specs/2026-05-03-ecgcode-stage1-design.md §6
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+
13
+ import neurokit2 as nk
14
+ import numpy as np
15
+
16
+
17
+ @dataclass
18
+ class DelineateResult:
19
+ """Per-lead wave delineation output. All arrays length = num beats.
20
+ Missing wave indices stored as NaN."""
21
+ p_onsets: np.ndarray
22
+ p_peaks: np.ndarray
23
+ p_offsets: np.ndarray
24
+ q_peaks: np.ndarray
25
+ r_onsets: np.ndarray
26
+ r_peaks: np.ndarray
27
+ r_offsets: np.ndarray
28
+ s_peaks: np.ndarray
29
+ t_onsets: np.ndarray
30
+ t_peaks: np.ndarray
31
+ t_offsets: np.ndarray
32
+
33
+ @classmethod
34
+ def empty(cls) -> "DelineateResult":
35
+ e = np.array([], dtype=float)
36
+ return cls(*([e] * 11))
37
+
38
+ @property
39
+ def n_beats(self) -> int:
40
+ return len(self.r_peaks)
41
+
42
+
43
+ def run(signal: np.ndarray, fs: int = 500, method: str = "dwt") -> DelineateResult:
44
+ """Run NK ecg_peaks + ecg_delineate. Returns DelineateResult.
45
+
46
+ On any NK exception or 0 R peaks detected, returns DelineateResult.empty().
47
+ """
48
+ try:
49
+ _, info = nk.ecg_peaks(signal, sampling_rate=fs)
50
+ rpeaks = np.asarray(info["ECG_R_Peaks"], dtype=float)
51
+ if len(rpeaks) == 0:
52
+ return DelineateResult.empty()
53
+ _, waves = nk.ecg_delineate(
54
+ signal, rpeaks=rpeaks.astype(int), sampling_rate=fs, method=method
55
+ )
56
+ except Exception:
57
+ return DelineateResult.empty()
58
+
59
+ return DelineateResult(
60
+ p_onsets=np.asarray(waves["ECG_P_Onsets"], dtype=float),
61
+ p_peaks=np.asarray(waves["ECG_P_Peaks"], dtype=float),
62
+ p_offsets=np.asarray(waves["ECG_P_Offsets"], dtype=float),
63
+ q_peaks=np.asarray(waves["ECG_Q_Peaks"], dtype=float),
64
+ r_onsets=np.asarray(waves["ECG_R_Onsets"], dtype=float),
65
+ r_peaks=rpeaks,
66
+ r_offsets=np.asarray(waves["ECG_R_Offsets"], dtype=float),
67
+ s_peaks=np.asarray(waves["ECG_S_Peaks"], dtype=float),
68
+ t_onsets=np.asarray(waves["ECG_T_Onsets"], dtype=float),
69
+ t_peaks=np.asarray(waves["ECG_T_Peaks"], dtype=float),
70
+ t_offsets=np.asarray(waves["ECG_T_Offsets"], dtype=float),
71
+ )
openecg/eval.py ADDED
@@ -0,0 +1,189 @@
1
+ """Evaluation metrics — frame-level F1 (4-class) + boundary error (Martinez-style).
2
+
3
+ Spec: docs/superpowers/specs/2026-05-03-ecgcode-stage1-design.md §7
4
+ """
5
+
6
+ import numpy as np
7
+
8
+ from openecg import vocab
9
+
10
+ # Supercategory IDs for LUDB-compat 4-class comparison
11
+ SUPER_OTHER = 0
12
+ SUPER_P = 1
13
+ SUPER_QRS = 2
14
+ SUPER_T = 3
15
+ SUPER_NAMES = {SUPER_OTHER: "other", SUPER_P: "P", SUPER_QRS: "QRS", SUPER_T: "T"}
16
+
17
+ # Sentinel for masked frames (boundary regions where the model has one-sided
18
+ # context and predictions are unreliable). PyTorch cross_entropy supports
19
+ # `ignore_index` natively; our focal_cross_entropy does too.
20
+ IGNORE_INDEX = 255
21
+
22
+ _SUPER_MAP = {
23
+ vocab.ID_PAD: SUPER_OTHER,
24
+ vocab.ID_UNK: SUPER_OTHER,
25
+ vocab.ID_ISO: SUPER_OTHER,
26
+ vocab.ID_PACER: SUPER_OTHER,
27
+ vocab.ID_P: SUPER_P,
28
+ vocab.ID_Q: SUPER_QRS,
29
+ vocab.ID_R: SUPER_QRS,
30
+ vocab.ID_S: SUPER_QRS,
31
+ vocab.ID_W: SUPER_QRS,
32
+ vocab.ID_T: SUPER_T,
33
+ vocab.ID_U: SUPER_T, # U is repolarization-adjacent
34
+ vocab.ID_D: SUPER_QRS, # delta is QRS-adjacent
35
+ vocab.ID_J: SUPER_QRS,
36
+ }
37
+
38
+
39
+ def to_supercategory(frames: np.ndarray) -> np.ndarray:
40
+ """Map per-frame v1 alphabet IDs → LUDB-compat 4-class."""
41
+ out = np.zeros_like(frames, dtype=np.uint8)
42
+ for src, dst in _SUPER_MAP.items():
43
+ out[frames == src] = dst
44
+ return out
45
+
46
+
47
+ def frame_f1(pred: np.ndarray, true: np.ndarray) -> dict:
48
+ """Per-supercategory precision/recall/F1.
49
+
50
+ Frames where `true == IGNORE_INDEX` are excluded from all counts (TP/FP/FN).
51
+ Pred values at those positions are also ignored regardless of their value.
52
+
53
+ Returns: {super_id: {'precision': p, 'recall': r, 'f1': f, 'tp', 'fp', 'fn'}}
54
+ """
55
+ valid = true != IGNORE_INDEX
56
+ pred = pred[valid]
57
+ true = true[valid]
58
+ out = {}
59
+ for sc in (SUPER_OTHER, SUPER_P, SUPER_QRS, SUPER_T):
60
+ tp = int(np.sum((pred == sc) & (true == sc)))
61
+ fp = int(np.sum((pred == sc) & (true != sc)))
62
+ fn = int(np.sum((pred != sc) & (true == sc)))
63
+ p = tp / (tp + fp) if (tp + fp) > 0 else 0.0
64
+ r = tp / (tp + fn) if (tp + fn) > 0 else 0.0
65
+ f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
66
+ out[sc] = {"precision": p, "recall": r, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
67
+ return out
68
+
69
+
70
+ def boundary_error(
71
+ pred_indices: list[int],
72
+ true_indices: list[int],
73
+ tolerance_ms: float,
74
+ fs: int,
75
+ ) -> dict:
76
+ """Greedy nearest-match boundary comparison (Martinez 2004 style).
77
+
78
+ For each true boundary, find nearest predicted within tolerance.
79
+ Returns sensitivity, PPV, mean/median/p95 error in ms, hit/miss counts.
80
+ """
81
+ tolerance_samples = tolerance_ms * fs / 1000
82
+ pred_arr = np.sort(np.array(pred_indices, dtype=int))
83
+ true_arr = np.sort(np.array(true_indices, dtype=int))
84
+
85
+ if len(true_arr) == 0 and len(pred_arr) == 0:
86
+ return _empty_boundary_result()
87
+
88
+ matched_pred = set()
89
+ errors_samples = []
90
+ n_hits = 0
91
+
92
+ for t in true_arr:
93
+ if len(pred_arr) == 0:
94
+ break
95
+ best_idx = -1
96
+ best_err = float("inf")
97
+ for j, p in enumerate(pred_arr):
98
+ if j in matched_pred:
99
+ continue
100
+ err = abs(int(p) - int(t))
101
+ if err < best_err:
102
+ best_err = err
103
+ best_idx = j
104
+ if best_idx >= 0 and best_err <= tolerance_samples:
105
+ matched_pred.add(best_idx)
106
+ errors_samples.append(best_err)
107
+ n_hits += 1
108
+
109
+ if errors_samples:
110
+ errors_ms = np.array(errors_samples) * 1000.0 / fs
111
+ mean_err = float(np.mean(errors_ms))
112
+ median_err = float(np.median(errors_ms))
113
+ p95_err = float(np.percentile(errors_ms, 95))
114
+ else:
115
+ mean_err = median_err = p95_err = 0.0
116
+
117
+ sensitivity = n_hits / len(true_arr) if len(true_arr) > 0 else 0.0
118
+ ppv = n_hits / len(pred_arr) if len(pred_arr) > 0 else 0.0
119
+
120
+ return {
121
+ "sensitivity": sensitivity,
122
+ "ppv": ppv,
123
+ "n_hits": n_hits,
124
+ "n_true": int(len(true_arr)),
125
+ "n_pred": int(len(pred_arr)),
126
+ "mean_error_ms": mean_err,
127
+ "median_error_ms": median_err,
128
+ "p95_error_ms": p95_err,
129
+ }
130
+
131
+
132
+ def boundary_f1(pred_indices, true_indices, tolerance_ms, fs):
133
+ """Compute F1 from boundary_error sensitivity and PPV.
134
+
135
+ Literature standard: each true boundary matched to a predicted boundary within
136
+ tolerance_ms gives a TP. Sensitivity = TP / |true|, PPV = TP / |pred|.
137
+ F1 = 2 * sens * PPV / (sens + PPV).
138
+ """
139
+ res = boundary_error(pred_indices, true_indices, tolerance_ms=tolerance_ms, fs=fs)
140
+ sens = res["sensitivity"]
141
+ ppv = res["ppv"]
142
+ f1 = 2 * sens * ppv / (sens + ppv) if (sens + ppv) > 0 else 0.0
143
+ return {**res, "f1": f1}
144
+
145
+
146
+ def _empty_boundary_result():
147
+ return {
148
+ "sensitivity": 0.0, "ppv": 0.0, "n_hits": 0, "n_true": 0, "n_pred": 0,
149
+ "mean_error_ms": 0.0, "median_error_ms": 0.0, "p95_error_ms": 0.0,
150
+ }
151
+
152
+
153
+ def events_to_super_frames(events, n_samples, fs=500, frame_ms=20):
154
+ """Pipeline events → per-frame supercategory array.
155
+ Used by validate_v1.py and ablate_methods.py."""
156
+ from openecg import codec
157
+ total_ms = round(n_samples * 1000.0 / fs)
158
+ frames = codec.to_frames(events, frame_ms=frame_ms, total_ms=total_ms)
159
+ return to_supercategory(frames)
160
+
161
+
162
+ def gt_to_super_frames(gt_ann, n_samples, fs=500, frame_ms=20):
163
+ """LUDB cardiologist annotation dict → per-frame supercategory array (majority per frame).
164
+
165
+ samples_per_frame is fixed by physical time (fs * frame_ms / 1000) so each
166
+ output frame represents exactly frame_ms of signal. n_frames = n_samples //
167
+ samples_per_frame; trailing samples that don't fit a full frame are dropped.
168
+ Earlier versions computed samples_per_frame = n_samples // n_frames, which
169
+ introduced cumulative time drift (up to 500ms by frame 499) when n_samples
170
+ was not a clean multiple of samples-per-frame (e.g. ISP records of 9998-9999
171
+ samples at 1000Hz with frame_ms=20).
172
+ """
173
+ samples_per_frame = round(fs * frame_ms / 1000.0)
174
+ if samples_per_frame < 1:
175
+ samples_per_frame = 1
176
+ sample_labels = np.full(n_samples, SUPER_OTHER, dtype=np.uint8)
177
+ for on, off in zip(gt_ann["p_on"], gt_ann["p_off"]):
178
+ sample_labels[on:off + 1] = SUPER_P
179
+ for on, off in zip(gt_ann["qrs_on"], gt_ann["qrs_off"]):
180
+ sample_labels[on:off + 1] = SUPER_QRS
181
+ for on, off in zip(gt_ann["t_on"], gt_ann["t_off"]):
182
+ sample_labels[on:off + 1] = SUPER_T
183
+ n_frames = n_samples // samples_per_frame
184
+ out = np.zeros(n_frames, dtype=np.uint8)
185
+ for f in range(n_frames):
186
+ seg = sample_labels[f * samples_per_frame: (f + 1) * samples_per_frame]
187
+ vals, counts = np.unique(seg, return_counts=True)
188
+ out[f] = vals[np.argmax(counts)]
189
+ return out
openecg/isp.py ADDED
@@ -0,0 +1,110 @@
1
+ # openecg/isp.py
2
+ """ISP ECG delineation dataset loader.
3
+
4
+ Source: https://zenodo.org/records/14679837 (475 records, 12-lead, ~10s @ 1000Hz,
5
+ 2-cardiologist annotations of P/QRS/T onset+offset).
6
+
7
+ Target format: list of (class_id, onset_sample, offset_sample) where class 0=P, 1=QRS, 2=T.
8
+ """
9
+
10
+ import csv
11
+ import os
12
+ import re
13
+ import zipfile
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+ import wfdb
18
+
19
+ ISP_INNER_DIR = "isp_delineation_dataset"
20
+ LEADS_12 = ("i", "ii", "iii", "avr", "avl", "avf",
21
+ "v1", "v2", "v3", "v4", "v5", "v6")
22
+ FS_NATIVE = 1000
23
+
24
+ _TUPLE_RE = re.compile(r"\((\d+)\s*,\s*(\d+)\s*,\s*(\d+)\)")
25
+
26
+
27
+ def _zip_path() -> Path:
28
+ p = os.environ.get("OPENECG_ISP_ZIP")
29
+ if not p:
30
+ raise FileNotFoundError("Set OPENECG_ISP_ZIP env var")
31
+ return Path(p)
32
+
33
+
34
+ def _cache_path() -> Path:
35
+ p = os.environ.get("OPENECG_ISP_CACHE")
36
+ if p:
37
+ return Path(p).expanduser()
38
+ return Path.home() / ".cache" / "openecg" / "isp"
39
+
40
+
41
+ def ensure_extracted() -> Path:
42
+ cache = _cache_path()
43
+ inner = cache / ISP_INNER_DIR
44
+ if inner.exists():
45
+ return inner
46
+ cache.mkdir(parents=True, exist_ok=True)
47
+ with zipfile.ZipFile(_zip_path()) as z:
48
+ z.extractall(cache)
49
+ return inner
50
+
51
+
52
+ def _parse_target(s: str) -> list[tuple[int, int, int]]:
53
+ """Parse target string like '[(0, 100, 150), (1, 200, 250), ...]' safely via regex."""
54
+ return [(int(a), int(b), int(c)) for a, b, c in _TUPLE_RE.findall(s)]
55
+
56
+
57
+ def _load_csv(path: Path) -> dict[int, list[tuple[int, int, int]]]:
58
+ """Returns {file_name (int): list of (class, onset, offset)}."""
59
+ out = {}
60
+ with open(path) as f:
61
+ for row in csv.DictReader(f):
62
+ fid = int(row["file_name"])
63
+ out[fid] = _parse_target(row["target"])
64
+ return out
65
+
66
+
67
+ def load_split() -> dict[str, list[int]]:
68
+ """ISP's predefined train/test split. Returns {'train': [...], 'test': [...]}."""
69
+ inner = ensure_extracted()
70
+ train_ann = _load_csv(inner / "train_isp_delineation_data.csv")
71
+ test_ann = _load_csv(inner / "test_isp_delineation_data.csv")
72
+ return {"train": sorted(train_ann.keys()), "test": sorted(test_ann.keys())}
73
+
74
+
75
+ def _load_annotations(record_id: int, split: str) -> list[tuple[int, int, int]]:
76
+ """split: 'train' or 'test'. Returns annotation list."""
77
+ inner = ensure_extracted()
78
+ csv_path = inner / f"{split}_isp_delineation_data.csv"
79
+ ann = _load_csv(csv_path)
80
+ return ann.get(record_id, [])
81
+
82
+
83
+ def load_record(record_id: int, split: str = "train") -> dict[str, np.ndarray]:
84
+ """Load 12-lead ECG. Returns {lead_name: signal[9998 or so]} at 1000Hz."""
85
+ inner = ensure_extracted()
86
+ record_path = str(inner / f"{split}_data" / str(record_id))
87
+ record = wfdb.rdrecord(record_path)
88
+ return {lead: record.p_signal[:, i].astype(np.float64)
89
+ for i, lead in enumerate(LEADS_12)}
90
+
91
+
92
+ def load_annotations_as_super(record_id: int, split: str = "train") -> dict[str, list[int]]:
93
+ """Convert ISP annotation tuples to LUDB-style dict for use with gt_to_super_frames.
94
+ Maps class 0->P, 1->QRS, 2->T."""
95
+ raw = _load_annotations(record_id, split)
96
+ out = {"p_on": [], "p_off": [],
97
+ "qrs_on": [], "qrs_off": [],
98
+ "t_on": [], "t_off": [],
99
+ "p_peak": [], "qrs_peak": [], "t_peak": []} # peaks unused but expected
100
+ for cls, on, off in raw:
101
+ if cls == 0:
102
+ out["p_on"].append(on)
103
+ out["p_off"].append(off)
104
+ elif cls == 1:
105
+ out["qrs_on"].append(on)
106
+ out["qrs_off"].append(off)
107
+ elif cls == 2:
108
+ out["t_on"].append(on)
109
+ out["t_off"].append(off)
110
+ return out
openecg/labeler.py ADDED
@@ -0,0 +1,154 @@
1
+ # openecg/labeler.py
2
+ """Convert NK delineate output + pacer spikes -> RLE token stream.
3
+
4
+ Algorithm (per the spec section 6):
5
+ 1. Initialize sample-level array as iso
6
+ 2. Mark P / T regions
7
+ 3. Decompose each QRS into q/r/s by midpoint, with wide-QRS fallback (w)
8
+ 4. Override with pacer spikes (priority: spike > wave > iso)
9
+ 5. Run-length compress to (symbol_id, length_ms) events
10
+ """
11
+
12
+ import numpy as np
13
+
14
+ from openecg import vocab
15
+ from openecg.delineate import DelineateResult
16
+
17
+ WIDE_QRS_THRESHOLD_MS = 120.0
18
+
19
+
20
+ def _safe_int(x, n: int):
21
+ """Cast NK index (float, possibly NaN) to int and clamp to [0, n-1]. Returns None if NaN."""
22
+ if x is None:
23
+ return None
24
+ try:
25
+ if np.isnan(x):
26
+ return None
27
+ except (TypeError, ValueError):
28
+ pass
29
+ return max(0, min(n - 1, int(x)))
30
+
31
+
32
+ def _has(x) -> bool:
33
+ if x is None:
34
+ return False
35
+ try:
36
+ return not np.isnan(x)
37
+ except (TypeError, ValueError):
38
+ return True
39
+
40
+
41
+ def label(
42
+ dr: DelineateResult,
43
+ spike_idx,
44
+ n_samples: int,
45
+ fs: int = 500,
46
+ ) -> list:
47
+ """Build sample-level label array, then run-length compress to RLE events.
48
+
49
+ Returns list of (symbol_id, length_ms) tuples.
50
+ """
51
+ ms_per_sample = 1000.0 / fs
52
+
53
+ # NK total failure -> entire signal as one ? event
54
+ if dr.n_beats == 0:
55
+ return [(vocab.ID_UNK, int(round(n_samples * ms_per_sample)))]
56
+
57
+ labels = np.full(n_samples, vocab.ID_ISO, dtype=np.uint8)
58
+
59
+ # 1. P waves
60
+ for on_f, off_f in zip(dr.p_onsets, dr.p_offsets):
61
+ if not (_has(on_f) and _has(off_f)):
62
+ continue
63
+ on = _safe_int(on_f, n_samples)
64
+ off = _safe_int(off_f, n_samples)
65
+ labels[on:off + 1] = vocab.ID_P
66
+
67
+ # 2. T waves
68
+ for on_f, off_f in zip(dr.t_onsets, dr.t_offsets):
69
+ if not (_has(on_f) and _has(off_f)):
70
+ continue
71
+ on = _safe_int(on_f, n_samples)
72
+ off = _safe_int(off_f, n_samples)
73
+ labels[on:off + 1] = vocab.ID_T
74
+
75
+ # 3. QRS - q/r/s decomposition with wide-QRS fallback
76
+ n_beats = dr.n_beats
77
+ for i in range(n_beats):
78
+ if not (_has(dr.r_onsets[i]) and _has(dr.r_offsets[i])):
79
+ continue
80
+ on = _safe_int(dr.r_onsets[i], n_samples)
81
+ off = _safe_int(dr.r_offsets[i], n_samples)
82
+ r = _safe_int(dr.r_peaks[i], n_samples)
83
+ q_raw = dr.q_peaks[i] if i < len(dr.q_peaks) else None
84
+ s_raw = dr.s_peaks[i] if i < len(dr.s_peaks) else None
85
+ q = _safe_int(q_raw, n_samples)
86
+ s = _safe_int(s_raw, n_samples)
87
+
88
+ # Duration in ms: (off - on) samples gives the span
89
+ qrs_ms = (off - on) * ms_per_sample
90
+ has_q = q is not None
91
+ has_s = s is not None
92
+
93
+ # Wide-QRS fallback: no Q peak AND no S peak AND duration > 120ms
94
+ if not has_q and not has_s and qrs_ms > WIDE_QRS_THRESHOLD_MS:
95
+ labels[on:off + 1] = vocab.ID_W
96
+ continue
97
+
98
+ # Standard q/r/s decomposition by midpoints
99
+ if r is None:
100
+ # No R peak available; treat whole QRS as r
101
+ labels[on:off + 1] = vocab.ID_R
102
+ continue
103
+
104
+ q_end = (q + r) // 2 if has_q else on
105
+ s_start = (r + s) // 2 if has_s else off + 1
106
+
107
+ if has_q:
108
+ labels[on:q_end] = vocab.ID_Q
109
+ labels[q_end:s_start] = vocab.ID_R
110
+ if has_s:
111
+ labels[s_start:off + 1] = vocab.ID_S
112
+
113
+ # 4. Pacer spikes - highest priority override
114
+ for idx in spike_idx:
115
+ idx_int = int(idx)
116
+ if 0 <= idx_int < n_samples:
117
+ labels[idx_int] = vocab.ID_PACER
118
+
119
+ # 5. RLE compress
120
+ return _rle_compress(labels, ms_per_sample)
121
+
122
+
123
+ def _rle_compress(labels: np.ndarray, ms_per_sample: float) -> list:
124
+ """Group consecutive identical labels -> list of (symbol_id, length_ms).
125
+
126
+ Lengths are snapped to the codec grid (4ms) using cumulative rounding so
127
+ rounding error never drifts more than 4ms from the true total. Single-sample
128
+ pacer spikes are always emitted as 4ms (the minimum codec quantum).
129
+ """
130
+ from openecg.codec import MS_PER_UNIT
131
+
132
+ if len(labels) == 0:
133
+ return []
134
+ change_idx = np.flatnonzero(np.diff(labels)) + 1
135
+ boundaries = np.concatenate(([0], change_idx, [len(labels)]))
136
+
137
+ events = []
138
+ cum_true_ms = 0.0
139
+ cum_emitted_ms = 0
140
+ for start, end in zip(boundaries[:-1], boundaries[1:]):
141
+ sym = int(labels[start])
142
+ n = int(end - start)
143
+ cum_true_ms += n * ms_per_sample
144
+ # Snap cumulative end to 4ms grid; segment length = grid_end - prev_emitted
145
+ grid_end = int(round(cum_true_ms / MS_PER_UNIT)) * MS_PER_UNIT
146
+ ms = grid_end - cum_emitted_ms
147
+ if ms <= 0:
148
+ # Segment too short to land on a new grid line; force 1 quantum so
149
+ # we don't drop the symbol (e.g. a single-sample pacer spike).
150
+ ms = MS_PER_UNIT
151
+ grid_end = cum_emitted_ms + MS_PER_UNIT
152
+ events.append((sym, ms))
153
+ cum_emitted_ms = grid_end
154
+ return events