deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -23,25 +23,24 @@ from __future__ import annotations
23
23
  import os
24
24
  from abc import ABC
25
25
  from pathlib import Path
26
- from typing import Any, List, Literal, Mapping, Optional, Tuple, Union
26
+ from typing import Any, Literal, Mapping, Optional, Union
27
27
  from zipfile import ZipFile
28
28
 
29
29
  from lazy_imports import try_import
30
30
 
31
- from ..utils.detection_types import ImageType, Requirement
31
+ from ..utils.env_info import ENV_VARS_TRUE
32
32
  from ..utils.error import DependencyError
33
33
  from ..utils.file_utils import (
34
34
  get_doctr_requirement,
35
35
  get_pytorch_requirement,
36
36
  get_tensorflow_requirement,
37
37
  get_tf_addons_requirements,
38
- pytorch_available,
39
- tf_available,
40
38
  )
41
39
  from ..utils.fs import load_json
42
40
  from ..utils.settings import LayoutType, ObjectTypes, PageType, TypeOrStr
41
+ from ..utils.types import PathLikeOrStr, PixelValues, Requirement
43
42
  from ..utils.viz import viz_handler
44
- from .base import DetectionResult, ImageTransformer, ObjectDetector, PredictorBase, TextRecognizer
43
+ from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector, TextRecognizer
45
44
  from .pt.ptutils import get_torch_device
46
45
  from .tp.tfutils import get_tf_device
47
46
 
@@ -60,13 +59,24 @@ with try_import() as doctr_import_guard:
60
59
  from doctr.models.recognition.zoo import ARCHS, recognition
61
60
 
62
61
 
62
+ def _get_doctr_requirements() -> list[Requirement]:
63
+ if os.environ.get("DD_USE_TF", "0") in ENV_VARS_TRUE:
64
+ return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
65
+ if os.environ.get("DD_USE_TORCH", "0") in ENV_VARS_TRUE:
66
+ return [get_pytorch_requirement(), get_doctr_requirement()]
67
+ raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextRecognizer")
68
+
69
+
63
70
  def _load_model(
64
- path_weights: str, doctr_predictor: Any, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
71
+ path_weights: PathLikeOrStr,
72
+ doctr_predictor: Union[DetectionPredictor, RecognitionPredictor],
73
+ device: Union[torch.device, tf.device],
74
+ lib: Literal["PT", "TF"],
65
75
  ) -> None:
66
76
  """Loading a model either in TF or PT. We only shift the model to the device when using PyTorch. The shift of
67
77
  the model to the device in Tensorflow is done in the predict function."""
68
78
  if lib == "PT":
69
- state_dict = torch.load(path_weights, map_location=device)
79
+ state_dict = torch.load(os.fspath(path_weights), map_location=device)
70
80
  for key in list(state_dict.keys()):
71
81
  state_dict["model." + key] = state_dict.pop(key)
72
82
  doctr_predictor.load_state_dict(state_dict)
