deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (131) hide show
  1. deepdoctection/__init__.py +16 -29
  2. deepdoctection/analyzer/dd.py +70 -59
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/common.py +9 -5
  5. deepdoctection/dataflow/custom.py +5 -5
  6. deepdoctection/dataflow/custom_serialize.py +75 -18
  7. deepdoctection/dataflow/parallel_map.py +3 -3
  8. deepdoctection/dataflow/serialize.py +4 -4
  9. deepdoctection/dataflow/stats.py +3 -3
  10. deepdoctection/datapoint/annotation.py +41 -56
  11. deepdoctection/datapoint/box.py +9 -8
  12. deepdoctection/datapoint/convert.py +6 -6
  13. deepdoctection/datapoint/image.py +56 -44
  14. deepdoctection/datapoint/view.py +245 -150
  15. deepdoctection/datasets/__init__.py +1 -4
  16. deepdoctection/datasets/adapter.py +35 -26
  17. deepdoctection/datasets/base.py +14 -12
  18. deepdoctection/datasets/dataflow_builder.py +3 -3
  19. deepdoctection/datasets/info.py +24 -26
  20. deepdoctection/datasets/instances/doclaynet.py +51 -51
  21. deepdoctection/datasets/instances/fintabnet.py +46 -46
  22. deepdoctection/datasets/instances/funsd.py +25 -24
  23. deepdoctection/datasets/instances/iiitar13k.py +13 -10
  24. deepdoctection/datasets/instances/layouttest.py +4 -3
  25. deepdoctection/datasets/instances/publaynet.py +5 -5
  26. deepdoctection/datasets/instances/pubtables1m.py +24 -21
  27. deepdoctection/datasets/instances/pubtabnet.py +32 -30
  28. deepdoctection/datasets/instances/rvlcdip.py +30 -30
  29. deepdoctection/datasets/instances/xfund.py +26 -26
  30. deepdoctection/datasets/save.py +6 -6
  31. deepdoctection/eval/__init__.py +1 -4
  32. deepdoctection/eval/accmetric.py +32 -33
  33. deepdoctection/eval/base.py +8 -9
  34. deepdoctection/eval/cocometric.py +15 -13
  35. deepdoctection/eval/eval.py +41 -37
  36. deepdoctection/eval/tedsmetric.py +30 -23
  37. deepdoctection/eval/tp_eval_callback.py +16 -19
  38. deepdoctection/extern/__init__.py +2 -7
  39. deepdoctection/extern/base.py +339 -134
  40. deepdoctection/extern/d2detect.py +85 -113
  41. deepdoctection/extern/deskew.py +14 -11
  42. deepdoctection/extern/doctrocr.py +141 -130
  43. deepdoctection/extern/fastlang.py +27 -18
  44. deepdoctection/extern/hfdetr.py +71 -62
  45. deepdoctection/extern/hflayoutlm.py +504 -211
  46. deepdoctection/extern/hflm.py +230 -0
  47. deepdoctection/extern/model.py +488 -302
  48. deepdoctection/extern/pdftext.py +23 -19
  49. deepdoctection/extern/pt/__init__.py +1 -3
  50. deepdoctection/extern/pt/nms.py +6 -2
  51. deepdoctection/extern/pt/ptutils.py +29 -19
  52. deepdoctection/extern/tessocr.py +39 -38
  53. deepdoctection/extern/texocr.py +18 -18
  54. deepdoctection/extern/tp/tfutils.py +57 -9
  55. deepdoctection/extern/tp/tpcompat.py +21 -14
  56. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
  60. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  61. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
  62. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
  67. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
  68. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  69. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  70. deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
  71. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  72. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  73. deepdoctection/extern/tpdetect.py +45 -53
  74. deepdoctection/mapper/__init__.py +3 -8
  75. deepdoctection/mapper/cats.py +27 -29
  76. deepdoctection/mapper/cocostruct.py +10 -10
  77. deepdoctection/mapper/d2struct.py +27 -26
  78. deepdoctection/mapper/hfstruct.py +13 -8
  79. deepdoctection/mapper/laylmstruct.py +178 -37
  80. deepdoctection/mapper/maputils.py +12 -11
  81. deepdoctection/mapper/match.py +2 -2
  82. deepdoctection/mapper/misc.py +11 -9
  83. deepdoctection/mapper/pascalstruct.py +4 -4
  84. deepdoctection/mapper/prodigystruct.py +5 -5
  85. deepdoctection/mapper/pubstruct.py +84 -92
  86. deepdoctection/mapper/tpstruct.py +5 -5
  87. deepdoctection/mapper/xfundstruct.py +33 -33
  88. deepdoctection/pipe/__init__.py +1 -1
  89. deepdoctection/pipe/anngen.py +12 -14
  90. deepdoctection/pipe/base.py +52 -106
  91. deepdoctection/pipe/common.py +72 -59
  92. deepdoctection/pipe/concurrency.py +16 -11
  93. deepdoctection/pipe/doctectionpipe.py +24 -21
  94. deepdoctection/pipe/language.py +20 -25
  95. deepdoctection/pipe/layout.py +20 -16
  96. deepdoctection/pipe/lm.py +75 -105
  97. deepdoctection/pipe/order.py +194 -89
  98. deepdoctection/pipe/refine.py +111 -124
  99. deepdoctection/pipe/segment.py +156 -161
  100. deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
  101. deepdoctection/pipe/text.py +37 -36
  102. deepdoctection/pipe/transform.py +19 -16
  103. deepdoctection/train/__init__.py +6 -12
  104. deepdoctection/train/d2_frcnn_train.py +48 -41
  105. deepdoctection/train/hf_detr_train.py +41 -30
  106. deepdoctection/train/hf_layoutlm_train.py +153 -135
  107. deepdoctection/train/tp_frcnn_train.py +32 -31
  108. deepdoctection/utils/concurrency.py +1 -1
  109. deepdoctection/utils/context.py +13 -6
  110. deepdoctection/utils/develop.py +4 -4
  111. deepdoctection/utils/env_info.py +87 -125
  112. deepdoctection/utils/file_utils.py +6 -11
  113. deepdoctection/utils/fs.py +22 -18
  114. deepdoctection/utils/identifier.py +2 -2
  115. deepdoctection/utils/logger.py +16 -15
  116. deepdoctection/utils/metacfg.py +7 -7
  117. deepdoctection/utils/mocks.py +93 -0
  118. deepdoctection/utils/pdf_utils.py +11 -11
  119. deepdoctection/utils/settings.py +185 -181
  120. deepdoctection/utils/tqdm.py +1 -1
  121. deepdoctection/utils/transform.py +14 -9
  122. deepdoctection/utils/types.py +104 -0
  123. deepdoctection/utils/utils.py +7 -7
  124. deepdoctection/utils/viz.py +74 -72
  125. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
  126. deepdoctection-0.33.dist-info/RECORD +146 -0
  127. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
  128. deepdoctection/utils/detection_types.py +0 -68
  129. deepdoctection-0.31.dist-info/RECORD +0 -144
  130. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
  131. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
