deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -18,9 +18,11 @@
18
18
  """
19
19
  Module for cell detection pipeline component
20
20
  """
21
+ from __future__ import annotations
22
+
21
23
  from collections import Counter
22
- from copy import deepcopy
23
- from typing import Dict, List, Mapping, Optional, Sequence, Union
24
+ from types import MappingProxyType
25
+ from typing import Mapping, Optional, Sequence, Union
24
26
 
25
27
  import numpy as np
26
28
 
@@ -28,10 +30,10 @@ from ..datapoint.annotation import ImageAnnotation
28
30
  from ..datapoint.box import crop_box_from_image
29
31
  from ..datapoint.image import Image
30
32
  from ..extern.base import DetectionResult, ObjectDetector, PdfMiner
31
- from ..utils.detection_types import ImageType, JsonDict
32
- from ..utils.settings import ObjectTypes, Relationships
33
+ from ..utils.settings import ObjectTypes, Relationships, TypeOrStr, get_type
33
34
  from ..utils.transform import PadTransform
34
- from .base import PredictorPipelineComponent
35
+ from ..utils.types import PixelValues
36
+ from .base import MetaAnnotation, PipelineComponent
35
37
  from .registry import pipeline_component_registry
36
38
 
37
39
 
@@ -47,9 +49,9 @@ class DetectResultGenerator:
47
49
 
48
50
  def __init__(
49
51
  self,
50
- categories: Mapping[str, ObjectTypes],
51
- group_categories: Optional[List[List[str]]] = None,
52
- exclude_category_ids: Optional[Sequence[str]] = None,
52
+ categories: Mapping[int, ObjectTypes],
53
+ group_categories: Optional[list[list[int]]] = None,
54
+ exclude_category_ids: Optional[Sequence[int]] = None,
53
55
  absolute_coords: bool = True,
54
56
  ) -> None:
55
57
  """
@@ -59,7 +61,7 @@ class DetectResultGenerator:
59
61
  grouping category ids.
60
62
  :param absolute_coords: 'absolute_coords' value to be set in 'DetectionResult'
61
63
  """
62
- self.categories = categories
64
+ self.categories = MappingProxyType(dict(categories.items()))
63
65
  self.width: Optional[int] = None
64
66
  self.height: Optional[int] = None
65
67
  if group_categories is None:
@@ -71,7 +73,7 @@ class DetectResultGenerator:
71
73
  self.dummy_for_group_generated = [False for _ in self.group_categories]
72
74
  self.absolute_coords = absolute_coords
73
75
 
74
- def create_detection_result(self, detect_result_list: List[DetectionResult]) -> List[DetectionResult]:
76
+ def create_detection_result(self, detect_result_list: list[DetectionResult]) -> list[DetectionResult]:
75
77
  """
76
78
  Adds DetectResults for which no object was detected to the list.
77
79
 
@@ -100,8 +102,8 @@ class DetectResultGenerator:
100
102
  self.dummy_for_group_generated = self._initialize_dummy_for_group_generated()
101
103
  return detect_result_list
102
104
 
103
- def _create_condition(self, detect_result_list: List[DetectionResult]) -> Dict[str, int]:
104
- count = Counter([str(ann.class_id) for ann in detect_result_list])
105
+ def _create_condition(self, detect_result_list: list[DetectionResult]) -> dict[int, int]:
106
+ count = Counter([ann.class_id for ann in detect_result_list])
105
107
  cat_to_group_sum = {}
106
108
  for group in self.group_categories:
107
109
  group_sum = 0
@@ -111,7 +113,7 @@ class DetectResultGenerator:
111
113
  cat_to_group_sum[el] = group_sum
112
114
  return cat_to_group_sum
113
115
 
114
- def _dummy_for_group_generated(self, category_id: str) -> bool:
116
+ def _dummy_for_group_generated(self, category_id: int) -> bool:
115
117
  for idx, group in enumerate(self.group_categories):
116
118
  if category_id in group:
117
119
  is_generated = self.dummy_for_group_generated[idx]
@@ -119,12 +121,12 @@ class DetectResultGenerator:
119
121
  return is_generated
120
122
  return False
121
123
 
122
- def _initialize_dummy_for_group_generated(self) -> List[bool]:
124
+ def _initialize_dummy_for_group_generated(self) -> list[bool]:
123
125
  return [False for _ in self.group_categories]
124
126
 
125
127
 
126
128
  @pipeline_component_registry.register("SubImageLayoutService")
127
- class SubImageLayoutService(PredictorPipelineComponent):
129
+ class SubImageLayoutService(PipelineComponent):
128
130
  """
