python-doctr 0.10.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. doctr/contrib/__init__.py +1 -0
  2. doctr/contrib/artefacts.py +7 -9
  3. doctr/contrib/base.py +8 -17
  4. doctr/datasets/__init__.py +1 -0
  5. doctr/datasets/coco_text.py +139 -0
  6. doctr/datasets/cord.py +10 -8
  7. doctr/datasets/datasets/__init__.py +4 -4
  8. doctr/datasets/datasets/base.py +16 -16
  9. doctr/datasets/datasets/pytorch.py +12 -12
  10. doctr/datasets/datasets/tensorflow.py +10 -10
  11. doctr/datasets/detection.py +6 -9
  12. doctr/datasets/doc_artefacts.py +3 -4
  13. doctr/datasets/funsd.py +9 -8
  14. doctr/datasets/generator/__init__.py +4 -4
  15. doctr/datasets/generator/base.py +16 -17
  16. doctr/datasets/generator/pytorch.py +1 -3
  17. doctr/datasets/generator/tensorflow.py +1 -3
  18. doctr/datasets/ic03.py +5 -6
  19. doctr/datasets/ic13.py +6 -6
  20. doctr/datasets/iiit5k.py +10 -6
  21. doctr/datasets/iiithws.py +4 -5
  22. doctr/datasets/imgur5k.py +15 -7
  23. doctr/datasets/loader.py +4 -7
  24. doctr/datasets/mjsynth.py +6 -5
  25. doctr/datasets/ocr.py +3 -4
  26. doctr/datasets/orientation.py +3 -4
  27. doctr/datasets/recognition.py +4 -5
  28. doctr/datasets/sroie.py +6 -5
  29. doctr/datasets/svhn.py +7 -6
  30. doctr/datasets/svt.py +6 -7
  31. doctr/datasets/synthtext.py +19 -7
  32. doctr/datasets/utils.py +41 -35
  33. doctr/datasets/vocabs.py +1107 -49
  34. doctr/datasets/wildreceipt.py +14 -10
  35. doctr/file_utils.py +11 -7
  36. doctr/io/elements.py +96 -82
  37. doctr/io/html.py +1 -3
  38. doctr/io/image/__init__.py +3 -3
  39. doctr/io/image/base.py +2 -5
  40. doctr/io/image/pytorch.py +3 -12
  41. doctr/io/image/tensorflow.py +2 -11
  42. doctr/io/pdf.py +5 -7
  43. doctr/io/reader.py +5 -11
  44. doctr/models/_utils.py +15 -23
  45. doctr/models/builder.py +30 -48
  46. doctr/models/classification/__init__.py +1 -0
  47. doctr/models/classification/magc_resnet/__init__.py +3 -3
  48. doctr/models/classification/magc_resnet/pytorch.py +11 -15
  49. doctr/models/classification/magc_resnet/tensorflow.py +11 -14
  50. doctr/models/classification/mobilenet/__init__.py +3 -3
  51. doctr/models/classification/mobilenet/pytorch.py +20 -18
  52. doctr/models/classification/mobilenet/tensorflow.py +19 -23
  53. doctr/models/classification/predictor/__init__.py +4 -4
  54. doctr/models/classification/predictor/pytorch.py +7 -9
  55. doctr/models/classification/predictor/tensorflow.py +6 -8
  56. doctr/models/classification/resnet/__init__.py +4 -4
  57. doctr/models/classification/resnet/pytorch.py +47 -34
  58. doctr/models/classification/resnet/tensorflow.py +45 -35
  59. doctr/models/classification/textnet/__init__.py +3 -3
  60. doctr/models/classification/textnet/pytorch.py +20 -18
  61. doctr/models/classification/textnet/tensorflow.py +19 -17
  62. doctr/models/classification/vgg/__init__.py +3 -3
  63. doctr/models/classification/vgg/pytorch.py +21 -8
  64. doctr/models/classification/vgg/tensorflow.py +20 -14
  65. doctr/models/classification/vip/__init__.py +4 -0
  66. doctr/models/classification/vip/layers/__init__.py +4 -0
  67. doctr/models/classification/vip/layers/pytorch.py +615 -0
  68. doctr/models/classification/vip/pytorch.py +505 -0
  69. doctr/models/classification/vit/__init__.py +3 -3
  70. doctr/models/classification/vit/pytorch.py +18 -15
  71. doctr/models/classification/vit/tensorflow.py +15 -12
  72. doctr/models/classification/zoo.py +23 -14
  73. doctr/models/core.py +3 -3
  74. doctr/models/detection/_utils/__init__.py +4 -4
  75. doctr/models/detection/_utils/base.py +4 -7
  76. doctr/models/detection/_utils/pytorch.py +1 -5
  77. doctr/models/detection/_utils/tensorflow.py +1 -5
  78. doctr/models/detection/core.py +2 -8
  79. doctr/models/detection/differentiable_binarization/__init__.py +4 -4
  80. doctr/models/detection/differentiable_binarization/base.py +10 -21
  81. doctr/models/detection/differentiable_binarization/pytorch.py +37 -31
  82. doctr/models/detection/differentiable_binarization/tensorflow.py +26 -29
  83. doctr/models/detection/fast/__init__.py +4 -4
  84. doctr/models/detection/fast/base.py +8 -17
  85. doctr/models/detection/fast/pytorch.py +37 -35
  86. doctr/models/detection/fast/tensorflow.py +24 -28
  87. doctr/models/detection/linknet/__init__.py +4 -4
  88. doctr/models/detection/linknet/base.py +8 -18
  89. doctr/models/detection/linknet/pytorch.py +34 -28
  90. doctr/models/detection/linknet/tensorflow.py +24 -25
  91. doctr/models/detection/predictor/__init__.py +5 -5
  92. doctr/models/detection/predictor/pytorch.py +6 -7
  93. doctr/models/detection/predictor/tensorflow.py +5 -6
  94. doctr/models/detection/zoo.py +27 -7
  95. doctr/models/factory/hub.py +6 -10
  96. doctr/models/kie_predictor/__init__.py +5 -5
  97. doctr/models/kie_predictor/base.py +4 -5
  98. doctr/models/kie_predictor/pytorch.py +19 -20
  99. doctr/models/kie_predictor/tensorflow.py +14 -15
  100. doctr/models/modules/layers/__init__.py +3 -3
  101. doctr/models/modules/layers/pytorch.py +55 -10
  102. doctr/models/modules/layers/tensorflow.py +5 -7
  103. doctr/models/modules/transformer/__init__.py +3 -3
  104. doctr/models/modules/transformer/pytorch.py +12 -13
  105. doctr/models/modules/transformer/tensorflow.py +9 -10
  106. doctr/models/modules/vision_transformer/__init__.py +3 -3
  107. doctr/models/modules/vision_transformer/pytorch.py +2 -3
  108. doctr/models/modules/vision_transformer/tensorflow.py +3 -3
  109. doctr/models/predictor/__init__.py +5 -5
  110. doctr/models/predictor/base.py +28 -29
  111. doctr/models/predictor/pytorch.py +13 -14
  112. doctr/models/predictor/tensorflow.py +9 -10
  113. doctr/models/preprocessor/__init__.py +4 -4
  114. doctr/models/preprocessor/pytorch.py +13 -17
  115. doctr/models/preprocessor/tensorflow.py +10 -14
  116. doctr/models/recognition/__init__.py +1 -0
  117. doctr/models/recognition/core.py +3 -7
  118. doctr/models/recognition/crnn/__init__.py +4 -4
  119. doctr/models/recognition/crnn/pytorch.py +30 -29
  120. doctr/models/recognition/crnn/tensorflow.py +21 -24
  121. doctr/models/recognition/master/__init__.py +3 -3
  122. doctr/models/recognition/master/base.py +3 -7
  123. doctr/models/recognition/master/pytorch.py +32 -25
  124. doctr/models/recognition/master/tensorflow.py +22 -25
  125. doctr/models/recognition/parseq/__init__.py +3 -3
  126. doctr/models/recognition/parseq/base.py +3 -7
  127. doctr/models/recognition/parseq/pytorch.py +47 -29
  128. doctr/models/recognition/parseq/tensorflow.py +29 -27
  129. doctr/models/recognition/predictor/__init__.py +5 -5
  130. doctr/models/recognition/predictor/_utils.py +111 -52
  131. doctr/models/recognition/predictor/pytorch.py +9 -9
  132. doctr/models/recognition/predictor/tensorflow.py +8 -9
  133. doctr/models/recognition/sar/__init__.py +4 -4
  134. doctr/models/recognition/sar/pytorch.py +30 -22
  135. doctr/models/recognition/sar/tensorflow.py +22 -24
  136. doctr/models/recognition/utils.py +57 -53
  137. doctr/models/recognition/viptr/__init__.py +4 -0
  138. doctr/models/recognition/viptr/pytorch.py +277 -0
  139. doctr/models/recognition/vitstr/__init__.py +4 -4
  140. doctr/models/recognition/vitstr/base.py +3 -7
  141. doctr/models/recognition/vitstr/pytorch.py +28 -21
  142. doctr/models/recognition/vitstr/tensorflow.py +22 -23
  143. doctr/models/recognition/zoo.py +27 -11
  144. doctr/models/utils/__init__.py +4 -4
  145. doctr/models/utils/pytorch.py +41 -34
  146. doctr/models/utils/tensorflow.py +31 -23
  147. doctr/models/zoo.py +1 -5
  148. doctr/transforms/functional/__init__.py +3 -3
  149. doctr/transforms/functional/base.py +4 -11
  150. doctr/transforms/functional/pytorch.py +20 -28
  151. doctr/transforms/functional/tensorflow.py +10 -22
  152. doctr/transforms/modules/__init__.py +4 -4
  153. doctr/transforms/modules/base.py +48 -55
  154. doctr/transforms/modules/pytorch.py +58 -22
  155. doctr/transforms/modules/tensorflow.py +18 -32
  156. doctr/utils/common_types.py +8 -9
  157. doctr/utils/data.py +9 -13
  158. doctr/utils/fonts.py +2 -7
  159. doctr/utils/geometry.py +17 -48
  160. doctr/utils/metrics.py +17 -37
  161. doctr/utils/multithreading.py +4 -6
  162. doctr/utils/reconstitution.py +9 -13
  163. doctr/utils/repr.py +2 -3
  164. doctr/utils/visualization.py +16 -29
  165. doctr/version.py +1 -1
  166. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +70 -52
  167. python_doctr-0.12.0.dist-info/RECORD +180 -0
  168. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  169. python_doctr-0.10.0.dist-info/RECORD +0 -173
  170. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  171. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  172. {python_doctr-0.10.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -1,10 +1,10 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any
8
8
 
9
9
  import tensorflow as tf
10
10
  from tensorflow.keras import Model, layers
@@ -19,7 +19,7 @@ from .base import _MASTER, _MASTERPostProcessor
19
19
  __all__ = ["MASTER", "master"]
20
20
 
21
21
 
22
- default_cfgs: Dict[str, Dict[str, Any]] = {
22
+ default_cfgs: dict[str, dict[str, Any]] = {
23
23
  "master": {
24
24
  "mean": (0.694, 0.695, 0.693),
25
25
  "std": (0.299, 0.296, 0.301),
@@ -35,7 +35,6 @@ class MASTER(_MASTER, Model):
35
35
  Implementation based on the official TF implementation: <https://github.com/jiangxiluning/MASTER-TF>`_.
36
36
 
37
37
  Args:
38
- ----
39
38
  feature_extractor: the backbone serving as feature extractor
40
39
  vocab: vocabulary, (without EOS, SOS, PAD)
41
40
  d_model: d parameter for the transformer decoder
@@ -59,9 +58,9 @@ class MASTER(_MASTER, Model):
59
58
  num_layers: int = 3,
60
59
  max_length: int = 50,
61
60
  dropout: float = 0.2,
62
- input_shape: Tuple[int, int, int] = (32, 128, 3), # different from the paper
61
+ input_shape: tuple[int, int, int] = (32, 128, 3), # different from the paper
63
62
  exportable: bool = False,
64
- cfg: Optional[Dict[str, Any]] = None,
63
+ cfg: dict[str, Any] | None = None,
65
64
  ) -> None:
66
65
  super().__init__()
67
66
 
@@ -88,8 +87,17 @@ class MASTER(_MASTER, Model):
88
87
  self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())
89
88
  self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
90
89
 
90
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
91
+ """Load pretrained parameters onto the model
92
+
93
+ Args:
94
+ path_or_url: the path or URL to the model parameters (checkpoint)
95
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
96
+ """
97
+ load_pretrained_params(self, path_or_url, **kwargs)
98
+
91
99
  @tf.function
92
- def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
100
+ def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
93
101
  # [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
94
102
  # (N, 1, 1, max_length)
95
103
  target_pad_mask = tf.cast(tf.math.not_equal(target, self.vocab_size + 2), dtype=tf.uint8)
@@ -109,19 +117,17 @@ class MASTER(_MASTER, Model):
109
117
  def compute_loss(
110
118
  model_output: tf.Tensor,
111
119
  gt: tf.Tensor,
112
- seq_len: List[int],
120
+ seq_len: list[int],
113
121
  ) -> tf.Tensor:
114
122
  """Compute categorical cross-entropy loss for the model.
115
123
  Sequences are masked after the EOS character.
116
124
 
117
125
  Args:
118
- ----
119
126
  gt: the encoded tensor with gt labels
120
127
  model_output: predicted logits of the model
121
128
  seq_len: lengths of each gt word inside the batch
122
129
 
123
130
  Returns:
124
- -------
125
131
  The loss of the model on the batch
126
132
  """
127
133
  # Input length : number of timesteps
@@ -144,15 +150,14 @@ class MASTER(_MASTER, Model):
144
150
  def call(
145
151
  self,
146
152
  x: tf.Tensor,
147
- target: Optional[List[str]] = None,
153
+ target: list[str] | None = None,
148
154
  return_model_output: bool = False,
149
155
  return_preds: bool = False,
150
156
  **kwargs: Any,
151
- ) -> Dict[str, Any]:
157
+ ) -> dict[str, Any]:
152
158
  """Call function for training
153
159
 
154
160
  Args:
155
- ----
156
161
  x: images
157
162
  target: list of str labels
158
163
  return_model_output: if True, return logits
@@ -160,7 +165,6 @@ class MASTER(_MASTER, Model):
160
165
  **kwargs: keyword arguments passed to the decoder
161
166
 
162
167
  Returns:
163
- -------
164
168
  A dictionnary containing eventually loss, logits and predictions.
165
169
  """
166
170
  # Encode
@@ -171,7 +175,7 @@ class MASTER(_MASTER, Model):
171
175
  # add positional encoding to features
172
176
  encoded = self.positional_encoding(feature, **kwargs)
173
177
 
174
- out: Dict[str, tf.Tensor] = {}
178
+ out: dict[str, tf.Tensor] = {}
175
179
 
176
180
  if kwargs.get("training", False) and target is None:
177
181
  raise ValueError("Need to provide labels during training")
@@ -209,13 +213,11 @@ class MASTER(_MASTER, Model):
209
213
  """Decode function for prediction
210
214
 
211
215
  Args:
212
- ----
213
216
  encoded: encoded features
214
217
  **kwargs: keyword arguments passed to the decoder
215
218
 
216
219
  Returns:
217
- -------
218
- A Tuple of tf.Tensor: predictions, logits
220
+ A tuple of tf.Tensor: predictions, logits
219
221
  """
220
222
  b = encoded.shape[0]
221
223
 
@@ -247,14 +249,13 @@ class MASTERPostProcessor(_MASTERPostProcessor):
247
249
  """Post processor for MASTER architectures
248
250
 
249
251
  Args:
250
- ----
251
252
  vocab: string containing the ordered sequence of supported characters
252
253
  """
253
254
 
254
255
  def __call__(
255
256
  self,
256
257
  logits: tf.Tensor,
257
- ) -> List[Tuple[str, float]]:
258
+ ) -> list[tuple[str, float]]:
258
259
  # compute pred with argmax for attention models
