sleap-nn 0.1.0a0__tar.gz → 0.1.0a2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/PKG-INFO +2 -2
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/config.md +13 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/pyproject.toml +1 -1
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/__init__.py +4 -2
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/get_config.py +5 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/trainer_config.py +23 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/custom_datasets.py +53 -11
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/evaluation.py +73 -22
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/bottomup.py +86 -20
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/train.py +5 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/callbacks.py +274 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/lightning_modules.py +210 -2
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/model_trainer.py +53 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/PKG-INFO +2 -2
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/requires.txt +1 -1
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_trainer_config.py +46 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_bottomup.py +91 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_cli.py +44 -7
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_evaluation.py +10 -10
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/training/test_callbacks.py +355 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/uv.lock +4 -4
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.claude/commands/coverage.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.claude/commands/lint.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.claude/commands/pr-description.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.claude/skills/investigation/SKILL.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.dockerignore +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.github/workflows/build.yml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.github/workflows/ci.yml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.github/workflows/codespell.yml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.github/workflows/docs.yml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.gitignore +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/CLAUDE.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/CONTRIBUTING.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/LICENSE +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/README.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/codecov.yml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/assets/favicon.ico +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/assets/sleap-logo.png +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/cli.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/colab_notebooks/Training_with_sleap_nn_on_colab.ipynb +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/colab_notebooks/index.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/core_components.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/example_notebooks.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/index.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/inference.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/installation.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/models.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_convnext.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_unet_large_rf.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_unet_medium_rf.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_centroid_swint.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_centroid_unet.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_multi_class_bottomup_unet.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_single_instance_unet_large_rf.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_single_instance_unet_medium_rf.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_centered_instance_unet_large_rf.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_centered_instance_unet_medium_rf.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_multi_class_centered_instance_unet.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/step_by_step_tutorial.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/training.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/example_notebooks/README.md +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/example_notebooks/augmentation_guide.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/example_notebooks/receptive_field_guide.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/example_notebooks/training_demo.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/mkdocs.yml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/scripts/cov_summary.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/scripts/gen_changelog.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/scripts/gen_ref_pages.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/setup.cfg +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/.DS_Store +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/common.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/convnext.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/encoder_decoder.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/heads.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/model.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/swint.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/unet.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/cli.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/data_config.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/model_config.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/training_job_config.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/augmentation.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/confidence_maps.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/edge_maps.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/identity.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/instance_centroids.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/instance_cropping.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/normalization.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/providers.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/resizing.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/identity.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/paf_grouping.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/peak_finding.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/predictors.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/provenance.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/single_instance.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/topdown.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/legacy_models.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/predict.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/system_info.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/fixed_window.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/local_queues.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/track_instance.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/tracker.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/losses.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/SOURCES.txt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/dependency_links.txt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/entry_points.txt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/top_level.txt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_architecture_utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_common.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_convnext.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_encoder_decoder.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_heads.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_model.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_swint.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_unet.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/datasets/centered_pair_small.mp4 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/datasets/minimal_instance.pkg.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/datasets/small_robot.mp4 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/datasets/small_robot_minimal.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/inference/minimal_bboxes.pt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/inference/minimal_cms.pt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/get_dummy_activations.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/best_model.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/best_model.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/best_model.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/initial_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.train.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.val.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.train.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.val.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.train.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.val.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/best_model.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/initial_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.train.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.val.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.train.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.val.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.train.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.val.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/best_model.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/initial_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.train.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.val.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.train.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.val.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.train.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.val.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/best_model.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/dummy_activations.h5 +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/initial_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.train.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.val.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/bottomup_multiclass_training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/bottomup_training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centered_instance_training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centered_instance_with_scaling_training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centroid_training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/single_instance_training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/topdown_training_config.json +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/best.ckpt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/training_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/best.ckpt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/training_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/best.ckpt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/training_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/best.ckpt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/best.ckpt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/best.ckpt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/training_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/best.ckpt +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/initial_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/labels_train_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/labels_val_gt_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_test.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_train_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_val_0.slp +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/test_pred_metrics.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/train_0_pred_metrics.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/training_config.yaml +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/training_log.csv +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/val_0_pred_metrics.npz +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_config_utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_data_config.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_model_config.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_training_job_config.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/conftest.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_augmentation.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_confmaps.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_custom_datasets.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_edge_maps.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_identity.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_instance_centroids.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_instance_cropping.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_normalization.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_providers.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_resizing.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/datasets.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/inference.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/legacy_models.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/legacy_sleap_json_configs.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/model_ckpts.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/__init__.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_paf_grouping.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_peak_finding.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_predictors.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_provenance.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_single_instance.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_topdown.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_utils.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_legacy_models.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_predict.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_system_info.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_train.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_version.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/tracking/candidates/test_fixed_window.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/tracking/candidates/test_local_queues.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/tracking/test_tracker.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/training/test_lightning_modules.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/training/test_model_trainer.py +0 -0
- {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/training/test_training_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a2
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
13
13
|
Requires-Python: <3.14,>=3.11
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
-
Requires-Dist: sleap-io<0.7.0,>=0.6.
|
|
16
|
+
Requires-Dist: sleap-io<0.7.0,>=0.6.2
|
|
17
17
|
Requires-Dist: numpy
|
|
18
18
|
Requires-Dist: lightning
|
|
19
19
|
Requires-Dist: kornia
|
|
@@ -845,6 +845,7 @@ trainer_config:
|
|
|
845
845
|
- `wandb_mode`: (str) "offline" if only local logging is required. **Default**: `"None"`
|
|
846
846
|
- `prv_runid`: (str) Previous run ID if training should be resumed from a previous ckpt. **Default**: `None`
|
|
847
847
|
- `group`: (str) Group name for the run.
|
|
848
|
+
- `delete_local_logs`: (bool, optional) If `True`, delete local wandb logs folder (`wandb/`) after training completes. If `False`, keep the folder. If `None` (default), automatically delete if logging online (`wandb_mode` != "offline") and keep if logging offline. This can save significant disk space since wandb local logs can be several GB. **Default**: `None`
|
|
848
849
|
|
|
849
850
|
**Example WandB configurations:**
|
|
850
851
|
|
|
@@ -876,6 +877,18 @@ trainer_config:
|
|
|
876
877
|
group: "continued_experiments"
|
|
877
878
|
```
|
|
878
879
|
|
|
880
|
+
**Keep local wandb logs (override auto-delete):**
|
|
881
|
+
```yaml
|
|
882
|
+
trainer_config:
|
|
883
|
+
use_wandb: true
|
|
884
|
+
wandb:
|
|
885
|
+
entity: "your_username"
|
|
886
|
+
project: "sleap_nn_experiments"
|
|
887
|
+
name: "training_run"
|
|
888
|
+
wandb_mode: "online"
|
|
889
|
+
delete_local_logs: false # Keep local logs even when syncing online
|
|
890
|
+
```
|
|
891
|
+
|
|
879
892
|
**No WandB logging:**
|
|
880
893
|
```yaml
|
|
881
894
|
trainer_config:
|
|
@@ -41,14 +41,16 @@ def _safe_print(msg):
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
# Add logger with the custom filter
|
|
44
|
+
# Disable colorization to avoid ANSI codes in captured output
|
|
44
45
|
logger.add(
|
|
45
46
|
_safe_print,
|
|
46
47
|
level="DEBUG",
|
|
47
48
|
filter=_should_log,
|
|
48
|
-
format="{time:YYYY-MM-DD HH:mm:ss} | {
|
|
49
|
+
format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
|
|
50
|
+
colorize=False,
|
|
49
51
|
)
|
|
50
52
|
|
|
51
|
-
__version__ = "0.1.
|
|
53
|
+
__version__ = "0.1.0a2"
|
|
52
54
|
|
|
53
55
|
# Public API
|
|
54
56
|
from sleap_nn.evaluation import load_metrics
|
|
@@ -677,6 +677,7 @@ def get_trainer_config(
|
|
|
677
677
|
wandb_save_viz_imgs_wandb: bool = False,
|
|
678
678
|
wandb_resume_prv_runid: Optional[str] = None,
|
|
679
679
|
wandb_group_name: Optional[str] = None,
|
|
680
|
+
wandb_delete_local_logs: Optional[bool] = None,
|
|
680
681
|
optimizer: str = "Adam",
|
|
681
682
|
learning_rate: float = 1e-3,
|
|
682
683
|
amsgrad: bool = False,
|
|
@@ -746,6 +747,9 @@ def get_trainer_config(
|
|
|
746
747
|
wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
|
|
747
748
|
ckpt. Default: None
|
|
748
749
|
wandb_group_name: Group name for the wandb run. Default: None.
|
|
750
|
+
wandb_delete_local_logs: If True, delete local wandb logs folder after training.
|
|
751
|
+
If False, keep the folder. If None (default), automatically delete if logging
|
|
752
|
+
online (wandb_mode != "offline") and keep if logging offline. Default: None.
|
|
749
753
|
optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
|
|
750
754
|
learning_rate: Learning rate of type float. Default: 1e-3.
|
|
751
755
|
amsgrad: Enable AMSGrad with the optimizer. Default: False.
|
|
@@ -846,6 +850,7 @@ def get_trainer_config(
|
|
|
846
850
|
save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
|
|
847
851
|
prv_runid=wandb_resume_prv_runid,
|
|
848
852
|
group=wandb_group_name,
|
|
853
|
+
delete_local_logs=wandb_delete_local_logs,
|
|
849
854
|
),
|
|
850
855
|
save_ckpt=save_ckpt,
|
|
851
856
|
ckpt_dir=ckpt_dir,
|
|
@@ -90,6 +90,10 @@ class WandBConfig:
|
|
|
90
90
|
viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
|
|
91
91
|
viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
|
|
92
92
|
log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
|
|
93
|
+
delete_local_logs: (bool, optional) If True, delete local wandb logs folder after
|
|
94
|
+
training. If False, keep the folder. If None (default), automatically delete
|
|
95
|
+
if logging online (wandb_mode != "offline") and keep if logging offline.
|
|
96
|
+
*Default*: `None`.
|
|
93
97
|
"""
|
|
94
98
|
|
|
95
99
|
entity: Optional[str] = None
|
|
@@ -107,6 +111,7 @@ class WandBConfig:
|
|
|
107
111
|
viz_box_size: float = 5.0
|
|
108
112
|
viz_confmap_threshold: float = 0.1
|
|
109
113
|
log_viz_table: bool = False
|
|
114
|
+
delete_local_logs: Optional[bool] = None
|
|
110
115
|
|
|
111
116
|
|
|
112
117
|
@define
|
|
@@ -203,6 +208,23 @@ class EarlyStoppingConfig:
|
|
|
203
208
|
stop_training_on_plateau: bool = True
|
|
204
209
|
|
|
205
210
|
|
|
211
|
+
@define
|
|
212
|
+
class EvalConfig:
|
|
213
|
+
"""Configuration for epoch-end evaluation.
|
|
214
|
+
|
|
215
|
+
Attributes:
|
|
216
|
+
enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
|
|
217
|
+
frequency: (int) Evaluate every N epochs. *Default*: `1`.
|
|
218
|
+
oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
|
|
219
|
+
oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
enabled: bool = False
|
|
223
|
+
frequency: int = field(default=1, validator=validators.ge(1))
|
|
224
|
+
oks_stddev: float = field(default=0.025, validator=validators.gt(0))
|
|
225
|
+
oks_scale: Optional[float] = None
|
|
226
|
+
|
|
227
|
+
|
|
206
228
|
@define
|
|
207
229
|
class HardKeypointMiningConfig:
|
|
208
230
|
"""Configuration for online hard keypoint mining.
|
|
@@ -305,6 +327,7 @@ class TrainerConfig:
|
|
|
305
327
|
factory=HardKeypointMiningConfig
|
|
306
328
|
)
|
|
307
329
|
zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
|
|
330
|
+
eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
|
|
308
331
|
|
|
309
332
|
@staticmethod
|
|
310
333
|
def validate_optimizer_name(value):
|
|
@@ -13,6 +13,14 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
from PIL import Image
|
|
15
15
|
from loguru import logger
|
|
16
|
+
from rich.progress import (
|
|
17
|
+
Progress,
|
|
18
|
+
SpinnerColumn,
|
|
19
|
+
TextColumn,
|
|
20
|
+
BarColumn,
|
|
21
|
+
TimeElapsedColumn,
|
|
22
|
+
)
|
|
23
|
+
from rich.console import Console
|
|
16
24
|
import torch
|
|
17
25
|
import torchvision.transforms as T
|
|
18
26
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
|
@@ -215,17 +223,51 @@ class BaseDataset(Dataset):
|
|
|
215
223
|
def _fill_cache(self, labels: List[sio.Labels]):
|
|
216
224
|
"""Load all samples to cache."""
|
|
217
225
|
# TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
226
|
+
import os
|
|
227
|
+
import sys
|
|
228
|
+
|
|
229
|
+
total_samples = len(self.lf_idx_list)
|
|
230
|
+
cache_type = "disk" if self.cache_img == "disk" else "memory"
|
|
231
|
+
|
|
232
|
+
# Check for NO_COLOR env var or non-interactive terminal
|
|
233
|
+
no_color = (
|
|
234
|
+
os.environ.get("NO_COLOR") is not None
|
|
235
|
+
or os.environ.get("FORCE_COLOR") == "0"
|
|
236
|
+
)
|
|
237
|
+
use_progress = sys.stdout.isatty() and not no_color
|
|
238
|
+
|
|
239
|
+
def process_samples(progress=None, task=None):
|
|
240
|
+
for sample in self.lf_idx_list:
|
|
241
|
+
labels_idx = sample["labels_idx"]
|
|
242
|
+
lf_idx = sample["lf_idx"]
|
|
243
|
+
img = labels[labels_idx][lf_idx].image
|
|
244
|
+
if img.shape[-1] == 1:
|
|
245
|
+
img = np.squeeze(img)
|
|
246
|
+
if self.cache_img == "disk":
|
|
247
|
+
f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
|
|
248
|
+
Image.fromarray(img).save(f_name, format="JPEG")
|
|
249
|
+
if self.cache_img == "memory":
|
|
250
|
+
self.cache[(labels_idx, lf_idx)] = img
|
|
251
|
+
if progress is not None:
|
|
252
|
+
progress.update(task, advance=1)
|
|
253
|
+
|
|
254
|
+
if use_progress:
|
|
255
|
+
with Progress(
|
|
256
|
+
SpinnerColumn(),
|
|
257
|
+
TextColumn("[progress.description]{task.description}"),
|
|
258
|
+
BarColumn(),
|
|
259
|
+
TextColumn("{task.completed}/{task.total}"),
|
|
260
|
+
TimeElapsedColumn(),
|
|
261
|
+
console=Console(force_terminal=True),
|
|
262
|
+
transient=True,
|
|
263
|
+
) as progress:
|
|
264
|
+
task = progress.add_task(
|
|
265
|
+
f"Caching images to {cache_type}", total=total_samples
|
|
266
|
+
)
|
|
267
|
+
process_samples(progress, task)
|
|
268
|
+
else:
|
|
269
|
+
logger.info(f"Caching {total_samples} images to {cache_type}...")
|
|
270
|
+
process_samples()
|
|
229
271
|
|
|
230
272
|
def __len__(self) -> int:
|
|
231
273
|
"""Return the number of samples in the dataset."""
|
|
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
|
|
|
29
29
|
"""
|
|
30
30
|
instance_list = []
|
|
31
31
|
frame_idx = labeled_frame.frame_idx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
32
|
+
|
|
33
|
+
# Extract video path with fallbacks for embedded videos
|
|
34
|
+
video = labeled_frame.video
|
|
35
|
+
video_path = None
|
|
36
|
+
if video is not None:
|
|
37
|
+
backend = getattr(video, "backend", None)
|
|
38
|
+
if backend is not None:
|
|
39
|
+
# Try source_filename first (for embedded videos with provenance)
|
|
40
|
+
video_path = getattr(backend, "source_filename", None)
|
|
41
|
+
if video_path is None:
|
|
42
|
+
video_path = getattr(backend, "filename", None)
|
|
43
|
+
# Fallback to video.filename if backend doesn't have it
|
|
44
|
+
if video_path is None:
|
|
45
|
+
video_path = getattr(video, "filename", None)
|
|
46
|
+
# Handle list filenames (image sequences)
|
|
47
|
+
if isinstance(video_path, list) and video_path:
|
|
48
|
+
video_path = video_path[0]
|
|
49
|
+
# Final fallback: use a unique identifier
|
|
50
|
+
if video_path is None:
|
|
51
|
+
video_path = f"video_{id(video)}" if video is not None else "unknown"
|
|
52
|
+
|
|
37
53
|
for instance in labeled_frame.instances:
|
|
38
54
|
match_instance = MatchInstance(
|
|
39
55
|
instance=instance, frame_idx=frame_idx, video_path=video_path
|
|
@@ -47,6 +63,10 @@ def find_frame_pairs(
|
|
|
47
63
|
) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
|
|
48
64
|
"""Find corresponding frames across two sets of labels.
|
|
49
65
|
|
|
66
|
+
This function uses sleap-io's robust video matching API to handle various
|
|
67
|
+
scenarios including embedded videos, cross-platform paths, and videos with
|
|
68
|
+
different metadata.
|
|
69
|
+
|
|
50
70
|
Args:
|
|
51
71
|
labels_gt: A `sio.Labels` instance with ground truth instances.
|
|
52
72
|
labels_pr: A `sio.Labels` instance with predicted instances.
|
|
@@ -56,16 +76,15 @@ def find_frame_pairs(
|
|
|
56
76
|
Returns:
|
|
57
77
|
A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
|
|
58
78
|
"""
|
|
79
|
+
# Use sleap-io's robust video matching API (added in 0.6.2)
|
|
80
|
+
# The match() method returns a MatchResult with video_map: {pred_video: gt_video}
|
|
81
|
+
match_result = labels_gt.match(labels_pr)
|
|
82
|
+
|
|
59
83
|
frame_pairs = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if video_gt.matches_content(video) and video_gt.matches_path(video):
|
|
65
|
-
video_pr = video
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
if video_pr is None:
|
|
84
|
+
# Iterate over matched video pairs (pred_video -> gt_video mapping)
|
|
85
|
+
for video_pr, video_gt in match_result.video_map.items():
|
|
86
|
+
if video_gt is None:
|
|
87
|
+
# No match found for this prediction video
|
|
69
88
|
continue
|
|
70
89
|
|
|
71
90
|
# Find labeled frames in this video.
|
|
@@ -786,11 +805,26 @@ def run_evaluation(
|
|
|
786
805
|
"""Evaluate SLEAP-NN model predictions against ground truth labels."""
|
|
787
806
|
logger.info("Loading ground truth labels...")
|
|
788
807
|
ground_truth_instances = sio.load_slp(ground_truth_path)
|
|
808
|
+
logger.info(
|
|
809
|
+
f" Ground truth: {len(ground_truth_instances.videos)} videos, "
|
|
810
|
+
f"{len(ground_truth_instances.labeled_frames)} frames"
|
|
811
|
+
)
|
|
789
812
|
|
|
790
813
|
logger.info("Loading predicted labels...")
|
|
791
814
|
predicted_instances = sio.load_slp(predicted_path)
|
|
815
|
+
logger.info(
|
|
816
|
+
f" Predictions: {len(predicted_instances.videos)} videos, "
|
|
817
|
+
f"{len(predicted_instances.labeled_frames)} frames"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
logger.info("Matching videos and frames...")
|
|
821
|
+
# Get match stats before creating evaluator
|
|
822
|
+
match_result = ground_truth_instances.match(predicted_instances)
|
|
823
|
+
logger.info(
|
|
824
|
+
f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
|
|
825
|
+
)
|
|
792
826
|
|
|
793
|
-
logger.info("
|
|
827
|
+
logger.info("Matching instances...")
|
|
794
828
|
evaluator = Evaluator(
|
|
795
829
|
ground_truth_instances=ground_truth_instances,
|
|
796
830
|
predicted_instances=predicted_instances,
|
|
@@ -799,21 +833,38 @@ def run_evaluation(
|
|
|
799
833
|
match_threshold=match_threshold,
|
|
800
834
|
user_labels_only=user_labels_only,
|
|
801
835
|
)
|
|
836
|
+
logger.info(
|
|
837
|
+
f" Frame pairs: {len(evaluator.frame_pairs)}, "
|
|
838
|
+
f"Matched instances: {len(evaluator.positive_pairs)}, "
|
|
839
|
+
f"Unmatched GT: {len(evaluator.false_negatives)}"
|
|
840
|
+
)
|
|
802
841
|
|
|
803
842
|
logger.info("Computing evaluation metrics...")
|
|
804
843
|
metrics = evaluator.evaluate()
|
|
805
844
|
|
|
845
|
+
# Compute PCK at specific thresholds (5 and 10 pixels)
|
|
846
|
+
dists = metrics["distance_metrics"]["dists"]
|
|
847
|
+
dists_clean = np.copy(dists)
|
|
848
|
+
dists_clean[np.isnan(dists_clean)] = np.inf
|
|
849
|
+
pck_5 = (dists_clean < 5).mean()
|
|
850
|
+
pck_10 = (dists_clean < 10).mean()
|
|
851
|
+
|
|
806
852
|
# Print key metrics
|
|
807
853
|
logger.info("Evaluation Results:")
|
|
808
|
-
logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
809
|
-
logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
810
|
-
logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
811
|
-
logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.
|
|
812
|
-
logger.info(f"
|
|
854
|
+
logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
855
|
+
logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
856
|
+
logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
857
|
+
logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
|
|
858
|
+
logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
|
|
859
|
+
logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
|
|
860
|
+
logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
|
|
861
|
+
logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
|
|
862
|
+
logger.info(f" PCK@5px: {pck_5:.4f}")
|
|
863
|
+
logger.info(f" PCK@10px: {pck_10:.4f}")
|
|
813
864
|
logger.info(
|
|
814
|
-
f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
865
|
+
f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
815
866
|
)
|
|
816
|
-
logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
867
|
+
logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
817
868
|
|
|
818
869
|
# Save metrics if path provided
|
|
819
870
|
if save_metrics:
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Inference modules for BottomUp models."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Dict, Optional
|
|
4
5
|
import torch
|
|
5
6
|
import lightning as L
|
|
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
|
|
|
7
8
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
8
9
|
from sleap_nn.inference.identity import classify_peaks_from_maps
|
|
9
10
|
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class BottomUpInferenceModel(L.LightningModule):
|
|
12
15
|
"""BottomUp Inference model.
|
|
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
63
66
|
return_pafs: Optional[bool] = False,
|
|
64
67
|
return_paf_graph: Optional[bool] = False,
|
|
65
68
|
input_scale: float = 1.0,
|
|
69
|
+
max_peaks_per_node: Optional[int] = None,
|
|
66
70
|
):
|
|
67
|
-
"""Initialise the model attributes.
|
|
71
|
+
"""Initialise the model attributes.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
torch_model: A `nn.Module` that accepts images and predicts confidence maps.
|
|
75
|
+
paf_scorer: A `PAFScorer` instance for grouping instances.
|
|
76
|
+
cms_output_stride: Output stride of confidence maps relative to images.
|
|
77
|
+
pafs_output_stride: Output stride of PAFs relative to images.
|
|
78
|
+
peak_threshold: Minimum confidence map value for valid peaks.
|
|
79
|
+
refinement: Peak refinement method: None, "integral", or "local".
|
|
80
|
+
integral_patch_size: Size of patches for integral refinement.
|
|
81
|
+
return_confmaps: If True, return confidence maps in output.
|
|
82
|
+
return_pafs: If True, return PAFs in output.
|
|
83
|
+
return_paf_graph: If True, return intermediate PAF graph in output.
|
|
84
|
+
input_scale: Scale factor applied to input images.
|
|
85
|
+
max_peaks_per_node: Maximum number of peaks allowed per node before
|
|
86
|
+
skipping PAF scoring. If any node has more peaks than this limit,
|
|
87
|
+
empty predictions are returned. This prevents combinatorial explosion
|
|
88
|
+
during early training when confidence maps are noisy. Set to None to
|
|
89
|
+
disable this check (default). Recommended value: 100.
|
|
90
|
+
"""
|
|
68
91
|
super().__init__()
|
|
69
92
|
self.torch_model = torch_model
|
|
70
93
|
self.paf_scorer = paf_scorer
|
|
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
77
100
|
self.return_pafs = return_pafs
|
|
78
101
|
self.return_paf_graph = return_paf_graph
|
|
79
102
|
self.input_scale = input_scale
|
|
103
|
+
self.max_peaks_per_node = max_peaks_per_node
|
|
80
104
|
|
|
81
105
|
def _generate_cms_peaks(self, cms):
|
|
82
106
|
# TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
|
|
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
124
148
|
) # (batch, h, w, 2*edges)
|
|
125
149
|
cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
|
|
126
150
|
|
|
127
|
-
(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
151
|
+
# Check if too many peaks per node (prevents combinatorial explosion)
|
|
152
|
+
skip_paf_scoring = False
|
|
153
|
+
if self.max_peaks_per_node is not None:
|
|
154
|
+
n_nodes = cms.shape[1]
|
|
155
|
+
for b in range(self.batch_size):
|
|
156
|
+
for node_idx in range(n_nodes):
|
|
157
|
+
n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
|
|
158
|
+
if n_peaks > self.max_peaks_per_node:
|
|
159
|
+
logger.warning(
|
|
160
|
+
f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
|
|
161
|
+
f"(max_peaks_per_node={self.max_peaks_per_node}). "
|
|
162
|
+
f"Model may need more training."
|
|
163
|
+
)
|
|
164
|
+
skip_paf_scoring = True
|
|
165
|
+
break
|
|
166
|
+
if skip_paf_scoring:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
if skip_paf_scoring:
|
|
170
|
+
# Return empty predictions for each sample
|
|
171
|
+
device = cms.device
|
|
172
|
+
n_nodes = cms.shape[1]
|
|
173
|
+
predicted_instances_adjusted = []
|
|
174
|
+
predicted_peak_scores = []
|
|
175
|
+
predicted_instance_scores = []
|
|
176
|
+
for _ in range(self.batch_size):
|
|
177
|
+
predicted_instances_adjusted.append(
|
|
178
|
+
torch.full((0, n_nodes, 2), float("nan"), device=device)
|
|
179
|
+
)
|
|
180
|
+
predicted_peak_scores.append(
|
|
181
|
+
torch.full((0, n_nodes), float("nan"), device=device)
|
|
182
|
+
)
|
|
183
|
+
predicted_instance_scores.append(torch.tensor([], device=device))
|
|
184
|
+
edge_inds = [
|
|
185
|
+
torch.tensor([], dtype=torch.int32, device=device)
|
|
186
|
+
] * self.batch_size
|
|
187
|
+
edge_peak_inds = [
|
|
188
|
+
torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
|
|
189
|
+
] * self.batch_size
|
|
190
|
+
line_scores = [torch.tensor([], device=device)] * self.batch_size
|
|
191
|
+
else:
|
|
192
|
+
(
|
|
193
|
+
predicted_instances,
|
|
194
|
+
predicted_peak_scores,
|
|
195
|
+
predicted_instance_scores,
|
|
196
|
+
edge_inds,
|
|
197
|
+
edge_peak_inds,
|
|
198
|
+
line_scores,
|
|
199
|
+
) = self.paf_scorer.predict(
|
|
200
|
+
pafs=pafs,
|
|
201
|
+
peaks=cms_peaks,
|
|
202
|
+
peak_vals=cms_peak_vals,
|
|
203
|
+
peak_channel_inds=cms_peak_channel_inds,
|
|
146
204
|
)
|
|
205
|
+
|
|
206
|
+
predicted_instances = [p / self.input_scale for p in predicted_instances]
|
|
207
|
+
predicted_instances_adjusted = []
|
|
208
|
+
for idx, p in enumerate(predicted_instances):
|
|
209
|
+
predicted_instances_adjusted.append(
|
|
210
|
+
p / inputs["eff_scale"][idx].to(p.device)
|
|
211
|
+
)
|
|
212
|
+
|
|
147
213
|
out = {
|
|
148
214
|
"pred_instance_peaks": predicted_instances_adjusted,
|
|
149
215
|
"pred_peak_values": predicted_peak_scores,
|
|
@@ -175,6 +175,7 @@ def train(
|
|
|
175
175
|
wandb_save_viz_imgs_wandb: bool = False,
|
|
176
176
|
wandb_resume_prv_runid: Optional[str] = None,
|
|
177
177
|
wandb_group_name: Optional[str] = None,
|
|
178
|
+
wandb_delete_local_logs: Optional[bool] = None,
|
|
178
179
|
optimizer: str = "Adam",
|
|
179
180
|
learning_rate: float = 1e-3,
|
|
180
181
|
amsgrad: bool = False,
|
|
@@ -353,6 +354,9 @@ def train(
|
|
|
353
354
|
wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
|
|
354
355
|
ckpt. Default: None
|
|
355
356
|
wandb_group_name: Group name for the wandb run. Default: None.
|
|
357
|
+
wandb_delete_local_logs: If True, delete local wandb logs folder after training.
|
|
358
|
+
If False, keep the folder. If None (default), automatically delete if logging
|
|
359
|
+
online (wandb_mode != "offline") and keep if logging offline. Default: None.
|
|
356
360
|
optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
|
|
357
361
|
learning_rate: Learning rate of type float. Default: 1e-3.
|
|
358
362
|
amsgrad: Enable AMSGrad with the optimizer. Default: False.
|
|
@@ -456,6 +460,7 @@ def train(
|
|
|
456
460
|
wandb_save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
|
|
457
461
|
wandb_resume_prv_runid=wandb_resume_prv_runid,
|
|
458
462
|
wandb_group_name=wandb_group_name,
|
|
463
|
+
wandb_delete_local_logs=wandb_delete_local_logs,
|
|
459
464
|
optimizer=optimizer,
|
|
460
465
|
learning_rate=learning_rate,
|
|
461
466
|
amsgrad=amsgrad,
|