ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,169 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from ultralytics.models.yolo.segment import SegmentationPredictor
7
+ from ultralytics.utils import DEFAULT_CFG
8
+ from ultralytics.utils.metrics import box_iou
9
+ from ultralytics.utils.ops import scale_masks
10
+ from ultralytics.utils.torch_utils import TORCH_1_10
11
+
12
+ from .utils import adjust_bboxes_to_image_border
13
+
14
+
15
+ class FastSAMPredictor(SegmentationPredictor):
16
+ """FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
17
+
18
+ This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
19
+ adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
20
+ single-class segmentation.
21
+
22
+ Attributes:
23
+ prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
24
+ device (torch.device): Device on which model and tensors are processed.
25
+ clip (Any, optional): CLIP model used for text-based prompting, loaded on demand.
26
+
27
+ Methods:
28
+ postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
29
+ prompt: Perform image segmentation inference based on various prompt types.
30
+ set_prompts: Set prompts to be used during inference.
31
+ """
32
+
33
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
34
+ """Initialize the FastSAMPredictor with configuration and callbacks.
35
+
36
+ This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
37
+ extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
38
+ optimized for single-class segmentation.
39
+
40
+ Args:
41
+ cfg (dict): Configuration for the predictor.
42
+ overrides (dict, optional): Configuration overrides.
43
+ _callbacks (list, optional): List of callback functions.
44
+ """
45
+ super().__init__(cfg, overrides, _callbacks)
46
+ self.prompts = {}
47
+
48
+ def postprocess(self, preds, img, orig_imgs):
49
+ """Apply postprocessing to FastSAM predictions and handle prompts.
50
+
51
+ Args:
52
+ preds (list[torch.Tensor]): Raw predictions from the model.
53
+ img (torch.Tensor): Input image tensor that was fed to the model.
54
+ orig_imgs (list[np.ndarray]): Original images before preprocessing.
55
+
56
+ Returns:
57
+ (list[Results]): Processed results with prompts applied.
58
+ """
59
+ bboxes = self.prompts.pop("bboxes", None)
60
+ points = self.prompts.pop("points", None)
61
+ labels = self.prompts.pop("labels", None)
62
+ texts = self.prompts.pop("texts", None)
63
+ results = super().postprocess(preds, img, orig_imgs)
64
+ for result in results:
65
+ full_box = torch.tensor(
66
+ [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
67
+ )
68
+ boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
69
+ idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
70
+ if idx.numel() != 0:
71
+ result.boxes.xyxy[idx] = full_box
72
+
73
+ return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
74
+
75
+ def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
76
+ """Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
77
+
78
+ Args:
79
+ results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
80
+ bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
81
+ points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
82
+ labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
83
+ texts (str | list[str], optional): Textual prompts, a list containing string objects.
84
+
85
+ Returns:
86
+ (list[Results]): Output results filtered and determined by the provided prompts.
87
+ """
88
+ if bboxes is None and points is None and texts is None:
89
+ return results
90
+ prompt_results = []
91
+ if not isinstance(results, list):
92
+ results = [results]
93
+ for result in results:
94
+ if len(result) == 0:
95
+ prompt_results.append(result)
96
+ continue
97
+ masks = result.masks.data
98
+ if masks.shape[1:] != result.orig_shape:
99
+ masks = (scale_masks(masks[None].float(), result.orig_shape)[0] > 0.5).byte()
100
+ # bboxes prompt
101
+ idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
102
+ if bboxes is not None:
103
+ bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
104
+ bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
105
+ bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
106
+ mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
107
+ full_mask_areas = torch.sum(masks, dim=(1, 2))
108
+
109
+ union = bbox_areas[:, None] + full_mask_areas - mask_areas
110
+ idx[torch.argmax(mask_areas / union, dim=1)] = True
111
+ if points is not None:
112
+ points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
113
+ points = points[None] if points.ndim == 1 else points
114
+ if labels is None:
115
+ labels = torch.ones(points.shape[0])
116
+ labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
117
+ assert len(labels) == len(points), (
118
+ f"Expected `labels` to have the same length as `points`, but got {len(labels)} and {len(points)}."
119
+ )
120
+ point_idx = (
121
+ torch.ones(len(result), dtype=torch.bool, device=self.device)
122
+ if labels.sum() == 0 # all negative points
123
+ else torch.zeros(len(result), dtype=torch.bool, device=self.device)
124
+ )
125
+ for point, label in zip(points, labels):
126
+ point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
127
+ idx |= point_idx
128
+ if texts is not None:
129
+ if isinstance(texts, str):
130
+ texts = [texts]
131
+ crop_ims, filter_idx = [], []
132
+ for i, b in enumerate(result.boxes.xyxy.tolist()):
133
+ x1, y1, x2, y2 = (int(x) for x in b)
134
+ if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100: # torch 1.9 bug workaround
135
+ filter_idx.append(i)
136
+ continue
137
+ crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
138
+ similarity = self._clip_inference(crop_ims, texts)
139
+ text_idx = torch.argmax(similarity, dim=-1) # (M, )
140
+ if len(filter_idx):
141
+ text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
142
+ idx[text_idx] = True
143
+
144
+ prompt_results.append(result[idx])
145
+
146
+ return prompt_results
147
+
148
+ def _clip_inference(self, images, texts):
149
+ """Perform CLIP inference to calculate similarity between images and text prompts.
150
+
151
+ Args:
152
+ images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
153
+ texts (list[str]): List of prompt texts, each should be a string object.
154
+
155
+ Returns:
156
+ (torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
157
+ """
158
+ from ultralytics.nn.text_model import CLIP
159
+
160
+ if not hasattr(self, "clip"):
161
+ self.clip = CLIP("ViT-B/32", device=self.device)
162
+ images = torch.stack([self.clip.image_preprocess(image).to(self.device) for image in images])
163
+ image_features = self.clip.encode_image(images)
164
+ text_features = self.clip.encode_text(self.clip.tokenize(texts))
165
+ return text_features @ image_features.T # (M, N)
166
+
167
+ def set_prompts(self, prompts):
168
+ """Set prompts to be used during inference."""
169
+ self.prompts = prompts
@@ -0,0 +1,23 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+
4
+ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
5
+ """Adjust bounding boxes to stick to image border if they are within a certain threshold.
6
+
7
+ Args:
8
+ boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
9
+ image_shape (tuple): Image dimensions as (height, width).
10
+ threshold (int): Pixel threshold for considering a box close to the border.
11
+
12
+ Returns:
13
+ (torch.Tensor): Adjusted bounding boxes with shape (N, 4).
14
+ """
15
+ # Image dimensions
16
+ h, w = image_shape
17
+
18
+ # Adjust boxes that are close to image borders
19
+ boxes[boxes[:, 0] < threshold, 0] = 0 # x1
20
+ boxes[boxes[:, 1] < threshold, 1] = 0 # y1
21
+ boxes[boxes[:, 2] > w - threshold, 2] = w # x2
22
+ boxes[boxes[:, 3] > h - threshold, 3] = h # y2
23
+ return boxes
@@ -0,0 +1,38 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.models.yolo.segment import SegmentationValidator
4
+
5
+
6
+ class FastSAMValidator(SegmentationValidator):
7
+ """Custom validation class for FastSAM (Segment Anything Model) segmentation in the Ultralytics YOLO framework.
8
+
9
+ Extends the SegmentationValidator class, customizing the validation process specifically for FastSAM. This class
10
+ sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
11
+ to avoid errors during validation.
12
+
13
+ Attributes:
14
+ dataloader (torch.utils.data.DataLoader): The data loader object used for validation.
15
+ save_dir (Path): The directory where validation results will be saved.
16
+ args (SimpleNamespace): Additional arguments for customization of the validation process.
17
+ _callbacks (list): List of callback functions to be invoked during validation.
18
+ metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
19
+
20
+ Methods:
21
+ __init__: Initialize the FastSAMValidator with custom settings for FastSAM.
22
+ """
23
+
24
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
25
+ """Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
26
+
27
+ Args:
28
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to be used for validation.
29
+ save_dir (Path, optional): Directory to save results.
30
+ args (SimpleNamespace, optional): Configuration for the validator.
31
+ _callbacks (list, optional): List of callback functions to be invoked during validation.
32
+
33
+ Notes:
34
+ Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
35
+ """
36
+ super().__init__(dataloader, save_dir, args, _callbacks)
37
+ self.args.task = "segment"
38
+ self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .model import NAS
4
+ from .predict import NASPredictor
5
+ from .val import NASValidator
6
+
7
+ __all__ = "NAS", "NASPredictor", "NASValidator"
@@ -0,0 +1,98 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import torch
9
+
10
+ from ultralytics.engine.model import Model
11
+ from ultralytics.utils import DEFAULT_CFG_DICT
12
+ from ultralytics.utils.downloads import attempt_download_asset
13
+ from ultralytics.utils.patches import torch_load
14
+ from ultralytics.utils.torch_utils import model_info
15
+
16
+ from .predict import NASPredictor
17
+ from .val import NASValidator
18
+
19
+
20
+ class NAS(Model):
21
+ """YOLO-NAS model for object detection.
22
+
23
+ This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. It
24
+ is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
25
+
26
+ Attributes:
27
+ model (torch.nn.Module): The loaded YOLO-NAS model.
28
+ task (str): The task type for the model, defaults to 'detect'.
29
+ predictor (NASPredictor): The predictor instance for making predictions.
30
+ validator (NASValidator): The validator instance for model validation.
31
+
32
+ Methods:
33
+ info: Log model information and return model details.
34
+
35
+ Examples:
36
+ >>> from ultralytics import NAS
37
+ >>> model = NAS("yolo_nas_s")
38
+ >>> results = model.predict("ultralytics/assets/bus.jpg")
39
+
40
+ Notes:
41
+ YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
42
+ """
43
+
44
+ def __init__(self, model: str = "yolo_nas_s.pt") -> None:
45
+ """Initialize the NAS model with the provided or default model."""
46
+ assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
47
+ super().__init__(model, task="detect")
48
+
49
+ def _load(self, weights: str, task=None) -> None:
50
+ """Load an existing NAS model weights or create a new NAS model with pretrained weights.
51
+
52
+ Args:
53
+ weights (str): Path to the model weights file or model name.
54
+ task (str, optional): Task type for the model.
55
+ """
56
+ import super_gradients
57
+
58
+ suffix = Path(weights).suffix
59
+ if suffix == ".pt":
60
+ self.model = torch_load(attempt_download_asset(weights))
61
+ elif suffix == "":
62
+ self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
63
+
64
+ # Override the forward method to ignore additional arguments
65
+ def new_forward(x, *args, **kwargs):
66
+ """Ignore additional __call__ arguments."""
67
+ return self.model._original_forward(x)
68
+
69
+ self.model._original_forward = self.model.forward
70
+ self.model.forward = new_forward
71
+
72
+ # Standardize model attributes for compatibility
73
+ self.model.fuse = lambda verbose=True: self.model
74
+ self.model.stride = torch.tensor([32])
75
+ self.model.names = dict(enumerate(self.model._class_names))
76
+ self.model.is_fused = lambda: False # for info()
77
+ self.model.yaml = {} # for info()
78
+ self.model.pt_path = weights # for export()
79
+ self.model.task = "detect" # for export()
80
+ self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
81
+ self.model.eval()
82
+
83
+ def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
84
+ """Log model information.
85
+
86
+ Args:
87
+ detailed (bool): Show detailed information about model.
88
+ verbose (bool): Controls verbosity.
89
+
90
+ Returns:
91
+ (dict[str, Any]): Model information dictionary.
92
+ """
93
+ return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
94
+
95
+ @property
96
+ def task_map(self) -> dict[str, dict[str, Any]]:
97
+ """Return a dictionary mapping tasks to respective predictor and validator classes."""
98
+ return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
@@ -0,0 +1,56 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+
5
+ from ultralytics.models.yolo.detect.predict import DetectionPredictor
6
+ from ultralytics.utils import ops
7
+
8
+
9
+ class NASPredictor(DetectionPredictor):
10
+ """Ultralytics YOLO NAS Predictor for object detection.
11
+
12
+ This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw
13
+ predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the
14
+ bounding boxes to fit the original image dimensions.
15
+
16
+ Attributes:
17
+ args (Namespace): Namespace containing various configurations for post-processing including confidence
18
+ threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
19
+ model (torch.nn.Module): The YOLO NAS model used for inference.
20
+ batch (list): Batch of inputs for processing.
21
+
22
+ Examples:
23
+ >>> from ultralytics import NAS
24
+ >>> model = NAS("yolo_nas_s")
25
+ >>> predictor = model.predictor
26
+
27
+ Assume that raw_preds, img, orig_imgs are available
28
+ >>> results = predictor.postprocess(raw_preds, img, orig_imgs)
29
+
30
+ Notes:
31
+ Typically, this class is not instantiated directly. It is used internally within the NAS class.
32
+ """
33
+
34
+ def postprocess(self, preds_in, img, orig_imgs):
35
+ """Postprocess NAS model predictions to generate final detection results.
36
+
37
+ This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
38
+ post-processing operations to generate the final detection results compatible with Ultralytics result
39
+ visualization and analysis tools.
40
+
41
+ Args:
42
+ preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.
43
+ img (torch.Tensor): Input image tensor that was fed to the model, with shape (B, C, H, W).
44
+ orig_imgs (list | torch.Tensor | np.ndarray): Original images before preprocessing, used for scaling
45
+ coordinates back to original dimensions.
46
+
47
+ Returns:
48
+ (list): List of Results objects containing the processed predictions for each image in the batch.
49
+
50
+ Examples:
51
+ >>> predictor = NAS("yolo_nas_s").predictor
52
+ >>> results = predictor.postprocess(raw_preds, img, orig_imgs)
53
+ """
54
+ boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format
55
+ preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores
56
+ return super().postprocess(preds, img, orig_imgs)
@@ -0,0 +1,38 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+
5
+ from ultralytics.models.yolo.detect import DetectionValidator
6
+ from ultralytics.utils import ops
7
+
8
+ __all__ = ["NASValidator"]
9
+
10
+
11
+ class NASValidator(DetectionValidator):
12
+ """Ultralytics YOLO NAS Validator for object detection.
13
+
14
+ Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
15
+ generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
16
+ ultimately producing the final detections.
17
+
18
+ Attributes:
19
+ args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU
20
+ thresholds.
21
+ lb (torch.Tensor): Optional tensor for multilabel NMS.
22
+
23
+ Examples:
24
+ >>> from ultralytics import NAS
25
+ >>> model = NAS("yolo_nas_s")
26
+ >>> validator = model.validator
27
+ >>> # Assumes that raw_preds are available
28
+ >>> final_preds = validator.postprocess(raw_preds)
29
+
30
+ Notes:
31
+ This class is generally not instantiated directly but is used internally within the NAS class.
32
+ """
33
+
34
+ def postprocess(self, preds_in):
35
+ """Apply Non-maximum suppression to prediction outputs."""
36
+ boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh
37
+ preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute
38
+ return super().postprocess(preds)
@@ -0,0 +1,7 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .model import RTDETR
4
+ from .predict import RTDETRPredictor
5
+ from .val import RTDETRValidator
6
+
7
+ __all__ = "RTDETR", "RTDETRPredictor", "RTDETRValidator"
@@ -0,0 +1,63 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """
3
+ Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector.
4
+
5
+ RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT.
6
+ It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
7
+
8
+ References:
9
+ https://arxiv.org/pdf/2304.08069.pdf
10
+ """
11
+
12
+ from ultralytics.engine.model import Model
13
+ from ultralytics.nn.tasks import RTDETRDetectionModel
14
+ from ultralytics.utils.torch_utils import TORCH_1_11
15
+
16
+ from .predict import RTDETRPredictor
17
+ from .train import RTDETRTrainer
18
+ from .val import RTDETRValidator
19
+
20
+
21
+ class RTDETR(Model):
22
+ """Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
23
+
24
+ This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
25
+ selection, and adaptable inference speed.
26
+
27
+ Attributes:
28
+ model (str): Path to the pre-trained model.
29
+
30
+ Methods:
31
+ task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
32
+
33
+ Examples:
34
+ Initialize RT-DETR with a pre-trained model
35
+ >>> from ultralytics import RTDETR
36
+ >>> model = RTDETR("rtdetr-l.pt")
37
+ >>> results = model("image.jpg")
38
+ """
39
+
40
+ def __init__(self, model: str = "rtdetr-l.pt") -> None:
41
+ """Initialize the RT-DETR model with the given pre-trained model file.
42
+
43
+ Args:
44
+ model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
45
+ """
46
+ assert TORCH_1_11, "RTDETR requires torch>=1.11"
47
+ super().__init__(model=model, task="detect")
48
+
49
+ @property
50
+ def task_map(self) -> dict:
51
+ """Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
52
+
53
+ Returns:
54
+ (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
55
+ """
56
+ return {
57
+ "detect": {
58
+ "predictor": RTDETRPredictor,
59
+ "validator": RTDETRValidator,
60
+ "trainer": RTDETRTrainer,
61
+ "model": RTDETRDetectionModel,
62
+ }
63
+ }
@@ -0,0 +1,88 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import torch
4
+
5
+ from ultralytics.data.augment import LetterBox
6
+ from ultralytics.engine.predictor import BasePredictor
7
+ from ultralytics.engine.results import Results
8
+ from ultralytics.utils import ops
9
+
10
+
11
+ class RTDETRPredictor(BasePredictor):
12
+ """RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
13
+
14
+ This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It
15
+ supports key features like efficient hybrid encoding and IoU-aware query selection.
16
+
17
+ Attributes:
18
+ imgsz (int): Image size for inference (must be square and scale-filled).
19
+ args (dict): Argument overrides for the predictor.
20
+ model (torch.nn.Module): The loaded RT-DETR model.
21
+ batch (list): Current batch of processed inputs.
22
+
23
+ Methods:
24
+ postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.
25
+ pre_transform: Pre-transform input images before feeding them into the model for inference.
26
+
27
+ Examples:
28
+ >>> from ultralytics.utils import ASSETS
29
+ >>> from ultralytics.models.rtdetr import RTDETRPredictor
30
+ >>> args = dict(model="rtdetr-l.pt", source=ASSETS)
31
+ >>> predictor = RTDETRPredictor(overrides=args)
32
+ >>> predictor.predict_cli()
33
+ """
34
+
35
+ def postprocess(self, preds, img, orig_imgs):
36
+ """Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
37
+
38
+ The method filters detections based on confidence and class if specified in `self.args`. It converts model
39
+ predictions to Results objects containing properly scaled bounding boxes.
40
+
41
+ Args:
42
+ preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes
43
+ and scores.
44
+ img (torch.Tensor): Processed input images with shape (N, 3, H, W).
45
+ orig_imgs (list | torch.Tensor): Original, unprocessed images.
46
+
47
+ Returns:
48
+ results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence
49
+ scores, and class labels.
50
+ """
51
+ if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
52
+ preds = [preds, None]
53
+
54
+ nd = preds[0].shape[-1]
55
+ bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
56
+
57
+ if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
58
+ orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)[..., ::-1]
59
+
60
+ results = []
61
+ for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
62
+ bbox = ops.xywh2xyxy(bbox)
63
+ max_score, cls = score.max(-1, keepdim=True) # (300, 1)
64
+ idx = max_score.squeeze(-1) > self.args.conf # (300, )
65
+ if self.args.classes is not None:
66
+ idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
67
+ pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
68
+ pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det]
69
+ oh, ow = orig_img.shape[:2]
70
+ pred[..., [0, 2]] *= ow # scale x coordinates to original width
71
+ pred[..., [1, 3]] *= oh # scale y coordinates to original height
72
+ results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
73
+ return results
74
+
75
+ def pre_transform(self, im):
76
+ """Pre-transform input images before feeding them into the model for inference.
77
+
78
+ The input images are letterboxed to ensure a square aspect ratio and scale-filled.
79
+
80
+ Args:
81
+ im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
82
+ list.
83
+
84
+ Returns:
85
+ (list): List of pre-transformed images ready for model inference.
86
+ """
87
+ letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True)
88
+ return [letterbox(image=x) for x in im]