deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -21,20 +21,21 @@ D2 GeneralizedRCNN model as predictor for deepdoctection pipeline
21
21
  from __future__ import annotations
22
22
 
23
23
  import io
24
+ import os
24
25
  from abc import ABC
25
26
  from copy import copy
26
27
  from pathlib import Path
27
- from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Union
28
+ from typing import Literal, Mapping, Optional, Sequence, Union
28
29
 
29
30
  import numpy as np
30
31
  from lazy_imports import try_import
31
32
 
32
- from ..utils.detection_types import ImageType, Requirement
33
33
  from ..utils.file_utils import get_detectron2_requirement, get_pytorch_requirement
34
34
  from ..utils.metacfg import AttrDict, set_config_by_yaml
35
- from ..utils.settings import ObjectTypes, TypeOrStr, get_type
35
+ from ..utils.settings import DefaultType, ObjectTypes, TypeOrStr, get_type
36
36
  from ..utils.transform import InferenceResize, ResizeTransform
37
- from .base import DetectionResult, ObjectDetector, PredictorBase
37
+ from ..utils.types import PathLikeOrStr, PixelValues, Requirement
38
+ from .base import DetectionResult, ModelCategories, ObjectDetector
38
39
  from .pt.nms import batched_nms
39
40
  from .pt.ptutils import get_torch_device
40
41
 
@@ -50,7 +51,7 @@ with try_import() as d2_import_guard:
50
51
  from detectron2.structures import Instances # pylint: disable=W0611
51
52
 
52
53
 
53
- def _d2_post_processing(predictions: Dict[str, Instances], nms_thresh_class_agnostic: float) -> Dict[str, Instances]:
54
+ def _d2_post_processing(predictions: dict[str, Instances], nms_thresh_class_agnostic: float) -> dict[str, Instances]:
54
55
  """
55
56
  D2 postprocessing steps, so that detection outputs are aligned with outputs of other packages (e.g. Tensorpack).
56
57
  Apply a class agnostic NMS.
@@ -67,11 +68,11 @@ def _d2_post_processing(predictions: Dict[str, Instances], nms_thresh_class_agno
67
68
 
68
69
 
69
70
  def d2_predict_image(
70
- np_img: ImageType,
71
+ np_img: PixelValues,
71
72
  predictor: nn.Module,
72
73
  resizer: InferenceResize,
73
74
  nms_thresh_class_agnostic: float,
74
- ) -> List[DetectionResult]:
75
+ ) -> list[DetectionResult]:
75
76
  """
76
77
  Run detection on one image, using the D2 model callable. It will also handle the preprocessing internally which
77
78
  is using a custom resizing within some bounds.
@@ -103,8 +104,8 @@ def d2_predict_image(
103
104
 
104
105
 
105
106
  def d2_jit_predict_image(
106
- np_img: ImageType, d2_predictor: nn.Module, resizer: InferenceResize, nms_thresh_class_agnostic: float
107
- ) -> List[DetectionResult]:
107
+ np_img: PixelValues, d2_predictor: nn.Module, resizer: InferenceResize, nms_thresh_class_agnostic: float
108
+ ) -> list[DetectionResult]:
108
109
  """
109
110
  Run detection on an image using torchscript. It will also handle the preprocessing internally which
110
111
  is using a custom resizing within some bounds. Moreover, and different from the setting where Detectron2 is used
@@ -148,7 +149,7 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
148
149
 
149
150
  def __init__(
150
151
  self,
151
- categories: Mapping[str, TypeOrStr],
152
+ categories: Mapping[int, TypeOrStr],
152
153
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
153
154
  ):
154
155
  """
@@ -159,37 +160,31 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
159
160
  be filtered. Pass a list of category names that must not be returned
160
161
  """
161
162
 
163
+ self.categories = ModelCategories(init_categories=categories)
162
164
  if filter_categories:
163
- filter_categories = [get_type(cat) for cat in filter_categories]
164
- self.filter_categories = filter_categories
165
- self._categories_d2 = self._map_to_d2_categories(copy(categories))
166
- self.categories = {idx: get_type(cat) for idx, cat in categories.items()}
165
+ self.categories.filter_categories = tuple(get_type(cat) for cat in filter_categories)
167
166
 
168
- def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
167
+ def _map_category_names(self, detection_results: list[DetectionResult]) -> list[DetectionResult]:
169
168
  """
170
169
  Populating category names to detection results
171
170
 
172
171
  :param detection_results: list of detection results. Will also filter categories
173
172
  :return: List of detection results with attribute class_name populated
