python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/cord.py +8 -7
  5. doctr/datasets/datasets/__init__.py +4 -4
  6. doctr/datasets/datasets/base.py +16 -16
  7. doctr/datasets/datasets/pytorch.py +12 -12
  8. doctr/datasets/datasets/tensorflow.py +10 -10
  9. doctr/datasets/detection.py +6 -9
  10. doctr/datasets/doc_artefacts.py +3 -4
  11. doctr/datasets/funsd.py +7 -6
  12. doctr/datasets/generator/__init__.py +4 -4
  13. doctr/datasets/generator/base.py +16 -17
  14. doctr/datasets/generator/pytorch.py +1 -3
  15. doctr/datasets/generator/tensorflow.py +1 -3
  16. doctr/datasets/ic03.py +4 -5
  17. doctr/datasets/ic13.py +4 -5
  18. doctr/datasets/iiit5k.py +6 -5
  19. doctr/datasets/iiithws.py +4 -5
  20. doctr/datasets/imgur5k.py +6 -5
  21. doctr/datasets/loader.py +4 -7
  22. doctr/datasets/mjsynth.py +6 -5
  23. doctr/datasets/ocr.py +3 -4
  24. doctr/datasets/orientation.py +3 -4
  25. doctr/datasets/recognition.py +3 -4
  26. doctr/datasets/sroie.py +6 -5
  27. doctr/datasets/svhn.py +6 -5
  28. doctr/datasets/svt.py +4 -5
  29. doctr/datasets/synthtext.py +4 -5
  30. doctr/datasets/utils.py +34 -29
  31. doctr/datasets/vocabs.py +17 -7
  32. doctr/datasets/wildreceipt.py +14 -10
  33. doctr/file_utils.py +2 -7
  34. doctr/io/elements.py +59 -79
  35. doctr/io/html.py +1 -3
  36. doctr/io/image/__init__.py +3 -3
  37. doctr/io/image/base.py +2 -5
  38. doctr/io/image/pytorch.py +3 -12
  39. doctr/io/image/tensorflow.py +2 -11
  40. doctr/io/pdf.py +5 -7
  41. doctr/io/reader.py +5 -11
  42. doctr/models/_utils.py +14 -22
  43. doctr/models/builder.py +30 -48
  44. doctr/models/classification/magc_resnet/__init__.py +3 -3
  45. doctr/models/classification/magc_resnet/pytorch.py +10 -13
  46. doctr/models/classification/magc_resnet/tensorflow.py +8 -11
  47. doctr/models/classification/mobilenet/__init__.py +3 -3
  48. doctr/models/classification/mobilenet/pytorch.py +5 -17
  49. doctr/models/classification/mobilenet/tensorflow.py +8 -21
  50. doctr/models/classification/predictor/__init__.py +4 -4
  51. doctr/models/classification/predictor/pytorch.py +6 -8
  52. doctr/models/classification/predictor/tensorflow.py +6 -8
  53. doctr/models/classification/resnet/__init__.py +4 -4
  54. doctr/models/classification/resnet/pytorch.py +21 -31
  55. doctr/models/classification/resnet/tensorflow.py +20 -31
  56. doctr/models/classification/textnet/__init__.py +3 -3
  57. doctr/models/classification/textnet/pytorch.py +10 -17
  58. doctr/models/classification/textnet/tensorflow.py +8 -15
  59. doctr/models/classification/vgg/__init__.py +3 -3
  60. doctr/models/classification/vgg/pytorch.py +5 -7
  61. doctr/models/classification/vgg/tensorflow.py +9 -12
  62. doctr/models/classification/vit/__init__.py +3 -3
  63. doctr/models/classification/vit/pytorch.py +8 -14
  64. doctr/models/classification/vit/tensorflow.py +6 -12
  65. doctr/models/classification/zoo.py +19 -14
  66. doctr/models/core.py +3 -3
  67. doctr/models/detection/_utils/__init__.py +4 -4
  68. doctr/models/detection/_utils/base.py +4 -7
  69. doctr/models/detection/_utils/pytorch.py +1 -5
  70. doctr/models/detection/_utils/tensorflow.py +1 -5
  71. doctr/models/detection/core.py +2 -8
  72. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  73. doctr/models/detection/differentiable_binarization/base.py +7 -17
  74. doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
  75. doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
  76. doctr/models/detection/fast/__init__.py +4 -4
  77. doctr/models/detection/fast/base.py +6 -14
  78. doctr/models/detection/fast/pytorch.py +24 -31
  79. doctr/models/detection/fast/tensorflow.py +14 -26
  80. doctr/models/detection/linknet/__init__.py +4 -4
  81. doctr/models/detection/linknet/base.py +6 -15
  82. doctr/models/detection/linknet/pytorch.py +24 -27
  83. doctr/models/detection/linknet/tensorflow.py +14 -23
  84. doctr/models/detection/predictor/__init__.py +5 -5
  85. doctr/models/detection/predictor/pytorch.py +6 -7
  86. doctr/models/detection/predictor/tensorflow.py +5 -6
  87. doctr/models/detection/zoo.py +27 -7
  88. doctr/models/factory/hub.py +3 -7
  89. doctr/models/kie_predictor/__init__.py +5 -5
  90. doctr/models/kie_predictor/base.py +4 -5
  91. doctr/models/kie_predictor/pytorch.py +18 -19
  92. doctr/models/kie_predictor/tensorflow.py +13 -14
  93. doctr/models/modules/layers/__init__.py +3 -3
  94. doctr/models/modules/layers/pytorch.py +6 -9
  95. doctr/models/modules/layers/tensorflow.py +5 -7
  96. doctr/models/modules/transformer/__init__.py +3 -3
  97. doctr/models/modules/transformer/pytorch.py +12 -13
  98. doctr/models/modules/transformer/tensorflow.py +9 -10
  99. doctr/models/modules/vision_transformer/__init__.py +3 -3
  100. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  101. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  102. doctr/models/predictor/__init__.py +5 -5
  103. doctr/models/predictor/base.py +28 -29
  104. doctr/models/predictor/pytorch.py +12 -13
  105. doctr/models/predictor/tensorflow.py +8 -9
  106. doctr/models/preprocessor/__init__.py +4 -4
  107. doctr/models/preprocessor/pytorch.py +13 -17
  108. doctr/models/preprocessor/tensorflow.py +10 -14
  109. doctr/models/recognition/core.py +3 -7
  110. doctr/models/recognition/crnn/__init__.py +4 -4
  111. doctr/models/recognition/crnn/pytorch.py +20 -28
  112. doctr/models/recognition/crnn/tensorflow.py +11 -23
  113. doctr/models/recognition/master/__init__.py +3 -3
  114. doctr/models/recognition/master/base.py +3 -7
  115. doctr/models/recognition/master/pytorch.py +22 -24
  116. doctr/models/recognition/master/tensorflow.py +12 -22
  117. doctr/models/recognition/parseq/__init__.py +3 -3
  118. doctr/models/recognition/parseq/base.py +3 -7
  119. doctr/models/recognition/parseq/pytorch.py +26 -26
  120. doctr/models/recognition/parseq/tensorflow.py +16 -22
  121. doctr/models/recognition/predictor/__init__.py +5 -5
  122. doctr/models/recognition/predictor/_utils.py +7 -10
  123. doctr/models/recognition/predictor/pytorch.py +6 -6
  124. doctr/models/recognition/predictor/tensorflow.py +5 -6
  125. doctr/models/recognition/sar/__init__.py +4 -4
  126. doctr/models/recognition/sar/pytorch.py +20 -21
  127. doctr/models/recognition/sar/tensorflow.py +12 -21
  128. doctr/models/recognition/utils.py +5 -10
  129. doctr/models/recognition/vitstr/__init__.py +4 -4
  130. doctr/models/recognition/vitstr/base.py +3 -7
  131. doctr/models/recognition/vitstr/pytorch.py +18 -20
  132. doctr/models/recognition/vitstr/tensorflow.py +12 -20
  133. doctr/models/recognition/zoo.py +22 -11
  134. doctr/models/utils/__init__.py +4 -4
  135. doctr/models/utils/pytorch.py +14 -17
  136. doctr/models/utils/tensorflow.py +17 -16
  137. doctr/models/zoo.py +1 -5
  138. doctr/transforms/functional/__init__.py +3 -3
  139. doctr/transforms/functional/base.py +4 -11
  140. doctr/transforms/functional/pytorch.py +20 -28
  141. doctr/transforms/functional/tensorflow.py +10 -22
  142. doctr/transforms/modules/__init__.py +4 -4
  143. doctr/transforms/modules/base.py +48 -55
  144. doctr/transforms/modules/pytorch.py +58 -22
  145. doctr/transforms/modules/tensorflow.py +18 -32
  146. doctr/utils/common_types.py +8 -9
  147. doctr/utils/data.py +8 -12
  148. doctr/utils/fonts.py +2 -7
  149. doctr/utils/geometry.py +16 -47
  150. doctr/utils/metrics.py +17 -37
  151. doctr/utils/multithreading.py +4 -6
  152. doctr/utils/reconstitution.py +9 -13
  153. doctr/utils/repr.py +2 -3
  154. doctr/utils/visualization.py +16 -29
  155. doctr/version.py +1 -1
  156. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
  157. python_doctr-0.11.0.dist-info/RECORD +173 -0
  158. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
  159. python_doctr-0.10.0.dist-info/RECORD +0 -173
  160. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
  161. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
  162. {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Model, layers
@@ -19,7 +19,7 @@ from .base import _MASTER, _MASTERPostProcessor
19
19
  __all__ = ["MASTER", "master"]
20
20
 
21
21
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
22
+ default_cfgs: dict[str, dict[str, Any]] = {
23
23
  "master": {
24
24
  "mean": (0.694, 0.695, 0.693),
25
25
  "std": (0.299, 0.296, 0.301),
@@ -35,7 +35,6 @@ class MASTER(_MASTER, Model):
35
35
  Implementation based on the official TF implementation: <https://github.com/jiangxiluning/MASTER-TF>`_.
36
36
 
37
37
  Args:
38
- ----
39
38
  feature_extractor: the backbone serving as feature extractor
40
39
  vocab: vocabulary, (without EOS, SOS, PAD)
41
40
  d_model: d parameter for the transformer decoder
@@ -59,9 +58,9 @@ class MASTER(_MASTER, Model):
59
58
  num_layers: int = 3,
60
59
  max_length: int = 50,
61
60
  dropout: float = 0.2,
62
- input_shape: Tuple[int, int, int] = (32, 128, 3), # different from the paper
61
+ input_shape: tuple[int, int, int] = (32, 128, 3), # different from the paper
63
62
  exportable: bool = False,
64
- cfg: Optional[Dict[str, Any]] = None,
63
+ cfg: dict[str, Any] | None = None,
65
64
  ) -> None:
66
65
  super().__init__()
67
66
 
@@ -89,7 +88,7 @@ class MASTER(_MASTER, Model):
89
88
  self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
90
89
 
91
90
  @tf.function
92
- def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
91
+ def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
93
92
  # [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
94
93
  # (N, 1, 1, max_length)
95
94
  target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8)
@@ -109,19 +108,17 @@ class MASTER(_MASTER, Model):
109
108
  def compute_loss(
110
109
  model_output: tf.Tensor,
111
110
  gt: tf.Tensor,
112
- seq_len: List[int],
111
+ seq_len: list[int],
113
112
  ) -> tf.Tensor:
114
113
  """Compute categorical cross-entropy loss for the model.
115
114
  Sequences are masked after the EOS character.
116
115
 
117
116
  Args:
118
- ----
119
117
  gt: the encoded tensor with gt labels
120
118
  model_output: predicted logits of the model
121
119
  seq_len: lengths of each gt word inside the batch
122
120
 
123
121
  Returns:
124
- -------
125
122
  The loss of the model on the batch
126
123
  """
127
124
  # Input length : number of timesteps
@@ -144,15 +141,14 @@ class MASTER(_MASTER, Model):
144
141
  def call(
145
142
  self,
146
143
  x: tf.Tensor,
147
- target: Optional[List[str]] = None,
144
+ target: list[str] | None = None,
148
145
  return_model_output: bool = False,
149
146
  return_preds: bool = False,
150
147
  **kwargs: Any,
151
- ) -> Dict[str, Any]:
148
+ ) -> dict[str, Any]:
152
149
  """Call function for training
153
150
 
154
151
  Args:
155
- ----
156
152
  x: images
157
153
  target: list of str labels
158
154
  return_model_output: if True, return logits
@@ -160,7 +156,6 @@ class MASTER(_MASTER, Model):
160
156
  **kwargs: keyword arguments passed to the decoder
161
157
 
162
158
  Returns:
163
- -------
164
159
  A dictionnary containing eventually loss, logits and predictions.
165
160
  """
