sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
1
+ """Provenance metadata utilities for inference outputs.
2
+
3
+ This module provides utilities for building and managing provenance metadata
4
+ that is stored in SLP files produced during inference. Provenance metadata
5
+ helps track where predictions came from and how they were generated.
6
+ """
7
+
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from typing import Any, Optional, Union
11
+
12
+ import sleap_io as sio
13
+
14
+ import sleap_nn
15
+ from sleap_nn.system_info import get_system_info_dict
16
+
17
+
18
+ def build_inference_provenance(
19
+ model_paths: Optional[list[str]] = None,
20
+ model_type: Optional[str] = None,
21
+ start_time: Optional[datetime] = None,
22
+ end_time: Optional[datetime] = None,
23
+ input_labels: Optional[sio.Labels] = None,
24
+ input_path: Optional[Union[str, Path]] = None,
25
+ frames_processed: Optional[int] = None,
26
+ frames_total: Optional[int] = None,
27
+ frame_selection_method: Optional[str] = None,
28
+ inference_params: Optional[dict[str, Any]] = None,
29
+ tracking_params: Optional[dict[str, Any]] = None,
30
+ device: Optional[str] = None,
31
+ cli_args: Optional[dict[str, Any]] = None,
32
+ include_system_info: bool = True,
33
+ ) -> dict[str, Any]:
34
+ """Build provenance metadata dictionary for inference output.
35
+
36
+ This function creates a comprehensive provenance dictionary that captures
37
+ all relevant metadata about an inference run, enabling reproducibility
38
+ and tracking of prediction origins.
39
+
40
+ Args:
41
+ model_paths: List of paths to model checkpoints used for inference.
42
+ model_type: Type of model used (e.g., "top_down", "bottom_up",
43
+ "single_instance").
44
+ start_time: Datetime when inference started.
45
+ end_time: Datetime when inference finished.
46
+ input_labels: Input Labels object if inference was run on an SLP file.
47
+ The provenance from this object will be preserved.
48
+ input_path: Path to input file (SLP or video).
49
+ frames_processed: Number of frames that were processed.
50
+ frames_total: Total number of frames in the input.
51
+ frame_selection_method: Method used to select frames (e.g., "all",
52
+ "labeled", "suggested", "range").
53
+ inference_params: Dictionary of inference parameters (peak_threshold,
54
+ integral_refinement, batch_size, etc.).
55
+ tracking_params: Dictionary of tracking parameters if tracking was run.
56
+ device: Device used for inference (e.g., "cuda:0", "cpu", "mps").
57
+ cli_args: Command-line arguments if available.
58
+ include_system_info: If True, include detailed system information.
59
+ Set to False for lighter-weight provenance.
60
+
61
+ Returns:
62
+ Dictionary containing provenance metadata suitable for storing in
63
+ Labels.provenance.
64
+
65
+ Example:
66
+ >>> from datetime import datetime
67
+ >>> provenance = build_inference_provenance(
68
+ ... model_paths=["/path/to/model.ckpt"],
69
+ ... model_type="top_down",
70
+ ... start_time=datetime.now(),
71
+ ... end_time=datetime.now(),
72
+ ... device="cuda:0",
73
+ ... )
74
+ >>> labels.provenance = provenance
75
+ >>> labels.save("predictions.slp")
76
+ """
77
+ provenance: dict[str, Any] = {}
78
+
79
+ # Timestamps
80
+ if start_time is not None:
81
+ provenance["inference_start_timestamp"] = start_time.isoformat()
82
+ if end_time is not None:
83
+ provenance["inference_end_timestamp"] = end_time.isoformat()
84
+ if start_time is not None and end_time is not None:
85
+ runtime_seconds = (end_time - start_time).total_seconds()
86
+ provenance["inference_runtime_seconds"] = runtime_seconds
87
+
88
+ # Version information
89
+ provenance["sleap_nn_version"] = sleap_nn.__version__
90
+ provenance["sleap_io_version"] = sio.__version__
91
+
92
+ # Model information
93
+ if model_paths is not None:
94
+ # Store as absolute POSIX paths for cross-platform compatibility
95
+ provenance["model_paths"] = [
96
+ Path(p).resolve().as_posix() if isinstance(p, (str, Path)) else str(p)
97
+ for p in model_paths
98
+ ]
99
+ if model_type is not None:
100
+ provenance["model_type"] = model_type
101
+
102
+ # Input data lineage
103
+ if input_path is not None:
104
+ provenance["source_file"] = (
105
+ Path(input_path).resolve().as_posix()
106
+ if isinstance(input_path, (str, Path))
107
+ else str(input_path)
108
+ )
109
+
110
+ # Preserve input provenance if available
111
+ if input_labels is not None and hasattr(input_labels, "provenance"):
112
+ input_prov = dict(input_labels.provenance)
113
+ if input_prov:
114
+ provenance["input_provenance"] = input_prov
115
+ # Also set source_labels for compatibility with sleap-io conventions
116
+ if "filename" in input_prov:
117
+ provenance["source_labels"] = input_prov["filename"]
118
+
119
+ # Frame selection information
120
+ if frames_processed is not None or frames_total is not None:
121
+ frame_info: dict[str, Any] = {}
122
+ if frame_selection_method is not None:
123
+ frame_info["method"] = frame_selection_method
124
+ if frames_processed is not None:
125
+ frame_info["frames_processed"] = frames_processed
126
+ if frames_total is not None:
127
+ frame_info["frames_total"] = frames_total
128
+ if frame_info:
129
+ provenance["frame_selection"] = frame_info
130
+
131
+ # Inference parameters
132
+ if inference_params is not None:
133
+ # Filter out None values and convert paths
134
+ clean_params = {}
135
+ for key, value in inference_params.items():
136
+ if value is not None:
137
+ if isinstance(value, Path):
138
+ clean_params[key] = value.as_posix()
139
+ else:
140
+ clean_params[key] = value
141
+ if clean_params:
142
+ provenance["inference_config"] = clean_params
143
+
144
+ # Tracking parameters
145
+ if tracking_params is not None:
146
+ clean_tracking = {k: v for k, v in tracking_params.items() if v is not None}
147
+ if clean_tracking:
148
+ provenance["tracking_config"] = clean_tracking
149
+
150
+ # Device information
151
+ if device is not None:
152
+ provenance["device"] = device
153
+
154
+ # CLI arguments
155
+ if cli_args is not None:
156
+ # Filter out None values
157
+ clean_cli = {k: v for k, v in cli_args.items() if v is not None}
158
+ if clean_cli:
159
+ provenance["cli_args"] = clean_cli
160
+
161
+ # System information (can be disabled for lighter provenance)
162
+ if include_system_info:
163
+ try:
164
+ system_info = get_system_info_dict()
165
+ # Extract key fields for provenance (avoid excessive nesting)
166
+ provenance["system_info"] = {
167
+ "python_version": system_info.get("python_version"),
168
+ "platform": system_info.get("platform"),
169
+ "pytorch_version": system_info.get("pytorch_version"),
170
+ "cuda_version": system_info.get("cuda_version"),
171
+ "accelerator": system_info.get("accelerator"),
172
+ "gpu_count": system_info.get("gpu_count"),
173
+ }
174
+ # Include GPU names if available
175
+ if system_info.get("gpus"):
176
+ provenance["system_info"]["gpus"] = [
177
+ gpu.get("name") for gpu in system_info["gpus"]
178
+ ]
179
+ except Exception:
180
+ # Don't fail inference if system info collection fails
181
+ pass
182
+
183
+ return provenance
184
+
185
+
186
+ def build_tracking_only_provenance(
187
+ input_labels: Optional[sio.Labels] = None,
188
+ input_path: Optional[Union[str, Path]] = None,
189
+ start_time: Optional[datetime] = None,
190
+ end_time: Optional[datetime] = None,
191
+ tracking_params: Optional[dict[str, Any]] = None,
192
+ frames_processed: Optional[int] = None,
193
+ include_system_info: bool = True,
194
+ ) -> dict[str, Any]:
195
+ """Build provenance metadata for tracking-only pipeline.
196
+
197
+ This is a simplified version of build_inference_provenance for when
198
+ only tracking is run without model inference.
199
+
200
+ Args:
201
+ input_labels: Input Labels object with existing predictions.
202
+ input_path: Path to input SLP file.
203
+ start_time: Datetime when tracking started.
204
+ end_time: Datetime when tracking finished.
205
+ tracking_params: Dictionary of tracking parameters.
206
+ frames_processed: Number of frames that were tracked.
207
+ include_system_info: If True, include system information.
208
+
209
+ Returns:
210
+ Dictionary containing provenance metadata.
211
+ """
212
+ provenance: dict[str, Any] = {}
213
+
214
+ # Timestamps
215
+ if start_time is not None:
216
+ provenance["tracking_start_timestamp"] = start_time.isoformat()
217
+ if end_time is not None:
218
+ provenance["tracking_end_timestamp"] = end_time.isoformat()
219
+ if start_time is not None and end_time is not None:
220
+ runtime_seconds = (end_time - start_time).total_seconds()
221
+ provenance["tracking_runtime_seconds"] = runtime_seconds
222
+
223
+ # Version information
224
+ provenance["sleap_nn_version"] = sleap_nn.__version__
225
+ provenance["sleap_io_version"] = sio.__version__
226
+
227
+ # Note that this is tracking-only
228
+ provenance["pipeline_type"] = "tracking_only"
229
+
230
+ # Input data lineage
231
+ if input_path is not None:
232
+ provenance["source_file"] = (
233
+ Path(input_path).resolve().as_posix()
234
+ if isinstance(input_path, (str, Path))
235
+ else str(input_path)
236
+ )
237
+
238
+ # Preserve input provenance if available
239
+ if input_labels is not None and hasattr(input_labels, "provenance"):
240
+ input_prov = dict(input_labels.provenance)
241
+ if input_prov:
242
+ provenance["input_provenance"] = input_prov
243
+ if "filename" in input_prov:
244
+ provenance["source_labels"] = input_prov["filename"]
245
+
246
+ # Frame information
247
+ if frames_processed is not None:
248
+ provenance["frames_processed"] = frames_processed
249
+
250
+ # Tracking parameters
251
+ if tracking_params is not None:
252
+ clean_tracking = {k: v for k, v in tracking_params.items() if v is not None}
253
+ if clean_tracking:
254
+ provenance["tracking_config"] = clean_tracking
255
+
256
+ # System information
257
+ if include_system_info:
258
+ try:
259
+ system_info = get_system_info_dict()
260
+ provenance["system_info"] = {
261
+ "python_version": system_info.get("python_version"),
262
+ "platform": system_info.get("platform"),
263
+ "pytorch_version": system_info.get("pytorch_version"),
264
+ "accelerator": system_info.get("accelerator"),
265
+ }
266
+ except Exception:
267
+ pass
268
+
269
+ return provenance
270
+
271
+
272
+ def merge_provenance(
273
+ base_provenance: dict[str, Any],
274
+ additional: dict[str, Any],
275
+ overwrite: bool = True,
276
+ ) -> dict[str, Any]:
277
+ """Merge additional provenance fields into base provenance.
278
+
279
+ Args:
280
+ base_provenance: Base provenance dictionary.
281
+ additional: Additional fields to merge.
282
+ overwrite: If True, additional fields overwrite base fields.
283
+ If False, base fields take precedence.
284
+
285
+ Returns:
286
+ Merged provenance dictionary.
287
+ """
288
+ result = dict(base_provenance)
289
+ for key, value in additional.items():
290
+ if key not in result or overwrite:
291
+ result[key] = value
292
+ return result
@@ -47,9 +47,6 @@ class CentroidCrop(L.LightningModule):
47
47
  crop_hw: Tuple (height, width) representing the crop size.
