sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/evaluation.py +8 -0
  4. sleap_nn/export/__init__.py +21 -0
  5. sleap_nn/export/cli.py +1778 -0
  6. sleap_nn/export/exporters/__init__.py +51 -0
  7. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  8. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  9. sleap_nn/export/metadata.py +225 -0
  10. sleap_nn/export/predictors/__init__.py +63 -0
  11. sleap_nn/export/predictors/base.py +22 -0
  12. sleap_nn/export/predictors/onnx.py +154 -0
  13. sleap_nn/export/predictors/tensorrt.py +312 -0
  14. sleap_nn/export/utils.py +307 -0
  15. sleap_nn/export/wrappers/__init__.py +25 -0
  16. sleap_nn/export/wrappers/base.py +96 -0
  17. sleap_nn/export/wrappers/bottomup.py +243 -0
  18. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  19. sleap_nn/export/wrappers/centered_instance.py +56 -0
  20. sleap_nn/export/wrappers/centroid.py +58 -0
  21. sleap_nn/export/wrappers/single_instance.py +83 -0
  22. sleap_nn/export/wrappers/topdown.py +180 -0
  23. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  24. sleap_nn/inference/postprocessing.py +284 -0
  25. sleap_nn/predict.py +29 -0
  26. sleap_nn/train.py +64 -0
  27. sleap_nn/training/callbacks.py +62 -20
  28. sleap_nn/training/lightning_modules.py +332 -30
  29. sleap_nn/training/model_trainer.py +35 -67
  30. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
  31. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
  32. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  33. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  34. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,63 @@