@@ -74,27 +84,27 @@ def _load_model(
74
84
  elif lib == "TF":
75
85
  # Unzip the archive
76
86
  params_path = Path(path_weights).parent
77
- is_zip_path = path_weights.endswith(".zip")
87
+ is_zip_path = os.fspath(path_weights).endswith(".zip")
78
88
  if is_zip_path:
79
89
  with ZipFile(path_weights, "r") as file:
80
90
  file.extractall(path=params_path)
81
91
  doctr_predictor.model.load_weights(params_path / "weights")
82
92
  else:
83
- doctr_predictor.model.load_weights(path_weights)
93
+ doctr_predictor.model.load_weights(os.fspath(path_weights))
84
94
 
85
95
 
86
96
  def auto_select_lib_for_doctr() -> Literal["PT", "TF"]:
87
97
  """Auto select the DL library from environment variables"""
88
- if os.environ.get("USE_TORCH"):
98
+ if os.environ.get("USE_TORCH", "0") in ENV_VARS_TRUE:
89
99
  return "PT"
90
- if os.environ.get("USE_TF"):
100
+ if os.environ.get("USE_TF", "0") in ENV_VARS_TRUE:
91
101
  return "TF"
92
102
  raise DependencyError("At least one of the env variables USE_TORCH or USE_TF must be set.")
93
103
 
94
104
 
95
105
  def doctr_predict_text_lines(
96
- np_img: ImageType, predictor: DetectionPredictor, device: Union[torch.device, tf.device], lib: Literal["TF", "PT"]
97
- ) -> List[DetectionResult]:
106
+ np_img: PixelValues, predictor: DetectionPredictor, device: Union[torch.device, tf.device], lib: Literal["TF", "PT"]
107
+ ) -> list[DetectionResult]:
98
108
  """
99
109
  Generating text line DetectionResult based on Doctr DetectionPredictor.
100
110
 
@@ -113,7 +123,7 @@ def doctr_predict_text_lines(
113
123
  raise DependencyError("Tensorflow or PyTorch must be installed.")
114
124
  detection_results = [
115
125
  DetectionResult(
116
- box=box[:4].tolist(), class_id=1, score=box[4], absolute_coords=False, class_name=LayoutType.word
126
+ box=box[:4].tolist(), class_id=1, score=box[4], absolute_coords=False, class_name=LayoutType.WORD
117
127
  )
118
128
  for box in raw_output[0]["words"]
119
129
  ]
@@ -121,11 +131,11 @@ def doctr_predict_text_lines(
121
131
 
122
132
 
123
133
  def doctr_predict_text(
124
- inputs: List[Tuple[str, ImageType]],
134
+ inputs: list[tuple[str, PixelValues]],
125
135
  predictor: RecognitionPredictor,
126
136
  device: Union[torch.device, tf.device],
127
137
  lib: Literal["TF", "PT"],
128
- ) -> List[DetectionResult]:
138
+ ) -> list[DetectionResult]:
129
139
  """
130
140
  Calls Doctr text recognition model on a batch of numpy arrays (text lines predicted from a text line detector) and
131
141
  returns the recognized text as DetectionResult
@@ -155,15 +165,15 @@ def doctr_predict_text(
155
165
  class DoctrTextlineDetectorMixin(ObjectDetector, ABC):
156
166
  """Base class for Doctr textline detector. This class only implements the basic wrapper functions"""
157
167
 
158
- def __init__(self, categories: Mapping[str, TypeOrStr], lib: Optional[Literal["PT", "TF"]] = None):
159
- self.categories = categories # type: ignore
168
+ def __init__(self, categories: Mapping[int, TypeOrStr], lib: Optional[Literal["PT", "TF"]] = None):
169
+ self.categories = ModelCategories(init_categories=categories)
160
170
  self.lib = lib if lib is not None else self.auto_select_lib()
161
171
 
162
- def possible_categories(self) -> List[ObjectTypes]:
163
- return [LayoutType.word]
172
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
173
+ return self.categories.get_categories(as_dict=False)
164
174
 
165
175
  @staticmethod
166
- def get_name(path_weights: str, architecture: str) -> str:
176
+ def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
167
177
  """Returns the name of the model"""
168
178
  return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])
169
179
 
@@ -211,8 +221,8 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
211
221
  def __init__(
212
222
  self,
213
223
  architecture: str,
214
- path_weights: str,
215
- categories: Mapping[str, TypeOrStr],
224
+ path_weights: PathLikeOrStr,
225
+ categories: Mapping[int, TypeOrStr],
216
226
  device: Optional[Union[Literal["cpu", "cuda"], torch.device, tf.device]] = None,
217
227
  lib: Optional[Literal["PT", "TF"]] = None,
218
228
  ) -> None:
