sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -5
  8. sleap_nn/config/trainer_config.py +0 -71
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +34 -364
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -69
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +204 -201
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
sleap_nn/data/utils.py CHANGED
@@ -1,14 +1,12 @@
1
1
  """Miscellaneous utility functions for data processing."""
2
2
 
3
3
  from typing import Tuple, List, Any, Optional
4
- import sys
5
4
  import torch
6
5
  from omegaconf import DictConfig
7
6
  import sleap_io as sio
8
7
  from sleap_nn.config.utils import get_model_type_from_cfg
9
8
  import psutil
10
9
  import numpy as np
11
- from loguru import logger
12
10
  from sleap_nn.data.providers import get_max_instances
13
11
 
14
12
 
@@ -117,151 +115,35 @@ def check_memory(
117
115
  return img_mem
118
116
 
119
117
 
120
- def estimate_cache_memory(
121
- train_labels: List[sio.Labels],
122
- val_labels: List[sio.Labels],
123
- num_workers: int = 0,
124
- memory_buffer: float = 0.2,
125
- ) -> dict:
126
- """Estimate memory requirements for in-memory caching dataset pipeline.
127
-
128
- This function calculates the total memory needed for caching images, accounting for:
129
- - Raw image data size
130
- - Python object overhead (dictionary keys, numpy array wrappers)
131
- - DataLoader worker memory overhead (Copy-on-Write duplication on Unix systems)
132
- - General memory buffer for training overhead
133
-
134
- When using DataLoader with num_workers > 0, worker processes are spawned via fork()
135
- on Unix systems. While Copy-on-Write (CoW) initially shares memory, Python's reference
136
- counting can trigger memory page duplication when workers access cached data.
137
-
138
- Args:
139
- train_labels: List of `sleap_io.Labels` objects for training data.
140
- val_labels: List of `sleap_io.Labels` objects for validation data.
141
- num_workers: Number of DataLoader worker processes. When > 0, additional memory
142
- overhead is estimated for worker process duplication.
143
- memory_buffer: Fraction of memory to reserve as buffer for training overhead
144
- (model weights, activations, gradients, etc.). Default: 0.2 (20%).
145
-
146
- Returns:
147
- dict: Memory estimation breakdown with keys:
148
- - 'raw_cache_bytes': Raw image data size in bytes
149
- - 'python_overhead_bytes': Estimated Python object overhead
150
- - 'worker_overhead_bytes': Estimated memory for DataLoader workers
151
- - 'buffer_bytes': Memory buffer for training overhead
152
- - 'total_bytes': Total estimated memory requirement
153
- - 'available_bytes': Available system memory
154
- - 'sufficient': True if total <= available, False otherwise
155
- """
156
- # Calculate raw image cache size
157
- train_cache_bytes = 0
158
- val_cache_bytes = 0
159
- num_train_samples = 0
160
- num_val_samples = 0
161
-
162
- for train, val in zip(train_labels, val_labels):
163
- train_cache_bytes += check_memory(train)
164
- val_cache_bytes += check_memory(val)
165
- num_train_samples += len(train)
166
- num_val_samples += len(val)
167
-
168
- raw_cache_bytes = train_cache_bytes + val_cache_bytes
169
- total_samples = num_train_samples + num_val_samples
170
-
171
- # Python object overhead: dict keys, numpy array wrappers, tuple keys
172
- # Estimate ~200 bytes per sample for Python object overhead
173
- python_overhead_per_sample = 200
174
- python_overhead_bytes = total_samples * python_overhead_per_sample
175
-
176
- # Worker memory overhead
177
- # When num_workers > 0, workers are forked or spawned depending on platform.
178
- # Default start methods (Python 3.8+):
179
- # - Linux: fork (Copy-on-Write, partial memory duplication)
180
- # - macOS: spawn (full dataset copy to each worker, changed in Python 3.8)
181
- # - Windows: spawn (full dataset copy to each worker)
182
- worker_overhead_bytes = 0
183
- if num_workers > 0:
184
- if sys.platform == "linux":
185
- # Linux uses fork() with Copy-on-Write by default
186
- # Estimate 25% duplication per worker due to Python refcounting
187
- # triggering CoW page copies
188
- worker_overhead_bytes = int(raw_cache_bytes * 0.25 * num_workers)
189
- if num_workers >= 4:
190
- logger.info(
191
- f"Using in-memory caching with {num_workers} DataLoader workers. "
192
- f"Estimated additional memory for workers: "
193
- f"{worker_overhead_bytes / (1024**3):.2f} GB"
194
- )
195
- else:
196
- # macOS (darwin) and Windows use spawn - dataset is copied to each worker
197
- # Since Python 3.8, macOS defaults to spawn due to fork safety issues
198
- # With caching enabled, we avoid pickling labels_list, but the cache
199
- # dict is still part of the dataset and gets copied to each worker
200
- worker_overhead_bytes = int(raw_cache_bytes * 0.5 * num_workers)
201
- platform_name = "macOS" if sys.platform == "darwin" else "Windows"
202
- logger.warning(
203
- f"Using in-memory caching with {num_workers} DataLoader workers on {platform_name}. "
204
- f"Memory usage may be significantly higher than estimated (~{worker_overhead_bytes / (1024**3):.1f} GB extra) "
205
- f"due to spawn-based multiprocessing. "
206
- f"Consider using disk caching or num_workers=0 for large datasets."
207
- )
208
-
209
- # Memory buffer for training overhead (model, gradients, activations)
210
- subtotal = raw_cache_bytes + python_overhead_bytes + worker_overhead_bytes
211
- buffer_bytes = int(subtotal * memory_buffer)
212
-
213
- total_bytes = subtotal + buffer_bytes
214
- available_bytes = psutil.virtual_memory().available
215
-
216
- return {
217
- "raw_cache_bytes": raw_cache_bytes,
218
- "python_overhead_bytes": python_overhead_bytes,
219
- "worker_overhead_bytes": worker_overhead_bytes,
220
- "buffer_bytes": buffer_bytes,
221
- "total_bytes": total_bytes,
222
- "available_bytes": available_bytes,
223
- "sufficient": total_bytes <= available_bytes,
224
- "num_samples": total_samples,
225
- }
226
-
227
-
228
118
  def check_cache_memory(
229
119
  train_labels: List[sio.Labels],
230
120
  val_labels: List[sio.Labels],
231
121
  memory_buffer: float = 0.2,
232
- num_workers: int = 0,
233
122
  ) -> bool:
234
123
  """Check memory requirements for in-memory caching dataset pipeline.
235
124
 
236
- This function determines if the system has sufficient memory for in-memory
237
- image caching, accounting for DataLoader worker processes.
238
-
239
125
  Args:
240
126
  train_labels: List of `sleap_io.Labels` objects for training data.
241
127
  val_labels: List of `sleap_io.Labels` objects for validation data.
242
- memory_buffer: Fraction of memory to reserve as buffer. Default: 0.2 (20%).
243
- num_workers: Number of DataLoader worker processes. When > 0, additional memory
244
- overhead is estimated for worker process duplication.
128
+ memory_buffer: Fraction of the total image memory required for caching that
129
+ should be reserved as a buffer.
245
130
 
246
131
  Returns:
247
132
  bool: True if the total memory required for caching is within available system
248
133
  memory, False otherwise.
249
134
  """
250
- estimate = estimate_cache_memory(
251
- train_labels=train_labels,
252
- val_labels=val_labels,
253
- num_workers=num_workers,
254
- memory_buffer=memory_buffer,
255
- )
256
-
257
- if not estimate["sufficient"]:
258
- total_gb = estimate["total_bytes"] / (1024**3)
259
- available_gb = estimate["available_bytes"] / (1024**3)
260
- raw_gb = estimate["raw_cache_bytes"] / (1024**3)
261
- logger.info(
262
- f"Memory check failed: need ~{total_gb:.2f} GB "
263
- f"(raw cache: {raw_gb:.2f} GB, {estimate['num_samples']} samples), "
264
- f"available: {available_gb:.2f} GB"
265
- )
266
-
267
- return estimate["sufficient"]
135
+ train_cache_memory_final = 0
136
+ val_cache_memory_final = 0
137
+ for train, val in zip(train_labels, val_labels):
138
+ train_cache_memory = check_memory(train)
139
+ val_cache_memory = check_memory(val)
140
+ train_cache_memory_final += train_cache_memory
141
+ val_cache_memory_final += val_cache_memory
142
+
143
+ total_cache_memory = train_cache_memory_final + val_cache_memory_final
144
+ total_cache_memory += memory_buffer * total_cache_memory # memory required in bytes
145
+ available_memory = psutil.virtual_memory().available # available memory in bytes
146
+
147
+ if total_cache_memory > available_memory:
148
+ return False
149
+ return True
sleap_nn/evaluation.py CHANGED
@@ -29,27 +29,11 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
29
29
  """
30
30
  instance_list = []
31
31
  frame_idx = labeled_frame.frame_idx
32
-
33
- # Extract video path with fallbacks for embedded videos
34
- video = labeled_frame.video
35
- video_path = None
36
- if video is not None:
37
- backend = getattr(video, "backend", None)
38
- if backend is not None:
39
- # Try source_filename first (for embedded videos with provenance)
40
- video_path = getattr(backend, "source_filename", None)
41
- if video_path is None:
42
- video_path = getattr(backend, "filename", None)
43
- # Fallback to video.filename if backend doesn't have it
44
- if video_path is None:
45
- video_path = getattr(video, "filename", None)
46
- # Handle list filenames (image sequences)
47
- if isinstance(video_path, list) and video_path:
48
- video_path = video_path[0]
49
- # Final fallback: use a unique identifier
50
- if video_path is None:
51
- video_path = f"video_{id(video)}" if video is not None else "unknown"
52
-
32
+ video_path = (
33
+ labeled_frame.video.backend.source_filename
34
+ if hasattr(labeled_frame.video.backend, "source_filename")
35
+ else labeled_frame.video.backend.filename
36
+ )
53
37
  for instance in labeled_frame.instances:
54
38
  match_instance = MatchInstance(
55
39
  instance=instance, frame_idx=frame_idx, video_path=video_path
@@ -63,10 +47,6 @@ def find_frame_pairs(
63
47
  ) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
64
48
  """Find corresponding frames across two sets of labels.
65
49
 
66
- This function uses sleap-io's robust video matching API to handle various
67
- scenarios including embedded videos, cross-platform paths, and videos with
68
- different metadata.
69
-
70
50
  Args:
71
51
  labels_gt: A `sio.Labels` instance with ground truth instances.
72
52
  labels_pr: A `sio.Labels` instance with predicted instances.
@@ -76,15 +56,16 @@ def find_frame_pairs(
76
56
  Returns:
77
57
  A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
78
58
  """
79
- # Use sleap-io's robust video matching API (added in 0.6.2)
80
- # The match() method returns a MatchResult with video_map: {pred_video: gt_video}
81
- match_result = labels_gt.match(labels_pr)
82
-
83
59
  frame_pairs = []
84
- # Iterate over matched video pairs (pred_video -> gt_video mapping)
85
- for video_pr, video_gt in match_result.video_map.items():
86
- if video_gt is None:
87
- # No match found for this prediction video
60
+ for video_gt in labels_gt.videos:
61
+ # Find matching video instance in predictions.
62
+ video_pr = None
63
+ for video in labels_pr.videos:
64
+ if video_gt.matches_content(video) and video_gt.matches_path(video):
65
+ video_pr = video
66
+ break
67
+
68
+ if video_pr is None:
88
69
  continue
89
70
 
90
71
  # Find labeled frames in this video.
@@ -639,19 +620,11 @@ class Evaluator:
639
620
  mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
640
621
  mPCK = mPCK_parts.mean()
641
622
 
642
- # Precompute PCK at common thresholds
643
- idx_5 = np.argmin(np.abs(thresholds - 5))
644
- idx_10 = np.argmin(np.abs(thresholds - 10))
645
- pck5 = pcks[:, :, idx_5].mean()
646
- pck10 = pcks[:, :, idx_10].mean()
647
-
648
623
  return {
649
624
  "thresholds": thresholds,
650
625
  "pcks": pcks,
651
626
  "mPCK_parts": mPCK_parts,
652
627
  "mPCK": mPCK,
653
- "PCK@5": pck5,
654
- "PCK@10": pck10,
655
628
  }
656
629
 
657
630
  def visibility_metrics(self):
@@ -813,26 +786,11 @@ def run_evaluation(
813
786
  """Evaluate SLEAP-NN model predictions against ground truth labels."""
814
787
  logger.info("Loading ground truth labels...")
815
788
  ground_truth_instances = sio.load_slp(ground_truth_path)
816
- logger.info(
817
- f" Ground truth: {len(ground_truth_instances.videos)} videos, "
818
- f"{len(ground_truth_instances.labeled_frames)} frames"
819
- )
820
789
 
821
790
  logger.info("Loading predicted labels...")
822
791
  predicted_instances = sio.load_slp(predicted_path)
823
- logger.info(
824
- f" Predictions: {len(predicted_instances.videos)} videos, "
825
- f"{len(predicted_instances.labeled_frames)} frames"
826
- )
827
-
828
- logger.info("Matching videos and frames...")
829
- # Get match stats before creating evaluator
830
- match_result = ground_truth_instances.match(predicted_instances)
831
- logger.info(
832
- f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
833
- )
834
792
 
835
- logger.info("Matching instances...")
793
+ logger.info("Creating evaluator...")
836
794
  evaluator = Evaluator(
837
795
  ground_truth_instances=ground_truth_instances,
838
796
  predicted_instances=predicted_instances,
@@ -841,38 +799,21 @@ def run_evaluation(
841
799
  match_threshold=match_threshold,
842
800
  user_labels_only=user_labels_only,
843
801
  )
844
- logger.info(
845
- f" Frame pairs: {len(evaluator.frame_pairs)}, "
846
- f"Matched instances: {len(evaluator.positive_pairs)}, "
847
- f"Unmatched GT: {len(evaluator.false_negatives)}"
848
- )
849
802
 
850
803
  logger.info("Computing evaluation metrics...")
851
804
  metrics = evaluator.evaluate()
852
805
 
853
- # Compute PCK at specific thresholds (5 and 10 pixels)
854
- dists = metrics["distance_metrics"]["dists"]
855
- dists_clean = np.copy(dists)
856
- dists_clean[np.isnan(dists_clean)] = np.inf
857
- pck_5 = (dists_clean < 5).mean()
858
- pck_10 = (dists_clean < 10).mean()
859
-
860
806
  # Print key metrics
861
807
  logger.info("Evaluation Results:")
862
- logger.info(f" mOKS: {metrics['mOKS']['mOKS']:.4f}")
863
- logger.info(f" mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
864
- logger.info(f" mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
865
- logger.info(f" Average Distance: {metrics['distance_metrics']['avg']:.2f} px")
866
- logger.info(f" dist.p50: {metrics['distance_metrics']['p50']:.2f} px")
867
- logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
868
- logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
869
- logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
870
- logger.info(f" PCK@5px: {pck_5:.4f}")
871
- logger.info(f" PCK@10px: {pck_10:.4f}")
808
+ logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
809
+ logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
810
+ logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
811
+ logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.4f}")
812
+ logger.info(f"mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
872
813
  logger.info(
873
- f" Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
814
+ f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
874
815
  )
875
- logger.info(f" Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
816
+ logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
876
817
 
877
818
  # Save metrics if path provided
878
819
  if save_metrics:
@@ -1,6 +1,5 @@
1
1
  """Inference modules for BottomUp models."""
2
2
 
3
- import logging
4
3
  from typing import Dict, Optional
5
4
  import torch
6
5
  import lightning as L
@@ -8,8 +7,6 @@ from sleap_nn.inference.peak_finding import find_local_peaks
8
7
  from sleap_nn.inference.paf_grouping import PAFScorer
9
8
  from sleap_nn.inference.identity import classify_peaks_from_maps
10
9
 
11
- logger = logging.getLogger(__name__)
12
-
13
10
 
14
11
  class BottomUpInferenceModel(L.LightningModule):
15
12
  """BottomUp Inference model.
@@ -66,28 +63,8 @@ class BottomUpInferenceModel(L.LightningModule):
66
63
  return_pafs: Optional[bool] = False,
67
64
  return_paf_graph: Optional[bool] = False,
68
65
  input_scale: float = 1.0,
69
- max_peaks_per_node: Optional[int] = None,
70
66
  ):
71
- """Initialise the model attributes.
72
-
73
- Args:
74
- torch_model: A `nn.Module` that accepts images and predicts confidence maps.
75
- paf_scorer: A `PAFScorer` instance for grouping instances.
76
- cms_output_stride: Output stride of confidence maps relative to images.
77
- pafs_output_stride: Output stride of PAFs relative to images.
78
- peak_threshold: Minimum confidence map value for valid peaks.
79
- refinement: Peak refinement method: None, "integral", or "local".
80
- integral_patch_size: Size of patches for integral refinement.
81
- return_confmaps: If True, return confidence maps in output.
82
- return_pafs: If True, return PAFs in output.
83
- return_paf_graph: If True, return intermediate PAF graph in output.
84
- input_scale: Scale factor applied to input images.
85
- max_peaks_per_node: Maximum number of peaks allowed per node before
86
- skipping PAF scoring. If any node has more peaks than this limit,
87
- empty predictions are returned. This prevents combinatorial explosion
88
- during early training when confidence maps are noisy. Set to None to
89
- disable this check (default). Recommended value: 100.
90
- """
67
+ """Initialise the model attributes."""
91
68
  super().__init__()
92
69
  self.torch_model = torch_model
93
70
  self.paf_scorer = paf_scorer
@@ -100,7 +77,6 @@ class BottomUpInferenceModel(L.LightningModule):
100
77
  self.return_pafs = return_pafs
101
78
  self.return_paf_graph = return_paf_graph
102
79
  self.input_scale = input_scale
103
- self.max_peaks_per_node = max_peaks_per_node
104
80
 
105
81
  def _generate_cms_peaks(self, cms):
106
82
  # TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
@@ -148,68 +124,26 @@ class BottomUpInferenceModel(L.LightningModule):
148
124
  ) # (batch, h, w, 2*edges)
