sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
import shutil
|
|
5
|
-
import copy
|
|
6
5
|
import attrs
|
|
7
6
|
import torch
|
|
8
7
|
import random
|
|
@@ -16,11 +15,14 @@ import yaml
|
|
|
16
15
|
from pathlib import Path
|
|
17
16
|
from typing import List, Optional
|
|
18
17
|
from datetime import datetime
|
|
19
|
-
from itertools import
|
|
18
|
+
from itertools import count
|
|
20
19
|
from omegaconf import DictConfig, OmegaConf
|
|
21
20
|
from lightning.pytorch.loggers import WandbLogger
|
|
22
21
|
from sleap_nn.data.utils import check_cache_memory
|
|
23
|
-
from lightning.pytorch.callbacks import
|
|
22
|
+
from lightning.pytorch.callbacks import (
|
|
23
|
+
ModelCheckpoint,
|
|
24
|
+
EarlyStopping,
|
|
25
|
+
)
|
|
24
26
|
from lightning.pytorch.profilers import (
|
|
25
27
|
SimpleProfiler,
|
|
26
28
|
AdvancedProfiler,
|
|
@@ -28,7 +30,11 @@ from lightning.pytorch.profilers import (
|
|
|
28
30
|
PassThroughProfiler,
|
|
29
31
|
)
|
|
30
32
|
from sleap_io.io.skeleton import SkeletonYAMLEncoder
|
|
31
|
-
from sleap_nn.data.instance_cropping import
|
|
33
|
+
from sleap_nn.data.instance_cropping import (
|
|
34
|
+
find_instance_crop_size,
|
|
35
|
+
find_max_instance_bbox_size,
|
|
36
|
+
compute_augmentation_padding,
|
|
37
|
+
)
|
|
32
38
|
from sleap_nn.data.providers import get_max_height_width
|
|
33
39
|
from sleap_nn.data.custom_datasets import (
|
|
34
40
|
get_train_val_dataloaders,
|
|
@@ -47,9 +53,11 @@ from sleap_nn.config.training_job_config import verify_training_cfg
|
|
|
47
53
|
from sleap_nn.training.callbacks import (
|
|
48
54
|
ProgressReporterZMQ,
|
|
49
55
|
TrainingControllerZMQ,
|
|
50
|
-
MatplotlibSaver,
|
|
51
|
-
WandBPredImageLogger,
|
|
52
56
|
CSVLoggerCallback,
|
|
57
|
+
SleapProgressBar,
|
|
58
|
+
EpochEndEvaluationCallback,
|
|
59
|
+
CentroidEvaluationCallback,
|
|
60
|
+
UnifiedVizCallback,
|
|
53
61
|
)
|
|
54
62
|
from sleap_nn import RANK
|
|
55
63
|
from sleap_nn.legacy_models import get_keras_first_layer_channels
|
|
@@ -207,6 +215,52 @@ class ModelTrainer:
|
|
|
207
215
|
trainer_devices = 1
|
|
208
216
|
return trainer_devices
|
|
209
217
|
|
|
218
|
+
def _count_labeled_frames(
|
|
219
|
+
self, labels_list: List[sio.Labels], user_only: bool = True
|
|
220
|
+
) -> int:
|
|
221
|
+
"""Count labeled frames, optionally filtering to user-labeled only.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
labels_list: List of Labels objects to count frames from.
|
|
225
|
+
user_only: If True, count only frames with user instances.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Total count of labeled frames.
|
|
229
|
+
"""
|
|
230
|
+
total = 0
|
|
231
|
+
for label in labels_list:
|
|
232
|
+
if user_only:
|
|
233
|
+
total += sum(1 for lf in label if lf.has_user_instances)
|
|
234
|
+
else:
|
|
235
|
+
total += len(label)
|
|
236
|
+
return total
|
|
237
|
+
|
|
238
|
+
def _filter_to_user_labeled(self, labels: sio.Labels) -> sio.Labels:
|
|
239
|
+
"""Filter a Labels object to only include user-labeled frames.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
labels: Labels object to filter.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
New Labels object containing only frames with user instances.
|
|
246
|
+
"""
|
|
247
|
+
# Filter labeled frames to only those with user instances
|
|
248
|
+
user_lfs = [lf for lf in labels if lf.has_user_instances]
|
|
249
|
+
|
|
250
|
+
# Set instances to user instances only
|
|
251
|
+
for lf in user_lfs:
|
|
252
|
+
lf.instances = lf.user_instances
|
|
253
|
+
|
|
254
|
+
# Create new Labels with filtered frames
|
|
255
|
+
return sio.Labels(
|
|
256
|
+
labeled_frames=user_lfs,
|
|
257
|
+
videos=labels.videos,
|
|
258
|
+
skeletons=labels.skeletons,
|
|
259
|
+
tracks=labels.tracks,
|
|
260
|
+
suggestions=labels.suggestions,
|
|
261
|
+
provenance=labels.provenance,
|
|
262
|
+
)
|
|
263
|
+
|
|
210
264
|
def _setup_train_val_labels(
|
|
211
265
|
self,
|
|
212
266
|
labels: Optional[List[sio.Labels]] = None,
|
|
@@ -218,21 +272,35 @@ class ModelTrainer:
|
|
|
218
272
|
total_val_lfs = 0
|
|
219
273
|
self.skeletons = labels[0].skeletons
|
|
220
274
|
|
|
275
|
+
# Check if we should count only user-labeled frames
|
|
276
|
+
user_instances_only = OmegaConf.select(
|
|
277
|
+
self.config, "data_config.user_instances_only", default=True
|
|
278
|
+
)
|
|
279
|
+
|
|
221
280
|
# check if all `.slp` file shave same skeleton structure (if multiple slp file paths are provided)
|
|
222
281
|
skeleton = self.skeletons[0]
|
|
223
282
|
for index, train_label in enumerate(labels):
|
|
224
283
|
skel_temp = train_label.skeletons[0]
|
|
225
284
|
skeletons_equal = skeleton.matches(skel_temp)
|
|
226
|
-
if skeletons_equal:
|
|
227
|
-
total_train_lfs += len(train_label)
|
|
228
|
-
else:
|
|
285
|
+
if not skeletons_equal:
|
|
229
286
|
message = f"The skeletons in the training labels: {index + 1} do not match the skeleton in the first training label file."
|
|
230
287
|
logger.error(message)
|
|
231
288
|
raise ValueError(message)
|
|
232
289
|
|
|
233
|
-
|
|
290
|
+
# Check for same-data mode (train = val, for intentional overfitting)
|
|
291
|
+
use_same = OmegaConf.select(
|
|
292
|
+
self.config, "data_config.use_same_data_for_val", default=False
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if use_same:
|
|
296
|
+
# Same mode: use identical data for train and val (for overfitting)
|
|
297
|
+
logger.info("Using same data for train and val (overfit mode)")
|
|
298
|
+
self.train_labels = labels
|
|
299
|
+
self.val_labels = labels
|
|
300
|
+
total_train_lfs = self._count_labeled_frames(labels, user_instances_only)
|
|
301
|
+
total_val_lfs = total_train_lfs
|
|
302
|
+
elif val_labels is None or not len(val_labels):
|
|
234
303
|
# if val labels are not provided, split from train
|
|
235
|
-
total_train_lfs = 0
|
|
236
304
|
val_fraction = OmegaConf.select(
|
|
237
305
|
self.config, "data_config.validation_fraction", default=0.1
|
|
238
306
|
)
|
|
@@ -250,13 +318,14 @@ class ModelTrainer:
|
|
|
250
318
|
)
|
|
251
319
|
self.train_labels.append(train_split)
|
|
252
320
|
self.val_labels.append(val_split)
|
|
321
|
+
# make_training_splits returns only user-labeled frames
|
|
253
322
|
total_train_lfs += len(train_split)
|
|
254
323
|
total_val_lfs += len(val_split)
|
|
255
324
|
else:
|
|
256
325
|
self.train_labels = labels
|
|
257
326
|
self.val_labels = val_labels
|
|
258
|
-
|
|
259
|
-
|
|
327
|
+
total_train_lfs = self._count_labeled_frames(labels, user_instances_only)
|
|
328
|
+
total_val_lfs = self._count_labeled_frames(val_labels, user_instances_only)
|
|
260
329
|
|
|
261
330
|
logger.info(f"# Train Labeled frames: {total_train_lfs}")
|
|
262
331
|
logger.info(f"# Val Labeled frames: {total_val_lfs}")
|
|
@@ -291,13 +360,70 @@ class ModelTrainer:
|
|
|
291
360
|
):
|
|
292
361
|
# compute crop size if not provided in config
|
|
293
362
|
if crop_size is None:
|
|
363
|
+
# Get padding from config or auto-compute from augmentation settings
|
|
364
|
+
padding = self.config.data_config.preprocessing.crop_padding
|
|
365
|
+
if padding is None:
|
|
366
|
+
# Auto-compute padding based on augmentation settings
|
|
367
|
+
aug_config = self.config.data_config.augmentation_config
|
|
368
|
+
if (
|
|
369
|
+
self.config.data_config.use_augmentations_train
|
|
370
|
+
and aug_config is not None
|
|
371
|
+
and aug_config.geometric is not None
|
|
372
|
+
):
|
|
373
|
+
geo = aug_config.geometric
|
|
374
|
+
# Check if rotation is enabled (via rotation_p or affine_p)
|
|
375
|
+
rotation_enabled = (
|
|
376
|
+
geo.rotation_p is not None and geo.rotation_p > 0
|
|
377
|
+
) or (
|
|
378
|
+
geo.rotation_p is None
|
|
379
|
+
and geo.scale_p is None
|
|
380
|
+
and geo.translate_p is None
|
|
381
|
+
and geo.affine_p > 0
|
|
382
|
+
)
|
|
383
|
+
# Check if scale is enabled (via scale_p or affine_p)
|
|
384
|
+
scale_enabled = (
|
|
385
|
+
geo.scale_p is not None and geo.scale_p > 0
|
|
386
|
+
) or (
|
|
387
|
+
geo.rotation_p is None
|
|
388
|
+
and geo.scale_p is None
|
|
389
|
+
and geo.translate_p is None
|
|
390
|
+
and geo.affine_p > 0
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
if rotation_enabled or scale_enabled:
|
|
394
|
+
# First find the actual max bbox size from labels
|
|
395
|
+
bbox_size = find_max_instance_bbox_size(train_label)
|
|
396
|
+
bbox_size = max(
|
|
397
|
+
bbox_size,
|
|
398
|
+
self.config.data_config.preprocessing.min_crop_size
|
|
399
|
+
or 100,
|
|
400
|
+
)
|
|
401
|
+
rotation_max = (
|
|
402
|
+
max(
|
|
403
|
+
abs(geo.rotation_min),
|
|
404
|
+
abs(geo.rotation_max),
|
|
405
|
+
)
|
|
406
|
+
if rotation_enabled
|
|
407
|
+
else 0.0
|
|
408
|
+
)
|
|
409
|
+
scale_max = geo.scale_max if scale_enabled else 1.0
|
|
410
|
+
padding = compute_augmentation_padding(
|
|
411
|
+
bbox_size=bbox_size,
|
|
412
|
+
rotation_max=rotation_max,
|
|
413
|
+
scale_max=scale_max,
|
|
414
|
+
)
|
|
415
|
+
else:
|
|
416
|
+
padding = 0
|
|
417
|
+
else:
|
|
418
|
+
padding = 0
|
|
419
|
+
|
|
294
420
|
crop_sz = find_instance_crop_size(
|
|
295
421
|
labels=train_label,
|
|
422
|
+
padding=padding,
|
|
296
423
|
maximum_stride=self.config.model_config.backbone_config[
|
|
297
424
|
f"{self.backbone_type}"
|
|
298
425
|
]["max_stride"],
|
|
299
426
|
min_crop_size=self.config.data_config.preprocessing.min_crop_size,
|
|
300
|
-
input_scaling=self.config.data_config.preprocessing.scale,
|
|
301
427
|
)
|
|
302
428
|
|
|
303
429
|
if crop_sz > max_crop_size:
|
|
@@ -361,16 +487,36 @@ class ModelTrainer:
|
|
|
361
487
|
ckpt_dir = "."
|
|
362
488
|
self.config.trainer_config.ckpt_dir = ckpt_dir
|
|
363
489
|
run_name = self.config.trainer_config.run_name
|
|
364
|
-
|
|
490
|
+
run_name_is_empty = run_name is None or run_name == "" or run_name == "None"
|
|
491
|
+
|
|
492
|
+
# Validate: multi-GPU + disk cache requires explicit run_name
|
|
493
|
+
if run_name_is_empty:
|
|
494
|
+
is_disk_caching = (
|
|
495
|
+
self.config.data_config.data_pipeline_fw
|
|
496
|
+
== "torch_dataset_cache_img_disk"
|
|
497
|
+
)
|
|
498
|
+
num_devices = self._get_trainer_devices()
|
|
499
|
+
|
|
500
|
+
if is_disk_caching and num_devices > 1:
|
|
501
|
+
raise ValueError(
|
|
502
|
+
f"Multi-GPU training with disk caching requires an explicit `run_name`.\n\n"
|
|
503
|
+
f"Detected {num_devices} device(s) with "
|
|
504
|
+
f"`data_pipeline_fw='torch_dataset_cache_img_disk'`.\n"
|
|
505
|
+
f"Without an explicit run_name, each GPU worker generates a different "
|
|
506
|
+
f"timestamp-based directory, causing cache synchronization failures.\n\n"
|
|
507
|
+
f"Please provide a run_name using one of these methods:\n"
|
|
508
|
+
f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
|
|
509
|
+
f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
|
|
510
|
+
f" - Python API: train(..., run_name='my_experiment')"
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Auto-generate timestamp-based run_name (safe for single GPU or non-disk-cache)
|
|
365
514
|
sum_train_lfs = sum([len(train_label) for train_label in self.train_labels])
|
|
366
515
|
sum_val_lfs = sum([len(val_label) for val_label in self.val_labels])
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
datetime.now().strftime("%y%m%d_%H%M%S")
|
|
372
|
-
+ f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
|
|
373
|
-
)
|
|
516
|
+
run_name = (
|
|
517
|
+
datetime.now().strftime("%y%m%d_%H%M%S")
|
|
518
|
+
+ f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
|
|
519
|
+
)
|
|
374
520
|
|
|
375
521
|
# If checkpoint path already exists, add suffix to prevent overwriting
|
|
376
522
|
if (Path(ckpt_dir) / run_name).exists() and (
|
|
@@ -509,6 +655,10 @@ class ModelTrainer:
|
|
|
509
655
|
"""Compute config parameters."""
|
|
510
656
|
logger.info("Setting up config...")
|
|
511
657
|
|
|
658
|
+
# Normalize empty strings to None for optional wandb fields
|
|
659
|
+
if self.config.trainer_config.wandb.prv_runid == "":
|
|
660
|
+
self.config.trainer_config.wandb.prv_runid = None
|
|
661
|
+
|
|
512
662
|
# compute preprocessing parameters from the labels objects and fill in the config
|
|
513
663
|
self._setup_preprocessing_config()
|
|
514
664
|
|
|
@@ -558,39 +708,89 @@ class ModelTrainer:
|
|
|
558
708
|
)
|
|
559
709
|
)
|
|
560
710
|
|
|
561
|
-
# setup checkpoint path
|
|
711
|
+
# setup checkpoint path (generates run_name if not specified)
|
|
562
712
|
self._setup_ckpt_path()
|
|
563
713
|
|
|
714
|
+
# Default wandb run name to trainer run_name if not specified
|
|
715
|
+
# Note: This must come after _setup_ckpt_path() which generates run_name
|
|
716
|
+
if self.config.trainer_config.wandb.name is None:
|
|
717
|
+
self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
|
|
718
|
+
|
|
564
719
|
# verify input_channels in model_config based on input image and pretrained model weights
|
|
565
720
|
self._verify_model_input_channels()
|
|
566
721
|
|
|
567
722
|
def _setup_model_ckpt_dir(self):
|
|
568
|
-
"""Create the model ckpt folder."""
|
|
723
|
+
"""Create the model ckpt folder and save ground truth labels."""
|
|
569
724
|
ckpt_path = (
|
|
570
725
|
Path(self.config.trainer_config.ckpt_dir)
|
|
571
726
|
/ self.config.trainer_config.run_name
|
|
572
727
|
).as_posix()
|
|
573
728
|
logger.info(f"Setting up model ckpt dir: `{ckpt_path}`...")
|
|
574
729
|
|
|
575
|
-
|
|
576
|
-
try:
|
|
577
|
-
Path(ckpt_path).mkdir(parents=True, exist_ok=True)
|
|
578
|
-
except OSError as e:
|
|
579
|
-
message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
|
|
580
|
-
logger.error(message)
|
|
581
|
-
raise OSError(message)
|
|
582
|
-
|
|
730
|
+
# Only rank 0 (or non-distributed) should create directories and save files
|
|
583
731
|
if RANK in [0, -1]:
|
|
732
|
+
if not Path(ckpt_path).exists():
|
|
733
|
+
try:
|
|
734
|
+
Path(ckpt_path).mkdir(parents=True, exist_ok=True)
|
|
735
|
+
except OSError as e:
|
|
736
|
+
message = f"Cannot create a new folder in {ckpt_path}.\n {e}"
|
|
737
|
+
logger.error(message)
|
|
738
|
+
raise OSError(message)
|
|
739
|
+
# Check if we should filter to user-labeled frames only
|
|
740
|
+
user_instances_only = OmegaConf.select(
|
|
741
|
+
self.config, "data_config.user_instances_only", default=True
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# Save train and val ground truth labels
|
|
584
745
|
for idx, (train, val) in enumerate(zip(self.train_labels, self.val_labels)):
|
|
585
|
-
|
|
586
|
-
|
|
746
|
+
# Filter to user-labeled frames if needed (for evaluation)
|
|
747
|
+
if user_instances_only:
|
|
748
|
+
train_filtered = self._filter_to_user_labeled(train)
|
|
749
|
+
val_filtered = self._filter_to_user_labeled(val)
|
|
750
|
+
else:
|
|
751
|
+
train_filtered = train
|
|
752
|
+
val_filtered = val
|
|
753
|
+
|
|
754
|
+
train_filtered.save(
|
|
755
|
+
Path(ckpt_path) / f"labels_gt.train.{idx}.slp",
|
|
587
756
|
restore_original_videos=False,
|
|
588
757
|
)
|
|
589
|
-
|
|
590
|
-
Path(ckpt_path) / f"
|
|
758
|
+
val_filtered.save(
|
|
759
|
+
Path(ckpt_path) / f"labels_gt.val.{idx}.slp",
|
|
591
760
|
restore_original_videos=False,
|
|
592
761
|
)
|
|
593
762
|
|
|
763
|
+
# Save test ground truth labels if test paths are provided
|
|
764
|
+
test_file_path = OmegaConf.select(
|
|
765
|
+
self.config, "data_config.test_file_path", default=None
|
|
766
|
+
)
|
|
767
|
+
if test_file_path is not None:
|
|
768
|
+
# Normalize to list of strings
|
|
769
|
+
if isinstance(test_file_path, str):
|
|
770
|
+
test_paths = [test_file_path]
|
|
771
|
+
else:
|
|
772
|
+
test_paths = list(test_file_path)
|
|
773
|
+
|
|
774
|
+
for idx, test_path in enumerate(test_paths):
|
|
775
|
+
# Only save if it's a .slp file (not a video file)
|
|
776
|
+
if test_path.endswith(".slp") or test_path.endswith(".pkg.slp"):
|
|
777
|
+
try:
|
|
778
|
+
test_labels = sio.load_slp(test_path)
|
|
779
|
+
if user_instances_only:
|
|
780
|
+
test_filtered = self._filter_to_user_labeled(
|
|
781
|
+
test_labels
|
|
782
|
+
)
|
|
783
|
+
else:
|
|
784
|
+
test_filtered = test_labels
|
|
785
|
+
test_filtered.save(
|
|
786
|
+
Path(ckpt_path) / f"labels_gt.test.{idx}.slp",
|
|
787
|
+
restore_original_videos=False,
|
|
788
|
+
)
|
|
789
|
+
except Exception as e:
|
|
790
|
+
logger.warning(
|
|
791
|
+
f"Could not save test ground truth for {test_path}: {e}"
|
|
792
|
+
)
|
|
793
|
+
|
|
594
794
|
def _setup_viz_datasets(self):
|
|
595
795
|
"""Setup dataloaders."""
|
|
596
796
|
data_viz_config = self.config.copy()
|
|
@@ -608,10 +808,40 @@ class ModelTrainer:
|
|
|
608
808
|
base_cache_img_path = None
|
|
609
809
|
if self.config.data_config.data_pipeline_fw == "torch_dataset_cache_img_memory":
|
|
610
810
|
# check available memory. If insufficient memory, default to disk caching.
|
|
811
|
+
# Account for DataLoader worker memory overhead
|
|
812
|
+
train_num_workers = self.config.trainer_config.train_data_loader.num_workers
|
|
813
|
+
val_num_workers = self.config.trainer_config.val_data_loader.num_workers
|
|
814
|
+
max_num_workers = max(train_num_workers, val_num_workers)
|
|
815
|
+
|
|
611
816
|
mem_available = check_cache_memory(
|
|
612
|
-
self.train_labels,
|
|
817
|
+
self.train_labels,
|
|
818
|
+
self.val_labels,
|
|
819
|
+
memory_buffer=MEMORY_BUFFER,
|
|
820
|
+
num_workers=max_num_workers,
|
|
613
821
|
)
|
|
614
822
|
if not mem_available:
|
|
823
|
+
# Validate: multi-GPU + auto-generated run_name + fallback to disk cache
|
|
824
|
+
original_run_name = self._initial_config.trainer_config.run_name
|
|
825
|
+
run_name_was_auto = (
|
|
826
|
+
original_run_name is None
|
|
827
|
+
or original_run_name == ""
|
|
828
|
+
or original_run_name == "None"
|
|
829
|
+
)
|
|
830
|
+
if run_name_was_auto and self.trainer.num_devices > 1:
|
|
831
|
+
raise ValueError(
|
|
832
|
+
f"Memory caching failed and disk caching fallback requires an "
|
|
833
|
+
f"explicit `run_name` for multi-GPU training.\n\n"
|
|
834
|
+
f"Detected {self.trainer.num_devices} device(s) with insufficient "
|
|
835
|
+
f"memory for in-memory caching.\n"
|
|
836
|
+
f"Without an explicit run_name, each GPU worker generates a different "
|
|
837
|
+
f"timestamp-based directory, causing cache synchronization failures.\n\n"
|
|
838
|
+
f"Please provide a run_name using one of these methods:\n"
|
|
839
|
+
f" - CLI: sleap-nn train config.yaml trainer_config.run_name=my_experiment\n"
|
|
840
|
+
f" - Config file: Set `trainer_config.run_name: my_experiment`\n"
|
|
841
|
+
f" - Python API: train(..., run_name='my_experiment')\n\n"
|
|
842
|
+
f"Alternatively, use `data_pipeline_fw='torch_dataset'` to disable caching."
|
|
843
|
+
)
|
|
844
|
+
|
|
615
845
|
self.config.data_config.data_pipeline_fw = (
|
|
616
846
|
"torch_dataset_cache_img_disk"
|
|
617
847
|
)
|
|
@@ -655,7 +885,7 @@ class ModelTrainer:
|
|
|
655
885
|
/ self.config.trainer_config.run_name
|
|
656
886
|
).as_posix(),
|
|
657
887
|
filename="best",
|
|
658
|
-
monitor="
|
|
888
|
+
monitor="val/loss",
|
|
659
889
|
mode="min",
|
|
660
890
|
)
|
|
661
891
|
callbacks.append(checkpoint_callback)
|
|
@@ -663,18 +893,52 @@ class ModelTrainer:
|
|
|
663
893
|
# csv log callback
|
|
664
894
|
csv_log_keys = [
|
|
665
895
|
"epoch",
|
|
666
|
-
"
|
|
667
|
-
"
|
|
896
|
+
"train/loss",
|
|
897
|
+
"val/loss",
|
|
668
898
|
"learning_rate",
|
|
669
|
-
"
|
|
670
|
-
"
|
|
899
|
+
"train/time",
|
|
900
|
+
"val/time",
|
|
671
901
|
]
|
|
902
|
+
# Add model-specific keys for wandb parity
|
|
672
903
|
if self.model_type in [
|
|
673
904
|
"single_instance",
|
|
674
905
|
"centered_instance",
|
|
675
906
|
"multi_class_topdown",
|
|
676
907
|
]:
|
|
677
|
-
csv_log_keys.extend(
|
|
908
|
+
csv_log_keys.extend(
|
|
909
|
+
[f"train/confmaps/{name}" for name in self.skeletons[0].node_names]
|
|
910
|
+
)
|
|
911
|
+
if self.model_type == "bottomup":
|
|
912
|
+
csv_log_keys.extend(
|
|
913
|
+
[
|
|
914
|
+
"train/confmaps_loss",
|
|
915
|
+
"train/paf_loss",
|
|
916
|
+
"val/confmaps_loss",
|
|
917
|
+
"val/paf_loss",
|
|
918
|
+
]
|
|
919
|
+
)
|
|
920
|
+
if self.model_type == "multi_class_bottomup":
|
|
921
|
+
csv_log_keys.extend(
|
|
922
|
+
[
|
|
923
|
+
"train/confmaps_loss",
|
|
924
|
+
"train/classmap_loss",
|
|
925
|
+
"train/class_accuracy",
|
|
926
|
+
"val/confmaps_loss",
|
|
927
|
+
"val/classmap_loss",
|
|
928
|
+
"val/class_accuracy",
|
|
929
|
+
]
|
|
930
|
+
)
|
|
931
|
+
if self.model_type == "multi_class_topdown":
|
|
932
|
+
csv_log_keys.extend(
|
|
933
|
+
[
|
|
934
|
+
"train/confmaps_loss",
|
|
935
|
+
"train/classvector_loss",
|
|
936
|
+
"train/class_accuracy",
|
|
937
|
+
"val/confmaps_loss",
|
|
938
|
+
"val/classvector_loss",
|
|
939
|
+
"val/class_accuracy",
|
|
940
|
+
]
|
|
941
|
+
)
|
|
678
942
|
csv_logger = CSVLoggerCallback(
|
|
679
943
|
filepath=Path(self.config.trainer_config.ckpt_dir)
|
|
680
944
|
/ self.config.trainer_config.run_name
|
|
@@ -687,7 +951,7 @@ class ModelTrainer:
|
|
|
687
951
|
# early stopping callback
|
|
688
952
|
callbacks.append(
|
|
689
953
|
EarlyStopping(
|
|
690
|
-
monitor="
|
|
954
|
+
monitor="val/loss",
|
|
691
955
|
mode="min",
|
|
692
956
|
verbose=False,
|
|
693
957
|
min_delta=self.config.trainer_config.early_stopping.min_delta,
|
|
@@ -716,6 +980,17 @@ class ModelTrainer:
|
|
|
716
980
|
)
|
|
717
981
|
loggers.append(wandb_logger)
|
|
718
982
|
|
|
983
|
+
# Log message about wandb local logs cleanup
|
|
984
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
985
|
+
wandb_config.delete_local_logs is None
|
|
986
|
+
and wandb_config.wandb_mode != "offline"
|
|
987
|
+
)
|
|
988
|
+
if should_delete_wandb_logs:
|
|
989
|
+
logger.info(
|
|
990
|
+
"WandB local logs will be deleted after training completes. "
|
|
991
|
+
"To keep logs, set trainer_config.wandb.delete_local_logs=false"
|
|
992
|
+
)
|
|
993
|
+
|
|
719
994
|
# save the configs as yaml in the checkpoint dir
|
|
720
995
|
# Mask API key in both configs to prevent saving to disk
|
|
721
996
|
self.config.trainer_config.wandb.api_key = ""
|
|
@@ -734,11 +1009,8 @@ class ModelTrainer:
|
|
|
734
1009
|
)
|
|
735
1010
|
callbacks.append(ProgressReporterZMQ(address=publish_address))
|
|
736
1011
|
|
|
737
|
-
# viz callbacks
|
|
1012
|
+
# viz callbacks - use unified callback for all visualization outputs
|
|
738
1013
|
if self.config.trainer_config.visualize_preds_during_training:
|
|
739
|
-
train_viz_pipeline = cycle(viz_train_dataset)
|
|
740
|
-
val_viz_pipeline = cycle(viz_val_dataset)
|
|
741
|
-
|
|
742
1014
|
viz_dir = (
|
|
743
1015
|
Path(self.config.trainer_config.ckpt_dir)
|
|
744
1016
|
/ self.config.trainer_config.run_name
|
|
@@ -748,77 +1020,74 @@ class ModelTrainer:
|
|
|
748
1020
|
if RANK in [0, -1]:
|
|
749
1021
|
Path(viz_dir).mkdir(parents=True, exist_ok=True)
|
|
750
1022
|
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
plot_fn=lambda: self.lightning_model.visualize_example(
|
|
755
|
-
next(train_viz_pipeline)
|
|
756
|
-
),
|
|
757
|
-
prefix="train",
|
|
758
|
-
)
|
|
1023
|
+
# Get wandb viz config options
|
|
1024
|
+
log_wandb = self.config.trainer_config.use_wandb and OmegaConf.select(
|
|
1025
|
+
self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
|
|
759
1026
|
)
|
|
1027
|
+
wandb_modes = []
|
|
1028
|
+
if log_wandb:
|
|
1029
|
+
if OmegaConf.select(
|
|
1030
|
+
self.config, "trainer_config.wandb.viz_enabled", default=True
|
|
1031
|
+
):
|
|
1032
|
+
wandb_modes.append("direct")
|
|
1033
|
+
if OmegaConf.select(
|
|
1034
|
+
self.config, "trainer_config.wandb.viz_boxes", default=False
|
|
1035
|
+
):
|
|
1036
|
+
wandb_modes.append("boxes")
|
|
1037
|
+
if OmegaConf.select(
|
|
1038
|
+
self.config, "trainer_config.wandb.viz_masks", default=False
|
|
1039
|
+
):
|
|
1040
|
+
wandb_modes.append("masks")
|
|
1041
|
+
|
|
1042
|
+
# Single unified callback handles all visualization outputs
|
|
760
1043
|
callbacks.append(
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
1044
|
+
UnifiedVizCallback(
|
|
1045
|
+
model_trainer=self,
|
|
1046
|
+
train_dataset=viz_train_dataset,
|
|
1047
|
+
val_dataset=viz_val_dataset,
|
|
1048
|
+
model_type=self.model_type,
|
|
1049
|
+
save_local=self.config.trainer_config.save_ckpt,
|
|
1050
|
+
local_save_dir=viz_dir,
|
|
1051
|
+
log_wandb=log_wandb,
|
|
1052
|
+
wandb_modes=wandb_modes if wandb_modes else ["direct"],
|
|
1053
|
+
wandb_box_size=OmegaConf.select(
|
|
1054
|
+
self.config, "trainer_config.wandb.viz_box_size", default=5.0
|
|
1055
|
+
),
|
|
1056
|
+
wandb_confmap_threshold=OmegaConf.select(
|
|
1057
|
+
self.config,
|
|
1058
|
+
"trainer_config.wandb.viz_confmap_threshold",
|
|
1059
|
+
default=0.1,
|
|
1060
|
+
),
|
|
1061
|
+
log_wandb_table=OmegaConf.select(
|
|
1062
|
+
self.config, "trainer_config.wandb.log_viz_table", default=False
|
|
765
1063
|
),
|
|
766
|
-
prefix="validation",
|
|
767
1064
|
)
|
|
768
1065
|
)
|
|
769
1066
|
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
callbacks.append(
|
|
774
|
-
MatplotlibSaver(
|
|
775
|
-
save_folder=viz_dir,
|
|
776
|
-
plot_fn=lambda: self.lightning_model.visualize_pafs_example(
|
|
777
|
-
next(train_viz_pipeline1)
|
|
778
|
-
),
|
|
779
|
-
prefix="train.pafs_magnitude",
|
|
780
|
-
)
|
|
781
|
-
)
|
|
782
|
-
callbacks.append(
|
|
783
|
-
MatplotlibSaver(
|
|
784
|
-
save_folder=viz_dir,
|
|
785
|
-
plot_fn=lambda: self.lightning_model.visualize_pafs_example(
|
|
786
|
-
next(val_viz_pipeline1)
|
|
787
|
-
),
|
|
788
|
-
prefix="validation.pafs_magnitude",
|
|
789
|
-
)
|
|
790
|
-
)
|
|
1067
|
+
# Add custom progress bar with better metric formatting
|
|
1068
|
+
if self.config.trainer_config.enable_progress_bar:
|
|
1069
|
+
callbacks.append(SleapProgressBar())
|
|
791
1070
|
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
1071
|
+
# Add epoch-end evaluation callback if enabled
|
|
1072
|
+
if self.config.trainer_config.eval.enabled:
|
|
1073
|
+
if self.model_type == "centroid":
|
|
1074
|
+
# Use centroid-specific evaluation with distance-based metrics
|
|
795
1075
|
callbacks.append(
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
),
|
|
801
|
-
prefix="train.class_maps",
|
|
1076
|
+
CentroidEvaluationCallback(
|
|
1077
|
+
videos=self.val_labels[0].videos,
|
|
1078
|
+
eval_frequency=self.config.trainer_config.eval.frequency,
|
|
1079
|
+
match_threshold=self.config.trainer_config.eval.match_threshold,
|
|
802
1080
|
)
|
|
803
1081
|
)
|
|
1082
|
+
else:
|
|
1083
|
+
# Use standard OKS/PCK evaluation for pose models
|
|
804
1084
|
callbacks.append(
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
)
|
|
812
|
-
)
|
|
813
|
-
|
|
814
|
-
if self.config.trainer_config.use_wandb and OmegaConf.select(
|
|
815
|
-
self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
|
|
816
|
-
):
|
|
817
|
-
callbacks.append(
|
|
818
|
-
WandBPredImageLogger(
|
|
819
|
-
viz_folder=viz_dir,
|
|
820
|
-
wandb_run_name=self.config.trainer_config.wandb.name,
|
|
821
|
-
is_bottomup=(self.model_type == "bottomup"),
|
|
1085
|
+
EpochEndEvaluationCallback(
|
|
1086
|
+
skeleton=self.skeletons[0],
|
|
1087
|
+
videos=self.val_labels[0].videos,
|
|
1088
|
+
eval_frequency=self.config.trainer_config.eval.frequency,
|
|
1089
|
+
oks_stddev=self.config.trainer_config.eval.oks_stddev,
|
|
1090
|
+
oks_scale=self.config.trainer_config.eval.oks_scale,
|
|
822
1091
|
)
|
|
823
1092
|
)
|
|
824
1093
|
|
|
@@ -910,6 +1179,11 @@ class ModelTrainer:
|
|
|
910
1179
|
: self.config.trainer_config.trainer_devices
|
|
911
1180
|
]
|
|
912
1181
|
]
|
|
1182
|
+
# Sort device indices in ascending order for NCCL compatibility.
|
|
1183
|
+
# NCCL expects devices in consistent ascending order across ranks
|
|
1184
|
+
# to properly set up communication rings. Without sorting, DDP may
|
|
1185
|
+
# assign multiple ranks to the same GPU, causing "Duplicate GPU detected" errors.
|
|
1186
|
+
devices.sort()
|
|
913
1187
|
logger.info(f"Using GPUs with most available memory: {devices}")
|
|
914
1188
|
|
|
915
1189
|
# create lightning.Trainer instance.
|
|
@@ -931,6 +1205,10 @@ class ModelTrainer:
|
|
|
931
1205
|
# setup datasets
|
|
932
1206
|
train_dataset, val_dataset = self._setup_datasets()
|
|
933
1207
|
|
|
1208
|
+
# Barrier after dataset creation to ensure all workers wait for disk caching
|
|
1209
|
+
# (rank 0 caches to disk, others must wait before reading cached files)
|
|
1210
|
+
self.trainer.strategy.barrier()
|
|
1211
|
+
|
|
934
1212
|
# set-up steps per epoch
|
|
935
1213
|
train_steps_per_epoch = self.config.trainer_config.train_steps_per_epoch
|
|
936
1214
|
if train_steps_per_epoch is None:
|
|
@@ -959,7 +1237,7 @@ class ModelTrainer:
|
|
|
959
1237
|
logger.info(f"Backbone model: {self.lightning_model.model.backbone}")
|
|
960
1238
|
logger.info(f"Head model: {self.lightning_model.model.head_layers}")
|
|
961
1239
|
total_params = sum(p.numel() for p in self.lightning_model.parameters())
|
|
962
|
-
logger.info(f"Total model parameters: {total_params}")
|
|
1240
|
+
logger.info(f"Total model parameters: {total_params:,}")
|
|
963
1241
|
self.config.model_config.total_params = total_params
|
|
964
1242
|
|
|
965
1243
|
# setup dataloaders
|
|
@@ -1000,6 +1278,26 @@ class ModelTrainer:
|
|
|
1000
1278
|
id=self.config.trainer_config.wandb.prv_runid,
|
|
1001
1279
|
group=self.config.trainer_config.wandb.group,
|
|
1002
1280
|
)
|
|
1281
|
+
|
|
1282
|
+
# Define custom x-axes for wandb metrics
|
|
1283
|
+
# Epoch-level metrics use epoch as x-axis, step-level use default global_step
|
|
1284
|
+
wandb.define_metric("epoch")
|
|
1285
|
+
|
|
1286
|
+
# Training metrics (train/ prefix for grouping) - all use epoch x-axis
|
|
1287
|
+
wandb.define_metric("train/*", step_metric="epoch")
|
|
1288
|
+
wandb.define_metric("train/confmaps/*", step_metric="epoch")
|
|
1289
|
+
|
|
1290
|
+
# Validation metrics (val/ prefix for grouping)
|
|
1291
|
+
wandb.define_metric("val/*", step_metric="epoch")
|
|
1292
|
+
|
|
1293
|
+
# Evaluation metrics (eval/ prefix for grouping)
|
|
1294
|
+
wandb.define_metric("eval/*", step_metric="epoch")
|
|
1295
|
+
|
|
1296
|
+
# Visualization images (need explicit nested paths)
|
|
1297
|
+
wandb.define_metric("viz/*", step_metric="epoch")
|
|
1298
|
+
wandb.define_metric("viz/train/*", step_metric="epoch")
|
|
1299
|
+
wandb.define_metric("viz/val/*", step_metric="epoch")
|
|
1300
|
+
|
|
1003
1301
|
self.config.trainer_config.wandb.current_run_id = wandb.run.id
|
|
1004
1302
|
wandb.config["run_name"] = self.config.trainer_config.wandb.name
|
|
1005
1303
|
wandb.config["run_config"] = OmegaConf.to_container(
|
|
@@ -1017,6 +1315,9 @@ class ModelTrainer:
|
|
|
1017
1315
|
|
|
1018
1316
|
self.trainer.strategy.barrier()
|
|
1019
1317
|
|
|
1318
|
+
# Flag to track if training was interrupted (not completed normally)
|
|
1319
|
+
training_interrupted = False
|
|
1320
|
+
|
|
1020
1321
|
try:
|
|
1021
1322
|
logger.info(
|
|
1022
1323
|
f"Finished trainer set up. [{time.time() - start_setup_time:.1f}s]"
|
|
@@ -1032,13 +1333,13 @@ class ModelTrainer:
|
|
|
1032
1333
|
|
|
1033
1334
|
except KeyboardInterrupt:
|
|
1034
1335
|
logger.info("Stopping training...")
|
|
1336
|
+
training_interrupted = True
|
|
1035
1337
|
|
|
1036
1338
|
finally:
|
|
1037
1339
|
logger.info(
|
|
1038
1340
|
f"Finished training loop. [{(time.time() - start_train_time) / 60:.1f} min]"
|
|
1039
1341
|
)
|
|
1040
|
-
|
|
1041
|
-
wandb.finish()
|
|
1342
|
+
# Note: wandb.finish() is called in train.py after post-training evaluation
|
|
1042
1343
|
|
|
1043
1344
|
# delete image disk caching
|
|
1044
1345
|
if (
|
|
@@ -1063,3 +1364,15 @@ class ModelTrainer:
|
|
|
1063
1364
|
if viz_dir.exists():
|
|
1064
1365
|
logger.info(f"Deleting viz folder at {viz_dir}...")
|
|
1065
1366
|
shutil.rmtree(viz_dir, ignore_errors=True)
|
|
1367
|
+
|
|
1368
|
+
# Clean up entire run folder if training was interrupted (KeyboardInterrupt)
|
|
1369
|
+
if training_interrupted and self.trainer.global_rank == 0:
|
|
1370
|
+
run_dir = (
|
|
1371
|
+
Path(self.config.trainer_config.ckpt_dir)
|
|
1372
|
+
/ self.config.trainer_config.run_name
|
|
1373
|
+
)
|
|
1374
|
+
if run_dir.exists():
|
|
1375
|
+
logger.info(
|
|
1376
|
+
f"Training canceled - cleaning up run folder at {run_dir}..."
|
|
1377
|
+
)
|
|
1378
|
+
shutil.rmtree(run_dir, ignore_errors=True)
|