sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
|
@@ -48,4 +48,9 @@ logger.add(
|
|
|
48
48
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
|
|
49
49
|
)
|
|
50
50
|
|
|
51
|
-
__version__ = "0.
|
|
51
|
+
__version__ = "0.1.0a0"
|
|
52
|
+
|
|
53
|
+
# Public API
|
|
54
|
+
from sleap_nn.evaluation import load_metrics
|
|
55
|
+
|
|
56
|
+
__all__ = ["load_metrics", "__version__"]
|
sleap_nn/cli.py
CHANGED
|
@@ -4,14 +4,24 @@ import click
|
|
|
4
4
|
from loguru import logger
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from omegaconf import OmegaConf, DictConfig
|
|
7
|
+
import sleap_io as sio
|
|
7
8
|
from sleap_nn.predict import run_inference, frame_list
|
|
8
9
|
from sleap_nn.evaluation import run_evaluation
|
|
9
10
|
from sleap_nn.train import run_training
|
|
11
|
+
from sleap_nn import __version__
|
|
10
12
|
import hydra
|
|
11
13
|
import sys
|
|
12
14
|
from click import Command
|
|
13
15
|
|
|
14
16
|
|
|
17
|
+
def print_version(ctx, param, value):
|
|
18
|
+
"""Print version and exit."""
|
|
19
|
+
if not value or ctx.resilient_parsing:
|
|
20
|
+
return
|
|
21
|
+
click.echo(f"sleap-nn {__version__}")
|
|
22
|
+
ctx.exit()
|
|
23
|
+
|
|
24
|
+
|
|
15
25
|
class TrainCommand(Command):
|
|
16
26
|
"""Custom command class that overrides help behavior for train command."""
|
|
17
27
|
|
|
@@ -20,7 +30,26 @@ class TrainCommand(Command):
|
|
|
20
30
|
show_training_help()
|
|
21
31
|
|
|
22
32
|
|
|
33
|
+
def parse_path_map(ctx, param, value):
|
|
34
|
+
"""Parse (old, new) path pairs into a dictionary for path mapping options."""
|
|
35
|
+
if not value:
|
|
36
|
+
return None
|
|
37
|
+
result = {}
|
|
38
|
+
for old_path, new_path in value:
|
|
39
|
+
result[old_path] = Path(new_path).as_posix()
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
|
|
23
43
|
@click.group()
|
|
44
|
+
@click.option(
|
|
45
|
+
"--version",
|
|
46
|
+
"-v",
|
|
47
|
+
is_flag=True,
|
|
48
|
+
callback=print_version,
|
|
49
|
+
expose_value=False,
|
|
50
|
+
is_eager=True,
|
|
51
|
+
help="Show version and exit.",
|
|
52
|
+
)
|
|
24
53
|
def cli():
|
|
25
54
|
"""SLEAP-NN: Neural network backend for training and inference for animal pose estimation.
|
|
26
55
|
|
|
@@ -29,6 +58,7 @@ def cli():
|
|
|
29
58
|
train - Run training workflow
|
|
30
59
|
track - Run inference/ tracking workflow
|
|
31
60
|
eval - Run evaluation workflow
|
|
61
|
+
system - Display system information and GPU status
|
|
32
62
|
"""
|
|
33
63
|
pass
|
|
34
64
|
|
|
@@ -67,8 +97,38 @@ For a detailed list of all available config options, please refer to https://nn.
|
|
|
67
97
|
@click.option(
|
|
68
98
|
"--config-dir", "-d", type=str, default=".", help="Configuration directory path"
|
|
69
99
|
)
|
|
100
|
+
@click.option(
|
|
101
|
+
"--video-paths",
|
|
102
|
+
"-v",
|
|
103
|
+
multiple=True,
|
|
104
|
+
help="Video paths to replace existing paths in the labels file. "
|
|
105
|
+
"Order must match the order of videos in the labels file. "
|
|
106
|
+
"Can be specified multiple times. "
|
|
107
|
+
"Example: --video-paths /path/to/vid1.mp4 --video-paths /path/to/vid2.mp4",
|
|
108
|
+
)
|
|
109
|
+
@click.option(
|
|
110
|
+
"--video-path-map",
|
|
111
|
+
nargs=2,
|
|
112
|
+
multiple=True,
|
|
113
|
+
callback=parse_path_map,
|
|
114
|
+
metavar="OLD NEW",
|
|
115
|
+
help="Map old video path to new path. Takes two arguments: old path and new path. "
|
|
116
|
+
"Can be specified multiple times. "
|
|
117
|
+
'Example: --video-path-map "/old/vid.mp`4" "/new/vid.mp4"',
|
|
118
|
+
)
|
|
119
|
+
@click.option(
|
|
120
|
+
"--prefix-map",
|
|
121
|
+
nargs=2,
|
|
122
|
+
multiple=True,
|
|
123
|
+
callback=parse_path_map,
|
|
124
|
+
metavar="OLD NEW",
|
|
125
|
+
help="Map old path prefix to new prefix. Takes two arguments: old prefix and new prefix. "
|
|
126
|
+
"Updates ALL videos that share the same prefix. Useful when moving data between machines. "
|
|
127
|
+
"Can be specified multiple times. "
|
|
128
|
+
'Example: --prefix-map "/old/server/path" "/new/local/path"',
|
|
129
|
+
)
|
|
70
130
|
@click.argument("overrides", nargs=-1, type=click.UNPROCESSED)
|
|
71
|
-
def train(config_name, config_dir, overrides):
|
|
131
|
+
def train(config_name, config_dir, video_paths, video_path_map, prefix_map, overrides):
|
|
72
132
|
"""Run training workflow with Hydra config overrides.
|
|
73
133
|
|
|
74
134
|
Examples:
|
|
@@ -97,7 +157,62 @@ def train(config_name, config_dir, overrides):
|
|
|
97
157
|
|
|
98
158
|
logger.info("Input config:")
|
|
99
159
|
logger.info("\n" + OmegaConf.to_yaml(cfg))
|
|
100
|
-
|
|
160
|
+
|
|
161
|
+
# Handle video path replacement options
|
|
162
|
+
train_labels = None
|
|
163
|
+
val_labels = None
|
|
164
|
+
|
|
165
|
+
# Check that only one replacement option is used
|
|
166
|
+
# video_paths is a tuple (empty if not used), others are None or dict
|
|
167
|
+
has_video_paths = len(video_paths) > 0
|
|
168
|
+
has_video_path_map = video_path_map is not None
|
|
169
|
+
has_prefix_map = prefix_map is not None
|
|
170
|
+
options_used = sum([has_video_paths, has_video_path_map, has_prefix_map])
|
|
171
|
+
|
|
172
|
+
if options_used > 1:
|
|
173
|
+
raise click.UsageError(
|
|
174
|
+
"Cannot use multiple path replacement options. "
|
|
175
|
+
"Choose one of: --video-paths, --video-path-map, or --prefix-map."
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if options_used == 1:
|
|
179
|
+
# Load train labels
|
|
180
|
+
train_labels = [
|
|
181
|
+
sio.load_slp(path) for path in cfg.data_config.train_labels_path
|
|
182
|
+
]
|
|
183
|
+
|
|
184
|
+
# Load val labels if they exist
|
|
185
|
+
if (
|
|
186
|
+
cfg.data_config.val_labels_path is not None
|
|
187
|
+
and len(cfg.data_config.val_labels_path) > 0
|
|
188
|
+
):
|
|
189
|
+
val_labels = [
|
|
190
|
+
sio.load_slp(path) for path in cfg.data_config.val_labels_path
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
# Build replacement arguments based on option used
|
|
194
|
+
if has_video_paths:
|
|
195
|
+
# List of paths (order must match videos in labels file)
|
|
196
|
+
replace_kwargs = {
|
|
197
|
+
"new_filenames": [Path(p).as_posix() for p in video_paths]
|
|
198
|
+
}
|
|
199
|
+
elif has_video_path_map:
|
|
200
|
+
# Dictionary mapping old filenames to new filenames
|
|
201
|
+
replace_kwargs = {"filename_map": video_path_map}
|
|
202
|
+
else: # has_prefix_map
|
|
203
|
+
# Dictionary mapping old prefixes to new prefixes
|
|
204
|
+
replace_kwargs = {"prefix_map": prefix_map}
|
|
205
|
+
|
|
206
|
+
# Apply replacement to train labels
|
|
207
|
+
for labels in train_labels:
|
|
208
|
+
labels.replace_filenames(**replace_kwargs)
|
|
209
|
+
|
|
210
|
+
# Apply replacement to val labels if they exist
|
|
211
|
+
if val_labels:
|
|
212
|
+
for labels in val_labels:
|
|
213
|
+
labels.replace_filenames(**replace_kwargs)
|
|
214
|
+
|
|
215
|
+
run_training(config=cfg, train_labels=train_labels, val_labels=val_labels)
|
|
101
216
|
|
|
102
217
|
|
|
103
218
|
@cli.command()
|
|
@@ -209,6 +324,18 @@ def train(config_name, config_dir, overrides):
|
|
|
209
324
|
default=False,
|
|
210
325
|
help="Only run inference on unlabeled suggested frames when running on labels dataset. This is useful for generating predictions for initialization during labeling.",
|
|
211
326
|
)
|
|
327
|
+
@click.option(
|
|
328
|
+
"--exclude_user_labeled",
|
|
329
|
+
is_flag=True,
|
|
330
|
+
default=False,
|
|
331
|
+
help="Skip frames that have user-labeled instances. Useful when predicting on entire video but skipping already-labeled frames.",
|
|
332
|
+
)
|
|
333
|
+
@click.option(
|
|
334
|
+
"--only_predicted_frames",
|
|
335
|
+
is_flag=True,
|
|
336
|
+
default=False,
|
|
337
|
+
help="Only run inference on frames that already have predictions. Requires .slp input file. Useful for re-predicting with a different model.",
|
|
338
|
+
)
|
|
212
339
|
@click.option(
|
|
213
340
|
"--no_empty_frames",
|
|
214
341
|
is_flag=True,
|
|
@@ -282,7 +409,7 @@ def train(config_name, config_dir, overrides):
|
|
|
282
409
|
"--crop_size",
|
|
283
410
|
type=int,
|
|
284
411
|
default=None,
|
|
285
|
-
help="Crop size. If not provided, the crop size from training_config.yaml is used.",
|
|
412
|
+
help="Crop size. If not provided, the crop size from training_config.yaml is used. If `input_scale` is provided, then the cropped image will be resized according to `input_scale`.",
|
|
286
413
|
)
|
|
287
414
|
@click.option(
|
|
288
415
|
"--peak_threshold",
|
|
@@ -474,5 +601,17 @@ def eval(**kwargs):
|
|
|
474
601
|
run_evaluation(**kwargs)
|
|
475
602
|
|
|
476
603
|
|
|
604
|
+
@cli.command()
|
|
605
|
+
def system():
|
|
606
|
+
"""Display system information and GPU status.
|
|
607
|
+
|
|
608
|
+
Shows Python version, platform, PyTorch version, CUDA availability,
|
|
609
|
+
driver version with compatibility check, GPU details, and package versions.
|
|
610
|
+
"""
|
|
611
|
+
from sleap_nn.system_info import print_system_info
|
|
612
|
+
|
|
613
|
+
print_system_info()
|
|
614
|
+
|
|
615
|
+
|
|
477
616
|
if __name__ == "__main__":
|
|
478
617
|
cli()
|
sleap_nn/config/data_config.py
CHANGED
|
@@ -6,7 +6,7 @@ the parameters required to initialize the data config.
|
|
|
6
6
|
|
|
7
7
|
from attrs import define, field, validators
|
|
8
8
|
from omegaconf import MISSING
|
|
9
|
-
from typing import Optional, Tuple, Any, List
|
|
9
|
+
from typing import Optional, Tuple, Any, List, Union
|
|
10
10
|
from loguru import logger
|
|
11
11
|
import sleap_io as sio
|
|
12
12
|
import yaml
|
|
@@ -20,11 +20,15 @@ class PreprocessingConfig:
|
|
|
20
20
|
Attributes:
|
|
21
21
|
ensure_rgb: (bool) True if the input image should have 3 channels (RGB image). If input has only one channel when this is set to `True`, then the images from single-channel is replicated along the channel axis. If the image has three channels and this is set to False, then we retain the three channels. *Default*: `False`.
|
|
22
22
|
ensure_grayscale: (bool) True if the input image should only have a single channel. If input has three channels (RGB) and this is set to True, then we convert the image to grayscale (single-channel) image. If the source image has only one channel and this is set to False, then we retain the single channel input. *Default*: `False`.
|
|
23
|
-
max_height: (int) Maximum height the image should be padded to. If not provided, the original image size will be retained. *Default*: `None`.
|
|
24
|
-
max_width: (int) Maximum width the image should be padded to. If not provided, the original image size will be retained. *Default*: `None`.
|
|
23
|
+
max_height: (int) Maximum height the original image should be resized and padded to. If not provided, the original image size will be retained. *Default*: `None`.
|
|
24
|
+
max_width: (int) Maximum width the original image should be resized and padded to. If not provided, the original image size will be retained. *Default*: `None`.
|
|
25
25
|
scale: (float) Factor to resize the image dimensions by, specified as a float. *Default*: `1.0`.
|
|
26
|
-
crop_size: (int) Crop size of each instance for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file.
|
|
26
|
+
crop_size: (int) Crop size of each instance for centered-instance model. If `None`, this would be automatically computed based on the largest instance in the `sio.Labels` file.
|
|
27
|
+
If `scale` is provided, then the cropped image will be resized according to `scale`.*Default*: `None`.
|
|
27
28
|
min_crop_size: (int) Minimum crop size to be used if `crop_size` is `None`. *Default*: `100`.
|
|
29
|
+
crop_padding: (int) Padding in pixels to add around the instance bounding box when computing crop size.
|
|
30
|
+
If `None`, padding is auto-computed based on augmentation settings (rotation/scale).
|
|
31
|
+
Only used when `crop_size` is `None`. *Default*: `None`.
|
|
28
32
|
"""
|
|
29
33
|
|
|
30
34
|
ensure_rgb: bool = False
|
|
@@ -36,6 +40,7 @@ class PreprocessingConfig:
|
|
|
36
40
|
)
|
|
37
41
|
crop_size: Optional[int] = None
|
|
38
42
|
min_crop_size: Optional[int] = 100 # to help app work in case of error
|
|
43
|
+
crop_padding: Optional[int] = None
|
|
39
44
|
|
|
40
45
|
def validate_scale(self):
|
|
41
46
|
"""Scale Validation.
|
|
@@ -104,11 +109,14 @@ class GeometricConfig:
|
|
|
104
109
|
Attributes:
|
|
105
110
|
rotation_min: (float) Minimum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `-15.0`.
|
|
106
111
|
rotation_max: (float) Maximum rotation angle in degrees. A random angle in (rotation_min, rotation_max) will be sampled and applied to both images and keypoints. Set to 0 to disable rotation augmentation. *Default*: `15.0`.
|
|
112
|
+
rotation_p: (float, optional) Probability of applying random rotation independently. If set, rotation is applied separately from scale/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
|
|
107
113
|
scale_min: (float) Minimum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `0.9`.
|
|
108
114
|
scale_max: (float) Maximum scaling factor. If scale_min and scale_max are provided, the scale is randomly sampled from the range scale_min <= scale <= scale_max for isotropic scaling. *Default*: `1.1`.
|
|
115
|
+
scale_p: (float, optional) Probability of applying random scaling independently. If set, scaling is applied separately from rotation/translate. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
|
|
109
116
|
translate_width: (float) Maximum absolute fraction for horizontal translation. For example, if translate_width=a, then horizontal shift is randomly sampled in the range -img_width * a < dx < img_width * a. Will not translate by default. *Default*: `0.0`.
|
|
110
117
|
translate_height: (float) Maximum absolute fraction for vertical translation. For example, if translate_height=a, then vertical shift is randomly sampled in the range -img_height * a < dy < img_height * a. Will not translate by default. *Default*: `0.0`.
|
|
111
|
-
|
|
118
|
+
translate_p: (float, optional) Probability of applying random translation independently. If set, translation is applied separately from rotation/scale. If `None`, falls back to `affine_p` for bundled behavior. *Default*: `None`.
|
|
119
|
+
affine_p: (float) Probability of applying random affine transformations (rotation, scale, translate bundled together). Used for backwards compatibility when individual `*_p` params are not set. *Default*: `0.0`.
|
|
112
120
|
erase_scale_min: (float) Minimum value of range of proportion of erased area against input image. *Default*: `0.0001`.
|
|
113
121
|
erase_scale_max: (float) Maximum value of range of proportion of erased area against input image. *Default*: `0.01`.
|
|
114
122
|
erase_ratio_min: (float) Minimum value of range of aspect ratio of erased area. *Default*: `1.0`.
|
|
@@ -121,10 +129,13 @@ class GeometricConfig:
|
|
|
121
129
|
|
|
122
130
|
rotation_min: float = field(default=-15.0, validator=validators.ge(-180))
|
|
123
131
|
rotation_max: float = field(default=15.0, validator=validators.le(180))
|
|
132
|
+
rotation_p: Optional[float] = field(default=None)
|
|
124
133
|
scale_min: float = field(default=0.9, validator=validators.ge(0))
|
|
125
134
|
scale_max: float = field(default=1.1, validator=validators.ge(0))
|
|
135
|
+
scale_p: Optional[float] = field(default=None)
|
|
126
136
|
translate_width: float = 0.0
|
|
127
137
|
translate_height: float = 0.0
|
|
138
|
+
translate_p: Optional[float] = field(default=None)
|
|
128
139
|
affine_p: float = field(default=0.0, validator=validate_proportion)
|
|
129
140
|
erase_scale_min: float = 0.0001
|
|
130
141
|
erase_scale_max: float = 0.01
|
|
@@ -149,6 +160,28 @@ class AugmentationConfig:
|
|
|
149
160
|
geometric: Optional[GeometricConfig] = None
|
|
150
161
|
|
|
151
162
|
|
|
163
|
+
def validate_test_file_path(instance, attribute, value):
|
|
164
|
+
"""Validate test_file_path to accept str or List[str].
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
instance: The instance being validated.
|
|
168
|
+
attribute: The attribute being validated.
|
|
169
|
+
value: The value to validate.
|
|
170
|
+
|
|
171
|
+
Raises:
|
|
172
|
+
ValueError: If value is not None, str, or list of strings.
|
|
173
|
+
"""
|
|
174
|
+
if value is None:
|
|
175
|
+
return
|
|
176
|
+
if isinstance(value, str):
|
|
177
|
+
return
|
|
178
|
+
if isinstance(value, (list, tuple)) and all(isinstance(p, str) for p in value):
|
|
179
|
+
return
|
|
180
|
+
message = f"{attribute.name} must be a string or list of strings, got {type(value).__name__}"
|
|
181
|
+
logger.error(message)
|
|
182
|
+
raise ValueError(message)
|
|
183
|
+
|
|
184
|
+
|
|
152
185
|
@define
|
|
153
186
|
class DataConfig:
|
|
154
187
|
"""Data configuration.
|
|
@@ -157,7 +190,8 @@ class DataConfig:
|
|
|
157
190
|
train_labels_path: (List[str]) List of paths to training data (`.slp` file(s)). *Default*: `None`.
|
|
158
191
|
val_labels_path: (List[str]) List of paths to validation data (`.slp` file(s)). *Default*: `None`.
|
|
159
192
|
validation_fraction: (float) Float between 0 and 1 specifying the fraction of the training set to sample for generating the validation set. The remaining labeled frames will be left in the training set. If the `validation_labels` are already specified, this has no effect. *Default*: `0.1`.
|
|
160
|
-
|
|
193
|
+
use_same_data_for_val: (bool) If `True`, use the same data for both training and validation (train = val). Useful for intentional overfitting on small datasets. When enabled, `val_labels_path` and `validation_fraction` are ignored. *Default*: `False`.
|
|
194
|
+
test_file_path: (str or List[str]) Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)). *Note*: This is used only with CLI to get evaluation on test set after training is completed. *Default*: `None`.
|
|
161
195
|
provider: (str) Provider class to read the input sleap files. Only "LabelsReader" is currently supported for the training pipeline. *Default*: `"LabelsReader"`.
|
|
162
196
|
user_instances_only: (bool) `True` if only user labeled instances should be used for training. If `False`, both user labeled and predicted instances would be used. *Default*: `True`.
|
|
163
197
|
data_pipeline_fw: (str) Framework to create the data loaders. One of [`torch_dataset`, `torch_dataset_cache_img_memory`, `torch_dataset_cache_img_disk`]. *Default*: `"torch_dataset"`. (Note: When using `torch_dataset`, `num_workers` in `trainer_config` should be set to 0 as multiprocessing doesn't work with pickling video backends.)
|
|
@@ -173,7 +207,10 @@ class DataConfig:
|
|
|
173
207
|
train_labels_path: Optional[List[str]] = None
|
|
174
208
|
val_labels_path: Optional[List[str]] = None # TODO : revisit MISSING!
|
|
175
209
|
validation_fraction: float = 0.1
|
|
176
|
-
|
|
210
|
+
use_same_data_for_val: bool = False
|
|
211
|
+
test_file_path: Optional[Any] = field(
|
|
212
|
+
default=None, validator=validate_test_file_path
|
|
213
|
+
)
|
|
177
214
|
provider: str = "LabelsReader"
|
|
178
215
|
user_instances_only: bool = True
|
|
179
216
|
data_pipeline_fw: str = "torch_dataset"
|
sleap_nn/config/get_config.py
CHANGED
|
@@ -131,27 +131,18 @@ def get_aug_config(
|
|
|
131
131
|
|
|
132
132
|
for g in geometric_aug:
|
|
133
133
|
if g == "rotation":
|
|
134
|
-
|
|
135
|
-
aug_config.geometric.
|
|
136
|
-
aug_config.geometric.scale_max = 1.0
|
|
137
|
-
aug_config.geometric.translate_height = 0
|
|
138
|
-
aug_config.geometric.translate_width = 0
|
|
134
|
+
# Use new independent rotation probability
|
|
135
|
+
aug_config.geometric.rotation_p = 1.0
|
|
139
136
|
elif g == "scale":
|
|
137
|
+
# Use new independent scale probability
|
|
140
138
|
aug_config.geometric.scale_min = 0.9
|
|
141
139
|
aug_config.geometric.scale_max = 1.1
|
|
142
|
-
aug_config.geometric.
|
|
143
|
-
aug_config.geometric.rotation_min = 0
|
|
144
|
-
aug_config.geometric.rotation_max = 0
|
|
145
|
-
aug_config.geometric.translate_height = 0
|
|
146
|
-
aug_config.geometric.translate_width = 0
|
|
140
|
+
aug_config.geometric.scale_p = 1.0
|
|
147
141
|
elif g == "translate":
|
|
142
|
+
# Use new independent translate probability
|
|
148
143
|
aug_config.geometric.translate_height = 0.2
|
|
149
144
|
aug_config.geometric.translate_width = 0.2
|
|
150
|
-
aug_config.geometric.
|
|
151
|
-
aug_config.geometric.rotation_min = 0
|
|
152
|
-
aug_config.geometric.rotation_max = 0
|
|
153
|
-
aug_config.geometric.scale_min = 1.0
|
|
154
|
-
aug_config.geometric.scale_max = 1.0
|
|
145
|
+
aug_config.geometric.translate_p = 1.0
|
|
155
146
|
elif g == "erase_scale":
|
|
156
147
|
aug_config.geometric.erase_p = 1.0
|
|
157
148
|
elif g == "mixup":
|
|
@@ -456,7 +447,8 @@ def get_data_config(
|
|
|
456
447
|
train_labels_path: Optional[List[str]] = None,
|
|
457
448
|
val_labels_path: Optional[List[str]] = None,
|
|
458
449
|
validation_fraction: float = 0.1,
|
|
459
|
-
|
|
450
|
+
use_same_data_for_val: bool = False,
|
|
451
|
+
test_file_path: Optional[Union[str, List[str]]] = None,
|
|
460
452
|
provider: str = "LabelsReader",
|
|
461
453
|
user_instances_only: bool = True,
|
|
462
454
|
data_pipeline_fw: str = "torch_dataset",
|
|
@@ -470,6 +462,7 @@ def get_data_config(
|
|
|
470
462
|
max_width: Optional[int] = None,
|
|
471
463
|
crop_size: Optional[int] = None,
|
|
472
464
|
min_crop_size: Optional[int] = 100,
|
|
465
|
+
crop_padding: Optional[int] = None,
|
|
473
466
|
use_augmentations_train: bool = False,
|
|
474
467
|
intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
|
|
475
468
|
geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
|
|
@@ -486,7 +479,11 @@ def get_data_config(
|
|
|
486
479
|
training set to sample for generating the validation set. The remaining
|
|
487
480
|
labeled frames will be left in the training set. If the `validation_labels`
|
|
488
481
|
are already specified, this has no effect. Default: 0.1.
|
|
489
|
-
|
|
482
|
+
use_same_data_for_val: If `True`, use the same data for both training and
|
|
483
|
+
validation (train = val). Useful for intentional overfitting on small
|
|
484
|
+
datasets. When enabled, `val_labels_path` and `validation_fraction` are
|
|
485
|
+
ignored. Default: False.
|
|
486
|
+
test_file_path: Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)).
|
|
490
487
|
Note: This is used to get evaluation on test set after training is completed.
|
|
491
488
|
provider: Provider class to read the input sleap files. Only "LabelsReader"
|
|
492
489
|
supported for the training pipeline. Default: "LabelsReader".
|
|
@@ -508,14 +505,17 @@ def get_data_config(
|
|
|
508
505
|
is set to True, then we convert the image to grayscale (single-channel)
|
|
509
506
|
image. If the source image has only one channel and this is set to False, then we retain the single channel input. Default: `False`.
|
|
510
507
|
scale: Factor to resize the image dimensions by, specified as a float. Default: 1.0.
|
|
511
|
-
max_height: Maximum height the image should be padded to. If not provided, the
|
|
508
|
+
max_height: Maximum height the original image should be resized and padded to. If not provided, the
|
|
512
509
|
original image size will be retained. Default: None.
|
|
513
|
-
max_width: Maximum width the image should be padded to. If not provided, the
|
|
510
|
+
max_width: Maximum width the original image should be resized and padded to. If not provided, the
|
|
514
511
|
original image size will be retained. Default: None.
|
|
515
512
|
crop_size: Crop size of each instance for centered-instance model.
|
|
516
513
|
If `None`, this would be automatically computed based on the largest instance
|
|
517
|
-
in the `sio.Labels` file. Default: None.
|
|
514
|
+
in the `sio.Labels` file. If `scale` is provided, then the cropped image will be resized according to `scale`. Default: None.
|
|
518
515
|
min_crop_size: Minimum crop size to be used if `crop_size` is `None`. Default: 100.
|
|
516
|
+
crop_padding: Padding in pixels to add around instance bounding box when computing
|
|
517
|
+
crop size. If `None`, padding is auto-computed based on augmentation settings.
|
|
518
|
+
Only used when `crop_size` is `None`. Default: None.
|
|
519
519
|
use_augmentations_train: True if the data augmentation should be applied to the
|
|
520
520
|
training data, else False. Default: False.
|
|
521
521
|
intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
|
|
@@ -541,6 +541,7 @@ def get_data_config(
|
|
|
541
541
|
scale=scale,
|
|
542
542
|
crop_size=crop_size,
|
|
543
543
|
min_crop_size=min_crop_size,
|
|
544
|
+
crop_padding=crop_padding,
|
|
544
545
|
)
|
|
545
546
|
augmentation_config = None
|
|
546
547
|
if use_augmentations_train:
|
|
@@ -553,6 +554,7 @@ def get_data_config(
|
|
|
553
554
|
train_labels_path=train_labels_path,
|
|
554
555
|
val_labels_path=val_labels_path,
|
|
555
556
|
validation_fraction=validation_fraction,
|
|
557
|
+
use_same_data_for_val=use_same_data_for_val,
|
|
556
558
|
test_file_path=test_file_path,
|
|
557
559
|
provider=provider,
|
|
558
560
|
user_instances_only=user_instances_only,
|
|
@@ -84,6 +84,12 @@ class WandBConfig:
|
|
|
84
84
|
prv_runid: (str) Previous run ID if training should be resumed from a previous ckpt. *Default*: `None`.
|
|
85
85
|
group: (str) Group for wandb logging. *Default*: `None`.
|
|
86
86
|
current_run_id: (str) Run ID for the current model training. (stored once the training starts). *Default*: `None`.
|
|
87
|
+
viz_enabled: (bool) If True, log pre-rendered matplotlib images to wandb. *Default*: `True`.
|
|
88
|
+
viz_boxes: (bool) If True, log interactive keypoint boxes. *Default*: `False`.
|
|
89
|
+
viz_masks: (bool) If True, log confidence map overlay masks. *Default*: `False`.
|
|
90
|
+
viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
|
|
91
|
+
viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
|
|
92
|
+
log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
|
|
87
93
|
"""
|
|
88
94
|
|
|
89
95
|
entity: Optional[str] = None
|
|
@@ -95,6 +101,12 @@ class WandBConfig:
|
|
|
95
101
|
prv_runid: Optional[str] = None
|
|
96
102
|
group: Optional[str] = None
|
|
97
103
|
current_run_id: Optional[str] = None
|
|
104
|
+
viz_enabled: bool = True
|
|
105
|
+
viz_boxes: bool = False
|
|
106
|
+
viz_masks: bool = False
|
|
107
|
+
viz_box_size: float = 5.0
|
|
108
|
+
viz_confmap_threshold: float = 0.1
|
|
109
|
+
log_viz_table: bool = False
|
|
98
110
|
|
|
99
111
|
|
|
100
112
|
@define
|
sleap_nn/data/augmentation.py
CHANGED
|
@@ -112,10 +112,13 @@ def apply_geometric_augmentation(
|
|
|
112
112
|
instances: torch.Tensor,
|
|
113
113
|
rotation_min: Optional[float] = -15.0,
|
|
114
114
|
rotation_max: Optional[float] = 15.0,
|
|
115
|
+
rotation_p: Optional[float] = None,
|
|
115
116
|
scale_min: Optional[float] = 0.9,
|
|
116
117
|
scale_max: Optional[float] = 1.1,
|
|
118
|
+
scale_p: Optional[float] = None,
|
|
117
119
|
translate_width: Optional[float] = 0.02,
|
|
118
120
|
translate_height: Optional[float] = 0.02,
|
|
121
|
+
translate_p: Optional[float] = None,
|
|
119
122
|
affine_p: float = 0.0,
|
|
120
123
|
erase_scale_min: Optional[float] = 0.0001,
|
|
121
124
|
erase_scale_max: Optional[float] = 0.01,
|
|
@@ -133,11 +136,18 @@ def apply_geometric_augmentation(
|
|
|
133
136
|
instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
|
|
134
137
|
rotation_min: Minimum rotation angle in degrees. Default: -15.0.
|
|
135
138
|
rotation_max: Maximum rotation angle in degrees. Default: 15.0.
|
|
139
|
+
rotation_p: Probability of applying random rotation independently. If None,
|
|
140
|
+
falls back to affine_p for bundled behavior. Default: None.
|
|
136
141
|
scale_min: Minimum scaling factor for isotropic scaling. Default: 0.9.
|
|
137
142
|
scale_max: Maximum scaling factor for isotropic scaling. Default: 1.1.
|
|
143
|
+
scale_p: Probability of applying random scaling independently. If None,
|
|
144
|
+
falls back to affine_p for bundled behavior. Default: None.
|
|
138
145
|
translate_width: Maximum absolute fraction for horizontal translation. Default: 0.02.
|
|
139
146
|
translate_height: Maximum absolute fraction for vertical translation. Default: 0.02.
|
|
140
|
-
|
|
147
|
+
translate_p: Probability of applying random translation independently. If None,
|
|
148
|
+
falls back to affine_p for bundled behavior. Default: None.
|
|
149
|
+
affine_p: Probability of applying random affine transformations (rotation, scale,
|
|
150
|
+
translate bundled). Used when individual *_p params are None. Default: 0.0.
|
|
141
151
|
erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001.
|
|
142
152
|
erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01.
|
|
143
153
|
erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1.
|
|
@@ -151,7 +161,49 @@ def apply_geometric_augmentation(
|
|
|
151
161
|
Returns tuple: (image, instances) with augmentation applied.
|
|
152
162
|
"""
|
|
153
163
|
aug_stack = []
|
|
154
|
-
|
|
164
|
+
|
|
165
|
+
# Check if any individual probability is set
|
|
166
|
+
use_independent = (
|
|
167
|
+
rotation_p is not None or scale_p is not None or translate_p is not None
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if use_independent:
|
|
171
|
+
# New behavior: Apply augmentations independently with separate probabilities
|
|
172
|
+
if rotation_p is not None and rotation_p > 0:
|
|
173
|
+
aug_stack.append(
|
|
174
|
+
K.augmentation.RandomRotation(
|
|
175
|
+
degrees=(rotation_min, rotation_max),
|
|
176
|
+
p=rotation_p,
|
|
177
|
+
keepdim=True,
|
|
178
|
+
same_on_batch=True,
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
if scale_p is not None and scale_p > 0:
|
|
183
|
+
aug_stack.append(
|
|
184
|
+
K.augmentation.RandomAffine(
|
|
185
|
+
degrees=0, # No rotation
|
|
186
|
+
translate=None, # No translation
|
|
187
|
+
scale=(scale_min, scale_max),
|
|
188
|
+
p=scale_p,
|
|
189
|
+
keepdim=True,
|
|
190
|
+
same_on_batch=True,
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if translate_p is not None and translate_p > 0:
|
|
195
|
+
aug_stack.append(
|
|
196
|
+
K.augmentation.RandomAffine(
|
|
197
|
+
degrees=0, # No rotation
|
|
198
|
+
translate=(translate_width, translate_height),
|
|
199
|
+
scale=None, # No scaling
|
|
200
|
+
p=translate_p,
|
|
201
|
+
keepdim=True,
|
|
202
|
+
same_on_batch=True,
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
elif affine_p > 0:
|
|
206
|
+
# Legacy behavior: Bundled affine transformation
|
|
155
207
|
aug_stack.append(
|
|
156
208
|
K.augmentation.RandomAffine(
|
|
157
209
|
degrees=(rotation_min, rotation_max),
|
sleap_nn/data/custom_datasets.py
CHANGED
|
@@ -177,6 +177,9 @@ class BaseDataset(Dataset):
|
|
|
177
177
|
if self.user_instances_only:
|
|
178
178
|
if lf.user_instances is not None and len(lf.user_instances) > 0:
|
|
179
179
|
lf.instances = lf.user_instances
|
|
180
|
+
else:
|
|
181
|
+
# Skip frames without user instances
|
|
182
|
+
continue
|
|
180
183
|
is_empty = True
|
|
181
184
|
for _, inst in enumerate(lf.instances):
|
|
182
185
|
if not inst.is_empty: # filter all NaN instances.
|
|
@@ -684,15 +687,12 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
684
687
|
the images aren't cached and loaded from the `.slp` file on each access.
|
|
685
688
|
cache_img_path: Path to save the `.jpg` files. If `None`, current working dir is used.
|
|
686
689
|
use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
|
|
687
|
-
crop_size: Crop size of each instance for centered-instance model.
|
|
690
|
+
crop_size: Crop size of each instance for centered-instance model. If `scale` is provided, then the cropped image will be resized according to `scale`.
|
|
688
691
|
rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
|
|
689
692
|
disk occurs only once across all workers.
|
|
690
693
|
confmap_head_config: DictConfig object with all the keys in the `head_config` section.
|
|
691
694
|
(required keys: `sigma`, `output_stride`, `part_names` and `anchor_part` depending on the model type ).
|
|
692
695
|
labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
|
|
693
|
-
|
|
694
|
-
Note: If scale is provided for centered-instance model, the images are cropped out
|
|
695
|
-
from the scaled image with the given crop size.
|
|
696
696
|
"""
|
|
697
697
|
|
|
698
698
|
def __init__(
|
|
@@ -748,6 +748,9 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
748
748
|
if self.user_instances_only:
|
|
749
749
|
if lf.user_instances is not None and len(lf.user_instances) > 0:
|
|
750
750
|
lf.instances = lf.user_instances
|
|
751
|
+
else:
|
|
752
|
+
# Skip frames without user instances
|
|
753
|
+
continue
|
|
751
754
|
for inst_idx, inst in enumerate(lf.instances):
|
|
752
755
|
if not inst.is_empty: # filter all NaN instances.
|
|
753
756
|
video_idx = labels[labels_idx].videos.index(lf.video)
|
|
@@ -834,13 +837,6 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
834
837
|
)
|
|
835
838
|
instances = instances * eff_scale
|
|
836
839
|
|
|
837
|
-
# resize image
|
|
838
|
-
image, instances = apply_resizer(
|
|
839
|
-
image,
|
|
840
|
-
instances,
|
|
841
|
-
scale=self.scale,
|
|
842
|
-
)
|
|
843
|
-
|
|
844
840
|
# get the centroids based on the anchor idx
|
|
845
841
|
centroids = generate_centroids(instances, anchor_ind=self.anchor_ind)
|
|
846
842
|
|
|
@@ -901,6 +897,13 @@ class CenteredInstanceDataset(BaseDataset):
|
|
|
901
897
|
sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
|
|
902
898
|
sample["centroid"] = centered_centroid # (n_samples=1, 2)
|
|
903
899
|
|
|
900
|
+
# resize the cropped image
|
|
901
|
+
sample["instance_image"], sample["instance"] = apply_resizer(
|
|
902
|
+
sample["instance_image"],
|
|
903
|
+
sample["instance"],
|
|
904
|
+
scale=self.scale,
|
|
905
|
+
)
|
|
906
|
+
|
|
904
907
|
# Pad the image (if needed) according max stride
|
|
905
908
|
sample["instance_image"] = apply_pad_to_stride(
|
|
906
909
|
sample["instance_image"], max_stride=self.max_stride
|
|
@@ -959,7 +962,7 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
959
962
|
the images aren't cached and loaded from the `.slp` file on each access.
|
|
960
963
|
cache_img_path: Path to save the `.jpg` files. If `None`, current working dir is used.
|
|
961
964
|
use_existing_imgs: Use existing imgs/ chunks in the `cache_img_path`.
|
|
962
|
-
crop_size: Crop size of each instance for centered-instance model.
|
|
965
|
+
crop_size: Crop size of each instance for centered-instance model. If `scale` is provided, then the cropped image will be resized according to `scale`.
|
|
963
966
|
rank: Indicates the rank of the process. Used during distributed training to ensure that image storage to
|
|
964
967
|
disk occurs only once across all workers.
|
|
965
968
|
confmap_head_config: DictConfig object with all the keys in the `head_config` section.
|
|
@@ -967,9 +970,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
967
970
|
class_vectors_head_config: DictConfig object with all the keys in the `head_config` section.
|
|
968
971
|
(required keys: `classes`, `num_fc_layers`, `num_fc_units`, `output_stride`, `loss_weight`).
|
|
969
972
|
labels_list: List of `sio.Labels` objects. Used to store the labels in the cache. (only used if `cache_img` is `None`)
|
|
970
|
-
|
|
971
|
-
Note: If scale is provided for centered-instance model, the images are cropped out
|
|
972
|
-
from the scaled image with the given crop size.
|
|
973
973
|
"""
|
|
974
974
|
|
|
975
975
|
def __init__(
|
|
@@ -1082,13 +1082,6 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1082
1082
|
)
|
|
1083
1083
|
instances = instances * eff_scale
|
|
1084
1084
|
|
|
1085
|
-
# resize image
|
|
1086
|
-
image, instances = apply_resizer(
|
|
1087
|
-
image,
|
|
1088
|
-
instances,
|
|
1089
|
-
scale=self.scale,
|
|
1090
|
-
)
|
|
1091
|
-
|
|
1092
1085
|
# get class vectors
|
|
1093
1086
|
track_ids = torch.Tensor(
|
|
1094
1087
|
[
|
|
@@ -1165,6 +1158,13 @@ class TopDownCenteredInstanceMultiClassDataset(CenteredInstanceDataset):
|
|
|
1165
1158
|
sample["instance"] = center_instance # (n_samples=1, n_nodes, 2)
|
|
1166
1159
|
sample["centroid"] = centered_centroid # (n_samples=1, 2)
|
|
1167
1160
|
|
|
1161
|
+
# resize image
|
|
1162
|
+
sample["instance_image"], sample["instance"] = apply_resizer(
|
|
1163
|
+
sample["instance_image"],
|
|
1164
|
+
sample["instance"],
|
|
1165
|
+
scale=self.scale,
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
1168
|
# Pad the image (if needed) according max stride
|
|
1169
1169
|
sample["instance_image"] = apply_pad_to_stride(
|
|
1170
1170
|
sample["instance_image"], max_stride=self.max_stride
|