deepdoctection 0.31__py3-none-any.whl → 0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (91) hide show
  1. deepdoctection/__init__.py +35 -28
  2. deepdoctection/analyzer/dd.py +30 -24
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/datapoint/annotation.py +2 -1
  5. deepdoctection/datapoint/box.py +2 -1
  6. deepdoctection/datapoint/image.py +13 -7
  7. deepdoctection/datapoint/view.py +95 -24
  8. deepdoctection/datasets/__init__.py +1 -4
  9. deepdoctection/datasets/adapter.py +5 -2
  10. deepdoctection/datasets/base.py +5 -3
  11. deepdoctection/datasets/info.py +2 -2
  12. deepdoctection/datasets/instances/doclaynet.py +3 -2
  13. deepdoctection/datasets/instances/fintabnet.py +2 -1
  14. deepdoctection/datasets/instances/funsd.py +2 -1
  15. deepdoctection/datasets/instances/iiitar13k.py +5 -2
  16. deepdoctection/datasets/instances/layouttest.py +2 -1
  17. deepdoctection/datasets/instances/publaynet.py +2 -2
  18. deepdoctection/datasets/instances/pubtables1m.py +6 -3
  19. deepdoctection/datasets/instances/pubtabnet.py +2 -1
  20. deepdoctection/datasets/instances/rvlcdip.py +2 -1
  21. deepdoctection/datasets/instances/xfund.py +2 -1
  22. deepdoctection/eval/__init__.py +1 -4
  23. deepdoctection/eval/cocometric.py +2 -1
  24. deepdoctection/eval/eval.py +17 -13
  25. deepdoctection/eval/tedsmetric.py +14 -11
  26. deepdoctection/eval/tp_eval_callback.py +9 -3
  27. deepdoctection/extern/__init__.py +2 -7
  28. deepdoctection/extern/d2detect.py +24 -32
  29. deepdoctection/extern/deskew.py +4 -2
  30. deepdoctection/extern/doctrocr.py +75 -81
  31. deepdoctection/extern/fastlang.py +4 -2
  32. deepdoctection/extern/hfdetr.py +22 -28
  33. deepdoctection/extern/hflayoutlm.py +335 -103
  34. deepdoctection/extern/hflm.py +225 -0
  35. deepdoctection/extern/model.py +56 -47
  36. deepdoctection/extern/pdftext.py +8 -4
  37. deepdoctection/extern/pt/__init__.py +1 -3
  38. deepdoctection/extern/pt/nms.py +6 -2
  39. deepdoctection/extern/pt/ptutils.py +27 -19
  40. deepdoctection/extern/texocr.py +4 -2
  41. deepdoctection/extern/tp/tfutils.py +43 -9
  42. deepdoctection/extern/tp/tpcompat.py +10 -7
  43. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  44. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  45. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  46. deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
  47. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  48. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
  49. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  50. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  56. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  57. deepdoctection/extern/tp/tpfrcnn/preproc.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  60. deepdoctection/extern/tpdetect.py +5 -8
  61. deepdoctection/mapper/__init__.py +3 -8
  62. deepdoctection/mapper/d2struct.py +8 -6
  63. deepdoctection/mapper/hfstruct.py +6 -1
  64. deepdoctection/mapper/laylmstruct.py +163 -20
  65. deepdoctection/mapper/maputils.py +3 -1
  66. deepdoctection/mapper/misc.py +6 -3
  67. deepdoctection/mapper/tpstruct.py +2 -2
  68. deepdoctection/pipe/__init__.py +1 -1
  69. deepdoctection/pipe/common.py +11 -9
  70. deepdoctection/pipe/concurrency.py +2 -1
  71. deepdoctection/pipe/layout.py +3 -1
  72. deepdoctection/pipe/lm.py +32 -64
  73. deepdoctection/pipe/order.py +142 -35
  74. deepdoctection/pipe/refine.py +8 -14
  75. deepdoctection/pipe/{cell.py → sub_layout.py} +1 -1
  76. deepdoctection/train/__init__.py +6 -12
  77. deepdoctection/train/d2_frcnn_train.py +21 -16
  78. deepdoctection/train/hf_detr_train.py +18 -11
  79. deepdoctection/train/hf_layoutlm_train.py +118 -101
  80. deepdoctection/train/tp_frcnn_train.py +21 -19
  81. deepdoctection/utils/env_info.py +41 -117
  82. deepdoctection/utils/logger.py +1 -0
  83. deepdoctection/utils/mocks.py +93 -0
  84. deepdoctection/utils/settings.py +1 -0
  85. deepdoctection/utils/viz.py +4 -3
  86. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/METADATA +27 -18
  87. deepdoctection-0.32.dist-info/RECORD +146 -0
  88. deepdoctection-0.31.dist-info/RECORD +0 -144
  89. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
  90. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/WHEEL +0 -0
  91. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
