onnxtr 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {onnxtr-0.1.2 → onnxtr-0.2.0}/PKG-INFO +37 -11
  2. {onnxtr-0.1.2 → onnxtr-0.2.0}/README.md +36 -10
  3. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/classification/models/mobilenet.py +15 -4
  4. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/classification/predictor/base.py +1 -0
  5. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/classification/zoo.py +10 -7
  6. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/models/differentiable_binarization.py +21 -6
  7. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/models/fast.py +13 -6
  8. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/models/linknet.py +21 -6
  9. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/zoo.py +7 -3
  10. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/engine.py +2 -2
  11. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/predictor/base.py +5 -1
  12. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/models/crnn.py +21 -6
  13. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/models/master.py +7 -2
  14. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/models/parseq.py +8 -2
  15. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/models/sar.py +9 -2
  16. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/models/vitstr.py +17 -6
  17. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/zoo.py +7 -4
  18. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/zoo.py +6 -0
  19. onnxtr-0.2.0/onnxtr/version.py +1 -0
  20. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr.egg-info/PKG-INFO +37 -11
  21. {onnxtr-0.1.2 → onnxtr-0.2.0}/pyproject.toml +2 -1
  22. {onnxtr-0.1.2 → onnxtr-0.2.0}/setup.py +1 -1
  23. onnxtr-0.1.2/onnxtr/version.py +0 -1
  24. {onnxtr-0.1.2 → onnxtr-0.2.0}/LICENSE +0 -0
  25. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/__init__.py +0 -0
  26. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/contrib/__init__.py +0 -0
  27. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/contrib/artefacts.py +0 -0
  28. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/contrib/base.py +0 -0
  29. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/file_utils.py +0 -0
  30. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/io/__init__.py +0 -0
  31. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/io/elements.py +0 -0
  32. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/io/html.py +0 -0
  33. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/io/image.py +0 -0
  34. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/io/pdf.py +0 -0
  35. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/io/reader.py +0 -0
  36. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/__init__.py +0 -0
  37. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/_utils.py +0 -0
  38. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/builder.py +0 -0
  39. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/classification/__init__.py +0 -0
  40. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/classification/models/__init__.py +0 -0
  41. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/classification/predictor/__init__.py +0 -0
  42. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/__init__.py +0 -0
  43. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/core.py +0 -0
  44. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/models/__init__.py +0 -0
  45. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/postprocessor/__init__.py +0 -0
  46. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/postprocessor/base.py +0 -0
  47. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/predictor/__init__.py +0 -0
  48. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/detection/predictor/base.py +0 -0
  49. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/predictor/__init__.py +0 -0
  50. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/predictor/predictor.py +0 -0
  51. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/preprocessor/__init__.py +0 -0
  52. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/preprocessor/base.py +0 -0
  53. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/__init__.py +0 -0
  54. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/core.py +0 -0
  55. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/models/__init__.py +0 -0
  56. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/predictor/__init__.py +0 -0
  57. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/predictor/_utils.py +0 -0
  58. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/predictor/base.py +0 -0
  59. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/models/recognition/utils.py +0 -0
  60. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/transforms/__init__.py +0 -0
  61. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/transforms/base.py +0 -0
  62. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/__init__.py +0 -0
  63. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/common_types.py +0 -0
  64. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/data.py +0 -0
  65. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/fonts.py +0 -0
  66. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/geometry.py +0 -0
  67. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/multithreading.py +0 -0
  68. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/reconstitution.py +0 -0
  69. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/repr.py +0 -0
  70. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/visualization.py +0 -0
  71. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr/utils/vocabs.py +0 -0
  72. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr.egg-info/SOURCES.txt +0 -0
  73. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr.egg-info/dependency_links.txt +0 -0
  74. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr.egg-info/requires.txt +0 -0
  75. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr.egg-info/top_level.txt +0 -0
  76. {onnxtr-0.1.2 → onnxtr-0.2.0}/onnxtr.egg-info/zip-safe +0 -0
  77. {onnxtr-0.1.2 → onnxtr-0.2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: onnxtr
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: Onnx Text Recognition (OnnxTR): docTR Onnx-Wrapper for high-performance OCR on documents.
5
5
  Author-email: Felix Dittrich <felixdittrich92@gmail.com>
6
6
  Maintainer: Felix Dittrich
@@ -275,7 +275,7 @@ Requires-Dist: pre-commit>=2.17.0; extra == "dev"
275
275
  [![codecov](https://codecov.io/gh/felixdittrich92/OnnxTR/graph/badge.svg?token=WVFRCQBOLI)](https://codecov.io/gh/felixdittrich92/OnnxTR)
276
276
  [![Codacy Badge](https://app.codacy.com/project/badge/Grade/4fff4d764bb14fb8b4f4afeb9587231b)](https://app.codacy.com/gh/felixdittrich92/OnnxTR/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
277
277
  [![CodeFactor](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr/badge)](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr)
278
- [![Pypi](https://img.shields.io/badge/pypi-v0.1.1-blue.svg)](https://pypi.org/project/OnnxTR/)
278
+ [![Pypi](https://img.shields.io/badge/pypi-v0.2.0-blue.svg)](https://pypi.org/project/OnnxTR/)
279
279
 
280
280
  > :warning: Please note that this is a wrapper around the [doctr](https://github.com/mindee/doctr) library to provide a Onnx pipeline for docTR. For feature requests, which are not directly related to the Onnx pipeline, please refer to the base project.
281
281
 
@@ -284,8 +284,9 @@ Requires-Dist: pre-commit>=2.17.0; extra == "dev"
284
284
  What you can expect from this repository:
285
285
 
286
286
  - efficient ways to parse textual information (localize and identify each word) from your documents
287
- - a Onnx pipeline for docTR, a wrapper around the [doctr](https://github.com/mindee/doctr) library
287
+ - a Onnx pipeline for docTR, a wrapper around the [doctr](https://github.com/mindee/doctr) library - no PyTorch or TensorFlow dependencies
288
288
  - more lightweight package with faster inference latency and less required resources
289
+ - 8-Bit quantized models for faster inference on CPU
289
290
 
290
291
  ![OCR_example](https://github.com/felixdittrich92/OnnxTR/raw/main/docs/images/ocr.png)
291
292
 
@@ -358,6 +359,9 @@ model = ocr_predictor(
358
359
  resolve_lines=True, # whether words should be automatically grouped into lines (default: True)
359
360
  resolve_blocks=True, # whether lines should be automatically grouped into blocks (default: True)
360
361
  paragraph_break=0.035, # relative length of the minimum space separating paragraphs (default: 0.035)
362
+ # OnnxTR specific parameters
363
+ # NOTE: 8-Bit quantized models are not available for FAST detection models and can in general lead to poorer accuracy
364
+ load_in_8_bit=False, # set to `True` to load 8-bit quantized models instead of the full precision onces (default: False)
361
365
  )
362
366
  # PDF
363
367
  doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
@@ -438,9 +442,9 @@ predictor.list_archs()
438
442
  'linknet_resnet18',
439
443
  'linknet_resnet34',
440
444
  'linknet_resnet50',
441
- 'fast_tiny',
442
- 'fast_small',
443
- 'fast_base'
445
+ 'fast_tiny', # No 8-bit support
446
+ 'fast_small', # No 8-bit support
447
+ 'fast_base' # No 8-bit support
444
448
  ],
445
449
  'recognition archs':
446
450
  [
@@ -469,14 +473,36 @@ NOTE:
469
473
 
470
474
  ### Benchmarks
471
475
 
472
- The benchmarks was measured on a `i7-14700K Intel CPU`.
476
+ The CPU benchmarks was measured on a `i7-14700K Intel CPU`.
473
477
 
474
- MORE BENCHMARKS COMING SOON
478
+ The GPU benchmarks was measured on a `RTX 4080 Nvidia GPU`.
475
479
 
476
- |Dataset |docTR (CPU) - v0.8.1 |OnnxTR (CPU) - v0.1.1 |
480
+ Benchmarking performed on the FUNSD dataset and CORD dataset.
481
+
482
+ docTR / OnnxTR models used for the benchmarks are `fast_base` (full precision) | `db_resnet50` (8-bit variant) for detection and `crnn_vgg16_bn` for recognition.
483
+
484
+ The smallest combination in OnnxTR (docTR) of `db_mobilenet_v3_large` and `crnn_mobilenet_v3_small` takes as comparison `~0.17s / Page` on the FUNSD dataset and `~0.12s / Page` on the CORD dataset in **full precision**.
485
+
486
+ - CPU benchmarks:
487
+
488
+ |Library |FUNSD (199 pages) |CORD (900 pages) |
489
+ |--------------------------------|-------------------------------|-------------------------------|
490
+ |docTR (CPU) - v0.8.1 | ~1.29s / Page | ~0.60s / Page |
491
+ |**OnnxTR (CPU)** - v0.1.2 | ~0.57s / Page | **~0.25s / Page** |
492
+ |**OnnxTR (CPU) 8-bit** - v0.1.2 | **~0.38s / Page** | **~0.14s / Page** |
493
+ |EasyOCR (CPU) - v1.7.1 | ~1.96s / Page | ~1.75s / Page |
494
+ |**PyTesseract (CPU)** - v0.3.10 | **~0.50s / Page** | ~0.52s / Page |
495
+ |Surya (line) (CPU) - v0.4.4 | ~48.76s / Page | ~35.49s / Page |
496
+
497
+ - GPU benchmarks:
498
+
499
+ |Library |FUNSD (199 pages) |CORD (900 pages) |
477
500
  |--------------------------------|-------------------------------|-------------------------------|
478
- |FUNSD (199 pages) | ~1.29s / Page | ~0.57s / Page |
479
- |CORD (900 pages) | ~0.60s / Page | ~0.25s / Page |
501
+ |docTR (GPU) - v0.8.1 | ~0.07s / Page | ~0.05s / Page |
502
+ |**docTR (GPU) float16** - v0.8.1| **~0.06s / Page** | **~0.03s / Page** |
503
+ |OnnxTR (GPU) - v0.1.2 | **~0.06s / Page** | ~0.04s / Page |
504
+ |EasyOCR (GPU) - v1.7.1 | ~0.31s / Page | ~0.19s / Page |
505
+ |Surya (GPU) float16 - v0.4.4 | ~3.70s / Page | ~2.81s / Page |
480
506
 
481
507
  ## Citation
482
508
 
@@ -7,7 +7,7 @@
7
7
  [![codecov](https://codecov.io/gh/felixdittrich92/OnnxTR/graph/badge.svg?token=WVFRCQBOLI)](https://codecov.io/gh/felixdittrich92/OnnxTR)
8
8
  [![Codacy Badge](https://app.codacy.com/project/badge/Grade/4fff4d764bb14fb8b4f4afeb9587231b)](https://app.codacy.com/gh/felixdittrich92/OnnxTR/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
9
9
  [![CodeFactor](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr/badge)](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr)
10
- [![Pypi](https://img.shields.io/badge/pypi-v0.1.1-blue.svg)](https://pypi.org/project/OnnxTR/)
10
+ [![Pypi](https://img.shields.io/badge/pypi-v0.2.0-blue.svg)](https://pypi.org/project/OnnxTR/)
11
11
 
12
12
  > :warning: Please note that this is a wrapper around the [doctr](https://github.com/mindee/doctr) library to provide a Onnx pipeline for docTR. For feature requests, which are not directly related to the Onnx pipeline, please refer to the base project.
13
13
 
@@ -16,8 +16,9 @@
16
16
  What you can expect from this repository:
17
17
 
18
18
  - efficient ways to parse textual information (localize and identify each word) from your documents
19
- - a Onnx pipeline for docTR, a wrapper around the [doctr](https://github.com/mindee/doctr) library
19
+ - a Onnx pipeline for docTR, a wrapper around the [doctr](https://github.com/mindee/doctr) library - no PyTorch or TensorFlow dependencies
20
20
  - more lightweight package with faster inference latency and less required resources
21
+ - 8-Bit quantized models for faster inference on CPU
21
22
 
22
23
  ![OCR_example](https://github.com/felixdittrich92/OnnxTR/raw/main/docs/images/ocr.png)
23
24
 
@@ -90,6 +91,9 @@ model = ocr_predictor(
90
91
  resolve_lines=True, # whether words should be automatically grouped into lines (default: True)
91
92
  resolve_blocks=True, # whether lines should be automatically grouped into blocks (default: True)
92
93
  paragraph_break=0.035, # relative length of the minimum space separating paragraphs (default: 0.035)
94
+ # OnnxTR specific parameters
95
+ # NOTE: 8-Bit quantized models are not available for FAST detection models and can in general lead to poorer accuracy
96
+ load_in_8_bit=False, # set to `True` to load 8-bit quantized models instead of the full precision onces (default: False)
93
97
  )
94
98
  # PDF
95
99
  doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
@@ -170,9 +174,9 @@ predictor.list_archs()
170
174
  'linknet_resnet18',
171
175
  'linknet_resnet34',
172
176
  'linknet_resnet50',
173
- 'fast_tiny',
174
- 'fast_small',
175
- 'fast_base'
177
+ 'fast_tiny', # No 8-bit support
178
+ 'fast_small', # No 8-bit support
179
+ 'fast_base' # No 8-bit support
176
180
  ],
177
181
  'recognition archs':
178
182
  [
@@ -201,14 +205,36 @@ NOTE:
201
205
 
202
206
  ### Benchmarks
203
207
 
204
- The benchmarks was measured on a `i7-14700K Intel CPU`.
208
+ The CPU benchmarks was measured on a `i7-14700K Intel CPU`.
205
209
 
206
- MORE BENCHMARKS COMING SOON
210
+ The GPU benchmarks was measured on a `RTX 4080 Nvidia GPU`.
207
211
 
208
- |Dataset |docTR (CPU) - v0.8.1 |OnnxTR (CPU) - v0.1.1 |
212
+ Benchmarking performed on the FUNSD dataset and CORD dataset.
213
+
214
+ docTR / OnnxTR models used for the benchmarks are `fast_base` (full precision) | `db_resnet50` (8-bit variant) for detection and `crnn_vgg16_bn` for recognition.
215
+
216
+ The smallest combination in OnnxTR (docTR) of `db_mobilenet_v3_large` and `crnn_mobilenet_v3_small` takes as comparison `~0.17s / Page` on the FUNSD dataset and `~0.12s / Page` on the CORD dataset in **full precision**.
217
+
218
+ - CPU benchmarks:
219
+
220
+ |Library |FUNSD (199 pages) |CORD (900 pages) |
221
+ |--------------------------------|-------------------------------|-------------------------------|
222
+ |docTR (CPU) - v0.8.1 | ~1.29s / Page | ~0.60s / Page |
223
+ |**OnnxTR (CPU)** - v0.1.2 | ~0.57s / Page | **~0.25s / Page** |
224
+ |**OnnxTR (CPU) 8-bit** - v0.1.2 | **~0.38s / Page** | **~0.14s / Page** |
225
+ |EasyOCR (CPU) - v1.7.1 | ~1.96s / Page | ~1.75s / Page |
226
+ |**PyTesseract (CPU)** - v0.3.10 | **~0.50s / Page** | ~0.52s / Page |
227
+ |Surya (line) (CPU) - v0.4.4 | ~48.76s / Page | ~35.49s / Page |
228
+
229
+ - GPU benchmarks:
230
+
231
+ |Library |FUNSD (199 pages) |CORD (900 pages) |
209
232
  |--------------------------------|-------------------------------|-------------------------------|
210
- |FUNSD (199 pages) | ~1.29s / Page | ~0.57s / Page |
211
- |CORD (900 pages) | ~0.60s / Page | ~0.25s / Page |
233
+ |docTR (GPU) - v0.8.1 | ~0.07s / Page | ~0.05s / Page |
234
+ |**docTR (GPU) float16** - v0.8.1| **~0.06s / Page** | **~0.03s / Page** |
235
+ |OnnxTR (GPU) - v0.1.2 | **~0.06s / Page** | ~0.04s / Page |
236
+ |EasyOCR (GPU) - v1.7.1 | ~0.31s / Page | ~0.19s / Page |
237
+ |Surya (GPU) float16 - v0.4.4 | ~3.70s / Page | ~2.81s / Page |
212
238
 
213
239
  ## Citation
214
240
 
@@ -24,6 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
24
24
  "input_shape": (3, 256, 256),
25
25
  "classes": [0, -90, 180, 90],
26
26
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/mobilenet_v3_small_crop_orientation-5620cf7e.onnx",
27
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/mobilenet_v3_small_crop_orientation_static_8_bit-4cfaa621.onnx",
27
28
  },
28
29
  "mobilenet_v3_small_page_orientation": {
29
30
  "mean": (0.694, 0.695, 0.693),
@@ -31,6 +32,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
31
32
  "input_shape": (3, 512, 512),
32
33
  "classes": [0, -90, 180, 90],
33
34
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/mobilenet_v3_small_page_orientation-d3f76d79.onnx",
35
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/mobilenet_v3_small_page_orientation_static_8_bit-3e5ef3dc.onnx",
34
36
  },
35
37
  }
36
38
 
@@ -64,14 +66,19 @@ class MobileNetV3(Engine):
64
66
  def _mobilenet_v3(
65
67
  arch: str,
66
68
  model_path: str,
69
+ load_in_8_bit: bool = False,
67
70
  **kwargs: Any,
68
71
  ) -> MobileNetV3:
72
+ # Patch the url
73
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
69
74
  _cfg = deepcopy(default_cfgs[arch])
70
75
  return MobileNetV3(model_path, cfg=_cfg, **kwargs)
71
76
 
72
77
 
73
78
  def mobilenet_v3_small_crop_orientation(
74
- model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"], **kwargs: Any
79
+ model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
80
+ load_in_8_bit: bool = False,
81
+ **kwargs: Any,
75
82
  ) -> MobileNetV3:
76
83
  """MobileNetV3-Small architecture as described in
77
84
  `"Searching for MobileNetV3",
@@ -86,17 +93,20 @@ def mobilenet_v3_small_crop_orientation(
86
93
  Args:
87
94
  ----
88
95
  model_path: path to onnx model file, defaults to url in default_cfgs
96
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
89
97
  **kwargs: keyword arguments of the MobileNetV3 architecture
90
98
 
91
99
  Returns:
92
100
  -------
93
101
  MobileNetV3
94
102
  """
95
- return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, **kwargs)
103
+ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, **kwargs)
96
104
 
97
105
 
98
106
  def mobilenet_v3_small_page_orientation(
99
- model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"], **kwargs: Any
107
+ model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
108
+ load_in_8_bit: bool = False,
109
+ **kwargs: Any,
100
110
  ) -> MobileNetV3:
101
111
  """MobileNetV3-Small architecture as described in
102
112
  `"Searching for MobileNetV3",
@@ -111,10 +121,11 @@ def mobilenet_v3_small_page_orientation(
111
121
  Args:
112
122
  ----
113
123
  model_path: path to onnx model file, defaults to url in default_cfgs
124
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
114
125
  **kwargs: keyword arguments of the MobileNetV3 architecture
115
126
 
116
127
  Returns:
117
128
  -------
118
129
  MobileNetV3
119
130
  """
120
- return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, **kwargs)
131
+ return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, **kwargs)
@@ -22,6 +22,7 @@ class OrientationPredictor(NestedObject):
22
22
  ----
23
23
  pre_processor: transform inputs for easier batched model inference
24
24
  model: core classification architecture (backbone + classification head)
25
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
25
26
  """
26
27
 
27
28
  _children_names: List[str] = ["pre_processor", "model"]
@@ -14,24 +14,25 @@ __all__ = ["crop_orientation_predictor", "page_orientation_predictor"]
14
14
  ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]
15
15
 
16
16
 
17
- def _orientation_predictor(arch: str, **kwargs: Any) -> OrientationPredictor:
17
+ def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any) -> OrientationPredictor:
18
18
  if arch not in ORIENTATION_ARCHS:
19
19
  raise ValueError(f"unknown architecture '{arch}'")
20
20
 
21
21
  # Load directly classifier from backbone
22
- _model = classification.__dict__[arch]()
22
+ _model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit)
23
23
  kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
24
24
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
25
25
  kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
26
26
  input_shape = _model.cfg["input_shape"][1:]
27
27
  predictor = OrientationPredictor(
28
- PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
28
+ PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs),
29
+ _model,
29
30
  )
30
31
  return predictor
31
32
 
32
33
 
33
34
  def crop_orientation_predictor(
34
- arch: Any = "mobilenet_v3_small_crop_orientation", **kwargs: Any
35
+ arch: Any = "mobilenet_v3_small_crop_orientation", load_in_8_bit: bool = False, **kwargs: Any
35
36
  ) -> OrientationPredictor:
36
37
  """Crop orientation classification architecture.
37
38
 
@@ -44,17 +45,18 @@ def crop_orientation_predictor(
44
45
  Args:
45
46
  ----
46
47
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
48
+ load_in_8_bit: load the 8-bit quantized version of the model
47
49
  **kwargs: keyword arguments to be passed to the OrientationPredictor
48
50
 
49
51
  Returns:
50
52
  -------
51
53
  OrientationPredictor
52
54
  """
53
- return _orientation_predictor(arch, **kwargs)
55
+ return _orientation_predictor(arch, load_in_8_bit, **kwargs)
54
56
 
55
57
 
56
58
  def page_orientation_predictor(
57
- arch: Any = "mobilenet_v3_small_page_orientation", **kwargs: Any
59
+ arch: Any = "mobilenet_v3_small_page_orientation", load_in_8_bit: bool = False, **kwargs: Any
58
60
  ) -> OrientationPredictor:
59
61
  """Page orientation classification architecture.
60
62
 
@@ -67,10 +69,11 @@ def page_orientation_predictor(
67
69
  Args:
68
70
  ----
69
71
  arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
72
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
70
73
  **kwargs: keyword arguments to be passed to the OrientationPredictor
71
74
 
72
75
  Returns:
73
76
  -------
74
77
  OrientationPredictor
75
78
  """
76
- return _orientation_predictor(arch, **kwargs)
79
+ return _orientation_predictor(arch, load_in_8_bit, **kwargs)
@@ -20,18 +20,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
20
20
  "mean": (0.798, 0.785, 0.772),
21
21
  "std": (0.264, 0.2749, 0.287),
22
22
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet50-69ba0015.onnx",
23
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet50_static_8_bit-09a6104f.onnx",
23
24
  },
24
25
  "db_resnet34": {
25
26
  "input_shape": (3, 1024, 1024),
26
27
  "mean": (0.798, 0.785, 0.772),
27
28
  "std": (0.264, 0.2749, 0.287),
28
29
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet34-b4873198.onnx",
30
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet34_static_8_bit-027e2c7f.onnx",
29
31
  },
30
32
  "db_mobilenet_v3_large": {
31
33
  "input_shape": (3, 1024, 1024),
32
34
  "mean": (0.798, 0.785, 0.772),
33
35
  "std": (0.264, 0.2749, 0.287),
34
36
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
37
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_mobilenet_v3_large_static_8_bit-51659bb9.onnx",
35
38
  },
36
39
  }
37
40
 
@@ -87,13 +90,18 @@ class DBNet(Engine):
87
90
  def _dbnet(
88
91
  arch: str,
89
92
  model_path: str,
93
+ load_in_8_bit: bool = False,
90
94
  **kwargs: Any,
91
95
  ) -> DBNet:
96
+ # Patch the url
97
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
92
98
  # Build the model
93
99
  return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)
94
100
 
95
101
 
96
- def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs: Any) -> DBNet:
102
+ def db_resnet34(
103
+ model_path: str = default_cfgs["db_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
104
+ ) -> DBNet:
97
105
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
98
106
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
99
107
 
@@ -106,16 +114,19 @@ def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs:
106
114
  Args:
107
115
  ----
108
116
  model_path: path to onnx model file, defaults to url in default_cfgs
117
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
109
118
  **kwargs: keyword arguments of the DBNet architecture
110
119
 
111
120
  Returns:
112
121
  -------
113
122
  text detection architecture
114
123
  """
115
- return _dbnet("db_resnet34", model_path, **kwargs)
124
+ return _dbnet("db_resnet34", model_path, load_in_8_bit, **kwargs)
116
125
 
117
126
 
118
- def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs: Any) -> DBNet:
127
+ def db_resnet50(
128
+ model_path: str = default_cfgs["db_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
129
+ ) -> DBNet:
119
130
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
120
131
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
121
132
 
@@ -128,16 +139,19 @@ def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs:
128
139
  Args:
129
140
  ----
130
141
  model_path: path to onnx model file, defaults to url in default_cfgs
142
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
131
143
  **kwargs: keyword arguments of the DBNet architecture
132
144
 
133
145
  Returns:
134
146
  -------
135
147
  text detection architecture
136
148
  """
137
- return _dbnet("db_resnet50", model_path, **kwargs)
149
+ return _dbnet("db_resnet50", model_path, load_in_8_bit, **kwargs)
138
150
 
139
151
 
140
- def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], **kwargs: Any) -> DBNet:
152
+ def db_mobilenet_v3_large(
153
+ model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any
154
+ ) -> DBNet:
141
155
  """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
142
156
  <https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
143
157
 
@@ -150,10 +164,11 @@ def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"
150
164
  Args:
151
165
  ----
152
166
  model_path: path to onnx model file, defaults to url in default_cfgs
167
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
153
168
  **kwargs: keyword arguments of the DBNet architecture
154
169
 
155
170
  Returns:
156
171
  -------
157
172
  text detection architecture
158
173
  """
159
- return _dbnet("db_mobilenet_v3_large", model_path, **kwargs)
174
+ return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
@@ -3,6 +3,7 @@
3
3
  # This program is licensed under the Apache License 2.0.
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
+ import logging
6
7
  from typing import Any, Dict, Optional
7
8
 
8
9
  import numpy as np
@@ -88,13 +89,16 @@ class FAST(Engine):
88
89
  def _fast(
89
90
  arch: str,
90
91
  model_path: str,
92
+ load_in_8_bit: bool = False,
91
93
  **kwargs: Any,
92
94
  ) -> FAST:
95
+ if load_in_8_bit:
96
+ logging.warning("FAST models do not support 8-bit quantization yet. Loading full precision model...")
93
97
  # Build the model
94
98
  return FAST(model_path, cfg=default_cfgs[arch], **kwargs)
95
99
 
96
100
 
97
- def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any) -> FAST:
101
+ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
98
102
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
99
103
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
100
104
 
@@ -107,16 +111,17 @@ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any)
107
111
  Args:
108
112
  ----
109
113
  model_path: path to onnx model file, defaults to url in default_cfgs
114
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
110
115
  **kwargs: keyword arguments of the DBNet architecture
111
116
 
112
117
  Returns:
113
118
  -------
114
119
  text detection architecture
115
120
  """
116
- return _fast("fast_tiny", model_path, **kwargs)
121
+ return _fast("fast_tiny", model_path, load_in_8_bit, **kwargs)
117
122
 
118
123
 
119
- def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: Any) -> FAST:
124
+ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
120
125
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
121
126
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
122
127
 
@@ -129,16 +134,17 @@ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: An
129
134
  Args:
130
135
  ----
131
136
  model_path: path to onnx model file, defaults to url in default_cfgs
137
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
132
138
  **kwargs: keyword arguments of the DBNet architecture
133
139
 
134
140
  Returns:
135
141
  -------
136
142
  text detection architecture
137
143
  """
138
- return _fast("fast_small", model_path, **kwargs)
144
+ return _fast("fast_small", model_path, load_in_8_bit, **kwargs)
139
145
 
140
146
 
141
- def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any) -> FAST:
147
+ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
142
148
  """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
143
149
  <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
144
150
 
@@ -151,10 +157,11 @@ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any)
151
157
  Args:
152
158
  ----
153
159
  model_path: path to onnx model file, defaults to url in default_cfgs
160
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
154
161
  **kwargs: keyword arguments of the DBNet architecture
155
162
 
156
163
  Returns:
157
164
  -------
158
165
  text detection architecture
159
166
  """
160
- return _fast("fast_base", model_path, **kwargs)
167
+ return _fast("fast_base", model_path, load_in_8_bit, **kwargs)
@@ -20,18 +20,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
20
20
  "mean": (0.798, 0.785, 0.772),
21
21
  "std": (0.264, 0.2749, 0.287),
22
22
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet18-e0e0b9dc.onnx",
23
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet18_static_8_bit-3b3a37dd.onnx",
23
24
  },
24
25
  "linknet_resnet34": {
25
26
  "input_shape": (3, 1024, 1024),
26
27
  "mean": (0.798, 0.785, 0.772),
27
28
  "std": (0.264, 0.2749, 0.287),
28
29
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet34-93e39a39.onnx",
30
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet34_static_8_bit-2824329d.onnx",
29
31
  },
30
32
  "linknet_resnet50": {
31
33
  "input_shape": (3, 1024, 1024),
32
34
  "mean": (0.798, 0.785, 0.772),
33
35
  "std": (0.264, 0.2749, 0.287),
34
36
  "url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/linknet_resnet50-15d8c4ec.onnx",
37
+ "url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/linknet_resnet50_static_8_bit-65d6b0b8.onnx",
35
38
  },
36
39
  }
37
40
 
@@ -88,13 +91,18 @@ class LinkNet(Engine):
88
91
  def _linknet(
89
92
  arch: str,
90
93
  model_path: str,
94
+ load_in_8_bit: bool = False,
91
95
  **kwargs: Any,
92
96
  ) -> LinkNet:
97
+ # Patch the url
98
+ model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
93
99
  # Build the model
94
100
  return LinkNet(model_path, cfg=default_cfgs[arch], **kwargs)
95
101
 
96
102
 
97
- def linknet_resnet18(model_path: str = default_cfgs["linknet_resnet18"]["url"], **kwargs: Any) -> LinkNet:
103
+ def linknet_resnet18(
104
+ model_path: str = default_cfgs["linknet_resnet18"]["url"], load_in_8_bit: bool = False, **kwargs: Any
105
+ ) -> LinkNet:
98
106
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
99
107
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
100
108
 
@@ -107,16 +115,19 @@ def linknet_resnet18(model_path: str = default_cfgs["linknet_resnet18"]["url"],
107
115
  Args:
108
116
  ----
109
117
  model_path: path to onnx model file, defaults to url in default_cfgs
118
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
110
119
  **kwargs: keyword arguments of the LinkNet architecture
111
120
 
112
121
  Returns:
113
122
  -------
114
123
  text detection architecture
115
124
  """
116
- return _linknet("linknet_resnet18", model_path, **kwargs)
125
+ return _linknet("linknet_resnet18", model_path, load_in_8_bit, **kwargs)
117
126
 
118
127
 
119
- def linknet_resnet34(model_path: str = default_cfgs["linknet_resnet34"]["url"], **kwargs: Any) -> LinkNet:
128
+ def linknet_resnet34(
129
+ model_path: str = default_cfgs["linknet_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
130
+ ) -> LinkNet:
120
131
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
121
132
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
122
133
 
@@ -129,16 +140,19 @@ def linknet_resnet34(model_path: str = default_cfgs["linknet_resnet34"]["url"],
129
140
  Args:
130
141
  ----
131
142
  model_path: path to onnx model file, defaults to url in default_cfgs
143
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
132
144
  **kwargs: keyword arguments of the LinkNet architecture
133
145
 
134
146
  Returns:
135
147
  -------
136
148
  text detection architecture
137
149
  """
138
- return _linknet("linknet_resnet34", model_path, **kwargs)
150
+ return _linknet("linknet_resnet34", model_path, load_in_8_bit, **kwargs)
139
151
 
140
152
 
141
- def linknet_resnet50(model_path: str = default_cfgs["linknet_resnet50"]["url"], **kwargs: Any) -> LinkNet:
153
+ def linknet_resnet50(
154
+ model_path: str = default_cfgs["linknet_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
155
+ ) -> LinkNet:
142
156
  """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
143
157
  <https://arxiv.org/pdf/1707.03718.pdf>`_.
144
158
 
@@ -151,10 +165,11 @@ def linknet_resnet50(model_path: str = default_cfgs["linknet_resnet50"]["url"],
151
165
  Args:
152
166
  ----
153
167
  model_path: path to onnx model file, defaults to url in default_cfgs
168
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
154
169
  **kwargs: keyword arguments of the LinkNet architecture
155
170
 
156
171
  Returns:
157
172
  -------
158
173
  text detection architecture
159
174
  """
160
- return _linknet("linknet_resnet50", model_path, **kwargs)
175
+ return _linknet("linknet_resnet50", model_path, load_in_8_bit, **kwargs)
@@ -24,12 +24,14 @@ ARCHS = [
24
24
  ]
25
25
 
26
26
 
27
- def _predictor(arch: Any, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
27
+ def _predictor(
28
+ arch: Any, assume_straight_pages: bool = True, load_in_8_bit: bool = False, **kwargs: Any
29
+ ) -> DetectionPredictor:
28
30
  if isinstance(arch, str):
29
31
  if arch not in ARCHS:
30
32
  raise ValueError(f"unknown architecture '{arch}'")
31
33
 
32
- _model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages)
34
+ _model = detection.__dict__[arch](assume_straight_pages=assume_straight_pages, load_in_8_bit=load_in_8_bit)
33
35
  else:
34
36
  if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
35
37
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -50,6 +52,7 @@ def _predictor(arch: Any, assume_straight_pages: bool = True, **kwargs: Any) ->
50
52
  def detection_predictor(
51
53
  arch: Any = "fast_base",
52
54
  assume_straight_pages: bool = True,
55
+ load_in_8_bit: bool = False,
53
56
  **kwargs: Any,
54
57
  ) -> DetectionPredictor:
55
58
  """Text detection architecture.
@@ -64,10 +67,11 @@ def detection_predictor(
64
67
  ----
65
68
  arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
66
69
  assume_straight_pages: If True, fit straight boxes to the page
70
+ load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
67
71
  **kwargs: optional keyword arguments passed to the architecture
68
72
 
69
73
  Returns:
70
74
  -------
71
75
  Detection predictor
72
76
  """
73
- return _predictor(arch, assume_straight_pages, **kwargs)
77
+ return _predictor(arch, assume_straight_pages, load_in_8_bit, **kwargs)
@@ -43,8 +43,8 @@ class Engine:
43
43
  inputs = np.broadcast_to(inputs, (self.fixed_batch_size, *inputs.shape))
44
44
  # combine the results
45
45
  logits = np.concatenate(
46
- [self.runtime.run(self.output_name, {"input": batch})[0] for batch in inputs], axis=0
46
+ [self.runtime.run(self.output_name, {self.runtime_inputs.name: batch})[0] for batch in inputs], axis=0
47
47
  )
48
48
  else:
49
- logits = self.runtime.run(self.output_name, {"input": inputs})[0]
49
+ logits = self.runtime.run(self.output_name, {self.runtime_inputs.name: inputs})[0]
50
50
  return shape_translate(logits, format="BHWC")