ultralytics 8.1.28__py3-none-any.whl → 8.3.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +36 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +110 -40
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +569 -242
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +190 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +527 -67
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +225 -77
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +160 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +44 -37
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +84 -56
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +302 -102
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.62.dist-info/METADATA +370 -0
  235. ultralytics-8.3.62.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.28.dist-info/METADATA +0 -373
  244. ultralytics-8.1.28.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.28.dist-info → ultralytics-8.3.62.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,21 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
4
 
5
5
  from ultralytics.engine.model import Model
6
6
  from ultralytics.models import yolo
7
7
  from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
8
- from ultralytics.utils import yaml_load, ROOT
8
+ from ultralytics.utils import ROOT, yaml_load
9
9
 
10
10
 
11
11
  class YOLO(Model):
12
12
  """YOLO (You Only Look Once) object detection model."""
13
13
 
14
- def __init__(self, model="yolov8n.pt", task=None, verbose=False):
14
+ def __init__(self, model="yolo11n.pt", task=None, verbose=False):
15
15
  """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
16
16
  path = Path(model)
17
17
  if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
18
- new_instance = YOLOWorld(path)
18
+ new_instance = YOLOWorld(path, verbose=verbose)
19
19
  self.__class__ = type(new_instance)
20
20
  self.__dict__ = new_instance.__dict__
21
21
  else:
@@ -62,14 +62,18 @@ class YOLO(Model):
62
62
  class YOLOWorld(Model):
63
63
  """YOLO-World object detection model."""
64
64
 
65
- def __init__(self, model="yolov8s-world.pt") -> None:
65
+ def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
66
66
  """
67
- Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
67
+ Initialize YOLOv8-World model with a pre-trained model file.
68
+
69
+ Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
70
+ COCO class names.
68
71
 
69
72
  Args:
70
- model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
73
+ model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
74
+ verbose (bool): If True, prints additional information during initialization.
71
75
  """
72
- super().__init__(model=model, task="detect")
76
+ super().__init__(model=model, task="detect", verbose=verbose)
73
77
 
74
78
  # Assign default COCO class names when there are no custom names
75
79
  if not hasattr(self.model, "names"):
@@ -83,6 +87,7 @@ class YOLOWorld(Model):
83
87
  "model": WorldModel,
84
88
  "validator": yolo.detect.DetectionValidator,
85
89
  "predictor": yolo.detect.DetectionPredictor,
90
+ "trainer": yolo.world.WorldTrainer,
86
91
  }
87
92
  }
88
93
 
@@ -91,7 +96,7 @@ class YOLOWorld(Model):
91
96
  Set classes.
92
97
 
93
98
  Args:
94
- classes (List(str)): A list of categories i.e ["person"].
99
+ classes (List(str)): A list of categories i.e. ["person"].
95
100
  """
96
101
  self.model.set_classes(classes)
97
102
  # Remove background if it's given
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .predict import OBBPredictor
4
4
  from .train import OBBTrainer
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import torch
4
4
 
@@ -16,7 +16,7 @@ class OBBPredictor(DetectionPredictor):
16
16
  from ultralytics.utils import ASSETS
17
17
  from ultralytics.models.yolo.obb import OBBPredictor
18
18
 
19
- args = dict(model='yolov8n-obb.pt', source=ASSETS)
19
+ args = dict(model="yolov8n-obb.pt", source=ASSETS)
20
20
  predictor = OBBPredictor(overrides=args)
21
21
  predictor.predict_cli()
22
22
  ```
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
4
 
@@ -15,7 +15,7 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
15
15
  ```python
16
16
  from ultralytics.models.yolo.obb import OBBTrainer
17
17
 
18
- args = dict(model='yolov8n-obb.pt', data='dota8.yaml', epochs=3)
18
+ args = dict(model="yolov8n-obb.pt", data="dota8.yaml", epochs=3)
19
19
  trainer = OBBTrainer(overrides=args)
20
20
  trainer.train()
21
21
  ```
@@ -39,4 +39,6 @@ class OBBTrainer(yolo.detect.DetectionTrainer):
39
39
  def get_validator(self):
40
40
  """Return an instance of OBBValidator for validation of YOLO model."""
41
41
  self.loss_names = "box_loss", "cls_loss", "dfl_loss"
42
- return yolo.obb.OBBValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
42
+ return yolo.obb.OBBValidator(
43
+ self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
44
+ )
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
4
 
@@ -18,9 +18,9 @@ class OBBValidator(DetectionValidator):
18
18
  ```python
19
19
  from ultralytics.models.yolo.obb import OBBValidator
20
20
 
