deepdoctection 0.30__py3-none-any.whl → 0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (120) hide show
  1. deepdoctection/__init__.py +38 -29
  2. deepdoctection/analyzer/dd.py +36 -29
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/base.py +0 -19
  5. deepdoctection/dataflow/custom.py +4 -3
  6. deepdoctection/dataflow/custom_serialize.py +14 -5
  7. deepdoctection/dataflow/parallel_map.py +12 -11
  8. deepdoctection/dataflow/serialize.py +5 -4
  9. deepdoctection/datapoint/annotation.py +35 -13
  10. deepdoctection/datapoint/box.py +3 -5
  11. deepdoctection/datapoint/convert.py +3 -1
  12. deepdoctection/datapoint/image.py +79 -36
  13. deepdoctection/datapoint/view.py +152 -49
  14. deepdoctection/datasets/__init__.py +1 -4
  15. deepdoctection/datasets/adapter.py +6 -3
  16. deepdoctection/datasets/base.py +86 -11
  17. deepdoctection/datasets/dataflow_builder.py +1 -1
  18. deepdoctection/datasets/info.py +4 -4
  19. deepdoctection/datasets/instances/doclaynet.py +3 -2
  20. deepdoctection/datasets/instances/fintabnet.py +2 -1
  21. deepdoctection/datasets/instances/funsd.py +2 -1
  22. deepdoctection/datasets/instances/iiitar13k.py +5 -2
  23. deepdoctection/datasets/instances/layouttest.py +4 -8
  24. deepdoctection/datasets/instances/publaynet.py +2 -2
  25. deepdoctection/datasets/instances/pubtables1m.py +6 -3
  26. deepdoctection/datasets/instances/pubtabnet.py +2 -1
  27. deepdoctection/datasets/instances/rvlcdip.py +2 -1
  28. deepdoctection/datasets/instances/xfund.py +2 -1
  29. deepdoctection/eval/__init__.py +1 -4
  30. deepdoctection/eval/accmetric.py +1 -1
  31. deepdoctection/eval/base.py +5 -4
  32. deepdoctection/eval/cocometric.py +2 -1
  33. deepdoctection/eval/eval.py +19 -15
  34. deepdoctection/eval/tedsmetric.py +14 -11
  35. deepdoctection/eval/tp_eval_callback.py +14 -7
  36. deepdoctection/extern/__init__.py +2 -7
  37. deepdoctection/extern/base.py +39 -13
  38. deepdoctection/extern/d2detect.py +182 -90
  39. deepdoctection/extern/deskew.py +36 -9
  40. deepdoctection/extern/doctrocr.py +265 -83
  41. deepdoctection/extern/fastlang.py +49 -9
  42. deepdoctection/extern/hfdetr.py +106 -55
  43. deepdoctection/extern/hflayoutlm.py +441 -122
  44. deepdoctection/extern/hflm.py +225 -0
  45. deepdoctection/extern/model.py +56 -47
  46. deepdoctection/extern/pdftext.py +10 -5
  47. deepdoctection/extern/pt/__init__.py +1 -3
  48. deepdoctection/extern/pt/nms.py +6 -2
  49. deepdoctection/extern/pt/ptutils.py +27 -18
  50. deepdoctection/extern/tessocr.py +134 -22
  51. deepdoctection/extern/texocr.py +6 -2
  52. deepdoctection/extern/tp/tfutils.py +43 -9
  53. deepdoctection/extern/tp/tpcompat.py +14 -11
  54. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  55. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  56. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
  58. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
  60. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  61. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
  62. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  67. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  68. deepdoctection/extern/tp/tpfrcnn/preproc.py +8 -9
  69. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  70. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  71. deepdoctection/extern/tpdetect.py +54 -30
  72. deepdoctection/mapper/__init__.py +3 -8
  73. deepdoctection/mapper/d2struct.py +9 -7
  74. deepdoctection/mapper/hfstruct.py +7 -2
  75. deepdoctection/mapper/laylmstruct.py +164 -21
  76. deepdoctection/mapper/maputils.py +16 -3
  77. deepdoctection/mapper/misc.py +6 -3
  78. deepdoctection/mapper/prodigystruct.py +1 -1
  79. deepdoctection/mapper/pubstruct.py +10 -10
  80. deepdoctection/mapper/tpstruct.py +3 -3
  81. deepdoctection/pipe/__init__.py +1 -1
  82. deepdoctection/pipe/anngen.py +35 -8
  83. deepdoctection/pipe/base.py +53 -19
  84. deepdoctection/pipe/common.py +23 -13
  85. deepdoctection/pipe/concurrency.py +2 -1
  86. deepdoctection/pipe/doctectionpipe.py +2 -2
  87. deepdoctection/pipe/language.py +3 -2
  88. deepdoctection/pipe/layout.py +6 -3
  89. deepdoctection/pipe/lm.py +34 -66
  90. deepdoctection/pipe/order.py +142 -35
  91. deepdoctection/pipe/refine.py +26 -24
  92. deepdoctection/pipe/segment.py +21 -16
  93. deepdoctection/pipe/{cell.py → sub_layout.py} +30 -9
  94. deepdoctection/pipe/text.py +14 -8
  95. deepdoctection/pipe/transform.py +16 -9
  96. deepdoctection/train/__init__.py +6 -12
  97. deepdoctection/train/d2_frcnn_train.py +36 -28
  98. deepdoctection/train/hf_detr_train.py +26 -17
  99. deepdoctection/train/hf_layoutlm_train.py +133 -111
  100. deepdoctection/train/tp_frcnn_train.py +21 -19
  101. deepdoctection/utils/__init__.py +3 -0
  102. deepdoctection/utils/concurrency.py +1 -1
  103. deepdoctection/utils/context.py +2 -2
  104. deepdoctection/utils/env_info.py +41 -84
  105. deepdoctection/utils/error.py +84 -0
  106. deepdoctection/utils/file_utils.py +4 -15
  107. deepdoctection/utils/fs.py +7 -7
  108. deepdoctection/utils/logger.py +1 -0
  109. deepdoctection/utils/mocks.py +93 -0
  110. deepdoctection/utils/pdf_utils.py +5 -4
  111. deepdoctection/utils/settings.py +6 -1
  112. deepdoctection/utils/transform.py +1 -1
  113. deepdoctection/utils/utils.py +0 -6
  114. deepdoctection/utils/viz.py +48 -5
  115. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/METADATA +57 -73
  116. deepdoctection-0.32.dist-info/RECORD +146 -0
  117. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/WHEEL +1 -1
  118. deepdoctection-0.30.dist-info/RECORD +0 -143
  119. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
  120. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
