dgenerate-ultralytics-headless 8.3.190__py3-none-any.whl → 8.3.192__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.190.dist-info → dgenerate_ultralytics_headless-8.3.192.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.190.dist-info → dgenerate_ultralytics_headless-8.3.192.dist-info}/RECORD +103 -102
- tests/test_cuda.py +6 -5
- tests/test_exports.py +1 -6
- tests/test_python.py +1 -4
- tests/test_solutions.py +1 -1
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -14
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VisDrone.yaml +4 -4
- ultralytics/data/annotator.py +6 -6
- ultralytics/data/augment.py +53 -51
- ultralytics/data/base.py +15 -13
- ultralytics/data/build.py +7 -4
- ultralytics/data/converter.py +9 -10
- ultralytics/data/dataset.py +24 -22
- ultralytics/data/loaders.py +13 -11
- ultralytics/data/split.py +4 -3
- ultralytics/data/split_dota.py +14 -12
- ultralytics/data/utils.py +29 -23
- ultralytics/engine/exporter.py +2 -2
- ultralytics/engine/model.py +16 -14
- ultralytics/engine/predictor.py +8 -6
- ultralytics/engine/results.py +54 -52
- ultralytics/engine/trainer.py +8 -3
- ultralytics/engine/tuner.py +230 -42
- ultralytics/hub/google/__init__.py +7 -6
- ultralytics/hub/session.py +8 -6
- ultralytics/hub/utils.py +3 -4
- ultralytics/models/fastsam/model.py +8 -6
- ultralytics/models/nas/model.py +5 -3
- ultralytics/models/rtdetr/train.py +4 -3
- ultralytics/models/rtdetr/val.py +6 -4
- ultralytics/models/sam/amg.py +13 -10
- ultralytics/models/sam/model.py +3 -2
- ultralytics/models/sam/modules/blocks.py +21 -21
- ultralytics/models/sam/modules/decoders.py +11 -11
- ultralytics/models/sam/modules/encoders.py +25 -25
- ultralytics/models/sam/modules/memory_attention.py +9 -8
- ultralytics/models/sam/modules/sam.py +8 -10
- ultralytics/models/sam/modules/tiny_encoder.py +21 -20
- ultralytics/models/sam/modules/transformer.py +6 -5
- ultralytics/models/sam/modules/utils.py +7 -5
- ultralytics/models/sam/predict.py +32 -31
- ultralytics/models/utils/loss.py +29 -27
- ultralytics/models/utils/ops.py +10 -8
- ultralytics/models/yolo/classify/train.py +9 -7
- ultralytics/models/yolo/classify/val.py +11 -9
- ultralytics/models/yolo/detect/predict.py +1 -1
- ultralytics/models/yolo/detect/train.py +8 -6
- ultralytics/models/yolo/detect/val.py +22 -20
- ultralytics/models/yolo/model.py +14 -14
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +11 -9
- ultralytics/models/yolo/pose/train.py +7 -5
- ultralytics/models/yolo/pose/val.py +12 -10
- ultralytics/models/yolo/segment/train.py +4 -5
- ultralytics/models/yolo/segment/val.py +13 -11
- ultralytics/models/yolo/world/train.py +10 -8
- ultralytics/models/yolo/yoloe/train.py +10 -10
- ultralytics/models/yolo/yoloe/val.py +11 -9
- ultralytics/nn/autobackend.py +17 -19
- ultralytics/nn/modules/block.py +12 -12
- ultralytics/nn/modules/conv.py +4 -3
- ultralytics/nn/modules/head.py +41 -37
- ultralytics/nn/modules/transformer.py +22 -21
- ultralytics/nn/tasks.py +2 -2
- ultralytics/nn/text_model.py +6 -5
- ultralytics/solutions/analytics.py +7 -5
- ultralytics/solutions/config.py +12 -10
- ultralytics/solutions/distance_calculation.py +3 -3
- ultralytics/solutions/heatmap.py +4 -2
- ultralytics/solutions/object_counter.py +5 -3
- ultralytics/solutions/parking_management.py +4 -2
- ultralytics/solutions/region_counter.py +7 -5
- ultralytics/solutions/similarity_search.py +5 -3
- ultralytics/solutions/solutions.py +38 -36
- ultralytics/solutions/streamlit_inference.py +8 -7
- ultralytics/trackers/bot_sort.py +11 -9
- ultralytics/trackers/byte_tracker.py +17 -15
- ultralytics/trackers/utils/gmc.py +4 -3
- ultralytics/utils/__init__.py +16 -88
- ultralytics/utils/autobatch.py +3 -2
- ultralytics/utils/autodevice.py +10 -10
- ultralytics/utils/benchmarks.py +11 -10
- ultralytics/utils/callbacks/comet.py +9 -9
- ultralytics/utils/checks.py +17 -26
- ultralytics/utils/export.py +12 -11
- ultralytics/utils/files.py +8 -7
- ultralytics/utils/git.py +139 -0
- ultralytics/utils/instance.py +8 -7
- ultralytics/utils/loss.py +15 -13
- ultralytics/utils/metrics.py +62 -62
- ultralytics/utils/ops.py +3 -2
- ultralytics/utils/patches.py +6 -4
- ultralytics/utils/plotting.py +20 -18
- ultralytics/utils/torch_utils.py +4 -2
- ultralytics/utils/tqdm.py +18 -14
- ultralytics/utils/triton.py +3 -2
- {dgenerate_ultralytics_headless-8.3.190.dist-info → dgenerate_ultralytics_headless-8.3.192.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.190.dist-info → dgenerate_ultralytics_headless-8.3.192.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.190.dist-info → dgenerate_ultralytics_headless-8.3.192.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.190.dist-info → dgenerate_ultralytics_headless-8.3.192.dist-info}/top_level.txt +0 -0
ultralytics/utils/git.py
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from functools import cached_property
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
|
9
|
+
class GitRepo:
|
10
|
+
"""
|
11
|
+
Represent a local Git repository and expose branch, commit, and remote metadata.
|
12
|
+
|
13
|
+
This class discovers the repository root by searching for a .git entry from the given path upward, resolves the
|
14
|
+
actual .git directory (including worktrees), and reads Git metadata directly from on-disk files. It does not
|
15
|
+
invoke the git binary and therefore works in restricted environments. All metadata properties are resolved
|
16
|
+
lazily and cached; construct a new instance to refresh state.
|
17
|
+
|
18
|
+
Attributes:
|
19
|
+
root (Path | None): Repository root directory containing the .git entry; None if not in a repository.
|
20
|
+
gitdir (Path | None): Resolved .git directory path; handles worktrees; None if unresolved.
|
21
|
+
head (str | None): Raw contents of HEAD; a SHA for detached HEAD or "ref: <refname>" for branch heads.
|
22
|
+
is_repo (bool): Whether the provided path resides inside a Git repository.
|
23
|
+
branch (str | None): Current branch name when HEAD points to a branch; None for detached HEAD or non-repo.
|
24
|
+
commit (str | None): Current commit SHA for HEAD; None if not determinable.
|
25
|
+
origin (str | None): URL of the "origin" remote as read from gitdir/config; None if unset or unavailable.
|
26
|
+
|
27
|
+
Examples:
|
28
|
+
Initialize from the current working directory and read metadata
|
29
|
+
>>> from pathlib import Path
|
30
|
+
>>> repo = GitRepo(Path.cwd())
|
31
|
+
>>> repo.is_repo
|
32
|
+
True
|
33
|
+
>>> repo.branch, repo.commit[:7], repo.origin
|
34
|
+
('main', '1a2b3c4', 'https://example.com/owner/repo.git')
|
35
|
+
|
36
|
+
Notes:
|
37
|
+
- Resolves metadata by reading files: HEAD, packed-refs, and config; no subprocess calls are used.
|
38
|
+
- Caches properties on first access using cached_property; recreate the object to reflect repository changes.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self, path: Path = Path(__file__).resolve()):
|
42
|
+
"""
|
43
|
+
Initialize a Git repository context by discovering the repository root from a starting path.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
path (Path, optional): File or directory path used as the starting point to locate the repository root.
|
47
|
+
"""
|
48
|
+
self.root = self._find_root(path)
|
49
|
+
self.gitdir = self._gitdir(self.root) if self.root else None
|
50
|
+
|
51
|
+
@staticmethod
|
52
|
+
def _find_root(p: Path) -> Path | None:
|
53
|
+
"""Return repo root or None."""
|
54
|
+
return next((d for d in [p] + list(p.parents) if (d / ".git").exists()), None)
|
55
|
+
|
56
|
+
@staticmethod
|
57
|
+
def _gitdir(root: Path) -> Path | None:
|
58
|
+
"""Resolve actual .git directory (handles worktrees)."""
|
59
|
+
g = root / ".git"
|
60
|
+
if g.is_dir():
|
61
|
+
return g
|
62
|
+
if g.is_file():
|
63
|
+
t = g.read_text(errors="ignore").strip()
|
64
|
+
if t.startswith("gitdir:"):
|
65
|
+
return (root / t.split(":", 1)[1].strip()).resolve()
|
66
|
+
return None
|
67
|
+
|
68
|
+
def _read(self, p: Path | None) -> str | None:
|
69
|
+
"""Read and strip file if exists."""
|
70
|
+
return p.read_text(errors="ignore").strip() if p and p.exists() else None
|
71
|
+
|
72
|
+
@cached_property
|
73
|
+
def head(self) -> str | None:
|
74
|
+
"""HEAD file contents."""
|
75
|
+
return self._read(self.gitdir / "HEAD" if self.gitdir else None)
|
76
|
+
|
77
|
+
def _ref_commit(self, ref: str) -> str | None:
|
78
|
+
"""Commit for ref (handles packed-refs)."""
|
79
|
+
rf = self.gitdir / ref
|
80
|
+
s = self._read(rf)
|
81
|
+
if s:
|
82
|
+
return s
|
83
|
+
pf = self.gitdir / "packed-refs"
|
84
|
+
b = pf.read_bytes().splitlines() if pf.exists() else []
|
85
|
+
tgt = ref.encode()
|
86
|
+
for line in b:
|
87
|
+
if line[:1] in (b"#", b"^") or b" " not in line:
|
88
|
+
continue
|
89
|
+
sha, name = line.split(b" ", 1)
|
90
|
+
if name.strip() == tgt:
|
91
|
+
return sha.decode()
|
92
|
+
return None
|
93
|
+
|
94
|
+
@property
|
95
|
+
def is_repo(self) -> bool:
|
96
|
+
"""True if inside a git repo."""
|
97
|
+
return self.gitdir is not None
|
98
|
+
|
99
|
+
@cached_property
|
100
|
+
def branch(self) -> str | None:
|
101
|
+
"""Current branch or None."""
|
102
|
+
if not self.is_repo or not self.head or not self.head.startswith("ref: "):
|
103
|
+
return None
|
104
|
+
ref = self.head[5:].strip()
|
105
|
+
return ref[len("refs/heads/") :] if ref.startswith("refs/heads/") else ref
|
106
|
+
|
107
|
+
@cached_property
|
108
|
+
def commit(self) -> str | None:
|
109
|
+
"""Current commit SHA or None."""
|
110
|
+
if not self.is_repo or not self.head:
|
111
|
+
return None
|
112
|
+
return self._ref_commit(self.head[5:].strip()) if self.head.startswith("ref: ") else self.head
|
113
|
+
|
114
|
+
@cached_property
|
115
|
+
def origin(self) -> str | None:
|
116
|
+
"""Origin URL or None."""
|
117
|
+
if not self.is_repo:
|
118
|
+
return None
|
119
|
+
cfg = self.gitdir / "config"
|
120
|
+
remote, url = None, None
|
121
|
+
for s in (self._read(cfg) or "").splitlines():
|
122
|
+
t = s.strip()
|
123
|
+
if t.startswith("[") and t.endswith("]"):
|
124
|
+
remote = t.lower()
|
125
|
+
elif t.lower().startswith("url =") and remote == '[remote "origin"]':
|
126
|
+
url = t.split("=", 1)[1].strip()
|
127
|
+
break
|
128
|
+
return url
|
129
|
+
|
130
|
+
|
131
|
+
if __name__ == "__main__":
|
132
|
+
import time
|
133
|
+
|
134
|
+
g = GitRepo()
|
135
|
+
if g.is_repo:
|
136
|
+
t0 = time.perf_counter()
|
137
|
+
print(f"repo={g.root}\nbranch={g.branch}\ncommit={g.commit}\norigin={g.origin}")
|
138
|
+
dt = (time.perf_counter() - t0) * 1000
|
139
|
+
print(f"\n⏱️ Profiling: total {dt:.3f} ms")
|
ultralytics/utils/instance.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
from collections import abc
|
4
6
|
from itertools import repeat
|
5
7
|
from numbers import Number
|
6
|
-
from typing import List, Union
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
|
@@ -101,7 +102,7 @@ class Bboxes:
|
|
101
102
|
else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh
|
102
103
|
)
|
103
104
|
|
104
|
-
def mul(self, scale:
|
105
|
+
def mul(self, scale: int | tuple | list) -> None:
|
105
106
|
"""
|
106
107
|
Multiply bounding box coordinates by scale factor(s).
|
107
108
|
|
@@ -118,7 +119,7 @@ class Bboxes:
|
|
118
119
|
self.bboxes[:, 2] *= scale[2]
|
119
120
|
self.bboxes[:, 3] *= scale[3]
|
120
121
|
|
121
|
-
def add(self, offset:
|
122
|
+
def add(self, offset: int | tuple | list) -> None:
|
122
123
|
"""
|
123
124
|
Add offset to bounding box coordinates.
|
124
125
|
|
@@ -140,7 +141,7 @@ class Bboxes:
|
|
140
141
|
return len(self.bboxes)
|
141
142
|
|
142
143
|
@classmethod
|
143
|
-
def concatenate(cls, boxes_list:
|
144
|
+
def concatenate(cls, boxes_list: list[Bboxes], axis: int = 0) -> Bboxes:
|
144
145
|
"""
|
145
146
|
Concatenate a list of Bboxes objects into a single Bboxes object.
|
146
147
|
|
@@ -163,7 +164,7 @@ class Bboxes:
|
|
163
164
|
return boxes_list[0]
|
164
165
|
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
165
166
|
|
166
|
-
def __getitem__(self, index:
|
167
|
+
def __getitem__(self, index: int | np.ndarray | slice) -> Bboxes:
|
167
168
|
"""
|
168
169
|
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
169
170
|
|
@@ -327,7 +328,7 @@ class Instances:
|
|
327
328
|
self.keypoints[..., 0] += padw
|
328
329
|
self.keypoints[..., 1] += padh
|
329
330
|
|
330
|
-
def __getitem__(self, index:
|
331
|
+
def __getitem__(self, index: int | np.ndarray | slice) -> Instances:
|
331
332
|
"""
|
332
333
|
Retrieve a specific instance or a set of instances using indexing.
|
333
334
|
|
@@ -452,7 +453,7 @@ class Instances:
|
|
452
453
|
return len(self.bboxes)
|
453
454
|
|
454
455
|
@classmethod
|
455
|
-
def concatenate(cls, instances_list:
|
456
|
+
def concatenate(cls, instances_list: list[Instances], axis=0) -> Instances:
|
456
457
|
"""
|
457
458
|
Concatenate a list of Instances objects into a single Instances object.
|
458
459
|
|
ultralytics/utils/loss.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
-
from
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import Any
|
4
6
|
|
5
7
|
import torch
|
6
8
|
import torch.nn as nn
|
@@ -122,7 +124,7 @@ class BboxLoss(nn.Module):
|
|
122
124
|
target_scores: torch.Tensor,
|
123
125
|
target_scores_sum: torch.Tensor,
|
124
126
|
fg_mask: torch.Tensor,
|
125
|
-
) ->
|
127
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
126
128
|
"""Compute IoU and DFL losses for bounding boxes."""
|
127
129
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
128
130
|
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
@@ -155,7 +157,7 @@ class RotatedBboxLoss(BboxLoss):
|
|
155
157
|
target_scores: torch.Tensor,
|
156
158
|
target_scores_sum: torch.Tensor,
|
157
159
|
fg_mask: torch.Tensor,
|
158
|
-
) ->
|
160
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
159
161
|
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
160
162
|
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
161
163
|
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
@@ -240,7 +242,7 @@ class v8DetectionLoss:
|
|
240
242
|
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
241
243
|
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
242
244
|
|
243
|
-
def __call__(self, preds: Any, batch:
|
245
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
244
246
|
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
245
247
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
246
248
|
feats = preds[1] if isinstance(preds, tuple) else preds
|
@@ -305,7 +307,7 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|
305
307
|
super().__init__(model)
|
306
308
|
self.overlap = model.args.overlap_mask
|
307
309
|
|
308
|
-
def __call__(self, preds: Any, batch:
|
310
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
309
311
|
"""Calculate and return the combined loss for detection and segmentation."""
|
310
312
|
loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
|
311
313
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
@@ -493,7 +495,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
493
495
|
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
494
496
|
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
495
497
|
|
496
|
-
def __call__(self, preds: Any, batch:
|
498
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
497
499
|
"""Calculate the total loss and detach it for pose estimation."""
|
498
500
|
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
499
501
|
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
@@ -577,7 +579,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
577
579
|
stride_tensor: torch.Tensor,
|
578
580
|
target_bboxes: torch.Tensor,
|
579
581
|
pred_kpts: torch.Tensor,
|
580
|
-
) ->
|
582
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
581
583
|
"""
|
582
584
|
Calculate the keypoints loss for the model.
|
583
585
|
|
@@ -645,7 +647,7 @@ class v8PoseLoss(v8DetectionLoss):
|
|
645
647
|
class v8ClassificationLoss:
|
646
648
|
"""Criterion class for computing training losses for classification."""
|
647
649
|
|
648
|
-
def __call__(self, preds: Any, batch:
|
650
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
649
651
|
"""Compute the classification loss between predictions and true labels."""
|
650
652
|
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
|
651
653
|
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
@@ -678,7 +680,7 @@ class v8OBBLoss(v8DetectionLoss):
|
|
678
680
|
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
679
681
|
return out
|
680
682
|
|
681
|
-
def __call__(self, preds: Any, batch:
|
683
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
682
684
|
"""Calculate and return the loss for oriented bounding box detection."""
|
683
685
|
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
684
686
|
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
@@ -778,7 +780,7 @@ class E2EDetectLoss:
|
|
778
780
|
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
779
781
|
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
780
782
|
|
781
|
-
def __call__(self, preds: Any, batch:
|
783
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
782
784
|
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
783
785
|
preds = preds[1] if isinstance(preds, tuple) else preds
|
784
786
|
one2many = preds["one2many"]
|
@@ -799,7 +801,7 @@ class TVPDetectLoss:
|
|
799
801
|
self.ori_no = self.vp_criterion.no
|
800
802
|
self.ori_reg_max = self.vp_criterion.reg_max
|
801
803
|
|
802
|
-
def __call__(self, preds: Any, batch:
|
804
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
803
805
|
"""Calculate the loss for text-visual prompt detection."""
|
804
806
|
feats = preds[1] if isinstance(preds, tuple) else preds
|
805
807
|
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
@@ -813,7 +815,7 @@ class TVPDetectLoss:
|
|
813
815
|
box_loss = vp_loss[0][1]
|
814
816
|
return box_loss, vp_loss[1]
|
815
817
|
|
816
|
-
def _get_vp_features(self, feats:
|
818
|
+
def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
|
817
819
|
"""Extract visual-prompt features from the model output."""
|
818
820
|
vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
|
819
821
|
|
@@ -835,7 +837,7 @@ class TVPSegmentLoss(TVPDetectLoss):
|
|
835
837
|
super().__init__(model)
|
836
838
|
self.vp_criterion = v8SegmentationLoss(model)
|
837
839
|
|
838
|
-
def __call__(self, preds: Any, batch:
|
840
|
+
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
839
841
|
"""Calculate the loss for text-visual prompt segmentation."""
|
840
842
|
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
841
843
|
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|