python-doctr 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +17 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +17 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +14 -5
  17. doctr/datasets/ic13.py +13 -5
  18. doctr/datasets/iiit5k.py +31 -20
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +15 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +16 -5
  27. doctr/datasets/svhn.py +16 -5
  28. doctr/datasets/svt.py +14 -5
  29. doctr/datasets/synthtext.py +14 -5
  30. doctr/datasets/utils.py +37 -27
  31. doctr/datasets/vocabs.py +21 -7
  32. doctr/datasets/wildreceipt.py +25 -10
  33. doctr/file_utils.py +18 -4
  34. doctr/io/elements.py +69 -81
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +32 -50
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +21 -17
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +7 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +22 -29
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +13 -11
  52. doctr/models/classification/predictor/tensorflow.py +13 -11
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +41 -39
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +19 -20
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +18 -15
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +16 -16
  65. doctr/models/classification/zoo.py +36 -19
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +49 -37
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +28 -37
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +36 -33
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +7 -8
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +8 -13
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +8 -5
  91. doctr/models/kie_predictor/pytorch.py +22 -19
  92. doctr/models/kie_predictor/tensorflow.py +21 -15
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -12
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +3 -4
  101. doctr/models/modules/vision_transformer/tensorflow.py +4 -4
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +52 -41
  104. doctr/models/predictor/pytorch.py +16 -13
  105. doctr/models/predictor/tensorflow.py +16 -10
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +11 -15
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +19 -29
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +21 -26
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +26 -30
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +19 -24
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +21 -24
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +13 -16
  136. doctr/models/utils/tensorflow.py +31 -30
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +21 -29
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +65 -28
  145. doctr/transforms/modules/tensorflow.py +33 -44
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +120 -64
  150. doctr/utils/metrics.py +18 -38
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +157 -75
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +59 -57
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.9.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.9.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,11 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
8
  from itertools import groupby
8
- from typing import Any, Callable, Dict, List, Optional, Tuple
9
+ from typing import Any
9
10
 
10
11
  import torch
11
12
  from torch import nn
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
19
20
 
20
21
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
21
22
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
23
24
  "crnn_vgg16_bn": {
24
25
  "mean": (0.694, 0.695, 0.693),
25
26
  "std": (0.299, 0.296, 0.301),
@@ -48,7 +49,6 @@ class CTCPostProcessor(RecognitionPostProcessor):
48
49
  """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
49
50
 
50
51
  Args:
51
- ----
52
52
  vocab: string containing the ordered sequence of supported characters
53
53
  """
54
54
 
@@ -57,18 +57,16 @@ class CTCPostProcessor(RecognitionPostProcessor):
57
57
  logits: torch.Tensor,
58
58
  vocab: str = VOCABS["french"],
59
59
  blank: int = 0,
60
- ) -> List[Tuple[str, float]]:
60
+ ) -> list[tuple[str, float]]:
61
61
  """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
62
62
  <https://github.com/githubharald/CTCDecoder>`_.
63
63
 
64
64
  Args:
65
- ----
66
65
  logits: model output, shape: N x T x C
67
66
  vocab: vocabulary to use
68
67
  blank: index of blank label
69
68
 
70
69
  Returns:
71
- -------
72
70
  A list of tuples: (word, confidence)
73
71
  """
74
72
  # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
@@ -82,16 +80,14 @@ class CTCPostProcessor(RecognitionPostProcessor):
82
80
 
83
81
  return list(zip(words, probs.tolist()))
84
82
 
85
- def __call__(self, logits: torch.Tensor) -> List[Tuple[str, float]]:
83
+ def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
86
84
  """Performs decoding of raw output with CTC and decoding of CTC predictions
87
85
  with label_to_idx mapping dictionnary
88
86
 
89
87
  Args:
90
- ----
91
88
  logits: raw output of the model, shape (N, C + 1, seq_len)
92
89
 
93
90
  Returns:
94
- -------
95
91
  A tuple of 2 lists: a list of str (words) and a list of float (probs)
96
92
 
97
93
  """
