deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -19,26 +19,26 @@
19
19
  PDFPlumber text extraction engine
20
20
  """
21
21
 
22
- from typing import Dict, List, Tuple
22
+ from typing import Optional
23
23
 
24
24
  from lazy_imports import try_import
25
25
 
26
26
  from ..utils.context import save_tmp_file
27
- from ..utils.detection_types import Requirement
28
27
  from ..utils.file_utils import get_pdfplumber_requirement
29
28
  from ..utils.settings import LayoutType, ObjectTypes
30
- from .base import DetectionResult, PdfMiner
29
+ from ..utils.types import Requirement
30
+ from .base import DetectionResult, ModelCategories, PdfMiner
31
31
 
32
32
  with try_import() as import_guard:
33
- from pdfplumber.pdf import PDF
33
+ from pdfplumber.pdf import PDF, Page
34
34
 
35
35
 
36
- def _to_detect_result(word: Dict[str, str]) -> DetectionResult:
36
+ def _to_detect_result(word: dict[str, str]) -> DetectionResult:
37
37
  return DetectionResult(
38
38
  box=[float(word["x0"]), float(word["top"]), float(word["x1"]), float(word["bottom"])],
39
39
  class_id=1,
40
40
  text=word["text"],
41
- class_name=LayoutType.word,
41
+ class_name=LayoutType.WORD,
42
42
  )
43
43
 
44
44
 
@@ -69,11 +69,12 @@ class PdfPlumberTextDetector(PdfMiner):
69
69
  def __init__(self, x_tolerance: int = 3, y_tolerance: int = 3) -> None:
70
70
  self.name = "Pdfplumber"
71
71
  self.model_id = self.get_model_id()
72
- self.categories = {"1": LayoutType.word}
72
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD})
73
73
  self.x_tolerance = x_tolerance
74
74
  self.y_tolerance = y_tolerance
75
+ self._page: Optional[Page] = None
75
76
 
76
- def predict(self, pdf_bytes: bytes) -> List[DetectionResult]:
77
+ def predict(self, pdf_bytes: bytes) -> list[DetectionResult]:
77
78
  """
78
79
  Call pdfminer.six and returns detected text as detection results
79
80
 
@@ -83,25 +84,24 @@ class PdfPlumberTextDetector(PdfMiner):
83
84
 
84
85
  with save_tmp_file(pdf_bytes, "pdf_") as (tmp_name, _):
85
86
  with open(tmp_name, "rb") as fin:
86
- _pdf = PDF(fin)
87
- self._page = _pdf.pages[0]
87
+ self._page = PDF(fin).pages[0]
88
88
  self._pdf_bytes = pdf_bytes
89
89
  words = self._page.extract_words(x_tolerance=self.x_tolerance, y_tolerance=self.y_tolerance)
90
90
  detect_results = list(map(_to_detect_result, words))
91
91
  return detect_results
92
92
 
93
93
  @classmethod
94
- def get_requirements(cls) -> List[Requirement]:
94
+ def get_requirements(cls) -> list[Requirement]:
95
95
  return [get_pdfplumber_requirement()]
96
96
 
97
- def get_width_height(self, pdf_bytes: bytes) -> Tuple[float, float]:
97
+ def get_width_height(self, pdf_bytes: bytes) -> tuple[float, float]:
98
98
  """
99
99
  Get the width and height of the full page
100
100
  :param pdf_bytes: pdf_bytes generating the pdf
101
101
  :return: width and height
102
102
  """
103
103
 
104
- if self._pdf_bytes == pdf_bytes:
104
+ if self._pdf_bytes == pdf_bytes and self._page is not None:
105
105
  return self._page.bbox[2], self._page.bbox[3]
106
106
  # if the pdf bytes is not equal to the cached pdf, will recalculate values
107
107
  with save_tmp_file(pdf_bytes, "pdf_") as (tmp_name, _):
@@ -111,5 +111,5 @@ class PdfPlumberTextDetector(PdfMiner):
111
111
  self._pdf_bytes = pdf_bytes
112
112
  return self._page.bbox[2], self._page.bbox[3]
113
113
 
114
- def possible_categories(self) -> List[ObjectTypes]:
115
- return [LayoutType.word]
114
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
115
+ return self.categories.get_categories(as_dict=False)
@@ -25,6 +25,8 @@ from typing import Optional, Union
25
25
 
