dgenerate-ultralytics-headless 8.3.134__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (272) hide show
  1. dgenerate_ultralytics_headless-8.3.134.dist-info/METADATA +400 -0
  2. dgenerate_ultralytics_headless-8.3.134.dist-info/RECORD +272 -0
  3. dgenerate_ultralytics_headless-8.3.134.dist-info/WHEEL +5 -0
  4. dgenerate_ultralytics_headless-8.3.134.dist-info/entry_points.txt +3 -0
  5. dgenerate_ultralytics_headless-8.3.134.dist-info/licenses/LICENSE +661 -0
  6. dgenerate_ultralytics_headless-8.3.134.dist-info/top_level.txt +1 -0
  7. tests/__init__.py +22 -0
  8. tests/conftest.py +83 -0
  9. tests/test_cli.py +138 -0
  10. tests/test_cuda.py +215 -0
  11. tests/test_engine.py +131 -0
  12. tests/test_exports.py +236 -0
  13. tests/test_integrations.py +154 -0
  14. tests/test_python.py +694 -0
  15. tests/test_solutions.py +187 -0
  16. ultralytics/__init__.py +30 -0
  17. ultralytics/assets/bus.jpg +0 -0
  18. ultralytics/assets/zidane.jpg +0 -0
  19. ultralytics/cfg/__init__.py +1023 -0
  20. ultralytics/cfg/datasets/Argoverse.yaml +77 -0
  21. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  22. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  23. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  24. ultralytics/cfg/datasets/HomeObjects-3K.yaml +33 -0
  25. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  26. ultralytics/cfg/datasets/Objects365.yaml +443 -0
  27. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  28. ultralytics/cfg/datasets/VOC.yaml +106 -0
  29. ultralytics/cfg/datasets/VisDrone.yaml +77 -0
  30. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  31. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  32. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  33. ultralytics/cfg/datasets/coco-pose.yaml +42 -0
  34. ultralytics/cfg/datasets/coco.yaml +118 -0
  35. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  36. ultralytics/cfg/datasets/coco128.yaml +101 -0
  37. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  38. ultralytics/cfg/datasets/coco8-pose.yaml +26 -0
  39. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  40. ultralytics/cfg/datasets/coco8.yaml +101 -0
  41. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  42. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  43. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  44. ultralytics/cfg/datasets/dota8.yaml +35 -0
  45. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  46. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  47. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  48. ultralytics/cfg/datasets/open-images-v7.yaml +666 -0
  49. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  50. ultralytics/cfg/datasets/signature.yaml +21 -0
  51. ultralytics/cfg/datasets/tiger-pose.yaml +25 -0
  52. ultralytics/cfg/datasets/xView.yaml +155 -0
  53. ultralytics/cfg/default.yaml +127 -0
  54. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  55. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  56. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  57. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  58. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  59. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  60. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  61. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  62. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  63. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  64. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  65. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  66. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  68. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  69. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  70. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  71. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  74. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  75. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  76. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  77. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  78. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  79. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  80. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  81. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  82. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  83. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +45 -0
  84. ultralytics/cfg/models/v8/yoloe-v8.yaml +45 -0
  85. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  86. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  87. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  88. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  89. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  90. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  91. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  92. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  93. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  94. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  95. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  96. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  97. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  98. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  99. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  100. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  101. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  102. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  103. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  105. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  106. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  107. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  108. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  109. ultralytics/cfg/trackers/botsort.yaml +22 -0
  110. ultralytics/cfg/trackers/bytetrack.yaml +14 -0
  111. ultralytics/data/__init__.py +26 -0
  112. ultralytics/data/annotator.py +66 -0
  113. ultralytics/data/augment.py +2945 -0
  114. ultralytics/data/base.py +438 -0
  115. ultralytics/data/build.py +258 -0
  116. ultralytics/data/converter.py +754 -0
  117. ultralytics/data/dataset.py +834 -0
  118. ultralytics/data/loaders.py +676 -0
  119. ultralytics/data/scripts/download_weights.sh +18 -0
  120. ultralytics/data/scripts/get_coco.sh +61 -0
  121. ultralytics/data/scripts/get_coco128.sh +18 -0
  122. ultralytics/data/scripts/get_imagenet.sh +52 -0
  123. ultralytics/data/split.py +125 -0
  124. ultralytics/data/split_dota.py +325 -0
  125. ultralytics/data/utils.py +777 -0
  126. ultralytics/engine/__init__.py +1 -0
  127. ultralytics/engine/exporter.py +1519 -0
  128. ultralytics/engine/model.py +1156 -0
  129. ultralytics/engine/predictor.py +502 -0
  130. ultralytics/engine/results.py +1840 -0
  131. ultralytics/engine/trainer.py +853 -0
  132. ultralytics/engine/tuner.py +243 -0
  133. ultralytics/engine/validator.py +377 -0
  134. ultralytics/hub/__init__.py +168 -0
  135. ultralytics/hub/auth.py +137 -0
  136. ultralytics/hub/google/__init__.py +176 -0
  137. ultralytics/hub/session.py +446 -0
  138. ultralytics/hub/utils.py +248 -0
  139. ultralytics/models/__init__.py +9 -0
  140. ultralytics/models/fastsam/__init__.py +7 -0
  141. ultralytics/models/fastsam/model.py +61 -0
  142. ultralytics/models/fastsam/predict.py +181 -0
  143. ultralytics/models/fastsam/utils.py +24 -0
  144. ultralytics/models/fastsam/val.py +40 -0
  145. ultralytics/models/nas/__init__.py +7 -0
  146. ultralytics/models/nas/model.py +102 -0
  147. ultralytics/models/nas/predict.py +58 -0
  148. ultralytics/models/nas/val.py +39 -0
  149. ultralytics/models/rtdetr/__init__.py +7 -0
  150. ultralytics/models/rtdetr/model.py +63 -0
  151. ultralytics/models/rtdetr/predict.py +84 -0
  152. ultralytics/models/rtdetr/train.py +85 -0
  153. ultralytics/models/rtdetr/val.py +191 -0
  154. ultralytics/models/sam/__init__.py +6 -0
  155. ultralytics/models/sam/amg.py +260 -0
  156. ultralytics/models/sam/build.py +358 -0
  157. ultralytics/models/sam/model.py +170 -0
  158. ultralytics/models/sam/modules/__init__.py +1 -0
  159. ultralytics/models/sam/modules/blocks.py +1129 -0
  160. ultralytics/models/sam/modules/decoders.py +515 -0
  161. ultralytics/models/sam/modules/encoders.py +854 -0
  162. ultralytics/models/sam/modules/memory_attention.py +299 -0
  163. ultralytics/models/sam/modules/sam.py +1006 -0
  164. ultralytics/models/sam/modules/tiny_encoder.py +1002 -0
  165. ultralytics/models/sam/modules/transformer.py +351 -0
  166. ultralytics/models/sam/modules/utils.py +394 -0
  167. ultralytics/models/sam/predict.py +1605 -0
  168. ultralytics/models/utils/__init__.py +1 -0
  169. ultralytics/models/utils/loss.py +455 -0
  170. ultralytics/models/utils/ops.py +268 -0
  171. ultralytics/models/yolo/__init__.py +7 -0
  172. ultralytics/models/yolo/classify/__init__.py +7 -0
  173. ultralytics/models/yolo/classify/predict.py +88 -0
  174. ultralytics/models/yolo/classify/train.py +233 -0
  175. ultralytics/models/yolo/classify/val.py +215 -0
  176. ultralytics/models/yolo/detect/__init__.py +7 -0
  177. ultralytics/models/yolo/detect/predict.py +124 -0
  178. ultralytics/models/yolo/detect/train.py +217 -0
  179. ultralytics/models/yolo/detect/val.py +451 -0
  180. ultralytics/models/yolo/model.py +354 -0
  181. ultralytics/models/yolo/obb/__init__.py +7 -0
  182. ultralytics/models/yolo/obb/predict.py +66 -0
  183. ultralytics/models/yolo/obb/train.py +81 -0
  184. ultralytics/models/yolo/obb/val.py +283 -0
  185. ultralytics/models/yolo/pose/__init__.py +7 -0
  186. ultralytics/models/yolo/pose/predict.py +79 -0
  187. ultralytics/models/yolo/pose/train.py +154 -0
  188. ultralytics/models/yolo/pose/val.py +394 -0
  189. ultralytics/models/yolo/segment/__init__.py +7 -0
  190. ultralytics/models/yolo/segment/predict.py +113 -0
  191. ultralytics/models/yolo/segment/train.py +123 -0
  192. ultralytics/models/yolo/segment/val.py +428 -0
  193. ultralytics/models/yolo/world/__init__.py +5 -0
  194. ultralytics/models/yolo/world/train.py +119 -0
  195. ultralytics/models/yolo/world/train_world.py +176 -0
  196. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  197. ultralytics/models/yolo/yoloe/predict.py +169 -0
  198. ultralytics/models/yolo/yoloe/train.py +298 -0
  199. ultralytics/models/yolo/yoloe/train_seg.py +124 -0
  200. ultralytics/models/yolo/yoloe/val.py +191 -0
  201. ultralytics/nn/__init__.py +29 -0
  202. ultralytics/nn/autobackend.py +842 -0
  203. ultralytics/nn/modules/__init__.py +182 -0
  204. ultralytics/nn/modules/activation.py +53 -0
  205. ultralytics/nn/modules/block.py +1966 -0
  206. ultralytics/nn/modules/conv.py +712 -0
  207. ultralytics/nn/modules/head.py +880 -0
  208. ultralytics/nn/modules/transformer.py +713 -0
  209. ultralytics/nn/modules/utils.py +164 -0
  210. ultralytics/nn/tasks.py +1627 -0
  211. ultralytics/nn/text_model.py +351 -0
  212. ultralytics/solutions/__init__.py +41 -0
  213. ultralytics/solutions/ai_gym.py +116 -0
  214. ultralytics/solutions/analytics.py +252 -0
  215. ultralytics/solutions/config.py +106 -0
  216. ultralytics/solutions/distance_calculation.py +124 -0
  217. ultralytics/solutions/heatmap.py +127 -0
  218. ultralytics/solutions/instance_segmentation.py +84 -0
  219. ultralytics/solutions/object_blurrer.py +90 -0
  220. ultralytics/solutions/object_counter.py +195 -0
  221. ultralytics/solutions/object_cropper.py +84 -0
  222. ultralytics/solutions/parking_management.py +273 -0
  223. ultralytics/solutions/queue_management.py +93 -0
  224. ultralytics/solutions/region_counter.py +120 -0
  225. ultralytics/solutions/security_alarm.py +154 -0
  226. ultralytics/solutions/similarity_search.py +172 -0
  227. ultralytics/solutions/solutions.py +724 -0
  228. ultralytics/solutions/speed_estimation.py +110 -0
  229. ultralytics/solutions/streamlit_inference.py +196 -0
  230. ultralytics/solutions/templates/similarity-search.html +160 -0
  231. ultralytics/solutions/trackzone.py +88 -0
  232. ultralytics/solutions/vision_eye.py +68 -0
  233. ultralytics/trackers/__init__.py +7 -0
  234. ultralytics/trackers/basetrack.py +124 -0
  235. ultralytics/trackers/bot_sort.py +260 -0
  236. ultralytics/trackers/byte_tracker.py +480 -0
  237. ultralytics/trackers/track.py +125 -0
  238. ultralytics/trackers/utils/__init__.py +1 -0
  239. ultralytics/trackers/utils/gmc.py +376 -0
  240. ultralytics/trackers/utils/kalman_filter.py +493 -0
  241. ultralytics/trackers/utils/matching.py +157 -0
  242. ultralytics/utils/__init__.py +1435 -0
  243. ultralytics/utils/autobatch.py +106 -0
  244. ultralytics/utils/autodevice.py +174 -0
  245. ultralytics/utils/benchmarks.py +695 -0
  246. ultralytics/utils/callbacks/__init__.py +5 -0
  247. ultralytics/utils/callbacks/base.py +234 -0
  248. ultralytics/utils/callbacks/clearml.py +153 -0
  249. ultralytics/utils/callbacks/comet.py +552 -0
  250. ultralytics/utils/callbacks/dvc.py +205 -0
  251. ultralytics/utils/callbacks/hub.py +108 -0
  252. ultralytics/utils/callbacks/mlflow.py +138 -0
  253. ultralytics/utils/callbacks/neptune.py +140 -0
  254. ultralytics/utils/callbacks/raytune.py +43 -0
  255. ultralytics/utils/callbacks/tensorboard.py +132 -0
  256. ultralytics/utils/callbacks/wb.py +185 -0
  257. ultralytics/utils/checks.py +897 -0
  258. ultralytics/utils/dist.py +119 -0
  259. ultralytics/utils/downloads.py +499 -0
  260. ultralytics/utils/errors.py +43 -0
  261. ultralytics/utils/export.py +219 -0
  262. ultralytics/utils/files.py +221 -0
  263. ultralytics/utils/instance.py +499 -0
  264. ultralytics/utils/loss.py +813 -0
  265. ultralytics/utils/metrics.py +1356 -0
  266. ultralytics/utils/ops.py +885 -0
  267. ultralytics/utils/patches.py +143 -0
  268. ultralytics/utils/plotting.py +1011 -0
  269. ultralytics/utils/tal.py +416 -0
  270. ultralytics/utils/torch_utils.py +990 -0
  271. ultralytics/utils/triton.py +116 -0
  272. ultralytics/utils/tuner.py +159 -0
