deepdoctection 0.30__py3-none-any.whl → 0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +38 -29
- deepdoctection/analyzer/dd.py +36 -29
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/dataflow/base.py +0 -19
- deepdoctection/dataflow/custom.py +4 -3
- deepdoctection/dataflow/custom_serialize.py +14 -5
- deepdoctection/dataflow/parallel_map.py +12 -11
- deepdoctection/dataflow/serialize.py +5 -4
- deepdoctection/datapoint/annotation.py +35 -13
- deepdoctection/datapoint/box.py +3 -5
- deepdoctection/datapoint/convert.py +3 -1
- deepdoctection/datapoint/image.py +79 -36
- deepdoctection/datapoint/view.py +152 -49
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +6 -3
- deepdoctection/datasets/base.py +86 -11
- deepdoctection/datasets/dataflow_builder.py +1 -1
- deepdoctection/datasets/info.py +4 -4
- deepdoctection/datasets/instances/doclaynet.py +3 -2
- deepdoctection/datasets/instances/fintabnet.py +2 -1
- deepdoctection/datasets/instances/funsd.py +2 -1
- deepdoctection/datasets/instances/iiitar13k.py +5 -2
- deepdoctection/datasets/instances/layouttest.py +4 -8
- deepdoctection/datasets/instances/publaynet.py +2 -2
- deepdoctection/datasets/instances/pubtables1m.py +6 -3
- deepdoctection/datasets/instances/pubtabnet.py +2 -1
- deepdoctection/datasets/instances/rvlcdip.py +2 -1
- deepdoctection/datasets/instances/xfund.py +2 -1
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/accmetric.py +1 -1
- deepdoctection/eval/base.py +5 -4
- deepdoctection/eval/cocometric.py +2 -1
- deepdoctection/eval/eval.py +19 -15
- deepdoctection/eval/tedsmetric.py +14 -11
- deepdoctection/eval/tp_eval_callback.py +14 -7
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/base.py +39 -13
- deepdoctection/extern/d2detect.py +182 -90
- deepdoctection/extern/deskew.py +36 -9
- deepdoctection/extern/doctrocr.py +265 -83
- deepdoctection/extern/fastlang.py +49 -9
- deepdoctection/extern/hfdetr.py +106 -55
- deepdoctection/extern/hflayoutlm.py +441 -122
- deepdoctection/extern/hflm.py +225 -0
- deepdoctection/extern/model.py +56 -47
- deepdoctection/extern/pdftext.py +10 -5
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +27 -18
- deepdoctection/extern/tessocr.py +134 -22
- deepdoctection/extern/texocr.py +6 -2
- deepdoctection/extern/tp/tfutils.py +43 -9
- deepdoctection/extern/tp/tpcompat.py +14 -11
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +8 -9
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +54 -30
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/d2struct.py +9 -7
- deepdoctection/mapper/hfstruct.py +7 -2
- deepdoctection/mapper/laylmstruct.py +164 -21
- deepdoctection/mapper/maputils.py +16 -3
- deepdoctection/mapper/misc.py +6 -3
- deepdoctection/mapper/prodigystruct.py +1 -1
- deepdoctection/mapper/pubstruct.py +10 -10
- deepdoctection/mapper/tpstruct.py +3 -3
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/anngen.py +35 -8
- deepdoctection/pipe/base.py +53 -19
- deepdoctection/pipe/common.py +23 -13
- deepdoctection/pipe/concurrency.py +2 -1
- deepdoctection/pipe/doctectionpipe.py +2 -2
- deepdoctection/pipe/language.py +3 -2
- deepdoctection/pipe/layout.py +6 -3
- deepdoctection/pipe/lm.py +34 -66
- deepdoctection/pipe/order.py +142 -35
- deepdoctection/pipe/refine.py +26 -24
- deepdoctection/pipe/segment.py +21 -16
- deepdoctection/pipe/{cell.py → sub_layout.py} +30 -9
- deepdoctection/pipe/text.py +14 -8
- deepdoctection/pipe/transform.py +16 -9
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +36 -28
- deepdoctection/train/hf_detr_train.py +26 -17
- deepdoctection/train/hf_layoutlm_train.py +133 -111
- deepdoctection/train/tp_frcnn_train.py +21 -19
- deepdoctection/utils/__init__.py +3 -0
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +2 -2
- deepdoctection/utils/env_info.py +41 -84
- deepdoctection/utils/error.py +84 -0
- deepdoctection/utils/file_utils.py +4 -15
- deepdoctection/utils/fs.py +7 -7
- deepdoctection/utils/logger.py +1 -0
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/pdf_utils.py +5 -4
- deepdoctection/utils/settings.py +6 -1
- deepdoctection/utils/transform.py +1 -1
- deepdoctection/utils/utils.py +0 -6
- deepdoctection/utils/viz.py +48 -5
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/METADATA +57 -73
- deepdoctection-0.32.dist-info/RECORD +146 -0
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/WHEEL +1 -1
- deepdoctection-0.30.dist-info/RECORD +0 -143
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
deepdoctection/extern/tessocr.py
CHANGED
|
@@ -19,21 +19,24 @@
|
|
|
19
19
|
Tesseract OCR engine for text extraction
|
|
20
20
|
"""
|
|
21
21
|
import shlex
|
|
22
|
+
import string
|
|
22
23
|
import subprocess
|
|
23
24
|
import sys
|
|
24
25
|
from errno import ENOENT
|
|
25
26
|
from itertools import groupby
|
|
26
27
|
from os import environ
|
|
27
|
-
from typing import Any, Dict, List, Optional, Union
|
|
28
|
+
from typing import Any, Dict, List, Mapping, Optional, Union
|
|
28
29
|
|
|
29
|
-
import
|
|
30
|
+
from packaging.version import InvalidVersion, Version, parse
|
|
30
31
|
|
|
31
32
|
from ..utils.context import save_tmp_file, timeout_manager
|
|
32
33
|
from ..utils.detection_types import ImageType, Requirement
|
|
33
|
-
from ..utils.
|
|
34
|
+
from ..utils.error import DependencyError, TesseractError
|
|
35
|
+
from ..utils.file_utils import _TESS_PATH, get_tesseract_requirement
|
|
34
36
|
from ..utils.metacfg import config_to_cli_str, set_config_by_yaml
|
|
35
|
-
from ..utils.settings import LayoutType, ObjectTypes
|
|
36
|
-
from .
|
|
37
|
+
from ..utils.settings import LayoutType, ObjectTypes, PageType
|
|
38
|
+
from ..utils.viz import viz_handler
|
|
39
|
+
from .base import DetectionResult, ImageTransformer, ObjectDetector, PredictorBase
|
|
37
40
|
|
|
38
41
|
# copy and paste with some light modifications from https://github.com/madmaze/pytesseract/tree/master/pytesseract
|
|
39
42
|
|
|
@@ -57,18 +60,6 @@ _LANG_CODE_TO_TESS_LANG_CODE = {
|
|
|
57
60
|
}
|
|
58
61
|
|
|
59
62
|
|
|
60
|
-
class TesseractError(RuntimeError):
|
|
61
|
-
"""
|
|
62
|
-
Tesseract Error
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
def __init__(self, status: int, message: str) -> None:
|
|
66
|
-
super().__init__()
|
|
67
|
-
self.status = status
|
|
68
|
-
self.message = message
|
|
69
|
-
self.args = (status, message)
|
|
70
|
-
|
|
71
|
-
|
|
72
63
|
def _subprocess_args() -> Dict[str, Any]:
|
|
73
64
|
# See https://github.com/pyinstaller/pyinstaller/wiki/Recipe-subprocess
|
|
74
65
|
# for reference and comments.
|
|
@@ -109,7 +100,7 @@ def _run_tesseract(tesseract_args: List[str]) -> None:
|
|
|
109
100
|
except OSError as error:
|
|
110
101
|
if error.errno != ENOENT:
|
|
111
102
|
raise error from error
|
|
112
|
-
raise
|
|
103
|
+
raise DependencyError("Tesseract not found. Please install or add to your PATH.") from error
|
|
113
104
|
|
|
114
105
|
with timeout_manager(proc, 0) as error_string:
|
|
115
106
|
if proc.returncode:
|
|
@@ -119,6 +110,50 @@ def _run_tesseract(tesseract_args: List[str]) -> None:
|
|
|
119
110
|
)
|
|
120
111
|
|
|
121
112
|
|
|
113
|
+
def get_tesseract_version() -> Version:
|
|
114
|
+
"""
|
|
115
|
+
Returns Version object of the Tesseract version
|
|
116
|
+
"""
|
|
117
|
+
try:
|
|
118
|
+
output = subprocess.check_output(
|
|
119
|
+
["tesseract", "--version"],
|
|
120
|
+
stderr=subprocess.STDOUT,
|
|
121
|
+
env=environ,
|
|
122
|
+
stdin=subprocess.DEVNULL,
|
|
123
|
+
)
|
|
124
|
+
except OSError as error:
|
|
125
|
+
raise DependencyError("Tesseract not found. Please install or add to your PATH.") from error
|
|
126
|
+
|
|
127
|
+
raw_version = output.decode("utf-8")
|
|
128
|
+
str_version, *_ = raw_version.lstrip(string.printable[10:]).partition(" ")
|
|
129
|
+
str_version, *_ = str_version.partition("-")
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
version = parse(str_version)
|
|
133
|
+
assert version >= Version("3.05")
|
|
134
|
+
except (AssertionError, InvalidVersion) as error:
|
|
135
|
+
raise SystemExit(f'Invalid tesseract version: "{raw_version}"') from error
|
|
136
|
+
|
|
137
|
+
return version
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def image_to_angle(image: ImageType) -> Mapping[str, str]:
|
|
141
|
+
"""
|
|
142
|
+
Generating a tmp file and running tesseract to get the orientation of the image.
|
|
143
|
+
|
|
144
|
+
:param image: Image in np.array.
|
|
145
|
+
:return: A dictionary with keys 'Orientation in degrees' and 'Orientation confidence'.
|
|
146
|
+
"""
|
|
147
|
+
with save_tmp_file(image, "tess_") as (tmp_name, input_file_name):
|
|
148
|
+
_run_tesseract(_input_to_cli_str("osd", "--psm 0", 0, input_file_name, tmp_name))
|
|
149
|
+
with open(tmp_name + ".osd", "rb") as output_file:
|
|
150
|
+
output = output_file.read().decode("utf-8")
|
|
151
|
+
|
|
152
|
+
return {
|
|
153
|
+
key_value[0]: key_value[1] for key_value in (line.split(": ") for line in output.split("\n") if len(line) >= 2)
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
|
|
122
157
|
def image_to_dict(image: ImageType, lang: str, config: str) -> Dict[str, List[Union[str, int, float]]]:
|
|
123
158
|
"""
|
|
124
159
|
This is more or less pytesseract.image_to_data with a dict as returned value.
|
|
@@ -220,7 +255,6 @@ def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool,
|
|
|
220
255
|
:return: A list of tesseract extractions wrapped in DetectionResult
|
|
221
256
|
"""
|
|
222
257
|
|
|
223
|
-
np_img = np_img.astype(np.uint8)
|
|
224
258
|
results = image_to_dict(np_img, supported_languages, config)
|
|
225
259
|
all_results = []
|
|
226
260
|
|
|
@@ -249,6 +283,16 @@ def predict_text(np_img: ImageType, supported_languages: str, text_lines: bool,
|
|
|
249
283
|
return all_results
|
|
250
284
|
|
|
251
285
|
|
|
286
|
+
def predict_rotation(np_img: ImageType) -> Mapping[str, str]:
|
|
287
|
+
"""
|
|
288
|
+
Predicts the rotation of an image using the Tesseract OCR engine.
|
|
289
|
+
|
|
290
|
+
:param np_img: numpy array of the image
|
|
291
|
+
:return: A dictionary with keys 'Orientation in degrees' and 'Orientation confidence'
|
|
292
|
+
"""
|
|
293
|
+
return image_to_angle(np_img)
|
|
294
|
+
|
|
295
|
+
|
|
252
296
|
class TesseractOcrDetector(ObjectDetector):
|
|
253
297
|
"""
|
|
254
298
|
Text object detector based on Tesseracts OCR engine. Note that tesseract has to be installed separately.
|
|
@@ -292,7 +336,9 @@ class TesseractOcrDetector(ObjectDetector):
|
|
|
292
336
|
:param config_overwrite: Overwrite config parameters defined by the yaml file with new values.
|
|
293
337
|
E.g. ["oem=14"]
|
|
294
338
|
"""
|
|
295
|
-
self.name =
|
|
339
|
+
self.name = self.get_name()
|
|
340
|
+
self.model_id = self.get_model_id()
|
|
341
|
+
|
|
296
342
|
if config_overwrite is None:
|
|
297
343
|
config_overwrite = []
|
|
298
344
|
|
|
@@ -316,13 +362,13 @@ class TesseractOcrDetector(ObjectDetector):
|
|
|
316
362
|
:param np_img: image as numpy array
|
|
317
363
|
:return: A list of DetectionResult
|
|
318
364
|
"""
|
|
319
|
-
|
|
365
|
+
|
|
366
|
+
return predict_text(
|
|
320
367
|
np_img,
|
|
321
368
|
supported_languages=self.config.LANGUAGES,
|
|
322
369
|
text_lines=self.config.LINES,
|
|
323
370
|
config=config_to_cli_str(self.config, "LANGUAGES", "LINES"),
|
|
324
371
|
)
|
|
325
|
-
return detection_results
|
|
326
372
|
|
|
327
373
|
@classmethod
|
|
328
374
|
def get_requirements(cls) -> List[Requirement]:
|
|
@@ -342,3 +388,69 @@ class TesseractOcrDetector(ObjectDetector):
|
|
|
342
388
|
:param language: `Languages`
|
|
343
389
|
"""
|
|
344
390
|
self.config.LANGUAGES = _LANG_CODE_TO_TESS_LANG_CODE.get(language, language.value)
|
|
391
|
+
|
|
392
|
+
@staticmethod
|
|
393
|
+
def get_name() -> str:
|
|
394
|
+
"""Returns the name of the model"""
|
|
395
|
+
return f"Tesseract_{get_tesseract_version()}"
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class TesseractRotationTransformer(ImageTransformer):
|
|
399
|
+
"""
|
|
400
|
+
The `TesseractRotationTransformer` class is a specialized image transformer that is designed to handle image
|
|
401
|
+
rotation in the context of Optical Character Recognition (OCR) tasks. It inherits from the `ImageTransformer`
|
|
402
|
+
base class and implements methods for predicting and applying rotation transformations to images.
|
|
403
|
+
|
|
404
|
+
The `predict` method determines the angle of the rotated image. It can only handle angles that are multiples of 90
|
|
405
|
+
degrees.
|
|
406
|
+
This method uses the Tesseract OCR engine to predict the rotation angle of an image.
|
|
407
|
+
|
|
408
|
+
The `transform` method applies the predicted rotation to the image, effectively rotating the image backwards.
|
|
409
|
+
This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
|
|
410
|
+
|
|
411
|
+
This class can be particularly useful in OCR tasks where the orientation of the text in the image matters.
|
|
412
|
+
The class also provides methods for cloning itself and for getting the requirements of the Tesseract OCR system.
|
|
413
|
+
|
|
414
|
+
**Example:**
|
|
415
|
+
transformer = TesseractRotationTransformer()
|
|
416
|
+
detection_result = transformer.predict(np_img)
|
|
417
|
+
rotated_image = transformer.transform(np_img, detection_result)
|
|
418
|
+
"""
|
|
419
|
+
|
|
420
|
+
def __init__(self) -> None:
|
|
421
|
+
self.name = _TESS_PATH + "-rotation"
|
|
422
|
+
|
|
423
|
+
def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
|
|
424
|
+
"""
|
|
425
|
+
Applies the predicted rotation to the image, effectively rotating the image backwards.
|
|
426
|
+
This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
|
|
427
|
+
|
|
428
|
+
:param np_img: The input image as a numpy array.
|
|
429
|
+
:param specification: A `DetectionResult` object containing the predicted rotation angle.
|
|
430
|
+
:return: The rotated image as a numpy array.
|
|
431
|
+
"""
|
|
432
|
+
return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
|
|
433
|
+
|
|
434
|
+
def predict(self, np_img: ImageType) -> DetectionResult:
|
|
435
|
+
"""
|
|
436
|
+
Determines the angle of the rotated image. It can only handle angles that are multiples of 90 degrees.
|
|
437
|
+
This method uses the Tesseract OCR engine to predict the rotation angle of an image.
|
|
438
|
+
|
|
439
|
+
:param np_img: The input image as a numpy array.
|
|
440
|
+
:return: A `DetectionResult` object containing the predicted rotation angle and confidence.
|
|
441
|
+
"""
|
|
442
|
+
output_dict = predict_rotation(np_img)
|
|
443
|
+
return DetectionResult(
|
|
444
|
+
angle=float(output_dict["Orientation in degrees"]), score=float(output_dict["Orientation confidence"])
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
@classmethod
|
|
448
|
+
def get_requirements(cls) -> List[Requirement]:
|
|
449
|
+
return [get_tesseract_requirement()]
|
|
450
|
+
|
|
451
|
+
def clone(self) -> PredictorBase:
|
|
452
|
+
return self.__class__()
|
|
453
|
+
|
|
454
|
+
@staticmethod
|
|
455
|
+
def possible_category() -> PageType:
|
|
456
|
+
return PageType.angle
|
deepdoctection/extern/texocr.py
CHANGED
|
@@ -23,14 +23,16 @@ import sys
|
|
|
23
23
|
import traceback
|
|
24
24
|
from typing import List
|
|
25
25
|
|
|
26
|
+
from lazy_imports import try_import
|
|
27
|
+
|
|
26
28
|
from ..datapoint.convert import convert_np_array_to_b64_b
|
|
27
29
|
from ..utils.detection_types import ImageType, JsonDict, Requirement
|
|
28
|
-
from ..utils.file_utils import
|
|
30
|
+
from ..utils.file_utils import get_boto3_requirement
|
|
29
31
|
from ..utils.logger import LoggingRecord, logger
|
|
30
32
|
from ..utils.settings import LayoutType, ObjectTypes
|
|
31
33
|
from .base import DetectionResult, ObjectDetector, PredictorBase
|
|
32
34
|
|
|
33
|
-
|
|
35
|
+
with try_import() as import_guard:
|
|
34
36
|
import boto3 # type:ignore
|
|
35
37
|
|
|
36
38
|
|
|
@@ -120,6 +122,8 @@ class TextractOcrDetector(ObjectDetector):
|
|
|
120
122
|
:param credentials_kwargs: `aws_access_key_id`, `aws_secret_access_key` or `aws_session_token`
|
|
121
123
|
"""
|
|
122
124
|
self.name = "textract"
|
|
125
|
+
self.model_id = self.get_model_id()
|
|
126
|
+
|
|
123
127
|
self.text_lines = text_lines
|
|
124
128
|
self.client = boto3.client("textract", **credentials_kwargs)
|
|
125
129
|
if self.text_lines:
|
|
@@ -19,7 +19,18 @@
|
|
|
19
19
|
Tensorflow related utils.
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
|
-
from
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import os
|
|
25
|
+
from typing import Optional, Union, ContextManager
|
|
26
|
+
|
|
27
|
+
from lazy_imports import try_import
|
|
28
|
+
|
|
29
|
+
with try_import() as import_guard:
|
|
30
|
+
from tensorpack.models import disable_layer_logging # pylint: disable=E0401
|
|
31
|
+
|
|
32
|
+
with try_import() as tf_import_guard:
|
|
33
|
+
import tensorflow as tf # pylint: disable=E0401
|
|
23
34
|
|
|
24
35
|
|
|
25
36
|
def is_tfv2() -> bool:
|
|
@@ -38,16 +49,13 @@ def disable_tfv2() -> bool:
|
|
|
38
49
|
"""
|
|
39
50
|
Disable TF in V2 mode.
|
|
40
51
|
"""
|
|
41
|
-
try:
|
|
42
|
-
import tensorflow as tf # pylint: disable=C0415
|
|
43
52
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
53
|
+
tfv1 = tf.compat.v1
|
|
54
|
+
if is_tfv2():
|
|
55
|
+
tfv1.disable_v2_behavior()
|
|
56
|
+
tfv1.disable_eager_execution()
|
|
48
57
|
return True
|
|
49
|
-
|
|
50
|
-
return False
|
|
58
|
+
return False
|
|
51
59
|
|
|
52
60
|
|
|
53
61
|
def disable_tp_layer_logging() -> None:
|
|
@@ -55,3 +63,29 @@ def disable_tp_layer_logging() -> None:
|
|
|
55
63
|
Disables TP layer logging, if not already set
|
|
56
64
|
"""
|
|
57
65
|
disable_layer_logging()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_tf_device(device: Optional[Union[str, tf.device]] = None) -> tf.device:
|
|
69
|
+
"""
|
|
70
|
+
Selecting a device on which to load a model. The selection follows a cascade of priorities:
|
|
71
|
+
|
|
72
|
+
- If a device string is provided, it is used. If the string is "cuda" or "GPU", the first GPU is used.
|
|
73
|
+
- If the environment variable "USE_CUDA" is set, a GPU is used. If more GPUs are available it will use the first one
|
|
74
|
+
|
|
75
|
+
:param device: Device string
|
|
76
|
+
:return: Tensorflow device
|
|
77
|
+
"""
|
|
78
|
+
if device is not None:
|
|
79
|
+
if isinstance(device, ContextManager):
|
|
80
|
+
return device
|
|
81
|
+
if isinstance(device, str):
|
|
82
|
+
if device in ("cuda", "GPU"):
|
|
83
|
+
device_names = [device.name for device in tf.config.list_logical_devices(device_type="GPU")]
|
|
84
|
+
return tf.device(device_names[0].name)
|
|
85
|
+
# The input must be something sensible
|
|
86
|
+
return tf.device(device)
|
|
87
|
+
if os.environ.get("USE_CUDA"):
|
|
88
|
+
device_names = [device.name for device in tf.config.list_logical_devices(device_type="GPU")]
|
|
89
|
+
return tf.device(device_names[0])
|
|
90
|
+
device_names = [device.name for device in tf.config.list_logical_devices(device_type="CPU")]
|
|
91
|
+
return tf.device(device_names[0])
|
|
@@ -18,21 +18,24 @@
|
|
|
18
18
|
"""
|
|
19
19
|
Compatibility classes and methods related to Tensorpack package
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
21
22
|
|
|
22
23
|
from abc import ABC, abstractmethod
|
|
23
24
|
from typing import Any, List, Mapping, Tuple, Union
|
|
24
25
|
|
|
25
|
-
from
|
|
26
|
-
from tensorpack.tfutils import SmartInit # pylint: disable=E0401
|
|
27
|
-
|
|
28
|
-
# pylint: disable=import-error
|
|
29
|
-
from tensorpack.train.model_desc import ModelDesc
|
|
30
|
-
from tensorpack.utils.gpu import get_num_gpu
|
|
26
|
+
from lazy_imports import try_import
|
|
31
27
|
|
|
32
28
|
from ...utils.metacfg import AttrDict
|
|
33
29
|
from ...utils.settings import ObjectTypes
|
|
34
30
|
|
|
35
|
-
|
|
31
|
+
with try_import() as import_guard:
|
|
32
|
+
from tensorpack.predict import OfflinePredictor, PredictConfig # pylint: disable=E0401
|
|
33
|
+
from tensorpack.tfutils import SmartInit # pylint: disable=E0401
|
|
34
|
+
from tensorpack.train.model_desc import ModelDesc # pylint: disable=E0401
|
|
35
|
+
from tensorpack.utils.gpu import get_num_gpu # pylint: disable=E0401
|
|
36
|
+
|
|
37
|
+
if not import_guard.is_successful():
|
|
38
|
+
from ...utils.mocks import ModelDesc
|
|
36
39
|
|
|
37
40
|
|
|
38
41
|
class ModelDescWithConfig(ModelDesc, ABC): # type: ignore
|
|
@@ -55,7 +58,7 @@ class ModelDescWithConfig(ModelDesc, ABC): # type: ignore
|
|
|
55
58
|
|
|
56
59
|
:return: Tuple of list input and list output names. The names must coincide with tensor within the model.
|
|
57
60
|
"""
|
|
58
|
-
raise NotImplementedError
|
|
61
|
+
raise NotImplementedError()
|
|
59
62
|
|
|
60
63
|
|
|
61
64
|
class TensorpackPredictor(ABC):
|
|
@@ -106,14 +109,14 @@ class TensorpackPredictor(ABC):
|
|
|
106
109
|
|
|
107
110
|
@staticmethod
|
|
108
111
|
@abstractmethod
|
|
109
|
-
def
|
|
112
|
+
def get_wrapped_model(
|
|
110
113
|
path_yaml: str, categories: Mapping[str, ObjectTypes], config_overwrite: Union[List[str], None]
|
|
111
114
|
) -> ModelDescWithConfig:
|
|
112
115
|
"""
|
|
113
116
|
Implement the config generation, its modification and instantiate a version of the model. See
|
|
114
117
|
`pipe.tpfrcnn.TPFrcnnDetector` for an example
|
|
115
118
|
"""
|
|
116
|
-
raise NotImplementedError
|
|
119
|
+
raise NotImplementedError()
|
|
117
120
|
|
|
118
121
|
@abstractmethod
|
|
119
122
|
def predict(self, np_img: Any) -> Any:
|
|
@@ -121,7 +124,7 @@ class TensorpackPredictor(ABC):
|
|
|
121
124
|
Implement, how `self.tp_predictor` is invoked and raw prediction results are generated. Do use only raw
|
|
122
125
|
objects and nothing, which is related to the DD API.
|
|
123
126
|
"""
|
|
124
|
-
raise NotImplementedError
|
|
127
|
+
raise NotImplementedError()
|
|
125
128
|
|
|
126
129
|
@property
|
|
127
130
|
def model(self) -> ModelDescWithConfig:
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# File: __init__.py
|
|
3
|
+
|
|
4
|
+
# Copyright 2021 Dr. Janis Meyer. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
Init file for code for Tensorpack FRCNN example
|
|
20
|
+
"""
|
|
@@ -11,13 +11,17 @@ This file is modified from
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
|
-
from
|
|
14
|
+
from lazy_imports import try_import
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
with try_import() as import_guard:
|
|
17
|
+
from tensorpack.dataflow.imgaug import ImageAugmentor, ResizeTransform # pylint: disable=E0401
|
|
17
18
|
|
|
18
|
-
|
|
19
|
+
with try_import() as cc_import_guard:
|
|
19
20
|
import pycocotools.mask as coco_mask
|
|
20
21
|
|
|
22
|
+
if not import_guard.is_successful():
|
|
23
|
+
from ....utils.mocks import ImageAugmentor
|
|
24
|
+
|
|
21
25
|
|
|
22
26
|
class CustomResize(ImageAugmentor):
|
|
23
27
|
"""
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# File: __init__.py
|
|
3
|
+
|
|
4
|
+
# Copyright 2021 Dr. Janis Meyer. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
Init file for code for Tensorpack's FRCNN configs
|
|
20
|
+
"""
|
|
@@ -191,16 +191,19 @@ import os
|
|
|
191
191
|
from typing import List, Mapping, Tuple
|
|
192
192
|
|
|
193
193
|
import numpy as np
|
|
194
|
-
from
|
|
195
|
-
from tensorpack.utils import logger # pylint: disable=E0401
|
|
196
|
-
|
|
197
|
-
# pylint: disable=import-error
|
|
198
|
-
from tensorpack.utils.gpu import get_num_gpu
|
|
194
|
+
from lazy_imports import try_import
|
|
199
195
|
|
|
200
196
|
from .....utils.metacfg import AttrDict
|
|
201
197
|
from .....utils.settings import ObjectTypes
|
|
202
198
|
|
|
203
|
-
|
|
199
|
+
with try_import() as import_guard:
|
|
200
|
+
from tensorpack.tfutils import collect_env_info # pylint: disable=E0401
|
|
201
|
+
from tensorpack.utils import logger # pylint: disable=E0401
|
|
202
|
+
|
|
203
|
+
# pylint: disable=import-error
|
|
204
|
+
from tensorpack.utils.gpu import get_num_gpu
|
|
205
|
+
|
|
206
|
+
# pylint: enable=import-error
|
|
204
207
|
|
|
205
208
|
|
|
206
209
|
__all__ = ["train_frcnn_config", "model_frcnn_config"]
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# File: __init__.py
|
|
3
|
+
|
|
4
|
+
# Copyright 2021 Dr. Janis Meyer. All rights reserved.
|
|
5
|
+
#
|
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
7
|
+
# you may not use this file except in compliance with the License.
|
|
8
|
+
# You may obtain a copy of the License at
|
|
9
|
+
#
|
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
#
|
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
15
|
+
# See the License for the specific language governing permissions and
|
|
16
|
+
# limitations under the License.
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
Init file for code for Tensorpack's FRCNN configs
|
|
20
|
+
"""
|
|
@@ -12,22 +12,30 @@ This file is modified from
|
|
|
12
12
|
from contextlib import ExitStack, contextmanager
|
|
13
13
|
|
|
14
14
|
import numpy as np
|
|
15
|
+
from lazy_imports import try_import
|
|
15
16
|
|
|
16
17
|
# pylint: disable=import-error
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from tensorpack
|
|
21
|
-
from tensorpack.
|
|
18
|
+
|
|
19
|
+
with try_import() as import_guard:
|
|
20
|
+
import tensorflow as tf
|
|
21
|
+
from tensorpack import tfv1
|
|
22
|
+
from tensorpack.models import BatchNorm, Conv2D, MaxPooling, layer_register
|
|
23
|
+
from tensorpack.tfutils import argscope
|
|
24
|
+
from tensorpack.tfutils.varreplace import custom_getter_scope, freeze_variables
|
|
22
25
|
|
|
23
26
|
# pylint: enable=import-error
|
|
24
27
|
|
|
28
|
+
if not import_guard.is_successful():
|
|
29
|
+
from .....utils.mocks import layer_register
|
|
30
|
+
|
|
25
31
|
|
|
26
32
|
@layer_register(log_shape=True)
|
|
27
|
-
def GroupNorm(x, group=32, gamma_initializer=
|
|
33
|
+
def GroupNorm(x, group=32, gamma_initializer=None):
|
|
28
34
|
"""
|
|
29
35
|
More code that reproduces the paper can be found at <https://github.com/ppwwyyxx/GroupNorm-reproduce/>.
|
|
30
36
|
"""
|
|
37
|
+
if gamma_initializer is None:
|
|
38
|
+
gamma_initializer = tf.constant_initializer(1.0)
|
|
31
39
|
shape = x.get_shape().as_list()
|
|
32
40
|
ndims = len(shape)
|
|
33
41
|
assert ndims == 4, shape
|
|
@@ -153,7 +161,7 @@ def get_norm(cfg, zero_init=False):
|
|
|
153
161
|
return lambda x: norm(layer_name, x, gamma_initializer=tf.zeros_initializer() if zero_init else None)
|
|
154
162
|
|
|
155
163
|
|
|
156
|
-
def resnet_shortcut(l, n_out, stride, activation=
|
|
164
|
+
def resnet_shortcut(l, n_out, stride, activation=None):
|
|
157
165
|
"""
|
|
158
166
|
Defining the skip connection in bottleneck
|
|
159
167
|
|
|
@@ -163,6 +171,8 @@ def resnet_shortcut(l, n_out, stride, activation=tf.identity):
|
|
|
163
171
|
:param activation: An activation function
|
|
164
172
|
:return: tf.Tensor
|
|
165
173
|
"""
|
|
174
|
+
if activation is None:
|
|
175
|
+
activation = tf.identity
|
|
166
176
|
n_in = l.shape[1]
|
|
167
177
|
if n_in != n_out: # change dimension when channel is not the same
|
|
168
178
|
return Conv2D("convshortcut", l, n_out, 1, strides=stride, activation=activation) # pylint: disable=E1124
|
|
@@ -9,12 +9,8 @@ This file is modified from
|
|
|
9
9
|
<https://github.com/tensorpack/tensorpack/blob/master/examples/FasterRCNN/modeling/generalized_rcnn.py>
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
from tensorpack import tfv1
|
|
15
|
-
from tensorpack.models import l2_regularizer, regularize_cost
|
|
16
|
-
from tensorpack.tfutils import optimizer
|
|
17
|
-
from tensorpack.tfutils.summary import add_moving_summary
|
|
12
|
+
|
|
13
|
+
from lazy_imports import try_import
|
|
18
14
|
|
|
19
15
|
from ...tpcompat import ModelDescWithConfig
|
|
20
16
|
from ..utils.box_ops import area as tf_area
|
|
@@ -40,6 +36,16 @@ from .model_frcnn import (
|
|
|
40
36
|
from .model_mrcnn import maskrcnn_loss, unpackbits_masks
|
|
41
37
|
from .model_rpn import rpn_head
|
|
42
38
|
|
|
39
|
+
with try_import() as import_guard:
|
|
40
|
+
# pylint: disable=import-error
|
|
41
|
+
import tensorflow as tf
|
|
42
|
+
from tensorpack import tfv1
|
|
43
|
+
from tensorpack.models import l2_regularizer, regularize_cost
|
|
44
|
+
from tensorpack.tfutils import optimizer
|
|
45
|
+
from tensorpack.tfutils.summary import add_moving_summary
|
|
46
|
+
|
|
47
|
+
# pylint: enable=import-error
|
|
48
|
+
|
|
43
49
|
|
|
44
50
|
class GeneralizedRCNN(ModelDescWithConfig):
|
|
45
51
|
"""
|
|
@@ -11,12 +11,17 @@ This file is modified from
|
|
|
11
11
|
from collections import namedtuple
|
|
12
12
|
|
|
13
13
|
import numpy as np
|
|
14
|
+
from lazy_imports import try_import
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
with try_import() as import_guard:
|
|
17
|
+
# pylint: disable=import-error
|
|
18
|
+
import tensorflow as tf
|
|
19
|
+
from tensorpack.tfutils.scope_utils import under_name_scope
|
|
18
20
|
|
|
19
|
-
# pylint: enable=import-error
|
|
21
|
+
# pylint: enable=import-error
|
|
22
|
+
|
|
23
|
+
if not import_guard.is_successful():
|
|
24
|
+
from .....utils.mocks import under_name_scope
|
|
20
25
|
|
|
21
26
|
|
|
22
27
|
@under_name_scope()
|
|
@@ -9,17 +9,20 @@ This file is modified from
|
|
|
9
9
|
<https://github.com/tensorpack/tensorpack/blob/master/examples/FasterRCNN/modeling/model_cascade.py>
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
|
|
13
|
-
import tensorflow as tf
|
|
14
|
-
from tensorpack import tfv1
|
|
15
|
-
from tensorpack.tfutils import get_current_tower_context
|
|
12
|
+
from lazy_imports import try_import
|
|
16
13
|
|
|
17
14
|
from ..utils.box_ops import area as tf_area
|
|
18
15
|
from ..utils.box_ops import pairwise_iou
|
|
19
16
|
from .model_box import clip_boxes
|
|
20
17
|
from .model_frcnn import BoxProposals, FastRCNNHead, fastrcnn_outputs
|
|
21
18
|
|
|
22
|
-
|
|
19
|
+
with try_import() as import_guard:
|
|
20
|
+
# pylint: disable=import-error
|
|
21
|
+
import tensorflow as tf
|
|
22
|
+
from tensorpack import tfv1
|
|
23
|
+
from tensorpack.tfutils import get_current_tower_context
|
|
24
|
+
|
|
25
|
+
# pylint: enable=import-error
|
|
23
26
|
|
|
24
27
|
|
|
25
28
|
class CascadeRCNNHead:
|