ultralytics 8.0.196__tar.gz → 8.0.198__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (186) hide show
  1. {ultralytics-8.0.196/ultralytics.egg-info → ultralytics-8.0.198}/PKG-INFO +1 -1
  2. {ultralytics-8.0.196 → ultralytics-8.0.198}/tests/test_cli.py +2 -2
  3. {ultralytics-8.0.196 → ultralytics-8.0.198}/tests/test_cuda.py +3 -20
  4. {ultralytics-8.0.196 → ultralytics-8.0.198}/tests/test_engine.py +1 -1
  5. ultralytics-8.0.198/tests/test_integrations.py +26 -0
  6. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/__init__.py +1 -1
  7. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/__init__.py +4 -5
  8. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/augment.py +2 -2
  9. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/converter.py +12 -13
  10. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/dataset.py +1 -1
  11. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/engine/exporter.py +1 -1
  12. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/engine/trainer.py +2 -1
  13. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/hub/session.py +1 -1
  14. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/fastsam/predict.py +33 -2
  15. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/fastsam/prompt.py +38 -1
  16. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/fastsam/utils.py +5 -5
  17. ultralytics-8.0.198/ultralytics/models/fastsam/val.py +40 -0
  18. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/nas/model.py +20 -0
  19. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/nas/predict.py +23 -0
  20. ultralytics-8.0.198/ultralytics/models/nas/val.py +48 -0
  21. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/rtdetr/val.py +17 -5
  22. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/modules/decoders.py +26 -1
  23. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/modules/encoders.py +31 -3
  24. ultralytics-8.0.198/ultralytics/models/sam/modules/sam.py +64 -0
  25. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/modules/tiny_encoder.py +147 -45
  26. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/modules/transformer.py +47 -2
  27. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/predict.py +19 -2
  28. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/utils/loss.py +20 -2
  29. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/utils/ops.py +5 -5
  30. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/modules/block.py +33 -10
  31. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/modules/conv.py +16 -4
  32. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/modules/head.py +48 -17
  33. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/modules/transformer.py +2 -2
  34. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/tasks.py +7 -7
  35. ultralytics-8.0.198/ultralytics/trackers/utils/__init__.py +1 -0
  36. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/__init__.py +2 -1
  37. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/benchmarks.py +13 -0
  38. ultralytics-8.0.198/ultralytics/utils/callbacks/mlflow.py +107 -0
  39. ultralytics-8.0.198/ultralytics/utils/callbacks/wb.py +156 -0
  40. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/checks.py +4 -4
  41. ultralytics-8.0.198/ultralytics/utils/errors.py +22 -0
  42. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/files.py +1 -1
  43. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/instance.py +41 -3
  44. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/loss.py +22 -19
  45. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/metrics.py +106 -24
  46. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/tal.py +1 -1
  47. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/torch_utils.py +4 -2
  48. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/tuner.py +10 -4
  49. {ultralytics-8.0.196 → ultralytics-8.0.198/ultralytics.egg-info}/PKG-INFO +1 -1
  50. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics.egg-info/SOURCES.txt +1 -0
  51. ultralytics-8.0.196/ultralytics/engine/__init__.py +0 -0
  52. ultralytics-8.0.196/ultralytics/models/fastsam/val.py +0 -14
  53. ultralytics-8.0.196/ultralytics/models/nas/val.py +0 -24
  54. ultralytics-8.0.196/ultralytics/models/sam/modules/sam.py +0 -49
  55. ultralytics-8.0.196/ultralytics/utils/callbacks/mlflow.py +0 -67
  56. ultralytics-8.0.196/ultralytics/utils/callbacks/wb.py +0 -65
  57. ultralytics-8.0.196/ultralytics/utils/errors.py +0 -10
  58. {ultralytics-8.0.196 → ultralytics-8.0.198}/CONTRIBUTING.md +0 -0
  59. {ultralytics-8.0.196 → ultralytics-8.0.198}/LICENSE +0 -0
  60. {ultralytics-8.0.196 → ultralytics-8.0.198}/MANIFEST.in +0 -0
  61. {ultralytics-8.0.196 → ultralytics-8.0.198}/README.md +0 -0
  62. {ultralytics-8.0.196 → ultralytics-8.0.198}/README.zh-CN.md +0 -0
  63. {ultralytics-8.0.196 → ultralytics-8.0.198}/requirements.txt +0 -0
  64. {ultralytics-8.0.196 → ultralytics-8.0.198}/setup.cfg +0 -0
  65. {ultralytics-8.0.196 → ultralytics-8.0.198}/setup.py +0 -0
  66. {ultralytics-8.0.196 → ultralytics-8.0.198}/tests/conftest.py +0 -0
  67. {ultralytics-8.0.196 → ultralytics-8.0.198}/tests/test_python.py +0 -0
  68. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/assets/bus.jpg +0 -0
  69. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/assets/zidane.jpg +0 -0
  70. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/Argoverse.yaml +0 -0
  71. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/DOTAv2.yaml +0 -0
  72. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/GlobalWheat2020.yaml +0 -0
  73. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/ImageNet.yaml +0 -0
  74. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/Objects365.yaml +0 -0
  75. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/SKU-110K.yaml +0 -0
  76. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/VOC.yaml +0 -0
  77. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/VisDrone.yaml +0 -0
  78. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/coco-pose.yaml +0 -0
  79. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/coco.yaml +0 -0
  80. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/coco128-seg.yaml +0 -0
  81. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/coco128.yaml +0 -0
  82. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/coco8-pose.yaml +0 -0
  83. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/coco8-seg.yaml +0 -0
  84. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/coco8.yaml +0 -0
  85. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/open-images-v7.yaml +0 -0
  86. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/tiger-pose.yaml +0 -0
  87. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/datasets/xView.yaml +0 -0
  88. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/default.yaml +0 -0
  89. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +0 -0
  90. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +0 -0
  91. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v3/yolov3-spp.yaml +0 -0
  92. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v3/yolov3-tiny.yaml +0 -0
  93. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v3/yolov3.yaml +0 -0
  94. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v5/yolov5-p6.yaml +0 -0
  95. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v5/yolov5.yaml +0 -0
  96. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v6/yolov6.yaml +0 -0
  97. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8-cls.yaml +0 -0
  98. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8-p2.yaml +0 -0
  99. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8-p6.yaml +0 -0
  100. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +0 -0
  101. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8-pose.yaml +0 -0
  102. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +0 -0
  103. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +0 -0
  104. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8-seg.yaml +0 -0
  105. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/models/v8/yolov8.yaml +0 -0
  106. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/trackers/botsort.yaml +0 -0
  107. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/cfg/trackers/bytetrack.yaml +0 -0
  108. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/__init__.py +0 -0
  109. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/annotator.py +0 -0
  110. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/base.py +0 -0
  111. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/build.py +0 -0
  112. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/loaders.py +0 -0
  113. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/data/utils.py +0 -0
  114. {ultralytics-8.0.196/ultralytics/models/sam/modules → ultralytics-8.0.198/ultralytics/engine}/__init__.py +0 -0
  115. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/engine/model.py +0 -0
  116. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/engine/predictor.py +0 -0
  117. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/engine/results.py +0 -0
  118. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/engine/tuner.py +0 -0
  119. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/engine/validator.py +0 -0
  120. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/hub/__init__.py +0 -0
  121. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/hub/auth.py +0 -0
  122. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/hub/utils.py +0 -0
  123. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/__init__.py +0 -0
  124. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/fastsam/__init__.py +0 -0
  125. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/fastsam/model.py +0 -0
  126. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/nas/__init__.py +0 -0
  127. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/rtdetr/__init__.py +0 -0
  128. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/rtdetr/model.py +0 -0
  129. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/rtdetr/predict.py +0 -0
  130. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/rtdetr/train.py +0 -0
  131. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/__init__.py +0 -0
  132. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/amg.py +0 -0
  133. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/build.py +0 -0
  134. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/sam/model.py +0 -0
  135. {ultralytics-8.0.196/ultralytics/models/utils → ultralytics-8.0.198/ultralytics/models/sam/modules}/__init__.py +0 -0
  136. {ultralytics-8.0.196/ultralytics/trackers → ultralytics-8.0.198/ultralytics/models}/utils/__init__.py +0 -0
  137. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/__init__.py +0 -0
  138. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/classify/__init__.py +0 -0
  139. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/classify/predict.py +0 -0
  140. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/classify/train.py +0 -0
  141. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/classify/val.py +0 -0
  142. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/detect/__init__.py +0 -0
  143. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/detect/predict.py +0 -0
  144. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/detect/train.py +0 -0
  145. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/detect/val.py +0 -0
  146. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/model.py +0 -0
  147. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/pose/__init__.py +0 -0
  148. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/pose/predict.py +0 -0
  149. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/pose/train.py +0 -0
  150. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/pose/val.py +0 -0
  151. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/segment/__init__.py +0 -0
  152. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/segment/predict.py +0 -0
  153. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/segment/train.py +0 -0
  154. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/models/yolo/segment/val.py +0 -0
  155. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/__init__.py +0 -0
  156. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/autobackend.py +0 -0
  157. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/modules/__init__.py +0 -0
  158. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/nn/modules/utils.py +0 -0
  159. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/trackers/__init__.py +0 -0
  160. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/trackers/basetrack.py +0 -0
  161. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/trackers/bot_sort.py +0 -0
  162. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/trackers/byte_tracker.py +0 -0
  163. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/trackers/track.py +0 -0
  164. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/trackers/utils/gmc.py +0 -0
  165. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/trackers/utils/kalman_filter.py +0 -0
  166. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/trackers/utils/matching.py +0 -0
  167. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/autobatch.py +0 -0
  168. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/__init__.py +0 -0
  169. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/base.py +0 -0
  170. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/clearml.py +0 -0
  171. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/comet.py +0 -0
  172. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/dvc.py +0 -0
  173. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/hub.py +0 -0
  174. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/neptune.py +0 -0
  175. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/raytune.py +0 -0
  176. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/callbacks/tensorboard.py +0 -0
  177. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/dist.py +0 -0
  178. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/downloads.py +0 -0
  179. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/ops.py +0 -0
  180. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/patches.py +0 -0
  181. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/plotting.py +0 -0
  182. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics/utils/triton.py +0 -0
  183. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics.egg-info/dependency_links.txt +0 -0
  184. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics.egg-info/entry_points.txt +0 -0
  185. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics.egg-info/requires.txt +0 -0
  186. {ultralytics-8.0.196 → ultralytics-8.0.198}/ultralytics.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ultralytics