26
26
  from lazy_imports import try_import
27
27
 
28
+ from ...utils.env_info import ENV_VARS_TRUE
29
+
28
30
  with try_import() as import_guard:
29
31
  import torch
30
32
 
@@ -50,8 +52,8 @@ def get_torch_device(device: Optional[Union[str, torch.device]] = None) -> torch
50
52
  return device
51
53
  if isinstance(device, str):
52
54
  return torch.device(device)
53
- if os.environ.get("USE_CUDA"):
55
+ if os.environ.get("USE_CUDA", "False") in ENV_VARS_TRUE:
54
56
  return torch.device("cuda")
55
- if os.environ.get("USE_MPS"):
57
+ if os.environ.get("USE_MPS", "False") in ENV_VARS_TRUE:
56
58
  return torch.device("mps")
57
59
  return torch.device("cpu")
@@ -18,25 +18,28 @@
18
18
  """
19
19
  Tesseract OCR engine for text extraction
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  import shlex
22
24
  import string
23
25
  import subprocess
24
26
  import sys
25
27
  from errno import ENOENT
26
28
  from itertools import groupby
27
- from os import environ
28
- from typing import Any, Dict, List, Mapping, Optional, Union
29
+ from os import environ, fspath
30
+ from pathlib import Path
31
+ from typing import Any, Mapping, Optional, Union
29
32
 
30
33
  from packaging.version import InvalidVersion, Version, parse
31
34
 
32
35
  from ..utils.context import save_tmp_file, timeout_manager
33
- from ..utils.detection_types import ImageType, Requirement
34
36
  from ..utils.error import DependencyError, TesseractError
35
37
  from ..utils.file_utils import _TESS_PATH, get_tesseract_requirement
36
38
  from ..utils.metacfg import config_to_cli_str, set_config_by_yaml
37
39
  from ..utils.settings import LayoutType, ObjectTypes, PageType
40
+ from ..utils.types import PathLikeOrStr, PixelValues, Requirement
38
41
  from ..utils.viz import viz_handler
39
- from .base import DetectionResult, ImageTransformer, ObjectDetector, PredictorBase
42
+ from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector
40
43
 
41
44
  # copy and paste with some light modifications from https://github.com/madmaze/pytesseract/tree/master/pytesseract
42
45
 
@@ -60,7 +63,7 @@ _LANG_CODE_TO_TESS_LANG_CODE = {
60
63
  }
61
64
 
62
65
 
63
- def _subprocess_args() -> Dict[str, Any]:
66
+ def _subprocess_args() -> dict[str, Any]:
64
67
  # See https://github.com/pyinstaller/pyinstaller/wiki/Recipe-subprocess
65
68
  # for reference and comments.
66
69
 
@@ -75,16 +78,16 @@ def _subprocess_args() -> Dict[str, Any]:
75
78
  return kwargs
76
79
 
77
80
 
78
- def _input_to_cli_str(lang: str, config: str, nice: int, input_file_name: str, output_file_name_base: str) -> List[str]:
81
+ def _input_to_cli_str(lang: str, config: str, nice: int, input_file_name: str, output_file_name_base: str) -> list[str]:
79
82
  """
80
83
  Generates a tesseract cmd as list of string with given inputs
81
84
  """
82
- cmd_args: List[str] = []
85
+ cmd_args: list[str] = []
83
86
 
84
87
  if not sys.platform.startswith("win32") and nice != 0:
85
88
  cmd_args += ("nice", "-n", str(nice))
86
89
 
87
- cmd_args += (_TESS_PATH, input_file_name, output_file_name_base, "-l", lang)
90
+ cmd_args += (fspath(_TESS_PATH), input_file_name, output_file_name_base, "-l", lang)
88
91
 
89
92
  if config:
90
93
  cmd_args += shlex.split(config)
@@ -94,7 +97,7 @@ def _input_to_cli_str(lang: str, config: str, nice: int, input_file_name: str, o
94
97
  return cmd_args
95
98
 
96
99
 
97
- def _run_tesseract(tesseract_args: List[str]) -> None:
100
+ def _run_tesseract(tesseract_args: list[str]) -> None:
98
101
  try:
99
102
  proc = subprocess.Popen(tesseract_args, **_subprocess_args()) # pylint: disable=R1732
100
103
  except OSError as error:
@@ -137,7 +140,7 @@ def get_tesseract_version() -> Version:
137
140
  return version
138
141
 
139
142
 
140
- def image_to_angle(image: ImageType) -> Mapping[str, str]:
143
+ def image_to_angle(image: PixelValues) -> Mapping[str, str]:
141
144
  """
142
145
  Generating a tmp file and running tesseract to get the orientation of the image.
143
146
 
@@ -154,7 +157,7 @@ def image_to_angle(image: ImageType) -> Mapping[str, str]:
154
157
  }
155
158
 
156
159
 
157
- def image_to_dict(image: ImageType, lang: str, config: str) -> Dict[str, List[Union[str, int, float]]]:
160
+ def image_to_dict(image: PixelValues, lang: str, config: str) -> dict[str, list[Union[str, int, float]]]:
158
161
  """
159
162
  This is more or less pytesseract.image_to_data with a dict as returned value.
160
163
  What happens under the hood is:
@@ -177,7 +180,7 @@ def image_to_dict(image: ImageType, lang: str, config: str) -> Dict[str, List[Un
177
180
  _run_tesseract(_input_to_cli_str(lang, config, 0, input_file_name, tmp_name))
178
181
  with open(tmp_name + ".tsv", "rb") as output_file:
179
182
  output = output_file.read().decode("utf-8")
180
- result: Dict[str, List[Union[str, int, float]]] = {}
183
+ result: dict[str, list[Union[str, int, float]]] = {}
181
184
  rows = [row.split("\t") for row in output.strip().split("\n")]
182
185
  if len(rows) < 2:
183
186
  return result
@@ -208,7 +211,7 @@ def image_to_dict(image: ImageType, lang: str, config: str) -> Dict[str, List[Un
208
211
  return result
209
212
 
210
213
 
211
- def tesseract_line_to_detectresult(detect_result_list: List[DetectionResult]) -> List[DetectionResult]:
214
+ def tesseract_line_to_detectresult(detect_result_list: list[DetectionResult]) -> list[DetectionResult]:
212
215
  """
213
216
  Generating text line DetectionResult based on Tesseract word grouping. It generates line bounding boxes from
214
217
  word bounding boxes.
@@ -216,7 +219,7 @@ def tesseract_line_to_detectresult(detect_result_list: List[DetectionResult]) ->
216
219
  :return: An extended list of detection result
217
220
  """
218
221
 
219
- line_detect_result: List[DetectionResult] = []
222
+ line_detect_result: list[DetectionResult] = []
220
223
  for _, block_group_iter in groupby(detect_result_list, key=lambda x: x.block):
221
224
  block_group = []
222
225
  for _, line_group_iter in groupby(list(block_group_iter), key=lambda x: x.line):
@@ -231,7 +234,7 @@ def tesseract_line_to_detectresult(detect_result_list: List[DetectionResult]) ->
231
234
  DetectionResult(
232
235
  box=[ulx, uly, lrx, lry],
233
236
  class_id=2,
234
- class_name=LayoutType.line,
237
+ class_name=LayoutType.LINE,
235
238
  text=" ".join(
236
239
  [detect_result.text for detect_result in block_group if isinstance(detect_result.text, str)]
237
240
  ),
@@ -242,7 +245,7 @@ def tesseract_line_to_detectresult(detect_result_list: List[DetectionResult]) ->
242
245
  return detect_result_list
243
246
 
244
247
 
245
- def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool, config: str) -> List[DetectionResult]:
248
+ def predict_text(np_img: PixelValues, supported_languages: str, text_lines: bool, config: str) -> list[DetectionResult]:
246
249
  """
247
250
  Calls tesseract directly with some given configs. Requires Tesseract to be installed.
248
251
 
@@ -275,7 +278,7 @@ def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool,
275
278
  score=score / 100,
276
279
  text=caption[5],
277
280
  class_id=1,
278
- class_name=LayoutType.word,
281
+ class_name=LayoutType.WORD,
279
282
  )
280
283
  all_results.append(word)
281
284
  if text_lines:
@@ -283,7 +286,7 @@ def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool,
283
286
  return all_results
284
287
 
285
288
 
286
- def predict_rotation(np_img: ImageType) -> Mapping[str, str]:
289
+ def predict_rotation(np_img: PixelValues) -> Mapping[str, str]:
287
290
  """
288
291
  Predicts the rotation of an image using the Tesseract OCR engine.
289
292
 
@@ -326,8 +329,8 @@ class TesseractOcrDetector(ObjectDetector):
326
329
 
327
330
  def __init__(
328
331
  self,
329
- path_yaml: str,
330
- config_overwrite: Optional[List[str]] = None,
332
+ path_yaml: PathLikeOrStr,
333
+ config_overwrite: Optional[list[str]] = None,
331
334
  ):
332
335
  """
333
336
  Set up the configuration which is stored in a yaml-file, that need to be passed through.
@@ -346,16 +349,16 @@ class TesseractOcrDetector(ObjectDetector):
346
349
  if len(config_overwrite):
347
350
  hyper_param_config.update_args(config_overwrite)
348
351
 
349
- self.path_yaml = path_yaml
352
+ self.path_yaml = Path(path_yaml)
350
353
  self.config_overwrite = config_overwrite
351
354
  self.config = hyper_param_config
352
355
 
353
356
  if self.config.LINES:
354
- self.categories = {"1": LayoutType.word, "2": LayoutType.line}
357
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD, 2: LayoutType.LINE})
355
358
  else:
356
- self.categories = {"1": LayoutType.word}
359
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD})
357
360
 
358
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
361
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
359
362
  """
360
363
  Transfer of a numpy array and call of pytesseract. Return of the detection results.
361
364
 
@@ -371,16 +374,14 @@ class TesseractOcrDetector(ObjectDetector):
371
374
  )
372
375
 
373
376
  @classmethod
374
- def get_requirements(cls) -> List[Requirement]:
377
+ def get_requirements(cls) -> list[Requirement]:
375
378
  return [get_tesseract_requirement()]
376
379
 
377
- def clone(self) -> PredictorBase:
380
+ def clone(self) -> TesseractOcrDetector:
378
381
  return self.__class__(self.path_yaml, self.config_overwrite)
379
382
 
380
- def possible_categories(self) -> List[ObjectTypes]:
381
- if self.config.LINES:
382
- return [LayoutType.word, LayoutType.line]
383
- return [LayoutType.word]
383
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
384
+ return self.categories.get_categories(as_dict=False)
384
385
 
385
386
  def set_language(self, language: ObjectTypes) -> None:
386
387
  """
@@ -418,9 +419,10 @@ class TesseractRotationTransformer(ImageTransformer):
418
419
  """
419
420
 
420
421
  def __init__(self) -> None:
421
- self.name = _TESS_PATH + "-rotation"
422
+ self.name = fspath(_TESS_PATH) + "-rotation"
423
+ self.categories = ModelCategories(init_categories={1: PageType.ANGLE})
422
424
 
423
- def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
425
+ def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
424
426
  """
425
427
  Applies the predicted rotation to the image, effectively rotating the image backwards.
426
428
  This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
@@ -431,7 +433,7 @@ class TesseractRotationTransformer(ImageTransformer):
431
433
  """
432
434
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
433
435
 
434
- def predict(self, np_img: ImageType) -> DetectionResult:
436
+ def predict(self, np_img: PixelValues) -> DetectionResult:
435
437
  """
436
438
  Determines the angle of the rotated image. It can only handle angles that are multiples of 90 degrees.
437
439
  This method uses the Tesseract OCR engine to predict the rotation angle of an image.
@@ -445,12 +447,11 @@ class TesseractRotationTransformer(ImageTransformer):
445
447
  )
446
448
 
447
449
  @classmethod
448
- def get_requirements(cls) -> List[Requirement]:
450
+ def get_requirements(cls) -> list[Requirement]:
449
451
  return [get_tesseract_requirement()]
450
452
 
451
- def clone(self) -> PredictorBase:
453
+ def clone(self) -> TesseractRotationTransformer:
452
454
  return self.__class__()
453
455
 
