frontveg 0.1.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. frontveg/__init__.py +11 -0
  2. frontveg/_tests/__init__.py +0 -0
  3. frontveg/_tests/test_widget.py +66 -0
  4. frontveg/_version.py +21 -0
  5. frontveg/_widget.py +132 -0
  6. frontveg/napari.yaml +14 -0
  7. frontveg/utils.py +95 -0
  8. frontveg-0.1.dev1.dist-info/METADATA +143 -0
  9. frontveg-0.1.dev1.dist-info/RECORD +44 -0
  10. frontveg-0.1.dev1.dist-info/WHEEL +5 -0
  11. frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
  12. frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
  13. frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
  14. sam2/__init__.py +11 -0
  15. sam2/automatic_mask_generator.py +454 -0
  16. sam2/build_sam.py +167 -0
  17. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  18. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  19. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  20. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  21. sam2/modeling/__init__.py +5 -0
  22. sam2/modeling/backbones/__init__.py +5 -0
  23. sam2/modeling/backbones/hieradet.py +317 -0
  24. sam2/modeling/backbones/image_encoder.py +134 -0
  25. sam2/modeling/backbones/utils.py +95 -0
  26. sam2/modeling/memory_attention.py +169 -0
  27. sam2/modeling/memory_encoder.py +181 -0
  28. sam2/modeling/position_encoding.py +221 -0
  29. sam2/modeling/sam/__init__.py +5 -0
  30. sam2/modeling/sam/mask_decoder.py +295 -0
  31. sam2/modeling/sam/prompt_encoder.py +182 -0
  32. sam2/modeling/sam/transformer.py +360 -0
  33. sam2/modeling/sam2_base.py +907 -0
  34. sam2/modeling/sam2_utils.py +323 -0
  35. sam2/sam2_hiera_b+.yaml +1 -0
  36. sam2/sam2_hiera_l.yaml +1 -0
  37. sam2/sam2_hiera_s.yaml +1 -0
  38. sam2/sam2_hiera_t.yaml +1 -0
  39. sam2/sam2_image_predictor.py +466 -0
  40. sam2/sam2_video_predictor.py +1172 -0
  41. sam2/utils/__init__.py +5 -0
  42. sam2/utils/amg.py +348 -0
  43. sam2/utils/misc.py +349 -0
  44. sam2/utils/transforms.py +118 -0