@@ -18,32 +18,39 @@
18
18
  """
19
19
  Deepdoctection wrappers for DocTr OCR text line detection and text recognition models
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  import os
22
24
  from abc import ABC
23
25
  from pathlib import Path
24
- from typing import Any, List, Literal, Mapping, Optional, Tuple
26
+ from typing import Any, Literal, Mapping, Optional, Union
25
27
  from zipfile import ZipFile
26
28
 
27
- from ..utils.detection_types import ImageType, Requirement
28
- from ..utils.env_info import get_device
29
+ from lazy_imports import try_import
30
+
31
+ from ..utils.env_info import ENV_VARS_TRUE
29
32
  from ..utils.error import DependencyError
30
33
  from ..utils.file_utils import (
31
- doctr_available,
32
34
  get_doctr_requirement,
33
35
  get_pytorch_requirement,
34
36
  get_tensorflow_requirement,
35
37
  get_tf_addons_requirements,
36
- pytorch_available,
37
- tf_addons_available,
38
- tf_available,
39
38
  )
40
39
  from ..utils.fs import load_json
41
40
  from ..utils.settings import LayoutType, ObjectTypes, PageType, TypeOrStr
41
+ from ..utils.types import PathLikeOrStr, PixelValues, Requirement
42
42
  from ..utils.viz import viz_handler
43
- from .base import DetectionResult, ImageTransformer, ObjectDetector, PredictorBase, TextRecognizer
44
- from .pt.ptutils import set_torch_auto_device
43
+ from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector, TextRecognizer
44
+ from .pt.ptutils import get_torch_device
45
+ from .tp.tfutils import get_tf_device
46
+
47
+ with try_import() as pt_import_guard:
48
+ import torch
45
49
 
46
- if doctr_available() and ((tf_addons_available() and tf_available()) or pytorch_available()):
50
+ with try_import() as tf_import_guard:
51
+ import tensorflow as tf # type: ignore # pylint: disable=E0401
52
+
53
+ with try_import() as doctr_import_guard:
47
54
  from doctr.models._utils import estimate_orientation
48
55
  from doctr.models.detection.predictor import DetectionPredictor # pylint: disable=W0611
49
56
  from doctr.models.detection.zoo import detection_predictor
@@ -51,70 +58,72 @@ if doctr_available() and ((tf_addons_available() and tf_available()) or pytorch_
51
58
  from doctr.models.recognition.predictor import RecognitionPredictor # pylint: disable=W0611
52
59
  from doctr.models.recognition.zoo import ARCHS, recognition
53
60
 
54
- if pytorch_available():
55
- import torch
56
61
 
57
- if tf_available():
58
- import tensorflow as tf # type: ignore # pylint: disable=E0401
59
-
60
-
61
- def _set_device_str(device: Optional[str] = None) -> str:
62
- if device is not None:
63
- if tf_available():
64
- device = "/" + device.replace("cuda", "gpu") + ":0"
65
- elif pytorch_available():
66
- device = set_torch_auto_device()
67
- else:
68
- device = "/gpu:0" # we impose to install tensorflow-gpu because of Tensorpack models
69
- return device
70
-
71
-
72
- def _load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
73
- if lib == "PT" and pytorch_available():
74
- state_dict = torch.load(path_weights, map_location=device)
62
+ def _get_doctr_requirements() -> list[Requirement]:
63
+ if os.environ.get("DD_USE_TF", "0") in ENV_VARS_TRUE:
64
+ return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
65
+ if os.environ.get("DD_USE_TORCH", "0") in ENV_VARS_TRUE:
66
+ return [get_pytorch_requirement(), get_doctr_requirement()]
67
+ raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextRecognizer")
68
+
69
+
70
+ def _load_model(
71
+ path_weights: PathLikeOrStr,
72
+ doctr_predictor: Union[DetectionPredictor, RecognitionPredictor],
73
+ device: Union[torch.device, tf.device],
74
+ lib: Literal["PT", "TF"],
75
+ ) -> None:
76
+ """Loading a model either in TF or PT. We only shift the model to the device when using PyTorch. The shift of
77
+ the model to the device in Tensorflow is done in the predict function."""
78
+ if lib == "PT":
79
+ state_dict = torch.load(os.fspath(path_weights), map_location=device)
75
80
  for key in list(state_dict.keys()):
76
81
  state_dict["model." + key] = state_dict.pop(key)
77
82
  doctr_predictor.load_state_dict(state_dict)
78
83
  doctr_predictor.to(device)
79
- elif lib == "TF" and tf_available():
84
+ elif lib == "TF":
80
85
  # Unzip the archive
81
86
  params_path = Path(path_weights).parent
82
- is_zip_path = path_weights.endswith(".zip")
87
+ is_zip_path = os.fspath(path_weights).endswith(".zip")
83
88
  if is_zip_path:
84
89
  with ZipFile(path_weights, "r") as file:
85
90
  file.extractall(path=params_path)
86
91
  doctr_predictor.model.load_weights(params_path / "weights")
87
92
  else:
88
- doctr_predictor.model.load_weights(path_weights)
93
+ doctr_predictor.model.load_weights(os.fspath(path_weights))
89
94
 
90
95
 
91
96
  def auto_select_lib_for_doctr() -> Literal["PT", "TF"]:
92
- """Auto select the DL library from the installed and from environment variables"""
93
- if tf_available() and os.environ.get("USE_TF", os.environ.get("USE_TENSORFLOW", False)):
94
- os.environ["USE_TF"] = "TRUE"
95
- return "TF"
96
- if pytorch_available() and os.environ.get("USE_TORCH", os.environ.get("USE_PYTORCH", False)):
97
+ """Auto select the DL library from environment variables"""
98
+ if os.environ.get("USE_TORCH", "0") in ENV_VARS_TRUE:
97
99
  return "PT"
98
- raise DependencyError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextlineDetector")
100
+ if os.environ.get("USE_TF", "0") in ENV_VARS_TRUE:
101
+ return "TF"
102
+ raise DependencyError("At least one of the env variables USE_TORCH or USE_TF must be set.")
99
103
 
100
104
 
101
- def doctr_predict_text_lines(np_img: ImageType, predictor: "DetectionPredictor", device: str) -> List[DetectionResult]:
105
+ def doctr_predict_text_lines(
106
+ np_img: PixelValues, predictor: DetectionPredictor, device: Union[torch.device, tf.device], lib: Literal["TF", "PT"]
107
+ ) -> list[DetectionResult]:
102
108
  """
