deepdoctection 0.30__py3-none-any.whl → 0.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (74) hide show
  1. deepdoctection/__init__.py +4 -2
  2. deepdoctection/analyzer/dd.py +6 -5
  3. deepdoctection/dataflow/base.py +0 -19
  4. deepdoctection/dataflow/custom.py +4 -3
  5. deepdoctection/dataflow/custom_serialize.py +14 -5
  6. deepdoctection/dataflow/parallel_map.py +12 -11
  7. deepdoctection/dataflow/serialize.py +5 -4
  8. deepdoctection/datapoint/annotation.py +33 -12
  9. deepdoctection/datapoint/box.py +1 -4
  10. deepdoctection/datapoint/convert.py +3 -1
  11. deepdoctection/datapoint/image.py +66 -29
  12. deepdoctection/datapoint/view.py +57 -25
  13. deepdoctection/datasets/adapter.py +1 -1
  14. deepdoctection/datasets/base.py +83 -10
  15. deepdoctection/datasets/dataflow_builder.py +1 -1
  16. deepdoctection/datasets/info.py +2 -2
  17. deepdoctection/datasets/instances/layouttest.py +2 -7
  18. deepdoctection/eval/accmetric.py +1 -1
  19. deepdoctection/eval/base.py +5 -4
  20. deepdoctection/eval/eval.py +2 -2
  21. deepdoctection/eval/tp_eval_callback.py +5 -4
  22. deepdoctection/extern/base.py +39 -13
  23. deepdoctection/extern/d2detect.py +164 -64
  24. deepdoctection/extern/deskew.py +32 -7
  25. deepdoctection/extern/doctrocr.py +227 -39
  26. deepdoctection/extern/fastlang.py +45 -7
  27. deepdoctection/extern/hfdetr.py +90 -33
  28. deepdoctection/extern/hflayoutlm.py +109 -22
  29. deepdoctection/extern/pdftext.py +2 -1
  30. deepdoctection/extern/pt/ptutils.py +3 -2
  31. deepdoctection/extern/tessocr.py +134 -22
  32. deepdoctection/extern/texocr.py +2 -0
  33. deepdoctection/extern/tp/tpcompat.py +4 -4
  34. deepdoctection/extern/tp/tpfrcnn/preproc.py +2 -7
  35. deepdoctection/extern/tpdetect.py +50 -23
  36. deepdoctection/mapper/d2struct.py +1 -1
  37. deepdoctection/mapper/hfstruct.py +1 -1
  38. deepdoctection/mapper/laylmstruct.py +1 -1
  39. deepdoctection/mapper/maputils.py +13 -2
  40. deepdoctection/mapper/prodigystruct.py +1 -1
  41. deepdoctection/mapper/pubstruct.py +10 -10
  42. deepdoctection/mapper/tpstruct.py +1 -1
  43. deepdoctection/pipe/anngen.py +35 -8
  44. deepdoctection/pipe/base.py +53 -19
  45. deepdoctection/pipe/cell.py +29 -8
  46. deepdoctection/pipe/common.py +12 -4
  47. deepdoctection/pipe/doctectionpipe.py +2 -2
  48. deepdoctection/pipe/language.py +3 -2
  49. deepdoctection/pipe/layout.py +3 -2
  50. deepdoctection/pipe/lm.py +2 -2
  51. deepdoctection/pipe/refine.py +18 -10
  52. deepdoctection/pipe/segment.py +21 -16
  53. deepdoctection/pipe/text.py +14 -8
  54. deepdoctection/pipe/transform.py +16 -9
  55. deepdoctection/train/d2_frcnn_train.py +15 -12
  56. deepdoctection/train/hf_detr_train.py +8 -6
  57. deepdoctection/train/hf_layoutlm_train.py +16 -11
  58. deepdoctection/utils/__init__.py +3 -0
  59. deepdoctection/utils/concurrency.py +1 -1
  60. deepdoctection/utils/context.py +2 -2
  61. deepdoctection/utils/env_info.py +55 -22
  62. deepdoctection/utils/error.py +84 -0
  63. deepdoctection/utils/file_utils.py +4 -15
  64. deepdoctection/utils/fs.py +7 -7
  65. deepdoctection/utils/pdf_utils.py +5 -4
  66. deepdoctection/utils/settings.py +5 -1
  67. deepdoctection/utils/transform.py +1 -1
  68. deepdoctection/utils/utils.py +0 -6
  69. deepdoctection/utils/viz.py +44 -2
  70. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/METADATA +33 -58
  71. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/RECORD +74 -73
  72. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/WHEEL +1 -1
  73. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/LICENSE +0 -0
  74. {deepdoctection-0.30.dist-info → deepdoctection-0.31.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,15 @@
18
18
  """
19
19
  Deepdoctection wrappers for DocTr OCR text line detection and text recognition models
20
20
  """
21
-
21
+ import os
22
+ from abc import ABC
22
23
  from pathlib import Path
23
24
  from typing import Any, List, Literal, Mapping, Optional, Tuple
24
25
  from zipfile import ZipFile
25
26
 
26
27
  from ..utils.detection_types import ImageType, Requirement
28
+ from ..utils.env_info import get_device
29
+ from ..utils.error import DependencyError
27
30
  from ..utils.file_utils import (
28
31
  doctr_available,
29
32
  get_doctr_requirement,
@@ -35,11 +38,13 @@ from ..utils.file_utils import (
35
38
  tf_available,
36
39
  )
37
40
  from ..utils.fs import load_json
38
- from ..utils.settings import LayoutType, ObjectTypes, TypeOrStr
39
- from .base import DetectionResult, ObjectDetector, PredictorBase, TextRecognizer
41
+ from ..utils.settings import LayoutType, ObjectTypes, PageType, TypeOrStr
42
+ from ..utils.viz import viz_handler
43
+ from .base import DetectionResult, ImageTransformer, ObjectDetector, PredictorBase, TextRecognizer
40
44
  from .pt.ptutils import set_torch_auto_device
41
45
 
42
46
  if doctr_available() and ((tf_addons_available() and tf_available()) or pytorch_available()):
47
+ from doctr.models._utils import estimate_orientation
43
48
  from doctr.models.detection.predictor import DetectionPredictor # pylint: disable=W0611
44
49
  from doctr.models.detection.zoo import detection_predictor
45
50
  from doctr.models.preprocessor import PreProcessor
@@ -64,7 +69,7 @@ def _set_device_str(device: Optional[str] = None) -> str:
64
69
  return device
65
70
 
66
71
 
67
- def _load_model(path_weights: str, doctr_predictor: Any, device: str, lib: str) -> None:
72
+ def _load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
68
73
  if lib == "PT" and pytorch_available():
69
74
  state_dict = torch.load(path_weights, map_location=device)
70
75
  for key in list(state_dict.keys()):
@@ -83,6 +88,16 @@ def _load_model(path_weights: str, doctr_predictor: Any, device: str, lib: str)
83
88
  doctr_predictor.model.load_weights(path_weights)
84
89
 
85
90
 
91
+ def auto_select_lib_for_doctr() -> Literal["PT", "TF"]:
92
+ """Auto select the DL library from the installed and from environment variables"""
93
+ if tf_available() and os.environ.get("USE_TF", os.environ.get("USE_TENSORFLOW", False)):
94
+ os.environ["USE_TF"] = "TRUE"
95
+ return "TF"
96
+ if pytorch_available() and os.environ.get("USE_TORCH", os.environ.get("USE_PYTORCH", False)):
97
+ return "PT"
98
+ raise DependencyError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextlineDetector")
99
+
100
+
86
101
  def doctr_predict_text_lines(np_img: ImageType, predictor: "DetectionPredictor", device: str) -> List[DetectionResult]:
87
102
  """
88
103
  Generating text line DetectionResult based on Doctr DetectionPredictor.
@@ -132,7 +147,28 @@ def doctr_predict_text(
132
147
  return detection_results
133
148
 
134
149
 
135
- class DoctrTextlineDetector(ObjectDetector):
150
+ class DoctrTextlineDetectorMixin(ObjectDetector, ABC):
151
+ """Base class for Doctr textline detector. This class only implements the basic wrapper functions"""
152
+
153
+ def __init__(self, categories: Mapping[str, TypeOrStr], lib: Optional[Literal["PT", "TF"]] = None):
154
+ self.categories = categories # type: ignore
155
+ self.lib = lib if lib is not None else self.auto_select_lib()
156
+
157
+ def possible_categories(self) -> List[ObjectTypes]:
158
+ return [LayoutType.word]
159
+
160
+ @staticmethod
161
+ def get_name(path_weights: str, architecture: str) -> str:
162
+ """Returns the name of the model"""
163
+ return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])
164
+
165
+ @staticmethod
166
+ def auto_select_lib() -> Literal["PT", "TF"]:
167
+ """Auto select the DL library from the installed and from environment variables"""
168
+ return auto_select_lib_for_doctr()
169
+
170
+
171
+ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
136
172
  """
137
173
  A deepdoctection wrapper of DocTr text line detector. We model text line detection as ObjectDetector
138
174
  and assume to use this detector in a ImageLayoutService.
@@ -165,8 +201,6 @@ class DoctrTextlineDetector(ObjectDetector):
165
201
 
166
202
  for dp in df:
167
203
  ...
168
-
169
-
170
204
  """
171
205
 
172
206
  def __init__(
@@ -175,20 +209,36 @@ class DoctrTextlineDetector(ObjectDetector):
175
209
  path_weights: str,
176
210
  categories: Mapping[str, TypeOrStr],
177
211
  device: Optional[Literal["cpu", "cuda"]] = None,
178
- lib: str = "TF",
212
+ lib: Optional[Literal["PT", "TF"]] = None,
179
213
  ) -> None:
180
- self.lib = lib
181
- self.name = "doctr_text_detector"
214
+ """
215
+ :param architecture: DocTR supports various text line detection models, e.g. "db_resnet50",
216
+ "db_mobilenet_v3_large". The full list can be found here:
217
+ https://github.com/mindee/doctr/blob/main/doctr/models/detection/zoo.py#L20
218
+ :param path_weights: Path to the weights of the model
219
+ :param categories: A dict with the model output label and value
220
+ :param device: "cpu" or "cuda". Will default to "cuda" if the required hardware is available.
221
+ :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used.
222
+ """
223
+ super().__init__(categories, lib)
182
224
  self.architecture = architecture
183
225
  self.path_weights = path_weights
184
- self.doctr_predictor = detection_predictor(
185
- arch=self.architecture, pretrained=False, pretrained_backbone=False
186
- ) # we will be loading the model
187
- # later because there is no easy way in doctr to load a model by giving only a path to its weights
188
- self.categories = categories # type: ignore
226
+
227
+ self.name = self.get_name(self.path_weights, self.architecture)
228
+ self.model_id = self.get_model_id()
229
+
230
+ if device is None:
231
+ if self.lib == "TF":
232
+ device = "cuda" if tf.test.is_gpu_available() else "cpu"
233
+ elif self.lib == "PT":
234
+ auto_device = get_device(False)
235
+ device = "cpu" if auto_device == "mps" else auto_device
236
+ else:
237
+ raise DependencyError("Cannot select device automatically. Please set the device manually.")
238
+
189
239
  self.device_input = device
190
240
  self.device = _set_device_str(device)
191
- self.load_model()
241
+ self.doctr_predictor = self.get_wrapped_model(self.architecture, self.path_weights, self.device_input, self.lib)
192
242
 
193
243
  def predict(self, np_img: ImageType) -> List[DetectionResult]:
194
244
  """
@@ -211,12 +261,34 @@ class DoctrTextlineDetector(ObjectDetector):
211
261
  def clone(self) -> PredictorBase:
212
262
  return self.__class__(self.architecture, self.path_weights, self.categories, self.device_input, self.lib)
213
263
 
214
- def possible_categories(self) -> List[ObjectTypes]:
215
- return [LayoutType.word]
216
-
217
- def load_model(self) -> None:
264
+ @staticmethod
265
+ def load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
218
266
  """Loading model weights"""
219
- _load_model(self.path_weights, self.doctr_predictor, self.device, self.lib)
267
+ _load_model(path_weights, doctr_predictor, device, lib)
268
+
269
+ @staticmethod
270
+ def get_wrapped_model(
271
+ architecture: str, path_weights: str, device: Literal["cpu", "cuda"], lib: Literal["PT", "TF"]
272
+ ) -> Any:
273
+ """
274
+ Get the inner (wrapped) model.
275
+
276
+ :param architecture: DocTR supports various text line detection models, e.g. "db_resnet50",
277
+ "db_mobilenet_v3_large". The full list can be found here:
278
+ https://github.com/mindee/doctr/blob/main/doctr/models/detection/zoo.py#L20
279
+ :param path_weights: Path to the weights of the model
280
+ :param device: "cpu" or "cuda". Will default to "cuda" if the required hardware is available.
281
+ :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used. Make sure,
282
+ these variables are set. If not, use
283
+
284
+ deepdoctection.utils.env_info.auto_select_lib_and_device
285
+
286
+ :return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow
287
+ """
288
+ doctr_predictor = detection_predictor(arch=architecture, pretrained=False, pretrained_backbone=False)
289
+ device_str = _set_device_str(device)
290
+ DoctrTextlineDetector.load_model(path_weights, doctr_predictor, device_str, lib)
291
+ return doctr_predictor
220
292
 
221
293
 
222
294
  class DoctrTextRecognizer(TextRecognizer):
@@ -261,7 +333,7 @@ class DoctrTextRecognizer(TextRecognizer):
261
333
  architecture: str,
262
334
  path_weights: str,
263
335
  device: Optional[Literal["cpu", "cuda"]] = None,
264
- lib: str = "TF",
336
+ lib: Optional[Literal["PT", "TF"]] = None,
265
337
  path_config_json: Optional[str] = None,
266
338
  ) -> None:
267
339
  """
@@ -270,19 +342,36 @@ class DoctrTextRecognizer(TextRecognizer):
270
342
  https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py#L16.
271
343
  :param path_weights: Path to the weights of the model
272
344
  :param device: "cpu" or "cuda". Will default to "cuda" if the required hardware is available.
273
- :param lib: "TF" or "PT". Will default to "TF".
345
+ :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used.
274
346
  :param path_config_json: Path to a json file containing the configuration of the model. Useful, if you have
275
347
  a model trained on custom vocab.
276
348
  """
277
- self.lib = lib
278
- self.name = "doctr_text_recognizer"
349
+
350
+ self.lib = lib if lib is not None else self.auto_select_lib()
351
+
279
352
  self.architecture = architecture
280
353
  self.path_weights = path_weights
354
+
355
+ self.name = self.get_name(self.path_weights, self.architecture)
356
+ self.model_id = self.get_model_id()
357
+
358
+ if device is None:
359
+ if self.lib == "TF":
360
+ device = "cuda" if tf.test.is_gpu_available() else "cpu"
361
+ if self.lib == "PT":
362
+ auto_device = get_device(False)
363
+ device = "cpu" if auto_device == "mps" else auto_device
364
+ else:
365
+ raise DependencyError("Cannot select device automatically. Please set the device manually.")
366
+
281
367
  self.device_input = device
282
368
  self.device = _set_device_str(device)
283
369
  self.path_config_json = path_config_json
284
- self.doctr_predictor = self.build_model()
285
- self.load_model()
370
+ self.doctr_predictor = self.build_model(self.architecture, self.path_config_json)
371
+ self.load_model(self.path_weights, self.doctr_predictor, self.device, self.lib)
372
+ self.doctr_predictor = self.get_wrapped_model(
373
+ self.architecture, self.path_weights, self.device_input, self.lib, self.path_config_json
374
+ )
286
375
 
287
376
  def predict(self, images: List[Tuple[str, ImageType]]) -> List[DetectionResult]:
288
377
  """
@@ -306,19 +395,21 @@ class DoctrTextRecognizer(TextRecognizer):
306
395
  def clone(self) -> PredictorBase:
307
396
  return self.__class__(self.architecture, self.path_weights, self.device_input, self.lib)
308
397
 
309
- def load_model(self) -> None:
398
+ @staticmethod
399
+ def load_model(path_weights: str, doctr_predictor: Any, device: str, lib: Literal["PT", "TF"]) -> None:
310
400
  """Loading model weights"""
311
- _load_model(self.path_weights, self.doctr_predictor, self.device, self.lib)
401
+ _load_model(path_weights, doctr_predictor, device, lib)
312
402
 
313
- def build_model(self) -> "RecognitionPredictor":
403
+ @staticmethod
404
+ def build_model(architecture: str, path_config_json: Optional[str] = None) -> "RecognitionPredictor":
314
405
  """Building the model"""
315
406
 
316
407
  # inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py
317
408
  custom_configs = {}
318
409
  batch_size = 32
319
410
  recognition_configs = {}
320
- if self.path_config_json:
321
- custom_configs = load_json(self.path_config_json)
411
+ if path_config_json:
412
+ custom_configs = load_json(path_config_json)
322
413
  custom_configs.pop("arch", None)
323
414
  custom_configs.pop("url", None)
324
415
  custom_configs.pop("task", None)
@@ -327,18 +418,115 @@ class DoctrTextRecognizer(TextRecognizer):
327
418
  batch_size = custom_configs.pop("batch_size")
328
419
  recognition_configs["batch_size"] = batch_size
329
420
 
330
- if isinstance(self.architecture, str):
331
- if self.architecture not in ARCHS:
332
- raise ValueError(f"unknown architecture '{self.architecture}'")
421
+ if isinstance(architecture, str):
422
+ if architecture not in ARCHS:
423
+ raise ValueError(f"unknown architecture '{architecture}'")
333
424
 
334
- model = recognition.__dict__[self.architecture](pretrained=True, pretrained_backbone=True, **custom_configs)
425
+ model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs)
335
426
  else:
336
427
  if not isinstance(
337
- self.architecture,
428
+ architecture,
338
429
  (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq),
339
430
  ):
340
- raise ValueError(f"unknown architecture: {type(self.architecture)}")
341
- model = self.architecture
431
+ raise ValueError(f"unknown architecture: {type(architecture)}")
432
+ model = architecture
342
433
 
343
434
  input_shape = model.cfg["input_shape"][:2] if tf_available() else model.cfg["input_shape"][-2:]
344
435
  return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model)
436
+
437
+ @staticmethod
438
+ def get_wrapped_model(
439
+ architecture: str,
440
+ path_weights: str,
441
+ device: Literal["cpu", "cuda"],
442
+ lib: Literal["PT", "TF"],
443
+ path_config_json: Optional[str] = None,
444
+ ) -> Any:
445
+ """
446
+ Get the inner (wrapped) model.
447
+
448
+ :param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn",
449
+ "crnn_mobilenet_v3_small". The full list can be found here:
450
+ https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py#L16.
451
+ :param path_weights: Path to the weights of the model
452
+ :param device: "cpu" or "cuda". Will default to "cuda" if the required hardware is available.
453
+ :param lib: "TF" or "PT" or None. If None, env variables USE_TENSORFLOW, USE_PYTORCH will be used.
454
+ :param path_config_json: Path to a json file containing the configuration of the model. Useful, if you have
455
+ a model trained on custom vocab.
456
+ :return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow
457
+ """
458
+ doctr_predictor = DoctrTextRecognizer.build_model(architecture, path_config_json)
459
+ device_str = _set_device_str(device)
460
+ DoctrTextRecognizer.load_model(path_weights, doctr_predictor, device_str, lib)
461
+ return doctr_predictor
462
+
463
+ @staticmethod
464
+ def get_name(path_weights: str, architecture: str) -> str:
465
+ """Returns the name of the model"""
466
+ return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])
467
+
468
+ @staticmethod
469
+ def auto_select_lib() -> Literal["PT", "TF"]:
470
+ """Auto select the DL library from the installed and from environment variables"""
471
+ return auto_select_lib_for_doctr()
472
+
473
+
474
+ class DocTrRotationTransformer(ImageTransformer):
475
+ """
476
+ The `DocTrRotationTransformer` class is a specialized image transformer that is designed to handle image rotation
477
+ in the context of Optical Character Recognition (OCR) tasks. It inherits from the `ImageTransformer` base class and
478
+ implements methods for predicting and applying rotation transformations to images.
479
+
480
+ The `predict` method determines the angle of the rotated image using the `estimate_orientation` function from the
481
+ `doctr.models._utils` module. The `n_ct` and `ratio_threshold_for_lines` parameters for this function can be
482
+ configured when instantiating the class.
483
+
484
+ The `transform` method applies the predicted rotation to the image, effectively rotating the image backwards.
485
+ This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
486
+
487
+ This class can be particularly useful in OCR tasks where the orientation of the text in the image matters.
488
+ The class also provides methods for cloning itself and for getting the requirements of the OCR system.
489
+
490
+ **Example:**
491
+ transformer = DocTrRotationTransformer()
492
+ detection_result = transformer.predict(np_img)
493
+ rotated_image = transformer.transform(np_img, detection_result)
494
+ """
495
+
496
+ def __init__(self, number_contours: int = 50, ratio_threshold_for_lines: float = 5):
497
+ """
498
+
499
+ :param number_contours: the number of contours used for the orientation estimation
500
+ :param ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
501
+ """
502
+ self.number_contours = number_contours
503
+ self.ratio_threshold_for_lines = ratio_threshold_for_lines
504
+ self.name = "doctr_rotation_transformer"
505
+
506
+ def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
507
+ """
508
+ Applies the predicted rotation to the image, effectively rotating the image backwards.
509
+ This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
510
+
511
+ :param np_img: The input image as a numpy array.
512
+ :param specification: A `DetectionResult` object containing the predicted rotation angle.
513
+ :return: The rotated image as a numpy array.
514
+ """
515
+ return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
516
+
517
+ def predict(self, np_img: ImageType) -> DetectionResult:
518
+ angle = estimate_orientation(np_img, self.number_contours, self.ratio_threshold_for_lines)
519
+ if angle < 0:
520
+ angle += 360
521
+ return DetectionResult(angle=round(angle, 2))
522
+
523
+ @classmethod
524
+ def get_requirements(cls) -> List[Requirement]:
525
+ return [get_doctr_requirement()]
526
+
527
+ def clone(self) -> PredictorBase:
528
+ return self.__class__(self.number_contours, self.ratio_threshold_for_lines)
529
+
530
+ @staticmethod
531
+ def possible_category() -> PageType:
532
+ return PageType.angle
@@ -18,18 +18,45 @@
18
18
  """
19
19
  Deepdoctection wrappers for fasttext language detection models
20
20
  """
21
+ from abc import ABC
21
22
  from copy import copy
22
- from typing import List, Mapping
23
+ from pathlib import Path
24
+ from typing import Any, List, Mapping, Tuple, Union
23
25
 
24
26
  from ..utils.file_utils import Requirement, fasttext_available, get_fasttext_requirement
25
- from ..utils.settings import TypeOrStr
27
+ from ..utils.settings import TypeOrStr, get_type
26
28
  from .base import DetectionResult, LanguageDetector, PredictorBase
27
29
 
28
30
  if fasttext_available():
29
31
  from fasttext import load_model # type: ignore
30
32
 
31
33
 
32
- class FasttextLangDetector(LanguageDetector):
34
+ class FasttextLangDetectorMixin(LanguageDetector, ABC):
35
+ """
36
+ Base class for Fasttext language detection implementation. This class only implements the basic wrapper functions.
37
+ """
38
+
39
+ def __init__(self, categories: Mapping[str, TypeOrStr]) -> None:
40
+ """
41
+ :param categories: A dict with the model output label and value. We use as convention the ISO 639-2 language
42
+ """
43
+ self.categories = copy({idx: get_type(cat) for idx, cat in categories.items()})
44
+
45
+ def output_to_detection_result(self, output: Union[Tuple[Any, Any]]) -> DetectionResult:
46
+ """
47
+ Generating `DetectionResult` from model output
48
+ :param output: FastText model output
49
+ :return: `DetectionResult` filled with `text` and `score`
50
+ """
51
+ return DetectionResult(text=self.categories[output[0][0]], score=output[1][0])
52
+
53
+ @staticmethod
54
+ def get_name(path_weights: str) -> str:
55
+ """Returns the name of the model"""
56
+ return "fasttext_" + "_".join(Path(path_weights).parts[-2:])
57
+
58
+
59
+ class FasttextLangDetector(FasttextLangDetectorMixin):
33
60
  """
34
61
  Fasttext language detector wrapper. Two models provided in the fasttext library can be used to identify languages.
35
62
  The background to the models can be found in the works:
@@ -57,15 +84,18 @@ class FasttextLangDetector(LanguageDetector):
57
84
  :param categories: A dict with the model output label and value. We use as convention the ISO 639-2 language
58
85
  code.
59
86
  """
87
+ super().__init__(categories)
60
88
 
61
- self.name = "fasttest_lang_detector"
62
89
  self.path_weights = path_weights
63
- self.model = load_model(self.path_weights)
64
- self.categories = copy(categories) # type: ignore
90
+
91
+ self.name = self.get_name(self.path_weights)
92
+ self.model_id = self.get_model_id()
93
+
94
+ self.model = self.get_wrapped_model(self.path_weights)
65
95
 
66
96
  def predict(self, text_string: str) -> DetectionResult:
67
97
  output = self.model.predict(text_string)
68
- return DetectionResult(text=self.categories[output[0][0]], score=output[1][0])
98
+ return self.output_to_detection_result(output)
69
99
 
70
100
  @classmethod
71
101
  def get_requirements(cls) -> List[Requirement]:
@@ -73,3 +103,11 @@ class FasttextLangDetector(LanguageDetector):
73
103
 
74
104
  def clone(self) -> PredictorBase:
75
105
  return self.__class__(self.path_weights, self.categories)
106
+
107
+ @staticmethod
108
+ def get_wrapped_model(path_weights: str) -> Any:
109
+ """
110
+ Get the wrapped model
111
+ :param path_weights: path to model weights
112
+ """
113
+ return load_model(path_weights)
@@ -19,6 +19,8 @@
19
19
  HF Detr model for object detection.
20
20
  """
21
21
 
22
+ from abc import ABC
23
+ from pathlib import Path
22
24
  from typing import List, Literal, Mapping, Optional, Sequence
23
25
 
24
26
  from ..utils.detection_types import ImageType, Requirement
@@ -94,7 +96,48 @@ def detr_predict_image(
94
96
  ]
95
97
 
96
98
 
97
- class HFDetrDerivedDetector(ObjectDetector):
99
+ class HFDetrDerivedDetectorMixin(ObjectDetector, ABC):
100
+ """Base class for Detr object detector. This class only implements the basic wrapper functions"""
101
+
102
+ def __init__(self, categories: Mapping[str, TypeOrStr], filter_categories: Optional[Sequence[TypeOrStr]] = None):
103
+ """
104
+
105
+ :param categories: A dict with key (indices) and values (category names).
106
+ :param filter_categories: The model might return objects that are not supposed to be predicted and that should
107
+ be filtered. Pass a list of category names that must not be returned
108
+ """
109
+ self.categories = {idx: get_type(cat) for idx, cat in categories.items()}
110
+ if filter_categories:
111
+ filter_categories = [get_type(cat) for cat in filter_categories]
112
+ self.filter_categories = filter_categories
113
+
114
+ def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
115
+ """
116
+ Populating category names to detection results. Will also filter categories
117
+
118
+ :param detection_results: list of detection results
119
+ :return: List of detection results with attribute class_name populated
120
+ """
121
+ filtered_detection_result: List[DetectionResult] = []
122
+ for result in detection_results:
123
+ result.class_name = self.categories[str(result.class_id + 1)] # type: ignore
124
+ if isinstance(result.class_id, int):
125
+ result.class_id += 1
126
+ if self.filter_categories:
127
+ if result.class_name not in self.filter_categories:
128
+ filtered_detection_result.append(result)
129
+ else:
130
+ filtered_detection_result.append(result)
131
+
132
+ return filtered_detection_result
133
+
134
+ @staticmethod
135
+ def get_name(path_weights: str) -> str:
136
+ """Returns the name of the model"""
137
+ return "Transformers_Tatr_" + "_".join(Path(path_weights).parts[-2:])
138
+
139
+
140
+ class HFDetrDerivedDetector(HFDetrDerivedDetectorMixin):
98
141
  """
99
142
  Model wrapper for TableTransformerForObjectDetection that again is based on
100
143
 
@@ -138,26 +181,25 @@ class HFDetrDerivedDetector(ObjectDetector):
138
181
  :param filter_categories: The model might return objects that are not supposed to be predicted and that should
139
182
  be filtered. Pass a list of category names that must not be returned
140
183
  """
141
- self.name = "Detr"
142
- self.categories = {idx: get_type(cat) for idx, cat in categories.items()}
184
+ super().__init__(categories, filter_categories)
185
+
143
186
  self.path_config = path_config_json
144
187
  self.path_weights = path_weights
145
188
  self.path_feature_extractor_config = path_feature_extractor_config_json
146
- self.config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=self.path_config)
147
- self.config.use_timm_backbone = True
148
- self.config.threshold = 0.1
149
- self.config.nms_threshold = 0.05
150
- self.hf_detr_predictor = self.set_model(path_weights)
151
- self.feature_extractor = self.set_pre_processor()
189
+
190
+ self.name = self.get_name(self.path_weights)
191
+ self.model_id = self.get_model_id()
192
+
193
+ self.config = self.get_config(path_config_json)
194
+
195
+ self.hf_detr_predictor = self.get_model(self.path_weights, self.config)
196
+ self.feature_extractor = self.get_pre_processor(self.path_feature_extractor_config)
152
197
 
