ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +122 -0
- tests/test_cuda.py +155 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +216 -0
- tests/test_integrations.py +150 -0
- tests/test_python.py +615 -0
- tests/test_solutions.py +94 -0
- ultralytics/__init__.py +11 -8
- ultralytics/cfg/__init__.py +569 -131
- ultralytics/cfg/datasets/Argoverse.yaml +2 -1
- ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
- ultralytics/cfg/datasets/ImageNet.yaml +2 -1
- ultralytics/cfg/datasets/Objects365.yaml +5 -4
- ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
- ultralytics/cfg/datasets/VOC.yaml +3 -2
- ultralytics/cfg/datasets/VisDrone.yaml +6 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco-pose.yaml +7 -6
- ultralytics/cfg/datasets/coco.yaml +3 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
- ultralytics/cfg/datasets/coco128.yaml +4 -3
- ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
- ultralytics/cfg/datasets/coco8.yaml +3 -2
- ultralytics/cfg/datasets/crack-seg.yaml +3 -2
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8.yaml +3 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1236 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
- ultralytics/cfg/datasets/package-seg.yaml +5 -4
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
- ultralytics/cfg/datasets/xView.yaml +2 -1
- ultralytics/cfg/default.yaml +14 -11
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
- ultralytics/cfg/models/v3/yolov3.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
- ultralytics/cfg/models/v5/yolov5.yaml +5 -2
- ultralytics/cfg/models/v6/yolov6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
- ultralytics/cfg/models/v8/yolov8.yaml +5 -2
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/solutions/default.yaml +24 -0
- ultralytics/cfg/trackers/botsort.yaml +8 -5
- ultralytics/cfg/trackers/bytetrack.yaml +8 -5
- ultralytics/data/__init__.py +14 -3
- ultralytics/data/annotator.py +37 -15
- ultralytics/data/augment.py +1783 -289
- ultralytics/data/base.py +62 -27
- ultralytics/data/build.py +37 -8
- ultralytics/data/converter.py +196 -36
- ultralytics/data/dataset.py +233 -94
- ultralytics/data/loaders.py +199 -96
- ultralytics/data/split_dota.py +39 -29
- ultralytics/data/utils.py +111 -41
- ultralytics/engine/__init__.py +1 -1
- ultralytics/engine/exporter.py +579 -244
- ultralytics/engine/model.py +604 -252
- ultralytics/engine/predictor.py +22 -11
- ultralytics/engine/results.py +1228 -218
- ultralytics/engine/trainer.py +191 -129
- ultralytics/engine/tuner.py +18 -18
- ultralytics/engine/validator.py +18 -15
- ultralytics/hub/__init__.py +31 -13
- ultralytics/hub/auth.py +11 -7
- ultralytics/hub/google/__init__.py +159 -0
- ultralytics/hub/session.py +128 -94
- ultralytics/hub/utils.py +20 -21
- ultralytics/models/__init__.py +4 -2
- ultralytics/models/fastsam/__init__.py +2 -3
- ultralytics/models/fastsam/model.py +26 -4
- ultralytics/models/fastsam/predict.py +127 -63
- ultralytics/models/fastsam/utils.py +1 -44
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +21 -10
- ultralytics/models/nas/predict.py +3 -6
- ultralytics/models/nas/val.py +4 -4
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +1 -1
- ultralytics/models/rtdetr/predict.py +6 -8
- ultralytics/models/rtdetr/train.py +6 -2
- ultralytics/models/rtdetr/val.py +3 -3
- ultralytics/models/sam/__init__.py +3 -3
- ultralytics/models/sam/amg.py +29 -23
- ultralytics/models/sam/build.py +211 -13
- ultralytics/models/sam/model.py +91 -30
- ultralytics/models/sam/modules/__init__.py +1 -1
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +381 -53
- ultralytics/models/sam/modules/encoders.py +515 -324
- ultralytics/models/sam/modules/memory_attention.py +237 -0
- ultralytics/models/sam/modules/sam.py +969 -21
- ultralytics/models/sam/modules/tiny_encoder.py +425 -154
- ultralytics/models/sam/modules/transformer.py +159 -60
- ultralytics/models/sam/modules/utils.py +293 -0
- ultralytics/models/sam/predict.py +1263 -132
- ultralytics/models/utils/__init__.py +1 -1
- ultralytics/models/utils/loss.py +36 -24
- ultralytics/models/utils/ops.py +3 -7
- ultralytics/models/yolo/__init__.py +3 -3
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +7 -8
- ultralytics/models/yolo/classify/train.py +17 -22
- ultralytics/models/yolo/classify/val.py +8 -4
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +3 -5
- ultralytics/models/yolo/detect/train.py +11 -4
- ultralytics/models/yolo/detect/val.py +90 -52
- ultralytics/models/yolo/model.py +14 -9
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +2 -2
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +41 -23
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +3 -5
- ultralytics/models/yolo/pose/train.py +2 -2
- ultralytics/models/yolo/pose/val.py +51 -17
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +3 -5
- ultralytics/models/yolo/segment/train.py +2 -2
- ultralytics/models/yolo/segment/val.py +60 -19
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +92 -0
- ultralytics/models/yolo/world/train_world.py +109 -0
- ultralytics/nn/__init__.py +1 -1
- ultralytics/nn/autobackend.py +228 -93
- ultralytics/nn/modules/__init__.py +39 -14
- ultralytics/nn/modules/activation.py +21 -0
- ultralytics/nn/modules/block.py +526 -66
- ultralytics/nn/modules/conv.py +24 -7
- ultralytics/nn/modules/head.py +177 -34
- ultralytics/nn/modules/transformer.py +6 -5
- ultralytics/nn/modules/utils.py +1 -2
- ultralytics/nn/tasks.py +226 -82
- ultralytics/solutions/__init__.py +30 -1
- ultralytics/solutions/ai_gym.py +96 -143
- ultralytics/solutions/analytics.py +247 -0
- ultralytics/solutions/distance_calculation.py +78 -135
- ultralytics/solutions/heatmap.py +93 -247
- ultralytics/solutions/object_counter.py +184 -259
- ultralytics/solutions/parking_management.py +246 -0
- ultralytics/solutions/queue_management.py +112 -0
- ultralytics/solutions/region_counter.py +116 -0
- ultralytics/solutions/security_alarm.py +144 -0
- ultralytics/solutions/solutions.py +178 -0
- ultralytics/solutions/speed_estimation.py +86 -174
- ultralytics/solutions/streamlit_inference.py +190 -0
- ultralytics/solutions/trackzone.py +68 -0
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +32 -13
- ultralytics/trackers/bot_sort.py +61 -28
- ultralytics/trackers/byte_tracker.py +83 -51
- ultralytics/trackers/track.py +21 -6
- ultralytics/trackers/utils/__init__.py +1 -1
- ultralytics/trackers/utils/gmc.py +62 -48
- ultralytics/trackers/utils/kalman_filter.py +166 -35
- ultralytics/trackers/utils/matching.py +40 -21
- ultralytics/utils/__init__.py +511 -239
- ultralytics/utils/autobatch.py +40 -22
- ultralytics/utils/benchmarks.py +266 -85
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +1 -3
- ultralytics/utils/callbacks/clearml.py +7 -6
- ultralytics/utils/callbacks/comet.py +39 -17
- ultralytics/utils/callbacks/dvc.py +1 -1
- ultralytics/utils/callbacks/hub.py +16 -16
- ultralytics/utils/callbacks/mlflow.py +28 -24
- ultralytics/utils/callbacks/neptune.py +6 -2
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +18 -18
- ultralytics/utils/callbacks/wb.py +27 -20
- ultralytics/utils/checks.py +172 -100
- ultralytics/utils/dist.py +2 -1
- ultralytics/utils/downloads.py +40 -34
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +72 -38
- ultralytics/utils/instance.py +41 -19
- ultralytics/utils/loss.py +83 -55
- ultralytics/utils/metrics.py +61 -56
- ultralytics/utils/ops.py +94 -89
- ultralytics/utils/patches.py +30 -14
- ultralytics/utils/plotting.py +600 -269
- ultralytics/utils/tal.py +67 -26
- ultralytics/utils/torch_utils.py +305 -112
- ultralytics/utils/triton.py +2 -1
- ultralytics/utils/tuner.py +21 -12
- ultralytics-8.3.63.dist-info/METADATA +370 -0
- ultralytics-8.3.63.dist-info/RECORD +241 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
- ultralytics/data/explorer/__init__.py +0 -5
- ultralytics/data/explorer/explorer.py +0 -472
- ultralytics/data/explorer/gui/__init__.py +0 -1
- ultralytics/data/explorer/gui/dash.py +0 -268
- ultralytics/data/explorer/utils.py +0 -166
- ultralytics/models/fastsam/prompt.py +0 -357
- ultralytics-8.1.29.dist-info/METADATA +0 -373
- ultralytics-8.1.29.dist-info/RECORD +0 -197
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
@@ -1,86 +1,150 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
|
+
from PIL import Image
|
4
5
|
|
5
|
-
from ultralytics.
|
6
|
-
from ultralytics.
|
7
|
-
from ultralytics.
|
8
|
-
from ultralytics.utils import
|
6
|
+
from ultralytics.models.yolo.segment import SegmentationPredictor
|
7
|
+
from ultralytics.utils import DEFAULT_CFG, checks
|
8
|
+
from ultralytics.utils.metrics import box_iou
|
9
|
+
from ultralytics.utils.ops import scale_masks
|
9
10
|
|
11
|
+
from .utils import adjust_bboxes_to_image_border
|
10
12
|
|
11
|
-
|
13
|
+
|
14
|
+
class FastSAMPredictor(SegmentationPredictor):
|
12
15
|
"""
|
13
16
|
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
|
14
17
|
YOLO framework.
|
15
18
|
|
16
|
-
This class extends the
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
Attributes:
|
21
|
-
cfg (dict): Configuration parameters for prediction.
|
22
|
-
overrides (dict, optional): Optional parameter overrides for custom behavior.
|
23
|
-
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
|
19
|
+
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
|
20
|
+
adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
|
21
|
+
class segmentation.
|
24
22
|
"""
|
25
23
|
|
26
24
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
25
|
+
"""Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework."""
|
26
|
+
super().__init__(cfg, overrides, _callbacks)
|
27
|
+
self.prompts = {}
|
28
|
+
|
29
|
+
def postprocess(self, preds, img, orig_imgs):
|
30
|
+
"""Applies box postprocess for FastSAM predictions."""
|
31
|
+
bboxes = self.prompts.pop("bboxes", None)
|
32
|
+
points = self.prompts.pop("points", None)
|
33
|
+
labels = self.prompts.pop("labels", None)
|
34
|
+
texts = self.prompts.pop("texts", None)
|
35
|
+
results = super().postprocess(preds, img, orig_imgs)
|
36
|
+
for result in results:
|
37
|
+
full_box = torch.tensor(
|
38
|
+
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
|
39
|
+
)
|
40
|
+
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
|
41
|
+
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
|
42
|
+
if idx.numel() != 0:
|
43
|
+
result.boxes.xyxy[idx] = full_box
|
44
|
+
|
45
|
+
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
|
46
|
+
|
47
|
+
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
27
48
|
"""
|
28
|
-
|
49
|
+
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
50
|
+
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
29
51
|
|
30
52
|
Args:
|
31
|
-
|
32
|
-
|
33
|
-
|
53
|
+
results (Results | List[Results]): The original inference results from FastSAM models without any prompts.
|
54
|
+
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
55
|
+
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
56
|
+
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
57
|
+
texts (str | List[str], optional): Textual prompts, a list contains string objects.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
(List[Results]): The output results determined by prompts.
|
34
61
|
"""
|
35
|
-
|
36
|
-
|
62
|
+
if bboxes is None and points is None and texts is None:
|
63
|
+
return results
|
64
|
+
prompt_results = []
|
65
|
+
if not isinstance(results, list):
|
66
|
+
results = [results]
|
67
|
+
for result in results:
|
68
|
+
if len(result) == 0:
|
69
|
+
prompt_results.append(result)
|
70
|
+
continue
|
71
|
+
masks = result.masks.data
|
72
|
+
if masks.shape[1:] != result.orig_shape:
|
73
|
+
masks = scale_masks(masks[None], result.orig_shape)[0]
|
74
|
+
# bboxes prompt
|
75
|
+
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
76
|
+
if bboxes is not None:
|
77
|
+
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
|
78
|
+
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
79
|
+
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
80
|
+
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
|
81
|
+
full_mask_areas = torch.sum(masks, dim=(1, 2))
|
37
82
|
|
38
|
-
|
83
|
+
union = bbox_areas[:, None] + full_mask_areas - mask_areas
|
84
|
+
idx[torch.argmax(mask_areas / union, dim=1)] = True
|
85
|
+
if points is not None:
|
86
|
+
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
|
87
|
+
points = points[None] if points.ndim == 1 else points
|
88
|
+
if labels is None:
|
89
|
+
labels = torch.ones(points.shape[0])
|
90
|
+
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
91
|
+
assert len(labels) == len(points), (
|
92
|
+
f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
|
93
|
+
)
|
94
|
+
point_idx = (
|
95
|
+
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
96
|
+
if labels.sum() == 0 # all negative points
|
97
|
+
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
98
|
+
)
|
99
|
+
for point, label in zip(points, labels):
|
100
|
+
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
|
101
|
+
idx |= point_idx
|
102
|
+
if texts is not None:
|
103
|
+
if isinstance(texts, str):
|
104
|
+
texts = [texts]
|
105
|
+
crop_ims, filter_idx = [], []
|
106
|
+
for i, b in enumerate(result.boxes.xyxy.tolist()):
|
107
|
+
x1, y1, x2, y2 = (int(x) for x in b)
|
108
|
+
if masks[i].sum() <= 100:
|
109
|
+
filter_idx.append(i)
|
110
|
+
continue
|
111
|
+
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
|
112
|
+
similarity = self._clip_inference(crop_ims, texts)
|
113
|
+
text_idx = torch.argmax(similarity, dim=-1) # (M, )
|
114
|
+
if len(filter_idx):
|
115
|
+
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
|
116
|
+
idx[text_idx] = True
|
117
|
+
|
118
|
+
prompt_results.append(result[idx])
|
119
|
+
|
120
|
+
return prompt_results
|
121
|
+
|
122
|
+
def _clip_inference(self, images, texts):
|
39
123
|
"""
|
40
|
-
|
41
|
-
size, and returns the final results.
|
124
|
+
CLIP Inference process.
|
42
125
|
|
43
126
|
Args:
|
44
|
-
|
45
|
-
|
46
|
-
orig_imgs (list | torch.Tensor): The original image or list of images.
|
127
|
+
images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order.
|
128
|
+
texts (List[str]): A list of prompt texts and each of them should be string object.
|
47
129
|
|
48
130
|
Returns:
|
49
|
-
(
|
131
|
+
(torch.Tensor): The similarity between given images and texts.
|
50
132
|
"""
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
)
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
70
|
-
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
71
|
-
|
72
|
-
results = []
|
73
|
-
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
74
|
-
for i, pred in enumerate(p):
|
75
|
-
orig_img = orig_imgs[i]
|
76
|
-
img_path = self.batch[0][i]
|
77
|
-
if not len(pred): # save empty boxes
|
78
|
-
masks = None
|
79
|
-
elif self.args.retina_masks:
|
80
|
-
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
81
|
-
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
82
|
-
else:
|
83
|
-
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
84
|
-
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
85
|
-
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
86
|
-
return results
|
133
|
+
try:
|
134
|
+
import clip
|
135
|
+
except ImportError:
|
136
|
+
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
137
|
+
import clip
|
138
|
+
if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
|
139
|
+
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
|
140
|
+
images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
|
141
|
+
tokenized_text = clip.tokenize(texts).to(self.device)
|
142
|
+
image_features = self.clip_model.encode_image(images)
|
143
|
+
text_features = self.clip_model.encode_text(tokenized_text)
|
144
|
+
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
|
145
|
+
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
|
146
|
+
return (image_features * text_features[:, None]).sum(-1) # (M, N)
|
147
|
+
|
148
|
+
def set_prompts(self, prompts):
|
149
|
+
"""Set prompts in advance."""
|
150
|
+
self.prompts = prompts
|
@@ -1,6 +1,4 @@
|
|
1
|
-
# Ultralytics
|
2
|
-
|
3
|
-
import torch
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
4
2
|
|
5
3
|
|
6
4
|
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
@@ -15,7 +13,6 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
|
15
13
|
Returns:
|
16
14
|
adjusted_boxes (torch.Tensor): adjusted bounding boxes
|
17
15
|
"""
|
18
|
-
|
19
16
|
# Image dimensions
|
20
17
|
h, w = image_shape
|
21
18
|
|
@@ -25,43 +22,3 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
|
25
22
|
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
|
26
23
|
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
|
27
24
|
return boxes
|
28
|
-
|
29
|
-
|
30
|
-
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
|
31
|
-
"""
|
32
|
-
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
|
33
|
-
|
34
|
-
Args:
|
35
|
-
box1 (torch.Tensor): (4, )
|
36
|
-
boxes (torch.Tensor): (n, 4)
|
37
|
-
iou_thres (float): IoU threshold
|
38
|
-
image_shape (tuple): (height, width)
|
39
|
-
raw_output (bool): If True, return the raw IoU values instead of the indices
|
40
|
-
|
41
|
-
Returns:
|
42
|
-
high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
|
43
|
-
"""
|
44
|
-
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
|
45
|
-
# Obtain coordinates for intersections
|
46
|
-
x1 = torch.max(box1[0], boxes[:, 0])
|
47
|
-
y1 = torch.max(box1[1], boxes[:, 1])
|
48
|
-
x2 = torch.min(box1[2], boxes[:, 2])
|
49
|
-
y2 = torch.min(box1[3], boxes[:, 3])
|
50
|
-
|
51
|
-
# Compute the area of intersection
|
52
|
-
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
|
53
|
-
|
54
|
-
# Compute the area of both individual boxes
|
55
|
-
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
56
|
-
box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
57
|
-
|
58
|
-
# Compute the area of union
|
59
|
-
union = box1_area + box2_area - intersection
|
60
|
-
|
61
|
-
# Compute the IoU
|
62
|
-
iou = intersection / union # Should be shape (n, )
|
63
|
-
if raw_output:
|
64
|
-
return 0 if iou.numel() == 0 else iou
|
65
|
-
|
66
|
-
# return indices of boxes with IoU > thres
|
67
|
-
return torch.nonzero(iou > iou_thres).flatten()
|
ultralytics/models/nas/model.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""
|
3
3
|
YOLO-NAS model interface.
|
4
4
|
|
@@ -6,8 +6,8 @@ Example:
|
|
6
6
|
```python
|
7
7
|
from ultralytics import NAS
|
8
8
|
|
9
|
-
model = NAS(
|
10
|
-
results = model.predict(
|
9
|
+
model = NAS("yolo_nas_s")
|
10
|
+
results = model.predict("ultralytics/assets/bus.jpg")
|
11
11
|
```
|
12
12
|
"""
|
13
13
|
|
@@ -16,7 +16,9 @@ from pathlib import Path
|
|
16
16
|
import torch
|
17
17
|
|
18
18
|
from ultralytics.engine.model import Model
|
19
|
-
from ultralytics.utils.
|
19
|
+
from ultralytics.utils.downloads import attempt_download_asset
|
20
|
+
from ultralytics.utils.torch_utils import model_info
|
21
|
+
|
20
22
|
from .predict import NASPredictor
|
21
23
|
from .val import NASValidator
|
22
24
|
|
@@ -32,8 +34,8 @@ class NAS(Model):
|
|
32
34
|
```python
|
33
35
|
from ultralytics import NAS
|
34
36
|
|
35
|
-
model = NAS(
|
36
|
-
results = model.predict(
|
37
|
+
model = NAS("yolo_nas_s")
|
38
|
+
results = model.predict("ultralytics/assets/bus.jpg")
|
37
39
|
```
|
38
40
|
|
39
41
|
Attributes:
|
@@ -45,19 +47,28 @@ class NAS(Model):
|
|
45
47
|
|
46
48
|
def __init__(self, model="yolo_nas_s.pt") -> None:
|
47
49
|
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
48
|
-
assert Path(model).suffix not in
|
50
|
+
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
|
49
51
|
super().__init__(model, task="detect")
|
50
52
|
|
51
|
-
|
52
|
-
def _load(self, weights: str, task: str):
|
53
|
+
def _load(self, weights: str, task=None) -> None:
|
53
54
|
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
54
55
|
import super_gradients
|
55
56
|
|
56
57
|
suffix = Path(weights).suffix
|
57
58
|
if suffix == ".pt":
|
58
|
-
self.model = torch.load(weights)
|
59
|
+
self.model = torch.load(attempt_download_asset(weights))
|
60
|
+
|
59
61
|
elif suffix == "":
|
60
62
|
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
63
|
+
|
64
|
+
# Override the forward method to ignore additional arguments
|
65
|
+
def new_forward(x, *args, **kwargs):
|
66
|
+
"""Ignore additional __call__ arguments."""
|
67
|
+
return self.model._original_forward(x)
|
68
|
+
|
69
|
+
self.model._original_forward = self.model.forward
|
70
|
+
self.model.forward = new_forward
|
71
|
+
|
61
72
|
# Standardize model
|
62
73
|
self.model.fuse = lambda verbose=True: self.model
|
63
74
|
self.model.stride = torch.tensor([32])
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -22,7 +22,7 @@ class NASPredictor(BasePredictor):
|
|
22
22
|
```python
|
23
23
|
from ultralytics import NAS
|
24
24
|
|
25
|
-
model = NAS(
|
25
|
+
model = NAS("yolo_nas_s")
|
26
26
|
predictor = model.predictor
|
27
27
|
# Assumes that raw_preds, img, orig_imgs are available
|
28
28
|
results = predictor.postprocess(raw_preds, img, orig_imgs)
|
@@ -34,7 +34,6 @@ class NASPredictor(BasePredictor):
|
|
34
34
|
|
35
35
|
def postprocess(self, preds_in, img, orig_imgs):
|
36
36
|
"""Postprocess predictions and returns a list of Results objects."""
|
37
|
-
|
38
37
|
# Cat boxes and class scores
|
39
38
|
boxes = ops.xyxy2xywh(preds_in[0][0])
|
40
39
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
@@ -52,9 +51,7 @@ class NASPredictor(BasePredictor):
|
|
52
51
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
53
52
|
|
54
53
|
results = []
|
55
|
-
for
|
56
|
-
orig_img = orig_imgs[i]
|
54
|
+
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
|
57
55
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
58
|
-
img_path = self.batch[0][i]
|
59
56
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
60
57
|
return results
|
ultralytics/models/nas/val.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -17,14 +17,14 @@ class NASValidator(DetectionValidator):
|
|
17
17
|
ultimately producing the final detections.
|
18
18
|
|
19
19
|
Attributes:
|
20
|
-
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU
|
20
|
+
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
|
21
21
|
lb (torch.Tensor): Optional tensor for multilabel NMS.
|
22
22
|
|
23
23
|
Example:
|
24
24
|
```python
|
25
25
|
from ultralytics import NAS
|
26
26
|
|
27
|
-
model = NAS(
|
27
|
+
model = NAS("yolo_nas_s")
|
28
28
|
validator = model.validator
|
29
29
|
# Assumes that raw_preds are available
|
30
30
|
final_preds = validator.postprocess(raw_preds)
|
@@ -44,7 +44,7 @@ class NASValidator(DetectionValidator):
|
|
44
44
|
self.args.iou,
|
45
45
|
labels=self.lb,
|
46
46
|
multi_label=False,
|
47
|
-
agnostic=self.args.single_cls,
|
47
|
+
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
48
48
|
max_det=self.args.max_det,
|
49
49
|
max_time_img=0.5,
|
50
50
|
)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""
|
3
3
|
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
|
4
4
|
performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -21,7 +21,7 @@ class RTDETRPredictor(BasePredictor):
|
|
21
21
|
from ultralytics.utils import ASSETS
|
22
22
|
from ultralytics.models.rtdetr import RTDETRPredictor
|
23
23
|
|
24
|
-
args = dict(model=
|
24
|
+
args = dict(model="rtdetr-l.pt", source=ASSETS)
|
25
25
|
predictor = RTDETRPredictor(overrides=args)
|
26
26
|
predictor.predict_cli()
|
27
27
|
```
|
@@ -56,18 +56,16 @@ class RTDETRPredictor(BasePredictor):
|
|
56
56
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
57
57
|
|
58
58
|
results = []
|
59
|
-
for
|
59
|
+
for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
|
60
60
|
bbox = ops.xywh2xyxy(bbox)
|
61
|
-
|
62
|
-
idx =
|
61
|
+
max_score, cls = score.max(-1, keepdim=True) # (300, 1)
|
62
|
+
idx = max_score.squeeze(-1) > self.args.conf # (300, )
|
63
63
|
if self.args.classes is not None:
|
64
64
|
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
65
|
-
pred = torch.cat([bbox,
|
66
|
-
orig_img = orig_imgs[i]
|
65
|
+
pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
|
67
66
|
oh, ow = orig_img.shape[:2]
|
68
67
|
pred[..., [0, 2]] *= ow
|
69
68
|
pred[..., [1, 3]] *= oh
|
70
|
-
img_path = self.batch[0][i]
|
71
69
|
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
72
70
|
return results
|
73
71
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from copy import copy
|
4
4
|
|
@@ -7,6 +7,7 @@ import torch
|
|
7
7
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
8
8
|
from ultralytics.nn.tasks import RTDETRDetectionModel
|
9
9
|
from ultralytics.utils import RANK, colorstr
|
10
|
+
|
10
11
|
from .val import RTDETRDataset, RTDETRValidator
|
11
12
|
|
12
13
|
|
@@ -24,7 +25,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
24
25
|
```python
|
25
26
|
from ultralytics.models.rtdetr.train import RTDETRTrainer
|
26
27
|
|
27
|
-
args = dict(model=
|
28
|
+
args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
|
28
29
|
trainer = RTDETRTrainer(overrides=args)
|
29
30
|
trainer.train()
|
30
31
|
```
|
@@ -67,8 +68,11 @@ class RTDETRTrainer(DetectionTrainer):
|
|
67
68
|
hyp=self.args,
|
68
69
|
rect=False,
|
69
70
|
cache=self.args.cache or None,
|
71
|
+
single_cls=self.args.single_cls or False,
|
70
72
|
prefix=colorstr(f"{mode}: "),
|
73
|
+
classes=self.args.classes,
|
71
74
|
data=self.data,
|
75
|
+
fraction=self.args.fraction if mode == "train" else 1.0,
|
72
76
|
)
|
73
77
|
|
74
78
|
def get_validator(self):
|
ultralytics/models/rtdetr/val.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -62,7 +62,7 @@ class RTDETRValidator(DetectionValidator):
|
|
62
62
|
```python
|
63
63
|
from ultralytics.models.rtdetr import RTDETRValidator
|
64
64
|
|
65
|
-
args = dict(model=
|
65
|
+
args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
66
66
|
validator = RTDETRValidator(args=args)
|
67
67
|
validator()
|
68
68
|
```
|
@@ -125,7 +125,7 @@ class RTDETRValidator(DetectionValidator):
|
|
125
125
|
bbox = ops.xywh2xyxy(bbox) # target boxes
|
126
126
|
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
|
127
127
|
bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
|
128
|
-
return
|
128
|
+
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
|
129
129
|
|
130
130
|
def _prepare_pred(self, pred, pbatch):
|
131
131
|
"""Prepares and returns a batch with transformed bounding boxes and class labels."""
|
@@ -1,6 +1,6 @@
|
|
1
|
-
# Ultralytics
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
3
|
from .model import SAM
|
4
|
-
from .predict import Predictor
|
4
|
+
from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
|
5
5
|
|
6
|
-
__all__ = "SAM", "Predictor" # tuple or list
|
6
|
+
__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list
|