deepdoctection 0.26__py3-none-any.whl → 0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +7 -1
- deepdoctection/analyzer/dd.py +15 -3
- deepdoctection/configs/conf_dd_one.yaml +4 -0
- deepdoctection/datapoint/convert.py +5 -10
- deepdoctection/datapoint/image.py +2 -2
- deepdoctection/datapoint/view.py +38 -18
- deepdoctection/datasets/save.py +3 -3
- deepdoctection/extern/d2detect.py +1 -2
- deepdoctection/extern/doctrocr.py +14 -9
- deepdoctection/extern/tp/tpfrcnn/common.py +2 -3
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +6 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +3 -3
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -2
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +5 -3
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +3 -1
- deepdoctection/extern/tp/tpfrcnn/predict.py +1 -0
- deepdoctection/mapper/laylmstruct.py +2 -3
- deepdoctection/utils/context.py +2 -2
- deepdoctection/utils/file_utils.py +63 -26
- deepdoctection/utils/fs.py +6 -6
- deepdoctection/utils/pdf_utils.py +2 -2
- deepdoctection/utils/settings.py +8 -1
- deepdoctection/utils/transform.py +9 -9
- deepdoctection/utils/viz.py +405 -86
- {deepdoctection-0.26.dist-info → deepdoctection-0.27.dist-info}/METADATA +93 -94
- {deepdoctection-0.26.dist-info → deepdoctection-0.27.dist-info}/RECORD +31 -31
- {deepdoctection-0.26.dist-info → deepdoctection-0.27.dist-info}/WHEEL +1 -1
- tests/analyzer/test_dd.py +6 -57
- tests/conftest.py +2 -0
- {deepdoctection-0.26.dist-info → deepdoctection-0.27.dist-info}/LICENSE +0 -0
- {deepdoctection-0.26.dist-info → deepdoctection-0.27.dist-info}/top_level.txt +0 -0
|
@@ -25,6 +25,8 @@ from .detection_types import Requirement
|
|
|
25
25
|
from .logger import logger
|
|
26
26
|
from .metacfg import AttrDict
|
|
27
27
|
|
|
28
|
+
_GENERIC_ERR_MSG = "Please check the required version either in the docs or in the setup file"
|
|
29
|
+
|
|
28
30
|
# Tensorflow and Tensorpack dependencies
|
|
29
31
|
_TF_AVAILABLE = False
|
|
30
32
|
|
|
@@ -33,7 +35,7 @@ try:
|
|
|
33
35
|
except ValueError:
|
|
34
36
|
pass
|
|
35
37
|
|
|
36
|
-
_TF_ERR_MSG = "Tensorflow
|
|
38
|
+
_TF_ERR_MSG = f"Tensorflow must be installed. {_GENERIC_ERR_MSG}"
|
|
37
39
|
|
|
38
40
|
|
|
39
41
|
def tf_available() -> bool:
|
|
@@ -90,8 +92,8 @@ def get_tensorflow_requirement() -> Requirement:
|
|
|
90
92
|
|
|
91
93
|
_TF_ADDONS_AVAILABLE = importlib.util.find_spec("tensorflow_addons") is not None
|
|
92
94
|
_TF_ADDONS_ERR_MSG = (
|
|
93
|
-
"Tensorflow Addons must be installed
|
|
94
|
-
"
|
|
95
|
+
"Tensorflow Addons must be installed. Please check the required version either in the docs or in the setup file."
|
|
96
|
+
"Please note, that it has been announced, the this package will be deprecated in the near future."
|
|
95
97
|
)
|
|
96
98
|
|
|
97
99
|
|
|
@@ -110,10 +112,7 @@ def get_tf_addons_requirements() -> Requirement:
|
|
|
110
112
|
|
|
111
113
|
|
|
112
114
|
_TP_AVAILABLE = importlib.util.find_spec("tensorpack") is not None
|
|
113
|
-
_TP_ERR_MSG =
|
|
114
|
-
"Tensorflow models all use the Tensorpack modeling API. Therefore, Tensorpack must be installed: "
|
|
115
|
-
">>make install-dd-tf"
|
|
116
|
-
)
|
|
115
|
+
_TP_ERR_MSG = f"Tensorpack must be installed. {_GENERIC_ERR_MSG}"
|
|
117
116
|
|
|
118
117
|
|
|
119
118
|
def tensorpack_available() -> bool:
|
|
@@ -132,7 +131,7 @@ def get_tensorpack_requirement() -> Requirement:
|
|
|
132
131
|
|
|
133
132
|
# Pytorch related dependencies
|
|
134
133
|
_PYTORCH_AVAILABLE = importlib.util.find_spec("torch") is not None
|
|
135
|
-
_PYTORCH_ERR_MSG = "Pytorch must be installed
|
|
134
|
+
_PYTORCH_ERR_MSG = f"Pytorch must be installed. {_GENERIC_ERR_MSG}"
|
|
136
135
|
|
|
137
136
|
|
|
138
137
|
def pytorch_available() -> bool:
|
|
@@ -151,7 +150,7 @@ def get_pytorch_requirement() -> Requirement:
|
|
|
151
150
|
|
|
152
151
|
# lxml
|
|
153
152
|
_LXML_AVAILABLE = importlib.util.find_spec("lxml") is not None
|
|
154
|
-
_LXML_ERR_MSG = "lxml must be installed
|
|
153
|
+
_LXML_ERR_MSG = f"lxml must be installed. {_GENERIC_ERR_MSG}"
|
|
155
154
|
|
|
156
155
|
|
|
157
156
|
def lxml_available() -> bool:
|
|
@@ -170,7 +169,7 @@ def get_lxml_requirement() -> Requirement:
|
|
|
170
169
|
|
|
171
170
|
# apted
|
|
172
171
|
_APTED_AVAILABLE = importlib.util.find_spec("apted") is not None
|
|
173
|
-
_APTED_ERR_MSG = "
|
|
172
|
+
_APTED_ERR_MSG = f"apted must be installed. {_GENERIC_ERR_MSG}"
|
|
174
173
|
|
|
175
174
|
|
|
176
175
|
def apted_available() -> bool:
|
|
@@ -189,7 +188,7 @@ def get_apted_requirement() -> Requirement:
|
|
|
189
188
|
|
|
190
189
|
# distance
|
|
191
190
|
_DISTANCE_AVAILABLE = importlib.util.find_spec("distance") is not None
|
|
192
|
-
_DISTANCE_ERR_MSG = "distance must be installed
|
|
191
|
+
_DISTANCE_ERR_MSG = f"distance must be installed. {_GENERIC_ERR_MSG}"
|
|
193
192
|
|
|
194
193
|
|
|
195
194
|
def distance_available() -> bool:
|
|
@@ -208,7 +207,7 @@ def get_distance_requirement() -> Requirement:
|
|
|
208
207
|
|
|
209
208
|
# Transformers
|
|
210
209
|
_TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None
|
|
211
|
-
_TRANSFORMERS_ERR_MSG = "
|
|
210
|
+
_TRANSFORMERS_ERR_MSG = f"transformers must be installed. {_GENERIC_ERR_MSG}"
|
|
212
211
|
|
|
213
212
|
|
|
214
213
|
def transformers_available() -> bool:
|
|
@@ -228,7 +227,7 @@ def get_transformers_requirement() -> Requirement:
|
|
|
228
227
|
# Detectron2 related requirements
|
|
229
228
|
_DETECTRON2_AVAILABLE = importlib.util.find_spec("detectron2") is not None
|
|
230
229
|
_DETECTRON2_ERR_MSG = (
|
|
231
|
-
"Detectron2 must be installed
|
|
230
|
+
"Detectron2 must be installed. Please follow the official installation instructions "
|
|
232
231
|
"https://detectron2.readthedocs.io/en/latest/tutorials/install.html"
|
|
233
232
|
)
|
|
234
233
|
|
|
@@ -251,7 +250,10 @@ def get_detectron2_requirement() -> Requirement:
|
|
|
251
250
|
_TESS_AVAILABLE = which("tesseract") is not None
|
|
252
251
|
# Tesseract installation path
|
|
253
252
|
_TESS_PATH = "tesseract"
|
|
254
|
-
_TESS_ERR_MSG =
|
|
253
|
+
_TESS_ERR_MSG = (
|
|
254
|
+
"Tesseract >=4.0 must be installed. Please follow the official installation instructions. "
|
|
255
|
+
"https://tesseract-ocr.github.io/tessdoc/Installation.html"
|
|
256
|
+
)
|
|
255
257
|
|
|
256
258
|
|
|
257
259
|
def set_tesseract_path(tesseract_path: str) -> None:
|
|
@@ -304,7 +306,7 @@ def get_tesseract_version() -> Union[int, version.Version]:
|
|
|
304
306
|
stdin=subprocess.DEVNULL,
|
|
305
307
|
)
|
|
306
308
|
except OSError:
|
|
307
|
-
raise TesseractNotFound() from OSError
|
|
309
|
+
raise TesseractNotFound(_TESS_ERR_MSG) from OSError
|
|
308
310
|
|
|
309
311
|
raw_version = output.decode("utf-8")
|
|
310
312
|
str_version, *_ = raw_version.lstrip(string.printable[10:]).partition(" ")
|
|
@@ -390,7 +392,7 @@ def get_poppler_requirement() -> Requirement:
|
|
|
390
392
|
|
|
391
393
|
# Pdfplumber.six related dependencies
|
|
392
394
|
_PDFPLUMBER_AVAILABLE = importlib.util.find_spec("pdfplumber") is not None
|
|
393
|
-
_PDFPLUMBER_ERR_MSG = "pdfplumber must be installed.
|
|
395
|
+
_PDFPLUMBER_ERR_MSG = f"pdfplumber must be installed. {_GENERIC_ERR_MSG}"
|
|
394
396
|
|
|
395
397
|
|
|
396
398
|
def pdfplumber_available() -> bool:
|
|
@@ -409,7 +411,7 @@ def get_pdfplumber_requirement() -> Requirement:
|
|
|
409
411
|
|
|
410
412
|
# pycocotools dependencies
|
|
411
413
|
_COCOTOOLS_AVAILABLE = importlib.util.find_spec("pycocotools") is not None
|
|
412
|
-
_COCOTOOLS_ERR_MSG = "pycocotools must be installed.
|
|
414
|
+
_COCOTOOLS_ERR_MSG = f"pycocotools must be installed. {_GENERIC_ERR_MSG}"
|
|
413
415
|
|
|
414
416
|
|
|
415
417
|
def cocotools_available() -> bool:
|
|
@@ -439,7 +441,7 @@ def scipy_available() -> bool:
|
|
|
439
441
|
|
|
440
442
|
# jdeskew dependency
|
|
441
443
|
_JDESKEW_AVAILABLE = importlib.util.find_spec("jdeskew") is not None
|
|
442
|
-
_JDESKEW_ERR_MSG = "jdeskew must be installed.
|
|
444
|
+
_JDESKEW_ERR_MSG = f"jdeskew must be installed. {_GENERIC_ERR_MSG}"
|
|
443
445
|
|
|
444
446
|
|
|
445
447
|
def jdeskew_available() -> bool:
|
|
@@ -458,7 +460,7 @@ def get_jdeskew_requirement() -> Requirement:
|
|
|
458
460
|
|
|
459
461
|
# scikit-learn dependencies
|
|
460
462
|
_SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None
|
|
461
|
-
_SKLEARN_ERR_MSG = "scikit-learn must be installed.
|
|
463
|
+
_SKLEARN_ERR_MSG = f"scikit-learn must be installed. {_GENERIC_ERR_MSG}"
|
|
462
464
|
|
|
463
465
|
|
|
464
466
|
def sklearn_available() -> bool:
|
|
@@ -488,7 +490,7 @@ def qpdf_available() -> bool:
|
|
|
488
490
|
|
|
489
491
|
# Textract related dependencies
|
|
490
492
|
_BOTO3_AVAILABLE = importlib.util.find_spec("boto3") is not None
|
|
491
|
-
_BOTO3_ERR_MSG = "Boto3 must be installed
|
|
493
|
+
_BOTO3_ERR_MSG = f"Boto3 must be installed. {_GENERIC_ERR_MSG}"
|
|
492
494
|
|
|
493
495
|
_AWS_CLI_AVAILABLE = which("aws") is not None
|
|
494
496
|
_AWS_ERR_MSG = "AWS CLI must be installed https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html"
|
|
@@ -525,10 +527,7 @@ def get_aws_requirement() -> Requirement:
|
|
|
525
527
|
|
|
526
528
|
# DocTr related dependencies
|
|
527
529
|
_DOCTR_AVAILABLE = importlib.util.find_spec("doctr") is not None
|
|
528
|
-
_DOCTR_ERR_MSG =
|
|
529
|
-
"DocTr must be installed. Please read the necessary requirements at https://github.com/mindee/doctr"
|
|
530
|
-
"and use >> pip install python-doctr"
|
|
531
|
-
)
|
|
530
|
+
_DOCTR_ERR_MSG = f"DocTr must be installed. {_GENERIC_ERR_MSG}"
|
|
532
531
|
|
|
533
532
|
|
|
534
533
|
def doctr_available() -> bool:
|
|
@@ -552,7 +551,7 @@ def get_doctr_requirement() -> Requirement:
|
|
|
552
551
|
|
|
553
552
|
# Fasttext related dependencies
|
|
554
553
|
_FASTTEXT_AVAILABLE = importlib.util.find_spec("fasttext") is not None
|
|
555
|
-
_FASTTEXT_ERR_MSG = "
|
|
554
|
+
_FASTTEXT_ERR_MSG = f"fasttext must be installed. {_GENERIC_ERR_MSG}"
|
|
556
555
|
|
|
557
556
|
|
|
558
557
|
def fasttext_available() -> bool:
|
|
@@ -571,7 +570,7 @@ def get_fasttext_requirement() -> Requirement:
|
|
|
571
570
|
|
|
572
571
|
# Wandb related dependencies
|
|
573
572
|
_WANDB_AVAILABLE = importlib.util.find_spec("wandb") is not None
|
|
574
|
-
_WANDB_ERR_MSG = "WandB must be installed.
|
|
573
|
+
_WANDB_ERR_MSG = f"WandB must be installed. {_GENERIC_ERR_MSG}"
|
|
575
574
|
|
|
576
575
|
|
|
577
576
|
def wandb_available() -> bool:
|
|
@@ -592,6 +591,44 @@ _S = AttrDict()
|
|
|
592
591
|
_S.mp_context_set = False
|
|
593
592
|
_S.freeze()
|
|
594
593
|
|
|
594
|
+
# Image libraries: OpenCV and Pillow
|
|
595
|
+
# OpenCV
|
|
596
|
+
_CV2_AVAILABLE = importlib.util.find_spec("cv2") is not None
|
|
597
|
+
_CV2_ERR_MSG = f"OpenCV must be installed. {_GENERIC_ERR_MSG}"
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def opencv_available() -> bool:
|
|
601
|
+
"""
|
|
602
|
+
Returns True if OpenCV is installed
|
|
603
|
+
"""
|
|
604
|
+
return bool(_CV2_AVAILABLE)
|
|
605
|
+
|
|
606
|
+
|
|
607
|
+
def get_opencv_requirement() -> Requirement:
|
|
608
|
+
"""
|
|
609
|
+
Return OpenCV requirement
|
|
610
|
+
"""
|
|
611
|
+
return "opencv", opencv_available(), _CV2_ERR_MSG
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
# Pillow
|
|
615
|
+
_PILLOW_AVAILABLE = importlib.util.find_spec("PIL") is not None
|
|
616
|
+
_PILLOW_ERR_MSG = f"pillow must be installed. {_GENERIC_ERR_MSG}"
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def pillow_available() -> bool:
|
|
620
|
+
"""
|
|
621
|
+
Returns True if Pillow is installed
|
|
622
|
+
"""
|
|
623
|
+
return bool(_PILLOW_AVAILABLE)
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
def get_pillow_requirement() -> Requirement:
|
|
627
|
+
"""
|
|
628
|
+
Return OpenCV requirement
|
|
629
|
+
"""
|
|
630
|
+
return "pillow", pillow_available(), _PILLOW_ERR_MSG
|
|
631
|
+
|
|
595
632
|
|
|
596
633
|
def set_mp_spawn() -> None:
|
|
597
634
|
"""
|
deepdoctection/utils/fs.py
CHANGED
|
@@ -28,13 +28,12 @@ from pathlib import Path
|
|
|
28
28
|
from typing import Callable, Literal, Optional, Protocol, Union, overload
|
|
29
29
|
from urllib.request import urlretrieve
|
|
30
30
|
|
|
31
|
-
from cv2 import IMREAD_COLOR, imread
|
|
32
|
-
|
|
33
31
|
from .detection_types import ImageType, JsonDict, Pathlike
|
|
34
32
|
from .logger import logger
|
|
35
33
|
from .pdf_utils import get_pdf_file_reader, get_pdf_file_writer
|
|
36
34
|
from .tqdm import get_tqdm
|
|
37
35
|
from .utils import FileExtensionError, is_file_extension
|
|
36
|
+
from .viz import viz_handler
|
|
38
37
|
|
|
39
38
|
__all__ = [
|
|
40
39
|
"load_image_from_file",
|
|
@@ -90,7 +89,7 @@ def download(url: str, directory: Pathlike, file_name: Optional[str] = None, exp
|
|
|
90
89
|
f_path = os.path.join(directory, file_name)
|
|
91
90
|
|
|
92
91
|
if os.path.isfile(f_path):
|
|
93
|
-
if expect_size is not None and os.stat(f_path).st_size == expect_size:
|
|
92
|
+
if (expect_size is not None and os.stat(f_path).st_size == expect_size) or expect_size is None:
|
|
94
93
|
logger.info("File %s exists! Skip download.", file_name)
|
|
95
94
|
return f_path
|
|
96
95
|
logger.warning("File %s exists. Will overwrite with a new download!", file_name)
|
|
@@ -156,19 +155,20 @@ def load_image_from_file(path: Pathlike, type_id: Literal["np", "b64"] = "np") -
|
|
|
156
155
|
with open(path, "rb") as file:
|
|
157
156
|
image = b64encode(file.read()).decode("utf-8")
|
|
158
157
|
else:
|
|
159
|
-
image =
|
|
158
|
+
image = viz_handler.read_image(path)
|
|
160
159
|
except (FileNotFoundError, ValueError):
|
|
161
160
|
logger.info("file not found or value error: %s", path)
|
|
162
161
|
|
|
163
162
|
return image
|
|
164
163
|
|
|
165
164
|
|
|
166
|
-
def load_bytes_from_pdf_file(path: Pathlike) -> bytes:
|
|
165
|
+
def load_bytes_from_pdf_file(path: Pathlike, page_number: int = 0) -> bytes:
|
|
167
166
|
"""
|
|
168
167
|
Loads a pdf file with one single page and passes back a bytes' representation of this file. Can be converted into
|
|
169
168
|
a numpy or directly passed to the attr: image of Image.
|
|
170
169
|
|
|
171
170
|
:param path: A path to a pdf file. If more pages are available, it will take the first page.
|
|
171
|
+
:param page_number: If a document has less than page_number it will raise an `IndexError`
|
|
172
172
|
:return: A bytes' representation of the file, width and height
|
|
173
173
|
"""
|
|
174
174
|
|
|
@@ -177,7 +177,7 @@ def load_bytes_from_pdf_file(path: Pathlike) -> bytes:
|
|
|
177
177
|
file_reader = get_pdf_file_reader(path)
|
|
178
178
|
buffer = BytesIO()
|
|
179
179
|
writer = get_pdf_file_writer()
|
|
180
|
-
writer.
|
|
180
|
+
writer.add_page(file_reader.pages[page_number])
|
|
181
181
|
writer.write(buffer)
|
|
182
182
|
return buffer.getvalue()
|
|
183
183
|
|
|
@@ -27,7 +27,6 @@ from io import BytesIO
|
|
|
27
27
|
from shutil import copyfile
|
|
28
28
|
from typing import Generator, List, Optional, Tuple
|
|
29
29
|
|
|
30
|
-
from cv2 import IMREAD_COLOR, imread
|
|
31
30
|
from numpy import uint8
|
|
32
31
|
from PyPDF2 import PdfReader, PdfWriter, errors
|
|
33
32
|
|
|
@@ -36,6 +35,7 @@ from .detection_types import ImageType, Pathlike
|
|
|
36
35
|
from .file_utils import PopplerNotFound, pdf_to_cairo_available, pdf_to_ppm_available, qpdf_available
|
|
37
36
|
from .logger import logger
|
|
38
37
|
from .utils import FileExtensionError, is_file_extension
|
|
38
|
+
from .viz import viz_handler
|
|
39
39
|
|
|
40
40
|
__all__ = ["decrypt_pdf_document", "get_pdf_file_reader", "get_pdf_file_writer", "PDFStreamer", "pdf_to_np_array"]
|
|
41
41
|
|
|
@@ -215,6 +215,6 @@ def pdf_to_np_array(pdf_bytes: bytes, size: Optional[Tuple[int, int]] = None, dp
|
|
|
215
215
|
|
|
216
216
|
with save_tmp_file(pdf_bytes, "pdf_") as (tmp_name, input_file_name):
|
|
217
217
|
_run_poppler(_input_to_cli_str(input_file_name, tmp_name, dpi, size))
|
|
218
|
-
image =
|
|
218
|
+
image = viz_handler.read_image(tmp_name + "-1.png")
|
|
219
219
|
|
|
220
220
|
return image.astype(uint8)
|
deepdoctection/utils/settings.py
CHANGED
|
@@ -165,6 +165,9 @@ class WordType(ObjectTypes):
|
|
|
165
165
|
tag = "tag"
|
|
166
166
|
token_tag = "token_tag"
|
|
167
167
|
text_line = "text_line"
|
|
168
|
+
character_type = "character_type"
|
|
169
|
+
printed = "printed"
|
|
170
|
+
handwritten = "handwritten"
|
|
168
171
|
|
|
169
172
|
|
|
170
173
|
@object_types_registry.register("TokenClasses")
|
|
@@ -411,7 +414,11 @@ file_path = Path(os.path.split(__file__)[0])
|
|
|
411
414
|
PATH = file_path.parent.parent
|
|
412
415
|
|
|
413
416
|
# model cache directory
|
|
414
|
-
|
|
417
|
+
if os.environ.get("DEEPDOCTECTION_CACHE"):
|
|
418
|
+
dd_cache_home = Path(os.environ["DEEPDOCTECTION_CACHE"])
|
|
419
|
+
else:
|
|
420
|
+
dd_cache_home = Path(os.getenv("XDG_CACHE_HOME", Path.home() / ".cache")) / "deepdoctection"
|
|
421
|
+
|
|
415
422
|
MODEL_DIR = dd_cache_home / "weights"
|
|
416
423
|
|
|
417
424
|
# configs cache directory
|
|
@@ -24,14 +24,14 @@ of coordinates. Most have the ideas have been taken from
|
|
|
24
24
|
from abc import ABC, abstractmethod
|
|
25
25
|
from typing import Literal, Optional, Union
|
|
26
26
|
|
|
27
|
-
import cv2
|
|
28
27
|
import numpy as np
|
|
29
28
|
import numpy.typing as npt
|
|
30
29
|
from numpy import float32
|
|
31
30
|
|
|
32
31
|
from .detection_types import ImageType
|
|
32
|
+
from .viz import viz_handler
|
|
33
33
|
|
|
34
|
-
__all__ = ["ResizeTransform", "InferenceResize", "PadTransform"]
|
|
34
|
+
__all__ = ["ResizeTransform", "InferenceResize", "PadTransform", "normalize_image"]
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class BaseTransform(ABC):
|
|
@@ -61,25 +61,25 @@ class ResizeTransform(BaseTransform):
|
|
|
61
61
|
w: Union[int, float],
|
|
62
62
|
new_h: Union[int, float],
|
|
63
63
|
new_w: Union[int, float],
|
|
64
|
-
interp:
|
|
64
|
+
interp: str,
|
|
65
65
|
):
|
|
66
66
|
"""
|
|
67
67
|
:param h: height
|
|
68
68
|
:param w: width
|
|
69
69
|
:param new_h: target height
|
|
70
70
|
:param new_w: target width
|
|
71
|
-
:param interp:
|
|
72
|
-
|
|
71
|
+
:param interp: interpolation method, that depends on the image processing library. Currently, it supports
|
|
72
|
+
NEAREST, BOX, BILINEAR, BICUBIC and VIZ for PIL or INTER_NEAREST, INTER_LINEAR, INTER_AREA or VIZ for OpenCV
|
|
73
73
|
"""
|
|
74
74
|
self.h = h
|
|
75
75
|
self.w = w
|
|
76
|
-
self.new_h = new_h
|
|
77
|
-
self.new_w = new_w
|
|
76
|
+
self.new_h = int(new_h)
|
|
77
|
+
self.new_w = int(new_w)
|
|
78
78
|
self.interp = interp
|
|
79
79
|
|
|
80
80
|
def apply_image(self, img: ImageType) -> ImageType:
|
|
81
81
|
assert img.shape[:2] == (self.h, self.w)
|
|
82
|
-
ret =
|
|
82
|
+
ret = viz_handler.resize(img, self.new_w, self.new_h, self.interp)
|
|
83
83
|
if img.ndim == 3 and ret.ndim == 2:
|
|
84
84
|
ret = ret[:, :, np.newaxis]
|
|
85
85
|
return ret
|
|
@@ -97,7 +97,7 @@ class InferenceResize:
|
|
|
97
97
|
the inference version of `extern.tp.frcnn.common.CustomResize` .
|
|
98
98
|
"""
|
|
99
99
|
|
|
100
|
-
def __init__(self, short_edge_length: int, max_size: int, interp:
|
|
100
|
+
def __init__(self, short_edge_length: int, max_size: int, interp: str = "VIZ") -> None:
|
|
101
101
|
"""
|
|
102
102
|
:param short_edge_length: a [min, max] interval from which to sample the shortest edge length.
|
|
103
103
|
:param max_size: maximum allowed longest edge length.
|