103
109
  Generating text line DetectionResult based on Doctr DetectionPredictor.
104
110
 
105
111
  :param np_img: Image in np.array.
106
112
  :param predictor: `doctr.models.detection.predictor.DetectionPredictor`
107
113
  :param device: Will only be used in tensorflow settings. Either /gpu:0 or /cpu:0
114
+ :param lib: "TF" or "PT"
108
115
  :return: A list of text line detection results (without text).
109
116
  """
110
- if tf_available() and device is not None:
111
- with tf.device(device):
117
+ if lib == "TF":
118
+ with device:
112
119
  raw_output = predictor([np_img])
113
- else:
120
+ elif lib == "PT":
114
121
  raw_output = predictor([np_img])
122
+ else:
123
+ raise DependencyError("Tensorflow or PyTorch must be installed.")
115
124
  detection_results = [
116
125
  DetectionResult(
117
- box=box[:4].tolist(), class_id=1, score=box[4], absolute_coords=False, class_name=LayoutType.word
126
+ box=box[:4].tolist(), class_id=1, score=box[4], absolute_coords=False, class_name=LayoutType.WORD
118
127
  )
119
128
  for box in raw_output[0]["words"]
120
129
  ]
@@ -122,8 +131,11 @@ def doctr_predict_text_lines(np_img: ImageType, predictor: "DetectionPredictor",
122
131
 
123
132
 
124
133
  def doctr_predict_text(
125
- inputs: List[Tuple[str, ImageType]], predictor: "RecognitionPredictor", device: str
126
- ) -> List[DetectionResult]:
134
+ inputs: list[tuple[str, PixelValues]],
135
+ predictor: RecognitionPredictor,
136
+ device: Union[torch.device, tf.device],
137
+ lib: Literal["TF", "PT"],
138
+ ) -> list[DetectionResult]:
127
139
  """
