deepdoctection 0.31__py3-none-any.whl → 0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (131) hide show
  1. deepdoctection/__init__.py +16 -29
  2. deepdoctection/analyzer/dd.py +70 -59
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/dataflow/common.py +9 -5
  5. deepdoctection/dataflow/custom.py +5 -5
  6. deepdoctection/dataflow/custom_serialize.py +75 -18
  7. deepdoctection/dataflow/parallel_map.py +3 -3
  8. deepdoctection/dataflow/serialize.py +4 -4
  9. deepdoctection/dataflow/stats.py +3 -3
  10. deepdoctection/datapoint/annotation.py +41 -56
  11. deepdoctection/datapoint/box.py +9 -8
  12. deepdoctection/datapoint/convert.py +6 -6
  13. deepdoctection/datapoint/image.py +56 -44
  14. deepdoctection/datapoint/view.py +245 -150
  15. deepdoctection/datasets/__init__.py +1 -4
  16. deepdoctection/datasets/adapter.py +35 -26
  17. deepdoctection/datasets/base.py +14 -12
  18. deepdoctection/datasets/dataflow_builder.py +3 -3
  19. deepdoctection/datasets/info.py +24 -26
  20. deepdoctection/datasets/instances/doclaynet.py +51 -51
  21. deepdoctection/datasets/instances/fintabnet.py +46 -46
  22. deepdoctection/datasets/instances/funsd.py +25 -24
  23. deepdoctection/datasets/instances/iiitar13k.py +13 -10
  24. deepdoctection/datasets/instances/layouttest.py +4 -3
  25. deepdoctection/datasets/instances/publaynet.py +5 -5
  26. deepdoctection/datasets/instances/pubtables1m.py +24 -21
  27. deepdoctection/datasets/instances/pubtabnet.py +32 -30
  28. deepdoctection/datasets/instances/rvlcdip.py +30 -30
  29. deepdoctection/datasets/instances/xfund.py +26 -26
  30. deepdoctection/datasets/save.py +6 -6
  31. deepdoctection/eval/__init__.py +1 -4
  32. deepdoctection/eval/accmetric.py +32 -33
  33. deepdoctection/eval/base.py +8 -9
  34. deepdoctection/eval/cocometric.py +15 -13
  35. deepdoctection/eval/eval.py +41 -37
  36. deepdoctection/eval/tedsmetric.py +30 -23
  37. deepdoctection/eval/tp_eval_callback.py +16 -19
  38. deepdoctection/extern/__init__.py +2 -7
  39. deepdoctection/extern/base.py +339 -134
  40. deepdoctection/extern/d2detect.py +85 -113
  41. deepdoctection/extern/deskew.py +14 -11
  42. deepdoctection/extern/doctrocr.py +141 -130
  43. deepdoctection/extern/fastlang.py +27 -18
  44. deepdoctection/extern/hfdetr.py +71 -62
  45. deepdoctection/extern/hflayoutlm.py +504 -211
  46. deepdoctection/extern/hflm.py +230 -0
  47. deepdoctection/extern/model.py +488 -302
  48. deepdoctection/extern/pdftext.py +23 -19
  49. deepdoctection/extern/pt/__init__.py +1 -3
  50. deepdoctection/extern/pt/nms.py +6 -2
  51. deepdoctection/extern/pt/ptutils.py +29 -19
  52. deepdoctection/extern/tessocr.py +39 -38
  53. deepdoctection/extern/texocr.py +18 -18
  54. deepdoctection/extern/tp/tfutils.py +57 -9
  55. deepdoctection/extern/tp/tpcompat.py +21 -14
  56. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  57. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/config/config.py +13 -10
  60. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  61. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +18 -8
  62. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  63. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +14 -9
  64. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  65. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +22 -17
  66. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +21 -14
  67. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +19 -11
  68. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  69. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  70. deepdoctection/extern/tp/tpfrcnn/preproc.py +12 -8
  71. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  72. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  73. deepdoctection/extern/tpdetect.py +45 -53
  74. deepdoctection/mapper/__init__.py +3 -8
  75. deepdoctection/mapper/cats.py +27 -29
  76. deepdoctection/mapper/cocostruct.py +10 -10
  77. deepdoctection/mapper/d2struct.py +27 -26
  78. deepdoctection/mapper/hfstruct.py +13 -8
  79. deepdoctection/mapper/laylmstruct.py +178 -37
  80. deepdoctection/mapper/maputils.py +12 -11
  81. deepdoctection/mapper/match.py +2 -2
  82. deepdoctection/mapper/misc.py +11 -9
  83. deepdoctection/mapper/pascalstruct.py +4 -4
  84. deepdoctection/mapper/prodigystruct.py +5 -5
  85. deepdoctection/mapper/pubstruct.py +84 -92
  86. deepdoctection/mapper/tpstruct.py +5 -5
  87. deepdoctection/mapper/xfundstruct.py +33 -33
  88. deepdoctection/pipe/__init__.py +1 -1
  89. deepdoctection/pipe/anngen.py +12 -14
  90. deepdoctection/pipe/base.py +52 -106
  91. deepdoctection/pipe/common.py +72 -59
  92. deepdoctection/pipe/concurrency.py +16 -11
  93. deepdoctection/pipe/doctectionpipe.py +24 -21
  94. deepdoctection/pipe/language.py +20 -25
  95. deepdoctection/pipe/layout.py +20 -16
  96. deepdoctection/pipe/lm.py +75 -105
  97. deepdoctection/pipe/order.py +194 -89
  98. deepdoctection/pipe/refine.py +111 -124
  99. deepdoctection/pipe/segment.py +156 -161
  100. deepdoctection/pipe/{cell.py → sub_layout.py} +50 -40
  101. deepdoctection/pipe/text.py +37 -36
  102. deepdoctection/pipe/transform.py +19 -16
  103. deepdoctection/train/__init__.py +6 -12
  104. deepdoctection/train/d2_frcnn_train.py +48 -41
  105. deepdoctection/train/hf_detr_train.py +41 -30
  106. deepdoctection/train/hf_layoutlm_train.py +153 -135
  107. deepdoctection/train/tp_frcnn_train.py +32 -31
  108. deepdoctection/utils/concurrency.py +1 -1
  109. deepdoctection/utils/context.py +13 -6
  110. deepdoctection/utils/develop.py +4 -4
  111. deepdoctection/utils/env_info.py +87 -125
  112. deepdoctection/utils/file_utils.py +6 -11
  113. deepdoctection/utils/fs.py +22 -18
  114. deepdoctection/utils/identifier.py +2 -2
  115. deepdoctection/utils/logger.py +16 -15
  116. deepdoctection/utils/metacfg.py +7 -7
  117. deepdoctection/utils/mocks.py +93 -0
  118. deepdoctection/utils/pdf_utils.py +11 -11
  119. deepdoctection/utils/settings.py +185 -181
  120. deepdoctection/utils/tqdm.py +1 -1
  121. deepdoctection/utils/transform.py +14 -9
  122. deepdoctection/utils/types.py +104 -0
  123. deepdoctection/utils/utils.py +7 -7
  124. deepdoctection/utils/viz.py +74 -72
  125. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/METADATA +30 -21
  126. deepdoctection-0.33.dist-info/RECORD +146 -0
  127. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/WHEEL +1 -1
  128. deepdoctection/utils/detection_types.py +0 -68
  129. deepdoctection-0.31.dist-info/RECORD +0 -144
  130. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/LICENSE +0 -0
  131. {deepdoctection-0.31.dist-info → deepdoctection-0.33.dist-info}/top_level.txt +0 -0
