sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/train.py CHANGED
@@ -6,6 +6,7 @@ from datetime import datetime
6
6
  from time import time
7
7
  from omegaconf import DictConfig, OmegaConf
8
8
  from typing import Any, Dict, Optional, List, Tuple, Union
9
+ import sleap_io as sio
9
10
  from sleap_nn.config.training_job_config import TrainingJobConfig
10
11
  from sleap_nn.training.model_trainer import ModelTrainer
11
12
  from sleap_nn.predict import run_inference as predict
@@ -15,15 +16,31 @@ from sleap_nn.config.get_config import (
15
16
  get_model_config,
16
17
  get_data_config,
17
18
  )
19
+ from sleap_nn.system_info import get_startup_info_string
18
20
 
19
21
 
20
- def run_training(config: DictConfig):
21
- """Create ModelTrainer instance and start training."""
22
+ def run_training(
23
+ config: DictConfig,
24
+ train_labels: Optional[List[sio.Labels]] = None,
25
+ val_labels: Optional[List[sio.Labels]] = None,
26
+ ):
27
+ """Create ModelTrainer instance and start training.
28
+
29
+ Args:
30
+ config: Training configuration as a DictConfig.
31
+ train_labels: List of Labels objects for training.
32
+ val_labels: List of Labels objects for validation.
33
+ If not provided, the labels will be loaded from paths in the config.
34
+ """
22
35
  start_train_time = time()
23
36
  start_timestamp = str(datetime.now())
24
37
  logger.info(f"Started training at: {start_timestamp}")
38
+ logger.info(get_startup_info_string())
25
39
 
26
- trainer = ModelTrainer.get_model_trainer_from_config(config)
40
+ # provide the labels as the train labels, val labels will be split from the train labels
41
+ trainer = ModelTrainer.get_model_trainer_from_config(
42
+ config, train_labels=train_labels, val_labels=val_labels
43
+ )
27
44
  trainer.train()
28
45
 
29
46
  finish_timestamp = str(datetime.now())
@@ -39,48 +56,44 @@ def run_training(config: DictConfig):
39
56
  # run inference on val dataset
40
57
  if trainer.config.trainer_config.save_ckpt:
41
58
  data_paths = {}
