ultralytics-opencv-headless 8.3.242__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (298) hide show
  1. tests/__init__.py +23 -0
  2. tests/conftest.py +59 -0
  3. tests/test_cli.py +131 -0
  4. tests/test_cuda.py +216 -0
  5. tests/test_engine.py +157 -0
  6. tests/test_exports.py +309 -0
  7. tests/test_integrations.py +151 -0
  8. tests/test_python.py +777 -0
  9. tests/test_solutions.py +371 -0
  10. ultralytics/__init__.py +48 -0
  11. ultralytics/assets/bus.jpg +0 -0
  12. ultralytics/assets/zidane.jpg +0 -0
  13. ultralytics/cfg/__init__.py +1026 -0
  14. ultralytics/cfg/datasets/Argoverse.yaml +78 -0
  15. ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
  16. ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
  17. ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
  18. ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
  19. ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
  20. ultralytics/cfg/datasets/Objects365.yaml +447 -0
  21. ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
  22. ultralytics/cfg/datasets/VOC.yaml +102 -0
  23. ultralytics/cfg/datasets/VisDrone.yaml +87 -0
  24. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  25. ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
  26. ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
  27. ultralytics/cfg/datasets/coco-pose.yaml +64 -0
  28. ultralytics/cfg/datasets/coco.yaml +118 -0
  29. ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
  30. ultralytics/cfg/datasets/coco128.yaml +101 -0
  31. ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
  32. ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
  33. ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
  34. ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
  35. ultralytics/cfg/datasets/coco8.yaml +101 -0
  36. ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
  37. ultralytics/cfg/datasets/crack-seg.yaml +22 -0
  38. ultralytics/cfg/datasets/dog-pose.yaml +52 -0
  39. ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
  40. ultralytics/cfg/datasets/dota8.yaml +35 -0
  41. ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
  42. ultralytics/cfg/datasets/kitti.yaml +27 -0
  43. ultralytics/cfg/datasets/lvis.yaml +1240 -0
  44. ultralytics/cfg/datasets/medical-pills.yaml +21 -0
  45. ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
  46. ultralytics/cfg/datasets/package-seg.yaml +22 -0
  47. ultralytics/cfg/datasets/signature.yaml +21 -0
  48. ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
  49. ultralytics/cfg/datasets/xView.yaml +155 -0
  50. ultralytics/cfg/default.yaml +130 -0
  51. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
  52. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  53. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  54. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  55. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  56. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  57. ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
  58. ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
  59. ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
  60. ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
  61. ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
  62. ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
  63. ultralytics/cfg/models/12/yolo12.yaml +48 -0
  64. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
  65. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
  66. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
  67. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
  68. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  69. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  70. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  71. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  72. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  73. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  74. ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
  75. ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
  76. ultralytics/cfg/models/v3/yolov3.yaml +49 -0
  77. ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
  78. ultralytics/cfg/models/v5/yolov5.yaml +51 -0
  79. ultralytics/cfg/models/v6/yolov6.yaml +56 -0
  80. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
  81. ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
  82. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
  83. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
  84. ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
  85. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
  86. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
  87. ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
  88. ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
  89. ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
  90. ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
  91. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
  92. ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
  93. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
  94. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
  95. ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
  96. ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
  97. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
  98. ultralytics/cfg/models/v8/yolov8.yaml +49 -0
  99. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  100. ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
  101. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  102. ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
  103. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  104. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  105. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  106. ultralytics/cfg/trackers/botsort.yaml +21 -0
  107. ultralytics/cfg/trackers/bytetrack.yaml +12 -0
  108. ultralytics/data/__init__.py +26 -0
  109. ultralytics/data/annotator.py +66 -0
  110. ultralytics/data/augment.py +2801 -0
  111. ultralytics/data/base.py +435 -0
  112. ultralytics/data/build.py +437 -0
  113. ultralytics/data/converter.py +855 -0
  114. ultralytics/data/dataset.py +834 -0
  115. ultralytics/data/loaders.py +704 -0
  116. ultralytics/data/scripts/download_weights.sh +18 -0
  117. ultralytics/data/scripts/get_coco.sh +61 -0
  118. ultralytics/data/scripts/get_coco128.sh +18 -0
  119. ultralytics/data/scripts/get_imagenet.sh +52 -0
  120. ultralytics/data/split.py +138 -0
  121. ultralytics/data/split_dota.py +344 -0
  122. ultralytics/data/utils.py +798 -0
  123. ultralytics/engine/__init__.py +1 -0
  124. ultralytics/engine/exporter.py +1574 -0
  125. ultralytics/engine/model.py +1124 -0
  126. ultralytics/engine/predictor.py +508 -0
  127. ultralytics/engine/results.py +1522 -0
  128. ultralytics/engine/trainer.py +974 -0
  129. ultralytics/engine/tuner.py +448 -0
  130. ultralytics/engine/validator.py +384 -0
  131. ultralytics/hub/__init__.py +166 -0
  132. ultralytics/hub/auth.py +151 -0
  133. ultralytics/hub/google/__init__.py +174 -0
  134. ultralytics/hub/session.py +422 -0
  135. ultralytics/hub/utils.py +162 -0
  136. ultralytics/models/__init__.py +9 -0
  137. ultralytics/models/fastsam/__init__.py +7 -0
  138. ultralytics/models/fastsam/model.py +79 -0
  139. ultralytics/models/fastsam/predict.py +169 -0
  140. ultralytics/models/fastsam/utils.py +23 -0
  141. ultralytics/models/fastsam/val.py +38 -0
  142. ultralytics/models/nas/__init__.py +7 -0
  143. ultralytics/models/nas/model.py +98 -0
  144. ultralytics/models/nas/predict.py +56 -0
  145. ultralytics/models/nas/val.py +38 -0
  146. ultralytics/models/rtdetr/__init__.py +7 -0
  147. ultralytics/models/rtdetr/model.py +63 -0
  148. ultralytics/models/rtdetr/predict.py +88 -0
  149. ultralytics/models/rtdetr/train.py +89 -0
  150. ultralytics/models/rtdetr/val.py +216 -0
  151. ultralytics/models/sam/__init__.py +25 -0
  152. ultralytics/models/sam/amg.py +275 -0
  153. ultralytics/models/sam/build.py +365 -0
  154. ultralytics/models/sam/build_sam3.py +377 -0
  155. ultralytics/models/sam/model.py +169 -0
  156. ultralytics/models/sam/modules/__init__.py +1 -0
  157. ultralytics/models/sam/modules/blocks.py +1067 -0
  158. ultralytics/models/sam/modules/decoders.py +495 -0
  159. ultralytics/models/sam/modules/encoders.py +794 -0
  160. ultralytics/models/sam/modules/memory_attention.py +298 -0
  161. ultralytics/models/sam/modules/sam.py +1160 -0
  162. ultralytics/models/sam/modules/tiny_encoder.py +979 -0
  163. ultralytics/models/sam/modules/transformer.py +344 -0
  164. ultralytics/models/sam/modules/utils.py +512 -0
  165. ultralytics/models/sam/predict.py +3940 -0
  166. ultralytics/models/sam/sam3/__init__.py +3 -0
  167. ultralytics/models/sam/sam3/decoder.py +546 -0
  168. ultralytics/models/sam/sam3/encoder.py +529 -0
  169. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  170. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  171. ultralytics/models/sam/sam3/model_misc.py +199 -0
  172. ultralytics/models/sam/sam3/necks.py +129 -0
  173. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  174. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  175. ultralytics/models/sam/sam3/vitdet.py +547 -0
  176. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  177. ultralytics/models/utils/__init__.py +1 -0
  178. ultralytics/models/utils/loss.py +466 -0
  179. ultralytics/models/utils/ops.py +315 -0
  180. ultralytics/models/yolo/__init__.py +7 -0
  181. ultralytics/models/yolo/classify/__init__.py +7 -0
  182. ultralytics/models/yolo/classify/predict.py +90 -0
  183. ultralytics/models/yolo/classify/train.py +202 -0
  184. ultralytics/models/yolo/classify/val.py +216 -0
  185. ultralytics/models/yolo/detect/__init__.py +7 -0
  186. ultralytics/models/yolo/detect/predict.py +122 -0
  187. ultralytics/models/yolo/detect/train.py +227 -0
  188. ultralytics/models/yolo/detect/val.py +507 -0
  189. ultralytics/models/yolo/model.py +430 -0
  190. ultralytics/models/yolo/obb/__init__.py +7 -0
  191. ultralytics/models/yolo/obb/predict.py +56 -0
  192. ultralytics/models/yolo/obb/train.py +79 -0
  193. ultralytics/models/yolo/obb/val.py +302 -0
  194. ultralytics/models/yolo/pose/__init__.py +7 -0
  195. ultralytics/models/yolo/pose/predict.py +65 -0
  196. ultralytics/models/yolo/pose/train.py +110 -0
  197. ultralytics/models/yolo/pose/val.py +248 -0
  198. ultralytics/models/yolo/segment/__init__.py +7 -0
  199. ultralytics/models/yolo/segment/predict.py +109 -0
  200. ultralytics/models/yolo/segment/train.py +69 -0
  201. ultralytics/models/yolo/segment/val.py +307 -0
  202. ultralytics/models/yolo/world/__init__.py +5 -0
  203. ultralytics/models/yolo/world/train.py +173 -0
  204. ultralytics/models/yolo/world/train_world.py +178 -0
  205. ultralytics/models/yolo/yoloe/__init__.py +22 -0
  206. ultralytics/models/yolo/yoloe/predict.py +162 -0
  207. ultralytics/models/yolo/yoloe/train.py +287 -0
  208. ultralytics/models/yolo/yoloe/train_seg.py +122 -0
  209. ultralytics/models/yolo/yoloe/val.py +206 -0
  210. ultralytics/nn/__init__.py +27 -0
  211. ultralytics/nn/autobackend.py +958 -0
  212. ultralytics/nn/modules/__init__.py +182 -0
  213. ultralytics/nn/modules/activation.py +54 -0
  214. ultralytics/nn/modules/block.py +1947 -0
  215. ultralytics/nn/modules/conv.py +669 -0
  216. ultralytics/nn/modules/head.py +1183 -0
  217. ultralytics/nn/modules/transformer.py +793 -0
  218. ultralytics/nn/modules/utils.py +159 -0
  219. ultralytics/nn/tasks.py +1768 -0
  220. ultralytics/nn/text_model.py +356 -0
  221. ultralytics/py.typed +1 -0
  222. ultralytics/solutions/__init__.py +41 -0
  223. ultralytics/solutions/ai_gym.py +108 -0
  224. ultralytics/solutions/analytics.py +264 -0
  225. ultralytics/solutions/config.py +107 -0
  226. ultralytics/solutions/distance_calculation.py +123 -0
  227. ultralytics/solutions/heatmap.py +125 -0
  228. ultralytics/solutions/instance_segmentation.py +86 -0
  229. ultralytics/solutions/object_blurrer.py +89 -0
  230. ultralytics/solutions/object_counter.py +190 -0
  231. ultralytics/solutions/object_cropper.py +87 -0
  232. ultralytics/solutions/parking_management.py +280 -0
  233. ultralytics/solutions/queue_management.py +93 -0
  234. ultralytics/solutions/region_counter.py +133 -0
  235. ultralytics/solutions/security_alarm.py +151 -0
  236. ultralytics/solutions/similarity_search.py +219 -0
  237. ultralytics/solutions/solutions.py +828 -0
  238. ultralytics/solutions/speed_estimation.py +114 -0
  239. ultralytics/solutions/streamlit_inference.py +260 -0
  240. ultralytics/solutions/templates/similarity-search.html +156 -0
  241. ultralytics/solutions/trackzone.py +88 -0
  242. ultralytics/solutions/vision_eye.py +67 -0
  243. ultralytics/trackers/__init__.py +7 -0
  244. ultralytics/trackers/basetrack.py +115 -0
  245. ultralytics/trackers/bot_sort.py +257 -0
  246. ultralytics/trackers/byte_tracker.py +469 -0
  247. ultralytics/trackers/track.py +116 -0
  248. ultralytics/trackers/utils/__init__.py +1 -0
  249. ultralytics/trackers/utils/gmc.py +339 -0
  250. ultralytics/trackers/utils/kalman_filter.py +482 -0
  251. ultralytics/trackers/utils/matching.py +154 -0
  252. ultralytics/utils/__init__.py +1450 -0
  253. ultralytics/utils/autobatch.py +118 -0
  254. ultralytics/utils/autodevice.py +205 -0
  255. ultralytics/utils/benchmarks.py +728 -0
  256. ultralytics/utils/callbacks/__init__.py +5 -0
  257. ultralytics/utils/callbacks/base.py +233 -0
  258. ultralytics/utils/callbacks/clearml.py +146 -0
  259. ultralytics/utils/callbacks/comet.py +625 -0
  260. ultralytics/utils/callbacks/dvc.py +197 -0
  261. ultralytics/utils/callbacks/hub.py +110 -0
  262. ultralytics/utils/callbacks/mlflow.py +134 -0
  263. ultralytics/utils/callbacks/neptune.py +126 -0
  264. ultralytics/utils/callbacks/platform.py +73 -0
  265. ultralytics/utils/callbacks/raytune.py +42 -0
  266. ultralytics/utils/callbacks/tensorboard.py +123 -0
  267. ultralytics/utils/callbacks/wb.py +188 -0
  268. ultralytics/utils/checks.py +998 -0
  269. ultralytics/utils/cpu.py +85 -0
  270. ultralytics/utils/dist.py +123 -0
  271. ultralytics/utils/downloads.py +529 -0
  272. ultralytics/utils/errors.py +35 -0
  273. ultralytics/utils/events.py +113 -0
  274. ultralytics/utils/export/__init__.py +7 -0
  275. ultralytics/utils/export/engine.py +237 -0
  276. ultralytics/utils/export/imx.py +315 -0
  277. ultralytics/utils/export/tensorflow.py +231 -0
  278. ultralytics/utils/files.py +219 -0
  279. ultralytics/utils/git.py +137 -0
  280. ultralytics/utils/instance.py +484 -0
  281. ultralytics/utils/logger.py +444 -0
  282. ultralytics/utils/loss.py +849 -0
  283. ultralytics/utils/metrics.py +1560 -0
  284. ultralytics/utils/nms.py +337 -0
  285. ultralytics/utils/ops.py +664 -0
  286. ultralytics/utils/patches.py +201 -0
  287. ultralytics/utils/plotting.py +1045 -0
  288. ultralytics/utils/tal.py +403 -0
  289. ultralytics/utils/torch_utils.py +984 -0
  290. ultralytics/utils/tqdm.py +440 -0
  291. ultralytics/utils/triton.py +112 -0
  292. ultralytics/utils/tuner.py +160 -0
  293. ultralytics_opencv_headless-8.3.242.dist-info/METADATA +374 -0
  294. ultralytics_opencv_headless-8.3.242.dist-info/RECORD +298 -0
  295. ultralytics_opencv_headless-8.3.242.dist-info/WHEEL +5 -0
  296. ultralytics_opencv_headless-8.3.242.dist-info/entry_points.txt +3 -0
  297. ultralytics_opencv_headless-8.3.242.dist-info/licenses/LICENSE +661 -0
  298. ultralytics_opencv_headless-8.3.242.dist-info/top_level.txt +1 -0