259
260
  out_idxs = tf.math.argmax(logits, axis=2)
260
261
  # N x L
@@ -295,9 +296,7 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
295
296
  # Load pretrained parameters
296
297
  if pretrained:
297
298
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
298
- load_pretrained_params(
299
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
300
- )
299
+ model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
301
300
 
302
301
  return model
303
302
 
@@ -312,12 +311,10 @@ def master(pretrained: bool = False, **kwargs: Any) -> MASTER:
312
311
  >>> out = model(input_tensor)
313
312
 
314
313
  Args:
315
- ----
316
314
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
317
315
  **kwargs: keywoard arguments passed to the MASTER architecture
318
316
 
319
317
  Returns:
320
- -------
321
318
  text recognition architecture
322
319
  """
323
320
  return _master("master", pretrained, magc_resnet31, **kwargs)
@@ -1,6 +1,6 @@
1
1
  from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
4
6
  from .tensorflow import *
5
- elif is_torch_available():
6
- from .pytorch import * # type: ignore[assignment]
@@ -1,9 +1,8 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple
7
6
 
8
7
  import numpy as np
9
8
 
@@ -17,17 +16,15 @@ class _PARSeq:
17
16
 
18
17
  def build_target(
19
18
  self,
20
- gts: List[str],
21
- ) -> Tuple[np.ndarray, List[int]]:
19
+ gts: list[str],
20
+ ) -> tuple[np.ndarray, list[int]]:
22
21
  """Encode a list of gts sequences into a np array and gives the corresponding*
