dgenerate-ultralytics-headless 8.3.143__py3-none-any.whl → 8.3.145__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/METADATA +2 -2
  2. dgenerate_ultralytics_headless-8.3.145.dist-info/RECORD +272 -0
  3. tests/conftest.py +7 -24
  4. tests/test_cli.py +1 -1
  5. tests/test_cuda.py +7 -2
  6. tests/test_engine.py +7 -8
  7. tests/test_exports.py +16 -16
  8. tests/test_integrations.py +1 -1
  9. tests/test_solutions.py +11 -11
  10. ultralytics/__init__.py +1 -1
  11. ultralytics/cfg/__init__.py +16 -13
  12. ultralytics/data/annotator.py +6 -5
  13. ultralytics/data/augment.py +127 -126
  14. ultralytics/data/base.py +54 -51
  15. ultralytics/data/build.py +47 -23
  16. ultralytics/data/converter.py +47 -43
  17. ultralytics/data/dataset.py +51 -50
  18. ultralytics/data/loaders.py +77 -44
  19. ultralytics/data/split.py +22 -9
  20. ultralytics/data/split_dota.py +63 -39
  21. ultralytics/data/utils.py +59 -39
  22. ultralytics/engine/exporter.py +79 -27
  23. ultralytics/engine/model.py +52 -51
  24. ultralytics/engine/predictor.py +37 -28
  25. ultralytics/engine/results.py +191 -161
  26. ultralytics/engine/trainer.py +36 -19
  27. ultralytics/engine/tuner.py +12 -9
  28. ultralytics/engine/validator.py +7 -9
  29. ultralytics/hub/__init__.py +11 -13
  30. ultralytics/hub/auth.py +22 -2
  31. ultralytics/hub/google/__init__.py +19 -19
  32. ultralytics/hub/session.py +37 -51
  33. ultralytics/hub/utils.py +19 -5
  34. ultralytics/models/fastsam/model.py +30 -12
  35. ultralytics/models/fastsam/predict.py +5 -6
  36. ultralytics/models/fastsam/utils.py +3 -3
  37. ultralytics/models/fastsam/val.py +10 -6
  38. ultralytics/models/nas/model.py +9 -5
  39. ultralytics/models/nas/predict.py +6 -6
  40. ultralytics/models/nas/val.py +3 -3
  41. ultralytics/models/rtdetr/model.py +7 -6
  42. ultralytics/models/rtdetr/predict.py +14 -7
  43. ultralytics/models/rtdetr/train.py +10 -4
  44. ultralytics/models/rtdetr/val.py +36 -9
  45. ultralytics/models/sam/amg.py +30 -12
  46. ultralytics/models/sam/build.py +22 -22
  47. ultralytics/models/sam/model.py +10 -9
  48. ultralytics/models/sam/modules/blocks.py +76 -80
  49. ultralytics/models/sam/modules/decoders.py +6 -8
  50. ultralytics/models/sam/modules/encoders.py +23 -26
  51. ultralytics/models/sam/modules/memory_attention.py +13 -1
  52. ultralytics/models/sam/modules/sam.py +57 -26
  53. ultralytics/models/sam/modules/tiny_encoder.py +232 -237
  54. ultralytics/models/sam/modules/transformer.py +13 -13
  55. ultralytics/models/sam/modules/utils.py +11 -19
  56. ultralytics/models/sam/predict.py +114 -101
  57. ultralytics/models/utils/loss.py +98 -77
  58. ultralytics/models/utils/ops.py +116 -67
  59. ultralytics/models/yolo/classify/predict.py +5 -5
  60. ultralytics/models/yolo/classify/train.py +32 -28
  61. ultralytics/models/yolo/classify/val.py +7 -8
  62. ultralytics/models/yolo/detect/predict.py +1 -0
  63. ultralytics/models/yolo/detect/train.py +15 -14
  64. ultralytics/models/yolo/detect/val.py +37 -36
  65. ultralytics/models/yolo/model.py +106 -23
  66. ultralytics/models/yolo/obb/predict.py +3 -4
  67. ultralytics/models/yolo/obb/train.py +14 -6
  68. ultralytics/models/yolo/obb/val.py +29 -23
  69. ultralytics/models/yolo/pose/predict.py +9 -8
  70. ultralytics/models/yolo/pose/train.py +24 -16
  71. ultralytics/models/yolo/pose/val.py +44 -26
  72. ultralytics/models/yolo/segment/predict.py +5 -5
  73. ultralytics/models/yolo/segment/train.py +11 -7
  74. ultralytics/models/yolo/segment/val.py +2 -2
  75. ultralytics/models/yolo/world/train.py +33 -23
  76. ultralytics/models/yolo/world/train_world.py +11 -3
  77. ultralytics/models/yolo/yoloe/predict.py +11 -11
  78. ultralytics/models/yolo/yoloe/train.py +73 -21
  79. ultralytics/models/yolo/yoloe/train_seg.py +10 -7
  80. ultralytics/models/yolo/yoloe/val.py +42 -18
  81. ultralytics/nn/autobackend.py +59 -15
  82. ultralytics/nn/modules/__init__.py +4 -4
  83. ultralytics/nn/modules/activation.py +4 -1
  84. ultralytics/nn/modules/block.py +178 -111
  85. ultralytics/nn/modules/conv.py +6 -5
  86. ultralytics/nn/modules/head.py +469 -121
  87. ultralytics/nn/modules/transformer.py +147 -58
  88. ultralytics/nn/tasks.py +227 -20
  89. ultralytics/nn/text_model.py +30 -33
  90. ultralytics/solutions/ai_gym.py +4 -6
  91. ultralytics/solutions/analytics.py +7 -4
  92. ultralytics/solutions/config.py +10 -10
  93. ultralytics/solutions/distance_calculation.py +11 -10
  94. ultralytics/solutions/heatmap.py +2 -2
  95. ultralytics/solutions/instance_segmentation.py +7 -4
  96. ultralytics/solutions/object_blurrer.py +3 -3
  97. ultralytics/solutions/object_counter.py +15 -11
  98. ultralytics/solutions/object_cropper.py +3 -2
  99. ultralytics/solutions/parking_management.py +29 -28
  100. ultralytics/solutions/queue_management.py +6 -6
  101. ultralytics/solutions/region_counter.py +10 -3
  102. ultralytics/solutions/security_alarm.py +3 -3
  103. ultralytics/solutions/similarity_search.py +85 -24
  104. ultralytics/solutions/solutions.py +189 -79
  105. ultralytics/solutions/speed_estimation.py +28 -22
  106. ultralytics/solutions/streamlit_inference.py +17 -12
  107. ultralytics/solutions/trackzone.py +4 -4
  108. ultralytics/trackers/basetrack.py +16 -23
  109. ultralytics/trackers/bot_sort.py +30 -20
  110. ultralytics/trackers/byte_tracker.py +70 -64
  111. ultralytics/trackers/track.py +4 -8
  112. ultralytics/trackers/utils/gmc.py +31 -58
  113. ultralytics/trackers/utils/kalman_filter.py +37 -37
  114. ultralytics/trackers/utils/matching.py +1 -1
  115. ultralytics/utils/__init__.py +105 -89
  116. ultralytics/utils/autobatch.py +16 -3
  117. ultralytics/utils/autodevice.py +54 -24
  118. ultralytics/utils/benchmarks.py +45 -29
  119. ultralytics/utils/callbacks/base.py +3 -3
  120. ultralytics/utils/callbacks/clearml.py +9 -9
  121. ultralytics/utils/callbacks/comet.py +67 -25
  122. ultralytics/utils/callbacks/dvc.py +7 -10
  123. ultralytics/utils/callbacks/mlflow.py +2 -5
  124. ultralytics/utils/callbacks/neptune.py +7 -13
  125. ultralytics/utils/callbacks/raytune.py +1 -1
  126. ultralytics/utils/callbacks/tensorboard.py +5 -6
  127. ultralytics/utils/callbacks/wb.py +14 -14
  128. ultralytics/utils/checks.py +14 -13
  129. ultralytics/utils/dist.py +5 -5
  130. ultralytics/utils/downloads.py +94 -67
  131. ultralytics/utils/errors.py +5 -5
  132. ultralytics/utils/export.py +61 -47
  133. ultralytics/utils/files.py +23 -22
  134. ultralytics/utils/instance.py +48 -52
  135. ultralytics/utils/loss.py +78 -40
  136. ultralytics/utils/metrics.py +186 -130
  137. ultralytics/utils/ops.py +186 -190
  138. ultralytics/utils/patches.py +15 -17
  139. ultralytics/utils/plotting.py +71 -27
  140. ultralytics/utils/tal.py +21 -15
  141. ultralytics/utils/torch_utils.py +53 -50
  142. ultralytics/utils/triton.py +5 -4
  143. ultralytics/utils/tuner.py +5 -5
  144. dgenerate_ultralytics_headless-8.3.143.dist-info/RECORD +0 -272
  145. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/WHEEL +0 -0
  146. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/entry_points.txt +0 -0
  147. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/licenses/LICENSE +0 -0
  148. {dgenerate_ultralytics_headless-8.3.143.dist-info → dgenerate_ultralytics_headless-8.3.145.dist-info}/top_level.txt +0 -0
