python-doctr 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/cord.py +10 -1
- doctr/datasets/funsd.py +11 -1
- doctr/datasets/ic03.py +11 -1
- doctr/datasets/ic13.py +10 -1
- doctr/datasets/iiit5k.py +26 -16
- doctr/datasets/imgur5k.py +10 -1
- doctr/datasets/sroie.py +11 -1
- doctr/datasets/svhn.py +11 -1
- doctr/datasets/svt.py +11 -1
- doctr/datasets/synthtext.py +11 -1
- doctr/datasets/utils.py +7 -2
- doctr/datasets/vocabs.py +6 -2
- doctr/datasets/wildreceipt.py +12 -1
- doctr/file_utils.py +19 -0
- doctr/io/elements.py +12 -4
- doctr/models/builder.py +2 -2
- doctr/models/classification/magc_resnet/tensorflow.py +13 -6
- doctr/models/classification/mobilenet/pytorch.py +2 -0
- doctr/models/classification/mobilenet/tensorflow.py +14 -8
- doctr/models/classification/predictor/pytorch.py +11 -7
- doctr/models/classification/predictor/tensorflow.py +10 -6
- doctr/models/classification/resnet/tensorflow.py +21 -8
- doctr/models/classification/textnet/tensorflow.py +11 -5
- doctr/models/classification/vgg/tensorflow.py +9 -3
- doctr/models/classification/vit/tensorflow.py +10 -4
- doctr/models/classification/zoo.py +22 -10
- doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
- doctr/models/detection/fast/tensorflow.py +14 -11
- doctr/models/detection/linknet/tensorflow.py +23 -11
- doctr/models/detection/predictor/tensorflow.py +2 -2
- doctr/models/factory/hub.py +5 -6
- doctr/models/kie_predictor/base.py +4 -0
- doctr/models/kie_predictor/pytorch.py +4 -0
- doctr/models/kie_predictor/tensorflow.py +8 -1
- doctr/models/modules/transformer/tensorflow.py +0 -2
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/modules/vision_transformer/tensorflow.py +1 -1
- doctr/models/predictor/base.py +24 -12
- doctr/models/predictor/pytorch.py +4 -0
- doctr/models/predictor/tensorflow.py +8 -1
- doctr/models/preprocessor/tensorflow.py +1 -1
- doctr/models/recognition/crnn/tensorflow.py +8 -6
- doctr/models/recognition/master/tensorflow.py +9 -4
- doctr/models/recognition/parseq/tensorflow.py +10 -8
- doctr/models/recognition/sar/tensorflow.py +7 -3
- doctr/models/recognition/vitstr/tensorflow.py +9 -4
- doctr/models/utils/pytorch.py +1 -1
- doctr/models/utils/tensorflow.py +15 -15
- doctr/transforms/functional/pytorch.py +1 -1
- doctr/transforms/modules/pytorch.py +7 -6
- doctr/transforms/modules/tensorflow.py +15 -12
- doctr/utils/geometry.py +106 -19
- doctr/utils/metrics.py +1 -1
- doctr/utils/reconstitution.py +151 -65
- doctr/version.py +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/METADATA +11 -11
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/RECORD +61 -61
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.9.0.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
|
@@ -10,11 +10,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
13
|
-
from tensorflow import
|
|
14
|
-
from tensorflow.keras import Sequential, layers
|
|
13
|
+
from tensorflow.keras import Model, Sequential, layers
|
|
15
14
|
|
|
16
15
|
from doctr.file_utils import CLASS_NAME
|
|
17
|
-
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
|
|
18
17
|
from doctr.utils.repr import NestedObject
|
|
19
18
|
|
|
20
19
|
from ...classification import textnet_base, textnet_small, textnet_tiny
|
|
@@ -29,19 +28,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
29
28
|
"input_shape": (1024, 1024, 3),
|
|
30
29
|
"mean": (0.798, 0.785, 0.772),
|
|
31
30
|
"std": (0.264, 0.2749, 0.287),
|
|
32
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
31
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
|
|
33
32
|
},
|
|
34
33
|
"fast_small": {
|
|
35
34
|
"input_shape": (1024, 1024, 3),
|
|
36
35
|
"mean": (0.798, 0.785, 0.772),
|
|
37
36
|
"std": (0.264, 0.2749, 0.287),
|
|
38
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
37
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
|
|
39
38
|
},
|
|
40
39
|
"fast_base": {
|
|
41
40
|
"input_shape": (1024, 1024, 3),
|
|
42
41
|
"mean": (0.798, 0.785, 0.772),
|
|
43
42
|
"std": (0.264, 0.2749, 0.287),
|
|
44
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
43
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
|
|
45
44
|
},
|
|
46
45
|
}
|
|
47
46
|
|
|
@@ -100,7 +99,7 @@ class FastHead(Sequential):
|
|
|
100
99
|
super().__init__(_layers)
|
|
101
100
|
|
|
102
101
|
|
|
103
|
-
class FAST(_FAST,
|
|
102
|
+
class FAST(_FAST, Model, NestedObject):
|
|
104
103
|
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
|
|
105
104
|
<https://arxiv.org/pdf/2111.02394.pdf>`_.
|
|
106
105
|
|
|
@@ -334,12 +333,16 @@ def _fast(
|
|
|
334
333
|
|
|
335
334
|
# Build the model
|
|
336
335
|
model = FAST(feat_extractor, cfg=_cfg, **kwargs)
|
|
336
|
+
_build_model(model)
|
|
337
|
+
|
|
337
338
|
# Load pretrained parameters
|
|
338
339
|
if pretrained:
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
340
|
+
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
341
|
+
load_pretrained_params(
|
|
342
|
+
model,
|
|
343
|
+
_cfg["url"],
|
|
344
|
+
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
345
|
+
)
|
|
343
346
|
|
|
344
347
|
return model
|
|
345
348
|
|
|
@@ -10,12 +10,17 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
12
|
import tensorflow as tf
|
|
13
|
-
from tensorflow import
|
|
14
|
-
from tensorflow.keras import Model, Sequential, layers
|
|
13
|
+
from tensorflow.keras import Model, Sequential, layers, losses
|
|
15
14
|
|
|
16
15
|
from doctr.file_utils import CLASS_NAME
|
|
17
16
|
from doctr.models.classification import resnet18, resnet34, resnet50
|
|
18
|
-
from doctr.models.utils import
|
|
17
|
+
from doctr.models.utils import (
|
|
18
|
+
IntermediateLayerGetter,
|
|
19
|
+
_bf16_to_float32,
|
|
20
|
+
_build_model,
|
|
21
|
+
conv_sequence,
|
|
22
|
+
load_pretrained_params,
|
|
23
|
+
)
|
|
19
24
|
from doctr.utils.repr import NestedObject
|
|
20
25
|
|
|
21
26
|
from .base import LinkNetPostProcessor, _LinkNet
|
|
@@ -27,19 +32,19 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
27
32
|
"mean": (0.798, 0.785, 0.772),
|
|
28
33
|
"std": (0.264, 0.2749, 0.287),
|
|
29
34
|
"input_shape": (1024, 1024, 3),
|
|
30
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
35
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet18-615a82c5.weights.h5&src=0",
|
|
31
36
|
},
|
|
32
37
|
"linknet_resnet34": {
|
|
33
38
|
"mean": (0.798, 0.785, 0.772),
|
|
34
39
|
"std": (0.264, 0.2749, 0.287),
|
|
35
40
|
"input_shape": (1024, 1024, 3),
|
|
36
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
41
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet34-9d772be5.weights.h5&src=0",
|
|
37
42
|
},
|
|
38
43
|
"linknet_resnet50": {
|
|
39
44
|
"mean": (0.798, 0.785, 0.772),
|
|
40
45
|
"std": (0.264, 0.2749, 0.287),
|
|
41
46
|
"input_shape": (1024, 1024, 3),
|
|
42
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
47
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/linknet_resnet50-6bf6c8b5.weights.h5&src=0",
|
|
43
48
|
},
|
|
44
49
|
}
|
|
45
50
|
|
|
@@ -80,17 +85,17 @@ class LinkNetFPN(Model, NestedObject):
|
|
|
80
85
|
for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1])
|
|
81
86
|
]
|
|
82
87
|
|
|
83
|
-
def call(self, x: List[tf.Tensor]) -> tf.Tensor:
|
|
88
|
+
def call(self, x: List[tf.Tensor], **kwargs: Any) -> tf.Tensor:
|
|
84
89
|
out = 0
|
|
85
90
|
for decoder, fmap in zip(self.decoders, x[::-1]):
|
|
86
|
-
out = decoder(out + fmap)
|
|
91
|
+
out = decoder(out + fmap, **kwargs)
|
|
87
92
|
return out
|
|
88
93
|
|
|
89
94
|
def extra_repr(self) -> str:
|
|
90
95
|
return f"out_chans={self.out_chans}"
|
|
91
96
|
|
|
92
97
|
|
|
93
|
-
class LinkNet(_LinkNet,
|
|
98
|
+
class LinkNet(_LinkNet, Model):
|
|
94
99
|
"""LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
|
|
95
100
|
<https://arxiv.org/pdf/1707.03718.pdf>`_.
|
|
96
101
|
|
|
@@ -187,7 +192,7 @@ class LinkNet(_LinkNet, keras.Model):
|
|
|
187
192
|
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
|
|
188
193
|
seg_mask = tf.cast(seg_mask, tf.float32)
|
|
189
194
|
|
|
190
|
-
bce_loss =
|
|
195
|
+
bce_loss = losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
|
|
191
196
|
proba_map = tf.sigmoid(out_map)
|
|
192
197
|
|
|
193
198
|
# Focal loss
|
|
@@ -275,9 +280,16 @@ def _linknet(
|
|
|
275
280
|
|
|
276
281
|
# Build the model
|
|
277
282
|
model = LinkNet(feat_extractor, cfg=_cfg, **kwargs)
|
|
283
|
+
_build_model(model)
|
|
284
|
+
|
|
278
285
|
# Load pretrained parameters
|
|
279
286
|
if pretrained:
|
|
280
|
-
|
|
287
|
+
# The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
288
|
+
load_pretrained_params(
|
|
289
|
+
model,
|
|
290
|
+
_cfg["url"],
|
|
291
|
+
skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
|
|
292
|
+
)
|
|
281
293
|
|
|
282
294
|
return model
|
|
283
295
|
|
|
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Tuple, Union
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
import tensorflow as tf
|
|
10
|
-
from tensorflow import
|
|
10
|
+
from tensorflow.keras import Model
|
|
11
11
|
|
|
12
12
|
from doctr.models.detection._utils import _remove_padding
|
|
13
13
|
from doctr.models.preprocessor import PreProcessor
|
|
@@ -30,7 +30,7 @@ class DetectionPredictor(NestedObject):
|
|
|
30
30
|
def __init__(
|
|
31
31
|
self,
|
|
32
32
|
pre_processor: PreProcessor,
|
|
33
|
-
model:
|
|
33
|
+
model: Model,
|
|
34
34
|
) -> None:
|
|
35
35
|
self.pre_processor = pre_processor
|
|
36
36
|
self.model = model
|
doctr/models/factory/hub.py
CHANGED
|
@@ -20,7 +20,6 @@ from huggingface_hub import (
|
|
|
20
20
|
get_token_permission,
|
|
21
21
|
hf_hub_download,
|
|
22
22
|
login,
|
|
23
|
-
snapshot_download,
|
|
24
23
|
)
|
|
25
24
|
|
|
26
25
|
from doctr import models
|
|
@@ -33,7 +32,7 @@ __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config
|
|
|
33
32
|
|
|
34
33
|
|
|
35
34
|
AVAILABLE_ARCHS = {
|
|
36
|
-
"classification": models.classification.zoo.ARCHS,
|
|
35
|
+
"classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
|
|
37
36
|
"detection": models.detection.zoo.ARCHS,
|
|
38
37
|
"recognition": models.recognition.zoo.ARCHS,
|
|
39
38
|
}
|
|
@@ -74,7 +73,7 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
74
73
|
weights_path = save_directory / "pytorch_model.bin"
|
|
75
74
|
torch.save(model.state_dict(), weights_path)
|
|
76
75
|
elif is_tf_available():
|
|
77
|
-
weights_path = save_directory / "tf_model
|
|
76
|
+
weights_path = save_directory / "tf_model.weights.h5"
|
|
78
77
|
model.save_weights(str(weights_path))
|
|
79
78
|
|
|
80
79
|
config_path = save_directory / "config.json"
|
|
@@ -174,7 +173,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
174
173
|
|
|
175
174
|
local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
|
|
176
175
|
repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
|
|
177
|
-
repo = Repository(local_dir=local_cache_dir, clone_from=repo_url
|
|
176
|
+
repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)
|
|
178
177
|
|
|
179
178
|
with repo.commit(commit_message):
|
|
180
179
|
_save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
|
|
@@ -225,7 +224,7 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
225
224
|
state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
|
|
226
225
|
model.load_state_dict(state_dict)
|
|
227
226
|
else: # tf
|
|
228
|
-
|
|
229
|
-
model.load_weights(
|
|
227
|
+
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
|
|
228
|
+
model.load_weights(weights)
|
|
230
229
|
|
|
231
230
|
return model
|
|
@@ -46,4 +46,8 @@ class _KIEPredictor(_OCRPredictor):
|
|
|
46
46
|
assume_straight_pages, straighten_pages, preserve_aspect_ratio, symmetric_pad, detect_orientation, **kwargs
|
|
47
47
|
)
|
|
48
48
|
|
|
49
|
+
# Remove the following arguments from kwargs after initialization of the parent class
|
|
50
|
+
kwargs.pop("disable_page_orientation", None)
|
|
51
|
+
kwargs.pop("disable_crop_orientation", None)
|
|
52
|
+
|
|
49
53
|
self.doc_builder: KIEDocumentBuilder = KIEDocumentBuilder(**kwargs)
|
|
@@ -99,6 +99,9 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
99
99
|
origin_pages_orientations = None
|
|
100
100
|
if self.straighten_pages:
|
|
101
101
|
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
|
|
102
|
+
# update page shapes after straightening
|
|
103
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
104
|
+
|
|
102
105
|
# Forward again to get predictions on straight pages
|
|
103
106
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
104
107
|
|
|
@@ -126,6 +129,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
126
129
|
dict_loc_preds[class_name],
|
|
127
130
|
channels_last=channels_last,
|
|
128
131
|
assume_straight_pages=self.assume_straight_pages,
|
|
132
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
129
133
|
)
|
|
130
134
|
# Rectify crop orientation
|
|
131
135
|
crop_orientations: Any = {}
|
|
@@ -99,6 +99,9 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
99
99
|
origin_pages_orientations = None
|
|
100
100
|
if self.straighten_pages:
|
|
101
101
|
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
102
|
+
# update page shapes after straightening
|
|
103
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
104
|
+
|
|
102
105
|
# Forward again to get predictions on straight pages
|
|
103
106
|
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
104
107
|
|
|
@@ -119,7 +122,11 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
119
122
|
crops = {}
|
|
120
123
|
for class_name in dict_loc_preds.keys():
|
|
121
124
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
122
|
-
pages,
|
|
125
|
+
pages,
|
|
126
|
+
dict_loc_preds[class_name],
|
|
127
|
+
channels_last=True,
|
|
128
|
+
assume_straight_pages=self.assume_straight_pages,
|
|
129
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
123
130
|
)
|
|
124
131
|
|
|
125
132
|
# Rectify crop orientation
|
|
@@ -13,8 +13,6 @@ from doctr.utils.repr import NestedObject
|
|
|
13
13
|
|
|
14
14
|
__all__ = ["Decoder", "PositionalEncoding", "EncoderBlock", "PositionwiseFeedForward", "MultiHeadAttention"]
|
|
15
15
|
|
|
16
|
-
tf.config.run_functions_eagerly(True)
|
|
17
|
-
|
|
18
16
|
|
|
19
17
|
class PositionalEncoding(layers.Layer, NestedObject):
|
|
20
18
|
"""Compute positional encoding"""
|
|
@@ -20,7 +20,7 @@ class PatchEmbedding(nn.Module):
|
|
|
20
20
|
channels, height, width = input_shape
|
|
21
21
|
self.patch_size = patch_size
|
|
22
22
|
self.interpolate = True if patch_size[0] == patch_size[1] else False
|
|
23
|
-
self.grid_size = tuple(
|
|
23
|
+
self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
|
|
24
24
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
25
25
|
|
|
26
26
|
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
|
|
@@ -22,7 +22,7 @@ class PatchEmbedding(layers.Layer, NestedObject):
|
|
|
22
22
|
height, width, _ = input_shape
|
|
23
23
|
self.patch_size = patch_size
|
|
24
24
|
self.interpolate = True if patch_size[0] == patch_size[1] else False
|
|
25
|
-
self.grid_size = tuple(
|
|
25
|
+
self.grid_size = tuple(s // p for s, p in zip((height, width), self.patch_size))
|
|
26
26
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
27
27
|
|
|
28
28
|
self.cls_token = self.add_weight(shape=(1, 1, embed_dim), initializer="zeros", trainable=True, name="cls_token")
|
doctr/models/predictor/base.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
from doctr.models.builder import DocumentBuilder
|
|
11
|
-
from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image
|
|
11
|
+
from doctr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image
|
|
12
12
|
|
|
13
13
|
from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds
|
|
14
14
|
from ..classification import crop_orientation_predictor, page_orientation_predictor
|
|
@@ -48,9 +48,15 @@ class _OCRPredictor:
|
|
|
48
48
|
) -> None:
|
|
49
49
|
self.assume_straight_pages = assume_straight_pages
|
|
50
50
|
self.straighten_pages = straighten_pages
|
|
51
|
-
self.
|
|
51
|
+
self._page_orientation_disabled = kwargs.pop("disable_page_orientation", False)
|
|
52
|
+
self._crop_orientation_disabled = kwargs.pop("disable_crop_orientation", False)
|
|
53
|
+
self.crop_orientation_predictor = (
|
|
54
|
+
None
|
|
55
|
+
if assume_straight_pages
|
|
56
|
+
else crop_orientation_predictor(pretrained=True, disabled=self._crop_orientation_disabled)
|
|
57
|
+
)
|
|
52
58
|
self.page_orientation_predictor = (
|
|
53
|
-
page_orientation_predictor(pretrained=True)
|
|
59
|
+
page_orientation_predictor(pretrained=True, disabled=self._page_orientation_disabled)
|
|
54
60
|
if detect_orientation or straighten_pages or not assume_straight_pages
|
|
55
61
|
else None
|
|
56
62
|
)
|
|
@@ -101,8 +107,8 @@ class _OCRPredictor:
|
|
|
101
107
|
]
|
|
102
108
|
)
|
|
103
109
|
return [
|
|
104
|
-
#
|
|
105
|
-
rotate_image(page, angle, expand=page.shape[
|
|
110
|
+
# expand if height and width are not equal, then remove the padding
|
|
111
|
+
remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1]))
|
|
106
112
|
for page, angle in zip(pages, origin_pages_orientations)
|
|
107
113
|
]
|
|
108
114
|
|
|
@@ -112,13 +118,18 @@ class _OCRPredictor:
|
|
|
112
118
|
loc_preds: List[np.ndarray],
|
|
113
119
|
channels_last: bool,
|
|
114
120
|
assume_straight_pages: bool = False,
|
|
121
|
+
assume_horizontal: bool = False,
|
|
115
122
|
) -> List[List[np.ndarray]]:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
123
|
+
if assume_straight_pages:
|
|
124
|
+
crops = [
|
|
125
|
+
extract_crops(page, _boxes[:, :4], channels_last=channels_last)
|
|
126
|
+
for page, _boxes in zip(pages, loc_preds)
|
|
127
|
+
]
|
|
128
|
+
else:
|
|
129
|
+
crops = [
|
|
130
|
+
extract_rcrops(page, _boxes[:, :4], channels_last=channels_last, assume_horizontal=assume_horizontal)
|
|
131
|
+
for page, _boxes in zip(pages, loc_preds)
|
|
132
|
+
]
|
|
122
133
|
return crops
|
|
123
134
|
|
|
124
135
|
@staticmethod
|
|
@@ -127,8 +138,9 @@ class _OCRPredictor:
|
|
|
127
138
|
loc_preds: List[np.ndarray],
|
|
128
139
|
channels_last: bool,
|
|
129
140
|
assume_straight_pages: bool = False,
|
|
141
|
+
assume_horizontal: bool = False,
|
|
130
142
|
) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]:
|
|
131
|
-
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages)
|
|
143
|
+
crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages, assume_horizontal)
|
|
132
144
|
|
|
133
145
|
# Avoid sending zero-sized crops
|
|
134
146
|
is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops]
|
|
@@ -97,6 +97,9 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
97
97
|
origin_pages_orientations = None
|
|
98
98
|
if self.straighten_pages:
|
|
99
99
|
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
|
|
100
|
+
# update page shapes after straightening
|
|
101
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
102
|
+
|
|
100
103
|
# Forward again to get predictions on straight pages
|
|
101
104
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
102
105
|
|
|
@@ -120,6 +123,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
120
123
|
loc_preds,
|
|
121
124
|
channels_last=channels_last,
|
|
122
125
|
assume_straight_pages=self.assume_straight_pages,
|
|
126
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
123
127
|
)
|
|
124
128
|
# Rectify crop orientation and get crop orientation predictions
|
|
125
129
|
crop_orientations: Any = []
|
|
@@ -97,6 +97,9 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
97
97
|
origin_pages_orientations = None
|
|
98
98
|
if self.straighten_pages:
|
|
99
99
|
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
100
|
+
# update page shapes after straightening
|
|
101
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
102
|
+
|
|
100
103
|
# forward again to get predictions on straight pages
|
|
101
104
|
loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
102
105
|
|
|
@@ -113,7 +116,11 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
113
116
|
|
|
114
117
|
# Crop images
|
|
115
118
|
crops, loc_preds = self._prepare_crops(
|
|
116
|
-
pages,
|
|
119
|
+
pages,
|
|
120
|
+
loc_preds,
|
|
121
|
+
channels_last=True,
|
|
122
|
+
assume_straight_pages=self.assume_straight_pages,
|
|
123
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
117
124
|
)
|
|
118
125
|
# Rectify crop orientation and get crop orientation predictions
|
|
119
126
|
crop_orientations: Any = []
|
|
@@ -41,7 +41,7 @@ class PreProcessor(NestedObject):
|
|
|
41
41
|
self.resize = Resize(output_size, **kwargs)
|
|
42
42
|
# Perform the division by 255 at the same time
|
|
43
43
|
self.normalize = Normalize(mean, std)
|
|
44
|
-
self._runs_on_cuda = tf.
|
|
44
|
+
self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
|
|
45
45
|
|
|
46
46
|
def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
|
|
47
47
|
"""Gather samples into batches for inference purposes
|
|
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Model, Sequential
|
|
|
13
13
|
from doctr.datasets import VOCABS
|
|
14
14
|
|
|
15
15
|
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
17
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
@@ -24,21 +24,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
25
25
|
"input_shape": (32, 128, 3),
|
|
26
26
|
"vocab": VOCABS["legacy_french"],
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
27
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
|
|
28
28
|
},
|
|
29
29
|
"crnn_mobilenet_v3_small": {
|
|
30
30
|
"mean": (0.694, 0.695, 0.693),
|
|
31
31
|
"std": (0.299, 0.296, 0.301),
|
|
32
32
|
"input_shape": (32, 128, 3),
|
|
33
33
|
"vocab": VOCABS["french"],
|
|
34
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
34
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
|
|
35
35
|
},
|
|
36
36
|
"crnn_mobilenet_v3_large": {
|
|
37
37
|
"mean": (0.694, 0.695, 0.693),
|
|
38
38
|
"std": (0.299, 0.296, 0.301),
|
|
39
39
|
"input_shape": (32, 128, 3),
|
|
40
40
|
"vocab": VOCABS["french"],
|
|
41
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
41
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
|
|
42
42
|
},
|
|
43
43
|
}
|
|
44
44
|
|
|
@@ -128,7 +128,7 @@ class CRNN(RecognitionModel, Model):
|
|
|
128
128
|
|
|
129
129
|
def __init__(
|
|
130
130
|
self,
|
|
131
|
-
feature_extractor:
|
|
131
|
+
feature_extractor: Model,
|
|
132
132
|
vocab: str,
|
|
133
133
|
rnn_units: int = 128,
|
|
134
134
|
exportable: bool = False,
|
|
@@ -245,9 +245,11 @@ def _crnn(
|
|
|
245
245
|
|
|
246
246
|
# Build the model
|
|
247
247
|
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
248
|
+
_build_model(model)
|
|
248
249
|
# Load pretrained parameters
|
|
249
250
|
if pretrained:
|
|
250
|
-
|
|
251
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
252
|
+
load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
251
253
|
|
|
252
254
|
return model
|
|
253
255
|
|
|
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
|
|
|
13
13
|
from doctr.models.classification import magc_resnet31
|
|
14
14
|
from doctr.models.modules.transformer import Decoder, PositionalEncoding
|
|
15
15
|
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
17
|
from .base import _MASTER, _MASTERPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["MASTER", "master"]
|
|
@@ -25,7 +25,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
26
26
|
"input_shape": (32, 128, 3),
|
|
27
27
|
"vocab": VOCABS["french"],
|
|
28
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0",
|
|
29
29
|
},
|
|
30
30
|
}
|
|
31
31
|
|
|
@@ -51,7 +51,7 @@ class MASTER(_MASTER, Model):
|
|
|
51
51
|
|
|
52
52
|
def __init__(
|
|
53
53
|
self,
|
|
54
|
-
feature_extractor:
|
|
54
|
+
feature_extractor: Model,
|
|
55
55
|
vocab: str,
|
|
56
56
|
d_model: int = 512,
|
|
57
57
|
dff: int = 2048,
|
|
@@ -290,9 +290,14 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
|
|
|
290
290
|
cfg=_cfg,
|
|
291
291
|
**kwargs,
|
|
292
292
|
)
|
|
293
|
+
_build_model(model)
|
|
294
|
+
|
|
293
295
|
# Load pretrained parameters
|
|
294
296
|
if pretrained:
|
|
295
|
-
|
|
297
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
298
|
+
load_pretrained_params(
|
|
299
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
300
|
+
)
|
|
296
301
|
|
|
297
302
|
return model
|
|
298
303
|
|
|
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
|
|
|
16
16
|
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
17
17
|
|
|
18
18
|
from ...classification import vit_s
|
|
19
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
19
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
20
20
|
from .base import _PARSeq, _PARSeqPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -27,7 +27,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
27
27
|
"std": (0.299, 0.296, 0.301),
|
|
28
28
|
"input_shape": (32, 128, 3),
|
|
29
29
|
"vocab": VOCABS["french"],
|
|
30
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
30
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
|
|
31
31
|
},
|
|
32
32
|
}
|
|
33
33
|
|
|
@@ -43,7 +43,7 @@ class CharEmbedding(layers.Layer):
|
|
|
43
43
|
|
|
44
44
|
def __init__(self, vocab_size: int, d_model: int):
|
|
45
45
|
super(CharEmbedding, self).__init__()
|
|
46
|
-
self.embedding =
|
|
46
|
+
self.embedding = layers.Embedding(vocab_size, d_model)
|
|
47
47
|
self.d_model = d_model
|
|
48
48
|
|
|
49
49
|
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
@@ -167,7 +167,6 @@ class PARSeq(_PARSeq, Model):
|
|
|
167
167
|
|
|
168
168
|
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
169
169
|
|
|
170
|
-
@tf.function
|
|
171
170
|
def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
|
|
172
171
|
# Generates permutations of the target sequence.
|
|
173
172
|
# Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
@@ -214,7 +213,6 @@ class PARSeq(_PARSeq, Model):
|
|
|
214
213
|
)
|
|
215
214
|
return combined
|
|
216
215
|
|
|
217
|
-
@tf.function
|
|
218
216
|
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
219
217
|
# Generate source and target mask for the decoder attention.
|
|
220
218
|
sz = permutation.shape[0]
|
|
@@ -234,11 +232,10 @@ class PARSeq(_PARSeq, Model):
|
|
|
234
232
|
target_mask = mask[1:, :-1]
|
|
235
233
|
return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
|
|
236
234
|
|
|
237
|
-
@tf.function
|
|
238
235
|
def decode(
|
|
239
236
|
self,
|
|
240
237
|
target: tf.Tensor,
|
|
241
|
-
memory: tf,
|
|
238
|
+
memory: tf.Tensor,
|
|
242
239
|
target_mask: Optional[tf.Tensor] = None,
|
|
243
240
|
target_query: Optional[tf.Tensor] = None,
|
|
244
241
|
**kwargs: Any,
|
|
@@ -476,9 +473,14 @@ def _parseq(
|
|
|
476
473
|
|
|
477
474
|
# Build the model
|
|
478
475
|
model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
|
|
476
|
+
_build_model(model)
|
|
477
|
+
|
|
479
478
|
# Load pretrained parameters
|
|
480
479
|
if pretrained:
|
|
481
|
-
|
|
480
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
481
|
+
load_pretrained_params(
|
|
482
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
483
|
+
)
|
|
482
484
|
|
|
483
485
|
return model
|
|
484
486
|
|
|
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
|
|
|
13
13
|
from doctr.utils.repr import NestedObject
|
|
14
14
|
|
|
15
15
|
from ...classification import resnet31
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
17
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -24,7 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
25
25
|
"input_shape": (32, 128, 3),
|
|
26
26
|
"vocab": VOCABS["french"],
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
27
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
|
|
28
28
|
},
|
|
29
29
|
}
|
|
30
30
|
|
|
@@ -392,9 +392,13 @@ def _sar(
|
|
|
392
392
|
|
|
393
393
|
# Build the model
|
|
394
394
|
model = SAR(feat_extractor, cfg=_cfg, **kwargs)
|
|
395
|
+
_build_model(model)
|
|
395
396
|
# Load pretrained parameters
|
|
396
397
|
if pretrained:
|
|
397
|
-
|
|
398
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
399
|
+
load_pretrained_params(
|
|
400
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
401
|
+
)
|
|
398
402
|
|
|
399
403
|
return model
|
|
400
404
|
|