dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.3.248__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/METADATA +13 -14
- dgenerate_ultralytics_headless-8.3.248.dist-info/RECORD +298 -0
- tests/__init__.py +5 -7
- tests/conftest.py +8 -15
- tests/test_cli.py +1 -1
- tests/test_cuda.py +5 -8
- tests/test_engine.py +1 -1
- tests/test_exports.py +57 -12
- tests/test_integrations.py +4 -4
- tests/test_python.py +84 -53
- tests/test_solutions.py +160 -151
- ultralytics/__init__.py +1 -1
- ultralytics/cfg/__init__.py +56 -62
- ultralytics/cfg/datasets/Argoverse.yaml +7 -6
- ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
- ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
- ultralytics/cfg/datasets/ImageNet.yaml +1 -1
- ultralytics/cfg/datasets/VOC.yaml +15 -16
- ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
- ultralytics/cfg/datasets/coco-pose.yaml +21 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
- ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
- ultralytics/cfg/datasets/dog-pose.yaml +28 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +5 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
- ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
- ultralytics/cfg/datasets/xView.yaml +16 -16
- ultralytics/cfg/default.yaml +1 -1
- ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
- ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
- ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
- ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
- ultralytics/cfg/models/v6/yolov6.yaml +1 -1
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
- ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
- ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
- ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +3 -4
- ultralytics/data/augment.py +285 -475
- ultralytics/data/base.py +18 -26
- ultralytics/data/build.py +147 -25
- ultralytics/data/converter.py +36 -46
- ultralytics/data/dataset.py +46 -74
- ultralytics/data/loaders.py +42 -49
- ultralytics/data/split.py +5 -6
- ultralytics/data/split_dota.py +8 -15
- ultralytics/data/utils.py +34 -43
- ultralytics/engine/exporter.py +319 -237
- ultralytics/engine/model.py +148 -188
- ultralytics/engine/predictor.py +29 -38
- ultralytics/engine/results.py +177 -311
- ultralytics/engine/trainer.py +83 -59
- ultralytics/engine/tuner.py +23 -34
- ultralytics/engine/validator.py +39 -22
- ultralytics/hub/__init__.py +16 -19
- ultralytics/hub/auth.py +6 -12
- ultralytics/hub/google/__init__.py +7 -10
- ultralytics/hub/session.py +15 -25
- ultralytics/hub/utils.py +5 -8
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +8 -10
- ultralytics/models/fastsam/predict.py +17 -29
- ultralytics/models/fastsam/utils.py +1 -2
- ultralytics/models/fastsam/val.py +5 -7
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +5 -8
- ultralytics/models/nas/predict.py +7 -9
- ultralytics/models/nas/val.py +1 -2
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +5 -8
- ultralytics/models/rtdetr/predict.py +15 -19
- ultralytics/models/rtdetr/train.py +10 -13
- ultralytics/models/rtdetr/val.py +21 -23
- ultralytics/models/sam/__init__.py +15 -2
- ultralytics/models/sam/amg.py +14 -20
- ultralytics/models/sam/build.py +26 -19
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +29 -32
- ultralytics/models/sam/modules/blocks.py +83 -144
- ultralytics/models/sam/modules/decoders.py +19 -37
- ultralytics/models/sam/modules/encoders.py +44 -101
- ultralytics/models/sam/modules/memory_attention.py +16 -30
- ultralytics/models/sam/modules/sam.py +200 -73
- ultralytics/models/sam/modules/tiny_encoder.py +64 -83
- ultralytics/models/sam/modules/transformer.py +18 -28
- ultralytics/models/sam/modules/utils.py +174 -50
- ultralytics/models/sam/predict.py +2248 -350
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/loss.py +14 -26
- ultralytics/models/utils/ops.py +13 -17
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +9 -12
- ultralytics/models/yolo/classify/train.py +11 -32
- ultralytics/models/yolo/classify/val.py +29 -28
- ultralytics/models/yolo/detect/predict.py +7 -10
- ultralytics/models/yolo/detect/train.py +11 -20
- ultralytics/models/yolo/detect/val.py +70 -58
- ultralytics/models/yolo/model.py +36 -53
- ultralytics/models/yolo/obb/predict.py +5 -14
- ultralytics/models/yolo/obb/train.py +11 -14
- ultralytics/models/yolo/obb/val.py +39 -36
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +6 -21
- ultralytics/models/yolo/pose/train.py +10 -15
- ultralytics/models/yolo/pose/val.py +38 -57
- ultralytics/models/yolo/segment/predict.py +14 -18
- ultralytics/models/yolo/segment/train.py +3 -6
- ultralytics/models/yolo/segment/val.py +93 -45
- ultralytics/models/yolo/world/train.py +8 -14
- ultralytics/models/yolo/world/train_world.py +11 -34
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +16 -23
- ultralytics/models/yolo/yoloe/train.py +30 -43
- ultralytics/models/yolo/yoloe/train_seg.py +5 -10
- ultralytics/models/yolo/yoloe/val.py +15 -20
- ultralytics/nn/__init__.py +7 -7
- ultralytics/nn/autobackend.py +145 -77
- ultralytics/nn/modules/__init__.py +60 -60
- ultralytics/nn/modules/activation.py +4 -6
- ultralytics/nn/modules/block.py +132 -216
- ultralytics/nn/modules/conv.py +52 -97
- ultralytics/nn/modules/head.py +50 -103
- ultralytics/nn/modules/transformer.py +76 -88
- ultralytics/nn/modules/utils.py +16 -21
- ultralytics/nn/tasks.py +94 -154
- ultralytics/nn/text_model.py +40 -67
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +11 -17
- ultralytics/solutions/analytics.py +15 -16
- ultralytics/solutions/config.py +5 -6
- ultralytics/solutions/distance_calculation.py +10 -13
- ultralytics/solutions/heatmap.py +7 -13
- ultralytics/solutions/instance_segmentation.py +5 -8
- ultralytics/solutions/object_blurrer.py +7 -10
- ultralytics/solutions/object_counter.py +12 -19
- ultralytics/solutions/object_cropper.py +8 -14
- ultralytics/solutions/parking_management.py +33 -31
- ultralytics/solutions/queue_management.py +10 -12
- ultralytics/solutions/region_counter.py +9 -12
- ultralytics/solutions/security_alarm.py +15 -20
- ultralytics/solutions/similarity_search.py +10 -15
- ultralytics/solutions/solutions.py +75 -74
- ultralytics/solutions/speed_estimation.py +7 -10
- ultralytics/solutions/streamlit_inference.py +2 -4
- ultralytics/solutions/templates/similarity-search.html +7 -18
- ultralytics/solutions/trackzone.py +7 -10
- ultralytics/solutions/vision_eye.py +5 -8
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +3 -5
- ultralytics/trackers/bot_sort.py +10 -27
- ultralytics/trackers/byte_tracker.py +14 -30
- ultralytics/trackers/track.py +3 -6
- ultralytics/trackers/utils/gmc.py +11 -22
- ultralytics/trackers/utils/kalman_filter.py +37 -48
- ultralytics/trackers/utils/matching.py +12 -15
- ultralytics/utils/__init__.py +116 -116
- ultralytics/utils/autobatch.py +2 -4
- ultralytics/utils/autodevice.py +17 -18
- ultralytics/utils/benchmarks.py +32 -46
- ultralytics/utils/callbacks/base.py +8 -10
- ultralytics/utils/callbacks/clearml.py +5 -13
- ultralytics/utils/callbacks/comet.py +32 -46
- ultralytics/utils/callbacks/dvc.py +13 -18
- ultralytics/utils/callbacks/mlflow.py +4 -5
- ultralytics/utils/callbacks/neptune.py +7 -15
- ultralytics/utils/callbacks/platform.py +314 -38
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +23 -31
- ultralytics/utils/callbacks/wb.py +10 -13
- ultralytics/utils/checks.py +99 -76
- ultralytics/utils/cpu.py +3 -8
- ultralytics/utils/dist.py +8 -12
- ultralytics/utils/downloads.py +20 -30
- ultralytics/utils/errors.py +6 -14
- ultralytics/utils/events.py +2 -4
- ultralytics/utils/export/__init__.py +4 -236
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +91 -55
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +24 -28
- ultralytics/utils/git.py +9 -11
- ultralytics/utils/instance.py +30 -51
- ultralytics/utils/logger.py +212 -114
- ultralytics/utils/loss.py +14 -22
- ultralytics/utils/metrics.py +126 -155
- ultralytics/utils/nms.py +13 -16
- ultralytics/utils/ops.py +107 -165
- ultralytics/utils/patches.py +33 -21
- ultralytics/utils/plotting.py +72 -80
- ultralytics/utils/tal.py +25 -39
- ultralytics/utils/torch_utils.py +52 -78
- ultralytics/utils/tqdm.py +20 -20
- ultralytics/utils/triton.py +13 -19
- ultralytics/utils/tuner.py +17 -5
- dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/WHEEL +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.3.248.dist-info}/top_level.txt +0 -0
ultralytics/nn/autobackend.py
CHANGED
|
@@ -19,11 +19,11 @@ from PIL import Image
|
|
|
19
19
|
from ultralytics.utils import ARM64, IS_JETSON, LINUX, LOGGER, PYTHON_VERSION, ROOT, YAML, is_jetson
|
|
20
20
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml, is_rockchip
|
|
21
21
|
from ultralytics.utils.downloads import attempt_download_asset, is_url
|
|
22
|
+
from ultralytics.utils.nms import non_max_suppression
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
def check_class_names(names: list | dict) -> dict[int, str]:
|
|
25
|
-
"""
|
|
26
|
-
Check class names and convert to dict format if needed.
|
|
26
|
+
"""Check class names and convert to dict format if needed.
|
|
27
27
|
|
|
28
28
|
Args:
|
|
29
29
|
names (list | dict): Class names as list or dict format.
|
|
@@ -52,8 +52,7 @@ def check_class_names(names: list | dict) -> dict[int, str]:
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
def default_class_names(data: str | Path | None = None) -> dict[int, str]:
|
|
55
|
-
"""
|
|
56
|
-
Apply default class names to an input YAML file or return numerical class names.
|
|
55
|
+
"""Apply default class names to an input YAML file or return numerical class names.
|
|
57
56
|
|
|
58
57
|
Args:
|
|
59
58
|
data (str | Path, optional): Path to YAML file containing class names.
|
|
@@ -70,8 +69,7 @@ def default_class_names(data: str | Path | None = None) -> dict[int, str]:
|
|
|
70
69
|
|
|
71
70
|
|
|
72
71
|
class AutoBackend(nn.Module):
|
|
73
|
-
"""
|
|
74
|
-
Handle dynamic backend selection for running inference using Ultralytics YOLO models.
|
|
72
|
+
"""Handle dynamic backend selection for running inference using Ultralytics YOLO models.
|
|
75
73
|
|
|
76
74
|
The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide
|
|
77
75
|
range of formats, each with specific naming conventions as outlined below:
|
|
@@ -95,6 +93,9 @@ class AutoBackend(nn.Module):
|
|
|
95
93
|
| NCNN | *_ncnn_model/ |
|
|
96
94
|
| IMX | *_imx_model/ |
|
|
97
95
|
| RKNN | *_rknn_model/ |
|
|
96
|
+
| Triton Inference | triton://model |
|
|
97
|
+
| ExecuTorch | *.pte |
|
|
98
|
+
| Axelera | *_axelera_model/ |
|
|
98
99
|
|
|
99
100
|
Attributes:
|
|
100
101
|
model (torch.nn.Module): The loaded YOLO model.
|
|
@@ -121,10 +122,12 @@ class AutoBackend(nn.Module):
|
|
|
121
122
|
imx (bool): Whether the model is an IMX model.
|
|
122
123
|
rknn (bool): Whether the model is an RKNN model.
|
|
123
124
|
triton (bool): Whether the model is a Triton Inference Server model.
|
|
125
|
+
pte (bool): Whether the model is a PyTorch ExecuTorch model.
|
|
126
|
+
axelera (bool): Whether the model is an Axelera model.
|
|
124
127
|
|
|
125
128
|
Methods:
|
|
126
129
|
forward: Run inference on an input image.
|
|
127
|
-
from_numpy: Convert
|
|
130
|
+
from_numpy: Convert NumPy arrays to tensors on the model device.
|
|
128
131
|
warmup: Warm up the model with a dummy input.
|
|
129
132
|
_model_type: Determine the model type from file path.
|
|
130
133
|
|
|
@@ -144,8 +147,7 @@ class AutoBackend(nn.Module):
|
|
|
144
147
|
fuse: bool = True,
|
|
145
148
|
verbose: bool = True,
|
|
146
149
|
):
|
|
147
|
-
"""
|
|
148
|
-
Initialize the AutoBackend for inference.
|
|
150
|
+
"""Initialize the AutoBackend for inference.
|
|
149
151
|
|
|
150
152
|
Args:
|
|
151
153
|
model (str | torch.nn.Module): Path to the model weights file or a module instance.
|
|
@@ -175,10 +177,12 @@ class AutoBackend(nn.Module):
|
|
|
175
177
|
ncnn,
|
|
176
178
|
imx,
|
|
177
179
|
rknn,
|
|
180
|
+
pte,
|
|
181
|
+
axelera,
|
|
178
182
|
triton,
|
|
179
183
|
) = self._model_type("" if nn_module else model)
|
|
180
184
|
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
|
|
181
|
-
nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch
|
|
185
|
+
nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCHW)
|
|
182
186
|
stride, ch = 32, 3 # default stride and channels
|
|
183
187
|
end2end, dynamic = False, False
|
|
184
188
|
metadata, task = None, None
|
|
@@ -241,25 +245,28 @@ class AutoBackend(nn.Module):
|
|
|
241
245
|
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime"))
|
|
242
246
|
import onnxruntime
|
|
243
247
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
248
|
+
# Select execution provider: CUDA > CoreML (mps) > CPU
|
|
249
|
+
available = onnxruntime.get_available_providers()
|
|
250
|
+
if cuda and "CUDAExecutionProvider" in available:
|
|
251
|
+
providers = [("CUDAExecutionProvider", {"device_id": device.index}), "CPUExecutionProvider"]
|
|
252
|
+
elif device.type == "mps" and "CoreMLExecutionProvider" in available:
|
|
253
|
+
providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
|
|
254
|
+
else:
|
|
255
|
+
providers = ["CPUExecutionProvider"]
|
|
256
|
+
if cuda:
|
|
257
|
+
LOGGER.warning("CUDA requested but CUDAExecutionProvider not available. Using CPU...")
|
|
258
|
+
device, cuda = torch.device("cpu"), False
|
|
259
|
+
LOGGER.info(
|
|
260
|
+
f"Using ONNX Runtime {onnxruntime.__version__} with {providers[0] if isinstance(providers[0], str) else providers[0][0]}"
|
|
261
|
+
)
|
|
253
262
|
if onnx:
|
|
254
263
|
session = onnxruntime.InferenceSession(w, providers=providers)
|
|
255
264
|
else:
|
|
256
|
-
check_requirements(
|
|
257
|
-
("model-compression-toolkit>=2.4.1", "sony-custom-layers[torch]>=0.3.0", "onnxruntime-extensions")
|
|
258
|
-
)
|
|
265
|
+
check_requirements(("model-compression-toolkit>=2.4.1", "edge-mdt-cl<1.1.0", "onnxruntime-extensions"))
|
|
259
266
|
w = next(Path(w).glob("*.onnx"))
|
|
260
267
|
LOGGER.info(f"Loading {w} for ONNX IMX inference...")
|
|
261
268
|
import mct_quantizers as mctq
|
|
262
|
-
from
|
|
269
|
+
from edgemdt_cl.pytorch.nms import nms_ort # noqa - register custom NMS ops
|
|
263
270
|
|
|
264
271
|
session_options = mctq.get_ort_session_options()
|
|
265
272
|
session_options.enable_mem_reuse = False # fix the shape mismatch from onnxruntime
|
|
@@ -269,7 +276,10 @@ class AutoBackend(nn.Module):
|
|
|
269
276
|
metadata = session.get_modelmeta().custom_metadata_map
|
|
270
277
|
dynamic = isinstance(session.get_outputs()[0].shape[0], str)
|
|
271
278
|
fp16 = "float16" in session.get_inputs()[0].type
|
|
272
|
-
|
|
279
|
+
|
|
280
|
+
# Setup IO binding for optimized inference (CUDA only, not supported for CoreML)
|
|
281
|
+
use_io_binding = not dynamic and cuda
|
|
282
|
+
if use_io_binding:
|
|
273
283
|
io = session.io_binding()
|
|
274
284
|
bindings = []
|
|
275
285
|
for output in session.get_outputs():
|
|
@@ -332,11 +342,11 @@ class AutoBackend(nn.Module):
|
|
|
332
342
|
check_requirements("numpy==1.23.5")
|
|
333
343
|
|
|
334
344
|
try: # https://developer.nvidia.com/nvidia-tensorrt-download
|
|
335
|
-
import tensorrt as trt
|
|
345
|
+
import tensorrt as trt
|
|
336
346
|
except ImportError:
|
|
337
347
|
if LINUX:
|
|
338
348
|
check_requirements("tensorrt>7.0.0,!=10.1.0")
|
|
339
|
-
import tensorrt as trt
|
|
349
|
+
import tensorrt as trt
|
|
340
350
|
check_version(trt.__version__, ">=7.0.0", hard=True)
|
|
341
351
|
check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239")
|
|
342
352
|
if device.type == "cpu":
|
|
@@ -369,39 +379,42 @@ class AutoBackend(nn.Module):
|
|
|
369
379
|
is_trt10 = not hasattr(model, "num_bindings")
|
|
370
380
|
num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
|
|
371
381
|
for i in num:
|
|
382
|
+
# Get tensor info using TRT10+ or legacy API
|
|
372
383
|
if is_trt10:
|
|
373
384
|
name = model.get_tensor_name(i)
|
|
374
385
|
dtype = trt.nptype(model.get_tensor_dtype(name))
|
|
375
386
|
is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
|
|
380
|
-
if dtype == np.float16:
|
|
381
|
-
fp16 = True
|
|
382
|
-
else:
|
|
383
|
-
output_names.append(name)
|
|
384
|
-
shape = tuple(context.get_tensor_shape(name))
|
|
385
|
-
else: # TensorRT < 10.0
|
|
387
|
+
shape = tuple(model.get_tensor_shape(name))
|
|
388
|
+
profile_shape = tuple(model.get_tensor_profile_shape(name, 0)[2]) if is_input else None
|
|
389
|
+
else:
|
|
386
390
|
name = model.get_binding_name(i)
|
|
387
391
|
dtype = trt.nptype(model.get_binding_dtype(i))
|
|
388
392
|
is_input = model.binding_is_input(i)
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
393
|
+
shape = tuple(model.get_binding_shape(i))
|
|
394
|
+
profile_shape = tuple(model.get_profile_shape(0, i)[1]) if is_input else None
|
|
395
|
+
|
|
396
|
+
# Process input/output tensors
|
|
397
|
+
if is_input:
|
|
398
|
+
if -1 in shape:
|
|
399
|
+
dynamic = True
|
|
400
|
+
if is_trt10:
|
|
401
|
+
context.set_input_shape(name, profile_shape)
|
|
402
|
+
else:
|
|
403
|
+
context.set_binding_shape(i, profile_shape)
|
|
404
|
+
if dtype == np.float16:
|
|
405
|
+
fp16 = True
|
|
406
|
+
else:
|
|
407
|
+
output_names.append(name)
|
|
408
|
+
shape = tuple(context.get_tensor_shape(name)) if is_trt10 else tuple(context.get_binding_shape(i))
|
|
398
409
|
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
|
399
410
|
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
|
400
411
|
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
|
401
412
|
|
|
402
413
|
# CoreML
|
|
403
414
|
elif coreml:
|
|
404
|
-
check_requirements(
|
|
415
|
+
check_requirements(
|
|
416
|
+
["coremltools>=9.0", "numpy>=1.14.5,<=2.3.5"]
|
|
417
|
+
) # latest numpy 2.4.0rc1 breaks coremltools exports
|
|
405
418
|
LOGGER.info(f"Loading {w} for CoreML inference...")
|
|
406
419
|
import coremltools as ct
|
|
407
420
|
|
|
@@ -414,8 +427,7 @@ class AutoBackend(nn.Module):
|
|
|
414
427
|
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...")
|
|
415
428
|
import tensorflow as tf
|
|
416
429
|
|
|
417
|
-
|
|
418
|
-
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
|
430
|
+
model = tf.saved_model.load(w)
|
|
419
431
|
metadata = Path(w) / "metadata.yaml"
|
|
420
432
|
|
|
421
433
|
# TF GraphDef
|
|
@@ -423,7 +435,7 @@ class AutoBackend(nn.Module):
|
|
|
423
435
|
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...")
|
|
424
436
|
import tensorflow as tf
|
|
425
437
|
|
|
426
|
-
from ultralytics.
|
|
438
|
+
from ultralytics.utils.export.tensorflow import gd_outputs
|
|
427
439
|
|
|
428
440
|
def wrap_frozen_graph(gd, inputs, outputs):
|
|
429
441
|
"""Wrap frozen graphs for deployment."""
|
|
@@ -491,7 +503,7 @@ class AutoBackend(nn.Module):
|
|
|
491
503
|
if ARM64
|
|
492
504
|
else "paddlepaddle>=3.0.0"
|
|
493
505
|
)
|
|
494
|
-
import paddle.inference as pdi
|
|
506
|
+
import paddle.inference as pdi
|
|
495
507
|
|
|
496
508
|
w = Path(w)
|
|
497
509
|
model_file, params_file = None, None
|
|
@@ -569,6 +581,51 @@ class AutoBackend(nn.Module):
|
|
|
569
581
|
rknn_model.init_runtime()
|
|
570
582
|
metadata = w.parent / "metadata.yaml"
|
|
571
583
|
|
|
584
|
+
# Axelera
|
|
585
|
+
elif axelera:
|
|
586
|
+
import os
|
|
587
|
+
|
|
588
|
+
if not os.environ.get("AXELERA_RUNTIME_DIR"):
|
|
589
|
+
LOGGER.warning(
|
|
590
|
+
"Axelera runtime environment is not activated."
|
|
591
|
+
"\nPlease run: source /opt/axelera/sdk/latest/axelera_activate.sh"
|
|
592
|
+
"\n\nIf this fails, verify driver installation: https://docs.ultralytics.com/integrations/axelera/#axelera-driver-installation"
|
|
593
|
+
)
|
|
594
|
+
try:
|
|
595
|
+
from axelera.runtime import op
|
|
596
|
+
except ImportError:
|
|
597
|
+
check_requirements(
|
|
598
|
+
"axelera_runtime2==0.1.2",
|
|
599
|
+
cmds="--extra-index-url https://software.axelera.ai/artifactory/axelera-runtime-pypi",
|
|
600
|
+
)
|
|
601
|
+
from axelera.runtime import op
|
|
602
|
+
|
|
603
|
+
w = Path(w)
|
|
604
|
+
if (found := next(w.rglob("*.axm"), None)) is None:
|
|
605
|
+
raise FileNotFoundError(f"No .axm file found in: {w}")
|
|
606
|
+
|
|
607
|
+
ax_model = op.load(str(found))
|
|
608
|
+
metadata = found.parent / "metadata.yaml"
|
|
609
|
+
|
|
610
|
+
# ExecuTorch
|
|
611
|
+
elif pte:
|
|
612
|
+
LOGGER.info(f"Loading {w} for ExecuTorch inference...")
|
|
613
|
+
# TorchAO release compatibility table bug https://github.com/pytorch/ao/issues/2919
|
|
614
|
+
check_requirements("setuptools<71.0.0") # Setuptools bug: https://github.com/pypa/setuptools/issues/4483
|
|
615
|
+
check_requirements(("executorch==1.0.1", "flatbuffers"))
|
|
616
|
+
from executorch.runtime import Runtime
|
|
617
|
+
|
|
618
|
+
w = Path(w)
|
|
619
|
+
if w.is_dir():
|
|
620
|
+
model_file = next(w.rglob("*.pte"))
|
|
621
|
+
metadata = w / "metadata.yaml"
|
|
622
|
+
else:
|
|
623
|
+
model_file = w
|
|
624
|
+
metadata = w.parent / "metadata.yaml"
|
|
625
|
+
|
|
626
|
+
program = Runtime.get().load_program(str(model_file))
|
|
627
|
+
model = program.load_method("forward")
|
|
628
|
+
|
|
572
629
|
# Any other format (unsupported)
|
|
573
630
|
else:
|
|
574
631
|
from ultralytics.engine.exporter import export_formats
|
|
@@ -585,14 +642,15 @@ class AutoBackend(nn.Module):
|
|
|
585
642
|
for k, v in metadata.items():
|
|
586
643
|
if k in {"stride", "batch", "channels"}:
|
|
587
644
|
metadata[k] = int(v)
|
|
588
|
-
elif k in {"imgsz", "names", "kpt_shape", "args"} and isinstance(v, str):
|
|
589
|
-
metadata[k] =
|
|
645
|
+
elif k in {"imgsz", "names", "kpt_shape", "kpt_names", "args"} and isinstance(v, str):
|
|
646
|
+
metadata[k] = ast.literal_eval(v)
|
|
590
647
|
stride = metadata["stride"]
|
|
591
648
|
task = metadata["task"]
|
|
592
649
|
batch = metadata["batch"]
|
|
593
650
|
imgsz = metadata["imgsz"]
|
|
594
651
|
names = metadata["names"]
|
|
595
652
|
kpt_shape = metadata.get("kpt_shape")
|
|
653
|
+
kpt_names = metadata.get("kpt_names")
|
|
596
654
|
end2end = metadata.get("args", {}).get("nms", False)
|
|
597
655
|
dynamic = metadata.get("args", {}).get("dynamic", dynamic)
|
|
598
656
|
ch = metadata.get("channels", 3)
|
|
@@ -614,8 +672,7 @@ class AutoBackend(nn.Module):
|
|
|
614
672
|
embed: list | None = None,
|
|
615
673
|
**kwargs: Any,
|
|
616
674
|
) -> torch.Tensor | list[torch.Tensor]:
|
|
617
|
-
"""
|
|
618
|
-
Run inference on an AutoBackend model.
|
|
675
|
+
"""Run inference on an AutoBackend model.
|
|
619
676
|
|
|
620
677
|
Args:
|
|
621
678
|
im (torch.Tensor): The image tensor to perform inference on.
|
|
@@ -627,7 +684,7 @@ class AutoBackend(nn.Module):
|
|
|
627
684
|
Returns:
|
|
628
685
|
(torch.Tensor | list[torch.Tensor]): The raw output tensor(s) from the model.
|
|
629
686
|
"""
|
|
630
|
-
|
|
687
|
+
_b, _ch, h, w = im.shape # batch, channel, height, width
|
|
631
688
|
if self.fp16 and im.dtype != torch.float16:
|
|
632
689
|
im = im.half() # to FP16
|
|
633
690
|
if self.nhwc:
|
|
@@ -649,10 +706,7 @@ class AutoBackend(nn.Module):
|
|
|
649
706
|
|
|
650
707
|
# ONNX Runtime
|
|
651
708
|
elif self.onnx or self.imx:
|
|
652
|
-
if self.
|
|
653
|
-
im = im.cpu().numpy() # torch to numpy
|
|
654
|
-
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
|
655
|
-
else:
|
|
709
|
+
if self.use_io_binding:
|
|
656
710
|
if not self.cuda:
|
|
657
711
|
im = im.cpu()
|
|
658
712
|
self.io.bind_input(
|
|
@@ -665,13 +719,21 @@ class AutoBackend(nn.Module):
|
|
|
665
719
|
)
|
|
666
720
|
self.session.run_with_iobinding(self.io)
|
|
667
721
|
y = self.bindings
|
|
722
|
+
else:
|
|
723
|
+
im = im.cpu().numpy() # torch to numpy
|
|
724
|
+
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
|
|
668
725
|
if self.imx:
|
|
669
726
|
if self.task == "detect":
|
|
670
727
|
# boxes, conf, cls
|
|
671
728
|
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1)
|
|
672
729
|
elif self.task == "pose":
|
|
673
730
|
# boxes, conf, kpts
|
|
674
|
-
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1)
|
|
731
|
+
y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1, dtype=y[0].dtype)
|
|
732
|
+
elif self.task == "segment":
|
|
733
|
+
y = (
|
|
734
|
+
np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None], y[3]], axis=-1, dtype=y[0].dtype),
|
|
735
|
+
y[4],
|
|
736
|
+
)
|
|
675
737
|
|
|
676
738
|
# OpenVINO
|
|
677
739
|
elif self.xml:
|
|
@@ -771,11 +833,19 @@ class AutoBackend(nn.Module):
|
|
|
771
833
|
im = im if isinstance(im, (list, tuple)) else [im]
|
|
772
834
|
y = self.rknn_model.inference(inputs=im)
|
|
773
835
|
|
|
836
|
+
# Axelera
|
|
837
|
+
elif self.axelera:
|
|
838
|
+
y = self.ax_model(im.cpu())
|
|
839
|
+
|
|
840
|
+
# ExecuTorch
|
|
841
|
+
elif self.pte:
|
|
842
|
+
y = self.model.execute([im])
|
|
843
|
+
|
|
774
844
|
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
|
775
845
|
else:
|
|
776
846
|
im = im.cpu().numpy()
|
|
777
847
|
if self.saved_model: # SavedModel
|
|
778
|
-
y = self.model
|
|
848
|
+
y = self.model.serving_default(im)
|
|
779
849
|
if not isinstance(y, list):
|
|
780
850
|
y = [y]
|
|
781
851
|
elif self.pb: # GraphDef
|
|
@@ -820,8 +890,6 @@ class AutoBackend(nn.Module):
|
|
|
820
890
|
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
|
|
821
891
|
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
|
|
822
892
|
|
|
823
|
-
# for x in y:
|
|
824
|
-
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
|
|
825
893
|
if isinstance(y, (list, tuple)):
|
|
826
894
|
if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined
|
|
827
895
|
nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400)
|
|
@@ -830,35 +898,35 @@ class AutoBackend(nn.Module):
|
|
|
830
898
|
else:
|
|
831
899
|
return self.from_numpy(y)
|
|
832
900
|
|
|
833
|
-
def from_numpy(self, x: np.ndarray) -> torch.Tensor:
|
|
834
|
-
"""
|
|
835
|
-
Convert a numpy array to a tensor.
|
|
901
|
+
def from_numpy(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
|
|
902
|
+
"""Convert a NumPy array to a torch tensor on the model device.
|
|
836
903
|
|
|
837
904
|
Args:
|
|
838
|
-
x (np.ndarray):
|
|
905
|
+
x (np.ndarray | torch.Tensor): Input array or tensor.
|
|
839
906
|
|
|
840
907
|
Returns:
|
|
841
|
-
(torch.Tensor):
|
|
908
|
+
(torch.Tensor): Tensor on `self.device`.
|
|
842
909
|
"""
|
|
843
910
|
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
|
|
844
911
|
|
|
845
912
|
def warmup(self, imgsz: tuple[int, int, int, int] = (1, 3, 640, 640)) -> None:
|
|
846
|
-
"""
|
|
847
|
-
Warm up the model by running one forward pass with a dummy input.
|
|
913
|
+
"""Warm up the model by running one forward pass with a dummy input.
|
|
848
914
|
|
|
849
915
|
Args:
|
|
850
|
-
imgsz (tuple
|
|
916
|
+
imgsz (tuple[int, int, int, int]): Dummy input shape in (batch, channels, height, width) format.
|
|
851
917
|
"""
|
|
852
918
|
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
|
853
919
|
if any(warmup_types) and (self.device.type != "cpu" or self.triton):
|
|
854
920
|
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
|
855
921
|
for _ in range(2 if self.jit else 1):
|
|
856
|
-
self.forward(im) # warmup
|
|
922
|
+
self.forward(im) # warmup model
|
|
923
|
+
warmup_boxes = torch.rand(1, 84, 16, device=self.device) # 16 boxes works best empirically
|
|
924
|
+
warmup_boxes[:, :4] *= imgsz[-1]
|
|
925
|
+
non_max_suppression(warmup_boxes) # warmup NMS
|
|
857
926
|
|
|
858
927
|
@staticmethod
|
|
859
928
|
def _model_type(p: str = "path/to/model.pt") -> list[bool]:
|
|
860
|
-
"""
|
|
861
|
-
Take a path to a model file and return the model type.
|
|
929
|
+
"""Take a path to a model file and return the model type.
|
|
862
930
|
|
|
863
931
|
Args:
|
|
864
932
|
p (str): Path to the model file.
|
|
@@ -867,8 +935,8 @@ class AutoBackend(nn.Module):
|
|
|
867
935
|
(list[bool]): List of booleans indicating the model type.
|
|
868
936
|
|
|
869
937
|
Examples:
|
|
870
|
-
>>>
|
|
871
|
-
>>>
|
|
938
|
+
>>> types = AutoBackend._model_type("path/to/model.onnx")
|
|
939
|
+
>>> assert types[2] # onnx
|
|
872
940
|
"""
|
|
873
941
|
from ultralytics.engine.exporter import export_formats
|
|
874
942
|
|
|
@@ -887,4 +955,4 @@ class AutoBackend(nn.Module):
|
|
|
887
955
|
url = urlsplit(p)
|
|
888
956
|
triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"}
|
|
889
957
|
|
|
890
|
-
return types
|
|
958
|
+
return [*types, triton]
|
|
@@ -103,80 +103,80 @@ from .transformer import (
|
|
|
103
103
|
)
|
|
104
104
|
|
|
105
105
|
__all__ = (
|
|
106
|
-
"
|
|
107
|
-
"
|
|
108
|
-
"
|
|
109
|
-
"
|
|
110
|
-
"
|
|
111
|
-
"
|
|
112
|
-
"ConvTranspose",
|
|
113
|
-
"Focus",
|
|
114
|
-
"GhostConv",
|
|
115
|
-
"ChannelAttention",
|
|
116
|
-
"SpatialAttention",
|
|
106
|
+
"AIFI",
|
|
107
|
+
"C1",
|
|
108
|
+
"C2",
|
|
109
|
+
"C2PSA",
|
|
110
|
+
"C3",
|
|
111
|
+
"C3TR",
|
|
117
112
|
"CBAM",
|
|
118
|
-
"
|
|
119
|
-
"TransformerLayer",
|
|
120
|
-
"TransformerBlock",
|
|
121
|
-
"MLPBlock",
|
|
122
|
-
"LayerNorm2d",
|
|
113
|
+
"CIB",
|
|
123
114
|
"DFL",
|
|
124
|
-
"
|
|
125
|
-
"
|
|
115
|
+
"ELAN1",
|
|
116
|
+
"MLP",
|
|
117
|
+
"OBB",
|
|
118
|
+
"PSA",
|
|
126
119
|
"SPP",
|
|
120
|
+
"SPPELAN",
|
|
127
121
|
"SPPF",
|
|
128
|
-
"
|
|
129
|
-
"
|
|
130
|
-
"
|
|
122
|
+
"A2C2f",
|
|
123
|
+
"AConv",
|
|
124
|
+
"ADown",
|
|
125
|
+
"Attention",
|
|
126
|
+
"BNContrastiveHead",
|
|
127
|
+
"Bottleneck",
|
|
128
|
+
"BottleneckCSP",
|
|
131
129
|
"C2f",
|
|
132
|
-
"C3k2",
|
|
133
|
-
"SCDown",
|
|
134
|
-
"C2fPSA",
|
|
135
|
-
"C2PSA",
|
|
136
130
|
"C2fAttn",
|
|
137
|
-
"
|
|
138
|
-
"
|
|
131
|
+
"C2fCIB",
|
|
132
|
+
"C2fPSA",
|
|
139
133
|
"C3Ghost",
|
|
140
|
-
"
|
|
141
|
-
"
|
|
142
|
-
"
|
|
143
|
-
"
|
|
144
|
-
"
|
|
145
|
-
"Segment",
|
|
146
|
-
"Pose",
|
|
134
|
+
"C3k2",
|
|
135
|
+
"C3x",
|
|
136
|
+
"CBFuse",
|
|
137
|
+
"CBLinear",
|
|
138
|
+
"ChannelAttention",
|
|
147
139
|
"Classify",
|
|
148
|
-
"
|
|
149
|
-
"
|
|
150
|
-
"
|
|
151
|
-
"
|
|
140
|
+
"Concat",
|
|
141
|
+
"ContrastiveHead",
|
|
142
|
+
"Conv",
|
|
143
|
+
"Conv2",
|
|
144
|
+
"ConvTranspose",
|
|
145
|
+
"DWConv",
|
|
146
|
+
"DWConvTranspose2d",
|
|
152
147
|
"DeformableTransformerDecoder",
|
|
153
148
|
"DeformableTransformerDecoderLayer",
|
|
149
|
+
"Detect",
|
|
150
|
+
"Focus",
|
|
151
|
+
"GhostBottleneck",
|
|
152
|
+
"GhostConv",
|
|
153
|
+
"HGBlock",
|
|
154
|
+
"HGStem",
|
|
155
|
+
"ImagePoolingAttn",
|
|
156
|
+
"Index",
|
|
157
|
+
"LRPCHead",
|
|
158
|
+
"LayerNorm2d",
|
|
159
|
+
"LightConv",
|
|
160
|
+
"MLPBlock",
|
|
154
161
|
"MSDeformAttn",
|
|
155
|
-
"
|
|
162
|
+
"MaxSigmoidAttnBlock",
|
|
163
|
+
"Pose",
|
|
164
|
+
"Proto",
|
|
165
|
+
"RTDETRDecoder",
|
|
166
|
+
"RepC3",
|
|
167
|
+
"RepConv",
|
|
168
|
+
"RepNCSPELAN4",
|
|
169
|
+
"RepVGGDW",
|
|
156
170
|
"ResNetLayer",
|
|
157
|
-
"
|
|
171
|
+
"SCDown",
|
|
172
|
+
"Segment",
|
|
173
|
+
"SpatialAttention",
|
|
174
|
+
"TorchVision",
|
|
175
|
+
"TransformerBlock",
|
|
176
|
+
"TransformerEncoderLayer",
|
|
177
|
+
"TransformerLayer",
|
|
158
178
|
"WorldDetect",
|
|
159
179
|
"YOLOEDetect",
|
|
160
180
|
"YOLOESegment",
|
|
161
181
|
"v10Detect",
|
|
162
|
-
"LRPCHead",
|
|
163
|
-
"ImagePoolingAttn",
|
|
164
|
-
"MaxSigmoidAttnBlock",
|
|
165
|
-
"ContrastiveHead",
|
|
166
|
-
"BNContrastiveHead",
|
|
167
|
-
"RepNCSPELAN4",
|
|
168
|
-
"ADown",
|
|
169
|
-
"SPPELAN",
|
|
170
|
-
"CBFuse",
|
|
171
|
-
"CBLinear",
|
|
172
|
-
"AConv",
|
|
173
|
-
"ELAN1",
|
|
174
|
-
"RepVGGDW",
|
|
175
|
-
"CIB",
|
|
176
|
-
"C2fCIB",
|
|
177
|
-
"Attention",
|
|
178
|
-
"PSA",
|
|
179
|
-
"TorchVision",
|
|
180
|
-
"Index",
|
|
181
|
-
"A2C2f",
|
|
182
182
|
)
|
|
@@ -6,8 +6,7 @@ import torch.nn as nn
|
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class AGLU(nn.Module):
|
|
9
|
-
"""
|
|
10
|
-
Unified activation function module from AGLU.
|
|
9
|
+
"""Unified activation function module from AGLU.
|
|
11
10
|
|
|
12
11
|
This class implements a parameterized activation function with learnable parameters lambda and kappa, based on the
|
|
13
12
|
AGLU (Adaptive Gated Linear Unit) approach.
|
|
@@ -40,11 +39,10 @@ class AGLU(nn.Module):
|
|
|
40
39
|
self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter
|
|
41
40
|
|
|
42
41
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
|
-
"""
|
|
44
|
-
Apply the Adaptive Gated Linear Unit (AGLU) activation function.
|
|
42
|
+
"""Apply the Adaptive Gated Linear Unit (AGLU) activation function.
|
|
45
43
|
|
|
46
|
-
This forward method implements the AGLU activation function with learnable parameters lambda and kappa.
|
|
47
|
-
|
|
44
|
+
This forward method implements the AGLU activation function with learnable parameters lambda and kappa. The
|
|
45
|
+
function applies a transformation that adaptively combines linear and non-linear components.
|
|
48
46
|
|
|
49
47
|
Args:
|
|
50
48
|
x (torch.Tensor): Input tensor to apply the activation function to.
|