python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -116,18 +116,14 @@ class _OCRPredictor:
116
116
  def _generate_crops(
117
117
  pages: list[np.ndarray],
118
118
  loc_preds: list[np.ndarray],
119
- channels_last: bool,
120
119
  assume_straight_pages: bool = False,
121
120
  assume_horizontal: bool = False,
122
121
  ) -> list[list[np.ndarray]]:
123
122
  if assume_straight_pages:
124
- crops = [
125
- extract_crops(page, _boxes[:, :4], channels_last=channels_last)
126
- for page, _boxes in zip(pages, loc_preds)
127
- ]
123
+ crops = [extract_crops(page, _boxes[:, :4]) for page, _boxes in zip(pages, loc_preds)]
128
124
  else:
129
125
  crops = [
130
- extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
126
+ extract_rcrops(page, _boxes[:, :4], assume_horizontal=assume_horizontal)
131
127
  for page, _boxes in zip(pages, loc_preds)
132
128
  ]
133
129
  return crops
@@ -136,11 +132,10 @@ class _OCRPredictor:
136
132
  def _prepare_crops(
137
133
  pages: list[np.ndarray],
138
134
  loc_preds: list[np.ndarray],
139
- channels_last: bool,
140
135
  assume_straight_pages: bool = False,
141
136
  assume_horizontal: bool = False,
142
137
  ) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
143
- crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
138
+ crops = _OCRPredictor._generate_crops(pages, loc_preds, assume_straight_pages, assume_horizontal)
144
139
 
145
140
  # Avoid sending zero-sized crops
146
141
  is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
@@ -68,14 +68,14 @@ class OCRPredictor(nn.Module, _OCRPredictor):
68
68
  @torch.inference_mode()
69
69
  def forward(
70
70
  self,
71
- pages: list[np.ndarray | torch.Tensor],
71
+ pages: list[np.ndarray],
72
72
  **kwargs: Any,
73
73
  ) -> Document:
74
74
  # Dimension check
75
75
  if any(page.ndim != 3 for page in pages):
76
76
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
77
77
 
78
- origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
78
+ origin_page_shapes = [page.shape[:2] for page in pages]
79
79
 
80
80
  # Localize text elements
81
81
  loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
@@ -109,8 +109,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
109
109
  loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
110
110
  # Detach objectness scores from loc_preds
111
111
  loc_preds, objectness_scores = detach_scores(loc_preds)
112
- # Check whether crop mode should be switched to channels first
113
- channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
114
112
 
115
113
  # Apply hooks to loc_preds if any
116
114
  for hook in self.hooks:
@@ -120,7 +118,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
120
118
  crops, loc_preds = self._prepare_crops(
121
119
  pages,
122
120
  loc_preds,
123
- channels_last=channels_last,
124
121
  assume_straight_pages=self.assume_straight_pages,
125
122
  assume_horizontal=self._page_orientation_disabled,
126
123
  )
@@ -150,7 +147,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
150
147
  boxes,
151
148
  objectness_scores,
152
149
  text_preds,
153
- origin_page_shapes, # type: ignore[arg-type]
150
+ origin_page_shapes,
154
151
  crop_orientations,
155
152
  orientations,
156
153
  languages_dict,
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -60,65 +60,60 @@ class PreProcessor(nn.Module):
60
60
 
61
61
  return batches
62
62
 
