sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
|
@@ -8,11 +8,79 @@ import torch
|
|
|
8
8
|
from kornia.geometry.transform import crop_and_resize
|
|
9
9
|
|
|
10
10
|
|
|
11
|
+
def compute_augmentation_padding(
|
|
12
|
+
bbox_size: float,
|
|
13
|
+
rotation_max: float = 0.0,
|
|
14
|
+
scale_max: float = 1.0,
|
|
15
|
+
) -> int:
|
|
16
|
+
"""Compute padding needed to accommodate augmentation transforms.
|
|
17
|
+
|
|
18
|
+
When rotation and scaling augmentations are applied, the bounding box of an
|
|
19
|
+
instance can expand beyond its original size. This function calculates the
|
|
20
|
+
padding needed to ensure the full instance remains visible after augmentation.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
bbox_size: The size of the instance bounding box (max of width/height).
|
|
24
|
+
rotation_max: Maximum absolute rotation angle in degrees. For symmetric
|
|
25
|
+
rotation ranges like [-180, 180], pass 180.
|
|
26
|
+
scale_max: Maximum scaling factor. For scale range [0.9, 1.1], pass 1.1.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Padding in pixels to add around the bounding box (total, not per side).
|
|
30
|
+
"""
|
|
31
|
+
if rotation_max == 0.0 and scale_max <= 1.0:
|
|
32
|
+
return 0
|
|
33
|
+
|
|
34
|
+
# For a square bbox rotated by angle θ, the new bbox has side length:
|
|
35
|
+
# L' = L * (|cos(θ)| + |sin(θ)|)
|
|
36
|
+
# Maximum expansion occurs at 45°: L' = L * sqrt(2)
|
|
37
|
+
# For arbitrary angle: we use the worst case within the rotation range
|
|
38
|
+
rotation_rad = math.radians(min(abs(rotation_max), 90))
|
|
39
|
+
rotation_factor = abs(math.cos(rotation_rad)) + abs(math.sin(rotation_rad))
|
|
40
|
+
|
|
41
|
+
# For angles > 45°, the factor increases, max at 45° = sqrt(2)
|
|
42
|
+
# But for angles approaching 90°, it goes back to 1
|
|
43
|
+
# Worst case in any range including 45° is sqrt(2)
|
|
44
|
+
if abs(rotation_max) >= 45:
|
|
45
|
+
rotation_factor = math.sqrt(2)
|
|
46
|
+
|
|
47
|
+
# Combined expansion factor
|
|
48
|
+
expansion_factor = rotation_factor * max(scale_max, 1.0)
|
|
49
|
+
|
|
50
|
+
# Total padding needed (both sides)
|
|
51
|
+
expanded_size = bbox_size * expansion_factor
|
|
52
|
+
padding = expanded_size - bbox_size
|
|
53
|
+
|
|
54
|
+
return int(math.ceil(padding))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def find_max_instance_bbox_size(labels: sio.Labels) -> float:
|
|
58
|
+
"""Find the maximum bounding box dimension across all instances in labels.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
labels: A `sio.Labels` containing user-labeled instances.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The maximum bounding box dimension (max of width or height) across all instances.
|
|
65
|
+
"""
|
|
66
|
+
max_length = 0.0
|
|
67
|
+
for lf in labels:
|
|
68
|
+
for inst in lf.instances:
|
|
69
|
+
if not inst.is_empty:
|
|
70
|
+
pts = inst.numpy()
|
|
71
|
+
diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
|
|
72
|
+
diff_x = 0 if np.isnan(diff_x) else diff_x
|
|
73
|
+
max_length = np.maximum(max_length, diff_x)
|
|
74
|
+
diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
|
|
75
|
+
diff_y = 0 if np.isnan(diff_y) else diff_y
|
|
76
|
+
max_length = np.maximum(max_length, diff_y)
|
|
77
|
+
return float(max_length)
|
|
78
|
+
|
|
79
|
+
|
|
11
80
|
def find_instance_crop_size(
|
|
12
81
|
labels: sio.Labels,
|
|
13
82
|
padding: int = 0,
|
|
14
83
|
maximum_stride: int = 2,
|
|
15
|
-
input_scaling: float = 1.0,
|
|
16
84
|
min_crop_size: Optional[int] = None,
|
|
17
85
|
) -> int:
|
|
18
86
|
"""Compute the size of the largest instance bounding box from labels.
|
|
@@ -23,8 +91,6 @@ def find_instance_crop_size(
|
|
|
23
91
|
maximum_stride: Ensure that the returned crop size is divisible by this value.
|
|
24
92
|
Useful for ensuring that the crop size will not be truncated in a given
|
|
25
93
|
architecture.
|
|
26
|
-
input_scaling: Float factor indicating the scale of the input images if any
|
|
27
|
-
scaling will be done before cropping.
|
|
28
94
|
min_crop_size: The crop size set by the user.
|
|
29
95
|
|
|
30
96
|
Returns:
|
|
@@ -32,7 +98,7 @@ def find_instance_crop_size(
|
|
|
32
98
|
will contain the instances when cropped. The returned crop size will be larger
|
|
33
99
|
or equal to the input `min_crop_size`.
|
|
34
100
|
|
|
35
|
-
This accounts for stride
|
|
101
|
+
This accounts for stride and padding when ensuring divisibility.
|
|
36
102
|
"""
|
|
37
103
|
# Check if user-specified crop size is divisible by max stride
|
|
38
104
|
min_crop_size = 0 if min_crop_size is None else min_crop_size
|
|
@@ -46,7 +112,6 @@ def find_instance_crop_size(
|
|
|
46
112
|
for inst in lf.instances:
|
|
47
113
|
if not inst.is_empty: # only if at least one point is not nan
|
|
48
114
|
pts = inst.numpy()
|
|
49
|
-
pts *= input_scaling
|
|
50
115
|
diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
|
|
51
116
|
diff_x = 0 if np.isnan(diff_x) else diff_x
|
|
52
117
|
max_length = np.maximum(max_length, diff_x)
|
sleap_nn/data/normalization.py
CHANGED
|
@@ -4,6 +4,36 @@ import torch
|
|
|
4
4
|
import torchvision.transforms.v2.functional as F
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
def normalize_on_gpu(image: torch.Tensor) -> torch.Tensor:
|
|
8
|
+
"""Normalize image tensor on GPU after transfer.
|
|
9
|
+
|
|
10
|
+
This function is called in the model's forward() method after the image has been
|
|
11
|
+
transferred to GPU. It converts uint8 images to float32 and normalizes to [0, 1].
|
|
12
|
+
|
|
13
|
+
By performing normalization on GPU after transfer, we reduce PCIe bandwidth by 4x
|
|
14
|
+
(transferring 1 byte/pixel as uint8 instead of 4 bytes/pixel as float32). This
|
|
15
|
+
provides up to 17x speedup for the transfer+normalization stage.
|
|
16
|
+
|
|
17
|
+
This function handles two cases:
|
|
18
|
+
1. uint8 tensor with values in [0, 255] -> convert to float32 and divide by 255
|
|
19
|
+
2. float32 tensor with values in [0, 255] (e.g., from preprocessing that cast to
|
|
20
|
+
float32 without normalizing) -> divide by 255
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
image: Tensor image that may be uint8 or float32 with values in [0, 255] range.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Float32 tensor normalized to [0, 1] range.
|
|
27
|
+
"""
|
|
28
|
+
if not torch.is_floating_point(image):
|
|
29
|
+
# uint8 -> float32 normalized
|
|
30
|
+
image = image.float() / 255.0
|
|
31
|
+
elif image.max() > 1.0:
|
|
32
|
+
# float32 but not normalized (values > 1 indicate [0, 255] range)
|
|
33
|
+
image = image / 255.0
|
|
34
|
+
return image
|
|
35
|
+
|
|
36
|
+
|
|
7
37
|
def convert_to_grayscale(image: torch.Tensor):
|
|
8
38
|
"""Convert given image to Grayscale image (single-channel).
|
|
9
39
|
|
|
@@ -38,8 +68,21 @@ def convert_to_rgb(image: torch.Tensor):
|
|
|
38
68
|
return image
|
|
39
69
|
|
|
40
70
|
|
|
41
|
-
def apply_normalization(image: torch.Tensor):
|
|
42
|
-
"""Normalize image tensor.
|
|
71
|
+
def apply_normalization(image: torch.Tensor) -> torch.Tensor:
|
|
72
|
+
"""Normalize image tensor from uint8 [0, 255] to float32 [0, 1].
|
|
73
|
+
|
|
74
|
+
This function is used during training data preprocessing where augmentation
|
|
75
|
+
operations (kornia) require float32 input.
|
|
76
|
+
|
|
77
|
+
For inference, normalization is deferred to GPU via `normalize_on_gpu()` in the
|
|
78
|
+
model's forward() method to reduce PCIe bandwidth.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
image: Tensor image (typically uint8 with values in [0, 255]).
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Float32 tensor normalized to [0, 1] range.
|
|
85
|
+
"""
|
|
43
86
|
if not torch.is_floating_point(image):
|
|
44
87
|
image = image.to(torch.float32) / 255.0
|
|
45
88
|
return image
|
sleap_nn/data/providers.py
CHANGED
|
@@ -71,6 +71,8 @@ def process_lf(
|
|
|
71
71
|
for inst in instances_list:
|
|
72
72
|
if not inst.is_empty:
|
|
73
73
|
instances.append(inst.numpy())
|
|
74
|
+
if len(instances) == 0:
|
|
75
|
+
return None
|
|
74
76
|
instances = np.stack(instances, axis=0)
|
|
75
77
|
|
|
76
78
|
# Add singleton time dimension for single frames.
|
|
@@ -233,6 +235,8 @@ class LabelsReader(Thread):
|
|
|
233
235
|
instances_key: bool = False,
|
|
234
236
|
only_labeled_frames: bool = False,
|
|
235
237
|
only_suggested_frames: bool = False,
|
|
238
|
+
exclude_user_labeled: bool = False,
|
|
239
|
+
only_predicted_frames: bool = False,
|
|
236
240
|
):
|
|
237
241
|
"""Initialize attribute of the class."""
|
|
238
242
|
super().__init__()
|
|
@@ -245,6 +249,8 @@ class LabelsReader(Thread):
|
|
|
245
249
|
|
|
246
250
|
self.only_labeled_frames = only_labeled_frames
|
|
247
251
|
self.only_suggested_frames = only_suggested_frames
|
|
252
|
+
self.exclude_user_labeled = exclude_user_labeled
|
|
253
|
+
self.only_predicted_frames = only_predicted_frames
|
|
248
254
|
|
|
249
255
|
# Filter to only user labeled instances
|
|
250
256
|
if self.only_labeled_frames:
|
|
@@ -265,6 +271,20 @@ class LabelsReader(Thread):
|
|
|
265
271
|
)
|
|
266
272
|
self.filtered_lfs.append(new_lf)
|
|
267
273
|
|
|
274
|
+
# Filter out user labeled frames
|
|
275
|
+
elif self.exclude_user_labeled:
|
|
276
|
+
self.filtered_lfs = []
|
|
277
|
+
for lf in self.labels:
|
|
278
|
+
if not lf.has_user_instances:
|
|
279
|
+
self.filtered_lfs.append(lf)
|
|
280
|
+
|
|
281
|
+
# Filter to only predicted frames
|
|
282
|
+
elif self.only_predicted_frames:
|
|
283
|
+
self.filtered_lfs = []
|
|
284
|
+
for lf in self.labels:
|
|
285
|
+
if lf.has_predicted_instances:
|
|
286
|
+
self.filtered_lfs.append(lf)
|
|
287
|
+
|
|
268
288
|
else:
|
|
269
289
|
self.filtered_lfs = [lf for lf in self.labels]
|
|
270
290
|
|
|
@@ -300,6 +320,8 @@ class LabelsReader(Thread):
|
|
|
300
320
|
instances_key: bool = False,
|
|
301
321
|
only_labeled_frames: bool = False,
|
|
302
322
|
only_suggested_frames: bool = False,
|
|
323
|
+
exclude_user_labeled: bool = False,
|
|
324
|
+
only_predicted_frames: bool = False,
|
|
303
325
|
):
|
|
304
326
|
"""Create LabelsReader from a .slp filename."""
|
|
305
327
|
labels = sio.load_slp(filename)
|
|
@@ -310,6 +332,8 @@ class LabelsReader(Thread):
|
|
|
310
332
|
instances_key,
|
|
311
333
|
only_labeled_frames,
|
|
312
334
|
only_suggested_frames,
|
|
335
|
+
exclude_user_labeled,
|
|
336
|
+
only_predicted_frames,
|
|
313
337
|
)
|
|
314
338
|
|
|
315
339
|
def run(self):
|
|
@@ -333,6 +357,8 @@ class LabelsReader(Thread):
|
|
|
333
357
|
for inst in lf:
|
|
334
358
|
if not inst.is_empty:
|
|
335
359
|
instances.append(inst.numpy())
|
|
360
|
+
if len(instances) == 0:
|
|
361
|
+
continue
|
|
336
362
|
instances = np.stack(instances, axis=0)
|
|
337
363
|
|
|
338
364
|
# Add singleton time dimension for single frames.
|
sleap_nn/evaluation.py
CHANGED
|
@@ -61,18 +61,9 @@ def find_frame_pairs(
|
|
|
61
61
|
# Find matching video instance in predictions.
|
|
62
62
|
video_pr = None
|
|
63
63
|
for video in labels_pr.videos:
|
|
64
|
-
if (
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
):
|
|
68
|
-
same_dataset = (
|
|
69
|
-
(video.backend.dataset == video_gt.backend.dataset)
|
|
70
|
-
if hasattr(video.backend, "dataset")
|
|
71
|
-
else True
|
|
72
|
-
) # `dataset` attr exists only for hdf5 backend not for mediavideo
|
|
73
|
-
if same_dataset:
|
|
74
|
-
video_pr = video
|
|
75
|
-
break
|
|
64
|
+
if video_gt.matches_content(video) and video_gt.matches_path(video):
|
|
65
|
+
video_pr = video
|
|
66
|
+
break
|
|
76
67
|
|
|
77
68
|
if video_pr is None:
|
|
78
69
|
continue
|
|
@@ -678,24 +669,109 @@ class Evaluator:
|
|
|
678
669
|
return metrics
|
|
679
670
|
|
|
680
671
|
|
|
681
|
-
def
|
|
682
|
-
"""
|
|
672
|
+
def _find_metrics_file(model_dir: Path, split: str, dataset_idx: int) -> Path:
|
|
673
|
+
"""Find the metrics file in a model directory.
|
|
674
|
+
|
|
675
|
+
Tries new naming format first, then falls back to old format.
|
|
676
|
+
If split is "test" and not found, falls back to "val".
|
|
677
|
+
"""
|
|
678
|
+
# Try new naming format first: metrics.{split}.{idx}.npz
|
|
679
|
+
metrics_path = model_dir / f"metrics.{split}.{dataset_idx}.npz"
|
|
680
|
+
if metrics_path.exists():
|
|
681
|
+
return metrics_path
|
|
682
|
+
|
|
683
|
+
# Fall back to old naming format: {split}_{idx}_pred_metrics.npz
|
|
684
|
+
metrics_path = model_dir / f"{split}_{dataset_idx}_pred_metrics.npz"
|
|
685
|
+
if metrics_path.exists():
|
|
686
|
+
return metrics_path
|
|
687
|
+
|
|
688
|
+
# If split is "test" and not found, try "val" fallback
|
|
689
|
+
if split == "test":
|
|
690
|
+
return _find_metrics_file(model_dir, "val", dataset_idx)
|
|
691
|
+
|
|
692
|
+
# Return the new format path (will raise FileNotFoundError later)
|
|
693
|
+
return model_dir / f"metrics.{split}.{dataset_idx}.npz"
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def _load_npz_metrics(metrics_path: Path) -> dict:
|
|
697
|
+
"""Load metrics from an npz file, supporting both old and new formats.
|
|
698
|
+
|
|
699
|
+
New format: single "metrics" key containing a dict with all metrics.
|
|
700
|
+
Old format: individual metric keys at top level (voc_metrics, mOKS, etc.).
|
|
701
|
+
"""
|
|
702
|
+
with np.load(metrics_path, allow_pickle=True) as data:
|
|
703
|
+
keys = list(data.keys())
|
|
704
|
+
|
|
705
|
+
# New format: single "metrics" key containing dict
|
|
706
|
+
if "metrics" in keys:
|
|
707
|
+
return data["metrics"].item()
|
|
708
|
+
|
|
709
|
+
# Old format: individual metric keys at top level
|
|
710
|
+
expected_keys = {
|
|
711
|
+
"voc_metrics",
|
|
712
|
+
"mOKS",
|
|
713
|
+
"distance_metrics",
|
|
714
|
+
"pck_metrics",
|
|
715
|
+
"visibility_metrics",
|
|
716
|
+
}
|
|
717
|
+
if expected_keys.issubset(set(keys)):
|
|
718
|
+
return {
|
|
719
|
+
k: data[k].item() if data[k].ndim == 0 else data[k]
|
|
720
|
+
for k in expected_keys
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
# Unknown format - return all keys as dict
|
|
724
|
+
return {k: data[k].item() if data[k].ndim == 0 else data[k] for k in keys}
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
def load_metrics(
|
|
728
|
+
path: str,
|
|
729
|
+
split: str = "test",
|
|
730
|
+
dataset_idx: int = 0,
|
|
731
|
+
) -> dict:
|
|
732
|
+
"""Load metrics from a model folder or metrics file.
|
|
733
|
+
|
|
734
|
+
This function supports both the new format (single "metrics" key) and the old
|
|
735
|
+
format (individual metric keys at top level). It also handles both old and new
|
|
736
|
+
file naming conventions in model folders.
|
|
683
737
|
|
|
684
738
|
Args:
|
|
685
|
-
|
|
686
|
-
split: Name of the split to load
|
|
687
|
-
|
|
688
|
-
|
|
739
|
+
path: Path to a model folder or metrics file (.npz).
|
|
740
|
+
split: Name of the split to load. Must be "train", "val", or "test".
|
|
741
|
+
Default: "test". If "test" is not found, falls back to "val".
|
|
742
|
+
Ignored if path points directly to a .npz file.
|
|
743
|
+
dataset_idx: Index of the dataset (for multi-dataset training).
|
|
744
|
+
Default: 0. Ignored if path points directly to a .npz file.
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
Dictionary containing metrics with keys: voc_metrics, mOKS,
|
|
748
|
+
distance_metrics, pck_metrics, visibility_metrics.
|
|
749
|
+
|
|
750
|
+
Raises:
|
|
751
|
+
FileNotFoundError: If no metrics file is found.
|
|
689
752
|
|
|
753
|
+
Examples:
|
|
754
|
+
>>> # Load from model folder (tries test, falls back to val)
|
|
755
|
+
>>> metrics = load_metrics("/path/to/model")
|
|
756
|
+
>>> print(metrics["mOKS"]["mOKS"])
|
|
757
|
+
|
|
758
|
+
>>> # Load specific split and dataset
|
|
759
|
+
>>> metrics = load_metrics("/path/to/model", split="val", dataset_idx=1)
|
|
760
|
+
|
|
761
|
+
>>> # Load directly from npz file
|
|
762
|
+
>>> metrics = load_metrics("/path/to/metrics.val.0.npz")
|
|
690
763
|
"""
|
|
691
|
-
|
|
692
|
-
|
|
764
|
+
path = Path(path)
|
|
765
|
+
|
|
766
|
+
if path.suffix == ".npz":
|
|
767
|
+
metrics_path = path
|
|
693
768
|
else:
|
|
694
|
-
metrics_path =
|
|
769
|
+
metrics_path = _find_metrics_file(path, split, dataset_idx)
|
|
770
|
+
|
|
695
771
|
if not metrics_path.exists():
|
|
696
772
|
raise FileNotFoundError(f"Metrics file not found at {metrics_path}")
|
|
697
|
-
|
|
698
|
-
|
|
773
|
+
|
|
774
|
+
return _load_npz_metrics(metrics_path)
|
|
699
775
|
|
|
700
776
|
|
|
701
777
|
def run_evaluation(
|
sleap_nn/inference/__init__.py
CHANGED
|
@@ -47,15 +47,23 @@ def crop_bboxes(
|
|
|
47
47
|
width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
|
|
48
48
|
box_size = tuple(torch.round(torch.Tensor((height + 1, width + 1))).to(torch.int32))
|
|
49
49
|
|
|
50
|
+
# Store original dtype for conversion back after cropping.
|
|
51
|
+
original_dtype = images.dtype
|
|
52
|
+
|
|
53
|
+
# Kornia's crop_and_resize requires float32 input.
|
|
54
|
+
images_to_crop = images[sample_inds]
|
|
55
|
+
if not torch.is_floating_point(images_to_crop):
|
|
56
|
+
images_to_crop = images_to_crop.float()
|
|
57
|
+
|
|
50
58
|
# Crop.
|
|
51
59
|
crops = crop_and_resize(
|
|
52
|
-
|
|
60
|
+
images_to_crop, # (n_boxes, channels, height, width)
|
|
53
61
|
boxes=bboxes,
|
|
54
62
|
size=box_size,
|
|
55
63
|
)
|
|
56
64
|
|
|
57
65
|
# Cast back to original dtype and return.
|
|
58
|
-
crops = crops.to(
|
|
66
|
+
crops = crops.to(original_dtype)
|
|
59
67
|
return crops
|
|
60
68
|
|
|
61
69
|
|