@@ -60,6 +60,9 @@ class BaseTrainer:
60
60
  """
61
61
  A base class for creating trainers.
62
62
 
63
+ This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
64
+ and various training utilities. It supports both single-GPU and multi-GPU distributed training.
65
+
63
66
  Attributes:
64
67
  args (SimpleNamespace): Configuration for the trainer.
65
68
  validator (BaseValidator): Validator instance.
@@ -89,6 +92,19 @@ class BaseTrainer:
89
92
  csv (Path): Path to results CSV file.
90
93
  metrics (dict): Dictionary of metrics.
91
94
  plots (dict): Dictionary of plots.
95
+
96
+ Methods:
97
+ train: Execute the training process.
98
+ validate: Run validation on the test set.
99
+ save_model: Save model training checkpoints.
100
+ get_dataset: Get train and validation datasets.
101
+ setup_model: Load, create, or download model.
102
+ build_optimizer: Construct an optimizer for the model.
103
+
104
+ Examples:
105
+ Initialize a trainer and start training
106
+ >>> trainer = BaseTrainer(cfg="config.yaml")
107
+ >>> trainer.train()
92
108
  """
93
109
 
94
110
  def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
@@ -96,14 +112,14 @@ class BaseTrainer:
96
112
  Initialize the BaseTrainer class.