454
- @staticmethod
455
- def possible_category() -> PageType:
456
- return PageType.angle
456
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
457
+ return self.categories.get_categories(as_dict=False)
@@ -18,26 +18,26 @@
18
18
  """
19
19
  AWS Textract OCR engine for text extraction
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
23
  import sys
23
24
  import traceback
24
- from typing import List
25
25
 
26
26
  from lazy_imports import try_import
27
27
 
28
28
  from ..datapoint.convert import convert_np_array_to_b64_b
29
- from ..utils.detection_types import ImageType, JsonDict, Requirement
30
29
  from ..utils.file_utils import get_boto3_requirement
31
30
  from ..utils.logger import LoggingRecord, logger
32
31
  from ..utils.settings import LayoutType, ObjectTypes
33
- from .base import DetectionResult, ObjectDetector, PredictorBase
32
+ from ..utils.types import JsonDict, PixelValues, Requirement
33
+ from .base import DetectionResult, ModelCategories, ObjectDetector
34
34
 
35
35
  with try_import() as import_guard:
36
36
  import boto3 # type:ignore
37
37
 
38
38
 
39
- def _textract_to_detectresult(response: JsonDict, width: int, height: int, text_lines: bool) -> List[DetectionResult]:
40
- all_results: List[DetectionResult] = []
39
+ def _textract_to_detectresult(response: JsonDict, width: int, height: int, text_lines: bool) -> list[DetectionResult]:
40
+ all_results: list[DetectionResult] = []
41
41
  blocks = response.get("Blocks")
42
42
 
43
43
  if blocks:
@@ -53,14 +53,14 @@ def _textract_to_detectresult(response: JsonDict, width: int, height: int, text_
53
53
  score=block["Confidence"] / 100,
54
54
  text=block["Text"],
55
55
  class_id=1 if block["BlockType"] == "WORD" else 2,
56
- class_name=LayoutType.word if block["BlockType"] == "WORD" else LayoutType.line,
56
+ class_name=LayoutType.WORD if block["BlockType"] == "WORD" else LayoutType.LINE,
57
57
  )
58
58
  all_results.append(word)
59
59
 
60
60
  return all_results
61
61
 
62
62
 
63
- def predict_text(np_img: ImageType, client, text_lines: bool) -> List[DetectionResult]: # type: ignore
63
+ def predict_text(np_img: PixelValues, client, text_lines: bool) -> list[DetectionResult]: # type: ignore
64
64
  """
65
65
  Calls AWS Textract client (`detect_document_text`) and returns plain OCR results.
66
66
  AWS account required.
@@ -127,11 +127,11 @@ class TextractOcrDetector(ObjectDetector):
127
127
  self.text_lines = text_lines
128
128
  self.client = boto3.client("textract", **credentials_kwargs)
129
129
  if self.text_lines:
130
- self.categories = {"1": LayoutType.word, "2": LayoutType.line}
130
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD, 2: LayoutType.LINE})
131
131
  else:
132
- self.categories = {"1": LayoutType.word}
132
+ self.categories = ModelCategories(init_categories={1: LayoutType.WORD})
133
133
 
134
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
134
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
135
135
  """
136
136
  Transfer of a numpy array and call textract client. Return of the detection results.
137
137
 
@@ -142,13 +142,11 @@ class TextractOcrDetector(ObjectDetector):
142
142
  return predict_text(np_img, self.client, self.text_lines)
143
143
 
144
144
  @classmethod
145
- def get_requirements(cls) -> List[Requirement]:
145
+ def get_requirements(cls) -> list[Requirement]:
146
146
  return [get_boto3_requirement()]
147
147
 
148
- def clone(self) -> PredictorBase:
148
+ def clone(self) -> TextractOcrDetector:
149
149
  return self.__class__()
150
150
 
151
- def possible_categories(self) -> List[ObjectTypes]:
152
- if self.text_lines:
153
- return [LayoutType.word, LayoutType.line]
154
- return [LayoutType.word]
151
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
152
+ return self.categories.get_categories(as_dict=False)
@@ -22,10 +22,12 @@ Tensorflow related utils.
22
22
  from __future__ import annotations
23
23
 
24
24
  import os
25
- from typing import Optional, Union, ContextManager
25
+ from typing import ContextManager, Optional, Union
26
26
 
27
27
  from lazy_imports import try_import
28
28
 
29
+ from ...utils.env_info import ENV_VARS_TRUE
30
+
29
31
  with try_import() as import_guard:
30
32
  from tensorpack.models import disable_layer_logging # pylint: disable=E0401
31
33
 
@@ -84,8 +86,20 @@ def get_tf_device(device: Optional[Union[str, tf.device]] = None) -> tf.device:
84
86
  return tf.device(device_names[0].name)
85
87
  # The input must be something sensible
86
88
  return tf.device(device)
87
- if os.environ.get("USE_CUDA"):
89
+ if os.environ.get("USE_CUDA", "False") in ENV_VARS_TRUE:
88
90
  device_names = [device.name for device in tf.config.list_logical_devices(device_type="GPU")]
91
+ if not device_names:
92
+ raise EnvironmentError(
93
+ "USE_CUDA is set but tf.config.list_logical_devices cannot find anyx device. "
94
+ "It looks like there is an issue with your Tensorlfow installation. "
95
+ "You can LOG_LEVEL='DEBUG' to get more information about installation."
96
+ )
89
97
  return tf.device(device_names[0])
90
98
  device_names = [device.name for device in tf.config.list_logical_devices(device_type="CPU")]
99
+ if not device_names:
100
+ raise EnvironmentError(
101
+ "Cannot find any CPU device. It looks like there is an issue with your "
102
+ "Tensorflow installation. You can LOG_LEVEL='DEBUG' to get more information about "
103
+ "installation."
104
+ )
91
105
  return tf.device(device_names[0])
@@ -20,13 +20,16 @@ Compatibility classes and methods related to Tensorpack package
20
20
  """