@@ -227,7 +237,7 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
227
237
  """
228
238
  super().__init__(categories, lib)
229
239
  self.architecture = architecture
230
- self.path_weights = path_weights
240
+ self.path_weights = Path(path_weights)
231
241
 
232
242
  self.name = self.get_name(self.path_weights, self.architecture)
233
243
  self.model_id = self.get_model_id()
@@ -239,37 +249,37 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
239
249
 
240
250
  self.doctr_predictor = self.get_wrapped_model(self.architecture, self.path_weights, self.device, self.lib)
241
251
 
242
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
252
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
243
253
  """
244
254
  Prediction per image.
245
255
 
246
256
  :param np_img: image as numpy array
247
257
  :return: A list of DetectionResult
248
258
  """
249
- detection_results = doctr_predict_text_lines(np_img, self.doctr_predictor, self.device, self.lib)
250
- return detection_results
259
+ return doctr_predict_text_lines(np_img, self.doctr_predictor, self.device, self.lib)
251
260
 
252
261
  @classmethod
253
- def get_requirements(cls) -> List[Requirement]:
254
- if os.environ.get("DD_USE_TF"):
255
- return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
256
- if os.environ.get("DD_USE_TORCH"):
257
- return [get_pytorch_requirement(), get_doctr_requirement()]
258
- raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextlineDetector")
262
+ def get_requirements(cls) -> list[Requirement]:
263
+ return _get_doctr_requirements()
259
264
 
260
- def clone(self) -> PredictorBase:
261
- return self.__class__(self.architecture, self.path_weights, self.categories, self.device, self.lib)
265
+ def clone(self) -> DoctrTextlineDetector:
266
+ return self.__class__(
267
+ self.architecture, self.path_weights, self.categories.get_categories(), self.device, self.lib
268
+ )
262
269
 
263
270
  @staticmethod
264
271
  def load_model(
265
- path_weights: str, doctr_predictor: Any, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
272
+ path_weights: PathLikeOrStr,
273
+ doctr_predictor: DetectionPredictor,
274
+ device: Union[torch.device, tf.device],
275
+ lib: Literal["PT", "TF"],
266
276
  ) -> None:
267
277
  """Loading model weights"""
268
278
  _load_model(path_weights, doctr_predictor, device, lib)
269
279
 
270
280
  @staticmethod
271
281
  def get_wrapped_model(
272
- architecture: str, path_weights: str, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
282
+ architecture: str, path_weights: PathLikeOrStr, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
273
283
  ) -> Any:
274
284
  """
275
285
  Get the inner (wrapped) model.
@@ -290,6 +300,9 @@ class DoctrTextlineDetector(DoctrTextlineDetectorMixin):
290
300
  DoctrTextlineDetector.load_model(path_weights, doctr_predictor, device, lib)
291
301
  return doctr_predictor
292
302
 
303
+ def clear_model(self) -> None:
304
+ self.doctr_predictor = None
305
+
293
306
 
294
307
  class DoctrTextRecognizer(TextRecognizer):
295
308
  """
@@ -330,10 +343,10 @@ class DoctrTextRecognizer(TextRecognizer):
330
343
  def __init__(
331
344
  self,
332
345
  architecture: str,
333
- path_weights: str,
346
+ path_weights: PathLikeOrStr,
334
347
  device: Optional[Union[Literal["cpu", "cuda"], torch.device, tf.device]] = None,
335
348
  lib: Optional[Literal["PT", "TF"]] = None,
336
- path_config_json: Optional[str] = None,
349
+ path_config_json: Optional[PathLikeOrStr] = None,
337
350
  ) -> None:
338
351
  """
339
352
  :param architecture: DocTR supports various text recognition models, e.g. "crnn_vgg16_bn",
@@ -349,7 +362,7 @@ class DoctrTextRecognizer(TextRecognizer):
349
362
  self.lib = lib if lib is not None else self.auto_select_lib()
350
363
 
351
364
  self.architecture = architecture
352
- self.path_weights = path_weights
365
+ self.path_weights = Path(path_weights)
353
366
 
354
367
  self.name = self.get_name(self.path_weights, self.architecture)
355
368
  self.model_id = self.get_model_id()
@@ -360,13 +373,13 @@ class DoctrTextRecognizer(TextRecognizer):
360
373
  self.device = get_torch_device(device)