128
140
  Calls Doctr text recognition model on a batch of numpy arrays (text lines predicted from a text line detector) and
129
141
  returns the recognized text as DetectionResult
@@ -132,15 +144,18 @@ def doctr_predict_text(
132
144
  text line
133
145
  :param predictor: `doctr.models.detection.predictor.RecognitionPredictor`
134
146
  :param device: Will only be used in tensorflow settings. Either /gpu:0 or /cpu:0
147
+ :param lib: "TF" or "PT"
135
148
  :return: A list of DetectionResult containing recognized text.
136
149
  """
137
150
 
138
151
  uuids, images = list(zip(*inputs))
139
- if tf_available() and device is not None:
140
- with tf.device(device):
152
+ if lib == "TF":
153
+ with device:
141
154
  raw_output = predictor(list(images))
142
- else:
155
+ elif lib == "PT":
143
156
  raw_output = predictor(list(images))
157
+ else:
158
+ raise DependencyError("Tensorflow or PyTorch must be installed.")
144
159
  detection_results = [
145
160
  DetectionResult(score=output[1], text=output[0], uuid=uuid) for uuid, output in zip(uuids, raw_output)
146
161
  ]
@@ -150,15 +165,15 @@ def doctr_predict_text(
150
165
  class DoctrTextlineDetectorMixin(ObjectDetector, ABC):
151
166
  """Base class for Doctr textline detector. This class only implements the basic wrapper functions"""
152
167
 
153
- def __init__(self, categories: Mapping[str, TypeOrStr], lib: Optional[Literal["PT", "TF"]] = None):
154
- self.categories = categories # type: ignore
168
+ def __init__(self, categories: Mapping[int, TypeOrStr], lib: Optional[Literal["PT", "TF"]] = None):
169
+ self.categories = ModelCategories(init_categories=categories)
155
170
  self.lib = lib if lib is not None else self.auto_select_lib()
156
171
 
157
- def possible_categories(self) -> List[ObjectTypes]:
158
- return [LayoutType.word]
172
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
173
+ return self.categories.get_categories(as_dict=False)
159
174
 
160
175
  @staticmethod
161
- def get_name(path_weights: str, architecture: str) -> str:
176
+ def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
162
177
  """Returns the name of the model"""
163
178
  return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])
164
179
 
@@ -206,9 +221,9 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
206
221
  def __init__(
207
222
  self,
208
223
  architecture: str,
209
- path_weights: str,
210
- categories: Mapping[str, TypeOrStr],
211
- device: Optional[Literal["cpu", "cuda"]] = None,
224
+ path_weights: PathLikeOrStr,
225
+ categories: Mapping[int, TypeOrStr],
226
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device, tf.device]] = None,
212
227
  lib: Optional[Literal["PT", "TF"]] = None,
213
228
  ) -> None:
214
229
  """
@@ -217,58 +232,54 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
217
232
  https://github.com/mindee/doctr/blob/main/doctr/models/detection/zoo.py#L20
218
233
  :param path_weights: Path to the weights of the model
219
234
  :param categories: A dict with the model output label and value
220
- :param device: "cpu" or "cuda". Will default to "cuda" if the required hardware is available.
235
+ :param device: "cpu" or "cuda" or any tf.device or torch.device. The device must be compatible with the dll
221
236
  :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used.
222
237
  """
223
238
  super().__init__(categories, lib)
224
239
  self.architecture = architecture
225
- self.path_weights = path_weights
240
+ self.path_weights = Path(path_weights)
226
241
 
227
242
  self.name = self.get_name(self.path_weights, self.architecture)
228
243
  self.model_id = self.get_model_id()
229
244
 
230
- if device is None:
231
- if self.lib == "TF":
232
- device = "cuda" if tf.test.is_gpu_available() else "cpu"
233
- elif self.lib == "PT":
234
- auto_device = get_device(False)
235
- device = "cpu" if auto_device == "mps" else auto_device
236
- else:
237
- raise DependencyError("Cannot select device automatically. Please set the device manually.")
245
+ if self.lib == "TF":
246
+ self.device = get_tf_device(device)
247
+ if self.lib == "PT":
248
+ self.device = get_torch_device(device)
238
249
 
239
- self.device_input = device
240
- self.device = _set_device_str(device)
241
- self.doctr_predictor = self.get_wrapped_model(self.architecture, self.path_weights, self.device_input, self.lib)
250
+ self.doctr_predictor = self.get_wrapped_model(self.architecture, self.path_weights, self.device, self.lib)
242
251
 
243
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
252
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
244
253
  """
245
254
  Prediction per image.
246
255
 
247
256
  :param np_img: image as numpy array
248
257
  :return: A list of DetectionResult
249
258
  """
250
- detection_results = doctr_predict_text_lines(np_img, self.doctr_predictor, self.device)
251
- return detection_results
259
+ return doctr_predict_text_lines(np_img, self.doctr_predictor, self.device, self.lib)
252
260
 
253
261
  @classmethod
254
- def get_requirements(cls) -> List[Requirement]:
255
- if tf_available():
256
- return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
257
- if pytorch_available():
258
- return [get_pytorch_requirement(), get_doctr_requirement()]
259
- raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextlineDetector")
262
+ def get_requirements(cls) -> list[Requirement]:
263
+ return _get_doctr_requirements()
260
264
 
261
- def clone(self) -> PredictorBase:
262
- return self.__class__(self.architecture, self.path_weights, self.categories, self.device_input, self.lib)
265
+ def clone(self) -> DoctrTextlineDetector:
266
+ return self.__class__(
267
+ self.architecture, self.path_weights, self.categories.get_categories(), self.device, self.lib
268
+ )
263
269
 
264
270
  @staticmethod
265
- def load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
271
+ def load_model(
272
+ path_weights: PathLikeOrStr,
273
+ doctr_predictor: DetectionPredictor,
274
+ device: Union[torch.device, tf.device],
275
+ lib: Literal["PT", "TF"],
276
+ ) -> None:
266
277
  """Loading model weights"""
267
278
  _load_model(path_weights, doctr_predictor, device, lib)
268
279
 
269
280
  @staticmethod
270
281
  def get_wrapped_model(
271
- architecture: str, path_weights: str, device: Literal["cpu", "cuda"], lib: Literal["PT", "TF"]
282
+ architecture: str, path_weights: PathLikeOrStr, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
272
283
  ) -> Any:
273
284
  """
274
285
  Get the inner (wrapped) model.
@@ -286,10 +297,12 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
286
297
  :return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow
287
298
  """
288
299
  doctr_predictor = detection_predictor(arch=architecture, pretrained=False, pretrained_backbone=False)
289
- device_str = _set_device_str(device)
290
- DoctrTextlineDetector.load_model(path_weights, doctr_predictor, device_str, lib)
300
+ DoctrTextlineDetector.load_model(path_weights, doctr_predictor, device, lib)
291
301
  return doctr_predictor
292
302
 
303
+ def clear_model(self) -> None:
304
+ self.doctr_predictor = None
305
+
293
306
 
294
307
  class DoctrTextRecognizer(TextRecognizer):
295
308
  """
@@ -325,16 +338,15 @@ class DoctrTextRecognizer(TextRecognizer):
325
338
 
326
339
  for dp in df:
327
340
  ...
328
-
329
341
  """
330
342
 
331
343
  def __init__(
332
344
  self,
333
345
  architecture: str,
334
- path_weights: str,
335
- device: Optional[Literal["cpu", "cuda"]] = None,
346
+ path_weights: PathLikeOrStr,
347
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device, tf.device]] = None,
336
348
  lib: Optional[Literal["PT", "TF"]] = None,
337
- path_config_json: Optional[str] = None,
349
+ path_config_json: Optional[PathLikeOrStr] = None,
338
350
  ) -> None:
339
351
  """
340
352
  :param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn",
@@ -350,30 +362,24 @@ class DoctrTextRecognizer(TextRecognizer):
350
362
  self.lib = lib if lib is not None else self.auto_select_lib()
351
363
 
352
364
  self.architecture = architecture
353
- self.path_weights = path_weights
365
+ self.path_weights = Path(path_weights)
354
366
 
355
367
  self.name = self.get_name(self.path_weights, self.architecture)
356
368
  self.model_id = self.get_model_id()
357
369
 
358
- if device is None:
359
- if self.lib == "TF":
360
- device = "cuda" if tf.test.is_gpu_available() else "cpu"
361
- if self.lib == "PT":
362
- auto_device = get_device(False)
363
- device = "cpu" if auto_device == "mps" else auto_device
364
- else:
365
- raise DependencyError("Cannot select device automatically. Please set the device manually.")
366
-
367
- self.device_input = device
368
- self.device = _set_device_str(device)
370
+ if self.lib == "TF":
371
+ self.device = get_tf_device(device)
372
+ if self.lib == "PT":
373
+ self.device = get_torch_device(device)
374
+
369
375
  self.path_config_json = path_config_json
370
- self.doctr_predictor = self.build_model(self.architecture, self.path_config_json)
376
+ self.doctr_predictor = self.build_model(self.architecture, self.lib, self.path_config_json)
371
377
  self.load_model(self.path_weights, self.doctr_predictor, self.device, self.lib)
372
378
  self.doctr_predictor = self.get_wrapped_model(
373
- self.architecture, self.path_weights, self.device_input, self.lib, self.path_config_json
379
+ self.architecture, self.path_weights, self.device, self.lib, self.path_config_json
374
380
  )
375
381
 
376
- def predict(self, images: List[Tuple[str, ImageType]]) -> List[DetectionResult]:
382
+ def predict(self, images: list[tuple[str, PixelValues]]) -> list[DetectionResult]:
377
383
  """
378
384
  Prediction on a batch of text lines
379
385
 
@@ -381,27 +387,30 @@ class DoctrTextRecognizer(TextRecognizer):
381
387
  :return: A list of DetectionResult
382
388
  """
383
389
  if images:
384
- return doctr_predict_text(images, self.doctr_predictor, self.device)
390
+ return doctr_predict_text(images, self.doctr_predictor, self.device, self.lib)
385
391
  return []
386
392
 
387
393
  @classmethod
388
- def get_requirements(cls) -> List[Requirement]:
389
- if tf_available():
390
- return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
391
- if pytorch_available():
392
- return [get_pytorch_requirement(), get_doctr_requirement()]
393
- raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextRecognizer")
394
+ def get_requirements(cls) -> list[Requirement]:
395
+ return _get_doctr_requirements()
394
396
 
395
- def clone(self) -> PredictorBase:
396
- return self.__class__(self.architecture, self.path_weights, self.device_input, self.lib)
397
+ def clone(self) -> DoctrTextRecognizer:
398
+ return self.__class__(self.architecture, self.path_weights, self.device, self.lib)
397
399
 
398
400
  @staticmethod
399
- def load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
401
+ def load_model(
402
+ path_weights: PathLikeOrStr,
403
+ doctr_predictor: RecognitionPredictor,
404
+ device: Union[torch.device, tf.device],
405
+ lib: Literal["PT", "TF"],
406
+ ) -> None:
400
407
  """Loading model weights"""
401
408
  _load_model(path_weights, doctr_predictor, device, lib)
402
409
 
403
410
  @staticmethod
404
- def build_model(architecture: str, path_config_json: Optional[str] = None) -> "RecognitionPredictor":
411
+ def build_model(
412
+ architecture: str, lib: Literal["TF", "PT"], path_config_json: Optional[PathLikeOrStr] = None
413
+ ) -> RecognitionPredictor:
405
414
  """Building the model"""
406
415
 
407
416
  # inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py
@@ -424,6 +433,7 @@ class DoctrTextRecognizer(TextRecognizer):
424
433
 
425
434
  model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs)
