dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/data/dataset.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
import json
|
|
4
6
|
from collections import defaultdict
|
|
5
7
|
from itertools import repeat
|
|
6
8
|
from multiprocessing.pool import ThreadPool
|
|
7
9
|
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
8
11
|
|
|
9
12
|
import cv2
|
|
10
13
|
import numpy as np
|
|
@@ -44,8 +47,7 @@ DATASET_CACHE_VERSION = "1.0.3"
|
|
|
44
47
|
|
|
45
48
|
|
|
46
49
|
class YOLODataset(BaseDataset):
|
|
47
|
-
"""
|
|
48
|
-
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
|
50
|
+
"""Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
|
49
51
|
|
|
50
52
|
This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box
|
|
51
53
|
(OBB) tasks using the YOLO format.
|
|
@@ -58,20 +60,19 @@ class YOLODataset(BaseDataset):
|
|
|
58
60
|
|
|
59
61
|
Methods:
|
|
60
62
|
cache_labels: Cache dataset labels, check images and read shapes.
|
|
61
|
-
get_labels:
|
|
62
|
-
build_transforms:
|
|
63
|
-
close_mosaic:
|
|
64
|
-
update_labels_info:
|
|
65
|
-
collate_fn:
|
|
63
|
+
get_labels: Return dictionary of labels for YOLO training.
|
|
64
|
+
build_transforms: Build and append transforms to the list.
|
|
65
|
+
close_mosaic: Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
|
|
66
|
+
update_labels_info: Update label format for different tasks.
|
|
67
|
+
collate_fn: Collate data samples into batches.
|
|
66
68
|
|
|
67
69
|
Examples:
|
|
68
70
|
>>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
|
|
69
71
|
>>> dataset.get_labels()
|
|
70
72
|
"""
|
|
71
73
|
|
|
72
|
-
def __init__(self, *args, data=None, task="detect", **kwargs):
|
|
73
|
-
"""
|
|
74
|
-
Initialize the YOLODataset.
|
|
74
|
+
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
|
75
|
+
"""Initialize the YOLODataset.
|
|
75
76
|
|
|
76
77
|
Args:
|
|
77
78
|
data (dict, optional): Dataset configuration dictionary.
|
|
@@ -84,11 +85,10 @@ class YOLODataset(BaseDataset):
|
|
|
84
85
|
self.use_obb = task == "obb"
|
|
85
86
|
self.data = data
|
|
86
87
|
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
|
87
|
-
super().__init__(*args, channels=self.data
|
|
88
|
+
super().__init__(*args, channels=self.data.get("channels", 3), **kwargs)
|
|
88
89
|
|
|
89
|
-
def cache_labels(self, path=Path("./labels.cache")):
|
|
90
|
-
"""
|
|
91
|
-
Cache dataset labels, check images and read shapes.
|
|
90
|
+
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
|
|
91
|
+
"""Cache dataset labels, check images and read shapes.
|
|
92
92
|
|
|
93
93
|
Args:
|
|
94
94
|
path (Path): Path where to save the cache file.
|
|
@@ -154,14 +154,13 @@ class YOLODataset(BaseDataset):
|
|
|
154
154
|
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
|
155
155
|
return x
|
|
156
156
|
|
|
157
|
-
def get_labels(self):
|
|
158
|
-
"""
|
|
159
|
-
Returns dictionary of labels for YOLO training.
|
|
157
|
+
def get_labels(self) -> list[dict]:
|
|
158
|
+
"""Return dictionary of labels for YOLO training.
|
|
160
159
|
|
|
161
160
|
This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
|
|
162
161
|
|
|
163
162
|
Returns:
|
|
164
|
-
(
|
|
163
|
+
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
|
165
164
|
"""
|
|
166
165
|
self.label_files = img2label_paths(self.im_files)
|
|
167
166
|
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
|
@@ -169,7 +168,7 @@ class YOLODataset(BaseDataset):
|
|
|
169
168
|
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
|
170
169
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
171
170
|
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
|
172
|
-
except (FileNotFoundError, AssertionError, AttributeError):
|
|
171
|
+
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
|
|
173
172
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
|
174
173
|
|
|
175
174
|
# Display cache
|
|
@@ -204,9 +203,8 @@ class YOLODataset(BaseDataset):
|
|
|
204
203
|
LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}")
|
|
205
204
|
return labels
|
|
206
205
|
|
|
207
|
-
def build_transforms(self, hyp=None):
|
|
208
|
-
"""
|
|
209
|
-
Builds and appends transforms to the list.
|
|
206
|
+
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
|
207
|
+
"""Build and append transforms to the list.
|
|
210
208
|
|
|
211
209
|
Args:
|
|
212
210
|
hyp (dict, optional): Hyperparameters for transforms.
|
|
@@ -236,9 +234,8 @@ class YOLODataset(BaseDataset):
|
|
|
236
234
|
)
|
|
237
235
|
return transforms
|
|
238
236
|
|
|
239
|
-
def close_mosaic(self, hyp):
|
|
240
|
-
"""
|
|
241
|
-
Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
|
|
237
|
+
def close_mosaic(self, hyp: dict) -> None:
|
|
238
|
+
"""Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
|
|
242
239
|
|
|
243
240
|
Args:
|
|
244
241
|
hyp (dict): Hyperparameters for transforms.
|
|
@@ -249,9 +246,8 @@ class YOLODataset(BaseDataset):
|
|
|
249
246
|
hyp.cutmix = 0.0
|
|
250
247
|
self.transforms = self.build_transforms(hyp)
|
|
251
248
|
|
|
252
|
-
def update_labels_info(self, label):
|
|
253
|
-
"""
|
|
254
|
-
Custom your label format here.
|
|
249
|
+
def update_labels_info(self, label: dict) -> dict:
|
|
250
|
+
"""Update label format for different tasks.
|
|
255
251
|
|
|
256
252
|
Args:
|
|
257
253
|
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
|
|
@@ -259,7 +255,7 @@ class YOLODataset(BaseDataset):
|
|
|
259
255
|
Returns:
|
|
260
256
|
(dict): Updated label dictionary with instances.
|
|
261
257
|
|
|
262
|
-
|
|
258
|
+
Notes:
|
|
263
259
|
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
|
264
260
|
Can also support classification and semantic segmentation by adding or removing dict keys there.
|
|
265
261
|
"""
|
|
@@ -283,12 +279,11 @@ class YOLODataset(BaseDataset):
|
|
|
283
279
|
return label
|
|
284
280
|
|
|
285
281
|
@staticmethod
|
|
286
|
-
def collate_fn(batch):
|
|
287
|
-
"""
|
|
288
|
-
Collates data samples into batches.
|
|
282
|
+
def collate_fn(batch: list[dict]) -> dict:
|
|
283
|
+
"""Collate data samples into batches.
|
|
289
284
|
|
|
290
285
|
Args:
|
|
291
|
-
batch (
|
|
286
|
+
batch (list[dict]): List of dictionaries containing sample data.
|
|
292
287
|
|
|
293
288
|
Returns:
|
|
294
289
|
(dict): Collated batch with stacked tensors.
|
|
@@ -314,15 +309,14 @@ class YOLODataset(BaseDataset):
|
|
|
314
309
|
|
|
315
310
|
|
|
316
311
|
class YOLOMultiModalDataset(YOLODataset):
|
|
317
|
-
"""
|
|
318
|
-
Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
|
|
312
|
+
"""Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
|
|
319
313
|
|
|
320
|
-
This class extends YOLODataset to add text information for multi-modal model training, enabling models to
|
|
321
|
-
|
|
314
|
+
This class extends YOLODataset to add text information for multi-modal model training, enabling models to process
|
|
315
|
+
both image and text data.
|
|
322
316
|
|
|
323
317
|
Methods:
|
|
324
|
-
update_labels_info:
|
|
325
|
-
build_transforms:
|
|
318
|
+
update_labels_info: Add text information for multi-modal model training.
|
|
319
|
+
build_transforms: Enhance data transformations with text augmentation.
|
|
326
320
|
|
|
327
321
|
Examples:
|
|
328
322
|
>>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
|
|
@@ -330,9 +324,8 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
|
330
324
|
>>> print(batch.keys()) # Should include 'texts'
|
|
331
325
|
"""
|
|
332
326
|
|
|
333
|
-
def __init__(self, *args, data=None, task="detect", **kwargs):
|
|
334
|
-
"""
|
|
335
|
-
Initialize a YOLOMultiModalDataset.
|
|
327
|
+
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
|
328
|
+
"""Initialize a YOLOMultiModalDataset.
|
|
336
329
|
|
|
337
330
|
Args:
|
|
338
331
|
data (dict, optional): Dataset configuration dictionary.
|
|
@@ -342,9 +335,8 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
|
342
335
|
"""
|
|
343
336
|
super().__init__(*args, data=data, task=task, **kwargs)
|
|
344
337
|
|
|
345
|
-
def update_labels_info(self, label):
|
|
346
|
-
"""
|
|
347
|
-
Add texts information for multi-modal model training.
|
|
338
|
+
def update_labels_info(self, label: dict) -> dict:
|
|
339
|
+
"""Add text information for multi-modal model training.
|
|
348
340
|
|
|
349
341
|
Args:
|
|
350
342
|
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
|
|
@@ -359,9 +351,8 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
|
359
351
|
|
|
360
352
|
return labels
|
|
361
353
|
|
|
362
|
-
def build_transforms(self, hyp=None):
|
|
363
|
-
"""
|
|
364
|
-
Enhances data transformations with optional text augmentation for multi-modal training.
|
|
354
|
+
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
|
355
|
+
"""Enhance data transformations with optional text augmentation for multi-modal training.
|
|
365
356
|
|
|
366
357
|
Args:
|
|
367
358
|
hyp (dict, optional): Hyperparameters for transforms.
|
|
@@ -385,11 +376,10 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
|
385
376
|
|
|
386
377
|
@property
|
|
387
378
|
def category_names(self):
|
|
388
|
-
"""
|
|
389
|
-
Return category names for the dataset.
|
|
379
|
+
"""Return category names for the dataset.
|
|
390
380
|
|
|
391
381
|
Returns:
|
|
392
|
-
(
|
|
382
|
+
(set[str]): List of class names.
|
|
393
383
|
"""
|
|
394
384
|
names = self.data["names"].values()
|
|
395
385
|
return {n.strip() for name in names for n in name.split("/")} # category names
|
|
@@ -408,48 +398,48 @@ class YOLOMultiModalDataset(YOLODataset):
|
|
|
408
398
|
return category_freq
|
|
409
399
|
|
|
410
400
|
@staticmethod
|
|
411
|
-
def _get_neg_texts(category_freq, threshold=100):
|
|
401
|
+
def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
|
|
412
402
|
"""Get negative text samples based on frequency threshold."""
|
|
403
|
+
threshold = min(max(category_freq.values()), 100)
|
|
413
404
|
return [k for k, v in category_freq.items() if v >= threshold]
|
|
414
405
|
|
|
415
406
|
|
|
416
407
|
class GroundingDataset(YOLODataset):
|
|
417
|
-
"""
|
|
418
|
-
Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format.
|
|
408
|
+
"""Dataset class for object detection tasks using annotations from a JSON file in grounding format.
|
|
419
409
|
|
|
420
|
-
This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than
|
|
421
|
-
|
|
410
|
+
This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than the standard
|
|
411
|
+
YOLO format text files.
|
|
422
412
|
|
|
423
413
|
Attributes:
|
|
424
414
|
json_file (str): Path to the JSON file containing annotations.
|
|
425
415
|
|
|
426
416
|
Methods:
|
|
427
|
-
get_img_files:
|
|
428
|
-
get_labels:
|
|
429
|
-
build_transforms:
|
|
417
|
+
get_img_files: Return empty list as image files are read in get_labels.
|
|
418
|
+
get_labels: Load annotations from a JSON file and prepare them for training.
|
|
419
|
+
build_transforms: Configure augmentations for training with optional text loading.
|
|
430
420
|
|
|
431
421
|
Examples:
|
|
432
422
|
>>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect")
|
|
433
423
|
>>> len(dataset) # Number of valid images with annotations
|
|
434
424
|
"""
|
|
435
425
|
|
|
436
|
-
def __init__(self, *args, task="detect", json_file="", **kwargs):
|
|
437
|
-
"""
|
|
438
|
-
Initialize a GroundingDataset for object detection.
|
|
426
|
+
def __init__(self, *args, task: str = "detect", json_file: str = "", max_samples: int = 80, **kwargs):
|
|
427
|
+
"""Initialize a GroundingDataset for object detection.
|
|
439
428
|
|
|
440
429
|
Args:
|
|
441
430
|
json_file (str): Path to the JSON file containing annotations.
|
|
442
431
|
task (str): Must be 'detect' or 'segment' for GroundingDataset.
|
|
432
|
+
max_samples (int): Maximum number of samples to load for text augmentation.
|
|
443
433
|
*args (Any): Additional positional arguments for the parent class.
|
|
444
434
|
**kwargs (Any): Additional keyword arguments for the parent class.
|
|
445
435
|
"""
|
|
446
436
|
assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
|
|
447
437
|
self.json_file = json_file
|
|
438
|
+
self.max_samples = max_samples
|
|
448
439
|
super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
|
|
449
440
|
|
|
450
|
-
def get_img_files(self, img_path):
|
|
451
|
-
"""
|
|
452
|
-
The image files would be read in `get_labels` function, return empty list here.
|
|
441
|
+
def get_img_files(self, img_path: str) -> list:
|
|
442
|
+
"""The image files would be read in `get_labels` function, return empty list here.
|
|
453
443
|
|
|
454
444
|
Args:
|
|
455
445
|
img_path (str): Path to the directory containing images.
|
|
@@ -459,29 +449,47 @@ class GroundingDataset(YOLODataset):
|
|
|
459
449
|
"""
|
|
460
450
|
return []
|
|
461
451
|
|
|
462
|
-
def verify_labels(self, labels):
|
|
463
|
-
"""Verify the number of instances in the dataset matches expected counts.
|
|
464
|
-
|
|
465
|
-
if
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
452
|
+
def verify_labels(self, labels: list[dict[str, Any]]) -> None:
|
|
453
|
+
"""Verify the number of instances in the dataset matches expected counts.
|
|
454
|
+
|
|
455
|
+
This method checks if the total number of bounding box instances in the provided labels matches the expected
|
|
456
|
+
count for known datasets. It performs validation against a predefined set of datasets with known instance
|
|
457
|
+
counts.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary contains dataset
|
|
461
|
+
annotations. Each label dict must have a 'bboxes' key with a numpy array or tensor containing bounding
|
|
462
|
+
box coordinates.
|
|
463
|
+
|
|
464
|
+
Raises:
|
|
465
|
+
AssertionError: If the actual instance count doesn't match the expected count for a recognized dataset.
|
|
475
466
|
|
|
476
|
-
|
|
467
|
+
Notes:
|
|
468
|
+
For unrecognized datasets (those not in the predefined expected_counts),
|
|
469
|
+
a warning is logged and verification is skipped.
|
|
477
470
|
"""
|
|
478
|
-
|
|
471
|
+
expected_counts = {
|
|
472
|
+
"final_mixed_train_no_coco_segm": 3662412,
|
|
473
|
+
"final_mixed_train_no_coco": 3681235,
|
|
474
|
+
"final_flickr_separateGT_train_segm": 638214,
|
|
475
|
+
"final_flickr_separateGT_train": 640704,
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
instance_count = sum(label["bboxes"].shape[0] for label in labels)
|
|
479
|
+
for data_name, count in expected_counts.items():
|
|
480
|
+
if data_name in self.json_file:
|
|
481
|
+
assert instance_count == count, f"'{self.json_file}' has {instance_count} instances, expected {count}."
|
|
482
|
+
return
|
|
483
|
+
LOGGER.warning(f"Skipping instance count verification for unrecognized dataset '{self.json_file}'")
|
|
484
|
+
|
|
485
|
+
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict[str, Any]:
|
|
486
|
+
"""Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
|
|
479
487
|
|
|
480
488
|
Args:
|
|
481
489
|
path (Path): Path where to save the cache file.
|
|
482
490
|
|
|
483
491
|
Returns:
|
|
484
|
-
(dict): Dictionary containing cached labels and related information.
|
|
492
|
+
(dict[str, Any]): Dictionary containing cached labels and related information.
|
|
485
493
|
"""
|
|
486
494
|
x = {"labels": []}
|
|
487
495
|
LOGGER.info("Loading annotation file...")
|
|
@@ -521,7 +529,7 @@ class GroundingDataset(YOLODataset):
|
|
|
521
529
|
cat2id[cat_name] = len(cat2id)
|
|
522
530
|
texts.append([cat_name])
|
|
523
531
|
cls = cat2id[cat_name] # class
|
|
524
|
-
box = [cls
|
|
532
|
+
box = [cls, *box.tolist()]
|
|
525
533
|
if box not in bboxes:
|
|
526
534
|
bboxes.append(box)
|
|
527
535
|
if ann.get("segmentation") is not None:
|
|
@@ -538,7 +546,7 @@ class GroundingDataset(YOLODataset):
|
|
|
538
546
|
.reshape(-1)
|
|
539
547
|
.tolist()
|
|
540
548
|
)
|
|
541
|
-
s = [cls
|
|
549
|
+
s = [cls, *s]
|
|
542
550
|
segments.append(s)
|
|
543
551
|
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
|
544
552
|
|
|
@@ -564,31 +572,29 @@ class GroundingDataset(YOLODataset):
|
|
|
564
572
|
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
|
565
573
|
return x
|
|
566
574
|
|
|
567
|
-
def get_labels(self):
|
|
568
|
-
"""
|
|
569
|
-
Load labels from cache or generate them from JSON file.
|
|
575
|
+
def get_labels(self) -> list[dict]:
|
|
576
|
+
"""Load labels from cache or generate them from JSON file.
|
|
570
577
|
|
|
571
578
|
Returns:
|
|
572
|
-
(
|
|
579
|
+
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
|
573
580
|
"""
|
|
574
581
|
cache_path = Path(self.json_file).with_suffix(".cache")
|
|
575
582
|
try:
|
|
576
583
|
cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
|
577
584
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
|
578
585
|
assert cache["hash"] == get_hash(self.json_file) # identical hash
|
|
579
|
-
except (FileNotFoundError, AssertionError, AttributeError):
|
|
586
|
+
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
|
|
580
587
|
cache, _ = self.cache_labels(cache_path), False # run cache ops
|
|
581
588
|
[cache.pop(k) for k in ("hash", "version")] # remove items
|
|
582
589
|
labels = cache["labels"]
|
|
583
|
-
|
|
590
|
+
self.verify_labels(labels)
|
|
584
591
|
self.im_files = [str(label["im_file"]) for label in labels]
|
|
585
592
|
if LOCAL_RANK in {-1, 0}:
|
|
586
593
|
LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
|
|
587
594
|
return labels
|
|
588
595
|
|
|
589
|
-
def build_transforms(self, hyp=None):
|
|
590
|
-
"""
|
|
591
|
-
Configures augmentations for training with optional text loading.
|
|
596
|
+
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
|
597
|
+
"""Configure augmentations for training with optional text loading.
|
|
592
598
|
|
|
593
599
|
Args:
|
|
594
600
|
hyp (dict, optional): Hyperparameters for transforms.
|
|
@@ -603,7 +609,7 @@ class GroundingDataset(YOLODataset):
|
|
|
603
609
|
# the strategy of selecting negative is restricted in one dataset,
|
|
604
610
|
# while official pre-saved neg embeddings from all datasets at once.
|
|
605
611
|
transform = RandomLoadText(
|
|
606
|
-
max_samples=80,
|
|
612
|
+
max_samples=min(self.max_samples, 80),
|
|
607
613
|
padding=True,
|
|
608
614
|
padding_value=self._get_neg_texts(self.category_freq),
|
|
609
615
|
)
|
|
@@ -627,17 +633,17 @@ class GroundingDataset(YOLODataset):
|
|
|
627
633
|
return category_freq
|
|
628
634
|
|
|
629
635
|
@staticmethod
|
|
630
|
-
def _get_neg_texts(category_freq, threshold=100):
|
|
636
|
+
def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
|
|
631
637
|
"""Get negative text samples based on frequency threshold."""
|
|
638
|
+
threshold = min(max(category_freq.values()), 100)
|
|
632
639
|
return [k for k, v in category_freq.items() if v >= threshold]
|
|
633
640
|
|
|
634
641
|
|
|
635
642
|
class YOLOConcatDataset(ConcatDataset):
|
|
636
|
-
"""
|
|
637
|
-
Dataset as a concatenation of multiple datasets.
|
|
643
|
+
"""Dataset as a concatenation of multiple datasets.
|
|
638
644
|
|
|
639
|
-
This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same
|
|
640
|
-
|
|
645
|
+
This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same collation
|
|
646
|
+
function.
|
|
641
647
|
|
|
642
648
|
Methods:
|
|
643
649
|
collate_fn: Static method that collates data samples into batches using YOLODataset's collation function.
|
|
@@ -649,21 +655,19 @@ class YOLOConcatDataset(ConcatDataset):
|
|
|
649
655
|
"""
|
|
650
656
|
|
|
651
657
|
@staticmethod
|
|
652
|
-
def collate_fn(batch):
|
|
653
|
-
"""
|
|
654
|
-
Collates data samples into batches.
|
|
658
|
+
def collate_fn(batch: list[dict]) -> dict:
|
|
659
|
+
"""Collate data samples into batches.
|
|
655
660
|
|
|
656
661
|
Args:
|
|
657
|
-
batch (
|
|
662
|
+
batch (list[dict]): List of dictionaries containing sample data.
|
|
658
663
|
|
|
659
664
|
Returns:
|
|
660
665
|
(dict): Collated batch with stacked tensors.
|
|
661
666
|
"""
|
|
662
667
|
return YOLODataset.collate_fn(batch)
|
|
663
668
|
|
|
664
|
-
def close_mosaic(self, hyp):
|
|
665
|
-
"""
|
|
666
|
-
Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.
|
|
669
|
+
def close_mosaic(self, hyp: dict) -> None:
|
|
670
|
+
"""Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
|
|
667
671
|
|
|
668
672
|
Args:
|
|
669
673
|
hyp (dict): Hyperparameters for transforms.
|
|
@@ -684,8 +688,7 @@ class SemanticDataset(BaseDataset):
|
|
|
684
688
|
|
|
685
689
|
|
|
686
690
|
class ClassificationDataset:
|
|
687
|
-
"""
|
|
688
|
-
Extends torchvision ImageFolder to support YOLO classification tasks.
|
|
691
|
+
"""Dataset class for image classification tasks extending torchvision ImageFolder functionality.
|
|
689
692
|
|
|
690
693
|
This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
|
|
691
694
|
handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
|
|
@@ -695,20 +698,19 @@ class ClassificationDataset:
|
|
|
695
698
|
cache_ram (bool): Indicates if caching in RAM is enabled.
|
|
696
699
|
cache_disk (bool): Indicates if caching on disk is enabled.
|
|
697
700
|
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
|
698
|
-
|
|
701
|
+
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
|
699
702
|
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
|
700
703
|
root (str): Root directory of the dataset.
|
|
701
704
|
prefix (str): Prefix for logging and cache filenames.
|
|
702
705
|
|
|
703
706
|
Methods:
|
|
704
|
-
__getitem__:
|
|
705
|
-
__len__:
|
|
706
|
-
verify_images:
|
|
707
|
+
__getitem__: Return subset of data and targets corresponding to given indices.
|
|
708
|
+
__len__: Return the total number of samples in the dataset.
|
|
709
|
+
verify_images: Verify all images in dataset.
|
|
707
710
|
"""
|
|
708
711
|
|
|
709
|
-
def __init__(self, root, args, augment=False, prefix=""):
|
|
710
|
-
"""
|
|
711
|
-
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
|
712
|
+
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
|
|
713
|
+
"""Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
|
|
712
714
|
|
|
713
715
|
Args:
|
|
714
716
|
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
|
@@ -740,7 +742,7 @@ class ClassificationDataset:
|
|
|
740
742
|
self.cache_ram = False
|
|
741
743
|
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
|
742
744
|
self.samples = self.verify_images() # filter out bad images
|
|
743
|
-
self.samples = [list(x)
|
|
745
|
+
self.samples = [[*list(x), Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
|
744
746
|
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
|
745
747
|
self.torch_transforms = (
|
|
746
748
|
classify_augmentations(
|
|
@@ -758,9 +760,8 @@ class ClassificationDataset:
|
|
|
758
760
|
else classify_transforms(size=args.imgsz)
|
|
759
761
|
)
|
|
760
762
|
|
|
761
|
-
def __getitem__(self, i):
|
|
762
|
-
"""
|
|
763
|
-
Returns subset of data and targets corresponding to given indices.
|
|
763
|
+
def __getitem__(self, i: int) -> dict:
|
|
764
|
+
"""Return subset of data and targets corresponding to given indices.
|
|
764
765
|
|
|
765
766
|
Args:
|
|
766
767
|
i (int): Index of the sample to retrieve.
|
|
@@ -787,9 +788,8 @@ class ClassificationDataset:
|
|
|
787
788
|
"""Return the total number of samples in the dataset."""
|
|
788
789
|
return len(self.samples)
|
|
789
790
|
|
|
790
|
-
def verify_images(self):
|
|
791
|
-
"""
|
|
792
|
-
Verify all images in dataset.
|
|
791
|
+
def verify_images(self) -> list[tuple]:
|
|
792
|
+
"""Verify all images in dataset.
|
|
793
793
|
|
|
794
794
|
Returns:
|
|
795
795
|
(list): List of valid samples after verification.
|