deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +16 -29
- deepdoctection/analyzer/dd.py +70 -59
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/dataflow/common.py +9 -5
- deepdoctection/dataflow/custom.py +5 -5
- deepdoctection/dataflow/custom_serialize.py +75 -18
- deepdoctection/dataflow/parallel_map.py +3 -3
- deepdoctection/dataflow/serialize.py +4 -4
- deepdoctection/dataflow/stats.py +3 -3
- deepdoctection/datapoint/annotation.py +41 -56
- deepdoctection/datapoint/box.py +9 -8
- deepdoctection/datapoint/convert.py +6 -6
- deepdoctection/datapoint/image.py +56 -44
- deepdoctection/datapoint/view.py +245 -150
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +35 -26
- deepdoctection/datasets/base.py +14 -12
- deepdoctection/datasets/dataflow_builder.py +3 -3
- deepdoctection/datasets/info.py +24 -26
- deepdoctection/datasets/instances/doclaynet.py +51 -51
- deepdoctection/datasets/instances/fintabnet.py +46 -46
- deepdoctection/datasets/instances/funsd.py +25 -24
- deepdoctection/datasets/instances/iiitar13k.py +13 -10
- deepdoctection/datasets/instances/layouttest.py +4 -3
- deepdoctection/datasets/instances/publaynet.py +5 -5
- deepdoctection/datasets/instances/pubtables1m.py +24 -21
- deepdoctection/datasets/instances/pubtabnet.py +32 -30
- deepdoctection/datasets/instances/rvlcdip.py +30 -30
- deepdoctection/datasets/instances/xfund.py +26 -26
- deepdoctection/datasets/save.py +6 -6
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/accmetric.py +32 -33
- deepdoctection/eval/base.py +8 -9
- deepdoctection/eval/cocometric.py +15 -13
- deepdoctection/eval/eval.py +41 -37
- deepdoctection/eval/tedsmetric.py +30 -23
- deepdoctection/eval/tp_eval_callback.py +16 -19
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/base.py +339 -134
- deepdoctection/extern/d2detect.py +85 -113
- deepdoctection/extern/deskew.py +14 -11
- deepdoctection/extern/doctrocr.py +141 -130
- deepdoctection/extern/fastlang.py +27 -18
- deepdoctection/extern/hfdetr.py +71 -62
- deepdoctection/extern/hflayoutlm.py +504 -211
- deepdoctection/extern/hflm.py +230 -0
- deepdoctection/extern/model.py +488 -302
- deepdoctection/extern/pdftext.py +23 -19
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +29 -19
- deepdoctection/extern/tessocr.py +39 -38
- deepdoctection/extern/texocr.py +18 -18
- deepdoctection/extern/tp/tfutils.py +57 -9
- deepdoctection/extern/tp/tpcompat.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +45 -53
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/cats.py +27 -29
- deepdoctection/mapper/cocostruct.py +10 -10
- deepdoctection/mapper/d2struct.py +27 -26
- deepdoctection/mapper/hfstruct.py +13 -8
- deepdoctection/mapper/laylmstruct.py +178 -37
- deepdoctection/mapper/maputils.py +12 -11
- deepdoctection/mapper/match.py +2 -2
- deepdoctection/mapper/misc.py +11 -9
- deepdoctection/mapper/pascalstruct.py +4 -4
- deepdoctection/mapper/prodigystruct.py +5 -5
- deepdoctection/mapper/pubstruct.py +84 -92
- deepdoctection/mapper/tpstruct.py +5 -5
- deepdoctection/mapper/xfundstruct.py +33 -33
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/anngen.py +12 -14
- deepdoctection/pipe/base.py +52 -106
- deepdoctection/pipe/common.py +72 -59
- deepdoctection/pipe/concurrency.py +16 -11
- deepdoctection/pipe/doctectionpipe.py +24 -21
- deepdoctection/pipe/language.py +20 -25
- deepdoctection/pipe/layout.py +20 -16
- deepdoctection/pipe/lm.py +75 -105
- deepdoctection/pipe/order.py +194 -89
- deepdoctection/pipe/refine.py +111 -124
- deepdoctection/pipe/segment.py +156 -161
- deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
- deepdoctection/pipe/text.py +37 -36
- deepdoctection/pipe/transform.py +19 -16
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +48 -41
- deepdoctection/train/hf_detr_train.py +41 -30
- deepdoctection/train/hf_layoutlm_train.py +153 -135
- deepdoctection/train/tp_frcnn_train.py +32 -31
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +13 -6
- deepdoctection/utils/develop.py +4 -4
- deepdoctection/utils/env_info.py +87 -125
- deepdoctection/utils/file_utils.py +6 -11
- deepdoctection/utils/fs.py +22 -18
- deepdoctection/utils/identifier.py +2 -2
- deepdoctection/utils/logger.py +16 -15
- deepdoctection/utils/metacfg.py +7 -7
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/pdf_utils.py +11 -11
- deepdoctection/utils/settings.py +185 -181
- deepdoctection/utils/tqdm.py +1 -1
- deepdoctection/utils/transform.py +14 -9
- deepdoctection/utils/types.py +104 -0
- deepdoctection/utils/utils.py +7 -7
- deepdoctection/utils/viz.py +74 -72
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
- deepdoctection-0.33.dist-info/RECORD +146 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
- deepdoctection/utils/detection_types.py +0 -68
- deepdoctection-0.31.dist-info/RECORD +0 -144
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
|
@@ -18,43 +18,40 @@
|
|
|
18
18
|
"""
|
|
19
19
|
D2 GeneralizedRCNN model as predictor for deepdoctection pipeline
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
21
23
|
import io
|
|
24
|
+
import os
|
|
22
25
|
from abc import ABC
|
|
23
26
|
from copy import copy
|
|
24
27
|
from pathlib import Path
|
|
25
|
-
from typing import
|
|
28
|
+
from typing import Literal, Mapping, Optional, Sequence, Union
|
|
26
29
|
|
|
27
30
|
import numpy as np
|
|
31
|
+
from lazy_imports import try_import
|
|
28
32
|
|
|
29
|
-
from ..utils.
|
|
30
|
-
from ..utils.file_utils import (
|
|
31
|
-
detectron2_available,
|
|
32
|
-
get_detectron2_requirement,
|
|
33
|
-
get_pytorch_requirement,
|
|
34
|
-
pytorch_available,
|
|
35
|
-
)
|
|
33
|
+
from ..utils.file_utils import get_detectron2_requirement, get_pytorch_requirement
|
|
36
34
|
from ..utils.metacfg import AttrDict, set_config_by_yaml
|
|
37
|
-
from ..utils.settings import ObjectTypes, TypeOrStr, get_type
|
|
35
|
+
from ..utils.settings import DefaultType, ObjectTypes, TypeOrStr, get_type
|
|
38
36
|
from ..utils.transform import InferenceResize, ResizeTransform
|
|
39
|
-
from .
|
|
37
|
+
from ..utils.types import PathLikeOrStr, PixelValues, Requirement
|
|
38
|
+
from .base import DetectionResult, ModelCategories, ObjectDetector
|
|
40
39
|
from .pt.nms import batched_nms
|
|
41
|
-
from .pt.ptutils import
|
|
40
|
+
from .pt.ptutils import get_torch_device
|
|
42
41
|
|
|
43
|
-
|
|
42
|
+
with try_import() as pt_import_guard:
|
|
44
43
|
import torch
|
|
45
44
|
import torch.cuda
|
|
46
45
|
from torch import nn # pylint: disable=W0611
|
|
47
46
|
|
|
48
|
-
|
|
47
|
+
with try_import() as d2_import_guard:
|
|
49
48
|
from detectron2.checkpoint import DetectionCheckpointer
|
|
50
49
|
from detectron2.config import CfgNode, get_cfg # pylint: disable=W0611
|
|
51
50
|
from detectron2.modeling import GeneralizedRCNN, build_model # pylint: disable=W0611
|
|
52
51
|
from detectron2.structures import Instances # pylint: disable=W0611
|
|
53
52
|
|
|
54
53
|
|
|
55
|
-
def _d2_post_processing(
|
|
56
|
-
predictions: Dict[str, "Instances"], nms_thresh_class_agnostic: float
|
|
57
|
-
) -> Dict[str, "Instances"]:
|
|
54
|
+
def _d2_post_processing(predictions: dict[str, Instances], nms_thresh_class_agnostic: float) -> dict[str, Instances]:
|
|
58
55
|
"""
|
|
59
56
|
D2 postprocessing steps, so that detection outputs are aligned with outputs of other packages (e.g. Tensorpack).
|
|
60
57
|
Apply a class agnostic NMS.
|
|
@@ -71,11 +68,11 @@ def _d2_post_processing(
|
|
|
71
68
|
|
|
72
69
|
|
|
73
70
|
def d2_predict_image(
|
|
74
|
-
np_img:
|
|
75
|
-
predictor:
|
|
71
|
+
np_img: PixelValues,
|
|
72
|
+
predictor: nn.Module,
|
|
76
73
|
resizer: InferenceResize,
|
|
77
74
|
nms_thresh_class_agnostic: float,
|
|
78
|
-
) ->
|
|
75
|
+
) -> list[DetectionResult]:
|
|
79
76
|
"""
|
|
80
77
|
Run detection on one image, using the D2 model callable. It will also handle the preprocessing internally which
|
|
81
78
|
is using a custom resizing within some bounds.
|
|
@@ -107,8 +104,8 @@ def d2_predict_image(
|
|
|
107
104
|
|
|
108
105
|
|
|
109
106
|
def d2_jit_predict_image(
|
|
110
|
-
np_img:
|
|
111
|
-
) ->
|
|
107
|
+
np_img: PixelValues, d2_predictor: nn.Module, resizer: InferenceResize, nms_thresh_class_agnostic: float
|
|
108
|
+
) -> list[DetectionResult]:
|
|
112
109
|
"""
|
|
113
110
|
Run detection on an image using torchscript. It will also handle the preprocessing internally which
|
|
114
111
|
is using a custom resizing within some bounds. Moreover, and different from the setting where Detectron2 is used
|
|
@@ -152,7 +149,7 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
|
|
|
152
149
|
|
|
153
150
|
def __init__(
|
|
154
151
|
self,
|
|
155
|
-
categories: Mapping[
|
|
152
|
+
categories: Mapping[int, TypeOrStr],
|
|
156
153
|
filter_categories: Optional[Sequence[TypeOrStr]] = None,
|
|
157
154
|
):
|
|
158
155
|
"""
|
|
@@ -163,37 +160,31 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
|
|
|
163
160
|
be filtered. Pass a list of category names that must not be returned
|
|
164
161
|
"""
|
|
165
162
|
|
|
163
|
+
self.categories = ModelCategories(init_categories=categories)
|
|
166
164
|
if filter_categories:
|
|
167
|
-
filter_categories =
|
|
168
|
-
self.filter_categories = filter_categories
|
|
169
|
-
self._categories_d2 = self._map_to_d2_categories(copy(categories))
|
|
170
|
-
self.categories = {idx: get_type(cat) for idx, cat in categories.items()}
|
|
165
|
+
self.categories.filter_categories = tuple(get_type(cat) for cat in filter_categories)
|
|
171
166
|
|
|
172
|
-
def _map_category_names(self, detection_results:
|
|
167
|
+
def _map_category_names(self, detection_results: list[DetectionResult]) -> list[DetectionResult]:
|
|
173
168
|
"""
|
|
174
169
|
Populating category names to detection results
|
|
175
170
|
|
|
176
171
|
:param detection_results: list of detection results. Will also filter categories
|
|
177
172
|
:return: List of detection results with attribute class_name populated
|
|
178
173
|
"""
|
|
179
|
-
filtered_detection_result:
|
|
174
|
+
filtered_detection_result: list[DetectionResult] = []
|
|
175
|
+
shifted_categories = self.categories.shift_category_ids(shift_by=-1)
|
|
180
176
|
for result in detection_results:
|
|
181
|
-
result.class_name =
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
if
|
|
185
|
-
if result.
|
|
177
|
+
result.class_name = shifted_categories.get(
|
|
178
|
+
result.class_id if result.class_id is not None else -1, DefaultType.DEFAULT_TYPE
|
|
179
|
+
)
|
|
180
|
+
if result.class_name != DefaultType.DEFAULT_TYPE:
|
|
181
|
+
if result.class_id is not None:
|
|
182
|
+
result.class_id += 1
|
|
186
183
|
filtered_detection_result.append(result)
|
|
187
|
-
else:
|
|
188
|
-
filtered_detection_result.append(result)
|
|
189
184
|
return filtered_detection_result
|
|
190
185
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
|
|
194
|
-
|
|
195
|
-
def possible_categories(self) -> List[ObjectTypes]:
|
|
196
|
-
return list(self.categories.values())
|
|
186
|
+
def get_category_names(self) -> tuple[ObjectTypes, ...]:
|
|
187
|
+
return self.categories.get_categories(as_dict=False)
|
|
197
188
|
|
|
198
189
|
@staticmethod
|
|
199
190
|
def get_inference_resizer(min_size_test: int, max_size_test: int) -> InferenceResize:
|
|
@@ -205,7 +196,7 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
|
|
|
205
196
|
return InferenceResize(min_size_test, max_size_test)
|
|
206
197
|
|
|
207
198
|
@staticmethod
|
|
208
|
-
def get_name(path_weights:
|
|
199
|
+
def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
|
|
209
200
|
"""Returns the name of the model"""
|
|
210
201
|
return f"detectron2_{architecture}" + "_".join(Path(path_weights).parts[-2:])
|
|
211
202
|
|
|
@@ -234,11 +225,11 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
|
|
|
234
225
|
|
|
235
226
|
def __init__(
|
|
236
227
|
self,
|
|
237
|
-
path_yaml:
|
|
238
|
-
path_weights:
|
|
239
|
-
categories: Mapping[
|
|
240
|
-
config_overwrite: Optional[
|
|
241
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
228
|
+
path_yaml: PathLikeOrStr,
|
|
229
|
+
path_weights: PathLikeOrStr,
|
|
230
|
+
categories: Mapping[int, TypeOrStr],
|
|
231
|
+
config_overwrite: Optional[list[str]] = None,
|
|
232
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
242
233
|
filter_categories: Optional[Sequence[TypeOrStr]] = None,
|
|
243
234
|
):
|
|
244
235
|
"""
|
|
@@ -261,18 +252,15 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
|
|
|
261
252
|
"""
|
|
262
253
|
super().__init__(categories, filter_categories)
|
|
263
254
|
|
|
264
|
-
self.path_weights = path_weights
|
|
265
|
-
self.path_yaml = path_yaml
|
|
255
|
+
self.path_weights = Path(path_weights)
|
|
256
|
+
self.path_yaml = Path(path_yaml)
|
|
266
257
|
|
|
267
258
|
config_overwrite = config_overwrite if config_overwrite else []
|
|
268
259
|
self.config_overwrite = config_overwrite
|
|
269
|
-
|
|
270
|
-
self.device = device
|
|
271
|
-
else:
|
|
272
|
-
self.device = set_torch_auto_device()
|
|
260
|
+
self.device = get_torch_device(device)
|
|
273
261
|
|
|
274
262
|
d2_conf_list = self._get_d2_config_list(path_weights, config_overwrite)
|
|
275
|
-
self.cfg = self._set_config(path_yaml, d2_conf_list, device)
|
|
263
|
+
self.cfg = self._set_config(path_yaml, d2_conf_list, self.device)
|
|
276
264
|
|
|
277
265
|
self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
|
|
278
266
|
self.model_id = self.get_model_id()
|
|
@@ -282,21 +270,18 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
|
|
|
282
270
|
self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
|
|
283
271
|
|
|
284
272
|
@staticmethod
|
|
285
|
-
def _set_config(
|
|
286
|
-
path_yaml: str, d2_conf_list: List[str], device: Optional[Literal["cpu", "cuda"]] = None
|
|
287
|
-
) -> "CfgNode":
|
|
273
|
+
def _set_config(path_yaml: PathLikeOrStr, d2_conf_list: list[str], device: torch.device) -> CfgNode:
|
|
288
274
|
cfg = get_cfg()
|
|
289
275
|
# additional attribute with default value, so that the true value can be loaded from the configs
|
|
290
276
|
cfg.NMS_THRESH_CLASS_AGNOSTIC = 0.1
|
|
291
|
-
cfg.merge_from_file(path_yaml)
|
|
277
|
+
cfg.merge_from_file(os.fspath(path_yaml))
|
|
292
278
|
cfg.merge_from_list(d2_conf_list)
|
|
293
|
-
|
|
294
|
-
cfg.MODEL.DEVICE = "cpu"
|
|
279
|
+
cfg.MODEL.DEVICE = str(device)
|
|
295
280
|
cfg.freeze()
|
|
296
281
|
return cfg
|
|
297
282
|
|
|
298
283
|
@staticmethod
|
|
299
|
-
def _set_model(config:
|
|
284
|
+
def _set_model(config: CfgNode) -> GeneralizedRCNN:
|
|
300
285
|
"""
|
|
301
286
|
Build the D2 model. It uses the available builtin tools of D2
|
|
302
287
|
|
|
@@ -306,11 +291,11 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
|
|
|
306
291
|
return build_model(config.clone()).eval()
|
|
307
292
|
|
|
308
293
|
@staticmethod
|
|
309
|
-
def _instantiate_d2_predictor(wrapped_model:
|
|
294
|
+
def _instantiate_d2_predictor(wrapped_model: GeneralizedRCNN, path_weights: PathLikeOrStr) -> None:
|
|
310
295
|
checkpointer = DetectionCheckpointer(wrapped_model)
|
|
311
|
-
checkpointer.load(path_weights)
|
|
296
|
+
checkpointer.load(os.fspath(path_weights))
|
|
312
297
|
|
|
313
|
-
def predict(self, np_img:
|
|
298
|
+
def predict(self, np_img: PixelValues) -> list[DetectionResult]:
|
|
314
299
|
"""
|
|
315
300
|
Prediction per image.
|
|
316
301
|
|
|
@@ -326,23 +311,26 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
|
|
|
326
311
|
return self._map_category_names(detection_results)
|
|
327
312
|
|
|
328
313
|
@classmethod
|
|
329
|
-
def get_requirements(cls) ->
|
|
314
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
330
315
|
return [get_pytorch_requirement(), get_detectron2_requirement()]
|
|
331
316
|
|
|
332
|
-
def clone(self) ->
|
|
317
|
+
def clone(self) -> D2FrcnnDetector:
|
|
333
318
|
return self.__class__(
|
|
334
319
|
self.path_yaml,
|
|
335
320
|
self.path_weights,
|
|
336
|
-
self.categories,
|
|
321
|
+
self.categories.get_categories(),
|
|
337
322
|
self.config_overwrite,
|
|
338
323
|
self.device,
|
|
339
|
-
self.filter_categories,
|
|
324
|
+
self.categories.filter_categories,
|
|
340
325
|
)
|
|
341
326
|
|
|
342
327
|
@staticmethod
|
|
343
328
|
def get_wrapped_model(
|
|
344
|
-
path_yaml:
|
|
345
|
-
|
|
329
|
+
path_yaml: PathLikeOrStr,
|
|
330
|
+
path_weights: PathLikeOrStr,
|
|
331
|
+
config_overwrite: list[str],
|
|
332
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
333
|
+
) -> GeneralizedRCNN:
|
|
346
334
|
"""
|
|
347
335
|
Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
|
|
348
336
|
|
|
@@ -365,8 +353,7 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
|
|
|
365
353
|
:return: Detectron2 GeneralizedRCNN model
|
|
366
354
|
"""
|
|
367
355
|
|
|
368
|
-
|
|
369
|
-
device = set_torch_auto_device()
|
|
356
|
+
device = get_torch_device(device)
|
|
370
357
|
d2_conf_list = D2FrcnnDetector._get_d2_config_list(path_weights, config_overwrite)
|
|
371
358
|
cfg = D2FrcnnDetector._set_config(path_yaml, d2_conf_list, device)
|
|
372
359
|
model = D2FrcnnDetector._set_model(cfg)
|
|
@@ -374,14 +361,17 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
|
|
|
374
361
|
return model
|
|
375
362
|
|
|
376
363
|
@staticmethod
|
|
377
|
-
def _get_d2_config_list(path_weights:
|
|
378
|
-
d2_conf_list = ["MODEL.WEIGHTS", path_weights]
|
|
364
|
+
def _get_d2_config_list(path_weights: PathLikeOrStr, config_overwrite: list[str]) -> list[str]:
|
|
365
|
+
d2_conf_list = ["MODEL.WEIGHTS", os.fspath(path_weights)]
|
|
379
366
|
config_overwrite = config_overwrite if config_overwrite else []
|
|
380
367
|
for conf in config_overwrite:
|
|
381
368
|
key, val = conf.split("=", maxsplit=1)
|
|
382
369
|
d2_conf_list.extend([key, val])
|
|
383
370
|
return d2_conf_list
|
|
384
371
|
|
|
372
|
+
def clear_model(self) -> None:
|
|
373
|
+
self.d2_predictor = None
|
|
374
|
+
|
|
385
375
|
|
|
386
376
|
class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
|
|
387
377
|
"""
|
|
@@ -409,10 +399,10 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
|
|
|
409
399
|
|
|
410
400
|
def __init__(
|
|
411
401
|
self,
|
|
412
|
-
path_yaml:
|
|
413
|
-
path_weights:
|
|
414
|
-
categories: Mapping[
|
|
415
|
-
config_overwrite: Optional[
|
|
402
|
+
path_yaml: PathLikeOrStr,
|
|
403
|
+
path_weights: PathLikeOrStr,
|
|
404
|
+
categories: Mapping[int, TypeOrStr],
|
|
405
|
+
config_overwrite: Optional[list[str]] = None,
|
|
416
406
|
filter_categories: Optional[Sequence[TypeOrStr]] = None,
|
|
417
407
|
):
|
|
418
408
|
"""
|
|
@@ -432,8 +422,8 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
|
|
|
432
422
|
|
|
433
423
|
super().__init__(categories, filter_categories)
|
|
434
424
|
|
|
435
|
-
self.path_weights = path_weights
|
|
436
|
-
self.path_yaml = path_yaml
|
|
425
|
+
self.path_weights = Path(path_weights)
|
|
426
|
+
self.path_yaml = Path(path_yaml)
|
|
437
427
|
|
|
438
428
|
self.config_overwrite = copy(config_overwrite)
|
|
439
429
|
self.cfg = self._set_config(self.path_yaml, self.path_weights, self.config_overwrite)
|
|
@@ -445,14 +435,16 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
|
|
|
445
435
|
self.d2_predictor = self.get_wrapped_model(self.path_weights)
|
|
446
436
|
|
|
447
437
|
@staticmethod
|
|
448
|
-
def _set_config(
|
|
438
|
+
def _set_config(
|
|
439
|
+
path_yaml: PathLikeOrStr, path_weights: PathLikeOrStr, config_overwrite: Optional[list[str]]
|
|
440
|
+
) -> AttrDict:
|
|
449
441
|
cfg = set_config_by_yaml(path_yaml)
|
|
450
442
|
config_overwrite = config_overwrite if config_overwrite else []
|
|
451
|
-
config_overwrite.extend([f"MODEL.WEIGHTS={path_weights}"])
|
|
443
|
+
config_overwrite.extend([f"MODEL.WEIGHTS={os.fspath(path_weights)}"])
|
|
452
444
|
cfg.update_args(config_overwrite)
|
|
453
445
|
return cfg
|
|
454
446
|
|
|
455
|
-
def predict(self, np_img:
|
|
447
|
+
def predict(self, np_img: PixelValues) -> list[DetectionResult]:
|
|
456
448
|
"""
|
|
457
449
|
Prediction per image.
|
|
458
450
|
|
|
@@ -468,46 +460,23 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
|
|
|
468
460
|
return self._map_category_names(detection_results)
|
|
469
461
|
|
|
470
462
|
@classmethod
|
|
471
|
-
def get_requirements(cls) ->
|
|
463
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
472
464
|
return [get_pytorch_requirement()]
|
|
473
465
|
|
|
474
|
-
|
|
475
|
-
def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
|
|
476
|
-
return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
|
|
477
|
-
|
|
478
|
-
def clone(self) -> PredictorBase:
|
|
466
|
+
def clone(self) -> D2FrcnnTracingDetector:
|
|
479
467
|
return self.__class__(
|
|
480
468
|
self.path_yaml,
|
|
481
469
|
self.path_weights,
|
|
482
|
-
self.categories,
|
|
470
|
+
self.categories.get_categories(),
|
|
483
471
|
self.config_overwrite,
|
|
484
|
-
self.filter_categories,
|
|
472
|
+
self.categories.filter_categories,
|
|
485
473
|
)
|
|
486
474
|
|
|
487
|
-
def
|
|
488
|
-
|
|
489
|
-
Populating category names to detection results
|
|
490
|
-
|
|
491
|
-
:param detection_results: list of detection results. Will also filter categories
|
|
492
|
-
:return: List of detection results with attribute class_name populated
|
|
493
|
-
"""
|
|
494
|
-
filtered_detection_result: List[DetectionResult] = []
|
|
495
|
-
for result in detection_results:
|
|
496
|
-
result.class_name = self._categories_d2[str(result.class_id)]
|
|
497
|
-
if isinstance(result.class_id, int):
|
|
498
|
-
result.class_id += 1
|
|
499
|
-
if self.filter_categories:
|
|
500
|
-
if result.class_name not in self.filter_categories:
|
|
501
|
-
filtered_detection_result.append(result)
|
|
502
|
-
else:
|
|
503
|
-
filtered_detection_result.append(result)
|
|
504
|
-
return filtered_detection_result
|
|
505
|
-
|
|
506
|
-
def possible_categories(self) -> List[ObjectTypes]:
|
|
507
|
-
return list(self.categories.values())
|
|
475
|
+
def get_category_names(self) -> tuple[ObjectTypes, ...]:
|
|
476
|
+
return self.categories.get_categories(as_dict=False)
|
|
508
477
|
|
|
509
478
|
@staticmethod
|
|
510
|
-
def get_wrapped_model(path_weights:
|
|
479
|
+
def get_wrapped_model(path_weights: PathLikeOrStr) -> torch.jit.ScriptModule:
|
|
511
480
|
"""
|
|
512
481
|
Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
|
|
513
482
|
|
|
@@ -518,3 +487,6 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
|
|
|
518
487
|
buffer = io.BytesIO(file.read())
|
|
519
488
|
# Load all tensors to the original device
|
|
520
489
|
return torch.jit.load(buffer)
|
|
490
|
+
|
|
491
|
+
def clear_model(self) -> None:
|
|
492
|
+
self.d2_predictor = None # type: ignore
|
deepdoctection/extern/deskew.py
CHANGED
|
@@ -18,16 +18,17 @@
|
|
|
18
18
|
"""
|
|
19
19
|
jdeskew estimator and rotator to deskew images: <https://github.com/phamquiluan/jdeskew>
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
21
22
|
|
|
22
|
-
from
|
|
23
|
+
from lazy_imports import try_import
|
|
23
24
|
|
|
24
|
-
from ..utils.
|
|
25
|
-
from ..utils.
|
|
26
|
-
from ..utils.
|
|
25
|
+
from ..utils.file_utils import get_jdeskew_requirement
|
|
26
|
+
from ..utils.settings import ObjectTypes, PageType
|
|
27
|
+
from ..utils.types import PixelValues, Requirement
|
|
27
28
|
from ..utils.viz import viz_handler
|
|
28
29
|
from .base import DetectionResult, ImageTransformer
|
|
29
30
|
|
|
30
|
-
|
|
31
|
+
with try_import() as import_guard:
|
|
31
32
|
from jdeskew.estimator import get_angle
|
|
32
33
|
|
|
33
34
|
|
|
@@ -42,7 +43,7 @@ class Jdeskewer(ImageTransformer):
|
|
|
42
43
|
self.model_id = self.get_model_id()
|
|
43
44
|
self.min_angle_rotation = min_angle_rotation
|
|
44
45
|
|
|
45
|
-
def transform(self, np_img:
|
|
46
|
+
def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
|
|
46
47
|
"""
|
|
47
48
|
Rotation of the image according to the angle determined by the jdeskew estimator.
|
|
48
49
|
|
|
@@ -59,7 +60,7 @@ class Jdeskewer(ImageTransformer):
|
|
|
59
60
|
return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
|
|
60
61
|
return np_img
|
|
61
62
|
|
|
62
|
-
def predict(self, np_img:
|
|
63
|
+
def predict(self, np_img: PixelValues) -> DetectionResult:
|
|
63
64
|
"""
|
|
64
65
|
Predict the angle of the image to deskew it.
|
|
65
66
|
|
|
@@ -69,12 +70,14 @@ class Jdeskewer(ImageTransformer):
|
|
|
69
70
|
return DetectionResult(angle=round(float(get_angle(np_img)), 4))
|
|
70
71
|
|
|
71
72
|
@classmethod
|
|
72
|
-
def get_requirements(cls) ->
|
|
73
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
73
74
|
"""
|
|
74
75
|
Get a list of requirements for running the detector
|
|
75
76
|
"""
|
|
76
77
|
return [get_jdeskew_requirement()]
|
|
77
78
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
79
|
+
def clone(self) -> Jdeskewer:
|
|
80
|
+
return self.__class__(self.min_angle_rotation)
|
|
81
|
+
|
|
82
|
+
def get_category_names(self) -> tuple[ObjectTypes, ...]:
|
|
83
|
+
return (PageType.ANGLE,)
|