deepdoctection 0.42.0__py3-none-any.whl → 0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (124) hide show
  1. deepdoctection/__init__.py +2 -1
  2. deepdoctection/analyzer/__init__.py +2 -1
  3. deepdoctection/analyzer/config.py +904 -0
  4. deepdoctection/analyzer/dd.py +36 -62
  5. deepdoctection/analyzer/factory.py +311 -141
  6. deepdoctection/configs/conf_dd_one.yaml +100 -44
  7. deepdoctection/configs/profiles.jsonl +32 -0
  8. deepdoctection/dataflow/__init__.py +9 -6
  9. deepdoctection/dataflow/base.py +33 -15
  10. deepdoctection/dataflow/common.py +96 -75
  11. deepdoctection/dataflow/custom.py +36 -29
  12. deepdoctection/dataflow/custom_serialize.py +135 -91
  13. deepdoctection/dataflow/parallel_map.py +33 -31
  14. deepdoctection/dataflow/serialize.py +15 -10
  15. deepdoctection/dataflow/stats.py +41 -28
  16. deepdoctection/datapoint/__init__.py +4 -6
  17. deepdoctection/datapoint/annotation.py +104 -66
  18. deepdoctection/datapoint/box.py +190 -130
  19. deepdoctection/datapoint/convert.py +66 -39
  20. deepdoctection/datapoint/image.py +151 -95
  21. deepdoctection/datapoint/view.py +383 -236
  22. deepdoctection/datasets/__init__.py +2 -6
  23. deepdoctection/datasets/adapter.py +11 -11
  24. deepdoctection/datasets/base.py +118 -81
  25. deepdoctection/datasets/dataflow_builder.py +18 -12
  26. deepdoctection/datasets/info.py +76 -57
  27. deepdoctection/datasets/instances/__init__.py +6 -2
  28. deepdoctection/datasets/instances/doclaynet.py +17 -14
  29. deepdoctection/datasets/instances/fintabnet.py +16 -22
  30. deepdoctection/datasets/instances/funsd.py +11 -6
  31. deepdoctection/datasets/instances/iiitar13k.py +9 -9
  32. deepdoctection/datasets/instances/layouttest.py +9 -9
  33. deepdoctection/datasets/instances/publaynet.py +9 -9
  34. deepdoctection/datasets/instances/pubtables1m.py +13 -13
  35. deepdoctection/datasets/instances/pubtabnet.py +13 -15
  36. deepdoctection/datasets/instances/rvlcdip.py +8 -8
  37. deepdoctection/datasets/instances/xfund.py +11 -9
  38. deepdoctection/datasets/registry.py +18 -11
  39. deepdoctection/datasets/save.py +12 -11
  40. deepdoctection/eval/__init__.py +3 -2
  41. deepdoctection/eval/accmetric.py +72 -52
  42. deepdoctection/eval/base.py +29 -10
  43. deepdoctection/eval/cocometric.py +14 -12
  44. deepdoctection/eval/eval.py +56 -41
  45. deepdoctection/eval/registry.py +6 -3
  46. deepdoctection/eval/tedsmetric.py +24 -9
  47. deepdoctection/eval/tp_eval_callback.py +13 -12
  48. deepdoctection/extern/__init__.py +1 -1
  49. deepdoctection/extern/base.py +176 -97
  50. deepdoctection/extern/d2detect.py +127 -92
  51. deepdoctection/extern/deskew.py +19 -10
  52. deepdoctection/extern/doctrocr.py +157 -106
  53. deepdoctection/extern/fastlang.py +25 -17
  54. deepdoctection/extern/hfdetr.py +137 -60
  55. deepdoctection/extern/hflayoutlm.py +329 -248
  56. deepdoctection/extern/hflm.py +67 -33
  57. deepdoctection/extern/model.py +108 -762
  58. deepdoctection/extern/pdftext.py +37 -12
  59. deepdoctection/extern/pt/nms.py +15 -1
  60. deepdoctection/extern/pt/ptutils.py +13 -9
  61. deepdoctection/extern/tessocr.py +87 -54
  62. deepdoctection/extern/texocr.py +29 -14
  63. deepdoctection/extern/tp/tfutils.py +36 -8
  64. deepdoctection/extern/tp/tpcompat.py +54 -16
  65. deepdoctection/extern/tp/tpfrcnn/config/config.py +20 -4
  66. deepdoctection/extern/tpdetect.py +4 -2
  67. deepdoctection/mapper/__init__.py +1 -1
  68. deepdoctection/mapper/cats.py +117 -76
  69. deepdoctection/mapper/cocostruct.py +35 -17
  70. deepdoctection/mapper/d2struct.py +56 -29
  71. deepdoctection/mapper/hfstruct.py +32 -19
  72. deepdoctection/mapper/laylmstruct.py +221 -185
  73. deepdoctection/mapper/maputils.py +71 -35
  74. deepdoctection/mapper/match.py +76 -62
  75. deepdoctection/mapper/misc.py +68 -44
  76. deepdoctection/mapper/pascalstruct.py +13 -12
  77. deepdoctection/mapper/prodigystruct.py +33 -19
  78. deepdoctection/mapper/pubstruct.py +42 -32
  79. deepdoctection/mapper/tpstruct.py +39 -19
  80. deepdoctection/mapper/xfundstruct.py +20 -13
  81. deepdoctection/pipe/__init__.py +1 -2
  82. deepdoctection/pipe/anngen.py +104 -62
  83. deepdoctection/pipe/base.py +226 -107
  84. deepdoctection/pipe/common.py +206 -123
  85. deepdoctection/pipe/concurrency.py +74 -47
  86. deepdoctection/pipe/doctectionpipe.py +108 -47
  87. deepdoctection/pipe/language.py +41 -24
  88. deepdoctection/pipe/layout.py +45 -18
  89. deepdoctection/pipe/lm.py +146 -78
  90. deepdoctection/pipe/order.py +196 -113
  91. deepdoctection/pipe/refine.py +111 -63
  92. deepdoctection/pipe/registry.py +1 -1
  93. deepdoctection/pipe/segment.py +213 -142
  94. deepdoctection/pipe/sub_layout.py +76 -46
  95. deepdoctection/pipe/text.py +52 -33
  96. deepdoctection/pipe/transform.py +8 -6
  97. deepdoctection/train/d2_frcnn_train.py +87 -69
  98. deepdoctection/train/hf_detr_train.py +72 -40
  99. deepdoctection/train/hf_layoutlm_train.py +85 -46
  100. deepdoctection/train/tp_frcnn_train.py +56 -28
  101. deepdoctection/utils/concurrency.py +59 -16
  102. deepdoctection/utils/context.py +40 -19
  103. deepdoctection/utils/develop.py +25 -17
  104. deepdoctection/utils/env_info.py +85 -36
  105. deepdoctection/utils/error.py +16 -10
  106. deepdoctection/utils/file_utils.py +246 -62
  107. deepdoctection/utils/fs.py +162 -43
  108. deepdoctection/utils/identifier.py +29 -16
  109. deepdoctection/utils/logger.py +49 -32
  110. deepdoctection/utils/metacfg.py +83 -21
  111. deepdoctection/utils/pdf_utils.py +119 -62
  112. deepdoctection/utils/settings.py +24 -10
  113. deepdoctection/utils/tqdm.py +10 -5
  114. deepdoctection/utils/transform.py +182 -46
  115. deepdoctection/utils/utils.py +61 -28
  116. deepdoctection/utils/viz.py +150 -104
  117. deepdoctection-0.43.dist-info/METADATA +376 -0
  118. deepdoctection-0.43.dist-info/RECORD +149 -0
  119. {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/WHEEL +1 -1
  120. deepdoctection/analyzer/_config.py +0 -146
  121. deepdoctection-0.42.0.dist-info/METADATA +0 -431
  122. deepdoctection-0.42.0.dist-info/RECORD +0 -148
  123. {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/licenses/LICENSE +0 -0
  124. {deepdoctection-0.42.0.dist-info → deepdoctection-0.43.dist-info}/top_level.txt +0 -0
@@ -16,8 +16,7 @@
16
16
  # limitations under the License.
17
17
 
18
18
  """
19
- Module for training Hugging Face Detr implementation. Note, that this scripts only trans Tabletransformer like Detr
20
- models that are a slightly different from the plain Detr model that are provided by the transformer library.
19
+ Fine-tuning Hugging Face Detr implementation.
21
20
  """
22
21
  from __future__ import annotations
23
22
 
@@ -50,6 +49,7 @@ with try_import() as pt_import_guard:
50
49
  with try_import() as hf_import_guard:
51
50
  from transformers import (
52
51
  AutoFeatureExtractor,
52
+ DeformableDetrForObjectDetection,
53
53
  IntervalStrategy,
54
54
  PretrainedConfig,
55
55
  PreTrainedModel,
@@ -65,12 +65,11 @@ with try_import() as wb_import_guard:
65
65
  class DetrDerivedTrainer(Trainer):
66
66
  """
67
67
  Huggingface Trainer for training Transformer models with a custom evaluate method in order
68
- to use dd Evaluator. Train setting is not defined in the trainer itself but in config setting as
69
- defined in `TrainingArguments`. Please check the Transformer documentation
68
+ to use dd Evaluator.
70
69
 
71
- <https://huggingface.co/docs/transformers/main_classes/trainer>
72
-
73
- for custom training setting.
70
+ Train setting is not defined in the trainer itself but in config setting as defined in `TrainingArguments`.
71
+ Please check the Transformer documentation: https://huggingface.co/docs/transformers/main_classes/trainer for
72
+ custom training setting.
74
73
  """
75
74
 
76
75
  def __init__(
@@ -81,6 +80,16 @@ class DetrDerivedTrainer(Trainer):
81
80
  train_dataset: DatasetAdapter,
82
81
  eval_dataset: Optional[DatasetBase] = None,
83
82
  ):
83
+ """
84
+ Initializes `DetrDerivedTrainer`.
85
+
86
+ Args:
87
+ model: Model to be trained, either `PreTrainedModel` or `nn.Module`.
88
+ args: Training arguments.
89
+ data_collator: Data collator for Detr.
90
+ train_dataset: Training dataset.
91
+ eval_dataset: Optional evaluation dataset.
92
+ """
84
93
  self.evaluator: Optional[Evaluator] = None
85
94
  self.build_eval_kwargs: Optional[dict[str, Any]] = None
86
95
  super().__init__(model, args, data_collator, train_dataset, eval_dataset=eval_dataset)
@@ -94,14 +103,16 @@ class DetrDerivedTrainer(Trainer):
94
103
  **build_eval_kwargs: Union[str, int],
95
104
  ) -> None:
96
105
  """
97
- Setup of evaluator before starting training. During training, predictors will be replaced by current
98
- checkpoints.
99
-
100
- :param dataset_val: dataset on which to run evaluation
101
- :param pipeline_component: pipeline component to plug into the evaluator
102
- :param metric: A metric class
103
- :param run: WandB run
104
- :param build_eval_kwargs:
106
+ Setup of evaluator before starting training.
107
+
108
+ During training, predictors will be replaced by current checkpoints.
109
+
110
+ Args:
111
+ dataset_val: Dataset on which to run evaluation.
112
+ pipeline_component: Pipeline component to plug into the evaluator.
113
+ metric: A metric class.
114
+ run: WandB run.
115
+ **build_eval_kwargs: Additional keyword arguments for evaluation.
105
116
  """
106
117
 
107
118
  self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1, run=run)
@@ -152,29 +163,32 @@ def train_hf_detr(
152
163
  ) -> None:
153
164
  """
154
165
  Train Tabletransformer from scratch or fine-tune using an adaptation of the transformer trainer.
166
+
155
167
  Allowing experiments by using different config settings.
156
168
 
157
- :param path_config_json: path to a Tabletransformer config file
158
- :param dataset_train: dataset to use for training
159
- :param path_weights: path to a checkpoint, if you want to resume training or fine-tune. Will train from scratch if
160
- an empty string is passed
161
- :param path_feature_extractor_config_json: path to a feature extractor config file. In many situations you can use
162
- the standard config file:
163
-
164
- ModelCatalog.
165
- get_full_path_preprocessor_configs
166
- ("microsoft/table-transformer-detection/pytorch_model.bin")
167
-
168
- :param config_overwrite: Pass a list of arguments if some configs from the .json file are supposed to be replaced.
169
- Use the list convention, e.g. ['per_device_train_batch_size=4']
170
- :param log_dir: Will default to 'train_log/detr'
171
- :param build_train_config: dataflow build setting. Again, use list convention setting, e.g. ['max_datapoints=1000']
172
- :param dataset_val: the dataset to use for validation
173
- :param build_val_config: same as `build_train_config` but for dataflow validation
174
- :param metric_name: A metric name to choose for validation. Will use the default setting. If you want a custom
175
- metric setting, pass a metric explicitly.
176
- :param metric: A metric to choose for validation
177
- :param pipeline_component_name: A pipeline component name to use for validation
169
+ Args:
170
+ path_config_json: Path to a Tabletransformer config file.
171
+ dataset_train: Dataset to use for training.
172
+ path_weights: Path to a checkpoint, if you want to resume training or fine-tune. Will train from scratch if an
173
+ empty string is passed.
174
+ path_feature_extractor_config_json: Path to a feature extractor config file. In many situations you can use the
175
+ standard config file:
176
+ Example:
177
+ ```python
178
+ ModelCatalog.get_full_path_preprocessor_configs
179
+ ("microsoft/table-transformer-detection/pytorch_model.bin")
180
+ ```
181
+
182
+ config_overwrite: Pass a list of arguments if some configs from the .json file are supposed to be replaced.
183
+ Use the list convention, e.g. `['per_device_train_batch_size=4']`.
184
+ log_dir: Will default to `train_log/detr`.
185
+ build_train_config: Dataflow build setting. Again, use list convention setting, e.g. `['max_datapoints=1000']`.
186
+ dataset_val: The dataset to use for validation.
187
+ build_val_config: Same as `build_train_config` but for dataflow validation.
188
+ metric_name: A metric name to choose for validation. Will use the default setting.
189
+ If you want a custom metric setting, pass a metric explicitly.
190
+ metric: A metric to choose for validation.
191
+ pipeline_component_name: A pipeline component name to use for validation.
178
192
  """
179
193
 
180
194
  build_train_dict: dict[str, str] = {}
@@ -275,11 +289,29 @@ def train_hf_detr(
275
289
  config.use_timm_backbone = True
276
290
 
277
291
  if path_weights != "":
278
- model = TableTransformerForObjectDetection.from_pretrained(
279
- pretrained_model_name_or_path=path_weights, config=config, ignore_mismatched_sizes=True
280
- )
292
+ if "TableTransformerForObjectDetection" in config.architectures:
293
+ model = TableTransformerForObjectDetection.from_pretrained(
294
+ pretrained_model_name_or_path=path_weights, config=config, ignore_mismatched_sizes=True
295
+ )
296
+ elif "DeformableDetrForObjectDetection" in config.architectures:
297
+ return DeformableDetrForObjectDetection.from_pretrained(
298
+ pretrained_model_name_or_path=os.fspath(path_weights), config=config
299
+ )
300
+ else:
301
+ raise ValueError(
302
+ f"Model architecture {config.architectures} not eligible. Please use either "
303
+ "TableTransformerForObjectDetection or DeformableDetrForObjectDetection."
304
+ )
281
305
  else:
282
- model = TableTransformerForObjectDetection(config)
306
+ if "TableTransformerForObjectDetection" in config.architectures:
307
+ model = TableTransformerForObjectDetection(config)
308
+ elif "DeformableDetrForObjectDetection" in config.architectures:
309
+ model = DeformableDetrForObjectDetection(config)
310
+ else:
311
+ raise ValueError(
312
+ f"Model architecture {config.architectures} not eligible. Please use either "
313
+ "TableTransformerForObjectDetection or DeformableDetrForObjectDetection."
314
+ )
283
315
 
284
316
  feature_extractor = AutoFeatureExtractor.from_pretrained(
285
317
  pretrained_model_name_or_path=path_feature_extractor_config_json
@@ -16,7 +16,10 @@
16
16
  # limitations under the License.
17
17
 
18
18
  """
19
- Module for training Huggingface implementation of LayoutLm
19
+ Fine-tuning Huggingface implementation of LayoutLm.
20
+
21
+ This module provides functions and classes for fine-tuning LayoutLM models for sequence or token classification using
22
+ the Huggingface Trainer and custom evaluation. It supports LayoutLM, LayoutLMv2, LayoutLMv3, and LayoutXLM models.
20
23
  """
21
24
  from __future__ import annotations
22
25
 
@@ -85,11 +88,14 @@ with try_import() as wb_import_guard:
85
88
 
86
89
  def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> tuple[Any, Any, Any]:
87
90
  """
88
- Get the model architecture, model wrapper and config class for a given model type and dataset type.
91
+ Gets the model architecture, model wrapper, and config class for a given `model_type` and `dataset_type`.
92
+
93
+ Args:
94
+ model_type: The model type.
95
+ dataset_type: The dataset type.
89
96
 
90
- :param model_type: The model type
91
- :param dataset_type: The dataset type
92
- :return: Tuple of model architecture, model wrapper and config class
97
+ Returns:
98
+ Tuple of model architecture, model wrapper, and config class.
93
99
  """
94
100
  return {
95
101
  ("layoutlm", DatasetType.SEQUENCE_CLASSIFICATION): (
@@ -141,19 +147,28 @@ def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetTy
141
147
 
142
148
 
143
149
  def maybe_remove_bounding_box_features(model_type: str) -> bool:
144
- """Listing of models that do not need bounding box features."""
150
+ """
151
+ Lists models that do not need bounding box features.
152
+
153
+ Args:
154
+ model_type: The model type.
155
+
156
+ Returns:
157
+ Whether the model does not need bounding box features.
158
+ """
145
159
  return {"xlm-roberta": True}.get(model_type, False)
146
160
 
147
161
 
148
162
  class LayoutLMTrainer(Trainer):
149
163
  """
150
- Huggingface Trainer for training Transformer models with a custom evaluate method in order
151
- to use dd Evaluator. Train setting is not defined in the trainer itself but in config setting as
152
- defined in `TrainingArguments`. Please check the Transformer documentation
164
+ Huggingface Trainer for training Transformer models with a custom evaluate method to use the Deepdoctection
165
+ Evaluator.
153
166
 
154
- <https://huggingface.co/docs/transformers/main_classes/trainer>
167
+ Train settings are not defined in the trainer itself but in the config setting as defined in `TrainingArguments`.
168
+ Please check the Transformer documentation for custom training settings.
155
169
 
156
- for custom training setting.
170
+ Info:
171
+ https://huggingface.co/docs/transformers/main_classes/trainer
157
172
  """
158
173
 
159
174
  def __init__(
@@ -164,6 +179,16 @@ class LayoutLMTrainer(Trainer):
164
179
  train_dataset: DatasetAdapter,
165
180
  eval_dataset: Optional[DatasetBase] = None,
166
181
  ):
182
+ """
183
+ Initializes the `LayoutLMTrainer`.
184
+
185
+ Args:
186
+ model: The model to train.
187
+ args: Training arguments.
188
+ data_collator: Data collator for batching.
189
+ train_dataset: Training dataset.
190
+ eval_dataset: Optional evaluation dataset.
191
+ """
167
192
  self.evaluator: Optional[Evaluator] = None
168
193
  self.build_eval_kwargs: Optional[dict[str, Any]] = None
169
194
  super().__init__(model, args, data_collator, train_dataset, eval_dataset=eval_dataset)
@@ -177,14 +202,15 @@ class LayoutLMTrainer(Trainer):
177
202
  **build_eval_kwargs: Union[str, int],
178
203
  ) -> None:
179
204
  """
180
- Setup of evaluator before starting training. During training, predictors will be replaced by current
205
+ Sets up the evaluator before starting training. During training, predictors will be replaced by current
181
206
  checkpoints.
182
207
 
183
- :param dataset_val: dataset on which to run evaluation
184
- :param pipeline_component: pipeline component to plug into the evaluator
185
- :param metric: A metric class
186
- :param run: WandB run
187
- :param build_eval_kwargs:
208
+ Args:
209
+ dataset_val: Dataset on which to run evaluation.
210
+ pipeline_component: Pipeline component to plug into the evaluator.
211
+ metric: A metric class.
212
+ run: WandB run.
213
+ **build_eval_kwargs: Additional keyword arguments for evaluation.
188
214
  """
189
215
 
190
216
  self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1, run=run)
@@ -201,6 +227,14 @@ class LayoutLMTrainer(Trainer):
201
227
  ) -> dict[str, float]:
202
228
  """
203
229
  Overwritten method from `Trainer`. Arguments will not be used.
230
+
231
+ Args:
232
+ eval_dataset: Not used.
233
+ ignore_keys: Not used.
234
+ metric_key_prefix: Not used.
235
+
236
+ Returns:
237
+ Evaluation scores as a dictionary.
204
238
  """
205
239
  if self.evaluator is None:
206
240
  raise ValueError("Evaluator not set up. Please use `setup_evaluator` before running evaluation")
@@ -266,28 +300,32 @@ def train_hf_layoutlm(
266
300
  LayoutXLM. Training similar but different models like LILT <https://arxiv.org/abs/2202.13669> can be done by
267
301
  changing a few lines of code regarding the selection of the tokenizer.
268
302
 
269
- The theoretical foundation can be taken from
270
-
271
- <https://arxiv.org/abs/1912.13318>
303
+ Info:
304
+ The theoretical foundation can be taken from <https://arxiv.org/abs/1912.13318>.
272
305
 
273
- This is not the pre-training script.
306
+ This is not the pre-training script.
274
307
 
275
308
  In order to remain within the framework of this library, the base and uncased LayoutLM model must be downloaded
276
309
  from the HF-hub in a first step for fine-tuning. Models are available for this, which are registered in the
277
310
  ModelCatalog. It is possible to choose one of the following options:
278
311
 
279
- "microsoft/layoutlm-base-uncased/pytorch_model.bin"
280
- "microsoft/layoutlmv2-base-uncased/pytorch_model.bin"
281
- "microsoft/layoutxlm-base/pytorch_model.bin"
282
- "microsoft/layoutlmv3-base/pytorch_model.bin"
283
312
 
284
- and
313
+ `microsoft/layoutlm-base-uncased/pytorch_model.bin`
314
+ `microsoft/layoutlmv2-base-uncased/pytorch_model.bin`
315
+ `microsoft/layoutxlm-base/pytorch_model.bin`
316
+ `microsoft/layoutlmv3-base/pytorch_model.bin`
317
+ `microsoft/layoutlm-large-uncased/pytorch_model.bin`
318
+ `SCUT-DLVCLab/lilt-roberta-en-base/pytorch_model.bin`
285
319
 
286
- "microsoft/layoutlm-large-uncased/pytorch_model.bin"
287
320
 
288
- (You can also choose the large versions of LayoutLMv2 and LayoutXLM but you need to organize the download yourself.)
321
+ Note:
322
+ You can also choose the large versions of LayoutLMv2 and LayoutXLM but you need to organize the download
323
+ yourself.
289
324
 
325
+ Example:
326
+ ```python
290
327
  ModelDownloadManager.maybe_download_weights_and_configs("microsoft/layoutlm-base-uncased/pytorch_model.bin")
328
+ ```
291
329
 
292
330
  The corresponding cased models are currently not available, but this is only to keep the model selection small.
293
331
 
@@ -296,30 +334,31 @@ def train_hf_layoutlm(
296
334
  How does the model selection work?
297
335
 
298
336
  The base model is selected by the transferred config file and the weights. Depending on the dataset type
299
- ("SEQUENCE_CLASSIFICATION" or "TOKEN_CLASSIFICATION"), the complete model is then put together by placing a suitable
300
- top layer on the base model.
337
+ `("SEQUENCE_CLASSIFICATION" or "TOKEN_CLASSIFICATION")`, the complete model is then put together by placing a
338
+ suitable top layer on the base model.
301
339
 
302
- :param path_config_json: Absolute path to HF config file, e.g.
303
- ModelCatalog.get_full_path_configs("microsoft/layoutlm-base-uncased/pytorch_model.bin")
304
- :param dataset_train: Dataset to use for training. Only datasets of type "SEQUENCE_CLASSIFICATION" or
340
+ Args:
341
+ path_config_json: Absolute path to HF config file, e.g.
342
+ `ModelCatalog.get_full_path_configs("microsoft/layoutlm-base-uncased/pytorch_model.bin")`
343
+ dataset_train: Dataset to use for training. Only datasets of type "SEQUENCE_CLASSIFICATION" or
305
344
  "TOKEN_CLASSIFICATION" are supported.
306
- :param path_weights: path to a checkpoint for further fine-tuning
307
- :param config_overwrite: Pass a list of arguments if some configs from `TrainingArguments` should be replaced. Check
308
- https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
345
+ path_weights: path to a checkpoint for further fine-tuning
346
+ config_overwrite: Pass a list of arguments if some configs from `TrainingArguments` should be replaced. Check
347
+ <https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments>
309
348
  for the full training default setting.
310
- :param log_dir: Path to log dir. Will default to `train_log/layoutlm`
311
- :param build_train_config: dataflow build setting. Again, use list convention setting, e.g. ['max_datapoints=1000']
312
- :param dataset_val: Dataset to use for validation. Dataset type must be the same as type of `dataset_train`
313
- :param build_val_config: same as `build_train_config` but for validation
314
- :param metric: A metric to choose for validation.
315
- :param pipeline_component_name: A pipeline component name to use for validation (e.g. LMSequenceClassifierService or
349
+ log_dir: Path to log dir. Will default to `train_log/layoutlm`
350
+ build_train_config: dataflow build setting. Again, use list convention setting, e.g. `['max_datapoints=1000']`
351
+ dataset_val: Dataset to use for validation. Dataset type must be the same as type of `dataset_train`
352
+ build_val_config: same as `build_train_config` but for validation
353
+ metric: A metric to choose for validation.
354
+ pipeline_component_name: A pipeline component name to use for validation (e.g. `LMSequenceClassifierService` or
316
355
  LMTokenClassifierService.
317
- :param use_xlm_tokenizer: This is only necessary if you pass weights of layoutxlm. The config cannot distinguish
318
- between Layoutlmv2 and Layoutxlm, so you need to pass this info explicitly.
319
- :param use_token_tag: Will only be used for dataset_type="token_classification". If use_token_tag=True, will use
356
+ use_xlm_tokenizer: This is only necessary if you pass weights of LayoutXLM. The config cannot distinguish
357
+ between Layoutlmv2 and LayoutXLM, so you need to pass this info explicitly.
358
+ use_token_tag: Will only be used for `dataset_type="token_classification"`. If `use_token_tag=True`, will use
320
359
  labels from sub category `WordType.token_tag` (with `B,I,O` suffix), otherwise
321
360
  `WordType.token_class`.
322
- :param segment_positions: Using bounding boxes of segment instead of words improves model accuracy significantly.
361
+ segment_positions: Using bounding boxes of segment instead of words improves model accuracy significantly.
323
362
  Choose a single or a sequence of layout segments to use their bounding boxes. Note, that
324
363
  the layout segments need to have a child-relationship with words. If a word does not
325
364
  appear as child, it will use the word bounding box.
@@ -16,7 +16,7 @@
16
16
  # limitations under the License.
17
17
 
18
18
  """
19
- Module for training Tensorpack `GeneralizedRCNN`
19
+ Training Tensorpack's `GeneralizedRCNN`
20
20
  """
21
21
 
22
22
  import os
@@ -75,6 +75,9 @@ __all__ = ["train_faster_rcnn"]
75
75
  class LoadAugmentAddAnchors:
76
76
  """
77
77
  A helper class for default mapping `load_augment_add_anchors`.
78
+
79
+ Args:
80
+ config: An `AttrDict` configuration for TP FRCNN.
78
81
  """
79
82
 
80
83
  def __init__(self, config: AttrDict) -> None:
@@ -89,9 +92,15 @@ def load_augment_add_anchors(dp: JsonDict, config: AttrDict) -> Optional[JsonDic
89
92
  Transforming an image before entering the graph. This function bundles all the necessary steps to feed
90
93
  the network for training.
91
94
 
92
- :param dp: A dict with 'file_name', 'gt_boxes', 'gt_labels' and optional 'image'
93
- :param config: An `AttrDict` with a TP frcnn config
94
- :return: An dict with all necessary keys for feeding the graph
95
+ Args:
96
+ dp: A dict with `file_name`, `gt_boxes`, `gt_labels` and optional `image`.
97
+ config: An `AttrDict` with a TP frcnn config.
98
+
99
+ Returns:
100
+ A dict with all necessary keys for feeding the graph.
101
+
102
+ Note:
103
+ If `image` is not in `dp`, it will be loaded from `file_name`.
95
104
  """
96
105
  cfg = config
97
106
  if "image" not in dp:
@@ -124,14 +133,20 @@ def get_train_dataflow(
124
133
  dataset: DatasetBase, config: AttrDict, use_multi_proc_for_train: bool, **build_train_kwargs: str
125
134
  ) -> DataFlow:
126
135
  """
127
- Return a dataflow for training TP Frcnn. The returned dataflow depends on the dataset and the configuration of
136
+ Return a dataflow for training TP FRCNN. The returned dataflow depends on the dataset and the configuration of
128
137
  the model, as the augmentation is part of the data preparation.
129
138
 
130
- :param dataset: A dataset for object detection
131
- :param config: An `AttrDict` with a TP Frcnn config
132
- :param use_multi_proc_for_train: If set to `True` will use multi processes for augmenting
133
- :param build_train_kwargs: build configuration of the dataflow.
134
- :return: A dataflow
139
+ Args:
140
+ dataset: A dataset for object detection.
141
+ config: An `AttrDict` with a TP FRCNN config.
142
+ use_multi_proc_for_train: If set to `True` will use multi processes for augmenting.
143
+ build_train_kwargs: Build configuration of the dataflow.
144
+
145
+ Returns:
146
+ A dataflow.
147
+
148
+ Note:
149
+ If `use_multi_proc_for_train` is `True`, multi-processing will be used for augmentation.
135
150
  """
136
151
 
137
152
  set_mp_spawn()
@@ -202,23 +217,35 @@ def train_faster_rcnn(
202
217
  Train Faster-RCNN from Scratch or fine-tune a model using Tensorpack's training API. Observe the training with
203
218
  Tensorpack callbacks and evaluate the training progress with a validation data set after certain training intervals.
204
219
 
205
- Tensorpack provides a training API under TF1. Training runs under a TF2 installation if TF2 behavior is deactivated.
206
-
207
- :param path_config_yaml: path to TP config file. Check the
208
- [deepdoctection.extern.tp.tpfrcnn.config.config][] for various settings.
209
- :param dataset_train: the dataset to use for training.
210
- :param path_weights: path to a checkpoint, if you want to continue training or fine-tune. Will train from scratch if
211
- nothing is passed.
212
- :param config_overwrite: Pass a list of arguments if some configs from the .yaml file should be replaced. Use the
213
- list convention, e.g. ['TRAIN.STEPS_PER_EPOCH=500', 'OUTPUT.RESULT_SCORE_THRESH=0.4']
214
- :param log_dir: Path to log dir. Will default to TRAIN.LOG_DIR
215
- :param build_train_config: dataflow build setting. Again, use list convention setting, e.g. ['max_datapoints=1000']
216
- :param dataset_val: the dataset to use for validation.
217
- :param build_val_config: same as 'build_train_config' but for validation
218
- :param metric_name: A metric name to choose for validation. Will use the default setting. If you want a custom
219
- metric setting pass a metric explicitly.
220
- :param metric: A metric to choose for validation.
221
- :param pipeline_component_name: A pipeline component to use for validation.
220
+ Info:
221
+ Tensorpack provides a training API under TF1. Training runs under a TF2 installation if TF2 behavior is
222
+ deactivated.
223
+
224
+ Args:
225
+ path_config_yaml: Path to TP config file. Check the `deepdoctection.extern.tp.tpfrcnn.config.config` for various
226
+ settings.
227
+ dataset_train: The dataset to use for training.
228
+ path_weights: Path to a checkpoint, if you want to continue training or fine-tune. Will train from scratch if
229
+ nothing is passed.
230
+ config_overwrite: Pass a list of arguments if some configs from the .yaml file should be replaced. Use the list
231
+ convention, e.g. `[`TRAIN.STEPS_PER_EPOCH=500`, `OUTPUT.RESULT_SCORE_THRESH=0.4`]`.
232
+ log_dir: Path to log dir. Will default to `TRAIN.LOG_DIR`.
233
+ build_train_config: Dataflow build setting. Use list convention setting, e.g. `[`max_datapoints=1000`]`.
234
+ dataset_val: The dataset to use for validation.
235
+ build_val_config: Same as `build_train_config` but for validation.
236
+ metric_name: A metric name to choose for validation. Will use the default setting. If you want a custom metric
237
+ setting pass a metric explicitly.
238
+ metric: A metric to choose for validation.
239
+ pipeline_component_name: A pipeline component to use for validation.
240
+
241
+ Example:
242
+ ```python
243
+ train_faster_rcnn(
244
+ path_config_yaml="config.yaml",
245
+ dataset_train=my_train_dataset,
246
+ path_weights="weights.ckpt"
247
+ )
248
+ ```
222
249
  """
223
250
 
224
251
  assert disable_tfv2() # TP works only in Graph mode
@@ -241,9 +268,10 @@ def train_faster_rcnn(
241
268
  config_overwrite.append(log_dir)
242
269
 
243
270
  config = set_config_by_yaml(path_config_yaml)
244
-
271
+ config.freeze(False)
245
272
  if config_overwrite:
246
273
  config.update_args(config_overwrite)
274
+ config.freeze(True)
247
275
 
248
276
  categories = dataset_train.dataflow.categories.get_categories(filtered=True)
249
277
  model_frcnn_config(config, categories, False)
@@ -16,7 +16,7 @@
16
16
  # limitations under the License.
17
17
 
18
18
  """
19
- Some utility functions for multi threading purposes
19
+ Functions for multi/threading purposes
20
20
  """
21
21
 
22
22
  import multiprocessing as mp
@@ -35,12 +35,17 @@ from .types import QueueType
35
35
  # taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
36
36
  class StoppableThread(threading.Thread):
37
37
  """
38
- A thread that has a 'stop' event.
38
+ A thread that has a `stop` event.
39
+
40
+ This class extends `threading.Thread` and provides a mechanism to stop the thread gracefully.
39
41
  """
40
42
 
41
43
  def __init__(self, evt: Optional[threading.Event] = None) -> None:
42
44
  """
43
- :param evt: if None, will create one.
45
+ Initializes a `StoppableThread`.
46
+
47
+ Args:
48
+ evt: An optional `threading.Event`. If `None`, a new event will be created.
44
49
  """
45
50
  super().__init__()
46
51
  if evt is None:
@@ -48,17 +53,30 @@ class StoppableThread(threading.Thread):
48
53
  self._stop_evt = evt
49
54
 
50
55
  def stop(self) -> None:
51
- """Stop the thread"""
56
+ """
57
+ Stop the thread.
58
+
59
+ Sets the internal stop event, signaling the thread to stop.
60
+ """
52
61
  self._stop_evt.set()
53
62
 
54
63
  def stopped(self) -> bool:
55
64
  """
56
- :param bool: whether the thread is stopped or not
65
+ Check whether the thread is stopped.
66
+
67
+ Returns:
68
+ Whether the thread is stopped or not.
57
69
  """
58
70
  return self._stop_evt.is_set()
59
71
 
60
72
  def queue_put_stoppable(self, q: QueueType, obj: Any) -> None:
61
- """Put obj to queue, but will give up when the thread is stopped"""
73
+ """
74
+ Put `obj` to queue `q`, but will give up when the thread is stopped.
75
+
76
+ Args:
77
+ q: The queue to put the object into.
78
+ obj: The object to put into the queue.
79
+ """
62
80
  while not self.stopped():
63
81
  try:
64
82
  q.put(obj, timeout=5)
@@ -67,7 +85,15 @@ class StoppableThread(threading.Thread):
67
85
  pass
68
86
 
69
87
  def queue_get_stoppable(self, q: QueueType) -> Any:
70
- """Take obj from queue, but will give up when the thread is stopped"""
88
+ """
89
+ Take an object from queue `q`, but will give up when the thread is stopped.
90
+
91
+ Args:
92
+ q: The queue to get the object from.
93
+
94
+ Returns:
95
+ The object taken from the queue.
96
+ """
71
97
  while not self.stopped():
72
98
  try:
73
99
  return q.get(timeout=5)
@@ -77,9 +103,14 @@ class StoppableThread(threading.Thread):
77
103
 
78
104
  @contextmanager
79
105
  def mask_sigint() -> Generator[Any, None, None]:
80
- """[Any,None,None
81
- :return: If called in main thread, returns a context where ``SIGINT`` is ignored, and yield True.
82
- Otherwise, yield False.
106
+ """
107
+ Context manager to mask `SIGINT`.
108
+
109
+ If called in the main thread, returns a context where `SIGINT` is ignored, and yields `True`. Otherwise, yields
110
+ `False`.
111
+
112
+ Yields:
113
+ `True` if called in the main thread, otherwise `False`.
83
114
  """
84
115
  if threading.current_thread() == threading.main_thread():
85
116
  sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
@@ -91,9 +122,15 @@ def mask_sigint() -> Generator[Any, None, None]:
91
122
 
92
123
  def enable_death_signal(_warn: bool = True) -> None:
93
124
  """
94
- Set the "death signal" of the current process, so that
95
- the current process will be cleaned with guarantee
96
- in case the parent dies accidentally.
125
+ Set the "death signal" of the current process.
126
+
127
+ Ensures that the current process will be cleaned up if the parent process dies accidentally.
128
+
129
+ Args:
130
+ _warn: If `True`, logs a warning if `prctl` is not available.
131
+
132
+ Note:
133
+ Only works on Linux systems. Requires the `python-prctl` package.
97
134
  """
98
135
  if platform.system() != "Linux":
99
136
  return
@@ -118,11 +155,17 @@ def enable_death_signal(_warn: bool = True) -> None:
118
155
  @no_type_check
119
156
  def start_proc_mask_signal(proc):
120
157
  """
121
- Start process(es) with SIGINT ignored.
158
+ Start process(es) with `SIGINT` ignored.
159
+
160
+ The signal mask is only applied when called from the main thread.
122
161
 
123
- :param proc: (mp.Process or list)
162
+ Note:
163
+ Starting a process with the 'fork' method is efficient but not safe and may cause deadlock or crash.
164
+ Use 'forkserver' or 'spawn' method instead if you run into such issues.
165
+ See <https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods> on how to set them.
124
166
 
125
- The signal mask is only applied when called from main thread.
167
+ Args:
168
+ proc: A `mp.Process` or a list of `mp.Process` instances.
126
169
  """
127
170
  if not isinstance(proc, list):
128
171
  proc = [proc]