21
- args = dict(model='yolov8n-obb.pt', data='dota8.yaml')
21
+ args = dict(model="yolov8n-obb.pt", data="dota8.yaml")
22
22
  validator = OBBValidator(args=args)
23
- validator(model=args['model'])
23
+ validator(model=args["model"])
24
24
  ```
25
25
  """
26
26
 
@@ -45,24 +45,36 @@ class OBBValidator(DetectionValidator):
45
45
  labels=self.lb,
46
46
  nc=self.nc,
47
47
  multi_label=True,
48
- agnostic=self.args.single_cls,
48
+ agnostic=self.args.single_cls or self.args.agnostic_nms,
49
49
  max_det=self.args.max_det,
50
50
  rotated=True,
51
51
  )
52
52
 
53
53
  def _process_batch(self, detections, gt_bboxes, gt_cls):
54
54
  """
55
- Return correct prediction matrix.
55
+ Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
56
56
 
57
57
  Args:
58
- detections (torch.Tensor): Tensor of shape [N, 7] representing detections.
59
- Each detection is of the format: x1, y1, x2, y2, conf, class, angle.
60
- gt_bboxes (torch.Tensor): Tensor of shape [M, 5] representing rotated boxes.
61
- Each box is of the format: x1, y1, x2, y2, angle.
62
- labels (torch.Tensor): Tensor of shape [M] representing labels.
58
+ detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated
59
+ data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle).
60
+ gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is
61
+ represented as (x1, y1, x2, y2, angle).
62
+ gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes.
63
63
 
64
64
  Returns:
65
- (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
65
+ (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over
66
+ Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth.
67
+
68
+ Example:
69
+ ```python
70
+ detections = torch.rand(100, 7) # 100 sample detections
71
+ gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
72
+ gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
73
+ correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls)
74
+ ```
75
+
76
+ Note:
77
+ This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes.
66
78
  """
67
79
  iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1))
68
80
  return self.match_predictions(detections[:, 5], gt_cls, iou)
@@ -78,7 +90,7 @@ class OBBValidator(DetectionValidator):
78
90
  if len(cls):
79
91
  bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
80
92
  ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels
81
- return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
93
+ return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
82
94
 
83
95
  def _prepare_pred(self, pred, pbatch):
84
96
  """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes."""
@@ -118,13 +130,19 @@ class OBBValidator(DetectionValidator):
118
130
 
119
131
  def save_one_txt(self, predn, save_conf, shape, file):
120
132
  """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
121
- gn = torch.tensor(shape)[[1, 0]] # normalization gain whwh
122
- for *xywh, conf, cls, angle in predn.tolist():
123
- xywha = torch.tensor([*xywh, angle]).view(1, 5)
124
- xyxyxyxy = (ops.xywhr2xyxyxyxy(xywha) / gn).view(-1).tolist() # normalized xywh
125
- line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
126
- with open(file, "a") as f:
127
- f.write(("%g " * len(line)).rstrip() % line + "\n")
133
+ import numpy as np
134
+
135
+ from ultralytics.engine.results import Results
136
+
137
+ rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1)
138
+ # xywh, r, conf, cls
139
+ obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1)
140
+ Results(
141
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
142
+ path=None,
143
+ names=self.names,
144
+ obb=obb,
145
+ ).save_txt(file, save_conf=save_conf)
128
146
 
129
147
  def eval_json(self, stats):
130
148
  """Evaluates YOLO output in JSON format and returns performance statistics."""
@@ -142,10 +160,10 @@ class OBBValidator(DetectionValidator):
142
160
  for d in data:
143
161
  image_id = d["image_id"]
144
162
  score = d["score"]
145
- classname = self.names[d["category_id"]].replace(" ", "-")
163
+ classname = self.names[d["category_id"] - 1].replace(" ", "-")
146
164
  p = d["poly"]
147
165
 
148
- with open(f'{pred_txt / f"Task1_{classname}"}.txt', "a") as f:
166
+ with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a") as f:
149
167
  f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
150
168
  # Save merged results, this could result slightly lower map than using official merging script,
151
169
  # because of the probiou calculation.
@@ -157,7 +175,7 @@ class OBBValidator(DetectionValidator):
157
175
  image_id = d["image_id"].split("__")[0]
158
176
  pattern = re.compile(r"\d+___\d+")
159
177
  x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
160
- bbox, score, cls = d["rbox"], d["score"], d["category_id"]
178
+ bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
161
179
  bbox[0] += x
162
180
  bbox[1] += y
163
181
  bbox.extend([score, cls])
@@ -179,7 +197,7 @@ class OBBValidator(DetectionValidator):
179
197
  p = [round(i, 3) for i in x[:-2]] # poly
180
198
  score = round(x[-2], 3)
181
199
 
182
- with open(f'{pred_merged_txt / f"Task1_{classname}"}.txt', "a") as f:
200
+ with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a") as f:
183
201
  f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
184
202
 
185
203
  return stats
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .predict import PosePredictor
4
4
  from .train import PoseTrainer
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from ultralytics.engine.results import Results
4
4
  from ultralytics.models.yolo.detect.predict import DetectionPredictor
@@ -14,7 +14,7 @@ class PosePredictor(DetectionPredictor):
14
14
  from ultralytics.utils import ASSETS
15
15
  from ultralytics.models.yolo.pose import PosePredictor
16
16
 
17
- args = dict(model='yolov8n-pose.pt', source=ASSETS)
17
+ args = dict(model="yolov8n-pose.pt", source=ASSETS)
18
18
  predictor = PosePredictor(overrides=args)
19
19
  predictor.predict_cli()
20
20
  ```