97
113
 
98
114
  Args:
99
- cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
100
- overrides (dict, optional): Configuration overrides. Defaults to None.
101
- _callbacks (list, optional): List of callback functions. Defaults to None.
115
+ cfg (str, optional): Path to a configuration file.
116
+ overrides (dict, optional): Configuration overrides.
117
+ _callbacks (list, optional): List of callback functions.
102
118
  """
103
119
  self.args = get_cfg(cfg, overrides)
104
120
  self.check_resume(overrides)
105
121
  self.device = select_device(self.args.device, self.args.batch)
106
- # update "-1" devices so post-training val does not repeat search
122
+ # Update "-1" devices so post-training val does not repeat search
107
123
  self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
108
124
  self.validator = None
109
125
  self.metrics = None
@@ -626,7 +642,7 @@ class BaseTrainer:
626
642
  self.ema.update(self.model)
627
643
 
628
644
  def preprocess_batch(self, batch):
629
- """Allows custom preprocessing model inputs and ground truths depending on task type."""
645
+ """Allow custom preprocessing model inputs and ground truths depending on task type."""
630
646
  return batch
631
647
 
632
648
  def validate(self):
@@ -634,7 +650,8 @@ class BaseTrainer:
634
650
  Run validation on test set using self.validator.
635
651
 
636
652
  Returns:
637
- (tuple): A tuple containing metrics dictionary and fitness score.
653
+ metrics (dict): Dictionary of validation metrics.
654
+ fitness (float): Fitness score for the validation.
638
655
  """
