deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (131) hide show
  1. deepdoctection/__init__.py +16 -29
  2. deepdoctection/analyzer/dd.py +70 -59
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/common.py +9 -5
  5. deepdoctection/dataflow/custom.py +5 -5
  6. deepdoctection/dataflow/custom_serialize.py +75 -18
  7. deepdoctection/dataflow/parallel_map.py +3 -3
  8. deepdoctection/dataflow/serialize.py +4 -4
  9. deepdoctection/dataflow/stats.py +3 -3
  10. deepdoctection/datapoint/annotation.py +41 -56
  11. deepdoctection/datapoint/box.py +9 -8
  12. deepdoctection/datapoint/convert.py +6 -6
  13. deepdoctection/datapoint/image.py +56 -44
  14. deepdoctection/datapoint/view.py +245 -150
  15. deepdoctection/datasets/__init__.py +1 -4
  16. deepdoctection/datasets/adapter.py +35 -26
  17. deepdoctection/datasets/base.py +14 -12
  18. deepdoctection/datasets/dataflow_builder.py +3 -3
  19. deepdoctection/datasets/info.py +24 -26
  20. deepdoctection/datasets/instances/doclaynet.py +51 -51
  21. deepdoctection/datasets/instances/fintabnet.py +46 -46
  22. deepdoctection/datasets/instances/funsd.py +25 -24
  23. deepdoctection/datasets/instances/iiitar13k.py +13 -10
  24. deepdoctection/datasets/instances/layouttest.py +4 -3
  25. deepdoctection/datasets/instances/publaynet.py +5 -5
  26. deepdoctection/datasets/instances/pubtables1m.py +24 -21
  27. deepdoctection/datasets/instances/pubtabnet.py +32 -30
  28. deepdoctection/datasets/instances/rvlcdip.py +30 -30
  29. deepdoctection/datasets/instances/xfund.py +26 -26
  30. deepdoctection/datasets/save.py +6 -6
  31. deepdoctection/eval/__init__.py +1 -4
  32. deepdoctection/eval/accmetric.py +32 -33
  33. deepdoctection/eval/base.py +8 -9
  34. deepdoctection/eval/cocometric.py +15 -13
  35. deepdoctection/eval/eval.py +41 -37
  36. deepdoctection/eval/tedsmetric.py +30 -23
  37. deepdoctection/eval/tp_eval_callback.py +16 -19
  38. deepdoctection/extern/__init__.py +2 -7
  39. deepdoctection/extern/base.py +339 -134
  40. deepdoctection/extern/d2detect.py +85 -113
  41. deepdoctection/extern/deskew.py +14 -11
  42. deepdoctection/extern/doctrocr.py +141 -130
  43. deepdoctection/extern/fastlang.py +27 -18
  44. deepdoctection/extern/hfdetr.py +71 -62
  45. deepdoctection/extern/hflayoutlm.py +504 -211
  46. deepdoctection/extern/hflm.py +230 -0
  47. deepdoctection/extern/model.py +488 -302
  48. deepdoctection/extern/pdftext.py +23 -19
  49. deepdoctection/extern/pt/__init__.py +1 -3
  50. deepdoctection/extern/pt/nms.py +6 -2
  51. deepdoctection/extern/pt/ptutils.py +29 -19
  52. deepdoctection/extern/tessocr.py +39 -38
  53. deepdoctection/extern/texocr.py +18 -18
  54. deepdoctection/extern/tp/tfutils.py +57 -9
  55. deepdoctection/extern/tp/tpcompat.py +21 -14
  56. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
  60. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  61. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
  62. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
  67. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
  68. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  69. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  70. deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
  71. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  72. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  73. deepdoctection/extern/tpdetect.py +45 -53
  74. deepdoctection/mapper/__init__.py +3 -8
  75. deepdoctection/mapper/cats.py +27 -29
  76. deepdoctection/mapper/cocostruct.py +10 -10
  77. deepdoctection/mapper/d2struct.py +27 -26
  78. deepdoctection/mapper/hfstruct.py +13 -8
  79. deepdoctection/mapper/laylmstruct.py +178 -37
  80. deepdoctection/mapper/maputils.py +12 -11
  81. deepdoctection/mapper/match.py +2 -2
  82. deepdoctection/mapper/misc.py +11 -9
  83. deepdoctection/mapper/pascalstruct.py +4 -4
  84. deepdoctection/mapper/prodigystruct.py +5 -5
  85. deepdoctection/mapper/pubstruct.py +84 -92
  86. deepdoctection/mapper/tpstruct.py +5 -5
  87. deepdoctection/mapper/xfundstruct.py +33 -33
  88. deepdoctection/pipe/__init__.py +1 -1
  89. deepdoctection/pipe/anngen.py +12 -14
  90. deepdoctection/pipe/base.py +52 -106
  91. deepdoctection/pipe/common.py +72 -59
  92. deepdoctection/pipe/concurrency.py +16 -11
  93. deepdoctection/pipe/doctectionpipe.py +24 -21
  94. deepdoctection/pipe/language.py +20 -25
  95. deepdoctection/pipe/layout.py +20 -16
  96. deepdoctection/pipe/lm.py +75 -105
  97. deepdoctection/pipe/order.py +194 -89
  98. deepdoctection/pipe/refine.py +111 -124
  99. deepdoctection/pipe/segment.py +156 -161
  100. deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
  101. deepdoctection/pipe/text.py +37 -36
  102. deepdoctection/pipe/transform.py +19 -16
  103. deepdoctection/train/__init__.py +6 -12
  104. deepdoctection/train/d2_frcnn_train.py +48 -41
  105. deepdoctection/train/hf_detr_train.py +41 -30
  106. deepdoctection/train/hf_layoutlm_train.py +153 -135
  107. deepdoctection/train/tp_frcnn_train.py +32 -31
  108. deepdoctection/utils/concurrency.py +1 -1
  109. deepdoctection/utils/context.py +13 -6
  110. deepdoctection/utils/develop.py +4 -4
  111. deepdoctection/utils/env_info.py +87 -125
  112. deepdoctection/utils/file_utils.py +6 -11
  113. deepdoctection/utils/fs.py +22 -18
  114. deepdoctection/utils/identifier.py +2 -2
  115. deepdoctection/utils/logger.py +16 -15
  116. deepdoctection/utils/metacfg.py +7 -7
  117. deepdoctection/utils/mocks.py +93 -0
  118. deepdoctection/utils/pdf_utils.py +11 -11
  119. deepdoctection/utils/settings.py +185 -181
  120. deepdoctection/utils/tqdm.py +1 -1
  121. deepdoctection/utils/transform.py +14 -9
  122. deepdoctection/utils/types.py +104 -0
  123. deepdoctection/utils/utils.py +7 -7
  124. deepdoctection/utils/viz.py +74 -72
  125. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
  126. deepdoctection-0.33.dist-info/RECORD +146 -0
  127. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
  128. deepdoctection/utils/detection_types.py +0 -68
  129. deepdoctection-0.31.dist-info/RECORD +0 -144
  130. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
  131. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
