dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,10 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import itertools
4
6
  from pathlib import Path
7
+ from typing import Any
5
8
 
6
9
  import torch
7
10
 
@@ -9,58 +12,65 @@ from ultralytics.data import build_yolo_dataset
9
12
  from ultralytics.models.yolo.detect import DetectionTrainer
10
13
  from ultralytics.nn.tasks import WorldModel
11
14
  from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
12
- from ultralytics.utils.torch_utils import de_parallel
15
+ from ultralytics.utils.torch_utils import unwrap_model
13
16
 
14
17
 
15
- def on_pretrain_routine_end(trainer):
16
- """Callback to set up model classes and text encoder at the end of the pretrain routine."""
18
+ def on_pretrain_routine_end(trainer) -> None:
19
+ """Set up model classes and text encoder at the end of the pretrain routine."""
17
20
  if RANK in {-1, 0}:
18
21
  # Set class names for evaluation
19
- names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
20
- de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
22
+ names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
23
+ unwrap_model(trainer.ema.ema).set_classes(names, cache_clip_model=False)
21
24
 
22
25
 
23
26
  class WorldTrainer(DetectionTrainer):
24
- """
25
- A class to fine-tune a world model on a close-set dataset.
27
+ """A trainer class for fine-tuning YOLO World models on close-set datasets.
26
28
 
27
- This trainer extends the DetectionTrainer to support training YOLO World models, which combine
28
- visual and textual features for improved object detection and understanding.
29
+ This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
30
+ features for improved object detection and understanding. It handles text embedding generation and caching to
31
+ accelerate training with multi-modal data.
29
32
 
30
33
  Attributes:
31
- clip (module): The CLIP module for text-image understanding.
32
- text_model (module): The text encoder model from CLIP.
34
+ text_embeddings (dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate
35
+ training.
33
36
  model (WorldModel): The YOLO World model being trained.
34
- data (dict): Dataset configuration containing class information.
35
- args (dict): Training arguments and configuration.
37
+ data (dict[str, Any]): Dataset configuration containing class information.
38
+ args (Any): Training arguments and configuration.
39
+
40
+ Methods:
41
+ get_model: Return WorldModel initialized with specified config and weights.
42
+ build_dataset: Build YOLO Dataset for training or validation.
43
+ set_text_embeddings: Set text embeddings for datasets to accelerate training.
44
+ generate_text_embeddings: Generate text embeddings for a list of text samples.
45
+ preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.
36
46
 
37
47
  Examples:
38
- >>> from ultralytics.models.yolo.world import WorldModel
48
+ Initialize and train a YOLO World model
49
+ >>> from ultralytics.models.yolo.world import WorldTrainer
39
50
  >>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
40
51
  >>> trainer = WorldTrainer(overrides=args)
41
52
  >>> trainer.train()
42
53
  """
43
54
 
44
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
45
- """
46
- Initialize a WorldTrainer object with given arguments.
55
+ def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
56
+ """Initialize a WorldTrainer object with given arguments.
47
57
 
48
58
  Args:
49
- cfg (dict): Configuration for the trainer.
50
- overrides (dict, optional): Configuration overrides.
51
- _callbacks (list, optional): List of callback functions.
59
+ cfg (dict[str, Any]): Configuration for the trainer.
60
+ overrides (dict[str, Any], optional): Configuration overrides.
61
+ _callbacks (list[Any], optional): List of callback functions.
52
62
  """
53
63
  if overrides is None:
54
64
  overrides = {}
65
+ assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
55
66
  super().__init__(cfg, overrides, _callbacks)
56
67
  self.text_embeddings = None
57
68
 
58
- def get_model(self, cfg=None, weights=None, verbose=True):
59
- """
60
- Return WorldModel initialized with specified config and weights.
69
+ def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
70
+ """Return WorldModel initialized with specified config and weights.
61
71
 
62
72
  Args:
63
- cfg (Dict | str, optional): Model configuration.
73
+ cfg (dict[str, Any] | str, optional): Model configuration.
64
74
  weights (str, optional): Path to pretrained weights.
65
75
  verbose (bool): Whether to display model info.
66
76
 
@@ -81,9 +91,8 @@ class WorldTrainer(DetectionTrainer):
81
91
 
82
92
  return model
83
93
 
84
- def build_dataset(self, img_path, mode="train", batch=None):
85
- """
86
- Build YOLO Dataset for training or validation.
94
+ def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
95
+ """Build YOLO Dataset for training or validation.
87
96
 
88
97
  Args:
89
98
  img_path (str): Path to the folder containing images.
@@ -91,9 +100,9 @@ class WorldTrainer(DetectionTrainer):
91
100
  batch (int, optional): Size of batches, this is for `rect`.
92
101
 
93
102
  Returns:
94
- (Dataset): YOLO dataset configured for training or validation.
103
+ (Any): YOLO dataset configured for training or validation.
95
104
  """
96
- gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
105
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
97
106
  dataset = build_yolo_dataset(
98
107
  self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
99
108
  )
@@ -101,15 +110,14 @@ class WorldTrainer(DetectionTrainer):
101
110
  self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training
102
111
  return dataset
103
112
 
104
- def set_text_embeddings(self, datasets, batch):
105
- """
106
- Set text embeddings for datasets to accelerate training by caching category names.
113
+ def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
114
+ """Set text embeddings for datasets to accelerate training by caching category names.
107
115
 
108
- This method collects unique category names from all datasets, then generates and caches text embeddings
109
- for these categories to improve training efficiency.
116
+ This method collects unique category names from all datasets, then generates and caches text embeddings for
117
+ these categories to improve training efficiency.
110
118
 
111
119
  Args:
112
- datasets (List[Dataset]): List of datasets from which to extract category names.
120
+ datasets (list[Any]): List of datasets from which to extract category names.
113
121
  batch (int | None): Batch size used for processing.
114
122
 
115
123
  Notes:
@@ -127,39 +135,39 @@ class WorldTrainer(DetectionTrainer):
127
135
  )
128
136
  self.text_embeddings = text_embeddings
129
137
 
130
- def generate_text_embeddings(self, texts, batch, cache_dir):
131
- """
132
- Generate text embeddings for a list of text samples.
138
+ def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
139
+ """Generate text embeddings for a list of text samples.
133
140
 
134
141
  Args:
135
- texts (List[str]): List of text samples to encode.
142
+ texts (list[str]): List of text samples to encode.
136
143
  batch (int): Batch size for processing.
137
144
  cache_dir (Path): Directory to save/load cached embeddings.
138
145
 
139
146
  Returns:
140
- (dict): Dictionary mapping text samples to their embeddings.
147
+ (dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.
141
148
  """
142
149
  model = "clip:ViT-B/32"
143
150
  cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
144
151
  if cache_path.exists():
145
152
  LOGGER.info(f"Reading existed cache from '{cache_path}'")
146
- txt_map = torch.load(cache_path)
153
+ txt_map = torch.load(cache_path, map_location=self.device)
147
154
  if sorted(txt_map.keys()) == sorted(texts):
148
155
  return txt_map
149
156
  LOGGER.info(f"Caching text embeddings to '{cache_path}'")
150
157
  assert self.model is not None
151
- txt_feats = self.model.get_text_pe(texts, batch, cache_clip_model=False)
158
+ txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, cache_clip_model=False)
152
159
  txt_map = dict(zip(texts, txt_feats.squeeze(0)))
