deepdoctection 0.31__py3-none-any.whl → 0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (91) hide show
  1. deepdoctection/__init__.py +35 -28
  2. deepdoctection/analyzer/dd.py +30 -24
  3. deepdoctection/configs/conf_dd_one.yaml +34 -31
  4. deepdoctection/datapoint/annotation.py +2 -1
  5. deepdoctection/datapoint/box.py +2 -1
  6. deepdoctection/datapoint/image.py +13 -7
  7. deepdoctection/datapoint/view.py +95 -24
  8. deepdoctection/datasets/__init__.py +1 -4
  9. deepdoctection/datasets/adapter.py +5 -2
  10. deepdoctection/datasets/base.py +5 -3
  11. deepdoctection/datasets/info.py +2 -2
  12. deepdoctection/datasets/instances/doclaynet.py +3 -2
  13. deepdoctection/datasets/instances/fintabnet.py +2 -1
  14. deepdoctection/datasets/instances/funsd.py +2 -1
  15. deepdoctection/datasets/instances/iiitar13k.py +5 -2
  16. deepdoctection/datasets/instances/layouttest.py +2 -1
  17. deepdoctection/datasets/instances/publaynet.py +2 -2
  18. deepdoctection/datasets/instances/pubtables1m.py +6 -3
  19. deepdoctection/datasets/instances/pubtabnet.py +2 -1
  20. deepdoctection/datasets/instances/rvlcdip.py +2 -1
  21. deepdoctection/datasets/instances/xfund.py +2 -1
  22. deepdoctection/eval/__init__.py +1 -4
  23. deepdoctection/eval/cocometric.py +2 -1
  24. deepdoctection/eval/eval.py +17 -13
  25. deepdoctection/eval/tedsmetric.py +14 -11
  26. deepdoctection/eval/tp_eval_callback.py +9 -3
  27. deepdoctection/extern/__init__.py +2 -7
  28. deepdoctection/extern/d2detect.py +24 -32
  29. deepdoctection/extern/deskew.py +4 -2
  30. deepdoctection/extern/doctrocr.py +75 -81
  31. deepdoctection/extern/fastlang.py +4 -2
  32. deepdoctection/extern/hfdetr.py +22 -28
  33. deepdoctection/extern/hflayoutlm.py +335 -103
  34. deepdoctection/extern/hflm.py +225 -0
  35. deepdoctection/extern/model.py +56 -47
  36. deepdoctection/extern/pdftext.py +8 -4
  37. deepdoctection/extern/pt/__init__.py +1 -3
  38. deepdoctection/extern/pt/nms.py +6 -2
  39. deepdoctection/extern/pt/ptutils.py +27 -19
  40. deepdoctection/extern/texocr.py +4 -2
  41. deepdoctection/extern/tp/tfutils.py +43 -9
  42. deepdoctection/extern/tp/tpcompat.py +10 -7
  43. deepdoctection/extern/tp/tpfrcnn/__init__.py +20 -0
  44. deepdoctection/extern/tp/tpfrcnn/common.py +7 -3
  45. deepdoctection/extern/tp/tpfrcnn/config/__init__.py +20 -0
  46. deepdoctection/extern/tp/tpfrcnn/config/config.py +9 -6
  47. deepdoctection/extern/tp/tpfrcnn/modeling/__init__.py +20 -0
  48. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +17 -7
  49. deepdoctection/extern/tp/tpfrcnn/modeling/generalized_rcnn.py +12 -6
  50. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +9 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/model_cascade.py +8 -5
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +16 -11
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +17 -10
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +14 -8
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_rpn.py +15 -10
  56. deepdoctection/extern/tp/tpfrcnn/predict.py +9 -4
  57. deepdoctection/extern/tp/tpfrcnn/preproc.py +7 -3
  58. deepdoctection/extern/tp/tpfrcnn/utils/__init__.py +20 -0
  59. deepdoctection/extern/tp/tpfrcnn/utils/box_ops.py +10 -2
  60. deepdoctection/extern/tpdetect.py +5 -8
  61. deepdoctection/mapper/__init__.py +3 -8
  62. deepdoctection/mapper/d2struct.py +8 -6
  63. deepdoctection/mapper/hfstruct.py +6 -1
  64. deepdoctection/mapper/laylmstruct.py +163 -20
  65. deepdoctection/mapper/maputils.py +3 -1
  66. deepdoctection/mapper/misc.py +6 -3
  67. deepdoctection/mapper/tpstruct.py +2 -2
  68. deepdoctection/pipe/__init__.py +1 -1
  69. deepdoctection/pipe/common.py +11 -9
  70. deepdoctection/pipe/concurrency.py +2 -1
  71. deepdoctection/pipe/layout.py +3 -1
  72. deepdoctection/pipe/lm.py +32 -64
  73. deepdoctection/pipe/order.py +142 -35
  74. deepdoctection/pipe/refine.py +8 -14
  75. deepdoctection/pipe/{cell.py → sub_layout.py} +1 -1
  76. deepdoctection/train/__init__.py +6 -12
  77. deepdoctection/train/d2_frcnn_train.py +21 -16
  78. deepdoctection/train/hf_detr_train.py +18 -11
  79. deepdoctection/train/hf_layoutlm_train.py +118 -101
  80. deepdoctection/train/tp_frcnn_train.py +21 -19
  81. deepdoctection/utils/env_info.py +41 -117
  82. deepdoctection/utils/logger.py +1 -0
  83. deepdoctection/utils/mocks.py +93 -0
  84. deepdoctection/utils/settings.py +1 -0
  85. deepdoctection/utils/viz.py +4 -3
  86. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/METADATA +27 -18
  87. deepdoctection-0.32.dist-info/RECORD +146 -0
  88. deepdoctection-0.31.dist-info/RECORD +0 -144
  89. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/LICENSE +0 -0
  90. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/WHEEL +0 -0
  91. {deepdoctection-0.31.dist-info → deepdoctection-0.32.dist-info}/top_level.txt +0 -0
