deepdoctection 0.32__py3-none-any.whl → 0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of deepdoctection might be problematic. Click here for more details.

Files changed (111) hide show
  1. deepdoctection/__init__.py +8 -25
  2. deepdoctection/analyzer/dd.py +84 -71
  3. deepdoctection/dataflow/common.py +9 -5
  4. deepdoctection/dataflow/custom.py +5 -5
  5. deepdoctection/dataflow/custom_serialize.py +75 -18
  6. deepdoctection/dataflow/parallel_map.py +3 -3
  7. deepdoctection/dataflow/serialize.py +4 -4
  8. deepdoctection/dataflow/stats.py +3 -3
  9. deepdoctection/datapoint/annotation.py +78 -56
  10. deepdoctection/datapoint/box.py +7 -7
  11. deepdoctection/datapoint/convert.py +6 -6
  12. deepdoctection/datapoint/image.py +157 -75
  13. deepdoctection/datapoint/view.py +175 -151
  14. deepdoctection/datasets/adapter.py +30 -24
  15. deepdoctection/datasets/base.py +10 -10
  16. deepdoctection/datasets/dataflow_builder.py +3 -3
  17. deepdoctection/datasets/info.py +23 -25
  18. deepdoctection/datasets/instances/doclaynet.py +48 -49
  19. deepdoctection/datasets/instances/fintabnet.py +44 -45
  20. deepdoctection/datasets/instances/funsd.py +23 -23
  21. deepdoctection/datasets/instances/iiitar13k.py +8 -8
  22. deepdoctection/datasets/instances/layouttest.py +2 -2
  23. deepdoctection/datasets/instances/publaynet.py +3 -3
  24. deepdoctection/datasets/instances/pubtables1m.py +18 -18
  25. deepdoctection/datasets/instances/pubtabnet.py +30 -29
  26. deepdoctection/datasets/instances/rvlcdip.py +28 -29
  27. deepdoctection/datasets/instances/xfund.py +51 -30
  28. deepdoctection/datasets/save.py +6 -6
  29. deepdoctection/eval/accmetric.py +32 -33
  30. deepdoctection/eval/base.py +8 -9
  31. deepdoctection/eval/cocometric.py +13 -12
  32. deepdoctection/eval/eval.py +32 -26
  33. deepdoctection/eval/tedsmetric.py +16 -12
  34. deepdoctection/eval/tp_eval_callback.py +7 -16
  35. deepdoctection/extern/base.py +339 -134
  36. deepdoctection/extern/d2detect.py +69 -89
  37. deepdoctection/extern/deskew.py +11 -10
  38. deepdoctection/extern/doctrocr.py +81 -64
  39. deepdoctection/extern/fastlang.py +23 -16
  40. deepdoctection/extern/hfdetr.py +53 -38
  41. deepdoctection/extern/hflayoutlm.py +216 -155
  42. deepdoctection/extern/hflm.py +35 -30
  43. deepdoctection/extern/model.py +433 -255
  44. deepdoctection/extern/pdftext.py +15 -15
  45. deepdoctection/extern/pt/ptutils.py +4 -2
  46. deepdoctection/extern/tessocr.py +39 -38
  47. deepdoctection/extern/texocr.py +14 -16
  48. deepdoctection/extern/tp/tfutils.py +16 -2
  49. deepdoctection/extern/tp/tpcompat.py +11 -7
  50. deepdoctection/extern/tp/tpfrcnn/config/config.py +4 -4
  51. deepdoctection/extern/tp/tpfrcnn/modeling/backbone.py +1 -1
  52. deepdoctection/extern/tp/tpfrcnn/modeling/model_box.py +5 -5
  53. deepdoctection/extern/tp/tpfrcnn/modeling/model_fpn.py +6 -6
  54. deepdoctection/extern/tp/tpfrcnn/modeling/model_frcnn.py +4 -4
  55. deepdoctection/extern/tp/tpfrcnn/modeling/model_mrcnn.py +5 -3
  56. deepdoctection/extern/tp/tpfrcnn/preproc.py +5 -5
  57. deepdoctection/extern/tpdetect.py +40 -45
  58. deepdoctection/mapper/cats.py +36 -40
  59. deepdoctection/mapper/cocostruct.py +16 -12
  60. deepdoctection/mapper/d2struct.py +22 -22
  61. deepdoctection/mapper/hfstruct.py +7 -7
  62. deepdoctection/mapper/laylmstruct.py +22 -24
  63. deepdoctection/mapper/maputils.py +9 -10
  64. deepdoctection/mapper/match.py +33 -2
  65. deepdoctection/mapper/misc.py +6 -7
  66. deepdoctection/mapper/pascalstruct.py +4 -4
  67. deepdoctection/mapper/prodigystruct.py +6 -6
  68. deepdoctection/mapper/pubstruct.py +84 -92
  69. deepdoctection/mapper/tpstruct.py +3 -3
  70. deepdoctection/mapper/xfundstruct.py +33 -33
  71. deepdoctection/pipe/anngen.py +39 -14
  72. deepdoctection/pipe/base.py +68 -99
  73. deepdoctection/pipe/common.py +181 -85
  74. deepdoctection/pipe/concurrency.py +14 -10
  75. deepdoctection/pipe/doctectionpipe.py +24 -21
  76. deepdoctection/pipe/language.py +20 -25
  77. deepdoctection/pipe/layout.py +18 -16
  78. deepdoctection/pipe/lm.py +49 -47
  79. deepdoctection/pipe/order.py +63 -65
  80. deepdoctection/pipe/refine.py +102 -109
  81. deepdoctection/pipe/segment.py +157 -162
  82. deepdoctection/pipe/sub_layout.py +50 -40
  83. deepdoctection/pipe/text.py +37 -36
  84. deepdoctection/pipe/transform.py +19 -16
  85. deepdoctection/train/d2_frcnn_train.py +27 -25
  86. deepdoctection/train/hf_detr_train.py +22 -18
  87. deepdoctection/train/hf_layoutlm_train.py +49 -48
  88. deepdoctection/train/tp_frcnn_train.py +10 -11
  89. deepdoctection/utils/concurrency.py +1 -1
  90. deepdoctection/utils/context.py +13 -6
  91. deepdoctection/utils/develop.py +4 -4
  92. deepdoctection/utils/env_info.py +52 -14
  93. deepdoctection/utils/file_utils.py +6 -11
  94. deepdoctection/utils/fs.py +41 -14
  95. deepdoctection/utils/identifier.py +2 -2
  96. deepdoctection/utils/logger.py +15 -15
  97. deepdoctection/utils/metacfg.py +7 -7
  98. deepdoctection/utils/pdf_utils.py +39 -14
  99. deepdoctection/utils/settings.py +188 -182
  100. deepdoctection/utils/tqdm.py +1 -1
  101. deepdoctection/utils/transform.py +14 -9
  102. deepdoctection/utils/types.py +104 -0
  103. deepdoctection/utils/utils.py +7 -7
  104. deepdoctection/utils/viz.py +70 -69
  105. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/METADATA +7 -4
  106. deepdoctection-0.34.dist-info/RECORD +146 -0
  107. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/WHEEL +1 -1
  108. deepdoctection/utils/detection_types.py +0 -68
  109. deepdoctection-0.32.dist-info/RECORD +0 -146
  110. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/LICENSE +0 -0
  111. {deepdoctection-0.32.dist-info → deepdoctection-0.34.dist-info}/top_level.txt +0 -0
