ultralytics 8.0.238__py3-none-any.whl → 8.0.239__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ultralytics might be problematic. Click here for more details.
- ultralytics/__init__.py +2 -2
- ultralytics/cfg/__init__.py +241 -138
- ultralytics/data/__init__.py +9 -2
- ultralytics/data/annotator.py +4 -4
- ultralytics/data/augment.py +186 -169
- ultralytics/data/base.py +54 -48
- ultralytics/data/build.py +34 -23
- ultralytics/data/converter.py +242 -70
- ultralytics/data/dataset.py +117 -95
- ultralytics/data/explorer/__init__.py +3 -1
- ultralytics/data/explorer/explorer.py +120 -100
- ultralytics/data/explorer/gui/__init__.py +1 -0
- ultralytics/data/explorer/gui/dash.py +123 -89
- ultralytics/data/explorer/utils.py +37 -39
- ultralytics/data/loaders.py +75 -62
- ultralytics/data/split_dota.py +44 -36
- ultralytics/data/utils.py +160 -142
- ultralytics/engine/exporter.py +348 -292
- ultralytics/engine/model.py +102 -66
- ultralytics/engine/predictor.py +74 -55
- ultralytics/engine/results.py +61 -41
- ultralytics/engine/trainer.py +192 -144
- ultralytics/engine/tuner.py +66 -59
- ultralytics/engine/validator.py +31 -26
- ultralytics/hub/__init__.py +54 -31
- ultralytics/hub/auth.py +28 -25
- ultralytics/hub/session.py +282 -133
- ultralytics/hub/utils.py +64 -42
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +6 -6
- ultralytics/models/fastsam/predict.py +3 -2
- ultralytics/models/fastsam/prompt.py +55 -48
- ultralytics/models/fastsam/val.py +1 -1
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +9 -8
- ultralytics/models/nas/predict.py +8 -6
- ultralytics/models/nas/val.py +11 -9
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +11 -9
- ultralytics/models/rtdetr/train.py +18 -16
- ultralytics/models/rtdetr/val.py +25 -19
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +13 -14
- ultralytics/models/sam/build.py +44 -42
- ultralytics/models/sam/model.py +6 -6
- ultralytics/models/sam/modules/decoders.py +6 -4
- ultralytics/models/sam/modules/encoders.py +37 -35
- ultralytics/models/sam/modules/sam.py +5 -4
- ultralytics/models/sam/modules/tiny_encoder.py +95 -73
- ultralytics/models/sam/modules/transformer.py +3 -2
- ultralytics/models/sam/predict.py +39 -27
- ultralytics/models/utils/loss.py +99 -95
- ultralytics/models/utils/ops.py +34 -31
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +8 -6
- ultralytics/models/yolo/classify/train.py +37 -31
- ultralytics/models/yolo/classify/val.py +26 -24
- ultralytics/models/yolo/detect/__init__.py +1 -1
- ultralytics/models/yolo/detect/predict.py +8 -6
- ultralytics/models/yolo/detect/train.py +47 -37
- ultralytics/models/yolo/detect/val.py +100 -82
- ultralytics/models/yolo/model.py +31 -25
- ultralytics/models/yolo/obb/__init__.py +1 -1
- ultralytics/models/yolo/obb/predict.py +13 -11
- ultralytics/models/yolo/obb/train.py +3 -3
- ultralytics/models/yolo/obb/val.py +70 -59
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +17 -12
- ultralytics/models/yolo/pose/train.py +28 -25
- ultralytics/models/yolo/pose/val.py +91 -64
- ultralytics/models/yolo/segment/__init__.py +1 -1
- ultralytics/models/yolo/segment/predict.py +10 -8
- ultralytics/models/yolo/segment/train.py +16 -15
- ultralytics/models/yolo/segment/val.py +90 -68
- ultralytics/nn/__init__.py +26 -6
- ultralytics/nn/autobackend.py +144 -112
- ultralytics/nn/modules/__init__.py +96 -13
- ultralytics/nn/modules/block.py +28 -7
- ultralytics/nn/modules/conv.py +41 -23
- ultralytics/nn/modules/head.py +60 -52
- ultralytics/nn/modules/transformer.py +49 -32
- ultralytics/nn/modules/utils.py +20 -15
- ultralytics/nn/tasks.py +215 -141
- ultralytics/solutions/ai_gym.py +59 -47
- ultralytics/solutions/distance_calculation.py +17 -14
- ultralytics/solutions/heatmap.py +57 -55
- ultralytics/solutions/object_counter.py +46 -39
- ultralytics/solutions/speed_estimation.py +13 -16
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +1 -0
- ultralytics/trackers/bot_sort.py +2 -1
- ultralytics/trackers/byte_tracker.py +10 -7
- ultralytics/trackers/track.py +7 -7
- ultralytics/trackers/utils/gmc.py +25 -25
- ultralytics/trackers/utils/kalman_filter.py +85 -42
- ultralytics/trackers/utils/matching.py +8 -7
- ultralytics/utils/__init__.py +173 -152
- ultralytics/utils/autobatch.py +10 -10
- ultralytics/utils/benchmarks.py +76 -86
- ultralytics/utils/callbacks/__init__.py +1 -1
- ultralytics/utils/callbacks/base.py +29 -29
- ultralytics/utils/callbacks/clearml.py +51 -43
- ultralytics/utils/callbacks/comet.py +81 -66
- ultralytics/utils/callbacks/dvc.py +33 -26
- ultralytics/utils/callbacks/hub.py +44 -26
- ultralytics/utils/callbacks/mlflow.py +31 -24
- ultralytics/utils/callbacks/neptune.py +35 -25
- ultralytics/utils/callbacks/raytune.py +9 -4
- ultralytics/utils/callbacks/tensorboard.py +16 -11
- ultralytics/utils/callbacks/wb.py +39 -33
- ultralytics/utils/checks.py +189 -141
- ultralytics/utils/dist.py +15 -12
- ultralytics/utils/downloads.py +112 -96
- ultralytics/utils/errors.py +1 -1
- ultralytics/utils/files.py +11 -11
- ultralytics/utils/instance.py +22 -22
- ultralytics/utils/loss.py +117 -67
- ultralytics/utils/metrics.py +224 -158
- ultralytics/utils/ops.py +38 -28
- ultralytics/utils/patches.py +3 -3
- ultralytics/utils/plotting.py +217 -120
- ultralytics/utils/tal.py +19 -13
- ultralytics/utils/torch_utils.py +138 -109
- ultralytics/utils/triton.py +12 -10
- ultralytics/utils/tuner.py +49 -47
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/METADATA +2 -1
- ultralytics-8.0.239.dist-info/RECORD +188 -0
- ultralytics-8.0.238.dist-info/RECORD +0 -188
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.0.238.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
|
@@ -31,59 +31,76 @@ class PoseValidator(DetectionValidator):
|
|
|
31
31
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
|
32
32
|
self.sigma = None
|
|
33
33
|
self.kpt_shape = None
|
|
34
|
-
self.args.task =
|
|
34
|
+
self.args.task = "pose"
|
|
35
35
|
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
|
36
|
-
if isinstance(self.args.device, str) and self.args.device.lower() ==
|
|
37
|
-
LOGGER.warning(
|
|
38
|
-
|
|
36
|
+
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
|
37
|
+
LOGGER.warning(
|
|
38
|
+
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
|
39
|
+
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
|
40
|
+
)
|
|
39
41
|
|
|
40
42
|
def preprocess(self, batch):
|
|
41
43
|
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
|
|
42
44
|
batch = super().preprocess(batch)
|
|
43
|
-
batch[
|
|
45
|
+
batch["keypoints"] = batch["keypoints"].to(self.device).float()
|
|
44
46
|
return batch
|
|
45
47
|
|
|
46
48
|
def get_desc(self):
|
|
47
49
|
"""Returns description of evaluation metrics in string format."""
|
|
48
|
-
return (
|
|
49
|
-
|
|
50
|
+
return ("%22s" + "%11s" * 10) % (
|
|
51
|
+
"Class",
|
|
52
|
+
"Images",
|
|
53
|
+
"Instances",
|
|
54
|
+
"Box(P",
|
|
55
|
+
"R",
|
|
56
|
+
"mAP50",
|
|
57
|
+
"mAP50-95)",
|
|
58
|
+
"Pose(P",
|
|
59
|
+
"R",
|
|
60
|
+
"mAP50",
|
|
61
|
+
"mAP50-95)",
|
|
62
|
+
)
|
|
50
63
|
|
|
51
64
|
def postprocess(self, preds):
|
|
52
65
|
"""Apply non-maximum suppression and return detections with high confidence scores."""
|
|
53
|
-
return ops.non_max_suppression(
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
66
|
+
return ops.non_max_suppression(
|
|
67
|
+
preds,
|
|
68
|
+
self.args.conf,
|
|
69
|
+
self.args.iou,
|
|
70
|
+
labels=self.lb,
|
|
71
|
+
multi_label=True,
|
|
72
|
+
agnostic=self.args.single_cls,
|
|
73
|
+
max_det=self.args.max_det,
|
|
74
|
+
nc=self.nc,
|
|
75
|
+
)
|
|
61
76
|
|
|
62
77
|
def init_metrics(self, model):
|
|
63
78
|
"""Initiate pose estimation metrics for YOLO model."""
|
|
64
79
|
super().init_metrics(model)
|
|
65
|
-
self.kpt_shape = self.data[
|
|
80
|
+
self.kpt_shape = self.data["kpt_shape"]
|
|
66
81
|
is_pose = self.kpt_shape == [17, 3]
|
|
67
82
|
nkpt = self.kpt_shape[0]
|
|
68
83
|
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
|
69
84
|
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[])
|
|
70
85
|
|
|
71
86
|
def _prepare_batch(self, si, batch):
|
|
87
|
+
"""Prepares a batch for processing by converting keypoints to float and moving to device."""
|
|
72
88
|
pbatch = super()._prepare_batch(si, batch)
|
|
73
|
-
kpts = batch[
|
|
74
|
-
h, w = pbatch[
|
|
89
|
+
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
|
90
|
+
h, w = pbatch["imgsz"]
|
|
75
91
|
kpts = kpts.clone()
|
|
76
92
|
kpts[..., 0] *= w
|
|
77
93
|
kpts[..., 1] *= h
|
|
78
|
-
kpts = ops.scale_coords(pbatch[
|
|
79
|
-
pbatch[
|
|
94
|
+
kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
|
95
|
+
pbatch["kpts"] = kpts
|
|
80
96
|
return pbatch
|
|
81
97
|
|
|
82
98
|
def _prepare_pred(self, pred, pbatch):
|
|
99
|
+
"""Prepares and scales keypoints in a batch for pose processing."""
|
|
83
100
|
predn = super()._prepare_pred(pred, pbatch)
|
|
84
|
-
nk = pbatch[
|
|
101
|
+
nk = pbatch["kpts"].shape[1]
|
|
85
102
|
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
|
86
|
-
ops.scale_coords(pbatch[
|
|
103
|
+
ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
|
|
87
104
|
return predn, pred_kpts
|
|
88
105
|
|
|
89
106
|
def update_metrics(self, preds, batch):
|
|
@@ -91,14 +108,16 @@ class PoseValidator(DetectionValidator):
|
|
|
91
108
|
for si, pred in enumerate(preds):
|
|
92
109
|
self.seen += 1
|
|
93
110
|
npr = len(pred)
|
|
94
|
-
stat = dict(
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
111
|
+
stat = dict(
|
|
112
|
+
conf=torch.zeros(0, device=self.device),
|
|
113
|
+
pred_cls=torch.zeros(0, device=self.device),
|
|
114
|
+
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
115
|
+
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
116
|
+
)
|
|
98
117
|
pbatch = self._prepare_batch(si, batch)
|
|
99
|
-
cls, bbox = pbatch.pop(
|
|
118
|
+
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
|
100
119
|
nl = len(cls)
|
|
101
|
-
stat[
|
|
120
|
+
stat["target_cls"] = cls
|
|
102
121
|
if npr == 0:
|
|
103
122
|
if nl:
|
|
104
123
|
for k in self.stats.keys():
|
|
@@ -111,13 +130,13 @@ class PoseValidator(DetectionValidator):
|
|
|
111
130
|
if self.args.single_cls:
|
|
112
131
|
pred[:, 5] = 0
|
|
113
132
|
predn, pred_kpts = self._prepare_pred(pred, pbatch)
|
|
114
|
-
stat[
|
|
115
|
-
stat[
|
|
133
|
+
stat["conf"] = predn[:, 4]
|
|
134
|
+
stat["pred_cls"] = predn[:, 5]
|
|
116
135
|
|
|
117
136
|
# Evaluate
|
|
118
137
|
if nl:
|
|
119
|
-
stat[
|
|
120
|
-
stat[
|
|
138
|
+
stat["tp"] = self._process_batch(predn, bbox, cls)
|
|
139
|
+
stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
|
|
121
140
|
if self.args.plots:
|
|
122
141
|
self.confusion_matrix.process_batch(predn, bbox, cls)
|
|
123
142
|
|
|
@@ -126,7 +145,7 @@ class PoseValidator(DetectionValidator):
|
|
|
126
145
|
|
|
127
146
|
# Save
|
|
128
147
|
if self.args.save_json:
|
|
129
|
-
self.pred_to_json(predn, batch[
|
|
148
|
+
self.pred_to_json(predn, batch["im_file"][si])
|
|
130
149
|
# if self.args.save_txt:
|
|
131
150
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
|
132
151
|
|
|
@@ -157,26 +176,30 @@ class PoseValidator(DetectionValidator):
|
|
|
157
176
|
|
|
158
177
|
def plot_val_samples(self, batch, ni):
|
|
159
178
|
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
|
160
|
-
plot_images(
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
179
|
+
plot_images(
|
|
180
|
+
batch["img"],
|
|
181
|
+
batch["batch_idx"],
|
|
182
|
+
batch["cls"].squeeze(-1),
|
|
183
|
+
batch["bboxes"],
|
|
184
|
+
kpts=batch["keypoints"],
|
|
185
|
+
paths=batch["im_file"],
|
|
186
|
+
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
|
187
|
+
names=self.names,
|
|
188
|
+
on_plot=self.on_plot,
|
|
189
|
+
)
|
|
169
190
|
|
|
170
191
|
def plot_predictions(self, batch, preds, ni):
|
|
171
192
|
"""Plots predictions for YOLO model."""
|
|
172
193
|
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
|
173
|
-
plot_images(
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
194
|
+
plot_images(
|
|
195
|
+
batch["img"],
|
|
196
|
+
*output_to_target(preds, max_det=self.args.max_det),
|
|
197
|
+
kpts=pred_kpts,
|
|
198
|
+
paths=batch["im_file"],
|
|
199
|
+
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
200
|
+
names=self.names,
|
|
201
|
+
on_plot=self.on_plot,
|
|
202
|
+
) # pred
|
|
180
203
|
|
|
181
204
|
def pred_to_json(self, predn, filename):
|
|
182
205
|
"""Converts YOLO predictions to COCO JSON format."""
|
|
@@ -185,37 +208,41 @@ class PoseValidator(DetectionValidator):
|
|
|
185
208
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
|
186
209
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
|
187
210
|
for p, b in zip(predn.tolist(), box.tolist()):
|
|
188
|
-
self.jdict.append(
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
211
|
+
self.jdict.append(
|
|
212
|
+
{
|
|
213
|
+
"image_id": image_id,
|
|
214
|
+
"category_id": self.class_map[int(p[5])],
|
|
215
|
+
"bbox": [round(x, 3) for x in b],
|
|
216
|
+
"keypoints": p[6:],
|
|
217
|
+
"score": round(p[4], 5),
|
|
218
|
+
}
|
|
219
|
+
)
|
|
194
220
|
|
|
195
221
|
def eval_json(self, stats):
|
|
196
222
|
"""Evaluates object detection model using COCO JSON format."""
|
|
197
223
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
|
198
|
-
anno_json = self.data[
|
|
199
|
-
pred_json = self.save_dir /
|
|
200
|
-
LOGGER.info(f
|
|
224
|
+
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
|
225
|
+
pred_json = self.save_dir / "predictions.json" # predictions
|
|
226
|
+
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
|
201
227
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
|
202
|
-
check_requirements(
|
|
228
|
+
check_requirements("pycocotools>=2.0.6")
|
|
203
229
|
from pycocotools.coco import COCO # noqa
|
|
204
230
|
from pycocotools.cocoeval import COCOeval # noqa
|
|
205
231
|
|
|
206
232
|
for x in anno_json, pred_json:
|
|
207
|
-
assert x.is_file(), f
|
|
233
|
+
assert x.is_file(), f"{x} file not found"
|
|
208
234
|
anno = COCO(str(anno_json)) # init annotations api
|
|
209
235
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
|
210
|
-
for i, eval in enumerate([COCOeval(anno, pred,
|
|
236
|
+
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
|
|
211
237
|
if self.is_coco:
|
|
212
238
|
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
|
213
239
|
eval.evaluate()
|
|
214
240
|
eval.accumulate()
|
|
215
241
|
eval.summarize()
|
|
216
242
|
idx = i * 4 + 2
|
|
217
|
-
stats[self.metrics.keys[idx + 1]], stats[
|
|
218
|
-
|
|
243
|
+
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
|
|
244
|
+
:2
|
|
245
|
+
] # update mAP50-95 and mAP50
|
|
219
246
|
except Exception as e:
|
|
220
|
-
LOGGER.warning(f
|
|
247
|
+
LOGGER.warning(f"pycocotools unable to run: {e}")
|
|
221
248
|
return stats
|
|
@@ -4,4 +4,4 @@ from .predict import SegmentationPredictor
|
|
|
4
4
|
from .train import SegmentationTrainer
|
|
5
5
|
from .val import SegmentationValidator
|
|
6
6
|
|
|
7
|
-
__all__ =
|
|
7
|
+
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
|
|
@@ -23,17 +23,19 @@ class SegmentationPredictor(DetectionPredictor):
|
|
|
23
23
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
|
24
24
|
"""Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks."""
|
|
25
25
|
super().__init__(cfg, overrides, _callbacks)
|
|
26
|
-
self.args.task =
|
|
26
|
+
self.args.task = "segment"
|
|
27
27
|
|
|
28
28
|
def postprocess(self, preds, img, orig_imgs):
|
|
29
29
|
"""Applies non-max suppression and processes detections for each image in an input batch."""
|
|
30
|
-
p = ops.non_max_suppression(
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
30
|
+
p = ops.non_max_suppression(
|
|
31
|
+
preds[0],
|
|
32
|
+
self.args.conf,
|
|
33
|
+
self.args.iou,
|
|
34
|
+
agnostic=self.args.agnostic_nms,
|
|
35
|
+
max_det=self.args.max_det,
|
|
36
|
+
nc=len(self.model.names),
|
|
37
|
+
classes=self.args.classes,
|
|
38
|
+
)
|
|
37
39
|
|
|
38
40
|
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
|
39
41
|
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
|
@@ -26,12 +26,12 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
26
26
|
"""Initialize a SegmentationTrainer object with given arguments."""
|
|
27
27
|
if overrides is None:
|
|
28
28
|
overrides = {}
|
|
29
|
-
overrides[
|
|
29
|
+
overrides["task"] = "segment"
|
|
30
30
|
super().__init__(cfg, overrides, _callbacks)
|
|
31
31
|
|
|
32
32
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
|
33
33
|
"""Return SegmentationModel initialized with specified config and weights."""
|
|
34
|
-
model = SegmentationModel(cfg, ch=3, nc=self.data[
|
|
34
|
+
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
|
35
35
|
if weights:
|
|
36
36
|
model.load(weights)
|
|
37
37
|
|
|
@@ -39,22 +39,23 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
|
39
39
|
|
|
40
40
|
def get_validator(self):
|
|
41
41
|
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
|
42
|
-
self.loss_names =
|
|
43
|
-
return yolo.segment.SegmentationValidator(
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
_callbacks=self.callbacks)
|
|
42
|
+
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
|
43
|
+
return yolo.segment.SegmentationValidator(
|
|
44
|
+
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
|
45
|
+
)
|
|
47
46
|
|
|
48
47
|
def plot_training_samples(self, batch, ni):
|
|
49
48
|
"""Creates a plot of training sample images with labels and box coordinates."""
|
|
50
|
-
plot_images(
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
49
|
+
plot_images(
|
|
50
|
+
batch["img"],
|
|
51
|
+
batch["batch_idx"],
|
|
52
|
+
batch["cls"].squeeze(-1),
|
|
53
|
+
batch["bboxes"],
|
|
54
|
+
masks=batch["masks"],
|
|
55
|
+
paths=batch["im_file"],
|
|
56
|
+
fname=self.save_dir / f"train_batch{ni}.jpg",
|
|
57
|
+
on_plot=self.on_plot,
|
|
58
|
+
)
|
|
58
59
|
|
|
59
60
|
def plot_metrics(self):
|
|
60
61
|
"""Plots training/val metrics."""
|
|
@@ -33,13 +33,13 @@ class SegmentationValidator(DetectionValidator):
|
|
|
33
33
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
|
34
34
|
self.plot_masks = None
|
|
35
35
|
self.process = None
|
|
36
|
-
self.args.task =
|
|
36
|
+
self.args.task = "segment"
|
|
37
37
|
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
|
38
38
|
|
|
39
39
|
def preprocess(self, batch):
|
|
40
40
|
"""Preprocesses batch by converting masks to float and sending to device."""
|
|
41
41
|
batch = super().preprocess(batch)
|
|
42
|
-
batch[
|
|
42
|
+
batch["masks"] = batch["masks"].to(self.device).float()
|
|
43
43
|
return batch
|
|
44
44
|
|
|
45
45
|
def init_metrics(self, model):
|
|
@@ -47,7 +47,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
47
47
|
super().init_metrics(model)
|
|
48
48
|
self.plot_masks = []
|
|
49
49
|
if self.args.save_json:
|
|
50
|
-
check_requirements(
|
|
50
|
+
check_requirements("pycocotools>=2.0.6")
|
|
51
51
|
self.process = ops.process_mask_upsample # more accurate
|
|
52
52
|
else:
|
|
53
53
|
self.process = ops.process_mask # faster
|
|
@@ -55,31 +55,46 @@ class SegmentationValidator(DetectionValidator):
|
|
|
55
55
|
|
|
56
56
|
def get_desc(self):
|
|
57
57
|
"""Return a formatted description of evaluation metrics."""
|
|
58
|
-
return (
|
|
59
|
-
|
|
58
|
+
return ("%22s" + "%11s" * 10) % (
|
|
59
|
+
"Class",
|
|
60
|
+
"Images",
|
|
61
|
+
"Instances",
|
|
62
|
+
"Box(P",
|
|
63
|
+
"R",
|
|
64
|
+
"mAP50",
|
|
65
|
+
"mAP50-95)",
|
|
66
|
+
"Mask(P",
|
|
67
|
+
"R",
|
|
68
|
+
"mAP50",
|
|
69
|
+
"mAP50-95)",
|
|
70
|
+
)
|
|
60
71
|
|
|
61
72
|
def postprocess(self, preds):
|
|
62
73
|
"""Post-processes YOLO predictions and returns output detections with proto."""
|
|
63
|
-
p = ops.non_max_suppression(
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
74
|
+
p = ops.non_max_suppression(
|
|
75
|
+
preds[0],
|
|
76
|
+
self.args.conf,
|
|
77
|
+
self.args.iou,
|
|
78
|
+
labels=self.lb,
|
|
79
|
+
multi_label=True,
|
|
80
|
+
agnostic=self.args.single_cls,
|
|
81
|
+
max_det=self.args.max_det,
|
|
82
|
+
nc=self.nc,
|
|
83
|
+
)
|
|
71
84
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
|
72
85
|
return p, proto
|
|
73
86
|
|
|
74
87
|
def _prepare_batch(self, si, batch):
|
|
88
|
+
"""Prepares a batch for training or inference by processing images and targets."""
|
|
75
89
|
prepared_batch = super()._prepare_batch(si, batch)
|
|
76
|
-
midx = [si] if self.args.overlap_mask else batch[
|
|
77
|
-
prepared_batch[
|
|
90
|
+
midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
|
|
91
|
+
prepared_batch["masks"] = batch["masks"][midx]
|
|
78
92
|
return prepared_batch
|
|
79
93
|
|
|
80
94
|
def _prepare_pred(self, pred, pbatch, proto):
|
|
95
|
+
"""Prepares a batch for training or inference by processing images and targets."""
|
|
81
96
|
predn = super()._prepare_pred(pred, pbatch)
|
|
82
|
-
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch[
|
|
97
|
+
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
|
|
83
98
|
return predn, pred_masks
|
|
84
99
|
|
|
85
100
|
def update_metrics(self, preds, batch):
|
|
@@ -87,14 +102,16 @@ class SegmentationValidator(DetectionValidator):
|
|
|
87
102
|
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
|
88
103
|
self.seen += 1
|
|
89
104
|
npr = len(pred)
|
|
90
|
-
stat = dict(
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
105
|
+
stat = dict(
|
|
106
|
+
conf=torch.zeros(0, device=self.device),
|
|
107
|
+
pred_cls=torch.zeros(0, device=self.device),
|
|
108
|
+
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
109
|
+
tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
|
110
|
+
)
|
|
94
111
|
pbatch = self._prepare_batch(si, batch)
|
|
95
|
-
cls, bbox = pbatch.pop(
|
|
112
|
+
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
|
96
113
|
nl = len(cls)
|
|
97
|
-
stat[
|
|
114
|
+
stat["target_cls"] = cls
|
|
98
115
|
if npr == 0:
|
|
99
116
|
if nl:
|
|
100
117
|
for k in self.stats.keys():
|
|
@@ -104,24 +121,20 @@ class SegmentationValidator(DetectionValidator):
|
|
|
104
121
|
continue
|
|
105
122
|
|
|
106
123
|
# Masks
|
|
107
|
-
gt_masks = pbatch.pop(
|
|
124
|
+
gt_masks = pbatch.pop("masks")
|
|
108
125
|
# Predictions
|
|
109
126
|
if self.args.single_cls:
|
|
110
127
|
pred[:, 5] = 0
|
|
111
128
|
predn, pred_masks = self._prepare_pred(pred, pbatch, proto)
|
|
112
|
-
stat[
|
|
113
|
-
stat[
|
|
129
|
+
stat["conf"] = predn[:, 4]
|
|
130
|
+
stat["pred_cls"] = predn[:, 5]
|
|
114
131
|
|
|
115
132
|
# Evaluate
|
|
116
133
|
if nl:
|
|
117
|
-
stat[
|
|
118
|
-
stat[
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
pred_masks,
|
|
122
|
-
gt_masks,
|
|
123
|
-
self.args.overlap_mask,
|
|
124
|
-
masks=True)
|
|
134
|
+
stat["tp"] = self._process_batch(predn, bbox, cls)
|
|
135
|
+
stat["tp_m"] = self._process_batch(
|
|
136
|
+
predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
|
|
137
|
+
)
|
|
125
138
|
if self.args.plots:
|
|
126
139
|
self.confusion_matrix.process_batch(predn, bbox, cls)
|
|
127
140
|
|
|
@@ -134,10 +147,12 @@ class SegmentationValidator(DetectionValidator):
|
|
|
134
147
|
|
|
135
148
|
# Save
|
|
136
149
|
if self.args.save_json:
|
|
137
|
-
pred_masks = ops.scale_image(
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
150
|
+
pred_masks = ops.scale_image(
|
|
151
|
+
pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
|
|
152
|
+
pbatch["ori_shape"],
|
|
153
|
+
ratio_pad=batch["ratio_pad"][si],
|
|
154
|
+
)
|
|
155
|
+
self.pred_to_json(predn, batch["im_file"][si], pred_masks)
|
|
141
156
|
# if self.args.save_txt:
|
|
142
157
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
|
143
158
|
|
|
@@ -164,7 +179,7 @@ class SegmentationValidator(DetectionValidator):
|
|
|
164
179
|
gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
|
|
165
180
|
gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
|
|
166
181
|
if gt_masks.shape[1:] != pred_masks.shape[1:]:
|
|
167
|
-
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode=
|
|
182
|
+
gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
|
|
168
183
|
gt_masks = gt_masks.gt_(0.5)
|
|
169
184
|
iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
|
|
170
185
|
else: # boxes
|
|
@@ -174,26 +189,29 @@ class SegmentationValidator(DetectionValidator):
|
|
|
174
189
|
|
|
175
190
|
def plot_val_samples(self, batch, ni):
|
|
176
191
|
"""Plots validation samples with bounding box labels."""
|
|
177
|
-
plot_images(
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
192
|
+
plot_images(
|
|
193
|
+
batch["img"],
|
|
194
|
+
batch["batch_idx"],
|
|
195
|
+
batch["cls"].squeeze(-1),
|
|
196
|
+
batch["bboxes"],
|
|
197
|
+
masks=batch["masks"],
|
|
198
|
+
paths=batch["im_file"],
|
|
199
|
+
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
|
200
|
+
names=self.names,
|
|
201
|
+
on_plot=self.on_plot,
|
|
202
|
+
)
|
|
186
203
|
|
|
187
204
|
def plot_predictions(self, batch, preds, ni):
|
|
188
205
|
"""Plots batch predictions with masks and bounding boxes."""
|
|
189
206
|
plot_images(
|
|
190
|
-
batch[
|
|
207
|
+
batch["img"],
|
|
191
208
|
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
|
|
192
209
|
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
|
193
|
-
paths=batch[
|
|
194
|
-
fname=self.save_dir / f
|
|
210
|
+
paths=batch["im_file"],
|
|
211
|
+
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
195
212
|
names=self.names,
|
|
196
|
-
on_plot=self.on_plot
|
|
213
|
+
on_plot=self.on_plot,
|
|
214
|
+
) # pred
|
|
197
215
|
self.plot_masks.clear()
|
|
198
216
|
|
|
199
217
|
def pred_to_json(self, predn, filename, pred_masks):
|
|
@@ -203,8 +221,8 @@ class SegmentationValidator(DetectionValidator):
|
|
|
203
221
|
|
|
204
222
|
def single_encode(x):
|
|
205
223
|
"""Encode predicted masks as RLE and append results to jdict."""
|
|
206
|
-
rle = encode(np.asarray(x[:, :, None], order=
|
|
207
|
-
rle[
|
|
224
|
+
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
|
|
225
|
+
rle["counts"] = rle["counts"].decode("utf-8")
|
|
208
226
|
return rle
|
|
209
227
|
|
|
210
228
|
stem = Path(filename).stem
|
|
@@ -215,37 +233,41 @@ class SegmentationValidator(DetectionValidator):
|
|
|
215
233
|
with ThreadPool(NUM_THREADS) as pool:
|
|
216
234
|
rles = pool.map(single_encode, pred_masks)
|
|
217
235
|
for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
|
|
218
|
-
self.jdict.append(
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
236
|
+
self.jdict.append(
|
|
237
|
+
{
|
|
238
|
+
"image_id": image_id,
|
|
239
|
+
"category_id": self.class_map[int(p[5])],
|
|
240
|
+
"bbox": [round(x, 3) for x in b],
|
|
241
|
+
"score": round(p[4], 5),
|
|
242
|
+
"segmentation": rles[i],
|
|
243
|
+
}
|
|
244
|
+
)
|
|
224
245
|
|
|
225
246
|
def eval_json(self, stats):
|
|
226
247
|
"""Return COCO-style object detection evaluation metrics."""
|
|
227
248
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
|
228
|
-
anno_json = self.data[
|
|
229
|
-
pred_json = self.save_dir /
|
|
230
|
-
LOGGER.info(f
|
|
249
|
+
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
|
250
|
+
pred_json = self.save_dir / "predictions.json" # predictions
|
|
251
|
+
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
|
231
252
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
|
232
|
-
check_requirements(
|
|
253
|
+
check_requirements("pycocotools>=2.0.6")
|
|
233
254
|
from pycocotools.coco import COCO # noqa
|
|
234
255
|
from pycocotools.cocoeval import COCOeval # noqa
|
|
235
256
|
|
|
236
257
|
for x in anno_json, pred_json:
|
|
237
|
-
assert x.is_file(), f
|
|
258
|
+
assert x.is_file(), f"{x} file not found"
|
|
238
259
|
anno = COCO(str(anno_json)) # init annotations api
|
|
239
260
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
|
240
|
-
for i, eval in enumerate([COCOeval(anno, pred,
|
|
261
|
+
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]):
|
|
241
262
|
if self.is_coco:
|
|
242
263
|
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
|
|
243
264
|
eval.evaluate()
|
|
244
265
|
eval.accumulate()
|
|
245
266
|
eval.summarize()
|
|
246
267
|
idx = i * 4 + 2
|
|
247
|
-
stats[self.metrics.keys[idx + 1]], stats[
|
|
248
|
-
|
|
268
|
+
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
|
|
269
|
+
:2
|
|
270
|
+
] # update mAP50-95 and mAP50
|
|
249
271
|
except Exception as e:
|
|
250
|
-
LOGGER.warning(f
|
|
272
|
+
LOGGER.warning(f"pycocotools unable to run: {e}")
|
|
251
273
|
return stats
|