426
435
  else:
436
+ # This is not documented, but you can also directly pass the model class to architecture
427
437
  if not isinstance(
428
438
  architecture,
429
439
  (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq),
@@ -431,16 +441,16 @@ class DoctrTextRecognizer(TextRecognizer):
431
441
  raise ValueError(f"unknown architecture: {type(architecture)}")
432
442
  model = architecture
433
443
 
434
- input_shape = model.cfg["input_shape"][:2] if tf_available() else model.cfg["input_shape"][-2:]
444
+ input_shape = model.cfg["input_shape"][:2] if lib == "TF" else model.cfg["input_shape"][-2:]
435
445
  return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model)
436
446
 
437
447
  @staticmethod
438
448
  def get_wrapped_model(
439
449
  architecture: str,
440
- path_weights: str,
441
- device: Literal["cpu", "cuda"],
450
+ path_weights: PathLikeOrStr,
451
+ device: Union[torch.device, tf.device],
442
452
  lib: Literal["PT", "TF"],
443
- path_config_json: Optional[str] = None,
453
+ path_config_json: Optional[PathLikeOrStr] = None,
444
454
  ) -> Any:
445
455
  """
446
456
  Get the inner (wrapped) model.
@@ -455,13 +465,12 @@ class DoctrTextRecognizer(TextRecognizer):
455
465
  a model trained on custom vocab.
456
466
  :return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow
457
467
  """