@@ -52,7 +52,7 @@ import re
52
52
  import subprocess
53
53
  import sys
54
54
  from collections import defaultdict
55
- from typing import List, Optional, Tuple
55
+ from typing import Optional
56
56
 
57
57
  import numpy as np
58
58
  from packaging import version
@@ -85,14 +85,15 @@ from .file_utils import (
85
85
  transformers_available,
86
86
  wandb_available,
87
87
  )
88
+ from .logger import LoggingRecord, logger
89
+ from .types import KeyValEnvInfos, PathLikeOrStr
88
90
 
89
- __all__ = [
90
- "collect_env_info",
91
- "auto_select_viz_library",
92
- ]
91
+ __all__ = ["collect_env_info", "auto_select_viz_library", "ENV_VARS_TRUE"]
93
92
 
94
93
  # pylint: disable=import-outside-toplevel
95
94
 
95
+ ENV_VARS_TRUE: set[str] = {"1", "True", "TRUE", "true", "yes"}
96
+
96
97
 
97
98
  def collect_torch_env() -> str:
98
99
  """Wrapper for torch.utils.collect_env.get_pretty_env_info"""
@@ -107,7 +108,7 @@ def collect_torch_env() -> str:
107
108
  return get_pretty_env_info()
108
109
 
109
110
 
