python-doctr 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/contrib/__init__.py +1 -0
- doctr/contrib/artefacts.py +7 -9
- doctr/contrib/base.py +8 -17
- doctr/datasets/cord.py +8 -7
- doctr/datasets/datasets/__init__.py +4 -4
- doctr/datasets/datasets/base.py +16 -16
- doctr/datasets/datasets/pytorch.py +12 -12
- doctr/datasets/datasets/tensorflow.py +10 -10
- doctr/datasets/detection.py +6 -9
- doctr/datasets/doc_artefacts.py +3 -4
- doctr/datasets/funsd.py +7 -6
- doctr/datasets/generator/__init__.py +4 -4
- doctr/datasets/generator/base.py +16 -17
- doctr/datasets/generator/pytorch.py +1 -3
- doctr/datasets/generator/tensorflow.py +1 -3
- doctr/datasets/ic03.py +4 -5
- doctr/datasets/ic13.py +4 -5
- doctr/datasets/iiit5k.py +6 -5
- doctr/datasets/iiithws.py +4 -5
- doctr/datasets/imgur5k.py +6 -5
- doctr/datasets/loader.py +4 -7
- doctr/datasets/mjsynth.py +6 -5
- doctr/datasets/ocr.py +3 -4
- doctr/datasets/orientation.py +3 -4
- doctr/datasets/recognition.py +3 -4
- doctr/datasets/sroie.py +6 -5
- doctr/datasets/svhn.py +6 -5
- doctr/datasets/svt.py +4 -5
- doctr/datasets/synthtext.py +4 -5
- doctr/datasets/utils.py +34 -29
- doctr/datasets/vocabs.py +17 -7
- doctr/datasets/wildreceipt.py +14 -10
- doctr/file_utils.py +2 -7
- doctr/io/elements.py +59 -79
- doctr/io/html.py +1 -3
- doctr/io/image/__init__.py +3 -3
- doctr/io/image/base.py +2 -5
- doctr/io/image/pytorch.py +3 -12
- doctr/io/image/tensorflow.py +2 -11
- doctr/io/pdf.py +5 -7
- doctr/io/reader.py +5 -11
- doctr/models/_utils.py +14 -22
- doctr/models/builder.py +30 -48
- doctr/models/classification/magc_resnet/__init__.py +3 -3
- doctr/models/classification/magc_resnet/pytorch.py +10 -13
- doctr/models/classification/magc_resnet/tensorflow.py +8 -11
- doctr/models/classification/mobilenet/__init__.py +3 -3
- doctr/models/classification/mobilenet/pytorch.py +5 -17
- doctr/models/classification/mobilenet/tensorflow.py +8 -21
- doctr/models/classification/predictor/__init__.py +4 -4
- doctr/models/classification/predictor/pytorch.py +6 -8
- doctr/models/classification/predictor/tensorflow.py +6 -8
- doctr/models/classification/resnet/__init__.py +4 -4
- doctr/models/classification/resnet/pytorch.py +21 -31
- doctr/models/classification/resnet/tensorflow.py +20 -31
- doctr/models/classification/textnet/__init__.py +3 -3
- doctr/models/classification/textnet/pytorch.py +10 -17
- doctr/models/classification/textnet/tensorflow.py +8 -15
- doctr/models/classification/vgg/__init__.py +3 -3
- doctr/models/classification/vgg/pytorch.py +5 -7
- doctr/models/classification/vgg/tensorflow.py +9 -12
- doctr/models/classification/vit/__init__.py +3 -3
- doctr/models/classification/vit/pytorch.py +8 -14
- doctr/models/classification/vit/tensorflow.py +6 -12
- doctr/models/classification/zoo.py +19 -14
- doctr/models/core.py +3 -3
- doctr/models/detection/_utils/__init__.py +4 -4
- doctr/models/detection/_utils/base.py +4 -7
- doctr/models/detection/_utils/pytorch.py +1 -5
- doctr/models/detection/_utils/tensorflow.py +1 -5
- doctr/models/detection/core.py +2 -8
- doctr/models/detection/differentiable_binarization/__init__.py +4 -4
- doctr/models/detection/differentiable_binarization/base.py +7 -17
- doctr/models/detection/differentiable_binarization/pytorch.py +27 -30
- doctr/models/detection/differentiable_binarization/tensorflow.py +15 -25
- doctr/models/detection/fast/__init__.py +4 -4
- doctr/models/detection/fast/base.py +6 -14
- doctr/models/detection/fast/pytorch.py +24 -31
- doctr/models/detection/fast/tensorflow.py +14 -26
- doctr/models/detection/linknet/__init__.py +4 -4
- doctr/models/detection/linknet/base.py +6 -15
- doctr/models/detection/linknet/pytorch.py +24 -27
- doctr/models/detection/linknet/tensorflow.py +14 -23
- doctr/models/detection/predictor/__init__.py +5 -5
- doctr/models/detection/predictor/pytorch.py +6 -7
- doctr/models/detection/predictor/tensorflow.py +5 -6
- doctr/models/detection/zoo.py +27 -7
- doctr/models/factory/hub.py +3 -7
- doctr/models/kie_predictor/__init__.py +5 -5
- doctr/models/kie_predictor/base.py +4 -5
- doctr/models/kie_predictor/pytorch.py +18 -19
- doctr/models/kie_predictor/tensorflow.py +13 -14
- doctr/models/modules/layers/__init__.py +3 -3
- doctr/models/modules/layers/pytorch.py +6 -9
- doctr/models/modules/layers/tensorflow.py +5 -7
- doctr/models/modules/transformer/__init__.py +3 -3
- doctr/models/modules/transformer/pytorch.py +12 -13
- doctr/models/modules/transformer/tensorflow.py +9 -10
- doctr/models/modules/vision_transformer/__init__.py +3 -3
- doctr/models/modules/vision_transformer/pytorch.py +2 -3
- doctr/models/modules/vision_transformer/tensorflow.py +3 -3
- doctr/models/predictor/__init__.py +5 -5
- doctr/models/predictor/base.py +28 -29
- doctr/models/predictor/pytorch.py +12 -13
- doctr/models/predictor/tensorflow.py +8 -9
- doctr/models/preprocessor/__init__.py +4 -4
- doctr/models/preprocessor/pytorch.py +13 -17
- doctr/models/preprocessor/tensorflow.py +10 -14
- doctr/models/recognition/core.py +3 -7
- doctr/models/recognition/crnn/__init__.py +4 -4
- doctr/models/recognition/crnn/pytorch.py +20 -28
- doctr/models/recognition/crnn/tensorflow.py +11 -23
- doctr/models/recognition/master/__init__.py +3 -3
- doctr/models/recognition/master/base.py +3 -7
- doctr/models/recognition/master/pytorch.py +22 -24
- doctr/models/recognition/master/tensorflow.py +12 -22
- doctr/models/recognition/parseq/__init__.py +3 -3
- doctr/models/recognition/parseq/base.py +3 -7
- doctr/models/recognition/parseq/pytorch.py +26 -26
- doctr/models/recognition/parseq/tensorflow.py +16 -22
- doctr/models/recognition/predictor/__init__.py +5 -5
- doctr/models/recognition/predictor/_utils.py +7 -10
- doctr/models/recognition/predictor/pytorch.py +6 -6
- doctr/models/recognition/predictor/tensorflow.py +5 -6
- doctr/models/recognition/sar/__init__.py +4 -4
- doctr/models/recognition/sar/pytorch.py +20 -21
- doctr/models/recognition/sar/tensorflow.py +12 -21
- doctr/models/recognition/utils.py +5 -10
- doctr/models/recognition/vitstr/__init__.py +4 -4
- doctr/models/recognition/vitstr/base.py +3 -7
- doctr/models/recognition/vitstr/pytorch.py +18 -20
- doctr/models/recognition/vitstr/tensorflow.py +12 -20
- doctr/models/recognition/zoo.py +22 -11
- doctr/models/utils/__init__.py +4 -4
- doctr/models/utils/pytorch.py +14 -17
- doctr/models/utils/tensorflow.py +17 -16
- doctr/models/zoo.py +1 -5
- doctr/transforms/functional/__init__.py +3 -3
- doctr/transforms/functional/base.py +4 -11
- doctr/transforms/functional/pytorch.py +20 -28
- doctr/transforms/functional/tensorflow.py +10 -22
- doctr/transforms/modules/__init__.py +4 -4
- doctr/transforms/modules/base.py +48 -55
- doctr/transforms/modules/pytorch.py +58 -22
- doctr/transforms/modules/tensorflow.py +18 -32
- doctr/utils/common_types.py +8 -9
- doctr/utils/data.py +8 -12
- doctr/utils/fonts.py +2 -7
- doctr/utils/geometry.py +16 -47
- doctr/utils/metrics.py +17 -37
- doctr/utils/multithreading.py +4 -6
- doctr/utils/reconstitution.py +9 -13
- doctr/utils/repr.py +2 -3
- doctr/utils/visualization.py +16 -29
- doctr/version.py +1 -1
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/METADATA +54 -52
- python_doctr-0.11.0.dist-info/RECORD +173 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/WHEEL +1 -1
- python_doctr-0.10.0.dist-info/RECORD +0 -173
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.10.0.dist-info → python_doctr-0.11.0.dist-info}/zip-safe +0 -0
|
@@ -1,11 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
|
-
from typing import Dict, List, Tuple, Union
|
|
9
8
|
|
|
10
9
|
import cv2
|
|
11
10
|
import numpy as np
|
|
@@ -23,7 +22,6 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
23
22
|
"""Implements a post processor for LinkNet model.
|
|
24
23
|
|
|
25
24
|
Args:
|
|
26
|
-
----
|
|
27
25
|
bin_thresh: threshold used to binzarized p_map at inference time
|
|
28
26
|
box_thresh: minimal objectness score to consider a box
|
|
29
27
|
assume_straight_pages: whether the inputs were expected to have horizontal text elements
|
|
@@ -45,11 +43,9 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
45
43
|
"""Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
|
|
46
44
|
|
|
47
45
|
Args:
|
|
48
|
-
----
|
|
49
46
|
points: The first parameter.
|
|
50
47
|
|
|
51
48
|
Returns:
|
|
52
|
-
-------
|
|
53
49
|
a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
|
|
54
50
|
"""
|
|
55
51
|
if not self.assume_straight_pages:
|
|
@@ -94,24 +90,22 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
94
90
|
"""Compute boxes from a bitmap/pred_map: find connected components then filter boxes
|
|
95
91
|
|
|
96
92
|
Args:
|
|
97
|
-
----
|
|
98
93
|
pred: Pred map from differentiable linknet output
|
|
99
94
|
bitmap: Bitmap map computed from pred (binarized)
|
|
100
95
|
angle_tol: Comparison tolerance of the angle with the median angle across the page
|
|
101
96
|
ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
|
|
102
97
|
|
|
103
98
|
Returns:
|
|
104
|
-
-------
|
|
105
99
|
np tensor boxes for the bitmap, each box is a 6-element list
|
|
106
100
|
containing x, y, w, h, alpha, score for the box
|
|
107
101
|
"""
|
|
108
102
|
height, width = bitmap.shape[:2]
|
|
109
|
-
boxes:
|
|
103
|
+
boxes: list[np.ndarray | list[float]] = []
|
|
110
104
|
# get contours from connected components on the bitmap
|
|
111
105
|
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
112
106
|
for contour in contours:
|
|
113
107
|
# Check whether smallest enclosing bounding box is not too small
|
|
114
|
-
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
108
|
+
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
|
|
115
109
|
continue
|
|
116
110
|
# Compute objectness
|
|
117
111
|
if self.assume_straight_pages:
|
|
@@ -152,7 +146,6 @@ class _LinkNet(BaseModel):
|
|
|
152
146
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
153
147
|
|
|
154
148
|
Args:
|
|
155
|
-
----
|
|
156
149
|
out_chan: number of channels for the output
|
|
157
150
|
"""
|
|
158
151
|
|
|
@@ -162,20 +155,18 @@ class _LinkNet(BaseModel):
|
|
|
162
155
|
|
|
163
156
|
def build_target(
|
|
164
157
|
self,
|
|
165
|
-
target:
|
|
166
|
-
output_shape:
|
|
158
|
+
target: list[dict[str, np.ndarray]],
|
|
159
|
+
output_shape: tuple[int, int, int],
|
|
167
160
|
channels_last: bool = True,
|
|
168
|
-
) ->
|
|
161
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
169
162
|
"""Build the target, and it's mask to be used from loss computation.
|
|
170
163
|
|
|
171
164
|
Args:
|
|
172
|
-
----
|
|
173
165
|
target: target coming from dataset
|
|
174
166
|
output_shape: shape of the output of the model without batch_size
|
|
175
167
|
channels_last: whether channels are last or not
|
|
176
168
|
|
|
177
169
|
Returns:
|
|
178
|
-
-------
|
|
179
170
|
the new formatted target and the mask
|
|
180
171
|
"""
|
|
181
172
|
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
|
|
@@ -1,9 +1,10 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import torch
|
|
@@ -20,7 +21,7 @@ from .base import LinkNetPostProcessor, _LinkNet
|
|
|
20
21
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
default_cfgs:
|
|
24
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
24
25
|
"linknet_resnet18": {
|
|
25
26
|
"input_shape": (3, 1024, 1024),
|
|
26
27
|
"mean": (0.798, 0.785, 0.772),
|
|
@@ -43,7 +44,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
class LinkNetFPN(nn.Module):
|
|
46
|
-
def __init__(self, layer_shapes:
|
|
47
|
+
def __init__(self, layer_shapes: list[tuple[int, int, int]]) -> None:
|
|
47
48
|
super().__init__()
|
|
48
49
|
strides = [
|
|
49
50
|
1 if (in_shape[-1] == out_shape[-1]) else 2
|
|
@@ -74,7 +75,7 @@ class LinkNetFPN(nn.Module):
|
|
|
74
75
|
nn.ReLU(inplace=True),
|
|
75
76
|
)
|
|
76
77
|
|
|
77
|
-
def forward(self, feats:
|
|
78
|
+
def forward(self, feats: list[torch.Tensor]) -> torch.Tensor:
|
|
78
79
|
out = feats[-1]
|
|
79
80
|
for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]):
|
|
80
81
|
out = decoder(out) + fmap
|
|
@@ -89,7 +90,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
89
90
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
90
91
|
|
|
91
92
|
Args:
|
|
92
|
-
----
|
|
93
93
|
feature extractor: the backbone serving as feature extractor
|
|
94
94
|
bin_thresh: threshold for binarization of the output feature map
|
|
95
95
|
box_thresh: minimal objectness score to consider a box
|
|
@@ -108,8 +108,8 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
108
108
|
head_chans: int = 32,
|
|
109
109
|
assume_straight_pages: bool = True,
|
|
110
110
|
exportable: bool = False,
|
|
111
|
-
cfg:
|
|
112
|
-
class_names:
|
|
111
|
+
cfg: dict[str, Any] | None = None,
|
|
112
|
+
class_names: list[str] = [CLASS_NAME],
|
|
113
113
|
) -> None:
|
|
114
114
|
super().__init__()
|
|
115
115
|
self.class_names = class_names
|
|
@@ -163,16 +163,16 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
163
163
|
def forward(
|
|
164
164
|
self,
|
|
165
165
|
x: torch.Tensor,
|
|
166
|
-
target:
|
|
166
|
+
target: list[np.ndarray] | None = None,
|
|
167
167
|
return_model_output: bool = False,
|
|
168
168
|
return_preds: bool = False,
|
|
169
169
|
**kwargs: Any,
|
|
170
|
-
) ->
|
|
170
|
+
) -> dict[str, Any]:
|
|
171
171
|
feats = self.feat_extractor(x)
|
|
172
172
|
logits = self.fpn([feats[str(idx)] for idx in range(len(feats))])
|
|
173
173
|
logits = self.classifier(logits)
|
|
174
174
|
|
|
175
|
-
out:
|
|
175
|
+
out: dict[str, Any] = {}
|
|
176
176
|
if self.exportable:
|
|
177
177
|
out["logits"] = logits
|
|
178
178
|
return out
|
|
@@ -183,11 +183,16 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
183
183
|
out["out_map"] = prob_map
|
|
184
184
|
|
|
185
185
|
if target is None or return_preds:
|
|
186
|
-
#
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
186
|
+
# Disable for torch.compile compatibility
|
|
187
|
+
@torch.compiler.disable # type: ignore[attr-defined]
|
|
188
|
+
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
189
|
+
return [
|
|
190
|
+
dict(zip(self.class_names, preds))
|
|
191
|
+
for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
# Post-process boxes (keep only text predictions)
|
|
195
|
+
out["preds"] = _postprocess(prob_map)
|
|
191
196
|
|
|
192
197
|
if target is not None:
|
|
193
198
|
loss = self.compute_loss(logits, target)
|
|
@@ -198,7 +203,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
198
203
|
def compute_loss(
|
|
199
204
|
self,
|
|
200
205
|
out_map: torch.Tensor,
|
|
201
|
-
target:
|
|
206
|
+
target: list[np.ndarray],
|
|
202
207
|
gamma: float = 2.0,
|
|
203
208
|
alpha: float = 0.5,
|
|
204
209
|
eps: float = 1e-8,
|
|
@@ -207,7 +212,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
207
212
|
<https://github.com/tensorflow/addons/>`_.
|
|
208
213
|
|
|
209
214
|
Args:
|
|
210
|
-
----
|
|
211
215
|
out_map: output feature map of the model of shape (N, num_classes, H, W)
|
|
212
216
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
213
217
|
gamma: modulating factor in the focal loss formula
|
|
@@ -215,7 +219,6 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
215
219
|
eps: epsilon factor in dice loss
|
|
216
220
|
|
|
217
221
|
Returns:
|
|
218
|
-
-------
|
|
219
222
|
A loss tensor
|
|
220
223
|
"""
|
|
221
224
|
_target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
|
|
@@ -252,9 +255,9 @@ def _linknet(
|
|
|
252
255
|
arch: str,
|
|
253
256
|
pretrained: bool,
|
|
254
257
|
backbone_fn: Callable[[bool], nn.Module],
|
|
255
|
-
fpn_layers:
|
|
258
|
+
fpn_layers: list[str],
|
|
256
259
|
pretrained_backbone: bool = True,
|
|
257
|
-
ignore_keys:
|
|
260
|
+
ignore_keys: list[str] | None = None,
|
|
258
261
|
**kwargs: Any,
|
|
259
262
|
) -> LinkNet:
|
|
260
263
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -295,12 +298,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
295
298
|
>>> out = model(input_tensor)
|
|
296
299
|
|
|
297
300
|
Args:
|
|
298
|
-
----
|
|
299
301
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
300
302
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
301
303
|
|
|
302
304
|
Returns:
|
|
303
|
-
-------
|
|
304
305
|
text detection architecture
|
|
305
306
|
"""
|
|
306
307
|
return _linknet(
|
|
@@ -327,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
327
328
|
>>> out = model(input_tensor)
|
|
328
329
|
|
|
329
330
|
Args:
|
|
330
|
-
----
|
|
331
331
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
332
332
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
333
333
|
|
|
334
334
|
Returns:
|
|
335
|
-
-------
|
|
336
335
|
text detection architecture
|
|
337
336
|
"""
|
|
338
337
|
return _linknet(
|
|
@@ -359,12 +358,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
359
358
|
>>> out = model(input_tensor)
|
|
360
359
|
|
|
361
360
|
Args:
|
|
362
|
-
----
|
|
363
361
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
364
362
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
365
363
|
|
|
366
364
|
Returns:
|
|
367
|
-
-------
|
|
368
365
|
text detection architecture
|
|
369
366
|
"""
|
|
370
367
|
return _linknet(
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
|
|
7
7
|
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
@@ -27,7 +27,7 @@ from .base import LinkNetPostProcessor, _LinkNet
|
|
|
27
27
|
|
|
28
28
|
__all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
|
|
29
29
|
|
|
30
|
-
default_cfgs:
|
|
30
|
+
default_cfgs: dict[str, dict[str, Any]] = {
|
|
31
31
|
"linknet_resnet18": {
|
|
32
32
|
"mean": (0.798, 0.785, 0.772),
|
|
33
33
|
"std": (0.264, 0.2749, 0.287),
|
|
@@ -73,7 +73,7 @@ class LinkNetFPN(Model, NestedObject):
|
|
|
73
73
|
def __init__(
|
|
74
74
|
self,
|
|
75
75
|
out_chans: int,
|
|
76
|
-
in_shapes:
|
|
76
|
+
in_shapes: list[tuple[int, ...]],
|
|
77
77
|
) -> None:
|
|
78
78
|
super().__init__()
|
|
79
79
|
self.out_chans = out_chans
|
|
@@ -85,7 +85,7 @@ class LinkNetFPN(Model, NestedObject):
|
|
|
85
85
|
for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
|
|
86
86
|
]
|
|
87
87
|
|
|
88
|
-
def call(self, x:
|
|
88
|
+
def call(self, x: list[tf.Tensor], **kwargs: Any) -> tf.Tensor:
|
|
89
89
|
out = 0
|
|
90
90
|
for decoder, fmap in zip(self.decoders, x[::-1]):
|
|
91
91
|
out = decoder(out + fmap, **kwargs)
|
|
@@ -100,7 +100,6 @@ class LinkNet(_LinkNet, Model):
|
|
|
100
100
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
101
101
|
|
|
102
102
|
Args:
|
|
103
|
-
----
|
|
104
103
|
feature extractor: the backbone serving as feature extractor
|
|
105
104
|
fpn_channels: number of channels each extracted feature maps is mapped to
|
|
106
105
|
bin_thresh: threshold for binarization of the output feature map
|
|
@@ -111,7 +110,7 @@ class LinkNet(_LinkNet, Model):
|
|
|
111
110
|
class_names: list of class names
|
|
112
111
|
"""
|
|
113
112
|
|
|
114
|
-
_children_names:
|
|
113
|
+
_children_names: list[str] = ["feat_extractor", "fpn", "classifier", "postprocessor"]
|
|
115
114
|
|
|
116
115
|
def __init__(
|
|
117
116
|
self,
|
|
@@ -121,8 +120,8 @@ class LinkNet(_LinkNet, Model):
|
|
|
121
120
|
box_thresh: float = 0.1,
|
|
122
121
|
assume_straight_pages: bool = True,
|
|
123
122
|
exportable: bool = False,
|
|
124
|
-
cfg:
|
|
125
|
-
class_names:
|
|
123
|
+
cfg: dict[str, Any] | None = None,
|
|
124
|
+
class_names: list[str] = [CLASS_NAME],
|
|
126
125
|
) -> None:
|
|
127
126
|
super().__init__(cfg=cfg)
|
|
128
127
|
|
|
@@ -167,7 +166,7 @@ class LinkNet(_LinkNet, Model):
|
|
|
167
166
|
def compute_loss(
|
|
168
167
|
self,
|
|
169
168
|
out_map: tf.Tensor,
|
|
170
|
-
target:
|
|
169
|
+
target: list[dict[str, np.ndarray]],
|
|
171
170
|
gamma: float = 2.0,
|
|
172
171
|
alpha: float = 0.5,
|
|
173
172
|
eps: float = 1e-8,
|
|
@@ -176,7 +175,6 @@ class LinkNet(_LinkNet, Model):
|
|
|
176
175
|
<https://github.com/tensorflow/addons/>`_.
|
|
177
176
|
|
|
178
177
|
Args:
|
|
179
|
-
----
|
|
180
178
|
out_map: output feature map of the model of shape N x H x W x 1
|
|
181
179
|
target: list of dictionary where each dict has a `boxes` and a `flags` entry
|
|
182
180
|
gamma: modulating factor in the focal loss formula
|
|
@@ -184,7 +182,6 @@ class LinkNet(_LinkNet, Model):
|
|
|
184
182
|
eps: epsilon factor in dice loss
|
|
185
183
|
|
|
186
184
|
Returns:
|
|
187
|
-
-------
|
|
188
185
|
A loss tensor
|
|
189
186
|
"""
|
|
190
187
|
seg_target, seg_mask = self.build_target(target, out_map.shape[1:], True)
|
|
@@ -218,16 +215,16 @@ class LinkNet(_LinkNet, Model):
|
|
|
218
215
|
def call(
|
|
219
216
|
self,
|
|
220
217
|
x: tf.Tensor,
|
|
221
|
-
target:
|
|
218
|
+
target: list[dict[str, np.ndarray]] | None = None,
|
|
222
219
|
return_model_output: bool = False,
|
|
223
220
|
return_preds: bool = False,
|
|
224
221
|
**kwargs: Any,
|
|
225
|
-
) ->
|
|
222
|
+
) -> dict[str, Any]:
|
|
226
223
|
feat_maps = self.feat_extractor(x, **kwargs)
|
|
227
224
|
logits = self.fpn(feat_maps, **kwargs)
|
|
228
225
|
logits = self.classifier(logits, **kwargs)
|
|
229
226
|
|
|
230
|
-
out:
|
|
227
|
+
out: dict[str, tf.Tensor] = {}
|
|
231
228
|
if self.exportable:
|
|
232
229
|
out["logits"] = logits
|
|
233
230
|
return out
|
|
@@ -253,9 +250,9 @@ def _linknet(
|
|
|
253
250
|
arch: str,
|
|
254
251
|
pretrained: bool,
|
|
255
252
|
backbone_fn,
|
|
256
|
-
fpn_layers:
|
|
253
|
+
fpn_layers: list[str],
|
|
257
254
|
pretrained_backbone: bool = True,
|
|
258
|
-
input_shape:
|
|
255
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
259
256
|
**kwargs: Any,
|
|
260
257
|
) -> LinkNet:
|
|
261
258
|
pretrained_backbone = pretrained_backbone and not pretrained
|
|
@@ -305,12 +302,10 @@ def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
305
302
|
>>> out = model(input_tensor)
|
|
306
303
|
|
|
307
304
|
Args:
|
|
308
|
-
----
|
|
309
305
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
310
306
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
311
307
|
|
|
312
308
|
Returns:
|
|
313
|
-
-------
|
|
314
309
|
text detection architecture
|
|
315
310
|
"""
|
|
316
311
|
return _linknet(
|
|
@@ -333,12 +328,10 @@ def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
333
328
|
>>> out = model(input_tensor)
|
|
334
329
|
|
|
335
330
|
Args:
|
|
336
|
-
----
|
|
337
331
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
338
332
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
339
333
|
|
|
340
334
|
Returns:
|
|
341
|
-
-------
|
|
342
335
|
text detection architecture
|
|
343
336
|
"""
|
|
344
337
|
return _linknet(
|
|
@@ -361,12 +354,10 @@ def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet:
|
|
|
361
354
|
>>> out = model(input_tensor)
|
|
362
355
|
|
|
363
356
|
Args:
|
|
364
|
-
----
|
|
365
357
|
pretrained (bool): If True, returns a model pre-trained on our text detection dataset
|
|
366
358
|
**kwargs: keyword arguments of the LinkNet architecture
|
|
367
359
|
|
|
368
360
|
Returns:
|
|
369
|
-
-------
|
|
370
361
|
text detection architecture
|
|
371
362
|
"""
|
|
372
363
|
return _linknet(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import torch
|
|
@@ -20,7 +20,6 @@ class DetectionPredictor(nn.Module):
|
|
|
20
20
|
"""Implements an object able to localize text elements in a document
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
pre_processor: transform inputs for easier batched model inference
|
|
25
24
|
model: core detection architecture
|
|
26
25
|
"""
|
|
@@ -37,10 +36,10 @@ class DetectionPredictor(nn.Module):
|
|
|
37
36
|
@torch.inference_mode()
|
|
38
37
|
def forward(
|
|
39
38
|
self,
|
|
40
|
-
pages:
|
|
39
|
+
pages: list[np.ndarray | torch.Tensor],
|
|
41
40
|
return_maps: bool = False,
|
|
42
41
|
**kwargs: Any,
|
|
43
|
-
) ->
|
|
42
|
+
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
|
44
43
|
# Extract parameters from the preprocessor
|
|
45
44
|
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
46
45
|
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
|
@@ -60,11 +59,11 @@ class DetectionPredictor(nn.Module):
|
|
|
60
59
|
]
|
|
61
60
|
# Remove padding from loc predictions
|
|
62
61
|
preds = _remove_padding(
|
|
63
|
-
pages,
|
|
62
|
+
pages,
|
|
64
63
|
[pred for batch in predicted_batches for pred in batch["preds"]],
|
|
65
64
|
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
66
65
|
symmetric_pad=symmetric_pad,
|
|
67
|
-
assume_straight_pages=assume_straight_pages,
|
|
66
|
+
assume_straight_pages=assume_straight_pages, # type: ignore[arg-type]
|
|
68
67
|
)
|
|
69
68
|
|
|
70
69
|
if return_maps:
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
@@ -20,12 +20,11 @@ class DetectionPredictor(NestedObject):
|
|
|
20
20
|
"""Implements an object able to localize text elements in a document
|
|
21
21
|
|
|
22
22
|
Args:
|
|
23
|
-
----
|
|
24
23
|
pre_processor: transform inputs for easier batched model inference
|
|
25
24
|
model: core detection architecture
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
|
-
_children_names:
|
|
27
|
+
_children_names: list[str] = ["pre_processor", "model"]
|
|
29
28
|
|
|
30
29
|
def __init__(
|
|
31
30
|
self,
|
|
@@ -37,10 +36,10 @@ class DetectionPredictor(NestedObject):
|
|
|
37
36
|
|
|
38
37
|
def __call__(
|
|
39
38
|
self,
|
|
40
|
-
pages:
|
|
39
|
+
pages: list[np.ndarray | tf.Tensor],
|
|
41
40
|
return_maps: bool = False,
|
|
42
41
|
**kwargs: Any,
|
|
43
|
-
) ->
|
|
42
|
+
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
|
44
43
|
# Extract parameters from the preprocessor
|
|
45
44
|
preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
|
|
46
45
|
symmetric_pad = self.pre_processor.resize.symmetric_pad
|
doctr/models/detection/zoo.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
5
|
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
8
|
from doctr.file_utils import is_tf_available, is_torch_available
|
|
9
9
|
|
|
@@ -14,7 +14,7 @@ from .predictor import DetectionPredictor
|
|
|
14
14
|
|
|
15
15
|
__all__ = ["detection_predictor"]
|
|
16
16
|
|
|
17
|
-
ARCHS:
|
|
17
|
+
ARCHS: list[str]
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
if is_tf_available():
|
|
@@ -56,7 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
56
56
|
if isinstance(_model, detection.FAST):
|
|
57
57
|
_model = reparameterize(_model)
|
|
58
58
|
else:
|
|
59
|
-
|
|
59
|
+
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
|
|
60
|
+
if is_torch_available():
|
|
61
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
62
|
+
from doctr.models.utils import _CompiledModule
|
|
63
|
+
|
|
64
|
+
allowed_archs.append(_CompiledModule)
|
|
65
|
+
|
|
66
|
+
if not isinstance(arch, tuple(allowed_archs)):
|
|
60
67
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
61
68
|
|
|
62
69
|
_model = arch
|
|
@@ -79,6 +86,9 @@ def detection_predictor(
|
|
|
79
86
|
arch: Any = "fast_base",
|
|
80
87
|
pretrained: bool = False,
|
|
81
88
|
assume_straight_pages: bool = True,
|
|
89
|
+
preserve_aspect_ratio: bool = True,
|
|
90
|
+
symmetric_pad: bool = True,
|
|
91
|
+
batch_size: int = 2,
|
|
82
92
|
**kwargs: Any,
|
|
83
93
|
) -> DetectionPredictor:
|
|
84
94
|
"""Text detection architecture.
|
|
@@ -90,14 +100,24 @@ def detection_predictor(
|
|
|
90
100
|
>>> out = model([input_page])
|
|
91
101
|
|
|
92
102
|
Args:
|
|
93
|
-
----
|
|
94
103
|
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
|
|
95
104
|
pretrained: If True, returns a model pre-trained on our text detection dataset
|
|
96
105
|
assume_straight_pages: If True, fit straight boxes to the page
|
|
106
|
+
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
|
|
107
|
+
running the detection model on it
|
|
108
|
+
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
|
|
109
|
+
batch_size: number of samples the model processes in parallel
|
|
97
110
|
**kwargs: optional keyword arguments passed to the architecture
|
|
98
111
|
|
|
99
112
|
Returns:
|
|
100
|
-
-------
|
|
101
113
|
Detection predictor
|
|
102
114
|
"""
|
|
103
|
-
return _predictor(
|
|
115
|
+
return _predictor(
|
|
116
|
+
arch=arch,
|
|
117
|
+
pretrained=pretrained,
|
|
118
|
+
assume_straight_pages=assume_straight_pages,
|
|
119
|
+
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
120
|
+
symmetric_pad=symmetric_pad,
|
|
121
|
+
batch_size=batch_size,
|
|
122
|
+
**kwargs,
|
|
123
|
+
)
|
doctr/models/factory/hub.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2025, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -61,7 +61,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
61
61
|
"""Save model and config to disk for pushing to huggingface hub
|
|
62
62
|
|
|
63
63
|
Args:
|
|
64
|
-
----
|
|
65
64
|
model: TF or PyTorch model to be saved
|
|
66
65
|
save_dir: directory to save model and config
|
|
67
66
|
arch: architecture name
|
|
@@ -97,7 +96,6 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
97
96
|
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
|
|
98
97
|
|
|
99
98
|
Args:
|
|
100
|
-
----
|
|
101
99
|
model: TF or PyTorch model to be saved
|
|
102
100
|
model_name: name of the model which is also the repository name
|
|
103
101
|
task: task name
|
|
@@ -114,9 +112,9 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
114
112
|
# default readme
|
|
115
113
|
readme = textwrap.dedent(
|
|
116
114
|
f"""
|
|
117
|
-
|
|
115
|
+
|
|
118
116
|
language: en
|
|
119
|
-
|
|
117
|
+
|
|
120
118
|
|
|
121
119
|
<p align="center">
|
|
122
120
|
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
|
|
@@ -190,12 +188,10 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
190
188
|
>>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
|
|
191
189
|
|
|
192
190
|
Args:
|
|
193
|
-
----
|
|
194
191
|
repo_id: HuggingFace model hub repo
|
|
195
192
|
kwargs: kwargs of `hf_hub_download` or `snapshot_download`
|
|
196
193
|
|
|
197
194
|
Returns:
|
|
198
|
-
-------
|
|
199
195
|
Model loaded with the checkpoint
|
|
200
196
|
"""
|
|
201
197
|
# Get the config
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from doctr.file_utils import is_tf_available
|
|
1
|
+
from doctr.file_utils import is_tf_available, is_torch_available
|
|
2
2
|
|
|
3
|
-
if
|
|
4
|
-
from .
|
|
5
|
-
|
|
6
|
-
from .
|
|
3
|
+
if is_torch_available():
|
|
4
|
+
from .pytorch import *
|
|
5
|
+
elif is_tf_available():
|
|
6
|
+
from .tensorflow import * # type: ignore[assignment]
|