sleap-nn 0.1.0a1__tar.gz → 0.1.0a2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/PKG-INFO +2 -2
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/pyproject.toml +1 -1
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/__init__.py +1 -1
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/trainer_config.py +18 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/evaluation.py +73 -22
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/bottomup.py +86 -20
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/callbacks.py +274 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/lightning_modules.py +210 -2
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/model_trainer.py +23 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/PKG-INFO +2 -2
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/requires.txt +1 -1
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_bottomup.py +91 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_evaluation.py +10 -10
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/training/test_callbacks.py +355 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/uv.lock +4 -4
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.claude/commands/coverage.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.claude/commands/lint.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.claude/commands/pr-description.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.claude/skills/investigation/SKILL.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.dockerignore +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.github/workflows/build.yml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.github/workflows/ci.yml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.github/workflows/codespell.yml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.github/workflows/docs.yml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.gitignore +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/CLAUDE.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/CONTRIBUTING.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/LICENSE +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/README.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/codecov.yml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/assets/favicon.ico +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/assets/sleap-logo.png +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/cli.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/colab_notebooks/Training_with_sleap_nn_on_colab.ipynb +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/colab_notebooks/index.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/config.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/core_components.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/example_notebooks.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/index.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/inference.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/installation.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/models.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_convnext.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_unet_large_rf.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_unet_medium_rf.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_centroid_swint.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_centroid_unet.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_multi_class_bottomup_unet.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_single_instance_unet_large_rf.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_single_instance_unet_medium_rf.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_centered_instance_unet_large_rf.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_centered_instance_unet_medium_rf.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_multi_class_centered_instance_unet.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/step_by_step_tutorial.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/training.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/example_notebooks/README.md +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/example_notebooks/augmentation_guide.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/example_notebooks/receptive_field_guide.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/example_notebooks/training_demo.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/mkdocs.yml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/scripts/cov_summary.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/scripts/gen_changelog.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/scripts/gen_ref_pages.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/setup.cfg +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/.DS_Store +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/common.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/convnext.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/encoder_decoder.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/heads.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/model.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/swint.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/unet.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/cli.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/data_config.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/get_config.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/model_config.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/training_job_config.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/augmentation.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/confidence_maps.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/custom_datasets.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/edge_maps.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/identity.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/instance_centroids.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/instance_cropping.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/normalization.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/providers.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/resizing.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/identity.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/paf_grouping.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/peak_finding.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/predictors.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/provenance.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/single_instance.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/topdown.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/legacy_models.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/predict.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/system_info.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/fixed_window.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/local_queues.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/track_instance.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/tracker.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/train.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/losses.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/SOURCES.txt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/dependency_links.txt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/entry_points.txt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/top_level.txt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_architecture_utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_common.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_convnext.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_encoder_decoder.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_heads.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_model.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_swint.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_unet.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/datasets/centered_pair_small.mp4 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/datasets/minimal_instance.pkg.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/datasets/small_robot.mp4 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/datasets/small_robot_minimal.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/inference/minimal_bboxes.pt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/inference/minimal_cms.pt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/get_dummy_activations.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/best_model.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/best_model.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/best_model.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/initial_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.train.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.val.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.train.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.val.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.train.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.val.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/best_model.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/initial_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.train.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.val.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.train.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.val.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.train.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.val.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/best_model.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/initial_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.train.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.val.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.train.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.val.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.train.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.val.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/best_model.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/initial_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.train.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.val.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/bottomup_multiclass_training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/bottomup_training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centered_instance_training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centered_instance_with_scaling_training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centroid_training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/single_instance_training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/topdown_training_config.json +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/best.ckpt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/training_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/best.ckpt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/training_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/best.ckpt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/training_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/best.ckpt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/best.ckpt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/best.ckpt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/training_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/best.ckpt +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_test.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_train_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_val_0.slp +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/test_pred_metrics.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/train_0_pred_metrics.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/training_config.yaml +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/training_log.csv +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/val_0_pred_metrics.npz +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_config_utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_data_config.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_model_config.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_trainer_config.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_training_job_config.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/conftest.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_augmentation.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_confmaps.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_custom_datasets.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_edge_maps.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_identity.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_instance_centroids.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_instance_cropping.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_normalization.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_providers.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_resizing.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/datasets.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/inference.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/legacy_models.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/legacy_sleap_json_configs.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/model_ckpts.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/__init__.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_paf_grouping.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_peak_finding.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_predictors.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_provenance.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_single_instance.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_topdown.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_utils.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_cli.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_legacy_models.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_predict.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_system_info.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_train.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_version.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/tracking/candidates/test_fixed_window.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/tracking/candidates/test_local_queues.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/tracking/test_tracker.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/training/test_lightning_modules.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/training/test_model_trainer.py +0 -0
- {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/training/test_training_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a2
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
13
13
|
Requires-Python: <3.14,>=3.11
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
-
Requires-Dist: sleap-io<0.7.0,>=0.6.
|
|
16
|
+
Requires-Dist: sleap-io<0.7.0,>=0.6.2
|
|
17
17
|
Requires-Dist: numpy
|
|
18
18
|
Requires-Dist: lightning
|
|
19
19
|
Requires-Dist: kornia
|
|
@@ -208,6 +208,23 @@ class EarlyStoppingConfig:
|
|
|
208
208
|
stop_training_on_plateau: bool = True
|
|
209
209
|
|
|
210
210
|
|
|
211
|
+
@define
|
|
212
|
+
class EvalConfig:
|
|
213
|
+
"""Configuration for epoch-end evaluation.
|
|
214
|
+
|
|
215
|
+
Attributes:
|
|
216
|
+
enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
|
|
217
|
+
frequency: (int) Evaluate every N epochs. *Default*: `1`.
|
|
218
|
+
oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
|
|
219
|
+
oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
enabled: bool = False
|
|
223
|
+
frequency: int = field(default=1, validator=validators.ge(1))
|
|
224
|
+
oks_stddev: float = field(default=0.025, validator=validators.gt(0))
|
|
225
|
+
oks_scale: Optional[float] = None
|
|
226
|
+
|
|
227
|
+
|
|
211
228
|
@define
|
|
212
229
|
class HardKeypointMiningConfig:
|
|
213
230
|
"""Configuration for online hard keypoint mining.
|
|
@@ -310,6 +327,7 @@ class TrainerConfig:
|
|
|
310
327
|
factory=HardKeypointMiningConfig
|
|
311
328
|
)
|
|
312
329
|
zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
|
|
330
|
+
eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
|
|
313
331
|
|
|
314
332
|
@staticmethod
|
|
315
333
|
def validate_optimizer_name(value):
|
|
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
|
|
|
29
29
|
"""
|
|
30
30
|
instance_list = []
|
|
31
31
|
frame_idx = labeled_frame.frame_idx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
32
|
+
|
|
33
|
+
# Extract video path with fallbacks for embedded videos
|
|
34
|
+
video = labeled_frame.video
|
|
35
|
+
video_path = None
|
|
36
|
+
if video is not None:
|
|
37
|
+
backend = getattr(video, "backend", None)
|
|
38
|
+
if backend is not None:
|
|
39
|
+
# Try source_filename first (for embedded videos with provenance)
|
|
40
|
+
video_path = getattr(backend, "source_filename", None)
|
|
41
|
+
if video_path is None:
|
|
42
|
+
video_path = getattr(backend, "filename", None)
|
|
43
|
+
# Fallback to video.filename if backend doesn't have it
|
|
44
|
+
if video_path is None:
|
|
45
|
+
video_path = getattr(video, "filename", None)
|
|
46
|
+
# Handle list filenames (image sequences)
|
|
47
|
+
if isinstance(video_path, list) and video_path:
|
|
48
|
+
video_path = video_path[0]
|
|
49
|
+
# Final fallback: use a unique identifier
|
|
50
|
+
if video_path is None:
|
|
51
|
+
video_path = f"video_{id(video)}" if video is not None else "unknown"
|
|
52
|
+
|
|
37
53
|
for instance in labeled_frame.instances:
|
|
38
54
|
match_instance = MatchInstance(
|
|
39
55
|
instance=instance, frame_idx=frame_idx, video_path=video_path
|
|
@@ -47,6 +63,10 @@ def find_frame_pairs(
|
|
|
47
63
|
) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
|
|
48
64
|
"""Find corresponding frames across two sets of labels.
|
|
49
65
|
|
|
66
|
+
This function uses sleap-io's robust video matching API to handle various
|
|
67
|
+
scenarios including embedded videos, cross-platform paths, and videos with
|
|
68
|
+
different metadata.
|
|
69
|
+
|
|
50
70
|
Args:
|
|
51
71
|
labels_gt: A `sio.Labels` instance with ground truth instances.
|
|
52
72
|
labels_pr: A `sio.Labels` instance with predicted instances.
|
|
@@ -56,16 +76,15 @@ def find_frame_pairs(
|
|
|
56
76
|
Returns:
|
|
57
77
|
A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
|
|
58
78
|
"""
|
|
79
|
+
# Use sleap-io's robust video matching API (added in 0.6.2)
|
|
80
|
+
# The match() method returns a MatchResult with video_map: {pred_video: gt_video}
|
|
81
|
+
match_result = labels_gt.match(labels_pr)
|
|
82
|
+
|
|
59
83
|
frame_pairs = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if video_gt.matches_content(video) and video_gt.matches_path(video):
|
|
65
|
-
video_pr = video
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
if video_pr is None:
|
|
84
|
+
# Iterate over matched video pairs (pred_video -> gt_video mapping)
|
|
85
|
+
for video_pr, video_gt in match_result.video_map.items():
|
|
86
|
+
if video_gt is None:
|
|
87
|
+
# No match found for this prediction video
|
|
69
88
|
continue
|
|
70
89
|
|
|
71
90
|
# Find labeled frames in this video.
|
|
@@ -786,11 +805,26 @@ def run_evaluation(
|
|
|
786
805
|
"""Evaluate SLEAP-NN model predictions against ground truth labels."""
|
|
787
806
|
logger.info("Loading ground truth labels...")
|
|
788
807
|
ground_truth_instances = sio.load_slp(ground_truth_path)
|
|
808
|
+
logger.info(
|
|
809
|
+
f" Ground truth: {len(ground_truth_instances.videos)} videos, "
|
|
810
|
+
f"{len(ground_truth_instances.labeled_frames)} frames"
|
|
811
|
+
)
|
|
789
812
|
|
|
790
813
|
logger.info("Loading predicted labels...")
|
|
791
814
|
predicted_instances = sio.load_slp(predicted_path)
|
|
815
|
+
logger.info(
|
|
816
|
+
f" Predictions: {len(predicted_instances.videos)} videos, "
|
|
817
|
+
f"{len(predicted_instances.labeled_frames)} frames"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
logger.info("Matching videos and frames...")
|
|
821
|
+
# Get match stats before creating evaluator
|
|
822
|
+
match_result = ground_truth_instances.match(predicted_instances)
|
|
823
|
+
logger.info(
|
|
824
|
+
f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
|
|
825
|
+
)
|
|
792
826
|
|
|
793
|
-
logger.info("
|
|
827
|
+
logger.info("Matching instances...")
|
|
794
828
|
evaluator = Evaluator(
|
|
795
829
|
ground_truth_instances=ground_truth_instances,
|
|
796
830
|
predicted_instances=predicted_instances,
|
|
@@ -799,21 +833,38 @@ def run_evaluation(
|
|
|
799
833
|
match_threshold=match_threshold,
|
|
800
834
|
user_labels_only=user_labels_only,
|
|
801
835
|
)
|
|
836
|
+
logger.info(
|
|
837
|
+
f" Frame pairs: {len(evaluator.frame_pairs)}, "
|
|
838
|
+
f"Matched instances: {len(evaluator.positive_pairs)}, "
|
|
839
|
+
f"Unmatched GT: {len(evaluator.false_negatives)}"
|
|
840
|
+
)
|
|
802
841
|
|
|
803
842
|
logger.info("Computing evaluation metrics...")
|
|
804
843
|
metrics = evaluator.evaluate()
|
|
805
844
|
|
|
845
|
+
# Compute PCK at specific thresholds (5 and 10 pixels)
|
|
846
|
+
dists = metrics["distance_metrics"]["dists"]
|
|
847
|
+
dists_clean = np.copy(dists)
|
|
848
|
+
dists_clean[np.isnan(dists_clean)] = np.inf
|
|
849
|
+
pck_5 = (dists_clean < 5).mean()
|
|
850
|
+
pck_10 = (dists_clean < 10).mean()
|
|
851
|
+
|
|
806
852
|
# Print key metrics
|
|
807
853
|
logger.info("Evaluation Results:")
|
|
808
|
-
logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
809
|
-
logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
810
|
-
logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
811
|
-
logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.
|
|
812
|
-
logger.info(f"
|
|
854
|
+
logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
855
|
+
logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
856
|
+
logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
857
|
+
logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
|
|
858
|
+
logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
|
|
859
|
+
logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
|
|
860
|
+
logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
|
|
861
|
+
logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
|
|
862
|
+
logger.info(f" PCK@5px: {pck_5:.4f}")
|
|
863
|
+
logger.info(f" PCK@10px: {pck_10:.4f}")
|
|
813
864
|
logger.info(
|
|
814
|
-
f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
865
|
+
f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
815
866
|
)
|
|
816
|
-
logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
867
|
+
logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
817
868
|
|
|
818
869
|
# Save metrics if path provided
|
|
819
870
|
if save_metrics:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Inference modules for BottomUp models."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Dict, Optional
|
|
4
5
|
import torch
|
|
5
6
|
import lightning as L
|
|
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
|
|
|
7
8
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
8
9
|
from sleap_nn.inference.identity import classify_peaks_from_maps
|
|
9
10
|
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class BottomUpInferenceModel(L.LightningModule):
|
|
12
15
|
"""BottomUp Inference model.
|
|
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
63
66
|
return_pafs: Optional[bool] = False,
|
|
64
67
|
return_paf_graph: Optional[bool] = False,
|
|
65
68
|
input_scale: float = 1.0,
|
|
69
|
+
max_peaks_per_node: Optional[int] = None,
|
|
66
70
|
):
|
|
67
|
-
"""Initialise the model attributes.
|
|
71
|
+
"""Initialise the model attributes.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
torch_model: A `nn.Module` that accepts images and predicts confidence maps.
|
|
75
|
+
paf_scorer: A `PAFScorer` instance for grouping instances.
|
|
76
|
+
cms_output_stride: Output stride of confidence maps relative to images.
|
|
77
|
+
pafs_output_stride: Output stride of PAFs relative to images.
|
|
78
|
+
peak_threshold: Minimum confidence map value for valid peaks.
|
|
79
|
+
refinement: Peak refinement method: None, "integral", or "local".
|
|
80
|
+
integral_patch_size: Size of patches for integral refinement.
|
|
81
|
+
return_confmaps: If True, return confidence maps in output.
|
|
82
|
+
return_pafs: If True, return PAFs in output.
|
|
83
|
+
return_paf_graph: If True, return intermediate PAF graph in output.
|
|
84
|
+
input_scale: Scale factor applied to input images.
|
|
85
|
+
max_peaks_per_node: Maximum number of peaks allowed per node before
|
|
86
|
+
skipping PAF scoring. If any node has more peaks than this limit,
|
|
87
|
+
empty predictions are returned. This prevents combinatorial explosion
|
|
88
|
+
during early training when confidence maps are noisy. Set to None to
|
|
89
|
+
disable this check (default). Recommended value: 100.
|
|
90
|
+
"""
|
|
68
91
|
super().__init__()
|
|
69
92
|
self.torch_model = torch_model
|
|
70
93
|
self.paf_scorer = paf_scorer
|
|
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
77
100
|
self.return_pafs = return_pafs
|
|
78
101
|
self.return_paf_graph = return_paf_graph
|
|
79
102
|
self.input_scale = input_scale
|
|
103
|
+
self.max_peaks_per_node = max_peaks_per_node
|
|
80
104
|
|
|
81
105
|
def _generate_cms_peaks(self, cms):
|
|
82
106
|
# TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
|
|
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
124
148
|
) # (batch, h, w, 2*edges)
|
|
125
149
|
cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
|
|
126
150
|
|
|
127
|
-
(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
151
|
+
# Check if too many peaks per node (prevents combinatorial explosion)
|
|
152
|
+
skip_paf_scoring = False
|
|
153
|
+
if self.max_peaks_per_node is not None:
|
|
154
|
+
n_nodes = cms.shape[1]
|
|
155
|
+
for b in range(self.batch_size):
|
|
156
|
+
for node_idx in range(n_nodes):
|
|
157
|
+
n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
|
|
158
|
+
if n_peaks > self.max_peaks_per_node:
|
|
159
|
+
logger.warning(
|
|
160
|
+
f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
|
|
161
|
+
f"(max_peaks_per_node={self.max_peaks_per_node}). "
|
|
162
|
+
f"Model may need more training."
|
|
163
|
+
)
|
|
164
|
+
skip_paf_scoring = True
|
|
165
|
+
break
|
|
166
|
+
if skip_paf_scoring:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
if skip_paf_scoring:
|
|
170
|
+
# Return empty predictions for each sample
|
|
171
|
+
device = cms.device
|
|
172
|
+
n_nodes = cms.shape[1]
|
|
173
|
+
predicted_instances_adjusted = []
|
|
174
|
+
predicted_peak_scores = []
|
|
175
|
+
predicted_instance_scores = []
|
|
176
|
+
for _ in range(self.batch_size):
|
|
177
|
+
predicted_instances_adjusted.append(
|
|
178
|
+
torch.full((0, n_nodes, 2), float("nan"), device=device)
|
|
179
|
+
)
|
|
180
|
+
predicted_peak_scores.append(
|
|
181
|
+
torch.full((0, n_nodes), float("nan"), device=device)
|
|
182
|
+
)
|
|
183
|
+
predicted_instance_scores.append(torch.tensor([], device=device))
|
|
184
|
+
edge_inds = [
|
|
185
|
+
torch.tensor([], dtype=torch.int32, device=device)
|
|
186
|
+
] * self.batch_size
|
|
187
|
+
edge_peak_inds = [
|
|
188
|
+
torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
|
|
189
|
+
] * self.batch_size
|
|
190
|
+
line_scores = [torch.tensor([], device=device)] * self.batch_size
|
|
191
|
+
else:
|
|
192
|
+
(
|
|
193
|
+
predicted_instances,
|
|
194
|
+
predicted_peak_scores,
|
|
195
|
+
predicted_instance_scores,
|
|
196
|
+
edge_inds,
|
|
197
|
+
edge_peak_inds,
|
|
198
|
+
line_scores,
|
|
199
|
+
) = self.paf_scorer.predict(
|
|
200
|
+
pafs=pafs,
|
|
201
|
+
peaks=cms_peaks,
|
|
202
|
+
peak_vals=cms_peak_vals,
|
|
203
|
+
peak_channel_inds=cms_peak_channel_inds,
|
|
146
204
|
)
|
|
205
|
+
|
|
206
|
+
predicted_instances = [p / self.input_scale for p in predicted_instances]
|
|
207
|
+
predicted_instances_adjusted = []
|
|
208
|
+
for idx, p in enumerate(predicted_instances):
|
|
209
|
+
predicted_instances_adjusted.append(
|
|
210
|
+
p / inputs["eff_scale"][idx].to(p.device)
|
|
211
|
+
)
|
|
212
|
+
|
|
147
213
|
out = {
|
|
148
214
|
"pred_instance_peaks": predicted_instances_adjusted,
|
|
149
215
|
"pred_peak_values": predicted_peak_scores,
|
|
@@ -662,3 +662,277 @@ class ProgressReporterZMQ(Callback):
|
|
|
662
662
|
return {
|
|
663
663
|
k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
|
|
664
664
|
}
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class EpochEndEvaluationCallback(Callback):
|
|
668
|
+
"""Callback to run full evaluation metrics at end of validation epochs.
|
|
669
|
+
|
|
670
|
+
This callback collects predictions and ground truth during validation,
|
|
671
|
+
then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
|
|
672
|
+
metrics to WandB.
|
|
673
|
+
|
|
674
|
+
Attributes:
|
|
675
|
+
skeleton: sio.Skeleton for creating instances.
|
|
676
|
+
videos: List of sio.Video objects.
|
|
677
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
678
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
679
|
+
oks_scale: Optional OKS scale override.
|
|
680
|
+
metrics_to_log: List of metric keys to log.
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
def __init__(
|
|
684
|
+
self,
|
|
685
|
+
skeleton: "sio.Skeleton",
|
|
686
|
+
videos: list,
|
|
687
|
+
eval_frequency: int = 1,
|
|
688
|
+
oks_stddev: float = 0.025,
|
|
689
|
+
oks_scale: Optional[float] = None,
|
|
690
|
+
metrics_to_log: Optional[list] = None,
|
|
691
|
+
):
|
|
692
|
+
"""Initialize the callback.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
skeleton: sio.Skeleton for creating instances.
|
|
696
|
+
videos: List of sio.Video objects.
|
|
697
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
698
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
699
|
+
oks_scale: Optional OKS scale override.
|
|
700
|
+
metrics_to_log: List of metric keys to log. If None, logs all available.
|
|
701
|
+
"""
|
|
702
|
+
super().__init__()
|
|
703
|
+
self.skeleton = skeleton
|
|
704
|
+
self.videos = videos
|
|
705
|
+
self.eval_frequency = eval_frequency
|
|
706
|
+
self.oks_stddev = oks_stddev
|
|
707
|
+
self.oks_scale = oks_scale
|
|
708
|
+
self.metrics_to_log = metrics_to_log or [
|
|
709
|
+
"mOKS",
|
|
710
|
+
"oks_voc.mAP",
|
|
711
|
+
"oks_voc.mAR",
|
|
712
|
+
"avg_distance",
|
|
713
|
+
"p50_distance",
|
|
714
|
+
"mPCK",
|
|
715
|
+
"visibility_precision",
|
|
716
|
+
"visibility_recall",
|
|
717
|
+
]
|
|
718
|
+
|
|
719
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
|
720
|
+
"""Enable prediction collection at the start of validation.
|
|
721
|
+
|
|
722
|
+
Skip during sanity check to avoid inference issues.
|
|
723
|
+
"""
|
|
724
|
+
if trainer.sanity_checking:
|
|
725
|
+
return
|
|
726
|
+
pl_module._collect_val_predictions = True
|
|
727
|
+
|
|
728
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
729
|
+
"""Run evaluation and log metrics at end of validation epoch."""
|
|
730
|
+
import sleap_io as sio
|
|
731
|
+
import numpy as np
|
|
732
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
733
|
+
from sleap_nn.evaluation import Evaluator
|
|
734
|
+
|
|
735
|
+
# Check frequency (epoch is 0-indexed, so add 1)
|
|
736
|
+
if (trainer.current_epoch + 1) % self.eval_frequency != 0:
|
|
737
|
+
pl_module._collect_val_predictions = False
|
|
738
|
+
return
|
|
739
|
+
|
|
740
|
+
# Only run on rank 0 for distributed training
|
|
741
|
+
if not trainer.is_global_zero:
|
|
742
|
+
pl_module._collect_val_predictions = False
|
|
743
|
+
return
|
|
744
|
+
|
|
745
|
+
# Check if we have predictions
|
|
746
|
+
if not pl_module.val_predictions or not pl_module.val_ground_truth:
|
|
747
|
+
logger.warning("No predictions collected for epoch-end evaluation")
|
|
748
|
+
pl_module._collect_val_predictions = False
|
|
749
|
+
return
|
|
750
|
+
|
|
751
|
+
try:
|
|
752
|
+
# Build sio.Labels from accumulated predictions and ground truth
|
|
753
|
+
pred_labels = self._build_pred_labels(pl_module.val_predictions, sio, np)
|
|
754
|
+
gt_labels = self._build_gt_labels(pl_module.val_ground_truth, sio, np)
|
|
755
|
+
|
|
756
|
+
# Check if we have valid frames to evaluate
|
|
757
|
+
if len(pred_labels) == 0:
|
|
758
|
+
logger.warning(
|
|
759
|
+
"No valid predictions for epoch-end evaluation "
|
|
760
|
+
"(all predictions may be empty or NaN)"
|
|
761
|
+
)
|
|
762
|
+
pl_module._collect_val_predictions = False
|
|
763
|
+
pl_module.val_predictions = []
|
|
764
|
+
pl_module.val_ground_truth = []
|
|
765
|
+
return
|
|
766
|
+
|
|
767
|
+
# Run evaluation
|
|
768
|
+
evaluator = Evaluator(
|
|
769
|
+
ground_truth_instances=gt_labels,
|
|
770
|
+
predicted_instances=pred_labels,
|
|
771
|
+
oks_stddev=self.oks_stddev,
|
|
772
|
+
oks_scale=self.oks_scale,
|
|
773
|
+
user_labels_only=False, # All validation frames are "user" frames
|
|
774
|
+
)
|
|
775
|
+
metrics = evaluator.evaluate()
|
|
776
|
+
|
|
777
|
+
# Log to WandB
|
|
778
|
+
self._log_metrics(trainer, metrics, trainer.current_epoch)
|
|
779
|
+
|
|
780
|
+
logger.info(
|
|
781
|
+
f"Epoch {trainer.current_epoch} evaluation: "
|
|
782
|
+
f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
|
|
783
|
+
f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
except Exception as e:
|
|
787
|
+
logger.warning(f"Epoch-end evaluation failed: {e}")
|
|
788
|
+
|
|
789
|
+
# Cleanup
|
|
790
|
+
pl_module._collect_val_predictions = False
|
|
791
|
+
pl_module.val_predictions = []
|
|
792
|
+
pl_module.val_ground_truth = []
|
|
793
|
+
|
|
794
|
+
def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
|
|
795
|
+
"""Convert prediction dicts to sio.Labels."""
|
|
796
|
+
labeled_frames = []
|
|
797
|
+
for pred in predictions:
|
|
798
|
+
pred_peaks = pred["pred_peaks"]
|
|
799
|
+
pred_scores = pred["pred_scores"]
|
|
800
|
+
|
|
801
|
+
# Handle NaN/missing predictions
|
|
802
|
+
if pred_peaks is None or (
|
|
803
|
+
isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
|
|
804
|
+
):
|
|
805
|
+
continue
|
|
806
|
+
|
|
807
|
+
# Handle multi-instance predictions (bottomup)
|
|
808
|
+
if len(pred_peaks.shape) == 2:
|
|
809
|
+
# Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
|
|
810
|
+
pred_peaks = pred_peaks.reshape(1, -1, 2)
|
|
811
|
+
pred_scores = pred_scores.reshape(1, -1)
|
|
812
|
+
|
|
813
|
+
instances = []
|
|
814
|
+
for inst_idx in range(len(pred_peaks)):
|
|
815
|
+
inst_points = pred_peaks[inst_idx]
|
|
816
|
+
inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
|
|
817
|
+
|
|
818
|
+
# Skip if all NaN
|
|
819
|
+
if np.isnan(inst_points).all():
|
|
820
|
+
continue
|
|
821
|
+
|
|
822
|
+
inst = sio.PredictedInstance.from_numpy(
|
|
823
|
+
points_data=inst_points,
|
|
824
|
+
skeleton=self.skeleton,
|
|
825
|
+
point_scores=(
|
|
826
|
+
inst_scores
|
|
827
|
+
if inst_scores is not None
|
|
828
|
+
else np.ones(len(inst_points))
|
|
829
|
+
),
|
|
830
|
+
score=(
|
|
831
|
+
float(np.nanmean(inst_scores))
|
|
832
|
+
if inst_scores is not None
|
|
833
|
+
else 1.0
|
|
834
|
+
),
|
|
835
|
+
)
|
|
836
|
+
instances.append(inst)
|
|
837
|
+
|
|
838
|
+
if instances:
|
|
839
|
+
lf = sio.LabeledFrame(
|
|
840
|
+
video=self.videos[pred["video_idx"]],
|
|
841
|
+
frame_idx=pred["frame_idx"],
|
|
842
|
+
instances=instances,
|
|
843
|
+
)
|
|
844
|
+
labeled_frames.append(lf)
|
|
845
|
+
|
|
846
|
+
return sio.Labels(
|
|
847
|
+
videos=self.videos,
|
|
848
|
+
skeletons=[self.skeleton],
|
|
849
|
+
labeled_frames=labeled_frames,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
|
|
853
|
+
"""Convert ground truth dicts to sio.Labels."""
|
|
854
|
+
labeled_frames = []
|
|
855
|
+
for gt in ground_truth:
|
|
856
|
+
instances = []
|
|
857
|
+
gt_instances = gt["gt_instances"]
|
|
858
|
+
|
|
859
|
+
# Handle shape variations
|
|
860
|
+
if len(gt_instances.shape) == 2:
|
|
861
|
+
# (n_nodes, 2) -> (1, n_nodes, 2)
|
|
862
|
+
gt_instances = gt_instances.reshape(1, -1, 2)
|
|
863
|
+
|
|
864
|
+
for i in range(min(gt["num_instances"], len(gt_instances))):
|
|
865
|
+
inst_data = gt_instances[i]
|
|
866
|
+
if np.isnan(inst_data).all():
|
|
867
|
+
continue
|
|
868
|
+
inst = sio.Instance.from_numpy(
|
|
869
|
+
points_data=inst_data,
|
|
870
|
+
skeleton=self.skeleton,
|
|
871
|
+
)
|
|
872
|
+
instances.append(inst)
|
|
873
|
+
|
|
874
|
+
if instances:
|
|
875
|
+
lf = sio.LabeledFrame(
|
|
876
|
+
video=self.videos[gt["video_idx"]],
|
|
877
|
+
frame_idx=gt["frame_idx"],
|
|
878
|
+
instances=instances,
|
|
879
|
+
)
|
|
880
|
+
labeled_frames.append(lf)
|
|
881
|
+
|
|
882
|
+
return sio.Labels(
|
|
883
|
+
videos=self.videos,
|
|
884
|
+
skeletons=[self.skeleton],
|
|
885
|
+
labeled_frames=labeled_frames,
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
def _log_metrics(self, trainer, metrics: dict, epoch: int):
|
|
889
|
+
"""Log evaluation metrics to WandB."""
|
|
890
|
+
import numpy as np
|
|
891
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
892
|
+
|
|
893
|
+
# Get WandB logger
|
|
894
|
+
wandb_logger = None
|
|
895
|
+
for log in trainer.loggers:
|
|
896
|
+
if isinstance(log, WandbLogger):
|
|
897
|
+
wandb_logger = log
|
|
898
|
+
break
|
|
899
|
+
|
|
900
|
+
if wandb_logger is None:
|
|
901
|
+
return
|
|
902
|
+
|
|
903
|
+
log_dict = {"epoch": epoch}
|
|
904
|
+
|
|
905
|
+
# Extract key metrics with consistent naming
|
|
906
|
+
if "mOKS" in self.metrics_to_log:
|
|
907
|
+
log_dict["val_mOKS"] = metrics["mOKS"]["mOKS"]
|
|
908
|
+
|
|
909
|
+
if "oks_voc.mAP" in self.metrics_to_log:
|
|
910
|
+
log_dict["val_oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
|
|
911
|
+
|
|
912
|
+
if "oks_voc.mAR" in self.metrics_to_log:
|
|
913
|
+
log_dict["val_oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
|
|
914
|
+
|
|
915
|
+
if "avg_distance" in self.metrics_to_log:
|
|
916
|
+
val = metrics["distance_metrics"]["avg"]
|
|
917
|
+
if not np.isnan(val):
|
|
918
|
+
log_dict["val_avg_distance"] = val
|
|
919
|
+
|
|
920
|
+
if "p50_distance" in self.metrics_to_log:
|
|
921
|
+
val = metrics["distance_metrics"]["p50"]
|
|
922
|
+
if not np.isnan(val):
|
|
923
|
+
log_dict["val_p50_distance"] = val
|
|
924
|
+
|
|
925
|
+
if "mPCK" in self.metrics_to_log:
|
|
926
|
+
log_dict["val_mPCK"] = metrics["pck_metrics"]["mPCK"]
|
|
927
|
+
|
|
928
|
+
if "visibility_precision" in self.metrics_to_log:
|
|
929
|
+
val = metrics["visibility_metrics"]["precision"]
|
|
930
|
+
if not np.isnan(val):
|
|
931
|
+
log_dict["val_visibility_precision"] = val
|
|
932
|
+
|
|
933
|
+
if "visibility_recall" in self.metrics_to_log:
|
|
934
|
+
val = metrics["visibility_metrics"]["recall"]
|
|
935
|
+
if not np.isnan(val):
|
|
936
|
+
log_dict["val_visibility_recall"] = val
|
|
937
|
+
|
|
938
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|