sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/predict.py CHANGED
@@ -15,6 +15,11 @@ from sleap_nn.tracking.tracker import (
15
15
  connect_single_breaks,
16
16
  cull_instances,
17
17
  )
18
+ from sleap_nn.system_info import get_startup_info_string
19
+ from sleap_nn.inference.provenance import (
20
+ build_inference_provenance,
21
+ build_tracking_only_provenance,
22
+ )
18
23
  from omegaconf import OmegaConf
19
24
  import sleap_io as sio
20
25
  from pathlib import Path
@@ -58,15 +63,20 @@ def run_inference(
58
63
  anchor_part: Optional[str] = None,
59
64
  only_labeled_frames: bool = False,
60
65
  only_suggested_frames: bool = False,
66
+ exclude_user_labeled: bool = False,
67
+ only_predicted_frames: bool = False,
61
68
  no_empty_frames: bool = False,
62
69
  batch_size: int = 4,
63
- queue_maxsize: int = 8,
70
+ queue_maxsize: int = 32,
64
71
  video_index: Optional[int] = None,
65
72
  video_dataset: Optional[str] = None,
66
73
  video_input_format: str = "channels_last",
67
74
  frames: Optional[list] = None,
68
75
  crop_size: Optional[int] = None,
69
76
  peak_threshold: Union[float, List[float]] = 0.2,
77
+ filter_overlapping: bool = False,
78
+ filter_overlapping_method: str = "iou",
79
+ filter_overlapping_threshold: float = 0.8,
70
80
  integral_refinement: Optional[str] = "integral",
71
81
  integral_patch_size: int = 5,
72
82
  return_confmaps: bool = False,
@@ -103,6 +113,7 @@ def run_inference(
103
113
  tracking_pre_cull_iou_threshold: float = 0,
104
114
  tracking_clean_instance_count: int = 0,
105
115
  tracking_clean_iou_threshold: float = 0,
116
+ gui: bool = False,
106
117
  ):
107
118
  """Entry point to run inference on trained SLEAP-NN models.
108
119
 
@@ -136,21 +147,32 @@ def run_inference(
136
147
  provided, the anchor part in the `training_config.yaml` is used. Default: `None`.
137
148
  only_labeled_frames: (bool) `True` if inference should be run only on user-labeled frames. Default: `False`.
138
149
  only_suggested_frames: (bool) `True` if inference should be run only on unlabeled suggested frames. Default: `False`.
150
+ exclude_user_labeled: (bool) `True` to skip frames that have user-labeled instances. Default: `False`.
151
+ only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
139
152
  no_empty_frames: (bool) `True` if empty frames that did not have predictions should be cleared before saving to output. Default: `False`.
140
153
  batch_size: (int) Number of samples per batch. Default: 4.
141
- queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
154
+ queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
142
155
  video_index: (int) Integer index of video in .slp file to predict on. To be used with
143
156
  an .slp path as an alternative to specifying the video path.
144
157
  video_dataset: (str) The dataset for HDF5 videos.
145
158
  video_input_format: (str) The input_format for HDF5 videos.
146
159
  frames: (list) List of frames indices. If `None`, all frames in the video are used. Default: None.
147
160
  crop_size: (int) Crop size. If not provided, the crop size from training_config.yaml is used.
148
- Default: None.
161
+ If `input_scale` is provided, then the cropped image will be resized according to `input_scale`. Default: None.
149
162
  peak_threshold: (float) Minimum confidence threshold. Peaks with values below
150
163
  this will be ignored. Default: 0.2. This can also be `List[float]` for topdown
151
164
  centroid and centered-instance model, where the first element corresponds
152
165
  to centroid model peak finding threshold and the second element is for
153
166
  centered-instance model peak finding.
167
+ filter_overlapping: (bool) If True, removes overlapping instances after
168
+ inference using greedy NMS. Applied independently of tracking.
169
+ Default: False.
170
+ filter_overlapping_method: (str) Similarity metric for filtering overlapping
171
+ instances. One of "iou" (bounding box) or "oks" (keypoint similarity).
172
+ Default: "iou".
173
+ filter_overlapping_threshold: (float) Similarity threshold for filtering.
174
+ Instances with similarity > threshold are removed (keeping higher-scoring).
175
+ Typical values: 0.3 (aggressive) to 0.8 (permissive). Default: 0.8.
154
176
  integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
155
177
  If `"integral"`, peaks will be refined with integral regression.
156
178
  Default: `"integral"`.
@@ -241,6 +263,8 @@ def run_inference(
241
263
  tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
242
264
  tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
243
265
  tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
266
+ gui: (bool) If True, outputs JSON progress lines for GUI integration instead
267
+ of Rich progress bars. Default: False.
244
268
 
245
269
  Returns:
246
270
  Returns `sio.Labels` object if `make_labels` is True. Else this function returns
@@ -256,6 +280,27 @@ def run_inference(
256
280
  "scale": input_scale,
257
281
  }
258
282
 
283
+ # Validate mutually exclusive frame filter flags
284
+ if only_labeled_frames and exclude_user_labeled:
285
+ message = (
286
+ "--only_labeled_frames and --exclude_user_labeled are mutually exclusive "
287
+ "(would result in zero frames)"
288
+ )
289
+ logger.error(message)
290
+ raise ValueError(message)
291
+
292
+ if (
293
+ only_predicted_frames
294
+ and data_path is not None
295
+ and not data_path.endswith(".slp")
296
+ ):
297
+ message = (
298
+ "--only_predicted_frames requires a .slp file input "
299
+ "(need Labels to know which frames have predictions)"
300
+ )
301
+ logger.error(message)
302
+ raise ValueError(message)
303
+
259
304
  if model_paths is None or not len(
260
305
  model_paths
261
306
  ): # if model paths is not provided, run tracking-only pipeline.
@@ -273,7 +318,8 @@ def run_inference(
273
318
  raise ValueError(message)
274
319
 
275
320
  start_inf_time = time()
276
- start_timestamp = str(datetime.now())
321
+ start_datetime = datetime.now()
322
+ start_timestamp = str(start_datetime)
277
323
  logger.info(f"Started tracking at: {start_timestamp}")
278
324
 
279
325
  labels = sio.load_slp(data_path) if input_labels is None else input_labels
@@ -302,7 +348,22 @@ def run_inference(
302
348
 
303
349
  if post_connect_single_breaks or tracking_pre_cull_to_target:
304
350
  if tracking_target_instance_count is None and max_instances is None:
305
- message = "Both tracking_target_instance_count and max_instances is set to 0. To connect single breaks or pre-cull to target, at least one of them should be set to an integer."
351
+ features_requested = []
352
+ if post_connect_single_breaks:
353
+ features_requested.append("--post_connect_single_breaks")
354
+ if tracking_pre_cull_to_target:
355
+ features_requested.append("--tracking_pre_cull_to_target")
356
+ features_str = " and ".join(features_requested)
357
+
358
+ if max_tracks is not None:
359
+ suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
360
+ else:
361
+ suggestion = "Add --tracking_target_instance_count N where N is the expected number of instances per frame."
362
+
363
+ message = (
364
+ f"{features_str} requires --tracking_target_instance_count to be set. "
365
+ f"{suggestion}"
366
+ )
306
367
  logger.error(message)
307
368
  raise ValueError(message)
308
369
  elif tracking_target_instance_count is None:
@@ -332,21 +393,53 @@ def run_inference(
332
393
  tracking_clean_iou_threshold=tracking_clean_iou_threshold,
333
394
  )
334
395
 
335
- finish_timestamp = str(datetime.now())
396
+ end_datetime = datetime.now()
397
+ finish_timestamp = str(end_datetime)
336
398
  total_elapsed = time() - start_inf_time
337
399
  logger.info(f"Finished tracking at: {finish_timestamp}")
338
400
  logger.info(f"Total runtime: {total_elapsed} secs")
339
401
 
402
+ # Build tracking-only provenance
403
+ tracking_params = {
404
+ "window_size": tracking_window_size,
405
+ "min_new_track_points": min_new_track_points,
406
+ "candidates_method": candidates_method,
407
+ "min_match_points": min_match_points,
408
+ "features": features,
409
+ "scoring_method": scoring_method,
410
+ "scoring_reduction": scoring_reduction,
411
+ "robust_best_instance": robust_best_instance,
412
+ "track_matching_method": track_matching_method,
413
+ "max_tracks": max_tracks,
414
+ "use_flow": use_flow,
415
+ "post_connect_single_breaks": post_connect_single_breaks,
416
+ }
417
+ provenance = build_tracking_only_provenance(
418
+ input_labels=labels,
419
+ input_path=data_path,
420
+ start_time=start_datetime,
421
+ end_time=end_datetime,
422
+ tracking_params=tracking_params,
423
+ frames_processed=len(tracked_frames),
424
+ )
425
+
340
426
  output = sio.Labels(
341
427
  labeled_frames=tracked_frames,
342
428
  videos=labels.videos,
343
429
  skeletons=labels.skeletons,
430
+ provenance=provenance,
344
431
  )
345
432
 
346
433
  else:
347
434
  start_inf_time = time()
348
- start_timestamp = str(datetime.now())
435
+ start_datetime = datetime.now()
436
+ start_timestamp = str(start_datetime)
349
437
  logger.info(f"Started inference at: {start_timestamp}")
438
+ logger.info(get_startup_info_string())
439
+
440
+ # Convert device to string if it's a torch.device object
441
+ if hasattr(device, "type"):
442
+ device = str(device)
350
443
 
351
444
  if device == "auto":
352
445
  device = (
@@ -355,13 +448,6 @@ def run_inference(
355
448
  else "mps" if torch.backends.mps.is_available() else "cpu"
356
449
  )
357
450
 
358
- if integral_refinement is not None and device == "mps": # TODO
359
- # kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
360
- logger.info(
361
- "Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
362
- )
363
- integral_refinement = None
364
-
365
451
  logger.info(f"Using device: {device}")
366
452
 
367
453
  # initializes the inference model
@@ -380,6 +466,9 @@ def run_inference(
380
466
  anchor_part=anchor_part,
381
467
  )
382
468
 
469
+ # Set GUI mode for progress output
470
+ predictor.gui = gui
471
+
383
472
  if (
384
473
  tracking
385
474
  and not isinstance(predictor, BottomUpMultiClassPredictor)
@@ -387,7 +476,22 @@ def run_inference(
387
476
  ):
388
477
  if post_connect_single_breaks or tracking_pre_cull_to_target:
389
478
  if tracking_target_instance_count is None and max_instances is None:
390
- message = "Both tracking_target_instance_count and max_instances is set to 0. To connect single breaks or pre-cull to target, at least one of them should be set to an integer."
479
+ features_requested = []
480
+ if post_connect_single_breaks:
481
+ features_requested.append("--post_connect_single_breaks")
482
+ if tracking_pre_cull_to_target:
483
+ features_requested.append("--tracking_pre_cull_to_target")
484
+ features_str = " and ".join(features_requested)
485
+
486
+ if max_tracks is not None:
487
+ suggestion = f"Add --tracking_target_instance_count {max_tracks} to your command (using your --max_tracks value)."
488
+ else:
489
+ suggestion = "Add --tracking_target_instance_count N or --max_instances N where N is the expected number of instances per frame."
490
+
491
+ message = (
492
+ f"{features_str} requires --tracking_target_instance_count or --max_instances to be set. "
493
+ f"{suggestion}"
494
+ )
391
495
  logger.error(message)
392
496
  raise ValueError(message)
393
497
  elif tracking_target_instance_count is None:
@@ -448,6 +552,8 @@ def run_inference(
448
552
  frames=frames,
449
553
  only_labeled_frames=only_labeled_frames,
450
554
  only_suggested_frames=only_suggested_frames,
555
+ exclude_user_labeled=exclude_user_labeled,
556
+ only_predicted_frames=only_predicted_frames,
451
557
  video_index=video_index,
452
558
  video_dataset=video_dataset,
453
559
  video_input_format=video_input_format,
@@ -458,6 +564,20 @@ def run_inference(
458
564
  make_labels=make_labels,
459
565
  )
460
566
 
567
+ # Filter overlapping instances (independent of tracking)
568
+ if filter_overlapping and make_labels:
569
+ from sleap_nn.inference.postprocessing import filter_overlapping_instances
570
+
571
+ output = filter_overlapping_instances(
572
+ output,
573
+ threshold=filter_overlapping_threshold,
574
+ method=filter_overlapping_method,
575
+ )
576
+ logger.info(
577
+ f"Filtered overlapping instances with {filter_overlapping_method.upper()} "
578
+ f"threshold: {filter_overlapping_threshold}"
579
+ )
580
+
461
581
  if tracking:
462
582
  lfs = [x for x in output]
463
583
  if tracking_clean_instance_count > 0:
@@ -492,12 +612,97 @@ def run_inference(
492
612
  skeletons=output.skeletons,
493
613
  )
494
614
 
495
- finish_timestamp = str(datetime.now())
615
+ end_datetime = datetime.now()
616
+ finish_timestamp = str(end_datetime)
496
617
  total_elapsed = time() - start_inf_time
497
618
  logger.info(f"Finished inference at: {finish_timestamp}")
498
- logger.info(
499
- f"Total runtime: {total_elapsed} secs"
500
- ) # TODO: add number of predicted frames
619
+ logger.info(f"Total runtime: {total_elapsed} secs")
620
+
621
+ # Determine input labels for provenance preservation
622
+ input_labels_for_prov = None
623
+ if input_labels is not None:
624
+ input_labels_for_prov = input_labels
625
+ elif data_path is not None and data_path.endswith(".slp"):
626
+ # Load input labels to preserve provenance (if not already loaded)
627
+ try:
628
+ input_labels_for_prov = sio.load_slp(data_path)
629
+ except Exception:
630
+ pass
631
+
632
+ # Build inference parameters for provenance
633
+ inference_params = {
634
+ "peak_threshold": peak_threshold,
635
+ "filter_overlapping": filter_overlapping,
636
+ "filter_overlapping_method": filter_overlapping_method,
637
+ "filter_overlapping_threshold": filter_overlapping_threshold,
638
+ "integral_refinement": integral_refinement,
639
+ "integral_patch_size": integral_patch_size,
640
+ "batch_size": batch_size,
641
+ "max_instances": max_instances,
642
+ "crop_size": crop_size,
643
+ "input_scale": input_scale,
644
+ "anchor_part": anchor_part,
645
+ }
646
+
647
+ # Build tracking parameters if tracking was enabled
648
+ tracking_params_prov = None
649
+ if tracking:
650
+ tracking_params_prov = {
651
+ "window_size": tracking_window_size,
652
+ "min_new_track_points": min_new_track_points,
653
+ "candidates_method": candidates_method,
654
+ "min_match_points": min_match_points,
655
+ "features": features,
656
+ "scoring_method": scoring_method,
657
+ "scoring_reduction": scoring_reduction,
658
+ "robust_best_instance": robust_best_instance,
659
+ "track_matching_method": track_matching_method,
660
+ "max_tracks": max_tracks,
661
+ "use_flow": use_flow,
662
+ "post_connect_single_breaks": post_connect_single_breaks,
663
+ }
664
+
665
+ # Determine frame selection method
666
+ frame_selection_method = "all"
667
+ if only_labeled_frames:
668
+ frame_selection_method = "labeled"
669
+ elif only_suggested_frames:
670
+ frame_selection_method = "suggested"
671
+ elif only_predicted_frames:
672
+ frame_selection_method = "predicted"
673
+ elif frames is not None:
674
+ frame_selection_method = "specified"
675
+
676
+ # Determine model type from predictor class
677
+ predictor_type_map = {
678
+ "TopDownPredictor": "top_down",
679
+ "SingleInstancePredictor": "single_instance",
680
+ "BottomUpPredictor": "bottom_up",
681
+ "BottomUpMultiClassPredictor": "bottom_up_multi_class",
682
+ "TopDownMultiClassPredictor": "top_down_multi_class",
683
+ }
684
+ model_type = predictor_type_map.get(type(predictor).__name__)
685
+
686
+ # Build and set provenance (only for Labels objects)
687
+ if make_labels and isinstance(output, sio.Labels):
688
+ provenance = build_inference_provenance(
689
+ model_paths=model_paths,
690
+ model_type=model_type,
691
+ start_time=start_datetime,
692
+ end_time=end_datetime,
693
+ input_labels=input_labels_for_prov,
694
+ input_path=data_path,
695
+ frames_processed=(
696
+ len(output.labeled_frames)
697
+ if hasattr(output, "labeled_frames")
698
+ else None
699
+ ),
700
+ frame_selection_method=frame_selection_method,
701
+ inference_params=inference_params,
702
+ tracking_params=tracking_params_prov,
703
+ device=device,
704
+ )
705
+ output.provenance = provenance
501
706
 
502
707
  if no_empty_frames:
503
708
  output.clean(frames=True, skeletons=False)