110
- def collect_installed_dependencies(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
111
+ def collect_installed_dependencies(data: KeyValEnvInfos) -> KeyValEnvInfos:
111
112
  """Collect installed dependencies for all third party libraries.
112
113
 
113
114
  :param data: A list of tuples to dump all collected package information such as the name and the version
@@ -175,7 +176,7 @@ def collect_installed_dependencies(data: List[Tuple[str, str]]) -> List[Tuple[st
175
176
  data.append(("Pycocotools", "None"))
176
177
 
177
178
  if scipy_available():
178
- import scipy # type: ignore
179
+ import scipy
179
180
 
180
181
  data.append(("Scipy", scipy.__version__))
181
182
  else:
@@ -232,7 +233,7 @@ def collect_installed_dependencies(data: List[Tuple[str, str]]) -> List[Tuple[st
232
233
  return data
233
234
 
234
235
 
235
- def detect_compute_compatibility(cuda_home: Optional[str], so_file: Optional[str]) -> str:
236
+ def detect_compute_compatibility(cuda_home: Optional[PathLikeOrStr], so_file: Optional[PathLikeOrStr]) -> str:
236
237
  """
237
238
  Detect the compute compatibility of a CUDA library.
238
239
 
@@ -258,7 +259,7 @@ def detect_compute_compatibility(cuda_home: Optional[str], so_file: Optional[str
258
259
 
259
260
 
260
261
  # Copied from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/tfutils/collect_env.py
261
- def tf_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
262
+ def tf_info(data: KeyValEnvInfos) -> KeyValEnvInfos:
262
263
  """Returns a list of (key, value) pairs containing tensorflow information.
263
264
 
264
265
  :param data: A list of tuples to dump all collected package information such as the name and the version
@@ -273,12 +274,12 @@ def tf_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
273
274
  if version.parse(get_tf_version()) > version.parse("2.4.1"):
274
275
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
275
276
  try:
276
- import tensorflow.python.util.deprecation as deprecation # type: ignore # pylint: disable=E0401,R0402
277
+ import tensorflow.python.util.deprecation as deprecation # type: ignore # pylint: disable=E0401,R0402,E0611
277
278
 
278
279
  deprecation._PRINT_DEPRECATION_WARNINGS = False # pylint: disable=W0212
279
280
  except Exception: # pylint: disable=W0703
280
281
  try:
281
- from tensorflow.python.util import deprecation # type: ignore # pylint: disable=E0401
282
+ from tensorflow.python.util import deprecation # type: ignore # pylint: disable=E0401,E0611
282
283
 
283
284
  deprecation._PRINT_DEPRECATION_WARNINGS = False # pylint: disable=W0212
284
285
  except Exception: # pylint: disable=W0703
@@ -287,13 +288,13 @@ def tf_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
287
288
  data.append(("Tensorflow", "None"))
288
289
  return data
289
290
 
290
- from tensorflow.python.platform import build_info # type: ignore # pylint: disable=E0401
291
+ from tensorflow.python.platform import build_info # type: ignore # pylint: disable=E0401,E0611
291
292
 
292
293
  try:
293
294
  for key, value in list(build_info.build_info.items()):
294
295
  if key == "is_cuda_build":
295
296
  data.append(("TF compiled with CUDA", value))
296
- if value and len(tf.config.list_physical_devices('GPU')):
297
+ if value and len(tf.config.list_physical_devices("GPU")):
297
298
  os.environ["USE_CUDA"] = "1"
298
299
  elif key == "cuda_version":
299
300
  data.append(("TF built with CUDA", value))
@@ -315,7 +316,7 @@ def tf_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
315
316
 
316
317
 
317
318
  # Heavily inspired by https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/collect_env.py
318
- def pt_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
319
+ def pt_info(data: KeyValEnvInfos) -> KeyValEnvInfos:
319
320
  """Returns a list of (key, value) pairs containing Pytorch information.
320
321
 
321
322
  :param data: A list of tuples to dump all collected package information such as the name and the version
@@ -423,6 +424,42 @@ def pt_info(data: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
423
424
  return data
424
425
 
425
426
 
427
+ def set_dl_env_vars() -> None:
428
+ """Set the environment variables that steer the selection of the DL framework.
429
+ If both PyTorch and TensorFlow are available, PyTorch will be selected by default.
430
+ It is possible that for testing purposes, e.g. on Colab you can find yourself with a pre-installed Tensorflow
431
+ version. If you want to enforce PyTorch you must set:
432
+
433
+ os.environ["DD_USE_TORCH"] = "1"
434
+ os.environ["USE_TORCH"] = "1" # necessary if you make use of DocTr's OCR engine
435
+ os.environ["DD_USE_TF"] = "0"
436
+ os.environ["USE_TF"] = "0" # it's better to explcitly disable Tensorflow
437
+
438
+
439
+ """
440
+
441
+ if os.environ.get("PYTORCH_AVAILABLE") and os.environ.get("DD_USE_TORCH") is None:
442
+ os.environ["DD_USE_TORCH"] = "1"
443
+ os.environ["USE_TORCH"] = "1"
444
+ if os.environ.get("TENSORFLOW_AVAILABLE") and os.environ.get("DD_USE_TF") is None:
445
+ os.environ["DD_USE_TF"] = "1"
446
+ os.environ["USE_TF"] = "1"
447
+
448
+ if os.environ.get("DD_USE_TORCH", "0") in ENV_VARS_TRUE and os.environ.get("DD_USE_TF", "0") in ENV_VARS_TRUE:
449
+ logger.warning(
450
+ "Both DD_USE_TORCH and DD_USE_TF are set. Defaulting to PyTorch. If you want a different "
451
+ "behaviour, set DD_USE_TORCH to None before importing deepdoctection."
452
+ )
453
+ os.environ["DD_USE_TF"] = "0"
454
+ os.environ["USE_TF"] = "0"
455
+
456
+ if (
457
+ os.environ.get("PYTORCH_AVAILABLE") not in ENV_VARS_TRUE
458
+ and os.environ.get("TENSORFLOW_AVAILABLE") not in ENV_VARS_TRUE
459
+ ):
460
+ logger.warning(LoggingRecord(msg="Neither Tensorflow or Pytorch are available."))
461
+
462
+
426
463
  def collect_env_info() -> str:
427
464
  """
428
465
 
@@ -469,6 +506,7 @@ def collect_env_info() -> str:
469
506
 
470
507
  data = pt_info(data)
471
508
  data = tf_info(data)
509
+ set_dl_env_vars()
472
510
 
473
511
  data = collect_installed_dependencies(data)
474
512
 
@@ -16,15 +16,15 @@ import sys
16
16
  from os import environ, path
17
17
  from shutil import which
18
18
  from types import ModuleType
19
- from typing import Any, Tuple, Union, no_type_check
19
+ from typing import Any, Union, no_type_check
20
20
 
21
21
  import importlib_metadata
22
22
  from packaging import version
23
23
 
24
- from .detection_types import Requirement
25
24
  from .error import DependencyError
26
25
  from .logger import LoggingRecord, logger
27
26
  from .metacfg import AttrDict
27
+ from .types import PathLikeOrStr, Requirement
28
28
 
29
29
  _GENERIC_ERR_MSG = "Please check the required version either in the docs or in the setup file"
30
30
 
@@ -52,7 +52,7 @@ def get_tf_version() -> str:
52
52
  """
53
53
  tf_version = "0.0"
54
54
  if tf_available():
55
- candidates: Tuple[str, ...] = (
55
+ candidates: tuple[str, ...] = (
56
56
  "tensorflow",
57
57
  "tensorflow-cpu",
58
58
  "tensorflow-gpu",
@@ -250,31 +250,26 @@ def get_detectron2_requirement() -> Requirement:
250
250
  # Tesseract related dependencies
251
251
  _TESS_AVAILABLE = which("tesseract") is not None
252
252
  # Tesseract installation path
253
- _TESS_PATH = "tesseract"
253
+ _TESS_PATH: PathLikeOrStr = "tesseract"
254
254
  _TESS_ERR_MSG = (
255
255
  "Tesseract >=4.0 must be installed. Please follow the official installation instructions. "
256
256
  "https://tesseract-ocr.github.io/tessdoc/Installation.html"
257
257
  )
258
258
 
259
259
 
260
- def set_tesseract_path(tesseract_path: str) -> None:
260
+ def set_tesseract_path(tesseract_path: PathLikeOrStr) -> None:
261
261
  """Set the Tesseract path. If you have tesseract installed in Anaconda,
262
262
  you can use this function to set tesseract path.
263
263
 
264
264
  :param tesseract_path: Tesseract installation path.
265
265
  """
266
- if tesseract_path is None:
267
- raise TypeError("tesseract_path cannot be None")
268
266
 
269
267
  global _TESS_AVAILABLE # pylint: disable=W0603
270
268
  global _TESS_PATH # pylint: disable=W0603
271
269
 
272
270
  tesseract_flag = which(tesseract_path)
273
271
 
274
- if tesseract_flag is None:
275
- _TESS_AVAILABLE = False
276
- else:
277
- _TESS_AVAILABLE = True
272
+ _TESS_AVAILABLE = False if tesseract_flag is not None else True # pylint: disable=W0603,R1719
278
273
 
279
274
  _TESS_PATH = tesseract_path
280
275
 
@@ -25,15 +25,16 @@ import os
25
25
  from base64 import b64encode
26
26
  from io import BytesIO
27
27
  from pathlib import Path
28
+ from shutil import copyfile
28
29
  from typing import Callable, Literal, Optional, Protocol, Union, overload
29
30
  from urllib.request import urlretrieve
30
31
 
31
- from .detection_types import ImageType, JsonDict, Pathlike
32
32
  from .develop import deprecated
33
33
  from .logger import LoggingRecord, logger
34
34
  from .pdf_utils import get_pdf_file_reader, get_pdf_file_writer
35
35
  from .settings import CONFIGS, DATASET_DIR, MODEL_DIR, PATH
36
36
  from .tqdm import get_tqdm
37
+ from .types import B64, B64Str, JsonDict, PathLikeOrStr, PixelValues
37
38
  from .utils import is_file_extension
38
39
  from .viz import viz_handler
39
40
 
@@ -50,6 +51,7 @@ __all__ = [
50
51
  "get_configs_dir_path",
51
52
  "get_weights_dir_path",
52
53
  "get_dataset_dir_path",
54
+ "maybe_copy_config_to_cache",
53
55
  ]
54
56
 
55
57
 
@@ -66,7 +68,7 @@ def sizeof_fmt(num: float, suffix: str = "B") -> str:
66
68
 
67
69
  # Copyright (c) Tensorpack Contributors
68
70
  # Licensed under the Apache License, Version 2.0 (the "License")
69
- def mkdir_p(dir_name: Pathlike) -> None:
71
+ def mkdir_p(dir_name: PathLikeOrStr) -> None:
70
72
  """
71
73
  Like "mkdir -p", make a dir recursively, but do nothing if the dir exists
72
74
 
@@ -84,7 +86,9 @@ def mkdir_p(dir_name: Pathlike) -> None:
84
86
 
85
87
  # Copyright (c) Tensorpack Contributors
86
88
  # Licensed under the Apache License, Version 2.0 (the "License")
87
- def download(url: str, directory: Pathlike, file_name: Optional[str] = None, expect_size: Optional[int] = None) -> str:
89
+ def download(
90
+ url: str, directory: PathLikeOrStr, file_name: Optional[str] = None, expect_size: Optional[int] = None
91
+ ) -> str:
88
92
  """
89
93
  Download URL to a directory. Will figure out the filename automatically from URL, if not given.
90
94
  """
@@ -133,16 +137,18 @@ def download(url: str, directory: Pathlike, file_name: Optional[str] = None, exp
133
137
 
134
138
 
135
139
  @overload
136
- def load_image_from_file(path: Pathlike, type_id: Literal["np"] = "np") -> Optional[ImageType]:
140
+ def load_image_from_file(path: PathLikeOrStr, type_id: Literal["np"] = "np") -> Optional[PixelValues]:
137
141
  ...
138
142
 
139
143
 
140
144
  @overload
141
- def load_image_from_file(path: Pathlike, type_id: Literal["b64"]) -> Optional[str]:
145
+ def load_image_from_file(path: PathLikeOrStr, type_id: Literal["b64"]) -> Optional[B64Str]:
142
146
  ...
143
147
 
144
148
 
145
- def load_image_from_file(path: Pathlike, type_id: Literal["np", "b64"] = "np") -> Optional[Union[str, ImageType]]:
149
+ def load_image_from_file(
150
+ path: PathLikeOrStr, type_id: Literal["np", "b64"] = "np"
151
+ ) -> Optional[Union[B64Str, PixelValues]]:
146
152
  """
147
153
  Loads an image from path and passes back an encoded base64 string, a numpy array or None if file is not found
148
154
  or a conversion error occurs.
@@ -151,7 +157,7 @@ def load_image_from_file(path: Pathlike, type_id: Literal["np", "b64"] = "np") -
151
157
  :param type_id: "np" or "b64".
152
158
  :return: image of desired representation
153
159
  """
154
- image: Optional[Union[str, ImageType]] = None
160
+ image: Optional[Union[str, PixelValues]] = None
155
161
  path = path.as_posix() if isinstance(path, Path) else path
156
162
 
157
163
  assert is_file_extension(path, [".png", ".jpeg", ".jpg", ".tif"]), f"image type not allowed: {path}"
@@ -169,7 +175,7 @@ def load_image_from_file(path: Pathlike, type_id: Literal["np", "b64"] = "np") -
169
175
  return image
170
176
 
171
177
 
172
- def load_bytes_from_pdf_file(path: Pathlike, page_number: int = 0) -> bytes:
178
+ def load_bytes_from_pdf_file(path: PathLikeOrStr, page_number: int = 0) -> B64:
173
179
  """
174
180
  Loads a pdf file with one single page and passes back a bytes' representation of this file. Can be converted into
175
181
  a numpy or directly passed to the attr: image of Image.
@@ -194,13 +200,13 @@ class LoadImageFunc(Protocol):
194
200
  Protocol for typing load_image_from_file
195
201
  """
196
202
 
197
- def __call__(self, path: Pathlike) -> Optional[ImageType]:
203
+ def __call__(self, path: PathLikeOrStr) -> Optional[PixelValues]:
198
204
  ...
199
205
 
200
206
 
201
207
  def get_load_image_func(
202
- path: Pathlike,
203
- ) -> Union[LoadImageFunc, Callable[[Pathlike], bytes]]:
208
+ path: PathLikeOrStr,
209
+ ) -> Union[LoadImageFunc, Callable[[PathLikeOrStr], B64]]:
204
210
  """
205
211
  Return the loading function according to its file extension.
206
212
 
@@ -219,7 +225,7 @@ def get_load_image_func(
219
225
  )
220
226
 
221
227
 
222
- def maybe_path_or_pdf(path: Pathlike) -> int:
228
+ def maybe_path_or_pdf(path: PathLikeOrStr) -> int:
223
229
  """
224
230
  Checks if the path points to a directory or a pdf document. Returns 1 if the path points to a directory, 2
225
231
  if the path points to a pdf doc or 0, if none of the previous is true.
@@ -238,7 +244,7 @@ def maybe_path_or_pdf(path: Pathlike) -> int:
238
244
  return 0
239
245
 
240
246
 
241
- def load_json(path_ann: Pathlike) -> JsonDict:
247
+ def load_json(path_ann: PathLikeOrStr) -> JsonDict:
242
248
  """
243
249
  Loading json file
244
250
 
@@ -278,8 +284,29 @@ def get_dataset_dir_path() -> Path:
278
284
  return DATASET_DIR
279
285
 
280
286
 
287
+ def maybe_copy_config_to_cache(
288
+ package_path: PathLikeOrStr, configs_dir_path: PathLikeOrStr, file_name: str, force_copy: bool = True
289
+ ) -> str:
290
+ """
291
+ Initial copying of various files
292
+ :param package_path: base path to directory of source file `file_name`
293
+ :param configs_dir_path: base path to target directory
294
+ :param file_name: file to copy
295
+ :param force_copy: If file is already in target directory, will re-copy the file
296
+
297
+ :return: path to the copied file_name
298
+ """
299
+
300
+ absolute_path_source = os.path.join(package_path, file_name)
301
+ absolute_path = os.path.join(configs_dir_path, os.path.join(os.path.split(file_name)[1]))
302
+ mkdir_p(os.path.split(absolute_path)[0])
303
+ if not os.path.isfile(absolute_path) or force_copy:
304
+ copyfile(absolute_path_source, absolute_path)
305
+ return absolute_path
306
+
307
+
281
308
  @deprecated("Use pathlib operations instead", "2022-06-08")
282
- def sub_path(anchor_dir: str, *paths: str) -> str:
309
+ def sub_path(anchor_dir: PathLikeOrStr, *paths: PathLikeOrStr) -> PathLikeOrStr:
283
310
  """
284
311
  Generate a path from the anchor directory and various paths args.
285
312
 
@@ -21,7 +21,7 @@ Methods for generating and checking uuids
21
21
  import hashlib
22
22
  import uuid
23
23
 
24
- from .detection_types import Pathlike
24
+ from .types import PathLikeOrStr
25
25
 
26
26
  __all__ = ["is_uuid_like", "get_uuid_from_str", "get_uuid"]
27
27
 
@@ -65,7 +65,7 @@ def get_uuid(*inputs: str) -> str:
65
65
  return get_uuid_from_str(str_input)
66
66
 
67
67
 
68
- def get_md5_hash(path: Pathlike, buffer_size: int = 65536) -> str:
68
+ def get_md5_hash(path: PathLikeOrStr, buffer_size: int = 65536) -> str:
69
69
  """
70
70
  Calculate a md5 hash for a given file
71
71
 
@@ -25,7 +25,6 @@ Log levels can be set via the environment variable `LOG_LEVEL` (default: INFO).
25
25
  `STD_OUT_VERBOSE` will print a verbose message to the terminal (default: False).
26
26
  """
27
27
 
28
- import ast
29
28
  import errno
30
29
  import functools
31
30
  import json
@@ -37,21 +36,23 @@ import sys
37
36
  from dataclasses import dataclass, field
38
37
  from datetime import datetime
39
38
  from pathlib import Path
40
- from typing import Any, Dict, Optional, Union, no_type_check
39
+ from typing import Any, Optional, Union, no_type_check
41
40
 
42
41
  from termcolor import colored
43
42
 
44
- from .detection_types import Pathlike
43
+ from .types import PathLikeOrStr
45
44
 
46
45
  __all__ = ["logger", "set_logger_dir", "auto_set_dir", "get_logger_dir"]
47
46
 
47
+ ENV_VARS_TRUE: set[str] = {"1", "True", "TRUE", "true", "yes"}
48
+
48
49
 
49
50
  @dataclass
50
51
  class LoggingRecord:
51
52
  """LoggingRecord to pass to the logger in order to distinguish from third party libraries."""
52
53
 
53
54
  msg: str
54
- log_dict: Optional[Dict[Union[int, str], Any]] = field(default=None)
55
+ log_dict: Optional[dict[Union[int, str], Any]] = field(default=None)
55
56
 
56
57
  def __post_init__(self) -> None:
57
58
  """log_dict will be added to the log record as a dict."""
@@ -66,7 +67,7 @@ class LoggingRecord:
66
67
  class CustomFilter(logging.Filter):
67
68
  """A custom filter"""
68
69
 
69
- filter_third_party_lib = ast.literal_eval(os.environ.get("FILTER_THIRD_PARTY_LIB", "False"))
70
+ filter_third_party_lib = os.environ.get("FILTER_THIRD_PARTY_LIB", "False") in ENV_VARS_TRUE
70
71
 
71
72
  def filter(self, record: logging.LogRecord) -> bool:
72
73
  if self.filter_third_party_lib:
@@ -79,7 +80,7 @@ class CustomFilter(logging.Filter):
79
80
  class StreamFormatter(logging.Formatter):
80
81
  """A custom formatter to produce unified LogRecords"""
81
82
 
82
- std_out_verbose = ast.literal_eval(os.environ.get("STD_OUT_VERBOSE", "False"))
83
+ std_out_verbose = os.environ.get("STD_OUT_VERBOSE", "False") in ENV_VARS_TRUE
83
84
 
84
85
  @no_type_check
85
86
  def format(self, record: logging.LogRecord) -> str:
@@ -109,7 +110,7 @@ class StreamFormatter(logging.Formatter):
109
110
  class FileFormatter(logging.Formatter):
110
111
  """A custom formatter to produce a loggings in json format"""
111
112
 
112
- filter_third_party_lib = ast.literal_eval(os.environ.get("FILTER_THIRD_PARTY_LIB", "False"))
113
+ filter_third_party_lib = os.environ.get("FILTER_THIRD_PARTY_LIB", "False") in ENV_VARS_TRUE
113
114
 
114
115
  @no_type_check
115
116
  def format(self, record: logging.LogRecord) -> str:
@@ -132,7 +133,7 @@ class FileFormatter(logging.Formatter):
132
133
 
133
134
 
134
135
  _LOG_DIR = None
135
- _CONFIG_DICT: Dict[str, Any] = {
136
+ _CONFIG_DICT: dict[str, Any] = {
136
137
  "version": 1,
137
138
  "disable_existing_loggers": False,
138
139
  "filters": {"customfilter": {"()": lambda: CustomFilter()}}, # pylint: disable=W0108
@@ -145,7 +146,7 @@ _CONFIG_DICT: Dict[str, Any] = {
145
146
  "root": {
146
147
  "handlers": ["streamhandler"],
147
148
  "level": os.environ.get("LOG_LEVEL", "INFO"),
148
- "propagate": ast.literal_eval(os.environ.get("LOG_PROPAGATE", "False")),
149
+ "propagate": os.environ.get("LOG_PROPAGATE", "False") in ENV_VARS_TRUE,
149
150
  },
150
151
  }
151
152
 
@@ -171,9 +172,8 @@ def _get_time_str() -> str:
171
172
  return datetime.now().strftime("%m%d-%H%M%S")
172
173
 
173
174
 
174
- def _set_file(path: Pathlike) -> None:
175
- if isinstance(path, Path):
176
- path = path.as_posix()
175
+ def _set_file(path: PathLikeOrStr) -> None:
176
+ path = os.fspath(path)
177
177
  global _FILE_HANDLER # pylint: disable=W0603
178
178
  if os.path.isfile(path):
179
179
  backup_name = path + "." + _get_time_str()
@@ -188,7 +188,7 @@ def _set_file(path: Pathlike) -> None:
188
188
  logger.info("Argv: %s ", sys.argv)
189
189
 
190
190
 
191
- def set_logger_dir(dir_name: Pathlike, action: Optional[str] = None) -> None:
191
+ def set_logger_dir(dir_name: PathLikeOrStr, action: Optional[str] = None) -> None:
192
192
  """
193
193
  Set the directory for global logging.
194
194
 
@@ -213,7 +213,7 @@ def set_logger_dir(dir_name: Pathlike, action: Optional[str] = None) -> None:
213
213
  logger.removeHandler(_FILE_HANDLER)
214
214
  del _FILE_HANDLER
215
215
 
216
- def dir_nonempty(directory: str) -> int:
216
+ def dir_nonempty(directory: PathLikeOrStr) -> int:
217
217
  return os.path.isdir(directory) and len([x for x in os.listdir(directory) if x[0] != "."])
218
218
 
219
219
  if dir_nonempty(dir_name):
@@ -267,7 +267,7 @@ def auto_set_dir(action: Optional[str] = None, name: Optional[str] = None) -> No
267
267
  set_logger_dir(auto_dir_name, action=action)
268
268
 
269
269
 
270
- def get_logger_dir() -> Optional[str]:
270
+ def get_logger_dir() -> Optional[PathLikeOrStr]:
271
271
  """
272
272
  The logger directory, or None if not set.
273
273
  The directory is used for general logging, tensorboard events, checkpoints, etc.
@@ -20,11 +20,11 @@ Class AttrDict for maintaining configs and some functions for generating and sav
20
20
  """
21
21
 
22
22
  import pprint
23
- from typing import Any, Dict, List
23
+ from typing import Any
24
24
 
25
25
  import yaml
26
26
 
27
- from .detection_types import Pathlike
27
+ from .types import PathLikeOrStr
28
28
 
29
29
 
30
30
  # Copyright (c) Tensorpack Contributors
@@ -67,13 +67,13 @@ class AttrDict:
67
67
 
68
68
  __repr__ = __str__
69
69
 
70
- def to_dict(self) -> Dict[str, Any]:
70
+ def to_dict(self) -> dict[str, Any]:
71
71
  """Convert to a nested dict."""
72
72
  return {
73
73
  k: v.to_dict() if isinstance(v, AttrDict) else v for k, v in self.__dict__.items() if not k.startswith("_")
74
74
  }
75
75
 
76
- def from_dict(self, d: Dict[str, Any]) -> None: # pylint: disable=C0103
76
+ def from_dict(self, d: dict[str, Any]) -> None: # pylint: disable=C0103
77
77
  """
78
78
  Generate an instance from a dict
79
79
  """
@@ -86,7 +86,7 @@ class AttrDict:
86
86
  else:
87
87
  setattr(self, k, v)
88
88
 
89
- def update_args(self, args: List[str]) -> None:
89
+ def update_args(self, args: list[str]) -> None:
90
90
  """
91
91
  Update from command line args.
92
92
  """
@@ -122,7 +122,7 @@ class AttrDict:
122
122
  raise NotImplementedError()
123
123
 
124
124
 
125
- def set_config_by_yaml(path_yaml: Pathlike) -> AttrDict:
125
+ def set_config_by_yaml(path_yaml: PathLikeOrStr) -> AttrDict:
126
126
  """
127
127
  Use to initialize the config class for tensorpack faster rcnn
128
128
 
@@ -139,7 +139,7 @@ def set_config_by_yaml(path_yaml: Pathlike) -> AttrDict:
139
139
  return config
140
140
 
141
141
 
142
- def save_config_to_yaml(config: AttrDict, path_yaml: Pathlike) -> None:
142
+ def save_config_to_yaml(config: AttrDict, path_yaml: PathLikeOrStr) -> None:
143
143
  """
144
144
  :param config: The configuration instance as an AttrDict
145
145
  :param path_yaml: Save the config class for tensorpack faster rcnn