ultralytics 8.3.89__py3-none-any.whl → 8.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. tests/conftest.py +2 -2
  2. tests/test_cli.py +13 -11
  3. tests/test_cuda.py +10 -1
  4. tests/test_integrations.py +1 -5
  5. tests/test_python.py +16 -16
  6. tests/test_solutions.py +9 -9
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +3 -1
  9. ultralytics/cfg/models/11/yolo11-cls.yaml +5 -5
  10. ultralytics/cfg/models/11/yolo11-obb.yaml +5 -5
  11. ultralytics/cfg/models/11/yolo11-pose.yaml +5 -5
  12. ultralytics/cfg/models/11/yolo11-seg.yaml +5 -5
  13. ultralytics/cfg/models/11/yolo11.yaml +5 -5
  14. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +5 -5
  15. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +5 -5
  16. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -5
  17. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -5
  18. ultralytics/cfg/models/v8/yolov8-p6.yaml +5 -5
  19. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -5
  20. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -5
  21. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -5
  22. ultralytics/cfg/models/v8/yolov8.yaml +5 -5
  23. ultralytics/cfg/models/v9/yolov9c-seg.yaml +1 -1
  24. ultralytics/cfg/models/v9/yolov9c.yaml +1 -1
  25. ultralytics/cfg/models/v9/yolov9e-seg.yaml +1 -1
  26. ultralytics/cfg/models/v9/yolov9e.yaml +1 -1
  27. ultralytics/cfg/models/v9/yolov9m.yaml +1 -1
  28. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  29. ultralytics/cfg/models/v9/yolov9t.yaml +1 -1
  30. ultralytics/data/annotator.py +9 -14
  31. ultralytics/data/base.py +118 -30
  32. ultralytics/data/build.py +63 -24
  33. ultralytics/data/converter.py +5 -5
  34. ultralytics/data/dataset.py +207 -53
  35. ultralytics/data/loaders.py +1 -0
  36. ultralytics/data/split_dota.py +39 -12
  37. ultralytics/data/utils.py +13 -19
  38. ultralytics/engine/exporter.py +19 -17
  39. ultralytics/engine/model.py +67 -88
  40. ultralytics/engine/predictor.py +106 -21
  41. ultralytics/engine/trainer.py +32 -23
  42. ultralytics/engine/tuner.py +21 -18
  43. ultralytics/engine/validator.py +75 -41
  44. ultralytics/hub/__init__.py +12 -13
  45. ultralytics/hub/auth.py +9 -12
  46. ultralytics/hub/session.py +76 -21
  47. ultralytics/hub/utils.py +19 -17
  48. ultralytics/models/fastsam/model.py +20 -11
  49. ultralytics/models/fastsam/predict.py +36 -16
  50. ultralytics/models/fastsam/utils.py +5 -5
  51. ultralytics/models/fastsam/val.py +6 -6
  52. ultralytics/models/nas/model.py +22 -11
  53. ultralytics/models/nas/predict.py +9 -4
  54. ultralytics/models/nas/val.py +5 -5
  55. ultralytics/models/rtdetr/model.py +20 -11
  56. ultralytics/models/rtdetr/predict.py +18 -15
  57. ultralytics/models/rtdetr/train.py +20 -16
  58. ultralytics/models/rtdetr/val.py +42 -6
  59. ultralytics/models/sam/__init__.py +1 -1
  60. ultralytics/models/sam/amg.py +50 -4
  61. ultralytics/models/sam/model.py +8 -14
  62. ultralytics/models/sam/modules/decoders.py +18 -21
  63. ultralytics/models/sam/modules/encoders.py +25 -46
  64. ultralytics/models/sam/modules/memory_attention.py +19 -15
  65. ultralytics/models/sam/modules/sam.py +18 -25
  66. ultralytics/models/sam/modules/tiny_encoder.py +19 -29
  67. ultralytics/models/sam/modules/transformer.py +35 -57
  68. ultralytics/models/sam/modules/utils.py +15 -15
  69. ultralytics/models/sam/predict.py +0 -3
  70. ultralytics/models/utils/loss.py +87 -36
  71. ultralytics/models/utils/ops.py +26 -31
  72. ultralytics/models/yolo/classify/predict.py +24 -3
  73. ultralytics/models/yolo/classify/train.py +77 -10
  74. ultralytics/models/yolo/classify/val.py +40 -15
  75. ultralytics/models/yolo/detect/predict.py +23 -10
  76. ultralytics/models/yolo/detect/train.py +85 -15
  77. ultralytics/models/yolo/detect/val.py +145 -21
  78. ultralytics/models/yolo/model.py +1 -2
  79. ultralytics/models/yolo/obb/predict.py +12 -4
  80. ultralytics/models/yolo/obb/train.py +7 -0
  81. ultralytics/models/yolo/obb/val.py +25 -7
  82. ultralytics/models/yolo/pose/predict.py +22 -6
  83. ultralytics/models/yolo/pose/train.py +17 -1
  84. ultralytics/models/yolo/pose/val.py +46 -21
  85. ultralytics/models/yolo/segment/predict.py +22 -8
  86. ultralytics/models/yolo/segment/train.py +6 -0
  87. ultralytics/models/yolo/segment/val.py +100 -14
  88. ultralytics/models/yolo/world/train.py +38 -8
  89. ultralytics/models/yolo/world/train_world.py +39 -10
  90. ultralytics/nn/autobackend.py +28 -14
  91. ultralytics/nn/modules/__init__.py +3 -0
  92. ultralytics/nn/modules/activation.py +12 -3
  93. ultralytics/nn/modules/block.py +587 -84
  94. ultralytics/nn/modules/conv.py +418 -54
  95. ultralytics/nn/modules/head.py +3 -4
  96. ultralytics/nn/modules/transformer.py +320 -34
  97. ultralytics/nn/modules/utils.py +17 -3
  98. ultralytics/nn/tasks.py +221 -69
  99. ultralytics/solutions/ai_gym.py +2 -2
  100. ultralytics/solutions/analytics.py +4 -4
  101. ultralytics/solutions/heatmap.py +4 -4
  102. ultralytics/solutions/instance_segmentation.py +10 -4
  103. ultralytics/solutions/object_blurrer.py +2 -2
  104. ultralytics/solutions/object_counter.py +2 -2
  105. ultralytics/solutions/object_cropper.py +2 -2
  106. ultralytics/solutions/parking_management.py +9 -9
  107. ultralytics/solutions/queue_management.py +1 -1
  108. ultralytics/solutions/region_counter.py +2 -2
  109. ultralytics/solutions/security_alarm.py +7 -7
  110. ultralytics/solutions/solutions.py +7 -4
  111. ultralytics/solutions/speed_estimation.py +2 -2
  112. ultralytics/solutions/streamlit_inference.py +6 -6
  113. ultralytics/solutions/trackzone.py +9 -2
  114. ultralytics/solutions/vision_eye.py +4 -4
  115. ultralytics/trackers/basetrack.py +1 -1
  116. ultralytics/trackers/bot_sort.py +23 -22
  117. ultralytics/trackers/byte_tracker.py +4 -4
  118. ultralytics/trackers/track.py +2 -1
  119. ultralytics/trackers/utils/gmc.py +26 -27
  120. ultralytics/trackers/utils/kalman_filter.py +31 -29
  121. ultralytics/trackers/utils/matching.py +7 -7
  122. ultralytics/utils/__init__.py +32 -27
  123. ultralytics/utils/autobatch.py +5 -5
  124. ultralytics/utils/benchmarks.py +111 -18
  125. ultralytics/utils/callbacks/base.py +3 -3
  126. ultralytics/utils/callbacks/clearml.py +11 -11
  127. ultralytics/utils/callbacks/comet.py +35 -22
  128. ultralytics/utils/callbacks/dvc.py +11 -10
  129. ultralytics/utils/callbacks/hub.py +8 -8
  130. ultralytics/utils/callbacks/mlflow.py +1 -1
  131. ultralytics/utils/callbacks/neptune.py +12 -10
  132. ultralytics/utils/callbacks/raytune.py +1 -1
  133. ultralytics/utils/callbacks/tensorboard.py +6 -6
  134. ultralytics/utils/callbacks/wb.py +16 -16
  135. ultralytics/utils/checks.py +116 -35
  136. ultralytics/utils/dist.py +15 -2
  137. ultralytics/utils/downloads.py +13 -9
  138. ultralytics/utils/files.py +12 -13
  139. ultralytics/utils/instance.py +112 -45
  140. ultralytics/utils/loss.py +28 -33
  141. ultralytics/utils/metrics.py +246 -181
  142. ultralytics/utils/ops.py +61 -53
  143. ultralytics/utils/patches.py +8 -6
  144. ultralytics/utils/plotting.py +64 -45
  145. ultralytics/utils/tal.py +88 -57
  146. ultralytics/utils/torch_utils.py +181 -33
  147. ultralytics/utils/triton.py +13 -3
  148. ultralytics/utils/tuner.py +8 -16
  149. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/METADATA +1 -1
  150. ultralytics-8.3.90.dist-info/RECORD +250 -0
  151. ultralytics-8.3.89.dist-info/RECORD +0 -250
  152. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/LICENSE +0 -0
  153. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/WHEEL +0 -0
  154. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/entry_points.txt +0 -0
  155. {ultralytics-8.3.89.dist-info → ultralytics-8.3.90.dist-info}/top_level.txt +0 -0