@@ -18,32 +18,15 @@
18
18
  """
19
19
  Module for training Huggingface implementation of LayoutLm
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
23
  import copy
23
24
  import json
24
25
  import os
25
26
  import pprint
26
- from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
27
-
28
- from torch.nn import Module
29
- from torch.utils.data import Dataset
30
- from transformers import (
31
- IntervalStrategy,
32
- LayoutLMForSequenceClassification,
33
- LayoutLMForTokenClassification,
34
- LayoutLMTokenizerFast,
35
- LayoutLMv2Config,
36
- LayoutLMv2ForSequenceClassification,
37
- LayoutLMv2ForTokenClassification,
38
- LayoutLMv3Config,
39
- LayoutLMv3ForSequenceClassification,
40
- LayoutLMv3ForTokenClassification,
41
- PretrainedConfig,
42
- PreTrainedModel,
43
- RobertaTokenizerFast,
44
- XLMRobertaTokenizerFast,
45
- )
46
- from transformers.trainer import Trainer, TrainingArguments
27
+ from typing import Any, Optional, Sequence, Type, Union
28
+
29
+ from lazy_imports import try_import
47
30
 
48
31
  from ..datasets.adapter import DatasetAdapter
49
32
  from ..datasets.base import DatasetBase
@@ -57,79 +40,109 @@ from ..extern.hflayoutlm import (
57
40
  HFLayoutLmv2TokenClassifier,
58
41
  HFLayoutLmv3SequenceClassifier,
59
42
  HFLayoutLmv3TokenClassifier,
43
+ HFLiltSequenceClassifier,
44
+ HFLiltTokenClassifier,
45
+ get_tokenizer_from_model_class,
60
46
  )
61
- from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features
62
- from ..pipe.base import LanguageModelPipelineComponent
63
- from ..pipe.lm import get_tokenizer_from_architecture
47
+ from ..extern.hflm import HFLmSequenceClassifier
48
+ from ..extern.pt.ptutils import get_torch_device
49
+ from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features, image_to_raw_lm_features
50
+ from ..pipe.base import PipelineComponent
64
51
  from ..pipe.registry import pipeline_component_registry
65
- from ..utils.env_info import get_device
66
52
  from ..utils.error import DependencyError
67
53
  from ..utils.file_utils import wandb_available
68
54
  from ..utils.logger import LoggingRecord, logger
69
- from ..utils.settings import DatasetType, LayoutType, ObjectTypes, WordType
55
+ from ..utils.settings import DatasetType, LayoutType, WordType
56
+ from ..utils.types import PathLikeOrStr
70
57
  from ..utils.utils import string_to_dict
71
58
 
72
- if wandb_available():
73
- import wandb
74
-
75
- _ARCHITECTURES_TO_MODEL_CLASS = {
76
- "LayoutLMForTokenClassification": (LayoutLMForTokenClassification, HFLayoutLmTokenClassifier, PretrainedConfig),
77
- "LayoutLMForSequenceClassification": (
78
- LayoutLMForSequenceClassification,
79
- HFLayoutLmSequenceClassifier,
80
- PretrainedConfig,
81
- ),
82
- "LayoutLMv2ForTokenClassification": (
83
- LayoutLMv2ForTokenClassification,
84
- HFLayoutLmv2TokenClassifier,
85
- LayoutLMv2Config,
86
- ),
87
- "LayoutLMv2ForSequenceClassification": (
88
- LayoutLMv2ForSequenceClassification,
89
- HFLayoutLmv2SequenceClassifier,
90
- LayoutLMv2Config,
91
- ),
92
- }
93
-
59
+ with try_import() as pt_import_guard:
60
+ from torch import nn
61
+ from torch.utils.data import Dataset
94
62
 
95
- _MODEL_TYPE_AND_TASK_TO_MODEL_CLASS: Mapping[Tuple[str, ObjectTypes], Any] = {
96
- ("layoutlm", DatasetType.sequence_classification): (
63
+ with try_import() as tr_import_guard:
64
+ from transformers import (
65
+ IntervalStrategy,
97
66
  LayoutLMForSequenceClassification,
98
- HFLayoutLmSequenceClassifier,
99
- PretrainedConfig,
100
- ),
101
- ("layoutlm", DatasetType.token_classification): (
102
67
  LayoutLMForTokenClassification,
103
- HFLayoutLmTokenClassifier,
104
- PretrainedConfig,
105
- ),
106
- ("layoutlmv2", DatasetType.sequence_classification): (
107
- LayoutLMv2ForSequenceClassification,
108
- HFLayoutLmv2SequenceClassifier,
109
68
  LayoutLMv2Config,
110
- ),
111
- ("layoutlmv2", DatasetType.token_classification): (
69
+ LayoutLMv2ForSequenceClassification,
112
70
  LayoutLMv2ForTokenClassification,
113
- HFLayoutLmv2TokenClassifier,
114
- LayoutLMv2Config,
115
- ),
116
- ("layoutlmv3", DatasetType.sequence_classification): (
117
- LayoutLMv3ForSequenceClassification,
118
- HFLayoutLmv3SequenceClassifier,
119
71
  LayoutLMv3Config,
120
- ),
121
- ("layoutlmv3", DatasetType.token_classification): (
72
+ LayoutLMv3ForSequenceClassification,
122
73
  LayoutLMv3ForTokenClassification,
123
- HFLayoutLmv3TokenClassifier,
124
- LayoutLMv3Config,
125
- ),
126
- }
127
- _MODEL_TYPE_TO_TOKENIZER = {
128
- ("layoutlm", False): LayoutLMTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased"),
129
- ("layoutlmv2", False): LayoutLMTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased"),
130
- ("layoutlmv2", True): XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base", add_prefix_space=True),
131
- ("layoutlmv3", False): RobertaTokenizerFast.from_pretrained("roberta-base", add_prefix_space=True),
132
- }
74
+ LiltForSequenceClassification,
75
+ LiltForTokenClassification,
76
+ PretrainedConfig,
77
+ PreTrainedModel,
78
+ XLMRobertaForSequenceClassification,
79
+ )
80
+ from transformers.trainer import Trainer, TrainingArguments
81
+
82
+ with try_import() as wb_import_guard:
83
+ import wandb
84
+
85
+
86
+ def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> tuple[Any, Any, Any]:
87
+ """
88
+ Get the model architecture, model wrapper and config class for a given model type and dataset type.
89
+
90
+ :param model_type: The model type
91
+ :param dataset_type: The dataset type
92
+ :return: Tuple of model architecture, model wrapper and config class
93
+ """
94
+ return {
95
+ ("layoutlm", DatasetType.SEQUENCE_CLASSIFICATION): (
96
+ LayoutLMForSequenceClassification,
97
+ HFLayoutLmSequenceClassifier,
98
+ PretrainedConfig,
99
+ ),
100
+ ("layoutlm", DatasetType.TOKEN_CLASSIFICATION): (
101
+ LayoutLMForTokenClassification,
102
+ HFLayoutLmTokenClassifier,
103
+ PretrainedConfig,
104
+ ),
105
+ ("layoutlmv2", DatasetType.SEQUENCE_CLASSIFICATION): (
106
+ LayoutLMv2ForSequenceClassification,
107
+ HFLayoutLmv2SequenceClassifier,
108
+ LayoutLMv2Config,
109
+ ),
110
+ ("layoutlmv2", DatasetType.TOKEN_CLASSIFICATION): (
111
+ LayoutLMv2ForTokenClassification,
112
+ HFLayoutLmv2TokenClassifier,
113
+ LayoutLMv2Config,
114
+ ),
115
+ ("layoutlmv3", DatasetType.SEQUENCE_CLASSIFICATION): (
116
+ LayoutLMv3ForSequenceClassification,
117
+ HFLayoutLmv3SequenceClassifier,
118
+ LayoutLMv3Config,
119
+ ),
120
+ ("layoutlmv3", DatasetType.TOKEN_CLASSIFICATION): (
121
+ LayoutLMv3ForTokenClassification,
122
+ HFLayoutLmv3TokenClassifier,
123
+ LayoutLMv3Config,
124
+ ),
125
+ ("lilt", DatasetType.TOKEN_CLASSIFICATION): (
126
+ LiltForTokenClassification,
127
+ HFLiltTokenClassifier,
128
+ PretrainedConfig,
129
+ ),
130
+ ("lilt", DatasetType.SEQUENCE_CLASSIFICATION): (
131
+ LiltForSequenceClassification,
132
+ HFLiltSequenceClassifier,
133
+ PretrainedConfig,
134
+ ),
135
+ ("xlm-roberta", DatasetType.SEQUENCE_CLASSIFICATION): (
136
+ XLMRobertaForSequenceClassification,
137
+ HFLmSequenceClassifier,
138
+ PretrainedConfig,
139
+ ),
140
+ }[(model_type, dataset_type)]
141
+
142
+
143
+ def maybe_remove_bounding_box_features(model_type: str) -> bool:
144
+ """Listing of models that do not need bounding box features."""
145
+ return {"xlm-roberta": True}.get(model_type, False)
133
146
 
134
147
 
135
148
  class LayoutLMTrainer(Trainer):
@@ -145,21 +158,21 @@ class LayoutLMTrainer(Trainer):
145
158
 
146
159
  def __init__(
147
160
  self,
148
- model: Union[PreTrainedModel, Module],
161
+ model: Union[PreTrainedModel, nn.Module],
149
162
  args: TrainingArguments,
150
163
  data_collator: LayoutLMDataCollator,
151
164
  train_dataset: Dataset[Any],
152
165
  ):
153
166
  self.evaluator: Optional[Evaluator] = None
154
- self.build_eval_kwargs: Optional[Dict[str, Any]] = None
167
+ self.build_eval_kwargs: Optional[dict[str, Any]] = None
155
168
  super().__init__(model, args, data_collator, train_dataset)
156
169
 
157
170
  def setup_evaluator(
158
171
  self,
159
172
  dataset_val: DatasetBase,
160
- pipeline_component: LanguageModelPipelineComponent,
173
+ pipeline_component: PipelineComponent,
161
174
  metric: Union[Type[ClassificationMetric], ClassificationMetric],
162
- run: Optional["wandb.sdk.wandb_run.Run"] = None,
175
+ run: Optional[wandb.sdk.wandb_run.Run] = None,
163
176
  **build_eval_kwargs: Union[str, int],
164
177
  ) -> None:
165
178
  """
