sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/data/utils.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
"""Miscellaneous utility functions for data processing."""
|
|
2
2
|
|
|
3
3
|
from typing import Tuple, List, Any, Optional
|
|
4
|
+
import sys
|
|
4
5
|
import torch
|
|
5
6
|
from omegaconf import DictConfig
|
|
6
7
|
import sleap_io as sio
|
|
7
8
|
from sleap_nn.config.utils import get_model_type_from_cfg
|
|
8
9
|
import psutil
|
|
9
10
|
import numpy as np
|
|
11
|
+
from loguru import logger
|
|
10
12
|
from sleap_nn.data.providers import get_max_instances
|
|
11
13
|
|
|
12
14
|
|
|
@@ -115,35 +117,151 @@ def check_memory(
|
|
|
115
117
|
return img_mem
|
|
116
118
|
|
|
117
119
|
|
|
120
|
+
def estimate_cache_memory(
|
|
121
|
+
train_labels: List[sio.Labels],
|
|
122
|
+
val_labels: List[sio.Labels],
|
|
123
|
+
num_workers: int = 0,
|
|
124
|
+
memory_buffer: float = 0.2,
|
|
125
|
+
) -> dict:
|
|
126
|
+
"""Estimate memory requirements for in-memory caching dataset pipeline.
|
|
127
|
+
|
|
128
|
+
This function calculates the total memory needed for caching images, accounting for:
|
|
129
|
+
- Raw image data size
|
|
130
|
+
- Python object overhead (dictionary keys, numpy array wrappers)
|
|
131
|
+
- DataLoader worker memory overhead (Copy-on-Write duplication on Unix systems)
|
|
132
|
+
- General memory buffer for training overhead
|
|
133
|
+
|
|
134
|
+
When using DataLoader with num_workers > 0, worker processes are spawned via fork()
|
|
135
|
+
on Unix systems. While Copy-on-Write (CoW) initially shares memory, Python's reference
|
|
136
|
+
counting can trigger memory page duplication when workers access cached data.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
train_labels: List of `sleap_io.Labels` objects for training data.
|
|
140
|
+
val_labels: List of `sleap_io.Labels` objects for validation data.
|
|
141
|
+
num_workers: Number of DataLoader worker processes. When > 0, additional memory
|
|
142
|
+
overhead is estimated for worker process duplication.
|
|
143
|
+
memory_buffer: Fraction of memory to reserve as buffer for training overhead
|
|
144
|
+
(model weights, activations, gradients, etc.). Default: 0.2 (20%).
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
dict: Memory estimation breakdown with keys:
|
|
148
|
+
- 'raw_cache_bytes': Raw image data size in bytes
|
|
149
|
+
- 'python_overhead_bytes': Estimated Python object overhead
|
|
150
|
+
- 'worker_overhead_bytes': Estimated memory for DataLoader workers
|
|
151
|
+
- 'buffer_bytes': Memory buffer for training overhead
|
|
152
|
+
- 'total_bytes': Total estimated memory requirement
|
|
153
|
+
- 'available_bytes': Available system memory
|
|
154
|
+
- 'sufficient': True if total <= available, False otherwise
|
|
155
|
+
"""
|
|
156
|
+
# Calculate raw image cache size
|
|
157
|
+
train_cache_bytes = 0
|
|
158
|
+
val_cache_bytes = 0
|
|
159
|
+
num_train_samples = 0
|
|
160
|
+
num_val_samples = 0
|
|
161
|
+
|
|
162
|
+
for train, val in zip(train_labels, val_labels):
|
|
163
|
+
train_cache_bytes += check_memory(train)
|
|
164
|
+
val_cache_bytes += check_memory(val)
|
|
165
|
+
num_train_samples += len(train)
|
|
166
|
+
num_val_samples += len(val)
|
|
167
|
+
|
|
168
|
+
raw_cache_bytes = train_cache_bytes + val_cache_bytes
|
|
169
|
+
total_samples = num_train_samples + num_val_samples
|
|
170
|
+
|
|
171
|
+
# Python object overhead: dict keys, numpy array wrappers, tuple keys
|
|
172
|
+
# Estimate ~200 bytes per sample for Python object overhead
|
|
173
|
+
python_overhead_per_sample = 200
|
|
174
|
+
python_overhead_bytes = total_samples * python_overhead_per_sample
|
|
175
|
+
|
|
176
|
+
# Worker memory overhead
|
|
177
|
+
# When num_workers > 0, workers are forked or spawned depending on platform.
|
|
178
|
+
# Default start methods (Python 3.8+):
|
|
179
|
+
# - Linux: fork (Copy-on-Write, partial memory duplication)
|
|
180
|
+
# - macOS: spawn (full dataset copy to each worker, changed in Python 3.8)
|
|
181
|
+
# - Windows: spawn (full dataset copy to each worker)
|
|
182
|
+
worker_overhead_bytes = 0
|
|
183
|
+
if num_workers > 0:
|
|
184
|
+
if sys.platform == "linux":
|
|
185
|
+
# Linux uses fork() with Copy-on-Write by default
|
|
186
|
+
# Estimate 25% duplication per worker due to Python refcounting
|
|
187
|
+
# triggering CoW page copies
|
|
188
|
+
worker_overhead_bytes = int(raw_cache_bytes * 0.25 * num_workers)
|
|
189
|
+
if num_workers >= 4:
|
|
190
|
+
logger.info(
|
|
191
|
+
f"Using in-memory caching with {num_workers} DataLoader workers. "
|
|
192
|
+
f"Estimated additional memory for workers: "
|
|
193
|
+
f"{worker_overhead_bytes / (1024**3):.2f} GB"
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
# macOS (darwin) and Windows use spawn - dataset is copied to each worker
|
|
197
|
+
# Since Python 3.8, macOS defaults to spawn due to fork safety issues
|
|
198
|
+
# With caching enabled, we avoid pickling labels_list, but the cache
|
|
199
|
+
# dict is still part of the dataset and gets copied to each worker
|
|
200
|
+
worker_overhead_bytes = int(raw_cache_bytes * 0.5 * num_workers)
|
|
201
|
+
platform_name = "macOS" if sys.platform == "darwin" else "Windows"
|
|
202
|
+
logger.warning(
|
|
203
|
+
f"Using in-memory caching with {num_workers} DataLoader workers on {platform_name}. "
|
|
204
|
+
f"Memory usage may be significantly higher than estimated (~{worker_overhead_bytes / (1024**3):.1f} GB extra) "
|
|
205
|
+
f"due to spawn-based multiprocessing. "
|
|
206
|
+
f"Consider using disk caching or num_workers=0 for large datasets."
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Memory buffer for training overhead (model, gradients, activations)
|
|
210
|
+
subtotal = raw_cache_bytes + python_overhead_bytes + worker_overhead_bytes
|
|
211
|
+
buffer_bytes = int(subtotal * memory_buffer)
|
|
212
|
+
|
|
213
|
+
total_bytes = subtotal + buffer_bytes
|
|
214
|
+
available_bytes = psutil.virtual_memory().available
|
|
215
|
+
|
|
216
|
+
return {
|
|
217
|
+
"raw_cache_bytes": raw_cache_bytes,
|
|
218
|
+
"python_overhead_bytes": python_overhead_bytes,
|
|
219
|
+
"worker_overhead_bytes": worker_overhead_bytes,
|
|
220
|
+
"buffer_bytes": buffer_bytes,
|
|
221
|
+
"total_bytes": total_bytes,
|
|
222
|
+
"available_bytes": available_bytes,
|
|
223
|
+
"sufficient": total_bytes <= available_bytes,
|
|
224
|
+
"num_samples": total_samples,
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
|
|
118
228
|
def check_cache_memory(
|
|
119
229
|
train_labels: List[sio.Labels],
|
|
120
230
|
val_labels: List[sio.Labels],
|
|
121
231
|
memory_buffer: float = 0.2,
|
|
232
|
+
num_workers: int = 0,
|
|
122
233
|
) -> bool:
|
|
123
234
|
"""Check memory requirements for in-memory caching dataset pipeline.
|
|
124
235
|
|
|
236
|
+
This function determines if the system has sufficient memory for in-memory
|
|
237
|
+
image caching, accounting for DataLoader worker processes.
|
|
238
|
+
|
|
125
239
|
Args:
|
|
126
240
|
train_labels: List of `sleap_io.Labels` objects for training data.
|
|
127
241
|
val_labels: List of `sleap_io.Labels` objects for validation data.
|
|
128
|
-
memory_buffer: Fraction of
|
|
129
|
-
|
|
242
|
+
memory_buffer: Fraction of memory to reserve as buffer. Default: 0.2 (20%).
|
|
243
|
+
num_workers: Number of DataLoader worker processes. When > 0, additional memory
|
|
244
|
+
overhead is estimated for worker process duplication.
|
|
130
245
|
|
|
131
246
|
Returns:
|
|
132
247
|
bool: True if the total memory required for caching is within available system
|
|
133
248
|
memory, False otherwise.
|
|
134
249
|
"""
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
250
|
+
estimate = estimate_cache_memory(
|
|
251
|
+
train_labels=train_labels,
|
|
252
|
+
val_labels=val_labels,
|
|
253
|
+
num_workers=num_workers,
|
|
254
|
+
memory_buffer=memory_buffer,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if not estimate["sufficient"]:
|
|
258
|
+
total_gb = estimate["total_bytes"] / (1024**3)
|
|
259
|
+
available_gb = estimate["available_bytes"] / (1024**3)
|
|
260
|
+
raw_gb = estimate["raw_cache_bytes"] / (1024**3)
|
|
261
|
+
logger.info(
|
|
262
|
+
f"Memory check failed: need ~{total_gb:.2f} GB "
|
|
263
|
+
f"(raw cache: {raw_gb:.2f} GB, {estimate['num_samples']} samples), "
|
|
264
|
+
f"available: {available_gb:.2f} GB"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
return estimate["sufficient"]
|
sleap_nn/evaluation.py
CHANGED
|
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
|
|
|
29
29
|
"""
|
|
30
30
|
instance_list = []
|
|
31
31
|
frame_idx = labeled_frame.frame_idx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
32
|
+
|
|
33
|
+
# Extract video path with fallbacks for embedded videos
|
|
34
|
+
video = labeled_frame.video
|
|
35
|
+
video_path = None
|
|
36
|
+
if video is not None:
|
|
37
|
+
backend = getattr(video, "backend", None)
|
|
38
|
+
if backend is not None:
|
|
39
|
+
# Try source_filename first (for embedded videos with provenance)
|
|
40
|
+
video_path = getattr(backend, "source_filename", None)
|
|
41
|
+
if video_path is None:
|
|
42
|
+
video_path = getattr(backend, "filename", None)
|
|
43
|
+
# Fallback to video.filename if backend doesn't have it
|
|
44
|
+
if video_path is None:
|
|
45
|
+
video_path = getattr(video, "filename", None)
|
|
46
|
+
# Handle list filenames (image sequences)
|
|
47
|
+
if isinstance(video_path, list) and video_path:
|
|
48
|
+
video_path = video_path[0]
|
|
49
|
+
# Final fallback: use a unique identifier
|
|
50
|
+
if video_path is None:
|
|
51
|
+
video_path = f"video_{id(video)}" if video is not None else "unknown"
|
|
52
|
+
|
|
37
53
|
for instance in labeled_frame.instances:
|
|
38
54
|
match_instance = MatchInstance(
|
|
39
55
|
instance=instance, frame_idx=frame_idx, video_path=video_path
|
|
@@ -47,6 +63,10 @@ def find_frame_pairs(
|
|
|
47
63
|
) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
|
|
48
64
|
"""Find corresponding frames across two sets of labels.
|
|
49
65
|
|
|
66
|
+
This function uses sleap-io's robust video matching API to handle various
|
|
67
|
+
scenarios including embedded videos, cross-platform paths, and videos with
|
|
68
|
+
different metadata.
|
|
69
|
+
|
|
50
70
|
Args:
|
|
51
71
|
labels_gt: A `sio.Labels` instance with ground truth instances.
|
|
52
72
|
labels_pr: A `sio.Labels` instance with predicted instances.
|
|
@@ -56,25 +76,15 @@ def find_frame_pairs(
|
|
|
56
76
|
Returns:
|
|
57
77
|
A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
|
|
58
78
|
"""
|
|
79
|
+
# Use sleap-io's robust video matching API (added in 0.6.2)
|
|
80
|
+
# The match() method returns a MatchResult with video_map: {pred_video: gt_video}
|
|
81
|
+
match_result = labels_gt.match(labels_pr)
|
|
82
|
+
|
|
59
83
|
frame_pairs = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if (
|
|
65
|
-
isinstance(video.backend, type(video_gt.backend))
|
|
66
|
-
and video.filename == video_gt.filename
|
|
67
|
-
):
|
|
68
|
-
same_dataset = (
|
|
69
|
-
(video.backend.dataset == video_gt.backend.dataset)
|
|
70
|
-
if hasattr(video.backend, "dataset")
|
|
71
|
-
else True
|
|
72
|
-
) # `dataset` attr exists only for hdf5 backend not for mediavideo
|
|
73
|
-
if same_dataset:
|
|
74
|
-
video_pr = video
|
|
75
|
-
break
|
|
76
|
-
|
|
77
|
-
if video_pr is None:
|
|
84
|
+
# Iterate over matched video pairs (pred_video -> gt_video mapping)
|
|
85
|
+
for video_pr, video_gt in match_result.video_map.items():
|
|
86
|
+
if video_gt is None:
|
|
87
|
+
# No match found for this prediction video
|
|
78
88
|
continue
|
|
79
89
|
|
|
80
90
|
# Find labeled frames in this video.
|
|
@@ -629,11 +639,19 @@ class Evaluator:
|
|
|
629
639
|
mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
|
|
630
640
|
mPCK = mPCK_parts.mean()
|
|
631
641
|
|
|
642
|
+
# Precompute PCK at common thresholds
|
|
643
|
+
idx_5 = np.argmin(np.abs(thresholds - 5))
|
|
644
|
+
idx_10 = np.argmin(np.abs(thresholds - 10))
|
|
645
|
+
pck5 = pcks[:, :, idx_5].mean()
|
|
646
|
+
pck10 = pcks[:, :, idx_10].mean()
|
|
647
|
+
|
|
632
648
|
return {
|
|
633
649
|
"thresholds": thresholds,
|
|
634
650
|
"pcks": pcks,
|
|
635
651
|
"mPCK_parts": mPCK_parts,
|
|
636
652
|
"mPCK": mPCK,
|
|
653
|
+
"PCK@5": pck5,
|
|
654
|
+
"PCK@10": pck10,
|
|
637
655
|
}
|
|
638
656
|
|
|
639
657
|
def visibility_metrics(self):
|
|
@@ -678,24 +696,109 @@ class Evaluator:
|
|
|
678
696
|
return metrics
|
|
679
697
|
|
|
680
698
|
|
|
681
|
-
def
|
|
682
|
-
"""
|
|
699
|
+
def _find_metrics_file(model_dir: Path, split: str, dataset_idx: int) -> Path:
|
|
700
|
+
"""Find the metrics file in a model directory.
|
|
701
|
+
|
|
702
|
+
Tries new naming format first, then falls back to old format.
|
|
703
|
+
If split is "test" and not found, falls back to "val".
|
|
704
|
+
"""
|
|
705
|
+
# Try new naming format first: metrics.{split}.{idx}.npz
|
|
706
|
+
metrics_path = model_dir / f"metrics.{split}.{dataset_idx}.npz"
|
|
707
|
+
if metrics_path.exists():
|
|
708
|
+
return metrics_path
|
|
709
|
+
|
|
710
|
+
# Fall back to old naming format: {split}_{idx}_pred_metrics.npz
|
|
711
|
+
metrics_path = model_dir / f"{split}_{dataset_idx}_pred_metrics.npz"
|
|
712
|
+
if metrics_path.exists():
|
|
713
|
+
return metrics_path
|
|
714
|
+
|
|
715
|
+
# If split is "test" and not found, try "val" fallback
|
|
716
|
+
if split == "test":
|
|
717
|
+
return _find_metrics_file(model_dir, "val", dataset_idx)
|
|
718
|
+
|
|
719
|
+
# Return the new format path (will raise FileNotFoundError later)
|
|
720
|
+
return model_dir / f"metrics.{split}.{dataset_idx}.npz"
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
def _load_npz_metrics(metrics_path: Path) -> dict:
|
|
724
|
+
"""Load metrics from an npz file, supporting both old and new formats.
|
|
725
|
+
|
|
726
|
+
New format: single "metrics" key containing a dict with all metrics.
|
|
727
|
+
Old format: individual metric keys at top level (voc_metrics, mOKS, etc.).
|
|
728
|
+
"""
|
|
729
|
+
with np.load(metrics_path, allow_pickle=True) as data:
|
|
730
|
+
keys = list(data.keys())
|
|
731
|
+
|
|
732
|
+
# New format: single "metrics" key containing dict
|
|
733
|
+
if "metrics" in keys:
|
|
734
|
+
return data["metrics"].item()
|
|
735
|
+
|
|
736
|
+
# Old format: individual metric keys at top level
|
|
737
|
+
expected_keys = {
|
|
738
|
+
"voc_metrics",
|
|
739
|
+
"mOKS",
|
|
740
|
+
"distance_metrics",
|
|
741
|
+
"pck_metrics",
|
|
742
|
+
"visibility_metrics",
|
|
743
|
+
}
|
|
744
|
+
if expected_keys.issubset(set(keys)):
|
|
745
|
+
return {
|
|
746
|
+
k: data[k].item() if data[k].ndim == 0 else data[k]
|
|
747
|
+
for k in expected_keys
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
# Unknown format - return all keys as dict
|
|
751
|
+
return {k: data[k].item() if data[k].ndim == 0 else data[k] for k in keys}
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def load_metrics(
|
|
755
|
+
path: str,
|
|
756
|
+
split: str = "test",
|
|
757
|
+
dataset_idx: int = 0,
|
|
758
|
+
) -> dict:
|
|
759
|
+
"""Load metrics from a model folder or metrics file.
|
|
760
|
+
|
|
761
|
+
This function supports both the new format (single "metrics" key) and the old
|
|
762
|
+
format (individual metric keys at top level). It also handles both old and new
|
|
763
|
+
file naming conventions in model folders.
|
|
683
764
|
|
|
684
765
|
Args:
|
|
685
|
-
|
|
686
|
-
split: Name of the split to load
|
|
687
|
-
|
|
688
|
-
|
|
766
|
+
path: Path to a model folder or metrics file (.npz).
|
|
767
|
+
split: Name of the split to load. Must be "train", "val", or "test".
|
|
768
|
+
Default: "test". If "test" is not found, falls back to "val".
|
|
769
|
+
Ignored if path points directly to a .npz file.
|
|
770
|
+
dataset_idx: Index of the dataset (for multi-dataset training).
|
|
771
|
+
Default: 0. Ignored if path points directly to a .npz file.
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
Dictionary containing metrics with keys: voc_metrics, mOKS,
|
|
775
|
+
distance_metrics, pck_metrics, visibility_metrics.
|
|
776
|
+
|
|
777
|
+
Raises:
|
|
778
|
+
FileNotFoundError: If no metrics file is found.
|
|
779
|
+
|
|
780
|
+
Examples:
|
|
781
|
+
>>> # Load from model folder (tries test, falls back to val)
|
|
782
|
+
>>> metrics = load_metrics("/path/to/model")
|
|
783
|
+
>>> print(metrics["mOKS"]["mOKS"])
|
|
689
784
|
|
|
785
|
+
>>> # Load specific split and dataset
|
|
786
|
+
>>> metrics = load_metrics("/path/to/model", split="val", dataset_idx=1)
|
|
787
|
+
|
|
788
|
+
>>> # Load directly from npz file
|
|
789
|
+
>>> metrics = load_metrics("/path/to/metrics.val.0.npz")
|
|
690
790
|
"""
|
|
691
|
-
|
|
692
|
-
|
|
791
|
+
path = Path(path)
|
|
792
|
+
|
|
793
|
+
if path.suffix == ".npz":
|
|
794
|
+
metrics_path = path
|
|
693
795
|
else:
|
|
694
|
-
metrics_path =
|
|
796
|
+
metrics_path = _find_metrics_file(path, split, dataset_idx)
|
|
797
|
+
|
|
695
798
|
if not metrics_path.exists():
|
|
696
799
|
raise FileNotFoundError(f"Metrics file not found at {metrics_path}")
|
|
697
|
-
|
|
698
|
-
|
|
800
|
+
|
|
801
|
+
return _load_npz_metrics(metrics_path)
|
|
699
802
|
|
|
700
803
|
|
|
701
804
|
def run_evaluation(
|
|
@@ -710,11 +813,26 @@ def run_evaluation(
|
|
|
710
813
|
"""Evaluate SLEAP-NN model predictions against ground truth labels."""
|
|
711
814
|
logger.info("Loading ground truth labels...")
|
|
712
815
|
ground_truth_instances = sio.load_slp(ground_truth_path)
|
|
816
|
+
logger.info(
|
|
817
|
+
f" Ground truth: {len(ground_truth_instances.videos)} videos, "
|
|
818
|
+
f"{len(ground_truth_instances.labeled_frames)} frames"
|
|
819
|
+
)
|
|
713
820
|
|
|
714
821
|
logger.info("Loading predicted labels...")
|
|
715
822
|
predicted_instances = sio.load_slp(predicted_path)
|
|
823
|
+
logger.info(
|
|
824
|
+
f" Predictions: {len(predicted_instances.videos)} videos, "
|
|
825
|
+
f"{len(predicted_instances.labeled_frames)} frames"
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
logger.info("Matching videos and frames...")
|
|
829
|
+
# Get match stats before creating evaluator
|
|
830
|
+
match_result = ground_truth_instances.match(predicted_instances)
|
|
831
|
+
logger.info(
|
|
832
|
+
f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
|
|
833
|
+
)
|
|
716
834
|
|
|
717
|
-
logger.info("
|
|
835
|
+
logger.info("Matching instances...")
|
|
718
836
|
evaluator = Evaluator(
|
|
719
837
|
ground_truth_instances=ground_truth_instances,
|
|
720
838
|
predicted_instances=predicted_instances,
|
|
@@ -723,21 +841,38 @@ def run_evaluation(
|
|
|
723
841
|
match_threshold=match_threshold,
|
|
724
842
|
user_labels_only=user_labels_only,
|
|
725
843
|
)
|
|
844
|
+
logger.info(
|
|
845
|
+
f" Frame pairs: {len(evaluator.frame_pairs)}, "
|
|
846
|
+
f"Matched instances: {len(evaluator.positive_pairs)}, "
|
|
847
|
+
f"Unmatched GT: {len(evaluator.false_negatives)}"
|
|
848
|
+
)
|
|
726
849
|
|
|
727
850
|
logger.info("Computing evaluation metrics...")
|
|
728
851
|
metrics = evaluator.evaluate()
|
|
729
852
|
|
|
853
|
+
# Compute PCK at specific thresholds (5 and 10 pixels)
|
|
854
|
+
dists = metrics["distance_metrics"]["dists"]
|
|
855
|
+
dists_clean = np.copy(dists)
|
|
856
|
+
dists_clean[np.isnan(dists_clean)] = np.inf
|
|
857
|
+
pck_5 = (dists_clean < 5).mean()
|
|
858
|
+
pck_10 = (dists_clean < 10).mean()
|
|
859
|
+
|
|
730
860
|
# Print key metrics
|
|
731
861
|
logger.info("Evaluation Results:")
|
|
732
|
-
logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
733
|
-
logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
734
|
-
logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
735
|
-
logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.
|
|
736
|
-
logger.info(f"
|
|
862
|
+
logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
863
|
+
logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
864
|
+
logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
865
|
+
logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
|
|
866
|
+
logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
|
|
867
|
+
logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
|
|
868
|
+
logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
|
|
869
|
+
logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
|
|
870
|
+
logger.info(f" PCK@5px: {pck_5:.4f}")
|
|
871
|
+
logger.info(f" PCK@10px: {pck_10:.4f}")
|
|
737
872
|
logger.info(
|
|
738
|
-
f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
873
|
+
f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
739
874
|
)
|
|
740
|
-
logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
875
|
+
logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
741
876
|
|
|
742
877
|
# Save metrics if path provided
|
|
743
878
|
if save_metrics:
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Export utilities for sleap-nn."""
|
|
2
|
+
|
|
3
|
+
from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
|
|
4
|
+
from sleap_nn.export.metadata import ExportMetadata
|
|
5
|
+
from sleap_nn.export.predictors import (
|
|
6
|
+
load_exported_model,
|
|
7
|
+
ONNXPredictor,
|
|
8
|
+
TensorRTPredictor,
|
|
9
|
+
)
|
|
10
|
+
from sleap_nn.export.utils import build_bottomup_candidate_template
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"export_model",
|
|
14
|
+
"export_to_onnx",
|
|
15
|
+
"export_to_tensorrt",
|
|
16
|
+
"load_exported_model",
|
|
17
|
+
"ONNXPredictor",
|
|
18
|
+
"TensorRTPredictor",
|
|
19
|
+
"ExportMetadata",
|
|
20
|
+
"build_bottomup_candidate_template",
|
|
21
|
+
]
|