@@ -19,24 +19,26 @@
19
19
  PDFPlumber text extraction engine
20
20
  """
21
21
 
22
- from typing import Dict, List, Tuple
22
+ from typing import Optional
23
+
24
+ from lazy_imports import try_import
23
25
 
24
26
  from ..utils.context import save_tmp_file
25
- from ..utils.detection_types import Requirement
26
- from ..utils.file_utils import get_pdfplumber_requirement, pdfplumber_available
27
+ from ..utils.file_utils import get_pdfplumber_requirement
27
28
  from ..utils.settings import LayoutType, ObjectTypes
28
- from .base import DetectionResult, PdfMiner
29
+ from ..utils.types import Requirement
30
+ from .base import DetectionResult, ModelCategories, PdfMiner
29
31
 
30
- if pdfplumber_available():
31
- from pdfplumber.pdf import PDF
32
+ with try_import() as import_guard:
33
+ from pdfplumber.pdf import PDF, Page
32
34
 
33
35
 
34
- def _to_detect_result(word: Dict[str, str]) -> DetectionResult:
36
+ def _to_detect_result(word: dict[str, str]) -> DetectionResult:
35
37
  return DetectionResult(
36
38
  box=[float(word["x0"]), float(word["top"]), float(word["x1"]), float(word["bottom"])],
37
39
  class_id=1,
38
40
  text=word["text"],
39
- class_name=LayoutType.word,
41
+ class_name=LayoutType.WORD,
40
42
  )
41
43
 
42
44
 
@@ -64,12 +66,15 @@ class PdfPlumberTextDetector(PdfMiner):
64
66
 
65
67
  """
