dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- import itertools
6
5
  from copy import copy, deepcopy
7
6
  from pathlib import Path
8
7
 
@@ -20,11 +19,10 @@ from .val import YOLOEDetectValidator
20
19
 
21
20
 
22
21
  class YOLOETrainer(DetectionTrainer):
23
- """
24
- A trainer class for YOLOE object detection models.
22
+ """A trainer class for YOLOE object detection models.
25
23
 
26
- This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,
27
- including custom model initialization, validation, and dataset building with multi-modal support.
24
+ This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
25
+ model initialization, validation, and dataset building with multi-modal support.
28
26
 
29
27
  Attributes:
30
28
  loss_names (tuple): Names of loss components used during training.
@@ -36,8 +34,7 @@ class YOLOETrainer(DetectionTrainer):
36
34
  """
37
35
 
38
36
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
39
- """
40
- Initialize the YOLOE Trainer with specified configurations.
37
+ """Initialize the YOLOE Trainer with specified configurations.
41
38
 
42
39
  Args:
43
40
  cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
@@ -46,16 +43,16 @@ class YOLOETrainer(DetectionTrainer):
46
43
  """
47
44
  if overrides is None:
48
45
  overrides = {}
46
+ assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
49
47
  overrides["overlap_mask"] = False
50
48
  super().__init__(cfg, overrides, _callbacks)
51
49
 
52
50
  def get_model(self, cfg=None, weights=None, verbose: bool = True):
53
- """
54
- Return a YOLOEModel initialized with the specified configuration and weights.
51
+ """Return a YOLOEModel initialized with the specified configuration and weights.
55
52
 
56
53
  Args:
57
- cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
58
- a direct path to a YAML file, or None to use default configuration.
54
+ cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
55
+ path to a YAML file, or None to use default configuration.
59
56
  weights (str | Path, optional): Path to pretrained weights file to load into the model.
60
57
  verbose (bool): Whether to display model information during initialization.
61
58
 
@@ -88,8 +85,7 @@ class YOLOETrainer(DetectionTrainer):
88
85
  )
89
86
 
90
87
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
91
- """
92
- Build YOLO Dataset.
88
+ """Build YOLO Dataset.
93
89
 
94
90
  Args:
95
91
  img_path (str): Path to the folder containing images.
@@ -106,19 +102,17 @@ class YOLOETrainer(DetectionTrainer):
106
102
 
107
103
 
108
104
  class YOLOEPETrainer(DetectionTrainer):
109
- """
110
- Fine-tune YOLOE model using linear probing approach.
105
+ """Fine-tune YOLOE model using linear probing approach.
111
106
 
112
- This trainer freezes most model layers and only trains specific projection layers for efficient
113
- fine-tuning on new datasets while preserving pretrained features.
107
+ This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
108
+ datasets while preserving pretrained features.
114
109
 
115
110
  Methods:
116
111
  get_model: Initialize YOLOEModel with frozen layers except projection layers.
117
112
  """
118
113
 
119
114
  def get_model(self, cfg=None, weights=None, verbose: bool = True):
120
- """
121
- Return YOLOEModel initialized with specified config and weights.
115
+ """Return YOLOEModel initialized with specified config and weights.
122
116
 
123
117
  Args:
124
118
  cfg (dict | str, optional): Model configuration.
@@ -160,24 +154,21 @@ class YOLOEPETrainer(DetectionTrainer):
160
154
 
161
155
 
162
156
  class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
163
- """
164
- Train YOLOE models from scratch with text embedding support.
157
+ """Train YOLOE models from scratch with text embedding support.
165
158
 
166
- This trainer combines YOLOE training capabilities with world training features, enabling
167
- training from scratch with text embeddings and grounding datasets.
159
+ This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
160
+ text embeddings and grounding datasets.
168
161
 
169
162
  Methods:
170
163
  build_dataset: Build datasets for training with grounding support.
171
- preprocess_batch: Process batches with text features.
172
164
  generate_text_embeddings: Generate and cache text embeddings for training.
173
165
  """
174
166
 
175
167
  def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
176
- """
177
- Build YOLO Dataset for training or validation.
168
+ """Build YOLO Dataset for training or validation.
178
169
 
179
- This method constructs appropriate datasets based on the mode and input paths, handling both
180
- standard YOLO datasets and grounding datasets with different formats.
170
+ This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
171
+ datasets and grounding datasets with different formats.
181
172
 