23
22
  sequence lengths.
24
23
 
25
24
  Args:
26
- ----
27
25
  gts: list of ground-truth labels
28
26
 
29
27
  Returns:
30
- -------
31
28
  A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch)
32
29
  """
33
30
  encoded = encode_sequences(
@@ -46,7 +43,6 @@ class _PARSeqPostProcessor(RecognitionPostProcessor):
46
43
  """Abstract class to postprocess the raw output of the model
47
44
 
48
45
  Args:
49
- ----
50
46
  vocab: string containing the ordered sequence of supported characters
51
47
  """
52
48
 
@@ -1,12 +1,13 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import math
7
+ from collections.abc import Callable
7
8
  from copy import deepcopy
8
9
  from itertools import permutations
9
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+ from typing import Any
10
11
 
11
12
  import numpy as np
12
13
  import torch
@@ -23,7 +24,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
23
24
 
24
25
  __all__ = ["PARSeq", "parseq"]
25
26
 
26
- default_cfgs: Dict[str, Dict[str, Any]] = {
27
+ default_cfgs: dict[str, dict[str, Any]] = {
27
28
  "parseq": {
28
29
  "mean": (0.694, 0.695, 0.693),
29
30
  "std": (0.299, 0.296, 0.301),
@@ -38,7 +39,6 @@ class CharEmbedding(nn.Module):
38
39
  """Implements the character embedding module