@@ -104,7 +100,6 @@ class CRNN(RecognitionModel, nn.Module):
104
100
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
105
101
 
106
102
  Args:
107
- ----
108
103
  feature_extractor: the backbone serving as feature extractor
109
104
  vocab: vocabulary used for encoding
110
105
  rnn_units: number of units in the LSTM layers
@@ -112,16 +107,16 @@ class CRNN(RecognitionModel, nn.Module):
112
107
  cfg: configuration dictionary
113
108
  """
114
109
 
115
- _children_names: List[str] = ["feat_extractor", "decoder", "linear", "postprocessor"]
110
+ _children_names: list[str] = ["feat_extractor", "decoder", "linear", "postprocessor"]
116
111
 
117
112
  def __init__(
118
113
  self,
119
114
  feature_extractor: nn.Module,
120
115
  vocab: str,
121
116
  rnn_units: int = 128,
122
- input_shape: Tuple[int, int, int] = (3, 32, 128),
117
+ input_shape: tuple[int, int, int] = (3, 32, 128),
123
118
  exportable: bool = False,
124
- cfg: Optional[Dict[str, Any]] = None,
119
+ cfg: dict[str, Any] | None = None,
125
120
  ) -> None:
126
121
  super().__init__()
127
122
  self.vocab = vocab
@@ -163,17 +158,15 @@ class CRNN(RecognitionModel, nn.Module):
163
158
  def compute_loss(
164
159
  self,
165
160
  model_output: torch.Tensor,
166
- target: List[str],
161
+ target: list[str],
167
162
  ) -> torch.Tensor:
168
163
  """Compute CTC loss for the model.
169
164
 
170
165
  Args:
171
- ----
172
166
  model_output: predicted logits of the model
173
167
  target: list of target strings
174
168
 
175
169
  Returns:
176
- -------
177
170
  The loss of the model on the batch
178
171
  """
179
172
  gt, seq_len = self.build_target(target)
@@ -196,10 +189,10 @@ class CRNN(RecognitionModel, nn.Module):
196
189
  def forward(
197
190
  self,
198
191
  x: torch.Tensor,
199
- target: Optional[List[str]] = None,
192
+ target: list[str] | None = None,
200
193
  return_model_output: bool = False,
201
194
  return_preds: bool = False,
202
- ) -> Dict[str, Any]:
195
+ ) -> dict[str, Any]:
203
196
  if self.training and target is None:
204
197
  raise ValueError("Need to provide labels during training")
205
198
 
@@ -211,7 +204,7 @@ class CRNN(RecognitionModel, nn.Module):
211
204
  logits, _ = self.decoder(features_seq)
212
205
  logits = self.linear(logits)
213
206
 
214
- out: Dict[str, Any] = {}
207
+ out: dict[str, Any] = {}
215
208
  if self.exportable:
216
209
  out["logits"] = logits
217
210
  return out
@@ -220,8 +213,13 @@ class CRNN(RecognitionModel, nn.Module):
220
213
  out["out_map"] = logits
221
214
 
222
215
  if target is None or return_preds:
216
+ # Disable for torch.compile compatibility
217
+ @torch.compiler.disable # type: ignore[attr-defined]
218
+ def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
219
+ return self.postprocessor(logits)
220
+
223
221
  # Post-process boxes
224
- out["preds"] = self.postprocessor(logits)
222
+ out["preds"] = _postprocess(logits)
225
223
 
226
224
  if target is not None:
227
225
  out["loss"] = self.compute_loss(logits, target)
@@ -234,7 +232,7 @@ def _crnn(
234
232
  pretrained: bool,
235
233
  backbone_fn: Callable[[Any], nn.Module],
236
234
  pretrained_backbone: bool = True,
237
- ignore_keys: Optional[List[str]] = None,
235
+ ignore_keys: list[str] | None = None,
238
236
  **kwargs: Any,
239
237
  ) -> CRNN:
240
238
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -272,12 +270,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
272
270
  >>> out = model(input_tensor)
273
271
 
274
272
  Args:
275
- ----
276
273
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
277
274
  **kwargs: keyword arguments of the CRNN architecture
278
275
 
279
276
  Returns:
280
- -------
281
277
  text recognition architecture
282
278
  """
