python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/contrib/artefacts.py +1 -1
  3. doctr/contrib/base.py +1 -1
  4. doctr/datasets/__init__.py +0 -5
  5. doctr/datasets/coco_text.py +1 -1
  6. doctr/datasets/cord.py +1 -1
  7. doctr/datasets/datasets/__init__.py +1 -6
  8. doctr/datasets/datasets/base.py +1 -1
  9. doctr/datasets/datasets/pytorch.py +3 -3
  10. doctr/datasets/detection.py +1 -1
  11. doctr/datasets/doc_artefacts.py +1 -1
  12. doctr/datasets/funsd.py +1 -1
  13. doctr/datasets/generator/__init__.py +1 -6
  14. doctr/datasets/generator/base.py +1 -1
  15. doctr/datasets/generator/pytorch.py +1 -1
  16. doctr/datasets/ic03.py +1 -1
  17. doctr/datasets/ic13.py +1 -1
  18. doctr/datasets/iiit5k.py +1 -1
  19. doctr/datasets/iiithws.py +1 -1
  20. doctr/datasets/imgur5k.py +1 -1
  21. doctr/datasets/mjsynth.py +1 -1
  22. doctr/datasets/ocr.py +1 -1
  23. doctr/datasets/orientation.py +1 -1
  24. doctr/datasets/recognition.py +1 -1
  25. doctr/datasets/sroie.py +1 -1
  26. doctr/datasets/svhn.py +1 -1
  27. doctr/datasets/svt.py +1 -1
  28. doctr/datasets/synthtext.py +1 -1
  29. doctr/datasets/utils.py +1 -1
  30. doctr/datasets/vocabs.py +1 -3
  31. doctr/datasets/wildreceipt.py +1 -1
  32. doctr/file_utils.py +3 -102
  33. doctr/io/elements.py +1 -1
  34. doctr/io/html.py +1 -1
  35. doctr/io/image/__init__.py +1 -7
  36. doctr/io/image/base.py +1 -1
  37. doctr/io/image/pytorch.py +2 -2
  38. doctr/io/pdf.py +1 -1
  39. doctr/io/reader.py +1 -1
  40. doctr/models/_utils.py +56 -18
  41. doctr/models/builder.py +1 -1
  42. doctr/models/classification/magc_resnet/__init__.py +1 -6
  43. doctr/models/classification/magc_resnet/pytorch.py +3 -3
  44. doctr/models/classification/mobilenet/__init__.py +1 -6
  45. doctr/models/classification/mobilenet/pytorch.py +1 -1
  46. doctr/models/classification/predictor/__init__.py +1 -6
  47. doctr/models/classification/predictor/pytorch.py +2 -2
  48. doctr/models/classification/resnet/__init__.py +1 -6
  49. doctr/models/classification/resnet/pytorch.py +1 -1
  50. doctr/models/classification/textnet/__init__.py +1 -6
  51. doctr/models/classification/textnet/pytorch.py +2 -2
  52. doctr/models/classification/vgg/__init__.py +1 -6
  53. doctr/models/classification/vgg/pytorch.py +1 -1
  54. doctr/models/classification/vip/__init__.py +1 -4
  55. doctr/models/classification/vip/layers/__init__.py +1 -4
  56. doctr/models/classification/vip/layers/pytorch.py +2 -2
  57. doctr/models/classification/vip/pytorch.py +1 -1
  58. doctr/models/classification/vit/__init__.py +1 -6
  59. doctr/models/classification/vit/pytorch.py +3 -3
  60. doctr/models/classification/zoo.py +7 -12
  61. doctr/models/core.py +1 -1
  62. doctr/models/detection/_utils/__init__.py +1 -6
  63. doctr/models/detection/_utils/base.py +1 -1
  64. doctr/models/detection/_utils/pytorch.py +1 -1
  65. doctr/models/detection/core.py +2 -2
  66. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  67. doctr/models/detection/differentiable_binarization/base.py +5 -13
  68. doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
  69. doctr/models/detection/fast/__init__.py +1 -6
  70. doctr/models/detection/fast/base.py +5 -15
  71. doctr/models/detection/fast/pytorch.py +5 -5
  72. doctr/models/detection/linknet/__init__.py +1 -6
  73. doctr/models/detection/linknet/base.py +4 -13
  74. doctr/models/detection/linknet/pytorch.py +3 -3
  75. doctr/models/detection/predictor/__init__.py +1 -6
  76. doctr/models/detection/predictor/pytorch.py +2 -2
  77. doctr/models/detection/zoo.py +16 -33
  78. doctr/models/factory/hub.py +26 -34
  79. doctr/models/kie_predictor/__init__.py +1 -6
  80. doctr/models/kie_predictor/base.py +1 -1
  81. doctr/models/kie_predictor/pytorch.py +3 -7
  82. doctr/models/modules/layers/__init__.py +1 -6
  83. doctr/models/modules/layers/pytorch.py +4 -4
  84. doctr/models/modules/transformer/__init__.py +1 -6
  85. doctr/models/modules/transformer/pytorch.py +3 -3
  86. doctr/models/modules/vision_transformer/__init__.py +1 -6
  87. doctr/models/modules/vision_transformer/pytorch.py +1 -1
  88. doctr/models/predictor/__init__.py +1 -6
  89. doctr/models/predictor/base.py +4 -9
  90. doctr/models/predictor/pytorch.py +3 -6
  91. doctr/models/preprocessor/__init__.py +1 -6
  92. doctr/models/preprocessor/pytorch.py +28 -33
  93. doctr/models/recognition/core.py +1 -1
  94. doctr/models/recognition/crnn/__init__.py +1 -6
  95. doctr/models/recognition/crnn/pytorch.py +7 -7
  96. doctr/models/recognition/master/__init__.py +1 -6
  97. doctr/models/recognition/master/base.py +1 -1
  98. doctr/models/recognition/master/pytorch.py +6 -6
  99. doctr/models/recognition/parseq/__init__.py +1 -6
  100. doctr/models/recognition/parseq/base.py +1 -1
  101. doctr/models/recognition/parseq/pytorch.py +6 -6
  102. doctr/models/recognition/predictor/__init__.py +1 -6
  103. doctr/models/recognition/predictor/_utils.py +8 -17
  104. doctr/models/recognition/predictor/pytorch.py +2 -3
  105. doctr/models/recognition/sar/__init__.py +1 -6
  106. doctr/models/recognition/sar/pytorch.py +4 -4
  107. doctr/models/recognition/utils.py +1 -1
  108. doctr/models/recognition/viptr/__init__.py +1 -4
  109. doctr/models/recognition/viptr/pytorch.py +4 -4
  110. doctr/models/recognition/vitstr/__init__.py +1 -6
  111. doctr/models/recognition/vitstr/base.py +1 -1
  112. doctr/models/recognition/vitstr/pytorch.py +4 -4
  113. doctr/models/recognition/zoo.py +14 -14
  114. doctr/models/utils/__init__.py +1 -6
  115. doctr/models/utils/pytorch.py +3 -2
  116. doctr/models/zoo.py +1 -1
  117. doctr/transforms/functional/__init__.py +1 -6
  118. doctr/transforms/functional/base.py +3 -2
  119. doctr/transforms/functional/pytorch.py +5 -5
  120. doctr/transforms/modules/__init__.py +1 -7
  121. doctr/transforms/modules/base.py +28 -94
  122. doctr/transforms/modules/pytorch.py +29 -27
  123. doctr/utils/common_types.py +1 -1
  124. doctr/utils/data.py +1 -2
  125. doctr/utils/fonts.py +1 -1
  126. doctr/utils/geometry.py +7 -11
  127. doctr/utils/metrics.py +1 -1
  128. doctr/utils/multithreading.py +1 -1
  129. doctr/utils/reconstitution.py +1 -1
  130. doctr/utils/repr.py +1 -1
  131. doctr/utils/visualization.py +2 -2
  132. doctr/version.py +1 -1
  133. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
  134. python_doctr-1.0.1.dist-info/RECORD +149 -0
  135. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
  136. doctr/datasets/datasets/tensorflow.py +0 -59
  137. doctr/datasets/generator/tensorflow.py +0 -58
  138. doctr/datasets/loader.py +0 -94
  139. doctr/io/image/tensorflow.py +0 -101
  140. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  141. doctr/models/classification/mobilenet/tensorflow.py +0 -442
  142. doctr/models/classification/predictor/tensorflow.py +0 -60
  143. doctr/models/classification/resnet/tensorflow.py +0 -418
  144. doctr/models/classification/textnet/tensorflow.py +0 -275
  145. doctr/models/classification/vgg/tensorflow.py +0 -125
  146. doctr/models/classification/vit/tensorflow.py +0 -201
  147. doctr/models/detection/_utils/tensorflow.py +0 -34
  148. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
  149. doctr/models/detection/fast/tensorflow.py +0 -427
  150. doctr/models/detection/linknet/tensorflow.py +0 -377
  151. doctr/models/detection/predictor/tensorflow.py +0 -70
  152. doctr/models/kie_predictor/tensorflow.py +0 -187
  153. doctr/models/modules/layers/tensorflow.py +0 -171
  154. doctr/models/modules/transformer/tensorflow.py +0 -235
  155. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  156. doctr/models/predictor/tensorflow.py +0 -155
  157. doctr/models/preprocessor/tensorflow.py +0 -122
  158. doctr/models/recognition/crnn/tensorflow.py +0 -317
  159. doctr/models/recognition/master/tensorflow.py +0 -320
  160. doctr/models/recognition/parseq/tensorflow.py +0 -516
  161. doctr/models/recognition/predictor/tensorflow.py +0 -79
  162. doctr/models/recognition/sar/tensorflow.py +0 -423
  163. doctr/models/recognition/vitstr/tensorflow.py +0 -285
  164. doctr/models/utils/tensorflow.py +0 -189
  165. doctr/transforms/functional/tensorflow.py +0 -254
  166. doctr/transforms/modules/tensorflow.py +0 -562
  167. python_doctr-0.12.0.dist-info/RECORD +0 -180
  168. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
  169. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
  170. {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -7,26 +7,21 @@
7
7
 
8
8
  import json
9
9
  import logging
10
- import os
11
10
  import subprocess
11
+ import tempfile
12
12
  import textwrap
13
13
  from pathlib import Path
14
14
  from typing import Any
15
15
 
16
+ import torch
16
17
  from huggingface_hub import (
17
18
  HfApi,
18
- Repository,
19
19
  get_token,
20
- get_token_permission,
21
20
  hf_hub_download,
22
21
  login,
23
22
  )
24
23
 
25
24
  from doctr import models
26
- from doctr.file_utils import is_tf_available, is_torch_available
27
-
28
- if is_torch_available():
29
- import torch
30
25
 
31
26
  __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
32
27
 
@@ -41,9 +36,9 @@ AVAILABLE_ARCHS = {
41
36
  def login_to_hub() -> None: # pragma: no cover
42
37
  """Login to huggingface hub"""
43
38
  access_token = get_token()
44
- if access_token is not None and get_token_permission(access_token):
39
+ if access_token is not None:
45
40
  logging.info("Huggingface Hub token found and valid")
46
- login(token=access_token, write_permission=True)
41
+ login(token=access_token)
47
42
  else:
48
43
  login()
49
44
  # check if git lfs is installed
@@ -61,19 +56,14 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
61
56
  """Save model and config to disk for pushing to huggingface hub
62
57
 
63
58
  Args:
64
- model: TF or PyTorch model to be saved
59
+ model: PyTorch model to be saved
65
60
  save_dir: directory to save model and config
66
61
  arch: architecture name
67
62
  task: task name
68
63
  """
69
64
  save_directory = Path(save_dir)
70
-
71
- if is_torch_available():
72
- weights_path = save_directory / "pytorch_model.bin"
73
- torch.save(model.state_dict(), weights_path)
74
- elif is_tf_available():
75
- weights_path = save_directory / "tf_model.weights.h5"
76
- model.save_weights(str(weights_path))
65
+ weights_path = save_directory / "pytorch_model.bin"
66
+ torch.save(model.state_dict(), weights_path)
77
67
 
78
68
  config_path = save_directory / "config.json"
79
69
 
@@ -96,7 +86,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
96
86
  >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
97
87
 
98
88
  Args:
99
- model: TF or PyTorch model to be saved
89
+ model: PyTorch model to be saved
100
90
  model_name: name of the model which is also the repository name
101
91
  task: task name
102
92
  **kwargs: keyword arguments for push_to_hf_hub
@@ -120,7 +110,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
120
110
  <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
121
111
  </p>
122
112
 
123
- **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch**
113
+ **Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**
124
114
 
125
115
  ## Task: {task}
126
116
 
@@ -169,16 +159,23 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
169
159
 
170
160
  commit_message = f"Add {model_name} model"
171
161
 
172
- local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
173
- repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
174
- repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)
162
+ # Create repository
163
+ api = HfApi()
164
+ api.create_repo(model_name, token=get_token(), exist_ok=False)
175
165
 
176
- with repo.commit(commit_message):
177
- _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
178
- readme_path = Path(repo.local_dir) / "README.md"
166
+ # Save model files to a temporary directory
167
+ with tempfile.TemporaryDirectory() as tmp_dir:
168
+ _save_model_and_config_for_hf_hub(model, tmp_dir, arch=arch, task=task)
169
+ readme_path = Path(tmp_dir) / "README.md"
179
170
  readme_path.write_text(readme)
180
171
 
181
- repo.git_push()
172
+ # Upload all files to the hub
173
+ api.upload_folder(
174
+ folder_path=tmp_dir,
175
+ repo_id=model_name,
176
+ commit_message=commit_message,
177
+ token=get_token(),
178
+ )
182
179
 
183
180
 
184
181
  def from_hub(repo_id: str, **kwargs: Any):
@@ -214,13 +211,8 @@ def from_hub(repo_id: str, **kwargs: Any):
214
211
 
215
212
  # update model cfg
216
213
  model.cfg = cfg
217
-
218
- # Load checkpoint
219
- if is_torch_available():
220
- weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
221
- else: # tf
222
- weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
223
-
214
+ # load the weights
215
+ weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
224
216
  model.from_pretrained(weights)
225
217
 
226
218
  return model
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -68,14 +68,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
68
68
  @torch.inference_mode()
69
69
  def forward(
70
70
  self,
71
- pages: list[np.ndarray | torch.Tensor],
71
+ pages: list[np.ndarray],
72
72
  **kwargs: Any,
73
73
  ) -> Document:
74
74
  # Dimension check
75
75
  if any(page.ndim != 3 for page in pages):
76
76
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
77
77
 
78
- origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
78
+ origin_page_shapes = [page.shape[:2] for page in pages]
79
79
 
80
80
  # Localize text elements
81
81
  loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
@@ -113,9 +113,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
113
113
  dict_loc_preds[class_name] = _loc_preds
114
114
  objectness_scores[class_name] = _scores
115
115
 
116
- # Check whether crop mode should be switched to channels first
117
- channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
118
-
119
116
  # Apply hooks to loc_preds if any
120
117
  for hook in self.hooks:
121
118
  dict_loc_preds = hook(dict_loc_preds)
@@ -126,7 +123,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
126
123
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
127
124
  pages,
128
125
  dict_loc_preds[class_name],
129
- channels_last=channels_last,
130
126
  assume_straight_pages=self.assume_straight_pages,
131
127
  assume_horizontal=self._page_orientation_disabled,
132
128
  )
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -151,16 +151,16 @@ class FASTConvLayer(nn.Module):
151
151
  id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