39
40
 
40
41
  Args:
41
- ----
42
42
  vocab_size: size of the vocabulary
43
43
  d_model: dimension of the model
44
44
  """
@@ -56,7 +56,6 @@ class PARSeqDecoder(nn.Module):
56
56
  """Implements decoder module of the PARSeq model
57
57
 
58
58
  Args:
59
- ----
60
59
  d_model: dimension of the model
61
60
  num_heads: number of attention heads
62
61
  ffd: dimension of the feed forward layer
@@ -77,8 +76,6 @@ class PARSeqDecoder(nn.Module):
77
76
  self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
78
77
  self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU())
79
78
 
80
- self.attention_norm = nn.LayerNorm(d_model, eps=1e-5)
81
- self.cross_attention_norm = nn.LayerNorm(d_model, eps=1e-5)
82
79
  self.query_norm = nn.LayerNorm(d_model, eps=1e-5)
83
80
  self.content_norm = nn.LayerNorm(d_model, eps=1e-5)
84
81
  self.feed_forward_norm = nn.LayerNorm(d_model, eps=1e-5)
@@ -92,7 +89,7 @@ class PARSeqDecoder(nn.Module):
92
89
  target,
93
90
  content,
94
91
  memory,
95
- target_mask: Optional[torch.Tensor] = None,
92
+ target_mask: torch.Tensor | None = None,
96
93
  ):
