dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +1 -1
  6. tests/test_cuda.py +5 -8
  7. tests/test_engine.py +1 -1
  8. tests/test_exports.py +57 -12
  9. tests/test_integrations.py +4 -4
  10. tests/test_python.py +84 -53
  11. tests/test_solutions.py +160 -151
  12. ultralytics/__init__.py +1 -1
  13. ultralytics/cfg/__init__.py +56 -62
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/VOC.yaml +15 -16
  19. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  20. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  21. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  22. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  24. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  25. ultralytics/cfg/datasets/dota8.yaml +2 -2
  26. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  27. ultralytics/cfg/datasets/kitti.yaml +27 -0
  28. ultralytics/cfg/datasets/lvis.yaml +5 -5
  29. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  30. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  31. ultralytics/cfg/datasets/xView.yaml +16 -16
  32. ultralytics/cfg/default.yaml +1 -1
  33. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  34. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  35. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  36. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  37. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  38. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  39. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  40. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  41. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  42. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  43. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  44. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  45. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  46. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  47. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  48. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  49. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  50. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  51. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  52. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  53. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  54. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  55. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  58. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  59. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  62. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  63. ultralytics/data/__init__.py +4 -4
  64. ultralytics/data/annotator.py +3 -4
  65. ultralytics/data/augment.py +285 -475
  66. ultralytics/data/base.py +18 -26
  67. ultralytics/data/build.py +147 -25
  68. ultralytics/data/converter.py +36 -46
  69. ultralytics/data/dataset.py +46 -74
  70. ultralytics/data/loaders.py +42 -49
  71. ultralytics/data/split.py +5 -6
  72. ultralytics/data/split_dota.py +8 -15
  73. ultralytics/data/utils.py +34 -43
  74. ultralytics/engine/exporter.py +319 -237
  75. ultralytics/engine/model.py +148 -188
  76. ultralytics/engine/predictor.py +29 -38
  77. ultralytics/engine/results.py +177 -311
  78. ultralytics/engine/trainer.py +83 -59
  79. ultralytics/engine/tuner.py +23 -34
  80. ultralytics/engine/validator.py +39 -22
  81. ultralytics/hub/__init__.py +16 -19
  82. ultralytics/hub/auth.py +6 -12
  83. ultralytics/hub/google/__init__.py +7 -10
  84. ultralytics/hub/session.py +15 -25
  85. ultralytics/hub/utils.py +5 -8
  86. ultralytics/models/__init__.py +1 -1
  87. ultralytics/models/fastsam/__init__.py +1 -1
  88. ultralytics/models/fastsam/model.py +8 -10
  89. ultralytics/models/fastsam/predict.py +17 -29
  90. ultralytics/models/fastsam/utils.py +1 -2
  91. ultralytics/models/fastsam/val.py +5 -7
  92. ultralytics/models/nas/__init__.py +1 -1
  93. ultralytics/models/nas/model.py +5 -8
  94. ultralytics/models/nas/predict.py +7 -9
  95. ultralytics/models/nas/val.py +1 -2
  96. ultralytics/models/rtdetr/__init__.py +1 -1
  97. ultralytics/models/rtdetr/model.py +5 -8
  98. ultralytics/models/rtdetr/predict.py +15 -19
  99. ultralytics/models/rtdetr/train.py +10 -13
  100. ultralytics/models/rtdetr/val.py +21 -23
  101. ultralytics/models/sam/__init__.py +15 -2
  102. ultralytics/models/sam/amg.py +14 -20
  103. ultralytics/models/sam/build.py +26 -19
  104. ultralytics/models/sam/build_sam3.py +377 -0
  105. ultralytics/models/sam/model.py +29 -32
  106. ultralytics/models/sam/modules/blocks.py +83 -144
  107. ultralytics/models/sam/modules/decoders.py +19 -37
  108. ultralytics/models/sam/modules/encoders.py +44 -101
  109. ultralytics/models/sam/modules/memory_attention.py +16 -30
  110. ultralytics/models/sam/modules/sam.py +200 -73
  111. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  112. ultralytics/models/sam/modules/transformer.py +18 -28
  113. ultralytics/models/sam/modules/utils.py +174 -50
  114. ultralytics/models/sam/predict.py +2248 -350
  115. ultralytics/models/sam/sam3/__init__.py +3 -0
  116. ultralytics/models/sam/sam3/decoder.py +546 -0
  117. ultralytics/models/sam/sam3/encoder.py +529 -0
  118. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  119. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  120. ultralytics/models/sam/sam3/model_misc.py +199 -0
  121. ultralytics/models/sam/sam3/necks.py +129 -0
  122. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  123. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  124. ultralytics/models/sam/sam3/vitdet.py +547 -0
  125. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  126. ultralytics/models/utils/loss.py +14 -26
  127. ultralytics/models/utils/ops.py +13 -17
  128. ultralytics/models/yolo/__init__.py +1 -1
  129. ultralytics/models/yolo/classify/predict.py +9 -12
  130. ultralytics/models/yolo/classify/train.py +11 -32
  131. ultralytics/models/yolo/classify/val.py +29 -28
  132. ultralytics/models/yolo/detect/predict.py +7 -10
  133. ultralytics/models/yolo/detect/train.py +11 -20
  134. ultralytics/models/yolo/detect/val.py +70 -58
  135. ultralytics/models/yolo/model.py +36 -53
  136. ultralytics/models/yolo/obb/predict.py +5 -14
  137. ultralytics/models/yolo/obb/train.py +11 -14
  138. ultralytics/models/yolo/obb/val.py +39 -36
  139. ultralytics/models/yolo/pose/__init__.py +1 -1
  140. ultralytics/models/yolo/pose/predict.py +6 -21
  141. ultralytics/models/yolo/pose/train.py +10 -15
  142. ultralytics/models/yolo/pose/val.py +38 -57
  143. ultralytics/models/yolo/segment/predict.py +14 -18
  144. ultralytics/models/yolo/segment/train.py +3 -6
  145. ultralytics/models/yolo/segment/val.py +93 -45
  146. ultralytics/models/yolo/world/train.py +8 -14
  147. ultralytics/models/yolo/world/train_world.py +11 -34
  148. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  149. ultralytics/models/yolo/yoloe/predict.py +16 -23
  150. ultralytics/models/yolo/yoloe/train.py +30 -43
  151. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  152. ultralytics/models/yolo/yoloe/val.py +15 -20
  153. ultralytics/nn/__init__.py +7 -7
  154. ultralytics/nn/autobackend.py +145 -77
  155. ultralytics/nn/modules/__init__.py +60 -60
  156. ultralytics/nn/modules/activation.py +4 -6
  157. ultralytics/nn/modules/block.py +132 -216
  158. ultralytics/nn/modules/conv.py +52 -97
  159. ultralytics/nn/modules/head.py +50 -103
  160. ultralytics/nn/modules/transformer.py +76 -88
  161. ultralytics/nn/modules/utils.py +16 -21
  162. ultralytics/nn/tasks.py +94 -154
  163. ultralytics/nn/text_model.py +40 -67
  164. ultralytics/solutions/__init__.py +12 -12
  165. ultralytics/solutions/ai_gym.py +11 -17
  166. ultralytics/solutions/analytics.py +15 -16
  167. ultralytics/solutions/config.py +5 -6
  168. ultralytics/solutions/distance_calculation.py +10 -13
  169. ultralytics/solutions/heatmap.py +7 -13
  170. ultralytics/solutions/instance_segmentation.py +5 -8
  171. ultralytics/solutions/object_blurrer.py +7 -10
  172. ultralytics/solutions/object_counter.py +12 -19
  173. ultralytics/solutions/object_cropper.py +8 -14
  174. ultralytics/solutions/parking_management.py +33 -31
  175. ultralytics/solutions/queue_management.py +10 -12
  176. ultralytics/solutions/region_counter.py +9 -12
  177. ultralytics/solutions/security_alarm.py +15 -20
  178. ultralytics/solutions/similarity_search.py +10 -15
  179. ultralytics/solutions/solutions.py +75 -74
  180. ultralytics/solutions/speed_estimation.py +7 -10
  181. ultralytics/solutions/streamlit_inference.py +2 -4
  182. ultralytics/solutions/templates/similarity-search.html +7 -18
  183. ultralytics/solutions/trackzone.py +7 -10
  184. ultralytics/solutions/vision_eye.py +5 -8
  185. ultralytics/trackers/__init__.py +1 -1
  186. ultralytics/trackers/basetrack.py +3 -5
  187. ultralytics/trackers/bot_sort.py +10 -27
  188. ultralytics/trackers/byte_tracker.py +14 -30
  189. ultralytics/trackers/track.py +3 -6
  190. ultralytics/trackers/utils/gmc.py +11 -22
  191. ultralytics/trackers/utils/kalman_filter.py +37 -48
  192. ultralytics/trackers/utils/matching.py +12 -15
  193. ultralytics/utils/__init__.py +116 -116
  194. ultralytics/utils/autobatch.py +2 -4
  195. ultralytics/utils/autodevice.py +17 -18
  196. ultralytics/utils/benchmarks.py +32 -46
  197. ultralytics/utils/callbacks/base.py +8 -10
  198. ultralytics/utils/callbacks/clearml.py +5 -13
  199. ultralytics/utils/callbacks/comet.py +32 -46
  200. ultralytics/utils/callbacks/dvc.py +13 -18
  201. ultralytics/utils/callbacks/mlflow.py +4 -5
  202. ultralytics/utils/callbacks/neptune.py +7 -15
  203. ultralytics/utils/callbacks/platform.py +314 -38
  204. ultralytics/utils/callbacks/raytune.py +3 -4
  205. ultralytics/utils/callbacks/tensorboard.py +23 -31
  206. ultralytics/utils/callbacks/wb.py +10 -13
  207. ultralytics/utils/checks.py +99 -76
  208. ultralytics/utils/cpu.py +3 -8
  209. ultralytics/utils/dist.py +8 -12
  210. ultralytics/utils/downloads.py +20 -30
  211. ultralytics/utils/errors.py +6 -14
  212. ultralytics/utils/events.py +2 -4
  213. ultralytics/utils/export/__init__.py +4 -236
  214. ultralytics/utils/export/engine.py +237 -0
  215. ultralytics/utils/export/imx.py +91 -55
  216. ultralytics/utils/export/tensorflow.py +231 -0
  217. ultralytics/utils/files.py +24 -28
  218. ultralytics/utils/git.py +9 -11
  219. ultralytics/utils/instance.py +30 -51
  220. ultralytics/utils/logger.py +212 -114
  221. ultralytics/utils/loss.py +14 -22
  222. ultralytics/utils/metrics.py +126 -155
  223. ultralytics/utils/nms.py +13 -16
  224. ultralytics/utils/ops.py +107 -165
  225. ultralytics/utils/patches.py +33 -21
  226. ultralytics/utils/plotting.py +72 -80
  227. ultralytics/utils/tal.py +25 -39
  228. ultralytics/utils/torch_utils.py +52 -78
  229. ultralytics/utils/tqdm.py +20 -20
  230. ultralytics/utils/triton.py +13 -19
  231. ultralytics/utils/tuner.py +17 -5
  232. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  233. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  234. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  235. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  236. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,10 @@ from .val import YOLOEDetectValidator