@@ -18,43 +18,39 @@
18
18
  """
19
19
  D2 GeneralizedRCNN model as predictor for deepdoctection pipeline
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  import io
22
24
  from abc import ABC
23
25
  from copy import copy
24
26
  from pathlib import Path
25
- from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence
27
+ from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Union
26
28
 
27
29
  import numpy as np
30
+ from lazy_imports import try_import
28
31
 
29
32
  from ..utils.detection_types import ImageType, Requirement
30
- from ..utils.file_utils import (
31
- detectron2_available,
32
- get_detectron2_requirement,
33
- get_pytorch_requirement,
34
- pytorch_available,
35
- )
33
+ from ..utils.file_utils import get_detectron2_requirement, get_pytorch_requirement
36
34
  from ..utils.metacfg import AttrDict, set_config_by_yaml
37
35
  from ..utils.settings import ObjectTypes, TypeOrStr, get_type
38
36
  from ..utils.transform import InferenceResize, ResizeTransform
39
37
  from .base import DetectionResult, ObjectDetector, PredictorBase
40
38
  from .pt.nms import batched_nms
41
- from .pt.ptutils import set_torch_auto_device
39
+ from .pt.ptutils import get_torch_device
42
40
 
43
- if pytorch_available():
41
+ with try_import() as pt_import_guard:
44
42
  import torch
45
43
  import torch.cuda
46
44
  from torch import nn # pylint: disable=W0611
47
45
 
48
- if detectron2_available():
46
+ with try_import() as d2_import_guard:
49
47
  from detectron2.checkpoint import DetectionCheckpointer
50
48
  from detectron2.config import CfgNode, get_cfg # pylint: disable=W0611
51
49
  from detectron2.modeling import GeneralizedRCNN, build_model # pylint: disable=W0611
52
50
  from detectron2.structures import Instances # pylint: disable=W0611
53
51
 
54
52
 
55
- def _d2_post_processing(
56
- predictions: Dict[str, "Instances"], nms_thresh_class_agnostic: float
57
- ) -> Dict[str, "Instances"]:
53
+ def _d2_post_processing(predictions: Dict[str, Instances], nms_thresh_class_agnostic: float) -> Dict[str, Instances]:
58
54
  """
59
55
  D2 postprocessing steps, so that detection outputs are aligned with outputs of other packages (e.g. Tensorpack).
60
56
  Apply a class agnostic NMS.
