dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,116 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from typing import List
4
+ from urllib.parse import urlsplit
5
+
6
+ import numpy as np
7
+
8
+
9
+ class TritonRemoteModel:
10
+ """
11
+ Client for interacting with a remote Triton Inference Server model.
12
+
13
+ This class provides a convenient interface for sending inference requests to a Triton Inference Server
14
+ and processing the responses.
15
+
16
+ Attributes:
17
+ endpoint (str): The name of the model on the Triton server.
18
+ url (str): The URL of the Triton server.
19
+ triton_client: The Triton client (either HTTP or gRPC).
20
+ InferInput: The input class for the Triton client.
21
+ InferRequestedOutput: The output request class for the Triton client.
22
+ input_formats (List[str]): The data types of the model inputs.
23
+ np_input_formats (List[type]): The numpy data types of the model inputs.
24
+ input_names (List[str]): The names of the model inputs.
25
+ output_names (List[str]): The names of the model outputs.
26
+ metadata: The metadata associated with the model.
27
+
28
+ Methods:
29
+ __call__: Call the model with the given inputs and return the outputs.
30
+
31
+ Examples:
32
+ Initialize a Triton client with HTTP
33
+ >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
34
+ Make inference with numpy arrays
35
+ >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
36
+ """
37
+
38
+ def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
39
+ """
40
+ Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.
41
+
42
+ Arguments may be provided individually or parsed from a collective 'url' argument of the form
43
+ <scheme>://<netloc>/<endpoint>/<task_name>
44
+
45
+ Args:
46
+ url (str): The URL of the Triton server.
47
+ endpoint (str): The name of the model on the Triton server.
48
+ scheme (str): The communication scheme ('http' or 'grpc').
49
+
50
+ Examples:
51
+ >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
52
+ >>> model = TritonRemoteModel(url="http://localhost:8000/yolov8")
53
+ """
54
+ if not endpoint and not scheme: # Parse all args from URL string
55
+ splits = urlsplit(url)
56
+ endpoint = splits.path.strip("/").split("/")[0]
57
+ scheme = splits.scheme
58
+ url = splits.netloc
59
+
60
+ self.endpoint = endpoint
61
+ self.url = url
62
+
63
+ # Choose the Triton client based on the communication scheme
64
+ if scheme == "http":
65
+ import tritonclient.http as client # noqa
66
+
67
+ self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
68
+ config = self.triton_client.get_model_config(endpoint)
69
+ else:
70
+ import tritonclient.grpc as client # noqa
71
+
72
+ self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
73
+ config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
74
+
75
+ # Sort output names alphabetically, i.e. 'output0', 'output1', etc.
76
+ config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
77
+
78
+ # Define model attributes
79
+ type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
80
+ self.InferRequestedOutput = client.InferRequestedOutput
81
+ self.InferInput = client.InferInput
82
+ self.input_formats = [x["data_type"] for x in config["input"]]
83
+ self.np_input_formats = [type_map[x] for x in self.input_formats]
84
+ self.input_names = [x["name"] for x in config["input"]]
85
+ self.output_names = [x["name"] for x in config["output"]]
86
+ self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
87
+
88
+ def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
89
+ """
90
+ Call the model with the given inputs.
91
+
92
+ Args:
93
+ *inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type
94
+ for the corresponding model input.
95
+
96
+ Returns:
97
+ (List[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list
98
+ corresponds to one of the model's output tensors.
99
+
100
+ Examples:
101
+ >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
102
+ >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
103
+ """
104
+ infer_inputs = []
105
+ input_format = inputs[0].dtype
106
+ for i, x in enumerate(inputs):
107
+ if x.dtype != self.np_input_formats[i]:
108
+ x = x.astype(self.np_input_formats[i])
109
+ infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
110
+ infer_input.set_data_from_numpy(x)
111
+ infer_inputs.append(infer_input)
112
+
113
+ infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
114
+ outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
115
+
116
+ return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
@@ -0,0 +1,159 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
4
+ from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr
5
+
6
+
7
+ def run_ray_tune(
8
+ model,
9
+ space: dict = None,
10
+ grace_period: int = 10,
11
+ gpu_per_trial: int = None,
12
+ max_samples: int = 10,
13
+ **train_args,
14
+ ):
15
+ """
16
+ Run hyperparameter tuning using Ray Tune.
17
+
18
+ Args:
19
+ model (YOLO): Model to run the tuner on.
20
+ space (dict, optional): The hyperparameter search space.
21
+ grace_period (int, optional): The grace period in epochs of the ASHA scheduler.
22
+ gpu_per_trial (int, optional): The number of GPUs to allocate per trial.
23
+ max_samples (int, optional): The maximum number of trials to run.
24
+ **train_args (Any): Additional arguments to pass to the `train()` method.
25
+
26
+ Returns:
27
+ (dict): A dictionary containing the results of the hyperparameter search.
28
+
29
+ Examples:
30
+ >>> from ultralytics import YOLO
31
+ >>> model = YOLO("yolo11n.pt") # Load a YOLO11n model
32
+
33
+ Start tuning hyperparameters for YOLO11n training on the COCO8 dataset
34
+ >>> result_grid = model.tune(data="coco8.yaml", use_ray=True)
35
+ """
36
+ LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
37
+ if train_args is None:
38
+ train_args = {}
39
+
40
+ try:
41
+ checks.check_requirements("ray[tune]")
42
+
43
+ import ray
44
+ from ray import tune
45
+ from ray.air import RunConfig
46
+ from ray.air.integrations.wandb import WandbLoggerCallback
47
+ from ray.tune.schedulers import ASHAScheduler
48
+ except ImportError:
49
+ raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')
50
+
51
+ try:
52
+ import wandb
53
+
54
+ assert hasattr(wandb, "__version__")
55
+ except (ImportError, AssertionError):
56
+ wandb = False
57
+
58
+ checks.check_version(ray.__version__, ">=2.0.0", "ray")
59
+ default_space = {
60
+ # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
61
+ "lr0": tune.uniform(1e-5, 1e-1),
62
+ "lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
63
+ "momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
64
+ "weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay
65
+ "warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
66
+ "warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
67
+ "box": tune.uniform(0.02, 0.2), # box loss gain
68
+ "cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
69
+ "hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
70
+ "hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
71
+ "hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
72
+ "degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
73
+ "translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
74
+ "scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
75
+ "shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
76
+ "perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
77
+ "flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
78
+ "fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
79
+ "bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability)
80
+ "mosaic": tune.uniform(0.0, 1.0), # image mosaic (probability)
81
+ "mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
82
+ "cutmix": tune.uniform(0.0, 1.0), # image cutmix (probability)
83
+ "copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
84
+ }
85
+
86
+ # Put the model in ray store
87
+ task = model.task
88
+ model_in_store = ray.put(model)
89
+
90
+ def _tune(config):
91
+ """Train the YOLO model with the specified hyperparameters."""
92
+ model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
93
+ model_to_train.reset_callbacks()
94
+ config.update(train_args)
95
+ results = model_to_train.train(**config)
96
+ return results.results_dict
97
+
98
+ # Get search space
99
+ if not space and not train_args.get("resume"):
100
+ space = default_space
101
+ LOGGER.warning("search space not provided, using default search space.")
102
+
103
+ # Get dataset
104
+ data = train_args.get("data", TASK2DATA[task])
105
+ space["data"] = data
106
+ if "data" not in train_args:
107
+ LOGGER.warning(f'data not provided, using default "data={data}".')
108
+
109
+ # Define the trainable function with allocated resources
110
+ trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
111
+
112
+ # Define the ASHA scheduler for hyperparameter search
113
+ asha_scheduler = ASHAScheduler(
114
+ time_attr="epoch",
115
+ metric=TASK2METRIC[task],
116
+ mode="max",
117
+ max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
118
+ grace_period=grace_period,
119
+ reduction_factor=3,
120
+ )
121
+
122
+ # Define the callbacks for the hyperparameter search
123
+ tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
124
+
125
+ # Create the Ray Tune hyperparameter search tuner
126
+ tune_dir = get_save_dir(
127
+ get_cfg(
128
+ DEFAULT_CFG,
129
+ {**train_args, **{"exist_ok": train_args.pop("resume", False)}}, # resume w/ same tune_dir
130
+ ),
131
+ name=train_args.pop("name", "tune"), # runs/{task}/{tune_dir}
132
+ ).resolve() # must be absolute dir
133
+ tune_dir.mkdir(parents=True, exist_ok=True)
134
+ if tune.Tuner.can_restore(tune_dir):
135
+ LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...")
136
+ tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)
137
+ else:
138
+ tuner = tune.Tuner(
139
+ trainable_with_resources,
140
+ param_space=space,
141
+ tune_config=tune.TuneConfig(
142
+ scheduler=asha_scheduler,
143
+ num_samples=max_samples,
144
+ trial_name_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
145
+ trial_dirname_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
146
+ ),
147
+ run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name),
148
+ )
149
+
150
+ # Run the hyperparameter search
151
+ tuner.fit()
152
+
153
+ # Get the results of the hyperparameter search
154
+ results = tuner.get_results()
155
+
156
+ # Shut down Ray to clean up workers
157
+ ray.shutdown()
158
+
159
+ return results