639
656
  metrics = self.validator(self)
640
657
  fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@@ -647,11 +664,11 @@ class BaseTrainer:
647
664
  raise NotImplementedError("This task trainer doesn't support loading cfg files")
648
665
 
649
666
  def get_validator(self):
650
- """Returns a NotImplementedError when the get_validator function is called."""
667
+ """Return a NotImplementedError when the get_validator function is called."""
651
668
  raise NotImplementedError("get_validator function not implemented in trainer")
652
669
 
653
670
  def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
654
- """Returns dataloader derived from torch.data.Dataloader."""
671
+ """Return dataloader derived from torch.data.Dataloader."""
655
672
  raise NotImplementedError("get_dataloader function not implemented in trainer")
656
673
 
657
674
  def build_dataset(self, img_path, mode="train", batch=None):
@@ -660,7 +677,7 @@ class BaseTrainer:
660
677
 
661
678
  def label_loss_items(self, loss_items=None, prefix="train"):
662
679
  """
663
- Returns a loss dict with labelled training loss items tensor.
680
+ Return a loss dict with labelled training loss items tensor.
664
681
 
665
682
  Note:
666
683
  This is not needed for classification but necessary for segmentation & detection
@@ -672,20 +689,20 @@ class BaseTrainer:
672
689
  self.model.names = self.data["names"]
673
690
 
674
691
  def build_targets(self, preds, targets):
675
- """Builds target tensors for training YOLO model."""
692
+ """Build target tensors for training YOLO model."""
676
693
  pass
677
694
 
678
695
  def progress_string(self):
679
- """Returns a string describing training progress."""
696
+ """Return a string describing training progress."""
680
697
  return ""
681
698
 
682
699
  # TODO: may need to put these following functions into callback
683
700
  def plot_training_samples(self, batch, ni):
684
- """Plots training samples during YOLO training."""
701
+ """Plot training samples during YOLO training."""
685
702
  pass
686
703
 
687
704
  def plot_training_labels(self):
688
- """Plots training labels for YOLO model."""
705
+ """Plot training labels for YOLO model."""
689
706
  pass
690
707
 
691
708
  def save_metrics(self, metrics):
@@ -702,7 +719,7 @@ class BaseTrainer:
702
719
  pass
703
720
 
704
721
  def on_plot(self, name, data=None):
705
- """Registers plots (e.g. to be consumed in callbacks)."""
722
+ """Register plots (e.g. to be consumed in callbacks)."""
706
723
  path = Path(name)
707
724
  self.plots[path] = {"data": data, "timestamp": time.time()}
708
725
 
@@ -796,12 +813,12 @@ class BaseTrainer:
796
813
  Args:
797
814
  model (torch.nn.Module): The model for which to build an optimizer.
798
815
  name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
799
- based on the number of iterations. Default: 'auto'.
800
- lr (float, optional): The learning rate for the optimizer. Default: 0.001.
801
- momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
802
- decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
816
+ based on the number of iterations.
817
+ lr (float, optional): The learning rate for the optimizer.
818
+ momentum (float, optional): The momentum factor for the optimizer.
819
+ decay (float, optional): The weight decay for the optimizer.
803
820
  iterations (float, optional): The number of iterations, which determines the optimizer if
804
- name is 'auto'. Default: 1e5.
821
+ name is 'auto'.
805
822
 
806
823
  Returns:
807
824
  (torch.optim.Optimizer): The constructed optimizer.
@@ -18,6 +18,7 @@ import random
18
18
  import shutil
19
19
  import subprocess
20
20
  import time
21
+ from typing import Dict, List, Optional
21
22
 
22
23
  import numpy as np
23
24
  import torch
@@ -35,7 +36,7 @@ class Tuner:
35
36
  search space and retraining the model to evaluate their performance.
36
37
 
37
38
  Attributes:
38
- space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
39
+ space (Dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
39
40
  tune_dir (Path): Directory where evolution logs and results will be saved.
40
41
  tune_csv (Path): Path to the CSV file where evolution logs are saved.
41
42
  args (dict): Configuration arguments for the tuning process.
@@ -43,8 +44,8 @@ class Tuner:
43
44
  prefix (str): Prefix string for logging messages.
44
45
 
45
46
  Methods:
46
- _mutate: Mutates the given hyperparameters within the specified bounds.
47
- __call__: Executes the hyperparameter evolution across multiple iterations.
47
+ _mutate: Mutate hyperparameters based on bounds and scaling factors.
48
+ __call__: Execute the hyperparameter evolution across multiple iterations.
48
49
 
49
50
  Examples:
50
51
  Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
@@ -58,13 +59,13 @@ class Tuner:
58
59
  >>> model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
59
60
  """
