deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (131) hide show
  1. deepdoctection/__init__.py +16 -29
  2. deepdoctection/analyzer/dd.py +70 -59
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/common.py +9 -5
  5. deepdoctection/dataflow/custom.py +5 -5
  6. deepdoctection/dataflow/custom_serialize.py +75 -18
  7. deepdoctection/dataflow/parallel_map.py +3 -3
  8. deepdoctection/dataflow/serialize.py +4 -4
  9. deepdoctection/dataflow/stats.py +3 -3
  10. deepdoctection/datapoint/annotation.py +41 -56
  11. deepdoctection/datapoint/box.py +9 -8
  12. deepdoctection/datapoint/convert.py +6 -6
  13. deepdoctection/datapoint/image.py +56 -44
  14. deepdoctection/datapoint/view.py +245 -150
  15. deepdoctection/datasets/__init__.py +1 -4
  16. deepdoctection/datasets/adapter.py +35 -26
  17. deepdoctection/datasets/base.py +14 -12
  18. deepdoctection/datasets/dataflow_builder.py +3 -3
  19. deepdoctection/datasets/info.py +24 -26
  20. deepdoctection/datasets/instances/doclaynet.py +51 -51
  21. deepdoctection/datasets/instances/fintabnet.py +46 -46
  22. deepdoctection/datasets/instances/funsd.py +25 -24
  23. deepdoctection/datasets/instances/iiitar13k.py +13 -10
  24. deepdoctection/datasets/instances/layouttest.py +4 -3
  25. deepdoctection/datasets/instances/publaynet.py +5 -5
  26. deepdoctection/datasets/instances/pubtables1m.py +24 -21
  27. deepdoctection/datasets/instances/pubtabnet.py +32 -30
  28. deepdoctection/datasets/instances/rvlcdip.py +30 -30
  29. deepdoctection/datasets/instances/xfund.py +26 -26
  30. deepdoctection/datasets/save.py +6 -6
  31. deepdoctection/eval/__init__.py +1 -4
  32. deepdoctection/eval/accmetric.py +32 -33
  33. deepdoctection/eval/base.py +8 -9
  34. deepdoctection/eval/cocometric.py +15 -13
  35. deepdoctection/eval/eval.py +41 -37
  36. deepdoctection/eval/tedsmetric.py +30 -23
  37. deepdoctection/eval/tp_eval_callback.py +16 -19
  38. deepdoctection/extern/__init__.py +2 -7
  39. deepdoctection/extern/base.py +339 -134
  40. deepdoctection/extern/d2detect.py +85 -113
  41. deepdoctection/extern/deskew.py +14 -11
  42. deepdoctection/extern/doctrocr.py +141 -130
  43. deepdoctection/extern/fastlang.py +27 -18
  44. deepdoctection/extern/hfdetr.py +71 -62
  45. deepdoctection/extern/hflayoutlm.py +504 -211
  46. deepdoctection/extern/hflm.py +230 -0
  47. deepdoctection/extern/model.py +488 -302
  48. deepdoctection/extern/pdftext.py +23 -19
  49. deepdoctection/extern/pt/__init__.py +1 -3
  50. deepdoctection/extern/pt/nms.py +6 -2
  51. deepdoctection/extern/pt/ptutils.py +29 -19
  52. deepdoctection/extern/tessocr.py +39 -38
  53. deepdoctection/extern/texocr.py +18 -18
  54. deepdoctection/extern/tp/tfutils.py +57 -9
  55. deepdoctection/extern/tp/tpcompat.py +21 -14
  56. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
  60. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  61. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
  62. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
  67. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
  68. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  69. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  70. deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
  71. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  72. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  73. deepdoctection/extern/tpdetect.py +45 -53
  74. deepdoctection/mapper/__init__.py +3 -8
  75. deepdoctection/mapper/cats.py +27 -29
  76. deepdoctection/mapper/cocostruct.py +10 -10
  77. deepdoctection/mapper/d2struct.py +27 -26
  78. deepdoctection/mapper/hfstruct.py +13 -8
  79. deepdoctection/mapper/laylmstruct.py +178 -37
  80. deepdoctection/mapper/maputils.py +12 -11
  81. deepdoctection/mapper/match.py +2 -2
  82. deepdoctection/mapper/misc.py +11 -9
  83. deepdoctection/mapper/pascalstruct.py +4 -4
  84. deepdoctection/mapper/prodigystruct.py +5 -5
  85. deepdoctection/mapper/pubstruct.py +84 -92
  86. deepdoctection/mapper/tpstruct.py +5 -5
  87. deepdoctection/mapper/xfundstruct.py +33 -33
  88. deepdoctection/pipe/__init__.py +1 -1
  89. deepdoctection/pipe/anngen.py +12 -14
  90. deepdoctection/pipe/base.py +52 -106
  91. deepdoctection/pipe/common.py +72 -59
  92. deepdoctection/pipe/concurrency.py +16 -11
  93. deepdoctection/pipe/doctectionpipe.py +24 -21
  94. deepdoctection/pipe/language.py +20 -25
  95. deepdoctection/pipe/layout.py +20 -16
  96. deepdoctection/pipe/lm.py +75 -105
  97. deepdoctection/pipe/order.py +194 -89
  98. deepdoctection/pipe/refine.py +111 -124
  99. deepdoctection/pipe/segment.py +156 -161
  100. deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
  101. deepdoctection/pipe/text.py +37 -36
  102. deepdoctection/pipe/transform.py +19 -16
  103. deepdoctection/train/__init__.py +6 -12
  104. deepdoctection/train/d2_frcnn_train.py +48 -41
  105. deepdoctection/train/hf_detr_train.py +41 -30
  106. deepdoctection/train/hf_layoutlm_train.py +153 -135
  107. deepdoctection/train/tp_frcnn_train.py +32 -31
  108. deepdoctection/utils/concurrency.py +1 -1
  109. deepdoctection/utils/context.py +13 -6
  110. deepdoctection/utils/develop.py +4 -4
  111. deepdoctection/utils/env_info.py +87 -125
  112. deepdoctection/utils/file_utils.py +6 -11
  113. deepdoctection/utils/fs.py +22 -18
  114. deepdoctection/utils/identifier.py +2 -2
  115. deepdoctection/utils/logger.py +16 -15
  116. deepdoctection/utils/metacfg.py +7 -7
  117. deepdoctection/utils/mocks.py +93 -0
  118. deepdoctection/utils/pdf_utils.py +11 -11
  119. deepdoctection/utils/settings.py +185 -181
  120. deepdoctection/utils/tqdm.py +1 -1
  121. deepdoctection/utils/transform.py +14 -9
  122. deepdoctection/utils/types.py +104 -0
  123. deepdoctection/utils/utils.py +7 -7
  124. deepdoctection/utils/viz.py +74 -72
  125. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
  126. deepdoctection-0.33.dist-info/RECORD +146 -0
  127. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
  128. deepdoctection/utils/detection_types.py +0 -68
  129. deepdoctection-0.31.dist-info/RECORD +0 -144
  130. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
  131. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
