sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +168 -39
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/peak_finding.py +47 -17
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +213 -106
- sleap_nn/predict.py +35 -7
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +69 -22
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +67 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
sleap_nn/training/callbacks.py
CHANGED
|
@@ -85,10 +85,15 @@ class CSVLoggerCallback(Callback):
|
|
|
85
85
|
if key == "epoch":
|
|
86
86
|
log_data["epoch"] = trainer.current_epoch
|
|
87
87
|
elif key == "learning_rate":
|
|
88
|
-
# Handle
|
|
88
|
+
# Handle multiple formats:
|
|
89
|
+
# 1. Direct "learning_rate" key
|
|
90
|
+
# 2. "train/lr" key (current format from lightning modules)
|
|
91
|
+
# 3. "lr-*" keys from LearningRateMonitor (legacy)
|
|
89
92
|
value = metrics.get(key, None)
|
|
90
93
|
if value is None:
|
|
91
|
-
|
|
94
|
+
value = metrics.get("train/lr", None)
|
|
95
|
+
if value is None:
|
|
96
|
+
# Look for lr-* keys from LearningRateMonitor (legacy)
|
|
92
97
|
for metric_key in metrics.keys():
|
|
93
98
|
if metric_key.startswith("lr-"):
|
|
94
99
|
value = metrics[metric_key]
|
|
@@ -295,8 +300,8 @@ class WandBVizCallback(Callback):
|
|
|
295
300
|
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
296
301
|
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
297
302
|
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
298
|
-
log_dict[f"
|
|
299
|
-
log_dict[f"
|
|
303
|
+
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
304
|
+
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
300
305
|
|
|
301
306
|
if log_dict:
|
|
302
307
|
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
@@ -394,8 +399,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
394
399
|
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
395
400
|
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
396
401
|
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
397
|
-
log_dict[f"
|
|
398
|
-
log_dict[f"
|
|
402
|
+
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
403
|
+
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
399
404
|
|
|
400
405
|
# Render PAFs (always use matplotlib/direct for PAFs)
|
|
401
406
|
from io import BytesIO
|
|
@@ -408,7 +413,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
408
413
|
buf.seek(0)
|
|
409
414
|
plt.close(train_pafs_fig)
|
|
410
415
|
train_pafs_pil = Image.open(buf)
|
|
411
|
-
log_dict["
|
|
416
|
+
log_dict["viz/train/pafs"] = wandb.Image(
|
|
412
417
|
train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
|
|
413
418
|
)
|
|
414
419
|
|
|
@@ -418,7 +423,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
418
423
|
buf.seek(0)
|
|
419
424
|
plt.close(val_pafs_fig)
|
|
420
425
|
val_pafs_pil = Image.open(buf)
|
|
421
|
-
log_dict["
|
|
426
|
+
log_dict["viz/val/pafs"] = wandb.Image(
|
|
422
427
|
val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
|
|
423
428
|
)
|
|
424
429
|
|
|
@@ -444,8 +449,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
444
449
|
epoch,
|
|
445
450
|
train_img,
|
|
446
451
|
val_img,
|
|
447
|
-
log_dict["
|
|
448
|
-
log_dict["
|
|
452
|
+
log_dict["viz/train/pafs"],
|
|
453
|
+
log_dict["viz/val/pafs"],
|
|
449
454
|
]
|
|
450
455
|
],
|
|
451
456
|
)
|
|
@@ -709,9 +714,13 @@ class EpochEndEvaluationCallback(Callback):
|
|
|
709
714
|
"mOKS",
|
|
710
715
|
"oks_voc.mAP",
|
|
711
716
|
"oks_voc.mAR",
|
|
712
|
-
"
|
|
713
|
-
"
|
|
717
|
+
"distance/avg",
|
|
718
|
+
"distance/p50",
|
|
719
|
+
"distance/p95",
|
|
720
|
+
"distance/p99",
|
|
714
721
|
"mPCK",
|
|
722
|
+
"PCK@5",
|
|
723
|
+
"PCK@10",
|
|
715
724
|
"visibility_precision",
|
|
716
725
|
"visibility_recall",
|
|
717
726
|
]
|
|
@@ -779,6 +788,7 @@ class EpochEndEvaluationCallback(Callback):
|
|
|
779
788
|
|
|
780
789
|
logger.info(
|
|
781
790
|
f"Epoch {trainer.current_epoch} evaluation: "
|
|
791
|
+
f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
|
|
782
792
|
f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
|
|
783
793
|
f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
|
|
784
794
|
)
|
|
@@ -903,36 +913,73 @@ class EpochEndEvaluationCallback(Callback):
|
|
|
903
913
|
log_dict = {"epoch": epoch}
|
|
904
914
|
|
|
905
915
|
# Extract key metrics with consistent naming
|
|
916
|
+
# All eval metrics use eval/val/ prefix since they're computed on validation data
|
|
906
917
|
if "mOKS" in self.metrics_to_log:
|
|
907
|
-
log_dict["
|
|
918
|
+
log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
|
|
908
919
|
|
|
909
920
|
if "oks_voc.mAP" in self.metrics_to_log:
|
|
910
|
-
log_dict["
|
|
921
|
+
log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
|
|
911
922
|
|
|
912
923
|
if "oks_voc.mAR" in self.metrics_to_log:
|
|
913
|
-
log_dict["
|
|
924
|
+
log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
|
|
914
925
|
|
|
915
|
-
|
|
926
|
+
# Distance metrics grouped under eval/val/distance/
|
|
927
|
+
if "distance/avg" in self.metrics_to_log:
|
|
916
928
|
val = metrics["distance_metrics"]["avg"]
|
|
917
929
|
if not np.isnan(val):
|
|
918
|
-
log_dict["
|
|
930
|
+
log_dict["eval/val/distance/avg"] = val
|
|
919
931
|
|
|
920
|
-
if "
|
|
932
|
+
if "distance/p50" in self.metrics_to_log:
|
|
921
933
|
val = metrics["distance_metrics"]["p50"]
|
|
922
934
|
if not np.isnan(val):
|
|
923
|
-
log_dict["
|
|
935
|
+
log_dict["eval/val/distance/p50"] = val
|
|
936
|
+
|
|
937
|
+
if "distance/p95" in self.metrics_to_log:
|
|
938
|
+
val = metrics["distance_metrics"]["p95"]
|
|
939
|
+
if not np.isnan(val):
|
|
940
|
+
log_dict["eval/val/distance/p95"] = val
|
|
941
|
+
|
|
942
|
+
if "distance/p99" in self.metrics_to_log:
|
|
943
|
+
val = metrics["distance_metrics"]["p99"]
|
|
944
|
+
if not np.isnan(val):
|
|
945
|
+
log_dict["eval/val/distance/p99"] = val
|
|
924
946
|
|
|
947
|
+
# PCK metrics
|
|
925
948
|
if "mPCK" in self.metrics_to_log:
|
|
926
|
-
log_dict["
|
|
949
|
+
log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
|
|
927
950
|
|
|
951
|
+
# PCK at specific thresholds (precomputed in evaluation.py)
|
|
952
|
+
if "PCK@5" in self.metrics_to_log:
|
|
953
|
+
log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
|
|
954
|
+
|
|
955
|
+
if "PCK@10" in self.metrics_to_log:
|
|
956
|
+
log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
|
|
957
|
+
|
|
958
|
+
# Visibility metrics
|
|
928
959
|
if "visibility_precision" in self.metrics_to_log:
|
|
929
960
|
val = metrics["visibility_metrics"]["precision"]
|
|
930
961
|
if not np.isnan(val):
|
|
931
|
-
log_dict["
|
|
962
|
+
log_dict["eval/val/visibility_precision"] = val
|
|
932
963
|
|
|
933
964
|
if "visibility_recall" in self.metrics_to_log:
|
|
934
965
|
val = metrics["visibility_metrics"]["recall"]
|
|
935
966
|
if not np.isnan(val):
|
|
936
|
-
log_dict["
|
|
967
|
+
log_dict["eval/val/visibility_recall"] = val
|
|
937
968
|
|
|
938
969
|
wandb_logger.experiment.log(log_dict, commit=False)
|
|
970
|
+
|
|
971
|
+
# Update best metrics in summary (excluding epoch)
|
|
972
|
+
for key, value in log_dict.items():
|
|
973
|
+
if key == "epoch":
|
|
974
|
+
continue
|
|
975
|
+
# Create summary key like "best/eval/val/mOKS"
|
|
976
|
+
summary_key = f"best/{key}"
|
|
977
|
+
current_best = wandb_logger.experiment.summary.get(summary_key)
|
|
978
|
+
# For distance metrics, lower is better; for others, higher is better
|
|
979
|
+
is_distance = "distance" in key
|
|
980
|
+
if current_best is None:
|
|
981
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
982
|
+
elif is_distance and value < current_best:
|
|
983
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
984
|
+
elif not is_distance and value > current_best:
|
|
985
|
+
wandb_logger.experiment.summary[summary_key] = value
|