3
- Version: 8.0.196
3
+ Version: 8.0.198
4
4
  Summary: Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Home-page: https://github.com/ultralytics/ultralytics
6
6
  Author: Ultralytics
@@ -97,8 +97,8 @@ def test_fastsam(task='segment', model=WEIGHTS_DIR / 'FastSAM-s.pt', data='coco8
97
97
  ann = prompt_process.text_prompt(text='a photo of a dog')
98
98
 
99
99
  # Point prompt
100
- # points default [[0,0]] [[x1,y1],[x2,y2]]
101
- # point_label default [0] [1,0] 0:background, 1:foreground
100
+ # Points default [[0,0]] [[x1,y1],[x2,y2]]
101
+ # Point_label default [0] [1,0] 0:background, 1:foreground
102
102
  ann = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
103
103
  prompt_process.plot(annotations=ann, output='./')
104
104
 
@@ -1,16 +1,13 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- import contextlib
4
-
5
3
  import pytest
6
4
  import torch
7
5
 
8
6
  from ultralytics import YOLO, download
9
- from ultralytics.utils import ASSETS, DATASETS_DIR, WEIGHTS_DIR
10
- from ultralytics.utils.checks import cuda_device_count, cuda_is_available
7
+ from ultralytics.utils import ASSETS, DATASETS_DIR, WEIGHTS_DIR, checks
11
8
 
12
- CUDA_IS_AVAILABLE = cuda_is_available()
13
- CUDA_DEVICE_COUNT = cuda_device_count()
9
+ CUDA_IS_AVAILABLE = checks.cuda_is_available()
10
+ CUDA_DEVICE_COUNT = checks.cuda_device_count()
14
11
 
15
12
  MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path
16
13
  DATA = 'coco8.yaml'
@@ -107,20 +104,6 @@ def test_predict_sam():
107
104
  predictor.reset_image()
108
105
 
109
106
 
110
- @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
111
- def test_model_ray_tune():
112
- """Tune YOLO model with Ray optimization library."""
113
- with contextlib.suppress(RuntimeError): # RuntimeError may be caused by out-of-memory
114
- YOLO('yolov8n-cls.yaml').tune(use_ray=True,
115
- data='imagenet10',
116
- grace_period=1,
117
- iterations=1,
118
- imgsz=32,
119
- epochs=1,
120
- plots=False,
121
- device='cpu')
122
-
123
-
124
107
  @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
125
108
  def test_model_tune():
126
109
  """Tune YOLO model for performance."""
@@ -70,7 +70,7 @@ def test_segment():
70
70
  CFG.imgsz = 32
71
71
  # YOLO(CFG_SEG).train(**overrides) # works
72
72
 
73
- # trainer
73
+ # Trainer
74
74
  trainer = segment.SegmentationTrainer(overrides=overrides)
75
75
  trainer.add_callback('on_train_start', test_func)
76
76
  assert test_func in trainer.callbacks['on_train_start'], 'callback test failed'
@@ -0,0 +1,26 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import pytest
4
+
5
+ from ultralytics import YOLO
6
+ from ultralytics.utils import SETTINGS, checks
7
+
8
+
9
+ @pytest.mark.skipif(not checks.check_requirements('ray', install=False), reason='RayTune not installed')
10
+ def test_model_ray_tune():
11
+ """Tune YOLO model with Ray optimization library."""
12
+ YOLO('yolov8n-cls.yaml').tune(use_ray=True,
13
+ data='imagenet10',
14
+ grace_period=1,
15
+ iterations=1,
16
+ imgsz=32,
17
+ epochs=1,
18
+ plots=False,
19
+ device='cpu')
20
+
21
+
22
+ @pytest.mark.skipif(not checks.check_requirements('mlflow', install=False), reason='MLflow not installed')
23
+ def test_mlflow():
24
+ """Test training with MLflow tracking enabled."""
25
+ SETTINGS['mlflow'] = True
26
+ YOLO('yolov8n-cls.yaml').train(data='imagenet10', imgsz=32, epochs=3, plots=False, device='cpu')
@@ -1,6 +1,6 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
- __version__ = '8.0.196'
3
+ __version__ = '8.0.198'
4
4
 
5
5
  from ultralytics.models import RTDETR, SAM, YOLO
6
6
  from ultralytics.models.fastsam import FastSAM
@@ -7,9 +7,9 @@ from pathlib import Path
7
7
  from types import SimpleNamespace
8
8
  from typing import Dict, List, Union
9
9
 
10
- from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, SETTINGS,
11
- SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks, colorstr,
12
- deprecation_warn, yaml_load, yaml_print)
10
+ from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, ROOT, RUNS_DIR,
11
+ SETTINGS, SETTINGS_YAML, TESTS_RUNNING, IterableSimpleNamespace, __version__, checks,
12
+ colorstr, deprecation_warn, yaml_load, yaml_print)
13
13
 