@@ -18,43 +18,40 @@
18
18
  """
19
19
  D2 GeneralizedRCNN model as predictor for deepdoctection pipeline
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  import io
24
+ import os
22
25
  from abc import ABC
23
26
  from copy import copy
24
27
  from pathlib import Path
25
- from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence
28
+ from typing import Literal, Mapping, Optional, Sequence, Union
26
29
 
27
30
  import numpy as np
31
+ from lazy_imports import try_import
28
32
 
29
- from ..utils.detection_types import ImageType, Requirement
30
- from ..utils.file_utils import (
31
- detectron2_available,
32
- get_detectron2_requirement,
33
- get_pytorch_requirement,
34
- pytorch_available,
35
- )
33
+ from ..utils.file_utils import get_detectron2_requirement, get_pytorch_requirement
36
34
  from ..utils.metacfg import AttrDict, set_config_by_yaml
37
- from ..utils.settings import ObjectTypes, TypeOrStr, get_type
35
+ from ..utils.settings import DefaultType, ObjectTypes, TypeOrStr, get_type
38
36
  from ..utils.transform import InferenceResize, ResizeTransform
39
- from .base import DetectionResult, ObjectDetector, PredictorBase
37
+ from ..utils.types import PathLikeOrStr, PixelValues, Requirement
38
+ from .base import DetectionResult, ModelCategories, ObjectDetector
40
39
  from .pt.nms import batched_nms
41
- from .pt.ptutils import set_torch_auto_device
40
+ from .pt.ptutils import get_torch_device
42
41
 
43
- if pytorch_available():
42
+ with try_import() as pt_import_guard:
44
43
  import torch
45
44
  import torch.cuda
46
45
  from torch import nn # pylint: disable=W0611
47
46
 
48
- if detectron2_available():
47
+ with try_import() as d2_import_guard:
49
48
  from detectron2.checkpoint import DetectionCheckpointer
50
49
  from detectron2.config import CfgNode, get_cfg # pylint: disable=W0611
51
50
  from detectron2.modeling import GeneralizedRCNN, build_model # pylint: disable=W0611
52
51
  from detectron2.structures import Instances # pylint: disable=W0611
53
52
 
54
53
 
55
- def _d2_post_processing(
56
- predictions: Dict[str, "Instances"], nms_thresh_class_agnostic: float
57
- ) -> Dict[str, "Instances"]:
54
+ def _d2_post_processing(predictions: dict[str, Instances], nms_thresh_class_agnostic: float) -> dict[str, Instances]:
58
55
  """
