deepdoctection 0.30__py3-none-any.whl → 0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (120) hide show
  1. deepdoctection/__init__.py +38 -29
  2. deepdoctection/analyzer/dd.py +36 -29
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/base.py +0 -19
  5. deepdoctection/dataflow/custom.py +4 -3
  6. deepdoctection/dataflow/custom_serialize.py +14 -5
  7. deepdoctection/dataflow/parallel_map.py +12 -11
  8. deepdoctection/dataflow/serialize.py +5 -4
  9. deepdoctection/datapoint/annotation.py +35 -13
  10. deepdoctection/datapoint/box.py +3 -5
  11. deepdoctection/datapoint/convert.py +3 -1
  12. deepdoctection/datapoint/image.py +79 -36
  13. deepdoctection/datapoint/view.py +152 -49
  14. deepdoctection/datasets/__init__.py +1 -4
  15. deepdoctection/datasets/adapter.py +6 -3
  16. deepdoctection/datasets/base.py +86 -11
  17. deepdoctection/datasets/dataflow_builder.py +1 -1
  18. deepdoctection/datasets/info.py +4 -4
  19. deepdoctection/datasets/instances/doclaynet.py +3 -2
  20. deepdoctection/datasets/instances/fintabnet.py +2 -1
  21. deepdoctection/datasets/instances/funsd.py +2 -1
  22. deepdoctection/datasets/instances/iiitar13k.py +5 -2
  23. deepdoctection/datasets/instances/layouttest.py +4 -8
  24. deepdoctection/datasets/instances/publaynet.py +2 -2
  25. deepdoctection/datasets/instances/pubtables1m.py +6 -3
  26. deepdoctection/datasets/instances/pubtabnet.py +2 -1
  27. deepdoctection/datasets/instances/rvlcdip.py +2 -1
  28. deepdoctection/datasets/instances/xfund.py +2 -1
  29. deepdoctection/eval/__init__.py +1 -4
  30. deepdoctection/eval/accmetric.py +1 -1
  31. deepdoctection/eval/base.py +5 -4
  32. deepdoctection/eval/cocometric.py +2 -1
  33. deepdoctection/eval/eval.py +19 -15
  34. deepdoctection/eval/tedsmetric.py +14 -11
  35. deepdoctection/eval/tp_eval_callback.py +14 -7
  36. deepdoctection/extern/__init__.py +2 -7
  37. deepdoctection/extern/base.py +39 -13
  38. deepdoctection/extern/d2detect.py +182 -90
  39. deepdoctection/extern/deskew.py +36 -9
  40. deepdoctection/extern/doctrocr.py +265 -83
  41. deepdoctection/extern/fastlang.py +49 -9
  42. deepdoctection/extern/hfdetr.py +106 -55
  43. deepdoctection/extern/hflayoutlm.py +441 -122
  44. deepdoctection/extern/hflm.py +225 -0
  45. deepdoctection/extern/model.py +56 -47
  46. deepdoctection/extern/pdftext.py +10 -5
  47. deepdoctection/extern/pt/__init__.py +1 -3
  48. deepdoctection/extern/pt/nms.py +6 -2
  49. deepdoctection/extern/pt/ptutils.py +27 -18
  50. deepdoctection/extern/tessocr.py +134 -22
  51. deepdoctection/extern/texocr.py +6 -2
  52. deepdoctection/extern/tp/tfutils.py +43 -9
  53. deepdoctection/extern/tp/tpcompat.py +14 -11
  54. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  55. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  56. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
  58. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
  60. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  61. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
  62. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  67. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  68. deepdoctection/extern/tp/tpfrcnn/preproc.py +8 -9
  69. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  70. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  71. deepdoctection/extern/tpdetect.py +54 -30
  72. deepdoctection/mapper/__init__.py +3 -8
  73. deepdoctection/mapper/d2struct.py +9 -7
  74. deepdoctection/mapper/hfstruct.py +7 -2
  75. deepdoctection/mapper/laylmstruct.py +164 -21
  76. deepdoctection/mapper/maputils.py +16 -3
  77. deepdoctection/mapper/misc.py +6 -3
  78. deepdoctection/mapper/prodigystruct.py +1 -1
  79. deepdoctection/mapper/pubstruct.py +10 -10
  80. deepdoctection/mapper/tpstruct.py +3 -3
  81. deepdoctection/pipe/__init__.py +1 -1
  82. deepdoctection/pipe/anngen.py +35 -8
  83. deepdoctection/pipe/base.py +53 -19
  84. deepdoctection/pipe/common.py +23 -13
  85. deepdoctection/pipe/concurrency.py +2 -1
  86. deepdoctection/pipe/doctectionpipe.py +2 -2
  87. deepdoctection/pipe/language.py +3 -2
  88. deepdoctection/pipe/layout.py +6 -3
  89. deepdoctection/pipe/lm.py +34 -66
  90. deepdoctection/pipe/order.py +142 -35
  91. deepdoctection/pipe/refine.py +26 -24
  92. deepdoctection/pipe/segment.py +21 -16
  93. deepdoctection/pipe/{cell.py → sub_layout.py} +30 -9
  94. deepdoctection/pipe/text.py +14 -8
  95. deepdoctection/pipe/transform.py +16 -9
  96. deepdoctection/train/__init__.py +6 -12
  97. deepdoctection/train/d2_frcnn_train.py +36 -28
  98. deepdoctection/train/hf_detr_train.py +26 -17
  99. deepdoctection/train/hf_layoutlm_train.py +133 -111
  100. deepdoctection/train/tp_frcnn_train.py +21 -19
  101. deepdoctection/utils/__init__.py +3 -0
  102. deepdoctection/utils/concurrency.py +1 -1
  103. deepdoctection/utils/context.py +2 -2
  104. deepdoctection/utils/env_info.py +41 -84
  105. deepdoctection/utils/error.py +84 -0
  106. deepdoctection/utils/file_utils.py +4 -15
  107. deepdoctection/utils/fs.py +7 -7
  108. deepdoctection/utils/logger.py +1 -0
  109. deepdoctection/utils/mocks.py +93 -0
  110. deepdoctection/utils/pdf_utils.py +5 -4
  111. deepdoctection/utils/settings.py +6 -1
  112. deepdoctection/utils/transform.py +1 -1
  113. deepdoctection/utils/utils.py +0 -6
  114. deepdoctection/utils/viz.py +48 -5
  115. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/METADATA +57 -73
  116. deepdoctection-0.32.dist-info/RECORD +146 -0
  117. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/WHEEL +1 -1
  118. deepdoctection-0.30.dist-info/RECORD +0 -143
  119. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
  120. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