14
14
  # Define valid tasks and modes
15
15
  MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
@@ -153,8 +153,7 @@ def get_save_dir(args, name=None):
153
153
  else:
154
154
  from ultralytics.utils.files import increment_path
155
155
 
156
- project = args.project or (ROOT /
157
- '../tests/tmp/runs' if TESTS_RUNNING else Path(SETTINGS['runs_dir'])) / args.task
156
+ project = args.project or (ROOT.parent / 'tests/tmp/runs' if TESTS_RUNNING else RUNS_DIR) / args.task
158
157
  name = name or args.name or f'{args.mode}'
159
158
  save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
160
159
 
@@ -491,7 +491,7 @@ class RandomPerspective:
491
491
  border = labels.pop('mosaic_border', self.border)
492
492
  self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h
493
493
  # M is affine matrix
494
- # scale for func:`box_candidates`
494
+ # Scale for func:`box_candidates`
495
495
  img, M, scale = self.affine_transform(img, border)
496
496
 
497
497
  bboxes = self.apply_bboxes(instances.bboxes, M)
@@ -894,7 +894,7 @@ class Format:
894
894
  return labels
895
895
 
896
896
  def _format_img(self, img):
897
- """Format the image for YOLOv5 from Numpy array to PyTorch tensor."""
897
+ """Format the image for YOLO from Numpy array to PyTorch tensor."""
898
898
  if len(img.shape) < 3:
899
899
  img = np.expand_dims(img, -1)
900
900
  img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
@@ -1,14 +1,14 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  import json
4
- import shutil
5
4
  from collections import defaultdict
6
5
  from pathlib import Path
7
6
 
8
7
  import cv2
9
8
  import numpy as np
10
9
 
11
- from ultralytics.utils import TQDM
10
+ from ultralytics.utils import LOGGER, TQDM
11
+ from ultralytics.utils.files import increment_path
12
12
 
13
13
 
14
14
  def coco91_to_coco80_class():
@@ -48,12 +48,12 @@ def coco80_to_coco91_class(): #
48
48
 
49
49
 
50
50
  def convert_coco(labels_dir='../coco/annotations/',
51
- save_dir='.',
51
+ save_dir='coco_converted/',
52
52
  use_segments=False,
53
53
  use_keypoints=False,
54
54
  cls91to80=True):
55
55
  """
56
- Converts COCO dataset annotations to a format suitable for training YOLOv5 models.
56
+ Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
57
57
 
58
58
  Args:
59
59
  labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
@@ -74,9 +74,7 @@ def convert_coco(labels_dir='../coco/annotations/',
74
74
  """
75
75
 
76
76
  # Create dataset directory
77
- save_dir = Path(save_dir)
78
- if save_dir.exists():
79
- shutil.rmtree(save_dir) # delete dir
77
+ save_dir = increment_path(save_dir) # increment if save directory already exists
80
78
  for p in save_dir / 'labels', save_dir / 'images':
