deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -20,27 +20,27 @@ HF Layoutlm model for diverse downstream tasks.
20
20
  """
21
21
  from __future__ import annotations
22
22
 
23
+ import os
23
24
  from abc import ABC
24
25
  from collections import defaultdict
25
- from copy import copy
26
26
  from pathlib import Path
27
- from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union
27
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Sequence, Union
28
28
 
29
29
  import numpy as np
30
30
  from lazy_imports import try_import
31
+ from typing_extensions import TypeAlias
31
32
 
32
- from ..utils.detection_types import JsonDict, Requirement
33
33
  from ..utils.file_utils import get_pytorch_requirement, get_transformers_requirement
34
- from ..utils.settings import (
35
- BioTag,
36
- ObjectTypes,
37
- TokenClasses,
38
- TypeOrStr,
39
- get_type,
40
- token_class_tag_to_token_class_with_tag,
41
- token_class_with_tag_to_token_class_and_tag,
34
+ from ..utils.settings import TypeOrStr
35
+ from ..utils.types import JsonDict, PathLikeOrStr, Requirement
36
+ from .base import (
37
+ LMSequenceClassifier,
38
+ LMTokenClassifier,
39
+ ModelCategories,
40
+ NerModelCategories,
41
+ SequenceClassResult,
42
+ TokenClassResult,
42
43
  )
43
- from .base import LMSequenceClassifier, LMTokenClassifier, SequenceClassResult, TokenClassResult
44
44
  from .pt.ptutils import get_torch_device
45
45
 
46
46
  with try_import() as pt_import_guard:
@@ -66,6 +66,35 @@ with try_import() as tr_import_guard:
66
66
  XLMRobertaTokenizerFast,
67
67
  )
68
68
 
69
+ if TYPE_CHECKING:
70
+ LayoutTokenModels: TypeAlias = Union[
71
+ LayoutLMForTokenClassification,
72
+ LayoutLMv2ForTokenClassification,
73
+ LayoutLMv3ForTokenClassification,
74
+ LiltForTokenClassification,
75
+ ]
76
+
77
+ LayoutSequenceModels: TypeAlias = Union[
78
+ LayoutLMForSequenceClassification,
79
+ LayoutLMv2ForSequenceClassification,
80
+ LayoutLMv3ForSequenceClassification,
81
+ LiltForSequenceClassification,
82
+ ]
83
+
84
+ HfLayoutTokenModels: TypeAlias = Union[
85
+ LayoutLMForTokenClassification,
86
+ LayoutLMv2ForTokenClassification,
87
+ LayoutLMv3ForTokenClassification,
88
+ LiltForTokenClassification,
89
+ ]
90
+
91
+ HfLayoutSequenceModels: TypeAlias = Union[
92
+ LayoutLMForSequenceClassification,
93
+ LayoutLMv2ForSequenceClassification,
94
+ LayoutLMv3ForSequenceClassification,
95
+ LiltForSequenceClassification,
96
+ ]
97
+
69
98
 
70
99
  def get_tokenizer_from_model_class(model_class: str, use_xlm_tokenizer: bool) -> Any:
71
100
  """
@@ -112,15 +141,15 @@ def get_tokenizer_from_model_class(model_class: str, use_xlm_tokenizer: bool) ->
112
141
 
113
142
 
114
143
  def predict_token_classes(
115
- uuids: List[List[str]],
144
+ uuids: list[list[str]],
116
145
  input_ids: torch.Tensor,
117
146
  attention_mask: torch.Tensor,
118
147
  token_type_ids: torch.Tensor,
119
148
  boxes: torch.Tensor,
120
- tokens: List[List[str]],
121
- model: Union[LayoutLMForTokenClassification, LayoutLMv2ForTokenClassification, LayoutLMv3ForTokenClassification],
149
+ tokens: list[list[str]],
150
+ model: LayoutTokenModels,
122
151
  images: Optional[torch.Tensor] = None,
123
- ) -> List[TokenClassResult]:
152
+ ) -> list[TokenClassResult]:
124
153
  """
125
154
  :param uuids: A list of uuids that correspond to a word that induces the resulting token
126
155
  :param input_ids: Token converted to ids to be taken from LayoutLMTokenizer
@@ -176,12 +205,7 @@ def predict_sequence_classes(
176
205
  attention_mask: torch.Tensor,
177
206
  token_type_ids: torch.Tensor,
178
207
  boxes: torch.Tensor,
179
- model: Union[
180
- LayoutLMForSequenceClassification,
181
- LayoutLMv2ForSequenceClassification,
182
- LayoutLMv3ForSequenceClassification,
183
- LiltForSequenceClassification,
184
- ],
208
+ model: LayoutSequenceModels,
185
209
  images: Optional[torch.Tensor] = None,
186
210
  ) -> SequenceClassResult:
187
211
  """
