python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +1 -0
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/loader.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1102 -54
- doctr/file_utils.py +9 -0
- doctr/io/elements.py +37 -3
- doctr/models/_utils.py +1 -1
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +1 -2
- doctr/models/classification/magc_resnet/tensorflow.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/mobilenet/tensorflow.py +11 -2
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/resnet/tensorflow.py +25 -4
- doctr/models/classification/textnet/pytorch.py +10 -1
- doctr/models/classification/textnet/tensorflow.py +11 -2
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vgg/tensorflow.py +11 -2
- doctr/models/classification/vip/__init__.py +4 -0
- doctr/models/classification/vip/layers/__init__.py +4 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/pytorch.py +10 -1
- doctr/models/classification/vit/tensorflow.py +9 -0
- doctr/models/classification/zoo.py +4 -0
- doctr/models/detection/differentiable_binarization/base.py +3 -4
- doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
- doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
- doctr/models/detection/fast/base.py +2 -3
- doctr/models/detection/fast/pytorch.py +13 -4
- doctr/models/detection/fast/tensorflow.py +10 -2
- doctr/models/detection/linknet/base.py +2 -3
- doctr/models/detection/linknet/pytorch.py +10 -1
- doctr/models/detection/linknet/tensorflow.py +10 -2
- doctr/models/factory/hub.py +3 -3
- doctr/models/kie_predictor/pytorch.py +1 -1
- doctr/models/kie_predictor/tensorflow.py +1 -1
- doctr/models/modules/layers/pytorch.py +49 -1
- doctr/models/predictor/pytorch.py +1 -1
- doctr/models/predictor/tensorflow.py +1 -1
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/pytorch.py +10 -1
- doctr/models/recognition/crnn/tensorflow.py +10 -1
- doctr/models/recognition/master/pytorch.py +10 -1
- doctr/models/recognition/master/tensorflow.py +10 -3
- doctr/models/recognition/parseq/pytorch.py +23 -5
- doctr/models/recognition/parseq/tensorflow.py +13 -5
- doctr/models/recognition/predictor/_utils.py +107 -45
- doctr/models/recognition/predictor/pytorch.py +3 -3
- doctr/models/recognition/predictor/tensorflow.py +3 -3
- doctr/models/recognition/sar/pytorch.py +10 -1
- doctr/models/recognition/sar/tensorflow.py +10 -3
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +4 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/pytorch.py +10 -1
- doctr/models/recognition/vitstr/tensorflow.py +10 -3
- doctr/models/recognition/zoo.py +5 -0
- doctr/models/utils/pytorch.py +28 -18
- doctr/models/utils/tensorflow.py +15 -8
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,615 @@
|
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
|
|
9
|
+
from doctr.models.modules.layers import DropPath
|
|
10
|
+
from doctr.models.modules.transformer import PositionwiseFeedForward
|
|
11
|
+
from doctr.models.utils import conv_sequence_pt
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"PermuteLayer",
|
|
15
|
+
"SqueezeLayer",
|
|
16
|
+
"PatchEmbed",
|
|
17
|
+
"Attention",
|
|
18
|
+
"MultiHeadSelfAttention",
|
|
19
|
+
"OverlappedSpatialReductionAttention",
|
|
20
|
+
"OSRABlock",
|
|
21
|
+
"PatchMerging",
|
|
22
|
+
"LePEAttention",
|
|
23
|
+
"CrossShapedWindowAttention",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class PermuteLayer(nn.Module):
|
|
28
|
+
"""Custom layer to permute dimensions in a Sequential model."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, dims: tuple[int, int, int, int] = (0, 2, 3, 1)):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.dims = dims
|
|
33
|
+
|
|
34
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
return x.permute(self.dims).contiguous()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SqueezeLayer(nn.Module):
|
|
39
|
+
"""Custom layer to squeeze out a dimension in a Sequential model."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, dim: int = 3):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.dim = dim
|
|
44
|
+
|
|
45
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
return x.squeeze(self.dim)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class PatchEmbed(nn.Module):
|
|
50
|
+
"""
|
|
51
|
+
Patch embedding layer for Vision Permutable Extractor.
|
|
52
|
+
|
|
53
|
+
This layer reduces the spatial resolution of the input tensor by a factor of 4 in total
|
|
54
|
+
(two consecutive strides of 2). It then permutes the output into `(b, h, w, c)` form.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
in_channels: Number of channels in the input images.
|
|
58
|
+
embed_dim: Dimensionality of the embedding (i.e., output channels).
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, in_channels: int = 3, embed_dim: int = 128) -> None:
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.embed_dim = embed_dim
|
|
64
|
+
self.proj = nn.Sequential(
|
|
65
|
+
*conv_sequence_pt(
|
|
66
|
+
in_channels, embed_dim // 2, kernel_size=3, stride=2, padding=1, bias=False, bn=True, relu=False
|
|
67
|
+
),
|
|
68
|
+
nn.GELU(),
|
|
69
|
+
*conv_sequence_pt(
|
|
70
|
+
embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False, bn=True, relu=False
|
|
71
|
+
),
|
|
72
|
+
nn.GELU(),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
76
|
+
"""
|
|
77
|
+
Forward pass for PatchEmbed.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
x: A float tensor of shape (b, c, h, w).
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
A float tensor of shape (b, h/4, w/4, embed_dim).
|
|
84
|
+
"""
|
|
85
|
+
return self.proj(x).permute(0, 2, 3, 1)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class Attention(nn.Module):
|
|
89
|
+
"""
|
|
90
|
+
Standard multi-head attention module.
|
|
91
|
+
|
|
92
|
+
This module applies self-attention across the input sequence using 'num_heads' heads.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
dim: Dimensionality of the input embeddings.
|
|
96
|
+
num_heads: Number of attention heads.
|
|
97
|
+
qkv_bias: If True, adds a learnable bias to the query, key, value projections.
|
|
98
|
+
attn_drop: Dropout rate applied to the attention map.
|
|
99
|
+
proj_drop: Dropout rate applied to the final output projection.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
dim: int,
|
|
105
|
+
num_heads: int = 8,
|
|
106
|
+
qkv_bias: bool = False,
|
|
107
|
+
attn_drop: float = 0.0,
|
|
108
|
+
proj_drop: float = 0.0,
|
|
109
|
+
) -> None:
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.num_heads = num_heads
|
|
112
|
+
head_dim = dim // num_heads
|
|
113
|
+
self.scale = head_dim**-0.5
|
|
114
|
+
|
|
115
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
116
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
117
|
+
self.proj = nn.Linear(dim, dim)
|
|
118
|
+
self.proj_drop = nn.Dropout(proj_drop)
|
|
119
|
+
|
|
120
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
"""
|
|
122
|
+
Forward pass for Attention.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
x: A float tensor of shape (b, n, c), where n is the sequence length and c is
|
|
126
|
+
the embedding dimension.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
A float tensor of shape (b, n, c) with attended information.
|
|
130
|
+
"""
|
|
131
|
+
_, n, c = x.shape
|
|
132
|
+
qkv = self.qkv(x).reshape((-1, n, 3, self.num_heads, c // self.num_heads)).permute((2, 0, 3, 1, 4))
|
|
133
|
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
|
134
|
+
|
|
135
|
+
attn = q.matmul(k.permute((0, 1, 3, 2)))
|
|
136
|
+
attn = nn.functional.softmax(attn, dim=-1)
|
|
137
|
+
attn = self.attn_drop(attn)
|
|
138
|
+
|
|
139
|
+
x = attn.matmul(v).permute((0, 2, 1, 3)).contiguous().reshape((-1, n, c))
|
|
140
|
+
x = self.proj(x)
|
|
141
|
+
x = self.proj_drop(x)
|
|
142
|
+
return x
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class MultiHeadSelfAttention(nn.Module):
|
|
146
|
+
"""
|
|
147
|
+
Multi-head Self Attention block with an MLP for feed-forward processing.
|
|
148
|
+
|
|
149
|
+
This block normalizes the input, applies attention mixing, adds a residual connection,
|
|
150
|
+
then applies an MLP with another residual connection.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
dim: Dimensionality of input embeddings.
|
|
154
|
+
num_heads: Number of attention heads.
|
|
155
|
+
mlp_ratio: Expansion factor for the internal dimension of the MLP.
|
|
156
|
+
qkv_bias: If True, adds a learnable bias to the query, key, value projections.
|
|
157
|
+
drop_path_rate: Drop path rate. If > 0, applies stochastic depth.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
dim: int,
|
|
163
|
+
num_heads: int,
|
|
164
|
+
mlp_ratio: float = 4.0,
|
|
165
|
+
qkv_bias: bool = False,
|
|
166
|
+
drop_path_rate: float = 0.0,
|
|
167
|
+
) -> None:
|
|
168
|
+
super().__init__()
|
|
169
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
170
|
+
|
|
171
|
+
self.mixer = Attention(
|
|
172
|
+
dim,
|
|
173
|
+
num_heads=num_heads,
|
|
174
|
+
qkv_bias=qkv_bias,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
|
178
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
179
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
180
|
+
self.mlp = PositionwiseFeedForward(d_model=dim, ffd=mlp_hidden_dim, dropout=0.0, activation_fct=nn.GELU())
|
|
181
|
+
|
|
182
|
+
def forward(self, x: torch.Tensor, size: tuple[int, int] | None = None) -> torch.Tensor:
|
|
183
|
+
"""
|
|
184
|
+
Forward pass for MultiHeadSelfAttention.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
x: A float tensor of shape (b, n, c).
|
|
188
|
+
size: An optional (h, w) if needed by some modules (unused here).
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
A float tensor of shape (b, n, c) after self-attention and MLP.
|
|
192
|
+
"""
|
|
193
|
+
x = x + self.drop_path(self.mixer(self.norm1(x)))
|
|
194
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
195
|
+
return x
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class OverlappedSpatialReductionAttention(nn.Module):
|
|
199
|
+
"""
|
|
200
|
+
Overlapped Spatial Reduction Attention (OSRA).
|
|
201
|
+
|
|
202
|
+
This attention mechanism downsamples the input according to 'sr_ratio' (spatial reduction ratio),
|
|
203
|
+
applies a local convolution for feature enhancement. It captures dependencies in an overlapping manner.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
dim: The embedding dimension of the tokens.
|
|
207
|
+
num_heads: Number of attention heads.
|
|
208
|
+
qk_scale: Optionally override q-k scaling. Defaults to head_dim^-0.5 if None.
|
|
209
|
+
attn_drop: Dropout rate for attention weights.
|
|
210
|
+
sr_ratio: Spatial reduction ratio. If > 1, a depthwise conv-based downsampling is applied.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
dim: int,
|
|
216
|
+
num_heads: int = 1,
|
|
217
|
+
qk_scale: float | None = None,
|
|
218
|
+
attn_drop: float = 0.0,
|
|
219
|
+
sr_ratio: int = 1,
|
|
220
|
+
) -> None:
|
|
221
|
+
super().__init__()
|
|
222
|
+
assert dim % num_heads == 0, f"dim {dim} should be divisible by num_heads {num_heads}."
|
|
223
|
+
self.dim = dim
|
|
224
|
+
self.num_heads = num_heads
|
|
225
|
+
head_dim = dim // num_heads
|
|
226
|
+
self.scale = qk_scale or head_dim**-0.5
|
|
227
|
+
self.sr_ratio = sr_ratio
|
|
228
|
+
self.q = nn.Conv2d(dim, dim, kernel_size=1)
|
|
229
|
+
self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1)
|
|
230
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
231
|
+
|
|
232
|
+
if sr_ratio > 1:
|
|
233
|
+
self.sr = nn.Sequential(
|
|
234
|
+
*conv_sequence_pt(
|
|
235
|
+
dim,
|
|
236
|
+
dim,
|
|
237
|
+
kernel_size=sr_ratio + 3,
|
|
238
|
+
stride=sr_ratio,
|
|
239
|
+
padding=(sr_ratio + 3) // 2,
|
|
240
|
+
groups=dim,
|
|
241
|
+
bias=False,
|
|
242
|
+
bn=True,
|
|
243
|
+
relu=False,
|
|
244
|
+
),
|
|
245
|
+
nn.GELU(),
|
|
246
|
+
*conv_sequence_pt(dim, dim, kernel_size=1, groups=dim, bias=False, bn=True, relu=False),
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
self.sr = nn.Identity() # type: ignore[assignment]
|
|
250
|
+
|
|
251
|
+
self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
|
|
252
|
+
|
|
253
|
+
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
254
|
+
"""
|
|
255
|
+
Forward pass for OverlappedSpatialReductionAttention.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
x: A float tensor of shape (b, n, c) where n = h * w.
|
|
259
|
+
size: A tuple (h, w) giving the height and width of the original feature map.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
A float tensor of shape (b, n, c) with updated representations.
|
|
263
|
+
"""
|
|
264
|
+
b, n, c = x.shape
|
|
265
|
+
h, w = size
|
|
266
|
+
x = x.permute(0, 2, 1).reshape(b, -1, h, w)
|
|
267
|
+
|
|
268
|
+
q = self.q(x).reshape(b, self.num_heads, c // self.num_heads, -1).transpose(-1, -2)
|
|
269
|
+
kv = self.sr(x)
|
|
270
|
+
kv = self.local_conv(kv) + kv
|
|
271
|
+
k, v = torch.chunk(self.kv(kv), chunks=2, dim=1)
|
|
272
|
+
k = k.reshape(b, self.num_heads, c // self.num_heads, -1)
|
|
273
|
+
v = v.reshape(b, self.num_heads, c // self.num_heads, -1).transpose(-1, -2)
|
|
274
|
+
|
|
275
|
+
attn = (q @ k) * self.scale
|
|
276
|
+
attn = torch.softmax(attn, dim=-1)
|
|
277
|
+
attn = self.attn_drop(attn)
|
|
278
|
+
x = (attn @ v).transpose(-1, -2).reshape(b, c, -1)
|
|
279
|
+
x = x.permute(0, 2, 1)
|
|
280
|
+
return x
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class OSRABlock(nn.Module):
|
|
284
|
+
"""
|
|
285
|
+
Global token mixing block using Overlapped Spatial Reduction Attention (OSRA).
|
|
286
|
+
|
|
287
|
+
Captures global dependencies by aggregating context from a wider spatial area,
|
|
288
|
+
followed by a position-wise feed-forward layer.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
dim: Embedding dimension of tokens.
|
|
292
|
+
sr_ratio: Spatial reduction ratio for OSRA.
|
|
293
|
+
num_heads: Number of attention heads.
|
|
294
|
+
mlp_ratio: Expansion factor for the MLP hidden dimension.
|
|
295
|
+
drop_path: Drop path rate. If > 0, applies stochastic depth.
|
|
296
|
+
"""
|
|
297
|
+
|
|
298
|
+
def __init__(
|
|
299
|
+
self,
|
|
300
|
+
dim: int = 64,
|
|
301
|
+
sr_ratio: int = 1,
|
|
302
|
+
num_heads: int = 1,
|
|
303
|
+
mlp_ratio: float = 4.0,
|
|
304
|
+
drop_path: float = 0.0,
|
|
305
|
+
) -> None:
|
|
306
|
+
super().__init__()
|
|
307
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
308
|
+
|
|
309
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
310
|
+
self.token_mixer = OverlappedSpatialReductionAttention(dim, num_heads=num_heads, sr_ratio=sr_ratio)
|
|
311
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
312
|
+
|
|
313
|
+
self.mlp = PositionwiseFeedForward(d_model=dim, ffd=mlp_hidden_dim, dropout=0.0, activation_fct=nn.GELU())
|
|
314
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
315
|
+
|
|
316
|
+
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
317
|
+
"""
|
|
318
|
+
Forward pass for OSRABlock.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
x: A float tensor of shape (b, n, c).
|
|
322
|
+
size: A tuple (h, w) giving the height and width of the original feature map.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
A float tensor of shape (b, n, c) with globally mixed features.
|
|
326
|
+
"""
|
|
327
|
+
x = x + self.drop_path(self.token_mixer(self.norm1(x), size))
|
|
328
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
329
|
+
return x
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class PatchMerging(nn.Module):
|
|
333
|
+
"""
|
|
334
|
+
Patch Merging Layer.
|
|
335
|
+
|
|
336
|
+
Reduces the spatial dimension by half along the height. If the input has shape
|
|
337
|
+
(b, h, w, c), the output shape becomes (b, h//2, w, out_dim).
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
dim: Number of input channels.
|
|
341
|
+
out_dim: Number of output channels after merging.
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
def __init__(self, dim: int, out_dim: int) -> None:
|
|
345
|
+
super().__init__()
|
|
346
|
+
self.dim = dim
|
|
347
|
+
self.reduction = nn.Conv2d(dim, out_dim, 3, (2, 1), 1)
|
|
348
|
+
self.norm = nn.LayerNorm(out_dim)
|
|
349
|
+
|
|
350
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
351
|
+
"""
|
|
352
|
+
Forward pass for PatchMerging.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
x: A float tensor of shape (b, h, w, c).
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
A float tensor of shape (b, h//2, w, out_dim).
|
|
359
|
+
"""
|
|
360
|
+
x = x.permute(0, 3, 1, 2)
|
|
361
|
+
x = self.reduction(x).permute(0, 2, 3, 1)
|
|
362
|
+
return self.norm(x)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
class LePEAttention(nn.Module):
|
|
366
|
+
"""
|
|
367
|
+
Local Enhancement Positional Encoding (LePE) Attention.
|
|
368
|
+
|
|
369
|
+
This is used for computing attention in cross-shaped windows (part of CrossShapedWindowAttention),
|
|
370
|
+
and includes a learnable position encoding via depthwise convolution.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
dim: Embedding dimension.
|
|
374
|
+
idx: Index used to determine the direction/split dimension for cross-shaped windows:
|
|
375
|
+
- idx == -1: no splitting (attend to all).
|
|
376
|
+
- idx == 0: vertical split.
|
|
377
|
+
- idx == 1: horizontal split.
|
|
378
|
+
split_size: Size of the split window.
|
|
379
|
+
dim_out: Output dimension; if None, defaults to `dim`.
|
|
380
|
+
num_heads: Number of attention heads.
|
|
381
|
+
attn_drop: Dropout rate for attention weights.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
def __init__(
|
|
385
|
+
self,
|
|
386
|
+
dim: int,
|
|
387
|
+
idx: int,
|
|
388
|
+
split_size: int = 7,
|
|
389
|
+
dim_out: int | None = None,
|
|
390
|
+
num_heads: int = 8,
|
|
391
|
+
attn_drop: float = 0.0,
|
|
392
|
+
) -> None:
|
|
393
|
+
super().__init__()
|
|
394
|
+
self.dim = dim
|
|
395
|
+
self.dim_out = dim_out or dim
|
|
396
|
+
self.split_size = split_size
|
|
397
|
+
self.num_heads = num_heads
|
|
398
|
+
self.idx = idx
|
|
399
|
+
head_dim = dim // num_heads
|
|
400
|
+
self.scale = head_dim**-0.5
|
|
401
|
+
|
|
402
|
+
self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
|
|
403
|
+
self.attn_drop = nn.Dropout(attn_drop)
|
|
404
|
+
|
|
405
|
+
def img2windows(self, img: torch.Tensor, h_sp: int, w_sp: int) -> torch.Tensor:
|
|
406
|
+
"""
|
|
407
|
+
Slice an image into windows of shape (h_sp, w_sp).
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
img: A float tensor of shape (b, c, h, w).
|
|
411
|
+
h_sp: The window's height.
|
|
412
|
+
w_sp: The window's width.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
A float tensor of shape (b', h_sp*w_sp, c), where b' = b * (h//h_sp) * (w//w_sp).
|
|
416
|
+
"""
|
|
417
|
+
b, c, h, w = img.shape
|
|
418
|
+
img_reshape = img.view(b, c, h // h_sp, h_sp, w // w_sp, w_sp)
|
|
419
|
+
img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).reshape(-1, h_sp * w_sp, c)
|
|
420
|
+
return img_perm
|
|
421
|
+
|
|
422
|
+
def windows2img(self, img_splits_hw: torch.Tensor, h_sp: int, w_sp: int, h: int, w: int) -> torch.Tensor:
|
|
423
|
+
"""
|
|
424
|
+
Merge windowed images back to the original spatial shape.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
img_splits_hw: A float tensor of shape (b', h_sp*w_sp, c).
|
|
428
|
+
h_sp: Window height.
|
|
429
|
+
w_sp: Window width.
|
|
430
|
+
h: Original height.
|
|
431
|
+
w: Original width.
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
A float tensor of shape (b, h, w, c).
|
|
435
|
+
"""
|
|
436
|
+
b_merged = int(img_splits_hw.shape[0] / (h * w / h_sp / w_sp))
|
|
437
|
+
img = img_splits_hw.view(b_merged, h // h_sp, w // w_sp, h_sp, w_sp, -1)
|
|
438
|
+
# contiguous() required to ensure the tensor has a contiguous memory layout
|
|
439
|
+
# after permute, allowing the subsequent view operation to work correctly.
|
|
440
|
+
img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(b_merged, h, w, -1)
|
|
441
|
+
return img
|
|
442
|
+
|
|
443
|
+
def _get_split(self, size: tuple[int, int]) -> tuple[int, int]:
|
|
444
|
+
"""
|
|
445
|
+
Determine how to split the height/width for the cross-shaped windows.
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
size: A tuple (h, w).
|
|
449
|
+
|
|
450
|
+
Returns:
|
|
451
|
+
A tuple (h_sp, w_sp) indicating split window dimensions.
|
|
452
|
+
"""
|
|
453
|
+
h, w = size
|
|
454
|
+
if self.idx == -1:
|
|
455
|
+
return h, w
|
|
456
|
+
elif self.idx == 0:
|
|
457
|
+
return h, self.split_size
|
|
458
|
+
elif self.idx == 1:
|
|
459
|
+
return self.split_size, w
|
|
460
|
+
else:
|
|
461
|
+
raise ValueError("idx must be -1, 0, or 1")
|
|
462
|
+
|
|
463
|
+
def im2cswin(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
464
|
+
"""
|
|
465
|
+
Re-arrange features into cross-shaped windows for Q/K.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
x: A float tensor of shape (b, n, c).
|
|
469
|
+
size: A tuple (h, w).
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
A float tensor of shape (b', num_heads, h_sp*w_sp, c//num_heads).
|
|
473
|
+
"""
|
|
474
|
+
b, n, c = x.shape
|
|
475
|
+
h, w = size
|
|
476
|
+
x = x.transpose(-2, -1).view(b, c, h, w)
|
|
477
|
+
h_sp, w_sp = self._get_split(size)
|
|
478
|
+
|
|
479
|
+
x = self.img2windows(x, h_sp, w_sp)
|
|
480
|
+
x = x.reshape(-1, h_sp * w_sp, self.num_heads, c // self.num_heads).permute(0, 2, 1, 3)
|
|
481
|
+
return x
|
|
482
|
+
|
|
483
|
+
def get_lepe(self, x: torch.Tensor, size: tuple[int, int]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
484
|
+
"""
|
|
485
|
+
Compute the learnable position encoding via depthwise convolution.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
x: A float tensor of shape (b, n, c).
|
|
489
|
+
size: A tuple (h, w).
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
x: A float tensor rearranged for V in shape (b', num_heads, n_window, c//num_heads).
|
|
493
|
+
lepe: A position encoding tensor of the same shape as x.
|
|
494
|
+
"""
|
|
495
|
+
b, n, c = x.shape
|
|
496
|
+
h, w = size
|
|
497
|
+
x = x.transpose(-2, -1).view(b, c, h, w)
|
|
498
|
+
h_sp, w_sp = self._get_split(size)
|
|
499
|
+
|
|
500
|
+
x = x.view(b, c, h // h_sp, h_sp, w // w_sp, w_sp)
|
|
501
|
+
x = x.permute(0, 2, 4, 1, 3, 5).reshape(-1, c, h_sp, w_sp) # b', c, h_sp, w_sp
|
|
502
|
+
|
|
503
|
+
lepe = self.get_v(x)
|
|
504
|
+
lepe = lepe.reshape(-1, self.num_heads, c // self.num_heads, h_sp * w_sp).permute(0, 1, 3, 2)
|
|
505
|
+
|
|
506
|
+
x = x.reshape(-1, self.num_heads, c // self.num_heads, h_sp * w_sp).permute(0, 1, 3, 2)
|
|
507
|
+
return x, lepe
|
|
508
|
+
|
|
509
|
+
def forward(self, qkv: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
510
|
+
"""
|
|
511
|
+
Forward pass for LePEAttention.
|
|
512
|
+
|
|
513
|
+
Splits Q/K/V according to cross-shaped windows, computes attention,
|
|
514
|
+
and returns the combined features.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
qkv: A tensor of shape (3, b, n, c) containing Q, K, and V.
|
|
518
|
+
size: A tuple (h, w) giving the height and width of the image/feature map.
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
A float tensor of shape (b, n, c) after cross-shaped window attention with LePE.
|
|
522
|
+
"""
|
|
523
|
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
524
|
+
|
|
525
|
+
h, w = size
|
|
526
|
+
b, n, c = q.shape
|
|
527
|
+
|
|
528
|
+
h_sp, w_sp = self._get_split(size)
|
|
529
|
+
q = self.im2cswin(q, size)
|
|
530
|
+
k = self.im2cswin(k, size)
|
|
531
|
+
v, lepe = self.get_lepe(v, size)
|
|
532
|
+
|
|
533
|
+
q = q * self.scale
|
|
534
|
+
attn = q @ k.transpose(-2, -1) # (b', head, n_window, n_window)
|
|
535
|
+
attn = nn.functional.softmax(attn, dim=-1)
|
|
536
|
+
attn = self.attn_drop(attn)
|
|
537
|
+
|
|
538
|
+
x = (attn @ v) + lepe
|
|
539
|
+
x = x.transpose(1, 2).reshape(-1, h_sp * w_sp, c)
|
|
540
|
+
# Window2Img
|
|
541
|
+
x = self.windows2img(x, h_sp, w_sp, h, w).view(b, -1, c)
|
|
542
|
+
return x
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
class CrossShapedWindowAttention(nn.Module):
|
|
546
|
+
"""
|
|
547
|
+
Local mixing module, performing attention within cross-shaped windows.
|
|
548
|
+
|
|
549
|
+
This captures local patterns by splitting the feature map into two cross-shaped windows:
|
|
550
|
+
vertical and horizontal slices. Each slice is passed to a LePEAttention. Outputs are
|
|
551
|
+
concatenated and projected, followed by an MLP for mixing.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
dim: Embedding dimension.
|
|
555
|
+
num_heads: Number of attention heads.
|
|
556
|
+
split_size: Window size for splitting.
|
|
557
|
+
mlp_ratio: Expansion factor for MLP hidden dimension.
|
|
558
|
+
qkv_bias: If True, adds a bias term to Q/K/V projections.
|
|
559
|
+
drop_path: Drop path rate. If > 0, applies stochastic depth.
|
|
560
|
+
"""
|
|
561
|
+
|
|
562
|
+
def __init__(
|
|
563
|
+
self,
|
|
564
|
+
dim: int,
|
|
565
|
+
num_heads: int,
|
|
566
|
+
split_size: int = 7,
|
|
567
|
+
mlp_ratio: float = 4.0,
|
|
568
|
+
qkv_bias: bool = False,
|
|
569
|
+
drop_path: float = 0.0,
|
|
570
|
+
) -> None:
|
|
571
|
+
super().__init__()
|
|
572
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
573
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
574
|
+
self.proj = nn.Linear(dim, dim)
|
|
575
|
+
|
|
576
|
+
self.attns = nn.ModuleList([
|
|
577
|
+
LePEAttention(
|
|
578
|
+
dim // 2,
|
|
579
|
+
idx=i,
|
|
580
|
+
split_size=split_size,
|
|
581
|
+
num_heads=num_heads // 2,
|
|
582
|
+
dim_out=dim // 2,
|
|
583
|
+
)
|
|
584
|
+
for i in range(2)
|
|
585
|
+
])
|
|
586
|
+
|
|
587
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
588
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
589
|
+
self.mlp = PositionwiseFeedForward(d_model=dim, ffd=mlp_hidden_dim, dropout=0.0, activation_fct=nn.GELU())
|
|
590
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
591
|
+
|
|
592
|
+
def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
|
|
593
|
+
"""
|
|
594
|
+
Forward pass for CrossShapedWindowAttention.
|
|
595
|
+
|
|
596
|
+
Args:
|
|
597
|
+
x: A float tensor of shape (b, n, c), where n = h * w.
|
|
598
|
+
size: A tuple (h, w) for the height and width of the feature map.
|
|
599
|
+
|
|
600
|
+
Returns:
|
|
601
|
+
A float tensor of shape (b, n, c) after cross-shaped window attention.
|
|
602
|
+
"""
|
|
603
|
+
b, _, c = x.shape
|
|
604
|
+
qkv = self.qkv(self.norm1(x)).reshape(b, -1, 3, c).permute(2, 0, 1, 3)
|
|
605
|
+
|
|
606
|
+
# Split QKV for each half, then apply cross-shaped window attention
|
|
607
|
+
x1 = self.attns[0](qkv[:, :, :, : c // 2], size)
|
|
608
|
+
x2 = self.attns[1](qkv[:, :, :, c // 2 :], size)
|
|
609
|
+
|
|
610
|
+
# Project and merge
|
|
611
|
+
merged = self.proj(torch.cat([x1, x2], dim=2))
|
|
612
|
+
x = x + self.drop_path(merged)
|
|
613
|
+
|
|
614
|
+
# MLP
|
|
615
|
+
return x + self.drop_path(self.mlp(self.norm2(x)))
|