onnxtr 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. onnxtr/contrib/__init__.py +1 -0
  2. onnxtr/contrib/artefacts.py +6 -8
  3. onnxtr/contrib/base.py +7 -16
  4. onnxtr/file_utils.py +1 -3
  5. onnxtr/io/elements.py +54 -60
  6. onnxtr/io/html.py +0 -2
  7. onnxtr/io/image.py +1 -4
  8. onnxtr/io/pdf.py +3 -5
  9. onnxtr/io/reader.py +4 -10
  10. onnxtr/models/_utils.py +10 -17
  11. onnxtr/models/builder.py +17 -30
  12. onnxtr/models/classification/models/mobilenet.py +7 -12
  13. onnxtr/models/classification/predictor/base.py +6 -7
  14. onnxtr/models/classification/zoo.py +25 -11
  15. onnxtr/models/detection/_utils/base.py +3 -7
  16. onnxtr/models/detection/core.py +2 -8
  17. onnxtr/models/detection/models/differentiable_binarization.py +10 -17
  18. onnxtr/models/detection/models/fast.py +10 -17
  19. onnxtr/models/detection/models/linknet.py +10 -17
  20. onnxtr/models/detection/postprocessor/base.py +3 -9
  21. onnxtr/models/detection/predictor/base.py +4 -5
  22. onnxtr/models/detection/zoo.py +20 -6
  23. onnxtr/models/engine.py +9 -9
  24. onnxtr/models/factory/hub.py +3 -7
  25. onnxtr/models/predictor/base.py +29 -30
  26. onnxtr/models/predictor/predictor.py +4 -5
  27. onnxtr/models/preprocessor/base.py +8 -12
  28. onnxtr/models/recognition/core.py +0 -1
  29. onnxtr/models/recognition/models/crnn.py +11 -23
  30. onnxtr/models/recognition/models/master.py +9 -15
  31. onnxtr/models/recognition/models/parseq.py +8 -12
  32. onnxtr/models/recognition/models/sar.py +8 -12
  33. onnxtr/models/recognition/models/vitstr.py +9 -15
  34. onnxtr/models/recognition/predictor/_utils.py +6 -9
  35. onnxtr/models/recognition/predictor/base.py +3 -3
  36. onnxtr/models/recognition/utils.py +2 -7
  37. onnxtr/models/recognition/zoo.py +19 -7
  38. onnxtr/models/zoo.py +7 -9
  39. onnxtr/transforms/base.py +17 -6
  40. onnxtr/utils/common_types.py +7 -8
  41. onnxtr/utils/data.py +7 -11
  42. onnxtr/utils/fonts.py +1 -6
  43. onnxtr/utils/geometry.py +18 -49
  44. onnxtr/utils/multithreading.py +3 -5
  45. onnxtr/utils/reconstitution.py +139 -38
  46. onnxtr/utils/repr.py +1 -2
  47. onnxtr/utils/visualization.py +12 -21
  48. onnxtr/utils/vocabs.py +1 -2
  49. onnxtr/version.py +1 -1
  50. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/METADATA +71 -41
  51. onnxtr-0.6.0.dist-info/RECORD +75 -0
  52. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/WHEEL +1 -1
  53. onnxtr-0.5.0.dist-info/RECORD +0 -75
  54. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/LICENSE +0 -0
  55. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/top_level.txt +0 -0
  56. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/zip-safe +0 -0
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, Optional
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  from scipy.special import softmax
@@ -16,7 +16,7 @@ from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["SAR", "sar_resnet31"]
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "sar_resnet31": {
21
21
  "mean": (0.694, 0.695, 0.693),
22
22
  "std": (0.299, 0.296, 0.301),
@@ -32,7 +32,6 @@ class SAR(Engine):
32
32
  """SAR Onnx loader
33
33
 
34
34
  Args:
35
- ----
36
35
  model_path: path to onnx model file
37
36
  vocab: vocabulary used for encoding
38
37
  engine_cfg: configuration for the inference engine
@@ -44,8 +43,8 @@ class SAR(Engine):
44
43
  self,
45
44
  model_path: str,
46
45
  vocab: str,
47
- engine_cfg: Optional[EngineConfig] = None,
48
- cfg: Optional[Dict[str, Any]] = None,
46
+ engine_cfg: EngineConfig | None = None,
47
+ cfg: dict[str, Any] | None = None,
49
48
  **kwargs: Any,
50
49
  ) -> None:
51
50
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -59,10 +58,10 @@ class SAR(Engine):
59
58
  self,
60
59
  x: np.ndarray,
61
60
  return_model_output: bool = False,
62
- ) -> Dict[str, Any]:
61
+ ) -> dict[str, Any]:
63
62
  logits = self.run(x)
64
63
 
65
- out: Dict[str, Any] = {}
64
+ out: dict[str, Any] = {}
66
65
  if return_model_output:
67
66
  out["out_map"] = logits
68
67
 
@@ -75,7 +74,6 @@ class SARPostProcessor(RecognitionPostProcessor):
75
74
  """Post processor for SAR architectures
76
75
 
77
76
  Args:
78
- ----
79
77
  embedding: string containing the ordered sequence of supported characters
80
78
  """
81
79
 
@@ -105,7 +103,7 @@ def _sar(
105
103
  arch: str,
106
104
  model_path: str,
107
105
  load_in_8_bit: bool = False,
108
- engine_cfg: Optional[EngineConfig] = None,
106
+ engine_cfg: EngineConfig | None = None,
109
107
  **kwargs: Any,
110
108
  ) -> SAR:
111
109
  # Patch the config
@@ -124,7 +122,7 @@ def _sar(
124
122
  def sar_resnet31(
125
123
  model_path: str = default_cfgs["sar_resnet31"]["url"],
126
124
  load_in_8_bit: bool = False,
127
- engine_cfg: Optional[EngineConfig] = None,
125
+ engine_cfg: EngineConfig | None = None,
128
126
  **kwargs: Any,
129
127
  ) -> SAR:
130
128
  """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong
@@ -137,14 +135,12 @@ def sar_resnet31(
137
135
  >>> out = model(input_tensor)
138
136
 
139
137
  Args:
140
- ----
141
138
  model_path: path to onnx model file, defaults to url in default_cfgs
142
139
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
143
140
  engine_cfg: configuration for the inference engine
144
141
  **kwargs: keyword arguments of the SAR architecture
145
142
 
146
143
  Returns:
147
- -------
148
144
  text recognition architecture
149
145
  """
150
146
  return _sar("sar_resnet31", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from copy import deepcopy
7
- from typing import Any, Dict, Optional
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  from scipy.special import softmax
@@ -16,7 +16,7 @@ from ..core import RecognitionPostProcessor
16
16
 
17
17
  __all__ = ["ViTSTR", "vitstr_small", "vitstr_base"]
18
18
 
19
- default_cfgs: Dict[str, Dict[str, Any]] = {
19
+ default_cfgs: dict[str, dict[str, Any]] = {
20
20
  "vitstr_small": {
21
21
  "mean": (0.694, 0.695, 0.693),
22
22
  "std": (0.299, 0.296, 0.301),
@@ -40,7 +40,6 @@ class ViTSTR(Engine):
40
40
  """ViTSTR Onnx loader
41
41
 
42
42
  Args:
43
- ----
44
43
  model_path: path to onnx model file
45
44
  vocab: vocabulary used for encoding
46
45
  engine_cfg: configuration for the inference engine
@@ -52,8 +51,8 @@ class ViTSTR(Engine):
52
51
  self,
53
52
  model_path: str,
54
53
  vocab: str,
55
- engine_cfg: Optional[EngineConfig] = None,
56
- cfg: Optional[Dict[str, Any]] = None,
54
+ engine_cfg: EngineConfig | None = None,
55
+ cfg: dict[str, Any] | None = None,
57
56
  **kwargs: Any,
58
57
  ) -> None:
59
58
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -67,10 +66,10 @@ class ViTSTR(Engine):
67
66
  self,
68
67
  x: np.ndarray,
69
68
  return_model_output: bool = False,
70
- ) -> Dict[str, Any]:
69
+ ) -> dict[str, Any]:
71
70
  logits = self.run(x)