283
279
  return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs)
@@ -294,12 +290,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
294
290
  >>> out = model(input_tensor)
295
291
 
296
292
  Args:
297
- ----
298
293
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
299
294
  **kwargs: keyword arguments of the CRNN architecture
300
295
 
301
296
  Returns:
302
- -------
303
297
  text recognition architecture
304
298
  """
305
299
  return _crnn(
@@ -322,12 +316,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
322
316
  >>> out = model(input_tensor)
323
317
 
324
318
  Args:
325
- ----
326
319
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
327
320
  **kwargs: keyword arguments of the CRNN architecture
328
321
 
329
322
  Returns:
330
- -------
331
323
  text recognition architecture
332
324
  """
333
325
  return _crnn(
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import layers
@@ -13,32 +13,32 @@ from tensorflow.keras.models import Model, Sequential
13
13
  from doctr.datasets import VOCABS
14
14
 
15
15
  from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
16
- from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
16
+ from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
17
17
  from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
20
20
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ default_cfgs: dict[str, dict[str, Any]] = {
22
22
  "crnn_vgg16_bn": {
23
23
  "mean": (0.694, 0.695, 0.693),
24
24
  "std": (0.299, 0.296, 0.301),
25
25
  "input_shape": (32, 128, 3),
26
26
  "vocab": VOCABS["legacy_french"],
27
- "url": "https://doctr-static.mindee.com/models?id=v0.3.0/crnn_vgg16_bn-76b7f2c6.zip&src=0",
27
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
28
28
  },
29
29
  "crnn_mobilenet_v3_small": {
30
30
  "mean": (0.694, 0.695, 0.693),
31
31
  "std": (0.299, 0.296, 0.301),
32
32
  "input_shape": (32, 128, 3),
33
33
  "vocab": VOCABS["french"],
34
- "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_mobilenet_v3_small-7f36edec.zip&src=0",
34
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
35
35
  },
36
36
  "crnn_mobilenet_v3_large": {
37
37
  "mean": (0.694, 0.695, 0.693),
38
38
  "std": (0.299, 0.296, 0.301),
39
39
  "input_shape": (32, 128, 3),
40
40
  "vocab": VOCABS["french"],
41
- "url": "https://doctr-static.mindee.com/models?id=v0.6.0/crnn_mobilenet_v3_large-cccc50b1.zip&src=0",
41
+ "url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
42
42
  },
43
43
  }
44
44
 
@@ -47,7 +47,6 @@ class CTCPostProcessor(RecognitionPostProcessor):
47
47
  """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
48
48
 
49
49
  Args:
50
- ----
51
50
  vocab: string containing the ordered sequence of supported characters
52
51
  ignore_case: if True, ignore case of letters
53
52
  ignore_accents: if True, ignore accents of letters
@@ -58,18 +57,16 @@ class CTCPostProcessor(RecognitionPostProcessor):
58
57
  logits: tf.Tensor,
59
58
  beam_width: int = 1,
60
59
  top_paths: int = 1,
61
- ) -> Union[List[Tuple[str, float]], List[Tuple[List[str], List[float]]]]:
60
+ ) -> list[tuple[str, float]] | list[tuple[list[str] | list[float]]]:
62
61
  """Performs decoding of raw output with CTC and decoding of CTC predictions
63
62
  with label_to_idx mapping dictionnary
64
63
 
65
64
  Args:
66
- ----
67
65
  logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
68
66
  beam_width: An int scalar >= 0 (beam search beam width).
69
67
  top_paths: An int scalar >= 0, <= beam_width (controls output size).
70
68
 
71
69
  Returns:
72
- -------
73
70
  A list of decoded words of length BATCH_SIZE
74
71
 
75
72
 
@@ -114,7 +111,6 @@ class CRNN(RecognitionModel, Model):
114
111
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
115
112
 
116
113
  Args:
117
- ----
118
114
  feature_extractor: the backbone serving as feature extractor
119
115
  vocab: vocabulary used for encoding
120
116
  rnn_units: number of units in the LSTM layers
@@ -124,17 +120,17 @@ class CRNN(RecognitionModel, Model):
124
120
  cfg: configuration dictionary
125
121
  """
126
122
 
127
- _children_names: List[str] = ["feat_extractor", "decoder", "postprocessor"]
123
+ _children_names: list[str] = ["feat_extractor", "decoder", "postprocessor"]
128
124
 
129
125
  def __init__(
130
126
  self,
131
- feature_extractor: tf.keras.Model,
127
+ feature_extractor: Model,
132
128
  vocab: str,
133
129
  rnn_units: int = 128,
134
130
  exportable: bool = False,
135
131
  beam_width: int = 1,
136
132
  top_paths: int = 1,
137
- cfg: Optional[Dict[str, Any]] = None,
133
+ cfg: dict[str, Any] | None = None,
138
134
  ) -> None:
139
135
  # Initialize kernels
140
136
  h, w, c = feature_extractor.output_shape[1:]
@@ -161,17 +157,15 @@ class CRNN(RecognitionModel, Model):
161
157
  def compute_loss(
162
158
  self,
163
159
  model_output: tf.Tensor,
164
- target: List[str],
160
+ target: list[str],
165
161
  ) -> tf.Tensor:
166
162
  """Compute CTC loss for the model.
167
163
 
168
164
  Args:
169
- ----
170
165
  model_output: predicted logits of the model
171
166
  target: lengths of each gt word inside the batch
172
167
 
173
168
  Returns:
174
- -------
175
169
  The loss of the model on the batch
176
170
  """
177
171
  gt, seq_len = self.build_target(target)
@@ -185,13 +179,13 @@ class CRNN(RecognitionModel, Model):
185
179
  def call(
186
180
  self,
187
181
  x: tf.Tensor,
188
- target: Optional[List[str]] = None,
182
+ target: list[str] | None = None,
189
183
  return_model_output: bool = False,
190
184
  return_preds: bool = False,
191
185
  beam_width: int = 1,
192
186
  top_paths: int = 1,
193
187
  **kwargs: Any,
194
- ) -> Dict[str, Any]:
188
+ ) -> dict[str, Any]:
195
189
  if kwargs.get("training", False) and target is None:
196
190
  raise ValueError("Need to provide labels during training")
197
191
 
@@ -203,7 +197,7 @@ class CRNN(RecognitionModel, Model):
203
197
  features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
204
198
  logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
205
199
 
206
- out: Dict[str, tf.Tensor] = {}
200
+ out: dict[str, tf.Tensor] = {}
207
201
  if self.exportable:
208
202
  out["logits"] = logits
209
203
  return out
@@ -226,7 +220,7 @@ def _crnn(
226
220
  pretrained: bool,
227
221
  backbone_fn,
228
222
  pretrained_backbone: bool = True,
229
- input_shape: Optional[Tuple[int, int, int]] = None,
223
+ input_shape: tuple[int, int, int] | None = None,
230
224
  **kwargs: Any,
231
225
  ) -> CRNN:
232
226
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -245,9 +239,11 @@ def _crnn(
245
239
 
246
240
  # Build the model
247
241
  model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
242
+ _build_model(model)
248
243
  # Load pretrained parameters
249
244
  if pretrained:
250
- load_pretrained_params(model, _cfg["url"])
245
+ # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
246
+ load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
251
247
 
252
248
  return model
253
249
 
@@ -263,12 +259,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
263
259
  >>> out = model(input_tensor)
264
260
 
265
261
  Args:
266
- ----
267
262
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
268
263
  **kwargs: keyword arguments of the CRNN architecture
269
264
 
270
265
  Returns:
271
- -------
272
266
  text recognition architecture
273
267
  """
274
268
  return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
@@ -285,12 +279,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
285
279
  >>> out = model(input_tensor)
286
280
 
287
281
  Args:
288
- ----
289
282
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
290
283
  **kwargs: keyword arguments of the CRNN architecture
291
284
 
292
285
  Returns:
293
- -------
294
286
  text recognition architecture
295
287
  """
296
288
  return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
@@ -307,12 +299,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
307
299
  >>> out = model(input_tensor)
308
300
 
309
301
  Args:
310
- ----
311
302
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
312
303
  **kwargs: keyword arguments of the CRNN architecture
313
304
 
314
305
  Returns:
315
- -------
316
306
  text recognition architecture
317
307
  """
318
308
  return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -17,17 +16,15 @@ class _MASTER:
17
16
 
18
17
  def build_target(
19
18
  self,
20
- gts: List[str],
21
- ) -> Tuple[np.ndarray, List[int]]:
19
+ gts: list[str],
20
+ ) -> tuple[np.ndarray, list[int]]:
22
21
  """Encode a list of gts sequences into a np array and gives the corresponding*
23
22
  sequence lengths.
24
23
 
25
24
  Args:
26
- ----
27
25
  gts: list of ground-truth labels
28
26
 
29
27
  Returns:
30
- -------
31
28
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
32
29
  """
