dgenerate-ultralytics-headless 8.3.189__py3-none-any.whl → 8.3.191__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/RECORD +111 -109
  3. tests/test_cuda.py +6 -5
  4. tests/test_exports.py +1 -6
  5. tests/test_python.py +1 -4
  6. tests/test_solutions.py +1 -1
  7. ultralytics/__init__.py +1 -1
  8. ultralytics/cfg/__init__.py +16 -14
  9. ultralytics/cfg/datasets/VisDrone.yaml +4 -4
  10. ultralytics/data/annotator.py +6 -6
  11. ultralytics/data/augment.py +53 -51
  12. ultralytics/data/base.py +15 -13
  13. ultralytics/data/build.py +7 -4
  14. ultralytics/data/converter.py +9 -10
  15. ultralytics/data/dataset.py +24 -22
  16. ultralytics/data/loaders.py +13 -11
  17. ultralytics/data/split.py +4 -3
  18. ultralytics/data/split_dota.py +14 -12
  19. ultralytics/data/utils.py +31 -25
  20. ultralytics/engine/exporter.py +7 -4
  21. ultralytics/engine/model.py +16 -14
  22. ultralytics/engine/predictor.py +9 -7
  23. ultralytics/engine/results.py +59 -57
  24. ultralytics/engine/trainer.py +7 -0
  25. ultralytics/engine/tuner.py +4 -3
  26. ultralytics/engine/validator.py +3 -1
  27. ultralytics/hub/__init__.py +6 -2
  28. ultralytics/hub/auth.py +2 -2
  29. ultralytics/hub/google/__init__.py +9 -8
  30. ultralytics/hub/session.py +11 -11
  31. ultralytics/hub/utils.py +8 -9
  32. ultralytics/models/fastsam/model.py +8 -6
  33. ultralytics/models/nas/model.py +5 -3
  34. ultralytics/models/rtdetr/train.py +4 -3
  35. ultralytics/models/rtdetr/val.py +6 -4
  36. ultralytics/models/sam/amg.py +13 -10
  37. ultralytics/models/sam/model.py +3 -2
  38. ultralytics/models/sam/modules/blocks.py +21 -21
  39. ultralytics/models/sam/modules/decoders.py +11 -11
  40. ultralytics/models/sam/modules/encoders.py +25 -25
  41. ultralytics/models/sam/modules/memory_attention.py +9 -8
  42. ultralytics/models/sam/modules/sam.py +8 -10
  43. ultralytics/models/sam/modules/tiny_encoder.py +21 -20
  44. ultralytics/models/sam/modules/transformer.py +6 -5
  45. ultralytics/models/sam/modules/utils.py +7 -5
  46. ultralytics/models/sam/predict.py +32 -31
  47. ultralytics/models/utils/loss.py +29 -27
  48. ultralytics/models/utils/ops.py +10 -8
  49. ultralytics/models/yolo/classify/train.py +7 -5
  50. ultralytics/models/yolo/classify/val.py +10 -8
  51. ultralytics/models/yolo/detect/predict.py +3 -3
  52. ultralytics/models/yolo/detect/train.py +8 -6
  53. ultralytics/models/yolo/detect/val.py +23 -21
  54. ultralytics/models/yolo/model.py +14 -14
  55. ultralytics/models/yolo/obb/train.py +5 -3
  56. ultralytics/models/yolo/obb/val.py +13 -10
  57. ultralytics/models/yolo/pose/train.py +7 -5
  58. ultralytics/models/yolo/pose/val.py +11 -9
  59. ultralytics/models/yolo/segment/train.py +4 -5
  60. ultralytics/models/yolo/segment/val.py +12 -10
  61. ultralytics/models/yolo/world/train.py +9 -7
  62. ultralytics/models/yolo/yoloe/train.py +7 -6
  63. ultralytics/models/yolo/yoloe/val.py +10 -8
  64. ultralytics/nn/autobackend.py +40 -52
  65. ultralytics/nn/modules/__init__.py +3 -3
  66. ultralytics/nn/modules/block.py +12 -12
  67. ultralytics/nn/modules/conv.py +4 -3
  68. ultralytics/nn/modules/head.py +46 -38
  69. ultralytics/nn/modules/transformer.py +22 -21
  70. ultralytics/nn/tasks.py +2 -2
  71. ultralytics/nn/text_model.py +6 -5
  72. ultralytics/solutions/analytics.py +7 -5
  73. ultralytics/solutions/config.py +12 -10
  74. ultralytics/solutions/distance_calculation.py +3 -3
  75. ultralytics/solutions/heatmap.py +4 -2
  76. ultralytics/solutions/object_counter.py +5 -3
  77. ultralytics/solutions/parking_management.py +4 -2
  78. ultralytics/solutions/region_counter.py +7 -5
  79. ultralytics/solutions/similarity_search.py +5 -3
  80. ultralytics/solutions/solutions.py +38 -36
  81. ultralytics/solutions/streamlit_inference.py +8 -7
  82. ultralytics/trackers/bot_sort.py +11 -9
  83. ultralytics/trackers/byte_tracker.py +17 -15
  84. ultralytics/trackers/utils/gmc.py +4 -3
  85. ultralytics/utils/__init__.py +27 -77
  86. ultralytics/utils/autobatch.py +3 -2
  87. ultralytics/utils/autodevice.py +10 -10
  88. ultralytics/utils/benchmarks.py +11 -10
  89. ultralytics/utils/callbacks/comet.py +9 -9
  90. ultralytics/utils/callbacks/platform.py +2 -1
  91. ultralytics/utils/checks.py +20 -29
  92. ultralytics/utils/downloads.py +2 -2
  93. ultralytics/utils/export.py +12 -11
  94. ultralytics/utils/files.py +8 -7
  95. ultralytics/utils/git.py +139 -0
  96. ultralytics/utils/instance.py +8 -7
  97. ultralytics/utils/logger.py +7 -6
  98. ultralytics/utils/loss.py +15 -13
  99. ultralytics/utils/metrics.py +62 -62
  100. ultralytics/utils/nms.py +346 -0
  101. ultralytics/utils/ops.py +83 -251
  102. ultralytics/utils/patches.py +6 -4
  103. ultralytics/utils/plotting.py +18 -16
  104. ultralytics/utils/tal.py +1 -1
  105. ultralytics/utils/torch_utils.py +4 -2
  106. ultralytics/utils/tqdm.py +47 -33
  107. ultralytics/utils/triton.py +3 -2
  108. {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/WHEEL +0 -0
  109. {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/entry_points.txt +0 -0
  110. {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/licenses/LICENSE +0 -0
  111. {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import contextlib
4
6
  import glob
5
7
  import os
@@ -8,7 +10,6 @@ import tempfile
8
10
  from contextlib import contextmanager
9
11
  from datetime import datetime
10
12
  from pathlib import Path
11
- from typing import Union
12
13
 
13
14
 
14
15
  class WorkingDirectory(contextlib.ContextDecorator):
@@ -39,7 +40,7 @@ class WorkingDirectory(contextlib.ContextDecorator):
39
40
  >>> pass
40
41
  """
41
42
 
42
- def __init__(self, new_dir: Union[str, Path]):
43
+ def __init__(self, new_dir: str | Path):
43
44
  """Initialize the WorkingDirectory context manager with the target directory."""
44
45
  self.dir = new_dir # new dir
45
46
  self.cwd = Path.cwd().resolve() # current dir
@@ -54,7 +55,7 @@ class WorkingDirectory(contextlib.ContextDecorator):
54
55
 
55
56
 
56
57
  @contextmanager
57
- def spaces_in_path(path: Union[str, Path]):
58
+ def spaces_in_path(path: str | Path):
58
59
  """
59
60
  Context manager to handle paths with spaces in their names.
60
61
 
@@ -105,7 +106,7 @@ def spaces_in_path(path: Union[str, Path]):
105
106
  yield path
106
107
 
107
108
 
108
- def increment_path(path: Union[str, Path], exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path:
109
+ def increment_path(path: str | Path, exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path:
109
110
  """
110
111
  Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
111
112
 
@@ -153,19 +154,19 @@ def increment_path(path: Union[str, Path], exist_ok: bool = False, sep: str = ""
153
154
  return path
154
155
 
155
156
 
156
- def file_age(path: Union[str, Path] = __file__) -> int:
157
+ def file_age(path: str | Path = __file__) -> int:
157
158
  """Return days since the last modification of the specified file."""
158
159
  dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
159
160
  return dt.days # + dt.seconds / 86400 # fractional days
160
161
 
161
162
 
162
- def file_date(path: Union[str, Path] = __file__) -> str:
163
+ def file_date(path: str | Path = __file__) -> str:
163
164
  """Return the file modification date in 'YYYY-M-D' format."""
164
165
  t = datetime.fromtimestamp(Path(path).stat().st_mtime)
165
166
  return f"{t.year}-{t.month}-{t.day}"
166
167
 
167
168
 
168
- def file_size(path: Union[str, Path]) -> float:
169
+ def file_size(path: str | Path) -> float:
169
170
  """Return the size of a file or directory in megabytes (MB)."""
170
171
  if isinstance(path, (str, Path)):
171
172
  mb = 1 << 20 # bytes to MiB (1024 ** 2)
@@ -0,0 +1,139 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ from __future__ import annotations
4
+
5
+ from functools import cached_property
6
+ from pathlib import Path
7
+
8
+
9
+ class GitRepo:
10
+ """
11
+ Represent a local Git repository and expose branch, commit, and remote metadata.
12
+
13
+ This class discovers the repository root by searching for a .git entry from the given path upward, resolves the
14
+ actual .git directory (including worktrees), and reads Git metadata directly from on-disk files. It does not
15
+ invoke the git binary and therefore works in restricted environments. All metadata properties are resolved
16
+ lazily and cached; construct a new instance to refresh state.
17
+
18
+ Attributes:
19
+ root (Path | None): Repository root directory containing the .git entry; None if not in a repository.
20
+ gitdir (Path | None): Resolved .git directory path; handles worktrees; None if unresolved.
21
+ head (str | None): Raw contents of HEAD; a SHA for detached HEAD or "ref: <refname>" for branch heads.
22
+ is_repo (bool): Whether the provided path resides inside a Git repository.
23
+ branch (str | None): Current branch name when HEAD points to a branch; None for detached HEAD or non-repo.
24
+ commit (str | None): Current commit SHA for HEAD; None if not determinable.
25
+ origin (str | None): URL of the "origin" remote as read from gitdir/config; None if unset or unavailable.
26
+
27
+ Examples:
28
+ Initialize from the current working directory and read metadata
29
+ >>> from pathlib import Path
30
+ >>> repo = GitRepo(Path.cwd())
31
+ >>> repo.is_repo
32
+ True
33
+ >>> repo.branch, repo.commit[:7], repo.origin
34
+ ('main', '1a2b3c4', 'https://example.com/owner/repo.git')
35
+
36
+ Notes:
37
+ - Resolves metadata by reading files: HEAD, packed-refs, and config; no subprocess calls are used.
38
+ - Caches properties on first access using cached_property; recreate the object to reflect repository changes.
39
+ """
40
+
41
+ def __init__(self, path: Path = Path(__file__).resolve()):
42
+ """
43
+ Initialize a Git repository context by discovering the repository root from a starting path.
44
+
45
+ Args:
46
+ path (Path, optional): File or directory path used as the starting point to locate the repository root.
47
+ """
48
+ self.root = self._find_root(path)
49
+ self.gitdir = self._gitdir(self.root) if self.root else None
50
+
51
+ @staticmethod
52
+ def _find_root(p: Path) -> Path | None:
53
+ """Return repo root or None."""
54
+ return next((d for d in [p] + list(p.parents) if (d / ".git").exists()), None)
55
+
56
+ @staticmethod
57
+ def _gitdir(root: Path) -> Path | None:
58
+ """Resolve actual .git directory (handles worktrees)."""
59
+ g = root / ".git"
60
+ if g.is_dir():
61
+ return g
62
+ if g.is_file():
63
+ t = g.read_text(errors="ignore").strip()
64
+ if t.startswith("gitdir:"):
65
+ return (root / t.split(":", 1)[1].strip()).resolve()
66
+ return None
67
+
68
+ def _read(self, p: Path | None) -> str | None:
69
+ """Read and strip file if exists."""
70
+ return p.read_text(errors="ignore").strip() if p and p.exists() else None
71
+
72
+ @cached_property
73
+ def head(self) -> str | None:
74
+ """HEAD file contents."""
75
+ return self._read(self.gitdir / "HEAD" if self.gitdir else None)
76
+
77
+ def _ref_commit(self, ref: str) -> str | None:
78
+ """Commit for ref (handles packed-refs)."""
79
+ rf = self.gitdir / ref
80
+ s = self._read(rf)
81
+ if s:
82
+ return s
83
+ pf = self.gitdir / "packed-refs"
84
+ b = pf.read_bytes().splitlines() if pf.exists() else []
85
+ tgt = ref.encode()
86
+ for line in b:
87
+ if line[:1] in (b"#", b"^") or b" " not in line:
88
+ continue
89
+ sha, name = line.split(b" ", 1)
90
+ if name.strip() == tgt:
91
+ return sha.decode()
92
+ return None
93
+
94
+ @property
95
+ def is_repo(self) -> bool:
96
+ """True if inside a git repo."""
97
+ return self.gitdir is not None
98
+
99
+ @cached_property
100
+ def branch(self) -> str | None:
101
+ """Current branch or None."""
102
+ if not self.is_repo or not self.head or not self.head.startswith("ref: "):
103
+ return None
104
+ ref = self.head[5:].strip()
105
+ return ref[len("refs/heads/") :] if ref.startswith("refs/heads/") else ref
106
+
107
+ @cached_property
108
+ def commit(self) -> str | None:
109
+ """Current commit SHA or None."""
110
+ if not self.is_repo or not self.head:
111
+ return None
112
+ return self._ref_commit(self.head[5:].strip()) if self.head.startswith("ref: ") else self.head
113
+
114
+ @cached_property
115
+ def origin(self) -> str | None:
116
+ """Origin URL or None."""
117
+ if not self.is_repo:
118
+ return None
119
+ cfg = self.gitdir / "config"
120
+ remote, url = None, None
121
+ for s in (self._read(cfg) or "").splitlines():
122
+ t = s.strip()
123
+ if t.startswith("[") and t.endswith("]"):
124
+ remote = t.lower()
125
+ elif t.lower().startswith("url =") and remote == '[remote "origin"]':
126
+ url = t.split("=", 1)[1].strip()
127
+ break
128
+ return url
129
+
130
+
131
+ if __name__ == "__main__":
132
+ import time
133
+
134
+ g = GitRepo()
135
+ if g.is_repo:
136
+ t0 = time.perf_counter()
137
+ print(f"repo={g.root}\nbranch={g.branch}\ncommit={g.commit}\norigin={g.origin}")
138
+ dt = (time.perf_counter() - t0) * 1000
139
+ print(f"\n⏱️ Profiling: total {dt:.3f} ms")
@@ -1,9 +1,10 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from collections import abc
4
6
  from itertools import repeat
5
7
  from numbers import Number
6
- from typing import List, Union
7
8
 
8
9
  import numpy as np
9
10
 
@@ -101,7 +102,7 @@ class Bboxes:
101
102
  else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh
102
103
  )
103
104
 
104
- def mul(self, scale: Union[int, tuple, list]) -> None:
105
+ def mul(self, scale: int | tuple | list) -> None:
105
106
  """
106
107
  Multiply bounding box coordinates by scale factor(s).
107
108
 
@@ -118,7 +119,7 @@ class Bboxes:
118
119
  self.bboxes[:, 2] *= scale[2]
119
120
  self.bboxes[:, 3] *= scale[3]
120
121
 
121
- def add(self, offset: Union[int, tuple, list]) -> None:
122
+ def add(self, offset: int | tuple | list) -> None:
122
123
  """
123
124
  Add offset to bounding box coordinates.
124
125
 
@@ -140,7 +141,7 @@ class Bboxes:
140
141
  return len(self.bboxes)
141
142
 
142
143
  @classmethod
143
- def concatenate(cls, boxes_list: List["Bboxes"], axis: int = 0) -> "Bboxes":
144
+ def concatenate(cls, boxes_list: list[Bboxes], axis: int = 0) -> Bboxes:
144
145
  """
145
146
  Concatenate a list of Bboxes objects into a single Bboxes object.
146
147
 
@@ -163,7 +164,7 @@ class Bboxes:
163
164
  return boxes_list[0]
164
165
  return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
165
166
 
166
- def __getitem__(self, index: Union[int, np.ndarray, slice]) -> "Bboxes":
167
+ def __getitem__(self, index: int | np.ndarray | slice) -> Bboxes:
167
168
  """
168
169
  Retrieve a specific bounding box or a set of bounding boxes using indexing.
169
170
 
@@ -327,7 +328,7 @@ class Instances:
327
328
  self.keypoints[..., 0] += padw
328
329
  self.keypoints[..., 1] += padh
329
330
 
330
- def __getitem__(self, index: Union[int, np.ndarray, slice]) -> "Instances":
331
+ def __getitem__(self, index: int | np.ndarray | slice) -> Instances:
331
332
  """
332
333
  Retrieve a specific instance or a set of instances using indexing.
333
334
 
@@ -452,7 +453,7 @@ class Instances:
452
453
  return len(self.bboxes)
453
454
 
454
455
  @classmethod
455
- def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
456
+ def concatenate(cls, instances_list: list[Instances], axis=0) -> Instances:
456
457
  """
457
458
  Concatenate a list of Instances objects into a single Instances object.
458
459
 
@@ -9,9 +9,6 @@ import time
9
9
  from datetime import datetime
10
10
  from pathlib import Path
11
11
 
12
- import psutil
13
- import requests
14
-
15
12
  from ultralytics.utils import MACOS, RANK
16
13
  from ultralytics.utils.checks import check_requirements
17
14
 
@@ -189,8 +186,10 @@ class ConsoleLogger:
189
186
  """Write log to API endpoint or local file destination."""
190
187
  try:
191
188
  if self.is_api:
189
+ import requests # scoped as slow import
190
+
192
191
  payload = {"timestamp": datetime.now().isoformat(), "message": text.strip()}
193
- requests.post(self.destination, json=payload, timeout=5)
192
+ requests.post(str(self.destination), json=payload, timeout=5)
194
193
  else:
195
194
  self.destination.parent.mkdir(parents=True, exist_ok=True)
196
195
  with self.destination.open("a", encoding="utf-8") as f:
@@ -237,7 +236,6 @@ class SystemLogger:
237
236
  Attributes:
238
237
  pynvml: NVIDIA pynvml module instance if successfully imported, None otherwise.
239
238
  nvidia_initialized (bool): Whether NVIDIA GPU monitoring is available and initialized.
240
- process (psutil.Process): Current psutil.Process instance for process-specific metrics.
241
239
  net_start: Initial network I/O counters for calculating cumulative usage.
242
240
  disk_start: Initial disk I/O counters for calculating cumulative usage.
243
241
 
@@ -260,9 +258,10 @@ class SystemLogger:
260
258
 
261
259
  def __init__(self):
262
260
  """Initialize the system logger."""
261
+ import psutil # scoped as slow import
262
+
263
263
  self.pynvml = None
264
264
  self.nvidia_initialized = self._init_nvidia()
265
- self.process = psutil.Process()
266
265
  self.net_start = psutil.net_io_counters()
267
266
  self.disk_start = psutil.disk_io_counters()
268
267
 
@@ -315,6 +314,8 @@ class SystemLogger:
315
314
  Returns:
316
315
  metrics (dict): System metrics containing 'cpu', 'ram', 'disk', 'network', 'gpus' with respective usage data.
317
316
  """
317
+ import psutil # scoped as slow import
318
+
318
319
  net = psutil.net_io_counters()
319
320
  disk = psutil.disk_io_counters()
320
321
  memory = psutil.virtual_memory()
ultralytics/utils/loss.py CHANGED
@@ -1,6 +1,8 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- from typing import Any, Dict, List, Tuple
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
4
6
 
5
7
  import torch
6
8
  import torch.nn as nn
@@ -122,7 +124,7 @@ class BboxLoss(nn.Module):
122
124
  target_scores: torch.Tensor,
123
125
  target_scores_sum: torch.Tensor,
124
126
  fg_mask: torch.Tensor,
125
- ) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ ) -> tuple[torch.Tensor, torch.Tensor]:
126
128
  """Compute IoU and DFL losses for bounding boxes."""
127
129
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
128
130
  iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
@@ -155,7 +157,7 @@ class RotatedBboxLoss(BboxLoss):
155
157
  target_scores: torch.Tensor,
156
158
  target_scores_sum: torch.Tensor,
157
159
  fg_mask: torch.Tensor,
158
- ) -> Tuple[torch.Tensor, torch.Tensor]:
160
+ ) -> tuple[torch.Tensor, torch.Tensor]:
159
161
  """Compute IoU and DFL losses for rotated bounding boxes."""
160
162
  weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
161
163
  iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
@@ -240,7 +242,7 @@ class v8DetectionLoss:
240
242
  # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
241
243
  return dist2bbox(pred_dist, anchor_points, xywh=False)
242
244
 
243
- def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
245
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
244
246
  """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
245
247
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
246
248
  feats = preds[1] if isinstance(preds, tuple) else preds
@@ -305,7 +307,7 @@ class v8SegmentationLoss(v8DetectionLoss):
305
307
  super().__init__(model)
306
308
  self.overlap = model.args.overlap_mask
307
309
 
308
- def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
309
311
  """Calculate and return the combined loss for detection and segmentation."""
310
312
  loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
311
313
  feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
@@ -493,7 +495,7 @@ class v8PoseLoss(v8DetectionLoss):
493
495
  sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
494
496
  self.keypoint_loss = KeypointLoss(sigmas=sigmas)
495
497
 
496
- def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
498
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
497
499
  """Calculate the total loss and detach it for pose estimation."""
498
500
  loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
499
501
  feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
@@ -577,7 +579,7 @@ class v8PoseLoss(v8DetectionLoss):
577
579
  stride_tensor: torch.Tensor,
578
580
  target_bboxes: torch.Tensor,
579
581
  pred_kpts: torch.Tensor,
580
- ) -> Tuple[torch.Tensor, torch.Tensor]:
582
+ ) -> tuple[torch.Tensor, torch.Tensor]:
581
583
  """
582
584
  Calculate the keypoints loss for the model.
583
585
 
@@ -645,7 +647,7 @@ class v8PoseLoss(v8DetectionLoss):
645
647
  class v8ClassificationLoss:
646
648
  """Criterion class for computing training losses for classification."""
647
649
 
648
- def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
650
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
649
651
  """Compute the classification loss between predictions and true labels."""
650
652
  preds = preds[1] if isinstance(preds, (list, tuple)) else preds
651
653
  loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
@@ -678,7 +680,7 @@ class v8OBBLoss(v8DetectionLoss):
678
680
  out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
679
681
  return out
680
682
 
681
- def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
683
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
682
684
  """Calculate and return the loss for oriented bounding box detection."""
683
685
  loss = torch.zeros(3, device=self.device) # box, cls, dfl
684
686
  feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
@@ -778,7 +780,7 @@ class E2EDetectLoss:
778
780
  self.one2many = v8DetectionLoss(model, tal_topk=10)
779
781
  self.one2one = v8DetectionLoss(model, tal_topk=1)
780
782
 
781
- def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
783
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
782
784
  """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
783
785
  preds = preds[1] if isinstance(preds, tuple) else preds
784
786
  one2many = preds["one2many"]
@@ -799,7 +801,7 @@ class TVPDetectLoss:
799
801
  self.ori_no = self.vp_criterion.no
800
802
  self.ori_reg_max = self.vp_criterion.reg_max
801
803
 
802
- def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
804
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
803
805
  """Calculate the loss for text-visual prompt detection."""
804
806
  feats = preds[1] if isinstance(preds, tuple) else preds
805
807
  assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
@@ -813,7 +815,7 @@ class TVPDetectLoss:
813
815
  box_loss = vp_loss[0][1]
814
816
  return box_loss, vp_loss[1]
815
817
 
816
- def _get_vp_features(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
818
+ def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
817
819
  """Extract visual-prompt features from the model output."""
818
820
  vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
819
821
 
@@ -835,7 +837,7 @@ class TVPSegmentLoss(TVPDetectLoss):
835
837
  super().__init__(model)
836
838
  self.vp_criterion = v8SegmentationLoss(model)
837
839
 
838
- def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
840
+ def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
839
841
  """Calculate the loss for text-visual prompt segmentation."""
840
842
  feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
841
843
  assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it