42
- for index, path in enumerate(trainer.config.data_config.train_labels_path):
43
- logger.info(
44
- f"Training labels path for index {index}: {(Path(trainer.config.trainer_config.ckpt_dir) / trainer.config.trainer_config.run_name).as_posix()}"
45
- )
46
- data_paths[f"train_{index}"] = (
47
- Path(trainer.config.trainer_config.ckpt_dir)
48
- / trainer.config.trainer_config.run_name
49
- / f"labels_train_gt_{index}.slp"
59
+ run_path = (
60
+ Path(trainer.config.trainer_config.ckpt_dir)
61
+ / trainer.config.trainer_config.run_name
62
+ )
63
+ for index, _ in enumerate(trainer.train_labels):
64
+ logger.info(f"Run path for index {index}: {run_path.as_posix()}")
65
+ data_paths[f"train.{index}"] = (
66
+ run_path / f"labels_gt.train.{index}.slp"
50
67
  ).as_posix()
51
- data_paths[f"val_{index}"] = (
52
- Path(trainer.config.trainer_config.ckpt_dir)
53
- / trainer.config.trainer_config.run_name
54
- / f"labels_val_gt_{index}.slp"
68
+ data_paths[f"val.{index}"] = (
69
+ run_path / f"labels_gt.val.{index}.slp"
55
70
  ).as_posix()
56
71
 
57
- if (
58
- OmegaConf.select(config, "data_config.test_file_path", default=None)
59
- is not None
60
- ):
61
- data_paths["test"] = config.data_config.test_file_path
72
+ # Handle test_file_path as either a string or list of strings
73
+ test_file_path = OmegaConf.select(
74
+ config, "data_config.test_file_path", default=None
75
+ )
76
+ if test_file_path is not None:
77
+ # Normalize to list of strings
78
+ if isinstance(test_file_path, str):
79
+ test_paths = [test_file_path]
80
+ else:
81
+ test_paths = list(test_file_path)
82
+ # Add each test path to data_paths (always use index for consistency)
83
+ for idx, test_path in enumerate(test_paths):
84
+ data_paths[f"test.{idx}"] = test_path
62
85
 
63
86
  for d_name, path in data_paths.items():
64
- pred_path = (
65
- Path(trainer.config.trainer_config.ckpt_dir)
66
- / trainer.config.trainer_config.run_name
67
- / f"pred_{d_name}.slp"
68
- )
69
- metrics_path = (
70
- Path(trainer.config.trainer_config.ckpt_dir)
71
- / trainer.config.trainer_config.run_name
72
- / f"{d_name}_pred_metrics.npz"
73
- )
87
+ # d_name is now in format: "train.0", "val.0", "test.0", etc.
88
+ pred_path = run_path / f"labels_pr.{d_name}.slp"
89
+ metrics_path = run_path / f"metrics.{d_name}.npz"
74
90
 
75
91
  pred_labels = predict(
76
92
  data_path=path,
77
- model_paths=[
78
- Path(trainer.config.trainer_config.ckpt_dir)
79
- / trainer.config.trainer_config.run_name
80
- ],
93
+ model_paths=[run_path],
81
94
  peak_threshold=0.2,
82
95
  make_labels=True,
83
- device=trainer.trainer.strategy.root_device,
96
+ device=str(trainer.trainer.strategy.root_device),
84
97
  output_path=pred_path,
85
98
  ensure_rgb=config.data_config.preprocessing.ensure_rgb,
86
99
  ensure_grayscale=config.data_config.preprocessing.ensure_grayscale,
@@ -105,12 +118,77 @@ def run_training(config: DictConfig):
105
118
  logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
106
119
  logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
107
120
 
121
+ # Log test metrics to wandb summary
122
+ if (
123
+ d_name.startswith("test")
124
+ and trainer.config.trainer_config.use_wandb
125
+ ):
126
+ import wandb
127
+
128
+ if wandb.run is not None:
129
+ summary_metrics = {
130
+ f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
131
+ f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
132
+ "oks_voc.mAP"
133
+ ],
134
+ f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
135
+ "oks_voc.mAR"
136
+ ],
137
+ f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
138
+ f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
139
+ f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
140
+ f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
141
+ "avg"
142
+ ],
143
+ f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
144
+ "p50"
145
+ ],
146
+ f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
147
+ "p95"
148
+ ],
149
+ f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
150
+ "p99"
151
+ ],
152
+ f"eval/{d_name}/visibility_precision": metrics[
153
+ "visibility_metrics"
154
+ ]["precision"],
155
+ f"eval/{d_name}/visibility_recall": metrics[
156
+ "visibility_metrics"
157
+ ]["recall"],
158
+ }
159
+ for key, value in summary_metrics.items():
160
+ wandb.run.summary[key] = value
161
+
162
+ # Finish wandb run and cleanup after all evaluation is complete
163
+ if trainer.config.trainer_config.use_wandb:
164
+ import wandb
165
+ import shutil
166
+
167
+ if wandb.run is not None:
168
+ wandb.finish()
169
+
170
+ # Delete local wandb logs if configured
171
+ wandb_config = trainer.config.trainer_config.wandb
172
+ should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
173
+ wandb_config.delete_local_logs is None
174
+ and wandb_config.wandb_mode != "offline"
175
+ )
176
+ if should_delete_wandb_logs:
177
+ wandb_dir = run_path / "wandb"
178
+ if wandb_dir.exists():
179
+ logger.info(
180
+ f"Deleting local wandb logs at {wandb_dir}... "
181
+ "(set trainer_config.wandb.delete_local_logs=false to disable)"
182
+ )
183
+ shutil.rmtree(wandb_dir, ignore_errors=True)
184
+
108
185
 
