dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -4,16 +4,16 @@ import torch
4
4
  from PIL import Image
5
5
 
6
6
  from ultralytics.models.yolo.segment import SegmentationPredictor
7
- from ultralytics.utils import DEFAULT_CFG, checks
7
+ from ultralytics.utils import DEFAULT_CFG
8
8
  from ultralytics.utils.metrics import box_iou
9
9
  from ultralytics.utils.ops import scale_masks
10
+ from ultralytics.utils.torch_utils import TORCH_1_10
10
11
 
11
12
  from .utils import adjust_bboxes_to_image_border
12
13
 
13
14
 
14
15
  class FastSAMPredictor(SegmentationPredictor):
15
- """
16
- FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
16
+ """FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
17
17
 
18
18
  This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
19
19
  adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
@@ -26,22 +26,20 @@ class FastSAMPredictor(SegmentationPredictor):
26
26
  clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
27
27
 
28
28
  Methods:
29
- postprocess: Applies box postprocessing for FastSAM predictions.
30
- prompt: Performs image segmentation inference based on various prompt types.
31
- _clip_inference: Performs CLIP inference to calculate similarity between images and text prompts.
32
- set_prompts: Sets prompts to be used during inference.
29
+ postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
30
+ prompt: Perform image segmentation inference based on various prompt types.
31
+ set_prompts: Set prompts to be used during inference.
33
32
  """
34
33
 
35
34
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
36
- """
37
- Initialize the FastSAMPredictor with configuration and callbacks.
35
+ """Initialize the FastSAMPredictor with configuration and callbacks.
38
36
 
39
37
  This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
40
38
  extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
41
39
  optimized for single-class segmentation.
42
40
 
43
41
  Args:
44
- cfg (dict): Configuration for the predictor. Defaults to Ultralytics DEFAULT_CFG.
42
+ cfg (dict): Configuration for the predictor.
45
43
  overrides (dict, optional): Configuration overrides.
46
44
  _callbacks (list, optional): List of callback functions.
47
45
  """
@@ -49,16 +47,15 @@ class FastSAMPredictor(SegmentationPredictor):
49
47
  self.prompts = {}
50
48
 
51
49
  def postprocess(self, preds, img, orig_imgs):
52
- """
53
- Apply postprocessing to FastSAM predictions and handle prompts.
50
+ """Apply postprocessing to FastSAM predictions and handle prompts.
54
51
 
55
52
  Args:
56
- preds (List[torch.Tensor]): Raw predictions from the model.
53
+ preds (list[torch.Tensor]): Raw predictions from the model.
57
54
  img (torch.Tensor): Input image tensor that was fed to the model.
58
- orig_imgs (List[numpy.ndarray]): Original images before preprocessing.
55
+ orig_imgs (list[np.ndarray]): Original images before preprocessing.
59
56
 
60
57
  Returns:
61
- (List[Results]): Processed results with prompts applied.
58
+ (list[Results]): Processed results with prompts applied.
62
59
  """
63
60
  bboxes = self.prompts.pop("bboxes", None)
64
61
  points = self.prompts.pop("points", None)
@@ -77,18 +74,17 @@ class FastSAMPredictor(SegmentationPredictor):
77
74
  return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
78
75
 
79
76
  def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
80
- """
81
- Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
77
+ """Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
82
78
 
83
79
  Args:
84
- results (Results | List[Results]): Original inference results from FastSAM models without any prompts.
85
- bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
86
- points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
87
- labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
88
- texts (str | List[str], optional): Textual prompts, a list containing string objects.
80
+ results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
81
+ bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
82
+ points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
83
+ labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
84
+ texts (str | list[str], optional): Textual prompts, a list containing string objects.
89
85
 
90
86
  Returns:
91
- (List[Results]): Output results filtered and determined by the provided prompts.
87
+ (list[Results]): Output results filtered and determined by the provided prompts.
92
88
  """
93
89
  if bboxes is None and points is None and texts is None:
94
90
  return results
@@ -101,7 +97,7 @@ class FastSAMPredictor(SegmentationPredictor):
101
97
  continue
102
98
  masks = result.masks.data
103
99
  if masks.shape[1:] != result.orig_shape:
104
- masks = scale_masks(masks[None], result.orig_shape)[0]
100
+ masks = (scale_masks(masks[None].float(), result.orig_shape)[0] > 0.5).byte()
105
101
  # bboxes prompt