97
94
  query_norm = self.query_norm(target)
98
95
  content_norm = self.content_norm(content)
@@ -112,7 +109,6 @@ class PARSeq(_PARSeq, nn.Module):
112
109
  Slightly modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
113
110
 
114
111
  Args:
115
- ----
116
112
  feature_extractor: the backbone serving as feature extractor
117
113
  vocab: vocabulary used for encoding
118
114
  embedding_units: number of embedding units
@@ -136,9 +132,9 @@ class PARSeq(_PARSeq, nn.Module):
136
132
  dec_num_heads: int = 12,
137
133
  dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
138
134
  dec_ffd_ratio: int = 4,
139
- input_shape: Tuple[int, int, int] = (3, 32, 128),
135
+ input_shape: tuple[int, int, int] = (3, 32, 128),
140
136
  exportable: bool = False,
141
- cfg: Optional[Dict[str, Any]] = None,
137
+ cfg: dict[str, Any] | None = None,
142
138
  ) -> None:
143
139
  super().__init__()
144
140
  self.vocab = vocab
@@ -175,6 +171,26 @@ class PARSeq(_PARSeq, nn.Module):
175
171
  nn.init.constant_(m.weight, 1)
176
172
  nn.init.constant_(m.bias, 0)
177
173
 
174
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
175
+ """Load pretrained parameters onto the model
176
+
177
+ Args:
178
+ path_or_url: the path or URL to the model parameters (checkpoint)
179
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
180
+ """
181
+ # NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
182
+ # ref.: https://github.com/mindee/doctr/issues/1911
183
+ if kwargs.get("ignore_keys") is None:
184
+ kwargs["ignore_keys"] = []
185
+
186
+ kwargs["ignore_keys"].extend([
187
+ "decoder.attention_norm.weight",
188
+ "decoder.attention_norm.bias",
189
+ "decoder.cross_attention_norm.weight",
190
+ "decoder.cross_attention_norm.bias",
191
+ ])
192
+ load_pretrained_params(self, path_or_url, **kwargs)
193
+
178
194
  def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor:
179
195
  # Generates permutations of the target sequence.
180
196
  # Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
@@ -217,7 +233,7 @@ class PARSeq(_PARSeq, nn.Module):
217
233
  combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
218
234
  return combined
219
235
 
220
- def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
236
+ def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
221
237
  # Generate source and target mask for the decoder attention.
222
238
  sz = permutation.shape[0]
223
239
  mask = torch.ones((sz, sz), device=permutation.device)
@@ -236,8 +252,8 @@ class PARSeq(_PARSeq, nn.Module):
236
252
  self,
237
253
  target: torch.Tensor,
238
254
  memory: torch.Tensor,
