dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (243) hide show
  1. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
  2. dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +8 -15
  5. tests/test_cli.py +8 -10
  6. tests/test_cuda.py +9 -10
  7. tests/test_engine.py +29 -2
  8. tests/test_exports.py +69 -21
  9. tests/test_integrations.py +8 -11
  10. tests/test_python.py +109 -71
  11. tests/test_solutions.py +170 -159
  12. ultralytics/__init__.py +27 -9
  13. ultralytics/cfg/__init__.py +57 -64
  14. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  16. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  17. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  18. ultralytics/cfg/datasets/Objects365.yaml +19 -15
  19. ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
  20. ultralytics/cfg/datasets/VOC.yaml +19 -21
  21. ultralytics/cfg/datasets/VisDrone.yaml +5 -5
  22. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  23. ultralytics/cfg/datasets/coco-pose.yaml +24 -2
  24. ultralytics/cfg/datasets/coco.yaml +2 -2
  25. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  26. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  27. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  28. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  29. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  30. ultralytics/cfg/datasets/dota8.yaml +2 -2
  31. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  32. ultralytics/cfg/datasets/kitti.yaml +27 -0
  33. ultralytics/cfg/datasets/lvis.yaml +7 -7
  34. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  35. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  36. ultralytics/cfg/datasets/xView.yaml +16 -16
  37. ultralytics/cfg/default.yaml +96 -94
  38. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  39. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  40. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  41. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  42. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  43. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  44. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  45. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  46. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  47. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  48. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  49. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  50. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  51. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  52. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  53. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  54. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  55. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  56. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  57. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  58. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  59. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  60. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  61. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  62. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  65. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  66. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  67. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  68. ultralytics/cfg/trackers/botsort.yaml +16 -17
  69. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  70. ultralytics/data/__init__.py +4 -4
  71. ultralytics/data/annotator.py +3 -4
  72. ultralytics/data/augment.py +286 -476
  73. ultralytics/data/base.py +18 -26
  74. ultralytics/data/build.py +151 -26
  75. ultralytics/data/converter.py +38 -50
  76. ultralytics/data/dataset.py +47 -75
  77. ultralytics/data/loaders.py +42 -49
  78. ultralytics/data/split.py +5 -6
  79. ultralytics/data/split_dota.py +8 -15
  80. ultralytics/data/utils.py +41 -45
  81. ultralytics/engine/exporter.py +462 -462
  82. ultralytics/engine/model.py +150 -191
  83. ultralytics/engine/predictor.py +30 -40
  84. ultralytics/engine/results.py +177 -311
  85. ultralytics/engine/trainer.py +193 -120
  86. ultralytics/engine/tuner.py +77 -63
  87. ultralytics/engine/validator.py +39 -22
  88. ultralytics/hub/__init__.py +16 -19
  89. ultralytics/hub/auth.py +6 -12
  90. ultralytics/hub/google/__init__.py +7 -10
  91. ultralytics/hub/session.py +15 -25
  92. ultralytics/hub/utils.py +5 -8
  93. ultralytics/models/__init__.py +1 -1
  94. ultralytics/models/fastsam/__init__.py +1 -1
  95. ultralytics/models/fastsam/model.py +8 -10
  96. ultralytics/models/fastsam/predict.py +19 -30
  97. ultralytics/models/fastsam/utils.py +1 -2
  98. ultralytics/models/fastsam/val.py +5 -7
  99. ultralytics/models/nas/__init__.py +1 -1
  100. ultralytics/models/nas/model.py +5 -8
  101. ultralytics/models/nas/predict.py +7 -9
  102. ultralytics/models/nas/val.py +1 -2
  103. ultralytics/models/rtdetr/__init__.py +1 -1
  104. ultralytics/models/rtdetr/model.py +7 -8
  105. ultralytics/models/rtdetr/predict.py +15 -19
  106. ultralytics/models/rtdetr/train.py +10 -13
  107. ultralytics/models/rtdetr/val.py +21 -23
  108. ultralytics/models/sam/__init__.py +15 -2
  109. ultralytics/models/sam/amg.py +14 -20
  110. ultralytics/models/sam/build.py +26 -19
  111. ultralytics/models/sam/build_sam3.py +377 -0
  112. ultralytics/models/sam/model.py +29 -32
  113. ultralytics/models/sam/modules/blocks.py +83 -144
  114. ultralytics/models/sam/modules/decoders.py +22 -40
  115. ultralytics/models/sam/modules/encoders.py +44 -101
  116. ultralytics/models/sam/modules/memory_attention.py +16 -30
  117. ultralytics/models/sam/modules/sam.py +206 -79
  118. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  119. ultralytics/models/sam/modules/transformer.py +18 -28
  120. ultralytics/models/sam/modules/utils.py +174 -50
  121. ultralytics/models/sam/predict.py +2268 -366
  122. ultralytics/models/sam/sam3/__init__.py +3 -0
  123. ultralytics/models/sam/sam3/decoder.py +546 -0
  124. ultralytics/models/sam/sam3/encoder.py +529 -0
  125. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  126. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  127. ultralytics/models/sam/sam3/model_misc.py +199 -0
  128. ultralytics/models/sam/sam3/necks.py +129 -0
  129. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  130. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  131. ultralytics/models/sam/sam3/vitdet.py +547 -0
  132. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  133. ultralytics/models/utils/loss.py +14 -26
  134. ultralytics/models/utils/ops.py +13 -17
  135. ultralytics/models/yolo/__init__.py +1 -1
  136. ultralytics/models/yolo/classify/predict.py +9 -12
  137. ultralytics/models/yolo/classify/train.py +15 -41
  138. ultralytics/models/yolo/classify/val.py +34 -32
  139. ultralytics/models/yolo/detect/predict.py +8 -11
  140. ultralytics/models/yolo/detect/train.py +13 -32
  141. ultralytics/models/yolo/detect/val.py +75 -63
  142. ultralytics/models/yolo/model.py +37 -53
  143. ultralytics/models/yolo/obb/predict.py +5 -14
  144. ultralytics/models/yolo/obb/train.py +11 -14
  145. ultralytics/models/yolo/obb/val.py +42 -39
  146. ultralytics/models/yolo/pose/__init__.py +1 -1
  147. ultralytics/models/yolo/pose/predict.py +7 -22
  148. ultralytics/models/yolo/pose/train.py +10 -22
  149. ultralytics/models/yolo/pose/val.py +40 -59
  150. ultralytics/models/yolo/segment/predict.py +16 -20
  151. ultralytics/models/yolo/segment/train.py +3 -12
  152. ultralytics/models/yolo/segment/val.py +106 -56
  153. ultralytics/models/yolo/world/train.py +12 -16
  154. ultralytics/models/yolo/world/train_world.py +11 -34
  155. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  156. ultralytics/models/yolo/yoloe/predict.py +16 -23
  157. ultralytics/models/yolo/yoloe/train.py +31 -56
  158. ultralytics/models/yolo/yoloe/train_seg.py +5 -10
  159. ultralytics/models/yolo/yoloe/val.py +16 -21
  160. ultralytics/nn/__init__.py +7 -7
  161. ultralytics/nn/autobackend.py +152 -80
  162. ultralytics/nn/modules/__init__.py +60 -60
  163. ultralytics/nn/modules/activation.py +4 -6
  164. ultralytics/nn/modules/block.py +133 -217
  165. ultralytics/nn/modules/conv.py +52 -97
  166. ultralytics/nn/modules/head.py +64 -116
  167. ultralytics/nn/modules/transformer.py +79 -89
  168. ultralytics/nn/modules/utils.py +16 -21
  169. ultralytics/nn/tasks.py +111 -156
  170. ultralytics/nn/text_model.py +40 -67
  171. ultralytics/solutions/__init__.py +12 -12
  172. ultralytics/solutions/ai_gym.py +11 -17
  173. ultralytics/solutions/analytics.py +15 -16
  174. ultralytics/solutions/config.py +5 -6
  175. ultralytics/solutions/distance_calculation.py +10 -13
  176. ultralytics/solutions/heatmap.py +7 -13
  177. ultralytics/solutions/instance_segmentation.py +5 -8
  178. ultralytics/solutions/object_blurrer.py +7 -10
  179. ultralytics/solutions/object_counter.py +12 -19
  180. ultralytics/solutions/object_cropper.py +8 -14
  181. ultralytics/solutions/parking_management.py +33 -31
  182. ultralytics/solutions/queue_management.py +10 -12
  183. ultralytics/solutions/region_counter.py +9 -12
  184. ultralytics/solutions/security_alarm.py +15 -20
  185. ultralytics/solutions/similarity_search.py +13 -17
  186. ultralytics/solutions/solutions.py +75 -74
  187. ultralytics/solutions/speed_estimation.py +7 -10
  188. ultralytics/solutions/streamlit_inference.py +4 -7
  189. ultralytics/solutions/templates/similarity-search.html +7 -18
  190. ultralytics/solutions/trackzone.py +7 -10
  191. ultralytics/solutions/vision_eye.py +5 -8
  192. ultralytics/trackers/__init__.py +1 -1
  193. ultralytics/trackers/basetrack.py +3 -5
  194. ultralytics/trackers/bot_sort.py +10 -27
  195. ultralytics/trackers/byte_tracker.py +14 -30
  196. ultralytics/trackers/track.py +3 -6
  197. ultralytics/trackers/utils/gmc.py +11 -22
  198. ultralytics/trackers/utils/kalman_filter.py +37 -48
  199. ultralytics/trackers/utils/matching.py +12 -15
  200. ultralytics/utils/__init__.py +116 -116
  201. ultralytics/utils/autobatch.py +2 -4
  202. ultralytics/utils/autodevice.py +17 -18
  203. ultralytics/utils/benchmarks.py +70 -70
  204. ultralytics/utils/callbacks/base.py +8 -10
  205. ultralytics/utils/callbacks/clearml.py +5 -13
  206. ultralytics/utils/callbacks/comet.py +32 -46
  207. ultralytics/utils/callbacks/dvc.py +13 -18
  208. ultralytics/utils/callbacks/mlflow.py +4 -5
  209. ultralytics/utils/callbacks/neptune.py +7 -15
  210. ultralytics/utils/callbacks/platform.py +314 -38
  211. ultralytics/utils/callbacks/raytune.py +3 -4
  212. ultralytics/utils/callbacks/tensorboard.py +23 -31
  213. ultralytics/utils/callbacks/wb.py +10 -13
  214. ultralytics/utils/checks.py +151 -87
  215. ultralytics/utils/cpu.py +3 -8
  216. ultralytics/utils/dist.py +19 -15
  217. ultralytics/utils/downloads.py +29 -41
  218. ultralytics/utils/errors.py +6 -14
  219. ultralytics/utils/events.py +2 -4
  220. ultralytics/utils/export/__init__.py +7 -0
  221. ultralytics/utils/{export.py → export/engine.py} +16 -16
  222. ultralytics/utils/export/imx.py +325 -0
  223. ultralytics/utils/export/tensorflow.py +231 -0
  224. ultralytics/utils/files.py +24 -28
  225. ultralytics/utils/git.py +9 -11
  226. ultralytics/utils/instance.py +30 -51
  227. ultralytics/utils/logger.py +212 -114
  228. ultralytics/utils/loss.py +15 -24
  229. ultralytics/utils/metrics.py +131 -160
  230. ultralytics/utils/nms.py +21 -30
  231. ultralytics/utils/ops.py +107 -165
  232. ultralytics/utils/patches.py +33 -21
  233. ultralytics/utils/plotting.py +122 -119
  234. ultralytics/utils/tal.py +28 -44
  235. ultralytics/utils/torch_utils.py +70 -187
  236. ultralytics/utils/tqdm.py +20 -20
  237. ultralytics/utils/triton.py +13 -19
  238. ultralytics/utils/tuner.py +17 -5
  239. dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
  240. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
  241. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
  242. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
  243. {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
@@ -8,12 +8,10 @@ from pathlib import Path
8
8
  from ultralytics.models import yolo
9
9
  from ultralytics.nn.tasks import SegmentationModel
10
10
  from ultralytics.utils import DEFAULT_CFG, RANK
11
- from ultralytics.utils.plotting import plot_results
12
11
 
13
12
 
14
13
  class SegmentationTrainer(yolo.detect.DetectionTrainer):
15
- """
16
- A class extending the DetectionTrainer class for training based on a segmentation model.
14
+ """A class extending the DetectionTrainer class for training based on a segmentation model.
17
15
 
18
16
  This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
19
17
  functionality including model initialization, validation, and visualization.
@@ -29,8 +27,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
29
27
  """
30
28
 
31
29
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
32
- """
33
- Initialize a SegmentationTrainer object.
30
+ """Initialize a SegmentationTrainer object.
34
31
 
35
32
  Args:
36
33
  cfg (dict): Configuration dictionary with default training settings.
@@ -41,11 +38,9 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
41
38
  overrides = {}
42
39
  overrides["task"] = "segment"
43
40
  super().__init__(cfg, overrides, _callbacks)
44
- self.dynamic_tensors = ["batch_idx", "cls", "bboxes", "masks"]
45
41
 
46
42
  def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
47
- """
48
- Initialize and return a SegmentationModel with specified configuration and weights.
43
+ """Initialize and return a SegmentationModel with specified configuration and weights.
49
44
 
50
45
  Args:
51
46
  cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
@@ -72,7 +67,3 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
72
67
  return yolo.segment.SegmentationValidator(
73
68
  self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
74
69
  )
75
-
76
- def plot_metrics(self):
77
- """Plot training/validation metrics."""
78
- plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
@@ -2,7 +2,6 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from multiprocessing.pool import ThreadPool
6
5
  from pathlib import Path
7
6
  from typing import Any
8
7
 
@@ -11,17 +10,16 @@ import torch
11
10
  import torch.nn.functional as F
12
11
 
13
12
  from ultralytics.models.yolo.detect import DetectionValidator
14
- from ultralytics.utils import LOGGER, NUM_THREADS, ops
13
+ from ultralytics.utils import LOGGER, ops
15
14
  from ultralytics.utils.checks import check_requirements
16
15
  from ultralytics.utils.metrics import SegmentMetrics, mask_iou
17
16
 
18
17
 
19
18
  class SegmentationValidator(DetectionValidator):
20
- """
21
- A class extending the DetectionValidator class for validation based on a segmentation model.
19
+ """A class extending the DetectionValidator class for validation based on a segmentation model.
22
20
 
23
- This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
24
- to compute metrics such as mAP for both detection and segmentation tasks.
21
+ This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions to
22
+ compute metrics such as mAP for both detection and segmentation tasks.
25
23
 
26
24
  Attributes:
27
25
  plot_masks (list): List to store masks for plotting.
@@ -38,11 +36,10 @@ class SegmentationValidator(DetectionValidator):
38
36
  """
39
37
 
40
38
  def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
41
- """
42
- Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
39
+ """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
43
40
 
44
41
  Args:
45
- dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
42
+ dataloader (torch.utils.data.DataLoader, optional): DataLoader to use for validation.
46
43
  save_dir (Path, optional): Directory to save results.
47
44
  args (namespace, optional): Arguments for the validator.
48
45
  _callbacks (list, optional): List of callback functions.
@@ -53,8 +50,7 @@ class SegmentationValidator(DetectionValidator):
53
50
  self.metrics = SegmentMetrics()
54
51
 
55
52
  def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
56
- """
57
- Preprocess batch of images for YOLO segmentation validation.
53
+ """Preprocess batch of images for YOLO segmentation validation.
58
54
 
59
55
  Args:
60
56
  batch (dict[str, Any]): Batch containing images and annotations.
@@ -67,8 +63,7 @@ class SegmentationValidator(DetectionValidator):
67
63
  return batch
68
64
 
69
65
  def init_metrics(self, model: torch.nn.Module) -> None:
70
- """
71
- Initialize metrics and select mask processing function based on save_json flag.
66
+ """Initialize metrics and select mask processing function based on save_json flag.
72
67
 
73
68
  Args:
74
69
  model (torch.nn.Module): Model to validate.
@@ -96,8 +91,7 @@ class SegmentationValidator(DetectionValidator):
96
91
  )