19
19
 
20
20
 
21
21
  class YOLOETrainer(DetectionTrainer):
22
- """
23
- A trainer class for YOLOE object detection models.
22
+ """A trainer class for YOLOE object detection models.
24
23
 
25
- This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,
26
- including custom model initialization, validation, and dataset building with multi-modal support.
24
+ This class extends DetectionTrainer to provide specialized training functionality for YOLOE models, including custom
25
+ model initialization, validation, and dataset building with multi-modal support.
27
26
 
28
27
  Attributes:
29
28
  loss_names (tuple): Names of loss components used during training.
@@ -35,8 +34,7 @@ class YOLOETrainer(DetectionTrainer):
35
34
  """
36
35
 
37
36
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
38
- """
39
- Initialize the YOLOE Trainer with specified configurations.
37
+ """Initialize the YOLOE Trainer with specified configurations.
40
38
 
41
39
  Args:
42
40
  cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
@@ -50,12 +48,11 @@ class YOLOETrainer(DetectionTrainer):
50
48
  super().__init__(cfg, overrides, _callbacks)
51
49
 
52
50
  def get_model(self, cfg=None, weights=None, verbose: bool = True):
53
- """
54
- Return a YOLOEModel initialized with the specified configuration and weights.
51
+ """Return a YOLOEModel initialized with the specified configuration and weights.
55
52
 
56
53
  Args:
57
- cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
58
- a direct path to a YAML file, or None to use default configuration.
54
+ cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key, a direct
55
+ path to a YAML file, or None to use default configuration.
59
56
  weights (str | Path, optional): Path to pretrained weights file to load into the model.
60
57
  verbose (bool): Whether to display model information during initialization.
61
58
 
@@ -88,8 +85,7 @@ class YOLOETrainer(DetectionTrainer):
88
85
  )
89
86
 
90
87
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
91
- """
92
- Build YOLO Dataset.
88
+ """Build YOLO Dataset.
93
89
 
94
90
  Args:
95
91
  img_path (str): Path to the folder containing images.
@@ -106,19 +102,17 @@ class YOLOETrainer(DetectionTrainer):
106
102
 
107
103
 
108
104
  class YOLOEPETrainer(DetectionTrainer):
109
- """
110
- Fine-tune YOLOE model using linear probing approach.
105
+ """Fine-tune YOLOE model using linear probing approach.
111
106
 
112
- This trainer freezes most model layers and only trains specific projection layers for efficient
113
- fine-tuning on new datasets while preserving pretrained features.
107
+ This trainer freezes most model layers and only trains specific projection layers for efficient fine-tuning on new
108
+ datasets while preserving pretrained features.
114
109
 
115
110
  Methods:
116
111
  get_model: Initialize YOLOEModel with frozen layers except projection layers.
117
112
  """
