deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of deepdoctection might be problematic. Click here for more details.
- deepdoctection/__init__.py +16 -29
- deepdoctection/analyzer/dd.py +70 -59
- deepdoctection/configs/conf_dd_one.yaml +34 -31
- deepdoctection/dataflow/common.py +9 -5
- deepdoctection/dataflow/custom.py +5 -5
- deepdoctection/dataflow/custom_serialize.py +75 -18
- deepdoctection/dataflow/parallel_map.py +3 -3
- deepdoctection/dataflow/serialize.py +4 -4
- deepdoctection/dataflow/stats.py +3 -3
- deepdoctection/datapoint/annotation.py +41 -56
- deepdoctection/datapoint/box.py +9 -8
- deepdoctection/datapoint/convert.py +6 -6
- deepdoctection/datapoint/image.py +56 -44
- deepdoctection/datapoint/view.py +245 -150
- deepdoctection/datasets/__init__.py +1 -4
- deepdoctection/datasets/adapter.py +35 -26
- deepdoctection/datasets/base.py +14 -12
- deepdoctection/datasets/dataflow_builder.py +3 -3
- deepdoctection/datasets/info.py +24 -26
- deepdoctection/datasets/instances/doclaynet.py +51 -51
- deepdoctection/datasets/instances/fintabnet.py +46 -46
- deepdoctection/datasets/instances/funsd.py +25 -24
- deepdoctection/datasets/instances/iiitar13k.py +13 -10
- deepdoctection/datasets/instances/layouttest.py +4 -3
- deepdoctection/datasets/instances/publaynet.py +5 -5
- deepdoctection/datasets/instances/pubtables1m.py +24 -21
- deepdoctection/datasets/instances/pubtabnet.py +32 -30
- deepdoctection/datasets/instances/rvlcdip.py +30 -30
- deepdoctection/datasets/instances/xfund.py +26 -26
- deepdoctection/datasets/save.py +6 -6
- deepdoctection/eval/__init__.py +1 -4
- deepdoctection/eval/accmetric.py +32 -33
- deepdoctection/eval/base.py +8 -9
- deepdoctection/eval/cocometric.py +15 -13
- deepdoctection/eval/eval.py +41 -37
- deepdoctection/eval/tedsmetric.py +30 -23
- deepdoctection/eval/tp_eval_callback.py +16 -19
- deepdoctection/extern/__init__.py +2 -7
- deepdoctection/extern/base.py +339 -134
- deepdoctection/extern/d2detect.py +85 -113
- deepdoctection/extern/deskew.py +14 -11
- deepdoctection/extern/doctrocr.py +141 -130
- deepdoctection/extern/fastlang.py +27 -18
- deepdoctection/extern/hfdetr.py +71 -62
- deepdoctection/extern/hflayoutlm.py +504 -211
- deepdoctection/extern/hflm.py +230 -0
- deepdoctection/extern/model.py +488 -302
- deepdoctection/extern/pdftext.py +23 -19
- deepdoctection/extern/pt/__init__.py +1 -3
- deepdoctection/extern/pt/nms.py +6 -2
- deepdoctection/extern/pt/ptutils.py +29 -19
- deepdoctection/extern/tessocr.py +39 -38
- deepdoctection/extern/texocr.py +18 -18
- deepdoctection/extern/tp/tfutils.py +57 -9
- deepdoctection/extern/tp/tpcompat.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
- deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
- deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
- deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
- deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
- deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
- deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
- deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
- deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
- deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
- deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
- deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
- deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
- deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
- deepdoctection/extern/tpdetect.py +45 -53
- deepdoctection/mapper/__init__.py +3 -8
- deepdoctection/mapper/cats.py +27 -29
- deepdoctection/mapper/cocostruct.py +10 -10
- deepdoctection/mapper/d2struct.py +27 -26
- deepdoctection/mapper/hfstruct.py +13 -8
- deepdoctection/mapper/laylmstruct.py +178 -37
- deepdoctection/mapper/maputils.py +12 -11
- deepdoctection/mapper/match.py +2 -2
- deepdoctection/mapper/misc.py +11 -9
- deepdoctection/mapper/pascalstruct.py +4 -4
- deepdoctection/mapper/prodigystruct.py +5 -5
- deepdoctection/mapper/pubstruct.py +84 -92
- deepdoctection/mapper/tpstruct.py +5 -5
- deepdoctection/mapper/xfundstruct.py +33 -33
- deepdoctection/pipe/__init__.py +1 -1
- deepdoctection/pipe/anngen.py +12 -14
- deepdoctection/pipe/base.py +52 -106
- deepdoctection/pipe/common.py +72 -59
- deepdoctection/pipe/concurrency.py +16 -11
- deepdoctection/pipe/doctectionpipe.py +24 -21
- deepdoctection/pipe/language.py +20 -25
- deepdoctection/pipe/layout.py +20 -16
- deepdoctection/pipe/lm.py +75 -105
- deepdoctection/pipe/order.py +194 -89
- deepdoctection/pipe/refine.py +111 -124
- deepdoctection/pipe/segment.py +156 -161
- deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
- deepdoctection/pipe/text.py +37 -36
- deepdoctection/pipe/transform.py +19 -16
- deepdoctection/train/__init__.py +6 -12
- deepdoctection/train/d2_frcnn_train.py +48 -41
- deepdoctection/train/hf_detr_train.py +41 -30
- deepdoctection/train/hf_layoutlm_train.py +153 -135
- deepdoctection/train/tp_frcnn_train.py +32 -31
- deepdoctection/utils/concurrency.py +1 -1
- deepdoctection/utils/context.py +13 -6
- deepdoctection/utils/develop.py +4 -4
- deepdoctection/utils/env_info.py +87 -125
- deepdoctection/utils/file_utils.py +6 -11
- deepdoctection/utils/fs.py +22 -18
- deepdoctection/utils/identifier.py +2 -2
- deepdoctection/utils/logger.py +16 -15
- deepdoctection/utils/metacfg.py +7 -7
- deepdoctection/utils/mocks.py +93 -0
- deepdoctection/utils/pdf_utils.py +11 -11
- deepdoctection/utils/settings.py +185 -181
- deepdoctection/utils/tqdm.py +1 -1
- deepdoctection/utils/transform.py +14 -9
- deepdoctection/utils/types.py +104 -0
- deepdoctection/utils/utils.py +7 -7
- deepdoctection/utils/viz.py +74 -72
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
- deepdoctection-0.33.dist-info/RECORD +146 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
- deepdoctection/utils/detection_types.py +0 -68
- deepdoctection-0.31.dist-info/RECORD +0 -144
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
- {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
|
@@ -18,19 +18,14 @@
|
|
|
18
18
|
"""
|
|
19
19
|
Module for training Detectron2 `GeneralizedRCNN`
|
|
20
20
|
"""
|
|
21
|
-
|
|
21
|
+
from __future__ import annotations
|
|
22
22
|
|
|
23
23
|
import copy
|
|
24
|
-
|
|
24
|
+
import os
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any, Mapping, Optional, Sequence, Type, Union
|
|
25
27
|
|
|
26
|
-
from
|
|
27
|
-
from detectron2.data import DatasetMapper, build_detection_train_loader
|
|
28
|
-
from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
|
|
29
|
-
from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
|
|
30
|
-
from detectron2.utils import comm
|
|
31
|
-
from detectron2.utils.events import EventWriter, get_event_storage
|
|
32
|
-
from fvcore.nn.precise_bn import get_bn_modules # type: ignore
|
|
33
|
-
from torch.utils.data import DataLoader, IterableDataset
|
|
28
|
+
from lazy_imports import try_import
|
|
34
29
|
|
|
35
30
|
from ..datasets.adapter import DatasetAdapter
|
|
36
31
|
from ..datasets.base import DatasetBase
|
|
@@ -39,22 +34,35 @@ from ..eval.base import MetricBase
|
|
|
39
34
|
from ..eval.eval import Evaluator
|
|
40
35
|
from ..eval.registry import metric_registry
|
|
41
36
|
from ..extern.d2detect import D2FrcnnDetector
|
|
42
|
-
from ..extern.pt.ptutils import get_num_gpu
|
|
43
37
|
from ..mapper.d2struct import image_to_d2_frcnn_training
|
|
44
|
-
from ..pipe.base import
|
|
38
|
+
from ..pipe.base import PipelineComponent
|
|
45
39
|
from ..pipe.registry import pipeline_component_registry
|
|
46
40
|
from ..utils.error import DependencyError
|
|
47
41
|
from ..utils.file_utils import get_wandb_requirement, wandb_available
|
|
48
42
|
from ..utils.logger import LoggingRecord, logger
|
|
43
|
+
from ..utils.types import PathLikeOrStr
|
|
49
44
|
from ..utils.utils import string_to_dict
|
|
50
45
|
|
|
51
|
-
|
|
46
|
+
with try_import() as d2_import_guard:
|
|
47
|
+
from detectron2.config import CfgNode, get_cfg
|
|
48
|
+
from detectron2.data import DatasetMapper, build_detection_train_loader
|
|
49
|
+
from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
|
|
50
|
+
from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
|
|
51
|
+
from detectron2.utils import comm
|
|
52
|
+
from detectron2.utils.events import EventWriter, get_event_storage
|
|
53
|
+
from fvcore.nn.precise_bn import get_bn_modules # type: ignore
|
|
54
|
+
|
|
55
|
+
with try_import() as pt_import_guard:
|
|
56
|
+
from torch import cuda
|
|
57
|
+
from torch.utils.data import DataLoader, IterableDataset
|
|
58
|
+
|
|
59
|
+
with try_import() as wb_import_guard:
|
|
52
60
|
import wandb
|
|
53
61
|
|
|
54
62
|
|
|
55
63
|
def _set_config(
|
|
56
|
-
path_config_yaml:
|
|
57
|
-
conf_list:
|
|
64
|
+
path_config_yaml: PathLikeOrStr,
|
|
65
|
+
conf_list: list[str],
|
|
58
66
|
dataset_train: DatasetBase,
|
|
59
67
|
dataset_val: Optional[DatasetBase],
|
|
60
68
|
metric_name: Optional[str],
|
|
@@ -69,7 +77,7 @@ def _set_config(
|
|
|
69
77
|
cfg.WANDB.USE_WANDB = False
|
|
70
78
|
cfg.WANDB.PROJECT = None
|
|
71
79
|
cfg.WANDB.REPO = "deepdoctection"
|
|
72
|
-
cfg.merge_from_file(path_config_yaml)
|
|
80
|
+
cfg.merge_from_file(path_config_yaml.as_posix() if isinstance(path_config_yaml, Path) else path_config_yaml)
|
|
73
81
|
cfg.merge_from_list(conf_list)
|
|
74
82
|
|
|
75
83
|
cfg.TEST.DO_EVAL = (
|
|
@@ -84,7 +92,7 @@ def _set_config(
|
|
|
84
92
|
return cfg
|
|
85
93
|
|
|
86
94
|
|
|
87
|
-
def _update_for_eval(config_overwrite:
|
|
95
|
+
def _update_for_eval(config_overwrite: list[str]) -> list[str]:
|
|
88
96
|
ret = [item for item in config_overwrite if not "WANDB" in item]
|
|
89
97
|
return ret
|
|
90
98
|
|
|
@@ -98,7 +106,7 @@ class WandbWriter(EventWriter):
|
|
|
98
106
|
self,
|
|
99
107
|
project: str,
|
|
100
108
|
repo: str,
|
|
101
|
-
config: Optional[Union[
|
|
109
|
+
config: Optional[Union[dict[str, Any], CfgNode]] = None,
|
|
102
110
|
window_size: int = 20,
|
|
103
111
|
**kwargs: Any,
|
|
104
112
|
):
|
|
@@ -112,7 +120,7 @@ class WandbWriter(EventWriter):
|
|
|
112
120
|
config = {}
|
|
113
121
|
self._window_size = window_size
|
|
114
122
|
self._run = wandb.init(project=project, config=config, **kwargs) if not wandb.run else wandb.run
|
|
115
|
-
self._run._label(repo=repo)
|
|
123
|
+
self._run._label(repo=repo)
|
|
116
124
|
|
|
117
125
|
def write(self) -> None:
|
|
118
126
|
storage = get_event_storage()
|
|
@@ -121,10 +129,10 @@ class WandbWriter(EventWriter):
|
|
|
121
129
|
for key, (val, _) in storage.latest_with_smoothing_hint(self._window_size).items():
|
|
122
130
|
log_dict[key] = val
|
|
123
131
|
|
|
124
|
-
self._run.log(log_dict)
|
|
132
|
+
self._run.log(log_dict)
|
|
125
133
|
|
|
126
134
|
def close(self) -> None:
|
|
127
|
-
self._run.finish()
|
|
135
|
+
self._run.finish()
|
|
128
136
|
|
|
129
137
|
|
|
130
138
|
class D2Trainer(DefaultTrainer):
|
|
@@ -140,7 +148,7 @@ class D2Trainer(DefaultTrainer):
|
|
|
140
148
|
self.build_val_dict: Mapping[str, str] = {}
|
|
141
149
|
super().__init__(cfg)
|
|
142
150
|
|
|
143
|
-
def build_hooks(self) ->
|
|
151
|
+
def build_hooks(self) -> list[HookBase]:
|
|
144
152
|
"""
|
|
145
153
|
Overwritten from DefaultTrainer. This ensures that the EvalHook is being called before the writer and
|
|
146
154
|
all metrics are being written to JSON, Tensorboard etc.
|
|
@@ -192,7 +200,7 @@ class D2Trainer(DefaultTrainer):
|
|
|
192
200
|
|
|
193
201
|
return ret
|
|
194
202
|
|
|
195
|
-
def build_writers(self) ->
|
|
203
|
+
def build_writers(self) -> list[EventWriter]:
|
|
196
204
|
"""
|
|
197
205
|
Build a list of writers to be using `default_writers()`.
|
|
198
206
|
If you'd like a different list of writers, you can overwrite it in
|
|
@@ -221,7 +229,7 @@ class D2Trainer(DefaultTrainer):
|
|
|
221
229
|
dataset=self.dataset, mapper=self.mapper, total_batch_size=cfg.SOLVER.IMS_PER_BATCH
|
|
222
230
|
)
|
|
223
231
|
|
|
224
|
-
def eval_with_dd_evaluator(self, **build_eval_kwargs: str) -> Union[
|
|
232
|
+
def eval_with_dd_evaluator(self, **build_eval_kwargs: str) -> Union[list[dict[str, Any]], dict[str, Any]]:
|
|
225
233
|
"""
|
|
226
234
|
Running the Evaluator. This method will be called from the `EvalHook`
|
|
227
235
|
|
|
@@ -238,7 +246,7 @@ class D2Trainer(DefaultTrainer):
|
|
|
238
246
|
def setup_evaluator(
|
|
239
247
|
self,
|
|
240
248
|
dataset_val: DatasetBase,
|
|
241
|
-
pipeline_component:
|
|
249
|
+
pipeline_component: PipelineComponent,
|
|
242
250
|
metric: Union[Type[MetricBase], MetricBase],
|
|
243
251
|
build_val_dict: Optional[Mapping[str, str]] = None,
|
|
244
252
|
) -> None:
|
|
@@ -259,16 +267,14 @@ class D2Trainer(DefaultTrainer):
|
|
|
259
267
|
dataset_val,
|
|
260
268
|
pipeline_component,
|
|
261
269
|
metric,
|
|
262
|
-
num_threads=
|
|
270
|
+
num_threads=cuda.device_count() * 2,
|
|
263
271
|
run=run,
|
|
264
272
|
)
|
|
265
273
|
if build_val_dict:
|
|
266
274
|
self.build_val_dict = build_val_dict
|
|
267
275
|
assert self.evaluator.pipe_component
|
|
268
276
|
for comp in self.evaluator.pipe_component.pipe_components:
|
|
269
|
-
|
|
270
|
-
assert isinstance(comp.predictor, D2FrcnnDetector)
|
|
271
|
-
comp.predictor.d2_predictor = None
|
|
277
|
+
comp.clear_predictor()
|
|
272
278
|
|
|
273
279
|
@classmethod
|
|
274
280
|
def build_evaluator(cls, cfg, dataset_name): # type: ignore
|
|
@@ -276,11 +282,11 @@ class D2Trainer(DefaultTrainer):
|
|
|
276
282
|
|
|
277
283
|
|
|
278
284
|
def train_d2_faster_rcnn(
|
|
279
|
-
path_config_yaml:
|
|
285
|
+
path_config_yaml: PathLikeOrStr,
|
|
280
286
|
dataset_train: Union[str, DatasetBase],
|
|
281
|
-
path_weights:
|
|
282
|
-
config_overwrite: Optional[
|
|
283
|
-
log_dir:
|
|
287
|
+
path_weights: PathLikeOrStr,
|
|
288
|
+
config_overwrite: Optional[list[str]] = None,
|
|
289
|
+
log_dir: PathLikeOrStr = "train_log/frcnn",
|
|
284
290
|
build_train_config: Optional[Sequence[str]] = None,
|
|
285
291
|
dataset_val: Optional[DatasetBase] = None,
|
|
286
292
|
build_val_config: Optional[Sequence[str]] = None,
|
|
@@ -335,15 +341,15 @@ def train_d2_faster_rcnn(
|
|
|
335
341
|
:param pipeline_component_name: A pipeline component name to use for validation.
|
|
336
342
|
"""
|
|
337
343
|
|
|
338
|
-
assert
|
|
344
|
+
assert cuda.device_count() > 0, "Has to train with GPU!"
|
|
339
345
|
|
|
340
|
-
build_train_dict:
|
|
346
|
+
build_train_dict: dict[str, str] = {}
|
|
341
347
|
if build_train_config is not None:
|
|
342
348
|
build_train_dict = string_to_dict(",".join(build_train_config))
|
|
343
349
|
if "split" not in build_train_dict:
|
|
344
350
|
build_train_dict["split"] = "train"
|
|
345
351
|
|
|
346
|
-
build_val_dict:
|
|
352
|
+
build_val_dict: dict[str, str] = {}
|
|
347
353
|
if build_val_config is not None:
|
|
348
354
|
build_val_dict = string_to_dict(",".join(build_val_config))
|
|
349
355
|
if "split" not in build_val_dict:
|
|
@@ -353,9 +359,9 @@ def train_d2_faster_rcnn(
|
|
|
353
359
|
config_overwrite = []
|
|
354
360
|
conf_list = [
|
|
355
361
|
"MODEL.WEIGHTS",
|
|
356
|
-
path_weights,
|
|
362
|
+
os.fspath(path_weights),
|
|
357
363
|
"OUTPUT_DIR",
|
|
358
|
-
log_dir,
|
|
364
|
+
os.fspath(log_dir),
|
|
359
365
|
]
|
|
360
366
|
for conf in config_overwrite:
|
|
361
367
|
key, val = conf.split("=", maxsplit=1)
|
|
@@ -371,11 +377,13 @@ def train_d2_faster_rcnn(
|
|
|
371
377
|
if metric_name is not None:
|
|
372
378
|
metric = metric_registry.get(metric_name)
|
|
373
379
|
|
|
374
|
-
dataset = DatasetAdapter(
|
|
380
|
+
dataset = DatasetAdapter(
|
|
381
|
+
dataset_train, True, image_to_d2_frcnn_training(False), True, number_repetitions=-1, **build_train_dict
|
|
382
|
+
)
|
|
375
383
|
augment_list = [ResizeShortestEdge(cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MAX_SIZE_TRAIN), RandomFlip()]
|
|
376
384
|
mapper = DatasetMapper(is_train=True, augmentations=augment_list, image_format="BGR")
|
|
377
385
|
|
|
378
|
-
logger.info(LoggingRecord(f"Config: \n {str(cfg)}", cfg
|
|
386
|
+
logger.info(LoggingRecord(f"Config: \n {str(cfg)}", dict(cfg)))
|
|
379
387
|
|
|
380
388
|
trainer = D2Trainer(cfg, dataset, mapper)
|
|
381
389
|
trainer.resume_or_load()
|
|
@@ -386,7 +394,6 @@ def train_d2_faster_rcnn(
|
|
|
386
394
|
detector = D2FrcnnDetector(path_config_yaml, path_weights, categories, config_overwrite, cfg.MODEL.DEVICE)
|
|
387
395
|
pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
|
|
388
396
|
pipeline_component = pipeline_component_cls(detector)
|
|
389
|
-
assert isinstance(pipeline_component, PredictorPipelineComponent)
|
|
390
397
|
|
|
391
398
|
if metric_name is not None:
|
|
392
399
|
metric = metric_registry.get(metric_name)
|
|
@@ -19,20 +19,13 @@
|
|
|
19
19
|
Module for training Hugging Face Detr implementation. Note, that this scripts only trans Tabletransformer like Detr
|
|
20
20
|
models that are a slightly different from the plain Detr model that are provided by the transformer library.
|
|
21
21
|
"""
|
|
22
|
+
from __future__ import annotations
|
|
22
23
|
|
|
23
24
|
import copy
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
from
|
|
28
|
-
from transformers import (
|
|
29
|
-
AutoFeatureExtractor,
|
|
30
|
-
IntervalStrategy,
|
|
31
|
-
PretrainedConfig,
|
|
32
|
-
PreTrainedModel,
|
|
33
|
-
TableTransformerForObjectDetection,
|
|
34
|
-
)
|
|
35
|
-
from transformers.trainer import Trainer, TrainingArguments
|
|
25
|
+
import os
|
|
26
|
+
from typing import Any, Optional, Sequence, Type, Union
|
|
27
|
+
|
|
28
|
+
from lazy_imports import try_import
|
|
36
29
|
|
|
37
30
|
from ..datasets.adapter import DatasetAdapter
|
|
38
31
|
from ..datasets.base import DatasetBase
|
|
@@ -42,11 +35,27 @@ from ..eval.eval import Evaluator
|
|
|
42
35
|
from ..eval.registry import metric_registry
|
|
43
36
|
from ..extern.hfdetr import HFDetrDerivedDetector
|
|
44
37
|
from ..mapper.hfstruct import DetrDataCollator, image_to_hf_detr_training
|
|
45
|
-
from ..pipe.base import
|
|
38
|
+
from ..pipe.base import PipelineComponent
|
|
46
39
|
from ..pipe.registry import pipeline_component_registry
|
|
47
40
|
from ..utils.logger import LoggingRecord, logger
|
|
41
|
+
from ..utils.types import PathLikeOrStr
|
|
48
42
|
from ..utils.utils import string_to_dict
|
|
49
43
|
|
|
44
|
+
with try_import() as pt_import_guard:
|
|
45
|
+
from torch import nn
|
|
46
|
+
from torch.utils.data import Dataset
|
|
47
|
+
|
|
48
|
+
with try_import() as hf_import_guard:
|
|
49
|
+
from transformers import (
|
|
50
|
+
AutoFeatureExtractor,
|
|
51
|
+
IntervalStrategy,
|
|
52
|
+
PretrainedConfig,
|
|
53
|
+
PreTrainedModel,
|
|
54
|
+
TableTransformerForObjectDetection,
|
|
55
|
+
Trainer,
|
|
56
|
+
TrainingArguments,
|
|
57
|
+
)
|
|
58
|
+
|
|
50
59
|
|
|
51
60
|
class DetrDerivedTrainer(Trainer):
|
|
52
61
|
"""
|
|
@@ -61,19 +70,19 @@ class DetrDerivedTrainer(Trainer):
|
|
|
61
70
|
|
|
62
71
|
def __init__(
|
|
63
72
|
self,
|
|
64
|
-
model: Union[PreTrainedModel, Module],
|
|
73
|
+
model: Union[PreTrainedModel, nn.Module],
|
|
65
74
|
args: TrainingArguments,
|
|
66
75
|
data_collator: DetrDataCollator,
|
|
67
76
|
train_dataset: Dataset[Any],
|
|
68
77
|
):
|
|
69
78
|
self.evaluator: Optional[Evaluator] = None
|
|
70
|
-
self.build_eval_kwargs: Optional[
|
|
79
|
+
self.build_eval_kwargs: Optional[dict[str, Any]] = None
|
|
71
80
|
super().__init__(model, args, data_collator, train_dataset)
|
|
72
81
|
|
|
73
82
|
def setup_evaluator(
|
|
74
83
|
self,
|
|
75
84
|
dataset_val: DatasetBase,
|
|
76
|
-
pipeline_component:
|
|
85
|
+
pipeline_component: PipelineComponent,
|
|
77
86
|
metric: Union[Type[MetricBase], MetricBase],
|
|
78
87
|
**build_eval_kwargs: Union[str, int],
|
|
79
88
|
) -> None:
|
|
@@ -90,17 +99,15 @@ class DetrDerivedTrainer(Trainer):
|
|
|
90
99
|
self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1)
|
|
91
100
|
assert self.evaluator.pipe_component
|
|
92
101
|
for comp in self.evaluator.pipe_component.pipe_components:
|
|
93
|
-
|
|
94
|
-
assert isinstance(comp.predictor, HFDetrDerivedDetector)
|
|
95
|
-
comp.predictor.hf_detr_predictor = None
|
|
102
|
+
comp.clear_predictor()
|
|
96
103
|
self.build_eval_kwargs = build_eval_kwargs
|
|
97
104
|
|
|
98
105
|
def evaluate(
|
|
99
106
|
self,
|
|
100
107
|
eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
|
|
101
|
-
ignore_keys: Optional[
|
|
108
|
+
ignore_keys: Optional[list[str]] = None, # pylint: disable=W0613
|
|
102
109
|
metric_key_prefix: str = "eval", # pylint: disable=W0613
|
|
103
|
-
) ->
|
|
110
|
+
) -> dict[str, float]:
|
|
104
111
|
"""
|
|
105
112
|
Overwritten method from `Trainer`. Arguments will not be used.
|
|
106
113
|
"""
|
|
@@ -122,12 +129,12 @@ class DetrDerivedTrainer(Trainer):
|
|
|
122
129
|
|
|
123
130
|
|
|
124
131
|
def train_hf_detr(
|
|
125
|
-
path_config_json:
|
|
132
|
+
path_config_json: PathLikeOrStr,
|
|
126
133
|
dataset_train: Union[str, DatasetBase],
|
|
127
|
-
path_weights:
|
|
134
|
+
path_weights: PathLikeOrStr,
|
|
128
135
|
path_feature_extractor_config_json: str,
|
|
129
|
-
config_overwrite: Optional[
|
|
130
|
-
log_dir:
|
|
136
|
+
config_overwrite: Optional[list[str]] = None,
|
|
137
|
+
log_dir: PathLikeOrStr = "train_log/detr",
|
|
131
138
|
build_train_config: Optional[Sequence[str]] = None,
|
|
132
139
|
dataset_val: Optional[DatasetBase] = None,
|
|
133
140
|
build_val_config: Optional[Sequence[str]] = None,
|
|
@@ -162,13 +169,13 @@ def train_hf_detr(
|
|
|
162
169
|
:param pipeline_component_name: A pipeline component name to use for validation
|
|
163
170
|
"""
|
|
164
171
|
|
|
165
|
-
build_train_dict:
|
|
172
|
+
build_train_dict: dict[str, str] = {}
|
|
166
173
|
if build_train_config is not None:
|
|
167
174
|
build_train_dict = string_to_dict(",".join(build_train_config))
|
|
168
175
|
if "split" not in build_train_dict:
|
|
169
176
|
build_train_dict["split"] = "train"
|
|
170
177
|
|
|
171
|
-
build_val_dict:
|
|
178
|
+
build_val_dict: dict[str, str] = {}
|
|
172
179
|
if build_val_config is not None:
|
|
173
180
|
build_val_dict = string_to_dict(",".join(build_val_config))
|
|
174
181
|
if "split" not in build_val_dict:
|
|
@@ -184,12 +191,17 @@ def train_hf_detr(
|
|
|
184
191
|
categories_dict_name_as_key = dataset_train.dataflow.categories.get_categories(name_as_key=True, filtered=True)
|
|
185
192
|
|
|
186
193
|
dataset = DatasetAdapter(
|
|
187
|
-
dataset_train,
|
|
194
|
+
dataset_train,
|
|
195
|
+
True,
|
|
196
|
+
image_to_hf_detr_training(category_names=categories),
|
|
197
|
+
True,
|
|
198
|
+
number_repetitions=-1,
|
|
199
|
+
**build_train_dict,
|
|
188
200
|
)
|
|
189
201
|
|
|
190
202
|
number_samples = len(dataset)
|
|
191
203
|
conf_dict = {
|
|
192
|
-
"output_dir": log_dir,
|
|
204
|
+
"output_dir": os.fspath(log_dir),
|
|
193
205
|
"remove_unused_columns": False,
|
|
194
206
|
"per_device_train_batch_size": 2,
|
|
195
207
|
"max_steps": number_samples,
|
|
@@ -249,7 +261,6 @@ def train_hf_detr(
|
|
|
249
261
|
)
|
|
250
262
|
pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
|
|
251
263
|
pipeline_component = pipeline_component_cls(detector)
|
|
252
|
-
assert isinstance(pipeline_component, PredictorPipelineComponent)
|
|
253
264
|
|
|
254
265
|
if metric_name is not None:
|
|
255
266
|
metric = metric_registry.get(metric_name)
|