sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/evaluation.py +8 -0
  4. sleap_nn/export/__init__.py +21 -0
  5. sleap_nn/export/cli.py +1778 -0
  6. sleap_nn/export/exporters/__init__.py +51 -0
  7. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  8. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  9. sleap_nn/export/metadata.py +225 -0
  10. sleap_nn/export/predictors/__init__.py +63 -0
  11. sleap_nn/export/predictors/base.py +22 -0
  12. sleap_nn/export/predictors/onnx.py +154 -0
  13. sleap_nn/export/predictors/tensorrt.py +312 -0
  14. sleap_nn/export/utils.py +307 -0
  15. sleap_nn/export/wrappers/__init__.py +25 -0
  16. sleap_nn/export/wrappers/base.py +96 -0
  17. sleap_nn/export/wrappers/bottomup.py +243 -0
  18. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  19. sleap_nn/export/wrappers/centered_instance.py +56 -0
  20. sleap_nn/export/wrappers/centroid.py +58 -0
  21. sleap_nn/export/wrappers/single_instance.py +83 -0
  22. sleap_nn/export/wrappers/topdown.py +180 -0
  23. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  24. sleap_nn/inference/postprocessing.py +284 -0
  25. sleap_nn/predict.py +29 -0
  26. sleap_nn/train.py +64 -0
  27. sleap_nn/training/callbacks.py +62 -20
  28. sleap_nn/training/lightning_modules.py +332 -30
  29. sleap_nn/training/model_trainer.py +35 -67
  30. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
  31. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
  32. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  33. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  34. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/train.py CHANGED
@@ -118,6 +118,70 @@ def run_training(
118
118
  logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
119
119
  logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
120
120
 
121
+ # Log test metrics to wandb summary
122
+ if (
123
+ d_name.startswith("test")
124
+ and trainer.config.trainer_config.use_wandb
125
+ ):
126
+ import wandb
127
+
128
+ if wandb.run is not None:
129
+ summary_metrics = {
130
+ f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
131
+ f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
132
+ "oks_voc.mAP"
133
+ ],
134
+ f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
135
+ "oks_voc.mAR"
136
+ ],
137
+ f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
138
+ f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
139
+ f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
140
+ f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
141
+ "avg"
142
+ ],
143
+ f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
144
+ "p50"
145
+ ],
146
+ f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
147
+ "p95"
148
+ ],
149
+ f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
150
+ "p99"
151
+ ],
152
+ f"eval/{d_name}/visibility_precision": metrics[
153
+ "visibility_metrics"
154
+ ]["precision"],
155
+ f"eval/{d_name}/visibility_recall": metrics[
156
+ "visibility_metrics"
157
+ ]["recall"],
158
+ }
159
+ for key, value in summary_metrics.items():
160
+ wandb.run.summary[key] = value
161
+
162
+ # Finish wandb run and cleanup after all evaluation is complete
163
+ if trainer.config.trainer_config.use_wandb:
164
+ import wandb
165
+ import shutil
166
+
167
+ if wandb.run is not None:
168
+ wandb.finish()
169
+
170
+ # Delete local wandb logs if configured
171
+ wandb_config = trainer.config.trainer_config.wandb
172
+ should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
173
+ wandb_config.delete_local_logs is None
174
+ and wandb_config.wandb_mode != "offline"
175
+ )
176
+ if should_delete_wandb_logs:
177
+ wandb_dir = run_path / "wandb"
178
+ if wandb_dir.exists():
179
+ logger.info(
180
+ f"Deleting local wandb logs at {wandb_dir}... "
181
+ "(set trainer_config.wandb.delete_local_logs=false to disable)"
182
+ )
183
+ shutil.rmtree(wandb_dir, ignore_errors=True)
184
+
121
185
 
122
186
  def train(
123
187
  train_labels_path: Optional[List[str]] = None,
@@ -295,8 +295,8 @@ class WandBVizCallback(Callback):
295
295
  suffix = "" if mode_name == "direct" else f"_{mode_name}"
296
296
  train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
297
297
  val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
298
- log_dict[f"train_predictions{suffix}"] = train_img
299
- log_dict[f"val_predictions{suffix}"] = val_img
298
+ log_dict[f"viz/train/predictions{suffix}"] = train_img
299
+ log_dict[f"viz/val/predictions{suffix}"] = val_img
300
300
 
301
301
  if log_dict:
302
302
  # Include epoch so wandb can use it as x-axis (via define_metric)
@@ -394,8 +394,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
394
394
  suffix = "" if mode_name == "direct" else f"_{mode_name}"
395
395
  train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
396
396
  val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
397
- log_dict[f"train_predictions{suffix}"] = train_img
398
- log_dict[f"val_predictions{suffix}"] = val_img
397
+ log_dict[f"viz/train/predictions{suffix}"] = train_img
398
+ log_dict[f"viz/val/predictions{suffix}"] = val_img
399
399
 
400
400
  # Render PAFs (always use matplotlib/direct for PAFs)
401
401
  from io import BytesIO
@@ -408,7 +408,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
408
408
  buf.seek(0)
409
409
  plt.close(train_pafs_fig)
410
410
  train_pafs_pil = Image.open(buf)
411
- log_dict["train_pafs"] = wandb.Image(
411
+ log_dict["viz/train/pafs"] = wandb.Image(
412
412
  train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
413
413
  )
414
414
 
@@ -418,7 +418,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
418
418
  buf.seek(0)
419
419
  plt.close(val_pafs_fig)
420
420
  val_pafs_pil = Image.open(buf)
421
- log_dict["val_pafs"] = wandb.Image(
421
+ log_dict["viz/val/pafs"] = wandb.Image(
422
422
  val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
423
423
  )
424
424
 
@@ -444,8 +444,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
444
444
  epoch,
445
445
  train_img,
446
446
  val_img,
447
- log_dict["train_pafs"],
448
- log_dict["val_pafs"],
447
+ log_dict["viz/train/pafs"],
448
+ log_dict["viz/val/pafs"],
449
449
  ]
450
450
  ],
451
451
  )
