python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import torch
10
11
  from torch import nn
@@ -19,7 +20,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
19
20
 
20
21
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
21
22
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
23
24
  "vitstr_small": {
24
25
  "mean": (0.694, 0.695, 0.693),
25
26
  "std": (0.299, 0.296, 0.301),
@@ -42,7 +43,6 @@ class ViTSTR(_ViTSTR, nn.Module):
42
43
  Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
43
44
 
44
45
  Args:
45
- ----
46
46
  feature_extractor: the backbone serving as feature extractor
47
47
  vocab: vocabulary used for encoding
48
48
  embedding_units: number of embedding units
@@ -59,9 +59,9 @@ class ViTSTR(_ViTSTR, nn.Module):
59
59
  vocab: str,
60
60
  embedding_units: int,
61
61
  max_length: int = 32, # different from paper
62
- input_shape: Tuple[int, int, int] = (3, 32, 128), # different from paper
62
+ input_shape: tuple[int, int, int] = (3, 32, 128), # different from paper
63
63
  exportable: bool = False,
64
- cfg: Optional[Dict[str, Any]] = None,
64
+ cfg: dict[str, Any] | None = None,
65
65
  ) -> None:
66
66
  super().__init__()
67
67
  self.vocab = vocab
@@ -77,10 +77,10 @@ class ViTSTR(_ViTSTR, nn.Module):
77
77
  def forward(
78
78
  self,
79
79
  x: torch.Tensor,
80
- target: Optional[List[str]] = None,
80
+ target: list[str] | None = None,
81
81
  return_model_output: bool = False,
82
82
  return_preds: bool = False,
83
- ) -> Dict[str, Any]:
83
+ ) -> dict[str, Any]:
84
84
  features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
85
85
 
86
86
  if target is not None:
@@ -98,7 +98,7 @@ class ViTSTR(_ViTSTR, nn.Module):
98
98
  logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
99
99
  decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
100
100
 
101
- out: Dict[str, Any] = {}
101
+ out: dict[str, Any] = {}
102
102
  if self.exportable:
103
103
  out["logits"] = decoded_features
104
104
  return out
@@ -107,8 +107,13 @@ class ViTSTR(_ViTSTR, nn.Module):
107
107
  out["out_map"] = decoded_features
108
108
 
109
109
  if target is None or return_preds:
110
+ # Disable for torch.compile compatibility
111
+ @torch.compiler.disable # type: ignore[attr-defined]
112
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
113
+ return self.postprocessor(decoded_features)
114
+
110
115
  # Post-process boxes
111
- out["preds"] = self.postprocessor(decoded_features)
116
+ out["preds"] = _postprocess(decoded_features)
112
117
 
113
118
  if target is not None:
114
119
  out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
@@ -125,19 +130,17 @@ class ViTSTR(_ViTSTR, nn.Module):
125
130
  Sequences are masked after the EOS character.
126
131
 
127
132
  Args:
128
- ----
129
133
  model_output: predicted logits of the model
130
134
  gt: the encoded tensor with gt labels
131
135
  seq_len: lengths of each gt word inside the batch
132
136
 
133
137
  Returns:
134
- -------
135
138
  The loss of the model on the batch
136
139
  """
137
140
  # Input length : number of steps
138
141
  input_len = model_output.shape[1]
139
142
  # Add one for additional <eos> token (sos disappear in shift!)
140
- seq_len = seq_len + 1
143
+ seq_len = seq_len + 1 # type: ignore[assignment]
141
144
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
142
145
  # The "masked" first gt char is <sos>.
143
146
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -153,14 +156,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
153
156
  """Post processor for ViTSTR architecture
154
157
 
155
158
  Args:
156
- ----
157
159
  vocab: string containing the ordered sequence of supported characters
158
160
  """
159
161
 
160
162
  def __call__(
161
163
  self,
162
164
  logits: torch.Tensor,
163
- ) -> List[Tuple[str, float]]:
165
+ ) -> list[tuple[str, float]]:
164
166
  # compute pred with argmax for attention models
165
167
  out_idxs = logits.argmax(-1)
166
168
  preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
@@ -183,7 +185,7 @@ def _vitstr(
183
185
  pretrained: bool,
184
186
  backbone_fn: Callable[[bool], nn.Module],
185
187
  layer: str,
186
- ignore_keys: Optional[List[str]] = None,
188
+ ignore_keys: list[str] | None = None,
187
189
  **kwargs: Any,
188
190
  ) -> ViTSTR:
189
191
  # Patch the config
@@ -228,12 +230,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
228
230
  >>> out = model(input_tensor)
229
231
 
230
232
  Args:
231
- ----
232
233
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
233
234
  kwargs: keyword arguments of the ViTSTR architecture
234
235
 
235
236
  Returns:
236
- -------
237
237
  text recognition architecture
238
238
  """
