sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +62 -20
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +35 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/export/cli.py
ADDED
|
@@ -0,0 +1,1778 @@
|
|
|
1
|
+
"""CLI entry points for export workflows."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
import json
|
|
8
|
+
import shutil
|
|
9
|
+
|
|
10
|
+
import click
|
|
11
|
+
from omegaconf import OmegaConf
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from sleap_nn.export.exporters import export_to_onnx, export_to_tensorrt
|
|
15
|
+
from sleap_nn.export.metadata import (
|
|
16
|
+
build_base_metadata,
|
|
17
|
+
embed_metadata_in_onnx,
|
|
18
|
+
hash_file,
|
|
19
|
+
)
|
|
20
|
+
from sleap_nn.export.utils import (
|
|
21
|
+
load_training_config,
|
|
22
|
+
resolve_backbone_type,
|
|
23
|
+
resolve_class_maps_output_stride,
|
|
24
|
+
resolve_class_names,
|
|
25
|
+
resolve_crop_size,
|
|
26
|
+
resolve_edge_inds,
|
|
27
|
+
resolve_input_channels,
|
|
28
|
+
resolve_input_scale,
|
|
29
|
+
resolve_input_shape,
|
|
30
|
+
resolve_model_type,
|
|
31
|
+
resolve_n_classes,
|
|
32
|
+
resolve_node_names,
|
|
33
|
+
resolve_output_stride,
|
|
34
|
+
resolve_pafs_output_stride,
|
|
35
|
+
)
|
|
36
|
+
from sleap_nn.export.wrappers import (
|
|
37
|
+
BottomUpMultiClassONNXWrapper,
|
|
38
|
+
BottomUpONNXWrapper,
|
|
39
|
+
CenteredInstanceONNXWrapper,
|
|
40
|
+
CentroidONNXWrapper,
|
|
41
|
+
SingleInstanceONNXWrapper,
|
|
42
|
+
TopDownMultiClassCombinedONNXWrapper,
|
|
43
|
+
TopDownMultiClassONNXWrapper,
|
|
44
|
+
TopDownONNXWrapper,
|
|
45
|
+
)
|
|
46
|
+
from sleap_nn.training.lightning_modules import (
|
|
47
|
+
BottomUpLightningModule,
|
|
48
|
+
BottomUpMultiClassLightningModule,
|
|
49
|
+
CentroidLightningModule,
|
|
50
|
+
SingleInstanceLightningModule,
|
|
51
|
+
TopDownCenteredInstanceLightningModule,
|
|
52
|
+
TopDownCenteredInstanceMultiClassLightningModule,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@click.command()
|
|
57
|
+
@click.argument(
|
|
58
|
+
"model_paths",
|
|
59
|
+
nargs=-1,
|
|
60
|
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
|
61
|
+
)
|
|
62
|
+
@click.option(
|
|
63
|
+
"--output",
|
|
64
|
+
"-o",
|
|
65
|
+
type=click.Path(file_okay=False, path_type=Path),
|
|
66
|
+
default=None,
|
|
67
|
+
help="Output directory for exported model files.",
|
|
68
|
+
)
|
|
69
|
+
@click.option(
|
|
70
|
+
"--format",
|
|
71
|
+
"-f",
|
|
72
|
+
"fmt",
|
|
73
|
+
type=click.Choice(["onnx", "tensorrt", "both"], case_sensitive=False),
|
|
74
|
+
default="onnx",
|
|
75
|
+
show_default=True,
|
|
76
|
+
)
|
|
77
|
+
@click.option("--opset-version", type=int, default=17, show_default=True)
|
|
78
|
+
@click.option("--max-instances", type=int, default=20, show_default=True)
|
|
79
|
+
@click.option("--max-batch-size", type=int, default=8, show_default=True)
|
|
80
|
+
@click.option("--input-scale", type=float, default=None)
|
|
81
|
+
@click.option("--input-height", type=int, default=None)
|
|
82
|
+
@click.option("--input-width", type=int, default=None)
|
|
83
|
+
@click.option("--crop-size", type=int, default=None)
|
|
84
|
+
@click.option("--max-peaks-per-node", type=int, default=20, show_default=True)
|
|
85
|
+
@click.option("--n-line-points", type=int, default=10, show_default=True)
|
|
86
|
+
@click.option("--max-edge-length-ratio", type=float, default=0.25, show_default=True)
|
|
87
|
+
@click.option("--dist-penalty-weight", type=float, default=1.0, show_default=True)
|
|
88
|
+
@click.option("--device", type=str, default="cpu", show_default=True)
|
|
89
|
+
@click.option(
|
|
90
|
+
"--precision",
|
|
91
|
+
type=click.Choice(["fp32", "fp16"], case_sensitive=False),
|
|
92
|
+
default="fp16",
|
|
93
|
+
show_default=True,
|
|
94
|
+
help="TensorRT precision mode.",
|
|
95
|
+
)
|
|
96
|
+
@click.option("--verify/--no-verify", default=True, show_default=True)
|
|
97
|
+
def export(
|
|
98
|
+
model_paths: tuple[Path, ...],
|
|
99
|
+
output: Optional[Path],
|
|
100
|
+
fmt: str,
|
|
101
|
+
opset_version: int,
|
|
102
|
+
max_instances: int,
|
|
103
|
+
max_batch_size: int,
|
|
104
|
+
input_scale: Optional[float],
|
|
105
|
+
input_height: Optional[int],
|
|
106
|
+
input_width: Optional[int],
|
|
107
|
+
crop_size: Optional[int],
|
|
108
|
+
max_peaks_per_node: int,
|
|
109
|
+
n_line_points: int,
|
|
110
|
+
max_edge_length_ratio: float,
|
|
111
|
+
dist_penalty_weight: float,
|
|
112
|
+
device: str,
|
|
113
|
+
precision: str,
|
|
114
|
+
verify: bool,
|
|
115
|
+
) -> None:
|
|
116
|
+
"""Export trained models to ONNX/TensorRT formats."""
|
|
117
|
+
fmt = fmt.lower()
|
|
118
|
+
|
|
119
|
+
if not model_paths:
|
|
120
|
+
raise click.ClickException("Provide at least one model path to export.")
|
|
121
|
+
|
|
122
|
+
model_paths = list(model_paths)
|
|
123
|
+
cfgs = [load_training_config(path) for path in model_paths]
|
|
124
|
+
model_types = [resolve_model_type(cfg) for cfg in cfgs]
|
|
125
|
+
backbone_types = [resolve_backbone_type(cfg) for cfg in cfgs]
|
|
126
|
+
|
|
127
|
+
if len(model_paths) == 1:
|
|
128
|
+
model_path = model_paths[0]
|
|
129
|
+
cfg = cfgs[0]
|
|
130
|
+
model_type = model_types[0]
|
|
131
|
+
backbone_type = backbone_types[0]
|
|
132
|
+
|
|
133
|
+
if model_type not in (
|
|
134
|
+
"centroid",
|
|
135
|
+
"centered_instance",
|
|
136
|
+
"bottomup",
|
|
137
|
+
"single_instance",
|
|
138
|
+
"multi_class_topdown",
|
|
139
|
+
"multi_class_bottomup",
|
|
140
|
+
):
|
|
141
|
+
raise click.ClickException(
|
|
142
|
+
f"Model type '{model_type}' is not supported for export yet."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
ckpt_path = model_path / "best.ckpt"
|
|
146
|
+
if not ckpt_path.exists():
|
|
147
|
+
raise click.ClickException(f"Checkpoint not found: {ckpt_path}")
|
|
148
|
+
|
|
149
|
+
lightning_model = _load_lightning_model(
|
|
150
|
+
model_type=model_type,
|
|
151
|
+
backbone_type=backbone_type,
|
|
152
|
+
cfg=cfg,
|
|
153
|
+
ckpt_path=ckpt_path,
|
|
154
|
+
device=device,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
torch_model = lightning_model.model
|
|
158
|
+
torch_model.eval()
|
|
159
|
+
torch_model.to(device)
|
|
160
|
+
|
|
161
|
+
export_dir = output or (model_path / "exported")
|
|
162
|
+
export_dir.mkdir(parents=True, exist_ok=True)
|
|
163
|
+
|
|
164
|
+
resolved_scale = (
|
|
165
|
+
input_scale if input_scale is not None else resolve_input_scale(cfg)
|
|
166
|
+
)
|
|
167
|
+
output_stride = resolve_output_stride(cfg, model_type)
|
|
168
|
+
resolved_crop_size = (
|
|
169
|
+
(crop_size, crop_size) if crop_size is not None else resolve_crop_size(cfg)
|
|
170
|
+
)
|
|
171
|
+
metadata_max_instances = None
|
|
172
|
+
metadata_max_peaks = None
|
|
173
|
+
metadata_n_classes = None
|
|
174
|
+
metadata_class_names = None
|
|
175
|
+
|
|
176
|
+
if model_type == "centroid":
|
|
177
|
+
wrapper = CentroidONNXWrapper(
|
|
178
|
+
torch_model,
|
|
179
|
+
max_instances=max_instances,
|
|
180
|
+
output_stride=output_stride,
|
|
181
|
+
input_scale=resolved_scale,
|
|
182
|
+
)
|
|
183
|
+
output_names = ["centroids", "centroid_vals", "instance_valid"]
|
|
184
|
+
metadata_max_instances = max_instances
|
|
185
|
+
node_names = resolve_node_names(cfg, model_type)
|
|
186
|
+
edge_inds = resolve_edge_inds(cfg, node_names)
|
|
187
|
+
elif model_type == "centered_instance":
|
|
188
|
+
wrapper = CenteredInstanceONNXWrapper(
|
|
189
|
+
torch_model,
|
|
190
|
+
output_stride=output_stride,
|
|
191
|
+
input_scale=resolved_scale,
|
|
192
|
+
)
|
|
193
|
+
output_names = ["peaks", "peak_vals"]
|
|
194
|
+
node_names = resolve_node_names(cfg, model_type)
|
|
195
|
+
edge_inds = resolve_edge_inds(cfg, node_names)
|
|
196
|
+
elif model_type == "bottomup":
|
|
197
|
+
node_names = resolve_node_names(cfg, model_type)
|
|
198
|
+
edge_inds = resolve_edge_inds(cfg, node_names)
|
|
199
|
+
pafs_output_stride = resolve_pafs_output_stride(cfg)
|
|
200
|
+
wrapper = BottomUpONNXWrapper(
|
|
201
|
+
torch_model,
|
|
202
|
+
skeleton_edges=edge_inds,
|
|
203
|
+
n_nodes=len(node_names),
|
|
204
|
+
max_peaks_per_node=max_peaks_per_node,
|
|
205
|
+
n_line_points=n_line_points,
|
|
206
|
+
cms_output_stride=output_stride,
|
|
207
|
+
pafs_output_stride=pafs_output_stride,
|
|
208
|
+
max_edge_length_ratio=max_edge_length_ratio,
|
|
209
|
+
dist_penalty_weight=dist_penalty_weight,
|
|
210
|
+
input_scale=resolved_scale,
|
|
211
|
+
)
|
|
212
|
+
output_names = [
|
|
213
|
+
"peaks",
|
|
214
|
+
"peak_vals",
|
|
215
|
+
"peak_mask",
|
|
216
|
+
"line_scores",
|
|
217
|
+
"candidate_mask",
|
|
218
|
+
]
|
|
219
|
+
metadata_max_peaks = max_peaks_per_node
|
|
220
|
+
elif model_type == "single_instance":
|
|
221
|
+
wrapper = SingleInstanceONNXWrapper(
|
|
222
|
+
torch_model,
|
|
223
|
+
output_stride=output_stride,
|
|
224
|
+
input_scale=resolved_scale,
|
|
225
|
+
)
|
|
226
|
+
output_names = ["peaks", "peak_vals"]
|
|
227
|
+
node_names = resolve_node_names(cfg, model_type)
|
|
228
|
+
edge_inds = resolve_edge_inds(cfg, node_names)
|
|
229
|
+
elif model_type == "multi_class_topdown":
|
|
230
|
+
n_classes = resolve_n_classes(cfg, model_type)
|
|
231
|
+
class_names = resolve_class_names(cfg, model_type)
|
|
232
|
+
wrapper = TopDownMultiClassONNXWrapper(
|
|
233
|
+
torch_model,
|
|
234
|
+
output_stride=output_stride,
|
|
235
|
+
input_scale=resolved_scale,
|
|
236
|
+
n_classes=n_classes,
|
|
237
|
+
)
|
|
238
|
+
output_names = ["peaks", "peak_vals", "class_logits"]
|
|
239
|
+
node_names = resolve_node_names(cfg, model_type)
|
|
240
|
+
edge_inds = resolve_edge_inds(cfg, node_names)
|
|
241
|
+
metadata_n_classes = n_classes
|
|
242
|
+
metadata_class_names = class_names
|
|
243
|
+
elif model_type == "multi_class_bottomup":
|
|
244
|
+
node_names = resolve_node_names(cfg, model_type)
|
|
245
|
+
edge_inds = resolve_edge_inds(cfg, node_names)
|
|
246
|
+
n_classes = resolve_n_classes(cfg, model_type)
|
|
247
|
+
class_names = resolve_class_names(cfg, model_type)
|
|
248
|
+
class_maps_output_stride = resolve_class_maps_output_stride(cfg)
|
|
249
|
+
wrapper = BottomUpMultiClassONNXWrapper(
|
|
250
|
+
torch_model,
|
|
251
|
+
n_nodes=len(node_names),
|
|
252
|
+
n_classes=n_classes,
|
|
253
|
+
max_peaks_per_node=max_peaks_per_node,
|
|
254
|
+
cms_output_stride=output_stride,
|
|
255
|
+
class_maps_output_stride=class_maps_output_stride,
|
|
256
|
+
input_scale=resolved_scale,
|
|
257
|
+
)
|
|
258
|
+
output_names = ["peaks", "peak_vals", "peak_mask", "class_probs"]
|
|
259
|
+
metadata_max_peaks = max_peaks_per_node
|
|
260
|
+
metadata_n_classes = n_classes
|
|
261
|
+
metadata_class_names = class_names
|
|
262
|
+
else:
|
|
263
|
+
raise click.ClickException(
|
|
264
|
+
f"Model type '{model_type}' is not supported for export yet."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
wrapper.eval()
|
|
268
|
+
wrapper.to(device)
|
|
269
|
+
|
|
270
|
+
input_shape = resolve_input_shape(
|
|
271
|
+
cfg, input_height=input_height, input_width=input_width
|
|
272
|
+
)
|
|
273
|
+
model_out_path = export_dir / "model.onnx"
|
|
274
|
+
|
|
275
|
+
export_to_onnx(
|
|
276
|
+
wrapper,
|
|
277
|
+
model_out_path,
|
|
278
|
+
input_shape=input_shape,
|
|
279
|
+
input_dtype=torch.uint8,
|
|
280
|
+
opset_version=opset_version,
|
|
281
|
+
output_names=output_names,
|
|
282
|
+
verify=verify,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
training_config_path = _copy_training_config(model_path, export_dir, None)
|
|
286
|
+
if training_config_path is not None:
|
|
287
|
+
training_config_hash = hash_file(training_config_path)
|
|
288
|
+
training_config_text = training_config_path.read_text()
|
|
289
|
+
else:
|
|
290
|
+
training_config_hash = ""
|
|
291
|
+
training_config_text = None
|
|
292
|
+
|
|
293
|
+
metadata = build_base_metadata(
|
|
294
|
+
export_format="onnx",
|
|
295
|
+
model_type=model_type,
|
|
296
|
+
model_name=model_path.name,
|
|
297
|
+
checkpoint_path=str(ckpt_path),
|
|
298
|
+
backbone=backbone_type,
|
|
299
|
+
n_nodes=len(node_names),
|
|
300
|
+
n_edges=len(edge_inds),
|
|
301
|
+
node_names=node_names,
|
|
302
|
+
edge_inds=edge_inds,
|
|
303
|
+
input_scale=resolved_scale,
|
|
304
|
+
input_channels=resolve_input_channels(cfg),
|
|
305
|
+
output_stride=output_stride,
|
|
306
|
+
crop_size=resolved_crop_size,
|
|
307
|
+
max_instances=metadata_max_instances,
|
|
308
|
+
max_peaks_per_node=metadata_max_peaks,
|
|
309
|
+
max_batch_size=max_batch_size,
|
|
310
|
+
precision="fp32",
|
|
311
|
+
training_config_hash=training_config_hash,
|
|
312
|
+
training_config_embedded=training_config_text is not None,
|
|
313
|
+
input_dtype="uint8",
|
|
314
|
+
normalization="0_to_1",
|
|
315
|
+
n_classes=metadata_n_classes,
|
|
316
|
+
class_names=metadata_class_names,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
metadata.save(export_dir / "export_metadata.json")
|
|
320
|
+
|
|
321
|
+
if training_config_text is not None:
|
|
322
|
+
try:
|
|
323
|
+
embed_metadata_in_onnx(model_out_path, metadata, training_config_text)
|
|
324
|
+
except ImportError:
|
|
325
|
+
pass
|
|
326
|
+
|
|
327
|
+
# Export to TensorRT if requested
|
|
328
|
+
if fmt in ("tensorrt", "both"):
|
|
329
|
+
trt_out_path = export_dir / "model.trt"
|
|
330
|
+
B, C, H, W = input_shape
|
|
331
|
+
|
|
332
|
+
# For centered_instance and single_instance models, use crop size
|
|
333
|
+
# for TensorRT shape profiles since inference uses cropped inputs
|
|
334
|
+
if model_type in ("centered_instance", "single_instance"):
|
|
335
|
+
if resolved_crop_size is not None:
|
|
336
|
+
crop_h, crop_w = resolved_crop_size
|
|
337
|
+
trt_input_shape = (1, C, crop_h, crop_w)
|
|
338
|
+
# Use crop size for min/opt, allow flexibility for max
|
|
339
|
+
trt_min_shape = (1, C, crop_h, crop_w)
|
|
340
|
+
trt_opt_shape = (1, C, crop_h, crop_w)
|
|
341
|
+
trt_max_shape = (max_batch_size, C, crop_h * 2, crop_w * 2)
|
|
342
|
+
else:
|
|
343
|
+
trt_input_shape = input_shape
|
|
344
|
+
trt_min_shape = None
|
|
345
|
+
trt_opt_shape = None
|
|
346
|
+
trt_max_shape = (max_batch_size, C, H * 2, W * 2)
|
|
347
|
+
else:
|
|
348
|
+
trt_input_shape = input_shape
|
|
349
|
+
trt_min_shape = None
|
|
350
|
+
trt_opt_shape = None
|
|
351
|
+
trt_max_shape = (max_batch_size, C, H * 2, W * 2)
|
|
352
|
+
|
|
353
|
+
export_to_tensorrt(
|
|
354
|
+
wrapper,
|
|
355
|
+
trt_out_path,
|
|
356
|
+
input_shape=trt_input_shape,
|
|
357
|
+
input_dtype=torch.uint8,
|
|
358
|
+
precision=precision,
|
|
359
|
+
min_shape=trt_min_shape,
|
|
360
|
+
opt_shape=trt_opt_shape,
|
|
361
|
+
max_shape=trt_max_shape,
|
|
362
|
+
verbose=True,
|
|
363
|
+
)
|
|
364
|
+
# Update metadata for TensorRT
|
|
365
|
+
trt_metadata = build_base_metadata(
|
|
366
|
+
export_format="tensorrt",
|
|
367
|
+
model_type=model_type,
|
|
368
|
+
model_name=model_path.name,
|
|
369
|
+
checkpoint_path=str(ckpt_path),
|
|
370
|
+
backbone=backbone_type,
|
|
371
|
+
n_nodes=len(node_names),
|
|
372
|
+
n_edges=len(edge_inds),
|
|
373
|
+
node_names=node_names,
|
|
374
|
+
edge_inds=edge_inds,
|
|
375
|
+
input_scale=resolved_scale,
|
|
376
|
+
input_channels=resolve_input_channels(cfg),
|
|
377
|
+
output_stride=output_stride,
|
|
378
|
+
crop_size=resolved_crop_size,
|
|
379
|
+
max_instances=metadata_max_instances,
|
|
380
|
+
max_peaks_per_node=metadata_max_peaks,
|
|
381
|
+
max_batch_size=max_batch_size,
|
|
382
|
+
precision=precision,
|
|
383
|
+
training_config_hash=training_config_hash,
|
|
384
|
+
training_config_embedded=training_config_text is not None,
|
|
385
|
+
input_dtype="uint8",
|
|
386
|
+
normalization="0_to_1",
|
|
387
|
+
n_classes=metadata_n_classes,
|
|
388
|
+
class_names=metadata_class_names,
|
|
389
|
+
)
|
|
390
|
+
trt_metadata.save(export_dir / "model.trt.metadata.json")
|
|
391
|
+
return
|
|
392
|
+
|
|
393
|
+
if len(model_paths) == 2 and set(model_types) == {
|
|
394
|
+
"centroid",
|
|
395
|
+
"centered_instance",
|
|
396
|
+
}:
|
|
397
|
+
centroid_idx = model_types.index("centroid")
|
|
398
|
+
instance_idx = model_types.index("centered_instance")
|
|
399
|
+
|
|
400
|
+
centroid_path = model_paths[centroid_idx]
|
|
401
|
+
instance_path = model_paths[instance_idx]
|
|
402
|
+
centroid_cfg = cfgs[centroid_idx]
|
|
403
|
+
instance_cfg = cfgs[instance_idx]
|
|
404
|
+
centroid_backbone = backbone_types[centroid_idx]
|
|
405
|
+
instance_backbone = backbone_types[instance_idx]
|
|
406
|
+
|
|
407
|
+
centroid_ckpt = centroid_path / "best.ckpt"
|
|
408
|
+
instance_ckpt = instance_path / "best.ckpt"
|
|
409
|
+
if not centroid_ckpt.exists():
|
|
410
|
+
raise click.ClickException(f"Checkpoint not found: {centroid_ckpt}")
|
|
411
|
+
if not instance_ckpt.exists():
|
|
412
|
+
raise click.ClickException(f"Checkpoint not found: {instance_ckpt}")
|
|
413
|
+
|
|
414
|
+
centroid_model = _load_lightning_model(
|
|
415
|
+
model_type="centroid",
|
|
416
|
+
backbone_type=centroid_backbone,
|
|
417
|
+
cfg=centroid_cfg,
|
|
418
|
+
ckpt_path=centroid_ckpt,
|
|
419
|
+
device=device,
|
|
420
|
+
).model
|
|
421
|
+
instance_model = _load_lightning_model(
|
|
422
|
+
model_type="centered_instance",
|
|
423
|
+
backbone_type=instance_backbone,
|
|
424
|
+
cfg=instance_cfg,
|
|
425
|
+
ckpt_path=instance_ckpt,
|
|
426
|
+
device=device,
|
|
427
|
+
).model
|
|
428
|
+
|
|
429
|
+
centroid_model.eval()
|
|
430
|
+
instance_model.eval()
|
|
431
|
+
centroid_model.to(device)
|
|
432
|
+
instance_model.to(device)
|
|
433
|
+
|
|
434
|
+
export_dir = output or (centroid_path / "exported_topdown")
|
|
435
|
+
export_dir.mkdir(parents=True, exist_ok=True)
|
|
436
|
+
|
|
437
|
+
centroid_scale = (
|
|
438
|
+
input_scale
|
|
439
|
+
if input_scale is not None
|
|
440
|
+
else resolve_input_scale(centroid_cfg)
|
|
441
|
+
)
|
|
442
|
+
instance_scale = (
|
|
443
|
+
input_scale
|
|
444
|
+
if input_scale is not None
|
|
445
|
+
else resolve_input_scale(instance_cfg)
|
|
446
|
+
)
|
|
447
|
+
centroid_stride = resolve_output_stride(centroid_cfg, "centroid")
|
|
448
|
+
instance_stride = resolve_output_stride(instance_cfg, "centered_instance")
|
|
449
|
+
|
|
450
|
+
resolved_crop = resolve_crop_size(instance_cfg)
|
|
451
|
+
if crop_size is not None:
|
|
452
|
+
resolved_crop = (crop_size, crop_size)
|
|
453
|
+
if resolved_crop is None:
|
|
454
|
+
raise click.ClickException(
|
|
455
|
+
"Top-down export requires crop_size. Provide --crop-size or ensure "
|
|
456
|
+
"data_config.preprocessing.crop_size is set."
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
node_names = resolve_node_names(instance_cfg, "centered_instance")
|
|
460
|
+
edge_inds = resolve_edge_inds(instance_cfg, node_names)
|
|
461
|
+
|
|
462
|
+
wrapper = TopDownONNXWrapper(
|
|
463
|
+
centroid_model=centroid_model,
|
|
464
|
+
instance_model=instance_model,
|
|
465
|
+
max_instances=max_instances,
|
|
466
|
+
crop_size=resolved_crop,
|
|
467
|
+
centroid_output_stride=centroid_stride,
|
|
468
|
+
instance_output_stride=instance_stride,
|
|
469
|
+
centroid_input_scale=centroid_scale,
|
|
470
|
+
instance_input_scale=instance_scale,
|
|
471
|
+
n_nodes=len(node_names),
|
|
472
|
+
)
|
|
473
|
+
wrapper.eval()
|
|
474
|
+
wrapper.to(device)
|
|
475
|
+
|
|
476
|
+
input_shape = resolve_input_shape(
|
|
477
|
+
centroid_cfg, input_height=input_height, input_width=input_width
|
|
478
|
+
)
|
|
479
|
+
model_out_path = export_dir / "model.onnx"
|
|
480
|
+
|
|
481
|
+
export_to_onnx(
|
|
482
|
+
wrapper,
|
|
483
|
+
model_out_path,
|
|
484
|
+
input_shape=input_shape,
|
|
485
|
+
input_dtype=torch.uint8,
|
|
486
|
+
opset_version=opset_version,
|
|
487
|
+
output_names=[
|
|
488
|
+
"centroids",
|
|
489
|
+
"centroid_vals",
|
|
490
|
+
"peaks",
|
|
491
|
+
"peak_vals",
|
|
492
|
+
"instance_valid",
|
|
493
|
+
],
|
|
494
|
+
verify=verify,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
centroid_cfg_path = _copy_training_config(centroid_path, export_dir, "centroid")
|
|
498
|
+
instance_cfg_path = _copy_training_config(
|
|
499
|
+
instance_path, export_dir, "centered_instance"
|
|
500
|
+
)
|
|
501
|
+
config_payload = {}
|
|
502
|
+
config_hashes = []
|
|
503
|
+
if centroid_cfg_path is not None:
|
|
504
|
+
config_payload["centroid"] = centroid_cfg_path.read_text()
|
|
505
|
+
config_hashes.append(f"centroid:{hash_file(centroid_cfg_path)}")
|
|
506
|
+
if instance_cfg_path is not None:
|
|
507
|
+
config_payload["centered_instance"] = instance_cfg_path.read_text()
|
|
508
|
+
config_hashes.append(f"centered_instance:{hash_file(instance_cfg_path)}")
|
|
509
|
+
|
|
510
|
+
training_config_hash = ";".join(config_hashes) if config_hashes else ""
|
|
511
|
+
training_config_text = json.dumps(config_payload) if config_payload else None
|
|
512
|
+
|
|
513
|
+
metadata = build_base_metadata(
|
|
514
|
+
export_format="onnx",
|
|
515
|
+
model_type="topdown",
|
|
516
|
+
model_name=f"{centroid_path.name}+{instance_path.name}",
|
|
517
|
+
checkpoint_path=(
|
|
518
|
+
f"centroid:{centroid_ckpt};centered_instance:{instance_ckpt}"
|
|
519
|
+
),
|
|
520
|
+
backbone=(
|
|
521
|
+
f"centroid:{centroid_backbone};centered_instance:{instance_backbone}"
|
|
522
|
+
),
|
|
523
|
+
n_nodes=len(node_names),
|
|
524
|
+
n_edges=len(edge_inds),
|
|
525
|
+
node_names=node_names,
|
|
526
|
+
edge_inds=edge_inds,
|
|
527
|
+
input_scale=centroid_scale,
|
|
528
|
+
input_channels=resolve_input_channels(centroid_cfg),
|
|
529
|
+
output_stride=instance_stride,
|
|
530
|
+
crop_size=resolved_crop,
|
|
531
|
+
max_instances=max_instances,
|
|
532
|
+
max_batch_size=max_batch_size,
|
|
533
|
+
precision="fp32",
|
|
534
|
+
training_config_hash=training_config_hash,
|
|
535
|
+
training_config_embedded=training_config_text is not None,
|
|
536
|
+
input_dtype="uint8",
|
|
537
|
+
normalization="0_to_1",
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
metadata.save(export_dir / "export_metadata.json")
|
|
541
|
+
|
|
542
|
+
if training_config_text is not None:
|
|
543
|
+
try:
|
|
544
|
+
embed_metadata_in_onnx(model_out_path, metadata, training_config_text)
|
|
545
|
+
except ImportError:
|
|
546
|
+
pass
|
|
547
|
+
|
|
548
|
+
# Export to TensorRT if requested
|
|
549
|
+
if fmt in ("tensorrt", "both"):
|
|
550
|
+
trt_out_path = export_dir / "model.trt"
|
|
551
|
+
B, C, H, W = input_shape
|
|
552
|
+
export_to_tensorrt(
|
|
553
|
+
wrapper,
|
|
554
|
+
trt_out_path,
|
|
555
|
+
input_shape=input_shape,
|
|
556
|
+
input_dtype=torch.uint8,
|
|
557
|
+
precision=precision,
|
|
558
|
+
max_shape=(max_batch_size, C, H * 2, W * 2),
|
|
559
|
+
verbose=True,
|
|
560
|
+
)
|
|
561
|
+
# Update metadata for TensorRT
|
|
562
|
+
trt_metadata = build_base_metadata(
|
|
563
|
+
export_format="tensorrt",
|
|
564
|
+
model_type="topdown",
|
|
565
|
+
model_name=f"{centroid_path.name}+{instance_path.name}",
|
|
566
|
+
checkpoint_path=(
|
|
567
|
+
f"centroid:{centroid_ckpt};centered_instance:{instance_ckpt}"
|
|
568
|
+
),
|
|
569
|
+
backbone=(
|
|
570
|
+
f"centroid:{centroid_backbone};centered_instance:{instance_backbone}"
|
|
571
|
+
),
|
|
572
|
+
n_nodes=len(node_names),
|
|
573
|
+
n_edges=len(edge_inds),
|
|
574
|
+
node_names=node_names,
|
|
575
|
+
edge_inds=edge_inds,
|
|
576
|
+
input_scale=centroid_scale,
|
|
577
|
+
input_channels=resolve_input_channels(centroid_cfg),
|
|
578
|
+
output_stride=instance_stride,
|
|
579
|
+
crop_size=resolved_crop,
|
|
580
|
+
max_instances=max_instances,
|
|
581
|
+
max_batch_size=max_batch_size,
|
|
582
|
+
precision=precision,
|
|
583
|
+
training_config_hash=training_config_hash,
|
|
584
|
+
training_config_embedded=training_config_text is not None,
|
|
585
|
+
input_dtype="uint8",
|
|
586
|
+
normalization="0_to_1",
|
|
587
|
+
)
|
|
588
|
+
trt_metadata.save(export_dir / "model.trt.metadata.json")
|
|
589
|
+
return
|
|
590
|
+
|
|
591
|
+
# Combined multiclass top-down export (centroid + multi_class_topdown)
|
|
592
|
+
if len(model_paths) == 2 and set(model_types) == {
|
|
593
|
+
"centroid",
|
|
594
|
+
"multi_class_topdown",
|
|
595
|
+
}:
|
|
596
|
+
centroid_idx = model_types.index("centroid")
|
|
597
|
+
instance_idx = model_types.index("multi_class_topdown")
|
|
598
|
+
|
|
599
|
+
centroid_path = model_paths[centroid_idx]
|
|
600
|
+
instance_path = model_paths[instance_idx]
|
|
601
|
+
centroid_cfg = cfgs[centroid_idx]
|
|
602
|
+
instance_cfg = cfgs[instance_idx]
|
|
603
|
+
centroid_backbone = backbone_types[centroid_idx]
|
|
604
|
+
instance_backbone = backbone_types[instance_idx]
|
|
605
|
+
|
|
606
|
+
centroid_ckpt = centroid_path / "best.ckpt"
|
|
607
|
+
instance_ckpt = instance_path / "best.ckpt"
|
|
608
|
+
if not centroid_ckpt.exists():
|
|
609
|
+
raise click.ClickException(f"Checkpoint not found: {centroid_ckpt}")
|
|
610
|
+
if not instance_ckpt.exists():
|
|
611
|
+
raise click.ClickException(f"Checkpoint not found: {instance_ckpt}")
|
|
612
|
+
|
|
613
|
+
centroid_model = _load_lightning_model(
|
|
614
|
+
model_type="centroid",
|
|
615
|
+
backbone_type=centroid_backbone,
|
|
616
|
+
cfg=centroid_cfg,
|
|
617
|
+
ckpt_path=centroid_ckpt,
|
|
618
|
+
device=device,
|
|
619
|
+
).model
|
|
620
|
+
instance_model = _load_lightning_model(
|
|
621
|
+
model_type="multi_class_topdown",
|
|
622
|
+
backbone_type=instance_backbone,
|
|
623
|
+
cfg=instance_cfg,
|
|
624
|
+
ckpt_path=instance_ckpt,
|
|
625
|
+
device=device,
|
|
626
|
+
).model
|
|
627
|
+
|
|
628
|
+
centroid_model.eval()
|
|
629
|
+
instance_model.eval()
|
|
630
|
+
centroid_model.to(device)
|
|
631
|
+
instance_model.to(device)
|
|
632
|
+
|
|
633
|
+
export_dir = output or (centroid_path / "exported_multi_class_topdown")
|
|
634
|
+
export_dir.mkdir(parents=True, exist_ok=True)
|
|
635
|
+
|
|
636
|
+
centroid_scale = (
|
|
637
|
+
input_scale
|
|
638
|
+
if input_scale is not None
|
|
639
|
+
else resolve_input_scale(centroid_cfg)
|
|
640
|
+
)
|
|
641
|
+
instance_scale = (
|
|
642
|
+
input_scale
|
|
643
|
+
if input_scale is not None
|
|
644
|
+
else resolve_input_scale(instance_cfg)
|
|
645
|
+
)
|
|
646
|
+
centroid_stride = resolve_output_stride(centroid_cfg, "centroid")
|
|
647
|
+
instance_stride = resolve_output_stride(instance_cfg, "multi_class_topdown")
|
|
648
|
+
|
|
649
|
+
resolved_crop = resolve_crop_size(instance_cfg)
|
|
650
|
+
if crop_size is not None:
|
|
651
|
+
resolved_crop = (crop_size, crop_size)
|
|
652
|
+
if resolved_crop is None:
|
|
653
|
+
raise click.ClickException(
|
|
654
|
+
"Multiclass top-down export requires crop_size. Provide --crop-size or "
|
|
655
|
+
"ensure data_config.preprocessing.crop_size is set."
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
node_names = resolve_node_names(instance_cfg, "multi_class_topdown")
|
|
659
|
+
edge_inds = resolve_edge_inds(instance_cfg, node_names)
|
|
660
|
+
n_classes = resolve_n_classes(instance_cfg, "multi_class_topdown")
|
|
661
|
+
class_names = resolve_class_names(instance_cfg, "multi_class_topdown")
|
|
662
|
+
|
|
663
|
+
wrapper = TopDownMultiClassCombinedONNXWrapper(
|
|
664
|
+
centroid_model=centroid_model,
|
|
665
|
+
instance_model=instance_model,
|
|
666
|
+
max_instances=max_instances,
|
|
667
|
+
crop_size=resolved_crop,
|
|
668
|
+
centroid_output_stride=centroid_stride,
|
|
669
|
+
instance_output_stride=instance_stride,
|
|
670
|
+
centroid_input_scale=centroid_scale,
|
|
671
|
+
instance_input_scale=instance_scale,
|
|
672
|
+
n_nodes=len(node_names),
|
|
673
|
+
n_classes=n_classes,
|
|
674
|
+
)
|
|
675
|
+
wrapper.eval()
|
|
676
|
+
wrapper.to(device)
|
|
677
|
+
|
|
678
|
+
input_shape = resolve_input_shape(
|
|
679
|
+
centroid_cfg, input_height=input_height, input_width=input_width
|
|
680
|
+
)
|
|
681
|
+
model_out_path = export_dir / "model.onnx"
|
|
682
|
+
|
|
683
|
+
export_to_onnx(
|
|
684
|
+
wrapper,
|
|
685
|
+
model_out_path,
|
|
686
|
+
input_shape=input_shape,
|
|
687
|
+
input_dtype=torch.uint8,
|
|
688
|
+
opset_version=opset_version,
|
|
689
|
+
output_names=[
|
|
690
|
+
"centroids",
|
|
691
|
+
"centroid_vals",
|
|
692
|
+
"peaks",
|
|
693
|
+
"peak_vals",
|
|
694
|
+
"class_logits",
|
|
695
|
+
"instance_valid",
|
|
696
|
+
],
|
|
697
|
+
verify=verify,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
centroid_cfg_path = _copy_training_config(centroid_path, export_dir, "centroid")
|
|
701
|
+
instance_cfg_path = _copy_training_config(
|
|
702
|
+
instance_path, export_dir, "multi_class_topdown"
|
|
703
|
+
)
|
|
704
|
+
config_payload = {}
|
|
705
|
+
config_hashes = []
|
|
706
|
+
if centroid_cfg_path is not None:
|
|
707
|
+
config_payload["centroid"] = centroid_cfg_path.read_text()
|
|
708
|
+
config_hashes.append(f"centroid:{hash_file(centroid_cfg_path)}")
|
|
709
|
+
if instance_cfg_path is not None:
|
|
710
|
+
config_payload["multi_class_topdown"] = instance_cfg_path.read_text()
|
|
711
|
+
config_hashes.append(f"multi_class_topdown:{hash_file(instance_cfg_path)}")
|
|
712
|
+
|
|
713
|
+
training_config_hash = ";".join(config_hashes) if config_hashes else ""
|
|
714
|
+
training_config_text = json.dumps(config_payload) if config_payload else None
|
|
715
|
+
|
|
716
|
+
metadata = build_base_metadata(
|
|
717
|
+
export_format="onnx",
|
|
718
|
+
model_type="multi_class_topdown_combined",
|
|
719
|
+
model_name=f"{centroid_path.name}+{instance_path.name}",
|
|
720
|
+
checkpoint_path=(
|
|
721
|
+
f"centroid:{centroid_ckpt};multi_class_topdown:{instance_ckpt}"
|
|
722
|
+
),
|
|
723
|
+
backbone=(
|
|
724
|
+
f"centroid:{centroid_backbone};multi_class_topdown:{instance_backbone}"
|
|
725
|
+
),
|
|
726
|
+
n_nodes=len(node_names),
|
|
727
|
+
n_edges=len(edge_inds),
|
|
728
|
+
node_names=node_names,
|
|
729
|
+
edge_inds=edge_inds,
|
|
730
|
+
input_scale=centroid_scale,
|
|
731
|
+
input_channels=resolve_input_channels(centroid_cfg),
|
|
732
|
+
output_stride=instance_stride,
|
|
733
|
+
crop_size=resolved_crop,
|
|
734
|
+
max_instances=max_instances,
|
|
735
|
+
max_batch_size=max_batch_size,
|
|
736
|
+
training_config_hash=training_config_hash,
|
|
737
|
+
training_config_embedded=training_config_text is not None,
|
|
738
|
+
input_dtype="uint8",
|
|
739
|
+
normalization="0_to_1",
|
|
740
|
+
n_classes=n_classes,
|
|
741
|
+
class_names=class_names,
|
|
742
|
+
)
|
|
743
|
+
metadata.save(export_dir / "export_metadata.json")
|
|
744
|
+
click.echo(f"ONNX model exported to: {model_out_path}")
|
|
745
|
+
click.echo(f"Metadata saved to: {export_dir / 'export_metadata.json'}")
|
|
746
|
+
|
|
747
|
+
# TensorRT export for combined multiclass top-down
|
|
748
|
+
if fmt in ("tensorrt", "both"):
|
|
749
|
+
trt_out_path = export_dir / "model.trt"
|
|
750
|
+
B, C, H, W = input_shape
|
|
751
|
+
export_to_tensorrt(
|
|
752
|
+
wrapper,
|
|
753
|
+
trt_out_path,
|
|
754
|
+
input_shape=input_shape,
|
|
755
|
+
input_dtype=torch.uint8,
|
|
756
|
+
precision=precision,
|
|
757
|
+
max_shape=(max_batch_size, C, H * 2, W * 2),
|
|
758
|
+
verbose=True,
|
|
759
|
+
)
|
|
760
|
+
trt_metadata = build_base_metadata(
|
|
761
|
+
export_format="tensorrt",
|
|
762
|
+
model_type="multi_class_topdown_combined",
|
|
763
|
+
model_name=f"{centroid_path.name}+{instance_path.name}",
|
|
764
|
+
checkpoint_path=(
|
|
765
|
+
f"centroid:{centroid_ckpt};multi_class_topdown:{instance_ckpt}"
|
|
766
|
+
),
|
|
767
|
+
backbone=(
|
|
768
|
+
f"centroid:{centroid_backbone};multi_class_topdown:{instance_backbone}"
|
|
769
|
+
),
|
|
770
|
+
n_nodes=len(node_names),
|
|
771
|
+
n_edges=len(edge_inds),
|
|
772
|
+
node_names=node_names,
|
|
773
|
+
edge_inds=edge_inds,
|
|
774
|
+
input_scale=centroid_scale,
|
|
775
|
+
input_channels=resolve_input_channels(centroid_cfg),
|
|
776
|
+
output_stride=instance_stride,
|
|
777
|
+
crop_size=resolved_crop,
|
|
778
|
+
max_instances=max_instances,
|
|
779
|
+
max_batch_size=max_batch_size,
|
|
780
|
+
precision=precision,
|
|
781
|
+
training_config_hash=training_config_hash,
|
|
782
|
+
training_config_embedded=training_config_text is not None,
|
|
783
|
+
input_dtype="uint8",
|
|
784
|
+
normalization="0_to_1",
|
|
785
|
+
n_classes=n_classes,
|
|
786
|
+
class_names=class_names,
|
|
787
|
+
)
|
|
788
|
+
trt_metadata.save(export_dir / "model.trt.metadata.json")
|
|
789
|
+
return
|
|
790
|
+
|
|
791
|
+
raise click.ClickException(
|
|
792
|
+
"Provide one model path for centroid/centered-instance/bottom-up export, "
|
|
793
|
+
"or two paths (centroid + centered_instance or centroid + multi_class_topdown) "
|
|
794
|
+
"for combined top-down export."
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
@click.command()
|
|
799
|
+
@click.argument(
|
|
800
|
+
"export_dir",
|
|
801
|
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
|
802
|
+
)
|
|
803
|
+
@click.argument(
|
|
804
|
+
"video_path",
|
|
805
|
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
|
806
|
+
)
|
|
807
|
+
@click.option(
|
|
808
|
+
"--output",
|
|
809
|
+
"-o",
|
|
810
|
+
type=click.Path(dir_okay=False, path_type=Path),
|
|
811
|
+
default=None,
|
|
812
|
+
help="Output SLP file path. Default: video_name.predictions.slp",
|
|
813
|
+
)
|
|
814
|
+
@click.option(
|
|
815
|
+
"--runtime",
|
|
816
|
+
type=click.Choice(["auto", "onnx", "tensorrt"], case_sensitive=False),
|
|
817
|
+
default="auto",
|
|
818
|
+
show_default=True,
|
|
819
|
+
help="Runtime to use for inference.",
|
|
820
|
+
)
|
|
821
|
+
@click.option("--device", type=str, default="auto", show_default=True)
|
|
822
|
+
@click.option("--batch-size", type=int, default=4, show_default=True)
|
|
823
|
+
@click.option("--n-frames", type=int, default=None, help="Limit to first N frames.")
|
|
824
|
+
@click.option(
|
|
825
|
+
"--max-edge-length-ratio",
|
|
826
|
+
type=float,
|
|
827
|
+
default=0.25,
|
|
828
|
+
show_default=True,
|
|
829
|
+
help="Bottom-up: max edge length as ratio of PAF dimensions.",
|
|
830
|
+
)
|
|
831
|
+
@click.option(
|
|
832
|
+
"--dist-penalty-weight",
|
|
833
|
+
type=float,
|
|
834
|
+
default=1.0,
|
|
835
|
+
show_default=True,
|
|
836
|
+
help="Bottom-up: weight for distance penalty in PAF scoring.",
|
|
837
|
+
)
|
|
838
|
+
@click.option(
|
|
839
|
+
"--n-points",
|
|
840
|
+
type=int,
|
|
841
|
+
default=10,
|
|
842
|
+
show_default=True,
|
|
843
|
+
help="Bottom-up: number of points to sample along PAF.",
|
|
844
|
+
)
|
|
845
|
+
@click.option(
|
|
846
|
+
"--min-instance-peaks",
|
|
847
|
+
type=float,
|
|
848
|
+
default=0,
|
|
849
|
+
show_default=True,
|
|
850
|
+
help="Bottom-up: minimum peaks required per instance.",
|
|
851
|
+
)
|
|
852
|
+
@click.option(
|
|
853
|
+
"--min-line-scores",
|
|
854
|
+
type=float,
|
|
855
|
+
default=-0.5,
|
|
856
|
+
show_default=True,
|
|
857
|
+
help="Bottom-up: minimum line score threshold.",
|
|
858
|
+
)
|
|
859
|
+
@click.option(
|
|
860
|
+
"--peak-conf-threshold",
|
|
861
|
+
type=float,
|
|
862
|
+
default=0.1,
|
|
863
|
+
show_default=True,
|
|
864
|
+
help="Bottom-up: peak confidence threshold for filtering candidates.",
|
|
865
|
+
)
|
|
866
|
+
@click.option(
|
|
867
|
+
"--max-instances",
|
|
868
|
+
type=int,
|
|
869
|
+
default=None,
|
|
870
|
+
help="Maximum instances to output per frame.",
|
|
871
|
+
)
|
|
872
|
+
def predict(
|
|
873
|
+
export_dir: Path,
|
|
874
|
+
video_path: Path,
|
|
875
|
+
output: Optional[Path],
|
|
876
|
+
runtime: str,
|
|
877
|
+
device: str,
|
|
878
|
+
batch_size: int,
|
|
879
|
+
n_frames: Optional[int],
|
|
880
|
+
max_edge_length_ratio: float,
|
|
881
|
+
dist_penalty_weight: float,
|
|
882
|
+
n_points: int,
|
|
883
|
+
min_instance_peaks: float,
|
|
884
|
+
min_line_scores: float,
|
|
885
|
+
peak_conf_threshold: float,
|
|
886
|
+
max_instances: Optional[int],
|
|
887
|
+
) -> None:
|
|
888
|
+
"""Run inference on exported models and save predictions to SLP.
|
|
889
|
+
|
|
890
|
+
EXPORT_DIR is the directory containing the exported model (model.onnx or model.trt)
|
|
891
|
+
along with export_metadata.json and training_config.yaml.
|
|
892
|
+
|
|
893
|
+
VIDEO_PATH is the path to the video file to process.
|
|
894
|
+
"""
|
|
895
|
+
import time
|
|
896
|
+
from datetime import datetime
|
|
897
|
+
|
|
898
|
+
import numpy as np
|
|
899
|
+
import sleap_io as sio
|
|
900
|
+
|
|
901
|
+
from sleap_nn.export.metadata import ExportMetadata
|
|
902
|
+
from sleap_nn.export.predictors import load_exported_model
|
|
903
|
+
from sleap_nn.export.utils import build_bottomup_candidate_template
|
|
904
|
+
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
905
|
+
from sleap_nn.inference.utils import get_skeleton_from_config
|
|
906
|
+
|
|
907
|
+
# Load metadata
|
|
908
|
+
metadata_path = export_dir / "export_metadata.json"
|
|
909
|
+
if not metadata_path.exists():
|
|
910
|
+
raise click.ClickException(f"Metadata not found: {metadata_path}")
|
|
911
|
+
metadata = ExportMetadata.load(metadata_path)
|
|
912
|
+
|
|
913
|
+
# Find model file
|
|
914
|
+
onnx_path = export_dir / "model.onnx"
|
|
915
|
+
trt_path = export_dir / "model.trt"
|
|
916
|
+
|
|
917
|
+
if runtime == "auto":
|
|
918
|
+
if trt_path.exists():
|
|
919
|
+
model_path = trt_path
|
|
920
|
+
runtime = "tensorrt"
|
|
921
|
+
elif onnx_path.exists():
|
|
922
|
+
model_path = onnx_path
|
|
923
|
+
runtime = "onnx"
|
|
924
|
+
else:
|
|
925
|
+
raise click.ClickException(
|
|
926
|
+
f"No model found in {export_dir}. Expected model.onnx or model.trt."
|
|
927
|
+
)
|
|
928
|
+
elif runtime == "onnx":
|
|
929
|
+
if not onnx_path.exists():
|
|
930
|
+
raise click.ClickException(f"ONNX model not found: {onnx_path}")
|
|
931
|
+
model_path = onnx_path
|
|
932
|
+
elif runtime == "tensorrt":
|
|
933
|
+
if not trt_path.exists():
|
|
934
|
+
raise click.ClickException(f"TensorRT model not found: {trt_path}")
|
|
935
|
+
model_path = trt_path
|
|
936
|
+
else:
|
|
937
|
+
raise click.ClickException(f"Unknown runtime: {runtime}")
|
|
938
|
+
|
|
939
|
+
# Load training config for skeleton
|
|
940
|
+
cfg_path = _find_training_config_for_predict(export_dir, metadata.model_type)
|
|
941
|
+
if cfg_path.suffix in {".yaml", ".yml"}:
|
|
942
|
+
cfg = OmegaConf.load(cfg_path.as_posix())
|
|
943
|
+
else:
|
|
944
|
+
from sleap_nn.config.training_job_config import TrainingJobConfig
|
|
945
|
+
|
|
946
|
+
cfg = TrainingJobConfig.load_sleap_config(cfg_path.as_posix())
|
|
947
|
+
skeletons = get_skeleton_from_config(cfg.data_config.skeletons)
|
|
948
|
+
skeleton = skeletons[0]
|
|
949
|
+
|
|
950
|
+
# Load video
|
|
951
|
+
video = sio.Video.from_filename(video_path.as_posix())
|
|
952
|
+
total_frames = len(video) if n_frames is None else min(n_frames, len(video))
|
|
953
|
+
frame_indices = list(range(total_frames))
|
|
954
|
+
|
|
955
|
+
click.echo(f"Loading model from: {model_path}")
|
|
956
|
+
click.echo(f" Model type: {metadata.model_type}")
|
|
957
|
+
click.echo(f" Runtime: {runtime}")
|
|
958
|
+
click.echo(f" Device: {device}")
|
|
959
|
+
|
|
960
|
+
predictor = load_exported_model(
|
|
961
|
+
model_path.as_posix(), runtime=runtime, device=device
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
click.echo(f"Processing video: {video_path}")
|
|
965
|
+
click.echo(f" Total frames: {total_frames}")
|
|
966
|
+
click.echo(f" Batch size: {batch_size}")
|
|
967
|
+
|
|
968
|
+
# Set up centroid anchor node if needed
|
|
969
|
+
anchor_node_idx = None
|
|
970
|
+
if metadata.model_type == "centroid":
|
|
971
|
+
anchor_part = cfg.model_config.head_configs.centroid.confmaps.anchor_part
|
|
972
|
+
node_names = [n.name for n in skeleton.nodes]
|
|
973
|
+
if anchor_part in node_names:
|
|
974
|
+
anchor_node_idx = node_names.index(anchor_part)
|
|
975
|
+
else:
|
|
976
|
+
raise click.ClickException(
|
|
977
|
+
f"Anchor part '{anchor_part}' not found in skeleton nodes: {node_names}"
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
# Set up bottom-up post-processing if needed
|
|
981
|
+
paf_scorer = None
|
|
982
|
+
candidate_template = None
|
|
983
|
+
if metadata.model_type == "bottomup":
|
|
984
|
+
paf_scorer = PAFScorer.from_config(
|
|
985
|
+
cfg.model_config.head_configs.bottomup,
|
|
986
|
+
max_edge_length_ratio=max_edge_length_ratio,
|
|
987
|
+
dist_penalty_weight=dist_penalty_weight,
|
|
988
|
+
n_points=n_points,
|
|
989
|
+
min_instance_peaks=min_instance_peaks,
|
|
990
|
+
min_line_scores=min_line_scores,
|
|
991
|
+
)
|
|
992
|
+
max_peaks = metadata.max_peaks_per_node
|
|
993
|
+
if max_peaks is None:
|
|
994
|
+
raise click.ClickException(
|
|
995
|
+
"Bottom-up export metadata missing max_peaks_per_node."
|
|
996
|
+
)
|
|
997
|
+
edge_inds_tuples = [(int(e[0]), int(e[1])) for e in paf_scorer.edge_inds]
|
|
998
|
+
peak_channel_inds, edge_inds_tensor, edge_peak_inds = (
|
|
999
|
+
build_bottomup_candidate_template(
|
|
1000
|
+
n_nodes=metadata.n_nodes,
|
|
1001
|
+
max_peaks_per_node=max_peaks,
|
|
1002
|
+
edge_inds=edge_inds_tuples,
|
|
1003
|
+
)
|
|
1004
|
+
)
|
|
1005
|
+
candidate_template = {
|
|
1006
|
+
"peak_channel_inds": peak_channel_inds,
|
|
1007
|
+
"edge_inds": edge_inds_tensor,
|
|
1008
|
+
"edge_peak_inds": edge_peak_inds,
|
|
1009
|
+
}
|
|
1010
|
+
|
|
1011
|
+
labeled_frames = []
|
|
1012
|
+
total_start = time.perf_counter()
|
|
1013
|
+
infer_time = 0.0
|
|
1014
|
+
post_time = 0.0
|
|
1015
|
+
|
|
1016
|
+
for start in range(0, len(frame_indices), batch_size):
|
|
1017
|
+
batch_indices = frame_indices[start : start + batch_size]
|
|
1018
|
+
batch = _load_video_batch(video, batch_indices)
|
|
1019
|
+
|
|
1020
|
+
infer_start = time.perf_counter()
|
|
1021
|
+
outputs = predictor.predict(batch)
|
|
1022
|
+
infer_time += time.perf_counter() - infer_start
|
|
1023
|
+
|
|
1024
|
+
post_start = time.perf_counter()
|
|
1025
|
+
if metadata.model_type == "topdown":
|
|
1026
|
+
labeled_frames.extend(
|
|
1027
|
+
_predict_topdown_frames(
|
|
1028
|
+
outputs,
|
|
1029
|
+
batch_indices,
|
|
1030
|
+
video,
|
|
1031
|
+
skeleton,
|
|
1032
|
+
max_instances=max_instances,
|
|
1033
|
+
)
|
|
1034
|
+
)
|
|
1035
|
+
elif metadata.model_type == "bottomup":
|
|
1036
|
+
labeled_frames.extend(
|
|
1037
|
+
_predict_bottomup_frames(
|
|
1038
|
+
outputs,
|
|
1039
|
+
batch_indices,
|
|
1040
|
+
video,
|
|
1041
|
+
skeleton,
|
|
1042
|
+
paf_scorer,
|
|
1043
|
+
candidate_template,
|
|
1044
|
+
input_scale=metadata.input_scale,
|
|
1045
|
+
peak_conf_threshold=peak_conf_threshold,
|
|
1046
|
+
max_instances=max_instances,
|
|
1047
|
+
)
|
|
1048
|
+
)
|
|
1049
|
+
elif metadata.model_type == "single_instance":
|
|
1050
|
+
labeled_frames.extend(
|
|
1051
|
+
_predict_single_instance_frames(
|
|
1052
|
+
outputs,
|
|
1053
|
+
batch_indices,
|
|
1054
|
+
video,
|
|
1055
|
+
skeleton,
|
|
1056
|
+
)
|
|
1057
|
+
)
|
|
1058
|
+
elif metadata.model_type == "centroid":
|
|
1059
|
+
labeled_frames.extend(
|
|
1060
|
+
_predict_centroid_frames(
|
|
1061
|
+
outputs,
|
|
1062
|
+
batch_indices,
|
|
1063
|
+
video,
|
|
1064
|
+
skeleton,
|
|
1065
|
+
anchor_node_idx=anchor_node_idx,
|
|
1066
|
+
max_instances=max_instances,
|
|
1067
|
+
)
|
|
1068
|
+
)
|
|
1069
|
+
elif metadata.model_type == "multi_class_bottomup":
|
|
1070
|
+
labeled_frames.extend(
|
|
1071
|
+
_predict_multiclass_bottomup_frames(
|
|
1072
|
+
outputs,
|
|
1073
|
+
batch_indices,
|
|
1074
|
+
video,
|
|
1075
|
+
skeleton,
|
|
1076
|
+
class_names=metadata.class_names or [],
|
|
1077
|
+
input_scale=metadata.input_scale,
|
|
1078
|
+
peak_conf_threshold=peak_conf_threshold,
|
|
1079
|
+
max_instances=max_instances,
|
|
1080
|
+
)
|
|
1081
|
+
)
|
|
1082
|
+
elif metadata.model_type == "multi_class_topdown_combined":
|
|
1083
|
+
labeled_frames.extend(
|
|
1084
|
+
_predict_multiclass_topdown_combined_frames(
|
|
1085
|
+
outputs,
|
|
1086
|
+
batch_indices,
|
|
1087
|
+
video,
|
|
1088
|
+
skeleton,
|
|
1089
|
+
class_names=metadata.class_names or [],
|
|
1090
|
+
max_instances=max_instances,
|
|
1091
|
+
)
|
|
1092
|
+
)
|
|
1093
|
+
else:
|
|
1094
|
+
raise click.ClickException(
|
|
1095
|
+
f"Unsupported model_type for predict: {metadata.model_type}"
|
|
1096
|
+
)
|
|
1097
|
+
post_time += time.perf_counter() - post_start
|
|
1098
|
+
|
|
1099
|
+
# Progress update
|
|
1100
|
+
processed = min(start + batch_size, len(frame_indices))
|
|
1101
|
+
click.echo(
|
|
1102
|
+
f"\r Processed {processed}/{len(frame_indices)} frames...",
|
|
1103
|
+
nl=False,
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
click.echo() # Newline after progress
|
|
1107
|
+
|
|
1108
|
+
total_time = time.perf_counter() - total_start
|
|
1109
|
+
fps = len(frame_indices) / total_time if total_time > 0 else 0
|
|
1110
|
+
|
|
1111
|
+
# Save predictions
|
|
1112
|
+
output_path = output or video_path.with_suffix(".predictions.slp")
|
|
1113
|
+
labels = sio.Labels(
|
|
1114
|
+
videos=[video],
|
|
1115
|
+
skeletons=[skeleton],
|
|
1116
|
+
labeled_frames=labeled_frames,
|
|
1117
|
+
)
|
|
1118
|
+
labels.provenance = {
|
|
1119
|
+
"sleap_nn_version": metadata.sleap_nn_version,
|
|
1120
|
+
"export_format": runtime,
|
|
1121
|
+
"model_type": metadata.model_type,
|
|
1122
|
+
"inference_timestamp": datetime.now().isoformat(),
|
|
1123
|
+
}
|
|
1124
|
+
sio.save_file(labels, output_path.as_posix())
|
|
1125
|
+
|
|
1126
|
+
click.echo(f"\nInference complete:")
|
|
1127
|
+
click.echo(f" Total time: {total_time:.2f}s")
|
|
1128
|
+
click.echo(f" Inference time: {infer_time:.2f}s")
|
|
1129
|
+
click.echo(f" Post-processing time: {post_time:.2f}s")
|
|
1130
|
+
click.echo(f" FPS: {fps:.2f}")
|
|
1131
|
+
click.echo(f" Frames with predictions: {len(labeled_frames)}")
|
|
1132
|
+
click.echo(f" Output saved to: {output_path}")
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
def _find_training_config_for_predict(export_dir: Path, model_type: str) -> Path:
|
|
1136
|
+
"""Find training config file in export directory."""
|
|
1137
|
+
candidates = []
|
|
1138
|
+
if model_type == "topdown":
|
|
1139
|
+
candidates.extend(
|
|
1140
|
+
[
|
|
1141
|
+
export_dir / "training_config_centered_instance.yaml",
|
|
1142
|
+
export_dir / "training_config_centered_instance.json",
|
|
1143
|
+
]
|
|
1144
|
+
)
|
|
1145
|
+
elif model_type == "multi_class_topdown_combined":
|
|
1146
|
+
candidates.extend(
|
|
1147
|
+
[
|
|
1148
|
+
export_dir / "training_config_multi_class_topdown.yaml",
|
|
1149
|
+
export_dir / "training_config_multi_class_topdown.json",
|
|
1150
|
+
]
|
|
1151
|
+
)
|
|
1152
|
+
candidates.extend(
|
|
1153
|
+
[
|
|
1154
|
+
export_dir / "training_config.yaml",
|
|
1155
|
+
export_dir / "training_config.json",
|
|
1156
|
+
export_dir / f"training_config_{model_type}.yaml",
|
|
1157
|
+
export_dir / f"training_config_{model_type}.json",
|
|
1158
|
+
]
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
for candidate in candidates:
|
|
1162
|
+
if candidate.exists():
|
|
1163
|
+
return candidate
|
|
1164
|
+
raise click.ClickException(
|
|
1165
|
+
f"No training_config found in {export_dir} for model_type={model_type}."
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
|
|
1169
|
+
def _load_video_batch(video, frame_indices):
|
|
1170
|
+
"""Load a batch of video frames as uint8 NCHW array."""
|
|
1171
|
+
import numpy as np
|
|
1172
|
+
|
|
1173
|
+
frames = []
|
|
1174
|
+
for idx in frame_indices:
|
|
1175
|
+
frame = np.asarray(video[idx])
|
|
1176
|
+
if frame.ndim == 2:
|
|
1177
|
+
frame = frame[:, :, None]
|
|
1178
|
+
if frame.dtype != np.uint8:
|
|
1179
|
+
frame = frame.astype(np.uint8)
|
|
1180
|
+
frame = np.transpose(frame, (2, 0, 1)) # HWC -> CHW
|
|
1181
|
+
frames.append(frame)
|
|
1182
|
+
return np.stack(frames, axis=0)
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
def _predict_topdown_frames(
|
|
1186
|
+
outputs,
|
|
1187
|
+
frame_indices,
|
|
1188
|
+
video,
|
|
1189
|
+
skeleton,
|
|
1190
|
+
max_instances=None,
|
|
1191
|
+
):
|
|
1192
|
+
"""Convert top-down model outputs to LabeledFrames."""
|
|
1193
|
+
import sleap_io as sio
|
|
1194
|
+
|
|
1195
|
+
labeled_frames = []
|
|
1196
|
+
centroids = outputs["centroids"]
|
|
1197
|
+
centroid_vals = outputs["centroid_vals"]
|
|
1198
|
+
peaks = outputs["peaks"]
|
|
1199
|
+
peak_vals = outputs["peak_vals"]
|
|
1200
|
+
instance_valid = outputs["instance_valid"]
|
|
1201
|
+
|
|
1202
|
+
for batch_idx, frame_idx in enumerate(frame_indices):
|
|
1203
|
+
instances = []
|
|
1204
|
+
valid_mask = instance_valid[batch_idx].astype(bool)
|
|
1205
|
+
for inst_idx, is_valid in enumerate(valid_mask):
|
|
1206
|
+
if not is_valid:
|
|
1207
|
+
continue
|
|
1208
|
+
pts = peaks[batch_idx, inst_idx]
|
|
1209
|
+
scores = peak_vals[batch_idx, inst_idx]
|
|
1210
|
+
score = float(centroid_vals[batch_idx, inst_idx])
|
|
1211
|
+
instances.append(
|
|
1212
|
+
sio.PredictedInstance.from_numpy(
|
|
1213
|
+
points_data=pts,
|
|
1214
|
+
point_scores=scores,
|
|
1215
|
+
score=score,
|
|
1216
|
+
skeleton=skeleton,
|
|
1217
|
+
)
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1220
|
+
if max_instances is not None and instances:
|
|
1221
|
+
instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
|
|
1222
|
+
instances = instances[:max_instances]
|
|
1223
|
+
|
|
1224
|
+
if instances:
|
|
1225
|
+
labeled_frames.append(
|
|
1226
|
+
sio.LabeledFrame(
|
|
1227
|
+
video=video,
|
|
1228
|
+
frame_idx=int(frame_idx),
|
|
1229
|
+
instances=instances,
|
|
1230
|
+
)
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
return labeled_frames
|
|
1234
|
+
|
|
1235
|
+
|
|
1236
|
+
def _predict_multiclass_topdown_combined_frames(
|
|
1237
|
+
outputs,
|
|
1238
|
+
frame_indices,
|
|
1239
|
+
video,
|
|
1240
|
+
skeleton,
|
|
1241
|
+
class_names: list,
|
|
1242
|
+
max_instances=None,
|
|
1243
|
+
):
|
|
1244
|
+
"""Convert combined multiclass top-down model outputs to LabeledFrames.
|
|
1245
|
+
|
|
1246
|
+
Args:
|
|
1247
|
+
outputs: Model outputs with centroids, centroid_vals, peaks, peak_vals,
|
|
1248
|
+
class_logits, instance_valid.
|
|
1249
|
+
frame_indices: Frame indices corresponding to batch.
|
|
1250
|
+
video: sleap_io.Video object.
|
|
1251
|
+
skeleton: sleap_io.Skeleton object.
|
|
1252
|
+
class_names: List of class names (e.g., ["female", "male"]).
|
|
1253
|
+
max_instances: Maximum instances per frame (None = n_classes).
|
|
1254
|
+
|
|
1255
|
+
Returns:
|
|
1256
|
+
List of LabeledFrame objects.
|
|
1257
|
+
"""
|
|
1258
|
+
import numpy as np
|
|
1259
|
+
import sleap_io as sio
|
|
1260
|
+
from scipy.optimize import linear_sum_assignment
|
|
1261
|
+
|
|
1262
|
+
labeled_frames = []
|
|
1263
|
+
centroids = outputs["centroids"]
|
|
1264
|
+
centroid_vals = outputs["centroid_vals"]
|
|
1265
|
+
peaks = outputs["peaks"]
|
|
1266
|
+
peak_vals = outputs["peak_vals"]
|
|
1267
|
+
class_logits = outputs["class_logits"]
|
|
1268
|
+
instance_valid = outputs["instance_valid"]
|
|
1269
|
+
|
|
1270
|
+
n_classes = len(class_names)
|
|
1271
|
+
|
|
1272
|
+
for batch_idx, frame_idx in enumerate(frame_indices):
|
|
1273
|
+
valid_mask = instance_valid[batch_idx].astype(bool)
|
|
1274
|
+
n_valid = valid_mask.sum()
|
|
1275
|
+
|
|
1276
|
+
if n_valid == 0:
|
|
1277
|
+
continue
|
|
1278
|
+
|
|
1279
|
+
# Gather valid instances
|
|
1280
|
+
valid_peaks = peaks[batch_idx, valid_mask] # (n_valid, n_nodes, 2)
|
|
1281
|
+
valid_peak_vals = peak_vals[batch_idx, valid_mask] # (n_valid, n_nodes)
|
|
1282
|
+
valid_centroid_vals = centroid_vals[batch_idx, valid_mask] # (n_valid,)
|
|
1283
|
+
valid_class_logits = class_logits[batch_idx, valid_mask] # (n_valid, n_classes)
|
|
1284
|
+
|
|
1285
|
+
# Compute softmax probabilities from logits
|
|
1286
|
+
logits = valid_class_logits - np.max(valid_class_logits, axis=1, keepdims=True)
|
|
1287
|
+
probs = np.exp(logits)
|
|
1288
|
+
probs = probs / np.sum(probs, axis=1, keepdims=True)
|
|
1289
|
+
|
|
1290
|
+
# Use Hungarian matching to assign classes to instances
|
|
1291
|
+
# Maximize total probability (minimize negative)
|
|
1292
|
+
cost = -probs
|
|
1293
|
+
row_inds, col_inds = linear_sum_assignment(cost)
|
|
1294
|
+
|
|
1295
|
+
# Create instances with class assignments
|
|
1296
|
+
instances = []
|
|
1297
|
+
for row_idx, class_idx in zip(row_inds, col_inds):
|
|
1298
|
+
pts = valid_peaks[row_idx]
|
|
1299
|
+
scores = valid_peak_vals[row_idx]
|
|
1300
|
+
score = float(valid_centroid_vals[row_idx])
|
|
1301
|
+
|
|
1302
|
+
# Get track name from class names
|
|
1303
|
+
track_name = (
|
|
1304
|
+
class_names[class_idx]
|
|
1305
|
+
if class_idx < len(class_names)
|
|
1306
|
+
else f"class_{class_idx}"
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1309
|
+
instances.append(
|
|
1310
|
+
sio.PredictedInstance.from_numpy(
|
|
1311
|
+
points_data=pts,
|
|
1312
|
+
point_scores=scores,
|
|
1313
|
+
score=score,
|
|
1314
|
+
skeleton=skeleton,
|
|
1315
|
+
track=sio.Track(name=track_name),
|
|
1316
|
+
)
|
|
1317
|
+
)
|
|
1318
|
+
|
|
1319
|
+
if max_instances is not None and instances:
|
|
1320
|
+
instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
|
|
1321
|
+
instances = instances[:max_instances]
|
|
1322
|
+
|
|
1323
|
+
if instances:
|
|
1324
|
+
labeled_frames.append(
|
|
1325
|
+
sio.LabeledFrame(
|
|
1326
|
+
video=video,
|
|
1327
|
+
frame_idx=int(frame_idx),
|
|
1328
|
+
instances=instances,
|
|
1329
|
+
)
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
return labeled_frames
|
|
1333
|
+
|
|
1334
|
+
|
|
1335
|
+
def _predict_bottomup_frames(
|
|
1336
|
+
outputs,
|
|
1337
|
+
frame_indices,
|
|
1338
|
+
video,
|
|
1339
|
+
skeleton,
|
|
1340
|
+
paf_scorer,
|
|
1341
|
+
candidate_template,
|
|
1342
|
+
input_scale,
|
|
1343
|
+
peak_conf_threshold=0.1,
|
|
1344
|
+
max_instances=None,
|
|
1345
|
+
):
|
|
1346
|
+
"""Convert bottom-up model outputs to LabeledFrames."""
|
|
1347
|
+
import numpy as np
|
|
1348
|
+
import sleap_io as sio
|
|
1349
|
+
import torch
|
|
1350
|
+
|
|
1351
|
+
labeled_frames = []
|
|
1352
|
+
|
|
1353
|
+
peaks = torch.from_numpy(outputs["peaks"]).to(torch.float32)
|
|
1354
|
+
peak_vals = torch.from_numpy(outputs["peak_vals"]).to(torch.float32)
|
|
1355
|
+
line_scores = torch.from_numpy(outputs["line_scores"]).to(torch.float32)
|
|
1356
|
+
candidate_mask = torch.from_numpy(outputs["candidate_mask"]).to(torch.bool)
|
|
1357
|
+
|
|
1358
|
+
batch_size, n_nodes, k, _ = peaks.shape
|
|
1359
|
+
peaks_flat = peaks.reshape(batch_size, n_nodes * k, 2)
|
|
1360
|
+
peak_vals_flat = peak_vals.reshape(batch_size, n_nodes * k)
|
|
1361
|
+
|
|
1362
|
+
peak_channel_inds_base = candidate_template["peak_channel_inds"]
|
|
1363
|
+
edge_inds_base = candidate_template["edge_inds"]
|
|
1364
|
+
edge_peak_inds_base = candidate_template["edge_peak_inds"]
|
|
1365
|
+
|
|
1366
|
+
peaks_list = []
|
|
1367
|
+
peak_vals_list = []
|
|
1368
|
+
peak_channel_inds_list = []
|
|
1369
|
+
edge_inds_list = []
|
|
1370
|
+
edge_peak_inds_list = []
|
|
1371
|
+
line_scores_list = []
|
|
1372
|
+
|
|
1373
|
+
for b in range(batch_size):
|
|
1374
|
+
peaks_list.append(peaks_flat[b])
|
|
1375
|
+
peak_vals_list.append(peak_vals_flat[b])
|
|
1376
|
+
peak_channel_inds_list.append(peak_channel_inds_base)
|
|
1377
|
+
|
|
1378
|
+
candidate_mask_flat = candidate_mask[b].reshape(-1)
|
|
1379
|
+
line_scores_flat = line_scores[b].reshape(-1)
|
|
1380
|
+
|
|
1381
|
+
if candidate_mask_flat.numel() == 0:
|
|
1382
|
+
edge_inds_list.append(torch.empty((0,), dtype=torch.int32))
|
|
1383
|
+
edge_peak_inds_list.append(torch.empty((0, 2), dtype=torch.int32))
|
|
1384
|
+
line_scores_list.append(torch.empty((0,), dtype=torch.float32))
|
|
1385
|
+
continue
|
|
1386
|
+
|
|
1387
|
+
# Filter candidates by peak confidence threshold
|
|
1388
|
+
peak_vals_b = peak_vals_flat[b]
|
|
1389
|
+
peak_conf_valid = peak_vals_b > peak_conf_threshold
|
|
1390
|
+
src_valid = peak_conf_valid[edge_peak_inds_base[:, 0].long()]
|
|
1391
|
+
dst_valid = peak_conf_valid[edge_peak_inds_base[:, 1].long()]
|
|
1392
|
+
valid = candidate_mask_flat & src_valid & dst_valid
|
|
1393
|
+
|
|
1394
|
+
edge_inds_list.append(edge_inds_base[valid])
|
|
1395
|
+
edge_peak_inds_list.append(edge_peak_inds_base[valid])
|
|
1396
|
+
line_scores_list.append(line_scores_flat[valid])
|
|
1397
|
+
|
|
1398
|
+
(
|
|
1399
|
+
match_edge_inds,
|
|
1400
|
+
match_src_peak_inds,
|
|
1401
|
+
match_dst_peak_inds,
|
|
1402
|
+
match_line_scores,
|
|
1403
|
+
) = paf_scorer.match_candidates(
|
|
1404
|
+
edge_inds_list,
|
|
1405
|
+
edge_peak_inds_list,
|
|
1406
|
+
line_scores_list,
|
|
1407
|
+
)
|
|
1408
|
+
|
|
1409
|
+
(
|
|
1410
|
+
predicted_instances,
|
|
1411
|
+
predicted_peak_scores,
|
|
1412
|
+
predicted_instance_scores,
|
|
1413
|
+
) = paf_scorer.group_instances(
|
|
1414
|
+
peaks_list,
|
|
1415
|
+
peak_vals_list,
|
|
1416
|
+
peak_channel_inds_list,
|
|
1417
|
+
match_edge_inds,
|
|
1418
|
+
match_src_peak_inds,
|
|
1419
|
+
match_dst_peak_inds,
|
|
1420
|
+
match_line_scores,
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
predicted_instances = [p / input_scale for p in predicted_instances]
|
|
1424
|
+
|
|
1425
|
+
for batch_idx, frame_idx in enumerate(frame_indices):
|
|
1426
|
+
instances = []
|
|
1427
|
+
for pts, confs, score in zip(
|
|
1428
|
+
predicted_instances[batch_idx],
|
|
1429
|
+
predicted_peak_scores[batch_idx],
|
|
1430
|
+
predicted_instance_scores[batch_idx],
|
|
1431
|
+
):
|
|
1432
|
+
pts_np = pts.cpu().numpy()
|
|
1433
|
+
if np.isnan(pts_np).all():
|
|
1434
|
+
continue
|
|
1435
|
+
instances.append(
|
|
1436
|
+
sio.PredictedInstance.from_numpy(
|
|
1437
|
+
points_data=pts_np,
|
|
1438
|
+
point_scores=confs.cpu().numpy(),
|
|
1439
|
+
score=float(score),
|
|
1440
|
+
skeleton=skeleton,
|
|
1441
|
+
)
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
if max_instances is not None and instances:
|
|
1445
|
+
instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
|
|
1446
|
+
instances = instances[:max_instances]
|
|
1447
|
+
|
|
1448
|
+
if instances:
|
|
1449
|
+
labeled_frames.append(
|
|
1450
|
+
sio.LabeledFrame(
|
|
1451
|
+
video=video,
|
|
1452
|
+
frame_idx=int(frame_idx),
|
|
1453
|
+
instances=instances,
|
|
1454
|
+
)
|
|
1455
|
+
)
|
|
1456
|
+
|
|
1457
|
+
return labeled_frames
|
|
1458
|
+
|
|
1459
|
+
|
|
1460
|
+
def _predict_single_instance_frames(
|
|
1461
|
+
outputs,
|
|
1462
|
+
frame_indices,
|
|
1463
|
+
video,
|
|
1464
|
+
skeleton,
|
|
1465
|
+
):
|
|
1466
|
+
"""Convert single-instance model outputs to LabeledFrames."""
|
|
1467
|
+
import numpy as np
|
|
1468
|
+
import sleap_io as sio
|
|
1469
|
+
|
|
1470
|
+
labeled_frames = []
|
|
1471
|
+
peaks = outputs["peaks"] # (batch, n_nodes, 2)
|
|
1472
|
+
peak_vals = outputs["peak_vals"] # (batch, n_nodes)
|
|
1473
|
+
|
|
1474
|
+
for batch_idx, frame_idx in enumerate(frame_indices):
|
|
1475
|
+
pts = peaks[batch_idx]
|
|
1476
|
+
scores = peak_vals[batch_idx]
|
|
1477
|
+
|
|
1478
|
+
# Compute instance score as mean of valid peak values
|
|
1479
|
+
valid_mask = ~np.isnan(pts[:, 0])
|
|
1480
|
+
if valid_mask.any():
|
|
1481
|
+
instance_score = float(np.mean(scores[valid_mask]))
|
|
1482
|
+
else:
|
|
1483
|
+
instance_score = 0.0
|
|
1484
|
+
|
|
1485
|
+
instance = sio.PredictedInstance.from_numpy(
|
|
1486
|
+
points_data=pts,
|
|
1487
|
+
point_scores=scores,
|
|
1488
|
+
score=instance_score,
|
|
1489
|
+
skeleton=skeleton,
|
|
1490
|
+
)
|
|
1491
|
+
|
|
1492
|
+
labeled_frames.append(
|
|
1493
|
+
sio.LabeledFrame(
|
|
1494
|
+
video=video,
|
|
1495
|
+
frame_idx=int(frame_idx),
|
|
1496
|
+
instances=[instance],
|
|
1497
|
+
)
|
|
1498
|
+
)
|
|
1499
|
+
|
|
1500
|
+
return labeled_frames
|
|
1501
|
+
|
|
1502
|
+
|
|
1503
|
+
def _predict_centroid_frames(
|
|
1504
|
+
outputs,
|
|
1505
|
+
frame_indices,
|
|
1506
|
+
video,
|
|
1507
|
+
skeleton,
|
|
1508
|
+
anchor_node_idx: int,
|
|
1509
|
+
max_instances=None,
|
|
1510
|
+
):
|
|
1511
|
+
"""Convert centroid model outputs to LabeledFrames.
|
|
1512
|
+
|
|
1513
|
+
For centroid-only models, creates instances with only the anchor node filled in.
|
|
1514
|
+
All other nodes are set to NaN.
|
|
1515
|
+
|
|
1516
|
+
Args:
|
|
1517
|
+
outputs: Model outputs with centroids, centroid_vals, instance_valid.
|
|
1518
|
+
frame_indices: Frame indices corresponding to batch.
|
|
1519
|
+
video: sleap_io.Video object.
|
|
1520
|
+
skeleton: sleap_io.Skeleton object.
|
|
1521
|
+
anchor_node_idx: Index of the anchor node in the skeleton.
|
|
1522
|
+
max_instances: Maximum instances to output per frame.
|
|
1523
|
+
|
|
1524
|
+
Returns:
|
|
1525
|
+
List of LabeledFrame objects.
|
|
1526
|
+
"""
|
|
1527
|
+
import numpy as np
|
|
1528
|
+
import sleap_io as sio
|
|
1529
|
+
|
|
1530
|
+
labeled_frames = []
|
|
1531
|
+
centroids = outputs["centroids"] # (batch, max_instances, 2)
|
|
1532
|
+
centroid_vals = outputs["centroid_vals"] # (batch, max_instances)
|
|
1533
|
+
instance_valid = outputs["instance_valid"] # (batch, max_instances)
|
|
1534
|
+
|
|
1535
|
+
n_nodes = len(skeleton.nodes)
|
|
1536
|
+
|
|
1537
|
+
for batch_idx, frame_idx in enumerate(frame_indices):
|
|
1538
|
+
instances = []
|
|
1539
|
+
valid_mask = instance_valid[batch_idx].astype(bool)
|
|
1540
|
+
|
|
1541
|
+
for inst_idx, is_valid in enumerate(valid_mask):
|
|
1542
|
+
if not is_valid:
|
|
1543
|
+
continue
|
|
1544
|
+
|
|
1545
|
+
# Create points array with NaN for all nodes except anchor
|
|
1546
|
+
pts = np.full((n_nodes, 2), np.nan, dtype=np.float32)
|
|
1547
|
+
pts[anchor_node_idx] = centroids[batch_idx, inst_idx]
|
|
1548
|
+
|
|
1549
|
+
# Create scores array - anchor gets centroid score, others get NaN
|
|
1550
|
+
scores = np.full((n_nodes,), np.nan, dtype=np.float32)
|
|
1551
|
+
scores[anchor_node_idx] = centroid_vals[batch_idx, inst_idx]
|
|
1552
|
+
|
|
1553
|
+
instance_score = float(centroid_vals[batch_idx, inst_idx])
|
|
1554
|
+
|
|
1555
|
+
instances.append(
|
|
1556
|
+
sio.PredictedInstance.from_numpy(
|
|
1557
|
+
points_data=pts,
|
|
1558
|
+
point_scores=scores,
|
|
1559
|
+
score=instance_score,
|
|
1560
|
+
skeleton=skeleton,
|
|
1561
|
+
)
|
|
1562
|
+
)
|
|
1563
|
+
|
|
1564
|
+
if max_instances is not None and instances:
|
|
1565
|
+
instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
|
|
1566
|
+
instances = instances[:max_instances]
|
|
1567
|
+
|
|
1568
|
+
if instances:
|
|
1569
|
+
labeled_frames.append(
|
|
1570
|
+
sio.LabeledFrame(
|
|
1571
|
+
video=video,
|
|
1572
|
+
frame_idx=int(frame_idx),
|
|
1573
|
+
instances=instances,
|
|
1574
|
+
)
|
|
1575
|
+
)
|
|
1576
|
+
|
|
1577
|
+
return labeled_frames
|
|
1578
|
+
|
|
1579
|
+
|
|
1580
|
+
def _predict_multiclass_bottomup_frames(
|
|
1581
|
+
outputs,
|
|
1582
|
+
frame_indices,
|
|
1583
|
+
video,
|
|
1584
|
+
skeleton,
|
|
1585
|
+
class_names: list,
|
|
1586
|
+
input_scale: float = 1.0,
|
|
1587
|
+
peak_conf_threshold: float = 0.1,
|
|
1588
|
+
max_instances: int = None,
|
|
1589
|
+
):
|
|
1590
|
+
"""Convert bottom-up multiclass model outputs to LabeledFrames.
|
|
1591
|
+
|
|
1592
|
+
Uses class probability maps to group peaks by identity rather than PAFs.
|
|
1593
|
+
|
|
1594
|
+
Args:
|
|
1595
|
+
outputs: Model outputs with peaks, peak_vals, peak_mask, class_probs.
|
|
1596
|
+
frame_indices: Frame indices corresponding to batch.
|
|
1597
|
+
video: sleap_io.Video object.
|
|
1598
|
+
skeleton: sleap_io.Skeleton object.
|
|
1599
|
+
class_names: List of class names (e.g., ["female", "male"]).
|
|
1600
|
+
input_scale: Scale factor applied to input.
|
|
1601
|
+
peak_conf_threshold: Minimum peak confidence to include.
|
|
1602
|
+
max_instances: Maximum instances per frame (None = n_classes).
|
|
1603
|
+
|
|
1604
|
+
Returns:
|
|
1605
|
+
List of LabeledFrame objects.
|
|
1606
|
+
"""
|
|
1607
|
+
import numpy as np
|
|
1608
|
+
import sleap_io as sio
|
|
1609
|
+
from scipy.optimize import linear_sum_assignment
|
|
1610
|
+
|
|
1611
|
+
labeled_frames = []
|
|
1612
|
+
n_classes = len(class_names)
|
|
1613
|
+
|
|
1614
|
+
peaks = outputs["peaks"] # (batch, n_nodes, max_peaks, 2)
|
|
1615
|
+
peak_vals = outputs["peak_vals"] # (batch, n_nodes, max_peaks)
|
|
1616
|
+
peak_mask = outputs["peak_mask"] # (batch, n_nodes, max_peaks)
|
|
1617
|
+
class_probs = outputs["class_probs"] # (batch, n_nodes, max_peaks, n_classes)
|
|
1618
|
+
|
|
1619
|
+
batch_size, n_nodes, max_peaks, _ = peaks.shape
|
|
1620
|
+
n_nodes_skel = len(skeleton.nodes)
|
|
1621
|
+
|
|
1622
|
+
for batch_idx, frame_idx in enumerate(frame_indices):
|
|
1623
|
+
# Initialize instances for each class
|
|
1624
|
+
instance_points = np.full(
|
|
1625
|
+
(n_classes, n_nodes_skel, 2), np.nan, dtype=np.float32
|
|
1626
|
+
)
|
|
1627
|
+
instance_scores = np.full((n_classes, n_nodes_skel), np.nan, dtype=np.float32)
|
|
1628
|
+
instance_class_probs = np.full((n_classes,), 0.0, dtype=np.float32)
|
|
1629
|
+
|
|
1630
|
+
# Process each node independently
|
|
1631
|
+
for node_idx in range(min(n_nodes, n_nodes_skel)):
|
|
1632
|
+
# Get valid peaks for this node
|
|
1633
|
+
valid = peak_mask[batch_idx, node_idx].astype(bool)
|
|
1634
|
+
valid = valid & (peak_vals[batch_idx, node_idx] > peak_conf_threshold)
|
|
1635
|
+
|
|
1636
|
+
if not valid.any():
|
|
1637
|
+
continue
|
|
1638
|
+
|
|
1639
|
+
valid_peaks = peaks[batch_idx, node_idx][valid] # (n_valid, 2)
|
|
1640
|
+
valid_vals = peak_vals[batch_idx, node_idx][valid] # (n_valid,)
|
|
1641
|
+
valid_class_probs = class_probs[batch_idx, node_idx][
|
|
1642
|
+
valid
|
|
1643
|
+
] # (n_valid, n_classes)
|
|
1644
|
+
|
|
1645
|
+
# Use Hungarian matching to assign peaks to classes
|
|
1646
|
+
# Maximize class probabilities (minimize negative)
|
|
1647
|
+
cost = -valid_class_probs
|
|
1648
|
+
row_inds, col_inds = linear_sum_assignment(cost)
|
|
1649
|
+
|
|
1650
|
+
# Assign matched peaks to instances
|
|
1651
|
+
for peak_idx, class_idx in zip(row_inds, col_inds):
|
|
1652
|
+
if class_idx < n_classes:
|
|
1653
|
+
instance_points[class_idx, node_idx] = (
|
|
1654
|
+
valid_peaks[peak_idx] / input_scale
|
|
1655
|
+
)
|
|
1656
|
+
instance_scores[class_idx, node_idx] = valid_vals[peak_idx]
|
|
1657
|
+
instance_class_probs[class_idx] += valid_class_probs[
|
|
1658
|
+
peak_idx, class_idx
|
|
1659
|
+
]
|
|
1660
|
+
|
|
1661
|
+
# Create predicted instances
|
|
1662
|
+
instances = []
|
|
1663
|
+
for class_idx in range(n_classes):
|
|
1664
|
+
pts = instance_points[class_idx]
|
|
1665
|
+
scores = instance_scores[class_idx]
|
|
1666
|
+
|
|
1667
|
+
# Skip if no valid points
|
|
1668
|
+
if np.isnan(pts).all():
|
|
1669
|
+
continue
|
|
1670
|
+
|
|
1671
|
+
# Compute instance score as mean of valid peak values
|
|
1672
|
+
valid_mask = ~np.isnan(pts[:, 0])
|
|
1673
|
+
if valid_mask.any():
|
|
1674
|
+
instance_score = float(np.mean(scores[valid_mask]))
|
|
1675
|
+
else:
|
|
1676
|
+
instance_score = 0.0
|
|
1677
|
+
|
|
1678
|
+
# Get track name from class names
|
|
1679
|
+
track_name = (
|
|
1680
|
+
class_names[class_idx]
|
|
1681
|
+
if class_idx < len(class_names)
|
|
1682
|
+
else f"class_{class_idx}"
|
|
1683
|
+
)
|
|
1684
|
+
|
|
1685
|
+
instances.append(
|
|
1686
|
+
sio.PredictedInstance.from_numpy(
|
|
1687
|
+
points_data=pts,
|
|
1688
|
+
point_scores=scores,
|
|
1689
|
+
score=instance_score,
|
|
1690
|
+
skeleton=skeleton,
|
|
1691
|
+
track=sio.Track(name=track_name),
|
|
1692
|
+
)
|
|
1693
|
+
)
|
|
1694
|
+
|
|
1695
|
+
if max_instances is not None and instances:
|
|
1696
|
+
instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
|
|
1697
|
+
instances = instances[:max_instances]
|
|
1698
|
+
|
|
1699
|
+
if instances:
|
|
1700
|
+
labeled_frames.append(
|
|
1701
|
+
sio.LabeledFrame(
|
|
1702
|
+
video=video,
|
|
1703
|
+
frame_idx=int(frame_idx),
|
|
1704
|
+
instances=instances,
|
|
1705
|
+
)
|
|
1706
|
+
)
|
|
1707
|
+
|
|
1708
|
+
return labeled_frames
|
|
1709
|
+
|
|
1710
|
+
|
|
1711
|
+
def _copy_training_config(
|
|
1712
|
+
model_path: Path, export_dir: Path, label: Optional[str]
|
|
1713
|
+
) -> Optional[Path]:
|
|
1714
|
+
training_config_path = _training_config_path(model_path)
|
|
1715
|
+
if training_config_path is None:
|
|
1716
|
+
return None
|
|
1717
|
+
|
|
1718
|
+
if label:
|
|
1719
|
+
dest_name = f"training_config_{label}{training_config_path.suffix}"
|
|
1720
|
+
else:
|
|
1721
|
+
dest_name = training_config_path.name
|
|
1722
|
+
|
|
1723
|
+
dest_path = export_dir / dest_name
|
|
1724
|
+
shutil.copy(training_config_path, dest_path)
|
|
1725
|
+
return dest_path
|
|
1726
|
+
|
|
1727
|
+
|
|
1728
|
+
def _training_config_path(model_path: Path) -> Optional[Path]:
|
|
1729
|
+
yaml_path = model_path / "training_config.yaml"
|
|
1730
|
+
json_path = model_path / "training_config.json"
|
|
1731
|
+
if yaml_path.exists():
|
|
1732
|
+
return yaml_path
|
|
1733
|
+
if json_path.exists():
|
|
1734
|
+
return json_path
|
|
1735
|
+
return None
|
|
1736
|
+
|
|
1737
|
+
|
|
1738
|
+
def _load_lightning_model(
|
|
1739
|
+
*,
|
|
1740
|
+
model_type: str,
|
|
1741
|
+
backbone_type: str,
|
|
1742
|
+
cfg,
|
|
1743
|
+
ckpt_path: Path,
|
|
1744
|
+
device: str,
|
|
1745
|
+
):
|
|
1746
|
+
lightning_cls = {
|
|
1747
|
+
"centroid": CentroidLightningModule,
|
|
1748
|
+
"centered_instance": TopDownCenteredInstanceLightningModule,
|
|
1749
|
+
"single_instance": SingleInstanceLightningModule,
|
|
1750
|
+
"bottomup": BottomUpLightningModule,
|
|
1751
|
+
"multi_class_topdown": TopDownCenteredInstanceMultiClassLightningModule,
|
|
1752
|
+
"multi_class_bottomup": BottomUpMultiClassLightningModule,
|
|
1753
|
+
}.get(model_type)
|
|
1754
|
+
|
|
1755
|
+
if lightning_cls is None:
|
|
1756
|
+
raise click.ClickException(f"Unsupported model type: {model_type}")
|
|
1757
|
+
|
|
1758
|
+
return lightning_cls.load_from_checkpoint(
|
|
1759
|
+
checkpoint_path=str(ckpt_path),
|
|
1760
|
+
model_type=model_type,
|
|
1761
|
+
backbone_type=backbone_type,
|
|
1762
|
+
backbone_config=cfg.model_config.backbone_config,
|
|
1763
|
+
head_configs=cfg.model_config.head_configs,
|
|
1764
|
+
pretrained_backbone_weights=cfg.model_config.pretrained_backbone_weights,
|
|
1765
|
+
pretrained_head_weights=cfg.model_config.pretrained_head_weights,
|
|
1766
|
+
init_weights=cfg.model_config.init_weights,
|
|
1767
|
+
lr_scheduler=cfg.trainer_config.lr_scheduler,
|
|
1768
|
+
online_mining=cfg.trainer_config.online_hard_keypoint_mining.online_mining,
|
|
1769
|
+
hard_to_easy_ratio=cfg.trainer_config.online_hard_keypoint_mining.hard_to_easy_ratio,
|
|
1770
|
+
min_hard_keypoints=cfg.trainer_config.online_hard_keypoint_mining.min_hard_keypoints,
|
|
1771
|
+
max_hard_keypoints=cfg.trainer_config.online_hard_keypoint_mining.max_hard_keypoints,
|
|
1772
|
+
loss_scale=cfg.trainer_config.online_hard_keypoint_mining.loss_scale,
|
|
1773
|
+
optimizer=cfg.trainer_config.optimizer_name,
|
|
1774
|
+
learning_rate=cfg.trainer_config.optimizer.lr,
|
|
1775
|
+
amsgrad=cfg.trainer_config.optimizer.amsgrad,
|
|
1776
|
+
map_location=device,
|
|
1777
|
+
weights_only=False,
|
|
1778
|
+
)
|