129
131
  Component in which the selected ImageAnnotation can be selected with cropped images and presented to a detector.
130
132
 
@@ -144,8 +146,8 @@ class SubImageLayoutService(PredictorPipelineComponent):
144
146
  def __init__(
145
147
  self,
146
148
  sub_image_detector: ObjectDetector,
147
- sub_image_names: Union[str, List[str]],
148
- category_id_mapping: Optional[Dict[int, int]] = None,
149
+ sub_image_names: Union[str, Sequence[TypeOrStr]],
150
+ category_id_mapping: Optional[dict[int, int]] = None,
149
151
  detect_result_generator: Optional[DetectResultGenerator] = None,
150
152
  padder: Optional[PadTransform] = None,
151
153
  ):
@@ -163,16 +165,23 @@ class SubImageLayoutService(PredictorPipelineComponent):
163
165
  inverse coordinate transformation.
164
166
  """
165
167
 
166
- if isinstance(sub_image_names, str):
167
- sub_image_names = [sub_image_names]
168
-
169
- self.sub_image_name = sub_image_names
168
+ self.sub_image_name = (
169
+ (get_type(sub_image_names),)
170
+ if isinstance(sub_image_names, str)
171
+ else tuple((get_type(cat) for cat in sub_image_names))
172
+ )
170
173
  self.category_id_mapping = category_id_mapping
171
174
  self.detect_result_generator = detect_result_generator
172
175
  self.padder = padder
173
- super().__init__(self._get_name(sub_image_detector.name), sub_image_detector)
176
+ self.predictor = sub_image_detector
177
+ super().__init__(self._get_name(sub_image_detector.name), self.predictor.model_id)
174
178
  if self.detect_result_generator is not None:
175
- assert self.detect_result_generator.categories == self.predictor.categories # type: ignore
179
+ if self.detect_result_generator.categories != self.predictor.categories.get_categories():
180
+ raise ValueError(
181
+ f"The categories of the 'detect_result_generator' must be the same as the categories of the "
182
+ f"'sub_image_detector'. Got {self.detect_result_generator.categories} #"
183
+ f"and {self.predictor.categories.get_categories()}."
184
+ )
176
185
 
177
186
  def serve(self, dp: Image) -> None:
178
187
  """
@@ -181,10 +190,10 @@ class SubImageLayoutService(PredictorPipelineComponent):
181
190
  - Optionally invoke the DetectResultGenerator
182
191
  - Generate ImageAnnotations and dump to parent image and sub image.
