python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Union
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -21,7 +22,7 @@ from .base import _FAST, FASTPostProcessor
21
22
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
22
23
 
23
24
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
25
+ default_cfgs: dict[str, dict[str, Any]] = {
25
26
  "fast_tiny": {
26
27
  "input_shape": (3, 1024, 1024),
27
28
  "mean": (0.798, 0.785, 0.772),
@@ -47,7 +48,6 @@ class FastNeck(nn.Module):
47
48
  """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers.
48
49
 
49
50
  Args:
50
- ----
51
51
  in_channels: number of input channels
52
52
  out_channels: number of output channels
53
53
  """
@@ -77,7 +77,6 @@ class FastHead(nn.Sequential):
77
77
  """Head of the FAST architecture
78
78
 
79
79
  Args:
80
- ----
81
80
  in_channels: number of input channels
82
81
  num_classes: number of output classes
83
82
  out_channels: number of output channels
@@ -91,7 +90,7 @@ class FastHead(nn.Sequential):
91
90
  out_channels: int = 128,
92
91
  dropout: float = 0.1,
93
92
  ) -> None:
94
- _layers: List[nn.Module] = [
93
+ _layers: list[nn.Module] = [
95
94
  FASTConvLayer(in_channels, out_channels, kernel_size=3),
96
95
  nn.Dropout(dropout),
97
96
  nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False),
@@ -104,7 +103,6 @@ class FAST(_FAST, nn.Module):
104
103
  <https://arxiv.org/pdf/2111.02394.pdf>`_.
105
104
 
106
105
  Args:
107
- ----
108
106
  feat extractor: the backbone serving as feature extractor
109
107
  bin_thresh: threshold for binarization
110
108
  box_thresh: minimal objectness score to consider a box
@@ -125,8 +123,8 @@ class FAST(_FAST, nn.Module):
125
123
  pooling_size: int = 4, # different from paper performs better on close text-rich images
126
124
  assume_straight_pages: bool = True,
127
125
  exportable: bool = False,
128
- cfg: Optional[Dict[str, Any]] = {},
129
- class_names: List[str] = [CLASS_NAME],
126
+ cfg: dict[str, Any] = {},
127
+ class_names: list[str] = [CLASS_NAME],
130
128
  ) -> None:
131
129
  super().__init__()
132
130
  self.class_names = class_names
@@ -172,13 +170,22 @@ class FAST(_FAST, nn.Module):
172
170
  m.weight.data.fill_(1.0)
173
171
  m.bias.data.zero_()
174
172
 
173
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
174
+ """Load pretrained parameters onto the model
175
+
176
+ Args:
177
+ path_or_url: the path or URL to the model parameters (checkpoint)
178
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
179
+ """
180
+ load_pretrained_params(self, path_or_url, **kwargs)
181
+
175
182
  def forward(
176
183
  self,
177
184
  x: torch.Tensor,
178
- target: Optional[List[np.ndarray]] = None,
185
+ target: list[np.ndarray] | None = None,
179
186
  return_model_output: bool = False,
180
187
  return_preds: bool = False,
181
- ) -> Dict[str, torch.Tensor]:
188
+ ) -> dict[str, torch.Tensor]:
182
189
  # Extract feature maps at different stages
183
190
  feats = self.feat_extractor(x)
184
191
  feats = [feats[str(idx)] for idx in range(len(feats))]
@@ -186,7 +193,7 @@ class FAST(_FAST, nn.Module):
186
193
  feat_concat = self.neck(feats)
187
194
  logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear")
188
195
 
189
- out: Dict[str, Any] = {}
196
+ out: dict[str, Any] = {}
190
197
  if self.exportable:
191
198
  out["logits"] = logits
192
199
  return out
@@ -198,11 +205,16 @@ class FAST(_FAST, nn.Module):
198
205
  out["out_map"] = prob_map
199
206
 
200
207
  if target is None or return_preds:
208
+ # Disable for torch.compile compatibility
209
+ @torch.compiler.disable # type: ignore[attr-defined]
210
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
211
+ return [
212
+ dict(zip(self.class_names, preds))
213
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
214
+ ]
215
+
201
216
  # Post-process boxes (keep only text predictions)
202
- out["preds"] = [
203
- dict(zip(self.class_names, preds))
204
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
205
- ]
217
+ out["preds"] = _postprocess(prob_map)
206
218
 
207
219
  if target is not None:
208
220
  loss = self.compute_loss(logits, target)
@@ -213,19 +225,17 @@ class FAST(_FAST, nn.Module):
213
225
  def compute_loss(
214
226
  self,
215
227
  out_map: torch.Tensor,
216
- target: List[np.ndarray],
228
+ target: list[np.ndarray],
217
229
  eps: float = 1e-6,
218
230
  ) -> torch.Tensor:
219
231
  """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
220
232
 
221
233
  Args:
222
- ----
223
234
  out_map: output feature map of the model of shape (N, num_classes, H, W)
224
235
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
225
236
  eps: epsilon factor in dice loss
226
237
 
227
238
  Returns:
228
- -------
229
239
  A loss tensor
230
240
  """
231
241
  targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
@@ -279,15 +289,13 @@ class FAST(_FAST, nn.Module):
279
289
  return text_loss + kernel_loss
280
290
 
281
291
 
282
- def reparameterize(model: Union[FAST, nn.Module]) -> FAST:
292
+ def reparameterize(model: FAST | nn.Module) -> FAST:
283
293
  """Fuse batchnorm and conv layers and reparameterize the model
284
294
 
285
- args:
286
- ----
295
+ Args:
287
296
  model: the FAST model to reparameterize
288
297
 
289
298
  Returns:
290
- -------
291
299
  the reparameterized model
292
300
  """
293
301
  last_conv = None
@@ -303,12 +311,12 @@ def reparameterize(model: Union[FAST, nn.Module]) -> FAST:
303
311
  if last_conv is None:
304
312
  continue
305
313
  conv_w = last_conv.weight
306
- conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean)
314
+ conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean) # type: ignore[arg-type]
307
315
 
308
- factor = child.weight / torch.sqrt(child.running_var + child.eps)
316
+ factor = child.weight / torch.sqrt(child.running_var + child.eps) # type: ignore
309
317
  last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
310
318
  last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
311
- model._modules[last_conv_name] = last_conv
319
+ model._modules[last_conv_name] = last_conv # type: ignore[index]
312
320
  model._modules[name] = nn.Identity()
313
321
  last_conv = None
314
322
  elif isinstance(child, nn.Conv2d):
@@ -324,9 +332,9 @@ def _fast(
324
332
  arch: str,
325
333
  pretrained: bool,
326
334
  backbone_fn: Callable[[bool], nn.Module],
327
- feat_layers: List[str],
335
+ feat_layers: list[str],
328
336
  pretrained_backbone: bool = True,
329
- ignore_keys: Optional[List[str]] = None,
337
+ ignore_keys: list[str] | None = None,
330
338
  **kwargs: Any,
331
339
  ) -> FAST:
332
340
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -350,7 +358,7 @@ def _fast(
350
358
  _ignore_keys = (
351
359
  ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
352
360
  )
353
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
361
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
354
362
 
355
363
  return model
356
364
 
@@ -366,12 +374,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
366
374
  >>> out = model(input_tensor)
367
375
 
368
376
  Args:
369
- ----
370
377
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
371
378
  **kwargs: keyword arguments of the DBNet architecture
372
379
 
373
380
  Returns:
374
- -------
375
381
  text detection architecture
376
382
  """
377
383
  return _fast(
@@ -395,12 +401,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
395
401
  >>> out = model(input_tensor)
396
402
 
397
403
  Args:
398
- ----
399
404
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
400
405
  **kwargs: keyword arguments of the DBNet architecture
401
406
 
402
407
  Returns:
403
- -------
404
408
  text detection architecture
405
409
  """
406
410
  return _fast(
@@ -424,12 +428,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
424
428
  >>> out = model(input_tensor)
425
429
 
426
430
  Args:
427
- ----
428
431
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
429
432
  **kwargs: keyword arguments of the DBNet architecture
430
433
 
431
434
  Returns:
432
- -------
433
435
  text detection architecture
434
436
  """
435
437
  return _fast(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -23,7 +23,7 @@ from .base import _FAST, FASTPostProcessor
23
23
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
24
24
 
25
25
 
26
- default_cfgs: Dict[str, Dict[str, Any]] = {
26
+ default_cfgs: dict[str, dict[str, Any]] = {
27
27
  "fast_tiny": {
28
28
  "input_shape": (1024, 1024, 3),
29
29
  "mean": (0.798, 0.785, 0.772),
@@ -49,7 +49,6 @@ class FastNeck(layers.Layer, NestedObject):
49
49
  """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
50
50
 
51
51
  Args:
52
- ----
53
52
  in_channels: number of input channels
54
53
  out_channels: number of output channels
55
54
  """
@@ -77,7 +76,6 @@ class FastHead(Sequential):
77
76
  """Head of the FAST architecture
78
77
 
79
78
  Args:
80
- ----
81
79
  in_channels: number of input channels
82
80
  num_classes: number of output classes
83
81
  out_channels: number of output channels
@@ -104,7 +102,6 @@ class FAST(_FAST, Model, NestedObject):
104
102
  <https://arxiv.org/pdf/2111.02394.pdf>`_.
105
103
 
106
104
  Args:
107
- ----
108
105
  feature extractor: the backbone serving as feature extractor
109
106
  bin_thresh: threshold for binarization
110
107
  box_thresh: minimal objectness score to consider a box
@@ -116,7 +113,7 @@ class FAST(_FAST, Model, NestedObject):
116
113
  class_names: list of class names
117
114
  """
118
115
 
119
- _children_names: List[str] = ["feat_extractor", "neck", "head", "postprocessor"]
116
+ _children_names: list[str] = ["feat_extractor", "neck", "head", "postprocessor"]
120
117
 
121
118
  def __init__(
122
119
  self,
@@ -127,8 +124,8 @@ class FAST(_FAST, Model, NestedObject):
127
124
  pooling_size: int = 4, # different from paper performs better on close text-rich images
128
125
  assume_straight_pages: bool = True,
129
126
  exportable: bool = False,
130
- cfg: Optional[Dict[str, Any]] = {},
131
- class_names: List[str] = [CLASS_NAME],
127
+ cfg: dict[str, Any] = {},
128
+ class_names: list[str] = [CLASS_NAME],
132
129
  ) -> None:
133
130
  super().__init__()
134
131
  self.class_names = class_names
@@ -156,22 +153,29 @@ class FAST(_FAST, Model, NestedObject):
156
153
  # Pooling layer as erosion reversal as described in the paper
157
154
  self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
158
155
 
156
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
157
+ """Load pretrained parameters onto the model
158
+
159
+ Args:
160
+ path_or_url: the path or URL to the model parameters (checkpoint)
161
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
162
+ """
163
+ load_pretrained_params(self, path_or_url, **kwargs)
164
+
159
165
  def compute_loss(
160
166
  self,
161
167
  out_map: tf.Tensor,
162
- target: List[Dict[str, np.ndarray]],
168
+ target: list[dict[str, np.ndarray]],
163
169
  eps: float = 1e-6,
164
170
  ) -> tf.Tensor:
165
171
  """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
166
172
 
167
173
  Args:
168
- ----
169
174
  out_map: output feature map of the model of shape (N, num_classes, H, W)
170
175
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
171
176
  eps: epsilon factor in dice loss
172
177
 
173
178
  Returns:
174
- -------
175
179
  A loss tensor
176
180
  """
177
181
  targets = self.build_target(target, out_map.shape[1:], True)
@@ -222,18 +226,18 @@ class FAST(_FAST, Model, NestedObject):
222
226
  def call(
223
227
  self,
224
228
  x: tf.Tensor,
225
- target: Optional[List[Dict[str, np.ndarray]]] = None,
229
+ target: list[dict[str, np.ndarray]] | None = None,
226
230
  return_model_output: bool = False,
227
231
  return_preds: bool = False,
228
232
  **kwargs: Any,
229
- ) -> Dict[str, Any]:
233
+ ) -> dict[str, Any]:
230
234
  feat_maps = self.feat_extractor(x, **kwargs)
231
235
  # Pass through the Neck & Head & Upsample
232
236
  feat_concat = self.neck(feat_maps, **kwargs)
233
237
  logits: tf.Tensor = self.head(feat_concat, **kwargs)
234
238
  logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
235
239
 
236
- out: Dict[str, tf.Tensor] = {}
240
+ out: dict[str, tf.Tensor] = {}
237
241
  if self.exportable:
238
242
  out["logits"] = logits
239
243
  return out
@@ -255,15 +259,14 @@ class FAST(_FAST, Model, NestedObject):
255
259
  return out
256
260
 
257
261
 
258
- def reparameterize(model: Union[FAST, layers.Layer]) -> FAST:
262
+ def reparameterize(model: FAST | layers.Layer) -> FAST:
259
263
  """Fuse batchnorm and conv layers and reparameterize the model
260
264
 
261
265
  args:
262
- ----
266
+
263
267
  model: the FAST model to reparameterize
264
268
 
265
269
  Returns:
266
- -------
267
270
  the reparameterized model
268
271
  """
269
272
  last_conv = None
@@ -306,9 +309,9 @@ def _fast(
306
309
  arch: str,
307
310
  pretrained: bool,
308
311
  backbone_fn,
309
- feat_layers: List[str],
312
+ feat_layers: list[str],
310
313
  pretrained_backbone: bool = True,
311
- input_shape: Optional[Tuple[int, int, int]] = None,
314
+ input_shape: tuple[int, int, int] | None = None,
312
315
  **kwargs: Any,
313
316
  ) -> FAST:
314
317
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -338,8 +341,7 @@ def _fast(
338
341
  # Load pretrained parameters
339
342
  if pretrained:
340
343
  # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
341
- load_pretrained_params(
342
- model,
344
+ model.from_pretrained(
343
345
  _cfg["url"],
344
346
  skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
345
347
  )
@@ -358,12 +360,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
358
360
  >>> out = model(input_tensor)
359
361
 
360
362
  Args:
361
- ----
362
363
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
363
364
  **kwargs: keyword arguments of the DBNet architecture
364
365
 
365
366
  Returns:
366
- -------
367
367
  text detection architecture
368
368
  """
369
369
  return _fast(
@@ -386,12 +386,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
386
386
  >>> out = model(input_tensor)
387
387
 
388
388
  Args:
389
- ----
390
389
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
391
390
  **kwargs: keyword arguments of the DBNet architecture
392
391
 
393
392
  Returns:
394
- -------
395
393
  text detection architecture
396
394
  """
397
395
  return _fast(
@@ -414,12 +412,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
414
412
  >>> out = model(input_tensor)
415
413
 
416
414
  Args:
417
- ----
418
415
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
419
416
  **kwargs: keyword arguments of the DBNet architecture
420
417
 
421
418
  Returns:
422
- -------
423
419
  text detection architecture
424
420
  """
425
421
  return _fast(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -23,7 +22,6 @@ class LinkNetPostProcessor(DetectionPostProcessor):
23
22
  """Implements a post processor for LinkNet model.
24
23
 
25
24
  Args:
26
- ----
27
25
  bin_thresh: threshold used to binzarized p_map at inference time
28
26
  box_thresh: minimal objectness score to consider a box
29
27
  assume_straight_pages: whether the inputs were expected to have horizontal text elements
@@ -45,11 +43,9 @@ class LinkNetPostProcessor(DetectionPostProcessor):
45
43
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
46
44
 
47
45
  Args:
48
- ----
49
46
  points: The first parameter.
50
47
 
51
48
  Returns:
52
- -------
53
49
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
54
50
  """
55
51
  if not self.assume_straight_pages:
@@ -60,9 +56,8 @@ class LinkNetPostProcessor(DetectionPostProcessor):
60
56
  area = (rect[1][0] + 1) * (1 + rect[1][1])
61
57
  length = 2 * (rect[1][0] + rect[1][1]) + 2
62
58
  else:
63
- poly = Polygon(points)
64
- area = poly.area
65
- length = poly.length
59
+ area = cv2.contourArea(points)
60
+ length = cv2.arcLength(points, closed=True)
66
61
  distance = area * self.unclip_ratio / length # compute distance to expand polygon
67
62
  offset = pyclipper.PyclipperOffset()
68
63
  offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -94,24 +89,22 @@ class LinkNetPostProcessor(DetectionPostProcessor):
94
89
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
95
90
 
96
91
  Args:
97
- ----
98
92
  pred: Pred map from differentiable linknet output
99
93
  bitmap: Bitmap map computed from pred (binarized)
100
94
  angle_tol: Comparison tolerance of the angle with the median angle across the page
101
95
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
102
96
 
103
97
  Returns:
104
- -------
105
98
  np tensor boxes for the bitmap, each box is a 6-element list
106
99
  containing x, y, w, h, alpha, score for the box
107
100
  """
108
101
  height, width = bitmap.shape[:2]
109
- boxes: List[Union[np.ndarray, List[float]]] = []
102
+ boxes: list[np.ndarray | list[float]] = []
110
103
  # get contours from connected components on the bitmap
111
104
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
105
  for contour in contours:
113
106
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
107
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
115
108
  continue
116
109
  # Compute objectness
117
110
  if self.assume_straight_pages:
@@ -152,7 +145,6 @@ class _LinkNet(BaseModel):
152
145
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
153
146
 
154
147
  Args:
155
- ----
156
148
  out_chan: number of channels for the output
157
149
  """
158
150
 
@@ -162,20 +154,18 @@ class _LinkNet(BaseModel):
162
154
 
163
155
  def build_target(
164
156
  self,
165
- target: List[Dict[str, np.ndarray]],
166
- output_shape: Tuple[int, int, int],
157
+ target: list[dict[str, np.ndarray]],
158
+ output_shape: tuple[int, int, int],
167
159
  channels_last: bool = True,
168
- ) -> Tuple[np.ndarray, np.ndarray]:
160
+ ) -> tuple[np.ndarray, np.ndarray]:
169
161
  """Build the target, and it's mask to be used from loss computation.
170
162
 
171
163
  Args:
172
- ----
173
164
  target: target coming from dataset
174
165
  output_shape: shape of the output of the model without batch_size
175
166
  channels_last: whether channels are last or not
176
167
 
177
168
  Returns:
178
- -------
179
169
  the new formatted target and the mask
180
170
  """
181
171
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):