deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -18,16 +18,14 @@
18
18
  """
19
19
  Module for language detection pipeline component
20
20
  """
21
- from copy import copy, deepcopy
22
21
  from typing import Optional, Sequence
23
22
 
24
23
  from ..datapoint.image import Image
25
24
  from ..datapoint.view import Page
26
25
  from ..extern.base import LanguageDetector, ObjectDetector
27
- from ..utils.detection_types import JsonDict
28
26
  from ..utils.error import ImageError
29
27
  from ..utils.settings import PageType, TypeOrStr, get_type
30
- from .base import PipelineComponent
28
+ from .base import MetaAnnotation, PipelineComponent
31
29
  from .registry import pipeline_component_registry
32
30
 
33
31
 
@@ -74,26 +72,27 @@ class LanguageDetectionService(PipelineComponent):
74
72
  self.predictor = language_detector
75
73
  self.text_detector = text_detector
76
74
  self.text_container = get_type(text_container) if text_container is not None else text_container
77
- if floating_text_block_categories:
78
- floating_text_block_categories = [get_type(text_block) for text_block in floating_text_block_categories]
79
- self.floating_text_block_categories = floating_text_block_categories if floating_text_block_categories else []
80
- super().__init__(
81
- self._get_name(self.predictor.name)
82
- ) # cannot use PredictorPipelineComponent class because of return type of predict meth
75
+ self.floating_text_block_categories = (
76
+ tuple(get_type(text_block) for text_block in floating_text_block_categories)
77
+ if (floating_text_block_categories is not None)
78
+ else ()
79
+ )
80
+
81
+ super().__init__(self._get_name(self.predictor.name))
83
82
 
84
83
  def serve(self, dp: Image) -> None:
85
84
  if self.text_detector is None:
86
- page = Page.from_image(dp, self.text_container, self.floating_text_block_categories) # type: ignore
85
+ page = Page.from_image(dp, self.text_container, self.floating_text_block_categories)
87
86
  text = page.text_no_line_break
88
87
  else:
89
88
  if dp.image is None:
90
89
  raise ImageError("image cannot be None")
91
90
  detect_result_list = self.text_detector.predict(dp.image)
92
91
  # this is a concatenation of all detection result. No reading order
93
- text = " ".join([result.text for result in detect_result_list if result.text is not None])
92
+ text = " ".join((result.text for result in detect_result_list if result.text is not None))
94
93
  predict_result = self.predictor.predict(text)
95
94
  self.dp_manager.set_summary_annotation(
96
- PageType.language, PageType.language, 1, predict_result.text, predict_result.score
95
+ PageType.LANGUAGE, PageType.LANGUAGE, 1, predict_result.text, predict_result.score
97
96
  )
98
97
 
99
98
  def clone(self) -> PipelineComponent:
@@ -101,22 +100,18 @@ class LanguageDetectionService(PipelineComponent):
101
100
  if not isinstance(predictor, LanguageDetector):
102
101
  raise TypeError(f"Predictor must be of type LanguageDetector, but is of type {type(predictor)}")
103
102
  return self.__class__(
104
- predictor,
105
- copy(self.text_container),
106
- deepcopy(self.text_detector),
107
- deepcopy(self.floating_text_block_categories),
103
+ language_detector=predictor,
104
+ text_container=self.text_container,
105
+ text_detector=self.text_detector.clone() if self.text_detector is not None else None,
106
+ floating_text_block_categories=self.floating_text_block_categories,
108
107
  )
109
108
 
110
- def get_meta_annotation(self) -> JsonDict:
111
- return dict(
112
- [
113
- ("image_annotations", []),
114
- ("sub_categories", {}),
115
- ("relationships", {}),
116
- ("summaries", [PageType.language]),
117
- ]
118
- )
109
+ def get_meta_annotation(self) -> MetaAnnotation:
110
+ return MetaAnnotation(image_annotations=(), sub_categories={}, relationships={}, summaries=(PageType.LANGUAGE,))
119
111
 
120
112
  @staticmethod
121
113
  def _get_name(predictor_name: str) -> str:
122
114
  return f"language_detection_{predictor_name}"
