deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,8 @@ models that are a slightly different from the plain Detr model that are provided
22
22
  from __future__ import annotations
23
23
 
24
24
  import copy
25
- from typing import Any, Dict, List, Optional, Sequence, Type, Union
25
+ import os
26
+ from typing import Any, Optional, Sequence, Type, Union
26
27
 
27
28
  from lazy_imports import try_import
28
29
 
@@ -34,9 +35,10 @@ from ..eval.eval import Evaluator
34
35
  from ..eval.registry import metric_registry
35
36
  from ..extern.hfdetr import HFDetrDerivedDetector
36
37
  from ..mapper.hfstruct import DetrDataCollator, image_to_hf_detr_training
37
- from ..pipe.base import PredictorPipelineComponent
38
+ from ..pipe.base import PipelineComponent
38
39
  from ..pipe.registry import pipeline_component_registry
39
40
  from ..utils.logger import LoggingRecord, logger
41
+ from ..utils.types import PathLikeOrStr
40
42
  from ..utils.utils import string_to_dict
41
43
 
42
44
  with try_import() as pt_import_guard:
@@ -74,13 +76,13 @@ class DetrDerivedTrainer(Trainer):
74
76
  train_dataset: Dataset[Any],
75
77
  ):
76
78
  self.evaluator: Optional[Evaluator] = None
77
- self.build_eval_kwargs: Optional[Dict[str, Any]] = None
79
+ self.build_eval_kwargs: Optional[dict[str, Any]] = None
78
80
  super().__init__(model, args, data_collator, train_dataset)
79
81
 
80
82
  def setup_evaluator(
81
83
  self,
82
84
  dataset_val: DatasetBase,
83
- pipeline_component: PredictorPipelineComponent,
85
+ pipeline_component: PipelineComponent,
84
86
  metric: Union[Type[MetricBase], MetricBase],
85
87
  **build_eval_kwargs: Union[str, int],
86
88
  ) -> None:
@@ -97,17 +99,15 @@ class DetrDerivedTrainer(Trainer):
97
99
  self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1)
98
100
  assert self.evaluator.pipe_component
99
101
  for comp in self.evaluator.pipe_component.pipe_components:
100
- assert isinstance(comp, PredictorPipelineComponent)
101
- assert isinstance(comp.predictor, HFDetrDerivedDetector)
102
- comp.predictor.hf_detr_predictor = None
102
+ comp.clear_predictor()
103
103
  self.build_eval_kwargs = build_eval_kwargs
104
104
 
105
105
  def evaluate(
106
106
  self,
107
107
  eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
108
- ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
108
+ ignore_keys: Optional[list[str]] = None, # pylint: disable=W0613
109
109
  metric_key_prefix: str = "eval", # pylint: disable=W0613
110
- ) -> Dict[str, float]:
110
+ ) -> dict[str, float]:
111
111
  """
112
112
  Overwritten method from `Trainer`. Arguments will not be used.
113
113
  """
@@ -129,12 +129,12 @@ class DetrDerivedTrainer(Trainer):
129
129
 
130
130
 