458
- doctr_predictor = DoctrTextRecognizer.build_model(architecture, path_config_json)
459
- device_str = _set_device_str(device)
460
- DoctrTextRecognizer.load_model(path_weights, doctr_predictor, device_str, lib)
468
+ doctr_predictor = DoctrTextRecognizer.build_model(architecture, lib, path_config_json)
469
+ DoctrTextRecognizer.load_model(path_weights, doctr_predictor, device, lib)
461
470
  return doctr_predictor
462
471
 
463
472
  @staticmethod
464
- def get_name(path_weights: str, architecture: str) -> str:
473
+ def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
465
474
  """Returns the name of the model"""
466
475
  return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])
467
476
 
@@ -470,6 +479,9 @@ class DoctrTextRecognizer(TextRecognizer):
470
479
  """Auto select the DL library from the installed and from environment variables"""
471
480
  return auto_select_lib_for_doctr()
472
481
 
482
+ def clear_model(self) -> None:
483
+ self.doctr_predictor = None
484
+
473
485
 
474
486
  class DocTrRotationTransformer(ImageTransformer):
475
487
  """
@@ -503,7 +515,7 @@ class DocTrRotationTransformer(ImageTransformer):
503
515
  self.ratio_threshold_for_lines = ratio_threshold_for_lines
504
516
  self.name = "doctr_rotation_transformer"
505
517
 
506
- def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
518
+ def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
507
519
  """