21
21
  from __future__ import annotations
22
22
 
23
+ import os
23
24
  from abc import ABC, abstractmethod
24
- from typing import Any, List, Mapping, Tuple, Union
25
+ from pathlib import Path
26
+ from typing import Any, Mapping, Union
25
27
 
26
28
  from lazy_imports import try_import
27
29
 
28
30
  from ...utils.metacfg import AttrDict
29
31
  from ...utils.settings import ObjectTypes
32
+ from ...utils.types import PathLikeOrStr, PixelValues
30
33
 
31
34
  with try_import() as import_guard:
32
35
  from tensorpack.predict import OfflinePredictor, PredictConfig # pylint: disable=E0401
@@ -51,7 +54,7 @@ class ModelDescWithConfig(ModelDesc, ABC): # type: ignore
51
54
  super().__init__()
52
55
  self.cfg = config
53
56
 
54
- def get_inference_tensor_names(self) -> Tuple[List[str], List[str]]:
57
+ def get_inference_tensor_names(self) -> tuple[list[str], list[str]]:
55
58
  """
56
59
  Returns lists of tensor names to be used to create an inference callable. "build_graph" must create tensors
57
60
  of these names when called under inference context.
@@ -77,7 +80,7 @@ class TensorpackPredictor(ABC):
77
80
  as there is an explicit class available for this.
78
81
  """
79
82
 
80
- def __init__(self, model: ModelDescWithConfig, path_weights: str, ignore_mismatch: bool) -> None:
83
+ def __init__(self, model: ModelDescWithConfig, path_weights: PathLikeOrStr, ignore_mismatch: bool) -> None:
81
84
  """
82
85
  :param model: Model, either as ModelDescWithConfig or derived from that class.
83
86
  :param path_weights: Model weights of the prediction config.
@@ -85,7 +88,7 @@ class TensorpackPredictor(ABC):
85
88
  if a pre-trained model is to be fine-tuned on a custom dataset.
86
89
  """
87
90
  self._model = model
88
- self.path_weights = path_weights
91
+ self.path_weights = Path(path_weights)
89
92
  self.ignore_mismatch = ignore_mismatch
90
93
  self._number_gpus = get_num_gpu()
91
94
  self.predict_config = self._build_config()
@@ -98,9 +101,10 @@ class TensorpackPredictor(ABC):
98
101
  return OfflinePredictor(self.predict_config)
99
102
 
100
103
  def _build_config(self) -> PredictConfig:
104
+ path_weights = os.fspath(self.path_weights) if os.fspath(self.path_weights) != "." else ""
101
105
  predict_config = PredictConfig(
102
106
  model=self._model,
103
- session_init=SmartInit(self.path_weights, ignore_mismatch=self.ignore_mismatch),
107
+ session_init=SmartInit(path_weights, ignore_mismatch=self.ignore_mismatch),
104
108
  input_names=self._model.get_inference_tensor_names()[0],
105
109
  output_names=self._model.get_inference_tensor_names()[1],
106
110
  )
@@ -110,7 +114,7 @@ class TensorpackPredictor(ABC):
110
114
  @staticmethod
