deepdoctection 0.30__py3-none-any.whl → 0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (74) hide show
  1. deepdoctection/__init__.py +4 -2
  2. deepdoctection/analyzer/dd.py +6 -5
  3. deepdoctection/dataflow/base.py +0 -19
  4. deepdoctection/dataflow/custom.py +4 -3
  5. deepdoctection/dataflow/custom_serialize.py +14 -5
  6. deepdoctection/dataflow/parallel_map.py +12 -11
  7. deepdoctection/dataflow/serialize.py +5 -4
  8. deepdoctection/datapoint/annotation.py +33 -12
  9. deepdoctection/datapoint/box.py +1 -4
  10. deepdoctection/datapoint/convert.py +3 -1
  11. deepdoctection/datapoint/image.py +66 -29
  12. deepdoctection/datapoint/view.py +57 -25
  13. deepdoctection/datasets/adapter.py +1 -1
  14. deepdoctection/datasets/base.py +83 -10
  15. deepdoctection/datasets/dataflow_builder.py +1 -1
  16. deepdoctection/datasets/info.py +2 -2
  17. deepdoctection/datasets/instances/layouttest.py +2 -7
  18. deepdoctection/eval/accmetric.py +1 -1
  19. deepdoctection/eval/base.py +5 -4
  20. deepdoctection/eval/eval.py +2 -2
  21. deepdoctection/eval/tp_eval_callback.py +5 -4
  22. deepdoctection/extern/base.py +39 -13
  23. deepdoctection/extern/d2detect.py +164 -64
  24. deepdoctection/extern/deskew.py +32 -7
  25. deepdoctection/extern/doctrocr.py +227 -39
  26. deepdoctection/extern/fastlang.py +45 -7
  27. deepdoctection/extern/hfdetr.py +90 -33
  28. deepdoctection/extern/hflayoutlm.py +109 -22
  29. deepdoctection/extern/pdftext.py +2 -1
  30. deepdoctection/extern/pt/ptutils.py +3 -2
  31. deepdoctection/extern/tessocr.py +134 -22
  32. deepdoctection/extern/texocr.py +2 -0
  33. deepdoctection/extern/tp/tpcompat.py +4 -4
  34. deepdoctection/extern/tp/tpfrcnn/preproc.py +2 -7
  35. deepdoctection/extern/tpdetect.py +50 -23
  36. deepdoctection/mapper/d2struct.py +1 -1
  37. deepdoctection/mapper/hfstruct.py +1 -1
  38. deepdoctection/mapper/laylmstruct.py +1 -1
  39. deepdoctection/mapper/maputils.py +13 -2
  40. deepdoctection/mapper/prodigystruct.py +1 -1
  41. deepdoctection/mapper/pubstruct.py +10 -10
  42. deepdoctection/mapper/tpstruct.py +1 -1
  43. deepdoctection/pipe/anngen.py +35 -8
  44. deepdoctection/pipe/base.py +53 -19
  45. deepdoctection/pipe/cell.py +29 -8
  46. deepdoctection/pipe/common.py +12 -4
  47. deepdoctection/pipe/doctectionpipe.py +2 -2
  48. deepdoctection/pipe/language.py +3 -2
  49. deepdoctection/pipe/layout.py +3 -2
  50. deepdoctection/pipe/lm.py +2 -2
  51. deepdoctection/pipe/refine.py +18 -10
  52. deepdoctection/pipe/segment.py +21 -16
  53. deepdoctection/pipe/text.py +14 -8
  54. deepdoctection/pipe/transform.py +16 -9
  55. deepdoctection/train/d2_frcnn_train.py +15 -12
  56. deepdoctection/train/hf_detr_train.py +8 -6
  57. deepdoctection/train/hf_layoutlm_train.py +16 -11
  58. deepdoctection/utils/__init__.py +3 -0
  59. deepdoctection/utils/concurrency.py +1 -1
  60. deepdoctection/utils/context.py +2 -2
  61. deepdoctection/utils/env_info.py +55 -22
  62. deepdoctection/utils/error.py +84 -0
  63. deepdoctection/utils/file_utils.py +4 -15
  64. deepdoctection/utils/fs.py +7 -7
  65. deepdoctection/utils/pdf_utils.py +5 -4
  66. deepdoctection/utils/settings.py +5 -1
  67. deepdoctection/utils/transform.py +1 -1
  68. deepdoctection/utils/utils.py +0 -6
  69. deepdoctection/utils/viz.py +44 -2
  70. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/METADATA +33 -58
  71. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/RECORD +74 -73
  72. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/WHEEL +1 -1
  73. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/LICENSE +0 -0
  74. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/top_level.txt +0 -0