149
125
  cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
150
126
 
151
- # Check if too many peaks per node (prevents combinatorial explosion)
152
- skip_paf_scoring = False
153
- if self.max_peaks_per_node is not None:
154
- n_nodes = cms.shape[1]
155
- for b in range(self.batch_size):
156
- for node_idx in range(n_nodes):
157
- n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
158
- if n_peaks > self.max_peaks_per_node:
159
- logger.warning(
160
- f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
161
- f"(max_peaks_per_node={self.max_peaks_per_node}). "
162
- f"Model may need more training."
163
- )
164
- skip_paf_scoring = True
165
- break
166
- if skip_paf_scoring:
167
- break
168
-
169
- if skip_paf_scoring:
170
- # Return empty predictions for each sample
171
- device = cms.device
172
- n_nodes = cms.shape[1]
173
- predicted_instances_adjusted = []
174
- predicted_peak_scores = []
175
- predicted_instance_scores = []
176
- for _ in range(self.batch_size):
177
- predicted_instances_adjusted.append(
178
- torch.full((0, n_nodes, 2), float("nan"), device=device)
179
- )
180
- predicted_peak_scores.append(
181
- torch.full((0, n_nodes), float("nan"), device=device)
182
- )
183
- predicted_instance_scores.append(torch.tensor([], device=device))
184
- edge_inds = [
185
- torch.tensor([], dtype=torch.int32, device=device)
186
- ] * self.batch_size
187
- edge_peak_inds = [
188
- torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
189
- ] * self.batch_size
190
- line_scores = [torch.tensor([], device=device)] * self.batch_size
191
- else:
192
- (
193
- predicted_instances,
194
- predicted_peak_scores,
195
- predicted_instance_scores,
196
- edge_inds,
197
- edge_peak_inds,
198
- line_scores,
199
- ) = self.paf_scorer.predict(
200
- pafs=pafs,
201
- peaks=cms_peaks,
202
- peak_vals=cms_peak_vals,
203
- peak_channel_inds=cms_peak_channel_inds,
204
- )
205
-
206
- predicted_instances = [p / self.input_scale for p in predicted_instances]
207
- predicted_instances_adjusted = []
208
- for idx, p in enumerate(predicted_instances):
209
- predicted_instances_adjusted.append(
210
- p / inputs["eff_scale"][idx].to(p.device)
211
- )
127
+ (
128
+ predicted_instances,
129
+ predicted_peak_scores,
130
+ predicted_instance_scores,
131
+ edge_inds,
132
+ edge_peak_inds,
133
+ line_scores,
134
+ ) = self.paf_scorer.predict(
135
+ pafs=pafs,
136
+ peaks=cms_peaks,
137
+ peak_vals=cms_peak_vals,
138
+ peak_channel_inds=cms_peak_channel_inds,
139
+ )
212
140
 