@@ -222,17 +246,14 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
222
246
  Abstract base class for wrapping LayoutLM models for token classification into the deepdoctection framework.
223
247
  """
224
248
 
225
- model: Union[LayoutLMForTokenClassification, LayoutLMv2ForTokenClassification]
226
-
227
249
  def __init__(
228
250
  self,
229
- path_config_json: str,
230
- path_weights: str,
251
+ path_config_json: PathLikeOrStr,
252
+ path_weights: PathLikeOrStr,
231
253
  categories_semantics: Optional[Sequence[TypeOrStr]] = None,
232
254
  categories_bio: Optional[Sequence[TypeOrStr]] = None,
233
- categories: Optional[Mapping[str, TypeOrStr]] = None,
255
+ categories: Optional[Mapping[int, TypeOrStr]] = None,
234
256
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
235
- use_xlm_tokenizer: bool = False,
236
257
  ):
237
258
  """
238
259
  :param path_config_json: path to .json config file
@@ -254,43 +275,21 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
254
275
  if categories_bio is None:
255
276
  raise ValueError("If categories is None then categories_bio cannot be None")
256
277
 
257
- self.path_config = path_config_json
258
- self.path_weights = path_weights
259
- self.categories_semantics = (
260
- [get_type(cat_sem) for cat_sem in categories_semantics] if categories_semantics is not None else []
278
+ self.path_config = Path(path_config_json)
279
+ self.path_weights = Path(path_weights)
280
+ self.categories = NerModelCategories(
281
+ init_categories=categories, categories_semantics=categories_semantics, categories_bio=categories_bio
261
282
  )
262
- self.categories_bio = [get_type(cat_bio) for cat_bio in categories_bio] if categories_bio is not None else []
263
- if categories:
264
- self.categories = copy(categories) # type: ignore
265
- else:
266
- self.categories = self._categories_orig_to_categories(
267
- self.categories_semantics, self.categories_bio # type: ignore
268
- )
269
283
  self.device = get_torch_device(device)
270
- self.model.to(self.device)
271
- self.model.config.tokenizer_class = self.get_tokenizer_class_name(use_xlm_tokenizer)
272
284
 
273
285
  @classmethod
274
- def get_requirements(cls) -> List[Requirement]:
286
+ def get_requirements(cls) -> list[Requirement]:
275
287
  return [get_pytorch_requirement(), get_transformers_requirement()]
276
288
 
277
- @staticmethod
278
- def _categories_orig_to_categories(
279
- categories_semantics: List[TokenClasses], categories_bio: List[BioTag]
280
- ) -> Dict[str, ObjectTypes]:
281
- categories_list = sorted(
282
- {
283
- token_class_tag_to_token_class_with_tag(token, tag)
284
- for token in categories_semantics
285
- for tag in categories_bio
286
- }
287
- )
288
- return {str(k): v for k, v in enumerate(categories_list, 1)}
289
-
290
- def _map_category_names(self, token_results: List[TokenClassResult]) -> List[TokenClassResult]:
289
+ def _map_category_names(self, token_results: list[TokenClassResult]) -> list[TokenClassResult]:
291
290
  for result in token_results:
292
- result.class_name = self.categories[str(result.class_id + 1)]
293
- output = token_class_with_tag_to_token_class_and_tag(result.class_name)
291
+ result.class_name = self.categories.categories[result.class_id + 1]
292
+ output = self.categories.disentangle_token_class_and_tag(result.class_name)
294
293
  if output is not None:
295
294
  token_class, tag = output
296
295
  result.semantic_name = token_class
@@ -302,7 +301,7 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
302
301
 
303
302
  def _validate_encodings(
304
303
  self, **encodings: Any
305
- ) -> Tuple[List[List[str]], List[str], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[List[str]]]:
304
+ ) -> tuple[list[list[str]], list[str], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[list[str]]]:
306
305
  image_ids = encodings.get("image_ids", [])
307
306
  ann_ids = encodings.get("ann_ids")
308
307
  input_ids = encodings.get("input_ids")
@@ -339,23 +338,25 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
339
338
  return self.__class__(
340
339
  self.path_config,
341
340
  self.path_weights,
342
- self.categories_semantics,
343
- self.categories_bio,
344
- self.categories,
341
+ self.categories.categories_semantics,
342
+ self.categories.categories_bio,
343
+ self.categories.get_categories(),
345
344
  self.device,
346
345
  )
347
346
 
348
347
  @staticmethod
349
- def get_name(path_weights: str, architecture: str) -> str:
348
+ def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
350
349
  """Returns the name of the model"""