115
+
116
+ def clear_predictor(self) -> None:
117
+ self.predictor.clear_model()
@@ -26,15 +26,14 @@ import numpy as np
26
26
 
27
27
  from ..datapoint.image import Image
28
28
  from ..extern.base import ObjectDetector, PdfMiner
29
- from ..utils.detection_types import JsonDict
30
29
  from ..utils.error import ImageError
31
30
  from ..utils.transform import PadTransform
32
- from .base import PredictorPipelineComponent
31
+ from .base import MetaAnnotation, PipelineComponent
33
32
  from .registry import pipeline_component_registry
34
33
 
35
34
 
36
35
  @pipeline_component_registry.register("ImageLayoutService")
37
- class ImageLayoutService(PredictorPipelineComponent):
36
+ class ImageLayoutService(PipelineComponent):
38
37
  """
39
38
  Pipeline component for determining the layout. Which layout blocks are determined depends on the Detector and thus
40
39
  usually on the data set on which the Detector was pre-trained. If the Detector has been trained on Publaynet, these
@@ -65,6 +64,7 @@ class ImageLayoutService(PredictorPipelineComponent):
65
64
  :param crop_image: Do not only populate `ImageAnnotation.image` but also crop the detected block according
66
65
  to its bounding box and populate the resulting sub image to
67
66
  `ImageAnnotation.image.image`.
67
+ :param padder: If not `None`, will apply the padder to the image before prediction and inverse apply the padder
68
68
  :param skip_if_layout_extracted: When `True` will check, if there are already `ImageAnnotation` of a category
69
69
  available that will be predicted by the `layout_detector`. If yes, will skip
70
70
  the prediction process.
@@ -73,11 +73,12 @@ class ImageLayoutService(PredictorPipelineComponent):
73
73
  self.crop_image = crop_image
74
74
  self.padder = padder
75
75
  self.skip_if_layout_extracted = skip_if_layout_extracted
76
- super().__init__(self._get_name(layout_detector.name), layout_detector)
76
+ self.predictor = layout_detector
77
+ super().__init__(self._get_name(layout_detector.name), self.predictor.model_id)
77
78
 
78
79
  def serve(self, dp: Image) -> None:
79
80
  if self.skip_if_layout_extracted:
80
- categories = self.predictor.possible_categories() # type: ignore
81
+ categories = self.predictor.get_category_names()
81
82
  anns = dp.get_annotation(category_names=categories)
82
83
  if anns:
83
84
  return
@@ -86,7 +87,7 @@ class ImageLayoutService(PredictorPipelineComponent):
86
87
  np_image = dp.image
87
88
  if self.padder:
88
89
  np_image = self.padder.apply_image(np_image)
89
- detect_result_list = self.predictor.predict(np_image) # type: ignore
90
+ detect_result_list = self.predictor.predict(np_image)
90
91
  if self.padder and detect_result_list:
91
92
  boxes = np.array([detect_result.box for detect_result in detect_result_list])
92
93
  boxes_orig = self.padder.inverse_apply_coords(boxes)
@@ -96,22 +97,20 @@ class ImageLayoutService(PredictorPipelineComponent):
96
97
  for detect_result in detect_result_list:
97
98
  self.dp_manager.set_image_annotation(detect_result, to_image=self.to_image, crop_image=self.crop_image)
98
99
 
99
- def get_meta_annotation(self) -> JsonDict:
100
- assert isinstance(self.predictor, (ObjectDetector, PdfMiner))
101
- return dict(
102
- [
103
- ("image_annotations", self.predictor.possible_categories()),
104
- ("sub_categories", {}),
105
- ("relationships", {}),
106
- ("summaries", []),
107
- ]
100
+ def get_meta_annotation(self) -> MetaAnnotation:
101
+ if not isinstance(self.predictor, (ObjectDetector, PdfMiner)):
102
+ raise TypeError(
103
+ f"self.predictor must be of type ObjectDetector or PdfMiner but is of type " f"{type(self.predictor)}"
104
+ )
105
+ return MetaAnnotation(
106
+ image_annotations=self.predictor.get_category_names(), sub_categories={}, relationships={}, summaries=()
108
107
  )
109
108
 
110
109
  @staticmethod
111
110
  def _get_name(predictor_name: str) -> str:
112
111
  return f"image_{predictor_name}"
113
112
 
114
- def clone(self) -> PredictorPipelineComponent:
113
+ def clone(self) -> ImageLayoutService:
115
114
  predictor = self.predictor.clone()
116
115
  padder_clone = None
117
116
  if self.padder:
@@ -119,3 +118,6 @@ class ImageLayoutService(PredictorPipelineComponent):
119
118
  if not isinstance(predictor, ObjectDetector):
120
119
  raise TypeError(f"predictor must be of type ObjectDetector, but is of type {type(predictor)}")
121
120
  return self.__class__(predictor, self.to_image, self.crop_image, padder_clone, self.skip_if_layout_extracted)
121
+
122
+ def clear_predictor(self) -> None:
123
+ self.predictor.clear_model()
deepdoctection/pipe/lm.py CHANGED
@@ -21,19 +21,20 @@ Module for token classification pipeline
21
21
  from __future__ import annotations
22
22
 
23
23
  from copy import copy
24
- from typing import Any, Callable, List, Literal, Optional, Sequence, Union
24
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Sequence, Union
25
25
 
26
26
  from ..datapoint.image import Image
27
- from ..extern.hflayoutlm import HFLayoutLmSequenceClassifierBase, HFLayoutLmTokenClassifierBase
28
27
  from ..mapper.laylmstruct import image_to_layoutlm_features, image_to_lm_features
29
- from ..utils.detection_types import JsonDict
30
28
  from ..utils.settings import BioTag, LayoutType, ObjectTypes, PageType, TokenClasses, WordType
31
- from .base import LanguageModelPipelineComponent
29
+ from .base import MetaAnnotation, PipelineComponent
32
30
  from .registry import pipeline_component_registry
33
31
 
32
+ if TYPE_CHECKING:
33
+ from ..extern.hflayoutlm import HfLayoutSequenceModels, HfLayoutTokenModels
34
+
34
35
 
35
36
  @pipeline_component_registry.register("LMTokenClassifierService")
36
- class LMTokenClassifierService(LanguageModelPipelineComponent):
37
+ class LMTokenClassifierService(PipelineComponent):
37
38
  """
38
39
  Pipeline component for token classification
39
40
 
@@ -65,7 +66,7 @@ class LMTokenClassifierService(LanguageModelPipelineComponent):
65
66
  def __init__(
66
67
  self,
67
68
  tokenizer: Any,
68
- language_model: HFLayoutLmTokenClassifierBase,
69
+ language_model: HfLayoutTokenModels,
69
70
  padding: Literal["max_length", "do_not_pad", "longest"] = "max_length",
70
71
  truncation: bool = True,
71
72
  return_overflowing_tokens: bool = False,
@@ -109,15 +110,16 @@ class LMTokenClassifierService(LanguageModelPipelineComponent):
109
110
  self.segment_positions = segment_positions
110
111
  self.sliding_window_stride = sliding_window_stride
111
112
  if self.use_other_as_default_category:
112
- categories_name_as_key = {val: key for key, val in self.language_model.categories.items()}
113
+ categories_name_as_key = {val: key for key, val in self.language_model.categories.categories.items()}
113
114
  self.default_key: ObjectTypes
114
- if BioTag.outside in categories_name_as_key:
115
- self.default_key = BioTag.outside
115
+ if BioTag.OUTSIDE in categories_name_as_key:
116
+ self.default_key = BioTag.OUTSIDE
116
117
  else:
117
- self.default_key = TokenClasses.other
118
+ self.default_key = TokenClasses.OTHER
118
119
  self.other_name_as_key = {self.default_key: categories_name_as_key[self.default_key]}
119
- image_to_features_func = self.image_to_features_func(self.language_model.image_to_features_mapping())
120
- super().__init__(self._get_name(), tokenizer, image_to_features_func)
120
+ self.tokenizer = tokenizer
121
+ self.mapping_to_lm_input_func = self.image_to_features_func(self.language_model.image_to_features_mapping())
122
+ super().__init__(self._get_name(), self.language_model.model_id)
121
123
  self.required_kwargs = {
122
124
  "tokenizer": self.tokenizer,
123
125
  "padding": self.padding,
@@ -127,7 +129,7 @@ class LMTokenClassifierService(LanguageModelPipelineComponent):
127
129
  "segment_positions": self.segment_positions,
128
130
  "sliding_window_stride": self.sliding_window_stride,
129
131
  }
130
- self.required_kwargs.update(self.language_model.default_kwargs_for_input_mapping())
132
+ self.required_kwargs.update(self.language_model.default_kwargs_for_image_to_features_mapping())
131
133
  self._init_sanity_checks()
132
134
 
133
135
  def serve(self, dp: Image) -> None:
@@ -145,7 +147,7 @@ class LMTokenClassifierService(LanguageModelPipelineComponent):
145
147
  and not token.token.startswith("##")
146
148
  ]
147
149
 
148
- words_populated: List[str] = []
150
+ words_populated: list[str] = []
149
151
  for token in lm_output:
150
152
  if token.uuid not in words_populated:
151
153
  if token.class_name == token.semantic_name:
@@ -153,31 +155,31 @@ class LMTokenClassifierService(LanguageModelPipelineComponent):
153
155
  else:
154
156
  token_class_name_id = None
155
157
  self.dp_manager.set_category_annotation(
156
- token.semantic_name, token_class_name_id, WordType.token_class, token.uuid
158
+ token.semantic_name, token_class_name_id, WordType.TOKEN_CLASS, token.uuid
157
159
  )
158
- self.dp_manager.set_category_annotation(token.bio_tag, None, WordType.tag, token.uuid)
160
+ self.dp_manager.set_category_annotation(token.bio_tag, None, WordType.TAG, token.uuid)
159
161
  self.dp_manager.set_category_annotation(
160
- token.class_name, token.class_id, WordType.token_tag, token.uuid
162
+ token.class_name, token.class_id, WordType.TOKEN_TAG, token.uuid
161
163
  )
162
164
  words_populated.append(token.uuid)
163
165
 
164
166
  if self.use_other_as_default_category:
165
- word_anns = dp.get_annotation(LayoutType.word)
167
+ word_anns = dp.get_annotation(LayoutType.WORD)
166
168
  for word in word_anns:
167
- if WordType.token_class not in word.sub_categories:
169
+ if WordType.TOKEN_CLASS not in word.sub_categories:
168
170
  self.dp_manager.set_category_annotation(
169
- TokenClasses.other,
171
+ TokenClasses.OTHER,
170
172
  self.other_name_as_key[self.default_key],
171
- WordType.token_class,
173
+ WordType.TOKEN_CLASS,
172
174
  word.annotation_id,
173
175
  )
174
- if WordType.tag not in word.sub_categories:
175
- self.dp_manager.set_category_annotation(BioTag.outside, None, WordType.tag, word.annotation_id)
176
- if WordType.token_tag not in word.sub_categories:
176
+ if WordType.TAG not in word.sub_categories:
177
+ self.dp_manager.set_category_annotation(BioTag.OUTSIDE, None, WordType.TAG, word.annotation_id)
178
+ if WordType.TOKEN_TAG not in word.sub_categories:
177
179
  self.dp_manager.set_category_annotation(
178
180
  self.default_key,
179
181
  self.other_name_as_key[self.default_key],
180
- WordType.token_tag,
182
+ WordType.TOKEN_TAG,
181
183
  word.annotation_id,
182
184
  )
183
185
 
@@ -195,14 +197,12 @@ class LMTokenClassifierService(LanguageModelPipelineComponent):
195
197
  self.sliding_window_stride,
196
198
  )
197
199
 
198
- def get_meta_annotation(self) -> JsonDict:
199
- return dict(
200
- [
201
- ("image_annotations", []),
202
- ("sub_categories", {LayoutType.word: {WordType.token_class, WordType.tag, WordType.token_tag}}),
203
- ("relationships", {}),
204
- ("summaries", []),
205
- ]
200
+ def get_meta_annotation(self) -> MetaAnnotation:
201
+ return MetaAnnotation(
202
+ image_annotations=(),
203
+ sub_categories={LayoutType.WORD: {WordType.TOKEN_CLASS, WordType.TAG, WordType.TOKEN_TAG}},
204
+ relationships={},
205
+ summaries=(),
206
206
  )
207
207
 
208
208
  def _get_name(self) -> str:
@@ -223,9 +223,12 @@ class LMTokenClassifierService(LanguageModelPipelineComponent):
223
223
  mapping_str
224
224
  ]
225
225
 
226
+ def clear_predictor(self) -> None:
227
+ self.language_model.clear_model()
228
+
226
229
 
227
230
  @pipeline_component_registry.register("LMSequenceClassifierService")
228
- class LMSequenceClassifierService(LanguageModelPipelineComponent):
231
+ class LMSequenceClassifierService(PipelineComponent):
229
232
  """
230
233
  Pipeline component for sequence classification
231
234
 
@@ -257,7 +260,7 @@ class LMSequenceClassifierService(LanguageModelPipelineComponent):
257
260
  def __init__(
258
261
  self,
259
262
  tokenizer: Any,
260
- language_model: HFLayoutLmSequenceClassifierBase,
263
+ language_model: HfLayoutSequenceModels,
261
264
  padding: Literal["max_length", "do_not_pad", "longest"] = "max_length",
262
265
  truncation: bool = True,
263
266
  return_overflowing_tokens: bool = False,
@@ -281,8 +284,9 @@ class LMSequenceClassifierService(LanguageModelPipelineComponent):
281
284
  self.padding = padding
282
285
  self.truncation = truncation
283
286
  self.return_overflowing_tokens = return_overflowing_tokens
284
- image_to_features_func = self.image_to_features_func(self.language_model.image_to_features_mapping())
285
- super().__init__(self._get_name(), tokenizer, image_to_features_func)
287
+ self.tokenizer = tokenizer
288
+ self.mapping_to_lm_input_func = self.image_to_features_func(self.language_model.image_to_features_mapping())
289
+ super().__init__(self._get_name(), self.language_model.model_id)
286
290
  self.required_kwargs = {
287
291
  "tokenizer": self.tokenizer,
288
292
  "padding": self.padding,
@@ -290,7 +294,7 @@ class LMSequenceClassifierService(LanguageModelPipelineComponent):
290
294
  "return_overflowing_tokens": self.return_overflowing_tokens,
291
295
  "return_tensors": "pt",
292
296
  }
293
- self.required_kwargs.update(self.language_model.default_kwargs_for_input_mapping())
297
+ self.required_kwargs.update(self.language_model.default_kwargs_for_image_to_features_mapping())
294
298
  self._init_sanity_checks()
295
299
 
296
300
  def serve(self, dp: Image) -> None:
@@ -299,7 +303,7 @@ class LMSequenceClassifierService(LanguageModelPipelineComponent):
299
303
  return
300
304
  lm_output = self.language_model.predict(**lm_input)
301
305
  self.dp_manager.set_summary_annotation(
302
- PageType.document_type, lm_output.class_name, lm_output.class_id, None, lm_output.score
306
+ PageType.DOCUMENT_TYPE, lm_output.class_name, lm_output.class_id, None, lm_output.score
303
307
  )
304
308
 
305
309
  def clone(self) -> LMSequenceClassifierService:
@@ -311,14 +315,9 @@ class LMSequenceClassifierService(LanguageModelPipelineComponent):
311
315
  self.return_overflowing_tokens,
312
316
  )
313
317
 
314
- def get_meta_annotation(self) -> JsonDict:
315
- return dict(
316
- [
317
- ("image_annotations", []),
318
- ("sub_categories", {}),
319
- ("relationships", {}),
320
- ("summaries", [PageType.document_type]),
321
- ]
318
+ def get_meta_annotation(self) -> MetaAnnotation:
319
+ return MetaAnnotation(
320
+ image_annotations=(), sub_categories={}, relationships={}, summaries=(PageType.DOCUMENT_TYPE,)
322
321
  )
323
322
 
324
323
  def _get_name(self) -> str:
@@ -338,3 +337,6 @@ class LMSequenceClassifierService(LanguageModelPipelineComponent):
338
337
  return {"image_to_layoutlm_features": image_to_layoutlm_features, "image_to_lm_features": image_to_lm_features}[
339
338
  mapping_str
340
339
  ]
340
+
341
+ def clear_predictor(self) -> None:
342
+ self.language_model.clear_model()