166
161
  # Encode
@@ -171,7 +166,7 @@ class MASTER(_MASTER, Model):
171
166
  # add positional encoding to features
172
167
  encoded = self.positional_encoding(feature, **kwargs)
173
168
 
174
- out: Dict[str, tf.Tensor] = {}
169
+ out: dict[str, tf.Tensor] = {}
175
170
 
176
171
  if kwargs.get("training", False) and target is None:
177
172
  raise ValueError("Need to provide labels during training")
@@ -209,13 +204,11 @@ class MASTER(_MASTER, Model):
209
204
  """Decode function for prediction
210
205
 
211
206
  Args:
212
- ----
213
207
  encoded: encoded features
214
208
  **kwargs: keyword arguments passed to the decoder
215
209
 
216
210
  Returns:
217
- -------
218
- A Tuple of tf.Tensor: predictions, logits
211
+ A tuple of tf.Tensor: predictions, logits
219
212
  """
220
213
  b = encoded.shape[0]
221
214
 
@@ -247,14 +240,13 @@ class MASTERPostProcessor(_MASTERPostProcessor):
247
240
  """Post processor for MASTER architectures
248
241
 
249
242
  Args:
250
- ----
251
243
  vocab: string containing the ordered sequence of supported characters
252
244
  """
253
245
 
