python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Model, layers
@@ -17,7 +17,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
17
17
 
18
18
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
19
19
 
20
- default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ default_cfgs: dict[str, dict[str, Any]] = {
21
21
  "vitstr_small": {
22
22
  "mean": (0.694, 0.695, 0.693),
23
23
  "std": (0.299, 0.296, 0.301),
@@ -40,7 +40,6 @@ class ViTSTR(_ViTSTR, Model):
40
40
  Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
41
41
 
42
42
  Args:
43
- ----
44
43
  feature_extractor: the backbone serving as feature extractor
45
44
  vocab: vocabulary used for encoding
46
45
  embedding_units: number of embedding units
@@ -51,7 +50,7 @@ class ViTSTR(_ViTSTR, Model):
51
50
  cfg: dictionary containing information about the model
52
51
  """
53
52
 
54
- _children_names: List[str] = ["feat_extractor", "postprocessor"]
53
+ _children_names: list[str] = ["feat_extractor", "postprocessor"]
55
54
 
56
55
  def __init__(
57
56
  self,
@@ -60,9 +59,9 @@ class ViTSTR(_ViTSTR, Model):
60
59
  embedding_units: int,
61
60
  max_length: int = 32,
62
61
  dropout_prob: float = 0.0,
63
- input_shape: Tuple[int, int, int] = (32, 128, 3), # different from paper
62
+ input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
64
63
  exportable: bool = False,
65
- cfg: Optional[Dict[str, Any]] = None,
64
+ cfg: dict[str, Any] | None = None,
66
65
  ) -> None:
67
66
  super().__init__()
68
67
  self.vocab = vocab
@@ -79,19 +78,17 @@ class ViTSTR(_ViTSTR, Model):
79
78
  def compute_loss(
80
79
  model_output: tf.Tensor,
81
80
  gt: tf.Tensor,
82
- seq_len: List[int],
81
+ seq_len: list[int],
83
82
  ) -> tf.Tensor:
84
83
  """Compute categorical cross-entropy loss for the model.
85
84
  Sequences are masked after the EOS character.
86
85
 
87
86
  Args:
88
- ----
89
87
  model_output: predicted logits of the model
90
88
  gt: the encoded tensor with gt labels
91
89
  seq_len: lengths of each gt word inside the batch
92
90
 
93
91
  Returns:
94
- -------
95
92
  The loss of the model on the batch
96
93
  """
97
94
  # Input length : number of steps
@@ -114,11 +111,11 @@ class ViTSTR(_ViTSTR, Model):
114
111
  def call(
115
112
  self,
116
113
  x: tf.Tensor,
117
- target: Optional[List[str]] = None,
114
+ target: list[str] | None = None,
118
115
  return_model_output: bool = False,
119
116
  return_preds: bool = False,
120
117
  **kwargs: Any,
121
- ) -> Dict[str, Any]:
118
+ ) -> dict[str, Any]:
122
119
  features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
123
120
 
124
121
  if target is not None:
@@ -136,7 +133,7 @@ class ViTSTR(_ViTSTR, Model):
136
133
  ) # (batch_size, max_length, vocab + 1)
137
134
  decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
138
135
 
139
- out: Dict[str, tf.Tensor] = {}
136
+ out: dict[str, tf.Tensor] = {}
140
137
  if self.exportable:
141
138
  out["logits"] = decoded_features
142
139
  return out
@@ -158,14 +155,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
158
155
  """Post processor for ViTSTR architecture
159
156
 
160
157
  Args:
161
- ----
162
158
  vocab: string containing the ordered sequence of supported characters
163
159
  """
164
160
 
165
161
  def __call__(
166
162
  self,
167
163
  logits: tf.Tensor,
168
- ) -> List[Tuple[str, float]]:
164
+ ) -> list[tuple[str, float]]:
169
165
  # compute pred with argmax for attention models
170
166
  out_idxs = tf.math.argmax(logits, axis=2)
171
167
  preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
@@ -191,7 +187,7 @@ def _vitstr(
191
187
  arch: str,
192
188
  pretrained: bool,
193
189
  backbone_fn,
194
- input_shape: Optional[Tuple[int, int, int]] = None,
190
+ input_shape: tuple[int, int, int] | None = None,
195
191
  **kwargs: Any,
196
192
  ) -> ViTSTR:
197
193
  # Patch the config
@@ -239,12 +235,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
239
235
  >>> out = model(input_tensor)
240
236
 
241
237
  Args:
242
- ----
243
238
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
244
239
  **kwargs: keyword arguments of the ViTSTR architecture
245
240
 
246
241
  Returns:
247
- -------
248
242
  text recognition architecture
249
243
  """
250
244
  return _vitstr(
@@ -268,12 +262,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
268
262
  >>> out = model(input_tensor)
269
263
 
270
264
  Args:
271
- ----
272
265
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
273
266
  **kwargs: keyword arguments of the ViTSTR architecture
274
267
 
275
268
  Returns:
276
- -------
277
269
  text recognition architecture
278
270
  """
279
271
  return _vitstr(
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available
8
+ from doctr.file_utils import is_tf_available, is_torch_available
9
9
  from doctr.models.preprocessor import PreProcessor
10
10
 
11
11
  from .. import recognition
@@ -14,7 +14,7 @@ from .predictor import RecognitionPredictor
14
14
  __all__ = ["recognition_predictor"]
15
15
 
16
16
 
17
- ARCHS: List[str] = [
17
+ ARCHS: list[str] = [
18
18
  "crnn_vgg16_bn",
19
19
  "crnn_mobilenet_v3_small",
20
20
  "crnn_mobilenet_v3_large",
@@ -35,9 +35,14 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
35
35
  pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
36
36
  )
37
37
  else:
38
- if not isinstance(
39
- arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
40
- ):
38
+ allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
39
+ if is_torch_available():
40
+ # Adding the type for torch compiled models to the allowed architectures
41
+ from doctr.models.utils import _CompiledModule
42
+
43
+ allowed_archs.append(_CompiledModule)
44
+
45
+ if not isinstance(arch, tuple(allowed_archs)):
41
46
  raise ValueError(f"unknown architecture: {type(arch)}")
42
47
  _model = arch
43
48
 
@@ -52,7 +57,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
52
57
  return predictor
53
58
 
54
59
 
55
- def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor:
60
+ def recognition_predictor(
61
+ arch: Any = "crnn_vgg16_bn",
62
+ pretrained: bool = False,
63
+ symmetric_pad: bool = False,
64
+ batch_size: int = 128,
65
+ **kwargs: Any,
66
+ ) -> RecognitionPredictor:
56
67
  """Text recognition architecture.
57
68
 
58
69
  Example::
@@ -63,13 +74,13 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
63
74
  >>> out = model([input_page])
64
75
 
65
76
  Args:
66
- ----
67
77
  arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
68
78
  pretrained: If True, returns a model pre-trained on our text recognition dataset
79
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
80
+ batch_size: number of samples the model processes in parallel
69
81
  **kwargs: optional parameters to be passed to the architecture
70
82
 
71
83
  Returns:
72
- -------
73
84
  Recognition predictor
74
85
  """
75
- return _predictor(arch, pretrained, **kwargs)
86
+ return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
- from typing import Any, List, Optional, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import torch
10
10
  from torch import nn
@@ -18,8 +18,12 @@ __all__ = [
18
18
  "export_model_to_onnx",
19
19
  "_copy_tensor",
20
20
  "_bf16_to_float32",
21
+ "_CompiledModule",
21
22
  ]
22
23
 
24
+ # torch compiled model type
25
+ _CompiledModule = torch._dynamo.eval_frame.OptimizedModule
26
+
23
27
 
24
28
  def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
25
29
  return x.clone().detach()
@@ -32,9 +36,9 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
32
36
 
33
37
  def load_pretrained_params(
34
38
  model: nn.Module,
35
- url: Optional[str] = None,
36
- hash_prefix: Optional[str] = None,
37
- ignore_keys: Optional[List[str]] = None,
39
+ url: str | None = None,
40
+ hash_prefix: str | None = None,
41
+ ignore_keys: list[str] | None = None,
38
42
  **kwargs: Any,
39
43
  ) -> None:
40
44
  """Load a set of parameters onto a model
@@ -43,7 +47,6 @@ def load_pretrained_params(
43
47
  >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
44
48
 
45
49
  Args:
46
- ----
47
50
  model: the PyTorch model to be loaded
48
51
  url: URL of the zipped set of parameters
49
52
  hash_prefix: first characters of SHA256 expected hash
@@ -76,7 +79,7 @@ def conv_sequence_pt(
76
79
  relu: bool = False,
77
80
  bn: bool = False,
78
81
  **kwargs: Any,
79
- ) -> List[nn.Module]:
82
+ ) -> list[nn.Module]:
80
83
  """Builds a convolutional-based layer sequence
81
84
 
82
85
  >>> from torch.nn import Sequential
@@ -84,7 +87,6 @@ def conv_sequence_pt(
84
87
  >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
85
88
 
86
89
  Args:
87
- ----
88
90
  in_channels: number of input channels
89
91
  out_channels: number of output channels
90
92
  relu: whether ReLU should be used
@@ -92,13 +94,12 @@ def conv_sequence_pt(
92
94
  **kwargs: additional arguments to be passed to the convolutional layer
93
95
 
94
96
  Returns:
95
- -------
96
97
  list of layers
97
98
  """
98
99
  # No bias before Batch norm
99
100
  kwargs["bias"] = kwargs.get("bias", not bn)
100
101
  # Add activation directly to the conv if there is no BN
101
- conv_seq: List[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
102
+ conv_seq: list[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
102
103
 
103
104
  if bn:
104
105
  conv_seq.append(nn.BatchNorm2d(out_channels))
@@ -110,8 +111,8 @@ def conv_sequence_pt(
110
111
 
111
112
 
112
113
  def set_device_and_dtype(
113
- model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype
114
- ) -> Tuple[Any, List[torch.Tensor]]:
114
+ model: Any, batches: list[torch.Tensor], device: str | torch.device, dtype: torch.dtype
115
+ ) -> tuple[Any, list[torch.Tensor]]:
115
116
  """Set the device and dtype of a model and its batches
116
117
 
117
118
  >>> import torch
@@ -122,14 +123,12 @@ def set_device_and_dtype(
122
123
  >>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
123
124
 
124
125
  Args:
125
- ----
126
126
  model: the model to be set
127
127
  batches: the batches to be set
128
128
  device: the device to be used
129
129
  dtype: the dtype to be used
130
130
 
131
131
  Returns:
132
- -------
133
132
  the model and batches set
134
133
  """
135
134
  return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
@@ -145,19 +144,17 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
145
144
  >>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
146
145
 
147
146
  Args:
148
- ----
149
147
  model: the PyTorch model to be exported
150
148
  model_name: the name for the exported model
151
149
  dummy_input: the dummy input to the model
152
150
  kwargs: additional arguments to be passed to torch.onnx.export
153
151
 
154
152
  Returns:
155
- -------
156
153
  the path to the exported model
157
154
  """
158
155
  torch.onnx.export(
159
156
  model,
160
- dummy_input, # type: ignore[arg-type]
157
+ dummy_input,
161
158
  f"{model_name}.onnx",
162
159
  input_names=["input"],
163
160
  output_names=["logits"],
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
- from typing import Any, Callable, List, Optional, Tuple, Union
7
+ from collections.abc import Callable
8
+ from typing import Any
8
9
 
9
10
  import tensorflow as tf
10
11
  import tf2onnx
@@ -39,7 +40,6 @@ def _build_model(model: Model):
39
40
  """Build a model by calling it once with dummy input
40
41
 
41
42
  Args:
42
- ----
43
43
  model: the model to be built
44
44
  """
45
45
  model(tf.zeros((1, *model.cfg["input_shape"])), training=False)
@@ -47,8 +47,8 @@ def _build_model(model: Model):
47
47
 
48
48
  def load_pretrained_params(
49
49
  model: Model,
50
- url: Optional[str] = None,
51
- hash_prefix: Optional[str] = None,
50
+ url: str | None = None,
51
+ hash_prefix: str | None = None,
52
52
  skip_mismatch: bool = False,
53
53
  **kwargs: Any,
54
54
  ) -> None:
@@ -58,7 +58,6 @@ def load_pretrained_params(
58
58
  >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.weights.h5")
59
59
 
60
60
  Args:
61
- ----
62
61
  model: the keras model to be loaded
63
62
  url: URL of the zipped set of parameters
64
63
  hash_prefix: first characters of SHA256 expected hash
@@ -75,12 +74,12 @@ def load_pretrained_params(
75
74
 
76
75
  def conv_sequence(
77
76
  out_channels: int,
78
- activation: Optional[Union[str, Callable]] = None,
77
+ activation: str | Callable | None = None,
79
78
  bn: bool = False,
80
79
  padding: str = "same",
81
80
  kernel_initializer: str = "he_normal",
82
81
  **kwargs: Any,
83
- ) -> List[layers.Layer]:
82
+ ) -> list[layers.Layer]:
84
83
  """Builds a convolutional-based layer sequence
85
84
 
86
85
  >>> from tensorflow.keras import Sequential
@@ -88,7 +87,6 @@ def conv_sequence(
88
87
  >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
89
88
 
90
89
  Args:
91
- ----
92
90
  out_channels: number of output channels
93
91
  activation: activation to be used (default: no activation)
94
92
  bn: should a batch normalization layer be added
@@ -97,7 +95,6 @@ def conv_sequence(
97
95
  **kwargs: additional arguments to be passed to the convolutional layer
98
96
 
99
97
  Returns:
100
- -------
101
98
  list of layers
102
99
  """
103
100
  # No bias before Batch norm
@@ -125,12 +122,11 @@ class IntermediateLayerGetter(Model):
125
122
  >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers)
126
123
 
127
124
  Args:
128
- ----
129
125
  model: the model to extract feature maps from
130
126
  layer_names: the list of layers to retrieve the feature map from
131
127
  """
132
128
 
133
- def __init__(self, model: Model, layer_names: List[str]) -> None:
129
+ def __init__(self, model: Model, layer_names: list[str]) -> None:
134
130
  intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names]
135
131
  super().__init__(model.input, outputs=intermediate_fmaps)
136
132
 
@@ -139,8 +135,8 @@ class IntermediateLayerGetter(Model):
139
135
 
140
136
 
141
137
  def export_model_to_onnx(
142
- model: Model, model_name: str, dummy_input: List[tf.TensorSpec], **kwargs: Any
143
- ) -> Tuple[str, List[str]]:
138
+ model: Model, model_name: str, dummy_input: list[tf.TensorSpec], **kwargs: Any
139
+ ) -> tuple[str, list[str]]:
144
140
  """Export model to ONNX format.
145
141
 
146
142
  >>> import tensorflow as tf
@@ -151,16 +147,18 @@ def export_model_to_onnx(
151
147
  >>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
152
148
 
153
149
  Args:
154
- ----
155
150
  model: the keras model to be exported
156
151
  model_name: the name for the exported model
157
152
  dummy_input: the dummy input to the model
158
153
  kwargs: additional arguments to be passed to tf2onnx
159
154
 
160
155
  Returns:
161
- -------
162
156
  the path to the exported model and a list with the output layer names
163
157
  """
158
+ # get the users eager mode
159
+ eager_mode = tf.executing_eagerly()
160
+ # set eager mode to true to avoid issues with tf2onnx
161
+ tf.config.run_functions_eagerly(True)
164
162
  large_model = kwargs.get("large_model", False)
165
163
  model_proto, _ = tf2onnx.convert.from_keras(
166
164
  model,
@@ -171,6 +169,9 @@ def export_model_to_onnx(
171
169
  # Get the output layer names
172
170
  output = [n.name for n in model_proto.graph.output]
173
171
 
172
+ # reset the eager mode to the users mode
173
+ tf.config.run_functions_eagerly(eager_mode)
174
+
174
175
  # models which are too large (weights > 2GB while converting to ONNX) needs to be handled
175
176
  # about an external tensor storage where the graph and weights are seperatly stored in a archive
176
177
  if large_model:
doctr/models/zoo.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -83,7 +83,6 @@ def ocr_predictor(
83
83
  >>> out = model([input_page])
84
84
 
85
85
  Args:
86
- ----
87
86
  det_arch: name of the detection architecture or the model itself to use
88
87
  (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
89
88
  reco_arch: name of the recognition architecture or the model itself to use
@@ -108,7 +107,6 @@ def ocr_predictor(
108
107
  kwargs: keyword args of `OCRPredictor`
109
108
 
110
109
  Returns:
111
- -------
112
110
  OCR predictor
113
111
  """
114
112
  return _predictor(
@@ -197,7 +195,6 @@ def kie_predictor(
197
195
  >>> out = model([input_page])
198
196
 
199
197
  Args:
200
- ----
201
198
  det_arch: name of the detection architecture or the model itself to use
202
199
  (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
203
200
  reco_arch: name of the recognition architecture or the model itself to use
@@ -222,7 +219,6 @@ def kie_predictor(
222
219
  kwargs: keyword args of `OCRPredictor`
223
220
 
224
221
  Returns:
225
- -------
226
222
  KIE predictor
227
223
  """
228
224
  return _kie_predictor(
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
3
+ if is_torch_available():
6
4
  from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import *
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Tuple, Union
7
6
 
8
7
  import cv2
9
8
  import numpy as np
@@ -15,17 +14,15 @@ __all__ = ["crop_boxes", "create_shadow_mask"]
15
14
 
16
15
  def crop_boxes(
17
16
  boxes: np.ndarray,
18
- crop_box: Union[Tuple[int, int, int, int], Tuple[float, float, float, float]],
17
+ crop_box: tuple[int, int, int, int] | tuple[float, float, float, float],
19
18
  ) -> np.ndarray:
20
19
  """Crop localization boxes
21
20
 
22
21
  Args:
23
- ----
24
22
  boxes: ndarray of shape (N, 4) in relative or abs coordinates
25
23
  crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes
26
24
 
27
25
  Returns:
28
- -------
29
26
  the cropped boxes
30
27
  """
31
28
  is_box_rel = boxes.max() <= 1
@@ -49,17 +46,15 @@ def crop_boxes(
49
46
  return boxes[is_valid]
50
47
 
51
48
 
52
- def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float, float]:
49
+ def expand_line(line: np.ndarray, target_shape: tuple[int, int]) -> tuple[float, float]:
53
50
  """Expands a 2-point line, so that the first is on the edge. In other terms, we extend the line in
54
51
  the same direction until we meet one of the edges.
55
52
 
56
53
  Args:
57
- ----
58
54
  line: array of shape (2, 2) of the point supposed to be on one edge, and the shadow tip.
59
55
  target_shape: the desired mask shape
60
56
 
61
57
  Returns:
62
- -------
63
58
  2D coordinates of the first point once we extended the line (on one of the edges)
64
59
  """
65
60
  if any(coord == 0 or coord == size for coord, size in zip(line[0], target_shape[::-1])):
@@ -112,7 +107,7 @@ def expand_line(line: np.ndarray, target_shape: Tuple[int, int]) -> Tuple[float,
112
107
 
113
108
 
114
109
  def create_shadow_mask(
115
- target_shape: Tuple[int, int],
110
+ target_shape: tuple[int, int],
116
111
  min_base_width=0.3,
117
112
  max_tip_width=0.5,
118
113
  max_tip_height=0.3,
@@ -120,14 +115,12 @@ def create_shadow_mask(
120
115
  """Creates a random shadow mask
121
116
 
122
117
  Args:
123
- ----
124
118
  target_shape: the target shape (H, W)
125
119
  min_base_width: the relative minimum shadow base width
126
120
  max_tip_width: the relative maximum shadow tip width
127
121
  max_tip_height: the relative maximum shadow tip height
128
122
 
129
123
  Returns:
130
- -------
131
124
  a numpy ndarray of shape (H, W, 1) with values in the range [0, 1]
132
125
  """
133
126
  # Default base is top