66
68
 
67
- def __init__(self) -> None:
69
+ def __init__(self, x_tolerance: int = 3, y_tolerance: int = 3) -> None:
68
70
  self.name = "Pdfplumber"
69
71
  self.model_id = self.get_model_id()
70
- self.categories = {"1": LayoutType.word}
72
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD})
73
+ self.x_tolerance = x_tolerance
74
+ self.y_tolerance = y_tolerance
75
+ self._page: Optional[Page] = None
71
76
 
72
- def predict(self, pdf_bytes: bytes) -> List[DetectionResult]:
77
+ def predict(self, pdf_bytes: bytes) -> list[DetectionResult]:
73
78
  """
74
79
  Call pdfminer.six and returns detected text as detection results
75
80
 
@@ -79,25 +84,24 @@ class PdfPlumberTextDetector(PdfMiner):
79
84
 
80
85
  with save_tmp_file(pdf_bytes, "pdf_") as (tmp_name, _):
81
86
  with open(tmp_name, "rb") as fin:
82
- _pdf = PDF(fin)
83
- self._page = _pdf.pages[0]
87
+ self._page = PDF(fin).pages[0]
84
88
  self._pdf_bytes = pdf_bytes
85
- words = self._page.extract_words()
89
+ words = self._page.extract_words(x_tolerance=self.x_tolerance, y_tolerance=self.y_tolerance)
86
90
  detect_results = list(map(_to_detect_result, words))
87
91
  return detect_results
88
92
 
89
93
  @classmethod
90
- def get_requirements(cls) -> List[Requirement]:
94
+ def get_requirements(cls) -> list[Requirement]:
91
95
  return [get_pdfplumber_requirement()]
92
96
 
93
- def get_width_height(self, pdf_bytes: bytes) -> Tuple[float, float]:
97
+ def get_width_height(self, pdf_bytes: bytes) -> tuple[float, float]:
94
98
  """
95
99
  Get the width and height of the full page
96
100
  :param pdf_bytes: pdf_bytes generating the pdf
97
101
  :return: width and height
98
102
  """
99
103
 
100
- if self._pdf_bytes == pdf_bytes:
104
+ if self._pdf_bytes == pdf_bytes and self._page is not None:
101
105
  return self._page.bbox[2], self._page.bbox[3]
102
106
  # if the pdf bytes is not equal to the cached pdf, will recalculate values
103
107
  with save_tmp_file(pdf_bytes, "pdf_") as (tmp_name, _):
@@ -107,5 +111,5 @@ class PdfPlumberTextDetector(PdfMiner):
107
111
  self._pdf_bytes = pdf_bytes
108
112
  return self._page.bbox[2], self._page.bbox[3]
109
113
 
110
- def possible_categories(self) -> List[ObjectTypes]:
111
- return [LayoutType.word]
114
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
115
+ return self.categories.get_categories(as_dict=False)
@@ -19,7 +19,5 @@
19
19
  Init file for pytorch compatibility package
20
20
  """
21
21
 
22
+ from .nms import *
22
23
  from .ptutils import *
23
-
24
- if pytorch_available():
25
- from .nms import *
@@ -18,9 +18,13 @@
18
18
  """
19
19
  Module for custom NMS functions.
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
- import torch
23
- from torchvision.ops import boxes as box_ops # type: ignore
23
+ from lazy_imports import try_import
24
+
25
+ with try_import() as import_guard:
26
+ import torch
27
+ from torchvision.ops import boxes as box_ops # type: ignore
24
28
 
25
29
 
26
30
  # Copy & paste from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/nms.py
@@ -18,32 +18,42 @@
18
18
  """
19
19
  Torch related utils