@@ -53,7 +53,7 @@ import re
53
53
  import subprocess
54
54
  import sys
55
55
  from collections import defaultdict
56
- from typing import List, Optional, Tuple
56
+ from typing import List, Literal, Optional, Tuple
57
57
 
58
58
  import numpy as np
59
59
  from tabulate import tabulate
@@ -420,7 +420,7 @@ def collect_env_info() -> str:
420
420
  try:
421
421
  import prctl # type: ignore
422
422
 
423
- _ = prctl.set_pdeathsig # noqa
423
+ _ = prctl.set_pdeathsig # pylint: disable=E1101
424
424
  except ModuleNotFoundError:
425
425
  has_prctl = False
426
426
  data.append(("python-prctl", str(has_prctl)))
@@ -452,6 +452,20 @@ def collect_env_info() -> str:
452
452
  return env_str
453
453
 
454
454
 
455
+ def set_env(name: str, value: str) -> None:
456
+ """
457
+ Set an environment variable if it is not already set.
458
+
459
+ :param name: The name of the environment variable
460
+ :param value: The value of the environment variable
461
+ """
462
+
463
+ if os.environ.get(name):
464
+ return
465
+ os.environ[name] = value
466
+ return
467
+
468
+
455
469
  def auto_select_lib_and_device() -> None:
456
470
  """
457
471
  Select the DL library and subsequently the device.
@@ -461,41 +475,60 @@ def auto_select_lib_and_device() -> None:
461
475
  is not installed raise ImportError.
462
476
  """
463
477
 
478
+ # USE_TF and USE_TORCH are env variables that steer DL library selection for Doctr.
464
479
  if tf_available() and tensorpack_available():
465
480
  from tensorpack.utils.gpu import get_num_gpu # pylint: disable=E0401
466
481
 
467
482
  if get_num_gpu() >= 1:
468
- os.environ["USE_TENSORFLOW"] = "True"
469
- os.environ["USE_PYTORCH"] = "False"
470
- os.environ["USE_CUDA"] = "True"
471
- os.environ["USE_MPS"] = "False"
483
+ set_env("USE_TENSORFLOW", "True")
484
+ set_env("USE_PYTORCH", "False")
485
+ set_env("USE_CUDA", "True")
486
+ set_env("USE_MPS", "False")
487
+ set_env("USE_TF", "TRUE")
488
+ set_env("USE_TORCH", "False")
472
489
  return
473
490
  if pytorch_available():
474
- os.environ["USE_TENSORFLOW"] = "False"
475
- os.environ["USE_PYTORCH"] = "True"
476
- os.environ["USE_CUDA"] = "False"
491
+ set_env("USE_TENSORFLOW", "False")
492
+ set_env("USE_PYTORCH", "True")
493
+ set_env("USE_CUDA", "False")
494
+ set_env("USE_TF", "False")
495
+ set_env("USE_TORCH", "TRUE")
477
496
  return
478
497
  logger.warning(
479
498
  LoggingRecord("You have Tensorflow installed but no GPU is available. All Tensorflow models require a GPU.")
480
499
  )
500
+ if tf_available():
501
+ set_env("USE_TENSORFLOW", "False")
502
+ set_env("USE_PYTORCH", "False")
503
+ set_env("USE_CUDA", "False")
504
+ set_env("USE_TF", "AUTO")
505
+ set_env("USE_TORCH", "AUTO")
506
+ return
507
+
481
508
  if pytorch_available():
482
509
  import torch
483
510
 
484
511
  if torch.cuda.is_available():
485
- os.environ["USE_TENSORFLOW"] = "False"
486
- os.environ["USE_PYTORCH"] = "True"
487
- os.environ["USE_CUDA"] = "True"
512
+ set_env("USE_TENSORFLOW", "False")
513
+ set_env("USE_PYTORCH", "True")
514
+ set_env("USE_CUDA", "True")
515
+ set_env("USE_TF", "False")
516
+ set_env("USE_TORCH", "TRUE")
488
517
  return
489
518
  if torch.backends.mps.is_available():
