deepdoctection 0.42.0__py3-none-any.whl → 0.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +2 -1
- deepdoctection/analyzer/__init__.py +2 -1
- deepdoctection/analyzer/config.py +904 -0
- deepdoctection/analyzer/dd.py +36 -62
- deepdoctection/analyzer/factory.py +311 -141
- deepdoctection/configs/conf_dd_one.yaml +100 -44
- deepdoctection/configs/profiles.jsonl +32 -0
- deepdoctection/dataflow/__init__.py +9 -6
- deepdoctection/dataflow/base.py +33 -15
- deepdoctection/dataflow/common.py +96 -75
- deepdoctection/dataflow/custom.py +36 -29
- deepdoctection/dataflow/custom_serialize.py +135 -91
- deepdoctection/dataflow/parallel_map.py +33 -31
- deepdoctection/dataflow/serialize.py +15 -10
- deepdoctection/dataflow/stats.py +41 -28
- deepdoctection/datapoint/__init__.py +4 -6
- deepdoctection/datapoint/annotation.py +104 -66
- deepdoctection/datapoint/box.py +190 -130
- deepdoctection/datapoint/convert.py +66 -39
- deepdoctection/datapoint/image.py +151 -95
- deepdoctection/datapoint/view.py +383 -236
- deepdoctection/datasets/__init__.py +2 -6
- deepdoctection/datasets/adapter.py +11 -11
- deepdoctection/datasets/base.py +118 -81
- deepdoctection/datasets/dataflow_builder.py +18 -12
- deepdoctection/datasets/info.py +76 -57
- deepdoctection/datasets/instances/__init__.py +6 -2
- deepdoctection/datasets/instances/doclaynet.py +17 -14
- deepdoctection/datasets/instances/fintabnet.py +16 -22
- deepdoctection/datasets/instances/funsd.py +11 -6
- deepdoctection/datasets/instances/iiitar13k.py +9 -9
- deepdoctection/datasets/instances/layouttest.py +9 -9
- deepdoctection/datasets/instances/publaynet.py +9 -9
- deepdoctection/datasets/instances/pubtables1m.py +13 -13
- deepdoctection/datasets/instances/pubtabnet.py +13 -15
- deepdoctection/datasets/instances/rvlcdip.py +8 -8
- deepdoctection/datasets/instances/xfund.py +11 -9
- deepdoctection/datasets/registry.py +18 -11
- deepdoctection/datasets/save.py +12 -11
- deepdoctection/eval/__init__.py +3 -2
- deepdoctection/eval/accmetric.py +72 -52
- deepdoctection/eval/base.py +29 -10
- deepdoctection/eval/cocometric.py +14 -12
- deepdoctection/eval/eval.py +56 -41
- deepdoctection/eval/registry.py +6 -3
- deepdoctection/eval/tedsmetric.py +24 -9
- deepdoctection/eval/tp_eval_callback.py +13 -12
- deepdoctection/extern/__init__.py +1 -1
- deepdoctection/extern/base.py +176 -97
- deepdoctection/extern/d2detect.py +127 -92
- deepdoctection/extern/deskew.py +19 -10
- deepdoctection/extern/doctrocr.py +157 -106
- deepdoctection/extern/fastlang.py +25 -17
- deepdoctection/extern/hfdetr.py +137 -60
- deepdoctection/extern/hflayoutlm.py +329 -248
- deepdoctection/extern/hflm.py +67 -33
- deepdoctection/extern/model.py +108 -762
- deepdoctection/extern/pdftext.py +37 -12
- deepdoctection/extern/pt/nms.py +15 -1
- deepdoctection/extern/pt/ptutils.py +13 -9
- deepdoctection/extern/tessocr.py +87 -54
- deepdoctection/extern/texocr.py +29 -14
- deepdoctection/extern/tp/tfutils.py +36 -8
- deepdoctection/extern/tp/tpcompat.py +54 -16
- deepdoctection/extern/tp/tpfrcnn/config/config.py +20 -4
- deepdoctection/extern/tpdetect.py +4 -2
- deepdoctection/mapper/__init__.py +1 -1
- deepdoctection/mapper/cats.py +117 -76
- deepdoctection/mapper/cocostruct.py +35 -17
- deepdoctection/mapper/d2struct.py +56 -29
- deepdoctection/mapper/hfstruct.py +32 -19
- deepdoctection/mapper/laylmstruct.py +221 -185
- deepdoctection/mapper/maputils.py +71 -35
- deepdoctection/mapper/match.py +76 -62
- deepdoctection/mapper/misc.py +68 -44
- deepdoctection/mapper/pascalstruct.py +13 -12
- deepdoctection/mapper/prodigystruct.py +33 -19
- deepdoctection/mapper/pubstruct.py +42 -32
- deepdoctection/mapper/tpstruct.py +39 -19
- deepdoctection/mapper/xfundstruct.py +20 -13
- deepdoctection/pipe/__init__.py +1 -2
- deepdoctection/pipe/anngen.py +104 -62
- deepdoctection/pipe/base.py +226 -107
- deepdoctection/pipe/common.py +206 -123
- deepdoctection/pipe/concurrency.py +74 -47
- deepdoctection/pipe/doctectionpipe.py +108 -47
- deepdoctection/pipe/language.py +41 -24
- deepdoctection/pipe/layout.py +45 -18
- deepdoctection/pipe/lm.py +146 -78
- deepdoctection/pipe/order.py +196 -113
- deepdoctection/pipe/refine.py +111 -63
- deepdoctection/pipe/registry.py +1 -1
- deepdoctection/pipe/segment.py +213 -142
- deepdoctection/pipe/sub_layout.py +76 -46
- deepdoctection/pipe/text.py +52 -33
- deepdoctection/pipe/transform.py +8 -6
- deepdoctection/train/d2_frcnn_train.py +87 -69
- deepdoctection/train/hf_detr_train.py +72 -40
- deepdoctection/train/hf_layoutlm_train.py +85 -46
- deepdoctection/train/tp_frcnn_train.py +56 -28
- deepdoctection/utils/concurrency.py +59 -16
- deepdoctection/utils/context.py +40 -19
- deepdoctection/utils/develop.py +25 -17
- deepdoctection/utils/env_info.py +85 -36
- deepdoctection/utils/error.py +16 -10
- deepdoctection/utils/file_utils.py +246 -62
- deepdoctection/utils/fs.py +162 -43
- deepdoctection/utils/identifier.py +29 -16
- deepdoctection/utils/logger.py +49 -32
- deepdoctection/utils/metacfg.py +83 -21
- deepdoctection/utils/pdf_utils.py +119 -62
- deepdoctection/utils/settings.py +24 -10
- deepdoctection/utils/tqdm.py +10 -5
- deepdoctection/utils/transform.py +182 -46
- deepdoctection/utils/utils.py +61 -28
- deepdoctection/utils/viz.py +150 -104
- deepdoctection-0.43.dist-info/METADATA +376 -0
- deepdoctection-0.43.dist-info/RECORD +149 -0
- {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/WHEEL +1 -1
- deepdoctection/analyzer/_config.py +0 -146
- deepdoctection-0.42.0.dist-info/METADATA +0 -431
- deepdoctection-0.42.0.dist-info/RECORD +0 -148
- {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/licenses/LICENSE +0 -0
- {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/top_level.txt +0 -0
|
@@ -16,8 +16,7 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
"""
|
|
19
|
-
|
|
20
|
-
models that are a slightly different from the plain Detr model that are provided by the transformer library.
|
|
19
|
+
Fine-tuning Hugging Face Detr implementation.
|
|
21
20
|
"""
|
|
22
21
|
from __future__ import annotations
|
|
23
22
|
|
|
@@ -50,6 +49,7 @@ with try_import() as pt_import_guard:
|
|
|
50
49
|
with try_import() as hf_import_guard:
|
|
51
50
|
from transformers import (
|
|
52
51
|
AutoFeatureExtractor,
|
|
52
|
+
DeformableDetrForObjectDetection,
|
|
53
53
|
IntervalStrategy,
|
|
54
54
|
PretrainedConfig,
|
|
55
55
|
PreTrainedModel,
|
|
@@ -65,12 +65,11 @@ with try_import() as wb_import_guard:
|
|
|
65
65
|
class DetrDerivedTrainer(Trainer):
|
|
66
66
|
"""
|
|
67
67
|
Huggingface Trainer for training Transformer models with a custom evaluate method in order
|
|
68
|
-
to use dd Evaluator.
|
|
69
|
-
defined in `TrainingArguments`. Please check the Transformer documentation
|
|
68
|
+
to use dd Evaluator.
|
|
70
69
|
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
70
|
+
Train setting is not defined in the trainer itself but in config setting as defined in `TrainingArguments`.
|
|
71
|
+
Please check the Transformer documentation: https://huggingface.co/docs/transformers/main_classes/trainer for
|
|
72
|
+
custom training setting.
|
|
74
73
|
"""
|
|
75
74
|
|
|
76
75
|
def __init__(
|
|
@@ -81,6 +80,16 @@ class DetrDerivedTrainer(Trainer):
|
|
|
81
80
|
train_dataset: DatasetAdapter,
|
|
82
81
|
eval_dataset: Optional[DatasetBase] = None,
|
|
83
82
|
):
|
|
83
|
+
"""
|
|
84
|
+
Initializes `DetrDerivedTrainer`.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model: Model to be trained, either `PreTrainedModel` or `nn.Module`.
|
|
88
|
+
args: Training arguments.
|
|
89
|
+
data_collator: Data collator for Detr.
|
|
90
|
+
train_dataset: Training dataset.
|
|
91
|
+
eval_dataset: Optional evaluation dataset.
|
|
92
|
+
"""
|
|
84
93
|
self.evaluator: Optional[Evaluator] = None
|
|
85
94
|
self.build_eval_kwargs: Optional[dict[str, Any]] = None
|
|
86
95
|
super().__init__(model, args, data_collator, train_dataset, eval_dataset=eval_dataset)
|
|
@@ -94,14 +103,16 @@ class DetrDerivedTrainer(Trainer):
|
|
|
94
103
|
**build_eval_kwargs: Union[str, int],
|
|
95
104
|
) -> None:
|
|
96
105
|
"""
|
|
97
|
-
Setup of evaluator before starting training.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
:
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
106
|
+
Setup of evaluator before starting training.
|
|
107
|
+
|
|
108
|
+
During training, predictors will be replaced by current checkpoints.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
dataset_val: Dataset on which to run evaluation.
|
|
112
|
+
pipeline_component: Pipeline component to plug into the evaluator.
|
|
113
|
+
metric: A metric class.
|
|
114
|
+
run: WandB run.
|
|
115
|
+
**build_eval_kwargs: Additional keyword arguments for evaluation.
|
|
105
116
|
"""
|
|
106
117
|
|
|
107
118
|
self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1, run=run)
|
|
@@ -152,29 +163,32 @@ def train_hf_detr(
|
|
|
152
163
|
) -> None:
|
|
153
164
|
"""
|
|
154
165
|
Train Tabletransformer from scratch or fine-tune using an adaptation of the transformer trainer.
|
|
166
|
+
|
|
155
167
|
Allowing experiments by using different config settings.
|
|
156
168
|
|
|
157
|
-
:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
169
|
+
Args:
|
|
170
|
+
path_config_json: Path to a Tabletransformer config file.
|
|
171
|
+
dataset_train: Dataset to use for training.
|
|
172
|
+
path_weights: Path to a checkpoint, if you want to resume training or fine-tune. Will train from scratch if an
|
|
173
|
+
empty string is passed.
|
|
174
|
+
path_feature_extractor_config_json: Path to a feature extractor config file. In many situations you can use the
|
|
175
|
+
standard config file:
|
|
176
|
+
Example:
|
|
177
|
+
```python
|
|
178
|
+
ModelCatalog.get_full_path_preprocessor_configs
|
|
179
|
+
("microsoft/table-transformer-detection/pytorch_model.bin")
|
|
180
|
+
```
|
|
181
|
+
|
|
182
|
+
config_overwrite: Pass a list of arguments if some configs from the .json file are supposed to be replaced.
|
|
183
|
+
Use the list convention, e.g. `['per_device_train_batch_size=4']`.
|
|
184
|
+
log_dir: Will default to `train_log/detr`.
|
|
185
|
+
build_train_config: Dataflow build setting. Again, use list convention setting, e.g. `['max_datapoints=1000']`.
|
|
186
|
+
dataset_val: The dataset to use for validation.
|
|
187
|
+
build_val_config: Same as `build_train_config` but for dataflow validation.
|
|
188
|
+
metric_name: A metric name to choose for validation. Will use the default setting.
|
|
189
|
+
If you want a custom metric setting, pass a metric explicitly.
|
|
190
|
+
metric: A metric to choose for validation.
|
|
191
|
+
pipeline_component_name: A pipeline component name to use for validation.
|
|
178
192
|
"""
|
|
179
193
|
|
|
180
194
|
build_train_dict: dict[str, str] = {}
|
|
@@ -275,11 +289,29 @@ def train_hf_detr(
|
|
|
275
289
|
config.use_timm_backbone = True
|
|
276
290
|
|
|
277
291
|
if path_weights != "":
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
292
|
+
if "TableTransformerForObjectDetection" in config.architectures:
|
|
293
|
+
model = TableTransformerForObjectDetection.from_pretrained(
|
|
294
|
+
pretrained_model_name_or_path=path_weights, config=config, ignore_mismatched_sizes=True
|
|
295
|
+
)
|
|
296
|
+
elif "DeformableDetrForObjectDetection" in config.architectures:
|
|
297
|
+
return DeformableDetrForObjectDetection.from_pretrained(
|
|
298
|
+
pretrained_model_name_or_path=os.fspath(path_weights), config=config
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"Model architecture {config.architectures} not eligible. Please use either "
|
|
303
|
+
"TableTransformerForObjectDetection or DeformableDetrForObjectDetection."
|
|
304
|
+
)
|
|
281
305
|
else:
|
|
282
|
-
|
|
306
|
+
if "TableTransformerForObjectDetection" in config.architectures:
|
|
307
|
+
model = TableTransformerForObjectDetection(config)
|
|
308
|
+
elif "DeformableDetrForObjectDetection" in config.architectures:
|
|
309
|
+
model = DeformableDetrForObjectDetection(config)
|
|
310
|
+
else:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"Model architecture {config.architectures} not eligible. Please use either "
|
|
313
|
+
"TableTransformerForObjectDetection or DeformableDetrForObjectDetection."
|
|
314
|
+
)
|
|
283
315
|
|
|
284
316
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
285
317
|
pretrained_model_name_or_path=path_feature_extractor_config_json
|
|
@@ -16,7 +16,10 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
"""
|
|
19
|
-
|
|
19
|
+
Fine-tuning Huggingface implementation of LayoutLm.
|
|
20
|
+
|
|
21
|
+
This module provides functions and classes for fine-tuning LayoutLM models for sequence or token classification using
|
|
22
|
+
the Huggingface Trainer and custom evaluation. It supports LayoutLM, LayoutLMv2, LayoutLMv3, and LayoutXLM models.
|
|
20
23
|
"""
|
|
21
24
|
from __future__ import annotations
|
|
22
25
|
|
|
@@ -85,11 +88,14 @@ with try_import() as wb_import_guard:
|
|
|
85
88
|
|
|
86
89
|
def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> tuple[Any, Any, Any]:
|
|
87
90
|
"""
|
|
88
|
-
|
|
91
|
+
Gets the model architecture, model wrapper, and config class for a given `model_type` and `dataset_type`.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
model_type: The model type.
|
|
95
|
+
dataset_type: The dataset type.
|
|
89
96
|
|
|
90
|
-
:
|
|
91
|
-
|
|
92
|
-
:return: Tuple of model architecture, model wrapper and config class
|
|
97
|
+
Returns:
|
|
98
|
+
Tuple of model architecture, model wrapper, and config class.
|
|
93
99
|
"""
|
|
94
100
|
return {
|
|
95
101
|
("layoutlm", DatasetType.SEQUENCE_CLASSIFICATION): (
|
|
@@ -141,19 +147,28 @@ def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetTy
|
|
|
141
147
|
|
|
142
148
|
|
|
143
149
|
def maybe_remove_bounding_box_features(model_type: str) -> bool:
|
|
144
|
-
"""
|
|
150
|
+
"""
|
|
151
|
+
Lists models that do not need bounding box features.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
model_type: The model type.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
Whether the model does not need bounding box features.
|
|
158
|
+
"""
|
|
145
159
|
return {"xlm-roberta": True}.get(model_type, False)
|
|
146
160
|
|
|
147
161
|
|
|
148
162
|
class LayoutLMTrainer(Trainer):
|
|
149
163
|
"""
|
|
150
|
-
Huggingface Trainer for training Transformer models with a custom evaluate method
|
|
151
|
-
|
|
152
|
-
defined in `TrainingArguments`. Please check the Transformer documentation
|
|
164
|
+
Huggingface Trainer for training Transformer models with a custom evaluate method to use the Deepdoctection
|
|
165
|
+
Evaluator.
|
|
153
166
|
|
|
154
|
-
|
|
167
|
+
Train settings are not defined in the trainer itself but in the config setting as defined in `TrainingArguments`.
|
|
168
|
+
Please check the Transformer documentation for custom training settings.
|
|
155
169
|
|
|
156
|
-
|
|
170
|
+
Info:
|
|
171
|
+
https://huggingface.co/docs/transformers/main_classes/trainer
|
|
157
172
|
"""
|
|
158
173
|
|
|
159
174
|
def __init__(
|
|
@@ -164,6 +179,16 @@ class LayoutLMTrainer(Trainer):
|
|
|
164
179
|
train_dataset: DatasetAdapter,
|
|
165
180
|
eval_dataset: Optional[DatasetBase] = None,
|
|
166
181
|
):
|
|
182
|
+
"""
|
|
183
|
+
Initializes the `LayoutLMTrainer`.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
model: The model to train.
|
|
187
|
+
args: Training arguments.
|
|
188
|
+
data_collator: Data collator for batching.
|
|
189
|
+
train_dataset: Training dataset.
|
|
190
|
+
eval_dataset: Optional evaluation dataset.
|
|
191
|
+
"""
|
|
167
192
|
self.evaluator: Optional[Evaluator] = None
|
|
168
193
|
self.build_eval_kwargs: Optional[dict[str, Any]] = None
|
|
169
194
|
super().__init__(model, args, data_collator, train_dataset, eval_dataset=eval_dataset)
|
|
@@ -177,14 +202,15 @@ class LayoutLMTrainer(Trainer):
|
|
|
177
202
|
**build_eval_kwargs: Union[str, int],
|
|
178
203
|
) -> None:
|
|
179
204
|
"""
|
|
180
|
-
|
|
205
|
+
Sets up the evaluator before starting training. During training, predictors will be replaced by current
|
|
181
206
|
checkpoints.
|
|
182
207
|
|
|
183
|
-
:
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
208
|
+
Args:
|
|
209
|
+
dataset_val: Dataset on which to run evaluation.
|
|
210
|
+
pipeline_component: Pipeline component to plug into the evaluator.
|
|
211
|
+
metric: A metric class.
|
|
212
|
+
run: WandB run.
|
|
213
|
+
**build_eval_kwargs: Additional keyword arguments for evaluation.
|
|
188
214
|
"""
|
|
189
215
|
|
|
190
216
|
self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1, run=run)
|
|
@@ -201,6 +227,14 @@ class LayoutLMTrainer(Trainer):
|
|
|
201
227
|
) -> dict[str, float]:
|
|
202
228
|
"""
|
|
203
229
|
Overwritten method from `Trainer`. Arguments will not be used.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
eval_dataset: Not used.
|
|
233
|
+
ignore_keys: Not used.
|
|
234
|
+
metric_key_prefix: Not used.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Evaluation scores as a dictionary.
|
|
204
238
|
"""
|
|
205
239
|
if self.evaluator is None:
|
|
206
240
|
raise ValueError("Evaluator not set up. Please use `setup_evaluator` before running evaluation")
|
|
@@ -266,28 +300,32 @@ def train_hf_layoutlm(
|
|
|
266
300
|
LayoutXLM. Training similar but different models like LILT <https://arxiv.org/abs/2202.13669> can be done by
|
|
267
301
|
changing a few lines of code regarding the selection of the tokenizer.
|
|
268
302
|
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
<https://arxiv.org/abs/1912.13318>
|
|
303
|
+
Info:
|
|
304
|
+
The theoretical foundation can be taken from <https://arxiv.org/abs/1912.13318>.
|
|
272
305
|
|
|
273
|
-
|
|
306
|
+
This is not the pre-training script.
|
|
274
307
|
|
|
275
308
|
In order to remain within the framework of this library, the base and uncased LayoutLM model must be downloaded
|
|
276
309
|
from the HF-hub in a first step for fine-tuning. Models are available for this, which are registered in the
|
|
277
310
|
ModelCatalog. It is possible to choose one of the following options:
|
|
278
311
|
|
|
279
|
-
"microsoft/layoutlm-base-uncased/pytorch_model.bin"
|
|
280
|
-
"microsoft/layoutlmv2-base-uncased/pytorch_model.bin"
|
|
281
|
-
"microsoft/layoutxlm-base/pytorch_model.bin"
|
|
282
|
-
"microsoft/layoutlmv3-base/pytorch_model.bin"
|
|
283
312
|
|
|
284
|
-
|
|
313
|
+
`microsoft/layoutlm-base-uncased/pytorch_model.bin`
|
|
314
|
+
`microsoft/layoutlmv2-base-uncased/pytorch_model.bin`
|
|
315
|
+
`microsoft/layoutxlm-base/pytorch_model.bin`
|
|
316
|
+
`microsoft/layoutlmv3-base/pytorch_model.bin`
|
|
317
|
+
`microsoft/layoutlm-large-uncased/pytorch_model.bin`
|
|
318
|
+
`SCUT-DLVCLab/lilt-roberta-en-base/pytorch_model.bin`
|
|
285
319
|
|
|
286
|
-
"microsoft/layoutlm-large-uncased/pytorch_model.bin"
|
|
287
320
|
|
|
288
|
-
|
|
321
|
+
Note:
|
|
322
|
+
You can also choose the large versions of LayoutLMv2 and LayoutXLM but you need to organize the download
|
|
323
|
+
yourself.
|
|
289
324
|
|
|
325
|
+
Example:
|
|
326
|
+
```python
|
|
290
327
|
ModelDownloadManager.maybe_download_weights_and_configs("microsoft/layoutlm-base-uncased/pytorch_model.bin")
|
|
328
|
+
```
|
|
291
329
|
|
|
292
330
|
The corresponding cased models are currently not available, but this is only to keep the model selection small.
|
|
293
331
|
|
|
@@ -296,30 +334,31 @@ def train_hf_layoutlm(
|
|
|
296
334
|
How does the model selection work?
|
|
297
335
|
|
|
298
336
|
The base model is selected by the transferred config file and the weights. Depending on the dataset type
|
|
299
|
-
("SEQUENCE_CLASSIFICATION" or "TOKEN_CLASSIFICATION")
|
|
300
|
-
top layer on the base model.
|
|
337
|
+
`("SEQUENCE_CLASSIFICATION" or "TOKEN_CLASSIFICATION")`, the complete model is then put together by placing a
|
|
338
|
+
suitable top layer on the base model.
|
|
301
339
|
|
|
302
|
-
:
|
|
303
|
-
|
|
304
|
-
|
|
340
|
+
Args:
|
|
341
|
+
path_config_json: Absolute path to HF config file, e.g.
|
|
342
|
+
`ModelCatalog.get_full_path_configs("microsoft/layoutlm-base-uncased/pytorch_model.bin")`
|
|
343
|
+
dataset_train: Dataset to use for training. Only datasets of type "SEQUENCE_CLASSIFICATION" or
|
|
305
344
|
"TOKEN_CLASSIFICATION" are supported.
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
|
345
|
+
path_weights: path to a checkpoint for further fine-tuning
|
|
346
|
+
config_overwrite: Pass a list of arguments if some configs from `TrainingArguments` should be replaced. Check
|
|
347
|
+
<https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments>
|
|
309
348
|
for the full training default setting.
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
349
|
+
log_dir: Path to log dir. Will default to `train_log/layoutlm`
|
|
350
|
+
build_train_config: dataflow build setting. Again, use list convention setting, e.g. `['max_datapoints=1000']`
|
|
351
|
+
dataset_val: Dataset to use for validation. Dataset type must be the same as type of `dataset_train`
|
|
352
|
+
build_val_config: same as `build_train_config` but for validation
|
|
353
|
+
metric: A metric to choose for validation.
|
|
354
|
+
pipeline_component_name: A pipeline component name to use for validation (e.g. `LMSequenceClassifierService` or
|
|
316
355
|
LMTokenClassifierService.
|
|
317
|
-
|
|
318
|
-
between Layoutlmv2 and
|
|
319
|
-
|
|
356
|
+
use_xlm_tokenizer: This is only necessary if you pass weights of LayoutXLM. The config cannot distinguish
|
|
357
|
+
between Layoutlmv2 and LayoutXLM, so you need to pass this info explicitly.
|
|
358
|
+
use_token_tag: Will only be used for `dataset_type="token_classification"`. If `use_token_tag=True`, will use
|
|
320
359
|
labels from sub category `WordType.token_tag` (with `B,I,O` suffix), otherwise
|
|
321
360
|
`WordType.token_class`.
|
|
322
|
-
|
|
361
|
+
segment_positions: Using bounding boxes of segment instead of words improves model accuracy significantly.
|
|
323
362
|
Choose a single or a sequence of layout segments to use their bounding boxes. Note, that
|
|
324
363
|
the layout segments need to have a child-relationship with words. If a word does not
|
|
325
364
|
appear as child, it will use the word bounding box.
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
"""
|
|
19
|
-
|
|
19
|
+
Training Tensorpack's `GeneralizedRCNN`
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
import os
|
|
@@ -75,6 +75,9 @@ __all__ = ["train_faster_rcnn"]
|
|
|
75
75
|
class LoadAugmentAddAnchors:
|
|
76
76
|
"""
|
|
77
77
|
A helper class for default mapping `load_augment_add_anchors`.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
config: An `AttrDict` configuration for TP FRCNN.
|
|
78
81
|
"""
|
|
79
82
|
|
|
80
83
|
def __init__(self, config: AttrDict) -> None:
|
|
@@ -89,9 +92,15 @@ def load_augment_add_anchors(dp: JsonDict, config: AttrDict) -> Optional[JsonDic
|
|
|
89
92
|
Transforming an image before entering the graph. This function bundles all the necessary steps to feed
|
|
90
93
|
the network for training.
|
|
91
94
|
|
|
92
|
-
:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
+
Args:
|
|
96
|
+
dp: A dict with `file_name`, `gt_boxes`, `gt_labels` and optional `image`.
|
|
97
|
+
config: An `AttrDict` with a TP frcnn config.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A dict with all necessary keys for feeding the graph.
|
|
101
|
+
|
|
102
|
+
Note:
|
|
103
|
+
If `image` is not in `dp`, it will be loaded from `file_name`.
|
|
95
104
|
"""
|
|
96
105
|
cfg = config
|
|
97
106
|
if "image" not in dp:
|
|
@@ -124,14 +133,20 @@ def get_train_dataflow(
|
|
|
124
133
|
dataset: DatasetBase, config: AttrDict, use_multi_proc_for_train: bool, **build_train_kwargs: str
|
|
125
134
|
) -> DataFlow:
|
|
126
135
|
"""
|
|
127
|
-
Return a dataflow for training TP
|
|
136
|
+
Return a dataflow for training TP FRCNN. The returned dataflow depends on the dataset and the configuration of
|
|
128
137
|
the model, as the augmentation is part of the data preparation.
|
|
129
138
|
|
|
130
|
-
:
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
139
|
+
Args:
|
|
140
|
+
dataset: A dataset for object detection.
|
|
141
|
+
config: An `AttrDict` with a TP FRCNN config.
|
|
142
|
+
use_multi_proc_for_train: If set to `True` will use multi processes for augmenting.
|
|
143
|
+
build_train_kwargs: Build configuration of the dataflow.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
A dataflow.
|
|
147
|
+
|
|
148
|
+
Note:
|
|
149
|
+
If `use_multi_proc_for_train` is `True`, multi-processing will be used for augmentation.
|
|
135
150
|
"""
|
|
136
151
|
|
|
137
152
|
set_mp_spawn()
|
|
@@ -202,23 +217,35 @@ def train_faster_rcnn(
|
|
|
202
217
|
Train Faster-RCNN from Scratch or fine-tune a model using Tensorpack's training API. Observe the training with
|
|
203
218
|
Tensorpack callbacks and evaluate the training progress with a validation data set after certain training intervals.
|
|
204
219
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
:
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
220
|
+
Info:
|
|
221
|
+
Tensorpack provides a training API under TF1. Training runs under a TF2 installation if TF2 behavior is
|
|
222
|
+
deactivated.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
path_config_yaml: Path to TP config file. Check the `deepdoctection.extern.tp.tpfrcnn.config.config` for various
|
|
226
|
+
settings.
|
|
227
|
+
dataset_train: The dataset to use for training.
|
|
228
|
+
path_weights: Path to a checkpoint, if you want to continue training or fine-tune. Will train from scratch if
|
|
229
|
+
nothing is passed.
|
|
230
|
+
config_overwrite: Pass a list of arguments if some configs from the .yaml file should be replaced. Use the list
|
|
231
|
+
convention, e.g. `[`TRAIN.STEPS_PER_EPOCH=500`, `OUTPUT.RESULT_SCORE_THRESH=0.4`]`.
|
|
232
|
+
log_dir: Path to log dir. Will default to `TRAIN.LOG_DIR`.
|
|
233
|
+
build_train_config: Dataflow build setting. Use list convention setting, e.g. `[`max_datapoints=1000`]`.
|
|
234
|
+
dataset_val: The dataset to use for validation.
|
|
235
|
+
build_val_config: Same as `build_train_config` but for validation.
|
|
236
|
+
metric_name: A metric name to choose for validation. Will use the default setting. If you want a custom metric
|
|
237
|
+
setting pass a metric explicitly.
|
|
238
|
+
metric: A metric to choose for validation.
|
|
239
|
+
pipeline_component_name: A pipeline component to use for validation.
|
|
240
|
+
|
|
241
|
+
Example:
|
|
242
|
+
```python
|
|
243
|
+
train_faster_rcnn(
|
|
244
|
+
path_config_yaml="config.yaml",
|
|
245
|
+
dataset_train=my_train_dataset,
|
|
246
|
+
path_weights="weights.ckpt"
|
|
247
|
+
)
|
|
248
|
+
```
|
|
222
249
|
"""
|
|
223
250
|
|
|
224
251
|
assert disable_tfv2() # TP works only in Graph mode
|
|
@@ -241,9 +268,10 @@ def train_faster_rcnn(
|
|
|
241
268
|
config_overwrite.append(log_dir)
|
|
242
269
|
|
|
243
270
|
config = set_config_by_yaml(path_config_yaml)
|
|
244
|
-
|
|
271
|
+
config.freeze(False)
|
|
245
272
|
if config_overwrite:
|
|
246
273
|
config.update_args(config_overwrite)
|
|
274
|
+
config.freeze(True)
|
|
247
275
|
|
|
248
276
|
categories = dataset_train.dataflow.categories.get_categories(filtered=True)
|
|
249
277
|
model_frcnn_config(config, categories, False)
|
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
# limitations under the License.
|
|
17
17
|
|
|
18
18
|
"""
|
|
19
|
-
|
|
19
|
+
Functions for multi/threading purposes
|
|
20
20
|
"""
|
|
21
21
|
|
|
22
22
|
import multiprocessing as mp
|
|
@@ -35,12 +35,17 @@ from .types import QueueType
|
|
|
35
35
|
# taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
|
|
36
36
|
class StoppableThread(threading.Thread):
|
|
37
37
|
"""
|
|
38
|
-
A thread that has a
|
|
38
|
+
A thread that has a `stop` event.
|
|
39
|
+
|
|
40
|
+
This class extends `threading.Thread` and provides a mechanism to stop the thread gracefully.
|
|
39
41
|
"""
|
|
40
42
|
|
|
41
43
|
def __init__(self, evt: Optional[threading.Event] = None) -> None:
|
|
42
44
|
"""
|
|
43
|
-
|
|
45
|
+
Initializes a `StoppableThread`.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
evt: An optional `threading.Event`. If `None`, a new event will be created.
|
|
44
49
|
"""
|
|
45
50
|
super().__init__()
|
|
46
51
|
if evt is None:
|
|
@@ -48,17 +53,30 @@ class StoppableThread(threading.Thread):
|
|
|
48
53
|
self._stop_evt = evt
|
|
49
54
|
|
|
50
55
|
def stop(self) -> None:
|
|
51
|
-
"""
|
|
56
|
+
"""
|
|
57
|
+
Stop the thread.
|
|
58
|
+
|
|
59
|
+
Sets the internal stop event, signaling the thread to stop.
|
|
60
|
+
"""
|
|
52
61
|
self._stop_evt.set()
|
|
53
62
|
|
|
54
63
|
def stopped(self) -> bool:
|
|
55
64
|
"""
|
|
56
|
-
|
|
65
|
+
Check whether the thread is stopped.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Whether the thread is stopped or not.
|
|
57
69
|
"""
|
|
58
70
|
return self._stop_evt.is_set()
|
|
59
71
|
|
|
60
72
|
def queue_put_stoppable(self, q: QueueType, obj: Any) -> None:
|
|
61
|
-
"""
|
|
73
|
+
"""
|
|
74
|
+
Put `obj` to queue `q`, but will give up when the thread is stopped.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
q: The queue to put the object into.
|
|
78
|
+
obj: The object to put into the queue.
|
|
79
|
+
"""
|
|
62
80
|
while not self.stopped():
|
|
63
81
|
try:
|
|
64
82
|
q.put(obj, timeout=5)
|
|
@@ -67,7 +85,15 @@ class StoppableThread(threading.Thread):
|
|
|
67
85
|
pass
|
|
68
86
|
|
|
69
87
|
def queue_get_stoppable(self, q: QueueType) -> Any:
|
|
70
|
-
"""
|
|
88
|
+
"""
|
|
89
|
+
Take an object from queue `q`, but will give up when the thread is stopped.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
q: The queue to get the object from.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
The object taken from the queue.
|
|
96
|
+
"""
|
|
71
97
|
while not self.stopped():
|
|
72
98
|
try:
|
|
73
99
|
return q.get(timeout=5)
|
|
@@ -77,9 +103,14 @@ class StoppableThread(threading.Thread):
|
|
|
77
103
|
|
|
78
104
|
@contextmanager
|
|
79
105
|
def mask_sigint() -> Generator[Any, None, None]:
|
|
80
|
-
"""
|
|
81
|
-
|
|
82
|
-
|
|
106
|
+
"""
|
|
107
|
+
Context manager to mask `SIGINT`.
|
|
108
|
+
|
|
109
|
+
If called in the main thread, returns a context where `SIGINT` is ignored, and yields `True`. Otherwise, yields
|
|
110
|
+
`False`.
|
|
111
|
+
|
|
112
|
+
Yields:
|
|
113
|
+
`True` if called in the main thread, otherwise `False`.
|
|
83
114
|
"""
|
|
84
115
|
if threading.current_thread() == threading.main_thread():
|
|
85
116
|
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
@@ -91,9 +122,15 @@ def mask_sigint() -> Generator[Any, None, None]:
|
|
|
91
122
|
|
|
92
123
|
def enable_death_signal(_warn: bool = True) -> None:
|
|
93
124
|
"""
|
|
94
|
-
Set the "death signal" of the current process
|
|
95
|
-
|
|
96
|
-
|
|
125
|
+
Set the "death signal" of the current process.
|
|
126
|
+
|
|
127
|
+
Ensures that the current process will be cleaned up if the parent process dies accidentally.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
_warn: If `True`, logs a warning if `prctl` is not available.
|
|
131
|
+
|
|
132
|
+
Note:
|
|
133
|
+
Only works on Linux systems. Requires the `python-prctl` package.
|
|
97
134
|
"""
|
|
98
135
|
if platform.system() != "Linux":
|
|
99
136
|
return
|
|
@@ -118,11 +155,17 @@ def enable_death_signal(_warn: bool = True) -> None:
|
|
|
118
155
|
@no_type_check
|
|
119
156
|
def start_proc_mask_signal(proc):
|
|
120
157
|
"""
|
|
121
|
-
Start process(es) with SIGINT ignored.
|
|
158
|
+
Start process(es) with `SIGINT` ignored.
|
|
159
|
+
|
|
160
|
+
The signal mask is only applied when called from the main thread.
|
|
122
161
|
|
|
123
|
-
:
|
|
162
|
+
Note:
|
|
163
|
+
Starting a process with the 'fork' method is efficient but not safe and may cause deadlock or crash.
|
|
164
|
+
Use 'forkserver' or 'spawn' method instead if you run into such issues.
|
|
165
|
+
See <https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods> on how to set them.
|
|
124
166
|
|
|
125
|
-
|
|
167
|
+
Args:
|
|
168
|
+
proc: A `mp.Process` or a list of `mp.Process` instances.
|
|
126
169
|
"""
|
|
127
170
|
if not isinstance(proc, list):
|
|
128
171
|
proc = [proc]
|