onnxtr 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. onnxtr/contrib/__init__.py +1 -0
  2. onnxtr/contrib/artefacts.py +6 -8
  3. onnxtr/contrib/base.py +7 -16
  4. onnxtr/file_utils.py +1 -3
  5. onnxtr/io/elements.py +54 -60
  6. onnxtr/io/html.py +0 -2
  7. onnxtr/io/image.py +1 -4
  8. onnxtr/io/pdf.py +3 -5
  9. onnxtr/io/reader.py +4 -10
  10. onnxtr/models/_utils.py +10 -17
  11. onnxtr/models/builder.py +17 -30
  12. onnxtr/models/classification/models/mobilenet.py +7 -12
  13. onnxtr/models/classification/predictor/base.py +6 -7
  14. onnxtr/models/classification/zoo.py +25 -11
  15. onnxtr/models/detection/_utils/base.py +3 -7
  16. onnxtr/models/detection/core.py +2 -8
  17. onnxtr/models/detection/models/differentiable_binarization.py +10 -17
  18. onnxtr/models/detection/models/fast.py +10 -17
  19. onnxtr/models/detection/models/linknet.py +10 -17
  20. onnxtr/models/detection/postprocessor/base.py +3 -9
  21. onnxtr/models/detection/predictor/base.py +4 -5
  22. onnxtr/models/detection/zoo.py +20 -6
  23. onnxtr/models/engine.py +9 -9
  24. onnxtr/models/factory/hub.py +3 -7
  25. onnxtr/models/predictor/base.py +29 -30
  26. onnxtr/models/predictor/predictor.py +4 -5
  27. onnxtr/models/preprocessor/base.py +8 -12
  28. onnxtr/models/recognition/core.py +0 -1
  29. onnxtr/models/recognition/models/crnn.py +11 -23
  30. onnxtr/models/recognition/models/master.py +9 -15
  31. onnxtr/models/recognition/models/parseq.py +8 -12
  32. onnxtr/models/recognition/models/sar.py +8 -12
  33. onnxtr/models/recognition/models/vitstr.py +9 -15
  34. onnxtr/models/recognition/predictor/_utils.py +6 -9
  35. onnxtr/models/recognition/predictor/base.py +3 -3
  36. onnxtr/models/recognition/utils.py +2 -7
  37. onnxtr/models/recognition/zoo.py +19 -7
  38. onnxtr/models/zoo.py +7 -9
  39. onnxtr/transforms/base.py +17 -6
  40. onnxtr/utils/common_types.py +7 -8
  41. onnxtr/utils/data.py +7 -11
  42. onnxtr/utils/fonts.py +1 -6
  43. onnxtr/utils/geometry.py +18 -49
  44. onnxtr/utils/multithreading.py +3 -5
  45. onnxtr/utils/reconstitution.py +139 -38
  46. onnxtr/utils/repr.py +1 -2
  47. onnxtr/utils/visualization.py +12 -21
  48. onnxtr/utils/vocabs.py +1 -2
  49. onnxtr/version.py +1 -1
  50. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/METADATA +71 -41
  51. onnxtr-0.6.0.dist-info/RECORD +75 -0
  52. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/WHEEL +1 -1
  53. onnxtr-0.5.0.dist-info/RECORD +0 -75
  54. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/LICENSE +0 -0
  55. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/top_level.txt +0 -0
  56. {onnxtr-0.5.0.dist-info → onnxtr-0.6.0.dist-info}/zip-safe +0 -0
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, Optional
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  from scipy.special import expit
@@ -14,7 +14,7 @@ from ..postprocessor.base import GeneralDetectionPostProcessor
14
14
  __all__ = ["DBNet", "db_resnet50", "db_resnet34", "db_mobilenet_v3_large"]
15
15
 
16
16
 
