python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Union
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -21,7 +22,7 @@ from .base import _FAST, FASTPostProcessor
21
22
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
22
23
 
23
24
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
25
+ default_cfgs: dict[str, dict[str, Any]] = {
25
26
  "fast_tiny": {
26
27
  "input_shape": (3, 1024, 1024),
27
28
  "mean": (0.798, 0.785, 0.772),
@@ -47,7 +48,6 @@ class FastNeck(nn.Module):
47
48
  """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers.
48
49
 
49
50
  Args:
50
- ----
51
51
  in_channels: number of input channels
52
52
  out_channels: number of output channels
53
53
  """
@@ -77,7 +77,6 @@ class FastHead(nn.Sequential):
77
77
  """Head of the FAST architecture
78
78
 
79
79
  Args:
80
- ----
81
80
  in_channels: number of input channels
82
81
  num_classes: number of output classes
83
82
  out_channels: number of output channels
@@ -91,7 +90,7 @@ class FastHead(nn.Sequential):
91
90
  out_channels: int = 128,
92
91
  dropout: float = 0.1,
93
92
  ) -> None:
94
- _layers: List[nn.Module] = [
93
+ _layers: list[nn.Module] = [
95
94
  FASTConvLayer(in_channels, out_channels, kernel_size=3),
96
95
  nn.Dropout(dropout),
97
96
  nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False),
@@ -104,7 +103,6 @@ class FAST(_FAST, nn.Module):
104
103
  <https://arxiv.org/pdf/2111.02394.pdf>`_.
105
104
 
106
105
  Args:
107
- ----
108
106
  feat extractor: the backbone serving as feature extractor
109
107
  bin_thresh: threshold for binarization
110
108
  box_thresh: minimal objectness score to consider a box
@@ -125,8 +123,8 @@ class FAST(_FAST, nn.Module):
125
123
  pooling_size: int = 4, # different from paper performs better on close text-rich images
126
124
  assume_straight_pages: bool = True,
127
125
  exportable: bool = False,
128
- cfg: Optional[Dict[str, Any]] = {},
129
- class_names: List[str] = [CLASS_NAME],
126
+ cfg: dict[str, Any] = {},
127
+ class_names: list[str] = [CLASS_NAME],
130
128
  ) -> None:
131
129
  super().__init__()
132
130
  self.class_names = class_names
@@ -175,10 +173,10 @@ class FAST(_FAST, nn.Module):
175
173
  def forward(
176
174
  self,
177
175
  x: torch.Tensor,
178
- target: Optional[List[np.ndarray]] = None,
176
+ target: list[np.ndarray] | None = None,
179
177
  return_model_output: bool = False,
180
178
  return_preds: bool = False,
181
- ) -> Dict[str, torch.Tensor]:
179
+ ) -> dict[str, torch.Tensor]:
182
180
  # Extract feature maps at different stages
183
181
  feats = self.feat_extractor(x)
184
182
  feats = [feats[str(idx)] for idx in range(len(feats))]
@@ -186,7 +184,7 @@ class FAST(_FAST, nn.Module):
186
184
  feat_concat = self.neck(feats)
187
185
  logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear")
188
186
 
189
- out: Dict[str, Any] = {}
187
+ out: dict[str, Any] = {}
190
188
  if self.exportable:
191
189
  out["logits"] = logits
192
190
  return out
@@ -198,11 +196,16 @@ class FAST(_FAST, nn.Module):
198
196
  out["out_map"] = prob_map
199
197
 
200
198
  if target is None or return_preds:
199
+ # Disable for torch.compile compatibility
200
+ @torch.compiler.disable # type: ignore[attr-defined]
201
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
202
+ return [
203
+ dict(zip(self.class_names, preds))
204
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
205
+ ]
206
+
201
207
  # Post-process boxes (keep only text predictions)
202
- out["preds"] = [
203
- dict(zip(self.class_names, preds))
204
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
205
- ]
208
+ out["preds"] = _postprocess(prob_map)
206
209
 
207
210
  if target is not None:
208
211
  loss = self.compute_loss(logits, target)
@@ -213,19 +216,17 @@ class FAST(_FAST, nn.Module):
213
216
  def compute_loss(
214
217
  self,
215
218
  out_map: torch.Tensor,
216
- target: List[np.ndarray],
219
+ target: list[np.ndarray],
217
220
  eps: float = 1e-6,
218
221
  ) -> torch.Tensor:
219
222
  """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
220
223
 
221
224
  Args:
222
- ----
223
225
  out_map: output feature map of the model of shape (N, num_classes, H, W)
224
226
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
225
227
  eps: epsilon factor in dice loss
226
228
 
227
229
  Returns:
228
- -------
229
230
  A loss tensor
230
231
  """
231
232
  targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
@@ -279,15 +280,13 @@ class FAST(_FAST, nn.Module):
279
280
  return text_loss + kernel_loss
280
281
 
281
282
 
282
- def reparameterize(model: Union[FAST, nn.Module]) -> FAST:
283
+ def reparameterize(model: FAST | nn.Module) -> FAST:
283
284
  """Fuse batchnorm and conv layers and reparameterize the model
284
285
 
285
- args:
286
- ----
286
+ Args:
287
287
  model: the FAST model to reparameterize
288
288
 
289
289
  Returns:
290
- -------
291
290
  the reparameterized model
292
291
  """
293
292
  last_conv = None
@@ -324,9 +323,9 @@ def _fast(
324
323
  arch: str,
325
324
  pretrained: bool,
326
325
  backbone_fn: Callable[[bool], nn.Module],
327
- feat_layers: List[str],
326
+ feat_layers: list[str],
328
327
  pretrained_backbone: bool = True,
329
- ignore_keys: Optional[List[str]] = None,
328
+ ignore_keys: list[str] | None = None,
330
329
  **kwargs: Any,
331
330
  ) -> FAST:
332
331
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -366,12 +365,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
366
365
  >>> out = model(input_tensor)
367
366
 
368
367
  Args:
369
- ----
370
368
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
371
369
  **kwargs: keyword arguments of the DBNet architecture
372
370
 
373
371
  Returns:
374
- -------
375
372
  text detection architecture
376
373
  """
377
374
  return _fast(
@@ -395,12 +392,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
395
392
  >>> out = model(input_tensor)
396
393
 
397
394
  Args:
398
- ----
399
395
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
400
396
  **kwargs: keyword arguments of the DBNet architecture
401
397
 
402
398
  Returns:
403
- -------
404
399
  text detection architecture
405
400
  """
406
401
  return _fast(
@@ -424,12 +419,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
424
419
  >>> out = model(input_tensor)
425
420
 
426
421
  Args:
427
- ----
428
422
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
429
423
  **kwargs: keyword arguments of the DBNet architecture
430
424
 
431
425
  Returns:
432
- -------
433
426
  text detection architecture
434
427
  """
435
428
  return _fast(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,15 +6,14 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
13
- from tensorflow import keras
14
- from tensorflow.keras import Sequential, layers
13
+ from tensorflow.keras import Model, Sequential, layers
15
14
 
16
15
  from doctr.file_utils import CLASS_NAME
17
- from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
16
+ from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
18
17
  from doctr.utils.repr import NestedObject
19
18
 
20
19
  from ...classification import textnet_base, textnet_small, textnet_tiny
@@ -24,24 +23,24 @@ from .base import _FAST, FASTPostProcessor
24
23
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
25
24
 
26
25
 
27
- default_cfgs: Dict[str, Dict[str, Any]] = {
26
+ default_cfgs: dict[str, dict[str, Any]] = {
28
27
  "fast_tiny": {
29
28
  "input_shape": (1024, 1024, 3),
30
29
  "mean": (0.798, 0.785, 0.772),
31
30
  "std": (0.264, 0.2749, 0.287),
32
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_tiny-959daecb.zip&src=0",
31
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
33
32
  },
34
33
  "fast_small": {
35
34
  "input_shape": (1024, 1024, 3),
36
35
  "mean": (0.798, 0.785, 0.772),
37
36
  "std": (0.264, 0.2749, 0.287),
38
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_small-f1617503.zip&src=0",
37
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
39
38
  },
40
39
  "fast_base": {
41
40
  "input_shape": (1024, 1024, 3),
42
41
  "mean": (0.798, 0.785, 0.772),
43
42
  "std": (0.264, 0.2749, 0.287),
44
- "url": "https://doctr-static.mindee.com/models?id=v0.8.1/fast_base-255e2ac3.zip&src=0",
43
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
45
44
  },
46
45
  }
47
46
 
@@ -50,7 +49,6 @@ class FastNeck(layers.Layer, NestedObject):
50
49
  """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
51
50
 
52
51
  Args:
53
- ----
54
52
  in_channels: number of input channels
55
53
  out_channels: number of output channels
56
54
  """
@@ -78,7 +76,6 @@ class FastHead(Sequential):
78
76
  """Head of the FAST architecture
79
77
 
80
78
  Args:
81
- ----
82
79
  in_channels: number of input channels
83
80
  num_classes: number of output classes
84
81
  out_channels: number of output channels
@@ -100,12 +97,11 @@ class FastHead(Sequential):
100
97
  super().__init__(_layers)
101
98
 
102
99
 
103
- class FAST(_FAST, keras.Model, NestedObject):
100
+ class FAST(_FAST, Model, NestedObject):
104
101
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
105
102
  <https://arxiv.org/pdf/2111.02394.pdf>`_.
106
103
 
107
104
  Args:
108
- ----
109
105
  feature extractor: the backbone serving as feature extractor
110
106
  bin_thresh: threshold for binarization
111
107
  box_thresh: minimal objectness score to consider a box
@@ -117,7 +113,7 @@ class FAST(_FAST, keras.Model, NestedObject):
117
113
  class_names: list of class names
118
114
  """
119
115
 
120
- _children_names: List[str] = ["feat_extractor", "neck", "head", "postprocessor"]
116
+ _children_names: list[str] = ["feat_extractor", "neck", "head", "postprocessor"]
121
117
 
122
118
  def __init__(
123
119
  self,
@@ -128,8 +124,8 @@ class FAST(_FAST, keras.Model, NestedObject):
128
124
  pooling_size: int = 4, # different from paper performs better on close text-rich images
129
125
  assume_straight_pages: bool = True,
130
126
  exportable: bool = False,
131
- cfg: Optional[Dict[str, Any]] = {},
132
- class_names: List[str] = [CLASS_NAME],
127
+ cfg: dict[str, Any] = {},
128
+ class_names: list[str] = [CLASS_NAME],
133
129
  ) -> None:
134
130
  super().__init__()
135
131
  self.class_names = class_names
@@ -160,19 +156,17 @@ class FAST(_FAST, keras.Model, NestedObject):
160
156
  def compute_loss(
161
157
  self,
162
158
  out_map: tf.Tensor,
163
- target: List[Dict[str, np.ndarray]],
159
+ target: list[dict[str, np.ndarray]],
164
160
  eps: float = 1e-6,
165
161
  ) -> tf.Tensor:
166
162
  """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
167
163
 
168
164
  Args:
169
- ----
170
165
  out_map: output feature map of the model of shape (N, num_classes, H, W)
171
166
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
172
167
  eps: epsilon factor in dice loss
173
168
 
174
169
  Returns:
175
- -------
176
170
  A loss tensor
177
171
  """
178
172
  targets = self.build_target(target, out_map.shape[1:], True)
@@ -223,18 +217,18 @@ class FAST(_FAST, keras.Model, NestedObject):
223
217
  def call(
224
218
  self,
225
219
  x: tf.Tensor,
226
- target: Optional[List[Dict[str, np.ndarray]]] = None,
220
+ target: list[dict[str, np.ndarray]] | None = None,
227
221
  return_model_output: bool = False,
228
222
  return_preds: bool = False,
229
223
  **kwargs: Any,
230
- ) -> Dict[str, Any]:
224
+ ) -> dict[str, Any]:
231
225
  feat_maps = self.feat_extractor(x, **kwargs)
232
226
  # Pass through the Neck & Head & Upsample
233
227
  feat_concat = self.neck(feat_maps, **kwargs)
234
228
  logits: tf.Tensor = self.head(feat_concat, **kwargs)
235
229
  logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
236
230
 
237
- out: Dict[str, tf.Tensor] = {}
231
+ out: dict[str, tf.Tensor] = {}
238
232
  if self.exportable:
239
233
  out["logits"] = logits
240
234
  return out
@@ -256,15 +250,14 @@ class FAST(_FAST, keras.Model, NestedObject):
256
250
  return out
257
251
 
258
252
 
259
- def reparameterize(model: Union[FAST, layers.Layer]) -> FAST:
253
+ def reparameterize(model: FAST | layers.Layer) -> FAST:
260
254
  """Fuse batchnorm and conv layers and reparameterize the model
261
255
 
262
256
  args:
263
- ----
257
+
264
258
  model: the FAST model to reparameterize
265
259
 
266
260
  Returns:
267
- -------
268
261
  the reparameterized model
269
262
  """
270
263
  last_conv = None
@@ -307,9 +300,9 @@ def _fast(
307
300
  arch: str,
308
301
  pretrained: bool,
309
302
  backbone_fn,
310
- feat_layers: List[str],
303
+ feat_layers: list[str],
311
304
  pretrained_backbone: bool = True,
312
- input_shape: Optional[Tuple[int, int, int]] = None,
305
+ input_shape: tuple[int, int, int] | None = None,
313
306
  **kwargs: Any,
314
307
  ) -> FAST:
315
308
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -334,12 +327,16 @@ def _fast(
334
327
 
335
328
  # Build the model
336
329
  model = FAST(feat_extractor, cfg=_cfg, **kwargs)
330
+ _build_model(model)
331
+
337
332
  # Load pretrained parameters
338
333
  if pretrained:
339
- load_pretrained_params(model, _cfg["url"])
340
-
341
- # Build the model for reparameterization to access the layers
342
- _ = model(tf.random.uniform(shape=[1, *_cfg["input_shape"]], maxval=1, dtype=tf.float32), training=False)
334
+ # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
335
+ load_pretrained_params(
336
+ model,
337
+ _cfg["url"],
338
+ skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
339
+ )
343
340
 
344
341
  return model
345
342
 
@@ -355,12 +352,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
355
352
  >>> out = model(input_tensor)
356
353
 
357
354
  Args:
358
- ----
359
355
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
360
356
  **kwargs: keyword arguments of the DBNet architecture
361
357
 
362
358
  Returns:
363
- -------
364
359
  text detection architecture
365
360
  """
366
361
  return _fast(
@@ -383,12 +378,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
383
378
  >>> out = model(input_tensor)
384
379
 
385
380
  Args:
386
- ----
387
381
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
388
382
  **kwargs: keyword arguments of the DBNet architecture
389
383
 
390
384
  Returns:
391
- -------
392
385
  text detection architecture
393
386
  """
394
387
  return _fast(
@@ -411,12 +404,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
411
404
  >>> out = model(input_tensor)
412
405
 
413
406
  Args:
414
- ----
415
407
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
416
408
  **kwargs: keyword arguments of the DBNet architecture
417
409
 
418
410
  Returns:
419
- -------
420
411
  text detection architecture
421
412
  """
422
413
  return _fast(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -23,7 +22,6 @@ class LinkNetPostProcessor(DetectionPostProcessor):
23
22
  """Implements a post processor for LinkNet model.
24
23
 
25
24
  Args:
26
- ----
27
25
  bin_thresh: threshold used to binzarized p_map at inference time
28
26
  box_thresh: minimal objectness score to consider a box
29
27
  assume_straight_pages: whether the inputs were expected to have horizontal text elements
@@ -45,11 +43,9 @@ class LinkNetPostProcessor(DetectionPostProcessor):
45
43
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
46
44
 
47
45
  Args:
48
- ----
49
46
  points: The first parameter.
50
47
 
51
48
  Returns:
52
- -------
53
49
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
54
50
  """
55
51
  if not self.assume_straight_pages:
@@ -94,24 +90,22 @@ class LinkNetPostProcessor(DetectionPostProcessor):
94
90
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
95
91
 
96
92
  Args:
97
- ----
98
93
  pred: Pred map from differentiable linknet output
99
94
  bitmap: Bitmap map computed from pred (binarized)
100
95
  angle_tol: Comparison tolerance of the angle with the median angle across the page
101
96
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
102
97
 
103
98
  Returns:
104
- -------
105
99
  np tensor boxes for the bitmap, each box is a 6-element list
106
100
  containing x, y, w, h, alpha, score for the box
107
101
  """
108
102
  height, width = bitmap.shape[:2]
109
- boxes: List[Union[np.ndarray, List[float]]] = []
103
+ boxes: list[np.ndarray | list[float]] = []
110
104
  # get contours from connected components on the bitmap
111
105
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
106
  for contour in contours:
113
107
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
108
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
115
109
  continue
116
110
  # Compute objectness
117
111
  if self.assume_straight_pages:
@@ -152,7 +146,6 @@ class _LinkNet(BaseModel):
152
146
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
153
147
 
154
148
  Args:
155
- ----
156
149
  out_chan: number of channels for the output
157
150
  """
158
151
 
@@ -162,20 +155,18 @@ class _LinkNet(BaseModel):
162
155
 
163
156
  def build_target(
164
157
  self,
165
- target: List[Dict[str, np.ndarray]],
166
- output_shape: Tuple[int, int, int],
158
+ target: list[dict[str, np.ndarray]],
159
+ output_shape: tuple[int, int, int],
167
160
  channels_last: bool = True,
168
- ) -> Tuple[np.ndarray, np.ndarray]:
161
+ ) -> tuple[np.ndarray, np.ndarray]:
169
162
  """Build the target, and it's mask to be used from loss computation.
170
163
 
171
164
  Args:
172
- ----
173
165
  target: target coming from dataset
174
166
  output_shape: shape of the output of the model without batch_size
175
167
  channels_last: whether channels are last or not
176
168
 
177
169
  Returns:
178
- -------
179
170
  the new formatted target and the mask
180
171
  """
181
172
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):