sleap-nn 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (254) hide show
  1. sleap_nn-0.0.1/.claude/commands/coverage.md +14 -0
  2. sleap_nn-0.0.1/.claude/commands/lint.md +9 -0
  3. sleap_nn-0.0.1/.claude/commands/pr-description.md +7 -0
  4. sleap_nn-0.0.1/.dockerignore +15 -0
  5. sleap_nn-0.0.1/.github/workflows/build.yml +45 -0
  6. sleap_nn-0.0.1/.github/workflows/ci.yml +141 -0
  7. sleap_nn-0.0.1/.github/workflows/codespell.yml +25 -0
  8. sleap_nn-0.0.1/.github/workflows/docs.yml +59 -0
  9. sleap_nn-0.0.1/.gitignore +173 -0
  10. sleap_nn-0.0.1/CLAUDE.md +81 -0
  11. sleap_nn-0.0.1/CONTRIBUTING.md +163 -0
  12. sleap_nn-0.0.1/LICENSE +674 -0
  13. sleap_nn-0.0.1/PKG-INFO +176 -0
  14. sleap_nn-0.0.1/README.md +112 -0
  15. sleap_nn-0.0.1/codecov.yml +18 -0
  16. sleap_nn-0.0.1/docs/assets/favicon.ico +0 -0
  17. sleap_nn-0.0.1/docs/assets/sleap-logo.png +0 -0
  18. sleap_nn-0.0.1/docs/config.md +863 -0
  19. sleap_nn-0.0.1/docs/example_notebooks.md +75 -0
  20. sleap_nn-0.0.1/docs/index.md +155 -0
  21. sleap_nn-0.0.1/docs/inference.md +272 -0
  22. sleap_nn-0.0.1/docs/installation.md +98 -0
  23. sleap_nn-0.0.1/docs/models.md +67 -0
  24. sleap_nn-0.0.1/docs/sample_configs/config_bottomup_convnext.yaml +145 -0
  25. sleap_nn-0.0.1/docs/sample_configs/config_bottomup_unet.yaml +144 -0
  26. sleap_nn-0.0.1/docs/sample_configs/config_centroid_swint.yaml +148 -0
  27. sleap_nn-0.0.1/docs/sample_configs/config_centroid_unet.yaml +146 -0
  28. sleap_nn-0.0.1/docs/sample_configs/config_multi_class_bottomup_unet.yaml +145 -0
  29. sleap_nn-0.0.1/docs/sample_configs/config_single_instance_unet.yaml +138 -0
  30. sleap_nn-0.0.1/docs/sample_configs/config_topdown_centered_instance_unet.yaml +139 -0
  31. sleap_nn-0.0.1/docs/sample_configs/config_topdown_multi_class_centered_instance_unet.yaml +147 -0
  32. sleap_nn-0.0.1/docs/step_by_step_tutorial.md +522 -0
  33. sleap_nn-0.0.1/docs/training.md +214 -0
  34. sleap_nn-0.0.1/example_notebooks/README.md +67 -0
  35. sleap_nn-0.0.1/example_notebooks/augmentation_guide.py +713 -0
  36. sleap_nn-0.0.1/example_notebooks/receptive_field_guide.py +344 -0
  37. sleap_nn-0.0.1/example_notebooks/training_demo.py +912 -0
  38. sleap_nn-0.0.1/mkdocs.yml +93 -0
  39. sleap_nn-0.0.1/pyproject.toml +158 -0
  40. sleap_nn-0.0.1/scripts/gen_changelog.py +74 -0
  41. sleap_nn-0.0.1/scripts/gen_ref_pages.py +35 -0
  42. sleap_nn-0.0.1/setup.cfg +4 -0
  43. sleap_nn-0.0.1/sleap_nn/.DS_Store +0 -0
  44. sleap_nn-0.0.1/sleap_nn/__init__.py +35 -0
  45. sleap_nn-0.0.1/sleap_nn/architectures/__init__.py +1 -0
  46. sleap_nn-0.0.1/sleap_nn/architectures/common.py +107 -0
  47. sleap_nn-0.0.1/sleap_nn/architectures/convnext.py +356 -0
  48. sleap_nn-0.0.1/sleap_nn/architectures/encoder_decoder.py +696 -0
  49. sleap_nn-0.0.1/sleap_nn/architectures/heads.py +596 -0
  50. sleap_nn-0.0.1/sleap_nn/architectures/model.py +197 -0
  51. sleap_nn-0.0.1/sleap_nn/architectures/swint.py +385 -0
  52. sleap_nn-0.0.1/sleap_nn/architectures/unet.py +294 -0
  53. sleap_nn-0.0.1/sleap_nn/architectures/utils.py +66 -0
  54. sleap_nn-0.0.1/sleap_nn/cli.py +402 -0
  55. sleap_nn-0.0.1/sleap_nn/config/__init__.py +1 -0
  56. sleap_nn-0.0.1/sleap_nn/config/data_config.py +471 -0
  57. sleap_nn-0.0.1/sleap_nn/config/get_config.py +866 -0
  58. sleap_nn-0.0.1/sleap_nn/config/model_config.py +1217 -0
  59. sleap_nn-0.0.1/sleap_nn/config/trainer_config.py +578 -0
  60. sleap_nn-0.0.1/sleap_nn/config/training_job_config.py +143 -0
  61. sleap_nn-0.0.1/sleap_nn/config/utils.py +156 -0
  62. sleap_nn-0.0.1/sleap_nn/data/__init__.py +1 -0
  63. sleap_nn-0.0.1/sleap_nn/data/augmentation.py +283 -0
  64. sleap_nn-0.0.1/sleap_nn/data/confidence_maps.py +166 -0
  65. sleap_nn-0.0.1/sleap_nn/data/custom_datasets.py +2168 -0
  66. sleap_nn-0.0.1/sleap_nn/data/edge_maps.py +323 -0
  67. sleap_nn-0.0.1/sleap_nn/data/identity.py +137 -0
  68. sleap_nn-0.0.1/sleap_nn/data/instance_centroids.py +61 -0
  69. sleap_nn-0.0.1/sleap_nn/data/instance_cropping.py +153 -0
  70. sleap_nn-0.0.1/sleap_nn/data/normalization.py +45 -0
  71. sleap_nn-0.0.1/sleap_nn/data/providers.py +340 -0
  72. sleap_nn-0.0.1/sleap_nn/data/resizing.py +143 -0
  73. sleap_nn-0.0.1/sleap_nn/data/utils.py +149 -0
  74. sleap_nn-0.0.1/sleap_nn/evaluation.py +797 -0
  75. sleap_nn-0.0.1/sleap_nn/inference/__init__.py +1 -0
  76. sleap_nn-0.0.1/sleap_nn/inference/bottomup.py +305 -0
  77. sleap_nn-0.0.1/sleap_nn/inference/identity.py +173 -0
  78. sleap_nn-0.0.1/sleap_nn/inference/paf_grouping.py +1527 -0
  79. sleap_nn-0.0.1/sleap_nn/inference/peak_finding.py +338 -0
  80. sleap_nn-0.0.1/sleap_nn/inference/predictors.py +3194 -0
  81. sleap_nn-0.0.1/sleap_nn/inference/single_instance.py +96 -0
  82. sleap_nn-0.0.1/sleap_nn/inference/topdown.py +814 -0
  83. sleap_nn-0.0.1/sleap_nn/inference/utils.py +138 -0
  84. sleap_nn-0.0.1/sleap_nn/legacy_models.py +529 -0
  85. sleap_nn-0.0.1/sleap_nn/predict.py +767 -0
  86. sleap_nn-0.0.1/sleap_nn/tracking/__init__.py +1 -0
  87. sleap_nn-0.0.1/sleap_nn/tracking/candidates/__init__.py +1 -0
  88. sleap_nn-0.0.1/sleap_nn/tracking/candidates/fixed_window.py +154 -0
  89. sleap_nn-0.0.1/sleap_nn/tracking/candidates/local_queues.py +173 -0
  90. sleap_nn-0.0.1/sleap_nn/tracking/track_instance.py +47 -0
  91. sleap_nn-0.0.1/sleap_nn/tracking/tracker.py +812 -0
  92. sleap_nn-0.0.1/sleap_nn/tracking/utils.py +100 -0
  93. sleap_nn-0.0.1/sleap_nn/train.py +522 -0
  94. sleap_nn-0.0.1/sleap_nn/training/__init__.py +1 -0
  95. sleap_nn-0.0.1/sleap_nn/training/callbacks.py +352 -0
  96. sleap_nn-0.0.1/sleap_nn/training/lightning_modules.py +1780 -0
  97. sleap_nn-0.0.1/sleap_nn/training/losses.py +60 -0
  98. sleap_nn-0.0.1/sleap_nn/training/model_trainer.py +940 -0
  99. sleap_nn-0.0.1/sleap_nn/training/utils.py +193 -0
  100. sleap_nn-0.0.1/sleap_nn.egg-info/PKG-INFO +176 -0
  101. sleap_nn-0.0.1/sleap_nn.egg-info/SOURCES.txt +252 -0
  102. sleap_nn-0.0.1/sleap_nn.egg-info/dependency_links.txt +1 -0
  103. sleap_nn-0.0.1/sleap_nn.egg-info/entry_points.txt +5 -0
  104. sleap_nn-0.0.1/sleap_nn.egg-info/requires.txt +54 -0
  105. sleap_nn-0.0.1/sleap_nn.egg-info/top_level.txt +1 -0
  106. sleap_nn-0.0.1/tests/__init__.py +1 -0
  107. sleap_nn-0.0.1/tests/architectures/test_architecture_utils.py +42 -0
  108. sleap_nn-0.0.1/tests/architectures/test_common.py +14 -0
  109. sleap_nn-0.0.1/tests/architectures/test_convnext.py +212 -0
  110. sleap_nn-0.0.1/tests/architectures/test_encoder_decoder.py +252 -0
  111. sleap_nn-0.0.1/tests/architectures/test_heads.py +393 -0
  112. sleap_nn-0.0.1/tests/architectures/test_model.py +700 -0
  113. sleap_nn-0.0.1/tests/architectures/test_swint.py +238 -0
  114. sleap_nn-0.0.1/tests/architectures/test_unet.py +191 -0
  115. sleap_nn-0.0.1/tests/assets/datasets/centered_pair_small.mp4 +0 -0
  116. sleap_nn-0.0.1/tests/assets/datasets/minimal_instance.pkg.slp +0 -0
  117. sleap_nn-0.0.1/tests/assets/datasets/small_robot.mp4 +0 -0
  118. sleap_nn-0.0.1/tests/assets/datasets/small_robot_minimal.slp +0 -0
  119. sleap_nn-0.0.1/tests/assets/inference/minimal_bboxes.pt +0 -0
  120. sleap_nn-0.0.1/tests/assets/inference/minimal_cms.pt +0 -0
  121. sleap_nn-0.0.1/tests/assets/legacy_models/get_dummy_activations.py +135 -0
  122. sleap_nn-0.0.1/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/best_model.h5 +0 -0
  123. sleap_nn-0.0.1/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/dummy_activations.h5 +0 -0
  124. sleap_nn-0.0.1/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/initial_config.json +155 -0
  125. sleap_nn-0.0.1/tests/assets/legacy_models/min_tracks_2node.UNet.bottomup_multiclass/training_config.json +218 -0
  126. sleap_nn-0.0.1/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/best_model.h5 +0 -0
  127. sleap_nn-0.0.1/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/dummy_activations.h5 +0 -0
  128. sleap_nn-0.0.1/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/initial_config.json +173 -0
  129. sleap_nn-0.0.1/tests/assets/legacy_models/min_tracks_2node.UNet.topdown_multiclass/training_config.json +240 -0
  130. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/best_model.h5 +0 -0
  131. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/dummy_activations.h5 +0 -0
  132. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/initial_config.json +154 -0
  133. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.train.slp +0 -0
  134. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_gt.val.slp +0 -0
  135. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.train.slp +0 -0
  136. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/labels_pr.val.slp +0 -0
  137. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.train.npz +0 -0
  138. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/metrics.val.npz +0 -0
  139. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_config.json +219 -0
  140. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.bottomup/training_log.csv +11 -0
  141. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/best_model.h5 +0 -0
  142. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/dummy_activations.h5 +0 -0
  143. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/initial_config.json +146 -0
  144. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.train.slp +0 -0
  145. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_gt.val.slp +0 -0
  146. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.train.slp +0 -0
  147. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/labels_pr.val.slp +0 -0
  148. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.train.npz +0 -0
  149. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/metrics.val.npz +0 -0
  150. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_config.json +206 -0
  151. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centered_instance/training_log.csv +11 -0
  152. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/best_model.h5 +0 -0
  153. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/dummy_activations.h5 +0 -0
  154. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/initial_config.json +145 -0
  155. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.train.slp +0 -0
  156. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_gt.val.slp +0 -0
  157. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.train.slp +0 -0
  158. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/labels_pr.val.slp +0 -0
  159. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.train.npz +0 -0
  160. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/metrics.val.npz +0 -0
  161. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_config.json +202 -0
  162. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_instance.UNet.centroid/training_log.csv +25 -0
  163. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_robot.UNet.single_instance/best_model.h5 +0 -0
  164. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_robot.UNet.single_instance/dummy_activations.h5 +0 -0
  165. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_robot.UNet.single_instance/initial_config.json +147 -0
  166. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.train.slp +0 -0
  167. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_robot.UNet.single_instance/labels_gt.val.slp +0 -0
  168. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_config.json +207 -0
  169. sleap_nn-0.0.1/tests/assets/legacy_models/minimal_robot.UNet.single_instance/training_log.csv +13 -0
  170. sleap_nn-0.0.1/tests/assets/legacy_sleap_json_configs/bottomup_multiclass_training_config.json +218 -0
  171. sleap_nn-0.0.1/tests/assets/legacy_sleap_json_configs/bottomup_training_config.json +219 -0
  172. sleap_nn-0.0.1/tests/assets/legacy_sleap_json_configs/centered_instance_training_config.json +206 -0
  173. sleap_nn-0.0.1/tests/assets/legacy_sleap_json_configs/centered_instance_with_scaling_training_config.json +226 -0
  174. sleap_nn-0.0.1/tests/assets/legacy_sleap_json_configs/centroid_training_config.json +202 -0
  175. sleap_nn-0.0.1/tests/assets/legacy_sleap_json_configs/single_instance_training_config.json +207 -0
  176. sleap_nn-0.0.1/tests/assets/legacy_sleap_json_configs/topdown_training_config.json +240 -0
  177. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_bottomup/best.ckpt +0 -0
  178. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_bottomup/initial_config.yaml +169 -0
  179. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_bottomup/labels_train_gt_0.slp +0 -0
  180. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_bottomup/labels_val_gt_0.slp +0 -0
  181. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_bottomup/training_config.yaml +169 -0
  182. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_bottomup/training_log.csv +12 -0
  183. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centered_instance/best.ckpt +0 -0
  184. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centered_instance/initial_config.yaml +165 -0
  185. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_train_gt_0.slp +0 -0
  186. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centered_instance/labels_val_gt_0.slp +0 -0
  187. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centered_instance/training_config.yaml +165 -0
  188. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centered_instance/training_log.csv +12 -0
  189. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centroid/best.ckpt +0 -0
  190. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centroid/initial_config.yaml +159 -0
  191. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centroid/labels_train_gt_0.slp +0 -0
  192. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centroid/labels_val_gt_0.slp +0 -0
  193. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centroid/training_config.yaml +159 -0
  194. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_centroid/training_log.csv +24 -0
  195. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/best.ckpt +0 -0
  196. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/initial_config.yaml +169 -0
  197. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_train_gt_0.slp +0 -0
  198. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/labels_val_gt_0.slp +0 -0
  199. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_config.yaml +169 -0
  200. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_bottomup/training_log.csv +202 -0
  201. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/best.ckpt +0 -0
  202. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/initial_config.yaml +174 -0
  203. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_train_gt_0.slp +0 -0
  204. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/labels_val_gt_0.slp +0 -0
  205. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_config.yaml +174 -0
  206. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_multiclass_centered_instance/training_log.csv +52 -0
  207. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_single_instance/best.ckpt +0 -0
  208. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_single_instance/initial_config.yaml +162 -0
  209. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_single_instance/labels_train_gt_0.slp +0 -0
  210. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_single_instance/labels_val_gt_0.slp +0 -0
  211. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_single_instance/training_config.yaml +162 -0
  212. sleap_nn-0.0.1/tests/assets/model_ckpts/minimal_instance_single_instance/training_log.csv +102 -0
  213. sleap_nn-0.0.1/tests/config/test_config_utils.py +42 -0
  214. sleap_nn-0.0.1/tests/config/test_data_config.py +210 -0
  215. sleap_nn-0.0.1/tests/config/test_model_config.py +189 -0
  216. sleap_nn-0.0.1/tests/config/test_trainer_config.py +311 -0
  217. sleap_nn-0.0.1/tests/config/test_training_job_config.py +328 -0
  218. sleap_nn-0.0.1/tests/conftest.py +7 -0
  219. sleap_nn-0.0.1/tests/data/test_augmentation.py +56 -0
  220. sleap_nn-0.0.1/tests/data/test_confmaps.py +43 -0
  221. sleap_nn-0.0.1/tests/data/test_custom_datasets.py +1265 -0
  222. sleap_nn-0.0.1/tests/data/test_edge_maps.py +206 -0
  223. sleap_nn-0.0.1/tests/data/test_identity.py +53 -0
  224. sleap_nn-0.0.1/tests/data/test_instance_centroids.py +34 -0
  225. sleap_nn-0.0.1/tests/data/test_instance_cropping.py +58 -0
  226. sleap_nn-0.0.1/tests/data/test_normalization.py +22 -0
  227. sleap_nn-0.0.1/tests/data/test_providers.py +258 -0
  228. sleap_nn-0.0.1/tests/data/test_resizing.py +69 -0
  229. sleap_nn-0.0.1/tests/data/test_utils.py +122 -0
  230. sleap_nn-0.0.1/tests/fixtures/__init__.py +5 -0
  231. sleap_nn-0.0.1/tests/fixtures/datasets.py +178 -0
  232. sleap_nn-0.0.1/tests/fixtures/inference.py +29 -0
  233. sleap_nn-0.0.1/tests/fixtures/legacy_models.py +159 -0
  234. sleap_nn-0.0.1/tests/fixtures/legacy_sleap_json_configs.py +67 -0
  235. sleap_nn-0.0.1/tests/fixtures/model_ckpts.py +50 -0
  236. sleap_nn-0.0.1/tests/inference/__init__.py +1 -0
  237. sleap_nn-0.0.1/tests/inference/test_bottomup.py +183 -0
  238. sleap_nn-0.0.1/tests/inference/test_paf_grouping.py +606 -0
  239. sleap_nn-0.0.1/tests/inference/test_peak_finding.py +383 -0
  240. sleap_nn-0.0.1/tests/inference/test_predictors.py +980 -0
  241. sleap_nn-0.0.1/tests/inference/test_single_instance.py +84 -0
  242. sleap_nn-0.0.1/tests/inference/test_topdown.py +460 -0
  243. sleap_nn-0.0.1/tests/inference/test_utils.py +40 -0
  244. sleap_nn-0.0.1/tests/test_cli.py +193 -0
  245. sleap_nn-0.0.1/tests/test_evaluation.py +561 -0
  246. sleap_nn-0.0.1/tests/test_legacy_models.py +1125 -0
  247. sleap_nn-0.0.1/tests/test_predict.py +1069 -0
  248. sleap_nn-0.0.1/tests/test_train.py +640 -0
  249. sleap_nn-0.0.1/tests/test_version.py +5 -0
  250. sleap_nn-0.0.1/tests/tracking/candidates/test_fixed_window.py +61 -0
  251. sleap_nn-0.0.1/tests/tracking/candidates/test_local_queues.py +64 -0
  252. sleap_nn-0.0.1/tests/tracking/test_tracker.py +428 -0
  253. sleap_nn-0.0.1/tests/training/test_lightning_modules.py +661 -0
  254. sleap_nn-0.0.1/tests/training/test_model_trainer.py +1169 -0