72
71
 
73
- out: Dict[str, Any] = {}
72
+ out: dict[str, Any] = {}
74
73
  if return_model_output:
75
74
  out["out_map"] = logits
76
75
 
@@ -83,7 +82,6 @@ class ViTSTRPostProcessor(RecognitionPostProcessor):
83
82
  """Post processor for ViTSTR architecture
84
83
 
85
84
  Args:
86
- ----
87
85
  vocab: string containing the ordered sequence of supported characters
88
86
  """
89
87
 
@@ -115,7 +113,7 @@ def _vitstr(
115
113
  arch: str,
116
114
  model_path: str,
117
115
  load_in_8_bit: bool = False,
118
- engine_cfg: Optional[EngineConfig] = None,
116
+ engine_cfg: EngineConfig | None = None,
119
117
  **kwargs: Any,
120
118
  ) -> ViTSTR:
121
119
  # Patch the config
@@ -134,7 +132,7 @@ def _vitstr(
134
132
  def vitstr_small(
135
133
  model_path: str = default_cfgs["vitstr_small"]["url"],
136
134
  load_in_8_bit: bool = False,
137
- engine_cfg: Optional[EngineConfig] = None,
135
+ engine_cfg: EngineConfig | None = None,
138
136
  **kwargs: Any,
139
137
  ) -> ViTSTR:
140
138
  """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
@@ -147,14 +145,12 @@ def vitstr_small(
147
145
  >>> out = model(input_tensor)
148
146
 
149
147
  Args:
150
- ----
151
148
  model_path: path to onnx model file, defaults to url in default_cfgs
152
149
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
153
150
  engine_cfg: configuration for the inference engine
154
151
  **kwargs: keyword arguments of the ViTSTR architecture
155
152
 
156
153
  Returns:
157
- -------
158
154
  text recognition architecture
159
155
  """
160
156
  return _vitstr("vitstr_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -163,7 +159,7 @@ def vitstr_small(
163
159
  def vitstr_base(
164
160
  model_path: str = default_cfgs["vitstr_base"]["url"],
165
161
  load_in_8_bit: bool = False,
166
- engine_cfg: Optional[EngineConfig] = None,
162
+ engine_cfg: EngineConfig | None = None,
167
163
  **kwargs: Any,
168
164
  ) -> ViTSTR:
169
165
  """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition"
@@ -176,14 +172,12 @@ def vitstr_base(
176
172
  >>> out = model(input_tensor)
177
173
 
178
174
  Args:
179
- ----
180
175
  model_path: path to onnx model file, defaults to url in default_cfgs
181
176
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
182
177
  engine_cfg: configuration for the inference engine
183
178
  **kwargs: keyword arguments of the ViTSTR architecture
184
179
 
185
180
  Returns:
186
- -------
187
181
  text recognition architecture
188
182
  """
189
183
  return _vitstr("vitstr_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -3,7 +3,6 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List, Tuple, Union
7
6
 
8
7
  import numpy as np
9
8
 
@@ -13,16 +12,15 @@ __all__ = ["split_crops", "remap_preds"]
13
12
 
14
13
 
15
14
  def split_crops(
16
- crops: List[np.ndarray],
15
+ crops: list[np.ndarray],
17
16
  max_ratio: float,
18
17
  target_ratio: int,
19
18
  dilation: float,
20
19
  channels_last: bool = True,
21
- ) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]:
20
+ ) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
22
21
  """Chunk crops horizontally to match a given aspect ratio
23
22
 
24
23
  Args:
25
- ----
26
24
  crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
27
25
  max_ratio: the maximum aspect ratio that won't trigger the chunk
28
26
  target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
@@ -30,12 +28,11 @@ def split_crops(
30
28
  channels_last: whether the numpy array has dimensions in channels last order
31
29
 
32
30
  Returns:
33
- -------
34
31
  a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
35
32
  """
36
33
  _remap_required = False
37
- crop_map: List[Union[int, Tuple[int, int]]] = []
38
- new_crops: List[np.ndarray] = []
34
+ crop_map: list[int | tuple[int, int]] = []
35
+ new_crops: list[np.ndarray] = []
39
36
  for crop in crops:
40
37
  h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
41
38
  aspect_ratio = w / h
@@ -71,8 +68,8 @@ def split_crops(
71
68
 
72
69
 
73
70
  def remap_preds(
74
- preds: List[Tuple[str, float]], crop_map: List[Union[int, Tuple[int, int]]], dilation: float
75
- ) -> List[Tuple[str, float]]:
71
+ preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int]], dilation: float
72
+ ) -> list[tuple[str, float]]:
76
73
  remapped_out = []
