python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. doctr/datasets/__init__.py +1 -0
  2. doctr/datasets/coco_text.py +139 -0
  3. doctr/datasets/cord.py +2 -1
  4. doctr/datasets/funsd.py +2 -2
  5. doctr/datasets/ic03.py +1 -1
  6. doctr/datasets/ic13.py +2 -1
  7. doctr/datasets/iiit5k.py +4 -1
  8. doctr/datasets/imgur5k.py +9 -2
  9. doctr/datasets/loader.py +1 -1
  10. doctr/datasets/ocr.py +1 -1
  11. doctr/datasets/recognition.py +1 -1
  12. doctr/datasets/svhn.py +1 -1
  13. doctr/datasets/svt.py +2 -2
  14. doctr/datasets/synthtext.py +15 -2
  15. doctr/datasets/utils.py +7 -6
  16. doctr/datasets/vocabs.py +1102 -54
  17. doctr/file_utils.py +9 -0
  18. doctr/io/elements.py +37 -3
  19. doctr/models/_utils.py +1 -1
  20. doctr/models/classification/__init__.py +1 -0
  21. doctr/models/classification/magc_resnet/pytorch.py +1 -2
  22. doctr/models/classification/magc_resnet/tensorflow.py +3 -3
  23. doctr/models/classification/mobilenet/pytorch.py +15 -1
  24. doctr/models/classification/mobilenet/tensorflow.py +11 -2
  25. doctr/models/classification/predictor/pytorch.py +1 -1
  26. doctr/models/classification/resnet/pytorch.py +26 -3
  27. doctr/models/classification/resnet/tensorflow.py +25 -4
  28. doctr/models/classification/textnet/pytorch.py +10 -1
  29. doctr/models/classification/textnet/tensorflow.py +11 -2
  30. doctr/models/classification/vgg/pytorch.py +16 -1
  31. doctr/models/classification/vgg/tensorflow.py +11 -2
  32. doctr/models/classification/vip/__init__.py +4 -0
  33. doctr/models/classification/vip/layers/__init__.py +4 -0
  34. doctr/models/classification/vip/layers/pytorch.py +615 -0
  35. doctr/models/classification/vip/pytorch.py +505 -0
  36. doctr/models/classification/vit/pytorch.py +10 -1
  37. doctr/models/classification/vit/tensorflow.py +9 -0
  38. doctr/models/classification/zoo.py +4 -0
  39. doctr/models/detection/differentiable_binarization/base.py +3 -4
  40. doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
  41. doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
  42. doctr/models/detection/fast/base.py +2 -3
  43. doctr/models/detection/fast/pytorch.py +13 -4
  44. doctr/models/detection/fast/tensorflow.py +10 -2
  45. doctr/models/detection/linknet/base.py +2 -3
  46. doctr/models/detection/linknet/pytorch.py +10 -1
  47. doctr/models/detection/linknet/tensorflow.py +10 -2
  48. doctr/models/factory/hub.py +3 -3
  49. doctr/models/kie_predictor/pytorch.py +1 -1
  50. doctr/models/kie_predictor/tensorflow.py +1 -1
  51. doctr/models/modules/layers/pytorch.py +49 -1
  52. doctr/models/predictor/pytorch.py +1 -1
  53. doctr/models/predictor/tensorflow.py +1 -1
  54. doctr/models/recognition/__init__.py +1 -0
  55. doctr/models/recognition/crnn/pytorch.py +10 -1
  56. doctr/models/recognition/crnn/tensorflow.py +10 -1
  57. doctr/models/recognition/master/pytorch.py +10 -1
  58. doctr/models/recognition/master/tensorflow.py +10 -3
  59. doctr/models/recognition/parseq/pytorch.py +23 -5
  60. doctr/models/recognition/parseq/tensorflow.py +13 -5
  61. doctr/models/recognition/predictor/_utils.py +107 -45
  62. doctr/models/recognition/predictor/pytorch.py +3 -3
  63. doctr/models/recognition/predictor/tensorflow.py +3 -3
  64. doctr/models/recognition/sar/pytorch.py +10 -1
  65. doctr/models/recognition/sar/tensorflow.py +10 -3
  66. doctr/models/recognition/utils.py +56 -47
  67. doctr/models/recognition/viptr/__init__.py +4 -0
  68. doctr/models/recognition/viptr/pytorch.py +277 -0
  69. doctr/models/recognition/vitstr/pytorch.py +10 -1
  70. doctr/models/recognition/vitstr/tensorflow.py +10 -3
  71. doctr/models/recognition/zoo.py +5 -0
  72. doctr/models/utils/pytorch.py +28 -18
  73. doctr/models/utils/tensorflow.py +15 -8
  74. doctr/utils/data.py +1 -1
  75. doctr/utils/geometry.py +1 -1
  76. doctr/version.py +1 -1
  77. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
  78. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
  79. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  80. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  81. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  82. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -0,0 +1,615 @@