@@ -0,0 +1,793 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+ """Transformer modules."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn.init import constant_, xavier_uniform_
12
+
13
+ from ultralytics.utils.torch_utils import TORCH_1_11
14
+
15
+ from .conv import Conv
16
+ from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
17
+
18
+ __all__ = (
19
+ "AIFI",
20
+ "MLP",
21
+ "DeformableTransformerDecoder",
22
+ "DeformableTransformerDecoderLayer",
23
+ "LayerNorm2d",
24
+ "MLPBlock",
25
+ "MSDeformAttn",
26
+ "TransformerBlock",
27
+ "TransformerEncoderLayer",
28
+ "TransformerLayer",
29
+ )
30
+
31
+
32
+ class TransformerEncoderLayer(nn.Module):
33
+ """A single layer of the transformer encoder.
34
+
35
+ This class implements a standard transformer encoder layer with multi-head attention and feedforward network,
36
+ supporting both pre-normalization and post-normalization configurations.
37
+
38
+ Attributes:
39
+ ma (nn.MultiheadAttention): Multi-head attention module.
40
+ fc1 (nn.Linear): First linear layer in the feedforward network.
41
+ fc2 (nn.Linear): Second linear layer in the feedforward network.
42
+ norm1 (nn.LayerNorm): Layer normalization after attention.
43
+ norm2 (nn.LayerNorm): Layer normalization after feedforward network.
44
+ dropout (nn.Dropout): Dropout layer for the feedforward network.
45
+ dropout1 (nn.Dropout): Dropout layer after attention.
46
+ dropout2 (nn.Dropout): Dropout layer after feedforward network.
47
+ act (nn.Module): Activation function.
48
+ normalize_before (bool): Whether to apply normalization before attention and feedforward.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ c1: int,
54
+ cm: int = 2048,
55
+ num_heads: int = 8,
56
+ dropout: float = 0.0,
57
+ act: nn.Module = nn.GELU(),
58
+ normalize_before: bool = False,
59
+ ):
60
+ """Initialize the TransformerEncoderLayer with specified parameters.
61
+
62
+ Args:
63
+ c1 (int): Input dimension.
64
+ cm (int): Hidden dimension in the feedforward network.
65
+ num_heads (int): Number of attention heads.
66
+ dropout (float): Dropout probability.
67
+ act (nn.Module): Activation function.
68
+ normalize_before (bool): Whether to apply normalization before attention and feedforward.
69
+ """
70
+ super().__init__()
71
+ from ...utils.torch_utils import TORCH_1_9
72
+
73
+ if not TORCH_1_9:
74
+ raise ModuleNotFoundError(
75
+ "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
76
+ )
77
+ self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
78
+ # Implementation of Feedforward model
79
+ self.fc1 = nn.Linear(c1, cm)
80
+ self.fc2 = nn.Linear(cm, c1)
81
+
82
+ self.norm1 = nn.LayerNorm(c1)
83
+ self.norm2 = nn.LayerNorm(c1)
84
+ self.dropout = nn.Dropout(dropout)
85
+ self.dropout1 = nn.Dropout(dropout)
86
+ self.dropout2 = nn.Dropout(dropout)
87
+
88
+ self.act = act
89
+ self.normalize_before = normalize_before
90
+
91
+ @staticmethod
92
+ def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None = None) -> torch.Tensor:
93
+ """Add position embeddings to the tensor if provided."""
94
+ return tensor if pos is None else tensor + pos
95
+
96
+ def forward_post(
97
+ self,
98
+ src: torch.Tensor,
99
+ src_mask: torch.Tensor | None = None,
100
+ src_key_padding_mask: torch.Tensor | None = None,
101
+ pos: torch.Tensor | None = None,
102
+ ) -> torch.Tensor:
103
+ """Perform forward pass with post-normalization.
104
+
105
+ Args:
106
+ src (torch.Tensor): Input tensor.
107
+ src_mask (torch.Tensor, optional): Mask for the src sequence.
108
+ src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
109
+ pos (torch.Tensor, optional): Positional encoding.
110
+
111
+ Returns:
112
+ (torch.Tensor): Output tensor after attention and feedforward.
113
+ """
114
+ q = k = self.with_pos_embed(src, pos)
115
+ src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
116
+ src = src + self.dropout1(src2)
117
+ src = self.norm1(src)
118
+ src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
119
+ src = src + self.dropout2(src2)
120
+ return self.norm2(src)
121
+
122
+ def forward_pre(
123
+ self,
124
+ src: torch.Tensor,
125
+ src_mask: torch.Tensor | None = None,
126
+ src_key_padding_mask: torch.Tensor | None = None,
127
+ pos: torch.Tensor | None = None,
128
+ ) -> torch.Tensor:
129
+ """Perform forward pass with pre-normalization.
130
+
131
+ Args:
132
+ src (torch.Tensor): Input tensor.
133
+ src_mask (torch.Tensor, optional): Mask for the src sequence.
134
+ src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
135
+ pos (torch.Tensor, optional): Positional encoding.
136
+
137
+ Returns:
138
+ (torch.Tensor): Output tensor after attention and feedforward.
139
+ """
140
+ src2 = self.norm1(src)
141
+ q = k = self.with_pos_embed(src2, pos)
142
+ src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
143
+ src = src + self.dropout1(src2)
144
+ src2 = self.norm2(src)
145
+ src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
146
+ return src + self.dropout2(src2)
147
+
148
+ def forward(
149
+ self,
150
+ src: torch.Tensor,
151
+ src_mask: torch.Tensor | None = None,
152
+ src_key_padding_mask: torch.Tensor | None = None,
153
+ pos: torch.Tensor | None = None,
154
+ ) -> torch.Tensor:
155
+ """Forward propagate the input through the encoder module.
156
+
157
+ Args:
158
+ src (torch.Tensor): Input tensor.
159
+ src_mask (torch.Tensor, optional): Mask for the src sequence.
160
+ src_key_padding_mask (torch.Tensor, optional): Mask for the src keys per batch.
161
+ pos (torch.Tensor, optional): Positional encoding.
162
+
163
+ Returns:
164
+ (torch.Tensor): Output tensor after transformer encoder layer.
165
+ """
166
+ if self.normalize_before:
167
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
168
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
169
+
170
+
171
+ class AIFI(TransformerEncoderLayer):
172
+ """AIFI transformer layer for 2D data with positional embeddings.
173
+
174
+ This class extends TransformerEncoderLayer to work with 2D feature maps by adding 2D sine-cosine positional
175
+ embeddings and handling the spatial dimensions appropriately.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ c1: int,
181
+ cm: int = 2048,
182
+ num_heads: int = 8,
183
+ dropout: float = 0,
184
+ act: nn.Module = nn.GELU(),
185
+ normalize_before: bool = False,
186
+ ):
187
+ """Initialize the AIFI instance with specified parameters.
188
+
189
+ Args:
190
+ c1 (int): Input dimension.
191
+ cm (int): Hidden dimension in the feedforward network.
192
+ num_heads (int): Number of attention heads.
193
+ dropout (float): Dropout probability.
194
+ act (nn.Module): Activation function.
195
+ normalize_before (bool): Whether to apply normalization before attention and feedforward.
196
+ """
197
+ super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ """Forward pass for the AIFI transformer layer.
201
+
202
+ Args:
203
+ x (torch.Tensor): Input tensor with shape [B, C, H, W].
204
+
205
+ Returns:
206
+ (torch.Tensor): Output tensor with shape [B, C, H, W].
207
+ """
208
+ c, h, w = x.shape[1:]
209
+ pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
210
+ # Flatten [B, C, H, W] to [B, HxW, C]
211
+ x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
212
+ return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()
213
+
214
+ @staticmethod
215
+ def build_2d_sincos_position_embedding(
216
+ w: int, h: int, embed_dim: int = 256, temperature: float = 10000.0
217
+ ) -> torch.Tensor:
218
+ """Build 2D sine-cosine position embedding.
219
+
220
+ Args:
221
+ w (int): Width of the feature map.
222
+ h (int): Height of the feature map.
223
+ embed_dim (int): Embedding dimension.
224
+ temperature (float): Temperature for the sine/cosine functions.
225
+
226
+ Returns:
227
+ (torch.Tensor): Position embedding with shape [1, embed_dim, h*w].
228
+ """
229
+ assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
230
+ grid_w = torch.arange(w, dtype=torch.float32)
231
+ grid_h = torch.arange(h, dtype=torch.float32)
232
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") if TORCH_1_11 else torch.meshgrid(grid_w, grid_h)
233
+ pos_dim = embed_dim // 4
234
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
235
+ omega = 1.0 / (temperature**omega)
236
+
237
+ out_w = grid_w.flatten()[..., None] @ omega[None]
238
+ out_h = grid_h.flatten()[..., None] @ omega[None]
239
+
240
+ return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]
241
+
242
+
243
+ class TransformerLayer(nn.Module):
244
+ """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
245
+
246
+ def __init__(self, c: int, num_heads: int):
247
+ """Initialize a self-attention mechanism using linear transformations and multi-head attention.
248
+
249
+ Args:
250
+ c (int): Input and output channel dimension.
251
+ num_heads (int): Number of attention heads.
252
+ """
253
+ super().__init__()
254
+ self.q = nn.Linear(c, c, bias=False)
255
+ self.k = nn.Linear(c, c, bias=False)
256
+ self.v = nn.Linear(c, c, bias=False)
257
+ self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
258
+ self.fc1 = nn.Linear(c, c, bias=False)
259
+ self.fc2 = nn.Linear(c, c, bias=False)
260
+
261
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
262
+ """Apply a transformer block to the input x and return the output.
263
+
264
+ Args:
265
+ x (torch.Tensor): Input tensor.
266
+
267
+ Returns:
268
+ (torch.Tensor): Output tensor after transformer layer.
269
+ """
270
+ x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
271
+ return self.fc2(self.fc1(x)) + x
272
+
273
+
274
+ class TransformerBlock(nn.Module):
275
+ """Vision Transformer block based on https://arxiv.org/abs/2010.11929.
276
+
277
+ This class implements a complete transformer block with optional convolution layer for channel adjustment, learnable
278
+ position embedding, and multiple transformer layers.
279
+
280
+ Attributes:
281
+ conv (Conv, optional): Convolution layer if input and output channels differ.
282
+ linear (nn.Linear): Learnable position embedding.
283
+ tr (nn.Sequential): Sequential container of transformer layers.
284
+ c2 (int): Output channel dimension.
285
+ """
286
+
287
+ def __init__(self, c1: int, c2: int, num_heads: int, num_layers: int):
288
+ """Initialize a Transformer module with position embedding and specified number of heads and layers.
289
+
290
+ Args:
291
+ c1 (int): Input channel dimension.
292
+ c2 (int): Output channel dimension.
293
+ num_heads (int): Number of attention heads.
294
+ num_layers (int): Number of transformer layers.
295
+ """
296
+ super().__init__()
297
+ self.conv = None
298
+ if c1 != c2:
299
+ self.conv = Conv(c1, c2)
300
+ self.linear = nn.Linear(c2, c2) # learnable position embedding
301
+ self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
302
+ self.c2 = c2
303
+
304
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
305
+ """Forward propagate the input through the transformer block.
306
+
307
+ Args:
308
+ x (torch.Tensor): Input tensor with shape [b, c1, w, h].
309
+
310
+ Returns:
311
+ (torch.Tensor): Output tensor with shape [b, c2, w, h].
312
+ """
313
+ if self.conv is not None:
314
+ x = self.conv(x)
315
+ b, _, w, h = x.shape
316
+ p = x.flatten(2).permute(2, 0, 1)
317
+ return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
318
+
319
+
320
+ class MLPBlock(nn.Module):
321
+ """A single block of a multi-layer perceptron."""
322
+
323
+ def __init__(self, embedding_dim: int, mlp_dim: int, act=nn.GELU):
324
+ """Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.
325
+
326
+ Args:
327
+ embedding_dim (int): Input and output dimension.
328
+ mlp_dim (int): Hidden dimension.
329
+ act (nn.Module): Activation function.
330
+ """
331
+ super().__init__()
332
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
333
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
334
+ self.act = act()
335
+
336
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
337
+ """Forward pass for the MLPBlock.
338
+
339
+ Args:
340
+ x (torch.Tensor): Input tensor.
341
+
342
+ Returns:
343
+ (torch.Tensor): Output tensor after MLP block.
344
+ """
345
+ return self.lin2(self.act(self.lin1(x)))
346
+
347
+
348
+ class MLP(nn.Module):
349
+ """A simple multi-layer perceptron (also called FFN).
350
+
351
+ This class implements a configurable MLP with multiple linear layers, activation functions, and optional sigmoid
352
+ output activation.
353
+
354
+ Attributes:
355
+ num_layers (int): Number of layers in the MLP.
356
+ layers (nn.ModuleList): List of linear layers.
357
+ sigmoid (bool): Whether to apply sigmoid to the output.
358
+ act (nn.Module): Activation function.
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ input_dim: int,
364
+ hidden_dim: int,
365
+ output_dim: int,
366
+ num_layers: int,
367
+ act=nn.ReLU,
368
+ sigmoid: bool = False,
369
+ residual: bool = False,
370
+ out_norm: nn.Module = None,
371
+ ):
372
+ """Initialize the MLP with specified input, hidden, output dimensions and number of layers.
373
+
374
+ Args:
375
+ input_dim (int): Input dimension.
376
+ hidden_dim (int): Hidden dimension.
377
+ output_dim (int): Output dimension.
378
+ num_layers (int): Number of layers.
379
+ act (nn.Module): Activation function.
380
+ sigmoid (bool): Whether to apply sigmoid to the output.
381
+ residual (bool): Whether to use residual connections.
382
+ out_norm (nn.Module, optional): Normalization layer for the output.
383
+ """
384
+ super().__init__()
385
+ self.num_layers = num_layers
386
+ h = [hidden_dim] * (num_layers - 1)
387
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim]))
388
+ self.sigmoid = sigmoid
389
+ self.act = act()
390
+ if residual and input_dim != output_dim:
391
+ raise ValueError("residual is only supported if input_dim == output_dim")
392
+ self.residual = residual
393
+ # whether to apply a normalization layer to the output
394
+ assert isinstance(out_norm, nn.Module) or out_norm is None
395
+ self.out_norm = out_norm or nn.Identity()
396
+
397
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
398
+ """Forward pass for the entire MLP.
399
+
400
+ Args:
401
+ x (torch.Tensor): Input tensor.
402
+
403
+ Returns:
404
+ (torch.Tensor): Output tensor after MLP.
405
+ """
406
+ orig_x = x
407
+ for i, layer in enumerate(self.layers):
408
+ x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)
409
+ if getattr(self, "residual", False):
410
+ x = x + orig_x
411
+ x = getattr(self, "out_norm", nn.Identity())(x)
412
+ return x.sigmoid() if getattr(self, "sigmoid", False) else x
413
+
414
+
415
+ class LayerNorm2d(nn.Module):
416
+ """2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations.
417
+
418
+ This class implements layer normalization for 2D feature maps, normalizing across the channel dimension while
419
+ preserving spatial dimensions.
420
+
421
+ Attributes:
422
+ weight (nn.Parameter): Learnable scale parameter.
423
+ bias (nn.Parameter): Learnable bias parameter.
424
+ eps (float): Small constant for numerical stability.
425
+
426
+ References:
427
+ https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
428
+ https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
429
+ """
430
+
431
+ def __init__(self, num_channels: int, eps: float = 1e-6):
432
+ """Initialize LayerNorm2d with the given parameters.
433
+
434
+ Args:
435
+ num_channels (int): Number of channels in the input.
436
+ eps (float): Small constant for numerical stability.
437
+ """
438
+ super().__init__()
439
+ self.weight = nn.Parameter(torch.ones(num_channels))
440
+ self.bias = nn.Parameter(torch.zeros(num_channels))
441
+ self.eps = eps
442
+
443
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
444
+ """Perform forward pass for 2D layer normalization.
445
+
446
+ Args:
447
+ x (torch.Tensor): Input tensor.
448
+
449
+ Returns:
450
+ (torch.Tensor): Normalized output tensor.
451
+ """
452
+ u = x.mean(1, keepdim=True)
453
+ s = (x - u).pow(2).mean(1, keepdim=True)
454
+ x = (x - u) / torch.sqrt(s + self.eps)
455
+ return self.weight[:, None, None] * x + self.bias[:, None, None]
456
+
457
+
458
+ class MSDeformAttn(nn.Module):
459
+ """Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations.
460
+
461
+ This module implements multiscale deformable attention that can attend to features at multiple scales with learnable
462
+ sampling locations and attention weights.
463
+
464
+ Attributes:
465
+ im2col_step (int): Step size for im2col operations.
466
+ d_model (int): Model dimension.
467
+ n_levels (int): Number of feature levels.
468
+ n_heads (int): Number of attention heads.
469
+ n_points (int): Number of sampling points per attention head per feature level.
470
+ sampling_offsets (nn.Linear): Linear layer for generating sampling offsets.
471
+ attention_weights (nn.Linear): Linear layer for generating attention weights.
472
+ value_proj (nn.Linear): Linear layer for projecting values.
473
+ output_proj (nn.Linear): Linear layer for projecting output.
474
+
475
+ References:
476
+ https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
477
+ """
478
+
479
+ def __init__(self, d_model: int = 256, n_levels: int = 4, n_heads: int = 8, n_points: int = 4):
480
+ """Initialize MSDeformAttn with the given parameters.
481
+
482
+ Args:
483
+ d_model (int): Model dimension.
484
+ n_levels (int): Number of feature levels.
485
+ n_heads (int): Number of attention heads.
486
+ n_points (int): Number of sampling points per attention head per feature level.
487
+ """
488
+ super().__init__()
489
+ if d_model % n_heads != 0:
490
+ raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}")
491
+ _d_per_head = d_model // n_heads
492
+ # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation
493
+ assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`"
494
+
495
+ self.im2col_step = 64
496
+
497
+ self.d_model = d_model
498
+ self.n_levels = n_levels
499
+ self.n_heads = n_heads
500
+ self.n_points = n_points
501
+
502
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
503
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
504
+ self.value_proj = nn.Linear(d_model, d_model)
505
+ self.output_proj = nn.Linear(d_model, d_model)
506
+
507
+ self._reset_parameters()
508
+
509
+ def _reset_parameters(self):
510
+ """Reset module parameters."""
511
+ constant_(self.sampling_offsets.weight.data, 0.0)
512
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
513
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
514
+ grid_init = (
515
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
516
+ .view(self.n_heads, 1, 1, 2)
517
+ .repeat(1, self.n_levels, self.n_points, 1)
518
+ )
519
+ for i in range(self.n_points):
520
+ grid_init[:, :, i, :] *= i + 1
521
+ with torch.no_grad():
522
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
523
+ constant_(self.attention_weights.weight.data, 0.0)
524
+ constant_(self.attention_weights.bias.data, 0.0)
525
+ xavier_uniform_(self.value_proj.weight.data)
526
+ constant_(self.value_proj.bias.data, 0.0)
527
+ xavier_uniform_(self.output_proj.weight.data)
528
+ constant_(self.output_proj.bias.data, 0.0)
529
+
530
+ def forward(
531
+ self,
532
+ query: torch.Tensor,
533
+ refer_bbox: torch.Tensor,
534
+ value: torch.Tensor,
535
+ value_shapes: list,
536
+ value_mask: torch.Tensor | None = None,
537
+ ) -> torch.Tensor:
538
+ """Perform forward pass for multiscale deformable attention.
539
+
540
+ Args:
541
+ query (torch.Tensor): Query tensor with shape [bs, query_length, C].
542
+ refer_bbox (torch.Tensor): Reference bounding boxes with shape [bs, query_length, n_levels, 2], range in [0,
543
+ 1], top-left (0,0), bottom-right (1, 1), including padding area.
544
+ value (torch.Tensor): Value tensor with shape [bs, value_length, C].
545
+ value_shapes (list): List with shape [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})].
546
+ value_mask (torch.Tensor, optional): Mask tensor with shape [bs, value_length], True for non-padding
547
+ elements, False for padding elements.
548
+
549
+ Returns:
550
+ (torch.Tensor): Output tensor with shape [bs, Length_{query}, C].
551
+
552
+ References:
553
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
554
+ """
555
+ bs, len_q = query.shape[:2]
556
+ len_v = value.shape[1]
557
+ assert sum(s[0] * s[1] for s in value_shapes) == len_v
558
+
559
+ value = self.value_proj(value)
560
+ if value_mask is not None:
561
+ value = value.masked_fill(value_mask[..., None], float(0))
562
+ value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
563
+ sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)
564
+ attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
565
+ attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
566
+ # N, Len_q, n_heads, n_levels, n_points, 2
567
+ num_points = refer_bbox.shape[-1]
568
+ if num_points == 2:
569
+ offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
570
+ add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
571
+ sampling_locations = refer_bbox[:, :, None, :, None, :] + add
572
+ elif num_points == 4:
573
+ add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
574
+ sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
575
+ else:
576
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.")
577
+ output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
578
+ return self.output_proj(output)
579
+
580
+
581
+ class DeformableTransformerDecoderLayer(nn.Module):
582
+ """Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations.
583
+
584
+ This class implements a single decoder layer with self-attention, cross-attention using multiscale deformable
585
+ attention, and a feedforward network.
586
+
587
+ Attributes:
588
+ self_attn (nn.MultiheadAttention): Self-attention module.
589
+ dropout1 (nn.Dropout): Dropout after self-attention.
590
+ norm1 (nn.LayerNorm): Layer normalization after self-attention.
591
+ cross_attn (MSDeformAttn): Cross-attention module.
592
+ dropout2 (nn.Dropout): Dropout after cross-attention.
593
+ norm2 (nn.LayerNorm): Layer normalization after cross-attention.
594
+ linear1 (nn.Linear): First linear layer in the feedforward network.
595
+ act (nn.Module): Activation function.
596
+ dropout3 (nn.Dropout): Dropout in the feedforward network.
597
+ linear2 (nn.Linear): Second linear layer in the feedforward network.
598
+ dropout4 (nn.Dropout): Dropout after the feedforward network.
599
+ norm3 (nn.LayerNorm): Layer normalization after the feedforward network.
600
+
601
+ References:
602
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
603
+ https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
604
+ """
605
+
606
+ def __init__(
607
+ self,
608
+ d_model: int = 256,
609
+ n_heads: int = 8,
610
+ d_ffn: int = 1024,
611
+ dropout: float = 0.0,
612
+ act: nn.Module = nn.ReLU(),
613
+ n_levels: int = 4,
614
+ n_points: int = 4,
615
+ ):
616
+ """Initialize the DeformableTransformerDecoderLayer with the given parameters.
617
+
618
+ Args:
619
+ d_model (int): Model dimension.
620
+ n_heads (int): Number of attention heads.
621
+ d_ffn (int): Dimension of the feedforward network.
622
+ dropout (float): Dropout probability.
623
+ act (nn.Module): Activation function.
624
+ n_levels (int): Number of feature levels.
625
+ n_points (int): Number of sampling points.
626
+ """
627
+ super().__init__()
628
+
629
+ # Self attention
630
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
631
+ self.dropout1 = nn.Dropout(dropout)
632
+ self.norm1 = nn.LayerNorm(d_model)
633
+
634
+ # Cross attention
635
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
636
+ self.dropout2 = nn.Dropout(dropout)
637
+ self.norm2 = nn.LayerNorm(d_model)
638
+
639
+ # FFN
640
+ self.linear1 = nn.Linear(d_model, d_ffn)
641
+ self.act = act
642
+ self.dropout3 = nn.Dropout(dropout)
643
+ self.linear2 = nn.Linear(d_ffn, d_model)
644
+ self.dropout4 = nn.Dropout(dropout)
645
+ self.norm3 = nn.LayerNorm(d_model)
646
+
647
+ @staticmethod
648
+ def with_pos_embed(tensor: torch.Tensor, pos: torch.Tensor | None) -> torch.Tensor:
649
+ """Add positional embeddings to the input tensor, if provided."""
650
+ return tensor if pos is None else tensor + pos
651
+
652
+ def forward_ffn(self, tgt: torch.Tensor) -> torch.Tensor:
653
+ """Perform forward pass through the Feed-Forward Network part of the layer.
654
+
655
+ Args:
656
+ tgt (torch.Tensor): Input tensor.
657
+
658
+ Returns:
659
+ (torch.Tensor): Output tensor after FFN.
660
+ """
661
+ tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
662
+ tgt = tgt + self.dropout4(tgt2)
663
+ return self.norm3(tgt)
664
+
665
+ def forward(
666
+ self,
667
+ embed: torch.Tensor,
668
+ refer_bbox: torch.Tensor,
669
+ feats: torch.Tensor,
670
+ shapes: list,
671
+ padding_mask: torch.Tensor | None = None,
672
+ attn_mask: torch.Tensor | None = None,
673
+ query_pos: torch.Tensor | None = None,
674
+ ) -> torch.Tensor:
675
+ """Perform the forward pass through the entire decoder layer.
676
+
677
+ Args:
678
+ embed (torch.Tensor): Input embeddings.
679
+ refer_bbox (torch.Tensor): Reference bounding boxes.
680
+ feats (torch.Tensor): Feature maps.
681
+ shapes (list): Feature shapes.
682
+ padding_mask (torch.Tensor, optional): Padding mask.
683
+ attn_mask (torch.Tensor, optional): Attention mask.
684
+ query_pos (torch.Tensor, optional): Query position embeddings.
685
+
686
+ Returns:
687
+ (torch.Tensor): Output tensor after decoder layer.
688
+ """
689
+ # Self attention
690
+ q = k = self.with_pos_embed(embed, query_pos)
691
+ tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[
692
+ 0
693
+ ].transpose(0, 1)
694
+ embed = embed + self.dropout1(tgt)
695
+ embed = self.norm1(embed)
696
+
697
+ # Cross attention
698
+ tgt = self.cross_attn(
699
+ self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask
700
+ )
701
+ embed = embed + self.dropout2(tgt)
702
+ embed = self.norm2(embed)
703
+
704
+ # FFN
705
+ return self.forward_ffn(embed)
706
+
707
+
708
+ class DeformableTransformerDecoder(nn.Module):
709
+ """Deformable Transformer Decoder based on PaddleDetection implementation.
710
+
711
+ This class implements a complete deformable transformer decoder with multiple decoder layers and prediction heads
712
+ for bounding box regression and classification.
713
+
714
+ Attributes:
715
+ layers (nn.ModuleList): List of decoder layers.
716
+ num_layers (int): Number of decoder layers.
717
+ hidden_dim (int): Hidden dimension.
718
+ eval_idx (int): Index of the layer to use during evaluation.
719
+
720
+ References:
721
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
722
+ """
723
+
724
+ def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1):
725
+ """Initialize the DeformableTransformerDecoder with the given parameters.
726
+
727
+ Args:
728
+ hidden_dim (int): Hidden dimension.
729
+ decoder_layer (nn.Module): Decoder layer module.
730
+ num_layers (int): Number of decoder layers.
731
+ eval_idx (int): Index of the layer to use during evaluation.
732
+ """
733
+ super().__init__()
734
+ self.layers = _get_clones(decoder_layer, num_layers)
735
+ self.num_layers = num_layers
736
+ self.hidden_dim = hidden_dim
737
+ self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
738
+
739
+ def forward(
740
+ self,
741
+ embed: torch.Tensor, # decoder embeddings
742
+ refer_bbox: torch.Tensor, # anchor
743
+ feats: torch.Tensor, # image features
744
+ shapes: list, # feature shapes
745
+ bbox_head: nn.Module,
746
+ score_head: nn.Module,
747
+ pos_mlp: nn.Module,
748
+ attn_mask: torch.Tensor | None = None,
749
+ padding_mask: torch.Tensor | None = None,
750
+ ):
751
+ """Perform the forward pass through the entire decoder.
752
+
753
+ Args:
754
+ embed (torch.Tensor): Decoder embeddings.
755
+ refer_bbox (torch.Tensor): Reference bounding boxes.
756
+ feats (torch.Tensor): Image features.
757
+ shapes (list): Feature shapes.
758
+ bbox_head (nn.Module): Bounding box prediction head.
759
+ score_head (nn.Module): Score prediction head.
760
+ pos_mlp (nn.Module): Position MLP.
761
+ attn_mask (torch.Tensor, optional): Attention mask.
762
+ padding_mask (torch.Tensor, optional): Padding mask.
763
+
764
+ Returns:
765
+ dec_bboxes (torch.Tensor): Decoded bounding boxes.
766
+ dec_cls (torch.Tensor): Decoded classification scores.
767
+ """
768
+ output = embed
769
+ dec_bboxes = []
770
+ dec_cls = []
771
+ last_refined_bbox = None
772
+ refer_bbox = refer_bbox.sigmoid()
773
+ for i, layer in enumerate(self.layers):
774
+ output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
775
+
776
+ bbox = bbox_head[i](output)
777
+ refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox))
778
+
779
+ if self.training:
780
+ dec_cls.append(score_head[i](output))
781
+ if i == 0:
782
+ dec_bboxes.append(refined_bbox)
783
+ else:
784
+ dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox)))
785
+ elif i == self.eval_idx:
786
+ dec_cls.append(score_head[i](output))
787
+ dec_bboxes.append(refined_bbox)
788
+ break
789
+
790
+ last_refined_bbox = refined_bbox
791
+ refer_bbox = refined_bbox.detach() if self.training else refined_bbox
792
+
793
+ return torch.stack(dec_bboxes), torch.stack(dec_cls)