ultralytics 8.3.88__py3-none-any.whl → 8.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +125 -39
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +34 -33
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +33 -47
- ultralytics/engine/exporter.py +19 -17
- ultralytics/engine/model.py +69 -90
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +31 -38
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +21 -26
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +23 -17
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +29 -24
- ultralytics/models/nas/predict.py +14 -11
- ultralytics/models/nas/val.py +11 -13
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +21 -21
- ultralytics/models/rtdetr/train.py +25 -24
- ultralytics/models/rtdetr/val.py +47 -14
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +30 -12
- ultralytics/models/yolo/classify/train.py +83 -19
- ultralytics/models/yolo/classify/val.py +45 -23
- ultralytics/models/yolo/detect/predict.py +29 -19
- ultralytics/models/yolo/detect/train.py +90 -23
- ultralytics/models/yolo/detect/val.py +150 -29
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +18 -13
- ultralytics/models/yolo/obb/train.py +12 -8
- ultralytics/models/yolo/obb/val.py +35 -22
- ultralytics/models/yolo/pose/predict.py +28 -15
- ultralytics/models/yolo/pose/train.py +21 -8
- ultralytics/models/yolo/pose/val.py +51 -31
- ultralytics/models/yolo/segment/predict.py +27 -16
- ultralytics/models/yolo/segment/train.py +11 -8
- ultralytics/models/yolo/segment/val.py +110 -29
- ultralytics/models/yolo/world/train.py +43 -16
- ultralytics/models/yolo/world/train_world.py +61 -36
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +12 -12
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +226 -79
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +37 -35
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +35 -22
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +139 -68
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +37 -56
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +117 -52
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +65 -61
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +72 -59
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +202 -64
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +13 -25
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/METADATA +2 -2
- ultralytics-8.3.90.dist-info/RECORD +250 -0
- ultralytics-8.3.88.dist-info/RECORD +0 -250
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.88.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -16,18 +16,38 @@ class PoseValidator(DetectionValidator):
|
|
16
16
|
"""
|
17
17
|
A class extending the DetectionValidator class for validation based on a pose model.
|
18
18
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
19
|
+
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
20
|
+
specialized metrics for pose evaluation.
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
sigma (np.ndarray): Sigma values for OKS calculation, either from OKS_SIGMA or ones divided by number of keypoints.
|
24
|
+
kpt_shape (List[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
25
|
+
args (Dict): Arguments for the validator including task set to "pose".
|
26
|
+
metrics (PoseMetrics): Metrics object for pose evaluation.
|
27
|
+
|
28
|
+
Methods:
|
29
|
+
preprocess: Preprocesses batch data for pose validation.
|
30
|
+
get_desc: Returns description of evaluation metrics.
|
31
|
+
init_metrics: Initializes pose metrics for the model.
|
32
|
+
_prepare_batch: Prepares a batch for processing.
|
33
|
+
_prepare_pred: Prepares and scales predictions for evaluation.
|
34
|
+
update_metrics: Updates metrics with new predictions.
|
35
|
+
_process_batch: Processes batch to compute IoU between detections and ground truth.
|
36
|
+
plot_val_samples: Plots validation samples with ground truth annotations.
|
37
|
+
plot_predictions: Plots model predictions.
|
38
|
+
save_one_txt: Saves detections to a text file.
|
39
|
+
pred_to_json: Converts predictions to COCO JSON format.
|
40
|
+
eval_json: Evaluates model using COCO JSON format.
|
41
|
+
|
42
|
+
Examples:
|
43
|
+
>>> from ultralytics.models.yolo.pose import PoseValidator
|
44
|
+
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
45
|
+
>>> validator = PoseValidator(args=args)
|
46
|
+
>>> validator()
|
27
47
|
"""
|
28
48
|
|
29
49
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
30
|
-
"""Initialize a
|
50
|
+
"""Initialize a PoseValidator object with custom parameters and assigned attributes."""
|
31
51
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
32
52
|
self.sigma = None
|
33
53
|
self.kpt_shape = None
|
@@ -40,13 +60,13 @@ class PoseValidator(DetectionValidator):
|
|
40
60
|
)
|
41
61
|
|
42
62
|
def preprocess(self, batch):
|
43
|
-
"""
|
63
|
+
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
44
64
|
batch = super().preprocess(batch)
|
45
65
|
batch["keypoints"] = batch["keypoints"].to(self.device).float()
|
46
66
|
return batch
|
47
67
|
|
48
68
|
def get_desc(self):
|
49
|
-
"""
|
69
|
+
"""Return description of evaluation metrics in string format."""
|
50
70
|
return ("%22s" + "%11s" * 10) % (
|
51
71
|
"Class",
|
52
72
|
"Images",
|
@@ -62,7 +82,7 @@ class PoseValidator(DetectionValidator):
|
|
62
82
|
)
|
63
83
|
|
64
84
|
def init_metrics(self, model):
|
65
|
-
"""
|
85
|
+
"""Initialize pose estimation metrics for YOLO model."""
|
66
86
|
super().init_metrics(model)
|
67
87
|
self.kpt_shape = self.data["kpt_shape"]
|
68
88
|
is_pose = self.kpt_shape == [17, 3]
|
@@ -71,7 +91,7 @@ class PoseValidator(DetectionValidator):
|
|
71
91
|
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
|
72
92
|
|
73
93
|
def _prepare_batch(self, si, batch):
|
74
|
-
"""
|
94
|
+
"""Prepare a batch for processing by converting keypoints to float and scaling to original dimensions."""
|
75
95
|
pbatch = super()._prepare_batch(si, batch)
|
76
96
|
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
77
97
|
h, w = pbatch["imgsz"]
|
@@ -83,7 +103,7 @@ class PoseValidator(DetectionValidator):
|
|
83
103
|
return pbatch
|
84
104
|
|
85
105
|
def _prepare_pred(self, pred, pbatch):
|
86
|
-
"""
|
106
|
+
"""Prepare and scale keypoints in predictions for pose processing."""
|
87
107
|
predn = super()._prepare_pred(pred, pbatch)
|
88
108
|
nk = pbatch["kpts"].shape[1]
|
89
109
|
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
|
@@ -91,7 +111,16 @@ class PoseValidator(DetectionValidator):
|
|
91
111
|
return predn, pred_kpts
|
92
112
|
|
93
113
|
def update_metrics(self, preds, batch):
|
94
|
-
"""
|
114
|
+
"""
|
115
|
+
Update metrics with new predictions and ground truth data.
|
116
|
+
|
117
|
+
This method processes each prediction, compares it with ground truth, and updates various statistics
|
118
|
+
for performance evaluation.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
preds (List[torch.Tensor]): List of prediction tensors from the model.
|
122
|
+
batch (Dict): Batch data containing images and ground truth annotations.
|
123
|
+
"""
|
95
124
|
for si, pred in enumerate(preds):
|
96
125
|
self.seen += 1
|
97
126
|
npr = len(pred)
|
@@ -161,18 +190,9 @@ class PoseValidator(DetectionValidator):
|
|
161
190
|
(torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
|
162
191
|
where N is the number of detections.
|
163
192
|
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
gt_bboxes = torch.rand(50, 4) # 50 ground truth boxes: (x1, y1, x2, y2)
|
168
|
-
gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices
|
169
|
-
pred_kpts = torch.rand(100, 51) # 100 predicted keypoints
|
170
|
-
gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints
|
171
|
-
correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)
|
172
|
-
```
|
173
|
-
|
174
|
-
Note:
|
175
|
-
`0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
193
|
+
Notes:
|
194
|
+
`0.53` scale factor used in area computation is referenced from
|
195
|
+
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
176
196
|
"""
|
177
197
|
if pred_kpts is not None and gt_kpts is not None:
|
178
198
|
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
@@ -184,7 +204,7 @@ class PoseValidator(DetectionValidator):
|
|
184
204
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
185
205
|
|
186
206
|
def plot_val_samples(self, batch, ni):
|
187
|
-
"""
|
207
|
+
"""Plot and save validation set samples with ground truth bounding boxes and keypoints."""
|
188
208
|
plot_images(
|
189
209
|
batch["img"],
|
190
210
|
batch["batch_idx"],
|
@@ -198,7 +218,7 @@ class PoseValidator(DetectionValidator):
|
|
198
218
|
)
|
199
219
|
|
200
220
|
def plot_predictions(self, batch, preds, ni):
|
201
|
-
"""
|
221
|
+
"""Plot and save model predictions with bounding boxes and keypoints."""
|
202
222
|
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
|
203
223
|
plot_images(
|
204
224
|
batch["img"],
|
@@ -223,7 +243,7 @@ class PoseValidator(DetectionValidator):
|
|
223
243
|
).save_txt(file, save_conf=save_conf)
|
224
244
|
|
225
245
|
def pred_to_json(self, predn, filename):
|
226
|
-
"""
|
246
|
+
"""Convert YOLO predictions to COCO JSON format."""
|
227
247
|
stem = Path(filename).stem
|
228
248
|
image_id = int(stem) if stem.isnumeric() else stem
|
229
249
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
@@ -240,7 +260,7 @@ class PoseValidator(DetectionValidator):
|
|
240
260
|
)
|
241
261
|
|
242
262
|
def eval_json(self, stats):
|
243
|
-
"""
|
263
|
+
"""Evaluate object detection model using COCO JSON format."""
|
244
264
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
245
265
|
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
246
266
|
pred_json = self.save_dir / "predictions.json" # predictions
|
@@ -9,31 +9,41 @@ class SegmentationPredictor(DetectionPredictor):
|
|
9
9
|
"""
|
10
10
|
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
11
11
|
|
12
|
-
|
13
|
-
|
14
|
-
from ultralytics.utils import ASSETS
|
15
|
-
from ultralytics.models.yolo.segment import SegmentationPredictor
|
12
|
+
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
13
|
+
prediction results.
|
16
14
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
15
|
+
Attributes:
|
16
|
+
args (Dict): Configuration arguments for the predictor.
|
17
|
+
model (torch.nn.Module): The loaded YOLO segmentation model.
|
18
|
+
batch (List): Current batch of images being processed.
|
19
|
+
|
20
|
+
Methods:
|
21
|
+
postprocess: Applies non-max suppression and processes detections.
|
22
|
+
construct_results: Constructs a list of result objects from predictions.
|
23
|
+
construct_result: Constructs a single result object from a prediction.
|
24
|
+
|
25
|
+
Examples:
|
26
|
+
>>> from ultralytics.utils import ASSETS
|
27
|
+
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
|
28
|
+
>>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
|
29
|
+
>>> predictor = SegmentationPredictor(overrides=args)
|
30
|
+
>>> predictor.predict_cli()
|
21
31
|
"""
|
22
32
|
|
23
33
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
24
|
-
"""
|
34
|
+
"""Initialize the SegmentationPredictor with configuration, overrides, and callbacks."""
|
25
35
|
super().__init__(cfg, overrides, _callbacks)
|
26
36
|
self.args.task = "segment"
|
27
37
|
|
28
38
|
def postprocess(self, preds, img, orig_imgs):
|
29
|
-
"""
|
30
|
-
# tuple if PyTorch model or array if exported
|
39
|
+
"""Apply non-max suppression and process detections for each image in the input batch."""
|
40
|
+
# Extract protos - tuple if PyTorch model or array if exported
|
31
41
|
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
32
42
|
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
33
43
|
|
34
44
|
def construct_results(self, preds, img, orig_imgs, protos):
|
35
45
|
"""
|
36
|
-
|
46
|
+
Construct a list of result objects from the predictions.
|
37
47
|
|
38
48
|
Args:
|
39
49
|
preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
@@ -42,7 +52,8 @@ class SegmentationPredictor(DetectionPredictor):
|
|
42
52
|
protos (List[torch.Tensor]): List of prototype masks.
|
43
53
|
|
44
54
|
Returns:
|
45
|
-
(
|
55
|
+
(List[Results]): List of result objects containing the original images, image paths, class names,
|
56
|
+
bounding boxes, and masks.
|
46
57
|
"""
|
47
58
|
return [
|
48
59
|
self.construct_result(pred, img, orig_img, img_path, proto)
|
@@ -51,7 +62,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
51
62
|
|
52
63
|
def construct_result(self, pred, img, orig_img, img_path, proto):
|
53
64
|
"""
|
54
|
-
|
65
|
+
Construct a single result object from the prediction.
|
55
66
|
|
56
67
|
Args:
|
57
68
|
pred (np.ndarray): The predicted bounding boxes, scores, and masks.
|
@@ -61,7 +72,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|
61
72
|
proto (torch.Tensor): The prototype masks.
|
62
73
|
|
63
74
|
Returns:
|
64
|
-
(Results):
|
75
|
+
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
|
65
76
|
"""
|
66
77
|
if not len(pred): # save empty boxes
|
67
78
|
masks = None
|
@@ -72,6 +83,6 @@ class SegmentationPredictor(DetectionPredictor):
|
|
72
83
|
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
73
84
|
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
74
85
|
if masks is not None:
|
75
|
-
keep = masks.sum((-2, -1)) > 0 # only keep
|
86
|
+
keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
|
76
87
|
pred, masks = pred[keep], masks[keep]
|
77
88
|
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
@@ -12,14 +12,17 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
|
12
12
|
"""
|
13
13
|
A class extending the DetectionTrainer class for training based on a segmentation model.
|
14
14
|
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
15
|
+
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
16
|
+
functionality including model initialization, validation, and visualization.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
loss_names (Tuple[str]): Names of the loss components used during training.
|
20
|
+
|
21
|
+
Examples:
|
22
|
+
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
23
|
+
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
24
|
+
>>> trainer = SegmentationTrainer(overrides=args)
|
25
|
+
>>> trainer.train()
|
23
26
|
"""
|
24
27
|
|
25
28
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
@@ -18,18 +18,34 @@ class SegmentationValidator(DetectionValidator):
|
|
18
18
|
"""
|
19
19
|
A class extending the DetectionValidator class for validation based on a segmentation model.
|
20
20
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
21
|
+
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
|
22
|
+
to compute metrics such as mAP for both detection and segmentation tasks.
|
23
|
+
|
24
|
+
Attributes:
|
25
|
+
plot_masks (List): List to store masks for plotting.
|
26
|
+
process (callable): Function to process masks based on save_json and save_txt flags.
|
27
|
+
args (namespace): Arguments for the validator.
|
28
|
+
metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
|
29
|
+
stats (Dict): Dictionary to store statistics during validation.
|
30
|
+
|
31
|
+
Examples:
|
32
|
+
>>> from ultralytics.models.yolo.segment import SegmentationValidator
|
33
|
+
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
|
34
|
+
>>> validator = SegmentationValidator(args=args)
|
35
|
+
>>> validator()
|
29
36
|
"""
|
30
37
|
|
31
38
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
32
|
-
"""
|
39
|
+
"""
|
40
|
+
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
44
|
+
save_dir (Path, optional): Directory to save results.
|
45
|
+
pbar (Any, optional): Progress bar for displaying progress.
|
46
|
+
args (namespace, optional): Arguments for the validator.
|
47
|
+
_callbacks (List, optional): List of callback functions.
|
48
|
+
"""
|
33
49
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
34
50
|
self.plot_masks = None
|
35
51
|
self.process = None
|
@@ -37,13 +53,18 @@ class SegmentationValidator(DetectionValidator):
|
|
37
53
|
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
38
54
|
|
39
55
|
def preprocess(self, batch):
|
40
|
-
"""
|
56
|
+
"""Preprocess batch by converting masks to float and sending to device."""
|
41
57
|
batch = super().preprocess(batch)
|
42
58
|
batch["masks"] = batch["masks"].to(self.device).float()
|
43
59
|
return batch
|
44
60
|
|
45
61
|
def init_metrics(self, model):
|
46
|
-
"""
|
62
|
+
"""
|
63
|
+
Initialize metrics and select mask processing function based on save_json flag.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
model (torch.nn.Module): Model to validate.
|
67
|
+
"""
|
47
68
|
super().init_metrics(model)
|
48
69
|
self.plot_masks = []
|
49
70
|
if self.args.save_json:
|
@@ -69,26 +90,61 @@ class SegmentationValidator(DetectionValidator):
|
|
69
90
|
)
|
70
91
|
|
71
92
|
def postprocess(self, preds):
|
72
|
-
"""
|
93
|
+
"""
|
94
|
+
Post-process YOLO predictions and return output detections with proto.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
preds (List): Raw predictions from the model.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
p (torch.Tensor): Processed detection predictions.
|
101
|
+
proto (torch.Tensor): Prototype masks for segmentation.
|
102
|
+
"""
|
73
103
|
p = super().postprocess(preds[0])
|
74
104
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
75
105
|
return p, proto
|
76
106
|
|
77
107
|
def _prepare_batch(self, si, batch):
|
78
|
-
"""
|
108
|
+
"""
|
109
|
+
Prepare a batch for training or inference by processing images and targets.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
si (int): Batch index.
|
113
|
+
batch (Dict): Batch data containing images and targets.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
(Dict): Prepared batch with processed images and targets.
|
117
|
+
"""
|
79
118
|
prepared_batch = super()._prepare_batch(si, batch)
|
80
119
|
midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
|
81
120
|
prepared_batch["masks"] = batch["masks"][midx]
|
82
121
|
return prepared_batch
|
83
122
|
|
84
123
|
def _prepare_pred(self, pred, pbatch, proto):
|
85
|
-
"""
|
124
|
+
"""
|
125
|
+
Prepare predictions for evaluation by processing bounding boxes and masks.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
pred (torch.Tensor): Raw predictions from the model.
|
129
|
+
pbatch (Dict): Prepared batch data.
|
130
|
+
proto (torch.Tensor): Prototype masks for segmentation.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
predn (torch.Tensor): Processed bounding box predictions.
|
134
|
+
pred_masks (torch.Tensor): Processed mask predictions.
|
135
|
+
"""
|
86
136
|
predn = super()._prepare_pred(pred, pbatch)
|
87
137
|
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
|
88
138
|
return predn, pred_masks
|
89
139
|
|
90
140
|
def update_metrics(self, preds, batch):
|
91
|
-
"""
|
141
|
+
"""
|
142
|
+
Update metrics with the current batch predictions and targets.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
preds (List): Predictions from the model.
|
146
|
+
batch (Dict): Batch data containing images and targets.
|
147
|
+
"""
|
92
148
|
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
93
149
|
self.seen += 1
|
94
150
|
npr = len(pred)
|
@@ -157,7 +213,7 @@ class SegmentationValidator(DetectionValidator):
|
|
157
213
|
)
|
158
214
|
|
159
215
|
def finalize_metrics(self, *args, **kwargs):
|
160
|
-
"""
|
216
|
+
"""Set speed and confusion matrix for evaluation metrics."""
|
161
217
|
self.metrics.speed = self.speed
|
162
218
|
self.metrics.confusion_matrix = self.confusion_matrix
|
163
219
|
|
@@ -171,9 +227,9 @@ class SegmentationValidator(DetectionValidator):
|
|
171
227
|
gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
|
172
228
|
Each row is of the format [x1, y1, x2, y2].
|
173
229
|
gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
|
174
|
-
pred_masks (torch.Tensor
|
230
|
+
pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
|
175
231
|
match the ground truth masks.
|
176
|
-
gt_masks (torch.Tensor
|
232
|
+
gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
|
177
233
|
overlap (bool): Flag indicating if overlapping masks should be considered.
|
178
234
|
masks (bool): Flag indicating if the batch contains mask data.
|
179
235
|
|
@@ -184,13 +240,11 @@ class SegmentationValidator(DetectionValidator):
|
|
184
240
|
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
185
241
|
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
186
242
|
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
|
193
|
-
```
|
243
|
+
Examples:
|
244
|
+
>>> detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
|
245
|
+
>>> gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
|
246
|
+
>>> gt_cls = torch.tensor([1, 0])
|
247
|
+
>>> correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
|
194
248
|
"""
|
195
249
|
if masks:
|
196
250
|
if overlap:
|
@@ -208,7 +262,13 @@ class SegmentationValidator(DetectionValidator):
|
|
208
262
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
209
263
|
|
210
264
|
def plot_val_samples(self, batch, ni):
|
211
|
-
"""
|
265
|
+
"""
|
266
|
+
Plot validation samples with bounding box labels and masks.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
batch (Dict): Batch data containing images and targets.
|
270
|
+
ni (int): Batch index.
|
271
|
+
"""
|
212
272
|
plot_images(
|
213
273
|
batch["img"],
|
214
274
|
batch["batch_idx"],
|
@@ -222,7 +282,14 @@ class SegmentationValidator(DetectionValidator):
|
|
222
282
|
)
|
223
283
|
|
224
284
|
def plot_predictions(self, batch, preds, ni):
|
225
|
-
"""
|
285
|
+
"""
|
286
|
+
Plot batch predictions with masks and bounding boxes.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
batch (Dict): Batch data containing images.
|
290
|
+
preds (List): Predictions from the model.
|
291
|
+
ni (int): Batch index.
|
292
|
+
"""
|
226
293
|
plot_images(
|
227
294
|
batch["img"],
|
228
295
|
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
|
@@ -235,7 +302,16 @@ class SegmentationValidator(DetectionValidator):
|
|
235
302
|
self.plot_masks.clear()
|
236
303
|
|
237
304
|
def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
|
238
|
-
"""
|
305
|
+
"""
|
306
|
+
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
|
310
|
+
pred_masks (torch.Tensor): Predicted masks.
|
311
|
+
save_conf (bool): Whether to save confidence scores.
|
312
|
+
shape (Tuple): Original image shape.
|
313
|
+
file (Path): File path to save the detections.
|
314
|
+
"""
|
239
315
|
from ultralytics.engine.results import Results
|
240
316
|
|
241
317
|
Results(
|
@@ -248,7 +324,12 @@ class SegmentationValidator(DetectionValidator):
|
|
248
324
|
|
249
325
|
def pred_to_json(self, predn, filename, pred_masks):
|
250
326
|
"""
|
251
|
-
Save one JSON result.
|
327
|
+
Save one JSON result for COCO evaluation.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
|
331
|
+
filename (str): Image filename.
|
332
|
+
pred_masks (numpy.ndarray): Predicted masks.
|
252
333
|
|
253
334
|
Examples:
|
254
335
|
>>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
@@ -10,9 +10,9 @@ from ultralytics.utils.torch_utils import de_parallel
|
|
10
10
|
|
11
11
|
|
12
12
|
def on_pretrain_routine_end(trainer):
|
13
|
-
"""Callback."""
|
13
|
+
"""Callback to set up model classes and text encoder at the end of the pretrain routine."""
|
14
14
|
if RANK in {-1, 0}:
|
15
|
-
#
|
15
|
+
# Set class names for evaluation
|
16
16
|
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
17
17
|
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
18
18
|
device = next(trainer.model.parameters()).device
|
@@ -25,18 +25,32 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
25
25
|
"""
|
26
26
|
A class to fine-tune a world model on a close-set dataset.
|
27
27
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
28
|
+
This trainer extends the DetectionTrainer to support training YOLO World models, which combine
|
29
|
+
visual and textual features for improved object detection and understanding.
|
30
|
+
|
31
|
+
Attributes:
|
32
|
+
clip (module): The CLIP module for text-image understanding.
|
33
|
+
text_model (module): The text encoder model from CLIP.
|
34
|
+
model (WorldModel): The YOLO World model being trained.
|
35
|
+
data (Dict): Dataset configuration containing class information.
|
36
|
+
args (Dict): Training arguments and configuration.
|
37
|
+
|
38
|
+
Examples:
|
39
|
+
>>> from ultralytics.models.yolo.world import WorldModel
|
40
|
+
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
|
41
|
+
>>> trainer = WorldTrainer(overrides=args)
|
42
|
+
>>> trainer.train()
|
36
43
|
"""
|
37
44
|
|
38
45
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
39
|
-
"""
|
46
|
+
"""
|
47
|
+
Initialize a WorldTrainer object with given arguments.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
cfg (Dict): Configuration for the trainer.
|
51
|
+
overrides (Dict, optional): Configuration overrides.
|
52
|
+
_callbacks (List, optional): List of callback functions.
|
53
|
+
"""
|
40
54
|
if overrides is None:
|
41
55
|
overrides = {}
|
42
56
|
super().__init__(cfg, overrides, _callbacks)
|
@@ -50,7 +64,17 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
50
64
|
self.clip = clip
|
51
65
|
|
52
66
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
53
|
-
"""
|
67
|
+
"""
|
68
|
+
Return WorldModel initialized with specified config and weights.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
cfg (Dict | str, optional): Model configuration.
|
72
|
+
weights (str, optional): Path to pretrained weights.
|
73
|
+
verbose (bool): Whether to display model info.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
(WorldModel): Initialized WorldModel.
|
77
|
+
"""
|
54
78
|
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
55
79
|
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
56
80
|
model = WorldModel(
|
@@ -67,12 +91,15 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
67
91
|
|
68
92
|
def build_dataset(self, img_path, mode="train", batch=None):
|
69
93
|
"""
|
70
|
-
Build YOLO Dataset.
|
94
|
+
Build YOLO Dataset for training or validation.
|
71
95
|
|
72
96
|
Args:
|
73
97
|
img_path (str): Path to the folder containing images.
|
74
98
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
75
|
-
batch (int, optional): Size of batches, this is for `rect`.
|
99
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
(Dataset): YOLO dataset configured for training or validation.
|
76
103
|
"""
|
77
104
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
78
105
|
return build_yolo_dataset(
|
@@ -80,10 +107,10 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
80
107
|
)
|
81
108
|
|
82
109
|
def preprocess_batch(self, batch):
|
83
|
-
"""
|
110
|
+
"""Preprocess a batch of images and text for YOLOWorld training."""
|
84
111
|
batch = super().preprocess_batch(batch)
|
85
112
|
|
86
|
-
#
|
113
|
+
# Add text features
|
87
114
|
texts = list(itertools.chain(*batch["texts"]))
|
88
115
|
text_token = self.clip.tokenize(texts).to(batch["img"].device)
|
89
116
|
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
|