1
+ # Copyright (C) 2021-2025, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from doctr.models.modules.layers import DropPath
10
+ from doctr.models.modules.transformer import PositionwiseFeedForward
11
+ from doctr.models.utils import conv_sequence_pt
12
+
13
+ __all__ = [
14
+ "PermuteLayer",
15
+ "SqueezeLayer",
16
+ "PatchEmbed",
17
+ "Attention",
18
+ "MultiHeadSelfAttention",
19
+ "OverlappedSpatialReductionAttention",
20
+ "OSRABlock",
21
+ "PatchMerging",
22
+ "LePEAttention",
23
+ "CrossShapedWindowAttention",
24
+ ]
25
+
26
+
27
+ class PermuteLayer(nn.Module):
28
+ """Custom layer to permute dimensions in a Sequential model."""
29
+
30
+ def __init__(self, dims: tuple[int, int, int, int] = (0, 2, 3, 1)):
31
+ super().__init__()
32
+ self.dims = dims
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ return x.permute(self.dims).contiguous()
36
+
37
+
38
+ class SqueezeLayer(nn.Module):
39
+ """Custom layer to squeeze out a dimension in a Sequential model."""
40
+
41
+ def __init__(self, dim: int = 3):
42
+ super().__init__()
43
+ self.dim = dim
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ return x.squeeze(self.dim)
47
+
48
+
49
+ class PatchEmbed(nn.Module):
50
+ """
51
+ Patch embedding layer for Vision Permutable Extractor.
52
+
53
+ This layer reduces the spatial resolution of the input tensor by a factor of 4 in total
54
+ (two consecutive strides of 2). It then permutes the output into `(b, h, w, c)` form.
55
+
56
+ Args:
57
+ in_channels: Number of channels in the input images.
58
+ embed_dim: Dimensionality of the embedding (i.e., output channels).
59
+ """
60
+
61
+ def __init__(self, in_channels: int = 3, embed_dim: int = 128) -> None:
62
+ super().__init__()
63
+ self.embed_dim = embed_dim
64
+ self.proj = nn.Sequential(
65
+ *conv_sequence_pt(
66
+ in_channels, embed_dim // 2, kernel_size=3, stride=2, padding=1, bias=False, bn=True, relu=False
67
+ ),
68
+ nn.GELU(),
69
+ *conv_sequence_pt(
70
+ embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False, bn=True, relu=False
71
+ ),
72
+ nn.GELU(),
73
+ )
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ """
77
+ Forward pass for PatchEmbed.
78
+
79
+ Args:
80
+ x: A float tensor of shape (b, c, h, w).
81
+
82
+ Returns:
83
+ A float tensor of shape (b, h/4, w/4, embed_dim).
84
+ """
85
+ return self.proj(x).permute(0, 2, 3, 1)
86
+
87
+
88
+ class Attention(nn.Module):
89
+ """
90
+ Standard multi-head attention module.
91
+
92
+ This module applies self-attention across the input sequence using 'num_heads' heads.
93
+
94
+ Args:
95
+ dim: Dimensionality of the input embeddings.
96
+ num_heads: Number of attention heads.
97
+ qkv_bias: If True, adds a learnable bias to the query, key, value projections.
98
+ attn_drop: Dropout rate applied to the attention map.
99
+ proj_drop: Dropout rate applied to the final output projection.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ dim: int,
105
+ num_heads: int = 8,
106
+ qkv_bias: bool = False,
107
+ attn_drop: float = 0.0,
108
+ proj_drop: float = 0.0,
109
+ ) -> None:
110
+ super().__init__()
111
+ self.num_heads = num_heads
112
+ head_dim = dim // num_heads
113
+ self.scale = head_dim**-0.5
114
+
115
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
116
+ self.attn_drop = nn.Dropout(attn_drop)
117
+ self.proj = nn.Linear(dim, dim)
118
+ self.proj_drop = nn.Dropout(proj_drop)
119
+
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ """
122
+ Forward pass for Attention.
123
+
124
+ Args:
125
+ x: A float tensor of shape (b, n, c), where n is the sequence length and c is
126
+ the embedding dimension.
127
+
128
+ Returns:
129
+ A float tensor of shape (b, n, c) with attended information.
130
+ """
131
+ _, n, c = x.shape
132
+ qkv = self.qkv(x).reshape((-1, n, 3, self.num_heads, c // self.num_heads)).permute((2, 0, 3, 1, 4))
133
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
134
+
135
+ attn = q.matmul(k.permute((0, 1, 3, 2)))
136
+ attn = nn.functional.softmax(attn, dim=-1)
137
+ attn = self.attn_drop(attn)
138
+
139
+ x = attn.matmul(v).permute((0, 2, 1, 3)).contiguous().reshape((-1, n, c))
140
+ x = self.proj(x)
141
+ x = self.proj_drop(x)
142
+ return x
143
+
144
+
145
+ class MultiHeadSelfAttention(nn.Module):
146
+ """
147
+ Multi-head Self Attention block with an MLP for feed-forward processing.
148
+
149
+ This block normalizes the input, applies attention mixing, adds a residual connection,
150
+ then applies an MLP with another residual connection.
151
+
152
+ Args:
153
+ dim: Dimensionality of input embeddings.
154
+ num_heads: Number of attention heads.
155
+ mlp_ratio: Expansion factor for the internal dimension of the MLP.
156
+ qkv_bias: If True, adds a learnable bias to the query, key, value projections.
157
+ drop_path_rate: Drop path rate. If > 0, applies stochastic depth.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ dim: int,
163
+ num_heads: int,
164
+ mlp_ratio: float = 4.0,
165
+ qkv_bias: bool = False,
166
+ drop_path_rate: float = 0.0,
167
+ ) -> None:
168
+ super().__init__()
169
+ self.norm1 = nn.LayerNorm(dim)
170
+
171
+ self.mixer = Attention(
172
+ dim,
173
+ num_heads=num_heads,
174
+ qkv_bias=qkv_bias,
175
+ )
176
+
177
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
178
+ self.norm2 = nn.LayerNorm(dim)
179
+ mlp_hidden_dim = int(dim * mlp_ratio)
180
+ self.mlp = PositionwiseFeedForward(d_model=dim, ffd=mlp_hidden_dim, dropout=0.0, activation_fct=nn.GELU())
181
+
182
+ def forward(self, x: torch.Tensor, size: tuple[int, int] | None = None) -> torch.Tensor:
183
+ """
184
+ Forward pass for MultiHeadSelfAttention.
185
+
186
+ Args:
187
+ x: A float tensor of shape (b, n, c).
188
+ size: An optional (h, w) if needed by some modules (unused here).
189
+
190
+ Returns:
191
+ A float tensor of shape (b, n, c) after self-attention and MLP.
192
+ """
193
+ x = x + self.drop_path(self.mixer(self.norm1(x)))
194
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
195
+ return x
196
+
197
+
198
+ class OverlappedSpatialReductionAttention(nn.Module):
199
+ """
200
+ Overlapped Spatial Reduction Attention (OSRA).
201
+
202
+ This attention mechanism downsamples the input according to 'sr_ratio' (spatial reduction ratio),
203
+ applies a local convolution for feature enhancement. It captures dependencies in an overlapping manner.
204
+
205
+ Args:
206
+ dim: The embedding dimension of the tokens.
207
+ num_heads: Number of attention heads.
208
+ qk_scale: Optionally override q-k scaling. Defaults to head_dim^-0.5 if None.
209
+ attn_drop: Dropout rate for attention weights.
210
+ sr_ratio: Spatial reduction ratio. If > 1, a depthwise conv-based downsampling is applied.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ dim: int,
216
+ num_heads: int = 1,
217
+ qk_scale: float | None = None,
218
+ attn_drop: float = 0.0,
219
+ sr_ratio: int = 1,
220
+ ) -> None:
221
+ super().__init__()
222
+ assert dim % num_heads == 0, f"dim {dim} should be divisible by num_heads {num_heads}."
223
+ self.dim = dim
224
+ self.num_heads = num_heads
225
+ head_dim = dim // num_heads
226
+ self.scale = qk_scale or head_dim**-0.5
227
+ self.sr_ratio = sr_ratio
228
+ self.q = nn.Conv2d(dim, dim, kernel_size=1)
229
+ self.kv = nn.Conv2d(dim, dim * 2, kernel_size=1)
230
+ self.attn_drop = nn.Dropout(attn_drop)
231
+
232
+ if sr_ratio > 1:
233
+ self.sr = nn.Sequential(
234
+ *conv_sequence_pt(
235
+ dim,
236
+ dim,
237
+ kernel_size=sr_ratio + 3,
238
+ stride=sr_ratio,
239
+ padding=(sr_ratio + 3) // 2,
240
+ groups=dim,
241
+ bias=False,
242
+ bn=True,
243
+ relu=False,
244
+ ),
245
+ nn.GELU(),
246
+ *conv_sequence_pt(dim, dim, kernel_size=1, groups=dim, bias=False, bn=True, relu=False),
247
+ )
248
+ else:
249
+ self.sr = nn.Identity() # type: ignore[assignment]
250
+
251
+ self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
252
+
253
+ def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
254
+ """
255
+ Forward pass for OverlappedSpatialReductionAttention.
256
+
257
+ Args:
258
+ x: A float tensor of shape (b, n, c) where n = h * w.
259
+ size: A tuple (h, w) giving the height and width of the original feature map.
260
+
261
+ Returns:
262
+ A float tensor of shape (b, n, c) with updated representations.
263
+ """
264
+ b, n, c = x.shape
265
+ h, w = size
266
+ x = x.permute(0, 2, 1).reshape(b, -1, h, w)
267
+
268
+ q = self.q(x).reshape(b, self.num_heads, c // self.num_heads, -1).transpose(-1, -2)
269
+ kv = self.sr(x)
270
+ kv = self.local_conv(kv) + kv
271
+ k, v = torch.chunk(self.kv(kv), chunks=2, dim=1)
272
+ k = k.reshape(b, self.num_heads, c // self.num_heads, -1)
273
+ v = v.reshape(b, self.num_heads, c // self.num_heads, -1).transpose(-1, -2)
274
+
275
+ attn = (q @ k) * self.scale
276
+ attn = torch.softmax(attn, dim=-1)
277
+ attn = self.attn_drop(attn)
278
+ x = (attn @ v).transpose(-1, -2).reshape(b, c, -1)
279
+ x = x.permute(0, 2, 1)
280
+ return x
281
+
282
+
283
+ class OSRABlock(nn.Module):
284
+ """
285
+ Global token mixing block using Overlapped Spatial Reduction Attention (OSRA).
286
+
287
+ Captures global dependencies by aggregating context from a wider spatial area,
288
+ followed by a position-wise feed-forward layer.
289
+
290
+ Args:
291
+ dim: Embedding dimension of tokens.
292
+ sr_ratio: Spatial reduction ratio for OSRA.
293
+ num_heads: Number of attention heads.
294
+ mlp_ratio: Expansion factor for the MLP hidden dimension.
295
+ drop_path: Drop path rate. If > 0, applies stochastic depth.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ dim: int = 64,
301
+ sr_ratio: int = 1,
302
+ num_heads: int = 1,
303
+ mlp_ratio: float = 4.0,
304
+ drop_path: float = 0.0,
305
+ ) -> None:
306
+ super().__init__()
307
+ mlp_hidden_dim = int(dim * mlp_ratio)
308
+
309
+ self.norm1 = nn.LayerNorm(dim)
310
+ self.token_mixer = OverlappedSpatialReductionAttention(dim, num_heads=num_heads, sr_ratio=sr_ratio)
311
+ self.norm2 = nn.LayerNorm(dim)
312
+
313
+ self.mlp = PositionwiseFeedForward(d_model=dim, ffd=mlp_hidden_dim, dropout=0.0, activation_fct=nn.GELU())
314
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
315
+
316
+ def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
317
+ """
318
+ Forward pass for OSRABlock.
319
+
320
+ Args:
321
+ x: A float tensor of shape (b, n, c).
322
+ size: A tuple (h, w) giving the height and width of the original feature map.
323
+
324
+ Returns:
325
+ A float tensor of shape (b, n, c) with globally mixed features.
326
+ """
327
+ x = x + self.drop_path(self.token_mixer(self.norm1(x), size))
328
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
329
+ return x
330
+
331
+
332
+ class PatchMerging(nn.Module):
333
+ """
334
+ Patch Merging Layer.
335
+
336
+ Reduces the spatial dimension by half along the height. If the input has shape
337
+ (b, h, w, c), the output shape becomes (b, h//2, w, out_dim).
338
+
339
+ Args:
340
+ dim: Number of input channels.
341
+ out_dim: Number of output channels after merging.
342
+ """
343
+
344
+ def __init__(self, dim: int, out_dim: int) -> None:
345
+ super().__init__()
346
+ self.dim = dim
347
+ self.reduction = nn.Conv2d(dim, out_dim, 3, (2, 1), 1)
348
+ self.norm = nn.LayerNorm(out_dim)
349
+
350
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
351
+ """
352
+ Forward pass for PatchMerging.
353
+
354
+ Args:
355
+ x: A float tensor of shape (b, h, w, c).
356
+
357
+ Returns:
358
+ A float tensor of shape (b, h//2, w, out_dim).
359
+ """
360
+ x = x.permute(0, 3, 1, 2)
361
+ x = self.reduction(x).permute(0, 2, 3, 1)
362
+ return self.norm(x)
363
+
364
+
365
+ class LePEAttention(nn.Module):
366
+ """
367
+ Local Enhancement Positional Encoding (LePE) Attention.
368
+
369
+ This is used for computing attention in cross-shaped windows (part of CrossShapedWindowAttention),
370
+ and includes a learnable position encoding via depthwise convolution.
371
+
372
+ Args:
373
+ dim: Embedding dimension.
374
+ idx: Index used to determine the direction/split dimension for cross-shaped windows:
375
+ - idx == -1: no splitting (attend to all).
376
+ - idx == 0: vertical split.
377
+ - idx == 1: horizontal split.
378
+ split_size: Size of the split window.
379
+ dim_out: Output dimension; if None, defaults to `dim`.
380
+ num_heads: Number of attention heads.
381
+ attn_drop: Dropout rate for attention weights.
382
+ """
383
+
384
+ def __init__(
385
+ self,
386
+ dim: int,
387
+ idx: int,
388
+ split_size: int = 7,
389
+ dim_out: int | None = None,
390
+ num_heads: int = 8,
391
+ attn_drop: float = 0.0,
392
+ ) -> None:
393
+ super().__init__()
394
+ self.dim = dim
395
+ self.dim_out = dim_out or dim
396
+ self.split_size = split_size
397
+ self.num_heads = num_heads
398
+ self.idx = idx
399
+ head_dim = dim // num_heads
400
+ self.scale = head_dim**-0.5
401
+
402
+ self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
403
+ self.attn_drop = nn.Dropout(attn_drop)
404
+
405
+ def img2windows(self, img: torch.Tensor, h_sp: int, w_sp: int) -> torch.Tensor:
406
+ """
407
+ Slice an image into windows of shape (h_sp, w_sp).
408
+
409
+ Args:
410
+ img: A float tensor of shape (b, c, h, w).
411
+ h_sp: The window's height.
412
+ w_sp: The window's width.
413
+
414
+ Returns:
415
+ A float tensor of shape (b', h_sp*w_sp, c), where b' = b * (h//h_sp) * (w//w_sp).
416
+ """
417
+ b, c, h, w = img.shape
418
+ img_reshape = img.view(b, c, h // h_sp, h_sp, w // w_sp, w_sp)
419
+ img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).reshape(-1, h_sp * w_sp, c)
420
+ return img_perm
421
+
422
+ def windows2img(self, img_splits_hw: torch.Tensor, h_sp: int, w_sp: int, h: int, w: int) -> torch.Tensor:
423
+ """
424
+ Merge windowed images back to the original spatial shape.
425
+
426
+ Args:
427
+ img_splits_hw: A float tensor of shape (b', h_sp*w_sp, c).
428
+ h_sp: Window height.
429
+ w_sp: Window width.
430
+ h: Original height.
431
+ w: Original width.
432
+
433
+ Returns:
434
+ A float tensor of shape (b, h, w, c).
435
+ """
436
+ b_merged = int(img_splits_hw.shape[0] / (h * w / h_sp / w_sp))
437
+ img = img_splits_hw.view(b_merged, h // h_sp, w // w_sp, h_sp, w_sp, -1)
438
+ # contiguous() required to ensure the tensor has a contiguous memory layout
439
+ # after permute, allowing the subsequent view operation to work correctly.
440
+ img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(b_merged, h, w, -1)
441
+ return img
442
+
443
+ def _get_split(self, size: tuple[int, int]) -> tuple[int, int]:
444
+ """
445
+ Determine how to split the height/width for the cross-shaped windows.
446
+
447
+ Args:
448
+ size: A tuple (h, w).
449
+
450
+ Returns:
451
+ A tuple (h_sp, w_sp) indicating split window dimensions.
452
+ """
453
+ h, w = size
454
+ if self.idx == -1:
455
+ return h, w
456
+ elif self.idx == 0:
457
+ return h, self.split_size
458
+ elif self.idx == 1:
459
+ return self.split_size, w
460
+ else:
461
+ raise ValueError("idx must be -1, 0, or 1")
462
+
463
+ def im2cswin(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
464
+ """
465
+ Re-arrange features into cross-shaped windows for Q/K.
466
+
467
+ Args:
468
+ x: A float tensor of shape (b, n, c).
469
+ size: A tuple (h, w).
470
+
471
+ Returns:
472
+ A float tensor of shape (b', num_heads, h_sp*w_sp, c//num_heads).
473
+ """
474
+ b, n, c = x.shape
475
+ h, w = size
476
+ x = x.transpose(-2, -1).view(b, c, h, w)
477
+ h_sp, w_sp = self._get_split(size)
478
+
479
+ x = self.img2windows(x, h_sp, w_sp)
480
+ x = x.reshape(-1, h_sp * w_sp, self.num_heads, c // self.num_heads).permute(0, 2, 1, 3)
481
+ return x
482
+
483
+ def get_lepe(self, x: torch.Tensor, size: tuple[int, int]) -> tuple[torch.Tensor, torch.Tensor]:
484
+ """
485
+ Compute the learnable position encoding via depthwise convolution.
486
+
487
+ Args:
488
+ x: A float tensor of shape (b, n, c).
489
+ size: A tuple (h, w).
490
+
491
+ Returns:
492
+ x: A float tensor rearranged for V in shape (b', num_heads, n_window, c//num_heads).
493
+ lepe: A position encoding tensor of the same shape as x.
494
+ """
495
+ b, n, c = x.shape
496
+ h, w = size
497
+ x = x.transpose(-2, -1).view(b, c, h, w)
498
+ h_sp, w_sp = self._get_split(size)
499
+
500
+ x = x.view(b, c, h // h_sp, h_sp, w // w_sp, w_sp)
501
+ x = x.permute(0, 2, 4, 1, 3, 5).reshape(-1, c, h_sp, w_sp) # b', c, h_sp, w_sp
502
+
503
+ lepe = self.get_v(x)
504
+ lepe = lepe.reshape(-1, self.num_heads, c // self.num_heads, h_sp * w_sp).permute(0, 1, 3, 2)
505
+
506
+ x = x.reshape(-1, self.num_heads, c // self.num_heads, h_sp * w_sp).permute(0, 1, 3, 2)
507
+ return x, lepe
508
+
509
+ def forward(self, qkv: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
510
+ """
511
+ Forward pass for LePEAttention.
512
+
513
+ Splits Q/K/V according to cross-shaped windows, computes attention,
514
+ and returns the combined features.
515
+
516
+ Args:
517
+ qkv: A tensor of shape (3, b, n, c) containing Q, K, and V.
518
+ size: A tuple (h, w) giving the height and width of the image/feature map.
519
+
520
+ Returns:
521
+ A float tensor of shape (b, n, c) after cross-shaped window attention with LePE.
522
+ """
523
+ q, k, v = qkv[0], qkv[1], qkv[2]
524
+
525
+ h, w = size
526
+ b, n, c = q.shape
527
+
528
+ h_sp, w_sp = self._get_split(size)
529
+ q = self.im2cswin(q, size)
530
+ k = self.im2cswin(k, size)
531
+ v, lepe = self.get_lepe(v, size)
532
+
533
+ q = q * self.scale
534
+ attn = q @ k.transpose(-2, -1) # (b', head, n_window, n_window)
535
+ attn = nn.functional.softmax(attn, dim=-1)
536
+ attn = self.attn_drop(attn)
537
+
538
+ x = (attn @ v) + lepe
539
+ x = x.transpose(1, 2).reshape(-1, h_sp * w_sp, c)
540
+ # Window2Img
541
+ x = self.windows2img(x, h_sp, w_sp, h, w).view(b, -1, c)
542
+ return x
543
+
544
+
545
+ class CrossShapedWindowAttention(nn.Module):
546
+ """
547
+ Local mixing module, performing attention within cross-shaped windows.
548
+
549
+ This captures local patterns by splitting the feature map into two cross-shaped windows:
550
+ vertical and horizontal slices. Each slice is passed to a LePEAttention. Outputs are
551
+ concatenated and projected, followed by an MLP for mixing.
552
+
553
+ Args:
554
+ dim: Embedding dimension.
555
+ num_heads: Number of attention heads.
556
+ split_size: Window size for splitting.
557
+ mlp_ratio: Expansion factor for MLP hidden dimension.
558
+ qkv_bias: If True, adds a bias term to Q/K/V projections.
559
+ drop_path: Drop path rate. If > 0, applies stochastic depth.
560
+ """
561
+
562
+ def __init__(
563
+ self,
564
+ dim: int,
565
+ num_heads: int,
566
+ split_size: int = 7,
567
+ mlp_ratio: float = 4.0,
568
+ qkv_bias: bool = False,
569
+ drop_path: float = 0.0,
570
+ ) -> None:
571
+ super().__init__()
572
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
573
+ self.norm1 = nn.LayerNorm(dim)
574
+ self.proj = nn.Linear(dim, dim)
575
+
576
+ self.attns = nn.ModuleList([
577
+ LePEAttention(
578
+ dim // 2,
579
+ idx=i,
580
+ split_size=split_size,
581
+ num_heads=num_heads // 2,
582
+ dim_out=dim // 2,
583
+ )
584
+ for i in range(2)
585
+ ])
586
+
587
+ mlp_hidden_dim = int(dim * mlp_ratio)
588
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
589
+ self.mlp = PositionwiseFeedForward(d_model=dim, ffd=mlp_hidden_dim, dropout=0.0, activation_fct=nn.GELU())
590
+ self.norm2 = nn.LayerNorm(dim)
591
+
592
+ def forward(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor:
593
+ """
594
+ Forward pass for CrossShapedWindowAttention.
595
+
596
+ Args:
597
+ x: A float tensor of shape (b, n, c), where n = h * w.
598
+ size: A tuple (h, w) for the height and width of the feature map.
599
+
600
+ Returns:
601
+ A float tensor of shape (b, n, c) after cross-shaped window attention.
602
+ """
603
+ b, _, c = x.shape
604
+ qkv = self.qkv(self.norm1(x)).reshape(b, -1, 3, c).permute(2, 0, 1, 3)
605
+
606
+ # Split QKV for each half, then apply cross-shaped window attention
607
+ x1 = self.attns[0](qkv[:, :, :, : c // 2], size)
608
+ x2 = self.attns[1](qkv[:, :, :, c // 2 :], size)
609
+
610
+ # Project and merge
611
+ merged = self.proj(torch.cat([x1, x2], dim=2))
612
+ x = x + self.drop_path(merged)
613
+
614
+ # MLP
615
+ return x + self.drop_path(self.mlp(self.norm2(x)))