361
374
 
362
375
  self.path_config_json = path_config_json
363
- self.doctr_predictor = self.build_model(self.architecture, self.path_config_json)
376
+ self.doctr_predictor = self.build_model(self.architecture, self.lib, self.path_config_json)
364
377
  self.load_model(self.path_weights, self.doctr_predictor, self.device, self.lib)
365
378
  self.doctr_predictor = self.get_wrapped_model(
366
379
  self.architecture, self.path_weights, self.device, self.lib, self.path_config_json
367
380
  )
368
381
 
369
- def predict(self, images: List[Tuple[str, ImageType]]) -> List[DetectionResult]:
382
+ def predict(self, images: list[tuple[str, PixelValues]]) -> list[DetectionResult]:
370
383
  """
371
384
  Prediction on a batch of text lines
372
385
 
@@ -378,25 +391,26 @@ class DoctrTextRecognizer(TextRecognizer):
378
391
  return []
379
392
 
380
393
  @classmethod
381
- def get_requirements(cls) -> List[Requirement]:
382
- if tf_available():
383
- return [get_tensorflow_requirement(), get_doctr_requirement(), get_tf_addons_requirements()]
384
- if pytorch_available():
385
- return [get_pytorch_requirement(), get_doctr_requirement()]
386
- raise ModuleNotFoundError("Neither Tensorflow nor PyTorch has been installed. Cannot use DoctrTextRecognizer")
387
-
388
- def clone(self) -> PredictorBase:
394
+ def get_requirements(cls) -> list[Requirement]:
395
+ return _get_doctr_requirements()
396
+
397
+ def clone(self) -> DoctrTextRecognizer:
389
398
  return self.__class__(self.architecture, self.path_weights, self.device, self.lib)
390
399
 
391
400
  @staticmethod
392
401
  def load_model(
393
- path_weights: str, doctr_predictor: Any, device: Union[torch.device, tf.device], lib: Literal["PT", "TF"]
402
+ path_weights: PathLikeOrStr,
403
+ doctr_predictor: RecognitionPredictor,
404
+ device: Union[torch.device, tf.device],
405
+ lib: Literal["PT", "TF"],
394
406
  ) -> None:
395
407
  """Loading model weights"""
396
408
  _load_model(path_weights, doctr_predictor, device, lib)
397
409
 
398
410
  @staticmethod
399
- def build_model(architecture: str, path_config_json: Optional[str] = None) -> "RecognitionPredictor":
411
+ def build_model(
412
+ architecture: str, lib: Literal["TF", "PT"], path_config_json: Optional[PathLikeOrStr] = None
413
+ ) -> RecognitionPredictor:
400
414
  """Building the model"""
401
415
 
402
416
  # inspired and adapted from https://github.com/mindee/doctr/blob/main/doctr/models/recognition/zoo.py
@@ -419,6 +433,7 @@ class DoctrTextRecognizer(TextRecognizer):
419
433
 
420
434
  model = recognition.__dict__[architecture](pretrained=True, pretrained_backbone=True, **custom_configs)
421
435
  else:
436
+ # This is not documented, but you can also directly pass the model class to architecture
422
437
  if not isinstance(
423
438
  architecture,
424
439
  (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq),
@@ -426,16 +441,16 @@ class DoctrTextRecognizer(TextRecognizer):
426
441
  raise ValueError(f"unknown architecture: {type(architecture)}")
427
442
  model = architecture
428
443
 
429
- input_shape = model.cfg["input_shape"][:2] if tf_available() else model.cfg["input_shape"][-2:]
444
+ input_shape = model.cfg["input_shape"][:2] if lib == "TF" else model.cfg["input_shape"][-2:]
430
445
  return RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **recognition_configs), model)
431
446
 
432
447
  @staticmethod
433
448
  def get_wrapped_model(
434
449
  architecture: str,
435
- path_weights: str,
450
+ path_weights: PathLikeOrStr,
436
451
  device: Union[torch.device, tf.device],
437
452
  lib: Literal["PT", "TF"],
438
- path_config_json: Optional[str] = None,
453
+ path_config_json: Optional[PathLikeOrStr] = None,
439
454
  ) -> Any:
440
455
  """
