sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/train.py
CHANGED
|
@@ -6,6 +6,7 @@ from datetime import datetime
|
|
|
6
6
|
from time import time
|
|
7
7
|
from omegaconf import DictConfig, OmegaConf
|
|
8
8
|
from typing import Any, Dict, Optional, List, Tuple, Union
|
|
9
|
+
import sleap_io as sio
|
|
9
10
|
from sleap_nn.config.training_job_config import TrainingJobConfig
|
|
10
11
|
from sleap_nn.training.model_trainer import ModelTrainer
|
|
11
12
|
from sleap_nn.predict import run_inference as predict
|
|
@@ -15,15 +16,31 @@ from sleap_nn.config.get_config import (
|
|
|
15
16
|
get_model_config,
|
|
16
17
|
get_data_config,
|
|
17
18
|
)
|
|
19
|
+
from sleap_nn.system_info import get_startup_info_string
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
def run_training(
|
|
21
|
-
|
|
22
|
+
def run_training(
|
|
23
|
+
config: DictConfig,
|
|
24
|
+
train_labels: Optional[List[sio.Labels]] = None,
|
|
25
|
+
val_labels: Optional[List[sio.Labels]] = None,
|
|
26
|
+
):
|
|
27
|
+
"""Create ModelTrainer instance and start training.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config: Training configuration as a DictConfig.
|
|
31
|
+
train_labels: List of Labels objects for training.
|
|
32
|
+
val_labels: List of Labels objects for validation.
|
|
33
|
+
If not provided, the labels will be loaded from paths in the config.
|
|
34
|
+
"""
|
|
22
35
|
start_train_time = time()
|
|
23
36
|
start_timestamp = str(datetime.now())
|
|
24
37
|
logger.info(f"Started training at: {start_timestamp}")
|
|
38
|
+
logger.info(get_startup_info_string())
|
|
25
39
|
|
|
26
|
-
|
|
40
|
+
# provide the labels as the train labels, val labels will be split from the train labels
|
|
41
|
+
trainer = ModelTrainer.get_model_trainer_from_config(
|
|
42
|
+
config, train_labels=train_labels, val_labels=val_labels
|
|
43
|
+
)
|
|
27
44
|
trainer.train()
|
|
28
45
|
|
|
29
46
|
finish_timestamp = str(datetime.now())
|
|
@@ -39,48 +56,44 @@ def run_training(config: DictConfig):
|
|
|
39
56
|
# run inference on val dataset
|
|
40
57
|
if trainer.config.trainer_config.save_ckpt:
|
|
41
58
|
data_paths = {}
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
/ f"
|
|
59
|
+
run_path = (
|
|
60
|
+
Path(trainer.config.trainer_config.ckpt_dir)
|
|
61
|
+
/ trainer.config.trainer_config.run_name
|
|
62
|
+
)
|
|
63
|
+
for index, _ in enumerate(trainer.train_labels):
|
|
64
|
+
logger.info(f"Run path for index {index}: {run_path.as_posix()}")
|
|
65
|
+
data_paths[f"train.{index}"] = (
|
|
66
|
+
run_path / f"labels_gt.train.{index}.slp"
|
|
50
67
|
).as_posix()
|
|
51
|
-
data_paths[f"
|
|
52
|
-
|
|
53
|
-
/ trainer.config.trainer_config.run_name
|
|
54
|
-
/ f"labels_val_gt_{index}.slp"
|
|
68
|
+
data_paths[f"val.{index}"] = (
|
|
69
|
+
run_path / f"labels_gt.val.{index}.slp"
|
|
55
70
|
).as_posix()
|
|
56
71
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
)
|
|
61
|
-
|
|
72
|
+
# Handle test_file_path as either a string or list of strings
|
|
73
|
+
test_file_path = OmegaConf.select(
|
|
74
|
+
config, "data_config.test_file_path", default=None
|
|
75
|
+
)
|
|
76
|
+
if test_file_path is not None:
|
|
77
|
+
# Normalize to list of strings
|
|
78
|
+
if isinstance(test_file_path, str):
|
|
79
|
+
test_paths = [test_file_path]
|
|
80
|
+
else:
|
|
81
|
+
test_paths = list(test_file_path)
|
|
82
|
+
# Add each test path to data_paths (always use index for consistency)
|
|
83
|
+
for idx, test_path in enumerate(test_paths):
|
|
84
|
+
data_paths[f"test.{idx}"] = test_path
|
|
62
85
|
|
|
63
86
|
for d_name, path in data_paths.items():
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
/ f"pred_{d_name}.slp"
|
|
68
|
-
)
|
|
69
|
-
metrics_path = (
|
|
70
|
-
Path(trainer.config.trainer_config.ckpt_dir)
|
|
71
|
-
/ trainer.config.trainer_config.run_name
|
|
72
|
-
/ f"{d_name}_pred_metrics.npz"
|
|
73
|
-
)
|
|
87
|
+
# d_name is now in format: "train.0", "val.0", "test.0", etc.
|
|
88
|
+
pred_path = run_path / f"labels_pr.{d_name}.slp"
|
|
89
|
+
metrics_path = run_path / f"metrics.{d_name}.npz"
|
|
74
90
|
|
|
75
91
|
pred_labels = predict(
|
|
76
92
|
data_path=path,
|
|
77
|
-
model_paths=[
|
|
78
|
-
Path(trainer.config.trainer_config.ckpt_dir)
|
|
79
|
-
/ trainer.config.trainer_config.run_name
|
|
80
|
-
],
|
|
93
|
+
model_paths=[run_path],
|
|
81
94
|
peak_threshold=0.2,
|
|
82
95
|
make_labels=True,
|
|
83
|
-
device=trainer.trainer.strategy.root_device,
|
|
96
|
+
device=str(trainer.trainer.strategy.root_device),
|
|
84
97
|
output_path=pred_path,
|
|
85
98
|
ensure_rgb=config.data_config.preprocessing.ensure_rgb,
|
|
86
99
|
ensure_grayscale=config.data_config.preprocessing.ensure_grayscale,
|
|
@@ -105,12 +118,77 @@ def run_training(config: DictConfig):
|
|
|
105
118
|
logger.info(f"p90 dist: {metrics['distance_metrics']['p90']}")
|
|
106
119
|
logger.info(f"p50 dist: {metrics['distance_metrics']['p50']}")
|
|
107
120
|
|
|
121
|
+
# Log test metrics to wandb summary
|
|
122
|
+
if (
|
|
123
|
+
d_name.startswith("test")
|
|
124
|
+
and trainer.config.trainer_config.use_wandb
|
|
125
|
+
):
|
|
126
|
+
import wandb
|
|
127
|
+
|
|
128
|
+
if wandb.run is not None:
|
|
129
|
+
summary_metrics = {
|
|
130
|
+
f"eval/{d_name}/mOKS": metrics["mOKS"]["mOKS"],
|
|
131
|
+
f"eval/{d_name}/oks_voc_mAP": metrics["voc_metrics"][
|
|
132
|
+
"oks_voc.mAP"
|
|
133
|
+
],
|
|
134
|
+
f"eval/{d_name}/oks_voc_mAR": metrics["voc_metrics"][
|
|
135
|
+
"oks_voc.mAR"
|
|
136
|
+
],
|
|
137
|
+
f"eval/{d_name}/mPCK": metrics["pck_metrics"]["mPCK"],
|
|
138
|
+
f"eval/{d_name}/PCK_5": metrics["pck_metrics"]["PCK@5"],
|
|
139
|
+
f"eval/{d_name}/PCK_10": metrics["pck_metrics"]["PCK@10"],
|
|
140
|
+
f"eval/{d_name}/distance_avg": metrics["distance_metrics"][
|
|
141
|
+
"avg"
|
|
142
|
+
],
|
|
143
|
+
f"eval/{d_name}/distance_p50": metrics["distance_metrics"][
|
|
144
|
+
"p50"
|
|
145
|
+
],
|
|
146
|
+
f"eval/{d_name}/distance_p95": metrics["distance_metrics"][
|
|
147
|
+
"p95"
|
|
148
|
+
],
|
|
149
|
+
f"eval/{d_name}/distance_p99": metrics["distance_metrics"][
|
|
150
|
+
"p99"
|
|
151
|
+
],
|
|
152
|
+
f"eval/{d_name}/visibility_precision": metrics[
|
|
153
|
+
"visibility_metrics"
|
|
154
|
+
]["precision"],
|
|
155
|
+
f"eval/{d_name}/visibility_recall": metrics[
|
|
156
|
+
"visibility_metrics"
|
|
157
|
+
]["recall"],
|
|
158
|
+
}
|
|
159
|
+
for key, value in summary_metrics.items():
|
|
160
|
+
wandb.run.summary[key] = value
|
|
161
|
+
|
|
162
|
+
# Finish wandb run and cleanup after all evaluation is complete
|
|
163
|
+
if trainer.config.trainer_config.use_wandb:
|
|
164
|
+
import wandb
|
|
165
|
+
import shutil
|
|
166
|
+
|
|
167
|
+
if wandb.run is not None:
|
|
168
|
+
wandb.finish()
|
|
169
|
+
|
|
170
|
+
# Delete local wandb logs if configured
|
|
171
|
+
wandb_config = trainer.config.trainer_config.wandb
|
|
172
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
173
|
+
wandb_config.delete_local_logs is None
|
|
174
|
+
and wandb_config.wandb_mode != "offline"
|
|
175
|
+
)
|
|
176
|
+
if should_delete_wandb_logs:
|
|
177
|
+
wandb_dir = run_path / "wandb"
|
|
178
|
+
if wandb_dir.exists():
|
|
179
|
+
logger.info(
|
|
180
|
+
f"Deleting local wandb logs at {wandb_dir}... "
|
|
181
|
+
"(set trainer_config.wandb.delete_local_logs=false to disable)"
|
|
182
|
+
)
|
|
183
|
+
shutil.rmtree(wandb_dir, ignore_errors=True)
|
|
184
|
+
|
|
108
185
|
|
|
109
186
|
def train(
|
|
110
187
|
train_labels_path: Optional[List[str]] = None,
|
|
111
188
|
val_labels_path: Optional[List[str]] = None,
|
|
112
189
|
validation_fraction: float = 0.1,
|
|
113
|
-
|
|
190
|
+
use_same_data_for_val: bool = False,
|
|
191
|
+
test_file_path: Optional[Union[str, List[str]]] = None,
|
|
114
192
|
provider: str = "LabelsReader",
|
|
115
193
|
user_instances_only: bool = True,
|
|
116
194
|
data_pipeline_fw: str = "torch_dataset",
|
|
@@ -124,9 +202,10 @@ def train(
|
|
|
124
202
|
max_width: Optional[int] = None,
|
|
125
203
|
crop_size: Optional[int] = None,
|
|
126
204
|
min_crop_size: Optional[int] = 100,
|
|
127
|
-
|
|
205
|
+
crop_padding: Optional[int] = None,
|
|
206
|
+
use_augmentations_train: bool = True,
|
|
128
207
|
intensity_aug: Optional[Union[str, List[str], Dict[str, Any]]] = None,
|
|
129
|
-
geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] =
|
|
208
|
+
geometry_aug: Optional[Union[str, List[str], Dict[str, Any]]] = "rotation",
|
|
130
209
|
init_weight: str = "default",
|
|
131
210
|
pretrained_backbone_weights: Optional[str] = None,
|
|
132
211
|
pretrained_head_weights: Optional[str] = None,
|
|
@@ -160,6 +239,7 @@ def train(
|
|
|
160
239
|
wandb_save_viz_imgs_wandb: bool = False,
|
|
161
240
|
wandb_resume_prv_runid: Optional[str] = None,
|
|
162
241
|
wandb_group_name: Optional[str] = None,
|
|
242
|
+
wandb_delete_local_logs: Optional[bool] = None,
|
|
163
243
|
optimizer: str = "Adam",
|
|
164
244
|
learning_rate: float = 1e-3,
|
|
165
245
|
amsgrad: bool = False,
|
|
@@ -188,7 +268,11 @@ def train(
|
|
|
188
268
|
training set to sample for generating the validation set. The remaining
|
|
189
269
|
labeled frames will be left in the training set. If the `validation_labels`
|
|
190
270
|
are already specified, this has no effect. Default: 0.1.
|
|
191
|
-
|
|
271
|
+
use_same_data_for_val: If `True`, use the same data for both training and
|
|
272
|
+
validation (train = val). Useful for intentional overfitting on small
|
|
273
|
+
datasets. When enabled, `val_labels_path` and `validation_fraction` are
|
|
274
|
+
ignored. Default: False.
|
|
275
|
+
test_file_path: Path or list of paths to test dataset(s) (`.slp` file(s) or `.mp4` file(s)).
|
|
192
276
|
Note: This is used to get evaluation on test set after training is completed.
|
|
193
277
|
provider: Provider class to read the input sleap files. Only "LabelsReader"
|
|
194
278
|
supported for the training pipeline. Default: "LabelsReader".
|
|
@@ -210,16 +294,19 @@ def train(
|
|
|
210
294
|
is set to True, then we convert the image to grayscale (single-channel)
|
|
211
295
|
image. If the source image has only one channel and this is set to False, then we retain the single channel input. Default: `False`.
|
|
212
296
|
scale: Factor to resize the image dimensions by, specified as a float. Default: 1.0.
|
|
213
|
-
max_height: Maximum height the image should be padded to. If not provided, the
|
|
297
|
+
max_height: Maximum height the original image should be resized and padded to. If not provided, the
|
|
214
298
|
original image size will be retained. Default: None.
|
|
215
|
-
max_width: Maximum width the image should be padded to. If not provided, the
|
|
299
|
+
max_width: Maximum width the original image should be resized and padded to. If not provided, the
|
|
216
300
|
original image size will be retained. Default: None.
|
|
217
301
|
crop_size: Crop size of each instance for centered-instance model.
|
|
218
302
|
If `None`, this would be automatically computed based on the largest instance
|
|
219
|
-
in the `sio.Labels` file. Default: None.
|
|
303
|
+
in the `sio.Labels` file. If `scale` is provided, then the cropped image will be resized according to `scale`. Default: None.
|
|
220
304
|
min_crop_size: Minimum crop size to be used if `crop_size` is `None`. Default: 100.
|
|
305
|
+
crop_padding: Padding in pixels to add around instance bounding box when computing
|
|
306
|
+
crop size. If `None`, padding is auto-computed based on augmentation settings.
|
|
307
|
+
Only used when `crop_size` is `None`. Default: None.
|
|
221
308
|
use_augmentations_train: True if the data augmentation should be applied to the
|
|
222
|
-
training data, else False. Default:
|
|
309
|
+
training data, else False. Default: True.
|
|
223
310
|
intensity_aug: One of ["uniform_noise", "gaussian_noise", "contrast", "brightness"]
|
|
224
311
|
or list of strings from the above allowed values. To have custom values, pass
|
|
225
312
|
a dict with the structure in `sleap_nn.config.data_config.IntensityConfig`.
|
|
@@ -231,7 +318,8 @@ def train(
|
|
|
231
318
|
or list of strings from the above allowed values. To have custom values, pass
|
|
232
319
|
a dict with the structure in `sleap_nn.config.data_config.GeometryConfig`.
|
|
233
320
|
For eg: {
|
|
234
|
-
"
|
|
321
|
+
"rotation_min": -45,
|
|
322
|
+
"rotation_max": 45,
|
|
235
323
|
"affine_p": 1.0
|
|
236
324
|
}
|
|
237
325
|
init_weight: model weights initialization method. "default" uses kaiming uniform
|
|
@@ -331,6 +419,9 @@ def train(
|
|
|
331
419
|
wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
|
|
332
420
|
ckpt. Default: None
|
|
333
421
|
wandb_group_name: Group name for the wandb run. Default: None.
|
|
422
|
+
wandb_delete_local_logs: If True, delete local wandb logs folder after training.
|
|
423
|
+
If False, keep the folder. If None (default), automatically delete if logging
|
|
424
|
+
online (wandb_mode != "offline") and keep if logging offline. Default: None.
|
|
334
425
|
optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
|
|
335
426
|
learning_rate: Learning rate of type float. Default: 1e-3.
|
|
336
427
|
amsgrad: Enable AMSGrad with the optimizer. Default: False.
|
|
@@ -376,6 +467,7 @@ def train(
|
|
|
376
467
|
train_labels_path=train_labels_path,
|
|
377
468
|
val_labels_path=val_labels_path,
|
|
378
469
|
validation_fraction=validation_fraction,
|
|
470
|
+
use_same_data_for_val=use_same_data_for_val,
|
|
379
471
|
test_file_path=test_file_path,
|
|
380
472
|
provider=provider,
|
|
381
473
|
user_instances_only=user_instances_only,
|
|
@@ -390,6 +482,7 @@ def train(
|
|
|
390
482
|
max_width=max_width,
|
|
391
483
|
crop_size=crop_size,
|
|
392
484
|
min_crop_size=min_crop_size,
|
|
485
|
+
crop_padding=crop_padding,
|
|
393
486
|
use_augmentations_train=use_augmentations_train,
|
|
394
487
|
intensity_aug=intensity_aug,
|
|
395
488
|
geometry_aug=geometry_aug,
|
|
@@ -432,6 +525,7 @@ def train(
|
|
|
432
525
|
wandb_save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
|
|
433
526
|
wandb_resume_prv_runid=wandb_resume_prv_runid,
|
|
434
527
|
wandb_group_name=wandb_group_name,
|
|
528
|
+
wandb_delete_local_logs=wandb_delete_local_logs,
|
|
435
529
|
optimizer=optimizer,
|
|
436
530
|
learning_rate=learning_rate,
|
|
437
531
|
amsgrad=amsgrad,
|