183
192
  """
184
- sub_image_anns = dp.get_annotation_iter(category_names=self.sub_image_name)
193
+ sub_image_anns = dp.get_annotation(category_names=self.sub_image_name)
185
194
  for sub_image_ann in sub_image_anns:
186
195
  np_image = self.prepare_np_image(sub_image_ann)
187
- detect_result_list = self.predictor.predict(np_image) # type: ignore
196
+ detect_result_list = self.predictor.predict(np_image)
188
197
  if self.padder and detect_result_list:
189
198
  boxes = np.array([detect_result.box for detect_result in detect_result_list])
190
199
  boxes_orig = self.padder.inverse_apply_coords(boxes)
@@ -203,23 +212,21 @@ class SubImageLayoutService(PredictorPipelineComponent):
203
212
  )
204
213
  self.dp_manager.set_image_annotation(detect_result, sub_image_ann.annotation_id)
205
214
 
206
- def get_meta_annotation(self) -> JsonDict:
207
- assert isinstance(self.predictor, (ObjectDetector, PdfMiner))
208
- return dict(
209
- [
210
- ("image_annotations", self.predictor.possible_categories()),
211
- ("sub_categories", {}),
212
- # implicit setup of relations by using set_image_annotation with explicit annotation_id
213
- ("relationships", {parent: {Relationships.child} for parent in self.sub_image_name}),
214
- ("summaries", []),
215
- ]
215
+ def get_meta_annotation(self) -> MetaAnnotation:
216
+ if not isinstance(self.predictor, (ObjectDetector, PdfMiner)):
217
+ raise ValueError(f"predictor must be of type ObjectDetector but is of type {type(self.predictor)}")
218
+ return MetaAnnotation(
219
+ image_annotations=self.predictor.get_category_names(),
220
+ sub_categories={},
221
+ relationships={get_type(parent): {Relationships.CHILD} for parent in self.sub_image_name},
222
+ summaries=(),
216
223
  )
217
224
 
218
225
  @staticmethod
219
226
  def _get_name(predictor_name: str) -> str:
220
227
  return f"sub_image_{predictor_name}"
221
228
 
222
- def clone(self) -> "PredictorPipelineComponent":
229
+ def clone(self) -> SubImageLayoutService:
223
230
  predictor = self.predictor.clone()
224
231
  padder_clone = None
225
232
  if self.padder:
@@ -228,13 +235,13 @@ class SubImageLayoutService(PredictorPipelineComponent):
228
235
  raise ValueError(f"predictor must be of type ObjectDetector but is of type {type(predictor)}")
229
236
  return self.__class__(
230
237
  predictor,
231
- deepcopy(self.sub_image_name),
232
- deepcopy(self.category_id_mapping),
233
- deepcopy(self.detect_result_generator),
238
+ self.sub_image_name,
239
+ self.category_id_mapping,
240
+ self.detect_result_generator,
234
241
  padder_clone,
235
242
  )
236
243
 
237
- def prepare_np_image(self, sub_image_ann: ImageAnnotation) -> ImageType:
244
+ def prepare_np_image(self, sub_image_ann: ImageAnnotation) -> PixelValues:
238
245
  """Maybe crop and pad a np_array before passing it to the predictor.
239
246
 
240
247
  Note that we currently assume to a two level hierachy of images, e.g. we can crop a sub-image from the base
@@ -256,3 +263,6 @@ class SubImageLayoutService(PredictorPipelineComponent):
256
263
  if self.padder:
257
264
  np_image = self.padder.apply_image(np_image)
258
265
  return np_image
266
+
267
+ def clear_predictor(self) -> None:
268
+ self.predictor.clear_model()
@@ -18,24 +18,27 @@
18
18
  """
19
19
  Module for text extraction pipeline component
20
20
  """
21
+
22
+ from __future__ import annotations
23
+
21
24
  from copy import deepcopy
22
- from typing import List, Optional, Sequence, Tuple, Union
25
+ from typing import Optional, Sequence, Union
23
26
 
24
27
  from ..datapoint.annotation import ImageAnnotation
25
28
  from ..datapoint.image import Image
26
29
  from ..extern.base import ObjectDetector, PdfMiner, TextRecognizer
27
30
  from ..extern.tessocr import TesseractOcrDetector
28
- from ..utils.detection_types import ImageType, JsonDict
29
31
  from ..utils.error import ImageError
30
- from ..utils.settings import PageType, TypeOrStr, WordType, get_type
31
- from .base import PredictorPipelineComponent
32
+ from ..utils.settings import ObjectTypes, PageType, TypeOrStr, WordType, get_type
33
+ from ..utils.types import PixelValues
34
+ from .base import MetaAnnotation, PipelineComponent
32
35
  from .registry import pipeline_component_registry
33
36
 
34
37
  __all__ = ["TextExtractionService"]
35
38
 
36
39
 
37
40
  @pipeline_component_registry.register("TextExtractionService")
