sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +168 -39
  6. sleap_nn/evaluation.py +8 -0
  7. sleap_nn/export/__init__.py +21 -0
  8. sleap_nn/export/cli.py +1778 -0
  9. sleap_nn/export/exporters/__init__.py +51 -0
  10. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  11. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  12. sleap_nn/export/metadata.py +225 -0
  13. sleap_nn/export/predictors/__init__.py +63 -0
  14. sleap_nn/export/predictors/base.py +22 -0
  15. sleap_nn/export/predictors/onnx.py +154 -0
  16. sleap_nn/export/predictors/tensorrt.py +312 -0
  17. sleap_nn/export/utils.py +307 -0
  18. sleap_nn/export/wrappers/__init__.py +25 -0
  19. sleap_nn/export/wrappers/base.py +96 -0
  20. sleap_nn/export/wrappers/bottomup.py +243 -0
  21. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  22. sleap_nn/export/wrappers/centered_instance.py +56 -0
  23. sleap_nn/export/wrappers/centroid.py +58 -0
  24. sleap_nn/export/wrappers/single_instance.py +83 -0
  25. sleap_nn/export/wrappers/topdown.py +180 -0
  26. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  27. sleap_nn/inference/peak_finding.py +47 -17
  28. sleap_nn/inference/postprocessing.py +284 -0
  29. sleap_nn/inference/predictors.py +213 -106
  30. sleap_nn/predict.py +35 -7
  31. sleap_nn/train.py +64 -0
  32. sleap_nn/training/callbacks.py +69 -22
  33. sleap_nn/training/lightning_modules.py +332 -30
  34. sleap_nn/training/model_trainer.py +67 -67
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
  36. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
  37. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
  38. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
  39. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
  40. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
@@ -85,10 +85,15 @@ class CSVLoggerCallback(Callback):
85
85
  if key == "epoch":
86
86
  log_data["epoch"] = trainer.current_epoch
87
87
  elif key == "learning_rate":
88
- # Handle both direct logging and LearningRateMonitor format (lr-*)
88
+ # Handle multiple formats:
89
+ # 1. Direct "learning_rate" key
90
+ # 2. "train/lr" key (current format from lightning modules)
91
+ # 3. "lr-*" keys from LearningRateMonitor (legacy)
89
92
  value = metrics.get(key, None)
90
93
  if value is None:
91
- # Look for lr-* keys from LearningRateMonitor
94
+ value = metrics.get("train/lr", None)
95
+ if value is None:
96
+ # Look for lr-* keys from LearningRateMonitor (legacy)
92
97
  for metric_key in metrics.keys():
93
98
  if metric_key.startswith("lr-"):
94
99
  value = metrics[metric_key]
@@ -295,8 +300,8 @@ class WandBVizCallback(Callback):
295
300
  suffix = "" if mode_name == "direct" else f"_{mode_name}"
296
301
  train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
297
302
  val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
298
- log_dict[f"train_predictions{suffix}"] = train_img
299
- log_dict[f"val_predictions{suffix}"] = val_img
303
+ log_dict[f"viz/train/predictions{suffix}"] = train_img
304
+ log_dict[f"viz/val/predictions{suffix}"] = val_img
300
305
 
301
306
  if log_dict:
302
307
  # Include epoch so wandb can use it as x-axis (via define_metric)
@@ -394,8 +399,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
394
399
  suffix = "" if mode_name == "direct" else f"_{mode_name}"
395
400
  train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
396
401
  val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
397
- log_dict[f"train_predictions{suffix}"] = train_img
398
- log_dict[f"val_predictions{suffix}"] = val_img
402
+ log_dict[f"viz/train/predictions{suffix}"] = train_img
403
+ log_dict[f"viz/val/predictions{suffix}"] = val_img
399
404
 
400
405
  # Render PAFs (always use matplotlib/direct for PAFs)
401
406
  from io import BytesIO
@@ -408,7 +413,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
408
413
  buf.seek(0)
409
414
  plt.close(train_pafs_fig)
410
415
  train_pafs_pil = Image.open(buf)
411
- log_dict["train_pafs"] = wandb.Image(
416
+ log_dict["viz/train/pafs"] = wandb.Image(
412
417
  train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
413
418
  )
414
419
 
@@ -418,7 +423,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
418
423
  buf.seek(0)
419
424
  plt.close(val_pafs_fig)
420
425
  val_pafs_pil = Image.open(buf)
421
- log_dict["val_pafs"] = wandb.Image(
426
+ log_dict["viz/val/pafs"] = wandb.Image(
422
427
  val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
423
428
  )
424
429
 
@@ -444,8 +449,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
444
449
  epoch,
445
450
  train_img,
446
451
  val_img,
447
- log_dict["train_pafs"],
448
- log_dict["val_pafs"],
452
+ log_dict["viz/train/pafs"],
453
+ log_dict["viz/val/pafs"],
449
454
  ]
450
455
  ],