97
92
 
98
93
  def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
99
- """
100
- Post-process YOLO predictions and return output detections with proto.
94
+ """Post-process YOLO predictions and return output detections with proto.
101
95
 
102
96
  Args:
103
97
  preds (list[torch.Tensor]): Raw predictions from the model.
@@ -112,7 +106,7 @@ class SegmentationValidator(DetectionValidator):
112
106
  coefficient = pred.pop("extra")
113
107
  pred["masks"] = (
114
108
  self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
115
- if len(coefficient)
109
+ if coefficient.shape[0]
116
110
  else torch.zeros(
117
111
  (0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),
118
112
  dtype=torch.uint8,
@@ -122,8 +116,7 @@ class SegmentationValidator(DetectionValidator):
122
116
  return preds
123
117
 
124
118
  def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
125
- """
126
- Prepare a batch for training or inference by processing images and targets.
119
+ """Prepare a batch for training or inference by processing images and targets.
127
120
 
128
121
  Args:
129
122
  si (int): Batch index.
@@ -133,22 +126,23 @@ class SegmentationValidator(DetectionValidator):
133
126
  (dict[str, Any]): Prepared batch with processed annotations.
134
127
  """
135
128
  prepared_batch = super()._prepare_batch(si, batch)
136
- nl = len(prepared_batch["cls"])
129
+ nl = prepared_batch["cls"].shape[0]
137
130
  if self.args.overlap_mask:
138
131
  masks = batch["masks"][si]
139
132
  index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
140
133
  masks = (masks == index).float()
141
134
  else:
142
135
  masks = batch["masks"][batch["batch_idx"] == si]
143
- if nl and self.process is ops.process_mask_native:
144
- masks = F.interpolate(masks[None], prepared_batch["imgsz"], mode="bilinear", align_corners=False)[0]
145
- masks = masks.gt_(0.5)
136
+ if nl:
137
+ mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
138
+ if masks.shape[1:] != mask_size:
139
+ masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
140
+ masks = masks.gt_(0.5)
146
141
  prepared_batch["masks"] = masks
147
142
  return prepared_batch
148
143
 
149
144
  def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
150
- """
151
- Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
145
+ """Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
152
146
 
153
147
  Args:
154
148
  preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
@@ -157,28 +151,27 @@ class SegmentationValidator(DetectionValidator):
157
151
  Returns:
158
152
  (dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
159
153
 
160
- Notes:
161
- - If `masks` is True, the function computes IoU between predicted and ground truth masks.
162
- - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
163
-
164
154
  Examples:
165
155
  >>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
166
156
  >>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
167
157
  >>> correct_preds = validator._process_batch(preds, batch)
158
+
159
+ Notes:
160
+ - If `masks` is True, the function computes IoU between predicted and ground truth masks.
161
+ - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
168
162
  """
169
163
  tp = super()._process_batch(preds, batch)
170
164
  gt_cls = batch["cls"]
171
- if len(gt_cls) == 0 or len(preds["cls"]) == 0:
172
- tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
165
+ if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
166
+ tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
173
167
  else:
174
- iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
168
+ iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1).float()) # float, uint8
175
169
  tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
