dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,176 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
4
+ from ultralytics.data.utils import check_det_dataset
5
+ from ultralytics.models.yolo.world import WorldTrainer
6
+ from ultralytics.utils import DEFAULT_CFG, LOGGER
7
+ from ultralytics.utils.torch_utils import de_parallel
8
+
9
+
10
+ class WorldTrainerFromScratch(WorldTrainer):
11
+ """
12
+ A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
13
+
14
+ This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
15
+ supporting training YOLO-World models with combined vision-language capabilities.
16
+
17
+ Attributes:
18
+ cfg (dict): Configuration dictionary with default parameters for model training.
19
+ overrides (dict): Dictionary of parameter overrides to customize the configuration.
20
+ _callbacks (list): List of callback functions to be executed during different stages of training.
21
+
22
+ Examples:
23
+ >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
24
+ >>> from ultralytics import YOLOWorld
25
+ >>> data = dict(
26
+ ... train=dict(
27
+ ... yolo_data=["Objects365.yaml"],
28
+ ... grounding_data=[
29
+ ... dict(
30
+ ... img_path="../datasets/flickr30k/images",
31
+ ... json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
32
+ ... ),
33
+ ... dict(
34
+ ... img_path="../datasets/GQA/images",
35
+ ... json_file="../datasets/GQA/final_mixed_train_no_coco.json",
36
+ ... ),
37
+ ... ],
38
+ ... ),
39
+ ... val=dict(yolo_data=["lvis.yaml"]),
40
+ ... )
41
+ >>> model = YOLOWorld("yolov8s-worldv2.yaml")
42
+ >>> model.train(data=data, trainer=WorldTrainerFromScratch)
43
+ """
44
+
45
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
46
+ """
47
+ Initialize a WorldTrainerFromScratch object.
48
+
49
+ This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
50
+ object detection and grounding datasets for vision-language capabilities.
51
+
52
+ Args:
53
+ cfg (dict): Configuration dictionary with default parameters for model training.
54
+ overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.
55
+ _callbacks (list, optional): List of callback functions to be executed during different stages of training.
56
+
57
+ Examples:
58
+ >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
59
+ >>> from ultralytics import YOLOWorld
60
+ >>> data = dict(
61
+ ... train=dict(
62
+ ... yolo_data=["Objects365.yaml"],
63
+ ... grounding_data=[
64
+ ... dict(
65
+ ... img_path="../datasets/flickr30k/images",
66
+ ... json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
67
+ ... ),
68
+ ... ],
69
+ ... ),
70
+ ... val=dict(yolo_data=["lvis.yaml"]),
71
+ ... )
72
+ >>> model = YOLOWorld("yolov8s-worldv2.yaml")
73
+ >>> model.train(data=data, trainer=WorldTrainerFromScratch)
74
+ """
75
+ if overrides is None:
76
+ overrides = {}
77
+ super().__init__(cfg, overrides, _callbacks)
78
+
79
+ def build_dataset(self, img_path, mode="train", batch=None):
80
+ """
81
+ Build YOLO Dataset for training or validation.
82
+
83
+ This method constructs appropriate datasets based on the mode and input paths, handling both
84
+ standard YOLO datasets and grounding datasets with different formats.
85
+
86
+ Args:
87
+ img_path (List[str] | str): Path to the folder containing images or list of paths.
88
+ mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
89
+ batch (int, optional): Size of batches, used for rectangular training/validation.
90
+
91
+ Returns:
92
+ (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
93
+ """
94
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
95
+ if mode != "train":
96
+ return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=gs)
97
+ datasets = [
98
+ build_yolo_dataset(self.args, im_path, batch, self.training_data[im_path], stride=gs, multi_modal=True)
99
+ if isinstance(im_path, str)
100
+ else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
101
+ for im_path in img_path
102
+ ]
103
+ return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
104
+
105
+ def get_dataset(self):
106
+ """
107
+ Get train and validation paths from data dictionary.
108
+
109
+ Processes the data configuration to extract paths for training and validation datasets,
110
+ handling both YOLO detection datasets and grounding datasets.
111
+
112
+ Returns:
113
+ (str): Train dataset path.
114
+ (str): Validation dataset path.
115
+
116
+ Raises:
117
+ AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
118
+ """
119
+ final_data = {}
120
+ data_yaml = self.args.data
121
+ assert data_yaml.get("train", False), "train dataset not found" # object365.yaml
122
+ assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml
123
+ data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
124
+ assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
125
+ val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
126
+ for d in data["val"]:
127
+ if d.get("minival") is None: # for lvis dataset
128
+ continue
129
+ d["minival"] = str(d["path"] / d["minival"])
130
+ for s in ["train", "val"]:
131
+ final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
132
+ # save grounding data if there's one
133
+ grounding_data = data_yaml[s].get("grounding_data")
134
+ if grounding_data is None:
135
+ continue
136
+ grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
137
+ for g in grounding_data:
138
+ assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
139
+ final_data[s] += grounding_data
140
+ data["val"] = data["val"][0] # assign the first val dataset as currently only one validation set is supported
141
+ # NOTE: to make training work properly, set `nc` and `names`
142
+ final_data["nc"] = data["val"]["nc"]
143
+ final_data["names"] = data["val"]["names"]
144
+ # NOTE: add path with lvis path
145
+ final_data["path"] = data["val"]["path"]
146
+ final_data["channels"] = data["val"]["channels"]
147
+ self.data = final_data
148
+ if self.args.single_cls: # consistent with base trainer
149
+ LOGGER.info("Overriding class names with single class.")
150
+ self.data["names"] = {0: "object"}
151
+ self.data["nc"] = 1
152
+ self.training_data = {}
153
+ for d in data["train"]:
154
+ if self.args.single_cls:
155
+ d["names"] = {0: "object"}
156
+ d["nc"] = 1
157
+ self.training_data[d["train"]] = d
158
+ return final_data
159
+
160
+ def plot_training_labels(self):
161
+ """Do not plot labels for YOLO-World training."""
162
+ pass
163
+
164
+ def final_eval(self):
165
+ """
166
+ Perform final evaluation and validation for the YOLO-World model.
167
+
168
+ Configures the validator with appropriate dataset and split information before running evaluation.
169
+
170
+ Returns:
171
+ (dict): Dictionary containing evaluation metrics and results.
172
+ """
173
+ val = self.args.data["val"]["yolo_data"][0]
174
+ self.validator.args.data = val
175
+ self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
176
+ return super().final_eval()
@@ -0,0 +1,22 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor
4
+ from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
5
+ from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer
6
+ from .val import YOLOEDetectValidator, YOLOESegValidator
7
+
8
+ __all__ = [
9
+ "YOLOETrainer",
10
+ "YOLOEPETrainer",
11
+ "YOLOESegTrainer",
12
+ "YOLOEDetectValidator",
13
+ "YOLOESegValidator",
14
+ "YOLOEPESegTrainer",
15
+ "YOLOESegTrainerFromScratch",
16
+ "YOLOESegVPTrainer",
17
+ "YOLOEVPTrainer",
18
+ "YOLOEPEFreeTrainer",
19
+ "YOLOEVPDetectPredictor",
20
+ "YOLOEVPSegPredictor",
21
+ "YOLOETrainerFromScratch",
22
+ ]
@@ -0,0 +1,169 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from ultralytics.data.augment import LoadVisualPrompt
7
+ from ultralytics.models.yolo.detect import DetectionPredictor
8
+ from ultralytics.models.yolo.segment import SegmentationPredictor
9
+
10
+
11
+ class YOLOEVPDetectPredictor(DetectionPredictor):
12
+ """
13
+ A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
14
+
15
+ This mixin provides common functionality for YOLO models that use visual prompting, including
16
+ model setup, prompt handling, and preprocessing transformations.
17
+
18
+ Attributes:
19
+ model (torch.nn.Module): The YOLO model for inference.
20
+ device (torch.device): Device to run the model on (CPU or CUDA).
21
+ prompts (dict): Visual prompts containing class indices and bounding boxes or masks.
22
+
23
+ Methods:
24
+ setup_model: Initialize the YOLO model and set it to evaluation mode.
25
+ set_return_vpe: Set whether to return visual prompt embeddings.
26
+ set_prompts: Set the visual prompts for the model.
27
+ pre_transform: Preprocess images and prompts before inference.
28
+ inference: Run inference with visual prompts.
29
+ """
30
+
31
+ def setup_model(self, model, verbose=True):
32
+ """
33
+ Sets up the model for prediction.
34
+
35
+ Args:
36
+ model (torch.nn.Module): Model to load or use.
37
+ verbose (bool): If True, provides detailed logging.
38
+ """
39
+ super().setup_model(model, verbose=verbose)
40
+ self.done_warmup = True
41
+
42
+ def set_prompts(self, prompts):
43
+ """
44
+ Set the visual prompts for the model.
45
+
46
+ Args:
47
+ prompts (dict): Dictionary containing class indices and bounding boxes or masks.
48
+ Must include a 'cls' key with class indices.
49
+ """
50
+ self.prompts = prompts
51
+
52
+ def pre_transform(self, im):
53
+ """
54
+ Preprocess images and prompts before inference.
55
+
56
+ This method applies letterboxing to the input image and transforms the visual prompts
57
+ (bounding boxes or masks) accordingly.
58
+
59
+ Args:
60
+ im (list): List containing a single input image.
61
+
62
+ Returns:
63
+ (list): Preprocessed image ready for model inference.
64
+
65
+ Raises:
66
+ ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
67
+ """
68
+ img = super().pre_transform(im)
69
+ bboxes = self.prompts.pop("bboxes", None)
70
+ masks = self.prompts.pop("masks", None)
71
+ category = self.prompts["cls"]
72
+ if len(img) == 1:
73
+ visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
74
+ self.prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
75
+ else:
76
+ # NOTE: only supports bboxes as prompts for now
77
+ assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
78
+ # NOTE: needs List[np.ndarray]
79
+ assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
80
+ f"Expected List[np.ndarray], but got {bboxes}!"
81
+ )
82
+ assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
83
+ f"Expected List[np.ndarray], but got {category}!"
84
+ )
85
+ assert len(im) == len(category) == len(bboxes), (
86
+ f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
87
+ )
88
+ visuals = [
89
+ self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
90
+ for i in range(len(img))
91
+ ]
92
+ self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)
93
+
94
+ return img
95
+
96
+ def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
97
+ """
98
+ Processes a single image by resizing bounding boxes or masks and generating visuals.
99
+
100
+ Args:
101
+ dst_shape (tuple): The target shape (height, width) of the image.
102
+ src_shape (tuple): The original shape (height, width) of the image.
103
+ category (str): The category of the image for visual prompts.
104
+ bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2]. Defaults to None.
105
+ masks (np.ndarray, optional): A list of masks corresponding to the image. Defaults to None.
106
+
107
+ Returns:
108
+ visuals: The processed visuals for the image.
109
+
110
+ Raises:
111
+ ValueError: If neither `bboxes` nor `masks` are provided.
112
+ """
113
+ if bboxes is not None and len(bboxes):
114
+ bboxes = np.array(bboxes, dtype=np.float32)
115
+ if bboxes.ndim == 1:
116
+ bboxes = bboxes[None, :]
117
+ # Calculate scaling factor and adjust bounding boxes
118
+ gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new
119
+ bboxes *= gain
120
+ bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)
121
+ bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)
122
+ elif masks is not None:
123
+ # Resize and process masks
124
+ resized_masks = super().pre_transform(masks)
125
+ masks = np.stack(resized_masks) # (N, H, W)
126
+ masks[masks == 114] = 0 # Reset padding values to 0
127
+ else:
128
+ raise ValueError("Please provide valid bboxes or masks")
129
+
130
+ # Generate visuals using the visual prompt loader
131
+ return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
132
+
133
+ def inference(self, im, *args, **kwargs):
134
+ """
135
+ Run inference with visual prompts.
136
+
137
+ Args:
138
+ im (torch.Tensor): Input image tensor.
139
+ *args (Any): Variable length argument list.
140
+ **kwargs (Any): Arbitrary keyword arguments.
141
+
142
+ Returns:
143
+ (torch.Tensor): Model prediction results.
144
+ """
145
+ return super().inference(im, vpe=self.prompts, *args, **kwargs)
146
+
147
+ def get_vpe(self, source):
148
+ """
149
+ Processes the source to get the visual prompt embeddings (VPE).
150
+
151
+ Args:
152
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
153
+ of the image to make predictions on. Accepts various types including file paths, URLs, PIL
154
+ images, numpy arrays, and torch tensors.
155
+
156
+ Returns:
157
+ (torch.Tensor): The visual prompt embeddings (VPE) from the model.
158
+ """
159
+ self.setup_source(source)
160
+ assert len(self.dataset) == 1, "get_vpe only supports one image!"
161
+ for _, im0s, _ in self.dataset:
162
+ im = self.preprocess(im0s)
163
+ return self.model(im, vpe=self.prompts, return_vpe=True)
164
+
165
+
166
+ class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
167
+ """Predictor for YOLOE VP segmentation."""
168
+
169
+ pass
@@ -0,0 +1,298 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import itertools
4
+ from copy import copy, deepcopy
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ from ultralytics.data import YOLOConcatDataset, build_yolo_dataset
10
+ from ultralytics.data.augment import LoadVisualPrompt
11
+ from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
12
+ from ultralytics.nn.tasks import YOLOEModel
13
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
14
+ from ultralytics.utils.torch_utils import de_parallel
15
+
16
+ from ..world.train_world import WorldTrainerFromScratch
17
+ from .val import YOLOEDetectValidator
18
+
19
+
20
+ class YOLOETrainer(DetectionTrainer):
21
+ """A base trainer for YOLOE training."""
22
+
23
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
24
+ """
25
+ Initialize the YOLOE Trainer with specified configurations.
26
+
27
+ This method sets up the YOLOE trainer with the provided configuration and overrides, initializing
28
+ the training environment, model, and callbacks for YOLOE object detection training.
29
+
30
+ Args:
31
+ cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
32
+ overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
33
+ _callbacks (list, optional): List of callback functions to be applied during training.
34
+ """
35
+ if overrides is None:
36
+ overrides = {}
37
+ overrides["overlap_mask"] = False
38
+ super().__init__(cfg, overrides, _callbacks)
39
+
40
+ def get_model(self, cfg=None, weights=None, verbose=True):
41
+ """
42
+ Return a YOLOEModel initialized with the specified configuration and weights.
43
+
44
+ Args:
45
+ cfg (dict | str | None): Model configuration. Can be a dictionary containing a 'yaml_file' key,
46
+ a direct path to a YAML file, or None to use default configuration.
47
+ weights (str | Path | None): Path to pretrained weights file to load into the model.
48
+ verbose (bool): Whether to display model information during initialization.
49
+
50
+ Returns:
51
+ (YOLOEModel): The initialized YOLOE model.
52
+
53
+ Notes:
54
+ - The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
55
+ - The nc parameter here represents the maximum number of different text samples in one image,
56
+ rather than the actual number of classes.
57
+ """
58
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
59
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
60
+ model = YOLOEModel(
61
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
62
+ ch=self.data["channels"],
63
+ nc=min(self.data["nc"], 80),
64
+ verbose=verbose and RANK == -1,
65
+ )
66
+ if weights:
67
+ model.load(weights)
68
+
69
+ return model
70
+
71
+ def get_validator(self):
72
+ """Returns a DetectionValidator for YOLO model validation."""
73
+ self.loss_names = "box", "cls", "dfl"
74
+ return YOLOEDetectValidator(
75
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
76
+ )
77
+
78
+ def build_dataset(self, img_path, mode="train", batch=None):
79
+ """
80
+ Build YOLO Dataset.
81
+
82
+ Args:
83
+ img_path (str): Path to the folder containing images.
84
+ mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
85
+ batch (int, optional): Size of batches, this is for `rect`.
86
+
87
+ Returns:
88
+ (Dataset): YOLO dataset configured for training or validation.
89
+ """
90
+ gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
91
+ return build_yolo_dataset(
92
+ self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
93
+ )
94
+
95
+
96
+ class YOLOEPETrainer(DetectionTrainer):
97
+ """Fine-tune YOLOE model in linear probing way."""
98
+
99
+ def get_model(self, cfg=None, weights=None, verbose=True):
100
+ """
101
+ Return YOLOEModel initialized with specified config and weights.
102
+
103
+ Args:
104
+ cfg (dict | str, optional): Model configuration.
105
+ weights (str, optional): Path to pretrained weights.
106
+ verbose (bool): Whether to display model information.
107
+
108
+ Returns:
109
+ (YOLOEModel): Initialized model with frozen layers except for specific projection layers.
110
+ """
111
+ # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
112
+ # NOTE: Following the official config, nc hard-coded to 80 for now.
113
+ model = YOLOEModel(
114
+ cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
115
+ ch=self.data["channels"],
116
+ nc=self.data["nc"],
117
+ verbose=verbose and RANK == -1,
118
+ )
119
+
120
+ del model.model[-1].savpe
121
+
122
+ assert weights is not None, "Pretrained weights must be provided for linear probing."
123
+ if weights:
124
+ model.load(weights)
125
+
126
+ model.eval()
127
+ names = list(self.data["names"].values())
128
+ # NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
129
+ # it'd get correct results as long as loading proper pretrained weights.
130
+ tpe = model.get_text_pe(names)
131
+ model.set_classes(names, tpe)
132
+ model.model[-1].fuse(model.pe) # fuse text embeddings to classify head
133
+ model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
134
+ model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
135
+ model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
136
+ del model.pe
137
+ model.train()
138
+
139
+ return model
140
+
141
+
142
+ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
143
+ """Train YOLOE models from scratch."""
144
+
145
+ def build_dataset(self, img_path, mode="train", batch=None):
146
+ """
147
+ Build YOLO Dataset for training or validation.
148
+
149
+ This method constructs appropriate datasets based on the mode and input paths, handling both
150
+ standard YOLO datasets and grounding datasets with different formats.
151
+
152
+ Args:
153
+ img_path (List[str] | str): Path to the folder containing images or list of paths.
154
+ mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
155
+ batch (int, optional): Size of batches, used for rectangular training/validation.
156
+
157
+ Returns:
158
+ (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
159
+ """
160
+ datasets = WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
161
+ if mode == "train":
162
+ self.set_text_embeddings(
163
+ datasets.datasets if hasattr(datasets, "datasets") else [datasets], batch
164
+ ) # cache text embeddings to accelerate training
165
+ return datasets
166
+
167
+ def set_text_embeddings(self, datasets, batch):
168
+ """
169
+ Set text embeddings for datasets to accelerate training by caching category names.
170
+
171
+ This method collects unique category names from all datasets, then generates and caches text embeddings
172
+ for these categories to improve training efficiency.
173
+
174
+ Args:
175
+ datasets (List[Dataset]): List of datasets from which to extract category names.
176
+ batch (int | None): Batch size used for processing.
177
+
178
+ Notes:
179
+ This method collects category names from datasets that have the 'category_names' attribute,
180
+ then uses the first dataset's image path to determine where to cache the generated text embeddings.
181
+ """
182
+ # TODO: open up an interface to determine whether to do cache
183
+ category_names = set()
184
+ for dataset in datasets:
185
+ if not hasattr(dataset, "category_names"):
186
+ continue
187
+ category_names |= dataset.category_names
188
+
189
+ # TODO: enable to update the path or use a more general way to get the path
190
+ img_path = datasets[0].img_path
191
+ self.text_embeddings = self.generate_text_embeddings(
192
+ category_names, batch, cache_path=Path(img_path).parent / "text_embeddings.pt"
193
+ )
194
+
195
+ def preprocess_batch(self, batch):
196
+ """Process batch for training, moving text features to the appropriate device."""
197
+ batch = DetectionTrainer.preprocess_batch(self, batch)
198
+
199
+ texts = list(itertools.chain(*batch["texts"]))
200
+ txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
201
+ txt_feats = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
202
+ batch["txt_feats"] = txt_feats
203
+ return batch
204
+
205
+ def generate_text_embeddings(self, texts, batch, cache_path="embeddings.pt"):
206
+ """
207
+ Generate text embeddings for a list of text samples.
208
+
209
+ Args:
210
+ texts (List[str]): List of text samples to encode.
211
+ batch (int): Batch size for processing.
212
+ cache_path (str | Path): Path to save/load cached embeddings.
213
+
214
+ Returns:
215
+ (dict): Dictionary mapping text samples to their embeddings.
216
+ """
217
+ if cache_path.exists():
218
+ LOGGER.info(f"Reading existed cache from '{cache_path}'")
219
+ return torch.load(cache_path)
220
+ assert self.model is not None
221
+ txt_feats = self.model.get_text_pe(texts, batch, without_reprta=True)
222
+ txt_map = dict(zip(texts, txt_feats.squeeze(0)))
223
+ torch.save(txt_map, cache_path)
224
+ return txt_map
225
+
226
+
227
+ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
228
+ """Train prompt-free YOLOE model."""
229
+
230
+ def get_validator(self):
231
+ """Returns a DetectionValidator for YOLO model validation."""
232
+ self.loss_names = "box", "cls", "dfl"
233
+ return DetectionValidator(
234
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
235
+ )
236
+
237
+ def preprocess_batch(self, batch):
238
+ """Preprocesses a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
239
+ batch = DetectionTrainer.preprocess_batch(self, batch)
240
+ return batch
241
+
242
+ def set_text_embeddings(self, datasets, batch):
243
+ """
244
+ Set text embeddings for datasets to accelerate training by caching category names.
245
+
246
+ This method collects unique category names from all datasets, generates text embeddings for them,
247
+ and caches these embeddings to improve training efficiency. The embeddings are stored in a file
248
+ in the parent directory of the first dataset's image path.
249
+
250
+ Args:
251
+ datasets (List[Dataset]): List of datasets containing category names to process.
252
+ batch (int): Batch size for processing text embeddings.
253
+
254
+ Notes:
255
+ The method creates a dictionary mapping text samples to their embeddings and stores it
256
+ at the path specified by 'cache_path'. If the cache file already exists, it will be loaded
257
+ instead of regenerating the embeddings.
258
+ """
259
+ pass
260
+
261
+
262
+ class YOLOEVPTrainer(YOLOETrainerFromScratch):
263
+ """Train YOLOE model with visual prompts."""
264
+
265
+ def build_dataset(self, img_path, mode="train", batch=None):
266
+ """
267
+ Build YOLO Dataset for training or validation with visual prompts.
268
+
269
+ Args:
270
+ img_path (List[str] | str): Path to the folder containing images or list of paths.
271
+ mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
272
+ batch (int, optional): Size of batches, used for rectangular training/validation.
273
+
274
+ Returns:
275
+ (Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.
276
+ """
277
+ dataset = super().build_dataset(img_path, mode, batch)
278
+ if isinstance(dataset, YOLOConcatDataset):
279
+ for d in dataset.datasets:
280
+ d.transforms.append(LoadVisualPrompt())
281
+ else:
282
+ dataset.transforms.append(LoadVisualPrompt())
283
+ return dataset
284
+
285
+ def _close_dataloader_mosaic(self):
286
+ """Close mosaic augmentation and add visual prompt loading to the training dataset."""
287
+ super()._close_dataloader_mosaic()
288
+ if isinstance(self.train_loader.dataset, YOLOConcatDataset):
289
+ for d in self.train_loader.dataset.datasets:
290
+ d.transforms.append(LoadVisualPrompt())
291
+ else:
292
+ self.train_loader.dataset.transforms.append(LoadVisualPrompt())
293
+
294
+ def preprocess_batch(self, batch):
295
+ """Preprocesses a batch of images for YOLOE training, moving visual prompts to the appropriate device."""
296
+ batch = super().preprocess_batch(batch)
297
+ batch["visuals"] = batch["visuals"].to(self.device)
298
+ return batch