sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +2 -4
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -10
  8. sleap_nn/config/trainer_config.py +0 -76
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +39 -411
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -74
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +184 -211
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/predict.py CHANGED
@@ -67,16 +67,13 @@ def run_inference(
67
67
  only_predicted_frames: bool = False,
68
68
  no_empty_frames: bool = False,
69
69
  batch_size: int = 4,
70
- queue_maxsize: int = 32,
70
+ queue_maxsize: int = 8,
71
71
  video_index: Optional[int] = None,
72
72
  video_dataset: Optional[str] = None,
73
73
  video_input_format: str = "channels_last",
74
74
  frames: Optional[list] = None,
75
75
  crop_size: Optional[int] = None,
76
76
  peak_threshold: Union[float, List[float]] = 0.2,
77
- filter_overlapping: bool = False,
78
- filter_overlapping_method: str = "iou",
79
- filter_overlapping_threshold: float = 0.8,
80
77
  integral_refinement: Optional[str] = "integral",
81
78
  integral_patch_size: int = 5,
82
79
  return_confmaps: bool = False,
@@ -113,7 +110,6 @@ def run_inference(
113
110
  tracking_pre_cull_iou_threshold: float = 0,
114
111
  tracking_clean_instance_count: int = 0,
115
112
  tracking_clean_iou_threshold: float = 0,
116
- gui: bool = False,
117
113
  ):
118
114
  """Entry point to run inference on trained SLEAP-NN models.
119
115
 
@@ -151,7 +147,7 @@ def run_inference(
151
147
  only_predicted_frames: (bool) `True` to run inference only on frames that already have predictions. Default: `False`.
152
148
  no_empty_frames: (bool) `True` if empty frames that did not have predictions should be cleared before saving to output. Default: `False`.
153
149
  batch_size: (int) Number of samples per batch. Default: 4.
154
- queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 32.
150
+ queue_maxsize: (int) Maximum size of the frame buffer queue. Default: 8.
155
151
  video_index: (int) Integer index of video in .slp file to predict on. To be used with
156
152
  an .slp path as an alternative to specifying the video path.
157
153
  video_dataset: (str) The dataset for HDF5 videos.
@@ -164,15 +160,6 @@ def run_inference(
164
160
  centroid and centered-instance model, where the first element corresponds
165
161
  to centroid model peak finding threshold and the second element is for
166
162
  centered-instance model peak finding.
167
- filter_overlapping: (bool) If True, removes overlapping instances after
168
- inference using greedy NMS. Applied independently of tracking.
169
- Default: False.
170
- filter_overlapping_method: (str) Similarity metric for filtering overlapping
171
- instances. One of "iou" (bounding box) or "oks" (keypoint similarity).
172
- Default: "iou".
173
- filter_overlapping_threshold: (float) Similarity threshold for filtering.
174
- Instances with similarity > threshold are removed (keeping higher-scoring).
175
- Typical values: 0.3 (aggressive) to 0.8 (permissive). Default: 0.8.
176
163
  integral_refinement: (str) If `None`, returns the grid-aligned peaks with no refinement.
177
164
  If `"integral"`, peaks will be refined with integral regression.
178
165
  Default: `"integral"`.
@@ -263,8 +250,6 @@ def run_inference(
263
250
  tracking_pre_cull_iou_threshold: If non-zero and pre_cull_to_target also set, then use IOU threshold to remove overlapping instances over count *before* tracking. (default: 0)
264
251
  tracking_clean_instance_count: Target number of instances to clean *after* tracking. (default: 0)
265
252
  tracking_clean_iou_threshold: IOU to use when culling instances *after* tracking. (default: 0)
266
- gui: (bool) If True, outputs JSON progress lines for GUI integration instead
267
- of Rich progress bars. Default: False.
268
253
 
269
254
  Returns:
270
255
  Returns `sio.Labels` object if `make_labels` is True. Else this function returns
@@ -448,6 +433,13 @@ def run_inference(
448
433
  else "mps" if torch.backends.mps.is_available() else "cpu"
449
434
  )
450
435
 
436
+ if integral_refinement is not None and device == "mps": # TODO
437
+ # kornia/geometry/transform/imgwarp.py:382: in get_perspective_transform. NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
438
+ logger.info(
439
+ "Integral refinement is not supported with MPS accelerator. Setting integral refinement to None."
440
+ )
441
+ integral_refinement = None
442
+
451
443
  logger.info(f"Using device: {device}")
452
444
 
453
445
  # initializes the inference model
@@ -466,9 +458,6 @@ def run_inference(
466
458
  anchor_part=anchor_part,
467
459
  )
468
460
 
469
- # Set GUI mode for progress output
470
- predictor.gui = gui
471
-
472
461
  if (
473
462
  tracking
474
463
  and not isinstance(predictor, BottomUpMultiClassPredictor)
@@ -564,20 +553,6 @@ def run_inference(
564
553
  make_labels=make_labels,
565
554
  )
566
555
 
567
- # Filter overlapping instances (independent of tracking)
568
- if filter_overlapping and make_labels:
569
- from sleap_nn.inference.postprocessing import filter_overlapping_instances
570
-
571
- output = filter_overlapping_instances(
572
- output,
573
- threshold=filter_overlapping_threshold,
574
- method=filter_overlapping_method,
575
- )
576
- logger.info(
577
- f"Filtered overlapping instances with {filter_overlapping_method.upper()} "
578
- f"threshold: {filter_overlapping_threshold}"
579
- )
580
-
581
556
  if tracking:
582
557
  lfs = [x for x in output]
583
558
  if tracking_clean_instance_count > 0:
@@ -632,9 +607,6 @@ def run_inference(
632
607
  # Build inference parameters for provenance
633
608
  inference_params = {
634
609
  "peak_threshold": peak_threshold,
635
- "filter_overlapping": filter_overlapping,
636
- "filter_overlapping_method": filter_overlapping_method,
637
- "filter_overlapping_threshold": filter_overlapping_threshold,
638
610
  "integral_refinement": integral_refinement,
639
611
  "integral_patch_size": integral_patch_size,
640
612
  "batch_size": batch_size,
sleap_nn/train.py CHANGED
@@ -118,70 +118,6 @@ def run_training(
118
118
  logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
119
119
  logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
120
120
 
121
- # Log test metrics to wandb summary
122
- if (
123
- d_name.startswith("test")
124
- and trainer.config.trainer_config.use_wandb
125
- ):
126
- import wandb
127
-
128
- if wandb.run is not None:
129
- summary_metrics = {
130
- f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
131
- f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
132
- "oks_voc.mAP"
133
- ],
134
- f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
135
- "oks_voc.mAR"
136
- ],
137
- f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
138
- f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
139
- f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
140
- f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
141
- "avg"
142
- ],
143
- f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
144
- "p50"
145
- ],
146
- f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
147
- "p95"
148
- ],
149
- f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
150
- "p99"
151
- ],
152
- f"eval/{d_name}/visibility_precision": metrics[
153
- "visibility_metrics"
154
- ]["precision"],
155
- f"eval/{d_name}/visibility_recall": metrics[
156
- "visibility_metrics"
157
- ]["recall"],
158
- }
159
- for key, value in summary_metrics.items():
160
- wandb.run.summary[key] = value
161
-
162
- # Finish wandb run and cleanup after all evaluation is complete
163
- if trainer.config.trainer_config.use_wandb:
164
- import wandb
165
- import shutil
166
-
167
- if wandb.run is not None:
168
- wandb.finish()
169
-
170
- # Delete local wandb logs if configured
171
- wandb_config = trainer.config.trainer_config.wandb
172
- should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
173
- wandb_config.delete_local_logs is None
174
- and wandb_config.wandb_mode != "offline"
175
- )
176
- if should_delete_wandb_logs:
177
- wandb_dir = run_path / "wandb"
178
- if wandb_dir.exists():
179
- logger.info(
180
- f"Deleting local wandb logs at {wandb_dir}... "
181
- "(set trainer_config.wandb.delete_local_logs=false to disable)"
182
- )
183
- shutil.rmtree(wandb_dir, ignore_errors=True)
184
-
185
121
 
186
122
  def train(
187
123
  train_labels_path: Optional[List[str]] = None,
@@ -203,9 +139,9 @@ def train(
203
139
  crop_size: Optional[int] = None,
204
140
  min_crop_size: Optional[int] = 100,
205
141
  crop_padding: Optional[int] = None,
206
- use_augmentations_train: bool = True,
142
+ use_augmentations_train: bool = False,
207
143
  intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
208
- geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = "rotation",
144
+ geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
209
145
  init_weight: str = "default",
210
146
  pretrained_backbone_weights: Optional[str] = None,
211
147
  pretrained_head_weights: Optional[str] = None,
@@ -239,7 +175,6 @@ def train(
239
175
  wandb_save_viz_imgs_wandb: bool = False,
240
176
  wandb_resume_prv_runid: Optional[str] = None,
241
177
  wandb_group_name: Optional[str] = None,
242
- wandb_delete_local_logs: Optional[bool] = None,
243
178
  optimizer: str = "Adam",
244
179
  learning_rate: float = 1e-3,
245
180
  amsgrad: bool = False,
@@ -306,7 +241,7 @@ def train(
306
241
  crop size. If `None`, padding is auto-computed based on augmentation settings.
307
242
  Only used when `crop_size` is `None`. Default: None.
308
243
  use_augmentations_train: True if the data augmentation should be applied to the
309
- training data, else False. Default: True.
244
+ training data, else False. Default: False.
310
245
  intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
311
246
  or list of strings from the above allowed values. To have custom values, pass
312
247
  a dict with the structure in `sleap_nn.config.data_config.IntensityConfig`.
@@ -318,8 +253,7 @@ def train(
318
253
  or list of strings from the above allowed values. To have custom values, pass
319
254
  a dict with the structure in `sleap_nn.config.data_config.GeometryConfig`.
320
255
  For eg: {
321
- "rotation_min": -45,
322
- "rotation_max": 45,
256
+ "rotation": 45,
323
257
  "affine_p": 1.0
324
258
  }
325
259
  init_weight: model weights initialization method. "default" uses kaiming uniform
@@ -419,9 +353,6 @@ def train(
419
353
  wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
420
354
  ckpt. Default: None
421
355
  wandb_group_name: Group name for the wandb run. Default: None.
422
- wandb_delete_local_logs: If True, delete local wandb logs folder after training.
423
- If False, keep the folder. If None (default), automatically delete if logging
424
- online (wandb_mode != "offline") and keep if logging offline. Default: None.
425
356
  optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
426
357
  learning_rate: Learning rate of type float. Default: 1e-3.
427
358
  amsgrad: Enable AMSGrad with the optimizer. Default: False.
@@ -525,7 +456,6 @@ def train(
525
456
  wandb_save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
526
457
  wandb_resume_prv_runid=wandb_resume_prv_runid,
527
458
  wandb_group_name=wandb_group_name,
528
- wandb_delete_local_logs=wandb_delete_local_logs,
529
459
  optimizer=optimizer,
530
460
  learning_rate=learning_rate,
531
461
  amsgrad=amsgrad,