254
246
  def __call__(
255
247
  self,
256
248
  logits: tf.Tensor,
257
- ) -> List[Tuple[str, float]]:
249
+ ) -> list[tuple[str, float]]:
258
250
  # compute pred with argmax for attention models
259
251
  out_idxs = tf.math.argmax(logits, axis=2)
260
252
  # N x L
@@ -312,12 +304,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
312
304
  >>> out = model(input_tensor)
313
305
 
314
306
  Args:
315
- ----
316
307
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
317
308
  **kwargs: keywoard arguments passed to the MASTER architecture
318
309
 
319
310
  Returns:
320
- -------
321
311
  text recognition architecture
322
312
  """
323
313
  return _master("master", pretrained, magc_resnet31, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -17,17 +16,15 @@ class _PARSeq:
17
16
 
18
17
  def build_target(
19
18
  self,
20
- gts: List[str],
21
- ) -> Tuple[np.ndarray, List[int]]:
19
+ gts: list[str],
20
+ ) -> tuple[np.ndarray, list[int]]:
22
21
  """Encode a list of gts sequences into a np array and gives the corresponding*
23
22
  sequence lengths.
24
23
 
25
24
  Args:
26
- ----
27
25
  gts: list of ground-truth labels
28
26
 
29
27
  Returns:
30
- -------
31
28
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
32
29
  """
33
30
  encoded = encode_sequences(
@@ -46,7 +43,6 @@ class _PARSeqPostProcessor(RecognitionPostProcessor):
46
43
  """Abstract class to postprocess the raw output of the model
47
44
 
48
45
  Args:
49
- ----
50
46
  vocab: string containing the ordered sequence of supported characters
51
47
  """
52
48
 
@@ -1,12 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
+ from collections.abc import Callable
7
8
  from copy import deepcopy
8
9
  from itertools import permutations
9
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+ from typing import Any
10
11
 
11
12
  import numpy as np
12
13
  import torch
@@ -23,7 +24,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
23
24
 
24
25
  __all__ = ["PARSeq", "parseq"]
25
26
 
26
- default_cfgs: Dict[str, Dict[str, Any]] = {
27
+ default_cfgs: dict[str, dict[str, Any]] = {
27
28
  "parseq": {
28
29
  "mean": (0.694, 0.695, 0.693),
29
30
  "std": (0.299, 0.296, 0.301),
@@ -38,7 +39,6 @@ class CharEmbedding(nn.Module):
38
39
  """Implements the character embedding module
39
40
 
40
41
  Args:
41
- ----
42
42
  vocab_size: size of the vocabulary
43
43
  d_model: dimension of the model
