deepdoctection 0.44.0__py3-none-any.whl → 0.45.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +6 -3
- deepdoctection/analyzer/config.py +41 -0
- deepdoctection/analyzer/factory.py +249 -1
- deepdoctection/configs/profiles.jsonl +2 -1
- deepdoctection/datapoint/image.py +1 -0
- deepdoctection/datapoint/view.py +162 -69
- deepdoctection/datasets/base.py +1 -0
- deepdoctection/extern/__init__.py +1 -0
- deepdoctection/extern/d2detect.py +1 -1
- deepdoctection/extern/fastlang.py +6 -4
- deepdoctection/extern/hflayoutlm.py +23 -10
- deepdoctection/extern/hflm.py +432 -7
- deepdoctection/mapper/laylmstruct.py +7 -7
- deepdoctection/pipe/language.py +4 -4
- deepdoctection/pipe/lm.py +7 -3
- deepdoctection/utils/file_utils.py +34 -0
- deepdoctection/utils/settings.py +2 -0
- deepdoctection/utils/types.py +0 -1
- deepdoctection/utils/viz.py +3 -3
- {deepdoctection-0.44.0.dist-info → deepdoctection-0.45.0.dist-info}/METADATA +15 -15
- {deepdoctection-0.44.0.dist-info → deepdoctection-0.45.0.dist-info}/RECORD +24 -24
- {deepdoctection-0.44.0.dist-info → deepdoctection-0.45.0.dist-info}/WHEEL +0 -0
- {deepdoctection-0.44.0.dist-info → deepdoctection-0.45.0.dist-info}/licenses/LICENSE +0 -0
- {deepdoctection-0.44.0.dist-info → deepdoctection-0.45.0.dist-info}/top_level.txt +0 -0
deepdoctection/datapoint/view.py
CHANGED
|
@@ -42,13 +42,60 @@ from ..utils.settings import (
|
|
|
42
42
|
get_type,
|
|
43
43
|
)
|
|
44
44
|
from ..utils.transform import ResizeTransform, box_to_point4, point4_to_box
|
|
45
|
-
from ..utils.types import HTML, AnnotationDict, Chunks, ImageDict, PathLikeOrStr, PixelValues,
|
|
45
|
+
from ..utils.types import HTML, AnnotationDict, Chunks, ImageDict, PathLikeOrStr, PixelValues, csv
|
|
46
46
|
from ..utils.viz import draw_boxes, interactive_imshow, viz_handler
|
|
47
47
|
from .annotation import CategoryAnnotation, ContainerAnnotation, ImageAnnotation, ann_from_dict
|
|
48
48
|
from .box import BoundingBox, crop_box_from_image
|
|
49
49
|
from .image import Image
|
|
50
50
|
|
|
51
51
|
|
|
52
|
+
@dataclass(frozen=True)
|
|
53
|
+
class Text_:
|
|
54
|
+
"""
|
|
55
|
+
Immutable dataclass for storing structured text extraction results.
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
text: The concatenated text string.
|
|
59
|
+
words: List of word strings.
|
|
60
|
+
ann_ids: List of annotation IDs for each word.
|
|
61
|
+
token_classes: List of token class names for each word.
|
|
62
|
+
token_class_ann_ids: List of annotation IDs for each token class.
|
|
63
|
+
token_tags: List of token tag names for each word.
|
|
64
|
+
token_tag_ann_ids: List of annotation IDs for each token tag.
|
|
65
|
+
token_class_ids: List of token class IDs.
|
|
66
|
+
token_tag_ids: List of token tag IDs.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
text: str = ""
|
|
70
|
+
words: list[str] = field(default_factory=list)
|
|
71
|
+
ann_ids: list[str] = field(default_factory=list)
|
|
72
|
+
token_classes: list[str] = field(default_factory=list)
|
|
73
|
+
token_class_ann_ids: list[str] = field(default_factory=list)
|
|
74
|
+
token_tags: list[str] = field(default_factory=list)
|
|
75
|
+
token_tag_ann_ids: list[str] = field(default_factory=list)
|
|
76
|
+
token_class_ids: list[str] = field(default_factory=list)
|
|
77
|
+
token_tag_ids: list[str] = field(default_factory=list)
|
|
78
|
+
|
|
79
|
+
def as_dict(self) -> dict[str, Union[list[str], str]]:
|
|
80
|
+
"""
|
|
81
|
+
Returns the Text_ as a dictionary.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
A dictionary representation of the Text_ dataclass.
|
|
85
|
+
"""
|
|
86
|
+
return {
|
|
87
|
+
"text": self.text,
|
|
88
|
+
"words": self.words,
|
|
89
|
+
"ann_ids": self.ann_ids,
|
|
90
|
+
"token_classes": self.token_classes,
|
|
91
|
+
"token_class_ann_ids": self.token_class_ann_ids,
|
|
92
|
+
"token_tags": self.token_tags,
|
|
93
|
+
"token_tag_ann_ids": self.token_tag_ann_ids,
|
|
94
|
+
"token_class_ids": self.token_class_ids,
|
|
95
|
+
"token_tag_ids": self.token_tag_ids,
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
52
99
|
class ImageAnnotationBaseView(ImageAnnotation):
|
|
53
100
|
"""
|
|
54
101
|
Consumption class for having easier access to categories added to an `ImageAnnotation`.
|
|
@@ -263,41 +310,73 @@ class Layout(ImageAnnotationBaseView):
|
|
|
263
310
|
"""
|
|
264
311
|
words = self.get_ordered_words()
|
|
265
312
|
if words:
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
313
|
+
(
|
|
314
|
+
characters,
|
|
315
|
+
ann_ids,
|
|
316
|
+
token_classes,
|
|
317
|
+
token_class_ann_ids,
|
|
318
|
+
token_tags,
|
|
319
|
+
token_tag_ann_ids,
|
|
320
|
+
token_classes_ids,
|
|
321
|
+
token_tag_ids,
|
|
322
|
+
) = map(
|
|
323
|
+
list,
|
|
324
|
+
zip(
|
|
325
|
+
*[
|
|
326
|
+
(
|
|
327
|
+
word.characters,
|
|
328
|
+
word.annotation_id,
|
|
329
|
+
word.token_class,
|
|
330
|
+
word.get_sub_category(WordType.TOKEN_CLASS).annotation_id
|
|
331
|
+
if WordType.TOKEN_CLASS in word.sub_categories
|
|
332
|
+
else None,
|
|
333
|
+
word.token_tag,
|
|
334
|
+
word.get_sub_category(WordType.TOKEN_TAG).annotation_id
|
|
335
|
+
if WordType.TOKEN_TAG in word.sub_categories
|
|
336
|
+
else None,
|
|
337
|
+
word.get_sub_category(WordType.TOKEN_CLASS).category_id
|
|
338
|
+
if WordType.TOKEN_CLASS in word.sub_categories
|
|
339
|
+
else None,
|
|
340
|
+
word.get_sub_category(WordType.TOKEN_TAG).category_id
|
|
341
|
+
if WordType.TOKEN_TAG in word.sub_categories
|
|
342
|
+
else None,
|
|
343
|
+
)
|
|
344
|
+
for word in words
|
|
345
|
+
]
|
|
346
|
+
),
|
|
282
347
|
)
|
|
283
348
|
else:
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
349
|
+
(
|
|
350
|
+
characters,
|
|
351
|
+
ann_ids,
|
|
352
|
+
token_classes,
|
|
353
|
+
token_class_ann_ids,
|
|
354
|
+
token_tags,
|
|
355
|
+
token_tag_ann_ids,
|
|
356
|
+
token_classes_ids,
|
|
357
|
+
token_tag_ids,
|
|
358
|
+
) = (
|
|
359
|
+
[],
|
|
360
|
+
[],
|
|
361
|
+
[],
|
|
362
|
+
[],
|
|
363
|
+
[],
|
|
364
|
+
[],
|
|
365
|
+
[],
|
|
366
|
+
[],
|
|
291
367
|
)
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
"
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
368
|
+
|
|
369
|
+
return Text_(
|
|
370
|
+
text=" ".join(characters), # type: ignore
|
|
371
|
+
words=characters, # type: ignore
|
|
372
|
+
ann_ids=ann_ids, # type: ignore
|
|
373
|
+
token_classes=token_classes, # type: ignore
|
|
374
|
+
token_class_ann_ids=token_class_ann_ids, # type: ignore
|
|
375
|
+
token_tags=token_tags, # type: ignore
|
|
376
|
+
token_tag_ann_ids=token_tag_ann_ids, # type: ignore
|
|
377
|
+
token_class_ids=token_classes_ids, # type: ignore
|
|
378
|
+
token_tag_ids=token_tag_ids, # type: ignore
|
|
379
|
+
)
|
|
301
380
|
|
|
302
381
|
def get_attribute_names(self) -> set[str]:
|
|
303
382
|
attr_names = (
|
|
@@ -387,9 +466,9 @@ class Table(Layout):
|
|
|
387
466
|
A list of a table cells.
|
|
388
467
|
"""
|
|
389
468
|
cell_anns: list[Cell] = []
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
469
|
+
if self.number_of_rows:
|
|
470
|
+
for row_number in range(1, self.number_of_rows + 1): # type: ignore
|
|
471
|
+
cell_anns.extend(self.row(row_number)) # type: ignore
|
|
393
472
|
return cell_anns
|
|
394
473
|
|
|
395
474
|
@property
|
|
@@ -626,26 +705,33 @@ class Table(Layout):
|
|
|
626
705
|
words: list[str] = []
|
|
627
706
|
ann_ids: list[str] = []
|
|
628
707
|
token_classes: list[str] = []
|
|
708
|
+
token_class_ann_ids: list[str] = []
|
|
629
709
|
token_tags: list[str] = []
|
|
710
|
+
token_tag_ann_ids: list[str] = []
|
|
630
711
|
token_class_ids: list[str] = []
|
|
631
712
|
token_tag_ids: list[str] = []
|
|
632
713
|
for cell in cells:
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
"
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
714
|
+
text_ = cell.text_
|
|
715
|
+
text.append(text_.text)
|
|
716
|
+
words.extend(text_.words)
|
|
717
|
+
ann_ids.extend(text_.ann_ids)
|
|
718
|
+
token_classes.extend(text_.token_classes)
|
|
719
|
+
token_class_ann_ids.extend(text_.token_class_ann_ids)
|
|
720
|
+
token_tags.extend(text_.token_tags)
|
|
721
|
+
token_tag_ann_ids.extend(text_.token_tag_ann_ids)
|
|
722
|
+
token_class_ids.extend(text_.token_class_ids)
|
|
723
|
+
token_tag_ids.extend(text_.token_tag_ids)
|
|
724
|
+
return Text_(
|
|
725
|
+
text=" ".join(text),
|
|
726
|
+
words=words,
|
|
727
|
+
ann_ids=ann_ids,
|
|
728
|
+
token_classes=token_classes,
|
|
729
|
+
token_class_ann_ids=token_class_ann_ids,
|
|
730
|
+
token_tags=token_tags,
|
|
731
|
+
token_tag_ann_ids=token_tag_ann_ids,
|
|
732
|
+
token_class_ids=token_class_ids,
|
|
733
|
+
token_tag_ids=token_tag_ids,
|
|
734
|
+
)
|
|
649
735
|
|
|
650
736
|
@property
|
|
651
737
|
def words(self) -> list[ImageAnnotationBaseView]:
|
|
@@ -1053,7 +1139,7 @@ class Page(Image):
|
|
|
1053
1139
|
|
|
1054
1140
|
```python
|
|
1055
1141
|
{"text": text string,
|
|
1056
|
-
"
|
|
1142
|
+
"words": list of single words,
|
|
1057
1143
|
"annotation_ids": word annotation ids}
|
|
1058
1144
|
```
|
|
1059
1145
|
"""
|
|
@@ -1062,26 +1148,33 @@ class Page(Image):
|
|
|
1062
1148
|
words: list[str] = []
|
|
1063
1149
|
ann_ids: list[str] = []
|
|
1064
1150
|
token_classes: list[str] = []
|
|
1151
|
+
token_class_ann_ids: list[str] = []
|
|
1065
1152
|
token_tags: list[str] = []
|
|
1153
|
+
token_tag_ann_ids: list[str] = []
|
|
1066
1154
|
token_class_ids: list[str] = []
|
|
1067
1155
|
token_tag_ids: list[str] = []
|
|
1068
1156
|
for block in block_with_order:
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
"
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1157
|
+
text_ = block.text_
|
|
1158
|
+
text.append(text_.text) # type: ignore
|
|
1159
|
+
words.extend(text_.words) # type: ignore
|
|
1160
|
+
ann_ids.extend(text_.ann_ids) # type: ignore
|
|
1161
|
+
token_classes.extend(text_.token_classes) # type: ignore
|
|
1162
|
+
token_class_ann_ids.extend(text_.token_class_ann_ids) # type: ignore
|
|
1163
|
+
token_tags.extend(text_.token_tags) # type: ignore
|
|
1164
|
+
token_tag_ann_ids.extend(text_.token_tag_ann_ids) # type: ignore
|
|
1165
|
+
token_class_ids.extend(text_.token_class_ids) # type: ignore
|
|
1166
|
+
token_tag_ids.extend(text_.token_tag_ids) # type: ignore
|
|
1167
|
+
return Text_(
|
|
1168
|
+
text=" ".join(text),
|
|
1169
|
+
words=words,
|
|
1170
|
+
ann_ids=ann_ids,
|
|
1171
|
+
token_classes=token_classes,
|
|
1172
|
+
token_class_ann_ids=token_class_ann_ids,
|
|
1173
|
+
token_tags=token_tags,
|
|
1174
|
+
token_tag_ann_ids=token_tag_ann_ids,
|
|
1175
|
+
token_class_ids=token_class_ids,
|
|
1176
|
+
token_tag_ids=token_tag_ann_ids,
|
|
1177
|
+
)
|
|
1085
1178
|
|
|
1086
1179
|
def get_layout_context(self, annotation_id: str, context_size: int = 3) -> list[ImageAnnotationBaseView]:
|
|
1087
1180
|
"""
|
deepdoctection/datasets/base.py
CHANGED
|
@@ -91,7 +91,7 @@ def d2_predict_image(
|
|
|
91
91
|
"""
|
|
92
92
|
height, width = np_img.shape[:2]
|
|
93
93
|
resized_img = resizer.get_transform(np_img).apply_image(np_img)
|
|
94
|
-
image = torch.as_tensor(resized_img.astype(
|
|
94
|
+
image = torch.as_tensor(resized_img.astype(np.float32).transpose(2, 0, 1))
|
|
95
95
|
|
|
96
96
|
with torch.no_grad():
|
|
97
97
|
inputs = {"image": image, "height": height, "width": width}
|
|
@@ -29,13 +29,14 @@ from typing import Any, Mapping, Union
|
|
|
29
29
|
|
|
30
30
|
from lazy_imports import try_import
|
|
31
31
|
|
|
32
|
-
from ..utils.
|
|
32
|
+
from ..utils.develop import deprecated
|
|
33
|
+
from ..utils.file_utils import Requirement, get_fasttext_requirement, get_numpy_v1_requirement
|
|
33
34
|
from ..utils.settings import TypeOrStr, get_type
|
|
34
35
|
from ..utils.types import PathLikeOrStr
|
|
35
36
|
from .base import DetectionResult, LanguageDetector, ModelCategories
|
|
36
37
|
|
|
37
38
|
with try_import() as import_guard:
|
|
38
|
-
from fasttext import load_model # type: ignore
|
|
39
|
+
from fasttext import load_model # type: ignore # pylint: disable=E0401
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
class FasttextLangDetectorMixin(LanguageDetector, ABC):
|
|
@@ -61,7 +62,7 @@ class FasttextLangDetectorMixin(LanguageDetector, ABC):
|
|
|
61
62
|
Returns:
|
|
62
63
|
`DetectionResult` filled with `text` and `score`
|
|
63
64
|
"""
|
|
64
|
-
return DetectionResult(
|
|
65
|
+
return DetectionResult(class_name=self.categories_orig[output[0][0]], score=output[1][0])
|
|
65
66
|
|
|
66
67
|
@staticmethod
|
|
67
68
|
def get_name(path_weights: PathLikeOrStr) -> str:
|
|
@@ -69,6 +70,7 @@ class FasttextLangDetectorMixin(LanguageDetector, ABC):
|
|
|
69
70
|
return "fasttext_" + "_".join(Path(path_weights).parts[-2:])
|
|
70
71
|
|
|
71
72
|
|
|
73
|
+
@deprecated("As FastText archived, it will be deprecated in the near future.", "2025-08-17")
|
|
72
74
|
class FasttextLangDetector(FasttextLangDetectorMixin):
|
|
73
75
|
"""
|
|
74
76
|
Fasttext language detector wrapper. Two models provided in the fasttext library can be used to identify languages.
|
|
@@ -114,7 +116,7 @@ class FasttextLangDetector(FasttextLangDetectorMixin):
|
|
|
114
116
|
|
|
115
117
|
@classmethod
|
|
116
118
|
def get_requirements(cls) -> list[Requirement]:
|
|
117
|
-
return [get_fasttext_requirement()]
|
|
119
|
+
return [get_numpy_v1_requirement(), get_fasttext_requirement()]
|
|
118
120
|
|
|
119
121
|
def clone(self) -> FasttextLangDetector:
|
|
120
122
|
return self.__class__(self.path_weights, self.categories.get_categories(), self.categories_orig)
|
|
@@ -126,10 +126,13 @@ def get_tokenizer_from_model_class(model_class: str, use_xlm_tokenizer: bool) ->
|
|
|
126
126
|
("XLMRobertaForSequenceClassification", True): XLMRobertaTokenizerFast.from_pretrained(
|
|
127
127
|
"FacebookAI/xlm-roberta-base"
|
|
128
128
|
),
|
|
129
|
+
("XLMRobertaForTokenClassification", True): XLMRobertaTokenizerFast.from_pretrained(
|
|
130
|
+
"FacebookAI/xlm-roberta-base"
|
|
131
|
+
),
|
|
129
132
|
}[(model_class, use_xlm_tokenizer)]
|
|
130
133
|
|
|
131
134
|
|
|
132
|
-
def
|
|
135
|
+
def predict_token_classes_from_layoutlm(
|
|
133
136
|
uuids: list[list[str]],
|
|
134
137
|
input_ids: torch.Tensor,
|
|
135
138
|
attention_mask: torch.Tensor,
|
|
@@ -192,7 +195,7 @@ def predict_token_classes(
|
|
|
192
195
|
return all_token_classes
|
|
193
196
|
|
|
194
197
|
|
|
195
|
-
def
|
|
198
|
+
def predict_sequence_classes_from_layoutlm(
|
|
196
199
|
input_ids: torch.Tensor,
|
|
197
200
|
attention_mask: torch.Tensor,
|
|
198
201
|
token_type_ids: torch.Tensor,
|
|
@@ -462,7 +465,7 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
462
465
|
|
|
463
466
|
ann_ids, _, input_ids, attention_mask, token_type_ids, boxes, tokens = self._validate_encodings(**encodings)
|
|
464
467
|
|
|
465
|
-
results =
|
|
468
|
+
results = predict_token_classes_from_layoutlm(
|
|
466
469
|
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, None
|
|
467
470
|
)
|
|
468
471
|
|
|
@@ -586,7 +589,7 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
586
589
|
images = images.to(self.device)
|
|
587
590
|
else:
|
|
588
591
|
raise ValueError(f"images must be list but is {type(images)}")
|
|
589
|
-
results =
|
|
592
|
+
results = predict_token_classes_from_layoutlm(
|
|
590
593
|
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, images
|
|
591
594
|
)
|
|
592
595
|
|
|
@@ -710,7 +713,7 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
710
713
|
images = images.to(self.device)
|
|
711
714
|
else:
|
|
712
715
|
raise ValueError(f"images must be list but is {type(images)}")
|
|
713
|
-
results =
|
|
716
|
+
results = predict_token_classes_from_layoutlm(
|
|
714
717
|
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, images
|
|
715
718
|
)
|
|
716
719
|
|
|
@@ -909,7 +912,7 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
909
912
|
"""
|
|
910
913
|
input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
|
|
911
914
|
|
|
912
|
-
result =
|
|
915
|
+
result = predict_sequence_classes_from_layoutlm(
|
|
913
916
|
input_ids,
|
|
914
917
|
attention_mask,
|
|
915
918
|
token_type_ids,
|
|
@@ -1021,7 +1024,12 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
1021
1024
|
else:
|
|
1022
1025
|
raise ValueError(f"images must be list but is {type(images)}")
|
|
1023
1026
|
|
|
1024
|
-
result =
|
|
1027
|
+
result = predict_sequence_classes_from_layoutlm(input_ids,
|
|
1028
|
+
attention_mask,
|
|
1029
|
+
token_type_ids,
|
|
1030
|
+
boxes,
|
|
1031
|
+
self.model,
|
|
1032
|
+
images)
|
|
1025
1033
|
|
|
1026
1034
|
result.class_id += 1
|
|
1027
1035
|
result.class_name = self.categories.categories[result.class_id]
|
|
@@ -1115,7 +1123,12 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
1115
1123
|
else:
|
|
1116
1124
|
raise ValueError(f"images must be list but is {type(images)}")
|
|
1117
1125
|
|
|
1118
|
-
result =
|
|
1126
|
+
result = predict_sequence_classes_from_layoutlm(input_ids,
|
|
1127
|
+
attention_mask,
|
|
1128
|
+
token_type_ids,
|
|
1129
|
+
boxes,
|
|
1130
|
+
self.model,
|
|
1131
|
+
images)
|
|
1119
1132
|
|
|
1120
1133
|
result.class_id += 1
|
|
1121
1134
|
result.class_name = self.categories.categories[result.class_id]
|
|
@@ -1245,7 +1258,7 @@ class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
1245
1258
|
|
|
1246
1259
|
ann_ids, _, input_ids, attention_mask, token_type_ids, boxes, tokens = self._validate_encodings(**encodings)
|
|
1247
1260
|
|
|
1248
|
-
results =
|
|
1261
|
+
results = predict_token_classes_from_layoutlm(
|
|
1249
1262
|
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, None
|
|
1250
1263
|
)
|
|
1251
1264
|
|
|
@@ -1323,7 +1336,7 @@ class HFLiltSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
1323
1336
|
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
|
|
1324
1337
|
input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
|
|
1325
1338
|
|
|
1326
|
-
result =
|
|
1339
|
+
result = predict_sequence_classes_from_layoutlm(
|
|
1327
1340
|
input_ids,
|
|
1328
1341
|
attention_mask,
|
|
1329
1342
|
token_type_ids,
|