153
160
  torch.save(txt_map, cache_path)
154
161
  return txt_map
155
162
 
156
- def preprocess_batch(self, batch):
163
+ def preprocess_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
157
164
  """Preprocess a batch of images and text for YOLOWorld training."""
158
165
  batch = DetectionTrainer.preprocess_batch(self, batch)
159
166
 
160
167
  # Add text features
161
168
  texts = list(itertools.chain(*batch["texts"]))
162
- txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
163
- txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
169
+ txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(
170
+ self.device, non_blocking=self.device.type == "cuda"
171
+ )
164
172
  batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
165
173
  return batch
@@ -1,15 +1,16 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from pathlib import Path
4
+
3
5
  from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
4
6
  from ultralytics.data.utils import check_det_dataset
5
7
  from ultralytics.models.yolo.world import WorldTrainer
6
- from ultralytics.utils import DEFAULT_CFG, LOGGER
7
- from ultralytics.utils.torch_utils import de_parallel
8
+ from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER
9
+ from ultralytics.utils.torch_utils import unwrap_model
8
10
 
9
11
 
10
12
  class WorldTrainerFromScratch(WorldTrainer):
11
- """
12
- A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
13
+ """A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
13
14
 
14
15
  This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
15
16
  supporting training YOLO-World models with combined vision-language capabilities.
@@ -18,6 +19,14 @@ class WorldTrainerFromScratch(WorldTrainer):
18
19
  cfg (dict): Configuration dictionary with default parameters for model training.
19
20
  overrides (dict): Dictionary of parameter overrides to customize the configuration.
20
21
  _callbacks (list): List of callback functions to be executed during different stages of training.