1
+ """Predictors for exported models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import torch
9
+
10
+ from sleap_nn.export.predictors.base import ExportPredictor
11
+ from sleap_nn.export.predictors.onnx import ONNXPredictor
12
+ from sleap_nn.export.predictors.tensorrt import TensorRTPredictor
13
+
14
+
15
+ def detect_runtime(model_path: str | Path) -> str:
16
+ """Auto-detect runtime from file extension or folder contents."""
17
+ model_path = Path(model_path)
18
+ if model_path.is_dir():
19
+ onnx_path = model_path / "exported" / "model.onnx"
20
+ trt_path = model_path / "exported" / "model.trt"
21
+ if trt_path.exists():
22
+ return "tensorrt"
23
+ if onnx_path.exists():
24
+ return "onnx"
25
+ raise ValueError(f"No exported model found in {model_path}")
26
+
27
+ ext = model_path.suffix.lower()
28
+ if ext == ".onnx":
29
+ return "onnx"
30
+ if ext in (".trt", ".engine"):
31
+ return "tensorrt"
32
+
33
+ raise ValueError(f"Unknown model format: {ext}")
34
+
35
+
36
+ def load_exported_model(
37
+ model_path: str | Path,
38
+ runtime: str = "auto",
39
+ device: str = "auto",
40
+ **kwargs,
41
+ ) -> ExportPredictor:
42
+ """Load an exported model and return a predictor instance."""
43
+ if runtime == "auto":
44
+ runtime = detect_runtime(model_path)
45
+
46
+ if device == "auto":
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+
49
+ if runtime == "onnx":
50
+ return ONNXPredictor(str(model_path), device=device, **kwargs)
51
+ if runtime == "tensorrt":
52
+ return TensorRTPredictor(str(model_path), device=device, **kwargs)
53
+
54
+ raise ValueError(f"Unknown runtime: {runtime}")
55
+
56
+
57
+ __all__ = [
58
+ "ExportPredictor",
59
+ "ONNXPredictor",
60
+ "TensorRTPredictor",
61
+ "detect_runtime",
62
+ "load_exported_model",
63
+ ]
@@ -0,0 +1,22 @@
1
+ """Predictor base class for exported models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict
7
+
8
+ import numpy as np
9
+
10
+
11
+ class ExportPredictor(ABC):
12
+ """Base interface for exported model inference."""
13
+
14
+ @abstractmethod
15
+ def predict(self, image: np.ndarray) -> Dict[str, np.ndarray]:
16
+ """Run inference on a batch of images."""
17
+
18
+ @abstractmethod
19
+ def benchmark(
20
+ self, image: np.ndarray, n_warmup: int = 50, n_runs: int = 200
21
+ ) -> Dict[str, float]:
22
+ """Benchmark inference latency and throughput."""
@@ -0,0 +1,154 @@
1
+ """ONNX Runtime predictor."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Iterable, Optional
6
+ import time
7
+
8
+ import numpy as np
9
+
10
+ from sleap_nn.export.predictors.base import ExportPredictor
11
+
12
+
13
+ class ONNXPredictor(ExportPredictor):
14
+ """ONNX Runtime inference with provider selection."""
15
+
16
+ def __init__(
17
+ self,
18
+ model_path: str,
19
+ device: str = "auto",
20
+ providers: Optional[Iterable[str]] = None,
21
+ ) -> None:
22
+ """Initialize ONNX predictor with execution providers.
23
+
24
+ Args:
25
+ model_path: Path to the ONNX model file.
26
+ device: Device for inference ("auto", "cpu", or "cuda").
27
+ providers: ONNX Runtime execution providers. Auto-selected if None.
28
+ """
29
+ try:
30
+ import onnxruntime as ort
31
+ except ImportError as exc:
32
+ raise ImportError(
33
+ "onnxruntime is required for ONNXPredictor. Install with "
34
+ "`pip install onnxruntime` or `onnxruntime-gpu`."
35
+ ) from exc
36
+
37
+ # Preload CUDA/cuDNN libraries from pip-installed nvidia packages
38
+ # This is required for onnxruntime-gpu to find the CUDA libraries
39
+ if hasattr(ort, "preload_dlls"):
40
+ ort.preload_dlls()
41
+
42
+ self.ort = ort
43
+ if providers is None:
44
+ providers = _select_providers(device, ort.get_available_providers())
45
+
46
+ self.session = ort.InferenceSession(model_path, providers=list(providers))
47
+ input_info = self.session.get_inputs()[0]
48
+ self.input_name = input_info.name
49
+ self.input_type = input_info.type
50
+ self.input_dtype = _onnx_type_to_numpy(self.input_type)
51
+ self.output_names = [out.name for out in self.session.get_outputs()]
52
+
53
+ def predict(self, image: np.ndarray) -> Dict[str, np.ndarray]:
54
+ """Run inference on a batch of images."""
55
+ image = _as_numpy(image, expected_dtype=self.input_dtype)
56
+ outputs = self.session.run(None, {self.input_name: image})
57
+ return dict(zip(self.output_names, outputs))
58
+
59
+ def benchmark(
60
+ self, image: np.ndarray, n_warmup: int = 50, n_runs: int = 200
61
+ ) -> Dict[str, float]:
62
+ """Benchmark inference latency and throughput."""
63
+ image = _as_numpy(image, expected_dtype=self.input_dtype)
64
+ for _ in range(n_warmup):
65
+ self.session.run(None, {self.input_name: image})
66
+
67
+ times = []
68
+ for _ in range(n_runs):
69
+ start = time.perf_counter()
70
+ self.session.run(None, {self.input_name: image})
71
+ times.append(time.perf_counter() - start)
72
+
73
+ times_ms = np.array(times) * 1000.0
74
+ mean_ms = float(times_ms.mean())
75
+ p50_ms = float(np.percentile(times_ms, 50))
76
+ p95_ms = float(np.percentile(times_ms, 95))
77
+ fps = float(1000.0 / mean_ms) if mean_ms > 0 else 0.0
78
+
79
+ return {
80
+ "latency_ms_mean": mean_ms,
81
+ "latency_ms_p50": p50_ms,
82
+ "latency_ms_p95": p95_ms,
83
+ "fps": fps,
84
+ }
85
+
86
+
87
+ def _select_providers(device: str, available: Iterable[str]) -> Iterable[str]:
88
+ device = device.lower()
89
+ available = list(available)
90
+
91
+ if device in ("cpu", "host"):
92
+ return ["CPUExecutionProvider"]
93
+
94
+ if device.startswith("cuda") or device == "auto":
95
+ # Note: We don't include TensorrtExecutionProvider here because:
96
+ # 1. We have a dedicated TensorRTPredictor for native TRT inference
97
+ # 2. ORT's TensorRT provider requires TRT libs in LD_LIBRARY_PATH
98
+ preferred = [
99
+ "CUDAExecutionProvider",
100
+ "CPUExecutionProvider",
101
+ ]
102
+ return [p for p in preferred if p in available] or available
103
+
104
+ if device in ("directml", "dml"):
105
+ preferred = ["DmlExecutionProvider", "CPUExecutionProvider"]
106
+ return [p for p in preferred if p in available] or available
107
+
108
+ return available
109
+
110
+
111
+ def _onnx_type_to_numpy(type_str: str | None) -> np.dtype | None:
112
+ if not type_str:
113
+ return None
114
+ if type_str.startswith("tensor(") and type_str.endswith(")"):
115
+ key = type_str[len("tensor(") : -1]
116
+ else:
117
+ return None
118
+
119
+ mapping = {
120
+ "float": np.float32,
121
+ "float16": np.float16,
122
+ "double": np.float64,
123
+ "uint8": np.uint8,
124
+ "int8": np.int8,
125
+ "uint16": np.uint16,
126
+ "int16": np.int16,
127
+ "uint32": np.uint32,
128
+ "int32": np.int32,
129
+ "uint64": np.uint64,
130
+ "int64": np.int64,
131
+ }
132
+ return mapping.get(key)
133
+
134
+
135
+ def _as_numpy(image: np.ndarray, expected_dtype: np.dtype | None = None) -> np.ndarray:
136
+ if isinstance(image, np.ndarray):
137
+ data = image
138
+ else:
139
+ try:
140
+ import torch
141
+
142
+ if isinstance(image, torch.Tensor):
143
+ data = image.detach().cpu().numpy()
144
+ else:
145
+ data = np.asarray(image)
146
+ except ImportError:
147
+ data = np.asarray(image)
148
+
149
+ if expected_dtype is not None:
150
+ if data.dtype != expected_dtype:
151
+ data = data.astype(expected_dtype)
152
+ elif data.dtype != np.float32:
153
+ data = data.astype(np.float32)
154
+ return data
@@ -0,0 +1,312 @@
1
+ """TensorRT predictor for exported models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from sleap_nn.export.predictors.base import ExportPredictor
13
+
14
+
15
+ class TensorRTPredictor(ExportPredictor):
16
+ """TensorRT inference for exported models.
17
+
18
+ This predictor loads a native TensorRT engine file (.trt) and provides
19
+ inference capabilities using CUDA for high-throughput predictions.
20
+
21
+ Args:
22
+ engine_path: Path to the TensorRT engine file (.trt).
23
+ device: Device to run inference on (only "cuda" supported for TRT).
24
+
25
+ Example:
26
+ >>> predictor = TensorRTPredictor("model.trt")
27
+ >>> outputs = predictor.predict(images) # uint8 [B, C, H, W]
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ engine_path: str | Path,
33
+ device: str = "cuda",
34
+ ) -> None:
35
+ """Initialize TensorRT predictor with a serialized engine.
36
+
37
+ Args:
38
+ engine_path: Path to the TensorRT engine file.
39
+ device: Device for inference ("cuda" or "auto"). TensorRT requires CUDA.
40
+ """
41
+ import tensorrt as trt
42
+
43
+ if device not in ("cuda", "auto"):
44
+ raise ValueError(f"TensorRT only supports CUDA devices, got: {device}")
45
+
46
+ self.engine_path = Path(engine_path)
47
+ if not self.engine_path.exists():
48
+ raise FileNotFoundError(f"TensorRT engine not found: {engine_path}")
49
+
50
+ self.logger = trt.Logger(trt.Logger.WARNING)
51
+
52
+ # Load engine
53
+ with open(self.engine_path, "rb") as f:
54
+ self.engine = trt.Runtime(self.logger).deserialize_cuda_engine(f.read())
55
+
56
+ if self.engine is None:
57
+ raise RuntimeError(f"Failed to load TensorRT engine: {engine_path}")
58
+
59
+ # Create execution context
60
+ self.context = self.engine.create_execution_context()
61
+
62
+ # Get input/output info
63
+ self.input_names: List[str] = []
64
+ self.output_names: List[str] = []
65
+ self.input_shapes: Dict[str, tuple] = {}
66
+ self.output_shapes: Dict[str, tuple] = {}
67
+
68
+ for i in range(self.engine.num_io_tensors):
69
+ name = self.engine.get_tensor_name(i)
70
+ shape = tuple(self.engine.get_tensor_shape(name))
71
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
72
+ self.input_names.append(name)
73
+ self.input_shapes[name] = shape
74
+ else:
75
+ self.output_names.append(name)
76
+ self.output_shapes[name] = shape
77
+
78
+ self.device = torch.device("cuda")
79
+
80
+ def predict(self, image: np.ndarray) -> Dict[str, np.ndarray]:
81
+ """Run TensorRT inference.
82
+
83
+ Args:
84
+ image: Input image(s) as numpy array [B, C, H, W] with uint8 dtype.
85
+
86
+ Returns:
87
+ Dict mapping output names to numpy arrays.
88
+ """
89
+ import tensorrt as trt
90
+
91
+ # Convert input to torch tensor on GPU
92
+ input_tensor = torch.from_numpy(image).to(self.device)
93
+
94
+ # Check if engine expects uint8 or float32 and convert if needed
95
+ input_name = self.input_names[0]
96
+ expected_dtype = self.engine.get_tensor_dtype(input_name)
97
+ if expected_dtype == trt.DataType.UINT8:
98
+ # Engine expects uint8 - keep as uint8
99
+ if input_tensor.dtype != torch.uint8:
100
+ input_tensor = input_tensor.to(torch.uint8)
101
+ else:
102
+ # Engine expects float - convert uint8 to float32
103
+ if input_tensor.dtype == torch.uint8:
104
+ input_tensor = input_tensor.to(torch.float32)
105
+
106
+ # Ensure contiguous memory
107
+ input_tensor = input_tensor.contiguous()
108
+
109
+ # Set input shape for dynamic dimensions
110
+ input_name = self.input_names[0]
111
+ self.context.set_input_shape(input_name, tuple(input_tensor.shape))
112
+
113
+ # Allocate output tensors
114
+ outputs: Dict[str, torch.Tensor] = {}
115
+ bindings: Dict[str, int] = {}
116
+
117
+ # Set input binding
118
+ bindings[input_name] = input_tensor.data_ptr()
119
+
120
+ # Allocate outputs
121
+ for name in self.output_names:
122
+ shape = self.context.get_tensor_shape(name)
123
+ dtype = self._trt_dtype_to_torch(self.engine.get_tensor_dtype(name))
124
+ outputs[name] = torch.empty(tuple(shape), dtype=dtype, device=self.device)
125
+ bindings[name] = outputs[name].data_ptr()
126
+
127
+ # Set tensor addresses
128
+ for name, ptr in bindings.items():
129
+ self.context.set_tensor_address(name, ptr)
130
+
131
+ # Run inference
132
+ stream = torch.cuda.current_stream().cuda_stream
133
+ success = self.context.execute_async_v3(stream)
134
+ torch.cuda.current_stream().synchronize()
135
+
136
+ if not success:
137
+ raise RuntimeError("TensorRT inference failed")
138
+
139
+ # Convert outputs to numpy
140
+ return {name: tensor.cpu().numpy() for name, tensor in outputs.items()}
141
+
142
+ def benchmark(
143
+ self, image: np.ndarray, n_warmup: int = 50, n_runs: int = 200
144
+ ) -> Dict[str, float]:
145
+ """Benchmark inference performance.
146
+
147
+ Args:
148
+ image: Input image(s) as numpy array [B, C, H, W].
149
+ n_warmup: Number of warmup runs (not timed).
150
+ n_runs: Number of timed runs.
151
+
152
+ Returns:
153
+ Dict with timing statistics:
154
+ - mean_ms: Mean inference time in milliseconds
155
+ - std_ms: Standard deviation of inference time
156
+ - min_ms: Minimum inference time
157
+ - max_ms: Maximum inference time
158
+ - fps: Frames per second (based on mean time and batch size)
159
+ """
160
+ batch_size = image.shape[0]
161
+
162
+ # Warmup
163
+ for _ in range(n_warmup):
164
+ _ = self.predict(image)
165
+
166
+ # Timed runs
167
+ times = []
168
+ for _ in range(n_runs):
169
+ start = time.perf_counter()
170
+ _ = self.predict(image)
171
+ times.append((time.perf_counter() - start) * 1000)
172
+
173
+ times_arr = np.array(times)
174
+ mean_ms = float(np.mean(times_arr))
175
+
176
+ return {
177
+ "mean_ms": mean_ms,
178
+ "std_ms": float(np.std(times_arr)),
179
+ "min_ms": float(np.min(times_arr)),
180
+ "max_ms": float(np.max(times_arr)),
181
+ "fps": (batch_size * 1000) / mean_ms if mean_ms > 0 else 0.0,
182
+ }
183
+
184
+ @staticmethod
185
+ def _trt_dtype_to_torch(trt_dtype):
186
+ """Convert TensorRT dtype to PyTorch dtype."""
187
+ import tensorrt as trt
188
+
189
+ mapping = {
190
+ trt.DataType.FLOAT: torch.float32,
191
+ trt.DataType.HALF: torch.float16,
192
+ trt.DataType.INT32: torch.int32,
193
+ trt.DataType.INT8: torch.int8,
194
+ trt.DataType.BOOL: torch.bool,
195
+ }
196
+ return mapping.get(trt_dtype, torch.float32)
197
+
198
+
199
+ class TensorRTEngine:
200
+ """Low-level wrapper for native TensorRT engine files (.trt).
201
+
202
+ Provides a callable interface similar to PyTorch models, returning
203
+ torch.Tensor outputs directly.
204
+
205
+ Args:
206
+ engine_path: Path to the TensorRT engine file (.trt).
207
+
208
+ Example:
209
+ >>> engine = TensorRTEngine("model.trt")
210
+ >>> input_tensor = torch.randn(1, 1, 512, 512, device="cuda")
211
+ >>> outputs = engine(input_tensor) # Dict[str, torch.Tensor]
212
+ """
213
+
214
+ def __init__(self, engine_path: str | Path) -> None:
215
+ """Initialize TensorRT engine wrapper.
216
+
217
+ Args:
218
+ engine_path: Path to the serialized TensorRT engine file.
219
+ """
220
+ import tensorrt as trt
221
+
222
+ self.engine_path = Path(engine_path)
223
+ self.logger = trt.Logger(trt.Logger.WARNING)
224
+
225
+ # Load engine
226
+ with open(engine_path, "rb") as f:
227
+ self.engine = trt.Runtime(self.logger).deserialize_cuda_engine(f.read())
228
+
229
+ if self.engine is None:
230
+ raise RuntimeError(f"Failed to load TensorRT engine: {engine_path}")
231
+
232
+ # Create execution context
233
+ self.context = self.engine.create_execution_context()
234
+
235
+ # Get input/output info
236
+ self.input_names: List[str] = []
237
+ self.output_names: List[str] = []
238
+ self.input_shapes: Dict[str, tuple] = {}
239
+ self.output_shapes: Dict[str, tuple] = {}
240
+
241
+ for i in range(self.engine.num_io_tensors):
242
+ name = self.engine.get_tensor_name(i)
243
+ shape = tuple(self.engine.get_tensor_shape(name))
244
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
245
+ self.input_names.append(name)
246
+ self.input_shapes[name] = shape
247
+ else:
248
+ self.output_names.append(name)
249
+ self.output_shapes[name] = shape
250
+
251
+ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
252
+ """Run inference.
253
+
254
+ Args:
255
+ *args: Input tensors (positional) or
256
+ **kwargs: Input tensors by name
257
+
258
+ Returns:
259
+ Dict of output tensors on the same device as input.
260
+ """
261
+ import tensorrt as trt
262
+
263
+ # Handle inputs
264
+ if args:
265
+ inputs = {self.input_names[i]: arg for i, arg in enumerate(args)}
266
+ else:
267
+ inputs = kwargs
268
+
269
+ # Set input shapes and allocate outputs
270
+ outputs: Dict[str, torch.Tensor] = {}
271
+ bindings: Dict[str, int] = {}
272
+
273
+ for name in self.input_names:
274
+ tensor = inputs[name]
275
+ if not isinstance(tensor, torch.Tensor):
276
+ raise ValueError(f"Input {name} must be a torch.Tensor")
277
+ # Set actual shape for dynamic dimensions
278
+ self.context.set_input_shape(name, tuple(tensor.shape))
279
+ bindings[name] = tensor.contiguous().data_ptr()
280
+
281
+ # Allocate output tensors
282
+ device = next(iter(inputs.values())).device
283
+ for name in self.output_names:
284
+ shape = self.context.get_tensor_shape(name)
285
+ dtype = self._trt_dtype_to_torch(self.engine.get_tensor_dtype(name))
286
+ outputs[name] = torch.empty(tuple(shape), dtype=dtype, device=device)
287
+ bindings[name] = outputs[name].data_ptr()
288
+
289
+ # Set tensor addresses
290
+ for name, ptr in bindings.items():
291
+ self.context.set_tensor_address(name, ptr)
292
+
293
+ # Run inference
294
+ stream = torch.cuda.current_stream().cuda_stream
295
+ self.context.execute_async_v3(stream)
296
+ torch.cuda.current_stream().synchronize()
297
+
298
+ return outputs
299
+
300
+ @staticmethod
301
+ def _trt_dtype_to_torch(trt_dtype):
302
+ """Convert TensorRT dtype to PyTorch dtype."""
303
+ import tensorrt as trt
304
+
305
+ mapping = {
306
+ trt.DataType.FLOAT: torch.float32,
307
+ trt.DataType.HALF: torch.float16,
308
+ trt.DataType.INT32: torch.int32,
309
+ trt.DataType.INT8: torch.int8,
310
+ trt.DataType.BOOL: torch.bool,
311
+ }
312
+ return mapping.get(trt_dtype, torch.float32)