dgenerate-ultralytics-headless 8.3.192__py3-none-any.whl → 8.3.194__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {dgenerate_ultralytics_headless-8.3.192.dist-info → dgenerate_ultralytics_headless-8.3.194.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.192.dist-info → dgenerate_ultralytics_headless-8.3.194.dist-info}/RECORD +33 -32
  3. tests/test_exports.py +8 -5
  4. tests/test_python.py +1 -1
  5. ultralytics/__init__.py +1 -1
  6. ultralytics/cfg/__init__.py +7 -5
  7. ultralytics/cfg/datasets/xView.yaml +1 -1
  8. ultralytics/data/utils.py +1 -1
  9. ultralytics/engine/exporter.py +7 -6
  10. ultralytics/engine/model.py +4 -4
  11. ultralytics/engine/predictor.py +7 -3
  12. ultralytics/engine/trainer.py +4 -4
  13. ultralytics/hub/__init__.py +1 -2
  14. ultralytics/hub/utils.py +0 -101
  15. ultralytics/models/sam/predict.py +3 -3
  16. ultralytics/models/yolo/segment/val.py +13 -13
  17. ultralytics/models/yolo/yoloe/val.py +2 -2
  18. ultralytics/nn/__init__.py +2 -4
  19. ultralytics/nn/autobackend.py +10 -13
  20. ultralytics/nn/tasks.py +2 -51
  21. ultralytics/utils/__init__.py +6 -3
  22. ultralytics/utils/callbacks/hub.py +2 -1
  23. ultralytics/utils/checks.py +2 -1
  24. ultralytics/utils/events.py +115 -0
  25. ultralytics/utils/ops.py +3 -1
  26. ultralytics/utils/tal.py +2 -2
  27. ultralytics/utils/torch_utils.py +7 -6
  28. ultralytics/utils/tqdm.py +49 -74
  29. ultralytics/utils/tuner.py +1 -1
  30. {dgenerate_ultralytics_headless-8.3.192.dist-info → dgenerate_ultralytics_headless-8.3.194.dist-info}/WHEEL +0 -0
  31. {dgenerate_ultralytics_headless-8.3.192.dist-info → dgenerate_ultralytics_headless-8.3.194.dist-info}/entry_points.txt +0 -0
  32. {dgenerate_ultralytics_headless-8.3.192.dist-info → dgenerate_ultralytics_headless-8.3.194.dist-info}/licenses/LICENSE +0 -0
  33. {dgenerate_ultralytics_headless-8.3.192.dist-info → dgenerate_ultralytics_headless-8.3.194.dist-info}/top_level.txt +0 -0
@@ -133,8 +133,17 @@ class SegmentationValidator(DetectionValidator):
133
133
  (Dict[str, Any]): Prepared batch with processed annotations.
134
134
  """
135
135
  prepared_batch = super()._prepare_batch(si, batch)
136
- midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si
137
- prepared_batch["masks"] = batch["masks"][midx]
136
+ nl = len(prepared_batch["cls"])
137
+ if self.args.overlap_mask:
138
+ masks = batch["masks"][si]
139
+ index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
140
+ masks = (masks == index).float()
141
+ else:
142
+ masks = batch["masks"][batch["batch_idx"] == si]
143
+ if nl and self.process is ops.process_mask_native:
144
+ masks = F.interpolate(masks[None], prepared_batch["imgsz"], mode="bilinear", align_corners=False)[0]
145
+ masks = masks.gt_(0.5)
146
+ prepared_batch["masks"] = masks
138
147
  return prepared_batch
139
148
 
140
149
  def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
@@ -158,20 +167,11 @@ class SegmentationValidator(DetectionValidator):
158
167
  >>> correct_preds = validator._process_batch(preds, batch)
159
168
  """
160
169
  tp = super()._process_batch(preds, batch)
161
- gt_cls, gt_masks = batch["cls"], batch["masks"]
170
+ gt_cls = batch["cls"]
162
171
  if len(gt_cls) == 0 or len(preds["cls"]) == 0:
163
172
  tp_m = np.zeros((len(preds["cls"]), self.niou), dtype=bool)
164
173
  else:
165
- pred_masks = preds["masks"]
166
- if self.args.overlap_mask:
167
- nl = len(gt_cls)
168
- index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
169
- gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
170
- gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
171
- if gt_masks.shape[1:] != pred_masks.shape[1:]:
172
- gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
173
- gt_masks = gt_masks.gt_(0.5)
174
- iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
174
+ iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
175
175
  tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
176
176
  tp.update({"tp_m": tp_m}) # update tp with mask IoU
177
177
  return tp
@@ -186,9 +186,9 @@ class YOLOEDetectValidator(DetectionValidator):
186
186
  self.device = select_device(self.args.device, verbose=False)
187
187
 
188
188
  if isinstance(model, (str, Path)):
189
- from ultralytics.nn.tasks import attempt_load_weights
189
+ from ultralytics.nn.tasks import load_checkpoint
190
190
 
191
- model = attempt_load_weights(model, device=self.device)
191
+ model, _ = load_checkpoint(model, device=self.device) # model, ckpt
192
192
  model.eval().to(self.device)
193
193
  data = check_det_dataset(refer_data or self.args.data)
194
194
  names = [name.split("/", 1)[0] for name in list(data["names"].values())]
@@ -5,18 +5,16 @@ from .tasks import (
5
5
  ClassificationModel,
6
6
  DetectionModel,
7
7
  SegmentationModel,
8
- attempt_load_one_weight,
9
- attempt_load_weights,
10
8
  guess_model_scale,
11
9
  guess_model_task,
10
+ load_checkpoint,
12
11
  parse_model,
13
12
  torch_safe_load,
14
13
  yaml_model_load,
15
14
  )
16
15
 
