sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sleap_nn/__init__.py CHANGED
@@ -50,7 +50,7 @@ logger.add(
50
50
  colorize=False,
51
51
  )
52
52
 
53
- __version__ = "0.1.0a1"
53
+ __version__ = "0.1.0a2"
54
54
 
55
55
  # Public API
56
56
  from sleap_nn.evaluation import load_metrics
@@ -208,6 +208,23 @@ class EarlyStoppingConfig:
208
208
  stop_training_on_plateau: bool = True
209
209
 
210
210
 
211
+ @define
212
+ class EvalConfig:
213
+ """Configuration for epoch-end evaluation.
214
+
215
+ Attributes:
216
+ enabled: (bool) Enable epoch-end evaluation metrics. *Default*: `False`.
217
+ frequency: (int) Evaluate every N epochs. *Default*: `1`.
218
+ oks_stddev: (float) OKS standard deviation for evaluation. *Default*: `0.025`.
219
+ oks_scale: (float) OKS scale override. If None, uses default. *Default*: `None`.
220
+ """
221
+
222
+ enabled: bool = False
223
+ frequency: int = field(default=1, validator=validators.ge(1))
224
+ oks_stddev: float = field(default=0.025, validator=validators.gt(0))
225
+ oks_scale: Optional[float] = None
226
+
227
+
211
228
  @define
212
229
  class HardKeypointMiningConfig:
213
230
  """Configuration for online hard keypoint mining.
@@ -310,6 +327,7 @@ class TrainerConfig:
310
327
  factory=HardKeypointMiningConfig
311
328
  )
312
329
  zmq: Optional[ZMQConfig] = field(factory=ZMQConfig) # Required for SLEAP GUI
330
+ eval: EvalConfig = field(factory=EvalConfig) # Epoch-end evaluation config
313
331
 
314
332
  @staticmethod
315
333
  def validate_optimizer_name(value):
sleap_nn/evaluation.py CHANGED
@@ -29,11 +29,27 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
29
29
  """
30
30
  instance_list = []
31
31
  frame_idx = labeled_frame.frame_idx
32
- video_path = (
33
- labeled_frame.video.backend.source_filename
34
- if hasattr(labeled_frame.video.backend, "source_filename")
35
- else labeled_frame.video.backend.filename
36
- )
32
+
33
+ # Extract video path with fallbacks for embedded videos
34
+ video = labeled_frame.video
35
+ video_path = None
36
+ if video is not None:
37
+ backend = getattr(video, "backend", None)
38
+ if backend is not None:
39
+ # Try source_filename first (for embedded videos with provenance)
40
+ video_path = getattr(backend, "source_filename", None)
41
+ if video_path is None:
42
+ video_path = getattr(backend, "filename", None)
43
+ # Fallback to video.filename if backend doesn't have it
44
+ if video_path is None:
45
+ video_path = getattr(video, "filename", None)
46
+ # Handle list filenames (image sequences)
47
+ if isinstance(video_path, list) and video_path:
48
+ video_path = video_path[0]
49
+ # Final fallback: use a unique identifier
50
+ if video_path is None:
51
+ video_path = f"video_{id(video)}" if video is not None else "unknown"
52
+
37
53
  for instance in labeled_frame.instances:
38
54
  match_instance = MatchInstance(
39
55
  instance=instance, frame_idx=frame_idx, video_path=video_path
@@ -47,6 +63,10 @@ def find_frame_pairs(
47
63
  ) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
48
64
  """Find corresponding frames across two sets of labels.
49
65
 
66
+ This function uses sleap-io's robust video matching API to handle various
67
+ scenarios including embedded videos, cross-platform paths, and videos with
68
+ different metadata.
69
+
50
70
  Args:
51
71
  labels_gt: A `sio.Labels` instance with ground truth instances.
52
72
  labels_pr: A `sio.Labels` instance with predicted instances.
@@ -56,16 +76,15 @@ def find_frame_pairs(
56
76
  Returns:
57
77
  A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
58
78
  """
79
+ # Use sleap-io's robust video matching API (added in 0.6.2)
80
+ # The match() method returns a MatchResult with video_map: {pred_video: gt_video}
81
+ match_result = labels_gt.match(labels_pr)
82
+
59
83
  frame_pairs = []
60
- for video_gt in labels_gt.videos:
61
- # Find matching video instance in predictions.
62
- video_pr = None
63
- for video in labels_pr.videos:
64
- if video_gt.matches_content(video) and video_gt.matches_path(video):
65
- video_pr = video
66
- break
67
-
68
- if video_pr is None:
84
+ # Iterate over matched video pairs (pred_video -> gt_video mapping)
85
+ for video_pr, video_gt in match_result.video_map.items():
86
+ if video_gt is None:
87
+ # No match found for this prediction video
69
88
  continue
70
89
 
71
90
  # Find labeled frames in this video.
@@ -786,11 +805,26 @@ def run_evaluation(
786
805
  """Evaluate SLEAP-NN model predictions against ground truth labels."""
787
806
  logger.info("Loading ground truth labels...")
788
807
  ground_truth_instances = sio.load_slp(ground_truth_path)
808
+ logger.info(
809
+ f" Ground truth: {len(ground_truth_instances.videos)} videos, "
810
+ f"{len(ground_truth_instances.labeled_frames)} frames"
811
+ )
789
812
 
790
813
  logger.info("Loading predicted labels...")
791
814
  predicted_instances = sio.load_slp(predicted_path)
815
+ logger.info(
816
+ f" Predictions: {len(predicted_instances.videos)} videos, "
817
+ f"{len(predicted_instances.labeled_frames)} frames"
818
+ )
819
+
820
+ logger.info("Matching videos and frames...")
821
+ # Get match stats before creating evaluator
822
+ match_result = ground_truth_instances.match(predicted_instances)
823
+ logger.info(
824
+ f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
825
+ )
792
826
 
793
- logger.info("Creating evaluator...")
827
+ logger.info("Matching instances...")
794
828
  evaluator = Evaluator(
795
829
  ground_truth_instances=ground_truth_instances,
796
830
  predicted_instances=predicted_instances,
@@ -799,21 +833,38 @@ def run_evaluation(
799
833
  match_threshold=match_threshold,
800
834
  user_labels_only=user_labels_only,
801
835
  )
836
+ logger.info(
837
+ f" Frame pairs: {len(evaluator.frame_pairs)}, "
838
+ f"Matched instances: {len(evaluator.positive_pairs)}, "
839
+ f"Unmatched GT: {len(evaluator.false_negatives)}"
840
+ )
802
841
 
803
842
  logger.info("Computing evaluation metrics...")
804
843
  metrics = evaluator.evaluate()
805
844
 
845
+ # Compute PCK at specific thresholds (5 and 10 pixels)
846
+ dists = metrics["distance_metrics"]["dists"]
847
+ dists_clean = np.copy(dists)
848
+ dists_clean[np.isnan(dists_clean)] = np.inf
849
+ pck_5 = (dists_clean < 5).mean()
850
+ pck_10 = (dists_clean < 10).mean()
851
+
806
852
  # Print key metrics
807
853
  logger.info("Evaluation Results:")
808
- logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
809
- logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
810
- logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
811
- logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.4f}")
812
- logger.info(f"mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
854
+ logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
855
+ logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
856
+ logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
857
+ logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
858
+ logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
859
+ logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
860
+ logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
861
+ logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
862
+ logger.info(f" PCK@5px: {pck_5:.4f}")
863
+ logger.info(f" PCK@10px: {pck_10:.4f}")
813
864
  logger.info(
814
- f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
865
+ f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
815
866
  )
816
- logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
867
+ logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
817
868
 
818
869
  # Save metrics if path provided
819
870
  if save_metrics:
@@ -1,5 +1,6 @@
1
1
  """Inference modules for BottomUp models."""
2
2
 
3
+ import logging
3
4
  from typing import Dict, Optional
4
5
  import torch
5
6
  import lightning as L
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
7
8
  from sleap_nn.inference.paf_grouping import PAFScorer
8
9
  from sleap_nn.inference.identity import classify_peaks_from_maps
9
10
 
11
+ logger = logging.getLogger(__name__)
12
+
10
13
 
11
14
  class BottomUpInferenceModel(L.LightningModule):
12
15
  """BottomUp Inference model.
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
63
66
  return_pafs: Optional[bool] = False,
64
67
  return_paf_graph: Optional[bool] = False,
65
68
  input_scale: float = 1.0,
69
+ max_peaks_per_node: Optional[int] = None,
66
70
  ):
67
- """Initialise the model attributes."""
71
+ """Initialise the model attributes.
72
+
73
+ Args:
74
+ torch_model: A `nn.Module` that accepts images and predicts confidence maps.
75
+ paf_scorer: A `PAFScorer` instance for grouping instances.
76
+ cms_output_stride: Output stride of confidence maps relative to images.
77
+ pafs_output_stride: Output stride of PAFs relative to images.
78
+ peak_threshold: Minimum confidence map value for valid peaks.
79
+ refinement: Peak refinement method: None, "integral", or "local".
80
+ integral_patch_size: Size of patches for integral refinement.
81
+ return_confmaps: If True, return confidence maps in output.
82
+ return_pafs: If True, return PAFs in output.
83
+ return_paf_graph: If True, return intermediate PAF graph in output.
84
+ input_scale: Scale factor applied to input images.
85
+ max_peaks_per_node: Maximum number of peaks allowed per node before
86
+ skipping PAF scoring. If any node has more peaks than this limit,
87
+ empty predictions are returned. This prevents combinatorial explosion
88
+ during early training when confidence maps are noisy. Set to None to
89
+ disable this check (default). Recommended value: 100.
90
+ """
68
91
  super().__init__()
69
92
  self.torch_model = torch_model
70
93
  self.paf_scorer = paf_scorer
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
77
100
  self.return_pafs = return_pafs
78
101
  self.return_paf_graph = return_paf_graph
79
102
  self.input_scale = input_scale
103
+ self.max_peaks_per_node = max_peaks_per_node
80
104
 
81
105
  def _generate_cms_peaks(self, cms):
82
106
  # TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
124
148
  ) # (batch, h, w, 2*edges)
125
149
  cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
126
150
 
127
- (
128
- predicted_instances,
129
- predicted_peak_scores,
130
- predicted_instance_scores,
131
- edge_inds,
132
- edge_peak_inds,
133
- line_scores,
134
- ) = self.paf_scorer.predict(
135
- pafs=pafs,
136
- peaks=cms_peaks,
137
- peak_vals=cms_peak_vals,
138
- peak_channel_inds=cms_peak_channel_inds,
139
- )
140
-
141
- predicted_instances = [p / self.input_scale for p in predicted_instances]
142
- predicted_instances_adjusted = []
143
- for idx, p in enumerate(predicted_instances):
144
- predicted_instances_adjusted.append(
145
- p / inputs["eff_scale"][idx].to(p.device)
151
+ # Check if too many peaks per node (prevents combinatorial explosion)
152
+ skip_paf_scoring = False
153
+ if self.max_peaks_per_node is not None:
154
+ n_nodes = cms.shape[1]
155
+ for b in range(self.batch_size):
156
+ for node_idx in range(n_nodes):
157
+ n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
158
+ if n_peaks > self.max_peaks_per_node:
159
+ logger.warning(
160
+ f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
161
+ f"(max_peaks_per_node={self.max_peaks_per_node}). "
162
+ f"Model may need more training."
163
+ )
164
+ skip_paf_scoring = True
165
+ break
166
+ if skip_paf_scoring:
167
+ break
168
+
169
+ if skip_paf_scoring:
170
+ # Return empty predictions for each sample
171
+ device = cms.device
172
+ n_nodes = cms.shape[1]
173
+ predicted_instances_adjusted = []
174
+ predicted_peak_scores = []
175
+ predicted_instance_scores = []
176
+ for _ in range(self.batch_size):
177
+ predicted_instances_adjusted.append(
178
+ torch.full((0, n_nodes, 2), float("nan"), device=device)
179
+ )
180
+ predicted_peak_scores.append(
181
+ torch.full((0, n_nodes), float("nan"), device=device)
182
+ )
183
+ predicted_instance_scores.append(torch.tensor([], device=device))
184
+ edge_inds = [
185
+ torch.tensor([], dtype=torch.int32, device=device)
186
+ ] * self.batch_size
187
+ edge_peak_inds = [
188
+ torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
189
+ ] * self.batch_size
190
+ line_scores = [torch.tensor([], device=device)] * self.batch_size
191
+ else:
192
+ (
193
+ predicted_instances,
194
+ predicted_peak_scores,
195
+ predicted_instance_scores,
196
+ edge_inds,
197
+ edge_peak_inds,
198
+ line_scores,
199
+ ) = self.paf_scorer.predict(
200
+ pafs=pafs,
201
+ peaks=cms_peaks,
202
+ peak_vals=cms_peak_vals,
203
+ peak_channel_inds=cms_peak_channel_inds,
146
204
  )
205
+
206
+ predicted_instances = [p / self.input_scale for p in predicted_instances]
207
+ predicted_instances_adjusted = []
208
+ for idx, p in enumerate(predicted_instances):
209
+ predicted_instances_adjusted.append(
210
+ p / inputs["eff_scale"][idx].to(p.device)
211
+ )
212
+
147
213
  out = {
148
214
  "pred_instance_peaks": predicted_instances_adjusted,
149
215
  "pred_peak_values": predicted_peak_scores,
@@ -662,3 +662,277 @@ class ProgressReporterZMQ(Callback):
662
662
  return {
663
663
  k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
664
664
  }
665
+
666
+
667
+ class EpochEndEvaluationCallback(Callback):
668
+ """Callback to run full evaluation metrics at end of validation epochs.
669
+
670
+ This callback collects predictions and ground truth during validation,
671
+ then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
672
+ metrics to WandB.
673
+
674
+ Attributes:
675
+ skeleton: sio.Skeleton for creating instances.
676
+ videos: List of sio.Video objects.
677
+ eval_frequency: Run evaluation every N epochs (default: 1).
678
+ oks_stddev: OKS standard deviation (default: 0.025).
679
+ oks_scale: Optional OKS scale override.
680
+ metrics_to_log: List of metric keys to log.
681
+ """
682
+
683
+ def __init__(
684
+ self,
685
+ skeleton: "sio.Skeleton",
686
+ videos: list,
687
+ eval_frequency: int = 1,
688
+ oks_stddev: float = 0.025,
689
+ oks_scale: Optional[float] = None,
690
+ metrics_to_log: Optional[list] = None,
691
+ ):
692
+ """Initialize the callback.
693
+
694
+ Args:
695
+ skeleton: sio.Skeleton for creating instances.
696
+ videos: List of sio.Video objects.
697
+ eval_frequency: Run evaluation every N epochs (default: 1).
698
+ oks_stddev: OKS standard deviation (default: 0.025).
699
+ oks_scale: Optional OKS scale override.
700
+ metrics_to_log: List of metric keys to log. If None, logs all available.
701
+ """
702
+ super().__init__()
703
+ self.skeleton = skeleton
704
+ self.videos = videos
705
+ self.eval_frequency = eval_frequency
706
+ self.oks_stddev = oks_stddev
707
+ self.oks_scale = oks_scale
708
+ self.metrics_to_log = metrics_to_log or [
709
+ "mOKS",
710
+ "oks_voc.mAP",
711
+ "oks_voc.mAR",
712
+ "avg_distance",
713
+ "p50_distance",
714
+ "mPCK",
715
+ "visibility_precision",
716
+ "visibility_recall",
717
+ ]
718
+
719
+ def on_validation_epoch_start(self, trainer, pl_module):
720
+ """Enable prediction collection at the start of validation.
721
+
722
+ Skip during sanity check to avoid inference issues.
723
+ """
724
+ if trainer.sanity_checking:
725
+ return
726
+ pl_module._collect_val_predictions = True
727
+
728
+ def on_validation_epoch_end(self, trainer, pl_module):
729
+ """Run evaluation and log metrics at end of validation epoch."""
730
+ import sleap_io as sio
731
+ import numpy as np
732
+ from lightning.pytorch.loggers import WandbLogger
733
+ from sleap_nn.evaluation import Evaluator
734
+
735
+ # Check frequency (epoch is 0-indexed, so add 1)
736
+ if (trainer.current_epoch + 1) % self.eval_frequency != 0:
737
+ pl_module._collect_val_predictions = False
738
+ return
739
+
740
+ # Only run on rank 0 for distributed training
741
+ if not trainer.is_global_zero:
742
+ pl_module._collect_val_predictions = False
743
+ return
744
+
745
+ # Check if we have predictions
746
+ if not pl_module.val_predictions or not pl_module.val_ground_truth:
747
+ logger.warning("No predictions collected for epoch-end evaluation")
748
+ pl_module._collect_val_predictions = False
749
+ return
750
+
751
+ try:
752
+ # Build sio.Labels from accumulated predictions and ground truth
753
+ pred_labels = self._build_pred_labels(pl_module.val_predictions, sio, np)
754
+ gt_labels = self._build_gt_labels(pl_module.val_ground_truth, sio, np)
755
+
756
+ # Check if we have valid frames to evaluate
757
+ if len(pred_labels) == 0:
758
+ logger.warning(
759
+ "No valid predictions for epoch-end evaluation "
760
+ "(all predictions may be empty or NaN)"
761
+ )
762
+ pl_module._collect_val_predictions = False
763
+ pl_module.val_predictions = []
764
+ pl_module.val_ground_truth = []
765
+ return
766
+
767
+ # Run evaluation
768
+ evaluator = Evaluator(
769
+ ground_truth_instances=gt_labels,
770
+ predicted_instances=pred_labels,
771
+ oks_stddev=self.oks_stddev,
772
+ oks_scale=self.oks_scale,
773
+ user_labels_only=False, # All validation frames are "user" frames
774
+ )
775
+ metrics = evaluator.evaluate()
776
+
777
+ # Log to WandB
778
+ self._log_metrics(trainer, metrics, trainer.current_epoch)
779
+
780
+ logger.info(
781
+ f"Epoch {trainer.current_epoch} evaluation: "
782
+ f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
783
+ f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
784
+ )
785
+
786
+ except Exception as e:
787
+ logger.warning(f"Epoch-end evaluation failed: {e}")
788
+
789
+ # Cleanup
790
+ pl_module._collect_val_predictions = False
791
+ pl_module.val_predictions = []
792
+ pl_module.val_ground_truth = []
793
+
794
+ def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
795
+ """Convert prediction dicts to sio.Labels."""
796
+ labeled_frames = []
797
+ for pred in predictions:
798
+ pred_peaks = pred["pred_peaks"]
799
+ pred_scores = pred["pred_scores"]
800
+
801
+ # Handle NaN/missing predictions
802
+ if pred_peaks is None or (
803
+ isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
804
+ ):
805
+ continue
806
+
807
+ # Handle multi-instance predictions (bottomup)
808
+ if len(pred_peaks.shape) == 2:
809
+ # Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
810
+ pred_peaks = pred_peaks.reshape(1, -1, 2)
811
+ pred_scores = pred_scores.reshape(1, -1)
812
+
813
+ instances = []
814
+ for inst_idx in range(len(pred_peaks)):
815
+ inst_points = pred_peaks[inst_idx]
816
+ inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
817
+
818
+ # Skip if all NaN
819
+ if np.isnan(inst_points).all():
820
+ continue
821
+
822
+ inst = sio.PredictedInstance.from_numpy(
823
+ points_data=inst_points,
824
+ skeleton=self.skeleton,
825
+ point_scores=(
826
+ inst_scores
827
+ if inst_scores is not None
828
+ else np.ones(len(inst_points))
829
+ ),
830
+ score=(
831
+ float(np.nanmean(inst_scores))
832
+ if inst_scores is not None
833
+ else 1.0
834
+ ),
835
+ )
836
+ instances.append(inst)
837
+
838
+ if instances:
839
+ lf = sio.LabeledFrame(
840
+ video=self.videos[pred["video_idx"]],
841
+ frame_idx=pred["frame_idx"],
842
+ instances=instances,
843
+ )
844
+ labeled_frames.append(lf)
845
+
846
+ return sio.Labels(
847
+ videos=self.videos,
848
+ skeletons=[self.skeleton],
849
+ labeled_frames=labeled_frames,
850
+ )
851
+
852
+ def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
853
+ """Convert ground truth dicts to sio.Labels."""
854
+ labeled_frames = []
855
+ for gt in ground_truth:
856
+ instances = []
857
+ gt_instances = gt["gt_instances"]
858
+
859
+ # Handle shape variations
860
+ if len(gt_instances.shape) == 2:
861
+ # (n_nodes, 2) -> (1, n_nodes, 2)
862
+ gt_instances = gt_instances.reshape(1, -1, 2)
863
+
864
+ for i in range(min(gt["num_instances"], len(gt_instances))):
865
+ inst_data = gt_instances[i]
866
+ if np.isnan(inst_data).all():
867
+ continue
868
+ inst = sio.Instance.from_numpy(
869
+ points_data=inst_data,
870
+ skeleton=self.skeleton,
871
+ )
872
+ instances.append(inst)
873
+
874
+ if instances:
875
+ lf = sio.LabeledFrame(
876
+ video=self.videos[gt["video_idx"]],
877
+ frame_idx=gt["frame_idx"],
878
+ instances=instances,
879
+ )
880
+ labeled_frames.append(lf)
881
+
882
+ return sio.Labels(
883
+ videos=self.videos,
884
+ skeletons=[self.skeleton],
885
+ labeled_frames=labeled_frames,
886
+ )
887
+
888
+ def _log_metrics(self, trainer, metrics: dict, epoch: int):
889
+ """Log evaluation metrics to WandB."""
890
+ import numpy as np
891
+ from lightning.pytorch.loggers import WandbLogger
892
+
893
+ # Get WandB logger
894
+ wandb_logger = None
895
+ for log in trainer.loggers:
896
+ if isinstance(log, WandbLogger):
897
+ wandb_logger = log
898
+ break
899
+
900
+ if wandb_logger is None:
901
+ return
902
+
903
+ log_dict = {"epoch": epoch}
904
+
905
+ # Extract key metrics with consistent naming
906
+ if "mOKS" in self.metrics_to_log:
907
+ log_dict["val_mOKS"] = metrics["mOKS"]["mOKS"]
908
+
909
+ if "oks_voc.mAP" in self.metrics_to_log:
910
+ log_dict["val_oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
911
+
912
+ if "oks_voc.mAR" in self.metrics_to_log:
913
+ log_dict["val_oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
914
+
915
+ if "avg_distance" in self.metrics_to_log:
916
+ val = metrics["distance_metrics"]["avg"]
917
+ if not np.isnan(val):
918
+ log_dict["val_avg_distance"] = val
919
+
920
+ if "p50_distance" in self.metrics_to_log:
921
+ val = metrics["distance_metrics"]["p50"]
922
+ if not np.isnan(val):
923
+ log_dict["val_p50_distance"] = val
924
+
925
+ if "mPCK" in self.metrics_to_log:
926
+ log_dict["val_mPCK"] = metrics["pck_metrics"]["mPCK"]
927
+
928
+ if "visibility_precision" in self.metrics_to_log:
929
+ val = metrics["visibility_metrics"]["precision"]
930
+ if not np.isnan(val):
931
+ log_dict["val_visibility_precision"] = val
932
+
933
+ if "visibility_recall" in self.metrics_to_log:
934
+ val = metrics["visibility_metrics"]["recall"]
935
+ if not np.isnan(val):
936
+ log_dict["val_visibility_recall"] = val
937
+
938
+ wandb_logger.experiment.log(log_dict, commit=False)
@@ -1,6 +1,6 @@
1
1
  """This module has the LightningModule classes for all model types."""
2
2
 
3
- from typing import Optional, Union, Dict, Any
3
+ from typing import Optional, Union, Dict, Any, List
4
4
  import time
5
5
  from torch import nn
6
6
  import numpy as np
@@ -184,6 +184,11 @@ class LightningModel(L.LightningModule):
184
184
  self.val_loss = {}
185
185
  self.learning_rate = {}
186
186
 
187
+ # For epoch-end evaluation
188
+ self.val_predictions: List[Dict] = []
189
+ self.val_ground_truth: List[Dict] = []
190
+ self._collect_val_predictions: bool = False
191
+
187
192
  # Initialization for encoder and decoder stacks.
188
193
  if self.init_weights == "xavier":
189
194
  self.model.apply(xavier_init_weights)
@@ -331,6 +336,9 @@ class LightningModel(L.LightningModule):
331
336
  def on_validation_epoch_start(self):
332
337
  """Configure the val timer at the beginning of each epoch."""
333
338
  self.val_start_time = time.time()
339
+ # Clear accumulated predictions for new epoch
340
+ self.val_predictions = []
341
+ self.val_ground_truth = []
334
342
 
335
343
  def on_validation_epoch_end(self):
336
344
  """Configure the val timer at the end of every epoch."""
@@ -639,6 +647,51 @@ class SingleInstanceLightningModule(LightningModel):
639
647
  sync_dist=True,
640
648
  )
641
649
 
650
+ # Collect predictions for epoch-end evaluation if enabled
651
+ if self._collect_val_predictions:
652
+ with torch.no_grad():
653
+ # Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
654
+ inference_batch = {k: v for k, v in batch.items()}
655
+ if inference_batch["image"].ndim == 5:
656
+ inference_batch["image"] = inference_batch["image"].squeeze(1)
657
+ inference_output = self.single_instance_inf_layer(inference_batch)
658
+ if isinstance(inference_output, list):
659
+ inference_output = inference_output[0]
660
+
661
+ batch_size = len(batch["frame_idx"])
662
+ for i in range(batch_size):
663
+ eff = batch["eff_scale"][i].cpu().numpy()
664
+
665
+ # Predictions are already in original image space (inference divides by eff_scale)
666
+ pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
667
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
668
+
669
+ # Transform GT from preprocessed to original image space
670
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
671
+ gt_prep = batch["instances"][i].cpu().numpy()
672
+ if gt_prep.ndim == 4:
673
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
674
+ gt_orig = gt_prep / eff
675
+ num_inst = batch["num_instances"][i].item()
676
+ gt_orig = gt_orig[:num_inst] # Only valid instances
677
+
678
+ self.val_predictions.append(
679
+ {
680
+ "video_idx": batch["video_idx"][i].item(),
681
+ "frame_idx": batch["frame_idx"][i].item(),
682
+ "pred_peaks": pred_peaks,
683
+ "pred_scores": pred_scores,
684
+ }
685
+ )
686
+ self.val_ground_truth.append(
687
+ {
688
+ "video_idx": batch["video_idx"][i].item(),
689
+ "frame_idx": batch["frame_idx"][i].item(),
690
+ "gt_instances": gt_orig,
691
+ "num_instances": num_inst,
692
+ }
693
+ )
694
+
642
695
 
643
696
  class TopDownCenteredInstanceLightningModule(LightningModel):
644
697
  """Lightning Module for TopDownCenteredInstance Model.
@@ -856,6 +909,62 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
856
909
  sync_dist=True,
857
910
  )
858
911
 
912
+ # Collect predictions for epoch-end evaluation if enabled
913
+ if self._collect_val_predictions:
914
+ # SAVE bbox BEFORE inference (it modifies in-place!)
915
+ bbox_prep_saved = batch["instance_bbox"].clone()
916
+
917
+ with torch.no_grad():
918
+ inference_output = self.instance_peaks_inf_layer(batch)
919
+
920
+ batch_size = len(batch["frame_idx"])
921
+ for i in range(batch_size):
922
+ eff = batch["eff_scale"][i].cpu().numpy()
923
+
924
+ # Predictions from inference (crop-relative, original scale)
925
+ pred_peaks_crop = (
926
+ inference_output["pred_instance_peaks"][i].cpu().numpy()
927
+ )
928
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
929
+
930
+ # Compute bbox offset in original space from SAVED prep bbox
931
+ # bbox has shape (n_samples=1, 4, 2) where 4 corners
932
+ bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
933
+ bbox_top_left_orig = (
934
+ bbox_prep[0] / eff
935
+ ) # Top-left corner in original space
936
+
937
+ # Full image coordinates (original space)
938
+ pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
939
+
940
+ # GT transform: crop-relative preprocessed -> full image original
941
+ gt_crop_prep = (
942
+ batch["instance"][i].squeeze(0).cpu().numpy()
943
+ ) # (n_nodes, 2)
944
+ gt_crop_orig = gt_crop_prep / eff
945
+ gt_full_orig = gt_crop_orig + bbox_top_left_orig
946
+
947
+ self.val_predictions.append(
948
+ {
949
+ "video_idx": batch["video_idx"][i].item(),
950
+ "frame_idx": batch["frame_idx"][i].item(),
951
+ "pred_peaks": pred_peaks_full.reshape(
952
+ 1, -1, 2
953
+ ), # (1, n_nodes, 2)
954
+ "pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
955
+ }
956
+ )
957
+ self.val_ground_truth.append(
958
+ {
959
+ "video_idx": batch["video_idx"][i].item(),
960
+ "frame_idx": batch["frame_idx"][i].item(),
961
+ "gt_instances": gt_full_orig.reshape(
962
+ 1, -1, 2
963
+ ), # (1, n_nodes, 2)
964
+ "num_instances": 1,
965
+ }
966
+ )
967
+
859
968
 
860
969
  class CentroidLightningModule(LightningModel):
861
970
  """Lightning Module for Centroid Model.
@@ -1034,6 +1143,57 @@ class CentroidLightningModule(LightningModel):
1034
1143
  sync_dist=True,
1035
1144
  )
1036
1145
 
1146
+ # Collect predictions for epoch-end evaluation if enabled
1147
+ if self._collect_val_predictions:
1148
+ with torch.no_grad():
1149
+ inference_output = self.centroid_inf_layer(batch)
1150
+
1151
+ batch_size = len(batch["frame_idx"])
1152
+ for i in range(batch_size):
1153
+ eff = batch["eff_scale"][i].cpu().numpy()
1154
+
1155
+ # Predictions are in original image space (inference divides by eff_scale)
1156
+ # centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
1157
+ pred_centroids = (
1158
+ inference_output["centroids"][i].squeeze(0).cpu().numpy()
1159
+ )
1160
+ pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
1161
+
1162
+ # Transform GT centroids from preprocessed to original image space
1163
+ gt_centroids_prep = (
1164
+ batch["centroids"][i].cpu().numpy()
1165
+ ) # (n_samples=1, max_inst, 2)
1166
+ gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
1167
+ num_inst = batch["num_instances"][i].item()
1168
+
1169
+ # Filter to valid instances (non-NaN)
1170
+ valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
1171
+ pred_centroids = pred_centroids[valid_pred_mask]
1172
+ pred_vals = pred_vals[valid_pred_mask]
1173
+
1174
+ gt_centroids_valid = gt_centroids_orig[:num_inst]
1175
+
1176
+ self.val_predictions.append(
1177
+ {
1178
+ "video_idx": batch["video_idx"][i].item(),
1179
+ "frame_idx": batch["frame_idx"][i].item(),
1180
+ "pred_peaks": pred_centroids.reshape(
1181
+ -1, 1, 2
1182
+ ), # (n_inst, 1, 2)
1183
+ "pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
1184
+ }
1185
+ )
1186
+ self.val_ground_truth.append(
1187
+ {
1188
+ "video_idx": batch["video_idx"][i].item(),
1189
+ "frame_idx": batch["frame_idx"][i].item(),
1190
+ "gt_instances": gt_centroids_valid.reshape(
1191
+ -1, 1, 2
1192
+ ), # (n_inst, 1, 2)
1193
+ "num_instances": num_inst,
1194
+ }
1195
+ )
1196
+
1037
1197
 
1038
1198
  class BottomUpLightningModule(LightningModel):
1039
1199
  """Lightning Module for BottomUp Model.
@@ -1126,12 +1286,13 @@ class BottomUpLightningModule(LightningModel):
1126
1286
  self.bottomup_inf_layer = BottomUpInferenceModel(
1127
1287
  torch_model=self.forward,
1128
1288
  paf_scorer=paf_scorer,
1129
- peak_threshold=0.2,
1289
+ peak_threshold=0.1, # Lower threshold for epoch-end eval during training
1130
1290
  input_scale=1.0,
1131
1291
  return_confmaps=True,
1132
1292
  return_pafs=True,
1133
1293
  cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
1134
1294
  pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
1295
+ max_peaks_per_node=100, # Prevents combinatorial explosion in early training
1135
1296
  )
1136
1297
  self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
1137
1298
 
@@ -1340,6 +1501,53 @@ class BottomUpLightningModule(LightningModel):
1340
1501
  sync_dist=True,
1341
1502
  )
1342
1503
 
1504
+ # Collect predictions for epoch-end evaluation if enabled
1505
+ if self._collect_val_predictions:
1506
+ with torch.no_grad():
1507
+ # Note: Do NOT squeeze the image here - the forward() method expects
1508
+ # (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
1509
+ inference_output = self.bottomup_inf_layer(batch)
1510
+ if isinstance(inference_output, list):
1511
+ inference_output = inference_output[0]
1512
+
1513
+ batch_size = len(batch["frame_idx"])
1514
+ for i in range(batch_size):
1515
+ eff = batch["eff_scale"][i].cpu().numpy()
1516
+
1517
+ # Predictions are already in original space (variable number of instances)
1518
+ pred_peaks = inference_output["pred_instance_peaks"][i]
1519
+ pred_scores = inference_output["pred_peak_values"][i]
1520
+ if torch.is_tensor(pred_peaks):
1521
+ pred_peaks = pred_peaks.cpu().numpy()
1522
+ if torch.is_tensor(pred_scores):
1523
+ pred_scores = pred_scores.cpu().numpy()
1524
+
1525
+ # Transform GT to original space
1526
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
1527
+ gt_prep = batch["instances"][i].cpu().numpy()
1528
+ if gt_prep.ndim == 4:
1529
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
1530
+ gt_orig = gt_prep / eff
1531
+ num_inst = batch["num_instances"][i].item()
1532
+ gt_orig = gt_orig[:num_inst] # Only valid instances
1533
+
1534
+ self.val_predictions.append(
1535
+ {
1536
+ "video_idx": batch["video_idx"][i].item(),
1537
+ "frame_idx": batch["frame_idx"][i].item(),
1538
+ "pred_peaks": pred_peaks, # Original space, variable instances
1539
+ "pred_scores": pred_scores,
1540
+ }
1541
+ )
1542
+ self.val_ground_truth.append(
1543
+ {
1544
+ "video_idx": batch["video_idx"][i].item(),
1545
+ "frame_idx": batch["frame_idx"][i].item(),
1546
+ "gt_instances": gt_orig, # Original space
1547
+ "num_instances": num_inst,
1548
+ }
1549
+ )
1550
+
1343
1551
 
1344
1552
  class BottomUpMultiClassLightningModule(LightningModel):
1345
1553
  """Lightning Module for BottomUp ID Model.
@@ -61,6 +61,7 @@ from sleap_nn.training.callbacks import (
61
61
  WandBVizCallbackWithPAFs,
62
62
  CSVLoggerCallback,
63
63
  SleapProgressBar,
64
+ EpochEndEvaluationCallback,
64
65
  )
65
66
  from sleap_nn import RANK
66
67
  from sleap_nn.legacy_models import get_keras_first_layer_channels
@@ -1086,6 +1087,18 @@ class ModelTrainer:
1086
1087
  if self.config.trainer_config.enable_progress_bar:
1087
1088
  callbacks.append(SleapProgressBar())
1088
1089
 
1090
+ # Add epoch-end evaluation callback if enabled
1091
+ if self.config.trainer_config.eval.enabled:
1092
+ callbacks.append(
1093
+ EpochEndEvaluationCallback(
1094
+ skeleton=self.skeletons[0],
1095
+ videos=self.val_labels[0].videos,
1096
+ eval_frequency=self.config.trainer_config.eval.frequency,
1097
+ oks_stddev=self.config.trainer_config.eval.oks_stddev,
1098
+ oks_scale=self.config.trainer_config.eval.oks_scale,
1099
+ )
1100
+ )
1101
+
1089
1102
  return loggers, callbacks
1090
1103
 
1091
1104
  def _delete_cache_imgs(self):
@@ -1281,6 +1294,16 @@ class ModelTrainer:
1281
1294
  wandb.define_metric("train_pafs*", step_metric="epoch")
1282
1295
  wandb.define_metric("val_pafs*", step_metric="epoch")
1283
1296
 
1297
+ # Evaluation metrics use epoch as x-axis
1298
+ wandb.define_metric("val_mOKS", step_metric="epoch")
1299
+ wandb.define_metric("val_oks_voc_mAP", step_metric="epoch")
1300
+ wandb.define_metric("val_oks_voc_mAR", step_metric="epoch")
1301
+ wandb.define_metric("val_avg_distance", step_metric="epoch")
1302
+ wandb.define_metric("val_p50_distance", step_metric="epoch")
1303
+ wandb.define_metric("val_mPCK", step_metric="epoch")
1304
+ wandb.define_metric("val_visibility_precision", step_metric="epoch")
1305
+ wandb.define_metric("val_visibility_recall", step_metric="epoch")
1306
+
1284
1307
  self.config.trainer_config.wandb.current_run_id = wandb.run.id
1285
1308
  wandb.config["run_name"] = self.config.trainer_config.wandb.name
1286
1309
  wandb.config["run_config"] = OmegaConf.to_container(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sleap-nn
3
- Version: 0.1.0a1
3
+ Version: 0.1.0a2
4
4
  Summary: Neural network backend for training and inference for animal pose estimation.
5
5
  Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
6
6
  License: BSD-3-Clause
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.13
13
13
  Requires-Python: <3.14,>=3.11
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
- Requires-Dist: sleap-io<0.7.0,>=0.6.0
16
+ Requires-Dist: sleap-io<0.7.0,>=0.6.2
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: lightning
19
19
  Requires-Dist: kornia
@@ -1,7 +1,7 @@
1
1
  sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
2
- sleap_nn/__init__.py,sha256=l5Lwiad8GOurqkAhMwWw8-UcpH6af2TnMURf-oKj_U8,1362
2
+ sleap_nn/__init__.py,sha256=s3sIImYR5tiP-PfftEj7J8P1Au2nRXj4XWowznrVwm8,1362
3
3
  sleap_nn/cli.py,sha256=U4hpEcOxK7a92GeItY95E2DRm5P1ME1GqU__mxaDcW0,21167
4
- sleap_nn/evaluation.py,sha256=3u7y85wFoBgCwOB2xOGTJIDrd2dUPWOo4m0s0oW3da4,31095
4
+ sleap_nn/evaluation.py,sha256=sKwLnHbCcaNzPs7CJtgRmFcDRFwPMjCxB92viZvinVI,33498
5
5
  sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
6
6
  sleap_nn/predict.py,sha256=8QKjRbS-L-6HF1NFJWioBPv3HSzUpFr2oGEB5hRJzQA,35523
7
7
  sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
@@ -19,7 +19,7 @@ sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,4
19
19
  sleap_nn/config/data_config.py,sha256=5a5YlXm4V9qGvkqgFNy6o0XJ_Q06UFjpYJXmNHfvXEI,24021
20
20
  sleap_nn/config/get_config.py,sha256=rjNUffKU9z-ohLwrOVmJNGCqwUM93eh68h4KJfrSy8Y,42396
21
21
  sleap_nn/config/model_config.py,sha256=XFIbqFno7IkX0Se5WF_2_7aUalAlC2SvpDe-uP2TttM,57582
22
- sleap_nn/config/trainer_config.py,sha256=ZMXxns6VYakgYHRhkM541Eje76DdaTdDi4FFPNjJtP4,28413
22
+ sleap_nn/config/trainer_config.py,sha256=Ob2UqU10DXsQOnDb0iJxy0qc82CfP6FkQZQkrCvTEEY,29120
23
23
  sleap_nn/config/training_job_config.py,sha256=v12_ME_tBUg8JFwOxJNW4sDQn-SedDhiJOGz-TlRwT0,5861
24
24
  sleap_nn/config/utils.py,sha256=GgWgVs7_N7ifsJ5OQG3_EyOagNyN3Dx7wS2BAlkaRkg,5553
25
25
  sleap_nn/data/__init__.py,sha256=eMNvFJFa3gv5Rq8oK5wzo6zt1pOlwUGYf8EQii6bq7c,54
@@ -35,7 +35,7 @@ sleap_nn/data/providers.py,sha256=0x6GFP1s1c08ji4p0M5V6p-dhT4Z9c-SI_Aw1DWX-uM,14
35
35
  sleap_nn/data/resizing.py,sha256=YFpSQduIBkRK39FYmrqDL-v8zMySlEs6TJxh6zb_0ZU,5076
36
36
  sleap_nn/data/utils.py,sha256=rT0w7KMOTlzaeKWq1TqjbgC4Lvjz_G96McllvEOqXx8,5641
37
37
  sleap_nn/inference/__init__.py,sha256=eVkCmKrxHlDFJIlZTf8B5XEOcSyw-gPQymXMY5uShOM,170
38
- sleap_nn/inference/bottomup.py,sha256=NqN-G8TzAOsvCoL3bttEjA1iGsuveLOnOCXIUeFCdSA,13684
38
+ sleap_nn/inference/bottomup.py,sha256=3s90aRlpIcRnSNe-R5-qiuX3S48kCWMpCl8YuNnTEDI,17084
39
39
  sleap_nn/inference/identity.py,sha256=GjNDL9MfGqNyQaK4AE8JQCAE8gpMuE_Y-3r3Gpa53CE,6540
40
40
  sleap_nn/inference/paf_grouping.py,sha256=7Fo9lCAj-zcHgv5rI5LIMYGcixCGNt_ZbSNs8Dik7l8,69973
41
41
  sleap_nn/inference/peak_finding.py,sha256=L9LdYKt_Bfw7cxo6xEpgF8wXcZAwq5plCfmKJ839N40,13014
@@ -52,14 +52,14 @@ sleap_nn/tracking/candidates/__init__.py,sha256=1O7NObIwshM7j1rLHmImbFphvkM9wY1j
52
52
  sleap_nn/tracking/candidates/fixed_window.py,sha256=D80KMlTnenuQveQVVhk9j0G8yx6K324C7nMLHgG76e0,6296
53
53
  sleap_nn/tracking/candidates/local_queues.py,sha256=Nx3R5wwEwq0gbfH-fi3oOumfkQo8_sYe5GN47pD9Be8,7305
54
54
  sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U,32
55
- sleap_nn/training/callbacks.py,sha256=TVnQ6plNC2MnlTiY2rSCRuw2WRk5cQSziek_VPUcOEg,25994
56
- sleap_nn/training/lightning_modules.py,sha256=G3c4xJkYWW-iSRawzkgTqkGd4lTsbPiMTcB5Nvq7jes,85512
55
+ sleap_nn/training/callbacks.py,sha256=ZO88NFGZi53Wn4qM6yp3Bk3HFmhkYSGqeMc1QJKirLo,35995
56
+ sleap_nn/training/lightning_modules.py,sha256=slkVtQ7r6LatWLYzxcq6x1RALYNyHTRcqiXXwD-x0PA,95420
57
57
  sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
58
- sleap_nn/training/model_trainer.py,sha256=loCmEX0DfBtdV_pN-W8s31fn2_L-lbpWaq3OQXeSp-0,59337
58
+ sleap_nn/training/model_trainer.py,sha256=mf6FOdGDal2mMP0F1xD9jVQ54wbUST0ovRt6OjXzVyg,60580
59
59
  sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
60
- sleap_nn-0.1.0a1.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
61
- sleap_nn-0.1.0a1.dist-info/METADATA,sha256=h3d4WPIu_JunY32jaRqJ4-fXp4KruTWT57FWb3L6dps,5637
62
- sleap_nn-0.1.0a1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
- sleap_nn-0.1.0a1.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
64
- sleap_nn-0.1.0a1.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
65
- sleap_nn-0.1.0a1.dist-info/RECORD,,
60
+ sleap_nn-0.1.0a2.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
61
+ sleap_nn-0.1.0a2.dist-info/METADATA,sha256=w0dUxvJerGIpu4hlYgGbimjCAooCcf_4NcAzo8T5Sos,5637
62
+ sleap_nn-0.1.0a2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
+ sleap_nn-0.1.0a2.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
64
+ sleap_nn-0.1.0a2.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
65
+ sleap_nn-0.1.0a2.dist-info/RECORD,,