nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,268 @@
1
+ from typing import Sequence, Optional
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+
7
+ def crop_and_pad_into_buffer(
8
+ target_tensor: torch.Tensor, bbox: Sequence[Sequence[int]], source_tensor, source_leading_slice=None
9
+ ) -> None:
10
+ """
11
+ Copies a sub-region from source_tensor into target_tensor based on a bounding box.
12
+
13
+ Args:
14
+ target_tensor (torch.Tensor): A preallocated tensor that will be updated.
15
+ bbox (sequence of [int, int]): A bounding box for each dimension of the source tensor
16
+ that is covered by the bbox. The bbox is defined as [start, end) (half-open interval)
17
+ and may extend outside the source tensor. If source_tensor has more dimensions than
18
+ len(bbox), the leading dimensions will be fully included.
19
+ source_tensor: The tensor (or blosc2 NDArray) to copy data from.
20
+ source_leading_slice: Optional slice to apply to the first leading dimension of the source
21
+ instead of slice(None). Useful for reading a subset of channels from a blosc2 NDArray
22
+ without decompressing channel 0.
23
+
24
+ Behavior:
25
+ For each dimension that the bbox covers (i.e. the last len(bbox) dims of source_tensor):
26
+ - Compute the overlapping region between the bbox and the source tensor.
27
+ - Determine the corresponding indices in the target tensor where the data will be copied.
28
+ For any extra leading dimensions (i.e. source_tensor.ndim > len(bbox)):
29
+ - Use slice(None) to include the entire dimension (or source_leading_slice for the
30
+ first leading dim when provided).
31
+ If source_tensor and target_tensor are on different devices, only the overlapping subregion
32
+ is transferred to the device of target_tensor.
33
+ """
34
+ total_dims = source_tensor.ndim
35
+ bbox_dims = len(bbox)
36
+ # Compute the number of leading dims that are not covered by bbox.
37
+ leading_dims = total_dims - bbox_dims
38
+
39
+ source_slices = []
40
+ target_slices = []
41
+
42
+ # For the leading dimensions, include the entire dimension.
43
+ for d in range(leading_dims):
44
+ if d == 0 and source_leading_slice is not None:
45
+ source_slices.append(source_leading_slice)
46
+ else:
47
+ source_slices.append(slice(None))
48
+ target_slices.append(slice(None))
49
+
50
+ # Process the dimensions covered by the bbox.
51
+ for d in range(bbox_dims):
52
+ box_start, box_end = bbox[d]
53
+ d_source = leading_dims + d
54
+ source_size = source_tensor.shape[d_source]
55
+
56
+ # Compute the overlapping region in source coordinates.
57
+ copy_start_source = max(box_start, 0)
58
+ copy_end_source = min(box_end, source_size)
59
+ copy_size = copy_end_source - copy_start_source
60
+
61
+ # Compute the corresponding indices in the target tensor.
62
+ copy_start_target = max(0, -box_start)
63
+ copy_end_target = copy_start_target + copy_size
64
+
65
+ source_slices.append(slice(copy_start_source, copy_end_source))
66
+ target_slices.append(slice(copy_start_target, copy_end_target))
67
+
68
+ # Extract the overlapping region from the source.
69
+ sub_source = source_tensor[tuple(source_slices)]
70
+ # Convert non-torch (e.g. blosc2 NDArray or numpy) to torch tensor.
71
+ if not isinstance(sub_source, torch.Tensor):
72
+ sub_source = torch.from_numpy(np.asarray(sub_source))
73
+ # Transfer only this subregion to the target tensor's device.
74
+ sub_source = sub_source.to(target_tensor.device) if isinstance(target_tensor, torch.Tensor) else sub_source.cpu()
75
+ # Write the data into the preallocated target_tensor.
76
+ target_tensor[tuple(target_slices)] = sub_source
77
+
78
+
79
+ def paste_tensor(target, source, bbox, channel_idx=None):
80
+ """
81
+ Paste a source tensor into a target tensor using a given bounding box.
82
+
83
+ Both tensors are assumed to be 3D (or 4D when channel_idx is provided for the target).
84
+ The bounding box is specified in the coordinate system of the target as:
85
+ [[x1, x2], [y1, y2], [z1, z2]]
86
+ and its size is assumed to be equal to the shape of the source tensor.
87
+ The bbox may exceed the boundaries of the target tensor.
88
+
89
+ Args:
90
+ target: The target tensor (torch.Tensor of shape (T0, T1, T2)) or blosc2 NDArray of
91
+ shape (C, T0, T1, T2) when channel_idx is provided.
92
+ source: The source tensor of shape (S0, S1, S2). It must be the same size as the bbox.
93
+ bbox (list or tuple): List of intervals for each dimension: [[x1, x2], [y1, y2], [z1, z2]].
94
+ channel_idx (int, optional): If provided, paste into this channel of a 4D target.
95
+ For torch.Tensor targets, delegates to target[channel_idx]. For blosc2 NDArray
96
+ targets, performs a numpy read-modify-write on the specified channel.
97
+
98
+ Returns:
99
+ The target after pasting in the source (for torch.Tensor targets).
100
+ """
101
+ # When channel_idx is given and target is a torch Tensor, delegate to the channel view.
102
+ if channel_idx is not None and isinstance(target, torch.Tensor):
103
+ return paste_tensor(target[channel_idx], source, bbox)
104
+
105
+ if channel_idx is not None:
106
+ # target is a 4D blosc2 NDArray; write to target[channel_idx] at bbox.
107
+ target_shape = target.shape[1:] # spatial dims
108
+
109
+ target_indices = []
110
+ source_indices = []
111
+
112
+ for i, (b0, b1) in enumerate(bbox):
113
+ t_start = max(b0, 0)
114
+ t_end = min(b1, target_shape[i])
115
+ if t_start >= t_end:
116
+ return # no overlap in this dimension
117
+ s_start = t_start - b0
118
+ s_end = s_start + (t_end - t_start)
119
+ target_indices.append((t_start, t_end))
120
+ source_indices.append((s_start, s_end))
121
+
122
+ src = source[
123
+ source_indices[0][0] : source_indices[0][1],
124
+ source_indices[1][0] : source_indices[1][1],
125
+ source_indices[2][0] : source_indices[2][1],
126
+ ]
127
+ if isinstance(src, torch.Tensor):
128
+ src = src.cpu().numpy().astype(np.float16)
129
+ else:
130
+ src = np.asarray(src).astype(np.float16)
131
+
132
+ target[
133
+ (
134
+ channel_idx,
135
+ slice(target_indices[0][0], target_indices[0][1]),
136
+ slice(target_indices[1][0], target_indices[1][1]),
137
+ slice(target_indices[2][0], target_indices[2][1]),
138
+ )
139
+ ] = src
140
+ return
141
+
142
+ target_shape = target.shape # (T0, T1, T2)
143
+
144
+ # For each dimension compute:
145
+ # - The valid region in the target: [t_start, t_end)
146
+ # - The corresponding region in the source: [s_start, s_end)
147
+ target_indices = []
148
+ source_indices = []
149
+
150
+ for i, (b0, b1) in enumerate(bbox):
151
+ # Determine valid region in target tensor:
152
+ t_start = max(b0, 0)
153
+ t_end = min(b1, target_shape[i])
154
+ # If there's no overlap in any dimension, nothing gets pasted.
155
+ if t_start >= t_end:
156
+ return target
157
+
158
+ # Determine corresponding indices in the source tensor.
159
+ # The source's coordinate 0 corresponds to b0 in the target.
160
+ s_start = t_start - b0
161
+ s_end = s_start + (t_end - t_start)
162
+
163
+ target_indices.append((t_start, t_end))
164
+ source_indices.append((s_start, s_end))
165
+
166
+ # Paste the corresponding region from source into target.
167
+ if isinstance(target, torch.Tensor):
168
+ target[
169
+ target_indices[0][0] : target_indices[0][1],
170
+ target_indices[1][0] : target_indices[1][1],
171
+ target_indices[2][0] : target_indices[2][1],
172
+ ] = source[
173
+ source_indices[0][0] : source_indices[0][1],
174
+ source_indices[1][0] : source_indices[1][1],
175
+ source_indices[2][0] : source_indices[2][1],
176
+ ].to(
177
+ target.device
178
+ )
179
+ else:
180
+ target[
181
+ target_indices[0][0] : target_indices[0][1],
182
+ target_indices[1][0] : target_indices[1][1],
183
+ target_indices[2][0] : target_indices[2][1],
184
+ ] = source[
185
+ source_indices[0][0] : source_indices[0][1],
186
+ source_indices[1][0] : source_indices[1][1],
187
+ source_indices[2][0] : source_indices[2][1],
188
+ ].cpu()
189
+
190
+ return target
191
+
192
+
193
+ def crop_to_valid(img, bbox):
194
+ """
195
+ Crops the image to the part of the bounding box that lies within the image.
196
+ Supports a 4D tensor of shape (C, X, Y, Z). The bounding box is specified as
197
+ [[x1, x2], [y1, y2], [z1, z2]] with half-open intervals.
198
+
199
+ Args:
200
+ img: Input tensor (or blosc2 NDArray) of shape (C, X, Y, Z).
201
+ bbox (list or tuple): Bounding box as a list of three intervals for spatial dims:
202
+ [[x1, x2], [y1, y2], [z1, z2]].
203
+
204
+ Returns:
205
+ cropped: Cropped data of shape (C, cropped_x, cropped_y, cropped_z).
206
+ pad (list of tuples): A list [(pad_x_left, pad_x_right),
207
+ (pad_y_left, pad_y_right),
208
+ (pad_z_left, pad_z_right)]
209
+ indicating how much padding needs to be applied on each side.
210
+ """
211
+ # Only spatial dimensions (X, Y, Z) are cropped; channels are preserved.
212
+ spatial_dims = img.shape[1:] # (X, Y, Z)
213
+ crop_indices = []
214
+ pad = [] # for each spatial dimension
215
+
216
+ for i, (start, end) in enumerate(bbox):
217
+ dim_size = spatial_dims[i]
218
+ # Clamp the indices to the valid range for cropping.
219
+ crop_start = max(start, 0)
220
+ crop_end = min(end, dim_size)
221
+ crop_indices.append((crop_start, crop_end))
222
+ # Calculate padding if the bbox goes out-of-bound.
223
+ pad_left = -start if start < 0 else 0
224
+ pad_right = end - dim_size if end > dim_size else 0
225
+ pad.append((pad_left, pad_right))
226
+
227
+ # Crop the image on spatial dimensions, leaving the channel dimension intact.
228
+ cropped = img[
229
+ :,
230
+ crop_indices[0][0] : crop_indices[0][1],
231
+ crop_indices[1][0] : crop_indices[1][1],
232
+ crop_indices[2][0] : crop_indices[2][1],
233
+ ]
234
+ return cropped, pad
235
+
236
+
237
+ def pad_cropped(cropped: torch.Tensor, pad):
238
+ """
239
+ Pads the cropped image using the given pad amounts.
240
+ Supports a 4D tensor of shape (C, X, Y, Z) and applies padding only on the spatial dimensions.
241
+ For 3D (volumetric) padding, F.pad expects a 5D tensor with shape (N, C, X, Y, Z).
242
+ Hence, we temporarily add a dummy batch dimension.
243
+
244
+ Args:
245
+ cropped (torch.Tensor): Cropped tensor of shape (C, X, Y, Z).
246
+ pad (list of tuples): List of padding for each spatial dimension, in order (x, y, z):
247
+ [(pad_x_left, pad_x_right),
248
+ (pad_y_left, pad_y_right),
249
+ (pad_z_left, pad_z_right)].
250
+
251
+ Returns:
252
+ padded (torch.Tensor): Padded tensor of shape (C, desired_x, desired_y, desired_z),
253
+ where the spatial dimensions match the bbox size.
254
+ """
255
+ # F.pad for 3D data expects a 5D input (N, C, X, Y, Z) and a pad tuple of length 6:
256
+ # (pad_z_left, pad_z_right, pad_y_left, pad_y_right, pad_x_left, pad_x_right)
257
+ need_unsqueeze = cropped.dim() == 4
258
+ if need_unsqueeze:
259
+ cropped = cropped.unsqueeze(0) # Now shape is (1, C, X, Y, Z)
260
+
261
+ # Reverse the pad list (currently in order x, y, z) to match F.pad's expected order: z, y, x.
262
+ pad_rev = pad[::-1]
263
+ pad_flat = [p for pair in pad_rev for p in pair]
264
+ padded = F.pad(cropped, pad_flat)
265
+
266
+ if need_unsqueeze:
267
+ padded = padded.squeeze(0)
268
+ return padded
@@ -0,0 +1,48 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from nnunetv2.utilities.helpers import empty_cache
5
+ from torch.backends import cudnn
6
+
7
+
8
+ @torch.inference_mode()
9
+ def iterative_3x3_same_padding_pool3d(x, kernel_size: int, use_min_pool: bool = False):
10
+ """
11
+ Applies 3D max pooling with manual asymmetric padding such that
12
+ the output shape is the same as the input shape.
13
+
14
+ Args:
15
+ x (Tensor): Input tensor of shape (N, C, D, H, W)
16
+ kernel_size (int or tuple): Kernel size for the pooling.
17
+ If int, the same kernel size is used for all three dimensions.
18
+
19
+ Returns:
20
+ Tensor: Output tensor with the same (D, H, W) dimensions as the input.
21
+ """
22
+ benchmark = cudnn.benchmark
23
+ cudnn.benchmark = False
24
+
25
+ assert kernel_size % 2 == 1, "Only works with odd kernels"
26
+
27
+ # Compute asymmetric padding for each dimension:
28
+ pad_front = (kernel_size - 1) // 2
29
+ pad_back = (kernel_size - 1) - pad_front
30
+
31
+ # For 3D (input shape: [N, C, D, H, W]), F.pad expects the padding in the order:
32
+ # (pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back)
33
+ x = F.pad(x, (pad_front, pad_back, pad_front, pad_back, pad_front, pad_back), mode="replicate")
34
+
35
+ iters = (kernel_size - 1) // 2
36
+ # Apply max pooling with no additional padding.
37
+ if not use_min_pool:
38
+ for _ in range(iters):
39
+ x = F.max_pool3d(x, kernel_size=3, stride=1, padding=0)
40
+ empty_cache(x.device)
41
+ cudnn.benchmark = benchmark
42
+ return x
43
+ else:
44
+ for _ in range(iters):
45
+ x = -F.max_pool3d(-x, kernel_size=3, stride=1, padding=0)
46
+ empty_cache(x.device)
47
+ cudnn.benchmark = benchmark
48
+ return x
@@ -0,0 +1,45 @@
1
+ import re
2
+ from typing import List, Tuple, Union
3
+
4
+
5
+ def version_to_tuple(version: str) -> Tuple[int, ...]:
6
+ return tuple(int(i) for i in re.findall(r"\d+", version))
7
+
8
+
9
+ def parse_channel_pair(channel_name: str, raw_channels) -> Tuple[int, int]:
10
+ if not isinstance(raw_channels, (tuple, list)) or len(raw_channels) != 2:
11
+ raise ValueError(
12
+ f"Invalid channel mapping for '{channel_name}': expected a pair [pos, neg], got {raw_channels}."
13
+ )
14
+ return int(raw_channels[0]), int(raw_channels[1])
15
+
16
+
17
+ def infer_num_interaction_channels_from_mapping(channel_mapping: dict) -> int:
18
+ max_positive_index = -1
19
+ max_negative_magnitude = 0
20
+
21
+ for k, v in channel_mapping.items():
22
+ if k == "prev_seg":
23
+ indices = [int(v)]
24
+ else:
25
+ pos_ch, neg_ch = parse_channel_pair(k, v)
26
+ indices = [pos_ch, neg_ch]
27
+
28
+ for idx in indices:
29
+ if idx >= 0:
30
+ max_positive_index = max(max_positive_index, idx)
31
+ else:
32
+ max_negative_magnitude = max(max_negative_magnitude, abs(idx))
33
+
34
+ # Positive indexing is 0-based, while negative indexing is 1-based-from-end.
35
+ return max(max_positive_index + 1, max_negative_magnitude, 1)
36
+
37
+
38
+ def transform_coordinates_noresampling(
39
+ coords_orig: Union[List[int], Tuple[int, ...]],
40
+ nnunet_preprocessing_crop_bbox: List[Tuple[int, int]],
41
+ ) -> Tuple[int, ...]:
42
+ """
43
+ Converts coordinates in the original uncropped image to the internal cropped representation.
44
+ """
45
+ return tuple(coords_orig[d] - nnunet_preprocessing_crop_bbox[d][0] for d in range(len(coords_orig)))
@@ -0,0 +1,16 @@
1
+ import platform
2
+ import sys
3
+
4
+
5
+ def is_linux_kernel_6_11():
6
+ if sys.platform != "linux":
7
+ return False
8
+
9
+ kernel_version = platform.release() # e.g., '6.11.0-24-generic'
10
+ version_parts = kernel_version.split(".")
11
+ try:
12
+ major = int(version_parts[0])
13
+ minor = int(version_parts[1])
14
+ return major == 6 and minor == 11
15
+ except (IndexError, ValueError):
16
+ return False
@@ -0,0 +1,13 @@
1
+ from math import ceil, floor
2
+
3
+
4
+ def round_to_nearest_odd(number: float):
5
+ assert number > 0
6
+ cl = ceil(number)
7
+ fl = floor(number)
8
+ if cl % 2 == 1:
9
+ return cl
10
+ elif fl % 2 == 1:
11
+ return fl
12
+ else:
13
+ return round(number) + 1