ultralytics 8.1.29__py3-none-any.whl → 8.3.63__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (247) hide show
  1. tests/__init__.py +22 -0
  2. tests/conftest.py +83 -0
  3. tests/test_cli.py +122 -0
  4. tests/test_cuda.py +155 -0
  5. tests/test_engine.py +131 -0
  6. tests/test_exports.py +216 -0
  7. tests/test_integrations.py +150 -0
  8. tests/test_python.py +615 -0
  9. tests/test_solutions.py +94 -0
  10. ultralytics/__init__.py +11 -8
  11. ultralytics/cfg/__init__.py +569 -131
  12. ultralytics/cfg/datasets/Argoverse.yaml +2 -1
  13. ultralytics/cfg/datasets/DOTAv1.5.yaml +3 -2
  14. ultralytics/cfg/datasets/DOTAv1.yaml +3 -2
  15. ultralytics/cfg/datasets/GlobalWheat2020.yaml +3 -2
  16. ultralytics/cfg/datasets/ImageNet.yaml +2 -1
  17. ultralytics/cfg/datasets/Objects365.yaml +5 -4
  18. ultralytics/cfg/datasets/SKU-110K.yaml +2 -1
  19. ultralytics/cfg/datasets/VOC.yaml +3 -2
  20. ultralytics/cfg/datasets/VisDrone.yaml +6 -5
  21. ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
  22. ultralytics/cfg/datasets/brain-tumor.yaml +23 -0
  23. ultralytics/cfg/datasets/carparts-seg.yaml +3 -2
  24. ultralytics/cfg/datasets/coco-pose.yaml +7 -6
  25. ultralytics/cfg/datasets/coco.yaml +3 -2
  26. ultralytics/cfg/datasets/coco128-seg.yaml +4 -3
  27. ultralytics/cfg/datasets/coco128.yaml +4 -3
  28. ultralytics/cfg/datasets/coco8-pose.yaml +3 -2
  29. ultralytics/cfg/datasets/coco8-seg.yaml +3 -2
  30. ultralytics/cfg/datasets/coco8.yaml +3 -2
  31. ultralytics/cfg/datasets/crack-seg.yaml +3 -2
  32. ultralytics/cfg/datasets/dog-pose.yaml +24 -0
  33. ultralytics/cfg/datasets/dota8.yaml +3 -2
  34. ultralytics/cfg/datasets/hand-keypoints.yaml +26 -0
  35. ultralytics/cfg/datasets/lvis.yaml +1236 -0
  36. ultralytics/cfg/datasets/medical-pills.yaml +22 -0
  37. ultralytics/cfg/datasets/open-images-v7.yaml +2 -1
  38. ultralytics/cfg/datasets/package-seg.yaml +5 -4
  39. ultralytics/cfg/datasets/signature.yaml +21 -0
  40. ultralytics/cfg/datasets/tiger-pose.yaml +3 -2
  41. ultralytics/cfg/datasets/xView.yaml +2 -1
  42. ultralytics/cfg/default.yaml +14 -11
  43. ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +24 -0
  44. ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
  45. ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
  46. ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
  47. ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
  48. ultralytics/cfg/models/11/yolo11.yaml +50 -0
  49. ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +5 -2
  50. ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +5 -2
  51. ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +5 -2
  52. ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +5 -2
  53. ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
  54. ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
  55. ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
  56. ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
  57. ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
  58. ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
  59. ultralytics/cfg/models/v3/yolov3-spp.yaml +5 -2
  60. ultralytics/cfg/models/v3/yolov3-tiny.yaml +5 -2
  61. ultralytics/cfg/models/v3/yolov3.yaml +5 -2
  62. ultralytics/cfg/models/v5/yolov5-p6.yaml +5 -2
  63. ultralytics/cfg/models/v5/yolov5.yaml +5 -2
  64. ultralytics/cfg/models/v6/yolov6.yaml +5 -2
  65. ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +5 -2
  66. ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +5 -2
  67. ultralytics/cfg/models/v8/yolov8-cls.yaml +5 -2
  68. ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +6 -2
  69. ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +6 -2
  70. ultralytics/cfg/models/v8/yolov8-ghost.yaml +5 -2
  71. ultralytics/cfg/models/v8/yolov8-obb.yaml +5 -2
  72. ultralytics/cfg/models/v8/yolov8-p2.yaml +5 -2
  73. ultralytics/cfg/models/v8/yolov8-p6.yaml +10 -7
  74. ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +5 -2
  75. ultralytics/cfg/models/v8/yolov8-pose.yaml +5 -2
  76. ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +5 -2
  77. ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +5 -2
  78. ultralytics/cfg/models/v8/yolov8-seg.yaml +5 -2
  79. ultralytics/cfg/models/v8/yolov8-world.yaml +5 -2
  80. ultralytics/cfg/models/v8/yolov8-worldv2.yaml +5 -2
  81. ultralytics/cfg/models/v8/yolov8.yaml +5 -2
  82. ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
  83. ultralytics/cfg/models/v9/yolov9c.yaml +30 -25
  84. ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
  85. ultralytics/cfg/models/v9/yolov9e.yaml +46 -42
  86. ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
  87. ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
  88. ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
  89. ultralytics/cfg/solutions/default.yaml +24 -0
  90. ultralytics/cfg/trackers/botsort.yaml +8 -5
  91. ultralytics/cfg/trackers/bytetrack.yaml +8 -5
  92. ultralytics/data/__init__.py +14 -3
  93. ultralytics/data/annotator.py +37 -15
  94. ultralytics/data/augment.py +1783 -289
  95. ultralytics/data/base.py +62 -27
  96. ultralytics/data/build.py +37 -8
  97. ultralytics/data/converter.py +196 -36
  98. ultralytics/data/dataset.py +233 -94
  99. ultralytics/data/loaders.py +199 -96
  100. ultralytics/data/split_dota.py +39 -29
  101. ultralytics/data/utils.py +111 -41
  102. ultralytics/engine/__init__.py +1 -1
  103. ultralytics/engine/exporter.py +579 -244
  104. ultralytics/engine/model.py +604 -252
  105. ultralytics/engine/predictor.py +22 -11
  106. ultralytics/engine/results.py +1228 -218
  107. ultralytics/engine/trainer.py +191 -129
  108. ultralytics/engine/tuner.py +18 -18
  109. ultralytics/engine/validator.py +18 -15
  110. ultralytics/hub/__init__.py +31 -13
  111. ultralytics/hub/auth.py +11 -7
  112. ultralytics/hub/google/__init__.py +159 -0
  113. ultralytics/hub/session.py +128 -94
  114. ultralytics/hub/utils.py +20 -21
  115. ultralytics/models/__init__.py +4 -2
  116. ultralytics/models/fastsam/__init__.py +2 -3
  117. ultralytics/models/fastsam/model.py +26 -4
  118. ultralytics/models/fastsam/predict.py +127 -63
  119. ultralytics/models/fastsam/utils.py +1 -44
  120. ultralytics/models/fastsam/val.py +1 -1
  121. ultralytics/models/nas/__init__.py +1 -1
  122. ultralytics/models/nas/model.py +21 -10
  123. ultralytics/models/nas/predict.py +3 -6
  124. ultralytics/models/nas/val.py +4 -4
  125. ultralytics/models/rtdetr/__init__.py +1 -1
  126. ultralytics/models/rtdetr/model.py +1 -1
  127. ultralytics/models/rtdetr/predict.py +6 -8
  128. ultralytics/models/rtdetr/train.py +6 -2
  129. ultralytics/models/rtdetr/val.py +3 -3
  130. ultralytics/models/sam/__init__.py +3 -3
  131. ultralytics/models/sam/amg.py +29 -23
  132. ultralytics/models/sam/build.py +211 -13
  133. ultralytics/models/sam/model.py +91 -30
  134. ultralytics/models/sam/modules/__init__.py +1 -1
  135. ultralytics/models/sam/modules/blocks.py +1129 -0
  136. ultralytics/models/sam/modules/decoders.py +381 -53
  137. ultralytics/models/sam/modules/encoders.py +515 -324
  138. ultralytics/models/sam/modules/memory_attention.py +237 -0
  139. ultralytics/models/sam/modules/sam.py +969 -21
  140. ultralytics/models/sam/modules/tiny_encoder.py +425 -154
  141. ultralytics/models/sam/modules/transformer.py +159 -60
  142. ultralytics/models/sam/modules/utils.py +293 -0
  143. ultralytics/models/sam/predict.py +1263 -132
  144. ultralytics/models/utils/__init__.py +1 -1
  145. ultralytics/models/utils/loss.py +36 -24
  146. ultralytics/models/utils/ops.py +3 -7
  147. ultralytics/models/yolo/__init__.py +3 -3
  148. ultralytics/models/yolo/classify/__init__.py +1 -1
  149. ultralytics/models/yolo/classify/predict.py +7 -8
  150. ultralytics/models/yolo/classify/train.py +17 -22
  151. ultralytics/models/yolo/classify/val.py +8 -4
  152. ultralytics/models/yolo/detect/__init__.py +1 -1
  153. ultralytics/models/yolo/detect/predict.py +3 -5
  154. ultralytics/models/yolo/detect/train.py +11 -4
  155. ultralytics/models/yolo/detect/val.py +90 -52
  156. ultralytics/models/yolo/model.py +14 -9
  157. ultralytics/models/yolo/obb/__init__.py +1 -1
  158. ultralytics/models/yolo/obb/predict.py +2 -2
  159. ultralytics/models/yolo/obb/train.py +5 -3
  160. ultralytics/models/yolo/obb/val.py +41 -23
  161. ultralytics/models/yolo/pose/__init__.py +1 -1
  162. ultralytics/models/yolo/pose/predict.py +3 -5
  163. ultralytics/models/yolo/pose/train.py +2 -2
  164. ultralytics/models/yolo/pose/val.py +51 -17
  165. ultralytics/models/yolo/segment/__init__.py +1 -1
  166. ultralytics/models/yolo/segment/predict.py +3 -5
  167. ultralytics/models/yolo/segment/train.py +2 -2
  168. ultralytics/models/yolo/segment/val.py +60 -19
  169. ultralytics/models/yolo/world/__init__.py +5 -0
  170. ultralytics/models/yolo/world/train.py +92 -0
  171. ultralytics/models/yolo/world/train_world.py +109 -0
  172. ultralytics/nn/__init__.py +1 -1
  173. ultralytics/nn/autobackend.py +228 -93
  174. ultralytics/nn/modules/__init__.py +39 -14
  175. ultralytics/nn/modules/activation.py +21 -0
  176. ultralytics/nn/modules/block.py +526 -66
  177. ultralytics/nn/modules/conv.py +24 -7
  178. ultralytics/nn/modules/head.py +177 -34
  179. ultralytics/nn/modules/transformer.py +6 -5
  180. ultralytics/nn/modules/utils.py +1 -2
  181. ultralytics/nn/tasks.py +226 -82
  182. ultralytics/solutions/__init__.py +30 -1
  183. ultralytics/solutions/ai_gym.py +96 -143
  184. ultralytics/solutions/analytics.py +247 -0
  185. ultralytics/solutions/distance_calculation.py +78 -135
  186. ultralytics/solutions/heatmap.py +93 -247
  187. ultralytics/solutions/object_counter.py +184 -259
  188. ultralytics/solutions/parking_management.py +246 -0
  189. ultralytics/solutions/queue_management.py +112 -0
  190. ultralytics/solutions/region_counter.py +116 -0
  191. ultralytics/solutions/security_alarm.py +144 -0
  192. ultralytics/solutions/solutions.py +178 -0
  193. ultralytics/solutions/speed_estimation.py +86 -174
  194. ultralytics/solutions/streamlit_inference.py +190 -0
  195. ultralytics/solutions/trackzone.py +68 -0
  196. ultralytics/trackers/__init__.py +1 -1
  197. ultralytics/trackers/basetrack.py +32 -13
  198. ultralytics/trackers/bot_sort.py +61 -28
  199. ultralytics/trackers/byte_tracker.py +83 -51
  200. ultralytics/trackers/track.py +21 -6
  201. ultralytics/trackers/utils/__init__.py +1 -1
  202. ultralytics/trackers/utils/gmc.py +62 -48
  203. ultralytics/trackers/utils/kalman_filter.py +166 -35
  204. ultralytics/trackers/utils/matching.py +40 -21
  205. ultralytics/utils/__init__.py +511 -239
  206. ultralytics/utils/autobatch.py +40 -22
  207. ultralytics/utils/benchmarks.py +266 -85
  208. ultralytics/utils/callbacks/__init__.py +1 -1
  209. ultralytics/utils/callbacks/base.py +1 -3
  210. ultralytics/utils/callbacks/clearml.py +7 -6
  211. ultralytics/utils/callbacks/comet.py +39 -17
  212. ultralytics/utils/callbacks/dvc.py +1 -1
  213. ultralytics/utils/callbacks/hub.py +16 -16
  214. ultralytics/utils/callbacks/mlflow.py +28 -24
  215. ultralytics/utils/callbacks/neptune.py +6 -2
  216. ultralytics/utils/callbacks/raytune.py +3 -4
  217. ultralytics/utils/callbacks/tensorboard.py +18 -18
  218. ultralytics/utils/callbacks/wb.py +27 -20
  219. ultralytics/utils/checks.py +172 -100
  220. ultralytics/utils/dist.py +2 -1
  221. ultralytics/utils/downloads.py +40 -34
  222. ultralytics/utils/errors.py +1 -1
  223. ultralytics/utils/files.py +72 -38
  224. ultralytics/utils/instance.py +41 -19
  225. ultralytics/utils/loss.py +83 -55
  226. ultralytics/utils/metrics.py +61 -56
  227. ultralytics/utils/ops.py +94 -89
  228. ultralytics/utils/patches.py +30 -14
  229. ultralytics/utils/plotting.py +600 -269
  230. ultralytics/utils/tal.py +67 -26
  231. ultralytics/utils/torch_utils.py +305 -112
  232. ultralytics/utils/triton.py +2 -1
  233. ultralytics/utils/tuner.py +21 -12
  234. ultralytics-8.3.63.dist-info/METADATA +370 -0
  235. ultralytics-8.3.63.dist-info/RECORD +241 -0
  236. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/WHEEL +1 -1
  237. ultralytics/data/explorer/__init__.py +0 -5
  238. ultralytics/data/explorer/explorer.py +0 -472
  239. ultralytics/data/explorer/gui/__init__.py +0 -1
  240. ultralytics/data/explorer/gui/dash.py +0 -268
  241. ultralytics/data/explorer/utils.py +0 -166
  242. ultralytics/models/fastsam/prompt.py +0 -357
  243. ultralytics-8.1.29.dist-info/METADATA +0 -373
  244. ultralytics-8.1.29.dist-info/RECORD +0 -197
  245. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/LICENSE +0 -0
  246. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/entry_points.txt +0 -0
  247. {ultralytics-8.1.29.dist-info → ultralytics-8.3.63.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
1
+ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
+
3
+ import copy
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+
9
+ from .blocks import RoPEAttention
10
+
11
+
12
+ class MemoryAttentionLayer(nn.Module):
13
+ """
14
+ Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
15
+
16
+ This class combines self-attention, cross-attention, and feedforward components to process input tensors and
17
+ generate memory-based attention outputs.
18
+
19
+ Attributes:
20
+ d_model (int): Dimensionality of the model.
21
+ dim_feedforward (int): Dimensionality of the feedforward network.
22
+ dropout_value (float): Dropout rate for regularization.
23
+ self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
24
+ cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
25
+ linear1 (nn.Linear): First linear layer of the feedforward network.
26
+ linear2 (nn.Linear): Second linear layer of the feedforward network.
27
+ norm1 (nn.LayerNorm): Layer normalization for self-attention output.
28
+ norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
29
+ norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
30
+ dropout1 (nn.Dropout): Dropout layer after self-attention.
31
+ dropout2 (nn.Dropout): Dropout layer after cross-attention.
32
+ dropout3 (nn.Dropout): Dropout layer after feedforward network.
33
+ activation (nn.ReLU): Activation function for the feedforward network.
34
+ pos_enc_at_attn (bool): Flag to add positional encoding at attention.
35
+ pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
36
+ pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
37
+
38
+ Methods:
39
+ forward: Performs the full memory attention operation on input tensors.
40
+ _forward_sa: Performs self-attention on input tensor.
41
+ _forward_ca: Performs cross-attention between target and memory tensors.
42
+
43
+ Examples:
44
+ >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
45
+ >>> tgt = torch.randn(1, 100, 256)
46
+ >>> memory = torch.randn(1, 100, 64)
47
+ >>> pos = torch.randn(1, 100, 256)
48
+ >>> query_pos = torch.randn(1, 100, 256)
49
+ >>> output = layer(tgt, memory, pos, query_pos)
50
+ >>> print(output.shape)
51
+ torch.Size([1, 100, 256])
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ d_model: int = 256,
57
+ dim_feedforward: int = 2048,
58
+ dropout: float = 0.1,
59
+ pos_enc_at_attn: bool = False,
60
+ pos_enc_at_cross_attn_keys: bool = True,
61
+ pos_enc_at_cross_attn_queries: bool = False,
62
+ ):
63
+ """Initializes a memory attention layer with self-attention, cross-attention, and feedforward components."""
64
+ super().__init__()
65
+ self.d_model = d_model
66
+ self.dim_feedforward = dim_feedforward
67
+ self.dropout_value = dropout
68
+ self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
69
+ self.cross_attn_image = RoPEAttention(
70
+ rope_k_repeat=True,
71
+ embedding_dim=256,
72
+ num_heads=1,
73
+ downsample_rate=1,
74
+ kv_in_dim=64,
75
+ )
76
+
77
+ # Implementation of Feedforward model
78
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
79
+ self.dropout = nn.Dropout(dropout)
80
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
81
+
82
+ self.norm1 = nn.LayerNorm(d_model)
83
+ self.norm2 = nn.LayerNorm(d_model)
84
+ self.norm3 = nn.LayerNorm(d_model)
85
+ self.dropout1 = nn.Dropout(dropout)
86
+ self.dropout2 = nn.Dropout(dropout)
87
+ self.dropout3 = nn.Dropout(dropout)
88
+
89
+ self.activation = nn.ReLU()
90
+
91
+ # Where to add pos enc
92
+ self.pos_enc_at_attn = pos_enc_at_attn
93
+ self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
94
+ self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
95
+
96
+ def _forward_sa(self, tgt, query_pos):
97
+ """Performs self-attention on input tensor using positional encoding and RoPE attention mechanism."""
98
+ tgt2 = self.norm1(tgt)
99
+ q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
100
+ tgt2 = self.self_attn(q, k, v=tgt2)
101
+ tgt = tgt + self.dropout1(tgt2)
102
+ return tgt
103
+
104
+ def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
105
+ """Performs cross-attention between target and memory tensors using RoPEAttention mechanism."""
106
+ kwds = {}
107
+ if num_k_exclude_rope > 0:
108
+ assert isinstance(self.cross_attn_image, RoPEAttention)
109
+ kwds = {"num_k_exclude_rope": num_k_exclude_rope}
110
+
111
+ # Cross-Attention
112
+ tgt2 = self.norm2(tgt)
113
+ tgt2 = self.cross_attn_image(
114
+ q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
115
+ k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
116
+ v=memory,
117
+ **kwds,
118
+ )
119
+ tgt = tgt + self.dropout2(tgt2)
120
+ return tgt
121
+
122
+ def forward(
123
+ self,
124
+ tgt,
125
+ memory,
126
+ pos: Optional[Tensor] = None,
127
+ query_pos: Optional[Tensor] = None,
128
+ num_k_exclude_rope: int = 0,
129
+ ) -> torch.Tensor:
130
+ """Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention."""
131
+ tgt = self._forward_sa(tgt, query_pos)
132
+ tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
133
+ # MLP
134
+ tgt2 = self.norm3(tgt)
135
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
136
+ tgt = tgt + self.dropout3(tgt2)
137
+ return tgt
138
+
139
+
140
+ class MemoryAttention(nn.Module):
141
+ """
142
+ Memory attention module for processing sequential data with self and cross-attention mechanisms.
143
+
144
+ This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
145
+ for processing sequential data, particularly useful in transformer-like architectures.
146
+
147
+ Attributes:
148
+ d_model (int): The dimension of the model's hidden state.
149
+ layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
150
+ num_layers (int): The number of attention layers.
151
+ norm (nn.LayerNorm): Layer normalization applied to the output.
152
+ pos_enc_at_input (bool): Whether to apply positional encoding at the input.
153
+ batch_first (bool): Whether the input tensors are in batch-first format.
154
+
155
+ Methods:
156
+ forward: Processes input tensors through the attention layers.
157
+
158
+ Examples:
159
+ >>> d_model = 256
160
+ >>> layer = MemoryAttentionLayer(d_model)
161
+ >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
162
+ >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
163
+ >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
164
+ >>> curr_pos = torch.randn(10, 32, d_model)
165
+ >>> memory_pos = torch.randn(20, 32, d_model)
166
+ >>> output = attention(curr, memory, curr_pos, memory_pos)
167
+ >>> print(output.shape)
168
+ torch.Size([10, 32, 256])
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ d_model: int,
174
+ pos_enc_at_input: bool,
175
+ layer: nn.Module,
176
+ num_layers: int,
177
+ batch_first: bool = True, # Do layers expect batch first input?
178
+ ):
179
+ """Initializes MemoryAttention module with layers and normalization for attention processing."""
180
+ super().__init__()
181
+ self.d_model = d_model
182
+ self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
183
+ self.num_layers = num_layers
184
+ self.norm = nn.LayerNorm(d_model)
185
+ self.pos_enc_at_input = pos_enc_at_input
186
+ self.batch_first = batch_first
187
+
188
+ def forward(
189
+ self,
190
+ curr: torch.Tensor, # self-attention inputs
191
+ memory: torch.Tensor, # cross-attention inputs
192
+ curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
193
+ memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
194
+ num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
195
+ ):
196
+ """Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms."""
197
+ if isinstance(curr, list):
198
+ assert isinstance(curr_pos, list)
199
+ assert len(curr) == len(curr_pos) == 1
200
+ curr, curr_pos = (
201
+ curr[0],
202
+ curr_pos[0],
203
+ )
204
+
205
+ assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
206
+
207
+ output = curr
208
+ if self.pos_enc_at_input and curr_pos is not None:
209
+ output = output + 0.1 * curr_pos
210
+
211
+ if self.batch_first:
212
+ # Convert to batch first
213
+ output = output.transpose(0, 1)
214
+ curr_pos = curr_pos.transpose(0, 1)
215
+ memory = memory.transpose(0, 1)
216
+ memory_pos = memory_pos.transpose(0, 1)
217
+
218
+ for layer in self.layers:
219
+ kwds = {}
220
+ if isinstance(layer.cross_attn_image, RoPEAttention):
221
+ kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
222
+
223
+ output = layer(
224
+ tgt=output,
225
+ memory=memory,
226
+ pos=memory_pos,
227
+ query_pos=curr_pos,
228
+ **kwds,
229
+ )
230
+ normed_output = self.norm(output)
231
+
232
+ if self.batch_first:
233
+ # Convert back to seq first
234
+ normed_output = normed_output.transpose(0, 1)
235
+ curr_pos = curr_pos.transpose(0, 1)
236
+
237
+ return normed_output