38
- class TextExtractionService(PredictorPipelineComponent):
41
+ class TextExtractionService(PipelineComponent):
39
42
  """
40
43
  Pipeline component for extracting text. Any detector can be selected, provided that it can evaluate a
41
44
  numpy array as an image.
@@ -83,11 +86,13 @@ class TextExtractionService(PredictorPipelineComponent):
83
86
  if extract_from_roi is None:
84
87
  extract_from_roi = []
85
88
  self.extract_from_category = (
86
- [get_type(extract_from_roi)]
89
+ (get_type(extract_from_roi),)
87
90
  if isinstance(extract_from_roi, str)
88
- else [get_type(roi_category) for roi_category in extract_from_roi]
91
+ else tuple((get_type(roi_category) for roi_category in extract_from_roi))
89
92
  )
90
- super().__init__(self._get_name(text_extract_detector.name), text_extract_detector)
93
+
94
+ self.predictor = text_extract_detector
95
+ super().__init__(self._get_name(text_extract_detector.name), self.predictor.model_id)
91
96
  if self.extract_from_category:
92
97
  if not isinstance(self.predictor, (ObjectDetector, TextRecognizer)):
93
98
  raise TypeError(
@@ -95,9 +100,8 @@ class TextExtractionService(PredictorPipelineComponent):
95
100
  f"TextRecognizer. Got {type(self.predictor)}"
96
101
  )
97
102
  if run_time_ocr_language_selection:
98
- assert isinstance(
99
- self.predictor, TesseractOcrDetector
100
- ), "Only TesseractOcrDetector supports multiple languages"
103
+ if not isinstance(self.predictor, TesseractOcrDetector):
104
+ raise TypeError("Only TesseractOcrDetector supports multiple languages")
101
105
 
102
106
  self.run_time_ocr_language_selection = run_time_ocr_language_selection
103
107
  self.skip_if_text_extracted = skip_if_text_extracted
@@ -120,7 +124,7 @@ class TextExtractionService(PredictorPipelineComponent):
120
124
  else:
121
125
  width, height = None, None
122
126
  if self.run_time_ocr_language_selection:
123
- self.predictor.set_language(dp.summary.get_sub_category(PageType.language).value) # type: ignore
127
+ self.predictor.set_language(dp.summary.get_sub_category(PageType.LANGUAGE).value) # type: ignore
124
128
  detect_result_list = self.predictor.predict(predictor_input) # type: ignore
125
129
  if isinstance(self.predictor, PdfMiner):
126
130
  width, height = self.predictor.get_width_height(predictor_input) # type: ignore
@@ -134,15 +138,15 @@ class TextExtractionService(PredictorPipelineComponent):
134
138
  )
135
139
  if detect_ann_id is not None:
136
140
  self.dp_manager.set_container_annotation(
137
- WordType.characters,
141
+ WordType.CHARACTERS,
138
142
  None,
139
- WordType.characters,
143
+ WordType.CHARACTERS,
140
144
  detect_ann_id,
141
145
  detect_result.text if detect_result.text is not None else "",
142
146
  detect_result.score,
143
147
  )
144
148
 
145
- def get_text_rois(self, dp: Image) -> Sequence[Union[Image, ImageAnnotation, List[ImageAnnotation]]]:
149
+ def get_text_rois(self, dp: Image) -> Sequence[Union[Image, ImageAnnotation, list[ImageAnnotation]]]:
146
150
  """
147
151
  Return image rois based on selected categories. As this selection makes only sense for specific text extractors
148
152
  (e.g. those who do proper OCR and do not mine from text from native pdfs) it will do some sanity checks.
@@ -151,7 +155,7 @@ class TextExtractionService(PredictorPipelineComponent):
151
155
  :return: list of ImageAnnotation or Image
