deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +16 -29
- deepdoctection/analyzer/dd.py +70 -59
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/dataflow/common.py +9 -5
- deepdoctection/dataflow/custom.py +5 -5
- deepdoctection/dataflow/custom_serialize.py +75 -18
- deepdoctection/dataflow/parallel_map.py +3 -3
- deepdoctection/dataflow/serialize.py +4 -4
- deepdoctection/dataflow/stats.py +3 -3
- deepdoctection/datapoint/annotation.py +41 -56
- deepdoctection/datapoint/box.py +9 -8
- deepdoctection/datapoint/convert.py +6 -6
- deepdoctection/datapoint/image.py +56 -44
- deepdoctection/datapoint/view.py +245 -150
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +35 -26
- deepdoctection/datasets/base.py +14 -12
- deepdoctection/datasets/dataflow_builder.py +3 -3
- deepdoctection/datasets/info.py +24 -26
- deepdoctection/datasets/instances/doclaynet.py +51 -51
- deepdoctection/datasets/instances/fintabnet.py +46 -46
- deepdoctection/datasets/instances/funsd.py +25 -24
- deepdoctection/datasets/instances/iiitar13k.py +13 -10
- deepdoctection/datasets/instances/layouttest.py +4 -3
- deepdoctection/datasets/instances/publaynet.py +5 -5
- deepdoctection/datasets/instances/pubtables1m.py +24 -21
- deepdoctection/datasets/instances/pubtabnet.py +32 -30
- deepdoctection/datasets/instances/rvlcdip.py +30 -30
- deepdoctection/datasets/instances/xfund.py +26 -26
- deepdoctection/datasets/save.py +6 -6
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/accmetric.py +32 -33
- deepdoctection/eval/base.py +8 -9
- deepdoctection/eval/cocometric.py +15 -13
- deepdoctection/eval/eval.py +41 -37
- deepdoctection/eval/tedsmetric.py +30 -23
- deepdoctection/eval/tp_eval_callback.py +16 -19
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/base.py +339 -134
- deepdoctection/extern/d2detect.py +85 -113
- deepdoctection/extern/deskew.py +14 -11
- deepdoctection/extern/doctrocr.py +141 -130
- deepdoctection/extern/fastlang.py +27 -18
- deepdoctection/extern/hfdetr.py +71 -62
- deepdoctection/extern/hflayoutlm.py +504 -211
- deepdoctection/extern/hflm.py +230 -0
- deepdoctection/extern/model.py +488 -302
- deepdoctection/extern/pdftext.py +23 -19
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +29 -19
- deepdoctection/extern/tessocr.py +39 -38
- deepdoctection/extern/texocr.py +18 -18
- deepdoctection/extern/tp/tfutils.py +57 -9
- deepdoctection/extern/tp/tpcompat.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +45 -53
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/cats.py +27 -29
- deepdoctection/mapper/cocostruct.py +10 -10
- deepdoctection/mapper/d2struct.py +27 -26
- deepdoctection/mapper/hfstruct.py +13 -8
- deepdoctection/mapper/laylmstruct.py +178 -37
- deepdoctection/mapper/maputils.py +12 -11
- deepdoctection/mapper/match.py +2 -2
- deepdoctection/mapper/misc.py +11 -9
- deepdoctection/mapper/pascalstruct.py +4 -4
- deepdoctection/mapper/prodigystruct.py +5 -5
- deepdoctection/mapper/pubstruct.py +84 -92
- deepdoctection/mapper/tpstruct.py +5 -5
- deepdoctection/mapper/xfundstruct.py +33 -33
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/anngen.py +12 -14
- deepdoctection/pipe/base.py +52 -106
- deepdoctection/pipe/common.py +72 -59
- deepdoctection/pipe/concurrency.py +16 -11
- deepdoctection/pipe/doctectionpipe.py +24 -21
- deepdoctection/pipe/language.py +20 -25
- deepdoctection/pipe/layout.py +20 -16
- deepdoctection/pipe/lm.py +75 -105
- deepdoctection/pipe/order.py +194 -89
- deepdoctection/pipe/refine.py +111 -124
- deepdoctection/pipe/segment.py +156 -161
- deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
- deepdoctection/pipe/text.py +37 -36
- deepdoctection/pipe/transform.py +19 -16
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +48 -41
- deepdoctection/train/hf_detr_train.py +41 -30
- deepdoctection/train/hf_layoutlm_train.py +153 -135
- deepdoctection/train/tp_frcnn_train.py +32 -31
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +13 -6
- deepdoctection/utils/develop.py +4 -4
- deepdoctection/utils/env_info.py +87 -125
- deepdoctection/utils/file_utils.py +6 -11
- deepdoctection/utils/fs.py +22 -18
- deepdoctection/utils/identifier.py +2 -2
- deepdoctection/utils/logger.py +16 -15
- deepdoctection/utils/metacfg.py +7 -7
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/pdf_utils.py +11 -11
- deepdoctection/utils/settings.py +185 -181
- deepdoctection/utils/tqdm.py +1 -1
- deepdoctection/utils/transform.py +14 -9
- deepdoctection/utils/types.py +104 -0
- deepdoctection/utils/utils.py +7 -7
- deepdoctection/utils/viz.py +74 -72
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
- deepdoctection-0.33.dist-info/RECORD +146 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
- deepdoctection/utils/detection_types.py +0 -68
- deepdoctection-0.31.dist-info/RECORD +0 -144
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
|
@@ -18,66 +18,138 @@
|
|
|
18
18
|
"""
|
|
19
19
|
HF Layoutlm model for diverse downstream tasks.
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
21
22
|
|
|
23
|
+
import os
|
|
22
24
|
from abc import ABC
|
|
23
25
|
from collections import defaultdict
|
|
24
|
-
from copy import copy
|
|
25
26
|
from pathlib import Path
|
|
26
|
-
from typing import
|
|
27
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Sequence, Union
|
|
27
28
|
|
|
28
29
|
import numpy as np
|
|
29
|
-
|
|
30
|
-
from
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
get_type,
|
|
43
|
-
token_class_tag_to_token_class_with_tag,
|
|
44
|
-
token_class_with_tag_to_token_class_and_tag,
|
|
30
|
+
from lazy_imports import try_import
|
|
31
|
+
from typing_extensions import TypeAlias
|
|
32
|
+
|
|
33
|
+
from ..utils.file_utils import get_pytorch_requirement, get_transformers_requirement
|
|
34
|
+
from ..utils.settings import TypeOrStr
|
|
35
|
+
from ..utils.types import JsonDict, PathLikeOrStr, Requirement
|
|
36
|
+
from .base import (
|
|
37
|
+
LMSequenceClassifier,
|
|
38
|
+
LMTokenClassifier,
|
|
39
|
+
ModelCategories,
|
|
40
|
+
NerModelCategories,
|
|
41
|
+
SequenceClassResult,
|
|
42
|
+
TokenClassResult,
|
|
45
43
|
)
|
|
46
|
-
from .
|
|
47
|
-
from .pt.ptutils import set_torch_auto_device
|
|
44
|
+
from .pt.ptutils import get_torch_device
|
|
48
45
|
|
|
49
|
-
|
|
46
|
+
with try_import() as pt_import_guard:
|
|
50
47
|
import torch
|
|
51
48
|
import torch.nn.functional as F
|
|
52
|
-
from torch import Tensor # pylint: disable=W0611
|
|
53
49
|
|
|
54
|
-
|
|
50
|
+
with try_import() as tr_import_guard:
|
|
55
51
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD # type: ignore
|
|
56
52
|
from transformers import (
|
|
57
53
|
LayoutLMForSequenceClassification,
|
|
58
54
|
LayoutLMForTokenClassification,
|
|
55
|
+
LayoutLMTokenizerFast,
|
|
59
56
|
LayoutLMv2Config,
|
|
60
57
|
LayoutLMv2ForSequenceClassification,
|
|
61
58
|
LayoutLMv2ForTokenClassification,
|
|
62
59
|
LayoutLMv3Config,
|
|
63
60
|
LayoutLMv3ForSequenceClassification,
|
|
64
61
|
LayoutLMv3ForTokenClassification,
|
|
62
|
+
LiltForSequenceClassification,
|
|
63
|
+
LiltForTokenClassification,
|
|
65
64
|
PretrainedConfig,
|
|
65
|
+
RobertaTokenizerFast,
|
|
66
|
+
XLMRobertaTokenizerFast,
|
|
66
67
|
)
|
|
67
68
|
|
|
69
|
+
if TYPE_CHECKING:
|
|
70
|
+
LayoutTokenModels: TypeAlias = Union[
|
|
71
|
+
LayoutLMForTokenClassification,
|
|
72
|
+
LayoutLMv2ForTokenClassification,
|
|
73
|
+
LayoutLMv3ForTokenClassification,
|
|
74
|
+
LiltForTokenClassification,
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
LayoutSequenceModels: TypeAlias = Union[
|
|
78
|
+
LayoutLMForSequenceClassification,
|
|
79
|
+
LayoutLMv2ForSequenceClassification,
|
|
80
|
+
LayoutLMv3ForSequenceClassification,
|
|
81
|
+
LiltForSequenceClassification,
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
HfLayoutTokenModels: TypeAlias = Union[
|
|
85
|
+
LayoutLMForTokenClassification,
|
|
86
|
+
LayoutLMv2ForTokenClassification,
|
|
87
|
+
LayoutLMv3ForTokenClassification,
|
|
88
|
+
LiltForTokenClassification,
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
HfLayoutSequenceModels: TypeAlias = Union[
|
|
92
|
+
LayoutLMForSequenceClassification,
|
|
93
|
+
LayoutLMv2ForSequenceClassification,
|
|
94
|
+
LayoutLMv3ForSequenceClassification,
|
|
95
|
+
LiltForSequenceClassification,
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_tokenizer_from_model_class(model_class: str, use_xlm_tokenizer: bool) -> Any:
|
|
100
|
+
"""
|
|
101
|
+
We do not use the tokenizer for a particular model that the transformer library provides. Thie mapping therefore
|
|
102
|
+
returns the tokenizer that should be used for a particular model.
|
|
103
|
+
|
|
104
|
+
:param model_class: The model as stated in the transformer library.
|
|
105
|
+
:param use_xlm_tokenizer: True if one uses the LayoutXLM. (The model cannot be distinguished from LayoutLMv2).
|
|
106
|
+
:return: Tokenizer instance to use.
|
|
107
|
+
"""
|
|
108
|
+
return {
|
|
109
|
+
("LayoutLMForTokenClassification", False): LayoutLMTokenizerFast.from_pretrained(
|
|
110
|
+
"microsoft/layoutlm-base-uncased"
|
|
111
|
+
),
|
|
112
|
+
("LayoutLMForSequenceClassification", False): LayoutLMTokenizerFast.from_pretrained(
|
|
113
|
+
"microsoft/layoutlm-base-uncased"
|
|
114
|
+
),
|
|
115
|
+
("LayoutLMv2ForTokenClassification", False): LayoutLMTokenizerFast.from_pretrained(
|
|
116
|
+
"microsoft/layoutlm-base-uncased"
|
|
117
|
+
),
|
|
118
|
+
("LayoutLMv2ForSequenceClassification", False): LayoutLMTokenizerFast.from_pretrained(
|
|
119
|
+
"microsoft/layoutlm-base-uncased"
|
|
120
|
+
),
|
|
121
|
+
("LayoutLMv2ForTokenClassification", True): XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base"),
|
|
122
|
+
("LayoutLMv2ForSequenceClassification", True): XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base"),
|
|
123
|
+
("LayoutLMv3ForSequenceClassification", False): RobertaTokenizerFast.from_pretrained(
|
|
124
|
+
"roberta-base", add_prefix_space=True
|
|
125
|
+
),
|
|
126
|
+
("LayoutLMv3ForTokenClassification", False): RobertaTokenizerFast.from_pretrained(
|
|
127
|
+
"roberta-base", add_prefix_space=True
|
|
128
|
+
),
|
|
129
|
+
("LiltForTokenClassification", True): XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base"),
|
|
130
|
+
("LiltForTokenClassification", False): RobertaTokenizerFast.from_pretrained(
|
|
131
|
+
"roberta-base", add_prefix_space=True
|
|
132
|
+
),
|
|
133
|
+
("LiltForSequenceClassification", True): XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base"),
|
|
134
|
+
("LiltForSequenceClassification", False): RobertaTokenizerFast.from_pretrained(
|
|
135
|
+
"roberta-base", add_prefix_space=True
|
|
136
|
+
),
|
|
137
|
+
("XLMRobertaForSequenceClassification", True): XLMRobertaTokenizerFast.from_pretrained(
|
|
138
|
+
"FacebookAI/xlm-roberta-base"
|
|
139
|
+
),
|
|
140
|
+
}[(model_class, use_xlm_tokenizer)]
|
|
141
|
+
|
|
68
142
|
|
|
69
143
|
def predict_token_classes(
|
|
70
|
-
uuids:
|
|
71
|
-
input_ids:
|
|
72
|
-
attention_mask:
|
|
73
|
-
token_type_ids:
|
|
74
|
-
boxes:
|
|
75
|
-
tokens:
|
|
76
|
-
model:
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
images: Optional["Tensor"] = None,
|
|
80
|
-
) -> List[TokenClassResult]:
|
|
144
|
+
uuids: list[list[str]],
|
|
145
|
+
input_ids: torch.Tensor,
|
|
146
|
+
attention_mask: torch.Tensor,
|
|
147
|
+
token_type_ids: torch.Tensor,
|
|
148
|
+
boxes: torch.Tensor,
|
|
149
|
+
tokens: list[list[str]],
|
|
150
|
+
model: LayoutTokenModels,
|
|
151
|
+
images: Optional[torch.Tensor] = None,
|
|
152
|
+
) -> list[TokenClassResult]:
|
|
81
153
|
"""
|
|
82
154
|
:param uuids: A list of uuids that correspond to a word that induces the resulting token
|
|
83
155
|
:param input_ids: Token converted to ids to be taken from LayoutLMTokenizer
|
|
@@ -129,26 +201,23 @@ def predict_token_classes(
|
|
|
129
201
|
|
|
130
202
|
|
|
131
203
|
def predict_sequence_classes(
|
|
132
|
-
input_ids:
|
|
133
|
-
attention_mask:
|
|
134
|
-
token_type_ids:
|
|
135
|
-
boxes:
|
|
136
|
-
model:
|
|
137
|
-
|
|
138
|
-
"LayoutLMv2ForSequenceClassification",
|
|
139
|
-
"LayoutLMv3ForSequenceClassification",
|
|
140
|
-
],
|
|
141
|
-
images: Optional["Tensor"] = None,
|
|
204
|
+
input_ids: torch.Tensor,
|
|
205
|
+
attention_mask: torch.Tensor,
|
|
206
|
+
token_type_ids: torch.Tensor,
|
|
207
|
+
boxes: torch.Tensor,
|
|
208
|
+
model: LayoutSequenceModels,
|
|
209
|
+
images: Optional[torch.Tensor] = None,
|
|
142
210
|
) -> SequenceClassResult:
|
|
143
211
|
"""
|
|
144
212
|
:param input_ids: Token converted to ids to be taken from LayoutLMTokenizer
|
|
145
213
|
:param attention_mask: The associated attention masks from padded sequences taken from LayoutLMTokenizer
|
|
146
214
|
:param token_type_ids: Torch tensor of token type ids taken from LayoutLMTokenizer
|
|
147
215
|
:param boxes: Torch tensor of bounding boxes of type 'xyxy'
|
|
148
|
-
:param model: layoutlm model for
|
|
216
|
+
:param model: layoutlm model for sequence classification
|
|
149
217
|
:param images: A list of torch image tensors or None
|
|
150
218
|
:return: SequenceClassResult
|
|
151
219
|
"""
|
|
220
|
+
|
|
152
221
|
if images is None:
|
|
153
222
|
outputs = model(input_ids=input_ids, bbox=boxes, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
|
154
223
|
elif isinstance(model, LayoutLMv2ForSequenceClassification):
|
|
@@ -177,16 +246,14 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
|
|
|
177
246
|
Abstract base class for wrapping LayoutLM models for token classification into the deepdoctection framework.
|
|
178
247
|
"""
|
|
179
248
|
|
|
180
|
-
model: Union["LayoutLMForTokenClassification", "LayoutLMv2ForTokenClassification"]
|
|
181
|
-
|
|
182
249
|
def __init__(
|
|
183
250
|
self,
|
|
184
|
-
path_config_json:
|
|
185
|
-
path_weights:
|
|
251
|
+
path_config_json: PathLikeOrStr,
|
|
252
|
+
path_weights: PathLikeOrStr,
|
|
186
253
|
categories_semantics: Optional[Sequence[TypeOrStr]] = None,
|
|
187
254
|
categories_bio: Optional[Sequence[TypeOrStr]] = None,
|
|
188
|
-
categories: Optional[Mapping[
|
|
189
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
255
|
+
categories: Optional[Mapping[int, TypeOrStr]] = None,
|
|
256
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
190
257
|
):
|
|
191
258
|
"""
|
|
192
259
|
:param path_config_json: path to .json config file
|
|
@@ -198,6 +265,8 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
|
|
|
198
265
|
consistent with detectors use only values>0. Conversion will be done internally.
|
|
199
266
|
:param categories: If you have a pre-trained model you can pass a complete dict of NER categories
|
|
200
267
|
:param device: The device (cpu,"cuda"), where to place the model.
|
|
268
|
+
:param use_xlm_tokenizer: True if one uses the LayoutXLM or a lilt model built with a xlm language model, e.g.
|
|
269
|
+
info-xlm or roberta-xlm. (LayoutXLM cannot be distinguished from LayoutLMv2).
|
|
201
270
|
"""
|
|
202
271
|
|
|
203
272
|
if categories is None:
|
|
@@ -206,45 +275,21 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
|
|
|
206
275
|
if categories_bio is None:
|
|
207
276
|
raise ValueError("If categories is None then categories_bio cannot be None")
|
|
208
277
|
|
|
209
|
-
self.path_config = path_config_json
|
|
210
|
-
self.path_weights = path_weights
|
|
211
|
-
self.
|
|
212
|
-
|
|
278
|
+
self.path_config = Path(path_config_json)
|
|
279
|
+
self.path_weights = Path(path_weights)
|
|
280
|
+
self.categories = NerModelCategories(
|
|
281
|
+
init_categories=categories, categories_semantics=categories_semantics, categories_bio=categories_bio
|
|
213
282
|
)
|
|
214
|
-
self.
|
|
215
|
-
if categories:
|
|
216
|
-
self.categories = copy(categories) # type: ignore
|
|
217
|
-
else:
|
|
218
|
-
self.categories = self._categories_orig_to_categories(
|
|
219
|
-
self.categories_semantics, self.categories_bio # type: ignore
|
|
220
|
-
)
|
|
221
|
-
if device is not None:
|
|
222
|
-
self.device = device
|
|
223
|
-
else:
|
|
224
|
-
self.device = set_torch_auto_device()
|
|
225
|
-
self.model.to(self.device)
|
|
283
|
+
self.device = get_torch_device(device)
|
|
226
284
|
|
|
227
285
|
@classmethod
|
|
228
|
-
def get_requirements(cls) ->
|
|
286
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
229
287
|
return [get_pytorch_requirement(), get_transformers_requirement()]
|
|
230
288
|
|
|
231
|
-
|
|
232
|
-
def _categories_orig_to_categories(
|
|
233
|
-
categories_semantics: List[TokenClasses], categories_bio: List[BioTag]
|
|
234
|
-
) -> Dict[str, ObjectTypes]:
|
|
235
|
-
categories_list = sorted(
|
|
236
|
-
{
|
|
237
|
-
token_class_tag_to_token_class_with_tag(token, tag)
|
|
238
|
-
for token in categories_semantics
|
|
239
|
-
for tag in categories_bio
|
|
240
|
-
}
|
|
241
|
-
)
|
|
242
|
-
return {str(k): v for k, v in enumerate(categories_list, 1)}
|
|
243
|
-
|
|
244
|
-
def _map_category_names(self, token_results: List[TokenClassResult]) -> List[TokenClassResult]:
|
|
289
|
+
def _map_category_names(self, token_results: list[TokenClassResult]) -> list[TokenClassResult]:
|
|
245
290
|
for result in token_results:
|
|
246
|
-
result.class_name = self.categories[
|
|
247
|
-
output =
|
|
291
|
+
result.class_name = self.categories.categories[result.class_id + 1]
|
|
292
|
+
output = self.categories.disentangle_token_class_and_tag(result.class_name)
|
|
248
293
|
if output is not None:
|
|
249
294
|
token_class, tag = output
|
|
250
295
|
result.semantic_name = token_class
|
|
@@ -256,9 +301,7 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
|
|
|
256
301
|
|
|
257
302
|
def _validate_encodings(
|
|
258
303
|
self, **encodings: Any
|
|
259
|
-
) ->
|
|
260
|
-
List[List[str]], List[str], "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", List[List[str]]
|
|
261
|
-
]:
|
|
304
|
+
) -> tuple[list[list[str]], list[str], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[list[str]]]:
|
|
262
305
|
image_ids = encodings.get("image_ids", [])
|
|
263
306
|
ann_ids = encodings.get("ann_ids")
|
|
264
307
|
input_ids = encodings.get("input_ids")
|
|
@@ -291,21 +334,41 @@ class HFLayoutLmTokenClassifierBase(LMTokenClassifier, ABC):
|
|
|
291
334
|
|
|
292
335
|
return ann_ids, image_ids, input_ids, attention_mask, token_type_ids, boxes, tokens
|
|
293
336
|
|
|
294
|
-
def clone(self) ->
|
|
337
|
+
def clone(self) -> HFLayoutLmTokenClassifierBase:
|
|
295
338
|
return self.__class__(
|
|
296
339
|
self.path_config,
|
|
297
340
|
self.path_weights,
|
|
298
|
-
self.categories_semantics,
|
|
299
|
-
self.categories_bio,
|
|
300
|
-
self.categories,
|
|
341
|
+
self.categories.categories_semantics,
|
|
342
|
+
self.categories.categories_bio,
|
|
343
|
+
self.categories.get_categories(),
|
|
301
344
|
self.device,
|
|
302
345
|
)
|
|
303
346
|
|
|
304
347
|
@staticmethod
|
|
305
|
-
def get_name(path_weights:
|
|
348
|
+
def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
|
|
306
349
|
"""Returns the name of the model"""
|
|
307
350
|
return f"Transformers_{architecture}_" + "_".join(Path(path_weights).parts[-2:])
|
|
308
351
|
|
|
352
|
+
@staticmethod
|
|
353
|
+
def get_tokenizer_class_name(model_class_name: str, use_xlm_tokenizer: bool) -> str:
|
|
354
|
+
"""A refinement for adding the tokenizer class name to the model configs.
|
|
355
|
+
|
|
356
|
+
:param model_class_name: The model name, e.g. model.__class__.__name__
|
|
357
|
+
:param use_xlm_tokenizer: Whether to use a XLM tokenizer.
|
|
358
|
+
"""
|
|
359
|
+
tokenizer = get_tokenizer_from_model_class(model_class_name, use_xlm_tokenizer)
|
|
360
|
+
return tokenizer.__class__.__name__
|
|
361
|
+
|
|
362
|
+
@staticmethod
|
|
363
|
+
def image_to_raw_features_mapping() -> str:
|
|
364
|
+
"""Returns the mapping function to convert images into raw features."""
|
|
365
|
+
return "image_to_raw_layoutlm_features"
|
|
366
|
+
|
|
367
|
+
@staticmethod
|
|
368
|
+
def image_to_features_mapping() -> str:
|
|
369
|
+
"""Returns the mapping function to convert images into features."""
|
|
370
|
+
return "image_to_layoutlm_features"
|
|
371
|
+
|
|
309
372
|
|
|
310
373
|
class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
311
374
|
"""
|
|
@@ -343,12 +406,13 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
343
406
|
|
|
344
407
|
def __init__(
|
|
345
408
|
self,
|
|
346
|
-
path_config_json:
|
|
347
|
-
path_weights:
|
|
409
|
+
path_config_json: PathLikeOrStr,
|
|
410
|
+
path_weights: PathLikeOrStr,
|
|
348
411
|
categories_semantics: Optional[Sequence[TypeOrStr]] = None,
|
|
349
412
|
categories_bio: Optional[Sequence[TypeOrStr]] = None,
|
|
350
|
-
categories: Optional[Mapping[
|
|
351
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
413
|
+
categories: Optional[Mapping[int, TypeOrStr]] = None,
|
|
414
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
415
|
+
use_xlm_tokenizer: bool = False,
|
|
352
416
|
):
|
|
353
417
|
"""
|
|
354
418
|
:param path_config_json: path to .json config file
|
|
@@ -360,13 +424,19 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
360
424
|
consistent with detectors use only values>0. Conversion will be done internally.
|
|
361
425
|
:param categories: If you have a pre-trained model you can pass a complete dict of NER categories
|
|
362
426
|
:param device: The device (cpu,"cuda"), where to place the model.
|
|
427
|
+
:param use_xlm_tokenizer: Do not change this value unless you pre-trained a LayoutLM model with a different
|
|
428
|
+
Tokenizer.
|
|
363
429
|
"""
|
|
430
|
+
super().__init__(path_config_json, path_weights, categories_semantics, categories_bio, categories, device)
|
|
364
431
|
self.name = self.get_name(path_weights, "LayoutLM")
|
|
365
432
|
self.model_id = self.get_model_id()
|
|
366
433
|
self.model = self.get_wrapped_model(path_config_json, path_weights)
|
|
367
|
-
|
|
434
|
+
self.model.to(self.device)
|
|
435
|
+
self.model.config.tokenizer_class = self.get_tokenizer_class_name(
|
|
436
|
+
self.model.__class__.__name__, use_xlm_tokenizer
|
|
437
|
+
)
|
|
368
438
|
|
|
369
|
-
def predict(self, **encodings: Union[
|
|
439
|
+
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> list[TokenClassResult]:
|
|
370
440
|
"""
|
|
371
441
|
Launch inference on LayoutLm for token classification. Pass the following arguments
|
|
372
442
|
|
|
@@ -392,7 +462,9 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
392
462
|
return self._map_category_names(results)
|
|
393
463
|
|
|
394
464
|
@staticmethod
|
|
395
|
-
def get_wrapped_model(
|
|
465
|
+
def get_wrapped_model(
|
|
466
|
+
path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
|
|
467
|
+
) -> LayoutLMForTokenClassification:
|
|
396
468
|
"""
|
|
397
469
|
Get the inner (wrapped) model.
|
|
398
470
|
|
|
@@ -400,8 +472,13 @@ class HFLayoutLmTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
400
472
|
:param path_weights: path to model artifact
|
|
401
473
|
:return: 'nn.Module'
|
|
402
474
|
"""
|
|
403
|
-
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config_json)
|
|
404
|
-
return LayoutLMForTokenClassification.from_pretrained(
|
|
475
|
+
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
|
|
476
|
+
return LayoutLMForTokenClassification.from_pretrained(
|
|
477
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
def clear_model(self) -> None:
|
|
481
|
+
self.model = None
|
|
405
482
|
|
|
406
483
|
|
|
407
484
|
class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
@@ -442,12 +519,13 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
442
519
|
|
|
443
520
|
def __init__(
|
|
444
521
|
self,
|
|
445
|
-
path_config_json:
|
|
446
|
-
path_weights:
|
|
522
|
+
path_config_json: PathLikeOrStr,
|
|
523
|
+
path_weights: PathLikeOrStr,
|
|
447
524
|
categories_semantics: Optional[Sequence[TypeOrStr]] = None,
|
|
448
525
|
categories_bio: Optional[Sequence[TypeOrStr]] = None,
|
|
449
|
-
categories: Optional[Mapping[
|
|
450
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
526
|
+
categories: Optional[Mapping[int, TypeOrStr]] = None,
|
|
527
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
528
|
+
use_xlm_tokenizer: bool = False,
|
|
451
529
|
):
|
|
452
530
|
"""
|
|
453
531
|
:param path_config_json: path to .json config file
|
|
@@ -459,13 +537,19 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
459
537
|
consistent with detectors use only values>0. Conversion will be done internally.
|
|
460
538
|
:param categories: If you have a pre-trained model you can pass a complete dict of NER categories
|
|
461
539
|
:param device: The device (cpu,"cuda"), where to place the model.
|
|
540
|
+
:param use_xlm_tokenizer: Set to True if you use a LayoutXLM model. If you use a LayoutLMv2 model keep the
|
|
541
|
+
default value.
|
|
462
542
|
"""
|
|
543
|
+
super().__init__(path_config_json, path_weights, categories_semantics, categories_bio, categories, device)
|
|
463
544
|
self.name = self.get_name(path_weights, "LayoutLMv2")
|
|
464
545
|
self.model_id = self.get_model_id()
|
|
465
546
|
self.model = self.get_wrapped_model(path_config_json, path_weights)
|
|
466
|
-
|
|
547
|
+
self.model.to(self.device)
|
|
548
|
+
self.model.config.tokenizer_class = self.get_tokenizer_class_name(
|
|
549
|
+
self.model.__class__.__name__, use_xlm_tokenizer
|
|
550
|
+
)
|
|
467
551
|
|
|
468
|
-
def predict(self, **encodings: Union[
|
|
552
|
+
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> list[TokenClassResult]:
|
|
469
553
|
"""
|
|
470
554
|
Launch inference on LayoutLm for token classification. Pass the following arguments
|
|
471
555
|
|
|
@@ -496,7 +580,7 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
496
580
|
return self._map_category_names(results)
|
|
497
581
|
|
|
498
582
|
@staticmethod
|
|
499
|
-
def
|
|
583
|
+
def default_kwargs_for_image_to_features_mapping() -> JsonDict:
|
|
500
584
|
"""
|
|
501
585
|
Add some default arguments that might be necessary when preparing a sample. Overwrite this method
|
|
502
586
|
for some custom setting.
|
|
@@ -504,7 +588,9 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
504
588
|
return {"image_width": 224, "image_height": 224}
|
|
505
589
|
|
|
506
590
|
@staticmethod
|
|
507
|
-
def get_wrapped_model(
|
|
591
|
+
def get_wrapped_model(
|
|
592
|
+
path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
|
|
593
|
+
) -> LayoutLMv2ForTokenClassification:
|
|
508
594
|
"""
|
|
509
595
|
Get the inner (wrapped) model.
|
|
510
596
|
|
|
@@ -512,11 +598,14 @@ class HFLayoutLmv2TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
512
598
|
:param path_weights: path to model artifact
|
|
513
599
|
:return: 'nn.Module'
|
|
514
600
|
"""
|
|
515
|
-
config = LayoutLMv2Config.from_pretrained(pretrained_model_name_or_path=path_config_json)
|
|
601
|
+
config = LayoutLMv2Config.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
|
|
516
602
|
return LayoutLMv2ForTokenClassification.from_pretrained(
|
|
517
|
-
pretrained_model_name_or_path=path_weights, config=config
|
|
603
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
518
604
|
)
|
|
519
605
|
|
|
606
|
+
def clear_model(self) -> None:
|
|
607
|
+
self.model = None
|
|
608
|
+
|
|
520
609
|
|
|
521
610
|
class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
522
611
|
"""
|
|
@@ -556,12 +645,13 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
556
645
|
|
|
557
646
|
def __init__(
|
|
558
647
|
self,
|
|
559
|
-
path_config_json:
|
|
560
|
-
path_weights:
|
|
648
|
+
path_config_json: PathLikeOrStr,
|
|
649
|
+
path_weights: PathLikeOrStr,
|
|
561
650
|
categories_semantics: Optional[Sequence[TypeOrStr]] = None,
|
|
562
651
|
categories_bio: Optional[Sequence[TypeOrStr]] = None,
|
|
563
|
-
categories: Optional[Mapping[
|
|
564
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
652
|
+
categories: Optional[Mapping[int, TypeOrStr]] = None,
|
|
653
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
654
|
+
use_xlm_tokenizer: bool = False,
|
|
565
655
|
):
|
|
566
656
|
"""
|
|
567
657
|
:param path_config_json: path to .json config file
|
|
@@ -573,13 +663,19 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
573
663
|
consistent with detectors use only values>0. Conversion will be done internally.
|
|
574
664
|
:param categories: If you have a pre-trained model you can pass a complete dict of NER categories
|
|
575
665
|
:param device: The device (cpu,"cuda"), where to place the model.
|
|
666
|
+
:param use_xlm_tokenizer: Do not change this value unless you pre-trained a LayoutLMv3 model with a different
|
|
667
|
+
tokenizer.
|
|
576
668
|
"""
|
|
669
|
+
super().__init__(path_config_json, path_weights, categories_semantics, categories_bio, categories, device)
|
|
577
670
|
self.name = self.get_name(path_weights, "LayoutLMv3")
|
|
578
671
|
self.model_id = self.get_model_id()
|
|
579
672
|
self.model = self.get_wrapped_model(path_config_json, path_weights)
|
|
580
|
-
|
|
673
|
+
self.model.to(self.device)
|
|
674
|
+
self.model.config.tokenizer_class = self.get_tokenizer_class_name(
|
|
675
|
+
self.model.__class__.__name__, use_xlm_tokenizer
|
|
676
|
+
)
|
|
581
677
|
|
|
582
|
-
def predict(self, **encodings: Union[
|
|
678
|
+
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> list[TokenClassResult]:
|
|
583
679
|
"""
|
|
584
680
|
Launch inference on LayoutLm for token classification. Pass the following arguments
|
|
585
681
|
|
|
@@ -606,7 +702,7 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
606
702
|
return self._map_category_names(results)
|
|
607
703
|
|
|
608
704
|
@staticmethod
|
|
609
|
-
def
|
|
705
|
+
def default_kwargs_for_image_to_features_mapping() -> JsonDict:
|
|
610
706
|
"""
|
|
611
707
|
Add some default arguments that might be necessary when preparing a sample. Overwrite this method
|
|
612
708
|
for some custom setting.
|
|
@@ -620,7 +716,9 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
620
716
|
}
|
|
621
717
|
|
|
622
718
|
@staticmethod
|
|
623
|
-
def get_wrapped_model(
|
|
719
|
+
def get_wrapped_model(
|
|
720
|
+
path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
|
|
721
|
+
) -> LayoutLMv3ForTokenClassification:
|
|
624
722
|
"""
|
|
625
723
|
Get the inner (wrapped) model.
|
|
626
724
|
|
|
@@ -628,81 +726,43 @@ class HFLayoutLmv3TokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
|
628
726
|
:param path_weights: path to model artifact
|
|
629
727
|
:return: 'nn.Module'
|
|
630
728
|
"""
|
|
631
|
-
config = LayoutLMv3Config.from_pretrained(pretrained_model_name_or_path=path_config_json)
|
|
729
|
+
config = LayoutLMv3Config.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
|
|
632
730
|
return LayoutLMv3ForTokenClassification.from_pretrained(
|
|
633
|
-
pretrained_model_name_or_path=path_weights, config=config
|
|
731
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
634
732
|
)
|
|
635
733
|
|
|
734
|
+
def clear_model(self) -> None:
|
|
735
|
+
self.model = None
|
|
736
|
+
|
|
636
737
|
|
|
637
738
|
class HFLayoutLmSequenceClassifierBase(LMSequenceClassifier, ABC):
|
|
638
739
|
"""
|
|
639
740
|
Abstract base class for wrapping LayoutLM models for sequence classification into the deepdoctection framework.
|
|
640
741
|
"""
|
|
641
742
|
|
|
642
|
-
model: Union["LayoutLMForSequenceClassification", "LayoutLMv2ForSequenceClassification"]
|
|
643
|
-
|
|
644
743
|
def __init__(
|
|
645
744
|
self,
|
|
646
|
-
path_config_json:
|
|
647
|
-
path_weights:
|
|
648
|
-
categories: Mapping[
|
|
649
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
745
|
+
path_config_json: PathLikeOrStr,
|
|
746
|
+
path_weights: PathLikeOrStr,
|
|
747
|
+
categories: Mapping[int, TypeOrStr],
|
|
748
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
650
749
|
):
|
|
651
|
-
self.path_config = path_config_json
|
|
652
|
-
self.path_weights = path_weights
|
|
653
|
-
self.categories =
|
|
654
|
-
|
|
655
|
-
if device is not None:
|
|
656
|
-
self.device = device
|
|
657
|
-
else:
|
|
658
|
-
self.device = set_torch_auto_device()
|
|
659
|
-
self.model.to(self.device)
|
|
750
|
+
self.path_config = Path(path_config_json)
|
|
751
|
+
self.path_weights = Path(path_weights)
|
|
752
|
+
self.categories = ModelCategories(init_categories=categories)
|
|
660
753
|
|
|
661
|
-
|
|
662
|
-
input_ids = encodings.get("input_ids")
|
|
663
|
-
attention_mask = encodings.get("attention_mask")
|
|
664
|
-
token_type_ids = encodings.get("token_type_ids")
|
|
665
|
-
boxes = encodings.get("bbox")
|
|
666
|
-
|
|
667
|
-
if isinstance(input_ids, torch.Tensor):
|
|
668
|
-
input_ids = input_ids.to(self.device)
|
|
669
|
-
else:
|
|
670
|
-
raise ValueError(f"input_ids must be list but is {type(input_ids)}")
|
|
671
|
-
if isinstance(attention_mask, torch.Tensor):
|
|
672
|
-
attention_mask = attention_mask.to(self.device)
|
|
673
|
-
else:
|
|
674
|
-
raise ValueError(f"attention_mask must be list but is {type(attention_mask)}")
|
|
675
|
-
if isinstance(token_type_ids, torch.Tensor):
|
|
676
|
-
token_type_ids = token_type_ids.to(self.device)
|
|
677
|
-
else:
|
|
678
|
-
raise ValueError(f"token_type_ids must be list but is {type(token_type_ids)}")
|
|
679
|
-
if isinstance(boxes, torch.Tensor):
|
|
680
|
-
boxes = boxes.to(self.device)
|
|
681
|
-
else:
|
|
682
|
-
raise ValueError(f"boxes must be list but is {type(boxes)}")
|
|
683
|
-
|
|
684
|
-
result = predict_sequence_classes(
|
|
685
|
-
input_ids,
|
|
686
|
-
attention_mask,
|
|
687
|
-
token_type_ids,
|
|
688
|
-
boxes,
|
|
689
|
-
self.model,
|
|
690
|
-
)
|
|
691
|
-
|
|
692
|
-
result.class_id += 1
|
|
693
|
-
result.class_name = self.categories[str(result.class_id)]
|
|
694
|
-
return result
|
|
754
|
+
self.device = get_torch_device(device)
|
|
695
755
|
|
|
696
756
|
@classmethod
|
|
697
|
-
def get_requirements(cls) ->
|
|
757
|
+
def get_requirements(cls) -> list[Requirement]:
|
|
698
758
|
return [get_pytorch_requirement(), get_transformers_requirement()]
|
|
699
759
|
|
|
700
|
-
def clone(self) ->
|
|
701
|
-
return self.__class__(self.path_config, self.path_weights, self.categories, self.device)
|
|
760
|
+
def clone(self) -> HFLayoutLmSequenceClassifierBase:
|
|
761
|
+
return self.__class__(self.path_config, self.path_weights, self.categories.get_categories(), self.device)
|
|
702
762
|
|
|
703
763
|
def _validate_encodings(
|
|
704
|
-
self, **encodings: Union[
|
|
705
|
-
) ->
|
|
764
|
+
self, **encodings: Union[list[list[str]], torch.Tensor]
|
|
765
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
706
766
|
input_ids = encodings.get("input_ids")
|
|
707
767
|
attention_mask = encodings.get("attention_mask")
|
|
708
768
|
token_type_ids = encodings.get("token_type_ids")
|
|
@@ -732,10 +792,30 @@ class HFLayoutLmSequenceClassifierBase(LMSequenceClassifier, ABC):
|
|
|
732
792
|
return input_ids, attention_mask, token_type_ids, boxes
|
|
733
793
|
|
|
734
794
|
@staticmethod
|
|
735
|
-
def get_name(path_weights:
|
|
795
|
+
def get_name(path_weights: PathLikeOrStr, architecture: str) -> str:
|
|
736
796
|
"""Returns the name of the model"""
|
|
737
797
|
return f"Transformers_{architecture}_" + "_".join(Path(path_weights).parts[-2:])
|
|
738
798
|
|
|
799
|
+
@staticmethod
|
|
800
|
+
def get_tokenizer_class_name(model_class_name: str, use_xlm_tokenizer: bool) -> str:
|
|
801
|
+
"""A refinement for adding the tokenizer class name to the model configs.
|
|
802
|
+
|
|
803
|
+
:param model_class_name: The model name, e.g. model.__class__.__name__
|
|
804
|
+
:param use_xlm_tokenizer: Whether to use a XLM tokenizer.
|
|
805
|
+
"""
|
|
806
|
+
tokenizer = get_tokenizer_from_model_class(model_class_name, use_xlm_tokenizer)
|
|
807
|
+
return tokenizer.__class__.__name__
|
|
808
|
+
|
|
809
|
+
@staticmethod
|
|
810
|
+
def image_to_raw_features_mapping() -> str:
|
|
811
|
+
"""Returns the mapping function to convert images into raw features."""
|
|
812
|
+
return "image_to_raw_layoutlm_features"
|
|
813
|
+
|
|
814
|
+
@staticmethod
|
|
815
|
+
def image_to_features_mapping() -> str:
|
|
816
|
+
"""Returns the mapping function to convert images into features."""
|
|
817
|
+
return "image_to_layoutlm_features"
|
|
818
|
+
|
|
739
819
|
|
|
740
820
|
class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
741
821
|
"""
|
|
@@ -770,20 +850,22 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
770
850
|
|
|
771
851
|
def __init__(
|
|
772
852
|
self,
|
|
773
|
-
path_config_json:
|
|
774
|
-
path_weights:
|
|
775
|
-
categories: Mapping[
|
|
776
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
853
|
+
path_config_json: PathLikeOrStr,
|
|
854
|
+
path_weights: PathLikeOrStr,
|
|
855
|
+
categories: Mapping[int, TypeOrStr],
|
|
856
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
857
|
+
use_xlm_tokenizer: bool = False,
|
|
777
858
|
):
|
|
859
|
+
super().__init__(path_config_json, path_weights, categories, device)
|
|
778
860
|
self.name = self.get_name(path_weights, "LayoutLM")
|
|
779
861
|
self.model_id = self.get_model_id()
|
|
780
|
-
|
|
781
|
-
self.model
|
|
782
|
-
|
|
862
|
+
self.model = self.get_wrapped_model(path_config_json, path_weights)
|
|
863
|
+
self.model.to(self.device)
|
|
864
|
+
self.model.config.tokenizer_class = self.get_tokenizer_class_name(
|
|
865
|
+
self.model.__class__.__name__, use_xlm_tokenizer
|
|
783
866
|
)
|
|
784
|
-
super().__init__(path_config_json, path_weights, categories, device)
|
|
785
867
|
|
|
786
|
-
def predict(self, **encodings: Union[
|
|
868
|
+
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
|
|
787
869
|
input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
|
|
788
870
|
|
|
789
871
|
result = predict_sequence_classes(
|
|
@@ -795,11 +877,13 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
795
877
|
)
|
|
796
878
|
|
|
797
879
|
result.class_id += 1
|
|
798
|
-
result.class_name = self.categories[
|
|
880
|
+
result.class_name = self.categories.categories[result.class_id]
|
|
799
881
|
return result
|
|
800
882
|
|
|
801
883
|
@staticmethod
|
|
802
|
-
def get_wrapped_model(
|
|
884
|
+
def get_wrapped_model(
|
|
885
|
+
path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
|
|
886
|
+
) -> LayoutLMForSequenceClassification:
|
|
803
887
|
"""
|
|
804
888
|
Get the inner (wrapped) model.
|
|
805
889
|
|
|
@@ -807,11 +891,14 @@ class HFLayoutLmSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
807
891
|
:param path_weights: path to model artifact
|
|
808
892
|
:return: 'nn.Module'
|
|
809
893
|
"""
|
|
810
|
-
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config_json)
|
|
894
|
+
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
|
|
811
895
|
return LayoutLMForSequenceClassification.from_pretrained(
|
|
812
|
-
pretrained_model_name_or_path=path_weights, config=config
|
|
896
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
813
897
|
)
|
|
814
898
|
|
|
899
|
+
def clear_model(self) -> None:
|
|
900
|
+
self.model = None
|
|
901
|
+
|
|
815
902
|
|
|
816
903
|
class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
817
904
|
"""
|
|
@@ -846,17 +933,22 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
846
933
|
|
|
847
934
|
def __init__(
|
|
848
935
|
self,
|
|
849
|
-
path_config_json:
|
|
850
|
-
path_weights:
|
|
851
|
-
categories: Mapping[
|
|
852
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
936
|
+
path_config_json: PathLikeOrStr,
|
|
937
|
+
path_weights: PathLikeOrStr,
|
|
938
|
+
categories: Mapping[int, TypeOrStr],
|
|
939
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
940
|
+
use_xlm_tokenizer: bool = False,
|
|
853
941
|
):
|
|
942
|
+
super().__init__(path_config_json, path_weights, categories, device)
|
|
854
943
|
self.name = self.get_name(path_weights, "LayoutLMv2")
|
|
855
944
|
self.model_id = self.get_model_id()
|
|
856
945
|
self.model = self.get_wrapped_model(path_config_json, path_weights)
|
|
857
|
-
|
|
946
|
+
self.model.to(self.device)
|
|
947
|
+
self.model.config.tokenizer_class = self.get_tokenizer_class_name(
|
|
948
|
+
self.model.__class__.__name__, use_xlm_tokenizer
|
|
949
|
+
)
|
|
858
950
|
|
|
859
|
-
def predict(self, **encodings: Union[
|
|
951
|
+
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
|
|
860
952
|
input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
|
|
861
953
|
images = encodings.get("image")
|
|
862
954
|
if isinstance(images, torch.Tensor):
|
|
@@ -867,11 +959,11 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
867
959
|
result = predict_sequence_classes(input_ids, attention_mask, token_type_ids, boxes, self.model, images)
|
|
868
960
|
|
|
869
961
|
result.class_id += 1
|
|
870
|
-
result.class_name = self.categories[
|
|
962
|
+
result.class_name = self.categories.categories[result.class_id]
|
|
871
963
|
return result
|
|
872
964
|
|
|
873
965
|
@staticmethod
|
|
874
|
-
def
|
|
966
|
+
def default_kwargs_for_image_to_features_mapping() -> JsonDict:
|
|
875
967
|
"""
|
|
876
968
|
Add some default arguments that might be necessary when preparing a sample. Overwrite this method
|
|
877
969
|
for some custom setting.
|
|
@@ -879,7 +971,9 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
879
971
|
return {"image_width": 224, "image_height": 224}
|
|
880
972
|
|
|
881
973
|
@staticmethod
|
|
882
|
-
def get_wrapped_model(
|
|
974
|
+
def get_wrapped_model(
|
|
975
|
+
path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
|
|
976
|
+
) -> LayoutLMv2ForSequenceClassification:
|
|
883
977
|
"""
|
|
884
978
|
Get the inner (wrapped) model.
|
|
885
979
|
|
|
@@ -887,11 +981,14 @@ class HFLayoutLmv2SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
887
981
|
:param path_weights: path to model artifact
|
|
888
982
|
:return: 'nn.Module'
|
|
889
983
|
"""
|
|
890
|
-
config = LayoutLMv2Config.from_pretrained(pretrained_model_name_or_path=path_config_json)
|
|
984
|
+
config = LayoutLMv2Config.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
|
|
891
985
|
return LayoutLMv2ForSequenceClassification.from_pretrained(
|
|
892
|
-
pretrained_model_name_or_path=path_weights, config=config
|
|
986
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
893
987
|
)
|
|
894
988
|
|
|
989
|
+
def clear_model(self) -> None:
|
|
990
|
+
self.model = None
|
|
991
|
+
|
|
895
992
|
|
|
896
993
|
class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
897
994
|
"""
|
|
@@ -926,17 +1023,22 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
926
1023
|
|
|
927
1024
|
def __init__(
|
|
928
1025
|
self,
|
|
929
|
-
path_config_json:
|
|
930
|
-
path_weights:
|
|
931
|
-
categories: Mapping[
|
|
932
|
-
device: Optional[Literal["cpu", "cuda"]] = None,
|
|
1026
|
+
path_config_json: PathLikeOrStr,
|
|
1027
|
+
path_weights: PathLikeOrStr,
|
|
1028
|
+
categories: Mapping[int, TypeOrStr],
|
|
1029
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
1030
|
+
use_xlm_tokenizer: bool = False,
|
|
933
1031
|
):
|
|
1032
|
+
super().__init__(path_config_json, path_weights, categories, device)
|
|
934
1033
|
self.name = self.get_name(path_weights, "LayoutLMv3")
|
|
935
1034
|
self.model_id = self.get_model_id()
|
|
936
1035
|
self.model = self.get_wrapped_model(path_config_json, path_weights)
|
|
937
|
-
|
|
1036
|
+
self.model.to(self.device)
|
|
1037
|
+
self.model.config.tokenizer_class = self.get_tokenizer_class_name(
|
|
1038
|
+
self.model.__class__.__name__, use_xlm_tokenizer
|
|
1039
|
+
)
|
|
938
1040
|
|
|
939
|
-
def predict(self, **encodings: Union[
|
|
1041
|
+
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
|
|
940
1042
|
input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
|
|
941
1043
|
images = encodings.get("pixel_values")
|
|
942
1044
|
if isinstance(images, torch.Tensor):
|
|
@@ -947,11 +1049,11 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
947
1049
|
result = predict_sequence_classes(input_ids, attention_mask, token_type_ids, boxes, self.model, images)
|
|
948
1050
|
|
|
949
1051
|
result.class_id += 1
|
|
950
|
-
result.class_name = self.categories[
|
|
1052
|
+
result.class_name = self.categories.categories[result.class_id]
|
|
951
1053
|
return result
|
|
952
1054
|
|
|
953
1055
|
@staticmethod
|
|
954
|
-
def
|
|
1056
|
+
def default_kwargs_for_image_to_features_mapping() -> JsonDict:
|
|
955
1057
|
"""
|
|
956
1058
|
Add some default arguments that might be necessary when preparing a sample. Overwrite this method
|
|
957
1059
|
for some custom setting.
|
|
@@ -965,7 +1067,9 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
965
1067
|
}
|
|
966
1068
|
|
|
967
1069
|
@staticmethod
|
|
968
|
-
def get_wrapped_model(
|
|
1070
|
+
def get_wrapped_model(
|
|
1071
|
+
path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr
|
|
1072
|
+
) -> LayoutLMv3ForSequenceClassification:
|
|
969
1073
|
"""
|
|
970
1074
|
Get the inner (wrapped) model.
|
|
971
1075
|
|
|
@@ -973,7 +1077,196 @@ class HFLayoutLmv3SequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
|
973
1077
|
:param path_weights: path to model artifact
|
|
974
1078
|
:return: 'nn.Module'
|
|
975
1079
|
"""
|
|
976
|
-
config = LayoutLMv3Config.from_pretrained(pretrained_model_name_or_path=path_config_json)
|
|
1080
|
+
config = LayoutLMv3Config.from_pretrained(pretrained_model_name_or_path=os.fspath(path_config_json))
|
|
977
1081
|
return LayoutLMv3ForSequenceClassification.from_pretrained(
|
|
978
|
-
pretrained_model_name_or_path=path_weights, config=config
|
|
1082
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
def clear_model(self) -> None:
|
|
1086
|
+
self.model = None
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
class HFLiltTokenClassifier(HFLayoutLmTokenClassifierBase):
|
|
1090
|
+
"""
|
|
1091
|
+
A wrapper class for `transformers.LiltForTokenClassification` to use within a pipeline component.
|
|
1092
|
+
Check <https://huggingface.co/docs/transformers/model_doc/lilt> for documentation of the model itself.
|
|
1093
|
+
Note that this model is equipped with a head that is only useful when classifying tokens. For sequence
|
|
1094
|
+
classification and other things please use another model of the family.
|
|
1095
|
+
|
|
1096
|
+
**Example**
|
|
1097
|
+
|
|
1098
|
+
# setting up compulsory ocr service
|
|
1099
|
+
tesseract_config_path = ModelCatalog.get_full_path_configs("/dd/conf_tesseract.yaml")
|
|
1100
|
+
tess = TesseractOcrDetector(tesseract_config_path)
|
|
1101
|
+
ocr_service = TextExtractionService(tess)
|
|
1102
|
+
|
|
1103
|
+
# hf tokenizer and token classifier
|
|
1104
|
+
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
|
1105
|
+
lilt = HFLiltTokenClassifier("path/to/config.json","path/to/model.bin",
|
|
1106
|
+
categories= ['B-answer', 'B-header', 'B-question', 'E-answer',
|
|
1107
|
+
'E-header', 'E-question', 'I-answer', 'I-header',
|
|
1108
|
+
'I-question', 'O', 'S-answer', 'S-header',
|
|
1109
|
+
'S-question'])
|
|
1110
|
+
|
|
1111
|
+
# token classification service
|
|
1112
|
+
lilt_service = LMTokenClassifierService(tokenizer,lilt)
|
|
1113
|
+
|
|
1114
|
+
pipe = DoctectionPipe(pipeline_component_list=[ocr_service,lilt_service])
|
|
1115
|
+
|
|
1116
|
+
path = "path/to/some/form"
|
|
1117
|
+
df = pipe.analyze(path=path)
|
|
1118
|
+
|
|
1119
|
+
for dp in df:
|
|
1120
|
+
...
|
|
1121
|
+
"""
|
|
1122
|
+
|
|
1123
|
+
def __init__(
|
|
1124
|
+
self,
|
|
1125
|
+
path_config_json: PathLikeOrStr,
|
|
1126
|
+
path_weights: PathLikeOrStr,
|
|
1127
|
+
categories_semantics: Optional[Sequence[TypeOrStr]] = None,
|
|
1128
|
+
categories_bio: Optional[Sequence[TypeOrStr]] = None,
|
|
1129
|
+
categories: Optional[Mapping[int, TypeOrStr]] = None,
|
|
1130
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
1131
|
+
use_xlm_tokenizer: bool = False,
|
|
1132
|
+
):
|
|
1133
|
+
"""
|
|
1134
|
+
:param path_config_json: path to .json config file
|
|
1135
|
+
:param path_weights: path to model artifact
|
|
1136
|
+
:param categories_semantics: A dict with key (indices) and values (category names) for NER semantics, i.e. the
|
|
1137
|
+
entities self. To be consistent with detectors use only values >0. Conversion will
|
|
1138
|
+
be done internally.
|
|
1139
|
+
:param categories_bio: A dict with key (indices) and values (category names) for NER tags (i.e. BIO). To be
|
|
1140
|
+
consistent with detectors use only values>0. Conversion will be done internally.
|
|
1141
|
+
:param categories: If you have a pre-trained model you can pass a complete dict of NER categories
|
|
1142
|
+
:param device: The device (cpu,"cuda"), where to place the model.
|
|
1143
|
+
"""
|
|
1144
|
+
|
|
1145
|
+
super().__init__(path_config_json, path_weights, categories_semantics, categories_bio, categories, device)
|
|
1146
|
+
self.name = self.get_name(path_weights, "LiLT")
|
|
1147
|
+
self.model_id = self.get_model_id()
|
|
1148
|
+
self.model = self.get_wrapped_model(path_config_json, path_weights)
|
|
1149
|
+
self.model.to(self.device)
|
|
1150
|
+
self.model.config.tokenizer_class = self.get_tokenizer_class_name(
|
|
1151
|
+
self.model.__class__.__name__, use_xlm_tokenizer
|
|
979
1152
|
)
|
|
1153
|
+
|
|
1154
|
+
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> list[TokenClassResult]:
|
|
1155
|
+
"""
|
|
1156
|
+
Launch inference on LayoutLm for token classification. Pass the following arguments
|
|
1157
|
+
|
|
1158
|
+
`input_ids:` Token converted to ids to be taken from `LayoutLMTokenizer`
|
|
1159
|
+
|
|
1160
|
+
`attention_mask:` The associated attention masks from padded sequences taken from `LayoutLMTokenizer`
|
|
1161
|
+
|
|
1162
|
+
`token_type_ids:` Torch tensor of token type ids taken from `LayoutLMTokenizer`
|
|
1163
|
+
|
|
1164
|
+
`boxes:` Torch tensor of bounding boxes of type 'xyxy'
|
|
1165
|
+
|
|
1166
|
+
`tokens:` List of original tokens taken from `LayoutLMTokenizer`
|
|
1167
|
+
|
|
1168
|
+
:return: A list of TokenClassResults
|
|
1169
|
+
"""
|
|
1170
|
+
|
|
1171
|
+
ann_ids, _, input_ids, attention_mask, token_type_ids, boxes, tokens = self._validate_encodings(**encodings)
|
|
1172
|
+
|
|
1173
|
+
results = predict_token_classes(
|
|
1174
|
+
ann_ids, input_ids, attention_mask, token_type_ids, boxes, tokens, self.model, None
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
return self._map_category_names(results)
|
|
1178
|
+
|
|
1179
|
+
@staticmethod
|
|
1180
|
+
def get_wrapped_model(path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr) -> LiltForTokenClassification:
|
|
1181
|
+
"""
|
|
1182
|
+
Get the inner (wrapped) model.
|
|
1183
|
+
|
|
1184
|
+
:param path_config_json: path to .json config file
|
|
1185
|
+
:param path_weights: path to model artifact
|
|
1186
|
+
:return: 'nn.Module'
|
|
1187
|
+
"""
|
|
1188
|
+
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config_json)
|
|
1189
|
+
return LiltForTokenClassification.from_pretrained(pretrained_model_name_or_path=path_weights, config=config)
|
|
1190
|
+
|
|
1191
|
+
def clear_model(self) -> None:
|
|
1192
|
+
self.model = None
|
|
1193
|
+
|
|
1194
|
+
|
|
1195
|
+
class HFLiltSequenceClassifier(HFLayoutLmSequenceClassifierBase):
|
|
1196
|
+
"""
|
|
1197
|
+
A wrapper class for `transformers.LiLTForSequenceClassification` to use within a pipeline component.
|
|
1198
|
+
Check <https://huggingface.co/docs/transformers/model_doc/lilt> for documentation of the model itself.
|
|
1199
|
+
Note that this model is equipped with a head that is only useful for classifying the input sequence. For token
|
|
1200
|
+
classification and other things please use another model of the family.
|
|
1201
|
+
|
|
1202
|
+
**Example**
|
|
1203
|
+
|
|
1204
|
+
# setting up compulsory ocr service
|
|
1205
|
+
tesseract_config_path = ModelCatalog.get_full_path_configs("/dd/conf_tesseract.yaml")
|
|
1206
|
+
tess = TesseractOcrDetector(tesseract_config_path)
|
|
1207
|
+
ocr_service = TextExtractionService(tess)
|
|
1208
|
+
|
|
1209
|
+
# hf tokenizer and sequence classifier
|
|
1210
|
+
tokenizer = LayoutLMTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased")
|
|
1211
|
+
lilt = HFLiltSequenceClassifier("path/to/config.json",
|
|
1212
|
+
"path/to/model.bin",
|
|
1213
|
+
categories=["handwritten", "presentation", "resume"])
|
|
1214
|
+
|
|
1215
|
+
# sequence classification service
|
|
1216
|
+
lilt_service = LMSequenceClassifierService(tokenizer,lilt)
|
|
1217
|
+
|
|
1218
|
+
pipe = DoctectionPipe(pipeline_component_list=[ocr_service,lilt_service])
|
|
1219
|
+
|
|
1220
|
+
path = "path/to/some/form"
|
|
1221
|
+
df = pipe.analyze(path=path)
|
|
1222
|
+
|
|
1223
|
+
for dp in df:
|
|
1224
|
+
...
|
|
1225
|
+
"""
|
|
1226
|
+
|
|
1227
|
+
def __init__(
|
|
1228
|
+
self,
|
|
1229
|
+
path_config_json: PathLikeOrStr,
|
|
1230
|
+
path_weights: PathLikeOrStr,
|
|
1231
|
+
categories: Mapping[int, TypeOrStr],
|
|
1232
|
+
device: Optional[Union[Literal["cpu", "cuda"], torch.device]] = None,
|
|
1233
|
+
use_xlm_tokenizer: bool = False,
|
|
1234
|
+
):
|
|
1235
|
+
super().__init__(path_config_json, path_weights, categories, device)
|
|
1236
|
+
self.name = self.get_name(path_weights, "LiLT")
|
|
1237
|
+
self.model_id = self.get_model_id()
|
|
1238
|
+
self.model = self.get_wrapped_model(path_config_json, path_weights)
|
|
1239
|
+
self.model.to(self.device)
|
|
1240
|
+
self.model.config.tokenizer_class = self.get_tokenizer_class_name(
|
|
1241
|
+
self.model.__class__.__name__, use_xlm_tokenizer
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
def predict(self, **encodings: Union[list[list[str]], torch.Tensor]) -> SequenceClassResult:
|
|
1245
|
+
input_ids, attention_mask, token_type_ids, boxes = self._validate_encodings(**encodings)
|
|
1246
|
+
|
|
1247
|
+
result = predict_sequence_classes(
|
|
1248
|
+
input_ids,
|
|
1249
|
+
attention_mask,
|
|
1250
|
+
token_type_ids,
|
|
1251
|
+
boxes,
|
|
1252
|
+
self.model,
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
result.class_id += 1
|
|
1256
|
+
result.class_name = self.categories.categories[result.class_id]
|
|
1257
|
+
return result
|
|
1258
|
+
|
|
1259
|
+
@staticmethod
|
|
1260
|
+
def get_wrapped_model(path_config_json: PathLikeOrStr, path_weights: PathLikeOrStr) -> Any:
|
|
1261
|
+
"""
|
|
1262
|
+
Get the inner (wrapped) model.
|
|
1263
|
+
|
|
1264
|
+
:param path_config_json: path to .json config file
|
|
1265
|
+
:param path_weights: path to model artifact
|
|
1266
|
+
:return: 'nn.Module'
|
|
1267
|
+
"""
|
|
1268
|
+
config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=path_config_json)
|
|
1269
|
+
return LiltForSequenceClassification.from_pretrained(pretrained_model_name_or_path=path_weights, config=config)
|
|
1270
|
+
|
|
1271
|
+
def clear_model(self) -> None:
|
|
1272
|
+
self.model = None
|