dgenerate-ultralytics-headless 8.3.187__py3-none-any.whl → 8.3.190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {dgenerate_ultralytics_headless-8.3.187.dist-info → dgenerate_ultralytics_headless-8.3.190.dist-info}/METADATA +3 -2
  2. {dgenerate_ultralytics_headless-8.3.187.dist-info → dgenerate_ultralytics_headless-8.3.190.dist-info}/RECORD +38 -37
  3. ultralytics/__init__.py +1 -1
  4. ultralytics/data/utils.py +2 -2
  5. ultralytics/engine/exporter.py +9 -6
  6. ultralytics/engine/predictor.py +1 -1
  7. ultralytics/engine/results.py +5 -5
  8. ultralytics/engine/trainer.py +2 -0
  9. ultralytics/engine/validator.py +3 -1
  10. ultralytics/hub/__init__.py +6 -2
  11. ultralytics/hub/auth.py +2 -2
  12. ultralytics/hub/google/__init__.py +2 -2
  13. ultralytics/hub/session.py +3 -5
  14. ultralytics/hub/utils.py +5 -5
  15. ultralytics/models/rtdetr/val.py +3 -1
  16. ultralytics/models/yolo/detect/predict.py +2 -2
  17. ultralytics/models/yolo/detect/val.py +15 -4
  18. ultralytics/models/yolo/obb/val.py +5 -2
  19. ultralytics/models/yolo/segment/val.py +0 -3
  20. ultralytics/nn/autobackend.py +29 -36
  21. ultralytics/nn/modules/__init__.py +3 -3
  22. ultralytics/nn/modules/head.py +5 -1
  23. ultralytics/nn/tasks.py +2 -2
  24. ultralytics/utils/__init__.py +49 -14
  25. ultralytics/utils/benchmarks.py +12 -6
  26. ultralytics/utils/callbacks/platform.py +2 -1
  27. ultralytics/utils/checks.py +3 -3
  28. ultralytics/utils/downloads.py +46 -40
  29. ultralytics/utils/logger.py +7 -6
  30. ultralytics/utils/nms.py +346 -0
  31. ultralytics/utils/ops.py +80 -249
  32. ultralytics/utils/tal.py +1 -1
  33. ultralytics/utils/torch_utils.py +50 -47
  34. ultralytics/utils/tqdm.py +58 -59
  35. {dgenerate_ultralytics_headless-8.3.187.dist-info → dgenerate_ultralytics_headless-8.3.190.dist-info}/WHEEL +0 -0
  36. {dgenerate_ultralytics_headless-8.3.187.dist-info → dgenerate_ultralytics_headless-8.3.190.dist-info}/entry_points.txt +0 -0
  37. {dgenerate_ultralytics_headless-8.3.187.dist-info → dgenerate_ultralytics_headless-8.3.190.dist-info}/licenses/LICENSE +0 -0
  38. {dgenerate_ultralytics_headless-8.3.187.dist-info → dgenerate_ultralytics_headless-8.3.190.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@ import torch
14
14
  import torch.nn as nn
15
15
  from PIL import Image
16
16
 
17
- from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML
17
+ from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML, is_jetson
18
18
  from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml, is_rockchip
19
19
  from ultralytics.utils.downloads import attempt_download_asset, is_url
20
20
 
@@ -127,14 +127,14 @@ class AutoBackend(nn.Module):
127
127
  _model_type: Determine the model type from file path.
128
128
 
129
129
  Examples:
130
- >>> model = AutoBackend(weights="yolo11n.pt", device="cuda")
130
+ >>> model = AutoBackend(model="yolo11n.pt", device="cuda")
131
131
  >>> results = model(img)
132
132
  """
133
133
 
134
134
  @torch.no_grad()
135
135
  def __init__(
136
136
  self,
137
- weights: Union[str, List[str], torch.nn.Module] = "yolo11n.pt",
137
+ model: Union[str, List[str], torch.nn.Module] = "yolo11n.pt",
138
138
  device: torch.device = torch.device("cpu"),
139
139
  dnn: bool = False,
140
140
  data: Optional[Union[str, Path]] = None,
@@ -146,7 +146,7 @@ class AutoBackend(nn.Module):
146
146
  Initialize the AutoBackend for inference.
147
147
 
148
148
  Args:
149
- weights (str | List[str] | torch.nn.Module): Path to the model weights file or a module instance.
149
+ model (str | List[str] | torch.nn.Module): Path to the model weights file or a module instance.
150
150
  device (torch.device): Device to run the model on.
151
151
  dnn (bool): Use OpenCV DNN module for ONNX inference.
152
152
  data (str | Path, optional): Path to the additional data.yaml file containing class names.
@@ -155,8 +155,8 @@ class AutoBackend(nn.Module):
155
155
  verbose (bool): Enable verbose logging.
156
156
  """
157
157
  super().__init__()
158
- w = str(weights[0] if isinstance(weights, list) else weights)
159
- nn_module = isinstance(weights, torch.nn.Module)
158
+ w = str(model[0] if isinstance(model, list) else model)
159
+ nn_module = isinstance(model, torch.nn.Module)
160
160
  (
161
161
  pt,
162
162
  jit,
@@ -180,7 +180,7 @@ class AutoBackend(nn.Module):
180
180
  nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH)
181
181
  stride, ch = 32, 3 # default stride and channels
182
182
  end2end, dynamic = False, False
183
- model, metadata, task = None, None, None
183
+ metadata, task = None, None
184
184
 
185
185
  # Set device
186
186
  cuda = isinstance(device, torch.device) and torch.cuda.is_available() and device.type != "cpu" # use CUDA
@@ -192,33 +192,32 @@ class AutoBackend(nn.Module):
192
192
  if not (pt or triton or nn_module):
193
193
  w = attempt_download_asset(w)
194
194
 
195
- # In-memory PyTorch model
196
- if nn_module:
197
- model = weights.to(device)
198
- if fuse:
199
- model = model.fuse(verbose=verbose)
200
- if hasattr(model, "kpt_shape"):
201
- kpt_shape = model.kpt_shape # pose-only
202
- stride = max(int(model.stride.max()), 32) # model stride
203
- names = model.module.names if hasattr(model, "module") else model.names # get class names
204
- model.half() if fp16 else model.float()
205
- ch = model.yaml.get("channels", 3)
206
- self.model = model # explicitly assign for to(), cpu(), cuda(), half()
207
- pt = True
208
-
209
- # PyTorch
210
- elif pt:
211
- from ultralytics.nn.tasks import attempt_load_weights
195
+ # PyTorch (in-memory or file)
196
+ if nn_module or pt:
197
+ if nn_module:
198
+ pt = True
199
+ if fuse:
200
+ if IS_JETSON and is_jetson(jetpack=5):
201
+ # Jetson Jetpack5 requires device before fuse https://github.com/ultralytics/ultralytics/pull/21028
202
+ model = model.to(device)
203
+ model = model.fuse(verbose=verbose)
204
+ model = model.to(device)
205
+ else: # pt file
206
+ from ultralytics.nn.tasks import attempt_load_weights
207
+
208
+ model = attempt_load_weights(
209
+ model if isinstance(model, list) else w, device=device, inplace=True, fuse=fuse
210
+ )
212
211
 
213
- model = attempt_load_weights(
214
- weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse
215
- )
212
+ # Common PyTorch model processing
216
213
  if hasattr(model, "kpt_shape"):
217
214
  kpt_shape = model.kpt_shape # pose-only
218
215
  stride = max(int(model.stride.max()), 32) # model stride
219
216
  names = model.module.names if hasattr(model, "module") else model.names # get class names
220
217
  model.half() if fp16 else model.float()
221
218
  ch = model.yaml.get("channels", 3)
219
+ for p in model.parameters():
220
+ p.requires_grad = False
222
221
  self.model = model # explicitly assign for to(), cpu(), cuda(), half()
223
222
 
224
223
  # TorchScript
@@ -404,6 +403,7 @@ class AutoBackend(nn.Module):
404
403
 
405
404
  # CoreML
406
405
  elif coreml:
406
+ check_requirements("coremltools>=8.0")
407
407
  LOGGER.info(f"Loading {w} for CoreML inference...")
408
408
  import coremltools as ct
409
409
 
@@ -598,18 +598,13 @@ class AutoBackend(nn.Module):
598
598
  dynamic = metadata.get("args", {}).get("dynamic", dynamic)
599
599
  ch = metadata.get("channels", 3)
600
600
  elif not (pt or triton or nn_module):
601
- LOGGER.warning(f"Metadata not found for 'model={weights}'")
601
+ LOGGER.warning(f"Metadata not found for 'model={w}'")
602
602
 
603
603
  # Check names
604
604
  if "names" not in locals(): # names missing
605
605
  names = default_class_names(data)
606
606
  names = check_class_names(names)
607
607
 
608
- # Disable gradients
609
- if pt:
610
- for p in model.parameters():
611
- p.requires_grad = False
612
-
613
608
  self.__dict__.update(locals()) # assign all variables to self
614
609
 
615
610
  def forward(
@@ -855,8 +850,6 @@ class AutoBackend(nn.Module):
855
850
  Args:
856
851
  imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
857
852
  """
858
- import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
859
-
860
853
  warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
861
854
  if any(warmup_types) and (self.device.type != "cpu" or self.triton):
862
855
  im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
@@ -875,7 +868,7 @@ class AutoBackend(nn.Module):
875
868
  (List[bool]): List of booleans indicating the model type.
876
869
 
877
870
  Examples:
878
- >>> model = AutoBackend(weights="path/to/model.onnx")
871
+ >>> model = AutoBackend(model="path/to/model.onnx")
879
872
  >>> model_type = model._model_type() # returns "onnx"
880
873
  """
881
874
  from ultralytics.engine.exporter import export_formats
@@ -7,14 +7,14 @@ blocks, attention mechanisms, transformer components, and detection/segmentation
7
7
 
8
8
  Examples:
9
9
  Visualize a module with Netron
10
- >>> from ultralytics.nn.modules import *
10
+ >>> from ultralytics.nn.modules import Conv
11
11
  >>> import torch
12
- >>> import os
12
+ >>> import subprocess
13
13
  >>> x = torch.ones(1, 128, 40, 40)
14
14
  >>> m = Conv(128, 128)
15
15
  >>> f = f"{m._get_name()}.onnx"
16
16
  >>> torch.onnx.export(m, x, f)
17
- >>> os.system(f"onnxslim {f} {f} && open {f}") # pip install onnxslim
17
+ >>> subprocess.run(f"onnxslim {f} {f} && open {f}", shell=True, check=True) # pip install onnxslim
18
18
  """
19
19
 
20
20
  from .block import (
@@ -10,6 +10,7 @@ import torch.nn as nn
10
10
  import torch.nn.functional as F
11
11
  from torch.nn.init import constant_, xavier_uniform_
12
12
 
13
+ from ultralytics.utils import NOT_MACOS14
13
14
  from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
14
15
  from ultralytics.utils.torch_utils import fuse_conv_and_bn, smart_inference_mode
15
16
 
@@ -408,7 +409,10 @@ class Pose(Detect):
408
409
  else:
409
410
  y = kpts.clone()
410
411
  if ndim == 3:
411
- y[:, 2::ndim] = y[:, 2::ndim].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
412
+ if NOT_MACOS14:
413
+ y[:, 2::ndim].sigmoid_()
414
+ else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
415
+ y[:, 2::ndim] = y[:, 2::ndim].sigmoid()
412
416
  y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
413
417
  y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
414
418
  return y
ultralytics/nn/tasks.py CHANGED
@@ -1500,7 +1500,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
1500
1500
  for w in weights if isinstance(weights, list) else [weights]:
1501
1501
  ckpt, w = torch_safe_load(w) # load ckpt
1502
1502
  args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
1503
- model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
1503
+ model = (ckpt.get("ema") or ckpt["model"]).float() # FP32 model
1504
1504
 
1505
1505
  # Model compatibility updates
1506
1506
  model.args = args # attach args to model
@@ -1510,7 +1510,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
1510
1510
  model.stride = torch.tensor([32.0])
1511
1511
 
1512
1512
  # Append
1513
- ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
1513
+ ensemble.append((model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()).to(device))
1514
1514
 
1515
1515
  # Module updates
1516
1516
  for m in ensemble.modules():
@@ -8,10 +8,12 @@ import logging
8
8
  import os
9
9
  import platform
10
10
  import re
11
+ import socket
11
12
  import subprocess
12
13
  import sys
13
14
  import threading
14
15
  import time
16
+ from functools import lru_cache
15
17
  from pathlib import Path
16
18
  from threading import Lock
17
19
  from types import SimpleNamespace
@@ -43,6 +45,7 @@ VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbo
43
45
  LOGGING_NAME = "ultralytics"
44
46
  MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
45
47
  MACOS_VERSION = platform.mac_ver()[0] if MACOS else None
48
+ NOT_MACOS14 = not (MACOS and MACOS_VERSION.startswith("14."))
46
49
  ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
47
50
  PYTHON_VERSION = platform.python_version()
48
51
  TORCH_VERSION = torch.__version__
@@ -727,32 +730,45 @@ def is_raspberrypi() -> bool:
727
730
  return "rpi" in DEVICE_MODEL
728
731
 
729
732
 
730
- def is_jetson() -> bool:
733
+ @lru_cache(maxsize=3)
734
+ def is_jetson(jetpack=None) -> bool:
731
735
  """
732
736
  Determine if the Python environment is running on an NVIDIA Jetson device.
733
737
 
738
+ Args:
739
+ jetpack (int | None): If specified, check for specific JetPack version (4, 5, 6).
740
+
734
741
  Returns:
735
742
  (bool): True if running on an NVIDIA Jetson device, False otherwise.
736
743
  """
737
- return "tegra" in DEVICE_MODEL
744
+ if jetson := ("tegra" in DEVICE_MODEL):
745
+ if jetpack:
746
+ try:
747
+ content = open("/etc/nv_tegra_release").read()
748
+ version_map = {4: "R32", 5: "R35", 6: "R36"} # JetPack to L4T major version mapping
749
+ return jetpack in version_map and version_map[jetpack] in content
750
+ except Exception:
751
+ return False
752
+ return jetson
738
753
 
739
754
 
740
755
  def is_online() -> bool:
741
756
  """
742
- Check internet connectivity by attempting to connect to a known online host.
757
+ Fast online check using DNS (v4/v6) resolution (Cloudflare + Google).
743
758
 
744
759
  Returns:
745
760
  (bool): True if connection is successful, False otherwise.
746
761
  """
747
- try:
748
- assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True"
749
- import socket
762
+ if str(os.getenv("YOLO_OFFLINE", "")).lower() == "true":
763
+ return False
750
764
 
751
- for dns in ("1.1.1.1", "8.8.8.8"): # check Cloudflare and Google DNS
752
- socket.create_connection(address=(dns, 80), timeout=2.0).close()
765
+ for host in ("one.one.one.one", "dns.google"):
766
+ try:
767
+ socket.getaddrinfo(host, 0, socket.AF_UNSPEC, 0, 0, socket.AI_ADDRCONFIG)
753
768
  return True
754
- except Exception:
755
- return False
769
+ except OSError:
770
+ continue
771
+ return False
756
772
 
757
773
 
758
774
  def is_pip_package(filepath: str = __name__) -> bool:
@@ -829,6 +845,7 @@ def is_git_dir():
829
845
  return GIT_DIR is not None
830
846
 
831
847
 
848
+ @lru_cache(maxsize=1)
832
849
  def get_git_origin_url():
833
850
  """
834
851
  Retrieve the origin URL of a git repository.
@@ -838,12 +855,14 @@ def get_git_origin_url():
838
855
  """
839
856
  if IS_GIT_DIR:
840
857
  try:
841
- origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
842
- return origin.decode().strip()
858
+ return subprocess.check_output(
859
+ ["git", "config", "--get", "remote.origin.url"], stderr=subprocess.DEVNULL, text=True
860
+ ).strip()
843
861
  except subprocess.CalledProcessError:
844
862
  return None
845
863
 
846
864
 
865
+ @lru_cache(maxsize=1)
847
866
  def get_git_branch():
848
867
  """
849
868
  Return the current git branch name. If not in a git repository, return None.
@@ -853,8 +872,24 @@ def get_git_branch():
853
872
  """
854
873
  if IS_GIT_DIR:
855
874
  try:
856
- origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
857
- return origin.decode().strip()
875
+ return subprocess.check_output(
876
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=subprocess.DEVNULL, text=True
877
+ ).strip()
878
+ except subprocess.CalledProcessError:
879
+ return None
880
+
881
+
882
+ @lru_cache(maxsize=1)
883
+ def get_git_commit():
884
+ """
885
+ Return the current git commit hash. If not in a git repository, return None.
886
+
887
+ Returns:
888
+ (str | None): The current git commit hash or None if not a git directory.
889
+ """
890
+ if IS_GIT_DIR:
891
+ try:
892
+ return subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL, text=True).strip()
858
893
  except subprocess.CalledProcessError:
859
894
  return None
860
895
 
@@ -90,9 +90,13 @@ def benchmark(
90
90
 
91
91
  import polars as pl # scope for faster 'import ultralytics'
92
92
 
93
- pl.Config.set_tbl_cols(10)
94
- pl.Config.set_tbl_width_chars(120)
95
- pl.Config.set_tbl_hide_dataframe_shape(True)
93
+ pl.Config.set_tbl_cols(-1) # Show all columns
94
+ pl.Config.set_tbl_rows(-1) # Show all rows
95
+ pl.Config.set_tbl_width_chars(-1) # No width limit
96
+ pl.Config.set_tbl_hide_column_data_types(True) # Hide data types
97
+ pl.Config.set_tbl_hide_dataframe_shape(True) # Hide shape info
98
+ pl.Config.set_tbl_formatting("ASCII_BORDERS_ONLY_CONDENSED")
99
+
96
100
  device = select_device(device, verbose=False)
97
101
  if isinstance(model, (str, Path)):
98
102
  model = YOLO(model)
@@ -194,12 +198,14 @@ def benchmark(
194
198
 
195
199
  # Print results
196
200
  check_yolo(device=device) # print system info
197
- df = pl.DataFrame(y, schema=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"])
201
+ df = pl.DataFrame(y, schema=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"], orient="row")
202
+ df = df.with_row_index(" ", offset=1) # add index info
203
+ df_display = df.with_columns(pl.all().cast(pl.String).fill_null("-"))
198
204
 
199
205
  name = model.model_name
200
206
  dt = time.time() - t0
201
207
  legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed"
202
- s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df.fill_null('-')}\n"
208
+ s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df_display}\n"
203
209
  LOGGER.info(s)
204
210
  with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
205
211
  f.write(s)
@@ -209,7 +215,7 @@ def benchmark(
209
215
  floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
210
216
  assert all(x > floor for x in metrics if not np.isnan(x)), f"Benchmark failure: metric(s) < floor {floor}"
211
217
 
212
- return df
218
+ return df_display
213
219
 
214
220
 
215
221
  class RF100Benchmark:
@@ -1,12 +1,13 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  from ultralytics.utils import RANK, SETTINGS
4
- from ultralytics.utils.logger import DEFAULT_LOG_PATH, ConsoleLogger, SystemLogger
5
4
 
6
5
 
7
6
  def on_pretrain_routine_start(trainer):
8
7
  """Initialize and start console logging immediately at the very beginning."""
9
8
  if RANK in {-1, 0}:
9
+ from ultralytics.utils.logger import DEFAULT_LOG_PATH, ConsoleLogger, SystemLogger
10
+
10
11
  trainer.system_logger = SystemLogger()
11
12
  trainer.console_logger = ConsoleLogger(DEFAULT_LOG_PATH)
12
13
  trainer.console_logger.start_capture()
@@ -274,7 +274,7 @@ def check_latest_pypi_version(package_name="ultralytics"):
274
274
  Returns:
275
275
  (str): The latest version of the package.
276
276
  """
277
- import requests # slow import
277
+ import requests # scoped as slow import
278
278
 
279
279
  try:
280
280
  requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
@@ -637,7 +637,7 @@ def check_yolo(verbose=True, device=""):
637
637
  verbose (bool): Whether to print verbose information.
638
638
  device (str | torch.device): Device to use for YOLO.
639
639
  """
640
- import psutil
640
+ import psutil # scoped as slow import
641
641
 
642
642
  from ultralytics.utils.torch_utils import select_device
643
643
 
@@ -670,7 +670,7 @@ def collect_system_info():
670
670
  Returns:
671
671
  (dict): Dictionary containing system information.
672
672
  """
673
- import psutil
673
+ import psutil # scoped as slow import
674
674
 
675
675
  from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
676
676
  from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info
@@ -1,12 +1,13 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import re
4
6
  import shutil
5
7
  import subprocess
6
8
  from itertools import repeat
7
9
  from multiprocessing.pool import ThreadPool
8
10
  from pathlib import Path
9
- from typing import List, Tuple
10
11
  from urllib import parse, request
11
12
 
12
13
  from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
@@ -41,7 +42,7 @@ GITHUB_ASSETS_NAMES = frozenset(
41
42
  GITHUB_ASSETS_STEMS = frozenset(k.rpartition(".")[0] for k in GITHUB_ASSETS_NAMES)
42
43
 
43
44
 
44
- def is_url(url, check: bool = False) -> bool:
45
+ def is_url(url: str | Path, check: bool = False) -> bool:
45
46
  """
46
47
  Validate if the given string is a URL and optionally check if the URL exists online.
47
48
 
@@ -68,7 +69,7 @@ def is_url(url, check: bool = False) -> bool:
68
69
  return False
69
70
 
70
71
 
71
- def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
72
+ def delete_dsstore(path: str | Path, files_to_delete: tuple[str, ...] = (".DS_Store", "__MACOSX")) -> None:
72
73
  """
73
74
  Delete all specified system files in a directory.
74
75
 
@@ -91,7 +92,12 @@ def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")):
91
92
  f.unlink()
92
93
 
93
94
 
94
- def zip_directory(directory, compress: bool = True, exclude=(".DS_Store", "__MACOSX"), progress: bool = True) -> Path:
95
+ def zip_directory(
96
+ directory: str | Path,
97
+ compress: bool = True,
98
+ exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
99
+ progress: bool = True,
100
+ ) -> Path:
95
101
  """
96
102
  Zip the contents of a directory, excluding specified files.
97
103
 
@@ -129,9 +135,9 @@ def zip_directory(directory, compress: bool = True, exclude=(".DS_Store", "__MAC
129
135
 
130
136
 
131
137
  def unzip_file(
132
- file,
133
- path=None,
134
- exclude=(".DS_Store", "__MACOSX"),
138
+ file: str | Path,
139
+ path: str | Path | None = None,
140
+ exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
135
141
  exist_ok: bool = False,
136
142
  progress: bool = True,
137
143
  ) -> Path:
@@ -198,8 +204,8 @@ def unzip_file(
198
204
 
199
205
 
200
206
  def check_disk_space(
201
- url: str = "https://ultralytics.com/assets/coco8.zip",
202
- path=Path.cwd(),
207
+ file_bytes: int,
208
+ path: str | Path = Path.cwd(),
203
209
  sf: float = 1.5,
204
210
  hard: bool = True,
205
211
  ) -> bool:
@@ -207,7 +213,7 @@ def check_disk_space(
207
213
  Check if there is sufficient disk space to download and store a file.
208
214
 
209
215
  Args:
210
- url (str, optional): The URL to the file.
216
+ file_bytes (int): The file size in bytes.
211
217
  path (str | Path, optional): The path or drive to check the available free space on.
212
218
  sf (float, optional): Safety factor, the multiplier for the required free space.
213
219
  hard (bool, optional): Whether to throw an error or not on insufficient disk space.
@@ -215,26 +221,14 @@ def check_disk_space(
215
221
  Returns:
216
222
  (bool): True if there is sufficient disk space, False otherwise.
217
223
  """
218
- import requests # slow import
219
-
220
- try:
221
- r = requests.head(url) # response
222
- assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response
223
- except Exception:
224
- return True # requests issue, default to True
225
-
226
- # Check file size
227
- gib = 1 << 30 # bytes per GiB
228
- data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB)
229
- total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes
230
-
231
- if data * sf < free:
224
+ total, used, free = shutil.disk_usage(path) # bytes
225
+ if file_bytes * sf < free:
232
226
  return True # sufficient space
233
227
 
234
228
  # Insufficient space
235
229
  text = (
236
- f"Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, "
237
- f"Please free {data * sf - free:.1f} GB additional disk space and try again."
230
+ f"Insufficient free disk space {free >> 30:.3f} GB < {int(file_bytes * sf) >> 30:.3f} GB required, "
231
+ f"Please free {int(file_bytes * sf - free) >> 30:.3f} GB additional disk space and try again."
238
232
  )
239
233
  if hard:
240
234
  raise MemoryError(text)
@@ -242,7 +236,7 @@ def check_disk_space(
242
236
  return False
243
237
 
244
238
 
245
- def get_google_drive_file_info(link: str) -> Tuple[str, str]:
239
+ def get_google_drive_file_info(link: str) -> tuple[str, str | None]:
246
240
  """
247
241
  Retrieve the direct download link and filename for a shareable Google Drive file link.
248
242
 
@@ -258,7 +252,7 @@ def get_google_drive_file_info(link: str) -> Tuple[str, str]:
258
252
  >>> link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link"
259
253
  >>> url, filename = get_google_drive_file_info(link)
260
254
  """
261
- import requests # slow import
255
+ import requests # scoped as slow import
262
256
 
263
257
  file_id = link.split("/d/")[1].split("/view", 1)[0]
264
258
  drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
@@ -283,9 +277,9 @@ def get_google_drive_file_info(link: str) -> Tuple[str, str]:
283
277
 
284
278
 
285
279
  def safe_download(
286
- url,
287
- file=None,
288
- dir=None,
280
+ url: str | Path,
281
+ file: str | Path | None = None,
282
+ dir: str | Path | None = None,
289
283
  unzip: bool = True,
290
284
  delete: bool = False,
291
285
  curl: bool = False,
@@ -293,7 +287,7 @@ def safe_download(
293
287
  min_bytes: float = 1e0,
294
288
  exist_ok: bool = False,
295
289
  progress: bool = True,
296
- ):
290
+ ) -> Path | str:
297
291
  """
298
292
  Download files from a URL with options for retrying, unzipping, and deleting the downloaded file. Enhanced with
299
293
  robust partial download detection using Content-Length validation.
@@ -335,7 +329,6 @@ def safe_download(
335
329
  )
336
330
  desc = f"Downloading {uri} to '{f}'"
337
331
  f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
338
- check_disk_space(url, path=f.parent)
339
332
  curl_installed = shutil.which("curl")
340
333
  for i in range(retry + 1):
341
334
  try:
@@ -347,6 +340,9 @@ def safe_download(
347
340
  else: # urllib download
348
341
  with request.urlopen(url) as response:
349
342
  expected_size = int(response.getheader("Content-Length", 0))
343
+ if i == 0 and expected_size > 1048576:
344
+ check_disk_space(expected_size, path=f.parent)
345
+ buffer_size = max(8192, min(1048576, expected_size // 1000)) if expected_size else 8192
350
346
  with TQDM(
351
347
  total=expected_size,
352
348
  desc=desc,
@@ -356,7 +352,10 @@ def safe_download(
356
352
  unit_divisor=1024,
357
353
  ) as pbar:
358
354
  with open(f, "wb") as f_opened:
359
- for data in response:
355
+ while True:
356
+ data = response.read(buffer_size)
357
+ if not data:
358
+ break
360
359
  f_opened.write(data)
361
360
  pbar.update(len(data))
362
361
 
@@ -371,6 +370,8 @@ def safe_download(
371
370
  else:
372
371
  break # success
373
372
  f.unlink() # remove partial downloads
373
+ except MemoryError:
374
+ raise # Re-raise immediately - no point retrying if insufficient disk space
374
375
  except Exception as e:
375
376
  if i == 0 and not is_online():
376
377
  raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e
@@ -397,7 +398,7 @@ def get_github_assets(
397
398
  repo: str = "ultralytics/assets",
398
399
  version: str = "latest",
399
400
  retry: bool = False,
400
- ) -> Tuple[str, List[str]]:
401
+ ) -> tuple[str, list[str]]:
401
402
  """
402
403
  Retrieve the specified version's tag and assets from a GitHub repository.
403
404
 
@@ -415,7 +416,7 @@ def get_github_assets(
415
416
  Examples:
416
417
  >>> tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
417
418
  """
418
- import requests # slow import
419
+ import requests # scoped as slow import
419
420
 
420
421
  if version != "latest":
421
422
  version = f"tags/{version}" # i.e. tags/v6.2
@@ -430,7 +431,12 @@ def get_github_assets(
430
431
  return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...]
431
432
 
432
433
 
433
- def attempt_download_asset(file, repo: str = "ultralytics/assets", release: str = "v8.3.0", **kwargs) -> str:
434
+ def attempt_download_asset(
435
+ file: str | Path,
436
+ repo: str = "ultralytics/assets",
437
+ release: str = "v8.3.0",
438
+ **kwargs,
439
+ ) -> str:
434
440
  """
435
441
  Attempt to download a file from GitHub release assets if it is not found locally.
436
442
 
@@ -482,15 +488,15 @@ def attempt_download_asset(file, repo: str = "ultralytics/assets", release: str
482
488
 
483
489
 
484
490
  def download(
485
- url,
486
- dir=Path.cwd(),
491
+ url: str | list[str] | Path,
492
+ dir: Path = Path.cwd(),
487
493
  unzip: bool = True,
488
494
  delete: bool = False,
489
495
  curl: bool = False,
490
496
  threads: int = 1,
491
497
  retry: int = 3,
492
498
  exist_ok: bool = False,
493
- ):
499
+ ) -> None:
494
500
  """
495
501
  Download files from specified URLs to a given directory.
496
502