sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ the parameters required to initialize the data config.
6
6
 
7
7
  from attrs import define, field, validators
8
8
  from omegaconf import MISSING
9
- from typing import Optional, Tuple, Any, List
9
+ from typing import Optional, Tuple, Any, List, Union
10
10
  from loguru import logger
11
11
  import sleap_io as sio
12
12
  import yaml
@@ -20,11 +20,15 @@ class PreprocessingConfig:
20
20
  Attributes:
21
21
  ensure_rgb: (bool) True if the input image should have 3 channels (RGB image). If input has only one channel when this is set to `True`, then the images from single-channel is replicated along the channel axis. If the image has three channels and this is set to False, then we retain the three channels. *Default*: `False`.
22
22
  ensure_grayscale: (bool) True if the input image should only have a single channel. If input has three channels (RGB) and this is set to True, then we convert the image to grayscale (single-channel) image. If the source image has only one channel and this is set to False, then we retain the single channel input. *Default*: `False`.
23
- max_height: (int) Maximum height the image should be padded to. If not provided, the original image size will be retained. *Default*: `None`.
24
- max_width: (int) Maximum width the image should be padded to. If not provided, the original image size will be retained. *Default*: `None`.
23
+ max_height: (int) Maximum height the original image should be resized and padded to. If not provided, the original image size will be retained. *Default*: `None`.
24
+ max_width: (int) Maximum width the original image should be resized and padded to. If not provided, the original image size will be retained. *Default*: `None`.
25
25
  scale: (float) Factor to resize the image dimensions by, specified as a float. *Default*: `1.0`.
26
- crop_size: (int) Crop size of each instance for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file. *Default*: `None`.
26
+ crop_size: (int) Crop size of each instance for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file.
27
+ If `scale` is provided, then the cropped image will be resized according to `scale`.*Default*: `None`.
27
28
  min_crop_size: (int) Minimum crop size to be used if `crop_size` is `None`. *Default*: `100`.
29
+ crop_padding: (int) Padding in pixels to add around the instance bounding box when computing crop size.
30
+ If `None`, padding is auto-computed based on augmentation settings (rotation/scale).
31
+ Only used when `crop_size` is `None`. *Default*: `None`.
28
32
  """
29
33
 
30
34
  ensure_rgb: bool = False
@@ -36,6 +40,7 @@ class PreprocessingConfig:
36
40
  )
37
41
  crop_size: Optional[int] = None
38
42
  min_crop_size: Optional[int] = 100 # to help app work in case of error
43
+ crop_padding: Optional[int] = None
39
44
 
40
45
  def validate_scale(self):
41
46
  """Scale Validation.
@@ -104,11 +109,14 @@ class GeometricConfig:
104
109
  Attributes:
105
110
  rotation_min: (float) Minimum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `-15.0`.
106
111
  rotation_max: (float) Maximum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `15.0`.
112
+ rotation_p: (float, optional) Probability of applying random rotation independently. If set, rotation is applied separately from scale/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `1.0`.
107
113
  scale_min: (float) Minimum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `0.9`.
108
114
  scale_max: (float) Maximum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `1.1`.
115
+ scale_p: (float, optional) Probability of applying random scaling independently. If set, scaling is applied separately from rotation/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `1.0`.
109
116
  translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. *Default*: `0.0`.
110
117
  translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. *Default*: `0.0`.
111
- affine_p: (float) Probability of applying random affine transformations. *Default*: `0.0`.
118
+ translate_p: (float, optional) Probability of applying random translation independently. If set, translation is applied separately from rotation/scale. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
119
+ affine_p: (float) Probability of applying random affine transformations (rotation, scale, translate bundled together). Used for backwards compatibility when individual `*_p` params are not set. *Default*: `0.0`.
112
120
  erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. *Default*: `0.0001`.
113
121
  erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. *Default*: `0.01`.
114
122
  erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. *Default*: `1.0`.
@@ -121,10 +129,13 @@ class GeometricConfig:
121
129
 
122
130
  rotation_min: float = field(default=-15.0, validator=validators.ge(-180))
