dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
- tests/__init__.py +22 -0
- tests/conftest.py +83 -0
- tests/test_cli.py +138 -0
- tests/test_cuda.py +215 -0
- tests/test_engine.py +131 -0
- tests/test_exports.py +236 -0
- tests/test_integrations.py +154 -0
- tests/test_python.py +694 -0
- tests/test_solutions.py +187 -0
- ultralytics/__init__.py +30 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1023 -0
- ultralytics/cfg/datasets/Argoverse.yaml +77 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +443 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/VOC.yaml +106 -0
- ultralytics/cfg/datasets/VisDrone.yaml +77 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +42 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +24 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +22 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +127 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +22 -0
- ultralytics/cfg/trackers/bytetrack.yaml +14 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2945 -0
- ultralytics/data/base.py +438 -0
- ultralytics/data/build.py +258 -0
- ultralytics/data/converter.py +754 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +676 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +125 -0
- ultralytics/data/split_dota.py +325 -0
- ultralytics/data/utils.py +777 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1519 -0
- ultralytics/engine/model.py +1156 -0
- ultralytics/engine/predictor.py +502 -0
- ultralytics/engine/results.py +1840 -0
- ultralytics/engine/trainer.py +853 -0
- ultralytics/engine/tuner.py +243 -0
- ultralytics/engine/validator.py +377 -0
- ultralytics/hub/__init__.py +168 -0
- ultralytics/hub/auth.py +137 -0
- ultralytics/hub/google/__init__.py +176 -0
- ultralytics/hub/session.py +446 -0
- ultralytics/hub/utils.py +248 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +61 -0
- ultralytics/models/fastsam/predict.py +181 -0
- ultralytics/models/fastsam/utils.py +24 -0
- ultralytics/models/fastsam/val.py +40 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +102 -0
- ultralytics/models/nas/predict.py +58 -0
- ultralytics/models/nas/val.py +39 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +84 -0
- ultralytics/models/rtdetr/train.py +85 -0
- ultralytics/models/rtdetr/val.py +191 -0
- ultralytics/models/sam/__init__.py +6 -0
- ultralytics/models/sam/amg.py +260 -0
- ultralytics/models/sam/build.py +358 -0
- ultralytics/models/sam/model.py +170 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1129 -0
- ultralytics/models/sam/modules/decoders.py +515 -0
- ultralytics/models/sam/modules/encoders.py +854 -0
- ultralytics/models/sam/modules/memory_attention.py +299 -0
- ultralytics/models/sam/modules/sam.py +1006 -0
- ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
- ultralytics/models/sam/modules/transformer.py +351 -0
- ultralytics/models/sam/modules/utils.py +394 -0
- ultralytics/models/sam/predict.py +1605 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +455 -0
- ultralytics/models/utils/ops.py +268 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +88 -0
- ultralytics/models/yolo/classify/train.py +233 -0
- ultralytics/models/yolo/classify/val.py +215 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +124 -0
- ultralytics/models/yolo/detect/train.py +217 -0
- ultralytics/models/yolo/detect/val.py +451 -0
- ultralytics/models/yolo/model.py +354 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +66 -0
- ultralytics/models/yolo/obb/train.py +81 -0
- ultralytics/models/yolo/obb/val.py +283 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +79 -0
- ultralytics/models/yolo/pose/train.py +154 -0
- ultralytics/models/yolo/pose/val.py +394 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +113 -0
- ultralytics/models/yolo/segment/train.py +123 -0
- ultralytics/models/yolo/segment/val.py +428 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +119 -0
- ultralytics/models/yolo/world/train_world.py +176 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +169 -0
- ultralytics/models/yolo/yoloe/train.py +298 -0
- ultralytics/models/yolo/yoloe/train_seg.py +124 -0
- ultralytics/models/yolo/yoloe/val.py +191 -0
- ultralytics/nn/__init__.py +29 -0
- ultralytics/nn/autobackend.py +842 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +53 -0
- ultralytics/nn/modules/block.py +1966 -0
- ultralytics/nn/modules/conv.py +712 -0
- ultralytics/nn/modules/head.py +880 -0
- ultralytics/nn/modules/transformer.py +713 -0
- ultralytics/nn/modules/utils.py +164 -0
- ultralytics/nn/tasks.py +1627 -0
- ultralytics/nn/text_model.py +351 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +116 -0
- ultralytics/solutions/analytics.py +252 -0
- ultralytics/solutions/config.py +106 -0
- ultralytics/solutions/distance_calculation.py +124 -0
- ultralytics/solutions/heatmap.py +127 -0
- ultralytics/solutions/instance_segmentation.py +84 -0
- ultralytics/solutions/object_blurrer.py +90 -0
- ultralytics/solutions/object_counter.py +195 -0
- ultralytics/solutions/object_cropper.py +84 -0
- ultralytics/solutions/parking_management.py +273 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +120 -0
- ultralytics/solutions/security_alarm.py +154 -0
- ultralytics/solutions/similarity_search.py +172 -0
- ultralytics/solutions/solutions.py +724 -0
- ultralytics/solutions/speed_estimation.py +110 -0
- ultralytics/solutions/streamlit_inference.py +196 -0
- ultralytics/solutions/templates/similarity-search.html +160 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +68 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +124 -0
- ultralytics/trackers/bot_sort.py +260 -0
- ultralytics/trackers/byte_tracker.py +480 -0
- ultralytics/trackers/track.py +125 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +376 -0
- ultralytics/trackers/utils/kalman_filter.py +493 -0
- ultralytics/trackers/utils/matching.py +157 -0
- ultralytics/utils/__init__.py +1435 -0
- ultralytics/utils/autobatch.py +106 -0
- ultralytics/utils/autodevice.py +174 -0
- ultralytics/utils/benchmarks.py +695 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +234 -0
- ultralytics/utils/callbacks/clearml.py +153 -0
- ultralytics/utils/callbacks/comet.py +552 -0
- ultralytics/utils/callbacks/dvc.py +205 -0
- ultralytics/utils/callbacks/hub.py +108 -0
- ultralytics/utils/callbacks/mlflow.py +138 -0
- ultralytics/utils/callbacks/neptune.py +140 -0
- ultralytics/utils/callbacks/raytune.py +43 -0
- ultralytics/utils/callbacks/tensorboard.py +132 -0
- ultralytics/utils/callbacks/wb.py +185 -0
- ultralytics/utils/checks.py +897 -0
- ultralytics/utils/dist.py +119 -0
- ultralytics/utils/downloads.py +499 -0
- ultralytics/utils/errors.py +43 -0
- ultralytics/utils/export.py +219 -0
- ultralytics/utils/files.py +221 -0
- ultralytics/utils/instance.py +499 -0
- ultralytics/utils/loss.py +813 -0
- ultralytics/utils/metrics.py +1356 -0
- ultralytics/utils/ops.py +885 -0
- ultralytics/utils/patches.py +143 -0
- ultralytics/utils/plotting.py +1011 -0
- ultralytics/utils/tal.py +416 -0
- ultralytics/utils/torch_utils.py +990 -0
- ultralytics/utils/triton.py +116 -0
- ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""
|
3
|
+
Ultralytics modules.
|
4
|
+
|
5
|
+
This module provides access to various neural network components used in Ultralytics models, including convolution blocks,
|
6
|
+
attention mechanisms, transformer components, and detection/segmentation heads.
|
7
|
+
|
8
|
+
Examples:
|
9
|
+
Visualize a module with Netron.
|
10
|
+
>>> from ultralytics.nn.modules import *
|
11
|
+
>>> import torch
|
12
|
+
>>> import os
|
13
|
+
>>> x = torch.ones(1, 128, 40, 40)
|
14
|
+
>>> m = Conv(128, 128)
|
15
|
+
>>> f = f"{m._get_name()}.onnx"
|
16
|
+
>>> torch.onnx.export(m, x, f)
|
17
|
+
>>> os.system(f"onnxslim {f} {f} && open {f}") # pip install onnxslim
|
18
|
+
"""
|
19
|
+
|
20
|
+
from .block import (
|
21
|
+
C1,
|
22
|
+
C2,
|
23
|
+
C2PSA,
|
24
|
+
C3,
|
25
|
+
C3TR,
|
26
|
+
CIB,
|
27
|
+
DFL,
|
28
|
+
ELAN1,
|
29
|
+
PSA,
|
30
|
+
SPP,
|
31
|
+
SPPELAN,
|
32
|
+
SPPF,
|
33
|
+
A2C2f,
|
34
|
+
AConv,
|
35
|
+
ADown,
|
36
|
+
Attention,
|
37
|
+
BNContrastiveHead,
|
38
|
+
Bottleneck,
|
39
|
+
BottleneckCSP,
|
40
|
+
C2f,
|
41
|
+
C2fAttn,
|
42
|
+
C2fCIB,
|
43
|
+
C2fPSA,
|
44
|
+
C3Ghost,
|
45
|
+
C3k2,
|
46
|
+
C3x,
|
47
|
+
CBFuse,
|
48
|
+
CBLinear,
|
49
|
+
ContrastiveHead,
|
50
|
+
GhostBottleneck,
|
51
|
+
HGBlock,
|
52
|
+
HGStem,
|
53
|
+
ImagePoolingAttn,
|
54
|
+
MaxSigmoidAttnBlock,
|
55
|
+
Proto,
|
56
|
+
RepC3,
|
57
|
+
RepNCSPELAN4,
|
58
|
+
RepVGGDW,
|
59
|
+
ResNetLayer,
|
60
|
+
SCDown,
|
61
|
+
TorchVision,
|
62
|
+
)
|
63
|
+
from .conv import (
|
64
|
+
CBAM,
|
65
|
+
ChannelAttention,
|
66
|
+
Concat,
|
67
|
+
Conv,
|
68
|
+
Conv2,
|
69
|
+
ConvTranspose,
|
70
|
+
DWConv,
|
71
|
+
DWConvTranspose2d,
|
72
|
+
Focus,
|
73
|
+
GhostConv,
|
74
|
+
Index,
|
75
|
+
LightConv,
|
76
|
+
RepConv,
|
77
|
+
SpatialAttention,
|
78
|
+
)
|
79
|
+
from .head import (
|
80
|
+
OBB,
|
81
|
+
Classify,
|
82
|
+
Detect,
|
83
|
+
LRPCHead,
|
84
|
+
Pose,
|
85
|
+
RTDETRDecoder,
|
86
|
+
Segment,
|
87
|
+
WorldDetect,
|
88
|
+
YOLOEDetect,
|
89
|
+
YOLOESegment,
|
90
|
+
v10Detect,
|
91
|
+
)
|
92
|
+
from .transformer import (
|
93
|
+
AIFI,
|
94
|
+
MLP,
|
95
|
+
DeformableTransformerDecoder,
|
96
|
+
DeformableTransformerDecoderLayer,
|
97
|
+
LayerNorm2d,
|
98
|
+
MLPBlock,
|
99
|
+
MSDeformAttn,
|
100
|
+
TransformerBlock,
|
101
|
+
TransformerEncoderLayer,
|
102
|
+
TransformerLayer,
|
103
|
+
)
|
104
|
+
|
105
|
+
__all__ = (
|
106
|
+
"Conv",
|
107
|
+
"Conv2",
|
108
|
+
"LightConv",
|
109
|
+
"RepConv",
|
110
|
+
"DWConv",
|
111
|
+
"DWConvTranspose2d",
|
112
|
+
"ConvTranspose",
|
113
|
+
"Focus",
|
114
|
+
"GhostConv",
|
115
|
+
"ChannelAttention",
|
116
|
+
"SpatialAttention",
|
117
|
+
"CBAM",
|
118
|
+
"Concat",
|
119
|
+
"TransformerLayer",
|
120
|
+
"TransformerBlock",
|
121
|
+
"MLPBlock",
|
122
|
+
"LayerNorm2d",
|
123
|
+
"DFL",
|
124
|
+
"HGBlock",
|
125
|
+
"HGStem",
|
126
|
+
"SPP",
|
127
|
+
"SPPF",
|
128
|
+
"C1",
|
129
|
+
"C2",
|
130
|
+
"C3",
|
131
|
+
"C2f",
|
132
|
+
"C3k2",
|
133
|
+
"SCDown",
|
134
|
+
"C2fPSA",
|
135
|
+
"C2PSA",
|
136
|
+
"C2fAttn",
|
137
|
+
"C3x",
|
138
|
+
"C3TR",
|
139
|
+
"C3Ghost",
|
140
|
+
"GhostBottleneck",
|
141
|
+
"Bottleneck",
|
142
|
+
"BottleneckCSP",
|
143
|
+
"Proto",
|
144
|
+
"Detect",
|
145
|
+
"Segment",
|
146
|
+
"Pose",
|
147
|
+
"Classify",
|
148
|
+
"TransformerEncoderLayer",
|
149
|
+
"RepC3",
|
150
|
+
"RTDETRDecoder",
|
151
|
+
"AIFI",
|
152
|
+
"DeformableTransformerDecoder",
|
153
|
+
"DeformableTransformerDecoderLayer",
|
154
|
+
"MSDeformAttn",
|
155
|
+
"MLP",
|
156
|
+
"ResNetLayer",
|
157
|
+
"OBB",
|
158
|
+
"WorldDetect",
|
159
|
+
"YOLOEDetect",
|
160
|
+
"YOLOESegment",
|
161
|
+
"v10Detect",
|
162
|
+
"LRPCHead",
|
163
|
+
"ImagePoolingAttn",
|
164
|
+
"MaxSigmoidAttnBlock",
|
165
|
+
"ContrastiveHead",
|
166
|
+
"BNContrastiveHead",
|
167
|
+
"RepNCSPELAN4",
|
168
|
+
"ADown",
|
169
|
+
"SPPELAN",
|
170
|
+
"CBFuse",
|
171
|
+
"CBLinear",
|
172
|
+
"AConv",
|
173
|
+
"ELAN1",
|
174
|
+
"RepVGGDW",
|
175
|
+
"CIB",
|
176
|
+
"C2fCIB",
|
177
|
+
"Attention",
|
178
|
+
"PSA",
|
179
|
+
"TorchVision",
|
180
|
+
"Index",
|
181
|
+
"A2C2f",
|
182
|
+
)
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
2
|
+
"""Activation modules."""
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import torch.nn as nn
|
6
|
+
|
7
|
+
|
8
|
+
class AGLU(nn.Module):
|
9
|
+
"""
|
10
|
+
Unified activation function module from AGLU.
|
11
|
+
|
12
|
+
This class implements a parameterized activation function with learnable parameters lambda and kappa, based on the
|
13
|
+
AGLU (Adaptive Gated Linear Unit) approach (https://github.com/kostas1515/AGLU).
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
act (nn.Softplus): Softplus activation function with negative beta.
|
17
|
+
lambd (nn.Parameter): Learnable lambda parameter initialized with uniform distribution.
|
18
|
+
kappa (nn.Parameter): Learnable kappa parameter initialized with uniform distribution.
|
19
|
+
|
20
|
+
Methods:
|
21
|
+
forward: Compute the forward pass of the Unified activation function.
|
22
|
+
|
23
|
+
Examples:
|
24
|
+
>>> import torch
|
25
|
+
>>> m = AGLU()
|
26
|
+
>>> input = torch.randn(2)
|
27
|
+
>>> output = m(input)
|
28
|
+
>>> print(output.shape)
|
29
|
+
torch.Size([2])
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, device=None, dtype=None) -> None:
|
33
|
+
"""Initialize the Unified activation function with learnable parameters."""
|
34
|
+
super().__init__()
|
35
|
+
self.act = nn.Softplus(beta=-1.0)
|
36
|
+
self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter
|
37
|
+
self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter
|
38
|
+
|
39
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40
|
+
"""
|
41
|
+
Apply the Adaptive Gated Linear Unit (AGLU) activation function.
|
42
|
+
|
43
|
+
This forward method implements the AGLU activation function with learnable parameters lambda and kappa.
|
44
|
+
The function applies a transformation that adaptively combines linear and non-linear components.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
x (torch.Tensor): Input tensor to apply the activation function to.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
(torch.Tensor): Output tensor after applying the AGLU activation function, with the same shape as the input.
|
51
|
+
"""
|
52
|
+
lam = torch.clamp(self.lambd, min=0.0001) # Clamp lambda to avoid division by zero
|
53
|
+
return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam)))
|