48
48
  input_scale: Float indicating if the images should be resized before being
49
49
  passed to the model.
50
- precrop_resize: Float indicating the factor by which the original images
51
- (not images resized for centroid model) should be resized before cropping.
52
- Note: This resize happens only after getting the predictions for centroid model.
53
50
  max_stride: Maximum stride in a model that the images must be divisible by.
54
51
  If > 1, this will pad the bottom and right of the images to ensure they meet
55
52
  this divisibility criteria. Padding is applied after the scaling specified
@@ -74,7 +71,6 @@ class CentroidCrop(L.LightningModule):
74
71
  return_crops: bool = False,
75
72
  crop_hw: Optional[List[int]] = None,
76
73
  input_scale: float = 1.0,
77
- precrop_resize: float = 1.0,
78
74
  max_stride: int = 1,
79
75
  use_gt_centroids: bool = False,
80
76
  anchor_ind: Optional[int] = None,
@@ -92,22 +88,25 @@ class CentroidCrop(L.LightningModule):
92
88
  self.return_crops = return_crops
93
89
  self.crop_hw = crop_hw
94
90
  self.input_scale = input_scale
95
- self.precrop_resize = precrop_resize
96
91
  self.max_stride = max_stride
97
92
  self.use_gt_centroids = use_gt_centroids
