ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/models/__init__.py
CHANGED
|
@@ -21,14 +21,14 @@ class FastSAM(Model):
|
|
|
21
21
|
```
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
|
-
def __init__(self, model=
|
|
24
|
+
def __init__(self, model="FastSAM-x.pt"):
|
|
25
25
|
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
|
26
|
-
if str(model) ==
|
|
27
|
-
model =
|
|
28
|
-
assert Path(model).suffix not in (
|
|
29
|
-
super().__init__(model=model, task=
|
|
26
|
+
if str(model) == "FastSAM.pt":
|
|
27
|
+
model = "FastSAM-x.pt"
|
|
28
|
+
assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
|
|
29
|
+
super().__init__(model=model, task="segment")
|
|
30
30
|
|
|
31
31
|
@property
|
|
32
32
|
def task_map(self):
|
|
33
33
|
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
|
34
|
-
return {
|
|
34
|
+
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
|
|
@@ -33,7 +33,7 @@ class FastSAMPredictor(DetectionPredictor):
|
|
|
33
33
|
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
|
|
34
34
|
"""
|
|
35
35
|
super().__init__(cfg, overrides, _callbacks)
|
|
36
|
-
self.args.task =
|
|
36
|
+
self.args.task = "segment"
|
|
37
37
|
|
|
38
38
|
def postprocess(self, preds, img, orig_imgs):
|
|
39
39
|
"""
|
|
@@ -55,7 +55,8 @@ class FastSAMPredictor(DetectionPredictor):
|
|
|
55
55
|
agnostic=self.args.agnostic_nms,
|
|
56
56
|
max_det=self.args.max_det,
|
|
57
57
|
nc=1, # set to 1 class since SAM has no class predictions
|
|
58
|
-
classes=self.args.classes
|
|
58
|
+
classes=self.args.classes,
|
|
59
|
+
)
|
|
59
60
|
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
|
|
60
61
|
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
|
61
62
|
full_box = full_box.view(1, -1)
|
|
@@ -23,7 +23,7 @@ class FastSAMPrompt:
|
|
|
23
23
|
clip: CLIP model for linear assignment.
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
|
-
def __init__(self, source, results, device=
|
|
26
|
+
def __init__(self, source, results, device="cuda") -> None:
|
|
27
27
|
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
|
28
28
|
self.device = device
|
|
29
29
|
self.results = results
|
|
@@ -34,7 +34,8 @@ class FastSAMPrompt:
|
|
|
34
34
|
import clip # for linear_assignment
|
|
35
35
|
except ImportError:
|
|
36
36
|
from ultralytics.utils.checks import check_requirements
|
|
37
|
-
|
|
37
|
+
|
|
38
|
+
check_requirements("git+https://github.com/openai/CLIP.git")
|
|
38
39
|
import clip
|
|
39
40
|
self.clip = clip
|
|
40
41
|
|
|
@@ -46,11 +47,11 @@ class FastSAMPrompt:
|
|
|
46
47
|
x1, y1, x2, y2 = bbox
|
|
47
48
|
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
|
48
49
|
segmented_image = Image.fromarray(segmented_image_array)
|
|
49
|
-
black_image = Image.new(
|
|
50
|
+
black_image = Image.new("RGB", image.size, (255, 255, 255))
|
|
50
51
|
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
|
51
52
|
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
|
52
53
|
transparency_mask[y1:y2, x1:x2] = 255
|
|
53
|
-
transparency_mask_image = Image.fromarray(transparency_mask, mode=
|
|
54
|
+
transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
|
|
54
55
|
black_image.paste(segmented_image, mask=transparency_mask_image)
|
|
55
56
|
return black_image
|
|
56
57
|
|
|
@@ -65,11 +66,12 @@ class FastSAMPrompt:
|
|
|
65
66
|
mask = result.masks.data[i] == 1.0
|
|
66
67
|
if torch.sum(mask) >= filter:
|
|
67
68
|
annotation = {
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
69
|
+
"id": i,
|
|
70
|
+
"segmentation": mask.cpu().numpy(),
|
|
71
|
+
"bbox": result.boxes.data[i],
|
|
72
|
+
"score": result.boxes.conf[i],
|
|
73
|
+
}
|
|
74
|
+
annotation["area"] = annotation["segmentation"].sum()
|
|
73
75
|
annotations.append(annotation)
|
|
74
76
|
return annotations
|
|
75
77
|
|
|
@@ -91,16 +93,18 @@ class FastSAMPrompt:
|
|
|
91
93
|
y2 = max(y2, y_t + h_t)
|
|
92
94
|
return [x1, y1, x2, y2]
|
|
93
95
|
|
|
94
|
-
def plot(
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
96
|
+
def plot(
|
|
97
|
+
self,
|
|
98
|
+
annotations,
|
|
99
|
+
output,
|
|
100
|
+
bbox=None,
|
|
101
|
+
points=None,
|
|
102
|
+
point_label=None,
|
|
103
|
+
mask_random_color=True,
|
|
104
|
+
better_quality=True,
|
|
105
|
+
retina=False,
|
|
106
|
+
with_contours=True,
|
|
107
|
+
):
|
|
104
108
|
"""
|
|
105
109
|
Plots annotations, bounding boxes, and points on images and saves the output.
|
|
106
110
|
|
|
@@ -139,15 +143,17 @@ class FastSAMPrompt:
|
|
|
139
143
|
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
|
140
144
|
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
|
141
145
|
|
|
142
|
-
self.fast_show_mask(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
146
|
+
self.fast_show_mask(
|
|
147
|
+
masks,
|
|
148
|
+
plt.gca(),
|
|
149
|
+
random_color=mask_random_color,
|
|
150
|
+
bbox=bbox,
|
|
151
|
+
points=points,
|
|
152
|
+
pointlabel=point_label,
|
|
153
|
+
retinamask=retina,
|
|
154
|
+
target_height=original_h,
|
|
155
|
+
target_width=original_w,
|
|
156
|
+
)
|
|
151
157
|
|
|
152
158
|
if with_contours:
|
|
153
159
|
contour_all = []
|
|
@@ -166,10 +172,10 @@ class FastSAMPrompt:
|
|
|
166
172
|
# Save the figure
|
|
167
173
|
save_path = Path(output) / result_name
|
|
168
174
|
save_path.parent.mkdir(exist_ok=True, parents=True)
|
|
169
|
-
plt.axis(
|
|
170
|
-
plt.savefig(save_path, bbox_inches=
|
|
175
|
+
plt.axis("off")
|
|
176
|
+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
|
|
171
177
|
plt.close()
|
|
172
|
-
pbar.set_description(f
|
|
178
|
+
pbar.set_description(f"Saving {result_name} to {save_path}")
|
|
173
179
|
|
|
174
180
|
@staticmethod
|
|
175
181
|
def fast_show_mask(
|
|
@@ -212,26 +218,26 @@ class FastSAMPrompt:
|
|
|
212
218
|
mask_image = np.expand_dims(annotation, -1) * visual
|
|
213
219
|
|
|
214
220
|
show = np.zeros((h, w, 4))
|
|
215
|
-
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing=
|
|
221
|
+
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
|
|
216
222
|
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
|
217
223
|
|
|
218
224
|
show[h_indices, w_indices, :] = mask_image[indices]
|
|
219
225
|
if bbox is not None:
|
|
220
226
|
x1, y1, x2, y2 = bbox
|
|
221
|
-
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=
|
|
227
|
+
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
|
|
222
228
|
# Draw point
|
|
223
229
|
if points is not None:
|
|
224
230
|
plt.scatter(
|
|
225
231
|
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
|
226
232
|
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
|
227
233
|
s=20,
|
|
228
|
-
c=
|
|
234
|
+
c="y",
|
|
229
235
|
)
|
|
230
236
|
plt.scatter(
|
|
231
237
|
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
|
232
238
|
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
|
233
239
|
s=20,
|
|
234
|
-
c=
|
|
240
|
+
c="m",
|
|
235
241
|
)
|
|
236
242
|
|
|
237
243
|
if not retinamask:
|
|
@@ -258,7 +264,7 @@ class FastSAMPrompt:
|
|
|
258
264
|
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
|
259
265
|
ori_w, ori_h = image.size
|
|
260
266
|
annotations = format_results
|
|
261
|
-
mask_h, mask_w = annotations[0][
|
|
267
|
+
mask_h, mask_w = annotations[0]["segmentation"].shape
|
|
262
268
|
if ori_w != mask_w or ori_h != mask_h:
|
|
263
269
|
image = image.resize((mask_w, mask_h))
|
|
264
270
|
cropped_boxes = []
|
|
@@ -266,19 +272,19 @@ class FastSAMPrompt:
|
|
|
266
272
|
not_crop = []
|
|
267
273
|
filter_id = []
|
|
268
274
|
for _, mask in enumerate(annotations):
|
|
269
|
-
if np.sum(mask[
|
|
275
|
+
if np.sum(mask["segmentation"]) <= 100:
|
|
270
276
|
filter_id.append(_)
|
|
271
277
|
continue
|
|
272
|
-
bbox = self._get_bbox_from_mask(mask[
|
|
273
|
-
cropped_boxes.append(self._segment_image(image, bbox)) #
|
|
274
|
-
cropped_images.append(bbox) #
|
|
278
|
+
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
|
|
279
|
+
cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
|
|
280
|
+
cropped_images.append(bbox) # save cropped image bbox
|
|
275
281
|
|
|
276
282
|
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
|
277
283
|
|
|
278
284
|
def box_prompt(self, bbox):
|
|
279
285
|
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
|
280
286
|
if self.results[0].masks is not None:
|
|
281
|
-
assert
|
|
287
|
+
assert bbox[2] != 0 and bbox[3] != 0
|
|
282
288
|
if os.path.isdir(self.source):
|
|
283
289
|
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
|
284
290
|
masks = self.results[0].masks.data
|
|
@@ -290,7 +296,8 @@ class FastSAMPrompt:
|
|
|
290
296
|
int(bbox[0] * w / target_width),
|
|
291
297
|
int(bbox[1] * h / target_height),
|
|
292
298
|
int(bbox[2] * w / target_width),
|
|
293
|
-
int(bbox[3] * h / target_height),
|
|
299
|
+
int(bbox[3] * h / target_height),
|
|
300
|
+
]
|
|
294
301
|
bbox[0] = max(round(bbox[0]), 0)
|
|
295
302
|
bbox[1] = max(round(bbox[1]), 0)
|
|
296
303
|
bbox[2] = min(round(bbox[2]), w)
|
|
@@ -299,7 +306,7 @@ class FastSAMPrompt:
|
|
|
299
306
|
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
|
300
307
|
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
|
301
308
|
|
|
302
|
-
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
|
|
309
|
+
masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
|
|
303
310
|
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
|
304
311
|
|
|
305
312
|
union = bbox_area + orig_masks_area - masks_area
|
|
@@ -316,13 +323,13 @@ class FastSAMPrompt:
|
|
|
316
323
|
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
|
|
317
324
|
masks = self._format_results(self.results[0], 0)
|
|
318
325
|
target_height, target_width = self.results[0].orig_shape
|
|
319
|
-
h = masks[0][
|
|
320
|
-
w = masks[0][
|
|
326
|
+
h = masks[0]["segmentation"].shape[0]
|
|
327
|
+
w = masks[0]["segmentation"].shape[1]
|
|
321
328
|
if h != target_height or w != target_width:
|
|
322
329
|
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
|
323
330
|
onemask = np.zeros((h, w))
|
|
324
331
|
for annotation in masks:
|
|
325
|
-
mask = annotation[
|
|
332
|
+
mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
|
|
326
333
|
for i, point in enumerate(points):
|
|
327
334
|
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
|
328
335
|
onemask += mask
|
|
@@ -337,12 +344,12 @@ class FastSAMPrompt:
|
|
|
337
344
|
if self.results[0].masks is not None:
|
|
338
345
|
format_results = self._format_results(self.results[0], 0)
|
|
339
346
|
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
|
340
|
-
clip_model, preprocess = self.clip.load(
|
|
347
|
+
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
|
|
341
348
|
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
|
342
349
|
max_idx = scores.argsort()
|
|
343
350
|
max_idx = max_idx[-1]
|
|
344
351
|
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
|
345
|
-
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx][
|
|
352
|
+
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
|
|
346
353
|
return self.results
|
|
347
354
|
|
|
348
355
|
def everything_prompt(self):
|
|
@@ -35,6 +35,6 @@ class FastSAMValidator(SegmentationValidator):
|
|
|
35
35
|
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
|
36
36
|
"""
|
|
37
37
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
|
38
|
-
self.args.task =
|
|
38
|
+
self.args.task = "segment"
|
|
39
39
|
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
|
|
40
40
|
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
ultralytics/models/nas/model.py
CHANGED
|
@@ -44,20 +44,21 @@ class NAS(Model):
|
|
|
44
44
|
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
|
|
45
45
|
"""
|
|
46
46
|
|
|
47
|
-
def __init__(self, model=
|
|
47
|
+
def __init__(self, model="yolo_nas_s.pt") -> None:
|
|
48
48
|
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
|
49
|
-
assert Path(model).suffix not in (
|
|
50
|
-
super().__init__(model, task=
|
|
49
|
+
assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
|
|
50
|
+
super().__init__(model, task="detect")
|
|
51
51
|
|
|
52
52
|
@smart_inference_mode()
|
|
53
53
|
def _load(self, weights: str, task: str):
|
|
54
54
|
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
|
55
55
|
import super_gradients
|
|
56
|
+
|
|
56
57
|
suffix = Path(weights).suffix
|
|
57
|
-
if suffix ==
|
|
58
|
+
if suffix == ".pt":
|
|
58
59
|
self.model = torch.load(weights)
|
|
59
|
-
elif suffix ==
|
|
60
|
-
self.model = super_gradients.training.models.get(weights, pretrained_weights=
|
|
60
|
+
elif suffix == "":
|
|
61
|
+
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
|
61
62
|
# Standardize model
|
|
62
63
|
self.model.fuse = lambda verbose=True: self.model
|
|
63
64
|
self.model.stride = torch.tensor([32])
|
|
@@ -65,7 +66,7 @@ class NAS(Model):
|
|
|
65
66
|
self.model.is_fused = lambda: False # for info()
|
|
66
67
|
self.model.yaml = {} # for info()
|
|
67
68
|
self.model.pt_path = weights # for export()
|
|
68
|
-
self.model.task =
|
|
69
|
+
self.model.task = "detect" # for export()
|
|
69
70
|
|
|
70
71
|
def info(self, detailed=False, verbose=True):
|
|
71
72
|
"""
|
|
@@ -80,4 +81,4 @@ class NAS(Model):
|
|
|
80
81
|
@property
|
|
81
82
|
def task_map(self):
|
|
82
83
|
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
|
|
83
|
-
return {
|
|
84
|
+
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|
|
@@ -39,12 +39,14 @@ class NASPredictor(BasePredictor):
|
|
|
39
39
|
boxes = ops.xyxy2xywh(preds_in[0][0])
|
|
40
40
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
|
41
41
|
|
|
42
|
-
preds = ops.non_max_suppression(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
42
|
+
preds = ops.non_max_suppression(
|
|
43
|
+
preds,
|
|
44
|
+
self.args.conf,
|
|
45
|
+
self.args.iou,
|
|
46
|
+
agnostic=self.args.agnostic_nms,
|
|
47
|
+
max_det=self.args.max_det,
|
|
48
|
+
classes=self.args.classes,
|
|
49
|
+
)
|
|
48
50
|
|
|
49
51
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
50
52
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
ultralytics/models/nas/val.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch
|
|
|
5
5
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
6
6
|
from ultralytics.utils import ops
|
|
7
7
|
|
|
8
|
-
__all__ = [
|
|
8
|
+
__all__ = ["NASValidator"]
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class NASValidator(DetectionValidator):
|
|
@@ -38,11 +38,13 @@ class NASValidator(DetectionValidator):
|
|
|
38
38
|
"""Apply Non-maximum suppression to prediction outputs."""
|
|
39
39
|
boxes = ops.xyxy2xywh(preds_in[0][0])
|
|
40
40
|
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
|
41
|
-
return ops.non_max_suppression(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
41
|
+
return ops.non_max_suppression(
|
|
42
|
+
preds,
|
|
43
|
+
self.args.conf,
|
|
44
|
+
self.args.iou,
|
|
45
|
+
labels=self.lb,
|
|
46
|
+
multi_label=False,
|
|
47
|
+
agnostic=self.args.single_cls,
|
|
48
|
+
max_det=self.args.max_det,
|
|
49
|
+
max_time_img=0.5,
|
|
50
|
+
)
|
|
@@ -24,7 +24,7 @@ class RTDETR(Model):
|
|
|
24
24
|
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
|
-
def __init__(self, model=
|
|
27
|
+
def __init__(self, model="rtdetr-l.pt") -> None:
|
|
28
28
|
"""
|
|
29
29
|
Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
|
|
30
30
|
|
|
@@ -34,9 +34,9 @@ class RTDETR(Model):
|
|
|
34
34
|
Raises:
|
|
35
35
|
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
|
|
36
36
|
"""
|
|
37
|
-
if model and model.split(
|
|
38
|
-
raise NotImplementedError(
|
|
39
|
-
super().__init__(model=model, task=
|
|
37
|
+
if model and model.split(".")[-1] not in ("pt", "yaml", "yml"):
|
|
38
|
+
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
|
|
39
|
+
super().__init__(model=model, task="detect")
|
|
40
40
|
|
|
41
41
|
@property
|
|
42
42
|
def task_map(self) -> dict:
|
|
@@ -47,8 +47,10 @@ class RTDETR(Model):
|
|
|
47
47
|
dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
|
48
48
|
"""
|
|
49
49
|
return {
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
50
|
+
"detect": {
|
|
51
|
+
"predictor": RTDETRPredictor,
|
|
52
|
+
"validator": RTDETRValidator,
|
|
53
|
+
"trainer": RTDETRTrainer,
|
|
54
|
+
"model": RTDETRDetectionModel,
|
|
55
|
+
}
|
|
56
|
+
}
|
|
@@ -43,12 +43,12 @@ class RTDETRTrainer(DetectionTrainer):
|
|
|
43
43
|
Returns:
|
|
44
44
|
(RTDETRDetectionModel): Initialized model.
|
|
45
45
|
"""
|
|
46
|
-
model = RTDETRDetectionModel(cfg, nc=self.data[
|
|
46
|
+
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
|
47
47
|
if weights:
|
|
48
48
|
model.load(weights)
|
|
49
49
|
return model
|
|
50
50
|
|
|
51
|
-
def build_dataset(self, img_path, mode=
|
|
51
|
+
def build_dataset(self, img_path, mode="val", batch=None):
|
|
52
52
|
"""
|
|
53
53
|
Build and return an RT-DETR dataset for training or validation.
|
|
54
54
|
|
|
@@ -60,15 +60,17 @@ class RTDETRTrainer(DetectionTrainer):
|
|
|
60
60
|
Returns:
|
|
61
61
|
(RTDETRDataset): Dataset object for the specific mode.
|
|
62
62
|
"""
|
|
63
|
-
return RTDETRDataset(
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
63
|
+
return RTDETRDataset(
|
|
64
|
+
img_path=img_path,
|
|
65
|
+
imgsz=self.args.imgsz,
|
|
66
|
+
batch_size=batch,
|
|
67
|
+
augment=mode == "train",
|
|
68
|
+
hyp=self.args,
|
|
69
|
+
rect=False,
|
|
70
|
+
cache=self.args.cache or None,
|
|
71
|
+
prefix=colorstr(f"{mode}: "),
|
|
72
|
+
data=self.data,
|
|
73
|
+
)
|
|
72
74
|
|
|
73
75
|
def get_validator(self):
|
|
74
76
|
"""
|
|
@@ -77,7 +79,7 @@ class RTDETRTrainer(DetectionTrainer):
|
|
|
77
79
|
Returns:
|
|
78
80
|
(RTDETRValidator): Validator object for model validation.
|
|
79
81
|
"""
|
|
80
|
-
self.loss_names =
|
|
82
|
+
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
|
|
81
83
|
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
|
82
84
|
|
|
83
85
|
def preprocess_batch(self, batch):
|
|
@@ -91,10 +93,10 @@ class RTDETRTrainer(DetectionTrainer):
|
|
|
91
93
|
(dict): Preprocessed batch.
|
|
92
94
|
"""
|
|
93
95
|
batch = super().preprocess_batch(batch)
|
|
94
|
-
bs = len(batch[
|
|
95
|
-
batch_idx = batch[
|
|
96
|
+
bs = len(batch["img"])
|
|
97
|
+
batch_idx = batch["batch_idx"]
|
|
96
98
|
gt_bbox, gt_class = [], []
|
|
97
99
|
for i in range(bs):
|
|
98
|
-
gt_bbox.append(batch[
|
|
99
|
-
gt_class.append(batch[
|
|
100
|
+
gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
|
|
101
|
+
gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
|
|
100
102
|
return batch
|
ultralytics/models/rtdetr/val.py
CHANGED
|
@@ -7,7 +7,7 @@ from ultralytics.data.augment import Compose, Format, v8_transforms
|
|
|
7
7
|
from ultralytics.models.yolo.detect import DetectionValidator
|
|
8
8
|
from ultralytics.utils import colorstr, ops
|
|
9
9
|
|
|
10
|
-
__all__ =
|
|
10
|
+
__all__ = ("RTDETRValidator",) # tuple or list
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class RTDETRDataset(YOLODataset):
|
|
@@ -37,13 +37,16 @@ class RTDETRDataset(YOLODataset):
|
|
|
37
37
|
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
|
|
38
38
|
transforms = Compose([])
|
|
39
39
|
transforms.append(
|
|
40
|
-
Format(
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
40
|
+
Format(
|
|
41
|
+
bbox_format="xywh",
|
|
42
|
+
normalize=True,
|
|
43
|
+
return_mask=self.use_segments,
|
|
44
|
+
return_keypoint=self.use_keypoints,
|
|
45
|
+
batch_idx=True,
|
|
46
|
+
mask_ratio=hyp.mask_ratio,
|
|
47
|
+
mask_overlap=hyp.overlap_mask,
|
|
48
|
+
)
|
|
49
|
+
)
|
|
47
50
|
return transforms
|
|
48
51
|
|
|
49
52
|
|
|
@@ -68,7 +71,7 @@ class RTDETRValidator(DetectionValidator):
|
|
|
68
71
|
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
|
69
72
|
"""
|
|
70
73
|
|
|
71
|
-
def build_dataset(self, img_path, mode=
|
|
74
|
+
def build_dataset(self, img_path, mode="val", batch=None):
|
|
72
75
|
"""
|
|
73
76
|
Build an RTDETR Dataset.
|
|
74
77
|
|
|
@@ -85,8 +88,9 @@ class RTDETRValidator(DetectionValidator):
|
|
|
85
88
|
hyp=self.args,
|
|
86
89
|
rect=False, # no rect
|
|
87
90
|
cache=self.args.cache or None,
|
|
88
|
-
prefix=colorstr(f
|
|
89
|
-
data=self.data
|
|
91
|
+
prefix=colorstr(f"{mode}: "),
|
|
92
|
+
data=self.data,
|
|
93
|
+
)
|
|
90
94
|
|
|
91
95
|
def postprocess(self, preds):
|
|
92
96
|
"""Apply Non-maximum suppression to prediction outputs."""
|
|
@@ -107,12 +111,13 @@ class RTDETRValidator(DetectionValidator):
|
|
|
107
111
|
return outputs
|
|
108
112
|
|
|
109
113
|
def _prepare_batch(self, si, batch):
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
114
|
+
"""Prepares a batch for training or inference by applying transformations."""
|
|
115
|
+
idx = batch["batch_idx"] == si
|
|
116
|
+
cls = batch["cls"][idx].squeeze(-1)
|
|
117
|
+
bbox = batch["bboxes"][idx]
|
|
118
|
+
ori_shape = batch["ori_shape"][si]
|
|
119
|
+
imgsz = batch["img"].shape[2:]
|
|
120
|
+
ratio_pad = batch["ratio_pad"][si]
|
|
116
121
|
if len(cls):
|
|
117
122
|
bbox = ops.xywh2xyxy(bbox) # target boxes
|
|
118
123
|
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
|
|
@@ -121,7 +126,8 @@ class RTDETRValidator(DetectionValidator):
|
|
|
121
126
|
return prepared_batch
|
|
122
127
|
|
|
123
128
|
def _prepare_pred(self, pred, pbatch):
|
|
129
|
+
"""Prepares and returns a batch with transformed bounding boxes and class labels."""
|
|
124
130
|
predn = pred.clone()
|
|
125
|
-
predn[..., [0, 2]] *= pbatch[
|
|
126
|
-
predn[..., [1, 3]] *= pbatch[
|
|
131
|
+
predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
|
132
|
+
predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
|
127
133
|
return predn.float()
|