deepdoctection 0.30__py3-none-any.whl → 0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (74) hide show
  1. deepdoctection/__init__.py +4 -2
  2. deepdoctection/analyzer/dd.py +6 -5
  3. deepdoctection/dataflow/base.py +0 -19
  4. deepdoctection/dataflow/custom.py +4 -3
  5. deepdoctection/dataflow/custom_serialize.py +14 -5
  6. deepdoctection/dataflow/parallel_map.py +12 -11
  7. deepdoctection/dataflow/serialize.py +5 -4
  8. deepdoctection/datapoint/annotation.py +33 -12
  9. deepdoctection/datapoint/box.py +1 -4
  10. deepdoctection/datapoint/convert.py +3 -1
  11. deepdoctection/datapoint/image.py +66 -29
  12. deepdoctection/datapoint/view.py +57 -25
  13. deepdoctection/datasets/adapter.py +1 -1
  14. deepdoctection/datasets/base.py +83 -10
  15. deepdoctection/datasets/dataflow_builder.py +1 -1
  16. deepdoctection/datasets/info.py +2 -2
  17. deepdoctection/datasets/instances/layouttest.py +2 -7
  18. deepdoctection/eval/accmetric.py +1 -1
  19. deepdoctection/eval/base.py +5 -4
  20. deepdoctection/eval/eval.py +2 -2
  21. deepdoctection/eval/tp_eval_callback.py +5 -4
  22. deepdoctection/extern/base.py +39 -13
  23. deepdoctection/extern/d2detect.py +164 -64
  24. deepdoctection/extern/deskew.py +32 -7
  25. deepdoctection/extern/doctrocr.py +227 -39
  26. deepdoctection/extern/fastlang.py +45 -7
  27. deepdoctection/extern/hfdetr.py +90 -33
  28. deepdoctection/extern/hflayoutlm.py +109 -22
  29. deepdoctection/extern/pdftext.py +2 -1
  30. deepdoctection/extern/pt/ptutils.py +3 -2
  31. deepdoctection/extern/tessocr.py +134 -22
  32. deepdoctection/extern/texocr.py +2 -0
  33. deepdoctection/extern/tp/tpcompat.py +4 -4
  34. deepdoctection/extern/tp/tpfrcnn/preproc.py +2 -7
  35. deepdoctection/extern/tpdetect.py +50 -23
  36. deepdoctection/mapper/d2struct.py +1 -1
  37. deepdoctection/mapper/hfstruct.py +1 -1
  38. deepdoctection/mapper/laylmstruct.py +1 -1
  39. deepdoctection/mapper/maputils.py +13 -2
  40. deepdoctection/mapper/prodigystruct.py +1 -1
  41. deepdoctection/mapper/pubstruct.py +10 -10
  42. deepdoctection/mapper/tpstruct.py +1 -1
  43. deepdoctection/pipe/anngen.py +35 -8
  44. deepdoctection/pipe/base.py +53 -19
  45. deepdoctection/pipe/cell.py +29 -8
  46. deepdoctection/pipe/common.py +12 -4
  47. deepdoctection/pipe/doctectionpipe.py +2 -2
  48. deepdoctection/pipe/language.py +3 -2
  49. deepdoctection/pipe/layout.py +3 -2
  50. deepdoctection/pipe/lm.py +2 -2
  51. deepdoctection/pipe/refine.py +18 -10
  52. deepdoctection/pipe/segment.py +21 -16
  53. deepdoctection/pipe/text.py +14 -8
  54. deepdoctection/pipe/transform.py +16 -9
  55. deepdoctection/train/d2_frcnn_train.py +15 -12
  56. deepdoctection/train/hf_detr_train.py +8 -6
  57. deepdoctection/train/hf_layoutlm_train.py +16 -11
  58. deepdoctection/utils/__init__.py +3 -0
  59. deepdoctection/utils/concurrency.py +1 -1
  60. deepdoctection/utils/context.py +2 -2
  61. deepdoctection/utils/env_info.py +55 -22
  62. deepdoctection/utils/error.py +84 -0
  63. deepdoctection/utils/file_utils.py +4 -15
  64. deepdoctection/utils/fs.py +7 -7
  65. deepdoctection/utils/pdf_utils.py +5 -4
  66. deepdoctection/utils/settings.py +5 -1
  67. deepdoctection/utils/transform.py +1 -1
  68. deepdoctection/utils/utils.py +0 -6
  69. deepdoctection/utils/viz.py +44 -2
  70. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/METADATA +33 -58
  71. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/RECORD +74 -73
  72. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/WHEEL +1 -1
  73. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/LICENSE +0 -0
  74. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/top_level.txt +0 -0
