sleap-nn 0.1.0a0__tar.gz → 0.1.0a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/PKG-INFO +2 -2
  2. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/config.md +13 -0
  3. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/pyproject.toml +1 -1
  4. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/__init__.py +4 -2
  5. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/get_config.py +5 -0
  6. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/trainer_config.py +23 -0
  7. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/custom_datasets.py +53 -11
  8. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/evaluation.py +73 -22
  9. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/bottomup.py +86 -20
  10. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/train.py +5 -0
  11. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/callbacks.py +274 -0
  12. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/lightning_modules.py +210 -2
  13. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/model_trainer.py +53 -0
  14. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/PKG-INFO +2 -2
  15. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/requires.txt +1 -1
  16. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_trainer_config.py +46 -0
  17. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_bottomup.py +91 -0
  18. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_cli.py +44 -7
  19. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_evaluation.py +10 -10
  20. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/training/test_callbacks.py +355 -0
  21. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/uv.lock +4 -4
  22. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.claude/commands/coverage.md +0 -0
  23. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.claude/commands/lint.md +0 -0
  24. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.claude/commands/pr-description.md +0 -0
  25. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.claude/skills/investigation/SKILL.md +0 -0
  26. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.dockerignore +0 -0
  27. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.github/workflows/build.yml +0 -0
  28. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.github/workflows/ci.yml +0 -0
  29. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.github/workflows/codespell.yml +0 -0
  30. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.github/workflows/docs.yml +0 -0
  31. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/.gitignore +0 -0
  32. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/CLAUDE.md +0 -0
  33. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/CONTRIBUTING.md +0 -0
  34. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/LICENSE +0 -0
  35. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/README.md +0 -0
  36. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/codecov.yml +0 -0
  37. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/assets/favicon.ico +0 -0
  38. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/assets/sleap-logo.png +0 -0
  39. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/cli.md +0 -0
  40. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/colab_notebooks/Training_with_sleap_nn_on_colab.ipynb +0 -0
  41. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/colab_notebooks/index.md +0 -0
  42. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/core_components.md +0 -0
  43. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/example_notebooks.md +0 -0
  44. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/index.md +0 -0
  45. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/inference.md +0 -0
  46. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/installation.md +0 -0
  47. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/models.md +0 -0
  48. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_convnext.yaml +0 -0
  49. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_unet_large_rf.yaml +0 -0
  50. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_unet_medium_rf.yaml +0 -0
  51. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_centroid_swint.yaml +0 -0
  52. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_centroid_unet.yaml +0 -0
  53. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_multi_class_bottomup_unet.yaml +0 -0
  54. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_single_instance_unet_large_rf.yaml +0 -0
  55. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_single_instance_unet_medium_rf.yaml +0 -0
  56. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_centered_instance_unet_large_rf.yaml +0 -0
  57. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_centered_instance_unet_medium_rf.yaml +0 -0
  58. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_multi_class_centered_instance_unet.yaml +0 -0
  59. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/step_by_step_tutorial.md +0 -0
  60. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/docs/training.md +0 -0
  61. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/example_notebooks/README.md +0 -0
  62. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/example_notebooks/augmentation_guide.py +0 -0
  63. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/example_notebooks/receptive_field_guide.py +0 -0
  64. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/example_notebooks/training_demo.py +0 -0
  65. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/mkdocs.yml +0 -0
  66. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/scripts/cov_summary.py +0 -0
  67. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/scripts/gen_changelog.py +0 -0
  68. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/scripts/gen_ref_pages.py +0 -0
  69. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/setup.cfg +0 -0
  70. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/.DS_Store +0 -0
  71. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/__init__.py +0 -0
  72. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/common.py +0 -0
  73. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/convnext.py +0 -0
  74. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/encoder_decoder.py +0 -0
  75. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/heads.py +0 -0
  76. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/model.py +0 -0
  77. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/swint.py +0 -0
  78. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/unet.py +0 -0
  79. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/architectures/utils.py +0 -0
  80. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/cli.py +0 -0
  81. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/__init__.py +0 -0
  82. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/data_config.py +0 -0
  83. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/model_config.py +0 -0
  84. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/training_job_config.py +0 -0
  85. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/config/utils.py +0 -0
  86. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/__init__.py +0 -0
  87. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/augmentation.py +0 -0
  88. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/confidence_maps.py +0 -0
  89. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/edge_maps.py +0 -0
  90. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/identity.py +0 -0
  91. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/instance_centroids.py +0 -0
  92. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/instance_cropping.py +0 -0
  93. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/normalization.py +0 -0
  94. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/providers.py +0 -0
  95. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/resizing.py +0 -0
  96. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/data/utils.py +0 -0
  97. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/__init__.py +0 -0
  98. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/identity.py +0 -0
  99. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/paf_grouping.py +0 -0
  100. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/peak_finding.py +0 -0
  101. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/predictors.py +0 -0
  102. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/provenance.py +0 -0
  103. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/single_instance.py +0 -0
  104. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/topdown.py +0 -0
  105. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/inference/utils.py +0 -0
  106. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/legacy_models.py +0 -0
  107. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/predict.py +0 -0
  108. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/system_info.py +0 -0
  109. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/__init__.py +0 -0
  110. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/__init__.py +0 -0
  111. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/fixed_window.py +0 -0
  112. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/local_queues.py +0 -0
  113. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/track_instance.py +0 -0
  114. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/tracker.py +0 -0
  115. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/tracking/utils.py +0 -0
  116. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/__init__.py +0 -0
  117. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/losses.py +0 -0
  118. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn/training/utils.py +0 -0
  119. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/SOURCES.txt +0 -0
  120. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/dependency_links.txt +0 -0
  121. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/entry_points.txt +0 -0
  122. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/top_level.txt +0 -0
  123. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/__init__.py +0 -0
  124. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_architecture_utils.py +0 -0
  125. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_common.py +0 -0
  126. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_convnext.py +0 -0
  127. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_encoder_decoder.py +0 -0
  128. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_heads.py +0 -0
  129. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_model.py +0 -0
  130. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_swint.py +0 -0
  131. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/architectures/test_unet.py +0 -0
  132. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/datasets/centered_pair_small.mp4 +0 -0
  133. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/datasets/minimal_instance.pkg.slp +0 -0
  134. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/datasets/small_robot.mp4 +0 -0
  135. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/datasets/small_robot_minimal.slp +0 -0
  136. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/inference/minimal_bboxes.pt +0 -0
  137. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/inference/minimal_cms.pt +0 -0
  138. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/get_dummy_activations.py +0 -0
  139. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/best_model.h5 +0 -0
  140. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/dummy_activations.h5 +0 -0
  141. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +0 -0
  142. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +0 -0
  143. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/best_model.h5 +0 -0
  144. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/dummy_activations.h5 +0 -0
  145. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json +0 -0
  146. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/training_config.json +0 -0
  147. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/best_model.h5 +0 -0
  148. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/dummy_activations.h5 +0 -0
  149. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/initial_config.json +0 -0
  150. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.train.slp +0 -0
  151. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.val.slp +0 -0
  152. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.train.slp +0 -0
  153. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.val.slp +0 -0
  154. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.train.npz +0 -0
  155. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.val.npz +0 -0
  156. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_config.json +0 -0
  157. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_log.csv +0 -0
  158. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/best_model.h5 +0 -0
  159. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/dummy_activations.h5 +0 -0
  160. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/initial_config.json +0 -0
  161. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.train.slp +0 -0
  162. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.val.slp +0 -0
  163. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.train.slp +0 -0
  164. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.val.slp +0 -0
  165. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.train.npz +0 -0
  166. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.val.npz +0 -0
  167. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_config.json +0 -0
  168. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_log.csv +0 -0
  169. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/best_model.h5 +0 -0
  170. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/dummy_activations.h5 +0 -0
  171. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/initial_config.json +0 -0
  172. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.train.slp +0 -0
  173. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.val.slp +0 -0
  174. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.train.slp +0 -0
  175. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.val.slp +0 -0
  176. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.train.npz +0 -0
  177. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.val.npz +0 -0
  178. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_config.json +0 -0
  179. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_log.csv +0 -0
  180. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/best_model.h5 +0 -0
  181. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/dummy_activations.h5 +0 -0
  182. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/initial_config.json +0 -0
  183. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.train.slp +0 -0
  184. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.val.slp +0 -0
  185. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_config.json +0 -0
  186. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_log.csv +0 -0
  187. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/bottomup_multiclass_training_config.json +0 -0
  188. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/bottomup_training_config.json +0 -0
  189. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centered_instance_training_config.json +0 -0
  190. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centered_instance_with_scaling_training_config.json +0 -0
  191. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centroid_training_config.json +0 -0
  192. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/single_instance_training_config.json +0 -0
  193. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/topdown_training_config.json +0 -0
  194. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/best.ckpt +0 -0
  195. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/initial_config.yaml +0 -0
  196. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/labels_train_gt_0.slp +0 -0
  197. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/labels_val_gt_0.slp +0 -0
  198. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/training_config.yaml +0 -0
  199. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/training_log.csv +0 -0
  200. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/best.ckpt +0 -0
  201. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/initial_config.yaml +0 -0
  202. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_train_gt_0.slp +0 -0
  203. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_val_gt_0.slp +0 -0
  204. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/training_config.yaml +0 -0
  205. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/training_log.csv +0 -0
  206. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/best.ckpt +0 -0
  207. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/initial_config.yaml +0 -0
  208. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/labels_train_gt_0.slp +0 -0
  209. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/labels_val_gt_0.slp +0 -0
  210. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/training_config.yaml +0 -0
  211. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/training_log.csv +0 -0
  212. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/best.ckpt +0 -0
  213. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/initial_config.yaml +0 -0
  214. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_train_gt_0.slp +0 -0
  215. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_val_gt_0.slp +0 -0
  216. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_config.yaml +0 -0
  217. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_log.csv +0 -0
  218. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/best.ckpt +0 -0
  219. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/initial_config.yaml +0 -0
  220. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_train_gt_0.slp +0 -0
  221. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_val_gt_0.slp +0 -0
  222. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_config.yaml +0 -0
  223. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_log.csv +0 -0
  224. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/best.ckpt +0 -0
  225. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/initial_config.yaml +0 -0
  226. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/labels_train_gt_0.slp +0 -0
  227. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/labels_val_gt_0.slp +0 -0
  228. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/training_config.yaml +0 -0
  229. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/training_log.csv +0 -0
  230. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/best.ckpt +0 -0
  231. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/initial_config.yaml +0 -0
  232. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/labels_train_gt_0.slp +0 -0
  233. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/labels_val_gt_0.slp +0 -0
  234. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_test.slp +0 -0
  235. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_train_0.slp +0 -0
  236. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_val_0.slp +0 -0
  237. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/test_pred_metrics.npz +0 -0
  238. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/train_0_pred_metrics.npz +0 -0
  239. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/training_config.yaml +0 -0
  240. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/training_log.csv +0 -0
  241. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/val_0_pred_metrics.npz +0 -0
  242. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_config_utils.py +0 -0
  243. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_data_config.py +0 -0
  244. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_model_config.py +0 -0
  245. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/config/test_training_job_config.py +0 -0
  246. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/conftest.py +0 -0
  247. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_augmentation.py +0 -0
  248. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_confmaps.py +0 -0
  249. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_custom_datasets.py +0 -0
  250. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_edge_maps.py +0 -0
  251. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_identity.py +0 -0
  252. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_instance_centroids.py +0 -0
  253. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_instance_cropping.py +0 -0
  254. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_normalization.py +0 -0
  255. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_providers.py +0 -0
  256. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_resizing.py +0 -0
  257. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/data/test_utils.py +0 -0
  258. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/__init__.py +0 -0
  259. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/datasets.py +0 -0
  260. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/inference.py +0 -0
  261. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/legacy_models.py +0 -0
  262. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/legacy_sleap_json_configs.py +0 -0
  263. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/fixtures/model_ckpts.py +0 -0
  264. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/__init__.py +0 -0
  265. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_paf_grouping.py +0 -0
  266. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_peak_finding.py +0 -0
  267. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_predictors.py +0 -0
  268. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_provenance.py +0 -0
  269. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_single_instance.py +0 -0
  270. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_topdown.py +0 -0
  271. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/inference/test_utils.py +0 -0
  272. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_legacy_models.py +0 -0
  273. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_predict.py +0 -0
  274. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_system_info.py +0 -0
  275. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_train.py +0 -0
  276. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/test_version.py +0 -0
  277. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/tracking/candidates/test_fixed_window.py +0 -0
  278. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/tracking/candidates/test_local_queues.py +0 -0
  279. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/tracking/test_tracker.py +0 -0
  280. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/training/test_lightning_modules.py +0 -0
  281. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/training/test_model_trainer.py +0 -0
  282. {sleap_nn-0.1.0a0 → sleap_nn-0.1.0a2}/tests/training/test_training_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sleap-nn
