sleap-nn 0.1.0a0__py3-none-any.whl → 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +4 -2
- sleap_nn/config/get_config.py +5 -0
- sleap_nn/config/trainer_config.py +23 -0
- sleap_nn/data/custom_datasets.py +53 -11
- sleap_nn/evaluation.py +73 -22
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/train.py +5 -0
- sleap_nn/training/callbacks.py +274 -0
- sleap_nn/training/lightning_modules.py +210 -2
- sleap_nn/training/model_trainer.py +53 -0
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a2.dist-info}/METADATA +2 -2
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a2.dist-info}/RECORD +16 -16
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a2.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a2.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a2.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a0.dist-info → sleap_nn-0.1.0a2.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
|
@@ -41,14 +41,16 @@ def _safe_print(msg):
|
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
# Add logger with the custom filter
|
|
44
|
+
# Disable colorization to avoid ANSI codes in captured output
|
|
44
45
|
logger.add(
|
|
45
46
|
_safe_print,
|
|
46
47
|
level="DEBUG",
|
|
47
48
|
filter=_should_log,
|
|
48
|
-
format="{time:YYYY-MM-DD HH:mm:ss} | {
|
|
49
|
+
format="{time:YYYY-MM-DD HH:mm:ss} | {message}",
|
|
50
|
+
colorize=False,
|
|
49
51
|
)
|
|
50
52
|
|
|
51
|
-
__version__ = "0.1.
|
|
53
|
+
__version__ = "0.1.0a2"
|
|
52
54
|
|
|
53
55
|
# Public API
|
|
54
56
|
from sleap_nn.evaluation import load_metrics
|
sleap_nn/config/get_config.py
CHANGED
|
@@ -677,6 +677,7 @@ def get_trainer_config(
|
|
|
677
677
|
wandb_save_viz_imgs_wandb: bool = False,
|
|
678
678
|
wandb_resume_prv_runid: Optional[str] = None,
|
|
679
679
|
wandb_group_name: Optional[str] = None,
|
|
680
|
+
wandb_delete_local_logs: Optional[bool] = None,
|
|
680
681
|
optimizer: str = "Adam",
|
|
681
682
|
learning_rate: float = 1e-3,
|
|
682
683
|
amsgrad: bool = False,
|
|
@@ -746,6 +747,9 @@ def get_trainer_config(
|
|
|
746
747
|
wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
|
|
747
748
|
ckpt. Default: None
|
|
748
749
|
wandb_group_name: Group name for the wandb run. Default: None.
|
|
750
|
+
wandb_delete_local_logs: If True, delete local wandb logs folder after training.
|
|
751
|
+
If False, keep the folder. If None (default), automatically delete if logging
|
|
752
|
+
online (wandb_mode != "offline") and keep if logging offline. Default: None.
|
|
749
753
|
optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
|
|
750
754
|
learning_rate: Learning rate of type float. Default: 1e-3.
|
|
751
755
|
amsgrad: Enable AMSGrad with the optimizer. Default: False.
|
|
@@ -846,6 +850,7 @@ def get_trainer_config(
|
|
|
846
850
|
save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
|
|
847
851
|
prv_runid=wandb_resume_prv_runid,
|
|
848
852
|
group=wandb_group_name,
|
|
853
|
+
delete_local_logs=wandb_delete_local_logs,
|
|
849
854
|
),
|
|
850
855
|
save_ckpt=save_ckpt,
|
|
851
856
|
ckpt_dir=ckpt_dir,
|
|
@@ -90,6 +90,10 @@ class WandBConfig:
|
|
|
90
90
|
viz_box_size: (float) Size of keypoint boxes in pixels (for viz_boxes). *Default*: `5.0`.
|
|
91
91
|
viz_confmap_threshold: (float) Threshold for confidence map masks (for viz_masks). *Default*: `0.1`.
|
|
92
92
|
log_viz_table: (bool) If True, also log images to a wandb.Table for backwards compatibility. *Default*: `False`.
|
|
93
|
+
delete_local_logs: (bool, optional) If True, delete local wandb logs folder after
|
|
94
|
+
training. If False, keep the folder. If None (default), automatically delete
|
|
95
|
+
if logging online (wandb_mode != "offline") and keep if logging offline.
|
|
96
|
+
*Default*: `None`.
|
|
93
97
|
"""
|
|
94
98
|
|
|
95
99
|
entity: Optional[str] = None
|
|
@@ -107,6 +111,7 @@ class WandBConfig:
|
|
|
107
111
|
viz_box_size: float = 5.0
|
|
108
112
|
viz_confmap_threshold: float = 0.1
|
|
109
113
|
log_viz_table: bool = False
|
|
114
|
+
delete_local_logs: Optional[bool] = None
|
|
110
115
|
|
|
111
116
|
|
|
112
117
|
@define
|
|
@@ -203,6 +208,23 @@ class EarlyStoppingConfig:
|
|
|
203
208
|
stop_training_on_plateau: bool = True
|
|
204
209
|
|
|
205
210
|
|
|
211
|
+
@define
|
|
212
|
+
class EvalConfig:
|
|
213
|
+
"""Configuration for epoch-end evaluation.
|
|
214
|
+
|
|
215
|
+
Attributes:
|
|
216
|
+
enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
|
|
217
|
+
frequency: (int) Evaluate every N epochs. *Default*: `1`.
|
|
218
|
+
oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
|
|
219
|
+
oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
enabled: bool = False
|
|
223
|
+
frequency: int = field(default=1, validator=validators.ge(1))
|
|
224
|
+
oks_stddev: float = field(default=0.025, validator=validators.gt(0))
|
|
225
|
+
oks_scale: Optional[float] = None
|
|
226
|
+
|
|
227
|
+
|
|
206
228
|
@define
|
|
207
229
|
class HardKeypointMiningConfig:
|
|
208
230
|
"""Configuration for online hard keypoint mining.
|
|
@@ -305,6 +327,7 @@ class TrainerConfig:
|
|
|
305
327
|
factory=HardKeypointMiningConfig
|
|
306
328
|
)
|
|
307
329
|
zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
|
|
330
|
+
eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
|
|
308
331
|
|
|
309
332
|
@staticmethod
|
|
310
333
|
def validate_optimizer_name(value):
|
sleap_nn/data/custom_datasets.py
CHANGED
|
@@ -13,6 +13,14 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
13
13
|
import numpy as np
|
|
14
14
|
from PIL import Image
|
|
15
15
|
from loguru import logger
|
|
16
|
+
from rich.progress import (
|
|
17
|
+
Progress,
|
|
18
|
+
SpinnerColumn,
|
|
19
|
+
TextColumn,
|
|
20
|
+
BarColumn,
|
|
21
|
+
TimeElapsedColumn,
|
|
22
|
+
)
|
|
23
|
+
from rich.console import Console
|
|
16
24
|
import torch
|
|
17
25
|
import torchvision.transforms as T
|
|
18
26
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
|
@@ -215,17 +223,51 @@ class BaseDataset(Dataset):
|
|
|
215
223
|
def _fill_cache(self, labels: List[sio.Labels]):
|
|
216
224
|
"""Load all samples to cache."""
|
|
217
225
|
# TODO: Implement parallel processing (using threads might cause error with MediaVideo backend)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
226
|
+
import os
|
|
227
|
+
import sys
|
|
228
|
+
|
|
229
|
+
total_samples = len(self.lf_idx_list)
|
|
230
|
+
cache_type = "disk" if self.cache_img == "disk" else "memory"
|
|
231
|
+
|
|
232
|
+
# Check for NO_COLOR env var or non-interactive terminal
|
|
233
|
+
no_color = (
|
|
234
|
+
os.environ.get("NO_COLOR") is not None
|
|
235
|
+
or os.environ.get("FORCE_COLOR") == "0"
|
|
236
|
+
)
|
|
237
|
+
use_progress = sys.stdout.isatty() and not no_color
|
|
238
|
+
|
|
239
|
+
def process_samples(progress=None, task=None):
|
|
240
|
+
for sample in self.lf_idx_list:
|
|
241
|
+
labels_idx = sample["labels_idx"]
|
|
242
|
+
lf_idx = sample["lf_idx"]
|
|
243
|
+
img = labels[labels_idx][lf_idx].image
|
|
244
|
+
if img.shape[-1] == 1:
|
|
245
|
+
img = np.squeeze(img)
|
|
246
|
+
if self.cache_img == "disk":
|
|
247
|
+
f_name = f"{self.cache_img_path}/sample_{labels_idx}_{lf_idx}.jpg"
|
|
248
|
+
Image.fromarray(img).save(f_name, format="JPEG")
|
|
249
|
+
if self.cache_img == "memory":
|
|
250
|
+
self.cache[(labels_idx, lf_idx)] = img
|
|
251
|
+
if progress is not None:
|
|
252
|
+
progress.update(task, advance=1)
|
|
253
|
+
|
|
254
|
+
if use_progress:
|
|
255
|
+
with Progress(
|
|
256
|
+
SpinnerColumn(),
|
|
257
|
+
TextColumn("[progress.description]{task.description}"),
|
|
258
|
+
BarColumn(),
|
|
259
|
+
TextColumn("{task.completed}/{task.total}"),
|
|
260
|
+
TimeElapsedColumn(),
|
|
261
|
+
console=Console(force_terminal=True),
|
|
262
|
+
transient=True,
|
|
263
|
+
) as progress:
|
|
264
|
+
task = progress.add_task(
|
|
265
|
+
f"Caching images to {cache_type}", total=total_samples
|
|
266
|
+
)
|
|
267
|
+
process_samples(progress, task)
|
|
268
|
+
else:
|
|
269
|
+
logger.info(f"Caching {total_samples} images to {cache_type}...")
|
|
270
|
+
process_samples()
|
|
229
271
|
|
|
230
272
|
def __len__(self) -> int:
|
|
231
273
|
"""Return the number of samples in the dataset."""
|
sleap_nn/evaluation.py
CHANGED
|
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
|
|
|
29
29
|
"""
|
|
30
30
|
instance_list = []
|
|
31
31
|
frame_idx = labeled_frame.frame_idx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
32
|
+
|
|
33
|
+
# Extract video path with fallbacks for embedded videos
|
|
34
|
+
video = labeled_frame.video
|
|
35
|
+
video_path = None
|
|
36
|
+
if video is not None:
|
|
37
|
+
backend = getattr(video, "backend", None)
|
|
38
|
+
if backend is not None:
|
|
39
|
+
# Try source_filename first (for embedded videos with provenance)
|
|
40
|
+
video_path = getattr(backend, "source_filename", None)
|
|
41
|
+
if video_path is None:
|
|
42
|
+
video_path = getattr(backend, "filename", None)
|
|
43
|
+
# Fallback to video.filename if backend doesn't have it
|
|
44
|
+
if video_path is None:
|
|
45
|
+
video_path = getattr(video, "filename", None)
|
|
46
|
+
# Handle list filenames (image sequences)
|
|
47
|
+
if isinstance(video_path, list) and video_path:
|
|
48
|
+
video_path = video_path[0]
|
|
49
|
+
# Final fallback: use a unique identifier
|
|
50
|
+
if video_path is None:
|
|
51
|
+
video_path = f"video_{id(video)}" if video is not None else "unknown"
|
|
52
|
+
|
|
37
53
|
for instance in labeled_frame.instances:
|
|
38
54
|
match_instance = MatchInstance(
|
|
39
55
|
instance=instance, frame_idx=frame_idx, video_path=video_path
|
|
@@ -47,6 +63,10 @@ def find_frame_pairs(
|
|
|
47
63
|
) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
|
|
48
64
|
"""Find corresponding frames across two sets of labels.
|
|
49
65
|
|
|
66
|
+
This function uses sleap-io's robust video matching API to handle various
|
|
67
|
+
scenarios including embedded videos, cross-platform paths, and videos with
|
|
68
|
+
different metadata.
|
|
69
|
+
|
|
50
70
|
Args:
|
|
51
71
|
labels_gt: A `sio.Labels` instance with ground truth instances.
|
|
52
72
|
labels_pr: A `sio.Labels` instance with predicted instances.
|
|
@@ -56,16 +76,15 @@ def find_frame_pairs(
|
|
|
56
76
|
Returns:
|
|
57
77
|
A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
|
|
58
78
|
"""
|
|
79
|
+
# Use sleap-io's robust video matching API (added in 0.6.2)
|
|
80
|
+
# The match() method returns a MatchResult with video_map: {pred_video: gt_video}
|
|
81
|
+
match_result = labels_gt.match(labels_pr)
|
|
82
|
+
|
|
59
83
|
frame_pairs = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if video_gt.matches_content(video) and video_gt.matches_path(video):
|
|
65
|
-
video_pr = video
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
if video_pr is None:
|
|
84
|
+
# Iterate over matched video pairs (pred_video -> gt_video mapping)
|
|
85
|
+
for video_pr, video_gt in match_result.video_map.items():
|
|
86
|
+
if video_gt is None:
|
|
87
|
+
# No match found for this prediction video
|
|
69
88
|
continue
|
|
70
89
|
|
|
71
90
|
# Find labeled frames in this video.
|
|
@@ -786,11 +805,26 @@ def run_evaluation(
|
|
|
786
805
|
"""Evaluate SLEAP-NN model predictions against ground truth labels."""
|
|
787
806
|
logger.info("Loading ground truth labels...")
|
|
788
807
|
ground_truth_instances = sio.load_slp(ground_truth_path)
|
|
808
|
+
logger.info(
|
|
809
|
+
f" Ground truth: {len(ground_truth_instances.videos)} videos, "
|
|
810
|
+
f"{len(ground_truth_instances.labeled_frames)} frames"
|
|
811
|
+
)
|
|
789
812
|
|
|
790
813
|
logger.info("Loading predicted labels...")
|
|
791
814
|
predicted_instances = sio.load_slp(predicted_path)
|
|
815
|
+
logger.info(
|
|
816
|
+
f" Predictions: {len(predicted_instances.videos)} videos, "
|
|
817
|
+
f"{len(predicted_instances.labeled_frames)} frames"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
logger.info("Matching videos and frames...")
|
|
821
|
+
# Get match stats before creating evaluator
|
|
822
|
+
match_result = ground_truth_instances.match(predicted_instances)
|
|
823
|
+
logger.info(
|
|
824
|
+
f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
|
|
825
|
+
)
|
|
792
826
|
|
|
793
|
-
logger.info("
|
|
827
|
+
logger.info("Matching instances...")
|
|
794
828
|
evaluator = Evaluator(
|
|
795
829
|
ground_truth_instances=ground_truth_instances,
|
|
796
830
|
predicted_instances=predicted_instances,
|
|
@@ -799,21 +833,38 @@ def run_evaluation(
|
|
|
799
833
|
match_threshold=match_threshold,
|
|
800
834
|
user_labels_only=user_labels_only,
|
|
801
835
|
)
|
|
836
|
+
logger.info(
|
|
837
|
+
f" Frame pairs: {len(evaluator.frame_pairs)}, "
|
|
838
|
+
f"Matched instances: {len(evaluator.positive_pairs)}, "
|
|
839
|
+
f"Unmatched GT: {len(evaluator.false_negatives)}"
|
|
840
|
+
)
|
|
802
841
|
|
|
803
842
|
logger.info("Computing evaluation metrics...")
|
|
804
843
|
metrics = evaluator.evaluate()
|
|
805
844
|
|
|
845
|
+
# Compute PCK at specific thresholds (5 and 10 pixels)
|
|
846
|
+
dists = metrics["distance_metrics"]["dists"]
|
|
847
|
+
dists_clean = np.copy(dists)
|
|
848
|
+
dists_clean[np.isnan(dists_clean)] = np.inf
|
|
849
|
+
pck_5 = (dists_clean < 5).mean()
|
|
850
|
+
pck_10 = (dists_clean < 10).mean()
|
|
851
|
+
|
|
806
852
|
# Print key metrics
|
|
807
853
|
logger.info("Evaluation Results:")
|
|
808
|
-
logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
809
|
-
logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
810
|
-
logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
811
|
-
logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.
|
|
812
|
-
logger.info(f"
|
|
854
|
+
logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
855
|
+
logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
856
|
+
logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
857
|
+
logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
|
|
858
|
+
logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
|
|
859
|
+
logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
|
|
860
|
+
logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
|
|
861
|
+
logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
|
|
862
|
+
logger.info(f" PCK@5px: {pck_5:.4f}")
|
|
863
|
+
logger.info(f" PCK@10px: {pck_10:.4f}")
|
|
813
864
|
logger.info(
|
|
814
|
-
f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
865
|
+
f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
815
866
|
)
|
|
816
|
-
logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
867
|
+
logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
817
868
|
|
|
818
869
|
# Save metrics if path provided
|
|
819
870
|
if save_metrics:
|
sleap_nn/inference/bottomup.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Inference modules for BottomUp models."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Dict, Optional
|
|
4
5
|
import torch
|
|
5
6
|
import lightning as L
|
|
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
|
|
|
7
8
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
8
9
|
from sleap_nn.inference.identity import classify_peaks_from_maps
|
|
9
10
|
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class BottomUpInferenceModel(L.LightningModule):
|
|
12
15
|
"""BottomUp Inference model.
|
|
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
63
66
|
return_pafs: Optional[bool] = False,
|
|
64
67
|
return_paf_graph: Optional[bool] = False,
|
|
65
68
|
input_scale: float = 1.0,
|
|
69
|
+
max_peaks_per_node: Optional[int] = None,
|
|
66
70
|
):
|
|
67
|
-
"""Initialise the model attributes.
|
|
71
|
+
"""Initialise the model attributes.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
torch_model: A `nn.Module` that accepts images and predicts confidence maps.
|
|
75
|
+
paf_scorer: A `PAFScorer` instance for grouping instances.
|
|
76
|
+
cms_output_stride: Output stride of confidence maps relative to images.
|
|
77
|
+
pafs_output_stride: Output stride of PAFs relative to images.
|
|
78
|
+
peak_threshold: Minimum confidence map value for valid peaks.
|
|
79
|
+
refinement: Peak refinement method: None, "integral", or "local".
|
|
80
|
+
integral_patch_size: Size of patches for integral refinement.
|
|
81
|
+
return_confmaps: If True, return confidence maps in output.
|
|
82
|
+
return_pafs: If True, return PAFs in output.
|
|
83
|
+
return_paf_graph: If True, return intermediate PAF graph in output.
|
|
84
|
+
input_scale: Scale factor applied to input images.
|
|
85
|
+
max_peaks_per_node: Maximum number of peaks allowed per node before
|
|
86
|
+
skipping PAF scoring. If any node has more peaks than this limit,
|
|
87
|
+
empty predictions are returned. This prevents combinatorial explosion
|
|
88
|
+
during early training when confidence maps are noisy. Set to None to
|
|
89
|
+
disable this check (default). Recommended value: 100.
|
|
90
|
+
"""
|
|
68
91
|
super().__init__()
|
|
69
92
|
self.torch_model = torch_model
|
|
70
93
|
self.paf_scorer = paf_scorer
|
|
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
77
100
|
self.return_pafs = return_pafs
|
|
78
101
|
self.return_paf_graph = return_paf_graph
|
|
79
102
|
self.input_scale = input_scale
|
|
103
|
+
self.max_peaks_per_node = max_peaks_per_node
|
|
80
104
|
|
|
81
105
|
def _generate_cms_peaks(self, cms):
|
|
82
106
|
# TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
|
|
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
124
148
|
) # (batch, h, w, 2*edges)
|
|
125
149
|
cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
|
|
126
150
|
|
|
127
|
-
(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
151
|
+
# Check if too many peaks per node (prevents combinatorial explosion)
|
|
152
|
+
skip_paf_scoring = False
|
|
153
|
+
if self.max_peaks_per_node is not None:
|
|
154
|
+
n_nodes = cms.shape[1]
|
|
155
|
+
for b in range(self.batch_size):
|
|
156
|
+
for node_idx in range(n_nodes):
|
|
157
|
+
n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
|
|
158
|
+
if n_peaks > self.max_peaks_per_node:
|
|
159
|
+
logger.warning(
|
|
160
|
+
f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
|
|
161
|
+
f"(max_peaks_per_node={self.max_peaks_per_node}). "
|
|
162
|
+
f"Model may need more training."
|
|
163
|
+
)
|
|
164
|
+
skip_paf_scoring = True
|
|
165
|
+
break
|
|
166
|
+
if skip_paf_scoring:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
if skip_paf_scoring:
|
|
170
|
+
# Return empty predictions for each sample
|
|
171
|
+
device = cms.device
|
|
172
|
+
n_nodes = cms.shape[1]
|
|
173
|
+
predicted_instances_adjusted = []
|
|
174
|
+
predicted_peak_scores = []
|
|
175
|
+
predicted_instance_scores = []
|
|
176
|
+
for _ in range(self.batch_size):
|
|
177
|
+
predicted_instances_adjusted.append(
|
|
178
|
+
torch.full((0, n_nodes, 2), float("nan"), device=device)
|
|
179
|
+
)
|
|
180
|
+
predicted_peak_scores.append(
|
|
181
|
+
torch.full((0, n_nodes), float("nan"), device=device)
|
|
182
|
+
)
|
|
183
|
+
predicted_instance_scores.append(torch.tensor([], device=device))
|
|
184
|
+
edge_inds = [
|
|
185
|
+
torch.tensor([], dtype=torch.int32, device=device)
|
|
186
|
+
] * self.batch_size
|
|
187
|
+
edge_peak_inds = [
|
|
188
|
+
torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
|
|
189
|
+
] * self.batch_size
|
|
190
|
+
line_scores = [torch.tensor([], device=device)] * self.batch_size
|
|
191
|
+
else:
|
|
192
|
+
(
|
|
193
|
+
predicted_instances,
|
|
194
|
+
predicted_peak_scores,
|
|
195
|
+
predicted_instance_scores,
|
|
196
|
+
edge_inds,
|
|
197
|
+
edge_peak_inds,
|
|
198
|
+
line_scores,
|
|
199
|
+
) = self.paf_scorer.predict(
|
|
200
|
+
pafs=pafs,
|
|
201
|
+
peaks=cms_peaks,
|
|
202
|
+
peak_vals=cms_peak_vals,
|
|
203
|
+
peak_channel_inds=cms_peak_channel_inds,
|
|
146
204
|
)
|
|
205
|
+
|
|
206
|
+
predicted_instances = [p / self.input_scale for p in predicted_instances]
|
|
207
|
+
predicted_instances_adjusted = []
|
|
208
|
+
for idx, p in enumerate(predicted_instances):
|
|
209
|
+
predicted_instances_adjusted.append(
|
|
210
|
+
p / inputs["eff_scale"][idx].to(p.device)
|
|
211
|
+
)
|
|
212
|
+
|
|
147
213
|
out = {
|
|
148
214
|
"pred_instance_peaks": predicted_instances_adjusted,
|
|
149
215
|
"pred_peak_values": predicted_peak_scores,
|
sleap_nn/train.py
CHANGED
|
@@ -175,6 +175,7 @@ def train(
|
|
|
175
175
|
wandb_save_viz_imgs_wandb: bool = False,
|
|
176
176
|
wandb_resume_prv_runid: Optional[str] = None,
|
|
177
177
|
wandb_group_name: Optional[str] = None,
|
|
178
|
+
wandb_delete_local_logs: Optional[bool] = None,
|
|
178
179
|
optimizer: str = "Adam",
|
|
179
180
|
learning_rate: float = 1e-3,
|
|
180
181
|
amsgrad: bool = False,
|
|
@@ -353,6 +354,9 @@ def train(
|
|
|
353
354
|
wandb_resume_prv_runid: Previous run ID if training should be resumed from a previous
|
|
354
355
|
ckpt. Default: None
|
|
355
356
|
wandb_group_name: Group name for the wandb run. Default: None.
|
|
357
|
+
wandb_delete_local_logs: If True, delete local wandb logs folder after training.
|
|
358
|
+
If False, keep the folder. If None (default), automatically delete if logging
|
|
359
|
+
online (wandb_mode != "offline") and keep if logging offline. Default: None.
|
|
356
360
|
optimizer: Optimizer to be used. One of ["Adam", "AdamW"]. Default: "Adam".
|
|
357
361
|
learning_rate: Learning rate of type float. Default: 1e-3.
|
|
358
362
|
amsgrad: Enable AMSGrad with the optimizer. Default: False.
|
|
@@ -456,6 +460,7 @@ def train(
|
|
|
456
460
|
wandb_save_viz_imgs_wandb=wandb_save_viz_imgs_wandb,
|
|
457
461
|
wandb_resume_prv_runid=wandb_resume_prv_runid,
|
|
458
462
|
wandb_group_name=wandb_group_name,
|
|
463
|
+
wandb_delete_local_logs=wandb_delete_local_logs,
|
|
459
464
|
optimizer=optimizer,
|
|
460
465
|
learning_rate=learning_rate,
|
|
461
466
|
amsgrad=amsgrad,
|
sleap_nn/training/callbacks.py
CHANGED
|
@@ -662,3 +662,277 @@ class ProgressReporterZMQ(Callback):
|
|
|
662
662
|
return {
|
|
663
663
|
k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
|
|
664
664
|
}
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class EpochEndEvaluationCallback(Callback):
|
|
668
|
+
"""Callback to run full evaluation metrics at end of validation epochs.
|
|
669
|
+
|
|
670
|
+
This callback collects predictions and ground truth during validation,
|
|
671
|
+
then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
|
|
672
|
+
metrics to WandB.
|
|
673
|
+
|
|
674
|
+
Attributes:
|
|
675
|
+
skeleton: sio.Skeleton for creating instances.
|
|
676
|
+
videos: List of sio.Video objects.
|
|
677
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
678
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
679
|
+
oks_scale: Optional OKS scale override.
|
|
680
|
+
metrics_to_log: List of metric keys to log.
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
def __init__(
|
|
684
|
+
self,
|
|
685
|
+
skeleton: "sio.Skeleton",
|
|
686
|
+
videos: list,
|
|
687
|
+
eval_frequency: int = 1,
|
|
688
|
+
oks_stddev: float = 0.025,
|
|
689
|
+
oks_scale: Optional[float] = None,
|
|
690
|
+
metrics_to_log: Optional[list] = None,
|
|
691
|
+
):
|
|
692
|
+
"""Initialize the callback.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
skeleton: sio.Skeleton for creating instances.
|
|
696
|
+
videos: List of sio.Video objects.
|
|
697
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
698
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
699
|
+
oks_scale: Optional OKS scale override.
|
|
700
|
+
metrics_to_log: List of metric keys to log. If None, logs all available.
|
|
701
|
+
"""
|
|
702
|
+
super().__init__()
|
|
703
|
+
self.skeleton = skeleton
|
|
704
|
+
self.videos = videos
|
|
705
|
+
self.eval_frequency = eval_frequency
|
|
706
|
+
self.oks_stddev = oks_stddev
|
|
707
|
+
self.oks_scale = oks_scale
|
|
708
|
+
self.metrics_to_log = metrics_to_log or [
|
|
709
|
+
"mOKS",
|
|
710
|
+
"oks_voc.mAP",
|
|
711
|
+
"oks_voc.mAR",
|
|
712
|
+
"avg_distance",
|
|
713
|
+
"p50_distance",
|
|
714
|
+
"mPCK",
|
|
715
|
+
"visibility_precision",
|
|
716
|
+
"visibility_recall",
|
|
717
|
+
]
|
|
718
|
+
|
|
719
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
|
720
|
+
"""Enable prediction collection at the start of validation.
|
|
721
|
+
|
|
722
|
+
Skip during sanity check to avoid inference issues.
|
|
723
|
+
"""
|
|
724
|
+
if trainer.sanity_checking:
|
|
725
|
+
return
|
|
726
|
+
pl_module._collect_val_predictions = True
|
|
727
|
+
|
|
728
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
729
|
+
"""Run evaluation and log metrics at end of validation epoch."""
|
|
730
|
+
import sleap_io as sio
|
|
731
|
+
import numpy as np
|
|
732
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
733
|
+
from sleap_nn.evaluation import Evaluator
|
|
734
|
+
|
|
735
|
+
# Check frequency (epoch is 0-indexed, so add 1)
|
|
736
|
+
if (trainer.current_epoch + 1) % self.eval_frequency != 0:
|
|
737
|
+
pl_module._collect_val_predictions = False
|
|
738
|
+
return
|
|
739
|
+
|
|
740
|
+
# Only run on rank 0 for distributed training
|
|
741
|
+
if not trainer.is_global_zero:
|
|
742
|
+
pl_module._collect_val_predictions = False
|
|
743
|
+
return
|
|
744
|
+
|
|
745
|
+
# Check if we have predictions
|
|
746
|
+
if not pl_module.val_predictions or not pl_module.val_ground_truth:
|
|
747
|
+
logger.warning("No predictions collected for epoch-end evaluation")
|
|
748
|
+
pl_module._collect_val_predictions = False
|
|
749
|
+
return
|
|
750
|
+
|
|
751
|
+
try:
|
|
752
|
+
# Build sio.Labels from accumulated predictions and ground truth
|
|
753
|
+
pred_labels = self._build_pred_labels(pl_module.val_predictions, sio, np)
|
|
754
|
+
gt_labels = self._build_gt_labels(pl_module.val_ground_truth, sio, np)
|
|
755
|
+
|
|
756
|
+
# Check if we have valid frames to evaluate
|
|
757
|
+
if len(pred_labels) == 0:
|
|
758
|
+
logger.warning(
|
|
759
|
+
"No valid predictions for epoch-end evaluation "
|
|
760
|
+
"(all predictions may be empty or NaN)"
|
|
761
|
+
)
|
|
762
|
+
pl_module._collect_val_predictions = False
|
|
763
|
+
pl_module.val_predictions = []
|
|
764
|
+
pl_module.val_ground_truth = []
|
|
765
|
+
return
|
|
766
|
+
|
|
767
|
+
# Run evaluation
|
|
768
|
+
evaluator = Evaluator(
|
|
769
|
+
ground_truth_instances=gt_labels,
|
|
770
|
+
predicted_instances=pred_labels,
|
|
771
|
+
oks_stddev=self.oks_stddev,
|
|
772
|
+
oks_scale=self.oks_scale,
|
|
773
|
+
user_labels_only=False, # All validation frames are "user" frames
|
|
774
|
+
)
|
|
775
|
+
metrics = evaluator.evaluate()
|
|
776
|
+
|
|
777
|
+
# Log to WandB
|
|
778
|
+
self._log_metrics(trainer, metrics, trainer.current_epoch)
|
|
779
|
+
|
|
780
|
+
logger.info(
|
|
781
|
+
f"Epoch {trainer.current_epoch} evaluation: "
|
|
782
|
+
f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
|
|
783
|
+
f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
except Exception as e:
|
|
787
|
+
logger.warning(f"Epoch-end evaluation failed: {e}")
|
|
788
|
+
|
|
789
|
+
# Cleanup
|
|
790
|
+
pl_module._collect_val_predictions = False
|
|
791
|
+
pl_module.val_predictions = []
|
|
792
|
+
pl_module.val_ground_truth = []
|
|
793
|
+
|
|
794
|
+
def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
|
|
795
|
+
"""Convert prediction dicts to sio.Labels."""
|
|
796
|
+
labeled_frames = []
|
|
797
|
+
for pred in predictions:
|
|
798
|
+
pred_peaks = pred["pred_peaks"]
|
|
799
|
+
pred_scores = pred["pred_scores"]
|
|
800
|
+
|
|
801
|
+
# Handle NaN/missing predictions
|
|
802
|
+
if pred_peaks is None or (
|
|
803
|
+
isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
|
|
804
|
+
):
|
|
805
|
+
continue
|
|
806
|
+
|
|
807
|
+
# Handle multi-instance predictions (bottomup)
|
|
808
|
+
if len(pred_peaks.shape) == 2:
|
|
809
|
+
# Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
|
|
810
|
+
pred_peaks = pred_peaks.reshape(1, -1, 2)
|
|
811
|
+
pred_scores = pred_scores.reshape(1, -1)
|
|
812
|
+
|
|
813
|
+
instances = []
|
|
814
|
+
for inst_idx in range(len(pred_peaks)):
|
|
815
|
+
inst_points = pred_peaks[inst_idx]
|
|
816
|
+
inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
|
|
817
|
+
|
|
818
|
+
# Skip if all NaN
|
|
819
|
+
if np.isnan(inst_points).all():
|
|
820
|
+
continue
|
|
821
|
+
|
|
822
|
+
inst = sio.PredictedInstance.from_numpy(
|
|
823
|
+
points_data=inst_points,
|
|
824
|
+
skeleton=self.skeleton,
|
|
825
|
+
point_scores=(
|
|
826
|
+
inst_scores
|
|
827
|
+
if inst_scores is not None
|
|
828
|
+
else np.ones(len(inst_points))
|
|
829
|
+
),
|
|
830
|
+
score=(
|
|
831
|
+
float(np.nanmean(inst_scores))
|
|
832
|
+
if inst_scores is not None
|
|
833
|
+
else 1.0
|
|
834
|
+
),
|
|
835
|
+
)
|
|
836
|
+
instances.append(inst)
|
|
837
|
+
|
|
838
|
+
if instances:
|
|
839
|
+
lf = sio.LabeledFrame(
|
|
840
|
+
video=self.videos[pred["video_idx"]],
|
|
841
|
+
frame_idx=pred["frame_idx"],
|
|
842
|
+
instances=instances,
|
|
843
|
+
)
|
|
844
|
+
labeled_frames.append(lf)
|
|
845
|
+
|
|
846
|
+
return sio.Labels(
|
|
847
|
+
videos=self.videos,
|
|
848
|
+
skeletons=[self.skeleton],
|
|
849
|
+
labeled_frames=labeled_frames,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
|
|
853
|
+
"""Convert ground truth dicts to sio.Labels."""
|
|
854
|
+
labeled_frames = []
|
|
855
|
+
for gt in ground_truth:
|
|
856
|
+
instances = []
|
|
857
|
+
gt_instances = gt["gt_instances"]
|
|
858
|
+
|
|
859
|
+
# Handle shape variations
|
|
860
|
+
if len(gt_instances.shape) == 2:
|
|
861
|
+
# (n_nodes, 2) -> (1, n_nodes, 2)
|
|
862
|
+
gt_instances = gt_instances.reshape(1, -1, 2)
|
|
863
|
+
|
|
864
|
+
for i in range(min(gt["num_instances"], len(gt_instances))):
|
|
865
|
+
inst_data = gt_instances[i]
|
|
866
|
+
if np.isnan(inst_data).all():
|
|
867
|
+
continue
|
|
868
|
+
inst = sio.Instance.from_numpy(
|
|
869
|
+
points_data=inst_data,
|
|
870
|
+
skeleton=self.skeleton,
|
|
871
|
+
)
|
|
872
|
+
instances.append(inst)
|
|
873
|
+
|
|
874
|
+
if instances:
|
|
875
|
+
lf = sio.LabeledFrame(
|
|
876
|
+
video=self.videos[gt["video_idx"]],
|
|
877
|
+
frame_idx=gt["frame_idx"],
|
|
878
|
+
instances=instances,
|
|
879
|
+
)
|
|
880
|
+
labeled_frames.append(lf)
|
|
881
|
+
|
|
882
|
+
return sio.Labels(
|
|
883
|
+
videos=self.videos,
|
|
884
|
+
skeletons=[self.skeleton],
|
|
885
|
+
labeled_frames=labeled_frames,
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
def _log_metrics(self, trainer, metrics: dict, epoch: int):
|
|
889
|
+
"""Log evaluation metrics to WandB."""
|
|
890
|
+
import numpy as np
|
|
891
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
892
|
+
|
|
893
|
+
# Get WandB logger
|
|
894
|
+
wandb_logger = None
|
|
895
|
+
for log in trainer.loggers:
|
|
896
|
+
if isinstance(log, WandbLogger):
|
|
897
|
+
wandb_logger = log
|
|
898
|
+
break
|
|
899
|
+
|
|
900
|
+
if wandb_logger is None:
|
|
901
|
+
return
|
|
902
|
+
|
|
903
|
+
log_dict = {"epoch": epoch}
|
|
904
|
+
|
|
905
|
+
# Extract key metrics with consistent naming
|
|
906
|
+
if "mOKS" in self.metrics_to_log:
|
|
907
|
+
log_dict["val_mOKS"] = metrics["mOKS"]["mOKS"]
|
|
908
|
+
|
|
909
|
+
if "oks_voc.mAP" in self.metrics_to_log:
|
|
910
|
+
log_dict["val_oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
|
|
911
|
+
|
|
912
|
+
if "oks_voc.mAR" in self.metrics_to_log:
|
|
913
|
+
log_dict["val_oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
|
|
914
|
+
|
|
915
|
+
if "avg_distance" in self.metrics_to_log:
|
|
916
|
+
val = metrics["distance_metrics"]["avg"]
|
|
917
|
+
if not np.isnan(val):
|
|
918
|
+
log_dict["val_avg_distance"] = val
|
|
919
|
+
|
|
920
|
+
if "p50_distance" in self.metrics_to_log:
|
|
921
|
+
val = metrics["distance_metrics"]["p50"]
|
|
922
|
+
if not np.isnan(val):
|
|
923
|
+
log_dict["val_p50_distance"] = val
|
|
924
|
+
|
|
925
|
+
if "mPCK" in self.metrics_to_log:
|
|
926
|
+
log_dict["val_mPCK"] = metrics["pck_metrics"]["mPCK"]
|
|
927
|
+
|
|
928
|
+
if "visibility_precision" in self.metrics_to_log:
|
|
929
|
+
val = metrics["visibility_metrics"]["precision"]
|
|
930
|
+
if not np.isnan(val):
|
|
931
|
+
log_dict["val_visibility_precision"] = val
|
|
932
|
+
|
|
933
|
+
if "visibility_recall" in self.metrics_to_log:
|
|
934
|
+
val = metrics["visibility_metrics"]["recall"]
|
|
935
|
+
if not np.isnan(val):
|
|
936
|
+
log_dict["val_visibility_recall"] = val
|
|
937
|
+
|
|
938
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""This module has the LightningModule classes for all model types."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional, Union, Dict, Any
|
|
3
|
+
from typing import Optional, Union, Dict, Any, List
|
|
4
4
|
import time
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
@@ -184,6 +184,11 @@ class LightningModel(L.LightningModule):
|
|
|
184
184
|
self.val_loss = {}
|
|
185
185
|
self.learning_rate = {}
|
|
186
186
|
|
|
187
|
+
# For epoch-end evaluation
|
|
188
|
+
self.val_predictions: List[Dict] = []
|
|
189
|
+
self.val_ground_truth: List[Dict] = []
|
|
190
|
+
self._collect_val_predictions: bool = False
|
|
191
|
+
|
|
187
192
|
# Initialization for encoder and decoder stacks.
|
|
188
193
|
if self.init_weights == "xavier":
|
|
189
194
|
self.model.apply(xavier_init_weights)
|
|
@@ -331,6 +336,9 @@ class LightningModel(L.LightningModule):
|
|
|
331
336
|
def on_validation_epoch_start(self):
|
|
332
337
|
"""Configure the val timer at the beginning of each epoch."""
|
|
333
338
|
self.val_start_time = time.time()
|
|
339
|
+
# Clear accumulated predictions for new epoch
|
|
340
|
+
self.val_predictions = []
|
|
341
|
+
self.val_ground_truth = []
|
|
334
342
|
|
|
335
343
|
def on_validation_epoch_end(self):
|
|
336
344
|
"""Configure the val timer at the end of every epoch."""
|
|
@@ -639,6 +647,51 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
639
647
|
sync_dist=True,
|
|
640
648
|
)
|
|
641
649
|
|
|
650
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
651
|
+
if self._collect_val_predictions:
|
|
652
|
+
with torch.no_grad():
|
|
653
|
+
# Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
|
|
654
|
+
inference_batch = {k: v for k, v in batch.items()}
|
|
655
|
+
if inference_batch["image"].ndim == 5:
|
|
656
|
+
inference_batch["image"] = inference_batch["image"].squeeze(1)
|
|
657
|
+
inference_output = self.single_instance_inf_layer(inference_batch)
|
|
658
|
+
if isinstance(inference_output, list):
|
|
659
|
+
inference_output = inference_output[0]
|
|
660
|
+
|
|
661
|
+
batch_size = len(batch["frame_idx"])
|
|
662
|
+
for i in range(batch_size):
|
|
663
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
664
|
+
|
|
665
|
+
# Predictions are already in original image space (inference divides by eff_scale)
|
|
666
|
+
pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
667
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
668
|
+
|
|
669
|
+
# Transform GT from preprocessed to original image space
|
|
670
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
671
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
672
|
+
if gt_prep.ndim == 4:
|
|
673
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
674
|
+
gt_orig = gt_prep / eff
|
|
675
|
+
num_inst = batch["num_instances"][i].item()
|
|
676
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
677
|
+
|
|
678
|
+
self.val_predictions.append(
|
|
679
|
+
{
|
|
680
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
681
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
682
|
+
"pred_peaks": pred_peaks,
|
|
683
|
+
"pred_scores": pred_scores,
|
|
684
|
+
}
|
|
685
|
+
)
|
|
686
|
+
self.val_ground_truth.append(
|
|
687
|
+
{
|
|
688
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
689
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
690
|
+
"gt_instances": gt_orig,
|
|
691
|
+
"num_instances": num_inst,
|
|
692
|
+
}
|
|
693
|
+
)
|
|
694
|
+
|
|
642
695
|
|
|
643
696
|
class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
644
697
|
"""Lightning Module for TopDownCenteredInstance Model.
|
|
@@ -856,6 +909,62 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
856
909
|
sync_dist=True,
|
|
857
910
|
)
|
|
858
911
|
|
|
912
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
913
|
+
if self._collect_val_predictions:
|
|
914
|
+
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
915
|
+
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
916
|
+
|
|
917
|
+
with torch.no_grad():
|
|
918
|
+
inference_output = self.instance_peaks_inf_layer(batch)
|
|
919
|
+
|
|
920
|
+
batch_size = len(batch["frame_idx"])
|
|
921
|
+
for i in range(batch_size):
|
|
922
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
923
|
+
|
|
924
|
+
# Predictions from inference (crop-relative, original scale)
|
|
925
|
+
pred_peaks_crop = (
|
|
926
|
+
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
927
|
+
)
|
|
928
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
929
|
+
|
|
930
|
+
# Compute bbox offset in original space from SAVED prep bbox
|
|
931
|
+
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
932
|
+
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
933
|
+
bbox_top_left_orig = (
|
|
934
|
+
bbox_prep[0] / eff
|
|
935
|
+
) # Top-left corner in original space
|
|
936
|
+
|
|
937
|
+
# Full image coordinates (original space)
|
|
938
|
+
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
939
|
+
|
|
940
|
+
# GT transform: crop-relative preprocessed -> full image original
|
|
941
|
+
gt_crop_prep = (
|
|
942
|
+
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
943
|
+
) # (n_nodes, 2)
|
|
944
|
+
gt_crop_orig = gt_crop_prep / eff
|
|
945
|
+
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
946
|
+
|
|
947
|
+
self.val_predictions.append(
|
|
948
|
+
{
|
|
949
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
950
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
951
|
+
"pred_peaks": pred_peaks_full.reshape(
|
|
952
|
+
1, -1, 2
|
|
953
|
+
), # (1, n_nodes, 2)
|
|
954
|
+
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
955
|
+
}
|
|
956
|
+
)
|
|
957
|
+
self.val_ground_truth.append(
|
|
958
|
+
{
|
|
959
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
960
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
961
|
+
"gt_instances": gt_full_orig.reshape(
|
|
962
|
+
1, -1, 2
|
|
963
|
+
), # (1, n_nodes, 2)
|
|
964
|
+
"num_instances": 1,
|
|
965
|
+
}
|
|
966
|
+
)
|
|
967
|
+
|
|
859
968
|
|
|
860
969
|
class CentroidLightningModule(LightningModel):
|
|
861
970
|
"""Lightning Module for Centroid Model.
|
|
@@ -1034,6 +1143,57 @@ class CentroidLightningModule(LightningModel):
|
|
|
1034
1143
|
sync_dist=True,
|
|
1035
1144
|
)
|
|
1036
1145
|
|
|
1146
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1147
|
+
if self._collect_val_predictions:
|
|
1148
|
+
with torch.no_grad():
|
|
1149
|
+
inference_output = self.centroid_inf_layer(batch)
|
|
1150
|
+
|
|
1151
|
+
batch_size = len(batch["frame_idx"])
|
|
1152
|
+
for i in range(batch_size):
|
|
1153
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1154
|
+
|
|
1155
|
+
# Predictions are in original image space (inference divides by eff_scale)
|
|
1156
|
+
# centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
|
|
1157
|
+
pred_centroids = (
|
|
1158
|
+
inference_output["centroids"][i].squeeze(0).cpu().numpy()
|
|
1159
|
+
)
|
|
1160
|
+
pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
|
|
1161
|
+
|
|
1162
|
+
# Transform GT centroids from preprocessed to original image space
|
|
1163
|
+
gt_centroids_prep = (
|
|
1164
|
+
batch["centroids"][i].cpu().numpy()
|
|
1165
|
+
) # (n_samples=1, max_inst, 2)
|
|
1166
|
+
gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
|
|
1167
|
+
num_inst = batch["num_instances"][i].item()
|
|
1168
|
+
|
|
1169
|
+
# Filter to valid instances (non-NaN)
|
|
1170
|
+
valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
|
|
1171
|
+
pred_centroids = pred_centroids[valid_pred_mask]
|
|
1172
|
+
pred_vals = pred_vals[valid_pred_mask]
|
|
1173
|
+
|
|
1174
|
+
gt_centroids_valid = gt_centroids_orig[:num_inst]
|
|
1175
|
+
|
|
1176
|
+
self.val_predictions.append(
|
|
1177
|
+
{
|
|
1178
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1179
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1180
|
+
"pred_peaks": pred_centroids.reshape(
|
|
1181
|
+
-1, 1, 2
|
|
1182
|
+
), # (n_inst, 1, 2)
|
|
1183
|
+
"pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
|
|
1184
|
+
}
|
|
1185
|
+
)
|
|
1186
|
+
self.val_ground_truth.append(
|
|
1187
|
+
{
|
|
1188
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1189
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1190
|
+
"gt_instances": gt_centroids_valid.reshape(
|
|
1191
|
+
-1, 1, 2
|
|
1192
|
+
), # (n_inst, 1, 2)
|
|
1193
|
+
"num_instances": num_inst,
|
|
1194
|
+
}
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1037
1197
|
|
|
1038
1198
|
class BottomUpLightningModule(LightningModel):
|
|
1039
1199
|
"""Lightning Module for BottomUp Model.
|
|
@@ -1126,12 +1286,13 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1126
1286
|
self.bottomup_inf_layer = BottomUpInferenceModel(
|
|
1127
1287
|
torch_model=self.forward,
|
|
1128
1288
|
paf_scorer=paf_scorer,
|
|
1129
|
-
peak_threshold=0.
|
|
1289
|
+
peak_threshold=0.1, # Lower threshold for epoch-end eval during training
|
|
1130
1290
|
input_scale=1.0,
|
|
1131
1291
|
return_confmaps=True,
|
|
1132
1292
|
return_pafs=True,
|
|
1133
1293
|
cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
|
|
1134
1294
|
pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
|
|
1295
|
+
max_peaks_per_node=100, # Prevents combinatorial explosion in early training
|
|
1135
1296
|
)
|
|
1136
1297
|
self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
|
|
1137
1298
|
|
|
@@ -1340,6 +1501,53 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1340
1501
|
sync_dist=True,
|
|
1341
1502
|
)
|
|
1342
1503
|
|
|
1504
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1505
|
+
if self._collect_val_predictions:
|
|
1506
|
+
with torch.no_grad():
|
|
1507
|
+
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
1508
|
+
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
1509
|
+
inference_output = self.bottomup_inf_layer(batch)
|
|
1510
|
+
if isinstance(inference_output, list):
|
|
1511
|
+
inference_output = inference_output[0]
|
|
1512
|
+
|
|
1513
|
+
batch_size = len(batch["frame_idx"])
|
|
1514
|
+
for i in range(batch_size):
|
|
1515
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1516
|
+
|
|
1517
|
+
# Predictions are already in original space (variable number of instances)
|
|
1518
|
+
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
1519
|
+
pred_scores = inference_output["pred_peak_values"][i]
|
|
1520
|
+
if torch.is_tensor(pred_peaks):
|
|
1521
|
+
pred_peaks = pred_peaks.cpu().numpy()
|
|
1522
|
+
if torch.is_tensor(pred_scores):
|
|
1523
|
+
pred_scores = pred_scores.cpu().numpy()
|
|
1524
|
+
|
|
1525
|
+
# Transform GT to original space
|
|
1526
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
1527
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
1528
|
+
if gt_prep.ndim == 4:
|
|
1529
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
1530
|
+
gt_orig = gt_prep / eff
|
|
1531
|
+
num_inst = batch["num_instances"][i].item()
|
|
1532
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
1533
|
+
|
|
1534
|
+
self.val_predictions.append(
|
|
1535
|
+
{
|
|
1536
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1537
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1538
|
+
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
1539
|
+
"pred_scores": pred_scores,
|
|
1540
|
+
}
|
|
1541
|
+
)
|
|
1542
|
+
self.val_ground_truth.append(
|
|
1543
|
+
{
|
|
1544
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1545
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1546
|
+
"gt_instances": gt_orig, # Original space
|
|
1547
|
+
"num_instances": num_inst,
|
|
1548
|
+
}
|
|
1549
|
+
)
|
|
1550
|
+
|
|
1343
1551
|
|
|
1344
1552
|
class BottomUpMultiClassLightningModule(LightningModel):
|
|
1345
1553
|
"""Lightning Module for BottomUp ID Model.
|
|
@@ -61,6 +61,7 @@ from sleap_nn.training.callbacks import (
|
|
|
61
61
|
WandBVizCallbackWithPAFs,
|
|
62
62
|
CSVLoggerCallback,
|
|
63
63
|
SleapProgressBar,
|
|
64
|
+
EpochEndEvaluationCallback,
|
|
64
65
|
)
|
|
65
66
|
from sleap_nn import RANK
|
|
66
67
|
from sleap_nn.legacy_models import get_keras_first_layer_channels
|
|
@@ -898,6 +899,17 @@ class ModelTrainer:
|
|
|
898
899
|
)
|
|
899
900
|
loggers.append(wandb_logger)
|
|
900
901
|
|
|
902
|
+
# Log message about wandb local logs cleanup
|
|
903
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
904
|
+
wandb_config.delete_local_logs is None
|
|
905
|
+
and wandb_config.wandb_mode != "offline"
|
|
906
|
+
)
|
|
907
|
+
if should_delete_wandb_logs:
|
|
908
|
+
logger.info(
|
|
909
|
+
"WandB local logs will be deleted after training completes. "
|
|
910
|
+
"To keep logs, set trainer_config.wandb.delete_local_logs=false"
|
|
911
|
+
)
|
|
912
|
+
|
|
901
913
|
# Learning rate monitor callback - logs LR at each step for dynamic schedulers
|
|
902
914
|
# Only added when wandb is enabled since it requires a logger
|
|
903
915
|
callbacks.append(LearningRateMonitor(logging_interval="step"))
|
|
@@ -1075,6 +1087,18 @@ class ModelTrainer:
|
|
|
1075
1087
|
if self.config.trainer_config.enable_progress_bar:
|
|
1076
1088
|
callbacks.append(SleapProgressBar())
|
|
1077
1089
|
|
|
1090
|
+
# Add epoch-end evaluation callback if enabled
|
|
1091
|
+
if self.config.trainer_config.eval.enabled:
|
|
1092
|
+
callbacks.append(
|
|
1093
|
+
EpochEndEvaluationCallback(
|
|
1094
|
+
skeleton=self.skeletons[0],
|
|
1095
|
+
videos=self.val_labels[0].videos,
|
|
1096
|
+
eval_frequency=self.config.trainer_config.eval.frequency,
|
|
1097
|
+
oks_stddev=self.config.trainer_config.eval.oks_stddev,
|
|
1098
|
+
oks_scale=self.config.trainer_config.eval.oks_scale,
|
|
1099
|
+
)
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1078
1102
|
return loggers, callbacks
|
|
1079
1103
|
|
|
1080
1104
|
def _delete_cache_imgs(self):
|
|
@@ -1270,6 +1294,16 @@ class ModelTrainer:
|
|
|
1270
1294
|
wandb.define_metric("train_pafs*", step_metric="epoch")
|
|
1271
1295
|
wandb.define_metric("val_pafs*", step_metric="epoch")
|
|
1272
1296
|
|
|
1297
|
+
# Evaluation metrics use epoch as x-axis
|
|
1298
|
+
wandb.define_metric("val_mOKS", step_metric="epoch")
|
|
1299
|
+
wandb.define_metric("val_oks_voc_mAP", step_metric="epoch")
|
|
1300
|
+
wandb.define_metric("val_oks_voc_mAR", step_metric="epoch")
|
|
1301
|
+
wandb.define_metric("val_avg_distance", step_metric="epoch")
|
|
1302
|
+
wandb.define_metric("val_p50_distance", step_metric="epoch")
|
|
1303
|
+
wandb.define_metric("val_mPCK", step_metric="epoch")
|
|
1304
|
+
wandb.define_metric("val_visibility_precision", step_metric="epoch")
|
|
1305
|
+
wandb.define_metric("val_visibility_recall", step_metric="epoch")
|
|
1306
|
+
|
|
1273
1307
|
self.config.trainer_config.wandb.current_run_id = wandb.run.id
|
|
1274
1308
|
wandb.config["run_name"] = self.config.trainer_config.wandb.name
|
|
1275
1309
|
wandb.config["run_config"] = OmegaConf.to_container(
|
|
@@ -1314,6 +1348,25 @@ class ModelTrainer:
|
|
|
1314
1348
|
if self.trainer.global_rank == 0 and self.config.trainer_config.use_wandb:
|
|
1315
1349
|
wandb.finish()
|
|
1316
1350
|
|
|
1351
|
+
# Delete local wandb logs if configured
|
|
1352
|
+
wandb_config = self.config.trainer_config.wandb
|
|
1353
|
+
should_delete_wandb_logs = wandb_config.delete_local_logs is True or (
|
|
1354
|
+
wandb_config.delete_local_logs is None
|
|
1355
|
+
and wandb_config.wandb_mode != "offline"
|
|
1356
|
+
)
|
|
1357
|
+
if should_delete_wandb_logs:
|
|
1358
|
+
wandb_dir = (
|
|
1359
|
+
Path(self.config.trainer_config.ckpt_dir)
|
|
1360
|
+
/ self.config.trainer_config.run_name
|
|
1361
|
+
/ "wandb"
|
|
1362
|
+
)
|
|
1363
|
+
if wandb_dir.exists():
|
|
1364
|
+
logger.info(
|
|
1365
|
+
f"Deleting local wandb logs at {wandb_dir}... "
|
|
1366
|
+
"(set trainer_config.wandb.delete_local_logs=false to disable)"
|
|
1367
|
+
)
|
|
1368
|
+
shutil.rmtree(wandb_dir, ignore_errors=True)
|
|
1369
|
+
|
|
1317
1370
|
# delete image disk caching
|
|
1318
1371
|
if (
|
|
1319
1372
|
self.config.data_config.data_pipeline_fw
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a2
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
13
13
|
Requires-Python: <3.14,>=3.11
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
-
Requires-Dist: sleap-io<0.7.0,>=0.6.
|
|
16
|
+
Requires-Dist: sleap-io<0.7.0,>=0.6.2
|
|
17
17
|
Requires-Dist: numpy
|
|
18
18
|
Requires-Dist: lightning
|
|
19
19
|
Requires-Dist: kornia
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
|
|
2
|
-
sleap_nn/__init__.py,sha256=
|
|
2
|
+
sleap_nn/__init__.py,sha256=s3sIImYR5tiP-PfftEj7J8P1Au2nRXj4XWowznrVwm8,1362
|
|
3
3
|
sleap_nn/cli.py,sha256=U4hpEcOxK7a92GeItY95E2DRm5P1ME1GqU__mxaDcW0,21167
|
|
4
|
-
sleap_nn/evaluation.py,sha256=
|
|
4
|
+
sleap_nn/evaluation.py,sha256=sKwLnHbCcaNzPs7CJtgRmFcDRFwPMjCxB92viZvinVI,33498
|
|
5
5
|
sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
|
|
6
6
|
sleap_nn/predict.py,sha256=8QKjRbS-L-6HF1NFJWioBPv3HSzUpFr2oGEB5hRJzQA,35523
|
|
7
7
|
sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
|
|
8
|
-
sleap_nn/train.py,sha256=
|
|
8
|
+
sleap_nn/train.py,sha256=XvVhzMXL9rNQLx1-6jIcp5BAO1pR7AZjdphMn5ZX-_I,27558
|
|
9
9
|
sleap_nn/architectures/__init__.py,sha256=w0XxQcx-CYyooszzvxRkKWiJkUg-26IlwQoGna8gn40,46
|
|
10
10
|
sleap_nn/architectures/common.py,sha256=MLv-zdHsWL5Q2ct_Wv6SQbRS-5hrFtjK_pvBEfwx-vU,3660
|
|
11
11
|
sleap_nn/architectures/convnext.py,sha256=l9lMJDxIMb-9MI3ShOtVwbOUMuwOLtSQlxiVyYHqjvE,13953
|
|
@@ -17,15 +17,15 @@ sleap_nn/architectures/unet.py,sha256=rAy2Omi6tv1MNW2nBn0Tw-94Nw_-1wFfCT3-IUyPcg
|
|
|
17
17
|
sleap_nn/architectures/utils.py,sha256=L0KVs0gbtG8U75Sl40oH_r_w2ySawh3oQPqIGi54HGo,2171
|
|
18
18
|
sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,42
|
|
19
19
|
sleap_nn/config/data_config.py,sha256=5a5YlXm4V9qGvkqgFNy6o0XJ_Q06UFjpYJXmNHfvXEI,24021
|
|
20
|
-
sleap_nn/config/get_config.py,sha256=
|
|
20
|
+
sleap_nn/config/get_config.py,sha256=rjNUffKU9z-ohLwrOVmJNGCqwUM93eh68h4KJfrSy8Y,42396
|
|
21
21
|
sleap_nn/config/model_config.py,sha256=XFIbqFno7IkX0Se5WF_2_7aUalAlC2SvpDe-uP2TttM,57582
|
|
22
|
-
sleap_nn/config/trainer_config.py,sha256=
|
|
22
|
+
sleap_nn/config/trainer_config.py,sha256=Ob2UqU10DXsQOnDb0iJxy0qc82CfP6FkQZQkrCvTEEY,29120
|
|
23
23
|
sleap_nn/config/training_job_config.py,sha256=v12_ME_tBUg8JFwOxJNW4sDQn-SedDhiJOGz-TlRwT0,5861
|
|
24
24
|
sleap_nn/config/utils.py,sha256=GgWgVs7_N7ifsJ5OQG3_EyOagNyN3Dx7wS2BAlkaRkg,5553
|
|
25
25
|
sleap_nn/data/__init__.py,sha256=eMNvFJFa3gv5Rq8oK5wzo6zt1pOlwUGYf8EQii6bq7c,54
|
|
26
26
|
sleap_nn/data/augmentation.py,sha256=Kqw_DayPth_DBsmaO1G8Voou_-cYZuSPOjSQWSajgRI,13618
|
|
27
27
|
sleap_nn/data/confidence_maps.py,sha256=PTRqZWSAz1S7viJhxu7QgIC1aHiek97c_dCUsKUwG1o,6217
|
|
28
|
-
sleap_nn/data/custom_datasets.py,sha256=
|
|
28
|
+
sleap_nn/data/custom_datasets.py,sha256=SO-aNB1-bB9DL5Zw-oGYDsliBxwI4iKX_FmwgZjKOgQ,99975
|
|
29
29
|
sleap_nn/data/edge_maps.py,sha256=75qG_7zHRw7fC8JUCVI2tzYakIoxxneWWmcrTwjcHPo,12519
|
|
30
30
|
sleap_nn/data/identity.py,sha256=7vNup6PudST4yDLyDT9wDO-cunRirTEvx4sP77xrlfk,5193
|
|
31
31
|
sleap_nn/data/instance_centroids.py,sha256=SF-3zJt_VMTbZI5ssbrvmZQZDd3684bn55EAtvcbQ6o,2172
|
|
@@ -35,7 +35,7 @@ sleap_nn/data/providers.py,sha256=0x6GFP1s1c08ji4p0M5V6p-dhT4Z9c-SI_Aw1DWX-uM,14
|
|
|
35
35
|
sleap_nn/data/resizing.py,sha256=YFpSQduIBkRK39FYmrqDL-v8zMySlEs6TJxh6zb_0ZU,5076
|
|
36
36
|
sleap_nn/data/utils.py,sha256=rT0w7KMOTlzaeKWq1TqjbgC4Lvjz_G96McllvEOqXx8,5641
|
|
37
37
|
sleap_nn/inference/__init__.py,sha256=eVkCmKrxHlDFJIlZTf8B5XEOcSyw-gPQymXMY5uShOM,170
|
|
38
|
-
sleap_nn/inference/bottomup.py,sha256=
|
|
38
|
+
sleap_nn/inference/bottomup.py,sha256=3s90aRlpIcRnSNe-R5-qiuX3S48kCWMpCl8YuNnTEDI,17084
|
|
39
39
|
sleap_nn/inference/identity.py,sha256=GjNDL9MfGqNyQaK4AE8JQCAE8gpMuE_Y-3r3Gpa53CE,6540
|
|
40
40
|
sleap_nn/inference/paf_grouping.py,sha256=7Fo9lCAj-zcHgv5rI5LIMYGcixCGNt_ZbSNs8Dik7l8,69973
|
|
41
41
|
sleap_nn/inference/peak_finding.py,sha256=L9LdYKt_Bfw7cxo6xEpgF8wXcZAwq5plCfmKJ839N40,13014
|
|
@@ -52,14 +52,14 @@ sleap_nn/tracking/candidates/__init__.py,sha256=1O7NObIwshM7j1rLHmImbFphvkM9wY1j
|
|
|
52
52
|
sleap_nn/tracking/candidates/fixed_window.py,sha256=D80KMlTnenuQveQVVhk9j0G8yx6K324C7nMLHgG76e0,6296
|
|
53
53
|
sleap_nn/tracking/candidates/local_queues.py,sha256=Nx3R5wwEwq0gbfH-fi3oOumfkQo8_sYe5GN47pD9Be8,7305
|
|
54
54
|
sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U,32
|
|
55
|
-
sleap_nn/training/callbacks.py,sha256=
|
|
56
|
-
sleap_nn/training/lightning_modules.py,sha256=
|
|
55
|
+
sleap_nn/training/callbacks.py,sha256=ZO88NFGZi53Wn4qM6yp3Bk3HFmhkYSGqeMc1QJKirLo,35995
|
|
56
|
+
sleap_nn/training/lightning_modules.py,sha256=slkVtQ7r6LatWLYzxcq6x1RALYNyHTRcqiXXwD-x0PA,95420
|
|
57
57
|
sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
|
|
58
|
-
sleap_nn/training/model_trainer.py,sha256=
|
|
58
|
+
sleap_nn/training/model_trainer.py,sha256=mf6FOdGDal2mMP0F1xD9jVQ54wbUST0ovRt6OjXzVyg,60580
|
|
59
59
|
sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
|
|
60
|
-
sleap_nn-0.1.
|
|
61
|
-
sleap_nn-0.1.
|
|
62
|
-
sleap_nn-0.1.
|
|
63
|
-
sleap_nn-0.1.
|
|
64
|
-
sleap_nn-0.1.
|
|
65
|
-
sleap_nn-0.1.
|
|
60
|
+
sleap_nn-0.1.0a2.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
61
|
+
sleap_nn-0.1.0a2.dist-info/METADATA,sha256=w0dUxvJerGIpu4hlYgGbimjCAooCcf_4NcAzo8T5Sos,5637
|
|
62
|
+
sleap_nn-0.1.0a2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
63
|
+
sleap_nn-0.1.0a2.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
|
|
64
|
+
sleap_nn-0.1.0a2.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
|
|
65
|
+
sleap_nn-0.1.0a2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|