77
74
  for _idx in crop_map:
78
75
  # Crop hasn't been split
@@ -3,7 +3,8 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Sequence, Tuple
6
+ from collections.abc import Sequence
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
 
@@ -19,7 +20,6 @@ class RecognitionPredictor(NestedObject):
19
20
  """Implements an object able to identify character sequences in images
20
21
 
21
22
  Args:
22
- ----
23
23
  pre_processor: transform inputs for easier batched model inference
24
24
  model: core recognition architecture
25
25
  split_wide_crops: wether to use crop splitting for high aspect ratio crops
@@ -43,7 +43,7 @@ class RecognitionPredictor(NestedObject):
43
43
  self,
44
44
  crops: Sequence[np.ndarray],
45
45
  **kwargs: Any,
46
- ) -> List[Tuple[str, float]]:
46
+ ) -> list[tuple[str, float]]:
47
47
  if len(crops) == 0:
48
48
  return []
49
49
  # Dimension check
@@ -3,7 +3,6 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import List
7
6
 
8
7
  from rapidfuzz.distance import Levenshtein
9
8
 
@@ -14,14 +13,12 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
14
13
  """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters.
15
14
 
16
15
  Args:
17
- ----
18
16
  a: first char seq, suffix should be similar to b's prefix.
19
17
  b: second char seq, prefix should be similar to a's suffix.