153
198
  if device is not None:
154
199
  self.device = device
155
200
  else:
156
201
  self.device = set_torch_auto_device()
157
202
  self.hf_detr_predictor.to(self.device)
158
- if filter_categories:
159
- filter_categories = [get_type(cat) for cat in filter_categories]
160
- self.filter_categories = filter_categories
161
203
 
162
204
  def predict(self, np_img: ImageType) -> List[DetectionResult]:
163
205
  results = detr_predict_image(
@@ -170,44 +212,41 @@ class HFDetrDerivedDetector(ObjectDetector):
170
212
  )
171
213
  return self._map_category_names(results)
172
214
 
173
- def set_model(self, path_weights: str) -> "TableTransformerForObjectDetection":
215
+ @staticmethod
216
+ def get_model(path_weights: str, config: "PretrainedConfig") -> "TableTransformerForObjectDetection":
174
217
  """
175
218
  Builds the Detr model
176
219
 
177
- :param path_weights: weights
220
+ :param path_weights: The path to the model checkpoint.
221
+ :param config: `PretrainedConfig`
178
222
  :return: TableTransformerForObjectDetection instance
179
223
  """
180
224
  return TableTransformerForObjectDetection.from_pretrained(
181
- pretrained_model_name_or_path=path_weights, config=self.config
225
+ pretrained_model_name_or_path=path_weights, config=config
182
226
  )
