deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +16 -29
- deepdoctection/analyzer/dd.py +70 -59
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/dataflow/common.py +9 -5
- deepdoctection/dataflow/custom.py +5 -5
- deepdoctection/dataflow/custom_serialize.py +75 -18
- deepdoctection/dataflow/parallel_map.py +3 -3
- deepdoctection/dataflow/serialize.py +4 -4
- deepdoctection/dataflow/stats.py +3 -3
- deepdoctection/datapoint/annotation.py +41 -56
- deepdoctection/datapoint/box.py +9 -8
- deepdoctection/datapoint/convert.py +6 -6
- deepdoctection/datapoint/image.py +56 -44
- deepdoctection/datapoint/view.py +245 -150
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +35 -26
- deepdoctection/datasets/base.py +14 -12
- deepdoctection/datasets/dataflow_builder.py +3 -3
- deepdoctection/datasets/info.py +24 -26
- deepdoctection/datasets/instances/doclaynet.py +51 -51
- deepdoctection/datasets/instances/fintabnet.py +46 -46
- deepdoctection/datasets/instances/funsd.py +25 -24
- deepdoctection/datasets/instances/iiitar13k.py +13 -10
- deepdoctection/datasets/instances/layouttest.py +4 -3
- deepdoctection/datasets/instances/publaynet.py +5 -5
- deepdoctection/datasets/instances/pubtables1m.py +24 -21
- deepdoctection/datasets/instances/pubtabnet.py +32 -30
- deepdoctection/datasets/instances/rvlcdip.py +30 -30
- deepdoctection/datasets/instances/xfund.py +26 -26
- deepdoctection/datasets/save.py +6 -6
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/accmetric.py +32 -33
- deepdoctection/eval/base.py +8 -9
- deepdoctection/eval/cocometric.py +15 -13
- deepdoctection/eval/eval.py +41 -37
- deepdoctection/eval/tedsmetric.py +30 -23
- deepdoctection/eval/tp_eval_callback.py +16 -19
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/base.py +339 -134
- deepdoctection/extern/d2detect.py +85 -113
- deepdoctection/extern/deskew.py +14 -11
- deepdoctection/extern/doctrocr.py +141 -130
- deepdoctection/extern/fastlang.py +27 -18
- deepdoctection/extern/hfdetr.py +71 -62
- deepdoctection/extern/hflayoutlm.py +504 -211
- deepdoctection/extern/hflm.py +230 -0
- deepdoctection/extern/model.py +488 -302
- deepdoctection/extern/pdftext.py +23 -19
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +29 -19
- deepdoctection/extern/tessocr.py +39 -38
- deepdoctection/extern/texocr.py +18 -18
- deepdoctection/extern/tp/tfutils.py +57 -9
- deepdoctection/extern/tp/tpcompat.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +45 -53
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/cats.py +27 -29
- deepdoctection/mapper/cocostruct.py +10 -10
- deepdoctection/mapper/d2struct.py +27 -26
- deepdoctection/mapper/hfstruct.py +13 -8
- deepdoctection/mapper/laylmstruct.py +178 -37
- deepdoctection/mapper/maputils.py +12 -11
- deepdoctection/mapper/match.py +2 -2
- deepdoctection/mapper/misc.py +11 -9
- deepdoctection/mapper/pascalstruct.py +4 -4
- deepdoctection/mapper/prodigystruct.py +5 -5
- deepdoctection/mapper/pubstruct.py +84 -92
- deepdoctection/mapper/tpstruct.py +5 -5
- deepdoctection/mapper/xfundstruct.py +33 -33
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/anngen.py +12 -14
- deepdoctection/pipe/base.py +52 -106
- deepdoctection/pipe/common.py +72 -59
- deepdoctection/pipe/concurrency.py +16 -11
- deepdoctection/pipe/doctectionpipe.py +24 -21
- deepdoctection/pipe/language.py +20 -25
- deepdoctection/pipe/layout.py +20 -16
- deepdoctection/pipe/lm.py +75 -105
- deepdoctection/pipe/order.py +194 -89
- deepdoctection/pipe/refine.py +111 -124
- deepdoctection/pipe/segment.py +156 -161
- deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
- deepdoctection/pipe/text.py +37 -36
- deepdoctection/pipe/transform.py +19 -16
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +48 -41
- deepdoctection/train/hf_detr_train.py +41 -30
- deepdoctection/train/hf_layoutlm_train.py +153 -135
- deepdoctection/train/tp_frcnn_train.py +32 -31
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +13 -6
- deepdoctection/utils/develop.py +4 -4
- deepdoctection/utils/env_info.py +87 -125
- deepdoctection/utils/file_utils.py +6 -11
- deepdoctection/utils/fs.py +22 -18
- deepdoctection/utils/identifier.py +2 -2
- deepdoctection/utils/logger.py +16 -15
- deepdoctection/utils/metacfg.py +7 -7
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/pdf_utils.py +11 -11
- deepdoctection/utils/settings.py +185 -181
- deepdoctection/utils/tqdm.py +1 -1
- deepdoctection/utils/transform.py +14 -9
- deepdoctection/utils/types.py +104 -0
- deepdoctection/utils/utils.py +7 -7
- deepdoctection/utils/viz.py +74 -72
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
- deepdoctection-0.33.dist-info/RECORD +146 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
- deepdoctection/utils/detection_types.py +0 -68
- deepdoctection-0.31.dist-info/RECORD +0 -144
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
|
@@ -18,32 +18,15 @@
|
|
|
18
18
|
"""
|
|
19
19
|
Module for training Huggingface implementation of LayoutLm
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
21
22
|
|
|
22
23
|
import copy
|
|
23
24
|
import json
|
|
24
25
|
import os
|
|
25
26
|
import pprint
|
|
26
|
-
from typing import Any,
|
|
27
|
-
|
|
28
|
-
from
|
|
29
|
-
from torch.utils.data import Dataset
|
|
30
|
-
from transformers import (
|
|
31
|
-
IntervalStrategy,
|
|
32
|
-
LayoutLMForSequenceClassification,
|
|
33
|
-
LayoutLMForTokenClassification,
|
|
34
|
-
LayoutLMTokenizerFast,
|
|
35
|
-
LayoutLMv2Config,
|
|
36
|
-
LayoutLMv2ForSequenceClassification,
|
|
37
|
-
LayoutLMv2ForTokenClassification,
|
|
38
|
-
LayoutLMv3Config,
|
|
39
|
-
LayoutLMv3ForSequenceClassification,
|
|
40
|
-
LayoutLMv3ForTokenClassification,
|
|
41
|
-
PretrainedConfig,
|
|
42
|
-
PreTrainedModel,
|
|
43
|
-
RobertaTokenizerFast,
|
|
44
|
-
XLMRobertaTokenizerFast,
|
|
45
|
-
)
|
|
46
|
-
from transformers.trainer import Trainer, TrainingArguments
|
|
27
|
+
from typing import Any, Optional, Sequence, Type, Union
|
|
28
|
+
|
|
29
|
+
from lazy_imports import try_import
|
|
47
30
|
|
|
48
31
|
from ..datasets.adapter import DatasetAdapter
|
|
49
32
|
from ..datasets.base import DatasetBase
|
|
@@ -57,79 +40,109 @@ from ..extern.hflayoutlm import (
|
|
|
57
40
|
HFLayoutLmv2TokenClassifier,
|
|
58
41
|
HFLayoutLmv3SequenceClassifier,
|
|
59
42
|
HFLayoutLmv3TokenClassifier,
|
|
43
|
+
HFLiltSequenceClassifier,
|
|
44
|
+
HFLiltTokenClassifier,
|
|
45
|
+
get_tokenizer_from_model_class,
|
|
60
46
|
)
|
|
61
|
-
from ..
|
|
62
|
-
from ..
|
|
63
|
-
from ..
|
|
47
|
+
from ..extern.hflm import HFLmSequenceClassifier
|
|
48
|
+
from ..extern.pt.ptutils import get_torch_device
|
|
49
|
+
from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features, image_to_raw_lm_features
|
|
50
|
+
from ..pipe.base import PipelineComponent
|
|
64
51
|
from ..pipe.registry import pipeline_component_registry
|
|
65
|
-
from ..utils.env_info import get_device
|
|
66
52
|
from ..utils.error import DependencyError
|
|
67
53
|
from ..utils.file_utils import wandb_available
|
|
68
54
|
from ..utils.logger import LoggingRecord, logger
|
|
69
|
-
from ..utils.settings import DatasetType, LayoutType,
|
|
55
|
+
from ..utils.settings import DatasetType, LayoutType, WordType
|
|
56
|
+
from ..utils.types import PathLikeOrStr
|
|
70
57
|
from ..utils.utils import string_to_dict
|
|
71
58
|
|
|
72
|
-
|
|
73
|
-
import
|
|
74
|
-
|
|
75
|
-
_ARCHITECTURES_TO_MODEL_CLASS = {
|
|
76
|
-
"LayoutLMForTokenClassification": (LayoutLMForTokenClassification, HFLayoutLmTokenClassifier, PretrainedConfig),
|
|
77
|
-
"LayoutLMForSequenceClassification": (
|
|
78
|
-
LayoutLMForSequenceClassification,
|
|
79
|
-
HFLayoutLmSequenceClassifier,
|
|
80
|
-
PretrainedConfig,
|
|
81
|
-
),
|
|
82
|
-
"LayoutLMv2ForTokenClassification": (
|
|
83
|
-
LayoutLMv2ForTokenClassification,
|
|
84
|
-
HFLayoutLmv2TokenClassifier,
|
|
85
|
-
LayoutLMv2Config,
|
|
86
|
-
),
|
|
87
|
-
"LayoutLMv2ForSequenceClassification": (
|
|
88
|
-
LayoutLMv2ForSequenceClassification,
|
|
89
|
-
HFLayoutLmv2SequenceClassifier,
|
|
90
|
-
LayoutLMv2Config,
|
|
91
|
-
),
|
|
92
|
-
}
|
|
93
|
-
|
|
59
|
+
with try_import() as pt_import_guard:
|
|
60
|
+
from torch import nn
|
|
61
|
+
from torch.utils.data import Dataset
|
|
94
62
|
|
|
95
|
-
|
|
96
|
-
|
|
63
|
+
with try_import() as tr_import_guard:
|
|
64
|
+
from transformers import (
|
|
65
|
+
IntervalStrategy,
|
|
97
66
|
LayoutLMForSequenceClassification,
|
|
98
|
-
HFLayoutLmSequenceClassifier,
|
|
99
|
-
PretrainedConfig,
|
|
100
|
-
),
|
|
101
|
-
("layoutlm", DatasetType.token_classification): (
|
|
102
67
|
LayoutLMForTokenClassification,
|
|
103
|
-
HFLayoutLmTokenClassifier,
|
|
104
|
-
PretrainedConfig,
|
|
105
|
-
),
|
|
106
|
-
("layoutlmv2", DatasetType.sequence_classification): (
|
|
107
|
-
LayoutLMv2ForSequenceClassification,
|
|
108
|
-
HFLayoutLmv2SequenceClassifier,
|
|
109
68
|
LayoutLMv2Config,
|
|
110
|
-
|
|
111
|
-
("layoutlmv2", DatasetType.token_classification): (
|
|
69
|
+
LayoutLMv2ForSequenceClassification,
|
|
112
70
|
LayoutLMv2ForTokenClassification,
|
|
113
|
-
HFLayoutLmv2TokenClassifier,
|
|
114
|
-
LayoutLMv2Config,
|
|
115
|
-
),
|
|
116
|
-
("layoutlmv3", DatasetType.sequence_classification): (
|
|
117
|
-
LayoutLMv3ForSequenceClassification,
|
|
118
|
-
HFLayoutLmv3SequenceClassifier,
|
|
119
71
|
LayoutLMv3Config,
|
|
120
|
-
|
|
121
|
-
("layoutlmv3", DatasetType.token_classification): (
|
|
72
|
+
LayoutLMv3ForSequenceClassification,
|
|
122
73
|
LayoutLMv3ForTokenClassification,
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
74
|
+
LiltForSequenceClassification,
|
|
75
|
+
LiltForTokenClassification,
|
|
76
|
+
PretrainedConfig,
|
|
77
|
+
PreTrainedModel,
|
|
78
|
+
XLMRobertaForSequenceClassification,
|
|
79
|
+
)
|
|
80
|
+
from transformers.trainer import Trainer, TrainingArguments
|
|
81
|
+
|
|
82
|
+
with try_import() as wb_import_guard:
|
|
83
|
+
import wandb
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> tuple[Any, Any, Any]:
|
|
87
|
+
"""
|
|
88
|
+
Get the model architecture, model wrapper and config class for a given model type and dataset type.
|
|
89
|
+
|
|
90
|
+
:param model_type: The model type
|
|
91
|
+
:param dataset_type: The dataset type
|
|
92
|
+
:return: Tuple of model architecture, model wrapper and config class
|
|
93
|
+
"""
|
|
94
|
+
return {
|
|
95
|
+
("layoutlm", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
96
|
+
LayoutLMForSequenceClassification,
|
|
97
|
+
HFLayoutLmSequenceClassifier,
|
|
98
|
+
PretrainedConfig,
|
|
99
|
+
),
|
|
100
|
+
("layoutlm", DatasetType.TOKEN_CLASSIFICATION): (
|
|
101
|
+
LayoutLMForTokenClassification,
|
|
102
|
+
HFLayoutLmTokenClassifier,
|
|
103
|
+
PretrainedConfig,
|
|
104
|
+
),
|
|
105
|
+
("layoutlmv2", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
106
|
+
LayoutLMv2ForSequenceClassification,
|
|
107
|
+
HFLayoutLmv2SequenceClassifier,
|
|
108
|
+
LayoutLMv2Config,
|
|
109
|
+
),
|
|
110
|
+
("layoutlmv2", DatasetType.TOKEN_CLASSIFICATION): (
|
|
111
|
+
LayoutLMv2ForTokenClassification,
|
|
112
|
+
HFLayoutLmv2TokenClassifier,
|
|
113
|
+
LayoutLMv2Config,
|
|
114
|
+
),
|
|
115
|
+
("layoutlmv3", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
116
|
+
LayoutLMv3ForSequenceClassification,
|
|
117
|
+
HFLayoutLmv3SequenceClassifier,
|
|
118
|
+
LayoutLMv3Config,
|
|
119
|
+
),
|
|
120
|
+
("layoutlmv3", DatasetType.TOKEN_CLASSIFICATION): (
|
|
121
|
+
LayoutLMv3ForTokenClassification,
|
|
122
|
+
HFLayoutLmv3TokenClassifier,
|
|
123
|
+
LayoutLMv3Config,
|
|
124
|
+
),
|
|
125
|
+
("lilt", DatasetType.TOKEN_CLASSIFICATION): (
|
|
126
|
+
LiltForTokenClassification,
|
|
127
|
+
HFLiltTokenClassifier,
|
|
128
|
+
PretrainedConfig,
|
|
129
|
+
),
|
|
130
|
+
("lilt", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
131
|
+
LiltForSequenceClassification,
|
|
132
|
+
HFLiltSequenceClassifier,
|
|
133
|
+
PretrainedConfig,
|
|
134
|
+
),
|
|
135
|
+
("xlm-roberta", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
136
|
+
XLMRobertaForSequenceClassification,
|
|
137
|
+
HFLmSequenceClassifier,
|
|
138
|
+
PretrainedConfig,
|
|
139
|
+
),
|
|
140
|
+
}[(model_type, dataset_type)]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def maybe_remove_bounding_box_features(model_type: str) -> bool:
|
|
144
|
+
"""Listing of models that do not need bounding box features."""
|
|
145
|
+
return {"xlm-roberta": True}.get(model_type, False)
|
|
133
146
|
|
|
134
147
|
|
|
135
148
|
class LayoutLMTrainer(Trainer):
|
|
@@ -145,21 +158,21 @@ class LayoutLMTrainer(Trainer):
|
|
|
145
158
|
|
|
146
159
|
def __init__(
|
|
147
160
|
self,
|
|
148
|
-
model: Union[PreTrainedModel, Module],
|
|
161
|
+
model: Union[PreTrainedModel, nn.Module],
|
|
149
162
|
args: TrainingArguments,
|
|
150
163
|
data_collator: LayoutLMDataCollator,
|
|
151
164
|
train_dataset: Dataset[Any],
|
|
152
165
|
):
|
|
153
166
|
self.evaluator: Optional[Evaluator] = None
|
|
154
|
-
self.build_eval_kwargs: Optional[
|
|
167
|
+
self.build_eval_kwargs: Optional[dict[str, Any]] = None
|
|
155
168
|
super().__init__(model, args, data_collator, train_dataset)
|
|
156
169
|
|
|
157
170
|
def setup_evaluator(
|
|
158
171
|
self,
|
|
159
172
|
dataset_val: DatasetBase,
|
|
160
|
-
pipeline_component:
|
|
173
|
+
pipeline_component: PipelineComponent,
|
|
161
174
|
metric: Union[Type[ClassificationMetric], ClassificationMetric],
|
|
162
|
-
run: Optional[
|
|
175
|
+
run: Optional[wandb.sdk.wandb_run.Run] = None,
|
|
163
176
|
**build_eval_kwargs: Union[str, int],
|
|
164
177
|
) -> None:
|
|
165
178
|
"""
|
|
@@ -176,15 +189,15 @@ class LayoutLMTrainer(Trainer):
|
|
|
176
189
|
self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1, run=run)
|
|
177
190
|
assert self.evaluator.pipe_component
|
|
178
191
|
for comp in self.evaluator.pipe_component.pipe_components:
|
|
179
|
-
comp.
|
|
192
|
+
comp.clear_predictor()
|
|
180
193
|
self.build_eval_kwargs = build_eval_kwargs
|
|
181
194
|
|
|
182
195
|
def evaluate(
|
|
183
196
|
self,
|
|
184
197
|
eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
|
|
185
|
-
ignore_keys: Optional[
|
|
198
|
+
ignore_keys: Optional[list[str]] = None, # pylint: disable=W0613
|
|
186
199
|
metric_key_prefix: str = "eval", # pylint: disable=W0613
|
|
187
|
-
) ->
|
|
200
|
+
) -> dict[str, float]:
|
|
188
201
|
"""
|
|
189
202
|
Overwritten method from `Trainer`. Arguments will not be used.
|
|
190
203
|
"""
|
|
@@ -208,34 +221,35 @@ class LayoutLMTrainer(Trainer):
|
|
|
208
221
|
|
|
209
222
|
|
|
210
223
|
def _get_model_class_and_tokenizer(
|
|
211
|
-
path_config_json:
|
|
212
|
-
) ->
|
|
224
|
+
path_config_json: PathLikeOrStr, dataset_type: DatasetType, use_xlm_tokenizer: bool
|
|
225
|
+
) -> tuple[Any, Any, Any, Any, Any]:
|
|
213
226
|
with open(path_config_json, "r", encoding="UTF-8") as file:
|
|
214
227
|
config_json = json.load(file)
|
|
215
228
|
|
|
216
|
-
model_type
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
model_cls, model_wrapper_cls, config_cls = _ARCHITECTURES_TO_MODEL_CLASS[architectures[0]]
|
|
220
|
-
tokenizer_fast = get_tokenizer_from_architecture(architectures[0], use_xlm_tokenizer)
|
|
221
|
-
elif model_type:
|
|
222
|
-
model_cls, model_wrapper_cls, config_cls = _MODEL_TYPE_AND_TASK_TO_MODEL_CLASS[(model_type, dataset_type)]
|
|
223
|
-
tokenizer_fast = _MODEL_TYPE_TO_TOKENIZER[(model_type, use_xlm_tokenizer)]
|
|
229
|
+
if model_type := config_json.get("model_type"):
|
|
230
|
+
model_cls, model_wrapper_cls, config_cls = get_model_architectures_and_configs(model_type, dataset_type)
|
|
231
|
+
remove_box_features = maybe_remove_bounding_box_features(model_type)
|
|
224
232
|
else:
|
|
225
|
-
raise KeyError("model_type
|
|
233
|
+
raise KeyError("model_type not available in configs. It seems that the config is not valid")
|
|
226
234
|
|
|
227
|
-
|
|
228
|
-
|
|
235
|
+
tokenizer_fast = get_tokenizer_from_model_class(model_cls.__name__, use_xlm_tokenizer)
|
|
236
|
+
return config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features
|
|
229
237
|
|
|
230
|
-
|
|
238
|
+
|
|
239
|
+
def get_image_to_raw_features_mapping(input_str: str) -> Any:
|
|
240
|
+
"""Replacing eval functions"""
|
|
241
|
+
return {
|
|
242
|
+
"image_to_raw_layoutlm_features": image_to_raw_layoutlm_features,
|
|
243
|
+
"image_to_raw_lm_features": image_to_raw_lm_features,
|
|
244
|
+
}[input_str]
|
|
231
245
|
|
|
232
246
|
|
|
233
247
|
def train_hf_layoutlm(
|
|
234
|
-
path_config_json:
|
|
248
|
+
path_config_json: PathLikeOrStr,
|
|
235
249
|
dataset_train: Union[str, DatasetBase],
|
|
236
|
-
path_weights:
|
|
237
|
-
config_overwrite: Optional[
|
|
238
|
-
log_dir:
|
|
250
|
+
path_weights: PathLikeOrStr,
|
|
251
|
+
config_overwrite: Optional[list[str]] = None,
|
|
252
|
+
log_dir: PathLikeOrStr = "train_log/layoutlm",
|
|
239
253
|
build_train_config: Optional[Sequence[str]] = None,
|
|
240
254
|
dataset_val: Optional[DatasetBase] = None,
|
|
241
255
|
build_val_config: Optional[Sequence[str]] = None,
|
|
@@ -310,13 +324,13 @@ def train_hf_layoutlm(
|
|
|
310
324
|
appear as child, it will use the word bounding box.
|
|
311
325
|
"""
|
|
312
326
|
|
|
313
|
-
build_train_dict:
|
|
327
|
+
build_train_dict: dict[str, str] = {}
|
|
314
328
|
if build_train_config is not None:
|
|
315
329
|
build_train_dict = string_to_dict(",".join(build_train_config))
|
|
316
330
|
if "split" not in build_train_dict:
|
|
317
331
|
build_train_dict["split"] = "train"
|
|
318
332
|
|
|
319
|
-
build_val_dict:
|
|
333
|
+
build_val_dict: dict[str, str] = {}
|
|
320
334
|
if build_val_config is not None:
|
|
321
335
|
build_val_dict = string_to_dict(",".join(build_val_config))
|
|
322
336
|
if "split" not in build_val_dict:
|
|
@@ -330,40 +344,43 @@ def train_hf_layoutlm(
|
|
|
330
344
|
|
|
331
345
|
# We wrap our dataset into a torch dataset
|
|
332
346
|
dataset_type = dataset_train.dataset_info.type
|
|
333
|
-
if dataset_type == DatasetType.
|
|
347
|
+
if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
|
|
334
348
|
categories_dict_name_as_key = dataset_train.dataflow.categories.get_categories(as_dict=True, name_as_key=True)
|
|
335
|
-
elif dataset_type == DatasetType.
|
|
349
|
+
elif dataset_type == DatasetType.TOKEN_CLASSIFICATION:
|
|
336
350
|
if use_token_tag:
|
|
337
351
|
categories_dict_name_as_key = dataset_train.dataflow.categories.get_sub_categories(
|
|
338
|
-
categories=LayoutType.
|
|
339
|
-
sub_categories={LayoutType.
|
|
352
|
+
categories=LayoutType.WORD,
|
|
353
|
+
sub_categories={LayoutType.WORD: [WordType.TOKEN_TAG]},
|
|
340
354
|
keys=False,
|
|
341
355
|
values_as_dict=True,
|
|
342
356
|
name_as_key=True,
|
|
343
|
-
)[LayoutType.
|
|
357
|
+
)[LayoutType.WORD][WordType.TOKEN_TAG]
|
|
344
358
|
else:
|
|
345
359
|
categories_dict_name_as_key = dataset_train.dataflow.categories.get_sub_categories(
|
|
346
|
-
categories=LayoutType.
|
|
347
|
-
sub_categories={LayoutType.
|
|
360
|
+
categories=LayoutType.WORD,
|
|
361
|
+
sub_categories={LayoutType.WORD: [WordType.TOKEN_CLASS]},
|
|
348
362
|
keys=False,
|
|
349
363
|
values_as_dict=True,
|
|
350
364
|
name_as_key=True,
|
|
351
|
-
)[LayoutType.
|
|
365
|
+
)[LayoutType.WORD][WordType.TOKEN_CLASS]
|
|
352
366
|
else:
|
|
353
367
|
raise UserWarning("Dataset type not supported for training")
|
|
354
368
|
|
|
355
|
-
config_cls, model_cls, model_wrapper_cls, tokenizer_fast = _get_model_class_and_tokenizer(
|
|
369
|
+
config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features = _get_model_class_and_tokenizer(
|
|
356
370
|
path_config_json, dataset_type, use_xlm_tokenizer
|
|
357
371
|
)
|
|
358
|
-
|
|
372
|
+
image_to_raw_features_func = get_image_to_raw_features_mapping(model_wrapper_cls.image_to_raw_features_mapping())
|
|
373
|
+
image_to_raw_features_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
|
|
359
374
|
if segment_positions:
|
|
360
|
-
|
|
361
|
-
|
|
375
|
+
image_to_raw_features_kwargs["segment_positions"] = segment_positions # type: ignore
|
|
376
|
+
image_to_raw_features_kwargs.update(model_wrapper_cls.default_kwargs_for_image_to_features_mapping())
|
|
377
|
+
|
|
362
378
|
dataset = DatasetAdapter(
|
|
363
379
|
dataset_train,
|
|
364
380
|
True,
|
|
365
|
-
|
|
381
|
+
image_to_raw_features_func(**image_to_raw_features_kwargs),
|
|
366
382
|
use_token_tag,
|
|
383
|
+
number_repetitions=-1,
|
|
367
384
|
**build_train_dict,
|
|
368
385
|
)
|
|
369
386
|
|
|
@@ -373,7 +390,7 @@ def train_hf_layoutlm(
|
|
|
373
390
|
# Need to set remove_unused_columns to False, as the DataCollator for column removal will remove some raw features
|
|
374
391
|
# that are necessary for the tokenizer.
|
|
375
392
|
conf_dict = {
|
|
376
|
-
"output_dir": log_dir,
|
|
393
|
+
"output_dir": os.fspath(log_dir),
|
|
377
394
|
"remove_unused_columns": False,
|
|
378
395
|
"per_device_train_batch_size": 8,
|
|
379
396
|
"max_steps": number_samples,
|
|
@@ -414,16 +431,16 @@ def train_hf_layoutlm(
|
|
|
414
431
|
)
|
|
415
432
|
|
|
416
433
|
use_wandb = conf_dict.pop("use_wandb")
|
|
417
|
-
wandb_project = conf_dict.pop("wandb_project")
|
|
418
|
-
wandb_repo = conf_dict.pop("wandb_repo")
|
|
434
|
+
wandb_project = str(conf_dict.pop("wandb_project"))
|
|
435
|
+
wandb_repo = str(conf_dict.pop("wandb_repo"))
|
|
419
436
|
|
|
420
437
|
# Initialize Wandb, if necessary
|
|
421
438
|
run = None
|
|
422
439
|
if use_wandb:
|
|
423
440
|
if not wandb_available():
|
|
424
441
|
raise DependencyError("WandB must be installed separately")
|
|
425
|
-
run = wandb.init(project=wandb_project, config=conf_dict)
|
|
426
|
-
run._label(repo=wandb_repo) #
|
|
442
|
+
run = wandb.init(project=wandb_project, config=conf_dict)
|
|
443
|
+
run._label(repo=wandb_repo) # pylint: disable=W0212
|
|
427
444
|
else:
|
|
428
445
|
os.environ["WANDB_DISABLED"] = "True"
|
|
429
446
|
|
|
@@ -453,32 +470,34 @@ def train_hf_layoutlm(
|
|
|
453
470
|
return_tensors="pt",
|
|
454
471
|
sliding_window_stride=sliding_window_stride, # type: ignore
|
|
455
472
|
max_batch_size=max_batch_size, # type: ignore
|
|
473
|
+
remove_bounding_box_features=remove_box_features,
|
|
456
474
|
)
|
|
457
475
|
trainer = LayoutLMTrainer(model, arguments, data_collator, dataset)
|
|
458
476
|
|
|
459
477
|
if arguments.evaluation_strategy in (IntervalStrategy.STEPS,):
|
|
460
478
|
assert metric is not None # silence mypy
|
|
461
|
-
if dataset_type == DatasetType.
|
|
479
|
+
if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
|
|
462
480
|
categories = dataset_val.dataflow.categories.get_categories(filtered=True) # type: ignore
|
|
463
481
|
else:
|
|
464
482
|
if use_token_tag:
|
|
465
483
|
categories = dataset_val.dataflow.categories.get_sub_categories( # type: ignore
|
|
466
|
-
categories=LayoutType.
|
|
467
|
-
)[LayoutType.
|
|
468
|
-
metric.set_categories(category_names=LayoutType.
|
|
484
|
+
categories=LayoutType.WORD, sub_categories={LayoutType.WORD: [WordType.TOKEN_TAG]}, keys=False
|
|
485
|
+
)[LayoutType.WORD][WordType.TOKEN_TAG]
|
|
486
|
+
metric.set_categories(category_names=LayoutType.WORD, sub_category_names={"word": ["token_tag"]})
|
|
469
487
|
else:
|
|
470
488
|
categories = dataset_val.dataflow.categories.get_sub_categories( # type: ignore
|
|
471
|
-
categories=LayoutType.
|
|
472
|
-
)[LayoutType.
|
|
473
|
-
metric.set_categories(category_names=LayoutType.
|
|
489
|
+
categories=LayoutType.WORD, sub_categories={LayoutType.WORD: [WordType.TOKEN_CLASS]}, keys=False
|
|
490
|
+
)[LayoutType.WORD][WordType.TOKEN_CLASS]
|
|
491
|
+
metric.set_categories(category_names=LayoutType.WORD, sub_category_names={"word": ["token_class"]})
|
|
474
492
|
dd_model = model_wrapper_cls(
|
|
475
493
|
path_config_json=path_config_json,
|
|
476
494
|
path_weights=path_weights,
|
|
477
495
|
categories=categories,
|
|
478
|
-
device=
|
|
496
|
+
device=get_torch_device(),
|
|
497
|
+
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
479
498
|
)
|
|
480
499
|
pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
|
|
481
|
-
if dataset_type == DatasetType.
|
|
500
|
+
if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
|
|
482
501
|
pipeline_component = pipeline_component_cls(tokenizer_fast, dd_model)
|
|
483
502
|
else:
|
|
484
503
|
pipeline_component = pipeline_component_cls(
|
|
@@ -487,7 +506,6 @@ def train_hf_layoutlm(
|
|
|
487
506
|
use_other_as_default_category=True,
|
|
488
507
|
sliding_window_stride=sliding_window_stride,
|
|
489
508
|
)
|
|
490
|
-
assert isinstance(pipeline_component, LanguageModelPipelineComponent)
|
|
491
509
|
|
|
492
510
|
trainer.setup_evaluator(dataset_val, pipeline_component, metric, run, **build_val_dict) # type: ignore
|
|
493
511
|
|
|
@@ -20,27 +20,9 @@ Module for training Tensorpack `GeneralizedRCNN`
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
import os
|
|
23
|
-
from typing import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
from tensorpack.callbacks import (
|
|
27
|
-
EstimatedTimeLeft,
|
|
28
|
-
GPUMemoryTracker,
|
|
29
|
-
GPUUtilizationTracker,
|
|
30
|
-
HostMemoryTracker,
|
|
31
|
-
ModelSaver,
|
|
32
|
-
PeriodicCallback,
|
|
33
|
-
ScheduledHyperParamSetter,
|
|
34
|
-
SessionRunTimeout,
|
|
35
|
-
ThroughputTracker,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# todo: check how dataflow import is directly possible without having AssertionError
|
|
39
|
-
from tensorpack.dataflow import ProxyDataFlow, imgaug
|
|
40
|
-
from tensorpack.input_source import QueueInput
|
|
41
|
-
from tensorpack.tfutils import SmartInit
|
|
42
|
-
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
|
|
43
|
-
from tensorpack.utils import logger
|
|
23
|
+
from typing import Optional, Sequence, Type, Union
|
|
24
|
+
|
|
25
|
+
from lazy_imports import try_import
|
|
44
26
|
|
|
45
27
|
from ..dataflow.base import DataFlow
|
|
46
28
|
from ..dataflow.common import MapData
|
|
@@ -58,16 +40,35 @@ from ..extern.tp.tpfrcnn.preproc import anchors_and_labels, augment
|
|
|
58
40
|
from ..extern.tpdetect import TPFrcnnDetector
|
|
59
41
|
from ..mapper.maputils import LabelSummarizer
|
|
60
42
|
from ..mapper.tpstruct import image_to_tp_frcnn_training
|
|
61
|
-
from ..pipe.base import PredictorPipelineComponent
|
|
62
43
|
from ..pipe.registry import pipeline_component_registry
|
|
63
|
-
from ..utils.detection_types import JsonDict
|
|
64
44
|
from ..utils.file_utils import set_mp_spawn
|
|
65
45
|
from ..utils.fs import get_load_image_func
|
|
66
46
|
from ..utils.logger import log_once
|
|
67
47
|
from ..utils.metacfg import AttrDict, set_config_by_yaml
|
|
68
48
|
from ..utils.tqdm import get_tqdm
|
|
49
|
+
from ..utils.types import JsonDict, PathLikeOrStr
|
|
69
50
|
from ..utils.utils import string_to_dict
|
|
70
51
|
|
|
52
|
+
with try_import() as tp_import_guard:
|
|
53
|
+
# todo: check how dataflow import is directly possible without having an AssertionError
|
|
54
|
+
# pylint: disable=import-error
|
|
55
|
+
from tensorpack.callbacks import (
|
|
56
|
+
EstimatedTimeLeft,
|
|
57
|
+
GPUMemoryTracker,
|
|
58
|
+
GPUUtilizationTracker,
|
|
59
|
+
HostMemoryTracker,
|
|
60
|
+
ModelSaver,
|
|
61
|
+
PeriodicCallback,
|
|
62
|
+
ScheduledHyperParamSetter,
|
|
63
|
+
SessionRunTimeout,
|
|
64
|
+
ThroughputTracker,
|
|
65
|
+
)
|
|
66
|
+
from tensorpack.dataflow import ProxyDataFlow, imgaug
|
|
67
|
+
from tensorpack.input_source import QueueInput
|
|
68
|
+
from tensorpack.tfutils import SmartInit
|
|
69
|
+
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
|
|
70
|
+
from tensorpack.utils import logger
|
|
71
|
+
|
|
71
72
|
__all__ = ["train_faster_rcnn"]
|
|
72
73
|
|
|
73
74
|
|
|
@@ -183,11 +184,11 @@ def get_train_dataflow(
|
|
|
183
184
|
|
|
184
185
|
|
|
185
186
|
def train_faster_rcnn(
|
|
186
|
-
path_config_yaml:
|
|
187
|
+
path_config_yaml: PathLikeOrStr,
|
|
187
188
|
dataset_train: DatasetBase,
|
|
188
|
-
path_weights:
|
|
189
|
-
config_overwrite: Optional[
|
|
190
|
-
log_dir:
|
|
189
|
+
path_weights: PathLikeOrStr,
|
|
190
|
+
config_overwrite: Optional[list[str]] = None,
|
|
191
|
+
log_dir: PathLikeOrStr = "train_log/frcnn",
|
|
191
192
|
build_train_config: Optional[Sequence[str]] = None,
|
|
192
193
|
dataset_val: Optional[DatasetBase] = None,
|
|
193
194
|
build_val_config: Optional[Sequence[str]] = None,
|
|
@@ -222,13 +223,13 @@ def train_faster_rcnn(
|
|
|
222
223
|
|
|
223
224
|
assert disable_tfv2() # TP works only in Graph mode
|
|
224
225
|
|
|
225
|
-
build_train_dict:
|
|
226
|
+
build_train_dict: dict[str, str] = {}
|
|
226
227
|
if build_train_config is not None:
|
|
227
228
|
build_train_dict = string_to_dict(",".join(build_train_config))
|
|
228
229
|
if "split" not in build_train_dict:
|
|
229
230
|
build_train_dict["split"] = "train"
|
|
230
231
|
|
|
231
|
-
build_val_dict:
|
|
232
|
+
build_val_dict: dict[str, str] = {}
|
|
232
233
|
if build_val_config is not None:
|
|
233
234
|
build_val_dict = string_to_dict(",".join(build_val_config))
|
|
234
235
|
if "split" not in build_val_dict:
|
|
@@ -236,7 +237,7 @@ def train_faster_rcnn(
|
|
|
236
237
|
|
|
237
238
|
config_overwrite = [] if config_overwrite is None else config_overwrite
|
|
238
239
|
|
|
239
|
-
log_dir = "TRAIN.LOG_DIR=" + log_dir
|
|
240
|
+
log_dir = "TRAIN.LOG_DIR=" + os.fspath(log_dir)
|
|
240
241
|
config_overwrite.append(log_dir)
|
|
241
242
|
|
|
242
243
|
config = set_config_by_yaml(path_config_yaml)
|
|
@@ -297,7 +298,6 @@ def train_faster_rcnn(
|
|
|
297
298
|
) # only a wrapper for the predictor itself. Will be replaced in Callback
|
|
298
299
|
pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
|
|
299
300
|
pipeline_component = pipeline_component_cls(detector)
|
|
300
|
-
assert isinstance(pipeline_component, PredictorPipelineComponent)
|
|
301
301
|
category_names = list(categories.values())
|
|
302
302
|
callbacks.extend(
|
|
303
303
|
[
|
|
@@ -308,6 +308,7 @@ def train_faster_rcnn(
|
|
|
308
308
|
metric, # type: ignore
|
|
309
309
|
pipeline_component,
|
|
310
310
|
*model.get_inference_tensor_names(), # type: ignore
|
|
311
|
+
cfg=detector.model.cfg,
|
|
311
312
|
**build_val_dict
|
|
312
313
|
)
|
|
313
314
|
]
|
|
@@ -28,8 +28,8 @@ import threading
|
|
|
28
28
|
from contextlib import contextmanager
|
|
29
29
|
from typing import Any, Generator, Optional, no_type_check
|
|
30
30
|
|
|
31
|
-
from .detection_types import QueueType
|
|
32
31
|
from .logger import log_once
|
|
32
|
+
from .types import QueueType
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
# taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
|
deepdoctection/utils/context.py
CHANGED
|
@@ -26,12 +26,12 @@ from glob import iglob
|
|
|
26
26
|
from os import path, remove
|
|
27
27
|
from tempfile import NamedTemporaryFile
|
|
28
28
|
from time import perf_counter as timer
|
|
29
|
-
from typing import Any, Generator, Iterator, Optional,
|
|
29
|
+
from typing import Any, Generator, Iterator, Optional, Union
|
|
30
30
|
|
|
31
31
|
import numpy as np
|
|
32
32
|
|
|
33
|
-
from .detection_types import ImageType
|
|
34
33
|
from .logger import LoggingRecord, logger
|
|
34
|
+
from .types import B64, B64Str, PixelValues
|
|
35
35
|
from .viz import viz_handler
|
|
36
36
|
|
|
37
37
|
__all__ = ["timeout_manager", "save_tmp_file", "timed_operation"]
|
|
@@ -72,7 +72,7 @@ def timeout_manager(proc, seconds: Optional[int] = None) -> Iterator[str]: # ty
|
|
|
72
72
|
|
|
73
73
|
|
|
74
74
|
@contextmanager
|
|
75
|
-
def save_tmp_file(image: Union[
|
|
75
|
+
def save_tmp_file(image: Union[B64Str, PixelValues, B64], prefix: str) -> Iterator[tuple[str, str]]:
|
|
76
76
|
"""
|
|
77
77
|
Save image temporarily and handle the clean-up once not necessary anymore
|
|
78
78
|
|
|
@@ -112,13 +112,20 @@ def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[
|
|
|
112
112
|
@contextmanager
|
|
113
113
|
def timed_operation(message: str, log_start: bool = False) -> Generator[Any, None, None]:
|
|
114
114
|
"""
|
|
115
|
-
Contextmanager with a timer.
|
|
115
|
+
Contextmanager with a timer.
|
|
116
116
|
|
|
117
|
-
|
|
117
|
+
... code-block:: python
|
|
118
|
+
|
|
119
|
+
with timed_operation(message="Your stdout message", log_start=True):
|
|
120
|
+
|
|
121
|
+
with open("log.txt", "a") as file:
|
|
122
|
+
...
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
:param message: a log to stdout
|
|
118
126
|
:param log_start: whether to print also the beginning
|
|
119
127
|
"""
|
|
120
128
|
|
|
121
|
-
assert len(message)
|
|
122
129
|
if log_start:
|
|
123
130
|
logger.info(LoggingRecord(f"start task: {message} ..."))
|
|
124
131
|
start = timer()
|
deepdoctection/utils/develop.py
CHANGED
|
@@ -26,19 +26,19 @@ import functools
|
|
|
26
26
|
import inspect
|
|
27
27
|
from collections import defaultdict
|
|
28
28
|
from datetime import datetime
|
|
29
|
-
from typing import Callable,
|
|
29
|
+
from typing import Callable, Optional
|
|
30
30
|
|
|
31
|
-
from .detection_types import T
|
|
32
31
|
from .logger import LoggingRecord, logger
|
|
32
|
+
from .types import T
|
|
33
33
|
|
|
34
|
-
__all__:
|
|
34
|
+
__all__: list[str] = ["deprecated"]
|
|
35
35
|
|
|
36
36
|
# Copy and paste from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py
|
|
37
37
|
|
|
38
38
|
_DEPRECATED_LOG_NUM = defaultdict(int) # type: ignore
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def log_deprecated(name: str
|
|
41
|
+
def log_deprecated(name: str, text: str, eos: str = "", max_num_warnings: Optional[int] = None) -> None:
|
|
42
42
|
"""
|
|
43
43
|
Log deprecation warning.
|
|
44
44
|
|