176
170
  tp.update({"tp_m": tp_m}) # update tp with mask IoU
177
171
  return tp
178
172
 
179
173
  def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
180
- """
181
- Plot batch predictions with masks and bounding boxes.
174
+ """Plot batch predictions with masks and bounding boxes.
182
175
 
183
176
  Args:
184
177
  batch (dict[str, Any]): Batch containing images and annotations.
@@ -187,14 +180,13 @@ class SegmentationValidator(DetectionValidator):
187
180
  """
188
181
  for p in preds:
189
182
  masks = p["masks"]
190
- if masks.shape[0] > 50:
191
- LOGGER.warning("Limiting validation plots to first 50 items per image for speed...")
192
- p["masks"] = torch.as_tensor(masks[:50], dtype=torch.uint8).cpu()
193
- super().plot_predictions(batch, preds, ni, max_det=50) # plot bboxes
183
+ if masks.shape[0] > self.args.max_det:
184
+ LOGGER.warning(f"Limiting validation plots to 'max_det={self.args.max_det}' items.")
185
+ p["masks"] = torch.as_tensor(masks[: self.args.max_det], dtype=torch.uint8).cpu()
186
+ super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
194
187
 
195
188
  def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
196
- """
197
- Save YOLO detections to a txt file in normalized coordinates in a specific format.
189
+ """Save YOLO detections to a txt file in normalized coordinates in a specific format.
198
190
 
199
191
  Args:
200
192
  predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
@@ -213,24 +205,84 @@ class SegmentationValidator(DetectionValidator):
213
205
  ).save_txt(file, save_conf=save_conf)
214
206
 
215
207
  def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
216
- """
217
- Save one JSON result for COCO evaluation.
208
+ """Save one JSON result for COCO evaluation.
218
209
 
219
210
  Args:
220
211
  predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
221
212
  pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
222
213
  """