183
227
 
184
- def set_pre_processor(self) -> "DetrFeatureExtractor":
228
+ @staticmethod
229
+ def get_pre_processor(path_feature_extractor_config: str) -> "DetrFeatureExtractor":
185
230
  """
186
231
  Builds the feature extractor
187
232
 
188
233
  :return: DetrFeatureExtractor
189
234
  """
190
- return AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path=self.path_feature_extractor_config)
235
+ return AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path=path_feature_extractor_config)
191
236
 
192
- def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
237
+ @staticmethod
238
+ def get_config(path_config: str) -> "PretrainedConfig":
193
239
  """
194
- Populating category names to detection results. Will also filter categories
240
+ Builds the config
195
241
 
196
- :param detection_results: list of detection results
197
- :return: List of detection results with attribute class_name populated
242
+ :param path_config: The path to the json config.
243
+ :return: PretrainedConfig instance
198
244
  """
199
- filtered_detection_result: List[DetectionResult] = []
200
- for result in detection_results:
201
- result.class_name = self.categories[str(result.class_id + 1)] # type: ignore
202
- if isinstance(result.class_id, int):
203
- result.class_id += 1
204
- if self.filter_categories:
205
- if result.class_name not in self.filter_categories:
206
- filtered_detection_result.append(result)
207
- else:
208
- filtered_detection_result.append(result)
209
-
210
- return filtered_detection_result
245
+ config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config)
246
+ config.use_timm_backbone = True
247
+ config.threshold = 0.1
248
+ config.nms_threshold = 0.05
249
+ return config
211
250
 
212
251
  @classmethod
213
252
  def get_requirements(cls) -> List[Requirement]:
@@ -217,3 +256,21 @@ class HFDetrDerivedDetector(ObjectDetector):
217
256
  return self.__class__(
218
257
  self.path_config, self.path_weights, self.path_feature_extractor_config, self.categories, self.device
219
258
  )
259
+
260
+ @staticmethod
261
+ def get_wrapped_model(
262
+ path_config_json: str, path_weights: str, device: Optional[Literal["cpu", "cuda"]] = None
263
+ ) -> "TableTransformerForObjectDetection":
264
+ """
265
+ Get the wrapped model
266
+
267
+ :param path_config_json: The path to the json config.
268
+ :param path_weights: The path to the model checkpoint.
269
+ :param device: "cpu" or "cuda". If not specified will auto select depending on what is available
270
+ :return: TableTransformerForObjectDetection instance
271
+ """
272
+ config = HFDetrDerivedDetector.get_config(path_config_json)
273
+ hf_detr_predictor = HFDetrDerivedDetector.get_model(path_weights, config)
274
+ if device is None:
275
+ device = set_torch_auto_device()
276
+ return hf_detr_predictor.to(device)