81
79
  p.mkdir(parents=True, exist_ok=True) # make dir
82
80
 
@@ -147,6 +145,8 @@ def convert_coco(labels_dir='../coco/annotations/',
147
145
  if use_segments and len(segments[i]) > 0 else bboxes[i]), # cls, box or segments
148
146
  file.write(('%g ' * len(line)).rstrip() % line + '\n')
149
147
 
148
+ LOGGER.info(f'COCO data converted successfully.\nResults saved to {save_dir.resolve()}')
149
+
150
150
 
151
151
  def convert_dota_to_yolo_obb(dota_root_path: str):
152
152
  """
@@ -271,26 +271,25 @@ def merge_multi_segment(segments):
271
271
  segments = [np.array(i).reshape(-1, 2) for i in segments]
272
272
  idx_list = [[] for _ in range(len(segments))]
273
273
 
274
- # record the indexes with min distance between each segment
274
+ # Record the indexes with min distance between each segment
275
275
  for i in range(1, len(segments)):
276
276
  idx1, idx2 = min_index(segments[i - 1], segments[i])
277
277
  idx_list[i - 1].append(idx1)
278
278
  idx_list[i].append(idx2)
279
279
 
280
- # use two round to connect all the segments
280
+ # Use two round to connect all the segments
281
281
  for k in range(2):
282
- # forward connection
282
+ # Forward connection
283
283
  if k == 0:
284
284
  for i, idx in enumerate(idx_list):
285
- # middle segments have two indexes
286
- # reverse the index of middle segments
285
+ # Middle segments have two indexes, reverse the index of middle segments
287
286
  if len(idx) == 2 and idx[0] > idx[1]:
288
287
  idx = idx[::-1]
289
288
  segments[i] = segments[i][::-1, :]
290
289
 
291
290
  segments[i] = np.roll(segments[i], -idx[0], axis=0)
292
291
  segments[i] = np.concatenate([segments[i], segments[i][:1]])
293
- # deal with the first segment and the last one
292
+ # Deal with the first segment and the last one
294
293
  if i in [0, len(idx_list) - 1]:
295
294
  s.append(segments[i])
296
295
  else:
@@ -162,7 +162,7 @@ class YOLODataset(BaseDataset):
162
162
  def update_labels_info(self, label):
163
163
  """Custom your label format here."""
164
164
  # NOTE: cls is not with bboxes now, classification and semantic segmentation need an independent cls label
165
- # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
165
+ # We can make it also support classification and semantic segmentation by add or remove some dict keys there.
166
166
  bboxes = label.pop('bboxes')
167
167
  segments = label.pop('segments')
168
168
  keypoints = label.pop('keypoints', None)
@@ -140,7 +140,7 @@ class Exporter:
140
140
  Args:
141
141
  cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
142
142
  overrides (dict, optional): Configuration overrides. Defaults to None.
143
- _callbacks (list, optional): List of callback functions. Defaults to None.
143
+ _callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
144
144
  """
145
145
  self.args = get_cfg(cfg, overrides)
146
146
  if self.args.format.lower() in ('coreml', 'mlmodel'): # fix attempt for protobuf<3.20.x errors
@@ -91,6 +91,7 @@ class BaseTrainer:
91
91
 
92
92
  # Dirs
93
93
  self.save_dir = get_save_dir(self.args)
94
+ self.args.name = self.save_dir.name # update name for loggers
94
95
  self.wdir = self.save_dir / 'weights' # weights dir
95
96
  if RANK in (-1, 0):
96
97
  self.wdir.mkdir(parents=True, exist_ok=True) # make dir
@@ -526,7 +527,7 @@ class BaseTrainer:
526
527
 
527
528
  # TODO: may need to put these following functions into callback
528
529
  def plot_training_samples(self, batch, ni):
529
- """Plots training samples during YOLOv5 training."""
530
+ """Plots training samples during YOLO training."""
530
531
  pass