508
520
  Applies the predicted rotation to the image, effectively rotating the image backwards.
509
521
  This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
@@ -514,19 +526,18 @@ class DocTrRotationTransformer(ImageTransformer):
514
526
  """
515
527
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
516
528
 
517
- def predict(self, np_img: ImageType) -> DetectionResult:
529
+ def predict(self, np_img: PixelValues) -> DetectionResult:
518
530
  angle = estimate_orientation(np_img, self.number_contours, self.ratio_threshold_for_lines)
519
531
  if angle < 0:
520
532
  angle += 360
521
533
  return DetectionResult(angle=round(angle, 2))
522
534
 
523
535
  @classmethod
524
- def get_requirements(cls) -> List[Requirement]:
536
+ def get_requirements(cls) -> list[Requirement]:
525
537
  return [get_doctr_requirement()]
526
538
 
527
- def clone(self) -> PredictorBase:
539
+ def clone(self) -> DocTrRotationTransformer:
528
540
  return self.__class__(self.number_contours, self.ratio_threshold_for_lines)
529
541
 
530
- @staticmethod
531
- def possible_category() -> PageType:
532
- return PageType.angle
542
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
543
+ return (PageType.ANGLE,)
@@ -18,16 +18,22 @@
18
18
  """
19
19
  Deepdoctection wrappers for fasttext language detection models