118
113
 
119
114
  def get_model(self, cfg=None, weights=None, verbose: bool = True):
120
- """
121
- Return YOLOEModel initialized with specified config and weights.
115
+ """Return YOLOEModel initialized with specified config and weights.
122
116
 
123
117
  Args:
124
118
  cfg (dict | str, optional): Model configuration.
@@ -160,11 +154,10 @@ class YOLOEPETrainer(DetectionTrainer):
160
154
 
161
155
 
162
156
  class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
163
- """
164
- Train YOLOE models from scratch with text embedding support.
157
+ """Train YOLOE models from scratch with text embedding support.
165
158
 
166
- This trainer combines YOLOE training capabilities with world training features, enabling
167
- training from scratch with text embeddings and grounding datasets.
159
+ This trainer combines YOLOE training capabilities with world training features, enabling training from scratch with
160
+ text embeddings and grounding datasets.
168
161
 
169
162
  Methods:
170
163
  build_dataset: Build datasets for training with grounding support.
@@ -172,11 +165,10 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
172
165
  """
173
166
 
174
167
  def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
175
- """
176
- Build YOLO Dataset for training or validation.
168
+ """Build YOLO Dataset for training or validation.
177
169
 
178
- This method constructs appropriate datasets based on the mode and input paths, handling both
179
- standard YOLO datasets and grounding datasets with different formats.
170
+ This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
171
+ datasets and grounding datasets with different formats.
180
172
 
181
173
  Args:
182
174
  img_path (list[str] | str): Path to the folder containing images or list of paths.
@@ -189,8 +181,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
189
181
  return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
190
182
 
191
183
  def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
192
- """
193
- Generate text embeddings for a list of text samples.
184
+ """Generate text embeddings for a list of text samples.
194
185
 
195
186
  Args:
196
187
  texts (list[str]): List of text samples to encode.
@@ -216,11 +207,10 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
216
207
 
217
208
 
218
209
  class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
219
- """
220
- Train prompt-free YOLOE model.
210
+ """Train prompt-free YOLOE model.
221
211
 
222
- This trainer combines linear probing capabilities with from-scratch training for prompt-free
223
- YOLOE models that don't require text prompts during inference.
212
+ This trainer combines linear probing capabilities with from-scratch training for prompt-free YOLOE models that don't
213
+ require text prompts during inference.
224
214
 
225
215
  Methods:
226
216
  get_validator: Return standard DetectionValidator for validation.
@@ -240,12 +230,11 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
240
230
  return DetectionTrainer.preprocess_batch(self, batch)
241
231
 
242
232
  def set_text_embeddings(self, datasets, batch: int):
243
- """
244
- Set text embeddings for datasets to accelerate training by caching category names.
233
+ """Set text embeddings for datasets to accelerate training by caching category names.
245
234
 
246
- This method collects unique category names from all datasets, generates text embeddings for them,
247
- and caches these embeddings to improve training efficiency. The embeddings are stored in a file
248
- in the parent directory of the first dataset's image path.
235
+ This method collects unique category names from all datasets, generates text embeddings for them, and caches
236
+ these embeddings to improve training efficiency. The embeddings are stored in a file in the parent directory of
237
+ the first dataset's image path.
249
238
 
250
239
  Args:
251
240
  datasets (list[Dataset]): List of datasets containing category names to process.
@@ -260,19 +249,17 @@ class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
260
249
 
261
250
 
262
251
  class YOLOEVPTrainer(YOLOETrainerFromScratch):
263
- """
264
- Train YOLOE model with visual prompts.
252
+ """Train YOLOE model with visual prompts.
265
253
 
266
- This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,
267
- where visual cues are provided alongside images to guide the detection process.
254
+ This trainer extends YOLOETrainerFromScratch to support visual prompt-based training, where visual cues are provided
255
+ alongside images to guide the detection process.
268
256
 
269
257
  Methods:
270
258
  build_dataset: Build dataset with visual prompt loading transforms.
271
259
  """
