python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,11 +1,12 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
8
  from itertools import groupby
8
- from typing import Any, Callable, Dict, List, Optional, Tuple
9
+ from typing import Any
9
10
 
10
11
  import torch
11
12
  from torch import nn
@@ -19,7 +20,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
19
20
 
20
21
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
21
22
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
23
+ default_cfgs: dict[str, dict[str, Any]] = {
23
24
  "crnn_vgg16_bn": {
24
25
  "mean": (0.694, 0.695, 0.693),
25
26
  "std": (0.299, 0.296, 0.301),
@@ -48,7 +49,6 @@ class CTCPostProcessor(RecognitionPostProcessor):
48
49
  """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
49
50
 
50
51
  Args:
51
- ----
52
52
  vocab: string containing the ordered sequence of supported characters
53
53
  """
54
54
 
@@ -57,18 +57,16 @@ class CTCPostProcessor(RecognitionPostProcessor):
57
57
  logits: torch.Tensor,
58
58
  vocab: str = VOCABS["french"],
59
59
  blank: int = 0,
60
- ) -> List[Tuple[str, float]]:
60
+ ) -> list[tuple[str, float]]:
61
61
  """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from
62
62
  <https://github.com/githubharald/CTCDecoder>`_.
63
63
 
64
64
  Args:
65
- ----
66
65
  logits: model output, shape: N x T x C
67
66
  vocab: vocabulary to use
68
67
  blank: index of blank label
69
68
 
70
69
  Returns:
71
- -------
72
70
  A list of tuples: (word, confidence)
73
71
  """
74
72
  # Gather the most confident characters, and assign the smallest conf among those to the sequence prob
@@ -82,16 +80,14 @@ class CTCPostProcessor(RecognitionPostProcessor):
82
80
 
83
81
  return list(zip(words, probs.tolist()))
84
82
 
85
- def __call__(self, logits: torch.Tensor) -> List[Tuple[str, float]]:
83
+ def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
86
84
  """Performs decoding of raw output with CTC and decoding of CTC predictions
87
85
  with label_to_idx mapping dictionnary
88
86
 
89
87
  Args:
90
- ----
91
88
  logits: raw output of the model, shape (N, C + 1, seq_len)
92
89
 
93
90
  Returns:
94
- -------
95
91
  A tuple of 2 lists: a list of str (words) and a list of float (probs)
96
92
 
97
93
  """
@@ -104,7 +100,6 @@ class CRNN(RecognitionModel, nn.Module):
104
100
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
105
101
 
106
102
  Args:
107
- ----
108
103
  feature_extractor: the backbone serving as feature extractor
109
104
  vocab: vocabulary used for encoding
110
105
  rnn_units: number of units in the LSTM layers
@@ -112,16 +107,16 @@ class CRNN(RecognitionModel, nn.Module):
112
107
  cfg: configuration dictionary
113
108
  """
114
109
 
115
- _children_names: List[str] = ["feat_extractor", "decoder", "linear", "postprocessor"]
110
+ _children_names: list[str] = ["feat_extractor", "decoder", "linear", "postprocessor"]
116
111
 
117
112
  def __init__(
118
113
  self,
119
114
  feature_extractor: nn.Module,
120
115
  vocab: str,
121
116
  rnn_units: int = 128,
122
- input_shape: Tuple[int, int, int] = (3, 32, 128),
117
+ input_shape: tuple[int, int, int] = (3, 32, 128),
123
118
  exportable: bool = False,
124
- cfg: Optional[Dict[str, Any]] = None,
119
+ cfg: dict[str, Any] | None = None,
125
120
  ) -> None:
126
121
  super().__init__()
127
122
  self.vocab = vocab
@@ -160,20 +155,27 @@ class CRNN(RecognitionModel, nn.Module):
160
155
  m.weight.data.fill_(1.0)
161
156
  m.bias.data.zero_()
162
157
 
158
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
159
+ """Load pretrained parameters onto the model
160
+
161
+ Args:
162
+ path_or_url: the path or URL to the model parameters (checkpoint)
163
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
164
+ """
165
+ load_pretrained_params(self, path_or_url, **kwargs)
166
+
163
167
  def compute_loss(
164
168
  self,
165
169
  model_output: torch.Tensor,
166
- target: List[str],
170
+ target: list[str],
167
171
  ) -> torch.Tensor:
168
172
  """Compute CTC loss for the model.
169
173
 
170
174
  Args:
171
- ----
172
175
  model_output: predicted logits of the model
173
176
  target: list of target strings
174
177
 
175
178
  Returns:
176
- -------
177
179
  The loss of the model on the batch
178
180
  """
179
181
  gt, seq_len = self.build_target(target)
@@ -196,10 +198,10 @@ class CRNN(RecognitionModel, nn.Module):
196
198
  def forward(
197
199
  self,
198
200
  x: torch.Tensor,
199
- target: Optional[List[str]] = None,
201
+ target: list[str] | None = None,
200
202
  return_model_output: bool = False,
201
203
  return_preds: bool = False,
202
- ) -> Dict[str, Any]:
204
+ ) -> dict[str, Any]:
203
205
  if self.training and target is None:
204
206
  raise ValueError("Need to provide labels during training")
205
207
 
@@ -211,7 +213,7 @@ class CRNN(RecognitionModel, nn.Module):
211
213
  logits, _ = self.decoder(features_seq)
212
214
  logits = self.linear(logits)
213
215
 
214
- out: Dict[str, Any] = {}
216
+ out: dict[str, Any] = {}
215
217
  if self.exportable:
216
218
  out["logits"] = logits
217
219
  return out
@@ -220,8 +222,13 @@ class CRNN(RecognitionModel, nn.Module):
220
222
  out["out_map"] = logits
221
223
 
222
224
  if target is None or return_preds:
225
+ # Disable for torch.compile compatibility
226
+ @torch.compiler.disable # type: ignore[attr-defined]
227
+ def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
228
+ return self.postprocessor(logits)
229
+
223
230
  # Post-process boxes
224
- out["preds"] = self.postprocessor(logits)
231
+ out["preds"] = _postprocess(logits)
225
232
 
226
233
  if target is not None:
227
234
  out["loss"] = self.compute_loss(logits, target)
@@ -234,7 +241,7 @@ def _crnn(
234
241
  pretrained: bool,
235
242
  backbone_fn: Callable[[Any], nn.Module],
236
243
  pretrained_backbone: bool = True,
237
- ignore_keys: Optional[List[str]] = None,
244
+ ignore_keys: list[str] | None = None,
238
245
  **kwargs: Any,
239
246
  ) -> CRNN:
240
247
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -256,7 +263,7 @@ def _crnn(
256
263
  # The number of classes is not the same as the number of classes in the pretrained model =>
257
264
  # remove the last layer weights
258
265
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
259
- load_pretrained_params(model, _cfg["url"], ignore_keys=_ignore_keys)
266
+ model.from_pretrained(_cfg["url"], ignore_keys=_ignore_keys)
260
267
 
261
268
  return model
262
269
 
@@ -272,12 +279,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
272
279
  >>> out = model(input_tensor)
273
280
 
274
281
  Args:
275
- ----
276
282
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
277
283
  **kwargs: keyword arguments of the CRNN architecture
278
284
 
279
285
  Returns:
280
- -------
281
286
  text recognition architecture
282
287
  """
283
288
  return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, ignore_keys=["linear.weight", "linear.bias"], **kwargs)
@@ -294,12 +299,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
294
299
  >>> out = model(input_tensor)
295
300
 
296
301
  Args:
297
- ----
298
302
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
299
303
  **kwargs: keyword arguments of the CRNN architecture
300
304
 
301
305
  Returns:
302
- -------
303
306
  text recognition architecture
304
307
  """
305
308
  return _crnn(
@@ -322,12 +325,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
322
325
  >>> out = model(input_tensor)
323
326
 
324
327
  Args:
325
- ----
326
328
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
327
329
  **kwargs: keyword arguments of the CRNN architecture
328
330
 
329
331
  Returns:
330
- -------
331
332
  text recognition architecture
332
333
  """
