sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -5
  8. sleap_nn/config/trainer_config.py +0 -71
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +34 -364
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -69
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +204 -201
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
sleap_nn/predict.py CHANGED
@@ -67,16 +67,13 @@ def run_inference(
67
67
  only_predicted_frames: bool = False,
68
68
  no_empty_frames: bool = False,
69
69
  batch_size: int = 4,
70
- queue_maxsize: int = 32,
70
+ queue_maxsize: int = 8,
71
71
  video_index: Optional[int] = None,
72
72
  video_dataset: Optional[str] = None,
73
73
  video_input_format: str = "channels_last",
74
74
  frames: Optional[list] = None,
75
75
  crop_size: Optional[int] = None,
76
76
  peak_threshold: Union[float, List[float]] = 0.2,
77
- filter_overlapping: bool = False,
78
- filter_overlapping_method: str = "iou",
79
- filter_overlapping_threshold: float = 0.8,
80
77
  integral_refinement: Optional[str] = "integral",
81
78
  integral_patch_size: int = 5,
82
79
  return_confmaps: bool = False,
@@ -113,7 +110,6 @@ def run_inference(
113
110
  tracking_pre_cull_iou_threshold: float = 0,
114
111
  tracking_clean_instance_count: int = 0,
115
112
  tracking_clean_iou_threshold: float = 0,
116
- gui: bool = False,
117
113
  ):
118
114
  """Entry point to run inference on trained SLEAP-NN models.
119
115
 
@@ -151,7 +147,7 @@ def run_inference(
151
147
  only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
152
148
  no_empty_frames: (bool) `True` if empty frames that did not have predictions should be cleared before saving to output. Default: `False`.
153
149
  batch_size: (int) Number of samples per batch. Default: 4.
154
- queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
150
+ queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
155
151
  video_index: (int) Integer index of video in .slp file to predict on. To be used with
156
152
  an .slp path as an alternative to specifying the video path.
157
153
  video_dataset: (str) The dataset for HDF5 videos.
@@ -164,15 +160,6 @@ def run_inference(
164
160
  centroid and centered-instance model, where the first element corresponds
165
161
  to centroid model peak finding threshold and the second element is for
166
162
  centered-instance model peak finding.
167
- filter_overlapping: (bool) If True, removes overlapping instances after
168
- inference using greedy NMS. Applied independently of tracking.
169
- Default: False.
170
- filter_overlapping_method: (str) Similarity metric for filtering overlapping
171
- instances. One of "iou" (bounding box) or "oks" (keypoint similarity).
172
- Default: "iou".
173
- filter_overlapping_threshold: (float) Similarity threshold for filtering.
174
- Instances with similarity > threshold are removed (keeping higher-scoring).
175
- Typical values: 0.3 (aggressive) to 0.8 (permissive). Default: 0.8.
176
163
  integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
177
164
  If `"integral"`, peaks will be refined with integral regression.
178
165
  Default: `"integral"`.
@@ -263,8 +250,6 @@ def run_inference(
263
250
  tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
264
251
  tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
265
252
  tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
266
- gui: (bool) If True, outputs JSON progress lines for GUI integration instead
267
- of Rich progress bars. Default: False.
268
253
 
269
254
  Returns:
270
255
  Returns `sio.Labels` object if `make_labels` is True. Else this function returns
@@ -448,6 +433,13 @@ def run_inference(
448
433
  else "mps" if torch.backends.mps.is_available() else "cpu"
449
434
  )
450
435
 
436
+ if integral_refinement is not None and device == "mps": # TODO
437
+ # kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
438
+ logger.info(
439
+ "Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
440
+ )
441
+ integral_refinement = None
442
+
451
443
  logger.info(f"Using device: {device}")
452
444
 
453
445
  # initializes the inference model
@@ -466,9 +458,6 @@ def run_inference(
466
458
  anchor_part=anchor_part,
467
459
  )
468
460
 
469
- # Set GUI mode for progress output
470
- predictor.gui = gui
471
-
472
461
  if (
473
462
  tracking
474
463
  and not isinstance(predictor, BottomUpMultiClassPredictor)
@@ -564,20 +553,6 @@ def run_inference(
564
553
  make_labels=make_labels,
565
554
  )
566
555
 
567
- # Filter overlapping instances (independent of tracking)
568
- if filter_overlapping and make_labels:
569
- from sleap_nn.inference.postprocessing import filter_overlapping_instances
570
-
571
- output = filter_overlapping_instances(
572
- output,
573
- threshold=filter_overlapping_threshold,
574
- method=filter_overlapping_method,
575
- )
576
- logger.info(
577
- f"Filtered overlapping instances with {filter_overlapping_method.upper()} "
578
- f"threshold: {filter_overlapping_threshold}"
579
- )
580
-
581
556
  if tracking:
582
557
  lfs = [x for x in output]
583
558
  if tracking_clean_instance_count > 0:
@@ -632,9 +607,6 @@ def run_inference(
632
607
  # Build inference parameters for provenance
633
608
  inference_params = {
634
609
  "peak_threshold": peak_threshold,
635
- "filter_overlapping": filter_overlapping,
636
- "filter_overlapping_method": filter_overlapping_method,
637
- "filter_overlapping_threshold": filter_overlapping_threshold,
638
610
  "integral_refinement": integral_refinement,
639
611
  "integral_patch_size": integral_patch_size,
640
612
  "batch_size": batch_size,
sleap_nn/train.py CHANGED
@@ -118,70 +118,6 @@ def run_training(
118
118
  logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
119
119
  logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
120
120
 
121
- # Log test metrics to wandb summary
122
- if (
123
- d_name.startswith("test")
124
- and trainer.config.trainer_config.use_wandb
125
- ):
126
- import wandb
127
-
128
- if wandb.run is not None:
129
- summary_metrics = {
130
- f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
131
- f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
132
- "oks_voc.mAP"
133
- ],
134
- f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
135
- "oks_voc.mAR"
136
- ],
137
- f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
138
- f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
139
- f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
140
- f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
141
- "avg"
142
- ],
143
- f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
144
- "p50"
145
- ],
146
- f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
147
- "p95"
148
- ],
149
- f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
150
- "p99"
151
- ],
152
- f"eval/{d_name}/visibility_precision": metrics[
153
- "visibility_metrics"
154
- ]["precision"],
155
- f"eval/{d_name}/visibility_recall": metrics[
156
- "visibility_metrics"
157
- ]["recall"],
158
- }
159
- for key, value in summary_metrics.items():
160
- wandb.run.summary[key] = value
161
-
162
- # Finish wandb run and cleanup after all evaluation is complete
163
- if trainer.config.trainer_config.use_wandb:
164
- import wandb
165
- import shutil
166
-
167
- if wandb.run is not None:
168
- wandb.finish()
169
-
170
- # Delete local wandb logs if configured
171
- wandb_config = trainer.config.trainer_config.wandb
172
- should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
173
- wandb_config.delete_local_logs is None
174
- and wandb_config.wandb_mode != "offline"
175
- )
176
- if should_delete_wandb_logs:
177
- wandb_dir = run_path / "wandb"
178
- if wandb_dir.exists():
179
- logger.info(
180
- f"Deleting local wandb logs at {wandb_dir}... "
181
- "(set trainer_config.wandb.delete_local_logs=false to disable)"
182
- )
183
- shutil.rmtree(wandb_dir, ignore_errors=True)
184
-
185
121
 
186
122
  def train(
187
123
  train_labels_path: Optional[List[str]] = None,
@@ -203,9 +139,9 @@ def train(
203
139
  crop_size: Optional[int] = None,
204
140
  min_crop_size: Optional[int] = 100,
205
141
  crop_padding: Optional[int] = None,
206
- use_augmentations_train: bool = True,
142
+ use_augmentations_train: bool = False,
207
143
  intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
208
- geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = "rotation",
144
+ geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
209
145
  init_weight: str = "default",
210
146
  pretrained_backbone_weights: Optional[str] = None,
211
147
  pretrained_head_weights: Optional[str] = None,
@@ -306,7 +242,7 @@ def train(
306
242
  crop size. If `None`, padding is auto-computed based on augmentation settings.
307
243
  Only used when `crop_size` is `None`. Default: None.
308
244
  use_augmentations_train: True if the data augmentation should be applied to the
309
- training data, else False. Default: True.
245
+ training data, else False. Default: False.
310
246
  intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
311
247
  or list of strings from the above allowed values. To have custom values, pass
312
248
  a dict with the structure in `sleap_nn.config.data_config.IntensityConfig`.
@@ -318,8 +254,7 @@ def train(
318
254
  or list of strings from the above allowed values. To have custom values, pass
319
255
  a dict with the structure in `sleap_nn.config.data_config.GeometryConfig`.
320
256
  For eg: {
321
- "rotation_min": -45,
322
- "rotation_max": 45,
257
+ "rotation": 45,
323
258
  "affine_p": 1.0
324
259
  }
325
260
  init_weight: model weights initialization method. "default" uses kaiming uniform