onnxtr 0.3.1__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. {onnxtr-0.3.1 → onnxtr-0.4.0}/PKG-INFO +66 -2
  2. {onnxtr-0.3.1 → onnxtr-0.4.0}/README.md +64 -1
  3. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/__init__.py +1 -0
  4. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/classification/models/mobilenet.py +1 -0
  5. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/models/differentiable_binarization.py +2 -0
  6. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/models/fast.py +1 -0
  7. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/models/linknet.py +1 -0
  8. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/engine.py +2 -0
  9. onnxtr-0.4.0/onnxtr/models/factory/__init__.py +1 -0
  10. onnxtr-0.4.0/onnxtr/models/factory/hub.py +224 -0
  11. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/models/crnn.py +2 -0
  12. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/models/master.py +1 -0
  13. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/models/parseq.py +2 -0
  14. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/models/sar.py +2 -0
  15. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/models/vitstr.py +1 -0
  16. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/transforms/base.py +33 -46
  17. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/fonts.py +5 -3
  18. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/vocabs.py +11 -4
  19. onnxtr-0.4.0/onnxtr/version.py +1 -0
  20. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr.egg-info/PKG-INFO +66 -2
  21. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr.egg-info/SOURCES.txt +2 -0
  22. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr.egg-info/requires.txt +1 -0
  23. {onnxtr-0.3.1 → onnxtr-0.4.0}/pyproject.toml +2 -0
  24. {onnxtr-0.3.1 → onnxtr-0.4.0}/setup.py +1 -1
  25. onnxtr-0.3.1/onnxtr/version.py +0 -1
  26. {onnxtr-0.3.1 → onnxtr-0.4.0}/LICENSE +0 -0
  27. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/__init__.py +0 -0
  28. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/contrib/__init__.py +0 -0
  29. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/contrib/artefacts.py +0 -0
  30. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/contrib/base.py +0 -0
  31. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/file_utils.py +0 -0
  32. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/io/__init__.py +0 -0
  33. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/io/elements.py +0 -0
  34. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/io/html.py +0 -0
  35. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/io/image.py +0 -0
  36. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/io/pdf.py +0 -0
  37. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/io/reader.py +0 -0
  38. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/_utils.py +0 -0
  39. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/builder.py +0 -0
  40. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/classification/__init__.py +0 -0
  41. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/classification/models/__init__.py +0 -0
  42. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/classification/predictor/__init__.py +0 -0
  43. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/classification/predictor/base.py +0 -0
  44. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/classification/zoo.py +0 -0
  45. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/__init__.py +0 -0
  46. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/_utils/__init__.py +0 -0
  47. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/_utils/base.py +0 -0
  48. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/core.py +0 -0
  49. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/models/__init__.py +0 -0
  50. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/postprocessor/__init__.py +0 -0
  51. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/postprocessor/base.py +0 -0
  52. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/predictor/__init__.py +0 -0
  53. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/predictor/base.py +0 -0
  54. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/detection/zoo.py +0 -0
  55. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/predictor/__init__.py +0 -0
  56. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/predictor/base.py +0 -0
  57. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/predictor/predictor.py +0 -0
  58. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/preprocessor/__init__.py +0 -0
  59. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/preprocessor/base.py +6 -6
  60. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/__init__.py +0 -0
  61. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/core.py +0 -0
  62. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/models/__init__.py +0 -0
  63. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/predictor/__init__.py +0 -0
  64. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/predictor/_utils.py +0 -0
  65. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/predictor/base.py +0 -0
  66. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/utils.py +0 -0
  67. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/recognition/zoo.py +0 -0
  68. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/models/zoo.py +0 -0
  69. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/py.typed +0 -0
  70. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/transforms/__init__.py +0 -0
  71. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/__init__.py +0 -0
  72. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/common_types.py +0 -0
  73. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/data.py +0 -0
  74. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/geometry.py +0 -0
  75. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/multithreading.py +0 -0
  76. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/reconstitution.py +0 -0
  77. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/repr.py +0 -0
  78. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr/utils/visualization.py +0 -0
  79. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr.egg-info/dependency_links.txt +0 -0
  80. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr.egg-info/top_level.txt +0 -0
  81. {onnxtr-0.3.1 → onnxtr-0.4.0}/onnxtr.egg-info/zip-safe +0 -0
  82. {onnxtr-0.3.1 → onnxtr-0.4.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: onnxtr
3
- Version: 0.3.1
3
+ Version: 0.4.0
4
4
  Summary: Onnx Text Recognition (OnnxTR): docTR Onnx-Wrapper for high-performance OCR on documents.
5
5
  Author-email: Felix Dittrich <felixdittrich92@gmail.com>
6
6
  Maintainer: Felix Dittrich
@@ -233,6 +233,7 @@ Requires-Dist: pyclipper<2.0.0,>=1.2.0
233
233
  Requires-Dist: shapely<3.0.0,>=1.6.0
234
234
  Requires-Dist: rapidfuzz<4.0.0,>=3.0.0
235
235
  Requires-Dist: langdetect<2.0.0,>=1.0.9
236
+ Requires-Dist: huggingface-hub<1.0.0,>=0.23.0
236
237
  Requires-Dist: Pillow>=9.2.0
237
238
  Requires-Dist: defusedxml>=0.7.0
238
239
  Requires-Dist: anyascii>=0.3.2
@@ -275,7 +276,7 @@ Requires-Dist: pre-commit>=2.17.0; extra == "dev"
275
276
  [![codecov](https://codecov.io/gh/felixdittrich92/OnnxTR/graph/badge.svg?token=WVFRCQBOLI)](https://codecov.io/gh/felixdittrich92/OnnxTR)
276
277
  [![Codacy Badge](https://app.codacy.com/project/badge/Grade/4fff4d764bb14fb8b4f4afeb9587231b)](https://app.codacy.com/gh/felixdittrich92/OnnxTR/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
277
278
  [![CodeFactor](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr/badge)](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr)
278
- [![Pypi](https://img.shields.io/badge/pypi-v0.3.1-blue.svg)](https://pypi.org/project/OnnxTR/)
279
+ [![Pypi](https://img.shields.io/badge/pypi-v0.3.2-blue.svg)](https://pypi.org/project/OnnxTR/)
279
280
 
280
281
  > :warning: Please note that this is a wrapper around the [doctr](https://github.com/mindee/doctr) library to provide a Onnx pipeline for docTR. For feature requests, which are not directly related to the Onnx pipeline, please refer to the base project.
281
282
 
@@ -449,6 +450,69 @@ det_model = linknet_resnet18("path_to_custom_model.onnx")
449
450
  model = ocr_predictor(det_arch=det_model, reco_arch=reco_model)
450
451
  ```
451
452
 
453
+ ## Loading models from HuggingFace Hub
454
+
455
+ You can also load models from the HuggingFace Hub:
456
+
457
+ ```python
458
+ from onnxtr.io import DocumentFile
459
+ from onnxtr.models import ocr_predictor, from_hub
460
+
461
+ img = DocumentFile.from_images(['<image_path>'])
462
+ # Load your model from the hub
463
+ model = from_hub('onnxtr/my-model')
464
+
465
+ # Pass it to the predictor
466
+ # If your model is a recognition model:
467
+ predictor = ocr_predictor(
468
+ det_arch='db_mobilenet_v3_large',
469
+ reco_arch=model
470
+ )
471
+
472
+ # If your model is a detection model:
473
+ predictor = ocr_predictor(
474
+ det_arch=model,
475
+ reco_arch='crnn_mobilenet_v3_small'
476
+ )
477
+
478
+ # Get your predictions
479
+ res = predictor(img)
480
+ ```
481
+
482
+ HF Hub search: [here](https://huggingface.co/models?search=onnxtr).
483
+
484
+ Collection: [here](https://huggingface.co/collections/Felix92/onnxtr-66bf213a9f88f7346c90e842)
485
+
486
+ Or push your own models to the hub:
487
+
488
+ ```python
489
+ from onnxtr.models import parseq, push_to_hf_hub, login_to_hub
490
+ from onnxtr.utils.vocabs import VOCABS
491
+
492
+ # Login to the hub
493
+ login_to_hub()
494
+
495
+ # Recogniton model
496
+ model = parseq("~/onnxtr-parseq-multilingual-v1.onnx", vocab=VOCABS["multilingual"])
497
+ push_to_hf_hub(
498
+ model,
499
+ model_name="onnxtr-parseq-multilingual-v1",
500
+ task="recognition", # The task for which the model is intended [detection, recognition, classification]
501
+ arch="parseq", # The name of the model architecture
502
+ override=False # Set to `True` if you want to override an existing model / repository
503
+ )
504
+
505
+ # Detection model
506
+ model = linknet_resnet18("~/onnxtr-linknet-resnet18.onnx")
507
+ push_to_hf_hub(
508
+ model,
509
+ model_name="onnxtr-linknet-resnet18",
510
+ task="detection",
511
+ arch="linknet_resnet18",
512
+ override=True
513
+ )
514
+ ```
515
+
452
516
  ## Models architectures
453
517
 
454
518
  Credits where it's due: this repository provides ONNX models for the following architectures, converted from the docTR models:
@@ -7,7 +7,7 @@
7
7
  [![codecov](https://codecov.io/gh/felixdittrich92/OnnxTR/graph/badge.svg?token=WVFRCQBOLI)](https://codecov.io/gh/felixdittrich92/OnnxTR)
8
8
  [![Codacy Badge](https://app.codacy.com/project/badge/Grade/4fff4d764bb14fb8b4f4afeb9587231b)](https://app.codacy.com/gh/felixdittrich92/OnnxTR/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
9
9
  [![CodeFactor](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr/badge)](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr)
10
- [![Pypi](https://img.shields.io/badge/pypi-v0.3.1-blue.svg)](https://pypi.org/project/OnnxTR/)
10
+ [![Pypi](https://img.shields.io/badge/pypi-v0.3.2-blue.svg)](https://pypi.org/project/OnnxTR/)
11
11
 
12
12
  > :warning: Please note that this is a wrapper around the [doctr](https://github.com/mindee/doctr) library to provide a Onnx pipeline for docTR. For feature requests, which are not directly related to the Onnx pipeline, please refer to the base project.
13
13
 
@@ -181,6 +181,69 @@ det_model = linknet_resnet18("path_to_custom_model.onnx")
181
181
  model = ocr_predictor(det_arch=det_model, reco_arch=reco_model)
182
182
  ```
183
183
 
184
+ ## Loading models from HuggingFace Hub
185
+
186
+ You can also load models from the HuggingFace Hub:
187
+
188
+ ```python
189
+ from onnxtr.io import DocumentFile
190
+ from onnxtr.models import ocr_predictor, from_hub
191
+
192
+ img = DocumentFile.from_images(['<image_path>'])
193
+ # Load your model from the hub
194
+ model = from_hub('onnxtr/my-model')
195
+
196
+ # Pass it to the predictor
197
+ # If your model is a recognition model:
198
+ predictor = ocr_predictor(
199
+ det_arch='db_mobilenet_v3_large',
200
+ reco_arch=model
201
+ )
202
+
203
+ # If your model is a detection model:
204
+ predictor = ocr_predictor(
205
+ det_arch=model,
206
+ reco_arch='crnn_mobilenet_v3_small'
207
+ )
208
+
209
+ # Get your predictions
210
+ res = predictor(img)
211
+ ```
212
+
213
+ HF Hub search: [here](https://huggingface.co/models?search=onnxtr).
214
+
215
+ Collection: [here](https://huggingface.co/collections/Felix92/onnxtr-66bf213a9f88f7346c90e842)
216
+
217
+ Or push your own models to the hub:
218
+
219
+ ```python
220
+ from onnxtr.models import parseq, push_to_hf_hub, login_to_hub
221
+ from onnxtr.utils.vocabs import VOCABS
222
+
223
+ # Login to the hub
224
+ login_to_hub()
225
+
226
+ # Recogniton model
227
+ model = parseq("~/onnxtr-parseq-multilingual-v1.onnx", vocab=VOCABS["multilingual"])
228
+ push_to_hf_hub(
229
+ model,
230
+ model_name="onnxtr-parseq-multilingual-v1",
231
+ task="recognition", # The task for which the model is intended [detection, recognition, classification]
232
+ arch="parseq", # The name of the model architecture
233
+ override=False # Set to `True` if you want to override an existing model / repository
234
+ )
235
+
236
+ # Detection model
237
+ model = linknet_resnet18("~/onnxtr-linknet-resnet18.onnx")
238
+ push_to_hf_hub(
239
+ model,
240
+ model_name="onnxtr-linknet-resnet18",
241
+ task="detection",
242
+ arch="linknet_resnet18",
243
+ override=True
244
+ )
245
+ ```
246
+
184
247
  ## Models architectures
185
248
 
186
249
  Credits where it's due: this repository provides ONNX models for the following architectures, converted from the docTR models:
@@ -3,3 +3,4 @@ from .classification import *
3
3
  from .detection import *
4
4
  from .recognition import *
5
5
  from .zoo import *
6
+ from .factory import *
@@ -56,6 +56,7 @@ class MobileNetV3(Engine):
56
56
  **kwargs: Any,
57
57
  ) -> None:
58
58
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
59
+
59
60
  self.cfg = cfg
60
61
 
61
62
  def __call__(
@@ -64,8 +64,10 @@ class DBNet(Engine):
64
64
  **kwargs: Any,
65
65
  ) -> None:
66
66
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
67
+
67
68
  self.cfg = cfg
68
69
  self.assume_straight_pages = assume_straight_pages
70
+
69
71
  self.postprocessor = GeneralDetectionPostProcessor(
70
72
  assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
71
73
  )
@@ -62,6 +62,7 @@ class FAST(Engine):
62
62
  **kwargs: Any,
63
63
  ) -> None:
64
64
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
65
+
65
66
  self.cfg = cfg
66
67
  self.assume_straight_pages = assume_straight_pages
67
68
 
@@ -64,6 +64,7 @@ class LinkNet(Engine):
64
64
  **kwargs: Any,
65
65
  ) -> None:
66
66
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
67
+
67
68
  self.cfg = cfg
68
69
  self.assume_straight_pages = assume_straight_pages
69
70
 
@@ -90,6 +90,8 @@ class Engine:
90
90
  def __init__(self, url: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any) -> None:
91
91
  engine_cfg = engine_cfg if isinstance(engine_cfg, EngineConfig) else EngineConfig()
92
92
  archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
93
+ # Store model path for each model
94
+ self.model_path = archive_path
93
95
  self.session_options = engine_cfg.session_options
94
96
  self.providers = engine_cfg.providers
95
97
  self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
@@ -0,0 +1 @@
1
+ from .hub import *
@@ -0,0 +1,224 @@
1
+ # Copyright (C) 2021-2024, Mindee | Felix Dittrich.
2
+
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
+
6
+ # Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import subprocess
13
+ import textwrap
14
+ from pathlib import Path
15
+ from typing import Any, Optional
16
+
17
+ from huggingface_hub import (
18
+ HfApi,
19
+ Repository,
20
+ get_token,
21
+ get_token_permission,
22
+ hf_hub_download,
23
+ login,
24
+ )
25
+
26
+ from onnxtr import models
27
+ from onnxtr.models.engine import EngineConfig
28
+
29
+ __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
30
+
31
+
32
+ AVAILABLE_ARCHS = {
33
+ "classification": models.classification.zoo.ORIENTATION_ARCHS,
34
+ "detection": models.detection.zoo.ARCHS,
35
+ "recognition": models.recognition.zoo.ARCHS,
36
+ }
37
+
38
+
39
+ def login_to_hub() -> None: # pragma: no cover
40
+ """Login to huggingface hub"""
41
+ access_token = get_token()
42
+ if access_token is not None and get_token_permission(access_token):
43
+ logging.info("Huggingface Hub token found and valid")
44
+ login(token=access_token, write_permission=True)
45
+ else:
46
+ login()
47
+ # check if git lfs is installed
48
+ try:
49
+ subprocess.call(["git", "lfs", "version"])
50
+ except FileNotFoundError:
51
+ raise OSError(
52
+ "Looks like you do not have git-lfs installed, please install. \
53
+ You can install from https://git-lfs.github.com/. \
54
+ Then run `git lfs install` (you only have to do this once)."
55
+ )
56
+
57
+
58
+ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None:
59
+ """Save model and config to disk for pushing to huggingface hub
60
+
61
+ Args:
62
+ ----
63
+ model: Onnx model to be saved
64
+ save_dir: directory to save model and config
65
+ arch: architecture name
66
+ task: task name
67
+ """
68
+ save_directory = Path(save_dir)
69
+ shutil.copy2(model.model_path, save_directory / "model.onnx")
70
+
71
+ config_path = save_directory / "config.json"
72
+
73
+ # add model configuration
74
+ model_config = model.cfg
75
+ model_config["arch"] = arch
76
+ model_config["task"] = task
77
+
78
+ with config_path.open("w") as f:
79
+ json.dump(model_config, f, indent=2, ensure_ascii=False)
80
+
81
+
82
+ def push_to_hf_hub(
83
+ model: Any, model_name: str, task: str, override: bool = False, **kwargs
84
+ ) -> None: # pragma: no cover
85
+ """Save model and its configuration on HF hub
86
+
87
+ >>> from onnxtr.models import login_to_hub, push_to_hf_hub
88
+ >>> from onnxtr.models.recognition import crnn_mobilenet_v3_small
89
+ >>> login_to_hub()
90
+ >>> model = crnn_mobilenet_v3_small()
91
+ >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
92
+
93
+ Args:
94
+ ----
95
+ model: Onnx model to be saved
96
+ model_name: name of the model which is also the repository name
97
+ task: task name
98
+ override: whether to override the existing model / repo on HF hub
99
+ **kwargs: keyword arguments for push_to_hf_hub
100
+ """
101
+ run_config = kwargs.get("run_config", None)
102
+ arch = kwargs.get("arch", None)
103
+
104
+ if run_config is None and arch is None:
105
+ raise ValueError("run_config or arch must be specified")
106
+ if task not in ["classification", "detection", "recognition"]:
107
+ raise ValueError("task must be one of classification, detection, recognition")
108
+
109
+ # default readme
110
+ readme = textwrap.dedent(
111
+ f"""
112
+ ---
113
+ language:
114
+ - en
115
+ - fr
116
+ license: apache-2.0
117
+ ---
118
+
119
+ <p align="center">
120
+ <img src="https://github.com/felixdittrich92/OnnxTR/raw/main/docs/images/logo.jpg" width="40%">
121
+ </p>
122
+
123
+ **Optical Character Recognition made seamless & accessible to anyone, powered by Onnxruntime**
124
+
125
+ ## Task: {task}
126
+
127
+ https://github.com/felixdittrich92/OnnxTR
128
+
129
+ ### Example usage:
130
+
131
+ ```python
132
+ >>> from onnxtr.io import DocumentFile
133
+ >>> from onnxtr.models import ocr_predictor, from_hub
134
+
135
+ >>> img = DocumentFile.from_images(['<image_path>'])
136
+ >>> # Load your model from the hub
137
+ >>> model = from_hub('onnxtr/my-model')
138
+
139
+ >>> # Pass it to the predictor
140
+ >>> # If your model is a recognition model:
141
+ >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
142
+ >>> reco_arch=model)
143
+
144
+ >>> # If your model is a detection model:
145
+ >>> predictor = ocr_predictor(det_arch=model,
146
+ >>> reco_arch='crnn_mobilenet_v3_small')
147
+
148
+ >>> # Get your predictions
149
+ >>> res = predictor(img)
150
+ ```
151
+ """
152
+ )
153
+
154
+ # add run configuration to readme if available
155
+ if run_config is not None:
156
+ arch = run_config.arch
157
+ readme += textwrap.dedent(
158
+ f"""### Run Configuration
159
+ \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
160
+ )
161
+
162
+ if arch not in AVAILABLE_ARCHS[task]:
163
+ raise ValueError(
164
+ f"Architecture: {arch} for task: {task} not found.\
165
+ \nAvailable architectures: {AVAILABLE_ARCHS}"
166
+ )
167
+
168
+ commit_message = f"Add {model_name} model"
169
+
170
+ local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
171
+ repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=override)
172
+ repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)
173
+
174
+ with repo.commit(commit_message):
175
+ _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
176
+ readme_path = Path(repo.local_dir) / "README.md"
177
+ readme_path.write_text(readme)
178
+
179
+ repo.git_push()
180
+
181
+
182
+ def from_hub(repo_id: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any):
183
+ """Instantiate & load a pretrained model from HF hub.
184
+
185
+ >>> from onnxtr.models import from_hub
186
+ >>> model = from_hub("onnxtr/my-model")
187
+
188
+ Args:
189
+ ----
190
+ repo_id: HuggingFace model hub repo
191
+ engine_cfg: configuration for the inference engine (optional)
192
+ kwargs: kwargs of `hf_hub_download`
193
+
194
+ Returns:
195
+ -------
196
+ Model loaded with the checkpoint
197
+ """
198
+ # Get the config
199
+ with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f:
200
+ cfg = json.load(f)
201
+ model_path = hf_hub_download(repo_id, filename="model.onnx", **kwargs)
202
+
203
+ arch = cfg["arch"]
204
+ task = cfg["task"]
205
+ cfg.pop("arch")
206
+ cfg.pop("task")
207
+
208
+ if task == "classification":
209
+ model = models.classification.__dict__[arch](model_path, classes=cfg["classes"], engine_cfg=engine_cfg)
210
+ elif task == "detection":
211
+ model = models.detection.__dict__[arch](model_path, engine_cfg=engine_cfg)
212
+ elif task == "recognition":
213
+ model = models.recognition.__dict__[arch](
214
+ model_path, input_shape=cfg["input_shape"], vocab=cfg["vocab"], engine_cfg=engine_cfg
215
+ )
216
+
217
+ # convert all values which are lists to tuples
218
+ for key, value in cfg.items():
219
+ if isinstance(value, list):
220
+ cfg[key] = tuple(value)
221
+ # update model cfg
222
+ model.cfg = cfg
223
+
224
+ return model
@@ -129,8 +129,10 @@ class CRNN(Engine):
129
129
  **kwargs: Any,
130
130
  ) -> None:
131
131
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
132
+
132
133
  self.vocab = vocab
133
134
  self.cfg = cfg
135
+
134
136
  self.postprocessor = CRNNPostProcessor(self.vocab)
135
137
 
136
138
  def __call__(
@@ -53,6 +53,7 @@ class MASTER(Engine):
53
53
 
54
54
  self.vocab = vocab
55
55
  self.cfg = cfg
56
+
56
57
  self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
57
58
 
58
59
  def __call__(
@@ -49,8 +49,10 @@ class PARSeq(Engine):
49
49
  **kwargs: Any,
50
50
  ) -> None:
51
51
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
52
+
52
53
  self.vocab = vocab
53
54
  self.cfg = cfg
55
+
54
56
  self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
55
57
 
56
58
  def __call__(
@@ -49,8 +49,10 @@ class SAR(Engine):
49
49
  **kwargs: Any,
50
50
  ) -> None:
51
51
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
52
+
52
53
  self.vocab = vocab
53
54
  self.cfg = cfg
55
+
54
56
  self.postprocessor = SARPostProcessor(self.vocab)
55
57
 
56
58
  def __call__(
@@ -57,6 +57,7 @@ class ViTSTR(Engine):
57
57
  **kwargs: Any,
58
58
  ) -> None:
59
59
  super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)
60
+
60
61
  self.vocab = vocab
61
62
  self.cfg = cfg
62
63
 
@@ -5,8 +5,8 @@
5
5
 
6
6
  from typing import Tuple, Union
7
7
 
8
- import cv2
9
8
  import numpy as np
9
+ from PIL import Image, ImageOps
10
10
 
11
11
  __all__ = ["Resize", "Normalize"]
12
12
 
@@ -17,64 +17,51 @@ class Resize:
17
17
  def __init__(
18
18
  self,
19
19
  size: Union[int, Tuple[int, int]],
20
- interpolation=cv2.INTER_LINEAR,
20
+ interpolation=Image.Resampling.BILINEAR,
21
21
  preserve_aspect_ratio: bool = False,
22
22
  symmetric_pad: bool = False,
23
23
  ) -> None:
24
- super().__init__()
25
- self.size = size
24
+ self.size = size if isinstance(size, tuple) else (size, size)
26
25
  self.interpolation = interpolation
27
26
  self.preserve_aspect_ratio = preserve_aspect_ratio
28
27
  self.symmetric_pad = symmetric_pad
29
28
  self.output_size = size if isinstance(size, tuple) else (size, size)
30
29
 
31
- if not isinstance(self.size, (int, tuple, list)):
32
- raise AssertionError("size should be either a tuple, a list or an int")
30
+ if not isinstance(self.size, (tuple, int)):
31
+ raise AssertionError("size should be either a tuple or an int")
33
32
 
34
- def __call__(
35
- self,
36
- img: np.ndarray,
37
- ) -> np.ndarray:
38
- if img.ndim == 3:
39
- h, w = img.shape[0:2]
40
- else:
41
- h, w = img.shape[1:3]
42
- sh, sw = self.size if isinstance(self.size, tuple) else (self.size, self.size)
33
+ def __call__(self, img: np.ndarray) -> np.ndarray:
34
+ img = (img * 255).astype(np.uint8) if img.dtype != np.uint8 else img
35
+ h, w = img.shape[:2] if img.ndim == 3 else img.shape[1:3]
36
+ sh, sw = self.size
43
37
 
44
- # Calculate aspect ratio of the image
45
- aspect = w / h
38
+ if not self.preserve_aspect_ratio:
39
+ return np.array(Image.fromarray(img).resize((sw, sh), resample=self.interpolation))
46
40
 
47
- # Compute scaling and padding sizes
48
- if self.preserve_aspect_ratio:
49
- if aspect > 1: # Horizontal image
50
- new_w = sw
51
- new_h = int(sw / aspect)
52
- elif aspect < 1: # Vertical image
53
- new_h = sh
54
- new_w = int(sh * aspect)
55
- else: # Square image
56
- new_h, new_w = sh, sw
57
-
58
- img_resized = cv2.resize(img, (new_w, new_h), interpolation=self.interpolation)
59
-
60
- # Calculate padding
61
- pad_top = max((sh - new_h) // 2, 0)
62
- pad_bottom = max(sh - new_h - pad_top, 0)
63
- pad_left = max((sw - new_w) // 2, 0)
64
- pad_right = max(sw - new_w - pad_left, 0)
65
-
66
- # Pad the image
67
- img_resized = cv2.copyMakeBorder( # type: ignore[call-overload]
68
- img_resized, pad_top, pad_bottom, pad_left, pad_right, borderType=cv2.BORDER_CONSTANT, value=0
69
- )
70
-
71
- # Ensure the image matches the target size by resizing it again if needed
72
- img_resized = cv2.resize(img_resized, (sw, sh), interpolation=self.interpolation)
41
+ actual_ratio = h / w
42
+ target_ratio = sh / sw
43
+
44
+ if target_ratio == actual_ratio:
45
+ return np.array(Image.fromarray(img).resize((sw, sh), resample=self.interpolation))
46
+
47
+ if actual_ratio > target_ratio:
48
+ tmp_size = (int(sh / actual_ratio), sh)
73
49
  else:
74
- # Resize the image without preserving aspect ratio
75
- img_resized = cv2.resize(img, (sw, sh), interpolation=self.interpolation)
50
+ tmp_size = (sw, int(sw * actual_ratio))
51
+
52
+ img_resized = Image.fromarray(img).resize(tmp_size, resample=self.interpolation)
53
+ pad_left = pad_top = 0
54
+ pad_right = sw - img_resized.width
55
+ pad_bottom = sh - img_resized.height
56
+
57
+ if self.symmetric_pad:
58
+ pad_left = pad_right // 2
59
+ pad_right -= pad_left
60
+ pad_top = pad_bottom // 2
61
+ pad_bottom -= pad_top
76
62
 
77
- return img_resized
63
+ img_resized = ImageOps.expand(img_resized, (pad_left, pad_top, pad_right, pad_bottom))
64
+ return np.array(img_resized)
78
65
 
79
66
  def __repr__(self) -> str:
80
67
  interpolate_str = self.interpolation
@@ -5,14 +5,16 @@
5
5
 
6
6
  import logging
7
7
  import platform
8
- from typing import Optional
8
+ from typing import Optional, Union
9
9
 
10
10
  from PIL import ImageFont
11
11
 
12
12
  __all__ = ["get_font"]
13
13
 
14
14
 
15
- def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont:
15
+ def get_font(
16
+ font_family: Optional[str] = None, font_size: int = 13
17
+ ) -> Union[ImageFont.FreeTypeFont, ImageFont.ImageFont]:
16
18
  """Resolves a compatible ImageFont for the system
17
19
 
18
20
  Args:
@@ -29,7 +31,7 @@ def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFon
29
31
  try:
30
32
  font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size)
31
33
  except OSError: # pragma: no cover
32
- font = ImageFont.load_default()
34
+ font = ImageFont.load_default() # type: ignore[assignment]
33
35
  logging.warning(
34
36
  "unable to load recommended font family. Loading default PIL font,"
35
37
  "font size issues may be expected."
@@ -17,9 +17,14 @@ VOCABS: Dict[str, str] = {
17
17
  "ancient_greek": "αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ",
18
18
  "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي",
19
19
  "persian_letters": "پچڢڤگ",
20
- "hindi_digits": "٠١٢٣٤٥٦٧٨٩",
20
+ "arabic_digits": "٠١٢٣٤٥٦٧٨٩",
21
21
  "arabic_diacritics": "ًٌٍَُِّْ",
22
22
  "arabic_punctuation": "؟؛«»—",
23
+ "hindi_letters": "अआइईउऊऋॠऌॡएऐओऔअंअःकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह",
24
+ "hindi_digits": "०१२३४५६७८९",
25
+ "hindi_punctuation": "।,?!:्ॐ॰॥॰",
26
+ "bangla_letters": "অআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃেৈোৌ্ৎংঃঁ",
27
+ "bangla_digits": "০১২৩৪৫৬৭৮৯",
23
28
  }
24
29
 
25
30
  VOCABS["latin"] = VOCABS["digits"] + VOCABS["ascii_letters"] + VOCABS["punctuation"]
@@ -32,7 +37,7 @@ VOCABS["italian"] = VOCABS["english"] + "àèéìíîòóùúÀÈÉÌÍÎÒÓÙ
32
37
  VOCABS["german"] = VOCABS["english"] + "äöüßÄÖÜẞ"
33
38
  VOCABS["arabic"] = (
34
39
  VOCABS["digits"]
35
- + VOCABS["hindi_digits"]
40
+ + VOCABS["arabic_digits"]
36
41
  + VOCABS["arabic_letters"]
37
42
  + VOCABS["persian_letters"]
38
43
  + VOCABS["arabic_diacritics"]
@@ -48,10 +53,12 @@ VOCABS["finnish"] = VOCABS["english"] + "äöÄÖ"
48
53
  VOCABS["swedish"] = VOCABS["english"] + "åäöÅÄÖ"
49
54
  VOCABS["vietnamese"] = (
50
55
  VOCABS["english"]
51
- + "áàảạãăắằẳẵặâấầẩẫậéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ"
52
- + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
56
+ + "áàảạãăắằẳẵặâấầẩẫậđéèẻẽẹêếềểễệóòỏõọôốồổộỗơớờởợỡúùủũụưứừửữựiíìỉĩịýỳỷỹỵ"
57
+ + "ÁÀẢẠÃĂẮẰẲẴẶÂẤẦẨẪẬĐÉÈẺẼẸÊẾỀỂỄỆÓÒỎÕỌÔỐỒỔỘỖƠỚỜỞỢỠÚÙỦŨỤƯỨỪỬỮỰIÍÌỈĨỊÝỲỶỸỴ"
53
58
  )
54
59
  VOCABS["hebrew"] = VOCABS["english"] + "אבגדהוזחטיכלמנסעפצקרשת" + "₪"
60
+ VOCABS["hindi"] = VOCABS["hindi_letters"] + VOCABS["hindi_digits"] + VOCABS["hindi_punctuation"]
61
+ VOCABS["bangla"] = VOCABS["bangla_letters"] + VOCABS["bangla_digits"]
55
62
  VOCABS["multilingual"] = "".join(
56
63
  dict.fromkeys(
57
64
  VOCABS["french"]
@@ -0,0 +1 @@
1
+ __version__ = 'v0.4.0'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: onnxtr
3
- Version: 0.3.1
3
+ Version: 0.4.0
4
4
  Summary: Onnx Text Recognition (OnnxTR): docTR Onnx-Wrapper for high-performance OCR on documents.
5
5
  Author-email: Felix Dittrich <felixdittrich92@gmail.com>
6
6
  Maintainer: Felix Dittrich
@@ -233,6 +233,7 @@ Requires-Dist: pyclipper<2.0.0,>=1.2.0
233
233
  Requires-Dist: shapely<3.0.0,>=1.6.0
234
234
  Requires-Dist: rapidfuzz<4.0.0,>=3.0.0
235
235
  Requires-Dist: langdetect<2.0.0,>=1.0.9
236
+ Requires-Dist: huggingface-hub<1.0.0,>=0.23.0
236
237
  Requires-Dist: Pillow>=9.2.0
237
238
  Requires-Dist: defusedxml>=0.7.0
238
239
  Requires-Dist: anyascii>=0.3.2
@@ -275,7 +276,7 @@ Requires-Dist: pre-commit>=2.17.0; extra == "dev"
275
276
  [![codecov](https://codecov.io/gh/felixdittrich92/OnnxTR/graph/badge.svg?token=WVFRCQBOLI)](https://codecov.io/gh/felixdittrich92/OnnxTR)
276
277
  [![Codacy Badge](https://app.codacy.com/project/badge/Grade/4fff4d764bb14fb8b4f4afeb9587231b)](https://app.codacy.com/gh/felixdittrich92/OnnxTR/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
277
278
  [![CodeFactor](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr/badge)](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr)
278
- [![Pypi](https://img.shields.io/badge/pypi-v0.3.1-blue.svg)](https://pypi.org/project/OnnxTR/)
279
+ [![Pypi](https://img.shields.io/badge/pypi-v0.3.2-blue.svg)](https://pypi.org/project/OnnxTR/)
279
280
 
280
281
  > :warning: Please note that this is a wrapper around the [doctr](https://github.com/mindee/doctr) library to provide a Onnx pipeline for docTR. For feature requests, which are not directly related to the Onnx pipeline, please refer to the base project.
281
282
 
@@ -449,6 +450,69 @@ det_model = linknet_resnet18("path_to_custom_model.onnx")
449
450
  model = ocr_predictor(det_arch=det_model, reco_arch=reco_model)
450
451
  ```
451
452
 
453
+ ## Loading models from HuggingFace Hub
454
+
455
+ You can also load models from the HuggingFace Hub:
456
+
457
+ ```python
458
+ from onnxtr.io import DocumentFile
459
+ from onnxtr.models import ocr_predictor, from_hub
460
+
461
+ img = DocumentFile.from_images(['<image_path>'])
462
+ # Load your model from the hub
463
+ model = from_hub('onnxtr/my-model')
464
+
465
+ # Pass it to the predictor
466
+ # If your model is a recognition model:
467
+ predictor = ocr_predictor(
468
+ det_arch='db_mobilenet_v3_large',
469
+ reco_arch=model
470
+ )
471
+
472
+ # If your model is a detection model:
473
+ predictor = ocr_predictor(
474
+ det_arch=model,
475
+ reco_arch='crnn_mobilenet_v3_small'
476
+ )
477
+
478
+ # Get your predictions
479
+ res = predictor(img)
480
+ ```
481
+
482
+ HF Hub search: [here](https://huggingface.co/models?search=onnxtr).
483
+
484
+ Collection: [here](https://huggingface.co/collections/Felix92/onnxtr-66bf213a9f88f7346c90e842)
485
+
486
+ Or push your own models to the hub:
487
+
488
+ ```python
489
+ from onnxtr.models import parseq, push_to_hf_hub, login_to_hub
490
+ from onnxtr.utils.vocabs import VOCABS
491
+
492
+ # Login to the hub
493
+ login_to_hub()
494
+
495
+ # Recogniton model
496
+ model = parseq("~/onnxtr-parseq-multilingual-v1.onnx", vocab=VOCABS["multilingual"])
497
+ push_to_hf_hub(
498
+ model,
499
+ model_name="onnxtr-parseq-multilingual-v1",
500
+ task="recognition", # The task for which the model is intended [detection, recognition, classification]
501
+ arch="parseq", # The name of the model architecture
502
+ override=False # Set to `True` if you want to override an existing model / repository
503
+ )
504
+
505
+ # Detection model
506
+ model = linknet_resnet18("~/onnxtr-linknet-resnet18.onnx")
507
+ push_to_hf_hub(
508
+ model,
509
+ model_name="onnxtr-linknet-resnet18",
510
+ task="detection",
511
+ arch="linknet_resnet18",
512
+ override=True
513
+ )
514
+ ```
515
+
452
516
  ## Models architectures
453
517
 
454
518
  Credits where it's due: this repository provides ONNX models for the following architectures, converted from the docTR models:
@@ -45,6 +45,8 @@ onnxtr/models/detection/postprocessor/__init__.py
45
45
  onnxtr/models/detection/postprocessor/base.py
46
46
  onnxtr/models/detection/predictor/__init__.py
47
47
  onnxtr/models/detection/predictor/base.py
48
+ onnxtr/models/factory/__init__.py
49
+ onnxtr/models/factory/hub.py
48
50
  onnxtr/models/predictor/__init__.py
49
51
  onnxtr/models/predictor/base.py
50
52
  onnxtr/models/predictor/predictor.py
@@ -6,6 +6,7 @@ pyclipper<2.0.0,>=1.2.0
6
6
  shapely<3.0.0,>=1.6.0
7
7
  rapidfuzz<4.0.0,>=3.0.0
8
8
  langdetect<2.0.0,>=1.0.9
9
+ huggingface-hub<1.0.0,>=0.23.0
9
10
  Pillow>=9.2.0
10
11
  defusedxml>=0.7.0
11
12
  anyascii>=0.3.2
@@ -39,6 +39,7 @@ dependencies = [
39
39
  "shapely>=1.6.0,<3.0.0",
40
40
  "rapidfuzz>=3.0.0,<4.0.0",
41
41
  "langdetect>=1.0.9,<2.0.0",
42
+ "huggingface-hub>=0.23.0,<1.0.0",
42
43
  "Pillow>=9.2.0",
43
44
  "defusedxml>=0.7.0",
44
45
  "anyascii>=0.3.2",
@@ -126,6 +127,7 @@ module = [
126
127
  "weasyprint.*",
127
128
  "pypdfium2.*",
128
129
  "langdetect.*",
130
+ "huggingface_hub.*",
129
131
  "rapidfuzz.*",
130
132
  "anyascii.*",
131
133
  "tqdm.*",
@@ -9,7 +9,7 @@ from pathlib import Path
9
9
  from setuptools import setup
10
10
 
11
11
  PKG_NAME = "onnxtr"
12
- VERSION = os.getenv("BUILD_VERSION", "0.3.1a0")
12
+ VERSION = os.getenv("BUILD_VERSION", "0.4.0a0")
13
13
 
14
14
 
15
15
  if __name__ == "__main__":
@@ -1 +0,0 @@
1
- __version__ = 'v0.3.1'
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -67,11 +67,12 @@ class PreProcessor(NestedObject):
67
67
  if x.dtype not in (np.uint8, np.float32):
68
68
  raise TypeError("unsupported data type for numpy.ndarray")
69
69
  x = shape_translate(x, "HWC")
70
+
71
+ # Resizing
72
+ x = self.resize(x)
70
73
  # Data type & 255 division
71
74
  if x.dtype == np.uint8:
72
75
  x = x.astype(np.float32) / 255.0
73
- # Resizing
74
- x = self.resize(x)
75
76
 
76
77
  return x
77
78
 
@@ -95,13 +96,12 @@ class PreProcessor(NestedObject):
95
96
  raise TypeError("unsupported data type for numpy.ndarray")
96
97
  x = shape_translate(x, "BHWC")
97
98
 
98
- # Data type & 255 division
99
- if x.dtype == np.uint8:
100
- x = x.astype(np.float32) / 255.0
101
99
  # Resizing
102
100
  if (x.shape[1], x.shape[2]) != self.resize.output_size:
103
101
  x = np.array([self.resize(sample) for sample in x])
104
-
102
+ # Data type & 255 division
103
+ if x.dtype == np.uint8:
104
+ x = x.astype(np.float32) / 255.0
105
105
  batches = [x]
106
106
 
107
107
  elif isinstance(x, list) and all(isinstance(sample, np.ndarray) for sample in x):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes