python-doctr 0.12.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/contrib/artefacts.py +1 -1
- doctr/contrib/base.py +1 -1
- doctr/datasets/__init__.py +0 -5
- doctr/datasets/coco_text.py +1 -1
- doctr/datasets/cord.py +1 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/base.py +1 -1
- doctr/datasets/datasets/pytorch.py +3 -3
- doctr/datasets/detection.py +1 -1
- doctr/datasets/doc_artefacts.py +1 -1
- doctr/datasets/funsd.py +1 -1
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/generator/base.py +1 -1
- doctr/datasets/generator/pytorch.py +1 -1
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +1 -1
- doctr/datasets/iiit5k.py +1 -1
- doctr/datasets/iiithws.py +1 -1
- doctr/datasets/imgur5k.py +1 -1
- doctr/datasets/mjsynth.py +1 -1
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/orientation.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/sroie.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +1 -1
- doctr/datasets/synthtext.py +1 -1
- doctr/datasets/utils.py +1 -1
- doctr/datasets/vocabs.py +1 -3
- doctr/datasets/wildreceipt.py +1 -1
- doctr/file_utils.py +3 -102
- doctr/io/elements.py +1 -1
- doctr/io/html.py +1 -1
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/base.py +1 -1
- doctr/io/image/pytorch.py +2 -2
- doctr/io/pdf.py +1 -1
- doctr/io/reader.py +1 -1
- doctr/models/_utils.py +56 -18
- doctr/models/builder.py +1 -1
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -3
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +1 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +1 -1
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +2 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +1 -1
- doctr/models/classification/vip/__init__.py +1 -4
- doctr/models/classification/vip/layers/__init__.py +1 -4
- doctr/models/classification/vip/layers/pytorch.py +2 -2
- doctr/models/classification/vip/pytorch.py +1 -1
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +3 -3
- doctr/models/classification/zoo.py +7 -12
- doctr/models/core.py +1 -1
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/_utils/base.py +1 -1
- doctr/models/detection/_utils/pytorch.py +1 -1
- doctr/models/detection/core.py +2 -2
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +5 -13
- doctr/models/detection/differentiable_binarization/pytorch.py +4 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +5 -15
- doctr/models/detection/fast/pytorch.py +5 -5
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +4 -13
- doctr/models/detection/linknet/pytorch.py +3 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +2 -2
- doctr/models/detection/zoo.py +16 -33
- doctr/models/factory/hub.py +26 -34
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/base.py +1 -1
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +4 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +3 -3
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +4 -9
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +28 -33
- doctr/models/recognition/core.py +1 -1
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +7 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/base.py +1 -1
- doctr/models/recognition/master/pytorch.py +6 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/base.py +1 -1
- doctr/models/recognition/parseq/pytorch.py +6 -6
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +8 -17
- doctr/models/recognition/predictor/pytorch.py +2 -3
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +4 -4
- doctr/models/recognition/utils.py +1 -1
- doctr/models/recognition/viptr/__init__.py +1 -4
- doctr/models/recognition/viptr/pytorch.py +4 -4
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/base.py +1 -1
- doctr/models/recognition/vitstr/pytorch.py +4 -4
- doctr/models/recognition/zoo.py +14 -14
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +3 -2
- doctr/models/zoo.py +1 -1
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/base.py +3 -2
- doctr/transforms/functional/pytorch.py +5 -5
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +28 -94
- doctr/transforms/modules/pytorch.py +29 -27
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +1 -2
- doctr/utils/fonts.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/metrics.py +1 -1
- doctr/utils/multithreading.py +1 -1
- doctr/utils/reconstitution.py +1 -1
- doctr/utils/repr.py +1 -1
- doctr/utils/visualization.py +2 -2
- doctr/version.py +1 -1
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/METADATA +30 -80
- python_doctr-1.0.1.dist-info/RECORD +149 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -442
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -418
- doctr/models/classification/textnet/tensorflow.py +0 -275
- doctr/models/classification/vgg/tensorflow.py +0 -125
- doctr/models/classification/vit/tensorflow.py +0 -201
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
- doctr/models/detection/fast/tensorflow.py +0 -427
- doctr/models/detection/linknet/tensorflow.py +0 -377
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -317
- doctr/models/recognition/master/tensorflow.py +0 -320
- doctr/models/recognition/parseq/tensorflow.py +0 -516
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -423
- doctr/models/recognition/vitstr/tensorflow.py +0 -285
- doctr/models/utils/tensorflow.py +0 -189
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.12.0.dist-info/RECORD +0 -180
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.1.dist-info}/zip-safe +0 -0
doctr/models/factory/hub.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -7,26 +7,21 @@
|
|
|
7
7
|
|
|
8
8
|
import json
|
|
9
9
|
import logging
|
|
10
|
-
import os
|
|
11
10
|
import subprocess
|
|
11
|
+
import tempfile
|
|
12
12
|
import textwrap
|
|
13
13
|
from pathlib import Path
|
|
14
14
|
from typing import Any
|
|
15
15
|
|
|
16
|
+
import torch
|
|
16
17
|
from huggingface_hub import (
|
|
17
18
|
HfApi,
|
|
18
|
-
Repository,
|
|
19
19
|
get_token,
|
|
20
|
-
get_token_permission,
|
|
21
20
|
hf_hub_download,
|
|
22
21
|
login,
|
|
23
22
|
)
|
|
24
23
|
|
|
25
24
|
from doctr import models
|
|
26
|
-
from doctr.file_utils import is_tf_available, is_torch_available
|
|
27
|
-
|
|
28
|
-
if is_torch_available():
|
|
29
|
-
import torch
|
|
30
25
|
|
|
31
26
|
__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
|
|
32
27
|
|
|
@@ -41,9 +36,9 @@ AVAILABLE_ARCHS = {
|
|
|
41
36
|
def login_to_hub() -> None: # pragma: no cover
|
|
42
37
|
"""Login to huggingface hub"""
|
|
43
38
|
access_token = get_token()
|
|
44
|
-
if access_token is not None
|
|
39
|
+
if access_token is not None:
|
|
45
40
|
logging.info("Huggingface Hub token found and valid")
|
|
46
|
-
login(token=access_token
|
|
41
|
+
login(token=access_token)
|
|
47
42
|
else:
|
|
48
43
|
login()
|
|
49
44
|
# check if git lfs is installed
|
|
@@ -61,19 +56,14 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
61
56
|
"""Save model and config to disk for pushing to huggingface hub
|
|
62
57
|
|
|
63
58
|
Args:
|
|
64
|
-
model:
|
|
59
|
+
model: PyTorch model to be saved
|
|
65
60
|
save_dir: directory to save model and config
|
|
66
61
|
arch: architecture name
|
|
67
62
|
task: task name
|
|
68
63
|
"""
|
|
69
64
|
save_directory = Path(save_dir)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
weights_path = save_directory / "pytorch_model.bin"
|
|
73
|
-
torch.save(model.state_dict(), weights_path)
|
|
74
|
-
elif is_tf_available():
|
|
75
|
-
weights_path = save_directory / "tf_model.weights.h5"
|
|
76
|
-
model.save_weights(str(weights_path))
|
|
65
|
+
weights_path = save_directory / "pytorch_model.bin"
|
|
66
|
+
torch.save(model.state_dict(), weights_path)
|
|
77
67
|
|
|
78
68
|
config_path = save_directory / "config.json"
|
|
79
69
|
|
|
@@ -96,7 +86,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
96
86
|
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
|
|
97
87
|
|
|
98
88
|
Args:
|
|
99
|
-
model:
|
|
89
|
+
model: PyTorch model to be saved
|
|
100
90
|
model_name: name of the model which is also the repository name
|
|
101
91
|
task: task name
|
|
102
92
|
**kwargs: keyword arguments for push_to_hf_hub
|
|
@@ -120,7 +110,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
120
110
|
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
|
|
121
111
|
</p>
|
|
122
112
|
|
|
123
|
-
**Optical Character Recognition made seamless & accessible to anyone, powered by
|
|
113
|
+
**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**
|
|
124
114
|
|
|
125
115
|
## Task: {task}
|
|
126
116
|
|
|
@@ -169,16 +159,23 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
169
159
|
|
|
170
160
|
commit_message = f"Add {model_name} model"
|
|
171
161
|
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
162
|
+
# Create repository
|
|
163
|
+
api = HfApi()
|
|
164
|
+
api.create_repo(model_name, token=get_token(), exist_ok=False)
|
|
175
165
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
166
|
+
# Save model files to a temporary directory
|
|
167
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
168
|
+
_save_model_and_config_for_hf_hub(model, tmp_dir, arch=arch, task=task)
|
|
169
|
+
readme_path = Path(tmp_dir) / "README.md"
|
|
179
170
|
readme_path.write_text(readme)
|
|
180
171
|
|
|
181
|
-
|
|
172
|
+
# Upload all files to the hub
|
|
173
|
+
api.upload_folder(
|
|
174
|
+
folder_path=tmp_dir,
|
|
175
|
+
repo_id=model_name,
|
|
176
|
+
commit_message=commit_message,
|
|
177
|
+
token=get_token(),
|
|
178
|
+
)
|
|
182
179
|
|
|
183
180
|
|
|
184
181
|
def from_hub(repo_id: str, **kwargs: Any):
|
|
@@ -214,13 +211,8 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
214
211
|
|
|
215
212
|
# update model cfg
|
|
216
213
|
model.cfg = cfg
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
if is_torch_available():
|
|
220
|
-
weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
|
|
221
|
-
else: # tf
|
|
222
|
-
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
|
|
223
|
-
|
|
214
|
+
# load the weights
|
|
215
|
+
weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
|
|
224
216
|
model.from_pretrained(weights)
|
|
225
217
|
|
|
226
218
|
return model
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -68,14 +68,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
68
68
|
@torch.inference_mode()
|
|
69
69
|
def forward(
|
|
70
70
|
self,
|
|
71
|
-
pages: list[np.ndarray
|
|
71
|
+
pages: list[np.ndarray],
|
|
72
72
|
**kwargs: Any,
|
|
73
73
|
) -> Document:
|
|
74
74
|
# Dimension check
|
|
75
75
|
if any(page.ndim != 3 for page in pages):
|
|
76
76
|
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
77
77
|
|
|
78
|
-
origin_page_shapes = [page.shape[:2]
|
|
78
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
79
79
|
|
|
80
80
|
# Localize text elements
|
|
81
81
|
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
@@ -113,9 +113,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
113
113
|
dict_loc_preds[class_name] = _loc_preds
|
|
114
114
|
objectness_scores[class_name] = _scores
|
|
115
115
|
|
|
116
|
-
# Check whether crop mode should be switched to channels first
|
|
117
|
-
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
118
|
-
|
|
119
116
|
# Apply hooks to loc_preds if any
|
|
120
117
|
for hook in self.hooks:
|
|
121
118
|
dict_loc_preds = hook(dict_loc_preds)
|
|
@@ -126,7 +123,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
126
123
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
127
124
|
pages,
|
|
128
125
|
dict_loc_preds[class_name],
|
|
129
|
-
channels_last=channels_last,
|
|
130
126
|
assume_straight_pages=self.assume_straight_pages,
|
|
131
127
|
assume_horizontal=self._page_orientation_disabled,
|
|
132
128
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -151,16 +151,16 @@ class FASTConvLayer(nn.Module):
|
|
|
151
151
|
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
152
152
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
153
153
|
kernel = self.id_tensor
|
|
154
|
-
std = (identity.running_var + identity.eps).sqrt()
|
|
154
|
+
std = (identity.running_var + identity.eps).sqrt()
|
|
155
155
|
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
156
|
-
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
156
|
+
return kernel * t, identity.bias - identity.running_mean * identity.weight / std # type: ignore[operator]
|
|
157
157
|
|
|
158
158
|
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
|
159
159
|
kernel = conv.weight
|
|
160
160
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
161
161
|
std = (bn.running_var + bn.eps).sqrt() # type: ignore
|
|
162
162
|
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
163
|
-
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
|
|
163
|
+
return kernel * t, bn.bias - bn.running_mean * bn.weight / std # type: ignore[operator]
|
|
164
164
|
|
|
165
165
|
def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
166
166
|
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -50,8 +50,8 @@ def scaled_dot_product_attention(
|
|
|
50
50
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
51
51
|
if mask is not None:
|
|
52
52
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
53
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
54
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
53
|
+
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
54
|
+
p_attn = torch.softmax(scores, dim=-1)
|
|
55
55
|
return torch.matmul(p_attn, value), p_attn
|
|
56
56
|
|
|
57
57
|
|
doctr/models/predictor/base.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -116,18 +116,14 @@ class _OCRPredictor:
|
|
|
116
116
|
def _generate_crops(
|
|
117
117
|
pages: list[np.ndarray],
|
|
118
118
|
loc_preds: list[np.ndarray],
|
|
119
|
-
channels_last: bool,
|
|
120
119
|
assume_straight_pages: bool = False,
|
|
121
120
|
assume_horizontal: bool = False,
|
|
122
121
|
) -> list[list[np.ndarray]]:
|
|
123
122
|
if assume_straight_pages:
|
|
124
|
-
crops = [
|
|
125
|
-
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
|
|
126
|
-
for page, _boxes in zip(pages, loc_preds)
|
|
127
|
-
]
|
|
123
|
+
crops = [extract_crops(page, _boxes[:, :4]) for page, _boxes in zip(pages, loc_preds)]
|
|
128
124
|
else:
|
|
129
125
|
crops = [
|
|
130
|
-
extract_rcrops(page, _boxes[:, :4],
|
|
126
|
+
extract_rcrops(page, _boxes[:, :4], assume_horizontal=assume_horizontal)
|
|
131
127
|
for page, _boxes in zip(pages, loc_preds)
|
|
132
128
|
]
|
|
133
129
|
return crops
|
|
@@ -136,11 +132,10 @@ class _OCRPredictor:
|
|
|
136
132
|
def _prepare_crops(
|
|
137
133
|
pages: list[np.ndarray],
|
|
138
134
|
loc_preds: list[np.ndarray],
|
|
139
|
-
channels_last: bool,
|
|
140
135
|
assume_straight_pages: bool = False,
|
|
141
136
|
assume_horizontal: bool = False,
|
|
142
137
|
) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
|
|
143
|
-
crops = _OCRPredictor._generate_crops(pages, loc_preds,
|
|
138
|
+
crops = _OCRPredictor._generate_crops(pages, loc_preds, assume_straight_pages, assume_horizontal)
|
|
144
139
|
|
|
145
140
|
# Avoid sending zero-sized crops
|
|
146
141
|
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -68,14 +68,14 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
68
68
|
@torch.inference_mode()
|
|
69
69
|
def forward(
|
|
70
70
|
self,
|
|
71
|
-
pages: list[np.ndarray
|
|
71
|
+
pages: list[np.ndarray],
|
|
72
72
|
**kwargs: Any,
|
|
73
73
|
) -> Document:
|
|
74
74
|
# Dimension check
|
|
75
75
|
if any(page.ndim != 3 for page in pages):
|
|
76
76
|
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
77
77
|
|
|
78
|
-
origin_page_shapes = [page.shape[:2]
|
|
78
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
79
79
|
|
|
80
80
|
# Localize text elements
|
|
81
81
|
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
@@ -109,8 +109,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
109
109
|
loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
|
|
110
110
|
# Detach objectness scores from loc_preds
|
|
111
111
|
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
112
|
-
# Check whether crop mode should be switched to channels first
|
|
113
|
-
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
114
112
|
|
|
115
113
|
# Apply hooks to loc_preds if any
|
|
116
114
|
for hook in self.hooks:
|
|
@@ -120,7 +118,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
120
118
|
crops, loc_preds = self._prepare_crops(
|
|
121
119
|
pages,
|
|
122
120
|
loc_preds,
|
|
123
|
-
channels_last=channels_last,
|
|
124
121
|
assume_straight_pages=self.assume_straight_pages,
|
|
125
122
|
assume_horizontal=self._page_orientation_disabled,
|
|
126
123
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -60,65 +60,60 @@ class PreProcessor(nn.Module):
|
|
|
60
60
|
|
|
61
61
|
return batches
|
|
62
62
|
|
|
63
|
-
def sample_transforms(self, x: np.ndarray
|
|
63
|
+
def sample_transforms(self, x: np.ndarray) -> torch.Tensor:
|
|
64
64
|
if x.ndim != 3:
|
|
65
65
|
raise AssertionError("expected list of 3D Tensors")
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
x = torch.from_numpy(x.copy()).permute(2, 0, 1)
|
|
70
|
-
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
71
|
-
raise TypeError("unsupported data type for torch.Tensor")
|
|
66
|
+
if x.dtype not in (np.uint8, np.float32, np.float16):
|
|
67
|
+
raise TypeError("unsupported data type for numpy.ndarray")
|
|
68
|
+
tensor = torch.from_numpy(x.copy()).permute(2, 0, 1)
|
|
72
69
|
# Resizing
|
|
73
|
-
|
|
70
|
+
tensor = self.resize(tensor)
|
|
74
71
|
# Data type
|
|
75
|
-
if
|
|
76
|
-
|
|
72
|
+
if tensor.dtype == torch.uint8:
|
|
73
|
+
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
|
|
77
74
|
else:
|
|
78
|
-
|
|
75
|
+
tensor = tensor.to(dtype=torch.float32)
|
|
79
76
|
|
|
80
|
-
return
|
|
77
|
+
return tensor
|
|
81
78
|
|
|
82
|
-
def __call__(self, x:
|
|
79
|
+
def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]:
|
|
83
80
|
"""Prepare document data for model forwarding
|
|
84
81
|
|
|
85
82
|
Args:
|
|
86
|
-
x: list of images (np.array) or
|
|
83
|
+
x: list of images (np.array) or a single image (np.array) of shape (H, W, C)
|
|
87
84
|
|
|
88
85
|
Returns:
|
|
89
|
-
list of page batches
|
|
86
|
+
list of page batches (*, C, H, W) ready for model inference
|
|
90
87
|
"""
|
|
91
88
|
# Input type check
|
|
92
|
-
if isinstance(x,
|
|
89
|
+
if isinstance(x, np.ndarray):
|
|
93
90
|
if x.ndim != 4:
|
|
94
91
|
raise AssertionError("expected 4D Tensor")
|
|
95
|
-
if
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
100
|
-
raise TypeError("unsupported data type for torch.Tensor")
|
|
92
|
+
if x.dtype not in (np.uint8, np.float32, np.float16):
|
|
93
|
+
raise TypeError("unsupported data type for numpy.ndarray")
|
|
94
|
+
tensor = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
|
|
95
|
+
|
|
101
96
|
# Resizing
|
|
102
|
-
if
|
|
103
|
-
|
|
104
|
-
|
|
97
|
+
if tensor.shape[-2] != self.resize.size[0] or tensor.shape[-1] != self.resize.size[1]:
|
|
98
|
+
tensor = F.resize(
|
|
99
|
+
tensor, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
|
|
105
100
|
)
|
|
106
101
|
# Data type
|
|
107
|
-
if
|
|
108
|
-
|
|
102
|
+
if tensor.dtype == torch.uint8:
|
|
103
|
+
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
|
|
109
104
|
else:
|
|
110
|
-
|
|
111
|
-
batches = [
|
|
105
|
+
tensor = tensor.to(dtype=torch.float32)
|
|
106
|
+
batches = [tensor]
|
|
112
107
|
|
|
113
|
-
elif isinstance(x, list) and all(isinstance(sample,
|
|
108
|
+
elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
|
|
114
109
|
# Sample transform (to tensor, resize)
|
|
115
110
|
samples = list(multithread_exec(self.sample_transforms, x))
|
|
116
111
|
# Batching
|
|
117
|
-
batches = self.batch_inputs(samples)
|
|
112
|
+
batches = self.batch_inputs(samples)
|
|
118
113
|
else:
|
|
119
114
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
120
115
|
|
|
121
116
|
# Batch transforms (normalize)
|
|
122
117
|
batches = list(multithread_exec(self.normalize, batches))
|
|
123
118
|
|
|
124
|
-
return batches
|
|
119
|
+
return batches
|
doctr/models/recognition/core.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -15,7 +15,7 @@ from torch.nn import functional as F
|
|
|
15
15
|
from doctr.datasets import VOCABS, decode_sequence
|
|
16
16
|
|
|
17
17
|
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import load_pretrained_params
|
|
19
19
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
@@ -25,8 +25,8 @@ default_cfgs: dict[str, dict[str, Any]] = {
|
|
|
25
25
|
"mean": (0.694, 0.695, 0.693),
|
|
26
26
|
"std": (0.299, 0.296, 0.301),
|
|
27
27
|
"input_shape": (3, 32, 128),
|
|
28
|
-
"vocab": VOCABS["
|
|
29
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"vocab": VOCABS["french"],
|
|
29
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.12.0/crnn_vgg16_bn-0417f351.pt&src=0",
|
|
30
30
|
},
|
|
31
31
|
"crnn_mobilenet_v3_small": {
|
|
32
32
|
"mean": (0.694, 0.695, 0.693),
|
|
@@ -82,7 +82,7 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
82
82
|
|
|
83
83
|
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
84
84
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
85
|
-
with label_to_idx mapping
|
|
85
|
+
with label_to_idx mapping dictionary
|
|
86
86
|
|
|
87
87
|
Args:
|
|
88
88
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
@@ -223,7 +223,7 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
223
223
|
|
|
224
224
|
if target is None or return_preds:
|
|
225
225
|
# Disable for torch.compile compatibility
|
|
226
|
-
@torch.compiler.disable
|
|
226
|
+
@torch.compiler.disable
|
|
227
227
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
228
228
|
return self.postprocessor(logits)
|
|
229
229
|
|
|
@@ -257,7 +257,7 @@ def _crnn(
|
|
|
257
257
|
_cfg["input_shape"] = kwargs["input_shape"]
|
|
258
258
|
|
|
259
259
|
# Build the model
|
|
260
|
-
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
260
|
+
model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type]
|
|
261
261
|
# Load pretrained parameters
|
|
262
262
|
if pretrained:
|
|
263
263
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2026, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
|
|
|
16
16
|
from doctr.models.classification import magc_resnet31
|
|
17
17
|
from doctr.models.modules.transformer import Decoder, PositionalEncoding
|
|
18
18
|
|
|
19
|
-
from ...utils
|
|
19
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
20
20
|
from .base import _MASTER, _MASTERPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["MASTER", "master"]
|
|
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
107
107
|
# NOTE: nn.TransformerDecoder takes the inverse from this implementation
|
|
108
108
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
109
109
|
# (N, 1, 1, max_length)
|
|
110
|
-
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
110
|
+
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
111
111
|
target_length = target.size(1)
|
|
112
112
|
# sub mask filled diagonal with True = see and False = masked (max_length, max_length)
|
|
113
113
|
# NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
|
|
@@ -140,7 +140,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
140
140
|
# Input length : number of timesteps
|
|
141
141
|
input_len = model_output.shape[1]
|
|
142
142
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
143
|
-
seq_len = seq_len + 1
|
|
143
|
+
seq_len = seq_len + 1
|
|
144
144
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
145
145
|
# The "masked" first gt char is <sos>. Delete last logit of the model output.
|
|
146
146
|
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -176,7 +176,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
176
176
|
return_preds: if True, decode logits
|
|
177
177
|
|
|
178
178
|
Returns:
|
|
179
|
-
A
|
|
179
|
+
A dictionary containing eventually loss, logits and predictions.
|
|
180
180
|
"""
|
|
181
181
|
# Encode
|
|
182
182
|
features = self.feat_extractor(x)["features"]
|
|
@@ -219,7 +219,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
219
219
|
|
|
220
220
|
if return_preds:
|
|
221
221
|
# Disable for torch.compile compatibility
|
|
222
|
-
@torch.compiler.disable
|
|
222
|
+
@torch.compiler.disable
|
|
223
223
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
224
224
|
return self.postprocessor(logits)
|
|
225
225
|
|