441
456
  Get the inner (wrapped) model.
@@ -450,12 +465,12 @@ class DoctrTextRecognizer(TextRecognizer):
450
465
  a model trained on custom vocab.
451
466
  :return: Inner model which is a "nn.Module" in PyTorch or a "tf.keras.Model" in Tensorflow
452
467
  """
453
- doctr_predictor = DoctrTextRecognizer.build_model(architecture, path_config_json)
468
+ doctr_predictor = DoctrTextRecognizer.build_model(architecture, lib, path_config_json)
454
469
  DoctrTextRecognizer.load_model(path_weights, doctr_predictor, device, lib)
455
470
  return doctr_predictor
456
471
 
457
472
  @staticmethod
458
- def get_name(path_weights: str, architecture: str) -> str:
473
+ def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
459
474
  """Returns the name of the model"""
460
475
  return f"doctr_{architecture}" + "_".join(Path(path_weights).parts[-2:])
461
476
 
@@ -464,6 +479,9 @@ class DoctrTextRecognizer(TextRecognizer):
464
479
  """Auto select the DL library from the installed and from environment variables"""
465
480
  return auto_select_lib_for_doctr()
466
481
 
482
+ def clear_model(self) -> None:
483
+ self.doctr_predictor = None
484
+
467
485
 
468
486
  class DocTrRotationTransformer(ImageTransformer):
469
487
  """
@@ -497,7 +515,7 @@ class DocTrRotationTransformer(ImageTransformer):
497
515
  self.ratio_threshold_for_lines = ratio_threshold_for_lines
498
516
  self.name = "doctr_rotation_transformer"
499
517
 
500
- def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
518
+ def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
501
519
  """
502
520
  Applies the predicted rotation to the image, effectively rotating the image backwards.
503
521
  This method uses either the Pillow library or OpenCV for the rotation operation, depending on the configuration.
@@ -508,19 +526,18 @@ class DocTrRotationTransformer(ImageTransformer):
508
526
  """
509
527
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
510
528
 
511
- def predict(self, np_img: ImageType) -> DetectionResult:
529
+ def predict(self, np_img: PixelValues) -> DetectionResult:
512
530
  angle = estimate_orientation(np_img, self.number_contours, self.ratio_threshold_for_lines)
513
531
  if angle < 0:
514
532
  angle += 360
515
533
  return DetectionResult(angle=round(angle, 2))
516
534
 
517
535
  @classmethod
518
- def get_requirements(cls) -> List[Requirement]:
536
+ def get_requirements(cls) -> list[Requirement]:
519
537
  return [get_doctr_requirement()]
520
538
 
521
- def clone(self) -> PredictorBase:
539
+ def clone(self) -> DocTrRotationTransformer:
522
540
  return self.__class__(self.number_contours, self.ratio_threshold_for_lines)
523
541
 
524
- @staticmethod
525
- def possible_category() -> PageType:
526
- return PageType.angle
542
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
543
+ return (PageType.ANGLE,)
@@ -18,16 +18,20 @@
18
18
  """
19
19
  Deepdoctection wrappers for fasttext language detection models
20
20
  """
21
+ from __future__ import annotations
22
+
23
+ import os
21
24
  from abc import ABC
22
- from copy import copy
23
25
  from pathlib import Path
24
- from typing import Any, List, Mapping, Tuple, Union
26
+ from types import MappingProxyType
27
+ from typing import Any, Mapping, Union
25
28
 
26
29
  from lazy_imports import try_import
27
30
 
28
31
  from ..utils.file_utils import Requirement, get_fasttext_requirement
29
32
  from ..utils.settings import TypeOrStr, get_type
30
- from .base import DetectionResult, LanguageDetector, PredictorBase
33
+ from ..utils.types import PathLikeOrStr
34
+ from .base import DetectionResult, LanguageDetector, ModelCategories
31
35
 
