dgenerate-ultralytics-headless 8.3.196__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +33 -34
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +8 -10
- tests/test_cuda.py +9 -10
- tests/test_engine.py +29 -2
- tests/test_exports.py +69 -21
- tests/test_integrations.py +8 -11
- tests/test_python.py +109 -71
- tests/test_solutions.py +170 -159
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +57 -64
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/Objects365.yaml +19 -15
- ultralytics/cfg/datasets/SKU-110K.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +19 -21
- ultralytics/cfg/datasets/VisDrone.yaml +5 -5
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +24 -2
- ultralytics/cfg/datasets/coco.yaml +2 -2
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +7 -7
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +96 -94
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +286 -476
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +151 -26
- ultralytics/data/converter.py +38 -50
- ultralytics/data/dataset.py +47 -75
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +41 -45
- ultralytics/engine/exporter.py +462 -462
- ultralytics/engine/model.py +150 -191
- ultralytics/engine/predictor.py +30 -40
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +193 -120
- ultralytics/engine/tuner.py +77 -63
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +19 -30
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +7 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +22 -40
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +206 -79
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2268 -366
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +15 -41
- ultralytics/models/yolo/classify/val.py +34 -32
- ultralytics/models/yolo/detect/predict.py +8 -11
- ultralytics/models/yolo/detect/train.py +13 -32
- ultralytics/models/yolo/detect/val.py +75 -63
- ultralytics/models/yolo/model.py +37 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +42 -39
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +7 -22
- ultralytics/models/yolo/pose/train.py +10 -22
- ultralytics/models/yolo/pose/val.py +40 -59
- ultralytics/models/yolo/segment/predict.py +16 -20
- ultralytics/models/yolo/segment/train.py +3 -12
- ultralytics/models/yolo/segment/val.py +106 -56
- ultralytics/models/yolo/world/train.py +12 -16
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +31 -56
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +16 -21
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +152 -80
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +133 -217
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +64 -116
- ultralytics/nn/modules/transformer.py +79 -89
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +111 -156
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +13 -17
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +4 -7
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +70 -70
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +151 -87
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +19 -15
- ultralytics/utils/downloads.py +29 -41
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +16 -16
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +15 -24
- ultralytics/utils/metrics.py +131 -160
- ultralytics/utils/nms.py +21 -30
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +122 -119
- ultralytics/utils/tal.py +28 -44
- ultralytics/utils/torch_utils.py +70 -187
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.196.dist-info/RECORD +0 -281
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.196.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/nn/autobackend.py
CHANGED
|
@@ -19,11 +19,11 @@ from PIL import Image
|
|
|
19
19
|
from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML, is_jetson
|
|
20
20
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml, is_rockchip
|
|
21
21
|
from ultralytics.utils.downloads import attempt_download_asset, is_url
|
|
22
|
+
from ultralytics.utils.nms import non_max_suppression
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
def check_class_names(names: list | dict) -> dict[int, str]:
|
|
25
|
-
"""
|
|
26
|
-
Check class names and convert to dict format if needed.
|
|
26
|
+
"""Check class names and convert to dict format if needed.
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
29
|
names (list | dict): Class names as list or dict format.
|
|
@@ -52,8 +52,7 @@ def check_class_names(names: list | dict) -> dict[int, str]:
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
def default_class_names(data: str | Path | None = None) -> dict[int, str]:
|
|
55
|
-
"""
|
|
56
|
-
Apply default class names to an input YAML file or return numerical class names.
|
|
55
|
+
"""Apply default class names to an input YAML file or return numerical class names.
|
|
57
56
|
|
|
58
57
|
Args:
|
|
59
58
|
data (str | Path, optional): Path to YAML file containing class names.
|
|
@@ -70,8 +69,7 @@ def default_class_names(data: str | Path | None = None) -> dict[int, str]:
|
|
|
70
69
|
|
|
71
70
|
|
|
72
71
|
class AutoBackend(nn.Module):
|
|
73
|
-
"""
|
|
74
|
-
Handle dynamic backend selection for running inference using Ultralytics YOLO models.
|
|
72
|
+
"""Handle dynamic backend selection for running inference using Ultralytics YOLO models.
|
|
75
73
|
|
|
76
74
|
The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
|
|
77
75
|
range of formats, each with specific naming conventions as outlined below:
|
|
@@ -95,6 +93,9 @@ class AutoBackend(nn.Module):
|
|
|
95
93
|
| NCNN | *_ncnn_model/ |
|
|
96
94
|
| IMX | *_imx_model/ |
|
|
97
95
|
| RKNN | *_rknn_model/ |
|
|
96
|
+
| Triton Inference | triton://model |
|
|
97
|
+
| ExecuTorch | *.pte |
|
|
98
|
+
| Axelera | *_axelera_model/ |
|
|
98
99
|
|
|
99
100
|
Attributes:
|
|
100
101
|
model (torch.nn.Module): The loaded YOLO model.
|
|
@@ -121,10 +122,12 @@ class AutoBackend(nn.Module):
|
|
|
121
122
|
imx (bool): Whether the model is an IMX model.
|
|
122
123
|
rknn (bool): Whether the model is an RKNN model.
|
|
123
124
|
triton (bool): Whether the model is a Triton Inference Server model.
|
|
125
|
+
pte (bool): Whether the model is a PyTorch ExecuTorch model.
|
|
126
|
+
axelera (bool): Whether the model is an Axelera model.
|
|
124
127
|
|
|
125
128
|
Methods:
|
|
126
129
|
forward: Run inference on an input image.
|
|
127
|
-
from_numpy: Convert
|
|
130
|
+
from_numpy: Convert NumPy arrays to tensors on the model device.
|
|
128
131
|
warmup: Warm up the model with a dummy input.
|
|
129
132
|
_model_type: Determine the model type from file path.
|
|
130
133
|
|
|
@@ -144,8 +147,7 @@ class AutoBackend(nn.Module):
|
|
|
144
147
|
fuse: bool = True,
|
|
145
148
|
verbose: bool = True,
|
|
146
149
|
):
|
|
147
|
-
"""
|
|
148
|
-
Initialize the AutoBackend for inference.
|
|
150
|
+
"""Initialize the AutoBackend for inference.
|
|
149
151
|
|
|
150
152
|
Args:
|
|
151
153
|
model (str | torch.nn.Module): Path to the model weights file or a module instance.
|
|
@@ -175,10 +177,12 @@ class AutoBackend(nn.Module):
|
|
|
175
177
|
ncnn,
|
|
176
178
|
imx,
|
|
177
179
|
rknn,
|
|
180
|
+
pte,
|
|
181
|
+
axelera,
|
|
178
182
|
triton,
|
|
179
183
|
) = self._model_type("" if nn_module else model)
|
|
180
184
|
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
|
|
181
|
-
nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch
|
|
185
|
+
nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCHW)
|
|
182
186
|
stride, ch = 32, 3 # default stride and channels
|
|
183
187
|
end2end, dynamic = False, False
|
|
184
188
|
metadata, task = None, None
|
|
@@ -241,25 +245,28 @@ class AutoBackend(nn.Module):
|
|
|
241
245
|
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
|
|
242
246
|
import onnxruntime
|
|
243
247
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
248
|
+
# Select execution provider: CUDA > CoreML (mps) > CPU
|
|
249
|
+
available = onnxruntime.get_available_providers()
|
|
250
|
+
if cuda and "CUDAExecutionProvider" in available:
|
|
251
|
+
providers = [("CUDAExecutionProvider", {"device_id": device.index}), "CPUExecutionProvider"]
|
|
252
|
+
elif device.type == "mps" and "CoreMLExecutionProvider" in available:
|
|
253
|
+
providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
|
|
254
|
+
else:
|
|
255
|
+
providers = ["CPUExecutionProvider"]
|
|
256
|
+
if cuda:
|
|
257
|
+
LOGGER.warning("CUDA requested but CUDAExecutionProvider not available. Using CPU...")
|
|
258
|
+
device, cuda = torch.device("cpu"), False
|
|
259
|
+
LOGGER.info(
|
|
260
|
+
f"Using ONNX Runtime {onnxruntime.__version__} with {providers[0] if isinstance(providers[0], str) else providers[0][0]}"
|
|
261
|
+
)
|
|
253
262
|
if onnx:
|
|
254
263
|
session = onnxruntime.InferenceSession(w, providers=providers)
|
|
255
264
|
else:
|
|
256
|
-
check_requirements(
|
|
257
|
-
["model-compression-toolkit>=2.4.1", "sony-custom-layers[torch]>=0.3.0", "onnxruntime-extensions"]
|
|
258
|
-
)
|
|
265
|
+
check_requirements(("model-compression-toolkit>=2.4.1", "edge-mdt-cl<1.1.0", "onnxruntime-extensions"))
|
|
259
266
|
w = next(Path(w).glob("*.onnx"))
|
|
260
267
|
LOGGER.info(f"Loading {w} for ONNX IMX inference...")
|
|
261
268
|
import mct_quantizers as mctq
|
|
262
|
-
from
|
|
269
|
+
from edgemdt_cl.pytorch.nms import nms_ort # noqa - register custom NMS ops
|
|
263
270
|
|
|
264
271
|
session_options = mctq.get_ort_session_options()
|
|
265
272
|
session_options.enable_mem_reuse = False # fix the shape mismatch from onnxruntime
|
|
@@ -269,7 +276,10 @@ class AutoBackend(nn.Module):
|
|
|
269
276
|
metadata = session.get_modelmeta().custom_metadata_map
|
|
270
277
|
dynamic = isinstance(session.get_outputs()[0].shape[0], str)
|
|
271
278
|
fp16 = "float16" in session.get_inputs()[0].type
|
|
272
|
-
|
|
279
|
+
|
|
280
|
+
# Setup IO binding for optimized inference (CUDA only, not supported for CoreML)
|
|
281
|
+
use_io_binding = not dynamic and cuda
|
|
282
|
+
if use_io_binding:
|
|
273
283
|
io = session.io_binding()
|
|
274
284
|
bindings = []
|
|
275
285
|
for output in session.get_outputs():
|
|
@@ -332,11 +342,11 @@ class AutoBackend(nn.Module):
|
|
|
332
342
|
check_requirements("numpy==1.23.5")
|
|
333
343
|
|
|
334
344
|
try: # https://developer.nvidia.com/nvidia-tensorrt-download
|
|
335
|
-
import tensorrt as trt
|
|
345
|
+
import tensorrt as trt
|
|
336
346
|
except ImportError:
|
|
337
347
|
if LINUX:
|
|
338
348
|
check_requirements("tensorrt>7.0.0,!=10.1.0")
|
|
339
|
-
import tensorrt as trt
|
|
349
|
+
import tensorrt as trt
|
|
340
350
|
check_version(trt.__version__, ">=7.0.0", hard=True)
|
|
341
351
|
check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
|
|
342
352
|
if device.type == "cpu":
|
|
@@ -369,43 +379,47 @@ class AutoBackend(nn.Module):
|
|
|
369
379
|
is_trt10 = not hasattr(model, "num_bindings")
|
|
370
380
|
num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
|
|
371
381
|
for i in num:
|
|
382
|
+
# Get tensor info using TRT10+ or legacy API
|
|
372
383
|
if is_trt10:
|
|
373
384
|
name = model.get_tensor_name(i)
|
|
374
385
|
dtype = trt.nptype(model.get_tensor_dtype(name))
|
|
375
386
|
is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
|
|
380
|
-
if dtype == np.float16:
|
|
381
|
-
fp16 = True
|
|
382
|
-
else:
|
|
383
|
-
output_names.append(name)
|
|
384
|
-
shape = tuple(context.get_tensor_shape(name))
|
|
385
|
-
else: # TensorRT < 10.0
|
|
387
|
+
shape = tuple(model.get_tensor_shape(name))
|
|
388
|
+
profile_shape = tuple(model.get_tensor_profile_shape(name, 0)[2]) if is_input else None
|
|
389
|
+
else:
|
|
386
390
|
name = model.get_binding_name(i)
|
|
387
391
|
dtype = trt.nptype(model.get_binding_dtype(i))
|
|
388
392
|
is_input = model.binding_is_input(i)
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
393
|
+
shape = tuple(model.get_binding_shape(i))
|
|
394
|
+
profile_shape = tuple(model.get_profile_shape(0, i)[1]) if is_input else None
|
|
395
|
+
|
|
396
|
+
# Process input/output tensors
|
|
397
|
+
if is_input:
|
|
398
|
+
if -1 in shape:
|
|
399
|
+
dynamic = True
|
|
400
|
+
if is_trt10:
|
|
401
|
+
context.set_input_shape(name, profile_shape)
|
|
402
|
+
else:
|
|
403
|
+
context.set_binding_shape(i, profile_shape)
|
|
404
|
+
if dtype == np.float16:
|
|
405
|
+
fp16 = True
|
|
406
|
+
else:
|
|
407
|
+
output_names.append(name)
|
|
408
|
+
shape = tuple(context.get_tensor_shape(name)) if is_trt10 else tuple(context.get_binding_shape(i))
|
|
398
409
|
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
|
399
410
|
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
|
400
411
|
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
|
401
412
|
|
|
402
413
|
# CoreML
|
|
403
414
|
elif coreml:
|
|
404
|
-
check_requirements(
|
|
415
|
+
check_requirements(
|
|
416
|
+
["coremltools>=9.0", "numpy>=1.14.5,<=2.3.5"]
|
|
417
|
+
) # latest numpy 2.4.0rc1 breaks coremltools exports
|
|
405
418
|
LOGGER.info(f"Loading {w} for CoreML inference...")
|
|
406
419
|
import coremltools as ct
|
|
407
420
|
|
|
408
421
|
model = ct.models.MLModel(w)
|
|
422
|
+
dynamic = model.get_spec().description.input[0].type.HasField("multiArrayType")
|
|
409
423
|
metadata = dict(model.user_defined_metadata)
|
|
410
424
|
|
|
411
425
|
# TF SavedModel
|
|
@@ -413,8 +427,7 @@ class AutoBackend(nn.Module):
|
|
|
413
427
|
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
|
|
414
428
|
import tensorflow as tf
|
|
415
429
|
|
|
416
|
-
|
|
417
|
-
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
|
430
|
+
model = tf.saved_model.load(w)
|
|
418
431
|
metadata = Path(w) / "metadata.yaml"
|
|
419
432
|
|
|
420
433
|
# TF GraphDef
|
|
@@ -422,7 +435,7 @@ class AutoBackend(nn.Module):
|
|
|
422
435
|
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
|
|
423
436
|
import tensorflow as tf
|
|
424
437
|
|
|
425
|
-
from ultralytics.
|
|
438
|
+
from ultralytics.utils.export.tensorflow import gd_outputs
|
|
426
439
|
|
|
427
440
|
def wrap_frozen_graph(gd, inputs, outputs):
|
|
428
441
|
"""Wrap frozen graphs for deployment."""
|
|
@@ -490,7 +503,7 @@ class AutoBackend(nn.Module):
|
|
|
490
503
|
if ARM64
|
|
491
504
|
else "paddlepaddle>=3.0.0"
|
|
492
505
|
)
|
|
493
|
-
import paddle.inference as pdi
|
|
506
|
+
import paddle.inference as pdi
|
|
494
507
|
|
|
495
508
|
w = Path(w)
|
|
496
509
|
model_file, params_file = None, None
|
|
@@ -568,6 +581,51 @@ class AutoBackend(nn.Module):
|
|
|
568
581
|
rknn_model.init_runtime()
|
|
569
582
|
metadata = w.parent / "metadata.yaml"
|
|
570
583
|
|
|
584
|
+
# Axelera
|
|
585
|
+
elif axelera:
|
|
586
|
+
import os
|
|
587
|
+
|
|
588
|
+
if not os.environ.get("AXELERA_RUNTIME_DIR"):
|
|
589
|
+
LOGGER.warning(
|
|
590
|
+
"Axelera runtime environment is not activated."
|
|
591
|
+
"\nPlease run: source /opt/axelera/sdk/latest/axelera_activate.sh"
|
|
592
|
+
"\n\nIf this fails, verify driver installation: https://docs.ultralytics.com/integrations/axelera/#axelera-driver-installation"
|
|
593
|
+
)
|
|
594
|
+
try:
|
|
595
|
+
from axelera.runtime import op
|
|
596
|
+
except ImportError:
|
|
597
|
+
check_requirements(
|
|
598
|
+
"axelera_runtime2==0.1.2",
|
|
599
|
+
cmds="--extra-index-url https://software.axelera.ai/artifactory/axelera-runtime-pypi",
|
|
600
|
+
)
|
|
601
|
+
from axelera.runtime import op
|
|
602
|
+
|
|
603
|
+
w = Path(w)
|
|
604
|
+
if (found := next(w.rglob("*.axm"), None)) is None:
|
|
605
|
+
raise FileNotFoundError(f"No .axm file found in: {w}")
|
|
606
|
+
|
|
607
|
+
ax_model = op.load(str(found))
|
|
608
|
+
metadata = found.parent / "metadata.yaml"
|
|
609
|
+
|
|
610
|
+
# ExecuTorch
|
|
611
|
+
elif pte:
|
|
612
|
+
LOGGER.info(f"Loading {w} for ExecuTorch inference...")
|
|
613
|
+
# TorchAO release compatibility table bug https://github.com/pytorch/ao/issues/2919
|
|
614
|
+
check_requirements("setuptools<71.0.0") # Setuptools bug: https://github.com/pypa/setuptools/issues/4483
|
|
615
|
+
check_requirements(("executorch==1.0.1", "flatbuffers"))
|
|
616
|
+
from executorch.runtime import Runtime
|
|
617
|
+
|
|
618
|
+
w = Path(w)
|
|
619
|
+
if w.is_dir():
|
|
620
|
+
model_file = next(w.rglob("*.pte"))
|
|
621
|
+
metadata = w / "metadata.yaml"
|
|
622
|
+
else:
|
|
623
|
+
model_file = w
|
|
624
|
+
metadata = w.parent / "metadata.yaml"
|
|
625
|
+
|
|
626
|
+
program = Runtime.get().load_program(str(model_file))
|
|
627
|
+
model = program.load_method("forward")
|
|
628
|
+
|
|
571
629
|
# Any other format (unsupported)
|
|
572
630
|
else:
|
|
573
631
|
from ultralytics.engine.exporter import export_formats
|
|
@@ -584,14 +642,15 @@ class AutoBackend(nn.Module):
|
|
|
584
642
|
for k, v in metadata.items():
|
|
585
643
|
if k in {"stride", "batch", "channels"}:
|
|
586
644
|
metadata[k] = int(v)
|
|
587
|
-
elif k in {"imgsz", "names", "kpt_shape", "args"} and isinstance(v, str):
|
|
588
|
-
metadata[k] =
|
|
645
|
+
elif k in {"imgsz", "names", "kpt_shape", "kpt_names", "args"} and isinstance(v, str):
|
|
646
|
+
metadata[k] = ast.literal_eval(v)
|
|
589
647
|
stride = metadata["stride"]
|
|
590
648
|
task = metadata["task"]
|
|
591
649
|
batch = metadata["batch"]
|
|
592
650
|
imgsz = metadata["imgsz"]
|
|
593
651
|
names = metadata["names"]
|
|
594
652
|
kpt_shape = metadata.get("kpt_shape")
|
|
653
|
+
kpt_names = metadata.get("kpt_names")
|
|
595
654
|
end2end = metadata.get("args", {}).get("nms", False)
|
|
596
655
|
dynamic = metadata.get("args", {}).get("dynamic", dynamic)
|
|
597
656
|
ch = metadata.get("channels", 3)
|
|
@@ -613,8 +672,7 @@ class AutoBackend(nn.Module):
|
|
|
613
672
|
embed: list | None = None,
|
|
614
673
|
**kwargs: Any,
|
|
615
674
|
) -> torch.Tensor | list[torch.Tensor]:
|
|
616
|
-
"""
|
|
617
|
-
Run inference on an AutoBackend model.
|
|
675
|
+
"""Run inference on an AutoBackend model.
|
|
618
676
|
|
|
619
677
|
Args:
|
|
620
678
|
im (torch.Tensor): The image tensor to perform inference on.
|
|
@@ -626,7 +684,7 @@ class AutoBackend(nn.Module):
|
|
|
626
684
|
Returns:
|
|
627
685
|
(torch.Tensor | list[torch.Tensor]): The raw output tensor(s) from the model.
|
|
628
686
|
"""
|
|
629
|
-
|
|
687
|
+
_b, _ch, h, w = im.shape # batch, channel, height, width
|
|
630
688
|
if self.fp16 and im.dtype != torch.float16:
|
|
631
689
|
im = im.half() # to FP16
|
|
632
690
|
if self.nhwc:
|
|
@@ -648,10 +706,7 @@ class AutoBackend(nn.Module):
|
|
|
648
706
|
|
|
649
707
|
# ONNX Runtime
|
|
650
708
|
elif self.onnx or self.imx:
|
|
651
|
-
if self.
|
|
652
|
-
im = im.cpu().numpy() # torch to numpy
|
|
653
|
-
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
|
654
|
-
else:
|
|
709
|
+
if self.use_io_binding:
|
|
655
710
|
if not self.cuda:
|
|
656
711
|
im = im.cpu()
|
|
657
712
|
self.io.bind_input(
|
|
@@ -664,13 +719,21 @@ class AutoBackend(nn.Module):
|
|
|
664
719
|
)
|
|
665
720
|
self.session.run_with_iobinding(self.io)
|
|
666
721
|
y = self.bindings
|
|
722
|
+
else:
|
|
723
|
+
im = im.cpu().numpy() # torch to numpy
|
|
724
|
+
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
|
667
725
|
if self.imx:
|
|
668
726
|
if self.task == "detect":
|
|
669
727
|
# boxes, conf, cls
|
|
670
728
|
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1)
|
|
671
729
|
elif self.task == "pose":
|
|
672
730
|
# boxes, conf, kpts
|
|
673
|
-
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1)
|
|
731
|
+
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1, dtype=y[0].dtype)
|
|
732
|
+
elif self.task == "segment":
|
|
733
|
+
y = (
|
|
734
|
+
np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1, dtype=y[0].dtype),
|
|
735
|
+
y[4],
|
|
736
|
+
)
|
|
674
737
|
|
|
675
738
|
# OpenVINO
|
|
676
739
|
elif self.xml:
|
|
@@ -720,10 +783,13 @@ class AutoBackend(nn.Module):
|
|
|
720
783
|
|
|
721
784
|
# CoreML
|
|
722
785
|
elif self.coreml:
|
|
723
|
-
im = im
|
|
724
|
-
|
|
786
|
+
im = im.cpu().numpy()
|
|
787
|
+
if self.dynamic:
|
|
788
|
+
im = im.transpose(0, 3, 1, 2)
|
|
789
|
+
else:
|
|
790
|
+
im = Image.fromarray((im[0] * 255).astype("uint8"))
|
|
725
791
|
# im = im.resize((192, 320), Image.BILINEAR)
|
|
726
|
-
y = self.model.predict({"image":
|
|
792
|
+
y = self.model.predict({"image": im}) # coordinates are xywh normalized
|
|
727
793
|
if "confidence" in y: # NMS included
|
|
728
794
|
from ultralytics.utils.ops import xywh2xyxy
|
|
729
795
|
|
|
@@ -767,11 +833,19 @@ class AutoBackend(nn.Module):
|
|
|
767
833
|
im = im if isinstance(im, (list, tuple)) else [im]
|
|
768
834
|
y = self.rknn_model.inference(inputs=im)
|
|
769
835
|
|
|
836
|
+
# Axelera
|
|
837
|
+
elif self.axelera:
|
|
838
|
+
y = self.ax_model(im.cpu())
|
|
839
|
+
|
|
840
|
+
# ExecuTorch
|
|
841
|
+
elif self.pte:
|
|
842
|
+
y = self.model.execute([im])
|
|
843
|
+
|
|
770
844
|
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
|
771
845
|
else:
|
|
772
846
|
im = im.cpu().numpy()
|
|
773
847
|
if self.saved_model: # SavedModel
|
|
774
|
-
y = self.model
|
|
848
|
+
y = self.model.serving_default(im)
|
|
775
849
|
if not isinstance(y, list):
|
|
776
850
|
y = [y]
|
|
777
851
|
elif self.pb: # GraphDef
|
|
@@ -816,8 +890,6 @@ class AutoBackend(nn.Module):
|
|
|
816
890
|
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
|
|
817
891
|
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
|
818
892
|
|
|
819
|
-
# for x in y:
|
|
820
|
-
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
|
|
821
893
|
if isinstance(y, (list, tuple)):
|
|
822
894
|
if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined
|
|
823
895
|
nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400)
|
|
@@ -826,35 +898,35 @@ class AutoBackend(nn.Module):
|
|
|
826
898
|
else:
|
|
827
899
|
return self.from_numpy(y)
|
|
828
900
|
|
|
829
|
-
def from_numpy(self, x: np.ndarray) -> torch.Tensor:
|
|
830
|
-
"""
|
|
831
|
-
Convert a numpy array to a tensor.
|
|
901
|
+
def from_numpy(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
|
|
902
|
+
"""Convert a NumPy array to a torch tensor on the model device.
|
|
832
903
|
|
|
833
904
|
Args:
|
|
834
|
-
x (np.ndarray):
|
|
905
|
+
x (np.ndarray | torch.Tensor): Input array or tensor.
|
|
835
906
|
|
|
836
907
|
Returns:
|
|
837
|
-
(torch.Tensor):
|
|
908
|
+
(torch.Tensor): Tensor on `self.device`.
|
|
838
909
|
"""
|
|
839
910
|
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
|
|
840
911
|
|
|
841
912
|
def warmup(self, imgsz: tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:
|
|
842
|
-
"""
|
|
843
|
-
Warm up the model by running one forward pass with a dummy input.
|
|
913
|
+
"""Warm up the model by running one forward pass with a dummy input.
|
|
844
914
|
|
|
845
915
|
Args:
|
|
846
|
-
imgsz (tuple
|
|
916
|
+
imgsz (tuple[int, int, int, int]): Dummy input shape in (batch, channels, height, width) format.
|
|
847
917
|
"""
|
|
848
918
|
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
|
849
919
|
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
|
|
850
920
|
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
|
851
921
|
for _ in range(2 if self.jit else 1):
|
|
852
|
-
self.forward(im) # warmup
|
|
922
|
+
self.forward(im) # warmup model
|
|
923
|
+
warmup_boxes = torch.rand(1, 84, 16, device=self.device) # 16 boxes works best empirically
|
|
924
|
+
warmup_boxes[:, :4] *= imgsz[-1]
|
|
925
|
+
non_max_suppression(warmup_boxes) # warmup NMS
|
|
853
926
|
|
|
854
927
|
@staticmethod
|
|
855
928
|
def _model_type(p: str = "path/to/model.pt") -> list[bool]:
|
|
856
|
-
"""
|
|
857
|
-
Take a path to a model file and return the model type.
|
|
929
|
+
"""Take a path to a model file and return the model type.
|
|
858
930
|
|
|
859
931
|
Args:
|
|
860
932
|
p (str): Path to the model file.
|
|
@@ -863,8 +935,8 @@ class AutoBackend(nn.Module):
|
|
|
863
935
|
(list[bool]): List of booleans indicating the model type.
|
|
864
936
|
|
|
865
937
|
Examples:
|
|
866
|
-
>>>
|
|
867
|
-
>>>
|
|
938
|
+
>>> types = AutoBackend._model_type("path/to/model.onnx")
|
|
939
|
+
>>> assert types[2] # onnx
|
|
868
940
|
"""
|
|
869
941
|
from ultralytics.engine.exporter import export_formats
|
|
870
942
|
|
|
@@ -883,4 +955,4 @@ class AutoBackend(nn.Module):
|
|
|
883
955
|
url = urlsplit(p)
|
|
884
956
|
triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"}
|
|
885
957
|
|
|
886
|
-
return types
|
|
958
|
+
return [*types, triton]
|
|
@@ -103,80 +103,80 @@ from .transformer import (
|
|
|
103
103
|
)
|
|
104
104
|
|
|
105
105
|
__all__ = (
|
|
106
|
-
"
|
|
107
|
-
"
|
|
108
|
-
"
|
|
109
|
-
"
|
|
110
|
-
"
|
|
111
|
-
"
|
|
112
|
-
"ConvTranspose",
|
|
113
|
-
"Focus",
|
|
114
|
-
"GhostConv",
|
|
115
|
-
"ChannelAttention",
|
|
116
|
-
"SpatialAttention",
|
|
106
|
+
"AIFI",
|
|
107
|
+
"C1",
|
|
108
|
+
"C2",
|
|
109
|
+
"C2PSA",
|
|
110
|
+
"C3",
|
|
111
|
+
"C3TR",
|
|
117
112
|
"CBAM",
|
|
118
|
-
"
|
|
119
|
-
"TransformerLayer",
|
|
120
|
-
"TransformerBlock",
|
|
121
|
-
"MLPBlock",
|
|
122
|
-
"LayerNorm2d",
|
|
113
|
+
"CIB",
|
|
123
114
|
"DFL",
|
|
124
|
-
"
|
|
125
|
-
"
|
|
115
|
+
"ELAN1",
|
|
116
|
+
"MLP",
|
|
117
|
+
"OBB",
|
|
118
|
+
"PSA",
|
|
126
119
|
"SPP",
|
|
120
|
+
"SPPELAN",
|
|
127
121
|
"SPPF",
|
|
128
|
-
"
|
|
129
|
-
"
|
|
130
|
-
"
|
|
122
|
+
"A2C2f",
|
|
123
|
+
"AConv",
|
|
124
|
+
"ADown",
|
|
125
|
+
"Attention",
|
|
126
|
+
"BNContrastiveHead",
|
|
127
|
+
"Bottleneck",
|
|
128
|
+
"BottleneckCSP",
|
|
131
129
|
"C2f",
|
|
132
|
-
"C3k2",
|
|
133
|
-
"SCDown",
|
|
134
|
-
"C2fPSA",
|
|
135
|
-
"C2PSA",
|
|
136
130
|
"C2fAttn",
|
|
137
|
-
"
|
|
138
|
-
"
|
|
131
|
+
"C2fCIB",
|
|
132
|
+
"C2fPSA",
|
|
139
133
|
"C3Ghost",
|
|
140
|
-
"
|
|
141
|
-
"
|
|
142
|
-
"
|
|
143
|
-
"
|
|
144
|
-
"
|
|
145
|
-
"Segment",
|
|
146
|
-
"Pose",
|
|
134
|
+
"C3k2",
|
|
135
|
+
"C3x",
|
|
136
|
+
"CBFuse",
|
|
137
|
+
"CBLinear",
|
|
138
|
+
"ChannelAttention",
|
|
147
139
|
"Classify",
|
|
148
|
-
"
|
|
149
|
-
"
|
|
150
|
-
"
|
|
151
|
-
"
|
|
140
|
+
"Concat",
|
|
141
|
+
"ContrastiveHead",
|
|
142
|
+
"Conv",
|
|
143
|
+
"Conv2",
|
|
144
|
+
"ConvTranspose",
|
|
145
|
+
"DWConv",
|
|
146
|
+
"DWConvTranspose2d",
|
|
152
147
|
"DeformableTransformerDecoder",
|
|
153
148
|
"DeformableTransformerDecoderLayer",
|
|
149
|
+
"Detect",
|
|
150
|
+
"Focus",
|
|
151
|
+
"GhostBottleneck",
|
|
152
|
+
"GhostConv",
|
|
153
|
+
"HGBlock",
|
|
154
|
+
"HGStem",
|
|
155
|
+
"ImagePoolingAttn",
|
|
156
|
+
"Index",
|
|
157
|
+
"LRPCHead",
|
|
158
|
+
"LayerNorm2d",
|
|
159
|
+
"LightConv",
|
|
160
|
+
"MLPBlock",
|
|
154
161
|
"MSDeformAttn",
|
|
155
|
-
"
|
|
162
|
+
"MaxSigmoidAttnBlock",
|
|
163
|
+
"Pose",
|
|
164
|
+
"Proto",
|
|
165
|
+
"RTDETRDecoder",
|
|
166
|
+
"RepC3",
|
|
167
|
+
"RepConv",
|
|
168
|
+
"RepNCSPELAN4",
|
|
169
|
+
"RepVGGDW",
|
|
156
170
|
"ResNetLayer",
|
|
157
|
-
"
|
|
171
|
+
"SCDown",
|
|
172
|
+
"Segment",
|
|
173
|
+
"SpatialAttention",
|
|
174
|
+
"TorchVision",
|
|
175
|
+
"TransformerBlock",
|
|
176
|
+
"TransformerEncoderLayer",
|
|
177
|
+
"TransformerLayer",
|
|
158
178
|
"WorldDetect",
|
|
159
179
|
"YOLOEDetect",
|
|
160
180
|
"YOLOESegment",
|
|
161
181
|
"v10Detect",
|
|
162
|
-
"LRPCHead",
|
|
163
|
-
"ImagePoolingAttn",
|
|
164
|
-
"MaxSigmoidAttnBlock",
|
|
165
|
-
"ContrastiveHead",
|
|
166
|
-
"BNContrastiveHead",
|
|
167
|
-
"RepNCSPELAN4",
|
|
168
|
-
"ADown",
|
|
169
|
-
"SPPELAN",
|
|
170
|
-
"CBFuse",
|
|
171
|
-
"CBLinear",
|
|
172
|
-
"AConv",
|
|
173
|
-
"ELAN1",
|
|
174
|
-
"RepVGGDW",
|
|
175
|
-
"CIB",
|
|
176
|
-
"C2fCIB",
|
|
177
|
-
"Attention",
|
|
178
|
-
"PSA",
|
|
179
|
-
"TorchVision",
|
|
180
|
-
"Index",
|
|
181
|
-
"A2C2f",
|
|
182
182
|
)
|
|
@@ -6,8 +6,7 @@ import torch.nn as nn
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class AGLU(nn.Module):
|
|
9
|
-
"""
|
|
10
|
-
Unified activation function module from AGLU.
|
|
9
|
+
"""Unified activation function module from AGLU.
|
|
11
10
|
|
|
12
11
|
This class implements a parameterized activation function with learnable parameters lambda and kappa, based on the
|
|
13
12
|
AGLU (Adaptive Gated Linear Unit) approach.
|
|
@@ -40,11 +39,10 @@ class AGLU(nn.Module):
|
|
|
40
39
|
self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter
|
|
41
40
|
|
|
42
41
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
|
-
"""
|
|
44
|
-
Apply the Adaptive Gated Linear Unit (AGLU) activation function.
|
|
42
|
+
"""Apply the Adaptive Gated Linear Unit (AGLU) activation function.
|
|
45
43
|
|
|
46
|
-
This forward method implements the AGLU activation function with learnable parameters lambda and kappa.
|
|
47
|
-
|
|
44
|
+
This forward method implements the AGLU activation function with learnable parameters lambda and kappa. The
|
|
45
|
+
function applies a transformation that adaptively combines linear and non-linear components.
|
|
48
46
|
|
|
49
47
|
Args:
|
|
50
48
|
x (torch.Tensor): Input tensor to apply the activation function to.
|