98
93
  self.anchor_ind = anchor_ind
99
94
 
100
- def _generate_crops(self, inputs):
95
+ def _generate_crops(self, inputs, cms: Optional[torch.Tensor] = None):
101
96
  """Generate Crops from the predicted centroids."""
102
97
  crops_dict = []
103
- for centroid, centroid_val, image, fidx, vidx, sz, eff_sc in zip(
104
- self.refined_peaks_batched,
105
- self.peak_vals_batched,
106
- inputs["image"],
107
- inputs["frame_idx"],
108
- inputs["video_idx"],
109
- inputs["orig_size"],
110
- inputs["eff_scale"],
98
+ if cms is not None:
99
+ cms = cms.detach()
100
+ for idx, (centroid, centroid_val, image, fidx, vidx, sz, eff_sc) in enumerate(
101
+ zip(
102
+ self.refined_peaks_batched,
103
+ self.peak_vals_batched,
104
+ inputs["image"],
105
+ inputs["frame_idx"],
106
+ inputs["video_idx"],
107
+ inputs["orig_size"],
108
+ inputs["eff_scale"],
109
+ )
111
110
  ):
112
111
  if torch.any(torch.isnan(centroid)):
113
112
  if torch.all(torch.isnan(centroid)):
@@ -149,6 +148,11 @@ class CentroidCrop(L.LightningModule):
149
148
  ex["instance_image"] = instance_image.unsqueeze(dim=1)
150
149
  ex["orig_size"] = torch.cat([torch.Tensor(sz)] * n)
151
150
  ex["eff_scale"] = torch.Tensor([eff_sc] * n)
151
+ ex["pred_centroids"] = centroid
152
+ if self.return_confmaps:
153
+ ex["pred_centroid_confmaps"] = torch.cat(
154
+ [cms[idx].unsqueeze(dim=0)] * n
155
+ )
152
156
  crops_dict.append(ex)
153
157
 
154
158
  return crops_dict
@@ -204,12 +208,6 @@ class CentroidCrop(L.LightningModule):
204
208
 
205
209
  if self.return_crops:
206
210
  crops_dict = self._generate_crops(inputs)
207
- inputs["image"] = resize_image(inputs["image"], self.precrop_resize)
208
- inputs["centroids"] *= self.precrop_resize
209
- scaled_refined_peaks = []
210
- for ref_peak in self.refined_peaks_batched:
211
- scaled_refined_peaks.append(ref_peak * self.precrop_resize)
212
- self.refined_peaks_batched = scaled_refined_peaks
213
211
  return crops_dict
214
212
  else:
215
213
  return inputs
@@ -274,19 +272,13 @@ class CentroidCrop(L.LightningModule):
274
272
 
275
273
  # Generate crops if return_crops=True to pass the crops to CenteredInstance model.
276
274
  if self.return_crops:
277
- inputs["image"] = resize_image(inputs["image"], self.precrop_resize)
278
- scaled_refined_peaks = []
279
- for ref_peak in self.refined_peaks_batched:
280
- scaled_refined_peaks.append(ref_peak * self.precrop_resize)
281
- self.refined_peaks_batched = scaled_refined_peaks
282
-
283
275
  inputs.update(
284
276
  {
285
277
  "centroids": self.refined_peaks_batched,
286
278
  "centroid_vals": self.peak_vals_batched,
287
279
  }
288
280
  )
289
- crops_dict = self._generate_crops(inputs)
281
+ crops_dict = self._generate_crops(inputs, cms)
290
282
  return crops_dict
291
283
  else:
292
284
  # batch the peaks to pass it to FindInstancePeaksGroundTruth class.
@@ -359,7 +351,11 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
359
351
 
360
352
  def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, np.array]:
361
353
  """Return the ground truth instance peaks given a set of crops."""
362
- b, _, max_inst, nodes, _ = batch["instances"].shape
354
+ b, _, _, nodes, _ = batch["instances"].shape
355
+ # Use number of centroids as max_inst to ensure consistent output shape
356
+ # This handles the case where max_instances limits centroids but instances
357
+ # tensor has a different (global) max_instances from the labels file
358
+ num_centroids = batch["centroids"].shape[2]
363
359
  inst = (
364
360
  batch["instances"].unsqueeze(dim=-4).float()
365
361
  ) # (batch, 1, 1, n_inst, nodes, 2)
@@ -389,26 +385,26 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
389
385
  parsed = 0
390
386
  for i in range(b):
391
387
  if i not in matched_batch_inds:
392
- batch_peaks = torch.full((max_inst, nodes, 2), torch.nan)
393
- vals = torch.full((max_inst, nodes), torch.nan)
388
+ batch_peaks = torch.full((num_centroids, nodes, 2), torch.nan)
389
+ vals = torch.full((num_centroids, nodes), torch.nan)
394
390
  else:
395
391
  c = counts[i]
396
392
  batch_peaks = peaks_list[parsed : parsed + c]
397
393
  num_inst = len(batch_peaks)
398
394
  vals = torch.ones((num_inst, nodes))