20
20
  """
21
+ from __future__ import annotations
21
22
 
23
+ import os
24
+ from typing import Optional, Union
22
25
 
23
- from ...utils.error import DependencyError
24
- from ...utils.file_utils import pytorch_available
26
+ from lazy_imports import try_import
25
27
 
28
+ from ...utils.env_info import ENV_VARS_TRUE
26
29
 
27
- def set_torch_auto_device() -> "torch.device": # type: ignore
28
- """
29
- Returns cuda device if available, otherwise cpu
30
- """
31
- if pytorch_available():
32
- from torch import cuda, device # pylint: disable=C0415
33
-
34
- return device("cuda" if cuda.is_available() else "cpu")
35
- raise DependencyError("Pytorch must be installed")
30
+ with try_import() as import_guard:
31
+ import torch
36
32
 
37
33
 
38
- def get_num_gpu() -> int:
34
+ def get_torch_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
39
35
  """
40
- Returns number of CUDA devices if pytorch is available
36
+ Selecting a device on which to load a model. The selection follows a cascade of priorities:
41
37
 
42
- :return:
43
- """
38
+ - If a device string is provided, it is used.
39
+ - If the environment variable "USE_CUDA" is set, a GPU is used. If more GPUs are available, it will use all of them
40
+ unless something else is specified by CUDA_VISIBLE_DEVICES:
41
+
42
+ https://stackoverflow.com/questions/54216920/how-to-use-multiple-gpus-in-pytorch
44
43
 
45
- if pytorch_available():
46
- from torch import cuda # pylint: disable=C0415
44
+ - If an MPS device is available, it is used.
45
+ - Otherwise, the CPU is used.
47
46
 
48
- return cuda.device_count()
49
- raise DependencyError("Pytorch must be installed")
47
+ :param device: Device either as string or torch.device
48
+ :return: Tensorflow device
49
+ """
50
+ if device is not None:
51
+ if isinstance(device, torch.device):
52
+ return device
53
+ if isinstance(device, str):
54
+ return torch.device(device)
55
+ if os.environ.get("USE_CUDA", "False") in ENV_VARS_TRUE:
56
+ return torch.device("cuda")
57
+ if os.environ.get("USE_MPS", "False") in ENV_VARS_TRUE:
58
+ return torch.device("mps")
59
+ return torch.device("cpu")
@@ -18,25 +18,28 @@
18
18
  """
19
19
  Tesseract OCR engine for text extraction
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  import shlex
22
24
  import string
23
25
  import subprocess
24
26
  import sys
25
27
  from errno import ENOENT
26
28
  from itertools import groupby
27
- from os import environ
28
- from typing import Any, Dict, List, Mapping, Optional, Union
29
+ from os import environ, fspath
30
+ from pathlib import Path
31
+ from typing import Any, Mapping, Optional, Union
29
32
 
30
33
  from packaging.version import InvalidVersion, Version, parse
31
34
 
32
35
  from ..utils.context import save_tmp_file, timeout_manager
33
- from ..utils.detection_types import ImageType, Requirement
34
36
  from ..utils.error import DependencyError, TesseractError
35
37
  from ..utils.file_utils import _TESS_PATH, get_tesseract_requirement
36
38
  from ..utils.metacfg import config_to_cli_str, set_config_by_yaml
37
39
  from ..utils.settings import LayoutType, ObjectTypes, PageType
40
+ from ..utils.types import PathLikeOrStr, PixelValues, Requirement
38
41
  from ..utils.viz import viz_handler
39
- from .base import DetectionResult, ImageTransformer, ObjectDetector, PredictorBase
42
+ from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector
40
43
 
41
44
  # copy and paste with some light modifications from https://github.com/madmaze/pytesseract/tree/master/pytesseract
42
45
 
@@ -60,7 +63,7 @@ _LANG_CODE_TO_TESS_LANG_CODE = {
60
63
  }
61
64
 
62
65
 
63
- def _subprocess_args() -> Dict[str, Any]:
66
+ def _subprocess_args() -> dict[str, Any]:
64
67
  # See https://github.com/pyinstaller/pyinstaller/wiki/Recipe-subprocess
65
68
  # for reference and comments.
66
69
 
@@ -75,16 +78,16 @@ def _subprocess_args() -> Dict[str, Any]:
75
78
  return kwargs
76
79
 
77
80
 
78
- def _input_to_cli_str(lang: str, config: str, nice: int, input_file_name: str, output_file_name_base: str) -> List[str]:
81
+ def _input_to_cli_str(lang: str, config: str, nice: int, input_file_name: str, output_file_name_base: str) -> list[str]:
79
82
  """