174
173
  """
175
- filtered_detection_result: List[DetectionResult] = []
174
+ filtered_detection_result: list[DetectionResult] = []
175
+ shifted_categories = self.categories.shift_category_ids(shift_by=-1)
176
176
  for result in detection_results:
177
- result.class_name = self._categories_d2[str(result.class_id)]
178
- if isinstance(result.class_id, int):
179
- result.class_id += 1
180
- if self.filter_categories:
181
- if result.class_name not in self.filter_categories:
177
+ result.class_name = shifted_categories.get(
178
+ result.class_id if result.class_id is not None else -1, DefaultType.DEFAULT_TYPE
179
+ )
180
+ if result.class_name != DefaultType.DEFAULT_TYPE:
181
+ if result.class_id is not None:
182
+ result.class_id += 1
182
183
  filtered_detection_result.append(result)
183
- else:
184
- filtered_detection_result.append(result)
185
184
  return filtered_detection_result
186
185
 
187
- @classmethod
188
- def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
189
- return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
190
-
191
- def possible_categories(self) -> List[ObjectTypes]:
192
- return list(self.categories.values())
186
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
187
+ return self.categories.get_categories(as_dict=False)
193
188
 
194
189
  @staticmethod
195
190
  def get_inference_resizer(min_size_test: int, max_size_test: int) -> InferenceResize:
@@ -201,7 +196,7 @@ class D2FrcnnDetectorMixin(ObjectDetector, ABC):
201
196
  return InferenceResize(min_size_test, max_size_test)
202
197
 
203
198
  @staticmethod
204
- def get_name(path_weights: str, architecture: str) -> str:
199
+ def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
205
200
  """Returns the name of the model"""
206
201
  return f"detectron2_{architecture}" + "_".join(Path(path_weights).parts[-2:])
207
202
 
@@ -230,10 +225,10 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
230
225
 
231
226
  def __init__(
232
227
  self,
233
- path_yaml: str,
234
- path_weights: str,
235
- categories: Mapping[str, TypeOrStr],
236
- config_overwrite: Optional[List[str]] = None,
228
+ path_yaml: PathLikeOrStr,
229
+ path_weights: PathLikeOrStr,
230
+ categories: Mapping[int, TypeOrStr],
231
+ config_overwrite: Optional[list[str]] = None,
237
232
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
238
233
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
239
234
  ):
@@ -257,8 +252,8 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
257
252
  """
258
253
  super().__init__(categories, filter_categories)
259
254
 
260
- self.path_weights = path_weights
261
- self.path_yaml = path_yaml
255
+ self.path_weights = Path(path_weights)
256
+ self.path_yaml = Path(path_yaml)
262
257
 
263
258
  config_overwrite = config_overwrite if config_overwrite else []
264
259
  self.config_overwrite = config_overwrite
@@ -275,11 +270,11 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
275
270
  self.resizer = self.get_inference_resizer(self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MAX_SIZE_TEST)
276
271
 
277
272
  @staticmethod
278
- def _set_config(path_yaml: str, d2_conf_list: List[str], device: torch.device) -> CfgNode:
273
+ def _set_config(path_yaml: PathLikeOrStr, d2_conf_list: list[str], device: torch.device) -> CfgNode:
279
274
  cfg = get_cfg()
280
275
  # additional attribute with default value, so that the true value can be loaded from the configs
281
276
  cfg.NMS_THRESH_CLASS_AGNOSTIC = 0.1
282
- cfg.merge_from_file(path_yaml)
277
+ cfg.merge_from_file(os.fspath(path_yaml))
283
278
  cfg.merge_from_list(d2_conf_list)
284
279
  cfg.MODEL.DEVICE = str(device)
285
280
  cfg.freeze()
@@ -296,11 +291,11 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
296
291
  return build_model(config.clone()).eval()
297
292
 
298
293
  @staticmethod
299
- def _instantiate_d2_predictor(wrapped_model: GeneralizedRCNN, path_weights: str) -> None:
294
+ def _instantiate_d2_predictor(wrapped_model: GeneralizedRCNN, path_weights: PathLikeOrStr) -> None:
300
295
  checkpointer = DetectionCheckpointer(wrapped_model)
301
- checkpointer.load(path_weights)
296
+ checkpointer.load(os.fspath(path_weights))
302
297
 
