python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import torch
10
11
  from torch import nn
@@ -19,7 +20,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
19
20
 
20
21
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
21
22
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
23
24
  "vitstr_small": {
24
25
  "mean": (0.694, 0.695, 0.693),
25
26
  "std": (0.299, 0.296, 0.301),
@@ -42,7 +43,6 @@ class ViTSTR(_ViTSTR, nn.Module):
42
43
  Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
43
44
 
44
45
  Args:
45
- ----
46
46
  feature_extractor: the backbone serving as feature extractor
47
47
  vocab: vocabulary used for encoding
48
48
  embedding_units: number of embedding units
@@ -59,9 +59,9 @@ class ViTSTR(_ViTSTR, nn.Module):
59
59
  vocab: str,
60
60
  embedding_units: int,
61
61
  max_length: int = 32, # different from paper
62
- input_shape: Tuple[int, int, int] = (3, 32, 128), # different from paper
62
+ input_shape: tuple[int, int, int] = (3, 32, 128), # different from paper
63
63
  exportable: bool = False,
64
- cfg: Optional[Dict[str, Any]] = None,
64
+ cfg: dict[str, Any] | None = None,
65
65
  ) -> None:
66
66
  super().__init__()
67
67
  self.vocab = vocab
@@ -74,13 +74,22 @@ class ViTSTR(_ViTSTR, nn.Module):
74
74
 
75
75
  self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
76
76
 
77
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
78
+ """Load pretrained parameters onto the model
79
+
80
+ Args:
81
+ path_or_url: the path or URL to the model parameters (checkpoint)
82
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
83
+ """
84
+ load_pretrained_params(self, path_or_url, **kwargs)
85
+
77
86
  def forward(
78
87
  self,
79
88
  x: torch.Tensor,
80
- target: Optional[List[str]] = None,
89
+ target: list[str] | None = None,
81
90
  return_model_output: bool = False,
82
91
  return_preds: bool = False,
83
- ) -> Dict[str, Any]:
92
+ ) -> dict[str, Any]:
84
93
  features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
85
94
 
86
95
  if target is not None:
@@ -98,7 +107,7 @@ class ViTSTR(_ViTSTR, nn.Module):
98
107
  logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
99
108
  decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
100
109
 
101
- out: Dict[str, Any] = {}
110
+ out: dict[str, Any] = {}
102
111
  if self.exportable:
103
112
  out["logits"] = decoded_features
104
113
  return out
@@ -107,8 +116,13 @@ class ViTSTR(_ViTSTR, nn.Module):
107
116
  out["out_map"] = decoded_features
108
117
 
109
118
  if target is None or return_preds:
119
+ # Disable for torch.compile compatibility
120
+ @torch.compiler.disable # type: ignore[attr-defined]
121
+ def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]:
122
+ return self.postprocessor(decoded_features)
123
+
110
124
  # Post-process boxes
111
- out["preds"] = self.postprocessor(decoded_features)
125
+ out["preds"] = _postprocess(decoded_features)
112
126
 
113
127
  if target is not None:
114
128
  out["loss"] = self.compute_loss(decoded_features, gt, seq_len)
@@ -125,19 +139,17 @@ class ViTSTR(_ViTSTR, nn.Module):
125
139
  Sequences are masked after the EOS character.
126
140
 
127
141
  Args:
128
- ----
129
142
  model_output: predicted logits of the model
130
143
  gt: the encoded tensor with gt labels
131
144
  seq_len: lengths of each gt word inside the batch
132
145
 
133
146
  Returns:
134
- -------
135
147
  The loss of the model on the batch
136
148
  """
137
149
  # Input length : number of steps
138
150
  input_len = model_output.shape[1]
139
151
  # Add one for additional <eos> token (sos disappear in shift!)
140
- seq_len = seq_len + 1
152
+ seq_len = seq_len + 1 # type: ignore[assignment]
141
153
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
142
154
  # The "masked" first gt char is <sos>.
143
155
  cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -153,14 +165,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
153
165
  """Post processor for ViTSTR architecture
154
166
 
155
167
  Args:
156
- ----
157
168
  vocab: string containing the ordered sequence of supported characters