106
102
  idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
107
103
  if bboxes is not None:
@@ -120,7 +116,7 @@ class FastSAMPredictor(SegmentationPredictor):
120
116
  labels = torch.ones(points.shape[0])
121
117
  labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
122
118
  assert len(labels) == len(points), (
123
- f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
119
+ f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
124
120
  )
125
121
  point_idx = (
126
122
  torch.ones(len(result), dtype=torch.bool, device=self.device)
@@ -136,7 +132,7 @@ class FastSAMPredictor(SegmentationPredictor):
136
132
  crop_ims, filter_idx = [], []
137
133
  for i, b in enumerate(result.boxes.xyxy.tolist()):
138
134
  x1, y1, x2, y2 = (int(x) for x in b)
139
- if masks[i].sum() <= 100:
135
+ if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100: # torch 1.9 bug workaround
140
136
  filter_idx.append(i)
141
137
  continue
142
138
  crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
@@ -151,30 +147,23 @@ class FastSAMPredictor(SegmentationPredictor):
151
147
  return prompt_results
152
148
 
153
149
  def _clip_inference(self, images, texts):
154
- """
155
- Perform CLIP inference to calculate similarity between images and text prompts.
150
+ """Perform CLIP inference to calculate similarity between images and text prompts.
156
151
 
157
152
  Args:
158
- images (List[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
159
- texts (List[str]): List of prompt texts, each should be a string object.
153
+ images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
154
+ texts (list[str]): List of prompt texts, each should be a string object.
160
155
 
161
156
  Returns:
162
157
  (torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
163
158
  """
164
- try:
165
- import clip
166
- except ImportError:
167
- checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
168
- import clip
169
- if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
170
- self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
171
- images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
172
- tokenized_text = clip.tokenize(texts).to(self.device)
173
- image_features = self.clip_model.encode_image(images)
174
- text_features = self.clip_model.encode_text(tokenized_text)
175
- image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
176
- text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
177
- return (image_features * text_features[:, None]).sum(-1) # (M, N)
159
+ from ultralytics.nn.text_model import CLIP
160
+
161
+ if not hasattr(self, "clip"):
162
+ self.clip = CLIP("ViT-B/32", device=self.device)
163
+ images = torch.stack([self.clip.image_preprocess(image).to(self.device) for image in images])
164
+ image_features = self.clip.encode_image(images)
165
+ text_features = self.clip.encode_text(self.clip.tokenize(texts))
166
+ return text_features @ image_features.T # (M, N)
178
167
 
179
168
  def set_prompts(self, prompts):
180
169
  """Set prompts to be used during inference."""
@@ -2,16 +2,15 @@
2
2
 
3
3
 
4
4
  def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
5
- """
6
- Adjust bounding boxes to stick to image border if they are within a certain threshold.
5
+ """Adjust bounding boxes to stick to image border if they are within a certain threshold.
7
6
 
8
7
  Args:
9
- boxes (torch.Tensor): Bounding boxes with shape (n, 4) in xyxy format.
10
- image_shape (Tuple[int, int]): Image dimensions as (height, width).
8
+ boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
9
+ image_shape (tuple): Image dimensions as (height, width).
11
10
  threshold (int): Pixel threshold for considering a box close to the border.
12
11
 
13
12
  Returns:
14
- boxes (torch.Tensor): Adjusted bounding boxes with shape (n, 4).
13
+ (torch.Tensor): Adjusted bounding boxes with shape (N, 4).
15
14
  """
16
15
  # Image dimensions
17
16
  h, w = image_shape
@@ -1,40 +1,38 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from ultralytics.models.yolo.segment import SegmentationValidator
4
- from ultralytics.utils.metrics import SegmentMetrics
5
4
 
6
5
 
7
6
  class FastSAMValidator(SegmentationValidator):
8
- """
9
- Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
7
+ """Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
10
8
 
11
- Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class
9
+ Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
12
10
  sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
13
11
  to avoid errors during validation.
14
12
 
15
13
  Attributes:
16
14
  dataloader (torch.utils.data.DataLoader): The data loader object used for validation.
17
15
  save_dir (Path): The directory where validation results will be saved.
18
- pbar (tqdm.tqdm): A progress bar object for displaying validation progress.
19
16
  args (SimpleNamespace): Additional arguments for customization of the validation process.
20
17
  _callbacks (list): List of callback functions to be invoked during validation.
18
+ metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
19
+
20
+ Methods:
21
+ __init__: Initialize the FastSAMValidator with custom settings for Fast SAM.
21
22
  """
22
23
 
23
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
24
- """
25
- Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
24
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
25
+ """Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
26
26
 
27
27
  Args:
28
- dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
28
+ dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
29
29
  save_dir (Path, optional): Directory to save results.
30
- pbar (tqdm.tqdm): Progress bar for displaying progress.
31
- args (SimpleNamespace): Configuration for the validator.
32
- _callbacks (list): List of callback functions to be invoked during validation.
30
+ args (SimpleNamespace, optional): Configuration for the validator.
31
+ _callbacks (list, optional): List of callback functions to be invoked during validation.
33
32
 
34
33
  Notes:
35
34
  Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
36
35
  """
37
- super().__init__(dataloader, save_dir, pbar, args, _callbacks)
36
+ super().__init__(dataloader, save_dir, args, _callbacks)
38
37
  self.args.task = "segment"
39
38
  self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
40
- self.metrics = SegmentMetrics(save_dir=self.save_dir)
@@ -4,4 +4,4 @@ from .model import NAS
4
4
  from .predict import NASPredictor
5
5
  from .val import NASValidator
6
6
 
7
- __all__ = "NASPredictor", "NASValidator", "NAS"
7
+ __all__ = "NAS", "NASPredictor", "NASValidator"
@@ -1,20 +1,16 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
- """
3
- YOLO-NAS model interface.
4
2
 
5
- Examples:
6
- >>> from ultralytics import NAS
7
- >>> model = NAS("yolo_nas_s")
8
- >>> results = model.predict("ultralytics/assets/bus.jpg")
9
- """
3
+ from __future__ import annotations
10
4
 
11
5
  from pathlib import Path
6
+ from typing import Any
12
7
 
13
8
  import torch
14
9
 
15
10
  from ultralytics.engine.model import Model
16
11
  from ultralytics.utils import DEFAULT_CFG_DICT
17
12
  from ultralytics.utils.downloads import attempt_download_asset
13
+ from ultralytics.utils.patches import torch_load
18
14
  from ultralytics.utils.torch_utils import model_info
19
15
 
20
16
  from .predict import NASPredictor
@@ -22,11 +18,10 @@ from .val import NASValidator
22
18
 
23
19
 
24
20
  class NAS(Model):
25
- """
26
- YOLO NAS model for object detection.
21
+ """YOLO-NAS model for object detection.
27
22
 
28
- This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
29
- It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
23
+ This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. It
24
+ is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
30
25
 
31
26
  Attributes:
32
27
  model (torch.nn.Module): The loaded YOLO-NAS model.
@@ -34,6 +29,9 @@ class NAS(Model):
34
29
  predictor (NASPredictor): The predictor instance for making predictions.
35
30
  validator (NASValidator): The validator instance for model validation.
36
31
 
32
+ Methods:
33
+ info: Log model information and return model details.
34
+
37
35
  Examples:
38
36
  >>> from ultralytics import NAS
39
37
  >>> model = NAS("yolo_nas_s")
@@ -49,8 +47,7 @@ class NAS(Model):
49
47
  super().__init__(model, task="detect")
50
48
 
51
49
  def _load(self, weights: str, task=None) -> None:
52
- """
53
- Load an existing NAS model weights or create a new NAS model with pretrained weights.
50
+ """Load an existing NAS model weights or create a new NAS model with pretrained weights.
54
51
 
55
52
  Args:
56
53
  weights (str): Path to the model weights file or model name.
@@ -60,7 +57,7 @@ class NAS(Model):
60
57
 
61
58
  suffix = Path(weights).suffix
62
59
  if suffix == ".pt":
63
- self.model = torch.load(attempt_download_asset(weights))
60
+ self.model = torch_load(attempt_download_asset(weights))
64
61
  elif suffix == "":
65
62
  self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
66
63
 
@@ -72,7 +69,7 @@ class NAS(Model):
72
69
  self.model._original_forward = self.model.forward
73
70
  self.model.forward = new_forward
74
71
 
75
- # Standardize model
72
+ # Standardize model attributes for compatibility
76
73
  self.model.fuse = lambda verbose=True: self.model
77
74
  self.model.stride = torch.tensor([32])
78
75
  self.model.names = dict(enumerate(self.model._class_names))
@@ -83,20 +80,19 @@ class NAS(Model):
83
80
  self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
84
81
  self.model.eval()
85
82
 
86
- def info(self, detailed: bool = False, verbose: bool = True):
87
- """
88
- Log model information.
83
+ def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
84
+ """Log model information.
89
85
 
90
86
  Args:
91
87
  detailed (bool): Show detailed information about model.
92
88
  verbose (bool): Controls verbosity.
93
89
 
94
90
  Returns:
95
- (dict): Model information dictionary.
91
+ (dict[str, Any]): Model information dictionary.
96
92
  """
97
93
  return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
98
94
 
99
95
  @property
100
- def task_map(self):
96
+ def task_map(self) -> dict[str, dict[str, Any]]:
101
97
  """Return a dictionary mapping tasks to respective predictor and validator classes."""
102
98
  return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
@@ -7,16 +7,15 @@ from ultralytics.utils import ops
7
7
 
8
8
 
9
9
  class NASPredictor(DetectionPredictor):
10
- """
11
- Ultralytics YOLO NAS Predictor for object detection.
10
+ """Ultralytics YOLO NAS Predictor for object detection.
12
11
 
13
- This class extends the `DetectionPredictor` from Ultralytics engine and is responsible for post-processing the
14
- raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
15
- scaling the bounding boxes to fit the original image dimensions.
12
+ This class extends the DetectionPredictor from Ultralytics engine and is responsible for post-processing the raw
13
+ predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and scaling the
14
+ bounding boxes to fit the original image dimensions.
16
15
 
17
16
  Attributes:
18
- args (Namespace): Namespace containing various configurations for post-processing including confidence threshold,
19
- IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
17
+ args (Namespace): Namespace containing various configurations for post-processing including confidence
18
+ threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
20
19
  model (torch.nn.Module): The YOLO NAS model used for inference.
21
20
  batch (list): Batch of inputs for processing.
22
21
 
@@ -29,16 +28,15 @@ class NASPredictor(DetectionPredictor):
29
28
  >>> results = predictor.postprocess(raw_preds, img, orig_imgs)
30
29
 
31
30
  Notes:
32
- Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
31
+ Typically, this class is not instantiated directly. It is used internally within the NAS class.
33
32
  """
34
33
 
35
34
  def postprocess(self, preds_in, img, orig_imgs):
36
- """
37
- Postprocess NAS model predictions to generate final detection results.
35
+ """Postprocess NAS model predictions to generate final detection results.
38
36
 
39
37
  This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
40
- post-processing operations to generate the final detection results compatible with Ultralytics
41
- result visualization and analysis tools.
38
+ post-processing operations to generate the final detection results compatible with Ultralytics result
39
+ visualization and analysis tools.
42
40
 
43
41
  Args:
44
42
  preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.
@@ -53,6 +51,6 @@ class NASPredictor(DetectionPredictor):
53
51
  >>> predictor = NAS("yolo_nas_s").predictor
54
52
  >>> results = predictor.postprocess(raw_preds, img, orig_imgs)
55
53
  """
56
- boxes = ops.xyxy2xywh(preds_in[0][0])
57
- preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # concatenate with class scores
54
+ boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format
55
+ preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores
58
56
  return super().postprocess(preds, img, orig_imgs)
@@ -9,10 +9,9 @@ __all__ = ["NASValidator"]
9
9
 
10
10
 
11
11
  class NASValidator(DetectionValidator):
12
- """
13
- Ultralytics YOLO NAS Validator for object detection.
12
+ """Ultralytics YOLO NAS Validator for object detection.
14
13
 
15
- Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
14
+ Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
16
15
  generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
17
16
  ultimately producing the final detections.
18
17
 
@@ -25,11 +24,11 @@ class NASValidator(DetectionValidator):
25
24
  >>> from ultralytics import NAS
26
25
  >>> model = NAS("yolo_nas_s")
27
26
  >>> validator = model.validator
28
- Assumes that raw_preds are available
27
+ >>> # Assumes that raw_preds are available
29
28
  >>> final_preds = validator.postprocess(raw_preds)
30
29
 
31
30
  Notes:
32
- This class is generally not instantiated directly but is used internally within the `NAS` class.
31
+ This class is generally not instantiated directly but is used internally within the NAS class.
33
32
  """
34
33
 
35
34
  def postprocess(self, preds_in):
@@ -4,4 +4,4 @@ from .model import RTDETR
4
4
  from .predict import RTDETRPredictor
5
5
  from .val import RTDETRValidator
6
6
 
7
- __all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"
7
+ __all__ = "RTDETR", "RTDETRPredictor", "RTDETRValidator"
@@ -11,6 +11,7 @@ References:
11
11
 
12
12
  from ultralytics.engine.model import Model
13
13
  from ultralytics.nn.tasks import RTDETRDetectionModel
14
+ from ultralytics.utils.torch_utils import TORCH_1_11
14
15
 
15
16
  from .predict import RTDETRPredictor
16
17
  from .train import RTDETRTrainer
@@ -18,8 +19,7 @@ from .val import RTDETRValidator
18
19
 
19
20
 
20
21
  class RTDETR(Model):
21
- """
22
- Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
22
+ """Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
23
23
 
24
24
  This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware query
25
25
  selection, and adaptable inference speed.
@@ -27,28 +27,28 @@ class RTDETR(Model):
27
27
  Attributes:
28
28
  model (str): Path to the pre-trained model.
29
29
 
30
+ Methods:
31
+ task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
32
+
30
33
  Examples:
34
+ Initialize RT-DETR with a pre-trained model
31
35
  >>> from ultralytics import RTDETR
32
36
  >>> model = RTDETR("rtdetr-l.pt")
33
37
  >>> results = model("image.jpg")
34
38
  """
35
39
 
36
40
  def __init__(self, model: str = "rtdetr-l.pt") -> None:
37
- """
38
- Initialize the RT-DETR model with the given pre-trained model file.
41
+ """Initialize the RT-DETR model with the given pre-trained model file.
39
42
 
40
43
  Args:
41
44
  model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
42
-
43
- Raises:
44
- NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
45
45
  """
46
+ assert TORCH_1_11, "RTDETR requires torch>=1.11"
46
47
  super().__init__(model=model, task="detect")
47
48
 
48
49
  @property
49
50
  def task_map(self) -> dict:
50
- """
51
- Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
51
+ """Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
52
52
 
53
53
  Returns:
54
54
  (dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
@@ -9,11 +9,10 @@ from ultralytics.utils import ops
9
9
 
10
10
 
11
11
  class RTDETRPredictor(BasePredictor):
12
- """
13
- RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
12
+ """RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
14
13
 
15
- This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.
16
- It supports key features like efficient hybrid encoding and IoU-aware query selection.
14
+ This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy. It
15
+ supports key features like efficient hybrid encoding and IoU-aware query selection.
17
16
 
18
17
  Attributes:
19
18
  imgsz (int): Image size for inference (must be square and scale-filled).
@@ -21,6 +20,10 @@ class RTDETRPredictor(BasePredictor):
21
20
  model (torch.nn.Module): The loaded RT-DETR model.
22
21
  batch (list): Current batch of processed inputs.
23
22
 
23
+ Methods:
24
+ postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.
25
+ pre_transform: Pre-transform input images before feeding them into the model for inference.
26
+
24
27
  Examples:
25
28
  >>> from ultralytics.utils import ASSETS
26
29
  >>> from ultralytics.models.rtdetr import RTDETRPredictor
@@ -30,21 +33,20 @@ class RTDETRPredictor(BasePredictor):
30
33
  """
31
34
 
32
35
  def postprocess(self, preds, img, orig_imgs):
33
- """
34
- Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
36
+ """Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
35
37
 
36
- The method filters detections based on confidence and class if specified in `self.args`. It converts
37
- model predictions to Results objects containing properly scaled bounding boxes.
38
+ The method filters detections based on confidence and class if specified in `self.args`. It converts model
39
+ predictions to Results objects containing properly scaled bounding boxes.
38
40
 
39
41
  Args:
40
- preds (List | Tuple): List of [predictions, extra] from the model, where predictions contain
41
- bounding boxes and scores.
42
+ preds (list | tuple): List of [predictions, extra] from the model, where predictions contain bounding boxes
43
+ and scores.
42
44
  img (torch.Tensor): Processed input images with shape (N, 3, H, W).
43
- orig_imgs (List | torch.Tensor): Original, unprocessed images.
45
+ orig_imgs (list | torch.Tensor): Original, unprocessed images.
44
46
 
45
47
  Returns:
46
- (List[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
47
- and class labels.
48
+ results (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence
49
+ scores, and class labels.
48
50
  """
49
51
  if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
50
52
  preds = [preds, None]
@@ -63,6 +65,7 @@ class RTDETRPredictor(BasePredictor):
63
65
  if self.args.classes is not None:
64
66
  idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
65
67
  pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
68
+ pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det]
66
69
  oh, ow = orig_img.shape[:2]
67
70
  pred[..., [0, 2]] *= ow # scale x coordinates to original width
68
71
  pred[..., [1, 3]] *= oh # scale y coordinates to original height
@@ -70,12 +73,14 @@ class RTDETRPredictor(BasePredictor):
70
73
  return results
71
74
 
72
75
  def pre_transform(self, im):
73
- """
74
- Pre-transforms the input images before feeding them into the model for inference. The input images are
75
- letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scale_filled.
76
+ """Pre-transform input images before feeding them into the model for inference.
77
+
78
+ The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square (640)
79
+ and scale_filled.
76
80
 
77
81
  Args:
78
- im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.
82
+ im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for
83
+ list.
79
84
 
80
85
  Returns:
81
86
  (list): List of pre-transformed images ready for model inference.
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from copy import copy
4
6
 
5
7
  from ultralytics.models.yolo.detect import DetectionTrainer
@@ -10,34 +12,37 @@ from .val import RTDETRDataset, RTDETRValidator
10
12
 
11
13
 
12
14
  class RTDETRTrainer(DetectionTrainer):
13
- """
14
- Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
15
+ """Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
15
16
 
16
- This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.
17
- The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference
18
- speed.
17
+ This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of
18
+ RT-DETR. The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable
19
+ inference speed.
19
20
 
20
21
  Attributes:
21
- loss_names (Tuple[str]): Names of the loss components used for training.
22
+ loss_names (tuple): Names of the loss components used for training.
22
23
  data (dict): Dataset configuration containing class count and other parameters.
23
24
  args (dict): Training arguments and hyperparameters.
24
25
  save_dir (Path): Directory to save training results.
25
26
  test_loader (DataLoader): DataLoader for validation/testing data.
26
27
 
27
- Notes:
28
- - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
29
- - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
28
+ Methods:
29
+ get_model: Initialize and return an RT-DETR model for object detection tasks.
30
+ build_dataset: Build and return an RT-DETR dataset for training or validation.
31
+ get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
30
32
 
31
33
  Examples:
32
34
  >>> from ultralytics.models.rtdetr.train import RTDETRTrainer
33
35
  >>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
34
36
  >>> trainer = RTDETRTrainer(overrides=args)
35
37
  >>> trainer.train()
38
+
39
+ Notes:
40
+ - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
41
+ - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
36
42
  """
37
43
 
38
- def get_model(self, cfg=None, weights=None, verbose=True):
39
- """
40
- Initialize and return an RT-DETR model for object detection tasks.
44
+ def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
45
+ """Initialize and return an RT-DETR model for object detection tasks.
41
46
 
42
47
  Args:
43
48
  cfg (dict, optional): Model configuration.
@@ -52,9 +57,8 @@ class RTDETRTrainer(DetectionTrainer):
52
57
  model.load(weights)
53
58
  return model
54
59
 
55
- def build_dataset(self, img_path, mode="val", batch=None):
56
- """
57
- Build and return an RT-DETR dataset for training or validation.
60
+ def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
61
+ """Build and return an RT-DETR dataset for training or validation.
58
62
 
59
63
  Args:
60
64
  img_path (str): Path to the folder containing images.
@@ -80,6 +84,6 @@ class RTDETRTrainer(DetectionTrainer):
80
84
  )
81
85
 
82
86
  def get_validator(self):
83
- """Returns a DetectionValidator suitable for RT-DETR model validation."""
87
+ """Return a DetectionValidator suitable for RT-DETR model validation."""
84
88
  self.loss_names = "giou_loss", "cls_loss", "l1_loss"
85
89
  return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))