sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/evaluation.py +8 -0
  4. sleap_nn/export/__init__.py +21 -0
  5. sleap_nn/export/cli.py +1778 -0
  6. sleap_nn/export/exporters/__init__.py +51 -0
  7. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  8. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  9. sleap_nn/export/metadata.py +225 -0
  10. sleap_nn/export/predictors/__init__.py +63 -0
  11. sleap_nn/export/predictors/base.py +22 -0
  12. sleap_nn/export/predictors/onnx.py +154 -0
  13. sleap_nn/export/predictors/tensorrt.py +312 -0
  14. sleap_nn/export/utils.py +307 -0
  15. sleap_nn/export/wrappers/__init__.py +25 -0
  16. sleap_nn/export/wrappers/base.py +96 -0
  17. sleap_nn/export/wrappers/bottomup.py +243 -0
  18. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  19. sleap_nn/export/wrappers/centered_instance.py +56 -0
  20. sleap_nn/export/wrappers/centroid.py +58 -0
  21. sleap_nn/export/wrappers/single_instance.py +83 -0
  22. sleap_nn/export/wrappers/topdown.py +180 -0
  23. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  24. sleap_nn/inference/postprocessing.py +284 -0
  25. sleap_nn/predict.py +29 -0
  26. sleap_nn/train.py +64 -0
  27. sleap_nn/training/callbacks.py +62 -20
  28. sleap_nn/training/lightning_modules.py +332 -30
  29. sleap_nn/training/model_trainer.py +35 -67
  30. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
  31. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
  32. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  33. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  34. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/export/cli.py ADDED
