nnInteractive 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. nnInteractive/__init__.py +3 -0
  2. nnInteractive/inference/__init__.py +0 -0
  3. nnInteractive/inference/cvpr2025_challenge_baseline/__init__.py +0 -0
  4. nnInteractive/inference/cvpr2025_challenge_baseline/predict.py +173 -0
  5. nnInteractive/inference/inference_session.py +1400 -0
  6. nnInteractive/interaction/__init__.py +0 -0
  7. nnInteractive/interaction/point.py +166 -0
  8. nnInteractive/supervoxel/setup.py +4 -0
  9. nnInteractive/supervoxel/src/metadata.py +118 -0
  10. nnInteractive/supervoxel/src/reader.py +175 -0
  11. nnInteractive/supervoxel/src/run.py +136 -0
  12. nnInteractive/supervoxel/src/sam2/__init__.py +2 -0
  13. nnInteractive/supervoxel/src/sam2/sam2/__init__.py +11 -0
  14. nnInteractive/supervoxel/src/sam2/sam2/automatic_mask_generator.py +434 -0
  15. nnInteractive/supervoxel/src/sam2/sam2/benchmark.py +86 -0
  16. nnInteractive/supervoxel/src/sam2/sam2/build_sam.py +172 -0
  17. nnInteractive/supervoxel/src/sam2/sam2/modeling/__init__.py +5 -0
  18. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/__init__.py +5 -0
  19. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/hieradet.py +305 -0
  20. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/image_encoder.py +132 -0
  21. nnInteractive/supervoxel/src/sam2/sam2/modeling/backbones/utils.py +89 -0
  22. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_attention.py +167 -0
  23. nnInteractive/supervoxel/src/sam2/sam2/modeling/memory_encoder.py +179 -0
  24. nnInteractive/supervoxel/src/sam2/sam2/modeling/position_encoding.py +217 -0
  25. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/__init__.py +5 -0
  26. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/mask_decoder.py +274 -0
  27. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/prompt_encoder.py +194 -0
  28. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam/transformer.py +293 -0
  29. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_base.py +879 -0
  30. nnInteractive/supervoxel/src/sam2/sam2/modeling/sam2_utils.py +315 -0
  31. nnInteractive/supervoxel/src/sam2/sam2/sam2_image_predictor.py +433 -0
  32. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor.py +1171 -0
  33. nnInteractive/supervoxel/src/sam2/sam2/sam2_video_predictor_legacy.py +1125 -0
  34. nnInteractive/supervoxel/src/sam2/sam2/utils/__init__.py +5 -0
  35. nnInteractive/supervoxel/src/sam2/sam2/utils/amg.py +332 -0
  36. nnInteractive/supervoxel/src/sam2/sam2/utils/misc.py +488 -0
  37. nnInteractive/supervoxel/src/sam2/sam2/utils/transforms.py +108 -0
  38. nnInteractive/supervoxel/src/sam2/setup.py +174 -0
  39. nnInteractive/supervoxel/src/sam2/training/__init__.py +5 -0
  40. nnInteractive/supervoxel/src/sam2/training/dataset/__init__.py +5 -0
  41. nnInteractive/supervoxel/src/sam2/training/dataset/sam2_datasets.py +176 -0
  42. nnInteractive/supervoxel/src/sam2/training/dataset/transforms.py +481 -0
  43. nnInteractive/supervoxel/src/sam2/training/dataset/utils.py +102 -0
  44. nnInteractive/supervoxel/src/sam2/training/dataset/vos_dataset.py +154 -0
  45. nnInteractive/supervoxel/src/sam2/training/dataset/vos_raw_dataset.py +290 -0
  46. nnInteractive/supervoxel/src/sam2/training/dataset/vos_sampler.py +103 -0
  47. nnInteractive/supervoxel/src/sam2/training/dataset/vos_segment_loader.py +289 -0
  48. nnInteractive/supervoxel/src/sam2/training/loss_fns.py +290 -0
  49. nnInteractive/supervoxel/src/sam2/training/model/__init__.py +5 -0
  50. nnInteractive/supervoxel/src/sam2/training/model/sam2.py +515 -0
  51. nnInteractive/supervoxel/src/sam2/training/optimizer.py +462 -0
  52. nnInteractive/supervoxel/src/sam2/training/scripts/sav_frame_extraction_submitit.py +157 -0
  53. nnInteractive/supervoxel/src/sam2/training/train.py +232 -0
  54. nnInteractive/supervoxel/src/sam2/training/trainer.py +1051 -0
  55. nnInteractive/supervoxel/src/sam2/training/utils/__init__.py +5 -0
  56. nnInteractive/supervoxel/src/sam2/training/utils/checkpoint_utils.py +328 -0
  57. nnInteractive/supervoxel/src/sam2/training/utils/data_utils.py +166 -0
  58. nnInteractive/supervoxel/src/sam2/training/utils/distributed.py +560 -0
  59. nnInteractive/supervoxel/src/sam2/training/utils/logger.py +236 -0
  60. nnInteractive/supervoxel/src/sam2/training/utils/train_utils.py +275 -0
  61. nnInteractive/supervoxel/src/supervoxel.py +198 -0
  62. nnInteractive/trainer/__init__.py +0 -0
  63. nnInteractive/trainer/nnInteractiveTrainer.py +24 -0
  64. nnInteractive/utils/__init__.py +0 -0
  65. nnInteractive/utils/bboxes.py +217 -0
  66. nnInteractive/utils/checkpoint_cleansing.py +9 -0
  67. nnInteractive/utils/crop.py +268 -0
  68. nnInteractive/utils/erosion_dilation.py +48 -0
  69. nnInteractive/utils/inference_helpers.py +45 -0
  70. nnInteractive/utils/os_shennanigans.py +16 -0
  71. nnInteractive/utils/rounding.py +13 -0
  72. nninteractive-2.0.0.dist-info/METADATA +511 -0
  73. nninteractive-2.0.0.dist-info/RECORD +76 -0
  74. nninteractive-2.0.0.dist-info/WHEEL +5 -0
  75. nninteractive-2.0.0.dist-info/licenses/LICENSE +201 -0
  76. nninteractive-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,315 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import copy