17
16
  __all__ = (
18
- "attempt_load_one_weight",
19
- "attempt_load_weights",
17
+ "load_checkpoint",
20
18
  "parse_model",
21
19
  "yaml_model_load",
22
20
  "guess_model_task",
@@ -203,9 +203,9 @@ class AutoBackend(nn.Module):
203
203
  model = model.fuse(verbose=verbose)
204
204
  model = model.to(device)
205
205
  else: # pt file
206
- from ultralytics.nn.tasks import attempt_load_one_weight
206
+ from ultralytics.nn.tasks import load_checkpoint
207
207
 
208
- model, _ = attempt_load_one_weight(model, device=device, fuse=fuse) # load model, ckpt
208
+ model, _ = load_checkpoint(model, device=device, fuse=fuse) # load model, ckpt
209
209
 
210
210
  # Common PyTorch model processing
211
211
  if hasattr(model, "kpt_shape"):
@@ -724,17 +724,14 @@ class AutoBackend(nn.Module):
724
724
  im_pil = Image.fromarray((im * 255).astype("uint8"))
725
725
  # im = im.resize((192, 320), Image.BILINEAR)
726
726
  y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized
727
- if "confidence" in y:
728
- raise TypeError(
729
- "Ultralytics only supports inference of non-pipelined CoreML models exported with "
730
- f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export."
731
- )
732
- # TODO: CoreML NMS inference handling
733
- # from ultralytics.utils.ops import xywh2xyxy
734
- # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
735
- # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)
736
- # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
737
- y = list(y.values())
727
+ if "confidence" in y: # NMS included
728
+ from ultralytics.utils.ops import xywh2xyxy
729
+
730
+ box = xywh2xyxy(y["coordinates"] * [[w, h, w, h]]) # xyxy pixels
731
+ cls = y["confidence"].argmax(1, keepdims=True)
732
+ y = np.concatenate((box, np.take_along_axis(y["confidence"], cls, axis=1), cls), 1)[None]
733
+ else:
734
+ y = list(y.values())
738
735
  if len(y) == 2 and len(y[1].shape) != 4: # segmentation model
739
736
  y = list(reversed(y)) # reversed for segmentation models (pred, proto)
740
737
 
ultralytics/nn/tasks.py CHANGED
@@ -1483,61 +1483,12 @@ def torch_safe_load(weight, safe_only=False):
1483
1483
  return ckpt, file
1484
1484
 
1485
1485
 
1486
- def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
1487
- """
1488
- Load an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.
1489
-
1490
- Args:
1491
- weights (str | List[str]): Model weights path(s).
1492
- device (torch.device, optional): Device to load model to.
1493
- inplace (bool): Whether to do inplace operations.
1494
- fuse (bool): Whether to fuse model.
1495
-
1496
- Returns:
1497
- (torch.nn.Module): Loaded model.
1498
- """
1499
- ensemble = Ensemble()
1500
- for w in weights if isinstance(weights, list) else [weights]:
1501
- ckpt, w = torch_safe_load(w) # load ckpt
1502
- args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
1503
- model = (ckpt.get("ema") or ckpt["model"]).float() # FP32 model
1504
-
1505
- # Model compatibility updates
1506
- model.args = args # attach args to model
1507
- model.pt_path = w # attach *.pt file path to model
1508
- model.task = getattr(model, "task", guess_model_task(model))
1509
- if not hasattr(model, "stride"):
1510
- model.stride = torch.tensor([32.0])
1511
-
1512
- # Append
1513
- ensemble.append((model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()).to(device))
1514
-
1515
- # Module updates
1516
- for m in ensemble.modules():
1517
- if hasattr(m, "inplace"):
1518
- m.inplace = inplace
1519
- elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
1520
- m.recompute_scale_factor = None # torch 1.11.0 compatibility
1521
-
1522
- # Return model
1523
- if len(ensemble) == 1:
1524
- return ensemble[-1]
1525
-
1526
- # Return ensemble
1527
- LOGGER.info(f"Ensemble created with {weights}\n")
1528
- for k in "names", "nc", "yaml":
1529
- setattr(ensemble, k, getattr(ensemble[0], k))
1530
- ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride
1531
- assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
1532
- return ensemble
1533
-
1534
-
1535
- def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
1486
+ def load_checkpoint(weight, device=None, inplace=True, fuse=False):
1536
1487
  """
1537
1488
  Load a single model weights.
1538
1489
 
1539
1490
  Args:
1540
- weight (str): Model weight path.
1491
+ weight (str | Path): Model weight path.
1541
1492
  device (torch.device, optional): Device to load model to.
1542
1493
  inplace (bool): Whether to do inplace operations.
1543
1494
  fuse (bool): Whether to fuse model.
@@ -49,7 +49,7 @@ MACOS_VERSION = platform.mac_ver()[0] if MACOS else None
49
49
  NOT_MACOS14 = not (MACOS and MACOS_VERSION.startswith("14."))
50
50
  ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
51
51
  PYTHON_VERSION = platform.python_version()
52
- TORCH_VERSION = torch.__version__
52
+ TORCH_VERSION = str(torch.__version__) # Normalize torch.__version__ (PyTorch>1.9 returns TorchVersion objects)
53
53
  TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
54
54
  IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode"
55
55
  RKNN_CHIPS = frozenset(
@@ -132,6 +132,10 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warning
132
132
  os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings
133
133
  os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs
134
134
 
135
+ # Precompiled type tuples for faster isinstance() checks
136
+ FLOAT_OR_INT = (float, int)
137
+ STR_OR_PATH = (str, Path)
138
+
135
139
 
136
140
  class DataExportMixin:
137
141
  """
@@ -456,8 +460,7 @@ def set_logging(name="LOGGING_NAME", verbose=True):
456
460
 
457
461
  # Set logger
458
462
  LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)
459
- for logger in "sentry_sdk", "urllib3.connectionpool":
460
- logging.getLogger(logger).setLevel(logging.CRITICAL + 1)
463
+ logging.getLogger("sentry_sdk").setLevel(logging.CRITICAL + 1)
461
464
 