223
- from faster_coco_eval.core.mask import encode # noqa
224
-
225
- def single_encode(x):
226
- """Encode predicted masks as RLE and append results to jdict."""
227
- rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
228
- rle["counts"] = rle["counts"].decode("utf-8")
229
- return rle
230
214
 
231
- pred_masks = np.transpose(predn["masks"], (2, 0, 1))
232
- with ThreadPool(NUM_THREADS) as pool:
233
- rles = pool.map(single_encode, pred_masks)
215
+ def to_string(counts: list[int]) -> str:
216
+ """Converts the RLE object into a compact string representation. Each count is delta-encoded and
217
+ variable-length encoded as a string.
218
+
219
+ Args:
220
+ counts (list[int]): List of RLE counts.
221
+ """
222
+ result = []
223
+
224
+ for i in range(len(counts)):
225
+ x = int(counts[i])
226
+
227
+ # Apply delta encoding for all counts after the second entry
228
+ if i > 2:
229
+ x -= int(counts[i - 2])
230
+
231
+ # Variable-length encode the value
232
+ while True:
233
+ c = x & 0x1F # Take 5 bits
234
+ x >>= 5
235
+
236
+ # If the sign bit (0x10) is set, continue if x != -1;
237
+ # otherwise, continue if x != 0
238
+ more = (x != -1) if (c & 0x10) else (x != 0)
239
+ if more:
240
+ c |= 0x20 # Set continuation bit
241
+ c += 48 # Shift to ASCII
242
+ result.append(chr(c))
243
+ if not more:
244
+ break
245
+
246
+ return "".join(result)
247
+
248
+ def multi_encode(pixels: torch.Tensor) -> list[int]:
249
+ """Convert multiple binary masks using Run-Length Encoding (RLE).
250
+
251
+ Args:
252
+ pixels (torch.Tensor): A 2D tensor where each row represents a flattened binary mask with shape [N,
253
+ H*W].
254
+
255
+ Returns:
256
+ (list[int]): A list of RLE counts for each mask.
257
+ """
258
+ transitions = pixels[:, 1:] != pixels[:, :-1]
259
+ row_idx, col_idx = torch.where(transitions)
260
+ col_idx = col_idx + 1
261
+
262
+ # Compute run lengths
263
+ counts = []
264
+ for i in range(pixels.shape[0]):
265
+ positions = col_idx[row_idx == i]
266
+ if len(positions):
267
+ count = torch.diff(positions).tolist()
268
+ count.insert(0, positions[0].item())
269
+ count.append(len(pixels[i]) - positions[-1].item())
270
+ else:
271
+ count = [len(pixels[i])]
272
+
273
+ # Ensure starting with background (0) count
274
+ if pixels[i][0].item() == 1:
275
+ count = [0, *count]
276
+ counts.append(count)
277
+
278
+ return counts
279
+
280
+ pred_masks = predn["masks"].transpose(2, 1).contiguous().view(len(predn["masks"]), -1) # N, H*W
281
+ h, w = predn["masks"].shape[1:3]
282
+ counts = multi_encode(pred_masks)
283
+ rles = []
284
+ for c in counts:
285
+ rles.append({"size": [h, w], "counts": to_string(c)})
234
286
  super().pred_to_json(predn, pbatch)
