deepdoctection 0.30__py3-none-any.whl → 0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +38 -29
- deepdoctection/analyzer/dd.py +36 -29
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/dataflow/base.py +0 -19
- deepdoctection/dataflow/custom.py +4 -3
- deepdoctection/dataflow/custom_serialize.py +14 -5
- deepdoctection/dataflow/parallel_map.py +12 -11
- deepdoctection/dataflow/serialize.py +5 -4
- deepdoctection/datapoint/annotation.py +35 -13
- deepdoctection/datapoint/box.py +3 -5
- deepdoctection/datapoint/convert.py +3 -1
- deepdoctection/datapoint/image.py +79 -36
- deepdoctection/datapoint/view.py +152 -49
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +6 -3
- deepdoctection/datasets/base.py +86 -11
- deepdoctection/datasets/dataflow_builder.py +1 -1
- deepdoctection/datasets/info.py +4 -4
- deepdoctection/datasets/instances/doclaynet.py +3 -2
- deepdoctection/datasets/instances/fintabnet.py +2 -1
- deepdoctection/datasets/instances/funsd.py +2 -1
- deepdoctection/datasets/instances/iiitar13k.py +5 -2
- deepdoctection/datasets/instances/layouttest.py +4 -8
- deepdoctection/datasets/instances/publaynet.py +2 -2
- deepdoctection/datasets/instances/pubtables1m.py +6 -3
- deepdoctection/datasets/instances/pubtabnet.py +2 -1
- deepdoctection/datasets/instances/rvlcdip.py +2 -1
- deepdoctection/datasets/instances/xfund.py +2 -1
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/accmetric.py +1 -1
- deepdoctection/eval/base.py +5 -4
- deepdoctection/eval/cocometric.py +2 -1
- deepdoctection/eval/eval.py +19 -15
- deepdoctection/eval/tedsmetric.py +14 -11
- deepdoctection/eval/tp_eval_callback.py +14 -7
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/base.py +39 -13
- deepdoctection/extern/d2detect.py +182 -90
- deepdoctection/extern/deskew.py +36 -9
- deepdoctection/extern/doctrocr.py +265 -83
- deepdoctection/extern/fastlang.py +49 -9
- deepdoctection/extern/hfdetr.py +106 -55
- deepdoctection/extern/hflayoutlm.py +441 -122
- deepdoctection/extern/hflm.py +225 -0
- deepdoctection/extern/model.py +56 -47
- deepdoctection/extern/pdftext.py +10 -5
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +27 -18
- deepdoctection/extern/tessocr.py +134 -22
- deepdoctection/extern/texocr.py +6 -2
- deepdoctection/extern/tp/tfutils.py +43 -9
- deepdoctection/extern/tp/tpcompat.py +14 -11
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +8 -9
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +54 -30
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/d2struct.py +9 -7
- deepdoctection/mapper/hfstruct.py +7 -2
- deepdoctection/mapper/laylmstruct.py +164 -21
- deepdoctection/mapper/maputils.py +16 -3
- deepdoctection/mapper/misc.py +6 -3
- deepdoctection/mapper/prodigystruct.py +1 -1
- deepdoctection/mapper/pubstruct.py +10 -10
- deepdoctection/mapper/tpstruct.py +3 -3
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/anngen.py +35 -8
- deepdoctection/pipe/base.py +53 -19
- deepdoctection/pipe/common.py +23 -13
- deepdoctection/pipe/concurrency.py +2 -1
- deepdoctection/pipe/doctectionpipe.py +2 -2
- deepdoctection/pipe/language.py +3 -2
- deepdoctection/pipe/layout.py +6 -3
- deepdoctection/pipe/lm.py +34 -66
- deepdoctection/pipe/order.py +142 -35
- deepdoctection/pipe/refine.py +26 -24
- deepdoctection/pipe/segment.py +21 -16
- deepdoctection/pipe/{cell.py → sub_layout.py} +30 -9
- deepdoctection/pipe/text.py +14 -8
- deepdoctection/pipe/transform.py +16 -9
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +36 -28
- deepdoctection/train/hf_detr_train.py +26 -17
- deepdoctection/train/hf_layoutlm_train.py +133 -111
- deepdoctection/train/tp_frcnn_train.py +21 -19
- deepdoctection/utils/__init__.py +3 -0
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +2 -2
- deepdoctection/utils/env_info.py +41 -84
- deepdoctection/utils/error.py +84 -0
- deepdoctection/utils/file_utils.py +4 -15
- deepdoctection/utils/fs.py +7 -7
- deepdoctection/utils/logger.py +1 -0
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/pdf_utils.py +5 -4
- deepdoctection/utils/settings.py +6 -1
- deepdoctection/utils/transform.py +1 -1
- deepdoctection/utils/utils.py +0 -6
- deepdoctection/utils/viz.py +48 -5
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/METADATA +57 -73
- deepdoctection-0.32.dist-info/RECORD +146 -0
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/WHEEL +1 -1
- deepdoctection-0.30.dist-info/RECORD +0 -143
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
|
@@ -18,42 +18,39 @@
|
|
|
18
18
|
"""
|
|
19
19
|
D2 GeneralizedRCNN model as predictor for deepdoctection pipeline
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
21
23
|
import io
|
|
24
|
+
from abc import ABC
|
|
22
25
|
from copy import copy
|
|
23
26
|
from pathlib import Path
|
|
24
|
-
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence
|
|
27
|
+
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Union
|
|
25
28
|
|
|
26
29
|
import numpy as np
|
|
30
|
+
from lazy_imports import try_import
|
|
27
31
|
|
|
28
32
|
from ..utils.detection_types import ImageType, Requirement
|
|
29
|
-
from ..utils.file_utils import
|
|
30
|
-
|
|
31
|
-
get_detectron2_requirement,
|
|
32
|
-
get_pytorch_requirement,
|
|
33
|
-
pytorch_available,
|
|
34
|
-
)
|
|
35
|
-
from ..utils.metacfg import set_config_by_yaml
|
|
33
|
+
from ..utils.file_utils import get_detectron2_requirement, get_pytorch_requirement
|
|
34
|
+
from ..utils.metacfg import AttrDict, set_config_by_yaml
|
|
36
35
|
from ..utils.settings import ObjectTypes, TypeOrStr, get_type
|
|
37
36
|
from ..utils.transform import InferenceResize, ResizeTransform
|
|
38
37
|
from .base import DetectionResult, ObjectDetector, PredictorBase
|
|
39
38
|
from .pt.nms import batched_nms
|
|
40
|
-
from .pt.ptutils import
|
|
39
|
+
from .pt.ptutils import get_torch_device
|
|
41
40
|
|
|
42
|
-
|
|
41
|
+
with try_import() as pt_import_guard:
|
|
43
42
|
import torch
|
|
44
43
|
import torch.cuda
|
|
45
44
|
from torch import nn # pylint: disable=W0611
|
|
46
45
|
|
|
47
|
-
|
|
46
|
+
with try_import() as d2_import_guard:
|
|
48
47
|
from detectron2.checkpoint import DetectionCheckpointer
|
|
49
48
|
from detectron2.config import CfgNode, get_cfg # pylint: disable=W0611
|
|
50
49
|
from detectron2.modeling import GeneralizedRCNN, build_model # pylint: disable=W0611
|
|
51
50
|
from detectron2.structures import Instances # pylint: disable=W0611
|
|
52
51
|
|
|
53
52
|
|
|
54
|
-
def _d2_post_processing(
|
|
55
|
-
predictions: Dict[str, "Instances"], nms_thresh_class_agnostic: float
|
|
56
|
-
) -> Dict[str, "Instances"]:
|
|
53
|
+
def _d2_post_processing(predictions: Dict[str, Instances], nms_thresh_class_agnostic: float) -> Dict[str, Instances]:
|
|
57
54
|
"""
|
|
58
55
|
D2 postprocessing steps, so that detection outputs are aligned with outputs of other packages (e.g. Tensorpack).
|
|
59
56
|
Apply a class agnostic NMS.
|
|
@@ -71,7 +68,7 @@ def _d2_post_processing(
|
|
|
71
68
|
|
|
72
69
|
def d2_predict_image(
|
|
73
70
|
np_img: ImageType,
|
|
74
|
-
predictor:
|
|
71
|
+
predictor: nn.Module,
|
|
75
72
|
resizer: InferenceResize,
|
|
76
73
|
nms_thresh_class_agnostic: float,
|
|
77
74
|
) -> List[DetectionResult]:
|
|
@@ -106,7 +103,7 @@ def d2_predict_image(
|
|
|
106
103
|
|
|
107
104
|
|
|
108
105
|
def d2_jit_predict_image(
|
|
109
|
-
np_img: ImageType, d2_predictor:
|
|
106
|
+
np_img: ImageType, d2_predictor: nn.Module, resizer: InferenceResize, nms_thresh_class_agnostic: float
|
|
110
107
|
) -> List[DetectionResult]:
|
|
111
108
|
"""
|
|
112
109
|
Run detection on an image using torchscript. It will also handle the preprocessing internally which
|
|
@@ -144,7 +141,72 @@ def d2_jit_predict_image(
|
|
|
144
141
|
return detect_result_list
|
|
145
142
|
|
|
146
143
|
|
|
147
|
-
class
|
|
144
|
+
class D2FrcnnDetectorMixin(ObjectDetector, ABC):
|
|
145
|
+
"""
|
|
146
|
+
Base class for D2 Faster-RCNN implementation. This class only implements the basic wrapper functions
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
categories: Mapping[str, TypeOrStr],
|
|
152
|
+
filter_categories: Optional[Sequence[TypeOrStr]] = None,
|
|
153
|
+
):
|
|
154
|
+
"""
|
|
155
|
+
:param categories: A dict with key (indices) and values (category names). Index 0 must be reserved for a
|
|
156
|
+
dummy 'BG' category. Note, that this convention is different from the builtin D2 framework,
|
|
157
|
+
where models in the model zoo are trained with 'BG' class having the highest index.
|
|
158
|
+
:param filter_categories: The model might return objects that are not supposed to be predicted and that should
|
|
159
|
+
be filtered. Pass a list of category names that must not be returned
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
if filter_categories:
|
|
163
|
+
filter_categories = [get_type(cat) for cat in filter_categories]
|
|
164
|
+
self.filter_categories = filter_categories
|
|
165
|
+
self._categories_d2 = self._map_to_d2_categories(copy(categories))
|
|
166
|
+
self.categories = {idx: get_type(cat) for idx, cat in categories.items()}
|
|
167
|
+
|
|
168
|
+
def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
|
|
169
|
+
"""
|
|
170
|
+
Populating category names to detection results
|
|
171
|
+
|
|
172
|
+
:param detection_results: list of detection results. Will also filter categories
|
|
173
|
+
:return: List of detection results with attribute class_name populated
|
|
174
|
+
"""
|
|
175
|
+
filtered_detection_result: List[DetectionResult] = []
|
|
176
|
+
for result in detection_results:
|
|
177
|
+
result.class_name = self._categories_d2[str(result.class_id)]
|
|
178
|
+
if isinstance(result.class_id, int):
|
|
179
|
+
result.class_id += 1
|
|
180
|
+
if self.filter_categories:
|
|
181
|
+
if result.class_name not in self.filter_categories:
|
|
182
|
+
filtered_detection_result.append(result)
|
|
183
|
+
else:
|
|
184
|
+
filtered_detection_result.append(result)
|
|
185
|
+
return filtered_detection_result
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
|
|
189
|
+
return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
|
|
190
|
+
|
|
191
|
+
def possible_categories(self) -> List[ObjectTypes]:
|
|
192
|
+
return list(self.categories.values())
|
|
193
|
+
|
|
194
|
+
@staticmethod
|
|
195
|
+
def get_inference_resizer(min_size_test: int, max_size_test: int) -> InferenceResize:
|
|
196
|
+
"""Returns the resizer for the inference
|
|
197
|
+
|
|
198
|
+
:param min_size_test: minimum size of the resized image
|
|
199
|
+
:param max_size_test: maximum size of the resized image
|
|
200
|
+
"""
|
|
201
|
+
return InferenceResize(min_size_test, max_size_test)
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
204
|
+
def get_name(path_weights: str, architecture: str) -> str:
|
|
205
|
+
"""Returns the name of the model"""
|
|
206
|
+
return f"detectron2_{architecture}" + "_".join(Path(path_weights).parts[-2:])
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class D2FrcnnDetector(D2FrcnnDetectorMixin):
|
|
148
210
|
"""
|
|
149
211
|
D2 Faster-RCNN implementation with all the available backbones, normalizations throughout the model
|
|
150
212
|
as well as FPN, optional Cascade-RCNN and many more.
|
|
@@ -155,6 +217,7 @@ class D2FrcnnDetector(ObjectDetector):
|
|
|
155
217
|
the standard D2 output that takes into account of the situation that detected objects are disjoint. For more infos
|
|
156
218
|
on this topic, see <https://github.com/facebookresearch/detectron2/issues/978> .
|
|
157
219
|
|
|
220
|
+
```python
|
|
158
221
|
config_path = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
|
|
159
222
|
weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
|
|
160
223
|
categories = ModelCatalog.get_profile("item/d2_model-800000-layout.pkl").categories
|
|
@@ -162,6 +225,7 @@ class D2FrcnnDetector(ObjectDetector):
|
|
|
162
225
|
d2_predictor = D2FrcnnDetector(config_path,weights_path,categories,device="cpu")
|
|
163
226
|
|
|
164
227
|
detection_results = d2_predictor.predict(bgr_image_np_array)
|
|
228
|
+
```
|
|
165
229
|
"""
|
|
166
230
|
|
|
167
231
|
def __init__(
|
|
@@ -170,7 +234,7 @@ class D2FrcnnDetector(ObjectDetector):
|
|
|
170
234
|
path_weights: str,
|
|
171
235
|
categories: Mapping[str, TypeOrStr],
|
|
172
236
|
config_overwrite: Optional[List[str]] = None,
|
|
173
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
237
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
174
238
|
filter_categories: Optional[Sequence[TypeOrStr]] = None,
|
|
175
239
|
):
|
|
176
240
|
"""
|
|
@@ -191,47 +255,38 @@ class D2FrcnnDetector(ObjectDetector):
|
|
|
191
255
|
:param filter_categories: The model might return objects that are not supposed to be predicted and that should
|
|
192
256
|
be filtered. Pass a list of category names that must not be returned
|
|
193
257
|
"""
|
|
258
|
+
super().__init__(categories, filter_categories)
|
|
194
259
|
|
|
195
|
-
self.name = "_".join(Path(path_weights).parts[-3:])
|
|
196
|
-
self._categories_d2 = self._map_to_d2_categories(copy(categories))
|
|
197
260
|
self.path_weights = path_weights
|
|
198
|
-
d2_conf_list = ["MODEL.WEIGHTS", path_weights]
|
|
199
|
-
config_overwrite = config_overwrite if config_overwrite else []
|
|
200
|
-
for conf in config_overwrite:
|
|
201
|
-
key, val = conf.split("=", maxsplit=1)
|
|
202
|
-
d2_conf_list.extend([key, val])
|
|
203
|
-
|
|
204
261
|
self.path_yaml = path_yaml
|
|
205
|
-
|
|
262
|
+
|
|
263
|
+
config_overwrite = config_overwrite if config_overwrite else []
|
|
206
264
|
self.config_overwrite = config_overwrite
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
self.
|
|
214
|
-
|
|
215
|
-
self.d2_predictor =
|
|
216
|
-
self.
|
|
217
|
-
self.
|
|
265
|
+
self.device = get_torch_device(device)
|
|
266
|
+
|
|
267
|
+
d2_conf_list = self._get_d2_config_list(path_weights, config_overwrite)
|
|
268
|
+
self.cfg = self._set_config(path_yaml, d2_conf_list, self.device)
|
|
269
|
+
|
|
270
|
+
self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
|
|
271
|
+
self.model_id = self.get_model_id()
|
|
272
|
+
|
|
273
|
+
self.d2_predictor = self._set_model(self.cfg)
|
|
274
|
+
self._instantiate_d2_predictor(self.d2_predictor, path_weights)
|
|
275
|
+
self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
|
|
218
276
|
|
|
219
277
|
@staticmethod
|
|
220
|
-
def _set_config(
|
|
221
|
-
path_yaml: str, d2_conf_list: List[str], device: Optional[Literal["cpu", "cuda"]] = None
|
|
222
|
-
) -> "CfgNode":
|
|
278
|
+
def _set_config(path_yaml: str, d2_conf_list: List[str], device: torch.device) -> CfgNode:
|
|
223
279
|
cfg = get_cfg()
|
|
224
280
|
# additional attribute with default value, so that the true value can be loaded from the configs
|
|
225
281
|
cfg.NMS_THRESH_CLASS_AGNOSTIC = 0.1
|
|
226
282
|
cfg.merge_from_file(path_yaml)
|
|
227
283
|
cfg.merge_from_list(d2_conf_list)
|
|
228
|
-
|
|
229
|
-
cfg.MODEL.DEVICE = "cpu"
|
|
284
|
+
cfg.MODEL.DEVICE = str(device)
|
|
230
285
|
cfg.freeze()
|
|
231
286
|
return cfg
|
|
232
287
|
|
|
233
288
|
@staticmethod
|
|
234
|
-
def
|
|
289
|
+
def _set_model(config: CfgNode) -> GeneralizedRCNN:
|
|
235
290
|
"""
|
|
236
291
|
Build the D2 model. It uses the available builtin tools of D2
|
|
237
292
|
|
|
@@ -240,9 +295,10 @@ class D2FrcnnDetector(ObjectDetector):
|
|
|
240
295
|
"""
|
|
241
296
|
return build_model(config.clone()).eval()
|
|
242
297
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
checkpointer
|
|
298
|
+
@staticmethod
|
|
299
|
+
def _instantiate_d2_predictor(wrapped_model: GeneralizedRCNN, path_weights: str) -> None:
|
|
300
|
+
checkpointer = DetectionCheckpointer(wrapped_model)
|
|
301
|
+
checkpointer.load(path_weights)
|
|
246
302
|
|
|
247
303
|
def predict(self, np_img: ImageType) -> List[DetectionResult]:
|
|
248
304
|
"""
|
|
@@ -259,33 +315,10 @@ class D2FrcnnDetector(ObjectDetector):
|
|
|
259
315
|
)
|
|
260
316
|
return self._map_category_names(detection_results)
|
|
261
317
|
|
|
262
|
-
def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
|
|
263
|
-
"""
|
|
264
|
-
Populating category names to detection results
|
|
265
|
-
|
|
266
|
-
:param detection_results: list of detection results. Will also filter categories
|
|
267
|
-
:return: List of detection results with attribute class_name populated
|
|
268
|
-
"""
|
|
269
|
-
filtered_detection_result: List[DetectionResult] = []
|
|
270
|
-
for result in detection_results:
|
|
271
|
-
result.class_name = self._categories_d2[str(result.class_id)]
|
|
272
|
-
if isinstance(result.class_id, int):
|
|
273
|
-
result.class_id += 1
|
|
274
|
-
if self.filter_categories:
|
|
275
|
-
if result.class_name not in self.filter_categories:
|
|
276
|
-
filtered_detection_result.append(result)
|
|
277
|
-
else:
|
|
278
|
-
filtered_detection_result.append(result)
|
|
279
|
-
return filtered_detection_result
|
|
280
|
-
|
|
281
318
|
@classmethod
|
|
282
319
|
def get_requirements(cls) -> List[Requirement]:
|
|
283
320
|
return [get_pytorch_requirement(), get_detectron2_requirement()]
|
|
284
321
|
|
|
285
|
-
@classmethod
|
|
286
|
-
def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
|
|
287
|
-
return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
|
|
288
|
-
|
|
289
322
|
def clone(self) -> PredictorBase:
|
|
290
323
|
return self.__class__(
|
|
291
324
|
self.path_yaml,
|
|
@@ -296,11 +329,53 @@ class D2FrcnnDetector(ObjectDetector):
|
|
|
296
329
|
self.filter_categories,
|
|
297
330
|
)
|
|
298
331
|
|
|
299
|
-
|
|
300
|
-
|
|
332
|
+
@staticmethod
|
|
333
|
+
def get_wrapped_model(
|
|
334
|
+
path_yaml: str,
|
|
335
|
+
path_weights: str,
|
|
336
|
+
config_overwrite: List[str],
|
|
337
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
338
|
+
) -> GeneralizedRCNN:
|
|
339
|
+
"""
|
|
340
|
+
Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
|
|
341
|
+
|
|
342
|
+
Example:
|
|
343
|
+
```python
|
|
344
|
+
|
|
345
|
+
path_yaml = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
|
|
346
|
+
weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
|
|
347
|
+
model = D2FrcnnDetector.get_wrapped_model(path_yaml,weights_path,["OUTPUT.FRCNN_NMS_THRESH=0.3",
|
|
348
|
+
"OUTPUT.RESULT_SCORE_THRESH=0.6"],
|
|
349
|
+
"cpu")
|
|
350
|
+
detect_result_list = d2_predict_image(np_img,model,InferenceResize(800,1333),0.3)
|
|
351
|
+
```
|
|
352
|
+
:param path_yaml: The path to the yaml config. If the model is built using several config files, always use
|
|
353
|
+
the highest level .yaml file.
|
|
354
|
+
:param path_weights: The path to the model checkpoint.
|
|
355
|
+
:param config_overwrite: Overwrite some hyperparameters defined by the yaml file with some new values. E.g.
|
|
356
|
+
["OUTPUT.FRCNN_NMS_THRESH=0.3","OUTPUT.RESULT_SCORE_THRESH=0.6"].
|
|
357
|
+
:param device: "cpu" or "cuda". If not specified will auto select depending on what is available
|
|
358
|
+
:return: Detectron2 GeneralizedRCNN model
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
device = get_torch_device(device)
|
|
362
|
+
d2_conf_list = D2FrcnnDetector._get_d2_config_list(path_weights, config_overwrite)
|
|
363
|
+
cfg = D2FrcnnDetector._set_config(path_yaml, d2_conf_list, device)
|
|
364
|
+
model = D2FrcnnDetector._set_model(cfg)
|
|
365
|
+
D2FrcnnDetector._instantiate_d2_predictor(model, path_weights)
|
|
366
|
+
return model
|
|
367
|
+
|
|
368
|
+
@staticmethod
|
|
369
|
+
def _get_d2_config_list(path_weights: str, config_overwrite: List[str]) -> List[str]:
|
|
370
|
+
d2_conf_list = ["MODEL.WEIGHTS", path_weights]
|
|
371
|
+
config_overwrite = config_overwrite if config_overwrite else []
|
|
372
|
+
for conf in config_overwrite:
|
|
373
|
+
key, val = conf.split("=", maxsplit=1)
|
|
374
|
+
d2_conf_list.extend([key, val])
|
|
375
|
+
return d2_conf_list
|
|
301
376
|
|
|
302
377
|
|
|
303
|
-
class D2FrcnnTracingDetector(
|
|
378
|
+
class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
|
|
304
379
|
"""
|
|
305
380
|
D2 Faster-RCNN exported torchscript model. Using this predictor has the advantage that Detectron2 does not have to
|
|
306
381
|
be installed. The associated config setting only contains parameters that are involved in pre-and post-processing.
|
|
@@ -312,6 +387,8 @@ class D2FrcnnTracingDetector(ObjectDetector):
|
|
|
312
387
|
the standard D2 output that takes into account of the situation that detected objects are disjoint. For more infos
|
|
313
388
|
on this topic, see <https://github.com/facebookresearch/detectron2/issues/978> .
|
|
314
389
|
|
|
390
|
+
Example:
|
|
391
|
+
```python
|
|
315
392
|
config_path = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
|
|
316
393
|
weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
|
|
317
394
|
categories = ModelCatalog.get_profile("item/d2_model-800000-layout.pkl").categories
|
|
@@ -319,6 +396,7 @@ class D2FrcnnTracingDetector(ObjectDetector):
|
|
|
319
396
|
d2_predictor = D2FrcnnDetector(config_path,weights_path,categories)
|
|
320
397
|
|
|
321
398
|
detection_results = d2_predictor.predict(bgr_image_np_array)
|
|
399
|
+
```
|
|
322
400
|
"""
|
|
323
401
|
|
|
324
402
|
def __init__(
|
|
@@ -343,27 +421,28 @@ class D2FrcnnTracingDetector(ObjectDetector):
|
|
|
343
421
|
:param filter_categories: The model might return objects that are not supposed to be predicted and that should
|
|
344
422
|
be filtered. Pass a list of category names that must not be returned
|
|
345
423
|
"""
|
|
346
|
-
|
|
347
|
-
|
|
424
|
+
|
|
425
|
+
super().__init__(categories, filter_categories)
|
|
426
|
+
|
|
348
427
|
self.path_weights = path_weights
|
|
349
428
|
self.path_yaml = path_yaml
|
|
350
|
-
|
|
351
|
-
self.config_overwrite = config_overwrite
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
self.
|
|
355
|
-
self.
|
|
429
|
+
|
|
430
|
+
self.config_overwrite = copy(config_overwrite)
|
|
431
|
+
self.cfg = self._set_config(self.path_yaml, self.path_weights, self.config_overwrite)
|
|
432
|
+
|
|
433
|
+
self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
|
|
434
|
+
self.model_id = self.get_model_id()
|
|
435
|
+
|
|
436
|
+
self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
|
|
437
|
+
self.d2_predictor = self.get_wrapped_model(self.path_weights)
|
|
438
|
+
|
|
439
|
+
@staticmethod
|
|
440
|
+
def _set_config(path_yaml: str, path_weights: str, config_overwrite: Optional[List[str]]) -> AttrDict:
|
|
441
|
+
cfg = set_config_by_yaml(path_yaml)
|
|
356
442
|
config_overwrite = config_overwrite if config_overwrite else []
|
|
357
443
|
config_overwrite.extend([f"MODEL.WEIGHTS={path_weights}"])
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
self.d2_predictor = self._instantiate_d2_predictor()
|
|
361
|
-
|
|
362
|
-
def _instantiate_d2_predictor(self) -> Any:
|
|
363
|
-
with open(self.path_weights, "rb") as file:
|
|
364
|
-
buffer = io.BytesIO(file.read())
|
|
365
|
-
# Load all tensors to the original device
|
|
366
|
-
return torch.jit.load(buffer)
|
|
444
|
+
cfg.update_args(config_overwrite)
|
|
445
|
+
return cfg
|
|
367
446
|
|
|
368
447
|
def predict(self, np_img: ImageType) -> List[DetectionResult]:
|
|
369
448
|
"""
|
|
@@ -418,3 +497,16 @@ class D2FrcnnTracingDetector(ObjectDetector):
|
|
|
418
497
|
|
|
419
498
|
def possible_categories(self) -> List[ObjectTypes]:
|
|
420
499
|
return list(self.categories.values())
|
|
500
|
+
|
|
501
|
+
@staticmethod
|
|
502
|
+
def get_wrapped_model(path_weights: str) -> Any:
|
|
503
|
+
"""
|
|
504
|
+
Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
|
|
505
|
+
|
|
506
|
+
:param path_weights:
|
|
507
|
+
:return:
|
|
508
|
+
"""
|
|
509
|
+
with open(path_weights, "rb") as file:
|
|
510
|
+
buffer = io.BytesIO(file.read())
|
|
511
|
+
# Load all tensors to the original device
|
|
512
|
+
return torch.jit.load(buffer)
|
deepdoctection/extern/deskew.py
CHANGED
|
@@ -21,13 +21,16 @@ jdeskew estimator and rotator to deskew images: <https://github.com/phamquiluan/
|
|
|
21
21
|
|
|
22
22
|
from typing import List
|
|
23
23
|
|
|
24
|
+
from lazy_imports import try_import
|
|
25
|
+
|
|
24
26
|
from ..utils.detection_types import ImageType, Requirement
|
|
25
|
-
from ..utils.file_utils import get_jdeskew_requirement
|
|
26
|
-
from .
|
|
27
|
+
from ..utils.file_utils import get_jdeskew_requirement
|
|
28
|
+
from ..utils.settings import PageType
|
|
29
|
+
from ..utils.viz import viz_handler
|
|
30
|
+
from .base import DetectionResult, ImageTransformer
|
|
27
31
|
|
|
28
|
-
|
|
32
|
+
with try_import() as import_guard:
|
|
29
33
|
from jdeskew.estimator import get_angle
|
|
30
|
-
from jdeskew.utility import rotate
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
class Jdeskewer(ImageTransformer):
|
|
@@ -37,19 +40,43 @@ class Jdeskewer(ImageTransformer):
|
|
|
37
40
|
"""
|
|
38
41
|
|
|
39
42
|
def __init__(self, min_angle_rotation: float = 2.0):
|
|
40
|
-
self.name = "
|
|
43
|
+
self.name = "jdeskewer"
|
|
44
|
+
self.model_id = self.get_model_id()
|
|
41
45
|
self.min_angle_rotation = min_angle_rotation
|
|
42
46
|
|
|
43
|
-
def transform(self, np_img: ImageType) -> ImageType:
|
|
44
|
-
|
|
47
|
+
def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
|
|
48
|
+
"""
|
|
49
|
+
Rotation of the image according to the angle determined by the jdeskew estimator.
|
|
50
|
+
|
|
51
|
+
**Example**:
|
|
52
|
+
jdeskew_predictor = Jdeskewer()
|
|
53
|
+
detection_result = jdeskew_predictor.predict(np_image)
|
|
54
|
+
jdeskew_predictor.transform(np_image, DetectionResult(angle=5.0))
|
|
45
55
|
|
|
46
|
-
|
|
47
|
-
|
|
56
|
+
:param np_img: image as numpy array
|
|
57
|
+
:param specification: DetectionResult with angle value
|
|
58
|
+
:return: image rotated by the angle
|
|
59
|
+
"""
|
|
60
|
+
if abs(specification.angle) > self.min_angle_rotation: # type: ignore
|
|
61
|
+
return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
|
|
48
62
|
return np_img
|
|
49
63
|
|
|
64
|
+
def predict(self, np_img: ImageType) -> DetectionResult:
|
|
65
|
+
"""
|
|
66
|
+
Predict the angle of the image to deskew it.
|
|
67
|
+
|
|
68
|
+
:param np_img: image as numpy array
|
|
69
|
+
:return: DetectionResult with angle value
|
|
70
|
+
"""
|
|
71
|
+
return DetectionResult(angle=round(float(get_angle(np_img)), 4))
|
|
72
|
+
|
|
50
73
|
@classmethod
|
|
51
74
|
def get_requirements(cls) -> List[Requirement]:
|
|
52
75
|
"""
|
|
53
76
|
Get a list of requirements for running the detector
|
|
54
77
|
"""
|
|
55
78
|
return [get_jdeskew_requirement()]
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def possible_category() -> PageType:
|
|
82
|
+
return PageType.angle
|