@@ -0,0 +1,1966 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """Block modules."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ultralytics.utils.torch_utils import fuse_conv_and_bn
9
+
10
+ from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
11
+ from .transformer import TransformerBlock
12
+
13
+ __all__ = (
14
+ "DFL",
15
+ "HGBlock",
16
+ "HGStem",
17
+ "SPP",
18
+ "SPPF",
19
+ "C1",
20
+ "C2",
21
+ "C3",
22
+ "C2f",
23
+ "C2fAttn",
24
+ "ImagePoolingAttn",
25
+ "ContrastiveHead",
26
+ "BNContrastiveHead",
27
+ "C3x",
28
+ "C3TR",
29
+ "C3Ghost",
30
+ "GhostBottleneck",
31
+ "Bottleneck",
32
+ "BottleneckCSP",
33
+ "Proto",
34
+ "RepC3",
35
+ "ResNetLayer",
36
+ "RepNCSPELAN4",
37
+ "ELAN1",
38
+ "ADown",
39
+ "AConv",
40
+ "SPPELAN",
41
+ "CBFuse",
42
+ "CBLinear",
43
+ "C3k2",
44
+ "C2fPSA",
45
+ "C2PSA",
46
+ "RepVGGDW",
47
+ "CIB",
48
+ "C2fCIB",
49
+ "Attention",
50
+ "PSA",
51
+ "SCDown",
52
+ "TorchVision",
53
+ )
54
+
55
+
56
+ class DFL(nn.Module):
57
+ """
58
+ Integral module of Distribution Focal Loss (DFL).
59
+
60
+ Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
61
+ """
62
+
63
+ def __init__(self, c1=16):
64
+ """Initialize a convolutional layer with a given number of input channels."""
65
+ super().__init__()
66
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
67
+ x = torch.arange(c1, dtype=torch.float)
68
+ self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
69
+ self.c1 = c1
70
+
71
+ def forward(self, x):
72
+ """Apply the DFL module to input tensor and return transformed output."""
73
+ b, _, a = x.shape # batch, channels, anchors
74
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
75
+ # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
76
+
77
+
78
+ class Proto(nn.Module):
79
+ """Ultralytics YOLO models mask Proto module for segmentation models."""
80
+
81
+ def __init__(self, c1, c_=256, c2=32):
82
+ """
83
+ Initialize the Ultralytics YOLO models mask Proto module with specified number of protos and masks.
84
+
85
+ Args:
86
+ c1 (int): Input channels.
87
+ c_ (int): Intermediate channels.
88
+ c2 (int): Output channels (number of protos).
89
+ """
90
+ super().__init__()
91
+ self.cv1 = Conv(c1, c_, k=3)
92
+ self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
93
+ self.cv2 = Conv(c_, c_, k=3)
94
+ self.cv3 = Conv(c_, c2)
95
+
96
+ def forward(self, x):
97
+ """Perform a forward pass through layers using an upsampled input image."""
98
+ return self.cv3(self.cv2(self.upsample(self.cv1(x))))
99
+
100
+
101
+ class HGStem(nn.Module):
102
+ """
103
+ StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
104
+
105
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
106
+ """
107
+
108
+ def __init__(self, c1, cm, c2):
109
+ """
110
+ Initialize the StemBlock of PPHGNetV2.
111
+
112
+ Args:
113
+ c1 (int): Input channels.
114
+ cm (int): Middle channels.
115
+ c2 (int): Output channels.
116
+ """
117
+ super().__init__()
118
+ self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())
119
+ self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())
120
+ self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())
121
+ self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())
122
+ self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())
123
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)
124
+
125
+ def forward(self, x):
126
+ """Forward pass of a PPHGNetV2 backbone layer."""
127
+ x = self.stem1(x)
128
+ x = F.pad(x, [0, 1, 0, 1])
129
+ x2 = self.stem2a(x)
130
+ x2 = F.pad(x2, [0, 1, 0, 1])
131
+ x2 = self.stem2b(x2)
132
+ x1 = self.pool(x)
133
+ x = torch.cat([x1, x2], dim=1)
134
+ x = self.stem3(x)
135
+ x = self.stem4(x)
136
+ return x
137
+
138
+
139
+ class HGBlock(nn.Module):
140
+ """
141
+ HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
142
+
143
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
144
+ """
145
+
146
+ def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
147
+ """
148
+ Initialize HGBlock with specified parameters.
149
+
150
+ Args:
151
+ c1 (int): Input channels.
152
+ cm (int): Middle channels.
153
+ c2 (int): Output channels.
154
+ k (int): Kernel size.
155
+ n (int): Number of LightConv or Conv blocks.
156
+ lightconv (bool): Whether to use LightConv.
157
+ shortcut (bool): Whether to use shortcut connection.
158
+ act (nn.Module): Activation function.
159
+ """
160
+ super().__init__()
161
+ block = LightConv if lightconv else Conv
162
+ self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
163
+ self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv
164
+ self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv
165
+ self.add = shortcut and c1 == c2
166
+
167
+ def forward(self, x):
168
+ """Forward pass of a PPHGNetV2 backbone layer."""
169
+ y = [x]
170
+ y.extend(m(y[-1]) for m in self.m)
171
+ y = self.ec(self.sc(torch.cat(y, 1)))
172
+ return y + x if self.add else y
173
+
174
+
175
+ class SPP(nn.Module):
176
+ """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
177
+
178
+ def __init__(self, c1, c2, k=(5, 9, 13)):
179
+ """
180
+ Initialize the SPP layer with input/output channels and pooling kernel sizes.
181
+
182
+ Args:
183
+ c1 (int): Input channels.
184
+ c2 (int): Output channels.
185
+ k (Tuple[int, int, int]): Kernel sizes for max pooling.
186
+ """
187
+ super().__init__()
188
+ c_ = c1 // 2 # hidden channels
189
+ self.cv1 = Conv(c1, c_, 1, 1)
190
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
191
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
192
+
193
+ def forward(self, x):
194
+ """Forward pass of the SPP layer, performing spatial pyramid pooling."""
195
+ x = self.cv1(x)
196
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
197
+
198
+
199
+ class SPPF(nn.Module):
200
+ """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
201
+
202
+ def __init__(self, c1, c2, k=5):
203
+ """
204
+ Initialize the SPPF layer with given input/output channels and kernel size.
205
+
206
+ Args:
207
+ c1 (int): Input channels.
208
+ c2 (int): Output channels.
209
+ k (int): Kernel size.
210
+
211
+ Notes:
212
+ This module is equivalent to SPP(k=(5, 9, 13)).
213
+ """
214
+ super().__init__()
215
+ c_ = c1 // 2 # hidden channels
216
+ self.cv1 = Conv(c1, c_, 1, 1)
217
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
218
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
219
+
220
+ def forward(self, x):
221
+ """Apply sequential pooling operations to input and return concatenated feature maps."""
222
+ y = [self.cv1(x)]
223
+ y.extend(self.m(y[-1]) for _ in range(3))
224
+ return self.cv2(torch.cat(y, 1))
225
+
226
+
227
+ class C1(nn.Module):
228
+ """CSP Bottleneck with 1 convolution."""
229
+
230
+ def __init__(self, c1, c2, n=1):
231
+ """
232
+ Initialize the CSP Bottleneck with 1 convolution.
233
+
234
+ Args:
235
+ c1 (int): Input channels.
236
+ c2 (int): Output channels.
237
+ n (int): Number of convolutions.
238
+ """
239
+ super().__init__()
240
+ self.cv1 = Conv(c1, c2, 1, 1)
241
+ self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
242
+
243
+ def forward(self, x):
244
+ """Apply convolution and residual connection to input tensor."""
245
+ y = self.cv1(x)
246
+ return self.m(y) + y
247
+
248
+
249
+ class C2(nn.Module):
250
+ """CSP Bottleneck with 2 convolutions."""
251
+
252
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
253
+ """
254
+ Initialize a CSP Bottleneck with 2 convolutions.
255
+
256
+ Args:
257
+ c1 (int): Input channels.
258
+ c2 (int): Output channels.
259
+ n (int): Number of Bottleneck blocks.
260
+ shortcut (bool): Whether to use shortcut connections.
261
+ g (int): Groups for convolutions.
262
+ e (float): Expansion ratio.
263
+ """
264
+ super().__init__()
265
+ self.c = int(c2 * e) # hidden channels
266
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
267
+ self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
268
+ # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
269
+ self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
270
+
271
+ def forward(self, x):
272
+ """Forward pass through the CSP bottleneck with 2 convolutions."""
273
+ a, b = self.cv1(x).chunk(2, 1)
274
+ return self.cv2(torch.cat((self.m(a), b), 1))
275
+
276
+
277
+ class C2f(nn.Module):
278
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
279
+
280
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
281
+ """
282
+ Initialize a CSP bottleneck with 2 convolutions.
283
+
284
+ Args:
285
+ c1 (int): Input channels.
286
+ c2 (int): Output channels.
287
+ n (int): Number of Bottleneck blocks.
288
+ shortcut (bool): Whether to use shortcut connections.
289
+ g (int): Groups for convolutions.
290
+ e (float): Expansion ratio.
291
+ """
292
+ super().__init__()
293
+ self.c = int(c2 * e) # hidden channels
294
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
295
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
296
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
297
+
298
+ def forward(self, x):
299
+ """Forward pass through C2f layer."""
300
+ y = list(self.cv1(x).chunk(2, 1))
301
+ y.extend(m(y[-1]) for m in self.m)
302
+ return self.cv2(torch.cat(y, 1))
303
+
304
+ def forward_split(self, x):
305
+ """Forward pass using split() instead of chunk()."""
306
+ y = self.cv1(x).split((self.c, self.c), 1)
307
+ y = [y[0], y[1]]
308
+ y.extend(m(y[-1]) for m in self.m)
309
+ return self.cv2(torch.cat(y, 1))
310
+
311
+
312
+ class C3(nn.Module):
313
+ """CSP Bottleneck with 3 convolutions."""
314
+
315
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
316
+ """
317
+ Initialize the CSP Bottleneck with 3 convolutions.
318
+
319
+ Args:
320
+ c1 (int): Input channels.
321
+ c2 (int): Output channels.
322
+ n (int): Number of Bottleneck blocks.
323
+ shortcut (bool): Whether to use shortcut connections.
324
+ g (int): Groups for convolutions.
325
+ e (float): Expansion ratio.
326
+ """
327
+ super().__init__()
328
+ c_ = int(c2 * e) # hidden channels
329
+ self.cv1 = Conv(c1, c_, 1, 1)
330
+ self.cv2 = Conv(c1, c_, 1, 1)
331
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
332
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
333
+
334
+ def forward(self, x):
335
+ """Forward pass through the CSP bottleneck with 3 convolutions."""
336
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
337
+
338
+
339
+ class C3x(C3):
340
+ """C3 module with cross-convolutions."""
341
+
342
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
343
+ """
344
+ Initialize C3 module with cross-convolutions.
345
+
346
+ Args:
347
+ c1 (int): Input channels.
348
+ c2 (int): Output channels.
349
+ n (int): Number of Bottleneck blocks.
350
+ shortcut (bool): Whether to use shortcut connections.
351
+ g (int): Groups for convolutions.
352
+ e (float): Expansion ratio.
353
+ """
354
+ super().__init__(c1, c2, n, shortcut, g, e)
355
+ self.c_ = int(c2 * e)
356
+ self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
357
+
358
+
359
+ class RepC3(nn.Module):
360
+ """Rep C3."""
361
+
362
+ def __init__(self, c1, c2, n=3, e=1.0):
363
+ """
364
+ Initialize CSP Bottleneck with a single convolution.
365
+
366
+ Args:
367
+ c1 (int): Input channels.
368
+ c2 (int): Output channels.
369
+ n (int): Number of RepConv blocks.
370
+ e (float): Expansion ratio.
371
+ """
372
+ super().__init__()
373
+ c_ = int(c2 * e) # hidden channels
374
+ self.cv1 = Conv(c1, c_, 1, 1)
375
+ self.cv2 = Conv(c1, c_, 1, 1)
376
+ self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
377
+ self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
378
+
379
+ def forward(self, x):
380
+ """Forward pass of RepC3 module."""
381
+ return self.cv3(self.m(self.cv1(x)) + self.cv2(x))
382
+
383
+
384
+ class C3TR(C3):
385
+ """C3 module with TransformerBlock()."""
386
+
387
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
388
+ """
389
+ Initialize C3 module with TransformerBlock.
390
+
391
+ Args:
392
+ c1 (int): Input channels.
393
+ c2 (int): Output channels.
394
+ n (int): Number of Transformer blocks.
395
+ shortcut (bool): Whether to use shortcut connections.
396
+ g (int): Groups for convolutions.
397
+ e (float): Expansion ratio.
398
+ """
399
+ super().__init__(c1, c2, n, shortcut, g, e)
400
+ c_ = int(c2 * e)
401
+ self.m = TransformerBlock(c_, c_, 4, n)
402
+
403
+
404
+ class C3Ghost(C3):
405
+ """C3 module with GhostBottleneck()."""
406
+
407
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
408
+ """
409
+ Initialize C3 module with GhostBottleneck.
410
+
411
+ Args:
412
+ c1 (int): Input channels.
413
+ c2 (int): Output channels.
414
+ n (int): Number of Ghost bottleneck blocks.
415
+ shortcut (bool): Whether to use shortcut connections.
416
+ g (int): Groups for convolutions.
417
+ e (float): Expansion ratio.
418
+ """
419
+ super().__init__(c1, c2, n, shortcut, g, e)
420
+ c_ = int(c2 * e) # hidden channels
421
+ self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
422
+
423
+
424
+ class GhostBottleneck(nn.Module):
425
+ """Ghost Bottleneck https://github.com/huawei-noah/Efficient-AI-Backbones."""
426
+
427
+ def __init__(self, c1, c2, k=3, s=1):
428
+ """
429
+ Initialize Ghost Bottleneck module.
430
+
431
+ Args:
432
+ c1 (int): Input channels.
433
+ c2 (int): Output channels.
434
+ k (int): Kernel size.
435
+ s (int): Stride.
436
+ """
437
+ super().__init__()
438
+ c_ = c2 // 2
439
+ self.conv = nn.Sequential(
440
+ GhostConv(c1, c_, 1, 1), # pw
441
+ DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
442
+ GhostConv(c_, c2, 1, 1, act=False), # pw-linear
443
+ )
444
+ self.shortcut = (
445
+ nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
446
+ )
447
+
448
+ def forward(self, x):
449
+ """Apply skip connection and concatenation to input tensor."""
450
+ return self.conv(x) + self.shortcut(x)
451
+
452
+
453
+ class Bottleneck(nn.Module):
454
+ """Standard bottleneck."""
455
+
456
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
457
+ """
458
+ Initialize a standard bottleneck module.
459
+
460
+ Args:
461
+ c1 (int): Input channels.
462
+ c2 (int): Output channels.
463
+ shortcut (bool): Whether to use shortcut connection.
464
+ g (int): Groups for convolutions.
465
+ k (Tuple[int, int]): Kernel sizes for convolutions.
466
+ e (float): Expansion ratio.
467
+ """
468
+ super().__init__()
469
+ c_ = int(c2 * e) # hidden channels
470
+ self.cv1 = Conv(c1, c_, k[0], 1)
471
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
472
+ self.add = shortcut and c1 == c2
473
+
474
+ def forward(self, x):
475
+ """Apply bottleneck with optional shortcut connection."""
476
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
477
+
478
+
479
+ class BottleneckCSP(nn.Module):
480
+ """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
481
+
482
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
483
+ """
484
+ Initialize CSP Bottleneck.
485
+
486
+ Args:
487
+ c1 (int): Input channels.
488
+ c2 (int): Output channels.
489
+ n (int): Number of Bottleneck blocks.
490
+ shortcut (bool): Whether to use shortcut connections.
491
+ g (int): Groups for convolutions.
492
+ e (float): Expansion ratio.
493
+ """
494
+ super().__init__()
495
+ c_ = int(c2 * e) # hidden channels
496
+ self.cv1 = Conv(c1, c_, 1, 1)
497
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
498
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
499
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
500
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
501
+ self.act = nn.SiLU()
502
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
503
+
504
+ def forward(self, x):
505
+ """Apply CSP bottleneck with 3 convolutions."""
506
+ y1 = self.cv3(self.m(self.cv1(x)))
507
+ y2 = self.cv2(x)
508
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
509
+
510
+
511
+ class ResNetBlock(nn.Module):
512
+ """ResNet block with standard convolution layers."""
513
+
514
+ def __init__(self, c1, c2, s=1, e=4):
515
+ """
516
+ Initialize ResNet block.
517
+
518
+ Args:
519
+ c1 (int): Input channels.
520
+ c2 (int): Output channels.
521
+ s (int): Stride.
522
+ e (int): Expansion ratio.
523
+ """
524
+ super().__init__()
525
+ c3 = e * c2
526
+ self.cv1 = Conv(c1, c2, k=1, s=1, act=True)
527
+ self.cv2 = Conv(c2, c2, k=3, s=s, p=1, act=True)
528
+ self.cv3 = Conv(c2, c3, k=1, act=False)
529
+ self.shortcut = nn.Sequential(Conv(c1, c3, k=1, s=s, act=False)) if s != 1 or c1 != c3 else nn.Identity()
530
+
531
+ def forward(self, x):
532
+ """Forward pass through the ResNet block."""
533
+ return F.relu(self.cv3(self.cv2(self.cv1(x))) + self.shortcut(x))
534
+
535
+
536
+ class ResNetLayer(nn.Module):
537
+ """ResNet layer with multiple ResNet blocks."""
538
+
539
+ def __init__(self, c1, c2, s=1, is_first=False, n=1, e=4):
540
+ """
541
+ Initialize ResNet layer.
542
+
543
+ Args:
544
+ c1 (int): Input channels.
545
+ c2 (int): Output channels.
546
+ s (int): Stride.
547
+ is_first (bool): Whether this is the first layer.
548
+ n (int): Number of ResNet blocks.
549
+ e (int): Expansion ratio.
550
+ """
551
+ super().__init__()
552
+ self.is_first = is_first
553
+
554
+ if self.is_first:
555
+ self.layer = nn.Sequential(
556
+ Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
557
+ )
558
+ else:
559
+ blocks = [ResNetBlock(c1, c2, s, e=e)]
560
+ blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])
561
+ self.layer = nn.Sequential(*blocks)
562
+
563
+ def forward(self, x):
564
+ """Forward pass through the ResNet layer."""
565
+ return self.layer(x)
566
+
567
+
568
+ class MaxSigmoidAttnBlock(nn.Module):
569
+ """Max Sigmoid attention block."""
570
+
571
+ def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False):
572
+ """
573
+ Initialize MaxSigmoidAttnBlock.
574
+
575
+ Args:
576
+ c1 (int): Input channels.
577
+ c2 (int): Output channels.
578
+ nh (int): Number of heads.
579
+ ec (int): Embedding channels.
580
+ gc (int): Guide channels.
581
+ scale (bool): Whether to use learnable scale parameter.
582
+ """
583
+ super().__init__()
584
+ self.nh = nh
585
+ self.hc = c2 // nh
586
+ self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None
587
+ self.gl = nn.Linear(gc, ec)
588
+ self.bias = nn.Parameter(torch.zeros(nh))
589
+ self.proj_conv = Conv(c1, c2, k=3, s=1, act=False)
590
+ self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0
591
+
592
+ def forward(self, x, guide):
593
+ """
594
+ Forward pass of MaxSigmoidAttnBlock.
595
+
596
+ Args:
597
+ x (torch.Tensor): Input tensor.
598
+ guide (torch.Tensor): Guide tensor.
599
+
600
+ Returns:
601
+ (torch.Tensor): Output tensor after attention.
602
+ """
603
+ bs, _, h, w = x.shape
604
+
605
+ guide = self.gl(guide)
606
+ guide = guide.view(bs, guide.shape[1], self.nh, self.hc)
607
+ embed = self.ec(x) if self.ec is not None else x
608
+ embed = embed.view(bs, self.nh, self.hc, h, w)
609
+
610
+ aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide)
611
+ aw = aw.max(dim=-1)[0]
612
+ aw = aw / (self.hc**0.5)
613
+ aw = aw + self.bias[None, :, None, None]
614
+ aw = aw.sigmoid() * self.scale
615
+
616
+ x = self.proj_conv(x)
617
+ x = x.view(bs, self.nh, -1, h, w)
618
+ x = x * aw.unsqueeze(2)
619
+ return x.view(bs, -1, h, w)
620
+
621
+
622
+ class C2fAttn(nn.Module):
623
+ """C2f module with an additional attn module."""
624
+
625
+ def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5):
626
+ """
627
+ Initialize C2f module with attention mechanism.
628
+
629
+ Args:
630
+ c1 (int): Input channels.
631
+ c2 (int): Output channels.
632
+ n (int): Number of Bottleneck blocks.
633
+ ec (int): Embedding channels for attention.
634
+ nh (int): Number of heads for attention.
635
+ gc (int): Guide channels for attention.
636
+ shortcut (bool): Whether to use shortcut connections.
637
+ g (int): Groups for convolutions.
638
+ e (float): Expansion ratio.
639
+ """
640
+ super().__init__()
641
+ self.c = int(c2 * e) # hidden channels
642
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
643
+ self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2)
644
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
645
+ self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh)
646
+
647
+ def forward(self, x, guide):
648
+ """
649
+ Forward pass through C2f layer with attention.
650
+
651
+ Args:
652
+ x (torch.Tensor): Input tensor.
653
+ guide (torch.Tensor): Guide tensor for attention.
654
+
655
+ Returns:
656
+ (torch.Tensor): Output tensor after processing.
657
+ """
658
+ y = list(self.cv1(x).chunk(2, 1))
659
+ y.extend(m(y[-1]) for m in self.m)
660
+ y.append(self.attn(y[-1], guide))
661
+ return self.cv2(torch.cat(y, 1))
662
+
663
+ def forward_split(self, x, guide):
664
+ """
665
+ Forward pass using split() instead of chunk().
666
+
667
+ Args:
668
+ x (torch.Tensor): Input tensor.
669
+ guide (torch.Tensor): Guide tensor for attention.
670
+
671
+ Returns:
672
+ (torch.Tensor): Output tensor after processing.
673
+ """
674
+ y = list(self.cv1(x).split((self.c, self.c), 1))
675
+ y.extend(m(y[-1]) for m in self.m)
676
+ y.append(self.attn(y[-1], guide))
677
+ return self.cv2(torch.cat(y, 1))
678
+
679
+
680
+ class ImagePoolingAttn(nn.Module):
681
+ """ImagePoolingAttn: Enhance the text embeddings with image-aware information."""
682
+
683
+ def __init__(self, ec=256, ch=(), ct=512, nh=8, k=3, scale=False):
684
+ """
685
+ Initialize ImagePoolingAttn module.
686
+
687
+ Args:
688
+ ec (int): Embedding channels.
689
+ ch (tuple): Channel dimensions for feature maps.
690
+ ct (int): Channel dimension for text embeddings.
691
+ nh (int): Number of attention heads.
692
+ k (int): Kernel size for pooling.
693
+ scale (bool): Whether to use learnable scale parameter.
694
+ """
695
+ super().__init__()
696
+
697
+ nf = len(ch)
698
+ self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec))
699
+ self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))
700
+ self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))
701
+ self.proj = nn.Linear(ec, ct)
702
+ self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0
703
+ self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch])
704
+ self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)])
705
+ self.ec = ec
706
+ self.nh = nh
707
+ self.nf = nf
708
+ self.hc = ec // nh
709
+ self.k = k
710
+
711
+ def forward(self, x, text):
712
+ """
713
+ Forward pass of ImagePoolingAttn.
714
+
715
+ Args:
716
+ x (List[torch.Tensor]): List of input feature maps.
717
+ text (torch.Tensor): Text embeddings.
718
+
719
+ Returns:
720
+ (torch.Tensor): Enhanced text embeddings.
721
+ """
722
+ bs = x[0].shape[0]
723
+ assert len(x) == self.nf
724
+ num_patches = self.k**2
725
+ x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)]
726
+ x = torch.cat(x, dim=-1).transpose(1, 2)
727
+ q = self.query(text)
728
+ k = self.key(x)
729
+ v = self.value(x)
730
+
731
+ # q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1)
732
+ q = q.reshape(bs, -1, self.nh, self.hc)
733
+ k = k.reshape(bs, -1, self.nh, self.hc)
734
+ v = v.reshape(bs, -1, self.nh, self.hc)
735
+
736
+ aw = torch.einsum("bnmc,bkmc->bmnk", q, k)
737
+ aw = aw / (self.hc**0.5)
738
+ aw = F.softmax(aw, dim=-1)
739
+
740
+ x = torch.einsum("bmnk,bkmc->bnmc", aw, v)
741
+ x = self.proj(x.reshape(bs, -1, self.ec))
742
+ return x * self.scale + text
743
+
744
+
745
+ class ContrastiveHead(nn.Module):
746
+ """Implements contrastive learning head for region-text similarity in vision-language models."""
747
+
748
+ def __init__(self):
749
+ """Initialize ContrastiveHead with region-text similarity parameters."""
750
+ super().__init__()
751
+ # NOTE: use -10.0 to keep the init cls loss consistency with other losses
752
+ self.bias = nn.Parameter(torch.tensor([-10.0]))
753
+ self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
754
+
755
+ def forward(self, x, w):
756
+ """
757
+ Forward function of contrastive learning.
758
+
759
+ Args:
760
+ x (torch.Tensor): Image features.
761
+ w (torch.Tensor): Text features.
762
+
763
+ Returns:
764
+ (torch.Tensor): Similarity scores.
765
+ """
766
+ x = F.normalize(x, dim=1, p=2)
767
+ w = F.normalize(w, dim=-1, p=2)
768
+ x = torch.einsum("bchw,bkc->bkhw", x, w)
769
+ return x * self.logit_scale.exp() + self.bias
770
+
771
+
772
+ class BNContrastiveHead(nn.Module):
773
+ """
774
+ Batch Norm Contrastive Head using batch norm instead of l2-normalization.
775
+
776
+ Args:
777
+ embed_dims (int): Embed dimensions of text and image features.
778
+ """
779
+
780
+ def __init__(self, embed_dims: int):
781
+ """
782
+ Initialize BNContrastiveHead.
783
+
784
+ Args:
785
+ embed_dims (int): Embedding dimensions for features.
786
+ """
787
+ super().__init__()
788
+ self.norm = nn.BatchNorm2d(embed_dims)
789
+ # NOTE: use -10.0 to keep the init cls loss consistency with other losses
790
+ self.bias = nn.Parameter(torch.tensor([-10.0]))
791
+ # use -1.0 is more stable
792
+ self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
793
+
794
+ def fuse(self):
795
+ """Fuse the batch normalization layer in the BNContrastiveHead module."""
796
+ del self.norm
797
+ del self.bias
798
+ del self.logit_scale
799
+ self.forward = self.forward_fuse
800
+
801
+ def forward_fuse(self, x, w):
802
+ """
803
+ Passes input out unchanged.
804
+
805
+ TODO: Update or remove?
806
+ """
807
+ return x
808
+
809
+ def forward(self, x, w):
810
+ """
811
+ Forward function of contrastive learning with batch normalization.
812
+
813
+ Args:
814
+ x (torch.Tensor): Image features.
815
+ w (torch.Tensor): Text features.
816
+
817
+ Returns:
818
+ (torch.Tensor): Similarity scores.
819
+ """
820
+ x = self.norm(x)
821
+ w = F.normalize(w, dim=-1, p=2)
822
+
823
+ x = torch.einsum("bchw,bkc->bkhw", x, w)
824
+ return x * self.logit_scale.exp() + self.bias
825
+
826
+
827
+ class RepBottleneck(Bottleneck):
828
+ """Rep bottleneck."""
829
+
830
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
831
+ """
832
+ Initialize RepBottleneck.
833
+
834
+ Args:
835
+ c1 (int): Input channels.
836
+ c2 (int): Output channels.
837
+ shortcut (bool): Whether to use shortcut connection.
838
+ g (int): Groups for convolutions.
839
+ k (Tuple[int, int]): Kernel sizes for convolutions.
840
+ e (float): Expansion ratio.
841
+ """
842
+ super().__init__(c1, c2, shortcut, g, k, e)
843
+ c_ = int(c2 * e) # hidden channels
844
+ self.cv1 = RepConv(c1, c_, k[0], 1)
845
+
846
+
847
+ class RepCSP(C3):
848
+ """Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction."""
849
+
850
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
851
+ """
852
+ Initialize RepCSP layer.
853
+
854
+ Args:
855
+ c1 (int): Input channels.
856
+ c2 (int): Output channels.
857
+ n (int): Number of RepBottleneck blocks.
858
+ shortcut (bool): Whether to use shortcut connections.
859
+ g (int): Groups for convolutions.
860
+ e (float): Expansion ratio.
861
+ """
862
+ super().__init__(c1, c2, n, shortcut, g, e)
863
+ c_ = int(c2 * e) # hidden channels
864
+ self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
865
+
866
+
867
+ class RepNCSPELAN4(nn.Module):
868
+ """CSP-ELAN."""
869
+
870
+ def __init__(self, c1, c2, c3, c4, n=1):
871
+ """
872
+ Initialize CSP-ELAN layer.
873
+
874
+ Args:
875
+ c1 (int): Input channels.
876
+ c2 (int): Output channels.
877
+ c3 (int): Intermediate channels.
878
+ c4 (int): Intermediate channels for RepCSP.
879
+ n (int): Number of RepCSP blocks.
880
+ """
881
+ super().__init__()
882
+ self.c = c3 // 2
883
+ self.cv1 = Conv(c1, c3, 1, 1)
884
+ self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1))
885
+ self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1))
886
+ self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
887
+
888
+ def forward(self, x):
889
+ """Forward pass through RepNCSPELAN4 layer."""
890
+ y = list(self.cv1(x).chunk(2, 1))
891
+ y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
892
+ return self.cv4(torch.cat(y, 1))
893
+
894
+ def forward_split(self, x):
895
+ """Forward pass using split() instead of chunk()."""
896
+ y = list(self.cv1(x).split((self.c, self.c), 1))
897
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
898
+ return self.cv4(torch.cat(y, 1))
899
+
900
+
901
+ class ELAN1(RepNCSPELAN4):
902
+ """ELAN1 module with 4 convolutions."""
903
+
904
+ def __init__(self, c1, c2, c3, c4):
905
+ """
906
+ Initialize ELAN1 layer.
907
+
908
+ Args:
909
+ c1 (int): Input channels.
910
+ c2 (int): Output channels.
911
+ c3 (int): Intermediate channels.
912
+ c4 (int): Intermediate channels for convolutions.
913
+ """
914
+ super().__init__(c1, c2, c3, c4)
915
+ self.c = c3 // 2
916
+ self.cv1 = Conv(c1, c3, 1, 1)
917
+ self.cv2 = Conv(c3 // 2, c4, 3, 1)
918
+ self.cv3 = Conv(c4, c4, 3, 1)
919
+ self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
920
+
921
+
922
+ class AConv(nn.Module):
923
+ """AConv."""
924
+
925
+ def __init__(self, c1, c2):
926
+ """
927
+ Initialize AConv module.
928
+
929
+ Args:
930
+ c1 (int): Input channels.
931
+ c2 (int): Output channels.
932
+ """
933
+ super().__init__()
934
+ self.cv1 = Conv(c1, c2, 3, 2, 1)
935
+
936
+ def forward(self, x):
937
+ """Forward pass through AConv layer."""
938
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
939
+ return self.cv1(x)
940
+
941
+
942
+ class ADown(nn.Module):
943
+ """ADown."""
944
+
945
+ def __init__(self, c1, c2):
946
+ """
947
+ Initialize ADown module.
948
+
949
+ Args:
950
+ c1 (int): Input channels.
951
+ c2 (int): Output channels.
952
+ """
953
+ super().__init__()
954
+ self.c = c2 // 2
955
+ self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
956
+ self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
957
+
958
+ def forward(self, x):
959
+ """Forward pass through ADown layer."""
960
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
961
+ x1, x2 = x.chunk(2, 1)
962
+ x1 = self.cv1(x1)
963
+ x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
964
+ x2 = self.cv2(x2)
965
+ return torch.cat((x1, x2), 1)
966
+
967
+
968
+ class SPPELAN(nn.Module):
969
+ """SPP-ELAN."""
970
+
971
+ def __init__(self, c1, c2, c3, k=5):
972
+ """
973
+ Initialize SPP-ELAN block.
974
+
975
+ Args:
976
+ c1 (int): Input channels.
977
+ c2 (int): Output channels.
978
+ c3 (int): Intermediate channels.
979
+ k (int): Kernel size for max pooling.
980
+ """
981
+ super().__init__()
982
+ self.c = c3
983
+ self.cv1 = Conv(c1, c3, 1, 1)
984
+ self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
985
+ self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
986
+ self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
987
+ self.cv5 = Conv(4 * c3, c2, 1, 1)
988
+
989
+ def forward(self, x):
990
+ """Forward pass through SPPELAN layer."""
991
+ y = [self.cv1(x)]
992
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
993
+ return self.cv5(torch.cat(y, 1))
994
+
995
+
996
+ class CBLinear(nn.Module):
997
+ """CBLinear."""
998
+
999
+ def __init__(self, c1, c2s, k=1, s=1, p=None, g=1):
1000
+ """
1001
+ Initialize CBLinear module.
1002
+
1003
+ Args:
1004
+ c1 (int): Input channels.
1005
+ c2s (List[int]): List of output channel sizes.
1006
+ k (int): Kernel size.
1007
+ s (int): Stride.
1008
+ p (int | None): Padding.
1009
+ g (int): Groups.
1010
+ """
1011
+ super().__init__()
1012
+ self.c2s = c2s
1013
+ self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
1014
+
1015
+ def forward(self, x):
1016
+ """Forward pass through CBLinear layer."""
1017
+ return self.conv(x).split(self.c2s, dim=1)
1018
+
1019
+
1020
+ class CBFuse(nn.Module):
1021
+ """CBFuse."""
1022
+
1023
+ def __init__(self, idx):
1024
+ """
1025
+ Initialize CBFuse module.
1026
+
1027
+ Args:
1028
+ idx (List[int]): Indices for feature selection.
1029
+ """
1030
+ super().__init__()
1031
+ self.idx = idx
1032
+
1033
+ def forward(self, xs):
1034
+ """
1035
+ Forward pass through CBFuse layer.
1036
+
1037
+ Args:
1038
+ xs (List[torch.Tensor]): List of input tensors.
1039
+
1040
+ Returns:
1041
+ (torch.Tensor): Fused output tensor.
1042
+ """
1043
+ target_size = xs[-1].shape[2:]
1044
+ res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
1045
+ return torch.sum(torch.stack(res + xs[-1:]), dim=0)
1046
+
1047
+
1048
+ class C3f(nn.Module):
1049
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
1050
+
1051
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
1052
+ """
1053
+ Initialize CSP bottleneck layer with two convolutions.
1054
+
1055
+ Args:
1056
+ c1 (int): Input channels.
1057
+ c2 (int): Output channels.
1058
+ n (int): Number of Bottleneck blocks.
1059
+ shortcut (bool): Whether to use shortcut connections.
1060
+ g (int): Groups for convolutions.
1061
+ e (float): Expansion ratio.
1062
+ """
1063
+ super().__init__()
1064
+ c_ = int(c2 * e) # hidden channels
1065
+ self.cv1 = Conv(c1, c_, 1, 1)
1066
+ self.cv2 = Conv(c1, c_, 1, 1)
1067
+ self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)
1068
+ self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
1069
+
1070
+ def forward(self, x):
1071
+ """Forward pass through C3f layer."""
1072
+ y = [self.cv2(x), self.cv1(x)]
1073
+ y.extend(m(y[-1]) for m in self.m)
1074
+ return self.cv3(torch.cat(y, 1))
1075
+
1076
+
1077
+ class C3k2(C2f):
1078
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
1079
+
1080
+ def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True):
1081
+ """
1082
+ Initialize C3k2 module.
1083
+
1084
+ Args:
1085
+ c1 (int): Input channels.
1086
+ c2 (int): Output channels.
1087
+ n (int): Number of blocks.
1088
+ c3k (bool): Whether to use C3k blocks.
1089
+ e (float): Expansion ratio.
1090
+ g (int): Groups for convolutions.
1091
+ shortcut (bool): Whether to use shortcut connections.
1092
+ """
1093
+ super().__init__(c1, c2, n, shortcut, g, e)
1094
+ self.m = nn.ModuleList(
1095
+ C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
1096
+ )
1097
+
1098
+
1099
+ class C3k(C3):
1100
+ """C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""
1101
+
1102
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
1103
+ """
1104
+ Initialize C3k module.
1105
+
1106
+ Args:
1107
+ c1 (int): Input channels.
1108
+ c2 (int): Output channels.
1109
+ n (int): Number of Bottleneck blocks.
1110
+ shortcut (bool): Whether to use shortcut connections.
1111
+ g (int): Groups for convolutions.
1112
+ e (float): Expansion ratio.
1113
+ k (int): Kernel size.
1114
+ """
1115
+ super().__init__(c1, c2, n, shortcut, g, e)
1116
+ c_ = int(c2 * e) # hidden channels
1117
+ # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
1118
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
1119
+
1120
+
1121
+ class RepVGGDW(torch.nn.Module):
1122
+ """RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
1123
+
1124
+ def __init__(self, ed) -> None:
1125
+ """
1126
+ Initialize RepVGGDW module.
1127
+
1128
+ Args:
1129
+ ed (int): Input and output channels.
1130
+ """
1131
+ super().__init__()
1132
+ self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
1133
+ self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
1134
+ self.dim = ed
1135
+ self.act = nn.SiLU()
1136
+
1137
+ def forward(self, x):
1138
+ """
1139
+ Perform a forward pass of the RepVGGDW block.
1140
+
1141
+ Args:
1142
+ x (torch.Tensor): Input tensor.
1143
+
1144
+ Returns:
1145
+ (torch.Tensor): Output tensor after applying the depth wise separable convolution.
1146
+ """
1147
+ return self.act(self.conv(x) + self.conv1(x))
1148
+
1149
+ def forward_fuse(self, x):
1150
+ """
1151
+ Perform a forward pass of the RepVGGDW block without fusing the convolutions.
1152
+
1153
+ Args:
1154
+ x (torch.Tensor): Input tensor.
1155
+
1156
+ Returns:
1157
+ (torch.Tensor): Output tensor after applying the depth wise separable convolution.
1158
+ """
1159
+ return self.act(self.conv(x))
1160
+
1161
+ @torch.no_grad()
1162
+ def fuse(self):
1163
+ """
1164
+ Fuse the convolutional layers in the RepVGGDW block.
1165
+
1166
+ This method fuses the convolutional layers and updates the weights and biases accordingly.
1167
+ """
1168
+ conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
1169
+ conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
1170
+
1171
+ conv_w = conv.weight
1172
+ conv_b = conv.bias
1173
+ conv1_w = conv1.weight
1174
+ conv1_b = conv1.bias
1175
+
1176
+ conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])
1177
+
1178
+ final_conv_w = conv_w + conv1_w
1179
+ final_conv_b = conv_b + conv1_b
1180
+
1181
+ conv.weight.data.copy_(final_conv_w)
1182
+ conv.bias.data.copy_(final_conv_b)
1183
+
1184
+ self.conv = conv
1185
+ del self.conv1
1186
+
1187
+
1188
+ class CIB(nn.Module):
1189
+ """
1190
+ Conditional Identity Block (CIB) module.
1191
+
1192
+ Args:
1193
+ c1 (int): Number of input channels.
1194
+ c2 (int): Number of output channels.
1195
+ shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.
1196
+ e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.
1197
+ lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.
1198
+ """
1199
+
1200
+ def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
1201
+ """
1202
+ Initialize the CIB module.
1203
+
1204
+ Args:
1205
+ c1 (int): Input channels.
1206
+ c2 (int): Output channels.
1207
+ shortcut (bool): Whether to use shortcut connection.
1208
+ e (float): Expansion ratio.
1209
+ lk (bool): Whether to use RepVGGDW.
1210
+ """
1211
+ super().__init__()
1212
+ c_ = int(c2 * e) # hidden channels
1213
+ self.cv1 = nn.Sequential(
1214
+ Conv(c1, c1, 3, g=c1),
1215
+ Conv(c1, 2 * c_, 1),
1216
+ RepVGGDW(2 * c_) if lk else Conv(2 * c_, 2 * c_, 3, g=2 * c_),
1217
+ Conv(2 * c_, c2, 1),
1218
+ Conv(c2, c2, 3, g=c2),
1219
+ )
1220
+
1221
+ self.add = shortcut and c1 == c2
1222
+
1223
+ def forward(self, x):
1224
+ """
1225
+ Forward pass of the CIB module.
1226
+
1227
+ Args:
1228
+ x (torch.Tensor): Input tensor.
1229
+
1230
+ Returns:
1231
+ (torch.Tensor): Output tensor.
1232
+ """
1233
+ return x + self.cv1(x) if self.add else self.cv1(x)
1234
+
1235
+
1236
+ class C2fCIB(C2f):
1237
+ """
1238
+ C2fCIB class represents a convolutional block with C2f and CIB modules.
1239
+
1240
+ Args:
1241
+ c1 (int): Number of input channels.
1242
+ c2 (int): Number of output channels.
1243
+ n (int, optional): Number of CIB modules to stack. Defaults to 1.
1244
+ shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.
1245
+ lk (bool, optional): Whether to use local key connection. Defaults to False.
1246
+ g (int, optional): Number of groups for grouped convolution. Defaults to 1.
1247
+ e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.
1248
+ """
1249
+
1250
+ def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
1251
+ """
1252
+ Initialize C2fCIB module.
1253
+
1254
+ Args:
1255
+ c1 (int): Input channels.
1256
+ c2 (int): Output channels.
1257
+ n (int): Number of CIB modules.
1258
+ shortcut (bool): Whether to use shortcut connection.
1259
+ lk (bool): Whether to use local key connection.
1260
+ g (int): Groups for convolutions.
1261
+ e (float): Expansion ratio.
1262
+ """
1263
+ super().__init__(c1, c2, n, shortcut, g, e)
1264
+ self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
1265
+
1266
+
1267
+ class Attention(nn.Module):
1268
+ """
1269
+ Attention module that performs self-attention on the input tensor.
1270
+
1271
+ Args:
1272
+ dim (int): The input tensor dimension.
1273
+ num_heads (int): The number of attention heads.
1274
+ attn_ratio (float): The ratio of the attention key dimension to the head dimension.
1275
+
1276
+ Attributes:
1277
+ num_heads (int): The number of attention heads.
1278
+ head_dim (int): The dimension of each attention head.
1279
+ key_dim (int): The dimension of the attention key.
1280
+ scale (float): The scaling factor for the attention scores.
1281
+ qkv (Conv): Convolutional layer for computing the query, key, and value.
1282
+ proj (Conv): Convolutional layer for projecting the attended values.
1283
+ pe (Conv): Convolutional layer for positional encoding.
1284
+ """
1285
+
1286
+ def __init__(self, dim, num_heads=8, attn_ratio=0.5):
1287
+ """
1288
+ Initialize multi-head attention module.
1289
+
1290
+ Args:
1291
+ dim (int): Input dimension.
1292
+ num_heads (int): Number of attention heads.
1293
+ attn_ratio (float): Attention ratio for key dimension.
1294
+ """
1295
+ super().__init__()
1296
+ self.num_heads = num_heads
1297
+ self.head_dim = dim // num_heads
1298
+ self.key_dim = int(self.head_dim * attn_ratio)
1299
+ self.scale = self.key_dim**-0.5
1300
+ nh_kd = self.key_dim * num_heads
1301
+ h = dim + nh_kd * 2
1302
+ self.qkv = Conv(dim, h, 1, act=False)
1303
+ self.proj = Conv(dim, dim, 1, act=False)
1304
+ self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
1305
+
1306
+ def forward(self, x):
1307
+ """
1308
+ Forward pass of the Attention module.
1309
+
1310
+ Args:
1311
+ x (torch.Tensor): The input tensor.
1312
+
1313
+ Returns:
1314
+ (torch.Tensor): The output tensor after self-attention.
1315
+ """
1316
+ B, C, H, W = x.shape
1317
+ N = H * W
1318
+ qkv = self.qkv(x)
1319
+ q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
1320
+ [self.key_dim, self.key_dim, self.head_dim], dim=2
1321
+ )
1322
+
1323
+ attn = (q.transpose(-2, -1) @ k) * self.scale
1324
+ attn = attn.softmax(dim=-1)
1325
+ x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
1326
+ x = self.proj(x)
1327
+ return x
1328
+
1329
+
1330
+ class PSABlock(nn.Module):
1331
+ """
1332
+ PSABlock class implementing a Position-Sensitive Attention block for neural networks.
1333
+
1334
+ This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
1335
+ with optional shortcut connections.
1336
+
1337
+ Attributes:
1338
+ attn (Attention): Multi-head attention module.
1339
+ ffn (nn.Sequential): Feed-forward neural network module.
1340
+ add (bool): Flag indicating whether to add shortcut connections.
1341
+
1342
+ Methods:
1343
+ forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.
1344
+
1345
+ Examples:
1346
+ Create a PSABlock and perform a forward pass
1347
+ >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)
1348
+ >>> input_tensor = torch.randn(1, 128, 32, 32)
1349
+ >>> output_tensor = psablock(input_tensor)
1350
+ """
1351
+
1352
+ def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None:
1353
+ """
1354
+ Initialize the PSABlock.
1355
+
1356
+ Args:
1357
+ c (int): Input and output channels.
1358
+ attn_ratio (float): Attention ratio for key dimension.
1359
+ num_heads (int): Number of attention heads.
1360
+ shortcut (bool): Whether to use shortcut connections.
1361
+ """
1362
+ super().__init__()
1363
+
1364
+ self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
1365
+ self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
1366
+ self.add = shortcut
1367
+
1368
+ def forward(self, x):
1369
+ """
1370
+ Execute a forward pass through PSABlock.
1371
+
1372
+ Args:
1373
+ x (torch.Tensor): Input tensor.
1374
+
1375
+ Returns:
1376
+ (torch.Tensor): Output tensor after attention and feed-forward processing.
1377
+ """
1378
+ x = x + self.attn(x) if self.add else self.attn(x)
1379
+ x = x + self.ffn(x) if self.add else self.ffn(x)
1380
+ return x
1381
+
1382
+
1383
+ class PSA(nn.Module):
1384
+ """
1385
+ PSA class for implementing Position-Sensitive Attention in neural networks.
1386
+
1387
+ This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to
1388
+ input tensors, enhancing feature extraction and processing capabilities.
1389
+
1390
+ Attributes:
1391
+ c (int): Number of hidden channels after applying the initial convolution.
1392
+ cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
1393
+ cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
1394
+ attn (Attention): Attention module for position-sensitive attention.
1395
+ ffn (nn.Sequential): Feed-forward network for further processing.
1396
+
1397
+ Methods:
1398
+ forward: Applies position-sensitive attention and feed-forward network to the input tensor.
1399
+
1400
+ Examples:
1401
+ Create a PSA module and apply it to an input tensor
1402
+ >>> psa = PSA(c1=128, c2=128, e=0.5)
1403
+ >>> input_tensor = torch.randn(1, 128, 64, 64)
1404
+ >>> output_tensor = psa.forward(input_tensor)
1405
+ """
1406
+
1407
+ def __init__(self, c1, c2, e=0.5):
1408
+ """
1409
+ Initialize PSA module.
1410
+
1411
+ Args:
1412
+ c1 (int): Input channels.
1413
+ c2 (int): Output channels.
1414
+ e (float): Expansion ratio.
1415
+ """
1416
+ super().__init__()
1417
+ assert c1 == c2
1418
+ self.c = int(c1 * e)
1419
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
1420
+ self.cv2 = Conv(2 * self.c, c1, 1)
1421
+
1422
+ self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
1423
+ self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
1424
+
1425
+ def forward(self, x):
1426
+ """
1427
+ Execute forward pass in PSA module.
1428
+
1429
+ Args:
1430
+ x (torch.Tensor): Input tensor.
1431
+
1432
+ Returns:
1433
+ (torch.Tensor): Output tensor after attention and feed-forward processing.
1434
+ """
1435
+ a, b = self.cv1(x).split((self.c, self.c), dim=1)
1436
+ b = b + self.attn(b)
1437
+ b = b + self.ffn(b)
1438
+ return self.cv2(torch.cat((a, b), 1))
1439
+
1440
+
1441
+ class C2PSA(nn.Module):
1442
+ """
1443
+ C2PSA module with attention mechanism for enhanced feature extraction and processing.
1444
+
1445
+ This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
1446
+ capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
1447
+
1448
+ Attributes:
1449
+ c (int): Number of hidden channels.
1450
+ cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
1451
+ cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
1452
+ m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.
1453
+
1454
+ Methods:
1455
+ forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.
1456
+
1457
+ Notes:
1458
+ This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.
1459
+
1460
+ Examples:
1461
+ >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)
1462
+ >>> input_tensor = torch.randn(1, 256, 64, 64)
1463
+ >>> output_tensor = c2psa(input_tensor)
1464
+ """
1465
+
1466
+ def __init__(self, c1, c2, n=1, e=0.5):
1467
+ """
1468
+ Initialize C2PSA module.
1469
+
1470
+ Args:
1471
+ c1 (int): Input channels.
1472
+ c2 (int): Output channels.
1473
+ n (int): Number of PSABlock modules.
1474
+ e (float): Expansion ratio.
1475
+ """
1476
+ super().__init__()
1477
+ assert c1 == c2
1478
+ self.c = int(c1 * e)
1479
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
1480
+ self.cv2 = Conv(2 * self.c, c1, 1)
1481
+
1482
+ self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
1483
+
1484
+ def forward(self, x):
1485
+ """
1486
+ Process the input tensor through a series of PSA blocks.
1487
+
1488
+ Args:
1489
+ x (torch.Tensor): Input tensor.
1490
+
1491
+ Returns:
1492
+ (torch.Tensor): Output tensor after processing.
1493
+ """
1494
+ a, b = self.cv1(x).split((self.c, self.c), dim=1)
1495
+ b = self.m(b)
1496
+ return self.cv2(torch.cat((a, b), 1))
1497
+
1498
+
1499
+ class C2fPSA(C2f):
1500
+ """
1501
+ C2fPSA module with enhanced feature extraction using PSA blocks.
1502
+
1503
+ This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.
1504
+
1505
+ Attributes:
1506
+ c (int): Number of hidden channels.
1507
+ cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
1508
+ cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
1509
+ m (nn.ModuleList): List of PSA blocks for feature extraction.
1510
+
1511
+ Methods:
1512
+ forward: Performs a forward pass through the C2fPSA module.
1513
+ forward_split: Performs a forward pass using split() instead of chunk().
1514
+
1515
+ Examples:
1516
+ >>> import torch
1517
+ >>> from ultralytics.models.common import C2fPSA
1518
+ >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)
1519
+ >>> x = torch.randn(1, 64, 128, 128)
1520
+ >>> output = model(x)
1521
+ >>> print(output.shape)
1522
+ """
1523
+
1524
+ def __init__(self, c1, c2, n=1, e=0.5):
1525
+ """
1526
+ Initialize C2fPSA module.
1527
+
1528
+ Args:
1529
+ c1 (int): Input channels.
1530
+ c2 (int): Output channels.
1531
+ n (int): Number of PSABlock modules.
1532
+ e (float): Expansion ratio.
1533
+ """
1534
+ assert c1 == c2
1535
+ super().__init__(c1, c2, n=n, e=e)
1536
+ self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))
1537
+
1538
+
1539
+ class SCDown(nn.Module):
1540
+ """
1541
+ SCDown module for downsampling with separable convolutions.
1542
+
1543
+ This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in
1544
+ efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.
1545
+
1546
+ Attributes:
1547
+ cv1 (Conv): Pointwise convolution layer that reduces the number of channels.
1548
+ cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.
1549
+
1550
+ Methods:
1551
+ forward: Applies the SCDown module to the input tensor.
1552
+
1553
+ Examples:
1554
+ >>> import torch
1555
+ >>> from ultralytics import SCDown
1556
+ >>> model = SCDown(c1=64, c2=128, k=3, s=2)
1557
+ >>> x = torch.randn(1, 64, 128, 128)
1558
+ >>> y = model(x)
1559
+ >>> print(y.shape)
1560
+ torch.Size([1, 128, 64, 64])
1561
+ """
1562
+
1563
+ def __init__(self, c1, c2, k, s):
1564
+ """
1565
+ Initialize SCDown module.
1566
+
1567
+ Args:
1568
+ c1 (int): Input channels.
1569
+ c2 (int): Output channels.
1570
+ k (int): Kernel size.
1571
+ s (int): Stride.
1572
+ """
1573
+ super().__init__()
1574
+ self.cv1 = Conv(c1, c2, 1, 1)
1575
+ self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
1576
+
1577
+ def forward(self, x):
1578
+ """
1579
+ Apply convolution and downsampling to the input tensor.
1580
+
1581
+ Args:
1582
+ x (torch.Tensor): Input tensor.
1583
+
1584
+ Returns:
1585
+ (torch.Tensor): Downsampled output tensor.
1586
+ """
1587
+ return self.cv2(self.cv1(x))
1588
+
1589
+
1590
+ class TorchVision(nn.Module):
1591
+ """
1592
+ TorchVision module to allow loading any torchvision model.
1593
+
1594
+ This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
1595
+
1596
+ Attributes:
1597
+ m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
1598
+
1599
+ Args:
1600
+ model (str): Name of the torchvision model to load.
1601
+ weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
1602
+ unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
1603
+ truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
1604
+ split (bool, optional): Returns output from intermediate child modules as list. Default is False.
1605
+ """
1606
+
1607
+ def __init__(self, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
1608
+ """
1609
+ Load the model and weights from torchvision.
1610
+
1611
+ Args:
1612
+ model (str): Name of the torchvision model to load.
1613
+ weights (str): Pre-trained weights to load.
1614
+ unwrap (bool): Whether to unwrap the model.
1615
+ truncate (int): Number of layers to truncate.
1616
+ split (bool): Whether to split the output.
1617
+ """
1618
+ import torchvision # scope for faster 'import ultralytics'
1619
+
1620
+ super().__init__()
1621
+ if hasattr(torchvision.models, "get_model"):
1622
+ self.m = torchvision.models.get_model(model, weights=weights)
1623
+ else:
1624
+ self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
1625
+ if unwrap:
1626
+ layers = list(self.m.children())
1627
+ if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
1628
+ layers = [*list(layers[0].children()), *layers[1:]]
1629
+ self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers))
1630
+ self.split = split
1631
+ else:
1632
+ self.split = False
1633
+ self.m.head = self.m.heads = nn.Identity()
1634
+
1635
+ def forward(self, x):
1636
+ """
1637
+ Forward pass through the model.
1638
+
1639
+ Args:
1640
+ x (torch.Tensor): Input tensor.
1641
+
1642
+ Returns:
1643
+ (torch.Tensor | List[torch.Tensor]): Output tensor or list of tensors.
1644
+ """
1645
+ if self.split:
1646
+ y = [x]
1647
+ y.extend(m(y[-1]) for m in self.m)
1648
+ else:
1649
+ y = self.m(x)
1650
+ return y
1651
+
1652
+
1653
+ class AAttn(nn.Module):
1654
+ """
1655
+ Area-attention module for YOLO models, providing efficient attention mechanisms.
1656
+
1657
+ This module implements an area-based attention mechanism that processes input features in a spatially-aware manner,
1658
+ making it particularly effective for object detection tasks.
1659
+
1660
+ Attributes:
1661
+ area (int): Number of areas the feature map is divided.
1662
+ num_heads (int): Number of heads into which the attention mechanism is divided.
1663
+ head_dim (int): Dimension of each attention head.
1664
+ qkv (Conv): Convolution layer for computing query, key and value tensors.
1665
+ proj (Conv): Projection convolution layer.
1666
+ pe (Conv): Position encoding convolution layer.
1667
+
1668
+ Methods:
1669
+ forward: Applies area-attention to input tensor.
1670
+
1671
+ Examples:
1672
+ >>> attn = AAttn(dim=256, num_heads=8, area=4)
1673
+ >>> x = torch.randn(1, 256, 32, 32)
1674
+ >>> output = attn(x)
1675
+ >>> print(output.shape)
1676
+ torch.Size([1, 256, 32, 32])
1677
+ """
1678
+
1679
+ def __init__(self, dim, num_heads, area=1):
1680
+ """
1681
+ Initialize an Area-attention module for YOLO models.
1682
+
1683
+ Args:
1684
+ dim (int): Number of hidden channels.
1685
+ num_heads (int): Number of heads into which the attention mechanism is divided.
1686
+ area (int): Number of areas the feature map is divided, default is 1.
1687
+ """
1688
+ super().__init__()
1689
+ self.area = area
1690
+
1691
+ self.num_heads = num_heads
1692
+ self.head_dim = head_dim = dim // num_heads
1693
+ all_head_dim = head_dim * self.num_heads
1694
+
1695
+ self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)
1696
+ self.proj = Conv(all_head_dim, dim, 1, act=False)
1697
+ self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
1698
+
1699
+ def forward(self, x):
1700
+ """
1701
+ Process the input tensor through the area-attention.
1702
+
1703
+ Args:
1704
+ x (torch.Tensor): Input tensor.
1705
+
1706
+ Returns:
1707
+ (torch.Tensor): Output tensor after area-attention.
1708
+ """
1709
+ B, C, H, W = x.shape
1710
+ N = H * W
1711
+
1712
+ qkv = self.qkv(x).flatten(2).transpose(1, 2)
1713
+ if self.area > 1:
1714
+ qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
1715
+ B, N, _ = qkv.shape
1716
+ q, k, v = (
1717
+ qkv.view(B, N, self.num_heads, self.head_dim * 3)
1718
+ .permute(0, 2, 3, 1)
1719
+ .split([self.head_dim, self.head_dim, self.head_dim], dim=2)
1720
+ )
1721
+ attn = (q.transpose(-2, -1) @ k) * (self.head_dim**-0.5)
1722
+ attn = attn.softmax(dim=-1)
1723
+ x = v @ attn.transpose(-2, -1)
1724
+ x = x.permute(0, 3, 1, 2)
1725
+ v = v.permute(0, 3, 1, 2)
1726
+
1727
+ if self.area > 1:
1728
+ x = x.reshape(B // self.area, N * self.area, C)
1729
+ v = v.reshape(B // self.area, N * self.area, C)
1730
+ B, N, _ = x.shape
1731
+
1732
+ x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
1733
+ v = v.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
1734
+
1735
+ x = x + self.pe(v)
1736
+ return self.proj(x)
1737
+
1738
+
1739
+ class ABlock(nn.Module):
1740
+ """
1741
+ Area-attention block module for efficient feature extraction in YOLO models.
1742
+
1743
+ This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.
1744
+ It uses a novel area-based attention approach that is more efficient than traditional self-attention while
1745
+ maintaining effectiveness.
1746
+
1747
+ Attributes:
1748
+ attn (AAttn): Area-attention module for processing spatial features.
1749
+ mlp (nn.Sequential): Multi-layer perceptron for feature transformation.
1750
+
1751
+ Methods:
1752
+ _init_weights: Initializes module weights using truncated normal distribution.
1753
+ forward: Applies area-attention and feed-forward processing to input tensor.
1754
+
1755
+ Examples:
1756
+ >>> block = ABlock(dim=256, num_heads=8, mlp_ratio=1.2, area=1)
1757
+ >>> x = torch.randn(1, 256, 32, 32)
1758
+ >>> output = block(x)
1759
+ >>> print(output.shape)
1760
+ torch.Size([1, 256, 32, 32])
1761
+ """
1762
+
1763
+ def __init__(self, dim, num_heads, mlp_ratio=1.2, area=1):
1764
+ """
1765
+ Initialize an Area-attention block module.
1766
+
1767
+ Args:
1768
+ dim (int): Number of input channels.
1769
+ num_heads (int): Number of heads into which the attention mechanism is divided.
1770
+ mlp_ratio (float): Expansion ratio for MLP hidden dimension.
1771
+ area (int): Number of areas the feature map is divided.
1772
+ """
1773
+ super().__init__()
1774
+
1775
+ self.attn = AAttn(dim, num_heads=num_heads, area=area)
1776
+ mlp_hidden_dim = int(dim * mlp_ratio)
1777
+ self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
1778
+
1779
+ self.apply(self._init_weights)
1780
+
1781
+ def _init_weights(self, m):
1782
+ """
1783
+ Initialize weights using a truncated normal distribution.
1784
+
1785
+ Args:
1786
+ m (nn.Module): Module to initialize.
1787
+ """
1788
+ if isinstance(m, nn.Conv2d):
1789
+ nn.init.trunc_normal_(m.weight, std=0.02)
1790
+ if m.bias is not None:
1791
+ nn.init.constant_(m.bias, 0)
1792
+
1793
+ def forward(self, x):
1794
+ """
1795
+ Forward pass through ABlock.
1796
+
1797
+ Args:
1798
+ x (torch.Tensor): Input tensor.
1799
+
1800
+ Returns:
1801
+ (torch.Tensor): Output tensor after area-attention and feed-forward processing.
1802
+ """
1803
+ x = x + self.attn(x)
1804
+ return x + self.mlp(x)
1805
+
1806
+
1807
+ class A2C2f(nn.Module):
1808
+ """
1809
+ Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
1810
+
1811
+ This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature
1812
+ processing. It supports both area-attention and standard convolution modes.
1813
+
1814
+ Attributes:
1815
+ cv1 (Conv): Initial 1x1 convolution layer that reduces input channels to hidden channels.
1816
+ cv2 (Conv): Final 1x1 convolution layer that processes concatenated features.
1817
+ gamma (nn.Parameter | None): Learnable parameter for residual scaling when using area attention.
1818
+ m (nn.ModuleList): List of either ABlock or C3k modules for feature processing.
1819
+
1820
+ Methods:
1821
+ forward: Processes input through area-attention or standard convolution pathway.
1822
+
1823
+ Examples:
1824
+ >>> m = A2C2f(512, 512, n=1, a2=True, area=1)
1825
+ >>> x = torch.randn(1, 512, 32, 32)
1826
+ >>> output = m(x)
1827
+ >>> print(output.shape)
1828
+ torch.Size([1, 512, 32, 32])
1829
+ """
1830
+
1831
+ def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True):
1832
+ """
1833
+ Initialize Area-Attention C2f module.
1834
+
1835
+ Args:
1836
+ c1 (int): Number of input channels.
1837
+ c2 (int): Number of output channels.
1838
+ n (int): Number of ABlock or C3k modules to stack.
1839
+ a2 (bool): Whether to use area attention blocks. If False, uses C3k blocks instead.
1840
+ area (int): Number of areas the feature map is divided.
1841
+ residual (bool): Whether to use residual connections with learnable gamma parameter.
1842
+ mlp_ratio (float): Expansion ratio for MLP hidden dimension.
1843
+ e (float): Channel expansion ratio for hidden channels.
1844
+ g (int): Number of groups for grouped convolutions.
1845
+ shortcut (bool): Whether to use shortcut connections in C3k blocks.
1846
+ """
1847
+ super().__init__()
1848
+ c_ = int(c2 * e) # hidden channels
1849
+ assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32."
1850
+
1851
+ self.cv1 = Conv(c1, c_, 1, 1)
1852
+ self.cv2 = Conv((1 + n) * c_, c2, 1)
1853
+
1854
+ self.gamma = nn.Parameter(0.01 * torch.ones(c2), requires_grad=True) if a2 and residual else None
1855
+ self.m = nn.ModuleList(
1856
+ nn.Sequential(*(ABlock(c_, c_ // 32, mlp_ratio, area) for _ in range(2)))
1857
+ if a2
1858
+ else C3k(c_, c_, 2, shortcut, g)
1859
+ for _ in range(n)
1860
+ )
1861
+
1862
+ def forward(self, x):
1863
+ """
1864
+ Forward pass through A2C2f layer.
1865
+
1866
+ Args:
1867
+ x (torch.Tensor): Input tensor.
1868
+
1869
+ Returns:
1870
+ (torch.Tensor): Output tensor after processing.
1871
+ """
1872
+ y = [self.cv1(x)]
1873
+ y.extend(m(y[-1]) for m in self.m)
1874
+ y = self.cv2(torch.cat(y, 1))
1875
+ if self.gamma is not None:
1876
+ return x + self.gamma.view(-1, len(self.gamma), 1, 1) * y
1877
+ return y
1878
+
1879
+
1880
+ class SwiGLUFFN(nn.Module):
1881
+ """SwiGLU Feed-Forward Network for transformer-based architectures."""
1882
+
1883
+ def __init__(self, gc, ec, e=4) -> None:
1884
+ """Initialize SwiGLU FFN with input dimension, output dimension, and expansion factor."""
1885
+ super().__init__()
1886
+ self.w12 = nn.Linear(gc, e * ec)
1887
+ self.w3 = nn.Linear(e * ec // 2, ec)
1888
+
1889
+ def forward(self, x):
1890
+ """Apply SwiGLU transformation to input features."""
1891
+ x12 = self.w12(x)
1892
+ x1, x2 = x12.chunk(2, dim=-1)
1893
+ hidden = F.silu(x1) * x2
1894
+ return self.w3(hidden)
1895
+
1896
+
1897
+ class Residual(nn.Module):
1898
+ """Residual connection wrapper for neural network modules."""
1899
+
1900
+ def __init__(self, m) -> None:
1901
+ """Initialize residual module with the wrapped module."""
1902
+ super().__init__()
1903
+ self.m = m
1904
+ nn.init.zeros_(self.m.w3.bias)
1905
+ # For models with l scale, please change the initialization to
1906
+ # nn.init.constant_(self.m.w3.weight, 1e-6)
1907
+ nn.init.zeros_(self.m.w3.weight)
1908
+
1909
+ def forward(self, x):
1910
+ """Apply residual connection to input features."""
1911
+ return x + self.m(x)
1912
+
1913
+
1914
+ class SAVPE(nn.Module):
1915
+ """Spatial-Aware Visual Prompt Embedding module for feature enhancement."""
1916
+
1917
+ def __init__(self, ch, c3, embed):
1918
+ """Initialize SAVPE module with channels, intermediate channels, and embedding dimension."""
1919
+ super().__init__()
1920
+ self.cv1 = nn.ModuleList(
1921
+ nn.Sequential(
1922
+ Conv(x, c3, 3), Conv(c3, c3, 3), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity()
1923
+ )
1924
+ for i, x in enumerate(ch)
1925
+ )
1926
+
1927
+ self.cv2 = nn.ModuleList(
1928
+ nn.Sequential(Conv(x, c3, 1), nn.Upsample(scale_factor=i * 2) if i in {1, 2} else nn.Identity())
1929
+ for i, x in enumerate(ch)
1930
+ )
1931
+
1932
+ self.c = 16
1933
+ self.cv3 = nn.Conv2d(3 * c3, embed, 1)
1934
+ self.cv4 = nn.Conv2d(3 * c3, self.c, 3, padding=1)
1935
+ self.cv5 = nn.Conv2d(1, self.c, 3, padding=1)
1936
+ self.cv6 = nn.Sequential(Conv(2 * self.c, self.c, 3), nn.Conv2d(self.c, self.c, 3, padding=1))
1937
+
1938
+ def forward(self, x, vp):
1939
+ """Process input features and visual prompts to generate enhanced embeddings."""
1940
+ y = [self.cv2[i](xi) for i, xi in enumerate(x)]
1941
+ y = self.cv4(torch.cat(y, dim=1))
1942
+
1943
+ x = [self.cv1[i](xi) for i, xi in enumerate(x)]
1944
+ x = self.cv3(torch.cat(x, dim=1))
1945
+
1946
+ B, C, H, W = x.shape
1947
+
1948
+ Q = vp.shape[1]
1949
+
1950
+ x = x.view(B, C, -1)
1951
+
1952
+ y = y.reshape(B, 1, self.c, H, W).expand(-1, Q, -1, -1, -1).reshape(B * Q, self.c, H, W)
1953
+ vp = vp.reshape(B, Q, 1, H, W).reshape(B * Q, 1, H, W)
1954
+
1955
+ y = self.cv6(torch.cat((y, self.cv5(vp)), dim=1))
1956
+
1957
+ y = y.reshape(B, Q, self.c, -1)
1958
+ vp = vp.reshape(B, Q, 1, -1)
1959
+
1960
+ score = y * vp + torch.logical_not(vp) * torch.finfo(y.dtype).min
1961
+
1962
+ score = F.softmax(score, dim=-1, dtype=torch.float).to(score.dtype)
1963
+
1964
+ aggregated = score.transpose(-2, -3) @ x.reshape(B, self.c, C // self.c, -1).transpose(-1, -2)
1965
+
1966
+ return F.normalize(aggregated.transpose(-2, -3).reshape(B, Q, -1), dim=-1, p=2)