303
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
298
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
304
299
  """
305
300
  Prediction per image.
306
301
 
@@ -316,24 +311,24 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
316
311
  return self._map_category_names(detection_results)
317
312
 
318
313
  @classmethod
319
- def get_requirements(cls) -> List[Requirement]:
314
+ def get_requirements(cls) -> list[Requirement]:
320
315
  return [get_pytorch_requirement(), get_detectron2_requirement()]
321
316
 
322
- def clone(self) -> PredictorBase:
317
+ def clone(self) -> D2FrcnnDetector:
323
318
  return self.__class__(
324
319
  self.path_yaml,
325
320
  self.path_weights,
326
- self.categories,
321
+ self.categories.get_categories(),
327
322
  self.config_overwrite,
328
323
  self.device,
329
- self.filter_categories,
324
+ self.categories.filter_categories,
330
325
  )
331
326
 
332
327
  @staticmethod
333
328
  def get_wrapped_model(
334
- path_yaml: str,
335
- path_weights: str,
336
- config_overwrite: List[str],
329
+ path_yaml: PathLikeOrStr,
330
+ path_weights: PathLikeOrStr,
331
+ config_overwrite: list[str],
337
332
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
338
333
  ) -> GeneralizedRCNN:
339
334
  """
@@ -366,14 +361,17 @@ class D2FrcnnDetector(D2FrcnnDetectorMixin):
366
361
  return model
367
362
 
368
363
  @staticmethod
369
- def _get_d2_config_list(path_weights: str, config_overwrite: List[str]) -> List[str]:
370
- d2_conf_list = ["MODEL.WEIGHTS", path_weights]
364
+ def _get_d2_config_list(path_weights: PathLikeOrStr, config_overwrite: list[str]) -> list[str]:
365
+ d2_conf_list = ["MODEL.WEIGHTS", os.fspath(path_weights)]
371
366
  config_overwrite = config_overwrite if config_overwrite else []
372
367
  for conf in config_overwrite:
373
368
  key, val = conf.split("=", maxsplit=1)
374
369
  d2_conf_list.extend([key, val])
375
370
  return d2_conf_list
376
371
 
372
+ def clear_model(self) -> None:
373
+ self.d2_predictor = None
374
+
377
375
 
378
376
  class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
379
377
  """
@@ -401,10 +399,10 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
401
399
 
402
400
  def __init__(
403
401
  self,
404
- path_yaml: str,
405
- path_weights: str,
406
- categories: Mapping[str, TypeOrStr],
407
- config_overwrite: Optional[List[str]] = None,
402
+ path_yaml: PathLikeOrStr,
403
+ path_weights: PathLikeOrStr,
404
+ categories: Mapping[int, TypeOrStr],
405
+ config_overwrite: Optional[list[str]] = None,
408
406
  filter_categories: Optional[Sequence[TypeOrStr]] = None,
409
407
  ):
410
408
  """
@@ -424,8 +422,8 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
424
422
 
425
423
  super().__init__(categories, filter_categories)
426
424
 
427
- self.path_weights = path_weights
428
- self.path_yaml = path_yaml
425
+ self.path_weights = Path(path_weights)
426
+ self.path_yaml = Path(path_yaml)
429
427
 
430
428
  self.config_overwrite = copy(config_overwrite)
431
429
  self.cfg = self._set_config(self.path_yaml, self.path_weights, self.config_overwrite)
@@ -437,14 +435,16 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
437
435
  self.d2_predictor = self.get_wrapped_model(self.path_weights)
438
436
 
439
437
  @staticmethod
440
- def _set_config(path_yaml: str, path_weights: str, config_overwrite: Optional[List[str]]) -> AttrDict:
438
+ def _set_config(
439
+ path_yaml: PathLikeOrStr, path_weights: PathLikeOrStr, config_overwrite: Optional[list[str]]
440
+ ) -> AttrDict:
441
441
  cfg = set_config_by_yaml(path_yaml)
442
442
  config_overwrite = config_overwrite if config_overwrite else []
443
- config_overwrite.extend([f"MODEL.WEIGHTS={path_weights}"])
443
+ config_overwrite.extend([f"MODEL.WEIGHTS={os.fspath(path_weights)}"])
444
444
  cfg.update_args(config_overwrite)
445
445
  return cfg
446
446
 
