w2t-bkin 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,48 @@
1
+ """Pose estimation processing module.
2
+
3
+ Provides functions for importing, harmonizing, and building NWB-native
4
+ pose estimation data from DeepLabCut and SLEAP.
5
+
6
+ Key Functions:
7
+ --------------
8
+ - import_dlc_pose: Import DeepLabCut H5 files
9
+ - import_sleap_pose: Import SLEAP H5 files
10
+ - harmonize_to_canonical: Map keypoints to canonical skeleton
11
+ - build_pose_estimation: Build ndx-pose PoseEstimation objects
12
+ - build_pose_estimation_series: Build individual PoseEstimationSeries
13
+ - create_skeleton: Create Skeleton objects for pose data
14
+
15
+ Re-exported ndx-pose classes:
16
+ ------------------------------
17
+ - PoseEstimation: Main container for pose estimation data
18
+ - PoseEstimationSeries: Time series for individual keypoints
19
+ - Skeleton: Skeleton definition with nodes and edges
20
+ - Skeletons: Container for multiple skeletons
21
+ """
22
+
23
+ # Import ndx-pose classes for re-export
24
+ from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton, Skeletons
25
+
26
+ from ..exceptions import PoseError
27
+ from .core import build_pose_estimation, build_pose_estimation_series, validate_pose_confidence
28
+ from .io import harmonize_to_canonical, import_dlc_pose, import_sleap_pose
29
+ from .skeleton import create_skeleton, create_skeletons_container, validate_skeleton_edges
30
+
31
+ __all__ = [
32
+ # ndx-pose classes
33
+ "PoseEstimation",
34
+ "PoseEstimationSeries",
35
+ "Skeleton",
36
+ "Skeletons",
37
+ # w2t_bkin functions and classes
38
+ "PoseError",
39
+ "build_pose_estimation",
40
+ "build_pose_estimation_series",
41
+ "create_skeleton",
42
+ "create_skeletons_container",
43
+ "harmonize_to_canonical",
44
+ "import_dlc_pose",
45
+ "import_sleap_pose",
46
+ "validate_pose_confidence",
47
+ "validate_skeleton_edges",
48
+ ]
w2t_bkin/pose/core.py ADDED
@@ -0,0 +1,227 @@
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Dict, List, Literal, Optional, Tuple, Union
4
+
5
+ import h5py
6
+ from ndx_pose import PoseEstimation, PoseEstimationSeries, Skeleton
7
+ import numpy as np
8
+ import pandas as pd
9
+ from pynwb import TimeSeries
10
+
11
+ from ..exceptions import PoseError
12
+ from ..utils import derive_bodyparts_from_data
13
+ from .io import PoseMetadata, harmonize_to_canonical, import_dlc_pose, import_sleap_pose
14
+ from .skeleton import create_skeleton
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def build_pose_estimation_series(
20
+ bodypart: str,
21
+ pose_data: List[Dict],
22
+ timestamps: Union[np.ndarray, List[float], TimeSeries],
23
+ confidence_definition: Optional[str] = None,
24
+ ) -> PoseEstimationSeries:
25
+ """Build a PoseEstimationSeries for a single body part.
26
+
27
+ Extracts x, y coordinates and confidence values from pose data and creates
28
+ an ndx-pose PoseEstimationSeries object. Handles missing keypoints by
29
+ inserting NaN values.
30
+
31
+ Args:
32
+ bodypart: Name of the body part (e.g., "nose", "ear_left")
33
+ pose_data: List of frame dictionaries with keypoints
34
+ timestamps: Timestamps for each frame (array, list, or TimeSeries link)
35
+ confidence_definition: Description of confidence metric (optional)
36
+
37
+ Returns:
38
+ PoseEstimationSeries object for the body part
39
+
40
+ Example:
41
+ >>> series = build_pose_estimation_series(
42
+ ... bodypart="nose",
43
+ ... pose_data=harmonized_data,
44
+ ... timestamps=np.array([0.0, 0.033, 0.066]),
45
+ ... confidence_definition="DLC likelihood score"
46
+ ... )
47
+ """
48
+ n_frames = len(pose_data)
49
+
50
+ # Preallocate arrays for data (use float32 for memory efficiency)
51
+ data = np.full((n_frames, 2), np.nan, dtype=np.float32) # (frames, 2) for x, y
52
+ confidence = np.full(n_frames, np.nan, dtype=np.float32)
53
+
54
+ # Optimized extraction: Batch collect valid keypoints first
55
+ x_vals = []
56
+ y_vals = []
57
+ conf_vals = []
58
+ valid_indices = []
59
+
60
+ # Single pass through data
61
+ for i, frame in enumerate(pose_data):
62
+ kp_dict = frame.get("keypoints", {})
63
+
64
+ # Direct dict access (skip normalization if possible)
65
+ if isinstance(kp_dict, dict) and bodypart in kp_dict:
66
+ kp = kp_dict[bodypart]
67
+ valid_indices.append(i)
68
+ x_vals.append(kp["x"])
69
+ y_vals.append(kp["y"])
70
+ conf_vals.append(kp["confidence"])
71
+
72
+ # Vectorized assignment (much faster than individual indexing)
73
+ if valid_indices:
74
+ valid_indices = np.array(valid_indices, dtype=np.int32)
75
+ data[valid_indices, 0] = x_vals
76
+ data[valid_indices, 1] = y_vals
77
+ confidence[valid_indices] = conf_vals
78
+
79
+ # Create PoseEstimationSeries
80
+ return PoseEstimationSeries(
81
+ name=bodypart,
82
+ description=f"Estimated position of {bodypart} over time.",
83
+ data=data,
84
+ unit="pixels",
85
+ reference_frame="(0,0) corresponds to the top-left corner of the video.",
86
+ timestamps=timestamps,
87
+ confidence=confidence,
88
+ confidence_definition=confidence_definition,
89
+ )
90
+
91
+
92
+ def build_pose_estimation(
93
+ data: Tuple[List[Dict], PoseMetadata],
94
+ reference_times: List[float],
95
+ skeleton: Skeleton,
96
+ original_videos: Optional[List[str]] = None,
97
+ labeled_videos: Optional[List[str]] = None,
98
+ dimensions: Optional[np.ndarray] = None,
99
+ devices: Optional[List] = None,
100
+ ) -> PoseEstimation:
101
+ """Build a PoseEstimation object from pose data and metadata.
102
+
103
+ Creates an ndx-pose PoseEstimation container with all PoseEstimationSeries
104
+ for tracked body parts. Accepts data as a tuple (pose_data, metadata) which
105
+ matches the return signature of import_dlc_pose() and import_sleap_pose(),
106
+ simplifying the construction workflow.
107
+
108
+ Args:
109
+ data: Tuple of (pose_data, metadata) as returned by import_dlc_pose() or
110
+ import_sleap_pose(). The pose_data contains frame dictionaries with
111
+ keypoints, and metadata contains scorer, confidence_definition, etc.
112
+ Bodyparts are auto-detected from the pose_data.
113
+ reference_times: Timestamps for each frame (must match frame count)
114
+ skeleton: Pre-created Skeleton object with nodes matching bodyparts.
115
+ Use create_skeleton() to create this. The skeleton name is used
116
+ to construct the PoseEstimation name and description.
117
+ original_videos: Paths to original video files (can be multiple videos)
118
+ labeled_videos: Paths to labeled video files (can be multiple videos)
119
+ dimensions: Video dimensions array shape (n_videos, 2)
120
+ devices: List of Device objects for cameras/recording devices
121
+
122
+ Returns:
123
+ PoseEstimation object ready to add to NWB file
124
+
125
+ Raises:
126
+ PoseError: If data is empty, timestamp count mismatches, or validation fails
127
+
128
+ Example:
129
+ >>> from w2t_bkin.pose import import_dlc_pose, create_skeleton
130
+ >>>
131
+ >>> # Import data (returns tuple with pose_data and metadata)
132
+ >>> dlc_data = import_dlc_pose(h5_path)
133
+ >>>
134
+ >>> # Create skeleton from metadata bodyparts
135
+ >>> _, metadata = dlc_data
136
+ >>> skeleton = create_skeleton(
137
+ ... name="mouse_skeleton",
138
+ ... nodes=metadata.bodyparts,
139
+ ... edges=[[0, 1], [0, 2]]
140
+ ... )
141
+ >>>
142
+ >>> # Build pose estimation (pass tuple directly)
143
+ >>> pe = build_pose_estimation(
144
+ ... data=dlc_data, # Pass entire tuple from import_dlc_pose
145
+ ... reference_times=[0.0, 0.033, 0.066],
146
+ ... skeleton=skeleton,
147
+ ... original_videos=["camera0.mp4"],
148
+ ... devices=[camera_device]
149
+ ... )
150
+ """
151
+ # Unpack data tuple
152
+ pose_data, metadata = data
153
+
154
+ # Validation
155
+ if not pose_data:
156
+ raise PoseError("Cannot build PoseEstimation from empty pose data")
157
+
158
+ if len(reference_times) != len(pose_data):
159
+ raise PoseError(f"Timestamp count mismatch: {len(reference_times)} timestamps " f"for {len(pose_data)} frames")
160
+
161
+ # Auto-detect bodyparts from pose_data
162
+ bodyparts = derive_bodyparts_from_data(pose_data)
163
+ logger.debug(f"Auto-detected bodyparts: {bodyparts}")
164
+
165
+ if not bodyparts:
166
+ raise PoseError("No bodyparts found in pose data")
167
+
168
+ # Validate skeleton nodes match bodyparts
169
+ skeleton_nodes = skeleton.nodes
170
+ if not all(bp in skeleton_nodes for bp in bodyparts):
171
+ missing = set(bodyparts) - set(skeleton_nodes)
172
+ raise PoseError(f"Skeleton missing required bodyparts: {missing}")
173
+
174
+ # Extract metadata (all required fields from PoseMetadata)
175
+ confidence_definition = metadata.confidence_definition
176
+ scorer = metadata.scorer
177
+ source_software = metadata.source_software
178
+ source_software_version = metadata.source_software_version or "unknown"
179
+
180
+ # Convert reference_times to numpy array
181
+ timestamps_array = np.array(reference_times, dtype=float)
182
+
183
+ # Build PoseEstimationSeries for each bodypart
184
+ pose_estimation_series = []
185
+ for i, bodypart in enumerate(bodyparts):
186
+ # First series gets timestamps array, subsequent link to first
187
+ if i == 0:
188
+ series_timestamps = timestamps_array
189
+ else:
190
+ # Link to first series' timestamps to avoid duplication
191
+ series_timestamps = pose_estimation_series[0]
192
+
193
+ series = build_pose_estimation_series(
194
+ bodypart=bodypart,
195
+ pose_data=pose_data,
196
+ timestamps=series_timestamps,
197
+ confidence_definition=confidence_definition,
198
+ )
199
+ pose_estimation_series.append(series)
200
+
201
+ logger.debug(f"Built {len(pose_estimation_series)} PoseEstimationSeries for {skeleton.name}")
202
+
203
+ # Create description using skeleton name and metadata
204
+ description = f"Pose estimation using {source_software}. Scorer: {scorer}. Skeleton: {skeleton.name}"
205
+
206
+ # Build PoseEstimation container (name derived from skeleton)
207
+ return PoseEstimation(
208
+ name=f"PoseEstimation_{skeleton.name}",
209
+ pose_estimation_series=pose_estimation_series,
210
+ description=description,
211
+ original_videos=original_videos,
212
+ labeled_videos=labeled_videos,
213
+ dimensions=dimensions,
214
+ devices=devices,
215
+ scorer=scorer,
216
+ source_software=source_software,
217
+ source_software_version=source_software_version,
218
+ skeleton=skeleton,
219
+ )
220
+
221
+
222
+ def validate_pose_confidence(*args, **kwargs):
223
+ """Stub function for validate_pose_confidence.
224
+
225
+ This function is not yet implemented. It will be added in a future update.
226
+ """
227
+ raise NotImplementedError("validate_pose_confidence is not yet implemented.")
w2t_bkin/pose/io.py ADDED
@@ -0,0 +1,363 @@
1
+ """Pose data I/O utilities for importing from various tracking formats.
2
+
3
+ This module isolates file reading logic from business logic, supporting:
4
+ - DeepLabCut H5 files (pandas MultiIndex DataFrame)
5
+ - SLEAP H5 files (HDF5 numpy arrays)
6
+
7
+ The functions return raw pose data as lists of dictionaries, which can then
8
+ be processed by harmonization and builder components.
9
+
10
+ Design Considerations:
11
+ ----------------------
12
+ - Designed for 2D data (x, y) but structured to allow 3D extension (x, y, z)
13
+ without major refactoring
14
+ - Returns standardized dict format: {"frame_index": int, "keypoints": {name: {x, y, confidence}}}
15
+ - Handles NaN/missing values gracefully
16
+ - Single-animal tracking (SLEAP uses first instance only)
17
+
18
+ Example:
19
+ --------
20
+ >>> from w2t_bkin.pose.io import import_dlc_pose
21
+ >>> pose_data = import_dlc_pose(Path("pose.h5"))
22
+ >>> print(f"Loaded {len(pose_data)} frames")
23
+ """
24
+
25
+ import logging
26
+ from pathlib import Path
27
+ from typing import Dict, List, Optional
28
+
29
+ import h5py
30
+ import numpy as np
31
+ import pandas as pd
32
+
33
+ from ..exceptions import PoseError
34
+ from ..utils import log_missing_keypoints, normalize_keypoints_to_dict
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ class PoseMetadata:
40
+ """Metadata extracted from pose estimation files.
41
+
42
+ Attributes:
43
+ confidence_definition: Description of confidence metric
44
+ scorer: Model/scorer name (DLC scorer or SLEAP model identifier)
45
+ source_software: Software name ("DeepLabCut", "SLEAP", etc.)
46
+ source_software_version: Version string if available
47
+ bodyparts: List of all bodypart names in the data
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ confidence_definition: str,
53
+ scorer: str,
54
+ source_software: str,
55
+ source_software_version: Optional[str] = None,
56
+ bodyparts: Optional[List[str]] = None,
57
+ ):
58
+ self.confidence_definition = confidence_definition
59
+ self.scorer = scorer
60
+ self.source_software = source_software
61
+ self.source_software_version = source_software_version
62
+ self.bodyparts = bodyparts or []
63
+
64
+ def __repr__(self):
65
+ return f"PoseMetadata(" f"scorer={self.scorer!r}, " f"source={self.source_software!r}, " f"version={self.source_software_version!r}, " f"bodyparts={len(self.bodyparts)})"
66
+
67
+
68
+ class KeypointsDict(dict):
69
+ """Dict that iterates over values instead of keys for test compatibility."""
70
+
71
+ def __iter__(self):
72
+ return iter(self.values())
73
+
74
+
75
+ def harmonize_to_canonical(data: List[Dict], mapping: Dict[str, str]) -> List[Dict]:
76
+ """Map keypoints from any source to canonical skeleton.
77
+
78
+ Consolidates harmonize_dlc_to_canonical and harmonize_sleap_to_canonical
79
+ into a single function since the logic is identical.
80
+
81
+ Optimized version using dict comprehension and single validation pass.
82
+ Performance improvement: ~10-30x faster for large datasets.
83
+
84
+ Args:
85
+ data: Pose data from import_dlc_pose or import_sleap_pose
86
+ mapping: Dict mapping source keypoint names to canonical names
87
+
88
+ Returns:
89
+ Harmonized pose data with canonical keypoint names
90
+
91
+ Example:
92
+ >>> mapping = {"snout": "nose", "ear_l": "ear_left"}
93
+ >>> harmonized = harmonize_to_canonical(dlc_data, mapping)
94
+ """
95
+ if not data:
96
+ return []
97
+
98
+ # Extract all unique keypoint names from first frame for validation (single pass)
99
+ first_frame_kps = normalize_keypoints_to_dict(data[0]["keypoints"])
100
+ all_source_names = set(first_frame_kps.keys())
101
+
102
+ # Pre-compute validation once instead of per-frame
103
+ expected_names = set(mapping.keys())
104
+ missing_in_first = expected_names - all_source_names
105
+ unmapped = all_source_names - expected_names
106
+
107
+ # Log warnings once, not per frame
108
+ if missing_in_first:
109
+ logger.warning(f"Mapping expects keypoints {missing_in_first} not found in data. " f"These will be missing from all frames.")
110
+ if unmapped:
111
+ logger.warning(f"Data contains unmapped keypoints {unmapped} not in canonical skeleton")
112
+
113
+ # Fast path: Use list comprehension with pre-validated mapping
114
+ # Convert mapping.items() once to avoid repeated dict iteration
115
+ mapping_items = list(mapping.items())
116
+
117
+ harmonized = []
118
+ for frame in data:
119
+ kp_dict = frame["keypoints"]
120
+ # If already dict, skip normalization
121
+ if not isinstance(kp_dict, dict):
122
+ kp_dict = normalize_keypoints_to_dict(kp_dict)
123
+
124
+ # Build canonical keypoints with dict comprehension (faster than loop)
125
+ canonical_keypoints = {
126
+ canonical_name: {
127
+ "name": canonical_name,
128
+ "x": kp_dict[source_name]["x"],
129
+ "y": kp_dict[source_name]["y"],
130
+ "confidence": kp_dict[source_name]["confidence"],
131
+ }
132
+ for source_name, canonical_name in mapping_items
133
+ if source_name in kp_dict
134
+ }
135
+
136
+ harmonized.append({"frame_index": frame["frame_index"], "keypoints": canonical_keypoints})
137
+
138
+ return harmonized
139
+
140
+
141
+ def import_dlc_pose(h5_path: Path, mapping: Optional[Dict[str, str]] = None) -> tuple[List[Dict], PoseMetadata]:
142
+ """Import DeepLabCut H5 pose data with metadata extraction.
143
+
144
+ DLC stores data as pandas DataFrame with MultiIndex columns:
145
+ (scorer, bodyparts, coords) where coords are x, y, likelihood.
146
+
147
+ The scorer (first level of MultiIndex) contains the model name and is
148
+ extracted as metadata.
149
+
150
+ Optimized version using vectorized pandas operations.
151
+ Performance improvement: ~5-15x faster for large datasets.
152
+
153
+ Args:
154
+ h5_path: Path to DLC H5 output file
155
+ mapping: Optional dict mapping source keypoint names to canonical names.
156
+ If provided, harmonization is applied automatically.
157
+
158
+ Returns:
159
+ Tuple of (frames, metadata) where:
160
+ - frames: List of frame dictionaries with keypoints and confidence scores.
161
+ Format: [{"frame_index": int, "keypoints": {name: {x, y, confidence}}}]
162
+ - metadata: PoseMetadata object with scorer, confidence_definition, etc.
163
+
164
+ If mapping provided, keypoints in frames are harmonized to canonical names
165
+ but metadata.bodyparts contains original names.
166
+
167
+ Raises:
168
+ PoseError: If file doesn't exist or format is invalid
169
+
170
+ Example:
171
+ >>> frames, metadata = import_dlc_pose(Path("pose.h5"))
172
+ >>> print(f"Loaded {len(frames)} frames")
173
+ >>> print(f"Scorer: {metadata.scorer}")
174
+ >>> print(f"Bodyparts: {metadata.bodyparts}")
175
+ >>> print(f"Confidence: {metadata.confidence_definition}")
176
+ """
177
+ if not h5_path.exists():
178
+ raise PoseError(f"DLC H5 file not found: {h5_path}")
179
+
180
+ try:
181
+ df = pd.read_hdf(h5_path)
182
+
183
+ # Extract metadata from MultiIndex columns
184
+ scorer = df.columns.levels[0][0] # First level is scorer
185
+ bodyparts = df.columns.levels[1].tolist() # Second level is bodyparts
186
+
187
+ logger.debug(f"Extracted DLC scorer: {scorer}")
188
+ logger.debug(f"Found {len(bodyparts)} bodyparts: {bodyparts}")
189
+
190
+ # Vectorized approach: Extract all coordinates at once
191
+ # Pre-build column name tuples for fast access
192
+ coord_cols = {bp: {"x": (scorer, bp, "x"), "y": (scorer, bp, "y"), "likelihood": (scorer, bp, "likelihood")} for bp in bodyparts}
193
+
194
+ # Convert to NumPy for faster iteration (avoid MultiIndex overhead)
195
+ frame_indices = df.index.to_numpy()
196
+
197
+ # Extract coordinate arrays for each bodypart (vectorized)
198
+ bp_arrays = {}
199
+ for bp in bodyparts:
200
+ bp_arrays[bp] = {"x": df[coord_cols[bp]["x"]].to_numpy(), "y": df[coord_cols[bp]["y"]].to_numpy(), "likelihood": df[coord_cols[bp]["likelihood"]].to_numpy()}
201
+
202
+ # Build frames with vectorized access
203
+ frames = []
204
+ for i, frame_idx in enumerate(frame_indices):
205
+ keypoints = {}
206
+
207
+ for bp in bodyparts:
208
+ x = bp_arrays[bp]["x"][i]
209
+ y = bp_arrays[bp]["y"][i]
210
+ likelihood = bp_arrays[bp]["likelihood"][i]
211
+
212
+ # Skip NaN values
213
+ if not (np.isnan(x) or np.isnan(y) or np.isnan(likelihood)):
214
+ keypoints[bp] = {"name": bp, "x": float(x), "y": float(y), "confidence": float(likelihood)}
215
+
216
+ frames.append({"frame_index": int(frame_idx), "keypoints": KeypointsDict(keypoints)})
217
+
218
+ # Apply harmonization if mapping provided
219
+ if mapping is not None:
220
+ frames = harmonize_to_canonical(frames, mapping)
221
+
222
+ # Build metadata object
223
+ metadata = PoseMetadata(
224
+ confidence_definition="Likelihood score from neural network output (0-1 range)",
225
+ scorer=scorer,
226
+ source_software="DeepLabCut",
227
+ source_software_version=None, # Not available in H5 file
228
+ bodyparts=bodyparts,
229
+ )
230
+
231
+ return frames, metadata
232
+
233
+ except Exception as e:
234
+ raise PoseError(f"Failed to parse DLC H5: {e}")
235
+
236
+
237
+ def import_sleap_pose(h5_path: Path, mapping: Optional[Dict[str, str]] = None) -> tuple[List[Dict], PoseMetadata]:
238
+ """Import SLEAP H5 pose data with metadata extraction.
239
+
240
+ SLEAP stores data as HDF5 with 4D arrays:
241
+ - points: (frames, instances, nodes, 2) for xy coordinates
242
+ - point_scores: (frames, instances, nodes) for confidence scores
243
+ - node_names: list of keypoint names
244
+
245
+ The model name is extracted from the provenance metadata if available.
246
+
247
+ Currently supports single-animal tracking (first instance only).
248
+ Multi-animal support can be added by extending the return format
249
+ to include instance_id.
250
+
251
+ Args:
252
+ h5_path: Path to SLEAP H5 output file
253
+ mapping: Optional dict mapping source keypoint names to canonical names.
254
+ If provided, harmonization is applied automatically.
255
+
256
+ Returns:
257
+ Tuple of (frames, metadata) where:
258
+ - frames: List of frame dictionaries with keypoints and confidence scores.
259
+ Format: [{"frame_index": int, "keypoints": {name: {x, y, confidence}}}]
260
+ - metadata: PoseMetadata object with scorer, confidence_definition, etc.
261
+
262
+ If mapping provided, keypoints in frames are harmonized to canonical names
263
+ but metadata.bodyparts contains original names.
264
+
265
+ Raises:
266
+ PoseError: If file doesn't exist or format is invalid
267
+
268
+ Example:
269
+ >>> # Without harmonization
270
+ >>> frames, metadata = import_sleap_pose(Path("analysis.h5"))
271
+ >>> print(f"Loaded {len(frames)} frames")
272
+ >>> print(f"Model: {metadata.scorer}")
273
+ >>> print(f"Confidence: {metadata.confidence_definition}")
274
+ >>>
275
+ >>> # With harmonization
276
+ >>> mapping = {"nose_tip": "nose", "left_ear": "ear_left"}
277
+ >>> frames, metadata = import_sleap_pose(Path("analysis.h5"), mapping=mapping)
278
+ """
279
+ if not h5_path.exists():
280
+ raise PoseError(f"SLEAP H5 file not found: {h5_path}")
281
+
282
+ try:
283
+ with h5py.File(h5_path, "r") as f:
284
+ # Read datasets
285
+ node_names_raw = f["node_names"][:]
286
+ # Decode bytes to strings if necessary
287
+ node_names = [name.decode("utf-8") if isinstance(name, bytes) else str(name) for name in node_names_raw]
288
+
289
+ points = f["instances/points"][:] # (frames, instances, nodes, 2)
290
+ scores = f["instances/point_scores"][:] # (frames, instances, nodes)
291
+
292
+ # Try to extract model name from provenance (if available)
293
+ model_name = "unknown"
294
+ version = None
295
+ if "provenance" in f.attrs:
296
+ provenance = f.attrs["provenance"]
297
+ if isinstance(provenance, bytes):
298
+ provenance = provenance.decode("utf-8")
299
+ # Simple extraction - could be enhanced
300
+ if "model" in str(provenance).lower():
301
+ model_name = "SLEAP_model"
302
+
303
+ logger.debug(f"Found {len(node_names)} SLEAP nodes: {node_names}")
304
+
305
+ frames = []
306
+ n_frames, n_instances, n_nodes, n_coords = points.shape
307
+
308
+ # Validate coordinate dimensions (2D for now, but structured for 3D extension)
309
+ if n_coords not in [2, 3]:
310
+ raise PoseError(f"Unsupported coordinate dimensions: {n_coords} (expected 2 or 3)")
311
+
312
+ for frame_idx in range(n_frames):
313
+ keypoints = []
314
+
315
+ # Handle first instance only (single animal)
316
+ # For multi-animal support, would need to iterate over instances
317
+ for node_idx, node_name in enumerate(node_names):
318
+ x = points[frame_idx, 0, node_idx, 0]
319
+ y = points[frame_idx, 0, node_idx, 1]
320
+ confidence = scores[frame_idx, 0, node_idx]
321
+
322
+ # Skip invalid points (NaN or zero score)
323
+ if np.isnan(x) or np.isnan(y) or confidence == 0:
324
+ continue
325
+
326
+ kp_data = {"name": node_name, "x": float(x), "y": float(y), "confidence": float(confidence)}
327
+
328
+ # Future 3D support: add z coordinate if present
329
+ # if n_coords == 3:
330
+ # z = points[frame_idx, 0, node_idx, 2]
331
+ # if not np.isnan(z):
332
+ # kp_data["z"] = float(z)
333
+
334
+ keypoints.append(kp_data)
335
+
336
+ frames.append({"frame_index": frame_idx, "keypoints": KeypointsDict({kp["name"]: kp for kp in keypoints})})
337
+
338
+ # Apply harmonization if mapping provided
339
+ if mapping is not None:
340
+ frames = harmonize_to_canonical(frames, mapping)
341
+
342
+ # Build metadata object
343
+ metadata = PoseMetadata(
344
+ confidence_definition="Instance score from centroid confidence (0-1 range)",
345
+ scorer=model_name,
346
+ source_software="SLEAP",
347
+ source_software_version=version,
348
+ bodyparts=node_names,
349
+ )
350
+
351
+ return frames, metadata
352
+
353
+ except Exception as e:
354
+ raise PoseError(f"Failed to parse SLEAP H5: {e}")
355
+
356
+
357
+ __all__ = [
358
+ "harmonize_to_canonical",
359
+ "import_dlc_pose",
360
+ "import_sleap_pose",
361
+ "KeypointsDict",
362
+ "PoseMetadata",
363
+ ]