sleap-nn 0.1.0a1__tar.gz → 0.1.0a2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/PKG-INFO +2 -2
  2. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/pyproject.toml +1 -1
  3. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/__init__.py +1 -1
  4. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/trainer_config.py +18 -0
  5. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/evaluation.py +73 -22
  6. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/bottomup.py +86 -20
  7. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/callbacks.py +274 -0
  8. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/lightning_modules.py +210 -2
  9. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/model_trainer.py +23 -0
  10. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/PKG-INFO +2 -2
  11. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/requires.txt +1 -1
  12. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_bottomup.py +91 -0
  13. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_evaluation.py +10 -10
  14. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/training/test_callbacks.py +355 -0
  15. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/uv.lock +4 -4
  16. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.claude/commands/coverage.md +0 -0
  17. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.claude/commands/lint.md +0 -0
  18. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.claude/commands/pr-description.md +0 -0
  19. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.claude/skills/investigation/SKILL.md +0 -0
  20. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.dockerignore +0 -0
  21. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.github/workflows/build.yml +0 -0
  22. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.github/workflows/ci.yml +0 -0
  23. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.github/workflows/codespell.yml +0 -0
  24. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.github/workflows/docs.yml +0 -0
  25. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/.gitignore +0 -0
  26. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/CLAUDE.md +0 -0
  27. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/CONTRIBUTING.md +0 -0
  28. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/LICENSE +0 -0
  29. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/README.md +0 -0
  30. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/codecov.yml +0 -0
  31. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/assets/favicon.ico +0 -0
  32. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/assets/sleap-logo.png +0 -0
  33. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/cli.md +0 -0
  34. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/colab_notebooks/Training_with_sleap_nn_on_colab.ipynb +0 -0
  35. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/colab_notebooks/index.md +0 -0
  36. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/config.md +0 -0
  37. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/core_components.md +0 -0
  38. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/example_notebooks.md +0 -0
  39. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/index.md +0 -0
  40. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/inference.md +0 -0
  41. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/installation.md +0 -0
  42. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/models.md +0 -0
  43. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_convnext.yaml +0 -0
  44. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_unet_large_rf.yaml +0 -0
  45. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_bottomup_unet_medium_rf.yaml +0 -0
  46. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_centroid_swint.yaml +0 -0
  47. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_centroid_unet.yaml +0 -0
  48. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_multi_class_bottomup_unet.yaml +0 -0
  49. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_single_instance_unet_large_rf.yaml +0 -0
  50. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_single_instance_unet_medium_rf.yaml +0 -0
  51. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_centered_instance_unet_large_rf.yaml +0 -0
  52. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_centered_instance_unet_medium_rf.yaml +0 -0
  53. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/sample_configs/config_topdown_multi_class_centered_instance_unet.yaml +0 -0
  54. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/step_by_step_tutorial.md +0 -0
  55. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/docs/training.md +0 -0
  56. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/example_notebooks/README.md +0 -0
  57. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/example_notebooks/augmentation_guide.py +0 -0
  58. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/example_notebooks/receptive_field_guide.py +0 -0
  59. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/example_notebooks/training_demo.py +0 -0
  60. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/mkdocs.yml +0 -0
  61. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/scripts/cov_summary.py +0 -0
  62. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/scripts/gen_changelog.py +0 -0
  63. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/scripts/gen_ref_pages.py +0 -0
  64. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/setup.cfg +0 -0
  65. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/.DS_Store +0 -0
  66. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/__init__.py +0 -0
  67. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/common.py +0 -0
  68. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/convnext.py +0 -0
  69. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/encoder_decoder.py +0 -0
  70. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/heads.py +0 -0
  71. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/model.py +0 -0
  72. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/swint.py +0 -0
  73. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/unet.py +0 -0
  74. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/architectures/utils.py +0 -0
  75. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/cli.py +0 -0
  76. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/__init__.py +0 -0
  77. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/data_config.py +0 -0
  78. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/get_config.py +0 -0
  79. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/model_config.py +0 -0
  80. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/training_job_config.py +0 -0
  81. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/config/utils.py +0 -0
  82. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/__init__.py +0 -0
  83. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/augmentation.py +0 -0
  84. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/confidence_maps.py +0 -0
  85. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/custom_datasets.py +0 -0
  86. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/edge_maps.py +0 -0
  87. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/identity.py +0 -0
  88. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/instance_centroids.py +0 -0
  89. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/instance_cropping.py +0 -0
  90. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/normalization.py +0 -0
  91. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/providers.py +0 -0
  92. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/resizing.py +0 -0
  93. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/data/utils.py +0 -0
  94. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/__init__.py +0 -0
  95. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/identity.py +0 -0
  96. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/paf_grouping.py +0 -0
  97. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/peak_finding.py +0 -0
  98. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/predictors.py +0 -0
  99. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/provenance.py +0 -0
  100. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/single_instance.py +0 -0
  101. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/topdown.py +0 -0
  102. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/inference/utils.py +0 -0
  103. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/legacy_models.py +0 -0
  104. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/predict.py +0 -0
  105. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/system_info.py +0 -0
  106. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/__init__.py +0 -0
  107. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/__init__.py +0 -0
  108. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/fixed_window.py +0 -0
  109. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/candidates/local_queues.py +0 -0
  110. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/track_instance.py +0 -0
  111. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/tracker.py +0 -0
  112. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/tracking/utils.py +0 -0
  113. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/train.py +0 -0
  114. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/__init__.py +0 -0
  115. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/losses.py +0 -0
  116. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn/training/utils.py +0 -0
  117. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/SOURCES.txt +0 -0
  118. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/dependency_links.txt +0 -0
  119. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/entry_points.txt +0 -0
  120. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/sleap_nn.egg-info/top_level.txt +0 -0
  121. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/__init__.py +0 -0
  122. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_architecture_utils.py +0 -0
  123. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_common.py +0 -0
  124. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_convnext.py +0 -0
  125. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_encoder_decoder.py +0 -0
  126. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_heads.py +0 -0
  127. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_model.py +0 -0
  128. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_swint.py +0 -0
  129. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/architectures/test_unet.py +0 -0
  130. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/datasets/centered_pair_small.mp4 +0 -0
  131. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/datasets/minimal_instance.pkg.slp +0 -0
  132. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/datasets/small_robot.mp4 +0 -0
  133. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/datasets/small_robot_minimal.slp +0 -0
  134. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/inference/minimal_bboxes.pt +0 -0
  135. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/inference/minimal_cms.pt +0 -0
  136. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/get_dummy_activations.py +0 -0
  137. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/best_model.h5 +0 -0
  138. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/dummy_activations.h5 +0 -0
  139. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +0 -0
  140. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +0 -0
  141. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/best_model.h5 +0 -0
  142. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/dummy_activations.h5 +0 -0
  143. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json +0 -0
  144. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/training_config.json +0 -0
  145. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/best_model.h5 +0 -0
  146. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/dummy_activations.h5 +0 -0
  147. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/initial_config.json +0 -0
  148. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.train.slp +0 -0
  149. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.val.slp +0 -0
  150. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.train.slp +0 -0
  151. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.val.slp +0 -0
  152. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.train.npz +0 -0
  153. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.val.npz +0 -0
  154. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_config.json +0 -0
  155. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_log.csv +0 -0
  156. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/best_model.h5 +0 -0
  157. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/dummy_activations.h5 +0 -0
  158. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/initial_config.json +0 -0
  159. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.train.slp +0 -0
  160. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.val.slp +0 -0
  161. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.train.slp +0 -0
  162. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.val.slp +0 -0
  163. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.train.npz +0 -0
  164. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.val.npz +0 -0
  165. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_config.json +0 -0
  166. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_log.csv +0 -0
  167. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/best_model.h5 +0 -0
  168. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/dummy_activations.h5 +0 -0
  169. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/initial_config.json +0 -0
  170. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.train.slp +0 -0
  171. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.val.slp +0 -0
  172. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.train.slp +0 -0
  173. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.val.slp +0 -0
  174. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.train.npz +0 -0
  175. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.val.npz +0 -0
  176. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_config.json +0 -0
  177. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_log.csv +0 -0
  178. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/best_model.h5 +0 -0
  179. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/dummy_activations.h5 +0 -0
  180. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/initial_config.json +0 -0
  181. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.train.slp +0 -0
  182. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.val.slp +0 -0
  183. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_config.json +0 -0
  184. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_log.csv +0 -0
  185. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/bottomup_multiclass_training_config.json +0 -0
  186. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/bottomup_training_config.json +0 -0
  187. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centered_instance_training_config.json +0 -0
  188. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centered_instance_with_scaling_training_config.json +0 -0
  189. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/centroid_training_config.json +0 -0
  190. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/single_instance_training_config.json +0 -0
  191. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/legacy_sleap_json_configs/topdown_training_config.json +0 -0
  192. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/best.ckpt +0 -0
  193. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/initial_config.yaml +0 -0
  194. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/labels_train_gt_0.slp +0 -0
  195. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/labels_val_gt_0.slp +0 -0
  196. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/training_config.yaml +0 -0
  197. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_bottomup/training_log.csv +0 -0
  198. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/best.ckpt +0 -0
  199. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/initial_config.yaml +0 -0
  200. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_train_gt_0.slp +0 -0
  201. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_val_gt_0.slp +0 -0
  202. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/training_config.yaml +0 -0
  203. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centered_instance/training_log.csv +0 -0
  204. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/best.ckpt +0 -0
  205. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/initial_config.yaml +0 -0
  206. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/labels_train_gt_0.slp +0 -0
  207. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/labels_val_gt_0.slp +0 -0
  208. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/training_config.yaml +0 -0
  209. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_centroid/training_log.csv +0 -0
  210. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/best.ckpt +0 -0
  211. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/initial_config.yaml +0 -0
  212. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_train_gt_0.slp +0 -0
  213. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_val_gt_0.slp +0 -0
  214. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_config.yaml +0 -0
  215. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_log.csv +0 -0
  216. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/best.ckpt +0 -0
  217. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/initial_config.yaml +0 -0
  218. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_train_gt_0.slp +0 -0
  219. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_val_gt_0.slp +0 -0
  220. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_config.yaml +0 -0
  221. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_log.csv +0 -0
  222. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/best.ckpt +0 -0
  223. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/initial_config.yaml +0 -0
  224. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/labels_train_gt_0.slp +0 -0
  225. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/labels_val_gt_0.slp +0 -0
  226. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/training_config.yaml +0 -0
  227. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/minimal_instance_single_instance/training_log.csv +0 -0
  228. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/best.ckpt +0 -0
  229. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/initial_config.yaml +0 -0
  230. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/labels_train_gt_0.slp +0 -0
  231. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/labels_val_gt_0.slp +0 -0
  232. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_test.slp +0 -0
  233. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_train_0.slp +0 -0
  234. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/pred_val_0.slp +0 -0
  235. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/test_pred_metrics.npz +0 -0
  236. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/train_0_pred_metrics.npz +0 -0
  237. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/training_config.yaml +0 -0
  238. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/training_log.csv +0 -0
  239. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/assets/model_ckpts/single_instance_with_metrics/val_0_pred_metrics.npz +0 -0
  240. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_config_utils.py +0 -0
  241. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_data_config.py +0 -0
  242. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_model_config.py +0 -0
  243. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_trainer_config.py +0 -0
  244. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/config/test_training_job_config.py +0 -0
  245. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/conftest.py +0 -0
  246. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_augmentation.py +0 -0
  247. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_confmaps.py +0 -0
  248. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_custom_datasets.py +0 -0
  249. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_edge_maps.py +0 -0
  250. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_identity.py +0 -0
  251. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_instance_centroids.py +0 -0
  252. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_instance_cropping.py +0 -0
  253. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_normalization.py +0 -0
  254. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_providers.py +0 -0
  255. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_resizing.py +0 -0
  256. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/data/test_utils.py +0 -0
  257. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/__init__.py +0 -0
  258. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/datasets.py +0 -0
  259. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/inference.py +0 -0
  260. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/legacy_models.py +0 -0
  261. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/legacy_sleap_json_configs.py +0 -0
  262. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/fixtures/model_ckpts.py +0 -0
  263. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/__init__.py +0 -0
  264. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_paf_grouping.py +0 -0
  265. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_peak_finding.py +0 -0
  266. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_predictors.py +0 -0
  267. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_provenance.py +0 -0
  268. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_single_instance.py +0 -0
  269. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_topdown.py +0 -0
  270. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/inference/test_utils.py +0 -0
  271. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_cli.py +0 -0
  272. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_legacy_models.py +0 -0
  273. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_predict.py +0 -0
  274. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_system_info.py +0 -0
  275. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_train.py +0 -0
  276. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/test_version.py +0 -0
  277. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/tracking/candidates/test_fixed_window.py +0 -0
  278. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/tracking/candidates/test_local_queues.py +0 -0
  279. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/tracking/test_tracker.py +0 -0
  280. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/training/test_lightning_modules.py +0 -0
  281. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/training/test_model_trainer.py +0 -0
  282. {sleap_nn-0.1.0a1 → sleap_nn-0.1.0a2}/tests/training/test_training_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sleap-nn