@@ -83,10 +83,11 @@ class EvalCallback(Callback): # pylint: disable=R0903
83
83
  self.num_gpu = get_num_gpu()
84
84
  self.category_names = category_names
85
85
  self.sub_categories = sub_categories
86
- assert isinstance(pipeline_component.predictor, TPFrcnnDetector), (
87
- f"pipeline_component.predictor must be of "
88
- f"type TPFrcnnDetector but is type {type(pipeline_component.predictor)}"
89
- )
86
+ if not isinstance(pipeline_component.predictor, TPFrcnnDetector):
87
+ raise TypeError(
88
+ f"pipeline_component.predictor must be of type TPFrcnnDetector but is "
89
+ f"type {type(pipeline_component.predictor)}"
90
+ )
90
91
  self.cfg = pipeline_component.predictor.model.cfg
91
92
  if _use_replicated(self.cfg):
92
93
  self.evaluator = Evaluator(dataset, pipeline_component, metric, num_threads=self.num_gpu * 2)
@@ -25,6 +25,7 @@ from dataclasses import dataclass
25
25
  from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
26
26
 
27
27
  from ..utils.detection_types import ImageType, JsonDict, Requirement
28
+ from ..utils.identifier import get_uuid_from_str
28
29
  from ..utils.settings import DefaultType, ObjectTypes, TypeOrStr, get_type
29
30
 
30
31
 
@@ -34,6 +35,7 @@ class PredictorBase(ABC):
34
35
  """
35
36
 
36
37
  name: str
38
+ model_id: str
37
39
 
38
40
  def __new__(cls, *args, **kwargs): # type: ignore # pylint: disable=W0613
39
41
  requirements = cls.get_requirements()
@@ -53,14 +55,22 @@ class PredictorBase(ABC):
53
55
  """
54
56
  Get a list of requirements for running the detector
55
57
  """
56
- raise NotImplementedError
58
+ raise NotImplementedError()
57
59
 
58
60
  @abstractmethod
59
61
  def clone(self) -> "PredictorBase":
60
62
  """
61
63
  Clone an instance
62
64
  """
63
- raise NotImplementedError
65
+ raise NotImplementedError()
66
+
67
+ def get_model_id(self) -> str:
68
+ """
69
+ Get the generating model
70
+ """
71
+ if self.name is not None:
72
+ return get_uuid_from_str(self.name)[:8]
73
+ raise ValueError("name must be set before calling get_model_id")
64
74
 
65
75
 
66
76
  @dataclass
@@ -102,6 +112,7 @@ class DetectionResult:
102
112
  line: Optional[str] = None
103
113
  uuid: Optional[str] = None
104
114
  relationships: Optional[Dict[str, Any]] = None
115
+ angle: Optional[float] = None
105
116
 
106
117
 
107
118
  class ObjectDetector(PredictorBase):
@@ -133,7 +144,7 @@ class ObjectDetector(PredictorBase):
133
144
  """
134
145
  Abstract method predict
135
146
  """
136
- raise NotImplementedError
147
+ raise NotImplementedError()
137
148
 
138
149
  @property
139
150
  def accepts_batch(self) -> bool:
@@ -174,14 +185,14 @@ class PdfMiner(PredictorBase):
174
185
  """
175
186
  Abstract method predict
176
187
  """
177
- raise NotImplementedError
188
+ raise NotImplementedError()
178
189
 
179
190
  @abstractmethod
180
191
  def get_width_height(self, pdf_bytes: bytes) -> Tuple[float, float]:
181
192
  """
182
193
  Abstract method get_width_height
183
194
  """
184
- raise NotImplementedError
195
+ raise NotImplementedError()
185
196
 
186
197
  def clone(self) -> PredictorBase:
187
198
  return self.__class__()
@@ -212,7 +223,7 @@ class TextRecognizer(PredictorBase):
212
223
  """
213
224
  Abstract method predict