60
61
 
61
- def __init__(self, args=DEFAULT_CFG, _callbacks=None):
62
+ def __init__(self, args=DEFAULT_CFG, _callbacks: Optional[List] = None):
62
63
  """
63
64
  Initialize the Tuner with configurations.
64
65
 
65
66
  Args:
66
67
  args (dict): Configuration for hyperparameter evolution.
67
- _callbacks (list, optional): Callback functions to be executed during tuning.
68
+ _callbacks (List, optional): Callback functions to be executed during tuning.
68
69
  """
69
70
  self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
70
71
  # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
@@ -106,7 +107,9 @@ class Tuner:
106
107
  f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
107
108
  )
108
109
 
109
- def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
110
+ def _mutate(
111
+ self, parent: str = "single", n: int = 5, mutation: float = 0.8, sigma: float = 0.2
112
+ ) -> Dict[str, float]:
110
113
  """
111
114
  Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
112
115
 
@@ -117,7 +120,7 @@ class Tuner:
117
120
  sigma (float): Standard deviation for Gaussian random number generator.
118
121
 
119
122
  Returns:
120
- (dict): A dictionary containing mutated hyperparameters.
123
+ (Dict[str, float]): A dictionary containing mutated hyperparameters.
121
124
  """
122
125
  if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
123
126
  # Select parent(s)
@@ -152,14 +155,14 @@ class Tuner:
152
155
 
153
156
  return hyp
154
157
 
155
- def __call__(self, model=None, iterations=10, cleanup=True):
158
+ def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
156
159
  """
157
160
  Execute the hyperparameter evolution process when the Tuner instance is called.
158
161
 
159
162
  This method iterates through the number of iterations, performing the following steps in each iteration:
160
163
 
161
164
  1. Load the existing hyperparameters or initialize new ones.
162
- 2. Mutate the hyperparameters using the `mutate` method.
165
+ 2. Mutate the hyperparameters using the `_mutate` method.
163
166
  3. Train a YOLO model with the mutated hyperparameters.
164
167
  4. Log the fitness score and mutated hyperparameters to a CSV file.
165
168
 
@@ -67,6 +67,8 @@ class BaseValidator:
67
67
  save_dir (Path): Directory to save results.
68
68
  plots (dict): Dictionary to store plots for visualization.
69
69
  callbacks (dict): Dictionary to store various callback functions.
70
+ stride (int): Model stride for padding calculations.
71
+ loss (torch.Tensor): Accumulated loss during training validation.
70
72
 
71
73
  Methods:
72
74
  __call__: Execute validation process, running inference on dataloader and computing performance metrics.
@@ -84,7 +86,7 @@ class BaseValidator:
84
86
  check_stats: Check statistics.
85
87
  print_results: Print the results of the model's predictions.
86
88
  get_desc: Get description of the YOLO model.
87
- on_plot: Register plots (e.g. to be consumed in callbacks).
89
+ on_plot: Register plots for visualization.
88
90
  plot_val_samples: Plot validation samples during training.
89
91
  plot_predictions: Plot YOLO model predictions on batch images.
90
92
  pred_to_json: Convert predictions to JSON format.
@@ -138,7 +140,7 @@ class BaseValidator:
138
140
  model (nn.Module, optional): Model to validate if not using a trainer.
139
141
 
140
142
  Returns:
141
- stats (dict): Dictionary containing validation statistics.
143
+ (dict): Dictionary containing validation statistics.
142
144
  """