141
+ predicted_instances = [p / self.input_scale for p in predicted_instances]
142
+ predicted_instances_adjusted = []
143
+ for idx, p in enumerate(predicted_instances):
144
+ predicted_instances_adjusted.append(
145
+ p / inputs["eff_scale"][idx].to(p.device)
146
+ )
213
147
  out = {
214
148
  "pred_instance_peaks": predicted_instances_adjusted,
215
149
  "pred_peak_values": predicted_peak_scores,
@@ -2,60 +2,18 @@
2
2
 
3
3
  from typing import Optional, Tuple
4
4
 
5
+ import kornia as K
6
+ import numpy as np
5
7
  import torch
6
- import torch.nn.functional as F
8
+ from kornia.geometry.transform import crop_and_resize
7
9
 
8
10
  from sleap_nn.data.instance_cropping import make_centered_bboxes
9
11
 
10
12
 
11
- def morphological_dilation(image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
12
- """Apply morphological dilation using max pooling.
13
-
14
- This is a pure PyTorch replacement for kornia.morphology.dilation.
15
- For non-maximum suppression, it computes the maximum of 8 neighbors
16
- (excluding the center pixel).
17
-
18
- Args:
19
- image: Input tensor of shape (B, 1, H, W).
20
- kernel: Dilation kernel (3x3 expected for NMS).
21
-
22
- Returns:
23
- Dilated tensor of same shape as input.
24
- """
25
- # Pad the image to handle border pixels
26
- padded = F.pad(image, (1, 1, 1, 1), mode="constant", value=float("-inf"))
27
-
28
- # Extract 3x3 patches using unfold
29
- # Shape: (B, 1, H, W, 3, 3)
30
- patches = padded.unfold(2, 3, 1).unfold(3, 3, 1)
31
-
32
- # Reshape to (B, 1, H, W, 9)
33
- b, c, h, w, kh, kw = patches.shape
34
- patches = patches.reshape(b, c, h, w, -1)
35
-
36
- # Apply kernel mask (kernel has 0 at center, 1 elsewhere for NMS)
37
- # Reshape kernel to (1, 1, 1, 1, 9)
38
- kernel_flat = kernel.reshape(-1).to(patches.device)
39
- kernel_mask = kernel_flat > 0
40
-
41
- # Set non-kernel positions to -inf so they don't affect max
42
- patches_masked = patches.clone()
43
- patches_masked[..., ~kernel_mask] = float("-inf")
44
-
45
- # Take max over the kernel neighborhood
46
- max_vals = patches_masked.max(dim=-1)[0]
47
-
48
- return max_vals
49
-
50
-
51
13
  def crop_bboxes(
52
14
  images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor
53
15
  ) -> torch.Tensor:
54
- """Crop bounding boxes from a batch of images using fast tensor indexing.
55
-
56
- This uses tensor unfold operations to extract patches, which is significantly
57
- faster than kornia's crop_and_resize (17-51x speedup) as it avoids perspective
58
- transform computations.
16
+ """Crop bounding boxes from a batch of images.
59
17
 
60
18
  Args:
61
19
  images: Tensor of shape (samples, channels, height, width) of a batch of images.
@@ -69,7 +27,7 @@ def crop_bboxes(
69
27
  box should be cropped from.
70
28
 
71
29
  Returns:
72
- A tensor of shape (n_bboxes, channels, crop_height, crop_width) of the same
30
+ A tensor of shape (n_bboxes, crop_height, crop_width, channels) of the same
73
31
  dtype as the input image. The crop size is inferred from the bounding box
74
32
  coordinates.
75
33
 
@@ -84,52 +42,25 @@ def crop_bboxes(
84
42
 
85
43
  See also: `make_centered_bboxes`
86
44
  """
87
- n_crops = bboxes.shape[0]
88
- if n_crops == 0:
89
- # Return empty tensor; use default crop size since we can't infer from bboxes
90
- return torch.empty(
91
- 0, images.shape[1], 0, 0, device=images.device, dtype=images.dtype
92
- )
93
-
94
45
  # Compute bounding box size to use for crops.
95
- height = int(abs(bboxes[0, 3, 1] - bboxes[0, 0, 1]).item()) + 1
96
- width = int(abs(bboxes[0, 1, 0] - bboxes[0, 0, 0]).item()) + 1
46
+ height = abs(bboxes[0, 3, 1] - bboxes[0, 0, 1])
47
+ width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
48
+ box_size = tuple(torch.round(torch.Tensor((height + 1, width + 1))).to(torch.int32))
97
49
 
98
50
  # Store original dtype for conversion back after cropping.
99
51
  original_dtype = images.dtype
100
- device = images.device
101
- n_samples, channels, img_h, img_w = images.shape
102
- half_h, half_w = height // 2, width // 2
103
52
 
104
- # Pad images for edge handling.
105
- images_padded = F.pad(
106
- images.float(), (half_w, half_w, half_h, half_h), mode="constant", value=0
107
- )
53
+ # Kornia's crop_and_resize requires float32 input.
54
+ images_to_crop = images[sample_inds]
55
+ if not torch.is_floating_point(images_to_crop):
56
+ images_to_crop = images_to_crop.float()
108
57
 
109
- # Extract all possible patches using unfold (creates a view, no copy).
110
- # Shape after unfold: (n_samples, channels, img_h, img_w, height, width)
111
- patches = images_padded.unfold(2, height, 1).unfold(3, width, 1)
112
-
113
- # Get crop centers from bboxes.
114
- # The bbox top-left is at index 0, with (x, y) coordinates.
115
- # We need the center of the crop (peak location), which is top-left + half_size.
116
- # Ensure bboxes are on the same device as images for index computation.
117
- bboxes_on_device = bboxes.to(device)
118
- crop_x = (bboxes_on_device[:, 0, 0] + half_w).to(torch.long)
119
- crop_y = (bboxes_on_device[:, 0, 1] + half_h).to(torch.long)
120
-
121
- # Clamp indices to valid bounds to handle edge cases where centroids
122
- # might be at or beyond image boundaries.
123
- crop_x = torch.clamp(crop_x, 0, patches.shape[3] - 1)
124
- crop_y = torch.clamp(crop_y, 0, patches.shape[2] - 1)
125
-
126
- # Select crops using advanced indexing.
127
- # Convert sample_inds to tensor if it's a list.
128
- if not isinstance(sample_inds, torch.Tensor):
129
- sample_inds = torch.tensor(sample_inds, device=device)
130
- sample_inds_long = sample_inds.to(device=device, dtype=torch.long)
131
- crops = patches[sample_inds_long, :, crop_y, crop_x]
132
- # Shape: (n_crops, channels, height, width)
58
+ # Crop.
59
+ crops = crop_and_resize(
60
+ images_to_crop, # (n_boxes, channels, height, width)
61
+ boxes=bboxes,
62
+ size=box_size,
63
+ )
133
64
 
134
65
  # Cast back to original dtype and return.
135
66
  crops = crops.to(original_dtype)
@@ -313,7 +244,7 @@ def find_local_peaks_rough(
313
244
  flat_img = cms.reshape(-1, 1, height, width)
314
245
 
315
246
  # Perform dilation filtering to find local maxima per channel and reshape back.
316
- max_img = morphological_dilation(flat_img, kernel.to(flat_img.device))
247
+ max_img = K.morphology.dilation(flat_img, kernel.to(flat_img.device))
317
248
  max_img = max_img.reshape(-1, channels, height, width)
318
249
 
319
250
  # Filter for maxima and threshold.