531
532
 
532
533
  def plot_training_labels(self):
@@ -23,7 +23,7 @@ class HUBTrainingSession:
23
23
 
24
24
  Attributes:
25
25
  agent_id (str): Identifier for the instance communicating with the server.
26
- model_id (str): Identifier for the YOLOv5 model being trained.
26
+ model_id (str): Identifier for the YOLO model being trained.
27
27
  model_url (str): URL for the model in Ultralytics HUB.
28
28
  api_url (str): API URL for the model in Ultralytics HUB.
29
29
  auth_header (dict): Authentication header for the Ultralytics HUB API requests.
@@ -9,14 +9,45 @@ from ultralytics.utils import DEFAULT_CFG, ops
9
9
 
10
10
 
11
11
  class FastSAMPredictor(DetectionPredictor):
12
+ """
13
+ FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
14
+ YOLO framework.
15
+
16
+ This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
17
+ It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
18
+ for single-class segmentation.
19
+
20
+ Attributes:
21
+ cfg (dict): Configuration parameters for prediction.
22
+ overrides (dict, optional): Optional parameter overrides for custom behavior.
23
+ _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
24
+ """
12
25
 
13
26
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
14
- """Initializes FastSAMPredictor class by inheriting from DetectionPredictor and setting task to 'segment'."""
27
+ """
28
+ Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.
29
+
30
+ Args:
31
+ cfg (dict): Configuration parameters for prediction.
32
+ overrides (dict, optional): Optional parameter overrides for custom behavior.
33
+ _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
34
+ """
15
35
  super().__init__(cfg, overrides, _callbacks)
16
36
  self.args.task = 'segment'
17
37
 
18
38
  def postprocess(self, preds, img, orig_imgs):
