sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
"""Provenance metadata utilities for inference outputs.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for building and managing provenance metadata
|
|
4
|
+
that is stored in SLP files produced during inference. Provenance metadata
|
|
5
|
+
helps track where predictions came from and how they were generated.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Optional, Union
|
|
11
|
+
|
|
12
|
+
import sleap_io as sio
|
|
13
|
+
|
|
14
|
+
import sleap_nn
|
|
15
|
+
from sleap_nn.system_info import get_system_info_dict
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def build_inference_provenance(
|
|
19
|
+
model_paths: Optional[list[str]] = None,
|
|
20
|
+
model_type: Optional[str] = None,
|
|
21
|
+
start_time: Optional[datetime] = None,
|
|
22
|
+
end_time: Optional[datetime] = None,
|
|
23
|
+
input_labels: Optional[sio.Labels] = None,
|
|
24
|
+
input_path: Optional[Union[str, Path]] = None,
|
|
25
|
+
frames_processed: Optional[int] = None,
|
|
26
|
+
frames_total: Optional[int] = None,
|
|
27
|
+
frame_selection_method: Optional[str] = None,
|
|
28
|
+
inference_params: Optional[dict[str, Any]] = None,
|
|
29
|
+
tracking_params: Optional[dict[str, Any]] = None,
|
|
30
|
+
device: Optional[str] = None,
|
|
31
|
+
cli_args: Optional[dict[str, Any]] = None,
|
|
32
|
+
include_system_info: bool = True,
|
|
33
|
+
) -> dict[str, Any]:
|
|
34
|
+
"""Build provenance metadata dictionary for inference output.
|
|
35
|
+
|
|
36
|
+
This function creates a comprehensive provenance dictionary that captures
|
|
37
|
+
all relevant metadata about an inference run, enabling reproducibility
|
|
38
|
+
and tracking of prediction origins.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model_paths: List of paths to model checkpoints used for inference.
|
|
42
|
+
model_type: Type of model used (e.g., "top_down", "bottom_up",
|
|
43
|
+
"single_instance").
|
|
44
|
+
start_time: Datetime when inference started.
|
|
45
|
+
end_time: Datetime when inference finished.
|
|
46
|
+
input_labels: Input Labels object if inference was run on an SLP file.
|
|
47
|
+
The provenance from this object will be preserved.
|
|
48
|
+
input_path: Path to input file (SLP or video).
|
|
49
|
+
frames_processed: Number of frames that were processed.
|
|
50
|
+
frames_total: Total number of frames in the input.
|
|
51
|
+
frame_selection_method: Method used to select frames (e.g., "all",
|
|
52
|
+
"labeled", "suggested", "range").
|
|
53
|
+
inference_params: Dictionary of inference parameters (peak_threshold,
|
|
54
|
+
integral_refinement, batch_size, etc.).
|
|
55
|
+
tracking_params: Dictionary of tracking parameters if tracking was run.
|
|
56
|
+
device: Device used for inference (e.g., "cuda:0", "cpu", "mps").
|
|
57
|
+
cli_args: Command-line arguments if available.
|
|
58
|
+
include_system_info: If True, include detailed system information.
|
|
59
|
+
Set to False for lighter-weight provenance.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Dictionary containing provenance metadata suitable for storing in
|
|
63
|
+
Labels.provenance.
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
>>> from datetime import datetime
|
|
67
|
+
>>> provenance = build_inference_provenance(
|
|
68
|
+
... model_paths=["/path/to/model.ckpt"],
|
|
69
|
+
... model_type="top_down",
|
|
70
|
+
... start_time=datetime.now(),
|
|
71
|
+
... end_time=datetime.now(),
|
|
72
|
+
... device="cuda:0",
|
|
73
|
+
... )
|
|
74
|
+
>>> labels.provenance = provenance
|
|
75
|
+
>>> labels.save("predictions.slp")
|
|
76
|
+
"""
|
|
77
|
+
provenance: dict[str, Any] = {}
|
|
78
|
+
|
|
79
|
+
# Timestamps
|
|
80
|
+
if start_time is not None:
|
|
81
|
+
provenance["inference_start_timestamp"] = start_time.isoformat()
|
|
82
|
+
if end_time is not None:
|
|
83
|
+
provenance["inference_end_timestamp"] = end_time.isoformat()
|
|
84
|
+
if start_time is not None and end_time is not None:
|
|
85
|
+
runtime_seconds = (end_time - start_time).total_seconds()
|
|
86
|
+
provenance["inference_runtime_seconds"] = runtime_seconds
|
|
87
|
+
|
|
88
|
+
# Version information
|
|
89
|
+
provenance["sleap_nn_version"] = sleap_nn.__version__
|
|
90
|
+
provenance["sleap_io_version"] = sio.__version__
|
|
91
|
+
|
|
92
|
+
# Model information
|
|
93
|
+
if model_paths is not None:
|
|
94
|
+
# Store as absolute POSIX paths for cross-platform compatibility
|
|
95
|
+
provenance["model_paths"] = [
|
|
96
|
+
Path(p).resolve().as_posix() if isinstance(p, (str, Path)) else str(p)
|
|
97
|
+
for p in model_paths
|
|
98
|
+
]
|
|
99
|
+
if model_type is not None:
|
|
100
|
+
provenance["model_type"] = model_type
|
|
101
|
+
|
|
102
|
+
# Input data lineage
|
|
103
|
+
if input_path is not None:
|
|
104
|
+
provenance["source_file"] = (
|
|
105
|
+
Path(input_path).resolve().as_posix()
|
|
106
|
+
if isinstance(input_path, (str, Path))
|
|
107
|
+
else str(input_path)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Preserve input provenance if available
|
|
111
|
+
if input_labels is not None and hasattr(input_labels, "provenance"):
|
|
112
|
+
input_prov = dict(input_labels.provenance)
|
|
113
|
+
if input_prov:
|
|
114
|
+
provenance["input_provenance"] = input_prov
|
|
115
|
+
# Also set source_labels for compatibility with sleap-io conventions
|
|
116
|
+
if "filename" in input_prov:
|
|
117
|
+
provenance["source_labels"] = input_prov["filename"]
|
|
118
|
+
|
|
119
|
+
# Frame selection information
|
|
120
|
+
if frames_processed is not None or frames_total is not None:
|
|
121
|
+
frame_info: dict[str, Any] = {}
|
|
122
|
+
if frame_selection_method is not None:
|
|
123
|
+
frame_info["method"] = frame_selection_method
|
|
124
|
+
if frames_processed is not None:
|
|
125
|
+
frame_info["frames_processed"] = frames_processed
|
|
126
|
+
if frames_total is not None:
|
|
127
|
+
frame_info["frames_total"] = frames_total
|
|
128
|
+
if frame_info:
|
|
129
|
+
provenance["frame_selection"] = frame_info
|
|
130
|
+
|
|
131
|
+
# Inference parameters
|
|
132
|
+
if inference_params is not None:
|
|
133
|
+
# Filter out None values and convert paths
|
|
134
|
+
clean_params = {}
|
|
135
|
+
for key, value in inference_params.items():
|
|
136
|
+
if value is not None:
|
|
137
|
+
if isinstance(value, Path):
|
|
138
|
+
clean_params[key] = value.as_posix()
|
|
139
|
+
else:
|
|
140
|
+
clean_params[key] = value
|
|
141
|
+
if clean_params:
|
|
142
|
+
provenance["inference_config"] = clean_params
|
|
143
|
+
|
|
144
|
+
# Tracking parameters
|
|
145
|
+
if tracking_params is not None:
|
|
146
|
+
clean_tracking = {k: v for k, v in tracking_params.items() if v is not None}
|
|
147
|
+
if clean_tracking:
|
|
148
|
+
provenance["tracking_config"] = clean_tracking
|
|
149
|
+
|
|
150
|
+
# Device information
|
|
151
|
+
if device is not None:
|
|
152
|
+
provenance["device"] = device
|
|
153
|
+
|
|
154
|
+
# CLI arguments
|
|
155
|
+
if cli_args is not None:
|
|
156
|
+
# Filter out None values
|
|
157
|
+
clean_cli = {k: v for k, v in cli_args.items() if v is not None}
|
|
158
|
+
if clean_cli:
|
|
159
|
+
provenance["cli_args"] = clean_cli
|
|
160
|
+
|
|
161
|
+
# System information (can be disabled for lighter provenance)
|
|
162
|
+
if include_system_info:
|
|
163
|
+
try:
|
|
164
|
+
system_info = get_system_info_dict()
|
|
165
|
+
# Extract key fields for provenance (avoid excessive nesting)
|
|
166
|
+
provenance["system_info"] = {
|
|
167
|
+
"python_version": system_info.get("python_version"),
|
|
168
|
+
"platform": system_info.get("platform"),
|
|
169
|
+
"pytorch_version": system_info.get("pytorch_version"),
|
|
170
|
+
"cuda_version": system_info.get("cuda_version"),
|
|
171
|
+
"accelerator": system_info.get("accelerator"),
|
|
172
|
+
"gpu_count": system_info.get("gpu_count"),
|
|
173
|
+
}
|
|
174
|
+
# Include GPU names if available
|
|
175
|
+
if system_info.get("gpus"):
|
|
176
|
+
provenance["system_info"]["gpus"] = [
|
|
177
|
+
gpu.get("name") for gpu in system_info["gpus"]
|
|
178
|
+
]
|
|
179
|
+
except Exception:
|
|
180
|
+
# Don't fail inference if system info collection fails
|
|
181
|
+
pass
|
|
182
|
+
|
|
183
|
+
return provenance
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def build_tracking_only_provenance(
|
|
187
|
+
input_labels: Optional[sio.Labels] = None,
|
|
188
|
+
input_path: Optional[Union[str, Path]] = None,
|
|
189
|
+
start_time: Optional[datetime] = None,
|
|
190
|
+
end_time: Optional[datetime] = None,
|
|
191
|
+
tracking_params: Optional[dict[str, Any]] = None,
|
|
192
|
+
frames_processed: Optional[int] = None,
|
|
193
|
+
include_system_info: bool = True,
|
|
194
|
+
) -> dict[str, Any]:
|
|
195
|
+
"""Build provenance metadata for tracking-only pipeline.
|
|
196
|
+
|
|
197
|
+
This is a simplified version of build_inference_provenance for when
|
|
198
|
+
only tracking is run without model inference.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
input_labels: Input Labels object with existing predictions.
|
|
202
|
+
input_path: Path to input SLP file.
|
|
203
|
+
start_time: Datetime when tracking started.
|
|
204
|
+
end_time: Datetime when tracking finished.
|
|
205
|
+
tracking_params: Dictionary of tracking parameters.
|
|
206
|
+
frames_processed: Number of frames that were tracked.
|
|
207
|
+
include_system_info: If True, include system information.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Dictionary containing provenance metadata.
|
|
211
|
+
"""
|
|
212
|
+
provenance: dict[str, Any] = {}
|
|
213
|
+
|
|
214
|
+
# Timestamps
|
|
215
|
+
if start_time is not None:
|
|
216
|
+
provenance["tracking_start_timestamp"] = start_time.isoformat()
|
|
217
|
+
if end_time is not None:
|
|
218
|
+
provenance["tracking_end_timestamp"] = end_time.isoformat()
|
|
219
|
+
if start_time is not None and end_time is not None:
|
|
220
|
+
runtime_seconds = (end_time - start_time).total_seconds()
|
|
221
|
+
provenance["tracking_runtime_seconds"] = runtime_seconds
|
|
222
|
+
|
|
223
|
+
# Version information
|
|
224
|
+
provenance["sleap_nn_version"] = sleap_nn.__version__
|
|
225
|
+
provenance["sleap_io_version"] = sio.__version__
|
|
226
|
+
|
|
227
|
+
# Note that this is tracking-only
|
|
228
|
+
provenance["pipeline_type"] = "tracking_only"
|
|
229
|
+
|
|
230
|
+
# Input data lineage
|
|
231
|
+
if input_path is not None:
|
|
232
|
+
provenance["source_file"] = (
|
|
233
|
+
Path(input_path).resolve().as_posix()
|
|
234
|
+
if isinstance(input_path, (str, Path))
|
|
235
|
+
else str(input_path)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Preserve input provenance if available
|
|
239
|
+
if input_labels is not None and hasattr(input_labels, "provenance"):
|
|
240
|
+
input_prov = dict(input_labels.provenance)
|
|
241
|
+
if input_prov:
|
|
242
|
+
provenance["input_provenance"] = input_prov
|
|
243
|
+
if "filename" in input_prov:
|
|
244
|
+
provenance["source_labels"] = input_prov["filename"]
|
|
245
|
+
|
|
246
|
+
# Frame information
|
|
247
|
+
if frames_processed is not None:
|
|
248
|
+
provenance["frames_processed"] = frames_processed
|
|
249
|
+
|
|
250
|
+
# Tracking parameters
|
|
251
|
+
if tracking_params is not None:
|
|
252
|
+
clean_tracking = {k: v for k, v in tracking_params.items() if v is not None}
|
|
253
|
+
if clean_tracking:
|
|
254
|
+
provenance["tracking_config"] = clean_tracking
|
|
255
|
+
|
|
256
|
+
# System information
|
|
257
|
+
if include_system_info:
|
|
258
|
+
try:
|
|
259
|
+
system_info = get_system_info_dict()
|
|
260
|
+
provenance["system_info"] = {
|
|
261
|
+
"python_version": system_info.get("python_version"),
|
|
262
|
+
"platform": system_info.get("platform"),
|
|
263
|
+
"pytorch_version": system_info.get("pytorch_version"),
|
|
264
|
+
"accelerator": system_info.get("accelerator"),
|
|
265
|
+
}
|
|
266
|
+
except Exception:
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
return provenance
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def merge_provenance(
|
|
273
|
+
base_provenance: dict[str, Any],
|
|
274
|
+
additional: dict[str, Any],
|
|
275
|
+
overwrite: bool = True,
|
|
276
|
+
) -> dict[str, Any]:
|
|
277
|
+
"""Merge additional provenance fields into base provenance.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
base_provenance: Base provenance dictionary.
|
|
281
|
+
additional: Additional fields to merge.
|
|
282
|
+
overwrite: If True, additional fields overwrite base fields.
|
|
283
|
+
If False, base fields take precedence.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Merged provenance dictionary.
|
|
287
|
+
"""
|
|
288
|
+
result = dict(base_provenance)
|
|
289
|
+
for key, value in additional.items():
|
|
290
|
+
if key not in result or overwrite:
|
|
291
|
+
result[key] = value
|
|
292
|
+
return result
|
sleap_nn/inference/topdown.py
CHANGED
|
@@ -47,9 +47,6 @@ class CentroidCrop(L.LightningModule):
|
|
|
47
47
|
crop_hw: Tuple (height, width) representing the crop size.
|
|
48
48
|
input_scale: Float indicating if the images should be resized before being
|
|
49
49
|
passed to the model.
|
|
50
|
-
precrop_resize: Float indicating the factor by which the original images
|
|
51
|
-
(not images resized for centroid model) should be resized before cropping.
|
|
52
|
-
Note: This resize happens only after getting the predictions for centroid model.
|
|
53
50
|
max_stride: Maximum stride in a model that the images must be divisible by.
|
|
54
51
|
If > 1, this will pad the bottom and right of the images to ensure they meet
|
|
55
52
|
this divisibility criteria. Padding is applied after the scaling specified
|
|
@@ -74,7 +71,6 @@ class CentroidCrop(L.LightningModule):
|
|
|
74
71
|
return_crops: bool = False,
|
|
75
72
|
crop_hw: Optional[List[int]] = None,
|
|
76
73
|
input_scale: float = 1.0,
|
|
77
|
-
precrop_resize: float = 1.0,
|
|
78
74
|
max_stride: int = 1,
|
|
79
75
|
use_gt_centroids: bool = False,
|
|
80
76
|
anchor_ind: Optional[int] = None,
|
|
@@ -92,22 +88,25 @@ class CentroidCrop(L.LightningModule):
|
|
|
92
88
|
self.return_crops = return_crops
|
|
93
89
|
self.crop_hw = crop_hw
|
|
94
90
|
self.input_scale = input_scale
|
|
95
|
-
self.precrop_resize = precrop_resize
|
|
96
91
|
self.max_stride = max_stride
|
|
97
92
|
self.use_gt_centroids = use_gt_centroids
|
|
98
93
|
self.anchor_ind = anchor_ind
|
|
99
94
|
|
|
100
|
-
def _generate_crops(self, inputs):
|
|
95
|
+
def _generate_crops(self, inputs, cms: Optional[torch.Tensor] = None):
|
|
101
96
|
"""Generate Crops from the predicted centroids."""
|
|
102
97
|
crops_dict = []
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
98
|
+
if cms is not None:
|
|
99
|
+
cms = cms.detach()
|
|
100
|
+
for idx, (centroid, centroid_val, image, fidx, vidx, sz, eff_sc) in enumerate(
|
|
101
|
+
zip(
|
|
102
|
+
self.refined_peaks_batched,
|
|
103
|
+
self.peak_vals_batched,
|
|
104
|
+
inputs["image"],
|
|
105
|
+
inputs["frame_idx"],
|
|
106
|
+
inputs["video_idx"],
|
|
107
|
+
inputs["orig_size"],
|
|
108
|
+
inputs["eff_scale"],
|
|
109
|
+
)
|
|
111
110
|
):
|
|
112
111
|
if torch.any(torch.isnan(centroid)):
|
|
113
112
|
if torch.all(torch.isnan(centroid)):
|
|
@@ -149,6 +148,11 @@ class CentroidCrop(L.LightningModule):
|
|
|
149
148
|
ex["instance_image"] = instance_image.unsqueeze(dim=1)
|
|
150
149
|
ex["orig_size"] = torch.cat([torch.Tensor(sz)] * n)
|
|
151
150
|
ex["eff_scale"] = torch.Tensor([eff_sc] * n)
|
|
151
|
+
ex["pred_centroids"] = centroid
|
|
152
|
+
if self.return_confmaps:
|
|
153
|
+
ex["pred_centroid_confmaps"] = torch.cat(
|
|
154
|
+
[cms[idx].unsqueeze(dim=0)] * n
|
|
155
|
+
)
|
|
152
156
|
crops_dict.append(ex)
|
|
153
157
|
|
|
154
158
|
return crops_dict
|
|
@@ -204,12 +208,6 @@ class CentroidCrop(L.LightningModule):
|
|
|
204
208
|
|
|
205
209
|
if self.return_crops:
|
|
206
210
|
crops_dict = self._generate_crops(inputs)
|
|
207
|
-
inputs["image"] = resize_image(inputs["image"], self.precrop_resize)
|
|
208
|
-
inputs["centroids"] *= self.precrop_resize
|
|
209
|
-
scaled_refined_peaks = []
|
|
210
|
-
for ref_peak in self.refined_peaks_batched:
|
|
211
|
-
scaled_refined_peaks.append(ref_peak * self.precrop_resize)
|
|
212
|
-
self.refined_peaks_batched = scaled_refined_peaks
|
|
213
211
|
return crops_dict
|
|
214
212
|
else:
|
|
215
213
|
return inputs
|
|
@@ -274,19 +272,13 @@ class CentroidCrop(L.LightningModule):
|
|
|
274
272
|
|
|
275
273
|
# Generate crops if return_crops=True to pass the crops to CenteredInstance model.
|
|
276
274
|
if self.return_crops:
|
|
277
|
-
inputs["image"] = resize_image(inputs["image"], self.precrop_resize)
|
|
278
|
-
scaled_refined_peaks = []
|
|
279
|
-
for ref_peak in self.refined_peaks_batched:
|
|
280
|
-
scaled_refined_peaks.append(ref_peak * self.precrop_resize)
|
|
281
|
-
self.refined_peaks_batched = scaled_refined_peaks
|
|
282
|
-
|
|
283
275
|
inputs.update(
|
|
284
276
|
{
|
|
285
277
|
"centroids": self.refined_peaks_batched,
|
|
286
278
|
"centroid_vals": self.peak_vals_batched,
|
|
287
279
|
}
|
|
288
280
|
)
|
|
289
|
-
crops_dict = self._generate_crops(inputs)
|
|
281
|
+
crops_dict = self._generate_crops(inputs, cms)
|
|
290
282
|
return crops_dict
|
|
291
283
|
else:
|
|
292
284
|
# batch the peaks to pass it to FindInstancePeaksGroundTruth class.
|
|
@@ -359,7 +351,11 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
|
|
|
359
351
|
|
|
360
352
|
def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, np.array]:
|
|
361
353
|
"""Return the ground truth instance peaks given a set of crops."""
|
|
362
|
-
b, _,
|
|
354
|
+
b, _, _, nodes, _ = batch["instances"].shape
|
|
355
|
+
# Use number of centroids as max_inst to ensure consistent output shape
|
|
356
|
+
# This handles the case where max_instances limits centroids but instances
|
|
357
|
+
# tensor has a different (global) max_instances from the labels file
|
|
358
|
+
num_centroids = batch["centroids"].shape[2]
|
|
363
359
|
inst = (
|
|
364
360
|
batch["instances"].unsqueeze(dim=-4).float()
|
|
365
361
|
) # (batch, 1, 1, n_inst, nodes, 2)
|
|
@@ -389,26 +385,26 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
|
|
|
389
385
|
parsed = 0
|
|
390
386
|
for i in range(b):
|
|
391
387
|
if i not in matched_batch_inds:
|
|
392
|
-
batch_peaks = torch.full((
|
|
393
|
-
vals = torch.full((
|
|
388
|
+
batch_peaks = torch.full((num_centroids, nodes, 2), torch.nan)
|
|
389
|
+
vals = torch.full((num_centroids, nodes), torch.nan)
|
|
394
390
|
else:
|
|
395
391
|
c = counts[i]
|
|
396
392
|
batch_peaks = peaks_list[parsed : parsed + c]
|
|
397
393
|
num_inst = len(batch_peaks)
|
|
398
394
|
vals = torch.ones((num_inst, nodes))
|
|
399
|
-
if c <
|
|
395
|
+
if c < num_centroids:
|
|
400
396
|
batch_peaks = torch.cat(
|
|
401
397
|
[
|
|
402
398
|
batch_peaks,
|
|
403
|
-
torch.full((
|
|
399
|
+
torch.full((num_centroids - num_inst, nodes, 2), torch.nan),
|
|
404
400
|
]
|
|
405
401
|
)
|
|
406
402
|
vals = torch.cat(
|
|
407
|
-
[vals, torch.full((
|
|
403
|
+
[vals, torch.full((num_centroids - num_inst, nodes), torch.nan)]
|
|
408
404
|
)
|
|
409
405
|
else:
|
|
410
|
-
batch_peaks = batch_peaks[:
|
|
411
|
-
vals = vals[:
|
|
406
|
+
batch_peaks = batch_peaks[:num_centroids]
|
|
407
|
+
vals = vals[:num_centroids]
|
|
412
408
|
parsed += c
|
|
413
409
|
|
|
414
410
|
batch_peaks = batch_peaks.unsqueeze(dim=0)
|
|
@@ -432,33 +428,45 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
|
|
|
432
428
|
peaks_output["pred_instance_peaks"] = peaks
|
|
433
429
|
peaks_output["pred_peak_values"] = peaks_vals
|
|
434
430
|
|
|
435
|
-
batch_size
|
|
436
|
-
batch["centroids"].shape[0],
|
|
437
|
-
batch["centroids"].shape[2],
|
|
438
|
-
)
|
|
431
|
+
batch_size = batch["centroids"].shape[0]
|
|
439
432
|
output_dict = {}
|
|
440
433
|
output_dict["centroid"] = batch["centroids"].squeeze(dim=1).reshape(-1, 1, 2)
|
|
441
434
|
output_dict["centroid_val"] = batch["centroid_vals"].reshape(-1)
|
|
442
|
-
output_dict["pred_instance_peaks"] =
|
|
443
|
-
|
|
435
|
+
output_dict["pred_instance_peaks"] = peaks_output[
|
|
436
|
+
"pred_instance_peaks"
|
|
437
|
+
].reshape(-1, nodes, 2)
|
|
438
|
+
output_dict["pred_peak_values"] = peaks_output["pred_peak_values"].reshape(
|
|
439
|
+
-1, nodes
|
|
444
440
|
)
|
|
445
|
-
output_dict["pred_peak_values"] = batch["pred_peak_values"].reshape(-1, nodes)
|
|
446
441
|
output_dict["instance_bbox"] = torch.zeros(
|
|
447
442
|
(batch_size * num_centroids, 1, 4, 2)
|
|
448
443
|
)
|
|
449
444
|
frame_inds = []
|
|
450
445
|
video_inds = []
|
|
451
446
|
orig_szs = []
|
|
447
|
+
images = []
|
|
448
|
+
centroid_confmaps = []
|
|
452
449
|
for b_idx in range(b):
|
|
453
450
|
curr_batch_size = len(batch["centroids"][b_idx][0])
|
|
454
451
|
frame_inds.extend([batch["frame_idx"][b_idx]] * curr_batch_size)
|
|
455
452
|
video_inds.extend([batch["video_idx"][b_idx]] * curr_batch_size)
|
|
456
453
|
orig_szs.append(torch.cat([batch["orig_size"][b_idx]] * curr_batch_size))
|
|
454
|
+
images.append(
|
|
455
|
+
batch["image"][b_idx].unsqueeze(0).repeat(curr_batch_size, 1, 1, 1, 1)
|
|
456
|
+
)
|
|
457
|
+
if "pred_centroid_confmaps" in batch:
|
|
458
|
+
centroid_confmaps.append(
|
|
459
|
+
batch["pred_centroid_confmaps"][b_idx]
|
|
460
|
+
.unsqueeze(0)
|
|
461
|
+
.repeat(curr_batch_size, 1, 1, 1)
|
|
462
|
+
)
|
|
457
463
|
|
|
458
464
|
output_dict["frame_idx"] = torch.tensor(frame_inds)
|
|
459
465
|
output_dict["video_idx"] = torch.tensor(video_inds)
|
|
460
466
|
output_dict["orig_size"] = torch.concatenate(orig_szs, dim=0)
|
|
461
|
-
|
|
467
|
+
output_dict["image"] = torch.cat(images, dim=0)
|
|
468
|
+
if centroid_confmaps:
|
|
469
|
+
output_dict["pred_centroid_confmaps"] = torch.cat(centroid_confmaps, dim=0)
|
|
462
470
|
return output_dict
|
|
463
471
|
|
|
464
472
|
|
|
@@ -548,6 +556,8 @@ class FindInstancePeaks(L.LightningModule):
|
|
|
548
556
|
# Network forward pass.
|
|
549
557
|
# resize and pad the input image
|
|
550
558
|
input_image = inputs["instance_image"]
|
|
559
|
+
# resize the crop image
|
|
560
|
+
input_image = resize_image(input_image, self.input_scale)
|
|
551
561
|
if self.max_stride != 1:
|
|
552
562
|
input_image = apply_pad_to_stride(input_image, self.max_stride)
|
|
553
563
|
|
|
@@ -569,8 +579,6 @@ class FindInstancePeaks(L.LightningModule):
|
|
|
569
579
|
inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device)
|
|
570
580
|
)
|
|
571
581
|
|
|
572
|
-
inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale
|
|
573
|
-
|
|
574
582
|
inputs["instance_bbox"] = inputs["instance_bbox"] / (
|
|
575
583
|
inputs["eff_scale"]
|
|
576
584
|
.unsqueeze(dim=1)
|
|
@@ -679,6 +687,8 @@ class TopDownMultiClassFindInstancePeaks(L.LightningModule):
|
|
|
679
687
|
# Network forward pass.
|
|
680
688
|
# resize and pad the input image
|
|
681
689
|
input_image = inputs["instance_image"]
|
|
690
|
+
# resize the crop image
|
|
691
|
+
input_image = resize_image(input_image, self.input_scale)
|
|
682
692
|
if self.max_stride != 1:
|
|
683
693
|
input_image = apply_pad_to_stride(input_image, self.max_stride)
|
|
684
694
|
|
|
@@ -702,8 +712,6 @@ class TopDownMultiClassFindInstancePeaks(L.LightningModule):
|
|
|
702
712
|
inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device)
|
|
703
713
|
)
|
|
704
714
|
|
|
705
|
-
inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale
|
|
706
|
-
|
|
707
715
|
inputs["instance_bbox"] = inputs["instance_bbox"] / (
|
|
708
716
|
inputs["eff_scale"]
|
|
709
717
|
.unsqueeze(dim=1)
|
sleap_nn/legacy_models.py
CHANGED
|
@@ -7,9 +7,8 @@ TensorFlow/Keras backend to PyTorch format compatible with sleap-nn.
|
|
|
7
7
|
import h5py
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
10
|
-
from typing import Dict,
|
|
10
|
+
from typing import Dict, Any, Optional
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
from omegaconf import OmegaConf
|
|
13
12
|
import re
|
|
14
13
|
from loguru import logger
|
|
15
14
|
|
|
@@ -181,18 +180,61 @@ def parse_keras_layer_name(layer_path: str) -> Dict[str, Any]:
|
|
|
181
180
|
return info
|
|
182
181
|
|
|
183
182
|
|
|
183
|
+
def filter_legacy_weights_by_component(
|
|
184
|
+
legacy_weights: Dict[str, np.ndarray], component: Optional[str]
|
|
185
|
+
) -> Dict[str, np.ndarray]:
|
|
186
|
+
"""Filter legacy weights based on component type.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
legacy_weights: Dictionary of legacy weights from load_keras_weights()
|
|
190
|
+
component: Component type to filter for. One of:
|
|
191
|
+
- "backbone": Keep only encoder/decoder weights (exclude heads)
|
|
192
|
+
- "head": Keep only head layer weights
|
|
193
|
+
- None: No filtering (keep all weights)
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Filtered dictionary of legacy weights
|
|
197
|
+
"""
|
|
198
|
+
if component is None:
|
|
199
|
+
return legacy_weights
|
|
200
|
+
|
|
201
|
+
filtered = {}
|
|
202
|
+
for path, weight in legacy_weights.items():
|
|
203
|
+
# Check if this is a head layer (contains "Head" in the path)
|
|
204
|
+
is_head_layer = "Head" in path
|
|
205
|
+
|
|
206
|
+
if component == "backbone" and not is_head_layer:
|
|
207
|
+
filtered[path] = weight
|
|
208
|
+
elif component == "head" and is_head_layer:
|
|
209
|
+
filtered[path] = weight
|
|
210
|
+
|
|
211
|
+
return filtered
|
|
212
|
+
|
|
213
|
+
|
|
184
214
|
def map_legacy_to_pytorch_layers(
|
|
185
|
-
legacy_weights: Dict[str, np.ndarray],
|
|
215
|
+
legacy_weights: Dict[str, np.ndarray],
|
|
216
|
+
pytorch_model: torch.nn.Module,
|
|
217
|
+
component: Optional[str] = None,
|
|
186
218
|
) -> Dict[str, str]:
|
|
187
219
|
"""Create mapping between legacy Keras layers and PyTorch model layers.
|
|
188
220
|
|
|
189
221
|
Args:
|
|
190
222
|
legacy_weights: Dictionary of legacy weights from load_keras_weights()
|
|
191
223
|
pytorch_model: PyTorch model instance to map to
|
|
224
|
+
component: Optional component type for filtering weights before mapping.
|
|
225
|
+
One of "backbone", "head", or None (no filtering).
|
|
192
226
|
|
|
193
227
|
Returns:
|
|
194
228
|
Dictionary mapping legacy layer paths to PyTorch parameter names
|
|
195
229
|
"""
|
|
230
|
+
# Filter weights based on component type
|
|
231
|
+
filtered_weights = filter_legacy_weights_by_component(legacy_weights, component)
|
|
232
|
+
|
|
233
|
+
if component is not None:
|
|
234
|
+
logger.info(
|
|
235
|
+
f"Filtered legacy weights for {component}: "
|
|
236
|
+
f"{len(filtered_weights)}/{len(legacy_weights)} weights"
|
|
237
|
+
)
|
|
196
238
|
mapping = {}
|
|
197
239
|
|
|
198
240
|
# Get all PyTorch parameters with their shapes
|
|
@@ -201,7 +243,7 @@ def map_legacy_to_pytorch_layers(
|
|
|
201
243
|
pytorch_params[name] = param.shape
|
|
202
244
|
|
|
203
245
|
# For each legacy weight, find the corresponding PyTorch parameter
|
|
204
|
-
for legacy_path, weight in
|
|
246
|
+
for legacy_path, weight in filtered_weights.items():
|
|
205
247
|
# Extract the layer name from the legacy path
|
|
206
248
|
# Legacy path format: "model_weights/stack0_enc0_conv0/stack0_enc0_conv0/kernel:0"
|
|
207
249
|
clean_path = legacy_path.replace("model_weights/", "")
|
|
@@ -220,8 +262,6 @@ def map_legacy_to_pytorch_layers(
|
|
|
220
262
|
# This handles cases where Keras uses suffixes like _0, _1, etc.
|
|
221
263
|
if "Head" in layer_name:
|
|
222
264
|
# Remove trailing _N where N is a number
|
|
223
|
-
import re
|
|
224
|
-
|
|
225
265
|
layer_name_clean = re.sub(r"_\d+$", "", layer_name)
|
|
226
266
|
else:
|
|
227
267
|
layer_name_clean = layer_name
|
|
@@ -266,12 +306,17 @@ def map_legacy_to_pytorch_layers(
|
|
|
266
306
|
if not mapping:
|
|
267
307
|
logger.info(
|
|
268
308
|
f"No mappings could be created between legacy weights and PyTorch model. "
|
|
269
|
-
f"Legacy weights: {len(
|
|
309
|
+
f"Legacy weights: {len(filtered_weights)}, PyTorch parameters: {len(pytorch_params)}"
|
|
270
310
|
)
|
|
271
311
|
else:
|
|
272
312
|
logger.info(
|
|
273
|
-
f"Successfully mapped {len(mapping)}/{len(
|
|
313
|
+
f"Successfully mapped {len(mapping)}/{len(pytorch_params)} PyTorch parameters from legacy weights"
|
|
274
314
|
)
|
|
315
|
+
unmatched_count = len(filtered_weights) - len(mapping)
|
|
316
|
+
if unmatched_count > 0:
|
|
317
|
+
logger.warning(
|
|
318
|
+
f"({unmatched_count} legacy weights did not match any parameters in this model component)"
|
|
319
|
+
)
|
|
275
320
|
|
|
276
321
|
return mapping
|
|
277
322
|
|
|
@@ -280,6 +325,7 @@ def load_legacy_model_weights(
|
|
|
280
325
|
pytorch_model: torch.nn.Module,
|
|
281
326
|
h5_path: str,
|
|
282
327
|
mapping: Optional[Dict[str, str]] = None,
|
|
328
|
+
component: Optional[str] = None,
|
|
283
329
|
) -> None:
|
|
284
330
|
"""Load legacy Keras weights into a PyTorch model.
|
|
285
331
|
|
|
@@ -288,6 +334,10 @@ def load_legacy_model_weights(
|
|
|
288
334
|
h5_path: Path to the legacy .h5 model file
|
|
289
335
|
mapping: Optional manual mapping of layer names. If None,
|
|
290
336
|
will attempt automatic mapping.
|
|
337
|
+
component: Optional component type for filtering weights. One of:
|
|
338
|
+
- "backbone": Only load encoder/decoder weights (exclude heads)
|
|
339
|
+
- "head": Only load head layer weights
|
|
340
|
+
- None: Load all weights (default, for full model loading)
|
|
291
341
|
"""
|
|
292
342
|
# Load legacy weights
|
|
293
343
|
legacy_weights = load_keras_weights(h5_path)
|
|
@@ -295,7 +345,9 @@ def load_legacy_model_weights(
|
|
|
295
345
|
if mapping is None:
|
|
296
346
|
# Attempt automatic mapping
|
|
297
347
|
try:
|
|
298
|
-
mapping = map_legacy_to_pytorch_layers(
|
|
348
|
+
mapping = map_legacy_to_pytorch_layers(
|
|
349
|
+
legacy_weights, pytorch_model, component=component
|
|
350
|
+
)
|
|
299
351
|
except Exception as e:
|
|
300
352
|
logger.error(f"Failed to create weight mappings: {e}")
|
|
301
353
|
return
|
|
@@ -417,7 +469,9 @@ def load_legacy_model_weights(
|
|
|
417
469
|
).item()
|
|
418
470
|
diff = abs(keras_mean - torch_mean)
|
|
419
471
|
if diff > 1e-6:
|
|
420
|
-
message = f"Weight verification failed for {pytorch_name} linear): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
|
|
472
|
+
message = f"Weight verification failed for {pytorch_name} (linear): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
|
|
473
|
+
logger.error(message)
|
|
474
|
+
verification_errors.append(message)
|
|
421
475
|
else:
|
|
422
476
|
# Bias : just compare all values
|
|
423
477
|
keras_mean = np.mean(original_weight)
|
|
@@ -426,7 +480,7 @@ def load_legacy_model_weights(
|
|
|
426
480
|
).item()
|
|
427
481
|
diff = abs(keras_mean - torch_mean)
|
|
428
482
|
if diff > 1e-6:
|
|
429
|
-
message = f"Weight verification failed for {pytorch_name} bias): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
|
|
483
|
+
message = f"Weight verification failed for {pytorch_name} (bias): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
|
|
430
484
|
logger.error(message)
|
|
431
485
|
verification_errors.append(message)
|
|
432
486
|
|