351
350
  return f"Transformers_{architecture}_" + "_".join(Path(path_weights).parts[-2:])
352
351
 
353
- def get_tokenizer_class_name(self, use_xlm_tokenizer: bool) -> str:
352
+ @staticmethod
353
+ def get_tokenizer_class_name(model_class_name: str, use_xlm_tokenizer: bool) -> str:
354
354
  """A refinement for adding the tokenizer class name to the model configs.
355
355
 
356
+ :param model_class_name: The model name, e.g. model.__class__.__name__
356
357
  :param use_xlm_tokenizer: Whether to use a XLM tokenizer.
357
358
  """
358
- tokenizer = get_tokenizer_from_model_class(self.model.__class__.__name__, use_xlm_tokenizer)
359
+ tokenizer = get_tokenizer_from_model_class(model_class_name, use_xlm_tokenizer)
359
360
  return tokenizer.__class__.__name__
360
361
 
361
362
  @staticmethod
@@ -405,11 +406,11 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
405
406
 
406
407
  def __init__(
407
408
  self,
408
- path_config_json: str,
409
- path_weights: str,
409
+ path_config_json: PathLikeOrStr,
410
+ path_weights: PathLikeOrStr,
410
411
  categories_semantics: Optional[Sequence[TypeOrStr]] = None,
411
412
  categories_bio: Optional[Sequence[TypeOrStr]] = None,
412
- categories: Optional[Mapping[str, TypeOrStr]] = None,
413
+ categories: Optional[Mapping[int, TypeOrStr]] = None,
413
414
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
414
415
  use_xlm_tokenizer: bool = False,
415
416
  ):
@@ -426,14 +427,16 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
426
427
  :param use_xlm_tokenizer: Do not change this value unless you pre-trained a LayoutLM model with a different
427
428
  Tokenizer.
428
429
  """
430
+ super().__init__(path_config_json, path_weights, categories_semantics, categories_bio, categories, device)
429
431
  self.name = self.get_name(path_weights, "LayoutLM")
430
432
  self.model_id = self.get_model_id()
431
433
  self.model = self.get_wrapped_model(path_config_json, path_weights)
432
- super().__init__(
433
- path_config_json, path_weights, categories_semantics, categories_bio, categories, device, use_xlm_tokenizer
434
+ self.model.to(self.device)
435
+ self.model.config.tokenizer_class = self.get_tokenizer_class_name(
436
+ self.model.__class__.__name__, use_xlm_tokenizer
434
437
  )
435
438
 
436
- def predict(self, **encodings: Union[List[List[str]], torch.Tensor]) -> List[TokenClassResult]:
439
+ def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> list[TokenClassResult]:
437
440
  """
438
441
  Launch inference on LayoutLm for token classification. Pass the following arguments
439
442
 
@@ -459,7 +462,9 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
459
462
  return self._map_category_names(results)
460
463
 
461
464
  @staticmethod
462
- def get_wrapped_model(path_config_json: str, path_weights: str) -> Any:
465
+ def get_wrapped_model(
466
+ path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
467
+ ) -> LayoutLMForTokenClassification:
463
468
  """
464
469
  Get the inner (wrapped) model.
465
470
 
@@ -467,8 +472,13 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
467
472
  :param path_weights: path to model artifact
468
473
  :return: 'nn.Module'
469
474
  """
470
- config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config_json)
471
- return LayoutLMForTokenClassification.from_pretrained(pretrained_model_name_or_path=path_weights, config=config)
475
+ config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
476
+ return LayoutLMForTokenClassification.from_pretrained(
477
+ pretrained_model_name_or_path=os.fspath(path_weights), config=config
478
+ )
479
+
480
+ def clear_model(self) -> None:
481
+ self.model = None
472
482
 
473
483
 
474
484
  class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
@@ -509,11 +519,11 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
509
519
 
510
520
  def __init__(
511
521
  self,
512
- path_config_json: str,
513
- path_weights: str,
522
+ path_config_json: PathLikeOrStr,
523
+ path_weights: PathLikeOrStr,
514
524
  categories_semantics: Optional[Sequence[TypeOrStr]] = None,
515
525
  categories_bio: Optional[Sequence[TypeOrStr]] = None,
516
- categories: Optional[Mapping[str, TypeOrStr]] = None,
526
+ categories: Optional[Mapping[int, TypeOrStr]] = None,
517
527
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
518
528
  use_xlm_tokenizer: bool = False,
519
529
  ):
@@ -530,14 +540,16 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
530
540
  :param use_xlm_tokenizer: Set to True if you use a LayoutXLM model. If you use a LayoutLMv2 model keep the
531
541
  default value.