152
152
  self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
153
153
  kernel = self.id_tensor
154
- std = (identity.running_var + identity.eps).sqrt() # type: ignore
154
+ std = (identity.running_var + identity.eps).sqrt()
155
155
  t = (identity.weight / std).reshape(-1, 1, 1, 1)
156
- return kernel * t, identity.bias - identity.running_mean * identity.weight / std
156
+ return kernel * t, identity.bias - identity.running_mean * identity.weight / std # type: ignore[operator]
157
157
 
158
158
  def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
159
159
  kernel = conv.weight
160
160
  kernel = self._pad_to_mxn_tensor(kernel)
161
161
  std = (bn.running_var + bn.eps).sqrt() # type: ignore
162
162
  t = (bn.weight / std).reshape(-1, 1, 1, 1)
163
- return kernel * t, bn.bias - bn.running_mean * bn.weight / std
163
+ return kernel * t, bn.bias - bn.running_mean * bn.weight / std # type: ignore[operator]
164
164
 
165
165
  def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
166
166
  kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -50,8 +50,8 @@ def scaled_dot_product_attention(
50
50
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
51
51
  if mask is not None:
52
52
  # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
53
- scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
54
- p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
53
+ scores = scores.masked_fill(mask == 0, float("-inf"))
54
+ p_attn = torch.softmax(scores, dim=-1)
55
55
  return torch.matmul(p_attn, value), p_attn
56
56
 
57
57
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -116,18 +116,14 @@ class _OCRPredictor:
116
116
  def _generate_crops(
117
117
  pages: list[np.ndarray],
118
118
  loc_preds: list[np.ndarray],
119
- channels_last: bool,
120
119
  assume_straight_pages: bool = False,
121
120
  assume_horizontal: bool = False,
122
121
  ) -> list[list[np.ndarray]]:
123
122
  if assume_straight_pages:
124
- crops = [
125
- extract_crops(page, _boxes[:, :4], channels_last=channels_last)
126
- for page, _boxes in zip(pages, loc_preds)
127
- ]
123
+ crops = [extract_crops(page, _boxes[:, :4]) for page, _boxes in zip(pages, loc_preds)]
128
124
  else:
129
125
  crops = [
130
- extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
126
+ extract_rcrops(page, _boxes[:, :4], assume_horizontal=assume_horizontal)
131
127
  for page, _boxes in zip(pages, loc_preds)
132
128
  ]
133
129
  return crops
@@ -136,11 +132,10 @@ class _OCRPredictor:
136
132
  def _prepare_crops(
137
133
  pages: list[np.ndarray],
138
134
  loc_preds: list[np.ndarray],
139
- channels_last: bool,
140
135
  assume_straight_pages: bool = False,
141
136
  assume_horizontal: bool = False,
142
137
  ) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
143
- crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
138
+ crops = _OCRPredictor._generate_crops(pages, loc_preds, assume_straight_pages, assume_horizontal)
144
139
 
145
140
  # Avoid sending zero-sized crops
146
141
  is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -68,14 +68,14 @@ class OCRPredictor(nn.Module, _OCRPredictor):
68
68
  @torch.inference_mode()
69
69
  def forward(
70
70
  self,
71
- pages: list[np.ndarray | torch.Tensor],
71
+ pages: list[np.ndarray],
72
72
  **kwargs: Any,
73
73
  ) -> Document:
74
74
  # Dimension check
75
75
  if any(page.ndim != 3 for page in pages):
76
76
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
77
77
 
78
- origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
78
+ origin_page_shapes = [page.shape[:2] for page in pages]
79
79
 
80
80
  # Localize text elements
81
81
  loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
@@ -109,8 +109,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
109
109
  loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
110
110
  # Detach objectness scores from loc_preds
111
111
  loc_preds, objectness_scores = detach_scores(loc_preds)
112
- # Check whether crop mode should be switched to channels first
113
- channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
114
112
 
115
113
  # Apply hooks to loc_preds if any
116
114
  for hook in self.hooks:
@@ -120,7 +118,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
120
118
  crops, loc_preds = self._prepare_crops(
121
119
  pages,
122
120
  loc_preds,
123
- channels_last=channels_last,
124
121
  assume_straight_pages=self.assume_straight_pages,
125
122
  assume_horizontal=self._page_orientation_disabled,
126
123
  )
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -60,65 +60,60 @@ class PreProcessor(nn.Module):
60
60
 
61
61
  return batches
62
62
 
63
- def sample_transforms(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
63
+ def sample_transforms(self, x: np.ndarray) -> torch.Tensor:
64
64
  if x.ndim != 3:
65
65
  raise AssertionError("expected list of 3D Tensors")
66
- if isinstance(x, np.ndarray):
67
- if x.dtype not in (np.uint8, np.float32):
68
- raise TypeError("unsupported data type for numpy.ndarray")
69
- x = torch.from_numpy(x.copy()).permute(2, 0, 1)
70
- elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
71
- raise TypeError("unsupported data type for torch.Tensor")
66
+ if x.dtype not in (np.uint8, np.float32, np.float16):
67
+ raise TypeError("unsupported data type for numpy.ndarray")
68
+ tensor = torch.from_numpy(x.copy()).permute(2, 0, 1)
72
69
  # Resizing
73
- x = self.resize(x)
70
+ tensor = self.resize(tensor)
74
71
  # Data type
75
- if x.dtype == torch.uint8:
76
- x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
72
+ if tensor.dtype == torch.uint8:
73
+ tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
77
74
  else:
78
- x = x.to(dtype=torch.float32) # type: ignore[union-attr]
75
+ tensor = tensor.to(dtype=torch.float32)
79
76
 
80
- return x # type: ignore[return-value]
77
+ return tensor
81
78
 
82
- def __call__(self, x: torch.Tensor | np.ndarray | list[torch.Tensor | np.ndarray]) -> list[torch.Tensor]:
79
+ def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]:
83
80
  """Prepare document data for model forwarding
84
81
 
85
82
  Args:
86
- x: list of images (np.array) or tensors (already resized and batched)
83
+ x: list of images (np.array) or a single image (np.array) of shape (H, W, C)
87
84
 
88
85
  Returns:
89
- list of page batches
86
+ list of page batches (*, C, H, W) ready for model inference
90
87
  """
91
88
  # Input type check
92
- if isinstance(x, (np.ndarray, torch.Tensor)):
89
+ if isinstance(x, np.ndarray):
93
90
  if x.ndim != 4:
94
91
  raise AssertionError("expected 4D Tensor")
95
- if isinstance(x, np.ndarray):
96
- if x.dtype not in (np.uint8, np.float32):
97
- raise TypeError("unsupported data type for numpy.ndarray")
98
- x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
99
- elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
100
- raise TypeError("unsupported data type for torch.Tensor")
92
+ if x.dtype not in (np.uint8, np.float32, np.float16):
93
+ raise TypeError("unsupported data type for numpy.ndarray")
94
+ tensor = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
95
+
101
96
  # Resizing
102
- if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: # type: ignore[union-attr]
103
- x = F.resize(
104
- x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
97
+ if tensor.shape[-2] != self.resize.size[0] or tensor.shape[-1] != self.resize.size[1]:
98
+ tensor = F.resize(
99
+ tensor, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
105
100
  )
106
101
  # Data type
107
- if x.dtype == torch.uint8: # type: ignore[union-attr]
108
- x = x.to(dtype=torch.float32).div(255).clip(0, 1) # type: ignore[union-attr]
102
+ if tensor.dtype == torch.uint8:
103
+ tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
109
104
  else:
110
- x = x.to(dtype=torch.float32) # type: ignore[union-attr]
111
- batches = [x]
105
+ tensor = tensor.to(dtype=torch.float32)
106
+ batches = [tensor]
112
107
 
113
- elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x):
108
+ elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
114
109
  # Sample transform (to tensor, resize)
115
110
  samples = list(multithread_exec(self.sample_transforms, x))
116
111
  # Batching
117
- batches = self.batch_inputs(samples) # type: ignore[assignment]
112
+ batches = self.batch_inputs(samples)
118
113
  else:
119
114
  raise TypeError(f"invalid input type: {type(x)}")
120
115
 
121
116
  # Batch transforms (normalize)
122
117
  batches = list(multithread_exec(self.normalize, batches))
123
118
 
124
- return batches # type: ignore[return-value]
119
+ return batches
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -15,7 +15,7 @@ from torch.nn import functional as F
15
15
  from doctr.datasets import VOCABS, decode_sequence
16
16
 
17
17
  from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
18
- from ...utils.pytorch import load_pretrained_params
18
+ from ...utils import load_pretrained_params
19
19
  from ..core import RecognitionModel, RecognitionPostProcessor
20
20
 
21
21
  __all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
@@ -25,8 +25,8 @@ default_cfgs: dict[str, dict[str, Any]] = {
25
25
  "mean": (0.694, 0.695, 0.693),
26
26
  "std": (0.299, 0.296, 0.301),
27
27
  "input_shape": (3, 32, 128),
28
- "vocab": VOCABS["legacy_french"],
29
- "url": "https://doctr-static.mindee.com/models?id=v0.3.1/crnn_vgg16_bn-9762b0b0.pt&src=0",
28
+ "vocab": VOCABS["french"],
29
+ "url": "https://doctr-static.mindee.com/models?id=v0.12.0/crnn_vgg16_bn-0417f351.pt&src=0",
30
30
  },
31
31
  "crnn_mobilenet_v3_small": {
32
32
  "mean": (0.694, 0.695, 0.693),
@@ -82,7 +82,7 @@ class CTCPostProcessor(RecognitionPostProcessor):
82
82
 
83
83
  def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
84
84
  """Performs decoding of raw output with CTC and decoding of CTC predictions
85
- with label_to_idx mapping dictionnary
85
+ with label_to_idx mapping dictionary
86
86
 
87
87
  Args:
88
88
  logits: raw output of the model, shape (N, C + 1, seq_len)
@@ -223,7 +223,7 @@ class CRNN(RecognitionModel, nn.Module):
223
223
 
224
224
  if target is None or return_preds:
225
225
  # Disable for torch.compile compatibility
226
- @torch.compiler.disable # type: ignore[attr-defined]
226
+ @torch.compiler.disable
227
227
  def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
228
228
  return self.postprocessor(logits)
229
229
 
@@ -257,7 +257,7 @@ def _crnn(
257
257
  _cfg["input_shape"] = kwargs["input_shape"]
258
258
 
259
259
  # Build the model
260
- model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
260
+ model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type]
261
261
  # Load pretrained parameters
262
262
  if pretrained:
263
263
  # The number of classes is not the same as the number of classes in the pretrained model =>
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
16
16
  from doctr.models.classification import magc_resnet31
17
17
  from doctr.models.modules.transformer import Decoder, PositionalEncoding
18
18
 
19
- from ...utils.pytorch import _bf16_to_float32, load_pretrained_params
19
+ from ...utils import _bf16_to_float32, load_pretrained_params
20
20
  from .base import _MASTER, _MASTERPostProcessor
21
21
 
22
22
  __all__ = ["MASTER", "master"]
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
107
107
  # NOTE: nn.TransformerDecoder takes the inverse from this implementation
108
108
  # [True, True, True, ..., False, False, False] -> False is masked
109
109
  # (N, 1, 1, max_length)
110
- target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1) # type: ignore[attr-defined]
110
+ target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
111
111
  target_length = target.size(1)
112
112
  # sub mask filled diagonal with True = see and False = masked (max_length, max_length)
113
113
  # NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
@@ -140,7 +140,7 @@ class MASTER(_MASTER, nn.Module):
140
140
  # Input length : number of timesteps
141
141
  input_len = model_output.shape[1]
142
142
  # Add one for additional <eos> token (sos disappear in shift!)
143
- seq_len = seq_len + 1 # type: ignore[assignment]
143
+ seq_len = seq_len + 1
144
144
  # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
145
145
  # The "masked" first gt char is <sos>. Delete last logit of the model output.
146
146
  cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
@@ -176,7 +176,7 @@ class MASTER(_MASTER, nn.Module):
176
176
  return_preds: if True, decode logits
177
177
 
178
178
  Returns:
179
- A dictionnary containing eventually loss, logits and predictions.
179
+ A dictionary containing eventually loss, logits and predictions.
180
180
  """
181
181
  # Encode
182
182
  features = self.feat_extractor(x)["features"]
@@ -219,7 +219,7 @@ class MASTER(_MASTER, nn.Module):
219
219
 
220
220
  if return_preds:
221
221
  # Disable for torch.compile compatibility
222
- @torch.compiler.disable # type: ignore[attr-defined]
222
+ @torch.compiler.disable
223
223
  def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
224
224
  return self.postprocessor(logits)
225
225
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,4 +1,4 @@
1
- # Copyright (C) 2021-2025, Mindee.
1
+ # Copyright (C) 2021-2026, Mindee.
2
2
 
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.