deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +8 -25
- deepdoctection/analyzer/dd.py +84 -71
- deepdoctection/dataflow/common.py +9 -5
- deepdoctection/dataflow/custom.py +5 -5
- deepdoctection/dataflow/custom_serialize.py +75 -18
- deepdoctection/dataflow/parallel_map.py +3 -3
- deepdoctection/dataflow/serialize.py +4 -4
- deepdoctection/dataflow/stats.py +3 -3
- deepdoctection/datapoint/annotation.py +78 -56
- deepdoctection/datapoint/box.py +7 -7
- deepdoctection/datapoint/convert.py +6 -6
- deepdoctection/datapoint/image.py +157 -75
- deepdoctection/datapoint/view.py +175 -151
- deepdoctection/datasets/adapter.py +30 -24
- deepdoctection/datasets/base.py +10 -10
- deepdoctection/datasets/dataflow_builder.py +3 -3
- deepdoctection/datasets/info.py +23 -25
- deepdoctection/datasets/instances/doclaynet.py +48 -49
- deepdoctection/datasets/instances/fintabnet.py +44 -45
- deepdoctection/datasets/instances/funsd.py +23 -23
- deepdoctection/datasets/instances/iiitar13k.py +8 -8
- deepdoctection/datasets/instances/layouttest.py +2 -2
- deepdoctection/datasets/instances/publaynet.py +3 -3
- deepdoctection/datasets/instances/pubtables1m.py +18 -18
- deepdoctection/datasets/instances/pubtabnet.py +30 -29
- deepdoctection/datasets/instances/rvlcdip.py +28 -29
- deepdoctection/datasets/instances/xfund.py +51 -30
- deepdoctection/datasets/save.py +6 -6
- deepdoctection/eval/accmetric.py +32 -33
- deepdoctection/eval/base.py +8 -9
- deepdoctection/eval/cocometric.py +13 -12
- deepdoctection/eval/eval.py +32 -26
- deepdoctection/eval/tedsmetric.py +16 -12
- deepdoctection/eval/tp_eval_callback.py +7 -16
- deepdoctection/extern/base.py +339 -134
- deepdoctection/extern/d2detect.py +69 -89
- deepdoctection/extern/deskew.py +11 -10
- deepdoctection/extern/doctrocr.py +81 -64
- deepdoctection/extern/fastlang.py +23 -16
- deepdoctection/extern/hfdetr.py +53 -38
- deepdoctection/extern/hflayoutlm.py +216 -155
- deepdoctection/extern/hflm.py +35 -30
- deepdoctection/extern/model.py +433 -255
- deepdoctection/extern/pdftext.py +15 -15
- deepdoctection/extern/pt/ptutils.py +4 -2
- deepdoctection/extern/tessocr.py +39 -38
- deepdoctection/extern/texocr.py +14 -16
- deepdoctection/extern/tp/tfutils.py +16 -2
- deepdoctection/extern/tp/tpcompat.py +11 -7
- deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
- deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
- deepdoctection/extern/tpdetect.py +40 -45
- deepdoctection/mapper/cats.py +36 -40
- deepdoctection/mapper/cocostruct.py +16 -12
- deepdoctection/mapper/d2struct.py +22 -22
- deepdoctection/mapper/hfstruct.py +7 -7
- deepdoctection/mapper/laylmstruct.py +22 -24
- deepdoctection/mapper/maputils.py +9 -10
- deepdoctection/mapper/match.py +33 -2
- deepdoctection/mapper/misc.py +6 -7
- deepdoctection/mapper/pascalstruct.py +4 -4
- deepdoctection/mapper/prodigystruct.py +6 -6
- deepdoctection/mapper/pubstruct.py +84 -92
- deepdoctection/mapper/tpstruct.py +3 -3
- deepdoctection/mapper/xfundstruct.py +33 -33
- deepdoctection/pipe/anngen.py +39 -14
- deepdoctection/pipe/base.py +68 -99
- deepdoctection/pipe/common.py +181 -85
- deepdoctection/pipe/concurrency.py +14 -10
- deepdoctection/pipe/doctectionpipe.py +24 -21
- deepdoctection/pipe/language.py +20 -25
- deepdoctection/pipe/layout.py +18 -16
- deepdoctection/pipe/lm.py +49 -47
- deepdoctection/pipe/order.py +63 -65
- deepdoctection/pipe/refine.py +102 -109
- deepdoctection/pipe/segment.py +157 -162
- deepdoctection/pipe/sub_layout.py +50 -40
- deepdoctection/pipe/text.py +37 -36
- deepdoctection/pipe/transform.py +19 -16
- deepdoctection/train/d2_frcnn_train.py +27 -25
- deepdoctection/train/hf_detr_train.py +22 -18
- deepdoctection/train/hf_layoutlm_train.py +49 -48
- deepdoctection/train/tp_frcnn_train.py +10 -11
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +13 -6
- deepdoctection/utils/develop.py +4 -4
- deepdoctection/utils/env_info.py +52 -14
- deepdoctection/utils/file_utils.py +6 -11
- deepdoctection/utils/fs.py +41 -14
- deepdoctection/utils/identifier.py +2 -2
- deepdoctection/utils/logger.py +15 -15
- deepdoctection/utils/metacfg.py +7 -7
- deepdoctection/utils/pdf_utils.py +39 -14
- deepdoctection/utils/settings.py +188 -182
- deepdoctection/utils/tqdm.py +1 -1
- deepdoctection/utils/transform.py +14 -9
- deepdoctection/utils/types.py +104 -0
- deepdoctection/utils/utils.py +7 -7
- deepdoctection/utils/viz.py +70 -69
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
- deepdoctection-0.34.dist-info/RECORD +146 -0
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
- deepdoctection/utils/detection_types.py +0 -68
- deepdoctection-0.32.dist-info/RECORD +0 -146
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
|
@@ -23,25 +23,24 @@ from __future__ import annotations
|
|
|
23
23
|
import os
|
|
24
24
|
from abc import ABC
|
|
25
25
|
from pathlib import Path
|
|
26
|
-
from typing import Any,
|
|
26
|
+
from typing import Any, Literal, Mapping, Optional, Union
|
|
27
27
|
from zipfile import ZipFile
|
|
28
28
|
|
|
29
29
|
from lazy_imports import try_import
|
|
30
30
|
|
|
31
|
-
from ..utils.
|
|
31
|
+
from ..utils.env_info import ENV_VARS_TRUE
|
|
32
32
|
from ..utils.error import DependencyError
|
|
33
33
|
from ..utils.file_utils import (
|
|
34
34
|
get_doctr_requirement,
|
|
35
35
|
get_pytorch_requirement,
|
|
36
36
|
get_tensorflow_requirement,
|
|
37
37
|
get_tf_addons_requirements,
|
|
38
|
-
pytorch_available,
|
|
39
|
-
tf_available,
|
|
40
38
|
)
|
|
41
39
|
from ..utils.fs import load_json
|
|
42
40
|
from ..utils.settings import LayoutType, ObjectTypes, PageType, TypeOrStr
|
|
41
|
+
from ..utils.types import PathLikeOrStr, PixelValues, Requirement
|
|
43
42
|
from ..utils.viz import viz_handler
|
|
44
|
-
from .base import DetectionResult, ImageTransformer,
|
|
43
|
+
from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector, TextRecognizer
|
|
45
44
|
from .pt.ptutils import get_torch_device
|
|
46
45
|
from .tp.tfutils import get_tf_device
|
|
47
46
|
|
|
@@ -60,13 +59,24 @@ with try_import() as doctr_import_guard:
|
|
|
60
59
|
from doctr.models.recognition.zoo import ARCHS, recognition
|
|
61
60
|
|
|
62
61
|
|
|
62
|
+
def _get_doctr_requirements() -> list[Requirement]:
|
|
63
|
+
if os.environ.get("DD_USE_TF", "0") in ENV_VARS_TRUE:
|
|
64
|
+
return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
|
|
65
|
+
if os.environ.get("DD_USE_TORCH", "0") in ENV_VARS_TRUE:
|
|
66
|
+
return [get_pytorch_requirement(), get_doctr_requirement()]
|
|
67
|
+
raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextRecognizer")
|
|
68
|
+
|
|
69
|
+
|
|
63
70
|
def _load_model(
|
|
64
|
-
path_weights:
|
|
71
|
+
path_weights: PathLikeOrStr,
|
|
72
|
+
doctr_predictor: Union[DetectionPredictor, RecognitionPredictor],
|
|
73
|
+
device: Union[torch.device, tf.device],
|
|
74
|
+
lib: Literal["PT", "TF"],
|
|
65
75
|
) -> None:
|
|
66
76
|
"""Loading a model either in TF or PT. We only shift the model to the device when using PyTorch. The shift of
|
|
67
77
|
the model to the device in Tensorflow is done in the predict function."""
|
|
68
78
|
if lib == "PT":
|
|
69
|
-
state_dict = torch.load(path_weights, map_location=device)
|
|
79
|
+
state_dict = torch.load(os.fspath(path_weights), map_location=device)
|
|
70
80
|
for key in list(state_dict.keys()):
|
|
71
81
|
state_dict["model." + key] = state_dict.pop(key)
|
|
72
82
|
doctr_predictor.load_state_dict(state_dict)
|
|
@@ -74,27 +84,27 @@ def _load_model(
|
|
|
74
84
|
elif lib == "TF":
|
|
75
85
|
# Unzip the archive
|
|
76
86
|
params_path = Path(path_weights).parent
|
|
77
|
-
is_zip_path = path_weights.endswith(".zip")
|
|
87
|
+
is_zip_path = os.fspath(path_weights).endswith(".zip")
|
|
78
88
|
if is_zip_path:
|
|
79
89
|
with ZipFile(path_weights, "r") as file:
|
|
80
90
|
file.extractall(path=params_path)
|
|
81
91
|
doctr_predictor.model.load_weights(params_path / "weights")
|
|
82
92
|
else:
|
|
83
|
-
doctr_predictor.model.load_weights(path_weights)
|
|
93
|
+
doctr_predictor.model.load_weights(os.fspath(path_weights))
|
|
84
94
|
|
|
85
95
|
|
|
86
96
|
def auto_select_lib_for_doctr() -> Literal["PT", "TF"]:
|
|
87
97
|
"""Auto select the DL library from environment variables"""
|
|
88
|
-
if os.environ.get("USE_TORCH"):
|
|
98
|
+
if os.environ.get("USE_TORCH", "0") in ENV_VARS_TRUE:
|
|
89
99
|
return "PT"
|
|
90
|
-
if os.environ.get("USE_TF"):
|
|
100
|
+
if os.environ.get("USE_TF", "0") in ENV_VARS_TRUE:
|
|
91
101
|
return "TF"
|
|
92
102
|
raise DependencyError("At least one of the env variables USE_TORCH or USE_TF must be set.")
|
|
93
103
|
|
|
94
104
|
|
|
95
105
|
def doctr_predict_text_lines(
|
|
96
|
-
np_img:
|
|
97
|
-
) ->
|
|
106
|
+
np_img: PixelValues, predictor: DetectionPredictor, device: Union[torch.device, tf.device], lib: Literal["TF", "PT"]
|
|
107
|
+
) -> list[DetectionResult]:
|
|
98
108
|
"""
|
|
99
109
|
Generating text line DetectionResult based on Doctr DetectionPredictor.
|
|
100
110
|
|
|
@@ -113,7 +123,7 @@ def doctr_predict_text_lines(
|
|
|
113
123
|
raise DependencyError("Tensorflow or PyTorch must be installed.")
|
|
114
124
|
detection_results = [
|
|
115
125
|
DetectionResult(
|
|
116
|
-
box=box[:4].tolist(), class_id=1, score=box[4], absolute_coords=False, class_name=LayoutType.
|
|
126
|
+
box=box[:4].tolist(), class_id=1, score=box[4], absolute_coords=False, class_name=LayoutType.WORD
|
|
117
127
|
)
|
|
118
128
|
for box in raw_output[0]["words"]
|
|
119
129
|
]
|
|
@@ -121,11 +131,11 @@ def doctr_predict_text_lines(
|
|
|
121
131
|
|
|
122
132
|
|
|
123
133
|
def doctr_predict_text(
|
|
124
|
-
inputs:
|
|
134
|
+
inputs: list[tuple[str, PixelValues]],
|
|
125
135
|
predictor: RecognitionPredictor,
|
|
126
136
|
device: Union[torch.device, tf.device],
|
|
127
137
|
lib: Literal["TF", "PT"],
|
|
128
|
-
) ->
|
|
138
|
+
) -> list[DetectionResult]:
|
|
129
139
|
"""
|
|
130
140
|
Calls Doctr text recognition model on a batch of numpy arrays (text lines predicted from a text line detector) and
|
|
131
141
|
returns the recognized text as DetectionResult
|
|
@@ -155,15 +165,15 @@ def doctr_predict_text(
|
|
|
155
165
|
class DoctrTextlineDetectorMixin(ObjectDetector, ABC):
|
|
156
166
|
"""Base class for Doctr textline detector. This class only implements the basic wrapper functions"""
|
|
157
167
|
|
|
158
|
-
def __init__(self, categories: Mapping[
|
|
159
|
-
self.categories = categories
|
|
168
|
+
def __init__(self, categories: Mapping[int, TypeOrStr], lib: Optional[Literal["PT", "TF"]] = None):
|
|
169
|
+
self.categories = ModelCategories(init_categories=categories)
|
|
160
170
|
self.lib = lib if lib is not None else self.auto_select_lib()
|
|
161
171
|
|
|
162
|
-
def
|
|
163
|
-
return
|
|
172
|
+
def get_category_names(self) -> tuple[ObjectTypes, ...]:
|
|
173
|
+
return self.categories.get_categories(as_dict=False)
|
|
164
174
|
|
|
165
175
|
@staticmethod
|
|
166
|
-
def get_name(path_weights:
|
|
176
|
+
def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
|
|
167
177
|
"""Returns the name of the model"""
|
|
168
178
|
return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])
|
|
169
179
|
|
|
@@ -211,8 +221,8 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
|
|
|
211
221
|
def __init__(
|
|
212
222
|
self,
|
|
213
223
|
architecture: str,
|
|
214
|
-
path_weights:
|
|
215
|
-
categories: Mapping[
|
|
224
|
+
path_weights: PathLikeOrStr,
|
|
225
|
+
categories: Mapping[int, TypeOrStr],
|
|
216
226
|
device: Optional[Union[Literal["cpu", "cuda"], torch.device, tf.device]] = None,
|
|
217
227
|
lib: Optional[Literal["PT", "TF"]] = None,
|
|
218
228
|
) -> None:
|
|
@@ -227,7 +237,7 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
|
|
|
227
237
|
"""
|
|
228
238
|
super().__init__(categories, lib)
|
|
229
239
|
self.architecture = architecture
|
|
230
|
-
self.path_weights = path_weights
|
|
240
|
+
self.path_weights = Path(path_weights)
|
|
231
241
|
|
|
232
242
|
self.name = self.get_name(self.path_weights, self.architecture)
|
|
233
243
|
self.model_id = self.get_model_id()
|
|
@@ -239,37 +249,37 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
|
|
|
239
249
|
|
|
240
250
|
self.doctr_predictor = self.get_wrapped_model(self.architecture, self.path_weights, self.device, self.lib)
|
|
241
251
|
|
|
242
|
-
def predict(self, np_img:
|
|
252
|
+
def predict(self, np_img: PixelValues) -> list[DetectionResult]:
|
|
243
253
|
"""
|
|
244
254
|
Prediction per image.
|
|
245
255
|
|
|
246
256
|
:param np_img: image as numpy array
|
|
247
257
|
:return: A list of DetectionResult
|
|
248
258
|
"""
|
|
249
|
-
|
|
250
|
-
return detection_results
|
|
259
|
+
return doctr_predict_text_lines(np_img, self.doctr_predictor, self.device, self.lib)
|
|
251
260
|
|
|
252
261
|
@classmethod
|
|
253
|
-
def get_requirements(cls) ->
|
|
254
|
-
|
|
255
|
-
return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
|
|
256
|
-
if os.environ.get("DD_USE_TORCH"):
|
|
257
|
-
return [get_pytorch_requirement(), get_doctr_requirement()]
|
|
258
|
-
raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextlineDetector")
|
|
262
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
263
|
+
return _get_doctr_requirements()
|
|
259
264
|
|
|
260
|
-
def clone(self) ->
|
|
261
|
-
return self.__class__(
|
|
265
|
+
def clone(self) -> DoctrTextlineDetector:
|
|
266
|
+
return self.__class__(
|
|
267
|
+
self.architecture, self.path_weights, self.categories.get_categories(), self.device, self.lib
|
|
268
|
+
)
|
|
262
269
|
|
|
263
270
|
@staticmethod
|
|
264
271
|
def load_model(
|
|
265
|
-
path_weights:
|
|
272
|
+
path_weights: PathLikeOrStr,
|
|
273
|
+
doctr_predictor: DetectionPredictor,
|
|
274
|
+
device: Union[torch.device, tf.device],
|
|
275
|
+
lib: Literal["PT", "TF"],
|
|
266
276
|
) -> None:
|
|
267
277
|
"""Loading model weights"""
|
|
268
278
|
_load_model(path_weights, doctr_predictor, device, lib)
|
|
269
279
|
|
|
270
280
|
@staticmethod
|
|
271
281
|
def get_wrapped_model(
|
|
272
|
-
architecture: str, path_weights:
|
|
282
|
+
architecture: str, path_weights: PathLikeOrStr, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
|
|
273
283
|
) -> Any:
|
|
274
284
|
"""
|
|
275
285
|
Get the inner (wrapped) model.
|
|
@@ -290,6 +300,9 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
|
|
|
290
300
|
DoctrTextlineDetector.load_model(path_weights, doctr_predictor, device, lib)
|
|
291
301
|
return doctr_predictor
|
|
292
302
|
|
|
303
|
+
def clear_model(self) -> None:
|
|
304
|
+
self.doctr_predictor = None
|
|
305
|
+
|
|
293
306
|
|
|
294
307
|
class DoctrTextRecognizer(TextRecognizer):
|
|
295
308
|
"""
|
|
@@ -330,10 +343,10 @@ class DoctrTextRecognizer(TextRecognizer):
|
|
|
330
343
|
def __init__(
|
|
331
344
|
self,
|
|
332
345
|
architecture: str,
|
|
333
|
-
path_weights:
|
|
346
|
+
path_weights: PathLikeOrStr,
|
|
334
347
|
device: Optional[Union[Literal["cpu", "cuda"], torch.device, tf.device]] = None,
|
|
335
348
|
lib: Optional[Literal["PT", "TF"]] = None,
|
|
336
|
-
path_config_json: Optional[
|
|
349
|
+
path_config_json: Optional[PathLikeOrStr] = None,
|
|
337
350
|
) -> None:
|
|
338
351
|
"""
|
|
339
352
|
:param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn",
|
|
@@ -349,7 +362,7 @@ class DoctrTextRecognizer(TextRecognizer):
|
|
|
349
362
|
self.lib = lib if lib is not None else self.auto_select_lib()
|
|
350
363
|
|
|
351
364
|
self.architecture = architecture
|
|
352
|
-
self.path_weights = path_weights
|
|
365
|
+
self.path_weights = Path(path_weights)
|
|
353
366
|
|
|
354
367
|
self.name = self.get_name(self.path_weights, self.architecture)
|
|
355
368
|
self.model_id = self.get_model_id()
|
|
@@ -360,13 +373,13 @@ class DoctrTextRecognizer(TextRecognizer):
|
|
|
360
373
|
self.device = get_torch_device(device)
|
|
361
374
|
|
|
362
375
|
self.path_config_json = path_config_json
|
|
363
|
-
self.doctr_predictor = self.build_model(self.architecture, self.path_config_json)
|
|
376
|
+
self.doctr_predictor = self.build_model(self.architecture, self.lib, self.path_config_json)
|
|
364
377
|
self.load_model(self.path_weights, self.doctr_predictor, self.device, self.lib)
|
|
365
378
|
self.doctr_predictor = self.get_wrapped_model(
|
|
366
379
|
self.architecture, self.path_weights, self.device, self.lib, self.path_config_json
|
|
367
380
|
)
|
|
368
381
|
|
|
369
|
-
def predict(self, images:
|
|
382
|
+
def predict(self, images: list[tuple[str, PixelValues]]) -> list[DetectionResult]:
|
|
370
383
|
"""
|
|
371
384
|
Prediction on a batch of text lines
|
|
372
385
|
|
|
@@ -378,25 +391,26 @@ class DoctrTextRecognizer(TextRecognizer):
|
|
|
378
391
|
return []
|
|
379
392
|
|
|
380
393
|
@classmethod
|
|
381
|
-
def get_requirements(cls) ->
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
return [get_pytorch_requirement(), get_doctr_requirement()]
|
|
386
|
-
raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextRecognizer")
|
|
387
|
-
|
|
388
|
-
def clone(self) -> PredictorBase:
|
|
394
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
395
|
+
return _get_doctr_requirements()
|
|
396
|
+
|
|
397
|
+
def clone(self) -> DoctrTextRecognizer:
|
|
389
398
|
return self.__class__(self.architecture, self.path_weights, self.device, self.lib)
|
|
390
399
|
|
|
391
400
|
@staticmethod
|
|
392
401
|
def load_model(
|
|
393
|
-
path_weights:
|
|
402
|
+
path_weights: PathLikeOrStr,
|
|
403
|
+
doctr_predictor: RecognitionPredictor,
|
|
404
|
+
device: Union[torch.device, tf.device],
|
|
405
|
+
lib: Literal["PT", "TF"],
|
|
394
406
|
) -> None:
|
|
395
407
|
"""Loading model weights"""
|
|
396
408
|
_load_model(path_weights, doctr_predictor, device, lib)
|
|
397
409
|
|
|
398
410
|
@staticmethod
|
|
399
|
-
def build_model(
|
|
411
|
+
def build_model(
|
|
412
|
+
architecture: str, lib: Literal["TF", "PT"], path_config_json: Optional[PathLikeOrStr] = None
|
|
413
|
+
) -> RecognitionPredictor:
|
|
400
414
|
"""Building the model"""
|
|
401
415
|
|
|
402
416
|
# inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py
|
|
@@ -419,6 +433,7 @@ class DoctrTextRecognizer(TextRecognizer):
|
|
|
419
433
|
|
|
420
434
|
model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs)
|
|
421
435
|
else:
|
|
436
|
+
# This is not documented, but you can also directly pass the model class to architecture
|
|
422
437
|
if not isinstance(
|
|
423
438
|
architecture,
|
|
424
439
|
(recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq),
|
|
@@ -426,16 +441,16 @@ class DoctrTextRecognizer(TextRecognizer):
|
|
|
426
441
|
raise ValueError(f"unknown architecture: {type(architecture)}")
|
|
427
442
|
model = architecture
|
|
428
443
|
|
|
429
|
-
input_shape = model.cfg["input_shape"][:2] if
|
|
444
|
+
input_shape = model.cfg["input_shape"][:2] if lib == "TF" else model.cfg["input_shape"][-2:]
|
|
430
445
|
return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model)
|
|
431
446
|
|
|
432
447
|
@staticmethod
|
|
433
448
|
def get_wrapped_model(
|
|
434
449
|
architecture: str,
|
|
435
|
-
path_weights:
|
|
450
|
+
path_weights: PathLikeOrStr,
|
|
436
451
|
device: Union[torch.device, tf.device],
|
|
437
452
|
lib: Literal["PT", "TF"],
|
|
438
|
-
path_config_json: Optional[
|
|
453
|
+
path_config_json: Optional[PathLikeOrStr] = None,
|
|
439
454
|
) -> Any:
|
|
440
455
|
"""
|
|
441
456
|
Get the inner (wrapped) model.
|
|
@@ -450,12 +465,12 @@ class DoctrTextRecognizer(TextRecognizer):
|
|
|
450
465
|
a model trained on custom vocab.
|
|
451
466
|
:return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow
|
|
452
467
|
"""
|
|
453
|
-
doctr_predictor = DoctrTextRecognizer.build_model(architecture, path_config_json)
|
|
468
|
+
doctr_predictor = DoctrTextRecognizer.build_model(architecture, lib, path_config_json)
|
|
454
469
|
DoctrTextRecognizer.load_model(path_weights, doctr_predictor, device, lib)
|
|
455
470
|
return doctr_predictor
|
|
456
471
|
|
|
457
472
|
@staticmethod
|
|
458
|
-
def get_name(path_weights:
|
|
473
|
+
def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
|
|
459
474
|
"""Returns the name of the model"""
|
|
460
475
|
return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])
|
|
461
476
|
|
|
@@ -464,6 +479,9 @@ class DoctrTextRecognizer(TextRecognizer):
|
|
|
464
479
|
"""Auto select the DL library from the installed and from environment variables"""
|
|
465
480
|
return auto_select_lib_for_doctr()
|
|
466
481
|
|
|
482
|
+
def clear_model(self) -> None:
|
|
483
|
+
self.doctr_predictor = None
|
|
484
|
+
|
|
467
485
|
|
|
468
486
|
class DocTrRotationTransformer(ImageTransformer):
|
|
469
487
|
"""
|
|
@@ -497,7 +515,7 @@ class DocTrRotationTransformer(ImageTransformer):
|
|
|
497
515
|
self.ratio_threshold_for_lines = ratio_threshold_for_lines
|
|
498
516
|
self.name = "doctr_rotation_transformer"
|
|
499
517
|
|
|
500
|
-
def transform(self, np_img:
|
|
518
|
+
def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
|
|
501
519
|
"""
|
|
502
520
|
Applies the predicted rotation to the image, effectively rotating the image backwards.
|
|
503
521
|
This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
|
|
@@ -508,19 +526,18 @@ class DocTrRotationTransformer(ImageTransformer):
|
|
|
508
526
|
"""
|
|
509
527
|
return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
|
|
510
528
|
|
|
511
|
-
def predict(self, np_img:
|
|
529
|
+
def predict(self, np_img: PixelValues) -> DetectionResult:
|
|
512
530
|
angle = estimate_orientation(np_img, self.number_contours, self.ratio_threshold_for_lines)
|
|
513
531
|
if angle < 0:
|
|
514
532
|
angle += 360
|
|
515
533
|
return DetectionResult(angle=round(angle, 2))
|
|
516
534
|
|
|
517
535
|
@classmethod
|
|
518
|
-
def get_requirements(cls) ->
|
|
536
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
519
537
|
return [get_doctr_requirement()]
|
|
520
538
|
|
|
521
|
-
def clone(self) ->
|
|
539
|
+
def clone(self) -> DocTrRotationTransformer:
|
|
522
540
|
return self.__class__(self.number_contours, self.ratio_threshold_for_lines)
|
|
523
541
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
return PageType.angle
|
|
542
|
+
def get_category_names(self) -> tuple[ObjectTypes, ...]:
|
|
543
|
+
return (PageType.ANGLE,)
|
|
@@ -18,16 +18,20 @@
|
|
|
18
18
|
"""
|
|
19
19
|
Deepdoctection wrappers for fasttext language detection models
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import os
|
|
21
24
|
from abc import ABC
|
|
22
|
-
from copy import copy
|
|
23
25
|
from pathlib import Path
|
|
24
|
-
from
|
|
26
|
+
from types import MappingProxyType
|
|
27
|
+
from typing import Any, Mapping, Union
|
|
25
28
|
|
|
26
29
|
from lazy_imports import try_import
|
|
27
30
|
|
|
28
31
|
from ..utils.file_utils import Requirement, get_fasttext_requirement
|
|
29
32
|
from ..utils.settings import TypeOrStr, get_type
|
|
30
|
-
from .
|
|
33
|
+
from ..utils.types import PathLikeOrStr
|
|
34
|
+
from .base import DetectionResult, LanguageDetector, ModelCategories
|
|
31
35
|
|
|
32
36
|
with try_import() as import_guard:
|
|
33
37
|
from fasttext import load_model # type: ignore
|
|
@@ -38,22 +42,23 @@ class FasttextLangDetectorMixin(LanguageDetector, ABC):
|
|
|
38
42
|
Base class for Fasttext language detection implementation. This class only implements the basic wrapper functions.
|
|
39
43
|
"""
|
|
40
44
|
|
|
41
|
-
def __init__(self, categories: Mapping[str, TypeOrStr]) -> None:
|
|
45
|
+
def __init__(self, categories: Mapping[int, TypeOrStr], categories_orig: Mapping[str, TypeOrStr]) -> None:
|
|
42
46
|
"""
|
|
43
47
|
:param categories: A dict with the model output label and value. We use as convention the ISO 639-2 language
|
|
44
48
|
"""
|
|
45
|
-
self.categories =
|
|
49
|
+
self.categories = ModelCategories(init_categories=categories)
|
|
50
|
+
self.categories_orig = MappingProxyType({cat_orig: get_type(cat) for cat_orig, cat in categories_orig.items()})
|
|
46
51
|
|
|
47
|
-
def output_to_detection_result(self, output: Union[
|
|
52
|
+
def output_to_detection_result(self, output: Union[tuple[Any, Any]]) -> DetectionResult:
|
|
48
53
|
"""
|
|
49
54
|
Generating `DetectionResult` from model output
|
|
50
55
|
:param output: FastText model output
|
|
51
56
|
:return: `DetectionResult` filled with `text` and `score`
|
|
52
57
|
"""
|
|
53
|
-
return DetectionResult(text=self.
|
|
58
|
+
return DetectionResult(text=self.categories_orig[output[0][0]], score=output[1][0])
|
|
54
59
|
|
|
55
60
|
@staticmethod
|
|
56
|
-
def get_name(path_weights:
|
|
61
|
+
def get_name(path_weights: PathLikeOrStr) -> str:
|
|
57
62
|
"""Returns the name of the model"""
|
|
58
63
|
return "fasttext_" + "_".join(Path(path_weights).parts[-2:])
|
|
59
64
|
|
|
@@ -80,15 +85,17 @@ class FasttextLangDetector(FasttextLangDetectorMixin):
|
|
|
80
85
|
|
|
81
86
|
"""
|
|
82
87
|
|
|
83
|
-
def __init__(
|
|
88
|
+
def __init__(
|
|
89
|
+
self, path_weights: PathLikeOrStr, categories: Mapping[int, TypeOrStr], categories_orig: Mapping[str, TypeOrStr]
|
|
90
|
+
):
|
|
84
91
|
"""
|
|
85
92
|
:param path_weights: path to model weights
|
|
86
93
|
:param categories: A dict with the model output label and value. We use as convention the ISO 639-2 language
|
|
87
94
|
code.
|
|
88
95
|
"""
|
|
89
|
-
super().__init__(categories)
|
|
96
|
+
super().__init__(categories, categories_orig)
|
|
90
97
|
|
|
91
|
-
self.path_weights = path_weights
|
|
98
|
+
self.path_weights = Path(path_weights)
|
|
92
99
|
|
|
93
100
|
self.name = self.get_name(self.path_weights)
|
|
94
101
|
self.model_id = self.get_model_id()
|
|
@@ -100,16 +107,16 @@ class FasttextLangDetector(FasttextLangDetectorMixin):
|
|
|
100
107
|
return self.output_to_detection_result(output)
|
|
101
108
|
|
|
102
109
|
@classmethod
|
|
103
|
-
def get_requirements(cls) ->
|
|
110
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
104
111
|
return [get_fasttext_requirement()]
|
|
105
112
|
|
|
106
|
-
def clone(self) ->
|
|
107
|
-
return self.__class__(self.path_weights, self.categories)
|
|
113
|
+
def clone(self) -> FasttextLangDetector:
|
|
114
|
+
return self.__class__(self.path_weights, self.categories.get_categories(), self.categories_orig)
|
|
108
115
|
|
|
109
116
|
@staticmethod
|
|
110
|
-
def get_wrapped_model(path_weights:
|
|
117
|
+
def get_wrapped_model(path_weights: PathLikeOrStr) -> Any:
|
|
111
118
|
"""
|
|
112
119
|
Get the wrapped model
|
|
113
120
|
:param path_weights: path to model weights
|
|
114
121
|
"""
|
|
115
|
-
return load_model(path_weights)
|
|
122
|
+
return load_model(os.fspath(path_weights))
|