33
30
  encoded = encode_sequences(
@@ -46,7 +43,6 @@ class _MASTERPostProcessor(RecognitionPostProcessor):
46
43
  """Abstract class to postprocess the raw output of the model
47
44
 
48
45
  Args:
49
- ----
50
46
  vocab: string containing the ordered sequence of supported characters
51
47
  """
52
48
 
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import torch
10
11
  from torch import nn
@@ -21,7 +22,7 @@ from .base import _MASTER, _MASTERPostProcessor
21
22
  __all__ = ["MASTER", "master"]
22
23
 
23
24
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
25
+ default_cfgs: dict[str, dict[str, Any]] = {
25
26
  "master": {
26
27
  "mean": (0.694, 0.695, 0.693),
27
28
  "std": (0.299, 0.296, 0.301),
@@ -37,7 +38,6 @@ class MASTER(_MASTER, nn.Module):
37
38
  Implementation based on the official Pytorch implementation: <https://github.com/wenwenyu/MASTER-pytorch>`_.
38
39
 
39
40
  Args:
40
- ----
41
41
  feature_extractor: the backbone serving as feature extractor
42
42
  vocab: vocabulary, (without EOS, SOS, PAD)
43
43
  d_model: d parameter for the transformer decoder
@@ -61,9 +61,9 @@ class MASTER(_MASTER, nn.Module):
61
61
  num_layers: int = 3,
62
62
  max_length: int = 50,
63
63
  dropout: float = 0.2,
64
- input_shape: Tuple[int, int, int] = (3, 32, 128), # different from the paper
64
+ input_shape: tuple[int, int, int] = (3, 32, 128), # different from the paper
65
65
  exportable: bool = False,
66
- cfg: Optional[Dict[str, Any]] = None,
66
+ cfg: dict[str, Any] | None = None,
67
67
  ) -> None:
68
68
  super().__init__()
69
69
 
@@ -102,12 +102,12 @@ class MASTER(_MASTER, nn.Module):
102
102
 
103
103
  def make_source_and_target_mask(
104
104
  self, source: torch.Tensor, target: torch.Tensor
105
- ) -> Tuple[torch.Tensor, torch.Tensor]:
105
+ ) -> tuple[torch.Tensor, torch.Tensor]:
106
106
  # borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch
107
107
  # NOTE: nn.TransformerDecoder takes the inverse from this implementation
108
108
  # [True, True, True, ..., False, False, False] -> False is masked
109
109
  # (N, 1, 1, max_length)
110
- target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
110
+ target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
111
111
  target_length = target.size(1)
112
112
  # sub mask filled diagonal with True = see and False = masked (max_length, max_length)
113
113
  # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
@@ -130,19 +130,17 @@ class MASTER(_MASTER, nn.Module):
130
130
  Sequences are masked after the EOS character.
131
131
 
132
132
  Args:
133
- ----
134
133
  gt: the encoded tensor with gt labels
135
134
  model_output: predicted logits of the model
136
135
  seq_len: lengths of each gt word inside the batch
137
136
 
138
137
  Returns:
139
- -------
140
138
  The loss of the model on the batch
141
139
  """
142
140
  # Input length : number of timesteps
143
141
  input_len = model_output.shape[1]
144
142
  # Add one for additional <eos> token (sos disappear in shift!)
145
- seq_len = seq_len + 1
143
+ seq_len = seq_len + 1 # type: ignore[assignment]
146
144
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
147
145
  # The "masked" first gt char is <sos>. Delete last logit of the model output.
148
146
  cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -156,21 +154,19 @@ class MASTER(_MASTER, nn.Module):
156
154
  def forward(
157
155
  self,
158
156
  x: torch.Tensor,
159
- target: Optional[List[str]] = None,
157
+ target: list[str] | None = None,
160
158
  return_model_output: bool = False,
161
159
  return_preds: bool = False,
162
- ) -> Dict[str, Any]:
160
+ ) -> dict[str, Any]:
163
161
  """Call function for training
164
162
 
165
163
  Args:
166
- ----
167
164
  x: images
168
165
  target: list of str labels
169
166
  return_model_output: if True, return logits
170
167
  return_preds: if True, decode logits
171
168
 
172
169
  Returns:
173
- -------
174
170
  A dictionnary containing eventually loss, logits and predictions.
175
171
  """
176
172
  # Encode
@@ -181,7 +177,7 @@ class MASTER(_MASTER, nn.Module):
181
177
  # add positional encoding to features
182
178
  encoded = self.positional_encoding(features)
183
179
 
184
- out: Dict[str, Any] = {}
180
+ out: dict[str, Any] = {}
185
181
 
186
182
  if self.training and target is None:
187
183
  raise ValueError("Need to provide labels during training")
@@ -213,7 +209,13 @@ class MASTER(_MASTER, nn.Module):
213
209
  out["out_map"] = logits
214
210
 
215
211
  if return_preds:
216
- out["preds"] = self.postprocessor(logits)
212
+ # Disable for torch.compile compatibility
213
+ @torch.compiler.disable # type: ignore[attr-defined]
214
+ def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
215
+ return self.postprocessor(logits)
216
+
217
+ # Post-process boxes
218
+ out["preds"] = _postprocess(logits)
217
219
 
218
220
  return out
219
221
 
@@ -221,12 +223,10 @@ class MASTER(_MASTER, nn.Module):
221
223
  """Decode function for prediction
222
224
 
223
225
  Args:
224
- ----
225
226
  encoded: input tensor
226
227
 
227
228
  Returns:
228
- -------
229
- A Tuple of torch.Tensor: predictions, logits
229
+ A tuple of torch.Tensor: predictions, logits
230
230
  """
231
231
  b = encoded.size(0)
232
232
 
@@ -254,7 +254,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
254
254
  def __call__(
255
255
  self,
256
256
  logits: torch.Tensor,
257
- ) -> List[Tuple[str, float]]:
257
+ ) -> list[tuple[str, float]]:
258
258
  # compute pred with argmax for attention models
259
259
  out_idxs = logits.argmax(-1)
260
260
  # N x L
@@ -277,7 +277,7 @@ def _master(
277
277
  backbone_fn: Callable[[bool], nn.Module],
278
278
  layer: str,
279
279
  pretrained_backbone: bool = True,
280
- ignore_keys: Optional[List[str]] = None,
280
+ ignore_keys: list[str] | None = None,
281
281
  **kwargs: Any,
282
282
  ) -> MASTER:
283
283
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -316,12 +316,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
316
316
  >>> out = model(input_tensor)
317
317
 
318
318
  Args:
319
- ----
320
319
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
321
320
  **kwargs: keywoard arguments passed to the MASTER architecture
322
321
 
323
322
  Returns:
324
- -------
325
323
  text recognition architecture
326
324
  """
327
325
  return _master(