59
56
  D2 postprocessing steps, so that detection outputs are aligned with outputs of other packages (e.g. Tensorpack).
60
57
  Apply a class agnostic NMS.
@@ -71,11 +68,11 @@ def _d2_post_processing(
71
68
 
72
69
 
73
70
  def d2_predict_image(
74
- np_img: ImageType,
75
- predictor: "nn.Module",
71
+ np_img: PixelValues,
72
+ predictor: nn.Module,
76
73
  resizer: InferenceResize,
77
74
  nms_thresh_class_agnostic: float,
78
- ) -> List[DetectionResult]:
75
+ ) -> list[DetectionResult]:
79
76
  """
80
77
  Run detection on one image, using the D2 model callable. It will also handle the preprocessing internally which
81
78
  is using a custom resizing within some bounds.
@@ -107,8 +104,8 @@ def d2_predict_image(
107
104
 
108
105
 
109
106
  def d2_jit_predict_image(
110
- np_img: ImageType, d2_predictor: "nn.Module", resizer: InferenceResize, nms_thresh_class_agnostic: float
111
- ) -> List[DetectionResult]:
107
+ np_img: PixelValues, d2_predictor: nn.Module, resizer: InferenceResize, nms_thresh_class_agnostic: float
108
+ ) -> list[DetectionResult]:
112
109
  """
113
110
  Run detection on an image using torchscript. It will also handle the preprocessing internally which
114
111
  is using a custom resizing within some bounds. Moreover, and different from the setting where Detectron2 is used
@@ -152,7 +149,7 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
152
149
 
153
150
  def __init__(
154
151
  self,
155
- categories: Mapping[str, TypeOrStr],
152
+ categories: Mapping[int, TypeOrStr],
156
153
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
157
154
  ):
158
155
  """
@@ -163,37 +160,31 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
163
160
  be filtered. Pass a list of category names that must not be returned
164
161
  """
165
162
 
163
+ self.categories = ModelCategories(init_categories=categories)
166
164
  if filter_categories:
167
- filter_categories = [get_type(cat) for cat in filter_categories]
168
- self.filter_categories = filter_categories
169
- self._categories_d2 = self._map_to_d2_categories(copy(categories))
170
- self.categories = {idx: get_type(cat) for idx, cat in categories.items()}
165
+ self.categories.filter_categories = tuple(get_type(cat) for cat in filter_categories)
171
166
 
172
- def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
167
+ def _map_category_names(self, detection_results: list[DetectionResult]) -> list[DetectionResult]:
173
168
  """
174
169
  Populating category names to detection results
175
170
 
176
171
  :param detection_results: list of detection results. Will also filter categories
177
172
  :return: List of detection results with attribute class_name populated
178
173
  """
179
- filtered_detection_result: List[DetectionResult] = []
174
+ filtered_detection_result: list[DetectionResult] = []
175
+ shifted_categories = self.categories.shift_category_ids(shift_by=-1)
180
176
  for result in detection_results:
181
- result.class_name = self._categories_d2[str(result.class_id)]
182
- if isinstance(result.class_id, int):
183
- result.class_id += 1
184
- if self.filter_categories:
185
- if result.class_name not in self.filter_categories:
177
+ result.class_name = shifted_categories.get(
178
+ result.class_id if result.class_id is not None else -1, DefaultType.DEFAULT_TYPE
179
+ )
180
+ if result.class_name != DefaultType.DEFAULT_TYPE:
181
+ if result.class_id is not None:
182
+ result.class_id += 1
186
183
  filtered_detection_result.append(result)
187
- else:
188
- filtered_detection_result.append(result)
189
184
  return filtered_detection_result
190
185
 
191
- @classmethod
192
- def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
193
- return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
194
-
195
- def possible_categories(self) -> List[ObjectTypes]:
196
- return list(self.categories.values())
186
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
187
+ return self.categories.get_categories(as_dict=False)
197
188
 
198
189
  @staticmethod
199
190
  def get_inference_resizer(min_size_test: int, max_size_test: int) -> InferenceResize:
@@ -205,7 +196,7 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
205
196
  return InferenceResize(min_size_test, max_size_test)
206
197
 
207
198
  @staticmethod
208
- def get_name(path_weights: str, architecture: str) -> str:
199
+ def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
209
200
  """Returns the name of the model"""
210
201
  return f"detectron2_{architecture}" + "_".join(Path(path_weights).parts[-2:])
211
202
 
@@ -234,11 +225,11 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
234
225
 
235
226
  def __init__(
236
227
  self,
237
- path_yaml: str,
238
- path_weights: str,
239
- categories: Mapping[str, TypeOrStr],
240
- config_overwrite: Optional[List[str]] = None,
241
- device: Optional[Literal["cpu", "cuda"]] = None,
228
+ path_yaml: PathLikeOrStr,
229
+ path_weights: PathLikeOrStr,
230
+ categories: Mapping[int, TypeOrStr],
231
+ config_overwrite: Optional[list[str]] = None,
232
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
242
233
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
243
234
  ):
244
235
  """
@@ -261,18 +252,15 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
261
252
  """
262
253
  super().__init__(categories, filter_categories)
263
254
 
264
- self.path_weights = path_weights
265
- self.path_yaml = path_yaml
255
+ self.path_weights = Path(path_weights)
256
+ self.path_yaml = Path(path_yaml)
266
257
 
267
258
  config_overwrite = config_overwrite if config_overwrite else []
268
259
  self.config_overwrite = config_overwrite
269
- if device is not None:
270
- self.device = device
271
- else:
272
- self.device = set_torch_auto_device()
260
+ self.device = get_torch_device(device)
273
261
 
274
262
  d2_conf_list = self._get_d2_config_list(path_weights, config_overwrite)
275
- self.cfg = self._set_config(path_yaml, d2_conf_list, device)
263
+ self.cfg = self._set_config(path_yaml, d2_conf_list, self.device)
276
264
 