@@ -18,32 +18,15 @@
18
18
  """
19
19
  Module for training Huggingface implementation of LayoutLm
20
20
  """
21
+ from __future__ import annotations
21
22
 
22
23
  import copy
23
24
  import json
24
25
  import os
25
26
  import pprint
26
- from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
27
-
28
- from torch.nn import Module
29
- from torch.utils.data import Dataset
30
- from transformers import (
31
- IntervalStrategy,
32
- LayoutLMForSequenceClassification,
33
- LayoutLMForTokenClassification,
34
- LayoutLMTokenizerFast,
35
- LayoutLMv2Config,
36
- LayoutLMv2ForSequenceClassification,
37
- LayoutLMv2ForTokenClassification,
38
- LayoutLMv3Config,
39
- LayoutLMv3ForSequenceClassification,
40
- LayoutLMv3ForTokenClassification,
41
- PretrainedConfig,
42
- PreTrainedModel,
43
- RobertaTokenizerFast,
44
- XLMRobertaTokenizerFast,
45
- )
46
- from transformers.trainer import Trainer, TrainingArguments
27
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
28
+
29
+ from lazy_imports import try_import
47
30
 
48
31
  from ..datasets.adapter import DatasetAdapter
49
32
  from ..datasets.base import DatasetBase
@@ -57,79 +40,108 @@ from ..extern.hflayoutlm import (
57
40
  HFLayoutLmv2TokenClassifier,
58
41
  HFLayoutLmv3SequenceClassifier,
59
42
  HFLayoutLmv3TokenClassifier,
43
+ HFLiltSequenceClassifier,
44
+ HFLiltTokenClassifier,
45
+ get_tokenizer_from_model_class,
60
46
  )
61
- from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features
47
+ from ..extern.hflm import HFLmSequenceClassifier
48
+ from ..extern.pt.ptutils import get_torch_device
49
+ from ..mapper.laylmstruct import LayoutLMDataCollator, image_to_raw_layoutlm_features, image_to_raw_lm_features
62
50
  from ..pipe.base import LanguageModelPipelineComponent
63
- from ..pipe.lm import get_tokenizer_from_architecture
64
51
  from ..pipe.registry import pipeline_component_registry
65
- from ..utils.env_info import get_device
66
52
  from ..utils.error import DependencyError
67
53
  from ..utils.file_utils import wandb_available
68
54
  from ..utils.logger import LoggingRecord, logger