@@ -0,0 +1,14 @@
1
+ Run tests with coverage.
2
+
3
+ Command to run:
4
+ ```
5
+ pytest -q --maxfail=1 --cov=sleap_nn --cov-branch tests/ && rm -f .coverage.* && python -m coverage annotate
6
+ ```
7
+
8
+ The result will be the terminal output and the line-by-line coverage will be in files sitting next to each module with the file naming `{module_name.py},cover`.
9
+
10
+ If you are working on a PR, figure out which files were changed and look for coverage specifically in those. If you don't know which files to look for coverage in, use this:
11
+
12
+ ```
13
+ git diff --name-only $(git merge-base origin/main HEAD) | jq -R . | jq -s .
14
+ ```
@@ -0,0 +1,9 @@
1
+ Run linting with `black` and `ruff`.
2
+
3
+ Command:
4
+
5
+ ```
6
+ black sleap_nn tests && ruff check --fix sleap_nn tests
7
+ ```
8
+
9
+ Then manually fix any remaining errors which cannot be automatically fixed by ruff.
@@ -0,0 +1,7 @@
1
+ Update PR description.
2
+
3
+ Use the `gh` CLI to fetch the current PR description, then update it with a comprehensive description of the changes made in this PR.
4
+
5
+ If there is an associated issue (linked in the PR metadata or mentioned in the PR description), then use the `gh` CLI to fetch that too to contextualize the work done in the PR.
6
+
7
+ Include a summary, example usage (for enhancements), API changes, and other notes for future consideration (including reasoning behind design decisions).
@@ -0,0 +1,15 @@
1
+ README.md
2
+ docs/
3
+ *.egg-info/
4
+
5
+
6
+ # Test artifacts
7
+ tests/
8
+ *.pytest_cache/
9
+
10
+ *.ruff_cache
11
+ codecov.yml
12
+
13
+ # Git files
14
+ .github/
15
+ .gitignore
@@ -0,0 +1,45 @@
1
+ # Package builds
2
+ name: Build
3
+
4
+ on:
5
+ release:
6
+ types:
7
+ - published
8
+ workflow_dispatch:
9
+ inputs:
10
+ testpypi:
11
+ description: 'Publish to TestPyPI'
12
+ required: false
13
+ type: boolean
14
+ default: false
15
+
16
+ jobs:
17
+ pypi:
18
+ name: PyPI Wheel
19
+ runs-on: ubuntu-latest
20
+ permissions:
21
+ id-token: write # Required for PyPI trusted publishing
22
+ steps:
23
+
24
+ - name: Checkout repo
25
+ uses: actions/checkout@v4
26
+
27
+ - name: Setup UV
28
+ uses: astral-sh/setup-uv@v6
29
+ with:
30
+ python-version: "3.12"
31
+
32
+ - name: Build distributions
33
+ run: |
34
+ uv build
35
+
36
+ - name: Determine index and publish
37
+ run: |
38
+ # Manual override for TestPyPI
39
+ if [[ "${{ github.event.inputs.testpypi }}" == "true" ]]; then
40
+ echo "Manual TestPyPI publishing..."
41
+ uv publish --index testpypi --trusted-publishing always
42
+ else
43
+ echo "Publishing to PyPI..."
44
+ uv publish
45
+ fi
@@ -0,0 +1,141 @@
1
+ name: CI
2
+
3
+ on:
4
+ pull_request:
5
+ types: [opened, reopened, synchronize]
6
+ paths:
7
+ - "sleap_nn/**"
8
+ - "tests/**"
9
+ - ".github/workflows/ci.yml"
10
+ - "pyproject.toml"
11
+
12
+ jobs:
13
+ lint:
14
+ name: Lint
15
+ runs-on: ubuntu-latest
16
+ steps:
17
+ - name: Checkout repo
18
+ uses: actions/checkout@v4
19
+
20
+ - name: Set up uv
21
+ uses: astral-sh/setup-uv@v5
22
+ with:
23
+ enable-cache: false
24
+
25
+ - name: Set up Python
26
+ run: uv python install 3.11
27
+
28
+ - name: Install dev dependencies and torch
29
+ run: uv sync --extra dev --extra torch-cpu
30
+
31
+ - name: Run Black
32
+ run: uv run black --check sleap_nn tests
33
+
34
+ - name: Run Ruff
35
+ run: uv run ruff check sleap_nn/
36
+
37
+ tests:
38
+ timeout-minutes: 30
39
+ strategy:
40
+ fail-fast: false
41
+ matrix:
42
+ os: ["ubuntu", "windows", "mac", "self-hosted-gpu"]
43
+ include:
44
+ - os: ubuntu
45
+ runs-on: ubuntu-latest
46
+ - os: windows
47
+ runs-on: windows-latest
48
+ - os: mac
49
+ runs-on: macos-14
50
+ - os: self-hosted-gpu
51
+ runs-on: [self-hosted, puma, gpu, 2xgpu]
52
+ python: [3.12]
53
+
54
+ name: Tests (${{ matrix.os }}, Python ${{ matrix.python }})
55
+ runs-on: ${{ matrix.runs-on }}
56
+
57
+ steps:
58
+ - name: Checkout repo
59
+ uses: actions/checkout@v4
60
+
61
+ - name: Set up uv
62
+ uses: astral-sh/setup-uv@v5
63
+ with:
64
+ enable-cache: false
65
+
66
+ - name: Set up Python (non-self-hosted GPU)
67
+ if: matrix.os != 'self-hosted-gpu'
68
+ run: uv python install ${{ matrix.python }}
69
+
70
+ - name: Install dev dependencies and torch (self-hosted GPU)
71
+ if: matrix.os == 'self-hosted-gpu'
72
+ run: uv sync --extra dev --extra torch-cuda128
73
+
74
+ - name: Install dev dependencies and torch (non-self-hosted GPU)
75
+ if: matrix.os != 'self-hosted-gpu'
76
+ run: uv sync --extra dev --extra torch-cpu
77
+
78
+ - name: Print environment info
79
+ run: |
80
+ echo "=== UV Environment ==="
81
+ uv run python --version
82
+ uv run python -c "import sys; print('Python executable:', sys.executable)"
83
+ echo "=== UV Environment NumPy Check ==="
84
+ uv run python -c "import numpy; print('NumPy version:', numpy.__version__); print('NumPy location:', numpy.__file__)" || echo "NumPy import failed in uv environment"
85
+ echo "=== CUDA Availability Check ==="
86
+ uv run python -c "
87
+ import torch
88
+ print(f'PyTorch version: {torch.__version__}')
89
+ print(f'CUDA available: {torch.cuda.is_available()}')
90
+ print(f'CUDA device count: {torch.cuda.device_count()}')
91
+ if torch.cuda.is_available():
92
+ print(f'CUDA version: {torch.version.cuda}')
93
+ print(f'Current device: {torch.cuda.current_device()}')
94
+ print(f'Device name: {torch.cuda.get_device_name(0)}')
95
+ else:
96
+ print('CUDA is not available')
97
+ " || echo "CUDA check failed"
98
+ echo "=== PIP EXECUTABLE COMPARISON ==="
99
+ uv run python -c "import subprocess; print('pip from uv run python:', subprocess.check_output(['pip', '--version']).decode().strip())" || echo "pip not found from python"
100
+ uv run pip --version || echo "uv run pip failed"
101
+ echo "=== UV pip list vs python -m pip list ==="
102
+ echo "--- uv run pip list ---"
103
+ uv run pip list | head -20
104
+ echo "--- uv run python -m pip list ---"
105
+ uv run python -m pip list | head -20
106
+ echo "=== UV ENVIRONMENT CHECK ==="
107
+ uv run python -c "import os; print('VIRTUAL_ENV:', os.environ.get('VIRTUAL_ENV', 'Not set'))"
108
+ echo "=== Import Test ==="
109
+ uv run python -c "import torch; import lightning; import kornia; print('All imports successful')" || echo "Import test failed"
110
+
111
+ - name: Check MPS backend (macOS only)
112
+ if: runner.os == 'macOS'
113
+ run: |
114
+ echo "=== macOS MPS Backend Check ==="
115
+ uv run python -c "
116
+ import torch
117
+ print(f'PyTorch version: {torch.__version__}')
118
+ print(f'MPS available: {torch.backends.mps.is_available()}')
119
+ print(f'MPS built: {torch.backends.mps.is_built()}')
120
+ if torch.backends.mps.is_available():
121
+ print('MPS backend is available and ready to use!')
122
+ device = torch.device('mps')
123
+ test_tensor = torch.randn(3, 3).to(device)
124
+ print(f'Test tensor on MPS: {test_tensor.device}')
125
+ else:
126
+ print('MPS backend is not available on this macOS system')
127
+ "
128
+
129
+ - name: Run pytest
130
+ run: |
131
+ echo "=== Final environment check before tests ==="
132
+ uv run python -c "import numpy, torch, lightning, kornia; print(f'All packages available: numpy={numpy.__version__}, torch={torch.__version__}')"
133
+ echo "=== Running pytest ==="
134
+ uv run pytest --cov=sleap_nn --cov-report=xml --durations=-1 tests/
135
+
136
+ - name: Upload coverage
137
+ uses: codecov/codecov-action@v5
138
+ with:
139
+ fail_ci_if_error: true
140
+ verbose: false
141
+ token: ${{ secrets.CODECOV_TOKEN }}
@@ -0,0 +1,25 @@
1
+ # Codespell configuration is within pyproject.toml
2
+ ---
3
+ name: Codespell
4
+
5
+ on:
6
+ push:
7
+ branches: [main]
8
+ pull_request:
9
+ branches: [main]
10
+
11
+ permissions:
12
+ contents: read
13
+
14
+ jobs:
15
+ codespell:
16
+ name: Check for spelling errors
17
+ runs-on: ubuntu-latest
18
+
19
+ steps:
20
+ - name: Checkout
21
+ uses: actions/checkout@v4
22
+ - name: Annotate locations with typos
23
+ uses: codespell-project/codespell-problem-matcher@v1
24
+ - name: Codespell
25
+ uses: codespell-project/actions-codespell@v2
@@ -0,0 +1,59 @@
1
+ name: Docs
2
+
3
+ on:
4
+ release:
5
+ types:
6
+ - published
7
+ push:
8
+ branches:
9
+ - main
10
+ paths:
11
+ - "sleap_nn/**"
12
+ - "docs/**"
13
+ - "mkdocs.yml"
14
+ - ".github/workflows/docs.yml"
15
+
16
+ jobs:
17
+ docs:
18
+ name: Docs
19
+ runs-on: "ubuntu-latest"
20
+ permissions:
21
+ contents: write
22
+ steps:
23
+ - name: Checkout repo
24
+ uses: actions/checkout@v3
25
+ with:
26
+ fetch-depth: 0
27
+
28
+ - name: Setup Python
29
+ uses: actions/setup-python@v4
30
+ with:
31
+ python-version: "3.12"
32
+
33
+ - name: Install sleap-nn with docs dependencies
34
+ run: |
35
+ pip install -e ".[docs]"
36
+
37
+ - name: Print environment info
38
+ run: |
39
+ which python
40
+ pip freeze
41
+
42
+ - name: Setup Git user
43
+ run: |
44
+ git config --global user.name "github-actions[bot]"
45
+ git config --global user.email "github-actions[bot]@users.noreply.github.com"
46
+
47
+ - name: Build and upload docs (release)
48
+ if: ${{ github.event_name == 'release' && github.event.action == 'published' }}
49
+ env:
50
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
51
+ run: |
52
+ mike deploy --update-aliases --allow-empty --push "${{ github.event.release.tag_name }}" latest
53
+
54
+ - name: Build and upload docs (dev)
55
+ if: ${{ github.event_name == 'push' }}
56
+ env:
57
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
58
+ run: |
59
+ mike deploy --update-aliases --allow-empty --push dev
@@ -0,0 +1,173 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # ruff
147
+ .ruff_cache/
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+
165
+ .vscode/
166
+
167
+ wandb/
168
+
169
+ # Serena
170
+ .serena/
171
+
172
+ # macOS
173
+ .DS_Store
@@ -0,0 +1,81 @@
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Common Development Commands
6
+
7
+ ### Testing
8
+ - Run all tests: `pytest tests`
9
+ - Run specific test file: `pytest tests/path/to/test_file.py`
10
+ - Run with coverage: `pytest --cov=sleap_nn --cov-report=xml tests/`
11
+ - Run tests with duration info: `pytest --durations=-1 tests/`
12
+
13
+ ### Code Quality
14
+ - Format code: `black sleap_nn tests`
15
+ - Check formatting: `black --check sleap_nn tests`
16
+ - Run linter: `ruff check sleap_nn/`
17
+
18
+ ### Environment Setup
19
+ - GPU (Windows/Linux): `mamba env create -f environment.yml`
20
+ - CPU (Windows/Linux/Intel Mac): `mamba env create -f environment_cpu.yml`
21
+ - Apple Silicon (M1/M2 Mac): `mamba env create -f environment_osx-arm64.yml`
22
+ - Activate environment: `mamba activate sleap-nn`
23
+
24
+ ## Architecture Overview
25
+
26
+ sleap-nn is a PyTorch-based neural network backend for animal pose estimation. The codebase follows a modular architecture:
27
+
28
+ ### Core Components
29
+
30
+ 1. **Model Architecture** (`sleap_nn/architectures/`)
31
+ - Backbone networks: UNet, ConvNext, SwinT (via `model.py:get_backbone`)
32
+ - Head modules for different tasks: confidence maps, centroids, PAFs, etc.
33
+ - Model configuration via Hydra/OmegaConf
34
+
35
+ 2. **Data Pipeline** (`sleap_nn/data/`)
36
+ - Providers for reading SLEAP files (`providers.py`)
37
+ - Data augmentation, normalization, resizing
38
+ - Confidence map and edge map generation
39
+ - Instance cropping and centroid computation
40
+
41
+ 3. **Training System** (`sleap_nn/training/`)
42
+ - Lightning-based training modules (`lightning_modules.py`)
43
+ - ModelTrainer class for orchestrating training (`model_trainer.py`)
44
+ - Custom losses and callbacks
45
+ - Configuration via `TrainingJobConfig`
46
+
47
+ 4. **Inference Pipeline** (`sleap_nn/inference/`)
48
+ - Different predictors for each model type (single instance, top-down, bottom-up)
49
+ - Peak finding and PAF grouping for multi-instance
50
+ - Unified prediction interface via `predictors.py`
51
+
52
+ 5. **Tracking** (`sleap_nn/tracking/`)
53
+ - Instance tracking across frames
54
+ - Candidate generation with fixed windows and local queues
55
+
56
+ ### Configuration System
57
+
58
+ The project uses a hierarchical configuration system with three main sections:
59
+ - `data_config`: Data pipeline configuration
60
+ - `model_config`: Model architecture configuration
61
+ - `trainer_config`: Training hyperparameters and Lightning configuration
62
+
63
+ Configurations are managed via Hydra and can be specified in YAML files (see `docs/config_*.yaml` examples).
64
+
65
+ ### Key Entry Points
66
+
67
+ - Training: `sleap_nn/train.py` - Hydra-based training entry point
68
+ - Inference: `sleap_nn/predict.py` - Run inference on trained models
69
+ - CLI: `sleap_nn/cli.py` - Command-line interface (currently minimal)
70
+ - Evaluation: `sleap_nn/evaluation.py` - Model evaluation utilities
71
+
72
+ ### Model Types
73
+
74
+ The system supports multiple model architectures for pose estimation:
75
+ - Single Instance: One animal per frame
76
+ - Centered Instance: Crop-based single instance
77
+ - Centroid: Animal center detection
78
+ - Top-Down: Centroid → Instance detection
79
+ - Bottom-Up: Multi-instance with PAFs
80
+
81
+ Each model type has corresponding head modules, data processing, and inference pipelines.