deepdoctection 0.30__py3-none-any.whl → 0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (120) hide show
  1. deepdoctection/__init__.py +38 -29
  2. deepdoctection/analyzer/dd.py +36 -29
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/base.py +0 -19
  5. deepdoctection/dataflow/custom.py +4 -3
  6. deepdoctection/dataflow/custom_serialize.py +14 -5
  7. deepdoctection/dataflow/parallel_map.py +12 -11
  8. deepdoctection/dataflow/serialize.py +5 -4
  9. deepdoctection/datapoint/annotation.py +35 -13
  10. deepdoctection/datapoint/box.py +3 -5
  11. deepdoctection/datapoint/convert.py +3 -1
  12. deepdoctection/datapoint/image.py +79 -36
  13. deepdoctection/datapoint/view.py +152 -49
  14. deepdoctection/datasets/__init__.py +1 -4
  15. deepdoctection/datasets/adapter.py +6 -3
  16. deepdoctection/datasets/base.py +86 -11
  17. deepdoctection/datasets/dataflow_builder.py +1 -1
  18. deepdoctection/datasets/info.py +4 -4
  19. deepdoctection/datasets/instances/doclaynet.py +3 -2
  20. deepdoctection/datasets/instances/fintabnet.py +2 -1
  21. deepdoctection/datasets/instances/funsd.py +2 -1
  22. deepdoctection/datasets/instances/iiitar13k.py +5 -2
  23. deepdoctection/datasets/instances/layouttest.py +4 -8
  24. deepdoctection/datasets/instances/publaynet.py +2 -2
  25. deepdoctection/datasets/instances/pubtables1m.py +6 -3
  26. deepdoctection/datasets/instances/pubtabnet.py +2 -1
  27. deepdoctection/datasets/instances/rvlcdip.py +2 -1
  28. deepdoctection/datasets/instances/xfund.py +2 -1
  29. deepdoctection/eval/__init__.py +1 -4
  30. deepdoctection/eval/accmetric.py +1 -1
  31. deepdoctection/eval/base.py +5 -4
  32. deepdoctection/eval/cocometric.py +2 -1
  33. deepdoctection/eval/eval.py +19 -15
  34. deepdoctection/eval/tedsmetric.py +14 -11
  35. deepdoctection/eval/tp_eval_callback.py +14 -7
  36. deepdoctection/extern/__init__.py +2 -7
  37. deepdoctection/extern/base.py +39 -13
  38. deepdoctection/extern/d2detect.py +182 -90
  39. deepdoctection/extern/deskew.py +36 -9
  40. deepdoctection/extern/doctrocr.py +265 -83
  41. deepdoctection/extern/fastlang.py +49 -9
  42. deepdoctection/extern/hfdetr.py +106 -55
  43. deepdoctection/extern/hflayoutlm.py +441 -122
  44. deepdoctection/extern/hflm.py +225 -0
  45. deepdoctection/extern/model.py +56 -47
  46. deepdoctection/extern/pdftext.py +10 -5
  47. deepdoctection/extern/pt/__init__.py +1 -3
  48. deepdoctection/extern/pt/nms.py +6 -2
  49. deepdoctection/extern/pt/ptutils.py +27 -18
  50. deepdoctection/extern/tessocr.py +134 -22
  51. deepdoctection/extern/texocr.py +6 -2
  52. deepdoctection/extern/tp/tfutils.py +43 -9
  53. deepdoctection/extern/tp/tpcompat.py +14 -11
  54. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  55. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  56. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
  58. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
  60. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  61. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
  62. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  67. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  68. deepdoctection/extern/tp/tpfrcnn/preproc.py +8 -9
  69. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  70. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  71. deepdoctection/extern/tpdetect.py +54 -30
  72. deepdoctection/mapper/__init__.py +3 -8
  73. deepdoctection/mapper/d2struct.py +9 -7
  74. deepdoctection/mapper/hfstruct.py +7 -2
  75. deepdoctection/mapper/laylmstruct.py +164 -21
  76. deepdoctection/mapper/maputils.py +16 -3
  77. deepdoctection/mapper/misc.py +6 -3
  78. deepdoctection/mapper/prodigystruct.py +1 -1
  79. deepdoctection/mapper/pubstruct.py +10 -10
  80. deepdoctection/mapper/tpstruct.py +3 -3
  81. deepdoctection/pipe/__init__.py +1 -1
  82. deepdoctection/pipe/anngen.py +35 -8
  83. deepdoctection/pipe/base.py +53 -19
  84. deepdoctection/pipe/common.py +23 -13
  85. deepdoctection/pipe/concurrency.py +2 -1
  86. deepdoctection/pipe/doctectionpipe.py +2 -2
  87. deepdoctection/pipe/language.py +3 -2
  88. deepdoctection/pipe/layout.py +6 -3
  89. deepdoctection/pipe/lm.py +34 -66
  90. deepdoctection/pipe/order.py +142 -35
  91. deepdoctection/pipe/refine.py +26 -24
  92. deepdoctection/pipe/segment.py +21 -16
  93. deepdoctection/pipe/{cell.py → sub_layout.py} +30 -9
  94. deepdoctection/pipe/text.py +14 -8
  95. deepdoctection/pipe/transform.py +16 -9
  96. deepdoctection/train/__init__.py +6 -12
  97. deepdoctection/train/d2_frcnn_train.py +36 -28
  98. deepdoctection/train/hf_detr_train.py +26 -17
  99. deepdoctection/train/hf_layoutlm_train.py +133 -111
  100. deepdoctection/train/tp_frcnn_train.py +21 -19
  101. deepdoctection/utils/__init__.py +3 -0
  102. deepdoctection/utils/concurrency.py +1 -1
  103. deepdoctection/utils/context.py +2 -2
  104. deepdoctection/utils/env_info.py +41 -84
  105. deepdoctection/utils/error.py +84 -0
  106. deepdoctection/utils/file_utils.py +4 -15
  107. deepdoctection/utils/fs.py +7 -7
  108. deepdoctection/utils/logger.py +1 -0
  109. deepdoctection/utils/mocks.py +93 -0
  110. deepdoctection/utils/pdf_utils.py +5 -4
  111. deepdoctection/utils/settings.py +6 -1
  112. deepdoctection/utils/transform.py +1 -1
  113. deepdoctection/utils/utils.py +0 -6
  114. deepdoctection/utils/viz.py +48 -5
  115. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/METADATA +57 -73
  116. deepdoctection-0.32.dist-info/RECORD +146 -0
  117. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/WHEEL +1 -1
  118. deepdoctection-0.30.dist-info/RECORD +0 -143
  119. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
  120. {deepdoctection-0.30.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,12 @@
18
18
  """
19
19
  Module for training Detectron2 `GeneralizedRCNN`
20
20
  """
21
-
21
+ from __future__ import annotations
22
22
 
23
23
  import copy
24
24
  from typing import Any, Dict, List, Mapping, Optional, Sequence, Type, Union
25
25
 
26
- from detectron2.config import CfgNode, get_cfg
27
- from detectron2.data import DatasetMapper, build_detection_train_loader
28
- from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
29
- from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
30
- from detectron2.utils import comm
31
- from detectron2.utils.events import EventWriter, get_event_storage
32
- from fvcore.nn.precise_bn import get_bn_modules # type: ignore
33
- from torch.utils.data import DataLoader, IterableDataset
26
+ from lazy_imports import try_import
34
27
 
35
28
  from ..datasets.adapter import DatasetAdapter
36
29
  from ..datasets.base import DatasetBase
@@ -39,15 +32,28 @@ from ..eval.base import MetricBase
39
32
  from ..eval.eval import Evaluator
40
33
  from ..eval.registry import metric_registry
41
34
  from ..extern.d2detect import D2FrcnnDetector
42
- from ..extern.pt.ptutils import get_num_gpu
43
35
  from ..mapper.d2struct import image_to_d2_frcnn_training
44
36
  from ..pipe.base import PredictorPipelineComponent
45
37
  from ..pipe.registry import pipeline_component_registry
38
+ from ..utils.error import DependencyError
46
39
  from ..utils.file_utils import get_wandb_requirement, wandb_available
47
40
  from ..utils.logger import LoggingRecord, logger
48
41
  from ..utils.utils import string_to_dict
49
42
 
50
- if wandb_available():
43
+ with try_import() as d2_import_guard:
44
+ from detectron2.config import CfgNode, get_cfg
45
+ from detectron2.data import DatasetMapper, build_detection_train_loader
46
+ from detectron2.data.transforms import RandomFlip, ResizeShortestEdge
47
+ from detectron2.engine import DefaultTrainer, HookBase, default_writers, hooks
48
+ from detectron2.utils import comm
49
+ from detectron2.utils.events import EventWriter, get_event_storage
50
+ from fvcore.nn.precise_bn import get_bn_modules # type: ignore
51
+
52
+ with try_import() as pt_import_guard:
53
+ from torch import cuda
54
+ from torch.utils.data import DataLoader, IterableDataset
55
+
56
+ with try_import() as wb_import_guard:
51
57
  import wandb
52
58
 
53
59
 
@@ -111,7 +117,7 @@ class WandbWriter(EventWriter):
111
117
  config = {}
112
118
  self._window_size = window_size
113
119
  self._run = wandb.init(project=project, config=config, **kwargs) if not wandb.run else wandb.run
114
- self._run._label(repo=repo) # type:ignore
120
+ self._run._label(repo=repo)
115
121
 
116
122
  def write(self) -> None:
117
123
  storage = get_event_storage()
@@ -120,10 +126,10 @@ class WandbWriter(EventWriter):
120
126
  for key, (val, _) in storage.latest_with_smoothing_hint(self._window_size).items():
121
127
  log_dict[key] = val
122
128
 
123
- self._run.log(log_dict) # type:ignore
129
+ self._run.log(log_dict)
124
130
 
125
131
  def close(self) -> None:
126
- self._run.finish() # type:ignore
132
+ self._run.finish()
127
133
 
128
134
 
129
135
  class D2Trainer(DefaultTrainer):
@@ -153,16 +159,18 @@ class D2Trainer(DefaultTrainer):
153
159
  ret = [
154
160
  hooks.IterationTimer(),
155
161
  hooks.LRScheduler(),
156
- hooks.PreciseBN(
157
- # Run at the same freq as (but before) evaluation.
158
- cfg.TEST.EVAL_PERIOD,
159
- self.model, # pylint: disable=E1101
160
- # Build a new data loader to not affect training
161
- self.build_train_loader(cfg),
162
- cfg.TEST.PRECISE_BN.NUM_ITER,
163
- )
164
- if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) # pylint: disable=E1101
165
- else None,
162
+ (
163
+ hooks.PreciseBN(
164
+ # Run at the same freq as (but before) evaluation.
165
+ cfg.TEST.EVAL_PERIOD,
166
+ self.model, # pylint: disable=E1101
167
+ # Build a new data loader to not affect training
168
+ self.build_train_loader(cfg),
169
+ cfg.TEST.PRECISE_BN.NUM_ITER,
170
+ )
171
+ if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) # pylint: disable=E1101
172
+ else None
173
+ ),
166
174
  ]