20
18
  dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
21
19
  only used when the mother sequence is splitted on a character repetition
22
20
 
23
21
  Returns:
24
- -------
25
22
  A merged character sequence.
26
23
 
27
24
  Example::
@@ -61,17 +58,15 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str:
61
58
  return a[:-1] + b[index - 1 :]
62
59
 
63
60
 
64
- def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
61
+ def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str:
65
62
  """Recursively merges consecutive string sequences with overlapping characters.
66
63
 
67
64
  Args:
68
- ----
69
65
  seq_list: list of sequences to merge. Sequences need to be ordered from left to right.
70
66
  dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is
71
67
  only used when the mother sequence is splitted on a character repetition
72
68
 
73
69
  Returns:
74
- -------
75
70
  A merged character sequence
76
71
 
77
72
  Example::
@@ -80,7 +75,7 @@ def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str:
80
75
  'abcdefghijkl'
81
76
  """
82
77
 
83
- def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str:
78
+ def _recursive_merge(a: str, seq_list: list[str], dil_factor: float) -> str:
84
79
  # Recursive version of compute_overlap
85
80
  if len(seq_list) == 1:
86
81
  return merge_strings(a, seq_list[0], dil_factor)
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Optional
6
+ from typing import Any
7
7
 
8
8
  from .. import recognition
9
9
  from ..engine import EngineConfig
@@ -13,7 +13,7 @@ from .predictor import RecognitionPredictor
13
13
  __all__ = ["recognition_predictor"]
14
14
 
15
15
 