@@ -46,12 +46,10 @@ class PosePredictor(DetectionPredictor):
46
46
  orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
47
47
 
48
48
  results = []
49
- for i, pred in enumerate(preds):
50
- orig_img = orig_imgs[i]
49
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
51
50
  pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
52
51
  pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
53
52
  pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
54
- img_path = self.batch[0][i]
55
53
  results.append(
56
54
  Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
57
55
  )
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
4
 
@@ -16,7 +16,7 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
16
16
  ```python
17
17
  from ultralytics.models.yolo.pose import PoseTrainer
18
18
 
19
- args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml', epochs=3)
19
+ args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3)
20
20
  trainer = PoseTrainer(overrides=args)
21
21
  trainer.train()
22
22
  ```
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from pathlib import Path
4
4
 
@@ -20,7 +20,7 @@ class PoseValidator(DetectionValidator):
20
20
  ```python
21
21
  from ultralytics.models.yolo.pose import PoseValidator
22
22
 
23
- args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
23
+ args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml")
24
24
  validator = PoseValidator(args=args)
25
25
  validator()
26
26
  ```
@@ -69,7 +69,7 @@ class PoseValidator(DetectionValidator):
69
69
  self.args.iou,
70
70
  labels=self.lb,
71
71
  multi_label=True,
72
- agnostic=self.args.single_cls,
72
+ agnostic=self.args.single_cls or self.args.agnostic_nms,
73
73
  max_det=self.args.max_det,
74
74
  nc=self.nc,
75
75
  )
@@ -81,7 +81,7 @@ class PoseValidator(DetectionValidator):
81
81
  is_pose = self.kpt_shape == [17, 3]
82
82
  nkpt = self.kpt_shape[0]
83
83
  self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
84
- self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[])
84
+ self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
85
85
 
86
86
  def _prepare_batch(self, si, batch):
87
87
  """Prepares a batch for processing by converting keypoints to float and moving to device."""
@@ -118,6 +118,7 @@ class PoseValidator(DetectionValidator):
118
118
  cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
119
119
  nl = len(cls)
120
120
  stat["target_cls"] = cls
121
+ stat["target_img"] = cls.unique()
121
122
  if npr == 0:
122
123
  if nl:
123
124
  for k in self.stats.keys():
@@ -137,8 +138,8 @@ class PoseValidator(DetectionValidator):
137
138
  if nl:
138
139
  stat["tp"] = self._process_batch(predn, bbox, cls)
139
140
  stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
140
- if self.args.plots:
141
- self.confusion_matrix.process_batch(predn, bbox, cls)
141
+ if self.args.plots:
142
+ self.confusion_matrix.process_batch(predn, bbox, cls)
142
143
 
143
144
  for k in self.stats.keys():
144
145
  self.stats[k].append(stat[k])
@@ -146,24 +147,45 @@ class PoseValidator(DetectionValidator):
146
147
  # Save
147
148
  if self.args.save_json:
148
149
  self.pred_to_json(predn, batch["im_file"][si])
149
- # if self.args.save_txt:
150
- # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
150
+ if self.args.save_txt:
151
+ self.save_one_txt(
152
+ predn,
153
+ pred_kpts,
154
+ self.args.save_conf,
155
+ pbatch["ori_shape"],
156
+ self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
157
+ )
151
158
 
152
159
  def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