272
260
 
273
261
  def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
274
- """
275
- Build YOLO Dataset for training or validation with visual prompts.
262
+ """Build YOLO Dataset for training or validation with visual prompts.
276
263
 
277
264
  Args:
278
265
  img_path (list[str] | str): Path to the folder containing images or list of paths.
@@ -11,8 +11,7 @@ from .val import YOLOESegValidator
11
11
 
12
12
 
13
13
  class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
14
- """
15
- Trainer class for YOLOE segmentation models.
14
+ """Trainer class for YOLOE segmentation models.
16
15
 
17
16
  This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
18
17
  segmentation models, enabling both object detection and instance segmentation capabilities.
@@ -24,8 +23,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
24
23
  """
25
24
 
26
25
  def get_model(self, cfg=None, weights=None, verbose=True):
27
- """
28
- Return YOLOESegModel initialized with specified config and weights.
26
+ """Return YOLOESegModel initialized with specified config and weights.
29
27
 
30
28
  Args:
31
29
  cfg (dict | str, optional): Model configuration dictionary or YAML file path.
@@ -49,8 +47,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
49
47
  return model
50
48
 
51
49
  def get_validator(self):
52
- """
53
- Create and return a validator for YOLOE segmentation model evaluation.
50
+ """Create and return a validator for YOLOE segmentation model evaluation.
54
51
 
55
52
  Returns:
56
53
  (YOLOESegValidator): Validator for YOLOE segmentation models.
@@ -62,8 +59,7 @@ class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
62
59
 
63
60
 
64
61
  class YOLOEPESegTrainer(SegmentationTrainer):
65
- """
66
- Fine-tune YOLOESeg model in linear probing way.
62
+ """Fine-tune YOLOESeg model in linear probing way.
67
63
 
68
64
  This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
69
65
  most of the model and only training specific layers for efficient adaptation to new tasks.
@@ -73,8 +69,7 @@ class YOLOEPESegTrainer(SegmentationTrainer):
73
69
  """
