ultralytics 8.0.195__py3-none-any.whl → 8.0.196__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +5 -6
- ultralytics/data/augment.py +234 -29
- ultralytics/data/base.py +2 -1
- ultralytics/data/build.py +9 -3
- ultralytics/data/converter.py +5 -2
- ultralytics/data/dataset.py +16 -2
- ultralytics/data/loaders.py +111 -7
- ultralytics/data/utils.py +3 -3
- ultralytics/engine/exporter.py +1 -3
- ultralytics/engine/model.py +3 -9
- ultralytics/engine/predictor.py +10 -6
- ultralytics/engine/results.py +18 -8
- ultralytics/engine/trainer.py +19 -31
- ultralytics/engine/tuner.py +20 -20
- ultralytics/engine/validator.py +3 -4
- ultralytics/hub/__init__.py +2 -2
- ultralytics/hub/auth.py +18 -3
- ultralytics/hub/session.py +1 -0
- ultralytics/hub/utils.py +1 -3
- ultralytics/models/fastsam/model.py +2 -1
- ultralytics/models/fastsam/predict.py +2 -0
- ultralytics/models/fastsam/prompt.py +15 -1
- ultralytics/models/nas/model.py +3 -1
- ultralytics/models/rtdetr/model.py +4 -6
- ultralytics/models/rtdetr/predict.py +2 -1
- ultralytics/models/rtdetr/train.py +2 -1
- ultralytics/models/rtdetr/val.py +1 -0
- ultralytics/models/sam/amg.py +12 -6
- ultralytics/models/sam/model.py +5 -6
- ultralytics/models/sam/modules/decoders.py +5 -1
- ultralytics/models/sam/modules/encoders.py +15 -12
- ultralytics/models/sam/modules/tiny_encoder.py +38 -2
- ultralytics/models/sam/modules/transformer.py +2 -4
- ultralytics/models/sam/predict.py +8 -4
- ultralytics/models/utils/loss.py +35 -8
- ultralytics/models/utils/ops.py +14 -18
- ultralytics/models/yolo/classify/predict.py +1 -0
- ultralytics/models/yolo/classify/train.py +4 -2
- ultralytics/models/yolo/classify/val.py +1 -0
- ultralytics/models/yolo/detect/train.py +4 -3
- ultralytics/models/yolo/model.py +2 -4
- ultralytics/models/yolo/pose/predict.py +1 -0
- ultralytics/models/yolo/segment/predict.py +2 -0
- ultralytics/models/yolo/segment/val.py +1 -1
- ultralytics/nn/autobackend.py +45 -32
- ultralytics/nn/modules/__init__.py +13 -9
- ultralytics/nn/modules/block.py +11 -5
- ultralytics/nn/modules/conv.py +16 -7
- ultralytics/nn/modules/head.py +6 -3
- ultralytics/nn/modules/transformer.py +47 -15
- ultralytics/nn/modules/utils.py +6 -4
- ultralytics/nn/tasks.py +61 -21
- ultralytics/trackers/bot_sort.py +53 -6
- ultralytics/trackers/byte_tracker.py +71 -15
- ultralytics/trackers/track.py +0 -1
- ultralytics/trackers/utils/gmc.py +23 -0
- ultralytics/trackers/utils/kalman_filter.py +6 -6
- ultralytics/utils/__init__.py +31 -18
- ultralytics/utils/autobatch.py +1 -3
- ultralytics/utils/benchmarks.py +14 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/comet.py +11 -3
- ultralytics/utils/callbacks/dvc.py +9 -0
- ultralytics/utils/callbacks/neptune.py +5 -6
- ultralytics/utils/callbacks/wb.py +1 -0
- ultralytics/utils/checks.py +13 -9
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +7 -3
- ultralytics/utils/files.py +3 -3
- ultralytics/utils/instance.py +12 -3
- ultralytics/utils/loss.py +97 -22
- ultralytics/utils/metrics.py +34 -34
- ultralytics/utils/ops.py +10 -9
- ultralytics/utils/patches.py +9 -7
- ultralytics/utils/plotting.py +4 -3
- ultralytics/utils/torch_utils.py +8 -6
- ultralytics/utils/triton.py +2 -1
- {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/METADATA +1 -1
- {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/RECORD +84 -84
- {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.195.dist-info → ultralytics-8.0.196.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@ from ultralytics.utils import TQDM
|
|
|
15
15
|
class FastSAMPrompt:
|
|
16
16
|
|
|
17
17
|
def __init__(self, source, results, device='cuda') -> None:
|
|
18
|
+
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
|
18
19
|
self.device = device
|
|
19
20
|
self.results = results
|
|
20
21
|
self.source = source
|
|
@@ -30,6 +31,7 @@ class FastSAMPrompt:
|
|
|
30
31
|
|
|
31
32
|
@staticmethod
|
|
32
33
|
def _segment_image(image, bbox):
|
|
34
|
+
"""Segments the given image according to the provided bounding box coordinates."""
|
|
33
35
|
image_array = np.array(image)
|
|
34
36
|
segmented_image_array = np.zeros_like(image_array)
|
|
35
37
|
x1, y1, x2, y2 = bbox
|
|
@@ -45,6 +47,9 @@ class FastSAMPrompt:
|
|
|
45
47
|
|
|
46
48
|
@staticmethod
|
|
47
49
|
def _format_results(result, filter=0):
|
|
50
|
+
"""Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
|
|
51
|
+
area.
|
|
52
|
+
"""
|
|
48
53
|
annotations = []
|
|
49
54
|
n = len(result.masks.data) if result.masks is not None else 0
|
|
50
55
|
for i in range(n):
|
|
@@ -61,6 +66,9 @@ class FastSAMPrompt:
|
|
|
61
66
|
|
|
62
67
|
@staticmethod
|
|
63
68
|
def _get_bbox_from_mask(mask):
|
|
69
|
+
"""Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
|
|
70
|
+
contours.
|
|
71
|
+
"""
|
|
64
72
|
mask = mask.astype(np.uint8)
|
|
65
73
|
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
66
74
|
x1, y1, w, h = cv2.boundingRect(contours[0])
|
|
@@ -195,6 +203,7 @@ class FastSAMPrompt:
|
|
|
195
203
|
|
|
196
204
|
@torch.no_grad()
|
|
197
205
|
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
|
206
|
+
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
|
198
207
|
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
|
199
208
|
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
|
200
209
|
stacked_images = torch.stack(preprocessed_images)
|
|
@@ -206,6 +215,7 @@ class FastSAMPrompt:
|
|
|
206
215
|
return probs[:, 0].softmax(dim=0)
|
|
207
216
|
|
|
208
217
|
def _crop_image(self, format_results):
|
|
218
|
+
"""Crops an image based on provided annotation format and returns cropped images and related data."""
|
|
209
219
|
if os.path.isdir(self.source):
|
|
210
220
|
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
|
211
221
|
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
|
@@ -229,6 +239,7 @@ class FastSAMPrompt:
|
|
|
229
239
|
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
|
230
240
|
|
|
231
241
|
def box_prompt(self, bbox):
|
|
242
|
+
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
|
232
243
|
if self.results[0].masks is not None:
|
|
233
244
|
assert (bbox[2] != 0 and bbox[3] != 0)
|
|
234
245
|
if os.path.isdir(self.source):
|
|
@@ -261,7 +272,8 @@ class FastSAMPrompt:
|
|
|
261
272
|
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
|
|
262
273
|
return self.results
|
|
263
274
|
|
|
264
|
-
def point_prompt(self, points, pointlabel): # numpy
|
|
275
|
+
def point_prompt(self, points, pointlabel): # numpy
|
|
276
|
+
"""Adjusts points on detected masks based on user input and returns the modified results."""
|
|
265
277
|
if self.results[0].masks is not None:
|
|
266
278
|
if os.path.isdir(self.source):
|
|
267
279
|
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
|
@@ -284,6 +296,7 @@ class FastSAMPrompt:
|
|
|
284
296
|
return self.results
|
|
285
297
|
|
|
286
298
|
def text_prompt(self, text):
|
|
299
|
+
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
|
287
300
|
if self.results[0].masks is not None:
|
|
288
301
|
format_results = self._format_results(self.results[0], 0)
|
|
289
302
|
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
|
@@ -296,4 +309,5 @@ class FastSAMPrompt:
|
|
|
296
309
|
return self.results
|
|
297
310
|
|
|
298
311
|
def everything_prompt(self):
|
|
312
|
+
"""Returns the processed results from the previous methods in the class."""
|
|
299
313
|
return self.results
|
ultralytics/models/nas/model.py
CHANGED
|
@@ -25,12 +25,13 @@ from .val import NASValidator
|
|
|
25
25
|
class NAS(Model):
|
|
26
26
|
|
|
27
27
|
def __init__(self, model='yolo_nas_s.pt') -> None:
|
|
28
|
+
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
|
28
29
|
assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
|
|
29
30
|
super().__init__(model, task='detect')
|
|
30
31
|
|
|
31
32
|
@smart_inference_mode()
|
|
32
33
|
def _load(self, weights: str, task: str):
|
|
33
|
-
|
|
34
|
+
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
|
34
35
|
import super_gradients
|
|
35
36
|
suffix = Path(weights).suffix
|
|
36
37
|
if suffix == '.pt':
|
|
@@ -58,4 +59,5 @@ class NAS(Model):
|
|
|
58
59
|
|
|
59
60
|
@property
|
|
60
61
|
def task_map(self):
|
|
62
|
+
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
|
|
61
63
|
return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
RT-DETR model interface
|
|
4
|
-
"""
|
|
2
|
+
"""RT-DETR model interface."""
|
|
5
3
|
from ultralytics.engine.model import Model
|
|
6
4
|
from ultralytics.nn.tasks import RTDETRDetectionModel
|
|
7
5
|
|
|
@@ -11,17 +9,17 @@ from .val import RTDETRValidator
|
|
|
11
9
|
|
|
12
10
|
|
|
13
11
|
class RTDETR(Model):
|
|
14
|
-
"""
|
|
15
|
-
RTDETR model interface.
|
|
16
|
-
"""
|
|
12
|
+
"""RTDETR model interface."""
|
|
17
13
|
|
|
18
14
|
def __init__(self, model='rtdetr-l.pt') -> None:
|
|
15
|
+
"""Initializes the RTDETR model with the given model file, defaulting to 'rtdetr-l.pt'."""
|
|
19
16
|
if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
|
|
20
17
|
raise NotImplementedError('RT-DETR only supports creating from *.pt file or *.yaml file.')
|
|
21
18
|
super().__init__(model=model, task='detect')
|
|
22
19
|
|
|
23
20
|
@property
|
|
24
21
|
def task_map(self):
|
|
22
|
+
"""Returns a dictionary mapping task names to corresponding Ultralytics task classes for RTDETR model."""
|
|
25
23
|
return {
|
|
26
24
|
'detect': {
|
|
27
25
|
'predictor': RTDETRPredictor,
|
|
@@ -48,7 +48,8 @@ class RTDETRPredictor(BasePredictor):
|
|
|
48
48
|
return results
|
|
49
49
|
|
|
50
50
|
def pre_transform(self, im):
|
|
51
|
-
"""
|
|
51
|
+
"""
|
|
52
|
+
Pre-transform input image before inference.
|
|
52
53
|
|
|
53
54
|
Args:
|
|
54
55
|
im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
|
ultralytics/models/rtdetr/val.py
CHANGED
|
@@ -16,6 +16,7 @@ __all__ = 'RTDETRValidator', # tuple or list
|
|
|
16
16
|
class RTDETRDataset(YOLODataset):
|
|
17
17
|
|
|
18
18
|
def __init__(self, *args, data=None, **kwargs):
|
|
19
|
+
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
|
|
19
20
|
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
|
|
20
21
|
|
|
21
22
|
# NOTE: add stretch version load_image for rtdetr mosaic
|
ultralytics/models/sam/amg.py
CHANGED
|
@@ -32,9 +32,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
|
|
32
32
|
|
|
33
33
|
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
|
34
34
|
"""
|
|
35
|
-
Computes the stability score for a batch of masks.
|
|
36
|
-
|
|
37
|
-
the predicted mask logits at high
|
|
35
|
+
Computes the stability score for a batch of masks.
|
|
36
|
+
|
|
37
|
+
The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
|
|
38
|
+
and low values.
|
|
38
39
|
"""
|
|
39
40
|
# One mask is always contained inside the other.
|
|
40
41
|
# Save memory by preventing unnecessary cast to torch.int64
|
|
@@ -60,7 +61,11 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
|
|
|
60
61
|
|
|
61
62
|
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
|
|
62
63
|
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
|
|
63
|
-
"""
|
|
64
|
+
"""
|
|
65
|
+
Generates a list of crop boxes of different sizes.
|
|
66
|
+
|
|
67
|
+
Each layer has (2**i)**2 boxes for the ith layer.
|
|
68
|
+
"""
|
|
64
69
|
crop_boxes, layer_idxs = [], []
|
|
65
70
|
im_h, im_w = im_size
|
|
66
71
|
short_side = min(im_h, im_w)
|
|
@@ -145,8 +150,9 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|
|
145
150
|
|
|
146
151
|
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
|
147
152
|
"""
|
|
148
|
-
Calculates boxes in XYXY format around masks.
|
|
149
|
-
|
|
153
|
+
Calculates boxes in XYXY format around masks.
|
|
154
|
+
|
|
155
|
+
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
|
150
156
|
"""
|
|
151
157
|
# torch.max below raises an error on empty inputs, just skip in this case
|
|
152
158
|
if torch.numel(masks) == 0:
|
ultralytics/models/sam/model.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
2
|
-
"""
|
|
3
|
-
SAM model interface
|
|
4
|
-
"""
|
|
2
|
+
"""SAM model interface."""
|
|
5
3
|
|
|
6
4
|
from pathlib import Path
|
|
7
5
|
|
|
@@ -13,16 +11,16 @@ from .predict import Predictor
|
|
|
13
11
|
|
|
14
12
|
|
|
15
13
|
class SAM(Model):
|
|
16
|
-
"""
|
|
17
|
-
SAM model interface.
|
|
18
|
-
"""
|
|
14
|
+
"""SAM model interface."""
|
|
19
15
|
|
|
20
16
|
def __init__(self, model='sam_b.pt') -> None:
|
|
17
|
+
"""Initializes the SAM model instance with the specified pre-trained model file."""
|
|
21
18
|
if model and Path(model).suffix not in ('.pt', '.pth'):
|
|
22
19
|
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
|
|
23
20
|
super().__init__(model=model, task='segment')
|
|
24
21
|
|
|
25
22
|
def _load(self, weights: str, task=None):
|
|
23
|
+
"""Loads the provided weights into the SAM model."""
|
|
26
24
|
self.model = build_sam(weights)
|
|
27
25
|
|
|
28
26
|
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
|
@@ -48,4 +46,5 @@ class SAM(Model):
|
|
|
48
46
|
|
|
49
47
|
@property
|
|
50
48
|
def task_map(self):
|
|
49
|
+
"""Returns a dictionary mapping the 'segment' task to its corresponding 'Predictor'."""
|
|
51
50
|
return {'segment': {'predictor': Predictor}}
|
|
@@ -98,7 +98,11 @@ class MaskDecoder(nn.Module):
|
|
|
98
98
|
sparse_prompt_embeddings: torch.Tensor,
|
|
99
99
|
dense_prompt_embeddings: torch.Tensor,
|
|
100
100
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
101
|
-
"""
|
|
101
|
+
"""
|
|
102
|
+
Predicts masks.
|
|
103
|
+
|
|
104
|
+
See 'forward' for more details.
|
|
105
|
+
"""
|
|
102
106
|
# Concatenate output tokens
|
|
103
107
|
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
|
104
108
|
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
|
@@ -100,6 +100,9 @@ class ImageEncoderViT(nn.Module):
|
|
|
100
100
|
)
|
|
101
101
|
|
|
102
102
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
103
|
+
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
|
|
104
|
+
and neck.
|
|
105
|
+
"""
|
|
103
106
|
x = self.patch_embed(x)
|
|
104
107
|
if self.pos_embed is not None:
|
|
105
108
|
x = x + self.pos_embed
|
|
@@ -157,8 +160,8 @@ class PromptEncoder(nn.Module):
|
|
|
157
160
|
|
|
158
161
|
def get_dense_pe(self) -> torch.Tensor:
|
|
159
162
|
"""
|
|
160
|
-
Returns the positional encoding used to encode point prompts,
|
|
161
|
-
|
|
163
|
+
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
|
|
164
|
+
image encoding.
|
|
162
165
|
|
|
163
166
|
Returns:
|
|
164
167
|
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
|
@@ -204,9 +207,7 @@ class PromptEncoder(nn.Module):
|
|
|
204
207
|
boxes: Optional[torch.Tensor],
|
|
205
208
|
masks: Optional[torch.Tensor],
|
|
206
209
|
) -> int:
|
|
207
|
-
"""
|
|
208
|
-
Gets the batch size of the output given the batch size of the input prompts.
|
|
209
|
-
"""
|
|
210
|
+
"""Gets the batch size of the output given the batch size of the input prompts."""
|
|
210
211
|
if points is not None:
|
|
211
212
|
return points[0].shape[0]
|
|
212
213
|
elif boxes is not None:
|
|
@@ -217,6 +218,7 @@ class PromptEncoder(nn.Module):
|
|
|
217
218
|
return 1
|
|
218
219
|
|
|
219
220
|
def _get_device(self) -> torch.device:
|
|
221
|
+
"""Returns the device of the first point embedding's weight tensor."""
|
|
220
222
|
return self.point_embeddings[0].weight.device
|
|
221
223
|
|
|
222
224
|
def forward(
|
|
@@ -259,11 +261,10 @@ class PromptEncoder(nn.Module):
|
|
|
259
261
|
|
|
260
262
|
|
|
261
263
|
class PositionEmbeddingRandom(nn.Module):
|
|
262
|
-
"""
|
|
263
|
-
Positional encoding using random spatial frequencies.
|
|
264
|
-
"""
|
|
264
|
+
"""Positional encoding using random spatial frequencies."""
|
|
265
265
|
|
|
266
266
|
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
|
267
|
+
"""Initializes a position embedding using random spatial frequencies."""
|
|
267
268
|
super().__init__()
|
|
268
269
|
if scale is None or scale <= 0.0:
|
|
269
270
|
scale = 1.0
|
|
@@ -304,7 +305,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|
|
304
305
|
|
|
305
306
|
|
|
306
307
|
class Block(nn.Module):
|
|
307
|
-
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
|
308
|
+
"""Transformer blocks with support of window attention and residual propagation blocks."""
|
|
308
309
|
|
|
309
310
|
def __init__(
|
|
310
311
|
self,
|
|
@@ -351,6 +352,7 @@ class Block(nn.Module):
|
|
|
351
352
|
self.window_size = window_size
|
|
352
353
|
|
|
353
354
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
355
|
+
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
|
|
354
356
|
shortcut = x
|
|
355
357
|
x = self.norm1(x)
|
|
356
358
|
# Window partition
|
|
@@ -404,6 +406,7 @@ class Attention(nn.Module):
|
|
|
404
406
|
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
|
405
407
|
|
|
406
408
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
409
|
+
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
|
|
407
410
|
B, H, W, _ = x.shape
|
|
408
411
|
# qkv with shape (3, B, nHead, H * W, C)
|
|
409
412
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
@@ -448,6 +451,7 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
|
|
|
448
451
|
hw: Tuple[int, int]) -> torch.Tensor:
|
|
449
452
|
"""
|
|
450
453
|
Window unpartition into original sequences and removing padding.
|
|
454
|
+
|
|
451
455
|
Args:
|
|
452
456
|
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
|
453
457
|
window_size (int): window size.
|
|
@@ -540,9 +544,7 @@ def add_decomposed_rel_pos(
|
|
|
540
544
|
|
|
541
545
|
|
|
542
546
|
class PatchEmbed(nn.Module):
|
|
543
|
-
"""
|
|
544
|
-
Image to Patch Embedding.
|
|
545
|
-
"""
|
|
547
|
+
"""Image to Patch Embedding."""
|
|
546
548
|
|
|
547
549
|
def __init__(
|
|
548
550
|
self,
|
|
@@ -565,4 +567,5 @@ class PatchEmbed(nn.Module):
|
|
|
565
567
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
|
566
568
|
|
|
567
569
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
570
|
+
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
|
|
568
571
|
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
|
@@ -23,6 +23,9 @@ from ultralytics.utils.instance import to_2tuple
|
|
|
23
23
|
class Conv2d_BN(torch.nn.Sequential):
|
|
24
24
|
|
|
25
25
|
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
|
26
|
+
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
|
|
27
|
+
drop path.
|
|
28
|
+
"""
|
|
26
29
|
super().__init__()
|
|
27
30
|
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
|
28
31
|
bn = torch.nn.BatchNorm2d(b)
|
|
@@ -34,6 +37,9 @@ class Conv2d_BN(torch.nn.Sequential):
|
|
|
34
37
|
class PatchEmbed(nn.Module):
|
|
35
38
|
|
|
36
39
|
def __init__(self, in_chans, embed_dim, resolution, activation):
|
|
40
|
+
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
|
|
41
|
+
function.
|
|
42
|
+
"""
|
|
37
43
|
super().__init__()
|
|
38
44
|
img_size: Tuple[int, int] = to_2tuple(resolution)
|
|
39
45
|
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
|
@@ -48,12 +54,16 @@ class PatchEmbed(nn.Module):
|
|
|
48
54
|
)
|
|
49
55
|
|
|
50
56
|
def forward(self, x):
|
|
57
|
+
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
|
|
51
58
|
return self.seq(x)
|
|
52
59
|
|
|
53
60
|
|
|
54
61
|
class MBConv(nn.Module):
|
|
55
62
|
|
|
56
63
|
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
|
64
|
+
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
|
|
65
|
+
function.
|
|
66
|
+
"""
|
|
57
67
|
super().__init__()
|
|
58
68
|
self.in_chans = in_chans
|
|
59
69
|
self.hidden_chans = int(in_chans * expand_ratio)
|
|
@@ -73,6 +83,7 @@ class MBConv(nn.Module):
|
|
|
73
83
|
self.drop_path = nn.Identity()
|
|
74
84
|
|
|
75
85
|
def forward(self, x):
|
|
86
|
+
"""Implements the forward pass for the model architecture."""
|
|
76
87
|
shortcut = x
|
|
77
88
|
x = self.conv1(x)
|
|
78
89
|
x = self.act1(x)
|
|
@@ -87,6 +98,9 @@ class MBConv(nn.Module):
|
|
|
87
98
|
class PatchMerging(nn.Module):
|
|
88
99
|
|
|
89
100
|
def __init__(self, input_resolution, dim, out_dim, activation):
|
|
101
|
+
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
|
|
102
|
+
optional parameters.
|
|
103
|
+
"""
|
|
90
104
|
super().__init__()
|
|
91
105
|
|
|
92
106
|
self.input_resolution = input_resolution
|
|
@@ -99,6 +113,7 @@ class PatchMerging(nn.Module):
|
|
|
99
113
|
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
|
100
114
|
|
|
101
115
|
def forward(self, x):
|
|
116
|
+
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
|
|
102
117
|
if x.ndim == 3:
|
|
103
118
|
H, W = self.input_resolution
|
|
104
119
|
B = len(x)
|
|
@@ -149,6 +164,7 @@ class ConvLayer(nn.Module):
|
|
|
149
164
|
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
|
150
165
|
|
|
151
166
|
def forward(self, x):
|
|
167
|
+
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
|
152
168
|
for blk in self.blocks:
|
|
153
169
|
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
|
154
170
|
return x if self.downsample is None else self.downsample(x)
|
|
@@ -157,6 +173,7 @@ class ConvLayer(nn.Module):
|
|
|
157
173
|
class Mlp(nn.Module):
|
|
158
174
|
|
|
159
175
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
176
|
+
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
|
160
177
|
super().__init__()
|
|
161
178
|
out_features = out_features or in_features
|
|
162
179
|
hidden_features = hidden_features or in_features
|
|
@@ -167,6 +184,7 @@ class Mlp(nn.Module):
|
|
|
167
184
|
self.drop = nn.Dropout(drop)
|
|
168
185
|
|
|
169
186
|
def forward(self, x):
|
|
187
|
+
"""Applies operations on input x and returns modified x, runs downsample if not None."""
|
|
170
188
|
x = self.norm(x)
|
|
171
189
|
x = self.fc1(x)
|
|
172
190
|
x = self.act(x)
|
|
@@ -216,6 +234,7 @@ class Attention(torch.nn.Module):
|
|
|
216
234
|
|
|
217
235
|
@torch.no_grad()
|
|
218
236
|
def train(self, mode=True):
|
|
237
|
+
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
|
219
238
|
super().train(mode)
|
|
220
239
|
if mode and hasattr(self, 'ab'):
|
|
221
240
|
del self.ab
|
|
@@ -298,6 +317,9 @@ class TinyViTBlock(nn.Module):
|
|
|
298
317
|
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
|
299
318
|
|
|
300
319
|
def forward(self, x):
|
|
320
|
+
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
|
|
321
|
+
convolution.
|
|
322
|
+
"""
|
|
301
323
|
H, W = self.input_resolution
|
|
302
324
|
B, L, C = x.shape
|
|
303
325
|
assert L == H * W, 'input feature has wrong size'
|
|
@@ -337,6 +359,9 @@ class TinyViTBlock(nn.Module):
|
|
|
337
359
|
return x + self.drop_path(self.mlp(x))
|
|
338
360
|
|
|
339
361
|
def extra_repr(self) -> str:
|
|
362
|
+
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
|
363
|
+
attentions heads, window size, and MLP ratio.
|
|
364
|
+
"""
|
|
340
365
|
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
|
341
366
|
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
|
|
342
367
|
|
|
@@ -402,23 +427,28 @@ class BasicLayer(nn.Module):
|
|
|
402
427
|
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
|
403
428
|
|
|
404
429
|
def forward(self, x):
|
|
430
|
+
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
|
405
431
|
for blk in self.blocks:
|
|
406
432
|
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
|
407
433
|
return x if self.downsample is None else self.downsample(x)
|
|
408
434
|
|
|
409
435
|
def extra_repr(self) -> str:
|
|
436
|
+
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
|
410
437
|
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
|
411
438
|
|
|
412
439
|
|
|
413
440
|
class LayerNorm2d(nn.Module):
|
|
441
|
+
"""A PyTorch implementation of Layer Normalization in 2D."""
|
|
414
442
|
|
|
415
443
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
|
444
|
+
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
|
|
416
445
|
super().__init__()
|
|
417
446
|
self.weight = nn.Parameter(torch.ones(num_channels))
|
|
418
447
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
|
419
448
|
self.eps = eps
|
|
420
449
|
|
|
421
450
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
451
|
+
"""Perform a forward pass, normalizing the input tensor."""
|
|
422
452
|
u = x.mean(1, keepdim=True)
|
|
423
453
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
424
454
|
x = (x - u) / torch.sqrt(s + self.eps)
|
|
@@ -518,6 +548,7 @@ class TinyViT(nn.Module):
|
|
|
518
548
|
)
|
|
519
549
|
|
|
520
550
|
def set_layer_lr_decay(self, layer_lr_decay):
|
|
551
|
+
"""Sets the learning rate decay for each layer in the TinyViT model."""
|
|
521
552
|
decay_rate = layer_lr_decay
|
|
522
553
|
|
|
523
554
|
# layers -> blocks (depth)
|
|
@@ -525,6 +556,7 @@ class TinyViT(nn.Module):
|
|
|
525
556
|
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
|
526
557
|
|
|
527
558
|
def _set_lr_scale(m, scale):
|
|
559
|
+
"""Sets the learning rate scale for each layer in the model based on the layer's depth."""
|
|
528
560
|
for p in m.parameters():
|
|
529
561
|
p.lr_scale = scale
|
|
530
562
|
|
|
@@ -544,12 +576,14 @@ class TinyViT(nn.Module):
|
|
|
544
576
|
p.param_name = k
|
|
545
577
|
|
|
546
578
|
def _check_lr_scale(m):
|
|
579
|
+
"""Checks if the learning rate scale attribute is present in module's parameters."""
|
|
547
580
|
for p in m.parameters():
|
|
548
581
|
assert hasattr(p, 'lr_scale'), p.param_name
|
|
549
582
|
|
|
550
583
|
self.apply(_check_lr_scale)
|
|
551
584
|
|
|
552
585
|
def _init_weights(self, m):
|
|
586
|
+
"""Initializes weights for linear layers and layer normalization in the given module."""
|
|
553
587
|
if isinstance(m, nn.Linear):
|
|
554
588
|
# NOTE: This initialization is needed only for training.
|
|
555
589
|
# trunc_normal_(m.weight, std=.02)
|
|
@@ -561,11 +595,12 @@ class TinyViT(nn.Module):
|
|
|
561
595
|
|
|
562
596
|
@torch.jit.ignore
|
|
563
597
|
def no_weight_decay_keywords(self):
|
|
598
|
+
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
|
564
599
|
return {'attention_biases'}
|
|
565
600
|
|
|
566
601
|
def forward_features(self, x):
|
|
567
|
-
|
|
568
|
-
x = self.patch_embed(x)
|
|
602
|
+
"""Runs the input through the model layers and returns the transformed output."""
|
|
603
|
+
x = self.patch_embed(x) # x input is (N, C, H, W)
|
|
569
604
|
|
|
570
605
|
x = self.layers[0](x)
|
|
571
606
|
start_i = 1
|
|
@@ -579,4 +614,5 @@ class TinyViT(nn.Module):
|
|
|
579
614
|
return self.neck(x)
|
|
580
615
|
|
|
581
616
|
def forward(self, x):
|
|
617
|
+
"""Executes a forward pass on the input tensor through the constructed model layers."""
|
|
582
618
|
return self.forward_features(x)
|
|
@@ -21,8 +21,7 @@ class TwoWayTransformer(nn.Module):
|
|
|
21
21
|
attention_downsample_rate: int = 2,
|
|
22
22
|
) -> None:
|
|
23
23
|
"""
|
|
24
|
-
A transformer decoder that attends to an input image using
|
|
25
|
-
queries whose positional embedding is supplied.
|
|
24
|
+
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
|
|
26
25
|
|
|
27
26
|
Args:
|
|
28
27
|
depth (int): number of layers in the transformer
|
|
@@ -171,8 +170,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|
|
171
170
|
|
|
172
171
|
|
|
173
172
|
class Attention(nn.Module):
|
|
174
|
-
"""
|
|
175
|
-
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
|
173
|
+
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
|
176
174
|
values.
|
|
177
175
|
"""
|
|
178
176
|
|
|
@@ -19,6 +19,7 @@ from .build import build_sam
|
|
|
19
19
|
class Predictor(BasePredictor):
|
|
20
20
|
|
|
21
21
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
22
|
+
"""Initializes the Predictor class with default or provided configuration, overrides, and callbacks."""
|
|
22
23
|
if overrides is None:
|
|
23
24
|
overrides = {}
|
|
24
25
|
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
|
|
@@ -34,7 +35,8 @@ class Predictor(BasePredictor):
|
|
|
34
35
|
self.segment_all = False
|
|
35
36
|
|
|
36
37
|
def preprocess(self, im):
|
|
37
|
-
"""
|
|
38
|
+
"""
|
|
39
|
+
Prepares input image before inference.
|
|
38
40
|
|
|
39
41
|
Args:
|
|
40
42
|
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
|
|
@@ -189,7 +191,8 @@ class Predictor(BasePredictor):
|
|
|
189
191
|
stability_score_thresh=0.95,
|
|
190
192
|
stability_score_offset=0.95,
|
|
191
193
|
crop_nms_thresh=0.7):
|
|
192
|
-
"""
|
|
194
|
+
"""
|
|
195
|
+
Segment the whole image.
|
|
193
196
|
|
|
194
197
|
Args:
|
|
195
198
|
im (torch.Tensor): The preprocessed image, (N, C, H, W).
|
|
@@ -360,14 +363,15 @@ class Predictor(BasePredictor):
|
|
|
360
363
|
self.prompts = prompts
|
|
361
364
|
|
|
362
365
|
def reset_image(self):
|
|
366
|
+
"""Resets the image and its features to None."""
|
|
363
367
|
self.im = None
|
|
364
368
|
self.features = None
|
|
365
369
|
|
|
366
370
|
@staticmethod
|
|
367
371
|
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
|
368
372
|
"""
|
|
369
|
-
Removes small disconnected regions and holes in masks, then reruns
|
|
370
|
-
|
|
373
|
+
Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates.
|
|
374
|
+
Requires open-cv as a dependency.
|
|
371
375
|
|
|
372
376
|
Args:
|
|
373
377
|
masks (torch.Tensor): Masks, (N, H, W).
|