16
- ARCHS: List[str] = [
16
+ ARCHS: list[str] = [
17
17
  "crnn_vgg16_bn",
18
18
  "crnn_mobilenet_v3_small",
19
19
  "crnn_mobilenet_v3_large",
@@ -26,7 +26,7 @@ ARCHS: List[str] = [
26
26
 
27
27
 
28
28
  def _predictor(
29
- arch: Any, load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any
29
+ arch: Any, load_in_8_bit: bool = False, engine_cfg: EngineConfig | None = None, **kwargs: Any
30
30
  ) -> RecognitionPredictor:
31
31
  if isinstance(arch, str):
32
32
  if arch not in ARCHS:
@@ -50,7 +50,12 @@ def _predictor(
50
50
 
51
51
 
52
52
  def recognition_predictor(
53
- arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any
53
+ arch: Any = "crnn_vgg16_bn",
54
+ symmetric_pad: bool = False,
55
+ batch_size: int = 128,
56
+ load_in_8_bit: bool = False,
57
+ engine_cfg: EngineConfig | None = None,
58
+ **kwargs: Any,
54
59
  ) -> RecognitionPredictor:
55
60
  """Text recognition architecture.
56
61
 
@@ -62,14 +67,21 @@ def recognition_predictor(
62
67
  >>> out = model([input_page])
63
68
 
64
69
  Args:
65
- ----
66
70
  arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
71
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
72
+ batch_size: number of samples the model processes in parallel
67
73
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
68
74
  engine_cfg: configuration of inference engine
69
75
  **kwargs: optional parameters to be passed to the architecture
70
76
 
71
77
  Returns:
72
- -------
73
78
  Recognition predictor
74
79
  """
75
- return _predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
80
+ return _predictor(
81
+ arch=arch,
82
+ symmetric_pad=symmetric_pad,
83
+ batch_size=batch_size,
84
+ load_in_8_bit=load_in_8_bit,
85
+ engine_cfg=engine_cfg,
86
+ **kwargs,
87
+ )
onnxtr/models/zoo.py CHANGED
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Optional
6
+ from typing import Any
7
7
 
8
8
  from .detection.zoo import detection_predictor
9
9
  from .engine import EngineConfig
@@ -25,9 +25,9 @@ def _predictor(
25
25
  straighten_pages: bool = False,
26
26
  detect_language: bool = False,
27
27
  load_in_8_bit: bool = False,
28
- det_engine_cfg: Optional[EngineConfig] = None,
29
- reco_engine_cfg: Optional[EngineConfig] = None,
30
- clf_engine_cfg: Optional[EngineConfig] = None,
28
+ det_engine_cfg: EngineConfig | None = None,
29
+ reco_engine_cfg: EngineConfig | None = None,
30
+ clf_engine_cfg: EngineConfig | None = None,
31
31
  **kwargs,
32
32
  ) -> OCRPredictor:
33
33
  # Detection
@@ -74,9 +74,9 @@ def ocr_predictor(
74
74
  straighten_pages: bool = False,
75
75
  detect_language: bool = False,
76
76
  load_in_8_bit: bool = False,
77
- det_engine_cfg: Optional[EngineConfig] = None,
78
- reco_engine_cfg: Optional[EngineConfig] = None,
79
- clf_engine_cfg: Optional[EngineConfig] = None,
77
+ det_engine_cfg: EngineConfig | None = None,
78
+ reco_engine_cfg: EngineConfig | None = None,
79
+ clf_engine_cfg: EngineConfig | None = None,
80
80
  **kwargs: Any,
81
81
  ) -> OCRPredictor:
82
82
  """End-to-end OCR architecture using one model for localization, and another for text recognition.
@@ -88,7 +88,6 @@ def ocr_predictor(
88
88
  >>> out = model([input_page])
89
89
 
90
90
  Args:
91
- ----
92
91
  det_arch: name of the detection architecture or the model itself to use
93
92
  (e.g. 'db_resnet50', 'db_mobilenet_v3_large')
94
93
  reco_arch: name of the recognition architecture or the model itself to use
@@ -115,7 +114,6 @@ def ocr_predictor(
115
114
  kwargs: keyword args of `OCRPredictor`
116
115
 
117
116
  Returns:
118
- -------
119
117
  OCR predictor
120
118
  """
121
119
  return _predictor(
onnxtr/transforms/base.py CHANGED
@@ -3,7 +3,6 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Tuple, Union
7
6
 
8
7
  import numpy as np
9
8
  from PIL import Image, ImageOps
@@ -12,11 +11,18 @@ __all__ = ["Resize", "Normalize"]
12
11
 
13
12
 
14
13
  class Resize:
15
- """Resize the input image to the given size"""
14
+ """Resize the input image to the given size
15
+
16
+ Args:
17
+ size: the target size of the image
18
+ interpolation: the interpolation method to use
19
+ preserve_aspect_ratio: whether to preserve the aspect ratio of the image
20
+ symmetric_pad: whether to symmetrically pad the image
21
+ """
16
22
 
17
23
  def __init__(
18
24
  self,
19
- size: Union[int, Tuple[int, int]],
25
+ size: int | tuple[int, int],
20
26
  interpolation=Image.Resampling.BILINEAR,
21
27
  preserve_aspect_ratio: bool = False,
22
28
  symmetric_pad: bool = False,
@@ -72,12 +78,17 @@ class Resize:
72
78
 
73
79
 
74
80
  class Normalize:
75
- """Normalize the input image"""
81
+ """Normalize the input image
82
+
83
+ Args:
84
+ mean: mean values to subtract
85
+ std: standard deviation values to divide
86
+ """
76
87
 
77
88
  def __init__(
78
89
  self,
79
- mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406),
80
- std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225),
90
+ mean: float | tuple[float, float, float] = (0.485, 0.456, 0.406),
91
+ std: float | tuple[float, float, float] = (0.229, 0.224, 0.225),
81
92
  ) -> None:
82
93
  self.mean = mean
83
94
  self.std = std
@@ -4,15 +4,14 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  from pathlib import Path
7
- from typing import List, Tuple, Union
8
7
 
9
8
  __all__ = ["Point2D", "BoundingBox", "Polygon4P", "Polygon", "Bbox"]
10
9
 
11
10
 
12
- Point2D = Tuple[float, float]
13
- BoundingBox = Tuple[Point2D, Point2D]
14
- Polygon4P = Tuple[Point2D, Point2D, Point2D, Point2D]
15
- Polygon = List[Point2D]
16
- AbstractPath = Union[str, Path]
17
- AbstractFile = Union[AbstractPath, bytes]
18
- Bbox = Tuple[float, float, float, float]
11
+ Point2D = tuple[float, float]
12
+ BoundingBox = tuple[Point2D, Point2D]
13
+ Polygon4P = tuple[Point2D, Point2D, Point2D, Point2D]
14
+ Polygon = list[Point2D]
15
+ AbstractPath = str | Path
16
+ AbstractFile = AbstractPath | bytes
17
+ Bbox = tuple[float, float, float, float]
onnxtr/utils/data.py CHANGED
@@ -13,7 +13,6 @@ import urllib
13
13
  import urllib.error
14
14
  import urllib.request
15
15
  from pathlib import Path
16
- from typing import Optional, Union
17
16
 
18
17
  from tqdm.auto import tqdm
19
18
 
@@ -25,7 +24,7 @@ HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
25
24
  USER_AGENT = "felixdittrich92/OnnxTR"
26
25
 
27
26
 
28
- def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -> None:
27
+ def _urlretrieve(url: str, filename: Path | str, chunk_size: int = 1024) -> None:
29
28
  with open(filename, "wb") as fh:
30
29
  with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
31
30
  with tqdm(total=response.length) as pbar:
@@ -36,7 +35,7 @@ def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -
36
35
  fh.write(chunk)
37
36
 
38
37
 
39
- def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool:
38
+ def _check_integrity(file_path: str | Path, hash_prefix: str) -> bool:
40
39
  with open(file_path, "rb") as f:
41
40
  sha_hash = hashlib.sha256(f.read()).hexdigest()
42
41
 
@@ -45,10 +44,10 @@ def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool:
45
44
 
46
45
  def download_from_url(
47
46
  url: str,
48
- file_name: Optional[str] = None,
49
- hash_prefix: Optional[str] = None,
50
- cache_dir: Optional[str] = None,
51
- cache_subdir: Optional[str] = None,
47
+ file_name: str | None = None,
48
+ hash_prefix: str | None = None,
49
+ cache_dir: str | None = None,
50
+ cache_subdir: str | None = None,
52
51
  ) -> Path:
53
52
  """Download a file using its URL
54
53
 
@@ -56,7 +55,6 @@ def download_from_url(
56
55
  >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip")
57
56
 
58
57
  Args:
59
- ----
60
58
  url: the URL of the file to download
61
59
  file_name: optional name of the file once downloaded
62
60
  hash_prefix: optional expected SHA256 hash of the file
@@ -64,11 +62,9 @@ def download_from_url(
64
62
  cache_subdir: subfolder to use in the cache
65
63
 
66
64
  Returns:
67
- -------
68
65
  the location of the downloaded file
69
66
 
70
67
  Note:
71
- ----
72
68
  You can change cache directory location by using `ONNXTR_CACHE_DIR` environment variable.
73
69
  """
74
70
  if not isinstance(file_name, str):
@@ -112,7 +108,7 @@ def download_from_url(
112
108
  except (urllib.error.URLError, IOError) as e: # pragma: no cover
113
109
  if url[:5] == "https":
114
110
  url = url.replace("https:", "http:")
115
- print("Failed download. Trying https -> http instead." f" Downloading {url} to {file_path}")
111
+ print(f"Failed download. Trying https -> http instead. Downloading {url} to {file_path}")
116
112
  _urlretrieve(url, file_path)
117
113
  else:
118
114
  raise e
onnxtr/utils/fonts.py CHANGED
@@ -5,25 +5,20 @@
5
5
 
6
6
  import logging
7
7
  import platform
8
- from typing import Optional, Union
9
8
 
10
9
  from PIL import ImageFont
11
10
 
12
11
  __all__ = ["get_font"]
13
12
 
14
13
 
15
- def get_font(
16
- font_family: Optional[str] = None, font_size: int = 13
17
- ) -> Union[ImageFont.FreeTypeFont, ImageFont.ImageFont]:
14
+ def get_font(font_family: str | None = None, font_size: int = 13) -> ImageFont.FreeTypeFont | ImageFont.ImageFont:
18
15
  """Resolves a compatible ImageFont for the system
19
16
 
20
17
  Args:
21
- ----
22
18
  font_family: the font family to use
23
19
  font_size: the size of the font upon rendering
24
20
 
25
21
  Returns:
26
- -------
27
22
  the Pillow font
28
23
  """
29
24
  # Font selection