3
- Version: 0.1.0a0
3
+ Version: 0.1.0a2
4
4
  Summary: Neural network backend for training and inference for animal pose estimation.
5
5
  Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
6
6
  License: BSD-3-Clause
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.13
13
13
  Requires-Python: <3.14,>=3.11
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
- Requires-Dist: sleap-io<0.7.0,>=0.6.0
16
+ Requires-Dist: sleap-io<0.7.0,>=0.6.2
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: lightning
19
19
  Requires-Dist: kornia
@@ -845,6 +845,7 @@ trainer_config:
845
845
  - `wandb_mode`: (str) "offline" if only local logging is required. **Default**: `"None"`
846
846
  - `prv_runid`: (str) Previous run ID if training should be resumed from a previous ckpt. **Default**: `None`
847
847
  - `group`: (str) Group name for the run.
848
+ - `delete_local_logs`: (bool, optional) If `True`, delete local wandb logs folder (`wandb/`) after training completes. If `False`, keep the folder. If `None` (default), automatically delete if logging online (`wandb_mode` != "offline") and keep if logging offline. This can save significant disk space since wandb local logs can be several GB. **Default**: `None`
848
849
 
849
850
  **Example WandB configurations:**
850
851
 
@@ -876,6 +877,18 @@ trainer_config:
876
877
  group: "continued_experiments"