532
542
  """
543
+ super().__init__(path_config_json, path_weights, categories_semantics, categories_bio, categories, device)
533
544
  self.name = self.get_name(path_weights, "LayoutLMv2")
534
545
  self.model_id = self.get_model_id()
535
546
  self.model = self.get_wrapped_model(path_config_json, path_weights)
536
- super().__init__(
537
- path_config_json, path_weights, categories_semantics, categories_bio, categories, device, use_xlm_tokenizer
547
+ self.model.to(self.device)
548
+ self.model.config.tokenizer_class = self.get_tokenizer_class_name(
549
+ self.model.__class__.__name__, use_xlm_tokenizer
538
550
  )
539
551
 
540
- def predict(self, **encodings: Union[List[List[str]], torch.Tensor]) -> List[TokenClassResult]:
552
+ def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> list[TokenClassResult]:
541
553
  """
542
554
  Launch inference on LayoutLm for token classification. Pass the following arguments
543
555
 
@@ -568,7 +580,7 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
568
580
  return self._map_category_names(results)
569
581
 
570
582
  @staticmethod
571
- def default_kwargs_for_input_mapping() -> JsonDict:
583
+ def default_kwargs_for_image_to_features_mapping() -> JsonDict:
572
584
  """
573
585
  Add some default arguments that might be necessary when preparing a sample. Overwrite this method
574
586
  for some custom setting.
@@ -576,7 +588,9 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
576
588
  return {"image_width": 224, "image_height": 224}
577
589
 
578
590
  @staticmethod
579
- def get_wrapped_model(path_config_json: str, path_weights: str) -> Any:
591
+ def get_wrapped_model(
592
+ path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
593
+ ) -> LayoutLMv2ForTokenClassification:
580
594
  """
581
595
  Get the inner (wrapped) model.
582
596
 
@@ -584,11 +598,14 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
584
598
  :param path_weights: path to model artifact
585
599
  :return: 'nn.Module'
586
600
  """
587
- config = LayoutLMv2Config.from_pretrained(pretrained_model_name_or_path=path_config_json)
601
+ config = LayoutLMv2Config.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
588
602
  return LayoutLMv2ForTokenClassification.from_pretrained(
589
- pretrained_model_name_or_path=path_weights, config=config
603
+ pretrained_model_name_or_path=os.fspath(path_weights), config=config
590
604
  )
591
605
 
606
+ def clear_model(self) -> None:
607
+ self.model = None
608
+
592
609
 
593
610
  class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
594
611
  """
@@ -628,11 +645,11 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
628
645
 
629
646
  def __init__(
630
647
  self,
631
- path_config_json: str,
632
- path_weights: str,
648
+ path_config_json: PathLikeOrStr,
649
+ path_weights: PathLikeOrStr,
633
650
  categories_semantics: Optional[Sequence[TypeOrStr]] = None,
634
651
  categories_bio: Optional[Sequence[TypeOrStr]] = None,
635
- categories: Optional[Mapping[str, TypeOrStr]] = None,
652
+ categories: Optional[Mapping[int, TypeOrStr]] = None,
636
653
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
637
654
  use_xlm_tokenizer: bool = False,
638
655
  ):
@@ -649,14 +666,16 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
649
666
  :param use_xlm_tokenizer: Do not change this value unless you pre-trained a LayoutLMv3 model with a different
650
667
  tokenizer.
651
668
  """
669
+ super().__init__(path_config_json, path_weights, categories_semantics, categories_bio, categories, device)
652
670
  self.name = self.get_name(path_weights, "LayoutLMv3")
653
671
  self.model_id = self.get_model_id()
654
672
  self.model = self.get_wrapped_model(path_config_json, path_weights)
655
- super().__init__(
656
- path_config_json, path_weights, categories_semantics, categories_bio, categories, device, use_xlm_tokenizer
673
+ self.model.to(self.device)
674
+ self.model.config.tokenizer_class = self.get_tokenizer_class_name(
675
+ self.model.__class__.__name__, use_xlm_tokenizer
657
676
  )
658
677
 
659
- def predict(self, **encodings: Union[List[List[str]], torch.Tensor]) -> List[TokenClassResult]:
678
+ def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> list[TokenClassResult]:
660
679
  """
661
680
  Launch inference on LayoutLm for token classification. Pass the following arguments
662
681
 
@@ -683,7 +702,7 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
683
702
  return self._map_category_names(results)
684
703
 
685
704
  @staticmethod
686
- def default_kwargs_for_input_mapping() -> JsonDict:
705
+ def default_kwargs_for_image_to_features_mapping() -> JsonDict:
687
706
  """
688
707
  Add some default arguments that might be necessary when preparing a sample. Overwrite this method
689
708
  for some custom setting.
@@ -697,7 +716,9 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
697
716
  }
698
717
 
699
718
  @staticmethod
700
- def get_wrapped_model(path_config_json: str, path_weights: str) -> Any:
719
+ def get_wrapped_model(
720
+ path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
721
+ ) -> LayoutLMv3ForTokenClassification:
701
722
  """
