deepdoctection 0.44.1__py3-none-any.whl → 0.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (33) hide show
  1. deepdoctection/__init__.py +7 -3
  2. deepdoctection/analyzer/config.py +44 -0
  3. deepdoctection/analyzer/factory.py +264 -7
  4. deepdoctection/configs/profiles.jsonl +2 -1
  5. deepdoctection/dataflow/parallel_map.py +7 -1
  6. deepdoctection/datapoint/box.py +5 -5
  7. deepdoctection/datapoint/image.py +5 -5
  8. deepdoctection/datapoint/view.py +73 -52
  9. deepdoctection/eval/cocometric.py +1 -0
  10. deepdoctection/extern/__init__.py +1 -0
  11. deepdoctection/extern/base.py +8 -1
  12. deepdoctection/extern/d2detect.py +1 -1
  13. deepdoctection/extern/doctrocr.py +18 -2
  14. deepdoctection/extern/fastlang.py +2 -2
  15. deepdoctection/extern/hflayoutlm.py +17 -10
  16. deepdoctection/extern/hflm.py +432 -7
  17. deepdoctection/extern/tessocr.py +17 -1
  18. deepdoctection/pipe/language.py +4 -4
  19. deepdoctection/pipe/lm.py +7 -3
  20. deepdoctection/pipe/order.py +12 -6
  21. deepdoctection/pipe/refine.py +10 -1
  22. deepdoctection/pipe/text.py +6 -0
  23. deepdoctection/pipe/transform.py +3 -0
  24. deepdoctection/utils/file_utils.py +34 -5
  25. deepdoctection/utils/logger.py +38 -1
  26. deepdoctection/utils/settings.py +2 -0
  27. deepdoctection/utils/transform.py +43 -18
  28. deepdoctection/utils/viz.py +24 -15
  29. {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/METADATA +16 -21
  30. {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/RECORD +33 -33
  31. {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/WHEEL +0 -0
  32. {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/licenses/LICENSE +0 -0
  33. {deepdoctection-0.44.1.dist-info → deepdoctection-0.46.dist-info}/top_level.txt +0 -0
@@ -319,29 +319,32 @@ class Layout(ImageAnnotationBaseView):
319
319
  token_tag_ann_ids,
320
320
  token_classes_ids,
321
321
  token_tag_ids,
322
- ) = map(list, zip(
323
- *[
324
- (
325
- word.characters,
326
- word.annotation_id,
327
- word.token_class,
328
- word.get_sub_category(WordType.TOKEN_CLASS).annotation_id
329
- if WordType.TOKEN_CLASS in word.sub_categories
330
- else None,
331
- word.token_tag,
332
- word.get_sub_category(WordType.TOKEN_TAG).annotation_id
333
- if WordType.TOKEN_TAG in word.sub_categories
334
- else None,
335
- word.get_sub_category(WordType.TOKEN_CLASS).category_id
336
- if WordType.TOKEN_CLASS in word.sub_categories
337
- else None,
338
- word.get_sub_category(WordType.TOKEN_TAG).category_id
339
- if WordType.TOKEN_TAG in word.sub_categories
340
- else None,
341
- )
342
- for word in words
343
- ]
344
- ))
322
+ ) = map(
323
+ list,
324
+ zip(
325
+ *[
326
+ (
327
+ word.characters,
328
+ word.annotation_id,
329
+ word.token_class,
330
+ word.get_sub_category(WordType.TOKEN_CLASS).annotation_id
331
+ if WordType.TOKEN_CLASS in word.sub_categories
332
+ else None,
333
+ word.token_tag,
334
+ word.get_sub_category(WordType.TOKEN_TAG).annotation_id
335
+ if WordType.TOKEN_TAG in word.sub_categories
336
+ else None,
337
+ word.get_sub_category(WordType.TOKEN_CLASS).category_id
338
+ if WordType.TOKEN_CLASS in word.sub_categories
339
+ else None,
340
+ word.get_sub_category(WordType.TOKEN_TAG).category_id
341
+ if WordType.TOKEN_TAG in word.sub_categories
342
+ else None,
343
+ )
344
+ for word in words
345
+ ]
346
+ ),
347
+ )
345
348
  else:
346
349
  (
347
350
  characters,
@@ -364,18 +367,17 @@ class Layout(ImageAnnotationBaseView):
364
367
  )
365
368
 
366
369
  return Text_(
367
- text=" ".join(characters), # type: ignore
368
- words=characters, # type: ignore
369
- ann_ids=ann_ids, # type: ignore
370
- token_classes=token_classes, # type: ignore
371
- token_class_ann_ids=token_class_ann_ids, # type: ignore
372
- token_tags=token_tags, # type: ignore
373
- token_tag_ann_ids=token_tag_ann_ids, # type: ignore
374
- token_class_ids=token_classes_ids, # type: ignore
375
- token_tag_ids=token_tag_ids, # type: ignore
370
+ text=" ".join(characters), # type: ignore
371
+ words=characters, # type: ignore
372
+ ann_ids=ann_ids, # type: ignore
373
+ token_classes=token_classes, # type: ignore
374
+ token_class_ann_ids=token_class_ann_ids, # type: ignore
375
+ token_tags=token_tags, # type: ignore
376
+ token_tag_ann_ids=token_tag_ann_ids, # type: ignore
377
+ token_class_ids=token_classes_ids, # type: ignore
378
+ token_tag_ids=token_tag_ids, # type: ignore
376
379
  )
377
380
 
378
-
379
381
  def get_attribute_names(self) -> set[str]:
380
382
  attr_names = (
381
383
  {"words", "text"}
@@ -426,6 +428,8 @@ class List(Layout):
426
428
  A list of words order by reading order. Words with no `reading_order` will not be returned"""
427
429
  try:
428
430
  list_items = self.list_items
431
+ if not list_items:
432
+ return super().get_ordered_words()
429
433
  all_words = []
430
434
  list_items.sort(key=lambda x: x.bbox[1])
431
435
  for list_item in list_items:
@@ -464,9 +468,9 @@ class Table(Layout):
464
468
  A list of a table cells.
465
469
  """
466
470
  cell_anns: list[Cell] = []
467
- for row_number in range(1, self.number_of_rows + 1): # type: ignore
468
- cell_anns.extend(self.row(row_number)) # type: ignore
469
-
471
+ if self.number_of_rows:
472
+ for row_number in range(1, self.number_of_rows + 1): # type: ignore
473
+ cell_anns.extend(self.row(row_number)) # type: ignore
470
474
  return cell_anns
471
475
 
472
476
  @property
@@ -731,7 +735,6 @@ class Table(Layout):
731
735
  token_tag_ids=token_tag_ids,
732
736
  )
733
737
 
734
-
735
738
  @property
736
739
  def words(self) -> list[ImageAnnotationBaseView]:
737
740
  """
@@ -754,6 +757,8 @@ class Table(Layout):
754
757
  """
755
758
  try:
756
759
  cells = self.cells
760
+ if not cells:
761
+ return super().get_ordered_words()
757
762
  all_words = []
758
763
  cells.sort(key=lambda x: (x.ROW_NUMBER, x.COLUMN_NUMBER))
759
764
  for cell in cells:
@@ -1053,6 +1058,8 @@ class Page(Image):
1053
1058
  Returns:
1054
1059
  A `Page` instance with all annotations as `ImageAnnotationBaseView` subclasses.
1055
1060
  """
1061
+ if isinstance(image_orig, Page):
1062
+ raise ImageError("Page.from_image() cannot be called on a Page instance.")
1056
1063
 
1057
1064
  if text_container is None:
1058
1065
  text_container = IMAGE_DEFAULTS.TEXT_CONTAINER
@@ -1175,7 +1182,6 @@ class Page(Image):
1175
1182
  token_tag_ids=token_tag_ann_ids,
1176
1183
  )
1177
1184
 
1178
-
1179
1185
  def get_layout_context(self, annotation_id: str, context_size: int = 3) -> list[ImageAnnotationBaseView]:
1180
1186
  """
1181
1187
  For a given `annotation_id` get a list of `ImageAnnotation` that are nearby in terms of `reading_order`.
@@ -1310,7 +1316,7 @@ class Page(Image):
1310
1316
  If `interactive=False` will return a `np.array`.
1311
1317
  """
1312
1318
 
1313
- category_names_list: list[Union[str, None]] = []
1319
+ category_names_list: list[Tuple[Union[str, None], Union[str, None]]] = []
1314
1320
  box_stack = []
1315
1321
  cells_found = False
1316
1322
 
@@ -1323,22 +1329,23 @@ class Page(Image):
1323
1329
  anns = self.get_annotation(category_names=list(debug_kwargs.keys()))
1324
1330
  for ann in anns:
1325
1331
  box_stack.append(self._ann_viz_bbox(ann))
1326
- category_names_list.append(str(getattr(ann, debug_kwargs[ann.category_name])))
1332
+ val = str(getattr(ann, debug_kwargs[ann.category_name]))
1333
+ category_names_list.append((val, val))
1327
1334
 
1328
1335
  if show_layouts and not debug_kwargs:
1329
1336
  for item in self.layouts:
1330
1337
  box_stack.append(self._ann_viz_bbox(item))
1331
- category_names_list.append(item.category_name.value)
1338
+ category_names_list.append((item.category_name.value, item.category_name.value))
1332
1339
 
1333
1340
  if show_figures and not debug_kwargs:
1334
1341
  for item in self.figures:
1335
1342
  box_stack.append(self._ann_viz_bbox(item))
1336
- category_names_list.append(item.category_name.value)
1343
+ category_names_list.append((item.category_name.value, item.category_name.value))
1337
1344
 
1338
1345
  if show_tables and not debug_kwargs:
1339
1346
  for table in self.tables:
1340
1347
  box_stack.append(self._ann_viz_bbox(table))
1341
- category_names_list.append(LayoutType.TABLE.value)
1348
+ category_names_list.append((LayoutType.TABLE.value, LayoutType.TABLE.value))
1342
1349
  if show_cells:
1343
1350
  for cell in table.cells:
1344
1351
  if cell.category_name in {
@@ -1347,21 +1354,21 @@ class Page(Image):
1347
1354
  }:
1348
1355
  cells_found = True
1349
1356
  box_stack.append(self._ann_viz_bbox(cell))
1350
- category_names_list.append(None)
1357
+ category_names_list.append((None, cell.category_name.value))
1351
1358
  if show_table_structure:
1352
1359
  rows = table.rows
1353
1360
  cols = table.columns
1354
1361
  for row in rows:
1355
1362
  box_stack.append(self._ann_viz_bbox(row))
1356
- category_names_list.append(None)
1363
+ category_names_list.append((None, row.category_name.value))
1357
1364
  for col in cols:
1358
1365
  box_stack.append(self._ann_viz_bbox(col))
1359
- category_names_list.append(None)
1366
+ category_names_list.append((None, col.category_name.value))
1360
1367
 
1361
1368
  if show_cells and not cells_found and not debug_kwargs:
1362
1369
  for ann in self.get_annotation(category_names=[LayoutType.CELL, CellType.SPANNING]):
1363
1370
  box_stack.append(self._ann_viz_bbox(ann))
1364
- category_names_list.append(None)
1371
+ category_names_list.append((None, ann.category_name.value))
1365
1372
 
1366
1373
  if show_words and not debug_kwargs:
1367
1374
  all_words = []
@@ -1379,22 +1386,36 @@ class Page(Image):
1379
1386
  for word in all_words:
1380
1387
  box_stack.append(self._ann_viz_bbox(word))
1381
1388
  if show_token_class:
1382
- category_names_list.append(word.token_class.value if word.token_class is not None else None)
1389
+ category_names_list.append(
1390
+ (word.token_class.value, word.token_class.value)
1391
+ if word.token_class is not None
1392
+ else (None, None)
1393
+ )
1383
1394
  else:
1384
- category_names_list.append(word.token_tag.value if word.token_tag is not None else None)
1395
+ category_names_list.append(
1396
+ (word.token_tag.value, word.token_tag.value) if word.token_tag is not None else (None, None)
1397
+ )
1385
1398
  else:
1386
1399
  for word in all_words:
1387
1400
  if word.token_class is not None and word.token_class != TokenClasses.OTHER:
1388
1401
  box_stack.append(self._ann_viz_bbox(word))
1389
1402
  if show_token_class:
1390
- category_names_list.append(word.token_class.value if word.token_class is not None else None)
1403
+ category_names_list.append(
1404
+ (word.token_class.value, word.token_class.value)
1405
+ if word.token_class is not None
1406
+ else (None, None)
1407
+ )
1391
1408
  else:
1392
- category_names_list.append(word.token_tag.value if word.token_tag is not None else None)
1409
+ category_names_list.append(
1410
+ (word.token_tag.value, word.token_tag.value)
1411
+ if word.token_tag is not None
1412
+ else (None, None)
1413
+ )
1393
1414
 
1394
1415
  if show_residual_layouts and not debug_kwargs:
1395
1416
  for item in self.residual_layouts:
1396
1417
  box_stack.append(item.bbox)
1397
- category_names_list.append(item.category_name.value)
1418
+ category_names_list.append((item.category_name.value, item.category_name.value))
1398
1419
 
1399
1420
  if self.image is not None:
1400
1421
  scale_fx = scaled_width / self.width
@@ -275,6 +275,7 @@ class CocoMetric(MetricBase):
275
275
  get the ultimate F1-score.
276
276
  f1_iou: Use with `f1_score=True` and reset the f1 iou threshold
277
277
  per_category: Whether to calculate metrics per category
278
+ per_category: If set to True, f1 score will be returned by each category.
278
279
  """
279
280
  if max_detections is not None:
280
281
  assert len(max_detections) == 3, max_detections
@@ -26,6 +26,7 @@ from .doctrocr import *
26
26
  from .fastlang import *
27
27
  from .hfdetr import *
28
28
  from .hflayoutlm import *
29
+ from .hflm import *
29
30
  from .model import *
30
31
  from .pdftext import *
31
32
  from .tessocr import *
@@ -263,7 +263,7 @@ class PredictorBase(ABC):
263
263
  requirements = cls.get_requirements()
264
264
  name = cls.__name__ if hasattr(cls, "__name__") else cls.__class__.__name__
265
265
  if not all(requirement[1] for requirement in requirements):
266
- raise ImportError(
266
+ raise ModuleNotFoundError(
267
267
  "\n".join(
268
268
  [f"{name} has the following dependencies:"]
269
269
  + [requirement[2] for requirement in requirements if not requirement[1]]
@@ -334,6 +334,11 @@ class DetectionResult:
334
334
  block: block number. For reading order from some ocr predictors
335
335
  line: line number. For reading order from some ocr predictors
336
336
  uuid: uuid. For assigning detection result (e.g. text to image annotations)
337
+ relationships: A dictionary of relationships. Each key is a relationship type and each value is a list of
338
+ uuids of the related annotations.
339
+ angle: angle of rotation in degrees. Only used for text detection.
340
+ image_width: image width
341
+ image_height: image height
337
342
  """
338
343
 
339
344
  box: Optional[list[float]] = None
@@ -348,6 +353,8 @@ class DetectionResult:
348
353
  uuid: Optional[str] = None
349
354
  relationships: Optional[dict[str, Any]] = None
350
355
  angle: Optional[float] = None
356
+ image_width: Optional[Union[int, float]] = None
357
+ image_height: Optional[Union[int, float]] = None
351
358
 
352
359
 
353
360
  class ObjectDetector(PredictorBase, ABC):
@@ -91,7 +91,7 @@ def d2_predict_image(
91
91
  """
92
92
  height, width = np_img.shape[:2]
93
93
  resized_img = resizer.get_transform(np_img).apply_image(np_img)
94
- image = torch.as_tensor(resized_img.astype("float32").transpose(2, 0, 1))
94
+ image = torch.as_tensor(resized_img.astype(np.float32).transpose(2, 0, 1))
95
95
 
96
96
  with torch.no_grad():
97
97
  inputs = {"image": image, "height": height, "width": width}
@@ -24,9 +24,10 @@ from __future__ import annotations
24
24
  import os
25
25
  from abc import ABC
26
26
  from pathlib import Path
27
- from typing import Any, Literal, Mapping, Optional, Union
27
+ from typing import Any, Literal, Mapping, Optional, Sequence, Union
28
28
  from zipfile import ZipFile
29
29
 
30
+ import numpy as np
30
31
  from lazy_imports import try_import
31
32
 
32
33
  from ..utils.env_info import ENV_VARS_TRUE
@@ -39,6 +40,7 @@ from ..utils.file_utils import (
39
40
  )
40
41
  from ..utils.fs import load_json
41
42
  from ..utils.settings import LayoutType, ObjectTypes, PageType, TypeOrStr
43
+ from ..utils.transform import RotationTransform
42
44
  from ..utils.types import PathLikeOrStr, PixelValues, Requirement
43
45
  from ..utils.viz import viz_handler
44
46
  from .base import DetectionResult, ImageTransformer, ModelCategories, ObjectDetector, TextRecognizer
@@ -558,12 +560,13 @@ class DocTrRotationTransformer(ImageTransformer):
558
560
  """
559
561
  Args:
560
562
  number_contours: the number of contours used for the orientation estimation
561
- ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
563
+ ratio_threshold_for_lines: this is the ratio w/h used to discriminate lines
562
564
  """
563
565
  self.number_contours = number_contours
564
566
  self.ratio_threshold_for_lines = ratio_threshold_for_lines
565
567
  self.name = "doctr_rotation_transformer"
566
568
  self.model_id = self.get_model_id()
569
+ self.rotator = RotationTransform(360)
567
570
 
568
571
  def transform_image(self, np_img: PixelValues, specification: DetectionResult) -> PixelValues:
569
572
  """
@@ -579,6 +582,19 @@ class DocTrRotationTransformer(ImageTransformer):
579
582
  """
580
583
  return viz_handler.rotate_image(np_img, specification.angle) # type: ignore
581
584
 
585
+ def transform_coords(self, detect_results: Sequence[DetectionResult]) -> Sequence[DetectionResult]:
586
+ if detect_results:
587
+ if detect_results[0].angle:
588
+ self.rotator.set_angle(detect_results[0].angle) # type: ignore
589
+ self.rotator.set_image_width(detect_results[0].image_width) # type: ignore
590
+ self.rotator.set_image_height(detect_results[0].image_height) # type: ignore
591
+ transformed_coords = self.rotator.apply_coords(
592
+ np.asarray([detect_result.box for detect_result in detect_results], dtype=float)
593
+ )
594
+ for idx, detect_result in enumerate(detect_results):
595
+ detect_result.box = transformed_coords[idx, :].tolist()
596
+ return detect_results
597
+
582
598
  def predict(self, np_img: PixelValues) -> DetectionResult:
583
599
  angle = estimate_orientation(
584
600
  np_img, n_ct=self.number_contours, ratio_threshold_for_lines=self.ratio_threshold_for_lines
@@ -36,7 +36,7 @@ from ..utils.types import PathLikeOrStr
36
36
  from .base import DetectionResult, LanguageDetector, ModelCategories
37
37
 
38
38
  with try_import() as import_guard:
39
- from fasttext import load_model # type: ignore
39
+ from fasttext import load_model # type: ignore # pylint: disable=E0401
40
40
 
41
41
 
42
42
  class FasttextLangDetectorMixin(LanguageDetector, ABC):
@@ -62,7 +62,7 @@ class FasttextLangDetectorMixin(LanguageDetector, ABC):
62
62
  Returns:
63
63
  `DetectionResult` filled with `text` and `score`
64
64
  """
65
- return DetectionResult(text=self.categories_orig[output[0][0]], score=output[1][0])
65
+ return DetectionResult(class_name=self.categories_orig[output[0][0]], score=output[1][0])
66
66
 
67
67
  @staticmethod
68
68
  def get_name(path_weights: PathLikeOrStr) -> str:
@@ -126,10 +126,13 @@ def get_tokenizer_from_model_class(model_class: str, use_xlm_tokenizer: bool) ->
126
126
  ("XLMRobertaForSequenceClassification", True): XLMRobertaTokenizerFast.from_pretrained(
127
127
  "FacebookAI/xlm-roberta-base"
128
128
  ),
129
+ ("XLMRobertaForTokenClassification", True): XLMRobertaTokenizerFast.from_pretrained(
130
+ "FacebookAI/xlm-roberta-base"
131
+ ),
129
132
  }[(model_class, use_xlm_tokenizer)]
130
133
 
131
134
 
132
- def predict_token_classes(
135
+ def predict_token_classes_from_layoutlm(
133
136
  uuids: list[list[str]],
134
137
  input_ids: torch.Tensor,
135
138
  attention_mask: torch.Tensor,
@@ -192,7 +195,7 @@ def predict_token_classes(
192
195
  return all_token_classes
193
196
 
194
197
 
195
- def predict_sequence_classes(
198
+ def predict_sequence_classes_from_layoutlm(
196
199
  input_ids: torch.Tensor,
197
200
  attention_mask: torch.Tensor,
198
201
  token_type_ids: torch.Tensor,
@@ -462,7 +465,7 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
462
465
 
463
466
  ann_ids, _, input_ids, attention_mask, token_type_ids, boxes, tokens = self._validate_encodings(**encodings)
464
467
 
465
- results = predict_token_classes(
468
+ results = predict_token_classes_from_layoutlm(
466
469
  ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, None
467
470
  )
468
471
 
@@ -586,7 +589,7 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
586
589
  images = images.to(self.device)
587
590
  else:
588
591
  raise ValueError(f"images must be list but is {type(images)}")
589
- results = predict_token_classes(
592
+ results = predict_token_classes_from_layoutlm(
590
593
  ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, images
591
594
  )
592
595
 
@@ -710,7 +713,7 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
710
713
  images = images.to(self.device)
711
714
  else:
712
715
  raise ValueError(f"images must be list but is {type(images)}")
713
- results = predict_token_classes(
716
+ results = predict_token_classes_from_layoutlm(
714
717
  ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, images
715
718
  )
716
719
 
@@ -909,7 +912,7 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
909
912
  """
910
913
  input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
911
914
 
912
- result = predict_sequence_classes(
915
+ result = predict_sequence_classes_from_layoutlm(
913
916
  input_ids,
914
917
  attention_mask,
915
918
  token_type_ids,
@@ -1021,7 +1024,9 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
1021
1024
  else:
1022
1025
  raise ValueError(f"images must be list but is {type(images)}")
1023
1026
 
1024
- result = predict_sequence_classes(input_ids, attention_mask, token_type_ids, boxes, self.model, images)
1027
+ result = predict_sequence_classes_from_layoutlm(
1028
+ input_ids, attention_mask, token_type_ids, boxes, self.model, images
1029
+ )
1025
1030
 
1026
1031
  result.class_id += 1
1027
1032
  result.class_name = self.categories.categories[result.class_id]
@@ -1115,7 +1120,9 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
1115
1120
  else:
1116
1121
  raise ValueError(f"images must be list but is {type(images)}")
1117
1122
 
1118
- result = predict_sequence_classes(input_ids, attention_mask, token_type_ids, boxes, self.model, images)
1123
+ result = predict_sequence_classes_from_layoutlm(
1124
+ input_ids, attention_mask, token_type_ids, boxes, self.model, images
1125
+ )
1119
1126
 
1120
1127
  result.class_id += 1
1121
1128
  result.class_name = self.categories.categories[result.class_id]
@@ -1245,7 +1252,7 @@ class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
1245
1252
 
1246
1253
  ann_ids, _, input_ids, attention_mask, token_type_ids, boxes, tokens = self._validate_encodings(**encodings)
1247
1254
 
1248
- results = predict_token_classes(
1255
+ results = predict_token_classes_from_layoutlm(
1249
1256
  ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, None
1250
1257
  )
1251
1258
 
@@ -1323,7 +1330,7 @@ class HFLiltSequenceClassifier(HFLayoutLmSequenceClassifierBase):
1323
1330
  def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
1324
1331
  input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
1325
1332
 
1326
- result = predict_sequence_classes(
1333
+ result = predict_sequence_classes_from_layoutlm(
1327
1334
  input_ids,
1328
1335
  attention_mask,
1329
1336
  token_type_ids,