167
175
 
168
176
  # Do PreciseBN before checkpointer, because it updates the model and need to
@@ -201,7 +209,7 @@ class D2Trainer(DefaultTrainer):
201
209
  if self.cfg.WANDB.USE_WANDB:
202
210
  _, _wandb_available, err_msg = get_wandb_requirement()
203
211
  if not _wandb_available:
204
- raise ImportError(err_msg)
212
+ raise DependencyError(err_msg)
205
213
  if self.cfg.WANDB.PROJECT is None:
206
214
  raise ValueError("When using W&B, you must specify a project, i.e. WANDB.PROJECT")
207
215
  writers_list.append(WandbWriter(self.cfg.WANDB.PROJECT, self.cfg.WANDB.REPO, self.cfg))
@@ -256,7 +264,7 @@ class D2Trainer(DefaultTrainer):
256
264
  dataset_val,
257
265
  pipeline_component,
258
266
  metric,
259
- num_threads=get_num_gpu() * 2,
267
+ num_threads=cuda.device_count() * 2,
260
268
  run=run,
261
269
  )
262
270
  if build_val_dict:
@@ -269,7 +277,7 @@ class D2Trainer(DefaultTrainer):
269
277
 
270
278
  @classmethod
271
279
  def build_evaluator(cls, cfg, dataset_name): # type: ignore
