ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +34 -0
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +5 -0
- ultralytics/data/explorer/explorer.py +170 -97
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +146 -76
- ultralytics/data/explorer/utils.py +87 -25
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +63 -40
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -12
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +80 -58
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +67 -59
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +22 -15
- ultralytics/solutions/heatmap.py +76 -54
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -151
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +39 -29
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.237.dist-info/RECORD +0 -187
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/models/yolo/model.py
CHANGED
|
@@ -12,28 +12,34 @@ class YOLO(Model):
|
|
|
12
12
|
def task_map(self):
|
|
13
13
|
"""Map head to model, trainer, validator, and predictor classes."""
|
|
14
14
|
return {
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
15
|
+
"classify": {
|
|
16
|
+
"model": ClassificationModel,
|
|
17
|
+
"trainer": yolo.classify.ClassificationTrainer,
|
|
18
|
+
"validator": yolo.classify.ClassificationValidator,
|
|
19
|
+
"predictor": yolo.classify.ClassificationPredictor,
|
|
20
|
+
},
|
|
21
|
+
"detect": {
|
|
22
|
+
"model": DetectionModel,
|
|
23
|
+
"trainer": yolo.detect.DetectionTrainer,
|
|
24
|
+
"validator": yolo.detect.DetectionValidator,
|
|
25
|
+
"predictor": yolo.detect.DetectionPredictor,
|
|
26
|
+
},
|
|
27
|
+
"segment": {
|
|
28
|
+
"model": SegmentationModel,
|
|
29
|
+
"trainer": yolo.segment.SegmentationTrainer,
|
|
30
|
+
"validator": yolo.segment.SegmentationValidator,
|
|
31
|
+
"predictor": yolo.segment.SegmentationPredictor,
|
|
32
|
+
},
|
|
33
|
+
"pose": {
|
|
34
|
+
"model": PoseModel,
|
|
35
|
+
"trainer": yolo.pose.PoseTrainer,
|
|
36
|
+
"validator": yolo.pose.PoseValidator,
|
|
37
|
+
"predictor": yolo.pose.PosePredictor,
|
|
38
|
+
},
|
|
39
|
+
"obb": {
|
|
40
|
+
"model": OBBModel,
|
|
41
|
+
"trainer": yolo.obb.OBBTrainer,
|
|
42
|
+
"validator": yolo.obb.OBBValidator,
|
|
43
|
+
"predictor": yolo.obb.OBBPredictor,
|
|
44
|
+
},
|
|
45
|
+
}
|
|
@@ -23,28 +23,29 @@ class OBBPredictor(DetectionPredictor):
|
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
25
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
26
|
+
"""Initializes OBBPredictor with optional model and data configuration overrides."""
|
|
26
27
|
super().__init__(cfg, overrides, _callbacks)
|
|
27
|
-
self.args.task =
|
|
28
|
+
self.args.task = "obb"
|
|
28
29
|
|
|
29
30
|
def postprocess(self, preds, img, orig_imgs):
|
|
30
31
|
"""Post-processes predictions and returns a list of Results objects."""
|
|
31
|
-
preds = ops.non_max_suppression(
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
32
|
+
preds = ops.non_max_suppression(
|
|
33
|
+
preds,
|
|
34
|
+
self.args.conf,
|
|
35
|
+
self.args.iou,
|
|
36
|
+
agnostic=self.args.agnostic_nms,
|
|
37
|
+
max_det=self.args.max_det,
|
|
38
|
+
nc=len(self.model.names),
|
|
39
|
+
classes=self.args.classes,
|
|
40
|
+
rotated=True,
|
|
41
|
+
)
|
|
39
42
|
|
|
40
43
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
41
44
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
42
45
|
|
|
43
46
|
results = []
|
|
44
|
-
for i, pred in enumerate(preds):
|
|
45
|
-
orig_img = orig_imgs[i]
|
|
47
|
+
for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])):
|
|
46
48
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
|
|
47
|
-
img_path = self.batch[0][i]
|
|
48
49
|
# xywh, r, conf, cls
|
|
49
50
|
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
|
|
50
51
|
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
|
|
@@ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
25
25
|
"""Initialize a OBBTrainer object with given arguments."""
|
|
26
26
|
if overrides is None:
|
|
27
27
|
overrides = {}
|
|
28
|
-
overrides[
|
|
28
|
+
overrides["task"] = "obb"
|
|
29
29
|
super().__init__(cfg, overrides, _callbacks)
|
|
30
30
|
|
|
31
31
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
32
32
|
"""Return OBBModel initialized with specified config and weights."""
|
|
33
|
-
model = OBBModel(cfg, ch=3, nc=self.data[
|
|
33
|
+
model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
|
34
34
|
if weights:
|
|
35
35
|
model.load(weights)
|
|
36
36
|
|
|
@@ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
38
38
|
|
|
39
39
|
def get_validator(self):
|
|
40
40
|
"""Return an instance of OBBValidator for validation of YOLO model."""
|
|
41
|
-
self.loss_names =
|
|
41
|
+
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
|
42
42
|
return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
|
@@ -27,26 +27,28 @@ class OBBValidator(DetectionValidator):
|
|
|
27
27
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
|
28
28
|
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
|
|
29
29
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
|
30
|
-
self.args.task =
|
|
30
|
+
self.args.task = "obb"
|
|
31
31
|
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
|
|
32
32
|
|
|
33
33
|
def init_metrics(self, model):
|
|
34
34
|
"""Initialize evaluation metrics for YOLO."""
|
|
35
35
|
super().init_metrics(model)
|
|
36
|
-
val = self.data.get(self.args.split,
|
|
37
|
-
self.is_dota = isinstance(val, str) and
|
|
36
|
+
val = self.data.get(self.args.split, "") # validation path
|
|
37
|
+
self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
|
|
38
38
|
|
|
39
39
|
def postprocess(self, preds):
|
|
40
40
|
"""Apply Non-maximum suppression to prediction outputs."""
|
|
41
|
-
return ops.non_max_suppression(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
41
|
+
return ops.non_max_suppression(
|
|
42
|
+
preds,
|
|
43
|
+
self.args.conf,
|
|
44
|
+
self.args.iou,
|
|
45
|
+
labels=self.lb,
|
|
46
|
+
nc=self.nc,
|
|
47
|
+
multi_label=True,
|
|
48
|
+
agnostic=self.args.single_cls,
|
|
49
|
+
max_det=self.args.max_det,
|
|
50
|
+
rotated=True,
|
|
51
|
+
)
|
|
50
52
|
|
|
51
53
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
|
52
54
|
"""
|
|
@@ -61,16 +63,17 @@ class OBBValidator(DetectionValidator):
|
|
|
61
63
|
Returns:
|
|
62
64
|
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
|
|
63
65
|
"""
|
|
64
|
-
iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -
|
|
66
|
+
iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
|
|
65
67
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
|
66
68
|
|
|
67
69
|
def _prepare_batch(self, si, batch):
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
70
|
+
"""Prepares and returns a batch for OBB validation."""
|
|
71
|
+
idx = batch["batch_idx"] == si
|
|
72
|
+
cls = batch["cls"][idx].squeeze(-1)
|
|
73
|
+
bbox = batch["bboxes"][idx]
|
|
74
|
+
ori_shape = batch["ori_shape"][si]
|
|
75
|
+
imgsz = batch["img"].shape[2:]
|
|
76
|
+
ratio_pad = batch["ratio_pad"][si]
|
|
74
77
|
if len(cls):
|
|
75
78
|
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
|
76
79
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
|
|
@@ -78,19 +81,23 @@ class OBBValidator(DetectionValidator):
|
|
|
78
81
|
return prepared_batch
|
|
79
82
|
|
|
80
83
|
def _prepare_pred(self, pred, pbatch):
|
|
84
|
+
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
|
|
81
85
|
predn = pred.clone()
|
|
82
|
-
ops.scale_boxes(
|
|
83
|
-
|
|
86
|
+
ops.scale_boxes(
|
|
87
|
+
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
|
88
|
+
) # native-space pred
|
|
84
89
|
return predn
|
|
85
90
|
|
|
86
91
|
def plot_predictions(self, batch, preds, ni):
|
|
87
92
|
"""Plots predicted bounding boxes on input images and saves the result."""
|
|
88
|
-
plot_images(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
93
|
+
plot_images(
|
|
94
|
+
batch["img"],
|
|
95
|
+
*output_to_rotated_target(preds, max_det=self.args.max_det),
|
|
96
|
+
paths=batch["im_file"],
|
|
97
|
+
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
98
|
+
names=self.names,
|
|
99
|
+
on_plot=self.on_plot,
|
|
100
|
+
) # pred
|
|
94
101
|
|
|
95
102
|
def pred_to_json(self, predn, filename):
|
|
96
103
|
"""Serialize YOLO predictions to COCO json format."""
|
|
@@ -99,12 +106,26 @@ class OBBValidator(DetectionValidator):
|
|
|
99
106
|
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
|
100
107
|
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
|
101
108
|
for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
|
|
102
|
-
self.jdict.append(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
109
|
+
self.jdict.append(
|
|
110
|
+
{
|
|
111
|
+
"image_id": image_id,
|
|
112
|
+
"category_id": self.class_map[int(predn[i, 5].item())],
|
|
113
|
+
"score": round(predn[i, 4].item(), 5),
|
|
114
|
+
"rbox": [round(x, 3) for x in r],
|
|
115
|
+
"poly": [round(x, 3) for x in b],
|
|
116
|
+
}
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def save_one_txt(self, predn, save_conf, shape, file):
|
|
120
|
+
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
|
121
|
+
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
|
|
122
|
+
for *xyxy, conf, cls, angle in predn.tolist():
|
|
123
|
+
xywha = torch.tensor([*xyxy, angle]).view(1, 5)
|
|
124
|
+
xywha[:, :4] /= gn
|
|
125
|
+
xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
|
|
126
|
+
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
|
|
127
|
+
with open(file, "a") as f:
|
|
128
|
+
f.write(("%g " * len(line)).rstrip() % line + "\n")
|
|
108
129
|
|
|
109
130
|
def eval_json(self, stats):
|
|
110
131
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
|
@@ -112,42 +133,43 @@ class OBBValidator(DetectionValidator):
|
|
|
112
133
|
import json
|
|
113
134
|
import re
|
|
114
135
|
from collections import defaultdict
|
|
115
|
-
|
|
116
|
-
|
|
136
|
+
|
|
137
|
+
pred_json = self.save_dir / "predictions.json" # predictions
|
|
138
|
+
pred_txt = self.save_dir / "predictions_txt" # predictions
|
|
117
139
|
pred_txt.mkdir(parents=True, exist_ok=True)
|
|
118
140
|
data = json.load(open(pred_json))
|
|
119
141
|
# Save split results
|
|
120
|
-
LOGGER.info(f
|
|
142
|
+
LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...")
|
|
121
143
|
for d in data:
|
|
122
|
-
image_id = d[
|
|
123
|
-
score = d[
|
|
124
|
-
classname = self.names[d[
|
|
144
|
+
image_id = d["image_id"]
|
|
145
|
+
score = d["score"]
|
|
146
|
+
classname = self.names[d["category_id"]].replace(" ", "-")
|
|
125
147
|
|
|
126
|
-
lines =
|
|
148
|
+
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
|
|
127
149
|
image_id,
|
|
128
150
|
score,
|
|
129
|
-
d[
|
|
130
|
-
d[
|
|
131
|
-
d[
|
|
132
|
-
d[
|
|
133
|
-
d[
|
|
134
|
-
d[
|
|
135
|
-
d[
|
|
136
|
-
d[
|
|
151
|
+
d["poly"][0],
|
|
152
|
+
d["poly"][1],
|
|
153
|
+
d["poly"][2],
|
|
154
|
+
d["poly"][3],
|
|
155
|
+
d["poly"][4],
|
|
156
|
+
d["poly"][5],
|
|
157
|
+
d["poly"][6],
|
|
158
|
+
d["poly"][7],
|
|
137
159
|
)
|
|
138
|
-
with open(str(pred_txt / f
|
|
160
|
+
with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f:
|
|
139
161
|
f.writelines(lines)
|
|
140
162
|
# Save merged results, this could result slightly lower map than using official merging script,
|
|
141
163
|
# because of the probiou calculation.
|
|
142
|
-
pred_merged_txt = self.save_dir /
|
|
164
|
+
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
|
|
143
165
|
pred_merged_txt.mkdir(parents=True, exist_ok=True)
|
|
144
166
|
merged_results = defaultdict(list)
|
|
145
|
-
LOGGER.info(f
|
|
167
|
+
LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...")
|
|
146
168
|
for d in data:
|
|
147
|
-
image_id = d[
|
|
148
|
-
pattern = re.compile(r
|
|
149
|
-
x, y = (int(c) for c in re.findall(pattern, d[
|
|
150
|
-
bbox, score, cls = d[
|
|
169
|
+
image_id = d["image_id"].split("__")[0]
|
|
170
|
+
pattern = re.compile(r"\d+___\d+")
|
|
171
|
+
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
|
|
172
|
+
bbox, score, cls = d["rbox"], d["score"], d["category_id"]
|
|
151
173
|
bbox[0] += x
|
|
152
174
|
bbox[1] += y
|
|
153
175
|
bbox.extend([score, cls])
|
|
@@ -165,11 +187,11 @@ class OBBValidator(DetectionValidator):
|
|
|
165
187
|
|
|
166
188
|
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
|
167
189
|
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
|
|
168
|
-
classname = self.names[int(x[-1])].replace(
|
|
190
|
+
classname = self.names[int(x[-1])].replace(" ", "-")
|
|
169
191
|
poly = [round(i, 3) for i in x[:-2]]
|
|
170
192
|
score = round(x[-2], 3)
|
|
171
193
|
|
|
172
|
-
lines =
|
|
194
|
+
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
|
|
173
195
|
image_id,
|
|
174
196
|
score,
|
|
175
197
|
poly[0],
|
|
@@ -181,7 +203,7 @@ class OBBValidator(DetectionValidator):
|
|
|
181
203
|
poly[6],
|
|
182
204
|
poly[7],
|
|
183
205
|
)
|
|
184
|
-
with open(str(pred_merged_txt / f
|
|
206
|
+
with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f:
|
|
185
207
|
f.writelines(lines)
|
|
186
208
|
|
|
187
209
|
return stats
|
|
@@ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor):
|
|
|
23
23
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
24
24
|
"""Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
|
|
25
25
|
super().__init__(cfg, overrides, _callbacks)
|
|
26
|
-
self.args.task =
|
|
27
|
-
if isinstance(self.args.device, str) and self.args.device.lower() ==
|
|
28
|
-
LOGGER.warning(
|
|
29
|
-
|
|
26
|
+
self.args.task = "pose"
|
|
27
|
+
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
28
|
+
LOGGER.warning(
|
|
29
|
+
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
30
|
+
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
31
|
+
)
|
|
30
32
|
|
|
31
33
|
def postprocess(self, preds, img, orig_imgs):
|
|
32
34
|
"""Return detection results for a given input image or list of images."""
|
|
33
|
-
preds = ops.non_max_suppression(
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
preds = ops.non_max_suppression(
|
|
36
|
+
preds,
|
|
37
|
+
self.args.conf,
|
|
38
|
+
self.args.iou,
|
|
39
|
+
agnostic=self.args.agnostic_nms,
|
|
40
|
+
max_det=self.args.max_det,
|
|
41
|
+
classes=self.args.classes,
|
|
42
|
+
nc=len(self.model.names),
|
|
43
|
+
)
|
|
40
44
|
|
|
41
45
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
42
46
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
@@ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor):
|
|
|
49
53
|
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
|
50
54
|
img_path = self.batch[0][i]
|
|
51
55
|
results.append(
|
|
52
|
-
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
|
|
56
|
+
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
|
|
57
|
+
)
|
|
53
58
|
return results
|
|
@@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
26
26
|
"""Initialize a PoseTrainer object with specified configurations and overrides."""
|
|
27
27
|
if overrides is None:
|
|
28
28
|
overrides = {}
|
|
29
|
-
overrides[
|
|
29
|
+
overrides["task"] = "pose"
|
|
30
30
|
super().__init__(cfg, overrides, _callbacks)
|
|
31
31
|
|
|
32
|
-
if isinstance(self.args.device, str) and self.args.device.lower() ==
|
|
33
|
-
LOGGER.warning(
|
|
34
|
-
|
|
32
|
+
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
33
|
+
LOGGER.warning(
|
|
34
|
+
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
35
|
+
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
36
|
+
)
|
|
35
37
|
|
|
36
38
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
37
39
|
"""Get pose estimation model with specified configuration and weights."""
|
|
38
|
-
model = PoseModel(cfg, ch=3, nc=self.data[
|
|
40
|
+
model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
|
|
39
41
|
if weights:
|
|
40
42
|
model.load(weights)
|
|
41
43
|
|
|
@@ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
44
46
|
def set_model_attributes(self):
|
|
45
47
|
"""Sets keypoints shape attribute of PoseModel."""
|
|
46
48
|
super().set_model_attributes()
|
|
47
|
-
self.model.kpt_shape = self.data[
|
|
49
|
+
self.model.kpt_shape = self.data["kpt_shape"]
|
|
48
50
|
|
|
49
51
|
def get_validator(self):
|
|
50
52
|
"""Returns an instance of the PoseValidator class for validation."""
|
|
51
|
-
self.loss_names =
|
|
52
|
-
return yolo.pose.PoseValidator(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
_callbacks=self.callbacks)
|
|
53
|
+
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
|
54
|
+
return yolo.pose.PoseValidator(
|
|
55
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
56
|
+
)
|
|
56
57
|
|
|
57
58
|
def plot_training_samples(self, batch, ni):
|
|
58
59
|
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
|
|
59
|
-
images = batch[
|
|
60
|
-
kpts = batch[
|
|
61
|
-
cls = batch[
|
|
62
|
-
bboxes = batch[
|
|
63
|
-
paths = batch[
|
|
64
|
-
batch_idx = batch[
|
|
65
|
-
plot_images(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
60
|
+
images = batch["img"]
|
|
61
|
+
kpts = batch["keypoints"]
|
|
62
|
+
cls = batch["cls"].squeeze(-1)
|
|
63
|
+
bboxes = batch["bboxes"]
|
|
64
|
+
paths = batch["im_file"]
|
|
65
|
+
batch_idx = batch["batch_idx"]
|
|
66
|
+
plot_images(
|
|
67
|
+
images,
|
|
68
|
+
batch_idx,
|
|
69
|
+
cls,
|
|
70
|
+
bboxes,
|
|
71
|
+
kpts=kpts,
|
|
72
|
+
paths=paths,
|
|
73
|
+
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
74
|
+
on_plot=self.on_plot,
|
|
75
|
+
)
|
|
73
76
|
|
|
74
77
|
def plot_metrics(self):
|
|
75
78
|
"""Plots training/val metrics."""
|