32
36
  with try_import() as import_guard:
33
37
  from fasttext import load_model # type: ignore
@@ -38,22 +42,23 @@ class FasttextLangDetectorMixin(LanguageDetector, ABC):
38
42
  Base class for Fasttext language detection implementation. This class only implements the basic wrapper functions.
39
43
  """
40
44
 
41
- def __init__(self, categories: Mapping[str, TypeOrStr]) -> None:
45
+ def __init__(self, categories: Mapping[int, TypeOrStr], categories_orig: Mapping[str, TypeOrStr]) -> None:
42
46
  """
43
47
  :param categories: A dict with the model output label and value. We use as convention the ISO 639-2 language
44
48
  """
45
- self.categories = copy({idx: get_type(cat) for idx, cat in categories.items()})
49
+ self.categories = ModelCategories(init_categories=categories)
50
+ self.categories_orig = MappingProxyType({cat_orig: get_type(cat) for cat_orig, cat in categories_orig.items()})
46
51
 
47
- def output_to_detection_result(self, output: Union[Tuple[Any, Any]]) -> DetectionResult:
52
+ def output_to_detection_result(self, output: Union[tuple[Any, Any]]) -> DetectionResult:
48
53
  """
49
54
  Generating `DetectionResult` from model output
50
55
  :param output: FastText model output
51
56
  :return: `DetectionResult` filled with `text` and `score`
52
57
  """
53
- return DetectionResult(text=self.categories[output[0][0]], score=output[1][0])
58
+ return DetectionResult(text=self.categories_orig[output[0][0]], score=output[1][0])
54
59
 
55
60
  @staticmethod
56
- def get_name(path_weights: str) -> str:
61
+ def get_name(path_weights: PathLikeOrStr) -> str:
57
62
  """Returns the name of the model"""
58
63
  return "fasttext_" + "_".join(Path(path_weights).parts[-2:])
59
64
 
@@ -80,15 +85,17 @@ class FasttextLangDetector(FasttextLangDetectorMixin):
80
85
 
81
86
  """
82
87
 
83
- def __init__(self, path_weights: str, categories: Mapping[str, TypeOrStr]):
88
+ def __init__(
89
+ self, path_weights: PathLikeOrStr, categories: Mapping[int, TypeOrStr], categories_orig: Mapping[str, TypeOrStr]
90
+ ):
84
91
  """
85
92
  :param path_weights: path to model weights
86
93
  :param categories: A dict with the model output label and value. We use as convention the ISO 639-2 language
87
94
  code.
88
95
  """
89
- super().__init__(categories)
96
+ super().__init__(categories, categories_orig)
90
97
 
91
- self.path_weights = path_weights
98
+ self.path_weights = Path(path_weights)
92
99
 
93
100
  self.name = self.get_name(self.path_weights)
94
101
  self.model_id = self.get_model_id()
@@ -100,16 +107,16 @@ class FasttextLangDetector(FasttextLangDetectorMixin):
100
107
  return self.output_to_detection_result(output)
101
108
 
102
109
  @classmethod
103
- def get_requirements(cls) -> List[Requirement]:
110
+ def get_requirements(cls) -> list[Requirement]:
104
111
  return [get_fasttext_requirement()]
105
112
 
106
- def clone(self) -> PredictorBase:
107
- return self.__class__(self.path_weights, self.categories)
113
+ def clone(self) -> FasttextLangDetector:
114
+ return self.__class__(self.path_weights, self.categories.get_categories(), self.categories_orig)
108
115
 
109
116
  @staticmethod
110
- def get_wrapped_model(path_weights: str) -> Any:
117
+ def get_wrapped_model(path_weights: PathLikeOrStr) -> Any:
111
118
  """
112
119
  Get the wrapped model
113
120
  :param path_weights: path to model weights
114
121
  """
115
- return load_model(path_weights)
122
+ return load_model(os.fspath(path_weights))