182
173
  Args:
183
174
  img_path (list[str] | str): Path to the folder containing images or list of paths.
@@ -189,19 +180,8 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
189
180
  """
190
181
  return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
191
182
 
192
- def preprocess_batch(self, batch):
193
- """Process batch for training, moving text features to the appropriate device."""
194
- batch = DetectionTrainer.preprocess_batch(self, batch)
195
-
196
- texts = list(itertools.chain(*batch["texts"]))
197
- txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device, non_blocking=True)
198
- txt_feats = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
199
- batch["txt_feats"] = txt_feats
200
- return batch
201
-
202
183
  def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
203
- """
204
- Generate text embeddings for a list of text samples.
184
+ """Generate text embeddings for a list of text samples.
205
185
 
206
186
  Args:
207
187
  texts (list[str]): List of text samples to encode.
@@ -227,11 +207,10 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
227
207
 
228
208
 
229
209
  class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
230
- """
231
- Train prompt-free YOLOE model.
210
+ """Train prompt-free YOLOE model.
232
211
 
233
- This trainer combines linear probing capabilities with from-scratch training for prompt-free
234
- YOLOE models that don't require text prompts during inference.
212
+ This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
213
+ require text prompts during inference.
235
214
 
236
215
  Methods:
237
216
  get_validator: Return standard DetectionValidator for validation.
@@ -251,12 +230,11 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
251
230
  return DetectionTrainer.preprocess_batch(self, batch)
252
231
 
253
232
  def set_text_embeddings(self, datasets, batch: int):
254
- """
255
- Set text embeddings for datasets to accelerate training by caching category names.
233
+ """Set text embeddings for datasets to accelerate training by caching category names.
256
234
 
257
- This method collects unique category names from all datasets, generates text embeddings for them,
258
- and caches these embeddings to improve training efficiency. The embeddings are stored in a file
259
- in the parent directory of the first dataset's image path.
235
+ This method collects unique category names from all datasets, generates text embeddings for them, and caches
236
+ these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
237
+ the first dataset's image path.
260
238
 
261
239
  Args:
262
240
  datasets (list[Dataset]): List of datasets containing category names to process.
@@ -271,20 +249,17 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
271
249
 
272
250
 
273
251
  class YOLOEVPTrainer(YOLOETrainerFromScratch):
274
- """
275
- Train YOLOE model with visual prompts.
252
+ """Train YOLOE model with visual prompts.
276
253
 
277
- This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,
278
- where visual cues are provided alongside images to guide the detection process.
254
+ This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
255
+ alongside images to guide the detection process.
279
256
 
280
257
  Methods:
281
258
  build_dataset: Build dataset with visual prompt loading transforms.
282
- preprocess_batch: Preprocess batches with visual prompts.
283
259
  """
284
260
 
285
261
  def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
286
- """
287
- Build YOLO Dataset for training or validation with visual prompts.
262
+ """Build YOLO Dataset for training or validation with visual prompts.
288
263
 
289
264
  Args:
290
265
  img_path (list[str] | str): Path to the folder containing images or list of paths.
@@ -11,8 +11,7 @@ from .val import YOLOESegValidator
11
11
 
12
12
 
13
13
  class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
14
- """
15
- Trainer class for YOLOE segmentation models.
14
+ """Trainer class for YOLOE segmentation models.
16
15
 
17
16
  This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
18
17
  segmentation models, enabling both object detection and instance segmentation capabilities.
@@ -24,8 +23,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
24
23
  """
25
24
 
26
25
  def get_model(self, cfg=None, weights=None, verbose=True):
27
- """
28
- Return YOLOESegModel initialized with specified config and weights.
26
+ """Return YOLOESegModel initialized with specified config and weights.
29
27
 
30
28
  Args:
31
29
  cfg (dict | str, optional): Model configuration dictionary or YAML file path.
@@ -49,8 +47,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
49
47
  return model
50
48
 
51
49
  def get_validator(self):
52
- """
53
- Create and return a validator for YOLOE segmentation model evaluation.
50
+ """Create and return a validator for YOLOE segmentation model evaluation.
54
51
 
55
52
  Returns:
56
53
  (YOLOESegValidator): Validator for YOLOE segmentation models.
@@ -62,8 +59,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
62
59
 
63
60
 
64
61
  class YOLOEPESegTrainer(SegmentationTrainer):
65
- """
66
- Fine-tune YOLOESeg model in linear probing way.
62
+ """Fine-tune YOLOESeg model in linear probing way.
67
63
 
68
64
  This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
69
65
  most of the model and only training specific layers for efficient adaptation to new tasks.
@@ -73,8 +69,7 @@ class YOLOEPESegTrainer(SegmentationTrainer):
73
69
  """
