ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/models/sam/amg.py
CHANGED
|
@@ -8,10 +8,9 @@ import numpy as np
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
def is_box_near_crop_edge(
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
atol: float = 20.0) -> torch.Tensor:
|
|
11
|
+
def is_box_near_crop_edge(
|
|
12
|
+
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
|
13
|
+
) -> torch.Tensor:
|
|
15
14
|
"""Return a boolean tensor indicating if boxes are near the crop edge."""
|
|
16
15
|
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
|
17
16
|
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
|
@@ -24,10 +23,10 @@ def is_box_near_crop_edge(boxes: torch.Tensor,
|
|
|
24
23
|
|
|
25
24
|
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
|
26
25
|
"""Yield batches of data from the input arguments."""
|
|
27
|
-
assert args and all(len(a) == len(args[0]) for a in args),
|
|
26
|
+
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
|
|
28
27
|
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
|
29
28
|
for b in range(n_batches):
|
|
30
|
-
yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
|
|
29
|
+
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
|
31
30
|
|
|
32
31
|
|
|
33
32
|
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
|
@@ -39,9 +38,8 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
|
|
|
39
38
|
"""
|
|
40
39
|
# One mask is always contained inside the other.
|
|
41
40
|
# Save memory by preventing unnecessary cast to torch.int64
|
|
42
|
-
intersections = (
|
|
43
|
-
|
|
44
|
-
unions = ((masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
|
|
41
|
+
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
|
42
|
+
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
|
45
43
|
return intersections / unions
|
|
46
44
|
|
|
47
45
|
|
|
@@ -56,11 +54,12 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
|
|
56
54
|
|
|
57
55
|
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
|
58
56
|
"""Generate point grids for all crop layers."""
|
|
59
|
-
return [build_point_grid(int(n_per_side / (scale_per_layer
|
|
57
|
+
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
|
60
58
|
|
|
61
59
|
|
|
62
|
-
def generate_crop_boxes(
|
|
63
|
-
|
|
60
|
+
def generate_crop_boxes(
|
|
61
|
+
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
|
62
|
+
) -> Tuple[List[List[int]], List[int]]:
|
|
64
63
|
"""
|
|
65
64
|
Generates a list of crop boxes of different sizes.
|
|
66
65
|
|
|
@@ -132,8 +131,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|
|
132
131
|
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
|
|
133
132
|
import cv2 # type: ignore
|
|
134
133
|
|
|
135
|
-
assert mode in {
|
|
136
|
-
correct_holes = mode ==
|
|
134
|
+
assert mode in {"holes", "islands"}
|
|
135
|
+
correct_holes = mode == "holes"
|
|
137
136
|
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
138
137
|
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
139
138
|
sizes = stats[:, -1][1:] # Row 0 is background label
|
ultralytics/models/sam/build.py
CHANGED
|
@@ -64,46 +64,47 @@ def build_mobile_sam(checkpoint=None):
|
|
|
64
64
|
)
|
|
65
65
|
|
|
66
66
|
|
|
67
|
-
def _build_sam(
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
encoder_global_attn_indexes,
|
|
71
|
-
checkpoint=None,
|
|
72
|
-
mobile_sam=False):
|
|
67
|
+
def _build_sam(
|
|
68
|
+
encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
|
|
69
|
+
):
|
|
73
70
|
"""Builds the selected SAM model architecture."""
|
|
74
71
|
prompt_embed_dim = 256
|
|
75
72
|
image_size = 1024
|
|
76
73
|
vit_patch_size = 16
|
|
77
74
|
image_embedding_size = image_size // vit_patch_size
|
|
78
|
-
image_encoder = (
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
75
|
+
image_encoder = (
|
|
76
|
+
TinyViT(
|
|
77
|
+
img_size=1024,
|
|
78
|
+
in_chans=3,
|
|
79
|
+
num_classes=1000,
|
|
80
|
+
embed_dims=encoder_embed_dim,
|
|
81
|
+
depths=encoder_depth,
|
|
82
|
+
num_heads=encoder_num_heads,
|
|
83
|
+
window_sizes=[7, 7, 14, 7],
|
|
84
|
+
mlp_ratio=4.0,
|
|
85
|
+
drop_rate=0.0,
|
|
86
|
+
drop_path_rate=0.0,
|
|
87
|
+
use_checkpoint=False,
|
|
88
|
+
mbconv_expand_ratio=4.0,
|
|
89
|
+
local_conv_size=3,
|
|
90
|
+
layer_lr_decay=0.8,
|
|
91
|
+
)
|
|
92
|
+
if mobile_sam
|
|
93
|
+
else ImageEncoderViT(
|
|
94
|
+
depth=encoder_depth,
|
|
95
|
+
embed_dim=encoder_embed_dim,
|
|
96
|
+
img_size=image_size,
|
|
97
|
+
mlp_ratio=4,
|
|
98
|
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
|
99
|
+
num_heads=encoder_num_heads,
|
|
100
|
+
patch_size=vit_patch_size,
|
|
101
|
+
qkv_bias=True,
|
|
102
|
+
use_rel_pos=True,
|
|
103
|
+
global_attn_indexes=encoder_global_attn_indexes,
|
|
104
|
+
window_size=14,
|
|
105
|
+
out_chans=prompt_embed_dim,
|
|
106
|
+
)
|
|
107
|
+
)
|
|
107
108
|
sam = Sam(
|
|
108
109
|
image_encoder=image_encoder,
|
|
109
110
|
prompt_encoder=PromptEncoder(
|
|
@@ -129,7 +130,7 @@ def _build_sam(encoder_embed_dim,
|
|
|
129
130
|
)
|
|
130
131
|
if checkpoint is not None:
|
|
131
132
|
checkpoint = attempt_download_asset(checkpoint)
|
|
132
|
-
with open(checkpoint,
|
|
133
|
+
with open(checkpoint, "rb") as f:
|
|
133
134
|
state_dict = torch.load(f)
|
|
134
135
|
sam.load_state_dict(state_dict)
|
|
135
136
|
sam.eval()
|
|
@@ -139,13 +140,14 @@ def _build_sam(encoder_embed_dim,
|
|
|
139
140
|
|
|
140
141
|
|
|
141
142
|
sam_model_map = {
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
143
|
+
"sam_h.pt": build_sam_vit_h,
|
|
144
|
+
"sam_l.pt": build_sam_vit_l,
|
|
145
|
+
"sam_b.pt": build_sam_vit_b,
|
|
146
|
+
"mobile_sam.pt": build_mobile_sam,
|
|
147
|
+
}
|
|
146
148
|
|
|
147
149
|
|
|
148
|
-
def build_sam(ckpt=
|
|
150
|
+
def build_sam(ckpt="sam_b.pt"):
|
|
149
151
|
"""Build a SAM model specified by ckpt."""
|
|
150
152
|
model_builder = None
|
|
151
153
|
ckpt = str(ckpt) # to allow Path ckpt types
|
|
@@ -154,6 +156,6 @@ def build_sam(ckpt='sam_b.pt'):
|
|
|
154
156
|
model_builder = sam_model_map.get(k)
|
|
155
157
|
|
|
156
158
|
if not model_builder:
|
|
157
|
-
raise FileNotFoundError(f
|
|
159
|
+
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
|
|
158
160
|
|
|
159
161
|
return model_builder(ckpt)
|
ultralytics/models/sam/model.py
CHANGED
|
@@ -32,7 +32,7 @@ class SAM(Model):
|
|
|
32
32
|
dataset.
|
|
33
33
|
"""
|
|
34
34
|
|
|
35
|
-
def __init__(self, model=
|
|
35
|
+
def __init__(self, model="sam_b.pt") -> None:
|
|
36
36
|
"""
|
|
37
37
|
Initializes the SAM model with a pre-trained model file.
|
|
38
38
|
|
|
@@ -42,9 +42,9 @@ class SAM(Model):
|
|
|
42
42
|
Raises:
|
|
43
43
|
NotImplementedError: If the model file extension is not .pt or .pth.
|
|
44
44
|
"""
|
|
45
|
-
if model and Path(model).suffix not in (
|
|
46
|
-
raise NotImplementedError(
|
|
47
|
-
super().__init__(model=model, task=
|
|
45
|
+
if model and Path(model).suffix not in (".pt", ".pth"):
|
|
46
|
+
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
|
47
|
+
super().__init__(model=model, task="segment")
|
|
48
48
|
|
|
49
49
|
def _load(self, weights: str, task=None):
|
|
50
50
|
"""
|
|
@@ -70,7 +70,7 @@ class SAM(Model):
|
|
|
70
70
|
Returns:
|
|
71
71
|
(list): The model predictions.
|
|
72
72
|
"""
|
|
73
|
-
overrides = dict(conf=0.25, task=
|
|
73
|
+
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
|
|
74
74
|
kwargs.update(overrides)
|
|
75
75
|
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
|
76
76
|
return super().predict(source, stream, prompts=prompts, **kwargs)
|
|
@@ -112,4 +112,4 @@ class SAM(Model):
|
|
|
112
112
|
Returns:
|
|
113
113
|
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
|
114
114
|
"""
|
|
115
|
-
return {
|
|
115
|
+
return {"segment": {"predictor": Predictor}}
|
|
@@ -64,8 +64,9 @@ class MaskDecoder(nn.Module):
|
|
|
64
64
|
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
|
65
65
|
activation(),
|
|
66
66
|
)
|
|
67
|
-
self.output_hypernetworks_mlps = nn.ModuleList(
|
|
68
|
-
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
|
67
|
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
|
68
|
+
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
|
69
|
+
)
|
|
69
70
|
|
|
70
71
|
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
|
|
71
72
|
|
|
@@ -132,13 +133,14 @@ class MaskDecoder(nn.Module):
|
|
|
132
133
|
# Run the transformer
|
|
133
134
|
hs, src = self.transformer(src, pos_src, tokens)
|
|
134
135
|
iou_token_out = hs[:, 0, :]
|
|
135
|
-
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
|
|
136
|
+
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
|
136
137
|
|
|
137
138
|
# Upscale mask embeddings and predict masks using the mask tokens
|
|
138
139
|
src = src.transpose(1, 2).view(b, c, h, w)
|
|
139
140
|
upscaled_embedding = self.output_upscaling(src)
|
|
140
141
|
hyper_in_list: List[torch.Tensor] = [
|
|
141
|
-
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
|
142
|
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
|
143
|
+
]
|
|
142
144
|
hyper_in = torch.stack(hyper_in_list, dim=1)
|
|
143
145
|
b, c, h, w = upscaled_embedding.shape
|
|
144
146
|
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
|
@@ -28,23 +28,23 @@ class ImageEncoderViT(nn.Module):
|
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
30
|
def __init__(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
31
|
+
self,
|
|
32
|
+
img_size: int = 1024,
|
|
33
|
+
patch_size: int = 16,
|
|
34
|
+
in_chans: int = 3,
|
|
35
|
+
embed_dim: int = 768,
|
|
36
|
+
depth: int = 12,
|
|
37
|
+
num_heads: int = 12,
|
|
38
|
+
mlp_ratio: float = 4.0,
|
|
39
|
+
out_chans: int = 256,
|
|
40
|
+
qkv_bias: bool = True,
|
|
41
|
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
42
|
+
act_layer: Type[nn.Module] = nn.GELU,
|
|
43
|
+
use_abs_pos: bool = True,
|
|
44
|
+
use_rel_pos: bool = False,
|
|
45
|
+
rel_pos_zero_init: bool = True,
|
|
46
|
+
window_size: int = 0,
|
|
47
|
+
global_attn_indexes: Tuple[int, ...] = (),
|
|
48
48
|
) -> None:
|
|
49
49
|
"""
|
|
50
50
|
Args:
|
|
@@ -283,9 +283,9 @@ class PromptEncoder(nn.Module):
|
|
|
283
283
|
if masks is not None:
|
|
284
284
|
dense_embeddings = self._embed_masks(masks)
|
|
285
285
|
else:
|
|
286
|
-
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1,
|
|
287
|
-
|
|
288
|
-
|
|
286
|
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
|
287
|
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
|
288
|
+
)
|
|
289
289
|
|
|
290
290
|
return sparse_embeddings, dense_embeddings
|
|
291
291
|
|
|
@@ -298,7 +298,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
|
298
298
|
super().__init__()
|
|
299
299
|
if scale is None or scale <= 0.0:
|
|
300
300
|
scale = 1.0
|
|
301
|
-
self.register_buffer(
|
|
301
|
+
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
|
|
302
302
|
|
|
303
303
|
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
|
|
304
304
|
torch.use_deterministic_algorithms(False)
|
|
@@ -425,14 +425,14 @@ class Attention(nn.Module):
|
|
|
425
425
|
super().__init__()
|
|
426
426
|
self.num_heads = num_heads
|
|
427
427
|
head_dim = dim // num_heads
|
|
428
|
-
self.scale = head_dim
|
|
428
|
+
self.scale = head_dim**-0.5
|
|
429
429
|
|
|
430
430
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
431
431
|
self.proj = nn.Linear(dim, dim)
|
|
432
432
|
|
|
433
433
|
self.use_rel_pos = use_rel_pos
|
|
434
434
|
if self.use_rel_pos:
|
|
435
|
-
assert
|
|
435
|
+
assert input_size is not None, "Input size must be provided if using relative positional encoding."
|
|
436
436
|
# Initialize relative positional embeddings
|
|
437
437
|
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
|
438
438
|
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
|
@@ -479,8 +479,9 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
|
|
|
479
479
|
return windows, (Hp, Wp)
|
|
480
480
|
|
|
481
481
|
|
|
482
|
-
def window_unpartition(
|
|
483
|
-
|
|
482
|
+
def window_unpartition(
|
|
483
|
+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
|
484
|
+
) -> torch.Tensor:
|
|
484
485
|
"""
|
|
485
486
|
Window unpartition into original sequences and removing padding.
|
|
486
487
|
|
|
@@ -523,7 +524,7 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
|
|
|
523
524
|
rel_pos_resized = F.interpolate(
|
|
524
525
|
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
|
525
526
|
size=max_rel_dist,
|
|
526
|
-
mode=
|
|
527
|
+
mode="linear",
|
|
527
528
|
)
|
|
528
529
|
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
|
529
530
|
else:
|
|
@@ -567,11 +568,12 @@ def add_decomposed_rel_pos(
|
|
|
567
568
|
|
|
568
569
|
B, _, dim = q.shape
|
|
569
570
|
r_q = q.reshape(B, q_h, q_w, dim)
|
|
570
|
-
rel_h = torch.einsum(
|
|
571
|
-
rel_w = torch.einsum(
|
|
571
|
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
|
572
|
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
|
572
573
|
|
|
573
574
|
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
|
574
|
-
B, q_h * q_w, k_h * k_w
|
|
575
|
+
B, q_h * q_w, k_h * k_w
|
|
576
|
+
)
|
|
575
577
|
|
|
576
578
|
return attn
|
|
577
579
|
|
|
@@ -580,12 +582,12 @@ class PatchEmbed(nn.Module):
|
|
|
580
582
|
"""Image to Patch Embedding."""
|
|
581
583
|
|
|
582
584
|
def __init__(
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
585
|
+
self,
|
|
586
|
+
kernel_size: Tuple[int, int] = (16, 16),
|
|
587
|
+
stride: Tuple[int, int] = (16, 16),
|
|
588
|
+
padding: Tuple[int, int] = (0, 0),
|
|
589
|
+
in_chans: int = 3,
|
|
590
|
+
embed_dim: int = 768,
|
|
589
591
|
) -> None:
|
|
590
592
|
"""
|
|
591
593
|
Initialize PatchEmbed module.
|
|
@@ -30,8 +30,9 @@ class Sam(nn.Module):
|
|
|
30
30
|
pixel_mean (List[float]): Mean pixel values for image normalization.
|
|
31
31
|
pixel_std (List[float]): Standard deviation values for image normalization.
|
|
32
32
|
"""
|
|
33
|
+
|
|
33
34
|
mask_threshold: float = 0.0
|
|
34
|
-
image_format: str =
|
|
35
|
+
image_format: str = "RGB"
|
|
35
36
|
|
|
36
37
|
def __init__(
|
|
37
38
|
self,
|
|
@@ -39,7 +40,7 @@ class Sam(nn.Module):
|
|
|
39
40
|
prompt_encoder: PromptEncoder,
|
|
40
41
|
mask_decoder: MaskDecoder,
|
|
41
42
|
pixel_mean: List[float] = (123.675, 116.28, 103.53),
|
|
42
|
-
pixel_std: List[float] = (58.395, 57.12, 57.375)
|
|
43
|
+
pixel_std: List[float] = (58.395, 57.12, 57.375),
|
|
43
44
|
) -> None:
|
|
44
45
|
"""
|
|
45
46
|
Initialize the Sam class to predict object masks from an image and input prompts.
|
|
@@ -60,5 +61,5 @@ class Sam(nn.Module):
|
|
|
60
61
|
self.image_encoder = image_encoder
|
|
61
62
|
self.prompt_encoder = prompt_encoder
|
|
62
63
|
self.mask_decoder = mask_decoder
|
|
63
|
-
self.register_buffer(
|
|
64
|
-
self.register_buffer(
|
|
64
|
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
|
65
|
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|