131
131
  def train_hf_detr(
132
- path_config_json: str,
132
+ path_config_json: PathLikeOrStr,
133
133
  dataset_train: Union[str, DatasetBase],
134
- path_weights: str,
134
+ path_weights: PathLikeOrStr,
135
135
  path_feature_extractor_config_json: str,
136
- config_overwrite: Optional[List[str]] = None,
137
- log_dir: str = "train_log/detr",
136
+ config_overwrite: Optional[list[str]] = None,
137
+ log_dir: PathLikeOrStr = "train_log/detr",
138
138
  build_train_config: Optional[Sequence[str]] = None,
139
139
  dataset_val: Optional[DatasetBase] = None,
140
140
  build_val_config: Optional[Sequence[str]] = None,
@@ -169,13 +169,13 @@ def train_hf_detr(
169
169
  :param pipeline_component_name: A pipeline component name to use for validation
170
170
  """
171
171
 
172
- build_train_dict: Dict[str, str] = {}
172
+ build_train_dict: dict[str, str] = {}
173
173
  if build_train_config is not None:
174
174
  build_train_dict = string_to_dict(",".join(build_train_config))
175
175
  if "split" not in build_train_dict:
176
176
  build_train_dict["split"] = "train"
177
177
 
178
- build_val_dict: Dict[str, str] = {}
178
+ build_val_dict: dict[str, str] = {}
179
179
  if build_val_config is not None:
180
180
  build_val_dict = string_to_dict(",".join(build_val_config))
181
181
  if "split" not in build_val_dict:
@@ -191,12 +191,17 @@ def train_hf_detr(
191
191
  categories_dict_name_as_key = dataset_train.dataflow.categories.get_categories(name_as_key=True, filtered=True)
192
192
 
193
193
  dataset = DatasetAdapter(
194
- dataset_train, True, image_to_hf_detr_training(category_names=categories), True, **build_train_dict
194
+ dataset_train,
195
+ True,
196
+ image_to_hf_detr_training(category_names=categories),
197
+ True,
198
+ number_repetitions=-1,
199
+ **build_train_dict,
195
200
  )
196
201
 
197
202
  number_samples = len(dataset)
198
203
  conf_dict = {
199
- "output_dir": log_dir,
204
+ "output_dir": os.fspath(log_dir),
200
205
  "remove_unused_columns": False,
201
206
  "per_device_train_batch_size": 2,
202
207
  "max_steps": number_samples,
@@ -256,7 +261,6 @@ def train_hf_detr(
256
261
  )
257
262
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
258
263
  pipeline_component = pipeline_component_cls(detector)
259
- assert isinstance(pipeline_component, PredictorPipelineComponent)
260
264
 
261
265
  if metric_name is not None:
262
266
  metric = metric_registry.get(metric_name)
@@ -24,7 +24,7 @@ import copy
24
24
  import json
25
25
  import os
26
26
  import pprint
27
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
27
+ from typing import Any, Optional, Sequence, Type, Union
28
28
 
29
29
  from lazy_imports import try_import
30
30
 
@@ -47,12 +47,13 @@ from ..extern.hflayoutlm import (
47
47
  from ..extern.hflm import HFLmSequenceClassifier
48
48
  from ..extern.pt.ptutils import get_torch_device
49
49
  from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features, image_to_raw_lm_features
50
- from ..pipe.base import LanguageModelPipelineComponent
50
+ from ..pipe.base import PipelineComponent
51
51
  from ..pipe.registry import pipeline_component_registry
52
52
  from ..utils.error import DependencyError
53
53
  from ..utils.file_utils import wandb_available
54
54
  from ..utils.logger import LoggingRecord, logger
55
55
  from ..utils.settings import DatasetType, LayoutType, WordType
56
+ from ..utils.types import PathLikeOrStr
56
57
  from ..utils.utils import string_to_dict
57
58
 
58
59
  with try_import() as pt_import_guard:
@@ -82,7 +83,7 @@ with try_import() as wb_import_guard:
82
83
  import wandb
83
84
 
84
85
 
85
- def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> Tuple[Any, Any, Any]:
86
+ def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> tuple[Any, Any, Any]:
86
87
  """
87
88
  Get the model architecture, model wrapper and config class for a given model type and dataset type.
88
89
 
@@ -91,47 +92,47 @@ def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetTy
91
92
  :return: Tuple of model architecture, model wrapper and config class
92
93
  """
93
94
  return {
94
- ("layoutlm", DatasetType.sequence_classification): (
95
+ ("layoutlm", DatasetType.SEQUENCE_CLASSIFICATION): (
95
96
  LayoutLMForSequenceClassification,
96
97
  HFLayoutLmSequenceClassifier,
97
98
  PretrainedConfig,
98
99
  ),
99
- ("layoutlm", DatasetType.token_classification): (
100
+ ("layoutlm", DatasetType.TOKEN_CLASSIFICATION): (
100
101
  LayoutLMForTokenClassification,
101
102
  HFLayoutLmTokenClassifier,
102
103
  PretrainedConfig,
103
104
  ),
104
- ("layoutlmv2", DatasetType.sequence_classification): (
105
+ ("layoutlmv2", DatasetType.SEQUENCE_CLASSIFICATION): (
105
106
  LayoutLMv2ForSequenceClassification,
106
107
  HFLayoutLmv2SequenceClassifier,
107
108
  LayoutLMv2Config,
108
109
  ),
109
- ("layoutlmv2", DatasetType.token_classification): (
110
+ ("layoutlmv2", DatasetType.TOKEN_CLASSIFICATION): (
110
111
  LayoutLMv2ForTokenClassification,
111
112
  HFLayoutLmv2TokenClassifier,
112
113
  LayoutLMv2Config,
113
114
  ),
114
- ("layoutlmv3", DatasetType.sequence_classification): (
115
+ ("layoutlmv3", DatasetType.SEQUENCE_CLASSIFICATION): (
115
116
  LayoutLMv3ForSequenceClassification,
116
117
  HFLayoutLmv3SequenceClassifier,
117
118
  LayoutLMv3Config,
118
119
  ),
119
- ("layoutlmv3", DatasetType.token_classification): (
120
+ ("layoutlmv3", DatasetType.TOKEN_CLASSIFICATION): (
120
121
  LayoutLMv3ForTokenClassification,
121
122
  HFLayoutLmv3TokenClassifier,
122
123
  LayoutLMv3Config,
123
124
  ),
124
- ("lilt", DatasetType.token_classification): (
125
+ ("lilt", DatasetType.TOKEN_CLASSIFICATION): (
125
126
  LiltForTokenClassification,
126
127
  HFLiltTokenClassifier,
127
128
  PretrainedConfig,
128
129
  ),
129
- ("lilt", DatasetType.sequence_classification): (
130
+ ("lilt", DatasetType.SEQUENCE_CLASSIFICATION): (
130
131
  LiltForSequenceClassification,
131
132
  HFLiltSequenceClassifier,
132
133
  PretrainedConfig,
133
134
  ),
134
- ("xlm-roberta", DatasetType.sequence_classification): (
135
+ ("xlm-roberta", DatasetType.SEQUENCE_CLASSIFICATION): (
135
136
  XLMRobertaForSequenceClassification,
136
137
  HFLmSequenceClassifier,
137
138
  PretrainedConfig,
@@ -163,13 +164,13 @@ class LayoutLMTrainer(Trainer):
163
164
  train_dataset: Dataset[Any],
164
165
  ):
165
166
  self.evaluator: Optional[Evaluator] = None
166
- self.build_eval_kwargs: Optional[Dict[str, Any]] = None
167
+ self.build_eval_kwargs: Optional[dict[str, Any]] = None
167
168
  super().__init__(model, args, data_collator, train_dataset)
168
169
 
169
170
  def setup_evaluator(
170
171
  self,
171
172
  dataset_val: DatasetBase,
172
- pipeline_component: LanguageModelPipelineComponent,
173
+ pipeline_component: PipelineComponent,
173
174
  metric: Union[Type[ClassificationMetric], ClassificationMetric],
174
175
  run: Optional[wandb.sdk.wandb_run.Run] = None,
175
176
  **build_eval_kwargs: Union[str, int],
@@ -188,15 +189,15 @@ class LayoutLMTrainer(Trainer):
188
189
  self.evaluator = Evaluator(dataset_val, pipeline_component, metric, num_threads=1, run=run)
189
190
  assert self.evaluator.pipe_component
190
191
  for comp in self.evaluator.pipe_component.pipe_components:
191
- comp.language_model.model = None # type: ignore
192
+ comp.clear_predictor()
192
193
  self.build_eval_kwargs = build_eval_kwargs
193
194
 
194
195
  def evaluate(
195
196
  self,
196
197
  eval_dataset: Optional[Dataset[Any]] = None, # pylint: disable=W0613
197
- ignore_keys: Optional[List[str]] = None, # pylint: disable=W0613
198
+ ignore_keys: Optional[list[str]] = None, # pylint: disable=W0613
198
199
  metric_key_prefix: str = "eval", # pylint: disable=W0613
199
- ) -> Dict[str, float]:
200
+ ) -> dict[str, float]:
200
201
  """
201
202
  Overwritten method from `Trainer`. Arguments will not be used.
202
203
  """
@@ -220,8 +221,8 @@ class LayoutLMTrainer(Trainer):
220
221
 
221
222
 
222
223
  def _get_model_class_and_tokenizer(
223
- path_config_json: str, dataset_type: DatasetType, use_xlm_tokenizer: bool
224
- ) -> Tuple[Any, Any, Any, Any, Any]:
224
+ path_config_json: PathLikeOrStr, dataset_type: DatasetType, use_xlm_tokenizer: bool
225
+ ) -> tuple[Any, Any, Any, Any, Any]:
225
226
  with open(path_config_json, "r", encoding="UTF-8") as file:
226
227
  config_json = json.load(file)
227
228
 
@@ -244,11 +245,11 @@ def get_image_to_raw_features_mapping(input_str: str) -> Any:
244
245
 
245
246
 
246
247
  def train_hf_layoutlm(
247
- path_config_json: str,
248
+ path_config_json: PathLikeOrStr,
248
249
  dataset_train: Union[str, DatasetBase],
249
- path_weights: str,
250
- config_overwrite: Optional[List[str]] = None,
251
- log_dir: str = "train_log/layoutlm",
250
+ path_weights: PathLikeOrStr,
251
+ config_overwrite: Optional[list[str]] = None,
252
+ log_dir: PathLikeOrStr = "train_log/layoutlm",
252
253
  build_train_config: Optional[Sequence[str]] = None,
253
254
  dataset_val: Optional[DatasetBase] = None,
254
255
  build_val_config: Optional[Sequence[str]] = None,
@@ -323,13 +324,13 @@ def train_hf_layoutlm(
323
324
  appear as child, it will use the word bounding box.
324
325
  """
325
326
 
326
- build_train_dict: Dict[str, str] = {}
327
+ build_train_dict: dict[str, str] = {}
327
328
  if build_train_config is not None:
328
329
  build_train_dict = string_to_dict(",".join(build_train_config))
329
330
  if "split" not in build_train_dict:
330
331
  build_train_dict["split"] = "train"
331
332
 
332
- build_val_dict: Dict[str, str] = {}
333
+ build_val_dict: dict[str, str] = {}
333
334
  if build_val_config is not None:
334
335
  build_val_dict = string_to_dict(",".join(build_val_config))
335
336
  if "split" not in build_val_dict:
@@ -343,25 +344,25 @@ def train_hf_layoutlm(
343
344
 
344
345
  # We wrap our dataset into a torch dataset
345
346
  dataset_type = dataset_train.dataset_info.type
346
- if dataset_type == DatasetType.sequence_classification:
347
+ if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
347
348
  categories_dict_name_as_key = dataset_train.dataflow.categories.get_categories(as_dict=True, name_as_key=True)
348
- elif dataset_type == DatasetType.token_classification:
349
+ elif dataset_type == DatasetType.TOKEN_CLASSIFICATION:
349
350
  if use_token_tag:
350
351
  categories_dict_name_as_key = dataset_train.dataflow.categories.get_sub_categories(
351
- categories=LayoutType.word,
352
- sub_categories={LayoutType.word: [WordType.token_tag]},
352
+ categories=LayoutType.WORD,
353
+ sub_categories={LayoutType.WORD: [WordType.TOKEN_TAG]},
353
354
  keys=False,
354
355
  values_as_dict=True,
355
356
  name_as_key=True,
356
- )[LayoutType.word][WordType.token_tag]
357
+ )[LayoutType.WORD][WordType.TOKEN_TAG]
357
358
  else:
358
359
  categories_dict_name_as_key = dataset_train.dataflow.categories.get_sub_categories(
359
- categories=LayoutType.word,
360
- sub_categories={LayoutType.word: [WordType.token_class]},
360
+ categories=LayoutType.WORD,
361
+ sub_categories={LayoutType.WORD: [WordType.TOKEN_CLASS]},
361
362
  keys=False,
362
363
  values_as_dict=True,
363
364
  name_as_key=True,
364
- )[LayoutType.word][WordType.token_class]
365
+ )[LayoutType.WORD][WordType.TOKEN_CLASS]
365
366
  else:
366
367
  raise UserWarning("Dataset type not supported for training")
367
368
 
@@ -372,13 +373,14 @@ def train_hf_layoutlm(
372
373
  image_to_raw_features_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
373
374
  if segment_positions:
374
375
  image_to_raw_features_kwargs["segment_positions"] = segment_positions # type: ignore
375
- image_to_raw_features_kwargs.update(model_wrapper_cls.default_kwargs_for_input_mapping())
376
+ image_to_raw_features_kwargs.update(model_wrapper_cls.default_kwargs_for_image_to_features_mapping())
376
377
 
377
378
  dataset = DatasetAdapter(
378
379
  dataset_train,
379
380
  True,
380
381
  image_to_raw_features_func(**image_to_raw_features_kwargs),
381
382
  use_token_tag,
383
+ number_repetitions=-1,
382
384
  **build_train_dict,
383
385
  )
384
386
 
@@ -388,7 +390,7 @@ def train_hf_layoutlm(
388
390
  # Need to set remove_unused_columns to False, as the DataCollator for column removal will remove some raw features
389
391
  # that are necessary for the tokenizer.
390
392
  conf_dict = {
391
- "output_dir": log_dir,
393
+ "output_dir": os.fspath(log_dir),
392
394
  "remove_unused_columns": False,
393
395
  "per_device_train_batch_size": 8,
394
396
  "max_steps": number_samples,
@@ -429,16 +431,16 @@ def train_hf_layoutlm(
429
431
  )
430
432
 
431
433
  use_wandb = conf_dict.pop("use_wandb")
432
- wandb_project = conf_dict.pop("wandb_project")
433
- wandb_repo = conf_dict.pop("wandb_repo")
434
+ wandb_project = str(conf_dict.pop("wandb_project"))
435
+ wandb_repo = str(conf_dict.pop("wandb_repo"))
434
436
 
435
437
  # Initialize Wandb, if necessary
436
438
  run = None
437
439
  if use_wandb:
438
440
  if not wandb_available():
439
441
  raise DependencyError("WandB must be installed separately")
440
- run = wandb.init(project=wandb_project, config=conf_dict) # type: ignore
441
- run._label(repo=wandb_repo) # type: ignore # pylint: disable=W0212
442
+ run = wandb.init(project=wandb_project, config=conf_dict)
443
+ run._label(repo=wandb_repo) # pylint: disable=W0212
442
444
  else:
443
445
  os.environ["WANDB_DISABLED"] = "True"
444
446
 
@@ -474,19 +476,19 @@ def train_hf_layoutlm(
474
476
 
475
477
  if arguments.evaluation_strategy in (IntervalStrategy.STEPS,):
476
478
  assert metric is not None # silence mypy
477
- if dataset_type == DatasetType.sequence_classification:
479
+ if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
478
480
  categories = dataset_val.dataflow.categories.get_categories(filtered=True) # type: ignore
479
481
  else:
480
482
  if use_token_tag:
481
483
  categories = dataset_val.dataflow.categories.get_sub_categories( # type: ignore
482
- categories=LayoutType.word, sub_categories={LayoutType.word: [WordType.token_tag]}, keys=False
483
- )[LayoutType.word][WordType.token_tag]
484
- metric.set_categories(category_names=LayoutType.word, sub_category_names={"word": ["token_tag"]})
484
+ categories=LayoutType.WORD, sub_categories={LayoutType.WORD: [WordType.TOKEN_TAG]}, keys=False
485
+ )[LayoutType.WORD][WordType.TOKEN_TAG]
486
+ metric.set_categories(category_names=LayoutType.WORD, sub_category_names={"word": ["token_tag"]})
485
487
  else:
486
488
  categories = dataset_val.dataflow.categories.get_sub_categories( # type: ignore
487
- categories=LayoutType.word, sub_categories={LayoutType.word: [WordType.token_class]}, keys=False
488
- )[LayoutType.word][WordType.token_class]
489
- metric.set_categories(category_names=LayoutType.word, sub_category_names={"word": ["token_class"]})
489
+ categories=LayoutType.WORD, sub_categories={LayoutType.WORD: [WordType.TOKEN_CLASS]}, keys=False
490
+ )[LayoutType.WORD][WordType.TOKEN_CLASS]
491
+ metric.set_categories(category_names=LayoutType.WORD, sub_category_names={"word": ["token_class"]})
490
492
  dd_model = model_wrapper_cls(
491
493
  path_config_json=path_config_json,
492
494
  path_weights=path_weights,
@@ -495,7 +497,7 @@ def train_hf_layoutlm(
495
497
  use_xlm_tokenizer=use_xlm_tokenizer,
496
498
  )
497
499
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
498
- if dataset_type == DatasetType.sequence_classification:
500
+ if dataset_type == DatasetType.SEQUENCE_CLASSIFICATION:
499
501
  pipeline_component = pipeline_component_cls(tokenizer_fast, dd_model)
500
502
  else:
501
503
  pipeline_component = pipeline_component_cls(
@@ -504,7 +506,6 @@ def train_hf_layoutlm(
504
506
  use_other_as_default_category=True,
505
507
  sliding_window_stride=sliding_window_stride,
506
508
  )
507
- assert isinstance(pipeline_component, LanguageModelPipelineComponent)
508
509
 
509
510
  trainer.setup_evaluator(dataset_val, pipeline_component, metric, run, **build_val_dict) # type: ignore
510
511
 
@@ -20,7 +20,7 @@ Module for training Tensorpack `GeneralizedRCNN`
20
20
  """
21
21
 
22
22
  import os
23
- from typing import Dict, List, Optional, Sequence, Type, Union
23
+ from typing import Optional, Sequence, Type, Union
24
24
 
25
25
  from lazy_imports import try_import
26
26
 
@@ -40,14 +40,13 @@ from ..extern.tp.tpfrcnn.preproc import anchors_and_labels, augment
40
40
  from ..extern.tpdetect import TPFrcnnDetector
41
41
  from ..mapper.maputils import LabelSummarizer
42
42
  from ..mapper.tpstruct import image_to_tp_frcnn_training
43
- from ..pipe.base import PredictorPipelineComponent
44
43
  from ..pipe.registry import pipeline_component_registry
45
- from ..utils.detection_types import JsonDict
46
44
  from ..utils.file_utils import set_mp_spawn
47
45
  from ..utils.fs import get_load_image_func
48
46
  from ..utils.logger import log_once
49
47
  from ..utils.metacfg import AttrDict, set_config_by_yaml
50
48
  from ..utils.tqdm import get_tqdm
49
+ from ..utils.types import JsonDict, PathLikeOrStr
51
50
  from ..utils.utils import string_to_dict
52
51
 
53
52
  with try_import() as tp_import_guard:
@@ -185,11 +184,11 @@ def get_train_dataflow(
185
184
 
186
185
 
187
186
  def train_faster_rcnn(
188
- path_config_yaml: str,
187
+ path_config_yaml: PathLikeOrStr,
189
188
  dataset_train: DatasetBase,
190
- path_weights: str = "",
191
- config_overwrite: Optional[List[str]] = None,
192
- log_dir: str = "train_log/frcnn",
189
+ path_weights: PathLikeOrStr,
190
+ config_overwrite: Optional[list[str]] = None,
191
+ log_dir: PathLikeOrStr = "train_log/frcnn",
193
192
  build_train_config: Optional[Sequence[str]] = None,
194
193
  dataset_val: Optional[DatasetBase] = None,
195
194
  build_val_config: Optional[Sequence[str]] = None,
@@ -224,13 +223,13 @@ def train_faster_rcnn(
224
223
 
225
224
  assert disable_tfv2() # TP works only in Graph mode
226
225
 
227
- build_train_dict: Dict[str, str] = {}
226
+ build_train_dict: dict[str, str] = {}
228
227
  if build_train_config is not None:
229
228
  build_train_dict = string_to_dict(",".join(build_train_config))
230
229
  if "split" not in build_train_dict:
231
230
  build_train_dict["split"] = "train"
232
231
 
233
- build_val_dict: Dict[str, str] = {}
232
+ build_val_dict: dict[str, str] = {}
234
233
  if build_val_config is not None:
235
234
  build_val_dict = string_to_dict(",".join(build_val_config))
236
235
  if "split" not in build_val_dict:
@@ -238,7 +237,7 @@ def train_faster_rcnn(
238
237
 
239
238
  config_overwrite = [] if config_overwrite is None else config_overwrite
240
239
 
241
- log_dir = "TRAIN.LOG_DIR=" + log_dir
240
+ log_dir = "TRAIN.LOG_DIR=" + os.fspath(log_dir)
242
241
  config_overwrite.append(log_dir)
243
242
 
244
243
  config = set_config_by_yaml(path_config_yaml)
@@ -299,7 +298,6 @@ def train_faster_rcnn(
299
298
  ) # only a wrapper for the predictor itself. Will be replaced in Callback
300
299
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
301
300
  pipeline_component = pipeline_component_cls(detector)
302
- assert isinstance(pipeline_component, PredictorPipelineComponent)
303
301
  category_names = list(categories.values())
304
302
  callbacks.extend(
305
303
  [
@@ -310,6 +308,7 @@ def train_faster_rcnn(
310
308
  metric, # type: ignore
311
309
  pipeline_component,
312
310
  *model.get_inference_tensor_names(), # type: ignore
311
+ cfg=detector.model.cfg,
313
312
  **build_val_dict
314
313
  )
315
314
  ]
@@ -28,8 +28,8 @@ import threading
28
28
  from contextlib import contextmanager
29
29
  from typing import Any, Generator, Optional, no_type_check
30
30
 
31
- from .detection_types import QueueType
32
31
  from .logger import log_once
32
+ from .types import QueueType
33
33
 
34
34
 
35
35
  # taken from https://github.com/tensorpack/dataflow/blob/master/dataflow/utils/concurrency.py
@@ -26,12 +26,12 @@ from glob import iglob
26
26
  from os import path, remove
27
27
  from tempfile import NamedTemporaryFile
28
28
  from time import perf_counter as timer
29
- from typing import Any, Generator, Iterator, Optional, Tuple, Union
29
+ from typing import Any, Generator, Iterator, Optional, Union
30
30
 
31
31
  import numpy as np
32
32
 
33
- from .detection_types import ImageType
34
33
  from .logger import LoggingRecord, logger
34
+ from .types import B64, B64Str, PixelValues
35
35
  from .viz import viz_handler
36
36
 
37
37
  __all__ = ["timeout_manager", "save_tmp_file", "timed_operation"]
@@ -72,7 +72,7 @@ def timeout_manager(proc, seconds: Optional[int] = None) -> Iterator[str]: # ty
72
72
 
73
73
 
74
74
  @contextmanager
75
- def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[Tuple[str, str]]:
75
+ def save_tmp_file(image: Union[B64Str, PixelValues, B64], prefix: str) -> Iterator[tuple[str, str]]:
76
76
  """
77
77
  Save image temporarily and handle the clean-up once not necessary anymore
78
78
 
@@ -112,13 +112,20 @@ def save_tmp_file(image: Union[str, ImageType, bytes], prefix: str) -> Iterator[
112
112
  @contextmanager
113
113
  def timed_operation(message: str, log_start: bool = False) -> Generator[Any, None, None]:
114
114
  """
115
- Contextmanager with a timer. Can therefore be used in a with statement.
115
+ Contextmanager with a timer.
116
116
 
117
- :param message: a log to print
117
+ ... code-block:: python
118
+
119
+ with timed_operation(message="Your stdout message", log_start=True):
120
+
121
+ with open("log.txt", "a") as file:
122
+ ...
123
+
124
+
125
+ :param message: a log to stdout
118
126
  :param log_start: whether to print also the beginning
119
127
  """
120
128
 
121
- assert len(message)
122
129
  if log_start:
123
130
  logger.info(LoggingRecord(f"start task: {message} ..."))
124
131
  start = timer()
@@ -26,19 +26,19 @@ import functools
26
26
  import inspect
27
27
  from collections import defaultdict
28
28
  from datetime import datetime
29
- from typing import Callable, List, Optional
29
+ from typing import Callable, Optional
30
30
 
31
- from .detection_types import T
32
31
  from .logger import LoggingRecord, logger
32
+ from .types import T
33
33
 
34
- __all__: List[str] = ["deprecated"]
34
+ __all__: list[str] = ["deprecated"]
35
35
 
36
36
  # Copy and paste from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py
37
37
 
38
38
  _DEPRECATED_LOG_NUM = defaultdict(int) # type: ignore
39
39
 
40
40
 
41
- def log_deprecated(name: str = "", text: str = "", eos: str = "", max_num_warnings: Optional[int] = None) -> None:
41
+ def log_deprecated(name: str, text: str, eos: str = "", max_num_warnings: Optional[int] = None) -> None:
42
42
  """
43
43
  Log deprecation warning.
44
44