frontveg 0.1.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- frontveg/__init__.py +11 -0
- frontveg/_tests/__init__.py +0 -0
- frontveg/_tests/test_widget.py +66 -0
- frontveg/_version.py +21 -0
- frontveg/_widget.py +132 -0
- frontveg/napari.yaml +14 -0
- frontveg/utils.py +95 -0
- frontveg-0.1.dev1.dist-info/METADATA +143 -0
- frontveg-0.1.dev1.dist-info/RECORD +44 -0
- frontveg-0.1.dev1.dist-info/WHEEL +5 -0
- frontveg-0.1.dev1.dist-info/entry_points.txt +2 -0
- frontveg-0.1.dev1.dist-info/licenses/LICENSE +28 -0
- frontveg-0.1.dev1.dist-info/top_level.txt +2 -0
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/build_sam.py +167 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +95 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +221 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +182 -0
- sam2/modeling/sam/transformer.py +360 -0
- sam2/modeling/sam2_base.py +907 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +1 -0
- sam2/sam2_hiera_l.yaml +1 -0
- sam2/sam2_hiera_s.yaml +1 -0
- sam2/sam2_hiera_t.yaml +1 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
sam2/utils/transforms.py
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
|
4
|
+
# This source code is licensed under the license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import warnings
|
8
|
+
|
9
|
+
import torch
|
10
|
+
import torch.nn as nn
|
11
|
+
import torch.nn.functional as F
|
12
|
+
from torchvision.transforms import Normalize, Resize, ToTensor
|
13
|
+
|
14
|
+
|
15
|
+
class SAM2Transforms(nn.Module):
|
16
|
+
def __init__(
|
17
|
+
self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
|
18
|
+
):
|
19
|
+
"""
|
20
|
+
Transforms for SAM2.
|
21
|
+
"""
|
22
|
+
super().__init__()
|
23
|
+
self.resolution = resolution
|
24
|
+
self.mask_threshold = mask_threshold
|
25
|
+
self.max_hole_area = max_hole_area
|
26
|
+
self.max_sprinkle_area = max_sprinkle_area
|
27
|
+
self.mean = [0.485, 0.456, 0.406]
|
28
|
+
self.std = [0.229, 0.224, 0.225]
|
29
|
+
self.to_tensor = ToTensor()
|
30
|
+
self.transforms = torch.jit.script(
|
31
|
+
nn.Sequential(
|
32
|
+
Resize((self.resolution, self.resolution)),
|
33
|
+
Normalize(self.mean, self.std),
|
34
|
+
)
|
35
|
+
)
|
36
|
+
|
37
|
+
def __call__(self, x):
|
38
|
+
x = self.to_tensor(x)
|
39
|
+
return self.transforms(x)
|
40
|
+
|
41
|
+
def forward_batch(self, img_list):
|
42
|
+
img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
|
43
|
+
img_batch = torch.stack(img_batch, dim=0)
|
44
|
+
return img_batch
|
45
|
+
|
46
|
+
def transform_coords(
|
47
|
+
self, coords: torch.Tensor, normalize=False, orig_hw=None
|
48
|
+
) -> torch.Tensor:
|
49
|
+
"""
|
50
|
+
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
|
51
|
+
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
|
52
|
+
|
53
|
+
Returns
|
54
|
+
Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
|
55
|
+
"""
|
56
|
+
if normalize:
|
57
|
+
assert orig_hw is not None
|
58
|
+
h, w = orig_hw
|
59
|
+
coords = coords.clone()
|
60
|
+
coords[..., 0] = coords[..., 0] / w
|
61
|
+
coords[..., 1] = coords[..., 1] / h
|
62
|
+
|
63
|
+
coords = coords * self.resolution # unnormalize coords
|
64
|
+
return coords
|
65
|
+
|
66
|
+
def transform_boxes(
|
67
|
+
self, boxes: torch.Tensor, normalize=False, orig_hw=None
|
68
|
+
) -> torch.Tensor:
|
69
|
+
"""
|
70
|
+
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
|
71
|
+
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
|
72
|
+
"""
|
73
|
+
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
|
74
|
+
return boxes
|
75
|
+
|
76
|
+
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
|
77
|
+
"""
|
78
|
+
Perform PostProcessing on output masks.
|
79
|
+
"""
|
80
|
+
from sam2.utils.misc import get_connected_components
|
81
|
+
|
82
|
+
masks = masks.float()
|
83
|
+
input_masks = masks
|
84
|
+
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
85
|
+
try:
|
86
|
+
if self.max_hole_area > 0:
|
87
|
+
# Holes are those connected components in background with area <= self.fill_hole_area
|
88
|
+
# (background regions are those with mask scores <= self.mask_threshold)
|
89
|
+
labels, areas = get_connected_components(
|
90
|
+
mask_flat <= self.mask_threshold
|
91
|
+
)
|
92
|
+
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
93
|
+
is_hole = is_hole.reshape_as(masks)
|
94
|
+
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
95
|
+
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
96
|
+
|
97
|
+
if self.max_sprinkle_area > 0:
|
98
|
+
labels, areas = get_connected_components(
|
99
|
+
mask_flat > self.mask_threshold
|
100
|
+
)
|
101
|
+
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
102
|
+
is_hole = is_hole.reshape_as(masks)
|
103
|
+
# We fill holes with negative mask score (-10.0) to change them to background.
|
104
|
+
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
105
|
+
except Exception as e:
|
106
|
+
# Skip the post-processing step if the CUDA kernel fails
|
107
|
+
warnings.warn(
|
108
|
+
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
109
|
+
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
110
|
+
"functionality may be limited (which doesn't affect the results in most cases; see "
|
111
|
+
"https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
|
112
|
+
category=UserWarning,
|
113
|
+
stacklevel=2,
|
114
|
+
)
|
115
|
+
masks = input_masks
|
116
|
+
|
117
|
+
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
|
118
|
+
return masks
|