111
115
  @abstractmethod
112
116
  def get_wrapped_model(
113
- path_yaml: str, categories: Mapping[str, ObjectTypes], config_overwrite: Union[List[str], None]
117
+ path_yaml: PathLikeOrStr, categories: Mapping[int, ObjectTypes], config_overwrite: Union[list[str], None]
114
118
  ) -> ModelDescWithConfig:
115
119
  """
116
120
  Implement the config generation, its modification and instantiate a version of the model. See
@@ -119,7 +123,7 @@ class TensorpackPredictor(ABC):
119
123
  raise NotImplementedError()
120
124
 
121
125
  @abstractmethod
122
- def predict(self, np_img: Any) -> Any:
126
+ def predict(self, np_img: PixelValues) -> Any:
123
127
  """
124
128
  Implement, how `self.tp_predictor` is invoked and raw prediction results are generated. Do use only raw
125
129
  objects and nothing, which is related to the DD API.
@@ -194,7 +194,7 @@ import numpy as np
194
194
  from lazy_imports import try_import
195
195
 
196
196
  from .....utils.metacfg import AttrDict
197
- from .....utils.settings import ObjectTypes
197
+ from .....utils.settings import TypeOrStr, get_type
198
198
 
199
199
  with try_import() as import_guard:
200
200
  from tensorpack.tfutils import collect_env_info # pylint: disable=E0401
@@ -209,7 +209,7 @@ with try_import() as import_guard:
209
209
  __all__ = ["train_frcnn_config", "model_frcnn_config"]
210
210
 
211
211
 
212
- def model_frcnn_config(config: AttrDict, categories: Mapping[str, ObjectTypes], print_summary: bool = True) -> None:
212
+ def model_frcnn_config(config: AttrDict, categories: Mapping[int, TypeOrStr], print_summary: bool = True) -> None:
213
213
  """
214
214
  Sanity checks for Tensorpack Faster-RCNN config settings, where the focus lies on the model for predicting.
215
215
  It will update the config instance.
@@ -221,8 +221,8 @@ def model_frcnn_config(config: AttrDict, categories: Mapping[str, ObjectTypes],
221
221
 
222
222
  config.freeze(False)
223
223
 
224
- categories = {str(key): categories[val] for key, val in enumerate(categories, 1)}
225
- categories[0] = "BG"
224
+ categories = {key: get_type(categories[val]) for key, val in enumerate(categories, 1)}
225
+ categories[0] = get_type("background")
226
226
  config.DATA.CLASS_NAMES = list(categories.values())
227
227
  config.DATA.CLASS_DICT = categories
228
228
  config.DATA.NUM_CATEGORY = len(config.DATA.CLASS_NAMES) - 1
@@ -71,7 +71,7 @@ def freeze_affine_getter(getter, *args, **kwargs):
71
71
  if name.endswith("/gamma") or name.endswith("/beta"):
72
72
  kwargs["trainable"] = False
73
73
  ret = getter(*args, **kwargs)
74
- tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, ret)
74
+ tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, ret) # pylint: disable=E1101
75
75
  else:
76
76
  ret = getter(*args, **kwargs)
77
77
  return ret
@@ -66,7 +66,7 @@ def decode_bbox_target(box_predictions, anchors, preproc_max_size):
66
66
  xbyb = box_pred_txty * waha + xaya
67
67
  x1y1 = xbyb - wbhb * 0.5
68
68
  x2y2 = xbyb + wbhb * 0.5 # (...)x1x2
69
- out = tf.concat([x1y1, x2y2], axis=-2)
69
+ out = tf.concat([x1y1, x2y2], axis=-2) # pylint: disable=E1123
70
70
  return tf.reshape(out, orig_shape)
71
71
 
72
72
 
@@ -93,7 +93,7 @@ def encode_bbox_target(boxes, anchors):
93
93
  # Note that here not all boxes are valid. Some may be zero
94
94
  txty = (xbyb - xaya) / waha
95
95
  twth = tf.math.log(wbhb / waha) # may contain -inf for invalid boxes
96
- encoded = tf.concat([txty, twth], axis=1) # (-1x2x2)
96
+ encoded = tf.concat([txty, twth], axis=1) # (-1x2x2) # pylint: disable=E1123
97
97
  return tf.reshape(encoded, tf.shape(boxes))
98
98
 
99
99
 
@@ -153,7 +153,7 @@ def crop_and_resize(image, boxes, box_ind, crop_size, pad_border=True):
153
153
  n_w = spacing_w * tf.cast(crop_shape[1] - 1, tf.float32) / imshape[1]
154
154
  n_h = spacing_h * tf.cast(crop_shape[0] - 1, tf.float32) / imshape[0]
155
155
 
156
- return tf.concat([ny0, nx0, ny0 + n_h, nx0 + n_w], axis=1)
156
+ return tf.concat([ny0, nx0, ny0 + n_h, nx0 + n_w], axis=1) # pylint: disable=E1123
157
157
 
158
158
  image_shape = tf.shape(image)[2:]
159
159
 
@@ -213,8 +213,8 @@ class RPNAnchors(namedtuple("_RPNAnchors", ["boxes", "gt_labels", "gt_boxes"])):
213
213
  Slice anchors to the spatial size of this feature map.
214
214
  """