239
239
  return _vitstr(
@@ -259,12 +259,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
259
259
  >>> out = model(input_tensor)
260
260
 
261
261
  Args:
262
- ----
263
262
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
264
263
  kwargs: keyword arguments of the ViTSTR architecture
265
264
 
266
265
  Returns:
267
- -------
268
266
  text recognition architecture
269
267
  """
270
268
  return _vitstr(
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Model, layers
@@ -12,25 +12,25 @@ from tensorflow.keras import Model, layers
12
12
  from doctr.datasets import VOCABS
13
13
 
14
14
  from ...classification import vit_b, vit_s
15
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
15
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
16
16
  from .base import _ViTSTR, _ViTSTRPostProcessor
17
17
 
18
18
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
19
19
 
20
- default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ default_cfgs: dict[str, dict[str, Any]] = {
21
21
  "vitstr_small": {
22
22
  "mean": (0.694, 0.695, 0.693),
23
23
  "std": (0.299, 0.296, 0.301),
24
24
  "input_shape": (32, 128, 3),
25
25
  "vocab": VOCABS["french"],
26
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_small-358fab2e.zip&src=0",
26
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_small-d28b8d92.weights.h5&src=0",
27
27
  },
28
28
  "vitstr_base": {
29
29
  "mean": (0.694, 0.695, 0.693),
30
30
  "std": (0.299, 0.296, 0.301),
31
31
  "input_shape": (32, 128, 3),
32
32
  "vocab": VOCABS["french"],
33
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vitstr_base-2889159a.zip&src=0",
33
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/vitstr_base-9ad6eb84.weights.h5&src=0",
34
34
  },
35
35
  }
36
36
 
@@ -40,7 +40,6 @@ class ViTSTR(_ViTSTR, Model):
40
40
  Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
41
41
 
42
42
  Args:
43
- ----
44
43
  feature_extractor: the backbone serving as feature extractor
45
44
  vocab: vocabulary used for encoding
46
45
  embedding_units: number of embedding units
@@ -51,7 +50,7 @@ class ViTSTR(_ViTSTR, Model):
51
50
  cfg: dictionary containing information about the model
52
51
  """
53
52
 
54
- _children_names: List[str] = ["feat_extractor", "postprocessor"]
53
+ _children_names: list[str] = ["feat_extractor", "postprocessor"]
55
54
 
56
55
  def __init__(
57
56
  self,
@@ -60,9 +59,9 @@ class ViTSTR(_ViTSTR, Model):
60
59
  embedding_units: int,
61
60
  max_length: int = 32,
62
61
  dropout_prob: float = 0.0,
63
- input_shape: Tuple[int, int, int] = (32, 128, 3), # different from paper
62
+ input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
64
63
  exportable: bool = False,
65
- cfg: Optional[Dict[str, Any]] = None,
64
+ cfg: dict[str, Any] | None = None,
66
65
  ) -> None:
67
66
  super().__init__()
68
67
  self.vocab = vocab
@@ -79,19 +78,17 @@ class ViTSTR(_ViTSTR, Model):
79
78
  def compute_loss(
80
79
  model_output: tf.Tensor,
81
80
  gt: tf.Tensor,
82
- seq_len: List[int],
81
+ seq_len: list[int],
83
82
  ) -> tf.Tensor:
84
83
  """Compute categorical cross-entropy loss for the model.
85
84
  Sequences are masked after the EOS character.
86
85
 
87
86
  Args:
88
- ----
89
87
  model_output: predicted logits of the model
90
88
  gt: the encoded tensor with gt labels
91
89
  seq_len: lengths of each gt word inside the batch
92
90
 
93
91
  Returns:
94
- -------
95
92
  The loss of the model on the batch
96
93
  """
97
94
  # Input length : number of steps
@@ -114,11 +111,11 @@ class ViTSTR(_ViTSTR, Model):
114
111
  def call(
115
112
  self,
116
113
  x: tf.Tensor,
117
- target: Optional[List[str]] = None,
114
+ target: list[str] | None = None,
118
115
  return_model_output: bool = False,
119
116
  return_preds: bool = False,
120
117
  **kwargs: Any,
121
- ) -> Dict[str, Any]:
118
+ ) -> dict[str, Any]:
122
119
  features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