80
83
  Generates a tesseract cmd as list of string with given inputs
81
84
  """
82
- cmd_args: List[str] = []
85
+ cmd_args: list[str] = []
83
86
 
84
87
  if not sys.platform.startswith("win32") and nice != 0:
85
88
  cmd_args += ("nice", "-n", str(nice))
86
89
 
87
- cmd_args += (_TESS_PATH, input_file_name, output_file_name_base, "-l", lang)
90
+ cmd_args += (fspath(_TESS_PATH), input_file_name, output_file_name_base, "-l", lang)
88
91
 
89
92
  if config:
90
93
  cmd_args += shlex.split(config)
@@ -94,7 +97,7 @@ def _input_to_cli_str(lang: str, config: str, nice: int, input_file_name: str, o
94
97
  return cmd_args
95
98
 
96
99
 
97
- def _run_tesseract(tesseract_args: List[str]) -> None:
100
+ def _run_tesseract(tesseract_args: list[str]) -> None:
98
101
  try:
99
102
  proc = subprocess.Popen(tesseract_args, **_subprocess_args()) # pylint: disable=R1732
100
103
  except OSError as error:
@@ -137,7 +140,7 @@ def get_tesseract_version() -> Version:
137
140
  return version
138
141
 
139
142
 
140
- def image_to_angle(image: ImageType) -> Mapping[str, str]:
143
+ def image_to_angle(image: PixelValues) -> Mapping[str, str]:
141
144
  """
142
145
  Generating a tmp file and running tesseract to get the orientation of the image.
143
146
 
@@ -154,7 +157,7 @@ def image_to_angle(image: ImageType) -> Mapping[str, str]:
154
157
  }
155
158
 
156
159
 
157
- def image_to_dict(image: ImageType, lang: str, config: str) -> Dict[str, List[Union[str, int, float]]]:
160
+ def image_to_dict(image: PixelValues, lang: str, config: str) -> dict[str, list[Union[str, int, float]]]:
158
161
  """
159
162
  This is more or less pytesseract.image_to_data with a dict as returned value.
160
163
  What happens under the hood is:
@@ -177,7 +180,7 @@ def image_to_dict(image: ImageType, lang: str, config: str) -> Dict[str, List[Un
177
180
  _run_tesseract(_input_to_cli_str(lang, config, 0, input_file_name, tmp_name))
178
181
  with open(tmp_name + ".tsv", "rb") as output_file:
179
182
  output = output_file.read().decode("utf-8")
180
- result: Dict[str, List[Union[str, int, float]]] = {}
183
+ result: dict[str, list[Union[str, int, float]]] = {}
181
184
  rows = [row.split("\t") for row in output.strip().split("\n")]
182
185
  if len(rows) < 2:
183
186
  return result
@@ -208,7 +211,7 @@ def image_to_dict(image: ImageType, lang: str, config: str) -> Dict[str, List[Un
208
211
  return result
209
212
 
210
213
 
211
- def tesseract_line_to_detectresult(detect_result_list: List[DetectionResult]) -> List[DetectionResult]:
214
+ def tesseract_line_to_detectresult(detect_result_list: list[DetectionResult]) -> list[DetectionResult]:
212
215
  """
213
216
  Generating text line DetectionResult based on Tesseract word grouping. It generates line bounding boxes from
214
217
  word bounding boxes.
@@ -216,7 +219,7 @@ def tesseract_line_to_detectresult(detect_result_list: List[DetectionResult]) ->
216
219
  :return: An extended list of detection result
217
220
  """
218
221
 
219
- line_detect_result: List[DetectionResult] = []
222
+ line_detect_result: list[DetectionResult] = []
220
223
  for _, block_group_iter in groupby(detect_result_list, key=lambda x: x.block):
221
224
  block_group = []
222
225
  for _, line_group_iter in groupby(list(block_group_iter), key=lambda x: x.line):
@@ -231,7 +234,7 @@ def tesseract_line_to_detectresult(detect_result_list: List[DetectionResult]) ->
231
234
  DetectionResult(
232
235
  box=[ulx, uly, lrx, lry],
233
236
  class_id=2,
234
- class_name=LayoutType.line,
237
+ class_name=LayoutType.LINE,
235
238
  text=" ".join(
236
239
  [detect_result.text for detect_result in block_group if isinstance(detect_result.text, str)]
237
240
  ),
@@ -242,7 +245,7 @@ def tesseract_line_to_detectresult(detect_result_list: List[DetectionResult]) ->
242
245
  return detect_result_list
243
246
 
244
247
 
245
- def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool, config: str) -> List[DetectionResult]:
248
+ def predict_text(np_img: PixelValues, supported_languages: str, text_lines: bool, config: str) -> list[DetectionResult]:
246
249
  """
