sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/data/utils.py CHANGED
@@ -1,12 +1,14 @@
1
1
  """Miscellaneous utility functions for data processing."""
2
2
 
3
3
  from typing import Tuple, List, Any, Optional
4
+ import sys
4
5
  import torch
5
6
  from omegaconf import DictConfig
6
7
  import sleap_io as sio
7
8
  from sleap_nn.config.utils import get_model_type_from_cfg
8
9
  import psutil
9
10
  import numpy as np
11
+ from loguru import logger
10
12
  from sleap_nn.data.providers import get_max_instances
11
13
 
12
14
 
@@ -115,35 +117,151 @@ def check_memory(
115
117
  return img_mem
116
118
 
117
119
 
120
+ def estimate_cache_memory(
121
+ train_labels: List[sio.Labels],
122
+ val_labels: List[sio.Labels],
123
+ num_workers: int = 0,
124
+ memory_buffer: float = 0.2,
125
+ ) -> dict:
126
+ """Estimate memory requirements for in-memory caching dataset pipeline.
127
+
128
+ This function calculates the total memory needed for caching images, accounting for:
129
+ - Raw image data size
130
+ - Python object overhead (dictionary keys, numpy array wrappers)
131
+ - DataLoader worker memory overhead (Copy-on-Write duplication on Unix systems)
132
+ - General memory buffer for training overhead
133
+
134
+ When using DataLoader with num_workers > 0, worker processes are spawned via fork()
135
+ on Unix systems. While Copy-on-Write (CoW) initially shares memory, Python's reference
136
+ counting can trigger memory page duplication when workers access cached data.
137
+
138
+ Args:
139
+ train_labels: List of `sleap_io.Labels` objects for training data.
140
+ val_labels: List of `sleap_io.Labels` objects for validation data.
141
+ num_workers: Number of DataLoader worker processes. When > 0, additional memory
142
+ overhead is estimated for worker process duplication.
143
+ memory_buffer: Fraction of memory to reserve as buffer for training overhead
144
+ (model weights, activations, gradients, etc.). Default: 0.2 (20%).
145
+
146
+ Returns:
147
+ dict: Memory estimation breakdown with keys:
148
+ - 'raw_cache_bytes': Raw image data size in bytes
149
+ - 'python_overhead_bytes': Estimated Python object overhead
150
+ - 'worker_overhead_bytes': Estimated memory for DataLoader workers
151
+ - 'buffer_bytes': Memory buffer for training overhead
152
+ - 'total_bytes': Total estimated memory requirement
153
+ - 'available_bytes': Available system memory
154
+ - 'sufficient': True if total <= available, False otherwise
155
+ """
156
+ # Calculate raw image cache size
157
+ train_cache_bytes = 0
158
+ val_cache_bytes = 0
159
+ num_train_samples = 0
160
+ num_val_samples = 0
161
+
162
+ for train, val in zip(train_labels, val_labels):
163
+ train_cache_bytes += check_memory(train)
164
+ val_cache_bytes += check_memory(val)
165
+ num_train_samples += len(train)
166
+ num_val_samples += len(val)
167
+
168
+ raw_cache_bytes = train_cache_bytes + val_cache_bytes
169
+ total_samples = num_train_samples + num_val_samples
170
+
171
+ # Python object overhead: dict keys, numpy array wrappers, tuple keys
172
+ # Estimate ~200 bytes per sample for Python object overhead
173
+ python_overhead_per_sample = 200
174
+ python_overhead_bytes = total_samples * python_overhead_per_sample
175
+
176
+ # Worker memory overhead
177
+ # When num_workers > 0, workers are forked or spawned depending on platform.
178
+ # Default start methods (Python 3.8+):
179
+ # - Linux: fork (Copy-on-Write, partial memory duplication)
180
+ # - macOS: spawn (full dataset copy to each worker, changed in Python 3.8)
181
+ # - Windows: spawn (full dataset copy to each worker)
182
+ worker_overhead_bytes = 0
183
+ if num_workers > 0:
184
+ if sys.platform == "linux":
185
+ # Linux uses fork() with Copy-on-Write by default
186
+ # Estimate 25% duplication per worker due to Python refcounting
187
+ # triggering CoW page copies
188
+ worker_overhead_bytes = int(raw_cache_bytes * 0.25 * num_workers)
189
+ if num_workers >= 4:
190
+ logger.info(
191
+ f"Using in-memory caching with {num_workers} DataLoader workers. "
192
+ f"Estimated additional memory for workers: "
193
+ f"{worker_overhead_bytes / (1024**3):.2f} GB"
194
+ )
195
+ else:
196
+ # macOS (darwin) and Windows use spawn - dataset is copied to each worker
197
+ # Since Python 3.8, macOS defaults to spawn due to fork safety issues
198
+ # With caching enabled, we avoid pickling labels_list, but the cache
199
+ # dict is still part of the dataset and gets copied to each worker
200
+ worker_overhead_bytes = int(raw_cache_bytes * 0.5 * num_workers)
201
+ platform_name = "macOS" if sys.platform == "darwin" else "Windows"
202
+ logger.warning(
203
+ f"Using in-memory caching with {num_workers} DataLoader workers on {platform_name}. "
204
+ f"Memory usage may be significantly higher than estimated (~{worker_overhead_bytes / (1024**3):.1f} GB extra) "
205
+ f"due to spawn-based multiprocessing. "
206
+ f"Consider using disk caching or num_workers=0 for large datasets."
207
+ )
208
+
209
+ # Memory buffer for training overhead (model, gradients, activations)
210
+ subtotal = raw_cache_bytes + python_overhead_bytes + worker_overhead_bytes
211
+ buffer_bytes = int(subtotal * memory_buffer)
212
+
213
+ total_bytes = subtotal + buffer_bytes
214
+ available_bytes = psutil.virtual_memory().available
215
+
216
+ return {
217
+ "raw_cache_bytes": raw_cache_bytes,
218
+ "python_overhead_bytes": python_overhead_bytes,
219
+ "worker_overhead_bytes": worker_overhead_bytes,
220
+ "buffer_bytes": buffer_bytes,
221
+ "total_bytes": total_bytes,
222
+ "available_bytes": available_bytes,
223
+ "sufficient": total_bytes <= available_bytes,
224
+ "num_samples": total_samples,
225
+ }
226
+
227
+
118
228
  def check_cache_memory(
119
229
  train_labels: List[sio.Labels],
120
230
  val_labels: List[sio.Labels],
121
231
  memory_buffer: float = 0.2,
232
+ num_workers: int = 0,
122
233
  ) -> bool:
123
234
  """Check memory requirements for in-memory caching dataset pipeline.
124
235
 
236
+ This function determines if the system has sufficient memory for in-memory
237
+ image caching, accounting for DataLoader worker processes.
238
+
125
239
  Args:
126
240
  train_labels: List of `sleap_io.Labels` objects for training data.
127
241
  val_labels: List of `sleap_io.Labels` objects for validation data.
128
- memory_buffer: Fraction of the total image memory required for caching that
129
- should be reserved as a buffer.
242
+ memory_buffer: Fraction of memory to reserve as buffer. Default: 0.2 (20%).
243
+ num_workers: Number of DataLoader worker processes. When > 0, additional memory
244
+ overhead is estimated for worker process duplication.
130
245
 
131
246
  Returns:
132
247
  bool: True if the total memory required for caching is within available system
133
248
  memory, False otherwise.
134
249
  """
135
- train_cache_memory_final = 0
136
- val_cache_memory_final = 0
137
- for train, val in zip(train_labels, val_labels):
138
- train_cache_memory = check_memory(train)
139
- val_cache_memory = check_memory(val)
140
- train_cache_memory_final += train_cache_memory
141
- val_cache_memory_final += val_cache_memory
142
-
143
- total_cache_memory = train_cache_memory_final + val_cache_memory_final
144
- total_cache_memory += memory_buffer * total_cache_memory # memory required in bytes
145
- available_memory = psutil.virtual_memory().available # available memory in bytes
146
-
147
- if total_cache_memory > available_memory:
148
- return False
149
- return True
250
+ estimate = estimate_cache_memory(
251
+ train_labels=train_labels,
252
+ val_labels=val_labels,
253
+ num_workers=num_workers,
254
+ memory_buffer=memory_buffer,
255
+ )
256
+
257
+ if not estimate["sufficient"]:
258
+ total_gb = estimate["total_bytes"] / (1024**3)
259
+ available_gb = estimate["available_bytes"] / (1024**3)
260
+ raw_gb = estimate["raw_cache_bytes"] / (1024**3)
261
+ logger.info(
262
+ f"Memory check failed: need ~{total_gb:.2f} GB "
263
+ f"(raw cache: {raw_gb:.2f} GB, {estimate['num_samples']} samples), "
264
+ f"available: {available_gb:.2f} GB"
265
+ )
266
+
267
+ return estimate["sufficient"]
sleap_nn/evaluation.py CHANGED
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
29
29
  """
30
30
  instance_list = []
31
31
  frame_idx = labeled_frame.frame_idx
32
- video_path = (
33
- labeled_frame.video.backend.source_filename
34
- if hasattr(labeled_frame.video.backend, "source_filename")
35
- else labeled_frame.video.backend.filename
36
- )
32
+
33
+ # Extract video path with fallbacks for embedded videos
34
+ video = labeled_frame.video
35
+ video_path = None
36
+ if video is not None:
37
+ backend = getattr(video, "backend", None)
38
+ if backend is not None:
39
+ # Try source_filename first (for embedded videos with provenance)
40
+ video_path = getattr(backend, "source_filename", None)
41
+ if video_path is None:
42
+ video_path = getattr(backend, "filename", None)
43
+ # Fallback to video.filename if backend doesn't have it
44
+ if video_path is None:
45
+ video_path = getattr(video, "filename", None)
46
+ # Handle list filenames (image sequences)
47
+ if isinstance(video_path, list) and video_path:
48
+ video_path = video_path[0]
49
+ # Final fallback: use a unique identifier
50
+ if video_path is None:
51
+ video_path = f"video_{id(video)}" if video is not None else "unknown"
52
+
37
53
  for instance in labeled_frame.instances:
38
54
  match_instance = MatchInstance(
39
55
  instance=instance, frame_idx=frame_idx, video_path=video_path
@@ -47,6 +63,10 @@ def find_frame_pairs(
47
63
  ) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
48
64
  """Find corresponding frames across two sets of labels.
49
65
 
66
+ This function uses sleap-io's robust video matching API to handle various
67
+ scenarios including embedded videos, cross-platform paths, and videos with
68
+ different metadata.
69
+
50
70
  Args:
51
71
  labels_gt: A `sio.Labels` instance with ground truth instances.
52
72
  labels_pr: A `sio.Labels` instance with predicted instances.
@@ -56,25 +76,15 @@ def find_frame_pairs(
56
76
  Returns:
57
77
  A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
58
78
  """
79
+ # Use sleap-io's robust video matching API (added in 0.6.2)
80
+ # The match() method returns a MatchResult with video_map: {pred_video: gt_video}
81
+ match_result = labels_gt.match(labels_pr)
82
+
59
83
  frame_pairs = []
60
- for video_gt in labels_gt.videos:
61
- # Find matching video instance in predictions.
62
- video_pr = None
63
- for video in labels_pr.videos:
64
- if (
65
- isinstance(video.backend, type(video_gt.backend))
66
- and video.filename == video_gt.filename
67
- ):
68
- same_dataset = (
69
- (video.backend.dataset == video_gt.backend.dataset)
70
- if hasattr(video.backend, "dataset")
71
- else True
72
- ) # `dataset` attr exists only for hdf5 backend not for mediavideo
73
- if same_dataset:
74
- video_pr = video
75
- break
76
-
77
- if video_pr is None:
84
+ # Iterate over matched video pairs (pred_video -> gt_video mapping)
85
+ for video_pr, video_gt in match_result.video_map.items():
86
+ if video_gt is None:
87
+ # No match found for this prediction video
78
88
  continue
79
89
 
80
90
  # Find labeled frames in this video.
@@ -629,11 +639,19 @@ class Evaluator:
629
639
  mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
630
640
  mPCK = mPCK_parts.mean()
631
641
 
642
+ # Precompute PCK at common thresholds
643
+ idx_5 = np.argmin(np.abs(thresholds - 5))
644
+ idx_10 = np.argmin(np.abs(thresholds - 10))
645
+ pck5 = pcks[:, :, idx_5].mean()
646
+ pck10 = pcks[:, :, idx_10].mean()
647
+
632
648
  return {
633
649
  "thresholds": thresholds,
634
650
  "pcks": pcks,
635
651
  "mPCK_parts": mPCK_parts,
636
652
  "mPCK": mPCK,
653
+ "PCK@5": pck5,
654
+ "PCK@10": pck10,
637
655
  }
638
656
 
639
657
  def visibility_metrics(self):
@@ -678,24 +696,109 @@ class Evaluator:
678
696
  return metrics
679
697
 
680
698
 
681
- def load_metrics(model_path: str, split="val"):
682
- """Load the metrics for a given model and split.
699
+ def _find_metrics_file(model_dir: Path, split: str, dataset_idx: int) -> Path:
700
+ """Find the metrics file in a model directory.
701
+
702
+ Tries new naming format first, then falls back to old format.
703
+ If split is "test" and not found, falls back to "val".
704
+ """
705
+ # Try new naming format first: metrics.{split}.{idx}.npz
706
+ metrics_path = model_dir / f"metrics.{split}.{dataset_idx}.npz"
707
+ if metrics_path.exists():
708
+ return metrics_path
709
+
710
+ # Fall back to old naming format: {split}_{idx}_pred_metrics.npz
711
+ metrics_path = model_dir / f"{split}_{dataset_idx}_pred_metrics.npz"
712
+ if metrics_path.exists():
713
+ return metrics_path
714
+
715
+ # If split is "test" and not found, try "val" fallback
716
+ if split == "test":
717
+ return _find_metrics_file(model_dir, "val", dataset_idx)
718
+
719
+ # Return the new format path (will raise FileNotFoundError later)
720
+ return model_dir / f"metrics.{split}.{dataset_idx}.npz"
721
+
722
+
723
+ def _load_npz_metrics(metrics_path: Path) -> dict:
724
+ """Load metrics from an npz file, supporting both old and new formats.
725
+
726
+ New format: single "metrics" key containing a dict with all metrics.
727
+ Old format: individual metric keys at top level (voc_metrics, mOKS, etc.).
728
+ """
729
+ with np.load(metrics_path, allow_pickle=True) as data:
730
+ keys = list(data.keys())
731
+
732
+ # New format: single "metrics" key containing dict
733
+ if "metrics" in keys:
734
+ return data["metrics"].item()
735
+
736
+ # Old format: individual metric keys at top level
737
+ expected_keys = {
738
+ "voc_metrics",
739
+ "mOKS",
740
+ "distance_metrics",
741
+ "pck_metrics",
742
+ "visibility_metrics",
743
+ }
744
+ if expected_keys.issubset(set(keys)):
745
+ return {
746
+ k: data[k].item() if data[k].ndim == 0 else data[k]
747
+ for k in expected_keys
748
+ }
749
+
750
+ # Unknown format - return all keys as dict
751
+ return {k: data[k].item() if data[k].ndim == 0 else data[k] for k in keys}
752
+
753
+
754
+ def load_metrics(
755
+ path: str,
756
+ split: str = "test",
757
+ dataset_idx: int = 0,
758
+ ) -> dict:
759
+ """Load metrics from a model folder or metrics file.
760
+
761
+ This function supports both the new format (single "metrics" key) and the old
762
+ format (individual metric keys at top level). It also handles both old and new
763
+ file naming conventions in model folders.
683
764
 
684
765
  Args:
685
- model_path: Path to a model folder or metrics file (.npz).
686
- split: Name of the split to load the metrics for. Must be `"train"`, `"val"` or
687
- `"test"` (default: `"val"`). Ignored if a path to a metrics NPZ file is
688
- provided.
766
+ path: Path to a model folder or metrics file (.npz).
767
+ split: Name of the split to load. Must be "train", "val", or "test".
768
+ Default: "test". If "test" is not found, falls back to "val".
769
+ Ignored if path points directly to a .npz file.
770
+ dataset_idx: Index of the dataset (for multi-dataset training).
771
+ Default: 0. Ignored if path points directly to a .npz file.
772
+
773
+ Returns:
774
+ Dictionary containing metrics with keys: voc_metrics, mOKS,
775
+ distance_metrics, pck_metrics, visibility_metrics.
776
+
777
+ Raises:
778
+ FileNotFoundError: If no metrics file is found.
779
+
780
+ Examples:
781
+ >>> # Load from model folder (tries test, falls back to val)
782
+ >>> metrics = load_metrics("/path/to/model")
783
+ >>> print(metrics["mOKS"]["mOKS"])
689
784
 
785
+ >>> # Load specific split and dataset
786
+ >>> metrics = load_metrics("/path/to/model", split="val", dataset_idx=1)
787
+
788
+ >>> # Load directly from npz file
789
+ >>> metrics = load_metrics("/path/to/metrics.val.0.npz")
690
790
  """
691
- if Path(model_path).suffix == ".npz":
692
- metrics_path = Path(model_path)
791
+ path = Path(path)
792
+
793
+ if path.suffix == ".npz":
794
+ metrics_path = path
693
795
  else:
694
- metrics_path = Path(model_path) / f"{split}_0_pred_metrics.npz"
796
+ metrics_path = _find_metrics_file(path, split, dataset_idx)
797
+
695
798
  if not metrics_path.exists():
696
799
  raise FileNotFoundError(f"Metrics file not found at {metrics_path}")
697
- with np.load(metrics_path, allow_pickle=True) as data:
698
- return data["metrics"].item()
800
+
801
+ return _load_npz_metrics(metrics_path)
699
802
 
700
803
 
701
804
  def run_evaluation(
@@ -710,11 +813,26 @@ def run_evaluation(
710
813
  """Evaluate SLEAP-NN model predictions against ground truth labels."""
711
814
  logger.info("Loading ground truth labels...")
712
815
  ground_truth_instances = sio.load_slp(ground_truth_path)
816
+ logger.info(
817
+ f" Ground truth: {len(ground_truth_instances.videos)} videos, "
818
+ f"{len(ground_truth_instances.labeled_frames)} frames"
819
+ )
713
820
 
714
821
  logger.info("Loading predicted labels...")
715
822
  predicted_instances = sio.load_slp(predicted_path)
823
+ logger.info(
824
+ f" Predictions: {len(predicted_instances.videos)} videos, "
825
+ f"{len(predicted_instances.labeled_frames)} frames"
826
+ )
827
+
828
+ logger.info("Matching videos and frames...")
829
+ # Get match stats before creating evaluator
830
+ match_result = ground_truth_instances.match(predicted_instances)
831
+ logger.info(
832
+ f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
833
+ )
716
834
 
717
- logger.info("Creating evaluator...")
835
+ logger.info("Matching instances...")
718
836
  evaluator = Evaluator(
719
837
  ground_truth_instances=ground_truth_instances,
720
838
  predicted_instances=predicted_instances,
@@ -723,21 +841,38 @@ def run_evaluation(
723
841
  match_threshold=match_threshold,
724
842
  user_labels_only=user_labels_only,
725
843
  )
844
+ logger.info(
845
+ f" Frame pairs: {len(evaluator.frame_pairs)}, "
846
+ f"Matched instances: {len(evaluator.positive_pairs)}, "
847
+ f"Unmatched GT: {len(evaluator.false_negatives)}"
848
+ )
726
849
 
727
850
  logger.info("Computing evaluation metrics...")
728
851
  metrics = evaluator.evaluate()
729
852
 
853
+ # Compute PCK at specific thresholds (5 and 10 pixels)
854
+ dists = metrics["distance_metrics"]["dists"]
855
+ dists_clean = np.copy(dists)
856
+ dists_clean[np.isnan(dists_clean)] = np.inf
857
+ pck_5 = (dists_clean < 5).mean()
858
+ pck_10 = (dists_clean < 10).mean()
859
+
730
860
  # Print key metrics
731
861
  logger.info("Evaluation Results:")
732
- logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
733
- logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
734
- logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
735
- logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.4f}")
736
- logger.info(f"mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
862
+ logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
863
+ logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
864
+ logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
865
+ logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
866
+ logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
867
+ logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
868
+ logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
869
+ logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
870
+ logger.info(f" PCK@5px: {pck_5:.4f}")
871
+ logger.info(f" PCK@10px: {pck_10:.4f}")
737
872
  logger.info(
738
- f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
873
+ f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
739
874
  )
740
- logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
875
+ logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
741
876
 
742
877
  # Save metrics if path provided
743
878
  if save_metrics:
@@ -0,0 +1,21 @@
1
+ """Export utilities for sleap-nn."""
2
+
3
+ from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
4
+ from sleap_nn.export.metadata import ExportMetadata
5
+ from sleap_nn.export.predictors import (
6
+ load_exported_model,
7
+ ONNXPredictor,
8
+ TensorRTPredictor,
9
+ )
10
+ from sleap_nn.export.utils import build_bottomup_candidate_template
11
+
12
+ __all__ = [
13
+ "export_model",
14
+ "export_to_onnx",
15
+ "export_to_tensorrt",
16
+ "load_exported_model",
17
+ "ONNXPredictor",
18
+ "TensorRTPredictor",
19
+ "ExportMetadata",
20
+ "build_bottomup_candidate_template",
21
+ ]