44
44
  """
@@ -56,7 +56,6 @@ class PARSeqDecoder(nn.Module):
56
56
  """Implements decoder module of the PARSeq model
57
57
 
58
58
  Args:
59
- ----
60
59
  d_model: dimension of the model
61
60
  num_heads: number of attention heads
62
61
  ffd: dimension of the feed forward layer
@@ -92,7 +91,7 @@ class PARSeqDecoder(nn.Module):
92
91
  target,
93
92
  content,
94
93
  memory,
95
- target_mask: Optional[torch.Tensor] = None,
94
+ target_mask: torch.Tensor | None = None,
96
95
  ):
97
96
  query_norm = self.query_norm(target)
98
97
  content_norm = self.content_norm(content)
@@ -112,7 +111,6 @@ class PARSeq(_PARSeq, nn.Module):
112
111
  Slightly modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
113
112
 
114
113
  Args:
115
- ----
116
114
  feature_extractor: the backbone serving as feature extractor
117
115
  vocab: vocabulary used for encoding
118
116
  embedding_units: number of embedding units
@@ -136,9 +134,9 @@ class PARSeq(_PARSeq, nn.Module):
136
134
  dec_num_heads: int = 12,
137
135
  dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
138
136
  dec_ffd_ratio: int = 4,
139
- input_shape: Tuple[int, int, int] = (3, 32, 128),
137
+ input_shape: tuple[int, int, int] = (3, 32, 128),
140
138
  exportable: bool = False,
141
- cfg: Optional[Dict[str, Any]] = None,
139
+ cfg: dict[str, Any] | None = None,
142
140
  ) -> None:
143
141
  super().__init__()
144
142
  self.vocab = vocab
@@ -212,12 +210,12 @@ class PARSeq(_PARSeq, nn.Module):
212
210
 
213
211
  sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
214
212
  eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
215
- combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
213
+ combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore[list-item]
216
214
  if len(combined) > 1:
217
215
  combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
218
216
  return combined
219
217
 
220
- def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
218
+ def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
221
219
  # Generate source and target mask for the decoder attention.
222
220
  sz = permutation.shape[0]
223
221
  mask = torch.ones((sz, sz), device=permutation.device)
@@ -236,8 +234,8 @@ class PARSeq(_PARSeq, nn.Module):
236
234
  self,
237
235
  target: torch.Tensor,
238
236
  memory: torch.Tensor,
239
- target_mask: Optional[torch.Tensor] = None,
240
- target_query: Optional[torch.Tensor] = None,
237
+ target_mask: torch.Tensor | None = None,
238
+ target_query: torch.Tensor | None = None,
241
239
  ) -> torch.Tensor:
242
240
  """Add positional information to the target sequence and pass it through the decoder."""
243
241
  batch_size, sequence_length = target.shape
@@ -250,7 +248,7 @@ class PARSeq(_PARSeq, nn.Module):
250
248
  target_query = self.dropout(target_query)
251
249
  return self.decoder(target_query, content, memory, target_mask)
252
250
 
253
- def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
251
+ def decode_autoregressive(self, features: torch.Tensor, max_len: int | None = None) -> torch.Tensor:
254
252
  """Generate predictions for the given features."""
255
253
  max_length = max_len if max_len is not None else self.max_length
256
254
  max_length = min(max_length, self.max_length) + 1
@@ -283,7 +281,7 @@ class PARSeq(_PARSeq, nn.Module):
283
281
 
284
282
  # Stop decoding if all sequences have reached the EOS token
285
283
  # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
286
- if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
284
+ if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
287
285
  break
288
286
 
289
287
  logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
@@ -298,7 +296,7 @@ class PARSeq(_PARSeq, nn.Module):
298
296
 
299
297
  # Create padding mask for refined target input maskes all behind EOS token as False
300
298
  # (N, 1, 1, max_length)
301
- target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
299
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
302
300
  mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