17
- default_cfgs: Dict[str, Dict[str, Any]] = {
17
+ default_cfgs: dict[str, dict[str, Any]] = {
18
18
  "db_resnet50": {
19
19
  "input_shape": (3, 1024, 1024),
20
20
  "mean": (0.798, 0.785, 0.772),
@@ -43,7 +43,6 @@ class DBNet(Engine):
43
43
  """DBNet Onnx loader
44
44
 
45
45
  Args:
46
- ----
47
46
  model_path: path or url to onnx model file
48
47
  engine_cfg: configuration for the inference engine
49
48
  bin_thresh: threshold for binarization of the output feature map
@@ -56,11 +55,11 @@ class DBNet(Engine):
56
55
  def __init__(
57
56
  self,
58
57
  model_path: str,
59
- engine_cfg: Optional[EngineConfig] = None,
58
+ engine_cfg: EngineConfig | None = None,
60
59
  bin_thresh: float = 0.3,
61
60
  box_thresh: float = 0.1,
62
61
  assume_straight_pages: bool = True,
63
- cfg: Optional[Dict[str, Any]] = None,
62
+ cfg: dict[str, Any] | None = None,
64
63
  **kwargs: Any,
65
64
  ) -> None:
66
65
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -77,10 +76,10 @@ class DBNet(Engine):
77
76
  x: np.ndarray,
78
77
  return_model_output: bool = False,
79
78
  **kwargs: Any,
80
- ) -> Dict[str, Any]:
79
+ ) -> dict[str, Any]:
81
80
  logits = self.run(x)
82
81
 
83
- out: Dict[str, Any] = {}
82
+ out: dict[str, Any] = {}
84
83
 
85
84
  prob_map = expit(logits)
86
85
  if return_model_output:
@@ -95,7 +94,7 @@ def _dbnet(
95
94
  arch: str,
96
95
  model_path: str,
97
96
  load_in_8_bit: bool = False,
98
- engine_cfg: Optional[EngineConfig] = None,
97
+ engine_cfg: EngineConfig | None = None,
99
98
  **kwargs: Any,
100
99
  ) -> DBNet:
101
100
  # Patch the url
@@ -107,7 +106,7 @@ def _dbnet(
107
106
  def db_resnet34(
108
107
  model_path: str = default_cfgs["db_resnet34"]["url"],
109
108
  load_in_8_bit: bool = False,
110
- engine_cfg: Optional[EngineConfig] = None,
109
+ engine_cfg: EngineConfig | None = None,
111
110
  **kwargs: Any,
112
111
  ) -> DBNet:
113
112
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
@@ -120,14 +119,12 @@ def db_resnet34(
120
119
  >>> out = model(input_tensor)
121
120
 
122
121
  Args:
123
- ----
124
122
  model_path: path to onnx model file, defaults to url in default_cfgs
125
123
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
126
124
  engine_cfg: configuration for the inference engine
127
125
  **kwargs: keyword arguments of the DBNet architecture
128
126
 
129
127
  Returns:
130
- -------
131
128
  text detection architecture
132
129
  """
133
130
  return _dbnet("db_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -136,7 +133,7 @@ def db_resnet34(
136
133
  def db_resnet50(
137
134
  model_path: str = default_cfgs["db_resnet50"]["url"],
138
135
  load_in_8_bit: bool = False,
139
- engine_cfg: Optional[EngineConfig] = None,
136
+ engine_cfg: EngineConfig | None = None,
140
137
  **kwargs: Any,
141
138
  ) -> DBNet:
142
139
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
@@ -149,14 +146,12 @@ def db_resnet50(
149
146
  >>> out = model(input_tensor)
150
147
 
151
148
  Args:
152
- ----
153
149
  model_path: path to onnx model file, defaults to url in default_cfgs
154
150
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
155
151
  engine_cfg: configuration for the inference engine
156
152
  **kwargs: keyword arguments of the DBNet architecture
157
153
 
158
154
  Returns:
159
- -------
160
155
  text detection architecture
161
156
  """
162
157
  return _dbnet("db_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -165,7 +160,7 @@ def db_resnet50(
165
160
  def db_mobilenet_v3_large(
166
161
  model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"],
167
162
  load_in_8_bit: bool = False,
168
- engine_cfg: Optional[EngineConfig] = None,
163
+ engine_cfg: EngineConfig | None = None,
169
164
  **kwargs: Any,
170
165
  ) -> DBNet:
171
166
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
@@ -178,14 +173,12 @@ def db_mobilenet_v3_large(
178
173
  >>> out = model(input_tensor)
179
174
 
180
175
  Args:
181
- ----
182
176
  model_path: path to onnx model file, defaults to url in default_cfgs
183
177
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
184
178
  engine_cfg: configuration for the inference engine
185
179
  **kwargs: keyword arguments of the DBNet architecture
186
180
 
187
181
  Returns:
188
- -------
189
182
  text detection architecture
190
183
  """
191
184
  return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -4,7 +4,7 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
  import logging
7
- from typing import Any, Dict, Optional
7
+ from typing import Any
8
8
 
9
9
  import numpy as np
10
10
  from scipy.special import expit
@@ -15,7 +15,7 @@ from ..postprocessor.base import GeneralDetectionPostProcessor
15
15
  __all__ = ["FAST", "fast_tiny", "fast_small", "fast_base"]
16
16
 
17
17
 
18
- default_cfgs: Dict[str, Dict[str, Any]] = {
18
+ default_cfgs: dict[str, dict[str, Any]] = {
19
19
  "fast_tiny": {
20
20
  "input_shape": (3, 1024, 1024),
21
21
  "mean": (0.798, 0.785, 0.772),
@@ -41,7 +41,6 @@ class FAST(Engine):
41
41
  """FAST Onnx loader
42
42
 
43
43
  Args:
44
- ----
45
44
  model_path: path or url to onnx model file
46
45
  engine_cfg: configuration for the inference engine
47
46
  bin_thresh: threshold for binarization of the output feature map
@@ -54,11 +53,11 @@ class FAST(Engine):
54
53
  def __init__(
55
54
  self,
56
55
  model_path: str,
57
- engine_cfg: Optional[EngineConfig] = None,
56
+ engine_cfg: EngineConfig | None = None,
58
57
  bin_thresh: float = 0.1,
59
58
  box_thresh: float = 0.1,
60
59
  assume_straight_pages: bool = True,
61
- cfg: Optional[Dict[str, Any]] = None,
60
+ cfg: dict[str, Any] | None = None,
62
61
  **kwargs: Any,
63
62
  ) -> None:
64
63
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -75,10 +74,10 @@ class FAST(Engine):
75
74
  x: np.ndarray,
76
75
  return_model_output: bool = False,
77
76
  **kwargs: Any,
78
- ) -> Dict[str, Any]:
77
+ ) -> dict[str, Any]:
79
78
  logits = self.run(x)
80
79
 
81
- out: Dict[str, Any] = {}
80
+ out: dict[str, Any] = {}
82
81
 
83
82
  prob_map = expit(logits)
84
83
  if return_model_output:
@@ -93,7 +92,7 @@ def _fast(
93
92
  arch: str,
94
93
  model_path: str,
95
94
  load_in_8_bit: bool = False,
96
- engine_cfg: Optional[EngineConfig] = None,
95
+ engine_cfg: EngineConfig | None = None,
97
96
  **kwargs: Any,
98
97
  ) -> FAST:
99
98
  if load_in_8_bit:
@@ -105,7 +104,7 @@ def _fast(
105
104
  def fast_tiny(
106
105
  model_path: str = default_cfgs["fast_tiny"]["url"],
107
106
  load_in_8_bit: bool = False,
108
- engine_cfg: Optional[EngineConfig] = None,
107
+ engine_cfg: EngineConfig | None = None,
109
108
  **kwargs: Any,
110
109
  ) -> FAST:
111
110
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
@@ -118,14 +117,12 @@ def fast_tiny(
118
117
  >>> out = model(input_tensor)
119
118
 
120
119
  Args:
121
- ----
122
120
  model_path: path to onnx model file, defaults to url in default_cfgs
123
121
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
124
122
  engine_cfg: configuration for the inference engine
125
123
  **kwargs: keyword arguments of the DBNet architecture
126
124
 
127
125
  Returns:
128
- -------
129
126
  text detection architecture
130
127
  """
131
128
  return _fast("fast_tiny", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -134,7 +131,7 @@ def fast_tiny(
134
131
  def fast_small(
135
132
  model_path: str = default_cfgs["fast_small"]["url"],
136
133
  load_in_8_bit: bool = False,
137
- engine_cfg: Optional[EngineConfig] = None,
134
+ engine_cfg: EngineConfig | None = None,
138
135
  **kwargs: Any,
139
136
  ) -> FAST:
140
137
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
@@ -147,14 +144,12 @@ def fast_small(
147
144
  >>> out = model(input_tensor)
148
145
 
149
146
  Args:
150
- ----
151
147
  model_path: path to onnx model file, defaults to url in default_cfgs
152
148
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
153
149
  engine_cfg: configuration for the inference engine
154
150
  **kwargs: keyword arguments of the DBNet architecture
155
151
 
156
152
  Returns:
157
- -------
158
153
  text detection architecture
159
154
  """
160
155
  return _fast("fast_small", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -163,7 +158,7 @@ def fast_small(
163
158
  def fast_base(
164
159
  model_path: str = default_cfgs["fast_base"]["url"],
165
160
  load_in_8_bit: bool = False,
166
- engine_cfg: Optional[EngineConfig] = None,
161
+ engine_cfg: EngineConfig | None = None,
167
162
  **kwargs: Any,
168
163
  ) -> FAST:
169
164
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
@@ -176,14 +171,12 @@ def fast_base(
176
171
  >>> out = model(input_tensor)
177
172
 
178
173
  Args:
179
- ----
180
174
  model_path: path to onnx model file, defaults to url in default_cfgs
181
175
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
182
176
  engine_cfg: configuration for the inference engine
183
177
  **kwargs: keyword arguments of the DBNet architecture
184
178
 
185
179
  Returns:
186
- -------
187
180
  text detection architecture
188
181
  """
189
182
  return _fast("fast_base", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, Optional
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
  from scipy.special import expit
@@ -14,7 +14,7 @@ from ..postprocessor.base import GeneralDetectionPostProcessor
14
14
  __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"]
15
15
 
16
16
 
17
- default_cfgs: Dict[str, Dict[str, Any]] = {
17
+ default_cfgs: dict[str, dict[str, Any]] = {
18
18
  "linknet_resnet18": {
19
19
  "input_shape": (3, 1024, 1024),
20
20
  "mean": (0.798, 0.785, 0.772),
@@ -43,7 +43,6 @@ class LinkNet(Engine):
43
43
  """LinkNet Onnx loader
44
44
 
45
45
  Args:
46
- ----
47
46
  model_path: path or url to onnx model file
48
47
  engine_cfg: configuration for the inference engine
49
48
  bin_thresh: threshold for binarization of the output feature map
@@ -56,11 +55,11 @@ class LinkNet(Engine):
56
55
  def __init__(
57
56
  self,
58
57
  model_path: str,
59
- engine_cfg: Optional[EngineConfig] = None,
58
+ engine_cfg: EngineConfig | None = None,
60
59
  bin_thresh: float = 0.1,
61
60
  box_thresh: float = 0.1,
62
61
  assume_straight_pages: bool = True,
63
- cfg: Optional[Dict[str, Any]] = None,
62
+ cfg: dict[str, Any] | None = None,
64
63
  **kwargs: Any,
65
64
  ) -> None:
66
65
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
@@ -77,10 +76,10 @@ class LinkNet(Engine):
77
76
  x: np.ndarray,
78
77
  return_model_output: bool = False,
79
78
  **kwargs: Any,
80
- ) -> Dict[str, Any]:
79
+ ) -> dict[str, Any]:
81
80
  logits = self.run(x)
82
81
 
83
- out: Dict[str, Any] = {}
82
+ out: dict[str, Any] = {}
84
83
 
85
84
  prob_map = expit(logits)
86
85
  if return_model_output:
@@ -95,7 +94,7 @@ def _linknet(
95
94
  arch: str,
96
95
  model_path: str,
97
96
  load_in_8_bit: bool = False,
98
- engine_cfg: Optional[EngineConfig] = None,
97
+ engine_cfg: EngineConfig | None = None,
99
98
  **kwargs: Any,
100
99
  ) -> LinkNet:
101
100
  # Patch the url
@@ -107,7 +106,7 @@ def _linknet(
107
106
  def linknet_resnet18(
108
107
  model_path: str = default_cfgs["linknet_resnet18"]["url"],
109
108
  load_in_8_bit: bool = False,
110
- engine_cfg: Optional[EngineConfig] = None,
109
+ engine_cfg: EngineConfig | None = None,
111
110
  **kwargs: Any,
112
111
  ) -> LinkNet:
113
112
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
@@ -120,14 +119,12 @@ def linknet_resnet18(
120
119
  >>> out = model(input_tensor)
121
120
 
122
121
  Args:
123
- ----
124
122
  model_path: path to onnx model file, defaults to url in default_cfgs
125
123
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
126
124
  engine_cfg: configuration for the inference engine
127
125
  **kwargs: keyword arguments of the LinkNet architecture
128
126
 
129
127
  Returns:
130
- -------
131
128
  text detection architecture
132
129
  """
133
130
  return _linknet("linknet_resnet18", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -136,7 +133,7 @@ def linknet_resnet18(
136
133
  def linknet_resnet34(
137
134
  model_path: str = default_cfgs["linknet_resnet34"]["url"],
138
135
  load_in_8_bit: bool = False,
139
- engine_cfg: Optional[EngineConfig] = None,
136
+ engine_cfg: EngineConfig | None = None,
140
137
  **kwargs: Any,
141
138
  ) -> LinkNet:
142
139
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
@@ -149,14 +146,12 @@ def linknet_resnet34(
149
146
  >>> out = model(input_tensor)
150
147
 
151
148
  Args:
152
- ----
153
149
  model_path: path to onnx model file, defaults to url in default_cfgs
154
150
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
155
151
  engine_cfg: configuration for the inference engine
156
152
  **kwargs: keyword arguments of the LinkNet architecture
157
153
 
158
154
  Returns:
159
- -------
160
155
  text detection architecture
161
156
  """
162
157
  return _linknet("linknet_resnet34", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -165,7 +160,7 @@ def linknet_resnet34(
165
160
  def linknet_resnet50(
166
161
  model_path: str = default_cfgs["linknet_resnet50"]["url"],
167
162
  load_in_8_bit: bool = False,
168
- engine_cfg: Optional[EngineConfig] = None,
163
+ engine_cfg: EngineConfig | None = None,
169
164
  **kwargs: Any,
170
165
  ) -> LinkNet:
171
166
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
@@ -178,14 +173,12 @@ def linknet_resnet50(
178
173
  >>> out = model(input_tensor)
179
174
 
180
175
  Args:
181
- ----
182
176
  model_path: path to onnx model file, defaults to url in default_cfgs
183
177
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
184
178
  engine_cfg: configuration for the inference engine
185
179
  **kwargs: keyword arguments of the LinkNet architecture
186
180
 
187
181
  Returns:
188
- -------
189
182
  text detection architecture
190
183
  """
191
184
  return _linknet("linknet_resnet50", model_path, load_in_8_bit, engine_cfg, **kwargs)
@@ -5,7 +5,6 @@
5
5
 
6
6
  # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
7
7
 
8
- from typing import List, Union
9
8
 
10
9
  import cv2
11
10
  import numpy as np
@@ -21,7 +20,6 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
21
20
  """Implements a post processor for FAST model.
22
21
 
23
22
  Args:
24
- ----
25
23
  bin_thresh: threshold used to binzarized p_map at inference time
26
24
  box_thresh: minimal objectness score to consider a box
27
25
  assume_straight_pages: whether the inputs were expected to have horizontal text elements
@@ -43,11 +41,9 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
43
41
  """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
44
42
 
45
43
  Args:
46
- ----
47
44
  points: The first parameter.
48
45
 
49
46
  Returns:
50
- -------
51
47
  a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
52
48
  """
53
49
  if not self.assume_straight_pages:
@@ -92,24 +88,22 @@ class GeneralDetectionPostProcessor(DetectionPostProcessor):
92
88
  """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
93
89
 
94
90
  Args:
95
- ----
96
91
  pred: Pred map from differentiable linknet output
97
92
  bitmap: Bitmap map computed from pred (binarized)
98
93
  angle_tol: Comparison tolerance of the angle with the median angle across the page
99
94
  ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
100
95
 
101
96
  Returns:
102
- -------
103
97
  np tensor boxes for the bitmap, each box is a 6-element list
104
- containing x, y, w, h, alpha, score for the box
98
+ containing x, y, w, h, alpha, score for the box
105
99
  """
106
100
  height, width = bitmap.shape[:2]
107
- boxes: List[Union[np.ndarray, List[float]]] = []
101
+ boxes: list[np.ndarray | list[float]] = []
108
102
  # get contours from connected components on the bitmap
109
103
  contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
110
104
  for contour in contours:
111
105
  # Check whether smallest enclosing bounding box is not too small
112
- if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
106
+ if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
113
107
  continue
114
108
  # Compute objectness
115
109
  if self.assume_straight_pages:
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, List, Tuple, Union
6
+ from typing import Any
7
7
 
8
8
  import numpy as np
9
9
 
@@ -18,12 +18,11 @@ class DetectionPredictor(NestedObject):
18
18
  """Implements an object able to localize text elements in a document
19
19
 
20
20
  Args:
21
- ----
22
21
  pre_processor: transform inputs for easier batched model inference
23
22
  model: core detection architecture
24
23
  """
25
24
 
26
- _children_names: List[str] = ["pre_processor", "model"]
25
+ _children_names: list[str] = ["pre_processor", "model"]
27
26
 
28
27
  def __init__(
29
28
  self,
@@ -35,10 +34,10 @@ class DetectionPredictor(NestedObject):
35
34
 
36
35
  def __call__(
37
36
  self,
38
- pages: List[np.ndarray],
37
+ pages: list[np.ndarray],
39
38
  return_maps: bool = False,
40
39
  **kwargs: Any,
41
- ) -> Union[List[np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]:
40
+ ) -> list[np.ndarray] | tuple[list[np.ndarray], list[np.ndarray]]:
42
41
  # Extract parameters from the preprocessor
43
42
  preserve_aspect_ratio = self.pre_processor.resize.preserve_aspect_ratio
44
43
  symmetric_pad = self.pre_processor.resize.symmetric_pad
@@ -3,7 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Optional
6
+ from typing import Any
7
7
 
8
8
  from .. import detection
9
9
  from ..engine import EngineConfig
@@ -29,7 +29,7 @@ def _predictor(
29
29
  arch: Any,
30
30
  assume_straight_pages: bool = True,
31
31
  load_in_8_bit: bool = False,
32
- engine_cfg: Optional[EngineConfig] = None,
32
+ engine_cfg: EngineConfig | None = None,
33
33
  **kwargs: Any,
34
34
  ) -> DetectionPredictor:
35
35
  if isinstance(arch, str):
@@ -59,8 +59,11 @@ def _predictor(
59
59
  def detection_predictor(
60
60
  arch: Any = "fast_base",
61
61
  assume_straight_pages: bool = True,
62
+ preserve_aspect_ratio: bool = True,
63
+ symmetric_pad: bool = True,
64
+ batch_size: int = 2,
62
65
  load_in_8_bit: bool = False,
63
- engine_cfg: Optional[EngineConfig] = None,
66
+ engine_cfg: EngineConfig | None = None,
64
67
  **kwargs: Any,
65
68
  ) -> DetectionPredictor:
66
69
  """Text detection architecture.
@@ -72,15 +75,26 @@ def detection_predictor(
72
75
  >>> out = model([input_page])
73
76
 
74
77
  Args:
75
- ----
76
78
  arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
77
79
  assume_straight_pages: If True, fit straight boxes to the page
80
+ preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
81
+ running the detection model on it
82
+ symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
83
+ batch_size: number of samples the model processes in parallel
78
84
  load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
79
85
  engine_cfg: configuration for the inference engine
80
86
  **kwargs: optional keyword arguments passed to the architecture
81
87
 
82
88
  Returns:
83
- -------
84
89
  Detection predictor
85
90
  """
86
- return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs)
91
+ return _predictor(
92
+ arch=arch,
93
+ assume_straight_pages=assume_straight_pages,
94
+ preserve_aspect_ratio=preserve_aspect_ratio,
95
+ symmetric_pad=symmetric_pad,
96
+ batch_size=batch_size,
97
+ load_in_8_bit=load_in_8_bit,
98
+ engine_cfg=engine_cfg,
99
+ **kwargs,
100
+ )
onnxtr/models/engine.py CHANGED
@@ -3,7 +3,8 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
- from typing import Any, Dict, List, Optional, Tuple, Union
6
+ import logging
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  from onnxruntime import (
@@ -25,22 +26,22 @@ class EngineConfig:
25
26
  """Implements a configuration class for the engine of a model
26
27
 
27
28
  Args:
28
- ----
29
29
  providers: list of providers to use for inference ref.: https://onnxruntime.ai/docs/execution-providers/
30
30
  session_options: configuration for the inference session ref.: https://onnxruntime.ai/docs/api/python/api_summary.html#sessionoptions
31
31
  """
32
32
 
33
33
  def __init__(
34
34
  self,
35
- providers: Optional[Union[List[Tuple[str, Dict[str, Any]]], List[str]]] = None,
36
- session_options: Optional[SessionOptions] = None,
35
+ providers: list[tuple[str, dict[str, Any]]] | list[str] | None = None,
36
+ session_options: SessionOptions | None = None,
37
37
  ):
38
38
  self._providers = providers or self._init_providers()
39
39
  self._session_options = session_options or self._init_sess_opts()
40
40
 
41
- def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]:
41
+ def _init_providers(self) -> list[tuple[str, dict[str, Any]]]:
42
42
  providers: Any = [("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})]
43
43
  available_providers = get_available_providers()
44
+ logging.info(f"Available providers: {available_providers}")
44
45
  if "CUDAExecutionProvider" in available_providers and get_device() == "GPU": # pragma: no cover
45
46
  providers.insert(
46
47
  0,
@@ -66,7 +67,7 @@ class EngineConfig:
66
67
  return session_options
67
68
 
68
69
  @property
69
- def providers(self) -> Union[List[Tuple[str, Dict[str, Any]]], List[str]]:
70
+ def providers(self) -> list[tuple[str, dict[str, Any]]] | list[str]:
70
71
  return self._providers
71
72
 
72
73
  @property
@@ -81,13 +82,12 @@ class Engine:
81
82
  """Implements an abstract class for the engine of a model
82
83
 
83
84
  Args:
84
- ----
85
85
  url: the url to use to download a model if needed
86
86
  engine_cfg: the configuration of the engine
87
87
  **kwargs: additional arguments to be passed to `download_from_url`
88
88
  """
89
89
 
90
- def __init__(self, url: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any) -> None:
90
+ def __init__(self, url: str, engine_cfg: EngineConfig | None = None, **kwargs: Any) -> None:
91
91
  engine_cfg = engine_cfg if isinstance(engine_cfg, EngineConfig) else EngineConfig()
92
92
  archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
93
93
  # Store model path for each model
@@ -97,7 +97,7 @@ class Engine:
97
97
  self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
98
98
  self.runtime_inputs = self.runtime.get_inputs()[0]
99
99
  self.tf_exported = int(self.runtime_inputs.shape[-1]) == 3
100
- self.fixed_batch_size: Union[int, str] = self.runtime_inputs.shape[
100
+ self.fixed_batch_size: int | str = self.runtime_inputs.shape[
101
101
  0
102
102
  ] # mostly possible with tensorflow exported models
103
103
  self.output_name = [output.name for output in self.runtime.get_outputs()]
@@ -12,7 +12,7 @@ import shutil
12
12
  import subprocess
13
13
  import textwrap
14
14
  from pathlib import Path
15
- from typing import Any, Optional
15
+ from typing import Any
16
16
 
17
17
  from huggingface_hub import (
18
18
  HfApi,
@@ -59,7 +59,6 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
59
59
  """Save model and config to disk for pushing to huggingface hub
60
60
 
61
61
  Args:
62
- ----
63
62
  model: Onnx model to be saved
64
63
  save_dir: directory to save model and config
65
64
  arch: architecture name
@@ -91,7 +90,6 @@ def push_to_hf_hub(
91
90
  >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
92
91
 
93
92
  Args:
94
- ----
95
93
  model: Onnx model to be saved
96
94
  model_name: name of the model which is also the repository name
97
95
  task: task name
@@ -179,20 +177,18 @@ def push_to_hf_hub(
179
177
  repo.git_push()
180
178
 
181
179
 
182
- def from_hub(repo_id: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any):
180
+ def from_hub(repo_id: str, engine_cfg: EngineConfig | None = None, **kwargs: Any):
183
181
  """Instantiate & load a pretrained model from HF hub.
184
182
 
185
183
  >>> from onnxtr.models import from_hub
186
184
  >>> model = from_hub("onnxtr/my-model")
187
185
 
188
186
  Args:
189
- ----
190
187
  repo_id: HuggingFace model hub repo
191
188
  engine_cfg: configuration for the inference engine (optional)
192
- kwargs: kwargs of `hf_hub_download`
189
+ **kwargs: kwargs of `hf_hub_download`
193
190
 
194
191
  Returns:
195
- -------
196
192
  Model loaded with the checkpoint
197
193
  """
198
194
  # Get the config