9
+ from typing import Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from sam2.utils.misc import mask_to_box
17
+
18
+
19
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
20
+ """
21
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
22
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
23
+ - a) the closest conditioning frame before `frame_idx` (if any);
24
+ - b) the closest conditioning frame after `frame_idx` (if any);
25
+ - c) any other temporally closest conditioning frames until reaching a total
26
+ of `max_cond_frame_num` conditioning frames.
27
+
28
+ Outputs:
29
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
30
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
31
+ """
32
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
33
+ selected_outputs = cond_frame_outputs
34
+ unselected_outputs = {}
35
+ else:
36
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
37
+ selected_outputs = {}
38
+
39
+ # the closest conditioning frame before `frame_idx` (if any)
40
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
41
+ if idx_before is not None:
42
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
43
+
44
+ # the closest conditioning frame after `frame_idx` (if any)
45
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
46
+ if idx_after is not None:
47
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
48
+
49
+ # add other temporally closest conditioning frames until reaching a total
50
+ # of `max_cond_frame_num` conditioning frames.
51
+ num_remain = max_cond_frame_num - len(selected_outputs)
52
+ inds_remain = sorted(
53
+ (t for t in cond_frame_outputs if t not in selected_outputs),
54
+ key=lambda x: abs(x - frame_idx),
55
+ )[:num_remain]
56
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
57
+ unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
58
+
59
+ return selected_outputs, unselected_outputs
60
+
61
+
62
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
63
+ """
64
+ Get 1D sine positional embedding as in the original Transformer paper.
65
+ """
66
+ pe_dim = dim // 2
67
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
68
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
69
+
70
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
71
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
72
+ return pos_embed
73
+
74
+
75
+ def get_activation_fn(activation):
76
+ """Return an activation function given a string"""
77
+ if activation == "relu":
78
+ return F.relu
79
+ if activation == "gelu":
80
+ return F.gelu
81
+ if activation == "glu":
82
+ return F.glu
83
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
84
+
85
+
86
+ def get_clones(module, N):
87
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
88
+
89
+
90
+ class DropPath(nn.Module):
91
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
92
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
93
+ super(DropPath, self).__init__()
94
+ self.drop_prob = drop_prob
95
+ self.scale_by_keep = scale_by_keep
96
+
97
+ def forward(self, x):
98
+ if self.drop_prob == 0.0 or not self.training:
99
+ return x
100
+ keep_prob = 1 - self.drop_prob
101
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
102
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
103
+ if keep_prob > 0.0 and self.scale_by_keep:
104
+ random_tensor.div_(keep_prob)
105
+ return x * random_tensor
106
+
107
+
108
+ # Lightly adapted from
109
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
110
+ class MLP(nn.Module):
111
+ def __init__(
112
+ self,
113
+ input_dim: int,
114
+ hidden_dim: int,
115
+ output_dim: int,
116
+ num_layers: int,
117
+ activation: nn.Module = nn.ReLU,
118
+ sigmoid_output: bool = False,
119
+ ) -> None:
120
+ super().__init__()
121
+ self.num_layers = num_layers
122
+ h = [hidden_dim] * (num_layers - 1)
123
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
124
+ self.sigmoid_output = sigmoid_output
125
+ self.act = activation()
126
+
127
+ def forward(self, x):
128
+ for i, layer in enumerate(self.layers):
129
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
130
+ if self.sigmoid_output:
131
+ x = F.sigmoid(x)
132
+ return x
133
+
134
+
135
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
136
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
137
+ class LayerNorm2d(nn.Module):
138
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
139
+ super().__init__()
140
+ self.weight = nn.Parameter(torch.ones(num_channels))
141
+ self.bias = nn.Parameter(torch.zeros(num_channels))
142
+ self.eps = eps
143
+
144
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
+ u = x.mean(1, keepdim=True)
146
+ s = (x - u).pow(2).mean(1, keepdim=True)
147
+ x = (x - u) / torch.sqrt(s + self.eps)
148
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
149
+ return x
150
+
151
+
152
+ def sample_box_points(
153
+ masks: torch.Tensor,
154
+ noise: float = 0.1, # SAM default
155
+ noise_bound: int = 20, # SAM default
156
+ top_left_label: int = 2,
157
+ bottom_right_label: int = 3,
158
+ ) -> Tuple[np.array, np.array]:
159
+ """
160
+ Sample a noised version of the top left and bottom right corners of a given `bbox`
161
+
162
+ Inputs:
163
+ - masks: [B, 1, H,W] boxes, dtype=torch.Tensor
164
+ - noise: noise as a fraction of box width and height, dtype=float
165
+ - noise_bound: maximum amount of noise (in pure pixesl), dtype=int
166
+
167
+ Returns:
168
+ - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
169
+ - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
170
+ """
171
+ device = masks.device
172
+ box_coords = mask_to_box(masks)
173
+ B, _, H, W = masks.shape
174
+ box_labels = torch.tensor([top_left_label, bottom_right_label], dtype=torch.int, device=device).repeat(B)
175
+ if noise > 0.0:
176
+ if not isinstance(noise_bound, torch.Tensor):
177
+ noise_bound = torch.tensor(noise_bound, device=device)
178
+ bbox_w = box_coords[..., 2] - box_coords[..., 0]
179
+ bbox_h = box_coords[..., 3] - box_coords[..., 1]
180
+ max_dx = torch.min(bbox_w * noise, noise_bound)
181
+ max_dy = torch.min(bbox_h * noise, noise_bound)
182
+ box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
183
+ box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
184
+
185
+ box_coords = box_coords + box_noise
186
+ img_bounds = torch.tensor([W, H, W, H], device=device) - 1 # uncentered pixel coords
187
+ box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
188
+
189
+ box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
190
+ box_labels = box_labels.reshape(-1, 2)
191
+ return box_coords, box_labels
192
+
193
+
194
+ def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
195
+ """
196
+ Sample `num_pt` random points (along with their labels) independently from the error regions.
197
+
198
+ Inputs:
199
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
200
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
201
+ - num_pt: int, number of points to sample independently for each of the B error maps
202
+
203
+ Outputs:
204
+ - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
205
+ - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
206
+ negative clicks
207
+ """
208
+ if pred_masks is None: # if pred_masks is not provided, treat it as empty
209
+ pred_masks = torch.zeros_like(gt_masks)
210
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
211
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
212
+ assert num_pt >= 0
213
+
214
+ B, _, H_im, W_im = gt_masks.shape
215
+ device = gt_masks.device
216
+
217
+ # false positive region, a new point sampled in this region should have
218
+ # negative label to correct the FP error
219
+ fp_masks = ~gt_masks & pred_masks
220
+ # false negative region, a new point sampled in this region should have
221
+ # positive label to correct the FN error
222
+ fn_masks = gt_masks & ~pred_masks
223
+ # whether the prediction completely match the ground-truth on each mask
224
+ all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
225
+ all_correct = all_correct[..., None, None]
226
+
227
+ # channel 0 is FP map, while channel 1 is FN map
228
+ pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
229
+ # sample a negative new click from FP region or a positive new click
230
+ # from FN region, depend on where the maximum falls,
231
+ # and in case the predictions are all correct (no FP or FN), we just
232
+ # sample a negative click from the background region
233
+ pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
234
+ pts_noise[..., 1] *= fn_masks
235
+ pts_idx = pts_noise.flatten(2).argmax(dim=2)
236
+ labels = (pts_idx % 2).to(torch.int32)
237
+ pts_idx = pts_idx // 2
238
+ pts_x = pts_idx % W_im
239
+ pts_y = pts_idx // W_im
240
+ points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
241
+ return points, labels
242
+
243
+
244
+ def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
245
+ """
246
+ Sample 1 random point (along with its label) from the center of each error region,
247
+ that is, the point with the largest distance to the boundary of each error region.
248
+ This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
249
+
250
+ Inputs:
251
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
252
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
253
+ - padding: if True, pad with boundary of 1 px for distance transform
254
+
255
+ Outputs:
256
+ - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
257
+ - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
258
+ """
259
+ import cv2
260
+
261
+ if pred_masks is None:
262
+ pred_masks = torch.zeros_like(gt_masks)
263
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
264
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
265
+
266
+ B, _, _, W_im = gt_masks.shape
267
+ device = gt_masks.device
268
+
269
+ # false positive region, a new point sampled in this region should have
270
+ # negative label to correct the FP error
271
+ fp_masks = ~gt_masks & pred_masks
272
+ # false negative region, a new point sampled in this region should have
273
+ # positive label to correct the FN error
274
+ fn_masks = gt_masks & ~pred_masks
275
+
276
+ fp_masks = fp_masks.cpu().numpy()
277
+ fn_masks = fn_masks.cpu().numpy()
278
+ points = torch.zeros(B, 1, 2, dtype=torch.float)
279
+ labels = torch.ones(B, 1, dtype=torch.int32)
280
+ for b in range(B):
281
+ fn_mask = fn_masks[b, 0]
282
+ fp_mask = fp_masks[b, 0]
283
+ if padding:
284
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
285
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
286
+ # compute the distance of each point in FN/FP region to its boundary
287
+ fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
288
+ fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
289
+ if padding:
290
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
291
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
292
+
293
+ # take the point in FN/FP region with the largest distance to its boundary
294
+ fn_mask_dt_flat = fn_mask_dt.reshape(-1)
295
+ fp_mask_dt_flat = fp_mask_dt.reshape(-1)
296
+ fn_argmax = np.argmax(fn_mask_dt_flat)
297
+ fp_argmax = np.argmax(fp_mask_dt_flat)
298
+ is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
299
+ pt_idx = fn_argmax if is_positive else fp_argmax
300
+ points[b, 0, 0] = pt_idx % W_im # x
301
+ points[b, 0, 1] = pt_idx // W_im # y
302
+ labels[b, 0] = int(is_positive)
303
+
304
+ points = points.to(device)
305
+ labels = labels.to(device)
306
+ return points, labels
307
+
308
+
309
+ def get_next_point(gt_masks, pred_masks, method):
310
+ if method == "uniform":
311
+ return sample_random_points_from_errors(gt_masks, pred_masks)
312
+ elif method == "center":
313
+ return sample_one_point_from_error_center(gt_masks, pred_masks)
314
+ else:
315
+ raise ValueError(f"unknown sampling method {method}")