109
186
  def train(
110
187
  train_labels_path: Optional[List[str]] = None,
111
188
  val_labels_path: Optional[List[str]] = None,
112
189
  validation_fraction: float = 0.1,
113
- test_file_path: Optional[str] = None,
190
+ use_same_data_for_val: bool = False,
191
+ test_file_path: Optional[Union[str, List[str]]] = None,
114
192
  provider: str = "LabelsReader",
115
193
  user_instances_only: bool = True,
116
194
  data_pipeline_fw: str = "torch_dataset",
@@ -124,9 +202,10 @@ def train(
124
202
  max_width: Optional[int] = None,
125
203
  crop_size: Optional[int] = None,
126
204
  min_crop_size: Optional[int] = 100,
127
- use_augmentations_train: bool = False,
205
+ crop_padding: Optional[int] = None,
206
+ use_augmentations_train: bool = True,
128
207
  intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
129
- geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
208
+ geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = "rotation",
130
209
  init_weight: str = "default",
131
210
  pretrained_backbone_weights: Optional[str] = None,
132
211
  pretrained_head_weights: Optional[str] = None,
@@ -160,6 +239,7 @@ def train(
160
239
  wandb_save_viz_imgs_wandb: bool = False,
161
240
  wandb_resume_prv_runid: Optional[str] = None,
162
241
  wandb_group_name: Optional[str] = None,
242
+ wandb_delete_local_logs: Optional[bool] = None,
163
243
  optimizer: str = "Adam",
164
244
  learning_rate: float = 1e-3,
165
245
  amsgrad: bool = False,
@@ -188,7 +268,11 @@ def train(
188
268
  training set to sample for generating the validation set. The remaining
189
269
  labeled frames will be left in the training set. If the `validation_labels`
190
270
  are already specified, this has no effect. Default: 0.1.
191
- test_file_path: Path to test dataset (`.slp` file or `.mp4` file).
271
+ use_same_data_for_val: If `True`, use the same data for both training and
272
+ validation (train = val). Useful for intentional overfitting on small
273
+ datasets. When enabled, `val_labels_path` and `validation_fraction` are
274
+ ignored. Default: False.
275
+ test_file_path: Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)).
192
276
  Note: This is used to get evaluation on test set after training is completed.
193
277
  provider: Provider class to read the input sleap files. Only "LabelsReader"
194
278
  supported for the training pipeline. Default: "LabelsReader".
@@ -210,16 +294,19 @@ def train(
210
294
  is set to True, then we convert the image to grayscale (single-channel)
211
295
  image. If the source image has only one channel and this is set to False, then we retain the single channel input. Default: `False`.
212
296
  scale: Factor to resize the image dimensions by, specified as a float. Default: 1.0.
213
- max_height: Maximum height the image should be padded to. If not provided, the
297
+ max_height: Maximum height the original image should be resized and padded to. If not provided, the
214
298
  original image size will be retained. Default: None.
215
- max_width: Maximum width the image should be padded to. If not provided, the
299
+ max_width: Maximum width the original image should be resized and padded to. If not provided, the
216
300
  original image size will be retained. Default: None.
217
301
  crop_size: Crop size of each instance for centered-instance model.
218
302
  If `None`, this would be automatically computed based on the largest instance
219
- in the `sio.Labels` file. Default: None.
303
+ in the `sio.Labels` file. If `scale` is provided, then the cropped image will be resized according to `scale`. Default: None.
220
304
  min_crop_size: Minimum crop size to be used if `crop_size` is `None`. Default: 100.
305
+ crop_padding: Padding in pixels to add around instance bounding box when computing
306
+ crop size. If `None`, padding is auto-computed based on augmentation settings.
307
+ Only used when `crop_size` is `None`. Default: None.
221
308
  use_augmentations_train: True if the data augmentation should be applied to the
222
- training data, else False. Default: False.
309
+ training data, else False. Default: True.
223
310
  intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
224
311
  or list of strings from the above allowed values. To have custom values, pass
225
312
  a dict with the structure in `sleap_nn.config.data_config.IntensityConfig`.
@@ -231,7 +318,8 @@ def train(
231
318
  or list of strings from the above allowed values. To have custom values, pass
232
319
  a dict with the structure in `sleap_nn.config.data_config.GeometryConfig`.
233
320
  For eg: {
234
- "rotation": 45,
321
+ "rotation_min": -45,
322
+ "rotation_max": 45,
235
323
  "affine_p": 1.0
236
324
  }
237
325
  init_weight: model weights initialization method. "default" uses kaiming uniform
@@ -331,6 +419,9 @@ def train(
331
419
  wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
332
420
  ckpt. Default: None
333
421
  wandb_group_name: Group name for the wandb run. Default: None.
422
+ wandb_delete_local_logs: If True, delete local wandb logs folder after training.
423
+ If False, keep the folder. If None (default), automatically delete if logging
424
+ online (wandb_mode != "offline") and keep if logging offline. Default: None.
334
425
  optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
335
426
  learning_rate: Learning rate of type float. Default: 1e-3.
336
427
  amsgrad: Enable AMSGrad with the optimizer. Default: False.
@@ -376,6 +467,7 @@ def train(
376
467
  train_labels_path=train_labels_path,
377
468
  val_labels_path=val_labels_path,
378
469
  validation_fraction=validation_fraction,
470
+ use_same_data_for_val=use_same_data_for_val,
379
471
  test_file_path=test_file_path,
380
472
  provider=provider,
381
473
  user_instances_only=user_instances_only,
@@ -390,6 +482,7 @@ def train(
390
482
  max_width=max_width,
391
483
  crop_size=crop_size,
392
484
  min_crop_size=min_crop_size,
485
+ crop_padding=crop_padding,
393
486
  use_augmentations_train=use_augmentations_train,
394
487
  intensity_aug=intensity_aug,
395
488
  geometry_aug=geometry_aug,
@@ -432,6 +525,7 @@ def train(
432
525
  wandb_save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
433
526
  wandb_resume_prv_runid=wandb_resume_prv_runid,
434
527
  wandb_group_name=wandb_group_name,
528
+ wandb_delete_local_logs=wandb_delete_local_logs,
435
529
  optimizer=optimizer,
436
530
  learning_rate=learning_rate,
437
531
  amsgrad=amsgrad,