235
287
  for i, r in enumerate(rles):
236
288
  self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
@@ -239,11 +291,9 @@ class SegmentationValidator(DetectionValidator):
239
291
  """Scales predictions to the original image size."""
240
292
  return {
241
293
  **super().scale_preds(predn, pbatch),
242
- "masks": ops.scale_image(
243
- torch.as_tensor(predn["masks"], dtype=torch.uint8).permute(1, 2, 0).contiguous().cpu().numpy(),
244
- pbatch["ori_shape"],
245
- ratio_pad=pbatch["ratio_pad"],
246
- ),
294
+ "masks": ops.scale_masks(predn["masks"][None], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])[
295
+ 0
296
+ ].byte(),
247
297
  }
248
298
 
249
299
  def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
@@ -24,8 +24,7 @@ def on_pretrain_routine_end(trainer) -> None:
24
24
 
25
25
 
26
26
  class WorldTrainer(DetectionTrainer):
27
- """
28
- A trainer class for fine-tuning YOLO World models on close-set datasets.
27
+ """A trainer class for fine-tuning YOLO World models on close-set datasets.
29
28
 
30
29
  This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
31
30
  features for improved object detection and understanding. It handles text embedding generation and caching to
@@ -54,8 +53,7 @@ class WorldTrainer(DetectionTrainer):
54
53
  """
55
54
 
56
55
  def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
57
- """
58
- Initialize a WorldTrainer object with given arguments.
56
+ """Initialize a WorldTrainer object with given arguments.
59
57
 
60
58
  Args:
61
59
  cfg (dict[str, Any]): Configuration for the trainer.
@@ -64,12 +62,12 @@ class WorldTrainer(DetectionTrainer):
64
62
  """
