python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -0,0 +1,505 @@
1
+ # Copyright (C) 2021-2025, Mindee.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ from copy import deepcopy
7
+ from typing import Any
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from doctr.datasets import VOCABS
13
+ from doctr.models.modules.layers import AdaptiveAvgPool2d
14
+
15
+ from ...utils import load_pretrained_params
16
+ from .layers import (
17
+ CrossShapedWindowAttention,
18
+ MultiHeadSelfAttention,
19
+ OSRABlock,
20
+ PatchEmbed,
21
+ PatchMerging,
22
+ PermuteLayer,
23
+ SqueezeLayer,
24
+ )
25
+
26
+ __all__ = ["vip_tiny", "vip_base"]
27
+
28
+ default_cfgs: dict[str, dict[str, Any]] = {
29
+ "vip_tiny": {
30
+ "mean": (0.694, 0.695, 0.693),
31
+ "std": (0.299, 0.296, 0.301),
32
+ "input_shape": (3, 32, 32),
33
+ "classes": list(VOCABS["french"]),
34
+ "url": "https://doctr-static.mindee.com/models?id=v0.11.0/vip_tiny-033ed51c.pt&src=0",
35
+ },
36
+ "vip_base": {
37
+ "mean": (0.694, 0.695, 0.693),
38
+ "std": (0.299, 0.296, 0.301),
39
+ "input_shape": (3, 32, 32),
40
+ "classes": list(VOCABS["french"]),
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.11.0/vip_base-f6ea2ff5.pt&src=0",
42
+ },
43
+ }
44
+
45
+
46
+ class ClassifierHead(nn.Module):
47
+ """Classification head which averages the features and applies a linear layer."""
48
+
49
+ def __init__(self, in_features: int, out_features: int):
50
+ super().__init__()
51
+ self.fc = nn.Linear(in_features, out_features)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ return self.fc(x.mean(dim=1))
55
+
56
+
57
+ class VIPBlock(nn.Module):
58
+ """Unified block for Local, Global, and Mixed feature mixing in VIP architecture."""
59
+
60
+ def __init__(
61
+ self,
62
+ embed_dim: int,
63
+ local_unit: nn.ModuleList,
64
+ global_unit: nn.ModuleList | None = None,
65
+ proj: nn.Module | None = None,
66
+ downsample: bool = False,
67
+ out_dim: int | None = None,
68
+ ):
69
+ """
70
+ Args:
71
+ embed_dim: dimension of embeddings
72
+ local_unit: local mixing block(s)
73
+ global_unit: global mixing block(s)
74
+ proj: projection layer used for mixed mixing
75
+ downsample: whether to downsample at the end
76
+ out_dim: out channels if downsampling
77
+ """
78
+ super().__init__()
79
+ if downsample and out_dim is None: # pragma: no cover
80
+ raise ValueError("`out_dim` must be specified if `downsample=True`")
81
+
82
+ self.local_unit = local_unit
83
+ self.global_unit = global_unit
84
+ self.proj = proj
85
+ self.downsample = PatchMerging(dim=embed_dim, out_dim=out_dim) if downsample else None # type: ignore[arg-type]
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ """
89
+ Forward pass for VIPBlock.
90
+
91
+ Args:
92
+ x: input tensor (B, H, W, C)
93
+
94
+ Returns:
95
+ Transformed tensor
96
+ """
97
+ b, h, w, C = x.shape
98
+
99
+ # Local or Mixed
100
+ if self.global_unit is None:
101
+ # local or global only
102
+ for blk in self.local_unit:
103
+ # Flatten to (B, H*W, C)
104
+ x = x.reshape(b, -1, C)
105
+ x = blk(x, (h, w))
106
+ x = x.reshape(b, h, w, -1)
107
+ else:
108
+ # Mixed
109
+ for lblk, gblk in zip(self.local_unit, self.global_unit):
110
+ x = x.reshape(b, -1, C)
111
+ # chunk into two halves
112
+ x1, x2 = torch.chunk(x, chunks=2, dim=2)
113
+ x1 = lblk(x1, (h, w))
114
+ x2 = gblk(x2, (h, w))
115
+ x = torch.cat([x1, x2], dim=2)
116
+ x = x.transpose(1, 2).contiguous().reshape(b, -1, h, w)
117
+ x = self.proj(x) + x # type: ignore[misc]
118
+ x = x.permute(0, 2, 3, 1).contiguous()
119
+
120
+ if isinstance(self.downsample, nn.Module):
121
+ x = self.downsample(x)
122
+
123
+ return x
124
+
125
+
126
+ class VIPNet(nn.Sequential):
127
+ """
128
+ VIP (Vision Permutable) encoder architecture, adapted for text recognition.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ in_channels: int,
134
+ out_dim: int,
135
+ embed_dims: list[int],
136
+ depths: list[int],
137
+ num_heads: list[int],
138
+ mlp_ratios: list[int],
139
+ split_sizes: list[int],
140
+ sr_ratios: list[int],
141
+ input_shape: tuple[int, int, int] = (3, 32, 32),
142
+ num_classes: int = 1000,
143
+ include_top: bool = True,
144
+ cfg: dict[str, Any] | None = None,
145
+ ) -> None:
146
+ """
147
+ Args:
148
+ in_channels: number of input channels
149
+ out_dim: final embedding dimension
150
+ embed_dims: list of embedding dims per stage
151
+ depths: number of blocks per stage
152
+ num_heads: number of heads for attention blocks
153
+ mlp_ratios: ratio for MLP expansion
154
+ split_sizes: local window split sizes
155
+ sr_ratios: used for some global block adjustments
156
+ input_shape: (C, H, W)
157
+ num_classes: number of output classes
158
+ include_top: if True, append a classification head
159
+ cfg: optional config dictionary
160
+ """
161
+ self.cfg = cfg
162
+
163
+ dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))]
164
+ drop_paths = [dpr[sum(depths[:i]) : sum(depths[: i + 1])] for i in range(len(depths))]
165
+ layers: list[Any] = [PatchEmbed(in_channels=in_channels, embed_dim=embed_dims[0])]
166
+
167
+ # Construct mixers
168
+ # e.g. local, mixed, global
169
+ mixer_functions = [
170
+ _vip_local_mixer,
171
+ _vip_mixed_mixer,
172
+ _vip_global_mha_mixer,
173
+ ]
174
+
175
+ for i, mixer_fn in enumerate(mixer_functions):
176
+ embed_dim = embed_dims[i]
177
+ depth_i = depths[i]
178
+ num_head = num_heads[i]
179
+ mlp_ratio = mlp_ratios[i]
180
+ sp_size = split_sizes[i]
181
+ sr_ratio = sr_ratios[i]
182
+ drop_path = drop_paths[i]
183
+
184
+ next_dim = embed_dims[i + 1] if i < len(embed_dims) - 1 else None
185
+
186
+ block = mixer_fn(
187
+ embed_dim=embed_dim,
188
+ depth=depth_i,
189
+ num_heads=num_head,
190
+ mlp_ratio=mlp_ratio,
191
+ split_size=sp_size,
192
+ sr_ratio=sr_ratio,
193
+ drop_path=drop_path,
194
+ downsample=(next_dim is not None),
195
+ out_dim=next_dim,
196
+ )
197
+ layers.append(block)
198
+
199
+ # LN -> permute -> GAP -> squeeze -> MLP
200
+ layers.append(
201
+ nn.Sequential(
202
+ nn.LayerNorm(embed_dims[-1], eps=1e-6),
203
+ PermuteLayer((0, 2, 3, 1)),
204
+ AdaptiveAvgPool2d((embed_dims[-1], 1)),
205
+ SqueezeLayer(dim=3),
206
+ )
207
+ )
208
+
209
+ mlp_head = nn.Sequential(
210
+ nn.Linear(embed_dims[-1], out_dim, bias=False),
211
+ nn.Hardswish(),
212
+ nn.Dropout(p=0.1),
213
+ )
214
+ layers.append(mlp_head)
215
+ if include_top:
216
+ layers.append(ClassifierHead(out_dim, num_classes))
217
+
218
+ super().__init__(*layers)
219
+
220
+ self.apply(self._init_weights)
221
+
222
+ def _init_weights(self, m):
223
+ if isinstance(m, nn.Linear):
224
+ nn.init.trunc_normal_(m.weight, std=0.02)
225
+ if m.bias is not None:
226
+ nn.init.constant_(m.bias, 0)
227
+ elif isinstance(m, nn.Conv2d):
228
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
229
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
230
+ nn.init.constant_(m.bias, 0)
231
+ nn.init.constant_(m.weight, 1.0)
232
+
233
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
234
+ """Load pretrained parameters onto the model
235
+
236
+ Args:
237
+ path_or_url: the path or URL to the model parameters (checkpoint)
238
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
239
+ """
240
+ load_pretrained_params(self, path_or_url, **kwargs)
241
+
242
+
243
+ def vip_tiny(pretrained: bool = False, **kwargs: Any) -> VIPNet:
244
+ """
245
+ VIP-Tiny encoder architecture.Corresponds to SVIPTRv2-T variant in the paper (VIPTRv2 function
246
+ in the official implementation:
247
+ https://github.com/cxfyxl/VIPTR/blob/main/modules/VIPTRv2.py)
248
+
249
+ Args:
250
+ pretrained: whether to load pretrained weights
251
+ **kwargs: optional arguments
252
+
253
+ Returns:
254
+ VIPNet model
255
+ """
256
+ return _vip(
257
+ "vip_tiny",
258
+ pretrained,
259
+ in_channels=3,
260
+ out_dim=192,
261
+ embed_dims=[64, 128, 256],
262
+ depths=[3, 3, 3],
263
+ num_heads=[2, 4, 8],
264
+ mlp_ratios=[3, 4, 4],
265
+ split_sizes=[1, 2, 4],
266
+ sr_ratios=[4, 2, 2],
267
+ ignore_keys=["6.fc.weight", "6.fc.bias"],
268
+ **kwargs,
269
+ )
270
+
271
+
272
+ def vip_base(pretrained: bool = False, **kwargs: Any) -> VIPNet:
273
+ """
274
+ VIP-Base encoder architecture. Corresponds to SVIPTRv2-B variant in the paper (VIPTRv2B function
275
+ in the official implementation:
276
+ https://github.com/cxfyxl/VIPTR/blob/main/modules/VIPTRv2.py)
277
+
278
+ Args:
279
+ pretrained: whether to load pretrained weights
280
+ **kwargs: optional arguments
281
+
282
+ Returns:
283
+ VIPNet model
284
+ """
285
+ return _vip(
286
+ "vip_base",
287
+ pretrained,
288
+ in_channels=3,
289
+ out_dim=256,
290
+ embed_dims=[128, 256, 384],
291
+ depths=[3, 6, 9],
292
+ num_heads=[4, 8, 12],
293
+ mlp_ratios=[4, 4, 4],
294
+ split_sizes=[1, 2, 4],
295
+ sr_ratios=[4, 2, 2],
296
+ ignore_keys=["6.fc.weight", "6.fc.bias"],
297
+ **kwargs,
298
+ )
299
+
300
+
301
+ def _vip(
302
+ arch: str,
303
+ pretrained: bool,
304
+ ignore_keys: list[str],
305
+ **kwargs: Any,
306
+ ) -> VIPNet:
307
+ """
308
+ Internal constructor for the VIPNet models.
309
+
310
+ Args:
311
+ arch: architecture key
312
+ pretrained: load pretrained weights?
313
+ ignore_keys: layer keys to ignore
314
+ **kwargs: arguments passed to VIPNet
315
+
316
+ Returns:
317
+ VIPNet instance
318
+ """
319
+ kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
320
+ kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
321
+ kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
322
+
323
+ _cfg = deepcopy(default_cfgs[arch])
324
+ _cfg["num_classes"] = kwargs["num_classes"]
325
+ _cfg["input_shape"] = kwargs["input_shape"]
326
+ _cfg["classes"] = kwargs["classes"]
327
+ kwargs.pop("classes")
328
+
329
+ model = VIPNet(cfg=_cfg, **kwargs)
330
+ if pretrained:
331
+ # The number of classes is not the same as the number of classes in the pretrained model =>
332
+ # remove the last layer weights
333
+ _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
334
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
335
+ return model
336
+
337
+
338
+ ############################################
339
+ # _vip_local_mixer
340
+ ############################################
341
+ def _vip_local_mixer(
342
+ embed_dim: int,
343
+ depth: int,
344
+ num_heads: int,
345
+ mlp_ratio: float,
346
+ drop_path: list[float],
347
+ split_size: int = 1,
348
+ sr_ratio: int = 1,
349
+ downsample: bool = False,
350
+ out_dim: int | None = None,
351
+ ) -> nn.Module:
352
+ """Builds a VIPBlock performing local (cross-shaped) window attention.
353
+
354
+ Args:
355
+ embed_dim: embedding dimension.
356
+ depth: number of attention blocks in this stage.
357
+ num_heads: number of attention heads.
358
+ mlp_ratio: ratio used to expand the hidden dimension in MLP.
359
+ split_size: size of the local window splits.
360
+ sr_ratio: parameter needed for cross-compatibility between different mixers
361
+ drop_path: list of per-block drop path rates.
362
+ downsample: whether to apply PatchMerging at the end.
363
+ out_dim: output embedding dimension if downsampling.
364
+
365
+ Returns:
366
+ A VIPBlock (local attention) for one stage of the VIP network.
367
+ """
368
+ blocks = nn.ModuleList([
369
+ CrossShapedWindowAttention(
370
+ dim=embed_dim,
371
+ num_heads=num_heads,
372
+ mlp_ratio=mlp_ratio,
373
+ qkv_bias=True,
374
+ split_size=split_size,
375
+ drop_path=drop_path[i],
376
+ )
377
+ for i in range(depth)
378
+ ])
379
+ return VIPBlock(embed_dim, local_unit=blocks, downsample=downsample, out_dim=out_dim)
380
+
381
+
382
+ ############################################
383
+ # _vip_global_mha_mixer
384
+ ############################################
385
+ def _vip_global_mha_mixer(
386
+ embed_dim: int,
387
+ depth: int,
388
+ num_heads: int,
389
+ mlp_ratio: float,
390
+ drop_path: list[float],
391
+ split_size: int = 1,
392
+ sr_ratio: int = 1,
393
+ downsample: bool = False,
394
+ out_dim: int | None = None,
395
+ ) -> nn.Module:
396
+ """Builds a VIPBlock performing global multi-head self-attention.
397
+
398
+ Args:
399
+ embed_dim: embedding dimension.
400
+ depth: number of attention blocks in this stage.
401
+ num_heads: number of attention heads.
402
+ mlp_ratio: ratio used to expand the hidden dimension in MLP.
403
+ drop_path: list of per-block drop path rates.
404
+ split_size: parameter needed for cross-compatibility between different mixers
405
+ sr_ratio: parameter needed for cross-compatibility between different mixers
406
+ downsample: whether to apply PatchMerging at the end.
407
+ out_dim: output embedding dimension if downsampling.
408
+
409
+ Returns:
410
+ A VIPBlock (global MHA) for one stage of the VIP network.
411
+ """
412
+ blocks = nn.ModuleList([
413
+ MultiHeadSelfAttention(
414
+ dim=embed_dim,
415
+ num_heads=num_heads,
416
+ mlp_ratio=mlp_ratio,
417
+ qkv_bias=True,
418
+ drop_path_rate=drop_path[i],
419
+ )
420
+ for i in range(depth)
421
+ ])
422
+ return VIPBlock(
423
+ embed_dim,
424
+ local_unit=blocks, # In this context, they are "global" blocks but stored in local_unit
425
+ downsample=downsample,
426
+ out_dim=out_dim,
427
+ )
428
+
429
+
430
+ ############################################
431
+ # _vip_mixed_mixer
432
+ ############################################
433
+ def _vip_mixed_mixer(
434
+ embed_dim: int,
435
+ depth: int,
436
+ num_heads: int,
437
+ mlp_ratio: float,
438
+ drop_path: list[float],
439
+ split_size: int = 1,
440
+ sr_ratio: int = 1,
441
+ downsample: bool = False,
442
+ out_dim: int | None = None,
443
+ ) -> nn.Module:
444
+ """Builds a VIPBlock performing mixed local+global attention.
445
+
446
+ Args:
447
+ embed_dim: embedding dimension.
448
+ depth: number of attention blocks in this stage.
449
+ num_heads: total number of attention heads.
450
+ mlp_ratio: ratio used to expand the hidden dimension in MLP.
451
+ drop_path: list of per-block drop path rates.
452
+ split_size: size of the local window splits (for the local half).
453
+ sr_ratio: reduce spatial resolution in the global half (OSRA).
454
+ downsample: whether to apply PatchMerging at the end.
455
+ out_dim: output embedding dimension if downsampling.
456
+
457
+ Returns:
458
+ A VIPBlock (mixed local+global) for one stage of the VIP network.
459
+ """
460
+ # an inner dimension for the conv-projection
461
+ inner_dim = max(16, embed_dim // 8)
462
+ proj = nn.Sequential(
463
+ nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, groups=embed_dim),
464
+ nn.GELU(),
465
+ nn.BatchNorm2d(embed_dim),
466
+ nn.Conv2d(embed_dim, inner_dim, kernel_size=1),
467
+ nn.GELU(),
468
+ nn.BatchNorm2d(inner_dim),
469
+ nn.Conv2d(inner_dim, embed_dim, kernel_size=1),
470
+ nn.BatchNorm2d(embed_dim),
471
+ )
472
+
473
+ # local half blocks
474
+ local_unit = nn.ModuleList([
475
+ CrossShapedWindowAttention(
476
+ dim=embed_dim // 2,
477
+ num_heads=num_heads,
478
+ mlp_ratio=mlp_ratio,
479
+ qkv_bias=True,
480
+ split_size=split_size,
481
+ drop_path=drop_path[i],
482
+ )
483
+ for i in range(depth)
484
+ ])
485
+
486
+ # global half blocks
487
+ global_unit = nn.ModuleList([
488
+ OSRABlock(
489
+ dim=embed_dim // 2,
490
+ sr_ratio=sr_ratio,
491
+ num_heads=num_heads // 2,
492
+ mlp_ratio=mlp_ratio,
493
+ drop_path=drop_path[i],
494
+ )
495
+ for i in range(depth)
496
+ ])
497
+
498
+ return VIPBlock(
499
+ embed_dim,
500
+ local_unit=local_unit,
501
+ global_unit=global_unit,
502
+ proj=proj,
503
+ downsample=downsample,
504
+ out_dim=out_dim,
505
+ )
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -11,9 +11,9 @@ from torch import nn
11
11
 
12
12
  from doctr.datasets import VOCABS
13
13
  from doctr.models.modules.transformer import EncoderBlock
14
- from doctr.models.modules.vision_transformer.pytorch import PatchEmbedding
14
+ from doctr.models.modules.vision_transformer import PatchEmbedding
15
15
 
16
- from ...utils.pytorch import load_pretrained_params
16
+ from ...utils import load_pretrained_params
17
17
 
18
18
  __all__ = ["vit_s", "vit_b"]
19
19
 
@@ -98,6 +98,15 @@ class VisionTransformer(nn.Sequential):
98
98
  super().__init__(*_layers)
99
99
  self.cfg = cfg
100
100
 
101
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
102
+ """Load pretrained parameters onto the model
103
+
104
+ Args:
105
+ path_or_url: the path or URL to the model parameters (checkpoint)
106
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
107
+ """
108
+ load_pretrained_params(self, path_or_url, **kwargs)
109
+
101
110
 
102
111
  def _vit(
103
112
  arch: str,
@@ -122,7 +131,7 @@ def _vit(
122
131
  # The number of classes is not the same as the number of classes in the pretrained model =>
123
132
  # remove the last layer weights
124
133
  _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
125
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
134
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
126
135
 
127
136
  return model
128
137
 
@@ -5,7 +5,7 @@
5
5
 
6
6
  from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available, is_torch_available
8
+ from doctr.models.utils import _CompiledModule
9
9
 
10
10
  from .. import classification
11
11
  from ..preprocessor import PreProcessor
@@ -30,7 +30,10 @@ ARCHS: list[str] = [
30
30
  "vgg16_bn_r",
31
31
  "vit_s",
32
32
  "vit_b",
33
+ "vip_tiny",
34
+ "vip_base",
33
35
  ]
36
+
34
37
  ORIENTATION_ARCHS: list[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
35
38
 
36
39
 
@@ -48,12 +51,8 @@ def _orientation_predictor(
48
51
  # Load directly classifier from backbone
49
52
  _model = classification.__dict__[arch](pretrained=pretrained)
50
53
  else:
51
- allowed_archs = [classification.MobileNetV3]
52
- if is_torch_available():
53
- # Adding the type for torch compiled models to the allowed architectures
54
- from doctr.models.utils import _CompiledModule
55
-
56
- allowed_archs.append(_CompiledModule)
54
+ # Adding the type for torch compiled models to the allowed architectures
55
+ allowed_archs = [classification.MobileNetV3, _CompiledModule]
57
56
 
58
57
  if not isinstance(arch, tuple(allowed_archs)):
59
58
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -62,7 +61,7 @@ def _orientation_predictor(
62
61
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
63
62
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
64
63
  kwargs["batch_size"] = kwargs.get("batch_size", 128 if model_type == "crop" else 4)
65
- input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
64
+ input_shape = _model.cfg["input_shape"][1:]
66
65
  predictor = OrientationPredictor(
67
66
  PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
68
67
  )
@@ -1,7 +1,2 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
1
  from .base import *
3
-
4
- if is_torch_available():
5
- from .pytorch import *
6
- elif is_tf_available():
7
- from .tensorflow import *
2
+ from .pytorch import *
@@ -53,7 +53,7 @@ class DetectionPostProcessor(NestedObject):
53
53
 
54
54
  else:
55
55
  mask: np.ndarray = np.zeros((h, w), np.int32)
56
- cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload]
56
+ cv2.fillPoly(mask, [points.astype(np.int32)], 1.0)
57
57
  product = pred * mask
58
58
  return np.sum(product) / np.count_nonzero(product)
59
59
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -58,9 +58,8 @@ class DBPostProcessor(DetectionPostProcessor):
58
58
  area = (rect[1][0] + 1) * (1 + rect[1][1])
59
59
  length = 2 * (rect[1][0] + rect[1][1]) + 2
60
60
  else:
61
- poly = Polygon(points)
62
- area = poly.area
63
- length = poly.length
61
+ area = cv2.contourArea(points)
62
+ length = cv2.arcLength(points, closed=True)
64
63
  distance = area * self.unclip_ratio / length # compute distance to expand polygon
65
64
  offset = pyclipper.PyclipperOffset()
66
65
  offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -206,7 +205,7 @@ class _DBNet:
206
205
  canvas: np.ndarray,
207
206
  mask: np.ndarray,
208
207
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
209
- """Draw a polygon treshold map on a canvas, as described in the DB paper
208
+ """Draw a polygon threshold map on a canvas, as described in the DB paper
210
209
 
211
210
  Args:
212
211
  polygon : array of coord., to draw the boundary of the polygon
@@ -225,7 +224,7 @@ class _DBNet:
225
224
  padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0])
226
225
 
227
226
  # Fill the mask with 1 on the new padded polygon
228
- cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload]
227
+ cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
229
228
 
230
229
  # Get min/max to recover polygon after distance computation
231
230
  xmin = padded_polygon[:, 0].min()
@@ -270,7 +269,6 @@ class _DBNet:
270
269
  self,
271
270
  target: list[dict[str, np.ndarray]],
272
271
  output_shape: tuple[int, int, int],
273
- channels_last: bool = True,
274
272
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
275
273
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
276
274
  raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
@@ -281,10 +279,8 @@ class _DBNet:
281
279
 
282
280
  h: int
283
281
  w: int
284
- if channels_last:
285
- h, w, num_classes = output_shape
286
- else:
287
- num_classes, h, w = output_shape
282
+
283
+ num_classes, h, w = output_shape
288
284
  target_shape = (len(target), num_classes, h, w)
289
285
 
290
286
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
@@ -344,17 +340,12 @@ class _DBNet:
344
340
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
345
341
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
346
342
  continue
347
- cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
343
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
348
344
 
349
345
  # Draw on both thresh map and thresh mask
350
346
  poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map(
351
347
  poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx]
352
348
  )
353
- if channels_last:
354
- seg_target = seg_target.transpose((0, 2, 3, 1))
355
- seg_mask = seg_mask.transpose((0, 2, 3, 1))
356
- thresh_target = thresh_target.transpose((0, 2, 3, 1))
357
- thresh_mask = thresh_mask.transpose((0, 2, 3, 1))
358
349
 
359
350
  thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min
360
351