214
225
  """
215
- raise NotImplementedError
226
+ raise NotImplementedError()
216
227
 
217
228
  @property
218
229
  def accepts_batch(self) -> bool:
@@ -294,7 +305,7 @@ class LMTokenClassifier(PredictorBase):
294
305
  """
295
306
  Abstract method predict
296
307
  """
297
- raise NotImplementedError
308
+ raise NotImplementedError()
298
309
 
299
310
  def possible_tokens(self) -> List[ObjectTypes]:
300
311
  """
@@ -307,7 +318,7 @@ class LMTokenClassifier(PredictorBase):
307
318
  """
308
319
  Clone an instance
309
320
  """
310
- raise NotImplementedError
321
+ raise NotImplementedError()
311
322
 
312
323
  @staticmethod
313
324
  def default_kwargs_for_input_mapping() -> JsonDict:
@@ -341,7 +352,7 @@ class LMSequenceClassifier(PredictorBase):
341
352
  """
342
353
  Abstract method predict
343
354
  """
344
- raise NotImplementedError
355
+ raise NotImplementedError()
345
356
 
346
357
  def possible_categories(self) -> List[ObjectTypes]:
347
358
  """
@@ -354,7 +365,7 @@ class LMSequenceClassifier(PredictorBase):
354
365
  """
355
366
  Clone an instance
356
367
  """
357
- raise NotImplementedError
368
+ raise NotImplementedError()
358
369
 
359
370
  @staticmethod
360
371
  def default_kwargs_for_input_mapping() -> JsonDict:
@@ -388,7 +399,7 @@ class LanguageDetector(PredictorBase):
388
399
  """
389
400
  Abstract method predict
390
401
  """
391
- raise NotImplementedError
402
+ raise NotImplementedError()
392
403
 
393
404
  def possible_languages(self) -> List[ObjectTypes]:
394
405
  """
@@ -403,11 +414,26 @@ class ImageTransformer(PredictorBase):
403
414
  """
404
415
 
405
416
  @abstractmethod
406
- def transform(self, np_img: ImageType) -> ImageType:
417
+ def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
407
418
  """
408
419
  Abstract method transform
409
420
  """
410
- raise NotImplementedError
421
+ raise NotImplementedError()
422
+
423
+ @abstractmethod
424
+ def predict(self, np_img: ImageType) -> DetectionResult:
425
+ """
426
+ Abstract method predict
427
+ """
428
+ raise NotImplementedError()
411
429
 
412
430
  def clone(self) -> PredictorBase:
413
431
  return self.__class__()
432
+
433
+ @staticmethod
434
+ @abstractmethod
435
+ def possible_category() -> ObjectTypes:
436
+ """
437
+ Returns a (single) category the `ImageTransformer` can predict
438
+ """
439
+ raise NotImplementedError()
@@ -19,6 +19,7 @@
19
19
  D2 GeneralizedRCNN model as predictor for deepdoctection pipeline
20
20
  """
21
21
  import io
22
+ from abc import ABC
22
23
  from copy import copy
23
24
  from pathlib import Path
24
25
  from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence
@@ -32,7 +33,7 @@ from ..utils.file_utils import (
32
33
  get_pytorch_requirement,
33
34
  pytorch_available,
34
35
  )
35
- from ..utils.metacfg import set_config_by_yaml
36
+ from ..utils.metacfg import AttrDict, set_config_by_yaml
36
37
  from ..utils.settings import ObjectTypes, TypeOrStr, get_type
37
38
  from ..utils.transform import InferenceResize, ResizeTransform
38
39
  from .base import DetectionResult, ObjectDetector, PredictorBase