123
131
  rotation_max: float = field(default=15.0, validator=validators.le(180))
132
+ rotation_p: Optional[float] = field(default=1.0)
124
133
  scale_min: float = field(default=0.9, validator=validators.ge(0))
125
134
  scale_max: float = field(default=1.1, validator=validators.ge(0))
135
+ scale_p: Optional[float] = field(default=1.0)
126
136
  translate_width: float = 0.0
127
137
  translate_height: float = 0.0
138
+ translate_p: Optional[float] = field(default=None)
128
139
  affine_p: float = field(default=0.0, validator=validate_proportion)
129
140
  erase_scale_min: float = 0.0001
130
141
  erase_scale_max: float = 0.01
@@ -149,6 +160,28 @@ class AugmentationConfig:
149
160
  geometric: Optional[GeometricConfig] = None
150
161
 
151
162
 
163
+ def validate_test_file_path(instance, attribute, value):
164
+ """Validate test_file_path to accept str or List[str].
165
+
166
+ Args:
167
+ instance: The instance being validated.
168
+ attribute: The attribute being validated.
169
+ value: The value to validate.
170
+
171
+ Raises:
172
+ ValueError: If value is not None, str, or list of strings.
173
+ """
174
+ if value is None:
175
+ return
176
+ if isinstance(value, str):
177
+ return
178
+ if isinstance(value, (list, tuple)) and all(isinstance(p, str) for p in value):
179
+ return
180
+ message = f"{attribute.name} must be a string or list of strings, got {type(value).__name__}"
181
+ logger.error(message)
182
+ raise ValueError(message)
183
+
184
+
152
185
  @define
153
186
  class DataConfig:
154
187
  """Data configuration.
@@ -157,13 +190,16 @@ class DataConfig:
157
190
  train_labels_path: (List[str]) List of paths to training data (`.slp` file(s)). *Default*: `None`.
158
191
  val_labels_path: (List[str]) List of paths to validation data (`.slp` file(s)). *Default*: `None`.
159
192
  validation_fraction: (float) Float between 0 and 1 specifying the fraction of the training set to sample for generating the validation set. The remaining labeled frames will be left in the training set. If the `validation_labels` are already specified, this has no effect. *Default*: `0.1`.
160
- test_file_path: (str) Path to test dataset (`.slp` file or `.mp4` file). *Note*: This is used only with CLI to get evaluation on test set after training is completed. *Default*: `None`.
193
+ use_same_data_for_val: (bool) If `True`, use the same data for both training and validation (train = val). Useful for intentional overfitting on small datasets. When enabled, `val_labels_path` and `validation_fraction` are ignored. *Default*: `False`.
194
+ test_file_path: (str or List[str]) Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)). *Note*: This is used only with CLI to get evaluation on test set after training is completed. *Default*: `None`.
161
195
  provider: (str) Provider class to read the input sleap files. Only "LabelsReader" is currently supported for the training pipeline. *Default*: `"LabelsReader"`.
162
196
  user_instances_only: (bool) `True` if only user labeled instances should be used for training. If `False`, both user labeled and predicted instances would be used. *Default*: `True`.
163
197
  data_pipeline_fw: (str) Framework to create the data loaders. One of [`torch_dataset`, `torch_dataset_cache_img_memory`, `torch_dataset_cache_img_disk`]. *Default*: `"torch_dataset"`. (Note: When using `torch_dataset`, `num_workers` in `trainer_config` should be set to 0 as multiprocessing doesn't work with pickling video backends.)
164
198
  cache_img_path: (str) Path to save `.jpg` images created with `torch_dataset_cache_img_disk` data pipeline framework. If `None`, the path provided in `trainer_config.save_ckpt` is used. The `train_imgs` and `val_imgs` dirs are created inside this path. *Default*: `None`.
165
199
  use_existing_imgs: (bool) Use existing train and val images/ chunks in the `cache_img_path` for `torch_dataset_cache_img_disk` frameworks. If `True`, the `cache_img_path` should have `train_imgs` and `val_imgs` dirs. *Default*: `False`.
166
200
  delete_cache_imgs_after_training: (bool) If `False`, the images (torch_dataset_cache_img_disk) are retained after training. Else, the files are deleted. *Default*: `True`.
201
+ parallel_caching: (bool) If `True`, use parallel processing to cache images (significantly faster for large datasets). *Default*: `True`.
202
+ cache_workers: (int) Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). *Default*: `0`.
167
203
  preprocessing: Configuration options related to data preprocessing.
168
204
  use_augmentations_train: (bool) True if the data augmentation should be applied to the training data, else False. *Default*: `True`.
169
205
  augmentation_config: Configurations related to augmentation. (only if `use_augmentations_train` is `True`)
@@ -173,16 +209,23 @@ class DataConfig:
173
209
  train_labels_path: Optional[List[str]] = None
174
210
  val_labels_path: Optional[List[str]] = None # TODO : revisit MISSING!
175
211
  validation_fraction: float = 0.1
176
- test_file_path: Optional[str] = None
212
+ use_same_data_for_val: bool = False
213
+ test_file_path: Optional[Any] = field(
214
+ default=None, validator=validate_test_file_path
215
+ )
177
216
  provider: str = "LabelsReader"
178
217
  user_instances_only: bool = True
179
218
  data_pipeline_fw: str = "torch_dataset"
180
219
  cache_img_path: Optional[str] = None
181
220
  use_existing_imgs: bool = False
182
221
  delete_cache_imgs_after_training: bool = True
222
+ parallel_caching: bool = True
223
+ cache_workers: int = 0
183
224
  preprocessing: PreprocessingConfig = field(factory=PreprocessingConfig)
184
225
  use_augmentations_train: bool = True
185
- augmentation_config: Optional[AugmentationConfig] = None
226
+ augmentation_config: Optional[AugmentationConfig] = field(
227
+ factory=lambda: AugmentationConfig(geometric=GeometricConfig())
228
+ )
186
229
  skeletons: Optional[list] = None
187
230
 
188
231
 
@@ -131,27 +131,18 @@ def get_aug_config(
131
131
 
132
132
  for g in geometric_aug:
133
133
  if g == "rotation":
134
- aug_config.geometric.affine_p = 1.0
135
- aug_config.geometric.scale_min = 1.0
136
- aug_config.geometric.scale_max = 1.0
137
- aug_config.geometric.translate_height = 0
138
- aug_config.geometric.translate_width = 0
134
+ # Use new independent rotation probability
135
+ aug_config.geometric.rotation_p = 1.0
139
136
  elif g == "scale":
137
+ # Use new independent scale probability
140
138
  aug_config.geometric.scale_min = 0.9
141
139
  aug_config.geometric.scale_max = 1.1
142
- aug_config.geometric.affine_p = 1.0
143
- aug_config.geometric.rotation_min = 0
144
- aug_config.geometric.rotation_max = 0
145
- aug_config.geometric.translate_height = 0
146
- aug_config.geometric.translate_width = 0
140
+ aug_config.geometric.scale_p = 1.0
147
141
  elif g == "translate":
142
+ # Use new independent translate probability
148
143
  aug_config.geometric.translate_height = 0.2
149
144
  aug_config.geometric.translate_width = 0.2
150
- aug_config.geometric.affine_p = 1.0
151
- aug_config.geometric.rotation_min = 0
152
- aug_config.geometric.rotation_max = 0
153
- aug_config.geometric.scale_min = 1.0
154
- aug_config.geometric.scale_max = 1.0
145
+ aug_config.geometric.translate_p = 1.0
155
146
  elif g == "erase_scale":
156
147
  aug_config.geometric.erase_p = 1.0
157
148
  elif g == "mixup":
@@ -456,7 +447,8 @@ def get_data_config(
456
447
  train_labels_path: Optional[List[str]] = None,
457
448
  val_labels_path: Optional[List[str]] = None,
458
449
  validation_fraction: float = 0.1,
459
- test_file_path: Optional[str] = None,
450
+ use_same_data_for_val: bool = False,
451
+ test_file_path: Optional[Union[str, List[str]]] = None,
460
452
  provider: str = "LabelsReader",
461
453
  user_instances_only: bool = True,
462
454
  data_pipeline_fw: str = "torch_dataset",
@@ -470,9 +462,10 @@ def get_data_config(
470
462
  max_width: Optional[int] = None,
471
463
  crop_size: Optional[int] = None,
472
464
  min_crop_size: Optional[int] = 100,
473
- use_augmentations_train: bool = False,
465
+ crop_padding: Optional[int] = None,
466
+ use_augmentations_train: bool = True,
474
467
  intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
475
- geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
468
+ geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = "rotation",
476
469
  ):
477
470
  """Train a pose-estimation model with SLEAP-NN framework.