247
250
  Calls tesseract directly with some given configs. Requires Tesseract to be installed.
248
251
 
@@ -275,7 +278,7 @@ def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool,
275
278
  score=score / 100,
276
279
  text=caption[5],
277
280
  class_id=1,
278
- class_name=LayoutType.word,
281
+ class_name=LayoutType.WORD,
279
282
  )
280
283
  all_results.append(word)
281
284
  if text_lines:
@@ -283,7 +286,7 @@ def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool,
283
286
  return all_results
284
287
 
285
288
 
286
- def predict_rotation(np_img: ImageType) -> Mapping[str, str]:
289
+ def predict_rotation(np_img: PixelValues) -> Mapping[str, str]:
287
290
  """
288
291
  Predicts the rotation of an image using the Tesseract OCR engine.
289
292
 
@@ -326,8 +329,8 @@ class TesseractOcrDetector(ObjectDetector):
326
329
 
327
330
  def __init__(
328
331
  self,
329
- path_yaml: str,
330
- config_overwrite: Optional[List[str]] = None,
332
+ path_yaml: PathLikeOrStr,
333
+ config_overwrite: Optional[list[str]] = None,
331
334
  ):
332
335
  """
333
336
  Set up the configuration which is stored in a yaml-file, that need to be passed through.
@@ -346,16 +349,16 @@ class TesseractOcrDetector(ObjectDetector):
346
349
  if len(config_overwrite):
347
350
  hyper_param_config.update_args(config_overwrite)
348
351
 
349
- self.path_yaml = path_yaml
352
+ self.path_yaml = Path(path_yaml)
350
353
  self.config_overwrite = config_overwrite
351
354
  self.config = hyper_param_config
352
355
 
353
356
  if self.config.LINES:
354
- self.categories = {"1": LayoutType.word, "2": LayoutType.line}
357
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD, 2: LayoutType.LINE})
355
358
  else:
356
- self.categories = {"1": LayoutType.word}
359
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD})
357
360
 
358
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
361
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
359
362
  """
360
363
  Transfer of a numpy array and call of pytesseract. Return of the detection results.
361
364
 
@@ -371,16 +374,14 @@ class TesseractOcrDetector(ObjectDetector):
371
374
  )
372
375
 
373
376
  @classmethod
374
- def get_requirements(cls) -> List[Requirement]:
377
+ def get_requirements(cls) -> list[Requirement]:
375
378
  return [get_tesseract_requirement()]
376
379
 
377
- def clone(self) -> PredictorBase:
380
+ def clone(self) -> TesseractOcrDetector:
378
381
  return self.__class__(self.path_yaml, self.config_overwrite)
379
382
 
380
- def possible_categories(self) -> List[ObjectTypes]:
381
- if self.config.LINES:
382
- return [LayoutType.word, LayoutType.line]
383
- return [LayoutType.word]
383
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
384
+ return self.categories.get_categories(as_dict=False)
384
385
 
385
386
  def set_language(self, language: ObjectTypes) -> None:
386
387
  """
@@ -418,9 +419,10 @@ class TesseractRotationTransformer(ImageTransformer):
418
419
  """
419
420
 
420
421
  def __init__(self) -> None:
421
- self.name = _TESS_PATH + "-rotation"
422
+ self.name = fspath(_TESS_PATH) + "-rotation"
423
+ self.categories = ModelCategories(init_categories={1: PageType.ANGLE})
422
424
 
423
- def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
425
+ def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
424
426
  """
425
427
  Applies the predicted rotation to the image, effectively rotating the image backwards.
426
428
  This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
@@ -431,7 +433,7 @@ class TesseractRotationTransformer(ImageTransformer):
431
433
  """
432
434
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
433
435
 
434
- def predict(self, np_img: ImageType) -> DetectionResult:
436
+ def predict(self, np_img: PixelValues) -> DetectionResult:
435
437
  """