69
- from ..utils.settings import DatasetType, LayoutType, ObjectTypes, WordType
55
+ from ..utils.settings import DatasetType, LayoutType, WordType
70
56
  from ..utils.utils import string_to_dict
71
57
 
72
- if wandb_available():
73
- import wandb
74
-
75
- _ARCHITECTURES_TO_MODEL_CLASS = {
76
- "LayoutLMForTokenClassification": (LayoutLMForTokenClassification, HFLayoutLmTokenClassifier, PretrainedConfig),
77
- "LayoutLMForSequenceClassification": (
78
- LayoutLMForSequenceClassification,
79
- HFLayoutLmSequenceClassifier,
80
- PretrainedConfig,
81
- ),
82
- "LayoutLMv2ForTokenClassification": (
83
- LayoutLMv2ForTokenClassification,
84
- HFLayoutLmv2TokenClassifier,
85
- LayoutLMv2Config,
86
- ),
87
- "LayoutLMv2ForSequenceClassification": (
88
- LayoutLMv2ForSequenceClassification,
89
- HFLayoutLmv2SequenceClassifier,
90
- LayoutLMv2Config,
91
- ),
92
- }
93
-
58
+ with try_import() as pt_import_guard:
59
+ from torch import nn
60
+ from torch.utils.data import Dataset
94
61
 
95
- _MODEL_TYPE_AND_TASK_TO_MODEL_CLASS: Mapping[Tuple[str, ObjectTypes], Any] = {
96
- ("layoutlm", DatasetType.sequence_classification): (
62
+ with try_import() as tr_import_guard:
63
+ from transformers import (
64
+ IntervalStrategy,
97
65
  LayoutLMForSequenceClassification,
98
- HFLayoutLmSequenceClassifier,
99
- PretrainedConfig,
100
- ),
101
- ("layoutlm", DatasetType.token_classification): (
102
66
  LayoutLMForTokenClassification,
103
- HFLayoutLmTokenClassifier,
104
- PretrainedConfig,
105
- ),
106
- ("layoutlmv2", DatasetType.sequence_classification): (
107
- LayoutLMv2ForSequenceClassification,
108
- HFLayoutLmv2SequenceClassifier,
109
67
  LayoutLMv2Config,
110
- ),
111
- ("layoutlmv2", DatasetType.token_classification): (
68
+ LayoutLMv2ForSequenceClassification,
112
69
  LayoutLMv2ForTokenClassification,
113
- HFLayoutLmv2TokenClassifier,
114
- LayoutLMv2Config,
115
- ),
116
- ("layoutlmv3", DatasetType.sequence_classification): (
117
- LayoutLMv3ForSequenceClassification,
118
- HFLayoutLmv3SequenceClassifier,
119
70
  LayoutLMv3Config,
120
- ),
121
- ("layoutlmv3", DatasetType.token_classification): (
71
+ LayoutLMv3ForSequenceClassification,
122
72
  LayoutLMv3ForTokenClassification,
123
- HFLayoutLmv3TokenClassifier,
124
- LayoutLMv3Config,
125
- ),
126
- }
127
- _MODEL_TYPE_TO_TOKENIZER = {
128
- ("layoutlm", False): LayoutLMTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased"),
129
- ("layoutlmv2", False): LayoutLMTokenizerFast.from_pretrained("microsoft/layoutlm-base-uncased"),
130
- ("layoutlmv2", True): XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base", add_prefix_space=True),
131
- ("layoutlmv3", False): RobertaTokenizerFast.from_pretrained("roberta-base", add_prefix_space=True),
132
- }
73
+ LiltForSequenceClassification,
74
+ LiltForTokenClassification,
75
+ PretrainedConfig,
76
+ PreTrainedModel,
77
+ XLMRobertaForSequenceClassification,
78
+ )
79
+ from transformers.trainer import Trainer, TrainingArguments
80
+
81
+ with try_import() as wb_import_guard:
82
+ import wandb
83
+
84
+
85
+ def get_model_architectures_and_configs(model_type: str, dataset_type: DatasetType) -> Tuple[Any, Any, Any]:
86
+ """
87
+ Get the model architecture, model wrapper and config class for a given model type and dataset type.
88
+
89
+ :param model_type: The model type
90
+ :param dataset_type: The dataset type
91
+ :return: Tuple of model architecture, model wrapper and config class
92
+ """
93
+ return {
94
+ ("layoutlm", DatasetType.sequence_classification): (
95
+ LayoutLMForSequenceClassification,
96
+ HFLayoutLmSequenceClassifier,
97
+ PretrainedConfig,
98
+ ),
99
+ ("layoutlm", DatasetType.token_classification): (
100
+ LayoutLMForTokenClassification,
101
+ HFLayoutLmTokenClassifier,
102
+ PretrainedConfig,
103
+ ),
104
+ ("layoutlmv2", DatasetType.sequence_classification): (
105
+ LayoutLMv2ForSequenceClassification,
106
+ HFLayoutLmv2SequenceClassifier,
107
+ LayoutLMv2Config,
108
+ ),
109
+ ("layoutlmv2", DatasetType.token_classification): (
110
+ LayoutLMv2ForTokenClassification,
111
+ HFLayoutLmv2TokenClassifier,
112
+ LayoutLMv2Config,
113
+ ),
114
+ ("layoutlmv3", DatasetType.sequence_classification): (
115
+ LayoutLMv3ForSequenceClassification,
116
+ HFLayoutLmv3SequenceClassifier,
117
+ LayoutLMv3Config,
118
+ ),
119
+ ("layoutlmv3", DatasetType.token_classification): (
120
+ LayoutLMv3ForTokenClassification,
121
+ HFLayoutLmv3TokenClassifier,
122
+ LayoutLMv3Config,
123
+ ),
124
+ ("lilt", DatasetType.token_classification): (
125
+ LiltForTokenClassification,
126
+ HFLiltTokenClassifier,
127
+ PretrainedConfig,
128
+ ),
129
+ ("lilt", DatasetType.sequence_classification): (
130
+ LiltForSequenceClassification,
131
+ HFLiltSequenceClassifier,
132
+ PretrainedConfig,
133
+ ),
134
+ ("xlm-roberta", DatasetType.sequence_classification): (
135
+ XLMRobertaForSequenceClassification,
136
+ HFLmSequenceClassifier,
137
+ PretrainedConfig,
138
+ ),
139
+ }[(model_type, dataset_type)]
140
+
141
+
142
+ def maybe_remove_bounding_box_features(model_type: str) -> bool:
143
+ """Listing of models that do not need bounding box features."""
144
+ return {"xlm-roberta": True}.get(model_type, False)
133
145
 