@@ -19,21 +19,24 @@
19
19
  Tesseract OCR engine for text extraction
20
20
  """
21
21
  import shlex
22
+ import string
22
23
  import subprocess
23
24
  import sys
24
25
  from errno import ENOENT
25
26
  from itertools import groupby
26
27
  from os import environ
27
- from typing import Any, Dict, List, Optional, Union
28
+ from typing import Any, Dict, List, Mapping, Optional, Union
28
29
 
29
- import numpy as np
30
+ from packaging.version import InvalidVersion, Version, parse
30
31
 
31
32
  from ..utils.context import save_tmp_file, timeout_manager
32
33
  from ..utils.detection_types import ImageType, Requirement
33
- from ..utils.file_utils import _TESS_PATH, TesseractNotFound, get_tesseract_requirement
34
+ from ..utils.error import DependencyError, TesseractError
35
+ from ..utils.file_utils import _TESS_PATH, get_tesseract_requirement
34
36
  from ..utils.metacfg import config_to_cli_str, set_config_by_yaml
35
- from ..utils.settings import LayoutType, ObjectTypes
36
- from .base import DetectionResult, ObjectDetector, PredictorBase
37
+ from ..utils.settings import LayoutType, ObjectTypes, PageType
38
+ from ..utils.viz import viz_handler
39
+ from .base import DetectionResult, ImageTransformer, ObjectDetector, PredictorBase
37
40
 
38
41
  # copy and paste with some light modifications from https://github.com/madmaze/pytesseract/tree/master/pytesseract
39
42
 
@@ -57,18 +60,6 @@ _LANG_CODE_TO_TESS_LANG_CODE = {
57
60
  }
58
61
 
59
62
 
60
- class TesseractError(RuntimeError):
61
- """
62
- Tesseract Error
63
- """
64
-
65
- def __init__(self, status: int, message: str) -> None:
66
- super().__init__()
67
- self.status = status
68
- self.message = message
69
- self.args = (status, message)
70
-
71
-
72
63
  def _subprocess_args() -> Dict[str, Any]:
73
64
  # See https://github.com/pyinstaller/pyinstaller/wiki/Recipe-subprocess
74
65
  # for reference and comments.
@@ -109,7 +100,7 @@ def _run_tesseract(tesseract_args: List[str]) -> None:
109
100
  except OSError as error:
110
101
  if error.errno != ENOENT:
111
102
  raise error from error
112
- raise TesseractNotFound("Tesseract not found. Please install or add to your PATH.") from error
103
+ raise DependencyError("Tesseract not found. Please install or add to your PATH.") from error
113
104
 
114
105
  with timeout_manager(proc, 0) as error_string:
115
106
  if proc.returncode:
@@ -119,6 +110,50 @@ def _run_tesseract(tesseract_args: List[str]) -> None:
119
110
  )
120
111
 
121
112
 
113
+ def get_tesseract_version() -> Version:
114
+ """
115
+ Returns Version object of the Tesseract version
116
+ """
117
+ try:
118
+ output = subprocess.check_output(
119
+ ["tesseract", "--version"],
120
+ stderr=subprocess.STDOUT,
121
+ env=environ,
122
+ stdin=subprocess.DEVNULL,
123
+ )
124
+ except OSError as error:
125
+ raise DependencyError("Tesseract not found. Please install or add to your PATH.") from error
126
+
127
+ raw_version = output.decode("utf-8")
128
+ str_version, *_ = raw_version.lstrip(string.printable[10:]).partition(" ")
129
+ str_version, *_ = str_version.partition("-")
130
+
131
+ try:
132
+ version = parse(str_version)
133
+ assert version >= Version("3.05")
134
+ except (AssertionError, InvalidVersion) as error:
135
+ raise SystemExit(f'Invalid tesseract version: "{raw_version}"') from error
136
+
137
+ return version
138
+
139
+
140
+ def image_to_angle(image: ImageType) -> Mapping[str, str]:
141
+ """
142
+ Generating a tmp file and running tesseract to get the orientation of the image.
143
+
144
+ :param image: Image in np.array.
145
+ :return: A dictionary with keys 'Orientation in degrees' and 'Orientation confidence'.
146
+ """
147
+ with save_tmp_file(image, "tess_") as (tmp_name, input_file_name):
148
+ _run_tesseract(_input_to_cli_str("osd", "--psm 0", 0, input_file_name, tmp_name))
149
+ with open(tmp_name + ".osd", "rb") as output_file:
150
+ output = output_file.read().decode("utf-8")
151
+
152
+ return {
153
+ key_value[0]: key_value[1] for key_value in (line.split(": ") for line in output.split("\n") if len(line) >= 2)
154
+ }
155
+
156
+
122
157
  def image_to_dict(image: ImageType, lang: str, config: str) -> Dict[str, List[Union[str, int, float]]]:
123
158
  """