152
156
  """
153
157
  if self.skip_if_text_extracted:
154
- text_categories = self.predictor.possible_categories() # type: ignore
158
+ text_categories = self.predictor.get_category_names()
155
159
  text_anns = dp.get_annotation(category_names=text_categories)
156
160
  if text_anns:
157
161
  return []
@@ -163,8 +167,8 @@ class TextExtractionService(PredictorPipelineComponent):
163
167
  return [dp]
164
168
 
165
169
  def get_predictor_input(
166
- self, text_roi: Union[Image, ImageAnnotation, List[ImageAnnotation]]
167
- ) -> Optional[Union[bytes, ImageType, List[Tuple[str, ImageType]], int]]:
170
+ self, text_roi: Union[Image, ImageAnnotation, list[ImageAnnotation]]
171
+ ) -> Optional[Union[bytes, PixelValues, list[tuple[str, PixelValues]], int]]:
168
172
  """
169
173
  Return raw input for a given `text_roi`. This can be a numpy array or pdf bytes and depends on the chosen
170
174
  predictor.
@@ -191,38 +195,35 @@ class TextExtractionService(PredictorPipelineComponent):
191
195
  return text_roi.pdf_bytes
192
196
  return 1
193
197
 
194
- def get_meta_annotation(self) -> JsonDict:
198
+ def get_meta_annotation(self) -> MetaAnnotation:
199
+ sub_cat_dict: dict[ObjectTypes, set[ObjectTypes]]
195
200
  if self.extract_from_category:
196
- sub_cat_dict = {category: {WordType.characters} for category in self.extract_from_category}
201
+ sub_cat_dict = {category: {WordType.CHARACTERS} for category in self.extract_from_category}
197
202
  else:
198
203
  if not isinstance(self.predictor, (ObjectDetector, PdfMiner)):
199
204
  raise TypeError(
200
205
  f"self.predictor must be of type ObjectDetector or PdfMiner but is of type "
201
206
  f"{type(self.predictor)}"
202
207
  )