134
146
 
135
147
  class LayoutLMTrainer(Trainer):
@@ -145,7 +157,7 @@ class LayoutLMTrainer(Trainer):
145
157
 
146
158
  def __init__(
147
159
  self,
148
- model: Union[PreTrainedModel, Module],
160
+ model: Union[PreTrainedModel, nn.Module],
149
161
  args: TrainingArguments,
150
162
  data_collator: LayoutLMDataCollator,
151
163
  train_dataset: Dataset[Any],
@@ -159,7 +171,7 @@ class LayoutLMTrainer(Trainer):
159
171
  dataset_val: DatasetBase,
160
172
  pipeline_component: LanguageModelPipelineComponent,
161
173
  metric: Union[Type[ClassificationMetric], ClassificationMetric],
162
- run: Optional["wandb.sdk.wandb_run.Run"] = None,
174
+ run: Optional[wandb.sdk.wandb_run.Run] = None,
163
175
  **build_eval_kwargs: Union[str, int],
164
176
  ) -> None:
165
177
  """
@@ -208,26 +220,27 @@ class LayoutLMTrainer(Trainer):
208
220
 
209
221
 
210
222
  def _get_model_class_and_tokenizer(
211
- path_config_json: str, dataset_type: ObjectTypes, use_xlm_tokenizer: bool
212
- ) -> Tuple[Any, Any, Any, Any]:
223
+ path_config_json: str, dataset_type: DatasetType, use_xlm_tokenizer: bool
224
+ ) -> Tuple[Any, Any, Any, Any, Any]:
213
225
  with open(path_config_json, "r", encoding="UTF-8") as file:
214
226
  config_json = json.load(file)
215
227
 
216
- model_type = config_json.get("model_type")
217
-
218
- if architectures := config_json.get("architectures"):
219
- model_cls, model_wrapper_cls, config_cls = _ARCHITECTURES_TO_MODEL_CLASS[architectures[0]]
220
- tokenizer_fast = get_tokenizer_from_architecture(architectures[0], use_xlm_tokenizer)
221
- elif model_type:
222
- model_cls, model_wrapper_cls, config_cls = _MODEL_TYPE_AND_TASK_TO_MODEL_CLASS[(model_type, dataset_type)]
223
- tokenizer_fast = _MODEL_TYPE_TO_TOKENIZER[(model_type, use_xlm_tokenizer)]
228
+ if model_type := config_json.get("model_type"):
229
+ model_cls, model_wrapper_cls, config_cls = get_model_architectures_and_configs(model_type, dataset_type)
230
+ remove_box_features = maybe_remove_bounding_box_features(model_type)
224
231
  else:
225
- raise KeyError("model_type and architectures not available in configs")
232
+ raise KeyError("model_type not available in configs. It seems that the config is not valid")
226
233
 
227
- if not model_cls:
228
- raise UserWarning("model not eligible to run with this framework")
234
+ tokenizer_fast = get_tokenizer_from_model_class(model_cls.__name__, use_xlm_tokenizer)
235
+ return config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features
229
236
 
230
- return config_cls, model_cls, model_wrapper_cls, tokenizer_fast
237
+
238
+ def get_image_to_raw_features_mapping(input_str: str) -> Any:
239
+ """Replacing eval functions"""
240
+ return {
241
+ "image_to_raw_layoutlm_features": image_to_raw_layoutlm_features,
242
+ "image_to_raw_lm_features": image_to_raw_lm_features,
243
+ }[input_str]
231
244
 
232
245
 
233
246
  def train_hf_layoutlm(
@@ -352,17 +365,19 @@ def train_hf_layoutlm(
352
365
  else:
353
366
  raise UserWarning("Dataset type not supported for training")
354
367
 
355
- config_cls, model_cls, model_wrapper_cls, tokenizer_fast = _get_model_class_and_tokenizer(
368
+ config_cls, model_cls, model_wrapper_cls, tokenizer_fast, remove_box_features = _get_model_class_and_tokenizer(
356
369
  path_config_json, dataset_type, use_xlm_tokenizer
357
370
  )
358
- image_to_raw_layoutlm_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
371
+ image_to_raw_features_func = get_image_to_raw_features_mapping(model_wrapper_cls.image_to_raw_features_mapping())
372
+ image_to_raw_features_kwargs = {"dataset_type": dataset_type, "use_token_tag": use_token_tag}
359
373
  if segment_positions:
360
- image_to_raw_layoutlm_kwargs["segment_positions"] = segment_positions # type: ignore
361
- image_to_raw_layoutlm_kwargs.update(model_wrapper_cls.default_kwargs_for_input_mapping())
374
+ image_to_raw_features_kwargs["segment_positions"] = segment_positions # type: ignore
375
+ image_to_raw_features_kwargs.update(model_wrapper_cls.default_kwargs_for_input_mapping())
376
+
362
377
  dataset = DatasetAdapter(
363
378
  dataset_train,
364
379
  True,
365
- image_to_raw_layoutlm_features(**image_to_raw_layoutlm_kwargs),
380
+ image_to_raw_features_func(**image_to_raw_features_kwargs),
366
381
  use_token_tag,
367
382
  **build_train_dict,
368
383
  )
@@ -453,6 +468,7 @@ def train_hf_layoutlm(
453
468
  return_tensors="pt",
454
469
  sliding_window_stride=sliding_window_stride, # type: ignore
455
470
  max_batch_size=max_batch_size, # type: ignore
471
+ remove_bounding_box_features=remove_box_features,
456
472
  )
457
473
  trainer = LayoutLMTrainer(model, arguments, data_collator, dataset)
458
474
 
@@ -475,7 +491,8 @@ def train_hf_layoutlm(
475
491
  path_config_json=path_config_json,
476
492
  path_weights=path_weights,
477
493
  categories=categories,
478
- device=get_device(),
494
+ device=get_torch_device(),
495
+ use_xlm_tokenizer=use_xlm_tokenizer,
479
496
  )
480
497
  pipeline_component_cls = pipeline_component_registry.get(pipeline_component_name)
481
498
  if dataset_type == DatasetType.sequence_classification:
@@ -22,25 +22,7 @@ Module for training Tensorpack `GeneralizedRCNN`
22
22
  import os
23
23
  from typing import Dict, List, Optional, Sequence, Type, Union
24
24
 
25
- # pylint: disable=import-error
26
- from tensorpack.callbacks import (
27
- EstimatedTimeLeft,
28
- GPUMemoryTracker,
29
- GPUUtilizationTracker,
30
- HostMemoryTracker,
31
- ModelSaver,
32
- PeriodicCallback,
33
- ScheduledHyperParamSetter,
34
- SessionRunTimeout,
35
- ThroughputTracker,
36
- )
37
-
38
- # todo: check how dataflow import is directly possible without having AssertionError
39
- from tensorpack.dataflow import ProxyDataFlow, imgaug
40
- from tensorpack.input_source import QueueInput
41
- from tensorpack.tfutils import SmartInit
42
- from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
43
- from tensorpack.utils import logger
25
+ from lazy_imports import try_import
44
26
 
45
27
  from ..dataflow.base import DataFlow
46
28
  from ..dataflow.common import MapData
@@ -68,6 +50,26 @@ from ..utils.metacfg import AttrDict, set_config_by_yaml
68
50
  from ..utils.tqdm import get_tqdm
69
51
  from ..utils.utils import string_to_dict
70
52
 
53
+ with try_import() as tp_import_guard:
54
+ # todo: check how dataflow import is directly possible without having an AssertionError
55
+ # pylint: disable=import-error
56
+ from tensorpack.callbacks import (
57
+ EstimatedTimeLeft,
58
+ GPUMemoryTracker,
59
+ GPUUtilizationTracker,
60
+ HostMemoryTracker,
61
+ ModelSaver,
62
+ PeriodicCallback,
63
+ ScheduledHyperParamSetter,
64
+ SessionRunTimeout,
65
+ ThroughputTracker,
66
+ )
67
+ from tensorpack.dataflow import ProxyDataFlow, imgaug
68
+ from tensorpack.input_source import QueueInput
69
+ from tensorpack.tfutils import SmartInit
70
+ from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
71
+ from tensorpack.utils import logger
72
+
71
73
  __all__ = ["train_faster_rcnn"]
72
74
 
73
75
 
@@ -46,16 +46,16 @@ can store an (absolute) path to a `.jsonl` file.
46
46
 
47
47
  """
