inference-models 0.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. inference_models/__init__.py +36 -0
  2. inference_models/configuration.py +72 -0
  3. inference_models/constants.py +2 -0
  4. inference_models/entities.py +5 -0
  5. inference_models/errors.py +137 -0
  6. inference_models/logger.py +52 -0
  7. inference_models/model_pipelines/__init__.py +0 -0
  8. inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
  9. inference_models/model_pipelines/auto_loaders/core.py +120 -0
  10. inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
  11. inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
  12. inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
  13. inference_models/models/__init__.py +0 -0
  14. inference_models/models/auto_loaders/__init__.py +0 -0
  15. inference_models/models/auto_loaders/access_manager.py +168 -0
  16. inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
  17. inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
  18. inference_models/models/auto_loaders/constants.py +7 -0
  19. inference_models/models/auto_loaders/core.py +1341 -0
  20. inference_models/models/auto_loaders/dependency_models.py +52 -0
  21. inference_models/models/auto_loaders/entities.py +57 -0
  22. inference_models/models/auto_loaders/models_registry.py +497 -0
  23. inference_models/models/auto_loaders/presentation_utils.py +333 -0
  24. inference_models/models/auto_loaders/ranking.py +413 -0
  25. inference_models/models/auto_loaders/utils.py +31 -0
  26. inference_models/models/base/__init__.py +0 -0
  27. inference_models/models/base/classification.py +123 -0
  28. inference_models/models/base/depth_estimation.py +62 -0
  29. inference_models/models/base/documents_parsing.py +111 -0
  30. inference_models/models/base/embeddings.py +66 -0
  31. inference_models/models/base/instance_segmentation.py +87 -0
  32. inference_models/models/base/keypoints_detection.py +93 -0
  33. inference_models/models/base/object_detection.py +143 -0
  34. inference_models/models/base/semantic_segmentation.py +74 -0
  35. inference_models/models/base/types.py +5 -0
  36. inference_models/models/clip/__init__.py +0 -0
  37. inference_models/models/clip/clip_onnx.py +148 -0
  38. inference_models/models/clip/clip_pytorch.py +104 -0
  39. inference_models/models/clip/preprocessing.py +162 -0
  40. inference_models/models/common/__init__.py +0 -0
  41. inference_models/models/common/cuda.py +30 -0
  42. inference_models/models/common/model_packages.py +25 -0
  43. inference_models/models/common/onnx.py +379 -0
  44. inference_models/models/common/roboflow/__init__.py +0 -0
  45. inference_models/models/common/roboflow/model_packages.py +361 -0
  46. inference_models/models/common/roboflow/post_processing.py +436 -0
  47. inference_models/models/common/roboflow/pre_processing.py +1332 -0
  48. inference_models/models/common/torch.py +20 -0
  49. inference_models/models/common/trt.py +266 -0
  50. inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
  51. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
  52. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
  53. inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
  54. inference_models/models/depth_anything_v2/__init__.py +0 -0
  55. inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
  56. inference_models/models/dinov3/__init__.py +0 -0
  57. inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
  58. inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
  59. inference_models/models/doctr/__init__.py +0 -0
  60. inference_models/models/doctr/doctr_torch.py +304 -0
  61. inference_models/models/easy_ocr/__init__.py +0 -0
  62. inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
  63. inference_models/models/florence2/__init__.py +0 -0
  64. inference_models/models/florence2/florence2_hf.py +897 -0
  65. inference_models/models/grounding_dino/__init__.py +0 -0
  66. inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
  67. inference_models/models/l2cs/__init__.py +0 -0
  68. inference_models/models/l2cs/l2cs_onnx.py +216 -0
  69. inference_models/models/mediapipe_face_detection/__init__.py +0 -0
  70. inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
  71. inference_models/models/moondream2/__init__.py +0 -0
  72. inference_models/models/moondream2/moondream2_hf.py +281 -0
  73. inference_models/models/owlv2/__init__.py +0 -0
  74. inference_models/models/owlv2/cache.py +182 -0
  75. inference_models/models/owlv2/entities.py +112 -0
  76. inference_models/models/owlv2/owlv2_hf.py +695 -0
  77. inference_models/models/owlv2/reference_dataset.py +291 -0
  78. inference_models/models/paligemma/__init__.py +0 -0
  79. inference_models/models/paligemma/paligemma_hf.py +209 -0
  80. inference_models/models/perception_encoder/__init__.py +0 -0
  81. inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
  82. inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
  83. inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
  84. inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
  85. inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
  86. inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
  87. inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
  88. inference_models/models/qwen25vl/__init__.py +1 -0
  89. inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
  90. inference_models/models/resnet/__init__.py +0 -0
  91. inference_models/models/resnet/resnet_classification_onnx.py +330 -0
  92. inference_models/models/resnet/resnet_classification_torch.py +305 -0
  93. inference_models/models/resnet/resnet_classification_trt.py +369 -0
  94. inference_models/models/rfdetr/__init__.py +0 -0
  95. inference_models/models/rfdetr/backbone_builder.py +101 -0
  96. inference_models/models/rfdetr/class_remapping.py +41 -0
  97. inference_models/models/rfdetr/common.py +115 -0
  98. inference_models/models/rfdetr/default_labels.py +108 -0
  99. inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
  100. inference_models/models/rfdetr/misc.py +26 -0
  101. inference_models/models/rfdetr/ms_deform_attn.py +180 -0
  102. inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
  103. inference_models/models/rfdetr/position_encoding.py +166 -0
  104. inference_models/models/rfdetr/post_processor.py +83 -0
  105. inference_models/models/rfdetr/projector.py +373 -0
  106. inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
  107. inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
  108. inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
  109. inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
  110. inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
  111. inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
  112. inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
  113. inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
  114. inference_models/models/rfdetr/segmentation_head.py +273 -0
  115. inference_models/models/rfdetr/transformer.py +767 -0
  116. inference_models/models/roboflow_instant/__init__.py +0 -0
  117. inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
  118. inference_models/models/sam/__init__.py +0 -0
  119. inference_models/models/sam/cache.py +147 -0
  120. inference_models/models/sam/entities.py +25 -0
  121. inference_models/models/sam/sam_torch.py +675 -0
  122. inference_models/models/sam2/__init__.py +0 -0
  123. inference_models/models/sam2/cache.py +162 -0
  124. inference_models/models/sam2/entities.py +43 -0
  125. inference_models/models/sam2/sam2_torch.py +905 -0
  126. inference_models/models/sam2_rt/__init__.py +0 -0
  127. inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
  128. inference_models/models/smolvlm/__init__.py +0 -0
  129. inference_models/models/smolvlm/smolvlm_hf.py +245 -0
  130. inference_models/models/trocr/__init__.py +0 -0
  131. inference_models/models/trocr/trocr_hf.py +53 -0
  132. inference_models/models/vit/__init__.py +0 -0
  133. inference_models/models/vit/vit_classification_huggingface.py +319 -0
  134. inference_models/models/vit/vit_classification_onnx.py +326 -0
  135. inference_models/models/vit/vit_classification_trt.py +365 -0
  136. inference_models/models/yolact/__init__.py +1 -0
  137. inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
  138. inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
  139. inference_models/models/yolo_world/__init__.py +1 -0
  140. inference_models/models/yolonas/__init__.py +0 -0
  141. inference_models/models/yolonas/nms.py +44 -0
  142. inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
  143. inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
  144. inference_models/models/yolov10/__init__.py +0 -0
  145. inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
  146. inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
  147. inference_models/models/yolov11/__init__.py +0 -0
  148. inference_models/models/yolov11/yolov11_onnx.py +28 -0
  149. inference_models/models/yolov11/yolov11_torch_script.py +25 -0
  150. inference_models/models/yolov11/yolov11_trt.py +21 -0
  151. inference_models/models/yolov12/__init__.py +0 -0
  152. inference_models/models/yolov12/yolov12_onnx.py +7 -0
  153. inference_models/models/yolov12/yolov12_torch_script.py +7 -0
  154. inference_models/models/yolov12/yolov12_trt.py +7 -0
  155. inference_models/models/yolov5/__init__.py +0 -0
  156. inference_models/models/yolov5/nms.py +99 -0
  157. inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
  158. inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
  159. inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
  160. inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
  161. inference_models/models/yolov7/__init__.py +0 -0
  162. inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
  163. inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
  164. inference_models/models/yolov8/__init__.py +0 -0
  165. inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
  166. inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
  167. inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
  168. inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
  169. inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
  170. inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
  171. inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
  172. inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
  173. inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
  174. inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
  175. inference_models/models/yolov9/__init__.py +0 -0
  176. inference_models/models/yolov9/yolov9_onnx.py +7 -0
  177. inference_models/models/yolov9/yolov9_torch_script.py +7 -0
  178. inference_models/models/yolov9/yolov9_trt.py +7 -0
  179. inference_models/runtime_introspection/__init__.py +0 -0
  180. inference_models/runtime_introspection/core.py +410 -0
  181. inference_models/utils/__init__.py +0 -0
  182. inference_models/utils/download.py +608 -0
  183. inference_models/utils/environment.py +28 -0
  184. inference_models/utils/file_system.py +51 -0
  185. inference_models/utils/hashing.py +7 -0
  186. inference_models/utils/imports.py +48 -0
  187. inference_models/utils/onnx_introspection.py +17 -0
  188. inference_models/weights_providers/__init__.py +0 -0
  189. inference_models/weights_providers/core.py +20 -0
  190. inference_models/weights_providers/entities.py +159 -0
  191. inference_models/weights_providers/roboflow.py +601 -0
  192. inference_models-0.18.3.dist-info/METADATA +466 -0
  193. inference_models-0.18.3.dist-info/RECORD +195 -0
  194. inference_models-0.18.3.dist-info/WHEEL +5 -0
  195. inference_models-0.18.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,304 @@