333
334
  return _crnn(
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple, Union
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import layers
@@ -18,7 +18,7 @@ from ..core import RecognitionModel, RecognitionPostProcessor
18
18
 
19
19
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
20
20
 
21
- default_cfgs: Dict[str, Dict[str, Any]] = {
21
+ default_cfgs: dict[str, dict[str, Any]] = {
22
22
  "crnn_vgg16_bn": {
23
23
  "mean": (0.694, 0.695, 0.693),
24
24
  "std": (0.299, 0.296, 0.301),
@@ -47,7 +47,6 @@ class CTCPostProcessor(RecognitionPostProcessor):
47
47
  """Postprocess raw prediction of the model (logits) to a list of words using CTC decoding
48
48
 
49
49
  Args:
50
- ----
51
50
  vocab: string containing the ordered sequence of supported characters
52
51
  ignore_case: if True, ignore case of letters
53
52
  ignore_accents: if True, ignore accents of letters
@@ -58,18 +57,16 @@ class CTCPostProcessor(RecognitionPostProcessor):
58
57
  logits: tf.Tensor,
59
58
  beam_width: int = 1,
60
59
  top_paths: int = 1,
61
- ) -> Union[List[Tuple[str, float]], List[Tuple[List[str], List[float]]]]:
60
+ ) -> list[tuple[str, float]] | list[tuple[list[str] | list[float]]]:
62
61
  """Performs decoding of raw output with CTC and decoding of CTC predictions
63
62
  with label_to_idx mapping dictionnary
64
63
 
65
64
  Args:
66
- ----
67
65
  logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1
68
66
  beam_width: An int scalar >= 0 (beam search beam width).
69
67
  top_paths: An int scalar >= 0, <= beam_width (controls output size).
70
68
 
71
69
  Returns:
72
- -------
73
70
  A list of decoded words of length BATCH_SIZE
74
71
 
75
72
 
@@ -114,7 +111,6 @@ class CRNN(RecognitionModel, Model):
114
111
  Sequence Recognition and Its Application to Scene Text Recognition" <https://arxiv.org/pdf/1507.05717.pdf>`_.
115
112
 
116
113
  Args:
117
- ----
118
114
  feature_extractor: the backbone serving as feature extractor
119
115
  vocab: vocabulary used for encoding
120
116
  rnn_units: number of units in the LSTM layers
@@ -124,7 +120,7 @@ class CRNN(RecognitionModel, Model):
124
120
  cfg: configuration dictionary
125
121
  """
126
122
 
127
- _children_names: List[str] = ["feat_extractor", "decoder", "postprocessor"]
123
+ _children_names: list[str] = ["feat_extractor", "decoder", "postprocessor"]
128
124
 
129
125
  def __init__(
130
126
  self,
@@ -134,7 +130,7 @@ class CRNN(RecognitionModel, Model):
134
130
  exportable: bool = False,
135
131
  beam_width: int = 1,
136
132
  top_paths: int = 1,
137
- cfg: Optional[Dict[str, Any]] = None,
133
+ cfg: dict[str, Any] | None = None,
138
134
  ) -> None:
139
135
  # Initialize kernels
140
136
  h, w, c = feature_extractor.output_shape[1:]
@@ -158,20 +154,27 @@ class CRNN(RecognitionModel, Model):
158
154
  self.beam_width = beam_width
159
155
  self.top_paths = top_paths
160
156
 
157
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
158
+ """Load pretrained parameters onto the model
159
+
160
+ Args:
161
+ path_or_url: the path or URL to the model parameters (checkpoint)
162
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
163
+ """
164
+ load_pretrained_params(self, path_or_url, **kwargs)
165
+
161
166
  def compute_loss(
162
167
  self,
163
168
  model_output: tf.Tensor,
164
- target: List[str],
169
+ target: list[str],
165
170
  ) -> tf.Tensor:
166
171
  """Compute CTC loss for the model.
167
172
 
168
173
  Args:
169
- ----
170
174
  model_output: predicted logits of the model
171
175
  target: lengths of each gt word inside the batch
172
176
 
173
177
  Returns:
174
- -------
175
178
  The loss of the model on the batch
176
179
  """
177
180
  gt, seq_len = self.build_target(target)
@@ -185,13 +188,13 @@ class CRNN(RecognitionModel, Model):
185
188
  def call(
186
189
  self,
187
190
  x: tf.Tensor,
188
- target: Optional[List[str]] = None,
191
+ target: list[str] | None = None,
189
192
  return_model_output: bool = False,
190
193
  return_preds: bool = False,
191
194
  beam_width: int = 1,
192
195
  top_paths: int = 1,
193
196
  **kwargs: Any,
194
- ) -> Dict[str, Any]:
197
+ ) -> dict[str, Any]:
195
198
  if kwargs.get("training", False) and target is None:
196
199
  raise ValueError("Need to provide labels during training")
197
200
 
@@ -203,7 +206,7 @@ class CRNN(RecognitionModel, Model):
203
206
  features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c))
204
207
  logits = _bf16_to_float32(self.decoder(features_seq, **kwargs))
205
208
 
206
- out: Dict[str, tf.Tensor] = {}
209
+ out: dict[str, tf.Tensor] = {}
207
210
  if self.exportable:
208
211
  out["logits"] = logits
209
212
  return out
@@ -226,7 +229,7 @@ def _crnn(
226
229
  pretrained: bool,
227
230
  backbone_fn,
228
231
  pretrained_backbone: bool = True,
229
- input_shape: Optional[Tuple[int, int, int]] = None,
232
+ input_shape: tuple[int, int, int] | None = None,
230
233
  **kwargs: Any,
231
234
  ) -> CRNN:
232
235
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -249,7 +252,7 @@ def _crnn(
249
252
  # Load pretrained parameters
250
253
  if pretrained:
251
254
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
252
- load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
255
+ model.from_pretrained(_cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
253
256
 
254
257
  return model
255
258
 
@@ -265,12 +268,10 @@ def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN:
265
268
  >>> out = model(input_tensor)
266
269
 
267
270
  Args:
268
- ----
269
271
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
270
272
  **kwargs: keyword arguments of the CRNN architecture
271
273
 
272
274
  Returns:
273
- -------
274
275
  text recognition architecture
275
276
  """
276
277
  return _crnn("crnn_vgg16_bn", pretrained, vgg16_bn_r, **kwargs)
@@ -287,12 +288,10 @@ def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN:
287
288
  >>> out = model(input_tensor)
288
289
 
289
290
  Args:
290
- ----
291
291
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
292
292
  **kwargs: keyword arguments of the CRNN architecture
293
293
 
294
294
  Returns:
295
- -------
296
295
  text recognition architecture
297
296
  """
298
297
  return _crnn("crnn_mobilenet_v3_small", pretrained, mobilenet_v3_small_r, **kwargs)
@@ -309,12 +308,10 @@ def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN:
309
308
  >>> out = model(input_tensor)
310
309
 
311
310
  Args:
312
- ----
313
311
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
314
312
  **kwargs: keyword arguments of the CRNN architecture
315
313
 
316
314
  Returns:
317
- -------
318
315
  text recognition architecture
319
316
  """
320
317
  return _crnn("crnn_mobilenet_v3_large", pretrained, mobilenet_v3_large_r, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -17,17 +16,15 @@ class _MASTER:
17
16
 
18
17
  def build_target(
19
18
  self,
20
- gts: List[str],
21
- ) -> Tuple[np.ndarray, List[int]]:
19
+ gts: list[str],
20
+ ) -> tuple[np.ndarray, list[int]]:
22
21
  """Encode a list of gts sequences into a np array and gives the corresponding*
23
22
  sequence lengths.
24
23
 
25
24
  Args:
26
- ----
27
25
  gts: list of ground-truth labels
28
26
 
29
27
  Returns:
30
- -------
31
28
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
32
29
  """
33
30
  encoded = encode_sequences(
@@ -46,7 +43,6 @@ class _MASTERPostProcessor(RecognitionPostProcessor):
46
43
  """Abstract class to postprocess the raw output of the model
47
44
 
48
45
  Args:
49
- ----
50
46
  vocab: string containing the ordered sequence of supported characters
51
47
  """
52
48
 
@@ -1,10 +1,11 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ from collections.abc import Callable
6
7
  from copy import deepcopy
7
- from typing import Any, Callable, Dict, List, Optional, Tuple
8
+ from typing import Any
8
9
 
9
10
  import torch
10
11
  from torch import nn
@@ -21,7 +22,7 @@ from .base import _MASTER, _MASTERPostProcessor
21
22
  __all__ = ["MASTER", "master"]
22
23
 
23
24
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
25
+ default_cfgs: dict[str, dict[str, Any]] = {
25
26
  "master": {
26
27
  "mean": (0.694, 0.695, 0.693),
27
28
  "std": (0.299, 0.296, 0.301),
@@ -37,7 +38,6 @@ class MASTER(_MASTER, nn.Module):
37
38
  Implementation based on the official Pytorch implementation: <https://github.com/wenwenyu/MASTER-pytorch>`_.
38
39
 
39
40
  Args:
40
- ----
41
41
  feature_extractor: the backbone serving as feature extractor
42
42
  vocab: vocabulary, (without EOS, SOS, PAD)
43
43
  d_model: d parameter for the transformer decoder
@@ -61,9 +61,9 @@ class MASTER(_MASTER, nn.Module):
61
61
  num_layers: int = 3,
62
62
  max_length: int = 50,
63
63
  dropout: float = 0.2,
64
- input_shape: Tuple[int, int, int] = (3, 32, 128), # different from the paper
64
+ input_shape: tuple[int, int, int] = (3, 32, 128), # different from the paper
65
65
  exportable: bool = False,
66
- cfg: Optional[Dict[str, Any]] = None,
66
+ cfg: dict[str, Any] | None = None,
67
67
  ) -> None:
68
68
  super().__init__()
69
69
 
@@ -102,12 +102,12 @@ class MASTER(_MASTER, nn.Module):
102
102
 
103
103
  def make_source_and_target_mask(
104
104
  self, source: torch.Tensor, target: torch.Tensor
105
- ) -> Tuple[torch.Tensor, torch.Tensor]:
105
+ ) -> tuple[torch.Tensor, torch.Tensor]:
106
106
  # borrowed and slightly modified from https://github.com/wenwenyu/MASTER-pytorch
107
107
  # NOTE: nn.TransformerDecoder takes the inverse from this implementation
108
108
  # [True, True, True, ..., False, False, False] -> False is masked
109
109
  # (N, 1, 1, max_length)
110
- target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
110
+ target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
111
111
  target_length = target.size(1)
112
112
  # sub mask filled diagonal with True = see and False = masked (max_length, max_length)
113
113
  # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
@@ -130,19 +130,17 @@ class MASTER(_MASTER, nn.Module):
130
130
  Sequences are masked after the EOS character.
131
131
 
132
132
  Args:
133
- ----
134
133
  gt: the encoded tensor with gt labels
135
134
  model_output: predicted logits of the model
136
135
  seq_len: lengths of each gt word inside the batch
137
136
 
138
137
  Returns:
139
- -------
140
138
  The loss of the model on the batch
141
139
  """
142
140
  # Input length : number of timesteps
143
141
  input_len = model_output.shape[1]
144
142
  # Add one for additional <eos> token (sos disappear in shift!)
145
- seq_len = seq_len + 1
143
+ seq_len = seq_len + 1 # type: ignore[assignment]
146
144
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
147
145
  # The "masked" first gt char is <sos>. Delete last logit of the model output.
148
146
  cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -153,24 +151,31 @@ class MASTER(_MASTER, nn.Module):
153
151
  ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
154
152
  return ce_loss.mean()
155
153
 
154
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
155
+ """Load pretrained parameters onto the model
156
+
157
+ Args:
158
+ path_or_url: the path or URL to the model parameters (checkpoint)
159
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
160
+ """
161
+ load_pretrained_params(self, path_or_url, **kwargs)
162
+
156
163
  def forward(
157
164
  self,
158
165
  x: torch.Tensor,
159
- target: Optional[List[str]] = None,
166
+ target: list[str] | None = None,
160
167
  return_model_output: bool = False,
161
168
  return_preds: bool = False,
162
- ) -> Dict[str, Any]:
169
+ ) -> dict[str, Any]:
163
170
  """Call function for training
164
171
 
165
172
  Args:
166
- ----
167
173
  x: images
168
174
  target: list of str labels
169
175
  return_model_output: if True, return logits
170
176
  return_preds: if True, decode logits
171
177
 
172
178
  Returns:
173
- -------
174
179
  A dictionnary containing eventually loss, logits and predictions.
175
180
  """
176
181
  # Encode
@@ -181,7 +186,7 @@ class MASTER(_MASTER, nn.Module):
181
186
  # add positional encoding to features
182
187
  encoded = self.positional_encoding(features)
183
188
 
184
- out: Dict[str, Any] = {}
189
+ out: dict[str, Any] = {}
185
190
 
186
191
  if self.training and target is None:
187
192
  raise ValueError("Need to provide labels during training")
@@ -213,7 +218,13 @@ class MASTER(_MASTER, nn.Module):
213
218
  out["out_map"] = logits
214
219
 
215
220
  if return_preds:
216
- out["preds"] = self.postprocessor(logits)
221
+ # Disable for torch.compile compatibility
222
+ @torch.compiler.disable # type: ignore[attr-defined]
223
+ def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
224
+ return self.postprocessor(logits)
225
+
226
+ # Post-process boxes
227
+ out["preds"] = _postprocess(logits)
217
228
 
218
229
  return out
219
230
 
@@ -221,12 +232,10 @@ class MASTER(_MASTER, nn.Module):
221
232
  """Decode function for prediction
222
233
 
223
234
  Args:
224
- ----
225
235
  encoded: input tensor
226
236
 
227
237
  Returns:
228
- -------
229
- A Tuple of torch.Tensor: predictions, logits
238
+ A tuple of torch.Tensor: predictions, logits
230
239
  """
231
240
  b = encoded.size(0)
232
241
 
@@ -254,7 +263,7 @@ class MASTERPostProcessor(_MASTERPostProcessor):
254
263
  def __call__(
255
264
  self,
256
265
  logits: torch.Tensor,
257
- ) -> List[Tuple[str, float]]:
266
+ ) -> list[tuple[str, float]]:
258
267
  # compute pred with argmax for attention models
259
268
  out_idxs = logits.argmax(-1)
260
269
  # N x L
@@ -277,7 +286,7 @@ def _master(
277
286
  backbone_fn: Callable[[bool], nn.Module],
278
287
  layer: str,
279
288
  pretrained_backbone: bool = True,
280
- ignore_keys: Optional[List[str]] = None,
289
+ ignore_keys: list[str] | None = None,
281
290
  **kwargs: Any,
282
291
  ) -> MASTER:
283
292
  pretrained_backbone = pretrained_backbone and not pretrained
@@ -301,7 +310,7 @@ def _master(
301
310
  # The number of classes is not the same as the number of classes in the pretrained model =>
302
311
  # remove the last layer weights
303
312
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
304
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
313
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
305
314
 
306
315
  return model
307
316
 
@@ -316,12 +325,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
316
325
  >>> out = model(input_tensor)
317
326
 
318
327
  Args:
319
- ----
320
328
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
321
329
  **kwargs: keywoard arguments passed to the MASTER architecture
322
330
 
323
331
  Returns:
324
- -------
325
332
  text recognition architecture
326
333
  """
327
334
  return _master(