48
48
 
49
- import ast
50
49
  import importlib
51
50
  import os
52
51
  import re
53
52
  import subprocess
54
53
  import sys
55
54
  from collections import defaultdict
56
- from typing import List, Literal, Optional, Tuple
55
+ from typing import List, Optional, Tuple
57
56
 
58
57
  import numpy as np
58
+ from packaging import version
59
59
  from tabulate import tabulate
60
60
 
61
61
  from .file_utils import (
@@ -68,6 +68,7 @@ from .file_utils import (
68
68
  fasttext_available,
69
69
  get_poppler_version,
70
70
  get_tesseract_version,
71
+ get_tf_version,
71
72
  jdeskew_available,
72
73
  lxml_available,
73
74
  opencv_available,
@@ -84,13 +85,9 @@ from .file_utils import (
84
85
  transformers_available,
85
86
  wandb_available,
86
87
  )
87
- from .logger import LoggingRecord, logger
88
88
 
89
89
  __all__ = [
90
- "collect_torch_env",
91
90
  "collect_env_info",
92
- "get_device",
93
- "auto_select_lib_and_device",
94
91
  "auto_select_viz_library",
95
92
  ]
96
93
 
@@ -270,7 +267,22 @@ def tf_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
270
267
  if tf_available():
271
268
  import tensorflow as tf # type: ignore # pylint: disable=E0401
272
269
 
270
+ os.environ["TENSORFLOW_AVAILABLE"] = "1"
271
+
273
272
  data.append(("Tensorflow", tf.__version__))
273
+ if version.parse(get_tf_version()) > version.parse("2.4.1"):
274
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
275
+ try:
276
+ import tensorflow.python.util.deprecation as deprecation # type: ignore # pylint: disable=E0401,R0402
277
+
278
+ deprecation._PRINT_DEPRECATION_WARNINGS = False # pylint: disable=W0212
279
+ except Exception: # pylint: disable=W0703
280
+ try:
281
+ from tensorflow.python.util import deprecation # type: ignore # pylint: disable=E0401
282
+
283
+ deprecation._PRINT_DEPRECATION_WARNINGS = False # pylint: disable=W0212
284
+ except Exception: # pylint: disable=W0703
285
+ pass
274
286
  else:
275
287
  data.append(("Tensorflow", "None"))
276
288
  return data
@@ -279,12 +291,18 @@ def tf_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
279
291
 
280
292
  try:
281
293
  for key, value in list(build_info.build_info.items()):
282
- if key == "cuda_version":
294
+ if key == "is_cuda_build":
295
+ data.append(("TF compiled with CUDA", value))
296
+ if value and len(tf.config.list_physical_devices('GPU')):
297
+ os.environ["USE_CUDA"] = "1"
298
+ elif key == "cuda_version":
283
299
  data.append(("TF built with CUDA", value))
284
300
  elif key == "cudnn_version":
285
301
  data.append(("TF built with CUDNN", value))
286
302
  elif key == "cuda_compute_capabilities":
287
303
  data.append(("TF compute capabilities", ",".join([k.replace("compute_", "") for k in value])))
304
+ elif key == "is_rocm_build":
305
+ data.append(("TF compiled with ROCM", value))
288
306
  return data
289
307
  except AttributeError:
290
308
  pass
@@ -306,6 +324,13 @@ def pt_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
306
324
 
307
325
  if pytorch_available():
308
326
  import torch
327
+
328
+ os.environ["PYTORCH_AVAILABLE"] = "1"
329
+
330
+ else:
331
+ data.append(("PyTorch", "None"))
332
+ return []
333
+
309
334
  has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
310
335
  has_mps = torch.backends.mps.is_available()
311
336
 
@@ -331,12 +356,9 @@ def pt_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
331
356
  data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
332
357
  data.append(("PyTorch debug build", str(torch.version.debug)))
333
358
 
334
- if not has_gpu:
335
- has_gpu_text = "No: torch.cuda.is_available() == False"
336
- else:
337
- has_gpu_text = "Yes"
338
- data.append(("GPU available", has_gpu_text))
339
359
  if has_gpu:
360
+ os.environ["USE_CUDA"] = "1"
361
+ has_gpu_text = "Yes"
340
362
  devices = defaultdict(list)
341
363
  for k in range(torch.cuda.device_count()):
342
364
  cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k)))