65
63
  if overrides is None:
66
64
  overrides = {}
65
+ assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
67
66
  super().__init__(cfg, overrides, _callbacks)
68
67
  self.text_embeddings = None
69
68
 
70
69
  def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
71
- """
72
- Return WorldModel initialized with specified config and weights.
70
+ """Return WorldModel initialized with specified config and weights.
73
71
 
74
72
  Args:
75
73
  cfg (dict[str, Any] | str, optional): Model configuration.
@@ -94,8 +92,7 @@ class WorldTrainer(DetectionTrainer):
94
92
  return model
95
93
 
96
94
  def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
97
- """
98
- Build YOLO Dataset for training or validation.
95
+ """Build YOLO Dataset for training or validation.
99
96
 
100
97
  Args:
101
98
  img_path (str): Path to the folder containing images.
@@ -114,11 +111,10 @@ class WorldTrainer(DetectionTrainer):
114
111
  return dataset
115
112
 
116
113
  def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
117
- """
118
- Set text embeddings for datasets to accelerate training by caching category names.
114
+ """Set text embeddings for datasets to accelerate training by caching category names.
119
115
 
120
- This method collects unique category names from all datasets, then generates and caches text embeddings
121
- for these categories to improve training efficiency.
116
+ This method collects unique category names from all datasets, then generates and caches text embeddings for
117
+ these categories to improve training efficiency.
122
118
 