153
160
  """
154
- Return correct prediction matrix.
161
+ Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
155
162
 
156
163
  Args:
157
- detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
158
- Each detection is of the format: x1, y1, x2, y2, conf, class.
159
- labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
160
- Each label is of the format: class, x1, y1, x2, y2.
161
- pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints.
162
- 51 corresponds to 17 keypoints each with 3 values.
163
- gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints.
164
+ detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each
165
+ detection is of the format (x1, y1, x2, y2, conf, class).
166
+ gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each
167
+ box is of the format (x1, y1, x2, y2).
168
+ gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices.
169
+ pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where
170
+ 51 corresponds to 17 keypoints each having 3 values.
171
+ gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
164
172
 
165
173
  Returns:
166
- torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels.
174
+ torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
175
+ where N is the number of detections.
176
+
177
+ Example:
178
+ ```python
179
+ detections = torch.rand(100, 6) # 100 predictions: (x1, y1, x2, y2, conf, class)
180
+ gt_bboxes = torch.rand(50, 4) # 50 ground truth boxes: (x1, y1, x2, y2)
181
+ gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices
182
+ pred_kpts = torch.rand(100, 51) # 100 predicted keypoints
183
+ gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints
184
+ correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts)
185
+ ```
186
+
187
+ Note:
188
+ `0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
167
189
  """
168
190
  if pred_kpts is not None and gt_kpts is not None:
169
191
  # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
@@ -201,6 +223,18 @@ class PoseValidator(DetectionValidator):
201
223
  on_plot=self.on_plot,
202
224
  ) # pred
203
225
 
226
+ def save_one_txt(self, predn, pred_kpts, save_conf, shape, file):
227
+ """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
228
+ from ultralytics.engine.results import Results
229
+
230
+ Results(
231
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
232
+ path=None,
233
+ names=self.names,
234
+ boxes=predn[:, :6],
235
+ keypoints=pred_kpts,
236
+ ).save_txt(file, save_conf=save_conf)
237
+
204
238
  def pred_to_json(self, predn, filename):
205
239
  """Converts YOLO predictions to COCO JSON format."""
206
240
  stem = Path(filename).stem
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from .predict import SegmentationPredictor
4
4
  from .train import SegmentationTrainer
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from ultralytics.engine.results import Results
4
4
  from ultralytics.models.yolo.detect.predict import DetectionPredictor
@@ -14,7 +14,7 @@ class SegmentationPredictor(DetectionPredictor):
14
14
  from ultralytics.utils import ASSETS
15
15
  from ultralytics.models.yolo.segment import SegmentationPredictor
16
16
 
17
- args = dict(model='yolov8n-seg.pt', source=ASSETS)
17
+ args = dict(model="yolov8n-seg.pt", source=ASSETS)
18
18
  predictor = SegmentationPredictor(overrides=args)
19
19
  predictor.predict_cli()
20
20
  ```
@@ -42,9 +42,7 @@ class SegmentationPredictor(DetectionPredictor):
42
42
 
43
43
  results = []
44
44
  proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported
45
- for i, pred in enumerate(p):
46
- orig_img = orig_imgs[i]
47
- img_path = self.batch[0][i]
45
+ for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
48
46
  if not len(pred): # save empty boxes
49
47
  masks = None
50
48
  elif self.args.retina_masks:
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from copy import copy
4
4
 
@@ -16,7 +16,7 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
16
16
  ```python
17
17
  from ultralytics.models.yolo.segment import SegmentationTrainer
18
18
 
19
- args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
19
+ args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3)
20
20
  trainer = SegmentationTrainer(overrides=args)
21
21
  trainer.train()
22
22
  ```
@@ -1,4 +1,4 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from multiprocessing.pool import ThreadPool
4
4
  from pathlib import Path
@@ -22,7 +22,7 @@ class SegmentationValidator(DetectionValidator):
22
22
  ```python
23
23
  from ultralytics.models.yolo.segment import SegmentationValidator
24
24
 
25
- args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
25
+ args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml")
26
26
  validator = SegmentationValidator(args=args)
27
27
  validator()
28
28
  ```
@@ -48,10 +48,9 @@ class SegmentationValidator(DetectionValidator):
48
48
  self.plot_masks = []
49
49
  if self.args.save_json:
50
50
  check_requirements("pycocotools>=2.0.6")
51
- self.process = ops.process_mask_upsample # more accurate
52
- else:
53
- self.process = ops.process_mask # faster
54
- self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[])
51
+ # more accurate vs faster
52
+ self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
53
+ self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
55
54
 
56
55
  def get_desc(self):
57
56
  """Return a formatted description of evaluation metrics."""
@@ -77,7 +76,7 @@ class SegmentationValidator(DetectionValidator):
77
76
  self.args.iou,
78
77
  labels=self.lb,
79
78
  multi_label=True,
80
- agnostic=self.args.single_cls,
79
+ agnostic=self.args.single_cls or self.args.agnostic_nms,
81
80
  max_det=self.args.max_det,
82
81
  nc=self.nc,
83
82
  )
@@ -112,6 +111,7 @@ class SegmentationValidator(DetectionValidator):
112
111
  cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
113
112
  nl = len(cls)
114
113
  stat["target_cls"] = cls
114
+ stat["target_img"] = cls.unique()
115
115
  if npr == 0:
116
116
  if nl:
117
117
  for k in self.stats.keys():
@@ -135,8 +135,8 @@ class SegmentationValidator(DetectionValidator):
135
135
  stat["tp_m"] = self._process_batch(
136
136
  predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True
137
137
  )
138
- if self.args.plots:
139
- self.confusion_matrix.process_batch(predn, bbox, cls)
138
+ if self.args.plots:
139
+ self.confusion_matrix.process_batch(predn, bbox, cls)
140
140
 
141
141
  for k in self.stats.keys():
142
142
  self.stats[k].append(stat[k])
@@ -147,14 +147,23 @@ class SegmentationValidator(DetectionValidator):
147
147
 
148
148
  # Save
149
149
  if self.args.save_json:
150
- pred_masks = ops.scale_image(
151
- pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
150
+ self.pred_to_json(
151
+ predn,
152
+ batch["im_file"][si],
153
+ ops.scale_image(
154
+ pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(),
155
+ pbatch["ori_shape"],
156
+ ratio_pad=batch["ratio_pad"][si],
157
+ ),
158
+ )
159
+ if self.args.save_txt:
160
+ self.save_one_txt(
161
+ predn,
162
+ pred_masks,
163
+ self.args.save_conf,
152
164
  pbatch["ori_shape"],
153
- ratio_pad=batch["ratio_pad"][si],
165
+ self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt",
154
166
  )
155
- self.pred_to_json(predn, batch["im_file"][si], pred_masks)
156
- # if self.args.save_txt:
157
- # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
158
167
 
159
168
  def finalize_metrics(self, *args, **kwargs):
160
169
  """Sets speed and confusion matrix for evaluation metrics."""
@@ -163,14 +172,34 @@ class SegmentationValidator(DetectionValidator):
163
172
 
164
173
  def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False):
165
174
  """
