sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,11 +8,79 @@ import torch
8
8
  from kornia.geometry.transform import crop_and_resize
9
9
 
10
10
 
11
+ def compute_augmentation_padding(
12
+ bbox_size: float,
13
+ rotation_max: float = 0.0,
14
+ scale_max: float = 1.0,
15
+ ) -> int:
16
+ """Compute padding needed to accommodate augmentation transforms.
17
+
18
+ When rotation and scaling augmentations are applied, the bounding box of an
19
+ instance can expand beyond its original size. This function calculates the
20
+ padding needed to ensure the full instance remains visible after augmentation.
21
+
22
+ Args:
23
+ bbox_size: The size of the instance bounding box (max of width/height).
24
+ rotation_max: Maximum absolute rotation angle in degrees. For symmetric
25
+ rotation ranges like [-180, 180], pass 180.
26
+ scale_max: Maximum scaling factor. For scale range [0.9, 1.1], pass 1.1.
27
+
28
+ Returns:
29
+ Padding in pixels to add around the bounding box (total, not per side).
30
+ """
31
+ if rotation_max == 0.0 and scale_max <= 1.0:
32
+ return 0
33
+
34
+ # For a square bbox rotated by angle θ, the new bbox has side length:
35
+ # L' = L * (|cos(θ)| + |sin(θ)|)
36
+ # Maximum expansion occurs at 45°: L' = L * sqrt(2)
37
+ # For arbitrary angle: we use the worst case within the rotation range
38
+ rotation_rad = math.radians(min(abs(rotation_max), 90))
39
+ rotation_factor = abs(math.cos(rotation_rad)) + abs(math.sin(rotation_rad))
40
+
41
+ # For angles > 45°, the factor increases, max at 45° = sqrt(2)
42
+ # But for angles approaching 90°, it goes back to 1
43
+ # Worst case in any range including 45° is sqrt(2)
44
+ if abs(rotation_max) >= 45:
45
+ rotation_factor = math.sqrt(2)
46
+
47
+ # Combined expansion factor
48
+ expansion_factor = rotation_factor * max(scale_max, 1.0)
49
+
50
+ # Total padding needed (both sides)
51
+ expanded_size = bbox_size * expansion_factor
52
+ padding = expanded_size - bbox_size
53
+
54
+ return int(math.ceil(padding))
55
+
56
+
57
+ def find_max_instance_bbox_size(labels: sio.Labels) -> float:
58
+ """Find the maximum bounding box dimension across all instances in labels.
59
+
60
+ Args:
61
+ labels: A `sio.Labels` containing user-labeled instances.
62
+
63
+ Returns:
64
+ The maximum bounding box dimension (max of width or height) across all instances.
65
+ """
66
+ max_length = 0.0
67
+ for lf in labels:
68
+ for inst in lf.instances:
69
+ if not inst.is_empty:
70
+ pts = inst.numpy()
71
+ diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
72
+ diff_x = 0 if np.isnan(diff_x) else diff_x
73
+ max_length = np.maximum(max_length, diff_x)
74
+ diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
75
+ diff_y = 0 if np.isnan(diff_y) else diff_y
76
+ max_length = np.maximum(max_length, diff_y)
77
+ return float(max_length)
78
+
79
+
11
80
  def find_instance_crop_size(
12
81
  labels: sio.Labels,
13
82
  padding: int = 0,
14
83
  maximum_stride: int = 2,
15
- input_scaling: float = 1.0,
16
84
  min_crop_size: Optional[int] = None,
17
85
  ) -> int:
18
86
  """Compute the size of the largest instance bounding box from labels.
@@ -23,8 +91,6 @@ def find_instance_crop_size(
23
91
  maximum_stride: Ensure that the returned crop size is divisible by this value.
24
92
  Useful for ensuring that the crop size will not be truncated in a given
25
93
  architecture.
26
- input_scaling: Float factor indicating the scale of the input images if any
27
- scaling will be done before cropping.
28
94
  min_crop_size: The crop size set by the user.
29
95
 
30
96
  Returns:
@@ -32,7 +98,7 @@ def find_instance_crop_size(
32
98
  will contain the instances when cropped. The returned crop size will be larger
33
99
  or equal to the input `min_crop_size`.
34
100
 
35
- This accounts for stride, padding and scaling when ensuring divisibility.
101
+ This accounts for stride and padding when ensuring divisibility.
36
102
  """
37
103
  # Check if user-specified crop size is divisible by max stride
38
104
  min_crop_size = 0 if min_crop_size is None else min_crop_size
