dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,132 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils
4
+
5
+ try:
6
+ assert not TESTS_RUNNING # do not log pytest
7
+ assert SETTINGS["tensorboard"] is True # verify integration is enabled
8
+ WRITER = None # TensorBoard SummaryWriter instance
9
+ PREFIX = colorstr("TensorBoard: ")
10
+
11
+ # Imports below only required if TensorBoard enabled
12
+ import warnings
13
+ from copy import deepcopy
14
+
15
+ import torch
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ except (ImportError, AssertionError, TypeError, AttributeError):
19
+ # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
20
+ # AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
21
+ SummaryWriter = None
22
+
23
+
24
+ def _log_scalars(scalars: dict, step: int = 0) -> None:
25
+ """
26
+ Log scalar values to TensorBoard.
27
+
28
+ Args:
29
+ scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the
30
+ corresponding scalar values.
31
+ step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
32
+
33
+ Examples:
34
+ >>> # Log training metrics
35
+ >>> metrics = {"loss": 0.5, "accuracy": 0.95}
36
+ >>> _log_scalars(metrics, step=100)
37
+ """
38
+ if WRITER:
39
+ for k, v in scalars.items():
40
+ WRITER.add_scalar(k, v, step)
41
+
42
+
43
+ def _log_tensorboard_graph(trainer) -> None:
44
+ """
45
+ Log model graph to TensorBoard.
46
+
47
+ This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
48
+ tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
49
+ approach for models like RTDETR that may require special handling.
50
+
51
+ Args:
52
+ trainer (BaseTrainer): The trainer object containing the model to visualize. Must have attributes:
53
+ - model: PyTorch model to visualize
54
+ - args: Configuration arguments with 'imgsz' attribute
55
+
56
+ Notes:
57
+ This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
58
+ It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different
59
+ model architectures.
60
+ """
61
+ # Input image
62
+ imgsz = trainer.args.imgsz
63
+ imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
64
+ p = next(trainer.model.parameters()) # for device, type
65
+ im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
66
+
67
+ with warnings.catch_warnings():
68
+ warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
69
+ warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
70
+
71
+ # Try simple method first (YOLO)
72
+ try:
73
+ trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
74
+ WRITER.add_graph(torch.jit.trace(torch_utils.de_parallel(trainer.model), im, strict=False), [])
75
+ LOGGER.info(f"{PREFIX}model graph visualization added ✅")
76
+ return
77
+
78
+ except Exception:
79
+ # Fallback to TorchScript export steps (RTDETR)
80
+ try:
81
+ model = deepcopy(torch_utils.de_parallel(trainer.model))
82
+ model.eval()
83
+ model = model.fuse(verbose=False)
84
+ for m in model.modules():
85
+ if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
86
+ m.export = True
87
+ m.format = "torchscript"
88
+ model(im) # dry run
89
+ WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
90
+ LOGGER.info(f"{PREFIX}model graph visualization added ✅")
91
+ except Exception as e:
92
+ LOGGER.warning(f"{PREFIX}TensorBoard graph visualization failure {e}")
93
+
94
+
95
+ def on_pretrain_routine_start(trainer) -> None:
96
+ """Initialize TensorBoard logging with SummaryWriter."""
97
+ if SummaryWriter:
98
+ try:
99
+ global WRITER
100
+ WRITER = SummaryWriter(str(trainer.save_dir))
101
+ LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
102
+ except Exception as e:
103
+ LOGGER.warning(f"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}")
104
+
105
+
106
+ def on_train_start(trainer) -> None:
107
+ """Log TensorBoard graph."""
108
+ if WRITER:
109
+ _log_tensorboard_graph(trainer)
110
+
111
+
112
+ def on_train_epoch_end(trainer) -> None:
113
+ """Logs scalar statistics at the end of a training epoch."""
114
+ _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
115
+ _log_scalars(trainer.lr, trainer.epoch + 1)
116
+
117
+
118
+ def on_fit_epoch_end(trainer) -> None:
119
+ """Logs epoch metrics at end of training epoch."""
120
+ _log_scalars(trainer.metrics, trainer.epoch + 1)
121
+
122
+
123
+ callbacks = (
124
+ {
125
+ "on_pretrain_routine_start": on_pretrain_routine_start,
126
+ "on_train_start": on_train_start,
127
+ "on_fit_epoch_end": on_fit_epoch_end,
128
+ "on_train_epoch_end": on_train_epoch_end,
129
+ }
130
+ if SummaryWriter
131
+ else {}
132
+ )
@@ -0,0 +1,185 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.utils import SETTINGS, TESTS_RUNNING
4
+ from ultralytics.utils.torch_utils import model_info_for_loggers
5
+
6
+ try:
7
+ assert not TESTS_RUNNING # do not log pytest
8
+ assert SETTINGS["wandb"] is True # verify integration is enabled
9
+ import wandb as wb
10
+
11
+ assert hasattr(wb, "__version__") # verify package is not directory
12
+ _processed_plots = {}
13
+
14
+ except (ImportError, AssertionError):
15
+ wb = None
16
+
17
+
18
+ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
19
+ """
20
+ Create and log a custom metric visualization to wandb.plot.pr_curve.
21
+
22
+ This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
23
+ curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
24
+ different classes.
25
+
26
+ Args:
27
+ x (list): Values for the x-axis; expected to have length N.
28
+ y (list): Corresponding values for the y-axis; also expected to have length N.
29
+ classes (list): Labels identifying the class of each point; length N.
30
+ title (str): Title for the plot; defaults to 'Precision Recall Curve'.
31
+ x_title (str): Label for the x-axis; defaults to 'Recall'.
32
+ y_title (str): Label for the y-axis; defaults to 'Precision'.
33
+
34
+ Returns:
35
+ (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
36
+ """
37
+ import pandas # scope for faster 'import ultralytics'
38
+
39
+ df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
40
+ fields = {"x": "x", "y": "y", "class": "class"}
41
+ string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
42
+ return wb.plot_table(
43
+ "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
44
+ )
45
+
46
+
47
+ def _plot_curve(
48
+ x,
49
+ y,
50
+ names=None,
51
+ id="precision-recall",
52
+ title="Precision Recall Curve",
53
+ x_title="Recall",
54
+ y_title="Precision",
55
+ num_x=100,
56
+ only_mean=False,
57
+ ):
58
+ """
59
+ Log a metric curve visualization.
60
+
61
+ This function generates a metric curve based on input data and logs the visualization to wandb.
62
+ The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
63
+
64
+ Args:
65
+ x (np.ndarray): Data points for the x-axis with length N.
66
+ y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
67
+ names (list): Names of the classes corresponding to the y-axis data; length C.
68
+ id (str): Unique identifier for the logged data in wandb.
69
+ title (str): Title for the visualization plot.
70
+ x_title (str): Label for the x-axis.
71
+ y_title (str): Label for the y-axis.
72
+ num_x (int): Number of interpolated data points for visualization.
73
+ only_mean (bool): Flag to indicate if only the mean curve should be plotted.
74
+
75
+ Notes:
76
+ The function leverages the '_custom_table' function to generate the actual visualization.
77
+ """
78
+ import numpy as np
79
+
80
+ # Create new x
81
+ if names is None:
82
+ names = []
83
+ x_new = np.linspace(x[0], x[-1], num_x).round(5)
84
+
85
+ # Create arrays for logging
86
+ x_log = x_new.tolist()
87
+ y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
88
+
89
+ if only_mean:
90
+ table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
91
+ wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
92
+ else:
93
+ classes = ["mean"] * len(x_log)
94
+ for i, yi in enumerate(y):
95
+ x_log.extend(x_new) # add new x
96
+ y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
97
+ classes.extend([names[i]] * len(x_new)) # add class names
98
+ wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
99
+
100
+
101
+ def _log_plots(plots, step):
102
+ """
103
+ Log plots to WandB at a specific step if they haven't been logged already.
104
+
105
+ This function checks each plot in the input dictionary against previously processed plots and logs
106
+ new or updated plots to WandB at the specified step.
107
+
108
+ Args:
109
+ plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries
110
+ containing plot metadata including timestamps.
111
+ step (int): The step/epoch at which to log the plots in the WandB run.
112
+
113
+ Notes:
114
+ - The function uses a shallow copy of the plots dictionary to prevent modification during iteration
115
+ - Plots are identified by their stem name (filename without extension)
116
+ - Each plot is logged as a WandB Image object
117
+ """
118
+ for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration
119
+ timestamp = params["timestamp"]
120
+ if _processed_plots.get(name) != timestamp:
121
+ wb.run.log({name.stem: wb.Image(str(name))}, step=step)
122
+ _processed_plots[name] = timestamp
123
+
124
+
125
+ def on_pretrain_routine_start(trainer):
126
+ """Initiate and start wandb project if module is present."""
127
+ if not wb.run:
128
+ wb.init(
129
+ project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
130
+ name=str(trainer.args.name).replace("/", "-"),
131
+ config=vars(trainer.args),
132
+ )
133
+
134
+
135
+ def on_fit_epoch_end(trainer):
136
+ """Log training metrics and model information at the end of an epoch."""
137
+ wb.run.log(trainer.metrics, step=trainer.epoch + 1)
138
+ _log_plots(trainer.plots, step=trainer.epoch + 1)
139
+ _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
140
+ if trainer.epoch == 0:
141
+ wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
142
+
143
+
144
+ def on_train_epoch_end(trainer):
145
+ """Log metrics and save images at the end of each training epoch."""
146
+ wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
147
+ wb.run.log(trainer.lr, step=trainer.epoch + 1)
148
+ if trainer.epoch == 1:
149
+ _log_plots(trainer.plots, step=trainer.epoch + 1)
150
+
151
+
152
+ def on_train_end(trainer):
153
+ """Save the best model as an artifact and log final plots at the end of training."""
154
+ _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
155
+ _log_plots(trainer.plots, step=trainer.epoch + 1)
156
+ art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
157
+ if trainer.best.exists():
158
+ art.add_file(trainer.best)
159
+ wb.run.log_artifact(art, aliases=["best"])
160
+ # Check if we actually have plots to save
161
+ if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"):
162
+ for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
163
+ x, y, x_title, y_title = curve_values
164
+ _plot_curve(
165
+ x,
166
+ y,
167
+ names=list(trainer.validator.metrics.names.values()),
168
+ id=f"curves/{curve_name}",
169
+ title=curve_name,
170
+ x_title=x_title,
171
+ y_title=y_title,
172
+ )
173
+ wb.run.finish() # required or run continues on dashboard
174
+
175
+
176
+ callbacks = (
177
+ {
178
+ "on_pretrain_routine_start": on_pretrain_routine_start,
179
+ "on_train_epoch_end": on_train_epoch_end,
180
+ "on_fit_epoch_end": on_fit_epoch_end,
181
+ "on_train_end": on_train_end,
182
+ }
183
+ if wb
184
+ else {}
185
+ )