239
- target_mask: Optional[torch.Tensor] = None,
240
- target_query: Optional[torch.Tensor] = None,
255
+ target_mask: torch.Tensor | None = None,
256
+ target_query: torch.Tensor | None = None,
241
257
  ) -> torch.Tensor:
242
258
  """Add positional information to the target sequence and pass it through the decoder."""
243
259
  batch_size, sequence_length = target.shape
@@ -250,7 +266,7 @@ class PARSeq(_PARSeq, nn.Module):
250
266
  target_query = self.dropout(target_query)
251
267
  return self.decoder(target_query, content, memory, target_mask)
252
268
 
253
- def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
269
+ def decode_autoregressive(self, features: torch.Tensor, max_len: int | None = None) -> torch.Tensor:
254
270
  """Generate predictions for the given features."""
255
271
  max_length = max_len if max_len is not None else self.max_length
256
272
  max_length = min(max_length, self.max_length) + 1
@@ -283,7 +299,7 @@ class PARSeq(_PARSeq, nn.Module):
283
299
 
284
300
  # Stop decoding if all sequences have reached the EOS token
285
301
  # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
286
- if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
302
+ if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
287
303
  break
288
304
 
289
305
  logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
@@ -298,7 +314,7 @@ class PARSeq(_PARSeq, nn.Module):
298
314
 
299
315
  # Create padding mask for refined target input maskes all behind EOS token as False
300
316
  # (N, 1, 1, max_length)
301
- target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
317
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
302
318
  mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
303
319
  logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
304
320
 
@@ -307,10 +323,10 @@ class PARSeq(_PARSeq, nn.Module):
307
323
  def forward(
308
324
  self,
309
325
  x: torch.Tensor,
310
- target: Optional[List[str]] = None,
326
+ target: list[str] | None = None,
311
327
  return_model_output: bool = False,
312
328
  return_preds: bool = False,
313
- ) -> Dict[str, Any]:
329
+ ) -> dict[str, Any]:
314
330
  features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
315
331
  # remove cls token
316
332
  features = features[:, 1:, :]
@@ -337,7 +353,7 @@ class PARSeq(_PARSeq, nn.Module):
337
353
  ).unsqueeze(1).unsqueeze(1) # (N, 1, 1, seq_len)
338
354
 
339
355
  loss = torch.tensor(0.0, device=features.device)
340
- loss_numel: Union[int, float] = 0
356
+ loss_numel: int | float = 0
341
357
  n = (gt_out != self.vocab_size + 2).sum().item()
342
358
  for i, perm in enumerate(tgt_perms):
343
359
  _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
@@ -351,7 +367,7 @@ class PARSeq(_PARSeq, nn.Module):
351
367
  # remove the [EOS] tokens for the succeeding perms
352
368
  if i == 1:
353
369
  gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
354
- n = (gt_out != self.vocab_size + 2).sum().item()
370
+ n = (gt_out != self.vocab_size + 2).sum().item() # type: ignore[attr-defined]
355
371
 
356
372
  loss /= loss_numel
357
373
 
@@ -365,7 +381,7 @@ class PARSeq(_PARSeq, nn.Module):
365
381
 
366
382
  logits = _bf16_to_float32(logits)
367
383
 
368
- out: Dict[str, Any] = {}
384
+ out: dict[str, Any] = {}
369
385
  if self.exportable:
370
386
  out["logits"] = logits
371
387
  return out
@@ -374,8 +390,13 @@ class PARSeq(_PARSeq, nn.Module):
374
390
  out["out_map"] = logits
375
391
 
376
392
  if target is None or return_preds:
393
+ # Disable for torch.compile compatibility
394
+ @torch.compiler.disable # type: ignore[attr-defined]
395
+ def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
396
+ return self.postprocessor(logits)
397
+
377
398
  # Post-process boxes
378
- out["preds"] = self.postprocessor(logits)
399
+ out["preds"] = _postprocess(logits)
379
400
 
380
401
  if target is not None:
381
402
  out["loss"] = loss
@@ -387,14 +408,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
387
408
  """Post processor for PARSeq architecture
388
409
 
389
410
  Args:
390
- ----
391
411
  vocab: string containing the ordered sequence of supported characters
392
412
  """
393
413
 
394
414
  def __call__(
395
415
  self,
396
416
  logits: torch.Tensor,
397
- ) -> List[Tuple[str, float]]:
417
+ ) -> list[tuple[str, float]]:
398
418
  # compute pred with argmax for attention models
399
419
  out_idxs = logits.argmax(-1)
400
420
  preds_prob = torch.softmax(logits, -1).max(dim=-1)[0]