@@ -0,0 +1,1778 @@
1
+ """CLI entry points for export workflows."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Optional
7
+ import json
8
+ import shutil
9
+
10
+ import click
11
+ from omegaconf import OmegaConf
12
+ import torch
13
+
14
+ from sleap_nn.export.exporters import export_to_onnx, export_to_tensorrt
15
+ from sleap_nn.export.metadata import (
16
+ build_base_metadata,
17
+ embed_metadata_in_onnx,
18
+ hash_file,
19
+ )
20
+ from sleap_nn.export.utils import (
21
+ load_training_config,
22
+ resolve_backbone_type,
23
+ resolve_class_maps_output_stride,
24
+ resolve_class_names,
25
+ resolve_crop_size,
26
+ resolve_edge_inds,
27
+ resolve_input_channels,
28
+ resolve_input_scale,
29
+ resolve_input_shape,
30
+ resolve_model_type,
31
+ resolve_n_classes,
32
+ resolve_node_names,
33
+ resolve_output_stride,
34
+ resolve_pafs_output_stride,
35
+ )
36
+ from sleap_nn.export.wrappers import (
37
+ BottomUpMultiClassONNXWrapper,
38
+ BottomUpONNXWrapper,
39
+ CenteredInstanceONNXWrapper,
40
+ CentroidONNXWrapper,
41
+ SingleInstanceONNXWrapper,
42
+ TopDownMultiClassCombinedONNXWrapper,
43
+ TopDownMultiClassONNXWrapper,
44
+ TopDownONNXWrapper,
45
+ )
46
+ from sleap_nn.training.lightning_modules import (
47
+ BottomUpLightningModule,
48
+ BottomUpMultiClassLightningModule,
49
+ CentroidLightningModule,
50
+ SingleInstanceLightningModule,
51
+ TopDownCenteredInstanceLightningModule,
52
+ TopDownCenteredInstanceMultiClassLightningModule,
53
+ )
54
+
55
+
56
+ @click.command()
57
+ @click.argument(
58
+ "model_paths",
59
+ nargs=-1,
60
+ type=click.Path(exists=True, file_okay=False, path_type=Path),
61
+ )
62
+ @click.option(
63
+ "--output",
64
+ "-o",
65
+ type=click.Path(file_okay=False, path_type=Path),
66
+ default=None,
67
+ help="Output directory for exported model files.",
68
+ )
69
+ @click.option(
70
+ "--format",
71
+ "-f",
72
+ "fmt",
73
+ type=click.Choice(["onnx", "tensorrt", "both"], case_sensitive=False),
74
+ default="onnx",
75
+ show_default=True,
76
+ )
77
+ @click.option("--opset-version", type=int, default=17, show_default=True)
78
+ @click.option("--max-instances", type=int, default=20, show_default=True)
79
+ @click.option("--max-batch-size", type=int, default=8, show_default=True)
80
+ @click.option("--input-scale", type=float, default=None)
81
+ @click.option("--input-height", type=int, default=None)
82
+ @click.option("--input-width", type=int, default=None)
83
+ @click.option("--crop-size", type=int, default=None)
84
+ @click.option("--max-peaks-per-node", type=int, default=20, show_default=True)
85
+ @click.option("--n-line-points", type=int, default=10, show_default=True)
86
+ @click.option("--max-edge-length-ratio", type=float, default=0.25, show_default=True)
87
+ @click.option("--dist-penalty-weight", type=float, default=1.0, show_default=True)
88
+ @click.option("--device", type=str, default="cpu", show_default=True)
89
+ @click.option(
90
+ "--precision",
91
+ type=click.Choice(["fp32", "fp16"], case_sensitive=False),
92
+ default="fp16",
93
+ show_default=True,
94
+ help="TensorRT precision mode.",
95
+ )
96
+ @click.option("--verify/--no-verify", default=True, show_default=True)
97
+ def export(
98
+ model_paths: tuple[Path, ...],
99
+ output: Optional[Path],
100
+ fmt: str,
101
+ opset_version: int,
102
+ max_instances: int,
103
+ max_batch_size: int,
104
+ input_scale: Optional[float],
105
+ input_height: Optional[int],
106
+ input_width: Optional[int],
107
+ crop_size: Optional[int],
108
+ max_peaks_per_node: int,
109
+ n_line_points: int,
110
+ max_edge_length_ratio: float,
111
+ dist_penalty_weight: float,
112
+ device: str,
113
+ precision: str,
114
+ verify: bool,
115
+ ) -> None:
116
+ """Export trained models to ONNX/TensorRT formats."""
117
+ fmt = fmt.lower()
118
+
119
+ if not model_paths:
120
+ raise click.ClickException("Provide at least one model path to export.")
121
+
122
+ model_paths = list(model_paths)
123
+ cfgs = [load_training_config(path) for path in model_paths]
124
+ model_types = [resolve_model_type(cfg) for cfg in cfgs]
125
+ backbone_types = [resolve_backbone_type(cfg) for cfg in cfgs]
126
+
127
+ if len(model_paths) == 1:
128
+ model_path = model_paths[0]
129
+ cfg = cfgs[0]
130
+ model_type = model_types[0]
131
+ backbone_type = backbone_types[0]
132
+
133
+ if model_type not in (
134
+ "centroid",
135
+ "centered_instance",
136
+ "bottomup",
137
+ "single_instance",
138
+ "multi_class_topdown",
139
+ "multi_class_bottomup",
140
+ ):
141
+ raise click.ClickException(
142
+ f"Model type '{model_type}' is not supported for export yet."
143
+ )
144
+
145
+ ckpt_path = model_path / "best.ckpt"
146
+ if not ckpt_path.exists():
147
+ raise click.ClickException(f"Checkpoint not found: {ckpt_path}")
148
+
149
+ lightning_model = _load_lightning_model(
150
+ model_type=model_type,
151
+ backbone_type=backbone_type,
152
+ cfg=cfg,
153
+ ckpt_path=ckpt_path,
154
+ device=device,
155
+ )
156
+
157
+ torch_model = lightning_model.model
158
+ torch_model.eval()
159
+ torch_model.to(device)
160
+
161
+ export_dir = output or (model_path / "exported")
162
+ export_dir.mkdir(parents=True, exist_ok=True)
163
+
164
+ resolved_scale = (
165
+ input_scale if input_scale is not None else resolve_input_scale(cfg)
166
+ )
167
+ output_stride = resolve_output_stride(cfg, model_type)
168
+ resolved_crop_size = (
169
+ (crop_size, crop_size) if crop_size is not None else resolve_crop_size(cfg)
170
+ )
171
+ metadata_max_instances = None
172
+ metadata_max_peaks = None
173
+ metadata_n_classes = None
174
+ metadata_class_names = None
175
+
176
+ if model_type == "centroid":
177
+ wrapper = CentroidONNXWrapper(
178
+ torch_model,
179
+ max_instances=max_instances,
180
+ output_stride=output_stride,
181
+ input_scale=resolved_scale,
182
+ )
183
+ output_names = ["centroids", "centroid_vals", "instance_valid"]
184
+ metadata_max_instances = max_instances
185
+ node_names = resolve_node_names(cfg, model_type)
186
+ edge_inds = resolve_edge_inds(cfg, node_names)
187
+ elif model_type == "centered_instance":
188
+ wrapper = CenteredInstanceONNXWrapper(
189
+ torch_model,
190
+ output_stride=output_stride,
191
+ input_scale=resolved_scale,
192
+ )
193
+ output_names = ["peaks", "peak_vals"]
194
+ node_names = resolve_node_names(cfg, model_type)
195
+ edge_inds = resolve_edge_inds(cfg, node_names)
196
+ elif model_type == "bottomup":
197
+ node_names = resolve_node_names(cfg, model_type)
198
+ edge_inds = resolve_edge_inds(cfg, node_names)
199
+ pafs_output_stride = resolve_pafs_output_stride(cfg)
200
+ wrapper = BottomUpONNXWrapper(
201
+ torch_model,
202
+ skeleton_edges=edge_inds,
203
+ n_nodes=len(node_names),
204
+ max_peaks_per_node=max_peaks_per_node,
205
+ n_line_points=n_line_points,
206
+ cms_output_stride=output_stride,
207
+ pafs_output_stride=pafs_output_stride,
208
+ max_edge_length_ratio=max_edge_length_ratio,
209
+ dist_penalty_weight=dist_penalty_weight,
210
+ input_scale=resolved_scale,
211
+ )
212
+ output_names = [
213
+ "peaks",
214
+ "peak_vals",
215
+ "peak_mask",
216
+ "line_scores",
217
+ "candidate_mask",
218
+ ]
219
+ metadata_max_peaks = max_peaks_per_node
220
+ elif model_type == "single_instance":
221
+ wrapper = SingleInstanceONNXWrapper(
222
+ torch_model,
223
+ output_stride=output_stride,
224
+ input_scale=resolved_scale,
225
+ )
226
+ output_names = ["peaks", "peak_vals"]
227
+ node_names = resolve_node_names(cfg, model_type)
228
+ edge_inds = resolve_edge_inds(cfg, node_names)
229
+ elif model_type == "multi_class_topdown":
230
+ n_classes = resolve_n_classes(cfg, model_type)
231
+ class_names = resolve_class_names(cfg, model_type)
232
+ wrapper = TopDownMultiClassONNXWrapper(
233
+ torch_model,
234
+ output_stride=output_stride,
235
+ input_scale=resolved_scale,
236
+ n_classes=n_classes,
237
+ )
238
+ output_names = ["peaks", "peak_vals", "class_logits"]
239
+ node_names = resolve_node_names(cfg, model_type)
240
+ edge_inds = resolve_edge_inds(cfg, node_names)
241
+ metadata_n_classes = n_classes
242
+ metadata_class_names = class_names
243
+ elif model_type == "multi_class_bottomup":
244
+ node_names = resolve_node_names(cfg, model_type)
245
+ edge_inds = resolve_edge_inds(cfg, node_names)
246
+ n_classes = resolve_n_classes(cfg, model_type)
247
+ class_names = resolve_class_names(cfg, model_type)
248
+ class_maps_output_stride = resolve_class_maps_output_stride(cfg)
249
+ wrapper = BottomUpMultiClassONNXWrapper(
250
+ torch_model,
251
+ n_nodes=len(node_names),
252
+ n_classes=n_classes,
253
+ max_peaks_per_node=max_peaks_per_node,
254
+ cms_output_stride=output_stride,
255
+ class_maps_output_stride=class_maps_output_stride,
256
+ input_scale=resolved_scale,
257
+ )
258
+ output_names = ["peaks", "peak_vals", "peak_mask", "class_probs"]
259
+ metadata_max_peaks = max_peaks_per_node
260
+ metadata_n_classes = n_classes
261
+ metadata_class_names = class_names
262
+ else:
263
+ raise click.ClickException(
264
+ f"Model type '{model_type}' is not supported for export yet."
265
+ )
266
+
267
+ wrapper.eval()
268
+ wrapper.to(device)
269
+
270
+ input_shape = resolve_input_shape(
271
+ cfg, input_height=input_height, input_width=input_width
272
+ )
273
+ model_out_path = export_dir / "model.onnx"
274
+
275
+ export_to_onnx(
276
+ wrapper,
277
+ model_out_path,
278
+ input_shape=input_shape,
279
+ input_dtype=torch.uint8,
280
+ opset_version=opset_version,
281
+ output_names=output_names,
282
+ verify=verify,
283
+ )
284
+
285
+ training_config_path = _copy_training_config(model_path, export_dir, None)
286
+ if training_config_path is not None:
287
+ training_config_hash = hash_file(training_config_path)
288
+ training_config_text = training_config_path.read_text()
289
+ else:
290
+ training_config_hash = ""
291
+ training_config_text = None
292
+
293
+ metadata = build_base_metadata(
294
+ export_format="onnx",
295
+ model_type=model_type,
296
+ model_name=model_path.name,
297
+ checkpoint_path=str(ckpt_path),
298
+ backbone=backbone_type,
299
+ n_nodes=len(node_names),
300
+ n_edges=len(edge_inds),
301
+ node_names=node_names,
302
+ edge_inds=edge_inds,
303
+ input_scale=resolved_scale,
304
+ input_channels=resolve_input_channels(cfg),
305
+ output_stride=output_stride,
306
+ crop_size=resolved_crop_size,
307
+ max_instances=metadata_max_instances,
308
+ max_peaks_per_node=metadata_max_peaks,
309
+ max_batch_size=max_batch_size,
310
+ precision="fp32",
311
+ training_config_hash=training_config_hash,
312
+ training_config_embedded=training_config_text is not None,
313
+ input_dtype="uint8",
314
+ normalization="0_to_1",
315
+ n_classes=metadata_n_classes,
316
+ class_names=metadata_class_names,
317
+ )
318
+
319
+ metadata.save(export_dir / "export_metadata.json")
320
+
321
+ if training_config_text is not None:
322
+ try:
323
+ embed_metadata_in_onnx(model_out_path, metadata, training_config_text)
324
+ except ImportError:
325
+ pass
326
+
327
+ # Export to TensorRT if requested
328
+ if fmt in ("tensorrt", "both"):
329
+ trt_out_path = export_dir / "model.trt"
330
+ B, C, H, W = input_shape
331
+
332
+ # For centered_instance and single_instance models, use crop size
333
+ # for TensorRT shape profiles since inference uses cropped inputs
334
+ if model_type in ("centered_instance", "single_instance"):
335
+ if resolved_crop_size is not None:
336
+ crop_h, crop_w = resolved_crop_size
337
+ trt_input_shape = (1, C, crop_h, crop_w)
338
+ # Use crop size for min/opt, allow flexibility for max
339
+ trt_min_shape = (1, C, crop_h, crop_w)
340
+ trt_opt_shape = (1, C, crop_h, crop_w)
341
+ trt_max_shape = (max_batch_size, C, crop_h * 2, crop_w * 2)
342
+ else:
343
+ trt_input_shape = input_shape
344
+ trt_min_shape = None
345
+ trt_opt_shape = None
346
+ trt_max_shape = (max_batch_size, C, H * 2, W * 2)
347
+ else:
348
+ trt_input_shape = input_shape
349
+ trt_min_shape = None
350
+ trt_opt_shape = None
351
+ trt_max_shape = (max_batch_size, C, H * 2, W * 2)
352
+
353
+ export_to_tensorrt(
354
+ wrapper,
355
+ trt_out_path,
356
+ input_shape=trt_input_shape,
357
+ input_dtype=torch.uint8,
358
+ precision=precision,
359
+ min_shape=trt_min_shape,
360
+ opt_shape=trt_opt_shape,
361
+ max_shape=trt_max_shape,
362
+ verbose=True,
363
+ )
364
+ # Update metadata for TensorRT
365
+ trt_metadata = build_base_metadata(
366
+ export_format="tensorrt",
367
+ model_type=model_type,
368
+ model_name=model_path.name,
369
+ checkpoint_path=str(ckpt_path),
370
+ backbone=backbone_type,
371
+ n_nodes=len(node_names),
372
+ n_edges=len(edge_inds),
373
+ node_names=node_names,
374
+ edge_inds=edge_inds,
375
+ input_scale=resolved_scale,
376
+ input_channels=resolve_input_channels(cfg),
377
+ output_stride=output_stride,
378
+ crop_size=resolved_crop_size,
379
+ max_instances=metadata_max_instances,
380
+ max_peaks_per_node=metadata_max_peaks,
381
+ max_batch_size=max_batch_size,
382
+ precision=precision,
383
+ training_config_hash=training_config_hash,
384
+ training_config_embedded=training_config_text is not None,
385
+ input_dtype="uint8",
386
+ normalization="0_to_1",
387
+ n_classes=metadata_n_classes,
388
+ class_names=metadata_class_names,
389
+ )
390
+ trt_metadata.save(export_dir / "model.trt.metadata.json")
391
+ return
392
+
393
+ if len(model_paths) == 2 and set(model_types) == {
394
+ "centroid",
395
+ "centered_instance",
396
+ }:
397
+ centroid_idx = model_types.index("centroid")
398
+ instance_idx = model_types.index("centered_instance")
399
+
400
+ centroid_path = model_paths[centroid_idx]
401
+ instance_path = model_paths[instance_idx]
402
+ centroid_cfg = cfgs[centroid_idx]
403
+ instance_cfg = cfgs[instance_idx]
404
+ centroid_backbone = backbone_types[centroid_idx]
405
+ instance_backbone = backbone_types[instance_idx]
406
+
407
+ centroid_ckpt = centroid_path / "best.ckpt"
408
+ instance_ckpt = instance_path / "best.ckpt"
409
+ if not centroid_ckpt.exists():
410
+ raise click.ClickException(f"Checkpoint not found: {centroid_ckpt}")
411
+ if not instance_ckpt.exists():
412
+ raise click.ClickException(f"Checkpoint not found: {instance_ckpt}")
413
+
414
+ centroid_model = _load_lightning_model(
415
+ model_type="centroid",
416
+ backbone_type=centroid_backbone,
417
+ cfg=centroid_cfg,
418
+ ckpt_path=centroid_ckpt,
419
+ device=device,
420
+ ).model
421
+ instance_model = _load_lightning_model(
422
+ model_type="centered_instance",
423
+ backbone_type=instance_backbone,
424
+ cfg=instance_cfg,
425
+ ckpt_path=instance_ckpt,
426
+ device=device,
427
+ ).model
428
+
429
+ centroid_model.eval()
430
+ instance_model.eval()
431
+ centroid_model.to(device)
432
+ instance_model.to(device)
433
+
434
+ export_dir = output or (centroid_path / "exported_topdown")
435
+ export_dir.mkdir(parents=True, exist_ok=True)
436
+
437
+ centroid_scale = (
438
+ input_scale
439
+ if input_scale is not None
440
+ else resolve_input_scale(centroid_cfg)
441
+ )
442
+ instance_scale = (
443
+ input_scale
444
+ if input_scale is not None
445
+ else resolve_input_scale(instance_cfg)
446
+ )
447
+ centroid_stride = resolve_output_stride(centroid_cfg, "centroid")
448
+ instance_stride = resolve_output_stride(instance_cfg, "centered_instance")
449
+
450
+ resolved_crop = resolve_crop_size(instance_cfg)
451
+ if crop_size is not None:
452
+ resolved_crop = (crop_size, crop_size)
453
+ if resolved_crop is None:
454
+ raise click.ClickException(
455
+ "Top-down export requires crop_size. Provide --crop-size or ensure "
456
+ "data_config.preprocessing.crop_size is set."
457
+ )
458
+
459
+ node_names = resolve_node_names(instance_cfg, "centered_instance")
460
+ edge_inds = resolve_edge_inds(instance_cfg, node_names)
461
+
462
+ wrapper = TopDownONNXWrapper(
463
+ centroid_model=centroid_model,
464
+ instance_model=instance_model,
465
+ max_instances=max_instances,
466
+ crop_size=resolved_crop,
467
+ centroid_output_stride=centroid_stride,
468
+ instance_output_stride=instance_stride,
469
+ centroid_input_scale=centroid_scale,
470
+ instance_input_scale=instance_scale,
471
+ n_nodes=len(node_names),
472
+ )
473
+ wrapper.eval()
474
+ wrapper.to(device)
475
+
476
+ input_shape = resolve_input_shape(
477
+ centroid_cfg, input_height=input_height, input_width=input_width
478
+ )
479
+ model_out_path = export_dir / "model.onnx"
480
+
481
+ export_to_onnx(
482
+ wrapper,
483
+ model_out_path,
484
+ input_shape=input_shape,
485
+ input_dtype=torch.uint8,
486
+ opset_version=opset_version,
487
+ output_names=[
488
+ "centroids",
489
+ "centroid_vals",
490
+ "peaks",
491
+ "peak_vals",
492
+ "instance_valid",
493
+ ],
494
+ verify=verify,
495
+ )
496
+
497
+ centroid_cfg_path = _copy_training_config(centroid_path, export_dir, "centroid")
498
+ instance_cfg_path = _copy_training_config(
499
+ instance_path, export_dir, "centered_instance"
500
+ )
501
+ config_payload = {}
502
+ config_hashes = []
503
+ if centroid_cfg_path is not None:
504
+ config_payload["centroid"] = centroid_cfg_path.read_text()
505
+ config_hashes.append(f"centroid:{hash_file(centroid_cfg_path)}")
506
+ if instance_cfg_path is not None:
507
+ config_payload["centered_instance"] = instance_cfg_path.read_text()
508
+ config_hashes.append(f"centered_instance:{hash_file(instance_cfg_path)}")
509
+
510
+ training_config_hash = ";".join(config_hashes) if config_hashes else ""
511
+ training_config_text = json.dumps(config_payload) if config_payload else None
512
+
513
+ metadata = build_base_metadata(
514
+ export_format="onnx",
515
+ model_type="topdown",
516
+ model_name=f"{centroid_path.name}+{instance_path.name}",
517
+ checkpoint_path=(
518
+ f"centroid:{centroid_ckpt};centered_instance:{instance_ckpt}"
519
+ ),
520
+ backbone=(
521
+ f"centroid:{centroid_backbone};centered_instance:{instance_backbone}"
522
+ ),
523
+ n_nodes=len(node_names),
524
+ n_edges=len(edge_inds),
525
+ node_names=node_names,
526
+ edge_inds=edge_inds,
527
+ input_scale=centroid_scale,
528
+ input_channels=resolve_input_channels(centroid_cfg),
529
+ output_stride=instance_stride,
530
+ crop_size=resolved_crop,
531
+ max_instances=max_instances,
532
+ max_batch_size=max_batch_size,
533
+ precision="fp32",
534
+ training_config_hash=training_config_hash,
535
+ training_config_embedded=training_config_text is not None,
536
+ input_dtype="uint8",
537
+ normalization="0_to_1",
538
+ )
539
+
540
+ metadata.save(export_dir / "export_metadata.json")
541
+
542
+ if training_config_text is not None:
543
+ try:
544
+ embed_metadata_in_onnx(model_out_path, metadata, training_config_text)
545
+ except ImportError:
546
+ pass
547
+
548
+ # Export to TensorRT if requested
549
+ if fmt in ("tensorrt", "both"):
550
+ trt_out_path = export_dir / "model.trt"
551
+ B, C, H, W = input_shape
552
+ export_to_tensorrt(
553
+ wrapper,
554
+ trt_out_path,
555
+ input_shape=input_shape,
556
+ input_dtype=torch.uint8,
557
+ precision=precision,
558
+ max_shape=(max_batch_size, C, H * 2, W * 2),
559
+ verbose=True,
560
+ )
561
+ # Update metadata for TensorRT
562
+ trt_metadata = build_base_metadata(
563
+ export_format="tensorrt",
564
+ model_type="topdown",
565
+ model_name=f"{centroid_path.name}+{instance_path.name}",
566
+ checkpoint_path=(
567
+ f"centroid:{centroid_ckpt};centered_instance:{instance_ckpt}"
568
+ ),
569
+ backbone=(
570
+ f"centroid:{centroid_backbone};centered_instance:{instance_backbone}"
571
+ ),
572
+ n_nodes=len(node_names),
573
+ n_edges=len(edge_inds),
574
+ node_names=node_names,
575
+ edge_inds=edge_inds,
576
+ input_scale=centroid_scale,
577
+ input_channels=resolve_input_channels(centroid_cfg),
578
+ output_stride=instance_stride,
579
+ crop_size=resolved_crop,
580
+ max_instances=max_instances,
581
+ max_batch_size=max_batch_size,
582
+ precision=precision,
583
+ training_config_hash=training_config_hash,
584
+ training_config_embedded=training_config_text is not None,
585
+ input_dtype="uint8",
586
+ normalization="0_to_1",
587
+ )
588
+ trt_metadata.save(export_dir / "model.trt.metadata.json")
589
+ return
590
+
591
+ # Combined multiclass top-down export (centroid + multi_class_topdown)
592
+ if len(model_paths) == 2 and set(model_types) == {
593
+ "centroid",
594
+ "multi_class_topdown",
595
+ }:
596
+ centroid_idx = model_types.index("centroid")
597
+ instance_idx = model_types.index("multi_class_topdown")
598
+
599
+ centroid_path = model_paths[centroid_idx]
600
+ instance_path = model_paths[instance_idx]
601
+ centroid_cfg = cfgs[centroid_idx]
602
+ instance_cfg = cfgs[instance_idx]
603
+ centroid_backbone = backbone_types[centroid_idx]
604
+ instance_backbone = backbone_types[instance_idx]
605
+
606
+ centroid_ckpt = centroid_path / "best.ckpt"
607
+ instance_ckpt = instance_path / "best.ckpt"
608
+ if not centroid_ckpt.exists():
609
+ raise click.ClickException(f"Checkpoint not found: {centroid_ckpt}")
610
+ if not instance_ckpt.exists():
611
+ raise click.ClickException(f"Checkpoint not found: {instance_ckpt}")
612
+
613
+ centroid_model = _load_lightning_model(
614
+ model_type="centroid",
615
+ backbone_type=centroid_backbone,
616
+ cfg=centroid_cfg,
617
+ ckpt_path=centroid_ckpt,
618
+ device=device,
619
+ ).model
620
+ instance_model = _load_lightning_model(
621
+ model_type="multi_class_topdown",
622
+ backbone_type=instance_backbone,
623
+ cfg=instance_cfg,
624
+ ckpt_path=instance_ckpt,
625
+ device=device,
626
+ ).model
627
+
628
+ centroid_model.eval()
629
+ instance_model.eval()
630
+ centroid_model.to(device)
631
+ instance_model.to(device)
632
+
633
+ export_dir = output or (centroid_path / "exported_multi_class_topdown")
634
+ export_dir.mkdir(parents=True, exist_ok=True)
635
+
636
+ centroid_scale = (
637
+ input_scale
638
+ if input_scale is not None
639
+ else resolve_input_scale(centroid_cfg)
640
+ )
641
+ instance_scale = (
642
+ input_scale
643
+ if input_scale is not None
644
+ else resolve_input_scale(instance_cfg)
645
+ )
646
+ centroid_stride = resolve_output_stride(centroid_cfg, "centroid")
647
+ instance_stride = resolve_output_stride(instance_cfg, "multi_class_topdown")
648
+
649
+ resolved_crop = resolve_crop_size(instance_cfg)
650
+ if crop_size is not None:
651
+ resolved_crop = (crop_size, crop_size)
652
+ if resolved_crop is None:
653
+ raise click.ClickException(
654
+ "Multiclass top-down export requires crop_size. Provide --crop-size or "
655
+ "ensure data_config.preprocessing.crop_size is set."
656
+ )
657
+
658
+ node_names = resolve_node_names(instance_cfg, "multi_class_topdown")
659
+ edge_inds = resolve_edge_inds(instance_cfg, node_names)
660
+ n_classes = resolve_n_classes(instance_cfg, "multi_class_topdown")
661
+ class_names = resolve_class_names(instance_cfg, "multi_class_topdown")
662
+
663
+ wrapper = TopDownMultiClassCombinedONNXWrapper(
664
+ centroid_model=centroid_model,
665
+ instance_model=instance_model,
666
+ max_instances=max_instances,
667
+ crop_size=resolved_crop,
668
+ centroid_output_stride=centroid_stride,
669
+ instance_output_stride=instance_stride,
670
+ centroid_input_scale=centroid_scale,
671
+ instance_input_scale=instance_scale,
672
+ n_nodes=len(node_names),
673
+ n_classes=n_classes,
674
+ )
675
+ wrapper.eval()
676
+ wrapper.to(device)
677
+
678
+ input_shape = resolve_input_shape(
679
+ centroid_cfg, input_height=input_height, input_width=input_width
680
+ )
681
+ model_out_path = export_dir / "model.onnx"
682
+
683
+ export_to_onnx(
684
+ wrapper,
685
+ model_out_path,
686
+ input_shape=input_shape,
687
+ input_dtype=torch.uint8,
688
+ opset_version=opset_version,
689
+ output_names=[
690
+ "centroids",
691
+ "centroid_vals",
692
+ "peaks",
693
+ "peak_vals",
694
+ "class_logits",
695
+ "instance_valid",
696
+ ],
697
+ verify=verify,
698
+ )
699
+
700
+ centroid_cfg_path = _copy_training_config(centroid_path, export_dir, "centroid")
701
+ instance_cfg_path = _copy_training_config(
702
+ instance_path, export_dir, "multi_class_topdown"
703
+ )
704
+ config_payload = {}
705
+ config_hashes = []
706
+ if centroid_cfg_path is not None:
707
+ config_payload["centroid"] = centroid_cfg_path.read_text()
708
+ config_hashes.append(f"centroid:{hash_file(centroid_cfg_path)}")
709
+ if instance_cfg_path is not None:
710
+ config_payload["multi_class_topdown"] = instance_cfg_path.read_text()
711
+ config_hashes.append(f"multi_class_topdown:{hash_file(instance_cfg_path)}")
712
+
713
+ training_config_hash = ";".join(config_hashes) if config_hashes else ""
714
+ training_config_text = json.dumps(config_payload) if config_payload else None
715
+
716
+ metadata = build_base_metadata(
717
+ export_format="onnx",
718
+ model_type="multi_class_topdown_combined",
719
+ model_name=f"{centroid_path.name}+{instance_path.name}",
720
+ checkpoint_path=(
721
+ f"centroid:{centroid_ckpt};multi_class_topdown:{instance_ckpt}"
722
+ ),
723
+ backbone=(
724
+ f"centroid:{centroid_backbone};multi_class_topdown:{instance_backbone}"
725
+ ),
726
+ n_nodes=len(node_names),
727
+ n_edges=len(edge_inds),
728
+ node_names=node_names,
729
+ edge_inds=edge_inds,
730
+ input_scale=centroid_scale,
731
+ input_channels=resolve_input_channels(centroid_cfg),
732
+ output_stride=instance_stride,
733
+ crop_size=resolved_crop,
734
+ max_instances=max_instances,
735
+ max_batch_size=max_batch_size,
736
+ training_config_hash=training_config_hash,
737
+ training_config_embedded=training_config_text is not None,
738
+ input_dtype="uint8",
739
+ normalization="0_to_1",
740
+ n_classes=n_classes,
741
+ class_names=class_names,
742
+ )
743
+ metadata.save(export_dir / "export_metadata.json")
744
+ click.echo(f"ONNX model exported to: {model_out_path}")
745
+ click.echo(f"Metadata saved to: {export_dir / 'export_metadata.json'}")
746
+
747
+ # TensorRT export for combined multiclass top-down
748
+ if fmt in ("tensorrt", "both"):
749
+ trt_out_path = export_dir / "model.trt"
750
+ B, C, H, W = input_shape
751
+ export_to_tensorrt(
752
+ wrapper,
753
+ trt_out_path,
754
+ input_shape=input_shape,
755
+ input_dtype=torch.uint8,
756
+ precision=precision,
757
+ max_shape=(max_batch_size, C, H * 2, W * 2),
758
+ verbose=True,
759
+ )
760
+ trt_metadata = build_base_metadata(
761
+ export_format="tensorrt",
762
+ model_type="multi_class_topdown_combined",
763
+ model_name=f"{centroid_path.name}+{instance_path.name}",
764
+ checkpoint_path=(
765
+ f"centroid:{centroid_ckpt};multi_class_topdown:{instance_ckpt}"
766
+ ),
767
+ backbone=(
768
+ f"centroid:{centroid_backbone};multi_class_topdown:{instance_backbone}"
769
+ ),
770
+ n_nodes=len(node_names),
771
+ n_edges=len(edge_inds),
772
+ node_names=node_names,
773
+ edge_inds=edge_inds,
774
+ input_scale=centroid_scale,
775
+ input_channels=resolve_input_channels(centroid_cfg),
776
+ output_stride=instance_stride,
777
+ crop_size=resolved_crop,
778
+ max_instances=max_instances,
779
+ max_batch_size=max_batch_size,
780
+ precision=precision,
781
+ training_config_hash=training_config_hash,
782
+ training_config_embedded=training_config_text is not None,
783
+ input_dtype="uint8",
784
+ normalization="0_to_1",
785
+ n_classes=n_classes,
786
+ class_names=class_names,
787
+ )
788
+ trt_metadata.save(export_dir / "model.trt.metadata.json")
789
+ return
790
+
791
+ raise click.ClickException(
792
+ "Provide one model path for centroid/centered-instance/bottom-up export, "
793
+ "or two paths (centroid + centered_instance or centroid + multi_class_topdown) "
794
+ "for combined top-down export."
795
+ )
796
+
797
+
798
+ @click.command()
799
+ @click.argument(
800
+ "export_dir",
801
+ type=click.Path(exists=True, file_okay=False, path_type=Path),
802
+ )
803
+ @click.argument(
804
+ "video_path",
805
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
806
+ )
807
+ @click.option(
808
+ "--output",
809
+ "-o",
810
+ type=click.Path(dir_okay=False, path_type=Path),
811
+ default=None,
812
+ help="Output SLP file path. Default: video_name.predictions.slp",
813
+ )
814
+ @click.option(
815
+ "--runtime",
816
+ type=click.Choice(["auto", "onnx", "tensorrt"], case_sensitive=False),
817
+ default="auto",
818
+ show_default=True,
819
+ help="Runtime to use for inference.",
820
+ )
821
+ @click.option("--device", type=str, default="auto", show_default=True)
822
+ @click.option("--batch-size", type=int, default=4, show_default=True)
823
+ @click.option("--n-frames", type=int, default=None, help="Limit to first N frames.")
824
+ @click.option(
825
+ "--max-edge-length-ratio",
826
+ type=float,
827
+ default=0.25,
828
+ show_default=True,
829
+ help="Bottom-up: max edge length as ratio of PAF dimensions.",
830
+ )
831
+ @click.option(
832
+ "--dist-penalty-weight",
833
+ type=float,
834
+ default=1.0,
835
+ show_default=True,
836
+ help="Bottom-up: weight for distance penalty in PAF scoring.",
837
+ )
838
+ @click.option(
839
+ "--n-points",
840
+ type=int,
841
+ default=10,
842
+ show_default=True,
843
+ help="Bottom-up: number of points to sample along PAF.",
844
+ )
845
+ @click.option(
846
+ "--min-instance-peaks",
847
+ type=float,
848
+ default=0,
849
+ show_default=True,
850
+ help="Bottom-up: minimum peaks required per instance.",
851
+ )
852
+ @click.option(
853
+ "--min-line-scores",
854
+ type=float,
855
+ default=-0.5,
856
+ show_default=True,
857
+ help="Bottom-up: minimum line score threshold.",
858
+ )
859
+ @click.option(
860
+ "--peak-conf-threshold",
861
+ type=float,
862
+ default=0.1,
863
+ show_default=True,
864
+ help="Bottom-up: peak confidence threshold for filtering candidates.",
865
+ )
866
+ @click.option(
867
+ "--max-instances",
868
+ type=int,
869
+ default=None,
870
+ help="Maximum instances to output per frame.",
871
+ )
872
+ def predict(
873
+ export_dir: Path,
874
+ video_path: Path,
875
+ output: Optional[Path],
876
+ runtime: str,
877
+ device: str,
878
+ batch_size: int,
879
+ n_frames: Optional[int],
880
+ max_edge_length_ratio: float,
881
+ dist_penalty_weight: float,
882
+ n_points: int,
883
+ min_instance_peaks: float,
884
+ min_line_scores: float,
885
+ peak_conf_threshold: float,
886
+ max_instances: Optional[int],
887
+ ) -> None:
888
+ """Run inference on exported models and save predictions to SLP.
889
+
890
+ EXPORT_DIR is the directory containing the exported model (model.onnx or model.trt)
891
+ along with export_metadata.json and training_config.yaml.
892
+
893
+ VIDEO_PATH is the path to the video file to process.
894
+ """
895
+ import time
896
+ from datetime import datetime
897
+
898
+ import numpy as np
899
+ import sleap_io as sio
900
+
901
+ from sleap_nn.export.metadata import ExportMetadata
902
+ from sleap_nn.export.predictors import load_exported_model
903
+ from sleap_nn.export.utils import build_bottomup_candidate_template
904
+ from sleap_nn.inference.paf_grouping import PAFScorer
905
+ from sleap_nn.inference.utils import get_skeleton_from_config
906
+
907
+ # Load metadata
908
+ metadata_path = export_dir / "export_metadata.json"
909
+ if not metadata_path.exists():
910
+ raise click.ClickException(f"Metadata not found: {metadata_path}")
911
+ metadata = ExportMetadata.load(metadata_path)
912
+
913
+ # Find model file
914
+ onnx_path = export_dir / "model.onnx"
915
+ trt_path = export_dir / "model.trt"
916
+
917
+ if runtime == "auto":
918
+ if trt_path.exists():
919
+ model_path = trt_path
920
+ runtime = "tensorrt"
921
+ elif onnx_path.exists():
922
+ model_path = onnx_path
923
+ runtime = "onnx"
924
+ else:
925
+ raise click.ClickException(
926
+ f"No model found in {export_dir}. Expected model.onnx or model.trt."
927
+ )
928
+ elif runtime == "onnx":
929
+ if not onnx_path.exists():
930
+ raise click.ClickException(f"ONNX model not found: {onnx_path}")
931
+ model_path = onnx_path
932
+ elif runtime == "tensorrt":
933
+ if not trt_path.exists():
934
+ raise click.ClickException(f"TensorRT model not found: {trt_path}")
935
+ model_path = trt_path
936
+ else:
937
+ raise click.ClickException(f"Unknown runtime: {runtime}")
938
+
939
+ # Load training config for skeleton
940
+ cfg_path = _find_training_config_for_predict(export_dir, metadata.model_type)
941
+ if cfg_path.suffix in {".yaml", ".yml"}:
942
+ cfg = OmegaConf.load(cfg_path.as_posix())
943
+ else:
944
+ from sleap_nn.config.training_job_config import TrainingJobConfig
945
+
946
+ cfg = TrainingJobConfig.load_sleap_config(cfg_path.as_posix())
947
+ skeletons = get_skeleton_from_config(cfg.data_config.skeletons)
948
+ skeleton = skeletons[0]
949
+
950
+ # Load video
951
+ video = sio.Video.from_filename(video_path.as_posix())
952
+ total_frames = len(video) if n_frames is None else min(n_frames, len(video))
953
+ frame_indices = list(range(total_frames))
954
+
955
+ click.echo(f"Loading model from: {model_path}")
956
+ click.echo(f" Model type: {metadata.model_type}")
957
+ click.echo(f" Runtime: {runtime}")
958
+ click.echo(f" Device: {device}")
959
+
960
+ predictor = load_exported_model(
961
+ model_path.as_posix(), runtime=runtime, device=device
962
+ )
963
+
964
+ click.echo(f"Processing video: {video_path}")
965
+ click.echo(f" Total frames: {total_frames}")
966
+ click.echo(f" Batch size: {batch_size}")
967
+
968
+ # Set up centroid anchor node if needed
969
+ anchor_node_idx = None
970
+ if metadata.model_type == "centroid":
971
+ anchor_part = cfg.model_config.head_configs.centroid.confmaps.anchor_part
972
+ node_names = [n.name for n in skeleton.nodes]
973
+ if anchor_part in node_names:
974
+ anchor_node_idx = node_names.index(anchor_part)
975
+ else:
976
+ raise click.ClickException(
977
+ f"Anchor part '{anchor_part}' not found in skeleton nodes: {node_names}"
978
+ )
979
+
980
+ # Set up bottom-up post-processing if needed
981
+ paf_scorer = None
982
+ candidate_template = None
983
+ if metadata.model_type == "bottomup":
984
+ paf_scorer = PAFScorer.from_config(
985
+ cfg.model_config.head_configs.bottomup,
986
+ max_edge_length_ratio=max_edge_length_ratio,
987
+ dist_penalty_weight=dist_penalty_weight,
988
+ n_points=n_points,
989
+ min_instance_peaks=min_instance_peaks,
990
+ min_line_scores=min_line_scores,
991
+ )
992
+ max_peaks = metadata.max_peaks_per_node
993
+ if max_peaks is None:
994
+ raise click.ClickException(
995
+ "Bottom-up export metadata missing max_peaks_per_node."
996
+ )
997
+ edge_inds_tuples = [(int(e[0]), int(e[1])) for e in paf_scorer.edge_inds]
998
+ peak_channel_inds, edge_inds_tensor, edge_peak_inds = (
999
+ build_bottomup_candidate_template(
1000
+ n_nodes=metadata.n_nodes,
1001
+ max_peaks_per_node=max_peaks,
1002
+ edge_inds=edge_inds_tuples,
1003
+ )
1004
+ )
1005
+ candidate_template = {
1006
+ "peak_channel_inds": peak_channel_inds,
1007
+ "edge_inds": edge_inds_tensor,
1008
+ "edge_peak_inds": edge_peak_inds,
1009
+ }
1010
+
1011
+ labeled_frames = []
1012
+ total_start = time.perf_counter()
1013
+ infer_time = 0.0
1014
+ post_time = 0.0
1015
+
1016
+ for start in range(0, len(frame_indices), batch_size):
1017
+ batch_indices = frame_indices[start : start + batch_size]
1018
+ batch = _load_video_batch(video, batch_indices)
1019
+
1020
+ infer_start = time.perf_counter()
1021
+ outputs = predictor.predict(batch)
1022
+ infer_time += time.perf_counter() - infer_start
1023
+
1024
+ post_start = time.perf_counter()
1025
+ if metadata.model_type == "topdown":
1026
+ labeled_frames.extend(
1027
+ _predict_topdown_frames(
1028
+ outputs,
1029
+ batch_indices,
1030
+ video,
1031
+ skeleton,
1032
+ max_instances=max_instances,
1033
+ )
1034
+ )
1035
+ elif metadata.model_type == "bottomup":
1036
+ labeled_frames.extend(
1037
+ _predict_bottomup_frames(
1038
+ outputs,
1039
+ batch_indices,
1040
+ video,
1041
+ skeleton,
1042
+ paf_scorer,
1043
+ candidate_template,
1044
+ input_scale=metadata.input_scale,
1045
+ peak_conf_threshold=peak_conf_threshold,
1046
+ max_instances=max_instances,
1047
+ )
1048
+ )
1049
+ elif metadata.model_type == "single_instance":
1050
+ labeled_frames.extend(
1051
+ _predict_single_instance_frames(
1052
+ outputs,
1053
+ batch_indices,
1054
+ video,
1055
+ skeleton,
1056
+ )
1057
+ )
1058
+ elif metadata.model_type == "centroid":
1059
+ labeled_frames.extend(
1060
+ _predict_centroid_frames(
1061
+ outputs,
1062
+ batch_indices,
1063
+ video,
1064
+ skeleton,
1065
+ anchor_node_idx=anchor_node_idx,
1066
+ max_instances=max_instances,
1067
+ )
1068
+ )
1069
+ elif metadata.model_type == "multi_class_bottomup":
1070
+ labeled_frames.extend(
1071
+ _predict_multiclass_bottomup_frames(
1072
+ outputs,
1073
+ batch_indices,
1074
+ video,
1075
+ skeleton,
1076
+ class_names=metadata.class_names or [],
1077
+ input_scale=metadata.input_scale,
1078
+ peak_conf_threshold=peak_conf_threshold,
1079
+ max_instances=max_instances,
1080
+ )
1081
+ )
1082
+ elif metadata.model_type == "multi_class_topdown_combined":
1083
+ labeled_frames.extend(
1084
+ _predict_multiclass_topdown_combined_frames(
1085
+ outputs,
1086
+ batch_indices,
1087
+ video,
1088
+ skeleton,
1089
+ class_names=metadata.class_names or [],
1090
+ max_instances=max_instances,
1091
+ )
1092
+ )
1093
+ else:
1094
+ raise click.ClickException(
1095
+ f"Unsupported model_type for predict: {metadata.model_type}"
1096
+ )
1097
+ post_time += time.perf_counter() - post_start
1098
+
1099
+ # Progress update
1100
+ processed = min(start + batch_size, len(frame_indices))
1101
+ click.echo(
1102
+ f"\r Processed {processed}/{len(frame_indices)} frames...",
1103
+ nl=False,
1104
+ )
1105
+
1106
+ click.echo() # Newline after progress
1107
+
1108
+ total_time = time.perf_counter() - total_start
1109
+ fps = len(frame_indices) / total_time if total_time > 0 else 0
1110
+
1111
+ # Save predictions
1112
+ output_path = output or video_path.with_suffix(".predictions.slp")
1113
+ labels = sio.Labels(
1114
+ videos=[video],
1115
+ skeletons=[skeleton],
1116
+ labeled_frames=labeled_frames,
1117
+ )
1118
+ labels.provenance = {
1119
+ "sleap_nn_version": metadata.sleap_nn_version,
1120
+ "export_format": runtime,
1121
+ "model_type": metadata.model_type,
1122
+ "inference_timestamp": datetime.now().isoformat(),
1123
+ }
1124
+ sio.save_file(labels, output_path.as_posix())
1125
+
1126
+ click.echo(f"\nInference complete:")
1127
+ click.echo(f" Total time: {total_time:.2f}s")
1128
+ click.echo(f" Inference time: {infer_time:.2f}s")
1129
+ click.echo(f" Post-processing time: {post_time:.2f}s")
1130
+ click.echo(f" FPS: {fps:.2f}")
1131
+ click.echo(f" Frames with predictions: {len(labeled_frames)}")
1132
+ click.echo(f" Output saved to: {output_path}")
1133
+
1134
+
1135
+ def _find_training_config_for_predict(export_dir: Path, model_type: str) -> Path:
1136
+ """Find training config file in export directory."""
1137
+ candidates = []
1138
+ if model_type == "topdown":
1139
+ candidates.extend(
1140
+ [
1141
+ export_dir / "training_config_centered_instance.yaml",
1142
+ export_dir / "training_config_centered_instance.json",
1143
+ ]
1144
+ )
1145
+ elif model_type == "multi_class_topdown_combined":
1146
+ candidates.extend(
1147
+ [
1148
+ export_dir / "training_config_multi_class_topdown.yaml",
1149
+ export_dir / "training_config_multi_class_topdown.json",
1150
+ ]
1151
+ )
1152
+ candidates.extend(
1153
+ [
1154
+ export_dir / "training_config.yaml",
1155
+ export_dir / "training_config.json",
1156
+ export_dir / f"training_config_{model_type}.yaml",
1157
+ export_dir / f"training_config_{model_type}.json",
1158
+ ]
1159
+ )
1160
+
1161
+ for candidate in candidates:
1162
+ if candidate.exists():
1163
+ return candidate
1164
+ raise click.ClickException(
1165
+ f"No training_config found in {export_dir} for model_type={model_type}."
1166
+ )
1167
+
1168
+
1169
+ def _load_video_batch(video, frame_indices):
1170
+ """Load a batch of video frames as uint8 NCHW array."""
1171
+ import numpy as np
1172
+
1173
+ frames = []
1174
+ for idx in frame_indices:
1175
+ frame = np.asarray(video[idx])
1176
+ if frame.ndim == 2:
1177
+ frame = frame[:, :, None]
1178
+ if frame.dtype != np.uint8:
1179
+ frame = frame.astype(np.uint8)
1180
+ frame = np.transpose(frame, (2, 0, 1)) # HWC -> CHW
1181
+ frames.append(frame)
1182
+ return np.stack(frames, axis=0)
1183
+
1184
+
1185
+ def _predict_topdown_frames(
1186
+ outputs,
1187
+ frame_indices,
1188
+ video,
1189
+ skeleton,
1190
+ max_instances=None,
1191
+ ):
1192
+ """Convert top-down model outputs to LabeledFrames."""
1193
+ import sleap_io as sio
1194
+
1195
+ labeled_frames = []
1196
+ centroids = outputs["centroids"]
1197
+ centroid_vals = outputs["centroid_vals"]
1198
+ peaks = outputs["peaks"]
1199
+ peak_vals = outputs["peak_vals"]
1200
+ instance_valid = outputs["instance_valid"]
1201
+
1202
+ for batch_idx, frame_idx in enumerate(frame_indices):
1203
+ instances = []
1204
+ valid_mask = instance_valid[batch_idx].astype(bool)
1205
+ for inst_idx, is_valid in enumerate(valid_mask):
1206
+ if not is_valid:
1207
+ continue
1208
+ pts = peaks[batch_idx, inst_idx]
1209
+ scores = peak_vals[batch_idx, inst_idx]
1210
+ score = float(centroid_vals[batch_idx, inst_idx])
1211
+ instances.append(
1212
+ sio.PredictedInstance.from_numpy(
1213
+ points_data=pts,
1214
+ point_scores=scores,
1215
+ score=score,
1216
+ skeleton=skeleton,
1217
+ )
1218
+ )
1219
+
1220
+ if max_instances is not None and instances:
1221
+ instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
1222
+ instances = instances[:max_instances]
1223
+
1224
+ if instances:
1225
+ labeled_frames.append(
1226
+ sio.LabeledFrame(
1227
+ video=video,
1228
+ frame_idx=int(frame_idx),
1229
+ instances=instances,
1230
+ )
1231
+ )
1232
+
1233
+ return labeled_frames
1234
+
1235
+
1236
+ def _predict_multiclass_topdown_combined_frames(
1237
+ outputs,
1238
+ frame_indices,
1239
+ video,
1240
+ skeleton,
1241
+ class_names: list,
1242
+ max_instances=None,
1243
+ ):
1244
+ """Convert combined multiclass top-down model outputs to LabeledFrames.
1245
+
1246
+ Args:
1247
+ outputs: Model outputs with centroids, centroid_vals, peaks, peak_vals,
1248
+ class_logits, instance_valid.
1249
+ frame_indices: Frame indices corresponding to batch.
1250
+ video: sleap_io.Video object.
1251
+ skeleton: sleap_io.Skeleton object.
1252
+ class_names: List of class names (e.g., ["female", "male"]).
1253
+ max_instances: Maximum instances per frame (None = n_classes).
1254
+
1255
+ Returns:
1256
+ List of LabeledFrame objects.
1257
+ """
1258
+ import numpy as np
1259
+ import sleap_io as sio
1260
+ from scipy.optimize import linear_sum_assignment
1261
+
1262
+ labeled_frames = []
1263
+ centroids = outputs["centroids"]
1264
+ centroid_vals = outputs["centroid_vals"]
1265
+ peaks = outputs["peaks"]
1266
+ peak_vals = outputs["peak_vals"]
1267
+ class_logits = outputs["class_logits"]
1268
+ instance_valid = outputs["instance_valid"]
1269
+
1270
+ n_classes = len(class_names)
1271
+
1272
+ for batch_idx, frame_idx in enumerate(frame_indices):
1273
+ valid_mask = instance_valid[batch_idx].astype(bool)
1274
+ n_valid = valid_mask.sum()
1275
+
1276
+ if n_valid == 0:
1277
+ continue
1278
+
1279
+ # Gather valid instances
1280
+ valid_peaks = peaks[batch_idx, valid_mask] # (n_valid, n_nodes, 2)
1281
+ valid_peak_vals = peak_vals[batch_idx, valid_mask] # (n_valid, n_nodes)
1282
+ valid_centroid_vals = centroid_vals[batch_idx, valid_mask] # (n_valid,)
1283
+ valid_class_logits = class_logits[batch_idx, valid_mask] # (n_valid, n_classes)
1284
+
1285
+ # Compute softmax probabilities from logits
1286
+ logits = valid_class_logits - np.max(valid_class_logits, axis=1, keepdims=True)
1287
+ probs = np.exp(logits)
1288
+ probs = probs / np.sum(probs, axis=1, keepdims=True)
1289
+
1290
+ # Use Hungarian matching to assign classes to instances
1291
+ # Maximize total probability (minimize negative)
1292
+ cost = -probs
1293
+ row_inds, col_inds = linear_sum_assignment(cost)
1294
+
1295
+ # Create instances with class assignments
1296
+ instances = []
1297
+ for row_idx, class_idx in zip(row_inds, col_inds):
1298
+ pts = valid_peaks[row_idx]
1299
+ scores = valid_peak_vals[row_idx]
1300
+ score = float(valid_centroid_vals[row_idx])
1301
+
1302
+ # Get track name from class names
1303
+ track_name = (
1304
+ class_names[class_idx]
1305
+ if class_idx < len(class_names)
1306
+ else f"class_{class_idx}"
1307
+ )
1308
+
1309
+ instances.append(
1310
+ sio.PredictedInstance.from_numpy(
1311
+ points_data=pts,
1312
+ point_scores=scores,
1313
+ score=score,
1314
+ skeleton=skeleton,
1315
+ track=sio.Track(name=track_name),
1316
+ )
1317
+ )
1318
+
1319
+ if max_instances is not None and instances:
1320
+ instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
1321
+ instances = instances[:max_instances]
1322
+
1323
+ if instances:
1324
+ labeled_frames.append(
1325
+ sio.LabeledFrame(
1326
+ video=video,
1327
+ frame_idx=int(frame_idx),
1328
+ instances=instances,
1329
+ )
1330
+ )
1331
+
1332
+ return labeled_frames
1333
+
1334
+
1335
+ def _predict_bottomup_frames(
1336
+ outputs,
1337
+ frame_indices,
1338
+ video,
1339
+ skeleton,
1340
+ paf_scorer,
1341
+ candidate_template,
1342
+ input_scale,
1343
+ peak_conf_threshold=0.1,
1344
+ max_instances=None,
1345
+ ):
1346
+ """Convert bottom-up model outputs to LabeledFrames."""
1347
+ import numpy as np
1348
+ import sleap_io as sio
1349
+ import torch
1350
+
1351
+ labeled_frames = []
1352
+
1353
+ peaks = torch.from_numpy(outputs["peaks"]).to(torch.float32)
1354
+ peak_vals = torch.from_numpy(outputs["peak_vals"]).to(torch.float32)
1355
+ line_scores = torch.from_numpy(outputs["line_scores"]).to(torch.float32)
1356
+ candidate_mask = torch.from_numpy(outputs["candidate_mask"]).to(torch.bool)
1357
+
1358
+ batch_size, n_nodes, k, _ = peaks.shape
1359
+ peaks_flat = peaks.reshape(batch_size, n_nodes * k, 2)
1360
+ peak_vals_flat = peak_vals.reshape(batch_size, n_nodes * k)
1361
+
1362
+ peak_channel_inds_base = candidate_template["peak_channel_inds"]
1363
+ edge_inds_base = candidate_template["edge_inds"]
1364
+ edge_peak_inds_base = candidate_template["edge_peak_inds"]
1365
+
1366
+ peaks_list = []
1367
+ peak_vals_list = []
1368
+ peak_channel_inds_list = []
1369
+ edge_inds_list = []
1370
+ edge_peak_inds_list = []
1371
+ line_scores_list = []
1372
+
1373
+ for b in range(batch_size):
1374
+ peaks_list.append(peaks_flat[b])
1375
+ peak_vals_list.append(peak_vals_flat[b])
1376
+ peak_channel_inds_list.append(peak_channel_inds_base)
1377
+
1378
+ candidate_mask_flat = candidate_mask[b].reshape(-1)
1379
+ line_scores_flat = line_scores[b].reshape(-1)
1380
+
1381
+ if candidate_mask_flat.numel() == 0:
1382
+ edge_inds_list.append(torch.empty((0,), dtype=torch.int32))
1383
+ edge_peak_inds_list.append(torch.empty((0, 2), dtype=torch.int32))
1384
+ line_scores_list.append(torch.empty((0,), dtype=torch.float32))
1385
+ continue
1386
+
1387
+ # Filter candidates by peak confidence threshold
1388
+ peak_vals_b = peak_vals_flat[b]
1389
+ peak_conf_valid = peak_vals_b > peak_conf_threshold
1390
+ src_valid = peak_conf_valid[edge_peak_inds_base[:, 0].long()]
1391
+ dst_valid = peak_conf_valid[edge_peak_inds_base[:, 1].long()]
1392
+ valid = candidate_mask_flat & src_valid & dst_valid
1393
+
1394
+ edge_inds_list.append(edge_inds_base[valid])
1395
+ edge_peak_inds_list.append(edge_peak_inds_base[valid])
1396
+ line_scores_list.append(line_scores_flat[valid])
1397
+
1398
+ (
1399
+ match_edge_inds,
1400
+ match_src_peak_inds,
1401
+ match_dst_peak_inds,
1402
+ match_line_scores,
1403
+ ) = paf_scorer.match_candidates(
1404
+ edge_inds_list,
1405
+ edge_peak_inds_list,
1406
+ line_scores_list,
1407
+ )
1408
+
1409
+ (
1410
+ predicted_instances,
1411
+ predicted_peak_scores,
1412
+ predicted_instance_scores,
1413
+ ) = paf_scorer.group_instances(
1414
+ peaks_list,
1415
+ peak_vals_list,
1416
+ peak_channel_inds_list,
1417
+ match_edge_inds,
1418
+ match_src_peak_inds,
1419
+ match_dst_peak_inds,
1420
+ match_line_scores,
1421
+ )
1422
+
1423
+ predicted_instances = [p / input_scale for p in predicted_instances]
1424
+
1425
+ for batch_idx, frame_idx in enumerate(frame_indices):
1426
+ instances = []
1427
+ for pts, confs, score in zip(
1428
+ predicted_instances[batch_idx],
1429
+ predicted_peak_scores[batch_idx],
1430
+ predicted_instance_scores[batch_idx],
1431
+ ):
1432
+ pts_np = pts.cpu().numpy()
1433
+ if np.isnan(pts_np).all():
1434
+ continue
1435
+ instances.append(
1436
+ sio.PredictedInstance.from_numpy(
1437
+ points_data=pts_np,
1438
+ point_scores=confs.cpu().numpy(),
1439
+ score=float(score),
1440
+ skeleton=skeleton,
1441
+ )
1442
+ )
1443
+
1444
+ if max_instances is not None and instances:
1445
+ instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
1446
+ instances = instances[:max_instances]
1447
+
1448
+ if instances:
1449
+ labeled_frames.append(
1450
+ sio.LabeledFrame(
1451
+ video=video,
1452
+ frame_idx=int(frame_idx),
1453
+ instances=instances,
1454
+ )
1455
+ )
1456
+
1457
+ return labeled_frames
1458
+
1459
+
1460
+ def _predict_single_instance_frames(
1461
+ outputs,
1462
+ frame_indices,
1463
+ video,
1464
+ skeleton,
1465
+ ):
1466
+ """Convert single-instance model outputs to LabeledFrames."""
1467
+ import numpy as np
1468
+ import sleap_io as sio
1469
+
1470
+ labeled_frames = []
1471
+ peaks = outputs["peaks"] # (batch, n_nodes, 2)
1472
+ peak_vals = outputs["peak_vals"] # (batch, n_nodes)
1473
+
1474
+ for batch_idx, frame_idx in enumerate(frame_indices):
1475
+ pts = peaks[batch_idx]
1476
+ scores = peak_vals[batch_idx]
1477
+
1478
+ # Compute instance score as mean of valid peak values
1479
+ valid_mask = ~np.isnan(pts[:, 0])
1480
+ if valid_mask.any():
1481
+ instance_score = float(np.mean(scores[valid_mask]))
1482
+ else:
1483
+ instance_score = 0.0
1484
+
1485
+ instance = sio.PredictedInstance.from_numpy(
1486
+ points_data=pts,
1487
+ point_scores=scores,
1488
+ score=instance_score,
1489
+ skeleton=skeleton,
1490
+ )
1491
+
1492
+ labeled_frames.append(
1493
+ sio.LabeledFrame(
1494
+ video=video,
1495
+ frame_idx=int(frame_idx),
1496
+ instances=[instance],
1497
+ )
1498
+ )
1499
+
1500
+ return labeled_frames
1501
+
1502
+
1503
+ def _predict_centroid_frames(
1504
+ outputs,
1505
+ frame_indices,
1506
+ video,
1507
+ skeleton,
1508
+ anchor_node_idx: int,
1509
+ max_instances=None,
1510
+ ):
1511
+ """Convert centroid model outputs to LabeledFrames.
1512
+
1513
+ For centroid-only models, creates instances with only the anchor node filled in.
1514
+ All other nodes are set to NaN.
1515
+
1516
+ Args:
1517
+ outputs: Model outputs with centroids, centroid_vals, instance_valid.
1518
+ frame_indices: Frame indices corresponding to batch.
1519
+ video: sleap_io.Video object.
1520
+ skeleton: sleap_io.Skeleton object.
1521
+ anchor_node_idx: Index of the anchor node in the skeleton.
1522
+ max_instances: Maximum instances to output per frame.
1523
+
1524
+ Returns:
1525
+ List of LabeledFrame objects.
1526
+ """
1527
+ import numpy as np
1528
+ import sleap_io as sio
1529
+
1530
+ labeled_frames = []
1531
+ centroids = outputs["centroids"] # (batch, max_instances, 2)
1532
+ centroid_vals = outputs["centroid_vals"] # (batch, max_instances)
1533
+ instance_valid = outputs["instance_valid"] # (batch, max_instances)
1534
+
1535
+ n_nodes = len(skeleton.nodes)
1536
+
1537
+ for batch_idx, frame_idx in enumerate(frame_indices):
1538
+ instances = []
1539
+ valid_mask = instance_valid[batch_idx].astype(bool)
1540
+
1541
+ for inst_idx, is_valid in enumerate(valid_mask):
1542
+ if not is_valid:
1543
+ continue
1544
+
1545
+ # Create points array with NaN for all nodes except anchor
1546
+ pts = np.full((n_nodes, 2), np.nan, dtype=np.float32)
1547
+ pts[anchor_node_idx] = centroids[batch_idx, inst_idx]
1548
+
1549
+ # Create scores array - anchor gets centroid score, others get NaN
1550
+ scores = np.full((n_nodes,), np.nan, dtype=np.float32)
1551
+ scores[anchor_node_idx] = centroid_vals[batch_idx, inst_idx]
1552
+
1553
+ instance_score = float(centroid_vals[batch_idx, inst_idx])
1554
+
1555
+ instances.append(
1556
+ sio.PredictedInstance.from_numpy(
1557
+ points_data=pts,
1558
+ point_scores=scores,
1559
+ score=instance_score,
1560
+ skeleton=skeleton,
1561
+ )
1562
+ )
1563
+
1564
+ if max_instances is not None and instances:
1565
+ instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
1566
+ instances = instances[:max_instances]
1567
+
1568
+ if instances:
1569
+ labeled_frames.append(
1570
+ sio.LabeledFrame(
1571
+ video=video,
1572
+ frame_idx=int(frame_idx),
1573
+ instances=instances,
1574
+ )
1575
+ )
1576
+
1577
+ return labeled_frames
1578
+
1579
+
1580
+ def _predict_multiclass_bottomup_frames(
1581
+ outputs,
1582
+ frame_indices,
1583
+ video,
1584
+ skeleton,
1585
+ class_names: list,
1586
+ input_scale: float = 1.0,
1587
+ peak_conf_threshold: float = 0.1,
1588
+ max_instances: int = None,
1589
+ ):
1590
+ """Convert bottom-up multiclass model outputs to LabeledFrames.
1591
+
1592
+ Uses class probability maps to group peaks by identity rather than PAFs.
1593
+
1594
+ Args:
1595
+ outputs: Model outputs with peaks, peak_vals, peak_mask, class_probs.
1596
+ frame_indices: Frame indices corresponding to batch.
1597
+ video: sleap_io.Video object.
1598
+ skeleton: sleap_io.Skeleton object.
1599
+ class_names: List of class names (e.g., ["female", "male"]).
1600
+ input_scale: Scale factor applied to input.
1601
+ peak_conf_threshold: Minimum peak confidence to include.
1602
+ max_instances: Maximum instances per frame (None = n_classes).
1603
+
1604
+ Returns:
1605
+ List of LabeledFrame objects.
1606
+ """
1607
+ import numpy as np
1608
+ import sleap_io as sio
1609
+ from scipy.optimize import linear_sum_assignment
1610
+
1611
+ labeled_frames = []
1612
+ n_classes = len(class_names)
1613
+
1614
+ peaks = outputs["peaks"] # (batch, n_nodes, max_peaks, 2)
1615
+ peak_vals = outputs["peak_vals"] # (batch, n_nodes, max_peaks)
1616
+ peak_mask = outputs["peak_mask"] # (batch, n_nodes, max_peaks)
1617
+ class_probs = outputs["class_probs"] # (batch, n_nodes, max_peaks, n_classes)
1618
+
1619
+ batch_size, n_nodes, max_peaks, _ = peaks.shape
1620
+ n_nodes_skel = len(skeleton.nodes)
1621
+
1622
+ for batch_idx, frame_idx in enumerate(frame_indices):
1623
+ # Initialize instances for each class
1624
+ instance_points = np.full(
1625
+ (n_classes, n_nodes_skel, 2), np.nan, dtype=np.float32
1626
+ )
1627
+ instance_scores = np.full((n_classes, n_nodes_skel), np.nan, dtype=np.float32)
1628
+ instance_class_probs = np.full((n_classes,), 0.0, dtype=np.float32)
1629
+
1630
+ # Process each node independently
1631
+ for node_idx in range(min(n_nodes, n_nodes_skel)):
1632
+ # Get valid peaks for this node
1633
+ valid = peak_mask[batch_idx, node_idx].astype(bool)
1634
+ valid = valid & (peak_vals[batch_idx, node_idx] > peak_conf_threshold)
1635
+
1636
+ if not valid.any():
1637
+ continue
1638
+
1639
+ valid_peaks = peaks[batch_idx, node_idx][valid] # (n_valid, 2)
1640
+ valid_vals = peak_vals[batch_idx, node_idx][valid] # (n_valid,)
1641
+ valid_class_probs = class_probs[batch_idx, node_idx][
1642
+ valid
1643
+ ] # (n_valid, n_classes)
1644
+
1645
+ # Use Hungarian matching to assign peaks to classes
1646
+ # Maximize class probabilities (minimize negative)
1647
+ cost = -valid_class_probs
1648
+ row_inds, col_inds = linear_sum_assignment(cost)
1649
+
1650
+ # Assign matched peaks to instances
1651
+ for peak_idx, class_idx in zip(row_inds, col_inds):
1652
+ if class_idx < n_classes:
1653
+ instance_points[class_idx, node_idx] = (
1654
+ valid_peaks[peak_idx] / input_scale
1655
+ )
1656
+ instance_scores[class_idx, node_idx] = valid_vals[peak_idx]
1657
+ instance_class_probs[class_idx] += valid_class_probs[
1658
+ peak_idx, class_idx
1659
+ ]
1660
+
1661
+ # Create predicted instances
1662
+ instances = []
1663
+ for class_idx in range(n_classes):
1664
+ pts = instance_points[class_idx]
1665
+ scores = instance_scores[class_idx]
1666
+
1667
+ # Skip if no valid points
1668
+ if np.isnan(pts).all():
1669
+ continue
1670
+
1671
+ # Compute instance score as mean of valid peak values
1672
+ valid_mask = ~np.isnan(pts[:, 0])
1673
+ if valid_mask.any():
1674
+ instance_score = float(np.mean(scores[valid_mask]))
1675
+ else:
1676
+ instance_score = 0.0
1677
+
1678
+ # Get track name from class names
1679
+ track_name = (
1680
+ class_names[class_idx]
1681
+ if class_idx < len(class_names)
1682
+ else f"class_{class_idx}"
1683
+ )
1684
+
1685
+ instances.append(
1686
+ sio.PredictedInstance.from_numpy(
1687
+ points_data=pts,
1688
+ point_scores=scores,
1689
+ score=instance_score,
1690
+ skeleton=skeleton,
1691
+ track=sio.Track(name=track_name),
1692
+ )
1693
+ )
1694
+
1695
+ if max_instances is not None and instances:
1696
+ instances = sorted(instances, key=lambda inst: inst.score, reverse=True)
1697
+ instances = instances[:max_instances]
1698
+
1699
+ if instances:
1700
+ labeled_frames.append(
1701
+ sio.LabeledFrame(
1702
+ video=video,
1703
+ frame_idx=int(frame_idx),
1704
+ instances=instances,
1705
+ )
1706
+ )
1707
+
1708
+ return labeled_frames
1709
+
1710
+
1711
+ def _copy_training_config(
1712
+ model_path: Path, export_dir: Path, label: Optional[str]
1713
+ ) -> Optional[Path]:
1714
+ training_config_path = _training_config_path(model_path)
1715
+ if training_config_path is None:
1716
+ return None
1717
+
1718
+ if label:
1719
+ dest_name = f"training_config_{label}{training_config_path.suffix}"
1720
+ else:
1721
+ dest_name = training_config_path.name
1722
+
1723
+ dest_path = export_dir / dest_name
1724
+ shutil.copy(training_config_path, dest_path)
1725
+ return dest_path
1726
+
1727
+
1728
+ def _training_config_path(model_path: Path) -> Optional[Path]:
1729
+ yaml_path = model_path / "training_config.yaml"
1730
+ json_path = model_path / "training_config.json"
1731
+ if yaml_path.exists():
1732
+ return yaml_path
1733
+ if json_path.exists():
1734
+ return json_path
1735
+ return None
1736
+
1737
+
1738
+ def _load_lightning_model(
1739
+ *,
1740
+ model_type: str,
1741
+ backbone_type: str,
1742
+ cfg,
1743
+ ckpt_path: Path,
1744
+ device: str,
1745
+ ):
1746
+ lightning_cls = {
1747
+ "centroid": CentroidLightningModule,
1748
+ "centered_instance": TopDownCenteredInstanceLightningModule,
1749
+ "single_instance": SingleInstanceLightningModule,
1750
+ "bottomup": BottomUpLightningModule,
1751
+ "multi_class_topdown": TopDownCenteredInstanceMultiClassLightningModule,
1752
+ "multi_class_bottomup": BottomUpMultiClassLightningModule,
1753
+ }.get(model_type)
1754
+
1755
+ if lightning_cls is None:
1756
+ raise click.ClickException(f"Unsupported model type: {model_type}")
1757
+
1758
+ return lightning_cls.load_from_checkpoint(
1759
+ checkpoint_path=str(ckpt_path),
1760
+ model_type=model_type,
1761
+ backbone_type=backbone_type,
1762
+ backbone_config=cfg.model_config.backbone_config,
1763
+ head_configs=cfg.model_config.head_configs,
1764
+ pretrained_backbone_weights=cfg.model_config.pretrained_backbone_weights,
1765
+ pretrained_head_weights=cfg.model_config.pretrained_head_weights,
1766
+ init_weights=cfg.model_config.init_weights,
1767
+ lr_scheduler=cfg.trainer_config.lr_scheduler,
1768
+ online_mining=cfg.trainer_config.online_hard_keypoint_mining.online_mining,
1769
+ hard_to_easy_ratio=cfg.trainer_config.online_hard_keypoint_mining.hard_to_easy_ratio,
1770
+ min_hard_keypoints=cfg.trainer_config.online_hard_keypoint_mining.min_hard_keypoints,
1771
+ max_hard_keypoints=cfg.trainer_config.online_hard_keypoint_mining.max_hard_keypoints,
1772
+ loss_scale=cfg.trainer_config.online_hard_keypoint_mining.loss_scale,
1773
+ optimizer=cfg.trainer_config.optimizer_name,
1774
+ learning_rate=cfg.trainer_config.optimizer.lr,
1775
+ amsgrad=cfg.trainer_config.optimizer.amsgrad,
1776
+ map_location=device,
1777
+ weights_only=False,
1778
+ )