436
438
  Determines the angle of the rotated image. It can only handle angles that are multiples of 90 degrees.
437
439
  This method uses the Tesseract OCR engine to predict the rotation angle of an image.
@@ -445,12 +447,11 @@ class TesseractRotationTransformer(ImageTransformer):
445
447
  )
446
448
 
447
449
  @classmethod
448
- def get_requirements(cls) -> List[Requirement]:
450
+ def get_requirements(cls) -> list[Requirement]:
449
451
  return [get_tesseract_requirement()]
450
452
 
451
- def clone(self) -> PredictorBase:
453
+ def clone(self) -> TesseractRotationTransformer:
452
454
  return self.__class__()
453
455
 
454
- @staticmethod
455
- def possible_category() -> PageType:
456
- return PageType.angle
456
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
457
+ return self.categories.get_categories(as_dict=False)
@@ -18,24 +18,26 @@
18
18
  """
19
19
  AWS Textract OCR engine for text extraction
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
23
  import sys
23
24
  import traceback
24
- from typing import List
25
+
26
+ from lazy_imports import try_import
25
27
 
26
28
  from ..datapoint.convert import convert_np_array_to_b64_b
27
- from ..utils.detection_types import ImageType, JsonDict, Requirement
28
- from ..utils.file_utils import boto3_available, get_boto3_requirement
29
+ from ..utils.file_utils import get_boto3_requirement
29
30
  from ..utils.logger import LoggingRecord, logger
30
31
  from ..utils.settings import LayoutType, ObjectTypes
31
- from .base import DetectionResult, ObjectDetector, PredictorBase
32
+ from ..utils.types import JsonDict, PixelValues, Requirement
33
+ from .base import DetectionResult, ModelCategories, ObjectDetector
32
34
 
33
- if boto3_available():
35
+ with try_import() as import_guard:
34
36
  import boto3 # type:ignore
35
37
 
36
38
 
37
- def _textract_to_detectresult(response: JsonDict, width: int, height: int, text_lines: bool) -> List[DetectionResult]:
38
- all_results: List[DetectionResult] = []
39
+ def _textract_to_detectresult(response: JsonDict, width: int, height: int, text_lines: bool) -> list[DetectionResult]:
40
+ all_results: list[DetectionResult] = []
39
41
  blocks = response.get("Blocks")
40
42
 
41
43
  if blocks:
@@ -51,14 +53,14 @@ def _textract_to_detectresult(response: JsonDict, width: int, height: int, text_
51
53
  score=block["Confidence"] / 100,
52
54
  text=block["Text"],
53
55
  class_id=1 if block["BlockType"] == "WORD" else 2,
54
- class_name=LayoutType.word if block["BlockType"] == "WORD" else LayoutType.line,
56
+ class_name=LayoutType.WORD if block["BlockType"] == "WORD" else LayoutType.LINE,
55
57
  )
56
58
  all_results.append(word)
57
59
 
58
60
  return all_results
59
61
 
60
62
 
61
- def predict_text(np_img: ImageType, client, text_lines: bool) -> List[DetectionResult]: # type: ignore
63
+ def predict_text(np_img: PixelValues, client, text_lines: bool) -> list[DetectionResult]: # type: ignore
62
64
  """
63
65
  Calls AWS Textract client (`detect_document_text`) and returns plain OCR results.
64
66
  AWS account required.
@@ -125,11 +127,11 @@ class TextractOcrDetector(ObjectDetector):
125
127
  self.text_lines = text_lines
126
128
  self.client = boto3.client("textract", **credentials_kwargs)
127
129
  if self.text_lines:
128
- self.categories = {"1": LayoutType.word, "2": LayoutType.line}
130
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD, 2: LayoutType.LINE})
129
131
  else:
130
- self.categories = {"1": LayoutType.word}
132
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD})
131
133
 