3
- Version: 0.1.0a1
3
+ Version: 0.1.0a2
4
4
  Summary: Neural network backend for training and inference for animal pose estimation.
5
5
  Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
6
6
  License: BSD-3-Clause
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.13
13
13
  Requires-Python: <3.14,>=3.11
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
- Requires-Dist: sleap-io<0.7.0,>=0.6.0
16
+ Requires-Dist: sleap-io<0.7.0,>=0.6.2
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: lightning
19
19
  Requires-Dist: kornia
@@ -29,7 +29,7 @@ classifiers = [
29
29
  "Programming Language :: Python :: 3.13",
30
30
  ]
31
31
  dependencies = [
32
- "sleap-io>=0.6.0,<0.7.0",
32
+ "sleap-io>=0.6.2,<0.7.0",
33
33
  "numpy",
34
34
  "lightning",
35
35
  "kornia",
@@ -50,7 +50,7 @@ logger.add(
50
50
  colorize=False,
51
51
  )
52
52
 
53
- __version__ = "0.1.0a1"
53
+ __version__ = "0.1.0a2"
54
54
 
55
55
  # Public API
56
56
  from sleap_nn.evaluation import load_metrics
@@ -208,6 +208,23 @@ class EarlyStoppingConfig:
208
208
  stop_training_on_plateau: bool = True
209
209
 
210
210
 
211
+ @define
212
+ class EvalConfig:
213
+ """Configuration for epoch-end evaluation.
214
+
215
+ Attributes:
216
+ enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
217
+ frequency: (int) Evaluate every N epochs. *Default*: `1`.
218
+ oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
219
+ oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
220
+ """
221
+
222
+ enabled: bool = False
223
+ frequency: int = field(default=1, validator=validators.ge(1))
224
+ oks_stddev: float = field(default=0.025, validator=validators.gt(0))
225
+ oks_scale: Optional[float] = None
226
+
227
+
211
228
  @define
212
229
  class HardKeypointMiningConfig:
213
230
  """Configuration for online hard keypoint mining.
@@ -310,6 +327,7 @@ class TrainerConfig:
310
327
  factory=HardKeypointMiningConfig
311
328
  )
312
329
  zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
330
+ eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
313
331
 
314
332
  @staticmethod
315
333
  def validate_optimizer_name(value):
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
29
29
  """
30
30
  instance_list = []
31
31
  frame_idx = labeled_frame.frame_idx
32
- video_path = (
33
- labeled_frame.video.backend.source_filename
34
- if hasattr(labeled_frame.video.backend, "source_filename")
35
- else labeled_frame.video.backend.filename
36
- )
32
+
33
+ # Extract video path with fallbacks for embedded videos
34
+ video = labeled_frame.video
35
+ video_path = None
36
+ if video is not None:
37
+ backend = getattr(video, "backend", None)
38
+ if backend is not None:
39
+ # Try source_filename first (for embedded videos with provenance)
40
+ video_path = getattr(backend, "source_filename", None)
41
+ if video_path is None:
42
+ video_path = getattr(backend, "filename", None)
43
+ # Fallback to video.filename if backend doesn't have it
44
+ if video_path is None:
45
+ video_path = getattr(video, "filename", None)
46
+ # Handle list filenames (image sequences)
47
+ if isinstance(video_path, list) and video_path:
48
+ video_path = video_path[0]
49
+ # Final fallback: use a unique identifier
50
+ if video_path is None:
51
+ video_path = f"video_{id(video)}" if video is not None else "unknown"
52
+
37
53
  for instance in labeled_frame.instances:
38
54
  match_instance = MatchInstance(
39
55
  instance=instance, frame_idx=frame_idx, video_path=video_path
@@ -47,6 +63,10 @@ def find_frame_pairs(
47
63
  ) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
48
64
  """Find corresponding frames across two sets of labels.
49
65
 
66
+ This function uses sleap-io's robust video matching API to handle various
67
+ scenarios including embedded videos, cross-platform paths, and videos with
68
+ different metadata.
69
+
50
70
  Args:
51
71
  labels_gt: A `sio.Labels` instance with ground truth instances.
52
72
  labels_pr: A `sio.Labels` instance with predicted instances.
@@ -56,16 +76,15 @@ def find_frame_pairs(
56
76
  Returns:
57
77
  A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
58
78
  """
79
+ # Use sleap-io's robust video matching API (added in 0.6.2)
80
+ # The match() method returns a MatchResult with video_map: {pred_video: gt_video}
81
+ match_result = labels_gt.match(labels_pr)
82
+
59
83
  frame_pairs = []
60
- for video_gt in labels_gt.videos:
61
- # Find matching video instance in predictions.
62
- video_pr = None
63
- for video in labels_pr.videos:
64
- if video_gt.matches_content(video) and video_gt.matches_path(video):
65
- video_pr = video
66
- break
67
-
68
- if video_pr is None:
84
+ # Iterate over matched video pairs (pred_video -> gt_video mapping)
85
+ for video_pr, video_gt in match_result.video_map.items():
86
+ if video_gt is None:
87
+ # No match found for this prediction video
69
88
  continue
70
89
 
71
90
  # Find labeled frames in this video.
@@ -786,11 +805,26 @@ def run_evaluation(
786
805
  """Evaluate SLEAP-NN model predictions against ground truth labels."""
787
806
  logger.info("Loading ground truth labels...")
788
807
  ground_truth_instances = sio.load_slp(ground_truth_path)
808
+ logger.info(
809
+ f" Ground truth: {len(ground_truth_instances.videos)} videos, "
810
+ f"{len(ground_truth_instances.labeled_frames)} frames"
811
+ )
789
812
 
790
813
  logger.info("Loading predicted labels...")
791
814
  predicted_instances = sio.load_slp(predicted_path)
815
+ logger.info(
816
+ f" Predictions: {len(predicted_instances.videos)} videos, "
817
+ f"{len(predicted_instances.labeled_frames)} frames"
818
+ )
819
+
820
+ logger.info("Matching videos and frames...")
821
+ # Get match stats before creating evaluator
822
+ match_result = ground_truth_instances.match(predicted_instances)
823
+ logger.info(
824
+ f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
825
+ )
792
826
 
793
- logger.info("Creating evaluator...")
827
+ logger.info("Matching instances...")
794
828
  evaluator = Evaluator(
795
829
  ground_truth_instances=ground_truth_instances,
796
830
  predicted_instances=predicted_instances,
@@ -799,21 +833,38 @@ def run_evaluation(
799
833
  match_threshold=match_threshold,
800
834
  user_labels_only=user_labels_only,
801
835
  )
836
+ logger.info(
837
+ f" Frame pairs: {len(evaluator.frame_pairs)}, "
838
+ f"Matched instances: {len(evaluator.positive_pairs)}, "
839
+ f"Unmatched GT: {len(evaluator.false_negatives)}"
840
+ )
802
841
 
803
842
  logger.info("Computing evaluation metrics...")
804
843
  metrics = evaluator.evaluate()
805
844
 
845
+ # Compute PCK at specific thresholds (5 and 10 pixels)
846
+ dists = metrics["distance_metrics"]["dists"]
847
+ dists_clean = np.copy(dists)
848
+ dists_clean[np.isnan(dists_clean)] = np.inf
849
+ pck_5 = (dists_clean < 5).mean()
850
+ pck_10 = (dists_clean < 10).mean()
851
+
806
852
  # Print key metrics
807
853
  logger.info("Evaluation Results:")
808
- logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
809
- logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
810
- logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
811
- logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.4f}")
812
- logger.info(f"mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
854
+ logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
855
+ logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
856
+ logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
857
+ logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
858
+ logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
859
+ logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
860
+ logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
861
+ logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
862
+ logger.info(f" PCK@5px: {pck_5:.4f}")
863
+ logger.info(f" PCK@10px: {pck_10:.4f}")
813
864
  logger.info(
814
- f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
865
+ f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
815
866
  )
816
- logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
867
+ logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
817
868
 
818
869
  # Save metrics if path provided
819
870
  if save_metrics:
@@ -1,5 +1,6 @@
1
1
  """Inference modules for BottomUp models."""
2
2
 
3
+ import logging
3
4
  from typing import Dict, Optional
4
5
  import torch
5
6
  import lightning as L
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
7
8
  from sleap_nn.inference.paf_grouping import PAFScorer
8
9
  from sleap_nn.inference.identity import classify_peaks_from_maps
9
10
 
11
+ logger = logging.getLogger(__name__)
12
+
10
13
 
11
14
  class BottomUpInferenceModel(L.LightningModule):
12
15
  """BottomUp Inference model.
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
63
66
  return_pafs: Optional[bool] = False,
64
67
  return_paf_graph: Optional[bool] = False,
65
68
  input_scale: float = 1.0,
69
+ max_peaks_per_node: Optional[int] = None,
66
70
  ):
67
- """Initialise the model attributes."""
71
+ """Initialise the model attributes.
72
+
73
+ Args:
74
+ torch_model: A `nn.Module` that accepts images and predicts confidence maps.
75
+ paf_scorer: A `PAFScorer` instance for grouping instances.
76
+ cms_output_stride: Output stride of confidence maps relative to images.
77
+ pafs_output_stride: Output stride of PAFs relative to images.
78
+ peak_threshold: Minimum confidence map value for valid peaks.
79
+ refinement: Peak refinement method: None, "integral", or "local".
80
+ integral_patch_size: Size of patches for integral refinement.
81
+ return_confmaps: If True, return confidence maps in output.
82
+ return_pafs: If True, return PAFs in output.
83
+ return_paf_graph: If True, return intermediate PAF graph in output.
84
+ input_scale: Scale factor applied to input images.
85
+ max_peaks_per_node: Maximum number of peaks allowed per node before
86
+ skipping PAF scoring. If any node has more peaks than this limit,
87
+ empty predictions are returned. This prevents combinatorial explosion
88
+ during early training when confidence maps are noisy. Set to None to
89
+ disable this check (default). Recommended value: 100.
90
+ """
68
91
  super().__init__()
69
92
  self.torch_model = torch_model
70
93
  self.paf_scorer = paf_scorer
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
77
100
  self.return_pafs = return_pafs
78
101
  self.return_paf_graph = return_paf_graph
79
102
  self.input_scale = input_scale
103
+ self.max_peaks_per_node = max_peaks_per_node
80
104
 
81
105
  def _generate_cms_peaks(self, cms):
82
106
  # TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
124
148
  ) # (batch, h, w, 2*edges)
125
149
  cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
126
150
 
127
- (
128
- predicted_instances,
129
- predicted_peak_scores,
130
- predicted_instance_scores,
131
- edge_inds,
132
- edge_peak_inds,
133
- line_scores,
134
- ) = self.paf_scorer.predict(
135
- pafs=pafs,
136
- peaks=cms_peaks,
137
- peak_vals=cms_peak_vals,
138
- peak_channel_inds=cms_peak_channel_inds,
139
- )
140
-
141
- predicted_instances = [p / self.input_scale for p in predicted_instances]
142
- predicted_instances_adjusted = []
143
- for idx, p in enumerate(predicted_instances):
144
- predicted_instances_adjusted.append(
145
- p / inputs["eff_scale"][idx].to(p.device)
151
+ # Check if too many peaks per node (prevents combinatorial explosion)
152
+ skip_paf_scoring = False
153
+ if self.max_peaks_per_node is not None:
154
+ n_nodes = cms.shape[1]
155
+ for b in range(self.batch_size):
156
+ for node_idx in range(n_nodes):
157
+ n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
158
+ if n_peaks > self.max_peaks_per_node:
159
+ logger.warning(
160
+ f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
161
+ f"(max_peaks_per_node={self.max_peaks_per_node}). "
162
+ f"Model may need more training."
163
+ )
164
+ skip_paf_scoring = True
165
+ break
166
+ if skip_paf_scoring:
167
+ break
168
+
169
+ if skip_paf_scoring:
170
+ # Return empty predictions for each sample
171
+ device = cms.device
172
+ n_nodes = cms.shape[1]
173
+ predicted_instances_adjusted = []
174
+ predicted_peak_scores = []
175
+ predicted_instance_scores = []
176
+ for _ in range(self.batch_size):
177
+ predicted_instances_adjusted.append(
178
+ torch.full((0, n_nodes, 2), float("nan"), device=device)
179
+ )
180
+ predicted_peak_scores.append(
181
+ torch.full((0, n_nodes), float("nan"), device=device)
182
+ )
183
+ predicted_instance_scores.append(torch.tensor([], device=device))
184
+ edge_inds = [
185
+ torch.tensor([], dtype=torch.int32, device=device)
186
+ ] * self.batch_size
187
+ edge_peak_inds = [
188
+ torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
189
+ ] * self.batch_size
190
+ line_scores = [torch.tensor([], device=device)] * self.batch_size
191
+ else:
192
+ (
193
+ predicted_instances,
194
+ predicted_peak_scores,
195
+ predicted_instance_scores,
196
+ edge_inds,
197
+ edge_peak_inds,
198
+ line_scores,
199
+ ) = self.paf_scorer.predict(
200
+ pafs=pafs,
201
+ peaks=cms_peaks,
202
+ peak_vals=cms_peak_vals,
203
+ peak_channel_inds=cms_peak_channel_inds,
146
204
  )
205
+
206
+ predicted_instances = [p / self.input_scale for p in predicted_instances]
207
+ predicted_instances_adjusted = []
208
+ for idx, p in enumerate(predicted_instances):
209
+ predicted_instances_adjusted.append(
210
+ p / inputs["eff_scale"][idx].to(p.device)
211
+ )
212
+
147
213
  out = {
148
214
  "pred_instance_peaks": predicted_instances_adjusted,
149
215
  "pred_peak_values": predicted_peak_scores,
@@ -662,3 +662,277 @@ class ProgressReporterZMQ(Callback):
662
662
  return {
663
663
  k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
664
664
  }
665
+
666
+
667
+ class EpochEndEvaluationCallback(Callback):
668
+ """Callback to run full evaluation metrics at end of validation epochs.
669
+
670
+ This callback collects predictions and ground truth during validation,
671
+ then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
672
+ metrics to WandB.
673
+
674
+ Attributes:
675
+ skeleton: sio.Skeleton for creating instances.
676
+ videos: List of sio.Video objects.
677
+ eval_frequency: Run evaluation every N epochs (default: 1).
678
+ oks_stddev: OKS standard deviation (default: 0.025).
679
+ oks_scale: Optional OKS scale override.
680
+ metrics_to_log: List of metric keys to log.
681
+ """
682
+
683
+ def __init__(
684
+ self,
685
+ skeleton: "sio.Skeleton",
686
+ videos: list,
687
+ eval_frequency: int = 1,
688
+ oks_stddev: float = 0.025,
689
+ oks_scale: Optional[float] = None,
690
+ metrics_to_log: Optional[list] = None,
691
+ ):
692
+ """Initialize the callback.
693
+
694
+ Args:
695
+ skeleton: sio.Skeleton for creating instances.
696
+ videos: List of sio.Video objects.
697
+ eval_frequency: Run evaluation every N epochs (default: 1).
698
+ oks_stddev: OKS standard deviation (default: 0.025).
699
+ oks_scale: Optional OKS scale override.
700
+ metrics_to_log: List of metric keys to log. If None, logs all available.
701
+ """
702
+ super().__init__()
703
+ self.skeleton = skeleton
704
+ self.videos = videos
705
+ self.eval_frequency = eval_frequency
706
+ self.oks_stddev = oks_stddev
707
+ self.oks_scale = oks_scale
708
+ self.metrics_to_log = metrics_to_log or [
709
+ "mOKS",
710
+ "oks_voc.mAP",
711
+ "oks_voc.mAR",
712
+ "avg_distance",
713
+ "p50_distance",
714
+ "mPCK",
715
+ "visibility_precision",
716
+ "visibility_recall",
717
+ ]
718
+
719
+ def on_validation_epoch_start(self, trainer, pl_module):
720
+ """Enable prediction collection at the start of validation.
721
+
722
+ Skip during sanity check to avoid inference issues.
723
+ """
724
+ if trainer.sanity_checking:
725
+ return
726
+ pl_module._collect_val_predictions = True
727
+
728
+ def on_validation_epoch_end(self, trainer, pl_module):
729
+ """Run evaluation and log metrics at end of validation epoch."""
730
+ import sleap_io as sio
731
+ import numpy as np
732
+ from lightning.pytorch.loggers import WandbLogger
733
+ from sleap_nn.evaluation import Evaluator
734
+
735
+ # Check frequency (epoch is 0-indexed, so add 1)
736
+ if (trainer.current_epoch + 1) % self.eval_frequency != 0:
737
+ pl_module._collect_val_predictions = False
738
+ return
739
+
740
+ # Only run on rank 0 for distributed training
741
+ if not trainer.is_global_zero:
742
+ pl_module._collect_val_predictions = False
743
+ return
744
+
745
+ # Check if we have predictions
746
+ if not pl_module.val_predictions or not pl_module.val_ground_truth:
747
+ logger.warning("No predictions collected for epoch-end evaluation")
748
+ pl_module._collect_val_predictions = False
749
+ return
750
+
751
+ try:
752
+ # Build sio.Labels from accumulated predictions and ground truth
753
+ pred_labels = self._build_pred_labels(pl_module.val_predictions, sio, np)
754
+ gt_labels = self._build_gt_labels(pl_module.val_ground_truth, sio, np)
755
+
756
+ # Check if we have valid frames to evaluate
757
+ if len(pred_labels) == 0:
758
+ logger.warning(
759
+ "No valid predictions for epoch-end evaluation "
760
+ "(all predictions may be empty or NaN)"
761
+ )
762
+ pl_module._collect_val_predictions = False
763
+ pl_module.val_predictions = []
764
+ pl_module.val_ground_truth = []
765
+ return
766
+
767
+ # Run evaluation
768
+ evaluator = Evaluator(
769
+ ground_truth_instances=gt_labels,
770
+ predicted_instances=pred_labels,
771
+ oks_stddev=self.oks_stddev,
772
+ oks_scale=self.oks_scale,
773
+ user_labels_only=False, # All validation frames are "user" frames
774
+ )
775
+ metrics = evaluator.evaluate()
776
+
777
+ # Log to WandB
778
+ self._log_metrics(trainer, metrics, trainer.current_epoch)
779
+
780
+ logger.info(
781
+ f"Epoch {trainer.current_epoch} evaluation: "
782
+ f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
783
+ f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
784
+ )
785
+
786
+ except Exception as e:
787
+ logger.warning(f"Epoch-end evaluation failed: {e}")
788
+
789
+ # Cleanup
790
+ pl_module._collect_val_predictions = False
791
+ pl_module.val_predictions = []
792
+ pl_module.val_ground_truth = []
793
+
794
+ def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
795
+ """Convert prediction dicts to sio.Labels."""
796
+ labeled_frames = []
797
+ for pred in predictions:
798
+ pred_peaks = pred["pred_peaks"]
799
+ pred_scores = pred["pred_scores"]
800
+
801
+ # Handle NaN/missing predictions
802
+ if pred_peaks is None or (
803
+ isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
804
+ ):
805
+ continue
806
+
807
+ # Handle multi-instance predictions (bottomup)
808
+ if len(pred_peaks.shape) == 2:
809
+ # Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
810
+ pred_peaks = pred_peaks.reshape(1, -1, 2)
811
+ pred_scores = pred_scores.reshape(1, -1)
812
+
813
+ instances = []
814
+ for inst_idx in range(len(pred_peaks)):
815
+ inst_points = pred_peaks[inst_idx]
816
+ inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
817
+
818
+ # Skip if all NaN
819
+ if np.isnan(inst_points).all():
820
+ continue
821
+
822
+ inst = sio.PredictedInstance.from_numpy(
823
+ points_data=inst_points,
824
+ skeleton=self.skeleton,
825
+ point_scores=(
826
+ inst_scores
827
+ if inst_scores is not None
828
+ else np.ones(len(inst_points))
829
+ ),
830
+ score=(
831
+ float(np.nanmean(inst_scores))
832
+ if inst_scores is not None
833
+ else 1.0
834
+ ),
835
+ )
836
+ instances.append(inst)
837
+
838
+ if instances:
839
+ lf = sio.LabeledFrame(
840
+ video=self.videos[pred["video_idx"]],
841
+ frame_idx=pred["frame_idx"],
842
+ instances=instances,
843
+ )
844
+ labeled_frames.append(lf)
845
+
846
+ return sio.Labels(
847
+ videos=self.videos,
848
+ skeletons=[self.skeleton],
849
+ labeled_frames=labeled_frames,
850
+ )
851
+
852
+ def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
853
+ """Convert ground truth dicts to sio.Labels."""
854
+ labeled_frames = []
855
+ for gt in ground_truth:
856
+ instances = []
857
+ gt_instances = gt["gt_instances"]
858
+
859
+ # Handle shape variations
860
+ if len(gt_instances.shape) == 2:
861
+ # (n_nodes, 2) -> (1, n_nodes, 2)
862
+ gt_instances = gt_instances.reshape(1, -1, 2)
863
+
864
+ for i in range(min(gt["num_instances"], len(gt_instances))):
865
+ inst_data = gt_instances[i]
866
+ if np.isnan(inst_data).all():
867
+ continue
868
+ inst = sio.Instance.from_numpy(
869
+ points_data=inst_data,
870
+ skeleton=self.skeleton,
871
+ )
872
+ instances.append(inst)
873
+
874
+ if instances:
875
+ lf = sio.LabeledFrame(
876
+ video=self.videos[gt["video_idx"]],
877
+ frame_idx=gt["frame_idx"],
878
+ instances=instances,
879
+ )
880
+ labeled_frames.append(lf)
881
+
882
+ return sio.Labels(
883
+ videos=self.videos,
884
+ skeletons=[self.skeleton],
885
+ labeled_frames=labeled_frames,
886
+ )
887
+
888
+ def _log_metrics(self, trainer, metrics: dict, epoch: int):
889
+ """Log evaluation metrics to WandB."""
890
+ import numpy as np
891
+ from lightning.pytorch.loggers import WandbLogger
892
+
893
+ # Get WandB logger
894
+ wandb_logger = None
895
+ for log in trainer.loggers:
896
+ if isinstance(log, WandbLogger):
897
+ wandb_logger = log
898
+ break
899
+
900
+ if wandb_logger is None:
901
+ return
902
+
903
+ log_dict = {"epoch": epoch}
904
+
905
+ # Extract key metrics with consistent naming
906
+ if "mOKS" in self.metrics_to_log:
907
+ log_dict["val_mOKS"] = metrics["mOKS"]["mOKS"]
908
+
909
+ if "oks_voc.mAP" in self.metrics_to_log:
910
+ log_dict["val_oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
911
+
912
+ if "oks_voc.mAR" in self.metrics_to_log:
913
+ log_dict["val_oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
914
+
915
+ if "avg_distance" in self.metrics_to_log:
916
+ val = metrics["distance_metrics"]["avg"]
917
+ if not np.isnan(val):
918
+ log_dict["val_avg_distance"] = val
919
+
920
+ if "p50_distance" in self.metrics_to_log:
921
+ val = metrics["distance_metrics"]["p50"]
922
+ if not np.isnan(val):
923
+ log_dict["val_p50_distance"] = val
924
+
925
+ if "mPCK" in self.metrics_to_log:
926
+ log_dict["val_mPCK"] = metrics["pck_metrics"]["mPCK"]
927
+
928
+ if "visibility_precision" in self.metrics_to_log:
929
+ val = metrics["visibility_metrics"]["precision"]
930
+ if not np.isnan(val):
931
+ log_dict["val_visibility_precision"] = val
932
+
933
+ if "visibility_recall" in self.metrics_to_log:
934
+ val = metrics["visibility_metrics"]["recall"]
935
+ if not np.isnan(val):
936
+ log_dict["val_visibility_recall"] = val
937
+
938
+ wandb_logger.experiment.log(log_dict, commit=False)