sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +62 -20
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +35 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/train.py
CHANGED
|
@@ -118,6 +118,70 @@ def run_training(
|
|
|
118
118
|
logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
|
|
119
119
|
logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
|
|
120
120
|
|
|
121
|
+
# Log test metrics to wandb summary
|
|
122
|
+
if (
|
|
123
|
+
d_name.startswith("test")
|
|
124
|
+
and trainer.config.trainer_config.use_wandb
|
|
125
|
+
):
|
|
126
|
+
import wandb
|
|
127
|
+
|
|
128
|
+
if wandb.run is not None:
|
|
129
|
+
summary_metrics = {
|
|
130
|
+
f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
|
|
131
|
+
f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
|
|
132
|
+
"oks_voc.mAP"
|
|
133
|
+
],
|
|
134
|
+
f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
|
|
135
|
+
"oks_voc.mAR"
|
|
136
|
+
],
|
|
137
|
+
f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
|
|
138
|
+
f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
|
|
139
|
+
f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
|
|
140
|
+
f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
|
|
141
|
+
"avg"
|
|
142
|
+
],
|
|
143
|
+
f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
|
|
144
|
+
"p50"
|
|
145
|
+
],
|
|
146
|
+
f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
|
|
147
|
+
"p95"
|
|
148
|
+
],
|
|
149
|
+
f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
|
|
150
|
+
"p99"
|
|
151
|
+
],
|
|
152
|
+
f"eval/{d_name}/visibility_precision": metrics[
|
|
153
|
+
"visibility_metrics"
|
|
154
|
+
]["precision"],
|
|
155
|
+
f"eval/{d_name}/visibility_recall": metrics[
|
|
156
|
+
"visibility_metrics"
|
|
157
|
+
]["recall"],
|
|
158
|
+
}
|
|
159
|
+
for key, value in summary_metrics.items():
|
|
160
|
+
wandb.run.summary[key] = value
|
|
161
|
+
|
|
162
|
+
# Finish wandb run and cleanup after all evaluation is complete
|
|
163
|
+
if trainer.config.trainer_config.use_wandb:
|
|
164
|
+
import wandb
|
|
165
|
+
import shutil
|
|
166
|
+
|
|
167
|
+
if wandb.run is not None:
|
|
168
|
+
wandb.finish()
|
|
169
|
+
|
|
170
|
+
# Delete local wandb logs if configured
|
|
171
|
+
wandb_config = trainer.config.trainer_config.wandb
|
|
172
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
173
|
+
wandb_config.delete_local_logs is None
|
|
174
|
+
and wandb_config.wandb_mode != "offline"
|
|
175
|
+
)
|
|
176
|
+
if should_delete_wandb_logs:
|
|
177
|
+
wandb_dir = run_path / "wandb"
|
|
178
|
+
if wandb_dir.exists():
|
|
179
|
+
logger.info(
|
|
180
|
+
f"Deleting local wandb logs at {wandb_dir}... "
|
|
181
|
+
"(set trainer_config.wandb.delete_local_logs=false to disable)"
|
|
182
|
+
)
|
|
183
|
+
shutil.rmtree(wandb_dir, ignore_errors=True)
|
|
184
|
+
|
|
121
185
|
|
|
122
186
|
def train(
|
|
123
187
|
train_labels_path: Optional[List[str]] = None,
|
sleap_nn/training/callbacks.py
CHANGED
|
@@ -295,8 +295,8 @@ class WandBVizCallback(Callback):
|
|
|
295
295
|
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
296
296
|
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
297
297
|
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
298
|
-
log_dict[f"
|
|
299
|
-
log_dict[f"
|
|
298
|
+
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
299
|
+
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
300
300
|
|
|
301
301
|
if log_dict:
|
|
302
302
|
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
@@ -394,8 +394,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
394
394
|
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
395
395
|
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
396
396
|
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
397
|
-
log_dict[f"
|
|
398
|
-
log_dict[f"
|
|
397
|
+
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
398
|
+
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
399
399
|
|
|
400
400
|
# Render PAFs (always use matplotlib/direct for PAFs)
|
|
401
401
|
from io import BytesIO
|
|
@@ -408,7 +408,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
408
408
|
buf.seek(0)
|
|
409
409
|
plt.close(train_pafs_fig)
|
|
410
410
|
train_pafs_pil = Image.open(buf)
|
|
411
|
-
log_dict["
|
|
411
|
+
log_dict["viz/train/pafs"] = wandb.Image(
|
|
412
412
|
train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
|
|
413
413
|
)
|
|
414
414
|
|
|
@@ -418,7 +418,7 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
418
418
|
buf.seek(0)
|
|
419
419
|
plt.close(val_pafs_fig)
|
|
420
420
|
val_pafs_pil = Image.open(buf)
|
|
421
|
-
log_dict["
|
|
421
|
+
log_dict["viz/val/pafs"] = wandb.Image(
|
|
422
422
|
val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
|
|
423
423
|
)
|
|
424
424
|
|
|
@@ -444,8 +444,8 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
444
444
|
epoch,
|
|
445
445
|
train_img,
|
|
446
446
|
val_img,
|
|
447
|
-
log_dict["
|
|
448
|
-
log_dict["
|
|
447
|
+
log_dict["viz/train/pafs"],
|
|
448
|
+
log_dict["viz/val/pafs"],
|
|
449
449
|
]
|
|
450
450
|
],
|
|
451
451
|
)
|
|
@@ -709,9 +709,13 @@ class EpochEndEvaluationCallback(Callback):
|
|
|
709
709
|
"mOKS",
|
|
710
710
|
"oks_voc.mAP",
|
|
711
711
|
"oks_voc.mAR",
|
|
712
|
-
"
|
|
713
|
-
"
|
|
712
|
+
"distance/avg",
|
|
713
|
+
"distance/p50",
|
|
714
|
+
"distance/p95",
|
|
715
|
+
"distance/p99",
|
|
714
716
|
"mPCK",
|
|
717
|
+
"PCK@5",
|
|
718
|
+
"PCK@10",
|
|
715
719
|
"visibility_precision",
|
|
716
720
|
"visibility_recall",
|
|
717
721
|
]
|
|
@@ -779,6 +783,7 @@ class EpochEndEvaluationCallback(Callback):
|
|
|
779
783
|
|
|
780
784
|
logger.info(
|
|
781
785
|
f"Epoch {trainer.current_epoch} evaluation: "
|
|
786
|
+
f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
|
|
782
787
|
f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
|
|
783
788
|
f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
|
|
784
789
|
)
|
|
@@ -903,36 +908,73 @@ class EpochEndEvaluationCallback(Callback):
|
|
|
903
908
|
log_dict = {"epoch": epoch}
|
|
904
909
|
|
|
905
910
|
# Extract key metrics with consistent naming
|
|
911
|
+
# All eval metrics use eval/val/ prefix since they're computed on validation data
|
|
906
912
|
if "mOKS" in self.metrics_to_log:
|
|
907
|
-
log_dict["
|
|
913
|
+
log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
|
|
908
914
|
|
|
909
915
|
if "oks_voc.mAP" in self.metrics_to_log:
|
|
910
|
-
log_dict["
|
|
916
|
+
log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
|
|
911
917
|
|
|
912
918
|
if "oks_voc.mAR" in self.metrics_to_log:
|
|
913
|
-
log_dict["
|
|
919
|
+
log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
|
|
914
920
|
|
|
915
|
-
|
|
921
|
+
# Distance metrics grouped under eval/val/distance/
|
|
922
|
+
if "distance/avg" in self.metrics_to_log:
|
|
916
923
|
val = metrics["distance_metrics"]["avg"]
|
|
917
924
|
if not np.isnan(val):
|
|
918
|
-
log_dict["
|
|
925
|
+
log_dict["eval/val/distance/avg"] = val
|
|
919
926
|
|
|
920
|
-
if "
|
|
927
|
+
if "distance/p50" in self.metrics_to_log:
|
|
921
928
|
val = metrics["distance_metrics"]["p50"]
|
|
922
929
|
if not np.isnan(val):
|
|
923
|
-
log_dict["
|
|
930
|
+
log_dict["eval/val/distance/p50"] = val
|
|
924
931
|
|
|
932
|
+
if "distance/p95" in self.metrics_to_log:
|
|
933
|
+
val = metrics["distance_metrics"]["p95"]
|
|
934
|
+
if not np.isnan(val):
|
|
935
|
+
log_dict["eval/val/distance/p95"] = val
|
|
936
|
+
|
|
937
|
+
if "distance/p99" in self.metrics_to_log:
|
|
938
|
+
val = metrics["distance_metrics"]["p99"]
|
|
939
|
+
if not np.isnan(val):
|
|
940
|
+
log_dict["eval/val/distance/p99"] = val
|
|
941
|
+
|
|
942
|
+
# PCK metrics
|
|
925
943
|
if "mPCK" in self.metrics_to_log:
|
|
926
|
-
log_dict["
|
|
944
|
+
log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
|
|
927
945
|
|
|
946
|
+
# PCK at specific thresholds (precomputed in evaluation.py)
|
|
947
|
+
if "PCK@5" in self.metrics_to_log:
|
|
948
|
+
log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
|
|
949
|
+
|
|
950
|
+
if "PCK@10" in self.metrics_to_log:
|
|
951
|
+
log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
|
|
952
|
+
|
|
953
|
+
# Visibility metrics
|
|
928
954
|
if "visibility_precision" in self.metrics_to_log:
|
|
929
955
|
val = metrics["visibility_metrics"]["precision"]
|
|
930
956
|
if not np.isnan(val):
|
|
931
|
-
log_dict["
|
|
957
|
+
log_dict["eval/val/visibility_precision"] = val
|
|
932
958
|
|
|
933
959
|
if "visibility_recall" in self.metrics_to_log:
|
|
934
960
|
val = metrics["visibility_metrics"]["recall"]
|
|
935
961
|
if not np.isnan(val):
|
|
936
|
-
log_dict["
|
|
962
|
+
log_dict["eval/val/visibility_recall"] = val
|
|
937
963
|
|
|
938
964
|
wandb_logger.experiment.log(log_dict, commit=False)
|
|
965
|
+
|
|
966
|
+
# Update best metrics in summary (excluding epoch)
|
|
967
|
+
for key, value in log_dict.items():
|
|
968
|
+
if key == "epoch":
|
|
969
|
+
continue
|
|
970
|
+
# Create summary key like "best/eval/val/mOKS"
|
|
971
|
+
summary_key = f"best/{key}"
|
|
972
|
+
current_best = wandb_logger.experiment.summary.get(summary_key)
|
|
973
|
+
# For distance metrics, lower is better; for others, higher is better
|
|
974
|
+
is_distance = "distance" in key
|
|
975
|
+
if current_best is None:
|
|
976
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
977
|
+
elif is_distance and value < current_best:
|
|
978
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
979
|
+
elif not is_distance and value > current_best:
|
|
980
|
+
wandb_logger.experiment.summary[summary_key] = value
|