sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/config/trainer_config.py +18 -0
- sleap_nn/evaluation.py +73 -22
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/training/callbacks.py +274 -0
- sleap_nn/training/lightning_modules.py +210 -2
- sleap_nn/training/model_trainer.py +23 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a2.dist-info}/METADATA +2 -2
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a2.dist-info}/RECORD +13 -13
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a2.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a2.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a2.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a2.dist-info}/top_level.txt +0 -0
sleap_nn/__init__.py
CHANGED
|
@@ -208,6 +208,23 @@ class EarlyStoppingConfig:
|
|
|
208
208
|
stop_training_on_plateau: bool = True
|
|
209
209
|
|
|
210
210
|
|
|
211
|
+
@define
|
|
212
|
+
class EvalConfig:
|
|
213
|
+
"""Configuration for epoch-end evaluation.
|
|
214
|
+
|
|
215
|
+
Attributes:
|
|
216
|
+
enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
|
|
217
|
+
frequency: (int) Evaluate every N epochs. *Default*: `1`.
|
|
218
|
+
oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
|
|
219
|
+
oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
enabled: bool = False
|
|
223
|
+
frequency: int = field(default=1, validator=validators.ge(1))
|
|
224
|
+
oks_stddev: float = field(default=0.025, validator=validators.gt(0))
|
|
225
|
+
oks_scale: Optional[float] = None
|
|
226
|
+
|
|
227
|
+
|
|
211
228
|
@define
|
|
212
229
|
class HardKeypointMiningConfig:
|
|
213
230
|
"""Configuration for online hard keypoint mining.
|
|
@@ -310,6 +327,7 @@ class TrainerConfig:
|
|
|
310
327
|
factory=HardKeypointMiningConfig
|
|
311
328
|
)
|
|
312
329
|
zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
|
|
330
|
+
eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
|
|
313
331
|
|
|
314
332
|
@staticmethod
|
|
315
333
|
def validate_optimizer_name(value):
|
sleap_nn/evaluation.py
CHANGED
|
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
|
|
|
29
29
|
"""
|
|
30
30
|
instance_list = []
|
|
31
31
|
frame_idx = labeled_frame.frame_idx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
32
|
+
|
|
33
|
+
# Extract video path with fallbacks for embedded videos
|
|
34
|
+
video = labeled_frame.video
|
|
35
|
+
video_path = None
|
|
36
|
+
if video is not None:
|
|
37
|
+
backend = getattr(video, "backend", None)
|
|
38
|
+
if backend is not None:
|
|
39
|
+
# Try source_filename first (for embedded videos with provenance)
|
|
40
|
+
video_path = getattr(backend, "source_filename", None)
|
|
41
|
+
if video_path is None:
|
|
42
|
+
video_path = getattr(backend, "filename", None)
|
|
43
|
+
# Fallback to video.filename if backend doesn't have it
|
|
44
|
+
if video_path is None:
|
|
45
|
+
video_path = getattr(video, "filename", None)
|
|
46
|
+
# Handle list filenames (image sequences)
|
|
47
|
+
if isinstance(video_path, list) and video_path:
|
|
48
|
+
video_path = video_path[0]
|
|
49
|
+
# Final fallback: use a unique identifier
|
|
50
|
+
if video_path is None:
|
|
51
|
+
video_path = f"video_{id(video)}" if video is not None else "unknown"
|
|
52
|
+
|
|
37
53
|
for instance in labeled_frame.instances:
|
|
38
54
|
match_instance = MatchInstance(
|
|
39
55
|
instance=instance, frame_idx=frame_idx, video_path=video_path
|
|
@@ -47,6 +63,10 @@ def find_frame_pairs(
|
|
|
47
63
|
) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
|
|
48
64
|
"""Find corresponding frames across two sets of labels.
|
|
49
65
|
|
|
66
|
+
This function uses sleap-io's robust video matching API to handle various
|
|
67
|
+
scenarios including embedded videos, cross-platform paths, and videos with
|
|
68
|
+
different metadata.
|
|
69
|
+
|
|
50
70
|
Args:
|
|
51
71
|
labels_gt: A `sio.Labels` instance with ground truth instances.
|
|
52
72
|
labels_pr: A `sio.Labels` instance with predicted instances.
|
|
@@ -56,16 +76,15 @@ def find_frame_pairs(
|
|
|
56
76
|
Returns:
|
|
57
77
|
A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
|
|
58
78
|
"""
|
|
79
|
+
# Use sleap-io's robust video matching API (added in 0.6.2)
|
|
80
|
+
# The match() method returns a MatchResult with video_map: {pred_video: gt_video}
|
|
81
|
+
match_result = labels_gt.match(labels_pr)
|
|
82
|
+
|
|
59
83
|
frame_pairs = []
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
if video_gt.matches_content(video) and video_gt.matches_path(video):
|
|
65
|
-
video_pr = video
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
if video_pr is None:
|
|
84
|
+
# Iterate over matched video pairs (pred_video -> gt_video mapping)
|
|
85
|
+
for video_pr, video_gt in match_result.video_map.items():
|
|
86
|
+
if video_gt is None:
|
|
87
|
+
# No match found for this prediction video
|
|
69
88
|
continue
|
|
70
89
|
|
|
71
90
|
# Find labeled frames in this video.
|
|
@@ -786,11 +805,26 @@ def run_evaluation(
|
|
|
786
805
|
"""Evaluate SLEAP-NN model predictions against ground truth labels."""
|
|
787
806
|
logger.info("Loading ground truth labels...")
|
|
788
807
|
ground_truth_instances = sio.load_slp(ground_truth_path)
|
|
808
|
+
logger.info(
|
|
809
|
+
f" Ground truth: {len(ground_truth_instances.videos)} videos, "
|
|
810
|
+
f"{len(ground_truth_instances.labeled_frames)} frames"
|
|
811
|
+
)
|
|
789
812
|
|
|
790
813
|
logger.info("Loading predicted labels...")
|
|
791
814
|
predicted_instances = sio.load_slp(predicted_path)
|
|
815
|
+
logger.info(
|
|
816
|
+
f" Predictions: {len(predicted_instances.videos)} videos, "
|
|
817
|
+
f"{len(predicted_instances.labeled_frames)} frames"
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
logger.info("Matching videos and frames...")
|
|
821
|
+
# Get match stats before creating evaluator
|
|
822
|
+
match_result = ground_truth_instances.match(predicted_instances)
|
|
823
|
+
logger.info(
|
|
824
|
+
f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
|
|
825
|
+
)
|
|
792
826
|
|
|
793
|
-
logger.info("
|
|
827
|
+
logger.info("Matching instances...")
|
|
794
828
|
evaluator = Evaluator(
|
|
795
829
|
ground_truth_instances=ground_truth_instances,
|
|
796
830
|
predicted_instances=predicted_instances,
|
|
@@ -799,21 +833,38 @@ def run_evaluation(
|
|
|
799
833
|
match_threshold=match_threshold,
|
|
800
834
|
user_labels_only=user_labels_only,
|
|
801
835
|
)
|
|
836
|
+
logger.info(
|
|
837
|
+
f" Frame pairs: {len(evaluator.frame_pairs)}, "
|
|
838
|
+
f"Matched instances: {len(evaluator.positive_pairs)}, "
|
|
839
|
+
f"Unmatched GT: {len(evaluator.false_negatives)}"
|
|
840
|
+
)
|
|
802
841
|
|
|
803
842
|
logger.info("Computing evaluation metrics...")
|
|
804
843
|
metrics = evaluator.evaluate()
|
|
805
844
|
|
|
845
|
+
# Compute PCK at specific thresholds (5 and 10 pixels)
|
|
846
|
+
dists = metrics["distance_metrics"]["dists"]
|
|
847
|
+
dists_clean = np.copy(dists)
|
|
848
|
+
dists_clean[np.isnan(dists_clean)] = np.inf
|
|
849
|
+
pck_5 = (dists_clean < 5).mean()
|
|
850
|
+
pck_10 = (dists_clean < 10).mean()
|
|
851
|
+
|
|
806
852
|
# Print key metrics
|
|
807
853
|
logger.info("Evaluation Results:")
|
|
808
|
-
logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
809
|
-
logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
810
|
-
logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
811
|
-
logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.
|
|
812
|
-
logger.info(f"
|
|
854
|
+
logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
855
|
+
logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
856
|
+
logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
857
|
+
logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
|
|
858
|
+
logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
|
|
859
|
+
logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
|
|
860
|
+
logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
|
|
861
|
+
logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
|
|
862
|
+
logger.info(f" PCK@5px: {pck_5:.4f}")
|
|
863
|
+
logger.info(f" PCK@10px: {pck_10:.4f}")
|
|
813
864
|
logger.info(
|
|
814
|
-
f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
865
|
+
f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
815
866
|
)
|
|
816
|
-
logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
867
|
+
logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
817
868
|
|
|
818
869
|
# Save metrics if path provided
|
|
819
870
|
if save_metrics:
|
sleap_nn/inference/bottomup.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Inference modules for BottomUp models."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Dict, Optional
|
|
4
5
|
import torch
|
|
5
6
|
import lightning as L
|
|
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
|
|
|
7
8
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
8
9
|
from sleap_nn.inference.identity import classify_peaks_from_maps
|
|
9
10
|
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class BottomUpInferenceModel(L.LightningModule):
|
|
12
15
|
"""BottomUp Inference model.
|
|
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
63
66
|
return_pafs: Optional[bool] = False,
|
|
64
67
|
return_paf_graph: Optional[bool] = False,
|
|
65
68
|
input_scale: float = 1.0,
|
|
69
|
+
max_peaks_per_node: Optional[int] = None,
|
|
66
70
|
):
|
|
67
|
-
"""Initialise the model attributes.
|
|
71
|
+
"""Initialise the model attributes.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
torch_model: A `nn.Module` that accepts images and predicts confidence maps.
|
|
75
|
+
paf_scorer: A `PAFScorer` instance for grouping instances.
|
|
76
|
+
cms_output_stride: Output stride of confidence maps relative to images.
|
|
77
|
+
pafs_output_stride: Output stride of PAFs relative to images.
|
|
78
|
+
peak_threshold: Minimum confidence map value for valid peaks.
|
|
79
|
+
refinement: Peak refinement method: None, "integral", or "local".
|
|
80
|
+
integral_patch_size: Size of patches for integral refinement.
|
|
81
|
+
return_confmaps: If True, return confidence maps in output.
|
|
82
|
+
return_pafs: If True, return PAFs in output.
|
|
83
|
+
return_paf_graph: If True, return intermediate PAF graph in output.
|
|
84
|
+
input_scale: Scale factor applied to input images.
|
|
85
|
+
max_peaks_per_node: Maximum number of peaks allowed per node before
|
|
86
|
+
skipping PAF scoring. If any node has more peaks than this limit,
|
|
87
|
+
empty predictions are returned. This prevents combinatorial explosion
|
|
88
|
+
during early training when confidence maps are noisy. Set to None to
|
|
89
|
+
disable this check (default). Recommended value: 100.
|
|
90
|
+
"""
|
|
68
91
|
super().__init__()
|
|
69
92
|
self.torch_model = torch_model
|
|
70
93
|
self.paf_scorer = paf_scorer
|
|
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
77
100
|
self.return_pafs = return_pafs
|
|
78
101
|
self.return_paf_graph = return_paf_graph
|
|
79
102
|
self.input_scale = input_scale
|
|
103
|
+
self.max_peaks_per_node = max_peaks_per_node
|
|
80
104
|
|
|
81
105
|
def _generate_cms_peaks(self, cms):
|
|
82
106
|
# TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
|
|
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
124
148
|
) # (batch, h, w, 2*edges)
|
|
125
149
|
cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
|
|
126
150
|
|
|
127
|
-
(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
151
|
+
# Check if too many peaks per node (prevents combinatorial explosion)
|
|
152
|
+
skip_paf_scoring = False
|
|
153
|
+
if self.max_peaks_per_node is not None:
|
|
154
|
+
n_nodes = cms.shape[1]
|
|
155
|
+
for b in range(self.batch_size):
|
|
156
|
+
for node_idx in range(n_nodes):
|
|
157
|
+
n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
|
|
158
|
+
if n_peaks > self.max_peaks_per_node:
|
|
159
|
+
logger.warning(
|
|
160
|
+
f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
|
|
161
|
+
f"(max_peaks_per_node={self.max_peaks_per_node}). "
|
|
162
|
+
f"Model may need more training."
|
|
163
|
+
)
|
|
164
|
+
skip_paf_scoring = True
|
|
165
|
+
break
|
|
166
|
+
if skip_paf_scoring:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
if skip_paf_scoring:
|
|
170
|
+
# Return empty predictions for each sample
|
|
171
|
+
device = cms.device
|
|
172
|
+
n_nodes = cms.shape[1]
|
|
173
|
+
predicted_instances_adjusted = []
|
|
174
|
+
predicted_peak_scores = []
|
|
175
|
+
predicted_instance_scores = []
|
|
176
|
+
for _ in range(self.batch_size):
|
|
177
|
+
predicted_instances_adjusted.append(
|
|
178
|
+
torch.full((0, n_nodes, 2), float("nan"), device=device)
|
|
179
|
+
)
|
|
180
|
+
predicted_peak_scores.append(
|
|
181
|
+
torch.full((0, n_nodes), float("nan"), device=device)
|
|
182
|
+
)
|
|
183
|
+
predicted_instance_scores.append(torch.tensor([], device=device))
|
|
184
|
+
edge_inds = [
|
|
185
|
+
torch.tensor([], dtype=torch.int32, device=device)
|
|
186
|
+
] * self.batch_size
|
|
187
|
+
edge_peak_inds = [
|
|
188
|
+
torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
|
|
189
|
+
] * self.batch_size
|
|
190
|
+
line_scores = [torch.tensor([], device=device)] * self.batch_size
|
|
191
|
+
else:
|
|
192
|
+
(
|
|
193
|
+
predicted_instances,
|
|
194
|
+
predicted_peak_scores,
|
|
195
|
+
predicted_instance_scores,
|
|
196
|
+
edge_inds,
|
|
197
|
+
edge_peak_inds,
|
|
198
|
+
line_scores,
|
|
199
|
+
) = self.paf_scorer.predict(
|
|
200
|
+
pafs=pafs,
|
|
201
|
+
peaks=cms_peaks,
|
|
202
|
+
peak_vals=cms_peak_vals,
|
|
203
|
+
peak_channel_inds=cms_peak_channel_inds,
|
|
146
204
|
)
|
|
205
|
+
|
|
206
|
+
predicted_instances = [p / self.input_scale for p in predicted_instances]
|
|
207
|
+
predicted_instances_adjusted = []
|
|
208
|
+
for idx, p in enumerate(predicted_instances):
|
|
209
|
+
predicted_instances_adjusted.append(
|
|
210
|
+
p / inputs["eff_scale"][idx].to(p.device)
|
|
211
|
+
)
|
|
212
|
+
|
|
147
213
|
out = {
|
|
148
214
|
"pred_instance_peaks": predicted_instances_adjusted,
|
|
149
215
|
"pred_peak_values": predicted_peak_scores,
|
sleap_nn/training/callbacks.py
CHANGED
|
@@ -662,3 +662,277 @@ class ProgressReporterZMQ(Callback):
|
|
|
662
662
|
return {
|
|
663
663
|
k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
|
|
664
664
|
}
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
class EpochEndEvaluationCallback(Callback):
|
|
668
|
+
"""Callback to run full evaluation metrics at end of validation epochs.
|
|
669
|
+
|
|
670
|
+
This callback collects predictions and ground truth during validation,
|
|
671
|
+
then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
|
|
672
|
+
metrics to WandB.
|
|
673
|
+
|
|
674
|
+
Attributes:
|
|
675
|
+
skeleton: sio.Skeleton for creating instances.
|
|
676
|
+
videos: List of sio.Video objects.
|
|
677
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
678
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
679
|
+
oks_scale: Optional OKS scale override.
|
|
680
|
+
metrics_to_log: List of metric keys to log.
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
def __init__(
|
|
684
|
+
self,
|
|
685
|
+
skeleton: "sio.Skeleton",
|
|
686
|
+
videos: list,
|
|
687
|
+
eval_frequency: int = 1,
|
|
688
|
+
oks_stddev: float = 0.025,
|
|
689
|
+
oks_scale: Optional[float] = None,
|
|
690
|
+
metrics_to_log: Optional[list] = None,
|
|
691
|
+
):
|
|
692
|
+
"""Initialize the callback.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
skeleton: sio.Skeleton for creating instances.
|
|
696
|
+
videos: List of sio.Video objects.
|
|
697
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
698
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
699
|
+
oks_scale: Optional OKS scale override.
|
|
700
|
+
metrics_to_log: List of metric keys to log. If None, logs all available.
|
|
701
|
+
"""
|
|
702
|
+
super().__init__()
|
|
703
|
+
self.skeleton = skeleton
|
|
704
|
+
self.videos = videos
|
|
705
|
+
self.eval_frequency = eval_frequency
|
|
706
|
+
self.oks_stddev = oks_stddev
|
|
707
|
+
self.oks_scale = oks_scale
|
|
708
|
+
self.metrics_to_log = metrics_to_log or [
|
|
709
|
+
"mOKS",
|
|
710
|
+
"oks_voc.mAP",
|
|
711
|
+
"oks_voc.mAR",
|
|
712
|
+
"avg_distance",
|
|
713
|
+
"p50_distance",
|
|
714
|
+
"mPCK",
|
|
715
|
+
"visibility_precision",
|
|
716
|
+
"visibility_recall",
|
|
717
|
+
]
|
|
718
|
+
|
|
719
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
|
720
|
+
"""Enable prediction collection at the start of validation.
|
|
721
|
+
|
|
722
|
+
Skip during sanity check to avoid inference issues.
|
|
723
|
+
"""
|
|
724
|
+
if trainer.sanity_checking:
|
|
725
|
+
return
|
|
726
|
+
pl_module._collect_val_predictions = True
|
|
727
|
+
|
|
728
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
729
|
+
"""Run evaluation and log metrics at end of validation epoch."""
|
|
730
|
+
import sleap_io as sio
|
|
731
|
+
import numpy as np
|
|
732
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
733
|
+
from sleap_nn.evaluation import Evaluator
|
|
734
|
+
|
|
735
|
+
# Check frequency (epoch is 0-indexed, so add 1)
|
|
736
|
+
if (trainer.current_epoch + 1) % self.eval_frequency != 0:
|
|
737
|
+
pl_module._collect_val_predictions = False
|
|
738
|
+
return
|
|
739
|
+
|
|
740
|
+
# Only run on rank 0 for distributed training
|
|
741
|
+
if not trainer.is_global_zero:
|
|
742
|
+
pl_module._collect_val_predictions = False
|
|
743
|
+
return
|
|
744
|
+
|
|
745
|
+
# Check if we have predictions
|
|
746
|
+
if not pl_module.val_predictions or not pl_module.val_ground_truth:
|
|
747
|
+
logger.warning("No predictions collected for epoch-end evaluation")
|
|
748
|
+
pl_module._collect_val_predictions = False
|
|
749
|
+
return
|
|
750
|
+
|
|
751
|
+
try:
|
|
752
|
+
# Build sio.Labels from accumulated predictions and ground truth
|
|
753
|
+
pred_labels = self._build_pred_labels(pl_module.val_predictions, sio, np)
|
|
754
|
+
gt_labels = self._build_gt_labels(pl_module.val_ground_truth, sio, np)
|
|
755
|
+
|
|
756
|
+
# Check if we have valid frames to evaluate
|
|
757
|
+
if len(pred_labels) == 0:
|
|
758
|
+
logger.warning(
|
|
759
|
+
"No valid predictions for epoch-end evaluation "
|
|
760
|
+
"(all predictions may be empty or NaN)"
|
|
761
|
+
)
|
|
762
|
+
pl_module._collect_val_predictions = False
|
|
763
|
+
pl_module.val_predictions = []
|
|
764
|
+
pl_module.val_ground_truth = []
|
|
765
|
+
return
|
|
766
|
+
|
|
767
|
+
# Run evaluation
|
|
768
|
+
evaluator = Evaluator(
|
|
769
|
+
ground_truth_instances=gt_labels,
|
|
770
|
+
predicted_instances=pred_labels,
|
|
771
|
+
oks_stddev=self.oks_stddev,
|
|
772
|
+
oks_scale=self.oks_scale,
|
|
773
|
+
user_labels_only=False, # All validation frames are "user" frames
|
|
774
|
+
)
|
|
775
|
+
metrics = evaluator.evaluate()
|
|
776
|
+
|
|
777
|
+
# Log to WandB
|
|
778
|
+
self._log_metrics(trainer, metrics, trainer.current_epoch)
|
|
779
|
+
|
|
780
|
+
logger.info(
|
|
781
|
+
f"Epoch {trainer.current_epoch} evaluation: "
|
|
782
|
+
f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
|
|
783
|
+
f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
except Exception as e:
|
|
787
|
+
logger.warning(f"Epoch-end evaluation failed: {e}")
|
|
788
|
+
|
|
789
|
+
# Cleanup
|
|
790
|
+
pl_module._collect_val_predictions = False
|
|
791
|
+
pl_module.val_predictions = []
|
|
792
|
+
pl_module.val_ground_truth = []
|
|
793
|
+
|
|
794
|
+
def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
|
|
795
|
+
"""Convert prediction dicts to sio.Labels."""
|
|
796
|
+
labeled_frames = []
|
|
797
|
+
for pred in predictions:
|
|
798
|
+
pred_peaks = pred["pred_peaks"]
|
|
799
|
+
pred_scores = pred["pred_scores"]
|
|
800
|
+
|
|
801
|
+
# Handle NaN/missing predictions
|
|
802
|
+
if pred_peaks is None or (
|
|
803
|
+
isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
|
|
804
|
+
):
|
|
805
|
+
continue
|
|
806
|
+
|
|
807
|
+
# Handle multi-instance predictions (bottomup)
|
|
808
|
+
if len(pred_peaks.shape) == 2:
|
|
809
|
+
# Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
|
|
810
|
+
pred_peaks = pred_peaks.reshape(1, -1, 2)
|
|
811
|
+
pred_scores = pred_scores.reshape(1, -1)
|
|
812
|
+
|
|
813
|
+
instances = []
|
|
814
|
+
for inst_idx in range(len(pred_peaks)):
|
|
815
|
+
inst_points = pred_peaks[inst_idx]
|
|
816
|
+
inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
|
|
817
|
+
|
|
818
|
+
# Skip if all NaN
|
|
819
|
+
if np.isnan(inst_points).all():
|
|
820
|
+
continue
|
|
821
|
+
|
|
822
|
+
inst = sio.PredictedInstance.from_numpy(
|
|
823
|
+
points_data=inst_points,
|
|
824
|
+
skeleton=self.skeleton,
|
|
825
|
+
point_scores=(
|
|
826
|
+
inst_scores
|
|
827
|
+
if inst_scores is not None
|
|
828
|
+
else np.ones(len(inst_points))
|
|
829
|
+
),
|
|
830
|
+
score=(
|
|
831
|
+
float(np.nanmean(inst_scores))
|
|
832
|
+
if inst_scores is not None
|
|
833
|
+
else 1.0
|
|
834
|
+
),
|
|
835
|
+
)
|
|
836
|
+
instances.append(inst)
|
|
837
|
+
|
|
838
|
+
if instances:
|
|
839
|
+
lf = sio.LabeledFrame(
|
|
840
|
+
video=self.videos[pred["video_idx"]],
|
|
841
|
+
frame_idx=pred["frame_idx"],
|
|
842
|
+
instances=instances,
|
|
843
|
+
)
|
|
844
|
+
labeled_frames.append(lf)
|
|
845
|
+
|
|
846
|
+
return sio.Labels(
|
|
847
|
+
videos=self.videos,
|
|
848
|
+
skeletons=[self.skeleton],
|
|
849
|
+
labeled_frames=labeled_frames,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
|
|
853
|
+
"""Convert ground truth dicts to sio.Labels."""
|
|
854
|
+
labeled_frames = []
|
|
855
|
+
for gt in ground_truth:
|
|
856
|
+
instances = []
|
|
857
|
+
gt_instances = gt["gt_instances"]
|
|
858
|
+
|
|
859
|
+
# Handle shape variations
|
|
860
|
+
if len(gt_instances.shape) == 2:
|
|
861
|
+
# (n_nodes, 2) -> (1, n_nodes, 2)
|
|
862
|
+
gt_instances = gt_instances.reshape(1, -1, 2)
|
|
863
|
+
|
|
864
|
+
for i in range(min(gt["num_instances"], len(gt_instances))):
|
|
865
|
+
inst_data = gt_instances[i]
|
|
866
|
+
if np.isnan(inst_data).all():
|
|
867
|
+
continue
|
|
868
|
+
inst = sio.Instance.from_numpy(
|
|
869
|
+
points_data=inst_data,
|
|
870
|
+
skeleton=self.skeleton,
|
|
871
|
+
)
|
|
872
|
+
instances.append(inst)
|
|
873
|
+
|
|
874
|
+
if instances:
|
|
875
|
+
lf = sio.LabeledFrame(
|
|
876
|
+
video=self.videos[gt["video_idx"]],
|
|
877
|
+
frame_idx=gt["frame_idx"],
|
|
878
|
+
instances=instances,
|
|
879
|
+
)
|
|
880
|
+
labeled_frames.append(lf)
|
|
881
|
+
|
|
882
|
+
return sio.Labels(
|
|
883
|
+
videos=self.videos,
|
|
884
|
+
skeletons=[self.skeleton],
|
|
885
|
+
labeled_frames=labeled_frames,
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
def _log_metrics(self, trainer, metrics: dict, epoch: int):
|
|
889
|
+
"""Log evaluation metrics to WandB."""
|
|
890
|
+
import numpy as np
|
|
891
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
892
|
+
|
|
893
|
+
# Get WandB logger
|
|
894
|
+
wandb_logger = None
|
|
895
|
+
for log in trainer.loggers:
|
|
896
|
+
if isinstance(log, WandbLogger):
|
|
897
|
+
wandb_logger = log
|
|
898
|
+
break
|
|
899
|
+
|
|
900
|
+
if wandb_logger is None:
|
|
901
|
+
return
|
|
902
|
+
|
|
903
|
+
log_dict = {"epoch": epoch}
|
|
904
|
+
|
|
905
|
+
# Extract key metrics with consistent naming
|
|
906
|
+
if "mOKS" in self.metrics_to_log:
|
|
907
|
+
log_dict["val_mOKS"] = metrics["mOKS"]["mOKS"]
|
|
908
|
+
|
|
909
|
+
if "oks_voc.mAP" in self.metrics_to_log:
|
|
910
|
+
log_dict["val_oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
|
|
911
|
+
|
|
912
|
+
if "oks_voc.mAR" in self.metrics_to_log:
|
|
913
|
+
log_dict["val_oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
|
|
914
|
+
|
|
915
|
+
if "avg_distance" in self.metrics_to_log:
|
|
916
|
+
val = metrics["distance_metrics"]["avg"]
|
|
917
|
+
if not np.isnan(val):
|
|
918
|
+
log_dict["val_avg_distance"] = val
|
|
919
|
+
|
|
920
|
+
if "p50_distance" in self.metrics_to_log:
|
|
921
|
+
val = metrics["distance_metrics"]["p50"]
|
|
922
|
+
if not np.isnan(val):
|
|
923
|
+
log_dict["val_p50_distance"] = val
|
|
924
|
+
|
|
925
|
+
if "mPCK" in self.metrics_to_log:
|
|
926
|
+
log_dict["val_mPCK"] = metrics["pck_metrics"]["mPCK"]
|
|
927
|
+
|
|
928
|
+
if "visibility_precision" in self.metrics_to_log:
|
|
929
|
+
val = metrics["visibility_metrics"]["precision"]
|
|
930
|
+
if not np.isnan(val):
|
|
931
|
+
log_dict["val_visibility_precision"] = val
|
|
932
|
+
|
|
933
|
+
if "visibility_recall" in self.metrics_to_log:
|
|
934
|
+
val = metrics["visibility_metrics"]["recall"]
|
|
935
|
+
if not np.isnan(val):
|
|
936
|
+
log_dict["val_visibility_recall"] = val
|
|
937
|
+
|
|
938
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""This module has the LightningModule classes for all model types."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional, Union, Dict, Any
|
|
3
|
+
from typing import Optional, Union, Dict, Any, List
|
|
4
4
|
import time
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
@@ -184,6 +184,11 @@ class LightningModel(L.LightningModule):
|
|
|
184
184
|
self.val_loss = {}
|
|
185
185
|
self.learning_rate = {}
|
|
186
186
|
|
|
187
|
+
# For epoch-end evaluation
|
|
188
|
+
self.val_predictions: List[Dict] = []
|
|
189
|
+
self.val_ground_truth: List[Dict] = []
|
|
190
|
+
self._collect_val_predictions: bool = False
|
|
191
|
+
|
|
187
192
|
# Initialization for encoder and decoder stacks.
|
|
188
193
|
if self.init_weights == "xavier":
|
|
189
194
|
self.model.apply(xavier_init_weights)
|
|
@@ -331,6 +336,9 @@ class LightningModel(L.LightningModule):
|
|
|
331
336
|
def on_validation_epoch_start(self):
|
|
332
337
|
"""Configure the val timer at the beginning of each epoch."""
|
|
333
338
|
self.val_start_time = time.time()
|
|
339
|
+
# Clear accumulated predictions for new epoch
|
|
340
|
+
self.val_predictions = []
|
|
341
|
+
self.val_ground_truth = []
|
|
334
342
|
|
|
335
343
|
def on_validation_epoch_end(self):
|
|
336
344
|
"""Configure the val timer at the end of every epoch."""
|
|
@@ -639,6 +647,51 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
639
647
|
sync_dist=True,
|
|
640
648
|
)
|
|
641
649
|
|
|
650
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
651
|
+
if self._collect_val_predictions:
|
|
652
|
+
with torch.no_grad():
|
|
653
|
+
# Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
|
|
654
|
+
inference_batch = {k: v for k, v in batch.items()}
|
|
655
|
+
if inference_batch["image"].ndim == 5:
|
|
656
|
+
inference_batch["image"] = inference_batch["image"].squeeze(1)
|
|
657
|
+
inference_output = self.single_instance_inf_layer(inference_batch)
|
|
658
|
+
if isinstance(inference_output, list):
|
|
659
|
+
inference_output = inference_output[0]
|
|
660
|
+
|
|
661
|
+
batch_size = len(batch["frame_idx"])
|
|
662
|
+
for i in range(batch_size):
|
|
663
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
664
|
+
|
|
665
|
+
# Predictions are already in original image space (inference divides by eff_scale)
|
|
666
|
+
pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
667
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
668
|
+
|
|
669
|
+
# Transform GT from preprocessed to original image space
|
|
670
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
671
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
672
|
+
if gt_prep.ndim == 4:
|
|
673
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
674
|
+
gt_orig = gt_prep / eff
|
|
675
|
+
num_inst = batch["num_instances"][i].item()
|
|
676
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
677
|
+
|
|
678
|
+
self.val_predictions.append(
|
|
679
|
+
{
|
|
680
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
681
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
682
|
+
"pred_peaks": pred_peaks,
|
|
683
|
+
"pred_scores": pred_scores,
|
|
684
|
+
}
|
|
685
|
+
)
|
|
686
|
+
self.val_ground_truth.append(
|
|
687
|
+
{
|
|
688
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
689
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
690
|
+
"gt_instances": gt_orig,
|
|
691
|
+
"num_instances": num_inst,
|
|
692
|
+
}
|
|
693
|
+
)
|
|
694
|
+
|
|
642
695
|
|
|
643
696
|
class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
644
697
|
"""Lightning Module for TopDownCenteredInstance Model.
|
|
@@ -856,6 +909,62 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
856
909
|
sync_dist=True,
|
|
857
910
|
)
|
|
858
911
|
|
|
912
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
913
|
+
if self._collect_val_predictions:
|
|
914
|
+
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
915
|
+
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
916
|
+
|
|
917
|
+
with torch.no_grad():
|
|
918
|
+
inference_output = self.instance_peaks_inf_layer(batch)
|
|
919
|
+
|
|
920
|
+
batch_size = len(batch["frame_idx"])
|
|
921
|
+
for i in range(batch_size):
|
|
922
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
923
|
+
|
|
924
|
+
# Predictions from inference (crop-relative, original scale)
|
|
925
|
+
pred_peaks_crop = (
|
|
926
|
+
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
927
|
+
)
|
|
928
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
929
|
+
|
|
930
|
+
# Compute bbox offset in original space from SAVED prep bbox
|
|
931
|
+
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
932
|
+
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
933
|
+
bbox_top_left_orig = (
|
|
934
|
+
bbox_prep[0] / eff
|
|
935
|
+
) # Top-left corner in original space
|
|
936
|
+
|
|
937
|
+
# Full image coordinates (original space)
|
|
938
|
+
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
939
|
+
|
|
940
|
+
# GT transform: crop-relative preprocessed -> full image original
|
|
941
|
+
gt_crop_prep = (
|
|
942
|
+
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
943
|
+
) # (n_nodes, 2)
|
|
944
|
+
gt_crop_orig = gt_crop_prep / eff
|
|
945
|
+
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
946
|
+
|
|
947
|
+
self.val_predictions.append(
|
|
948
|
+
{
|
|
949
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
950
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
951
|
+
"pred_peaks": pred_peaks_full.reshape(
|
|
952
|
+
1, -1, 2
|
|
953
|
+
), # (1, n_nodes, 2)
|
|
954
|
+
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
955
|
+
}
|
|
956
|
+
)
|
|
957
|
+
self.val_ground_truth.append(
|
|
958
|
+
{
|
|
959
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
960
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
961
|
+
"gt_instances": gt_full_orig.reshape(
|
|
962
|
+
1, -1, 2
|
|
963
|
+
), # (1, n_nodes, 2)
|
|
964
|
+
"num_instances": 1,
|
|
965
|
+
}
|
|
966
|
+
)
|
|
967
|
+
|
|
859
968
|
|
|
860
969
|
class CentroidLightningModule(LightningModel):
|
|
861
970
|
"""Lightning Module for Centroid Model.
|
|
@@ -1034,6 +1143,57 @@ class CentroidLightningModule(LightningModel):
|
|
|
1034
1143
|
sync_dist=True,
|
|
1035
1144
|
)
|
|
1036
1145
|
|
|
1146
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1147
|
+
if self._collect_val_predictions:
|
|
1148
|
+
with torch.no_grad():
|
|
1149
|
+
inference_output = self.centroid_inf_layer(batch)
|
|
1150
|
+
|
|
1151
|
+
batch_size = len(batch["frame_idx"])
|
|
1152
|
+
for i in range(batch_size):
|
|
1153
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1154
|
+
|
|
1155
|
+
# Predictions are in original image space (inference divides by eff_scale)
|
|
1156
|
+
# centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
|
|
1157
|
+
pred_centroids = (
|
|
1158
|
+
inference_output["centroids"][i].squeeze(0).cpu().numpy()
|
|
1159
|
+
)
|
|
1160
|
+
pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
|
|
1161
|
+
|
|
1162
|
+
# Transform GT centroids from preprocessed to original image space
|
|
1163
|
+
gt_centroids_prep = (
|
|
1164
|
+
batch["centroids"][i].cpu().numpy()
|
|
1165
|
+
) # (n_samples=1, max_inst, 2)
|
|
1166
|
+
gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
|
|
1167
|
+
num_inst = batch["num_instances"][i].item()
|
|
1168
|
+
|
|
1169
|
+
# Filter to valid instances (non-NaN)
|
|
1170
|
+
valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
|
|
1171
|
+
pred_centroids = pred_centroids[valid_pred_mask]
|
|
1172
|
+
pred_vals = pred_vals[valid_pred_mask]
|
|
1173
|
+
|
|
1174
|
+
gt_centroids_valid = gt_centroids_orig[:num_inst]
|
|
1175
|
+
|
|
1176
|
+
self.val_predictions.append(
|
|
1177
|
+
{
|
|
1178
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1179
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1180
|
+
"pred_peaks": pred_centroids.reshape(
|
|
1181
|
+
-1, 1, 2
|
|
1182
|
+
), # (n_inst, 1, 2)
|
|
1183
|
+
"pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
|
|
1184
|
+
}
|
|
1185
|
+
)
|
|
1186
|
+
self.val_ground_truth.append(
|
|
1187
|
+
{
|
|
1188
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1189
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1190
|
+
"gt_instances": gt_centroids_valid.reshape(
|
|
1191
|
+
-1, 1, 2
|
|
1192
|
+
), # (n_inst, 1, 2)
|
|
1193
|
+
"num_instances": num_inst,
|
|
1194
|
+
}
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1037
1197
|
|
|
1038
1198
|
class BottomUpLightningModule(LightningModel):
|
|
1039
1199
|
"""Lightning Module for BottomUp Model.
|
|
@@ -1126,12 +1286,13 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1126
1286
|
self.bottomup_inf_layer = BottomUpInferenceModel(
|
|
1127
1287
|
torch_model=self.forward,
|
|
1128
1288
|
paf_scorer=paf_scorer,
|
|
1129
|
-
peak_threshold=0.
|
|
1289
|
+
peak_threshold=0.1, # Lower threshold for epoch-end eval during training
|
|
1130
1290
|
input_scale=1.0,
|
|
1131
1291
|
return_confmaps=True,
|
|
1132
1292
|
return_pafs=True,
|
|
1133
1293
|
cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
|
|
1134
1294
|
pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
|
|
1295
|
+
max_peaks_per_node=100, # Prevents combinatorial explosion in early training
|
|
1135
1296
|
)
|
|
1136
1297
|
self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
|
|
1137
1298
|
|
|
@@ -1340,6 +1501,53 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1340
1501
|
sync_dist=True,
|
|
1341
1502
|
)
|
|
1342
1503
|
|
|
1504
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1505
|
+
if self._collect_val_predictions:
|
|
1506
|
+
with torch.no_grad():
|
|
1507
|
+
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
1508
|
+
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
1509
|
+
inference_output = self.bottomup_inf_layer(batch)
|
|
1510
|
+
if isinstance(inference_output, list):
|
|
1511
|
+
inference_output = inference_output[0]
|
|
1512
|
+
|
|
1513
|
+
batch_size = len(batch["frame_idx"])
|
|
1514
|
+
for i in range(batch_size):
|
|
1515
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1516
|
+
|
|
1517
|
+
# Predictions are already in original space (variable number of instances)
|
|
1518
|
+
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
1519
|
+
pred_scores = inference_output["pred_peak_values"][i]
|
|
1520
|
+
if torch.is_tensor(pred_peaks):
|
|
1521
|
+
pred_peaks = pred_peaks.cpu().numpy()
|
|
1522
|
+
if torch.is_tensor(pred_scores):
|
|
1523
|
+
pred_scores = pred_scores.cpu().numpy()
|
|
1524
|
+
|
|
1525
|
+
# Transform GT to original space
|
|
1526
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
1527
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
1528
|
+
if gt_prep.ndim == 4:
|
|
1529
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
1530
|
+
gt_orig = gt_prep / eff
|
|
1531
|
+
num_inst = batch["num_instances"][i].item()
|
|
1532
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
1533
|
+
|
|
1534
|
+
self.val_predictions.append(
|
|
1535
|
+
{
|
|
1536
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1537
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1538
|
+
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
1539
|
+
"pred_scores": pred_scores,
|
|
1540
|
+
}
|
|
1541
|
+
)
|
|
1542
|
+
self.val_ground_truth.append(
|
|
1543
|
+
{
|
|
1544
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1545
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1546
|
+
"gt_instances": gt_orig, # Original space
|
|
1547
|
+
"num_instances": num_inst,
|
|
1548
|
+
}
|
|
1549
|
+
)
|
|
1550
|
+
|
|
1343
1551
|
|
|
1344
1552
|
class BottomUpMultiClassLightningModule(LightningModel):
|
|
1345
1553
|
"""Lightning Module for BottomUp ID Model.
|
|
@@ -61,6 +61,7 @@ from sleap_nn.training.callbacks import (
|
|
|
61
61
|
WandBVizCallbackWithPAFs,
|
|
62
62
|
CSVLoggerCallback,
|
|
63
63
|
SleapProgressBar,
|
|
64
|
+
EpochEndEvaluationCallback,
|
|
64
65
|
)
|
|
65
66
|
from sleap_nn import RANK
|
|
66
67
|
from sleap_nn.legacy_models import get_keras_first_layer_channels
|
|
@@ -1086,6 +1087,18 @@ class ModelTrainer:
|
|
|
1086
1087
|
if self.config.trainer_config.enable_progress_bar:
|
|
1087
1088
|
callbacks.append(SleapProgressBar())
|
|
1088
1089
|
|
|
1090
|
+
# Add epoch-end evaluation callback if enabled
|
|
1091
|
+
if self.config.trainer_config.eval.enabled:
|
|
1092
|
+
callbacks.append(
|
|
1093
|
+
EpochEndEvaluationCallback(
|
|
1094
|
+
skeleton=self.skeletons[0],
|
|
1095
|
+
videos=self.val_labels[0].videos,
|
|
1096
|
+
eval_frequency=self.config.trainer_config.eval.frequency,
|
|
1097
|
+
oks_stddev=self.config.trainer_config.eval.oks_stddev,
|
|
1098
|
+
oks_scale=self.config.trainer_config.eval.oks_scale,
|
|
1099
|
+
)
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1089
1102
|
return loggers, callbacks
|
|
1090
1103
|
|
|
1091
1104
|
def _delete_cache_imgs(self):
|
|
@@ -1281,6 +1294,16 @@ class ModelTrainer:
|
|
|
1281
1294
|
wandb.define_metric("train_pafs*", step_metric="epoch")
|
|
1282
1295
|
wandb.define_metric("val_pafs*", step_metric="epoch")
|
|
1283
1296
|
|
|
1297
|
+
# Evaluation metrics use epoch as x-axis
|
|
1298
|
+
wandb.define_metric("val_mOKS", step_metric="epoch")
|
|
1299
|
+
wandb.define_metric("val_oks_voc_mAP", step_metric="epoch")
|
|
1300
|
+
wandb.define_metric("val_oks_voc_mAR", step_metric="epoch")
|
|
1301
|
+
wandb.define_metric("val_avg_distance", step_metric="epoch")
|
|
1302
|
+
wandb.define_metric("val_p50_distance", step_metric="epoch")
|
|
1303
|
+
wandb.define_metric("val_mPCK", step_metric="epoch")
|
|
1304
|
+
wandb.define_metric("val_visibility_precision", step_metric="epoch")
|
|
1305
|
+
wandb.define_metric("val_visibility_recall", step_metric="epoch")
|
|
1306
|
+
|
|
1284
1307
|
self.config.trainer_config.wandb.current_run_id = wandb.run.id
|
|
1285
1308
|
wandb.config["run_name"] = self.config.trainer_config.wandb.name
|
|
1286
1309
|
wandb.config["run_config"] = OmegaConf.to_container(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a2
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
13
13
|
Requires-Python: <3.14,>=3.11
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
-
Requires-Dist: sleap-io<0.7.0,>=0.6.
|
|
16
|
+
Requires-Dist: sleap-io<0.7.0,>=0.6.2
|
|
17
17
|
Requires-Dist: numpy
|
|
18
18
|
Requires-Dist: lightning
|
|
19
19
|
Requires-Dist: kornia
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
|
|
2
|
-
sleap_nn/__init__.py,sha256=
|
|
2
|
+
sleap_nn/__init__.py,sha256=s3sIImYR5tiP-PfftEj7J8P1Au2nRXj4XWowznrVwm8,1362
|
|
3
3
|
sleap_nn/cli.py,sha256=U4hpEcOxK7a92GeItY95E2DRm5P1ME1GqU__mxaDcW0,21167
|
|
4
|
-
sleap_nn/evaluation.py,sha256=
|
|
4
|
+
sleap_nn/evaluation.py,sha256=sKwLnHbCcaNzPs7CJtgRmFcDRFwPMjCxB92viZvinVI,33498
|
|
5
5
|
sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
|
|
6
6
|
sleap_nn/predict.py,sha256=8QKjRbS-L-6HF1NFJWioBPv3HSzUpFr2oGEB5hRJzQA,35523
|
|
7
7
|
sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
|
|
@@ -19,7 +19,7 @@ sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,4
|
|
|
19
19
|
sleap_nn/config/data_config.py,sha256=5a5YlXm4V9qGvkqgFNy6o0XJ_Q06UFjpYJXmNHfvXEI,24021
|
|
20
20
|
sleap_nn/config/get_config.py,sha256=rjNUffKU9z-ohLwrOVmJNGCqwUM93eh68h4KJfrSy8Y,42396
|
|
21
21
|
sleap_nn/config/model_config.py,sha256=XFIbqFno7IkX0Se5WF_2_7aUalAlC2SvpDe-uP2TttM,57582
|
|
22
|
-
sleap_nn/config/trainer_config.py,sha256=
|
|
22
|
+
sleap_nn/config/trainer_config.py,sha256=Ob2UqU10DXsQOnDb0iJxy0qc82CfP6FkQZQkrCvTEEY,29120
|
|
23
23
|
sleap_nn/config/training_job_config.py,sha256=v12_ME_tBUg8JFwOxJNW4sDQn-SedDhiJOGz-TlRwT0,5861
|
|
24
24
|
sleap_nn/config/utils.py,sha256=GgWgVs7_N7ifsJ5OQG3_EyOagNyN3Dx7wS2BAlkaRkg,5553
|
|
25
25
|
sleap_nn/data/__init__.py,sha256=eMNvFJFa3gv5Rq8oK5wzo6zt1pOlwUGYf8EQii6bq7c,54
|
|
@@ -35,7 +35,7 @@ sleap_nn/data/providers.py,sha256=0x6GFP1s1c08ji4p0M5V6p-dhT4Z9c-SI_Aw1DWX-uM,14
|
|
|
35
35
|
sleap_nn/data/resizing.py,sha256=YFpSQduIBkRK39FYmrqDL-v8zMySlEs6TJxh6zb_0ZU,5076
|
|
36
36
|
sleap_nn/data/utils.py,sha256=rT0w7KMOTlzaeKWq1TqjbgC4Lvjz_G96McllvEOqXx8,5641
|
|
37
37
|
sleap_nn/inference/__init__.py,sha256=eVkCmKrxHlDFJIlZTf8B5XEOcSyw-gPQymXMY5uShOM,170
|
|
38
|
-
sleap_nn/inference/bottomup.py,sha256=
|
|
38
|
+
sleap_nn/inference/bottomup.py,sha256=3s90aRlpIcRnSNe-R5-qiuX3S48kCWMpCl8YuNnTEDI,17084
|
|
39
39
|
sleap_nn/inference/identity.py,sha256=GjNDL9MfGqNyQaK4AE8JQCAE8gpMuE_Y-3r3Gpa53CE,6540
|
|
40
40
|
sleap_nn/inference/paf_grouping.py,sha256=7Fo9lCAj-zcHgv5rI5LIMYGcixCGNt_ZbSNs8Dik7l8,69973
|
|
41
41
|
sleap_nn/inference/peak_finding.py,sha256=L9LdYKt_Bfw7cxo6xEpgF8wXcZAwq5plCfmKJ839N40,13014
|
|
@@ -52,14 +52,14 @@ sleap_nn/tracking/candidates/__init__.py,sha256=1O7NObIwshM7j1rLHmImbFphvkM9wY1j
|
|
|
52
52
|
sleap_nn/tracking/candidates/fixed_window.py,sha256=D80KMlTnenuQveQVVhk9j0G8yx6K324C7nMLHgG76e0,6296
|
|
53
53
|
sleap_nn/tracking/candidates/local_queues.py,sha256=Nx3R5wwEwq0gbfH-fi3oOumfkQo8_sYe5GN47pD9Be8,7305
|
|
54
54
|
sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U,32
|
|
55
|
-
sleap_nn/training/callbacks.py,sha256=
|
|
56
|
-
sleap_nn/training/lightning_modules.py,sha256=
|
|
55
|
+
sleap_nn/training/callbacks.py,sha256=ZO88NFGZi53Wn4qM6yp3Bk3HFmhkYSGqeMc1QJKirLo,35995
|
|
56
|
+
sleap_nn/training/lightning_modules.py,sha256=slkVtQ7r6LatWLYzxcq6x1RALYNyHTRcqiXXwD-x0PA,95420
|
|
57
57
|
sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
|
|
58
|
-
sleap_nn/training/model_trainer.py,sha256=
|
|
58
|
+
sleap_nn/training/model_trainer.py,sha256=mf6FOdGDal2mMP0F1xD9jVQ54wbUST0ovRt6OjXzVyg,60580
|
|
59
59
|
sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
|
|
60
|
-
sleap_nn-0.1.
|
|
61
|
-
sleap_nn-0.1.
|
|
62
|
-
sleap_nn-0.1.
|
|
63
|
-
sleap_nn-0.1.
|
|
64
|
-
sleap_nn-0.1.
|
|
65
|
-
sleap_nn-0.1.
|
|
60
|
+
sleap_nn-0.1.0a2.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
61
|
+
sleap_nn-0.1.0a2.dist-info/METADATA,sha256=w0dUxvJerGIpu4hlYgGbimjCAooCcf_4NcAzo8T5Sos,5637
|
|
62
|
+
sleap_nn-0.1.0a2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
63
|
+
sleap_nn-0.1.0a2.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
|
|
64
|
+
sleap_nn-0.1.0a2.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
|
|
65
|
+
sleap_nn-0.1.0a2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|