462
465
 
463
466
  def emojis(string=""):
@@ -3,8 +3,9 @@
3
3
  import json
4
4
  from time import time
5
5
 
6
- from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events
6
+ from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession
7
7
  from ultralytics.utils import LOGGER, RANK, SETTINGS
8
+ from ultralytics.utils.events import events
8
9
 
9
10
 
10
11
  def on_pretrain_routine_start(trainer):
@@ -36,6 +36,7 @@ from ultralytics.utils import (
36
36
  PYTHON_VERSION,
37
37
  RKNN_CHIPS,
38
38
  ROOT,
39
+ TORCH_VERSION,
39
40
  TORCHVISION_VERSION,
40
41
  USER_CONFIG_DIR,
41
42
  WINDOWS,
@@ -464,7 +465,7 @@ def check_torchvision():
464
465
  }
465
466
 
466
467
  # Check major and minor versions
467
- v_torch = ".".join(torch.__version__.split("+", 1)[0].split(".")[:2])
468
+ v_torch = ".".join(TORCH_VERSION.split("+", 1)[0].split(".")[:2])
468
469
  if v_torch in compatibility_table:
469
470
  compatible_versions = compatibility_table[v_torch]
470
471
  v_torchvision = ".".join(TORCHVISION_VERSION.split("+", 1)[0].split(".")[:2])
@@ -0,0 +1,115 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import json
4
+ import random
5
+ import time
6
+ from pathlib import Path
7
+ from threading import Thread
8
+ from urllib.request import Request, urlopen
9
+
10
+ from ultralytics import SETTINGS, __version__
11
+ from ultralytics.utils import ARGV, ENVIRONMENT, GIT, IS_PIP_PACKAGE, ONLINE, PYTHON_VERSION, RANK, TESTS_RUNNING
12
+ from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
13
+ from ultralytics.utils.torch_utils import get_cpu_info
14
+
15
+
16
+ def _post(url: str, data: dict, timeout: float = 5.0) -> None:
17
+ """Send a one-shot JSON POST request."""
18
+ try:
19
+ body = json.dumps(data, separators=(",", ":")).encode() # compact JSON
20
+ req = Request(url, data=body, headers={"Content-Type": "application/json"})
21
+ urlopen(req, timeout=timeout).close()
22
+ except Exception:
23
+ pass
24
+
25
+
26
+ class Events:
27
+ """
28
+ Collect and send anonymous usage analytics with rate-limiting.
29
+
30
+ Event collection and transmission are enabled when sync is enabled in settings, the current process is rank -1 or 0,
31
+ tests are not running, the environment is online, and the installation source is either pip or the official
32
+ Ultralytics GitHub repository.
33
+
34
+ Attributes:
35
+ url (str): Measurement Protocol endpoint for receiving anonymous events.
36
+ events (list[dict]): In-memory queue of event payloads awaiting transmission.
37
+ rate_limit (float): Minimum time in seconds between POST requests.
38
+ t (float): Timestamp of the last transmission in seconds since the epoch.
39
+ metadata (dict): Static metadata describing runtime, installation source, and environment.
40
+ enabled (bool): Flag indicating whether analytics collection is active.
41
+
42
+ Methods:
43
+ __init__: Initialize the event queue, rate limiter, and runtime metadata.
44
+ __call__: Queue an event and trigger a non-blocking send when the rate limit elapses.
45
+ """
46
+
47
+ url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
48
+
49
+ def __init__(self) -> None:
50
+ """Initialize the Events instance with queue, rate limiter, and environment metadata."""
51
+ self.events = [] # pending events
52
+ self.rate_limit = 30.0 # rate limit (seconds)
53
+ self.t = 0.0 # last send timestamp (seconds)
54
+ self.metadata = {
55
+ "cli": Path(ARGV[0]).name == "yolo",
56
+ "install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
57
+ "python": PYTHON_VERSION.rsplit(".", 1)[0], # i.e. 3.13
58
+ "CPU": get_cpu_info(),
59
+ # "GPU": get_gpu_info(index=0) if cuda else None,
60
+ "version": __version__,
61
+ "env": ENVIRONMENT,
62
+ "session_id": round(random.random() * 1e15),
63
+ "engagement_time_msec": 1000,
64
+ }
65
+ self.enabled = (
66
+ SETTINGS["sync"]
67
+ and RANK in {-1, 0}
68
+ and not TESTS_RUNNING
69
+ and ONLINE
70
+ and (IS_PIP_PACKAGE or GIT.origin == "https://github.com/ultralytics/ultralytics.git")
71
+ )
72
+
73
+ def __call__(self, cfg, device=None) -> None:
74
+ """
75
+ Queue an event and flush the queue asynchronously when the rate limit elapses.
76
+
77
+ Args:
78
+ cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
79
+ device (torch.device | str, optional): The device type (e.g., 'cpu', 'cuda').
80
+ """
81
+ if not self.enabled:
82
+ # Events disabled, do nothing
83
+ return
84
+
85
+ # Attempt to enqueue a new event
86
+ if len(self.events) < 25: # Queue limited to 25 events to bound memory and traffic
87
+ params = {
88
+ **self.metadata,
89
+ "task": cfg.task,
90
+ "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
91
+ "device": str(device),
92
+ }
93
+ if cfg.mode == "export":
94
+ params["format"] = cfg.format
95
+ self.events.append({"name": cfg.mode, "params": params})
96
+
97
+ # Check rate limit and return early if under limit
98
+ t = time.time()
99
+ if (t - self.t) < self.rate_limit:
100
+ return
101
+
102
+ # Overrate limit: send a snapshot of queued events in a background thread
103
+ payload_events = list(self.events) # snapshot to avoid race with queue reset
104
+ Thread(
105
+ target=_post,
106
+ args=(self.url, {"client_id": SETTINGS["uuid"], "events": payload_events}), # SHA-256 anonymized
107
+ daemon=True,
108
+ ).start()
109
+
110
+ # Reset queue and rate limit timer
111
+ self.events = []
112
+ self.t = t
113
+
114
+
115
+ events = Events()
ultralytics/utils/ops.py CHANGED
@@ -244,7 +244,9 @@ def scale_image(masks, im0_shape, ratio_pad=None):
244
244
  if len(masks.shape) < 2:
245
245
  raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
246
246
  masks = masks[top:bottom, left:right]
247
- masks = cv2.resize(masks, (im0_w, im0_h))
247
+ # handle the cv2.resize 512 channels limitation: https://github.com/ultralytics/ultralytics/pull/21947
248
+ masks = [cv2.resize(array, (im0_w, im0_h)) for array in np.array_split(masks, masks.shape[-1] // 512 + 1, axis=-1)]
249
+ masks = np.concatenate(masks, axis=-1) if len(masks) > 1 else masks[0]
248
250
  if len(masks.shape) == 2:
249
251
  masks = masks[:, :, None]
250
252
 
ultralytics/utils/tal.py CHANGED
@@ -3,12 +3,12 @@
3
3
  import torch
4
4
  import torch.nn as nn
5
5
 
6
- from . import LOGGER
6
+ from . import LOGGER, TORCH_VERSION
7
7
  from .checks import check_version
8
8
  from .metrics import bbox_iou, probiou
9
9
  from .ops import xywhr2xyxyxyxy
10
10
 
11
- TORCH_1_10 = check_version(torch.__version__, "1.10.0")
11
+ TORCH_1_10 = check_version(TORCH_VERSION, "1.10.0")
12
12
 
13
13
 
14
14
  class TaskAlignedAssigner(nn.Module):
@@ -27,6 +27,7 @@ from ultralytics.utils import (
27
27
  LOGGER,
28
28
  NUM_THREADS,
29
29
  PYTHON_VERSION,
30
+ TORCH_VERSION,
30
31
  TORCHVISION_VERSION,
31
32
  WINDOWS,
32
33
  colorstr,
@@ -35,15 +36,15 @@ from ultralytics.utils.checks import check_version
35
36
  from ultralytics.utils.patches import torch_load
36
37
 
37
38
  # Version checks (all default to version>=min_version)
38
- TORCH_1_9 = check_version(torch.__version__, "1.9.0")
39
- TORCH_1_13 = check_version(torch.__version__, "1.13.0")
40
- TORCH_2_0 = check_version(torch.__version__, "2.0.0")
41
- TORCH_2_4 = check_version(torch.__version__, "2.4.0")
39
+ TORCH_1_9 = check_version(TORCH_VERSION, "1.9.0")
40
+ TORCH_1_13 = check_version(TORCH_VERSION, "1.13.0")
41
+ TORCH_2_0 = check_version(TORCH_VERSION, "2.0.0")
42
+ TORCH_2_4 = check_version(TORCH_VERSION, "2.4.0")
42
43
  TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
43
44
  TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
44
45
  TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
45
46
  TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
46
- if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows
47
+ if WINDOWS and check_version(TORCH_VERSION, "==2.4.0"): # reject version 2.4.0 on Windows
47
48
  LOGGER.warning(
48
49
  "Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
49
50
  "https://github.com/ultralytics/ultralytics/issues/15049"
@@ -165,7 +166,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
165
166
  if isinstance(device, torch.device) or str(device).startswith(("tpu", "intel")):
166
167
  return device
167
168
 
168
- s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} "
169
+ s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{TORCH_VERSION} "
169
170
  device = str(device).lower()
170
171
  for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
171
172
  device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
ultralytics/utils/tqdm.py CHANGED
@@ -88,11 +88,11 @@ class TQDM:
88
88
  mininterval: float = 0.1,
89
89
  disable: bool | None = None,
90
90
  unit: str = "it",
91
- unit_scale: bool = False,
91
+ unit_scale: bool = True,
92
92
  unit_divisor: int = 1000,
93
- bar_format: str | None = None,
93
+ bar_format: str | None = None, # kept for API compatibility; not used for formatting
94
94
  initial: int = 0,
95
- **kwargs, # Accept unused args for compatibility
95
+ **kwargs,
96
96
  ) -> None:
97
97
  """
98
98
  Initialize the TQDM progress bar with specified configuration options.
@@ -138,11 +138,8 @@ class TQDM:
138
138
  self.mininterval = max(mininterval, self.NONINTERACTIVE_MIN_INTERVAL) if self.noninteractive else mininterval
139
139
  self.initial = initial
140
140
 
141
- # Set bar format based on whether we have a total
142
- if self.total:
143
- self.bar_format = bar_format or "{desc}: {percent:.0f}% {bar} {n}/{total} {rate} {elapsed}<{remaining}"
144
- else:
145
- self.bar_format = bar_format or "{desc}: {bar} {n} {rate} {elapsed}"
141
+ # Kept for API compatibility (unused for f-string formatting)
142
+ self.bar_format = bar_format
146
143
 
147
144
  self.file = file or sys.stdout
148
145
 
@@ -151,48 +148,31 @@ class TQDM:
151
148
  self.last_print_n = self.initial
152
149
  self.last_print_t = time.time()
153
150
  self.start_t = time.time()
154
- self.last_rate = 0
151
+ self.last_rate = 0.0
155
152
  self.closed = False
153
+ self.is_bytes = unit_scale and unit in ("B", "bytes")
154
+ self.scales = (
155
+ [(1073741824, "GB/s"), (1048576, "MB/s"), (1024, "KB/s")]
156
+ if self.is_bytes
157
+ else [(1e9, f"G{self.unit}/s"), (1e6, f"M{self.unit}/s"), (1e3, f"K{self.unit}/s")]
158
+ )
156
159
 
157
- # Display initial bar if we have total and not disabled
158
160
  if not self.disable and self.total and not self.noninteractive:
159
161
  self._display()
160
162
 
161
163
  def _format_rate(self, rate: float) -> str:
162
- """Format rate with proper units and reasonable precision."""
164
+ """Format rate with units."""
163
165
  if rate <= 0:
164
166
  return ""
167
+ fallback = f"{rate:.1f}B/s" if self.is_bytes else f"{rate:.1f}{self.unit}/s"
168
+ return next((f"{rate / t:.1f}{u}" for t, u in self.scales if rate >= t), fallback)
165
169
 
166
- # For bytes with scaling, use binary units
167
- if self.unit in ("B", "bytes") and self.unit_scale:
168
- return next(
169
- (
170
- f"{rate / threshold:.1f}{unit}"
171
- for threshold, unit in [
172
- (1073741824, "GB/s"),
173
- (1048576, "MB/s"),
174
- (1024, "KB/s"),
175
- ]
176
- if rate >= threshold
177
- ),
178
- f"{rate:.1f}B/s",
179
- )
180
- # For other scalable units, use decimal units
181
- if self.unit_scale and self.unit in ("it", "items", ""):
182
- for threshold, prefix in [(1000000, "M"), (1000, "K")]:
183
- if rate >= threshold:
184
- return f"{rate / threshold:.1f}{prefix}{self.unit}/s"
185
-
186
- # Default formatting
187
- precision = ".1f" if rate >= 1 else ".2f"
188
- return f"{rate:{precision}}{self.unit}/s"
189
-
190
- def _format_num(self, num: int) -> str:
170
+ def _format_num(self, num: int | float) -> str:
191
171
  """Format number with optional unit scaling."""
192
- if not self.unit_scale or self.unit not in ("B", "bytes"):
172
+ if not self.unit_scale or not self.is_bytes:
193
173
  return str(num)
194
174
 
195
- for unit in ["", "K", "M", "G", "T"]:
175
+ for unit in ("", "K", "M", "G", "T"):
196
176
  if abs(num) < self.unit_divisor:
197
177
  return f"{num:3.1f}{unit}B" if unit else f"{num:.0f}B"
198
178
  num /= self.unit_divisor
@@ -224,8 +204,7 @@ class TQDM:
224
204
  """Check if display should update."""
225
205
  if self.noninteractive:
226
206
  return False
227
-
228
- return True if self.total and self.n >= self.total else dt >= self.mininterval
207
+ return (self.total is not None and self.n >= self.total) or (dt >= self.mininterval)
229
208
 
230
209
  def _display(self, final: bool = False) -> None:
231
210
  """Display progress bar."""
@@ -240,8 +219,8 @@ class TQDM:
240
219
  return
241
220
 
242
221
  # Calculate rate (avoid crazy numbers)
243
- if dt > self.MIN_RATE_CALC_INTERVAL: # Only calculate rate if enough time has passed
244
- rate = dn / dt
222
+ if dt > self.MIN_RATE_CALC_INTERVAL:
223
+ rate = dn / dt if dt else 0.0
245
224
  # Smooth rate for reasonable values, use raw rate for very high values
246
225
  if rate < self.MAX_SMOOTHED_RATE:
247
226
  self.last_rate = self.RATE_SMOOTHING_FACTOR * rate + (1 - self.RATE_SMOOTHING_FACTOR) * self.last_rate
@@ -249,8 +228,8 @@ class TQDM:
249
228
  else:
250
229
  rate = self.last_rate
251
230
 
252
- # At completion, use the overall rate for more accurate display
253
- if self.n >= (self.total or float("inf")) and self.total and self.total > 0:
231
+ # At completion, use overall rate
232
+ if self.total and self.n >= self.total:
254
233
  overall_elapsed = current_time - self.start_t
255
234
  if overall_elapsed > 0:
256
235
  rate = self.n / overall_elapsed
@@ -260,45 +239,41 @@ class TQDM:
260
239
  self.last_print_t = current_time
261
240
  elapsed = current_time - self.start_t
262
241
 
263
- # Calculate remaining time
242
+ # Remaining time
264
243
  remaining_str = ""
265
244
  if self.total and 0 < self.n < self.total and elapsed > 0:
266
- est_rate = rate or self.n / elapsed
267
- remaining_str = self._format_time((self.total - self.n) / est_rate)
245
+ est_rate = rate or (self.n / elapsed)
246
+ remaining_str = f"<{self._format_time((self.total - self.n) / est_rate)}"
268
247
 
269
- # Build progress components
248
+ # Numbers and percent
270
249
  if self.total:
271
250
  percent = (self.n / self.total) * 100
272
- # For bytes with unit scaling, avoid repeating units: show "5.4/5.4MB" not "5.4MB/5.4MB"
273
- n = self._format_num(self.n)
274
- total = self._format_num(self.total)
275
- if self.unit_scale and self.unit in ("B", "bytes"):
276
- n = n.rstrip("KMGTPB") # Remove unit suffix from current
251
+ n_str = self._format_num(self.n)
252
+ t_str = self._format_num(self.total)
253
+ if self.is_bytes:
254
+ # Collapse suffix only when identical (e.g. "5.4/5.4MB")
255
+ if n_str[-2] == t_str[-2]:
256
+ n_str = n_str.rstrip("KMGTPB") # Remove unit suffix from current if different than total
277
257
  else:
278
- percent = 0
279
- n = self._format_num(self.n)
280
- total = "?"
258
+ percent = 0.0
259
+ n_str, t_str = self._format_num(self.n), "?"
281
260
 
282
261
  elapsed_str = self._format_time(elapsed)
262
+ rate_str = self._format_rate(rate) or (self._format_rate(self.n / elapsed) if elapsed > 0 else "")
283
263
 
284
- # Use different format for completion
285
- if self.total and self.n >= self.total:
286
- format_str = self.bar_format.replace("<{remaining}", "")
264
+ bar = self._generate_bar()
265
+
266
+ # Compose progress line via f-strings (two shapes: with/without total)
267
+ if self.total:
268
+ if self.is_bytes and self.n >= self.total:
269
+ # Completed bytes: show only final size
270
+ progress_str = f"{self.desc}: {percent:.0f}% {bar} {t_str} {rate_str} {elapsed_str}"
271
+ else:
272
+ progress_str = (
273
+ f"{self.desc}: {percent:.0f}% {bar} {n_str}/{t_str} {rate_str} {elapsed_str}{remaining_str}"
274
+ )
287
275
  else:
288
- format_str = self.bar_format
289
-
290
- # Format progress string
291
- progress_str = format_str.format(
292
- desc=self.desc,
293
- percent=percent,
294
- bar=self._generate_bar(),
295
- n=n,
296
- total=total,
297
- rate=self._format_rate(rate) or (self._format_rate(self.n / elapsed) if elapsed > 0 else ""),
298
- remaining=remaining_str,
299
- elapsed=elapsed_str,
300
- unit=self.unit,
301
- )
276
+ progress_str = f"{self.desc}: {bar} {n_str} {rate_str} {elapsed_str}"
302
277
 
303
278
  # Write to output
304
279
  try:
@@ -336,7 +311,7 @@ class TQDM:
336
311
  if self.closed:
337
312
  return
338
313
 
339
- self.closed = True # Set before final display
314
+ self.closed = True
340
315
 
341
316
  if not self.disable:
342
317
  # Final display
@@ -129,7 +129,7 @@ def run_ray_tune(
129
129
  {**train_args, **{"exist_ok": train_args.pop("resume", False)}}, # resume w/ same tune_dir
130
130
  ),
131
131
  name=train_args.pop("name", "tune"), # runs/{task}/{tune_dir}
132
- ).resolve() # must be absolute dir
132
+ ) # must be absolute dir
133
133
  tune_dir.mkdir(parents=True, exist_ok=True)
134
134
  if tune.Tuner.can_restore(tune_dir):
135
135
  LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...")