74
70
 
75
71
  def get_model(self, cfg=None, weights=None, verbose=True):
76
- """
77
- Return YOLOESegModel initialized with specified config and weights for linear probing.
72
+ """Return YOLOESegModel initialized with specified config and weights for linear probing.
78
73
 
79
74
  Args:
80
75
  cfg (dict | str, optional): Model configuration dictionary or YAML file path.
@@ -21,12 +21,11 @@ from ultralytics.utils.torch_utils import select_device, smart_inference_mode
21
21
 
22
22
 
23
23
  class YOLOEDetectValidator(DetectionValidator):
24
- """
25
- A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
24
+ """A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
26
25
 
27
- This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.
28
- It supports validation using either text prompts or visual prompt embeddings extracted from training samples,
29
- enabling flexible evaluation strategies for prompt-based object detection.
26
+ This class extends DetectionValidator to provide specialized validation functionality for YOLOE models. It supports
27
+ validation using either text prompts or visual prompt embeddings extracted from training samples, enabling flexible
28
+ evaluation strategies for prompt-based object detection.
30
29
 
31
30
  Attributes:
32
31
  device (torch.device): The device on which validation is performed.
@@ -50,12 +49,11 @@ class YOLOEDetectValidator(DetectionValidator):
50
49
 
51
50
  @smart_inference_mode()