@@ -176,15 +189,15 @@ class LayoutLMTrainer(Trainer):
176
189
  self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1, run=run)
177
190
  assert self.evaluator.pipe_component
178
191
  for comp in self.evaluator.pipe_component.pipe_components:
179
- comp.language_model.model = None # type: ignore
192
+ comp.clear_predictor()
180
193
  self.build_eval_kwargs = build_eval_kwargs
181
194
 
182
195
  def evaluate(
183
196
  self,
184
197
  eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
185
- ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
198
+ ignore_keys: Optional[list[str]] = None, # pylint: disable=W0613
186
199
  metric_key_prefix: str = "eval", # pylint: disable=W0613
187
- ) -> Dict[str, float]:
200
+ ) -> dict[str, float]:
188
201
  """
189
202
  Overwritten method from `Trainer`. Arguments will not be used.
190
203
  """
@@ -208,34 +221,35 @@ class LayoutLMTrainer(Trainer):
208
221
 
209
222
 
210
223
  def _get_model_class_and_tokenizer(
211
- path_config_json: str, dataset_type: ObjectTypes, use_xlm_tokenizer: bool
212
- ) -> Tuple[Any, Any, Any, Any]:
224
+ path_config_json: PathLikeOrStr, dataset_type: DatasetType, use_xlm_tokenizer: bool
225
+ ) -> tuple[Any, Any, Any, Any, Any]:
213
226
  with open(path_config_json, "r", encoding="UTF-8") as file:
214
227
  config_json = json.load(file)
215
228
 
216
- model_type = config_json.get("model_type")
217
-
218
- if architectures := config_json.get("architectures"):
219
- model_cls, model_wrapper_cls, config_cls = _ARCHITECTURES_TO_MODEL_CLASS[architectures[0]]
220
- tokenizer_fast = get_tokenizer_from_architecture(architectures[0], use_xlm_tokenizer)
221
- elif model_type:
222
- model_cls, model_wrapper_cls, config_cls = _MODEL_TYPE_AND_TASK_TO_MODEL_CLASS[(model_type, dataset_type)]
223
- tokenizer_fast = _MODEL_TYPE_TO_TOKENIZER[(model_type, use_xlm_tokenizer)]
229
+ if model_type := config_json.get("model_type"):
230
+ model_cls, model_wrapper_cls, config_cls = get_model_architectures_and_configs(model_type, dataset_type)
231
+ remove_box_features = maybe_remove_bounding_box_features(model_type)
224
232
  else:
225
- raise KeyError("model_type and architectures not available in configs")
233
+ raise KeyError("model_type not available in configs. It seems that the config is not valid")
226
234
 
227
- if not model_cls:
228
- raise UserWarning("model not eligible to run with this framework")
235
+ tokenizer_fast = get_tokenizer_from_model_class(model_cls.__name__, use_xlm_tokenizer)
236
+ return config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features
229
237
 