399
- if c < max_inst:
395
+ if c < num_centroids:
400
396
  batch_peaks = torch.cat(
401
397
  [
402
398
  batch_peaks,
403
- torch.full((max_inst - num_inst, nodes, 2), torch.nan),
399
+ torch.full((num_centroids - num_inst, nodes, 2), torch.nan),
404
400
  ]
405
401
  )
406
402
  vals = torch.cat(
407
- [vals, torch.full((max_inst - num_inst, nodes), torch.nan)]
403
+ [vals, torch.full((num_centroids - num_inst, nodes), torch.nan)]
408
404
  )
409
405
  else:
410
- batch_peaks = batch_peaks[:max_inst]
411
- vals = vals[:max_inst]
406
+ batch_peaks = batch_peaks[:num_centroids]
407
+ vals = vals[:num_centroids]
412
408
  parsed += c
413
409
 
414
410
  batch_peaks = batch_peaks.unsqueeze(dim=0)
@@ -432,33 +428,45 @@ class FindInstancePeaksGroundTruth(L.LightningModule):
432
428
  peaks_output["pred_instance_peaks"] = peaks
433
429
  peaks_output["pred_peak_values"] = peaks_vals
434
430
 
435
- batch_size, num_centroids = (
436
- batch["centroids"].shape[0],
437
- batch["centroids"].shape[2],
438
- )
431
+ batch_size = batch["centroids"].shape[0]
439
432
  output_dict = {}
440
433
  output_dict["centroid"] = batch["centroids"].squeeze(dim=1).reshape(-1, 1, 2)