@@ -72,7 +68,7 @@ def _d2_post_processing(
72
68
 
73
69
  def d2_predict_image(
74
70
  np_img: ImageType,
75
- predictor: "nn.Module",
71
+ predictor: nn.Module,
76
72
  resizer: InferenceResize,
77
73
  nms_thresh_class_agnostic: float,
78
74
  ) -> List[DetectionResult]:
@@ -107,7 +103,7 @@ def d2_predict_image(
107
103
 
108
104
 
109
105
  def d2_jit_predict_image(
110
- np_img: ImageType, d2_predictor: "nn.Module", resizer: InferenceResize, nms_thresh_class_agnostic: float
106
+ np_img: ImageType, d2_predictor: nn.Module, resizer: InferenceResize, nms_thresh_class_agnostic: float
111
107
  ) -> List[DetectionResult]:
112
108
  """
113
109
  Run detection on an image using torchscript. It will also handle the preprocessing internally which
@@ -238,7 +234,7 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
238
234
  path_weights: str,
239
235
  categories: Mapping[str, TypeOrStr],
240
236
  config_overwrite: Optional[List[str]] = None,
241
- device: Optional[Literal["cpu", "cuda"]] = None,
237
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
242
238
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
243
239
  ):
244
240
  """
@@ -266,13 +262,10 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
266
262
 
267
263
  config_overwrite = config_overwrite if config_overwrite else []
268
264
  self.config_overwrite = config_overwrite
269
- if device is not None:
270
- self.device = device
271
- else:
272
- self.device = set_torch_auto_device()
265
+ self.device = get_torch_device(device)
273
266
 
274
267
  d2_conf_list = self._get_d2_config_list(path_weights, config_overwrite)
275
- self.cfg = self._set_config(path_yaml, d2_conf_list, device)
268
+ self.cfg = self._set_config(path_yaml, d2_conf_list, self.device)
276
269
 
277
270
  self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
278
271
  self.model_id = self.get_model_id()
@@ -282,21 +275,18 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
282
275
  self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
283
276
 
284
277
  @staticmethod
285
- def _set_config(
286
- path_yaml: str, d2_conf_list: List[str], device: Optional[Literal["cpu", "cuda"]] = None
287
- ) -> "CfgNode":
278
+ def _set_config(path_yaml: str, d2_conf_list: List[str], device: torch.device) -> CfgNode:
288
279
  cfg = get_cfg()
289
280
  # additional attribute with default value, so that the true value can be loaded from the configs
290
281
  cfg.NMS_THRESH_CLASS_AGNOSTIC = 0.1
291
282
  cfg.merge_from_file(path_yaml)
292
283
  cfg.merge_from_list(d2_conf_list)
293
- if not torch.cuda.is_available() or device == "cpu":
294
- cfg.MODEL.DEVICE = "cpu"
284
+ cfg.MODEL.DEVICE = str(device)
295
285
  cfg.freeze()
296
286
  return cfg
297
287
 
298
288
  @staticmethod
299
- def _set_model(config: "CfgNode") -> "GeneralizedRCNN":
289
+ def _set_model(config: CfgNode) -> GeneralizedRCNN:
300
290
  """
301
291
  Build the D2 model. It uses the available builtin tools of D2
302
292
 
@@ -306,7 +296,7 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
306
296
  return build_model(config.clone()).eval()
307
297
 
308
298
  @staticmethod
309
- def _instantiate_d2_predictor(wrapped_model: "GeneralizedRCNN", path_weights: str) -> None:
299
+ def _instantiate_d2_predictor(wrapped_model: GeneralizedRCNN, path_weights: str) -> None:
310
300
  checkpointer = DetectionCheckpointer(wrapped_model)
311
301
  checkpointer.load(path_weights)
312
302
 
@@ -341,8 +331,11 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
341
331
 
342
332
  @staticmethod
343
333
  def get_wrapped_model(
344
- path_yaml: str, path_weights: str, config_overwrite: List[str], device: Literal["cpu", "cuda"]
345
- ) -> "GeneralizedRCNN":
334
+ path_yaml: str,
335
+ path_weights: str,
336
+ config_overwrite: List[str],
337
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
338
+ ) -> GeneralizedRCNN:
346
339
  """
347
340
  Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
348
341
 
@@ -365,8 +358,7 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
365
358
  :return: Detectron2 GeneralizedRCNN model
366
359
  """
367
360
 
368
- if device is None:
369
- device = set_torch_auto_device()
361
+ device = get_torch_device(device)
370
362
  d2_conf_list = D2FrcnnDetector._get_d2_config_list(path_weights, config_overwrite)
371
363
  cfg = D2FrcnnDetector._set_config(path_yaml, d2_conf_list, device)
372
364
  model = D2FrcnnDetector._set_model(cfg)
@@ -21,13 +21,15 @@ jdeskew estimator and rotator to deskew images: <https://github.com/phamquiluan/
21
21
 
22
22
  from typing import List
23
23
 
24
+ from lazy_imports import try_import
25
+
24
26
  from ..utils.detection_types import ImageType, Requirement
25
- from ..utils.file_utils import get_jdeskew_requirement, jdeskew_available
27
+ from ..utils.file_utils import get_jdeskew_requirement
26
28
  from ..utils.settings import PageType
27
29
  from ..utils.viz import viz_handler
28
30
  from .base import DetectionResult, ImageTransformer
29
31
 
30
- if jdeskew_available():
32
+ with try_import() as import_guard:
31
33
  from jdeskew.estimator import get_angle
32
34
 
33
35
 
@@ -18,32 +18,40 @@
18
18
  """
19
19
  Deepdoctection wrappers for DocTr OCR text line detection and text recognition models
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  import os
22
24
  from abc import ABC
23
25
  from pathlib import Path
24
- from typing import Any, List, Literal, Mapping, Optional, Tuple
26
+ from typing import Any, List, Literal, Mapping, Optional, Tuple, Union
25
27
  from zipfile import ZipFile
26
28
 
29
+ from lazy_imports import try_import
30
+
27
31
  from ..utils.detection_types import ImageType, Requirement
28
- from ..utils.env_info import get_device
29
32
  from ..utils.error import DependencyError
30
33
  from ..utils.file_utils import (
31
- doctr_available,
32
34
  get_doctr_requirement,
33
35
  get_pytorch_requirement,
34
36
  get_tensorflow_requirement,
35
37
  get_tf_addons_requirements,
36
38
  pytorch_available,
37
- tf_addons_available,
38
39
  tf_available,
39
40
  )
40
41
  from ..utils.fs import load_json
41
42
  from ..utils.settings import LayoutType, ObjectTypes, PageType, TypeOrStr
42
43
  from ..utils.viz import viz_handler
43
44
  from .base import DetectionResult, ImageTransformer, ObjectDetector, PredictorBase, TextRecognizer
44
- from .pt.ptutils import set_torch_auto_device
45
+ from .pt.ptutils import get_torch_device
46
+ from .tp.tfutils import get_tf_device
47
+
48
+ with try_import() as pt_import_guard:
49
+ import torch
45
50
 
46
- if doctr_available() and ((tf_addons_available() and tf_available()) or pytorch_available()):
51
+ with try_import() as tf_import_guard:
52
+ import tensorflow as tf # type: ignore # pylint: disable=E0401
53
+
54
+ with try_import() as doctr_import_guard:
47
55
  from doctr.models._utils import estimate_orientation
48
56
  from doctr.models.detection.predictor import DetectionPredictor # pylint: disable=W0611
49
57
  from doctr.models.detection.zoo import detection_predictor
@@ -51,32 +59,19 @@ if doctr_available() and ((tf_addons_available() and tf_available()) or pytorch_
51
59
  from doctr.models.recognition.predictor import RecognitionPredictor # pylint: disable=W0611
52
60
  from doctr.models.recognition.zoo import ARCHS, recognition
53
61
 
54
- if pytorch_available():
55
- import torch
56
-
57
- if tf_available():
58
- import tensorflow as tf # type: ignore # pylint: disable=E0401
59
-
60
-
61
- def _set_device_str(device: Optional[str] = None) -> str:
62
- if device is not None:
63
- if tf_available():
64
- device = "/" + device.replace("cuda", "gpu") + ":0"
65
- elif pytorch_available():
66
- device = set_torch_auto_device()
67
- else:
68
- device = "/gpu:0" # we impose to install tensorflow-gpu because of Tensorpack models
69
- return device
70
-
71
62
 
72
- def _load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
73
- if lib == "PT" and pytorch_available():
63
+ def _load_model(
64
+ path_weights: str, doctr_predictor: Any, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
65
+ ) -> None:
66
+ """Loading a model either in TF or PT. We only shift the model to the device when using PyTorch. The shift of
67
+ the model to the device in Tensorflow is done in the predict function."""
68
+ if lib == "PT":
74
69
  state_dict = torch.load(path_weights, map_location=device)
75
70
  for key in list(state_dict.keys()):
76
71
  state_dict["model." + key] = state_dict.pop(key)
77
72
  doctr_predictor.load_state_dict(state_dict)
78
73
  doctr_predictor.to(device)
79
- elif lib == "TF" and tf_available():
74
+ elif lib == "TF":
80
75
  # Unzip the archive
81
76
  params_path = Path(path_weights).parent
82
77
  is_zip_path = path_weights.endswith(".zip")
@@ -89,29 +84,33 @@ def _load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Liter
89
84
 
90
85
 
91
86
  def auto_select_lib_for_doctr() -> Literal["PT", "TF"]:
92
- """Auto select the DL library from the installed and from environment variables"""
93
- if tf_available() and os.environ.get("USE_TF", os.environ.get("USE_TENSORFLOW", False)):
94
- os.environ["USE_TF"] = "TRUE"
95
- return "TF"
96
- if pytorch_available() and os.environ.get("USE_TORCH", os.environ.get("USE_PYTORCH", False)):
87
+ """Auto select the DL library from environment variables"""
88
+ if os.environ.get("USE_TORCH"):
97
89
  return "PT"
98
- raise DependencyError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextlineDetector")
90
+ if os.environ.get("USE_TF"):
91
+ return "TF"
92
+ raise DependencyError("At least one of the env variables USE_TORCH or USE_TF must be set.")
99
93
 
100
94
 
101
- def doctr_predict_text_lines(np_img: ImageType, predictor: "DetectionPredictor", device: str) -> List[DetectionResult]:
95
+ def doctr_predict_text_lines(
96
+ np_img: ImageType, predictor: DetectionPredictor, device: Union[torch.device, tf.device], lib: Literal["TF", "PT"]
97
+ ) -> List[DetectionResult]:
102
98
  """
103
99
  Generating text line DetectionResult based on Doctr DetectionPredictor.
104
100
 
105
101
  :param np_img: Image in np.array.
106
102
  :param predictor: `doctr.models.detection.predictor.DetectionPredictor`
107
103
  :param device: Will only be used in tensorflow settings. Either /gpu:0 or /cpu:0
104
+ :param lib: "TF" or "PT"
108
105
  :return: A list of text line detection results (without text).
109
106
  """
110
- if tf_available() and device is not None:
111
- with tf.device(device):
107
+ if lib == "TF":
108
+ with device:
112
109
  raw_output = predictor([np_img])
113
- else:
110
+ elif lib == "PT":
114
111
  raw_output = predictor([np_img])
112
+ else:
113
+ raise DependencyError("Tensorflow or PyTorch must be installed.")
115
114
  detection_results = [
116
115
  DetectionResult(
117
116
  box=box[:4].tolist(), class_id=1, score=box[4], absolute_coords=False, class_name=LayoutType.word
@@ -122,7 +121,10 @@ def doctr_predict_text_lines(np_img: ImageType, predictor: "DetectionPredictor",
122
121
 
123
122
 
124
123
  def doctr_predict_text(
125
- inputs: List[Tuple[str, ImageType]], predictor: "RecognitionPredictor", device: str
124
+ inputs: List[Tuple[str, ImageType]],
125
+ predictor: RecognitionPredictor,
126
+ device: Union[torch.device, tf.device],
127
+ lib: Literal["TF", "PT"],
126
128
  ) -> List[DetectionResult]:
127
129
  """
128
130
  Calls Doctr text recognition model on a batch of numpy arrays (text lines predicted from a text line detector) and
@@ -132,15 +134,18 @@ def doctr_predict_text(
132
134
  text line
133
135
  :param predictor: `doctr.models.detection.predictor.RecognitionPredictor`
134
136
  :param device: Will only be used in tensorflow settings. Either /gpu:0 or /cpu:0
137
+ :param lib: "TF" or "PT"
135
138
  :return: A list of DetectionResult containing recognized text.
136
139
  """
137
140
 
138
141
  uuids, images = list(zip(*inputs))
139
- if tf_available() and device is not None:
140
- with tf.device(device):
142
+ if lib == "TF":
143
+ with device:
141
144
  raw_output = predictor(list(images))
142
- else:
145
+ elif lib == "PT":
143
146
  raw_output = predictor(list(images))
147
+ else:
148
+ raise DependencyError("Tensorflow or PyTorch must be installed.")
144
149
  detection_results = [
145
150
  DetectionResult(score=output[1], text=output[0], uuid=uuid) for uuid, output in zip(uuids, raw_output)
146
151
  ]
@@ -208,7 +213,7 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
208
213
  architecture: str,
209
214
  path_weights: str,
210
215
  categories: Mapping[str, TypeOrStr],
211
- device: Optional[Literal["cpu", "cuda"]] = None,
216
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device, tf.device]] = None,
212
217
  lib: Optional[Literal["PT", "TF"]] = None,
213
218
  ) -> None:
214
219
  """
@@ -217,7 +222,7 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
217
222
  https://github.com/mindee/doctr/blob/main/doctr/models/detection/zoo.py#L20
218
223
  :param path_weights: Path to the weights of the model
219
224
  :param categories: A dict with the model output label and value
220
- :param device: "cpu" or "cuda". Will default to "cuda" if the required hardware is available.
225
+ :param device: "cpu" or "cuda" or any tf.device or torch.device. The device must be compatible with the dll
221
226
  :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used.
222
227
  """
223
228
  super().__init__(categories, lib)
@@ -227,18 +232,12 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
227
232
  self.name = self.get_name(self.path_weights, self.architecture)
228
233
  self.model_id = self.get_model_id()
229
234
 
230
- if device is None:
231
- if self.lib == "TF":
232
- device = "cuda" if tf.test.is_gpu_available() else "cpu"
233
- elif self.lib == "PT":
234
- auto_device = get_device(False)
235
- device = "cpu" if auto_device == "mps" else auto_device
236
- else:
237
- raise DependencyError("Cannot select device automatically. Please set the device manually.")
235
+ if self.lib == "TF":
236
+ self.device = get_tf_device(device)
237
+ if self.lib == "PT":
238
+ self.device = get_torch_device(device)
238
239
 
239
- self.device_input = device
240
- self.device = _set_device_str(device)
241
- self.doctr_predictor = self.get_wrapped_model(self.architecture, self.path_weights, self.device_input, self.lib)
240
+ self.doctr_predictor = self.get_wrapped_model(self.architecture, self.path_weights, self.device, self.lib)
242
241
 
243
242
  def predict(self, np_img: ImageType) -> List[DetectionResult]:
244
243
  """
@@ -247,28 +246,30 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
247
246
  :param np_img: image as numpy array
248
247
  :return: A list of DetectionResult
249
248
  """
250
- detection_results = doctr_predict_text_lines(np_img, self.doctr_predictor, self.device)
249
+ detection_results = doctr_predict_text_lines(np_img, self.doctr_predictor, self.device, self.lib)
251
250
  return detection_results
252
251
 
253
252
  @classmethod
254
253
  def get_requirements(cls) -> List[Requirement]:
255
- if tf_available():
254
+ if os.environ.get("DD_USE_TF"):
256
255
  return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
257
- if pytorch_available():
256
+ if os.environ.get("DD_USE_TORCH"):
258
257
  return [get_pytorch_requirement(), get_doctr_requirement()]
259
258
  raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextlineDetector")
260
259
 
261
260
  def clone(self) -> PredictorBase:
262
- return self.__class__(self.architecture, self.path_weights, self.categories, self.device_input, self.lib)
261
+ return self.__class__(self.architecture, self.path_weights, self.categories, self.device, self.lib)
263
262
 
264
263
  @staticmethod
265
- def load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
264
+ def load_model(
265
+ path_weights: str, doctr_predictor: Any, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
266
+ ) -> None:
266
267
  """Loading model weights"""
267
268
  _load_model(path_weights, doctr_predictor, device, lib)
268
269
 
269
270
  @staticmethod
270
271
  def get_wrapped_model(
271
- architecture: str, path_weights: str, device: Literal["cpu", "cuda"], lib: Literal["PT", "TF"]
272
+ architecture: str, path_weights: str, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
272
273
  ) -> Any:
273
274
  """
274
275
  Get the inner (wrapped) model.
@@ -286,8 +287,7 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
286
287
  :return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow
287
288
  """
288
289
  doctr_predictor = detection_predictor(arch=architecture, pretrained=False, pretrained_backbone=False)
289
- device_str = _set_device_str(device)
290
- DoctrTextlineDetector.load_model(path_weights, doctr_predictor, device_str, lib)
290
+ DoctrTextlineDetector.load_model(path_weights, doctr_predictor, device, lib)
291
291
  return doctr_predictor
292
292
 
293
293
 
@@ -325,14 +325,13 @@ class DoctrTextRecognizer(TextRecognizer):
325
325
 
326
326
  for dp in df:
327
327
  ...
328
-
329
328
  """
330
329
 
331
330
  def __init__(
332
331
  self,
333
332
  architecture: str,
334
333
  path_weights: str,
335
- device: Optional[Literal["cpu", "cuda"]] = None,
334
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device, tf.device]] = None,
336
335
  lib: Optional[Literal["PT", "TF"]] = None,
337
336
  path_config_json: Optional[str] = None,
338
337
  ) -> None:
@@ -355,22 +354,16 @@ class DoctrTextRecognizer(TextRecognizer):
355
354
  self.name = self.get_name(self.path_weights, self.architecture)
356
355
  self.model_id = self.get_model_id()
357
356
 
358
- if device is None:
359
- if self.lib == "TF":
360
- device = "cuda" if tf.test.is_gpu_available() else "cpu"
361
- if self.lib == "PT":
362
- auto_device = get_device(False)
363
- device = "cpu" if auto_device == "mps" else auto_device
364
- else:
365
- raise DependencyError("Cannot select device automatically. Please set the device manually.")
366
-
367
- self.device_input = device
368
- self.device = _set_device_str(device)
357
+ if self.lib == "TF":
358
+ self.device = get_tf_device(device)
359
+ if self.lib == "PT":
360
+ self.device = get_torch_device(device)
361
+
369
362
  self.path_config_json = path_config_json
370
363
  self.doctr_predictor = self.build_model(self.architecture, self.path_config_json)
371
364
  self.load_model(self.path_weights, self.doctr_predictor, self.device, self.lib)
372
365
  self.doctr_predictor = self.get_wrapped_model(
373
- self.architecture, self.path_weights, self.device_input, self.lib, self.path_config_json
366
+ self.architecture, self.path_weights, self.device, self.lib, self.path_config_json
374
367
  )
375
368
 
376
369
  def predict(self, images: List[Tuple[str, ImageType]]) -> List[DetectionResult]:
@@ -381,7 +374,7 @@ class DoctrTextRecognizer(TextRecognizer):
381
374
  :return: A list of DetectionResult
382
375
  """
383
376
  if images:
384
- return doctr_predict_text(images, self.doctr_predictor, self.device)
377
+ return doctr_predict_text(images, self.doctr_predictor, self.device, self.lib)
385
378
  return []
386
379
 
387
380
  @classmethod
@@ -393,10 +386,12 @@ class DoctrTextRecognizer(TextRecognizer):
393
386
  raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextRecognizer")
394
387
 
395
388
  def clone(self) -> PredictorBase:
396
- return self.__class__(self.architecture, self.path_weights, self.device_input, self.lib)
389
+ return self.__class__(self.architecture, self.path_weights, self.device, self.lib)
397
390
 
398
391
  @staticmethod
399
- def load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
392
+ def load_model(
393
+ path_weights: str, doctr_predictor: Any, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
394
+ ) -> None:
400
395
  """Loading model weights"""
401
396
  _load_model(path_weights, doctr_predictor, device, lib)
402
397
 
@@ -438,7 +433,7 @@ class DoctrTextRecognizer(TextRecognizer):
438
433
  def get_wrapped_model(
439
434
  architecture: str,
440
435
  path_weights: str,
441
- device: Literal["cpu", "cuda"],
436
+ device: Union[torch.device, tf.device],
442
437
  lib: Literal["PT", "TF"],
443
438
  path_config_json: Optional[str] = None,
444
439
  ) -> Any:
@@ -456,8 +451,7 @@ class DoctrTextRecognizer(TextRecognizer):
456
451
  :return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow
457
452
  """
458
453
  doctr_predictor = DoctrTextRecognizer.build_model(architecture, path_config_json)
459
- device_str = _set_device_str(device)
460
- DoctrTextRecognizer.load_model(path_weights, doctr_predictor, device_str, lib)
454
+ DoctrTextRecognizer.load_model(path_weights, doctr_predictor, device, lib)
461
455
  return doctr_predictor
462
456
 
463
457
  @staticmethod
@@ -23,11 +23,13 @@ from copy import copy
23
23
  from pathlib import Path
24
24
  from typing import Any, List, Mapping, Tuple, Union
25
25
 
26
- from ..utils.file_utils import Requirement, fasttext_available, get_fasttext_requirement
26
+ from lazy_imports import try_import
27
+
28
+ from ..utils.file_utils import Requirement, get_fasttext_requirement
27
29
  from ..utils.settings import TypeOrStr, get_type
28
30
  from .base import DetectionResult, LanguageDetector, PredictorBase
29
31
 
30
- if fasttext_available():
32
+ with try_import() as import_guard:
31
33
  from fasttext import load_model # type: ignore
32
34
 
33
35
 
@@ -18,27 +18,25 @@
18
18
  """
19
19
  HF Detr model for object detection.
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
23
  from abc import ABC
23
24
  from pathlib import Path
24
- from typing import List, Literal, Mapping, Optional, Sequence
25
+ from typing import List, Literal, Mapping, Optional, Sequence, Union
26
+
27
+ from lazy_imports import try_import
25
28
 
26
29
  from ..utils.detection_types import ImageType, Requirement
27
- from ..utils.file_utils import (
28
- get_pytorch_requirement,
29
- get_transformers_requirement,
30
- pytorch_available,
31
- transformers_available,
32
- )
30
+ from ..utils.file_utils import get_pytorch_requirement, get_transformers_requirement
33
31
  from ..utils.settings import TypeOrStr, get_type
34
32
  from .base import DetectionResult, ObjectDetector
35
- from .pt.ptutils import set_torch_auto_device
33
+ from .pt.ptutils import get_torch_device
36
34
 
37
- if pytorch_available():
35
+ with try_import() as pt_import_guard:
38
36
  import torch # pylint: disable=W0611
39
37
  from torchvision.ops import boxes as box_ops # type: ignore
40
38
 
41
- if transformers_available():
39
+ with try_import() as tr_import_guard:
42
40
  from transformers import ( # pylint: disable=W0611
43
41
  AutoFeatureExtractor,
44
42
  DetrFeatureExtractor,
@@ -48,16 +46,16 @@ if transformers_available():
48
46
 
49
47
 
50
48
  def _detr_post_processing(
51
- boxes: "torch.Tensor", scores: "torch.Tensor", labels: "torch.Tensor", nms_thresh: float
52
- ) -> "torch.Tensor":
49
+ boxes: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, nms_thresh: float
50
+ ) -> torch.Tensor:
53
51
  return box_ops.batched_nms(boxes.float(), scores, labels, nms_thresh)
54
52
 
55
53
 
56
54
  def detr_predict_image(
57
55
  np_img: ImageType,
58
- predictor: "TableTransformerForObjectDetection",
59
- feature_extractor: "DetrFeatureExtractor",
60
- device: Literal["cpu", "cuda"],
56
+ predictor: TableTransformerForObjectDetection,
57
+ feature_extractor: DetrFeatureExtractor,
58
+ device: torch.device,
61
59
  threshold: float,
62
60
  nms_threshold: float,
63
61
  ) -> List[DetectionResult]:
@@ -168,7 +166,7 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
168
166
  path_weights: str,
169
167
  path_feature_extractor_config_json: str,
170
168
  categories: Mapping[str, TypeOrStr],
171
- device: Optional[Literal["cpu", "cuda"]] = None,
169
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
172
170
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
173
171
  ):
174
172
  """
@@ -195,10 +193,7 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
195
193
  self.hf_detr_predictor = self.get_model(self.path_weights, self.config)
196
194
  self.feature_extractor = self.get_pre_processor(self.path_feature_extractor_config)
197
195
 
198
- if device is not None:
199
- self.device = device
200
- else:
201
- self.device = set_torch_auto_device()
196
+ self.device = get_torch_device(device)
202
197
  self.hf_detr_predictor.to(self.device)
203
198
 
204
199
  def predict(self, np_img: ImageType) -> List[DetectionResult]:
@@ -213,7 +208,7 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
213
208
  return self._map_category_names(results)
214
209
 
215
210
  @staticmethod
216
- def get_model(path_weights: str, config: "PretrainedConfig") -> "TableTransformerForObjectDetection":
211
+ def get_model(path_weights: str, config: PretrainedConfig) -> TableTransformerForObjectDetection:
217
212
  """
218
213
  Builds the Detr model
219
214
 
@@ -226,7 +221,7 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
226
221
  )
227
222
 
228
223
  @staticmethod
229
- def get_pre_processor(path_feature_extractor_config: str) -> "DetrFeatureExtractor":
224
+ def get_pre_processor(path_feature_extractor_config: str) -> DetrFeatureExtractor:
230
225
  """
231
226
  Builds the feature extractor
232
227
 
@@ -235,7 +230,7 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
235
230
  return AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path=path_feature_extractor_config)
236
231
 
237
232
  @staticmethod
238
- def get_config(path_config: str) -> "PretrainedConfig":
233
+ def get_config(path_config: str) -> PretrainedConfig:
239
234
  """
240
235
  Builds the config
241
236
 
@@ -252,15 +247,15 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
252
247
  def get_requirements(cls) -> List[Requirement]:
253
248
  return [get_pytorch_requirement(), get_transformers_requirement()]
254
249
 
255
- def clone(self) -> "HFDetrDerivedDetector":
250
+ def clone(self) -> HFDetrDerivedDetector:
256
251
  return self.__class__(
257
252
  self.path_config, self.path_weights, self.path_feature_extractor_config, self.categories, self.device
258
253
  )
259
254
 
260
255
  @staticmethod
261
256
  def get_wrapped_model(
262
- path_config_json: str, path_weights: str, device: Optional[Literal["cpu", "cuda"]] = None
263
- ) -> "TableTransformerForObjectDetection":
257
+ path_config_json: str, path_weights: str, device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None
258
+ ) -> TableTransformerForObjectDetection:
264
259
  """
265
260
  Get the wrapped model
266
261
 
@@ -271,6 +266,5 @@ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
271
266
  """
272
267
  config = HFDetrDerivedDetector.get_config(path_config_json)
273
268
  hf_detr_predictor = HFDetrDerivedDetector.get_model(path_weights, config)
274
- if device is None:
275
- device = set_torch_auto_device()
269
+ device = get_torch_device()
276
270
  return hf_detr_predictor.to(device)