166
- Return correct prediction matrix.
175
+ Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
167
176
 
168
177
  Args:
169
- detections (array[N, 6]), x1, y1, x2, y2, conf, class
170
- labels (array[M, 5]), class, x1, y1, x2, y2
178
+ detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and
179
+ associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class].
180
+ gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates.
181
+ Each row is of the format [x1, y1, x2, y2].
182
+ gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices.
183
+ pred_masks (torch.Tensor | None): Tensor representing predicted masks, if available. The shape should
184
+ match the ground truth masks.
185
+ gt_masks (torch.Tensor | None): Tensor of shape (M, H, W) representing ground truth masks, if available.
186
+ overlap (bool): Flag indicating if overlapping masks should be considered.
187
+ masks (bool): Flag indicating if the batch contains mask data.
171
188
 
172
189
  Returns:
173
- correct (array[N, 10]), for 10 IoU levels
190
+ (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels.
191
+
192
+ Note:
193
+ - If `masks` is True, the function computes IoU between predicted and ground truth masks.
194
+ - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
195
+
196
+ Example:
197
+ ```python
198
+ detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]])
199
+ gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]])
200
+ gt_cls = torch.tensor([1, 0])
201
+ correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls)
202
+ ```
174
203
  """
175
204
  if masks:
176
205
  if overlap:
@@ -214,6 +243,18 @@ class SegmentationValidator(DetectionValidator):
214
243
  ) # pred
215
244
  self.plot_masks.clear()
216
245
 
246
+ def save_one_txt(self, predn, pred_masks, save_conf, shape, file):
247
+ """Save YOLO detections to a txt file in normalized coordinates in a specific format."""
248
+ from ultralytics.engine.results import Results
249
+
250
+ Results(
251
+ np.zeros((shape[0], shape[1]), dtype=np.uint8),
252
+ path=None,
253
+ names=self.names,
254
+ boxes=predn[:, :6],
255
+ masks=pred_masks,
256
+ ).save_txt(file, save_conf=save_conf)
257
+
217
258
  def pred_to_json(self, predn, filename, pred_masks):
218
259
  """
219
260
  Save one JSON result.
@@ -0,0 +1,5 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from .train import WorldTrainer
4
+
5
+ __all__ = ["WorldTrainer"]