63
- def sample_transforms(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
63
+ def sample_transforms(self, x: np.ndarray) -> torch.Tensor:
64
64
  if x.ndim != 3:
65
65
  raise AssertionError("expected list of 3D Tensors")
66
- if isinstance(x, np.ndarray):
67
- if x.dtype not in (np.uint8, np.float32):
68
- raise TypeError("unsupported data type for numpy.ndarray")
69
- x = torch.from_numpy(x.copy()).permute(2, 0, 1)
70
- elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
71
- raise TypeError("unsupported data type for torch.Tensor")
66
+ if x.dtype not in (np.uint8, np.float32, np.float16):
67
+ raise TypeError("unsupported data type for numpy.ndarray")
68
+ tensor = torch.from_numpy(x.copy()).permute(2, 0, 1)
72
69
  # Resizing
73
- x = self.resize(x)
70
+ tensor = self.resize(tensor)
74
71
  # Data type
75
- if x.dtype == torch.uint8:
76
- x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
72
+ if tensor.dtype == torch.uint8:
73
+ tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
77
74
  else:
78
- x = x.to(dtype=torch.float32) # type: ignore[union-attr]
75
+ tensor = tensor.to(dtype=torch.float32)
79
76
 
80
- return x # type: ignore[return-value]
77
+ return tensor
81
78
 
82
- def __call__(self, x: torch.Tensor | np.ndarray | list[torch.Tensor | np.ndarray]) -> list[torch.Tensor]:
79
+ def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]:
83
80
  """Prepare document data for model forwarding
84
81
 
85
82
  Args:
86
- x: list of images (np.array) or tensors (already resized and batched)
83
+ x: list of images (np.array) or a single image (np.array) of shape (H, W, C)
87
84
 
88
85
  Returns:
89
- list of page batches
86
+ list of page batches (*, C, H, W) ready for model inference
90
87
  """
91
88
  # Input type check
92
- if isinstance(x, (np.ndarray, torch.Tensor)):
89
+ if isinstance(x, np.ndarray):
93
90
  if x.ndim != 4:
94
91
  raise AssertionError("expected 4D Tensor")
95
- if isinstance(x, np.ndarray):
96
- if x.dtype not in (np.uint8, np.float32):
97
- raise TypeError("unsupported data type for numpy.ndarray")
98
- x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
99
- elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
100
- raise TypeError("unsupported data type for torch.Tensor")
92
+ if x.dtype not in (np.uint8, np.float32, np.float16):
93
+ raise TypeError("unsupported data type for numpy.ndarray")
94
+ tensor = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
95
+
101
96
  # Resizing
102
- if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
103
- x = F.resize(
104
- x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
97
+ if tensor.shape[-2] != self.resize.size[0] or tensor.shape[-1] != self.resize.size[1]:
98
+ tensor = F.resize(
99
+ tensor, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
105
100
  )
106
101
  # Data type
107
- if x.dtype == torch.uint8: # type: ignore[union-attr]
108
- x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
102
+ if tensor.dtype == torch.uint8:
103
+ tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
109
104
  else:
110
- x = x.to(dtype=torch.float32) # type: ignore[union-attr]
111
- batches = [x]
105
+ tensor = tensor.to(dtype=torch.float32)
106
+ batches = [tensor]
112
107
 
113
- elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x):
108
+ elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
114
109
  # Sample transform (to tensor, resize)
115
110
  samples = list(multithread_exec(self.sample_transforms, x))
116
111
  # Batching
117
- batches = self.batch_inputs(samples) # type: ignore[assignment]
112
+ batches = self.batch_inputs(samples)
118
113
  else:
119
114
  raise TypeError(f"invalid input type: {type(x)}")
120
115
 
121
116
  # Batch transforms (normalize)
122
117
  batches = list(multithread_exec(self.normalize, batches))
123
118
 
124
- return batches # type: ignore[return-value]
119
+ return batches
@@ -3,4 +3,5 @@ from .master import *
3
3
  from .sar import *
4
4
  from .vitstr import *
5
5
  from .parseq import *
6
+ from .viptr import *
6
7
  from .zoo import *
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -15,7 +15,7 @@ from torch.nn import functional as F
15
15
  from doctr.datasets import VOCABS, decode_sequence
16
16
 
17
17
  from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
18
- from ...utils.pytorch import load_pretrained_params
18
+ from ...utils import load_pretrained_params
19
19
  from ..core import RecognitionModel, RecognitionPostProcessor
20
20
 
