sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,11 @@ from itertools import cycle, count
|
|
|
20
20
|
from omegaconf import DictConfig, OmegaConf
|
|
21
21
|
from lightning.pytorch.loggers import WandbLogger
|
|
22
22
|
from sleap_nn.data.utils import check_cache_memory
|
|
23
|
-
from lightning.pytorch.callbacks import
|
|
23
|
+
from lightning.pytorch.callbacks import (
|
|
24
|
+
ModelCheckpoint,
|
|
25
|
+
EarlyStopping,
|
|
26
|
+
LearningRateMonitor,
|
|
27
|
+
)
|
|
24
28
|
from lightning.pytorch.profilers import (
|
|
25
29
|
SimpleProfiler,
|
|
26
30
|
AdvancedProfiler,
|
|
@@ -28,7 +32,11 @@ from lightning.pytorch.profilers import (
|
|
|
28
32
|
PassThroughProfiler,
|
|
29
33
|
)
|
|
30
34
|
from sleap_io.io.skeleton import SkeletonYAMLEncoder
|
|
31
|
-
from sleap_nn.data.instance_cropping import
|
|
35
|
+
from sleap_nn.data.instance_cropping import (
|
|
36
|
+
find_instance_crop_size,
|
|
37
|
+
find_max_instance_bbox_size,
|
|
38
|
+
compute_augmentation_padding,
|
|
39
|
+
)
|
|
32
40
|
from sleap_nn.data.providers import get_max_height_width
|
|
33
41
|
from sleap_nn.data.custom_datasets import (
|
|
34
42
|
get_train_val_dataloaders,
|
|
@@ -49,7 +57,10 @@ from sleap_nn.training.callbacks import (
|
|
|
49
57
|
TrainingControllerZMQ,
|
|
50
58
|
MatplotlibSaver,
|
|
51
59
|
WandBPredImageLogger,
|
|
60
|
+
WandBVizCallback,
|
|
61
|
+
WandBVizCallbackWithPAFs,
|
|
52
62
|
CSVLoggerCallback,
|
|
63
|
+
SleapProgressBar,
|
|
53
64
|
)
|
|
54
65
|
from sleap_nn import RANK
|
|
55
66
|
from sleap_nn.legacy_models import get_keras_first_layer_channels
|
|
@@ -207,6 +218,52 @@ class ModelTrainer:
|
|
|
207
218
|
trainer_devices = 1
|
|
208
219
|
return trainer_devices
|
|
209
220
|
|
|
221
|
+
def _count_labeled_frames(
|
|
222
|
+
self, labels_list: List[sio.Labels], user_only: bool = True
|
|
223
|
+
) -> int:
|
|
224
|
+
"""Count labeled frames, optionally filtering to user-labeled only.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
labels_list: List of Labels objects to count frames from.
|
|
228
|
+
user_only: If True, count only frames with user instances.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Total count of labeled frames.
|
|
232
|
+
"""
|
|
233
|
+
total = 0
|
|
234
|
+
for label in labels_list:
|
|
235
|
+
if user_only:
|
|
236
|
+
total += sum(1 for lf in label if lf.has_user_instances)
|
|
237
|
+
else:
|
|
238
|
+
total += len(label)
|
|
239
|
+
return total
|
|
240
|
+
|
|
241
|
+
def _filter_to_user_labeled(self, labels: sio.Labels) -> sio.Labels:
|
|
242
|
+
"""Filter a Labels object to only include user-labeled frames.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
labels: Labels object to filter.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
New Labels object containing only frames with user instances.
|
|
249
|
+
"""
|
|
250
|
+
# Filter labeled frames to only those with user instances
|
|
251
|
+
user_lfs = [lf for lf in labels if lf.has_user_instances]
|
|
252
|
+
|
|
253
|
+
# Set instances to user instances only
|
|
254
|
+
for lf in user_lfs:
|
|
255
|
+
lf.instances = lf.user_instances
|
|
256
|
+
|
|
257
|
+
# Create new Labels with filtered frames
|
|
258
|
+
return sio.Labels(
|
|
259
|
+
labeled_frames=user_lfs,
|
|
260
|
+
videos=labels.videos,
|
|
261
|
+
skeletons=labels.skeletons,
|
|
262
|
+
tracks=labels.tracks,
|
|
263
|
+
suggestions=labels.suggestions,
|
|
264
|
+
provenance=labels.provenance,
|
|
265
|
+
)
|
|
266
|
+
|
|
210
267
|
def _setup_train_val_labels(
|
|
211
268
|
self,
|
|
212
269
|
labels: Optional[List[sio.Labels]] = None,
|
|
@@ -218,21 +275,35 @@ class ModelTrainer:
|
|
|
218
275
|
total_val_lfs = 0
|
|
219
276
|
self.skeletons = labels[0].skeletons
|
|
220
277
|
|
|
278
|
+
# Check if we should count only user-labeled frames
|
|
279
|
+
user_instances_only = OmegaConf.select(
|
|
280
|
+
self.config, "data_config.user_instances_only", default=True
|
|
281
|
+
)
|
|
282
|
+
|
|
221
283
|
# check if all `.slp` file shave same skeleton structure (if multiple slp file paths are provided)
|
|
222
284
|
skeleton = self.skeletons[0]
|
|
223
285
|
for index, train_label in enumerate(labels):
|
|
224
286
|
skel_temp = train_label.skeletons[0]
|
|
225
287
|
skeletons_equal = skeleton.matches(skel_temp)
|
|
226
|
-
if skeletons_equal:
|
|
227
|
-
total_train_lfs += len(train_label)
|
|
228
|
-
else:
|
|
288
|
+
if not skeletons_equal:
|
|
229
289
|
message = f"The skeletons in the training labels: {index + 1} do not match the skeleton in the first training label file."
|
|
230
290
|
logger.error(message)
|
|
231
291
|
raise ValueError(message)
|
|
232
292
|
|
|
233
|
-
|
|
293
|
+
# Check for same-data mode (train = val, for intentional overfitting)
|
|
294
|
+
use_same = OmegaConf.select(
|
|
295
|
+
self.config, "data_config.use_same_data_for_val", default=False
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
if use_same:
|
|
299
|
+
# Same mode: use identical data for train and val (for overfitting)
|
|
300
|
+
logger.info("Using same data for train and val (overfit mode)")
|
|
301
|
+
self.train_labels = labels
|
|
302
|
+
self.val_labels = labels
|
|
303
|
+
total_train_lfs = self._count_labeled_frames(labels, user_instances_only)
|
|
304
|
+
total_val_lfs = total_train_lfs
|
|
305
|
+
elif val_labels is None or not len(val_labels):
|
|
234
306
|
# if val labels are not provided, split from train
|
|
235
|
-
total_train_lfs = 0
|
|
236
307
|
val_fraction = OmegaConf.select(
|
|
237
308
|
self.config, "data_config.validation_fraction", default=0.1
|
|
238
309
|
)
|
|
@@ -250,13 +321,14 @@ class ModelTrainer:
|
|
|
250
321
|
)
|
|
251
322
|
self.train_labels.append(train_split)
|
|
252
323
|
self.val_labels.append(val_split)
|
|
324
|
+
# make_training_splits returns only user-labeled frames
|
|
253
325
|
total_train_lfs += len(train_split)
|
|
254
326
|
total_val_lfs += len(val_split)
|
|
255
327
|
else:
|
|
256
328
|
self.train_labels = labels
|
|
257
329
|
self.val_labels = val_labels
|
|
258
|
-
|
|
259
|
-
|
|
330
|
+
total_train_lfs = self._count_labeled_frames(labels, user_instances_only)
|
|
331
|
+
total_val_lfs = self._count_labeled_frames(val_labels, user_instances_only)
|
|
260
332
|
|
|
261
333
|
logger.info(f"# Train Labeled frames: {total_train_lfs}")
|
|
262
334
|
logger.info(f"# Val Labeled frames: {total_val_lfs}")
|
|
@@ -291,13 +363,70 @@ class ModelTrainer:
|
|
|
291
363
|
):
|
|
292
364
|
# compute crop size if not provided in config
|
|
293
365
|
if crop_size is None:
|
|
366
|
+
# Get padding from config or auto-compute from augmentation settings
|
|
367
|
+
padding = self.config.data_config.preprocessing.crop_padding
|
|
368
|
+
if padding is None:
|
|
369
|
+
# Auto-compute padding based on augmentation settings
|
|
370
|
+
aug_config = self.config.data_config.augmentation_config
|
|
371
|
+
if (
|
|
372
|
+
self.config.data_config.use_augmentations_train
|
|
373
|
+
and aug_config is not None
|
|
374
|
+
and aug_config.geometric is not None
|
|
375
|
+
):
|
|
376
|
+
geo = aug_config.geometric
|
|
377
|
+
# Check if rotation is enabled (via rotation_p or affine_p)
|
|
378
|
+
rotation_enabled = (
|
|
379
|
+
geo.rotation_p is not None and geo.rotation_p > 0
|
|
380
|
+
) or (
|
|
381
|
+
geo.rotation_p is None
|
|
382
|
+
and geo.scale_p is None
|
|
383
|
+
and geo.translate_p is None
|
|
384
|
+
and geo.affine_p > 0
|
|
385
|
+
)
|
|
386
|
+
# Check if scale is enabled (via scale_p or affine_p)
|
|
387
|
+
scale_enabled = (
|
|
388
|
+
geo.scale_p is not None and geo.scale_p > 0
|
|
389
|
+
) or (
|
|
390
|
+
geo.rotation_p is None
|
|
391
|
+
and geo.scale_p is None
|
|
392
|
+
and geo.translate_p is None
|
|
393
|
+
and geo.affine_p > 0
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if rotation_enabled or scale_enabled:
|
|
397
|
+
# First find the actual max bbox size from labels
|
|
398
|
+
bbox_size = find_max_instance_bbox_size(train_label)
|
|
399
|
+
bbox_size = max(
|
|
400
|
+
bbox_size,
|
|
401
|
+
self.config.data_config.preprocessing.min_crop_size
|
|
402
|
+
or 100,
|
|
403
|
+
)
|
|
404
|
+
rotation_max = (
|
|
405
|
+
max(
|
|
406
|
+
abs(geo.rotation_min),
|
|
407
|
+
abs(geo.rotation_max),
|
|
408
|
+
)
|
|
409
|
+
if rotation_enabled
|
|
410
|
+
else 0.0
|
|
411
|
+
)
|
|
412
|
+
scale_max = geo.scale_max if scale_enabled else 1.0
|
|
413
|
+
padding = compute_augmentation_padding(
|
|
414
|
+
bbox_size=bbox_size,
|
|
415
|
+
rotation_max=rotation_max,
|
|
416
|
+
scale_max=scale_max,
|
|
417
|
+
)
|
|
418
|
+
else:
|
|
419
|
+
padding = 0
|
|
420
|
+
else:
|
|
421
|
+
padding = 0
|
|
422
|
+
|
|
294
423
|
crop_sz = find_instance_crop_size(
|
|
295
424
|
labels=train_label,
|
|
425
|
+
padding=padding,
|
|
296
426
|
maximum_stride=self.config.model_config.backbone_config[
|
|
297
427
|
f"{self.backbone_type}"
|
|
298
428
|
]["max_stride"],
|
|
299
429
|
min_crop_size=self.config.data_config.preprocessing.min_crop_size,
|
|
300
|
-
input_scaling=self.config.data_config.preprocessing.scale,
|
|
301
430
|
)
|
|
302
431
|
|
|
303
432
|
if crop_sz > max_crop_size:
|
|
@@ -509,6 +638,14 @@ class ModelTrainer:
|
|
|
509
638
|
"""Compute config parameters."""
|
|
510
639
|
logger.info("Setting up config...")
|
|
511
640
|
|
|
641
|
+
# Normalize empty strings to None for optional wandb fields
|
|
642
|
+
if self.config.trainer_config.wandb.prv_runid == "":
|
|
643
|
+
self.config.trainer_config.wandb.prv_runid = None
|
|
644
|
+
|
|
645
|
+
# Default wandb run name to trainer run_name if not specified
|
|
646
|
+
if self.config.trainer_config.wandb.name is None:
|
|
647
|
+
self.config.trainer_config.wandb.name = self.config.trainer_config.run_name
|
|
648
|
+
|
|
512
649
|
# compute preprocessing parameters from the labels objects and fill in the config
|
|
513
650
|
self._setup_preprocessing_config()
|
|
514
651
|
|
|
@@ -565,7 +702,7 @@ class ModelTrainer:
|
|
|
565
702
|
self._verify_model_input_channels()
|
|
566
703
|
|
|
567
704
|
def _setup_model_ckpt_dir(self):
|
|
568
|
-
"""Create the model ckpt folder."""
|
|
705
|
+
"""Create the model ckpt folder and save ground truth labels."""
|
|
569
706
|
ckpt_path = (
|
|
570
707
|
Path(self.config.trainer_config.ckpt_dir)
|
|
571
708
|
/ self.config.trainer_config.run_name
|
|
@@ -581,16 +718,61 @@ class ModelTrainer:
|
|
|
581
718
|
raise OSError(message)
|
|
582
719
|
|
|
583
720
|
if RANK in [0, -1]:
|
|
721
|
+
# Check if we should filter to user-labeled frames only
|
|
722
|
+
user_instances_only = OmegaConf.select(
|
|
723
|
+
self.config, "data_config.user_instances_only", default=True
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# Save train and val ground truth labels
|
|
584
727
|
for idx, (train, val) in enumerate(zip(self.train_labels, self.val_labels)):
|
|
585
|
-
|
|
586
|
-
|
|
728
|
+
# Filter to user-labeled frames if needed (for evaluation)
|
|
729
|
+
if user_instances_only:
|
|
730
|
+
train_filtered = self._filter_to_user_labeled(train)
|
|
731
|
+
val_filtered = self._filter_to_user_labeled(val)
|
|
732
|
+
else:
|
|
733
|
+
train_filtered = train
|
|
734
|
+
val_filtered = val
|
|
735
|
+
|
|
736
|
+
train_filtered.save(
|
|
737
|
+
Path(ckpt_path) / f"labels_gt.train.{idx}.slp",
|
|
587
738
|
restore_original_videos=False,
|
|
588
739
|
)
|
|
589
|
-
|
|
590
|
-
Path(ckpt_path) / f"
|
|
740
|
+
val_filtered.save(
|
|
741
|
+
Path(ckpt_path) / f"labels_gt.val.{idx}.slp",
|
|
591
742
|
restore_original_videos=False,
|
|
592
743
|
)
|
|
593
744
|
|
|
745
|
+
# Save test ground truth labels if test paths are provided
|
|
746
|
+
test_file_path = OmegaConf.select(
|
|
747
|
+
self.config, "data_config.test_file_path", default=None
|
|
748
|
+
)
|
|
749
|
+
if test_file_path is not None:
|
|
750
|
+
# Normalize to list of strings
|
|
751
|
+
if isinstance(test_file_path, str):
|
|
752
|
+
test_paths = [test_file_path]
|
|
753
|
+
else:
|
|
754
|
+
test_paths = list(test_file_path)
|
|
755
|
+
|
|
756
|
+
for idx, test_path in enumerate(test_paths):
|
|
757
|
+
# Only save if it's a .slp file (not a video file)
|
|
758
|
+
if test_path.endswith(".slp") or test_path.endswith(".pkg.slp"):
|
|
759
|
+
try:
|
|
760
|
+
test_labels = sio.load_slp(test_path)
|
|
761
|
+
if user_instances_only:
|
|
762
|
+
test_filtered = self._filter_to_user_labeled(
|
|
763
|
+
test_labels
|
|
764
|
+
)
|
|
765
|
+
else:
|
|
766
|
+
test_filtered = test_labels
|
|
767
|
+
test_filtered.save(
|
|
768
|
+
Path(ckpt_path) / f"labels_gt.test.{idx}.slp",
|
|
769
|
+
restore_original_videos=False,
|
|
770
|
+
)
|
|
771
|
+
except Exception as e:
|
|
772
|
+
logger.warning(
|
|
773
|
+
f"Could not save test ground truth for {test_path}: {e}"
|
|
774
|
+
)
|
|
775
|
+
|
|
594
776
|
def _setup_viz_datasets(self):
|
|
595
777
|
"""Setup dataloaders."""
|
|
596
778
|
data_viz_config = self.config.copy()
|
|
@@ -716,6 +898,10 @@ class ModelTrainer:
|
|
|
716
898
|
)
|
|
717
899
|
loggers.append(wandb_logger)
|
|
718
900
|
|
|
901
|
+
# Learning rate monitor callback - logs LR at each step for dynamic schedulers
|
|
902
|
+
# Only added when wandb is enabled since it requires a logger
|
|
903
|
+
callbacks.append(LearningRateMonitor(logging_interval="step"))
|
|
904
|
+
|
|
719
905
|
# save the configs as yaml in the checkpoint dir
|
|
720
906
|
# Mask API key in both configs to prevent saving to disk
|
|
721
907
|
self.config.trainer_config.wandb.api_key = ""
|
|
@@ -814,13 +1000,80 @@ class ModelTrainer:
|
|
|
814
1000
|
if self.config.trainer_config.use_wandb and OmegaConf.select(
|
|
815
1001
|
self.config, "trainer_config.wandb.save_viz_imgs_wandb", default=False
|
|
816
1002
|
):
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
wandb_run_name=self.config.trainer_config.wandb.name,
|
|
821
|
-
is_bottomup=(self.model_type == "bottomup"),
|
|
822
|
-
)
|
|
1003
|
+
# Get wandb viz config options
|
|
1004
|
+
viz_enabled = OmegaConf.select(
|
|
1005
|
+
self.config, "trainer_config.wandb.viz_enabled", default=True
|
|
823
1006
|
)
|
|
1007
|
+
viz_boxes = OmegaConf.select(
|
|
1008
|
+
self.config, "trainer_config.wandb.viz_boxes", default=False
|
|
1009
|
+
)
|
|
1010
|
+
viz_masks = OmegaConf.select(
|
|
1011
|
+
self.config, "trainer_config.wandb.viz_masks", default=False
|
|
1012
|
+
)
|
|
1013
|
+
viz_box_size = OmegaConf.select(
|
|
1014
|
+
self.config, "trainer_config.wandb.viz_box_size", default=5.0
|
|
1015
|
+
)
|
|
1016
|
+
viz_confmap_threshold = OmegaConf.select(
|
|
1017
|
+
self.config,
|
|
1018
|
+
"trainer_config.wandb.viz_confmap_threshold",
|
|
1019
|
+
default=0.1,
|
|
1020
|
+
)
|
|
1021
|
+
log_viz_table = OmegaConf.select(
|
|
1022
|
+
self.config, "trainer_config.wandb.log_viz_table", default=False
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
# Create viz data pipelines for wandb callback
|
|
1026
|
+
wandb_train_viz_pipeline = cycle(copy.deepcopy(viz_train_dataset))
|
|
1027
|
+
wandb_val_viz_pipeline = cycle(copy.deepcopy(viz_val_dataset))
|
|
1028
|
+
|
|
1029
|
+
if self.model_type == "bottomup":
|
|
1030
|
+
# Bottom-up model needs PAF visualizations
|
|
1031
|
+
wandb_train_pafs_pipeline = cycle(copy.deepcopy(viz_train_dataset))
|
|
1032
|
+
wandb_val_pafs_pipeline = cycle(copy.deepcopy(viz_val_dataset))
|
|
1033
|
+
callbacks.append(
|
|
1034
|
+
WandBVizCallbackWithPAFs(
|
|
1035
|
+
train_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1036
|
+
next(wandb_train_viz_pipeline)
|
|
1037
|
+
),
|
|
1038
|
+
val_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1039
|
+
next(wandb_val_viz_pipeline)
|
|
1040
|
+
),
|
|
1041
|
+
train_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1042
|
+
next(wandb_train_pafs_pipeline), include_pafs=True
|
|
1043
|
+
),
|
|
1044
|
+
val_pafs_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1045
|
+
next(wandb_val_pafs_pipeline), include_pafs=True
|
|
1046
|
+
),
|
|
1047
|
+
viz_enabled=viz_enabled,
|
|
1048
|
+
viz_boxes=viz_boxes,
|
|
1049
|
+
viz_masks=viz_masks,
|
|
1050
|
+
box_size=viz_box_size,
|
|
1051
|
+
confmap_threshold=viz_confmap_threshold,
|
|
1052
|
+
log_table=log_viz_table,
|
|
1053
|
+
)
|
|
1054
|
+
)
|
|
1055
|
+
else:
|
|
1056
|
+
# Standard models
|
|
1057
|
+
callbacks.append(
|
|
1058
|
+
WandBVizCallback(
|
|
1059
|
+
train_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1060
|
+
next(wandb_train_viz_pipeline)
|
|
1061
|
+
),
|
|
1062
|
+
val_viz_fn=lambda: self.lightning_model.get_visualization_data(
|
|
1063
|
+
next(wandb_val_viz_pipeline)
|
|
1064
|
+
),
|
|
1065
|
+
viz_enabled=viz_enabled,
|
|
1066
|
+
viz_boxes=viz_boxes,
|
|
1067
|
+
viz_masks=viz_masks,
|
|
1068
|
+
box_size=viz_box_size,
|
|
1069
|
+
confmap_threshold=viz_confmap_threshold,
|
|
1070
|
+
log_table=log_viz_table,
|
|
1071
|
+
)
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
# Add custom progress bar with better metric formatting
|
|
1075
|
+
if self.config.trainer_config.enable_progress_bar:
|
|
1076
|
+
callbacks.append(SleapProgressBar())
|
|
824
1077
|
|
|
825
1078
|
return loggers, callbacks
|
|
826
1079
|
|
|
@@ -959,7 +1212,7 @@ class ModelTrainer:
|
|
|
959
1212
|
logger.info(f"Backbone model: {self.lightning_model.model.backbone}")
|
|
960
1213
|
logger.info(f"Head model: {self.lightning_model.model.head_layers}")
|
|
961
1214
|
total_params = sum(p.numel() for p in self.lightning_model.parameters())
|
|
962
|
-
logger.info(f"Total model parameters: {total_params}")
|
|
1215
|
+
logger.info(f"Total model parameters: {total_params:,}")
|
|
963
1216
|
self.config.model_config.total_params = total_params
|
|
964
1217
|
|
|
965
1218
|
# setup dataloaders
|
|
@@ -1000,6 +1253,23 @@ class ModelTrainer:
|
|
|
1000
1253
|
id=self.config.trainer_config.wandb.prv_runid,
|
|
1001
1254
|
group=self.config.trainer_config.wandb.group,
|
|
1002
1255
|
)
|
|
1256
|
+
|
|
1257
|
+
# Define custom x-axes for wandb metrics
|
|
1258
|
+
# Epoch-level metrics use epoch as x-axis, step-level use default global_step
|
|
1259
|
+
wandb.define_metric("epoch")
|
|
1260
|
+
wandb.define_metric("val_loss", step_metric="epoch")
|
|
1261
|
+
wandb.define_metric("val_time", step_metric="epoch")
|
|
1262
|
+
wandb.define_metric("train_time", step_metric="epoch")
|
|
1263
|
+
# Per-node losses use epoch as x-axis
|
|
1264
|
+
for node_name in self.skeletons[0].node_names:
|
|
1265
|
+
wandb.define_metric(node_name, step_metric="epoch")
|
|
1266
|
+
|
|
1267
|
+
# Visualization images use epoch as x-axis
|
|
1268
|
+
wandb.define_metric("train_predictions*", step_metric="epoch")
|
|
1269
|
+
wandb.define_metric("val_predictions*", step_metric="epoch")
|
|
1270
|
+
wandb.define_metric("train_pafs*", step_metric="epoch")
|
|
1271
|
+
wandb.define_metric("val_pafs*", step_metric="epoch")
|
|
1272
|
+
|
|
1003
1273
|
self.config.trainer_config.wandb.current_run_id = wandb.run.id
|
|
1004
1274
|
wandb.config["run_name"] = self.config.trainer_config.wandb.name
|
|
1005
1275
|
wandb.config["run_config"] = OmegaConf.to_container(
|
|
@@ -1017,6 +1287,9 @@ class ModelTrainer:
|
|
|
1017
1287
|
|
|
1018
1288
|
self.trainer.strategy.barrier()
|
|
1019
1289
|
|
|
1290
|
+
# Flag to track if training was interrupted (not completed normally)
|
|
1291
|
+
training_interrupted = False
|
|
1292
|
+
|
|
1020
1293
|
try:
|
|
1021
1294
|
logger.info(
|
|
1022
1295
|
f"Finished trainer set up. [{time.time() - start_setup_time:.1f}s]"
|
|
@@ -1032,6 +1305,7 @@ class ModelTrainer:
|
|
|
1032
1305
|
|
|
1033
1306
|
except KeyboardInterrupt:
|
|
1034
1307
|
logger.info("Stopping training...")
|
|
1308
|
+
training_interrupted = True
|
|
1035
1309
|
|
|
1036
1310
|
finally:
|
|
1037
1311
|
logger.info(
|
|
@@ -1063,3 +1337,15 @@ class ModelTrainer:
|
|
|
1063
1337
|
if viz_dir.exists():
|
|
1064
1338
|
logger.info(f"Deleting viz folder at {viz_dir}...")
|
|
1065
1339
|
shutil.rmtree(viz_dir, ignore_errors=True)
|
|
1340
|
+
|
|
1341
|
+
# Clean up entire run folder if training was interrupted (KeyboardInterrupt)
|
|
1342
|
+
if training_interrupted and self.trainer.global_rank == 0:
|
|
1343
|
+
run_dir = (
|
|
1344
|
+
Path(self.config.trainer_config.ckpt_dir)
|
|
1345
|
+
/ self.config.trainer_config.run_name
|
|
1346
|
+
)
|
|
1347
|
+
if run_dir.exists():
|
|
1348
|
+
logger.info(
|
|
1349
|
+
f"Training canceled - cleaning up run folder at {run_dir}..."
|
|
1350
|
+
)
|
|
1351
|
+
shutil.rmtree(run_dir, ignore_errors=True)
|