dgenerate-ultralytics-headless 8.3.189__py3-none-any.whl → 8.3.191__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/METADATA +1 -1
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/RECORD +111 -109
- tests/test_cuda.py +6 -5
- tests/test_exports.py +1 -6
- tests/test_python.py +1 -4
- tests/test_solutions.py +1 -1
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +16 -14
- ultralytics/cfg/datasets/VisDrone.yaml +4 -4
- ultralytics/data/annotator.py +6 -6
- ultralytics/data/augment.py +53 -51
- ultralytics/data/base.py +15 -13
- ultralytics/data/build.py +7 -4
- ultralytics/data/converter.py +9 -10
- ultralytics/data/dataset.py +24 -22
- ultralytics/data/loaders.py +13 -11
- ultralytics/data/split.py +4 -3
- ultralytics/data/split_dota.py +14 -12
- ultralytics/data/utils.py +31 -25
- ultralytics/engine/exporter.py +7 -4
- ultralytics/engine/model.py +16 -14
- ultralytics/engine/predictor.py +9 -7
- ultralytics/engine/results.py +59 -57
- ultralytics/engine/trainer.py +7 -0
- ultralytics/engine/tuner.py +4 -3
- ultralytics/engine/validator.py +3 -1
- ultralytics/hub/__init__.py +6 -2
- ultralytics/hub/auth.py +2 -2
- ultralytics/hub/google/__init__.py +9 -8
- ultralytics/hub/session.py +11 -11
- ultralytics/hub/utils.py +8 -9
- ultralytics/models/fastsam/model.py +8 -6
- ultralytics/models/nas/model.py +5 -3
- ultralytics/models/rtdetr/train.py +4 -3
- ultralytics/models/rtdetr/val.py +6 -4
- ultralytics/models/sam/amg.py +13 -10
- ultralytics/models/sam/model.py +3 -2
- ultralytics/models/sam/modules/blocks.py +21 -21
- ultralytics/models/sam/modules/decoders.py +11 -11
- ultralytics/models/sam/modules/encoders.py +25 -25
- ultralytics/models/sam/modules/memory_attention.py +9 -8
- ultralytics/models/sam/modules/sam.py +8 -10
- ultralytics/models/sam/modules/tiny_encoder.py +21 -20
- ultralytics/models/sam/modules/transformer.py +6 -5
- ultralytics/models/sam/modules/utils.py +7 -5
- ultralytics/models/sam/predict.py +32 -31
- ultralytics/models/utils/loss.py +29 -27
- ultralytics/models/utils/ops.py +10 -8
- ultralytics/models/yolo/classify/train.py +7 -5
- ultralytics/models/yolo/classify/val.py +10 -8
- ultralytics/models/yolo/detect/predict.py +3 -3
- ultralytics/models/yolo/detect/train.py +8 -6
- ultralytics/models/yolo/detect/val.py +23 -21
- ultralytics/models/yolo/model.py +14 -14
- ultralytics/models/yolo/obb/train.py +5 -3
- ultralytics/models/yolo/obb/val.py +13 -10
- ultralytics/models/yolo/pose/train.py +7 -5
- ultralytics/models/yolo/pose/val.py +11 -9
- ultralytics/models/yolo/segment/train.py +4 -5
- ultralytics/models/yolo/segment/val.py +12 -10
- ultralytics/models/yolo/world/train.py +9 -7
- ultralytics/models/yolo/yoloe/train.py +7 -6
- ultralytics/models/yolo/yoloe/val.py +10 -8
- ultralytics/nn/autobackend.py +40 -52
- ultralytics/nn/modules/__init__.py +3 -3
- ultralytics/nn/modules/block.py +12 -12
- ultralytics/nn/modules/conv.py +4 -3
- ultralytics/nn/modules/head.py +46 -38
- ultralytics/nn/modules/transformer.py +22 -21
- ultralytics/nn/tasks.py +2 -2
- ultralytics/nn/text_model.py +6 -5
- ultralytics/solutions/analytics.py +7 -5
- ultralytics/solutions/config.py +12 -10
- ultralytics/solutions/distance_calculation.py +3 -3
- ultralytics/solutions/heatmap.py +4 -2
- ultralytics/solutions/object_counter.py +5 -3
- ultralytics/solutions/parking_management.py +4 -2
- ultralytics/solutions/region_counter.py +7 -5
- ultralytics/solutions/similarity_search.py +5 -3
- ultralytics/solutions/solutions.py +38 -36
- ultralytics/solutions/streamlit_inference.py +8 -7
- ultralytics/trackers/bot_sort.py +11 -9
- ultralytics/trackers/byte_tracker.py +17 -15
- ultralytics/trackers/utils/gmc.py +4 -3
- ultralytics/utils/__init__.py +27 -77
- ultralytics/utils/autobatch.py +3 -2
- ultralytics/utils/autodevice.py +10 -10
- ultralytics/utils/benchmarks.py +11 -10
- ultralytics/utils/callbacks/comet.py +9 -9
- ultralytics/utils/callbacks/platform.py +2 -1
- ultralytics/utils/checks.py +20 -29
- ultralytics/utils/downloads.py +2 -2
- ultralytics/utils/export.py +12 -11
- ultralytics/utils/files.py +8 -7
- ultralytics/utils/git.py +139 -0
- ultralytics/utils/instance.py +8 -7
- ultralytics/utils/logger.py +7 -6
- ultralytics/utils/loss.py +15 -13
- ultralytics/utils/metrics.py +62 -62
- ultralytics/utils/nms.py +346 -0
- ultralytics/utils/ops.py +83 -251
- ultralytics/utils/patches.py +6 -4
- ultralytics/utils/plotting.py +18 -16
- ultralytics/utils/tal.py +1 -1
- ultralytics/utils/torch_utils.py +4 -2
- ultralytics/utils/tqdm.py +47 -33
- ultralytics/utils/triton.py +3 -2
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.189.dist-info → dgenerate_ultralytics_headless-8.3.191.dist-info}/top_level.txt +0 -0
ultralytics/nn/autobackend.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import ast
|
4
6
|
import json
|
5
7
|
import platform
|
6
8
|
import zipfile
|
7
9
|
from collections import OrderedDict, namedtuple
|
8
10
|
from pathlib import Path
|
9
|
-
from typing import Any
|
11
|
+
from typing import Any
|
10
12
|
|
11
13
|
import cv2
|
12
14
|
import numpy as np
|
@@ -19,7 +21,7 @@ from ultralytics.utils.checks import check_requirements, check_suffix, check_ver
|
|
19
21
|
from ultralytics.utils.downloads import attempt_download_asset, is_url
|
20
22
|
|
21
23
|
|
22
|
-
def check_class_names(names:
|
24
|
+
def check_class_names(names: list | dict) -> dict[int, str]:
|
23
25
|
"""
|
24
26
|
Check class names and convert to dict format if needed.
|
25
27
|
|
@@ -49,7 +51,7 @@ def check_class_names(names: Union[List, Dict]) -> Dict[int, str]:
|
|
49
51
|
return names
|
50
52
|
|
51
53
|
|
52
|
-
def default_class_names(data:
|
54
|
+
def default_class_names(data: str | Path | None = None) -> dict[int, str]:
|
53
55
|
"""
|
54
56
|
Apply default class names to an input YAML file or return numerical class names.
|
55
57
|
|
@@ -127,17 +129,17 @@ class AutoBackend(nn.Module):
|
|
127
129
|
_model_type: Determine the model type from file path.
|
128
130
|
|
129
131
|
Examples:
|
130
|
-
>>> model = AutoBackend(
|
132
|
+
>>> model = AutoBackend(model="yolo11n.pt", device="cuda")
|
131
133
|
>>> results = model(img)
|
132
134
|
"""
|
133
135
|
|
134
136
|
@torch.no_grad()
|
135
137
|
def __init__(
|
136
138
|
self,
|
137
|
-
|
139
|
+
model: str | torch.nn.Module = "yolo11n.pt",
|
138
140
|
device: torch.device = torch.device("cpu"),
|
139
141
|
dnn: bool = False,
|
140
|
-
data:
|
142
|
+
data: str | Path | None = None,
|
141
143
|
fp16: bool = False,
|
142
144
|
fuse: bool = True,
|
143
145
|
verbose: bool = True,
|
@@ -146,7 +148,7 @@ class AutoBackend(nn.Module):
|
|
146
148
|
Initialize the AutoBackend for inference.
|
147
149
|
|
148
150
|
Args:
|
149
|
-
|
151
|
+
model (str | torch.nn.Module): Path to the model weights file or a module instance.
|
150
152
|
device (torch.device): Device to run the model on.
|
151
153
|
dnn (bool): Use OpenCV DNN module for ONNX inference.
|
152
154
|
data (str | Path, optional): Path to the additional data.yaml file containing class names.
|
@@ -155,8 +157,7 @@ class AutoBackend(nn.Module):
|
|
155
157
|
verbose (bool): Enable verbose logging.
|
156
158
|
"""
|
157
159
|
super().__init__()
|
158
|
-
|
159
|
-
nn_module = isinstance(weights, torch.nn.Module)
|
160
|
+
nn_module = isinstance(model, torch.nn.Module)
|
160
161
|
(
|
161
162
|
pt,
|
162
163
|
jit,
|
@@ -175,12 +176,12 @@ class AutoBackend(nn.Module):
|
|
175
176
|
imx,
|
176
177
|
rknn,
|
177
178
|
triton,
|
178
|
-
) = self._model_type(
|
179
|
+
) = self._model_type("" if nn_module else model)
|
179
180
|
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
|
180
181
|
nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH)
|
181
182
|
stride, ch = 32, 3 # default stride and channels
|
182
183
|
end2end, dynamic = False, False
|
183
|
-
|
184
|
+
metadata, task = None, None
|
184
185
|
|
185
186
|
# Set device
|
186
187
|
cuda = isinstance(device, torch.device) and torch.cuda.is_available() and device.type != "cpu" # use CUDA
|
@@ -189,39 +190,32 @@ class AutoBackend(nn.Module):
|
|
189
190
|
cuda = False
|
190
191
|
|
191
192
|
# Download if not local
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
if
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
pt = True
|
211
|
-
|
212
|
-
# PyTorch
|
213
|
-
elif pt:
|
214
|
-
from ultralytics.nn.tasks import attempt_load_weights
|
215
|
-
|
216
|
-
model = attempt_load_weights(
|
217
|
-
weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse
|
218
|
-
)
|
193
|
+
w = attempt_download_asset(model) if pt else model # weights path
|
194
|
+
|
195
|
+
# PyTorch (in-memory or file)
|
196
|
+
if nn_module or pt:
|
197
|
+
if nn_module:
|
198
|
+
pt = True
|
199
|
+
if fuse:
|
200
|
+
if IS_JETSON and is_jetson(jetpack=5):
|
201
|
+
# Jetson Jetpack5 requires device before fuse https://github.com/ultralytics/ultralytics/pull/21028
|
202
|
+
model = model.to(device)
|
203
|
+
model = model.fuse(verbose=verbose)
|
204
|
+
model = model.to(device)
|
205
|
+
else: # pt file
|
206
|
+
from ultralytics.nn.tasks import attempt_load_one_weight
|
207
|
+
|
208
|
+
model, _ = attempt_load_one_weight(model, device=device, fuse=fuse) # load model, ckpt
|
209
|
+
|
210
|
+
# Common PyTorch model processing
|
219
211
|
if hasattr(model, "kpt_shape"):
|
220
212
|
kpt_shape = model.kpt_shape # pose-only
|
221
213
|
stride = max(int(model.stride.max()), 32) # model stride
|
222
214
|
names = model.module.names if hasattr(model, "module") else model.names # get class names
|
223
215
|
model.half() if fp16 else model.float()
|
224
216
|
ch = model.yaml.get("channels", 3)
|
217
|
+
for p in model.parameters():
|
218
|
+
p.requires_grad = False
|
225
219
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
226
220
|
|
227
221
|
# TorchScript
|
@@ -407,6 +401,7 @@ class AutoBackend(nn.Module):
|
|
407
401
|
|
408
402
|
# CoreML
|
409
403
|
elif coreml:
|
404
|
+
check_requirements("coremltools>=8.0")
|
410
405
|
LOGGER.info(f"Loading {w} for CoreML inference...")
|
411
406
|
import coremltools as ct
|
412
407
|
|
@@ -483,7 +478,7 @@ class AutoBackend(nn.Module):
|
|
483
478
|
|
484
479
|
# TF.js
|
485
480
|
elif tfjs:
|
486
|
-
raise NotImplementedError("
|
481
|
+
raise NotImplementedError("Ultralytics TF.js inference is not currently supported.")
|
487
482
|
|
488
483
|
# PaddlePaddle
|
489
484
|
elif paddle:
|
@@ -601,18 +596,13 @@ class AutoBackend(nn.Module):
|
|
601
596
|
dynamic = metadata.get("args", {}).get("dynamic", dynamic)
|
602
597
|
ch = metadata.get("channels", 3)
|
603
598
|
elif not (pt or triton or nn_module):
|
604
|
-
LOGGER.warning(f"Metadata not found for 'model={
|
599
|
+
LOGGER.warning(f"Metadata not found for 'model={w}'")
|
605
600
|
|
606
601
|
# Check names
|
607
602
|
if "names" not in locals(): # names missing
|
608
603
|
names = default_class_names(data)
|
609
604
|
names = check_class_names(names)
|
610
605
|
|
611
|
-
# Disable gradients
|
612
|
-
if pt:
|
613
|
-
for p in model.parameters():
|
614
|
-
p.requires_grad = False
|
615
|
-
|
616
606
|
self.__dict__.update(locals()) # assign all variables to self
|
617
607
|
|
618
608
|
def forward(
|
@@ -620,9 +610,9 @@ class AutoBackend(nn.Module):
|
|
620
610
|
im: torch.Tensor,
|
621
611
|
augment: bool = False,
|
622
612
|
visualize: bool = False,
|
623
|
-
embed:
|
613
|
+
embed: list | None = None,
|
624
614
|
**kwargs: Any,
|
625
|
-
) ->
|
615
|
+
) -> torch.Tensor | list[torch.Tensor]:
|
626
616
|
"""
|
627
617
|
Run inference on an AutoBackend model.
|
628
618
|
|
@@ -851,15 +841,13 @@ class AutoBackend(nn.Module):
|
|
851
841
|
"""
|
852
842
|
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
|
853
843
|
|
854
|
-
def warmup(self, imgsz:
|
844
|
+
def warmup(self, imgsz: tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:
|
855
845
|
"""
|
856
846
|
Warm up the model by running one forward pass with a dummy input.
|
857
847
|
|
858
848
|
Args:
|
859
849
|
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
|
860
850
|
"""
|
861
|
-
import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
|
862
|
-
|
863
851
|
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
864
852
|
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
|
865
853
|
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
@@ -867,7 +855,7 @@ class AutoBackend(nn.Module):
|
|
867
855
|
self.forward(im) # warmup
|
868
856
|
|
869
857
|
@staticmethod
|
870
|
-
def _model_type(p: str = "path/to/model.pt") ->
|
858
|
+
def _model_type(p: str = "path/to/model.pt") -> list[bool]:
|
871
859
|
"""
|
872
860
|
Take a path to a model file and return the model type.
|
873
861
|
|
@@ -878,7 +866,7 @@ class AutoBackend(nn.Module):
|
|
878
866
|
(List[bool]): List of booleans indicating the model type.
|
879
867
|
|
880
868
|
Examples:
|
881
|
-
>>> model = AutoBackend(
|
869
|
+
>>> model = AutoBackend(model="path/to/model.onnx")
|
882
870
|
>>> model_type = model._model_type() # returns "onnx"
|
883
871
|
"""
|
884
872
|
from ultralytics.engine.exporter import export_formats
|
@@ -7,14 +7,14 @@ blocks, attention mechanisms, transformer components, and detection/segmentation
|
|
7
7
|
|
8
8
|
Examples:
|
9
9
|
Visualize a module with Netron
|
10
|
-
>>> from ultralytics.nn.modules import
|
10
|
+
>>> from ultralytics.nn.modules import Conv
|
11
11
|
>>> import torch
|
12
|
-
>>> import
|
12
|
+
>>> import subprocess
|
13
13
|
>>> x = torch.ones(1, 128, 40, 40)
|
14
14
|
>>> m = Conv(128, 128)
|
15
15
|
>>> f = f"{m._get_name()}.onnx"
|
16
16
|
>>> torch.onnx.export(m, x, f)
|
17
|
-
>>>
|
17
|
+
>>> subprocess.run(f"onnxslim {f} {f} && open {f}", shell=True, check=True) # pip install onnxslim
|
18
18
|
"""
|
19
19
|
|
20
20
|
from .block import (
|
ultralytics/nn/modules/block.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""Block modules."""
|
3
3
|
|
4
|
-
from
|
4
|
+
from __future__ import annotations
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn as nn
|
@@ -192,7 +192,7 @@ class HGBlock(nn.Module):
|
|
192
192
|
class SPP(nn.Module):
|
193
193
|
"""Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
|
194
194
|
|
195
|
-
def __init__(self, c1: int, c2: int, k:
|
195
|
+
def __init__(self, c1: int, c2: int, k: tuple[int, ...] = (5, 9, 13)):
|
196
196
|
"""
|
197
197
|
Initialize the SPP layer with input/output channels and pooling kernel sizes.
|
198
198
|
|
@@ -471,7 +471,7 @@ class Bottleneck(nn.Module):
|
|
471
471
|
"""Standard bottleneck."""
|
472
472
|
|
473
473
|
def __init__(
|
474
|
-
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k:
|
474
|
+
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: tuple[int, int] = (3, 3), e: float = 0.5
|
475
475
|
):
|
476
476
|
"""
|
477
477
|
Initialize a standard bottleneck module.
|
@@ -711,7 +711,7 @@ class ImagePoolingAttn(nn.Module):
|
|
711
711
|
"""ImagePoolingAttn: Enhance the text embeddings with image-aware information."""
|
712
712
|
|
713
713
|
def __init__(
|
714
|
-
self, ec: int = 256, ch:
|
714
|
+
self, ec: int = 256, ch: tuple[int, ...] = (), ct: int = 512, nh: int = 8, k: int = 3, scale: bool = False
|
715
715
|
):
|
716
716
|
"""
|
717
717
|
Initialize ImagePoolingAttn module.
|
@@ -740,7 +740,7 @@ class ImagePoolingAttn(nn.Module):
|
|
740
740
|
self.hc = ec // nh
|
741
741
|
self.k = k
|
742
742
|
|
743
|
-
def forward(self, x:
|
743
|
+
def forward(self, x: list[torch.Tensor], text: torch.Tensor) -> torch.Tensor:
|
744
744
|
"""
|
745
745
|
Forward pass of ImagePoolingAttn.
|
746
746
|
|
@@ -856,7 +856,7 @@ class RepBottleneck(Bottleneck):
|
|
856
856
|
"""Rep bottleneck."""
|
857
857
|
|
858
858
|
def __init__(
|
859
|
-
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k:
|
859
|
+
self, c1: int, c2: int, shortcut: bool = True, g: int = 1, k: tuple[int, int] = (3, 3), e: float = 0.5
|
860
860
|
):
|
861
861
|
"""
|
862
862
|
Initialize RepBottleneck.
|
@@ -1026,7 +1026,7 @@ class SPPELAN(nn.Module):
|
|
1026
1026
|
class CBLinear(nn.Module):
|
1027
1027
|
"""CBLinear."""
|
1028
1028
|
|
1029
|
-
def __init__(self, c1: int, c2s:
|
1029
|
+
def __init__(self, c1: int, c2s: list[int], k: int = 1, s: int = 1, p: int | None = None, g: int = 1):
|
1030
1030
|
"""
|
1031
1031
|
Initialize CBLinear module.
|
1032
1032
|
|
@@ -1042,7 +1042,7 @@ class CBLinear(nn.Module):
|
|
1042
1042
|
self.c2s = c2s
|
1043
1043
|
self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
|
1044
1044
|
|
1045
|
-
def forward(self, x: torch.Tensor) ->
|
1045
|
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
1046
1046
|
"""Forward pass through CBLinear layer."""
|
1047
1047
|
return self.conv(x).split(self.c2s, dim=1)
|
1048
1048
|
|
@@ -1050,7 +1050,7 @@ class CBLinear(nn.Module):
|
|
1050
1050
|
class CBFuse(nn.Module):
|
1051
1051
|
"""CBFuse."""
|
1052
1052
|
|
1053
|
-
def __init__(self, idx:
|
1053
|
+
def __init__(self, idx: list[int]):
|
1054
1054
|
"""
|
1055
1055
|
Initialize CBFuse module.
|
1056
1056
|
|
@@ -1060,7 +1060,7 @@ class CBFuse(nn.Module):
|
|
1060
1060
|
super().__init__()
|
1061
1061
|
self.idx = idx
|
1062
1062
|
|
1063
|
-
def forward(self, xs:
|
1063
|
+
def forward(self, xs: list[torch.Tensor]) -> torch.Tensor:
|
1064
1064
|
"""
|
1065
1065
|
Forward pass through CBFuse layer.
|
1066
1066
|
|
@@ -1974,7 +1974,7 @@ class Residual(nn.Module):
|
|
1974
1974
|
class SAVPE(nn.Module):
|
1975
1975
|
"""Spatial-Aware Visual Prompt Embedding module for feature enhancement."""
|
1976
1976
|
|
1977
|
-
def __init__(self, ch:
|
1977
|
+
def __init__(self, ch: list[int], c3: int, embed: int):
|
1978
1978
|
"""
|
1979
1979
|
Initialize SAVPE module with channels, intermediate channels, and embedding dimension.
|
1980
1980
|
|
@@ -2002,7 +2002,7 @@ class SAVPE(nn.Module):
|
|
2002
2002
|
self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
|
2003
2003
|
self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
|
2004
2004
|
|
2005
|
-
def forward(self, x:
|
2005
|
+
def forward(self, x: list[torch.Tensor], vp: torch.Tensor) -> torch.Tensor:
|
2006
2006
|
"""Process input features and visual prompts to generate enhanced embeddings."""
|
2007
2007
|
y = [self.cv2[i](xi) for i, xi in enumerate(x)]
|
2008
2008
|
y = self.cv4(torch.cat(y, dim=1))
|
ultralytics/nn/modules/conv.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
2
|
"""Convolution modules."""
|
3
3
|
|
4
|
+
from __future__ import annotations
|
5
|
+
|
4
6
|
import math
|
5
|
-
from typing import List
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
import torch
|
@@ -669,7 +670,7 @@ class Concat(nn.Module):
|
|
669
670
|
super().__init__()
|
670
671
|
self.d = dimension
|
671
672
|
|
672
|
-
def forward(self, x:
|
673
|
+
def forward(self, x: list[torch.Tensor]):
|
673
674
|
"""
|
674
675
|
Concatenate input tensors along specified dimension.
|
675
676
|
|
@@ -700,7 +701,7 @@ class Index(nn.Module):
|
|
700
701
|
super().__init__()
|
701
702
|
self.index = index
|
702
703
|
|
703
|
-
def forward(self, x:
|
704
|
+
def forward(self, x: list[torch.Tensor]):
|
704
705
|
"""
|
705
706
|
Select and return a particular index from input.
|
706
707
|
|