ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (134) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/data/__init__.py +9 -2
  4. ultralytics/data/annotator.py +4 -4
  5. ultralytics/data/augment.py +186 -169
  6. ultralytics/data/base.py +54 -48
  7. ultralytics/data/build.py +34 -23
  8. ultralytics/data/converter.py +242 -70
  9. ultralytics/data/dataset.py +117 -95
  10. ultralytics/data/explorer/__init__.py +3 -1
  11. ultralytics/data/explorer/explorer.py +120 -100
  12. ultralytics/data/explorer/gui/__init__.py +1 -0
  13. ultralytics/data/explorer/gui/dash.py +123 -89
  14. ultralytics/data/explorer/utils.py +37 -39
  15. ultralytics/data/loaders.py +75 -62
  16. ultralytics/data/split_dota.py +44 -36
  17. ultralytics/data/utils.py +160 -142
  18. ultralytics/engine/exporter.py +348 -292
  19. ultralytics/engine/model.py +102 -66
  20. ultralytics/engine/predictor.py +74 -55
  21. ultralytics/engine/results.py +61 -41
  22. ultralytics/engine/trainer.py +192 -144
  23. ultralytics/engine/tuner.py +66 -59
  24. ultralytics/engine/validator.py +31 -26
  25. ultralytics/hub/__init__.py +54 -31
  26. ultralytics/hub/auth.py +28 -25
  27. ultralytics/hub/session.py +282 -133
  28. ultralytics/hub/utils.py +64 -42
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/model.py +6 -6
  32. ultralytics/models/fastsam/predict.py +3 -2
  33. ultralytics/models/fastsam/prompt.py +55 -48
  34. ultralytics/models/fastsam/val.py +1 -1
  35. ultralytics/models/nas/__init__.py +1 -1
  36. ultralytics/models/nas/model.py +9 -8
  37. ultralytics/models/nas/predict.py +8 -6
  38. ultralytics/models/nas/val.py +11 -9
  39. ultralytics/models/rtdetr/__init__.py +1 -1
  40. ultralytics/models/rtdetr/model.py +11 -9
  41. ultralytics/models/rtdetr/train.py +18 -16
  42. ultralytics/models/rtdetr/val.py +25 -19
  43. ultralytics/models/sam/__init__.py +1 -1
  44. ultralytics/models/sam/amg.py +13 -14
  45. ultralytics/models/sam/build.py +44 -42
  46. ultralytics/models/sam/model.py +6 -6
  47. ultralytics/models/sam/modules/decoders.py +6 -4
  48. ultralytics/models/sam/modules/encoders.py +37 -35
  49. ultralytics/models/sam/modules/sam.py +5 -4
  50. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  51. ultralytics/models/sam/modules/transformer.py +3 -2
  52. ultralytics/models/sam/predict.py +39 -27
  53. ultralytics/models/utils/loss.py +99 -95
  54. ultralytics/models/utils/ops.py +34 -31
  55. ultralytics/models/yolo/__init__.py +1 -1
  56. ultralytics/models/yolo/classify/__init__.py +1 -1
  57. ultralytics/models/yolo/classify/predict.py +8 -6
  58. ultralytics/models/yolo/classify/train.py +37 -31
  59. ultralytics/models/yolo/classify/val.py +26 -24
  60. ultralytics/models/yolo/detect/__init__.py +1 -1
  61. ultralytics/models/yolo/detect/predict.py +8 -6
  62. ultralytics/models/yolo/detect/train.py +47 -37
  63. ultralytics/models/yolo/detect/val.py +100 -82
  64. ultralytics/models/yolo/model.py +31 -25
  65. ultralytics/models/yolo/obb/__init__.py +1 -1
  66. ultralytics/models/yolo/obb/predict.py +13 -11
  67. ultralytics/models/yolo/obb/train.py +3 -3
  68. ultralytics/models/yolo/obb/val.py +70 -59
  69. ultralytics/models/yolo/pose/__init__.py +1 -1
  70. ultralytics/models/yolo/pose/predict.py +17 -12
  71. ultralytics/models/yolo/pose/train.py +28 -25
  72. ultralytics/models/yolo/pose/val.py +91 -64
  73. ultralytics/models/yolo/segment/__init__.py +1 -1
  74. ultralytics/models/yolo/segment/predict.py +10 -8
  75. ultralytics/models/yolo/segment/train.py +16 -15
  76. ultralytics/models/yolo/segment/val.py +90 -68
  77. ultralytics/nn/__init__.py +26 -6
  78. ultralytics/nn/autobackend.py +144 -112
  79. ultralytics/nn/modules/__init__.py +96 -13
  80. ultralytics/nn/modules/block.py +28 -7
  81. ultralytics/nn/modules/conv.py +41 -23
  82. ultralytics/nn/modules/head.py +60 -52
  83. ultralytics/nn/modules/transformer.py +49 -32
  84. ultralytics/nn/modules/utils.py +20 -15
  85. ultralytics/nn/tasks.py +215 -141
  86. ultralytics/solutions/ai_gym.py +59 -47
  87. ultralytics/solutions/distance_calculation.py +17 -14
  88. ultralytics/solutions/heatmap.py +57 -55
  89. ultralytics/solutions/object_counter.py +46 -39
  90. ultralytics/solutions/speed_estimation.py +13 -16
  91. ultralytics/trackers/__init__.py +1 -1
  92. ultralytics/trackers/basetrack.py +1 -0
  93. ultralytics/trackers/bot_sort.py +2 -1
  94. ultralytics/trackers/byte_tracker.py +10 -7
  95. ultralytics/trackers/track.py +7 -7
  96. ultralytics/trackers/utils/gmc.py +25 -25
  97. ultralytics/trackers/utils/kalman_filter.py +85 -42
  98. ultralytics/trackers/utils/matching.py +8 -7
  99. ultralytics/utils/__init__.py +173 -152
  100. ultralytics/utils/autobatch.py +10 -10
  101. ultralytics/utils/benchmarks.py +76 -86
  102. ultralytics/utils/callbacks/__init__.py +1 -1
  103. ultralytics/utils/callbacks/base.py +29 -29
  104. ultralytics/utils/callbacks/clearml.py +51 -43
  105. ultralytics/utils/callbacks/comet.py +81 -66
  106. ultralytics/utils/callbacks/dvc.py +33 -26
  107. ultralytics/utils/callbacks/hub.py +44 -26
  108. ultralytics/utils/callbacks/mlflow.py +31 -24
  109. ultralytics/utils/callbacks/neptune.py +35 -25
  110. ultralytics/utils/callbacks/raytune.py +9 -4
  111. ultralytics/utils/callbacks/tensorboard.py +16 -11
  112. ultralytics/utils/callbacks/wb.py +39 -33
  113. ultralytics/utils/checks.py +189 -141
  114. ultralytics/utils/dist.py +15 -12
  115. ultralytics/utils/downloads.py +112 -96
  116. ultralytics/utils/errors.py +1 -1
  117. ultralytics/utils/files.py +11 -11
  118. ultralytics/utils/instance.py +22 -22
  119. ultralytics/utils/loss.py +117 -67
  120. ultralytics/utils/metrics.py +224 -158
  121. ultralytics/utils/ops.py +38 -28
  122. ultralytics/utils/patches.py +3 -3
  123. ultralytics/utils/plotting.py +217 -120
  124. ultralytics/utils/tal.py +19 -13
  125. ultralytics/utils/torch_utils.py +138 -109
  126. ultralytics/utils/triton.py +12 -10
  127. ultralytics/utils/tuner.py +49 -47
  128. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
  129. ultralytics-8.0.239.dist-info/RECORD +188 -0
  130. ultralytics-8.0.238.dist-info/RECORD +0 -188
  131. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  132. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  133. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  134. {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -4,4 +4,4 @@ from .rtdetr import RTDETR
4
4
  from .sam import SAM
5
5
  from .yolo import YOLO
6
6
 
7
- __all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
7
+ __all__ = "YOLO", "RTDETR", "SAM" # allow simpler import
@@ -5,4 +5,4 @@ from .predict import FastSAMPredictor
5
5
  from .prompt import FastSAMPrompt
6
6
  from .val import FastSAMValidator
7
7
 
8
- __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMValidator'
8
+ __all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator"
@@ -21,14 +21,14 @@ class FastSAM(Model):
21
21
  ```
22
22
  """
23
23
 
24
- def __init__(self, model='FastSAM-x.pt'):
24
+ def __init__(self, model="FastSAM-x.pt"):
25
25
  """Call the __init__ method of the parent class (YOLO) with the updated default model."""
26
- if str(model) == 'FastSAM.pt':
27
- model = 'FastSAM-x.pt'
28
- assert Path(model).suffix not in ('.yaml', '.yml'), 'FastSAM models only support pre-trained models.'
29
- super().__init__(model=model, task='segment')
26
+ if str(model) == "FastSAM.pt":
27
+ model = "FastSAM-x.pt"
28
+ assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
29
+ super().__init__(model=model, task="segment")
30
30
 
31
31
  @property
32
32
  def task_map(self):
33
33
  """Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
34
- return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
34
+ return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
@@ -33,7 +33,7 @@ class FastSAMPredictor(DetectionPredictor):
33
33
  _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
34
34
  """
35
35
  super().__init__(cfg, overrides, _callbacks)
36
- self.args.task = 'segment'
36
+ self.args.task = "segment"
37
37
 
38
38
  def postprocess(self, preds, img, orig_imgs):
39
39
  """
@@ -55,7 +55,8 @@ class FastSAMPredictor(DetectionPredictor):
55
55
  agnostic=self.args.agnostic_nms,
56
56
  max_det=self.args.max_det,
57
57
  nc=1, # set to 1 class since SAM has no class predictions
58
- classes=self.args.classes)
58
+ classes=self.args.classes,
59
+ )
59
60
  full_box = torch.zeros(p[0].shape[1], device=p[0].device)
60
61
  full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
61
62
  full_box = full_box.view(1, -1)
@@ -23,7 +23,7 @@ class FastSAMPrompt:
23
23
  clip: CLIP model for linear assignment.
24
24
  """
25
25
 
26
- def __init__(self, source, results, device='cuda') -> None:
26
+ def __init__(self, source, results, device="cuda") -> None:
27
27
  """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
28
28
  self.device = device
29
29
  self.results = results
@@ -34,7 +34,8 @@ class FastSAMPrompt:
34
34
  import clip # for linear_assignment
35
35
  except ImportError:
36
36
  from ultralytics.utils.checks import check_requirements
37
- check_requirements('git+https://github.com/openai/CLIP.git')
37
+
38
+ check_requirements("git+https://github.com/openai/CLIP.git")
38
39
  import clip
39
40
  self.clip = clip
40
41
 
@@ -46,11 +47,11 @@ class FastSAMPrompt:
46
47
  x1, y1, x2, y2 = bbox
47
48
  segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
48
49
  segmented_image = Image.fromarray(segmented_image_array)
49
- black_image = Image.new('RGB', image.size, (255, 255, 255))
50
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
50
51
  # transparency_mask = np.zeros_like((), dtype=np.uint8)
51
52
  transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
52
53
  transparency_mask[y1:y2, x1:x2] = 255
53
- transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
54
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
54
55
  black_image.paste(segmented_image, mask=transparency_mask_image)
55
56
  return black_image
56
57
 
@@ -65,11 +66,12 @@ class FastSAMPrompt:
65
66
  mask = result.masks.data[i] == 1.0
66
67
  if torch.sum(mask) >= filter:
67
68
  annotation = {
68
- 'id': i,
69
- 'segmentation': mask.cpu().numpy(),
70
- 'bbox': result.boxes.data[i],
71
- 'score': result.boxes.conf[i]}
72
- annotation['area'] = annotation['segmentation'].sum()
69
+ "id": i,
70
+ "segmentation": mask.cpu().numpy(),
71
+ "bbox": result.boxes.data[i],
72
+ "score": result.boxes.conf[i],
73
+ }
74
+ annotation["area"] = annotation["segmentation"].sum()
73
75
  annotations.append(annotation)
74
76
  return annotations
75
77
 
@@ -91,16 +93,18 @@ class FastSAMPrompt:
91
93
  y2 = max(y2, y_t + h_t)
92
94
  return [x1, y1, x2, y2]
93
95
 
94
- def plot(self,
95
- annotations,
96
- output,
97
- bbox=None,
98
- points=None,
99
- point_label=None,
100
- mask_random_color=True,
101
- better_quality=True,
102
- retina=False,
103
- with_contours=True):
96
+ def plot(
97
+ self,
98
+ annotations,
99
+ output,
100
+ bbox=None,
101
+ points=None,
102
+ point_label=None,
103
+ mask_random_color=True,
104
+ better_quality=True,
105
+ retina=False,
106
+ with_contours=True,
107
+ ):
104
108
  """
105
109
  Plots annotations, bounding boxes, and points on images and saves the output.
106
110
 
@@ -139,15 +143,17 @@ class FastSAMPrompt:
139
143
  mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
140
144
  masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
141
145
 
142
- self.fast_show_mask(masks,
143
- plt.gca(),
144
- random_color=mask_random_color,
145
- bbox=bbox,
146
- points=points,
147
- pointlabel=point_label,
148
- retinamask=retina,
149
- target_height=original_h,
150
- target_width=original_w)
146
+ self.fast_show_mask(
147
+ masks,
148
+ plt.gca(),
149
+ random_color=mask_random_color,
150
+ bbox=bbox,
151
+ points=points,
152
+ pointlabel=point_label,
153
+ retinamask=retina,
154
+ target_height=original_h,
155
+ target_width=original_w,
156
+ )
151
157
 
152
158
  if with_contours:
153
159
  contour_all = []
@@ -166,10 +172,10 @@ class FastSAMPrompt:
166
172
  # Save the figure
167
173
  save_path = Path(output) / result_name
168
174
  save_path.parent.mkdir(exist_ok=True, parents=True)
169
- plt.axis('off')
170
- plt.savefig(save_path, bbox_inches='tight', pad_inches=0, transparent=True)
175
+ plt.axis("off")
176
+ plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
171
177
  plt.close()
172
- pbar.set_description(f'Saving {result_name} to {save_path}')
178
+ pbar.set_description(f"Saving {result_name} to {save_path}")
173
179
 
174
180
  @staticmethod
175
181
  def fast_show_mask(
@@ -212,26 +218,26 @@ class FastSAMPrompt:
212
218
  mask_image = np.expand_dims(annotation, -1) * visual
213
219
 
214
220
  show = np.zeros((h, w, 4))
215
- h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
221
+ h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
216
222
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
217
223
 
218
224
  show[h_indices, w_indices, :] = mask_image[indices]
219
225
  if bbox is not None:
220
226
  x1, y1, x2, y2 = bbox
221
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
227
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
222
228
  # Draw point
223
229
  if points is not None:
224
230
  plt.scatter(
225
231
  [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
226
232
  [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
227
233
  s=20,
228
- c='y',
234
+ c="y",
229
235
  )
230
236
  plt.scatter(
231
237
  [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
232
238
  [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
233
239
  s=20,
234
- c='m',
240
+ c="m",
235
241
  )
236
242
 
237
243
  if not retinamask:
@@ -258,7 +264,7 @@ class FastSAMPrompt:
258
264
  image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
259
265
  ori_w, ori_h = image.size
260
266
  annotations = format_results
261
- mask_h, mask_w = annotations[0]['segmentation'].shape
267
+ mask_h, mask_w = annotations[0]["segmentation"].shape
262
268
  if ori_w != mask_w or ori_h != mask_h:
263
269
  image = image.resize((mask_w, mask_h))
264
270
  cropped_boxes = []
@@ -266,19 +272,19 @@ class FastSAMPrompt:
266
272
  not_crop = []
267
273
  filter_id = []
268
274
  for _, mask in enumerate(annotations):
269
- if np.sum(mask['segmentation']) <= 100:
275
+ if np.sum(mask["segmentation"]) <= 100:
270
276
  filter_id.append(_)
271
277
  continue
272
- bbox = self._get_bbox_from_mask(mask['segmentation']) # mask bbox
273
- cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
274
- cropped_images.append(bbox) # 保存裁剪的图片的bbox
278
+ bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
279
+ cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
280
+ cropped_images.append(bbox) # save cropped image bbox
275
281
 
276
282
  return cropped_boxes, cropped_images, not_crop, filter_id, annotations
277
283
 
278
284
  def box_prompt(self, bbox):
279
285
  """Modifies the bounding box properties and calculates IoU between masks and bounding box."""
280
286
  if self.results[0].masks is not None:
281
- assert (bbox[2] != 0 and bbox[3] != 0)
287
+ assert bbox[2] != 0 and bbox[3] != 0
282
288
  if os.path.isdir(self.source):
283
289
  raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
284
290
  masks = self.results[0].masks.data
@@ -290,7 +296,8 @@ class FastSAMPrompt:
290
296
  int(bbox[0] * w / target_width),
291
297
  int(bbox[1] * h / target_height),
292
298
  int(bbox[2] * w / target_width),
293
- int(bbox[3] * h / target_height), ]
299
+ int(bbox[3] * h / target_height),
300
+ ]
294
301
  bbox[0] = max(round(bbox[0]), 0)
295
302
  bbox[1] = max(round(bbox[1]), 0)
296
303
  bbox[2] = min(round(bbox[2]), w)
@@ -299,7 +306,7 @@ class FastSAMPrompt:
299
306
  # IoUs = torch.zeros(len(masks), dtype=torch.float32)
300
307
  bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
301
308
 
302
- masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
309
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
303
310
  orig_masks_area = torch.sum(masks, dim=(1, 2))
304
311
 
305
312
  union = bbox_area + orig_masks_area - masks_area
@@ -316,13 +323,13 @@ class FastSAMPrompt:
316
323
  raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
317
324
  masks = self._format_results(self.results[0], 0)
318
325
  target_height, target_width = self.results[0].orig_shape
319
- h = masks[0]['segmentation'].shape[0]
320
- w = masks[0]['segmentation'].shape[1]
326
+ h = masks[0]["segmentation"].shape[0]
327
+ w = masks[0]["segmentation"].shape[1]
321
328
  if h != target_height or w != target_width:
322
329
  points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
323
330
  onemask = np.zeros((h, w))
324
331
  for annotation in masks:
325
- mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
332
+ mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
326
333
  for i, point in enumerate(points):
327
334
  if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
328
335
  onemask += mask
@@ -337,12 +344,12 @@ class FastSAMPrompt:
337
344
  if self.results[0].masks is not None:
338
345
  format_results = self._format_results(self.results[0], 0)
339
346
  cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
340
- clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
347
+ clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
341
348
  scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
342
349
  max_idx = scores.argsort()
343
350
  max_idx = max_idx[-1]
344
351
  max_idx += sum(np.array(filter_id) <= int(max_idx))
345
- self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]['segmentation']]))
352
+ self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
346
353
  return self.results
347
354
 
348
355
  def everything_prompt(self):
@@ -35,6 +35,6 @@ class FastSAMValidator(SegmentationValidator):
35
35
  Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
36
36
  """
37
37
  super().__init__(dataloader, save_dir, pbar, args, _callbacks)
38
- self.args.task = 'segment'
38
+ self.args.task = "segment"
39
39
  self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
40
40
  self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
@@ -4,4 +4,4 @@ from .model import NAS
4
4
  from .predict import NASPredictor
5
5
  from .val import NASValidator
6
6
 
7
- __all__ = 'NASPredictor', 'NASValidator', 'NAS'
7
+ __all__ = "NASPredictor", "NASValidator", "NAS"
@@ -44,20 +44,21 @@ class NAS(Model):
44
44
  YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
45
45
  """
46
46
 
47
- def __init__(self, model='yolo_nas_s.pt') -> None:
47
+ def __init__(self, model="yolo_nas_s.pt") -> None:
48
48
  """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
49
- assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.'
50
- super().__init__(model, task='detect')
49
+ assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
50
+ super().__init__(model, task="detect")
51
51
 
52
52
  @smart_inference_mode()
53
53
  def _load(self, weights: str, task: str):
54
54
  """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
55
55
  import super_gradients
56
+
56
57
  suffix = Path(weights).suffix
57
- if suffix == '.pt':
58
+ if suffix == ".pt":
58
59
  self.model = torch.load(weights)
59
- elif suffix == '':
60
- self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
60
+ elif suffix == "":
61
+ self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
61
62
  # Standardize model
62
63
  self.model.fuse = lambda verbose=True: self.model
63
64
  self.model.stride = torch.tensor([32])
@@ -65,7 +66,7 @@ class NAS(Model):
65
66
  self.model.is_fused = lambda: False # for info()
66
67
  self.model.yaml = {} # for info()
67
68
  self.model.pt_path = weights # for export()
68
- self.model.task = 'detect' # for export()
69
+ self.model.task = "detect" # for export()
69
70
 
70
71
  def info(self, detailed=False, verbose=True):
71
72
  """
@@ -80,4 +81,4 @@ class NAS(Model):
80
81
  @property
81
82
  def task_map(self):
82
83
  """Returns a dictionary mapping tasks to respective predictor and validator classes."""
83
- return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
84
+ return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
@@ -39,12 +39,14 @@ class NASPredictor(BasePredictor):
39
39
  boxes = ops.xyxy2xywh(preds_in[0][0])
40
40
  preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
41
41
 
42
- preds = ops.non_max_suppression(preds,
43
- self.args.conf,
44
- self.args.iou,
45
- agnostic=self.args.agnostic_nms,
46
- max_det=self.args.max_det,
47
- classes=self.args.classes)
42
+ preds = ops.non_max_suppression(
43
+ preds,
44
+ self.args.conf,
45
+ self.args.iou,
46
+ agnostic=self.args.agnostic_nms,
47
+ max_det=self.args.max_det,
48
+ classes=self.args.classes,
49
+ )
48
50
 
49
51
  if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
50
52
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
@@ -5,7 +5,7 @@ import torch
5
5
  from ultralytics.models.yolo.detect import DetectionValidator
6
6
  from ultralytics.utils import ops
7
7
 
8
- __all__ = ['NASValidator']
8
+ __all__ = ["NASValidator"]
9
9
 
10
10
 
11
11
  class NASValidator(DetectionValidator):
@@ -38,11 +38,13 @@ class NASValidator(DetectionValidator):
38
38
  """Apply Non-maximum suppression to prediction outputs."""
39
39
  boxes = ops.xyxy2xywh(preds_in[0][0])
40
40
  preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
41
- return ops.non_max_suppression(preds,
42
- self.args.conf,
43
- self.args.iou,
44
- labels=self.lb,
45
- multi_label=False,
46
- agnostic=self.args.single_cls,
47
- max_det=self.args.max_det,
48
- max_time_img=0.5)
41
+ return ops.non_max_suppression(
42
+ preds,
43
+ self.args.conf,
44
+ self.args.iou,
45
+ labels=self.lb,
46
+ multi_label=False,
47
+ agnostic=self.args.single_cls,
48
+ max_det=self.args.max_det,
49
+ max_time_img=0.5,
50
+ )
@@ -4,4 +4,4 @@ from .model import RTDETR
4
4
  from .predict import RTDETRPredictor
5
5
  from .val import RTDETRValidator
6
6
 
7
- __all__ = 'RTDETRPredictor', 'RTDETRValidator', 'RTDETR'
7
+ __all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"
@@ -24,7 +24,7 @@ class RTDETR(Model):
24
24
  model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
25
25
  """
26
26
 
27
- def __init__(self, model='rtdetr-l.pt') -> None:
27
+ def __init__(self, model="rtdetr-l.pt") -> None:
28
28
  """
29
29
  Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
30
30
 
@@ -34,9 +34,9 @@ class RTDETR(Model):
34
34
  Raises:
35
35
  NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
36
36
  """
37
- if model and model.split('.')[-1] not in ('pt', 'yaml', 'yml'):
38
- raise NotImplementedError('RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.')
39
- super().__init__(model=model, task='detect')
37
+ if model and model.split(".")[-1] not in ("pt", "yaml", "yml"):
38
+ raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
39
+ super().__init__(model=model, task="detect")
40
40
 
41
41
  @property
42
42
  def task_map(self) -> dict:
@@ -47,8 +47,10 @@ class RTDETR(Model):
47
47
  dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
48
48
  """
49
49
  return {
50
- 'detect': {
51
- 'predictor': RTDETRPredictor,
52
- 'validator': RTDETRValidator,
53
- 'trainer': RTDETRTrainer,
54
- 'model': RTDETRDetectionModel}}
50
+ "detect": {
51
+ "predictor": RTDETRPredictor,
52
+ "validator": RTDETRValidator,
53
+ "trainer": RTDETRTrainer,
54
+ "model": RTDETRDetectionModel,
55
+ }
56
+ }
@@ -43,12 +43,12 @@ class RTDETRTrainer(DetectionTrainer):
43
43
  Returns:
44
44
  (RTDETRDetectionModel): Initialized model.
45
45
  """
46
- model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
46
+ model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
47
47
  if weights:
48
48
  model.load(weights)
49
49
  return model
50
50
 
51
- def build_dataset(self, img_path, mode='val', batch=None):
51
+ def build_dataset(self, img_path, mode="val", batch=None):
52
52
  """
53
53
  Build and return an RT-DETR dataset for training or validation.
54
54
 
@@ -60,15 +60,17 @@ class RTDETRTrainer(DetectionTrainer):
60
60
  Returns:
61
61
  (RTDETRDataset): Dataset object for the specific mode.
62
62
  """
63
- return RTDETRDataset(img_path=img_path,
64
- imgsz=self.args.imgsz,
65
- batch_size=batch,
66
- augment=mode == 'train',
67
- hyp=self.args,
68
- rect=False,
69
- cache=self.args.cache or None,
70
- prefix=colorstr(f'{mode}: '),
71
- data=self.data)
63
+ return RTDETRDataset(
64
+ img_path=img_path,
65
+ imgsz=self.args.imgsz,
66
+ batch_size=batch,
67
+ augment=mode == "train",
68
+ hyp=self.args,
69
+ rect=False,
70
+ cache=self.args.cache or None,
71
+ prefix=colorstr(f"{mode}: "),
72
+ data=self.data,
73
+ )
72
74
 
73
75
  def get_validator(self):
74
76
  """
@@ -77,7 +79,7 @@ class RTDETRTrainer(DetectionTrainer):
77
79
  Returns:
78
80
  (RTDETRValidator): Validator object for model validation.
79
81
  """
80
- self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
82
+ self.loss_names = "giou_loss", "cls_loss", "l1_loss"
81
83
  return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
82
84
 
83
85
  def preprocess_batch(self, batch):
@@ -91,10 +93,10 @@ class RTDETRTrainer(DetectionTrainer):
91
93
  (dict): Preprocessed batch.
92
94
  """
93
95
  batch = super().preprocess_batch(batch)
94
- bs = len(batch['img'])
95
- batch_idx = batch['batch_idx']
96
+ bs = len(batch["img"])
97
+ batch_idx = batch["batch_idx"]
96
98
  gt_bbox, gt_class = [], []
97
99
  for i in range(bs):
98
- gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
99
- gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
100
+ gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
101
+ gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
100
102
  return batch
@@ -7,7 +7,7 @@ from ultralytics.data.augment import Compose, Format, v8_transforms
7
7
  from ultralytics.models.yolo.detect import DetectionValidator
8
8
  from ultralytics.utils import colorstr, ops
9
9
 
10
- __all__ = 'RTDETRValidator', # tuple or list
10
+ __all__ = ("RTDETRValidator",) # tuple or list
11
11
 
12
12
 
13
13
  class RTDETRDataset(YOLODataset):
@@ -37,13 +37,16 @@ class RTDETRDataset(YOLODataset):
37
37
  # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
38
38
  transforms = Compose([])
39
39
  transforms.append(
40
- Format(bbox_format='xywh',
41
- normalize=True,
42
- return_mask=self.use_segments,
43
- return_keypoint=self.use_keypoints,
44
- batch_idx=True,
45
- mask_ratio=hyp.mask_ratio,
46
- mask_overlap=hyp.overlap_mask))
40
+ Format(
41
+ bbox_format="xywh",
42
+ normalize=True,
43
+ return_mask=self.use_segments,
44
+ return_keypoint=self.use_keypoints,
45
+ batch_idx=True,
46
+ mask_ratio=hyp.mask_ratio,
47
+ mask_overlap=hyp.overlap_mask,
48
+ )
49
+ )
47
50
  return transforms
48
51
 
49
52
 
@@ -68,7 +71,7 @@ class RTDETRValidator(DetectionValidator):
68
71
  For further details on the attributes and methods, refer to the parent DetectionValidator class.
69
72
  """
70
73
 
71
- def build_dataset(self, img_path, mode='val', batch=None):
74
+ def build_dataset(self, img_path, mode="val", batch=None):
72
75
  """
73
76
  Build an RTDETR Dataset.
74
77
 
@@ -85,8 +88,9 @@ class RTDETRValidator(DetectionValidator):
85
88
  hyp=self.args,
86
89
  rect=False, # no rect
87
90
  cache=self.args.cache or None,
88
- prefix=colorstr(f'{mode}: '),
89
- data=self.data)
91
+ prefix=colorstr(f"{mode}: "),
92
+ data=self.data,
93
+ )
90
94
 
91
95
  def postprocess(self, preds):
92
96
  """Apply Non-maximum suppression to prediction outputs."""
@@ -107,12 +111,13 @@ class RTDETRValidator(DetectionValidator):
107
111
  return outputs
108
112
 
109
113
  def _prepare_batch(self, si, batch):
110
- idx = batch['batch_idx'] == si
111
- cls = batch['cls'][idx].squeeze(-1)
112
- bbox = batch['bboxes'][idx]
113
- ori_shape = batch['ori_shape'][si]
114
- imgsz = batch['img'].shape[2:]
115
- ratio_pad = batch['ratio_pad'][si]
114
+ """Prepares a batch for training or inference by applying transformations."""
115
+ idx = batch["batch_idx"] == si
116
+ cls = batch["cls"][idx].squeeze(-1)
117
+ bbox = batch["bboxes"][idx]
118
+ ori_shape = batch["ori_shape"][si]
119
+ imgsz = batch["img"].shape[2:]
120
+ ratio_pad = batch["ratio_pad"][si]
116
121
  if len(cls):
117
122
  bbox = ops.xywh2xyxy(bbox) # target boxes
118
123
  bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
@@ -121,7 +126,8 @@ class RTDETRValidator(DetectionValidator):
121
126
  return prepared_batch
122
127
 
123
128
  def _prepare_pred(self, pred, pbatch):
129
+ """Prepares and returns a batch with transformed bounding boxes and class labels."""
124
130
  predn = pred.clone()
125
- predn[..., [0, 2]] *= pbatch['ori_shape'][1] / self.args.imgsz # native-space pred
126
- predn[..., [1, 3]] *= pbatch['ori_shape'][0] / self.args.imgsz # native-space pred
131
+ predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
132
+ predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
127
133
  return predn.float()
@@ -3,4 +3,4 @@
3
3
  from .model import SAM
4
4
  from .predict import Predictor
5
5
 
6
- __all__ = 'SAM', 'Predictor' # tuple or list
6
+ __all__ = "SAM", "Predictor" # tuple or list