441
434
  output_dict["centroid_val"] = batch["centroid_vals"].reshape(-1)
442
- output_dict["pred_instance_peaks"] = batch["pred_instance_peaks"].reshape(
443
- -1, nodes, 2
435
+ output_dict["pred_instance_peaks"] = peaks_output[
436
+ "pred_instance_peaks"
437
+ ].reshape(-1, nodes, 2)
438
+ output_dict["pred_peak_values"] = peaks_output["pred_peak_values"].reshape(
439
+ -1, nodes
444
440
  )
445
- output_dict["pred_peak_values"] = batch["pred_peak_values"].reshape(-1, nodes)
446
441
  output_dict["instance_bbox"] = torch.zeros(
447
442
  (batch_size * num_centroids, 1, 4, 2)
448
443
  )
449
444
  frame_inds = []
450
445
  video_inds = []
451
446
  orig_szs = []
447
+ images = []
448
+ centroid_confmaps = []
452
449
  for b_idx in range(b):
453
450
  curr_batch_size = len(batch["centroids"][b_idx][0])
454
451
  frame_inds.extend([batch["frame_idx"][b_idx]] * curr_batch_size)
455
452
  video_inds.extend([batch["video_idx"][b_idx]] * curr_batch_size)
456
453
  orig_szs.append(torch.cat([batch["orig_size"][b_idx]] * curr_batch_size))
454
+ images.append(
455
+ batch["image"][b_idx].unsqueeze(0).repeat(curr_batch_size, 1, 1, 1, 1)
456
+ )
457
+ if "pred_centroid_confmaps" in batch:
458
+ centroid_confmaps.append(
459
+ batch["pred_centroid_confmaps"][b_idx]
460
+ .unsqueeze(0)
461
+ .repeat(curr_batch_size, 1, 1, 1)
462
+ )
457
463
 
458
464
  output_dict["frame_idx"] = torch.tensor(frame_inds)
459
465
  output_dict["video_idx"] = torch.tensor(video_inds)
460
466
  output_dict["orig_size"] = torch.concatenate(orig_szs, dim=0)
461
-
467
+ output_dict["image"] = torch.cat(images, dim=0)
468
+ if centroid_confmaps:
469
+ output_dict["pred_centroid_confmaps"] = torch.cat(centroid_confmaps, dim=0)
462
470
  return output_dict
463
471
 
464
472
 
@@ -548,6 +556,8 @@ class FindInstancePeaks(L.LightningModule):
548
556
  # Network forward pass.
549
557
  # resize and pad the input image
550
558
  input_image = inputs["instance_image"]
559
+ # resize the crop image
560
+ input_image = resize_image(input_image, self.input_scale)
551
561
  if self.max_stride != 1:
552
562
  input_image = apply_pad_to_stride(input_image, self.max_stride)
553
563
 
@@ -569,8 +579,6 @@ class FindInstancePeaks(L.LightningModule):
569
579
  inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device)
570
580
  )
571
581
 
572
- inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale
573
-
574
582
  inputs["instance_bbox"] = inputs["instance_bbox"] / (
575
583
  inputs["eff_scale"]
576
584
  .unsqueeze(dim=1)
@@ -679,6 +687,8 @@ class TopDownMultiClassFindInstancePeaks(L.LightningModule):
679
687
  # Network forward pass.
680
688
  # resize and pad the input image
681
689
  input_image = inputs["instance_image"]
690
+ # resize the crop image
691
+ input_image = resize_image(input_image, self.input_scale)
682
692
  if self.max_stride != 1:
683
693
  input_image = apply_pad_to_stride(input_image, self.max_stride)
684
694
 
@@ -702,8 +712,6 @@ class TopDownMultiClassFindInstancePeaks(L.LightningModule):
702
712
  inputs["eff_scale"].unsqueeze(dim=1).unsqueeze(dim=2).to(peak_points.device)
703
713
  )
704
714
 