19
- """Postprocesses the predictions, applies non-max suppression, scales the boxes, and returns the results."""
39
+ """
40
+ Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
41
+ size, and returns the final results.
42
+
43
+ Args:
44
+ preds (list): The raw output predictions from the model.
45
+ img (torch.Tensor): The processed image tensor.
46
+ orig_imgs (list | torch.Tensor): The original image or list of images.
47
+
48
+ Returns:
49
+ (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
50
+ """
20
51
  p = ops.non_max_suppression(
21
52
  preds[0],
22
53
  self.args.conf,
@@ -13,6 +13,15 @@ from ultralytics.utils import TQDM
13
13
 
14
14
 
15
15
  class FastSAMPrompt:
16
+ """
17
+ Fast Segment Anything Model class for image annotation and visualization.
18
+
19
+ Attributes:
20
+ device (str): Computing device ('cuda' or 'cpu').
21
+ results: Object detection or segmentation results.
22
+ source: Source image or image path.
23
+ clip: CLIP model for linear assignment.
24
+ """
16
25
 
17
26
  def __init__(self, source, results, device='cuda') -> None:
18
27
  """Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
@@ -92,12 +101,26 @@ class FastSAMPrompt:
92
101
  better_quality=True,
93
102
  retina=False,
94
103
  with_contours=True):
104
+ """
105
+ Plots annotations, bounding boxes, and points on images and saves the output.
106
+
107
+ Args:
108
+ annotations (list): Annotations to be plotted.
109
+ output (str or Path): Output directory for saving the plots.
110
+ bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
111
+ points (list, optional): Points to be plotted. Defaults to None.
112
+ point_label (list, optional): Labels for the points. Defaults to None.
113
+ mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
114
+ better_quality (bool, optional): Whether to apply morphological transformations for better mask quality. Defaults to True.
115
+ retina (bool, optional): Whether to use retina mask. Defaults to False.
116
+ with_contours (bool, optional): Whether to plot contours. Defaults to True.
117
+ """
95
118
  pbar = TQDM(annotations, total=len(annotations))
96
119
  for ann in pbar:
97
120
  result_name = os.path.basename(ann.path)
98
121
  image = ann.orig_img[..., ::-1] # BGR to RGB
99
122
  original_h, original_w = ann.orig_shape
100
- # for macOS only
123
+ # For macOS only
101
124
  # plt.switch_backend('TkAgg')
102
125
  plt.figure(figsize=(original_w / 100, original_h / 100))
103
126
  # Add subplot with no margin.
@@ -160,6 +183,20 @@ class FastSAMPrompt:
160
183
  target_height=960,
161
184
  target_width=960,
162
185
  ):
186
+ """
187
+ Quickly shows the mask annotations on the given matplotlib axis.
188
+
189
+ Args:
190
+ annotation (array-like): Mask annotation.
191
+ ax (matplotlib.axes.Axes): Matplotlib axis.
192
+ random_color (bool, optional): Whether to use random color for masks. Defaults to False.
193
+ bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
194
+ points (list, optional): Points to be plotted. Defaults to None.
195
+ pointlabel (list, optional): Labels for the points. Defaults to None.
196
+ retinamask (bool, optional): Whether to use retina mask. Defaults to True.
197
+ target_height (int, optional): Target height for resizing. Defaults to 960.
198
+ target_width (int, optional): Target width for resizing. Defaults to 960.
199
+ """
163
200
  n, h, w = annotation.shape # batch, height, width
164
201
 
165
202
  areas = np.sum(annotation, axis=(1, 2))
@@ -42,23 +42,23 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
42
42
  high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
43
43
  """
44
44
  boxes = adjust_bboxes_to_image_border(boxes, image_shape)
45
- # obtain coordinates for intersections
45
+ # Obtain coordinates for intersections
46
46
  x1 = torch.max(box1[0], boxes[:, 0])
47
47
  y1 = torch.max(box1[1], boxes[:, 1])
48
48
  x2 = torch.min(box1[2], boxes[:, 2])
49
49
  y2 = torch.min(box1[3], boxes[:, 3])
50
50
 
51
- # compute the area of intersection
51
+ # Compute the area of intersection
52
52
  intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
53
53
 
54
- # compute the area of both individual boxes
54
+ # Compute the area of both individual boxes
55
55
  box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
56
56
  box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
57
57
 
58
- # compute the area of union
58
+ # Compute the area of union
59
59
  union = box1_area + box2_area - intersection
60
60
 
61
- # compute the IoU
61
+ # Compute the IoU
62
62
  iou = intersection / union # Should be shape (n, )
63
63
  if raw_output:
64
64
  return 0 if iou.numel() == 0 else iou
@@ -0,0 +1,40 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from ultralytics.models.yolo.segment import SegmentationValidator
4
+ from ultralytics.utils.metrics import SegmentMetrics
5
+
6
+
7
+ class FastSAMValidator(SegmentationValidator):
8
+ """
9
+ Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
10
+
11
+ Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class
12
+ sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
13
+ to avoid errors during validation.
14
+
15
+ Attributes:
16
+ dataloader: The data loader object used for validation.
17
+ save_dir (str): The directory where validation results will be saved.
18
+ pbar: A progress bar object.
19
+ args: Additional arguments for customization.
20
+ _callbacks: List of callback functions to be invoked during validation.
21
+ """
22
+
23
+ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
24
+ """
25
+ Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
26
+
27
+ Args:
28
+ dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
29
+ save_dir (Path, optional): Directory to save results.
30
+ pbar (tqdm.tqdm): Progress bar for displaying progress.
31
+ args (SimpleNamespace): Configuration for the validator.
32
+ _callbacks (dict): Dictionary to store various callback functions.
33
+
34
+ Notes:
35
+ Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
36
+ """
37
+ super().__init__(dataloader, save_dir, pbar, args, _callbacks)
38
+ self.args.task = 'segment'
39
+ self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
40
+ self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
@@ -23,6 +23,26 @@ from .val import NASValidator
23
23
 
24
24
 
25
25
  class NAS(Model):
26
+ """
27
+ YOLO NAS model for object detection.
28
+
29
+ This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
30
+ It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
31
+
32
+ Example:
33
+ ```python
34
+ from ultralytics import NAS
35
+
36
+ model = NAS('yolo_nas_s')
37
+ results = model.predict('ultralytics/assets/bus.jpg')
38
+ ```
39
+
40
+ Attributes:
41
+ model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
42
+
43
+ Note:
44
+ YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
45
+ """
26
46
 
27
47
  def __init__(self, model='yolo_nas_s.pt') -> None:
28
48
  """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
@@ -8,6 +8,29 @@ from ultralytics.utils import ops
8
8
 
9
9
 
10
10
  class NASPredictor(BasePredictor):
11
+ """
12
+ Ultralytics YOLO NAS Predictor for object detection.
13
+
14
+ This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the
15
+ raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
16
+ scaling the bounding boxes to fit the original image dimensions.
17
+
18
+ Attributes:
19
+ args (Namespace): Namespace containing various configurations for post-processing.
20
+
21
+ Example:
22
+ ```python
23
+ from ultralytics import NAS
24
+
25
+ model = NAS('yolo_nas_s')
26
+ predictor = model.predictor
27
+ # Assumes that raw_preds, img, orig_imgs are available
28
+ results = predictor.postprocess(raw_preds, img, orig_imgs)
29
+ ```
30
+
31
+ Note:
32
+ Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
33
+ """
11
34
 
12
35
  def postprocess(self, preds_in, img, orig_imgs):
13
36
  """Postprocess predictions and returns a list of Results objects."""
@@ -0,0 +1,48 @@
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import torch
4
+
5
+ from ultralytics.models.yolo.detect import DetectionValidator
6
+ from ultralytics.utils import ops
7
+
8
+ __all__ = ['NASValidator']
9
+
10
+
11
+ class NASValidator(DetectionValidator):
12
+ """
13
+ Ultralytics YOLO NAS Validator for object detection.
14
+
15
+ Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
16
+ generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
17
+ ultimately producing the final detections.
18
+
19
+ Attributes:
20
+ args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU thresholds.
21
+ lb (torch.Tensor): Optional tensor for multilabel NMS.
22
+
23
+ Example:
24
+ ```python
25
+ from ultralytics import NAS
26
+
27
+ model = NAS('yolo_nas_s')
28
+ validator = model.validator
29
+ # Assumes that raw_preds are available
30
+ final_preds = validator.postprocess(raw_preds)
31
+ ```
32
+
33
+ Note:
34
+ This class is generally not instantiated directly but is used internally within the `NAS` class.
35
+ """
36
+
37
+ def postprocess(self, preds_in):
38
+ """Apply Non-maximum suppression to prediction outputs."""
39
+ boxes = ops.xyxy2xywh(preds_in[0][0])
40
+ preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
41
+ return ops.non_max_suppression(preds,
42
+ self.args.conf,
43
+ self.args.iou,
44
+ labels=self.lb,
45
+ multi_label=False,
46
+ agnostic=self.args.single_cls,
47
+ max_det=self.args.max_det,
48
+ max_time_img=0.5)
@@ -12,14 +12,19 @@ from ultralytics.utils import colorstr, ops
12
12
  __all__ = 'RTDETRValidator', # tuple or list
13
13
 
14
14
 
15
- # TODO: Temporarily RT-DETR does not need padding.
16
15
  class RTDETRDataset(YOLODataset):
16
+ """
17
+ Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
18
+
19
+ This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
20
+ real-time detection and tracking tasks.
21
+ """
17
22
 
18
23
  def __init__(self, *args, data=None, **kwargs):
19
24
  """Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
20
25
  super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
21
26
 
22
- # NOTE: add stretch version load_image for rtdetr mosaic
27
+ # NOTE: add stretch version load_image for RTDETR mosaic
23
28
  def load_image(self, i, rect_mode=False):
24
29
  """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
25
30
  return super().load_image(i=i, rect_mode=rect_mode)
@@ -46,7 +51,11 @@ class RTDETRDataset(YOLODataset):
46
51
 
47
52
  class RTDETRValidator(DetectionValidator):
48
53
  """
49
- A class extending the DetectionValidator class for validation based on an RT-DETR detection model.
54
+ RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
55
+ the RT-DETR (Real-Time DETR) object detection model.
56
+
57
+ The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
58
+ post-processing, and updates evaluation metrics accordingly.
50
59
 
51
60
  Example:
52
61
  ```python
@@ -56,6 +65,9 @@ class RTDETRValidator(DetectionValidator):
56
65
  validator = RTDETRValidator(args=args)
57
66
  validator()
58
67
  ```
68
+
69
+ Note:
70
+ For further details on the attributes and methods, refer to the parent DetectionValidator class.
59
71
  """
60
72
 
61
73
  def build_dataset(self, img_path, mode='val', batch=None):
@@ -87,10 +99,10 @@ class RTDETRValidator(DetectionValidator):
87
99
  for i, bbox in enumerate(bboxes): # (300, 4)
88
100
  bbox = ops.xywh2xyxy(bbox)
89
101
  score, cls = scores[i].max(-1) # (300, )
90
- # Do not need threshold for evaluation as only got 300 boxes here.
102
+ # Do not need threshold for evaluation as only got 300 boxes here
91
103
  # idx = score > self.args.conf
92
104
  pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
93
- # sort by confidence to correctly get internal metrics.
105
+ # Sort by confidence to correctly get internal metrics
94
106
  pred = pred[score.argsort(descending=True)]
95
107
  outputs[i] = pred # [idx]
96
108