@@ -18,42 +18,39 @@
18
18
  """
19
19
  D2 GeneralizedRCNN model as predictor for deepdoctection pipeline
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  import io
24
+ from abc import ABC
22
25
  from copy import copy
23
26
  from pathlib import Path
24
- from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence
27
+ from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Union
25
28
 
26
29
  import numpy as np
30
+ from lazy_imports import try_import
27
31
 
28
32
  from ..utils.detection_types import ImageType, Requirement
29
- from ..utils.file_utils import (
30
- detectron2_available,
31
- get_detectron2_requirement,
32
- get_pytorch_requirement,
33
- pytorch_available,
34
- )
35
- from ..utils.metacfg import set_config_by_yaml
33
+ from ..utils.file_utils import get_detectron2_requirement, get_pytorch_requirement
34
+ from ..utils.metacfg import AttrDict, set_config_by_yaml
36
35
  from ..utils.settings import ObjectTypes, TypeOrStr, get_type
37
36
  from ..utils.transform import InferenceResize, ResizeTransform
38
37
  from .base import DetectionResult, ObjectDetector, PredictorBase
39
38
  from .pt.nms import batched_nms
40
- from .pt.ptutils import set_torch_auto_device
39
+ from .pt.ptutils import get_torch_device
41
40
 
42
- if pytorch_available():
41
+ with try_import() as pt_import_guard:
43
42
  import torch
44
43
  import torch.cuda
45
44
  from torch import nn # pylint: disable=W0611
46
45
 
47
- if detectron2_available():
46
+ with try_import() as d2_import_guard:
48
47
  from detectron2.checkpoint import DetectionCheckpointer
49
48
  from detectron2.config import CfgNode, get_cfg # pylint: disable=W0611
50
49
  from detectron2.modeling import GeneralizedRCNN, build_model # pylint: disable=W0611
51
50
  from detectron2.structures import Instances # pylint: disable=W0611
52
51
 
53
52
 
54
- def _d2_post_processing(
55
- predictions: Dict[str, "Instances"], nms_thresh_class_agnostic: float
56
- ) -> Dict[str, "Instances"]:
53
+ def _d2_post_processing(predictions: Dict[str, Instances], nms_thresh_class_agnostic: float) -> Dict[str, Instances]:
57
54
  """