451
456
  )
@@ -709,9 +714,13 @@ class EpochEndEvaluationCallback(Callback):
709
714
  "mOKS",
710
715
  "oks_voc.mAP",
711
716
  "oks_voc.mAR",
712
- "avg_distance",
713
- "p50_distance",
717
+ "distance/avg",
718
+ "distance/p50",
719
+ "distance/p95",
720
+ "distance/p99",
714
721
  "mPCK",
722
+ "PCK@5",
723
+ "PCK@10",
715
724
  "visibility_precision",
716
725
  "visibility_recall",
717
726
  ]
@@ -779,6 +788,7 @@ class EpochEndEvaluationCallback(Callback):
779
788
 
780
789
  logger.info(
781
790
  f"Epoch {trainer.current_epoch} evaluation: "
791
+ f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
782
792
  f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
783
793
  f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
784
794
  )
@@ -903,36 +913,73 @@ class EpochEndEvaluationCallback(Callback):
903
913
  log_dict = {"epoch": epoch}
904
914
 
905
915
  # Extract key metrics with consistent naming
916
+ # All eval metrics use eval/val/ prefix since they're computed on validation data
906
917
  if "mOKS" in self.metrics_to_log:
907
- log_dict["val_mOKS"] = metrics["mOKS"]["mOKS"]
918
+ log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
908
919
 
909
920
  if "oks_voc.mAP" in self.metrics_to_log:
910
- log_dict["val_oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
921
+ log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
911
922
 
912
923
  if "oks_voc.mAR" in self.metrics_to_log:
913
- log_dict["val_oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
924
+ log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
914
925
 
915
- if "avg_distance" in self.metrics_to_log:
926
+ # Distance metrics grouped under eval/val/distance/
927
+ if "distance/avg" in self.metrics_to_log:
916
928
  val = metrics["distance_metrics"]["avg"]
917
929
  if not np.isnan(val):
918
- log_dict["val_avg_distance"] = val
930
+ log_dict["eval/val/distance/avg"] = val
919
931
 
920
- if "p50_distance" in self.metrics_to_log:
932
+ if "distance/p50" in self.metrics_to_log:
921
933
  val = metrics["distance_metrics"]["p50"]
922
934
  if not np.isnan(val):
923
- log_dict["val_p50_distance"] = val
935
+ log_dict["eval/val/distance/p50"] = val
936
+
937
+ if "distance/p95" in self.metrics_to_log:
938
+ val = metrics["distance_metrics"]["p95"]
939
+ if not np.isnan(val):
940
+ log_dict["eval/val/distance/p95"] = val
941
+
942
+ if "distance/p99" in self.metrics_to_log:
943
+ val = metrics["distance_metrics"]["p99"]
944
+ if not np.isnan(val):
945
+ log_dict["eval/val/distance/p99"] = val
924
946
 
947
+ # PCK metrics
925
948
  if "mPCK" in self.metrics_to_log:
926
- log_dict["val_mPCK"] = metrics["pck_metrics"]["mPCK"]
949
+ log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
927
950
 
951
+ # PCK at specific thresholds (precomputed in evaluation.py)
952
+ if "PCK@5" in self.metrics_to_log:
953
+ log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
954
+
955
+ if "PCK@10" in self.metrics_to_log:
956
+ log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
957
+
958
+ # Visibility metrics
928
959
  if "visibility_precision" in self.metrics_to_log:
929
960
  val = metrics["visibility_metrics"]["precision"]
930
961
  if not np.isnan(val):
931
- log_dict["val_visibility_precision"] = val
962
+ log_dict["eval/val/visibility_precision"] = val
932
963
 
933
964
  if "visibility_recall" in self.metrics_to_log:
934
965
  val = metrics["visibility_metrics"]["recall"]
935
966
  if not np.isnan(val):
936
- log_dict["val_visibility_recall"] = val
967
+ log_dict["eval/val/visibility_recall"] = val
937
968
 
938
969
  wandb_logger.experiment.log(log_dict, commit=False)
970
+
971
+ # Update best metrics in summary (excluding epoch)
972
+ for key, value in log_dict.items():
973
+ if key == "epoch":
974
+ continue
975
+ # Create summary key like "best/eval/val/mOKS"
976
+ summary_key = f"best/{key}"
977
+ current_best = wandb_logger.experiment.summary.get(summary_key)
978
+ # For distance metrics, lower is better; for others, higher is better
979
+ is_distance = "distance" in key
980
+ if current_best is None:
981
+ wandb_logger.experiment.summary[summary_key] = value
982
+ elif is_distance and value < current_best:
983
+ wandb_logger.experiment.summary[summary_key] = value
984
+ elif not is_distance and value > current_best:
985
+ wandb_logger.experiment.summary[summary_key] = value