123
120
 
124
121
  if target is not None:
@@ -136,7 +133,7 @@ class ViTSTR(_ViTSTR, Model):
136
133
  ) # (batch_size, max_length, vocab + 1)
137
134
  decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
138
135
 
139
- out: Dict[str, tf.Tensor] = {}
136
+ out: dict[str, tf.Tensor] = {}
140
137
  if self.exportable:
141
138
  out["logits"] = decoded_features
142
139
  return out
@@ -158,14 +155,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
158
155
  """Post processor for ViTSTR architecture
159
156
 
160
157
  Args:
161
- ----
162
158
  vocab: string containing the ordered sequence of supported characters
163
159
  """
164
160
 
165
161
  def __call__(
166
162
  self,
167
163
  logits: tf.Tensor,
168
- ) -> List[Tuple[str, float]]:
164
+ ) -> list[tuple[str, float]]:
169
165
  # compute pred with argmax for attention models
170
166
  out_idxs = tf.math.argmax(logits, axis=2)
171
167
  preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
@@ -191,7 +187,7 @@ def _vitstr(
191
187
  arch: str,
192
188
  pretrained: bool,
193
189
  backbone_fn,
194
- input_shape: Optional[Tuple[int, int, int]] = None,
190
+ input_shape: tuple[int, int, int] | None = None,
195
191
  **kwargs: Any,
196
192
  ) -> ViTSTR:
197
193
  # Patch the config