132
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
134
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
133
135
  """
134
136
  Transfer of a numpy array and call textract client. Return of the detection results.
135
137
 
@@ -140,13 +142,11 @@ class TextractOcrDetector(ObjectDetector):
140
142
  return predict_text(np_img, self.client, self.text_lines)
141
143
 
142
144
  @classmethod
143
- def get_requirements(cls) -> List[Requirement]:
145
+ def get_requirements(cls) -> list[Requirement]:
144
146
  return [get_boto3_requirement()]
145
147
 
146
- def clone(self) -> PredictorBase:
148
+ def clone(self) -> TextractOcrDetector:
147
149
  return self.__class__()
148
150
 
149
- def possible_categories(self) -> List[ObjectTypes]:
150
- if self.text_lines:
151
- return [LayoutType.word, LayoutType.line]
152
- return [LayoutType.word]
151
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
152
+ return self.categories.get_categories(as_dict=False)
@@ -19,7 +19,20 @@
19
19
  Tensorflow related utils.
20
20
  """
21
21
 
22
- from tensorpack.models import disable_layer_logging # pylint: disable=E0401
22
+ from __future__ import annotations
23
+
24
+ import os
25
+ from typing import ContextManager, Optional, Union
26
+
27
+ from lazy_imports import try_import
28
+
29
+ from ...utils.env_info import ENV_VARS_TRUE
30
+
31
+ with try_import() as import_guard:
32
+ from tensorpack.models import disable_layer_logging # pylint: disable=E0401
33
+
34
+ with try_import() as tf_import_guard:
35
+ import tensorflow as tf # pylint: disable=E0401
23
36
 
24
37
 
25
38
  def is_tfv2() -> bool:
@@ -38,16 +51,13 @@ def disable_tfv2() -> bool:
38
51
  """
39
52
  Disable TF in V2 mode.
40
53
  """
41
- try:
42
- import tensorflow as tf # pylint: disable=C0415
43
54
 
44
- tfv1 = tf.compat.v1
45
- if is_tfv2():
46
- tfv1.disable_v2_behavior()
47
- tfv1.disable_eager_execution()
55
+ tfv1 = tf.compat.v1
56
+ if is_tfv2():
57
+ tfv1.disable_v2_behavior()
58
+ tfv1.disable_eager_execution()
48
59
  return True
49
- except ModuleNotFoundError:
50
- return False
60
+ return False
51
61
 
52
62
 
53
63
  def disable_tp_layer_logging() -> None:
@@ -55,3 +65,41 @@ def disable_tp_layer_logging() -> None:
55
65
  Disables TP layer logging, if not already set
56
66
  """
57
67
  disable_layer_logging()
68
+
69
+
70
+ def get_tf_device(device: Optional[Union[str, tf.device]] = None) -> tf.device:
71
+ """
72
+ Selecting a device on which to load a model. The selection follows a cascade of priorities:
73
+
74
+ - If a device string is provided, it is used. If the string is "cuda" or "GPU", the first GPU is used.
75
+ - If the environment variable "USE_CUDA" is set, a GPU is used. If more GPUs are available it will use the first one
76
+
77
+ :param device: Device string
78
+ :return: Tensorflow device
79
+ """
80
+ if device is not None:
81
+ if isinstance(device, ContextManager):
82
+ return device
83
+ if isinstance(device, str):
84
+ if device in ("cuda", "GPU"):
85
+ device_names = [device.name for device in tf.config.list_logical_devices(device_type="GPU")]
86
+ return tf.device(device_names[0].name)
87
+ # The input must be something sensible
88
+ return tf.device(device)
89
+ if os.environ.get("USE_CUDA", "False") in ENV_VARS_TRUE:
90
+ device_names = [device.name for device in tf.config.list_logical_devices(device_type="GPU")]
91
+ if not device_names:
92
+ raise EnvironmentError(
93
+ "USE_CUDA is set but tf.config.list_logical_devices cannot find anyx device. "
94
+ "It looks like there is an issue with your Tensorlfow installation. "
95
+ "You can LOG_LEVEL='DEBUG' to get more information about installation."
96
+ )
97
+ return tf.device(device_names[0])
98
+ device_names = [device.name for device in tf.config.list_logical_devices(device_type="CPU")]
99
+ if not device_names:
100
+ raise EnvironmentError(
101
+ "Cannot find any CPU device. It looks like there is an issue with your "
102
+ "Tensorflow installation. You can LOG_LEVEL='DEBUG' to get more information about "
103
+ "installation."
104
+ )
105
+ return tf.device(device_names[0])