702
723
  Get the inner (wrapped) model.
703
724
 
@@ -705,45 +726,43 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
705
726
  :param path_weights: path to model artifact
706
727
  :return: 'nn.Module'
707
728
  """
708
- config = LayoutLMv3Config.from_pretrained(pretrained_model_name_or_path=path_config_json)
729
+ config = LayoutLMv3Config.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
709
730
  return LayoutLMv3ForTokenClassification.from_pretrained(
710
- pretrained_model_name_or_path=path_weights, config=config
731
+ pretrained_model_name_or_path=os.fspath(path_weights), config=config
711
732
  )
712
733
 
734
+ def clear_model(self) -> None:
735
+ self.model = None
736
+
713
737
 
714
738
  class HFLayoutLmSequenceClassifierBase(LMSequenceClassifier, ABC):
715
739
  """
716
740
  Abstract base class for wrapping LayoutLM models for sequence classification into the deepdoctection framework.
717
741
  """
718
742
 
719
- model: Union[LayoutLMForSequenceClassification, LayoutLMv2ForSequenceClassification]
720
-
721
743
  def __init__(
722
744
  self,
723
- path_config_json: str,
724
- path_weights: str,
725
- categories: Mapping[str, TypeOrStr],
745
+ path_config_json: PathLikeOrStr,
746
+ path_weights: PathLikeOrStr,
747
+ categories: Mapping[int, TypeOrStr],
726
748
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
727
- use_xlm_tokenizer: bool = False,
728
749
  ):
729
- self.path_config = path_config_json
730
- self.path_weights = path_weights
731
- self.categories = copy(categories) # type: ignore
750
+ self.path_config = Path(path_config_json)
751
+ self.path_weights = Path(path_weights)
752
+ self.categories = ModelCategories(init_categories=categories)
732
753
 
733
754
  self.device = get_torch_device(device)
734
- self.model.to(self.device)
735
- self.model.config.tokenizer_class = self.get_tokenizer_class_name(use_xlm_tokenizer)
736
755
 
737
756
  @classmethod
738
- def get_requirements(cls) -> List[Requirement]:
757
+ def get_requirements(cls) -> list[Requirement]:
739
758
  return [get_pytorch_requirement(), get_transformers_requirement()]
740
759
 
741
760
  def clone(self) -> HFLayoutLmSequenceClassifierBase:
742
- return self.__class__(self.path_config, self.path_weights, self.categories, self.device)
761
+ return self.__class__(self.path_config, self.path_weights, self.categories.get_categories(), self.device)
743
762
 
744
763
  def _validate_encodings(
745
- self, **encodings: Union[List[List[str]], torch.Tensor]
746
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
764
+ self, **encodings: Union[list[list[str]], torch.Tensor]
765
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
747
766
  input_ids = encodings.get("input_ids")
748
767
  attention_mask = encodings.get("attention_mask")
749
768
  token_type_ids = encodings.get("token_type_ids")
@@ -773,16 +792,18 @@ class HFLayoutLmSequenceClassifierBase(LMSequenceClassifier, ABC):
773
792
  return input_ids, attention_mask, token_type_ids, boxes
774
793
 
775
794
  @staticmethod
776
- def get_name(path_weights: str, architecture: str) -> str:
795
+ def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
777
796
  """Returns the name of the model"""
778
797
  return f"Transformers_{architecture}_" + "_".join(Path(path_weights).parts[-2:])
779
798
 
780
- def get_tokenizer_class_name(self, use_xlm_tokenizer: bool) -> str:
799
+ @staticmethod
800
+ def get_tokenizer_class_name(model_class_name: str, use_xlm_tokenizer: bool) -> str:
781
801
  """A refinement for adding the tokenizer class name to the model configs.
782
802
 
803
+ :param model_class_name: The model name, e.g. model.__class__.__name__
783
804
  :param use_xlm_tokenizer: Whether to use a XLM tokenizer.
