ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
ultralytics/models/yolo/model.py
CHANGED
|
@@ -12,28 +12,34 @@ class YOLO(Model):
|
|
|
12
12
|
def task_map(self):
|
|
13
13
|
"""Map head to model, trainer, validator, and predictor classes."""
|
|
14
14
|
return {
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
15
|
+
"classify": {
|
|
16
|
+
"model": ClassificationModel,
|
|
17
|
+
"trainer": yolo.classify.ClassificationTrainer,
|
|
18
|
+
"validator": yolo.classify.ClassificationValidator,
|
|
19
|
+
"predictor": yolo.classify.ClassificationPredictor,
|
|
20
|
+
},
|
|
21
|
+
"detect": {
|
|
22
|
+
"model": DetectionModel,
|
|
23
|
+
"trainer": yolo.detect.DetectionTrainer,
|
|
24
|
+
"validator": yolo.detect.DetectionValidator,
|
|
25
|
+
"predictor": yolo.detect.DetectionPredictor,
|
|
26
|
+
},
|
|
27
|
+
"segment": {
|
|
28
|
+
"model": SegmentationModel,
|
|
29
|
+
"trainer": yolo.segment.SegmentationTrainer,
|
|
30
|
+
"validator": yolo.segment.SegmentationValidator,
|
|
31
|
+
"predictor": yolo.segment.SegmentationPredictor,
|
|
32
|
+
},
|
|
33
|
+
"pose": {
|
|
34
|
+
"model": PoseModel,
|
|
35
|
+
"trainer": yolo.pose.PoseTrainer,
|
|
36
|
+
"validator": yolo.pose.PoseValidator,
|
|
37
|
+
"predictor": yolo.pose.PosePredictor,
|
|
38
|
+
},
|
|
39
|
+
"obb": {
|
|
40
|
+
"model": OBBModel,
|
|
41
|
+
"trainer": yolo.obb.OBBTrainer,
|
|
42
|
+
"validator": yolo.obb.OBBValidator,
|
|
43
|
+
"predictor": yolo.obb.OBBPredictor,
|
|
44
|
+
},
|
|
45
|
+
}
|
|
@@ -23,27 +23,29 @@ class OBBPredictor(DetectionPredictor):
|
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
25
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
26
|
+
"""Initializes OBBPredictor with optional model and data configuration overrides."""
|
|
26
27
|
super().__init__(cfg, overrides, _callbacks)
|
|
27
|
-
self.args.task =
|
|
28
|
+
self.args.task = "obb"
|
|
28
29
|
|
|
29
30
|
def postprocess(self, preds, img, orig_imgs):
|
|
30
31
|
"""Post-processes predictions and returns a list of Results objects."""
|
|
31
|
-
preds = ops.non_max_suppression(
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
32
|
+
preds = ops.non_max_suppression(
|
|
33
|
+
preds,
|
|
34
|
+
self.args.conf,
|
|
35
|
+
self.args.iou,
|
|
36
|
+
agnostic=self.args.agnostic_nms,
|
|
37
|
+
max_det=self.args.max_det,
|
|
38
|
+
nc=len(self.model.names),
|
|
39
|
+
classes=self.args.classes,
|
|
40
|
+
rotated=True,
|
|
41
|
+
)
|
|
39
42
|
|
|
40
43
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
41
44
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
42
45
|
|
|
43
46
|
results = []
|
|
44
|
-
for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)):
|
|
47
|
+
for i, (pred, orig_img, img_path) in enumerate(zip(preds, orig_imgs, self.batch[0])):
|
|
45
48
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
|
|
46
|
-
img_path = self.batch[0][i]
|
|
47
49
|
# xywh, r, conf, cls
|
|
48
50
|
obb = torch.cat([pred[:, :4], pred[:, -1:], pred[:, 4:6]], dim=-1)
|
|
49
51
|
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
|
|
@@ -25,12 +25,12 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
25
25
|
"""Initialize a OBBTrainer object with given arguments."""
|
|
26
26
|
if overrides is None:
|
|
27
27
|
overrides = {}
|
|
28
|
-
overrides[
|
|
28
|
+
overrides["task"] = "obb"
|
|
29
29
|
super().__init__(cfg, overrides, _callbacks)
|
|
30
30
|
|
|
31
31
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
32
32
|
"""Return OBBModel initialized with specified config and weights."""
|
|
33
|
-
model = OBBModel(cfg, ch=3, nc=self.data[
|
|
33
|
+
model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
|
34
34
|
if weights:
|
|
35
35
|
model.load(weights)
|
|
36
36
|
|
|
@@ -38,5 +38,5 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
|
|
|
38
38
|
|
|
39
39
|
def get_validator(self):
|
|
40
40
|
"""Return an instance of OBBValidator for validation of YOLO model."""
|
|
41
|
-
self.loss_names =
|
|
41
|
+
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
|
42
42
|
return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
|
@@ -27,26 +27,28 @@ class OBBValidator(DetectionValidator):
|
|
|
27
27
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
|
28
28
|
"""Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics."""
|
|
29
29
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
|
30
|
-
self.args.task =
|
|
30
|
+
self.args.task = "obb"
|
|
31
31
|
self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot)
|
|
32
32
|
|
|
33
33
|
def init_metrics(self, model):
|
|
34
34
|
"""Initialize evaluation metrics for YOLO."""
|
|
35
35
|
super().init_metrics(model)
|
|
36
|
-
val = self.data.get(self.args.split,
|
|
37
|
-
self.is_dota = isinstance(val, str) and
|
|
36
|
+
val = self.data.get(self.args.split, "") # validation path
|
|
37
|
+
self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
|
|
38
38
|
|
|
39
39
|
def postprocess(self, preds):
|
|
40
40
|
"""Apply Non-maximum suppression to prediction outputs."""
|
|
41
|
-
return ops.non_max_suppression(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
41
|
+
return ops.non_max_suppression(
|
|
42
|
+
preds,
|
|
43
|
+
self.args.conf,
|
|
44
|
+
self.args.iou,
|
|
45
|
+
labels=self.lb,
|
|
46
|
+
nc=self.nc,
|
|
47
|
+
multi_label=True,
|
|
48
|
+
agnostic=self.args.single_cls,
|
|
49
|
+
max_det=self.args.max_det,
|
|
50
|
+
rotated=True,
|
|
51
|
+
)
|
|
50
52
|
|
|
51
53
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
|
52
54
|
"""
|
|
@@ -65,12 +67,13 @@ class OBBValidator(DetectionValidator):
|
|
|
65
67
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
|
66
68
|
|
|
67
69
|
def _prepare_batch(self, si, batch):
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
70
|
+
"""Prepares and returns a batch for OBB validation."""
|
|
71
|
+
idx = batch["batch_idx"] == si
|
|
72
|
+
cls = batch["cls"][idx].squeeze(-1)
|
|
73
|
+
bbox = batch["bboxes"][idx]
|
|
74
|
+
ori_shape = batch["ori_shape"][si]
|
|
75
|
+
imgsz = batch["img"].shape[2:]
|
|
76
|
+
ratio_pad = batch["ratio_pad"][si]
|
|
74
77
|
if len(cls):
|
|
75
78
|
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
|
76
79
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
|
|
@@ -78,19 +81,23 @@ class OBBValidator(DetectionValidator):
|
|
|
78
81
|
return prepared_batch
|
|
79
82
|
|
|
80
83
|
def _prepare_pred(self, pred, pbatch):
|
|
84
|
+
"""Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
|
|
81
85
|
predn = pred.clone()
|
|
82
|
-
ops.scale_boxes(
|
|
83
|
-
|
|
86
|
+
ops.scale_boxes(
|
|
87
|
+
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
|
88
|
+
) # native-space pred
|
|
84
89
|
return predn
|
|
85
90
|
|
|
86
91
|
def plot_predictions(self, batch, preds, ni):
|
|
87
92
|
"""Plots predicted bounding boxes on input images and saves the result."""
|
|
88
|
-
plot_images(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
93
|
+
plot_images(
|
|
94
|
+
batch["img"],
|
|
95
|
+
*output_to_rotated_target(preds, max_det=self.args.max_det),
|
|
96
|
+
paths=batch["im_file"],
|
|
97
|
+
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
98
|
+
names=self.names,
|
|
99
|
+
on_plot=self.on_plot,
|
|
100
|
+
) # pred
|
|
94
101
|
|
|
95
102
|
def pred_to_json(self, predn, filename):
|
|
96
103
|
"""Serialize YOLO predictions to COCO json format."""
|
|
@@ -99,12 +106,15 @@ class OBBValidator(DetectionValidator):
|
|
|
99
106
|
rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
|
|
100
107
|
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
|
101
108
|
for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())):
|
|
102
|
-
self.jdict.append(
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
109
|
+
self.jdict.append(
|
|
110
|
+
{
|
|
111
|
+
"image_id": image_id,
|
|
112
|
+
"category_id": self.class_map[int(predn[i, 5].item())],
|
|
113
|
+
"score": round(predn[i, 4].item(), 5),
|
|
114
|
+
"rbox": [round(x, 3) for x in r],
|
|
115
|
+
"poly": [round(x, 3) for x in b],
|
|
116
|
+
}
|
|
117
|
+
)
|
|
108
118
|
|
|
109
119
|
def save_one_txt(self, predn, save_conf, shape, file):
|
|
110
120
|
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
|
@@ -114,8 +124,8 @@ class OBBValidator(DetectionValidator):
|
|
|
114
124
|
xywha[:, :4] /= gn
|
|
115
125
|
xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
|
|
116
126
|
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
|
|
117
|
-
with open(file,
|
|
118
|
-
f.write((
|
|
127
|
+
with open(file, "a") as f:
|
|
128
|
+
f.write(("%g " * len(line)).rstrip() % line + "\n")
|
|
119
129
|
|
|
120
130
|
def eval_json(self, stats):
|
|
121
131
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
|
@@ -123,42 +133,43 @@ class OBBValidator(DetectionValidator):
|
|
|
123
133
|
import json
|
|
124
134
|
import re
|
|
125
135
|
from collections import defaultdict
|
|
126
|
-
|
|
127
|
-
|
|
136
|
+
|
|
137
|
+
pred_json = self.save_dir / "predictions.json" # predictions
|
|
138
|
+
pred_txt = self.save_dir / "predictions_txt" # predictions
|
|
128
139
|
pred_txt.mkdir(parents=True, exist_ok=True)
|
|
129
140
|
data = json.load(open(pred_json))
|
|
130
141
|
# Save split results
|
|
131
|
-
LOGGER.info(f
|
|
142
|
+
LOGGER.info(f"Saving predictions with DOTA format to {str(pred_txt)}...")
|
|
132
143
|
for d in data:
|
|
133
|
-
image_id = d[
|
|
134
|
-
score = d[
|
|
135
|
-
classname = self.names[d[
|
|
144
|
+
image_id = d["image_id"]
|
|
145
|
+
score = d["score"]
|
|
146
|
+
classname = self.names[d["category_id"]].replace(" ", "-")
|
|
136
147
|
|
|
137
|
-
lines =
|
|
148
|
+
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
|
|
138
149
|
image_id,
|
|
139
150
|
score,
|
|
140
|
-
d[
|
|
141
|
-
d[
|
|
142
|
-
d[
|
|
143
|
-
d[
|
|
144
|
-
d[
|
|
145
|
-
d[
|
|
146
|
-
d[
|
|
147
|
-
d[
|
|
151
|
+
d["poly"][0],
|
|
152
|
+
d["poly"][1],
|
|
153
|
+
d["poly"][2],
|
|
154
|
+
d["poly"][3],
|
|
155
|
+
d["poly"][4],
|
|
156
|
+
d["poly"][5],
|
|
157
|
+
d["poly"][6],
|
|
158
|
+
d["poly"][7],
|
|
148
159
|
)
|
|
149
|
-
with open(str(pred_txt / f
|
|
160
|
+
with open(str(pred_txt / f"Task1_{classname}") + ".txt", "a") as f:
|
|
150
161
|
f.writelines(lines)
|
|
151
162
|
# Save merged results, this could result slightly lower map than using official merging script,
|
|
152
163
|
# because of the probiou calculation.
|
|
153
|
-
pred_merged_txt = self.save_dir /
|
|
164
|
+
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
|
|
154
165
|
pred_merged_txt.mkdir(parents=True, exist_ok=True)
|
|
155
166
|
merged_results = defaultdict(list)
|
|
156
|
-
LOGGER.info(f
|
|
167
|
+
LOGGER.info(f"Saving merged predictions with DOTA format to {str(pred_merged_txt)}...")
|
|
157
168
|
for d in data:
|
|
158
|
-
image_id = d[
|
|
159
|
-
pattern = re.compile(r
|
|
160
|
-
x, y = (int(c) for c in re.findall(pattern, d[
|
|
161
|
-
bbox, score, cls = d[
|
|
169
|
+
image_id = d["image_id"].split("__")[0]
|
|
170
|
+
pattern = re.compile(r"\d+___\d+")
|
|
171
|
+
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
|
|
172
|
+
bbox, score, cls = d["rbox"], d["score"], d["category_id"]
|
|
162
173
|
bbox[0] += x
|
|
163
174
|
bbox[1] += y
|
|
164
175
|
bbox.extend([score, cls])
|
|
@@ -176,11 +187,11 @@ class OBBValidator(DetectionValidator):
|
|
|
176
187
|
|
|
177
188
|
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
|
178
189
|
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
|
|
179
|
-
classname = self.names[int(x[-1])].replace(
|
|
190
|
+
classname = self.names[int(x[-1])].replace(" ", "-")
|
|
180
191
|
poly = [round(i, 3) for i in x[:-2]]
|
|
181
192
|
score = round(x[-2], 3)
|
|
182
193
|
|
|
183
|
-
lines =
|
|
194
|
+
lines = "{} {} {} {} {} {} {} {} {} {}\n".format(
|
|
184
195
|
image_id,
|
|
185
196
|
score,
|
|
186
197
|
poly[0],
|
|
@@ -192,7 +203,7 @@ class OBBValidator(DetectionValidator):
|
|
|
192
203
|
poly[6],
|
|
193
204
|
poly[7],
|
|
194
205
|
)
|
|
195
|
-
with open(str(pred_merged_txt / f
|
|
206
|
+
with open(str(pred_merged_txt / f"Task1_{classname}") + ".txt", "a") as f:
|
|
196
207
|
f.writelines(lines)
|
|
197
208
|
|
|
198
209
|
return stats
|
|
@@ -23,20 +23,24 @@ class PosePredictor(DetectionPredictor):
|
|
|
23
23
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
24
24
|
"""Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device."""
|
|
25
25
|
super().__init__(cfg, overrides, _callbacks)
|
|
26
|
-
self.args.task =
|
|
27
|
-
if isinstance(self.args.device, str) and self.args.device.lower() ==
|
|
28
|
-
LOGGER.warning(
|
|
29
|
-
|
|
26
|
+
self.args.task = "pose"
|
|
27
|
+
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
28
|
+
LOGGER.warning(
|
|
29
|
+
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
30
|
+
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
31
|
+
)
|
|
30
32
|
|
|
31
33
|
def postprocess(self, preds, img, orig_imgs):
|
|
32
34
|
"""Return detection results for a given input image or list of images."""
|
|
33
|
-
preds = ops.non_max_suppression(
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
35
|
+
preds = ops.non_max_suppression(
|
|
36
|
+
preds,
|
|
37
|
+
self.args.conf,
|
|
38
|
+
self.args.iou,
|
|
39
|
+
agnostic=self.args.agnostic_nms,
|
|
40
|
+
max_det=self.args.max_det,
|
|
41
|
+
classes=self.args.classes,
|
|
42
|
+
nc=len(self.model.names),
|
|
43
|
+
)
|
|
40
44
|
|
|
41
45
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
42
46
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
@@ -49,5 +53,6 @@ class PosePredictor(DetectionPredictor):
|
|
|
49
53
|
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
|
50
54
|
img_path = self.batch[0][i]
|
|
51
55
|
results.append(
|
|
52
|
-
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
|
|
56
|
+
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
|
|
57
|
+
)
|
|
53
58
|
return results
|
|
@@ -26,16 +26,18 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
26
26
|
"""Initialize a PoseTrainer object with specified configurations and overrides."""
|
|
27
27
|
if overrides is None:
|
|
28
28
|
overrides = {}
|
|
29
|
-
overrides[
|
|
29
|
+
overrides["task"] = "pose"
|
|
30
30
|
super().__init__(cfg, overrides, _callbacks)
|
|
31
31
|
|
|
32
|
-
if isinstance(self.args.device, str) and self.args.device.lower() ==
|
|
33
|
-
LOGGER.warning(
|
|
34
|
-
|
|
32
|
+
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
33
|
+
LOGGER.warning(
|
|
34
|
+
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
35
|
+
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
36
|
+
)
|
|
35
37
|
|
|
36
38
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
37
39
|
"""Get pose estimation model with specified configuration and weights."""
|
|
38
|
-
model = PoseModel(cfg, ch=3, nc=self.data[
|
|
40
|
+
model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose)
|
|
39
41
|
if weights:
|
|
40
42
|
model.load(weights)
|
|
41
43
|
|
|
@@ -44,32 +46,33 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
|
|
44
46
|
def set_model_attributes(self):
|
|
45
47
|
"""Sets keypoints shape attribute of PoseModel."""
|
|
46
48
|
super().set_model_attributes()
|
|
47
|
-
self.model.kpt_shape = self.data[
|
|
49
|
+
self.model.kpt_shape = self.data["kpt_shape"]
|
|
48
50
|
|
|
49
51
|
def get_validator(self):
|
|
50
52
|
"""Returns an instance of the PoseValidator class for validation."""
|
|
51
|
-
self.loss_names =
|
|
52
|
-
return yolo.pose.PoseValidator(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
_callbacks=self.callbacks)
|
|
53
|
+
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
|
54
|
+
return yolo.pose.PoseValidator(
|
|
55
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
56
|
+
)
|
|
56
57
|
|
|
57
58
|
def plot_training_samples(self, batch, ni):
|
|
58
59
|
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
|
|
59
|
-
images = batch[
|
|
60
|
-
kpts = batch[
|
|
61
|
-
cls = batch[
|
|
62
|
-
bboxes = batch[
|
|
63
|
-
paths = batch[
|
|
64
|
-
batch_idx = batch[
|
|
65
|
-
plot_images(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
60
|
+
images = batch["img"]
|
|
61
|
+
kpts = batch["keypoints"]
|
|
62
|
+
cls = batch["cls"].squeeze(-1)
|
|
63
|
+
bboxes = batch["bboxes"]
|
|
64
|
+
paths = batch["im_file"]
|
|
65
|
+
batch_idx = batch["batch_idx"]
|
|
66
|
+
plot_images(
|
|
67
|
+
images,
|
|
68
|
+
batch_idx,
|
|
69
|
+
cls,
|
|
70
|
+
bboxes,
|
|
71
|
+
kpts=kpts,
|
|
72
|
+
paths=paths,
|
|
73
|
+
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
74
|
+
on_plot=self.on_plot,
|
|
75
|
+
)
|
|
73
76
|
|
|
74
77
|
def plot_metrics(self):
|
|
75
78
|
"""Plots training/val metrics."""
|