@@ -417,7 +437,7 @@ def _parseq(
417
437
  pretrained: bool,
418
438
  backbone_fn: Callable[[bool], nn.Module],
419
439
  layer: str,
420
- ignore_keys: Optional[List[str]] = None,
440
+ ignore_keys: list[str] | None = None,
421
441
  **kwargs: Any,
422
442
  ) -> PARSeq:
423
443
  # Patch the config
@@ -446,7 +466,7 @@ def _parseq(
446
466
  # The number of classes is not the same as the number of classes in the pretrained model =>
447
467
  # remove the last layer weights
448
468
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
449
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
469
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
450
470
 
451
471
  return model
452
472
 
@@ -462,12 +482,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
462
482
  >>> out = model(input_tensor)
463
483
 
464
484
  Args:
465
- ----
466
485
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
467
486
  **kwargs: keyword arguments of the PARSeq architecture
468
487
 
469
488
  Returns:
470
- -------
471
489
  text recognition architecture
472
490
  """
473
491
  return _parseq(
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2024, Mindee.
1
+ # Copyright (C) 2021-2025, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -6,7 +6,7 @@
6
6
  import math
7
7
  from copy import deepcopy
8
8
  from itertools import permutations
9
- from typing import Any, Dict, List, Optional, Tuple
9
+ from typing import Any
10
10
 
11
11
  import numpy as np
12
12
  import tensorflow as tf
@@ -21,7 +21,7 @@ from .base import _PARSeq, _PARSeqPostProcessor
21
21
 
22
22
  __all__ = ["PARSeq", "parseq"]
23
23
 
24
- default_cfgs: Dict[str, Dict[str, Any]] = {
24
+ default_cfgs: dict[str, dict[str, Any]] = {
25
25
  "parseq": {
26
26
  "mean": (0.694, 0.695, 0.693),
27
27
  "std": (0.299, 0.296, 0.301),
@@ -36,7 +36,7 @@ class CharEmbedding(layers.Layer):
36
36
  """Implements the character embedding module
37
37
 
38
38
  Args:
39
- ----
39
+ -
40
40
  vocab_size: size of the vocabulary
41
41
  d_model: dimension of the model
42
42
  """
@@ -54,7 +54,6 @@ class PARSeqDecoder(layers.Layer):
54
54
  """Implements decoder module of the PARSeq model
55
55
 
56
56
  Args:
57
- ----
58
57
  d_model: dimension of the model
59
58
  num_heads: number of attention heads
60
59
  ffd: dimension of the feed forward layer
@@ -77,8 +76,6 @@ class PARSeqDecoder(layers.Layer):
77
76
  d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
78
77
  )
79
78
 
80
- self.attention_norm = layers.LayerNormalization(epsilon=1e-5)
81
- self.cross_attention_norm = layers.LayerNormalization(epsilon=1e-5)
82
79
  self.query_norm = layers.LayerNormalization(epsilon=1e-5)
83
80
  self.content_norm = layers.LayerNormalization(epsilon=1e-5)
84
81
  self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
@@ -115,7 +112,6 @@ class PARSeq(_PARSeq, Model):
115
112
  Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
116
113
 
117
114
  Args:
118
- ----
119
115
  feature_extractor: the backbone serving as feature extractor
120
116
  vocab: vocabulary used for encoding
121
117
  embedding_units: number of embedding units
@@ -129,7 +125,7 @@ class PARSeq(_PARSeq, Model):
129
125
  cfg: dictionary containing information about the model