784
805
  """
785
- tokenizer = get_tokenizer_from_model_class(self.model.__class__.__name__, use_xlm_tokenizer)
806
+ tokenizer = get_tokenizer_from_model_class(model_class_name, use_xlm_tokenizer)
786
807
  return tokenizer.__class__.__name__
787
808
 
788
809
  @staticmethod
@@ -829,18 +850,22 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
829
850
 
830
851
  def __init__(
831
852
  self,
832
- path_config_json: str,
833
- path_weights: str,
834
- categories: Mapping[str, TypeOrStr],
853
+ path_config_json: PathLikeOrStr,
854
+ path_weights: PathLikeOrStr,
855
+ categories: Mapping[int, TypeOrStr],
835
856
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
836
857
  use_xlm_tokenizer: bool = False,
837
858
  ):
859
+ super().__init__(path_config_json, path_weights, categories, device)
838
860
  self.name = self.get_name(path_weights, "LayoutLM")
839
861
  self.model_id = self.get_model_id()
840
862
  self.model = self.get_wrapped_model(path_config_json, path_weights)
841
- super().__init__(path_config_json, path_weights, categories, device, use_xlm_tokenizer)
863
+ self.model.to(self.device)
864
+ self.model.config.tokenizer_class = self.get_tokenizer_class_name(
865
+ self.model.__class__.__name__, use_xlm_tokenizer
866
+ )
842
867
 
843
- def predict(self, **encodings: Union[List[List[str]], torch.Tensor]) -> SequenceClassResult:
868
+ def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
844
869
  input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
845
870
 
846
871
  result = predict_sequence_classes(
@@ -852,11 +877,13 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
852
877
  )
853
878
 
854
879
  result.class_id += 1
855
- result.class_name = self.categories[str(result.class_id)]
880
+ result.class_name = self.categories.categories[result.class_id]
856
881
  return result
857
882
 
858
883
  @staticmethod
859
- def get_wrapped_model(path_config_json: str, path_weights: str) -> Any:
884
+ def get_wrapped_model(
885
+ path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
886
+ ) -> LayoutLMForSequenceClassification:
860
887
  """
861
888
  Get the inner (wrapped) model.
862
889
 
@@ -864,11 +891,14 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
864
891
  :param path_weights: path to model artifact
865
892
  :return: 'nn.Module'
866
893
  """
867
- config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config_json)
894
+ config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
868
895
  return LayoutLMForSequenceClassification.from_pretrained(
869
- pretrained_model_name_or_path=path_weights, config=config
896
+ pretrained_model_name_or_path=os.fspath(path_weights), config=config
870
897
  )
871
898
 
899
+ def clear_model(self) -> None:
900
+ self.model = None
901
+
872
902
 
873
903
  class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
874
904
  """
@@ -903,18 +933,22 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
903
933
 
904
934
  def __init__(
905
935
  self,
906
- path_config_json: str,
907
- path_weights: str,
908
- categories: Mapping[str, TypeOrStr],
936
+ path_config_json: PathLikeOrStr,
937
+ path_weights: PathLikeOrStr,
938
+ categories: Mapping[int, TypeOrStr],
909
939
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
910
940
  use_xlm_tokenizer: bool = False,
911
941
  ):
942
+ super().__init__(path_config_json, path_weights, categories, device)
912
943
  self.name = self.get_name(path_weights, "LayoutLMv2")
913
944
  self.model_id = self.get_model_id()
914
945
  self.model = self.get_wrapped_model(path_config_json, path_weights)
915
- super().__init__(path_config_json, path_weights, categories, device, use_xlm_tokenizer)
946
+ self.model.to(self.device)
947
+ self.model.config.tokenizer_class = self.get_tokenizer_class_name(
948
+ self.model.__class__.__name__, use_xlm_tokenizer
949
+ )
916
950
 
917
- def predict(self, **encodings: Union[List[List[str]], torch.Tensor]) -> SequenceClassResult:
951
+ def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
918
952
  input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
919
953
  images = encodings.get("image")
920
954
  if isinstance(images, torch.Tensor):
@@ -925,11 +959,11 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
925
959
  result = predict_sequence_classes(input_ids, attention_mask, token_type_ids, boxes, self.model, images)
926
960
 
927
961
  result.class_id += 1
928
- result.class_name = self.categories[str(result.class_id)]
962
+ result.class_name = self.categories.categories[result.class_id]
929
963
  return result
930
964
 
931
965
  @staticmethod
932
- def default_kwargs_for_input_mapping() -> JsonDict:
966
+ def default_kwargs_for_image_to_features_mapping() -> JsonDict:
933
967
  """
934
968
  Add some default arguments that might be necessary when preparing a sample. Overwrite this method
935
969
  for some custom setting.
@@ -937,7 +971,9 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
937
971
  return {"image_width": 224, "image_height": 224}
938
972
 
939
973
  @staticmethod
940
- def get_wrapped_model(path_config_json: str, path_weights: str) -> Any:
974
+ def get_wrapped_model(
975
+ path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
976
+ ) -> LayoutLMv2ForSequenceClassification:
941
977
  """
942
978
  Get the inner (wrapped) model.
943
979
 