22
+ data (dict): Final processed data configuration containing train/val paths and metadata.
23
+ training_data (dict): Dictionary mapping training dataset paths to their configurations.
24
+
25
+ Methods:
26
+ build_dataset: Build YOLO Dataset for training or validation with mixed dataset support.
27
+ get_dataset: Get train and validation paths from data dictionary.
28
+ plot_training_labels: Skip label plotting for YOLO-World training.
29
+ final_eval: Perform final evaluation and validation for the YOLO-World model.
21
30
 
22
31
  Examples:
23
32
  >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
@@ -27,12 +36,12 @@ class WorldTrainerFromScratch(WorldTrainer):
27
36
  ... yolo_data=["Objects365.yaml"],
28
37
  ... grounding_data=[
29
38
  ... dict(
30
- ... img_path="../datasets/flickr30k/images",
31
- ... json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
39
+ ... img_path="flickr30k/images",
40
+ ... json_file="flickr30k/final_flickr_separateGT_train.json",
32
41
  ... ),
33
42
  ... dict(
34
- ... img_path="../datasets/GQA/images",
35
- ... json_file="../datasets/GQA/final_mixed_train_no_coco.json",
43
+ ... img_path="GQA/images",
44
+ ... json_file="GQA/final_mixed_train_no_coco.json",
36
45
  ... ),
37
46
  ... ],
38
47
  ... ),
@@ -43,11 +52,10 @@ class WorldTrainerFromScratch(WorldTrainer):
43
52
  """
44
53
 
45
54
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
46
- """
47
- Initialize a WorldTrainerFromScratch object.
55
+ """Initialize a WorldTrainerFromScratch object.
48
56
 
49
- This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
50
- object detection and grounding datasets for vision-language capabilities.
57
+ This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both object
58
+ detection and grounding datasets for vision-language capabilities.
51
59
 
52
60
  Args:
53
61
  cfg (dict): Configuration dictionary with default parameters for model training.
@@ -62,8 +70,8 @@ class WorldTrainerFromScratch(WorldTrainer):
62
70
  ... yolo_data=["Objects365.yaml"],
63
71
  ... grounding_data=[
64
72
  ... dict(
65
- ... img_path="../datasets/flickr30k/images",
66
- ... json_file="../datasets/flickr30k/final_flickr_separateGT_train.json",
73
+ ... img_path="flickr30k/images",
74
+ ... json_file="flickr30k/final_flickr_separateGT_train.json",
67
75
  ... ),
68
76
  ... ],
69
77
  ... ),
@@ -77,42 +85,48 @@ class WorldTrainerFromScratch(WorldTrainer):
77
85
  super().__init__(cfg, overrides, _callbacks)
78
86
 
79
87
  def build_dataset(self, img_path, mode="train", batch=None):
80
- """
81
- Build YOLO Dataset for training or validation.
88
+ """Build YOLO Dataset for training or validation.
82
89
 
83
- This method constructs appropriate datasets based on the mode and input paths, handling both
84
- standard YOLO datasets and grounding datasets with different formats.
90
+ This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
91
+ datasets and grounding datasets with different formats.
85
92
 
86
93
  Args:
87
- img_path (List[str] | str): Path to the folder containing images or list of paths.
94
+ img_path (list[str] | str): Path to the folder containing images or list of paths.
88
95
  mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
89
96
  batch (int, optional): Size of batches, used for rectangular training/validation.
90
97
 
91
98
  Returns:
92
99
  (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
93
100
  """
94
- gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
101
+ gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
95
102
  if mode != "train":
96
103
  return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=gs)
97
104
  datasets = [
98
105
  build_yolo_dataset(self.args, im_path, batch, self.training_data[im_path], stride=gs, multi_modal=True)
99
106
  if isinstance(im_path, str)
100
- else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
107
+ else build_grounding(
108
+ # assign `nc` from validation set to max number of text samples for training consistency
109
+ self.args,
110
+ im_path["img_path"],
111
+ im_path["json_file"],
112
+ batch,
113
+ stride=gs,
114
+ max_samples=self.data["nc"],
115
+ )
101
116
  for im_path in img_path
102
117
  ]
103
118
  self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training
104
119
  return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
105
120
 
106
121
  def get_dataset(self):
107
- """
108
- Get train and validation paths from data dictionary.
122
+ """Get train and validation paths from data dictionary.
109
123
 
110
- Processes the data configuration to extract paths for training and validation datasets,
111
- handling both YOLO detection datasets and grounding datasets.
124
+ Processes the data configuration to extract paths for training and validation datasets, handling both YOLO
125
+ detection datasets and grounding datasets.
112
126
 
113
127
  Returns:
114
- (str): Train dataset path.
115
- (str): Validation dataset path.
128
+ train_path (str): Train dataset path.
129
+ val_path (str): Validation dataset path.
116
130
 
117
131
  Raises:
118
132
  AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
@@ -128,7 +142,7 @@ class WorldTrainerFromScratch(WorldTrainer):
128
142
  if d.get("minival") is None: # for lvis dataset
129
143
  continue
130
144
  d["minival"] = str(d["path"] / d["minival"])
131
- for s in ["train", "val"]:
145
+ for s in {"train", "val"}:
132
146
  final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
133
147
  # save grounding data if there's one
134
148
  grounding_data = data_yaml[s].get("grounding_data")
@@ -137,8 +151,14 @@ class WorldTrainerFromScratch(WorldTrainer):
137
151
  grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
138
152
  for g in grounding_data:
139
153
  assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
154
+ for k in {"img_path", "json_file"}:
155
+ path = Path(g[k])
156
+ if not path.exists() and not path.is_absolute():
157
+ g[k] = str((DATASETS_DIR / g[k]).resolve()) # path relative to DATASETS_DIR
140
158
  final_data[s] += grounding_data
141
- data["val"] = data["val"][0] # assign the first val dataset as currently only one validation set is supported
159
+ # assign the first val dataset as currently only one validation set is supported
160
+ data["val"] = data["val"][0]
161
+ final_data["val"] = final_data["val"][0]
142
162
  # NOTE: to make training work properly, set `nc` and `names`
143
163
  final_data["nc"] = data["val"]["nc"]
144
164
  final_data["names"] = data["val"]["names"]
@@ -159,12 +179,11 @@ class WorldTrainerFromScratch(WorldTrainer):
159
179
  return final_data
160
180
 
161
181
  def plot_training_labels(self):
162
- """Do not plot labels for YOLO-World training."""
182
+ """Skip label plotting for YOLO-World training."""
163
183
  pass
164
184
 
165
185
  def final_eval(self):
166
- """
167
- Perform final evaluation and validation for the YOLO-World model.
186
+ """Perform final evaluation and validation for the YOLO-World model.
168
187
 
169
188
  Configures the validator with appropriate dataset and split information before running evaluation.
170
189
 
@@ -6,17 +6,17 @@ from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromSc
6
6
  from .val import YOLOEDetectValidator, YOLOESegValidator
7
7
 
8
8
  __all__ = [
9
- "YOLOETrainer",
10
- "YOLOEPETrainer",
11
- "YOLOESegTrainer",
12
9
  "YOLOEDetectValidator",
13
- "YOLOESegValidator",
10
+ "YOLOEPEFreeTrainer",
14
11
  "YOLOEPESegTrainer",
12
+ "YOLOEPETrainer",
13
+ "YOLOESegTrainer",
15
14
  "YOLOESegTrainerFromScratch",
16
15
  "YOLOESegVPTrainer",
17
- "YOLOEVPTrainer",
18
- "YOLOEPEFreeTrainer",
16
+ "YOLOESegValidator",
17
+ "YOLOETrainer",
18
+ "YOLOETrainerFromScratch",
19
19
  "YOLOEVPDetectPredictor",
20
20
  "YOLOEVPSegPredictor",
21
- "YOLOETrainerFromScratch",
21
+ "YOLOEVPTrainer",
22
22
  ]
@@ -9,52 +9,48 @@ from ultralytics.models.yolo.segment import SegmentationPredictor
9
9
 
10
10
 
11
11
  class YOLOEVPDetectPredictor(DetectionPredictor):
12
- """
13
- A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
12
+ """A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
14
13
 
15
- This mixin provides common functionality for YOLO models that use visual prompting, including
16
- model setup, prompt handling, and preprocessing transformations.
14
+ This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
15
+ handling, and preprocessing transformations.
17
16
 
18
17
  Attributes:
19
18
  model (torch.nn.Module): The YOLO model for inference.
20
19
  device (torch.device): Device to run the model on (CPU or CUDA).
21
- prompts (dict): Visual prompts containing class indices and bounding boxes or masks.
20
+ prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.
22
21
 
23
22
  Methods:
24
23
  setup_model: Initialize the YOLO model and set it to evaluation mode.
25
- set_return_vpe: Set whether to return visual prompt embeddings.
26
24
  set_prompts: Set the visual prompts for the model.
27
25
  pre_transform: Preprocess images and prompts before inference.
28
26
  inference: Run inference with visual prompts.