158
169
  """
159
170
 
160
171
  def __call__(
161
172
  self,
162
173
  logits: torch.Tensor,
163
- ) -> List[Tuple[str, float]]:
174
+ ) -> list[tuple[str, float]]:
164
175
  # compute pred with argmax for attention models
165
176
  out_idxs = logits.argmax(-1)
166
177
  preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
@@ -183,7 +194,7 @@ def _vitstr(
183
194
  pretrained: bool,
184
195
  backbone_fn: Callable[[bool], nn.Module],
185
196
  layer: str,
186
- ignore_keys: Optional[List[str]] = None,
197
+ ignore_keys: list[str] | None = None,
187
198
  **kwargs: Any,
188
199
  ) -> ViTSTR:
189
200
  # Patch the config
@@ -212,7 +223,7 @@ def _vitstr(
212
223
  # The number of classes is not the same as the number of classes in the pretrained model =>
213
224
  # remove the last layer weights
214
225
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
215
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
226
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
216
227
 
217
228
  return model
218
229
 
@@ -228,12 +239,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
228
239
  >>> out = model(input_tensor)
229
240
 
230
241
  Args:
231
- ----
232
242
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
233
243
  kwargs: keyword arguments of the ViTSTR architecture
234
244
 
235
245
  Returns:
236
- -------
237
246
  text recognition architecture
238
247
  """
239
248
  return _vitstr(
@@ -259,12 +268,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
259
268
  >>> out = model(input_tensor)
260
269
 
261
270
  Args:
262
- ----
263
271
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
264
272
  kwargs: keyword arguments of the ViTSTR architecture
265
273
 
266
274
  Returns:
267
- -------
268
275
  text recognition architecture
269
276
  """
270
277
  return _vitstr(
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Model, layers
@@ -17,7 +17,7 @@ from .base import _ViTSTR, _ViTSTRPostProcessor
17
17
 
18
18
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
19
19
 
20
- default_cfgs: Dict[str, Dict[str, Any]] = {
20
+ default_cfgs: dict[str, dict[str, Any]] = {
21
21
  "vitstr_small": {
22
22
  "mean": (0.694, 0.695, 0.693),
23
23
  "std": (0.299, 0.296, 0.301),
@@ -40,7 +40,6 @@ class ViTSTR(_ViTSTR, Model):
40
40
  Efficient Scene Text Recognition" <https://arxiv.org/pdf/2105.08582.pdf>`_.
41
41
 
42
42
  Args:
43
- ----
44
43
  feature_extractor: the backbone serving as feature extractor
45
44
  vocab: vocabulary used for encoding
46
45
  embedding_units: number of embedding units
@@ -51,7 +50,7 @@ class ViTSTR(_ViTSTR, Model):
51
50
  cfg: dictionary containing information about the model
52
51
  """
53
52
 
54
- _children_names: List[str] = ["feat_extractor", "postprocessor"]
53
+ _children_names: list[str] = ["feat_extractor", "postprocessor"]
55
54
 
56
55
  def __init__(
57
56
  self,
@@ -60,9 +59,9 @@ class ViTSTR(_ViTSTR, Model):
60
59
  embedding_units: int,
61
60
  max_length: int = 32,
62
61
  dropout_prob: float = 0.0,
63
- input_shape: Tuple[int, int, int] = (32, 128, 3), # different from paper
62
+ input_shape: tuple[int, int, int] = (32, 128, 3), # different from paper
64
63
  exportable: bool = False,
65
- cfg: Optional[Dict[str, Any]] = None,
64
+ cfg: dict[str, Any] | None = None,
66
65
  ) -> None:
67
66
  super().__init__()
68
67
  self.vocab = vocab
@@ -75,23 +74,30 @@ class ViTSTR(_ViTSTR, Model):
75
74
 
76
75
  self.postprocessor = ViTSTRPostProcessor(vocab=self.vocab)
77
76
 
77
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
78
+ """Load pretrained parameters onto the model
79
+
80
+ Args:
81
+ path_or_url: the path or URL to the model parameters (checkpoint)
82
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
83
+ """
84
+ load_pretrained_params(self, path_or_url, **kwargs)
85
+
78
86
  @staticmethod
79
87
  def compute_loss(
80
88
  model_output: tf.Tensor,
81
89
  gt: tf.Tensor,
82
- seq_len: List[int],
90
+ seq_len: list[int],
83
91
  ) -> tf.Tensor:
84
92
  """Compute categorical cross-entropy loss for the model.
85
93
  Sequences are masked after the EOS character.
86
94
 
87
95
  Args:
88
- ----
89
96
  model_output: predicted logits of the model
90
97
  gt: the encoded tensor with gt labels
91
98
  seq_len: lengths of each gt word inside the batch
92
99
 
93
100
  Returns:
94
- -------
95
101
  The loss of the model on the batch
96
102
  """
97
103
  # Input length : number of steps
@@ -114,11 +120,11 @@ class ViTSTR(_ViTSTR, Model):
114
120
  def call(
115
121
  self,
116
122
  x: tf.Tensor,
117
- target: Optional[List[str]] = None,
123
+ target: list[str] | None = None,
118
124
  return_model_output: bool = False,
119
125
  return_preds: bool = False,
120
126
  **kwargs: Any,
121
- ) -> Dict[str, Any]:
127
+ ) -> dict[str, Any]:
122
128
  features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