877
878
  ```
878
879
 
880
+ **Keep local wandb logs (override auto-delete):**
881
+ ```yaml
882
+ trainer_config:
883
+ use_wandb: true
884
+ wandb:
885
+ entity: "your_username"
886
+ project: "sleap_nn_experiments"
887
+ name: "training_run"
888
+ wandb_mode: "online"
889
+ delete_local_logs: false # Keep local logs even when syncing online
890
+ ```
891
+
879
892
  **No WandB logging:**
880
893
  ```yaml
881
894
  trainer_config:
@@ -29,7 +29,7 @@ classifiers = [
29
29
  "Programming Language :: Python :: 3.13",
30
30
  ]
31
31
  dependencies = [
32
- "sleap-io>=0.6.0,<0.7.0",
32
+ "sleap-io>=0.6.2,<0.7.0",
33
33
  "numpy",
34
34
  "lightning",
35
35
  "kornia",
@@ -41,14 +41,16 @@ def _safe_print(msg):
41
41
 
42
42
 
43
43
  # Add logger with the custom filter
44
+ # Disable colorization to avoid ANSI codes in captured output
44
45
  logger.add(
45
46
  _safe_print,
46
47
  level="DEBUG",
47
48
  filter=_should_log,
48
- format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
49
+ format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
50
+ colorize=False,
49
51
  )
50
52
 
51
- __version__ = "0.1.0a0"
53
+ __version__ = "0.1.0a2"
52
54
 
53
55
  # Public API
54
56
  from sleap_nn.evaluation import load_metrics
@@ -677,6 +677,7 @@ def get_trainer_config(
677
677
  wandb_save_viz_imgs_wandb: bool = False,
678
678
  wandb_resume_prv_runid: Optional[str] = None,
679
679
  wandb_group_name: Optional[str] = None,
680
+ wandb_delete_local_logs: Optional[bool] = None,
680
681
  optimizer: str = "Adam",
681
682
  learning_rate: float = 1e-3,
682
683
  amsgrad: bool = False,
@@ -746,6 +747,9 @@ def get_trainer_config(
746
747
  wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
747
748
  ckpt. Default: None
748
749
  wandb_group_name: Group name for the wandb run. Default: None.
750
+ wandb_delete_local_logs: If True, delete local wandb logs folder after training.
751
+ If False, keep the folder. If None (default), automatically delete if logging
752
+ online (wandb_mode != "offline") and keep if logging offline. Default: None.
749
753
  optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
750
754
  learning_rate: Learning rate of type float. Default: 1e-3.
751
755
  amsgrad: Enable AMSGrad with the optimizer. Default: False.
@@ -846,6 +850,7 @@ def get_trainer_config(
846
850
  save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
847
851
  prv_runid=wandb_resume_prv_runid,
848
852
  group=wandb_group_name,
853
+ delete_local_logs=wandb_delete_local_logs,
849
854
  ),
850
855
  save_ckpt=save_ckpt,
851
856
  ckpt_dir=ckpt_dir,
@@ -90,6 +90,10 @@ class WandBConfig:
90
90
  viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
91
91
  viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
92
92
  log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
93
+ delete_local_logs: (bool, optional) If True, delete local wandb logs folder after
94
+ training. If False, keep the folder. If None (default), automatically delete
95
+ if logging online (wandb_mode != "offline") and keep if logging offline.
96
+ *Default*: `None`.
93
97
  """
94
98
 
95
99
  entity: Optional[str] = None
@@ -107,6 +111,7 @@ class WandBConfig:
107
111
  viz_box_size: float = 5.0
108
112
  viz_confmap_threshold: float = 0.1
109
113
  log_viz_table: bool = False
114
+ delete_local_logs: Optional[bool] = None
110
115
 
111
116
 
112
117
  @define
@@ -203,6 +208,23 @@ class EarlyStoppingConfig:
203
208
  stop_training_on_plateau: bool = True
204
209
 
205
210
 
211
+ @define
212
+ class EvalConfig:
213
+ """Configuration for epoch-end evaluation.
214
+
215
+ Attributes:
216
+ enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
217
+ frequency: (int) Evaluate every N epochs. *Default*: `1`.
218
+ oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
219
+ oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
220
+ """
221
+
222
+ enabled: bool = False
223
+ frequency: int = field(default=1, validator=validators.ge(1))
224
+ oks_stddev: float = field(default=0.025, validator=validators.gt(0))
225
+ oks_scale: Optional[float] = None
226
+
227
+
206
228
  @define
207
229
  class HardKeypointMiningConfig:
208
230
  """Configuration for online hard keypoint mining.
@@ -305,6 +327,7 @@ class TrainerConfig:
305
327
  factory=HardKeypointMiningConfig
306
328
  )
307
329
  zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
330
+ eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
308
331
 
309
332
  @staticmethod
310
333
  def validate_optimizer_name(value):
@@ -13,6 +13,14 @@ from omegaconf import DictConfig, OmegaConf
13
13
  import numpy as np
14
14
  from PIL import Image
15
15
  from loguru import logger
16
+ from rich.progress import (
17
+ Progress,
18
+ SpinnerColumn,
19
+ TextColumn,
20
+ BarColumn,
21
+ TimeElapsedColumn,
22
+ )
23
+ from rich.console import Console
16
24
  import torch
17
25
  import torchvision.transforms as T
18
26
  from torch.utils.data import Dataset, DataLoader, DistributedSampler
@@ -215,17 +223,51 @@ class BaseDataset(Dataset):
215
223
  def _fill_cache(self, labels: List[sio.Labels]):
216
224
  """Load all samples to cache."""
217
225
  # TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
218
- for sample in self.lf_idx_list:
219
- labels_idx = sample["labels_idx"]
220
- lf_idx = sample["lf_idx"]
221
- img = labels[labels_idx][lf_idx].image
222
- if img.shape[-1] == 1:
223
- img = np.squeeze(img)
224
- if self.cache_img == "disk":
225
- f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
226
- Image.fromarray(img).save(f_name, format="JPEG")
227
- if self.cache_img == "memory":
228
- self.cache[(labels_idx, lf_idx)] = img
226
+ import os
227
+ import sys
228
+
229
+ total_samples = len(self.lf_idx_list)
230
+ cache_type = "disk" if self.cache_img == "disk" else "memory"
231
+
232
+ # Check for NO_COLOR env var or non-interactive terminal
233
+ no_color = (
234
+ os.environ.get("NO_COLOR") is not None
235
+ or os.environ.get("FORCE_COLOR") == "0"
236
+ )
237
+ use_progress = sys.stdout.isatty() and not no_color
238
+
239
+ def process_samples(progress=None, task=None):
240
+ for sample in self.lf_idx_list:
241
+ labels_idx = sample["labels_idx"]
242
+ lf_idx = sample["lf_idx"]
243
+ img = labels[labels_idx][lf_idx].image
244
+ if img.shape[-1] == 1:
245
+ img = np.squeeze(img)
246
+ if self.cache_img == "disk":
247
+ f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
248
+ Image.fromarray(img).save(f_name, format="JPEG")
249
+ if self.cache_img == "memory":
250
+ self.cache[(labels_idx, lf_idx)] = img
251
+ if progress is not None:
252
+ progress.update(task, advance=1)
253
+
254
+ if use_progress:
255
+ with Progress(
256
+ SpinnerColumn(),
257
+ TextColumn("[progress.description]{task.description}"),
258
+ BarColumn(),
259
+ TextColumn("{task.completed}/{task.total}"),
260
+ TimeElapsedColumn(),
261
+ console=Console(force_terminal=True),
262
+ transient=True,
263
+ ) as progress:
264
+ task = progress.add_task(
265
+ f"Caching images to {cache_type}", total=total_samples
266
+ )
267
+ process_samples(progress, task)
268
+ else:
269
+ logger.info(f"Caching {total_samples} images to {cache_type}...")
270
+ process_samples()
229
271
 
230
272
  def __len__(self) -> int:
231
273
  """Return the number of samples in the dataset."""
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
29
29
  """
30
30
  instance_list = []
31
31
  frame_idx = labeled_frame.frame_idx
32
- video_path = (
33
- labeled_frame.video.backend.source_filename
34
- if hasattr(labeled_frame.video.backend, "source_filename")
35
- else labeled_frame.video.backend.filename
36
- )
32
+
33
+ # Extract video path with fallbacks for embedded videos
34
+ video = labeled_frame.video
35
+ video_path = None
36
+ if video is not None:
37
+ backend = getattr(video, "backend", None)
38
+ if backend is not None:
39
+ # Try source_filename first (for embedded videos with provenance)
40
+ video_path = getattr(backend, "source_filename", None)
41
+ if video_path is None:
42
+ video_path = getattr(backend, "filename", None)
43
+ # Fallback to video.filename if backend doesn't have it
44
+ if video_path is None:
45
+ video_path = getattr(video, "filename", None)
46
+ # Handle list filenames (image sequences)
47
+ if isinstance(video_path, list) and video_path:
48
+ video_path = video_path[0]
49
+ # Final fallback: use a unique identifier
50
+ if video_path is None:
51
+ video_path = f"video_{id(video)}" if video is not None else "unknown"
52
+
37
53
  for instance in labeled_frame.instances:
38
54
  match_instance = MatchInstance(
39
55
  instance=instance, frame_idx=frame_idx, video_path=video_path
@@ -47,6 +63,10 @@ def find_frame_pairs(
47
63
  ) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
48
64
  """Find corresponding frames across two sets of labels.
49
65
 
66
+ This function uses sleap-io's robust video matching API to handle various
67
+ scenarios including embedded videos, cross-platform paths, and videos with
68
+ different metadata.
69
+
50
70
  Args:
51
71
  labels_gt: A `sio.Labels` instance with ground truth instances.
52
72
  labels_pr: A `sio.Labels` instance with predicted instances.
@@ -56,16 +76,15 @@ def find_frame_pairs(
56
76
  Returns:
57
77
  A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
58
78
  """
79
+ # Use sleap-io's robust video matching API (added in 0.6.2)
80
+ # The match() method returns a MatchResult with video_map: {pred_video: gt_video}
81
+ match_result = labels_gt.match(labels_pr)
82
+
59
83
  frame_pairs = []
60
- for video_gt in labels_gt.videos:
61
- # Find matching video instance in predictions.
62
- video_pr = None
63
- for video in labels_pr.videos:
64
- if video_gt.matches_content(video) and video_gt.matches_path(video):
65
- video_pr = video
66
- break
67
-
68
- if video_pr is None:
84
+ # Iterate over matched video pairs (pred_video -> gt_video mapping)
85
+ for video_pr, video_gt in match_result.video_map.items():
86
+ if video_gt is None:
87
+ # No match found for this prediction video
69
88
  continue
70
89
 
71
90
  # Find labeled frames in this video.
@@ -786,11 +805,26 @@ def run_evaluation(
786
805
  """Evaluate SLEAP-NN model predictions against ground truth labels."""
787
806
  logger.info("Loading ground truth labels...")
788
807
  ground_truth_instances = sio.load_slp(ground_truth_path)
808
+ logger.info(
809
+ f" Ground truth: {len(ground_truth_instances.videos)} videos, "
810
+ f"{len(ground_truth_instances.labeled_frames)} frames"
811
+ )
789
812
 
790
813
  logger.info("Loading predicted labels...")
791
814
  predicted_instances = sio.load_slp(predicted_path)
815
+ logger.info(
816
+ f" Predictions: {len(predicted_instances.videos)} videos, "
817
+ f"{len(predicted_instances.labeled_frames)} frames"
818
+ )
819
+
820
+ logger.info("Matching videos and frames...")
821
+ # Get match stats before creating evaluator
822
+ match_result = ground_truth_instances.match(predicted_instances)
823
+ logger.info(
824
+ f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
825
+ )
792
826
 
793
- logger.info("Creating evaluator...")
827
+ logger.info("Matching instances...")
794
828
  evaluator = Evaluator(
795
829
  ground_truth_instances=ground_truth_instances,
796
830
  predicted_instances=predicted_instances,
@@ -799,21 +833,38 @@ def run_evaluation(
799
833
  match_threshold=match_threshold,
800
834
  user_labels_only=user_labels_only,
801
835
  )
836
+ logger.info(
837
+ f" Frame pairs: {len(evaluator.frame_pairs)}, "
838
+ f"Matched instances: {len(evaluator.positive_pairs)}, "
839
+ f"Unmatched GT: {len(evaluator.false_negatives)}"
840
+ )
802
841
 
803
842
  logger.info("Computing evaluation metrics...")
804
843
  metrics = evaluator.evaluate()
805
844
 
845
+ # Compute PCK at specific thresholds (5 and 10 pixels)
846
+ dists = metrics["distance_metrics"]["dists"]
847
+ dists_clean = np.copy(dists)
848
+ dists_clean[np.isnan(dists_clean)] = np.inf
849
+ pck_5 = (dists_clean < 5).mean()
850
+ pck_10 = (dists_clean < 10).mean()
851
+
806
852
  # Print key metrics
807
853
  logger.info("Evaluation Results:")
808
- logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
809
- logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
810
- logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
811
- logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.4f}")
812
- logger.info(f"mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
854
+ logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
855
+ logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
856
+ logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
857
+ logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
858
+ logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
859
+ logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
860
+ logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
861
+ logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
862
+ logger.info(f" PCK@5px: {pck_5:.4f}")
863
+ logger.info(f" PCK@10px: {pck_10:.4f}")
813
864
  logger.info(
814
- f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
865
+ f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
815
866
  )
816
- logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
867
+ logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
817
868
 
818
869
  # Save metrics if path provided
819
870
  if save_metrics:
@@ -1,5 +1,6 @@
1
1
  """Inference modules for BottomUp models."""
2
2
 
3
+ import logging
3
4
  from typing import Dict, Optional
4
5
  import torch
5
6
  import lightning as L
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
7
8
  from sleap_nn.inference.paf_grouping import PAFScorer
8
9
  from sleap_nn.inference.identity import classify_peaks_from_maps
9
10
 
11
+ logger = logging.getLogger(__name__)
12
+
10
13
 
11
14
  class BottomUpInferenceModel(L.LightningModule):
12
15
  """BottomUp Inference model.
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
63
66
  return_pafs: Optional[bool] = False,
64
67
  return_paf_graph: Optional[bool] = False,
65
68
  input_scale: float = 1.0,
69
+ max_peaks_per_node: Optional[int] = None,
66
70
  ):
67
- """Initialise the model attributes."""
71
+ """Initialise the model attributes.
72
+
73
+ Args:
74
+ torch_model: A `nn.Module` that accepts images and predicts confidence maps.
75
+ paf_scorer: A `PAFScorer` instance for grouping instances.
76
+ cms_output_stride: Output stride of confidence maps relative to images.
77
+ pafs_output_stride: Output stride of PAFs relative to images.
78
+ peak_threshold: Minimum confidence map value for valid peaks.
79
+ refinement: Peak refinement method: None, "integral", or "local".
80
+ integral_patch_size: Size of patches for integral refinement.
81
+ return_confmaps: If True, return confidence maps in output.
82
+ return_pafs: If True, return PAFs in output.
83
+ return_paf_graph: If True, return intermediate PAF graph in output.
84
+ input_scale: Scale factor applied to input images.
85
+ max_peaks_per_node: Maximum number of peaks allowed per node before
86
+ skipping PAF scoring. If any node has more peaks than this limit,
87
+ empty predictions are returned. This prevents combinatorial explosion
88
+ during early training when confidence maps are noisy. Set to None to
89
+ disable this check (default). Recommended value: 100.
90
+ """
68
91
  super().__init__()
69
92
  self.torch_model = torch_model
70
93
  self.paf_scorer = paf_scorer
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
77
100
  self.return_pafs = return_pafs
78
101
  self.return_paf_graph = return_paf_graph
79
102
  self.input_scale = input_scale
103
+ self.max_peaks_per_node = max_peaks_per_node
80
104
 
81
105
  def _generate_cms_peaks(self, cms):
82
106
  # TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
124
148
  ) # (batch, h, w, 2*edges)
125
149
  cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
126
150
 
127
- (
128
- predicted_instances,
129
- predicted_peak_scores,
130
- predicted_instance_scores,
131
- edge_inds,
132
- edge_peak_inds,
133
- line_scores,
134
- ) = self.paf_scorer.predict(
135
- pafs=pafs,
136
- peaks=cms_peaks,
137
- peak_vals=cms_peak_vals,
138
- peak_channel_inds=cms_peak_channel_inds,
139
- )
140
-
141
- predicted_instances = [p / self.input_scale for p in predicted_instances]
142
- predicted_instances_adjusted = []
143
- for idx, p in enumerate(predicted_instances):
144
- predicted_instances_adjusted.append(
145
- p / inputs["eff_scale"][idx].to(p.device)
151
+ # Check if too many peaks per node (prevents combinatorial explosion)
152
+ skip_paf_scoring = False
153
+ if self.max_peaks_per_node is not None:
154
+ n_nodes = cms.shape[1]
155
+ for b in range(self.batch_size):
156
+ for node_idx in range(n_nodes):
157
+ n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
158
+ if n_peaks > self.max_peaks_per_node:
159
+ logger.warning(
160
+ f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
161
+ f"(max_peaks_per_node={self.max_peaks_per_node}). "
162
+ f"Model may need more training."
163
+ )
164
+ skip_paf_scoring = True
165
+ break
166
+ if skip_paf_scoring:
167
+ break
168
+
169
+ if skip_paf_scoring:
170
+ # Return empty predictions for each sample
171
+ device = cms.device
172
+ n_nodes = cms.shape[1]
173
+ predicted_instances_adjusted = []
174
+ predicted_peak_scores = []
175
+ predicted_instance_scores = []
176
+ for _ in range(self.batch_size):
177
+ predicted_instances_adjusted.append(
178
+ torch.full((0, n_nodes, 2), float("nan"), device=device)
179
+ )
180
+ predicted_peak_scores.append(
181
+ torch.full((0, n_nodes), float("nan"), device=device)
182
+ )
183
+ predicted_instance_scores.append(torch.tensor([], device=device))
184
+ edge_inds = [
185
+ torch.tensor([], dtype=torch.int32, device=device)
186
+ ] * self.batch_size
187
+ edge_peak_inds = [
188
+ torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
189
+ ] * self.batch_size
190
+ line_scores = [torch.tensor([], device=device)] * self.batch_size
191
+ else:
192
+ (
193
+ predicted_instances,
194
+ predicted_peak_scores,
195
+ predicted_instance_scores,
196
+ edge_inds,
197
+ edge_peak_inds,
198
+ line_scores,
199
+ ) = self.paf_scorer.predict(
200
+ pafs=pafs,
201
+ peaks=cms_peaks,
202
+ peak_vals=cms_peak_vals,
203
+ peak_channel_inds=cms_peak_channel_inds,
146
204
  )
205
+
206
+ predicted_instances = [p / self.input_scale for p in predicted_instances]
207
+ predicted_instances_adjusted = []
208
+ for idx, p in enumerate(predicted_instances):
209
+ predicted_instances_adjusted.append(
210
+ p / inputs["eff_scale"][idx].to(p.device)
211
+ )
212
+
147
213
  out = {
148
214
  "pred_instance_peaks": predicted_instances_adjusted,
149
215
  "pred_peak_values": predicted_peak_scores,
@@ -175,6 +175,7 @@ def train(
175
175
  wandb_save_viz_imgs_wandb: bool = False,
176
176
  wandb_resume_prv_runid: Optional[str] = None,
177
177
  wandb_group_name: Optional[str] = None,
178
+ wandb_delete_local_logs: Optional[bool] = None,
178
179
  optimizer: str = "Adam",
179
180
  learning_rate: float = 1e-3,
180
181
  amsgrad: bool = False,
@@ -353,6 +354,9 @@ def train(
353
354
  wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
354
355
  ckpt. Default: None
355
356
  wandb_group_name: Group name for the wandb run. Default: None.
357
+ wandb_delete_local_logs: If True, delete local wandb logs folder after training.
358
+ If False, keep the folder. If None (default), automatically delete if logging
359
+ online (wandb_mode != "offline") and keep if logging offline. Default: None.
356
360
  optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
357
361
  learning_rate: Learning rate of type float. Default: 1e-3.
358
362
  amsgrad: Enable AMSGrad with the optimizer. Default: False.
@@ -456,6 +460,7 @@ def train(
456
460
  wandb_save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
457
461
  wandb_resume_prv_runid=wandb_resume_prv_runid,
458
462
  wandb_group_name=wandb_group_name,
463
+ wandb_delete_local_logs=wandb_delete_local_logs,
459
464
  optimizer=optimizer,
460
465
  learning_rate=learning_rate,
461
466
  amsgrad=amsgrad,