303
301
  logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
304
302
 
@@ -307,10 +305,10 @@ class PARSeq(_PARSeq, nn.Module):
307
305
  def forward(
308
306
  self,
309
307
  x: torch.Tensor,
310
- target: Optional[List[str]] = None,
308
+ target: list[str] | None = None,
311
309
  return_model_output: bool = False,
312
310
  return_preds: bool = False,
313
- ) -> Dict[str, Any]:
311
+ ) -> dict[str, Any]:
314
312
  features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
315
313
  # remove cls token
316
314
  features = features[:, 1:, :]
@@ -337,7 +335,7 @@ class PARSeq(_PARSeq, nn.Module):
337
335
  ).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len)
338
336
 
339
337
  loss = torch.tensor(0.0, device=features.device)
340
- loss_numel: Union[int, float] = 0
338
+ loss_numel: int | float = 0
341
339
  n = (gt_out != self.vocab_size + 2).sum().item()
342
340
  for i, perm in enumerate(tgt_perms):
343
341
  _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
@@ -365,7 +363,7 @@ class PARSeq(_PARSeq, nn.Module):
365
363
 
366
364
  logits = _bf16_to_float32(logits)
367
365
 
368
- out: Dict[str, Any] = {}
366
+ out: dict[str, Any] = {}
369
367
  if self.exportable:
370
368
  out["logits"] = logits
371
369
  return out
@@ -374,8 +372,13 @@ class PARSeq(_PARSeq, nn.Module):
374
372
  out["out_map"] = logits
375
373
 
376
374
  if target is None or return_preds:
375
+ # Disable for torch.compile compatibility
376
+ @torch.compiler.disable # type: ignore[attr-defined]
377
+ def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
378
+ return self.postprocessor(logits)
379
+
377
380
  # Post-process boxes
378
- out["preds"] = self.postprocessor(logits)
381
+ out["preds"] = _postprocess(logits)
379
382
 
380
383
  if target is not None:
381
384
  out["loss"] = loss
@@ -387,14 +390,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
387
390
  """Post processor for PARSeq architecture
388
391
 
389
392
  Args:
390
- ----
391
393
  vocab: string containing the ordered sequence of supported characters
392
394
  """
393
395
 
394
396
  def __call__(
395
397
  self,
396
398
  logits: torch.Tensor,
397
- ) -> List[Tuple[str, float]]:
399
+ ) -> list[tuple[str, float]]:
398
400
  # compute pred with argmax for attention models
399
401
  out_idxs = logits.argmax(-1)
400
402
  preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
@@ -417,7 +419,7 @@ def _parseq(
417
419
  pretrained: bool,
418
420
  backbone_fn: Callable[[bool], nn.Module],
419
421
  layer: str,
420
- ignore_keys: Optional[List[str]] = None,
422
+ ignore_keys: list[str] | None = None,
421
423
  **kwargs: Any,
422
424
  ) -> PARSeq:
423
425
  # Patch the config
@@ -462,12 +464,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
462
464
  >>> out = model(input_tensor)
463
465
 
464
466
  Args:
465
- ----
466
467
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
467
468
  **kwargs: keyword arguments of the PARSeq architecture
468
469
 
469
470
  Returns:
470
- -------
471
471
  text recognition architecture
472
472
  """
473
473
  return _parseq(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import math
7
7
  from copy import deepcopy
8
8
  from itertools import permutations
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -21,7 +21,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
21
21
 
22
22
  __all__ = ["PARSeq", "parseq"]
23
23
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
24
+ default_cfgs: dict[str, dict[str, Any]] = {
25
25
  "parseq": {
26
26
  "mean": (0.694, 0.695, 0.693),
27
27
  "std": (0.299, 0.296, 0.301),
@@ -36,7 +36,7 @@ class CharEmbedding(layers.Layer):
36
36
  """Implements the character embedding module
37
37
 
38
38
  Args:
39
- ----
39
+ -
40
40
  vocab_size: size of the vocabulary
41
41
  d_model: dimension of the model
42
42
  """
@@ -54,7 +54,6 @@ class PARSeqDecoder(layers.Layer):
54
54
  """Implements decoder module of the PARSeq model
