ultralytics 8.3.89__py3-none-any.whl → 8.3.91__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/conftest.py +2 -2
- tests/test_cli.py +13 -11
- tests/test_cuda.py +10 -1
- tests/test_exports.py +2 -2
- tests/test_integrations.py +1 -5
- tests/test_python.py +16 -16
- tests/test_solutions.py +9 -9
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +3 -1
- ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
- ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
- ultralytics/cfg/models/11/yolo11.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
- ultralytics/cfg/models/v8/yolov8.yaml +5 -5
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
- ultralytics/data/annotator.py +9 -14
- ultralytics/data/base.py +118 -30
- ultralytics/data/build.py +63 -24
- ultralytics/data/converter.py +5 -5
- ultralytics/data/dataset.py +207 -53
- ultralytics/data/loaders.py +1 -0
- ultralytics/data/split_dota.py +39 -12
- ultralytics/data/utils.py +15 -19
- ultralytics/engine/exporter.py +24 -23
- ultralytics/engine/model.py +67 -88
- ultralytics/engine/predictor.py +106 -21
- ultralytics/engine/trainer.py +32 -23
- ultralytics/engine/tuner.py +21 -18
- ultralytics/engine/validator.py +75 -41
- ultralytics/hub/__init__.py +12 -13
- ultralytics/hub/auth.py +9 -12
- ultralytics/hub/session.py +76 -21
- ultralytics/hub/utils.py +19 -17
- ultralytics/models/fastsam/model.py +20 -11
- ultralytics/models/fastsam/predict.py +36 -16
- ultralytics/models/fastsam/utils.py +5 -5
- ultralytics/models/fastsam/val.py +6 -6
- ultralytics/models/nas/model.py +22 -11
- ultralytics/models/nas/predict.py +9 -4
- ultralytics/models/nas/val.py +5 -5
- ultralytics/models/rtdetr/model.py +20 -11
- ultralytics/models/rtdetr/predict.py +18 -15
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +42 -6
- ultralytics/models/sam/__init__.py +1 -1
- ultralytics/models/sam/amg.py +50 -4
- ultralytics/models/sam/model.py +8 -14
- ultralytics/models/sam/modules/decoders.py +18 -21
- ultralytics/models/sam/modules/encoders.py +25 -46
- ultralytics/models/sam/modules/memory_attention.py +19 -15
- ultralytics/models/sam/modules/sam.py +18 -25
- ultralytics/models/sam/modules/tiny_encoder.py +19 -29
- ultralytics/models/sam/modules/transformer.py +35 -57
- ultralytics/models/sam/modules/utils.py +15 -15
- ultralytics/models/sam/predict.py +0 -3
- ultralytics/models/utils/loss.py +87 -36
- ultralytics/models/utils/ops.py +26 -31
- ultralytics/models/yolo/classify/predict.py +24 -3
- ultralytics/models/yolo/classify/train.py +77 -10
- ultralytics/models/yolo/classify/val.py +40 -15
- ultralytics/models/yolo/detect/predict.py +23 -10
- ultralytics/models/yolo/detect/train.py +85 -15
- ultralytics/models/yolo/detect/val.py +145 -21
- ultralytics/models/yolo/model.py +1 -2
- ultralytics/models/yolo/obb/predict.py +12 -4
- ultralytics/models/yolo/obb/train.py +7 -0
- ultralytics/models/yolo/obb/val.py +25 -7
- ultralytics/models/yolo/pose/predict.py +22 -6
- ultralytics/models/yolo/pose/train.py +17 -1
- ultralytics/models/yolo/pose/val.py +46 -21
- ultralytics/models/yolo/segment/predict.py +22 -8
- ultralytics/models/yolo/segment/train.py +6 -0
- ultralytics/models/yolo/segment/val.py +100 -14
- ultralytics/models/yolo/world/train.py +38 -8
- ultralytics/models/yolo/world/train_world.py +39 -10
- ultralytics/nn/autobackend.py +28 -14
- ultralytics/nn/modules/__init__.py +3 -0
- ultralytics/nn/modules/activation.py +12 -3
- ultralytics/nn/modules/block.py +587 -84
- ultralytics/nn/modules/conv.py +418 -54
- ultralytics/nn/modules/head.py +3 -4
- ultralytics/nn/modules/transformer.py +320 -34
- ultralytics/nn/modules/utils.py +17 -3
- ultralytics/nn/tasks.py +221 -69
- ultralytics/solutions/ai_gym.py +2 -2
- ultralytics/solutions/analytics.py +4 -4
- ultralytics/solutions/heatmap.py +4 -4
- ultralytics/solutions/instance_segmentation.py +10 -4
- ultralytics/solutions/object_blurrer.py +2 -2
- ultralytics/solutions/object_counter.py +2 -2
- ultralytics/solutions/object_cropper.py +2 -2
- ultralytics/solutions/parking_management.py +9 -9
- ultralytics/solutions/queue_management.py +1 -1
- ultralytics/solutions/region_counter.py +2 -2
- ultralytics/solutions/security_alarm.py +7 -7
- ultralytics/solutions/solutions.py +7 -4
- ultralytics/solutions/speed_estimation.py +2 -2
- ultralytics/solutions/streamlit_inference.py +6 -6
- ultralytics/solutions/trackzone.py +9 -2
- ultralytics/solutions/vision_eye.py +4 -4
- ultralytics/trackers/basetrack.py +1 -1
- ultralytics/trackers/bot_sort.py +23 -22
- ultralytics/trackers/byte_tracker.py +4 -4
- ultralytics/trackers/track.py +2 -1
- ultralytics/trackers/utils/gmc.py +26 -27
- ultralytics/trackers/utils/kalman_filter.py +31 -29
- ultralytics/trackers/utils/matching.py +7 -7
- ultralytics/utils/__init__.py +32 -27
- ultralytics/utils/autobatch.py +5 -5
- ultralytics/utils/benchmarks.py +111 -18
- ultralytics/utils/callbacks/base.py +3 -3
- ultralytics/utils/callbacks/clearml.py +11 -11
- ultralytics/utils/callbacks/comet.py +42 -24
- ultralytics/utils/callbacks/dvc.py +11 -10
- ultralytics/utils/callbacks/hub.py +8 -8
- ultralytics/utils/callbacks/mlflow.py +1 -1
- ultralytics/utils/callbacks/neptune.py +12 -10
- ultralytics/utils/callbacks/raytune.py +1 -1
- ultralytics/utils/callbacks/tensorboard.py +6 -6
- ultralytics/utils/callbacks/wb.py +16 -16
- ultralytics/utils/checks.py +116 -35
- ultralytics/utils/dist.py +15 -2
- ultralytics/utils/downloads.py +13 -9
- ultralytics/utils/files.py +12 -13
- ultralytics/utils/instance.py +112 -45
- ultralytics/utils/loss.py +28 -33
- ultralytics/utils/metrics.py +246 -181
- ultralytics/utils/ops.py +61 -53
- ultralytics/utils/patches.py +8 -6
- ultralytics/utils/plotting.py +65 -45
- ultralytics/utils/tal.py +88 -57
- ultralytics/utils/torch_utils.py +181 -33
- ultralytics/utils/triton.py +13 -3
- ultralytics/utils/tuner.py +8 -16
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/METADATA +1 -1
- ultralytics-8.3.91.dist-info/RECORD +250 -0
- ultralytics-8.3.89.dist-info/RECORD +0 -250
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/LICENSE +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/WHEEL +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/entry_points.txt +0 -0
- {ultralytics-8.3.89.dist-info → ultralytics-8.3.91.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,16 @@ class SegmentationValidator(DetectionValidator):
|
|
18
18
|
"""
|
19
19
|
A class extending the DetectionValidator class for validation based on a segmentation model.
|
20
20
|
|
21
|
+
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
|
22
|
+
to compute metrics such as mAP for both detection and segmentation tasks.
|
23
|
+
|
24
|
+
Attributes:
|
25
|
+
plot_masks (List): List to store masks for plotting.
|
26
|
+
process (callable): Function to process masks based on save_json and save_txt flags.
|
27
|
+
args (namespace): Arguments for the validator.
|
28
|
+
metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
|
29
|
+
stats (Dict): Dictionary to store statistics during validation.
|
30
|
+
|
21
31
|
Examples:
|
22
32
|
>>> from ultralytics.models.yolo.segment import SegmentationValidator
|
23
33
|
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
|
@@ -26,7 +36,16 @@ class SegmentationValidator(DetectionValidator):
|
|
26
36
|
"""
|
27
37
|
|
28
38
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
29
|
-
"""
|
39
|
+
"""
|
40
|
+
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
44
|
+
save_dir (Path, optional): Directory to save results.
|
45
|
+
pbar (Any, optional): Progress bar for displaying progress.
|
46
|
+
args (namespace, optional): Arguments for the validator.
|
47
|
+
_callbacks (List, optional): List of callback functions.
|
48
|
+
"""
|
30
49
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
31
50
|
self.plot_masks = None
|
32
51
|
self.process = None
|
@@ -34,13 +53,18 @@ class SegmentationValidator(DetectionValidator):
|
|
34
53
|
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
35
54
|
|
36
55
|
def preprocess(self, batch):
|
37
|
-
"""
|
56
|
+
"""Preprocess batch by converting masks to float and sending to device."""
|
38
57
|
batch = super().preprocess(batch)
|
39
58
|
batch["masks"] = batch["masks"].to(self.device).float()
|
40
59
|
return batch
|
41
60
|
|
42
61
|
def init_metrics(self, model):
|
43
|
-
"""
|
62
|
+
"""
|
63
|
+
Initialize metrics and select mask processing function based on save_json flag.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
model (torch.nn.Module): Model to validate.
|
67
|
+
"""
|
44
68
|
super().init_metrics(model)
|
45
69
|
self.plot_masks = []
|
46
70
|
if self.args.save_json:
|
@@ -66,26 +90,61 @@ class SegmentationValidator(DetectionValidator):
|
|
66
90
|
)
|
67
91
|
|
68
92
|
def postprocess(self, preds):
|
69
|
-
"""
|
93
|
+
"""
|
94
|
+
Post-process YOLO predictions and return output detections with proto.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
preds (List): Raw predictions from the model.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
p (torch.Tensor): Processed detection predictions.
|
101
|
+
proto (torch.Tensor): Prototype masks for segmentation.
|
102
|
+
"""
|
70
103
|
p = super().postprocess(preds[0])
|
71
104
|
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
72
105
|
return p, proto
|
73
106
|
|
74
107
|
def _prepare_batch(self, si, batch):
|
75
|
-
"""
|
108
|
+
"""
|
109
|
+
Prepare a batch for training or inference by processing images and targets.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
si (int): Batch index.
|
113
|
+
batch (Dict): Batch data containing images and targets.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
(Dict): Prepared batch with processed images and targets.
|
117
|
+
"""
|
76
118
|
prepared_batch = super()._prepare_batch(si, batch)
|
77
119
|
midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
|
78
120
|
prepared_batch["masks"] = batch["masks"][midx]
|
79
121
|
return prepared_batch
|
80
122
|
|
81
123
|
def _prepare_pred(self, pred, pbatch, proto):
|
82
|
-
"""
|
124
|
+
"""
|
125
|
+
Prepare predictions for evaluation by processing bounding boxes and masks.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
pred (torch.Tensor): Raw predictions from the model.
|
129
|
+
pbatch (Dict): Prepared batch data.
|
130
|
+
proto (torch.Tensor): Prototype masks for segmentation.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
predn (torch.Tensor): Processed bounding box predictions.
|
134
|
+
pred_masks (torch.Tensor): Processed mask predictions.
|
135
|
+
"""
|
83
136
|
predn = super()._prepare_pred(pred, pbatch)
|
84
137
|
pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"])
|
85
138
|
return predn, pred_masks
|
86
139
|
|
87
140
|
def update_metrics(self, preds, batch):
|
88
|
-
"""
|
141
|
+
"""
|
142
|
+
Update metrics with the current batch predictions and targets.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
preds (List): Predictions from the model.
|
146
|
+
batch (Dict): Batch data containing images and targets.
|
147
|
+
"""
|
89
148
|
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
90
149
|
self.seen += 1
|
91
150
|
npr = len(pred)
|
@@ -154,7 +213,7 @@ class SegmentationValidator(DetectionValidator):
|
|
154
213
|
)
|
155
214
|
|
156
215
|
def finalize_metrics(self, *args, **kwargs):
|
157
|
-
"""
|
216
|
+
"""Set speed and confusion matrix for evaluation metrics."""
|
158
217
|
self.metrics.speed = self.speed
|
159
218
|
self.metrics.confusion_matrix = self.confusion_matrix
|
160
219
|
|
@@ -168,9 +227,9 @@ class SegmentationValidator(DetectionValidator):
|
|
168
227
|
gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
|
169
228
|
Each row is of the format [x1, y1, x2, y2].
|
170
229
|
gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
|
171
|
-
pred_masks (torch.Tensor
|
230
|
+
pred_masks (torch.Tensor, optional): Tensor representing predicted masks, if available. The shape should
|
172
231
|
match the ground truth masks.
|
173
|
-
gt_masks (torch.Tensor
|
232
|
+
gt_masks (torch.Tensor, optional): Tensor of shape (M, H, W) representing ground truth masks, if available.
|
174
233
|
overlap (bool): Flag indicating if overlapping masks should be considered.
|
175
234
|
masks (bool): Flag indicating if the batch contains mask data.
|
176
235
|
|
@@ -203,7 +262,13 @@ class SegmentationValidator(DetectionValidator):
|
|
203
262
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
204
263
|
|
205
264
|
def plot_val_samples(self, batch, ni):
|
206
|
-
"""
|
265
|
+
"""
|
266
|
+
Plot validation samples with bounding box labels and masks.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
batch (Dict): Batch data containing images and targets.
|
270
|
+
ni (int): Batch index.
|
271
|
+
"""
|
207
272
|
plot_images(
|
208
273
|
batch["img"],
|
209
274
|
batch["batch_idx"],
|
@@ -217,7 +282,14 @@ class SegmentationValidator(DetectionValidator):
|
|
217
282
|
)
|
218
283
|
|
219
284
|
def plot_predictions(self, batch, preds, ni):
|
220
|
-
"""
|
285
|
+
"""
|
286
|
+
Plot batch predictions with masks and bounding boxes.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
batch (Dict): Batch data containing images.
|
290
|
+
preds (List): Predictions from the model.
|
291
|
+
ni (int): Batch index.
|
292
|
+
"""
|
221
293
|
plot_images(
|
222
294
|
batch["img"],
|
223
295
|
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
|
@@ -230,7 +302,16 @@ class SegmentationValidator(DetectionValidator):
|
|
230
302
|
self.plot_masks.clear()
|
231
303
|
|
232
304
|
def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
|
233
|
-
"""
|
305
|
+
"""
|
306
|
+
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
307
|
+
|
308
|
+
Args:
|
309
|
+
predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
|
310
|
+
pred_masks (torch.Tensor): Predicted masks.
|
311
|
+
save_conf (bool): Whether to save confidence scores.
|
312
|
+
shape (Tuple): Original image shape.
|
313
|
+
file (Path): File path to save the detections.
|
314
|
+
"""
|
234
315
|
from ultralytics.engine.results import Results
|
235
316
|
|
236
317
|
Results(
|
@@ -243,7 +324,12 @@ class SegmentationValidator(DetectionValidator):
|
|
243
324
|
|
244
325
|
def pred_to_json(self, predn, filename, pred_masks):
|
245
326
|
"""
|
246
|
-
Save one JSON result.
|
327
|
+
Save one JSON result for COCO evaluation.
|
328
|
+
|
329
|
+
Args:
|
330
|
+
predn (torch.Tensor): Predictions in the format [x1, y1, x2, y2, conf, cls].
|
331
|
+
filename (str): Image filename.
|
332
|
+
pred_masks (numpy.ndarray): Predicted masks.
|
247
333
|
|
248
334
|
Examples:
|
249
335
|
>>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
|
@@ -10,9 +10,9 @@ from ultralytics.utils.torch_utils import de_parallel
|
|
10
10
|
|
11
11
|
|
12
12
|
def on_pretrain_routine_end(trainer):
|
13
|
-
"""Callback."""
|
13
|
+
"""Callback to set up model classes and text encoder at the end of the pretrain routine."""
|
14
14
|
if RANK in {-1, 0}:
|
15
|
-
#
|
15
|
+
# Set class names for evaluation
|
16
16
|
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
17
17
|
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
18
18
|
device = next(trainer.model.parameters()).device
|
@@ -25,6 +25,16 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
25
25
|
"""
|
26
26
|
A class to fine-tune a world model on a close-set dataset.
|
27
27
|
|
28
|
+
This trainer extends the DetectionTrainer to support training YOLO World models, which combine
|
29
|
+
visual and textual features for improved object detection and understanding.
|
30
|
+
|
31
|
+
Attributes:
|
32
|
+
clip (module): The CLIP module for text-image understanding.
|
33
|
+
text_model (module): The text encoder model from CLIP.
|
34
|
+
model (WorldModel): The YOLO World model being trained.
|
35
|
+
data (Dict): Dataset configuration containing class information.
|
36
|
+
args (Dict): Training arguments and configuration.
|
37
|
+
|
28
38
|
Examples:
|
29
39
|
>>> from ultralytics.models.yolo.world import WorldModel
|
30
40
|
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
|
@@ -33,7 +43,14 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
33
43
|
"""
|
34
44
|
|
35
45
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
36
|
-
"""
|
46
|
+
"""
|
47
|
+
Initialize a WorldTrainer object with given arguments.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
cfg (Dict): Configuration for the trainer.
|
51
|
+
overrides (Dict, optional): Configuration overrides.
|
52
|
+
_callbacks (List, optional): List of callback functions.
|
53
|
+
"""
|
37
54
|
if overrides is None:
|
38
55
|
overrides = {}
|
39
56
|
super().__init__(cfg, overrides, _callbacks)
|
@@ -47,7 +64,17 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
47
64
|
self.clip = clip
|
48
65
|
|
49
66
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
50
|
-
"""
|
67
|
+
"""
|
68
|
+
Return WorldModel initialized with specified config and weights.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
cfg (Dict | str, optional): Model configuration.
|
72
|
+
weights (str, optional): Path to pretrained weights.
|
73
|
+
verbose (bool): Whether to display model info.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
(WorldModel): Initialized WorldModel.
|
77
|
+
"""
|
51
78
|
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
52
79
|
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
53
80
|
model = WorldModel(
|
@@ -64,12 +91,15 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
64
91
|
|
65
92
|
def build_dataset(self, img_path, mode="train", batch=None):
|
66
93
|
"""
|
67
|
-
Build YOLO Dataset.
|
94
|
+
Build YOLO Dataset for training or validation.
|
68
95
|
|
69
96
|
Args:
|
70
97
|
img_path (str): Path to the folder containing images.
|
71
98
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
72
|
-
batch (int, optional): Size of batches, this is for `rect`.
|
99
|
+
batch (int, optional): Size of batches, this is for `rect`.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
(Dataset): YOLO dataset configured for training or validation.
|
73
103
|
"""
|
74
104
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
75
105
|
return build_yolo_dataset(
|
@@ -77,10 +107,10 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
|
|
77
107
|
)
|
78
108
|
|
79
109
|
def preprocess_batch(self, batch):
|
80
|
-
"""
|
110
|
+
"""Preprocess a batch of images and text for YOLOWorld training."""
|
81
111
|
batch = super().preprocess_batch(batch)
|
82
112
|
|
83
|
-
#
|
113
|
+
# Add text features
|
84
114
|
texts = list(itertools.chain(*batch["texts"]))
|
85
115
|
text_token = self.clip.tokenize(texts).to(batch["img"].device)
|
86
116
|
txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
|
@@ -9,7 +9,15 @@ from ultralytics.utils.torch_utils import de_parallel
|
|
9
9
|
|
10
10
|
class WorldTrainerFromScratch(WorldTrainer):
|
11
11
|
"""
|
12
|
-
A class extending the WorldTrainer
|
12
|
+
A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
13
|
+
|
14
|
+
This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
|
15
|
+
supporting training YOLO-World models with combined vision-language capabilities.
|
16
|
+
|
17
|
+
Attributes:
|
18
|
+
cfg (Dict): Configuration dictionary with default parameters for model training.
|
19
|
+
overrides (Dict): Dictionary of parameter overrides to customize the configuration.
|
20
|
+
_callbacks (List): List of callback functions to be executed during different stages of training.
|
13
21
|
|
14
22
|
Examples:
|
15
23
|
>>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
@@ -35,19 +43,25 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
35
43
|
"""
|
36
44
|
|
37
45
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
38
|
-
"""Initialize a
|
46
|
+
"""Initialize a WorldTrainerFromScratch object with given configuration and callbacks."""
|
39
47
|
if overrides is None:
|
40
48
|
overrides = {}
|
41
49
|
super().__init__(cfg, overrides, _callbacks)
|
42
50
|
|
43
51
|
def build_dataset(self, img_path, mode="train", batch=None):
|
44
52
|
"""
|
45
|
-
Build YOLO Dataset.
|
53
|
+
Build YOLO Dataset for training or validation.
|
54
|
+
|
55
|
+
This method constructs appropriate datasets based on the mode and input paths, handling both
|
56
|
+
standard YOLO datasets and grounding datasets with different formats.
|
46
57
|
|
47
58
|
Args:
|
48
|
-
img_path (List[str] | str): Path to the folder containing images.
|
49
|
-
mode (str):
|
50
|
-
batch (int, optional): Size of batches,
|
59
|
+
img_path (List[str] | str): Path to the folder containing images or list of paths.
|
60
|
+
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
61
|
+
batch (int, optional): Size of batches, used for rectangular training/validation.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
(YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
|
51
65
|
"""
|
52
66
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
53
67
|
if mode != "train":
|
@@ -62,9 +76,17 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
62
76
|
|
63
77
|
def get_dataset(self):
|
64
78
|
"""
|
65
|
-
Get train
|
79
|
+
Get train and validation paths from data dictionary.
|
80
|
+
|
81
|
+
Processes the data configuration to extract paths for training and validation datasets,
|
82
|
+
handling both YOLO detection datasets and grounding datasets.
|
66
83
|
|
67
|
-
Returns
|
84
|
+
Returns:
|
85
|
+
(str): Train dataset path.
|
86
|
+
(str): Validation dataset path.
|
87
|
+
|
88
|
+
Raises:
|
89
|
+
AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
|
68
90
|
"""
|
69
91
|
final_data = {}
|
70
92
|
data_yaml = self.args.data
|
@@ -94,11 +116,18 @@ class WorldTrainerFromScratch(WorldTrainer):
|
|
94
116
|
return final_data["train"], final_data["val"][0]
|
95
117
|
|
96
118
|
def plot_training_labels(self):
|
97
|
-
"""
|
119
|
+
"""Do not plot labels for YOLO-World training."""
|
98
120
|
pass
|
99
121
|
|
100
122
|
def final_eval(self):
|
101
|
-
"""
|
123
|
+
"""
|
124
|
+
Perform final evaluation and validation for the YOLO-World model.
|
125
|
+
|
126
|
+
Configures the validator with appropriate dataset and split information before running evaluation.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
(Dict): Dictionary containing evaluation metrics and results.
|
130
|
+
"""
|
102
131
|
val = self.args.data["val"]["yolo_data"][0]
|
103
132
|
self.validator.args.data = val
|
104
133
|
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
|
ultralytics/nn/autobackend.py
CHANGED
@@ -19,11 +19,7 @@ from ultralytics.utils.downloads import attempt_download_asset, is_url
|
|
19
19
|
|
20
20
|
|
21
21
|
def check_class_names(names):
|
22
|
-
"""
|
23
|
-
Check class names.
|
24
|
-
|
25
|
-
Map imagenet class codes to human-readable names if required. Convert lists to dicts.
|
26
|
-
"""
|
22
|
+
"""Check class names and convert to dict format if needed."""
|
27
23
|
if isinstance(names, list): # names is a list
|
28
24
|
names = dict(enumerate(names)) # convert to dict
|
29
25
|
if isinstance(names, dict):
|
@@ -78,8 +74,23 @@ class AutoBackend(nn.Module):
|
|
78
74
|
| IMX | *_imx_model/ |
|
79
75
|
| RKNN | *_rknn_model/ |
|
80
76
|
|
81
|
-
|
82
|
-
|
77
|
+
Attributes:
|
78
|
+
model (torch.nn.Module): The loaded YOLO model.
|
79
|
+
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
80
|
+
task (str): The type of task the model performs (detect, segment, classify, pose).
|
81
|
+
names (Dict): A dictionary of class names that the model can detect.
|
82
|
+
stride (int): The model stride, typically 32 for YOLO models.
|
83
|
+
fp16 (bool): Whether the model uses half-precision (FP16) inference.
|
84
|
+
|
85
|
+
Methods:
|
86
|
+
forward: Run inference on an input image.
|
87
|
+
from_numpy: Convert numpy array to tensor.
|
88
|
+
warmup: Warm up the model with a dummy input.
|
89
|
+
_model_type: Determine the model type from file path.
|
90
|
+
|
91
|
+
Examples:
|
92
|
+
>>> model = AutoBackend(weights="yolov8n.pt", device="cuda")
|
93
|
+
>>> results = model(img)
|
83
94
|
"""
|
84
95
|
|
85
96
|
@torch.no_grad()
|
@@ -101,7 +112,7 @@ class AutoBackend(nn.Module):
|
|
101
112
|
weights (str | torch.nn.Module): Path to the model weights file or a module instance. Defaults to 'yolo11n.pt'.
|
102
113
|
device (torch.device): Device to run the model on. Defaults to CPU.
|
103
114
|
dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False.
|
104
|
-
data (str | Path | optional): Path to the additional data.yaml file containing class names.
|
115
|
+
data (str | Path | optional): Path to the additional data.yaml file containing class names.
|
105
116
|
fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False.
|
106
117
|
batch (int): Batch-size to assume for inference.
|
107
118
|
fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True.
|
@@ -539,12 +550,12 @@ class AutoBackend(nn.Module):
|
|
539
550
|
|
540
551
|
Args:
|
541
552
|
im (torch.Tensor): The image tensor to perform inference on.
|
542
|
-
augment (bool):
|
543
|
-
visualize (bool):
|
544
|
-
embed (
|
553
|
+
augment (bool): Whether to perform data augmentation during inference. Defaults to False.
|
554
|
+
visualize (bool): Whether to visualize the output predictions. Defaults to False.
|
555
|
+
embed (List, optional): A list of feature vectors/embeddings to return.
|
545
556
|
|
546
557
|
Returns:
|
547
|
-
(
|
558
|
+
(torch.Tensor | List[torch.Tensor]): The raw output tensor(s) from the model.
|
548
559
|
"""
|
549
560
|
b, ch, h, w = im.shape # batch, channel, height, width
|
550
561
|
if self.fp16 and im.dtype != torch.float16:
|
@@ -776,10 +787,13 @@ class AutoBackend(nn.Module):
|
|
776
787
|
def _model_type(p="path/to/model.pt"):
|
777
788
|
"""
|
778
789
|
Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml,
|
779
|
-
saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
|
790
|
+
saved_model, pb, tflite, edgetpu, tfjs, ncnn, mnn, imx or paddle.
|
780
791
|
|
781
792
|
Args:
|
782
|
-
p (str):
|
793
|
+
p (str): Path to the model file. Defaults to path/to/model.pt
|
794
|
+
|
795
|
+
Returns:
|
796
|
+
(List[bool]): List of booleans indicating the model type.
|
783
797
|
|
784
798
|
Examples:
|
785
799
|
>>> model = AutoBackend(weights="path/to/model.onnx")
|
@@ -2,6 +2,9 @@
|
|
2
2
|
"""
|
3
3
|
Ultralytics modules.
|
4
4
|
|
5
|
+
This module provides access to various neural network components used in Ultralytics models, including convolution blocks,
|
6
|
+
attention mechanisms, transformer components, and detection/segmentation heads.
|
7
|
+
|
5
8
|
Examples:
|
6
9
|
Visualize a module with Netron.
|
7
10
|
>>> from ultralytics.nn.modules import *
|
@@ -6,10 +6,19 @@ import torch.nn as nn
|
|
6
6
|
|
7
7
|
|
8
8
|
class AGLU(nn.Module):
|
9
|
-
"""
|
9
|
+
"""
|
10
|
+
Unified activation function module from https://github.com/kostas1515/AGLU.
|
11
|
+
|
12
|
+
This class implements a parameterized activation function with learnable parameters lambda and kappa.
|
13
|
+
|
14
|
+
Attributes:
|
15
|
+
act (nn.Softplus): Softplus activation function with negative beta.
|
16
|
+
lambd (nn.Parameter): Learnable lambda parameter initialized with uniform distribution.
|
17
|
+
kappa (nn.Parameter): Learnable kappa parameter initialized with uniform distribution.
|
18
|
+
"""
|
10
19
|
|
11
20
|
def __init__(self, device=None, dtype=None) -> None:
|
12
|
-
"""Initialize the Unified activation function."""
|
21
|
+
"""Initialize the Unified activation function with learnable parameters."""
|
13
22
|
super().__init__()
|
14
23
|
self.act = nn.Softplus(beta=-1.0)
|
15
24
|
self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter
|
@@ -17,5 +26,5 @@ class AGLU(nn.Module):
|
|
17
26
|
|
18
27
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
19
28
|
"""Compute the forward pass of the Unified activation function."""
|
20
|
-
lam = torch.clamp(self.lambd, min=0.0001)
|
29
|
+
lam = torch.clamp(self.lambd, min=0.0001) # Clamp lambda to avoid division by zero
|
21
30
|
return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))
|