deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +8 -25
- deepdoctection/analyzer/dd.py +84 -71
- deepdoctection/dataflow/common.py +9 -5
- deepdoctection/dataflow/custom.py +5 -5
- deepdoctection/dataflow/custom_serialize.py +75 -18
- deepdoctection/dataflow/parallel_map.py +3 -3
- deepdoctection/dataflow/serialize.py +4 -4
- deepdoctection/dataflow/stats.py +3 -3
- deepdoctection/datapoint/annotation.py +78 -56
- deepdoctection/datapoint/box.py +7 -7
- deepdoctection/datapoint/convert.py +6 -6
- deepdoctection/datapoint/image.py +157 -75
- deepdoctection/datapoint/view.py +175 -151
- deepdoctection/datasets/adapter.py +30 -24
- deepdoctection/datasets/base.py +10 -10
- deepdoctection/datasets/dataflow_builder.py +3 -3
- deepdoctection/datasets/info.py +23 -25
- deepdoctection/datasets/instances/doclaynet.py +48 -49
- deepdoctection/datasets/instances/fintabnet.py +44 -45
- deepdoctection/datasets/instances/funsd.py +23 -23
- deepdoctection/datasets/instances/iiitar13k.py +8 -8
- deepdoctection/datasets/instances/layouttest.py +2 -2
- deepdoctection/datasets/instances/publaynet.py +3 -3
- deepdoctection/datasets/instances/pubtables1m.py +18 -18
- deepdoctection/datasets/instances/pubtabnet.py +30 -29
- deepdoctection/datasets/instances/rvlcdip.py +28 -29
- deepdoctection/datasets/instances/xfund.py +51 -30
- deepdoctection/datasets/save.py +6 -6
- deepdoctection/eval/accmetric.py +32 -33
- deepdoctection/eval/base.py +8 -9
- deepdoctection/eval/cocometric.py +13 -12
- deepdoctection/eval/eval.py +32 -26
- deepdoctection/eval/tedsmetric.py +16 -12
- deepdoctection/eval/tp_eval_callback.py +7 -16
- deepdoctection/extern/base.py +339 -134
- deepdoctection/extern/d2detect.py +69 -89
- deepdoctection/extern/deskew.py +11 -10
- deepdoctection/extern/doctrocr.py +81 -64
- deepdoctection/extern/fastlang.py +23 -16
- deepdoctection/extern/hfdetr.py +53 -38
- deepdoctection/extern/hflayoutlm.py +216 -155
- deepdoctection/extern/hflm.py +35 -30
- deepdoctection/extern/model.py +433 -255
- deepdoctection/extern/pdftext.py +15 -15
- deepdoctection/extern/pt/ptutils.py +4 -2
- deepdoctection/extern/tessocr.py +39 -38
- deepdoctection/extern/texocr.py +14 -16
- deepdoctection/extern/tp/tfutils.py +16 -2
- deepdoctection/extern/tp/tpcompat.py +11 -7
- deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
- deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
- deepdoctection/extern/tpdetect.py +40 -45
- deepdoctection/mapper/cats.py +36 -40
- deepdoctection/mapper/cocostruct.py +16 -12
- deepdoctection/mapper/d2struct.py +22 -22
- deepdoctection/mapper/hfstruct.py +7 -7
- deepdoctection/mapper/laylmstruct.py +22 -24
- deepdoctection/mapper/maputils.py +9 -10
- deepdoctection/mapper/match.py +33 -2
- deepdoctection/mapper/misc.py +6 -7
- deepdoctection/mapper/pascalstruct.py +4 -4
- deepdoctection/mapper/prodigystruct.py +6 -6
- deepdoctection/mapper/pubstruct.py +84 -92
- deepdoctection/mapper/tpstruct.py +3 -3
- deepdoctection/mapper/xfundstruct.py +33 -33
- deepdoctection/pipe/anngen.py +39 -14
- deepdoctection/pipe/base.py +68 -99
- deepdoctection/pipe/common.py +181 -85
- deepdoctection/pipe/concurrency.py +14 -10
- deepdoctection/pipe/doctectionpipe.py +24 -21
- deepdoctection/pipe/language.py +20 -25
- deepdoctection/pipe/layout.py +18 -16
- deepdoctection/pipe/lm.py +49 -47
- deepdoctection/pipe/order.py +63 -65
- deepdoctection/pipe/refine.py +102 -109
- deepdoctection/pipe/segment.py +157 -162
- deepdoctection/pipe/sub_layout.py +50 -40
- deepdoctection/pipe/text.py +37 -36
- deepdoctection/pipe/transform.py +19 -16
- deepdoctection/train/d2_frcnn_train.py +27 -25
- deepdoctection/train/hf_detr_train.py +22 -18
- deepdoctection/train/hf_layoutlm_train.py +49 -48
- deepdoctection/train/tp_frcnn_train.py +10 -11
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +13 -6
- deepdoctection/utils/develop.py +4 -4
- deepdoctection/utils/env_info.py +52 -14
- deepdoctection/utils/file_utils.py +6 -11
- deepdoctection/utils/fs.py +41 -14
- deepdoctection/utils/identifier.py +2 -2
- deepdoctection/utils/logger.py +15 -15
- deepdoctection/utils/metacfg.py +7 -7
- deepdoctection/utils/pdf_utils.py +39 -14
- deepdoctection/utils/settings.py +188 -182
- deepdoctection/utils/tqdm.py +1 -1
- deepdoctection/utils/transform.py +14 -9
- deepdoctection/utils/types.py +104 -0
- deepdoctection/utils/utils.py +7 -7
- deepdoctection/utils/viz.py +70 -69
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
- deepdoctection-0.34.dist-info/RECORD +146 -0
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
- deepdoctection/utils/detection_types.py +0 -68
- deepdoctection-0.32.dist-info/RECORD +0 -146
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
- {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
|
@@ -22,7 +22,8 @@ models that are a slightly different from the plain Detr model that are provided
|
|
|
22
22
|
from __future__ import annotations
|
|
23
23
|
|
|
24
24
|
import copy
|
|
25
|
-
|
|
25
|
+
import os
|
|
26
|
+
from typing import Any, Optional, Sequence, Type, Union
|
|
26
27
|
|
|
27
28
|
from lazy_imports import try_import
|
|
28
29
|
|
|
@@ -34,9 +35,10 @@ from ..eval.eval import Evaluator
|
|
|
34
35
|
from ..eval.registry import metric_registry
|
|
35
36
|
from ..extern.hfdetr import HFDetrDerivedDetector
|
|
36
37
|
from ..mapper.hfstruct import DetrDataCollator, image_to_hf_detr_training
|
|
37
|
-
from ..pipe.base import
|
|
38
|
+
from ..pipe.base import PipelineComponent
|
|
38
39
|
from ..pipe.registry import pipeline_component_registry
|
|
39
40
|
from ..utils.logger import LoggingRecord, logger
|
|
41
|
+
from ..utils.types import PathLikeOrStr
|
|
40
42
|
from ..utils.utils import string_to_dict
|
|
41
43
|
|
|
42
44
|
with try_import() as pt_import_guard:
|
|
@@ -74,13 +76,13 @@ class DetrDerivedTrainer(Trainer):
|
|
|
74
76
|
train_dataset: Dataset[Any],
|
|
75
77
|
):
|
|
76
78
|
self.evaluator: Optional[Evaluator] = None
|
|
77
|
-
self.build_eval_kwargs: Optional[
|
|
79
|
+
self.build_eval_kwargs: Optional[dict[str, Any]] = None
|
|
78
80
|
super().__init__(model, args, data_collator, train_dataset)
|
|
79
81
|
|
|
80
82
|
def setup_evaluator(
|
|
81
83
|
self,
|
|
82
84
|
dataset_val: DatasetBase,
|
|
83
|
-
pipeline_component:
|
|
85
|
+
pipeline_component: PipelineComponent,
|
|
84
86
|
metric: Union[Type[MetricBase], MetricBase],
|
|
85
87
|
**build_eval_kwargs: Union[str, int],
|
|
86
88
|
) -> None:
|
|
@@ -97,17 +99,15 @@ class DetrDerivedTrainer(Trainer):
|
|
|
97
99
|
self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1)
|
|
98
100
|
assert self.evaluator.pipe_component
|
|
99
101
|
for comp in self.evaluator.pipe_component.pipe_components:
|
|
100
|
-
|
|
101
|
-
assert isinstance(comp.predictor, HFDetrDerivedDetector)
|
|
102
|
-
comp.predictor.hf_detr_predictor = None
|
|
102
|
+
comp.clear_predictor()
|
|
103
103
|
self.build_eval_kwargs = build_eval_kwargs
|
|
104
104
|
|
|
105
105
|
def evaluate(
|
|
106
106
|
self,
|
|
107
107
|
eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
|
|
108
|
-
ignore_keys: Optional[
|
|
108
|
+
ignore_keys: Optional[list[str]] = None, # pylint: disable=W0613
|
|
109
109
|
metric_key_prefix: str = "eval", # pylint: disable=W0613
|
|
110
|
-
) ->
|
|
110
|
+
) -> dict[str, float]:
|
|
111
111
|
"""
|
|
112
112
|
Overwritten method from `Trainer`. Arguments will not be used.
|
|
113
113
|
"""
|
|
@@ -129,12 +129,12 @@ class DetrDerivedTrainer(Trainer):
|
|
|
129
129
|
|
|
130
130
|
|
|
131
131
|
def train_hf_detr(
|
|
132
|
-
path_config_json:
|
|
132
|
+
path_config_json: PathLikeOrStr,
|
|
133
133
|
dataset_train: Union[str, DatasetBase],
|
|
134
|
-
path_weights:
|
|
134
|
+
path_weights: PathLikeOrStr,
|
|
135
135
|
path_feature_extractor_config_json: str,
|
|
136
|
-
config_overwrite: Optional[
|
|
137
|
-
log_dir:
|
|
136
|
+
config_overwrite: Optional[list[str]] = None,
|
|
137
|
+
log_dir: PathLikeOrStr = "train_log/detr",
|
|
138
138
|
build_train_config: Optional[Sequence[str]] = None,
|
|
139
139
|
dataset_val: Optional[DatasetBase] = None,
|
|
140
140
|
build_val_config: Optional[Sequence[str]] = None,
|
|
@@ -169,13 +169,13 @@ def train_hf_detr(
|
|
|
169
169
|
:param pipeline_component_name: A pipeline component name to use for validation
|
|
170
170
|
"""
|
|
171
171
|
|
|
172
|
-
build_train_dict:
|
|
172
|
+
build_train_dict: dict[str, str] = {}
|
|
173
173
|
if build_train_config is not None:
|
|
174
174
|
build_train_dict = string_to_dict(",".join(build_train_config))
|
|
175
175
|
if "split" not in build_train_dict:
|
|
176
176
|
build_train_dict["split"] = "train"
|
|
177
177
|
|
|
178
|
-
build_val_dict:
|
|
178
|
+
build_val_dict: dict[str, str] = {}
|
|
179
179
|
if build_val_config is not None:
|
|
180
180
|
build_val_dict = string_to_dict(",".join(build_val_config))
|
|
181
181
|
if "split" not in build_val_dict:
|
|
@@ -191,12 +191,17 @@ def train_hf_detr(
|
|
|
191
191
|
categories_dict_name_as_key = dataset_train.dataflow.categories.get_categories(name_as_key=True, filtered=True)
|
|
192
192
|
|
|
193
193
|
dataset = DatasetAdapter(
|
|
194
|
-
dataset_train,
|
|
194
|
+
dataset_train,
|
|
195
|
+
True,
|
|
196
|
+
image_to_hf_detr_training(category_names=categories),
|
|
197
|
+
True,
|
|
198
|
+
number_repetitions=-1,
|
|
199
|
+
**build_train_dict,
|
|
195
200
|
)
|
|
196
201
|
|
|
197
202
|
number_samples = len(dataset)
|
|
198
203
|
conf_dict = {
|
|
199
|
-
"output_dir": log_dir,
|
|
204
|
+
"output_dir": os.fspath(log_dir),
|
|
200
205
|
"remove_unused_columns": False,
|
|
201
206
|
"per_device_train_batch_size": 2,
|
|
202
207
|
"max_steps": number_samples,
|
|
@@ -256,7 +261,6 @@ def train_hf_detr(
|
|
|
256
261
|
)
|
|
257
262
|
pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
|
|
258
263
|
pipeline_component = pipeline_component_cls(detector)
|
|
259
|
-
assert isinstance(pipeline_component, PredictorPipelineComponent)
|
|
260
264
|
|
|
261
265
|
if metric_name is not None:
|
|
262
266
|
metric = metric_registry.get(metric_name)
|
|
@@ -24,7 +24,7 @@ import copy
|
|
|
24
24
|
import json
|
|
25
25
|
import os
|
|
26
26
|
import pprint
|
|
27
|
-
from typing import Any,
|
|
27
|
+
from typing import Any, Optional, Sequence, Type, Union
|
|
28
28
|
|
|
29
29
|
from lazy_imports import try_import
|
|
30
30
|
|
|
@@ -47,12 +47,13 @@ from ..extern.hflayoutlm import (
|
|
|
47
47
|
from ..extern.hflm import HFLmSequenceClassifier
|
|
48
48
|
from ..extern.pt.ptutils import get_torch_device
|
|
49
49
|
from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features, image_to_raw_lm_features
|
|
50
|
-
from ..pipe.base import
|
|
50
|
+
from ..pipe.base import PipelineComponent
|
|
51
51
|
from ..pipe.registry import pipeline_component_registry
|
|
52
52
|
from ..utils.error import DependencyError
|
|
53
53
|
from ..utils.file_utils import wandb_available
|
|
54
54
|
from ..utils.logger import LoggingRecord, logger
|
|
55
55
|
from ..utils.settings import DatasetType, LayoutType, WordType
|
|
56
|
+
from ..utils.types import PathLikeOrStr
|
|
56
57
|
from ..utils.utils import string_to_dict
|
|
57
58
|
|
|
58
59
|
with try_import() as pt_import_guard:
|
|
@@ -82,7 +83,7 @@ with try_import() as wb_import_guard:
|
|
|
82
83
|
import wandb
|
|
83
84
|
|
|
84
85
|
|
|
85
|
-
def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) ->
|
|
86
|
+
def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> tuple[Any, Any, Any]:
|
|
86
87
|
"""
|
|
87
88
|
Get the model architecture, model wrapper and config class for a given model type and dataset type.
|
|
88
89
|
|
|
@@ -91,47 +92,47 @@ def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetTy
|
|
|
91
92
|
:return: Tuple of model architecture, model wrapper and config class
|
|
92
93
|
"""
|
|
93
94
|
return {
|
|
94
|
-
("layoutlm", DatasetType.
|
|
95
|
+
("layoutlm", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
95
96
|
LayoutLMForSequenceClassification,
|
|
96
97
|
HFLayoutLmSequenceClassifier,
|
|
97
98
|
PretrainedConfig,
|
|
98
99
|
),
|
|
99
|
-
("layoutlm", DatasetType.
|
|
100
|
+
("layoutlm", DatasetType.TOKEN_CLASSIFICATION): (
|
|
100
101
|
LayoutLMForTokenClassification,
|
|
101
102
|
HFLayoutLmTokenClassifier,
|
|
102
103
|
PretrainedConfig,
|
|
103
104
|
),
|
|
104
|
-
("layoutlmv2", DatasetType.
|
|
105
|
+
("layoutlmv2", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
105
106
|
LayoutLMv2ForSequenceClassification,
|
|
106
107
|
HFLayoutLmv2SequenceClassifier,
|
|
107
108
|
LayoutLMv2Config,
|
|
108
109
|
),
|
|
109
|
-
("layoutlmv2", DatasetType.
|
|
110
|
+
("layoutlmv2", DatasetType.TOKEN_CLASSIFICATION): (
|
|
110
111
|
LayoutLMv2ForTokenClassification,
|
|
111
112
|
HFLayoutLmv2TokenClassifier,
|
|
112
113
|
LayoutLMv2Config,
|
|
113
114
|
),
|
|
114
|
-
("layoutlmv3", DatasetType.
|
|
115
|
+
("layoutlmv3", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
115
116
|
LayoutLMv3ForSequenceClassification,
|
|
116
117
|
HFLayoutLmv3SequenceClassifier,
|
|
117
118
|
LayoutLMv3Config,
|
|
118
119
|
),
|
|
119
|
-
("layoutlmv3", DatasetType.
|
|
120
|
+
("layoutlmv3", DatasetType.TOKEN_CLASSIFICATION): (
|
|
120
121
|
LayoutLMv3ForTokenClassification,
|
|
121
122
|
HFLayoutLmv3TokenClassifier,
|
|
122
123
|
LayoutLMv3Config,
|
|
123
124
|
),
|
|
124
|
-
("lilt", DatasetType.
|
|
125
|
+
("lilt", DatasetType.TOKEN_CLASSIFICATION): (
|
|
125
126
|
LiltForTokenClassification,
|
|
126
127
|
HFLiltTokenClassifier,
|
|
127
128
|
PretrainedConfig,
|
|
128
129
|
),
|
|
129
|
-
("lilt", DatasetType.
|
|
130
|
+
("lilt", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
130
131
|
LiltForSequenceClassification,
|
|
131
132
|
HFLiltSequenceClassifier,
|
|
132
133
|
PretrainedConfig,
|
|
133
134
|
),
|
|
134
|
-
("xlm-roberta", DatasetType.
|
|
135
|
+
("xlm-roberta", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
135
136
|
XLMRobertaForSequenceClassification,
|
|
136
137
|
HFLmSequenceClassifier,
|
|
137
138
|
PretrainedConfig,
|
|
@@ -163,13 +164,13 @@ class LayoutLMTrainer(Trainer):
|
|
|
163
164
|
train_dataset: Dataset[Any],
|
|
164
165
|
):
|
|
165
166
|
self.evaluator: Optional[Evaluator] = None
|
|
166
|
-
self.build_eval_kwargs: Optional[
|
|
167
|
+
self.build_eval_kwargs: Optional[dict[str, Any]] = None
|
|
167
168
|
super().__init__(model, args, data_collator, train_dataset)
|
|
168
169
|
|
|
169
170
|
def setup_evaluator(
|
|
170
171
|
self,
|
|
171
172
|
dataset_val: DatasetBase,
|
|
172
|
-
pipeline_component:
|
|
173
|
+
pipeline_component: PipelineComponent,
|
|
173
174
|
metric: Union[Type[ClassificationMetric], ClassificationMetric],
|
|
174
175
|
run: Optional[wandb.sdk.wandb_run.Run] = None,
|
|
175
176
|
**build_eval_kwargs: Union[str, int],
|
|
@@ -188,15 +189,15 @@ class LayoutLMTrainer(Trainer):
|
|
|
188
189
|
self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1, run=run)
|
|
189
190
|
assert self.evaluator.pipe_component
|
|
190
191
|
for comp in self.evaluator.pipe_component.pipe_components:
|
|
191
|
-
comp.
|
|
192
|
+
comp.clear_predictor()
|
|
192
193
|
self.build_eval_kwargs = build_eval_kwargs
|
|
193
194
|
|
|
194
195
|
def evaluate(
|
|
195
196
|
self,
|
|
196
197
|
eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
|
|
197
|
-
ignore_keys: Optional[
|
|
198
|
+
ignore_keys: Optional[list[str]] = None, # pylint: disable=W0613
|
|
198
199
|
metric_key_prefix: str = "eval", # pylint: disable=W0613
|
|
199
|
-
) ->
|
|
200
|
+
) -> dict[str, float]:
|
|
200
201
|
"""
|
|
201
202
|
Overwritten method from `Trainer`. Arguments will not be used.
|
|
202
203
|
"""
|
|
@@ -220,8 +221,8 @@ class LayoutLMTrainer(Trainer):
|
|
|
220
221
|
|
|
221
222
|
|
|
222
223
|
def _get_model_class_and_tokenizer(
|
|
223
|
-
path_config_json:
|
|
224
|
-
) ->
|
|
224
|
+
path_config_json: PathLikeOrStr, dataset_type: DatasetType, use_xlm_tokenizer: bool
|
|
225
|
+
) -> tuple[Any, Any, Any, Any, Any]:
|
|
225
226
|
with open(path_config_json, "r", encoding="UTF-8") as file:
|
|
226
227
|
config_json = json.load(file)
|
|
227
228
|
|
|
@@ -244,11 +245,11 @@ def get_image_to_raw_features_mapping(input_str: str) -> Any:
|
|
|
244
245
|
|
|
245
246
|
|
|
246
247
|
def train_hf_layoutlm(
|
|
247
|
-
path_config_json:
|
|
248
|
+
path_config_json: PathLikeOrStr,
|
|
248
249
|
dataset_train: Union[str, DatasetBase],
|
|
249
|
-
path_weights:
|
|
250
|
-
config_overwrite: Optional[
|
|
251
|
-
log_dir:
|
|
250
|
+
path_weights: PathLikeOrStr,
|
|
251
|
+
config_overwrite: Optional[list[str]] = None,
|
|
252
|
+
log_dir: PathLikeOrStr = "train_log/layoutlm",
|
|
252
253
|
build_train_config: Optional[Sequence[str]] = None,
|
|
253
254
|
dataset_val: Optional[DatasetBase] = None,
|
|
254
255
|
build_val_config: Optional[Sequence[str]] = None,
|
|
@@ -323,13 +324,13 @@ def train_hf_layoutlm(
|
|
|
323
324
|
appear as child, it will use the word bounding box.
|
|
324
325
|
"""
|
|
325
326
|
|
|
326
|
-
build_train_dict:
|
|
327
|
+
build_train_dict: dict[str, str] = {}
|
|
327
328
|
if build_train_config is not None:
|
|
328
329
|
build_train_dict = string_to_dict(",".join(build_train_config))
|
|
329
330
|
if "split" not in build_train_dict:
|
|
330
331
|
build_train_dict["split"] = "train"
|
|
331
332
|
|
|
332
|
-
build_val_dict:
|
|
333
|
+
build_val_dict: dict[str, str] = {}
|
|
333
334
|
if build_val_config is not None:
|
|
334
335
|
build_val_dict = string_to_dict(",".join(build_val_config))
|
|
335
336
|
if "split" not in build_val_dict:
|
|
@@ -343,25 +344,25 @@ def train_hf_layoutlm(
|
|
|
343
344
|
|
|
344
345
|
# We wrap our dataset into a torch dataset
|
|
345
346
|
dataset_type = dataset_train.dataset_info.type
|
|
346
|
-
if dataset_type == DatasetType.
|
|
347
|
+
if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
|
|
347
348
|
categories_dict_name_as_key = dataset_train.dataflow.categories.get_categories(as_dict=True, name_as_key=True)
|
|
348
|
-
elif dataset_type == DatasetType.
|
|
349
|
+
elif dataset_type == DatasetType.TOKEN_CLASSIFICATION:
|
|
349
350
|
if use_token_tag:
|
|
350
351
|
categories_dict_name_as_key = dataset_train.dataflow.categories.get_sub_categories(
|
|
351
|
-
categories=LayoutType.
|
|
352
|
-
sub_categories={LayoutType.
|
|
352
|
+
categories=LayoutType.WORD,
|
|
353
|
+
sub_categories={LayoutType.WORD: [WordType.TOKEN_TAG]},
|
|
353
354
|
keys=False,
|
|
354
355
|
values_as_dict=True,
|
|
355
356
|
name_as_key=True,
|
|
356
|
-
)[LayoutType.
|
|
357
|
+
)[LayoutType.WORD][WordType.TOKEN_TAG]
|
|
357
358
|
else:
|
|
358
359
|
categories_dict_name_as_key = dataset_train.dataflow.categories.get_sub_categories(
|
|
359
|
-
categories=LayoutType.
|
|
360
|
-
sub_categories={LayoutType.
|
|
360
|
+
categories=LayoutType.WORD,
|
|
361
|
+
sub_categories={LayoutType.WORD: [WordType.TOKEN_CLASS]},
|
|
361
362
|
keys=False,
|
|
362
363
|
values_as_dict=True,
|
|
363
364
|
name_as_key=True,
|
|
364
|
-
)[LayoutType.
|
|
365
|
+
)[LayoutType.WORD][WordType.TOKEN_CLASS]
|
|
365
366
|
else:
|
|
366
367
|
raise UserWarning("Dataset type not supported for training")
|
|
367
368
|
|
|
@@ -372,13 +373,14 @@ def train_hf_layoutlm(
|
|
|
372
373
|
image_to_raw_features_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
|
|
373
374
|
if segment_positions:
|
|
374
375
|
image_to_raw_features_kwargs["segment_positions"] = segment_positions # type: ignore
|
|
375
|
-
image_to_raw_features_kwargs.update(model_wrapper_cls.
|
|
376
|
+
image_to_raw_features_kwargs.update(model_wrapper_cls.default_kwargs_for_image_to_features_mapping())
|
|
376
377
|
|
|
377
378
|
dataset = DatasetAdapter(
|
|
378
379
|
dataset_train,
|
|
379
380
|
True,
|
|
380
381
|
image_to_raw_features_func(**image_to_raw_features_kwargs),
|
|
381
382
|
use_token_tag,
|
|
383
|
+
number_repetitions=-1,
|
|
382
384
|
**build_train_dict,
|
|
383
385
|
)
|
|
384
386
|
|
|
@@ -388,7 +390,7 @@ def train_hf_layoutlm(
|
|
|
388
390
|
# Need to set remove_unused_columns to False, as the DataCollator for column removal will remove some raw features
|
|
389
391
|
# that are necessary for the tokenizer.
|
|
390
392
|
conf_dict = {
|
|
391
|
-
"output_dir": log_dir,
|
|
393
|
+
"output_dir": os.fspath(log_dir),
|
|
392
394
|
"remove_unused_columns": False,
|
|
393
395
|
"per_device_train_batch_size": 8,
|
|
394
396
|
"max_steps": number_samples,
|
|
@@ -429,16 +431,16 @@ def train_hf_layoutlm(
|
|
|
429
431
|
)
|
|
430
432
|
|
|
431
433
|
use_wandb = conf_dict.pop("use_wandb")
|
|
432
|
-
wandb_project = conf_dict.pop("wandb_project")
|
|
433
|
-
wandb_repo = conf_dict.pop("wandb_repo")
|
|
434
|
+
wandb_project = str(conf_dict.pop("wandb_project"))
|
|
435
|
+
wandb_repo = str(conf_dict.pop("wandb_repo"))
|
|
434
436
|
|
|
435
437
|
# Initialize Wandb, if necessary
|
|
436
438
|
run = None
|
|
437
439
|
if use_wandb:
|
|
438
440
|
if not wandb_available():
|
|
439
441
|
raise DependencyError("WandB must be installed separately")
|
|
440
|
-
run = wandb.init(project=wandb_project, config=conf_dict)
|
|
441
|
-
run._label(repo=wandb_repo) #
|
|
442
|
+
run = wandb.init(project=wandb_project, config=conf_dict)
|
|
443
|
+
run._label(repo=wandb_repo) # pylint: disable=W0212
|
|
442
444
|
else:
|
|
443
445
|
os.environ["WANDB_DISABLED"] = "True"
|
|
444
446
|
|
|
@@ -474,19 +476,19 @@ def train_hf_layoutlm(
|
|
|
474
476
|
|
|
475
477
|
if arguments.evaluation_strategy in (IntervalStrategy.STEPS,):
|
|
476
478
|
assert metric is not None # silence mypy
|
|
477
|
-
if dataset_type == DatasetType.
|
|
479
|
+
if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
|
|
478
480
|
categories = dataset_val.dataflow.categories.get_categories(filtered=True) # type: ignore
|
|
479
481
|
else:
|
|
480
482
|
if use_token_tag:
|
|
481
483
|
categories = dataset_val.dataflow.categories.get_sub_categories( # type: ignore
|
|
482
|
-
categories=LayoutType.
|
|
483
|
-
)[LayoutType.
|
|
484
|
-
metric.set_categories(category_names=LayoutType.
|
|
484
|
+
categories=LayoutType.WORD, sub_categories={LayoutType.WORD: [WordType.TOKEN_TAG]}, keys=False
|
|
485
|
+
)[LayoutType.WORD][WordType.TOKEN_TAG]
|
|
486
|
+
metric.set_categories(category_names=LayoutType.WORD, sub_category_names={"word": ["token_tag"]})
|
|
485
487
|
else:
|
|
486
488
|
categories = dataset_val.dataflow.categories.get_sub_categories( # type: ignore
|
|
487
|
-
categories=LayoutType.
|
|
488
|
-
)[LayoutType.
|
|
489
|
-
metric.set_categories(category_names=LayoutType.
|
|
489
|
+
categories=LayoutType.WORD, sub_categories={LayoutType.WORD: [WordType.TOKEN_CLASS]}, keys=False
|
|
490
|
+
)[LayoutType.WORD][WordType.TOKEN_CLASS]
|
|
491
|
+
metric.set_categories(category_names=LayoutType.WORD, sub_category_names={"word": ["token_class"]})
|
|
490
492
|
dd_model = model_wrapper_cls(
|
|
491
493
|
path_config_json=path_config_json,
|
|
492
494
|
path_weights=path_weights,
|
|
@@ -495,7 +497,7 @@ def train_hf_layoutlm(
|
|
|
495
497
|
use_xlm_tokenizer=use_xlm_tokenizer,
|
|
496
498
|
)
|
|
497
499
|
pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
|
|
498
|
-
if dataset_type == DatasetType.
|
|
500
|
+
if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
|
|
499
501
|
pipeline_component = pipeline_component_cls(tokenizer_fast, dd_model)
|
|
500
502
|
else:
|
|
501
503
|
pipeline_component = pipeline_component_cls(
|
|
@@ -504,7 +506,6 @@ def train_hf_layoutlm(
|
|
|
504
506
|
use_other_as_default_category=True,
|
|
505
507
|
sliding_window_stride=sliding_window_stride,
|
|
506
508
|
)
|
|
507
|
-
assert isinstance(pipeline_component, LanguageModelPipelineComponent)
|
|
508
509
|
|
|
509
510
|
trainer.setup_evaluator(dataset_val, pipeline_component, metric, run, **build_val_dict) # type: ignore
|
|
510
511
|
|
|
@@ -20,7 +20,7 @@ Module for training Tensorpack `GeneralizedRCNN`
|
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
import os
|
|
23
|
-
from typing import
|
|
23
|
+
from typing import Optional, Sequence, Type, Union
|
|
24
24
|
|
|
25
25
|
from lazy_imports import try_import
|
|
26
26
|
|
|
@@ -40,14 +40,13 @@ from ..extern.tp.tpfrcnn.preproc import anchors_and_labels, augment
|
|
|
40
40
|
from ..extern.tpdetect import TPFrcnnDetector
|
|
41
41
|
from ..mapper.maputils import LabelSummarizer
|
|
42
42
|
from ..mapper.tpstruct import image_to_tp_frcnn_training
|
|
43
|
-
from ..pipe.base import PredictorPipelineComponent
|
|
44
43
|
from ..pipe.registry import pipeline_component_registry
|
|
45
|
-
from ..utils.detection_types import JsonDict
|
|
46
44
|
from ..utils.file_utils import set_mp_spawn
|
|
47
45
|
from ..utils.fs import get_load_image_func
|
|
48
46
|
from ..utils.logger import log_once
|
|
49
47
|
from ..utils.metacfg import AttrDict, set_config_by_yaml
|
|
50
48
|
from ..utils.tqdm import get_tqdm
|
|
49
|
+
from ..utils.types import JsonDict, PathLikeOrStr
|
|
51
50
|
from ..utils.utils import string_to_dict
|
|
52
51
|
|
|
53
52
|
with try_import() as tp_import_guard:
|
|
@@ -185,11 +184,11 @@ def get_train_dataflow(
|
|
|
185
184
|
|
|
186
185
|
|
|
187
186
|
def train_faster_rcnn(
|
|
188
|
-
path_config_yaml:
|
|
187
|
+
path_config_yaml: PathLikeOrStr,
|
|
189
188
|
dataset_train: DatasetBase,
|
|
190
|
-
path_weights:
|
|
191
|
-
config_overwrite: Optional[
|
|
192
|
-
log_dir:
|
|
189
|
+
path_weights: PathLikeOrStr,
|
|
190
|
+
config_overwrite: Optional[list[str]] = None,
|
|
191
|
+
log_dir: PathLikeOrStr = "train_log/frcnn",
|
|
193
192
|
build_train_config: Optional[Sequence[str]] = None,
|
|
194
193
|
dataset_val: Optional[DatasetBase] = None,
|
|
195
194
|
build_val_config: Optional[Sequence[str]] = None,
|
|
@@ -224,13 +223,13 @@ def train_faster_rcnn(
|
|
|
224
223
|
|
|
225
224
|
assert disable_tfv2() # TP works only in Graph mode
|
|
226
225
|
|
|
227
|
-
build_train_dict:
|
|
226
|
+
build_train_dict: dict[str, str] = {}
|
|
228
227
|
if build_train_config is not None:
|
|
229
228
|
build_train_dict = string_to_dict(",".join(build_train_config))
|
|
230
229
|
if "split" not in build_train_dict:
|
|
231
230
|
build_train_dict["split"] = "train"
|
|
232
231
|
|
|
233
|
-
build_val_dict:
|
|
232
|
+
build_val_dict: dict[str, str] = {}
|
|
234
233
|
if build_val_config is not None:
|
|
235
234
|
build_val_dict = string_to_dict(",".join(build_val_config))
|
|
236
235
|
if "split" not in build_val_dict:
|
|
@@ -238,7 +237,7 @@ def train_faster_rcnn(
|
|
|
238
237
|
|
|
239
238
|
config_overwrite = [] if config_overwrite is None else config_overwrite
|
|
240
239
|
|
|
241
|
-
log_dir = "TRAIN.LOG_DIR=" + log_dir
|
|
240
|
+
log_dir = "TRAIN.LOG_DIR=" + os.fspath(log_dir)
|
|
242
241
|
config_overwrite.append(log_dir)
|
|
243
242
|
|
|
244
243
|
config = set_config_by_yaml(path_config_yaml)
|
|
@@ -299,7 +298,6 @@ def train_faster_rcnn(
|
|
|
299
298
|
) # only a wrapper for the predictor itself. Will be replaced in Callback
|
|
300
299
|
pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
|
|
301
300
|
pipeline_component = pipeline_component_cls(detector)
|
|
302
|
-
assert isinstance(pipeline_component, PredictorPipelineComponent)
|
|
303
301
|
category_names = list(categories.values())
|
|
304
302
|
callbacks.extend(
|
|
305
303
|
[
|
|
@@ -310,6 +308,7 @@ def train_faster_rcnn(
|
|
|
310
308
|
metric, # type: ignore
|
|
311
309
|
pipeline_component,
|
|
312
310
|
*model.get_inference_tensor_names(), # type: ignore
|
|
311
|
+
cfg=detector.model.cfg,
|
|
313
312
|
**build_val_dict
|
|
314
313
|
)
|
|
315
314
|
]
|
|
@@ -28,8 +28,8 @@ import threading
|
|
|
28
28
|
from contextlib import contextmanager
|
|
29
29
|
from typing import Any, Generator, Optional, no_type_check
|
|
30
30
|
|
|
31
|
-
from .detection_types import QueueType
|
|
32
31
|
from .logger import log_once
|
|
32
|
+
from .types import QueueType
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
# taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
|
deepdoctection/utils/context.py
CHANGED
|
@@ -26,12 +26,12 @@ from glob import iglob
|
|
|
26
26
|
from os import path, remove
|
|
27
27
|
from tempfile import NamedTemporaryFile
|
|
28
28
|
from time import perf_counter as timer
|
|
29
|
-
from typing import Any, Generator, Iterator, Optional,
|
|
29
|
+
from typing import Any, Generator, Iterator, Optional, Union
|
|
30
30
|
|
|
31
31
|
import numpy as np
|
|
32
32
|
|
|
33
|
-
from .detection_types import ImageType
|
|
34
33
|
from .logger import LoggingRecord, logger
|
|
34
|
+
from .types import B64, B64Str, PixelValues
|
|
35
35
|
from .viz import viz_handler
|
|
36
36
|
|
|
37
37
|
__all__ = ["timeout_manager", "save_tmp_file", "timed_operation"]
|
|
@@ -72,7 +72,7 @@ def timeout_manager(proc, seconds: Optional[int] = None) -> Iterator[str]: # ty
|
|
|
72
72
|
|
|
73
73
|
|
|
74
74
|
@contextmanager
|
|
75
|
-
def save_tmp_file(image: Union[
|
|
75
|
+
def save_tmp_file(image: Union[B64Str, PixelValues, B64], prefix: str) -> Iterator[tuple[str, str]]:
|
|
76
76
|
"""
|
|
77
77
|
Save image temporarily and handle the clean-up once not necessary anymore
|
|
78
78
|
|
|
@@ -112,13 +112,20 @@ def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[
|
|
|
112
112
|
@contextmanager
|
|
113
113
|
def timed_operation(message: str, log_start: bool = False) -> Generator[Any, None, None]:
|
|
114
114
|
"""
|
|
115
|
-
Contextmanager with a timer.
|
|
115
|
+
Contextmanager with a timer.
|
|
116
116
|
|
|
117
|
-
|
|
117
|
+
... code-block:: python
|
|
118
|
+
|
|
119
|
+
with timed_operation(message="Your stdout message", log_start=True):
|
|
120
|
+
|
|
121
|
+
with open("log.txt", "a") as file:
|
|
122
|
+
...
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
:param message: a log to stdout
|
|
118
126
|
:param log_start: whether to print also the beginning
|
|
119
127
|
"""
|
|
120
128
|
|
|
121
|
-
assert len(message)
|
|
122
129
|
if log_start:
|
|
123
130
|
logger.info(LoggingRecord(f"start task: {message} ..."))
|
|
124
131
|
start = timer()
|
deepdoctection/utils/develop.py
CHANGED
|
@@ -26,19 +26,19 @@ import functools
|
|
|
26
26
|
import inspect
|
|
27
27
|
from collections import defaultdict
|
|
28
28
|
from datetime import datetime
|
|
29
|
-
from typing import Callable,
|
|
29
|
+
from typing import Callable, Optional
|
|
30
30
|
|
|
31
|
-
from .detection_types import T
|
|
32
31
|
from .logger import LoggingRecord, logger
|
|
32
|
+
from .types import T
|
|
33
33
|
|
|
34
|
-
__all__:
|
|
34
|
+
__all__: list[str] = ["deprecated"]
|
|
35
35
|
|
|
36
36
|
# Copy and paste from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py
|
|
37
37
|
|
|
38
38
|
_DEPRECATED_LOG_NUM = defaultdict(int) # type: ignore
|
|
39
39
|
|
|
40
40
|
|
|
41
|
-
def log_deprecated(name: str
|
|
41
|
+
def log_deprecated(name: str, text: str, eos: str = "", max_num_warnings: Optional[int] = None) -> None:
|
|
42
42
|
"""
|
|
43
43
|
Log deprecation warning.
|
|
44
44
|
|