55
55
 
56
56
  Args:
57
- ----
58
57
  d_model: dimension of the model
59
58
  num_heads: number of attention heads
60
59
  ffd: dimension of the feed forward layer
@@ -115,7 +114,6 @@ class PARSeq(_PARSeq, Model):
115
114
  Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
116
115
 
117
116
  Args:
118
- ----
119
117
  feature_extractor: the backbone serving as feature extractor
120
118
  vocab: vocabulary used for encoding
121
119
  embedding_units: number of embedding units
@@ -129,7 +127,7 @@ class PARSeq(_PARSeq, Model):
129
127
  cfg: dictionary containing information about the model
130
128
  """
131
129
 
132
- _children_names: List[str] = ["feat_extractor", "postprocessor"]
130
+ _children_names: list[str] = ["feat_extractor", "postprocessor"]
133
131
 
134
132
  def __init__(
135
133
  self,
@@ -141,9 +139,9 @@ class PARSeq(_PARSeq, Model):
141
139
  dec_num_heads: int = 12,
142
140
  dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
143
141
  dec_ffd_ratio: int = 4,
144
- input_shape: Tuple[int, int, int] = (32, 128, 3),
142
+ input_shape: tuple[int, int, int] = (32, 128, 3),
145
143
  exportable: bool = False,
146
- cfg: Optional[Dict[str, Any]] = None,
144
+ cfg: dict[str, Any] | None = None,
147
145
  ) -> None:
148
146
  super().__init__()
149
147
  self.vocab = vocab
@@ -213,7 +211,7 @@ class PARSeq(_PARSeq, Model):
213
211
  )
214
212
  return combined
215
213
 
216
- def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
214
+ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
217
215
  # Generate source and target mask for the decoder attention.
218
216
  sz = permutation.shape[0]
219
217
  mask = tf.ones((sz, sz), dtype=tf.float32)
@@ -236,8 +234,8 @@ class PARSeq(_PARSeq, Model):
236
234
  self,
237
235
  target: tf.Tensor,
238
236
  memory: tf.Tensor,
239
- target_mask: Optional[tf.Tensor] = None,
240
- target_query: Optional[tf.Tensor] = None,
237
+ target_mask: tf.Tensor | None = None,
238
+ target_query: tf.Tensor | None = None,
241
239
  **kwargs: Any,
242
240
  ) -> tf.Tensor:
243
241
  batch_size, sequence_length = target.shape
@@ -250,8 +248,7 @@ class PARSeq(_PARSeq, Model):
250
248
  target_query = self.dropout(target_query, **kwargs)
251
249
  return self.decoder(target_query, content, memory, target_mask, **kwargs)
252
250
 
253
- @tf.function
254
- def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
251
+ def decode_autoregressive(self, features: tf.Tensor, max_len: int | None = None, **kwargs) -> tf.Tensor:
255
252
  """Generate predictions for the given features."""
256
253
  max_length = max_len if max_len is not None else self.max_length
257
254
  max_length = min(max_length, self.max_length) + 1
@@ -318,11 +315,11 @@ class PARSeq(_PARSeq, Model):
318
315
  def call(
319
316
  self,
320
317
  x: tf.Tensor,
321
- target: Optional[List[str]] = None,
318
+ target: list[str] | None = None,
322
319
  return_model_output: bool = False,
323
320
  return_preds: bool = False,
324
321
  **kwargs: Any,
325
- ) -> Dict[str, Any]:
322
+ ) -> dict[str, Any]:
326
323
  features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
327
324
  # remove cls token
328
325
  features = features[:, 1:, :]
@@ -393,7 +390,7 @@ class PARSeq(_PARSeq, Model):
393
390
 
394
391
  logits = _bf16_to_float32(logits)
395
392
 
396
- out: Dict[str, tf.Tensor] = {}
393
+ out: dict[str, tf.Tensor] = {}
397
394
  if self.exportable:
398
395
  out["logits"] = logits
399
396
  return out
@@ -415,14 +412,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
415
412
  """Post processor for PARSeq architecture
416
413
 
417
414
  Args:
418
- ----
419
415
  vocab: string containing the ordered sequence of supported characters
