sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +168 -39
  6. sleap_nn/evaluation.py +8 -0
  7. sleap_nn/export/__init__.py +21 -0
  8. sleap_nn/export/cli.py +1778 -0
  9. sleap_nn/export/exporters/__init__.py +51 -0
  10. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  11. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  12. sleap_nn/export/metadata.py +225 -0
  13. sleap_nn/export/predictors/__init__.py +63 -0
  14. sleap_nn/export/predictors/base.py +22 -0
  15. sleap_nn/export/predictors/onnx.py +154 -0
  16. sleap_nn/export/predictors/tensorrt.py +312 -0
  17. sleap_nn/export/utils.py +307 -0
  18. sleap_nn/export/wrappers/__init__.py +25 -0
  19. sleap_nn/export/wrappers/base.py +96 -0
  20. sleap_nn/export/wrappers/bottomup.py +243 -0
  21. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  22. sleap_nn/export/wrappers/centered_instance.py +56 -0
  23. sleap_nn/export/wrappers/centroid.py +58 -0
  24. sleap_nn/export/wrappers/single_instance.py +83 -0
  25. sleap_nn/export/wrappers/topdown.py +180 -0
  26. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  27. sleap_nn/inference/peak_finding.py +47 -17
  28. sleap_nn/inference/postprocessing.py +284 -0
  29. sleap_nn/inference/predictors.py +213 -106
  30. sleap_nn/predict.py +35 -7
  31. sleap_nn/train.py +64 -0
  32. sleap_nn/training/callbacks.py +69 -22
  33. sleap_nn/training/lightning_modules.py +332 -30
  34. sleap_nn/training/model_trainer.py +67 -67
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
  36. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
  37. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
  38. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
  39. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
  40. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,51 @@