20
20
  """
21
+ from __future__ import annotations
22
+
23
+ import os
21
24
  from abc import ABC
22
- from copy import copy
23
25
  from pathlib import Path
24
- from typing import Any, List, Mapping, Tuple, Union
26
+ from types import MappingProxyType
27
+ from typing import Any, Mapping, Union
28
+
29
+ from lazy_imports import try_import
25
30
 
26
- from ..utils.file_utils import Requirement, fasttext_available, get_fasttext_requirement
31
+ from ..utils.file_utils import Requirement, get_fasttext_requirement
27
32
  from ..utils.settings import TypeOrStr, get_type
28
- from .base import DetectionResult, LanguageDetector, PredictorBase
33
+ from ..utils.types import PathLikeOrStr
34
+ from .base import DetectionResult, LanguageDetector, ModelCategories
29
35
 
30
- if fasttext_available():
36
+ with try_import() as import_guard:
31
37
  from fasttext import load_model # type: ignore
32
38
 
33
39
 
@@ -36,22 +42,23 @@ class FasttextLangDetectorMixin(LanguageDetector, ABC):
36
42
  Base class for Fasttext language detection implementation. This class only implements the basic wrapper functions.
37
43
  """
38
44
 
39
- def __init__(self, categories: Mapping[str, TypeOrStr]) -> None:
45
+ def __init__(self, categories: Mapping[int, TypeOrStr], categories_orig: Mapping[str, TypeOrStr]) -> None:
40
46
  """
41
47
  :param categories: A dict with the model output label and value. We use as convention the ISO 639-2 language
42
48
  """
43
- self.categories = copy({idx: get_type(cat) for idx, cat in categories.items()})
49
+ self.categories = ModelCategories(init_categories=categories)
50
+ self.categories_orig = MappingProxyType({cat_orig: get_type(cat) for cat_orig, cat in categories_orig.items()})
44
51
 
45
- def output_to_detection_result(self, output: Union[Tuple[Any, Any]]) -> DetectionResult:
52
+ def output_to_detection_result(self, output: Union[tuple[Any, Any]]) -> DetectionResult:
46
53
  """
47
54
  Generating `DetectionResult` from model output
48
55
  :param output: FastText model output
49
56
  :return: `DetectionResult` filled with `text` and `score`
50
57
  """
51
- return DetectionResult(text=self.categories[output[0][0]], score=output[1][0])
58
+ return DetectionResult(text=self.categories_orig[output[0][0]], score=output[1][0])
52
59
 
53
60
  @staticmethod
54
- def get_name(path_weights: str) -> str:
61
+ def get_name(path_weights: PathLikeOrStr) -> str:
55
62
  """Returns the name of the model"""
56
63
  return "fasttext_" + "_".join(Path(path_weights).parts[-2:])
57
64
 
@@ -78,15 +85,17 @@ class FasttextLangDetector(FasttextLangDetectorMixin):
78
85
 
79
86
  """
80
87
 
81
- def __init__(self, path_weights: str, categories: Mapping[str, TypeOrStr]):
88
+ def __init__(
89
+ self, path_weights: PathLikeOrStr, categories: Mapping[int, TypeOrStr], categories_orig: Mapping[str, TypeOrStr]
90
+ ):
82
91
  """
83
92
  :param path_weights: path to model weights
84
93
  :param categories: A dict with the model output label and value. We use as convention the ISO 639-2 language
85
94
  code.
86
95
  """
87
- super().__init__(categories)
96
+ super().__init__(categories, categories_orig)
88
97
 
89
- self.path_weights = path_weights
98
+ self.path_weights = Path(path_weights)
90
99
 
91
100
  self.name = self.get_name(self.path_weights)
92
101
  self.model_id = self.get_model_id()
@@ -98,16 +107,16 @@ class FasttextLangDetector(FasttextLangDetectorMixin):
98
107
  return self.output_to_detection_result(output)
99
108
 
100
109
  @classmethod
101
- def get_requirements(cls) -> List[Requirement]:
110
+ def get_requirements(cls) -> list[Requirement]:
102
111
  return [get_fasttext_requirement()]
103
112
 
104
- def clone(self) -> PredictorBase:
105
- return self.__class__(self.path_weights, self.categories)
113
+ def clone(self) -> FasttextLangDetector:
114
+ return self.__class__(self.path_weights, self.categories.get_categories(), self.categories_orig)
106
115
 
107
116
  @staticmethod
108
- def get_wrapped_model(path_weights: str) -> Any:
117
+ def get_wrapped_model(path_weights: PathLikeOrStr) -> Any:
109
118
  """
110
119
  Get the wrapped model
111
120
  :param path_weights: path to model weights
112
121
  """
113
- return load_model(path_weights)
122
+ return load_model(os.fspath(path_weights))