python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -29,7 +29,7 @@ from .base import DBPostProcessor, _DBNet
29
29
  __all__ = ["DBNet", "db_resnet50", "db_mobilenet_v3_large"]
30
30
 
31
31
 
32
- default_cfgs: Dict[str, Dict[str, Any]] = {
32
+ default_cfgs: dict[str, dict[str, Any]] = {
33
33
  "db_resnet50": {
34
34
  "mean": (0.798, 0.785, 0.772),
35
35
  "std": (0.264, 0.2749, 0.287),
@@ -50,7 +50,6 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
50
50
  <https://arxiv.org/pdf/1612.03144.pdf>`_.
51
51
 
52
52
  Args:
53
- ----
54
53
  channels: number of channel to output
55
54
  """
56
55
 
@@ -72,12 +71,10 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
72
71
  """Module which performs a 3x3 convolution followed by up-sampling
73
72
 
74
73
  Args:
75
- ----
76
74
  channels: number of output channels
77
75
  dilation_factor (int): dilation factor to scale the convolution output before concatenation
78
76
 
79
77
  Returns:
80
- -------
81
78
  a keras.layers.Layer object, wrapping these operations in a sequential module
82
79
 
83
80
  """
@@ -95,7 +92,7 @@ class FeaturePyramidNetwork(layers.Layer, NestedObject):
95
92
 
96
93
  def call(
97
94
  self,
98
- x: List[tf.Tensor],
95
+ x: list[tf.Tensor],
99
96
  **kwargs: Any,
100
97
  ) -> tf.Tensor:
101
98
  # Channel mapping
@@ -114,7 +111,6 @@ class DBNet(_DBNet, Model, NestedObject):
114
111
  <https://arxiv.org/pdf/1911.08947.pdf>`_.
115
112
 
116
113
  Args:
117
- ----
118
114
  feature extractor: the backbone serving as feature extractor
119
115
  fpn_channels: number of channels each extracted feature maps is mapped to
120
116
  bin_thresh: threshold for binarization
@@ -125,7 +121,7 @@ class DBNet(_DBNet, Model, NestedObject):
125
121
  class_names: list of class names
126
122
  """
127
123
 
128
- _children_names: List[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
124
+ _children_names: list[str] = ["feat_extractor", "fpn", "probability_head", "threshold_head", "postprocessor"]
129
125
 
130
126
  def __init__(
131
127
  self,
@@ -135,8 +131,8 @@ class DBNet(_DBNet, Model, NestedObject):
135
131
  box_thresh: float = 0.1,
136
132
  assume_straight_pages: bool = True,
137
133
  exportable: bool = False,
138
- cfg: Optional[Dict[str, Any]] = None,
139
- class_names: List[str] = [CLASS_NAME],
134
+ cfg: dict[str, Any] | None = None,
135
+ class_names: list[str] = [CLASS_NAME],
140
136
  ) -> None:
141
137
  super().__init__()
142
138
  self.class_names = class_names
@@ -175,7 +171,7 @@ class DBNet(_DBNet, Model, NestedObject):
175
171
  self,
176
172
  out_map: tf.Tensor,
177
173
  thresh_map: tf.Tensor,
178
- target: List[Dict[str, np.ndarray]],
174
+ target: list[dict[str, np.ndarray]],
179
175
  gamma: float = 2.0,
180
176
  alpha: float = 0.5,
181
177
  eps: float = 1e-8,
@@ -184,7 +180,6 @@ class DBNet(_DBNet, Model, NestedObject):
184
180
  and a list of masks for each image. From there it computes the loss with the model output
185
181
 
186
182
  Args:
187
- ----
188
183
  out_map: output feature map of the model of shape (N, H, W, C)
189
184
  thresh_map: threshold map of shape (N, H, W, C)
190
185
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
@@ -193,7 +188,6 @@ class DBNet(_DBNet, Model, NestedObject):
193
188
  eps: epsilon factor in dice loss
194
189
 
195
190
  Returns:
196
- -------
197
191
  A loss tensor
198
192
  """
199
193
  if gamma < 0:
@@ -246,16 +240,16 @@ class DBNet(_DBNet, Model, NestedObject):
246
240
  def call(
247
241
  self,
248
242
  x: tf.Tensor,
249
- target: Optional[List[Dict[str, np.ndarray]]] = None,
243
+ target: list[dict[str, np.ndarray]] | None = None,
250
244
  return_model_output: bool = False,
251
245
  return_preds: bool = False,
252
246
  **kwargs: Any,
253
- ) -> Dict[str, Any]:
247
+ ) -> dict[str, Any]:
254
248
  feat_maps = self.feat_extractor(x, **kwargs)
255
249
  feat_concat = self.fpn(feat_maps, **kwargs)
256
250
  logits = self.probability_head(feat_concat, **kwargs)
257
251
 
258
- out: Dict[str, tf.Tensor] = {}
252
+ out: dict[str, tf.Tensor] = {}
259
253
  if self.exportable:
260
254
  out["logits"] = logits
261
255
  return out
@@ -282,9 +276,9 @@ def _db_resnet(
282
276
  arch: str,
283
277
  pretrained: bool,
284
278
  backbone_fn,
285
- fpn_layers: List[str],
279
+ fpn_layers: list[str],
286
280
  pretrained_backbone: bool = True,
287
- input_shape: Optional[Tuple[int, int, int]] = None,
281
+ input_shape: tuple[int, int, int] | None = None,
288
282
  **kwargs: Any,
289
283
  ) -> DBNet:
290
284
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -328,9 +322,9 @@ def _db_mobilenet(
328
322
  arch: str,
329
323
  pretrained: bool,
330
324
  backbone_fn,
331
- fpn_layers: List[str],
325
+ fpn_layers: list[str],
332
326
  pretrained_backbone: bool = True,
333
- input_shape: Optional[Tuple[int, int, int]] = None,
327
+ input_shape: tuple[int, int, int] | None = None,
334
328
  **kwargs: Any,
335
329
  ) -> DBNet:
336
330
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -379,12 +373,10 @@ def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet:
379
373
  >>> out = model(input_tensor)
380
374
 
381
375
  Args:
382
- ----
383
376
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
384
377
  **kwargs: keyword arguments of the DBNet architecture
385
378
 
386
379
  Returns:
387
- -------
388
380
  text detection architecture
389
381
  """
390
382
  return _db_resnet(
@@ -407,12 +399,10 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet:
407
399
  >>> out = model(input_tensor)
408
400
 
409
401
  Args:
410
- ----
411
402
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
412
403
  **kwargs: keyword arguments of the DBNet architecture
413
404
 
414
405
  Returns:
415
- -------
416
406
  text detection architecture
417
407
  """
418
408
  return _db_mobilenet(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,11 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import Dict, List, Tuple, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -23,7 +22,6 @@ class FASTPostProcessor(DetectionPostProcessor):
23
22
  """Implements a post processor for FAST model.
24
23
 
25
24
  Args:
26
- ----
27
25
  bin_thresh: threshold used to binzarized p_map at inference time
28
26
  box_thresh: minimal objectness score to consider a box
29
27
  assume_straight_pages: whether the inputs were expected to have horizontal text elements
@@ -45,11 +43,9 @@ class FASTPostProcessor(DetectionPostProcessor):
45
43
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
46
44
 
47
45
  Args:
48
- ----
49
46
  points: The first parameter.
50
47
 
51
48
  Returns:
52
- -------
53
49
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
54
50
  """
55
51
  if not self.assume_straight_pages:
@@ -94,24 +90,22 @@ class FASTPostProcessor(DetectionPostProcessor):
94
90
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
95
91
 
96
92
  Args:
97
- ----
98
93
  pred: Pred map from differentiable linknet output
99
94
  bitmap: Bitmap map computed from pred (binarized)
100
95
  angle_tol: Comparison tolerance of the angle with the median angle across the page
101
96
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
102
97
 
103
98
  Returns:
104
- -------
105
99
  np tensor boxes for the bitmap, each box is a 6-element list
106
100
  containing x, y, w, h, alpha, score for the box
107
101
  """
108
102
  height, width = bitmap.shape[:2]
109
- boxes: List[Union[np.ndarray, List[float]]] = []
103
+ boxes: list[np.ndarray | list[float]] = []
110
104
  # get contours from connected components on the bitmap
111
105
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
112
106
  for contour in contours:
113
107
  # Check whether smallest enclosing bounding box is not too small
114
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
108
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
115
109
  continue
116
110
  # Compute objectness
117
111
  if self.assume_straight_pages:
@@ -158,20 +152,18 @@ class _FAST(BaseModel):
158
152
 
159
153
  def build_target(
160
154
  self,
161
- target: List[Dict[str, np.ndarray]],
162
- output_shape: Tuple[int, int, int],
155
+ target: list[dict[str, np.ndarray]],
156
+ output_shape: tuple[int, int, int],
163
157
  channels_last: bool = True,
164
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
158
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
165
159
  """Build the target, and it's mask to be used from loss computation.
166
160
 
167
161
  Args:
168
- ----
169
162
  target: target coming from dataset
170
163
  output_shape: shape of the output of the model without batch_size
171
164
  channels_last: whether channels are last or not
172
165
 
173
166
  Returns:
174
- -------
175
167
  the new formatted target, mask and shrunken text kernel
176
168
  """
177
169
  if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Callable, Dict, List, Optional, Union
6
+ from collections.abc import Callable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -21,7 +22,7 @@ from .base import _FAST, FASTPostProcessor
21
22
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
22
23
 
23
24
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
25
+ default_cfgs: dict[str, dict[str, Any]] = {
25
26
  "fast_tiny": {
26
27
  "input_shape": (3, 1024, 1024),
27
28
  "mean": (0.798, 0.785, 0.772),
@@ -47,7 +48,6 @@ class FastNeck(nn.Module):
47
48
  """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layers.
48
49
 
49
50
  Args:
50
- ----
51
51
  in_channels: number of input channels
52
52
  out_channels: number of output channels
53
53
  """
@@ -77,7 +77,6 @@ class FastHead(nn.Sequential):
77
77
  """Head of the FAST architecture
78
78
 
79
79
  Args:
80
- ----
81
80
  in_channels: number of input channels
82
81
  num_classes: number of output classes
83
82
  out_channels: number of output channels
@@ -91,7 +90,7 @@ class FastHead(nn.Sequential):
91
90
  out_channels: int = 128,
92
91
  dropout: float = 0.1,
93
92
  ) -> None:
94
- _layers: List[nn.Module] = [
93
+ _layers: list[nn.Module] = [
95
94
  FASTConvLayer(in_channels, out_channels, kernel_size=3),
96
95
  nn.Dropout(dropout),
97
96
  nn.Conv2d(out_channels, num_classes, kernel_size=1, bias=False),
@@ -104,7 +103,6 @@ class FAST(_FAST, nn.Module):
104
103
  <https://arxiv.org/pdf/2111.02394.pdf>`_.
105
104
 
106
105
  Args:
107
- ----
108
106
  feat extractor: the backbone serving as feature extractor
109
107
  bin_thresh: threshold for binarization
110
108
  box_thresh: minimal objectness score to consider a box
@@ -125,8 +123,8 @@ class FAST(_FAST, nn.Module):
125
123
  pooling_size: int = 4, # different from paper performs better on close text-rich images
126
124
  assume_straight_pages: bool = True,
127
125
  exportable: bool = False,
128
- cfg: Optional[Dict[str, Any]] = {},
129
- class_names: List[str] = [CLASS_NAME],
126
+ cfg: dict[str, Any] = {},
127
+ class_names: list[str] = [CLASS_NAME],
130
128
  ) -> None:
131
129
  super().__init__()
132
130
  self.class_names = class_names
@@ -175,10 +173,10 @@ class FAST(_FAST, nn.Module):
175
173
  def forward(
176
174
  self,
177
175
  x: torch.Tensor,
178
- target: Optional[List[np.ndarray]] = None,
176
+ target: list[np.ndarray] | None = None,
179
177
  return_model_output: bool = False,
180
178
  return_preds: bool = False,
181
- ) -> Dict[str, torch.Tensor]:
179
+ ) -> dict[str, torch.Tensor]:
182
180
  # Extract feature maps at different stages
183
181
  feats = self.feat_extractor(x)
184
182
  feats = [feats[str(idx)] for idx in range(len(feats))]
@@ -186,7 +184,7 @@ class FAST(_FAST, nn.Module):
186
184
  feat_concat = self.neck(feats)
187
185
  logits = F.interpolate(self.prob_head(feat_concat), size=x.shape[-2:], mode="bilinear")
188
186
 
189
- out: Dict[str, Any] = {}
187
+ out: dict[str, Any] = {}
190
188
  if self.exportable:
191
189
  out["logits"] = logits
192
190
  return out
@@ -198,11 +196,16 @@ class FAST(_FAST, nn.Module):
198
196
  out["out_map"] = prob_map
199
197
 
200
198
  if target is None or return_preds:
199
+ # Disable for torch.compile compatibility
200
+ @torch.compiler.disable # type: ignore[attr-defined]
201
+ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
202
+ return [
203
+ dict(zip(self.class_names, preds))
204
+ for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
205
+ ]
206
+
201
207
  # Post-process boxes (keep only text predictions)
202
- out["preds"] = [
203
- dict(zip(self.class_names, preds))
204
- for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
205
- ]
208
+ out["preds"] = _postprocess(prob_map)
206
209
 
207
210
  if target is not None:
208
211
  loss = self.compute_loss(logits, target)
@@ -213,19 +216,17 @@ class FAST(_FAST, nn.Module):
213
216
  def compute_loss(
214
217
  self,
215
218
  out_map: torch.Tensor,
216
- target: List[np.ndarray],
219
+ target: list[np.ndarray],
217
220
  eps: float = 1e-6,
218
221
  ) -> torch.Tensor:
219
222
  """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
220
223
 
221
224
  Args:
222
- ----
223
225
  out_map: output feature map of the model of shape (N, num_classes, H, W)
224
226
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
225
227
  eps: epsilon factor in dice loss
226
228
 
227
229
  Returns:
228
- -------
229
230
  A loss tensor
230
231
  """
231
232
  targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
@@ -279,15 +280,13 @@ class FAST(_FAST, nn.Module):
279
280
  return text_loss + kernel_loss
280
281
 
281
282
 
282
- def reparameterize(model: Union[FAST, nn.Module]) -> FAST:
283
+ def reparameterize(model: FAST | nn.Module) -> FAST:
283
284
  """Fuse batchnorm and conv layers and reparameterize the model
284
285
 
285
- args:
286
- ----
286
+ Args:
287
287
  model: the FAST model to reparameterize
288
288
 
289
289
  Returns:
290
- -------
291
290
  the reparameterized model
292
291
  """
293
292
  last_conv = None
@@ -324,9 +323,9 @@ def _fast(
324
323
  arch: str,
325
324
  pretrained: bool,
326
325
  backbone_fn: Callable[[bool], nn.Module],
327
- feat_layers: List[str],
326
+ feat_layers: list[str],
328
327
  pretrained_backbone: bool = True,
329
- ignore_keys: Optional[List[str]] = None,
328
+ ignore_keys: list[str] | None = None,
330
329
  **kwargs: Any,
331
330
  ) -> FAST:
332
331
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -366,12 +365,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
366
365
  >>> out = model(input_tensor)
367
366
 
368
367
  Args:
369
- ----
370
368
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
371
369
  **kwargs: keyword arguments of the DBNet architecture
372
370
 
373
371
  Returns:
374
- -------
375
372
  text detection architecture
376
373
  """
377
374
  return _fast(
@@ -395,12 +392,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
395
392
  >>> out = model(input_tensor)
396
393
 
397
394
  Args:
398
- ----
399
395
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
400
396
  **kwargs: keyword arguments of the DBNet architecture
401
397
 
402
398
  Returns:
403
- -------
404
399
  text detection architecture
405
400
  """
406
401
  return _fast(
@@ -424,12 +419,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
424
419
  >>> out = model(input_tensor)
425
420
 
426
421
  Args:
427
- ----
428
422
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
429
423
  **kwargs: keyword arguments of the DBNet architecture
430
424
 
431
425
  Returns:
432
- -------
433
426
  text detection architecture
434
427
  """
435
428
  return _fast(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
8
  from copy import deepcopy
9
- from typing import Any, Dict, List, Optional, Tuple, Union
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -23,7 +23,7 @@ from .base import _FAST, FASTPostProcessor
23
23
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]
24
24
 
25
25
 
26
- default_cfgs: Dict[str, Dict[str, Any]] = {
26
+ default_cfgs: dict[str, dict[str, Any]] = {
27
27
  "fast_tiny": {
28
28
  "input_shape": (1024, 1024, 3),
29
29
  "mean": (0.798, 0.785, 0.772),
@@ -49,7 +49,6 @@ class FastNeck(layers.Layer, NestedObject):
49
49
  """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.
50
50
 
51
51
  Args:
52
- ----
53
52
  in_channels: number of input channels
54
53
  out_channels: number of output channels
55
54
  """
@@ -77,7 +76,6 @@ class FastHead(Sequential):
77
76
  """Head of the FAST architecture
78
77
 
79
78
  Args:
80
- ----
81
79
  in_channels: number of input channels
82
80
  num_classes: number of output classes
83
81
  out_channels: number of output channels
@@ -104,7 +102,6 @@ class FAST(_FAST, Model, NestedObject):
104
102
  <https://arxiv.org/pdf/2111.02394.pdf>`_.
105
103
 
106
104
  Args:
107
- ----
108
105
  feature extractor: the backbone serving as feature extractor
109
106
  bin_thresh: threshold for binarization
110
107
  box_thresh: minimal objectness score to consider a box
@@ -116,7 +113,7 @@ class FAST(_FAST, Model, NestedObject):
116
113
  class_names: list of class names
117
114
  """
118
115
 
119
- _children_names: List[str] = ["feat_extractor", "neck", "head", "postprocessor"]
116
+ _children_names: list[str] = ["feat_extractor", "neck", "head", "postprocessor"]
120
117
 
121
118
  def __init__(
122
119
  self,
@@ -127,8 +124,8 @@ class FAST(_FAST, Model, NestedObject):
127
124
  pooling_size: int = 4, # different from paper performs better on close text-rich images
128
125
  assume_straight_pages: bool = True,
129
126
  exportable: bool = False,
130
- cfg: Optional[Dict[str, Any]] = {},
131
- class_names: List[str] = [CLASS_NAME],
127
+ cfg: dict[str, Any] = {},
128
+ class_names: list[str] = [CLASS_NAME],
132
129
  ) -> None:
133
130
  super().__init__()
134
131
  self.class_names = class_names
@@ -159,19 +156,17 @@ class FAST(_FAST, Model, NestedObject):
159
156
  def compute_loss(
160
157
  self,
161
158
  out_map: tf.Tensor,
162
- target: List[Dict[str, np.ndarray]],
159
+ target: list[dict[str, np.ndarray]],
163
160
  eps: float = 1e-6,
164
161
  ) -> tf.Tensor:
165
162
  """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.
166
163
 
167
164
  Args:
168
- ----
169
165
  out_map: output feature map of the model of shape (N, num_classes, H, W)
170
166
  target: list of dictionary where each dict has a `boxes` and a `flags` entry
171
167
  eps: epsilon factor in dice loss
172
168
 
173
169
  Returns:
174
- -------
175
170
  A loss tensor
176
171
  """
177
172
  targets = self.build_target(target, out_map.shape[1:], True)
@@ -222,18 +217,18 @@ class FAST(_FAST, Model, NestedObject):
222
217
  def call(
223
218
  self,
224
219
  x: tf.Tensor,
225
- target: Optional[List[Dict[str, np.ndarray]]] = None,
220
+ target: list[dict[str, np.ndarray]] | None = None,
226
221
  return_model_output: bool = False,
227
222
  return_preds: bool = False,
228
223
  **kwargs: Any,
229
- ) -> Dict[str, Any]:
224
+ ) -> dict[str, Any]:
230
225
  feat_maps = self.feat_extractor(x, **kwargs)
231
226
  # Pass through the Neck & Head & Upsample
232
227
  feat_concat = self.neck(feat_maps, **kwargs)
233
228
  logits: tf.Tensor = self.head(feat_concat, **kwargs)
234
229
  logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)
235
230
 
236
- out: Dict[str, tf.Tensor] = {}
231
+ out: dict[str, tf.Tensor] = {}
237
232
  if self.exportable:
238
233
  out["logits"] = logits
239
234
  return out
@@ -255,15 +250,14 @@ class FAST(_FAST, Model, NestedObject):
255
250
  return out
256
251
 
257
252
 
258
- def reparameterize(model: Union[FAST, layers.Layer]) -> FAST:
253
+ def reparameterize(model: FAST | layers.Layer) -> FAST:
259
254
  """Fuse batchnorm and conv layers and reparameterize the model
260
255
 
261
256
  args:
262
- ----
257
+
263
258
  model: the FAST model to reparameterize
264
259
 
265
260
  Returns:
266
- -------
267
261
  the reparameterized model
268
262
  """
269
263
  last_conv = None
@@ -306,9 +300,9 @@ def _fast(
306
300
  arch: str,
307
301
  pretrained: bool,
308
302
  backbone_fn,
309
- feat_layers: List[str],
303
+ feat_layers: list[str],
310
304
  pretrained_backbone: bool = True,
311
- input_shape: Optional[Tuple[int, int, int]] = None,
305
+ input_shape: tuple[int, int, int] | None = None,
312
306
  **kwargs: Any,
313
307
  ) -> FAST:
314
308
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -358,12 +352,10 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
358
352
  >>> out = model(input_tensor)
359
353
 
360
354
  Args:
361
- ----
362
355
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
363
356
  **kwargs: keyword arguments of the DBNet architecture
364
357
 
365
358
  Returns:
366
- -------
367
359
  text detection architecture
368
360
  """
369
361
  return _fast(
@@ -386,12 +378,10 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
386
378
  >>> out = model(input_tensor)
387
379
 
388
380
  Args:
389
- ----
390
381
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
391
382
  **kwargs: keyword arguments of the DBNet architecture
392
383
 
393
384
  Returns:
394
- -------
395
385
  text detection architecture
396
386
  """
397
387
  return _fast(
@@ -414,12 +404,10 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
414
404
  >>> out = model(input_tensor)
415
405
 
416
406
  Args:
417
- ----
418
407
  pretrained (bool): If True, returns a model pre-trained on our text detection dataset
419
408
  **kwargs: keyword arguments of the DBNet architecture
420
409
 
421
410
  Returns:
422
- -------
423
411
  text detection architecture
424
412
  """
425
413
  return _fast(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]