277
265
  self.name = self.get_name(path_weights, self.cfg.MODEL.META_ARCHITECTURE)
278
266
  self.model_id = self.get_model_id()
@@ -282,21 +270,18 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
282
270
  self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
283
271
 
284
272
  @staticmethod
285
- def _set_config(
286
- path_yaml: str, d2_conf_list: List[str], device: Optional[Literal["cpu", "cuda"]] = None
287
- ) -> "CfgNode":
273
+ def _set_config(path_yaml: PathLikeOrStr, d2_conf_list: list[str], device: torch.device) -> CfgNode:
288
274
  cfg = get_cfg()
289
275
  # additional attribute with default value, so that the true value can be loaded from the configs
290
276
  cfg.NMS_THRESH_CLASS_AGNOSTIC = 0.1
291
- cfg.merge_from_file(path_yaml)
277
+ cfg.merge_from_file(os.fspath(path_yaml))
292
278
  cfg.merge_from_list(d2_conf_list)
293
- if not torch.cuda.is_available() or device == "cpu":
294
- cfg.MODEL.DEVICE = "cpu"
279
+ cfg.MODEL.DEVICE = str(device)
295
280
  cfg.freeze()
296
281
  return cfg
297
282
 
298
283
  @staticmethod
299
- def _set_model(config: "CfgNode") -> "GeneralizedRCNN":
284
+ def _set_model(config: CfgNode) -> GeneralizedRCNN:
300
285
  """
301
286
  Build the D2 model. It uses the available builtin tools of D2
302
287
 
@@ -306,11 +291,11 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
306
291
  return build_model(config.clone()).eval()
307
292
 
308
293
  @staticmethod
309
- def _instantiate_d2_predictor(wrapped_model: "GeneralizedRCNN", path_weights: str) -> None:
294
+ def _instantiate_d2_predictor(wrapped_model: GeneralizedRCNN, path_weights: PathLikeOrStr) -> None:
310
295
  checkpointer = DetectionCheckpointer(wrapped_model)
311
- checkpointer.load(path_weights)
296
+ checkpointer.load(os.fspath(path_weights))
312
297
 
313
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
298
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
314
299
  """
315
300
  Prediction per image.
316
301
 
@@ -326,23 +311,26 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
326
311
  return self._map_category_names(detection_results)
327
312
 
328
313
  @classmethod
329
- def get_requirements(cls) -> List[Requirement]:
314
+ def get_requirements(cls) -> list[Requirement]:
330
315
  return [get_pytorch_requirement(), get_detectron2_requirement()]
331
316
 
332
- def clone(self) -> PredictorBase:
317
+ def clone(self) -> D2FrcnnDetector:
333
318
  return self.__class__(
334
319
  self.path_yaml,
335
320
  self.path_weights,
336
- self.categories,
321
+ self.categories.get_categories(),
337
322
  self.config_overwrite,
338
323
  self.device,
339
- self.filter_categories,
324
+ self.categories.filter_categories,
340
325
  )
341
326
 
342
327
  @staticmethod
343
328
  def get_wrapped_model(
344
- path_yaml: str, path_weights: str, config_overwrite: List[str], device: Literal["cpu", "cuda"]
345
- ) -> "GeneralizedRCNN":
329
+ path_yaml: PathLikeOrStr,
330
+ path_weights: PathLikeOrStr,
331
+ config_overwrite: list[str],
332
+ device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
333
+ ) -> GeneralizedRCNN:
346
334
  """
347
335
  Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
348
336
 
@@ -365,8 +353,7 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
365
353
  :return: Detectron2 GeneralizedRCNN model
366
354
  """
367
355
 
368
- if device is None:
369
- device = set_torch_auto_device()
356
+ device = get_torch_device(device)
370
357
  d2_conf_list = D2FrcnnDetector._get_d2_config_list(path_weights, config_overwrite)