490
- os.environ["USE_TENSORFLOW"] = "False"
491
- os.environ["USE_PYTORCH"] = "True"
492
- os.environ["USE_CUDA"] = "False"
493
- os.environ["USE_MPS"] = "True"
519
+ set_env("USE_TENSORFLOW", "False")
520
+ set_env("USE_PYTORCH", "True")
521
+ set_env("USE_CUDA", "False")
522
+ set_env("USE_MPS", "True")
523
+ set_env("USE_TF", "False")
524
+ set_env("USE_TORCH", "TRUE")
494
525
  return
495
- os.environ["USE_TENSORFLOW"] = "False"
496
- os.environ["USE_PYTORCH"] = "True"
497
- os.environ["USE_CUDA"] = "False"
498
- os.environ["USE_MPS"] = "False"
526
+ set_env("USE_TENSORFLOW", "False")
527
+ set_env("USE_PYTORCH", "True")
528
+ set_env("USE_CUDA", "False")
529
+ set_env("USE_MPS", "False")
530
+ set_env("USE_TF", "AUTO")
531
+ set_env("USE_TORCH", "AUTO")
499
532
  return
500
533
  logger.warning(
501
534
  LoggingRecord(
@@ -505,7 +538,7 @@ def auto_select_lib_and_device() -> None:
505
538
  )
506
539
 
507
540
 
508
- def get_device(ignore_cpu: bool = True) -> str:
541
+ def get_device(ignore_cpu: bool = True) -> Literal["cuda", "mps", "cpu"]:
509
542
  """
510
543
  Device checks for running PyTorch with CUDA, MPS or optionall CPU.
511
544
  If nothing can be found and if `disable_cpu` is deactivated it will raise a `ValueError`
@@ -520,7 +553,7 @@ def get_device(ignore_cpu: bool = True) -> str:
520
553
  return "mps"
521
554
  if not ignore_cpu:
522
555
  return "cpu"
523
- raise ValueError("Could not find either GPU nor MPS")
556
+ raise RuntimeWarning("Could not find either GPU nor MPS")
524
557
 
525
558
 
526
559
  def auto_select_viz_library() -> None:
@@ -0,0 +1,84 @@
1
+ # -*- coding: utf-8 -*-
2
+ # File: error.py
3
+
4
+ # Copyright 2024 Dr. Janis Meyer. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Module for custom exceptions
20
+ """
21
+
22
+
23
+ class BoundingBoxError(BaseException):
24
+ """Special exception only for `datapoint.box.BoundingBox`"""
25
+
26
+
27
+ class AnnotationError(BaseException):
28
+ """Special exception only for `datapoint.annotation.Annotation`"""
29
+
30
+
31
+ class ImageError(BaseException):
32
+ """Special exception only for `datapoint.image.Image`"""
33
+
34
+
35
+ class UUIDError(BaseException):
36
+ """Special exception only for `utils.identifier`"""
37
+
38
+
39
+ class DependencyError(BaseException):
40
+ """Special exception only for missing dependencies. We do not use the internals ImportError or
41
+ ModuleNotFoundError."""
42
+
43
+
44
+ class DataFlowTerminatedError(BaseException):
45
+ """
46
+ An exception indicating that the DataFlow is unable to produce any more
47
+ data, i.e. something wrong happened so that calling `__iter__`
48
+ cannot give a valid iterator anymore.
49
+ In most DataFlow this will never be raised.
50
+ """
51
+
52
+
53
+ class DataFlowResetStateNotCalledError(BaseException):
54
+ """
55
+ An exception indicating that `reset_state()` has not been called before starting
56
+ iteration.
57
+ """
58
+
59
+ def __init__(self) -> None:
60
+ super().__init__("Iterating a dataflow requires .reset_state() to be called first")
61
+
62
+
63
+ class MalformedData(BaseException):
64
+ """
65
+ Exception class for malformed data. Use this class if something does not look right with the data
66
+ """
67
+
68
+
69
+ class FileExtensionError(BaseException):
70
+ """
71
+ Exception class for wrong file extensions.
72
+ """
73
+
74
+
75
+ class TesseractError(RuntimeError):
76
+ """
77
+ Tesseract Error
78
+ """
79
+
80
+ def __init__(self, status: int, message: str) -> None:
81
+ super().__init__()
82
+ self.status = status
83
+ self.message = message
84
+ self.args = (status, message)
@@ -22,6 +22,7 @@ import importlib_metadata
22
22
  from packaging import version
23
23
 
24
24
  from .detection_types import Requirement
25
+ from .error import DependencyError
25
26
  from .logger import LoggingRecord, logger
26
27
  from .metacfg import AttrDict
27
28
 
@@ -263,7 +264,7 @@ def set_tesseract_path(tesseract_path: str) -> None:
263
264
  :param tesseract_path: Tesseract installation path.
264
265
  """
265
266
  if tesseract_path is None:
266
- raise ValueError("tesseract_path is empty.")
267
+ raise TypeError("tesseract_path cannot be None")
267
268
 
268
269
  global _TESS_AVAILABLE # pylint: disable=W0603
269
270
  global _TESS_PATH # pylint: disable=W0603
@@ -288,12 +289,6 @@ def tesseract_available() -> bool:
288
289
  # copy paste from https://github.com/madmaze/pytesseract/blob/master/pytesseract/pytesseract.py
289
290
 
290
291
 
291
- class TesseractNotFound(BaseException):
292
- """
293
- Exception class for Tesseract being not found
294
- """
295
-
296
-
297
292
  def get_tesseract_version() -> Union[int, version.Version]:
298
293
  """
299
294
  Returns Version object of the Tesseract version. We need at least Tesseract 3.05
@@ -306,7 +301,7 @@ def get_tesseract_version() -> Union[int, version.Version]:
306
301
  stdin=subprocess.DEVNULL,
307
302
  )