@@ -201,7 +201,23 @@ def benchmark(
201
201
 
202
202
 
203
203
  class RF100Benchmark:
204
- """Benchmark YOLO model performance across various formats for speed and accuracy."""
204
+ """
205
+ Benchmark YOLO model performance across various formats for speed and accuracy.
206
+
207
+ This class provides functionality to benchmark YOLO models on the RF100 dataset collection.
208
+
209
+ Attributes:
210
+ ds_names (List[str]): Names of datasets used for benchmarking.
211
+ ds_cfg_list (List[Path]): List of paths to dataset configuration files.
212
+ rf (Roboflow): Roboflow instance for accessing datasets.
213
+ val_metrics (List[str]): Metrics used for validation.
214
+
215
+ Methods:
216
+ set_key: Set Roboflow API key for accessing datasets.
217
+ parse_dataset: Parse dataset links and download datasets.
218
+ fix_yaml: Fix train and validation paths in YAML files.
219
+ evaluate: Evaluate model performance on validation results.
220
+ """
205
221
 
206
222
  def __init__(self):
207
223
  """Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats."""
@@ -234,6 +250,10 @@ class RF100Benchmark:
234
250
  Args:
235
251
  ds_link_txt (str): Path to the file containing dataset links.
236
252
 
253
+ Returns:
254
+ ds_names (List[str]): List of dataset names.
255
+ ds_cfg_list (List[Path]): List of paths to dataset configuration files.
256
+
237
257
  Examples:
238
258
  >>> benchmark = RF100Benchmark()
239
259
  >>> benchmark.set_key("api_key")
@@ -262,15 +282,7 @@ class RF100Benchmark:
262
282
 
263
283
  @staticmethod
264
284
  def fix_yaml(path):
265
- """
266
- Fixes the train and validation paths in a given YAML file.
267
-
268
- Args:
269
- path (str): Path to the YAML file to be fixed.
270
-
271
- Examples:
272
- >>> RF100Benchmark.fix_yaml("path/to/data.yaml")
273
- """
285
+ """Fix the train and validation paths in a given YAML file."""
274
286
  with open(path, encoding="utf-8") as file:
275
287
  yaml_data = yaml.safe_load(file)
276
288
  yaml_data["train"] = "train/images"
@@ -353,6 +365,14 @@ class ProfileModels:
353
365
 
354
366
  Methods:
355
367
  profile: Profiles the models and prints the result.
368
+ get_files: Gets all relevant model files.
369
+ get_onnx_model_info: Extracts metadata from an ONNX model.
370
+ iterative_sigma_clipping: Applies sigma clipping to remove outliers.
371
+ profile_tensorrt_model: Profiles a TensorRT model.
372
+ profile_onnx_model: Profiles an ONNX model.
373
+ generate_table_row: Generates a table row with model metrics.
374
+ generate_results_dict: Generates a dictionary of profiling results.
375
+ print_table: Prints a formatted table of results.
356
376
 
357
377
  Examples:
358
378
  Profile models and print results
@@ -404,7 +424,18 @@ class ProfileModels:
404
424
  self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu")
405
425
 
406
426
  def profile(self):
407
- """Profiles YOLO models for speed and accuracy across various formats including ONNX and TensorRT."""
427
+ """
428
+ Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT.
429
+
430
+ Returns:
431
+ (List[Dict]): List of dictionaries containing profiling results for each model.
432
+
433
+ Examples:
434
+ Profile models and print results
435
+ >>> from ultralytics.utils.benchmarks import ProfileModels
436
+ >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"])
437
+ >>> results = profiler.profile()
438
+ """
408
439
  files = self.get_files()
409
440
 
410
441
  if not files:
@@ -448,7 +479,12 @@ class ProfileModels:
448
479
  return output
449
480
 
450
481
  def get_files(self):
451
- """Returns a list of paths for all relevant model files given by the user."""
482
+ """
483
+ Return a list of paths for all relevant model files given by the user.
484
+
485
+ Returns:
486
+ (List[Path]): List of Path objects for the model files.
487
+ """
452
488
  files = []
453
489
  for path in self.paths:
454
490
  path = Path(path)
@@ -470,7 +506,17 @@ class ProfileModels:
470
506
 
471
507
  @staticmethod
472
508
  def iterative_sigma_clipping(data, sigma=2, max_iters=3):
473
- """Applies iterative sigma clipping to data to remove outliers based on specified sigma and iteration count."""
509
+ """
510
+ Apply iterative sigma clipping to data to remove outliers.
511
+
512
+ Args:
513
+ data (numpy.ndarray): Input data array.
514
+ sigma (float): Number of standard deviations to use for clipping.
515
+ max_iters (int): Maximum number of iterations for the clipping process.
516
+
517
+ Returns:
518
+ (numpy.ndarray): Clipped data array with outliers removed.
519
+ """
474
520
  data = np.array(data)
475
521
  for _ in range(max_iters):
476
522
  mean, std = np.mean(data), np.std(data)
@@ -481,7 +527,17 @@ class ProfileModels:
481
527
  return data
482
528
 
483
529
  def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
484
- """Profiles YOLO model performance with TensorRT, measuring average run time and standard deviation."""
530
+ """
531
+ Profile YOLO model performance with TensorRT, measuring average run time and standard deviation.
532
+
533
+ Args:
534
+ engine_file (str): Path to the TensorRT engine file.
535
+ eps (float): Small epsilon value to prevent division by zero.
536
+
537
+ Returns:
538
+ mean_time (float): Mean inference time in milliseconds.
539
+ std_time (float): Standard deviation of inference time in milliseconds.
540
+ """
485
541
  if not self.trt or not Path(engine_file).is_file():
486
542
  return 0.0, 0.0
487
543
 
@@ -510,7 +566,17 @@ class ProfileModels:
510
566
  return np.mean(run_times), np.std(run_times)
511
567
 
512
568
  def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
513
- """Profiles an ONNX model, measuring average inference time and standard deviation across multiple runs."""
569
+ """
570
+ Profile an ONNX model, measuring average inference time and standard deviation across multiple runs.
571
+
572
+ Args:
573
+ onnx_file (str): Path to the ONNX model file.
574
+ eps (float): Small epsilon value to prevent division by zero.
575
+
576
+ Returns:
577
+ mean_time (float): Mean inference time in milliseconds.
578
+ std_time (float): Standard deviation of inference time in milliseconds.
579
+ """
514
580
  check_requirements("onnxruntime")
515
581
  import onnxruntime as ort
516
582
 
@@ -565,7 +631,18 @@ class ProfileModels:
565
631
  return np.mean(run_times), np.std(run_times)
566
632
 
567
633
  def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
568
- """Generates a table row string with model performance metrics including inference times and model details."""
634
+ """
635
+ Generate a table row string with model performance metrics.
636
+
637
+ Args:
638
+ model_name (str): Name of the model.
639
+ t_onnx (tuple): ONNX model inference time statistics (mean, std).
640
+ t_engine (tuple): TensorRT engine inference time statistics (mean, std).
641
+ model_info (tuple): Model information (layers, params, gradients, flops).
642
+
643
+ Returns:
644
+ (str): Formatted table row string with model metrics.
645
+ """
569
646
  layers, params, gradients, flops = model_info
570
647
  return (
571
648
  f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±"
@@ -574,7 +651,18 @@ class ProfileModels:
574
651
 
575
652
  @staticmethod
576
653
  def generate_results_dict(model_name, t_onnx, t_engine, model_info):
577
- """Generates a dictionary of profiling results including model name, parameters, GFLOPs, and speed metrics."""
654
+ """
655
+ Generate a dictionary of profiling results.
656
+
657
+ Args:
658
+ model_name (str): Name of the model.
659
+ t_onnx (tuple): ONNX model inference time statistics (mean, std).
660
+ t_engine (tuple): TensorRT engine inference time statistics (mean, std).
661
+ model_info (tuple): Model information (layers, params, gradients, flops).
662
+
663
+ Returns:
664
+ (dict): Dictionary containing profiling results.
665
+ """
578
666
  layers, params, gradients, flops = model_info
579
667
  return {
580
668
  "model/name": model_name,
@@ -586,7 +674,12 @@ class ProfileModels:
586
674
 
587
675
  @staticmethod
588
676
  def print_table(table_rows):
589
- """Prints a formatted table of model profiling results, including speed and accuracy metrics."""
677
+ """
678
+ Print a formatted table of model profiling results.
679
+
680
+ Args:
681
+ table_rows (List[str]): List of formatted table row strings.
682
+ """
590
683
  gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
591
684
  headers = [
592
685
  "Model",
@@ -1,5 +1,5 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
- """Base callbacks."""
2
+ """Base callbacks for Ultralytics training, validation, prediction, and export processes."""
3
3
 
4
4
  from collections import defaultdict
5
5
  from copy import deepcopy
@@ -189,8 +189,8 @@ def add_integration_callbacks(instance):
189
189
  Add integration callbacks from various sources to the instance's callbacks.
190
190
 
191
191
  Args:
192
- instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
193
- of callback lists.
192
+ instance (Trainer | Predictor | Validator | Exporter): An object with a 'callbacks' attribute that is a
193
+ dictionary of callback lists.
194
194
  """
195
195
  # Load HUB callbacks
196
196
  from .hub import callbacks as hub_cb
@@ -14,12 +14,12 @@ except (ImportError, AssertionError):
14
14
  clearml = None
15
15
 
16
16
 
17
- def _log_debug_samples(files, title="Debug Samples") -> None:
17
+ def _log_debug_samples(files, title: str = "Debug Samples") -> None:
18
18
  """
19
19
  Log files (images) as debug samples in the ClearML task.
20
20
 
21
21
  Args:
22
- files (list): A list of file paths in PosixPath format.
22
+ files (List[Path]): A list of file paths in PosixPath format.
23
23
  title (str): A title that groups together images with the same values.
24
24
  """
25
25
  import re
@@ -34,7 +34,7 @@ def _log_debug_samples(files, title="Debug Samples") -> None:
34
34
  )
35
35
 
36
36
 
37
- def _log_plot(title, plot_path) -> None:
37
+ def _log_plot(title: str, plot_path: str) -> None:
38
38
  """
39
39
  Log an image as a plot in the plot section of ClearML.
40
40
 
@@ -55,8 +55,8 @@ def _log_plot(title, plot_path) -> None:
55
55
  )
56
56
 
57
57
 
58
- def on_pretrain_routine_start(trainer):
59
- """Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
58
+ def on_pretrain_routine_start(trainer) -> None:
59
+ """Runs at start of pretraining routine; initializes and connects/logs task to ClearML."""
60
60
  try:
61
61
  if task := Task.current_task():
62
62
  # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
@@ -84,8 +84,8 @@ def on_pretrain_routine_start(trainer):
84
84
  LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}")
85
85
 
86
86
 
87
- def on_train_epoch_end(trainer):
88
- """Logs debug samples for the first epoch of YOLO training and report current training progress."""
87
+ def on_train_epoch_end(trainer) -> None:
88
+ """Logs debug samples for the first epoch of YOLO training and reports current training progress."""
89
89
  if task := Task.current_task():
90
90
  # Log debug samples
91
91
  if trainer.epoch == 1:
@@ -97,10 +97,10 @@ def on_train_epoch_end(trainer):
97
97
  task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
98
98
 
99
99
 
100
- def on_fit_epoch_end(trainer):
100
+ def on_fit_epoch_end(trainer) -> None:
101
101
  """Reports model information to logger at the end of an epoch."""
102
102
  if task := Task.current_task():
103
- # You should have access to the validation bboxes under jdict
103
+ # Report epoch time and validation metrics
104
104
  task.get_logger().report_scalar(
105
105
  title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
106
106
  )
@@ -113,14 +113,14 @@ def on_fit_epoch_end(trainer):
113
113
  task.get_logger().report_single_value(k, v)
114
114
 
115
115
 
116
- def on_val_end(validator):
116
+ def on_val_end(validator) -> None:
117
117
  """Logs validation results including labels and predictions."""
118
118
  if Task.current_task():
119
119
  # Log val_labels and val_pred
120
120
  _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
121
121
 
122
122
 
123
- def on_train_end(trainer):
123
+ def on_train_end(trainer) -> None:
124
124
  """Logs final model and its name on training completion."""
125
125
  if task := Task.current_task():
126
126
  # Log final results, CM matrix + PR plots
@@ -50,33 +50,33 @@ def _get_comet_mode() -> str:
50
50
  return "online"
51
51
 
52
52
 
53
- def _get_comet_model_name():
53
+ def _get_comet_model_name() -> str:
54
54
  """Returns the model name for Comet from the environment variable COMET_MODEL_NAME or defaults to 'Ultralytics'."""
55
55
  return os.getenv("COMET_MODEL_NAME", "Ultralytics")
56
56
 
57
57
 
58
- def _get_eval_batch_logging_interval():
58
+ def _get_eval_batch_logging_interval() -> int:
59
59
  """Get the evaluation batch logging interval from environment variable or use default value 1."""
60
60
  return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
61
61
 
62
62
 
63
- def _get_max_image_predictions_to_log():
63
+ def _get_max_image_predictions_to_log() -> int:
64
64
  """Get the maximum number of image predictions to log from the environment variables."""
65
65
  return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
66
66
 
67
67
 
68
- def _scale_confidence_score(score):
68
+ def _scale_confidence_score(score: float) -> float:
69
69
  """Scales the given confidence score by a factor specified in an environment variable."""
70
70
  scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
71
71
  return score * scale
72
72
 
73
73
 
74
- def _should_log_confusion_matrix():
74
+ def _should_log_confusion_matrix() -> bool:
75
75
  """Determines if the confusion matrix should be logged based on the environment variable settings."""
76
76
  return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
77
77
 
78
78
 
79
- def _should_log_image_predictions():
79
+ def _should_log_image_predictions() -> bool:
80
80
  """Determines whether to log image predictions based on a specified environment variable."""
81
81
  return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
82
82
 
@@ -114,7 +114,7 @@ def _resume_or_create_experiment(args: SimpleNamespace) -> None:
114
114
  LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}")
115
115
 
116
116
 
117
- def _fetch_trainer_metadata(trainer):
117
+ def _fetch_trainer_metadata(trainer) -> dict:
118
118
  """Returns metadata for YOLO training including epoch and asset saving status."""
119
119
  curr_epoch = trainer.epoch + 1
120
120
 
@@ -130,7 +130,9 @@ def _fetch_trainer_metadata(trainer):
130
130
  return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
131
131
 
132
132
 
133
- def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad):
133
+ def _scale_bounding_box_to_original_image_shape(
134
+ box, resized_image_shape, original_image_shape, ratio_pad
135
+ ) -> List[float]:
134
136
  """
135
137
  YOLO resizes images during training and the label values are normalized based on this resized shape.
136
138
 
@@ -151,7 +153,7 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
151
153
  return box
152
154
 
153
155
 
154
- def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
156
+ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> Optional[dict]:
155
157
  """Format ground truth annotations for detection."""
156
158
  indices = batch["batch_idx"] == img_idx
157
159
  bboxes = batch["bboxes"][indices]
@@ -181,7 +183,7 @@ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, c
181
183
  return {"name": "ground_truth", "data": data}
182
184
 
183
185
 
184
- def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None):
186
+ def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> Optional[dict]:
185
187
  """Format YOLO predictions for object detection visualization."""
186
188
  stem = image_path.stem
187
189
  image_id = int(stem) if stem.isnumeric() else stem
@@ -227,7 +229,16 @@ def _format_prediction_annotations(image_path, metadata, class_label_map=None, c
227
229
 
228
230
 
229
231
  def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> Optional[List[List[Any]]]:
230
- """Extracts segmentation annotation from compressed segmentations as list of polygons."""
232
+ """
233
+ Extracts segmentation annotation from compressed segmentations as list of polygons.
234
+
235
+ Args:
236
+ segmentation_raw: Raw segmentation data in compressed format.
237
+ decode: Function to decode the compressed segmentation data.
238
+
239
+ Returns:
240
+ (Optional[List[List[Any]]]): List of polygon points or None if extraction fails.
241
+ """
231
242
  try:
232
243
  mask = decode(segmentation_raw)
233
244
  contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
@@ -238,7 +249,9 @@ def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) ->
238
249
  return None
239
250
 
240
251
 
241
- def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map):
252
+ def _fetch_annotations(
253
+ img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map
254
+ ) -> Optional[List]:
242
255
  """Join the ground truth and prediction annotations if they exist."""
243
256
  ground_truth_annotations = _format_ground_truth_annotations_for_detection(
244
257
  img_idx, image_path, batch, class_label_map
@@ -253,7 +266,7 @@ def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, clas
253
266
  return [annotations] if annotations else None
254
267
 
255
268
 
256
- def _create_prediction_metadata_map(model_predictions):
269
+ def _create_prediction_metadata_map(model_predictions) -> dict:
257
270
  """Create metadata map for model predictions by groupings them based on image ID."""
258
271
  pred_metadata_map = {}
259
272
  for prediction in model_predictions:
@@ -263,7 +276,7 @@ def _create_prediction_metadata_map(model_predictions):
263
276
  return pred_metadata_map
264
277
 
265
278
 
266
- def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
279
+ def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None:
267
280
  """Log the confusion matrix to Comet experiment."""
268
281
  conf_mat = trainer.validator.confusion_matrix.matrix
269
282
  names = list(trainer.data["names"].values()) + ["background"]
@@ -272,7 +285,7 @@ def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
272
285
  )
273
286
 
274
287
 
275
- def _log_images(experiment, image_paths, curr_step, annotations=None):
288
+ def _log_images(experiment, image_paths, curr_step, annotations=None) -> None:
276
289
  """Logs images to the experiment with optional annotations."""
277
290
  if annotations:
278
291
  for image_path, annotation in zip(image_paths, annotations):
@@ -283,7 +296,7 @@ def _log_images(experiment, image_paths, curr_step, annotations=None):
283
296
  experiment.log_image(image_path, name=image_path.stem, step=curr_step)
284
297
 
285
298
 
286
- def _log_image_predictions(experiment, validator, curr_step):
299
+ def _log_image_predictions(experiment, validator, curr_step) -> None:
287
300
  """Logs predicted boxes for a single image during training."""
288
301
  global _comet_image_prediction_count
289
302
 
@@ -330,7 +343,7 @@ def _log_image_predictions(experiment, validator, curr_step):
330
343
  _comet_image_prediction_count += 1
331
344
 
332
345
 
333
- def _log_plots(experiment, trainer):
346
+ def _log_plots(experiment, trainer) -> None:
334
347
  """Logs evaluation plots and label plots for the experiment."""
335
348
  plot_filenames = None
336
349
  if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment":
@@ -359,18 +372,18 @@ def _log_plots(experiment, trainer):
359
372
  _log_images(experiment, label_plot_filenames, None)
360
373
 
361
374
 
362
- def _log_model(experiment, trainer):
375
+ def _log_model(experiment, trainer) -> None:
363
376
  """Log the best-trained model to Comet.ml."""
364
377
  model_name = _get_comet_model_name()
365
378
  experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
366
379
 
367
380
 
368
- def on_pretrain_routine_start(trainer):
381
+ def on_pretrain_routine_start(trainer) -> None:
369
382
  """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
370
383
  _resume_or_create_experiment(trainer.args)
371
384
 
372
385
 
373
- def on_train_epoch_end(trainer):
386
+ def on_train_epoch_end(trainer) -> None:
374
387
  """Log metrics and save batch images at the end of training epochs."""
375
388
  experiment = comet_ml.get_running_experiment()
376
389
  if not experiment:
@@ -383,7 +396,7 @@ def on_train_epoch_end(trainer):
383
396
  experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
384
397
 
385
398
 
386
- def on_fit_epoch_end(trainer):
399
+ def on_fit_epoch_end(trainer) -> None:
387
400
  """Logs model assets at the end of each epoch."""
388
401
  experiment = comet_ml.get_running_experiment()
389
402
  if not experiment:
@@ -411,7 +424,7 @@ def on_fit_epoch_end(trainer):
411
424
  _log_image_predictions(experiment, trainer.validator, curr_step)
412
425
 
413
426
 
414
- def on_train_end(trainer):
427
+ def on_train_end(trainer) -> None:
415
428
  """Perform operations at the end of training."""
416
429
  experiment = comet_ml.get_running_experiment()
417
430
  if not experiment:
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from pathlib import Path
4
+
3
5
  from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
4
6
 
5
7
  try:
@@ -11,7 +13,6 @@ try:
11
13
 
12
14
  import os
13
15
  import re
14
- from pathlib import Path
15
16
 
16
17
  # DVCLive logger instance
17
18
  live = None
@@ -25,7 +26,7 @@ except (ImportError, AssertionError, TypeError):
25
26
  dvclive = None
26
27
 
27
28
 
28
- def _log_images(path, prefix=""):
29
+ def _log_images(path: Path, prefix: str = "") -> None:
29
30
  """Logs images at specified path with an optional prefix using DVCLive."""
30
31
  if live:
31
32
  name = path.name
@@ -39,7 +40,7 @@ def _log_images(path, prefix=""):
39
40
  live.log_image(os.path.join(prefix, name), path)
40
41
 
41
42
 
42
- def _log_plots(plots, prefix=""):
43
+ def _log_plots(plots: dict, prefix: str = "") -> None:
43
44
  """Logs plot images for training progress if they have not been previously processed."""
44
45
  for name, params in plots.items():
45
46
  timestamp = params["timestamp"]
@@ -48,7 +49,7 @@ def _log_plots(plots, prefix=""):
48
49
  _processed_plots[name] = timestamp
49
50
 
50
51
 
51
- def _log_confusion_matrix(validator):
52
+ def _log_confusion_matrix(validator) -> None:
52
53
  """Logs the confusion matrix for the given validator using DVCLive."""
53
54
  targets = []
54
55
  preds = []
@@ -65,7 +66,7 @@ def _log_confusion_matrix(validator):
65
66
  live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
66
67
 
67
68
 
68
- def on_pretrain_routine_start(trainer):
69
+ def on_pretrain_routine_start(trainer) -> None:
69
70
  """Initializes DVCLive logger for training metadata during pre-training routine."""
70
71
  try:
71
72
  global live
@@ -75,24 +76,24 @@ def on_pretrain_routine_start(trainer):
75
76
  LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}")
76
77
 
77
78
 
78
- def on_pretrain_routine_end(trainer):
79
+ def on_pretrain_routine_end(trainer) -> None:
79
80
  """Logs plots related to the training process at the end of the pretraining routine."""
80
81
  _log_plots(trainer.plots, "train")
81
82
 
82
83
 
83
- def on_train_start(trainer):
84
+ def on_train_start(trainer) -> None:
84
85
  """Logs the training parameters if DVCLive logging is active."""
85
86
  if live:
86
87
  live.log_params(trainer.args)
87
88
 
88
89
 
89
- def on_train_epoch_start(trainer):
90
+ def on_train_epoch_start(trainer) -> None:
90
91
  """Sets the global variable _training_epoch value to True at the start of training each epoch."""
91
92
  global _training_epoch
92
93
  _training_epoch = True
93
94
 
94
95
 
95
- def on_fit_epoch_end(trainer):
96
+ def on_fit_epoch_end(trainer) -> None:
96
97
  """Logs training metrics and model info, and advances to next step on the end of each fit epoch."""
97
98
  global _training_epoch
98
99
  if live and _training_epoch:
@@ -113,7 +114,7 @@ def on_fit_epoch_end(trainer):
113
114
  _training_epoch = False
114
115
 
115
116
 
116
- def on_train_end(trainer):
117
+ def on_train_end(trainer) -> None:
117
118
  """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active."""
118
119
  if live:
119
120
  # At the end log the best metrics. It runs validator on the best model internally.
@@ -14,16 +14,16 @@ def on_pretrain_routine_start(trainer):
14
14
 
15
15
 
16
16
  def on_pretrain_routine_end(trainer):
17
- """Logs info before starting timer for upload rate limit."""
17
+ """Initialize timers for upload rate limiting before training begins."""
18
18
  if session := getattr(trainer, "hub_session", None):
19
19
  # Start timer for upload rate limit
20
- session.timers = {"metrics": time(), "ckpt": time()} # start timer on session.rate_limit
20
+ session.timers = {"metrics": time(), "ckpt": time()} # start timer for session rate limiting
21
21
 
22
22
 
23
23
  def on_fit_epoch_end(trainer):
24
- """Uploads training progress metrics at the end of each epoch."""
24
+ """Upload training progress metrics to Ultralytics HUB at the end of each epoch."""
25
25
  if session := getattr(trainer, "hub_session", None):
26
- # Upload metrics after val end
26
+ # Upload metrics after validation ends
27
27
  all_plots = {
28
28
  **trainer.label_loss_items(trainer.tloss, prefix="train"),
29
29
  **trainer.metrics,
@@ -35,7 +35,7 @@ def on_fit_epoch_end(trainer):
35
35
 
36
36
  session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
37
37
 
38
- # If any metrics fail to upload, add them to the queue to attempt uploading again.
38
+ # If any metrics failed to upload previously, add them to the queue to attempt uploading again
39
39
  if session.metrics_upload_failed_queue:
40
40
  session.metrics_queue.update(session.metrics_upload_failed_queue)
41
41
 
@@ -46,7 +46,7 @@ def on_fit_epoch_end(trainer):
46
46
 
47
47
 
48
48
  def on_model_save(trainer):
49
- """Saves checkpoints to Ultralytics HUB with rate limiting."""
49
+ """Upload model checkpoints to Ultralytics HUB with rate limiting."""
50
50
  if session := getattr(trainer, "hub_session", None):
51
51
  # Upload checkpoints with rate limiting
52
52
  is_best = trainer.best_fitness == trainer.fitness
@@ -77,7 +77,7 @@ def on_train_start(trainer):
77
77
 
78
78
 
79
79
  def on_val_start(validator):
80
- """Runs events on validation start."""
80
+ """Run events on validation start."""
81
81
  events(validator.args)
82
82
 
83
83
 
@@ -105,4 +105,4 @@ callbacks = (
105
105
  }
106
106
  if SETTINGS["hub"] is True
107
107
  else {}
108
- ) # verify enabled
108
+ ) # verify hub is enabled before registering callbacks
@@ -39,7 +39,7 @@ except (ImportError, AssertionError):
39
39
  mlflow = None
40
40
 
41
41
 
42
- def sanitize_dict(x):
42
+ def sanitize_dict(x: dict) -> dict:
43
43
  """Sanitize dictionary keys by removing parentheses and converting values to floats."""
44
44
  return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
45
45