130
126
  """
131
127
 
132
- _children_names: List[str] = ["feat_extractor", "postprocessor"]
128
+ _children_names: list[str] = ["feat_extractor", "postprocessor"]
133
129
 
134
130
  def __init__(
135
131
  self,
@@ -141,9 +137,9 @@ class PARSeq(_PARSeq, Model):
141
137
  dec_num_heads: int = 12,
142
138
  dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
143
139
  dec_ffd_ratio: int = 4,
144
- input_shape: Tuple[int, int, int] = (32, 128, 3),
140
+ input_shape: tuple[int, int, int] = (32, 128, 3),
145
141
  exportable: bool = False,
146
- cfg: Optional[Dict[str, Any]] = None,
142
+ cfg: dict[str, Any] | None = None,
147
143
  ) -> None:
148
144
  super().__init__()
149
145
  self.vocab = vocab
@@ -167,6 +163,18 @@ class PARSeq(_PARSeq, Model):
167
163
 
168
164
  self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
169
165
 
166
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
167
+ """Load pretrained parameters onto the model
168
+
169
+ Args:
170
+ path_or_url: the path or URL to the model parameters (checkpoint)
171
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
172
+ """
173
+ # NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
174
+ # ref.: https://github.com/mindee/doctr/issues/1911
175
+ kwargs["skip_mismatch"] = True
176
+ load_pretrained_params(self, path_or_url, **kwargs)
177
+
170
178
  def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
171
179
  # Generates permutations of the target sequence.
172
180
  # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
@@ -213,7 +221,7 @@ class PARSeq(_PARSeq, Model):
213
221
  )
214
222
  return combined
215
223
 
216
- def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
224
+ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
217
225
  # Generate source and target mask for the decoder attention.
218
226
  sz = permutation.shape[0]
219
227
  mask = tf.ones((sz, sz), dtype=tf.float32)
@@ -236,8 +244,8 @@ class PARSeq(_PARSeq, Model):
236
244
  self,
237
245
  target: tf.Tensor,
238
246
  memory: tf.Tensor,
239
- target_mask: Optional[tf.Tensor] = None,
240
- target_query: Optional[tf.Tensor] = None,
247
+ target_mask: tf.Tensor | None = None,
248
+ target_query: tf.Tensor | None = None,
241
249
  **kwargs: Any,
242
250
  ) -> tf.Tensor:
243
251
  batch_size, sequence_length = target.shape
@@ -250,8 +258,7 @@ class PARSeq(_PARSeq, Model):
250
258
  target_query = self.dropout(target_query, **kwargs)
251
259
  return self.decoder(target_query, content, memory, target_mask, **kwargs)
252
260
 
253
- @tf.function
254
- def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
261
+ def decode_autoregressive(self, features: tf.Tensor, max_len: int | None = None, **kwargs) -> tf.Tensor:
255
262
  """Generate predictions for the given features."""
256
263
  max_length = max_len if max_len is not None else self.max_length
257
264
  max_length = min(max_length, self.max_length) + 1
@@ -318,11 +325,11 @@ class PARSeq(_PARSeq, Model):
318
325
  def call(
319
326
  self,
320
327
  x: tf.Tensor,
321
- target: Optional[List[str]] = None,
328
+ target: list[str] | None = None,
322
329
  return_model_output: bool = False,
323
330
  return_preds: bool = False,
324
331
  **kwargs: Any,
325
- ) -> Dict[str, Any]:
332
+ ) -> dict[str, Any]:
326
333
  features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
327
334
  # remove cls token
328
335
  features = features[:, 1:, :]
@@ -393,7 +400,7 @@ class PARSeq(_PARSeq, Model):
393
400
 
394
401
  logits = _bf16_to_float32(logits)
395
402
 
396
- out: Dict[str, tf.Tensor] = {}
403
+ out: dict[str, tf.Tensor] = {}
397
404
  if self.exportable:
398
405
  out["logits"] = logits
399
406
  return out
@@ -415,14 +422,13 @@ class PARSeqPostProcessor(_PARSeqPostProcessor):
415
422
  """Post processor for PARSeq architecture
416
423
 
417
424
  Args:
418
- ----
419
425
  vocab: string containing the ordered sequence of supported characters
420
426
  """
421
427
 
422
428
  def __call__(
423
429
  self,
424
430
  logits: tf.Tensor,
425
- ) -> List[Tuple[str, float]]:
431
+ ) -> list[tuple[str, float]]:
426
432
  # compute pred with argmax for attention models
427
433
  out_idxs = tf.math.argmax(logits, axis=2)
428
434
  preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
@@ -448,7 +454,7 @@ def _parseq(
448
454
  arch: str,
449
455
  pretrained: bool,
450
456
  backbone_fn,
451
- input_shape: Optional[Tuple[int, int, int]] = None,
457
+ input_shape: tuple[int, int, int] | None = None,
452
458
  **kwargs: Any,
453
459
  ) -> PARSeq:
454
460
  # Patch the config
@@ -478,9 +484,7 @@ def _parseq(
478
484
  # Load pretrained parameters
479
485
  if pretrained:
480
486
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
481
- load_pretrained_params(
482
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
483
- )
487
+ model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
484
488
 
485
489
  return model
486
490
 
@@ -496,12 +500,10 @@ def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
496
500
  >>> out = model(input_tensor)
497
501
 
498
502
  Args:
499
- ----
500
503
  pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
501
504
  **kwargs: keyword arguments of the PARSeq architecture
502
505
 
503
506
  Returns:
504
- -------
505
507
  text recognition architecture
506
508
  """
507
509
  return _parseq(
@@ -1,6 +1,6 @@
1
- from doctr.file_utils import is_tf_available
1
+ from doctr.file_utils import is_tf_available, is_torch_available
2
2
 
3
- if is_tf_available():
4
- from .tensorflow import *
5
- else:
6
- from .pytorch import * # type: ignore[assignment]
3
+ if is_torch_available():
4
+ from .pytorch import *
5
+ elif is_tf_available():
6
+ from .tensorflow import * # type: ignore[assignment]