ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +527 -67
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +44 -37
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +84 -56
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.28.dist-info/METADATA +0 -373
  244. ultralytics-8.1.28.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,86 +1,150 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
+ from PIL import Image
4
5
 
5
- from ultralytics.engine.results import Results
6
- from ultralytics.models.fastsam.utils import bbox_iou
7
- from ultralytics.models.yolo.detect.predict import DetectionPredictor
8
- from ultralytics.utils import DEFAULT_CFG, ops
6
+ from ultralytics.models.yolo.segment import SegmentationPredictor
7
+ from ultralytics.utils import DEFAULT_CFG, checks
8
+ from ultralytics.utils.metrics import box_iou
9
+ from ultralytics.utils.ops import scale_masks
9
10
 
11
+ from .utils import adjust_bboxes_to_image_border
10
12
 
11
- class FastSAMPredictor(DetectionPredictor):
13
+
14
+ class FastSAMPredictor(SegmentationPredictor):
12
15
  """
13
16
  FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
14
17
  YOLO framework.
15
18
 
16
- This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
17
- It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
18
- for single-class segmentation.
19
-
20
- Attributes:
21
- cfg (dict): Configuration parameters for prediction.
22
- overrides (dict, optional): Optional parameter overrides for custom behavior.
23
- _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
19
+ This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
20
+ adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
21
+ class segmentation.
24
22
  """
25
23
 