123
129
 
124
130
  if target is not None:
@@ -136,7 +142,7 @@ class ViTSTR(_ViTSTR, Model):
136
142
  ) # (batch_size, max_length, vocab + 1)
137
143
  decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token
138
144
 
139
- out: Dict[str, tf.Tensor] = {}
145
+ out: dict[str, tf.Tensor] = {}
140
146
  if self.exportable:
141
147
  out["logits"] = decoded_features
142
148
  return out
@@ -158,14 +164,13 @@ class ViTSTRPostProcessor(_ViTSTRPostProcessor):
158
164
  """Post processor for ViTSTR architecture
159
165
 
160
166
  Args:
161
- ----
162
167
  vocab: string containing the ordered sequence of supported characters
163
168
  """
164
169
 
165
170
  def __call__(
166
171
  self,
167
172
  logits: tf.Tensor,
168
- ) -> List[Tuple[str, float]]:
173
+ ) -> list[tuple[str, float]]:
169
174
  # compute pred with argmax for attention models
170
175
  out_idxs = tf.math.argmax(logits, axis=2)
171
176
  preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
@@ -191,7 +196,7 @@ def _vitstr(
191
196
  arch: str,
192
197
  pretrained: bool,
193
198
  backbone_fn,
194
- input_shape: Optional[Tuple[int, int, int]] = None,
199
+ input_shape: tuple[int, int, int] | None = None,
195
200
  **kwargs: Any,
196
201
  ) -> ViTSTR:
197
202
  # Patch the config