371
358
  cfg = D2FrcnnDetector._set_config(path_yaml, d2_conf_list, device)
372
359
  model = D2FrcnnDetector._set_model(cfg)
@@ -374,14 +361,17 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
374
361
  return model
375
362
 
376
363
  @staticmethod
377
- def _get_d2_config_list(path_weights: str, config_overwrite: List[str]) -> List[str]:
378
- d2_conf_list = ["MODEL.WEIGHTS", path_weights]
364
+ def _get_d2_config_list(path_weights: PathLikeOrStr, config_overwrite: list[str]) -> list[str]:
365
+ d2_conf_list = ["MODEL.WEIGHTS", os.fspath(path_weights)]
379
366
  config_overwrite = config_overwrite if config_overwrite else []
380
367
  for conf in config_overwrite:
381
368
  key, val = conf.split("=", maxsplit=1)
382
369
  d2_conf_list.extend([key, val])
383
370
  return d2_conf_list
384
371
 
372
+ def clear_model(self) -> None:
373
+ self.d2_predictor = None
374
+
385
375
 
386
376
  class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
387
377
  """
@@ -409,10 +399,10 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
409
399
 
410
400
  def __init__(
411
401
  self,
412
- path_yaml: str,
413
- path_weights: str,
414
- categories: Mapping[str, TypeOrStr],
415
- config_overwrite: Optional[List[str]] = None,
402
+ path_yaml: PathLikeOrStr,
403
+ path_weights: PathLikeOrStr,
404
+ categories: Mapping[int, TypeOrStr],
405
+ config_overwrite: Optional[list[str]] = None,
416
406
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
417
407
  ):
418
408
  """
@@ -432,8 +422,8 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
432
422
 
433
423
  super().__init__(categories, filter_categories)
434
424
 
435
- self.path_weights = path_weights
436
- self.path_yaml = path_yaml
425
+ self.path_weights = Path(path_weights)
426
+ self.path_yaml = Path(path_yaml)
437
427
 
438
428
  self.config_overwrite = copy(config_overwrite)
439
429
  self.cfg = self._set_config(self.path_yaml, self.path_weights, self.config_overwrite)
@@ -445,14 +435,16 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
445
435
  self.d2_predictor = self.get_wrapped_model(self.path_weights)
446
436
 
447
437
  @staticmethod
448
- def _set_config(path_yaml: str, path_weights: str, config_overwrite: Optional[List[str]]) -> AttrDict:
438
+ def _set_config(
439
+ path_yaml: PathLikeOrStr, path_weights: PathLikeOrStr, config_overwrite: Optional[list[str]]
440
+ ) -> AttrDict:
449
441
  cfg = set_config_by_yaml(path_yaml)
450
442
  config_overwrite = config_overwrite if config_overwrite else []
451
- config_overwrite.extend([f"MODEL.WEIGHTS={path_weights}"])
443
+ config_overwrite.extend([f"MODEL.WEIGHTS={os.fspath(path_weights)}"])
452
444
  cfg.update_args(config_overwrite)
453
445
  return cfg
454
446
 