1
+ """Exporters for serialized model formats."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Iterable, Optional
7
+
8
+ import torch
9
+
10
+ from sleap_nn.export.exporters.onnx_exporter import export_to_onnx
11
+ from sleap_nn.export.exporters.tensorrt_exporter import export_to_tensorrt
12
+
13
+
14
+ def export_model(
15
+ model: torch.nn.Module,
16
+ save_path: str | Path,
17
+ fmt: str = "onnx",
18
+ input_shape: Iterable[int] = (1, 1, 512, 512),
19
+ opset_version: int = 17,
20
+ output_names: Optional[list] = None,
21
+ verify: bool = True,
22
+ **kwargs,
23
+ ) -> Path:
24
+ """Export a model to the requested format."""
25
+ fmt = fmt.lower()
26
+ if fmt == "onnx":
27
+ return export_to_onnx(
28
+ model,
29
+ save_path,
30
+ input_shape=input_shape,
31
+ opset_version=opset_version,
32
+ output_names=output_names,
33
+ verify=verify,
34
+ )
35
+ if fmt == "tensorrt":
36
+ return export_to_tensorrt(model, save_path, input_shape=input_shape, **kwargs)
37
+ if fmt == "both":
38
+ export_to_onnx(
39
+ model,
40
+ save_path,
41
+ input_shape=input_shape,
42
+ opset_version=opset_version,
43
+ output_names=output_names,
44
+ verify=verify,
45
+ )
46
+ return export_to_tensorrt(model, save_path, input_shape=input_shape, **kwargs)
47
+
48
+ raise ValueError(f"Unknown export format: {fmt}")
49
+
50
+
51
+ __all__ = ["export_model", "export_to_onnx", "export_to_tensorrt"]
@@ -0,0 +1,80 @@
1
+ """ONNX export utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Dict, Iterable, List, Optional
7
+
8
+ import torch
9
+
10
+
11
+ def export_to_onnx(
12
+ model: torch.nn.Module,
13
+ save_path: str | Path,
14
+ input_shape: Iterable[int] = (1, 1, 512, 512),
15
+ input_dtype: torch.dtype = torch.uint8,
16
+ opset_version: int = 17,
17
+ dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
18
+ input_names: Optional[List[str]] = None,
19
+ output_names: Optional[List[str]] = None,
20
+ do_constant_folding: bool = True,
21
+ verify: bool = True,
22
+ ) -> Path:
23
+ """Export a PyTorch model to ONNX."""
24
+ save_path = Path(save_path)
25
+ model.eval()
26
+
27
+ if input_names is None:
28
+ input_names = ["image"]
29
+ if dynamic_axes is None:
30
+ dynamic_axes = {"image": {0: "batch", 2: "height", 3: "width"}}
31
+
32
+ device = None
33
+ try:
34
+ device = next(model.parameters()).device
35
+ except StopIteration:
36
+ device = torch.device("cpu")
37
+
38
+ if input_dtype.is_floating_point:
39
+ dummy_input = torch.randn(*input_shape, device=device, dtype=input_dtype)
40
+ else:
41
+ dummy_input = torch.randint(
42
+ 0, 256, input_shape, device=device, dtype=input_dtype
43
+ )
44
+
45
+ if output_names is None:
46
+ with torch.no_grad():
47
+ test_out = model(dummy_input)
48
+ output_names = _infer_output_names(test_out)
49
+
50
+ torch.onnx.export(
51
+ model,
52
+ dummy_input,
53
+ save_path.as_posix(),
54
+ opset_version=opset_version,
55
+ input_names=input_names,
56
+ output_names=output_names,
57
+ dynamic_axes=dynamic_axes,
58
+ do_constant_folding=do_constant_folding,
59
+ dynamo=False,
60
+ )
61
+
62
+ if verify:
63
+ _verify_onnx(save_path)
64
+
65
+ return save_path
66
+
67
+
68
+ def _infer_output_names(output) -> List[str]:
69
+ if isinstance(output, dict):
70
+ return list(output.keys())
71
+ if isinstance(output, (list, tuple)):
72
+ return [f"output_{idx}" for idx in range(len(output))]
73
+ return ["output_0"]
74
+
75
+
76
+ def _verify_onnx(path: Path) -> None:
77
+ import onnx
78
+
79
+ model = onnx.load(path.as_posix())
80
+ onnx.checker.check_model(model)
@@ -0,0 +1,291 @@
1
+ """TensorRT export utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Iterable, Optional, Tuple
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ def export_to_tensorrt(
13
+ model: nn.Module,
14
+ save_path: str | Path,
15
+ input_shape: Tuple[int, int, int, int] = (1, 1, 512, 512),
16
+ input_dtype: torch.dtype = torch.uint8,
17
+ precision: str = "fp16",
18
+ min_shape: Optional[Tuple[int, int, int, int]] = None,
19
+ opt_shape: Optional[Tuple[int, int, int, int]] = None,
20
+ max_shape: Optional[Tuple[int, int, int, int]] = None,
21
+ workspace_size: int = 2 << 30, # 2GB default
22
+ method: str = "onnx",
23
+ verbose: bool = True,
24
+ ) -> Path:
25
+ """Export a PyTorch model to TensorRT format.
26
+
27
+ This function supports multiple compilation methods:
28
+ - "onnx": Exports to ONNX first, then compiles with TensorRT (most reliable)
29
+ - "jit": Uses torch.jit.trace + torch_tensorrt.compile (alternative)
30
+
31
+ Args:
32
+ model: The PyTorch model to export (typically an ONNX wrapper).
33
+ save_path: Path to save the TensorRT engine (.trt file).
34
+ input_shape: (B, C, H, W) optimal input tensor shape.
35
+ input_dtype: Input tensor dtype (torch.uint8 or torch.float32).
36
+ precision: Model precision - "fp32" or "fp16".
37
+ min_shape: Minimum input shape for dynamic shapes (default: batch=1, H/W halved).
38
+ opt_shape: Optimal input shape (default: same as input_shape).
39
+ max_shape: Maximum input shape (default: batch=16, H/W doubled).
40
+ workspace_size: TensorRT workspace size in bytes (default 2GB).
41
+ method: Compilation method - "onnx" or "jit".
42
+ verbose: Print export info.
43
+
44
+ Returns:
45
+ Path to the exported TensorRT engine.
46
+
47
+ Note:
48
+ TensorRT models are NOT cross-platform. The exported model will only
49
+ work on the same GPU architecture and TensorRT version used for export.
50
+ """
51
+ import tensorrt as trt
52
+
53
+ model.eval()
54
+ device = next(model.parameters()).device
55
+
56
+ save_path = Path(save_path)
57
+ if not save_path.suffix:
58
+ save_path = save_path.with_suffix(".trt")
59
+
60
+ B, C, H, W = input_shape
61
+
62
+ if min_shape is None:
63
+ min_shape = (1, C, H // 2, W // 2)
64
+ if opt_shape is None:
65
+ opt_shape = input_shape
66
+ if max_shape is None:
67
+ max_shape = (min(16, B * 4), C, H * 2, W * 2)
68
+
69
+ if verbose:
70
+ print(f"Exporting model to TensorRT...")
71
+ print(f" Input shape: {input_shape}")
72
+ print(f" Min/Opt/Max: {min_shape} / {opt_shape} / {max_shape}")
73
+ print(f" Precision: {precision}")
74
+ print(f" Workspace: {workspace_size / 1e9:.1f} GB")
75
+ print(f" Method: {method}")
76
+
77
+ if method == "onnx":
78
+ return _export_tensorrt_onnx(
79
+ model,
80
+ save_path,
81
+ input_shape,
82
+ input_dtype,
83
+ min_shape,
84
+ opt_shape,
85
+ max_shape,
86
+ precision,
87
+ workspace_size,
88
+ verbose,
89
+ )
90
+ elif method == "jit":
91
+ return _export_tensorrt_jit(
92
+ model,
93
+ save_path,
94
+ input_shape,
95
+ input_dtype,
96
+ min_shape,
97
+ opt_shape,
98
+ max_shape,
99
+ precision,
100
+ workspace_size,
101
+ verbose,
102
+ )
103
+ else:
104
+ raise ValueError(f"Unknown method: {method}. Use 'onnx' or 'jit'.")
105
+
106
+
107
+ def _export_tensorrt_onnx(
108
+ model: nn.Module,
109
+ save_path: Path,
110
+ input_shape: Tuple[int, int, int, int],
111
+ input_dtype: torch.dtype,
112
+ min_shape: Tuple[int, int, int, int],
113
+ opt_shape: Tuple[int, int, int, int],
114
+ max_shape: Tuple[int, int, int, int],
115
+ precision: str,
116
+ workspace_size: int,
117
+ verbose: bool,
118
+ ) -> Path:
119
+ """Export via ONNX, then compile to TensorRT engine."""
120
+ import tensorrt as trt
121
+
122
+ # Check if ONNX file already exists (from prior export step)
123
+ onnx_path = save_path.with_suffix(".onnx")
124
+
125
+ if onnx_path.exists():
126
+ if verbose:
127
+ print(f" Using existing ONNX file: {onnx_path}")
128
+ else:
129
+ # Need to export to ONNX first
130
+ device = next(model.parameters()).device
131
+
132
+ if verbose:
133
+ print(" Exporting to ONNX first...")
134
+
135
+ # Create example input with correct dtype
136
+ if input_dtype == torch.uint8:
137
+ example_input = torch.randint(
138
+ 0, 255, input_shape, dtype=torch.uint8, device=device
139
+ )
140
+ else:
141
+ example_input = torch.randn(*input_shape, dtype=input_dtype, device=device)
142
+
143
+ # Export to ONNX
144
+ torch.onnx.export(
145
+ model,
146
+ example_input,
147
+ onnx_path,
148
+ opset_version=17,
149
+ input_names=["images"],
150
+ dynamic_axes={"images": {0: "batch", 2: "height", 3: "width"}},
151
+ do_constant_folding=True,
152
+ )
153
+
154
+ if verbose:
155
+ print(f" ONNX export complete: {onnx_path}")
156
+
157
+ if verbose:
158
+ print(f" Building TensorRT engine (this may take a while)...")
159
+
160
+ # Create TensorRT builder
161
+ logger = trt.Logger(trt.Logger.WARNING if not verbose else trt.Logger.INFO)
162
+ builder = trt.Builder(logger)
163
+ network = builder.create_network(
164
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
165
+ )
166
+ parser = trt.OnnxParser(network, logger)
167
+
168
+ # Parse ONNX model
169
+ with open(onnx_path, "rb") as f:
170
+ if not parser.parse(f.read()):
171
+ errors = []
172
+ for i in range(parser.num_errors):
173
+ errors.append(str(parser.get_error(i)))
174
+ raise RuntimeError(f"ONNX parsing failed: {errors}")
175
+
176
+ # Configure builder
177
+ config = builder.create_builder_config()
178
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size)
179
+
180
+ if precision == "fp16":
181
+ if builder.platform_has_fast_fp16:
182
+ config.set_flag(trt.BuilderFlag.FP16)
183
+ if verbose:
184
+ print(" Enabled FP16 mode")
185
+ else:
186
+ if verbose:
187
+ print(" WARNING: Platform does not have fast FP16, using FP32")
188
+
189
+ # Set up optimization profile for dynamic shapes
190
+ profile = builder.create_optimization_profile()
191
+ input_name = network.get_input(0).name
192
+
193
+ profile.set_shape(input_name, min_shape, opt_shape, max_shape)
194
+ config.add_optimization_profile(profile)
195
+
196
+ # Build engine
197
+ serialized_engine = builder.build_serialized_network(network, config)
198
+ if serialized_engine is None:
199
+ raise RuntimeError("Failed to build TensorRT engine")
200
+
201
+ # Save engine
202
+ engine_path = save_path.with_suffix(".trt")
203
+ with open(engine_path, "wb") as f:
204
+ f.write(serialized_engine)
205
+
206
+ if verbose:
207
+ import os
208
+
209
+ print(f" Exported TensorRT engine to: {engine_path}")
210
+ print(f" Engine size: {os.path.getsize(engine_path) / 1e6:.2f} MB")
211
+
212
+ return engine_path
213
+
214
+
215
+ def _export_tensorrt_jit(
216
+ model: nn.Module,
217
+ save_path: Path,
218
+ input_shape: Tuple[int, int, int, int],
219
+ input_dtype: torch.dtype,
220
+ min_shape: Tuple[int, int, int, int],
221
+ opt_shape: Tuple[int, int, int, int],
222
+ max_shape: Tuple[int, int, int, int],
223
+ precision: str,
224
+ workspace_size: int,
225
+ verbose: bool,
226
+ ) -> Path:
227
+ """Export using torch.jit.trace + torch_tensorrt.compile."""
228
+ import torch_tensorrt
229
+
230
+ device = next(model.parameters()).device
231
+
232
+ if verbose:
233
+ print(" Tracing model with torch.jit...")
234
+
235
+ # Create example input (float32 for tracing)
236
+ if input_dtype == torch.uint8:
237
+ example_input = torch.randint(
238
+ 0, 255, input_shape, dtype=torch.uint8, device=device
239
+ )
240
+ else:
241
+ example_input = torch.randn(*input_shape, dtype=input_dtype, device=device)
242
+
243
+ # Trace the model
244
+ with torch.no_grad():
245
+ traced_model = torch.jit.trace(model, example_input)
246
+
247
+ # Map precision to torch dtype
248
+ precision_map = {
249
+ "fp32": torch.float32,
250
+ "fp16": torch.float16,
251
+ }
252
+ if precision not in precision_map:
253
+ raise ValueError(f"Unknown precision: {precision}. Use 'fp32' or 'fp16'")
254
+
255
+ enabled_precisions = {precision_map[precision]}
256
+ if precision == "fp16":
257
+ enabled_precisions.add(torch.float32)
258
+
259
+ # Create input specs for TensorRT
260
+ trt_inputs = [
261
+ torch_tensorrt.Input(
262
+ min_shape=min_shape,
263
+ opt_shape=opt_shape,
264
+ max_shape=max_shape,
265
+ dtype=torch.float32, # TRT internally uses float32 input spec
266
+ )
267
+ ]
268
+
269
+ if verbose:
270
+ print(" Compiling with TensorRT...")
271
+
272
+ # Compile with TensorRT
273
+ trt_model = torch_tensorrt.compile(
274
+ traced_model,
275
+ inputs=trt_inputs,
276
+ enabled_precisions=enabled_precisions,
277
+ workspace_size=workspace_size,
278
+ truncate_long_and_double=True,
279
+ )
280
+
281
+ # Save as TorchScript
282
+ ts_path = save_path.with_suffix(".ts")
283
+ torch.jit.save(trt_model, ts_path)
284
+
285
+ if verbose:
286
+ import os
287
+
288
+ print(f" Exported TensorRT-accelerated model to: {ts_path}")
289
+ print(f" Model size: {os.path.getsize(ts_path) / 1e6:.2f} MB")
290
+
291
+ return ts_path
@@ -0,0 +1,225 @@
1
+ """Metadata helpers for exported models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, asdict
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Tuple
9
+ import hashlib
10
+ import json
11
+
12
+ from sleap_nn import __version__
13
+
14
+
15
+ @dataclass
16
+ class ExportMetadata:
17
+ """Metadata embedded or saved alongside exported models."""
18
+
19
+ # Version info
20
+ sleap_nn_version: str
21
+ export_timestamp: str
22
+ export_format: str # "onnx" or "tensorrt"
23
+
24
+ # Model info
25
+ model_type: str # "centroid", "centered_instance", "topdown", "bottomup"
26
+ model_name: str
27
+ checkpoint_path: str
28
+
29
+ # Architecture
30
+ backbone: str
31
+ n_nodes: int
32
+ n_edges: int
33
+ node_names: List[str]
34
+ edge_inds: List[Tuple[int, int]]
35
+
36
+ # Input/output spec
37
+ input_scale: float
38
+ input_channels: int
39
+ output_stride: int
40
+ crop_size: Optional[Tuple[int, int]] = None
41
+
42
+ # Export parameters
43
+ max_instances: Optional[int] = None
44
+ max_peaks_per_node: Optional[int] = None
45
+ max_batch_size: int = 1
46
+ precision: str = "fp32"
47
+
48
+ # Preprocessing - input is uint8 [0,255], normalized internally to float32 [0,1]
49
+ input_dtype: str = "uint8"
50
+ normalization: str = "0_to_1"
51
+
52
+ # Multiclass model fields (optional)
53
+ n_classes: Optional[int] = None
54
+ class_names: Optional[List[str]] = None
55
+
56
+ # Training config reference
57
+ training_config_embedded: bool = False
58
+ training_config_hash: str = ""
59
+
60
+ def to_dict(self) -> Dict[str, Any]:
61
+ """Convert to JSON-serializable dict."""
62
+ data = asdict(self)
63
+ data["edge_inds"] = [list(pair) for pair in self.edge_inds]
64
+ if self.crop_size is not None:
65
+ data["crop_size"] = list(self.crop_size)
66
+ return data
67
+
68
+ @classmethod
69
+ def from_dict(cls, data: Dict[str, Any]) -> "ExportMetadata":
70
+ """Load from dict."""
71
+ edge_inds = [tuple(pair) for pair in data.get("edge_inds", [])]
72
+ crop_size = data.get("crop_size")
73
+ if crop_size is not None:
74
+ crop_size = tuple(crop_size)
75
+ return cls(
76
+ sleap_nn_version=data.get("sleap_nn_version", ""),
77
+ export_timestamp=data.get("export_timestamp", ""),
78
+ export_format=data.get("export_format", ""),
79
+ model_type=data.get("model_type", ""),
80
+ model_name=data.get("model_name", ""),
81
+ checkpoint_path=data.get("checkpoint_path", ""),
82
+ backbone=data.get("backbone", ""),
83
+ n_nodes=int(data.get("n_nodes", 0)),
84
+ n_edges=int(data.get("n_edges", 0)),
85
+ node_names=list(data.get("node_names", [])),
86
+ edge_inds=edge_inds,
87
+ input_scale=float(data.get("input_scale", 1.0)),
88
+ input_channels=int(data.get("input_channels", 1)),
89
+ output_stride=int(data.get("output_stride", 1)),
90
+ crop_size=crop_size,
91
+ max_instances=data.get("max_instances"),
92
+ max_peaks_per_node=data.get("max_peaks_per_node"),
93
+ max_batch_size=int(data.get("max_batch_size", 1)),
94
+ precision=data.get("precision", "fp32"),
95
+ input_dtype=data.get("input_dtype", "uint8"),
96
+ normalization=data.get("normalization", "0_to_1"),
97
+ n_classes=data.get("n_classes"),
98
+ class_names=data.get("class_names"),
99
+ training_config_embedded=bool(data.get("training_config_embedded", False)),
100
+ training_config_hash=data.get("training_config_hash", ""),
101
+ )
102
+
103
+ def save(self, path: str | Path) -> None:
104
+ """Save to JSON file."""
105
+ path = Path(path)
106
+ path.write_text(json.dumps(self.to_dict(), indent=2, sort_keys=True))
107
+
108
+ @classmethod
109
+ def load(cls, path: str | Path) -> "ExportMetadata":
110
+ """Load from JSON file."""
111
+ path = Path(path)
112
+ data = json.loads(path.read_text())
113
+ return cls.from_dict(data)
114
+
115
+ @classmethod
116
+ def default_timestamp(cls) -> str:
117
+ """Return an ISO timestamp for export."""
118
+ return datetime.now().isoformat()
119
+
120
+
121
+ def hash_file(path: str | Path) -> str:
122
+ """Compute SHA256 hash for a file."""
123
+ path = Path(path)
124
+ hasher = hashlib.sha256()
125
+ with path.open("rb") as handle:
126
+ for chunk in iter(lambda: handle.read(8192), b""):
127
+ hasher.update(chunk)
128
+ return hasher.hexdigest()
129
+
130
+
131
+ def build_base_metadata(
132
+ *,
133
+ export_format: str,
134
+ model_type: str,
135
+ model_name: str,
136
+ checkpoint_path: str,
137
+ backbone: str,
138
+ n_nodes: int,
139
+ n_edges: int,
140
+ node_names: List[str],
141
+ edge_inds: List[Tuple[int, int]],
142
+ input_scale: float,
143
+ input_channels: int,
144
+ output_stride: int,
145
+ crop_size: Optional[Tuple[int, int]] = None,
146
+ max_instances: Optional[int] = None,
147
+ max_peaks_per_node: Optional[int] = None,
148
+ max_batch_size: int = 1,
149
+ precision: str = "fp32",
150
+ training_config_hash: str = "",
151
+ training_config_embedded: bool = False,
152
+ input_dtype: str = "uint8",
153
+ normalization: str = "0_to_1",
154
+ n_classes: Optional[int] = None,
155
+ class_names: Optional[List[str]] = None,
156
+ ) -> ExportMetadata:
157
+ """Create an ExportMetadata instance with standard defaults."""
158
+ return ExportMetadata(
159
+ sleap_nn_version=__version__,
160
+ export_timestamp=ExportMetadata.default_timestamp(),
161
+ export_format=export_format,
162
+ model_type=model_type,
163
+ model_name=model_name,
164
+ checkpoint_path=checkpoint_path,
165
+ backbone=backbone,
166
+ n_nodes=n_nodes,
167
+ n_edges=n_edges,
168
+ node_names=node_names,
169
+ edge_inds=edge_inds,
170
+ input_scale=input_scale,
171
+ input_channels=input_channels,
172
+ output_stride=output_stride,
173
+ crop_size=crop_size,
174
+ max_instances=max_instances,
175
+ max_peaks_per_node=max_peaks_per_node,
176
+ max_batch_size=max_batch_size,
177
+ precision=precision,
178
+ input_dtype=input_dtype,
179
+ normalization=normalization,
180
+ n_classes=n_classes,
181
+ class_names=class_names,
182
+ training_config_embedded=training_config_embedded,
183
+ training_config_hash=training_config_hash,
184
+ )
185
+
186
+
187
+ def embed_metadata_in_onnx(
188
+ model_path: str | Path,
189
+ metadata: ExportMetadata,
190
+ training_config_text: Optional[str] = None,
191
+ ) -> None:
192
+ """Embed metadata into an ONNX model file.
193
+
194
+ Raises ImportError if onnx is unavailable.
195
+ """
196
+ import onnx # local import to keep dependency optional
197
+
198
+ model_path = Path(model_path)
199
+ model = onnx.load(model_path.as_posix())
200
+ model.metadata_props.append(
201
+ onnx.StringStringEntryProto(
202
+ key="sleap_nn_metadata", value=json.dumps(metadata.to_dict())
203
+ )
204
+ )
205
+ if training_config_text:
206
+ model.metadata_props.append(
207
+ onnx.StringStringEntryProto(
208
+ key="training_config", value=training_config_text
209
+ )
210
+ )
211
+ onnx.save(model, model_path.as_posix())
212
+
213
+
214
+ def extract_metadata_from_onnx(model_path: str | Path) -> ExportMetadata:
215
+ """Extract metadata from an ONNX model file.
216
+
217
+ Raises ValueError if metadata is missing.
218
+ """
219
+ import onnx # local import to keep dependency optional
220
+
221
+ model = onnx.load(Path(model_path).as_posix())
222
+ for prop in model.metadata_props:
223
+ if prop.key == "sleap_nn_metadata":
224
+ return ExportMetadata.from_dict(json.loads(prop.value))
225
+ raise ValueError("No sleap_nn metadata found in ONNX model")