27
+ get_vpe: Process source to get visual prompt embeddings.
29
28
  """
30
29
 
31
- def setup_model(self, model, verbose=True):
32
- """
33
- Sets up the model for prediction.
30
+ def setup_model(self, model, verbose: bool = True):
31
+ """Set up the model for prediction.
34
32
 
35
33
  Args:
36
34
  model (torch.nn.Module): Model to load or use.
37
- verbose (bool): If True, provides detailed logging.
35
+ verbose (bool, optional): If True, provides detailed logging.
38
36
  """
39
37
  super().setup_model(model, verbose=verbose)
40
38
  self.done_warmup = True
41
39
 
42
40
  def set_prompts(self, prompts):
43
- """
44
- Set the visual prompts for the model.
41
+ """Set the visual prompts for the model.
45
42
 
46
43
  Args:
47
- prompts (dict): Dictionary containing class indices and bounding boxes or masks.
48
- Must include a 'cls' key with class indices.
44
+ prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
45
+ with class indices.
49
46
  """
50
47
  self.prompts = prompts
51
48
 
52
49
  def pre_transform(self, im):
53
- """
54
- Preprocess images and prompts before inference.
50
+ """Preprocess images and prompts before inference.
55
51
 
56
- This method applies letterboxing to the input image and transforms the visual prompts
57
- (bounding boxes or masks) accordingly.
52
+ This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
53
+ accordingly.
58
54
 
59
55
  Args:
60
56
  im (list): List containing a single input image.
@@ -71,16 +67,16 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
71
67
  category = self.prompts["cls"]
72
68
  if len(img) == 1:
73
69
  visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
74
- self.prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
70
+ prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
75
71
  else:
76
72
  # NOTE: only supports bboxes as prompts for now
77
73
  assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
78
- # NOTE: needs List[np.ndarray]
74
+ # NOTE: needs list[np.ndarray]
79
75
  assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
80
- f"Expected List[np.ndarray], but got {bboxes}!"
76
+ f"Expected list[np.ndarray], but got {bboxes}!"
81
77
  )
82
78
  assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
83
- f"Expected List[np.ndarray], but got {category}!"
79
+ f"Expected list[np.ndarray], but got {category}!"
84
80
  )
85
81
  assert len(im) == len(category) == len(bboxes), (
86
82
  f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
@@ -89,23 +85,22 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
89
85
  self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
90
86
  for i in range(len(img))
91
87
  ]
92
- self.prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device)
93
-
88
+ prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device) # (B, N, H, W)
89
+ self.prompts = prompts.half() if self.model.fp16 else prompts.float()
94
90
  return img
95
91
 
96
92
  def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
97
- """
98
- Processes a single image by resizing bounding boxes or masks and generating visuals.
93
+ """Process a single image by resizing bounding boxes or masks and generating visuals.
99
94
 
100
95
  Args:
101
96
  dst_shape (tuple): The target shape (height, width) of the image.
102
97
  src_shape (tuple): The original shape (height, width) of the image.
103
98
  category (str): The category of the image for visual prompts.
104
- bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2]. Defaults to None.
105
- masks (np.ndarray, optional): A list of masks corresponding to the image. Defaults to None.
99
+ bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].
100
+ masks (np.ndarray, optional): A list of masks corresponding to the image.
106
101
 
107
102
  Returns:
108
- visuals: The processed visuals for the image.
103
+ (torch.Tensor): The processed visuals for the image.
109
104
 
110
105
  Raises:
111
106
  ValueError: If neither `bboxes` nor `masks` are provided.
@@ -131,8 +126,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
131
126
  return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
132
127
 
133
128
  def inference(self, im, *args, **kwargs):
134
- """
135
- Run inference with visual prompts.
129
+ """Run inference with visual prompts.
136
130
 
137
131
  Args:
138
132
  im (torch.Tensor): Input image tensor.
@@ -145,13 +139,12 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
145
139
  return super().inference(im, vpe=self.prompts, *args, **kwargs)
146
140
 
147
141
  def get_vpe(self, source):
148
- """
149
- Processes the source to get the visual prompt embeddings (VPE).
142
+ """Process the source to get the visual prompt embeddings (VPE).
150
143
 
151
144
  Args:
152
- source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source
153
- of the image to make predictions on. Accepts various types including file paths, URLs, PIL
154
- images, numpy arrays, and torch tensors.
145
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
146
+ make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
147
+ torch tensors.
155
148
 
156
149
  Returns:
157
150
  (torch.Tensor): The visual prompt embeddings (VPE) from the model.
@@ -164,6 +157,6 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
164
157
 
165
158
 
166
159
  class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
167
- """Predictor for YOLOE VP segmentation."""
160
+ """Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities."""
168
161
 
169
162
  pass