@@ -362,6 +384,10 @@ def pt_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
362
384
  cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
363
385
  if cuda_arch_list:
364
386
  data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
387
+ else:
388
+ has_gpu_text = "No: torch.cuda.is_available() == False"
389
+
390
+ data.append(("GPU available", has_gpu_text))
365
391
 
366
392
  mps_build = "No: torch.backends.mps.is_built() == False"
367
393
  if not has_mps:
@@ -369,9 +395,11 @@ def pt_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
369
395
  else:
370
396
  has_mps_text = "Yes"
371
397
  mps_build = str(torch.backends.mps.is_built())
398
+ if mps_build == "True":
399
+ os.environ["USE_MPS"] = "1"
372
400
 
373
401
  data.append(("MPS available", has_mps_text))
374
- data.append(("MPS available", mps_build))
402
+ data.append(("MPS built", mps_build))
375
403
 
376
404
  try:
377
405
  import torchvision # type: ignore
@@ -452,110 +480,6 @@ def collect_env_info() -> str:
452
480
  return env_str
453
481
 
454
482
 
455
- def set_env(name: str, value: str) -> None:
456
- """
457
- Set an environment variable if it is not already set.
458
-
459
- :param name: The name of the environment variable
460
- :param value: The value of the environment variable
461
- """
462
-
463
- if os.environ.get(name):
464
- return
465
- os.environ[name] = value
466
- return
467
-
468
-
469
- def auto_select_lib_and_device() -> None:
470
- """
471
- Select the DL library and subsequently the device.
472
- This will set environment variable `USE_TENSORFLOW`, `USE_PYTORCH` and `USE_CUDA`
473
-
474
- If TF is available, use TF unless a GPU is not available, in which case choose PT. If CUDA is not available and PT
475
- is not installed raise ImportError.
476
- """
477
-
478
- # USE_TF and USE_TORCH are env variables that steer DL library selection for Doctr.
479
- if tf_available() and tensorpack_available():
480
- from tensorpack.utils.gpu import get_num_gpu # pylint: disable=E0401
481
-
482
- if get_num_gpu() >= 1:
483
- set_env("USE_TENSORFLOW", "True")
484
- set_env("USE_PYTORCH", "False")
485
- set_env("USE_CUDA", "True")
486
- set_env("USE_MPS", "False")
487
- set_env("USE_TF", "TRUE")
488
- set_env("USE_TORCH", "False")
489
- return
490
- if pytorch_available():
491
- set_env("USE_TENSORFLOW", "False")
492
- set_env("USE_PYTORCH", "True")
493
- set_env("USE_CUDA", "False")
494
- set_env("USE_TF", "False")
495
- set_env("USE_TORCH", "TRUE")
496
- return
497
- logger.warning(
498
- LoggingRecord("You have Tensorflow installed but no GPU is available. All Tensorflow models require a GPU.")
499
- )
500
- if tf_available():
501
- set_env("USE_TENSORFLOW", "False")
502
- set_env("USE_PYTORCH", "False")
503
- set_env("USE_CUDA", "False")
504
- set_env("USE_TF", "AUTO")
505
- set_env("USE_TORCH", "AUTO")
506
- return
507
-
508
- if pytorch_available():
509
- import torch
510
-
511
- if torch.cuda.is_available():
512
- set_env("USE_TENSORFLOW", "False")
513
- set_env("USE_PYTORCH", "True")
514
- set_env("USE_CUDA", "True")
515
- set_env("USE_TF", "False")
516
- set_env("USE_TORCH", "TRUE")
517
- return
518
- if torch.backends.mps.is_available():
519
- set_env("USE_TENSORFLOW", "False")
520
- set_env("USE_PYTORCH", "True")
521
- set_env("USE_CUDA", "False")
522
- set_env("USE_MPS", "True")
523
- set_env("USE_TF", "False")
524
- set_env("USE_TORCH", "TRUE")
525
- return
526
- set_env("USE_TENSORFLOW", "False")
527
- set_env("USE_PYTORCH", "True")
528
- set_env("USE_CUDA", "False")
529
- set_env("USE_MPS", "False")
530
- set_env("USE_TF", "AUTO")
531
- set_env("USE_TORCH", "AUTO")
532
- return
533
- logger.warning(
534
- LoggingRecord(
535
- "Neither Tensorflow or Pytorch are available. You will not be able to use any Deep Learning "
536
- "model from the library."
537
- )
538
- )
539
-
540
-
541
- def get_device(ignore_cpu: bool = True) -> Literal["cuda", "mps", "cpu"]:
542
- """
543
- Device checks for running PyTorch with CUDA, MPS or optionall CPU.
544
- If nothing can be found and if `disable_cpu` is deactivated it will raise a `ValueError`
545
-
546
- :param ignore_cpu: Will not consider `cpu` as valid return value
547
- :return: Either cuda or mps
548
- """
549
-
550
- if ast.literal_eval(os.environ.get("USE_CUDA", "True")):
551
- return "cuda"
552
- if ast.literal_eval(os.environ.get("USE_MPS", "True")):
553
- return "mps"
554
- if not ignore_cpu:
555
- return "cpu"
556
- raise RuntimeWarning("Could not find either GPU nor MPS")
557
-
558
-
559
483
  def auto_select_viz_library() -> None:
560
484
  """Setting PIL as default image library if cv2 is not installed"""
561
485
 
@@ -134,6 +134,7 @@ class FileFormatter(logging.Formatter):
134
134
  _LOG_DIR = None
135
135
  _CONFIG_DICT: Dict[str, Any] = {
136
136
  "version": 1,
137
+ "disable_existing_loggers": False,
137
138
  "filters": {"customfilter": {"()": lambda: CustomFilter()}}, # pylint: disable=W0108
138
139
  "formatters": {
139
140
  "streamformatter": {"()": lambda: StreamFormatter(datefmt="%m%d %H:%M.%S")},