sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +168 -39
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/peak_finding.py +47 -17
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +213 -106
- sleap_nn/predict.py +35 -7
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +69 -22
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +67 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Exporters for serialized model formats."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterable, Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from sleap_nn.export.exporters.onnx_exporter import export_to_onnx
|
|
11
|
+
from sleap_nn.export.exporters.tensorrt_exporter import export_to_tensorrt
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def export_model(
|
|
15
|
+
model: torch.nn.Module,
|
|
16
|
+
save_path: str | Path,
|
|
17
|
+
fmt: str = "onnx",
|
|
18
|
+
input_shape: Iterable[int] = (1, 1, 512, 512),
|
|
19
|
+
opset_version: int = 17,
|
|
20
|
+
output_names: Optional[list] = None,
|
|
21
|
+
verify: bool = True,
|
|
22
|
+
**kwargs,
|
|
23
|
+
) -> Path:
|
|
24
|
+
"""Export a model to the requested format."""
|
|
25
|
+
fmt = fmt.lower()
|
|
26
|
+
if fmt == "onnx":
|
|
27
|
+
return export_to_onnx(
|
|
28
|
+
model,
|
|
29
|
+
save_path,
|
|
30
|
+
input_shape=input_shape,
|
|
31
|
+
opset_version=opset_version,
|
|
32
|
+
output_names=output_names,
|
|
33
|
+
verify=verify,
|
|
34
|
+
)
|
|
35
|
+
if fmt == "tensorrt":
|
|
36
|
+
return export_to_tensorrt(model, save_path, input_shape=input_shape, **kwargs)
|
|
37
|
+
if fmt == "both":
|
|
38
|
+
export_to_onnx(
|
|
39
|
+
model,
|
|
40
|
+
save_path,
|
|
41
|
+
input_shape=input_shape,
|
|
42
|
+
opset_version=opset_version,
|
|
43
|
+
output_names=output_names,
|
|
44
|
+
verify=verify,
|
|
45
|
+
)
|
|
46
|
+
return export_to_tensorrt(model, save_path, input_shape=input_shape, **kwargs)
|
|
47
|
+
|
|
48
|
+
raise ValueError(f"Unknown export format: {fmt}")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
__all__ = ["export_model", "export_to_onnx", "export_to_tensorrt"]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""ONNX export utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, Iterable, List, Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def export_to_onnx(
|
|
12
|
+
model: torch.nn.Module,
|
|
13
|
+
save_path: str | Path,
|
|
14
|
+
input_shape: Iterable[int] = (1, 1, 512, 512),
|
|
15
|
+
input_dtype: torch.dtype = torch.uint8,
|
|
16
|
+
opset_version: int = 17,
|
|
17
|
+
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
|
|
18
|
+
input_names: Optional[List[str]] = None,
|
|
19
|
+
output_names: Optional[List[str]] = None,
|
|
20
|
+
do_constant_folding: bool = True,
|
|
21
|
+
verify: bool = True,
|
|
22
|
+
) -> Path:
|
|
23
|
+
"""Export a PyTorch model to ONNX."""
|
|
24
|
+
save_path = Path(save_path)
|
|
25
|
+
model.eval()
|
|
26
|
+
|
|
27
|
+
if input_names is None:
|
|
28
|
+
input_names = ["image"]
|
|
29
|
+
if dynamic_axes is None:
|
|
30
|
+
dynamic_axes = {"image": {0: "batch", 2: "height", 3: "width"}}
|
|
31
|
+
|
|
32
|
+
device = None
|
|
33
|
+
try:
|
|
34
|
+
device = next(model.parameters()).device
|
|
35
|
+
except StopIteration:
|
|
36
|
+
device = torch.device("cpu")
|
|
37
|
+
|
|
38
|
+
if input_dtype.is_floating_point:
|
|
39
|
+
dummy_input = torch.randn(*input_shape, device=device, dtype=input_dtype)
|
|
40
|
+
else:
|
|
41
|
+
dummy_input = torch.randint(
|
|
42
|
+
0, 256, input_shape, device=device, dtype=input_dtype
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if output_names is None:
|
|
46
|
+
with torch.no_grad():
|
|
47
|
+
test_out = model(dummy_input)
|
|
48
|
+
output_names = _infer_output_names(test_out)
|
|
49
|
+
|
|
50
|
+
torch.onnx.export(
|
|
51
|
+
model,
|
|
52
|
+
dummy_input,
|
|
53
|
+
save_path.as_posix(),
|
|
54
|
+
opset_version=opset_version,
|
|
55
|
+
input_names=input_names,
|
|
56
|
+
output_names=output_names,
|
|
57
|
+
dynamic_axes=dynamic_axes,
|
|
58
|
+
do_constant_folding=do_constant_folding,
|
|
59
|
+
dynamo=False,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if verify:
|
|
63
|
+
_verify_onnx(save_path)
|
|
64
|
+
|
|
65
|
+
return save_path
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _infer_output_names(output) -> List[str]:
|
|
69
|
+
if isinstance(output, dict):
|
|
70
|
+
return list(output.keys())
|
|
71
|
+
if isinstance(output, (list, tuple)):
|
|
72
|
+
return [f"output_{idx}" for idx in range(len(output))]
|
|
73
|
+
return ["output_0"]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _verify_onnx(path: Path) -> None:
|
|
77
|
+
import onnx
|
|
78
|
+
|
|
79
|
+
model = onnx.load(path.as_posix())
|
|
80
|
+
onnx.checker.check_model(model)
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
"""TensorRT export utilities."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterable, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def export_to_tensorrt(
|
|
13
|
+
model: nn.Module,
|
|
14
|
+
save_path: str | Path,
|
|
15
|
+
input_shape: Tuple[int, int, int, int] = (1, 1, 512, 512),
|
|
16
|
+
input_dtype: torch.dtype = torch.uint8,
|
|
17
|
+
precision: str = "fp16",
|
|
18
|
+
min_shape: Optional[Tuple[int, int, int, int]] = None,
|
|
19
|
+
opt_shape: Optional[Tuple[int, int, int, int]] = None,
|
|
20
|
+
max_shape: Optional[Tuple[int, int, int, int]] = None,
|
|
21
|
+
workspace_size: int = 2 << 30, # 2GB default
|
|
22
|
+
method: str = "onnx",
|
|
23
|
+
verbose: bool = True,
|
|
24
|
+
) -> Path:
|
|
25
|
+
"""Export a PyTorch model to TensorRT format.
|
|
26
|
+
|
|
27
|
+
This function supports multiple compilation methods:
|
|
28
|
+
- "onnx": Exports to ONNX first, then compiles with TensorRT (most reliable)
|
|
29
|
+
- "jit": Uses torch.jit.trace + torch_tensorrt.compile (alternative)
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model: The PyTorch model to export (typically an ONNX wrapper).
|
|
33
|
+
save_path: Path to save the TensorRT engine (.trt file).
|
|
34
|
+
input_shape: (B, C, H, W) optimal input tensor shape.
|
|
35
|
+
input_dtype: Input tensor dtype (torch.uint8 or torch.float32).
|
|
36
|
+
precision: Model precision - "fp32" or "fp16".
|
|
37
|
+
min_shape: Minimum input shape for dynamic shapes (default: batch=1, H/W halved).
|
|
38
|
+
opt_shape: Optimal input shape (default: same as input_shape).
|
|
39
|
+
max_shape: Maximum input shape (default: batch=16, H/W doubled).
|
|
40
|
+
workspace_size: TensorRT workspace size in bytes (default 2GB).
|
|
41
|
+
method: Compilation method - "onnx" or "jit".
|
|
42
|
+
verbose: Print export info.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Path to the exported TensorRT engine.
|
|
46
|
+
|
|
47
|
+
Note:
|
|
48
|
+
TensorRT models are NOT cross-platform. The exported model will only
|
|
49
|
+
work on the same GPU architecture and TensorRT version used for export.
|
|
50
|
+
"""
|
|
51
|
+
import tensorrt as trt
|
|
52
|
+
|
|
53
|
+
model.eval()
|
|
54
|
+
device = next(model.parameters()).device
|
|
55
|
+
|
|
56
|
+
save_path = Path(save_path)
|
|
57
|
+
if not save_path.suffix:
|
|
58
|
+
save_path = save_path.with_suffix(".trt")
|
|
59
|
+
|
|
60
|
+
B, C, H, W = input_shape
|
|
61
|
+
|
|
62
|
+
if min_shape is None:
|
|
63
|
+
min_shape = (1, C, H // 2, W // 2)
|
|
64
|
+
if opt_shape is None:
|
|
65
|
+
opt_shape = input_shape
|
|
66
|
+
if max_shape is None:
|
|
67
|
+
max_shape = (min(16, B * 4), C, H * 2, W * 2)
|
|
68
|
+
|
|
69
|
+
if verbose:
|
|
70
|
+
print(f"Exporting model to TensorRT...")
|
|
71
|
+
print(f" Input shape: {input_shape}")
|
|
72
|
+
print(f" Min/Opt/Max: {min_shape} / {opt_shape} / {max_shape}")
|
|
73
|
+
print(f" Precision: {precision}")
|
|
74
|
+
print(f" Workspace: {workspace_size / 1e9:.1f} GB")
|
|
75
|
+
print(f" Method: {method}")
|
|
76
|
+
|
|
77
|
+
if method == "onnx":
|
|
78
|
+
return _export_tensorrt_onnx(
|
|
79
|
+
model,
|
|
80
|
+
save_path,
|
|
81
|
+
input_shape,
|
|
82
|
+
input_dtype,
|
|
83
|
+
min_shape,
|
|
84
|
+
opt_shape,
|
|
85
|
+
max_shape,
|
|
86
|
+
precision,
|
|
87
|
+
workspace_size,
|
|
88
|
+
verbose,
|
|
89
|
+
)
|
|
90
|
+
elif method == "jit":
|
|
91
|
+
return _export_tensorrt_jit(
|
|
92
|
+
model,
|
|
93
|
+
save_path,
|
|
94
|
+
input_shape,
|
|
95
|
+
input_dtype,
|
|
96
|
+
min_shape,
|
|
97
|
+
opt_shape,
|
|
98
|
+
max_shape,
|
|
99
|
+
precision,
|
|
100
|
+
workspace_size,
|
|
101
|
+
verbose,
|
|
102
|
+
)
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Unknown method: {method}. Use 'onnx' or 'jit'.")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _export_tensorrt_onnx(
|
|
108
|
+
model: nn.Module,
|
|
109
|
+
save_path: Path,
|
|
110
|
+
input_shape: Tuple[int, int, int, int],
|
|
111
|
+
input_dtype: torch.dtype,
|
|
112
|
+
min_shape: Tuple[int, int, int, int],
|
|
113
|
+
opt_shape: Tuple[int, int, int, int],
|
|
114
|
+
max_shape: Tuple[int, int, int, int],
|
|
115
|
+
precision: str,
|
|
116
|
+
workspace_size: int,
|
|
117
|
+
verbose: bool,
|
|
118
|
+
) -> Path:
|
|
119
|
+
"""Export via ONNX, then compile to TensorRT engine."""
|
|
120
|
+
import tensorrt as trt
|
|
121
|
+
|
|
122
|
+
# Check if ONNX file already exists (from prior export step)
|
|
123
|
+
onnx_path = save_path.with_suffix(".onnx")
|
|
124
|
+
|
|
125
|
+
if onnx_path.exists():
|
|
126
|
+
if verbose:
|
|
127
|
+
print(f" Using existing ONNX file: {onnx_path}")
|
|
128
|
+
else:
|
|
129
|
+
# Need to export to ONNX first
|
|
130
|
+
device = next(model.parameters()).device
|
|
131
|
+
|
|
132
|
+
if verbose:
|
|
133
|
+
print(" Exporting to ONNX first...")
|
|
134
|
+
|
|
135
|
+
# Create example input with correct dtype
|
|
136
|
+
if input_dtype == torch.uint8:
|
|
137
|
+
example_input = torch.randint(
|
|
138
|
+
0, 255, input_shape, dtype=torch.uint8, device=device
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
example_input = torch.randn(*input_shape, dtype=input_dtype, device=device)
|
|
142
|
+
|
|
143
|
+
# Export to ONNX
|
|
144
|
+
torch.onnx.export(
|
|
145
|
+
model,
|
|
146
|
+
example_input,
|
|
147
|
+
onnx_path,
|
|
148
|
+
opset_version=17,
|
|
149
|
+
input_names=["images"],
|
|
150
|
+
dynamic_axes={"images": {0: "batch", 2: "height", 3: "width"}},
|
|
151
|
+
do_constant_folding=True,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
if verbose:
|
|
155
|
+
print(f" ONNX export complete: {onnx_path}")
|
|
156
|
+
|
|
157
|
+
if verbose:
|
|
158
|
+
print(f" Building TensorRT engine (this may take a while)...")
|
|
159
|
+
|
|
160
|
+
# Create TensorRT builder
|
|
161
|
+
logger = trt.Logger(trt.Logger.WARNING if not verbose else trt.Logger.INFO)
|
|
162
|
+
builder = trt.Builder(logger)
|
|
163
|
+
network = builder.create_network(
|
|
164
|
+
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
165
|
+
)
|
|
166
|
+
parser = trt.OnnxParser(network, logger)
|
|
167
|
+
|
|
168
|
+
# Parse ONNX model
|
|
169
|
+
with open(onnx_path, "rb") as f:
|
|
170
|
+
if not parser.parse(f.read()):
|
|
171
|
+
errors = []
|
|
172
|
+
for i in range(parser.num_errors):
|
|
173
|
+
errors.append(str(parser.get_error(i)))
|
|
174
|
+
raise RuntimeError(f"ONNX parsing failed: {errors}")
|
|
175
|
+
|
|
176
|
+
# Configure builder
|
|
177
|
+
config = builder.create_builder_config()
|
|
178
|
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size)
|
|
179
|
+
|
|
180
|
+
if precision == "fp16":
|
|
181
|
+
if builder.platform_has_fast_fp16:
|
|
182
|
+
config.set_flag(trt.BuilderFlag.FP16)
|
|
183
|
+
if verbose:
|
|
184
|
+
print(" Enabled FP16 mode")
|
|
185
|
+
else:
|
|
186
|
+
if verbose:
|
|
187
|
+
print(" WARNING: Platform does not have fast FP16, using FP32")
|
|
188
|
+
|
|
189
|
+
# Set up optimization profile for dynamic shapes
|
|
190
|
+
profile = builder.create_optimization_profile()
|
|
191
|
+
input_name = network.get_input(0).name
|
|
192
|
+
|
|
193
|
+
profile.set_shape(input_name, min_shape, opt_shape, max_shape)
|
|
194
|
+
config.add_optimization_profile(profile)
|
|
195
|
+
|
|
196
|
+
# Build engine
|
|
197
|
+
serialized_engine = builder.build_serialized_network(network, config)
|
|
198
|
+
if serialized_engine is None:
|
|
199
|
+
raise RuntimeError("Failed to build TensorRT engine")
|
|
200
|
+
|
|
201
|
+
# Save engine
|
|
202
|
+
engine_path = save_path.with_suffix(".trt")
|
|
203
|
+
with open(engine_path, "wb") as f:
|
|
204
|
+
f.write(serialized_engine)
|
|
205
|
+
|
|
206
|
+
if verbose:
|
|
207
|
+
import os
|
|
208
|
+
|
|
209
|
+
print(f" Exported TensorRT engine to: {engine_path}")
|
|
210
|
+
print(f" Engine size: {os.path.getsize(engine_path) / 1e6:.2f} MB")
|
|
211
|
+
|
|
212
|
+
return engine_path
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _export_tensorrt_jit(
|
|
216
|
+
model: nn.Module,
|
|
217
|
+
save_path: Path,
|
|
218
|
+
input_shape: Tuple[int, int, int, int],
|
|
219
|
+
input_dtype: torch.dtype,
|
|
220
|
+
min_shape: Tuple[int, int, int, int],
|
|
221
|
+
opt_shape: Tuple[int, int, int, int],
|
|
222
|
+
max_shape: Tuple[int, int, int, int],
|
|
223
|
+
precision: str,
|
|
224
|
+
workspace_size: int,
|
|
225
|
+
verbose: bool,
|
|
226
|
+
) -> Path:
|
|
227
|
+
"""Export using torch.jit.trace + torch_tensorrt.compile."""
|
|
228
|
+
import torch_tensorrt
|
|
229
|
+
|
|
230
|
+
device = next(model.parameters()).device
|
|
231
|
+
|
|
232
|
+
if verbose:
|
|
233
|
+
print(" Tracing model with torch.jit...")
|
|
234
|
+
|
|
235
|
+
# Create example input (float32 for tracing)
|
|
236
|
+
if input_dtype == torch.uint8:
|
|
237
|
+
example_input = torch.randint(
|
|
238
|
+
0, 255, input_shape, dtype=torch.uint8, device=device
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
example_input = torch.randn(*input_shape, dtype=input_dtype, device=device)
|
|
242
|
+
|
|
243
|
+
# Trace the model
|
|
244
|
+
with torch.no_grad():
|
|
245
|
+
traced_model = torch.jit.trace(model, example_input)
|
|
246
|
+
|
|
247
|
+
# Map precision to torch dtype
|
|
248
|
+
precision_map = {
|
|
249
|
+
"fp32": torch.float32,
|
|
250
|
+
"fp16": torch.float16,
|
|
251
|
+
}
|
|
252
|
+
if precision not in precision_map:
|
|
253
|
+
raise ValueError(f"Unknown precision: {precision}. Use 'fp32' or 'fp16'")
|
|
254
|
+
|
|
255
|
+
enabled_precisions = {precision_map[precision]}
|
|
256
|
+
if precision == "fp16":
|
|
257
|
+
enabled_precisions.add(torch.float32)
|
|
258
|
+
|
|
259
|
+
# Create input specs for TensorRT
|
|
260
|
+
trt_inputs = [
|
|
261
|
+
torch_tensorrt.Input(
|
|
262
|
+
min_shape=min_shape,
|
|
263
|
+
opt_shape=opt_shape,
|
|
264
|
+
max_shape=max_shape,
|
|
265
|
+
dtype=torch.float32, # TRT internally uses float32 input spec
|
|
266
|
+
)
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
if verbose:
|
|
270
|
+
print(" Compiling with TensorRT...")
|
|
271
|
+
|
|
272
|
+
# Compile with TensorRT
|
|
273
|
+
trt_model = torch_tensorrt.compile(
|
|
274
|
+
traced_model,
|
|
275
|
+
inputs=trt_inputs,
|
|
276
|
+
enabled_precisions=enabled_precisions,
|
|
277
|
+
workspace_size=workspace_size,
|
|
278
|
+
truncate_long_and_double=True,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Save as TorchScript
|
|
282
|
+
ts_path = save_path.with_suffix(".ts")
|
|
283
|
+
torch.jit.save(trt_model, ts_path)
|
|
284
|
+
|
|
285
|
+
if verbose:
|
|
286
|
+
import os
|
|
287
|
+
|
|
288
|
+
print(f" Exported TensorRT-accelerated model to: {ts_path}")
|
|
289
|
+
print(f" Model size: {os.path.getsize(ts_path) / 1e6:.2f} MB")
|
|
290
|
+
|
|
291
|
+
return ts_path
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Metadata helpers for exported models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, asdict
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
|
|
12
|
+
from sleap_nn import __version__
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class ExportMetadata:
|
|
17
|
+
"""Metadata embedded or saved alongside exported models."""
|
|
18
|
+
|
|
19
|
+
# Version info
|
|
20
|
+
sleap_nn_version: str
|
|
21
|
+
export_timestamp: str
|
|
22
|
+
export_format: str # "onnx" or "tensorrt"
|
|
23
|
+
|
|
24
|
+
# Model info
|
|
25
|
+
model_type: str # "centroid", "centered_instance", "topdown", "bottomup"
|
|
26
|
+
model_name: str
|
|
27
|
+
checkpoint_path: str
|
|
28
|
+
|
|
29
|
+
# Architecture
|
|
30
|
+
backbone: str
|
|
31
|
+
n_nodes: int
|
|
32
|
+
n_edges: int
|
|
33
|
+
node_names: List[str]
|
|
34
|
+
edge_inds: List[Tuple[int, int]]
|
|
35
|
+
|
|
36
|
+
# Input/output spec
|
|
37
|
+
input_scale: float
|
|
38
|
+
input_channels: int
|
|
39
|
+
output_stride: int
|
|
40
|
+
crop_size: Optional[Tuple[int, int]] = None
|
|
41
|
+
|
|
42
|
+
# Export parameters
|
|
43
|
+
max_instances: Optional[int] = None
|
|
44
|
+
max_peaks_per_node: Optional[int] = None
|
|
45
|
+
max_batch_size: int = 1
|
|
46
|
+
precision: str = "fp32"
|
|
47
|
+
|
|
48
|
+
# Preprocessing - input is uint8 [0,255], normalized internally to float32 [0,1]
|
|
49
|
+
input_dtype: str = "uint8"
|
|
50
|
+
normalization: str = "0_to_1"
|
|
51
|
+
|
|
52
|
+
# Multiclass model fields (optional)
|
|
53
|
+
n_classes: Optional[int] = None
|
|
54
|
+
class_names: Optional[List[str]] = None
|
|
55
|
+
|
|
56
|
+
# Training config reference
|
|
57
|
+
training_config_embedded: bool = False
|
|
58
|
+
training_config_hash: str = ""
|
|
59
|
+
|
|
60
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
61
|
+
"""Convert to JSON-serializable dict."""
|
|
62
|
+
data = asdict(self)
|
|
63
|
+
data["edge_inds"] = [list(pair) for pair in self.edge_inds]
|
|
64
|
+
if self.crop_size is not None:
|
|
65
|
+
data["crop_size"] = list(self.crop_size)
|
|
66
|
+
return data
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ExportMetadata":
|
|
70
|
+
"""Load from dict."""
|
|
71
|
+
edge_inds = [tuple(pair) for pair in data.get("edge_inds", [])]
|
|
72
|
+
crop_size = data.get("crop_size")
|
|
73
|
+
if crop_size is not None:
|
|
74
|
+
crop_size = tuple(crop_size)
|
|
75
|
+
return cls(
|
|
76
|
+
sleap_nn_version=data.get("sleap_nn_version", ""),
|
|
77
|
+
export_timestamp=data.get("export_timestamp", ""),
|
|
78
|
+
export_format=data.get("export_format", ""),
|
|
79
|
+
model_type=data.get("model_type", ""),
|
|
80
|
+
model_name=data.get("model_name", ""),
|
|
81
|
+
checkpoint_path=data.get("checkpoint_path", ""),
|
|
82
|
+
backbone=data.get("backbone", ""),
|
|
83
|
+
n_nodes=int(data.get("n_nodes", 0)),
|
|
84
|
+
n_edges=int(data.get("n_edges", 0)),
|
|
85
|
+
node_names=list(data.get("node_names", [])),
|
|
86
|
+
edge_inds=edge_inds,
|
|
87
|
+
input_scale=float(data.get("input_scale", 1.0)),
|
|
88
|
+
input_channels=int(data.get("input_channels", 1)),
|
|
89
|
+
output_stride=int(data.get("output_stride", 1)),
|
|
90
|
+
crop_size=crop_size,
|
|
91
|
+
max_instances=data.get("max_instances"),
|
|
92
|
+
max_peaks_per_node=data.get("max_peaks_per_node"),
|
|
93
|
+
max_batch_size=int(data.get("max_batch_size", 1)),
|
|
94
|
+
precision=data.get("precision", "fp32"),
|
|
95
|
+
input_dtype=data.get("input_dtype", "uint8"),
|
|
96
|
+
normalization=data.get("normalization", "0_to_1"),
|
|
97
|
+
n_classes=data.get("n_classes"),
|
|
98
|
+
class_names=data.get("class_names"),
|
|
99
|
+
training_config_embedded=bool(data.get("training_config_embedded", False)),
|
|
100
|
+
training_config_hash=data.get("training_config_hash", ""),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def save(self, path: str | Path) -> None:
|
|
104
|
+
"""Save to JSON file."""
|
|
105
|
+
path = Path(path)
|
|
106
|
+
path.write_text(json.dumps(self.to_dict(), indent=2, sort_keys=True))
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def load(cls, path: str | Path) -> "ExportMetadata":
|
|
110
|
+
"""Load from JSON file."""
|
|
111
|
+
path = Path(path)
|
|
112
|
+
data = json.loads(path.read_text())
|
|
113
|
+
return cls.from_dict(data)
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def default_timestamp(cls) -> str:
|
|
117
|
+
"""Return an ISO timestamp for export."""
|
|
118
|
+
return datetime.now().isoformat()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def hash_file(path: str | Path) -> str:
|
|
122
|
+
"""Compute SHA256 hash for a file."""
|
|
123
|
+
path = Path(path)
|
|
124
|
+
hasher = hashlib.sha256()
|
|
125
|
+
with path.open("rb") as handle:
|
|
126
|
+
for chunk in iter(lambda: handle.read(8192), b""):
|
|
127
|
+
hasher.update(chunk)
|
|
128
|
+
return hasher.hexdigest()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def build_base_metadata(
|
|
132
|
+
*,
|
|
133
|
+
export_format: str,
|
|
134
|
+
model_type: str,
|
|
135
|
+
model_name: str,
|
|
136
|
+
checkpoint_path: str,
|
|
137
|
+
backbone: str,
|
|
138
|
+
n_nodes: int,
|
|
139
|
+
n_edges: int,
|
|
140
|
+
node_names: List[str],
|
|
141
|
+
edge_inds: List[Tuple[int, int]],
|
|
142
|
+
input_scale: float,
|
|
143
|
+
input_channels: int,
|
|
144
|
+
output_stride: int,
|
|
145
|
+
crop_size: Optional[Tuple[int, int]] = None,
|
|
146
|
+
max_instances: Optional[int] = None,
|
|
147
|
+
max_peaks_per_node: Optional[int] = None,
|
|
148
|
+
max_batch_size: int = 1,
|
|
149
|
+
precision: str = "fp32",
|
|
150
|
+
training_config_hash: str = "",
|
|
151
|
+
training_config_embedded: bool = False,
|
|
152
|
+
input_dtype: str = "uint8",
|
|
153
|
+
normalization: str = "0_to_1",
|
|
154
|
+
n_classes: Optional[int] = None,
|
|
155
|
+
class_names: Optional[List[str]] = None,
|
|
156
|
+
) -> ExportMetadata:
|
|
157
|
+
"""Create an ExportMetadata instance with standard defaults."""
|
|
158
|
+
return ExportMetadata(
|
|
159
|
+
sleap_nn_version=__version__,
|
|
160
|
+
export_timestamp=ExportMetadata.default_timestamp(),
|
|
161
|
+
export_format=export_format,
|
|
162
|
+
model_type=model_type,
|
|
163
|
+
model_name=model_name,
|
|
164
|
+
checkpoint_path=checkpoint_path,
|
|
165
|
+
backbone=backbone,
|
|
166
|
+
n_nodes=n_nodes,
|
|
167
|
+
n_edges=n_edges,
|
|
168
|
+
node_names=node_names,
|
|
169
|
+
edge_inds=edge_inds,
|
|
170
|
+
input_scale=input_scale,
|
|
171
|
+
input_channels=input_channels,
|
|
172
|
+
output_stride=output_stride,
|
|
173
|
+
crop_size=crop_size,
|
|
174
|
+
max_instances=max_instances,
|
|
175
|
+
max_peaks_per_node=max_peaks_per_node,
|
|
176
|
+
max_batch_size=max_batch_size,
|
|
177
|
+
precision=precision,
|
|
178
|
+
input_dtype=input_dtype,
|
|
179
|
+
normalization=normalization,
|
|
180
|
+
n_classes=n_classes,
|
|
181
|
+
class_names=class_names,
|
|
182
|
+
training_config_embedded=training_config_embedded,
|
|
183
|
+
training_config_hash=training_config_hash,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def embed_metadata_in_onnx(
|
|
188
|
+
model_path: str | Path,
|
|
189
|
+
metadata: ExportMetadata,
|
|
190
|
+
training_config_text: Optional[str] = None,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""Embed metadata into an ONNX model file.
|
|
193
|
+
|
|
194
|
+
Raises ImportError if onnx is unavailable.
|
|
195
|
+
"""
|
|
196
|
+
import onnx # local import to keep dependency optional
|
|
197
|
+
|
|
198
|
+
model_path = Path(model_path)
|
|
199
|
+
model = onnx.load(model_path.as_posix())
|
|
200
|
+
model.metadata_props.append(
|
|
201
|
+
onnx.StringStringEntryProto(
|
|
202
|
+
key="sleap_nn_metadata", value=json.dumps(metadata.to_dict())
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
if training_config_text:
|
|
206
|
+
model.metadata_props.append(
|
|
207
|
+
onnx.StringStringEntryProto(
|
|
208
|
+
key="training_config", value=training_config_text
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
onnx.save(model, model_path.as_posix())
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def extract_metadata_from_onnx(model_path: str | Path) -> ExportMetadata:
|
|
215
|
+
"""Extract metadata from an ONNX model file.
|
|
216
|
+
|
|
217
|
+
Raises ValueError if metadata is missing.
|
|
218
|
+
"""
|
|
219
|
+
import onnx # local import to keep dependency optional
|
|
220
|
+
|
|
221
|
+
model = onnx.load(Path(model_path).as_posix())
|
|
222
|
+
for prop in model.metadata_props:
|
|
223
|
+
if prop.key == "sleap_nn_metadata":
|
|
224
|
+
return ExportMetadata.from_dict(json.loads(prop.value))
|
|
225
|
+
raise ValueError("No sleap_nn metadata found in ONNX model")
|