52
51
  def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
53
- """
54
- Extract visual prompt embeddings from training samples.
52
+ """Extract visual prompt embeddings from training samples.
55
53
 
56
- This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
57
- It normalizes the embeddings and handles cases where no samples exist for a class by setting their
58
- embeddings to zero.
54
+ This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model. It
55
+ normalizes the embeddings and handles cases where no samples exist for a class by setting their embeddings to
56
+ zero.
59
57
 
60
58
  Args:
61
59
  dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
@@ -99,12 +97,10 @@ class YOLOEDetectValidator(DetectionValidator):
99
97
  return visual_pe.unsqueeze(0)
100
98
 
101
99
  def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
102
- """
103
- Create a dataloader for LVIS training visual prompt samples.
100
+ """Create a dataloader for LVIS training visual prompt samples.
104
101
 
105
- This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
106
- It applies necessary transformations including LoadVisualPrompt and configurations to the dataset
107
- for validation purposes.
102
+ This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset. It applies
103
+ necessary transformations including LoadVisualPrompt and configurations to the dataset for validation purposes.
108
104
 
109
105
  Args:
110
106
  data (dict): Dataset configuration dictionary containing paths and settings.
@@ -141,12 +137,11 @@ class YOLOEDetectValidator(DetectionValidator):
141
137
  refer_data: str | None = None,
142
138
  load_vp: bool = False,
143
139
  ) -> dict[str, Any]:
144
- """
145
- Run validation on the model using either text or visual prompt embeddings.
140
+ """Run validation on the model using either text or visual prompt embeddings.
146
141
 
147
- This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
148
- It supports validation during training (using a trainer object) or standalone validation with a provided
149
- model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.
142
+ This method validates the model using either text prompts or visual prompts, depending on the load_vp flag. It
143
+ supports validation during training (using a trainer object) or standalone validation with a provided model. For
144
+ visual prompts, reference data can be specified to extract embeddings from a different dataset.
150
145
 
151
146
  Args:
152
147
  trainer (object, optional): Trainer object containing the model and device.
@@ -14,14 +14,14 @@ from .tasks import (
14
14
  )
15
15
 
16
16
  __all__ = (
17
+ "BaseModel",
18
+ "ClassificationModel",
19
+ "DetectionModel",
20
+ "SegmentationModel",
21
+ "guess_model_scale",
22
+ "guess_model_task",
17
23
  "load_checkpoint",
18
24
  "parse_model",
19
- "yaml_model_load",
20
- "guess_model_task",
21
- "guess_model_scale",
22
25
  "torch_safe_load",
23
- "DetectionModel",
24
- "SegmentationModel",
25
- "ClassificationModel",
26
- "BaseModel",
26
+ "yaml_model_load",
27
27
  )