143
145
  self.training = trainer is not None
144
146
  augment = self.args.augment and (not self.training)
@@ -149,7 +151,6 @@ class BaseValidator:
149
151
  self.args.half = self.device.type != "cpu" and trainer.amp
150
152
  model = trainer.ema.ema or trainer.model
151
153
  model = model.half() if self.args.half else model.float()
152
- # self.model = model
153
154
  self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
154
155
  self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
155
156
  model.eval()
@@ -164,7 +165,6 @@ class BaseValidator:
164
165
  data=self.args.data,
165
166
  fp16=self.args.half,
166
167
  )
167
- # self.model = model
168
168
  self.device = model.device # update device
169
169
  self.args.half = model.fp16 # update half
170
170
  stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
@@ -184,7 +184,7 @@ class BaseValidator:
184
184
 
185
185
  if self.device.type in {"cpu", "mps"}:
186
186
  self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
187
- if not (pt or getattr(model, "dynamic", False)):
187
+ if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
188
188
  self.args.rect = False
189
189
  self.stride = model.stride # used in get_dataloader() for padding
190
190
  self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
@@ -263,7 +263,7 @@ class BaseValidator:
263
263
  pred_classes (torch.Tensor): Predicted class indices of shape (N,).
264
264
  true_classes (torch.Tensor): Target class indices of shape (M,).
265
265
  iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
266
- use_scipy (bool): Whether to use scipy for matching (more precise).
266
+ use_scipy (bool, optional): Whether to use scipy for matching (more precise).
267
267
 
268
268
  Returns:
269
269
  (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
@@ -292,7 +292,6 @@ class BaseValidator:
292
292
  if matches.shape[0] > 1:
293
293
  matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
294
294
  matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
295
- # matches = matches[matches[:, 2].argsort()[::-1]]
296
295
  matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
297
296
  correct[matches[:, 1].astype(int), i] = True
298
297
  return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
@@ -356,10 +355,9 @@ class BaseValidator:
356
355
  return []
357
356
 
358
357
  def on_plot(self, name, data=None):
359
- """Register plots (e.g. to be consumed in callbacks)."""
358
+ """Register plots for visualization."""
360
359
  self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
361
360
 
362
- # TODO: may need to put these following functions into callback
363
361
  def plot_val_samples(self, batch, ni):
364
362
  """Plot validation samples during training."""
365
363
  pass
@@ -31,8 +31,8 @@ def login(api_key: str = None, save: bool = True) -> bool:
31
31
  environment variable if successfully authenticated.
32
32
 
33
33
  Args:
34
- api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from SETTINGS
35
- or HUB_API_KEY environment variable.
34
+ api_key (str, optional): API key to use for authentication. If not provided, it will be retrieved from
35
+ SETTINGS or HUB_API_KEY environment variable.
36
36
  save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
37
37
 
38
38
  Returns:
@@ -68,13 +68,7 @@ def login(api_key: str = None, save: bool = True) -> bool:
68
68
 
69
69
 
70
70
  def logout():
71
- """
72
- Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo login'.
73
-
74
- Examples:
75
- >>> from ultralytics import hub
76
- >>> hub.logout()
77
- """
71
+ """Log out of Ultralytics HUB by removing the API key from the settings file."""
78
72
  SETTINGS["api_key"] = ""
79
73
  LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
80
74
 
@@ -89,7 +83,7 @@ def reset_model(model_id: str = ""):
89
83
 
90
84
 
91
85
  def export_fmts_hub():
92
- """Returns a list of HUB-supported export formats."""
86
+ """Return a list of HUB-supported export formats."""
93
87
  from ultralytics.engine.exporter import export_formats
94
88
 
95
89
  return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
@@ -125,14 +119,18 @@ def get_export(model_id: str = "", format: str = "torchscript"):
125
119
 
126
120
  Args:
127
121
  model_id (str): The ID of the model to retrieve from Ultralytics HUB.
128
- format (str): The export format to retrieve. Must be one of the supported formats returned by export_fmts_hub().
122
+ format (str): The export format to retrieve. Must be one of the supported formats returned by
123
+ export_fmts_hub().
124
+
125
+ Returns:
126
+ (dict): JSON response containing the exported model information.
129
127
 
130
128
  Raises:
131
129
  AssertionError: If the specified format is not supported or if the API request fails.
132
130
 
133
131
  Examples:
134
132
  >>> from ultralytics import hub
135
- >>> hub.get_export(model_id="your_model_id", format="torchscript")
133
+ >>> result = hub.get_export(model_id="your_model_id", format="torchscript")
136
134
  """
137
135
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
138
136
  r = requests.post(
@@ -160,7 +158,7 @@ def check_dataset(path: str, task: str) -> None:
160
158
  >>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
161
159
  >>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
162
160
 
163
- Note:
161
+ Notes:
164
162
  Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
165
163
  i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
166
164
  """
ultralytics/hub/auth.py CHANGED
@@ -21,6 +21,19 @@ class Auth:
21
21
  id_token (str | bool): Token used for identity verification, initialized as False.
22
22
  api_key (str | bool): API key for authentication, initialized as False.
23
23
  model_key (bool): Placeholder for model key, initialized as False.
24
+
25
+ Methods:
26
+ authenticate: Attempt to authenticate with the server using either id_token or API key.
27
+ auth_with_cookies: Attempt to fetch authentication via cookies and set id_token.
28
+ get_auth_header: Get the authentication header for making API requests.
29
+ request_api_key: Prompt the user to input their API key.
30
+
31
+ Examples:
32
+ Initialize Auth with an API key
33
+ >>> auth = Auth(api_key="your_api_key_here")
34
+
35
+ Initialize Auth without API key (will prompt for input)
36
+ >>> auth = Auth()
24
37
  """
25
38
 
26
39
  id_token = api_key = model_key = False
@@ -71,7 +84,15 @@ class Auth:
71
84
  LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
72
85
 
73
86
  def request_api_key(self, max_attempts: int = 3) -> bool:
74
- """Prompt the user to input their API key."""
87
+ """
88
+ Prompt the user to input their API key.
89
+
90
+ Args:
91
+ max_attempts (int): Maximum number of authentication attempts.
92
+
93
+ Returns:
94
+ (bool): True if authentication is successful, False otherwise.
95
+ """
75
96
  import getpass
76
97
 
77
98
  for attempts in range(max_attempts):
@@ -134,4 +155,3 @@ class Auth:
134
155
  return {"authorization": f"Bearer {self.id_token}"}
135
156
  elif self.api_key:
136
157
  return {"x-api-key": self.api_key}
137
- # else returns None
@@ -31,7 +31,7 @@ class GCPRegions:
31
31
  """
32
32
 
33
33
  def __init__(self):
34
- """Initializes the GCPRegions class with predefined Google Cloud Platform regions and their details."""
34
+ """Initialize the GCPRegions class with predefined Google Cloud Platform regions and their details."""
35
35
  self.regions = {
36
36
  "asia-east1": (1, "Taiwan", "China"),
37
37
  "asia-east2": (2, "Hong Kong", "China"),
@@ -74,11 +74,11 @@ class GCPRegions:
74
74
  }
75
75
 
76
76
  def tier1(self) -> List[str]:
77
- """Returns a list of GCP regions classified as tier 1 based on predefined criteria."""
77
+ """Return a list of GCP regions classified as tier 1 based on predefined criteria."""
78
78
  return [region for region, info in self.regions.items() if info[0] == 1]
79
79
 
80
80
  def tier2(self) -> List[str]:
81
- """Returns a list of GCP regions classified as tier 2 based on predefined criteria."""
81
+ """Return a list of GCP regions classified as tier 2 based on predefined criteria."""
82
82
  return [region for region, info in self.regions.items() if info[0] == 2]
83
83
 
84
84
  @staticmethod
@@ -87,19 +87,19 @@ class GCPRegions:
87
87
  Ping a specified GCP region and measure network latency statistics.
88
88
 
89
89
  Args:
90
- region (str): The GCP region identifier to ping (e.g., 'us-central1').
91
- attempts (int): Number of ping attempts to make for calculating statistics.
90
+ region (str): The GCP region identifier to ping (e.g., 'us-central1').
91
+ attempts (int, optional): Number of ping attempts to make for calculating statistics.
92
92
 
93
93
  Returns:
94
- region (str): The GCP region identifier that was pinged.
95
- mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.
96
- min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.
97
- max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.
98
- std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.
94
+ region (str): The GCP region identifier that was pinged.
95
+ mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.
96
+ std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.
97
+ min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.
98
+ max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.
99
99
 
100
100
  Examples:
101
- >>> region, mean, min_lat, max_lat, std = GCPRegions._ping_region("us-central1", attempts=3)
102
- >>> print(f"Region {region} has mean latency: {mean:.2f}ms")
101
+ >>> region, mean, std, min_lat, max_lat = GCPRegions._ping_region("us-central1", attempts=3)
102
+ >>> print(f"Region {region} has mean latency: {mean:.2f}ms")
103
103
  """
104
104
  url = f"https://{region}-docker.pkg.dev"
105
105
  latencies = []
@@ -107,7 +107,7 @@ class GCPRegions:
107
107
  try:
108
108
  start_time = time.time()
109
109
  _ = requests.head(url, timeout=5)
110
- latency = (time.time() - start_time) * 1000 # convert latency to milliseconds
110
+ latency = (time.time() - start_time) * 1000 # Convert latency to milliseconds
111
111
  if latency != float("inf"):
112
112
  latencies.append(latency)
113
113
  except requests.RequestException:
@@ -126,17 +126,17 @@ class GCPRegions:
126
126
  attempts: int = 1,
127
127
  ) -> List[Tuple[str, float, float, float, float]]:
128
128
  """
129
- Determines the GCP regions with the lowest latency based on ping tests.
129
+ Determine the GCP regions with the lowest latency based on ping tests.
130
130
 
131
131
  Args:
132
- top (int): Number of top regions to return.
133
- verbose (bool): If True, prints detailed latency information for all tested regions.
134
- tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested.
135
- attempts (int): Number of ping attempts per region.
132
+ top (int, optional): Number of top regions to return.
133
+ verbose (bool, optional): If True, prints detailed latency information for all tested regions.
134
+ tier (int | None, optional): Filter regions by tier (1 or 2). If None, all regions are tested.
135
+ attempts (int, optional): Number of ping attempts per region.
136
136
 
137
137
  Returns:
138
138
  (List[Tuple[str, float, float, float, float]]): List of tuples containing region information and
139
- latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
139
+ latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
140
140
 
141
141
  Examples:
142
142
  >>> regions = GCPRegions()