python-doctr 0.12.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +0 -5
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/vocabs.py +0 -2
- doctr/file_utils.py +2 -101
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +3 -3
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +2 -2
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +1 -1
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +1 -1
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vip/__init__.py +1 -4
- doctr/models/classification/vip/layers/__init__.py +1 -4
- doctr/models/classification/vip/layers/pytorch.py +1 -1
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +2 -2
- doctr/models/classification/zoo.py +6 -11
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +4 -12
- doctr/models/detection/differentiable_binarization/pytorch.py +3 -3
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +4 -14
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +3 -12
- doctr/models/detection/linknet/pytorch.py +2 -2
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +8 -21
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +2 -6
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +3 -3
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +2 -5
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +6 -6
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +5 -5
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +5 -5
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +7 -16
- doctr/models/recognition/predictor/pytorch.py +1 -2
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +3 -3
- doctr/models/recognition/viptr/__init__.py +1 -4
- doctr/models/recognition/viptr/pytorch.py +3 -3
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +3 -3
- doctr/models/recognition/zoo.py +13 -13
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +1 -1
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/geometry.py +6 -10
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +18 -75
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -442
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -418
- doctr/models/classification/textnet/tensorflow.py +0 -275
- doctr/models/classification/vgg/tensorflow.py +0 -125
- doctr/models/classification/vit/tensorflow.py +0 -201
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -421
- doctr/models/detection/fast/tensorflow.py +0 -427
- doctr/models/detection/linknet/tensorflow.py +0 -377
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -317
- doctr/models/recognition/master/tensorflow.py +0 -320
- doctr/models/recognition/parseq/tensorflow.py +0 -516
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -423
- doctr/models/recognition/vitstr/tensorflow.py +0 -285
- doctr/models/utils/tensorflow.py +0 -189
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.12.0.dist-info/RECORD +0 -180
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/licenses/LICENSE +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.12.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
|
@@ -156,14 +156,12 @@ class _LinkNet(BaseModel):
|
|
|
156
156
|
self,
|
|
157
157
|
target: list[dict[str, np.ndarray]],
|
|
158
158
|
output_shape: tuple[int, int, int],
|
|
159
|
-
channels_last: bool = True,
|
|
160
159
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
161
160
|
"""Build the target, and it's mask to be used from loss computation.
|
|
162
161
|
|
|
163
162
|
Args:
|
|
164
163
|
target: target coming from dataset
|
|
165
164
|
output_shape: shape of the output of the model without batch_size
|
|
166
|
-
channels_last: whether channels are last or not
|
|
167
165
|
|
|
168
166
|
Returns:
|
|
169
167
|
the new formatted target and the mask
|
|
@@ -175,10 +173,8 @@ class _LinkNet(BaseModel):
|
|
|
175
173
|
|
|
176
174
|
h: int
|
|
177
175
|
w: int
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
else:
|
|
181
|
-
num_classes, h, w = output_shape
|
|
176
|
+
|
|
177
|
+
num_classes, h, w = output_shape
|
|
182
178
|
target_shape = (len(target), num_classes, h, w)
|
|
183
179
|
|
|
184
180
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -237,11 +233,6 @@ class _LinkNet(BaseModel):
|
|
|
237
233
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
238
234
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
239
235
|
continue
|
|
240
|
-
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
241
|
-
|
|
242
|
-
# Don't forget to switch back to channel last if Tensorflow is used
|
|
243
|
-
if channels_last:
|
|
244
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
245
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
236
|
+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
246
237
|
|
|
247
238
|
return seg_target, seg_mask
|
|
@@ -193,7 +193,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
193
193
|
|
|
194
194
|
if target is None or return_preds:
|
|
195
195
|
# Disable for torch.compile compatibility
|
|
196
|
-
@torch.compiler.disable
|
|
196
|
+
@torch.compiler.disable
|
|
197
197
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
198
198
|
return [
|
|
199
199
|
dict(zip(self.class_names, preds))
|
|
@@ -230,7 +230,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
230
230
|
Returns:
|
|
231
231
|
A loss tensor
|
|
232
232
|
"""
|
|
233
|
-
_target, _mask = self.build_target(target, out_map.shape[1:]
|
|
233
|
+
_target, _mask = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
234
234
|
|
|
235
235
|
seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask)
|
|
236
236
|
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
|
|
@@ -36,7 +36,7 @@ class DetectionPredictor(nn.Module):
|
|
|
36
36
|
@torch.inference_mode()
|
|
37
37
|
def forward(
|
|
38
38
|
self,
|
|
39
|
-
pages: list[np.ndarray
|
|
39
|
+
pages: list[np.ndarray],
|
|
40
40
|
return_maps: bool = False,
|
|
41
41
|
**kwargs: Any,
|
|
42
42
|
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
doctr/models/detection/zoo.py
CHANGED
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.
|
|
8
|
+
from doctr.models.utils import _CompiledModule
|
|
9
9
|
|
|
10
10
|
from .. import detection
|
|
11
11
|
from ..detection.fast import reparameterize
|
|
@@ -16,30 +16,17 @@ __all__ = ["detection_predictor"]
|
|
|
16
16
|
|
|
17
17
|
ARCHS: list[str]
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
]
|
|
31
|
-
elif is_torch_available():
|
|
32
|
-
ARCHS = [
|
|
33
|
-
"db_resnet34",
|
|
34
|
-
"db_resnet50",
|
|
35
|
-
"db_mobilenet_v3_large",
|
|
36
|
-
"linknet_resnet18",
|
|
37
|
-
"linknet_resnet34",
|
|
38
|
-
"linknet_resnet50",
|
|
39
|
-
"fast_tiny",
|
|
40
|
-
"fast_small",
|
|
41
|
-
"fast_base",
|
|
42
|
-
]
|
|
19
|
+
ARCHS = [
|
|
20
|
+
"db_resnet34",
|
|
21
|
+
"db_resnet50",
|
|
22
|
+
"db_mobilenet_v3_large",
|
|
23
|
+
"linknet_resnet18",
|
|
24
|
+
"linknet_resnet34",
|
|
25
|
+
"linknet_resnet50",
|
|
26
|
+
"fast_tiny",
|
|
27
|
+
"fast_small",
|
|
28
|
+
"fast_base",
|
|
29
|
+
]
|
|
43
30
|
|
|
44
31
|
|
|
45
32
|
def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
|
|
@@ -56,12 +43,8 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
56
43
|
if isinstance(_model, detection.FAST):
|
|
57
44
|
_model = reparameterize(_model)
|
|
58
45
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
# Adding the type for torch compiled models to the allowed architectures
|
|
62
|
-
from doctr.models.utils import _CompiledModule
|
|
63
|
-
|
|
64
|
-
allowed_archs.append(_CompiledModule)
|
|
46
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
47
|
+
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST, _CompiledModule]
|
|
65
48
|
|
|
66
49
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
67
50
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -76,7 +59,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
76
59
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
77
60
|
kwargs["batch_size"] = kwargs.get("batch_size", 2)
|
|
78
61
|
predictor = DetectionPredictor(
|
|
79
|
-
PreProcessor(_model.cfg["input_shape"][
|
|
62
|
+
PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
|
|
80
63
|
_model,
|
|
81
64
|
)
|
|
82
65
|
return predictor
|
doctr/models/factory/hub.py
CHANGED
|
@@ -13,6 +13,7 @@ import textwrap
|
|
|
13
13
|
from pathlib import Path
|
|
14
14
|
from typing import Any
|
|
15
15
|
|
|
16
|
+
import torch
|
|
16
17
|
from huggingface_hub import (
|
|
17
18
|
HfApi,
|
|
18
19
|
Repository,
|
|
@@ -23,10 +24,6 @@ from huggingface_hub import (
|
|
|
23
24
|
)
|
|
24
25
|
|
|
25
26
|
from doctr import models
|
|
26
|
-
from doctr.file_utils import is_tf_available, is_torch_available
|
|
27
|
-
|
|
28
|
-
if is_torch_available():
|
|
29
|
-
import torch
|
|
30
27
|
|
|
31
28
|
__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
|
|
32
29
|
|
|
@@ -61,19 +58,14 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
61
58
|
"""Save model and config to disk for pushing to huggingface hub
|
|
62
59
|
|
|
63
60
|
Args:
|
|
64
|
-
model:
|
|
61
|
+
model: PyTorch model to be saved
|
|
65
62
|
save_dir: directory to save model and config
|
|
66
63
|
arch: architecture name
|
|
67
64
|
task: task name
|
|
68
65
|
"""
|
|
69
66
|
save_directory = Path(save_dir)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
weights_path = save_directory / "pytorch_model.bin"
|
|
73
|
-
torch.save(model.state_dict(), weights_path)
|
|
74
|
-
elif is_tf_available():
|
|
75
|
-
weights_path = save_directory / "tf_model.weights.h5"
|
|
76
|
-
model.save_weights(str(weights_path))
|
|
67
|
+
weights_path = save_directory / "pytorch_model.bin"
|
|
68
|
+
torch.save(model.state_dict(), weights_path)
|
|
77
69
|
|
|
78
70
|
config_path = save_directory / "config.json"
|
|
79
71
|
|
|
@@ -96,7 +88,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
96
88
|
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
|
|
97
89
|
|
|
98
90
|
Args:
|
|
99
|
-
model:
|
|
91
|
+
model: PyTorch model to be saved
|
|
100
92
|
model_name: name of the model which is also the repository name
|
|
101
93
|
task: task name
|
|
102
94
|
**kwargs: keyword arguments for push_to_hf_hub
|
|
@@ -120,7 +112,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
120
112
|
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
|
|
121
113
|
</p>
|
|
122
114
|
|
|
123
|
-
**Optical Character Recognition made seamless & accessible to anyone, powered by
|
|
115
|
+
**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**
|
|
124
116
|
|
|
125
117
|
## Task: {task}
|
|
126
118
|
|
|
@@ -214,13 +206,8 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
214
206
|
|
|
215
207
|
# update model cfg
|
|
216
208
|
model.cfg = cfg
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
if is_torch_available():
|
|
220
|
-
weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
|
|
221
|
-
else: # tf
|
|
222
|
-
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
|
|
223
|
-
|
|
209
|
+
# load the weights
|
|
210
|
+
weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
|
|
224
211
|
model.from_pretrained(weights)
|
|
225
212
|
|
|
226
213
|
return model
|
|
@@ -68,14 +68,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
68
68
|
@torch.inference_mode()
|
|
69
69
|
def forward(
|
|
70
70
|
self,
|
|
71
|
-
pages: list[np.ndarray
|
|
71
|
+
pages: list[np.ndarray],
|
|
72
72
|
**kwargs: Any,
|
|
73
73
|
) -> Document:
|
|
74
74
|
# Dimension check
|
|
75
75
|
if any(page.ndim != 3 for page in pages):
|
|
76
76
|
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
77
77
|
|
|
78
|
-
origin_page_shapes = [page.shape[:2]
|
|
78
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
79
79
|
|
|
80
80
|
# Localize text elements
|
|
81
81
|
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
@@ -113,9 +113,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
113
113
|
dict_loc_preds[class_name] = _loc_preds
|
|
114
114
|
objectness_scores[class_name] = _scores
|
|
115
115
|
|
|
116
|
-
# Check whether crop mode should be switched to channels first
|
|
117
|
-
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
118
|
-
|
|
119
116
|
# Apply hooks to loc_preds if any
|
|
120
117
|
for hook in self.hooks:
|
|
121
118
|
dict_loc_preds = hook(dict_loc_preds)
|
|
@@ -126,7 +123,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
126
123
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
127
124
|
pages,
|
|
128
125
|
dict_loc_preds[class_name],
|
|
129
|
-
channels_last=channels_last,
|
|
130
126
|
assume_straight_pages=self.assume_straight_pages,
|
|
131
127
|
assume_horizontal=self._page_orientation_disabled,
|
|
132
128
|
)
|
|
@@ -151,16 +151,16 @@ class FASTConvLayer(nn.Module):
|
|
|
151
151
|
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
152
152
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
153
153
|
kernel = self.id_tensor
|
|
154
|
-
std = (identity.running_var + identity.eps).sqrt()
|
|
154
|
+
std = (identity.running_var + identity.eps).sqrt()
|
|
155
155
|
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
156
|
-
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
156
|
+
return kernel * t, identity.bias - identity.running_mean * identity.weight / std # type: ignore[operator]
|
|
157
157
|
|
|
158
158
|
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
|
159
159
|
kernel = conv.weight
|
|
160
160
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
161
161
|
std = (bn.running_var + bn.eps).sqrt() # type: ignore
|
|
162
162
|
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
163
|
-
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
|
|
163
|
+
return kernel * t, bn.bias - bn.running_mean * bn.weight / std # type: ignore[operator]
|
|
164
164
|
|
|
165
165
|
def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
166
166
|
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
@@ -50,8 +50,8 @@ def scaled_dot_product_attention(
|
|
|
50
50
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
51
51
|
if mask is not None:
|
|
52
52
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
53
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
54
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
53
|
+
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
54
|
+
p_attn = torch.softmax(scores, dim=-1)
|
|
55
55
|
return torch.matmul(p_attn, value), p_attn
|
|
56
56
|
|
|
57
57
|
|
doctr/models/predictor/base.py
CHANGED
|
@@ -116,18 +116,14 @@ class _OCRPredictor:
|
|
|
116
116
|
def _generate_crops(
|
|
117
117
|
pages: list[np.ndarray],
|
|
118
118
|
loc_preds: list[np.ndarray],
|
|
119
|
-
channels_last: bool,
|
|
120
119
|
assume_straight_pages: bool = False,
|
|
121
120
|
assume_horizontal: bool = False,
|
|
122
121
|
) -> list[list[np.ndarray]]:
|
|
123
122
|
if assume_straight_pages:
|
|
124
|
-
crops = [
|
|
125
|
-
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
|
|
126
|
-
for page, _boxes in zip(pages, loc_preds)
|
|
127
|
-
]
|
|
123
|
+
crops = [extract_crops(page, _boxes[:, :4]) for page, _boxes in zip(pages, loc_preds)]
|
|
128
124
|
else:
|
|
129
125
|
crops = [
|
|
130
|
-
extract_rcrops(page, _boxes[:, :4],
|
|
126
|
+
extract_rcrops(page, _boxes[:, :4], assume_horizontal=assume_horizontal)
|
|
131
127
|
for page, _boxes in zip(pages, loc_preds)
|
|
132
128
|
]
|
|
133
129
|
return crops
|
|
@@ -136,11 +132,10 @@ class _OCRPredictor:
|
|
|
136
132
|
def _prepare_crops(
|
|
137
133
|
pages: list[np.ndarray],
|
|
138
134
|
loc_preds: list[np.ndarray],
|
|
139
|
-
channels_last: bool,
|
|
140
135
|
assume_straight_pages: bool = False,
|
|
141
136
|
assume_horizontal: bool = False,
|
|
142
137
|
) -> tuple[list[list[np.ndarray]], list[np.ndarray]]:
|
|
143
|
-
crops = _OCRPredictor._generate_crops(pages, loc_preds,
|
|
138
|
+
crops = _OCRPredictor._generate_crops(pages, loc_preds, assume_straight_pages, assume_horizontal)
|
|
144
139
|
|
|
145
140
|
# Avoid sending zero-sized crops
|
|
146
141
|
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
|
|
@@ -68,14 +68,14 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
68
68
|
@torch.inference_mode()
|
|
69
69
|
def forward(
|
|
70
70
|
self,
|
|
71
|
-
pages: list[np.ndarray
|
|
71
|
+
pages: list[np.ndarray],
|
|
72
72
|
**kwargs: Any,
|
|
73
73
|
) -> Document:
|
|
74
74
|
# Dimension check
|
|
75
75
|
if any(page.ndim != 3 for page in pages):
|
|
76
76
|
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
77
77
|
|
|
78
|
-
origin_page_shapes = [page.shape[:2]
|
|
78
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
79
79
|
|
|
80
80
|
# Localize text elements
|
|
81
81
|
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
@@ -109,8 +109,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
109
109
|
loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
|
|
110
110
|
# Detach objectness scores from loc_preds
|
|
111
111
|
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
112
|
-
# Check whether crop mode should be switched to channels first
|
|
113
|
-
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
114
112
|
|
|
115
113
|
# Apply hooks to loc_preds if any
|
|
116
114
|
for hook in self.hooks:
|
|
@@ -120,7 +118,6 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
120
118
|
crops, loc_preds = self._prepare_crops(
|
|
121
119
|
pages,
|
|
122
120
|
loc_preds,
|
|
123
|
-
channels_last=channels_last,
|
|
124
121
|
assume_straight_pages=self.assume_straight_pages,
|
|
125
122
|
assume_horizontal=self._page_orientation_disabled,
|
|
126
123
|
)
|
|
@@ -60,65 +60,60 @@ class PreProcessor(nn.Module):
|
|
|
60
60
|
|
|
61
61
|
return batches
|
|
62
62
|
|
|
63
|
-
def sample_transforms(self, x: np.ndarray
|
|
63
|
+
def sample_transforms(self, x: np.ndarray) -> torch.Tensor:
|
|
64
64
|
if x.ndim != 3:
|
|
65
65
|
raise AssertionError("expected list of 3D Tensors")
|
|
66
|
-
if
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
x = torch.from_numpy(x.copy()).permute(2, 0, 1)
|
|
70
|
-
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
71
|
-
raise TypeError("unsupported data type for torch.Tensor")
|
|
66
|
+
if x.dtype not in (np.uint8, np.float32, np.float16):
|
|
67
|
+
raise TypeError("unsupported data type for numpy.ndarray")
|
|
68
|
+
tensor = torch.from_numpy(x.copy()).permute(2, 0, 1)
|
|
72
69
|
# Resizing
|
|
73
|
-
|
|
70
|
+
tensor = self.resize(tensor)
|
|
74
71
|
# Data type
|
|
75
|
-
if
|
|
76
|
-
|
|
72
|
+
if tensor.dtype == torch.uint8:
|
|
73
|
+
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
|
|
77
74
|
else:
|
|
78
|
-
|
|
75
|
+
tensor = tensor.to(dtype=torch.float32)
|
|
79
76
|
|
|
80
|
-
return
|
|
77
|
+
return tensor
|
|
81
78
|
|
|
82
|
-
def __call__(self, x:
|
|
79
|
+
def __call__(self, x: np.ndarray | list[np.ndarray]) -> list[torch.Tensor]:
|
|
83
80
|
"""Prepare document data for model forwarding
|
|
84
81
|
|
|
85
82
|
Args:
|
|
86
|
-
x: list of images (np.array) or
|
|
83
|
+
x: list of images (np.array) or a single image (np.array) of shape (H, W, C)
|
|
87
84
|
|
|
88
85
|
Returns:
|
|
89
|
-
list of page batches
|
|
86
|
+
list of page batches (*, C, H, W) ready for model inference
|
|
90
87
|
"""
|
|
91
88
|
# Input type check
|
|
92
|
-
if isinstance(x,
|
|
89
|
+
if isinstance(x, np.ndarray):
|
|
93
90
|
if x.ndim != 4:
|
|
94
91
|
raise AssertionError("expected 4D Tensor")
|
|
95
|
-
if
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
100
|
-
raise TypeError("unsupported data type for torch.Tensor")
|
|
92
|
+
if x.dtype not in (np.uint8, np.float32, np.float16):
|
|
93
|
+
raise TypeError("unsupported data type for numpy.ndarray")
|
|
94
|
+
tensor = torch.from_numpy(x.copy()).permute(0, 3, 1, 2)
|
|
95
|
+
|
|
101
96
|
# Resizing
|
|
102
|
-
if
|
|
103
|
-
|
|
104
|
-
|
|
97
|
+
if tensor.shape[-2] != self.resize.size[0] or tensor.shape[-1] != self.resize.size[1]:
|
|
98
|
+
tensor = F.resize(
|
|
99
|
+
tensor, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
|
|
105
100
|
)
|
|
106
101
|
# Data type
|
|
107
|
-
if
|
|
108
|
-
|
|
102
|
+
if tensor.dtype == torch.uint8:
|
|
103
|
+
tensor = tensor.to(dtype=torch.float32).div(255).clip(0, 1)
|
|
109
104
|
else:
|
|
110
|
-
|
|
111
|
-
batches = [
|
|
105
|
+
tensor = tensor.to(dtype=torch.float32)
|
|
106
|
+
batches = [tensor]
|
|
112
107
|
|
|
113
|
-
elif isinstance(x, list) and all(isinstance(sample,
|
|
108
|
+
elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
|
|
114
109
|
# Sample transform (to tensor, resize)
|
|
115
110
|
samples = list(multithread_exec(self.sample_transforms, x))
|
|
116
111
|
# Batching
|
|
117
|
-
batches = self.batch_inputs(samples)
|
|
112
|
+
batches = self.batch_inputs(samples)
|
|
118
113
|
else:
|
|
119
114
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
120
115
|
|
|
121
116
|
# Batch transforms (normalize)
|
|
122
117
|
batches = list(multithread_exec(self.normalize, batches))
|
|
123
118
|
|
|
124
|
-
return batches
|
|
119
|
+
return batches
|
|
@@ -15,7 +15,7 @@ from torch.nn import functional as F
|
|
|
15
15
|
from doctr.datasets import VOCABS, decode_sequence
|
|
16
16
|
|
|
17
17
|
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
|
|
18
|
-
from ...utils
|
|
18
|
+
from ...utils import load_pretrained_params
|
|
19
19
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
20
20
|
|
|
21
21
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
@@ -25,8 +25,8 @@ default_cfgs: dict[str, dict[str, Any]] = {
|
|
|
25
25
|
"mean": (0.694, 0.695, 0.693),
|
|
26
26
|
"std": (0.299, 0.296, 0.301),
|
|
27
27
|
"input_shape": (3, 32, 128),
|
|
28
|
-
"vocab": VOCABS["
|
|
29
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"vocab": VOCABS["french"],
|
|
29
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.12.0/crnn_vgg16_bn-0417f351.pt&src=0",
|
|
30
30
|
},
|
|
31
31
|
"crnn_mobilenet_v3_small": {
|
|
32
32
|
"mean": (0.694, 0.695, 0.693),
|
|
@@ -82,7 +82,7 @@ class CTCPostProcessor(RecognitionPostProcessor):
|
|
|
82
82
|
|
|
83
83
|
def __call__(self, logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
84
84
|
"""Performs decoding of raw output with CTC and decoding of CTC predictions
|
|
85
|
-
with label_to_idx mapping
|
|
85
|
+
with label_to_idx mapping dictionary
|
|
86
86
|
|
|
87
87
|
Args:
|
|
88
88
|
logits: raw output of the model, shape (N, C + 1, seq_len)
|
|
@@ -223,7 +223,7 @@ class CRNN(RecognitionModel, nn.Module):
|
|
|
223
223
|
|
|
224
224
|
if target is None or return_preds:
|
|
225
225
|
# Disable for torch.compile compatibility
|
|
226
|
-
@torch.compiler.disable
|
|
226
|
+
@torch.compiler.disable
|
|
227
227
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
228
228
|
return self.postprocessor(logits)
|
|
229
229
|
|
|
@@ -257,7 +257,7 @@ def _crnn(
|
|
|
257
257
|
_cfg["input_shape"] = kwargs["input_shape"]
|
|
258
258
|
|
|
259
259
|
# Build the model
|
|
260
|
-
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
260
|
+
model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type]
|
|
261
261
|
# Load pretrained parameters
|
|
262
262
|
if pretrained:
|
|
263
263
|
# The number of classes is not the same as the number of classes in the pretrained model =>
|
|
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
|
|
|
16
16
|
from doctr.models.classification import magc_resnet31
|
|
17
17
|
from doctr.models.modules.transformer import Decoder, PositionalEncoding
|
|
18
18
|
|
|
19
|
-
from ...utils
|
|
19
|
+
from ...utils import _bf16_to_float32, load_pretrained_params
|
|
20
20
|
from .base import _MASTER, _MASTERPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["MASTER", "master"]
|
|
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
107
107
|
# NOTE: nn.TransformerDecoder takes the inverse from this implementation
|
|
108
108
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
109
109
|
# (N, 1, 1, max_length)
|
|
110
|
-
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
110
|
+
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
111
111
|
target_length = target.size(1)
|
|
112
112
|
# sub mask filled diagonal with True = see and False = masked (max_length, max_length)
|
|
113
113
|
# NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
|
|
@@ -140,7 +140,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
140
140
|
# Input length : number of timesteps
|
|
141
141
|
input_len = model_output.shape[1]
|
|
142
142
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
143
|
-
seq_len = seq_len + 1
|
|
143
|
+
seq_len = seq_len + 1
|
|
144
144
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
145
145
|
# The "masked" first gt char is <sos>. Delete last logit of the model output.
|
|
146
146
|
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -176,7 +176,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
176
176
|
return_preds: if True, decode logits
|
|
177
177
|
|
|
178
178
|
Returns:
|
|
179
|
-
A
|
|
179
|
+
A dictionary containing eventually loss, logits and predictions.
|
|
180
180
|
"""
|
|
181
181
|
# Encode
|
|
182
182
|
features = self.feat_extractor(x)["features"]
|
|
@@ -219,7 +219,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
219
219
|
|
|
220
220
|
if return_preds:
|
|
221
221
|
# Disable for torch.compile compatibility
|
|
222
|
-
@torch.compiler.disable
|
|
222
|
+
@torch.compiler.disable
|
|
223
223
|
def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]:
|
|
224
224
|
return self.postprocessor(logits)
|
|
225
225
|
|