26
24
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
25
+ """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework."""
26
+ super().__init__(cfg, overrides, _callbacks)
27
+ self.prompts = {}
28
+
29
+ def postprocess(self, preds, img, orig_imgs):
30
+ """Applies box postprocess for FastSAM predictions."""
31
+ bboxes = self.prompts.pop("bboxes", None)
32
+ points = self.prompts.pop("points", None)
33
+ labels = self.prompts.pop("labels", None)
34
+ texts = self.prompts.pop("texts", None)
35
+ results = super().postprocess(preds, img, orig_imgs)
36
+ for result in results:
37
+ full_box = torch.tensor(
38
+ [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
39
+ )
40
+ boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
41
+ idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
42
+ if idx.numel() != 0:
43
+ result.boxes.xyxy[idx] = full_box
44
+
45
+ return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
46
+
47
+ def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
27
48
  """
28
- Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.
49
+ Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
50
+ Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
29
51
 
30
52
  Args:
31
- cfg (dict): Configuration parameters for prediction.
32
- overrides (dict, optional): Optional parameter overrides for custom behavior.
33
- _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
53
+ results (Results | List[Results]): The original inference results from FastSAM models without any prompts.
54
+ bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
55
+ points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
56
+ labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
57
+ texts (str | List[str], optional): Textual prompts, a list contains string objects.
58
+
59
+ Returns:
60
+ (List[Results]): The output results determined by prompts.
34
61
  """
35
- super().__init__(cfg, overrides, _callbacks)
36
- self.args.task = "segment"
62
+ if bboxes is None and points is None and texts is None:
63
+ return results
64
+ prompt_results = []
65
+ if not isinstance(results, list):
66
+ results = [results]
67
+ for result in results:
68
+ if len(result) == 0:
69
+ prompt_results.append(result)
70
+ continue
71
+ masks = result.masks.data
72
+ if masks.shape[1:] != result.orig_shape:
73
+ masks = scale_masks(masks[None], result.orig_shape)[0]
74
+ # bboxes prompt
75
+ idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
76
+ if bboxes is not None:
77
+ bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
78
+ bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
79
+ bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
80
+ mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
81
+ full_mask_areas = torch.sum(masks, dim=(1, 2))
37
82
 
38
- def postprocess(self, preds, img, orig_imgs):
83
+ union = bbox_areas[:, None] + full_mask_areas - mask_areas
84
+ idx[torch.argmax(mask_areas / union, dim=1)] = True
85
+ if points is not None:
86
+ points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
87
+ points = points[None] if points.ndim == 1 else points
88
+ if labels is None:
89
+ labels = torch.ones(points.shape[0])
90
+ labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
91
+ assert len(labels) == len(points), (
92
+ f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
93
+ )
94
+ point_idx = (
95
+ torch.ones(len(result), dtype=torch.bool, device=self.device)
96
+ if labels.sum() == 0 # all negative points
97
+ else torch.zeros(len(result), dtype=torch.bool, device=self.device)
98
+ )
99
+ for point, label in zip(points, labels):
100
+ point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
101
+ idx |= point_idx
102
+ if texts is not None:
103
+ if isinstance(texts, str):
104
+ texts = [texts]
105
+ crop_ims, filter_idx = [], []
106
+ for i, b in enumerate(result.boxes.xyxy.tolist()):
107
+ x1, y1, x2, y2 = (int(x) for x in b)
108
+ if masks[i].sum() <= 100:
109
+ filter_idx.append(i)
110
+ continue
111
+ crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
112
+ similarity = self._clip_inference(crop_ims, texts)
113
+ text_idx = torch.argmax(similarity, dim=-1) # (M, )
114
+ if len(filter_idx):
115
+ text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
116
+ idx[text_idx] = True
117
+
118
+ prompt_results.append(result[idx])
119
+
120
+ return prompt_results
121
+
122
+ def _clip_inference(self, images, texts):
39
123
  """
40
- Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
41
- size, and returns the final results.
124
+ CLIP Inference process.
42
125
 
43
126
  Args:
44
- preds (list): The raw output predictions from the model.
45
- img (torch.Tensor): The processed image tensor.
46
- orig_imgs (list | torch.Tensor): The original image or list of images.
127
+ images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order.
128
+ texts (List[str]): A list of prompt texts and each of them should be string object.
47
129
 
48
130
  Returns:
49
- (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
131
+ (torch.Tensor): The similarity between given images and texts.
50
132
  """
51
- p = ops.non_max_suppression(
52
- preds[0],
53
- self.args.conf,
54
- self.args.iou,
55
- agnostic=self.args.agnostic_nms,
56
- max_det=self.args.max_det,
57
- nc=1, # set to 1 class since SAM has no class predictions
58
- classes=self.args.classes,
59
- )
60
- full_box = torch.zeros(p[0].shape[1], device=p[0].device)
61
- full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
62
- full_box = full_box.view(1, -1)
63
- critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
64
- if critical_iou_index.numel() != 0:
65
- full_box[0][4] = p[0][critical_iou_index][:, 4]
66
- full_box[0][6:] = p[0][critical_iou_index][:, 6:]
67
- p[0][critical_iou_index] = full_box
68
-
69
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
70
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
71
-
72
- results = []
73
- proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
74
- for i, pred in enumerate(p):
75
- orig_img = orig_imgs[i]
76
- img_path = self.batch[0][i]
77
- if not len(pred): # save empty boxes
78
- masks = None
79
- elif self.args.retina_masks:
80
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
81
- masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
82
- else:
83
- masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
84
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
85
- results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
86
- return results
133
+ try:
134
+ import clip
135
+ except ImportError:
136
+ checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
137
+ import clip
138
+ if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
139
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
140
+ images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
141
+ tokenized_text = clip.tokenize(texts).to(self.device)
142
+ image_features = self.clip_model.encode_image(images)
143
+ text_features = self.clip_model.encode_text(tokenized_text)
144
+ image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
145
+ text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
146
+ return (image_features * text_features[:, None]).sum(-1) # (M, N)
147
+
148
+ def set_prompts(self, prompts):
149
+ """Set prompts in advance."""
150
+ self.prompts = prompts
@@ -1,6 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
2
-
3
- import torch
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
4
2
 
5
3
 
6
4
  def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
@@ -15,7 +13,6 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
15
13
  Returns:
16
14
  adjusted_boxes (torch.Tensor): adjusted bounding boxes
17
15
  """
18
-
19
16
  # Image dimensions
20
17
  h, w = image_shape
21
18
 
@@ -25,43 +22,3 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
25
22
  boxes[boxes[:, 2] > w - threshold, 2] = w # x2
26
23
  boxes[boxes[:, 3] > h - threshold, 3] = h # y2
27
24
  return boxes
28
-
29
-
30
- def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
31
- """
32
- Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
33
-
34
- Args:
35
- box1 (torch.Tensor): (4, )
36
- boxes (torch.Tensor): (n, 4)
37
- iou_thres (float): IoU threshold
38
- image_shape (tuple): (height, width)
39
- raw_output (bool): If True, return the raw IoU values instead of the indices
40
-
41
- Returns:
42
- high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
43
- """
44
- boxes = adjust_bboxes_to_image_border(boxes, image_shape)
45
- # Obtain coordinates for intersections
46
- x1 = torch.max(box1[0], boxes[:, 0])
47
- y1 = torch.max(box1[1], boxes[:, 1])
48
- x2 = torch.min(box1[2], boxes[:, 2])
49
- y2 = torch.min(box1[3], boxes[:, 3])
50
-
51
- # Compute the area of intersection
52
- intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
53
-
54
- # Compute the area of both individual boxes
55
- box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
56
- box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
57
-
58
- # Compute the area of union
59
- union = box1_area + box2_area - intersection
60
-
61
- # Compute the IoU
62
- iou = intersection / union # Should be shape (n, )
63
- if raw_output:
64
- return 0 if iou.numel() == 0 else iou
65
-
66
- # return indices of boxes with IoU > thres
67
- return torch.nonzero(iou > iou_thres).flatten()
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from ultralytics.models.yolo.segment import SegmentationValidator
4
4
  from ultralytics.utils.metrics import SegmentMetrics
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .model import NAS
4
4
  from .predict import NASPredictor
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """
3
3
  YOLO-NAS model interface.
4
4
 
@@ -6,8 +6,8 @@ Example:
6
6
  ```python
7
7
  from ultralytics import NAS
8
8
 
9
- model = NAS('yolo_nas_s')
10
- results = model.predict('ultralytics/assets/bus.jpg')
9
+ model = NAS("yolo_nas_s")
10
+ results = model.predict("ultralytics/assets/bus.jpg")
11
11
  ```
12
12
  """
13
13
 
@@ -16,7 +16,9 @@ from pathlib import Path
16
16
  import torch
17
17
 
18
18
  from ultralytics.engine.model import Model
19
- from ultralytics.utils.torch_utils import model_info, smart_inference_mode
19
+ from ultralytics.utils.downloads import attempt_download_asset
20
+ from ultralytics.utils.torch_utils import model_info
21
+
20
22
  from .predict import NASPredictor
21
23
  from .val import NASValidator
22
24
 
@@ -32,8 +34,8 @@ class NAS(Model):
32
34
  ```python
33
35
  from ultralytics import NAS
34
36
 
35
- model = NAS('yolo_nas_s')
36
- results = model.predict('ultralytics/assets/bus.jpg')
37
+ model = NAS("yolo_nas_s")
38
+ results = model.predict("ultralytics/assets/bus.jpg")
37
39
  ```
38
40
 
39
41
  Attributes:
@@ -45,19 +47,28 @@ class NAS(Model):
45
47
 
46
48
  def __init__(self, model="yolo_nas_s.pt") -> None:
47
49
  """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
48
- assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
50
+ assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
49
51
  super().__init__(model, task="detect")
50
52
 
51
- @smart_inference_mode()
52
- def _load(self, weights: str, task: str):
53
+ def _load(self, weights: str, task=None) -> None:
53
54
  """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
54
55
  import super_gradients
55
56
 
56
57
  suffix = Path(weights).suffix
57
58
  if suffix == ".pt":
58
- self.model = torch.load(weights)
59
+ self.model = torch.load(attempt_download_asset(weights))
60
+
59
61
  elif suffix == "":
60
62
  self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
63
+
64
+ # Override the forward method to ignore additional arguments
65
+ def new_forward(x, *args, **kwargs):
66
+ """Ignore additional __call__ arguments."""
67
+ return self.model._original_forward(x)
68
+
69
+ self.model._original_forward = self.model.forward
70
+ self.model.forward = new_forward
71
+
61
72
  # Standardize model
62
73
  self.model.fuse = lambda verbose=True: self.model
63
74
  self.model.stride = torch.tensor([32])
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
 
@@ -22,7 +22,7 @@ class NASPredictor(BasePredictor):
22
22
  ```python
23
23
  from ultralytics import NAS
24
24
 
25
- model = NAS('yolo_nas_s')
25
+ model = NAS("yolo_nas_s")
26
26
  predictor = model.predictor
27
27
  # Assumes that raw_preds, img, orig_imgs are available
28
28
  results = predictor.postprocess(raw_preds, img, orig_imgs)
@@ -34,7 +34,6 @@ class NASPredictor(BasePredictor):
34
34
 
35
35
  def postprocess(self, preds_in, img, orig_imgs):
36
36
  """Postprocess predictions and returns a list of Results objects."""
37
-
38
37
  # Cat boxes and class scores
39
38
  boxes = ops.xyxy2xywh(preds_in[0][0])
40
39
  preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
@@ -52,9 +51,7 @@ class NASPredictor(BasePredictor):
52
51
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
53
52
 
54
53
  results = []
55
- for i, pred in enumerate(preds):
56
- orig_img = orig_imgs[i]
54
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
57
55
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
58
- img_path = self.batch[0][i]
59
56
  results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
60
57
  return results
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
 
@@ -17,14 +17,14 @@ class NASValidator(DetectionValidator):
17
17
  ultimately producing the final detections.
18
18
 
19
19
  Attributes:
20
- args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds.
20
+ args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
21
21
  lb (torch.Tensor): Optional tensor for multilabel NMS.
22
22
 
23
23
  Example:
24
24
  ```python
25
25
  from ultralytics import NAS
26
26
 
27
- model = NAS('yolo_nas_s')
27
+ model = NAS("yolo_nas_s")
28
28
  validator = model.validator
29
29
  # Assumes that raw_preds are available
30
30
  final_preds = validator.postprocess(raw_preds)
@@ -44,7 +44,7 @@ class NASValidator(DetectionValidator):
44
44
  self.args.iou,
45
45
  labels=self.lb,
46
46
  multi_label=False,
47
- agnostic=self.args.single_cls,
47
+ agnostic=self.args.single_cls or self.args.agnostic_nms,
48
48
  max_det=self.args.max_det,
49
49
  max_time_img=0.5,
50
50
  )
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .model import RTDETR
4
4
  from .predict import RTDETRPredictor
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
  """
3
3
  Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
4
4
  performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
 
@@ -21,7 +21,7 @@ class RTDETRPredictor(BasePredictor):
21
21
  from ultralytics.utils import ASSETS
22
22
  from ultralytics.models.rtdetr import RTDETRPredictor
23
23
 
24
- args = dict(model='rtdetr-l.pt', source=ASSETS)
24
+ args = dict(model="rtdetr-l.pt", source=ASSETS)
25
25
  predictor = RTDETRPredictor(overrides=args)
26
26
  predictor.predict_cli()
27
27
  ```
@@ -56,18 +56,16 @@ class RTDETRPredictor(BasePredictor):
56
56
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
57
57
 
58
58
  results = []
59
- for i, bbox in enumerate(bboxes): # (300, 4)
59
+ for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
60
60
  bbox = ops.xywh2xyxy(bbox)
61
- score, cls = scores[i].max(-1, keepdim=True) # (300, 1)
62
- idx = score.squeeze(-1) > self.args.conf # (300, )
61
+ max_score, cls = score.max(-1, keepdim=True) # (300, 1)
62
+ idx = max_score.squeeze(-1) > self.args.conf # (300, )
63
63
  if self.args.classes is not None:
64
64
  idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
65
- pred = torch.cat([bbox, score, cls], dim=-1)[idx] # filter
66
- orig_img = orig_imgs[i]
65
+ pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
67
66
  oh, ow = orig_img.shape[:2]
68
67
  pred[..., [0, 2]] *= ow
69
68
  pred[..., [1, 3]] *= oh
70
- img_path = self.batch[0][i]
71
69
  results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
72
70
  return results
73
71
 
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
4
 
@@ -7,6 +7,7 @@ import torch
7
7
  from ultralytics.models.yolo.detect import DetectionTrainer
8
8
  from ultralytics.nn.tasks import RTDETRDetectionModel
9
9
  from ultralytics.utils import RANK, colorstr
10
+
10
11
  from .val import RTDETRDataset, RTDETRValidator
11
12
 
12
13
 
@@ -24,7 +25,7 @@ class RTDETRTrainer(DetectionTrainer):
24
25
  ```python
25
26
  from ultralytics.models.rtdetr.train import RTDETRTrainer
26
27
 
27
- args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
28
+ args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
28
29
  trainer = RTDETRTrainer(overrides=args)
29
30
  trainer.train()
30
31
  ```
@@ -67,8 +68,11 @@ class RTDETRTrainer(DetectionTrainer):
67
68
  hyp=self.args,
68
69
  rect=False,
69
70
  cache=self.args.cache or None,
71
+ single_cls=self.args.single_cls or False,
70
72
  prefix=colorstr(f"{mode}: "),
73
+ classes=self.args.classes,
71
74
  data=self.data,
75
+ fraction=self.args.fraction if mode == "train" else 1.0,
72
76
  )
73
77
 
74
78
  def get_validator(self):
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
 
@@ -62,7 +62,7 @@ class RTDETRValidator(DetectionValidator):
62
62
  ```python
63
63
  from ultralytics.models.rtdetr import RTDETRValidator
64
64
 
65
- args = dict(model='rtdetr-l.pt', data='coco8.yaml')
65
+ args = dict(model="rtdetr-l.pt", data="coco8.yaml")
66
66
  validator = RTDETRValidator(args=args)
67
67
  validator()
68
68
  ```
@@ -125,7 +125,7 @@ class RTDETRValidator(DetectionValidator):
125
125
  bbox = ops.xywh2xyxy(bbox) # target boxes
126
126
  bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
127
127
  bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
128
- return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
128
+ return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
129
129
 
130
130
  def _prepare_pred(self, pred, pbatch):
131
131
  """Prepares and returns a batch with transformed bounding boxes and class labels."""
@@ -1,6 +1,6 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .model import SAM
4
- from .predict import Predictor
4
+ from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
5
5
 
6
- __all__ = "SAM", "Predictor" # tuple or list
6
+ __all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list