705
- inputs["instance_bbox"] = inputs["instance_bbox"] / self.input_scale
706
-
707
715
  inputs["instance_bbox"] = inputs["instance_bbox"] / (
708
716
  inputs["eff_scale"]
709
717
  .unsqueeze(dim=1)
sleap_nn/legacy_models.py CHANGED
@@ -7,9 +7,8 @@ TensorFlow/Keras backend to PyTorch format compatible with sleap-nn.
7
7
  import h5py
8
8
  import numpy as np
9
9
  import torch
10
- from typing import Dict, Tuple, Any, Optional, List
10
+ from typing import Dict, Any, Optional
11
11
  from pathlib import Path
12
- from omegaconf import OmegaConf
13
12
  import re
14
13
  from loguru import logger
15
14
 
@@ -181,18 +180,61 @@ def parse_keras_layer_name(layer_path: str) -> Dict[str, Any]:
181
180
  return info
182
181
 
183
182
 
183
+ def filter_legacy_weights_by_component(
184
+ legacy_weights: Dict[str, np.ndarray], component: Optional[str]
185
+ ) -> Dict[str, np.ndarray]:
186
+ """Filter legacy weights based on component type.
187
+
188
+ Args:
189
+ legacy_weights: Dictionary of legacy weights from load_keras_weights()
190
+ component: Component type to filter for. One of:
191
+ - "backbone": Keep only encoder/decoder weights (exclude heads)
192
+ - "head": Keep only head layer weights
193
+ - None: No filtering (keep all weights)
194
+
195
+ Returns:
196
+ Filtered dictionary of legacy weights
197
+ """
198
+ if component is None:
199
+ return legacy_weights
200
+
201
+ filtered = {}
202
+ for path, weight in legacy_weights.items():
203
+ # Check if this is a head layer (contains "Head" in the path)
204
+ is_head_layer = "Head" in path
205
+
206
+ if component == "backbone" and not is_head_layer:
207
+ filtered[path] = weight
208
+ elif component == "head" and is_head_layer:
209
+ filtered[path] = weight
210
+
211
+ return filtered
212
+
213
+
184
214
  def map_legacy_to_pytorch_layers(
185
- legacy_weights: Dict[str, np.ndarray], pytorch_model: torch.nn.Module
215
+ legacy_weights: Dict[str, np.ndarray],
216
+ pytorch_model: torch.nn.Module,
217
+ component: Optional[str] = None,
186
218
  ) -> Dict[str, str]:
187
219
  """Create mapping between legacy Keras layers and PyTorch model layers.
188
220
 
189
221
  Args:
190
222
  legacy_weights: Dictionary of legacy weights from load_keras_weights()
191
223
  pytorch_model: PyTorch model instance to map to
224
+ component: Optional component type for filtering weights before mapping.
225
+ One of "backbone", "head", or None (no filtering).
192
226
 
193
227
  Returns:
194
228
  Dictionary mapping legacy layer paths to PyTorch parameter names
195
229
  """
230
+ # Filter weights based on component type
231
+ filtered_weights = filter_legacy_weights_by_component(legacy_weights, component)
232
+
233
+ if component is not None:
234
+ logger.info(
235
+ f"Filtered legacy weights for {component}: "
236
+ f"{len(filtered_weights)}/{len(legacy_weights)} weights"
237
+ )
196
238
  mapping = {}
197
239
 
198
240
  # Get all PyTorch parameters with their shapes
@@ -201,7 +243,7 @@ def map_legacy_to_pytorch_layers(
201
243
  pytorch_params[name] = param.shape
202
244
 
203
245
  # For each legacy weight, find the corresponding PyTorch parameter
204
- for legacy_path, weight in legacy_weights.items():
246
+ for legacy_path, weight in filtered_weights.items():
205
247
  # Extract the layer name from the legacy path
206
248
  # Legacy path format: "model_weights/stack0_enc0_conv0/stack0_enc0_conv0/kernel:0"
207
249
  clean_path = legacy_path.replace("model_weights/", "")
@@ -220,8 +262,6 @@ def map_legacy_to_pytorch_layers(
220
262
  # This handles cases where Keras uses suffixes like _0, _1, etc.
221
263
  if "Head" in layer_name:
222
264
  # Remove trailing _N where N is a number
223
- import re
224
-
225
265
  layer_name_clean = re.sub(r"_\d+$", "", layer_name)
226
266
  else:
227
267
  layer_name_clean = layer_name
@@ -266,12 +306,17 @@ def map_legacy_to_pytorch_layers(
266
306
  if not mapping:
267
307
  logger.info(
268
308
  f"No mappings could be created between legacy weights and PyTorch model. "
269
- f"Legacy weights: {len(legacy_weights)}, PyTorch parameters: {len(pytorch_params)}"
309
+ f"Legacy weights: {len(filtered_weights)}, PyTorch parameters: {len(pytorch_params)}"
270
310
  )
271
311
  else:
272
312
  logger.info(
273
- f"Successfully mapped {len(mapping)}/{len(legacy_weights)} legacy weights to PyTorch parameters"
313
+ f"Successfully mapped {len(mapping)}/{len(pytorch_params)} PyTorch parameters from legacy weights"
274
314
  )
315
+ unmatched_count = len(filtered_weights) - len(mapping)
316
+ if unmatched_count > 0:
317
+ logger.warning(
318
+ f"({unmatched_count} legacy weights did not match any parameters in this model component)"
319
+ )
275
320
 
276
321
  return mapping
277
322
 
@@ -280,6 +325,7 @@ def load_legacy_model_weights(
280
325
  pytorch_model: torch.nn.Module,
281
326
  h5_path: str,
282
327
  mapping: Optional[Dict[str, str]] = None,
328
+ component: Optional[str] = None,
283
329
  ) -> None:
284
330
  """Load legacy Keras weights into a PyTorch model.
285
331
 
@@ -288,6 +334,10 @@ def load_legacy_model_weights(
288
334
  h5_path: Path to the legacy .h5 model file
289
335
  mapping: Optional manual mapping of layer names. If None,
290
336
  will attempt automatic mapping.
337
+ component: Optional component type for filtering weights. One of:
338
+ - "backbone": Only load encoder/decoder weights (exclude heads)
339
+ - "head": Only load head layer weights
340
+ - None: Load all weights (default, for full model loading)
291
341
  """
292
342
  # Load legacy weights
293
343
  legacy_weights = load_keras_weights(h5_path)
@@ -295,7 +345,9 @@ def load_legacy_model_weights(
295
345
  if mapping is None:
296
346
  # Attempt automatic mapping
297
347
  try:
298
- mapping = map_legacy_to_pytorch_layers(legacy_weights, pytorch_model)
348
+ mapping = map_legacy_to_pytorch_layers(
349
+ legacy_weights, pytorch_model, component=component
350
+ )
299
351
  except Exception as e:
300
352
  logger.error(f"Failed to create weight mappings: {e}")
301
353
  return
@@ -417,7 +469,9 @@ def load_legacy_model_weights(
417
469
  ).item()
418
470
  diff = abs(keras_mean - torch_mean)
419
471
  if diff > 1e-6:
420
- message = f"Weight verification failed for {pytorch_name} linear): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
472
+ message = f"Weight verification failed for {pytorch_name} (linear): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
473
+ logger.error(message)
474
+ verification_errors.append(message)
421
475
  else:
422
476
  # Bias : just compare all values
423
477
  keras_mean = np.mean(original_weight)
@@ -426,7 +480,7 @@ def load_legacy_model_weights(
426
480
  ).item()
427
481
  diff = abs(keras_mean - torch_mean)
428
482
  if diff > 1e-6:
429
- message = f"Weight verification failed for {pytorch_name} bias): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
483
+ message = f"Weight verification failed for {pytorch_name} (bias): keras={keras_mean:.6f}, torch={torch_mean:.6f}, diff={diff:.6e}"
430
484
  logger.error(message)
431
485
  verification_errors.append(message)
432
486