58
55
  D2 postprocessing steps, so that detection outputs are aligned with outputs of other packages (e.g. Tensorpack).
59
56
  Apply a class agnostic NMS.
@@ -71,7 +68,7 @@ def _d2_post_processing(
71
68
 
72
69
  def d2_predict_image(
73
70
  np_img: ImageType,
74
- predictor: "nn.Module",
71
+ predictor: nn.Module,
75
72
  resizer: InferenceResize,
76
73
  nms_thresh_class_agnostic: float,
77
74
  ) -> List[DetectionResult]:
@@ -106,7 +103,7 @@ def d2_predict_image(
106
103
 
107
104
 
108
105
  def d2_jit_predict_image(
109
- np_img: ImageType, d2_predictor: "nn.Module", resizer: InferenceResize, nms_thresh_class_agnostic: float
106
+ np_img: ImageType, d2_predictor: nn.Module, resizer: InferenceResize, nms_thresh_class_agnostic: float
110
107
  ) -> List[DetectionResult]:
111
108
  """
112
109
  Run detection on an image using torchscript. It will also handle the preprocessing internally which
@@ -144,7 +141,72 @@ def d2_jit_predict_image(
144
141
  return detect_result_list
145
142
 
146
143
 
147
- class D2FrcnnDetector(ObjectDetector):
144
+ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
145
+ """
146
+ Base class for D2 Faster-RCNN implementation. This class only implements the basic wrapper functions
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ categories: Mapping[str, TypeOrStr],
152
+ filter_categories: Optional[Sequence[TypeOrStr]] = None,
153
+ ):
154
+ """
155
+ :param categories: A dict with key (indices) and values (category names). Index 0 must be reserved for a
156
+ dummy 'BG' category. Note, that this convention is different from the builtin D2 framework,
157
+ where models in the model zoo are trained with 'BG' class having the highest index.
158
+ :param filter_categories: The model might return objects that are not supposed to be predicted and that should
159
+ be filtered. Pass a list of category names that must not be returned
160
+ """
161
+
162
+ if filter_categories:
163
+ filter_categories = [get_type(cat) for cat in filter_categories]
164
+ self.filter_categories = filter_categories
165
+ self._categories_d2 = self._map_to_d2_categories(copy(categories))
166
+ self.categories = {idx: get_type(cat) for idx, cat in categories.items()}
167
+
168
+ def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
169
+ """
170
+ Populating category names to detection results
171
+
172
+ :param detection_results: list of detection results. Will also filter categories
173
+ :return: List of detection results with attribute class_name populated
174
+ """
175
+ filtered_detection_result: List[DetectionResult] = []
176
+ for result in detection_results:
177
+ result.class_name = self._categories_d2[str(result.class_id)]
178
+ if isinstance(result.class_id, int):
179
+ result.class_id += 1
180
+ if self.filter_categories:
181
+ if result.class_name not in self.filter_categories:
182
+ filtered_detection_result.append(result)
183
+ else:
184
+ filtered_detection_result.append(result)
185
+ return filtered_detection_result
186
+
187
+ @classmethod
188
+ def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
189
+ return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
190
+
191
+ def possible_categories(self) -> List[ObjectTypes]:
192
+ return list(self.categories.values())
193
+
194
+ @staticmethod
195
+ def get_inference_resizer(min_size_test: int, max_size_test: int) -> InferenceResize:
196
+ """Returns the resizer for the inference
197
+
198
+ :param min_size_test: minimum size of the resized image
199
+ :param max_size_test: maximum size of the resized image
200
+ """
201
+ return InferenceResize(min_size_test, max_size_test)
202
+
203
+ @staticmethod
204
+ def get_name(path_weights: str, architecture: str) -> str:
205
+ """Returns the name of the model"""
206
+ return f"detectron2_{architecture}" + "_".join(Path(path_weights).parts[-2:])
207
+
208
+
209
+ class D2FrcnnDetector(D2FrcnnDetectorMixin):
148
210
  """
149
211
  D2 Faster-RCNN implementation with all the available backbones, normalizations throughout the model
150
212
  as well as FPN, optional Cascade-RCNN and many more.
@@ -155,6 +217,7 @@ class D2FrcnnDetector(ObjectDetector):
155
217
  the standard D2 output that takes into account of the situation that detected objects are disjoint. For more infos
156
218
  on this topic, see <https://github.com/facebookresearch/detectron2/issues/978> .
157
219
 
220
+ ```python
158
221
  config_path = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
159
222
  weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
160
223
  categories = ModelCatalog.get_profile("item/d2_model-800000-layout.pkl").categories
@@ -162,6 +225,7 @@ class D2FrcnnDetector(ObjectDetector):
162
225
  d2_predictor = D2FrcnnDetector(config_path,weights_path,categories,device="cpu")
163
226
 
164
227
  detection_results = d2_predictor.predict(bgr_image_np_array)
228
+ ```
165
229
  """
166
230
 
167
231
  def __init__(
@@ -170,7 +234,7 @@ class D2FrcnnDetector(ObjectDetector):
170
234
  path_weights: str,
171
235
  categories: Mapping[str, TypeOrStr],
172
236
  config_overwrite: Optional[List[str]] = None,
173
- device: Optional[Literal["cpu", "cuda"]] = None,
237
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
174
238
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
175
239
  ):
176
240
  """
@@ -191,47 +255,38 @@ class D2FrcnnDetector(ObjectDetector):
191
255
  :param filter_categories: The model might return objects that are not supposed to be predicted and that should
192
256
  be filtered. Pass a list of category names that must not be returned
193
257
  """
258
+ super().__init__(categories, filter_categories)
194
259
 
195
- self.name = "_".join(Path(path_weights).parts[-3:])
196
- self._categories_d2 = self._map_to_d2_categories(copy(categories))
197
260
  self.path_weights = path_weights
198
- d2_conf_list = ["MODEL.WEIGHTS", path_weights]
199
- config_overwrite = config_overwrite if config_overwrite else []
200
- for conf in config_overwrite:
201
- key, val = conf.split("=", maxsplit=1)
202
- d2_conf_list.extend([key, val])
203
-
204
261
  self.path_yaml = path_yaml
205
- self.categories = copy(categories) # type: ignore
262
+
263
+ config_overwrite = config_overwrite if config_overwrite else []
206
264
  self.config_overwrite = config_overwrite
207
- if device is not None:
208
- self.device = device
209
- else:
210
- self.device = set_torch_auto_device()
211
- if filter_categories:
212
- filter_categories = [get_type(cat) for cat in filter_categories]
213
- self.filter_categories = filter_categories
214
- self.cfg = self._set_config(path_yaml, d2_conf_list, device)
215
- self.d2_predictor = D2FrcnnDetector.set_model(self.cfg)
216
- self.resizer = InferenceResize(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
217
- self._instantiate_d2_predictor()
265
+ self.device = get_torch_device(device)
266
+
267
+ d2_conf_list = self._get_d2_config_list(path_weights, config_overwrite)
268
+ self.cfg = self._set_config(path_yaml, d2_conf_list, self.device)
269
+
270
+ self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
271
+ self.model_id = self.get_model_id()
272
+
273
+ self.d2_predictor = self._set_model(self.cfg)
274
+ self._instantiate_d2_predictor(self.d2_predictor, path_weights)
275
+ self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
218
276
 
219
277
  @staticmethod
220
- def _set_config(
221
- path_yaml: str, d2_conf_list: List[str], device: Optional[Literal["cpu", "cuda"]] = None
222
- ) -> "CfgNode":
278
+ def _set_config(path_yaml: str, d2_conf_list: List[str], device: torch.device) -> CfgNode:
223
279
  cfg = get_cfg()
224
280
  # additional attribute with default value, so that the true value can be loaded from the configs
225
281
  cfg.NMS_THRESH_CLASS_AGNOSTIC = 0.1
226
282
  cfg.merge_from_file(path_yaml)
227
283
  cfg.merge_from_list(d2_conf_list)
228
- if not torch.cuda.is_available() or device == "cpu":
229
- cfg.MODEL.DEVICE = "cpu"
284
+ cfg.MODEL.DEVICE = str(device)
230
285
  cfg.freeze()
231
286
  return cfg
232
287
 
233
288
  @staticmethod
234
- def set_model(config: "CfgNode") -> "GeneralizedRCNN":
289
+ def _set_model(config: CfgNode) -> GeneralizedRCNN:
235
290
  """
236
291
  Build the D2 model. It uses the available builtin tools of D2
237
292
 
@@ -240,9 +295,10 @@ class D2FrcnnDetector(ObjectDetector):
240
295
  """
241
296
  return build_model(config.clone()).eval()
242
297
 
243
- def _instantiate_d2_predictor(self) -> None:
244
- checkpointer = DetectionCheckpointer(self.d2_predictor)
245
- checkpointer.load(self.cfg.MODEL.WEIGHTS)
298
+ @staticmethod
299
+ def _instantiate_d2_predictor(wrapped_model: GeneralizedRCNN, path_weights: str) -> None:
300
+ checkpointer = DetectionCheckpointer(wrapped_model)
301
+ checkpointer.load(path_weights)
246
302
 
247
303
  def predict(self, np_img: ImageType) -> List[DetectionResult]:
248
304
  """
@@ -259,33 +315,10 @@ class D2FrcnnDetector(ObjectDetector):
259
315
  )
260
316
  return self._map_category_names(detection_results)
261
317
 
262
- def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
263
- """
264
- Populating category names to detection results
265
-
266
- :param detection_results: list of detection results. Will also filter categories
267
- :return: List of detection results with attribute class_name populated
268
- """
269
- filtered_detection_result: List[DetectionResult] = []
270
- for result in detection_results:
271
- result.class_name = self._categories_d2[str(result.class_id)]
272
- if isinstance(result.class_id, int):
273
- result.class_id += 1
274
- if self.filter_categories:
275
- if result.class_name not in self.filter_categories:
276
- filtered_detection_result.append(result)
277
- else:
278
- filtered_detection_result.append(result)
279
- return filtered_detection_result
280
-
281
318
  @classmethod
282
319
  def get_requirements(cls) -> List[Requirement]:
283
320
  return [get_pytorch_requirement(), get_detectron2_requirement()]
284
321
 
285
- @classmethod
286
- def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
287
- return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
288
-
289
322
  def clone(self) -> PredictorBase:
290
323
  return self.__class__(
291
324
  self.path_yaml,
@@ -296,11 +329,53 @@ class D2FrcnnDetector(ObjectDetector):
296
329
  self.filter_categories,
297
330
  )
298
331
 
299
- def possible_categories(self) -> List[ObjectTypes]:
300
- return list(self.categories.values())
332
+ @staticmethod
333
+ def get_wrapped_model(
334
+ path_yaml: str,
335
+ path_weights: str,
336
+ config_overwrite: List[str],
337
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
338
+ ) -> GeneralizedRCNN:
339
+ """
340
+ Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
341
+
342
+ Example:
343
+ ```python
344
+
345
+ path_yaml = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
346
+ weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
347
+ model = D2FrcnnDetector.get_wrapped_model(path_yaml,weights_path,["OUTPUT.FRCNN_NMS_THRESH=0.3",
348
+ "OUTPUT.RESULT_SCORE_THRESH=0.6"],
349
+ "cpu")
350
+ detect_result_list = d2_predict_image(np_img,model,InferenceResize(800,1333),0.3)
351
+ ```
352
+ :param path_yaml: The path to the yaml config. If the model is built using several config files, always use
353
+ the highest level .yaml file.
354
+ :param path_weights: The path to the model checkpoint.
355
+ :param config_overwrite: Overwrite some hyperparameters defined by the yaml file with some new values. E.g.
356
+ ["OUTPUT.FRCNN_NMS_THRESH=0.3","OUTPUT.RESULT_SCORE_THRESH=0.6"].
357
+ :param device: "cpu" or "cuda". If not specified will auto select depending on what is available
358
+ :return: Detectron2 GeneralizedRCNN model
359
+ """
360
+
361
+ device = get_torch_device(device)
362
+ d2_conf_list = D2FrcnnDetector._get_d2_config_list(path_weights, config_overwrite)
363
+ cfg = D2FrcnnDetector._set_config(path_yaml, d2_conf_list, device)
364
+ model = D2FrcnnDetector._set_model(cfg)
365
+ D2FrcnnDetector._instantiate_d2_predictor(model, path_weights)
366
+ return model
367
+
368
+ @staticmethod
369
+ def _get_d2_config_list(path_weights: str, config_overwrite: List[str]) -> List[str]:
370
+ d2_conf_list = ["MODEL.WEIGHTS", path_weights]
371
+ config_overwrite = config_overwrite if config_overwrite else []
372
+ for conf in config_overwrite:
373
+ key, val = conf.split("=", maxsplit=1)
374
+ d2_conf_list.extend([key, val])
375
+ return d2_conf_list
301
376
 
302
377
 
303
- class D2FrcnnTracingDetector(ObjectDetector):
378
+ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
304
379
  """
305
380
  D2 Faster-RCNN exported torchscript model. Using this predictor has the advantage that Detectron2 does not have to
306
381
  be installed. The associated config setting only contains parameters that are involved in pre-and post-processing.
@@ -312,6 +387,8 @@ class D2FrcnnTracingDetector(ObjectDetector):
312
387
  the standard D2 output that takes into account of the situation that detected objects are disjoint. For more infos
313
388
  on this topic, see <https://github.com/facebookresearch/detectron2/issues/978> .
314
389
 
390
+ Example:
391
+ ```python
315
392
  config_path = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
316
393
  weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
317
394
  categories = ModelCatalog.get_profile("item/d2_model-800000-layout.pkl").categories
@@ -319,6 +396,7 @@ class D2FrcnnTracingDetector(ObjectDetector):
319
396
  d2_predictor = D2FrcnnDetector(config_path,weights_path,categories)
320
397
 
321
398
  detection_results = d2_predictor.predict(bgr_image_np_array)
399
+ ```
322
400
  """
323
401
 
324
402
  def __init__(
@@ -343,27 +421,28 @@ class D2FrcnnTracingDetector(ObjectDetector):
343
421
  :param filter_categories: The model might return objects that are not supposed to be predicted and that should
344
422
  be filtered. Pass a list of category names that must not be returned
345
423
  """
346
- self.name = "_".join(Path(path_weights).parts[-2:])
347
- self._categories_d2 = self._map_to_d2_categories(copy(categories))
424
+
425
+ super().__init__(categories, filter_categories)
426
+
348
427
  self.path_weights = path_weights
349
428
  self.path_yaml = path_yaml
350
- self.categories = copy(categories) # type: ignore
351
- self.config_overwrite = config_overwrite
352
- if filter_categories:
353
- filter_categories = [get_type(cat) for cat in filter_categories]
354
- self.filter_categories = filter_categories
355
- self.cfg = set_config_by_yaml(self.path_yaml)
429
+
430
+ self.config_overwrite = copy(config_overwrite)
431
+ self.cfg = self._set_config(self.path_yaml, self.path_weights, self.config_overwrite)
432
+
433
+ self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
434
+ self.model_id = self.get_model_id()
435
+
436
+ self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
437
+ self.d2_predictor = self.get_wrapped_model(self.path_weights)
438
+
439
+ @staticmethod
440
+ def _set_config(path_yaml: str, path_weights: str, config_overwrite: Optional[List[str]]) -> AttrDict:
441
+ cfg = set_config_by_yaml(path_yaml)
356
442
  config_overwrite = config_overwrite if config_overwrite else []
357
443
  config_overwrite.extend([f"MODEL.WEIGHTS={path_weights}"])
358
- self.cfg.update_args(config_overwrite)
359
- self.resizer = InferenceResize(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
360
- self.d2_predictor = self._instantiate_d2_predictor()
361
-
362
- def _instantiate_d2_predictor(self) -> Any:
363
- with open(self.path_weights, "rb") as file:
364
- buffer = io.BytesIO(file.read())
365
- # Load all tensors to the original device
366
- return torch.jit.load(buffer)
444
+ cfg.update_args(config_overwrite)
445
+ return cfg
367
446
 
368
447
  def predict(self, np_img: ImageType) -> List[DetectionResult]:
369
448
  """
@@ -418,3 +497,16 @@ class D2FrcnnTracingDetector(ObjectDetector):
418
497
 
419
498
  def possible_categories(self) -> List[ObjectTypes]:
420
499
  return list(self.categories.values())
500
+
501
+ @staticmethod
502
+ def get_wrapped_model(path_weights: str) -> Any:
503
+ """
504
+ Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
505
+
506
+ :param path_weights:
507
+ :return:
508
+ """
509
+ with open(path_weights, "rb") as file:
510
+ buffer = io.BytesIO(file.read())
511
+ # Load all tensors to the original device
512
+ return torch.jit.load(buffer)
@@ -21,13 +21,16 @@ jdeskew estimator and rotator to deskew images: <https://github.com/phamquiluan/
21
21
 
22
22
  from typing import List
23
23
 
24
+ from lazy_imports import try_import
25
+
24
26
  from ..utils.detection_types import ImageType, Requirement
25
- from ..utils.file_utils import get_jdeskew_requirement, jdeskew_available
26
- from .base import ImageTransformer
27
+ from ..utils.file_utils import get_jdeskew_requirement
28
+ from ..utils.settings import PageType
29
+ from ..utils.viz import viz_handler
30
+ from .base import DetectionResult, ImageTransformer
27
31
 
28
- if jdeskew_available():
32
+ with try_import() as import_guard:
29
33
  from jdeskew.estimator import get_angle
30
- from jdeskew.utility import rotate
31
34
 
32
35
 
33
36
  class Jdeskewer(ImageTransformer):
@@ -37,19 +40,43 @@ class Jdeskewer(ImageTransformer):
37
40
  """
38
41
 
39
42
  def __init__(self, min_angle_rotation: float = 2.0):
40
- self.name = "jdeskew_transform"
43
+ self.name = "jdeskewer"
44
+ self.model_id = self.get_model_id()
41
45
  self.min_angle_rotation = min_angle_rotation
42
46
 
43
- def transform(self, np_img: ImageType) -> ImageType:
44
- angle = get_angle(np_img)
47
+ def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
48
+ """
49
+ Rotation of the image according to the angle determined by the jdeskew estimator.
50
+
51
+ **Example**:
52
+ jdeskew_predictor = Jdeskewer()
53
+ detection_result = jdeskew_predictor.predict(np_image)
54
+ jdeskew_predictor.transform(np_image, DetectionResult(angle=5.0))
45
55
 
46
- if angle > self.min_angle_rotation:
47
- return rotate(np_img, angle)
56
+ :param np_img: image as numpy array
57
+ :param specification: DetectionResult with angle value
58
+ :return: image rotated by the angle
59
+ """
60
+ if abs(specification.angle) > self.min_angle_rotation: # type: ignore
61
+ return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
48
62
  return np_img
49
63
 
64
+ def predict(self, np_img: ImageType) -> DetectionResult:
65
+ """
66
+ Predict the angle of the image to deskew it.
67
+
68
+ :param np_img: image as numpy array
69
+ :return: DetectionResult with angle value
70
+ """
71
+ return DetectionResult(angle=round(float(get_angle(np_img)), 4))
72
+
50
73
  @classmethod
51
74
  def get_requirements(cls) -> List[Requirement]:
52
75
  """
53
76
  Get a list of requirements for running the detector
54
77
  """
55
78
  return [get_jdeskew_requirement()]
79
+
80
+ @staticmethod
81
+ def possible_category() -> PageType:
82
+ return PageType.angle