215
215
  shape2d = tf.shape(featuremap)[2:] # h,w
216
- slice3d = tf.concat([shape2d, [-1]], axis=0)
217
- slice4d = tf.concat([shape2d, [-1, -1]], axis=0)
216
+ slice3d = tf.concat([shape2d, [-1]], axis=0) # pylint: disable=E1123
217
+ slice4d = tf.concat([shape2d, [-1, -1]], axis=0) # pylint: disable=E1123
218
218
  boxes = tf.slice(self.boxes, [0, 0, 0, 0], slice4d)
219
219
  gt_labels = tf.slice(self.gt_labels, [0, 0, 0], slice3d)
220
220
  gt_boxes = tf.slice(self.gt_boxes, [0, 0, 0, 0], slice4d)
@@ -151,9 +151,9 @@ def multilevel_roi_align(features, rcnn_boxes, resolution, fpn_anchor_strides):
151
151
  all_rois.append(roi_align(featuremap, boxes_on_featuremap, resolution))
152
152
 
153
153
  # this can fail if using TF<=1.8 with MKL build
154
- all_rois = tf.concat(all_rois, axis=0) # NCHW
154
+ all_rois = tf.concat(all_rois, axis=0) # NCHW # pylint: disable=E1123
155
155
  # Unshuffle to the original order, to match the original samples
156
- level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N
156
+ level_id_perm = tf.concat(level_ids, axis=0) # A permutation of 1~N # pylint: disable=E1123
157
157
  level_id_invert_perm = tf.math.invert_permutation(level_id_perm)
158
158
  all_rois = tf.gather(all_rois, level_id_invert_perm, name="output")
159
159
  return all_rois
@@ -258,8 +258,8 @@ def generate_fpn_proposals(
258
258
  all_boxes.append(proposal_boxes)
259
259
  all_scores.append(proposal_scores)
260
260
 
261
- proposal_boxes = tf.concat(all_boxes, axis=0) # nx4
262
- proposal_scores = tf.concat(all_scores, axis=0) # n
261
+ proposal_boxes = tf.concat(all_boxes, axis=0) # nx4 # pylint: disable=E1123
262
+ proposal_scores = tf.concat(all_scores, axis=0) # n # pylint: disable=E1123
263
263
  # Here we are different from Detectron.
264
264
  # Detectron picks top-k within the batch, rather than within an image, however we do not have a batch.
265
265
  proposal_topk = tf.minimum(tf.size(proposal_scores), fpn_nms_top_k)
@@ -271,8 +271,8 @@ def generate_fpn_proposals(
271
271
  pred_boxes_decoded = multilevel_pred_boxes[lvl]
272
272
  all_boxes.append(tf.reshape(pred_boxes_decoded, [-1, 4]))
273
273
  all_scores.append(tf.reshape(multilevel_label_logits[lvl], [-1]))
274
- all_boxes = tf.concat(all_boxes, axis=0)
275
- all_scores = tf.concat(all_scores, axis=0)
274
+ all_boxes = tf.concat(all_boxes, axis=0) # pylint: disable=E1123
275
+ all_scores = tf.concat(all_scores, axis=0) # pylint: disable=E1123
276
276
  proposal_boxes, proposal_scores = generate_rpn_proposals(
277
277
  all_boxes,
278
278
  all_scores,