203
- sub_cat_dict = {category: {WordType.characters} for category in self.predictor.possible_categories()}
204
- return dict(
205
- [
206
- (
207
- "image_annotations",
208
- (
209
- self.predictor.possible_categories()
210
- if isinstance(self.predictor, (ObjectDetector, PdfMiner))
211
- else []
212
- ),
213
- ),
214
- ("sub_categories", sub_cat_dict),
215
- ("relationships", {}),
216
- ("summaries", []),
217
- ]
208
+ sub_cat_dict = {category: {WordType.CHARACTERS} for category in self.predictor.get_category_names()}
209
+ return MetaAnnotation(
210
+ image_annotations=self.predictor.get_category_names()
211
+ if isinstance(self.predictor, (ObjectDetector, PdfMiner))
212
+ else (),
213
+ sub_categories=sub_cat_dict,
214
+ relationships={},
215
+ summaries=(),
218
216
  )
219
217
 
220
218
  @staticmethod
221
219
  def _get_name(text_detector_name: str) -> str:
222
220
  return f"text_extract_{text_detector_name}"
223
221
 
224
- def clone(self) -> "PredictorPipelineComponent":
222
+ def clone(self) -> TextExtractionService:
225
223
  predictor = self.predictor.clone()
226
224
  if not isinstance(predictor, (ObjectDetector, PdfMiner, TextRecognizer)):
227
225
  raise ImageError(f"predictor must be of type ObjectDetector or PdfMiner, but is of type {type(predictor)}")
228
226
  return self.__class__(predictor, deepcopy(self.extract_from_category), self.run_time_ocr_language_selection)
227
+
228
+ def clear_predictor(self) -> None:
229
+ self.predictor.clear_model()
@@ -20,15 +20,16 @@ Module for transform style pipeline components. These pipeline components are us
20
20
  on images (e.g. deskew, de-noising or more general GAN like operations.
21
21
  """
22
22
 
23
+ from __future__ import annotations
24
+
23
25
  from ..datapoint.image import Image
24
26
  from ..extern.base import ImageTransformer
25
- from ..utils.detection_types import JsonDict
26
- from .base import ImageTransformPipelineComponent
27
+ from .base import MetaAnnotation, PipelineComponent
27
28
  from .registry import pipeline_component_registry
28
29
 
29
30
 
30
31
  @pipeline_component_registry.register("SimpleTransformService")
31
- class SimpleTransformService(ImageTransformPipelineComponent):
32
+ class SimpleTransformService(PipelineComponent):
32
33
  """
33
34
  Pipeline component for transforming an image. The service is designed for applying transform predictors that
34
35
  take an image as numpy array as input and return the same. The service itself will change the underlying metadata
@@ -44,7 +45,8 @@ class SimpleTransformService(ImageTransformPipelineComponent):
44
45
 
45
46
  :param transform_predictor: image transformer
46
47
  """
47
- super().__init__(self._get_name(transform_predictor.name), transform_predictor)
48
+ self.transform_predictor = transform_predictor
49
+ super().__init__(self._get_name(transform_predictor.name), self.transform_predictor.model_id)
48
50
 
49
51
  def serve(self, dp: Image) -> None:
50
52
  if dp.annotations:
@@ -60,26 +62,27 @@ class SimpleTransformService(ImageTransformPipelineComponent):
60
62
  self.dp_manager.datapoint.clear_image(True)
61
63
  self.dp_manager.datapoint.image = transformed_image
62
64
  self.dp_manager.set_summary_annotation(
63
- summary_key=self.transform_predictor.possible_category(),
64
- summary_name=self.transform_predictor.possible_category(),
65
+ summary_key=self.transform_predictor.get_category_names()[0],
66
+ summary_name=self.transform_predictor.get_category_names()[0],
65
67
  summary_number=None,
66
- summary_value=getattr(detection_result, self.transform_predictor.possible_category().value, None),
68
+ summary_value=getattr(detection_result, self.transform_predictor.get_category_names()[0].value, None),
67
69
  summary_score=detection_result.score,
68
70
  )
69
71
 
70
- def clone(self) -> "SimpleTransformService":
72
+ def clone(self) -> SimpleTransformService:
71
73
  return self.__class__(self.transform_predictor)
72
74
 
73
- def get_meta_annotation(self) -> JsonDict:
74
- return dict(
75
- [
76
- ("image_annotations", []),
77
- ("sub_categories", {}),
78
- ("relationships", {}),
79
- ("summaries", [self.transform_predictor.possible_category()]),
80
- ]
75
+ def get_meta_annotation(self) -> MetaAnnotation:
76
+ return MetaAnnotation(
77
+ image_annotations=(),
78
+ sub_categories={},
79
+ relationships={},
80
+ summaries=self.transform_predictor.get_category_names(),
81
81
  )
82
82
 
83
83
  @staticmethod
84
84
  def _get_name(transform_name: str) -> str:
85
85
  return f"simple_transform_{transform_name}"
86
+
87
+ def clear_predictor(self) -> None:
88
+ pass
@@ -21,7 +21,9 @@ Module for training Detectron2 `GeneralizedRCNN`
21
21
  from __future__ import annotations
22
22
 
23
23
  import copy
24
- from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union
24
+ import os
25
+ from pathlib import Path
26
+ from typing import Any, Mapping, Optional, Sequence, Type, Union
25
27
 
26
28
  from lazy_imports import try_import
27
29
 
@@ -33,11 +35,12 @@ from ..eval.eval import Evaluator
33
35
  from ..eval.registry import metric_registry
34
36
  from ..extern.d2detect import D2FrcnnDetector
35
37
  from ..mapper.d2struct import image_to_d2_frcnn_training
36
- from ..pipe.base import PredictorPipelineComponent
38
+ from ..pipe.base import PipelineComponent
37
39
  from ..pipe.registry import pipeline_component_registry
38
40
  from ..utils.error import DependencyError
39
41
  from ..utils.file_utils import get_wandb_requirement, wandb_available
40
42
  from ..utils.logger import LoggingRecord, logger
43
+ from ..utils.types import PathLikeOrStr
41
44
  from ..utils.utils import string_to_dict
42
45
 
43
46
  with try_import() as d2_import_guard:
@@ -58,8 +61,8 @@ with try_import() as wb_import_guard:
58
61
 
59
62
 
60
63
  def _set_config(
61
- path_config_yaml: str,
62
- conf_list: List[str],
64
+ path_config_yaml: PathLikeOrStr,
65
+ conf_list: list[str],
63
66
  dataset_train: DatasetBase,
64
67
  dataset_val: Optional[DatasetBase],
65
68
  metric_name: Optional[str],
@@ -74,7 +77,7 @@ def _set_config(
74
77
  cfg.WANDB.USE_WANDB = False
75
78
  cfg.WANDB.PROJECT = None
76
79
  cfg.WANDB.REPO = "deepdoctection"
77
- cfg.merge_from_file(path_config_yaml)
80
+ cfg.merge_from_file(path_config_yaml.as_posix() if isinstance(path_config_yaml, Path) else path_config_yaml)
78
81
  cfg.merge_from_list(conf_list)
79
82
 
80
83
  cfg.TEST.DO_EVAL = (
@@ -89,7 +92,7 @@ def _set_config(
89
92
  return cfg
90
93
 
91
94
 
92
- def _update_for_eval(config_overwrite: List[str]) -> List[str]:
95
+ def _update_for_eval(config_overwrite: list[str]) -> list[str]:
93
96
  ret = [item for item in config_overwrite if not "WANDB" in item]
94
97
  return ret
95
98
 
@@ -103,7 +106,7 @@ class WandbWriter(EventWriter):
103
106
  self,
104
107
  project: str,
105
108
  repo: str,
106
- config: Optional[Union[Dict[str, Any], CfgNode]] = None,
109
+ config: Optional[Union[dict[str, Any], CfgNode]] = None,
107
110
  window_size: int = 20,
108
111
  **kwargs: Any,
109
112
  ):
@@ -145,7 +148,7 @@ class D2Trainer(DefaultTrainer):
145
148
  self.build_val_dict: Mapping[str, str] = {}
146
149
  super().__init__(cfg)
147
150
 
148
- def build_hooks(self) -> List[HookBase]:
151
+ def build_hooks(self) -> list[HookBase]:
149
152
  """
150
153
  Overwritten from DefaultTrainer. This ensures that the EvalHook is being called before the writer and
151
154
  all metrics are being written to JSON, Tensorboard etc.
@@ -197,7 +200,7 @@ class D2Trainer(DefaultTrainer):
197
200
 
198
201
  return ret
199
202
 
200
- def build_writers(self) -> List[EventWriter]:
203
+ def build_writers(self) -> list[EventWriter]:
201
204
  """
202
205
  Build a list of writers to be using `default_writers()`.
203
206
  If you'd like a different list of writers, you can overwrite it in
@@ -226,7 +229,7 @@ class D2Trainer(DefaultTrainer):
226
229
  dataset=self.dataset, mapper=self.mapper, total_batch_size=cfg.SOLVER.IMS_PER_BATCH
227
230
  )
228
231
 
229
- def eval_with_dd_evaluator(self, **build_eval_kwargs: str) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
232
+ def eval_with_dd_evaluator(self, **build_eval_kwargs: str) -> Union[list[dict[str, Any]], dict[str, Any]]:
230
233
  """
231
234
  Running the Evaluator. This method will be called from the `EvalHook`
232
235
 
@@ -243,7 +246,7 @@ class D2Trainer(DefaultTrainer):
243
246
  def setup_evaluator(
244
247
  self,
245
248
  dataset_val: DatasetBase,
246
- pipeline_component: PredictorPipelineComponent,
249
+ pipeline_component: PipelineComponent,
247
250
  metric: Union[Type[MetricBase], MetricBase],
248
251
  build_val_dict: Optional[Mapping[str, str]] = None,
249
252
  ) -> None:
@@ -271,9 +274,7 @@ class D2Trainer(DefaultTrainer):
271
274
  self.build_val_dict = build_val_dict
272
275
  assert self.evaluator.pipe_component
273
276
  for comp in self.evaluator.pipe_component.pipe_components:
274
- assert isinstance(comp, PredictorPipelineComponent)
275
- assert isinstance(comp.predictor, D2FrcnnDetector)
276
- comp.predictor.d2_predictor = None
277
+ comp.clear_predictor()
277
278
 
278
279
  @classmethod
279
280
  def build_evaluator(cls, cfg, dataset_name): # type: ignore
@@ -281,11 +282,11 @@ class D2Trainer(DefaultTrainer):
281
282
 
282
283
 
283
284
  def train_d2_faster_rcnn(
284
- path_config_yaml: str,
285
+ path_config_yaml: PathLikeOrStr,
285
286
  dataset_train: Union[str, DatasetBase],
286
- path_weights: str,
287
- config_overwrite: Optional[List[str]] = None,
288
- log_dir: str = "train_log/frcnn",
287
+ path_weights: PathLikeOrStr,
288
+ config_overwrite: Optional[list[str]] = None,
289
+ log_dir: PathLikeOrStr = "train_log/frcnn",
289
290
  build_train_config: Optional[Sequence[str]] = None,
290
291
  dataset_val: Optional[DatasetBase] = None,
291
292
  build_val_config: Optional[Sequence[str]] = None,
@@ -342,13 +343,13 @@ def train_d2_faster_rcnn(
342
343
 
343
344
  assert cuda.device_count() > 0, "Has to train with GPU!"
344
345
 
345
- build_train_dict: Dict[str, str] = {}
346
+ build_train_dict: dict[str, str] = {}
346
347
  if build_train_config is not None:
347
348
  build_train_dict = string_to_dict(",".join(build_train_config))
348
349
  if "split" not in build_train_dict:
349
350
  build_train_dict["split"] = "train"
350
351
 
351
- build_val_dict: Dict[str, str] = {}
352
+ build_val_dict: dict[str, str] = {}
352
353
  if build_val_config is not None:
353
354
  build_val_dict = string_to_dict(",".join(build_val_config))
354
355
  if "split" not in build_val_dict:
@@ -358,9 +359,9 @@ def train_d2_faster_rcnn(
358
359
  config_overwrite = []
359
360
  conf_list = [
360
361
  "MODEL.WEIGHTS",
361
- path_weights,
362
+ os.fspath(path_weights),
362
363
  "OUTPUT_DIR",
363
- log_dir,
364
+ os.fspath(log_dir),
364
365
  ]
365
366
  for conf in config_overwrite:
366
367
  key, val = conf.split("=", maxsplit=1)
@@ -376,11 +377,13 @@ def train_d2_faster_rcnn(
376
377
  if metric_name is not None:
377
378
  metric = metric_registry.get(metric_name)
378
379
 
379
- dataset = DatasetAdapter(dataset_train, True, image_to_d2_frcnn_training(False), True, **build_train_dict)
380
+ dataset = DatasetAdapter(
381
+ dataset_train, True, image_to_d2_frcnn_training(False), True, number_repetitions=-1, **build_train_dict
382
+ )
380
383
  augment_list = [ResizeShortestEdge(cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MAX_SIZE_TRAIN), RandomFlip()]
381
384
  mapper = DatasetMapper(is_train=True, augmentations=augment_list, image_format="BGR")
382
385
 
383
- logger.info(LoggingRecord(f"Config: \n {str(cfg)}", cfg.to_dict()))
386
+ logger.info(LoggingRecord(f"Config: \n {str(cfg)}", dict(cfg)))
384
387
 
385
388
  trainer = D2Trainer(cfg, dataset, mapper)
386
389
  trainer.resume_or_load()
@@ -391,7 +394,6 @@ def train_d2_faster_rcnn(
391
394
  detector = D2FrcnnDetector(path_config_yaml, path_weights, categories, config_overwrite, cfg.MODEL.DEVICE)
392
395
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
393
396
  pipeline_component = pipeline_component_cls(detector)
394
- assert isinstance(pipeline_component, PredictorPipelineComponent)
395
397
 
396
398
  if metric_name is not None:
397
399
  metric = metric_registry.get(metric_name)