230
- return config_cls, model_cls, model_wrapper_cls, tokenizer_fast
238
+
239
+ def get_image_to_raw_features_mapping(input_str: str) -> Any:
240
+ """Replacing eval functions"""
241
+ return {
242
+ "image_to_raw_layoutlm_features": image_to_raw_layoutlm_features,
243
+ "image_to_raw_lm_features": image_to_raw_lm_features,
244
+ }[input_str]
231
245
 
232
246
 
233
247
  def train_hf_layoutlm(
234
- path_config_json: str,
248
+ path_config_json: PathLikeOrStr,
235
249
  dataset_train: Union[str, DatasetBase],
236
- path_weights: str,
237
- config_overwrite: Optional[List[str]] = None,
238
- log_dir: str = "train_log/layoutlm",
250
+ path_weights: PathLikeOrStr,
251
+ config_overwrite: Optional[list[str]] = None,
252
+ log_dir: PathLikeOrStr = "train_log/layoutlm",
239
253
  build_train_config: Optional[Sequence[str]] = None,
240
254
  dataset_val: Optional[DatasetBase] = None,
241
255
  build_val_config: Optional[Sequence[str]] = None,
@@ -310,13 +324,13 @@ def train_hf_layoutlm(
310
324
  appear as child, it will use the word bounding box.
311
325
  """
312
326
 
313
- build_train_dict: Dict[str, str] = {}
327
+ build_train_dict: dict[str, str] = {}
314
328
  if build_train_config is not None:
315
329
  build_train_dict = string_to_dict(",".join(build_train_config))
316
330
  if "split" not in build_train_dict:
317
331
  build_train_dict["split"] = "train"
318
332
 
319
- build_val_dict: Dict[str, str] = {}
333
+ build_val_dict: dict[str, str] = {}
320
334
  if build_val_config is not None:
321
335
  build_val_dict = string_to_dict(",".join(build_val_config))
322
336
  if "split" not in build_val_dict:
@@ -330,40 +344,43 @@ def train_hf_layoutlm(
330
344
 
331
345
  # We wrap our dataset into a torch dataset
332
346
  dataset_type = dataset_train.dataset_info.type
333
- if dataset_type == DatasetType.sequence_classification:
347
+ if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
334
348
  categories_dict_name_as_key = dataset_train.dataflow.categories.get_categories(as_dict=True, name_as_key=True)
335
- elif dataset_type == DatasetType.token_classification:
349
+ elif dataset_type == DatasetType.TOKEN_CLASSIFICATION:
336
350
  if use_token_tag:
337
351
  categories_dict_name_as_key = dataset_train.dataflow.categories.get_sub_categories(
338
- categories=LayoutType.word,
339
- sub_categories={LayoutType.word: [WordType.token_tag]},
352
+ categories=LayoutType.WORD,
353
+ sub_categories={LayoutType.WORD: [WordType.TOKEN_TAG]},
340
354
  keys=False,
341
355
  values_as_dict=True,
342
356
  name_as_key=True,
343
- )[LayoutType.word][WordType.token_tag]
357
+ )[LayoutType.WORD][WordType.TOKEN_TAG]
344
358
  else:
345
359
  categories_dict_name_as_key = dataset_train.dataflow.categories.get_sub_categories(
346
- categories=LayoutType.word,
347
- sub_categories={LayoutType.word: [WordType.token_class]},
360
+ categories=LayoutType.WORD,
361
+ sub_categories={LayoutType.WORD: [WordType.TOKEN_CLASS]},
348
362
  keys=False,
349
363
  values_as_dict=True,
350
364
  name_as_key=True,
351
- )[LayoutType.word][WordType.token_class]
365
+ )[LayoutType.WORD][WordType.TOKEN_CLASS]
352
366
  else:
353
367
  raise UserWarning("Dataset type not supported for training")
354
368
 
355
- config_cls, model_cls, model_wrapper_cls, tokenizer_fast = _get_model_class_and_tokenizer(
369
+ config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features = _get_model_class_and_tokenizer(
356
370
  path_config_json, dataset_type, use_xlm_tokenizer
357
371
  )
358
- image_to_raw_layoutlm_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
372
+ image_to_raw_features_func = get_image_to_raw_features_mapping(model_wrapper_cls.image_to_raw_features_mapping())
373
+ image_to_raw_features_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
359
374
  if segment_positions:
360
- image_to_raw_layoutlm_kwargs["segment_positions"] = segment_positions # type: ignore
361
- image_to_raw_layoutlm_kwargs.update(model_wrapper_cls.default_kwargs_for_input_mapping())
375
+ image_to_raw_features_kwargs["segment_positions"] = segment_positions # type: ignore
376
+ image_to_raw_features_kwargs.update(model_wrapper_cls.default_kwargs_for_image_to_features_mapping())
377
+
362
378
  dataset = DatasetAdapter(
363
379
  dataset_train,
364
380
  True,
365
- image_to_raw_layoutlm_features(**image_to_raw_layoutlm_kwargs),
381
+ image_to_raw_features_func(**image_to_raw_features_kwargs),
366
382
  use_token_tag,
383
+ number_repetitions=-1,
367
384
  **build_train_dict,
368
385
  )
369
386
 
@@ -373,7 +390,7 @@ def train_hf_layoutlm(
373
390
  # Need to set remove_unused_columns to False, as the DataCollator for column removal will remove some raw features
374
391
  # that are necessary for the tokenizer.
375
392
  conf_dict = {
376
- "output_dir": log_dir,
393
+ "output_dir": os.fspath(log_dir),
377
394
  "remove_unused_columns": False,
378
395
  "per_device_train_batch_size": 8,
379
396
  "max_steps": number_samples,
@@ -414,16 +431,16 @@ def train_hf_layoutlm(
414
431
  )
415
432
 
416
433
  use_wandb = conf_dict.pop("use_wandb")
417
- wandb_project = conf_dict.pop("wandb_project")
418
- wandb_repo = conf_dict.pop("wandb_repo")
434
+ wandb_project = str(conf_dict.pop("wandb_project"))
435
+ wandb_repo = str(conf_dict.pop("wandb_repo"))
419
436
 
420
437
  # Initialize Wandb, if necessary
421
438
  run = None
422
439
  if use_wandb:
423
440
  if not wandb_available():
424
441
  raise DependencyError("WandB must be installed separately")
425
- run = wandb.init(project=wandb_project, config=conf_dict) # type: ignore
426
- run._label(repo=wandb_repo) # type: ignore # pylint: disable=W0212
442
+ run = wandb.init(project=wandb_project, config=conf_dict)
443
+ run._label(repo=wandb_repo) # pylint: disable=W0212
427
444
  else:
428
445
  os.environ["WANDB_DISABLED"] = "True"
429
446
 
@@ -453,32 +470,34 @@ def train_hf_layoutlm(
453
470
  return_tensors="pt",
454
471
  sliding_window_stride=sliding_window_stride, # type: ignore
455
472
  max_batch_size=max_batch_size, # type: ignore
473
+ remove_bounding_box_features=remove_box_features,
456
474
  )
457
475
  trainer = LayoutLMTrainer(model, arguments, data_collator, dataset)
458
476
 
459
477
  if arguments.evaluation_strategy in (IntervalStrategy.STEPS,):
460
478
  assert metric is not None # silence mypy
461
- if dataset_type == DatasetType.sequence_classification:
479
+ if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
462
480
  categories = dataset_val.dataflow.categories.get_categories(filtered=True) # type: ignore
463
481
  else:
464
482
  if use_token_tag:
465
483
  categories = dataset_val.dataflow.categories.get_sub_categories( # type: ignore
466
- categories=LayoutType.word, sub_categories={LayoutType.word: [WordType.token_tag]}, keys=False
467
- )[LayoutType.word][WordType.token_tag]
468
- metric.set_categories(category_names=LayoutType.word, sub_category_names={"word": ["token_tag"]})
484
+ categories=LayoutType.WORD, sub_categories={LayoutType.WORD: [WordType.TOKEN_TAG]}, keys=False
485
+ )[LayoutType.WORD][WordType.TOKEN_TAG]
486
+ metric.set_categories(category_names=LayoutType.WORD, sub_category_names={"word": ["token_tag"]})
469
487
  else:
470
488
  categories = dataset_val.dataflow.categories.get_sub_categories( # type: ignore
471
- categories=LayoutType.word, sub_categories={LayoutType.word: [WordType.token_class]}, keys=False
472
- )[LayoutType.word][WordType.token_class]
473
- metric.set_categories(category_names=LayoutType.word, sub_category_names={"word": ["token_class"]})
489
+ categories=LayoutType.WORD, sub_categories={LayoutType.WORD: [WordType.TOKEN_CLASS]}, keys=False
490
+ )[LayoutType.WORD][WordType.TOKEN_CLASS]
491
+ metric.set_categories(category_names=LayoutType.WORD, sub_category_names={"word": ["token_class"]})
474
492
  dd_model = model_wrapper_cls(
475
493
  path_config_json=path_config_json,
476
494
  path_weights=path_weights,
477
495
  categories=categories,
478
- device=get_device(),
496
+ device=get_torch_device(),
497
+ use_xlm_tokenizer=use_xlm_tokenizer,
479
498
  )
480
499
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
481
- if dataset_type == DatasetType.sequence_classification:
500
+ if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
482
501
  pipeline_component = pipeline_component_cls(tokenizer_fast, dd_model)
483
502
  else:
484
503
  pipeline_component = pipeline_component_cls(
@@ -487,7 +506,6 @@ def train_hf_layoutlm(
487
506
  use_other_as_default_category=True,
488
507
  sliding_window_stride=sliding_window_stride,
489
508
  )
490
- assert isinstance(pipeline_component, LanguageModelPipelineComponent)
491
509
 
492
510
  trainer.setup_evaluator(dataset_val, pipeline_component, metric, run, **build_val_dict) # type: ignore
493
511
 
@@ -20,27 +20,9 @@ Module for training Tensorpack `GeneralizedRCNN`
20
20
  """
21
21
 
22
22
  import os
23
- from typing import Dict, List, Optional, Sequence, Type, Union
24
-
25
- # pylint: disable=import-error
26
- from tensorpack.callbacks import (
27
- EstimatedTimeLeft,
28
- GPUMemoryTracker,
29
- GPUUtilizationTracker,
30
- HostMemoryTracker,
31
- ModelSaver,
32
- PeriodicCallback,
33
- ScheduledHyperParamSetter,
34
- SessionRunTimeout,
35
- ThroughputTracker,
36
- )
37
-
38
- # todo: check how dataflow import is directly possible without having AssertionError
39
- from tensorpack.dataflow import ProxyDataFlow, imgaug
40
- from tensorpack.input_source import QueueInput
41
- from tensorpack.tfutils import SmartInit
42
- from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
43
- from tensorpack.utils import logger
23
+ from typing import Optional, Sequence, Type, Union
24
+
25
+ from lazy_imports import try_import
44
26
 
45
27
  from ..dataflow.base import DataFlow
46
28
  from ..dataflow.common import MapData
@@ -58,16 +40,35 @@ from ..extern.tp.tpfrcnn.preproc import anchors_and_labels, augment
58
40
  from ..extern.tpdetect import TPFrcnnDetector
59
41
  from ..mapper.maputils import LabelSummarizer
60
42
  from ..mapper.tpstruct import image_to_tp_frcnn_training
61
- from ..pipe.base import PredictorPipelineComponent
62
43
  from ..pipe.registry import pipeline_component_registry
63
- from ..utils.detection_types import JsonDict
64
44
  from ..utils.file_utils import set_mp_spawn
65
45
  from ..utils.fs import get_load_image_func
66
46
  from ..utils.logger import log_once
67
47
  from ..utils.metacfg import AttrDict, set_config_by_yaml
68
48
  from ..utils.tqdm import get_tqdm
49
+ from ..utils.types import JsonDict, PathLikeOrStr
69
50
  from ..utils.utils import string_to_dict
70
51
 
52
+ with try_import() as tp_import_guard:
53
+ # todo: check how dataflow import is directly possible without having an AssertionError
54
+ # pylint: disable=import-error
55
+ from tensorpack.callbacks import (
56
+ EstimatedTimeLeft,
57
+ GPUMemoryTracker,
58
+ GPUUtilizationTracker,
59
+ HostMemoryTracker,
60
+ ModelSaver,
61
+ PeriodicCallback,
62
+ ScheduledHyperParamSetter,
63
+ SessionRunTimeout,
64
+ ThroughputTracker,
65
+ )
66
+ from tensorpack.dataflow import ProxyDataFlow, imgaug
67
+ from tensorpack.input_source import QueueInput
68
+ from tensorpack.tfutils import SmartInit
69
+ from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
70
+ from tensorpack.utils import logger
71
+
71
72
  __all__ = ["train_faster_rcnn"]
72
73
 
73
74
 
@@ -183,11 +184,11 @@ def get_train_dataflow(
183
184
 
184
185
 
185
186
  def train_faster_rcnn(
186
- path_config_yaml: str,
187
+ path_config_yaml: PathLikeOrStr,
187
188
  dataset_train: DatasetBase,
188
- path_weights: str = "",
189
- config_overwrite: Optional[List[str]] = None,
190
- log_dir: str = "train_log/frcnn",
189
+ path_weights: PathLikeOrStr,
190
+ config_overwrite: Optional[list[str]] = None,
191
+ log_dir: PathLikeOrStr = "train_log/frcnn",
191
192
  build_train_config: Optional[Sequence[str]] = None,
192
193
  dataset_val: Optional[DatasetBase] = None,
193
194
  build_val_config: Optional[Sequence[str]] = None,
@@ -222,13 +223,13 @@ def train_faster_rcnn(
222
223
 
223
224
  assert disable_tfv2() # TP works only in Graph mode
224
225
 
225
- build_train_dict: Dict[str, str] = {}
226
+ build_train_dict: dict[str, str] = {}
226
227
  if build_train_config is not None:
227
228
  build_train_dict = string_to_dict(",".join(build_train_config))
228
229
  if "split" not in build_train_dict:
229
230
  build_train_dict["split"] = "train"
230
231
 
231
- build_val_dict: Dict[str, str] = {}
232
+ build_val_dict: dict[str, str] = {}
232
233
  if build_val_config is not None:
233
234
  build_val_dict = string_to_dict(",".join(build_val_config))
234
235
  if "split" not in build_val_dict:
@@ -236,7 +237,7 @@ def train_faster_rcnn(
236
237
 
237
238
  config_overwrite = [] if config_overwrite is None else config_overwrite
238
239
 
239
- log_dir = "TRAIN.LOG_DIR=" + log_dir
240
+ log_dir = "TRAIN.LOG_DIR=" + os.fspath(log_dir)
240
241
  config_overwrite.append(log_dir)
241
242
 
242
243
  config = set_config_by_yaml(path_config_yaml)
@@ -297,7 +298,6 @@ def train_faster_rcnn(
297
298
  ) # only a wrapper for the predictor itself. Will be replaced in Callback
298
299
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
299
300
  pipeline_component = pipeline_component_cls(detector)
300
- assert isinstance(pipeline_component, PredictorPipelineComponent)
301
301
  category_names = list(categories.values())
302
302
  callbacks.extend(
303
303
  [
@@ -308,6 +308,7 @@ def train_faster_rcnn(
308
308
  metric, # type: ignore
309
309
  pipeline_component,
310
310
  *model.get_inference_tensor_names(), # type: ignore
311
+ cfg=detector.model.cfg,
311
312
  **build_val_dict
312
313
  )
313
314
  ]
@@ -28,8 +28,8 @@ import threading
28
28
  from contextlib import contextmanager
29
29
  from typing import Any, Generator, Optional, no_type_check
30
30
 
31
- from .detection_types import QueueType
32
31
  from .logger import log_once
32
+ from .types import QueueType
33
33
 
34
34
 
35
35
  # taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
@@ -26,12 +26,12 @@ from glob import iglob
26
26
  from os import path, remove
27
27
  from tempfile import NamedTemporaryFile
28
28
  from time import perf_counter as timer
29
- from typing import Any, Generator, Iterator, Optional, Tuple, Union
29
+ from typing import Any, Generator, Iterator, Optional, Union
30
30
 
31
31
  import numpy as np
32
32
 
33
- from .detection_types import ImageType
34
33
  from .logger import LoggingRecord, logger
34
+ from .types import B64, B64Str, PixelValues
35
35
  from .viz import viz_handler
36
36
 
37
37
  __all__ = ["timeout_manager", "save_tmp_file", "timed_operation"]
@@ -72,7 +72,7 @@ def timeout_manager(proc, seconds: Optional[int] = None) -> Iterator[str]: # ty
72
72
 
73
73
 
74
74
  @contextmanager
75
- def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[Tuple[str, str]]:
75
+ def save_tmp_file(image: Union[B64Str, PixelValues, B64], prefix: str) -> Iterator[tuple[str, str]]:
76
76
  """
77
77
  Save image temporarily and handle the clean-up once not necessary anymore
78
78
 
@@ -112,13 +112,20 @@ def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[
112
112
  @contextmanager
113
113
  def timed_operation(message: str, log_start: bool = False) -> Generator[Any, None, None]:
114
114
  """
115
- Contextmanager with a timer. Can therefore be used in a with statement.
115
+ Contextmanager with a timer.
116
116
 
117
- :param message: a log to print
117
+ ... code-block:: python
118
+
119
+ with timed_operation(message="Your stdout message", log_start=True):
120
+
121
+ with open("log.txt", "a") as file:
122
+ ...
123
+
124
+
125
+ :param message: a log to stdout
118
126
  :param log_start: whether to print also the beginning
119
127
  """
120
128
 
121
- assert len(message)
122
129
  if log_start:
123
130
  logger.info(LoggingRecord(f"start task: {message} ..."))
124
131
  start = timer()
@@ -26,19 +26,19 @@ import functools
26
26
  import inspect
27
27
  from collections import defaultdict
28
28
  from datetime import datetime
29
- from typing import Callable, List, Optional
29
+ from typing import Callable, Optional
30
30
 
31
- from .detection_types import T
32
31
  from .logger import LoggingRecord, logger
32
+ from .types import T
33
33
 
34
- __all__: List[str] = ["deprecated"]
34
+ __all__: list[str] = ["deprecated"]
35
35
 
36
36
  # Copy and paste from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py
37
37
 
38
38
  _DEPRECATED_LOG_NUM = defaultdict(int) # type: ignore
39
39
 
40
40
 
41
- def log_deprecated(name: str = "", text: str = "", eos: str = "", max_num_warnings: Optional[int] = None) -> None:
41
+ def log_deprecated(name: str, text: str, eos: str = "", max_num_warnings: Optional[int] = None) -> None:
42
42
  """
43
43
  Log deprecation warning.
44
44