1
+ from dataclasses import dataclass
2
+ from typing import Callable, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from doctr.io import Document
7
+ from doctr.models import detection_predictor, ocr_predictor, recognition_predictor
8
+
9
+ from inference_models import Detections
10
+ from inference_models.configuration import DEFAULT_DEVICE
11
+ from inference_models.entities import ColorFormat, ImageDimensions
12
+ from inference_models.errors import CorruptedModelPackageError, ModelRuntimeError
13
+ from inference_models.models.base.documents_parsing import StructuredOCRModel
14
+ from inference_models.models.common.model_packages import get_model_package_contents
15
+ from inference_models.utils.file_system import read_json
16
+
17
+ SUPPORTED_DETECTION_MODELS = {
18
+ "fast_base",
19
+ "fast_small",
20
+ "fast_tiny",
21
+ "db_resnet50",
22
+ "db_resnet34",
23
+ "db_mobilenet_v3_large",
24
+ "linknet_resnet18",
25
+ "linknet_resnet34",
26
+ "linknet_resnet50",
27
+ }
28
+ SUPPORTED_RECOGNITION_MODELS = {
29
+ "crnn_vgg16_bn",
30
+ "crnn_mobilenet_v3_small",
31
+ "crnn_mobilenet_v3_large",
32
+ "master",
33
+ "sar_resnet31",
34
+ "vitstr_small",
35
+ "vitstr_base",
36
+ "parseq",
37
+ }
38
+
39
+
40
+ class DocTR(StructuredOCRModel[List[np.ndarray], ImageDimensions, Document]):
41
+
42
+ @classmethod
43
+ def from_pretrained(
44
+ cls,
45
+ model_name_or_path: str,
46
+ device: torch.device = DEFAULT_DEVICE,
47
+ assume_straight_pages: bool = True,
48
+ preserve_aspect_ratio: bool = True,
49
+ detection_max_batch_size: int = 2,
50
+ recognition_max_batch_size: int = 128,
51
+ **kwargs,
52
+ ) -> "StructuredOCRModel":
53
+ model_package_content = get_model_package_contents(
54
+ model_package_dir=model_name_or_path,
55
+ elements=["detection_weights.pt", "recognition_weights.pt", "config.json"],
56
+ )
57
+ config = parse_model_config(config_path=model_package_content["config.json"])
58
+ if config.det_model not in SUPPORTED_DETECTION_MODELS:
59
+ raise CorruptedModelPackageError(
60
+ message=f"{config.det_model} model denoted in configuration not supported as DocTR detection model.",
61
+ help_url="https://todo",
62
+ )
63
+ if config.rec_model not in SUPPORTED_RECOGNITION_MODELS:
64
+ raise CorruptedModelPackageError(
65
+ message=f"{config.rec_model} model denoted in configuration not supported as DocTR recognition model.",
66
+ help_url="https://todo",
67
+ )
68
+ det_model = detection_predictor(
69
+ arch=config.det_model,
70
+ pretrained=False,
71
+ assume_straight_pages=assume_straight_pages,
72
+ preserve_aspect_ratio=preserve_aspect_ratio,
73
+ batch_size=detection_max_batch_size,
74
+ )
75
+ det_model.model.to(device)
76
+ detector_weights = torch.load(
77
+ model_package_content["detection_weights.pt"],
78
+ weights_only=True,
79
+ map_location=device,
80
+ )
81
+ det_model.model.load_state_dict(detector_weights)
82
+ rec_model = recognition_predictor(
83
+ arch=config.rec_model,
84
+ pretrained=False,
85
+ batch_size=recognition_max_batch_size,
86
+ )
87
+ rec_model.model.to(device)
88
+ rec_weights = torch.load(
89
+ model_package_content["recognition_weights.pt"],
90
+ weights_only=True,
91
+ map_location=device,
92
+ )
93
+ rec_model.model.load_state_dict(rec_weights)
94
+ model = ocr_predictor(
95
+ det_arch=det_model.model,
96
+ reco_arch=rec_model.model,
97
+ ).to(device=device)
98
+ return cls(model=model, device=device)
99
+
100
+ def __init__(
101
+ self,
102
+ model: Callable[[List[np.ndarray]], Document],
103
+ device: torch.device,
104
+ ):
105
+ self._model = model
106
+ self._device = device
107
+
108
+ @property
109
+ def class_names(self) -> List[str]:
110
+ return ["block", "line", "word"]
111
+
112
+ def pre_process(
113
+ self,
114
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
115
+ input_color_format: Optional[ColorFormat] = None,
116
+ **kwargs,
117
+ ) -> Tuple[List[np.ndarray], List[ImageDimensions]]:
118
+ if isinstance(images, np.ndarray):
119
+ input_color_format = input_color_format or "bgr"
120
+ if input_color_format != "bgr":
121
+ images = images[:, :, ::-1]
122
+ h, w = images.shape[:2]
123
+ return [images], [ImageDimensions(height=h, width=w)]
124
+ if isinstance(images, torch.Tensor):
125
+ input_color_format = input_color_format or "rgb"
126
+ if len(images.shape) == 3:
127
+ images = torch.unsqueeze(images, dim=0)
128
+ if input_color_format != "bgr":
129
+ images = images[:, [2, 1, 0], :, :]
130
+ result = []
131
+ dimensions = []
132
+ for image in images:
133
+ np_image = image.permute(1, 2, 0).cpu().numpy()
134
+ result.append(np_image)
135
+ dimensions.append(
136
+ ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
137
+ )
138
+ return result, dimensions
139
+ if not isinstance(images, list):
140
+ raise ModelRuntimeError(
141
+ message="Pre-processing supports only np.array or torch.Tensor or list of above.",
142
+ help_url="https://todo",
143
+ )
144
+ if not len(images):
145
+ raise ModelRuntimeError(
146
+ message="Detected empty input to the model", help_url="https://todo"
147
+ )
148
+ if isinstance(images[0], np.ndarray):
149
+ input_color_format = input_color_format or "bgr"
150
+ if input_color_format != "bgr":
151
+ images = [i[:, :, ::-1] for i in images]
152
+ dimensions = [
153
+ ImageDimensions(height=i.shape[0], width=i.shape[1]) for i in images
154
+ ]
155
+ return images, dimensions
156
+ if isinstance(images[0], torch.Tensor):
157
+ result = []
158
+ dimensions = []
159
+ input_color_format = input_color_format or "rgb"
160
+ for image in images:
161
+ if input_color_format != "bgr":
162
+ image = image[[2, 1, 0], :, :]
163
+ np_image = image.permute(1, 2, 0).cpu().numpy()
164
+ result.append(np_image)
165
+ dimensions.append(
166
+ ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
167
+ )
168
+ return result, dimensions
169
+ raise ModelRuntimeError(
170
+ message=f"Detected unknown input batch element: {type(images[0])}",
171
+ help_url="https://todo",
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ pre_processed_images: List[np.ndarray],
177
+ **kwargs,
178
+ ) -> Document:
179
+ return self._model(pre_processed_images)
180
+
181
+ def post_process(
182
+ self,
183
+ model_results: Document,
184
+ pre_processing_meta: List[ImageDimensions],
185
+ **kwargs,
186
+ ) -> Tuple[List[str], List[Detections]]:
187
+ rendered_texts, all_detections = [], []
188
+ for result_page, original_dimensions in zip(
189
+ model_results.pages, pre_processing_meta
190
+ ):
191
+ detections = []
192
+ rendered_texts.append(result_page.render())
193
+ for block in result_page.blocks:
194
+ block_elements_probs = []
195
+ for line in block.lines:
196
+ line_elements_probs = []
197
+ for word in line.words:
198
+ line_elements_probs.append(word.confidence)
199
+ block_elements_probs.append(word.confidence)
200
+ detections.append(
201
+ {
202
+ "xyxy": [
203
+ word.geometry[0][0],
204
+ word.geometry[0][1],
205
+ word.geometry[1][0],
206
+ word.geometry[1][1],
207
+ ],
208
+ "class_id": 2,
209
+ "confidence": word.confidence,
210
+ "text": word.value,
211
+ }
212
+ )
213
+ detections.append(
214
+ {
215
+ "xyxy": [
216
+ line.geometry[0][0],
217
+ line.geometry[0][1],
218
+ line.geometry[1][0],
219
+ line.geometry[1][1],
220
+ ],
221
+ "class_id": 1,
222
+ "confidence": sum(line_elements_probs)
223
+ / len(line_elements_probs),
224
+ "text": line.render(),
225
+ }
226
+ )
227
+ detections.append(
228
+ {
229
+ "xyxy": [
230
+ block.geometry[0][0],
231
+ block.geometry[0][1],
232
+ block.geometry[1][0],
233
+ block.geometry[1][1],
234
+ ],
235
+ "class_id": 0,
236
+ "confidence": sum(block_elements_probs)
237
+ / len(block_elements_probs),
238
+ "text": block.render(),
239
+ }
240
+ )
241
+ dim_tensor = torch.tensor(
242
+ [
243
+ original_dimensions.width,
244
+ original_dimensions.height,
245
+ original_dimensions.width,
246
+ original_dimensions.height,
247
+ ],
248
+ device=self._device,
249
+ )
250
+ xyxy = (
251
+ (
252
+ torch.tensor([e["xyxy"] for e in detections], device=self._device)
253
+ * dim_tensor
254
+ )
255
+ .round()
256
+ .int()
257
+ )
258
+ class_id = torch.tensor(
259
+ [e["class_id"] for e in detections], device=self._device
260
+ )
261
+ confidence = torch.tensor(
262
+ [e["confidence"] for e in detections], device=self._device
263
+ )
264
+ data = [{"text": e["text"]} for e in detections]
265
+ all_detections.append(
266
+ Detections(
267
+ xyxy=xyxy,
268
+ class_id=class_id,
269
+ confidence=confidence,
270
+ bboxes_metadata=data,
271
+ )
272
+ )
273
+ return rendered_texts, all_detections
274
+
275
+
276
+ @dataclass
277
+ class DocTRConfig:
278
+ det_model: str
279
+ rec_model: str
280
+
281
+
282
+ def parse_model_config(config_path: str) -> DocTRConfig:
283
+ try:
284
+ content = read_json(path=config_path)
285
+ if not content:
286
+ raise ValueError("file is empty.")
287
+ if not isinstance(content, dict):
288
+ raise ValueError("file is malformed (not a JSON dictionary)")
289
+ if "det_model" not in content or "rec_model" not in content:
290
+ raise ValueError(
291
+ "file is malformed (lack of `det_model` or `rec_model` key)"
292
+ )
293
+ return DocTRConfig(
294
+ det_model=content["det_model"],
295
+ rec_model=content["rec_model"],
296
+ )
297
+ except (IOError, OSError, ValueError) as error:
298
+ raise CorruptedModelPackageError(
299
+ message=f"Config file located under path {config_path} is malformed: "
300
+ f"{error}. In case that the package is "
301
+ f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
302
+ f"verify its consistency in docs.",
303
+ help_url="https://todo",
304
+ ) from error
File without changes
@@ -0,0 +1,222 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import easyocr
4
+ import numpy as np
5
+ import torch
6
+ from pydantic import BaseModel
7
+
8
+ from inference_models import Detections, StructuredOCRModel
9
+ from inference_models.configuration import DEFAULT_DEVICE
10
+ from inference_models.entities import ColorFormat, ImageDimensions
11
+ from inference_models.errors import CorruptedModelPackageError, ModelRuntimeError
12
+ from inference_models.models.common.model_packages import get_model_package_contents
13
+ from inference_models.utils.file_system import read_json
14
+
15
+ Point = Tuple[int, int]
16
+ Coordinates = Tuple[Point, Point, Point, Point]
17
+ DetectedText = str
18
+ Confidence = float
19
+ EasyOCRRawPrediction = Tuple[Coordinates, DetectedText, Confidence]
20
+
21
+
22
+ RECOGNIZED_DETECTORS = {"craft", "dbnet18", "dbnet50"}
23
+
24
+
25
+ class EasyOcrConfig(BaseModel):
26
+ lang_list: List[str]
27
+ detector_model_file_name: str
28
+ recognition_model_file_name: str
29
+ detect_network: str
30
+ recognition_network: str
31
+
32
+
33
+ class EasyOCRTorch(
34
+ StructuredOCRModel[List[np.ndarray], ImageDimensions, EasyOCRRawPrediction]
35
+ ):
36
+
37
+ @classmethod
38
+ def from_pretrained(
39
+ cls,
40
+ model_name_or_path: str,
41
+ device: torch.device = DEFAULT_DEVICE,
42
+ **kwargs,
43
+ ) -> "StructuredOCRModel":
44
+ package_contents = get_model_package_contents(
45
+ model_package_dir=model_name_or_path, elements=["easy-ocr-config.json"]
46
+ )
47
+ config = parse_easy_ocr_config(
48
+ config_path=package_contents["easy-ocr-config.json"]
49
+ )
50
+ device_string = device.type
51
+ if device.type == "cuda" and device.index:
52
+ device_string = f"{device_string}:{device.index}"
53
+ try:
54
+ model = easyocr.Reader(
55
+ config.lang_list,
56
+ download_enabled=False,
57
+ model_storage_directory=model_name_or_path,
58
+ user_network_directory=model_name_or_path,
59
+ detect_network=config.detect_network,
60
+ recog_network=config.recognition_network,
61
+ detector=True,
62
+ recognizer=True,
63
+ gpu=device_string,
64
+ )
65
+ except Exception as error:
66
+ raise CorruptedModelPackageError(
67
+ message=f"EasyOCR model package is broken - could not parse model config file. Error: {error}"
68
+ f"If you attempt to run `inference-models` locally - inspect the contents of local directory to check "
69
+ f"model package - config file is corrupted. If you run the model on Roboflow platform - "
70
+ f"contact us.",
71
+ help_url="https://todo",
72
+ ) from error
73
+ return cls(model=model, device=device)
74
+
75
+ def __init__(
76
+ self,
77
+ model: easyocr.Reader,
78
+ device: torch.device,
79
+ ):
80
+ self._model = model
81
+ self._device = device
82
+
83
+ @property
84
+ def class_names(self) -> List[str]:
85
+ return ["text-region"]
86
+
87
+ def pre_process(
88
+ self,
89
+ images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
90
+ input_color_format: Optional[ColorFormat] = None,
91
+ **kwargs,
92
+ ) -> Tuple[List[np.ndarray], List[ImageDimensions]]:
93
+ if isinstance(images, np.ndarray):
94
+ input_color_format = input_color_format or "bgr"
95
+ if input_color_format != "bgr":
96
+ images = images[:, :, ::-1]
97
+ h, w = images.shape[:2]
98
+ return [images], [ImageDimensions(height=h, width=w)]
99
+ if isinstance(images, torch.Tensor):
100
+ input_color_format = input_color_format or "rgb"
101
+ if len(images.shape) == 3:
102
+ images = torch.unsqueeze(images, dim=0)
103
+ if input_color_format != "bgr":
104
+ images = images[:, [2, 1, 0], :, :]
105
+ result = []
106
+ dimensions = []
107
+ for image in images:
108
+ np_image = image.permute(1, 2, 0).cpu().numpy()
109
+ result.append(np_image)
110
+ dimensions.append(
111
+ ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
112
+ )
113
+ return result, dimensions
114
+ if not isinstance(images, list):
115
+ raise ModelRuntimeError(
116
+ message="Pre-processing supports only np.array or torch.Tensor or list of above.",
117
+ help_url="https://todo",
118
+ )
119
+ if not len(images):
120
+ raise ModelRuntimeError(
121
+ message="Detected empty input to the model", help_url="https://todo"
122
+ )
123
+ if isinstance(images[0], np.ndarray):
124
+ input_color_format = input_color_format or "bgr"
125
+ if input_color_format != "bgr":
126
+ images = [i[:, :, ::-1] for i in images]
127
+ dimensions = [
128
+ ImageDimensions(height=i.shape[0], width=i.shape[1]) for i in images
129
+ ]
130
+ return images, dimensions
131
+ if isinstance(images[0], torch.Tensor):
132
+ result = []
133
+ dimensions = []
134
+ input_color_format = input_color_format or "rgb"
135
+ for image in images:
136
+ if input_color_format != "bgr":
137
+ image = image[[2, 1, 0], :, :]
138
+ np_image = image.permute(1, 2, 0).cpu().numpy()
139
+ result.append(np_image)
140
+ dimensions.append(
141
+ ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
142
+ )
143
+ return result, dimensions
144
+ raise ModelRuntimeError(
145
+ message=f"Detected unknown input batch element: {type(images[0])}",
146
+ help_url="https://todo",
147
+ )
148
+
149
+ def forward(
150
+ self, pre_processed_images: List[np.ndarray], **kwargs
151
+ ) -> List[EasyOCRRawPrediction]:
152
+ all_results = []
153
+ for image in pre_processed_images:
154
+ image_results_raw = self._model.readtext(image)
155
+ image_results_parsed = [
156
+ (
157
+ [
158
+ [x.item() if not isinstance(x, (int, float)) else x for x in c]
159
+ for c in res[0]
160
+ ],
161
+ res[1],
162
+ res[2].item() if not isinstance(res[2], (int, float)) else res[2],
163
+ )
164
+ for res in image_results_raw
165
+ ]
166
+ all_results.append(image_results_parsed)
167
+ return all_results
168
+
169
+ def post_process(
170
+ self,
171
+ model_results: List[EasyOCRRawPrediction],
172
+ pre_processing_meta: List[ImageDimensions],
173
+ confidence_threshold: float = 0.3,
174
+ text_regions_separator: str = " ",
175
+ **kwargs,
176
+ ) -> Tuple[List[str], List[Detections]]:
177
+ rendered_texts, all_detections = [], []
178
+ for single_image_result, original_dimensions in zip(
179
+ model_results, pre_processing_meta
180
+ ):
181
+ whole_image_text = []
182
+ xyxy = []
183
+ confidence = []
184
+ class_id = []
185
+ for box, text, text_confidence in single_image_result:
186
+ if text_confidence < confidence_threshold:
187
+ continue
188
+ whole_image_text.append(text)
189
+ min_x = min(p[0] for p in box)
190
+ min_y = min(p[1] for p in box)
191
+ max_x = max(p[0] for p in box)
192
+ max_y = max(p[1] for p in box)
193
+ box_xyxy = [min_x, min_y, max_x, max_y]
194
+ xyxy.append(box_xyxy)
195
+ confidence.append(float(text_confidence))
196
+ class_id.append(0)
197
+ while_image_text_joined = text_regions_separator.join(whole_image_text)
198
+ rendered_texts.append(while_image_text_joined)
199
+ data = [{"text": text} for text in whole_image_text]
200
+ all_detections.append(
201
+ Detections(
202
+ xyxy=torch.tensor(xyxy, device=self._device),
203
+ class_id=torch.tensor(class_id, device=self._device),
204
+ confidence=torch.tensor(confidence, device=self._device),
205
+ bboxes_metadata=data,
206
+ )
207
+ )
208
+ return rendered_texts, all_detections
209
+
210
+
211
+ def parse_easy_ocr_config(config_path: str) -> EasyOcrConfig:
212
+ try:
213
+ raw_config = read_json(config_path)
214
+ return EasyOcrConfig.model_validate(raw_config)
215
+ except Exception as error:
216
+ raise CorruptedModelPackageError(
217
+ message=f"EasyOCR model package is broken - could not parse model config file. Error: {error}"
218
+ f"If you attempt to run `inference-models` locally - inspect the contents of local directory to check "
219
+ f"model package - config file is corrupted. If you run the model on Roboflow platform - "
220
+ f"contact us.",
221
+ help_url="https://todo",
222
+ ) from error
File without changes