21
21
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
@@ -25,8 +25,8 @@ default_cfgs: dict[str, dict[str, Any]] = {
25
25
  "mean": (0.694, 0.695, 0.693),
26
26
  "std": (0.299, 0.296, 0.301),
27
27
  "input_shape": (3, 32, 128),
28
- "vocab": VOCABS["legacy_french"],
29
- "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_vgg16_bn-9762b0b0.pt&src=0",
28
+ "vocab": VOCABS["french"],
29
+ "url": "https://doctr-static.mindee.com/models?id=v0.12.0/crnn_vgg16_bn-0417f351.pt&src=0",
30
30
  },
31
31
  "crnn_mobilenet_v3_small": {
32
32
  "mean": (0.694, 0.695, 0.693),
@@ -82,7 +82,7 @@ class CTCPostProcessor(RecognitionPostProcessor):
82
82
 
83
83
  def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
84
84
  """Performs decoding of raw output with CTC and decoding of CTC predictions
85
- with label_to_idx mapping dictionnary
85
+ with label_to_idx mapping dictionary
86
86
 
87
87
  Args:
88
88
  logits: raw output of the model, shape (N, C + 1, seq_len)
@@ -155,6 +155,15 @@ class CRNN(RecognitionModel, nn.Module):
155
155
  m.weight.data.fill_(1.0)
156
156
  m.bias.data.zero_()
157
157
 
158
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
159
+ """Load pretrained parameters onto the model
160
+
161
+ Args:
162
+ path_or_url: the path or URL to the model parameters (checkpoint)
163
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
164
+ """
165
+ load_pretrained_params(self, path_or_url, **kwargs)
166
+
158
167
  def compute_loss(
159
168
  self,
160
169
  model_output: torch.Tensor,
@@ -214,7 +223,7 @@ class CRNN(RecognitionModel, nn.Module):
214
223
 
215
224
  if target is None or return_preds:
216
225
  # Disable for torch.compile compatibility
217
- @torch.compiler.disable # type: ignore[attr-defined]
226
+ @torch.compiler.disable
218
227
  def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
219
228
  return self.postprocessor(logits)
220
229
 
@@ -248,13 +257,13 @@ def _crnn(
248
257
  _cfg["input_shape"] = kwargs["input_shape"]
249
258
 
250
259
  # Build the model
251
- model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
260
+ model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type]
252
261
  # Load pretrained parameters
253
262
  if pretrained:
254
263
  # The number of classes is not the same as the number of classes in the pretrained model =>
255
264
  # remove the last layer weights
256
265
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
257
- load_pretrained_params(model, _cfg["url"], ignore_keys=_ignore_keys)
266
+ model.from_pretrained(_cfg["url"], ignore_keys=_ignore_keys)
258
267
 
259
268
  return model
260
269
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
16
16
  from doctr.models.classification import magc_resnet31
17
17
  from doctr.models.modules.transformer import Decoder, PositionalEncoding
18
18
 
19
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
19
+ from ...utils import _bf16_to_float32, load_pretrained_params
20
20
  from .base import _MASTER, _MASTERPostProcessor
21
21
 
22
22
  __all__ = ["MASTER", "master"]
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
107
107
  # NOTE: nn.TransformerDecoder takes the inverse from this implementation
108
108
  # [True, True, True, ..., False, False, False] -> False is masked
109
109
  # (N, 1, 1, max_length)
110
- target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
110
+ target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
111
111
  target_length = target.size(1)
112
112
  # sub mask filled diagonal with True = see and False = masked (max_length, max_length)
113
113
  # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
@@ -140,7 +140,7 @@ class MASTER(_MASTER, nn.Module):
140
140
  # Input length : number of timesteps
141
141
  input_len = model_output.shape[1]
142
142
  # Add one for additional <eos> token (sos disappear in shift!)
143
- seq_len = seq_len + 1 # type: ignore[assignment]
143
+ seq_len = seq_len + 1
144
144
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
145
145
  # The "masked" first gt char is <sos>. Delete last logit of the model output.
146
146
  cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -151,6 +151,15 @@ class MASTER(_MASTER, nn.Module):
151
151
  ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
152
152
  return ce_loss.mean()
153
153
 
154
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
155
+ """Load pretrained parameters onto the model
156
+
157
+ Args:
158
+ path_or_url: the path or URL to the model parameters (checkpoint)
159
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
160
+ """
161
+ load_pretrained_params(self, path_or_url, **kwargs)
162
+
154
163
  def forward(
155
164
  self,
156
165
  x: torch.Tensor,
@@ -167,7 +176,7 @@ class MASTER(_MASTER, nn.Module):
167
176
  return_preds: if True, decode logits
168
177
 
169
178
  Returns:
170
- A dictionnary containing eventually loss, logits and predictions.
179
+ A dictionary containing eventually loss, logits and predictions.
171
180
  """
172
181
  # Encode
173
182
  features = self.feat_extractor(x)["features"]
@@ -210,7 +219,7 @@ class MASTER(_MASTER, nn.Module):
210
219
 
211
220
  if return_preds:
212
221
  # Disable for torch.compile compatibility
213
- @torch.compiler.disable # type: ignore[attr-defined]
222
+ @torch.compiler.disable
214
223
  def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
215
224
  return self.postprocessor(logits)
216
225
 
@@ -301,7 +310,7 @@ def _master(
301
310
  # The number of classes is not the same as the number of classes in the pretrained model =>
302
311
  # remove the last layer weights
303
312
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
304
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
313
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
305
314
 
306
315
  return model
307
316
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -19,7 +19,7 @@ from doctr.datasets import VOCABS
19
19
  from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
20
20
 
21
21
  from ...classification import vit_s
22
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
22
+ from ...utils import _bf16_to_float32, load_pretrained_params
23
23
  from .base import _PARSeq, _PARSeqPostProcessor
24
24
 
25
25
  __all__ = ["PARSeq", "parseq"]
@@ -76,8 +76,6 @@ class PARSeqDecoder(nn.Module):
76
76
  self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
77
77
  self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU())
78
78
 
79
- self.attention_norm = nn.LayerNorm(d_model, eps=1e-5)
80
- self.cross_attention_norm = nn.LayerNorm(d_model, eps=1e-5)
81
79
  self.query_norm = nn.LayerNorm(d_model, eps=1e-5)
82
80
  self.content_norm = nn.LayerNorm(d_model, eps=1e-5)
83
81
  self.feed_forward_norm = nn.LayerNorm(d_model, eps=1e-5)
@@ -173,6 +171,26 @@ class PARSeq(_PARSeq, nn.Module):
173
171
  nn.init.constant_(m.weight, 1)
174
172
  nn.init.constant_(m.bias, 0)
175
173
 
174
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
175
+ """Load pretrained parameters onto the model
176
+
177
+ Args:
178
+ path_or_url: the path or URL to the model parameters (checkpoint)
179
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
180
+ """
181
+ # NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
182
+ # ref.: https://github.com/mindee/doctr/issues/1911
183
+ if kwargs.get("ignore_keys") is None:
184
+ kwargs["ignore_keys"] = []
185
+
186
+ kwargs["ignore_keys"].extend([
187
+ "decoder.attention_norm.weight",
188
+ "decoder.attention_norm.bias",
189
+ "decoder.cross_attention_norm.weight",
190
+ "decoder.cross_attention_norm.bias",
191
+ ])
192
+ load_pretrained_params(self, path_or_url, **kwargs)
193
+
176
194
  def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor:
177
195
  # Generates permutations of the target sequence.
178
196
  # Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
@@ -210,7 +228,7 @@ class PARSeq(_PARSeq, nn.Module):
210
228
 
211
229
  sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
212
230
  eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
213
- combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore[list-item]
231
+ combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
214
232
  if len(combined) > 1:
215
233
  combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
216
234
  return combined
@@ -281,7 +299,7 @@ class PARSeq(_PARSeq, nn.Module):
281
299
 
282
300
  # Stop decoding if all sequences have reached the EOS token
283
301
  # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
284
- if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): # type: ignore[attr-defined]
302
+ if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
285
303
  break