447
- def predict(self, np_img: ImageType) -> List[DetectionResult]:
447
+ def predict(self, np_img: PixelValues) -> list[DetectionResult]:
448
448
  """
449
449
  Prediction per image.
450
450
 
@@ -460,46 +460,23 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
460
460
  return self._map_category_names(detection_results)
461
461
 
462
462
  @classmethod
463
- def get_requirements(cls) -> List[Requirement]:
463
+ def get_requirements(cls) -> list[Requirement]:
464
464
  return [get_pytorch_requirement()]
465
465
 
466
- @classmethod
467
- def _map_to_d2_categories(cls, categories: Mapping[str, TypeOrStr]) -> Dict[str, ObjectTypes]:
468
- return {str(int(k) - 1): get_type(v) for k, v in categories.items()}
469
-
470
- def clone(self) -> PredictorBase:
466
+ def clone(self) -> D2FrcnnTracingDetector:
471
467
  return self.__class__(
472
468
  self.path_yaml,
473
469
  self.path_weights,
474
- self.categories,
470
+ self.categories.get_categories(),
475
471
  self.config_overwrite,
476
- self.filter_categories,
472
+ self.categories.filter_categories,
477
473
  )
478
474
 
479
- def _map_category_names(self, detection_results: List[DetectionResult]) -> List[DetectionResult]:
480
- """
481
- Populating category names to detection results
482
-
483
- :param detection_results: list of detection results. Will also filter categories
484
- :return: List of detection results with attribute class_name populated
485
- """
486
- filtered_detection_result: List[DetectionResult] = []
487
- for result in detection_results:
488
- result.class_name = self._categories_d2[str(result.class_id)]
489
- if isinstance(result.class_id, int):
490
- result.class_id += 1
491
- if self.filter_categories:
492
- if result.class_name not in self.filter_categories:
493
- filtered_detection_result.append(result)
494
- else:
495
- filtered_detection_result.append(result)
496
- return filtered_detection_result
497
-
498
- def possible_categories(self) -> List[ObjectTypes]:
499
- return list(self.categories.values())
475
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
476
+ return self.categories.get_categories(as_dict=False)
500
477
 
501
478
  @staticmethod
502
- def get_wrapped_model(path_weights: str) -> Any:
479
+ def get_wrapped_model(path_weights: PathLikeOrStr) -> torch.jit.ScriptModule:
503
480
  """
504
481
  Get the wrapped model. Useful if one do not want to build the wrapper but only needs the instantiated model.
505
482
 
@@ -510,3 +487,6 @@ class D2FrcnnTracingDetector(D2FrcnnDetectorMixin):
510
487
  buffer = io.BytesIO(file.read())
511
488
  # Load all tensors to the original device
512
489
  return torch.jit.load(buffer)
490
+
491
+ def clear_model(self) -> None:
492
+ self.d2_predictor = None # type: ignore
@@ -18,14 +18,13 @@
18
18
  """
19
19
  jdeskew estimator and rotator to deskew images: <https://github.com/phamquiluan/jdeskew>
20
20
  """
21
-
22
- from typing import List
21
+ from __future__ import annotations
23
22
 
24
23
  from lazy_imports import try_import
25
24
 
26
- from ..utils.detection_types import ImageType, Requirement
27
25
  from ..utils.file_utils import get_jdeskew_requirement
28
- from ..utils.settings import PageType
26
+ from ..utils.settings import ObjectTypes, PageType
27
+ from ..utils.types import PixelValues, Requirement
29
28
  from ..utils.viz import viz_handler
30
29
  from .base import DetectionResult, ImageTransformer
31
30
 
@@ -44,7 +43,7 @@ class Jdeskewer(ImageTransformer):
44
43
  self.model_id = self.get_model_id()
45
44
  self.min_angle_rotation = min_angle_rotation
46
45
 
47
- def transform(self, np_img: ImageType, specification: DetectionResult) -> ImageType:
46
+ def transform(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
48
47
  """
49
48
  Rotation of the image according to the angle determined by the jdeskew estimator.
50
49
 
@@ -61,7 +60,7 @@ class Jdeskewer(ImageTransformer):
61
60
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
62
61
  return np_img
63
62
 
64
- def predict(self, np_img: ImageType) -> DetectionResult:
63
+ def predict(self, np_img: PixelValues) -> DetectionResult:
65
64
  """
66
65
  Predict the angle of the image to deskew it.
67
66
 
@@ -71,12 +70,14 @@ class Jdeskewer(ImageTransformer):
71
70
  return DetectionResult(angle=round(float(get_angle(np_img)), 4))
72
71
 
73
72
  @classmethod
74
- def get_requirements(cls) -> List[Requirement]:
73
+ def get_requirements(cls) -> list[Requirement]:
75
74
  """
76
75
  Get a list of requirements for running the detector
77
76
  """
78
77
  return [get_jdeskew_requirement()]
79
78
 
80
- @staticmethod
81
- def possible_category() -> PageType:
82
- return PageType.angle
79
+ def clone(self) -> Jdeskewer:
80
+ return self.__class__(self.min_angle_rotation)
81
+
82
+ def get_category_names(self) -> tuple[ObjectTypes, ...]:
83
+ return (PageType.ANGLE,)