124
159
  This is more or less pytesseract.image_to_data with a dict as returned value.
@@ -220,7 +255,6 @@ def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool,
220
255
  :return: A list of tesseract extractions wrapped in DetectionResult
221
256
  """
222
257
 
223
- np_img = np_img.astype(np.uint8)
224
258
  results = image_to_dict(np_img, supported_languages, config)
225
259
  all_results = []
226
260
 
@@ -249,6 +283,16 @@ def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool,
249
283
  return all_results
250
284
 
251
285
 
286
+ def predict_rotation(np_img: ImageType) -> Mapping[str, str]:
287
+ """
288
+ Predicts the rotation of an image using the Tesseract OCR engine.
289
+
290
+ :param np_img: numpy array of the image
291
+ :return: A dictionary with keys 'Orientation in degrees' and 'Orientation confidence'
292
+ """
293
+ return image_to_angle(np_img)
294
+
295
+
252
296
  class TesseractOcrDetector(ObjectDetector):
253
297
  """
254
298
  Text object detector based on Tesseracts OCR engine. Note that tesseract has to be installed separately.
@@ -292,7 +336,9 @@ class TesseractOcrDetector(ObjectDetector):
292
336
  :param config_overwrite: Overwrite config parameters defined by the yaml file with new values.
293
337
  E.g. ["oem=14"]