@@ -945,11 +981,14 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
945
981
  :param path_weights: path to model artifact
946
982
  :return: 'nn.Module'
947
983
  """
948
- config = LayoutLMv2Config.from_pretrained(pretrained_model_name_or_path=path_config_json)
984
+ config = LayoutLMv2Config.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
949
985
  return LayoutLMv2ForSequenceClassification.from_pretrained(
950
- pretrained_model_name_or_path=path_weights, config=config
986
+ pretrained_model_name_or_path=os.fspath(path_weights), config=config
951
987
  )
952
988
 
989
+ def clear_model(self) -> None:
990
+ self.model = None
991
+
953
992
 
954
993
  class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
955
994
  """
@@ -984,18 +1023,22 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
984
1023
 
985
1024
  def __init__(
986
1025
  self,
987
- path_config_json: str,
988
- path_weights: str,
989
- categories: Mapping[str, TypeOrStr],
1026
+ path_config_json: PathLikeOrStr,
1027
+ path_weights: PathLikeOrStr,
1028
+ categories: Mapping[int, TypeOrStr],
990
1029
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
991
1030
  use_xlm_tokenizer: bool = False,
992
1031
  ):
1032
+ super().__init__(path_config_json, path_weights, categories, device)
993
1033
  self.name = self.get_name(path_weights, "LayoutLMv3")
994
1034
  self.model_id = self.get_model_id()
995
1035
  self.model = self.get_wrapped_model(path_config_json, path_weights)
996
- super().__init__(path_config_json, path_weights, categories, device, use_xlm_tokenizer)
1036
+ self.model.to(self.device)
1037
+ self.model.config.tokenizer_class = self.get_tokenizer_class_name(
1038
+ self.model.__class__.__name__, use_xlm_tokenizer
1039
+ )
997
1040
 
998
- def predict(self, **encodings: Union[List[List[str]], torch.Tensor]) -> SequenceClassResult:
1041
+ def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
999
1042
  input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
1000
1043
  images = encodings.get("pixel_values")
1001
1044
  if isinstance(images, torch.Tensor):
@@ -1006,11 +1049,11 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
1006
1049
  result = predict_sequence_classes(input_ids, attention_mask, token_type_ids, boxes, self.model, images)
1007
1050
 
1008
1051
  result.class_id += 1
1009
- result.class_name = self.categories[str(result.class_id)]
1052
+ result.class_name = self.categories.categories[result.class_id]
1010
1053
  return result
1011
1054
 
1012
1055
  @staticmethod
1013
- def default_kwargs_for_input_mapping() -> JsonDict:
1056
+ def default_kwargs_for_image_to_features_mapping() -> JsonDict:
1014
1057
  """
1015
1058
  Add some default arguments that might be necessary when preparing a sample. Overwrite this method
1016
1059
  for some custom setting.
@@ -1024,7 +1067,9 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
1024
1067
  }
1025
1068
 
1026
1069
  @staticmethod
1027
- def get_wrapped_model(path_config_json: str, path_weights: str) -> Any:
1070
+ def get_wrapped_model(
1071
+ path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
1072
+ ) -> LayoutLMv3ForSequenceClassification:
1028
1073
  """
1029
1074
  Get the inner (wrapped) model.
1030
1075
 
@@ -1032,11 +1077,14 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
1032
1077
  :param path_weights: path to model artifact
1033
1078
  :return: 'nn.Module'
1034
1079
  """
1035
- config = LayoutLMv3Config.from_pretrained(pretrained_model_name_or_path=path_config_json)
1080
+ config = LayoutLMv3Config.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
1036
1081
  return LayoutLMv3ForSequenceClassification.from_pretrained(
1037
- pretrained_model_name_or_path=path_weights, config=config
1082
+ pretrained_model_name_or_path=os.fspath(path_weights), config=config
1038
1083
  )
1039
1084
 
1085
+ def clear_model(self) -> None:
1086
+ self.model = None
1087
+
1040
1088
 
1041
1089
  class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
1042
1090
  """
@@ -1074,11 +1122,11 @@ class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
1074
1122
 
1075
1123
  def __init__(
1076
1124
  self,
1077
- path_config_json: str,
1078
- path_weights: str,
1125
+ path_config_json: PathLikeOrStr,
1126
+ path_weights: PathLikeOrStr,
1079
1127
  categories_semantics: Optional[Sequence[TypeOrStr]] = None,
1080
1128
  categories_bio: Optional[Sequence[TypeOrStr]] = None,
1081
- categories: Optional[Mapping[str, TypeOrStr]] = None,
1129
+ categories: Optional[Mapping[int, TypeOrStr]] = None,
1082
1130
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
1083
1131
  use_xlm_tokenizer: bool = False,
1084
1132
  ):