420
416
  """
421
417
 
422
418
  def __call__(
423
419
  self,
424
420
  logits: tf.Tensor,
425
- ) -> List[Tuple[str, float]]:
421
+ ) -> list[tuple[str, float]]:
426
422
  # compute pred with argmax for attention models
427
423
  out_idxs = tf.math.argmax(logits, axis=2)
428
424
  preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
@@ -448,7 +444,7 @@ def _parseq(
448
444
  arch: str,
449
445
  pretrained: bool,
450
446
  backbone_fn,
451
- input_shape: Optional[Tuple[int, int, int]] = None,
447
+ input_shape: tuple[int, int, int] | None = None,
452
448
  **kwargs: Any,
453
449
  ) -> PARSeq:
454
450
  # Patch the config
@@ -496,12 +492,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
496
492
  >>> out = model(input_tensor)
497
493
 
498
494
  Args:
499
- ----
500
495
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
501
496
  **kwargs: keyword arguments of the PARSeq architecture
502
497
 
503
498
  Returns:
504
- -------
505
499
  text recognition architecture
506
500
  """
507
501
  return _parseq(
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple, Union
7
6
 
8
7
  import numpy as np
9
8
 
@@ -13,16 +12,15 @@ __all__ = ["split_crops", "remap_preds"]
13
12
 
14
13
 
15
14
  def split_crops(
16
- crops: List[np.ndarray],
15
+ crops: list[np.ndarray],
17
16
  max_ratio: float,
18
17
  target_ratio: int,
19
18
  dilation: float,
20
19
  channels_last: bool = True,
21
- ) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
20
+ ) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
22
21
  """Chunk crops horizontally to match a given aspect ratio
23
22
 
24
23
  Args:
25
- ----
26
24
  crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
27
25
  max_ratio: the maximum aspect ratio that won't trigger the chunk
28
26
  target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
@@ -30,12 +28,11 @@ def split_crops(
30
28
  channels_last: whether the numpy array has dimensions in channels last order
31
29
 
32
30
  Returns:
33
- -------
34
31
  a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
35
32
  """
36
33
  _remap_required = False
37
- crop_map: List[Union[int, Tuple[int, int]]] = []
38
- new_crops: List[np.ndarray] = []
34
+ crop_map: list[int | tuple[int, int]] = []
35
+ new_crops: list[np.ndarray] = []
39
36
  for crop in crops:
40
37
  h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
41
38
  aspect_ratio = w / h
@@ -71,8 +68,8 @@ def split_crops(
71
68
 
72
69
 
73
70
  def remap_preds(
74
- preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
75
- ) -> List[Tuple[str, float]]:
71
+ preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int]], dilation: float
72
+ ) -> list[tuple[str, float]]:
76
73
  remapped_out = []
77
74
  for _idx in crop_map:
78
75
  # Crop hasn't been split
@@ -1,9 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Sequence, Tuple, Union
6
+ from collections.abc import Sequence
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  import torch
@@ -21,7 +22,6 @@ class RecognitionPredictor(nn.Module):
21
22
  """Implements an object able to identify character sequences in images
22
23
 
23
24
  Args:
24
- ----
25
25
  pre_processor: transform inputs for easier batched model inference
26
26
  model: core detection architecture
27
27
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
@@ -44,9 +44,9 @@ class RecognitionPredictor(nn.Module):
44
44
  @torch.inference_mode()
45
45
  def forward(
46
46
  self,
47
- crops: Sequence[Union[np.ndarray, torch.Tensor]],
47
+ crops: Sequence[np.ndarray | torch.Tensor],
48
48
  **kwargs: Any,
49
- ) -> List[Tuple[str, float]]:
49
+ ) -> list[tuple[str, float]]:
50
50
  if len(crops) == 0:
51
51
  return []
52
52
  # Dimension check
@@ -67,7 +67,7 @@ class RecognitionPredictor(nn.Module):
67
67
  crops = new_crops
68
68
 
69
69
  # Resize & batch them
70
- processed_batches = self.pre_processor(crops)
70
+ processed_batches = self.pre_processor(crops) # type: ignore[arg-type]
71
71
 
72
72
  # Forward it
73
73
  _params = next(self.model.parameters())