dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,9 @@
3
3
  import json
4
4
  from time import time
5
5
 
6
- from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events
6
+ from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession
7
7
  from ultralytics.utils import LOGGER, RANK, SETTINGS
8
+ from ultralytics.utils.events import events
8
9
 
9
10
 
10
11
  def on_pretrain_routine_start(trainer):
@@ -73,22 +74,23 @@ def on_train_end(trainer):
73
74
 
74
75
  def on_train_start(trainer):
75
76
  """Run events on train start."""
76
- events(trainer.args)
77
+ events(trainer.args, trainer.device)
77
78
 
78
79
 
79
80
  def on_val_start(validator):
80
81
  """Run events on validation start."""
81
- events(validator.args)
82
+ if not validator.training:
83
+ events(validator.args, validator.device)
82
84
 
83
85
 
84
86
  def on_predict_start(predictor):
85
87
  """Run events on predict start."""
86
- events(predictor.args)
88
+ events(predictor.args, predictor.device)
87
89
 
88
90
 
89
91
  def on_export_start(exporter):
90
92
  """Run events on export start."""
91
- events(exporter.args)
93
+ events(exporter.args, exporter.device)
92
94
 
93
95
 
94
96
  callbacks = (
@@ -105,4 +107,4 @@ callbacks = (
105
107
  }
106
108
  if SETTINGS["hub"] is True
107
109
  else {}
108
- ) # verify hub is enabled before registering callbacks
110
+ )
@@ -45,24 +45,20 @@ def sanitize_dict(x: dict) -> dict:
45
45
 
46
46
 
47
47
  def on_pretrain_routine_end(trainer):
48
- """
49
- Log training parameters to MLflow at the end of the pretraining routine.
48
+ """Log training parameters to MLflow at the end of the pretraining routine.
50
49
 
51
50
  This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
52
- experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
53
- from the trainer.
51
+ experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters from
52
+ the trainer.
54
53
 
55
54
  Args:
56
55
  trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
57
56
 
58
- Global:
59
- mlflow: The imported mlflow module to use for logging.
60
-
61
- Environment Variables:
57
+ Notes:
62
58
  MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
63
59
  MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
64
60
  MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
