sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/config/data_config.py
CHANGED
|
@@ -6,7 +6,7 @@ the parameters required to initialize the data config.
|
|
|
6
6
|
|
|
7
7
|
from attrs import define, field, validators
|
|
8
8
|
from omegaconf import MISSING
|
|
9
|
-
from typing import Optional, Tuple, Any, List
|
|
9
|
+
from typing import Optional, Tuple, Any, List, Union
|
|
10
10
|
from loguru import logger
|
|
11
11
|
import sleap_io as sio
|
|
12
12
|
import yaml
|
|
@@ -20,11 +20,15 @@ class PreprocessingConfig:
|
|
|
20
20
|
Attributes:
|
|
21
21
|
ensure_rgb: (bool) True if the input image should have 3 channels (RGB image). If input has only one channel when this is set to `True`, then the images from single-channel is replicated along the channel axis. If the image has three channels and this is set to False, then we retain the three channels. *Default*: `False`.
|
|
22
22
|
ensure_grayscale: (bool) True if the input image should only have a single channel. If input has three channels (RGB) and this is set to True, then we convert the image to grayscale (single-channel) image. If the source image has only one channel and this is set to False, then we retain the single channel input. *Default*: `False`.
|
|
23
|
-
max_height: (int) Maximum height the image should be padded to. If not provided, the original image size will be retained. *Default*: `None`.
|
|
24
|
-
max_width: (int) Maximum width the image should be padded to. If not provided, the original image size will be retained. *Default*: `None`.
|
|
23
|
+
max_height: (int) Maximum height the original image should be resized and padded to. If not provided, the original image size will be retained. *Default*: `None`.
|
|
24
|
+
max_width: (int) Maximum width the original image should be resized and padded to. If not provided, the original image size will be retained. *Default*: `None`.
|
|
25
25
|
scale: (float) Factor to resize the image dimensions by, specified as a float. *Default*: `1.0`.
|
|
26
|
-
crop_size: (int) Crop size of each instance for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file.
|
|
26
|
+
crop_size: (int) Crop size of each instance for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file.
|
|
27
|
+
If `scale` is provided, then the cropped image will be resized according to `scale`.*Default*: `None`.
|
|
27
28
|
min_crop_size: (int) Minimum crop size to be used if `crop_size` is `None`. *Default*: `100`.
|
|
29
|
+
crop_padding: (int) Padding in pixels to add around the instance bounding box when computing crop size.
|
|
30
|
+
If `None`, padding is auto-computed based on augmentation settings (rotation/scale).
|
|
31
|
+
Only used when `crop_size` is `None`. *Default*: `None`.
|
|
28
32
|
"""
|
|
29
33
|
|
|
30
34
|
ensure_rgb: bool = False
|
|
@@ -36,6 +40,7 @@ class PreprocessingConfig:
|
|
|
36
40
|
)
|
|
37
41
|
crop_size: Optional[int] = None
|
|
38
42
|
min_crop_size: Optional[int] = 100 # to help app work in case of error
|
|
43
|
+
crop_padding: Optional[int] = None
|
|
39
44
|
|
|
40
45
|
def validate_scale(self):
|
|
41
46
|
"""Scale Validation.
|
|
@@ -104,11 +109,14 @@ class GeometricConfig:
|
|
|
104
109
|
Attributes:
|
|
105
110
|
rotation_min: (float) Minimum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `-15.0`.
|
|
106
111
|
rotation_max: (float) Maximum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `15.0`.
|
|
112
|
+
rotation_p: (float, optional) Probability of applying random rotation independently. If set, rotation is applied separately from scale/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `1.0`.
|
|
107
113
|
scale_min: (float) Minimum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `0.9`.
|
|
108
114
|
scale_max: (float) Maximum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `1.1`.
|
|
115
|
+
scale_p: (float, optional) Probability of applying random scaling independently. If set, scaling is applied separately from rotation/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `1.0`.
|
|
109
116
|
translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. *Default*: `0.0`.
|
|
110
117
|
translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. *Default*: `0.0`.
|
|
111
|
-
|
|
118
|
+
translate_p: (float, optional) Probability of applying random translation independently. If set, translation is applied separately from rotation/scale. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
|
|
119
|
+
affine_p: (float) Probability of applying random affine transformations (rotation, scale, translate bundled together). Used for backwards compatibility when individual `*_p` params are not set. *Default*: `0.0`.
|
|
112
120
|
erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. *Default*: `0.0001`.
|
|
113
121
|
erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. *Default*: `0.01`.
|
|
114
122
|
erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. *Default*: `1.0`.
|
|
@@ -121,10 +129,13 @@ class GeometricConfig:
|
|
|
121
129
|
|
|
122
130
|
rotation_min: float = field(default=-15.0, validator=validators.ge(-180))
|
|
123
131
|
rotation_max: float = field(default=15.0, validator=validators.le(180))
|
|
132
|
+
rotation_p: Optional[float] = field(default=1.0)
|
|
124
133
|
scale_min: float = field(default=0.9, validator=validators.ge(0))
|
|
125
134
|
scale_max: float = field(default=1.1, validator=validators.ge(0))
|
|
135
|
+
scale_p: Optional[float] = field(default=1.0)
|
|
126
136
|
translate_width: float = 0.0
|
|
127
137
|
translate_height: float = 0.0
|
|
138
|
+
translate_p: Optional[float] = field(default=None)
|
|
128
139
|
affine_p: float = field(default=0.0, validator=validate_proportion)
|
|
129
140
|
erase_scale_min: float = 0.0001
|
|
130
141
|
erase_scale_max: float = 0.01
|
|
@@ -149,6 +160,28 @@ class AugmentationConfig:
|
|
|
149
160
|
geometric: Optional[GeometricConfig] = None
|
|
150
161
|
|
|
151
162
|
|
|
163
|
+
def validate_test_file_path(instance, attribute, value):
|
|
164
|
+
"""Validate test_file_path to accept str or List[str].
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
instance: The instance being validated.
|
|
168
|
+
attribute: The attribute being validated.
|
|
169
|
+
value: The value to validate.
|
|
170
|
+
|
|
171
|
+
Raises:
|
|
172
|
+
ValueError: If value is not None, str, or list of strings.
|
|
173
|
+
"""
|
|
174
|
+
if value is None:
|
|
175
|
+
return
|
|
176
|
+
if isinstance(value, str):
|
|
177
|
+
return
|
|
178
|
+
if isinstance(value, (list, tuple)) and all(isinstance(p, str) for p in value):
|
|
179
|
+
return
|
|
180
|
+
message = f"{attribute.name} must be a string or list of strings, got {type(value).__name__}"
|
|
181
|
+
logger.error(message)
|
|
182
|
+
raise ValueError(message)
|
|
183
|
+
|
|
184
|
+
|
|
152
185
|
@define
|
|
153
186
|
class DataConfig:
|
|
154
187
|
"""Data configuration.
|
|
@@ -157,13 +190,16 @@ class DataConfig:
|
|
|
157
190
|
train_labels_path: (List[str]) List of paths to training data (`.slp` file(s)). *Default*: `None`.
|
|
158
191
|
val_labels_path: (List[str]) List of paths to validation data (`.slp` file(s)). *Default*: `None`.
|
|
159
192
|
validation_fraction: (float) Float between 0 and 1 specifying the fraction of the training set to sample for generating the validation set. The remaining labeled frames will be left in the training set. If the `validation_labels` are already specified, this has no effect. *Default*: `0.1`.
|
|
160
|
-
|
|
193
|
+
use_same_data_for_val: (bool) If `True`, use the same data for both training and validation (train = val). Useful for intentional overfitting on small datasets. When enabled, `val_labels_path` and `validation_fraction` are ignored. *Default*: `False`.
|
|
194
|
+
test_file_path: (str or List[str]) Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)). *Note*: This is used only with CLI to get evaluation on test set after training is completed. *Default*: `None`.
|
|
161
195
|
provider: (str) Provider class to read the input sleap files. Only "LabelsReader" is currently supported for the training pipeline. *Default*: `"LabelsReader"`.
|
|
162
196
|
user_instances_only: (bool) `True` if only user labeled instances should be used for training. If `False`, both user labeled and predicted instances would be used. *Default*: `True`.
|
|
163
197
|
data_pipeline_fw: (str) Framework to create the data loaders. One of [`torch_dataset`, `torch_dataset_cache_img_memory`, `torch_dataset_cache_img_disk`]. *Default*: `"torch_dataset"`. (Note: When using `torch_dataset`, `num_workers` in `trainer_config` should be set to 0 as multiprocessing doesn't work with pickling video backends.)
|
|
164
198
|
cache_img_path: (str) Path to save `.jpg` images created with `torch_dataset_cache_img_disk` data pipeline framework. If `None`, the path provided in `trainer_config.save_ckpt` is used. The `train_imgs` and `val_imgs` dirs are created inside this path. *Default*: `None`.
|
|
165
199
|
use_existing_imgs: (bool) Use existing train and val images/ chunks in the `cache_img_path` for `torch_dataset_cache_img_disk` frameworks. If `True`, the `cache_img_path` should have `train_imgs` and `val_imgs` dirs. *Default*: `False`.
|
|
166
200
|
delete_cache_imgs_after_training: (bool) If `False`, the images (torch_dataset_cache_img_disk) are retained after training. Else, the files are deleted. *Default*: `True`.
|
|
201
|
+
parallel_caching: (bool) If `True`, use parallel processing to cache images (significantly faster for large datasets). *Default*: `True`.
|
|
202
|
+
cache_workers: (int) Number of worker threads for parallel caching. If 0, uses min(4, cpu_count). *Default*: `0`.
|
|
167
203
|
preprocessing: Configuration options related to data preprocessing.
|
|
168
204
|
use_augmentations_train: (bool) True if the data augmentation should be applied to the training data, else False. *Default*: `True`.
|
|
169
205
|
augmentation_config: Configurations related to augmentation. (only if `use_augmentations_train` is `True`)
|
|
@@ -173,16 +209,23 @@ class DataConfig:
|
|
|
173
209
|
train_labels_path: Optional[List[str]] = None
|
|
174
210
|
val_labels_path: Optional[List[str]] = None # TODO : revisit MISSING!
|
|
175
211
|
validation_fraction: float = 0.1
|
|
176
|
-
|
|
212
|
+
use_same_data_for_val: bool = False
|
|
213
|
+
test_file_path: Optional[Any] = field(
|
|
214
|
+
default=None, validator=validate_test_file_path
|
|
215
|
+
)
|
|
177
216
|
provider: str = "LabelsReader"
|
|
178
217
|
user_instances_only: bool = True
|
|
179
218
|
data_pipeline_fw: str = "torch_dataset"
|
|
180
219
|
cache_img_path: Optional[str] = None
|
|
181
220
|
use_existing_imgs: bool = False
|
|
182
221
|
delete_cache_imgs_after_training: bool = True
|
|
222
|
+
parallel_caching: bool = True
|
|
223
|
+
cache_workers: int = 0
|
|
183
224
|
preprocessing: PreprocessingConfig = field(factory=PreprocessingConfig)
|
|
184
225
|
use_augmentations_train: bool = True
|
|
185
|
-
augmentation_config: Optional[AugmentationConfig] =
|
|
226
|
+
augmentation_config: Optional[AugmentationConfig] = field(
|
|
227
|
+
factory=lambda: AugmentationConfig(geometric=GeometricConfig())
|
|
228
|
+
)
|
|
186
229
|
skeletons: Optional[list] = None
|
|
187
230
|
|
|
188
231
|
|
sleap_nn/config/get_config.py
CHANGED
|
@@ -131,27 +131,18 @@ def get_aug_config(
|
|
|
131
131
|
|
|
132
132
|
for g in geometric_aug:
|
|
133
133
|
if g == "rotation":
|
|
134
|
-
|
|
135
|
-
aug_config.geometric.
|
|
136
|
-
aug_config.geometric.scale_max = 1.0
|
|
137
|
-
aug_config.geometric.translate_height = 0
|
|
138
|
-
aug_config.geometric.translate_width = 0
|
|
134
|
+
# Use new independent rotation probability
|
|
135
|
+
aug_config.geometric.rotation_p = 1.0
|
|
139
136
|
elif g == "scale":
|
|
137
|
+
# Use new independent scale probability
|
|
140
138
|
aug_config.geometric.scale_min = 0.9
|
|
141
139
|
aug_config.geometric.scale_max = 1.1
|
|
142
|
-
aug_config.geometric.
|
|
143
|
-
aug_config.geometric.rotation_min = 0
|
|
144
|
-
aug_config.geometric.rotation_max = 0
|
|
145
|
-
aug_config.geometric.translate_height = 0
|
|
146
|
-
aug_config.geometric.translate_width = 0
|
|
140
|
+
aug_config.geometric.scale_p = 1.0
|
|
147
141
|
elif g == "translate":
|
|
142
|
+
# Use new independent translate probability
|
|
148
143
|
aug_config.geometric.translate_height = 0.2
|
|
149
144
|
aug_config.geometric.translate_width = 0.2
|
|
150
|
-
aug_config.geometric.
|
|
151
|
-
aug_config.geometric.rotation_min = 0
|
|
152
|
-
aug_config.geometric.rotation_max = 0
|
|
153
|
-
aug_config.geometric.scale_min = 1.0
|
|
154
|
-
aug_config.geometric.scale_max = 1.0
|
|
145
|
+
aug_config.geometric.translate_p = 1.0
|
|
155
146
|
elif g == "erase_scale":
|
|
156
147
|
aug_config.geometric.erase_p = 1.0
|
|
157
148
|
elif g == "mixup":
|
|
@@ -456,7 +447,8 @@ def get_data_config(
|
|
|
456
447
|
train_labels_path: Optional[List[str]] = None,
|
|
457
448
|
val_labels_path: Optional[List[str]] = None,
|
|
458
449
|
validation_fraction: float = 0.1,
|
|
459
|
-
|
|
450
|
+
use_same_data_for_val: bool = False,
|
|
451
|
+
test_file_path: Optional[Union[str, List[str]]] = None,
|
|
460
452
|
provider: str = "LabelsReader",
|
|
461
453
|
user_instances_only: bool = True,
|
|
462
454
|
data_pipeline_fw: str = "torch_dataset",
|
|
@@ -470,9 +462,10 @@ def get_data_config(
|
|
|
470
462
|
max_width: Optional[int] = None,
|
|
471
463
|
crop_size: Optional[int] = None,
|
|
472
464
|
min_crop_size: Optional[int] = 100,
|
|
473
|
-
|
|
465
|
+
crop_padding: Optional[int] = None,
|
|
466
|
+
use_augmentations_train: bool = True,
|
|
474
467
|
intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
|
|
475
|
-
geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] =
|
|
468
|
+
geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = "rotation",
|
|
476
469
|
):
|
|
477
470
|
"""Train a pose-estimation model with SLEAP-NN framework.
|
|
478
471
|
|
|
@@ -486,7 +479,11 @@ def get_data_config(
|
|
|
486
479
|
training set to sample for generating the validation set. The remaining
|
|
487
480
|
labeled frames will be left in the training set. If the `validation_labels`
|
|
488
481
|
are already specified, this has no effect. Default: 0.1.
|
|
489
|
-
|
|
482
|
+
use_same_data_for_val: If `True`, use the same data for both training and
|
|
483
|
+
validation (train = val). Useful for intentional overfitting on small
|
|
484
|
+
datasets. When enabled, `val_labels_path` and `validation_fraction` are
|
|
485
|
+
ignored. Default: False.
|
|
486
|
+
test_file_path: Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)).
|
|
490
487
|
Note: This is used to get evaluation on test set after training is completed.
|
|
491
488
|
provider: Provider class to read the input sleap files. Only "LabelsReader"
|
|
492
489
|
supported for the training pipeline. Default: "LabelsReader".
|
|
@@ -508,16 +505,19 @@ def get_data_config(
|
|
|
508
505
|
is set to True, then we convert the image to grayscale (single-channel)
|
|
509
506
|
image. If the source image has only one channel and this is set to False, then we retain the single channel input. Default: `False`.
|
|
510
507
|
scale: Factor to resize the image dimensions by, specified as a float. Default: 1.0.
|
|
511
|
-
max_height: Maximum height the image should be padded to. If not provided, the
|
|
508
|
+
max_height: Maximum height the original image should be resized and padded to. If not provided, the
|
|
512
509
|
original image size will be retained. Default: None.
|
|
513
|
-
max_width: Maximum width the image should be padded to. If not provided, the
|
|
510
|
+
max_width: Maximum width the original image should be resized and padded to. If not provided, the
|
|
514
511
|
original image size will be retained. Default: None.
|
|
515
512
|
crop_size: Crop size of each instance for centered-instance model.
|
|
516
513
|
If `None`, this would be automatically computed based on the largest instance
|
|
517
|
-
in the `sio.Labels` file. Default: None.
|
|
514
|
+
in the `sio.Labels` file. If `scale` is provided, then the cropped image will be resized according to `scale`. Default: None.
|
|
518
515
|
min_crop_size: Minimum crop size to be used if `crop_size` is `None`. Default: 100.
|
|
516
|
+
crop_padding: Padding in pixels to add around instance bounding box when computing
|
|
517
|
+
crop size. If `None`, padding is auto-computed based on augmentation settings.
|
|
518
|
+
Only used when `crop_size` is `None`. Default: None.
|
|
519
519
|
use_augmentations_train: True if the data augmentation should be applied to the
|
|
520
|
-
training data, else False. Default:
|
|
520
|
+
training data, else False. Default: True.
|
|
521
521
|
intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
|
|
522
522
|
or list of strings from the above allowed values. To have custom values, pass
|
|
523
523
|
a dict with the structure in `sleap_nn.config.data_config.IntensityConfig`.
|
|
@@ -529,7 +529,8 @@ def get_data_config(
|
|
|
529
529
|
or list of strings from the above allowed values. To have custom values, pass
|
|
530
530
|
a dict with the structure in `sleap_nn.config.data_config.GeometryConfig`.
|
|
531
531
|
For eg: {
|
|
532
|
-
"
|
|
532
|
+
"rotation_min": -45,
|
|
533
|
+
"rotation_max": 45,
|
|
533
534
|
"affine_p": 1.0
|
|
534
535
|
}
|
|
535
536
|
"""
|
|
@@ -541,6 +542,7 @@ def get_data_config(
|
|
|
541
542
|
scale=scale,
|
|
542
543
|
crop_size=crop_size,
|
|
543
544
|
min_crop_size=min_crop_size,
|
|
545
|
+
crop_padding=crop_padding,
|
|
544
546
|
)
|
|
545
547
|
augmentation_config = None
|
|
546
548
|
if use_augmentations_train:
|
|
@@ -553,6 +555,7 @@ def get_data_config(
|
|
|
553
555
|
train_labels_path=train_labels_path,
|
|
554
556
|
val_labels_path=val_labels_path,
|
|
555
557
|
validation_fraction=validation_fraction,
|
|
558
|
+
use_same_data_for_val=use_same_data_for_val,
|
|
556
559
|
test_file_path=test_file_path,
|
|
557
560
|
provider=provider,
|
|
558
561
|
user_instances_only=user_instances_only,
|
|
@@ -675,6 +678,7 @@ def get_trainer_config(
|
|
|
675
678
|
wandb_save_viz_imgs_wandb: bool = False,
|
|
676
679
|
wandb_resume_prv_runid: Optional[str] = None,
|
|
677
680
|
wandb_group_name: Optional[str] = None,
|
|
681
|
+
wandb_delete_local_logs: Optional[bool] = None,
|
|
678
682
|
optimizer: str = "Adam",
|
|
679
683
|
learning_rate: float = 1e-3,
|
|
680
684
|
amsgrad: bool = False,
|
|
@@ -744,6 +748,9 @@ def get_trainer_config(
|
|
|
744
748
|
wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
|
|
745
749
|
ckpt. Default: None
|
|
746
750
|
wandb_group_name: Group name for the wandb run. Default: None.
|
|
751
|
+
wandb_delete_local_logs: If True, delete local wandb logs folder after training.
|
|
752
|
+
If False, keep the folder. If None (default), automatically delete if logging
|
|
753
|
+
online (wandb_mode != "offline") and keep if logging offline. Default: None.
|
|
747
754
|
optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
|
|
748
755
|
learning_rate: Learning rate of type float. Default: 1e-3.
|
|
749
756
|
amsgrad: Enable AMSGrad with the optimizer. Default: False.
|
|
@@ -844,6 +851,7 @@ def get_trainer_config(
|
|
|
844
851
|
save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
|
|
845
852
|
prv_runid=wandb_resume_prv_runid,
|
|
846
853
|
group=wandb_group_name,
|
|
854
|
+
delete_local_logs=wandb_delete_local_logs,
|
|
847
855
|
),
|
|
848
856
|
save_ckpt=save_ckpt,
|
|
849
857
|
ckpt_dir=ckpt_dir,
|
|
@@ -84,6 +84,16 @@ class WandBConfig:
|
|
|
84
84
|
prv_runid: (str) Previous run ID if training should be resumed from a previous ckpt. *Default*: `None`.
|
|
85
85
|
group: (str) Group for wandb logging. *Default*: `None`.
|
|
86
86
|
current_run_id: (str) Run ID for the current model training. (stored once the training starts). *Default*: `None`.
|
|
87
|
+
viz_enabled: (bool) If True, log pre-rendered matplotlib images to wandb. *Default*: `True`.
|
|
88
|
+
viz_boxes: (bool) If True, log interactive keypoint boxes. *Default*: `False`.
|
|
89
|
+
viz_masks: (bool) If True, log confidence map overlay masks. *Default*: `False`.
|
|
90
|
+
viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
|
|
91
|
+
viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
|
|
92
|
+
log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
|
|
93
|
+
delete_local_logs: (bool, optional) If True, delete local wandb logs folder after
|
|
94
|
+
training. If False, keep the folder. If None (default), automatically delete
|
|
95
|
+
if logging online (wandb_mode != "offline") and keep if logging offline.
|
|
96
|
+
*Default*: `None`.
|
|
87
97
|
"""
|
|
88
98
|
|
|
89
99
|
entity: Optional[str] = None
|
|
@@ -95,6 +105,13 @@ class WandBConfig:
|
|
|
95
105
|
prv_runid: Optional[str] = None
|
|
96
106
|
group: Optional[str] = None
|
|
97
107
|
current_run_id: Optional[str] = None
|
|
108
|
+
viz_enabled: bool = True
|
|
109
|
+
viz_boxes: bool = False
|
|
110
|
+
viz_masks: bool = False
|
|
111
|
+
viz_box_size: float = 5.0
|
|
112
|
+
viz_confmap_threshold: float = 0.1
|
|
113
|
+
log_viz_table: bool = False
|
|
114
|
+
delete_local_logs: Optional[bool] = None
|
|
98
115
|
|
|
99
116
|
|
|
100
117
|
@define
|
|
@@ -161,19 +178,69 @@ class ReduceLROnPlateauConfig:
|
|
|
161
178
|
raise ValueError(message)
|
|
162
179
|
|
|
163
180
|
|
|
181
|
+
@define
|
|
182
|
+
class CosineAnnealingWarmupConfig:
|
|
183
|
+
"""Configuration for Cosine Annealing with Linear Warmup scheduler.
|
|
184
|
+
|
|
185
|
+
The learning rate increases linearly during warmup, then decreases following
|
|
186
|
+
a cosine curve to the minimum value.
|
|
187
|
+
|
|
188
|
+
Attributes:
|
|
189
|
+
warmup_epochs: (int) Number of epochs for linear warmup phase. *Default*: `5`.
|
|
190
|
+
max_epochs: (int) Total number of training epochs. Will be overridden by
|
|
191
|
+
trainer's max_epochs if not specified. *Default*: `None`.
|
|
192
|
+
warmup_start_lr: (float) Learning rate at start of warmup. *Default*: `0.0`.
|
|
193
|
+
eta_min: (float) Minimum learning rate at end of cosine decay. *Default*: `0.0`.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
warmup_epochs: int = field(default=5, validator=validators.ge(0))
|
|
197
|
+
max_epochs: Optional[int] = None
|
|
198
|
+
warmup_start_lr: float = field(default=0.0, validator=validators.ge(0))
|
|
199
|
+
eta_min: float = field(default=0.0, validator=validators.ge(0))
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@define
|
|
203
|
+
class LinearWarmupLinearDecayConfig:
|
|
204
|
+
"""Configuration for Linear Warmup + Linear Decay scheduler.
|
|
205
|
+
|
|
206
|
+
The learning rate increases linearly during warmup, then decreases linearly
|
|
207
|
+
to the end learning rate.
|
|
208
|
+
|
|
209
|
+
Attributes:
|
|
210
|
+
warmup_epochs: (int) Number of epochs for linear warmup phase. *Default*: `5`.
|
|
211
|
+
max_epochs: (int) Total number of training epochs. Will be overridden by
|
|
212
|
+
trainer's max_epochs if not specified. *Default*: `None`.
|
|
213
|
+
warmup_start_lr: (float) Learning rate at start of warmup. *Default*: `0.0`.
|
|
214
|
+
end_lr: (float) Learning rate at end of training. *Default*: `0.0`.
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
warmup_epochs: int = field(default=5, validator=validators.ge(0))
|
|
218
|
+
max_epochs: Optional[int] = None
|
|
219
|
+
warmup_start_lr: float = field(default=0.0, validator=validators.ge(0))
|
|
220
|
+
end_lr: float = field(default=0.0, validator=validators.ge(0))
|
|
221
|
+
|
|
222
|
+
|
|
164
223
|
@define
|
|
165
224
|
class LRSchedulerConfig:
|
|
166
225
|
"""Configuration for lr_scheduler.
|
|
167
226
|
|
|
227
|
+
Only one scheduler should be configured at a time. If multiple are set,
|
|
228
|
+
priority order is: cosine_annealing_warmup > linear_warmup_linear_decay >
|
|
229
|
+
step_lr > reduce_lr_on_plateau.
|
|
230
|
+
|
|
168
231
|
Attributes:
|
|
169
232
|
step_lr: Configuration for StepLR scheduler.
|
|
170
233
|
reduce_lr_on_plateau: Configuration for ReduceLROnPlateau scheduler.
|
|
234
|
+
cosine_annealing_warmup: Configuration for Cosine Annealing with Linear Warmup scheduler.
|
|
235
|
+
linear_warmup_linear_decay: Configuration for Linear Warmup + Linear Decay scheduler.
|
|
171
236
|
"""
|
|
172
237
|
|
|
173
238
|
step_lr: Optional[StepLRConfig] = None
|
|
174
239
|
reduce_lr_on_plateau: Optional[ReduceLROnPlateauConfig] = field(
|
|
175
240
|
factory=ReduceLROnPlateauConfig
|
|
176
241
|
)
|
|
242
|
+
cosine_annealing_warmup: Optional[CosineAnnealingWarmupConfig] = None
|
|
243
|
+
linear_warmup_linear_decay: Optional[LinearWarmupLinearDecayConfig] = None
|
|
177
244
|
|
|
178
245
|
|
|
179
246
|
@define
|
|
@@ -191,6 +258,26 @@ class EarlyStoppingConfig:
|
|
|
191
258
|
stop_training_on_plateau: bool = True
|
|
192
259
|
|
|
193
260
|
|
|
261
|
+
@define
|
|
262
|
+
class EvalConfig:
|
|
263
|
+
"""Configuration for epoch-end evaluation.
|
|
264
|
+
|
|
265
|
+
Attributes:
|
|
266
|
+
enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
|
|
267
|
+
frequency: (int) Evaluate every N epochs. *Default*: `1`.
|
|
268
|
+
oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
|
|
269
|
+
oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
|
|
270
|
+
match_threshold: (float) Maximum distance in pixels for centroid matching.
|
|
271
|
+
Only used for centroid model evaluation. *Default*: `50.0`.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
enabled: bool = False
|
|
275
|
+
frequency: int = field(default=1, validator=validators.ge(1))
|
|
276
|
+
oks_stddev: float = field(default=0.025, validator=validators.gt(0))
|
|
277
|
+
oks_scale: Optional[float] = None
|
|
278
|
+
match_threshold: float = field(default=50.0, validator=validators.gt(0))
|
|
279
|
+
|
|
280
|
+
|
|
194
281
|
@define
|
|
195
282
|
class HardKeypointMiningConfig:
|
|
196
283
|
"""Configuration for online hard keypoint mining.
|
|
@@ -293,6 +380,7 @@ class TrainerConfig:
|
|
|
293
380
|
factory=HardKeypointMiningConfig
|
|
294
381
|
)
|
|
295
382
|
zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
|
|
383
|
+
eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
|
|
296
384
|
|
|
297
385
|
@staticmethod
|
|
298
386
|
def validate_optimizer_name(value):
|