123
119
  Args:
124
120
  datasets (list[Any]): List of datasets from which to extract category names.
@@ -140,8 +136,7 @@ class WorldTrainer(DetectionTrainer):
140
136
  self.text_embeddings = text_embeddings
141
137
 
142
138
  def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
143
- """
144
- Generate text embeddings for a list of text samples.
139
+ """Generate text embeddings for a list of text samples.
145
140
 
146
141
  Args:
147
142
  texts (list[str]): List of text samples to encode.
@@ -171,7 +166,8 @@ class WorldTrainer(DetectionTrainer):
171
166
 
172
167
  # Add text features
173
168
  texts = list(itertools.chain(*batch["texts"]))
174
- txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device, non_blocking=True)
175
- txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
169
+ txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(
170
+ self.device, non_blocking=self.device.type == "cuda"
171
+ )
176
172
  batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
177
173
  return batch
@@ -10,8 +10,7 @@ from ultralytics.utils.torch_utils import unwrap_model
10
10
 
11
11
 
12
12
  class WorldTrainerFromScratch(WorldTrainer):
13
- """
14
- A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
13
+ """A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
15
14
 
16
15
  This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
17
16
  supporting training YOLO-World models with combined vision-language capabilities.
@@ -53,45 +52,25 @@ class WorldTrainerFromScratch(WorldTrainer):
53
52
  """
54
53
 
55
54
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
56
- """
57
- Initialize a WorldTrainerFromScratch object.
55
+ """Initialize a WorldTrainerFromScratch object.
58
56
 
59
- This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
60
- object detection and grounding datasets for vision-language capabilities.
57
+ This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both object
58
+ detection and grounding datasets for vision-language capabilities.
61
59
 
62
60
  Args:
63
61
  cfg (dict): Configuration dictionary with default parameters for model training.
64
62
  overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.
65
63
  _callbacks (list, optional): List of callback functions to be executed during different stages of training.
66
-
67
- Examples:
68
- >>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
69
- >>> from ultralytics import YOLOWorld
70
- >>> data = dict(
71
- ... train=dict(
72
- ... yolo_data=["Objects365.yaml"],
73
- ... grounding_data=[
74
- ... dict(
75
- ... img_path="flickr30k/images",
76
- ... json_file="flickr30k/final_flickr_separateGT_train.json",
77
- ... ),
78
- ... ],
79
- ... ),
80
- ... val=dict(yolo_data=["lvis.yaml"]),
81
- ... )
82
- >>> model = YOLOWorld("yolov8s-worldv2.yaml")
83
- >>> model.train(data=data, trainer=WorldTrainerFromScratch)
84
64
  """
85
65
  if overrides is None:
86
66
  overrides = {}
87
67
  super().__init__(cfg, overrides, _callbacks)
88
68
 
89
69
  def build_dataset(self, img_path, mode="train", batch=None):
90
- """
91
- Build YOLO Dataset for training or validation.
70
+ """Build YOLO Dataset for training or validation.
92
71
 
93
- This method constructs appropriate datasets based on the mode and input paths, handling both
94
- standard YOLO datasets and grounding datasets with different formats.
72
+ This method constructs appropriate datasets based on the mode and input paths, handling both standard YOLO
73
+ datasets and grounding datasets with different formats.
95
74
 
96
75
  Args:
97
76
  img_path (list[str] | str): Path to the folder containing images or list of paths.
@@ -122,11 +101,10 @@ class WorldTrainerFromScratch(WorldTrainer):
122
101
  return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
123
102
 
124
103
  def get_dataset(self):
125
- """
126
- Get train and validation paths from data dictionary.
104
+ """Get train and validation paths from data dictionary.
127
105
 
128
- Processes the data configuration to extract paths for training and validation datasets,
129
- handling both YOLO detection datasets and grounding datasets.
106
+ Processes the data configuration to extract paths for training and validation datasets, handling both YOLO
107
+ detection datasets and grounding datasets.
130
108
 
131
109
  Returns:
132
110
  train_path (str): Train dataset path.
@@ -187,8 +165,7 @@ class WorldTrainerFromScratch(WorldTrainer):
187
165
  pass
188
166
 
189
167
  def final_eval(self):
190
- """
191
- Perform final evaluation and validation for the YOLO-World model.
168
+ """Perform final evaluation and validation for the YOLO-World model.
192
169
 
193
170
  Configures the validator with appropriate dataset and split information before running evaluation.
194
171
 
@@ -6,17 +6,17 @@ from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromSc
6
6
  from .val import YOLOEDetectValidator, YOLOESegValidator