65
- MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of training.
61
+ MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.
66
62
  """
67
63
  global mlflow
68
64
 
@@ -107,7 +103,7 @@ def on_fit_epoch_end(trainer):
107
103
 
108
104
 
109
105
  def on_train_end(trainer):
110
- """Log model artifacts at the end of the training."""
106
+ """Log model artifacts at the end of training."""
111
107
  if not mlflow:
112
108
  return
113
109
  mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
@@ -18,12 +18,11 @@ except (ImportError, AssertionError):
18
18
 
19
19
 
20
20
  def _log_scalars(scalars: dict, step: int = 0) -> None:
21
- """
22
- Log scalars to the NeptuneAI experiment logger.
21
+ """Log scalars to the NeptuneAI experiment logger.
23
22
 
24
23
  Args:
25
24
  scalars (dict): Dictionary of scalar values to log to NeptuneAI.
26
- step (int): The current step or iteration number for logging.
25
+ step (int, optional): The current step or iteration number for logging.
27
26
 
28
27
  Examples:
29
28
  >>> metrics = {"mAP": 0.85, "loss": 0.32}
@@ -35,11 +34,10 @@ def _log_scalars(scalars: dict, step: int = 0) -> None:
35
34
 
36
35
 
37
36
  def _log_images(imgs_dict: dict, group: str = "") -> None:
38
- """
39
- Log images to the NeptuneAI experiment logger.
37
+ """Log images to the NeptuneAI experiment logger.
40
38
 
41
- This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized
42
- under the specified group name.
39
+ This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized under the
40
+ specified group name.
43
41
 
44
42
  Args:
45
43
  imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.
@@ -55,13 +53,7 @@ def _log_images(imgs_dict: dict, group: str = "") -> None:
55
53
 
56
54
 
57
55
  def _log_plot(title: str, plot_path: str) -> None:
58
- """
59
- Log plots to the NeptuneAI experiment logger.
60
-
61
- Args:
62
- title (str): Title of the plot.
63
- plot_path (str): Path to the saved image file.
64
- """
56
+ """Log plots to the NeptuneAI experiment logger."""
65
57
  import matplotlib.image as mpimg
66
58
  import matplotlib.pyplot as plt
67
59
 
@@ -73,7 +65,7 @@ def _log_plot(title: str, plot_path: str) -> None:
73
65
 
74
66
 
75
67
  def on_pretrain_routine_start(trainer) -> None:
76
- """Callback function called before the training routine starts."""
68
+ """Initialize NeptuneAI run and log hyperparameters before training starts."""
77
69
  try:
78
70
  global run
79
71
  run = neptune.init_run(
@@ -87,7 +79,7 @@ def on_pretrain_routine_start(trainer) -> None:
87
79
 
88
80
 
89
81
  def on_train_epoch_end(trainer) -> None:
90
- """Callback function called at end of each training epoch."""
82
+ """Log training metrics and learning rate at the end of each training epoch."""
91
83
  _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
92
84
  _log_scalars(trainer.lr, trainer.epoch + 1)
93
85
  if trainer.epoch == 1:
@@ -95,7 +87,7 @@ def on_train_epoch_end(trainer) -> None:
95
87
 
96
88
 
97
89
  def on_fit_epoch_end(trainer) -> None:
98
- """Callback function called at end of each fit (train+val) epoch."""
90
+ """Log model info and validation metrics at the end of each fit epoch."""
99
91
  if run and trainer.epoch == 0:
100
92
  from ultralytics.utils.torch_utils import model_info_for_loggers
101
93
 
@@ -104,14 +96,14 @@ def on_fit_epoch_end(trainer) -> None:
104
96
 
105
97
 
106
98
  def on_val_end(validator) -> None:
107
- """Callback function called at end of each validation."""
99
+ """Log validation images at the end of validation."""
108
100
  if run:
109
101
  # Log val_labels and val_pred
110
102
  _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
111
103
 
112
104
 
113
105
  def on_train_end(trainer) -> None:
114
- """Callback function called at end of training."""
106
+ """Log final results, plots, and model weights at the end of training."""
115
107
  if run:
116
108
  # Log final results, CM matrix + PR plots
117
109
  files = [
@@ -0,0 +1,73 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from ultralytics.utils import RANK, SETTINGS
4
+
5
+
6
+ def on_pretrain_routine_start(trainer):
7
+ """Initialize and start console logging immediately at the very beginning."""
8
+ if RANK in {-1, 0}:
9
+ from ultralytics.utils.logger import DEFAULT_LOG_PATH, ConsoleLogger, SystemLogger
10
+
11
+ trainer.system_logger = SystemLogger()
12
+ trainer.console_logger = ConsoleLogger(DEFAULT_LOG_PATH)
13
+ trainer.console_logger.start_capture()
14
+
15
+
16
+ def on_pretrain_routine_end(trainer):
17
+ """Handle pre-training routine completion event."""
18
+ pass
19
+
20
+
21
+ def on_fit_epoch_end(trainer):
22
+ """Handle end of training epoch event and collect system metrics."""
23
+ if RANK in {-1, 0} and hasattr(trainer, "system_logger"):
24
+ system_metrics = trainer.system_logger.get_metrics()
25
+ print(system_metrics) # for debug
26
+
27
+
28
+ def on_model_save(trainer):
29
+ """Handle model checkpoint save event."""
30
+ pass
31
+
32
+
33
+ def on_train_end(trainer):
34
+ """Stop console capture and finalize logs."""
35
+ if logger := getattr(trainer, "console_logger", None):
36
+ logger.stop_capture()
37
+
38
+
39
+ def on_train_start(trainer):
40
+ """Handle training start event."""
41
+ pass
42
+
43
+
44
+ def on_val_start(validator):
45
+ """Handle validation start event."""
46
+ pass
47
+
48
+
49
+ def on_predict_start(predictor):
50
+ """Handle prediction start event."""
51
+ pass
52
+
53
+
54
+ def on_export_start(exporter):
55
+ """Handle model export start event."""
56
+ pass
57
+
58
+
59
+ callbacks = (
60
+ {
61
+ "on_pretrain_routine_start": on_pretrain_routine_start,
62
+ "on_pretrain_routine_end": on_pretrain_routine_end,
63
+ "on_fit_epoch_end": on_fit_epoch_end,
64
+ "on_model_save": on_model_save,
65
+ "on_train_end": on_train_end,
66
+ "on_train_start": on_train_start,
67
+ "on_val_start": on_val_start,
68
+ "on_predict_start": on_predict_start,
69
+ "on_export_start": on_export_start,
70
+ }
71
+ if SETTINGS.get("platform", False) is True # disabled for debugging
72
+ else {}
73
+ )
@@ -13,11 +13,10 @@ except (ImportError, AssertionError):
13
13
 
14
14
 
15
15
  def on_fit_epoch_end(trainer):
16
- """
17
- Reports training metrics to Ray Tune at epoch end when a Ray session is active.
16
+ """Report training metrics to Ray Tune at epoch end when a Ray session is active.
18
17
 
19
- Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number,
20
- enabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.
18
+ Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number, enabling
19
+ hyperparameter tuning optimization. Only executes when within an active Ray Tune session.
21
20
 
22
21
  Args:
23
22
  trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.
@@ -22,8 +22,7 @@ except (ImportError, AssertionError, TypeError, AttributeError):
22
22
 
23
23
 
24
24
  def _log_scalars(scalars: dict, step: int = 0) -> None:
25
- """
26
- Log scalar values to TensorBoard.
25
+ """Log scalar values to TensorBoard.
27
26
 
28
27
  Args:
29
28
  scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the
@@ -31,7 +30,7 @@ def _log_scalars(scalars: dict, step: int = 0) -> None:
31
30
  step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
32
31
 
33
32
  Examples:
34
- >>> # Log training metrics
33
+ Log training metrics
35
34
  >>> metrics = {"loss": 0.5, "accuracy": 0.95}
36
35
  >>> _log_scalars(metrics, step=100)
37
36
  """
