deepdoctection 0.30__py3-none-any.whl → 0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +38 -29
- deepdoctection/analyzer/dd.py +36 -29
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/dataflow/base.py +0 -19
- deepdoctection/dataflow/custom.py +4 -3
- deepdoctection/dataflow/custom_serialize.py +14 -5
- deepdoctection/dataflow/parallel_map.py +12 -11
- deepdoctection/dataflow/serialize.py +5 -4
- deepdoctection/datapoint/annotation.py +35 -13
- deepdoctection/datapoint/box.py +3 -5
- deepdoctection/datapoint/convert.py +3 -1
- deepdoctection/datapoint/image.py +79 -36
- deepdoctection/datapoint/view.py +152 -49
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +6 -3
- deepdoctection/datasets/base.py +86 -11
- deepdoctection/datasets/dataflow_builder.py +1 -1
- deepdoctection/datasets/info.py +4 -4
- deepdoctection/datasets/instances/doclaynet.py +3 -2
- deepdoctection/datasets/instances/fintabnet.py +2 -1
- deepdoctection/datasets/instances/funsd.py +2 -1
- deepdoctection/datasets/instances/iiitar13k.py +5 -2
- deepdoctection/datasets/instances/layouttest.py +4 -8
- deepdoctection/datasets/instances/publaynet.py +2 -2
- deepdoctection/datasets/instances/pubtables1m.py +6 -3
- deepdoctection/datasets/instances/pubtabnet.py +2 -1
- deepdoctection/datasets/instances/rvlcdip.py +2 -1
- deepdoctection/datasets/instances/xfund.py +2 -1
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/accmetric.py +1 -1
- deepdoctection/eval/base.py +5 -4
- deepdoctection/eval/cocometric.py +2 -1
- deepdoctection/eval/eval.py +19 -15
- deepdoctection/eval/tedsmetric.py +14 -11
- deepdoctection/eval/tp_eval_callback.py +14 -7
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/base.py +39 -13
- deepdoctection/extern/d2detect.py +182 -90
- deepdoctection/extern/deskew.py +36 -9
- deepdoctection/extern/doctrocr.py +265 -83
- deepdoctection/extern/fastlang.py +49 -9
- deepdoctection/extern/hfdetr.py +106 -55
- deepdoctection/extern/hflayoutlm.py +441 -122
- deepdoctection/extern/hflm.py +225 -0
- deepdoctection/extern/model.py +56 -47
- deepdoctection/extern/pdftext.py +10 -5
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +27 -18
- deepdoctection/extern/tessocr.py +134 -22
- deepdoctection/extern/texocr.py +6 -2
- deepdoctection/extern/tp/tfutils.py +43 -9
- deepdoctection/extern/tp/tpcompat.py +14 -11
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +8 -9
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +54 -30
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/d2struct.py +9 -7
- deepdoctection/mapper/hfstruct.py +7 -2
- deepdoctection/mapper/laylmstruct.py +164 -21
- deepdoctection/mapper/maputils.py +16 -3
- deepdoctection/mapper/misc.py +6 -3
- deepdoctection/mapper/prodigystruct.py +1 -1
- deepdoctection/mapper/pubstruct.py +10 -10
- deepdoctection/mapper/tpstruct.py +3 -3
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/anngen.py +35 -8
- deepdoctection/pipe/base.py +53 -19
- deepdoctection/pipe/common.py +23 -13
- deepdoctection/pipe/concurrency.py +2 -1
- deepdoctection/pipe/doctectionpipe.py +2 -2
- deepdoctection/pipe/language.py +3 -2
- deepdoctection/pipe/layout.py +6 -3
- deepdoctection/pipe/lm.py +34 -66
- deepdoctection/pipe/order.py +142 -35
- deepdoctection/pipe/refine.py +26 -24
- deepdoctection/pipe/segment.py +21 -16
- deepdoctection/pipe/{cell.py → sub_layout.py} +30 -9
- deepdoctection/pipe/text.py +14 -8
- deepdoctection/pipe/transform.py +16 -9
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +36 -28
- deepdoctection/train/hf_detr_train.py +26 -17
- deepdoctection/train/hf_layoutlm_train.py +133 -111
- deepdoctection/train/tp_frcnn_train.py +21 -19
- deepdoctection/utils/__init__.py +3 -0
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +2 -2
- deepdoctection/utils/env_info.py +41 -84
- deepdoctection/utils/error.py +84 -0
- deepdoctection/utils/file_utils.py +4 -15
- deepdoctection/utils/fs.py +7 -7
- deepdoctection/utils/logger.py +1 -0
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/pdf_utils.py +5 -4
- deepdoctection/utils/settings.py +6 -1
- deepdoctection/utils/transform.py +1 -1
- deepdoctection/utils/utils.py +0 -6
- deepdoctection/utils/viz.py +48 -5
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/METADATA +57 -73
- deepdoctection-0.32.dist-info/RECORD +146 -0
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/WHEEL +1 -1
- deepdoctection-0.30.dist-info/RECORD +0 -143
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
- {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
|
@@ -18,19 +18,12 @@
|
|
|
18
18
|
"""
|
|
19
19
|
Module for training Detectron2 `GeneralizedRCNN`
|
|
20
20
|
"""
|
|
21
|
-
|
|
21
|
+
from __future__ import annotations
|
|
22
22
|
|
|
23
23
|
import copy
|
|
24
24
|
from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union
|
|
25
25
|
|
|
26
|
-
from
|
|
27
|
-
from detectron2.data import DatasetMapper, build_detection_train_loader
|
|
28
|
-
from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
|
|
29
|
-
from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
|
|
30
|
-
from detectron2.utils import comm
|
|
31
|
-
from detectron2.utils.events import EventWriter, get_event_storage
|
|
32
|
-
from fvcore.nn.precise_bn import get_bn_modules # type: ignore
|
|
33
|
-
from torch.utils.data import DataLoader, IterableDataset
|
|
26
|
+
from lazy_imports import try_import
|
|
34
27
|
|
|
35
28
|
from ..datasets.adapter import DatasetAdapter
|
|
36
29
|
from ..datasets.base import DatasetBase
|
|
@@ -39,15 +32,28 @@ from ..eval.base import MetricBase
|
|
|
39
32
|
from ..eval.eval import Evaluator
|
|
40
33
|
from ..eval.registry import metric_registry
|
|
41
34
|
from ..extern.d2detect import D2FrcnnDetector
|
|
42
|
-
from ..extern.pt.ptutils import get_num_gpu
|
|
43
35
|
from ..mapper.d2struct import image_to_d2_frcnn_training
|
|
44
36
|
from ..pipe.base import PredictorPipelineComponent
|
|
45
37
|
from ..pipe.registry import pipeline_component_registry
|
|
38
|
+
from ..utils.error import DependencyError
|
|
46
39
|
from ..utils.file_utils import get_wandb_requirement, wandb_available
|
|
47
40
|
from ..utils.logger import LoggingRecord, logger
|
|
48
41
|
from ..utils.utils import string_to_dict
|
|
49
42
|
|
|
50
|
-
|
|
43
|
+
with try_import() as d2_import_guard:
|
|
44
|
+
from detectron2.config import CfgNode, get_cfg
|
|
45
|
+
from detectron2.data import DatasetMapper, build_detection_train_loader
|
|
46
|
+
from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
|
|
47
|
+
from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
|
|
48
|
+
from detectron2.utils import comm
|
|
49
|
+
from detectron2.utils.events import EventWriter, get_event_storage
|
|
50
|
+
from fvcore.nn.precise_bn import get_bn_modules # type: ignore
|
|
51
|
+
|
|
52
|
+
with try_import() as pt_import_guard:
|
|
53
|
+
from torch import cuda
|
|
54
|
+
from torch.utils.data import DataLoader, IterableDataset
|
|
55
|
+
|
|
56
|
+
with try_import() as wb_import_guard:
|
|
51
57
|
import wandb
|
|
52
58
|
|
|
53
59
|
|
|
@@ -111,7 +117,7 @@ class WandbWriter(EventWriter):
|
|
|
111
117
|
config = {}
|
|
112
118
|
self._window_size = window_size
|
|
113
119
|
self._run = wandb.init(project=project, config=config, **kwargs) if not wandb.run else wandb.run
|
|
114
|
-
self._run._label(repo=repo)
|
|
120
|
+
self._run._label(repo=repo)
|
|
115
121
|
|
|
116
122
|
def write(self) -> None:
|
|
117
123
|
storage = get_event_storage()
|
|
@@ -120,10 +126,10 @@ class WandbWriter(EventWriter):
|
|
|
120
126
|
for key, (val, _) in storage.latest_with_smoothing_hint(self._window_size).items():
|
|
121
127
|
log_dict[key] = val
|
|
122
128
|
|
|
123
|
-
self._run.log(log_dict)
|
|
129
|
+
self._run.log(log_dict)
|
|
124
130
|
|
|
125
131
|
def close(self) -> None:
|
|
126
|
-
self._run.finish()
|
|
132
|
+
self._run.finish()
|
|
127
133
|
|
|
128
134
|
|
|
129
135
|
class D2Trainer(DefaultTrainer):
|
|
@@ -153,16 +159,18 @@ class D2Trainer(DefaultTrainer):
|
|
|
153
159
|
ret = [
|
|
154
160
|
hooks.IterationTimer(),
|
|
155
161
|
hooks.LRScheduler(),
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
162
|
+
(
|
|
163
|
+
hooks.PreciseBN(
|
|
164
|
+
# Run at the same freq as (but before) evaluation.
|
|
165
|
+
cfg.TEST.EVAL_PERIOD,
|
|
166
|
+
self.model, # pylint: disable=E1101
|
|
167
|
+
# Build a new data loader to not affect training
|
|
168
|
+
self.build_train_loader(cfg),
|
|
169
|
+
cfg.TEST.PRECISE_BN.NUM_ITER,
|
|
170
|
+
)
|
|
171
|
+
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) # pylint: disable=E1101
|
|
172
|
+
else None
|
|
173
|
+
),
|
|
166
174
|
]
|
|
167
175
|
|
|
168
176
|
# Do PreciseBN before checkpointer, because it updates the model and need to
|
|
@@ -201,7 +209,7 @@ class D2Trainer(DefaultTrainer):
|
|
|
201
209
|
if self.cfg.WANDB.USE_WANDB:
|
|
202
210
|
_, _wandb_available, err_msg = get_wandb_requirement()
|
|
203
211
|
if not _wandb_available:
|
|
204
|
-
raise
|
|
212
|
+
raise DependencyError(err_msg)
|
|
205
213
|
if self.cfg.WANDB.PROJECT is None:
|
|
206
214
|
raise ValueError("When using W&B, you must specify a project, i.e. WANDB.PROJECT")
|
|
207
215
|
writers_list.append(WandbWriter(self.cfg.WANDB.PROJECT, self.cfg.WANDB.REPO, self.cfg))
|
|
@@ -256,7 +264,7 @@ class D2Trainer(DefaultTrainer):
|
|
|
256
264
|
dataset_val,
|
|
257
265
|
pipeline_component,
|
|
258
266
|
metric,
|
|
259
|
-
num_threads=
|
|
267
|
+
num_threads=cuda.device_count() * 2,
|
|
260
268
|
run=run,
|
|
261
269
|
)
|
|
262
270
|
if build_val_dict:
|
|
@@ -269,7 +277,7 @@ class D2Trainer(DefaultTrainer):
|
|
|
269
277
|
|
|
270
278
|
@classmethod
|
|
271
279
|
def build_evaluator(cls, cfg, dataset_name): # type: ignore
|
|
272
|
-
raise NotImplementedError
|
|
280
|
+
raise NotImplementedError()
|
|
273
281
|
|
|
274
282
|
|
|
275
283
|
def train_d2_faster_rcnn(
|
|
@@ -332,7 +340,7 @@ def train_d2_faster_rcnn(
|
|
|
332
340
|
:param pipeline_component_name: A pipeline component name to use for validation.
|
|
333
341
|
"""
|
|
334
342
|
|
|
335
|
-
assert
|
|
343
|
+
assert cuda.device_count() > 0, "Has to train with GPU!"
|
|
336
344
|
|
|
337
345
|
build_train_dict: Dict[str, str] = {}
|
|
338
346
|
if build_train_config is not None:
|
|
@@ -19,20 +19,12 @@
|
|
|
19
19
|
Module for training Hugging Face Detr implementation. Note, that this scripts only trans Tabletransformer like Detr
|
|
20
20
|
models that are a slightly different from the plain Detr model that are provided by the transformer library.
|
|
21
21
|
"""
|
|
22
|
+
from __future__ import annotations
|
|
22
23
|
|
|
23
24
|
import copy
|
|
24
25
|
from typing import Any, Dict, List, Optional, Sequence, Type, Union
|
|
25
26
|
|
|
26
|
-
from
|
|
27
|
-
from torch.utils.data import Dataset
|
|
28
|
-
from transformers import (
|
|
29
|
-
AutoFeatureExtractor,
|
|
30
|
-
IntervalStrategy,
|
|
31
|
-
PretrainedConfig,
|
|
32
|
-
PreTrainedModel,
|
|
33
|
-
TableTransformerForObjectDetection,
|
|
34
|
-
)
|
|
35
|
-
from transformers.trainer import Trainer, TrainingArguments
|
|
27
|
+
from lazy_imports import try_import
|
|
36
28
|
|
|
37
29
|
from ..datasets.adapter import DatasetAdapter
|
|
38
30
|
from ..datasets.base import DatasetBase
|
|
@@ -47,6 +39,21 @@ from ..pipe.registry import pipeline_component_registry
|
|
|
47
39
|
from ..utils.logger import LoggingRecord, logger
|
|
48
40
|
from ..utils.utils import string_to_dict
|
|
49
41
|
|
|
42
|
+
with try_import() as pt_import_guard:
|
|
43
|
+
from torch import nn
|
|
44
|
+
from torch.utils.data import Dataset
|
|
45
|
+
|
|
46
|
+
with try_import() as hf_import_guard:
|
|
47
|
+
from transformers import (
|
|
48
|
+
AutoFeatureExtractor,
|
|
49
|
+
IntervalStrategy,
|
|
50
|
+
PretrainedConfig,
|
|
51
|
+
PreTrainedModel,
|
|
52
|
+
TableTransformerForObjectDetection,
|
|
53
|
+
Trainer,
|
|
54
|
+
TrainingArguments,
|
|
55
|
+
)
|
|
56
|
+
|
|
50
57
|
|
|
51
58
|
class DetrDerivedTrainer(Trainer):
|
|
52
59
|
"""
|
|
@@ -61,7 +68,7 @@ class DetrDerivedTrainer(Trainer):
|
|
|
61
68
|
|
|
62
69
|
def __init__(
|
|
63
70
|
self,
|
|
64
|
-
model: Union[PreTrainedModel, Module],
|
|
71
|
+
model: Union[PreTrainedModel, nn.Module],
|
|
65
72
|
args: TrainingArguments,
|
|
66
73
|
data_collator: DetrDataCollator,
|
|
67
74
|
train_dataset: Dataset[Any],
|
|
@@ -97,9 +104,9 @@ class DetrDerivedTrainer(Trainer):
|
|
|
97
104
|
|
|
98
105
|
def evaluate(
|
|
99
106
|
self,
|
|
100
|
-
eval_dataset: Optional[Dataset[Any]] = None,
|
|
101
|
-
ignore_keys: Optional[List[str]] = None,
|
|
102
|
-
metric_key_prefix: str = "eval",
|
|
107
|
+
eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
|
|
108
|
+
ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
|
|
109
|
+
metric_key_prefix: str = "eval", # pylint: disable=W0613
|
|
103
110
|
) -> Dict[str, float]:
|
|
104
111
|
"""
|
|
105
112
|
Overwritten method from `Trainer`. Arguments will not be used.
|
|
@@ -193,9 +200,11 @@ def train_hf_detr(
|
|
|
193
200
|
"remove_unused_columns": False,
|
|
194
201
|
"per_device_train_batch_size": 2,
|
|
195
202
|
"max_steps": number_samples,
|
|
196
|
-
"evaluation_strategy":
|
|
197
|
-
|
|
198
|
-
|
|
203
|
+
"evaluation_strategy": (
|
|
204
|
+
"steps"
|
|
205
|
+
if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
|
|
206
|
+
else "no"
|
|
207
|
+
),
|
|
199
208
|
"eval_steps": 5000,
|
|
200
209
|
}
|
|
201
210
|
|
|
@@ -18,32 +18,15 @@
|
|
|
18
18
|
"""
|
|
19
19
|
Module for training Huggingface implementation of LayoutLm
|
|
20
20
|
"""
|
|
21
|
+
from __future__ import annotations
|
|
21
22
|
|
|
22
23
|
import copy
|
|
23
24
|
import json
|
|
24
25
|
import os
|
|
25
26
|
import pprint
|
|
26
|
-
from typing import Any, Dict, List,
|
|
27
|
-
|
|
28
|
-
from
|
|
29
|
-
from torch.utils.data import Dataset
|
|
30
|
-
from transformers import (
|
|
31
|
-
IntervalStrategy,
|
|
32
|
-
LayoutLMForSequenceClassification,
|
|
33
|
-
LayoutLMForTokenClassification,
|
|
34
|
-
LayoutLMTokenizerFast,
|
|
35
|
-
LayoutLMv2Config,
|
|
36
|
-
LayoutLMv2ForSequenceClassification,
|
|
37
|
-
LayoutLMv2ForTokenClassification,
|
|
38
|
-
LayoutLMv3Config,
|
|
39
|
-
LayoutLMv3ForSequenceClassification,
|
|
40
|
-
LayoutLMv3ForTokenClassification,
|
|
41
|
-
PretrainedConfig,
|
|
42
|
-
PreTrainedModel,
|
|
43
|
-
RobertaTokenizerFast,
|
|
44
|
-
XLMRobertaTokenizerFast,
|
|
45
|
-
)
|
|
46
|
-
from transformers.trainer import Trainer, TrainingArguments
|
|
27
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
|
28
|
+
|
|
29
|
+
from lazy_imports import try_import
|
|
47
30
|
|
|
48
31
|
from ..datasets.adapter import DatasetAdapter
|
|
49
32
|
from ..datasets.base import DatasetBase
|
|
@@ -57,78 +40,108 @@ from ..extern.hflayoutlm import (
|
|
|
57
40
|
HFLayoutLmv2TokenClassifier,
|
|
58
41
|
HFLayoutLmv3SequenceClassifier,
|
|
59
42
|
HFLayoutLmv3TokenClassifier,
|
|
43
|
+
HFLiltSequenceClassifier,
|
|
44
|
+
HFLiltTokenClassifier,
|
|
45
|
+
get_tokenizer_from_model_class,
|
|
60
46
|
)
|
|
61
|
-
from ..
|
|
47
|
+
from ..extern.hflm import HFLmSequenceClassifier
|
|
48
|
+
from ..extern.pt.ptutils import get_torch_device
|
|
49
|
+
from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features, image_to_raw_lm_features
|
|
62
50
|
from ..pipe.base import LanguageModelPipelineComponent
|
|
63
|
-
from ..pipe.lm import get_tokenizer_from_architecture
|
|
64
51
|
from ..pipe.registry import pipeline_component_registry
|
|
65
|
-
from ..utils.
|
|
52
|
+
from ..utils.error import DependencyError
|
|
66
53
|
from ..utils.file_utils import wandb_available
|
|
67
54
|
from ..utils.logger import LoggingRecord, logger
|
|
68
|
-
from ..utils.settings import DatasetType, LayoutType,
|
|
55
|
+
from ..utils.settings import DatasetType, LayoutType, WordType
|
|
69
56
|
from ..utils.utils import string_to_dict
|
|
70
57
|
|
|
71
|
-
|
|
72
|
-
import
|
|
73
|
-
|
|
74
|
-
_ARCHITECTURES_TO_MODEL_CLASS = {
|
|
75
|
-
"LayoutLMForTokenClassification": (LayoutLMForTokenClassification, HFLayoutLmTokenClassifier, PretrainedConfig),
|
|
76
|
-
"LayoutLMForSequenceClassification": (
|
|
77
|
-
LayoutLMForSequenceClassification,
|
|
78
|
-
HFLayoutLmSequenceClassifier,
|
|
79
|
-
PretrainedConfig,
|
|
80
|
-
),
|
|
81
|
-
"LayoutLMv2ForTokenClassification": (
|
|
82
|
-
LayoutLMv2ForTokenClassification,
|
|
83
|
-
HFLayoutLmv2TokenClassifier,
|
|
84
|
-
LayoutLMv2Config,
|
|
85
|
-
),
|
|
86
|
-
"LayoutLMv2ForSequenceClassification": (
|
|
87
|
-
LayoutLMv2ForSequenceClassification,
|
|
88
|
-
HFLayoutLmv2SequenceClassifier,
|
|
89
|
-
LayoutLMv2Config,
|
|
90
|
-
),
|
|
91
|
-
}
|
|
92
|
-
|
|
58
|
+
with try_import() as pt_import_guard:
|
|
59
|
+
from torch import nn
|
|
60
|
+
from torch.utils.data import Dataset
|
|
93
61
|
|
|
94
|
-
|
|
95
|
-
|
|
62
|
+
with try_import() as tr_import_guard:
|
|
63
|
+
from transformers import (
|
|
64
|
+
IntervalStrategy,
|
|
96
65
|
LayoutLMForSequenceClassification,
|
|
97
|
-
HFLayoutLmSequenceClassifier,
|
|
98
|
-
PretrainedConfig,
|
|
99
|
-
),
|
|
100
|
-
("layoutlm", DatasetType.token_classification): (
|
|
101
66
|
LayoutLMForTokenClassification,
|
|
102
|
-
HFLayoutLmTokenClassifier,
|
|
103
|
-
PretrainedConfig,
|
|
104
|
-
),
|
|
105
|
-
("layoutlmv2", DatasetType.sequence_classification): (
|
|
106
|
-
LayoutLMv2ForSequenceClassification,
|
|
107
|
-
HFLayoutLmv2SequenceClassifier,
|
|
108
67
|
LayoutLMv2Config,
|
|
109
|
-
|
|
110
|
-
("layoutlmv2", DatasetType.token_classification): (
|
|
68
|
+
LayoutLMv2ForSequenceClassification,
|
|
111
69
|
LayoutLMv2ForTokenClassification,
|
|
112
|
-
HFLayoutLmv2TokenClassifier,
|
|
113
|
-
LayoutLMv2Config,
|
|
114
|
-
),
|
|
115
|
-
("layoutlmv3", DatasetType.sequence_classification): (
|
|
116
|
-
LayoutLMv3ForSequenceClassification,
|
|
117
|
-
HFLayoutLmv3SequenceClassifier,
|
|
118
70
|
LayoutLMv3Config,
|
|
119
|
-
|
|
120
|
-
("layoutlmv3", DatasetType.token_classification): (
|
|
71
|
+
LayoutLMv3ForSequenceClassification,
|
|
121
72
|
LayoutLMv3ForTokenClassification,
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
73
|
+
LiltForSequenceClassification,
|
|
74
|
+
LiltForTokenClassification,
|
|
75
|
+
PretrainedConfig,
|
|
76
|
+
PreTrainedModel,
|
|
77
|
+
XLMRobertaForSequenceClassification,
|
|
78
|
+
)
|
|
79
|
+
from transformers.trainer import Trainer, TrainingArguments
|
|
80
|
+
|
|
81
|
+
with try_import() as wb_import_guard:
|
|
82
|
+
import wandb
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> Tuple[Any, Any, Any]:
|
|
86
|
+
"""
|
|
87
|
+
Get the model architecture, model wrapper and config class for a given model type and dataset type.
|
|
88
|
+
|
|
89
|
+
:param model_type: The model type
|
|
90
|
+
:param dataset_type: The dataset type
|
|
91
|
+
:return: Tuple of model architecture, model wrapper and config class
|
|
92
|
+
"""
|
|
93
|
+
return {
|
|
94
|
+
("layoutlm", DatasetType.sequence_classification): (
|
|
95
|
+
LayoutLMForSequenceClassification,
|
|
96
|
+
HFLayoutLmSequenceClassifier,
|
|
97
|
+
PretrainedConfig,
|
|
98
|
+
),
|
|
99
|
+
("layoutlm", DatasetType.token_classification): (
|
|
100
|
+
LayoutLMForTokenClassification,
|
|
101
|
+
HFLayoutLmTokenClassifier,
|
|
102
|
+
PretrainedConfig,
|
|
103
|
+
),
|
|
104
|
+
("layoutlmv2", DatasetType.sequence_classification): (
|
|
105
|
+
LayoutLMv2ForSequenceClassification,
|
|
106
|
+
HFLayoutLmv2SequenceClassifier,
|
|
107
|
+
LayoutLMv2Config,
|
|
108
|
+
),
|
|
109
|
+
("layoutlmv2", DatasetType.token_classification): (
|
|
110
|
+
LayoutLMv2ForTokenClassification,
|
|
111
|
+
HFLayoutLmv2TokenClassifier,
|
|
112
|
+
LayoutLMv2Config,
|
|
113
|
+
),
|
|
114
|
+
("layoutlmv3", DatasetType.sequence_classification): (
|
|
115
|
+
LayoutLMv3ForSequenceClassification,
|
|
116
|
+
HFLayoutLmv3SequenceClassifier,
|
|
117
|
+
LayoutLMv3Config,
|
|
118
|
+
),
|
|
119
|
+
("layoutlmv3", DatasetType.token_classification): (
|
|
120
|
+
LayoutLMv3ForTokenClassification,
|
|
121
|
+
HFLayoutLmv3TokenClassifier,
|
|
122
|
+
LayoutLMv3Config,
|
|
123
|
+
),
|
|
124
|
+
("lilt", DatasetType.token_classification): (
|
|
125
|
+
LiltForTokenClassification,
|
|
126
|
+
HFLiltTokenClassifier,
|
|
127
|
+
PretrainedConfig,
|
|
128
|
+
),
|
|
129
|
+
("lilt", DatasetType.sequence_classification): (
|
|
130
|
+
LiltForSequenceClassification,
|
|
131
|
+
HFLiltSequenceClassifier,
|
|
132
|
+
PretrainedConfig,
|
|
133
|
+
),
|
|
134
|
+
("xlm-roberta", DatasetType.sequence_classification): (
|
|
135
|
+
XLMRobertaForSequenceClassification,
|
|
136
|
+
HFLmSequenceClassifier,
|
|
137
|
+
PretrainedConfig,
|
|
138
|
+
),
|
|
139
|
+
}[(model_type, dataset_type)]
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def maybe_remove_bounding_box_features(model_type: str) -> bool:
|
|
143
|
+
"""Listing of models that do not need bounding box features."""
|
|
144
|
+
return {"xlm-roberta": True}.get(model_type, False)
|
|
132
145
|
|
|
133
146
|
|
|
134
147
|
class LayoutLMTrainer(Trainer):
|
|
@@ -144,7 +157,7 @@ class LayoutLMTrainer(Trainer):
|
|
|
144
157
|
|
|
145
158
|
def __init__(
|
|
146
159
|
self,
|
|
147
|
-
model: Union[PreTrainedModel, Module],
|
|
160
|
+
model: Union[PreTrainedModel, nn.Module],
|
|
148
161
|
args: TrainingArguments,
|
|
149
162
|
data_collator: LayoutLMDataCollator,
|
|
150
163
|
train_dataset: Dataset[Any],
|
|
@@ -158,7 +171,7 @@ class LayoutLMTrainer(Trainer):
|
|
|
158
171
|
dataset_val: DatasetBase,
|
|
159
172
|
pipeline_component: LanguageModelPipelineComponent,
|
|
160
173
|
metric: Union[Type[ClassificationMetric], ClassificationMetric],
|
|
161
|
-
run: Optional[
|
|
174
|
+
run: Optional[wandb.sdk.wandb_run.Run] = None,
|
|
162
175
|
**build_eval_kwargs: Union[str, int],
|
|
163
176
|
) -> None:
|
|
164
177
|
"""
|
|
@@ -180,15 +193,17 @@ class LayoutLMTrainer(Trainer):
|
|
|
180
193
|
|
|
181
194
|
def evaluate(
|
|
182
195
|
self,
|
|
183
|
-
eval_dataset: Optional[Dataset[Any]] = None,
|
|
184
|
-
ignore_keys: Optional[List[str]] = None,
|
|
185
|
-
metric_key_prefix: str = "eval",
|
|
196
|
+
eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
|
|
197
|
+
ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
|
|
198
|
+
metric_key_prefix: str = "eval", # pylint: disable=W0613
|
|
186
199
|
) -> Dict[str, float]:
|
|
187
200
|
"""
|
|
188
201
|
Overwritten method from `Trainer`. Arguments will not be used.
|
|
189
202
|
"""
|
|
190
|
-
|
|
191
|
-
|
|
203
|
+
if self.evaluator is None:
|
|
204
|
+
raise ValueError("Evaluator not set up. Please use `setup_evaluator` before running evaluation")
|
|
205
|
+
if self.evaluator.pipe_component is None:
|
|
206
|
+
raise ValueError("Pipeline component not set up. Please use `setup_evaluator` before running evaluation")
|
|
192
207
|
|
|
193
208
|
# memory metrics - must set up as early as possible
|
|
194
209
|
self._memory_tracker.start()
|
|
@@ -205,26 +220,27 @@ class LayoutLMTrainer(Trainer):
|
|
|
205
220
|
|
|
206
221
|
|
|
207
222
|
def _get_model_class_and_tokenizer(
|
|
208
|
-
path_config_json: str, dataset_type:
|
|
209
|
-
) -> Tuple[Any, Any, Any, Any]:
|
|
223
|
+
path_config_json: str, dataset_type: DatasetType, use_xlm_tokenizer: bool
|
|
224
|
+
) -> Tuple[Any, Any, Any, Any, Any]:
|
|
210
225
|
with open(path_config_json, "r", encoding="UTF-8") as file:
|
|
211
226
|
config_json = json.load(file)
|
|
212
227
|
|
|
213
|
-
model_type
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
model_cls, model_wrapper_cls, config_cls = _ARCHITECTURES_TO_MODEL_CLASS[architectures[0]]
|
|
217
|
-
tokenizer_fast = get_tokenizer_from_architecture(architectures[0], use_xlm_tokenizer)
|
|
218
|
-
elif model_type:
|
|
219
|
-
model_cls, model_wrapper_cls, config_cls = _MODEL_TYPE_AND_TASK_TO_MODEL_CLASS[(model_type, dataset_type)]
|
|
220
|
-
tokenizer_fast = _MODEL_TYPE_TO_TOKENIZER[(model_type, use_xlm_tokenizer)]
|
|
228
|
+
if model_type := config_json.get("model_type"):
|
|
229
|
+
model_cls, model_wrapper_cls, config_cls = get_model_architectures_and_configs(model_type, dataset_type)
|
|
230
|
+
remove_box_features = maybe_remove_bounding_box_features(model_type)
|
|
221
231
|
else:
|
|
222
|
-
raise KeyError("model_type
|
|
232
|
+
raise KeyError("model_type not available in configs. It seems that the config is not valid")
|
|
233
|
+
|
|
234
|
+
tokenizer_fast = get_tokenizer_from_model_class(model_cls.__name__, use_xlm_tokenizer)
|
|
235
|
+
return config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features
|
|
223
236
|
|
|
224
|
-
if not model_cls:
|
|
225
|
-
raise ValueError("model not eligible to run with this framework")
|
|
226
237
|
|
|
227
|
-
|
|
238
|
+
def get_image_to_raw_features_mapping(input_str: str) -> Any:
|
|
239
|
+
"""Replacing eval functions"""
|
|
240
|
+
return {
|
|
241
|
+
"image_to_raw_layoutlm_features": image_to_raw_layoutlm_features,
|
|
242
|
+
"image_to_raw_lm_features": image_to_raw_lm_features,
|
|
243
|
+
}[input_str]
|
|
228
244
|
|
|
229
245
|
|
|
230
246
|
def train_hf_layoutlm(
|
|
@@ -347,19 +363,21 @@ def train_hf_layoutlm(
|
|
|
347
363
|
name_as_key=True,
|
|
348
364
|
)[LayoutType.word][WordType.token_class]
|
|
349
365
|
else:
|
|
350
|
-
raise
|
|
366
|
+
raise UserWarning("Dataset type not supported for training")
|
|
351
367
|
|
|
352
|
-
config_cls, model_cls, model_wrapper_cls, tokenizer_fast = _get_model_class_and_tokenizer(
|
|
368
|
+
config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features = _get_model_class_and_tokenizer(
|
|
353
369
|
path_config_json, dataset_type, use_xlm_tokenizer
|
|
354
370
|
)
|
|
355
|
-
|
|
371
|
+
image_to_raw_features_func = get_image_to_raw_features_mapping(model_wrapper_cls.image_to_raw_features_mapping())
|
|
372
|
+
image_to_raw_features_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
|
|
356
373
|
if segment_positions:
|
|
357
|
-
|
|
358
|
-
|
|
374
|
+
image_to_raw_features_kwargs["segment_positions"] = segment_positions # type: ignore
|
|
375
|
+
image_to_raw_features_kwargs.update(model_wrapper_cls.default_kwargs_for_input_mapping())
|
|
376
|
+
|
|
359
377
|
dataset = DatasetAdapter(
|
|
360
378
|
dataset_train,
|
|
361
379
|
True,
|
|
362
|
-
|
|
380
|
+
image_to_raw_features_func(**image_to_raw_features_kwargs),
|
|
363
381
|
use_token_tag,
|
|
364
382
|
**build_train_dict,
|
|
365
383
|
)
|
|
@@ -374,9 +392,11 @@ def train_hf_layoutlm(
|
|
|
374
392
|
"remove_unused_columns": False,
|
|
375
393
|
"per_device_train_batch_size": 8,
|
|
376
394
|
"max_steps": number_samples,
|
|
377
|
-
"evaluation_strategy":
|
|
378
|
-
|
|
379
|
-
|
|
395
|
+
"evaluation_strategy": (
|
|
396
|
+
"steps"
|
|
397
|
+
if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
|
|
398
|
+
else "no"
|
|
399
|
+
),
|
|
380
400
|
"eval_steps": 100,
|
|
381
401
|
"use_wandb": False,
|
|
382
402
|
"wandb_project": None,
|
|
@@ -416,7 +436,7 @@ def train_hf_layoutlm(
|
|
|
416
436
|
run = None
|
|
417
437
|
if use_wandb:
|
|
418
438
|
if not wandb_available():
|
|
419
|
-
raise
|
|
439
|
+
raise DependencyError("WandB must be installed separately")
|
|
420
440
|
run = wandb.init(project=wandb_project, config=conf_dict) # type: ignore
|
|
421
441
|
run._label(repo=wandb_repo) # type: ignore # pylint: disable=W0212
|
|
422
442
|
else:
|
|
@@ -448,6 +468,7 @@ def train_hf_layoutlm(
|
|
|
448
468
|
return_tensors="pt",
|
|
449
469
|
sliding_window_stride=sliding_window_stride, # type: ignore
|
|
450
470
|
max_batch_size=max_batch_size, # type: ignore
|
|
471
|
+
remove_bounding_box_features=remove_box_features,
|
|
451
472
|
)
|
|
452
473
|
trainer = LayoutLMTrainer(model, arguments, data_collator, dataset)
|
|
453
474
|
|
|
@@ -470,7 +491,8 @@ def train_hf_layoutlm(
|
|
|
470
491
|
path_config_json=path_config_json,
|
|
471
492
|
path_weights=path_weights,
|
|
472
493
|
categories=categories,
|
|
473
|
-
device=
|
|
494
|
+
device=get_torch_device(),
|
|
495
|
+
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
474
496
|
)
|
|
475
497
|
pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
|
|
476
498
|
if dataset_type == DatasetType.sequence_classification:
|
|
@@ -22,25 +22,7 @@ Module for training Tensorpack `GeneralizedRCNN`
|
|
|
22
22
|
import os
|
|
23
23
|
from typing import Dict, List, Optional, Sequence, Type, Union
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
from tensorpack.callbacks import (
|
|
27
|
-
EstimatedTimeLeft,
|
|
28
|
-
GPUMemoryTracker,
|
|
29
|
-
GPUUtilizationTracker,
|
|
30
|
-
HostMemoryTracker,
|
|
31
|
-
ModelSaver,
|
|
32
|
-
PeriodicCallback,
|
|
33
|
-
ScheduledHyperParamSetter,
|
|
34
|
-
SessionRunTimeout,
|
|
35
|
-
ThroughputTracker,
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# todo: check how dataflow import is directly possible without having AssertionError
|
|
39
|
-
from tensorpack.dataflow import ProxyDataFlow, imgaug
|
|
40
|
-
from tensorpack.input_source import QueueInput
|
|
41
|
-
from tensorpack.tfutils import SmartInit
|
|
42
|
-
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
|
|
43
|
-
from tensorpack.utils import logger
|
|
25
|
+
from lazy_imports import try_import
|
|
44
26
|
|
|
45
27
|
from ..dataflow.base import DataFlow
|
|
46
28
|
from ..dataflow.common import MapData
|
|
@@ -68,6 +50,26 @@ from ..utils.metacfg import AttrDict, set_config_by_yaml
|
|
|
68
50
|
from ..utils.tqdm import get_tqdm
|
|
69
51
|
from ..utils.utils import string_to_dict
|
|
70
52
|
|
|
53
|
+
with try_import() as tp_import_guard:
|
|
54
|
+
# todo: check how dataflow import is directly possible without having an AssertionError
|
|
55
|
+
# pylint: disable=import-error
|
|
56
|
+
from tensorpack.callbacks import (
|
|
57
|
+
EstimatedTimeLeft,
|
|
58
|
+
GPUMemoryTracker,
|
|
59
|
+
GPUUtilizationTracker,
|
|
60
|
+
HostMemoryTracker,
|
|
61
|
+
ModelSaver,
|
|
62
|
+
PeriodicCallback,
|
|
63
|
+
ScheduledHyperParamSetter,
|
|
64
|
+
SessionRunTimeout,
|
|
65
|
+
ThroughputTracker,
|
|
66
|
+
)
|
|
67
|
+
from tensorpack.dataflow import ProxyDataFlow, imgaug
|
|
68
|
+
from tensorpack.input_source import QueueInput
|
|
69
|
+
from tensorpack.tfutils import SmartInit
|
|
70
|
+
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
|
|
71
|
+
from tensorpack.utils import logger
|
|
72
|
+
|
|
71
73
|
__all__ = ["train_faster_rcnn"]
|
|
72
74
|
|
|
73
75
|
|
deepdoctection/utils/__init__.py
CHANGED
|
@@ -6,7 +6,10 @@ Init file for utils package
|
|
|
6
6
|
"""
|
|
7
7
|
from typing import Optional, Tuple, Union, no_type_check
|
|
8
8
|
|
|
9
|
+
from .concurrency import *
|
|
9
10
|
from .context import *
|
|
11
|
+
from .env_info import *
|
|
12
|
+
from .error import *
|
|
10
13
|
from .file_utils import *
|
|
11
14
|
from .fs import *
|
|
12
15
|
from .identifier import *
|
|
@@ -109,7 +109,7 @@ def enable_death_signal(_warn: bool = True) -> None:
|
|
|
109
109
|
prctl, "set_pdeathsig"
|
|
110
110
|
), "prctl.set_pdeathsig does not exist! Note that you need to install 'python-prctl' instead of 'prctl'."
|
|
111
111
|
# is SIGHUP a good choice?
|
|
112
|
-
prctl.set_pdeathsig(signal.SIGHUP)
|
|
112
|
+
prctl.set_pdeathsig(signal.SIGHUP) # pylint: disable=E1101
|
|
113
113
|
|
|
114
114
|
|
|
115
115
|
# taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
|
deepdoctection/utils/context.py
CHANGED
|
@@ -61,7 +61,7 @@ def timeout_manager(proc, seconds: Optional[int] = None) -> Iterator[str]: # ty
|
|
|
61
61
|
proc.terminate()
|
|
62
62
|
proc.kill()
|
|
63
63
|
proc.returncode = -1
|
|
64
|
-
raise RuntimeError("
|
|
64
|
+
raise RuntimeError(f"timeout for process id: {proc.pid}") # pylint: disable=W0707
|
|
65
65
|
finally:
|
|
66
66
|
if proc.stdin is not None:
|
|
67
67
|
proc.stdin.close()
|
|
@@ -88,7 +88,7 @@ def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[
|
|
|
88
88
|
yield file.name, path.realpath(path.normpath(path.normcase(image)))
|
|
89
89
|
return
|
|
90
90
|
if isinstance(image, (np.ndarray, np.generic)):
|
|
91
|
-
input_file_name = file.name + ".PNG"
|
|
91
|
+
input_file_name = file.name + "_input.PNG"
|
|
92
92
|
viz_handler.write_image(input_file_name, image)
|
|
93
93
|
yield file.name, input_file_name
|
|
94
94
|
if isinstance(image, bytes):
|