308
303
  except OSError:
309
- raise TesseractNotFound(_TESS_ERR_MSG) from OSError
304
+ raise DependencyError(_TESS_ERR_MSG) from OSError
310
305
 
311
306
  raw_version = output.decode("utf-8")
312
307
  str_version, *_ = raw_version.lstrip(string.printable[10:]).partition(" ")
@@ -348,12 +343,6 @@ def pdf_to_cairo_available() -> bool:
348
343
  return bool(_PDF_TO_CAIRO_AVAILABLE)
349
344
 
350
345
 
351
- class PopplerNotFound(BaseException):
352
- """
353
- Exception class for Poppler being not found
354
- """
355
-
356
-
357
346
  def get_poppler_version() -> Union[int, version.Version]:
358
347
  """
359
348
  Returns Version object of the Poppler version. We need at least Tesseract 3.05
@@ -371,7 +360,7 @@ def get_poppler_version() -> Union[int, version.Version]:
371
360
  [command, "-v"], stderr=subprocess.STDOUT, env=environ, stdin=subprocess.DEVNULL
372
361
  )
373
362
  except OSError:
374
- raise PopplerNotFound() from OSError
363
+ raise DependencyError(_POPPLER_ERR_MSG) from OSError
375
364
 
376
365
  raw_version = output.decode("utf-8")
377
366
  list_version = raw_version.split("\n", maxsplit=1)[0].split(" ")[-1].split(".")
@@ -34,7 +34,7 @@ from .logger import LoggingRecord, logger
34
34
  from .pdf_utils import get_pdf_file_reader, get_pdf_file_writer
35
35
  from .settings import CONFIGS, DATASET_DIR, MODEL_DIR, PATH
36
36
  from .tqdm import get_tqdm
37
- from .utils import FileExtensionError, is_file_extension
37
+ from .utils import is_file_extension
38
38
  from .viz import viz_handler
39
39
 
40
40
  __all__ = [
@@ -44,9 +44,7 @@ __all__ = [
44
44
  "maybe_path_or_pdf",
45
45
  "download",
46
46
  "mkdir_p",
47
- "is_file_extension",
48
47
  "load_json",
49
- "FileExtensionError",
50
48
  "sub_path",
51
49
  "get_package_path",
52
50
  "get_configs_dir_path",
@@ -125,8 +123,8 @@ def download(url: str, directory: Pathlike, file_name: Optional[str] = None, exp
125
123
  assert size > 0, f"Downloaded an empty file from {url}!"
126
124
 
127
125
  if expect_size is not None and size != expect_size:
128
- logger.error(LoggingRecord(f"File downloaded from {url} does not match the expected size!"))
129
- logger.error(
126
+ logger.warning(LoggingRecord(f"File downloaded from {url} does not match the expected size!"))
127
+ logger.warning(
130
128
  LoggingRecord("You may have downloaded a broken file, or the upstream may have modified the file.")
131
129
  )
132
130
 
@@ -210,13 +208,15 @@ def get_load_image_func(
210
208
  :return: The function loading the file (and converting to its desired format)
211
209
  """
212
210
 