@@ -0,0 +1,323 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import copy
9
+ from typing import Tuple
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from sam2.utils.misc import mask_to_box
17
+
18
+
19
+ def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
20
+ """
21
+ Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
22
+ that are temporally closest to the current frame at `frame_idx`. Here, we take
23
+ - a) the closest conditioning frame before `frame_idx` (if any);
24
+ - b) the closest conditioning frame after `frame_idx` (if any);
25
+ - c) any other temporally closest conditioning frames until reaching a total
26
+ of `max_cond_frame_num` conditioning frames.
27
+
28
+ Outputs:
29
+ - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
30
+ - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
31
+ """
32
+ if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
33
+ selected_outputs = cond_frame_outputs
34
+ unselected_outputs = {}
35
+ else:
36
+ assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
37
+ selected_outputs = {}
38
+
39
+ # the closest conditioning frame before `frame_idx` (if any)
40
+ idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
41
+ if idx_before is not None:
42
+ selected_outputs[idx_before] = cond_frame_outputs[idx_before]
43
+
44
+ # the closest conditioning frame after `frame_idx` (if any)
45
+ idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
46
+ if idx_after is not None:
47
+ selected_outputs[idx_after] = cond_frame_outputs[idx_after]
48
+
49
+ # add other temporally closest conditioning frames until reaching a total
50
+ # of `max_cond_frame_num` conditioning frames.
51
+ num_remain = max_cond_frame_num - len(selected_outputs)
52
+ inds_remain = sorted(
53
+ (t for t in cond_frame_outputs if t not in selected_outputs),
54
+ key=lambda x: abs(x - frame_idx),
55
+ )[:num_remain]
56
+ selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
57
+ unselected_outputs = {
58
+ t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
59
+ }
60
+
61
+ return selected_outputs, unselected_outputs
62
+
63
+
64
+ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
65
+ """
66
+ Get 1D sine positional embedding as in the original Transformer paper.
67
+ """
68
+ pe_dim = dim // 2
69
+ dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
70
+ dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
71
+
72
+ pos_embed = pos_inds.unsqueeze(-1) / dim_t
73
+ pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
74
+ return pos_embed
75
+
76
+
77
+ def get_activation_fn(activation):
78
+ """Return an activation function given a string"""
79
+ if activation == "relu":
80
+ return F.relu
81
+ if activation == "gelu":
82
+ return F.gelu
83
+ if activation == "glu":
84
+ return F.glu
85
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
86
+
87
+
88
+ def get_clones(module, N):
89
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
90
+
91
+
92
+ class DropPath(nn.Module):
93
+ # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
94
+ def __init__(self, drop_prob=0.0, scale_by_keep=True):
95
+ super(DropPath, self).__init__()
96
+ self.drop_prob = drop_prob
97
+ self.scale_by_keep = scale_by_keep
98
+
99
+ def forward(self, x):
100
+ if self.drop_prob == 0.0 or not self.training:
101
+ return x
102
+ keep_prob = 1 - self.drop_prob
103
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
104
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
105
+ if keep_prob > 0.0 and self.scale_by_keep:
106
+ random_tensor.div_(keep_prob)
107
+ return x * random_tensor
108
+
109
+
110
+ # Lightly adapted from
111
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
112
+ class MLP(nn.Module):
113
+ def __init__(
114
+ self,
115
+ input_dim: int,
116
+ hidden_dim: int,
117
+ output_dim: int,
118
+ num_layers: int,
119
+ activation: nn.Module = nn.ReLU,
120
+ sigmoid_output: bool = False,
121
+ ) -> None:
122
+ super().__init__()
123
+ self.num_layers = num_layers
124
+ h = [hidden_dim] * (num_layers - 1)
125
+ self.layers = nn.ModuleList(
126
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
127
+ )
128
+ self.sigmoid_output = sigmoid_output
129
+ self.act = activation()
130
+
131
+ def forward(self, x):
132
+ for i, layer in enumerate(self.layers):
133
+ x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
134
+ if self.sigmoid_output:
135
+ x = F.sigmoid(x)
136
+ return x
137
+
138
+
139
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
140
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
141
+ class LayerNorm2d(nn.Module):
142
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
143
+ super().__init__()
144
+ self.weight = nn.Parameter(torch.ones(num_channels))
145
+ self.bias = nn.Parameter(torch.zeros(num_channels))
146
+ self.eps = eps
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ u = x.mean(1, keepdim=True)
150
+ s = (x - u).pow(2).mean(1, keepdim=True)
151
+ x = (x - u) / torch.sqrt(s + self.eps)
152
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
153
+ return x
154
+
155
+
156
+ def sample_box_points(
157
+ masks: torch.Tensor,
158
+ noise: float = 0.1, # SAM default
159
+ noise_bound: int = 20, # SAM default
160
+ top_left_label: int = 2,
161
+ bottom_right_label: int = 3,
162
+ ) -> Tuple[np.array, np.array]:
163
+ """
164
+ Sample a noised version of the top left and bottom right corners of a given `bbox`
165
+
166
+ Inputs:
167
+ - masks: [B, 1, H,W] boxes, dtype=torch.Tensor
168
+ - noise: noise as a fraction of box width and height, dtype=float
169
+ - noise_bound: maximum amount of noise (in pure pixesl), dtype=int
170
+
171
+ Returns:
172
+ - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
173
+ - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
174
+ """
175
+ device = masks.device
176
+ box_coords = mask_to_box(masks)
177
+ B, _, H, W = masks.shape
178
+ box_labels = torch.tensor(
179
+ [top_left_label, bottom_right_label], dtype=torch.int, device=device
180
+ ).repeat(B)
181
+ if noise > 0.0:
182
+ if not isinstance(noise_bound, torch.Tensor):
183
+ noise_bound = torch.tensor(noise_bound, device=device)
184
+ bbox_w = box_coords[..., 2] - box_coords[..., 0]
185
+ bbox_h = box_coords[..., 3] - box_coords[..., 1]
186
+ max_dx = torch.min(bbox_w * noise, noise_bound)
187
+ max_dy = torch.min(bbox_h * noise, noise_bound)
188
+ box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
189
+ box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
190
+
191
+ box_coords = box_coords + box_noise
192
+ img_bounds = (
193
+ torch.tensor([W, H, W, H], device=device) - 1
194
+ ) # uncentered pixel coords
195
+ box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
196
+
197
+ box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
198
+ box_labels = box_labels.reshape(-1, 2)
199
+ return box_coords, box_labels
200
+
201
+
202
+ def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
203
+ """
204
+ Sample `num_pt` random points (along with their labels) independently from the error regions.
205
+
206
+ Inputs:
207
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
208
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
209
+ - num_pt: int, number of points to sample independently for each of the B error maps
210
+
211
+ Outputs:
212
+ - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
213
+ - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
214
+ negative clicks
215
+ """
216
+ if pred_masks is None: # if pred_masks is not provided, treat it as empty
217
+ pred_masks = torch.zeros_like(gt_masks)
218
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
219
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
220
+ assert num_pt >= 0
221
+
222
+ B, _, H_im, W_im = gt_masks.shape
223
+ device = gt_masks.device
224
+
225
+ # false positive region, a new point sampled in this region should have
226
+ # negative label to correct the FP error
227
+ fp_masks = ~gt_masks & pred_masks
228
+ # false negative region, a new point sampled in this region should have
229
+ # positive label to correct the FN error
230
+ fn_masks = gt_masks & ~pred_masks
231
+ # whether the prediction completely match the ground-truth on each mask
232
+ all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
233
+ all_correct = all_correct[..., None, None]
234
+
235
+ # channel 0 is FP map, while channel 1 is FN map
236
+ pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
237
+ # sample a negative new click from FP region or a positive new click
238
+ # from FN region, depend on where the maximum falls,
239
+ # and in case the predictions are all correct (no FP or FN), we just
240
+ # sample a negative click from the background region
241
+ pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
242
+ pts_noise[..., 1] *= fn_masks
243
+ pts_idx = pts_noise.flatten(2).argmax(dim=2)
244
+ labels = (pts_idx % 2).to(torch.int32)
245
+ pts_idx = pts_idx // 2
246
+ pts_x = pts_idx % W_im
247
+ pts_y = pts_idx // W_im
248
+ points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
249
+ return points, labels
250
+
251
+
252
+ def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
253
+ """
254
+ Sample 1 random point (along with its label) from the center of each error region,
255
+ that is, the point with the largest distance to the boundary of each error region.
256
+ This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
257
+
258
+ Inputs:
259
+ - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
260
+ - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
261
+ - padding: if True, pad with boundary of 1 px for distance transform
262
+
263
+ Outputs:
264
+ - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
265
+ - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
266
+ """
267
+ import cv2
268
+
269
+ if pred_masks is None:
270
+ pred_masks = torch.zeros_like(gt_masks)
271
+ assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
272
+ assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
273
+
274
+ B, _, _, W_im = gt_masks.shape
275
+ device = gt_masks.device
276
+
277
+ # false positive region, a new point sampled in this region should have
278
+ # negative label to correct the FP error
279
+ fp_masks = ~gt_masks & pred_masks
280
+ # false negative region, a new point sampled in this region should have
281
+ # positive label to correct the FN error
282
+ fn_masks = gt_masks & ~pred_masks
283
+
284
+ fp_masks = fp_masks.cpu().numpy()
285
+ fn_masks = fn_masks.cpu().numpy()
286
+ points = torch.zeros(B, 1, 2, dtype=torch.float)
287
+ labels = torch.ones(B, 1, dtype=torch.int32)
288
+ for b in range(B):
289
+ fn_mask = fn_masks[b, 0]
290
+ fp_mask = fp_masks[b, 0]
291
+ if padding:
292
+ fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
293
+ fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
294
+ # compute the distance of each point in FN/FP region to its boundary
295
+ fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
296
+ fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
297
+ if padding:
298
+ fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
299
+ fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
300
+
301
+ # take the point in FN/FP region with the largest distance to its boundary
302
+ fn_mask_dt_flat = fn_mask_dt.reshape(-1)
303
+ fp_mask_dt_flat = fp_mask_dt.reshape(-1)
304
+ fn_argmax = np.argmax(fn_mask_dt_flat)
305
+ fp_argmax = np.argmax(fp_mask_dt_flat)
306
+ is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
307
+ pt_idx = fn_argmax if is_positive else fp_argmax
308
+ points[b, 0, 0] = pt_idx % W_im # x
309
+ points[b, 0, 1] = pt_idx // W_im # y
310
+ labels[b, 0] = int(is_positive)
311
+
312
+ points = points.to(device)
313
+ labels = labels.to(device)
314
+ return points, labels
315
+
316
+
317
+ def get_next_point(gt_masks, pred_masks, method):
318
+ if method == "uniform":
319
+ return sample_random_points_from_errors(gt_masks, pred_masks)
320
+ elif method == "center":
321
+ return sample_one_point_from_error_center(gt_masks, pred_masks)
322
+ else:
323
+ raise ValueError(f"unknown sampling method {method}")
@@ -0,0 +1 @@
1
+ configs/sam2/sam2_hiera_b+.yaml
sam2/sam2_hiera_l.yaml ADDED
@@ -0,0 +1 @@
1
+ configs/sam2/sam2_hiera_l.yaml
sam2/sam2_hiera_s.yaml ADDED
@@ -0,0 +1 @@
1
+ configs/sam2/sam2_hiera_s.yaml
sam2/sam2_hiera_t.yaml ADDED
@@ -0,0 +1 @@
1
+ configs/sam2/sam2_hiera_t.yaml