dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
  2. dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
  3. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -6
  5. tests/conftest.py +15 -39
  6. tests/test_cli.py +17 -17
  7. tests/test_cuda.py +17 -8
  8. tests/test_engine.py +36 -10
  9. tests/test_exports.py +98 -37
  10. tests/test_integrations.py +12 -15
  11. tests/test_python.py +126 -82
  12. tests/test_solutions.py +319 -135
  13. ultralytics/__init__.py +27 -9
  14. ultralytics/cfg/__init__.py +83 -87
  15. ultralytics/cfg/datasets/Argoverse.yaml +4 -4
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
  17. ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
  18. ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
  19. ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
  20. ultralytics/cfg/datasets/ImageNet.yaml +3 -3
  21. ultralytics/cfg/datasets/Objects365.yaml +24 -20
  22. ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
  23. ultralytics/cfg/datasets/VOC.yaml +10 -13
  24. ultralytics/cfg/datasets/VisDrone.yaml +43 -33
  25. ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
  26. ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
  27. ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
  28. ultralytics/cfg/datasets/coco-pose.yaml +26 -4
  29. ultralytics/cfg/datasets/coco.yaml +4 -4
  30. ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
  31. ultralytics/cfg/datasets/coco128.yaml +2 -2
  32. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  33. ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
  34. ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
  35. ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
  36. ultralytics/cfg/datasets/coco8.yaml +2 -2
  37. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  38. ultralytics/cfg/datasets/crack-seg.yaml +5 -5
  39. ultralytics/cfg/datasets/dog-pose.yaml +32 -4
  40. ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
  41. ultralytics/cfg/datasets/dota8.yaml +2 -2
  42. ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
  43. ultralytics/cfg/datasets/lvis.yaml +9 -9
  44. ultralytics/cfg/datasets/medical-pills.yaml +4 -5
  45. ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
  46. ultralytics/cfg/datasets/package-seg.yaml +5 -5
  47. ultralytics/cfg/datasets/signature.yaml +4 -4
  48. ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
  49. ultralytics/cfg/datasets/xView.yaml +5 -5
  50. ultralytics/cfg/default.yaml +96 -93
  51. ultralytics/cfg/trackers/botsort.yaml +16 -17
  52. ultralytics/cfg/trackers/bytetrack.yaml +9 -11
  53. ultralytics/data/__init__.py +4 -4
  54. ultralytics/data/annotator.py +12 -12
  55. ultralytics/data/augment.py +531 -564
  56. ultralytics/data/base.py +76 -81
  57. ultralytics/data/build.py +206 -42
  58. ultralytics/data/converter.py +179 -78
  59. ultralytics/data/dataset.py +121 -121
  60. ultralytics/data/loaders.py +114 -91
  61. ultralytics/data/split.py +28 -15
  62. ultralytics/data/split_dota.py +67 -48
  63. ultralytics/data/utils.py +110 -89
  64. ultralytics/engine/exporter.py +422 -460
  65. ultralytics/engine/model.py +224 -252
  66. ultralytics/engine/predictor.py +94 -89
  67. ultralytics/engine/results.py +345 -595
  68. ultralytics/engine/trainer.py +231 -134
  69. ultralytics/engine/tuner.py +279 -73
  70. ultralytics/engine/validator.py +53 -46
  71. ultralytics/hub/__init__.py +26 -28
  72. ultralytics/hub/auth.py +30 -16
  73. ultralytics/hub/google/__init__.py +34 -36
  74. ultralytics/hub/session.py +53 -77
  75. ultralytics/hub/utils.py +23 -109
  76. ultralytics/models/__init__.py +1 -1
  77. ultralytics/models/fastsam/__init__.py +1 -1
  78. ultralytics/models/fastsam/model.py +36 -18
  79. ultralytics/models/fastsam/predict.py +33 -44
  80. ultralytics/models/fastsam/utils.py +4 -5
  81. ultralytics/models/fastsam/val.py +12 -14
  82. ultralytics/models/nas/__init__.py +1 -1
  83. ultralytics/models/nas/model.py +16 -20
  84. ultralytics/models/nas/predict.py +12 -14
  85. ultralytics/models/nas/val.py +4 -5
  86. ultralytics/models/rtdetr/__init__.py +1 -1
  87. ultralytics/models/rtdetr/model.py +9 -9
  88. ultralytics/models/rtdetr/predict.py +22 -17
  89. ultralytics/models/rtdetr/train.py +20 -16
  90. ultralytics/models/rtdetr/val.py +79 -59
  91. ultralytics/models/sam/__init__.py +8 -2
  92. ultralytics/models/sam/amg.py +53 -38
  93. ultralytics/models/sam/build.py +29 -31
  94. ultralytics/models/sam/model.py +33 -38
  95. ultralytics/models/sam/modules/blocks.py +159 -182
  96. ultralytics/models/sam/modules/decoders.py +38 -47
  97. ultralytics/models/sam/modules/encoders.py +114 -133
  98. ultralytics/models/sam/modules/memory_attention.py +38 -31
  99. ultralytics/models/sam/modules/sam.py +114 -93
  100. ultralytics/models/sam/modules/tiny_encoder.py +268 -291
  101. ultralytics/models/sam/modules/transformer.py +59 -66
  102. ultralytics/models/sam/modules/utils.py +55 -72
  103. ultralytics/models/sam/predict.py +745 -341
  104. ultralytics/models/utils/loss.py +118 -107
  105. ultralytics/models/utils/ops.py +118 -71
  106. ultralytics/models/yolo/__init__.py +1 -1
  107. ultralytics/models/yolo/classify/predict.py +28 -26
  108. ultralytics/models/yolo/classify/train.py +50 -81
  109. ultralytics/models/yolo/classify/val.py +68 -61
  110. ultralytics/models/yolo/detect/predict.py +12 -15
  111. ultralytics/models/yolo/detect/train.py +56 -46
  112. ultralytics/models/yolo/detect/val.py +279 -223
  113. ultralytics/models/yolo/model.py +167 -86
  114. ultralytics/models/yolo/obb/predict.py +7 -11
  115. ultralytics/models/yolo/obb/train.py +23 -25
  116. ultralytics/models/yolo/obb/val.py +107 -99
  117. ultralytics/models/yolo/pose/__init__.py +1 -1
  118. ultralytics/models/yolo/pose/predict.py +12 -14
  119. ultralytics/models/yolo/pose/train.py +31 -69
  120. ultralytics/models/yolo/pose/val.py +119 -254
  121. ultralytics/models/yolo/segment/predict.py +21 -25
  122. ultralytics/models/yolo/segment/train.py +12 -66
  123. ultralytics/models/yolo/segment/val.py +126 -305
  124. ultralytics/models/yolo/world/train.py +53 -45
  125. ultralytics/models/yolo/world/train_world.py +51 -32
  126. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  127. ultralytics/models/yolo/yoloe/predict.py +30 -37
  128. ultralytics/models/yolo/yoloe/train.py +89 -71
  129. ultralytics/models/yolo/yoloe/train_seg.py +15 -17
  130. ultralytics/models/yolo/yoloe/val.py +56 -41
  131. ultralytics/nn/__init__.py +9 -11
  132. ultralytics/nn/autobackend.py +179 -107
  133. ultralytics/nn/modules/__init__.py +67 -67
  134. ultralytics/nn/modules/activation.py +8 -7
  135. ultralytics/nn/modules/block.py +302 -323
  136. ultralytics/nn/modules/conv.py +61 -104
  137. ultralytics/nn/modules/head.py +488 -186
  138. ultralytics/nn/modules/transformer.py +183 -123
  139. ultralytics/nn/modules/utils.py +15 -20
  140. ultralytics/nn/tasks.py +327 -203
  141. ultralytics/nn/text_model.py +81 -65
  142. ultralytics/py.typed +1 -0
  143. ultralytics/solutions/__init__.py +12 -12
  144. ultralytics/solutions/ai_gym.py +19 -27
  145. ultralytics/solutions/analytics.py +36 -26
  146. ultralytics/solutions/config.py +29 -28
  147. ultralytics/solutions/distance_calculation.py +23 -24
  148. ultralytics/solutions/heatmap.py +17 -19
  149. ultralytics/solutions/instance_segmentation.py +21 -19
  150. ultralytics/solutions/object_blurrer.py +16 -17
  151. ultralytics/solutions/object_counter.py +48 -53
  152. ultralytics/solutions/object_cropper.py +22 -16
  153. ultralytics/solutions/parking_management.py +61 -58
  154. ultralytics/solutions/queue_management.py +19 -19
  155. ultralytics/solutions/region_counter.py +63 -50
  156. ultralytics/solutions/security_alarm.py +22 -25
  157. ultralytics/solutions/similarity_search.py +107 -60
  158. ultralytics/solutions/solutions.py +343 -262
  159. ultralytics/solutions/speed_estimation.py +35 -31
  160. ultralytics/solutions/streamlit_inference.py +104 -40
  161. ultralytics/solutions/templates/similarity-search.html +31 -24
  162. ultralytics/solutions/trackzone.py +24 -24
  163. ultralytics/solutions/vision_eye.py +11 -12
  164. ultralytics/trackers/__init__.py +1 -1
  165. ultralytics/trackers/basetrack.py +18 -27
  166. ultralytics/trackers/bot_sort.py +48 -39
  167. ultralytics/trackers/byte_tracker.py +94 -94
  168. ultralytics/trackers/track.py +7 -16
  169. ultralytics/trackers/utils/gmc.py +37 -69
  170. ultralytics/trackers/utils/kalman_filter.py +68 -76
  171. ultralytics/trackers/utils/matching.py +13 -17
  172. ultralytics/utils/__init__.py +251 -275
  173. ultralytics/utils/autobatch.py +19 -7
  174. ultralytics/utils/autodevice.py +68 -38
  175. ultralytics/utils/benchmarks.py +169 -130
  176. ultralytics/utils/callbacks/base.py +12 -13
  177. ultralytics/utils/callbacks/clearml.py +14 -15
  178. ultralytics/utils/callbacks/comet.py +139 -66
  179. ultralytics/utils/callbacks/dvc.py +19 -27
  180. ultralytics/utils/callbacks/hub.py +8 -6
  181. ultralytics/utils/callbacks/mlflow.py +6 -10
  182. ultralytics/utils/callbacks/neptune.py +11 -19
  183. ultralytics/utils/callbacks/platform.py +73 -0
  184. ultralytics/utils/callbacks/raytune.py +3 -4
  185. ultralytics/utils/callbacks/tensorboard.py +9 -12
  186. ultralytics/utils/callbacks/wb.py +33 -30
  187. ultralytics/utils/checks.py +163 -114
  188. ultralytics/utils/cpu.py +89 -0
  189. ultralytics/utils/dist.py +24 -20
  190. ultralytics/utils/downloads.py +176 -146
  191. ultralytics/utils/errors.py +11 -13
  192. ultralytics/utils/events.py +113 -0
  193. ultralytics/utils/export/__init__.py +7 -0
  194. ultralytics/utils/{export.py → export/engine.py} +81 -63
  195. ultralytics/utils/export/imx.py +294 -0
  196. ultralytics/utils/export/tensorflow.py +217 -0
  197. ultralytics/utils/files.py +33 -36
  198. ultralytics/utils/git.py +137 -0
  199. ultralytics/utils/instance.py +105 -120
  200. ultralytics/utils/logger.py +404 -0
  201. ultralytics/utils/loss.py +99 -61
  202. ultralytics/utils/metrics.py +649 -478
  203. ultralytics/utils/nms.py +337 -0
  204. ultralytics/utils/ops.py +263 -451
  205. ultralytics/utils/patches.py +70 -31
  206. ultralytics/utils/plotting.py +253 -223
  207. ultralytics/utils/tal.py +48 -61
  208. ultralytics/utils/torch_utils.py +244 -251
  209. ultralytics/utils/tqdm.py +438 -0
  210. ultralytics/utils/triton.py +22 -23
  211. ultralytics/utils/tuner.py +11 -10
  212. dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
  213. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
  214. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
  215. {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
@@ -29,19 +29,19 @@ from pathlib import Path
29
29
 
30
30
  import numpy as np
31
31
  import torch
32
+ import torch.distributed as dist
32
33
 
33
34
  from ultralytics.cfg import get_cfg, get_save_dir
34
35
  from ultralytics.data.utils import check_cls_dataset, check_det_dataset
35
36
  from ultralytics.nn.autobackend import AutoBackend
36
- from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
37
+ from ultralytics.utils import LOGGER, RANK, TQDM, callbacks, colorstr, emojis
37
38
  from ultralytics.utils.checks import check_imgsz
38
39
  from ultralytics.utils.ops import Profile
39
- from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
40
+ from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
40
41
 
41
42
 
42
43
  class BaseValidator:
43
- """
44
- A base class for creating validators.
44
+ """A base class for creating validators.
45
45
 
46
46
  This class provides the foundation for validation processes, including model evaluation, metric computation, and
47
47
  result visualization.
@@ -49,7 +49,6 @@ class BaseValidator:
49
49
  Attributes:
50
50
  args (SimpleNamespace): Configuration for the validator.
51
51
  dataloader (DataLoader): Dataloader to use for validation.
52
- pbar (tqdm): Progress bar to update during validation.
53
52
  model (nn.Module): Model to validate.
54
53
  data (dict): Data dictionary containing dataset information.
55
54
  device (torch.device): Device to use for validation.
@@ -62,11 +61,13 @@ class BaseValidator:
62
61
  nc (int): Number of classes.
63
62
  iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
64
63
  jdict (list): List to store JSON validation results.
65
- speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
66
- batch processing times in milliseconds.
64
+ speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective batch
65
+ processing times in milliseconds.
67
66
  save_dir (Path): Directory to save results.
68
67
  plots (dict): Dictionary to store plots for visualization.
69
68
  callbacks (dict): Dictionary to store various callback functions.
69
+ stride (int): Model stride for padding calculations.
70
+ loss (torch.Tensor): Accumulated loss during training validation.
70
71
 
71
72
  Methods:
72
73
  __call__: Execute validation process, running inference on dataloader and computing performance metrics.
@@ -81,30 +82,28 @@ class BaseValidator:
81
82
  update_metrics: Update metrics based on predictions and batch.
82
83
  finalize_metrics: Finalize and return all metrics.
83
84
  get_stats: Return statistics about the model's performance.
84
- check_stats: Check statistics.
85
85
  print_results: Print the results of the model's predictions.
86
86
  get_desc: Get description of the YOLO model.
87
- on_plot: Register plots (e.g. to be consumed in callbacks).
87
+ on_plot: Register plots for visualization.
88
88
  plot_val_samples: Plot validation samples during training.
89
89
  plot_predictions: Plot YOLO model predictions on batch images.
90
90
  pred_to_json: Convert predictions to JSON format.
91
91
  eval_json: Evaluate and return JSON format of prediction statistics.
92
92
  """
93
93
 
94
- def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
95
- """
96
- Initialize a BaseValidator instance.
94
+ def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
95
+ """Initialize a BaseValidator instance.
97
96
 
98
97
  Args:
99
98
  dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
100
99
  save_dir (Path, optional): Directory to save results.
101
- pbar (tqdm.tqdm, optional): Progress bar for displaying progress.
102
100
  args (SimpleNamespace, optional): Configuration for the validator.
103
101
  _callbacks (dict, optional): Dictionary to store various callback functions.
104
102
  """
103
+ import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
104
+
105
105
  self.args = get_cfg(overrides=args)
106
106
  self.dataloader = dataloader
107
- self.pbar = pbar
108
107
  self.stride = None
109
108
  self.data = None
110
109
  self.device = None
@@ -122,7 +121,7 @@ class BaseValidator:
122
121
  self.save_dir = save_dir or get_save_dir(self.args)
123
122
  (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
124
123
  if self.args.conf is None:
125
- self.args.conf = 0.001 # default conf=0.001
124
+ self.args.conf = 0.01 if self.args.task == "obb" else 0.001 # reduce OBB val memory usage
126
125
  self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
127
126
 
128
127
  self.plots = {}
@@ -130,15 +129,14 @@ class BaseValidator:
130
129
 
131
130
  @smart_inference_mode()
132
131
  def __call__(self, trainer=None, model=None):
133
- """
134
- Execute validation process, running inference on dataloader and computing performance metrics.
132
+ """Execute validation process, running inference on dataloader and computing performance metrics.
135
133
 
136
134
  Args:
137
135
  trainer (object, optional): Trainer object that contains the model to validate.
138
136
  model (nn.Module, optional): Model to validate if not using a trainer.
139
137
 
140
138
  Returns:
141
- stats (dict): Dictionary containing validation statistics.
139
+ (dict): Dictionary containing validation statistics.
142
140
  """
143
141
  self.training = trainer is not None
144
142
  augment = self.args.augment and (not self.training)
@@ -148,8 +146,9 @@ class BaseValidator:
148
146
  # Force FP16 val during training
149
147
  self.args.half = self.device.type != "cpu" and trainer.amp
150
148
  model = trainer.ema.ema or trainer.model
149
+ if trainer.args.compile and hasattr(model, "_orig_mod"):
150
+ model = model._orig_mod # validate non-compiled original model to avoid issues
151
151
  model = model.half() if self.args.half else model.float()
152
- # self.model = model
153
152
  self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
154
153
  self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
155
154
  model.eval()
@@ -158,24 +157,21 @@ class BaseValidator:
158
157
  LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
159
158
  callbacks.add_integration_callbacks(self)
160
159
  model = AutoBackend(
161
- weights=model or self.args.model,
162
- device=select_device(self.args.device, self.args.batch),
160
+ model=model or self.args.model,
161
+ device=select_device(self.args.device) if RANK == -1 else torch.device("cuda", RANK),
163
162
  dnn=self.args.dnn,
164
163
  data=self.args.data,
165
164
  fp16=self.args.half,
166
165
  )
167
- # self.model = model
168
166
  self.device = model.device # update device
169
167
  self.args.half = model.fp16 # update half
170
- stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
168
+ stride, pt, jit = model.stride, model.pt, model.jit
171
169
  imgsz = check_imgsz(self.args.imgsz, stride=stride)
172
- if engine:
173
- self.args.batch = model.batch_size
174
- elif not (pt or jit or getattr(model, "dynamic", False)):
170
+ if not (pt or jit or getattr(model, "dynamic", False)):
175
171
  self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
176
172
  LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
177
173
 
178
- if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
174
+ if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
179
175
  self.data = check_det_dataset(self.args.data)
180
176
  elif self.args.task == "classify":
181
177
  self.data = check_cls_dataset(self.args.data, split=self.args.split)
@@ -184,12 +180,14 @@ class BaseValidator:
184
180
 
185
181
  if self.device.type in {"cpu", "mps"}:
186
182
  self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
187
- if not (pt or getattr(model, "dynamic", False)):
183
+ if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
188
184
  self.args.rect = False
189
185
  self.stride = model.stride # used in get_dataloader() for padding
190
186
  self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
191
187
 
192
188
  model.eval()
189
+ if self.args.compile:
190
+ model = attempt_compile(model, device=self.device)
193
191
  model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
194
192
 
195
193
  self.run_callbacks("on_val_start")
@@ -200,7 +198,7 @@ class BaseValidator:
200
198
  Profile(device=self.device),
201
199
  )
202
200
  bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
203
- self.init_metrics(de_parallel(model))
201
+ self.init_metrics(unwrap_model(model))
204
202
  self.jdict = [] # empty before each val
205
203
  for batch_i, batch in enumerate(bar):
206
204
  self.run_callbacks("on_val_batch_start")
@@ -223,22 +221,34 @@ class BaseValidator:
223
221
  preds = self.postprocess(preds)
224
222
 
225
223
  self.update_metrics(preds, batch)
226
- if self.args.plots and batch_i < 3:
224
+ if self.args.plots and batch_i < 3 and RANK in {-1, 0}:
227
225
  self.plot_val_samples(batch, batch_i)
228
226
  self.plot_predictions(batch, preds, batch_i)
229
227
 
230
228
  self.run_callbacks("on_val_batch_end")
231
- stats = self.get_stats()
232
- self.check_stats(stats)
233
- self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
234
- self.finalize_metrics()
235
- self.print_results()
236
- self.run_callbacks("on_val_end")
229
+
230
+ stats = {}
231
+ self.gather_stats()
232
+ if RANK in {-1, 0}:
233
+ stats = self.get_stats()
234
+ self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
235
+ self.finalize_metrics()
236
+ self.print_results()
237
+ self.run_callbacks("on_val_end")
238
+
237
239
  if self.training:
238
240
  model.float()
239
- results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
241
+ # Reduce loss across all GPUs
242
+ loss = self.loss.clone().detach()
243
+ if trainer.world_size > 1:
244
+ dist.reduce(loss, dst=0, op=dist.ReduceOp.AVG)
245
+ if RANK > 0:
246
+ return
247
+ results = {**stats, **trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val")}
240
248
  return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
241
249
  else:
250
+ if RANK > 0:
251
+ return stats
242
252
  LOGGER.info(
243
253
  "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
244
254
  *tuple(self.speed.values())
@@ -256,14 +266,13 @@ class BaseValidator:
256
266
  def match_predictions(
257
267
  self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
258
268
  ) -> torch.Tensor:
259
- """
260
- Match predictions to ground truth objects using IoU.
269
+ """Match predictions to ground truth objects using IoU.
261
270
 
262
271
  Args:
263
272
  pred_classes (torch.Tensor): Predicted class indices of shape (N,).
264
273
  true_classes (torch.Tensor): Target class indices of shape (M,).
265
274
  iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
266
- use_scipy (bool): Whether to use scipy for matching (more precise).
275
+ use_scipy (bool, optional): Whether to use scipy for matching (more precise).
267
276
 
268
277
  Returns:
269
278
  (torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
@@ -292,7 +301,6 @@ class BaseValidator:
292
301
  if matches.shape[0] > 1:
293
302
  matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
294
303
  matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
295
- # matches = matches[matches[:, 2].argsort()[::-1]]
296
304
  matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
297
305
  correct[matches[:, 1].astype(int), i] = True
298
306
  return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
@@ -330,7 +338,7 @@ class BaseValidator:
330
338
  """Update metrics based on predictions and batch."""
331
339
  pass
332
340
 
333
- def finalize_metrics(self, *args, **kwargs):
341
+ def finalize_metrics(self):
334
342
  """Finalize and return all metrics."""
335
343
  pass
336
344
 
@@ -338,8 +346,8 @@ class BaseValidator:
338
346
  """Return statistics about the model's performance."""
339
347
  return {}
340
348
 
341
- def check_stats(self, stats):
342
- """Check statistics."""
349
+ def gather_stats(self):
350
+ """Gather statistics from all the GPUs during DDP training to GPU 0."""
343
351
  pass
344
352
 
345
353
  def print_results(self):
@@ -356,10 +364,9 @@ class BaseValidator:
356
364
  return []
357
365
 
358
366
  def on_plot(self, name, data=None):
359
- """Register plots (e.g. to be consumed in callbacks)."""
367
+ """Register plots for visualization."""
360
368
  self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
361
369
 
362
- # TODO: may need to put these following functions into callback
363
370
  def plot_val_samples(self, batch, ni):
364
371
  """Plot validation samples during training."""
365
372
  pass
@@ -1,31 +1,29 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- import requests
3
+ from __future__ import annotations
4
4
 
5
5
  from ultralytics.data.utils import HUBDatasetStats
6
6
  from ultralytics.hub.auth import Auth
7
7
  from ultralytics.hub.session import HUBTrainingSession
8
- from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events
8
+ from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
9
9
  from ultralytics.utils import LOGGER, SETTINGS, checks
10
10
 
11
11
  __all__ = (
12
- "PREFIX",
13
12
  "HUB_WEB_ROOT",
13
+ "PREFIX",
14
14
  "HUBTrainingSession",
15
- "login",
16
- "logout",
17
- "reset_model",
15
+ "check_dataset",
18
16
  "export_fmts_hub",
19
17
  "export_model",
20
18
  "get_export",
21
- "check_dataset",
22
- "events",
19
+ "login",
20
+ "logout",
21
+ "reset_model",
23
22
  )
24
23
 
25
24
 
26
- def login(api_key: str = None, save: bool = True) -> bool:
27
- """
28
- Log in to the Ultralytics HUB API using the provided API key.
25
+ def login(api_key: str | None = None, save: bool = True) -> bool:
26
+ """Log in to the Ultralytics HUB API using the provided API key.
29
27
 
30
28
  The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY
31
29
  environment variable if successfully authenticated.
@@ -68,19 +66,15 @@ def login(api_key: str = None, save: bool = True) -> bool:
68
66
 
69
67
 
70
68
  def logout():
71
- """
72
- Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo login'.
73
-
74
- Examples:
75
- >>> from ultralytics import hub
76
- >>> hub.logout()
77
- """
69
+ """Log out of Ultralytics HUB by removing the API key from the settings file."""
78
70
  SETTINGS["api_key"] = ""
79
71
  LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.")
80
72
 
81
73
 
82
74
  def reset_model(model_id: str = ""):
83
75
  """Reset a trained model to an untrained state."""
76
+ import requests # scoped as slow import
77
+
84
78
  r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
85
79
  if r.status_code == 200:
86
80
  LOGGER.info(f"{PREFIX}Model reset successfully")
@@ -89,15 +83,14 @@ def reset_model(model_id: str = ""):
89
83
 
90
84
 
91
85
  def export_fmts_hub():
92
- """Returns a list of HUB-supported export formats."""
86
+ """Return a list of HUB-supported export formats."""
93
87
  from ultralytics.engine.exporter import export_formats
94
88
 
95
- return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
89
+ return [*list(export_formats()["Argument"][1:]), "ultralytics_tflite", "ultralytics_coreml"]
96
90
 
97
91
 
98
92
  def export_model(model_id: str = "", format: str = "torchscript"):
99
- """
100
- Export a model to a specified format for deployment via the Ultralytics HUB API.
93
+ """Export a model to a specified format for deployment via the Ultralytics HUB API.
101
94
 
102
95
  Args:
103
96
  model_id (str): The ID of the model to export. An empty string will use the default model.
@@ -111,6 +104,8 @@ def export_model(model_id: str = "", format: str = "torchscript"):
111
104
  >>> from ultralytics import hub
112
105
  >>> hub.export_model(model_id="your_model_id", format="torchscript")
113
106
  """
107
+ import requests # scoped as slow import
108
+
114
109
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
115
110
  r = requests.post(
116
111
  f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
@@ -120,20 +115,24 @@ def export_model(model_id: str = "", format: str = "torchscript"):
120
115
 
121
116
 
122
117
  def get_export(model_id: str = "", format: str = "torchscript"):
123
- """
124
- Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
118
+ """Retrieve an exported model in the specified format from Ultralytics HUB using the model ID.
125
119
 
126
120
  Args:
127
121
  model_id (str): The ID of the model to retrieve from Ultralytics HUB.
128
122
  format (str): The export format to retrieve. Must be one of the supported formats returned by export_fmts_hub().
129
123
 
124
+ Returns:
125
+ (dict): JSON response containing the exported model information.
126
+
130
127
  Raises:
131
128
  AssertionError: If the specified format is not supported or if the API request fails.
132
129
 
133
130
  Examples:
134
131
  >>> from ultralytics import hub
135
- >>> hub.get_export(model_id="your_model_id", format="torchscript")
132
+ >>> result = hub.get_export(model_id="your_model_id", format="torchscript")
136
133
  """
134
+ import requests # scoped as slow import
135
+
137
136
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
138
137
  r = requests.post(
139
138
  f"{HUB_API_ROOT}/get-export",
@@ -145,8 +144,7 @@ def get_export(model_id: str = "", format: str = "torchscript"):
145
144
 
146
145
 
147
146
  def check_dataset(path: str, task: str) -> None:
148
- """
149
- Check HUB dataset Zip file for errors before upload.
147
+ """Check HUB dataset Zip file for errors before upload.
150
148
 
151
149
  Args:
152
150
  path (str): Path to data.zip (with data.yaml inside data.zip).
@@ -160,7 +158,7 @@ def check_dataset(path: str, task: str) -> None:
160
158
  >>> check_dataset("path/to/dota8.zip", task="obb") # OBB dataset
161
159
  >>> check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset
162
160
 
163
- Note:
161
+ Notes:
164
162
  Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
165
163
  i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
166
164
  """
ultralytics/hub/auth.py CHANGED
@@ -1,7 +1,5 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- import requests
4
-
5
3
  from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
6
4
  from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis
7
5
 
@@ -9,8 +7,7 @@ API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
9
7
 
10
8
 
11
9
  class Auth:
12
- """
13
- Manages authentication processes including API key handling, cookie-based authentication, and header generation.
10
+ """Manages authentication processes including API key handling, cookie-based authentication, and header generation.
14
11
 
15
12
  The class supports different methods of authentication:
16
13
  1. Directly using an API key.
@@ -21,13 +18,25 @@ class Auth:
21
18
  id_token (str | bool): Token used for identity verification, initialized as False.
22
19
  api_key (str | bool): API key for authentication, initialized as False.
23
20
  model_key (bool): Placeholder for model key, initialized as False.
21
+
22
+ Methods:
23
+ authenticate: Attempt to authenticate with the server using either id_token or API key.
24
+ auth_with_cookies: Attempt to fetch authentication via cookies and set id_token.
25
+ get_auth_header: Get the authentication header for making API requests.
26
+ request_api_key: Prompt the user to input their API key.
27
+
28
+ Examples:
29
+ Initialize Auth with an API key
30
+ >>> auth = Auth(api_key="your_api_key_here")
31
+
32
+ Initialize Auth without API key (will prompt for input)
33
+ >>> auth = Auth()
24
34
  """
25
35
 
26
36
  id_token = api_key = model_key = False
27
37
 
28
38
  def __init__(self, api_key: str = "", verbose: bool = False):
29
- """
30
- Initialize Auth class and authenticate user.
39
+ """Initialize Auth class and authenticate user.
31
40
 
32
41
  Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful
33
42
  authentication.
@@ -37,7 +46,7 @@ class Auth:
37
46
  verbose (bool): Enable verbose logging.
38
47
  """
39
48
  # Split the input API key in case it contains a combined key_model and keep only the API key part
40
- api_key = api_key.split("_")[0]
49
+ api_key = api_key.split("_", 1)[0]
41
50
 
42
51
  # Set API key attribute as value passed or SETTINGS API key if none passed
43
52
  self.api_key = api_key or SETTINGS.get("api_key", "")
@@ -71,24 +80,32 @@ class Auth:
71
80
  LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'")
72
81
 
73
82
  def request_api_key(self, max_attempts: int = 3) -> bool:
74
- """Prompt the user to input their API key."""
83
+ """Prompt the user to input their API key.
84
+
85
+ Args:
86
+ max_attempts (int): Maximum number of authentication attempts.
87
+
88
+ Returns:
89
+ (bool): True if authentication is successful, False otherwise.
90
+ """
75
91
  import getpass
76
92
 
77
93
  for attempts in range(max_attempts):
78
94
  LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
79
95
  input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
80
- self.api_key = input_key.split("_")[0] # remove model id if present
96
+ self.api_key = input_key.split("_", 1)[0] # remove model id if present
81
97
  if self.authenticate():
82
98
  return True
83
99
  raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
84
100
 
85
101
  def authenticate(self) -> bool:
86
- """
87
- Attempt to authenticate with the server using either id_token or API key.
102
+ """Attempt to authenticate with the server using either id_token or API key.
88
103
 
89
104
  Returns:
90
105
  (bool): True if authentication is successful, False otherwise.
91
106
  """
107
+ import requests # scoped as slow import
108
+
92
109
  try:
93
110
  if header := self.get_auth_header():
94
111
  r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
@@ -102,8 +119,7 @@ class Auth:
102
119
  return False
103
120
 
104
121
  def auth_with_cookies(self) -> bool:
105
- """
106
- Attempt to fetch authentication via cookies and set id_token.
122
+ """Attempt to fetch authentication via cookies and set id_token.
107
123
 
108
124
  User must be logged in to HUB and running in a supported browser.
109
125
 
@@ -124,8 +140,7 @@ class Auth:
124
140
  return False
125
141
 
126
142
  def get_auth_header(self):
127
- """
128
- Get the authentication header for making API requests.
143
+ """Get the authentication header for making API requests.
129
144
 
130
145
  Returns:
131
146
  (dict | None): The authentication header if id_token or API key is set, None otherwise.
@@ -134,4 +149,3 @@ class Auth:
134
149
  return {"authorization": f"Bearer {self.id_token}"}
135
150
  elif self.api_key:
136
151
  return {"x-api-key": self.api_key}
137
- # else returns None
@@ -1,22 +1,20 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import concurrent.futures
4
6
  import statistics
5
7
  import time
6
- from typing import List, Optional, Tuple
7
-
8
- import requests
9
8
 
10
9
 
11
10
  class GCPRegions:
12
- """
13
- A class for managing and analyzing Google Cloud Platform (GCP) regions.
11
+ """A class for managing and analyzing Google Cloud Platform (GCP) regions.
14
12
 
15
- This class provides functionality to initialize, categorize, and analyze GCP regions based on their
16
- geographical location, tier classification, and network latency.
13
+ This class provides functionality to initialize, categorize, and analyze GCP regions based on their geographical
14
+ location, tier classification, and network latency.
17
15
 
18
16
  Attributes:
19
- regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
17
+ regions (dict[str, tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country.
20
18
 
21
19
  Methods:
22
20
  tier1: Returns a list of tier 1 GCP regions.
@@ -31,7 +29,7 @@ class GCPRegions:
31
29
  """
32
30
 
33
31
  def __init__(self):
34
- """Initializes the GCPRegions class with predefined Google Cloud Platform regions and their details."""
32
+ """Initialize the GCPRegions class with predefined Google Cloud Platform regions and their details."""
35
33
  self.regions = {
36
34
  "asia-east1": (1, "Taiwan", "China"),
37
35
  "asia-east2": (2, "Hong Kong", "China"),
@@ -73,41 +71,42 @@ class GCPRegions:
73
71
  "us-west4": (2, "Las Vegas", "United States"),
74
72
  }
75
73
 
76
- def tier1(self) -> List[str]:
77
- """Returns a list of GCP regions classified as tier 1 based on predefined criteria."""
74
+ def tier1(self) -> list[str]:
75
+ """Return a list of GCP regions classified as tier 1 based on predefined criteria."""
78
76
  return [region for region, info in self.regions.items() if info[0] == 1]
79
77
 
80
- def tier2(self) -> List[str]:
81
- """Returns a list of GCP regions classified as tier 2 based on predefined criteria."""
78
+ def tier2(self) -> list[str]:
79
+ """Return a list of GCP regions classified as tier 2 based on predefined criteria."""
82
80
  return [region for region, info in self.regions.items() if info[0] == 2]
83
81
 
84
82
  @staticmethod
85
- def _ping_region(region: str, attempts: int = 1) -> Tuple[str, float, float, float, float]:
86
- """
87
- Ping a specified GCP region and measure network latency statistics.
83
+ def _ping_region(region: str, attempts: int = 1) -> tuple[str, float, float, float, float]:
84
+ """Ping a specified GCP region and measure network latency statistics.
88
85
 
89
86
  Args:
90
- region (str): The GCP region identifier to ping (e.g., 'us-central1').
91
- attempts (int): Number of ping attempts to make for calculating statistics.
87
+ region (str): The GCP region identifier to ping (e.g., 'us-central1').
88
+ attempts (int, optional): Number of ping attempts to make for calculating statistics.
92
89
 
93
90
  Returns:
94
- region (str): The GCP region identifier that was pinged.
95
- mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.
96
- min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.
97
- max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.
98
- std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.
91
+ region (str): The GCP region identifier that was pinged.
92
+ mean_latency (float): Mean latency in milliseconds, or infinity if all pings failed.
93
+ std_dev (float): Standard deviation of latencies in milliseconds, or infinity if all pings failed.
94
+ min_latency (float): Minimum latency in milliseconds, or infinity if all pings failed.
95
+ max_latency (float): Maximum latency in milliseconds, or infinity if all pings failed.
99
96
 
100
97
  Examples:
101
- >>> region, mean, min_lat, max_lat, std = GCPRegions._ping_region("us-central1", attempts=3)
102
- >>> print(f"Region {region} has mean latency: {mean:.2f}ms")
98
+ >>> region, mean, std, min_lat, max_lat = GCPRegions._ping_region("us-central1", attempts=3)
99
+ >>> print(f"Region {region} has mean latency: {mean:.2f}ms")
103
100
  """
101
+ import requests # scoped as slow import
102
+
104
103
  url = f"https://{region}-docker.pkg.dev"
105
104
  latencies = []
106
105
  for _ in range(attempts):
107
106
  try:
108
107
  start_time = time.time()
109
108
  _ = requests.head(url, timeout=5)
110
- latency = (time.time() - start_time) * 1000 # convert latency to milliseconds
109
+ latency = (time.time() - start_time) * 1000 # Convert latency to milliseconds
111
110
  if latency != float("inf"):
112
111
  latencies.append(latency)
113
112
  except requests.RequestException:
@@ -122,21 +121,20 @@ class GCPRegions:
122
121
  self,
123
122
  top: int = 1,
124
123
  verbose: bool = False,
125
- tier: Optional[int] = None,
124
+ tier: int | None = None,
126
125
  attempts: int = 1,
127
- ) -> List[Tuple[str, float, float, float, float]]:
128
- """
129
- Determines the GCP regions with the lowest latency based on ping tests.
126
+ ) -> list[tuple[str, float, float, float, float]]:
127
+ """Determine the GCP regions with the lowest latency based on ping tests.
130
128
 
131
129
  Args:
132
- top (int): Number of top regions to return.
133
- verbose (bool): If True, prints detailed latency information for all tested regions.
134
- tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested.
135
- attempts (int): Number of ping attempts per region.
130
+ top (int, optional): Number of top regions to return.
131
+ verbose (bool, optional): If True, prints detailed latency information for all tested regions.
132
+ tier (int | None, optional): Filter regions by tier (1 or 2). If None, all regions are tested.
133
+ attempts (int, optional): Number of ping attempts per region.
136
134
 
137
135
  Returns:
138
- (List[Tuple[str, float, float, float, float]]): List of tuples containing region information and
139
- latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
136
+ (list[tuple[str, float, float, float, float]]): List of tuples containing region information and latency
137
+ statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency).
140
138
 
141
139
  Examples:
142
140
  >>> regions = GCPRegions()