478
471
 
@@ -486,7 +479,11 @@ def get_data_config(
486
479
  training set to sample for generating the validation set. The remaining
487
480
  labeled frames will be left in the training set. If the `validation_labels`
488
481
  are already specified, this has no effect. Default: 0.1.
489
- test_file_path: Path to test dataset (`.slp` file or `.mp4` file).
482
+ use_same_data_for_val: If `True`, use the same data for both training and
483
+ validation (train = val). Useful for intentional overfitting on small
484
+ datasets. When enabled, `val_labels_path` and `validation_fraction` are
485
+ ignored. Default: False.
486
+ test_file_path: Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)).
490
487
  Note: This is used to get evaluation on test set after training is completed.
491
488
  provider: Provider class to read the input sleap files. Only "LabelsReader"
492
489
  supported for the training pipeline. Default: "LabelsReader".
@@ -508,16 +505,19 @@ def get_data_config(
508
505
  is set to True, then we convert the image to grayscale (single-channel)
509
506
  image. If the source image has only one channel and this is set to False, then we retain the single channel input. Default: `False`.
510
507
  scale: Factor to resize the image dimensions by, specified as a float. Default: 1.0.
511
- max_height: Maximum height the image should be padded to. If not provided, the
508
+ max_height: Maximum height the original image should be resized and padded to. If not provided, the
512
509
  original image size will be retained. Default: None.
513
- max_width: Maximum width the image should be padded to. If not provided, the
510
+ max_width: Maximum width the original image should be resized and padded to. If not provided, the
514
511
  original image size will be retained. Default: None.
515
512
  crop_size: Crop size of each instance for centered-instance model.
516
513
  If `None`, this would be automatically computed based on the largest instance
517
- in the `sio.Labels` file. Default: None.
514
+ in the `sio.Labels` file. If `scale` is provided, then the cropped image will be resized according to `scale`. Default: None.
518
515
  min_crop_size: Minimum crop size to be used if `crop_size` is `None`. Default: 100.
516
+ crop_padding: Padding in pixels to add around instance bounding box when computing
517
+ crop size. If `None`, padding is auto-computed based on augmentation settings.
518
+ Only used when `crop_size` is `None`. Default: None.
519
519
  use_augmentations_train: True if the data augmentation should be applied to the
520
- training data, else False. Default: False.
520
+ training data, else False. Default: True.
521
521
  intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
522
522
  or list of strings from the above allowed values. To have custom values, pass
523
523
  a dict with the structure in `sleap_nn.config.data_config.IntensityConfig`.
@@ -529,7 +529,8 @@ def get_data_config(
529
529
  or list of strings from the above allowed values. To have custom values, pass
530
530
  a dict with the structure in `sleap_nn.config.data_config.GeometryConfig`.
531
531
  For eg: {
532
- "rotation": 45,
532
+ "rotation_min": -45,
533
+ "rotation_max": 45,
533
534
  "affine_p": 1.0
534
535
  }
535
536
  """
@@ -541,6 +542,7 @@ def get_data_config(
541
542
  scale=scale,
542
543
  crop_size=crop_size,
543
544
  min_crop_size=min_crop_size,
545
+ crop_padding=crop_padding,
544
546
  )
545
547
  augmentation_config = None
546
548
  if use_augmentations_train:
@@ -553,6 +555,7 @@ def get_data_config(
553
555
  train_labels_path=train_labels_path,
554
556
  val_labels_path=val_labels_path,
555
557
  validation_fraction=validation_fraction,
558
+ use_same_data_for_val=use_same_data_for_val,
556
559
  test_file_path=test_file_path,
557
560
  provider=provider,
558
561
  user_instances_only=user_instances_only,
@@ -675,6 +678,7 @@ def get_trainer_config(
675
678
  wandb_save_viz_imgs_wandb: bool = False,
676
679
  wandb_resume_prv_runid: Optional[str] = None,
677
680
  wandb_group_name: Optional[str] = None,
681
+ wandb_delete_local_logs: Optional[bool] = None,
678
682
  optimizer: str = "Adam",
679
683
  learning_rate: float = 1e-3,
680
684
  amsgrad: bool = False,
@@ -744,6 +748,9 @@ def get_trainer_config(
744
748
  wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
745
749
  ckpt. Default: None
746
750
  wandb_group_name: Group name for the wandb run. Default: None.
751
+ wandb_delete_local_logs: If True, delete local wandb logs folder after training.
752
+ If False, keep the folder. If None (default), automatically delete if logging
753
+ online (wandb_mode != "offline") and keep if logging offline. Default: None.
747
754
  optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
748
755
  learning_rate: Learning rate of type float. Default: 1e-3.
749
756
  amsgrad: Enable AMSGrad with the optimizer. Default: False.
@@ -844,6 +851,7 @@ def get_trainer_config(
844
851
  save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
845
852
  prv_runid=wandb_resume_prv_runid,
846
853
  group=wandb_group_name,
854
+ delete_local_logs=wandb_delete_local_logs,
847
855
  ),
848
856
  save_ckpt=save_ckpt,
849
857
  ckpt_dir=ckpt_dir,
@@ -84,6 +84,16 @@ class WandBConfig:
84
84
  prv_runid: (str) Previous run ID if training should be resumed from a previous ckpt. *Default*: `None`.
85
85
  group: (str) Group for wandb logging. *Default*: `None`.
86
86
  current_run_id: (str) Run ID for the current model training. (stored once the training starts). *Default*: `None`.
87
+ viz_enabled: (bool) If True, log pre-rendered matplotlib images to wandb. *Default*: `True`.
88
+ viz_boxes: (bool) If True, log interactive keypoint boxes. *Default*: `False`.
89
+ viz_masks: (bool) If True, log confidence map overlay masks. *Default*: `False`.
90
+ viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
91
+ viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
92
+ log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
93
+ delete_local_logs: (bool, optional) If True, delete local wandb logs folder after
94
+ training. If False, keep the folder. If None (default), automatically delete
95
+ if logging online (wandb_mode != "offline") and keep if logging offline.
96
+ *Default*: `None`.
87
97
  """
88
98
 
89
99
  entity: Optional[str] = None
@@ -95,6 +105,13 @@ class WandBConfig:
95
105
  prv_runid: Optional[str] = None
96
106
  group: Optional[str] = None
97
107
  current_run_id: Optional[str] = None
108
+ viz_enabled: bool = True
109
+ viz_boxes: bool = False
110
+ viz_masks: bool = False
111
+ viz_box_size: float = 5.0
112
+ viz_confmap_threshold: float = 0.1
113
+ log_viz_table: bool = False
114
+ delete_local_logs: Optional[bool] = None
98
115
 
99
116
 
100
117
  @define
@@ -161,19 +178,69 @@ class ReduceLROnPlateauConfig:
161
178
  raise ValueError(message)
162
179
 
163
180
 
181
+ @define
182
+ class CosineAnnealingWarmupConfig:
183
+ """Configuration for Cosine Annealing with Linear Warmup scheduler.
184
+
185
+ The learning rate increases linearly during warmup, then decreases following
186
+ a cosine curve to the minimum value.
187
+
188
+ Attributes:
189
+ warmup_epochs: (int) Number of epochs for linear warmup phase. *Default*: `5`.
190
+ max_epochs: (int) Total number of training epochs. Will be overridden by
191
+ trainer's max_epochs if not specified. *Default*: `None`.
192
+ warmup_start_lr: (float) Learning rate at start of warmup. *Default*: `0.0`.
193
+ eta_min: (float) Minimum learning rate at end of cosine decay. *Default*: `0.0`.
194
+ """
195
+
196
+ warmup_epochs: int = field(default=5, validator=validators.ge(0))
197
+ max_epochs: Optional[int] = None
198
+ warmup_start_lr: float = field(default=0.0, validator=validators.ge(0))
199
+ eta_min: float = field(default=0.0, validator=validators.ge(0))
200
+
201
+
202
+ @define
203
+ class LinearWarmupLinearDecayConfig:
204
+ """Configuration for Linear Warmup + Linear Decay scheduler.
205
+
206
+ The learning rate increases linearly during warmup, then decreases linearly
207
+ to the end learning rate.
208
+
209
+ Attributes:
210
+ warmup_epochs: (int) Number of epochs for linear warmup phase. *Default*: `5`.
211
+ max_epochs: (int) Total number of training epochs. Will be overridden by
212
+ trainer's max_epochs if not specified. *Default*: `None`.
213
+ warmup_start_lr: (float) Learning rate at start of warmup. *Default*: `0.0`.
214
+ end_lr: (float) Learning rate at end of training. *Default*: `0.0`.
215
+ """
216
+
217
+ warmup_epochs: int = field(default=5, validator=validators.ge(0))
218
+ max_epochs: Optional[int] = None
219
+ warmup_start_lr: float = field(default=0.0, validator=validators.ge(0))
220
+ end_lr: float = field(default=0.0, validator=validators.ge(0))
221
+
222
+
164
223
  @define
165
224
  class LRSchedulerConfig:
166
225
  """Configuration for lr_scheduler.
167
226
 
227
+ Only one scheduler should be configured at a time. If multiple are set,
228
+ priority order is: cosine_annealing_warmup > linear_warmup_linear_decay >
229
+ step_lr > reduce_lr_on_plateau.
230
+
168
231
  Attributes:
169
232
  step_lr: Configuration for StepLR scheduler.
170
233
  reduce_lr_on_plateau: Configuration for ReduceLROnPlateau scheduler.
234
+ cosine_annealing_warmup: Configuration for Cosine Annealing with Linear Warmup scheduler.
235
+ linear_warmup_linear_decay: Configuration for Linear Warmup + Linear Decay scheduler.
171
236
  """
172
237
 
173
238
  step_lr: Optional[StepLRConfig] = None
174
239
  reduce_lr_on_plateau: Optional[ReduceLROnPlateauConfig] = field(
175
240
  factory=ReduceLROnPlateauConfig
176
241
  )
242
+ cosine_annealing_warmup: Optional[CosineAnnealingWarmupConfig] = None
243
+ linear_warmup_linear_decay: Optional[LinearWarmupLinearDecayConfig] = None
177
244
 
178
245
 
179
246
  @define
@@ -191,6 +258,26 @@ class EarlyStoppingConfig:
191
258
  stop_training_on_plateau: bool = True
192
259
 
193
260
 
261
+ @define
262
+ class EvalConfig:
263
+ """Configuration for epoch-end evaluation.
264
+
265
+ Attributes:
266
+ enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
267
+ frequency: (int) Evaluate every N epochs. *Default*: `1`.
268
+ oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
269
+ oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
270
+ match_threshold: (float) Maximum distance in pixels for centroid matching.
271
+ Only used for centroid model evaluation. *Default*: `50.0`.
272
+ """
273
+
274
+ enabled: bool = False
275
+ frequency: int = field(default=1, validator=validators.ge(1))
276
+ oks_stddev: float = field(default=0.025, validator=validators.gt(0))
277
+ oks_scale: Optional[float] = None
278
+ match_threshold: float = field(default=50.0, validator=validators.gt(0))
279
+
280
+
194
281
  @define
195
282
  class HardKeypointMiningConfig:
196
283
  """Configuration for online hard keypoint mining.
@@ -293,6 +380,7 @@ class TrainerConfig:
293
380
  factory=HardKeypointMiningConfig
294
381
  )
295
382
  zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
383
+ eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
296
384
 
297
385
  @staticmethod
298
386
  def validate_optimizer_name(value):