@@ -144,7 +145,72 @@ def d2_jit_predict_image(
144
145
  return detect_result_list
145
146
 
146
147
 
147
- class D2FrcnnDetector(ObjectDetector):
148
+ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
149
+ """
150
+ Base class for D2 Faster-RCNN implementation. This class only implements the basic wrapper functions
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ categories: Mapping[str, TypeOrStr],
156
+ filter_categories: Optional[Sequence[TypeOrStr]] = None,
157
+ ):
158
+ """
159
+ :param categories: A dict with key (indices) and values (category names). Index 0 must be reserved for a
160
+ dummy 'BG' category. Note, that this convention is different from the builtin D2 framework,
161
+ where models in the model zoo are trained with 'BG' class having the highest index.
162
+ :param filter_categories: The model might return objects that are not supposed to be predicted and that should
163
+ be filtered. Pass a list of category names that must not be returned
164
+ """
165
+
166
+ if filter_categories:
167
+ filter_categories = [get_type(cat) for cat in filter_categories]
168
+ self.filter_categories = filter_categories
169
+ self._categories_d2 = self._map_to_d2_categories(copy(categories))
170
+ self.categories = {idx: get_type(cat) for idx, cat in categories.items()}
171
+
172
+ def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
173
+ """
174
+ Populating category names to detection results
175
+
176
+ :param detection_results: list of detection results. Will also filter categories
177
+ :return: List of detection results with attribute class_name populated
178
+ """
179
+ filtered_detection_result: List[DetectionResult] = []
180
+ for result in detection_results:
181
+ result.class_name = self._categories_d2[str(result.class_id)]
182
+ if isinstance(result.class_id, int):
183
+ result.class_id += 1
184
+ if self.filter_categories:
185
+ if result.class_name not in self.filter_categories:
186
+ filtered_detection_result.append(result)
187
+ else:
188
+ filtered_detection_result.append(result)
189
+ return filtered_detection_result
190
+
191
+ @classmethod
192
+ def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
193
+ return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
194
+
195
+ def possible_categories(self) -> List[ObjectTypes]:
196
+ return list(self.categories.values())
197
+
198
+ @staticmethod
199
+ def get_inference_resizer(min_size_test: int, max_size_test: int) -> InferenceResize:
200
+ """Returns the resizer for the inference
201
+
202
+ :param min_size_test: minimum size of the resized image
203
+ :param max_size_test: maximum size of the resized image
204
+ """
205
+ return InferenceResize(min_size_test, max_size_test)
206
+
207
+ @staticmethod
208
+ def get_name(path_weights: str, architecture: str) -> str:
209
+ """Returns the name of the model"""
210
+ return f"detectron2_{architecture}" + "_".join(Path(path_weights).parts[-2:])
211
+
212
+
213
+ class D2FrcnnDetector(D2FrcnnDetectorMixin):
148
214
  """
149
215
  D2 Faster-RCNN implementation with all the available backbones, normalizations throughout the model
150
216
  as well as FPN, optional Cascade-RCNN and many more.
@@ -155,6 +221,7 @@ class D2FrcnnDetector(ObjectDetector):
155
221
  the standard D2 output that takes into account of the situation that detected objects are disjoint. For more infos
156
222
  on this topic, see <https://github.com/facebookresearch/detectron2/issues/978> .
157
223
 
224
+ ```python
158
225
  config_path = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
159
226
  weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
160
227
  categories = ModelCatalog.get_profile("item/d2_model-800000-layout.pkl").categories
@@ -162,6 +229,7 @@ class D2FrcnnDetector(ObjectDetector):
162
229
  d2_predictor = D2FrcnnDetector(config_path,weights_path,categories,device="cpu")
163
230
 
164
231
  detection_results = d2_predictor.predict(bgr_image_np_array)
232
+ ```
165
233
  """
166
234
 
167
235
  def __init__(
@@ -191,30 +259,27 @@ class D2FrcnnDetector(ObjectDetector):
191
259
  :param filter_categories: The model might return objects that are not supposed to be predicted and that should
192
260
  be filtered. Pass a list of category names that must not be returned
193
261
  """
262
+ super().__init__(categories, filter_categories)
194
263
 
195
- self.name = "_".join(Path(path_weights).parts[-3:])
196
- self._categories_d2 = self._map_to_d2_categories(copy(categories))
197
264
  self.path_weights = path_weights
198
- d2_conf_list = ["MODEL.WEIGHTS", path_weights]
199
- config_overwrite = config_overwrite if config_overwrite else []
200
- for conf in config_overwrite:
201
- key, val = conf.split("=", maxsplit=1)
202
- d2_conf_list.extend([key, val])
203
-
204
265
  self.path_yaml = path_yaml
205
- self.categories = copy(categories) # type: ignore
266
+
267
+ config_overwrite = config_overwrite if config_overwrite else []
206
268
  self.config_overwrite = config_overwrite
207
269
  if device is not None:
208
270
  self.device = device
209
271
  else:
210
272
  self.device = set_torch_auto_device()
211
- if filter_categories:
212
- filter_categories = [get_type(cat) for cat in filter_categories]
213
- self.filter_categories = filter_categories
273
+
274
+ d2_conf_list = self._get_d2_config_list(path_weights, config_overwrite)
214
275
  self.cfg = self._set_config(path_yaml, d2_conf_list, device)
215
- self.d2_predictor = D2FrcnnDetector.set_model(self.cfg)
216
- self.resizer = InferenceResize(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
217
- self._instantiate_d2_predictor()
276
+
277
+ self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
278
+ self.model_id = self.get_model_id()
279
+
280
+ self.d2_predictor = self._set_model(self.cfg)
281
+ self._instantiate_d2_predictor(self.d2_predictor, path_weights)
282
+ self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
218
283
 
219
284
  @staticmethod
220
285
  def _set_config(
@@ -231,7 +296,7 @@ class D2FrcnnDetector(ObjectDetector):
231
296
  return cfg
232
297
 
233
298
  @staticmethod
234
- def set_model(config: "CfgNode") -> "GeneralizedRCNN":
299
+ def _set_model(config: "CfgNode") -> "GeneralizedRCNN":
235
300
  """
236
301
  Build the D2 model. It uses the available builtin tools of D2
237
302
 
@@ -240,9 +305,10 @@ class D2FrcnnDetector(ObjectDetector):
240
305
  """
241
306
  return build_model(config.clone()).eval()
242
307
 
243
- def _instantiate_d2_predictor(self) -> None:
244
- checkpointer = DetectionCheckpointer(self.d2_predictor)
245
- checkpointer.load(self.cfg.MODEL.WEIGHTS)
308
+ @staticmethod
309
+ def _instantiate_d2_predictor(wrapped_model: "GeneralizedRCNN", path_weights: str) -> None:
310
+ checkpointer = DetectionCheckpointer(wrapped_model)
311
+ checkpointer.load(path_weights)
246
312
 
247
313
  def predict(self, np_img: ImageType) -> List[DetectionResult]:
248
314
  """
@@ -259,33 +325,10 @@ class D2FrcnnDetector(ObjectDetector):
259
325
  )
260
326
  return self._map_category_names(detection_results)
261
327
 
262
- def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
263
- """
264
- Populating category names to detection results
265
-
266
- :param detection_results: list of detection results. Will also filter categories
267
- :return: List of detection results with attribute class_name populated
268
- """
269
- filtered_detection_result: List[DetectionResult] = []
270
- for result in detection_results:
271
- result.class_name = self._categories_d2[str(result.class_id)]
272
- if isinstance(result.class_id, int):
273
- result.class_id += 1
274
- if self.filter_categories:
275
- if result.class_name not in self.filter_categories:
276
- filtered_detection_result.append(result)
277
- else:
278
- filtered_detection_result.append(result)
279
- return filtered_detection_result
280
-
281
328
  @classmethod
282
329
  def get_requirements(cls) -> List[Requirement]:
283
330
  return [get_pytorch_requirement(), get_detectron2_requirement()]
284
331
 
285
- @classmethod
286
- def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
287
- return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
288
-
289
332
  def clone(self) -> PredictorBase:
290
333
  return self.__class__(
291
334
  self.path_yaml,
@@ -296,11 +339,51 @@ class D2FrcnnDetector(ObjectDetector):
296
339
  self.filter_categories,
297
340
  )
298
341
 
299
- def possible_categories(self) -> List[ObjectTypes]:
300
- return list(self.categories.values())
342
+ @staticmethod
343
+ def get_wrapped_model(
344
+ path_yaml: str, path_weights: str, config_overwrite: List[str], device: Literal["cpu", "cuda"]
345
+ ) -> "GeneralizedRCNN":
346
+ """
347
+ Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
348
+
349
+ Example:
350
+ ```python
351
+
352
+ path_yaml = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
353
+ weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
354
+ model = D2FrcnnDetector.get_wrapped_model(path_yaml,weights_path,["OUTPUT.FRCNN_NMS_THRESH=0.3",
355
+ "OUTPUT.RESULT_SCORE_THRESH=0.6"],
356
+ "cpu")
357
+ detect_result_list = d2_predict_image(np_img,model,InferenceResize(800,1333),0.3)
358
+ ```
359
+ :param path_yaml: The path to the yaml config. If the model is built using several config files, always use
360
+ the highest level .yaml file.
361
+ :param path_weights: The path to the model checkpoint.
362
+ :param config_overwrite: Overwrite some hyperparameters defined by the yaml file with some new values. E.g.
363
+ ["OUTPUT.FRCNN_NMS_THRESH=0.3","OUTPUT.RESULT_SCORE_THRESH=0.6"].
364
+ :param device: "cpu" or "cuda". If not specified will auto select depending on what is available
365
+ :return: Detectron2 GeneralizedRCNN model
366
+ """
367
+
368
+ if device is None:
369
+ device = set_torch_auto_device()
370
+ d2_conf_list = D2FrcnnDetector._get_d2_config_list(path_weights, config_overwrite)
371
+ cfg = D2FrcnnDetector._set_config(path_yaml, d2_conf_list, device)
372
+ model = D2FrcnnDetector._set_model(cfg)
373
+ D2FrcnnDetector._instantiate_d2_predictor(model, path_weights)
374
+ return model
375
+
376
+ @staticmethod
377
+ def _get_d2_config_list(path_weights: str, config_overwrite: List[str]) -> List[str]:
378
+ d2_conf_list = ["MODEL.WEIGHTS", path_weights]
379
+ config_overwrite = config_overwrite if config_overwrite else []
380
+ for conf in config_overwrite:
381
+ key, val = conf.split("=", maxsplit=1)
382
+ d2_conf_list.extend([key, val])
383
+ return d2_conf_list
301
384
 
302
385
 
303
- class D2FrcnnTracingDetector(ObjectDetector):
386
+ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
304
387
  """
305
388
  D2 Faster-RCNN exported torchscript model. Using this predictor has the advantage that Detectron2 does not have to
306
389
  be installed. The associated config setting only contains parameters that are involved in pre-and post-processing.
@@ -312,6 +395,8 @@ class D2FrcnnTracingDetector(ObjectDetector):
312
395
  the standard D2 output that takes into account of the situation that detected objects are disjoint. For more infos
313
396
  on this topic, see <https://github.com/facebookresearch/detectron2/issues/978> .
314
397
 
398
+ Example:
399
+ ```python
315
400
  config_path = ModelCatalog.get_full_path_configs("dd/d2/item/CASCADE_RCNN_R_50_FPN_GN.yaml")
316
401
  weights_path = ModelDownloadManager.maybe_download_weights_and_configs("item/d2_model-800000-layout.pkl")
317
402
  categories = ModelCatalog.get_profile("item/d2_model-800000-layout.pkl").categories
@@ -319,6 +404,7 @@ class D2FrcnnTracingDetector(ObjectDetector):
319
404
  d2_predictor = D2FrcnnDetector(config_path,weights_path,categories)
320
405
 
321
406
  detection_results = d2_predictor.predict(bgr_image_np_array)
407
+ ```
322
408
  """
323
409
 
324
410
  def __init__(
@@ -343,27 +429,28 @@ class D2FrcnnTracingDetector(ObjectDetector):
343
429
  :param filter_categories: The model might return objects that are not supposed to be predicted and that should
344
430
  be filtered. Pass a list of category names that must not be returned
345
431
  """
346
- self.name = "_".join(Path(path_weights).parts[-2:])
347
- self._categories_d2 = self._map_to_d2_categories(copy(categories))
432
+
433
+ super().__init__(categories, filter_categories)
434
+
348
435
  self.path_weights = path_weights
349
436
  self.path_yaml = path_yaml
350
- self.categories = copy(categories) # type: ignore
351
- self.config_overwrite = config_overwrite
352
- if filter_categories:
353
- filter_categories = [get_type(cat) for cat in filter_categories]
354
- self.filter_categories = filter_categories
355
- self.cfg = set_config_by_yaml(self.path_yaml)
437
+
438
+ self.config_overwrite = copy(config_overwrite)
439
+ self.cfg = self._set_config(self.path_yaml, self.path_weights, self.config_overwrite)
440
+
441
+ self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
442
+ self.model_id = self.get_model_id()
443
+
444
+ self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
445
+ self.d2_predictor = self.get_wrapped_model(self.path_weights)
446
+
447
+ @staticmethod
448
+ def _set_config(path_yaml: str, path_weights: str, config_overwrite: Optional[List[str]]) -> AttrDict:
449
+ cfg = set_config_by_yaml(path_yaml)
356
450
  config_overwrite = config_overwrite if config_overwrite else []
357
451
  config_overwrite.extend([f"MODEL.WEIGHTS={path_weights}"])
358
- self.cfg.update_args(config_overwrite)
359
- self.resizer = InferenceResize(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
360
- self.d2_predictor = self._instantiate_d2_predictor()
361
-
362
- def _instantiate_d2_predictor(self) -> Any:
363
- with open(self.path_weights, "rb") as file:
364
- buffer = io.BytesIO(file.read())
365
- # Load all tensors to the original device
366
- return torch.jit.load(buffer)
452
+ cfg.update_args(config_overwrite)
453
+ return cfg
367
454
 
368
455
  def predict(self, np_img: ImageType) -> List[DetectionResult]:
369
456
  """
@@ -418,3 +505,16 @@ class D2FrcnnTracingDetector(ObjectDetector):
418
505
 
419
506
  def possible_categories(self) -> List[ObjectTypes]:
420
507
  return list(self.categories.values())
508
+
509
+ @staticmethod
510
+ def get_wrapped_model(path_weights: str) -> Any:
511
+ """
512
+ Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
513
+
514
+ :param path_weights:
515
+ :return:
516
+ """
517
+ with open(path_weights, "rb") as file:
518
+ buffer = io.BytesIO(file.read())
519
+ # Load all tensors to the original device
520
+ return torch.jit.load(buffer)
@@ -23,11 +23,12 @@ from typing import List
23
23
 
24
24
  from ..utils.detection_types import ImageType, Requirement
25
25
  from ..utils.file_utils import get_jdeskew_requirement, jdeskew_available
26
- from .base import ImageTransformer
26
+ from ..utils.settings import PageType
27
+ from ..utils.viz import viz_handler
28
+ from .base import DetectionResult, ImageTransformer
27
29
 
28
30
  if jdeskew_available():
29
31
  from jdeskew.estimator import get_angle
30
- from jdeskew.utility import rotate
31
32
 
32
33
 
33
34
  class Jdeskewer(ImageTransformer):
@@ -37,19 +38,43 @@ class Jdeskewer(ImageTransformer):
37
38
  """
38
39
 
39
40
  def __init__(self, min_angle_rotation: float = 2.0):
40
- self.name = "jdeskew_transform"
41
+ self.name = "jdeskewer"
42
+ self.model_id = self.get_model_id()
41
43
  self.min_angle_rotation = min_angle_rotation
42
44
 
43
- def transform(self, np_img: ImageType) -> ImageType:
44
- angle = get_angle(np_img)
45
+ def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
46
+ """
47
+ Rotation of the image according to the angle determined by the jdeskew estimator.
48
+
49
+ **Example**:
50
+ jdeskew_predictor = Jdeskewer()
51
+ detection_result = jdeskew_predictor.predict(np_image)
52
+ jdeskew_predictor.transform(np_image, DetectionResult(angle=5.0))
45
53
 
46
- if angle > self.min_angle_rotation:
47
- return rotate(np_img, angle)
54
+ :param np_img: image as numpy array
55
+ :param specification: DetectionResult with angle value
56
+ :return: image rotated by the angle
57
+ """
58
+ if abs(specification.angle) > self.min_angle_rotation: # type: ignore
59
+ return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
48
60
  return np_img
49
61
 
62
+ def predict(self, np_img: ImageType) -> DetectionResult:
63
+ """
64
+ Predict the angle of the image to deskew it.
65
+
66
+ :param np_img: image as numpy array
67
+ :return: DetectionResult with angle value
68
+ """
69
+ return DetectionResult(angle=round(float(get_angle(np_img)), 4))
70
+
50
71
  @classmethod
51
72
  def get_requirements(cls) -> List[Requirement]:
52
73
  """
53
74
  Get a list of requirements for running the detector
54
75
  """
55
76
  return [get_jdeskew_requirement()]
77
+
78
+ @staticmethod
79
+ def possible_category() -> PageType:
80
+ return PageType.angle