7
7
 
8
8
  __all__ = [
9
- "YOLOETrainer",
10
- "YOLOEPETrainer",
11
- "YOLOESegTrainer",
12
9
  "YOLOEDetectValidator",
13
- "YOLOESegValidator",
10
+ "YOLOEPEFreeTrainer",
14
11
  "YOLOEPESegTrainer",
12
+ "YOLOEPETrainer",
13
+ "YOLOESegTrainer",
15
14
  "YOLOESegTrainerFromScratch",
16
15
  "YOLOESegVPTrainer",
17
- "YOLOEVPTrainer",
18
- "YOLOEPEFreeTrainer",
16
+ "YOLOESegValidator",
17
+ "YOLOETrainer",
18
+ "YOLOETrainerFromScratch",
19
19
  "YOLOEVPDetectPredictor",
20
20
  "YOLOEVPSegPredictor",
21
- "YOLOETrainerFromScratch",
21
+ "YOLOEVPTrainer",
22
22
  ]
@@ -9,11 +9,10 @@ from ultralytics.models.yolo.segment import SegmentationPredictor
9
9
 
10
10
 
11
11
  class YOLOEVPDetectPredictor(DetectionPredictor):
12
- """
13
- A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
12
+ """A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
14
13
 
15
- This mixin provides common functionality for YOLO models that use visual prompting, including
16
- model setup, prompt handling, and preprocessing transformations.
14
+ This mixin provides common functionality for YOLO models that use visual prompting, including model setup, prompt
15
+ handling, and preprocessing transformations.
17
16
 
18
17
  Attributes:
19
18
  model (torch.nn.Module): The YOLO model for inference.
@@ -29,8 +28,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
29
28
  """
30
29
 
31
30
  def setup_model(self, model, verbose: bool = True):
32
- """
33
- Set up the model for prediction.
31
+ """Set up the model for prediction.
34
32
 
35
33
  Args:
36
34
  model (torch.nn.Module): Model to load or use.
@@ -40,21 +38,19 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
40
38
  self.done_warmup = True
41
39
 
42
40
  def set_prompts(self, prompts):
43
- """
44
- Set the visual prompts for the model.
41
+ """Set the visual prompts for the model.
45
42
 
46
43
  Args:
47
- prompts (dict): Dictionary containing class indices and bounding boxes or masks.
48
- Must include a 'cls' key with class indices.
44
+ prompts (dict): Dictionary containing class indices and bounding boxes or masks. Must include a 'cls' key
45
+ with class indices.
49
46
  """
50
47
  self.prompts = prompts
51
48
 
52
49
  def pre_transform(self, im):
53
- """
54
- Preprocess images and prompts before inference.
50
+ """Preprocess images and prompts before inference.
55
51
 
56
- This method applies letterboxing to the input image and transforms the visual prompts
57
- (bounding boxes or masks) accordingly.
52
+ This method applies letterboxing to the input image and transforms the visual prompts (bounding boxes or masks)
53
+ accordingly.
58
54
 
59
55
  Args:
60
56
  im (list): List containing a single input image.
@@ -94,8 +90,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
94
90
  return img
95
91
 
96
92
  def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
97
- """
98
- Process a single image by resizing bounding boxes or masks and generating visuals.
93
+ """Process a single image by resizing bounding boxes or masks and generating visuals.
99
94
 
100
95
  Args:
101
96
  dst_shape (tuple): The target shape (height, width) of the image.
@@ -131,8 +126,7 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
131
126
  return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
132
127
 
133
128
  def inference(self, im, *args, **kwargs):
134
- """
135
- Run inference with visual prompts.
129
+ """Run inference with visual prompts.
136
130
 
137
131
  Args:
138
132
  im (torch.Tensor): Input image tensor.
@@ -145,13 +139,12 @@ class YOLOEVPDetectPredictor(DetectionPredictor):
145
139
  return super().inference(im, vpe=self.prompts, *args, **kwargs)
146
140
 
147
141
  def get_vpe(self, source):
148
- """
149
- Process the source to get the visual prompt embeddings (VPE).
142
+ """Process the source to get the visual prompt embeddings (VPE).
150
143
 
151
144
  Args:
152
- source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
153
- of the image to make predictions on. Accepts various types including file paths, URLs, PIL
154
- images, numpy arrays, and torch tensors.
145
+ source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source of the image to
146
+ make predictions on. Accepts various types including file paths, URLs, PIL images, numpy arrays, and
147
+ torch tensors.
155
148
 
156
149
  Returns:
157
150
  (torch.Tensor): The visual prompt embeddings (VPE) from the model.