@@ -46,7 +112,6 @@ def find_instance_crop_size(
46
112
  for inst in lf.instances:
47
113
  if not inst.is_empty: # only if at least one point is not nan
48
114
  pts = inst.numpy()
49
- pts *= input_scaling
50
115
  diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
51
116
  diff_x = 0 if np.isnan(diff_x) else diff_x
52
117
  max_length = np.maximum(max_length, diff_x)
@@ -4,6 +4,36 @@ import torch
4
4
  import torchvision.transforms.v2.functional as F
5
5
 
6
6
 
7
+ def normalize_on_gpu(image: torch.Tensor) -> torch.Tensor:
8
+ """Normalize image tensor on GPU after transfer.
9
+
10
+ This function is called in the model's forward() method after the image has been
11
+ transferred to GPU. It converts uint8 images to float32 and normalizes to [0, 1].
12
+
13
+ By performing normalization on GPU after transfer, we reduce PCIe bandwidth by 4x
14
+ (transferring 1 byte/pixel as uint8 instead of 4 bytes/pixel as float32). This
15
+ provides up to 17x speedup for the transfer+normalization stage.
16
+
17
+ This function handles two cases:
18
+ 1. uint8 tensor with values in [0, 255] -> convert to float32 and divide by 255
19
+ 2. float32 tensor with values in [0, 255] (e.g., from preprocessing that cast to
20
+ float32 without normalizing) -> divide by 255
21
+
22
+ Args:
23
+ image: Tensor image that may be uint8 or float32 with values in [0, 255] range.
24
+
25
+ Returns:
26
+ Float32 tensor normalized to [0, 1] range.
27
+ """
28
+ if not torch.is_floating_point(image):
29
+ # uint8 -> float32 normalized
30
+ image = image.float() / 255.0
31
+ elif image.max() > 1.0:
32
+ # float32 but not normalized (values > 1 indicate [0, 255] range)
33
+ image = image / 255.0
34
+ return image
35
+
36
+
7
37
  def convert_to_grayscale(image: torch.Tensor):
8
38
  """Convert given image to Grayscale image (single-channel).
9
39
 
@@ -38,8 +68,21 @@ def convert_to_rgb(image: torch.Tensor):
38
68
  return image
39
69
 
40
70
 
41
- def apply_normalization(image: torch.Tensor):
42
- """Normalize image tensor."""
71
+ def apply_normalization(image: torch.Tensor) -> torch.Tensor:
72
+ """Normalize image tensor from uint8 [0, 255] to float32 [0, 1].
73
+
74
+ This function is used during training data preprocessing where augmentation
75
+ operations (kornia) require float32 input.
76
+
77
+ For inference, normalization is deferred to GPU via `normalize_on_gpu()` in the
78
+ model's forward() method to reduce PCIe bandwidth.
79
+
80
+ Args:
81
+ image: Tensor image (typically uint8 with values in [0, 255]).
82
+
83
+ Returns:
84
+ Float32 tensor normalized to [0, 1] range.
85
+ """
43
86
  if not torch.is_floating_point(image):
44
87
  image = image.to(torch.float32) / 255.0
45
88
  return image
@@ -71,6 +71,8 @@ def process_lf(
71
71
  for inst in instances_list:
72
72
  if not inst.is_empty:
73
73
  instances.append(inst.numpy())
74
+ if len(instances) == 0:
75
+ return None
74
76
  instances = np.stack(instances, axis=0)
75
77
 
76
78
  # Add singleton time dimension for single frames.
@@ -233,6 +235,8 @@ class LabelsReader(Thread):
233
235
  instances_key: bool = False,
234
236
  only_labeled_frames: bool = False,
235
237
  only_suggested_frames: bool = False,
238
+ exclude_user_labeled: bool = False,
239
+ only_predicted_frames: bool = False,
236
240
  ):
237
241
  """Initialize attribute of the class."""
238
242
  super().__init__()
@@ -245,6 +249,8 @@ class LabelsReader(Thread):
245
249
 
246
250
  self.only_labeled_frames = only_labeled_frames
247
251
  self.only_suggested_frames = only_suggested_frames
252
+ self.exclude_user_labeled = exclude_user_labeled
253
+ self.only_predicted_frames = only_predicted_frames
248
254
 
249
255
  # Filter to only user labeled instances
250
256
  if self.only_labeled_frames:
@@ -265,6 +271,20 @@ class LabelsReader(Thread):
265
271
  )
266
272
  self.filtered_lfs.append(new_lf)
267
273
 
274
+ # Filter out user labeled frames
275
+ elif self.exclude_user_labeled:
276
+ self.filtered_lfs = []
277
+ for lf in self.labels:
278
+ if not lf.has_user_instances:
279
+ self.filtered_lfs.append(lf)
280
+
281
+ # Filter to only predicted frames
282
+ elif self.only_predicted_frames:
283
+ self.filtered_lfs = []
284
+ for lf in self.labels:
285
+ if lf.has_predicted_instances:
286
+ self.filtered_lfs.append(lf)
287
+
268
288
  else:
269
289
  self.filtered_lfs = [lf for lf in self.labels]
270
290
 
@@ -300,6 +320,8 @@ class LabelsReader(Thread):
300
320
  instances_key: bool = False,
301
321
  only_labeled_frames: bool = False,
302
322
  only_suggested_frames: bool = False,
323
+ exclude_user_labeled: bool = False,
324
+ only_predicted_frames: bool = False,
303
325
  ):
304
326
  """Create LabelsReader from a .slp filename."""
305
327
  labels = sio.load_slp(filename)
@@ -310,6 +332,8 @@ class LabelsReader(Thread):
310
332
  instances_key,
311
333
  only_labeled_frames,
312
334
  only_suggested_frames,
335
+ exclude_user_labeled,
336
+ only_predicted_frames,
313
337
  )
314
338
 
315
339
  def run(self):
@@ -333,6 +357,8 @@ class LabelsReader(Thread):
333
357
  for inst in lf:
334
358
  if not inst.is_empty:
335
359
  instances.append(inst.numpy())
360
+ if len(instances) == 0:
361
+ continue
336
362
  instances = np.stack(instances, axis=0)
337
363
 
338
364
  # Add singleton time dimension for single frames.
sleap_nn/evaluation.py CHANGED
@@ -61,18 +61,9 @@ def find_frame_pairs(
61
61
  # Find matching video instance in predictions.
62
62
  video_pr = None
63
63
  for video in labels_pr.videos:
64
- if (
65
- isinstance(video.backend, type(video_gt.backend))
66
- and video.filename == video_gt.filename
67
- ):
68
- same_dataset = (
69
- (video.backend.dataset == video_gt.backend.dataset)
70
- if hasattr(video.backend, "dataset")
71
- else True
72
- ) # `dataset` attr exists only for hdf5 backend not for mediavideo
73
- if same_dataset:
74
- video_pr = video
75
- break
64
+ if video_gt.matches_content(video) and video_gt.matches_path(video):
65
+ video_pr = video
66
+ break
76
67
 
77
68
  if video_pr is None:
78
69
  continue
@@ -678,24 +669,109 @@ class Evaluator:
678
669
  return metrics
679
670
 
680
671
 
681
- def load_metrics(model_path: str, split="val"):
682
- """Load the metrics for a given model and split.
672
+ def _find_metrics_file(model_dir: Path, split: str, dataset_idx: int) -> Path:
673
+ """Find the metrics file in a model directory.
674
+
675
+ Tries new naming format first, then falls back to old format.
676
+ If split is "test" and not found, falls back to "val".
677
+ """
678
+ # Try new naming format first: metrics.{split}.{idx}.npz
679
+ metrics_path = model_dir / f"metrics.{split}.{dataset_idx}.npz"
680
+ if metrics_path.exists():
681
+ return metrics_path
682
+
683
+ # Fall back to old naming format: {split}_{idx}_pred_metrics.npz
684
+ metrics_path = model_dir / f"{split}_{dataset_idx}_pred_metrics.npz"
685
+ if metrics_path.exists():
686
+ return metrics_path
687
+
688
+ # If split is "test" and not found, try "val" fallback
689
+ if split == "test":
690
+ return _find_metrics_file(model_dir, "val", dataset_idx)
691
+
692
+ # Return the new format path (will raise FileNotFoundError later)
693
+ return model_dir / f"metrics.{split}.{dataset_idx}.npz"
694
+
695
+
696
+ def _load_npz_metrics(metrics_path: Path) -> dict:
697
+ """Load metrics from an npz file, supporting both old and new formats.
698
+
699
+ New format: single "metrics" key containing a dict with all metrics.
700
+ Old format: individual metric keys at top level (voc_metrics, mOKS, etc.).
701
+ """
702
+ with np.load(metrics_path, allow_pickle=True) as data:
703
+ keys = list(data.keys())
704
+
705
+ # New format: single "metrics" key containing dict
706
+ if "metrics" in keys:
707
+ return data["metrics"].item()
708
+
709
+ # Old format: individual metric keys at top level
710
+ expected_keys = {
711
+ "voc_metrics",
712
+ "mOKS",
713
+ "distance_metrics",
714
+ "pck_metrics",
715
+ "visibility_metrics",
716
+ }
717
+ if expected_keys.issubset(set(keys)):
718
+ return {
719
+ k: data[k].item() if data[k].ndim == 0 else data[k]
720
+ for k in expected_keys
721
+ }
722
+
723
+ # Unknown format - return all keys as dict
724
+ return {k: data[k].item() if data[k].ndim == 0 else data[k] for k in keys}
725
+
726
+
727
+ def load_metrics(
728
+ path: str,
729
+ split: str = "test",
730
+ dataset_idx: int = 0,
731
+ ) -> dict:
732
+ """Load metrics from a model folder or metrics file.
733
+
734
+ This function supports both the new format (single "metrics" key) and the old
735
+ format (individual metric keys at top level). It also handles both old and new
736
+ file naming conventions in model folders.
683
737
 
684
738
  Args:
685
- model_path: Path to a model folder or metrics file (.npz).
686
- split: Name of the split to load the metrics for. Must be `"train"`, `"val"` or
687
- `"test"` (default: `"val"`). Ignored if a path to a metrics NPZ file is
688
- provided.
739
+ path: Path to a model folder or metrics file (.npz).
740
+ split: Name of the split to load. Must be "train", "val", or "test".
741
+ Default: "test". If "test" is not found, falls back to "val".
742
+ Ignored if path points directly to a .npz file.
743
+ dataset_idx: Index of the dataset (for multi-dataset training).
744
+ Default: 0. Ignored if path points directly to a .npz file.
745
+
746
+ Returns:
747
+ Dictionary containing metrics with keys: voc_metrics, mOKS,
748
+ distance_metrics, pck_metrics, visibility_metrics.
749
+
750
+ Raises:
751
+ FileNotFoundError: If no metrics file is found.
689
752
 
753
+ Examples:
754
+ >>> # Load from model folder (tries test, falls back to val)
755
+ >>> metrics = load_metrics("/path/to/model")
756
+ >>> print(metrics["mOKS"]["mOKS"])
757
+
758
+ >>> # Load specific split and dataset
759
+ >>> metrics = load_metrics("/path/to/model", split="val", dataset_idx=1)
760
+
761
+ >>> # Load directly from npz file
762
+ >>> metrics = load_metrics("/path/to/metrics.val.0.npz")
690
763
  """
691
- if Path(model_path).suffix == ".npz":
692
- metrics_path = Path(model_path)
764
+ path = Path(path)
765
+
766
+ if path.suffix == ".npz":
767
+ metrics_path = path
693
768
  else:
694
- metrics_path = Path(model_path) / f"{split}_0_pred_metrics.npz"
769
+ metrics_path = _find_metrics_file(path, split, dataset_idx)
770
+
695
771
  if not metrics_path.exists():
696
772
  raise FileNotFoundError(f"Metrics file not found at {metrics_path}")
697
- with np.load(metrics_path, allow_pickle=True) as data:
698
- return data["metrics"].item()
773
+
774
+ return _load_npz_metrics(metrics_path)
699
775
 
700
776
 
701
777
  def run_evaluation(
@@ -1 +1,7 @@
1
1
  """Inference-related modules."""
2
+
3
+ from sleap_nn.inference.provenance import (
4
+ build_inference_provenance,
5
+ build_tracking_only_provenance,
6
+ merge_provenance,
7
+ )
@@ -47,15 +47,23 @@ def crop_bboxes(
47
47
  width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
48
48
  box_size = tuple(torch.round(torch.Tensor((height + 1, width + 1))).to(torch.int32))
49
49
 
50
+ # Store original dtype for conversion back after cropping.
51
+ original_dtype = images.dtype
52
+
53
+ # Kornia's crop_and_resize requires float32 input.
54
+ images_to_crop = images[sample_inds]
55
+ if not torch.is_floating_point(images_to_crop):
56
+ images_to_crop = images_to_crop.float()
57
+
50
58
  # Crop.
51
59
  crops = crop_and_resize(
52
- images[sample_inds], # (n_boxes, channels, height, width)
60
+ images_to_crop, # (n_boxes, channels, height, width)
53
61
  boxes=bboxes,
54
62
  size=box_size,
55
63
  )
56
64
 
57
65
  # Cast back to original dtype and return.
58
- crops = crops.to(images.dtype)
66
+ crops = crops.to(original_dtype)
59
67
  return crops
60
68
 
61
69