@@ -216,9 +212,14 @@ def _vitstr(
216
212
 
217
213
  # Build the model
218
214
  model = ViTSTR(feat_extractor, cfg=_cfg, **kwargs)
215
+ _build_model(model)
216
+
219
217
  # Load pretrained parameters
220
218
  if pretrained:
221
- load_pretrained_params(model, default_cfgs[arch]["url"])
219
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
220
+ load_pretrained_params(
221
+ model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
222
+ )
222
223
 
223
224
  return model
224
225
 
@@ -234,12 +235,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
234
235
  >>> out = model(input_tensor)
235
236
 
236
237
  Args:
237
- ----
238
238
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
239
239
  **kwargs: keyword arguments of the ViTSTR architecture
240
240
 
241
241
  Returns:
242
- -------
243
242
  text recognition architecture
244
243
  """
245
244
  return _vitstr(
@@ -263,12 +262,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
263
262
  >>> out = model(input_tensor)
264
263
 
265
264
  Args:
266
- ----
267
265
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
268
266
  **kwargs: keyword arguments of the ViTSTR architecture
269
267
 
270
268
  Returns:
271
- -------
272
269
  text recognition architecture
273
270
  """
274
271
  return _vitstr(
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available
8
+ from doctr.file_utils import is_tf_available, is_torch_available
9
9
  from doctr.models.preprocessor import PreProcessor
10
10
 
11
11
  from .. import recognition
@@ -14,7 +14,7 @@ from .predictor import RecognitionPredictor
14
14
  __all__ = ["recognition_predictor"]
15
15
 
16
16
 
17
- ARCHS: List[str] = [
17
+ ARCHS: list[str] = [
18
18
  "crnn_vgg16_bn",
19
19
  "crnn_mobilenet_v3_small",
20
20
  "crnn_mobilenet_v3_large",
@@ -35,9 +35,14 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
35
35
  pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
36
36
  )
37
37
  else:
38
- if not isinstance(
39
- arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
40
- ):
38
+ allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
39
+ if is_torch_available():
40
+ # Adding the type for torch compiled models to the allowed architectures
41
+ from doctr.models.utils import _CompiledModule
42
+
43
+ allowed_archs.append(_CompiledModule)
44
+
45
+ if not isinstance(arch, tuple(allowed_archs)):
41
46
  raise ValueError(f"unknown architecture: {type(arch)}")
42
47
  _model = arch
43
48
 
@@ -52,7 +57,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
52
57
  return predictor
53
58
 
54
59
 
55
- def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor:
60
+ def recognition_predictor(
61
+ arch: Any = "crnn_vgg16_bn",
62
+ pretrained: bool = False,
63
+ symmetric_pad: bool = False,
64
+ batch_size: int = 128,
65
+ **kwargs: Any,
66
+ ) -> RecognitionPredictor:
56
67
  """Text recognition architecture.
57
68
 
58
69
  Example::
@@ -63,13 +74,13 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
63
74
  >>> out = model([input_page])
64
75
 
65
76
  Args:
66
- ----
67
77
  arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
68
78
  pretrained: If True, returns a model pre-trained on our text recognition dataset
79
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
80
+ batch_size: number of samples the model processes in parallel
69
81
  **kwargs: optional parameters to be passed to the architecture
70
82
 
71
83
  Returns:
72
- -------
73
84
  Recognition predictor
74
85
  """
75
- return _predictor(arch, pretrained, **kwargs)
86
+ return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
- from typing import Any, List, Optional, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import torch
10
10
  from torch import nn
@@ -18,8 +18,12 @@ __all__ = [
18
18
  "export_model_to_onnx",
19
19
  "_copy_tensor",
20
20
  "_bf16_to_float32",
21
+ "_CompiledModule",
21
22
  ]
22
23
 
24
+ # torch compiled model type
25
+ _CompiledModule = torch._dynamo.eval_frame.OptimizedModule
26
+
23
27
 
24
28
  def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
25
29
  return x.clone().detach()
@@ -32,9 +36,9 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
32
36
 
33
37
  def load_pretrained_params(
34
38
  model: nn.Module,
35
- url: Optional[str] = None,
36
- hash_prefix: Optional[str] = None,
37
- ignore_keys: Optional[List[str]] = None,
39
+ url: str | None = None,
40
+ hash_prefix: str | None = None,
41
+ ignore_keys: list[str] | None = None,
38
42
  **kwargs: Any,
39
43
  ) -> None:
40
44
  """Load a set of parameters onto a model
@@ -43,7 +47,6 @@ def load_pretrained_params(
43
47
  >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
44
48
 
45
49
  Args:
46
- ----
47
50
  model: the PyTorch model to be loaded
48
51
  url: URL of the zipped set of parameters
49
52
  hash_prefix: first characters of SHA256 expected hash
@@ -76,7 +79,7 @@ def conv_sequence_pt(
76
79
  relu: bool = False,
77
80
  bn: bool = False,
78
81
  **kwargs: Any,
79
- ) -> List[nn.Module]:
82
+ ) -> list[nn.Module]:
80
83
  """Builds a convolutional-based layer sequence
81
84
 
82
85
  >>> from torch.nn import Sequential
@@ -84,7 +87,6 @@ def conv_sequence_pt(
84
87
  >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
85
88
 
86
89
  Args:
87
- ----
88
90
  in_channels: number of input channels
89
91
  out_channels: number of output channels
90
92
  relu: whether ReLU should be used
@@ -92,13 +94,12 @@ def conv_sequence_pt(
92
94
  **kwargs: additional arguments to be passed to the convolutional layer
93
95
 
94
96
  Returns:
95
- -------
96
97
  list of layers
97
98
  """
98
99
  # No bias before Batch norm
99
100
  kwargs["bias"] = kwargs.get("bias", not bn)
100
101
  # Add activation directly to the conv if there is no BN
101
- conv_seq: List[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
102
+ conv_seq: list[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
102
103
 
103
104
  if bn:
104
105
  conv_seq.append(nn.BatchNorm2d(out_channels))
@@ -110,8 +111,8 @@ def conv_sequence_pt(
110
111
 
111
112
 
112
113
  def set_device_and_dtype(
113
- model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype
114
- ) -> Tuple[Any, List[torch.Tensor]]:
114
+ model: Any, batches: list[torch.Tensor], device: str | torch.device, dtype: torch.dtype
115
+ ) -> tuple[Any, list[torch.Tensor]]:
115
116
  """Set the device and dtype of a model and its batches
116
117
 
117
118
  >>> import torch
@@ -122,14 +123,12 @@ def set_device_and_dtype(
122
123
  >>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
123
124
 
124
125
  Args:
125
- ----
126
126
  model: the model to be set
127
127
  batches: the batches to be set
128
128
  device: the device to be used
129
129
  dtype: the dtype to be used
130
130
 
131
131
  Returns:
132
- -------
133
132
  the model and batches set
134
133
  """
135
134
  return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
@@ -145,14 +144,12 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
145
144
  >>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
146
145
 
147
146
  Args:
148
- ----
149
147
  model: the PyTorch model to be exported
150
148
  model_name: the name for the exported model
151
149
  dummy_input: the dummy input to the model
152
150
  kwargs: additional arguments to be passed to torch.onnx.export
153
151
 
154
152
  Returns:
155
- -------
156
153
  the path to the exported model
157
154
  """
158
155
  torch.onnx.export(