286
304
 
287
305
  logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
@@ -296,7 +314,7 @@ class PARSeq(_PARSeq, nn.Module):
296
314
 
297
315
  # Create padding mask for refined target input maskes all behind EOS token as False
298
316
  # (N, 1, 1, max_length)
299
- target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
317
+ target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
300
318
  mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
301
319
  logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
302
320
 
@@ -373,7 +391,7 @@ class PARSeq(_PARSeq, nn.Module):
373
391
 
374
392
  if target is None or return_preds:
375
393
  # Disable for torch.compile compatibility
376
- @torch.compiler.disable # type: ignore[attr-defined]
394
+ @torch.compiler.disable
377
395
  def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
378
396
  return self.postprocessor(logits)
379
397
 
@@ -448,7 +466,7 @@ def _parseq(
448
466
  # The number of classes is not the same as the number of classes in the pretrained model =>
449
467
  # remove the last layer weights
450
468
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
451
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
469
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
452
470
 
453
471
  return model
454
472
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -4,6 +4,8 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
+ import math
8
+
7
9
  import numpy as np
8
10
 
9
11
  from ..utils import merge_multi_strings
@@ -15,69 +17,120 @@ def split_crops(
15
17
  crops: list[np.ndarray],
16
18
  max_ratio: float,
17
19
  target_ratio: int,
18
- dilation: float,
19
- channels_last: bool = True,
20
- ) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
21
- """Chunk crops horizontally to match a given aspect ratio
20
+ split_overlap_ratio: float,
21
+ ) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
22
+ """
23
+ Split crops horizontally if they exceed a given aspect ratio.
22
24
 
23
25
  Args:
24
- crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
25
- max_ratio: the maximum aspect ratio that won't trigger the chunk
26
- target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
27
- dilation: the width dilation of final chunks (to provide some overlaps)
28
- channels_last: whether the numpy array has dimensions in channels last order
26
+ crops: List of image crops (H, W, C).
27
+ max_ratio: Aspect ratio threshold above which crops are split.
28
+ target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
29
+ split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
29
30
 
30
31
  Returns:
31
- a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
32
+ A tuple containing:
33
+ - The new list of crops (possibly with splits),
34
+ - A mapping indicating how to reassemble predictions,
35
+ - A boolean indicating whether remapping is required.
32
36
  """
33
- _remap_required = False
34
- crop_map: list[int | tuple[int, int]] = []
37
+ if split_overlap_ratio <= 0.0 or split_overlap_ratio >= 1.0:
38
+ raise ValueError(f"Valid range for split_overlap_ratio is (0.0, 1.0), but is: {split_overlap_ratio}")
39
+
40
+ remap_required = False
35
41
  new_crops: list[np.ndarray] = []
42
+ crop_map: list[int | tuple[int, int, float]] = []
43
+
36
44
  for crop in crops:
37
- h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
45
+ h, w = crop.shape[:2]
38
46
  aspect_ratio = w / h
47
+
39
48
  if aspect_ratio > max_ratio:
40
- # Determine the number of crops, reference aspect ratio = 4 = 128 / 32
41
- num_subcrops = int(aspect_ratio // target_ratio)
42
- # Find the new widths, additional dilation factor to overlap crops
43
- width = dilation * w / num_subcrops
44
- centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
45
- # Get the crops
46
- if channels_last:
47
- _crops = [
48
- crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
49
- for center in centers
50
- ]
49
+ split_width = max(1, math.ceil(h * target_ratio))
50
+ overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
51
+
52
+ splits, last_overlap = _split_horizontally(crop, split_width, overlap_width)
53
+
54
+ # Remove any empty splits
55
+ splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
56
+ if splits:
57
+ crop_map.append((len(new_crops), len(new_crops) + len(splits), last_overlap))
58
+ new_crops.extend(splits)
59
+ remap_required = True
51
60
  else:
52
- _crops = [
53
- crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
54
- for center in centers
55
- ]
56
- # Avoid sending zero-sized crops
57
- _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
58
- # Record the slice of crops
59
- crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
60
- new_crops.extend(_crops)
61
- # At least one crop will require merging
62
- _remap_required = True
61
+ # Fallback: treat it as a single crop
62
+ crop_map.append(len(new_crops))
63
+ new_crops.append(crop)
63
64
  else:
64
65
  crop_map.append(len(new_crops))
65
66
  new_crops.append(crop)
66
67
 
67
- return new_crops, crop_map, _remap_required
68
+ return new_crops, crop_map, remap_required
69
+
70
+
71
+ def _split_horizontally(image: np.ndarray, split_width: int, overlap_width: int) -> tuple[list[np.ndarray], float]:
72
+ """
73
+ Horizontally split a single image with overlapping regions.
74
+
75
+ Args:
76
+ image: The image to split (H, W, C).
77
+ split_width: Width of each split.
78
+ overlap_width: Width of the overlapping region.
79
+
80
+ Returns:
81
+ - A list of horizontal image slices.
82
+ - The actual overlap ratio of the last split.
83
+ """
84
+ image_width = image.shape[1]
85
+ if image_width <= split_width:
86
+ return [image], 0.0
87
+
88
+ # Compute start columns for each split
89
+ step = split_width - overlap_width
90
+ starts = list(range(0, image_width - split_width + 1, step))
91
+
92
+ # Ensure the last patch reaches the end of the image
93
+ if starts[-1] + split_width < image_width:
94
+ starts.append(image_width - split_width)
95
+
96
+ splits = []
97
+ for start_col in starts:
98
+ end_col = start_col + split_width
99
+ splits.append(image[:, start_col:end_col, :])
100
+
101
+ # Calculate the last overlap ratio, if only one split no overlap
102
+ last_overlap = 0
103
+ if len(starts) > 1:
104
+ last_overlap = (starts[-2] + split_width) - starts[-1]
105
+ last_overlap_ratio = last_overlap / split_width if split_width else 0.0
106
+
107
+ return splits, last_overlap_ratio
68
108
 
69
109
 
70
110
  def remap_preds(
71
- preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int]], dilation: float
111
+ preds: list[tuple[str, float]],
112
+ crop_map: list[int | tuple[int, int, float]],
113
+ overlap_ratio: float,
72
114
  ) -> list[tuple[str, float]]:
73
- remapped_out = []
74
- for _idx in crop_map:
75
- # Crop hasn't been split
76
- if isinstance(_idx, int):
77
- remapped_out.append(preds[_idx])
115
+ """
116
+ Reconstruct predictions from possibly split crops.
117
+
118
+ Args:
119
+ preds: List of (text, confidence) tuples from each crop.
120
+ crop_map: Map returned by `split_crops`.
121
+ overlap_ratio: Overlap ratio used during splitting.
122
+
123
+ Returns:
124
+ List of merged (text, confidence) tuples corresponding to original crops.
125
+ """
126
+ remapped = []
127
+ for item in crop_map:
128
+ if isinstance(item, int):
129
+ remapped.append(preds[item])
78
130
  else:
79
- # unzip
80
- vals, probs = zip(*preds[_idx[0] : _idx[1]])
81
- # Merge the string values
82
- remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type]
83
- return remapped_out
131
+ start_idx, end_idx, last_overlap = item
132
+ text_parts, confidences = zip(*preds[start_idx:end_idx])
133
+ merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap)
134
+ merged_conf = sum(confidences) / len(confidences) # average confidence
135
+ remapped.append((merged_text, merged_conf))
136
+ return remapped