@@ -709,9 +709,13 @@ class EpochEndEvaluationCallback(Callback):
709
709
  "mOKS",
710
710
  "oks_voc.mAP",
711
711
  "oks_voc.mAR",
712
- "avg_distance",
713
- "p50_distance",
712
+ "distance/avg",
713
+ "distance/p50",
714
+ "distance/p95",
715
+ "distance/p99",
714
716
  "mPCK",
717
+ "PCK@5",
718
+ "PCK@10",
715
719
  "visibility_precision",
716
720
  "visibility_recall",
717
721
  ]
@@ -779,6 +783,7 @@ class EpochEndEvaluationCallback(Callback):
779
783
 
780
784
  logger.info(
781
785
  f"Epoch {trainer.current_epoch} evaluation: "
786
+ f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
782
787
  f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
783
788
  f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
784
789
  )
@@ -903,36 +908,73 @@ class EpochEndEvaluationCallback(Callback):
903
908
  log_dict = {"epoch": epoch}
904
909
 
905
910
  # Extract key metrics with consistent naming
911
+ # All eval metrics use eval/val/ prefix since they're computed on validation data
906
912
  if "mOKS" in self.metrics_to_log:
907
- log_dict["val_mOKS"] = metrics["mOKS"]["mOKS"]
913
+ log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
908
914
 
909
915
  if "oks_voc.mAP" in self.metrics_to_log:
910
- log_dict["val_oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
916
+ log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
911
917
 
912
918
  if "oks_voc.mAR" in self.metrics_to_log:
913
- log_dict["val_oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
919
+ log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
914
920
 
915
- if "avg_distance" in self.metrics_to_log:
921
+ # Distance metrics grouped under eval/val/distance/
922
+ if "distance/avg" in self.metrics_to_log:
916
923
  val = metrics["distance_metrics"]["avg"]
917
924
  if not np.isnan(val):
918
- log_dict["val_avg_distance"] = val
925
+ log_dict["eval/val/distance/avg"] = val
919
926
 
920
- if "p50_distance" in self.metrics_to_log:
927
+ if "distance/p50" in self.metrics_to_log:
921
928
  val = metrics["distance_metrics"]["p50"]
922
929
  if not np.isnan(val):
923
- log_dict["val_p50_distance"] = val
930
+ log_dict["eval/val/distance/p50"] = val
924
931
 
932
+ if "distance/p95" in self.metrics_to_log:
933
+ val = metrics["distance_metrics"]["p95"]
934
+ if not np.isnan(val):
935
+ log_dict["eval/val/distance/p95"] = val
936
+
937
+ if "distance/p99" in self.metrics_to_log:
938
+ val = metrics["distance_metrics"]["p99"]
939
+ if not np.isnan(val):
940
+ log_dict["eval/val/distance/p99"] = val
941
+
942
+ # PCK metrics
925
943
  if "mPCK" in self.metrics_to_log:
926
- log_dict["val_mPCK"] = metrics["pck_metrics"]["mPCK"]
944
+ log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
927
945
 
946
+ # PCK at specific thresholds (precomputed in evaluation.py)
947
+ if "PCK@5" in self.metrics_to_log:
948
+ log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
949
+
950
+ if "PCK@10" in self.metrics_to_log:
951
+ log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
952
+
953
+ # Visibility metrics
928
954
  if "visibility_precision" in self.metrics_to_log:
929
955
  val = metrics["visibility_metrics"]["precision"]
930
956
  if not np.isnan(val):
931
- log_dict["val_visibility_precision"] = val
957
+ log_dict["eval/val/visibility_precision"] = val
932
958
 
933
959
  if "visibility_recall" in self.metrics_to_log:
934
960
  val = metrics["visibility_metrics"]["recall"]
935
961
  if not np.isnan(val):
936
- log_dict["val_visibility_recall"] = val
962
+ log_dict["eval/val/visibility_recall"] = val
937
963
 
938
964
  wandb_logger.experiment.log(log_dict, commit=False)
965
+
966
+ # Update best metrics in summary (excluding epoch)
967
+ for key, value in log_dict.items():
968
+ if key == "epoch":
969
+ continue
970
+ # Create summary key like "best/eval/val/mOKS"
971
+ summary_key = f"best/{key}"
972
+ current_best = wandb_logger.experiment.summary.get(summary_key)
973
+ # For distance metrics, lower is better; for others, higher is better
974
+ is_distance = "distance" in key
975
+ if current_best is None:
976
+ wandb_logger.experiment.summary[summary_key] = value
977
+ elif is_distance and value < current_best:
978
+ wandb_logger.experiment.summary[summary_key] = value
979
+ elif not is_distance and value > current_best:
980
+ wandb_logger.experiment.summary[summary_key] = value