74
70
 
75
71
  def get_model(self, cfg=None, weights=None, verbose=True):
76
- """
77
- Return YOLOESegModel initialized with specified config and weights for linear probing.
72
+ """Return YOLOESegModel initialized with specified config and weights for linear probing.
78
73
 
79
74
  Args:
80
75
  cfg (dict | str, optional): Model configuration dictionary or YAML file path.
@@ -21,12 +21,11 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
21
21
 
22
22
 
23
23
  class YOLOEDetectValidator(DetectionValidator):
24
- """
25
- A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
24
+ """A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
26
25
 
27
- This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.
28
- It supports validation using either text prompts or visual prompt embeddings extracted from training samples,
29
- enabling flexible evaluation strategies for prompt-based object detection.
26
+ This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
27
+ validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
28
+ evaluation strategies for prompt-based object detection.
30
29
 
31
30
  Attributes:
32
31
  device (torch.device): The device on which validation is performed.
@@ -50,12 +49,11 @@ class YOLOEDetectValidator(DetectionValidator):
50
49
 
51
50
  @smart_inference_mode()
52
51
  def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
53
- """
54
- Extract visual prompt embeddings from training samples.
52
+ """Extract visual prompt embeddings from training samples.
55
53
 
56
- This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
57
- It normalizes the embeddings and handles cases where no samples exist for a class by setting their
58
- embeddings to zero.
54
+ This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
55
+ normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
56
+ zero.
59
57
 
60
58
  Args:
61
59
  dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
@@ -89,7 +87,7 @@ class YOLOEDetectValidator(DetectionValidator):
89
87
  for i in range(preds.shape[0]):
90
88
  cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
91
89
  pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
92
- pad_cls[: len(cls)] = cls
90
+ pad_cls[: cls.shape[0]] = cls
93
91
  for c in cls:
94
92
  visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
95
93
 
@@ -99,12 +97,10 @@ class YOLOEDetectValidator(DetectionValidator):
99
97
  return visual_pe.unsqueeze(0)
100
98
 
101
99
  def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
102
- """
103
- Create a dataloader for LVIS training visual prompt samples.
100
+ """Create a dataloader for LVIS training visual prompt samples.
104
101
 
105
- This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
106
- It applies necessary transformations including LoadVisualPrompt and configurations to the dataset
107
- for validation purposes.
102
+ This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
103
+ necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
108
104
 
109
105
  Args:
110
106
  data (dict): Dataset configuration dictionary containing paths and settings.
@@ -141,12 +137,11 @@ class YOLOEDetectValidator(DetectionValidator):
141
137
  refer_data: str | None = None,
142
138
  load_vp: bool = False,
143
139
  ) -> dict[str, Any]:
144
- """
145
- Run validation on the model using either text or visual prompt embeddings.
140
+ """Run validation on the model using either text or visual prompt embeddings.
146
141
 
147
- This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
148
- It supports validation during training (using a trainer object) or standalone validation with a provided
149
- model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.
142
+ This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
143
+ supports validation during training (using a trainer object) or standalone validation with a provided model. For
144
+ visual prompts, reference data can be specified to extract embeddings from a different dataset.
150
145
 
151
146
  Args:
152
147
  trainer (object, optional): Trainer object containing the model and device.
@@ -14,14 +14,14 @@ from .tasks import (
14
14
  )
15
15
 
16
16
  __all__ = (
17
+ "BaseModel",
18
+ "ClassificationModel",
19
+ "DetectionModel",
20
+ "SegmentationModel",
21
+ "guess_model_scale",
22
+ "guess_model_task",
17
23
  "load_checkpoint",
18
24
  "parse_model",
19
- "yaml_model_load",
20
- "guess_model_task",
21
- "guess_model_scale",
22
25
  "torch_safe_load",
23
- "DetectionModel",
24
- "SegmentationModel",
25
- "ClassificationModel",
26
- "BaseModel",
26
+ "yaml_model_load",
27
27
  )