dgenerate-ultralytics-headless 8.3.214__py3-none-any.whl → 8.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (249) hide show
  1. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/METADATA +64 -74
  2. dgenerate_ultralytics_headless-8.4.7.dist-info/RECORD +311 -0
  3. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/WHEEL +1 -1
  4. tests/__init__.py +7 -9
  5. tests/conftest.py +8 -15
  6. tests/test_cli.py +1 -1
  7. tests/test_cuda.py +13 -10
  8. tests/test_engine.py +9 -9
  9. tests/test_exports.py +65 -13
  10. tests/test_integrations.py +13 -13
  11. tests/test_python.py +125 -69
  12. tests/test_solutions.py +161 -152
  13. ultralytics/__init__.py +1 -1
  14. ultralytics/cfg/__init__.py +86 -92
  15. ultralytics/cfg/datasets/Argoverse.yaml +7 -6
  16. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  17. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  18. ultralytics/cfg/datasets/ImageNet.yaml +1 -1
  19. ultralytics/cfg/datasets/TT100K.yaml +346 -0
  20. ultralytics/cfg/datasets/VOC.yaml +15 -16
  21. ultralytics/cfg/datasets/african-wildlife.yaml +1 -1
  22. ultralytics/cfg/datasets/coco-pose.yaml +21 -0
  23. ultralytics/cfg/datasets/coco12-formats.yaml +101 -0
  24. ultralytics/cfg/datasets/coco128-seg.yaml +1 -1
  25. ultralytics/cfg/datasets/coco8-pose.yaml +21 -0
  26. ultralytics/cfg/datasets/dog-pose.yaml +28 -0
  27. ultralytics/cfg/datasets/dota8-multispectral.yaml +1 -1
  28. ultralytics/cfg/datasets/dota8.yaml +2 -2
  29. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -2
  30. ultralytics/cfg/datasets/kitti.yaml +27 -0
  31. ultralytics/cfg/datasets/lvis.yaml +5 -5
  32. ultralytics/cfg/datasets/open-images-v7.yaml +1 -1
  33. ultralytics/cfg/datasets/tiger-pose.yaml +16 -0
  34. ultralytics/cfg/datasets/xView.yaml +16 -16
  35. ultralytics/cfg/default.yaml +4 -2
  36. ultralytics/cfg/models/11/yolo11-pose.yaml +1 -1
  37. ultralytics/cfg/models/11/yoloe-11-seg.yaml +2 -2
  38. ultralytics/cfg/models/11/yoloe-11.yaml +2 -2
  39. ultralytics/cfg/models/26/yolo26-cls.yaml +33 -0
  40. ultralytics/cfg/models/26/yolo26-obb.yaml +52 -0
  41. ultralytics/cfg/models/26/yolo26-p2.yaml +60 -0
  42. ultralytics/cfg/models/26/yolo26-p6.yaml +62 -0
  43. ultralytics/cfg/models/26/yolo26-pose.yaml +53 -0
  44. ultralytics/cfg/models/26/yolo26-seg.yaml +52 -0
  45. ultralytics/cfg/models/26/yolo26.yaml +52 -0
  46. ultralytics/cfg/models/26/yoloe-26-seg.yaml +53 -0
  47. ultralytics/cfg/models/26/yoloe-26.yaml +53 -0
  48. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +1 -1
  49. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +1 -1
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +1 -1
  51. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +1 -1
  52. ultralytics/cfg/models/v10/yolov10b.yaml +2 -2
  53. ultralytics/cfg/models/v10/yolov10l.yaml +2 -2
  54. ultralytics/cfg/models/v10/yolov10m.yaml +2 -2
  55. ultralytics/cfg/models/v10/yolov10n.yaml +2 -2
  56. ultralytics/cfg/models/v10/yolov10s.yaml +2 -2
  57. ultralytics/cfg/models/v10/yolov10x.yaml +2 -2
  58. ultralytics/cfg/models/v3/yolov3-tiny.yaml +1 -1
  59. ultralytics/cfg/models/v6/yolov6.yaml +1 -1
  60. ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +9 -6
  61. ultralytics/cfg/models/v8/yoloe-v8.yaml +9 -6
  62. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +1 -1
  63. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +1 -1
  64. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +2 -2
  65. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +2 -2
  66. ultralytics/cfg/models/v8/yolov8-ghost.yaml +2 -2
  67. ultralytics/cfg/models/v8/yolov8-obb.yaml +1 -1
  68. ultralytics/cfg/models/v8/yolov8-p2.yaml +1 -1
  69. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +1 -1
  70. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +1 -1
  71. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +1 -1
  72. ultralytics/cfg/models/v8/yolov8-world.yaml +1 -1
  73. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +6 -6
  74. ultralytics/cfg/models/v9/yolov9s.yaml +1 -1
  75. ultralytics/data/__init__.py +4 -4
  76. ultralytics/data/annotator.py +5 -6
  77. ultralytics/data/augment.py +300 -475
  78. ultralytics/data/base.py +18 -26
  79. ultralytics/data/build.py +147 -25
  80. ultralytics/data/converter.py +108 -87
  81. ultralytics/data/dataset.py +47 -75
  82. ultralytics/data/loaders.py +42 -49
  83. ultralytics/data/split.py +5 -6
  84. ultralytics/data/split_dota.py +8 -15
  85. ultralytics/data/utils.py +36 -45
  86. ultralytics/engine/exporter.py +351 -263
  87. ultralytics/engine/model.py +186 -225
  88. ultralytics/engine/predictor.py +45 -54
  89. ultralytics/engine/results.py +198 -325
  90. ultralytics/engine/trainer.py +165 -106
  91. ultralytics/engine/tuner.py +41 -43
  92. ultralytics/engine/validator.py +55 -38
  93. ultralytics/hub/__init__.py +16 -19
  94. ultralytics/hub/auth.py +6 -12
  95. ultralytics/hub/google/__init__.py +7 -10
  96. ultralytics/hub/session.py +15 -25
  97. ultralytics/hub/utils.py +5 -8
  98. ultralytics/models/__init__.py +1 -1
  99. ultralytics/models/fastsam/__init__.py +1 -1
  100. ultralytics/models/fastsam/model.py +8 -10
  101. ultralytics/models/fastsam/predict.py +18 -30
  102. ultralytics/models/fastsam/utils.py +1 -2
  103. ultralytics/models/fastsam/val.py +5 -7
  104. ultralytics/models/nas/__init__.py +1 -1
  105. ultralytics/models/nas/model.py +5 -8
  106. ultralytics/models/nas/predict.py +7 -9
  107. ultralytics/models/nas/val.py +1 -2
  108. ultralytics/models/rtdetr/__init__.py +1 -1
  109. ultralytics/models/rtdetr/model.py +5 -8
  110. ultralytics/models/rtdetr/predict.py +15 -19
  111. ultralytics/models/rtdetr/train.py +10 -13
  112. ultralytics/models/rtdetr/val.py +21 -23
  113. ultralytics/models/sam/__init__.py +15 -2
  114. ultralytics/models/sam/amg.py +14 -20
  115. ultralytics/models/sam/build.py +26 -19
  116. ultralytics/models/sam/build_sam3.py +377 -0
  117. ultralytics/models/sam/model.py +29 -32
  118. ultralytics/models/sam/modules/blocks.py +83 -144
  119. ultralytics/models/sam/modules/decoders.py +19 -37
  120. ultralytics/models/sam/modules/encoders.py +44 -101
  121. ultralytics/models/sam/modules/memory_attention.py +16 -30
  122. ultralytics/models/sam/modules/sam.py +200 -73
  123. ultralytics/models/sam/modules/tiny_encoder.py +64 -83
  124. ultralytics/models/sam/modules/transformer.py +18 -28
  125. ultralytics/models/sam/modules/utils.py +174 -50
  126. ultralytics/models/sam/predict.py +2248 -350
  127. ultralytics/models/sam/sam3/__init__.py +3 -0
  128. ultralytics/models/sam/sam3/decoder.py +546 -0
  129. ultralytics/models/sam/sam3/encoder.py +529 -0
  130. ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
  131. ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
  132. ultralytics/models/sam/sam3/model_misc.py +199 -0
  133. ultralytics/models/sam/sam3/necks.py +129 -0
  134. ultralytics/models/sam/sam3/sam3_image.py +339 -0
  135. ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
  136. ultralytics/models/sam/sam3/vitdet.py +547 -0
  137. ultralytics/models/sam/sam3/vl_combiner.py +160 -0
  138. ultralytics/models/utils/loss.py +14 -26
  139. ultralytics/models/utils/ops.py +13 -17
  140. ultralytics/models/yolo/__init__.py +1 -1
  141. ultralytics/models/yolo/classify/predict.py +10 -13
  142. ultralytics/models/yolo/classify/train.py +12 -33
  143. ultralytics/models/yolo/classify/val.py +30 -29
  144. ultralytics/models/yolo/detect/predict.py +9 -12
  145. ultralytics/models/yolo/detect/train.py +17 -23
  146. ultralytics/models/yolo/detect/val.py +77 -59
  147. ultralytics/models/yolo/model.py +43 -60
  148. ultralytics/models/yolo/obb/predict.py +7 -16
  149. ultralytics/models/yolo/obb/train.py +14 -17
  150. ultralytics/models/yolo/obb/val.py +40 -37
  151. ultralytics/models/yolo/pose/__init__.py +1 -1
  152. ultralytics/models/yolo/pose/predict.py +7 -22
  153. ultralytics/models/yolo/pose/train.py +13 -16
  154. ultralytics/models/yolo/pose/val.py +39 -58
  155. ultralytics/models/yolo/segment/predict.py +17 -21
  156. ultralytics/models/yolo/segment/train.py +7 -10
  157. ultralytics/models/yolo/segment/val.py +95 -47
  158. ultralytics/models/yolo/world/train.py +8 -14
  159. ultralytics/models/yolo/world/train_world.py +11 -34
  160. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  161. ultralytics/models/yolo/yoloe/predict.py +16 -23
  162. ultralytics/models/yolo/yoloe/train.py +36 -44
  163. ultralytics/models/yolo/yoloe/train_seg.py +11 -11
  164. ultralytics/models/yolo/yoloe/val.py +15 -20
  165. ultralytics/nn/__init__.py +7 -7
  166. ultralytics/nn/autobackend.py +159 -85
  167. ultralytics/nn/modules/__init__.py +68 -60
  168. ultralytics/nn/modules/activation.py +4 -6
  169. ultralytics/nn/modules/block.py +260 -224
  170. ultralytics/nn/modules/conv.py +52 -97
  171. ultralytics/nn/modules/head.py +831 -299
  172. ultralytics/nn/modules/transformer.py +76 -88
  173. ultralytics/nn/modules/utils.py +16 -21
  174. ultralytics/nn/tasks.py +180 -195
  175. ultralytics/nn/text_model.py +45 -69
  176. ultralytics/optim/__init__.py +5 -0
  177. ultralytics/optim/muon.py +338 -0
  178. ultralytics/solutions/__init__.py +12 -12
  179. ultralytics/solutions/ai_gym.py +13 -19
  180. ultralytics/solutions/analytics.py +15 -16
  181. ultralytics/solutions/config.py +6 -7
  182. ultralytics/solutions/distance_calculation.py +10 -13
  183. ultralytics/solutions/heatmap.py +8 -14
  184. ultralytics/solutions/instance_segmentation.py +6 -9
  185. ultralytics/solutions/object_blurrer.py +7 -10
  186. ultralytics/solutions/object_counter.py +12 -19
  187. ultralytics/solutions/object_cropper.py +8 -14
  188. ultralytics/solutions/parking_management.py +34 -32
  189. ultralytics/solutions/queue_management.py +10 -12
  190. ultralytics/solutions/region_counter.py +9 -12
  191. ultralytics/solutions/security_alarm.py +15 -20
  192. ultralytics/solutions/similarity_search.py +10 -15
  193. ultralytics/solutions/solutions.py +77 -76
  194. ultralytics/solutions/speed_estimation.py +7 -10
  195. ultralytics/solutions/streamlit_inference.py +2 -4
  196. ultralytics/solutions/templates/similarity-search.html +7 -18
  197. ultralytics/solutions/trackzone.py +7 -10
  198. ultralytics/solutions/vision_eye.py +5 -8
  199. ultralytics/trackers/__init__.py +1 -1
  200. ultralytics/trackers/basetrack.py +3 -5
  201. ultralytics/trackers/bot_sort.py +10 -27
  202. ultralytics/trackers/byte_tracker.py +21 -37
  203. ultralytics/trackers/track.py +4 -7
  204. ultralytics/trackers/utils/gmc.py +11 -22
  205. ultralytics/trackers/utils/kalman_filter.py +37 -48
  206. ultralytics/trackers/utils/matching.py +12 -15
  207. ultralytics/utils/__init__.py +124 -124
  208. ultralytics/utils/autobatch.py +2 -4
  209. ultralytics/utils/autodevice.py +17 -18
  210. ultralytics/utils/benchmarks.py +57 -71
  211. ultralytics/utils/callbacks/base.py +8 -10
  212. ultralytics/utils/callbacks/clearml.py +5 -13
  213. ultralytics/utils/callbacks/comet.py +32 -46
  214. ultralytics/utils/callbacks/dvc.py +13 -18
  215. ultralytics/utils/callbacks/mlflow.py +4 -5
  216. ultralytics/utils/callbacks/neptune.py +7 -15
  217. ultralytics/utils/callbacks/platform.py +423 -38
  218. ultralytics/utils/callbacks/raytune.py +3 -4
  219. ultralytics/utils/callbacks/tensorboard.py +25 -31
  220. ultralytics/utils/callbacks/wb.py +16 -14
  221. ultralytics/utils/checks.py +127 -85
  222. ultralytics/utils/cpu.py +3 -8
  223. ultralytics/utils/dist.py +9 -12
  224. ultralytics/utils/downloads.py +25 -33
  225. ultralytics/utils/errors.py +6 -14
  226. ultralytics/utils/events.py +2 -4
  227. ultralytics/utils/export/__init__.py +4 -236
  228. ultralytics/utils/export/engine.py +246 -0
  229. ultralytics/utils/export/imx.py +117 -63
  230. ultralytics/utils/export/tensorflow.py +231 -0
  231. ultralytics/utils/files.py +26 -30
  232. ultralytics/utils/git.py +9 -11
  233. ultralytics/utils/instance.py +30 -51
  234. ultralytics/utils/logger.py +212 -114
  235. ultralytics/utils/loss.py +601 -215
  236. ultralytics/utils/metrics.py +128 -156
  237. ultralytics/utils/nms.py +13 -16
  238. ultralytics/utils/ops.py +117 -166
  239. ultralytics/utils/patches.py +75 -21
  240. ultralytics/utils/plotting.py +75 -80
  241. ultralytics/utils/tal.py +125 -59
  242. ultralytics/utils/torch_utils.py +53 -79
  243. ultralytics/utils/tqdm.py +24 -21
  244. ultralytics/utils/triton.py +13 -19
  245. ultralytics/utils/tuner.py +19 -10
  246. dgenerate_ultralytics_headless-8.3.214.dist-info/RECORD +0 -283
  247. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/entry_points.txt +0 -0
  248. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/licenses/LICENSE +0 -0
  249. {dgenerate_ultralytics_headless-8.3.214.dist-info → dgenerate_ultralytics_headless-8.4.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,547 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
4
+
5
+ """
6
+ ViTDet backbone adapted from Detectron2.
7
+ This module implements Vision Transformer (ViT) backbone for object detection.
8
+
9
+ Rope embedding code adopted from:
10
+ 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
11
+ 2. https://github.com/naver-ai/rope-vit
12
+ 3. https://github.com/lucidrains/rotary-embedding-torch
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ from functools import partial
19
+ from typing import Callable
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint as checkpoint
25
+ from torch import Tensor
26
+
27
+ from ultralytics.models.sam.modules.blocks import PatchEmbed
28
+ from ultralytics.models.sam.modules.utils import (
29
+ apply_rotary_enc,
30
+ compute_axial_cis,
31
+ concat_rel_pos,
32
+ get_abs_pos,
33
+ window_partition,
34
+ window_unpartition,
35
+ )
36
+ from ultralytics.utils.checks import check_requirements
37
+
38
+ from .model_misc import LayerScale
39
+
40
+
41
+ class Attention(nn.Module):
42
+ """Multi-head Attention block with relative position embeddings and 2d-rope."""
43
+
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ num_heads: int = 8,
48
+ qkv_bias: bool = True,
49
+ use_rel_pos: bool = False,
50
+ rel_pos_zero_init: bool = True,
51
+ input_size: tuple[int, int] | None = None,
52
+ cls_token: bool = False,
53
+ use_rope: bool = False,
54
+ rope_theta: float = 10000.0,
55
+ rope_pt_size: tuple[int, int] | None = None,
56
+ rope_interp: bool = False,
57
+ ):
58
+ """
59
+ Args:
60
+ dim (int): Number of input channels.
61
+ num_heads (int): Number of attention heads.
62
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
63
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
64
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
65
+ input_size (int or None): Input resolution for calculating the relative positional parameter size or rope
66
+ size.
67
+ attn_type: Type of attention operation, e.g. "vanilla", "vanilla-xformer".
68
+ cls_token: whether a cls_token is present.
69
+ use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
70
+ use_rel_pos: whether to use relative positional embeddings
71
+ rope_theta: control frequencies of rope
72
+ rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
73
+ rope_interp: whether to interpolate (or extrapolate) rope to match input size.
74
+ """
75
+ super().__init__()
76
+ self.num_heads = num_heads
77
+ self.head_dim = dim // num_heads
78
+ self.scale = self.head_dim**-0.5
79
+ self.cls_token = cls_token
80
+
81
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
82
+ self.proj = nn.Linear(dim, dim)
83
+
84
+ # rel_pos embeddings and rope
85
+ self.use_rel_pos = use_rel_pos
86
+ self.input_size = input_size
87
+
88
+ self.use_rope = use_rope
89
+ self.rope_theta = rope_theta
90
+ self.rope_pt_size = rope_pt_size
91
+ self.rope_interp = rope_interp
92
+
93
+ # init rel_pos embeddings and rope
94
+ self._setup_rel_pos(rel_pos_zero_init, input_size)
95
+ self._setup_rope_freqs(input_size)
96
+
97
+ def _setup_rel_pos(self, rel_pos_zero_init: bool = True, input_size: tuple[int, int] | None = None) -> None:
98
+ """Setup relative positional embeddings."""
99
+ if not self.use_rel_pos:
100
+ self.rel_pos_h = None
101
+ self.rel_pos_w = None
102
+ return
103
+
104
+ assert input_size is not None
105
+ assert self.cls_token is False, "not supported"
106
+ # initialize relative positional embeddings
107
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
108
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
109
+
110
+ if not rel_pos_zero_init:
111
+ nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
112
+ nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
113
+
114
+ # Precompute the relative coords
115
+ H, W = input_size
116
+ q_coords = torch.arange(H)[:, None]
117
+ k_coords = torch.arange(W)[None, :]
118
+ relative_coords = (q_coords - k_coords) + (H - 1)
119
+ self.relative_coords = relative_coords.long()
120
+
121
+ def _setup_rope_freqs(self, input_size: tuple[int, int] | None = None) -> None:
122
+ """Setup 2d-rope frequencies."""
123
+ if not self.use_rope:
124
+ self.freqs_cis = None
125
+ return
126
+
127
+ assert input_size is not None
128
+ # determine rope input size
129
+ if self.rope_pt_size is None:
130
+ self.rope_pt_size = input_size
131
+
132
+ # initialize 2d rope freqs
133
+ self.compute_cis = partial(
134
+ compute_axial_cis,
135
+ dim=self.head_dim,
136
+ theta=self.rope_theta,
137
+ )
138
+
139
+ # interpolate rope
140
+ scale_pos = 1.0
141
+ if self.rope_interp:
142
+ scale_pos = self.rope_pt_size[0] / input_size[0]
143
+ # get scaled freqs_cis
144
+ freqs_cis = self.compute_cis(
145
+ end_x=input_size[0],
146
+ end_y=input_size[1],
147
+ scale_pos=scale_pos,
148
+ )
149
+ if self.cls_token:
150
+ t = torch.zeros(
151
+ self.head_dim // 2,
152
+ dtype=torch.float32,
153
+ device=freqs_cis.device,
154
+ )
155
+ cls_freqs_cis = torch.polar(torch.ones_like(t), t)[None, :]
156
+ freqs_cis = torch.cat([cls_freqs_cis, freqs_cis], dim=0)
157
+
158
+ self.freqs_cis = freqs_cis
159
+
160
+ def _apply_rope(self, q, k) -> tuple[Tensor, Tensor]:
161
+ """Apply 2d-rope to q and k."""
162
+ if not self.use_rope:
163
+ return q, k
164
+
165
+ assert self.freqs_cis is not None
166
+ return apply_rotary_enc(q, k, freqs_cis=self.freqs_cis.to(q.device))
167
+
168
+ def forward(self, x: Tensor) -> Tensor:
169
+ """Forward pass of attention block."""
170
+ s = 1 if self.cls_token else 0 # used to exclude cls_token
171
+ if x.ndim == 4:
172
+ B, H, W, _ = x.shape
173
+ assert s == 0 # no cls_token
174
+ L = H * W
175
+ ndim = 4
176
+ else:
177
+ assert x.ndim == 3
178
+ B, L, _ = x.shape
179
+ ndim = 3
180
+ H = W = math.sqrt(L - s)
181
+
182
+ # qkv with shape (3, B, nHead, L, C)
183
+ qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, -1)
184
+ # q, k, v with shape (B, nHead, L, C)
185
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
186
+
187
+ # handle rope and rel pos embeddings
188
+ q, k = self._apply_rope(q, k)
189
+ if self.use_rel_pos:
190
+ q, k = concat_rel_pos(
191
+ q.flatten(0, 1),
192
+ k.flatten(0, 1),
193
+ (H, W),
194
+ x.shape[1:3],
195
+ self.rel_pos_h,
196
+ self.rel_pos_w,
197
+ rescale=True,
198
+ relative_coords=self.relative_coords,
199
+ )
200
+
201
+ # sdpa expects [B, nheads, H*W, C] so we transpose back
202
+ q = q.reshape(B, self.num_heads, H * W, -1)
203
+ k = k.reshape(B, self.num_heads, H * W, -1)
204
+
205
+ x = F.scaled_dot_product_attention(q, k, v)
206
+
207
+ if ndim == 4:
208
+ x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
209
+ else:
210
+ x = x.view(B, self.num_heads, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
211
+
212
+ x = self.proj(x)
213
+
214
+ return x
215
+
216
+
217
+ class Block(nn.Module):
218
+ """Transformer blocks with support of window attention."""
219
+
220
+ def __init__(
221
+ self,
222
+ dim: int,
223
+ num_heads: int,
224
+ mlp_ratio: float = 4.0,
225
+ qkv_bias: bool = True,
226
+ drop_path: float = 0.0,
227
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
228
+ act_layer: Callable[..., nn.Module] = nn.GELU,
229
+ use_rel_pos: bool = False,
230
+ rel_pos_zero_init: bool = True,
231
+ window_size: int = 0,
232
+ input_size: tuple[int, int] | None = None,
233
+ use_rope: bool = False,
234
+ rope_pt_size: tuple[int, int] | None = None,
235
+ rope_interp: bool = False,
236
+ cls_token: bool = False,
237
+ dropout: float = 0.0,
238
+ init_values: float | None = None,
239
+ ):
240
+ """
241
+ Args:
242
+ dim (int): Number of input channels.
243
+ num_heads (int): Number of attention heads in each ViT block.
244
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
245
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
246
+ drop_path (float): Stochastic depth rate.
247
+ norm_layer (nn.Module): Normalization layer.
248
+ act_layer (nn.Module): Activation layer.
249
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
250
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
251
+ window_size (int): Window size for window attention blocks. If it equals 0, then not use window attention.
252
+ input_size (int or None): Input resolution for calculating the relative positional parameter size.
253
+ dropout (float): Dropout rate.
254
+ cls_token: whether a cls_token is present.
255
+ use_rope: whether to use rope 2d (indep of use_rel_pos, as it can be used together)
256
+ rope_pt_size: size of rope in previous stage of training, needed for interpolation or tiling
257
+ rope_interp: whether to interpolate (or extrapolate) rope to match target input size, expected to specify
258
+ source size as rope_pt_size.
259
+ init_values: layer scale init, None for no layer scale.
260
+ """
261
+ super().__init__()
262
+
263
+ check_requirements("timm")
264
+ from timm.layers import DropPath, Mlp
265
+
266
+ self.norm1 = norm_layer(dim)
267
+ self.attn = Attention(
268
+ dim,
269
+ num_heads=num_heads,
270
+ qkv_bias=qkv_bias,
271
+ use_rel_pos=use_rel_pos,
272
+ rel_pos_zero_init=rel_pos_zero_init,
273
+ input_size=input_size if window_size == 0 else (window_size, window_size),
274
+ use_rope=use_rope,
275
+ rope_pt_size=rope_pt_size,
276
+ rope_interp=rope_interp,
277
+ cls_token=cls_token,
278
+ )
279
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
280
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
281
+
282
+ self.norm2 = norm_layer(dim)
283
+ self.mlp = Mlp(
284
+ in_features=dim,
285
+ hidden_features=int(dim * mlp_ratio),
286
+ act_layer=act_layer,
287
+ drop=(dropout, 0.0),
288
+ )
289
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
290
+ self.dropout = nn.Dropout(dropout)
291
+ self.window_size = window_size
292
+
293
+ def forward(self, x: Tensor) -> Tensor:
294
+ """Forward pass of the transformer block."""
295
+ shortcut = x
296
+ x = self.norm1(x)
297
+ # Window partition
298
+ if self.window_size > 0:
299
+ H, W = x.shape[1], x.shape[2]
300
+ x, pad_hw = window_partition(x, self.window_size)
301
+
302
+ x = self.ls1(self.attn(x))
303
+ # Reverse window partition
304
+ if self.window_size > 0:
305
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
306
+
307
+ x = shortcut + self.dropout(self.drop_path(x))
308
+ x = x + self.dropout(self.drop_path(self.ls2(self.mlp(self.norm2(x)))))
309
+
310
+ return x
311
+
312
+
313
+ class ViT(nn.Module):
314
+ """This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. "Exploring Plain Vision Transformer
315
+ Backbones for Object Detection", https://arxiv.org/abs/2203.16527.
316
+ """
317
+
318
+ def __init__(
319
+ self,
320
+ img_size: int = 1024,
321
+ patch_size: int = 16,
322
+ in_chans: int = 3,
323
+ embed_dim: int = 768,
324
+ depth: int = 12,
325
+ num_heads: int = 12,
326
+ mlp_ratio: float = 4.0,
327
+ qkv_bias: bool = True,
328
+ drop_path_rate: float = 0.0,
329
+ norm_layer: Callable[..., nn.Module] | str = "LayerNorm",
330
+ act_layer: Callable[..., nn.Module] = nn.GELU,
331
+ use_abs_pos: bool = True,
332
+ tile_abs_pos: bool = True,
333
+ rel_pos_blocks: tuple[int, ...] | bool = (2, 5, 8, 11),
334
+ rel_pos_zero_init: bool = True,
335
+ window_size: int = 14,
336
+ global_att_blocks: tuple[int, ...] = (2, 5, 8, 11),
337
+ use_rope: bool = False,
338
+ rope_pt_size: int | None = None,
339
+ use_interp_rope: bool = False,
340
+ pretrain_img_size: int = 224,
341
+ pretrain_use_cls_token: bool = True,
342
+ retain_cls_token: bool = True,
343
+ dropout: float = 0.0,
344
+ return_interm_layers: bool = False,
345
+ init_values: float | None = None, # for layerscale
346
+ ln_pre: bool = False,
347
+ ln_post: bool = False,
348
+ bias_patch_embed: bool = True,
349
+ compile_mode: str | None = None,
350
+ use_act_checkpoint: bool = True,
351
+ ):
352
+ """
353
+ Args:
354
+ img_size (int): Input image size. Only relevant for rel pos or rope.
355
+ patch_size (int): Patch size.
356
+ in_chans (int): Number of input image channels.
357
+ embed_dim (int): Patch embedding dimension.
358
+ depth (int): Depth of ViT.
359
+ num_heads (int): Number of attention heads in each ViT block.
360
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
361
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
362
+ drop_path_rate (float): Stochastic depth rate.
363
+ norm_layer (nn.Module): Normalization layer.
364
+ act_layer (nn.Module): Activation layer.
365
+ use_abs_pos (bool): If True, use absolute positional embeddings.
366
+ tile_abs_pos (bool): If True, tile absolute positional embeddings instead of interpolation.
367
+ rel_pos_blocks (list): Blocks which have rel pos embeddings.
368
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
369
+ window_size (int): Window size for window attention blocks.
370
+ global_att_blocks (list): Indexes for blocks using global attention (other blocks use window attention).
371
+ use_rope (bool): whether to use rope 2d (indep of rel_pos_blocks, as it can be used together).
372
+ rope_pt_size (int): size of rope in previous stage of training, needed for interpolation or tiling.
373
+ use_interp_rope: whether to interpolate (or extrapolate) rope to match target input size, expected to
374
+ specify source size as rope_pt_size.
375
+ use_act_checkpoint (bool): If True, use activation checkpointing.
376
+ pretrain_img_size (int): input image size for pretraining models.
377
+ pretrain_use_cls_token (bool): If True, pretraining models use class token.
378
+ retain_cls_token: whether cls_token should be retained.
379
+ dropout (float): Dropout rate. Applied in residual blocks of attn, mlp and inside the mlp.
380
+ return_interm_layers (bool): Whether to return intermediate layers (all global attention blocks).
381
+ init_values: layer scale init, None for no layer scale.
382
+ ln_pre (bool): If True, apply layer norm before transformer blocks.
383
+ ln_post (bool): If True, apply layer norm after transformer blocks.
384
+ bias_patch_embed (bool): bias in conv for patch embed?
385
+ compile_mode (str): mode to compile the forward.
386
+ """
387
+ super().__init__()
388
+ self.pretrain_use_cls_token = pretrain_use_cls_token
389
+
390
+ window_block_indexes = [i for i in range(depth) if i not in global_att_blocks]
391
+ self.full_attn_ids = list(global_att_blocks)
392
+ self.rel_pos_blocks = [False] * depth
393
+ if isinstance(rel_pos_blocks, bool) and rel_pos_blocks:
394
+ self.rel_pos_blocks = [True] * depth
395
+ else:
396
+ for i in rel_pos_blocks:
397
+ self.rel_pos_blocks[i] = True
398
+
399
+ self.retain_cls_token = retain_cls_token
400
+ if self.retain_cls_token:
401
+ assert pretrain_use_cls_token
402
+ assert len(window_block_indexes) == 0, "windowing not supported with cls token"
403
+
404
+ assert sum(self.rel_pos_blocks) == 0, "rel pos not supported with cls token"
405
+
406
+ scale = embed_dim**-0.5
407
+ self.class_embedding = nn.Parameter(scale * torch.randn(1, 1, embed_dim))
408
+
409
+ if isinstance(norm_layer, str):
410
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-5)
411
+
412
+ self.patch_embed = PatchEmbed(
413
+ kernel_size=(patch_size, patch_size),
414
+ stride=(patch_size, patch_size),
415
+ in_chans=in_chans,
416
+ embed_dim=embed_dim,
417
+ bias=bias_patch_embed,
418
+ )
419
+
420
+ # Handle absolute positional embedding
421
+ self.tile_abs_pos = tile_abs_pos
422
+ self.use_abs_pos = use_abs_pos
423
+ if self.tile_abs_pos:
424
+ assert self.use_abs_pos
425
+
426
+ if self.use_abs_pos:
427
+ # Initialize absolute positional embedding with pretrain image size.
428
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
429
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
430
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
431
+ else:
432
+ self.pos_embed = None
433
+
434
+ # stochastic depth decay rule
435
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
436
+
437
+ self.patch_size = patch_size
438
+ self.window_size = window_size
439
+ self.blocks = nn.ModuleList()
440
+ cur_stage = 1
441
+ for i in range(depth):
442
+ block = Block(
443
+ dim=embed_dim,
444
+ num_heads=num_heads,
445
+ mlp_ratio=mlp_ratio,
446
+ qkv_bias=qkv_bias,
447
+ drop_path=dpr[i],
448
+ norm_layer=norm_layer,
449
+ act_layer=act_layer,
450
+ use_rel_pos=self.rel_pos_blocks[i],
451
+ rel_pos_zero_init=rel_pos_zero_init,
452
+ window_size=window_size if i in window_block_indexes else 0,
453
+ input_size=(img_size // patch_size, img_size // patch_size),
454
+ use_rope=use_rope,
455
+ rope_pt_size=((window_size, window_size) if rope_pt_size is None else (rope_pt_size, rope_pt_size)),
456
+ rope_interp=use_interp_rope,
457
+ cls_token=self.retain_cls_token,
458
+ dropout=dropout,
459
+ init_values=init_values,
460
+ )
461
+
462
+ if i not in window_block_indexes:
463
+ cur_stage += 1
464
+
465
+ self.use_act_checkpoint = use_act_checkpoint
466
+
467
+ self.blocks.append(block)
468
+
469
+ self.return_interm_layers = return_interm_layers
470
+ self.channel_list = [embed_dim] * len(self.full_attn_ids) if return_interm_layers else [embed_dim]
471
+
472
+ if self.pos_embed is not None:
473
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
474
+
475
+ self.ln_pre = norm_layer(embed_dim) if ln_pre else nn.Identity()
476
+ self.ln_post = norm_layer(embed_dim) if ln_post else nn.Identity()
477
+
478
+ self.apply(self._init_weights)
479
+
480
+ if compile_mode is not None:
481
+ self.forward = torch.compile(self.forward, mode=compile_mode, fullgraph=True)
482
+ if self.use_act_checkpoint and self.training:
483
+ torch._dynamo.config.optimize_ddp = False
484
+
485
+ @staticmethod
486
+ def _init_weights(m: nn.Module) -> None:
487
+ """Initialize the weights."""
488
+ if isinstance(m, nn.Linear):
489
+ nn.init.trunc_normal_(m.weight, std=0.02)
490
+ if isinstance(m, nn.Linear) and m.bias is not None:
491
+ nn.init.constant_(m.bias, 0)
492
+ elif isinstance(m, nn.LayerNorm):
493
+ nn.init.constant_(m.bias, 0)
494
+ nn.init.constant_(m.weight, 1.0)
495
+
496
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
497
+ """Vit forward path and get feature maps."""
498
+ x = self.patch_embed(x)
499
+ h, w = x.shape[1], x.shape[2]
500
+
501
+ s = 0
502
+ if self.retain_cls_token:
503
+ # If cls_token is retained, we don't
504
+ # maintain spatial shape
505
+ x = torch.cat([self.class_embedding, x.flatten(1, 2)], dim=1)
506
+ s = 1
507
+
508
+ if self.pos_embed is not None:
509
+ x = x + get_abs_pos(
510
+ self.pos_embed,
511
+ self.pretrain_use_cls_token,
512
+ (h, w),
513
+ self.retain_cls_token,
514
+ tiling=self.tile_abs_pos,
515
+ )
516
+
517
+ x = self.ln_pre(x)
518
+
519
+ outputs = []
520
+ for i, blk in enumerate(self.blocks):
521
+ if self.use_act_checkpoint and self.training:
522
+ x = checkpoint.checkpoint(blk, x, use_reentrant=False)
523
+ else:
524
+ x = blk(x)
525
+ if (i == self.full_attn_ids[-1]) or (self.return_interm_layers and i in self.full_attn_ids):
526
+ if i == self.full_attn_ids[-1]:
527
+ x = self.ln_post(x)
528
+
529
+ feats = x[:, s:]
530
+ if feats.ndim == 4:
531
+ feats = feats.permute(0, 3, 1, 2)
532
+ else:
533
+ assert feats.ndim == 3
534
+ h = w = math.sqrt(feats.shape[1])
535
+ feats = feats.reshape(feats.shape[0], h, w, feats.shape[-1]).permute(0, 3, 1, 2)
536
+
537
+ outputs.append(feats)
538
+
539
+ return outputs
540
+
541
+ def set_imgsz(self, imgsz: list[int] = [1008, 1008]):
542
+ """Setup rel pos embeddings and rope freqs for a new input image size."""
543
+ for block in self.blocks:
544
+ if block.window_size != 0:
545
+ continue
546
+ block.attn._setup_rel_pos(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))
547
+ block.attn._setup_rope_freqs(input_size=(imgsz[0] // self.patch_size, imgsz[1] // self.patch_size))