294
338
  """
295
- self.name = _TESS_PATH
339
+ self.name = self.get_name()
340
+ self.model_id = self.get_model_id()
341
+
296
342
  if config_overwrite is None:
297
343
  config_overwrite = []
298
344
 
@@ -316,13 +362,13 @@ class TesseractOcrDetector(ObjectDetector):
316
362
  :param np_img: image as numpy array
317
363
  :return: A list of DetectionResult
318
364
  """
319
- detection_results = predict_text(
365
+
366
+ return predict_text(
320
367
  np_img,
321
368
  supported_languages=self.config.LANGUAGES,
322
369
  text_lines=self.config.LINES,
323
370
  config=config_to_cli_str(self.config, "LANGUAGES", "LINES"),
324
371
  )
325
- return detection_results
326
372
 
327
373
  @classmethod
328
374
  def get_requirements(cls) -> List[Requirement]:
@@ -342,3 +388,69 @@ class TesseractOcrDetector(ObjectDetector):
342
388
  :param language: `Languages`
343
389
  """
344
390
  self.config.LANGUAGES = _LANG_CODE_TO_TESS_LANG_CODE.get(language, language.value)
391
+
392
+ @staticmethod
393
+ def get_name() -> str:
394
+ """Returns the name of the model"""
395
+ return f"Tesseract_{get_tesseract_version()}"
396
+
397
+
398
+ class TesseractRotationTransformer(ImageTransformer):
399
+ """
400
+ The `TesseractRotationTransformer` class is a specialized image transformer that is designed to handle image
401
+ rotation in the context of Optical Character Recognition (OCR) tasks. It inherits from the `ImageTransformer`
402
+ base class and implements methods for predicting and applying rotation transformations to images.
403
+
404
+ The `predict` method determines the angle of the rotated image. It can only handle angles that are multiples of 90
405
+ degrees.
406
+ This method uses the Tesseract OCR engine to predict the rotation angle of an image.
407
+
408
+ The `transform` method applies the predicted rotation to the image, effectively rotating the image backwards.
409
+ This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
410
+
411
+ This class can be particularly useful in OCR tasks where the orientation of the text in the image matters.
412
+ The class also provides methods for cloning itself and for getting the requirements of the Tesseract OCR system.
413
+
414
+ **Example:**
415
+ transformer = TesseractRotationTransformer()
416
+ detection_result = transformer.predict(np_img)
417
+ rotated_image = transformer.transform(np_img, detection_result)
418
+ """
419
+
420
+ def __init__(self) -> None:
421
+ self.name = _TESS_PATH + "-rotation"
422
+
423
+ def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
424
+ """
425
+ Applies the predicted rotation to the image, effectively rotating the image backwards.
426
+ This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
427
+
428
+ :param np_img: The input image as a numpy array.
429
+ :param specification: A `DetectionResult` object containing the predicted rotation angle.
430
+ :return: The rotated image as a numpy array.
431
+ """
432
+ return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
433
+
434
+ def predict(self, np_img: ImageType) -> DetectionResult:
435
+ """
436
+ Determines the angle of the rotated image. It can only handle angles that are multiples of 90 degrees.
437
+ This method uses the Tesseract OCR engine to predict the rotation angle of an image.
438
+
439
+ :param np_img: The input image as a numpy array.
440
+ :return: A `DetectionResult` object containing the predicted rotation angle and confidence.
441
+ """
442
+ output_dict = predict_rotation(np_img)
443
+ return DetectionResult(
444
+ angle=float(output_dict["Orientation in degrees"]), score=float(output_dict["Orientation confidence"])
445
+ )
446
+
447
+ @classmethod
448
+ def get_requirements(cls) -> List[Requirement]:
449
+ return [get_tesseract_requirement()]
450
+
451
+ def clone(self) -> PredictorBase:
452
+ return self.__class__()
453
+
454
+ @staticmethod
455
+ def possible_category() -> PageType:
456
+ return PageType.angle
@@ -23,14 +23,16 @@ import sys
23
23
  import traceback
24
24
  from typing import List
25
25
 
26
+ from lazy_imports import try_import
27
+
26
28
  from ..datapoint.convert import convert_np_array_to_b64_b
27
29
  from ..utils.detection_types import ImageType, JsonDict, Requirement
28
- from ..utils.file_utils import boto3_available, get_boto3_requirement
30
+ from ..utils.file_utils import get_boto3_requirement
29
31
  from ..utils.logger import LoggingRecord, logger
30
32
  from ..utils.settings import LayoutType, ObjectTypes
31
33
  from .base import DetectionResult, ObjectDetector, PredictorBase
32
34
 
33
- if boto3_available():
35
+ with try_import() as import_guard:
34
36
  import boto3 # type:ignore
35
37
 
36
38
 
@@ -120,6 +122,8 @@ class TextractOcrDetector(ObjectDetector):
120
122
  :param credentials_kwargs: `aws_access_key_id`, `aws_secret_access_key` or `aws_session_token`
121
123
  """
122
124
  self.name = "textract"
125
+ self.model_id = self.get_model_id()
126
+
123
127
  self.text_lines = text_lines
124
128
  self.client = boto3.client("textract", **credentials_kwargs)
125
129
  if self.text_lines:
@@ -19,7 +19,18 @@
19
19
  Tensorflow related utils.
20
20
  """
21
21
 
22
- from tensorpack.models import disable_layer_logging # pylint: disable=E0401
22
+ from __future__ import annotations
23
+
24
+ import os
25
+ from typing import Optional, Union, ContextManager
26
+
27
+ from lazy_imports import try_import
28
+
29
+ with try_import() as import_guard:
30
+ from tensorpack.models import disable_layer_logging # pylint: disable=E0401
31
+
32
+ with try_import() as tf_import_guard:
33
+ import tensorflow as tf # pylint: disable=E0401
23
34
 
24
35
 
25
36
  def is_tfv2() -> bool:
@@ -38,16 +49,13 @@ def disable_tfv2() -> bool:
38
49
  """
39
50
  Disable TF in V2 mode.
40
51
  """
41
- try:
42
- import tensorflow as tf # pylint: disable=C0415
43
52
 
44
- tfv1 = tf.compat.v1
45
- if is_tfv2():
46
- tfv1.disable_v2_behavior()
47
- tfv1.disable_eager_execution()
53
+ tfv1 = tf.compat.v1
54
+ if is_tfv2():
55
+ tfv1.disable_v2_behavior()
56
+ tfv1.disable_eager_execution()
48
57
  return True
49
- except ModuleNotFoundError:
50
- return False
58
+ return False
51
59
 
52
60
 
53
61
  def disable_tp_layer_logging() -> None:
@@ -55,3 +63,29 @@ def disable_tp_layer_logging() -> None:
55
63
  Disables TP layer logging, if not already set
56
64
  """
57
65
  disable_layer_logging()
66
+
67
+
68
+ def get_tf_device(device: Optional[Union[str, tf.device]] = None) -> tf.device:
69
+ """
70
+ Selecting a device on which to load a model. The selection follows a cascade of priorities:
71
+
72
+ - If a device string is provided, it is used. If the string is "cuda" or "GPU", the first GPU is used.
73
+ - If the environment variable "USE_CUDA" is set, a GPU is used. If more GPUs are available it will use the first one
74
+
75
+ :param device: Device string
76
+ :return: Tensorflow device
77
+ """
78
+ if device is not None:
79
+ if isinstance(device, ContextManager):
80
+ return device
81
+ if isinstance(device, str):
82
+ if device in ("cuda", "GPU"):
83
+ device_names = [device.name for device in tf.config.list_logical_devices(device_type="GPU")]
84
+ return tf.device(device_names[0].name)
85
+ # The input must be something sensible
86
+ return tf.device(device)
87
+ if os.environ.get("USE_CUDA"):
88
+ device_names = [device.name for device in tf.config.list_logical_devices(device_type="GPU")]
89
+ return tf.device(device_names[0])
90
+ device_names = [device.name for device in tf.config.list_logical_devices(device_type="CPU")]
91
+ return tf.device(device_names[0])
@@ -18,21 +18,24 @@
18
18
  """
19
19
  Compatibility classes and methods related to Tensorpack package
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
23
  from abc import ABC, abstractmethod
23
24
  from typing import Any, List, Mapping, Tuple, Union
24
25
 
25
- from tensorpack.predict import OfflinePredictor, PredictConfig # pylint: disable=E0401
26
- from tensorpack.tfutils import SmartInit # pylint: disable=E0401
27
-
28
- # pylint: disable=import-error
29
- from tensorpack.train.model_desc import ModelDesc
30
- from tensorpack.utils.gpu import get_num_gpu
26
+ from lazy_imports import try_import
31
27
 
32
28
  from ...utils.metacfg import AttrDict
33
29
  from ...utils.settings import ObjectTypes
34
30
 
35
- # pylint: enable=import-error
31
+ with try_import() as import_guard:
32
+ from tensorpack.predict import OfflinePredictor, PredictConfig # pylint: disable=E0401
33
+ from tensorpack.tfutils import SmartInit # pylint: disable=E0401
34
+ from tensorpack.train.model_desc import ModelDesc # pylint: disable=E0401
35
+ from tensorpack.utils.gpu import get_num_gpu # pylint: disable=E0401
36
+
37
+ if not import_guard.is_successful():
38
+ from ...utils.mocks import ModelDesc
36
39
 
37
40
 
38
41
  class ModelDescWithConfig(ModelDesc, ABC): # type: ignore
@@ -55,7 +58,7 @@ class ModelDescWithConfig(ModelDesc, ABC): # type: ignore
55
58
 
56
59
  :return: Tuple of list input and list output names. The names must coincide with tensor within the model.
57
60
  """
58
- raise NotImplementedError
61
+ raise NotImplementedError()
59
62
 
60
63
 
61
64
  class TensorpackPredictor(ABC):
@@ -106,14 +109,14 @@ class TensorpackPredictor(ABC):
106
109
 
107
110
  @staticmethod
108
111
  @abstractmethod
109
- def set_model(
112
+ def get_wrapped_model(
110
113
  path_yaml: str, categories: Mapping[str, ObjectTypes], config_overwrite: Union[List[str], None]
111
114
  ) -> ModelDescWithConfig:
112
115
  """
113
116
  Implement the config generation, its modification and instantiate a version of the model. See
114
117
  `pipe.tpfrcnn.TPFrcnnDetector` for an example
115
118
  """
116
- raise NotImplementedError
119
+ raise NotImplementedError()
117
120
 
118
121
  @abstractmethod
119
122
  def predict(self, np_img: Any) -> Any:
@@ -121,7 +124,7 @@ class TensorpackPredictor(ABC):
121
124
  Implement, how `self.tp_predictor` is invoked and raw prediction results are generated. Do use only raw
122
125
  objects and nothing, which is related to the DD API.
123
126
  """
124
- raise NotImplementedError
127
+ raise NotImplementedError()
125
128
 
126
129
  @property
127
130
  def model(self) -> ModelDescWithConfig:
@@ -0,0 +1,20 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File: __init__.py
3
+
4
+ # Copyright 2021 Dr. Janis Meyer. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Init file for code for Tensorpack FRCNN example
20
+ """
@@ -11,13 +11,17 @@ This file is modified from
11
11
 
12
12
 
13
13
  import numpy as np
14
- from tensorpack.dataflow.imgaug import ImageAugmentor, ResizeTransform # pylint: disable=E0401
14
+ from lazy_imports import try_import
15
15
 
16
- from ....utils.file_utils import cocotools_available
16
+ with try_import() as import_guard:
17
+ from tensorpack.dataflow.imgaug import ImageAugmentor, ResizeTransform # pylint: disable=E0401
17
18
 
18
- if cocotools_available():
19
+ with try_import() as cc_import_guard:
19
20
  import pycocotools.mask as coco_mask
20
21
 
22
+ if not import_guard.is_successful():
23
+ from ....utils.mocks import ImageAugmentor
24
+
21
25
 
22
26
  class CustomResize(ImageAugmentor):
23
27
  """
@@ -0,0 +1,20 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File: __init__.py
3
+
4
+ # Copyright 2021 Dr. Janis Meyer. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Init file for code for Tensorpack's FRCNN configs
20
+ """
@@ -191,16 +191,19 @@ import os
191
191
  from typing import List, Mapping, Tuple
192
192
 
193
193
  import numpy as np
194
- from tensorpack.tfutils import collect_env_info # pylint: disable=E0401
195
- from tensorpack.utils import logger # pylint: disable=E0401
196
-
197
- # pylint: disable=import-error
198
- from tensorpack.utils.gpu import get_num_gpu
194
+ from lazy_imports import try_import
199
195
 
200
196
  from .....utils.metacfg import AttrDict
201
197
  from .....utils.settings import ObjectTypes
202
198
 
203
- # pylint: enable=import-error
199
+ with try_import() as import_guard:
200
+ from tensorpack.tfutils import collect_env_info # pylint: disable=E0401
201
+ from tensorpack.utils import logger # pylint: disable=E0401
202
+
203
+ # pylint: disable=import-error
204
+ from tensorpack.utils.gpu import get_num_gpu
205
+
206
+ # pylint: enable=import-error
204
207
 
205
208
 
206
209
  __all__ = ["train_frcnn_config", "model_frcnn_config"]
@@ -0,0 +1,20 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File: __init__.py
3
+
4
+ # Copyright 2021 Dr. Janis Meyer. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Init file for code for Tensorpack's FRCNN configs
20
+ """
@@ -12,22 +12,30 @@ This file is modified from
12
12
  from contextlib import ExitStack, contextmanager
13
13
 
14
14
  import numpy as np
15
+ from lazy_imports import try_import
15
16
 
16
17
  # pylint: disable=import-error
17
- import tensorflow as tf
18
- from tensorpack import tfv1
19
- from tensorpack.models import BatchNorm, Conv2D, MaxPooling, layer_register
20
- from tensorpack.tfutils import argscope
21
- from tensorpack.tfutils.varreplace import custom_getter_scope, freeze_variables
18
+
19
+ with try_import() as import_guard:
20
+ import tensorflow as tf
21
+ from tensorpack import tfv1
22
+ from tensorpack.models import BatchNorm, Conv2D, MaxPooling, layer_register
23
+ from tensorpack.tfutils import argscope
24
+ from tensorpack.tfutils.varreplace import custom_getter_scope, freeze_variables
22
25
 
23
26
  # pylint: enable=import-error
24
27
 
28
+ if not import_guard.is_successful():
29
+ from .....utils.mocks import layer_register
30
+
25
31
 
26
32
  @layer_register(log_shape=True)
27
- def GroupNorm(x, group=32, gamma_initializer=tf.constant_initializer(1.0)):
33
+ def GroupNorm(x, group=32, gamma_initializer=None):
28
34
  """
29
35
  More code that reproduces the paper can be found at <https://github.com/ppwwyyxx/GroupNorm-reproduce/>.
30
36
  """
37
+ if gamma_initializer is None:
38
+ gamma_initializer = tf.constant_initializer(1.0)
31
39
  shape = x.get_shape().as_list()
32
40
  ndims = len(shape)
33
41
  assert ndims == 4, shape
@@ -153,7 +161,7 @@ def get_norm(cfg, zero_init=False):
153
161
  return lambda x: norm(layer_name, x, gamma_initializer=tf.zeros_initializer() if zero_init else None)
154
162
 
155
163
 
156
- def resnet_shortcut(l, n_out, stride, activation=tf.identity):
164
+ def resnet_shortcut(l, n_out, stride, activation=None):
157
165
  """
158
166
  Defining the skip connection in bottleneck
159
167
 
@@ -163,6 +171,8 @@ def resnet_shortcut(l, n_out, stride, activation=tf.identity):
163
171
  :param activation: An activation function
164
172
  :return: tf.Tensor
165
173
  """
174
+ if activation is None:
175
+ activation = tf.identity
166
176
  n_in = l.shape[1]
167
177
  if n_in != n_out: # change dimension when channel is not the same
168
178
  return Conv2D("convshortcut", l, n_out, 1, strides=stride, activation=activation) # pylint: disable=E1124
@@ -9,12 +9,8 @@ This file is modified from
9
9
  <https://github.com/tensorpack/tensorpack/blob/master/examples/FasterRCNN/modeling/generalized_rcnn.py>
10
10
  """
11
11
 
12
- # pylint: disable=import-error
13
- import tensorflow as tf
14
- from tensorpack import tfv1
15
- from tensorpack.models import l2_regularizer, regularize_cost
16
- from tensorpack.tfutils import optimizer
17
- from tensorpack.tfutils.summary import add_moving_summary
12
+
13
+ from lazy_imports import try_import
18
14
 
19
15
  from ...tpcompat import ModelDescWithConfig
20
16
  from ..utils.box_ops import area as tf_area
@@ -40,6 +36,16 @@ from .model_frcnn import (
40
36
  from .model_mrcnn import maskrcnn_loss, unpackbits_masks
41
37
  from .model_rpn import rpn_head
42
38
 
39
+ with try_import() as import_guard:
40
+ # pylint: disable=import-error
41
+ import tensorflow as tf
42
+ from tensorpack import tfv1
43
+ from tensorpack.models import l2_regularizer, regularize_cost
44
+ from tensorpack.tfutils import optimizer
45
+ from tensorpack.tfutils.summary import add_moving_summary
46
+
47
+ # pylint: enable=import-error
48
+
43
49
 
44
50
  class GeneralizedRCNN(ModelDescWithConfig):
45
51
  """
@@ -11,12 +11,17 @@ This file is modified from
11
11
  from collections import namedtuple
12
12
 
13
13
  import numpy as np
14
+ from lazy_imports import try_import
14
15
 
15
- # pylint: disable=import-error
16
- import tensorflow as tf
17
- from tensorpack.tfutils.scope_utils import under_name_scope
16
+ with try_import() as import_guard:
17
+ # pylint: disable=import-error
18
+ import tensorflow as tf
19
+ from tensorpack.tfutils.scope_utils import under_name_scope
18
20
 
19
- # pylint: enable=import-error
21
+ # pylint: enable=import-error
22
+
23
+ if not import_guard.is_successful():
24
+ from .....utils.mocks import under_name_scope
20
25
 
21
26
 
22
27
  @under_name_scope()
@@ -9,17 +9,20 @@ This file is modified from
9
9
  <https://github.com/tensorpack/tensorpack/blob/master/examples/FasterRCNN/modeling/model_cascade.py>
10
10
  """
11
11
 
12
- # pylint: disable=import-error
13
- import tensorflow as tf
14
- from tensorpack import tfv1
15
- from tensorpack.tfutils import get_current_tower_context
12
+ from lazy_imports import try_import
16
13
 
17
14
  from ..utils.box_ops import area as tf_area
18
15
  from ..utils.box_ops import pairwise_iou
19
16
  from .model_box import clip_boxes
20
17
  from .model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs
21
18
 
22
- # pylint: enable=import-error
19
+ with try_import() as import_guard:
20
+ # pylint: disable=import-error
21
+ import tensorflow as tf
22
+ from tensorpack import tfv1
23
+ from tensorpack.tfutils import get_current_tower_context
24
+
25
+ # pylint: enable=import-error
23
26
 
24
27
 
25
28
  class CascadeRCNNHead: