deepdoctection 0.44.1__py3-none-any.whl → 0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +7 -3
- deepdoctection/analyzer/config.py +44 -0
- deepdoctection/analyzer/factory.py +264 -7
- deepdoctection/configs/profiles.jsonl +2 -1
- deepdoctection/dataflow/parallel_map.py +7 -1
- deepdoctection/datapoint/box.py +5 -5
- deepdoctection/datapoint/image.py +5 -5
- deepdoctection/datapoint/view.py +73 -52
- deepdoctection/eval/cocometric.py +1 -0
- deepdoctection/extern/__init__.py +1 -0
- deepdoctection/extern/base.py +8 -1
- deepdoctection/extern/d2detect.py +1 -1
- deepdoctection/extern/doctrocr.py +18 -2
- deepdoctection/extern/fastlang.py +2 -2
- deepdoctection/extern/hflayoutlm.py +17 -10
- deepdoctection/extern/hflm.py +432 -7
- deepdoctection/extern/tessocr.py +17 -1
- deepdoctection/pipe/language.py +4 -4
- deepdoctection/pipe/lm.py +7 -3
- deepdoctection/pipe/order.py +12 -6
- deepdoctection/pipe/refine.py +10 -1
- deepdoctection/pipe/text.py +6 -0
- deepdoctection/pipe/transform.py +3 -0
- deepdoctection/utils/file_utils.py +34 -5
- deepdoctection/utils/logger.py +38 -1
- deepdoctection/utils/settings.py +2 -0
- deepdoctection/utils/transform.py +43 -18
- deepdoctection/utils/viz.py +24 -15
- {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/METADATA +16 -21
- {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/RECORD +33 -33
- {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/WHEEL +0 -0
- {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/licenses/LICENSE +0 -0
- {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/top_level.txt +0 -0
deepdoctection/datapoint/view.py
CHANGED
|
@@ -319,29 +319,32 @@ class Layout(ImageAnnotationBaseView):
|
|
|
319
319
|
token_tag_ann_ids,
|
|
320
320
|
token_classes_ids,
|
|
321
321
|
token_tag_ids,
|
|
322
|
-
) = map(
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
322
|
+
) = map(
|
|
323
|
+
list,
|
|
324
|
+
zip(
|
|
325
|
+
*[
|
|
326
|
+
(
|
|
327
|
+
word.characters,
|
|
328
|
+
word.annotation_id,
|
|
329
|
+
word.token_class,
|
|
330
|
+
word.get_sub_category(WordType.TOKEN_CLASS).annotation_id
|
|
331
|
+
if WordType.TOKEN_CLASS in word.sub_categories
|
|
332
|
+
else None,
|
|
333
|
+
word.token_tag,
|
|
334
|
+
word.get_sub_category(WordType.TOKEN_TAG).annotation_id
|
|
335
|
+
if WordType.TOKEN_TAG in word.sub_categories
|
|
336
|
+
else None,
|
|
337
|
+
word.get_sub_category(WordType.TOKEN_CLASS).category_id
|
|
338
|
+
if WordType.TOKEN_CLASS in word.sub_categories
|
|
339
|
+
else None,
|
|
340
|
+
word.get_sub_category(WordType.TOKEN_TAG).category_id
|
|
341
|
+
if WordType.TOKEN_TAG in word.sub_categories
|
|
342
|
+
else None,
|
|
343
|
+
)
|
|
344
|
+
for word in words
|
|
345
|
+
]
|
|
346
|
+
),
|
|
347
|
+
)
|
|
345
348
|
else:
|
|
346
349
|
(
|
|
347
350
|
characters,
|
|
@@ -364,18 +367,17 @@ class Layout(ImageAnnotationBaseView):
|
|
|
364
367
|
)
|
|
365
368
|
|
|
366
369
|
return Text_(
|
|
367
|
-
text=" ".join(characters),
|
|
368
|
-
words=characters,
|
|
369
|
-
ann_ids=ann_ids,
|
|
370
|
-
token_classes=token_classes,
|
|
371
|
-
token_class_ann_ids=token_class_ann_ids,
|
|
372
|
-
token_tags=token_tags,
|
|
373
|
-
token_tag_ann_ids=token_tag_ann_ids,
|
|
374
|
-
token_class_ids=token_classes_ids,
|
|
375
|
-
token_tag_ids=token_tag_ids,
|
|
370
|
+
text=" ".join(characters), # type: ignore
|
|
371
|
+
words=characters, # type: ignore
|
|
372
|
+
ann_ids=ann_ids, # type: ignore
|
|
373
|
+
token_classes=token_classes, # type: ignore
|
|
374
|
+
token_class_ann_ids=token_class_ann_ids, # type: ignore
|
|
375
|
+
token_tags=token_tags, # type: ignore
|
|
376
|
+
token_tag_ann_ids=token_tag_ann_ids, # type: ignore
|
|
377
|
+
token_class_ids=token_classes_ids, # type: ignore
|
|
378
|
+
token_tag_ids=token_tag_ids, # type: ignore
|
|
376
379
|
)
|
|
377
380
|
|
|
378
|
-
|
|
379
381
|
def get_attribute_names(self) -> set[str]:
|
|
380
382
|
attr_names = (
|
|
381
383
|
{"words", "text"}
|
|
@@ -426,6 +428,8 @@ class List(Layout):
|
|
|
426
428
|
A list of words order by reading order. Words with no `reading_order` will not be returned"""
|
|
427
429
|
try:
|
|
428
430
|
list_items = self.list_items
|
|
431
|
+
if not list_items:
|
|
432
|
+
return super().get_ordered_words()
|
|
429
433
|
all_words = []
|
|
430
434
|
list_items.sort(key=lambda x: x.bbox[1])
|
|
431
435
|
for list_item in list_items:
|
|
@@ -464,9 +468,9 @@ class Table(Layout):
|
|
|
464
468
|
A list of a table cells.
|
|
465
469
|
"""
|
|
466
470
|
cell_anns: list[Cell] = []
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
471
|
+
if self.number_of_rows:
|
|
472
|
+
for row_number in range(1, self.number_of_rows + 1): # type: ignore
|
|
473
|
+
cell_anns.extend(self.row(row_number)) # type: ignore
|
|
470
474
|
return cell_anns
|
|
471
475
|
|
|
472
476
|
@property
|
|
@@ -731,7 +735,6 @@ class Table(Layout):
|
|
|
731
735
|
token_tag_ids=token_tag_ids,
|
|
732
736
|
)
|
|
733
737
|
|
|
734
|
-
|
|
735
738
|
@property
|
|
736
739
|
def words(self) -> list[ImageAnnotationBaseView]:
|
|
737
740
|
"""
|
|
@@ -754,6 +757,8 @@ class Table(Layout):
|
|
|
754
757
|
"""
|
|
755
758
|
try:
|
|
756
759
|
cells = self.cells
|
|
760
|
+
if not cells:
|
|
761
|
+
return super().get_ordered_words()
|
|
757
762
|
all_words = []
|
|
758
763
|
cells.sort(key=lambda x: (x.ROW_NUMBER, x.COLUMN_NUMBER))
|
|
759
764
|
for cell in cells:
|
|
@@ -1053,6 +1058,8 @@ class Page(Image):
|
|
|
1053
1058
|
Returns:
|
|
1054
1059
|
A `Page` instance with all annotations as `ImageAnnotationBaseView` subclasses.
|
|
1055
1060
|
"""
|
|
1061
|
+
if isinstance(image_orig, Page):
|
|
1062
|
+
raise ImageError("Page.from_image() cannot be called on a Page instance.")
|
|
1056
1063
|
|
|
1057
1064
|
if text_container is None:
|
|
1058
1065
|
text_container = IMAGE_DEFAULTS.TEXT_CONTAINER
|
|
@@ -1175,7 +1182,6 @@ class Page(Image):
|
|
|
1175
1182
|
token_tag_ids=token_tag_ann_ids,
|
|
1176
1183
|
)
|
|
1177
1184
|
|
|
1178
|
-
|
|
1179
1185
|
def get_layout_context(self, annotation_id: str, context_size: int = 3) -> list[ImageAnnotationBaseView]:
|
|
1180
1186
|
"""
|
|
1181
1187
|
For a given `annotation_id` get a list of `ImageAnnotation` that are nearby in terms of `reading_order`.
|
|
@@ -1310,7 +1316,7 @@ class Page(Image):
|
|
|
1310
1316
|
If `interactive=False` will return a `np.array`.
|
|
1311
1317
|
"""
|
|
1312
1318
|
|
|
1313
|
-
category_names_list: list[Union[str, None]] = []
|
|
1319
|
+
category_names_list: list[Tuple[Union[str, None], Union[str, None]]] = []
|
|
1314
1320
|
box_stack = []
|
|
1315
1321
|
cells_found = False
|
|
1316
1322
|
|
|
@@ -1323,22 +1329,23 @@ class Page(Image):
|
|
|
1323
1329
|
anns = self.get_annotation(category_names=list(debug_kwargs.keys()))
|
|
1324
1330
|
for ann in anns:
|
|
1325
1331
|
box_stack.append(self._ann_viz_bbox(ann))
|
|
1326
|
-
|
|
1332
|
+
val = str(getattr(ann, debug_kwargs[ann.category_name]))
|
|
1333
|
+
category_names_list.append((val, val))
|
|
1327
1334
|
|
|
1328
1335
|
if show_layouts and not debug_kwargs:
|
|
1329
1336
|
for item in self.layouts:
|
|
1330
1337
|
box_stack.append(self._ann_viz_bbox(item))
|
|
1331
|
-
category_names_list.append(item.category_name.value)
|
|
1338
|
+
category_names_list.append((item.category_name.value, item.category_name.value))
|
|
1332
1339
|
|
|
1333
1340
|
if show_figures and not debug_kwargs:
|
|
1334
1341
|
for item in self.figures:
|
|
1335
1342
|
box_stack.append(self._ann_viz_bbox(item))
|
|
1336
|
-
category_names_list.append(item.category_name.value)
|
|
1343
|
+
category_names_list.append((item.category_name.value, item.category_name.value))
|
|
1337
1344
|
|
|
1338
1345
|
if show_tables and not debug_kwargs:
|
|
1339
1346
|
for table in self.tables:
|
|
1340
1347
|
box_stack.append(self._ann_viz_bbox(table))
|
|
1341
|
-
category_names_list.append(LayoutType.TABLE.value)
|
|
1348
|
+
category_names_list.append((LayoutType.TABLE.value, LayoutType.TABLE.value))
|
|
1342
1349
|
if show_cells:
|
|
1343
1350
|
for cell in table.cells:
|
|
1344
1351
|
if cell.category_name in {
|
|
@@ -1347,21 +1354,21 @@ class Page(Image):
|
|
|
1347
1354
|
}:
|
|
1348
1355
|
cells_found = True
|
|
1349
1356
|
box_stack.append(self._ann_viz_bbox(cell))
|
|
1350
|
-
category_names_list.append(None)
|
|
1357
|
+
category_names_list.append((None, cell.category_name.value))
|
|
1351
1358
|
if show_table_structure:
|
|
1352
1359
|
rows = table.rows
|
|
1353
1360
|
cols = table.columns
|
|
1354
1361
|
for row in rows:
|
|
1355
1362
|
box_stack.append(self._ann_viz_bbox(row))
|
|
1356
|
-
category_names_list.append(None)
|
|
1363
|
+
category_names_list.append((None, row.category_name.value))
|
|
1357
1364
|
for col in cols:
|
|
1358
1365
|
box_stack.append(self._ann_viz_bbox(col))
|
|
1359
|
-
category_names_list.append(None)
|
|
1366
|
+
category_names_list.append((None, col.category_name.value))
|
|
1360
1367
|
|
|
1361
1368
|
if show_cells and not cells_found and not debug_kwargs:
|
|
1362
1369
|
for ann in self.get_annotation(category_names=[LayoutType.CELL, CellType.SPANNING]):
|
|
1363
1370
|
box_stack.append(self._ann_viz_bbox(ann))
|
|
1364
|
-
category_names_list.append(None)
|
|
1371
|
+
category_names_list.append((None, ann.category_name.value))
|
|
1365
1372
|
|
|
1366
1373
|
if show_words and not debug_kwargs:
|
|
1367
1374
|
all_words = []
|
|
@@ -1379,22 +1386,36 @@ class Page(Image):
|
|
|
1379
1386
|
for word in all_words:
|
|
1380
1387
|
box_stack.append(self._ann_viz_bbox(word))
|
|
1381
1388
|
if show_token_class:
|
|
1382
|
-
category_names_list.append(
|
|
1389
|
+
category_names_list.append(
|
|
1390
|
+
(word.token_class.value, word.token_class.value)
|
|
1391
|
+
if word.token_class is not None
|
|
1392
|
+
else (None, None)
|
|
1393
|
+
)
|
|
1383
1394
|
else:
|
|
1384
|
-
category_names_list.append(
|
|
1395
|
+
category_names_list.append(
|
|
1396
|
+
(word.token_tag.value, word.token_tag.value) if word.token_tag is not None else (None, None)
|
|
1397
|
+
)
|
|
1385
1398
|
else:
|
|
1386
1399
|
for word in all_words:
|
|
1387
1400
|
if word.token_class is not None and word.token_class != TokenClasses.OTHER:
|
|
1388
1401
|
box_stack.append(self._ann_viz_bbox(word))
|
|
1389
1402
|
if show_token_class:
|
|
1390
|
-
category_names_list.append(
|
|
1403
|
+
category_names_list.append(
|
|
1404
|
+
(word.token_class.value, word.token_class.value)
|
|
1405
|
+
if word.token_class is not None
|
|
1406
|
+
else (None, None)
|
|
1407
|
+
)
|
|
1391
1408
|
else:
|
|
1392
|
-
category_names_list.append(
|
|
1409
|
+
category_names_list.append(
|
|
1410
|
+
(word.token_tag.value, word.token_tag.value)
|
|
1411
|
+
if word.token_tag is not None
|
|
1412
|
+
else (None, None)
|
|
1413
|
+
)
|
|
1393
1414
|
|
|
1394
1415
|
if show_residual_layouts and not debug_kwargs:
|
|
1395
1416
|
for item in self.residual_layouts:
|
|
1396
1417
|
box_stack.append(item.bbox)
|
|
1397
|
-
category_names_list.append(item.category_name.value)
|
|
1418
|
+
category_names_list.append((item.category_name.value, item.category_name.value))
|
|
1398
1419
|
|
|
1399
1420
|
if self.image is not None:
|
|
1400
1421
|
scale_fx = scaled_width / self.width
|
|
@@ -275,6 +275,7 @@ class CocoMetric(MetricBase):
|
|
|
275
275
|
get the ultimate F1-score.
|
|
276
276
|
f1_iou: Use with `f1_score=True` and reset the f1 iou threshold
|
|
277
277
|
per_category: Whether to calculate metrics per category
|
|
278
|
+
per_category: If set to True, f1 score will be returned by each category.
|
|
278
279
|
"""
|
|
279
280
|
if max_detections is not None:
|
|
280
281
|
assert len(max_detections) == 3, max_detections
|
deepdoctection/extern/base.py
CHANGED
|
@@ -263,7 +263,7 @@ class PredictorBase(ABC):
|
|
|
263
263
|
requirements = cls.get_requirements()
|
|
264
264
|
name = cls.__name__ if hasattr(cls, "__name__") else cls.__class__.__name__
|
|
265
265
|
if not all(requirement[1] for requirement in requirements):
|
|
266
|
-
raise
|
|
266
|
+
raise ModuleNotFoundError(
|
|
267
267
|
"\n".join(
|
|
268
268
|
[f"{name} has the following dependencies:"]
|
|
269
269
|
+ [requirement[2] for requirement in requirements if not requirement[1]]
|
|
@@ -334,6 +334,11 @@ class DetectionResult:
|
|
|
334
334
|
block: block number. For reading order from some ocr predictors
|
|
335
335
|
line: line number. For reading order from some ocr predictors
|
|
336
336
|
uuid: uuid. For assigning detection result (e.g. text to image annotations)
|
|
337
|
+
relationships: A dictionary of relationships. Each key is a relationship type and each value is a list of
|
|
338
|
+
uuids of the related annotations.
|
|
339
|
+
angle: angle of rotation in degrees. Only used for text detection.
|
|
340
|
+
image_width: image width
|
|
341
|
+
image_height: image height
|
|
337
342
|
"""
|
|
338
343
|
|
|
339
344
|
box: Optional[list[float]] = None
|
|
@@ -348,6 +353,8 @@ class DetectionResult:
|
|
|
348
353
|
uuid: Optional[str] = None
|
|
349
354
|
relationships: Optional[dict[str, Any]] = None
|
|
350
355
|
angle: Optional[float] = None
|
|
356
|
+
image_width: Optional[Union[int, float]] = None
|
|
357
|
+
image_height: Optional[Union[int, float]] = None
|
|
351
358
|
|
|
352
359
|
|
|
353
360
|
class ObjectDetector(PredictorBase, ABC):
|
|
@@ -91,7 +91,7 @@ def d2_predict_image(
|
|
|
91
91
|
"""
|
|
92
92
|
height, width = np_img.shape[:2]
|
|
93
93
|
resized_img = resizer.get_transform(np_img).apply_image(np_img)
|
|
94
|
-
image = torch.as_tensor(resized_img.astype(
|
|
94
|
+
image = torch.as_tensor(resized_img.astype(np.float32).transpose(2, 0, 1))
|
|
95
95
|
|
|
96
96
|
with torch.no_grad():
|
|
97
97
|
inputs = {"image": image, "height": height, "width": width}
|
|
@@ -24,9 +24,10 @@ from __future__ import annotations
|
|
|
24
24
|
import os
|
|
25
25
|
from abc import ABC
|
|
26
26
|
from pathlib import Path
|
|
27
|
-
from typing import Any, Literal, Mapping, Optional, Union
|
|
27
|
+
from typing import Any, Literal, Mapping, Optional, Sequence, Union
|
|
28
28
|
from zipfile import ZipFile
|
|
29
29
|
|
|
30
|
+
import numpy as np
|
|
30
31
|
from lazy_imports import try_import
|
|
31
32
|
|
|
32
33
|
from ..utils.env_info import ENV_VARS_TRUE
|
|
@@ -39,6 +40,7 @@ from ..utils.file_utils import (
|
|
|
39
40
|
)
|
|
40
41
|
from ..utils.fs import load_json
|
|
41
42
|
from ..utils.settings import LayoutType, ObjectTypes, PageType, TypeOrStr
|
|
43
|
+
from ..utils.transform import RotationTransform
|
|
42
44
|
from ..utils.types import PathLikeOrStr, PixelValues, Requirement
|
|
43
45
|
from ..utils.viz import viz_handler
|
|
44
46
|
from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector, TextRecognizer
|
|
@@ -558,12 +560,13 @@ class DocTrRotationTransformer(ImageTransformer):
|
|
|
558
560
|
"""
|
|
559
561
|
Args:
|
|
560
562
|
number_contours: the number of contours used for the orientation estimation
|
|
561
|
-
ratio_threshold_for_lines: this is the ratio w/h used to
|
|
563
|
+
ratio_threshold_for_lines: this is the ratio w/h used to discriminate lines
|
|
562
564
|
"""
|
|
563
565
|
self.number_contours = number_contours
|
|
564
566
|
self.ratio_threshold_for_lines = ratio_threshold_for_lines
|
|
565
567
|
self.name = "doctr_rotation_transformer"
|
|
566
568
|
self.model_id = self.get_model_id()
|
|
569
|
+
self.rotator = RotationTransform(360)
|
|
567
570
|
|
|
568
571
|
def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
|
|
569
572
|
"""
|
|
@@ -579,6 +582,19 @@ class DocTrRotationTransformer(ImageTransformer):
|
|
|
579
582
|
"""
|
|
580
583
|
return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
|
|
581
584
|
|
|
585
|
+
def transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
|
|
586
|
+
if detect_results:
|
|
587
|
+
if detect_results[0].angle:
|
|
588
|
+
self.rotator.set_angle(detect_results[0].angle) # type: ignore
|
|
589
|
+
self.rotator.set_image_width(detect_results[0].image_width) # type: ignore
|
|
590
|
+
self.rotator.set_image_height(detect_results[0].image_height) # type: ignore
|
|
591
|
+
transformed_coords = self.rotator.apply_coords(
|
|
592
|
+
np.asarray([detect_result.box for detect_result in detect_results], dtype=float)
|
|
593
|
+
)
|
|
594
|
+
for idx, detect_result in enumerate(detect_results):
|
|
595
|
+
detect_result.box = transformed_coords[idx, :].tolist()
|
|
596
|
+
return detect_results
|
|
597
|
+
|
|
582
598
|
def predict(self, np_img: PixelValues) -> DetectionResult:
|
|
583
599
|
angle = estimate_orientation(
|
|
584
600
|
np_img, n_ct=self.number_contours, ratio_threshold_for_lines=self.ratio_threshold_for_lines
|
|
@@ -36,7 +36,7 @@ from ..utils.types import PathLikeOrStr
|
|
|
36
36
|
from .base import DetectionResult, LanguageDetector, ModelCategories
|
|
37
37
|
|
|
38
38
|
with try_import() as import_guard:
|
|
39
|
-
from fasttext import load_model # type: ignore
|
|
39
|
+
from fasttext import load_model # type: ignore # pylint: disable=E0401
|
|
40
40
|
|
|
41
41
|
|
|
42
42
|
class FasttextLangDetectorMixin(LanguageDetector, ABC):
|
|
@@ -62,7 +62,7 @@ class FasttextLangDetectorMixin(LanguageDetector, ABC):
|
|
|
62
62
|
Returns:
|
|
63
63
|
`DetectionResult` filled with `text` and `score`
|
|
64
64
|
"""
|
|
65
|
-
return DetectionResult(
|
|
65
|
+
return DetectionResult(class_name=self.categories_orig[output[0][0]], score=output[1][0])
|
|
66
66
|
|
|
67
67
|
@staticmethod
|
|
68
68
|
def get_name(path_weights: PathLikeOrStr) -> str:
|
|
@@ -126,10 +126,13 @@ def get_tokenizer_from_model_class(model_class: str, use_xlm_tokenizer: bool) ->
|
|
|
126
126
|
("XLMRobertaForSequenceClassification", True): XLMRobertaTokenizerFast.from_pretrained(
|
|
127
127
|
"FacebookAI/xlm-roberta-base"
|
|
128
128
|
),
|
|
129
|
+
("XLMRobertaForTokenClassification", True): XLMRobertaTokenizerFast.from_pretrained(
|
|
130
|
+
"FacebookAI/xlm-roberta-base"
|
|
131
|
+
),
|
|
129
132
|
}[(model_class, use_xlm_tokenizer)]
|
|
130
133
|
|
|
131
134
|
|
|
132
|
-
def
|
|
135
|
+
def predict_token_classes_from_layoutlm(
|
|
133
136
|
uuids: list[list[str]],
|
|
134
137
|
input_ids: torch.Tensor,
|
|
135
138
|
attention_mask: torch.Tensor,
|
|
@@ -192,7 +195,7 @@ def predict_token_classes(
|
|
|
192
195
|
return all_token_classes
|
|
193
196
|
|
|
194
197
|
|
|
195
|
-
def
|
|
198
|
+
def predict_sequence_classes_from_layoutlm(
|
|
196
199
|
input_ids: torch.Tensor,
|
|
197
200
|
attention_mask: torch.Tensor,
|
|
198
201
|
token_type_ids: torch.Tensor,
|
|
@@ -462,7 +465,7 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
462
465
|
|
|
463
466
|
ann_ids, _, input_ids, attention_mask, token_type_ids, boxes, tokens = self._validate_encodings(**encodings)
|
|
464
467
|
|
|
465
|
-
results =
|
|
468
|
+
results = predict_token_classes_from_layoutlm(
|
|
466
469
|
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, None
|
|
467
470
|
)
|
|
468
471
|
|
|
@@ -586,7 +589,7 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
586
589
|
images = images.to(self.device)
|
|
587
590
|
else:
|
|
588
591
|
raise ValueError(f"images must be list but is {type(images)}")
|
|
589
|
-
results =
|
|
592
|
+
results = predict_token_classes_from_layoutlm(
|
|
590
593
|
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, images
|
|
591
594
|
)
|
|
592
595
|
|
|
@@ -710,7 +713,7 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
710
713
|
images = images.to(self.device)
|
|
711
714
|
else:
|
|
712
715
|
raise ValueError(f"images must be list but is {type(images)}")
|
|
713
|
-
results =
|
|
716
|
+
results = predict_token_classes_from_layoutlm(
|
|
714
717
|
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, images
|
|
715
718
|
)
|
|
716
719
|
|
|
@@ -909,7 +912,7 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
909
912
|
"""
|
|
910
913
|
input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
|
|
911
914
|
|
|
912
|
-
result =
|
|
915
|
+
result = predict_sequence_classes_from_layoutlm(
|
|
913
916
|
input_ids,
|
|
914
917
|
attention_mask,
|
|
915
918
|
token_type_ids,
|
|
@@ -1021,7 +1024,9 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
1021
1024
|
else:
|
|
1022
1025
|
raise ValueError(f"images must be list but is {type(images)}")
|
|
1023
1026
|
|
|
1024
|
-
result =
|
|
1027
|
+
result = predict_sequence_classes_from_layoutlm(
|
|
1028
|
+
input_ids, attention_mask, token_type_ids, boxes, self.model, images
|
|
1029
|
+
)
|
|
1025
1030
|
|
|
1026
1031
|
result.class_id += 1
|
|
1027
1032
|
result.class_name = self.categories.categories[result.class_id]
|
|
@@ -1115,7 +1120,9 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
1115
1120
|
else:
|
|
1116
1121
|
raise ValueError(f"images must be list but is {type(images)}")
|
|
1117
1122
|
|
|
1118
|
-
result =
|
|
1123
|
+
result = predict_sequence_classes_from_layoutlm(
|
|
1124
|
+
input_ids, attention_mask, token_type_ids, boxes, self.model, images
|
|
1125
|
+
)
|
|
1119
1126
|
|
|
1120
1127
|
result.class_id += 1
|
|
1121
1128
|
result.class_name = self.categories.categories[result.class_id]
|
|
@@ -1245,7 +1252,7 @@ class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
1245
1252
|
|
|
1246
1253
|
ann_ids, _, input_ids, attention_mask, token_type_ids, boxes, tokens = self._validate_encodings(**encodings)
|
|
1247
1254
|
|
|
1248
|
-
results =
|
|
1255
|
+
results = predict_token_classes_from_layoutlm(
|
|
1249
1256
|
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, None
|
|
1250
1257
|
)
|
|
1251
1258
|
|
|
@@ -1323,7 +1330,7 @@ class HFLiltSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
1323
1330
|
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
|
|
1324
1331
|
input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
|
|
1325
1332
|
|
|
1326
|
-
result =
|
|
1333
|
+
result = predict_sequence_classes_from_layoutlm(
|
|
1327
1334
|
input_ids,
|
|
1328
1335
|
attention_mask,
|
|
1329
1336
|
token_type_ids,
|