sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +168 -39
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/peak_finding.py +47 -17
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +213 -106
- sleap_nn/predict.py +35 -7
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +69 -22
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +67 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""Predictors for exported models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from sleap_nn.export.predictors.base import ExportPredictor
|
|
11
|
+
from sleap_nn.export.predictors.onnx import ONNXPredictor
|
|
12
|
+
from sleap_nn.export.predictors.tensorrt import TensorRTPredictor
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def detect_runtime(model_path: str | Path) -> str:
|
|
16
|
+
"""Auto-detect runtime from file extension or folder contents."""
|
|
17
|
+
model_path = Path(model_path)
|
|
18
|
+
if model_path.is_dir():
|
|
19
|
+
onnx_path = model_path / "exported" / "model.onnx"
|
|
20
|
+
trt_path = model_path / "exported" / "model.trt"
|
|
21
|
+
if trt_path.exists():
|
|
22
|
+
return "tensorrt"
|
|
23
|
+
if onnx_path.exists():
|
|
24
|
+
return "onnx"
|
|
25
|
+
raise ValueError(f"No exported model found in {model_path}")
|
|
26
|
+
|
|
27
|
+
ext = model_path.suffix.lower()
|
|
28
|
+
if ext == ".onnx":
|
|
29
|
+
return "onnx"
|
|
30
|
+
if ext in (".trt", ".engine"):
|
|
31
|
+
return "tensorrt"
|
|
32
|
+
|
|
33
|
+
raise ValueError(f"Unknown model format: {ext}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_exported_model(
|
|
37
|
+
model_path: str | Path,
|
|
38
|
+
runtime: str = "auto",
|
|
39
|
+
device: str = "auto",
|
|
40
|
+
**kwargs,
|
|
41
|
+
) -> ExportPredictor:
|
|
42
|
+
"""Load an exported model and return a predictor instance."""
|
|
43
|
+
if runtime == "auto":
|
|
44
|
+
runtime = detect_runtime(model_path)
|
|
45
|
+
|
|
46
|
+
if device == "auto":
|
|
47
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
48
|
+
|
|
49
|
+
if runtime == "onnx":
|
|
50
|
+
return ONNXPredictor(str(model_path), device=device, **kwargs)
|
|
51
|
+
if runtime == "tensorrt":
|
|
52
|
+
return TensorRTPredictor(str(model_path), device=device, **kwargs)
|
|
53
|
+
|
|
54
|
+
raise ValueError(f"Unknown runtime: {runtime}")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
__all__ = [
|
|
58
|
+
"ExportPredictor",
|
|
59
|
+
"ONNXPredictor",
|
|
60
|
+
"TensorRTPredictor",
|
|
61
|
+
"detect_runtime",
|
|
62
|
+
"load_exported_model",
|
|
63
|
+
]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Predictor base class for exported models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Dict
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ExportPredictor(ABC):
|
|
12
|
+
"""Base interface for exported model inference."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def predict(self, image: np.ndarray) -> Dict[str, np.ndarray]:
|
|
16
|
+
"""Run inference on a batch of images."""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def benchmark(
|
|
20
|
+
self, image: np.ndarray, n_warmup: int = 50, n_runs: int = 200
|
|
21
|
+
) -> Dict[str, float]:
|
|
22
|
+
"""Benchmark inference latency and throughput."""
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""ONNX Runtime predictor."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Iterable, Optional
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from sleap_nn.export.predictors.base import ExportPredictor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ONNXPredictor(ExportPredictor):
|
|
14
|
+
"""ONNX Runtime inference with provider selection."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
model_path: str,
|
|
19
|
+
device: str = "auto",
|
|
20
|
+
providers: Optional[Iterable[str]] = None,
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initialize ONNX predictor with execution providers.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
model_path: Path to the ONNX model file.
|
|
26
|
+
device: Device for inference ("auto", "cpu", or "cuda").
|
|
27
|
+
providers: ONNX Runtime execution providers. Auto-selected if None.
|
|
28
|
+
"""
|
|
29
|
+
try:
|
|
30
|
+
import onnxruntime as ort
|
|
31
|
+
except ImportError as exc:
|
|
32
|
+
raise ImportError(
|
|
33
|
+
"onnxruntime is required for ONNXPredictor. Install with "
|
|
34
|
+
"`pip install onnxruntime` or `onnxruntime-gpu`."
|
|
35
|
+
) from exc
|
|
36
|
+
|
|
37
|
+
# Preload CUDA/cuDNN libraries from pip-installed nvidia packages
|
|
38
|
+
# This is required for onnxruntime-gpu to find the CUDA libraries
|
|
39
|
+
if hasattr(ort, "preload_dlls"):
|
|
40
|
+
ort.preload_dlls()
|
|
41
|
+
|
|
42
|
+
self.ort = ort
|
|
43
|
+
if providers is None:
|
|
44
|
+
providers = _select_providers(device, ort.get_available_providers())
|
|
45
|
+
|
|
46
|
+
self.session = ort.InferenceSession(model_path, providers=list(providers))
|
|
47
|
+
input_info = self.session.get_inputs()[0]
|
|
48
|
+
self.input_name = input_info.name
|
|
49
|
+
self.input_type = input_info.type
|
|
50
|
+
self.input_dtype = _onnx_type_to_numpy(self.input_type)
|
|
51
|
+
self.output_names = [out.name for out in self.session.get_outputs()]
|
|
52
|
+
|
|
53
|
+
def predict(self, image: np.ndarray) -> Dict[str, np.ndarray]:
|
|
54
|
+
"""Run inference on a batch of images."""
|
|
55
|
+
image = _as_numpy(image, expected_dtype=self.input_dtype)
|
|
56
|
+
outputs = self.session.run(None, {self.input_name: image})
|
|
57
|
+
return dict(zip(self.output_names, outputs))
|
|
58
|
+
|
|
59
|
+
def benchmark(
|
|
60
|
+
self, image: np.ndarray, n_warmup: int = 50, n_runs: int = 200
|
|
61
|
+
) -> Dict[str, float]:
|
|
62
|
+
"""Benchmark inference latency and throughput."""
|
|
63
|
+
image = _as_numpy(image, expected_dtype=self.input_dtype)
|
|
64
|
+
for _ in range(n_warmup):
|
|
65
|
+
self.session.run(None, {self.input_name: image})
|
|
66
|
+
|
|
67
|
+
times = []
|
|
68
|
+
for _ in range(n_runs):
|
|
69
|
+
start = time.perf_counter()
|
|
70
|
+
self.session.run(None, {self.input_name: image})
|
|
71
|
+
times.append(time.perf_counter() - start)
|
|
72
|
+
|
|
73
|
+
times_ms = np.array(times) * 1000.0
|
|
74
|
+
mean_ms = float(times_ms.mean())
|
|
75
|
+
p50_ms = float(np.percentile(times_ms, 50))
|
|
76
|
+
p95_ms = float(np.percentile(times_ms, 95))
|
|
77
|
+
fps = float(1000.0 / mean_ms) if mean_ms > 0 else 0.0
|
|
78
|
+
|
|
79
|
+
return {
|
|
80
|
+
"latency_ms_mean": mean_ms,
|
|
81
|
+
"latency_ms_p50": p50_ms,
|
|
82
|
+
"latency_ms_p95": p95_ms,
|
|
83
|
+
"fps": fps,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _select_providers(device: str, available: Iterable[str]) -> Iterable[str]:
|
|
88
|
+
device = device.lower()
|
|
89
|
+
available = list(available)
|
|
90
|
+
|
|
91
|
+
if device in ("cpu", "host"):
|
|
92
|
+
return ["CPUExecutionProvider"]
|
|
93
|
+
|
|
94
|
+
if device.startswith("cuda") or device == "auto":
|
|
95
|
+
# Note: We don't include TensorrtExecutionProvider here because:
|
|
96
|
+
# 1. We have a dedicated TensorRTPredictor for native TRT inference
|
|
97
|
+
# 2. ORT's TensorRT provider requires TRT libs in LD_LIBRARY_PATH
|
|
98
|
+
preferred = [
|
|
99
|
+
"CUDAExecutionProvider",
|
|
100
|
+
"CPUExecutionProvider",
|
|
101
|
+
]
|
|
102
|
+
return [p for p in preferred if p in available] or available
|
|
103
|
+
|
|
104
|
+
if device in ("directml", "dml"):
|
|
105
|
+
preferred = ["DmlExecutionProvider", "CPUExecutionProvider"]
|
|
106
|
+
return [p for p in preferred if p in available] or available
|
|
107
|
+
|
|
108
|
+
return available
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _onnx_type_to_numpy(type_str: str | None) -> np.dtype | None:
|
|
112
|
+
if not type_str:
|
|
113
|
+
return None
|
|
114
|
+
if type_str.startswith("tensor(") and type_str.endswith(")"):
|
|
115
|
+
key = type_str[len("tensor(") : -1]
|
|
116
|
+
else:
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
mapping = {
|
|
120
|
+
"float": np.float32,
|
|
121
|
+
"float16": np.float16,
|
|
122
|
+
"double": np.float64,
|
|
123
|
+
"uint8": np.uint8,
|
|
124
|
+
"int8": np.int8,
|
|
125
|
+
"uint16": np.uint16,
|
|
126
|
+
"int16": np.int16,
|
|
127
|
+
"uint32": np.uint32,
|
|
128
|
+
"int32": np.int32,
|
|
129
|
+
"uint64": np.uint64,
|
|
130
|
+
"int64": np.int64,
|
|
131
|
+
}
|
|
132
|
+
return mapping.get(key)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _as_numpy(image: np.ndarray, expected_dtype: np.dtype | None = None) -> np.ndarray:
|
|
136
|
+
if isinstance(image, np.ndarray):
|
|
137
|
+
data = image
|
|
138
|
+
else:
|
|
139
|
+
try:
|
|
140
|
+
import torch
|
|
141
|
+
|
|
142
|
+
if isinstance(image, torch.Tensor):
|
|
143
|
+
data = image.detach().cpu().numpy()
|
|
144
|
+
else:
|
|
145
|
+
data = np.asarray(image)
|
|
146
|
+
except ImportError:
|
|
147
|
+
data = np.asarray(image)
|
|
148
|
+
|
|
149
|
+
if expected_dtype is not None:
|
|
150
|
+
if data.dtype != expected_dtype:
|
|
151
|
+
data = data.astype(expected_dtype)
|
|
152
|
+
elif data.dtype != np.float32:
|
|
153
|
+
data = data.astype(np.float32)
|
|
154
|
+
return data
|
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
"""TensorRT predictor for exported models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from sleap_nn.export.predictors.base import ExportPredictor
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TensorRTPredictor(ExportPredictor):
|
|
16
|
+
"""TensorRT inference for exported models.
|
|
17
|
+
|
|
18
|
+
This predictor loads a native TensorRT engine file (.trt) and provides
|
|
19
|
+
inference capabilities using CUDA for high-throughput predictions.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
engine_path: Path to the TensorRT engine file (.trt).
|
|
23
|
+
device: Device to run inference on (only "cuda" supported for TRT).
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
>>> predictor = TensorRTPredictor("model.trt")
|
|
27
|
+
>>> outputs = predictor.predict(images) # uint8 [B, C, H, W]
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
engine_path: str | Path,
|
|
33
|
+
device: str = "cuda",
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Initialize TensorRT predictor with a serialized engine.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
engine_path: Path to the TensorRT engine file.
|
|
39
|
+
device: Device for inference ("cuda" or "auto"). TensorRT requires CUDA.
|
|
40
|
+
"""
|
|
41
|
+
import tensorrt as trt
|
|
42
|
+
|
|
43
|
+
if device not in ("cuda", "auto"):
|
|
44
|
+
raise ValueError(f"TensorRT only supports CUDA devices, got: {device}")
|
|
45
|
+
|
|
46
|
+
self.engine_path = Path(engine_path)
|
|
47
|
+
if not self.engine_path.exists():
|
|
48
|
+
raise FileNotFoundError(f"TensorRT engine not found: {engine_path}")
|
|
49
|
+
|
|
50
|
+
self.logger = trt.Logger(trt.Logger.WARNING)
|
|
51
|
+
|
|
52
|
+
# Load engine
|
|
53
|
+
with open(self.engine_path, "rb") as f:
|
|
54
|
+
self.engine = trt.Runtime(self.logger).deserialize_cuda_engine(f.read())
|
|
55
|
+
|
|
56
|
+
if self.engine is None:
|
|
57
|
+
raise RuntimeError(f"Failed to load TensorRT engine: {engine_path}")
|
|
58
|
+
|
|
59
|
+
# Create execution context
|
|
60
|
+
self.context = self.engine.create_execution_context()
|
|
61
|
+
|
|
62
|
+
# Get input/output info
|
|
63
|
+
self.input_names: List[str] = []
|
|
64
|
+
self.output_names: List[str] = []
|
|
65
|
+
self.input_shapes: Dict[str, tuple] = {}
|
|
66
|
+
self.output_shapes: Dict[str, tuple] = {}
|
|
67
|
+
|
|
68
|
+
for i in range(self.engine.num_io_tensors):
|
|
69
|
+
name = self.engine.get_tensor_name(i)
|
|
70
|
+
shape = tuple(self.engine.get_tensor_shape(name))
|
|
71
|
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
|
72
|
+
self.input_names.append(name)
|
|
73
|
+
self.input_shapes[name] = shape
|
|
74
|
+
else:
|
|
75
|
+
self.output_names.append(name)
|
|
76
|
+
self.output_shapes[name] = shape
|
|
77
|
+
|
|
78
|
+
self.device = torch.device("cuda")
|
|
79
|
+
|
|
80
|
+
def predict(self, image: np.ndarray) -> Dict[str, np.ndarray]:
|
|
81
|
+
"""Run TensorRT inference.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
image: Input image(s) as numpy array [B, C, H, W] with uint8 dtype.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Dict mapping output names to numpy arrays.
|
|
88
|
+
"""
|
|
89
|
+
import tensorrt as trt
|
|
90
|
+
|
|
91
|
+
# Convert input to torch tensor on GPU
|
|
92
|
+
input_tensor = torch.from_numpy(image).to(self.device)
|
|
93
|
+
|
|
94
|
+
# Check if engine expects uint8 or float32 and convert if needed
|
|
95
|
+
input_name = self.input_names[0]
|
|
96
|
+
expected_dtype = self.engine.get_tensor_dtype(input_name)
|
|
97
|
+
if expected_dtype == trt.DataType.UINT8:
|
|
98
|
+
# Engine expects uint8 - keep as uint8
|
|
99
|
+
if input_tensor.dtype != torch.uint8:
|
|
100
|
+
input_tensor = input_tensor.to(torch.uint8)
|
|
101
|
+
else:
|
|
102
|
+
# Engine expects float - convert uint8 to float32
|
|
103
|
+
if input_tensor.dtype == torch.uint8:
|
|
104
|
+
input_tensor = input_tensor.to(torch.float32)
|
|
105
|
+
|
|
106
|
+
# Ensure contiguous memory
|
|
107
|
+
input_tensor = input_tensor.contiguous()
|
|
108
|
+
|
|
109
|
+
# Set input shape for dynamic dimensions
|
|
110
|
+
input_name = self.input_names[0]
|
|
111
|
+
self.context.set_input_shape(input_name, tuple(input_tensor.shape))
|
|
112
|
+
|
|
113
|
+
# Allocate output tensors
|
|
114
|
+
outputs: Dict[str, torch.Tensor] = {}
|
|
115
|
+
bindings: Dict[str, int] = {}
|
|
116
|
+
|
|
117
|
+
# Set input binding
|
|
118
|
+
bindings[input_name] = input_tensor.data_ptr()
|
|
119
|
+
|
|
120
|
+
# Allocate outputs
|
|
121
|
+
for name in self.output_names:
|
|
122
|
+
shape = self.context.get_tensor_shape(name)
|
|
123
|
+
dtype = self._trt_dtype_to_torch(self.engine.get_tensor_dtype(name))
|
|
124
|
+
outputs[name] = torch.empty(tuple(shape), dtype=dtype, device=self.device)
|
|
125
|
+
bindings[name] = outputs[name].data_ptr()
|
|
126
|
+
|
|
127
|
+
# Set tensor addresses
|
|
128
|
+
for name, ptr in bindings.items():
|
|
129
|
+
self.context.set_tensor_address(name, ptr)
|
|
130
|
+
|
|
131
|
+
# Run inference
|
|
132
|
+
stream = torch.cuda.current_stream().cuda_stream
|
|
133
|
+
success = self.context.execute_async_v3(stream)
|
|
134
|
+
torch.cuda.current_stream().synchronize()
|
|
135
|
+
|
|
136
|
+
if not success:
|
|
137
|
+
raise RuntimeError("TensorRT inference failed")
|
|
138
|
+
|
|
139
|
+
# Convert outputs to numpy
|
|
140
|
+
return {name: tensor.cpu().numpy() for name, tensor in outputs.items()}
|
|
141
|
+
|
|
142
|
+
def benchmark(
|
|
143
|
+
self, image: np.ndarray, n_warmup: int = 50, n_runs: int = 200
|
|
144
|
+
) -> Dict[str, float]:
|
|
145
|
+
"""Benchmark inference performance.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
image: Input image(s) as numpy array [B, C, H, W].
|
|
149
|
+
n_warmup: Number of warmup runs (not timed).
|
|
150
|
+
n_runs: Number of timed runs.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Dict with timing statistics:
|
|
154
|
+
- mean_ms: Mean inference time in milliseconds
|
|
155
|
+
- std_ms: Standard deviation of inference time
|
|
156
|
+
- min_ms: Minimum inference time
|
|
157
|
+
- max_ms: Maximum inference time
|
|
158
|
+
- fps: Frames per second (based on mean time and batch size)
|
|
159
|
+
"""
|
|
160
|
+
batch_size = image.shape[0]
|
|
161
|
+
|
|
162
|
+
# Warmup
|
|
163
|
+
for _ in range(n_warmup):
|
|
164
|
+
_ = self.predict(image)
|
|
165
|
+
|
|
166
|
+
# Timed runs
|
|
167
|
+
times = []
|
|
168
|
+
for _ in range(n_runs):
|
|
169
|
+
start = time.perf_counter()
|
|
170
|
+
_ = self.predict(image)
|
|
171
|
+
times.append((time.perf_counter() - start) * 1000)
|
|
172
|
+
|
|
173
|
+
times_arr = np.array(times)
|
|
174
|
+
mean_ms = float(np.mean(times_arr))
|
|
175
|
+
|
|
176
|
+
return {
|
|
177
|
+
"mean_ms": mean_ms,
|
|
178
|
+
"std_ms": float(np.std(times_arr)),
|
|
179
|
+
"min_ms": float(np.min(times_arr)),
|
|
180
|
+
"max_ms": float(np.max(times_arr)),
|
|
181
|
+
"fps": (batch_size * 1000) / mean_ms if mean_ms > 0 else 0.0,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
@staticmethod
|
|
185
|
+
def _trt_dtype_to_torch(trt_dtype):
|
|
186
|
+
"""Convert TensorRT dtype to PyTorch dtype."""
|
|
187
|
+
import tensorrt as trt
|
|
188
|
+
|
|
189
|
+
mapping = {
|
|
190
|
+
trt.DataType.FLOAT: torch.float32,
|
|
191
|
+
trt.DataType.HALF: torch.float16,
|
|
192
|
+
trt.DataType.INT32: torch.int32,
|
|
193
|
+
trt.DataType.INT8: torch.int8,
|
|
194
|
+
trt.DataType.BOOL: torch.bool,
|
|
195
|
+
}
|
|
196
|
+
return mapping.get(trt_dtype, torch.float32)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class TensorRTEngine:
|
|
200
|
+
"""Low-level wrapper for native TensorRT engine files (.trt).
|
|
201
|
+
|
|
202
|
+
Provides a callable interface similar to PyTorch models, returning
|
|
203
|
+
torch.Tensor outputs directly.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
engine_path: Path to the TensorRT engine file (.trt).
|
|
207
|
+
|
|
208
|
+
Example:
|
|
209
|
+
>>> engine = TensorRTEngine("model.trt")
|
|
210
|
+
>>> input_tensor = torch.randn(1, 1, 512, 512, device="cuda")
|
|
211
|
+
>>> outputs = engine(input_tensor) # Dict[str, torch.Tensor]
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(self, engine_path: str | Path) -> None:
|
|
215
|
+
"""Initialize TensorRT engine wrapper.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
engine_path: Path to the serialized TensorRT engine file.
|
|
219
|
+
"""
|
|
220
|
+
import tensorrt as trt
|
|
221
|
+
|
|
222
|
+
self.engine_path = Path(engine_path)
|
|
223
|
+
self.logger = trt.Logger(trt.Logger.WARNING)
|
|
224
|
+
|
|
225
|
+
# Load engine
|
|
226
|
+
with open(engine_path, "rb") as f:
|
|
227
|
+
self.engine = trt.Runtime(self.logger).deserialize_cuda_engine(f.read())
|
|
228
|
+
|
|
229
|
+
if self.engine is None:
|
|
230
|
+
raise RuntimeError(f"Failed to load TensorRT engine: {engine_path}")
|
|
231
|
+
|
|
232
|
+
# Create execution context
|
|
233
|
+
self.context = self.engine.create_execution_context()
|
|
234
|
+
|
|
235
|
+
# Get input/output info
|
|
236
|
+
self.input_names: List[str] = []
|
|
237
|
+
self.output_names: List[str] = []
|
|
238
|
+
self.input_shapes: Dict[str, tuple] = {}
|
|
239
|
+
self.output_shapes: Dict[str, tuple] = {}
|
|
240
|
+
|
|
241
|
+
for i in range(self.engine.num_io_tensors):
|
|
242
|
+
name = self.engine.get_tensor_name(i)
|
|
243
|
+
shape = tuple(self.engine.get_tensor_shape(name))
|
|
244
|
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
|
245
|
+
self.input_names.append(name)
|
|
246
|
+
self.input_shapes[name] = shape
|
|
247
|
+
else:
|
|
248
|
+
self.output_names.append(name)
|
|
249
|
+
self.output_shapes[name] = shape
|
|
250
|
+
|
|
251
|
+
def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
|
|
252
|
+
"""Run inference.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
*args: Input tensors (positional) or
|
|
256
|
+
**kwargs: Input tensors by name
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Dict of output tensors on the same device as input.
|
|
260
|
+
"""
|
|
261
|
+
import tensorrt as trt
|
|
262
|
+
|
|
263
|
+
# Handle inputs
|
|
264
|
+
if args:
|
|
265
|
+
inputs = {self.input_names[i]: arg for i, arg in enumerate(args)}
|
|
266
|
+
else:
|
|
267
|
+
inputs = kwargs
|
|
268
|
+
|
|
269
|
+
# Set input shapes and allocate outputs
|
|
270
|
+
outputs: Dict[str, torch.Tensor] = {}
|
|
271
|
+
bindings: Dict[str, int] = {}
|
|
272
|
+
|
|
273
|
+
for name in self.input_names:
|
|
274
|
+
tensor = inputs[name]
|
|
275
|
+
if not isinstance(tensor, torch.Tensor):
|
|
276
|
+
raise ValueError(f"Input {name} must be a torch.Tensor")
|
|
277
|
+
# Set actual shape for dynamic dimensions
|
|
278
|
+
self.context.set_input_shape(name, tuple(tensor.shape))
|
|
279
|
+
bindings[name] = tensor.contiguous().data_ptr()
|
|
280
|
+
|
|
281
|
+
# Allocate output tensors
|
|
282
|
+
device = next(iter(inputs.values())).device
|
|
283
|
+
for name in self.output_names:
|
|
284
|
+
shape = self.context.get_tensor_shape(name)
|
|
285
|
+
dtype = self._trt_dtype_to_torch(self.engine.get_tensor_dtype(name))
|
|
286
|
+
outputs[name] = torch.empty(tuple(shape), dtype=dtype, device=device)
|
|
287
|
+
bindings[name] = outputs[name].data_ptr()
|
|
288
|
+
|
|
289
|
+
# Set tensor addresses
|
|
290
|
+
for name, ptr in bindings.items():
|
|
291
|
+
self.context.set_tensor_address(name, ptr)
|
|
292
|
+
|
|
293
|
+
# Run inference
|
|
294
|
+
stream = torch.cuda.current_stream().cuda_stream
|
|
295
|
+
self.context.execute_async_v3(stream)
|
|
296
|
+
torch.cuda.current_stream().synchronize()
|
|
297
|
+
|
|
298
|
+
return outputs
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
def _trt_dtype_to_torch(trt_dtype):
|
|
302
|
+
"""Convert TensorRT dtype to PyTorch dtype."""
|
|
303
|
+
import tensorrt as trt
|
|
304
|
+
|
|
305
|
+
mapping = {
|
|
306
|
+
trt.DataType.FLOAT: torch.float32,
|
|
307
|
+
trt.DataType.HALF: torch.float16,
|
|
308
|
+
trt.DataType.INT32: torch.int32,
|
|
309
|
+
trt.DataType.INT8: torch.int8,
|
|
310
|
+
trt.DataType.BOOL: torch.bool,
|
|
311
|
+
}
|
|
312
|
+
return mapping.get(trt_dtype, torch.float32)
|