@@ -221,9 +226,7 @@ def _vitstr(
221
226
  # Load pretrained parameters
222
227
  if pretrained:
223
228
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
224
- load_pretrained_params(
225
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
226
- )
229
+ model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
227
230
 
228
231
  return model
229
232
 
@@ -239,12 +242,10 @@ def vitstr_small(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
239
242
  >>> out = model(input_tensor)
240
243
 
241
244
  Args:
242
- ----
243
245
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
244
246
  **kwargs: keyword arguments of the ViTSTR architecture
245
247
 
246
248
  Returns:
247
- -------
248
249
  text recognition architecture
249
250
  """
250
251
  return _vitstr(
@@ -268,12 +269,10 @@ def vitstr_base(pretrained: bool = False, **kwargs: Any) -> ViTSTR:
268
269
  >>> out = model(input_tensor)
269
270
 
270
271
  Args:
271
- ----
272
272
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
273
273
  **kwargs: keyword arguments of the ViTSTR architecture
274
274
 
275
275
  Returns:
276
- -------
277
276
  text recognition architecture
278
277
  """
279
278
  return _vitstr(
@@ -1,11 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List
6
+ from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available
8
+ from doctr.file_utils import is_tf_available, is_torch_available
9
9
  from doctr.models.preprocessor import PreProcessor
10
10
 
11
11
  from .. import recognition
@@ -14,7 +14,7 @@ from .predictor import RecognitionPredictor
14
14
  __all__ = ["recognition_predictor"]
15
15
 
16
16
 
17
- ARCHS: List[str] = [
17
+ ARCHS: list[str] = [
18
18
  "crnn_vgg16_bn",
19
19
  "crnn_mobilenet_v3_small",
20
20
  "crnn_mobilenet_v3_large",
@@ -25,6 +25,9 @@ ARCHS: List[str] = [
25
25
  "parseq",
26
26
  ]
27
27
 
28
+ if is_torch_available():
29
+ ARCHS.extend(["viptr_tiny"])
30
+
28
31
 
29
32
  def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredictor:
30
33
  if isinstance(arch, str):
@@ -35,9 +38,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
35
38
  pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True)
36
39
  )
37
40
  else:
38
- if not isinstance(
39
- arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
40
- ):
41
+ allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq]
42
+ if is_torch_available():
43
+ # Add VIPTR which is only available in torch at the moment
44
+ allowed_archs.append(recognition.VIPTR)
45
+ # Adding the type for torch compiled models to the allowed architectures
46
+ from doctr.models.utils import _CompiledModule
47
+
48
+ allowed_archs.append(_CompiledModule)
49
+
50
+ if not isinstance(arch, tuple(allowed_archs)):
41
51
  raise ValueError(f"unknown architecture: {type(arch)}")
42
52
  _model = arch
43
53
 
@@ -52,7 +62,13 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict
52
62
  return predictor
53
63
 
54
64
 
55
- def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor:
65
+ def recognition_predictor(
66
+ arch: Any = "crnn_vgg16_bn",
67
+ pretrained: bool = False,
68
+ symmetric_pad: bool = False,
69
+ batch_size: int = 128,
70
+ **kwargs: Any,
71
+ ) -> RecognitionPredictor:
56
72
  """Text recognition architecture.
57
73
 
58
74
  Example::
@@ -63,13 +79,13 @@ def recognition_predictor(arch: Any = "crnn_vgg16_bn", pretrained: bool = False,
63
79
  >>> out = model([input_page])
64
80
 
65
81
  Args:
66
- ----
67
82
  arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
68
83
  pretrained: If True, returns a model pre-trained on our text recognition dataset
84
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
85
+ batch_size: number of samples the model processes in parallel
69
86
  **kwargs: optional parameters to be passed to the architecture
70
87
 
71
88
  Returns:
72
- -------
73
89
  Recognition predictor
74
90
  """
75
- return _predictor(arch, pretrained, **kwargs)
91
+ return _predictor(arch=arch, pretrained=pretrained, symmetric_pad=symmetric_pad, batch_size=batch_size, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,12 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
- from typing import Any, List, Optional, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import torch
10
+ import validators
10
11
  from torch import nn
11
12
 
12
13
  from doctr.utils.data import download_from_url
@@ -18,8 +19,12 @@ __all__ = [
18
19
  "export_model_to_onnx",
19
20
  "_copy_tensor",
20
21
  "_bf16_to_float32",
22
+ "_CompiledModule",
21
23
  ]
22
24
 
25
+ # torch compiled model type
26
+ _CompiledModule = torch._dynamo.eval_frame.OptimizedModule
27
+
23
28
 
24
29
  def _copy_tensor(x: torch.Tensor) -> torch.Tensor:
25
30
  return x.clone().detach()
@@ -32,42 +37,50 @@ def _bf16_to_float32(x: torch.Tensor) -> torch.Tensor:
32
37
 
33
38
  def load_pretrained_params(
34
39
  model: nn.Module,
35
- url: Optional[str] = None,
36
- hash_prefix: Optional[str] = None,
37
- ignore_keys: Optional[List[str]] = None,
40
+ path_or_url: str | None = None,
41
+ hash_prefix: str | None = None,
42
+ ignore_keys: list[str] | None = None,
38
43
  **kwargs: Any,
39
44
  ) -> None:
40
45
  """Load a set of parameters onto a model
41
46
 
42
47
  >>> from doctr.models import load_pretrained_params
43
- >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip")
48
+ >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.pt")
44
49
 
45
50
  Args:
46
- ----
47
51
  model: the PyTorch model to be loaded
48
- url: URL of the zipped set of parameters
52
+ path_or_url: the path or URL to the model parameters (checkpoint)
49
53
  hash_prefix: first characters of SHA256 expected hash
50
54
  ignore_keys: list of weights to be ignored from the state_dict
51
55
  **kwargs: additional arguments to be passed to `doctr.utils.data.download_from_url`
52
56
  """
53
- if url is None:
54
- logging.warning("Invalid model URL, using default initialization.")
55
- else:
56
- archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
57
+ if path_or_url is None:
58
+ logging.warning("No model URL or Path provided, using default initialization.")
59
+ return
60
+
61
+ archive_path = (
62
+ download_from_url(path_or_url, hash_prefix=hash_prefix, cache_subdir="models", **kwargs)
63
+ if validators.url(path_or_url)
64
+ else path_or_url
65
+ )
57
66
 
58
- # Read state_dict
59
- state_dict = torch.load(archive_path, map_location="cpu")
67
+ # Read state_dict
68
+ state_dict = torch.load(archive_path, map_location="cpu")
60
69
 
61
- # Remove weights from the state_dict
62
- if ignore_keys is not None and len(ignore_keys) > 0:
63
- for key in ignore_keys:
70
+ # Remove weights from the state_dict
71
+ if ignore_keys is not None and len(ignore_keys) > 0:
72
+ for key in ignore_keys:
73
+ if key in state_dict:
64
74
  state_dict.pop(key)
65
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
66
- if set(missing_keys) != set(ignore_keys) or len(unexpected_keys) > 0:
67
- raise ValueError("unable to load state_dict, due to non-matching keys.")
68
- else:
69
- # Load weights
70
- model.load_state_dict(state_dict)
75
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
76
+ if any(k not in ignore_keys for k in missing_keys + unexpected_keys):
77
+ raise ValueError(
78
+ "Unable to load state_dict, due to non-matching keys.\n"
79
+ + f"Unexpected keys: {unexpected_keys}\nMissing keys: {missing_keys}"
80
+ )
81
+ else:
82
+ # Load weights
83
+ model.load_state_dict(state_dict)
71
84
 
72
85
 
73
86
  def conv_sequence_pt(
@@ -76,7 +89,7 @@ def conv_sequence_pt(
76
89
  relu: bool = False,
77
90
  bn: bool = False,
78
91
  **kwargs: Any,
79
- ) -> List[nn.Module]:
92
+ ) -> list[nn.Module]:
80
93
  """Builds a convolutional-based layer sequence
81
94
 
82
95
  >>> from torch.nn import Sequential
@@ -84,7 +97,6 @@ def conv_sequence_pt(
84
97
  >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3))
85
98
 
86
99
  Args:
87
- ----
88
100
  in_channels: number of input channels
89
101
  out_channels: number of output channels
90
102
  relu: whether ReLU should be used
@@ -92,13 +104,12 @@ def conv_sequence_pt(
92
104
  **kwargs: additional arguments to be passed to the convolutional layer
93
105
 
94
106
  Returns:
95
- -------
96
107
  list of layers
97
108
  """
98
109
  # No bias before Batch norm
99
110
  kwargs["bias"] = kwargs.get("bias", not bn)
100
111
  # Add activation directly to the conv if there is no BN
101
- conv_seq: List[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
112
+ conv_seq: list[nn.Module] = [nn.Conv2d(in_channels, out_channels, **kwargs)]
102
113
 
103
114
  if bn:
104
115
  conv_seq.append(nn.BatchNorm2d(out_channels))
@@ -110,8 +121,8 @@ def conv_sequence_pt(
110
121
 
111
122
 
112
123
  def set_device_and_dtype(
113
- model: Any, batches: List[torch.Tensor], device: Union[str, torch.device], dtype: torch.dtype
114
- ) -> Tuple[Any, List[torch.Tensor]]:
124
+ model: Any, batches: list[torch.Tensor], device: str | torch.device, dtype: torch.dtype
125
+ ) -> tuple[Any, list[torch.Tensor]]:
115
126
  """Set the device and dtype of a model and its batches
116
127
 
117
128
  >>> import torch
@@ -122,14 +133,12 @@ def set_device_and_dtype(
122
133
  >>> model, batches = set_device_and_dtype(model, batches, device="cuda", dtype=torch.float16)
123
134
 
124
135
  Args:
125
- ----
126
136
  model: the model to be set
127
137
  batches: the batches to be set
128
138
  device: the device to be used
129
139
  dtype: the dtype to be used
130
140
 
131
141
  Returns:
132
- -------
133
142
  the model and batches set
134
143
  """
135
144
  return model.to(device=device, dtype=dtype), [batch.to(device=device, dtype=dtype) for batch in batches]
@@ -145,19 +154,17 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
145
154
  >>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
146
155
 
147
156
  Args:
148
- ----
149
157
  model: the PyTorch model to be exported
150
158
  model_name: the name for the exported model
151
159
  dummy_input: the dummy input to the model
152
160
  kwargs: additional arguments to be passed to torch.onnx.export
153
161
 
154
162
  Returns:
155
- -------
156
163
  the path to the exported model
157
164
  """
158
165
  torch.onnx.export(
159
166
  model,
160
- dummy_input, # type: ignore[arg-type]
167
+ dummy_input,
161
168
  f"{model_name}.onnx",
162
169
  input_names=["input"],
163
170
  output_names=["logits"],