455
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
447
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
456
448
  """
457
449
  Prediction per image.
458
450
 
@@ -468,46 +460,23 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
468
460
  return self._map_category_names(detection_results)
469
461
 
470
462
  @classmethod
471
- def get_requirements(cls) -> List[Requirement]:
463
+ def get_requirements(cls) -> list[Requirement]:
472
464
  return [get_pytorch_requirement()]
473
465
 
474
- @classmethod
475
- def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
476
- return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
477
-
478
- def clone(self) -> PredictorBase:
466
+ def clone(self) -> D2FrcnnTracingDetector:
479
467
  return self.__class__(
480
468
  self.path_yaml,
481
469
  self.path_weights,
482
- self.categories,
470
+ self.categories.get_categories(),
483
471
  self.config_overwrite,
484
- self.filter_categories,
472
+ self.categories.filter_categories,
485
473
  )
486
474
 
487
- def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
488
- """
489
- Populating category names to detection results
490
-
491
- :param detection_results: list of detection results. Will also filter categories
492
- :return: List of detection results with attribute class_name populated
493
- """
494
- filtered_detection_result: List[DetectionResult] = []
495
- for result in detection_results:
496
- result.class_name = self._categories_d2[str(result.class_id)]
497
- if isinstance(result.class_id, int):
498
- result.class_id += 1
499
- if self.filter_categories:
500
- if result.class_name not in self.filter_categories:
501
- filtered_detection_result.append(result)
502
- else:
503
- filtered_detection_result.append(result)
504
- return filtered_detection_result
505
-
506
- def possible_categories(self) -> List[ObjectTypes]:
507
- return list(self.categories.values())
475
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
476
+ return self.categories.get_categories(as_dict=False)
508
477
 
509
478
  @staticmethod
510
- def get_wrapped_model(path_weights: str) -> Any:
479
+ def get_wrapped_model(path_weights: PathLikeOrStr) -> torch.jit.ScriptModule:
511
480
  """
512
481
  Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
513
482
 
@@ -518,3 +487,6 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
518
487
  buffer = io.BytesIO(file.read())
519
488
  # Load all tensors to the original device
520
489
  return torch.jit.load(buffer)
490
+
491
+ def clear_model(self) -> None:
492
+ self.d2_predictor = None # type: ignore
@@ -18,16 +18,17 @@
18
18
  """
19
19
  jdeskew estimator and rotator to deskew images: <https://github.com/phamquiluan/jdeskew>
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
- from typing import List
23
+ from lazy_imports import try_import
23
24
 
24
- from ..utils.detection_types import ImageType, Requirement
25
- from ..utils.file_utils import get_jdeskew_requirement, jdeskew_available
26
- from ..utils.settings import PageType
25
+ from ..utils.file_utils import get_jdeskew_requirement
26
+ from ..utils.settings import ObjectTypes, PageType
27
+ from ..utils.types import PixelValues, Requirement
27
28
  from ..utils.viz import viz_handler
28
29
  from .base import DetectionResult, ImageTransformer
29
30
 
30
- if jdeskew_available():
31
+ with try_import() as import_guard:
31
32
  from jdeskew.estimator import get_angle
32
33
 
33
34
 
@@ -42,7 +43,7 @@ class Jdeskewer(ImageTransformer):
42
43
  self.model_id = self.get_model_id()
43
44
  self.min_angle_rotation = min_angle_rotation
44
45
 
45
- def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
46
+ def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
46
47
  """
47
48
  Rotation of the image according to the angle determined by the jdeskew estimator.
48
49
 
@@ -59,7 +60,7 @@ class Jdeskewer(ImageTransformer):
59
60
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
60
61
  return np_img
61
62
 
62
- def predict(self, np_img: ImageType) -> DetectionResult:
63
+ def predict(self, np_img: PixelValues) -> DetectionResult:
63
64
  """
64
65
  Predict the angle of the image to deskew it.
65
66
 
@@ -69,12 +70,14 @@ class Jdeskewer(ImageTransformer):
69
70
  return DetectionResult(angle=round(float(get_angle(np_img)), 4))
70
71
 
71
72
  @classmethod
72
- def get_requirements(cls) -> List[Requirement]:
73
+ def get_requirements(cls) -> list[Requirement]:
73
74
  """
74
75
  Get a list of requirements for running the detector
75
76
  """
76
77
  return [get_jdeskew_requirement()]
77
78
 
78
- @staticmethod
79
- def possible_category() -> PageType:
80
- return PageType.angle
79
+ def clone(self) -> Jdeskewer:
80
+ return self.__class__(self.min_angle_rotation)
81
+
82
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
83
+ return (PageType.ANGLE,)