@@ -41,17 +40,15 @@ def _log_scalars(scalars: dict, step: int = 0) -> None:
41
40
 
42
41
 
43
42
  def _log_tensorboard_graph(trainer) -> None:
44
- """
45
- Log model graph to TensorBoard.
43
+ """Log model graph to TensorBoard.
46
44
 
47
45
  This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
48
46
  tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
49
47
  approach for models like RTDETR that may require special handling.
50
48
 
51
49
  Args:
52
- trainer (BaseTrainer): The trainer object containing the model to visualize. Must have attributes:
53
- - model: PyTorch model to visualize
54
- - args: Configuration arguments with 'imgsz' attribute
50
+ trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize. Must
51
+ have attributes model and args with imgsz.
55
52
 
56
53
  Notes:
57
54
  This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
@@ -71,14 +68,14 @@ def _log_tensorboard_graph(trainer) -> None:
71
68
  # Try simple method first (YOLO)
72
69
  try:
73
70
  trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
74
- WRITER.add_graph(torch.jit.trace(torch_utils.de_parallel(trainer.model), im, strict=False), [])
71
+ WRITER.add_graph(torch.jit.trace(torch_utils.unwrap_model(trainer.model), im, strict=False), [])
75
72
  LOGGER.info(f"{PREFIX}model graph visualization added ✅")
76
73
  return
77
74
 
78
75
  except Exception:
79
76
  # Fallback to TorchScript export steps (RTDETR)
80
77
  try:
81
- model = deepcopy(torch_utils.de_parallel(trainer.model))
78
+ model = deepcopy(torch_utils.unwrap_model(trainer.model))
82
79
  model.eval()
83
80
  model = model.fuse(verbose=False)
84
81
  for m in model.modules():
@@ -110,13 +107,13 @@ def on_train_start(trainer) -> None:
110
107
 
111
108
 
112
109
  def on_train_epoch_end(trainer) -> None:
113
- """Logs scalar statistics at the end of a training epoch."""
110
+ """Log scalar statistics at the end of a training epoch."""
114
111
  _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
115
112
  _log_scalars(trainer.lr, trainer.epoch + 1)
116
113
 
117
114
 
118
115
  def on_fit_epoch_end(trainer) -> None:
119
- """Logs epoch metrics at end of training epoch."""
116
+ """Log epoch metrics at end of training epoch."""
120
117
  _log_scalars(trainer.metrics, trainer.epoch + 1)
121
118
 
122
119
 
@@ -16,8 +16,7 @@ except (ImportError, AssertionError):
16
16
 
17
17
 
18
18
  def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
19
- """
20
- Create and log a custom metric visualization to wandb.plot.pr_curve.
19
+ """Create and log a custom metric visualization to wandb.plot.pr_curve.
21
20
 
22
21
  This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
23
22
  curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
@@ -27,20 +26,26 @@ def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall
27
26
  x (list): Values for the x-axis; expected to have length N.
28
27
  y (list): Corresponding values for the y-axis; also expected to have length N.
29
28
  classes (list): Labels identifying the class of each point; length N.
30
- title (str): Title for the plot; defaults to 'Precision Recall Curve'.
31
- x_title (str): Label for the x-axis; defaults to 'Recall'.
32
- y_title (str): Label for the y-axis; defaults to 'Precision'.
29
+ title (str, optional): Title for the plot.
30
+ x_title (str, optional): Label for the x-axis.
31
+ y_title (str, optional): Label for the y-axis.
33
32
 
34
33
  Returns:
35
34
  (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
36
35
  """
37
- import pandas # scope for faster 'import ultralytics'
36
+ import polars as pl # scope for faster 'import ultralytics'
37
+ import polars.selectors as cs
38
+
39
+ df = pl.DataFrame({"class": classes, "y": y, "x": x}).with_columns(cs.numeric().round(3))
40
+ data = df.select(["class", "y", "x"]).rows()
38
41
 
39
- df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3)
40
42
  fields = {"x": "x", "y": "y", "class": "class"}
41
43
  string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
42
44
  return wb.plot_table(
43
- "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields
45
+ "wandb/area-under-curve/v0",
46
+ wb.Table(data=data, columns=["class", "y", "x"]),
47
+ fields=fields,
48
+ string_fields=string_fields,
44
49
  )
45
50
 
46
51
 
@@ -55,22 +60,21 @@ def _plot_curve(
55
60
  num_x=100,
56
61
  only_mean=False,
57
62
  ):
58
- """
59
- Log a metric curve visualization.
63
+ """Log a metric curve visualization.
60
64
 
61
- This function generates a metric curve based on input data and logs the visualization to wandb.
62
- The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
65
+ This function generates a metric curve based on input data and logs the visualization to wandb. The curve can
66
+ represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
63
67
 
64
68
  Args:
65
69
  x (np.ndarray): Data points for the x-axis with length N.
66
70
  y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
67
- names (list): Names of the classes corresponding to the y-axis data; length C.
68
- id (str): Unique identifier for the logged data in wandb.
69
- title (str): Title for the visualization plot.
70
- x_title (str): Label for the x-axis.
71
- y_title (str): Label for the y-axis.
72
- num_x (int): Number of interpolated data points for visualization.
73
- only_mean (bool): Flag to indicate if only the mean curve should be plotted.
71
+ names (list, optional): Names of the classes corresponding to the y-axis data; length C.
72
+ id (str, optional): Unique identifier for the logged data in wandb.
73
+ title (str, optional): Title for the visualization plot.
74
+ x_title (str, optional): Label for the x-axis.
75
+ y_title (str, optional): Label for the y-axis.
76
+ num_x (int, optional): Number of interpolated data points for visualization.
77
+ only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.
74
78
 
75
79
  Notes:
76
80
  The function leverages the '_custom_table' function to generate the actual visualization.
@@ -99,21 +103,20 @@ def _plot_curve(
99
103
 
100
104
 
101
105
  def _log_plots(plots, step):
102
- """
103
- Log plots to WandB at a specific step if they haven't been logged already.
106
+ """Log plots to WandB at a specific step if they haven't been logged already.
104
107
 
105
- This function checks each plot in the input dictionary against previously processed plots and logs
106
- new or updated plots to WandB at the specified step.
108
+ This function checks each plot in the input dictionary against previously processed plots and logs new or updated
109
+ plots to WandB at the specified step.
107
110
 
108
111
  Args:
109
- plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries
110
- containing plot metadata including timestamps.
112
+ plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries containing plot
113
+ metadata including timestamps.
111
114
  step (int): The step/epoch at which to log the plots in the WandB run.
112
115
 
113
116
  Notes:
114
- - The function uses a shallow copy of the plots dictionary to prevent modification during iteration
115
- - Plots are identified by their stem name (filename without extension)
116
- - Each plot is logged as a WandB Image object
117
+ The function uses a shallow copy of the plots dictionary to prevent modification during iteration.
118
+ Plots are identified by their stem name (filename without extension).
119
+ Each plot is logged as a WandB Image object.
117
120
  """
118
121
  for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration
119
122
  timestamp = params["timestamp"]
@@ -123,7 +126,7 @@ def _log_plots(plots, step):
123
126
 
124
127
 
125
128
  def on_pretrain_routine_start(trainer):
126
- """Initiate and start wandb project if module is present."""
129
+ """Initialize and start wandb project if module is present."""
127
130
  if not wb.run:
128
131
  wb.init(
129
132
  project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
@@ -134,11 +137,11 @@ def on_pretrain_routine_start(trainer):
134
137
 
135
138
  def on_fit_epoch_end(trainer):
136
139
  """Log training metrics and model information at the end of an epoch."""
137
- wb.run.log(trainer.metrics, step=trainer.epoch + 1)
138
140
  _log_plots(trainer.plots, step=trainer.epoch + 1)
139
141
  _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
140
142
  if trainer.epoch == 0:
141
143
  wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
144
+ wb.run.log(trainer.metrics, step=trainer.epoch + 1, commit=True) # commit forces sync
142
145
 
143
146
 
144
147
  def on_train_epoch_end(trainer):