272
- raise NotImplementedError
280
+ raise NotImplementedError()
273
281
 
274
282
 
275
283
  def train_d2_faster_rcnn(
@@ -332,7 +340,7 @@ def train_d2_faster_rcnn(
332
340
  :param pipeline_component_name: A pipeline component name to use for validation.
333
341
  """
334
342
 
335
- assert get_num_gpu() > 0, "Has to train with GPU!"
343
+ assert cuda.device_count() > 0, "Has to train with GPU!"
336
344
 
337
345
  build_train_dict: Dict[str, str] = {}
338
346
  if build_train_config is not None:
@@ -19,20 +19,12 @@
19
19
  Module for training Hugging Face Detr implementation. Note, that this scripts only trans Tabletransformer like Detr
20
20
  models that are a slightly different from the plain Detr model that are provided by the transformer library.
21
21
  """
22
+ from __future__ import annotations
22
23
 
23
24
  import copy
24
25
  from typing import Any, Dict, List, Optional, Sequence, Type, Union
25
26
 
26
- from torch.nn import Module
27
- from torch.utils.data import Dataset
28
- from transformers import (
29
- AutoFeatureExtractor,
30
- IntervalStrategy,
31
- PretrainedConfig,
32
- PreTrainedModel,
33
- TableTransformerForObjectDetection,
34
- )
35
- from transformers.trainer import Trainer, TrainingArguments
27
+ from lazy_imports import try_import
36
28
 
37
29
  from ..datasets.adapter import DatasetAdapter
38
30
  from ..datasets.base import DatasetBase
@@ -47,6 +39,21 @@ from ..pipe.registry import pipeline_component_registry
47
39
  from ..utils.logger import LoggingRecord, logger
48
40
  from ..utils.utils import string_to_dict
49
41
 
42
+ with try_import() as pt_import_guard:
43
+ from torch import nn
44
+ from torch.utils.data import Dataset
45
+
46
+ with try_import() as hf_import_guard:
47
+ from transformers import (
48
+ AutoFeatureExtractor,
49
+ IntervalStrategy,
50
+ PretrainedConfig,
51
+ PreTrainedModel,
52
+ TableTransformerForObjectDetection,
53
+ Trainer,
54
+ TrainingArguments,
55
+ )
56
+
50
57
 
51
58
  class DetrDerivedTrainer(Trainer):
52
59
  """
@@ -61,7 +68,7 @@ class DetrDerivedTrainer(Trainer):
61
68
 
62
69
  def __init__(
63
70
  self,
64
- model: Union[PreTrainedModel, Module],
71
+ model: Union[PreTrainedModel, nn.Module],
65
72
  args: TrainingArguments,
66
73
  data_collator: DetrDataCollator,
67
74
  train_dataset: Dataset[Any],
@@ -97,9 +104,9 @@ class DetrDerivedTrainer(Trainer):
97
104
 
98
105
  def evaluate(
99
106
  self,
100
- eval_dataset: Optional[Dataset[Any]] = None,
101
- ignore_keys: Optional[List[str]] = None,
102
- metric_key_prefix: str = "eval",
107
+ eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
108
+ ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
109
+ metric_key_prefix: str = "eval", # pylint: disable=W0613
103
110
  ) -> Dict[str, float]:
104
111
  """
105
112
  Overwritten method from `Trainer`. Arguments will not be used.
@@ -193,9 +200,11 @@ def train_hf_detr(
193
200
  "remove_unused_columns": False,
194
201
  "per_device_train_batch_size": 2,
195
202
  "max_steps": number_samples,
196
- "evaluation_strategy": "steps"
197
- if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
198
- else "no",
203
+ "evaluation_strategy": (
204
+ "steps"
205
+ if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
206
+ else "no"
207
+ ),
199
208
  "eval_steps": 5000,
200
209
  }
201
210
 
@@ -18,32 +18,15 @@
18
18
  """
19
19
  Module for training Huggingface implementation of LayoutLm
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
23
  import copy
23
24
  import json
24
25
  import os
25
26
  import pprint
26
- from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
27
-
28
- from torch.nn import Module
29
- from torch.utils.data import Dataset
30
- from transformers import (
31
- IntervalStrategy,
32
- LayoutLMForSequenceClassification,
33
- LayoutLMForTokenClassification,
34
- LayoutLMTokenizerFast,
35
- LayoutLMv2Config,
36
- LayoutLMv2ForSequenceClassification,
37
- LayoutLMv2ForTokenClassification,
38
- LayoutLMv3Config,
39
- LayoutLMv3ForSequenceClassification,
40
- LayoutLMv3ForTokenClassification,
41
- PretrainedConfig,
42
- PreTrainedModel,
43
- RobertaTokenizerFast,
44
- XLMRobertaTokenizerFast,
45
- )
46
- from transformers.trainer import Trainer, TrainingArguments
27
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
28
+
29
+ from lazy_imports import try_import
47
30
 
48
31
  from ..datasets.adapter import DatasetAdapter
49
32
  from ..datasets.base import DatasetBase
@@ -57,78 +40,108 @@ from ..extern.hflayoutlm import (
57
40
  HFLayoutLmv2TokenClassifier,
58
41
  HFLayoutLmv3SequenceClassifier,
59
42
  HFLayoutLmv3TokenClassifier,
43
+ HFLiltSequenceClassifier,
44
+ HFLiltTokenClassifier,
45
+ get_tokenizer_from_model_class,
60
46
  )
61
- from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features
47
+ from ..extern.hflm import HFLmSequenceClassifier
48
+ from ..extern.pt.ptutils import get_torch_device
49
+ from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features, image_to_raw_lm_features
62
50
  from ..pipe.base import LanguageModelPipelineComponent
63
- from ..pipe.lm import get_tokenizer_from_architecture
64
51
  from ..pipe.registry import pipeline_component_registry
65
- from ..utils.env_info import get_device
52
+ from ..utils.error import DependencyError
66
53
  from ..utils.file_utils import wandb_available
67
54
  from ..utils.logger import LoggingRecord, logger
68
- from ..utils.settings import DatasetType, LayoutType, ObjectTypes, WordType
55
+ from ..utils.settings import DatasetType, LayoutType, WordType
69
56
  from ..utils.utils import string_to_dict
70
57
 
71
- if wandb_available():
72
- import wandb
73
-
74
- _ARCHITECTURES_TO_MODEL_CLASS = {
75
- "LayoutLMForTokenClassification": (LayoutLMForTokenClassification, HFLayoutLmTokenClassifier, PretrainedConfig),
76
- "LayoutLMForSequenceClassification": (
77
- LayoutLMForSequenceClassification,
78
- HFLayoutLmSequenceClassifier,
79
- PretrainedConfig,
80
- ),
81
- "LayoutLMv2ForTokenClassification": (
82
- LayoutLMv2ForTokenClassification,
83
- HFLayoutLmv2TokenClassifier,
84
- LayoutLMv2Config,
85
- ),
86
- "LayoutLMv2ForSequenceClassification": (
87
- LayoutLMv2ForSequenceClassification,
88
- HFLayoutLmv2SequenceClassifier,
89
- LayoutLMv2Config,
90
- ),
91
- }
92
-
58
+ with try_import() as pt_import_guard:
59
+ from torch import nn
60
+ from torch.utils.data import Dataset
93
61
 
94
- _MODEL_TYPE_AND_TASK_TO_MODEL_CLASS: Mapping[Tuple[str, ObjectTypes], Any] = {
95
- ("layoutlm", DatasetType.sequence_classification): (
62
+ with try_import() as tr_import_guard:
63
+ from transformers import (
64
+ IntervalStrategy,
96
65
  LayoutLMForSequenceClassification,
97
- HFLayoutLmSequenceClassifier,
98
- PretrainedConfig,
99
- ),
100
- ("layoutlm", DatasetType.token_classification): (
101
66
  LayoutLMForTokenClassification,
102
- HFLayoutLmTokenClassifier,
103
- PretrainedConfig,
104
- ),
105
- ("layoutlmv2", DatasetType.sequence_classification): (
106
- LayoutLMv2ForSequenceClassification,
107
- HFLayoutLmv2SequenceClassifier,
108
67
  LayoutLMv2Config,
109
- ),
110
- ("layoutlmv2", DatasetType.token_classification): (
68
+ LayoutLMv2ForSequenceClassification,
111
69
  LayoutLMv2ForTokenClassification,
112
- HFLayoutLmv2TokenClassifier,
113
- LayoutLMv2Config,
114
- ),
115
- ("layoutlmv3", DatasetType.sequence_classification): (
116
- LayoutLMv3ForSequenceClassification,
117
- HFLayoutLmv3SequenceClassifier,
118
70
  LayoutLMv3Config,
119
- ),
120
- ("layoutlmv3", DatasetType.token_classification): (
71
+ LayoutLMv3ForSequenceClassification,
121
72
  LayoutLMv3ForTokenClassification,
122
- HFLayoutLmv3TokenClassifier,
123
- LayoutLMv3Config,
124
- ),
125
- }
126
- _MODEL_TYPE_TO_TOKENIZER = {
127
- ("layoutlm", False): LayoutLMTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased"),
128
- ("layoutlmv2", False): LayoutLMTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased"),
129
- ("layoutlmv2", True): XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base", add_prefix_space=True),
130
- ("layoutlmv3", False): RobertaTokenizerFast.from_pretrained("roberta-base", add_prefix_space=True),
131
- }
73
+ LiltForSequenceClassification,
74
+ LiltForTokenClassification,
75
+ PretrainedConfig,
76
+ PreTrainedModel,
77
+ XLMRobertaForSequenceClassification,
78
+ )
79
+ from transformers.trainer import Trainer, TrainingArguments
80
+
81
+ with try_import() as wb_import_guard:
82
+ import wandb
83
+
84
+
85
+ def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> Tuple[Any, Any, Any]:
86
+ """
87
+ Get the model architecture, model wrapper and config class for a given model type and dataset type.
88
+
89
+ :param model_type: The model type
90
+ :param dataset_type: The dataset type
91
+ :return: Tuple of model architecture, model wrapper and config class
92
+ """
93
+ return {
94
+ ("layoutlm", DatasetType.sequence_classification): (
95
+ LayoutLMForSequenceClassification,
96
+ HFLayoutLmSequenceClassifier,
97
+ PretrainedConfig,
98
+ ),
99
+ ("layoutlm", DatasetType.token_classification): (
100
+ LayoutLMForTokenClassification,
101
+ HFLayoutLmTokenClassifier,
102
+ PretrainedConfig,
103
+ ),
104
+ ("layoutlmv2", DatasetType.sequence_classification): (
105
+ LayoutLMv2ForSequenceClassification,
106
+ HFLayoutLmv2SequenceClassifier,
107
+ LayoutLMv2Config,
108
+ ),
109
+ ("layoutlmv2", DatasetType.token_classification): (
110
+ LayoutLMv2ForTokenClassification,
111
+ HFLayoutLmv2TokenClassifier,
112
+ LayoutLMv2Config,
113
+ ),
114
+ ("layoutlmv3", DatasetType.sequence_classification): (
115
+ LayoutLMv3ForSequenceClassification,
116
+ HFLayoutLmv3SequenceClassifier,
117
+ LayoutLMv3Config,
118
+ ),
119
+ ("layoutlmv3", DatasetType.token_classification): (
120
+ LayoutLMv3ForTokenClassification,
121
+ HFLayoutLmv3TokenClassifier,
122
+ LayoutLMv3Config,
123
+ ),
124
+ ("lilt", DatasetType.token_classification): (
125
+ LiltForTokenClassification,
126
+ HFLiltTokenClassifier,
127
+ PretrainedConfig,
128
+ ),
129
+ ("lilt", DatasetType.sequence_classification): (
130
+ LiltForSequenceClassification,
131
+ HFLiltSequenceClassifier,
132
+ PretrainedConfig,
133
+ ),
134
+ ("xlm-roberta", DatasetType.sequence_classification): (
135
+ XLMRobertaForSequenceClassification,
136
+ HFLmSequenceClassifier,
137
+ PretrainedConfig,
138
+ ),
139
+ }[(model_type, dataset_type)]
140
+
141
+
142
+ def maybe_remove_bounding_box_features(model_type: str) -> bool:
143
+ """Listing of models that do not need bounding box features."""
144
+ return {"xlm-roberta": True}.get(model_type, False)
132
145
 
133
146
 
134
147
  class LayoutLMTrainer(Trainer):
@@ -144,7 +157,7 @@ class LayoutLMTrainer(Trainer):
144
157
 
145
158
  def __init__(
146
159
  self,
147
- model: Union[PreTrainedModel, Module],
160
+ model: Union[PreTrainedModel, nn.Module],
148
161
  args: TrainingArguments,
149
162
  data_collator: LayoutLMDataCollator,
150
163
  train_dataset: Dataset[Any],
@@ -158,7 +171,7 @@ class LayoutLMTrainer(Trainer):
158
171
  dataset_val: DatasetBase,
159
172
  pipeline_component: LanguageModelPipelineComponent,
160
173
  metric: Union[Type[ClassificationMetric], ClassificationMetric],
161
- run: Optional["wandb.sdk.wandb_run.Run"] = None,
174
+ run: Optional[wandb.sdk.wandb_run.Run] = None,
162
175
  **build_eval_kwargs: Union[str, int],
163
176
  ) -> None:
164
177
  """
@@ -180,15 +193,17 @@ class LayoutLMTrainer(Trainer):
180
193
 
181
194
  def evaluate(
182
195
  self,
183
- eval_dataset: Optional[Dataset[Any]] = None,
184
- ignore_keys: Optional[List[str]] = None,
185
- metric_key_prefix: str = "eval",
196
+ eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
197
+ ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
198
+ metric_key_prefix: str = "eval", # pylint: disable=W0613
186
199
  ) -> Dict[str, float]:
187
200
  """
188
201
  Overwritten method from `Trainer`. Arguments will not be used.
189
202
  """
190
- assert self.evaluator is not None
191
- assert self.evaluator.pipe_component is not None
203
+ if self.evaluator is None:
204
+ raise ValueError("Evaluator not set up. Please use `setup_evaluator` before running evaluation")
205
+ if self.evaluator.pipe_component is None:
206
+ raise ValueError("Pipeline component not set up. Please use `setup_evaluator` before running evaluation")
192
207
 
193
208
  # memory metrics - must set up as early as possible
194
209
  self._memory_tracker.start()
@@ -205,26 +220,27 @@ class LayoutLMTrainer(Trainer):
205
220
 
206
221
 
207
222
  def _get_model_class_and_tokenizer(
208
- path_config_json: str, dataset_type: ObjectTypes, use_xlm_tokenizer: bool
209
- ) -> Tuple[Any, Any, Any, Any]:
223
+ path_config_json: str, dataset_type: DatasetType, use_xlm_tokenizer: bool
224
+ ) -> Tuple[Any, Any, Any, Any, Any]:
210
225
  with open(path_config_json, "r", encoding="UTF-8") as file:
211
226
  config_json = json.load(file)
212
227
 
213
- model_type = config_json.get("model_type")
214
-
215
- if architectures := config_json.get("architectures"):
216
- model_cls, model_wrapper_cls, config_cls = _ARCHITECTURES_TO_MODEL_CLASS[architectures[0]]
217
- tokenizer_fast = get_tokenizer_from_architecture(architectures[0], use_xlm_tokenizer)
218
- elif model_type:
219
- model_cls, model_wrapper_cls, config_cls = _MODEL_TYPE_AND_TASK_TO_MODEL_CLASS[(model_type, dataset_type)]
220
- tokenizer_fast = _MODEL_TYPE_TO_TOKENIZER[(model_type, use_xlm_tokenizer)]
228
+ if model_type := config_json.get("model_type"):
229
+ model_cls, model_wrapper_cls, config_cls = get_model_architectures_and_configs(model_type, dataset_type)
230
+ remove_box_features = maybe_remove_bounding_box_features(model_type)
221
231
  else:
222
- raise KeyError("model_type and architectures not available in configs")
232
+ raise KeyError("model_type not available in configs. It seems that the config is not valid")
233
+
234
+ tokenizer_fast = get_tokenizer_from_model_class(model_cls.__name__, use_xlm_tokenizer)
235
+ return config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features
223
236
 
224
- if not model_cls:
225
- raise ValueError("model not eligible to run with this framework")
226
237
 
227
- return config_cls, model_cls, model_wrapper_cls, tokenizer_fast
238
+ def get_image_to_raw_features_mapping(input_str: str) -> Any:
239
+ """Replacing eval functions"""
240
+ return {
241
+ "image_to_raw_layoutlm_features": image_to_raw_layoutlm_features,
242
+ "image_to_raw_lm_features": image_to_raw_lm_features,
243
+ }[input_str]
228
244
 
229
245
 
230
246
  def train_hf_layoutlm(
@@ -347,19 +363,21 @@ def train_hf_layoutlm(
347
363
  name_as_key=True,
348
364
  )[LayoutType.word][WordType.token_class]
349
365
  else:
350
- raise ValueError("Dataset type not supported for training")
366
+ raise UserWarning("Dataset type not supported for training")
351
367
 
352
- config_cls, model_cls, model_wrapper_cls, tokenizer_fast = _get_model_class_and_tokenizer(
368
+ config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features = _get_model_class_and_tokenizer(
353
369
  path_config_json, dataset_type, use_xlm_tokenizer
354
370
  )
355
- image_to_raw_layoutlm_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
371
+ image_to_raw_features_func = get_image_to_raw_features_mapping(model_wrapper_cls.image_to_raw_features_mapping())
372
+ image_to_raw_features_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
356
373
  if segment_positions:
357
- image_to_raw_layoutlm_kwargs["segment_positions"] = segment_positions # type: ignore
358
- image_to_raw_layoutlm_kwargs.update(model_wrapper_cls.default_kwargs_for_input_mapping())
374
+ image_to_raw_features_kwargs["segment_positions"] = segment_positions # type: ignore
375
+ image_to_raw_features_kwargs.update(model_wrapper_cls.default_kwargs_for_input_mapping())
376
+
359
377
  dataset = DatasetAdapter(
360
378
  dataset_train,
361
379
  True,
362
- image_to_raw_layoutlm_features(**image_to_raw_layoutlm_kwargs),
380
+ image_to_raw_features_func(**image_to_raw_features_kwargs),
363
381
  use_token_tag,
364
382
  **build_train_dict,
365
383
  )
@@ -374,9 +392,11 @@ def train_hf_layoutlm(
374
392
  "remove_unused_columns": False,
375
393
  "per_device_train_batch_size": 8,
376
394
  "max_steps": number_samples,
377
- "evaluation_strategy": "steps"
378
- if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
379
- else "no",
395
+ "evaluation_strategy": (
396
+ "steps"
397
+ if (dataset_val is not None and metric is not None and pipeline_component_name is not None)
398
+ else "no"
399
+ ),
380
400
  "eval_steps": 100,
381
401
  "use_wandb": False,
382
402
  "wandb_project": None,
@@ -416,7 +436,7 @@ def train_hf_layoutlm(
416
436
  run = None
417
437
  if use_wandb:
418
438
  if not wandb_available():
419
- raise ModuleNotFoundError("WandB must be installed separately")
439
+ raise DependencyError("WandB must be installed separately")
420
440
  run = wandb.init(project=wandb_project, config=conf_dict) # type: ignore
421
441
  run._label(repo=wandb_repo) # type: ignore # pylint: disable=W0212
422
442
  else:
@@ -448,6 +468,7 @@ def train_hf_layoutlm(
448
468
  return_tensors="pt",
449
469
  sliding_window_stride=sliding_window_stride, # type: ignore
450
470
  max_batch_size=max_batch_size, # type: ignore
471
+ remove_bounding_box_features=remove_box_features,
451
472
  )
452
473
  trainer = LayoutLMTrainer(model, arguments, data_collator, dataset)
453
474
 
@@ -470,7 +491,8 @@ def train_hf_layoutlm(
470
491
  path_config_json=path_config_json,
471
492
  path_weights=path_weights,
472
493
  categories=categories,
473
- device=get_device(),
494
+ device=get_torch_device(),
495
+ use_xlm_tokenizer=use_xlm_tokenizer,
474
496
  )
475
497
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
476
498
  if dataset_type == DatasetType.sequence_classification:
@@ -22,25 +22,7 @@ Module for training Tensorpack `GeneralizedRCNN`
22
22
  import os
23
23
  from typing import Dict, List, Optional, Sequence, Type, Union
24
24
 
25
- # pylint: disable=import-error
26
- from tensorpack.callbacks import (
27
- EstimatedTimeLeft,
28
- GPUMemoryTracker,
29
- GPUUtilizationTracker,
30
- HostMemoryTracker,
31
- ModelSaver,
32
- PeriodicCallback,
33
- ScheduledHyperParamSetter,
34
- SessionRunTimeout,
35
- ThroughputTracker,
36
- )
37
-
38
- # todo: check how dataflow import is directly possible without having AssertionError
39
- from tensorpack.dataflow import ProxyDataFlow, imgaug
40
- from tensorpack.input_source import QueueInput
41
- from tensorpack.tfutils import SmartInit
42
- from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
43
- from tensorpack.utils import logger
25
+ from lazy_imports import try_import
44
26
 
45
27
  from ..dataflow.base import DataFlow
46
28
  from ..dataflow.common import MapData
@@ -68,6 +50,26 @@ from ..utils.metacfg import AttrDict, set_config_by_yaml
68
50
  from ..utils.tqdm import get_tqdm
69
51
  from ..utils.utils import string_to_dict
70
52
 
53
+ with try_import() as tp_import_guard:
54
+ # todo: check how dataflow import is directly possible without having an AssertionError
55
+ # pylint: disable=import-error
56
+ from tensorpack.callbacks import (
57
+ EstimatedTimeLeft,
58
+ GPUMemoryTracker,
59
+ GPUUtilizationTracker,
60
+ HostMemoryTracker,
61
+ ModelSaver,
62
+ PeriodicCallback,
63
+ ScheduledHyperParamSetter,
64
+ SessionRunTimeout,
65
+ ThroughputTracker,
66
+ )
67
+ from tensorpack.dataflow import ProxyDataFlow, imgaug
68
+ from tensorpack.input_source import QueueInput
69
+ from tensorpack.tfutils import SmartInit
70
+ from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
71
+ from tensorpack.utils import logger
72
+
71
73
  __all__ = ["train_faster_rcnn"]
72
74
 
73
75
 
@@ -6,7 +6,10 @@ Init file for utils package
6
6
  """
7
7
  from typing import Optional, Tuple, Union, no_type_check
8
8
 
9
+ from .concurrency import *
9
10
  from .context import *
11
+ from .env_info import *
12
+ from .error import *
10
13
  from .file_utils import *
11
14
  from .fs import *
12
15
  from .identifier import *
@@ -109,7 +109,7 @@ def enable_death_signal(_warn: bool = True) -> None:
109
109
  prctl, "set_pdeathsig"
110
110
  ), "prctl.set_pdeathsig does not exist! Note that you need to install 'python-prctl' instead of 'prctl'."
111
111
  # is SIGHUP a good choice?
112
- prctl.set_pdeathsig(signal.SIGHUP)
112
+ prctl.set_pdeathsig(signal.SIGHUP) # pylint: disable=E1101
113
113
 
114
114
 
115
115
  # taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
@@ -61,7 +61,7 @@ def timeout_manager(proc, seconds: Optional[int] = None) -> Iterator[str]: # ty
61
61
  proc.terminate()
62
62
  proc.kill()
63
63
  proc.returncode = -1
64
- raise RuntimeError("Tesseract process timeout") # pylint: disable=W0707
64
+ raise RuntimeError(f"timeout for process id: {proc.pid}") # pylint: disable=W0707
65
65
  finally:
66
66
  if proc.stdin is not None:
67
67
  proc.stdin.close()
@@ -88,7 +88,7 @@ def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[
88
88
  yield file.name, path.realpath(path.normpath(path.normcase(image)))
89
89
  return
90
90
  if isinstance(image, (np.ndarray, np.generic)):
91
- input_file_name = file.name + ".PNG"
91
+ input_file_name = file.name + "_input.PNG"
92
92
  viz_handler.write_image(input_file_name, image)
93
93
  yield file.name, input_file_name
94
94
  if isinstance(image, bytes):