sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -5
- sleap_nn/config/trainer_config.py +0 -71
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +34 -364
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -69
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +204 -201
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
sleap_nn/data/utils.py
CHANGED
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
"""Miscellaneous utility functions for data processing."""
|
|
2
2
|
|
|
3
3
|
from typing import Tuple, List, Any, Optional
|
|
4
|
-
import sys
|
|
5
4
|
import torch
|
|
6
5
|
from omegaconf import DictConfig
|
|
7
6
|
import sleap_io as sio
|
|
8
7
|
from sleap_nn.config.utils import get_model_type_from_cfg
|
|
9
8
|
import psutil
|
|
10
9
|
import numpy as np
|
|
11
|
-
from loguru import logger
|
|
12
10
|
from sleap_nn.data.providers import get_max_instances
|
|
13
11
|
|
|
14
12
|
|
|
@@ -117,151 +115,35 @@ def check_memory(
|
|
|
117
115
|
return img_mem
|
|
118
116
|
|
|
119
117
|
|
|
120
|
-
def estimate_cache_memory(
|
|
121
|
-
train_labels: List[sio.Labels],
|
|
122
|
-
val_labels: List[sio.Labels],
|
|
123
|
-
num_workers: int = 0,
|
|
124
|
-
memory_buffer: float = 0.2,
|
|
125
|
-
) -> dict:
|
|
126
|
-
"""Estimate memory requirements for in-memory caching dataset pipeline.
|
|
127
|
-
|
|
128
|
-
This function calculates the total memory needed for caching images, accounting for:
|
|
129
|
-
- Raw image data size
|
|
130
|
-
- Python object overhead (dictionary keys, numpy array wrappers)
|
|
131
|
-
- DataLoader worker memory overhead (Copy-on-Write duplication on Unix systems)
|
|
132
|
-
- General memory buffer for training overhead
|
|
133
|
-
|
|
134
|
-
When using DataLoader with num_workers > 0, worker processes are spawned via fork()
|
|
135
|
-
on Unix systems. While Copy-on-Write (CoW) initially shares memory, Python's reference
|
|
136
|
-
counting can trigger memory page duplication when workers access cached data.
|
|
137
|
-
|
|
138
|
-
Args:
|
|
139
|
-
train_labels: List of `sleap_io.Labels` objects for training data.
|
|
140
|
-
val_labels: List of `sleap_io.Labels` objects for validation data.
|
|
141
|
-
num_workers: Number of DataLoader worker processes. When > 0, additional memory
|
|
142
|
-
overhead is estimated for worker process duplication.
|
|
143
|
-
memory_buffer: Fraction of memory to reserve as buffer for training overhead
|
|
144
|
-
(model weights, activations, gradients, etc.). Default: 0.2 (20%).
|
|
145
|
-
|
|
146
|
-
Returns:
|
|
147
|
-
dict: Memory estimation breakdown with keys:
|
|
148
|
-
- 'raw_cache_bytes': Raw image data size in bytes
|
|
149
|
-
- 'python_overhead_bytes': Estimated Python object overhead
|
|
150
|
-
- 'worker_overhead_bytes': Estimated memory for DataLoader workers
|
|
151
|
-
- 'buffer_bytes': Memory buffer for training overhead
|
|
152
|
-
- 'total_bytes': Total estimated memory requirement
|
|
153
|
-
- 'available_bytes': Available system memory
|
|
154
|
-
- 'sufficient': True if total <= available, False otherwise
|
|
155
|
-
"""
|
|
156
|
-
# Calculate raw image cache size
|
|
157
|
-
train_cache_bytes = 0
|
|
158
|
-
val_cache_bytes = 0
|
|
159
|
-
num_train_samples = 0
|
|
160
|
-
num_val_samples = 0
|
|
161
|
-
|
|
162
|
-
for train, val in zip(train_labels, val_labels):
|
|
163
|
-
train_cache_bytes += check_memory(train)
|
|
164
|
-
val_cache_bytes += check_memory(val)
|
|
165
|
-
num_train_samples += len(train)
|
|
166
|
-
num_val_samples += len(val)
|
|
167
|
-
|
|
168
|
-
raw_cache_bytes = train_cache_bytes + val_cache_bytes
|
|
169
|
-
total_samples = num_train_samples + num_val_samples
|
|
170
|
-
|
|
171
|
-
# Python object overhead: dict keys, numpy array wrappers, tuple keys
|
|
172
|
-
# Estimate ~200 bytes per sample for Python object overhead
|
|
173
|
-
python_overhead_per_sample = 200
|
|
174
|
-
python_overhead_bytes = total_samples * python_overhead_per_sample
|
|
175
|
-
|
|
176
|
-
# Worker memory overhead
|
|
177
|
-
# When num_workers > 0, workers are forked or spawned depending on platform.
|
|
178
|
-
# Default start methods (Python 3.8+):
|
|
179
|
-
# - Linux: fork (Copy-on-Write, partial memory duplication)
|
|
180
|
-
# - macOS: spawn (full dataset copy to each worker, changed in Python 3.8)
|
|
181
|
-
# - Windows: spawn (full dataset copy to each worker)
|
|
182
|
-
worker_overhead_bytes = 0
|
|
183
|
-
if num_workers > 0:
|
|
184
|
-
if sys.platform == "linux":
|
|
185
|
-
# Linux uses fork() with Copy-on-Write by default
|
|
186
|
-
# Estimate 25% duplication per worker due to Python refcounting
|
|
187
|
-
# triggering CoW page copies
|
|
188
|
-
worker_overhead_bytes = int(raw_cache_bytes * 0.25 * num_workers)
|
|
189
|
-
if num_workers >= 4:
|
|
190
|
-
logger.info(
|
|
191
|
-
f"Using in-memory caching with {num_workers} DataLoader workers. "
|
|
192
|
-
f"Estimated additional memory for workers: "
|
|
193
|
-
f"{worker_overhead_bytes / (1024**3):.2f} GB"
|
|
194
|
-
)
|
|
195
|
-
else:
|
|
196
|
-
# macOS (darwin) and Windows use spawn - dataset is copied to each worker
|
|
197
|
-
# Since Python 3.8, macOS defaults to spawn due to fork safety issues
|
|
198
|
-
# With caching enabled, we avoid pickling labels_list, but the cache
|
|
199
|
-
# dict is still part of the dataset and gets copied to each worker
|
|
200
|
-
worker_overhead_bytes = int(raw_cache_bytes * 0.5 * num_workers)
|
|
201
|
-
platform_name = "macOS" if sys.platform == "darwin" else "Windows"
|
|
202
|
-
logger.warning(
|
|
203
|
-
f"Using in-memory caching with {num_workers} DataLoader workers on {platform_name}. "
|
|
204
|
-
f"Memory usage may be significantly higher than estimated (~{worker_overhead_bytes / (1024**3):.1f} GB extra) "
|
|
205
|
-
f"due to spawn-based multiprocessing. "
|
|
206
|
-
f"Consider using disk caching or num_workers=0 for large datasets."
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
# Memory buffer for training overhead (model, gradients, activations)
|
|
210
|
-
subtotal = raw_cache_bytes + python_overhead_bytes + worker_overhead_bytes
|
|
211
|
-
buffer_bytes = int(subtotal * memory_buffer)
|
|
212
|
-
|
|
213
|
-
total_bytes = subtotal + buffer_bytes
|
|
214
|
-
available_bytes = psutil.virtual_memory().available
|
|
215
|
-
|
|
216
|
-
return {
|
|
217
|
-
"raw_cache_bytes": raw_cache_bytes,
|
|
218
|
-
"python_overhead_bytes": python_overhead_bytes,
|
|
219
|
-
"worker_overhead_bytes": worker_overhead_bytes,
|
|
220
|
-
"buffer_bytes": buffer_bytes,
|
|
221
|
-
"total_bytes": total_bytes,
|
|
222
|
-
"available_bytes": available_bytes,
|
|
223
|
-
"sufficient": total_bytes <= available_bytes,
|
|
224
|
-
"num_samples": total_samples,
|
|
225
|
-
}
|
|
226
|
-
|
|
227
|
-
|
|
228
118
|
def check_cache_memory(
|
|
229
119
|
train_labels: List[sio.Labels],
|
|
230
120
|
val_labels: List[sio.Labels],
|
|
231
121
|
memory_buffer: float = 0.2,
|
|
232
|
-
num_workers: int = 0,
|
|
233
122
|
) -> bool:
|
|
234
123
|
"""Check memory requirements for in-memory caching dataset pipeline.
|
|
235
124
|
|
|
236
|
-
This function determines if the system has sufficient memory for in-memory
|
|
237
|
-
image caching, accounting for DataLoader worker processes.
|
|
238
|
-
|
|
239
125
|
Args:
|
|
240
126
|
train_labels: List of `sleap_io.Labels` objects for training data.
|
|
241
127
|
val_labels: List of `sleap_io.Labels` objects for validation data.
|
|
242
|
-
memory_buffer: Fraction of
|
|
243
|
-
|
|
244
|
-
overhead is estimated for worker process duplication.
|
|
128
|
+
memory_buffer: Fraction of the total image memory required for caching that
|
|
129
|
+
should be reserved as a buffer.
|
|
245
130
|
|
|
246
131
|
Returns:
|
|
247
132
|
bool: True if the total memory required for caching is within available system
|
|
248
133
|
memory, False otherwise.
|
|
249
134
|
"""
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
return estimate["sufficient"]
|
|
135
|
+
train_cache_memory_final = 0
|
|
136
|
+
val_cache_memory_final = 0
|
|
137
|
+
for train, val in zip(train_labels, val_labels):
|
|
138
|
+
train_cache_memory = check_memory(train)
|
|
139
|
+
val_cache_memory = check_memory(val)
|
|
140
|
+
train_cache_memory_final += train_cache_memory
|
|
141
|
+
val_cache_memory_final += val_cache_memory
|
|
142
|
+
|
|
143
|
+
total_cache_memory = train_cache_memory_final + val_cache_memory_final
|
|
144
|
+
total_cache_memory += memory_buffer * total_cache_memory # memory required in bytes
|
|
145
|
+
available_memory = psutil.virtual_memory().available # available memory in bytes
|
|
146
|
+
|
|
147
|
+
if total_cache_memory > available_memory:
|
|
148
|
+
return False
|
|
149
|
+
return True
|
sleap_nn/evaluation.py
CHANGED
|
@@ -29,27 +29,11 @@ def get_instances(labeled_frame: sio.LabeledFrame) -> List[MatchInstance]:
|
|
|
29
29
|
"""
|
|
30
30
|
instance_list = []
|
|
31
31
|
frame_idx = labeled_frame.frame_idx
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
backend = getattr(video, "backend", None)
|
|
38
|
-
if backend is not None:
|
|
39
|
-
# Try source_filename first (for embedded videos with provenance)
|
|
40
|
-
video_path = getattr(backend, "source_filename", None)
|
|
41
|
-
if video_path is None:
|
|
42
|
-
video_path = getattr(backend, "filename", None)
|
|
43
|
-
# Fallback to video.filename if backend doesn't have it
|
|
44
|
-
if video_path is None:
|
|
45
|
-
video_path = getattr(video, "filename", None)
|
|
46
|
-
# Handle list filenames (image sequences)
|
|
47
|
-
if isinstance(video_path, list) and video_path:
|
|
48
|
-
video_path = video_path[0]
|
|
49
|
-
# Final fallback: use a unique identifier
|
|
50
|
-
if video_path is None:
|
|
51
|
-
video_path = f"video_{id(video)}" if video is not None else "unknown"
|
|
52
|
-
|
|
32
|
+
video_path = (
|
|
33
|
+
labeled_frame.video.backend.source_filename
|
|
34
|
+
if hasattr(labeled_frame.video.backend, "source_filename")
|
|
35
|
+
else labeled_frame.video.backend.filename
|
|
36
|
+
)
|
|
53
37
|
for instance in labeled_frame.instances:
|
|
54
38
|
match_instance = MatchInstance(
|
|
55
39
|
instance=instance, frame_idx=frame_idx, video_path=video_path
|
|
@@ -63,10 +47,6 @@ def find_frame_pairs(
|
|
|
63
47
|
) -> List[Tuple[sio.LabeledFrame, sio.LabeledFrame]]:
|
|
64
48
|
"""Find corresponding frames across two sets of labels.
|
|
65
49
|
|
|
66
|
-
This function uses sleap-io's robust video matching API to handle various
|
|
67
|
-
scenarios including embedded videos, cross-platform paths, and videos with
|
|
68
|
-
different metadata.
|
|
69
|
-
|
|
70
50
|
Args:
|
|
71
51
|
labels_gt: A `sio.Labels` instance with ground truth instances.
|
|
72
52
|
labels_pr: A `sio.Labels` instance with predicted instances.
|
|
@@ -76,15 +56,16 @@ def find_frame_pairs(
|
|
|
76
56
|
Returns:
|
|
77
57
|
A list of pairs of `sio.LabeledFrame`s in the form `(frame_gt, frame_pr)`.
|
|
78
58
|
"""
|
|
79
|
-
# Use sleap-io's robust video matching API (added in 0.6.2)
|
|
80
|
-
# The match() method returns a MatchResult with video_map: {pred_video: gt_video}
|
|
81
|
-
match_result = labels_gt.match(labels_pr)
|
|
82
|
-
|
|
83
59
|
frame_pairs = []
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
60
|
+
for video_gt in labels_gt.videos:
|
|
61
|
+
# Find matching video instance in predictions.
|
|
62
|
+
video_pr = None
|
|
63
|
+
for video in labels_pr.videos:
|
|
64
|
+
if video_gt.matches_content(video) and video_gt.matches_path(video):
|
|
65
|
+
video_pr = video
|
|
66
|
+
break
|
|
67
|
+
|
|
68
|
+
if video_pr is None:
|
|
88
69
|
continue
|
|
89
70
|
|
|
90
71
|
# Find labeled frames in this video.
|
|
@@ -639,19 +620,11 @@ class Evaluator:
|
|
|
639
620
|
mPCK_parts = pcks.mean(axis=0).mean(axis=-1)
|
|
640
621
|
mPCK = mPCK_parts.mean()
|
|
641
622
|
|
|
642
|
-
# Precompute PCK at common thresholds
|
|
643
|
-
idx_5 = np.argmin(np.abs(thresholds - 5))
|
|
644
|
-
idx_10 = np.argmin(np.abs(thresholds - 10))
|
|
645
|
-
pck5 = pcks[:, :, idx_5].mean()
|
|
646
|
-
pck10 = pcks[:, :, idx_10].mean()
|
|
647
|
-
|
|
648
623
|
return {
|
|
649
624
|
"thresholds": thresholds,
|
|
650
625
|
"pcks": pcks,
|
|
651
626
|
"mPCK_parts": mPCK_parts,
|
|
652
627
|
"mPCK": mPCK,
|
|
653
|
-
"PCK@5": pck5,
|
|
654
|
-
"PCK@10": pck10,
|
|
655
628
|
}
|
|
656
629
|
|
|
657
630
|
def visibility_metrics(self):
|
|
@@ -813,26 +786,11 @@ def run_evaluation(
|
|
|
813
786
|
"""Evaluate SLEAP-NN model predictions against ground truth labels."""
|
|
814
787
|
logger.info("Loading ground truth labels...")
|
|
815
788
|
ground_truth_instances = sio.load_slp(ground_truth_path)
|
|
816
|
-
logger.info(
|
|
817
|
-
f" Ground truth: {len(ground_truth_instances.videos)} videos, "
|
|
818
|
-
f"{len(ground_truth_instances.labeled_frames)} frames"
|
|
819
|
-
)
|
|
820
789
|
|
|
821
790
|
logger.info("Loading predicted labels...")
|
|
822
791
|
predicted_instances = sio.load_slp(predicted_path)
|
|
823
|
-
logger.info(
|
|
824
|
-
f" Predictions: {len(predicted_instances.videos)} videos, "
|
|
825
|
-
f"{len(predicted_instances.labeled_frames)} frames"
|
|
826
|
-
)
|
|
827
|
-
|
|
828
|
-
logger.info("Matching videos and frames...")
|
|
829
|
-
# Get match stats before creating evaluator
|
|
830
|
-
match_result = ground_truth_instances.match(predicted_instances)
|
|
831
|
-
logger.info(
|
|
832
|
-
f" Videos matched: {match_result.n_videos_matched}/{len(match_result.video_map)}"
|
|
833
|
-
)
|
|
834
792
|
|
|
835
|
-
logger.info("
|
|
793
|
+
logger.info("Creating evaluator...")
|
|
836
794
|
evaluator = Evaluator(
|
|
837
795
|
ground_truth_instances=ground_truth_instances,
|
|
838
796
|
predicted_instances=predicted_instances,
|
|
@@ -841,38 +799,21 @@ def run_evaluation(
|
|
|
841
799
|
match_threshold=match_threshold,
|
|
842
800
|
user_labels_only=user_labels_only,
|
|
843
801
|
)
|
|
844
|
-
logger.info(
|
|
845
|
-
f" Frame pairs: {len(evaluator.frame_pairs)}, "
|
|
846
|
-
f"Matched instances: {len(evaluator.positive_pairs)}, "
|
|
847
|
-
f"Unmatched GT: {len(evaluator.false_negatives)}"
|
|
848
|
-
)
|
|
849
802
|
|
|
850
803
|
logger.info("Computing evaluation metrics...")
|
|
851
804
|
metrics = evaluator.evaluate()
|
|
852
805
|
|
|
853
|
-
# Compute PCK at specific thresholds (5 and 10 pixels)
|
|
854
|
-
dists = metrics["distance_metrics"]["dists"]
|
|
855
|
-
dists_clean = np.copy(dists)
|
|
856
|
-
dists_clean[np.isnan(dists_clean)] = np.inf
|
|
857
|
-
pck_5 = (dists_clean < 5).mean()
|
|
858
|
-
pck_10 = (dists_clean < 10).mean()
|
|
859
|
-
|
|
860
806
|
# Print key metrics
|
|
861
807
|
logger.info("Evaluation Results:")
|
|
862
|
-
logger.info(f"
|
|
863
|
-
logger.info(f"
|
|
864
|
-
logger.info(f"
|
|
865
|
-
logger.info(f"
|
|
866
|
-
logger.info(f"
|
|
867
|
-
logger.info(f" dist.p95: {metrics['distance_metrics']['p95']:.2f} px")
|
|
868
|
-
logger.info(f" dist.p99: {metrics['distance_metrics']['p99']:.2f} px")
|
|
869
|
-
logger.info(f" mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
|
|
870
|
-
logger.info(f" PCK@5px: {pck_5:.4f}")
|
|
871
|
-
logger.info(f" PCK@10px: {pck_10:.4f}")
|
|
808
|
+
logger.info(f"mOKS: {metrics['mOKS']['mOKS']:.4f}")
|
|
809
|
+
logger.info(f"mAP (OKS VOC): {metrics['voc_metrics']['oks_voc.mAP']:.4f}")
|
|
810
|
+
logger.info(f"mAR (OKS VOC): {metrics['voc_metrics']['oks_voc.mAR']:.4f}")
|
|
811
|
+
logger.info(f"Average Distance: {metrics['distance_metrics']['avg']:.4f}")
|
|
812
|
+
logger.info(f"mPCK: {metrics['pck_metrics']['mPCK']:.4f}")
|
|
872
813
|
logger.info(
|
|
873
|
-
f"
|
|
814
|
+
f"Visibility Precision: {metrics['visibility_metrics']['precision']:.4f}"
|
|
874
815
|
)
|
|
875
|
-
logger.info(f"
|
|
816
|
+
logger.info(f"Visibility Recall: {metrics['visibility_metrics']['recall']:.4f}")
|
|
876
817
|
|
|
877
818
|
# Save metrics if path provided
|
|
878
819
|
if save_metrics:
|
sleap_nn/inference/bottomup.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
"""Inference modules for BottomUp models."""
|
|
2
2
|
|
|
3
|
-
import logging
|
|
4
3
|
from typing import Dict, Optional
|
|
5
4
|
import torch
|
|
6
5
|
import lightning as L
|
|
@@ -8,8 +7,6 @@ from sleap_nn.inference.peak_finding import find_local_peaks
|
|
|
8
7
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
9
8
|
from sleap_nn.inference.identity import classify_peaks_from_maps
|
|
10
9
|
|
|
11
|
-
logger = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
10
|
|
|
14
11
|
class BottomUpInferenceModel(L.LightningModule):
|
|
15
12
|
"""BottomUp Inference model.
|
|
@@ -66,28 +63,8 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
66
63
|
return_pafs: Optional[bool] = False,
|
|
67
64
|
return_paf_graph: Optional[bool] = False,
|
|
68
65
|
input_scale: float = 1.0,
|
|
69
|
-
max_peaks_per_node: Optional[int] = None,
|
|
70
66
|
):
|
|
71
|
-
"""Initialise the model attributes.
|
|
72
|
-
|
|
73
|
-
Args:
|
|
74
|
-
torch_model: A `nn.Module` that accepts images and predicts confidence maps.
|
|
75
|
-
paf_scorer: A `PAFScorer` instance for grouping instances.
|
|
76
|
-
cms_output_stride: Output stride of confidence maps relative to images.
|
|
77
|
-
pafs_output_stride: Output stride of PAFs relative to images.
|
|
78
|
-
peak_threshold: Minimum confidence map value for valid peaks.
|
|
79
|
-
refinement: Peak refinement method: None, "integral", or "local".
|
|
80
|
-
integral_patch_size: Size of patches for integral refinement.
|
|
81
|
-
return_confmaps: If True, return confidence maps in output.
|
|
82
|
-
return_pafs: If True, return PAFs in output.
|
|
83
|
-
return_paf_graph: If True, return intermediate PAF graph in output.
|
|
84
|
-
input_scale: Scale factor applied to input images.
|
|
85
|
-
max_peaks_per_node: Maximum number of peaks allowed per node before
|
|
86
|
-
skipping PAF scoring. If any node has more peaks than this limit,
|
|
87
|
-
empty predictions are returned. This prevents combinatorial explosion
|
|
88
|
-
during early training when confidence maps are noisy. Set to None to
|
|
89
|
-
disable this check (default). Recommended value: 100.
|
|
90
|
-
"""
|
|
67
|
+
"""Initialise the model attributes."""
|
|
91
68
|
super().__init__()
|
|
92
69
|
self.torch_model = torch_model
|
|
93
70
|
self.paf_scorer = paf_scorer
|
|
@@ -100,7 +77,6 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
100
77
|
self.return_pafs = return_pafs
|
|
101
78
|
self.return_paf_graph = return_paf_graph
|
|
102
79
|
self.input_scale = input_scale
|
|
103
|
-
self.max_peaks_per_node = max_peaks_per_node
|
|
104
80
|
|
|
105
81
|
def _generate_cms_peaks(self, cms):
|
|
106
82
|
# TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
|
|
@@ -148,68 +124,26 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
148
124
|
) # (batch, h, w, 2*edges)
|
|
149
125
|
cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
|
|
150
126
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
skip_paf_scoring = True
|
|
165
|
-
break
|
|
166
|
-
if skip_paf_scoring:
|
|
167
|
-
break
|
|
168
|
-
|
|
169
|
-
if skip_paf_scoring:
|
|
170
|
-
# Return empty predictions for each sample
|
|
171
|
-
device = cms.device
|
|
172
|
-
n_nodes = cms.shape[1]
|
|
173
|
-
predicted_instances_adjusted = []
|
|
174
|
-
predicted_peak_scores = []
|
|
175
|
-
predicted_instance_scores = []
|
|
176
|
-
for _ in range(self.batch_size):
|
|
177
|
-
predicted_instances_adjusted.append(
|
|
178
|
-
torch.full((0, n_nodes, 2), float("nan"), device=device)
|
|
179
|
-
)
|
|
180
|
-
predicted_peak_scores.append(
|
|
181
|
-
torch.full((0, n_nodes), float("nan"), device=device)
|
|
182
|
-
)
|
|
183
|
-
predicted_instance_scores.append(torch.tensor([], device=device))
|
|
184
|
-
edge_inds = [
|
|
185
|
-
torch.tensor([], dtype=torch.int32, device=device)
|
|
186
|
-
] * self.batch_size
|
|
187
|
-
edge_peak_inds = [
|
|
188
|
-
torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
|
|
189
|
-
] * self.batch_size
|
|
190
|
-
line_scores = [torch.tensor([], device=device)] * self.batch_size
|
|
191
|
-
else:
|
|
192
|
-
(
|
|
193
|
-
predicted_instances,
|
|
194
|
-
predicted_peak_scores,
|
|
195
|
-
predicted_instance_scores,
|
|
196
|
-
edge_inds,
|
|
197
|
-
edge_peak_inds,
|
|
198
|
-
line_scores,
|
|
199
|
-
) = self.paf_scorer.predict(
|
|
200
|
-
pafs=pafs,
|
|
201
|
-
peaks=cms_peaks,
|
|
202
|
-
peak_vals=cms_peak_vals,
|
|
203
|
-
peak_channel_inds=cms_peak_channel_inds,
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
predicted_instances = [p / self.input_scale for p in predicted_instances]
|
|
207
|
-
predicted_instances_adjusted = []
|
|
208
|
-
for idx, p in enumerate(predicted_instances):
|
|
209
|
-
predicted_instances_adjusted.append(
|
|
210
|
-
p / inputs["eff_scale"][idx].to(p.device)
|
|
211
|
-
)
|
|
127
|
+
(
|
|
128
|
+
predicted_instances,
|
|
129
|
+
predicted_peak_scores,
|
|
130
|
+
predicted_instance_scores,
|
|
131
|
+
edge_inds,
|
|
132
|
+
edge_peak_inds,
|
|
133
|
+
line_scores,
|
|
134
|
+
) = self.paf_scorer.predict(
|
|
135
|
+
pafs=pafs,
|
|
136
|
+
peaks=cms_peaks,
|
|
137
|
+
peak_vals=cms_peak_vals,
|
|
138
|
+
peak_channel_inds=cms_peak_channel_inds,
|
|
139
|
+
)
|
|
212
140
|
|
|
141
|
+
predicted_instances = [p / self.input_scale for p in predicted_instances]
|
|
142
|
+
predicted_instances_adjusted = []
|
|
143
|
+
for idx, p in enumerate(predicted_instances):
|
|
144
|
+
predicted_instances_adjusted.append(
|
|
145
|
+
p / inputs["eff_scale"][idx].to(p.device)
|
|
146
|
+
)
|
|
213
147
|
out = {
|
|
214
148
|
"pred_instance_peaks": predicted_instances_adjusted,
|
|
215
149
|
"pred_peak_values": predicted_peak_scores,
|
|
@@ -2,60 +2,18 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Optional, Tuple
|
|
4
4
|
|
|
5
|
+
import kornia as K
|
|
6
|
+
import numpy as np
|
|
5
7
|
import torch
|
|
6
|
-
|
|
8
|
+
from kornia.geometry.transform import crop_and_resize
|
|
7
9
|
|
|
8
10
|
from sleap_nn.data.instance_cropping import make_centered_bboxes
|
|
9
11
|
|
|
10
12
|
|
|
11
|
-
def morphological_dilation(image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
|
|
12
|
-
"""Apply morphological dilation using max pooling.
|
|
13
|
-
|
|
14
|
-
This is a pure PyTorch replacement for kornia.morphology.dilation.
|
|
15
|
-
For non-maximum suppression, it computes the maximum of 8 neighbors
|
|
16
|
-
(excluding the center pixel).
|
|
17
|
-
|
|
18
|
-
Args:
|
|
19
|
-
image: Input tensor of shape (B, 1, H, W).
|
|
20
|
-
kernel: Dilation kernel (3x3 expected for NMS).
|
|
21
|
-
|
|
22
|
-
Returns:
|
|
23
|
-
Dilated tensor of same shape as input.
|
|
24
|
-
"""
|
|
25
|
-
# Pad the image to handle border pixels
|
|
26
|
-
padded = F.pad(image, (1, 1, 1, 1), mode="constant", value=float("-inf"))
|
|
27
|
-
|
|
28
|
-
# Extract 3x3 patches using unfold
|
|
29
|
-
# Shape: (B, 1, H, W, 3, 3)
|
|
30
|
-
patches = padded.unfold(2, 3, 1).unfold(3, 3, 1)
|
|
31
|
-
|
|
32
|
-
# Reshape to (B, 1, H, W, 9)
|
|
33
|
-
b, c, h, w, kh, kw = patches.shape
|
|
34
|
-
patches = patches.reshape(b, c, h, w, -1)
|
|
35
|
-
|
|
36
|
-
# Apply kernel mask (kernel has 0 at center, 1 elsewhere for NMS)
|
|
37
|
-
# Reshape kernel to (1, 1, 1, 1, 9)
|
|
38
|
-
kernel_flat = kernel.reshape(-1).to(patches.device)
|
|
39
|
-
kernel_mask = kernel_flat > 0
|
|
40
|
-
|
|
41
|
-
# Set non-kernel positions to -inf so they don't affect max
|
|
42
|
-
patches_masked = patches.clone()
|
|
43
|
-
patches_masked[..., ~kernel_mask] = float("-inf")
|
|
44
|
-
|
|
45
|
-
# Take max over the kernel neighborhood
|
|
46
|
-
max_vals = patches_masked.max(dim=-1)[0]
|
|
47
|
-
|
|
48
|
-
return max_vals
|
|
49
|
-
|
|
50
|
-
|
|
51
13
|
def crop_bboxes(
|
|
52
14
|
images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor
|
|
53
15
|
) -> torch.Tensor:
|
|
54
|
-
"""Crop bounding boxes from a batch of images
|
|
55
|
-
|
|
56
|
-
This uses tensor unfold operations to extract patches, which is significantly
|
|
57
|
-
faster than kornia's crop_and_resize (17-51x speedup) as it avoids perspective
|
|
58
|
-
transform computations.
|
|
16
|
+
"""Crop bounding boxes from a batch of images.
|
|
59
17
|
|
|
60
18
|
Args:
|
|
61
19
|
images: Tensor of shape (samples, channels, height, width) of a batch of images.
|
|
@@ -69,7 +27,7 @@ def crop_bboxes(
|
|
|
69
27
|
box should be cropped from.
|
|
70
28
|
|
|
71
29
|
Returns:
|
|
72
|
-
A tensor of shape (n_bboxes,
|
|
30
|
+
A tensor of shape (n_bboxes, crop_height, crop_width, channels) of the same
|
|
73
31
|
dtype as the input image. The crop size is inferred from the bounding box
|
|
74
32
|
coordinates.
|
|
75
33
|
|
|
@@ -84,52 +42,25 @@ def crop_bboxes(
|
|
|
84
42
|
|
|
85
43
|
See also: `make_centered_bboxes`
|
|
86
44
|
"""
|
|
87
|
-
n_crops = bboxes.shape[0]
|
|
88
|
-
if n_crops == 0:
|
|
89
|
-
# Return empty tensor; use default crop size since we can't infer from bboxes
|
|
90
|
-
return torch.empty(
|
|
91
|
-
0, images.shape[1], 0, 0, device=images.device, dtype=images.dtype
|
|
92
|
-
)
|
|
93
|
-
|
|
94
45
|
# Compute bounding box size to use for crops.
|
|
95
|
-
height =
|
|
96
|
-
width =
|
|
46
|
+
height = abs(bboxes[0, 3, 1] - bboxes[0, 0, 1])
|
|
47
|
+
width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
|
|
48
|
+
box_size = tuple(torch.round(torch.Tensor((height + 1, width + 1))).to(torch.int32))
|
|
97
49
|
|
|
98
50
|
# Store original dtype for conversion back after cropping.
|
|
99
51
|
original_dtype = images.dtype
|
|
100
|
-
device = images.device
|
|
101
|
-
n_samples, channels, img_h, img_w = images.shape
|
|
102
|
-
half_h, half_w = height // 2, width // 2
|
|
103
52
|
|
|
104
|
-
#
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
53
|
+
# Kornia's crop_and_resize requires float32 input.
|
|
54
|
+
images_to_crop = images[sample_inds]
|
|
55
|
+
if not torch.is_floating_point(images_to_crop):
|
|
56
|
+
images_to_crop = images_to_crop.float()
|
|
108
57
|
|
|
109
|
-
#
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
# We need the center of the crop (peak location), which is top-left + half_size.
|
|
116
|
-
# Ensure bboxes are on the same device as images for index computation.
|
|
117
|
-
bboxes_on_device = bboxes.to(device)
|
|
118
|
-
crop_x = (bboxes_on_device[:, 0, 0] + half_w).to(torch.long)
|
|
119
|
-
crop_y = (bboxes_on_device[:, 0, 1] + half_h).to(torch.long)
|
|
120
|
-
|
|
121
|
-
# Clamp indices to valid bounds to handle edge cases where centroids
|
|
122
|
-
# might be at or beyond image boundaries.
|
|
123
|
-
crop_x = torch.clamp(crop_x, 0, patches.shape[3] - 1)
|
|
124
|
-
crop_y = torch.clamp(crop_y, 0, patches.shape[2] - 1)
|
|
125
|
-
|
|
126
|
-
# Select crops using advanced indexing.
|
|
127
|
-
# Convert sample_inds to tensor if it's a list.
|
|
128
|
-
if not isinstance(sample_inds, torch.Tensor):
|
|
129
|
-
sample_inds = torch.tensor(sample_inds, device=device)
|
|
130
|
-
sample_inds_long = sample_inds.to(device=device, dtype=torch.long)
|
|
131
|
-
crops = patches[sample_inds_long, :, crop_y, crop_x]
|
|
132
|
-
# Shape: (n_crops, channels, height, width)
|
|
58
|
+
# Crop.
|
|
59
|
+
crops = crop_and_resize(
|
|
60
|
+
images_to_crop, # (n_boxes, channels, height, width)
|
|
61
|
+
boxes=bboxes,
|
|
62
|
+
size=box_size,
|
|
63
|
+
)
|
|
133
64
|
|
|
134
65
|
# Cast back to original dtype and return.
|
|
135
66
|
crops = crops.to(original_dtype)
|
|
@@ -313,7 +244,7 @@ def find_local_peaks_rough(
|
|
|
313
244
|
flat_img = cms.reshape(-1, 1, height, width)
|
|
314
245
|
|
|
315
246
|
# Perform dilation filtering to find local maxima per channel and reshape back.
|
|
316
|
-
max_img =
|
|
247
|
+
max_img = K.morphology.dilation(flat_img, kernel.to(flat_img.device))
|
|
317
248
|
max_img = max_img.reshape(-1, channels, height, width)
|
|
318
249
|
|
|
319
250
|
# Filter for maxima and threshold.
|