213
- assert is_file_extension(path, [".png", ".jpeg", ".jpg", ".pdf", ".tif"]), f"image type not allowed: {path}"
211
+ assert is_file_extension(path, [".png", ".jpeg", ".jpg", ".pdf", ".tif"]), f"image type not allowed: " f"{path}"
214
212
 
215
213
  if is_file_extension(path, [".png", ".jpeg", ".jpg", ".tif"]):
216
214
  return load_image_from_file
217
215
  if is_file_extension(path, [".pdf"]):
218
216
  return load_bytes_from_pdf_file
219
- return NotImplemented
217
+ raise NotImplementedError(
218
+ "File extension not supported by any loader. Please specify a file type and raise an issue"
219
+ )
220
220
 
221
221
 
222
222
  def maybe_path_or_pdf(path: Pathlike) -> int:
@@ -32,9 +32,10 @@ from pypdf import PdfReader, PdfWriter, errors
32
32
 
33
33
  from .context import save_tmp_file, timeout_manager
34
34
  from .detection_types import ImageType, Pathlike
35
- from .file_utils import PopplerNotFound, pdf_to_cairo_available, pdf_to_ppm_available, qpdf_available
35
+ from .error import DependencyError, FileExtensionError
36
+ from .file_utils import pdf_to_cairo_available, pdf_to_ppm_available, qpdf_available
36
37
  from .logger import LoggingRecord, logger
37
- from .utils import FileExtensionError, is_file_extension
38
+ from .utils import is_file_extension
38
39
  from .viz import viz_handler
39
40
 
40
41
  __all__ = ["decrypt_pdf_document", "get_pdf_file_reader", "get_pdf_file_writer", "PDFStreamer", "pdf_to_np_array"]
@@ -165,7 +166,7 @@ def _input_to_cli_str(
165
166
  elif pdf_to_cairo_available():
166
167
  command = "pdftocairo"
167
168
  else:
168
- raise PopplerNotFound("Poppler not found. Please install or add to your PATH.")
169
+ raise DependencyError("Poppler not found. Please install or add to your PATH.")
169
170
 
170
171
  if platform.system() == "Windows":
171
172
  command = command + ".exe"
@@ -201,7 +202,7 @@ def _run_poppler(poppler_args: List[str]) -> None:
201
202
  except OSError as error:
202
203
  if error.errno != ENOENT:
203
204
  raise error from error
204
- raise PopplerNotFound("Poppler not found. Please install or add to your PATH.") from error
205
+ raise DependencyError("Poppler not found. Please install or add to your PATH.") from error
205
206
 
206
207
  with timeout_manager(proc, 0):
207
208
  if proc.returncode:
@@ -65,6 +65,7 @@ class PageType(ObjectTypes):
65
65
 
66
66
  document_type = "document_type"
67
67
  language = "language"
68
+ angle = "angle"
68
69
 
69
70
 
70
71
  @object_types_registry.register("SummaryType")
@@ -125,6 +126,7 @@ class LayoutType(ObjectTypes):
125
126
  column = "column"
126
127
  word = "word"
127
128
  line = "line"
129
+ background = "background"
128
130
 
129
131
 
130
132
  @object_types_registry.register("TableType")
@@ -324,7 +326,9 @@ def token_class_tag_to_token_class_with_tag(token: ObjectTypes, tag: ObjectTypes
324
326
  """
325
327
  if isinstance(token, TokenClasses) and isinstance(tag, BioTag):
326
328
  return _TOKEN_AND_TAG_TO_TOKEN_CLASS_WITH_TAG[(token, tag)]
327
- raise TypeError("Token must be of type TokenClasses and tag must be of type BioTag")
329
+ raise TypeError(
330
+ f"Token must be of type TokenClasses, is of {type(token)} and tag " f"{type(tag)} must be of type BioTag"
331
+ )
328
332
 
329
333
 
330
334
  def token_class_with_tag_to_token_class_and_tag(
@@ -47,7 +47,7 @@ class BaseTransform(ABC):
47
47
  @abstractmethod
48
48
  def apply_image(self, img: ImageType) -> ImageType:
49
49
  """The transformation that should be applied to the image"""
50
- raise NotImplementedError
50
+ raise NotImplementedError()
51
51
 
52
52
 
53
53
  class ResizeTransform(BaseTransform):
@@ -144,12 +144,6 @@ def get_rng(obj: Any = None) -> np.random.RandomState:
144
144
  return np.random.RandomState(seed)
145
145
 
146
146
 
147
- class FileExtensionError(BaseException):
148
- """
149
- An exception indicating that a file does not seem to have an expected type
150
- """
151
-
152
-
153
147
  def is_file_extension(file_name: Pathlike, extension: Union[str, Sequence[str]]) -> bool:
154
148
  """
155
149
  Check if a given file name has a given extension
@@ -38,6 +38,7 @@ from numpy import float32, uint8
38
38
 
39
39
  from .detection_types import ImageType
40
40
  from .env_info import auto_select_viz_library
41
+ from .error import DependencyError
41
42
  from .file_utils import get_opencv_requirement, get_pillow_requirement, opencv_available, pillow_available
42
43
 
43
44
  if opencv_available():
@@ -307,6 +308,7 @@ class VizPackageHandler:
307
308
  "draw_text": "_cv2_draw_text",
308
309
  "interactive_imshow": "_cv2_interactive_imshow",
309
310
  "encode": "_cv2_encode",
311
+ "rotate_image": "_cv2_rotate_image",
310
312
  },
311
313
  "pillow": {
312
314
  "read_image": "_pillow_read_image",
@@ -319,6 +321,7 @@ class VizPackageHandler:
319
321
  "draw_text": "_pillow_draw_text",
320
322
  "interactive_imshow": "_pillow_interactive_imshow",
321
323
  "encode": "_pillow_encode",
324
+ "rotate_image": "_pillow_rotate_image",
322
325
  },
323
326
  }