@@ -1093,14 +1141,17 @@ class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
1093
1141
  :param categories: If you have a pre-trained model you can pass a complete dict of NER categories
1094
1142
  :param device: The device (cpu,"cuda"), where to place the model.
1095
1143
  """
1144
+
1145
+ super().__init__(path_config_json, path_weights, categories_semantics, categories_bio, categories, device)
1096
1146
  self.name = self.get_name(path_weights, "LiLT")
1097
1147
  self.model_id = self.get_model_id()
1098
1148
  self.model = self.get_wrapped_model(path_config_json, path_weights)
1099
- super().__init__(
1100
- path_config_json, path_weights, categories_semantics, categories_bio, categories, device, use_xlm_tokenizer
1149
+ self.model.to(self.device)
1150
+ self.model.config.tokenizer_class = self.get_tokenizer_class_name(
1151
+ self.model.__class__.__name__, use_xlm_tokenizer
1101
1152
  )
1102
1153
 
1103
- def predict(self, **encodings: Union[List[List[str]], torch.Tensor]) -> List[TokenClassResult]:
1154
+ def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> list[TokenClassResult]:
1104
1155
  """
1105
1156
  Launch inference on LayoutLm for token classification. Pass the following arguments
1106
1157
 
@@ -1126,7 +1177,7 @@ class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
1126
1177
  return self._map_category_names(results)
1127
1178
 
1128
1179
  @staticmethod
1129
- def get_wrapped_model(path_config_json: str, path_weights: str) -> Any:
1180
+ def get_wrapped_model(path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr) -> LiltForTokenClassification:
1130
1181
  """
1131
1182
  Get the inner (wrapped) model.
1132
1183
 
@@ -1137,6 +1188,9 @@ class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
1137
1188
  config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config_json)
1138
1189
  return LiltForTokenClassification.from_pretrained(pretrained_model_name_or_path=path_weights, config=config)
1139
1190
 
1191
+ def clear_model(self) -> None:
1192
+ self.model = None
1193
+
1140
1194
 
1141
1195
  class HFLiltSequenceClassifier(HFLayoutLmSequenceClassifierBase):
1142
1196
  """
@@ -1172,18 +1226,22 @@ class HFLiltSequenceClassifier(HFLayoutLmSequenceClassifierBase):
1172
1226
 
1173
1227
  def __init__(
1174
1228
  self,
1175
- path_config_json: str,
1176
- path_weights: str,
1177
- categories: Mapping[str, TypeOrStr],
1229
+ path_config_json: PathLikeOrStr,
1230
+ path_weights: PathLikeOrStr,
1231
+ categories: Mapping[int, TypeOrStr],
1178
1232
  device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
1179
1233
  use_xlm_tokenizer: bool = False,
1180
1234
  ):
1235
+ super().__init__(path_config_json, path_weights, categories, device)
1181
1236
  self.name = self.get_name(path_weights, "LiLT")
1182
1237
  self.model_id = self.get_model_id()
1183
1238
  self.model = self.get_wrapped_model(path_config_json, path_weights)
1184
- super().__init__(path_config_json, path_weights, categories, device, use_xlm_tokenizer)
1239
+ self.model.to(self.device)
1240
+ self.model.config.tokenizer_class = self.get_tokenizer_class_name(
1241
+ self.model.__class__.__name__, use_xlm_tokenizer
1242
+ )
1185
1243
 
1186
- def predict(self, **encodings: Union[List[List[str]], torch.Tensor]) -> SequenceClassResult:
1244
+ def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
1187
1245
  input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
1188
1246
 
1189
1247
  result = predict_sequence_classes(
@@ -1195,11 +1253,11 @@ class HFLiltSequenceClassifier(HFLayoutLmSequenceClassifierBase):
1195
1253
  )
1196
1254
 
1197
1255
  result.class_id += 1
1198
- result.class_name = self.categories[str(result.class_id)]
1256
+ result.class_name = self.categories.categories[result.class_id]
1199
1257
  return result
1200
1258
 
1201
1259
  @staticmethod
1202
- def get_wrapped_model(path_config_json: str, path_weights: str) -> Any:
1260
+ def get_wrapped_model(path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr) -> Any:
1203
1261
  """
1204
1262
  Get the inner (wrapped) model.
1205
1263
 
@@ -1209,3 +1267,6 @@ class HFLiltSequenceClassifier(HFLayoutLmSequenceClassifierBase):
1209
1267
  """
1210
1268
  config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config_json)
1211
1269
  return LiltForSequenceClassification.from_pretrained(pretrained_model_name_or_path=path_weights, config=config)
1270
+
1271
+ def clear_model(self) -> None:
1272
+ self.model = None