324
327
 
@@ -352,12 +355,12 @@ class VizPackageHandler:
352
355
  if maybe_cv2:
353
356
  requirements = get_opencv_requirement()
354
357
  if not requirements[1]:
355
- raise ImportError(requirements[2])
358
+ raise DependencyError(requirements[2])
356
359
  return maybe_cv2
357
360
 
358
361
  requirements = get_pillow_requirement()
359
362
  if not requirements[1]:
360
- raise ImportError(requirements[2])
363
+ raise DependencyError(requirements[2])
361
364
  return "pillow"
362
365
 
363
366
  def _set_vars(self, package: str) -> None:
@@ -690,6 +693,45 @@ class VizPackageHandler:
690
693
  pil_image = Image.fromarray(np.uint8(np_image[:, :, ::-1]))
691
694
  pil_image.show(name)
692
695
 
696
+ def rotate_image(self, np_image: ImageType, angle: int) -> ImageType:
697
+ """Rotating an image by some angle"""
698
+ return getattr(self, self.pkg_func_dict["rotate_image"])(np_image, angle)
699
+
700
+ @staticmethod
701
+ def _cv2_rotate_image(np_image: ImageType, angle: float) -> ImageType:
702
+ # copy & paste from https://stackoverflow.com/questions/43892506
703
+ # /opencv-python-rotate-image-without-cropping-sides
704
+
705
+ height, width = np_image.shape[:2]
706
+ image_center = (width / 2, height / 2)
707
+ rotation_mat = cv2.getRotationMatrix2D(center=image_center, angle=angle, scale=1.0)
708
+
709
+ # rotation calculates the cos and sin, taking absolutes of those.
710
+ abs_cos = abs(rotation_mat[0, 0])
711
+ abs_sin = abs(rotation_mat[0, 1])
712
+
713
+ # find the new width and height bounds
714
+ bound_w = int(height * abs_sin + width * abs_cos)
715
+ bound_h = int(height * abs_cos + width * abs_sin)
716
+
717
+ # subtract old image center (bringing image back to origo) and adding the new image center coordinates
718
+ rotation_mat[0, 2] += bound_w / 2 - image_center[0]
719
+ rotation_mat[1, 2] += bound_h / 2 - image_center[1]
720
+
721
+ np_image = cv2.warpAffine( # type: ignore
722
+ src=np_image,
723
+ M=rotation_mat,
724
+ dsize=(bound_w, bound_h),
725
+ )
726
+
727
+ return np_image
728
+
729
+ @staticmethod
730
+ def _pillow_rotate_image(np_image: ImageType, angle: int) -> ImageType:
731
+ pil_image = Image.fromarray(np.uint8(np_image[:, :, ::-1]))
732
+ pil_image_rotated = pil_image.rotate(angle, expand=True)
733
+ return np.array(pil_image_rotated)[:, :, ::-1]
734
+
693
735
 
694
736
  auto_select_viz_library()
695
737
  viz_handler = VizPackageHandler()