inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Callable, List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from doctr.io import Document
|
|
7
|
+
from doctr.models import detection_predictor, ocr_predictor, recognition_predictor
|
|
8
|
+
|
|
9
|
+
from inference_models import Detections
|
|
10
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
11
|
+
from inference_models.entities import ColorFormat, ImageDimensions
|
|
12
|
+
from inference_models.errors import CorruptedModelPackageError, ModelRuntimeError
|
|
13
|
+
from inference_models.models.base.documents_parsing import StructuredOCRModel
|
|
14
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
15
|
+
from inference_models.utils.file_system import read_json
|
|
16
|
+
|
|
17
|
+
SUPPORTED_DETECTION_MODELS = {
|
|
18
|
+
"fast_base",
|
|
19
|
+
"fast_small",
|
|
20
|
+
"fast_tiny",
|
|
21
|
+
"db_resnet50",
|
|
22
|
+
"db_resnet34",
|
|
23
|
+
"db_mobilenet_v3_large",
|
|
24
|
+
"linknet_resnet18",
|
|
25
|
+
"linknet_resnet34",
|
|
26
|
+
"linknet_resnet50",
|
|
27
|
+
}
|
|
28
|
+
SUPPORTED_RECOGNITION_MODELS = {
|
|
29
|
+
"crnn_vgg16_bn",
|
|
30
|
+
"crnn_mobilenet_v3_small",
|
|
31
|
+
"crnn_mobilenet_v3_large",
|
|
32
|
+
"master",
|
|
33
|
+
"sar_resnet31",
|
|
34
|
+
"vitstr_small",
|
|
35
|
+
"vitstr_base",
|
|
36
|
+
"parseq",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DocTR(StructuredOCRModel[List[np.ndarray], ImageDimensions, Document]):
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_pretrained(
|
|
44
|
+
cls,
|
|
45
|
+
model_name_or_path: str,
|
|
46
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
47
|
+
assume_straight_pages: bool = True,
|
|
48
|
+
preserve_aspect_ratio: bool = True,
|
|
49
|
+
detection_max_batch_size: int = 2,
|
|
50
|
+
recognition_max_batch_size: int = 128,
|
|
51
|
+
**kwargs,
|
|
52
|
+
) -> "StructuredOCRModel":
|
|
53
|
+
model_package_content = get_model_package_contents(
|
|
54
|
+
model_package_dir=model_name_or_path,
|
|
55
|
+
elements=["detection_weights.pt", "recognition_weights.pt", "config.json"],
|
|
56
|
+
)
|
|
57
|
+
config = parse_model_config(config_path=model_package_content["config.json"])
|
|
58
|
+
if config.det_model not in SUPPORTED_DETECTION_MODELS:
|
|
59
|
+
raise CorruptedModelPackageError(
|
|
60
|
+
message=f"{config.det_model} model denoted in configuration not supported as DocTR detection model.",
|
|
61
|
+
help_url="https://todo",
|
|
62
|
+
)
|
|
63
|
+
if config.rec_model not in SUPPORTED_RECOGNITION_MODELS:
|
|
64
|
+
raise CorruptedModelPackageError(
|
|
65
|
+
message=f"{config.rec_model} model denoted in configuration not supported as DocTR recognition model.",
|
|
66
|
+
help_url="https://todo",
|
|
67
|
+
)
|
|
68
|
+
det_model = detection_predictor(
|
|
69
|
+
arch=config.det_model,
|
|
70
|
+
pretrained=False,
|
|
71
|
+
assume_straight_pages=assume_straight_pages,
|
|
72
|
+
preserve_aspect_ratio=preserve_aspect_ratio,
|
|
73
|
+
batch_size=detection_max_batch_size,
|
|
74
|
+
)
|
|
75
|
+
det_model.model.to(device)
|
|
76
|
+
detector_weights = torch.load(
|
|
77
|
+
model_package_content["detection_weights.pt"],
|
|
78
|
+
weights_only=True,
|
|
79
|
+
map_location=device,
|
|
80
|
+
)
|
|
81
|
+
det_model.model.load_state_dict(detector_weights)
|
|
82
|
+
rec_model = recognition_predictor(
|
|
83
|
+
arch=config.rec_model,
|
|
84
|
+
pretrained=False,
|
|
85
|
+
batch_size=recognition_max_batch_size,
|
|
86
|
+
)
|
|
87
|
+
rec_model.model.to(device)
|
|
88
|
+
rec_weights = torch.load(
|
|
89
|
+
model_package_content["recognition_weights.pt"],
|
|
90
|
+
weights_only=True,
|
|
91
|
+
map_location=device,
|
|
92
|
+
)
|
|
93
|
+
rec_model.model.load_state_dict(rec_weights)
|
|
94
|
+
model = ocr_predictor(
|
|
95
|
+
det_arch=det_model.model,
|
|
96
|
+
reco_arch=rec_model.model,
|
|
97
|
+
).to(device=device)
|
|
98
|
+
return cls(model=model, device=device)
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
model: Callable[[List[np.ndarray]], Document],
|
|
103
|
+
device: torch.device,
|
|
104
|
+
):
|
|
105
|
+
self._model = model
|
|
106
|
+
self._device = device
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def class_names(self) -> List[str]:
|
|
110
|
+
return ["block", "line", "word"]
|
|
111
|
+
|
|
112
|
+
def pre_process(
|
|
113
|
+
self,
|
|
114
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
115
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
116
|
+
**kwargs,
|
|
117
|
+
) -> Tuple[List[np.ndarray], List[ImageDimensions]]:
|
|
118
|
+
if isinstance(images, np.ndarray):
|
|
119
|
+
input_color_format = input_color_format or "bgr"
|
|
120
|
+
if input_color_format != "bgr":
|
|
121
|
+
images = images[:, :, ::-1]
|
|
122
|
+
h, w = images.shape[:2]
|
|
123
|
+
return [images], [ImageDimensions(height=h, width=w)]
|
|
124
|
+
if isinstance(images, torch.Tensor):
|
|
125
|
+
input_color_format = input_color_format or "rgb"
|
|
126
|
+
if len(images.shape) == 3:
|
|
127
|
+
images = torch.unsqueeze(images, dim=0)
|
|
128
|
+
if input_color_format != "bgr":
|
|
129
|
+
images = images[:, [2, 1, 0], :, :]
|
|
130
|
+
result = []
|
|
131
|
+
dimensions = []
|
|
132
|
+
for image in images:
|
|
133
|
+
np_image = image.permute(1, 2, 0).cpu().numpy()
|
|
134
|
+
result.append(np_image)
|
|
135
|
+
dimensions.append(
|
|
136
|
+
ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
|
|
137
|
+
)
|
|
138
|
+
return result, dimensions
|
|
139
|
+
if not isinstance(images, list):
|
|
140
|
+
raise ModelRuntimeError(
|
|
141
|
+
message="Pre-processing supports only np.array or torch.Tensor or list of above.",
|
|
142
|
+
help_url="https://todo",
|
|
143
|
+
)
|
|
144
|
+
if not len(images):
|
|
145
|
+
raise ModelRuntimeError(
|
|
146
|
+
message="Detected empty input to the model", help_url="https://todo"
|
|
147
|
+
)
|
|
148
|
+
if isinstance(images[0], np.ndarray):
|
|
149
|
+
input_color_format = input_color_format or "bgr"
|
|
150
|
+
if input_color_format != "bgr":
|
|
151
|
+
images = [i[:, :, ::-1] for i in images]
|
|
152
|
+
dimensions = [
|
|
153
|
+
ImageDimensions(height=i.shape[0], width=i.shape[1]) for i in images
|
|
154
|
+
]
|
|
155
|
+
return images, dimensions
|
|
156
|
+
if isinstance(images[0], torch.Tensor):
|
|
157
|
+
result = []
|
|
158
|
+
dimensions = []
|
|
159
|
+
input_color_format = input_color_format or "rgb"
|
|
160
|
+
for image in images:
|
|
161
|
+
if input_color_format != "bgr":
|
|
162
|
+
image = image[[2, 1, 0], :, :]
|
|
163
|
+
np_image = image.permute(1, 2, 0).cpu().numpy()
|
|
164
|
+
result.append(np_image)
|
|
165
|
+
dimensions.append(
|
|
166
|
+
ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
|
|
167
|
+
)
|
|
168
|
+
return result, dimensions
|
|
169
|
+
raise ModelRuntimeError(
|
|
170
|
+
message=f"Detected unknown input batch element: {type(images[0])}",
|
|
171
|
+
help_url="https://todo",
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def forward(
|
|
175
|
+
self,
|
|
176
|
+
pre_processed_images: List[np.ndarray],
|
|
177
|
+
**kwargs,
|
|
178
|
+
) -> Document:
|
|
179
|
+
return self._model(pre_processed_images)
|
|
180
|
+
|
|
181
|
+
def post_process(
|
|
182
|
+
self,
|
|
183
|
+
model_results: Document,
|
|
184
|
+
pre_processing_meta: List[ImageDimensions],
|
|
185
|
+
**kwargs,
|
|
186
|
+
) -> Tuple[List[str], List[Detections]]:
|
|
187
|
+
rendered_texts, all_detections = [], []
|
|
188
|
+
for result_page, original_dimensions in zip(
|
|
189
|
+
model_results.pages, pre_processing_meta
|
|
190
|
+
):
|
|
191
|
+
detections = []
|
|
192
|
+
rendered_texts.append(result_page.render())
|
|
193
|
+
for block in result_page.blocks:
|
|
194
|
+
block_elements_probs = []
|
|
195
|
+
for line in block.lines:
|
|
196
|
+
line_elements_probs = []
|
|
197
|
+
for word in line.words:
|
|
198
|
+
line_elements_probs.append(word.confidence)
|
|
199
|
+
block_elements_probs.append(word.confidence)
|
|
200
|
+
detections.append(
|
|
201
|
+
{
|
|
202
|
+
"xyxy": [
|
|
203
|
+
word.geometry[0][0],
|
|
204
|
+
word.geometry[0][1],
|
|
205
|
+
word.geometry[1][0],
|
|
206
|
+
word.geometry[1][1],
|
|
207
|
+
],
|
|
208
|
+
"class_id": 2,
|
|
209
|
+
"confidence": word.confidence,
|
|
210
|
+
"text": word.value,
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
detections.append(
|
|
214
|
+
{
|
|
215
|
+
"xyxy": [
|
|
216
|
+
line.geometry[0][0],
|
|
217
|
+
line.geometry[0][1],
|
|
218
|
+
line.geometry[1][0],
|
|
219
|
+
line.geometry[1][1],
|
|
220
|
+
],
|
|
221
|
+
"class_id": 1,
|
|
222
|
+
"confidence": sum(line_elements_probs)
|
|
223
|
+
/ len(line_elements_probs),
|
|
224
|
+
"text": line.render(),
|
|
225
|
+
}
|
|
226
|
+
)
|
|
227
|
+
detections.append(
|
|
228
|
+
{
|
|
229
|
+
"xyxy": [
|
|
230
|
+
block.geometry[0][0],
|
|
231
|
+
block.geometry[0][1],
|
|
232
|
+
block.geometry[1][0],
|
|
233
|
+
block.geometry[1][1],
|
|
234
|
+
],
|
|
235
|
+
"class_id": 0,
|
|
236
|
+
"confidence": sum(block_elements_probs)
|
|
237
|
+
/ len(block_elements_probs),
|
|
238
|
+
"text": block.render(),
|
|
239
|
+
}
|
|
240
|
+
)
|
|
241
|
+
dim_tensor = torch.tensor(
|
|
242
|
+
[
|
|
243
|
+
original_dimensions.width,
|
|
244
|
+
original_dimensions.height,
|
|
245
|
+
original_dimensions.width,
|
|
246
|
+
original_dimensions.height,
|
|
247
|
+
],
|
|
248
|
+
device=self._device,
|
|
249
|
+
)
|
|
250
|
+
xyxy = (
|
|
251
|
+
(
|
|
252
|
+
torch.tensor([e["xyxy"] for e in detections], device=self._device)
|
|
253
|
+
* dim_tensor
|
|
254
|
+
)
|
|
255
|
+
.round()
|
|
256
|
+
.int()
|
|
257
|
+
)
|
|
258
|
+
class_id = torch.tensor(
|
|
259
|
+
[e["class_id"] for e in detections], device=self._device
|
|
260
|
+
)
|
|
261
|
+
confidence = torch.tensor(
|
|
262
|
+
[e["confidence"] for e in detections], device=self._device
|
|
263
|
+
)
|
|
264
|
+
data = [{"text": e["text"]} for e in detections]
|
|
265
|
+
all_detections.append(
|
|
266
|
+
Detections(
|
|
267
|
+
xyxy=xyxy,
|
|
268
|
+
class_id=class_id,
|
|
269
|
+
confidence=confidence,
|
|
270
|
+
bboxes_metadata=data,
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
return rendered_texts, all_detections
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@dataclass
|
|
277
|
+
class DocTRConfig:
|
|
278
|
+
det_model: str
|
|
279
|
+
rec_model: str
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def parse_model_config(config_path: str) -> DocTRConfig:
|
|
283
|
+
try:
|
|
284
|
+
content = read_json(path=config_path)
|
|
285
|
+
if not content:
|
|
286
|
+
raise ValueError("file is empty.")
|
|
287
|
+
if not isinstance(content, dict):
|
|
288
|
+
raise ValueError("file is malformed (not a JSON dictionary)")
|
|
289
|
+
if "det_model" not in content or "rec_model" not in content:
|
|
290
|
+
raise ValueError(
|
|
291
|
+
"file is malformed (lack of `det_model` or `rec_model` key)"
|
|
292
|
+
)
|
|
293
|
+
return DocTRConfig(
|
|
294
|
+
det_model=content["det_model"],
|
|
295
|
+
rec_model=content["rec_model"],
|
|
296
|
+
)
|
|
297
|
+
except (IOError, OSError, ValueError) as error:
|
|
298
|
+
raise CorruptedModelPackageError(
|
|
299
|
+
message=f"Config file located under path {config_path} is malformed: "
|
|
300
|
+
f"{error}. In case that the package is "
|
|
301
|
+
f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
|
|
302
|
+
f"verify its consistency in docs.",
|
|
303
|
+
help_url="https://todo",
|
|
304
|
+
) from error
|
|
File without changes
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import easyocr
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from inference_models import Detections, StructuredOCRModel
|
|
9
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
10
|
+
from inference_models.entities import ColorFormat, ImageDimensions
|
|
11
|
+
from inference_models.errors import CorruptedModelPackageError, ModelRuntimeError
|
|
12
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
13
|
+
from inference_models.utils.file_system import read_json
|
|
14
|
+
|
|
15
|
+
Point = Tuple[int, int]
|
|
16
|
+
Coordinates = Tuple[Point, Point, Point, Point]
|
|
17
|
+
DetectedText = str
|
|
18
|
+
Confidence = float
|
|
19
|
+
EasyOCRRawPrediction = Tuple[Coordinates, DetectedText, Confidence]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
RECOGNIZED_DETECTORS = {"craft", "dbnet18", "dbnet50"}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class EasyOcrConfig(BaseModel):
|
|
26
|
+
lang_list: List[str]
|
|
27
|
+
detector_model_file_name: str
|
|
28
|
+
recognition_model_file_name: str
|
|
29
|
+
detect_network: str
|
|
30
|
+
recognition_network: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class EasyOCRTorch(
|
|
34
|
+
StructuredOCRModel[List[np.ndarray], ImageDimensions, EasyOCRRawPrediction]
|
|
35
|
+
):
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_pretrained(
|
|
39
|
+
cls,
|
|
40
|
+
model_name_or_path: str,
|
|
41
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
42
|
+
**kwargs,
|
|
43
|
+
) -> "StructuredOCRModel":
|
|
44
|
+
package_contents = get_model_package_contents(
|
|
45
|
+
model_package_dir=model_name_or_path, elements=["easy-ocr-config.json"]
|
|
46
|
+
)
|
|
47
|
+
config = parse_easy_ocr_config(
|
|
48
|
+
config_path=package_contents["easy-ocr-config.json"]
|
|
49
|
+
)
|
|
50
|
+
device_string = device.type
|
|
51
|
+
if device.type == "cuda" and device.index:
|
|
52
|
+
device_string = f"{device_string}:{device.index}"
|
|
53
|
+
try:
|
|
54
|
+
model = easyocr.Reader(
|
|
55
|
+
config.lang_list,
|
|
56
|
+
download_enabled=False,
|
|
57
|
+
model_storage_directory=model_name_or_path,
|
|
58
|
+
user_network_directory=model_name_or_path,
|
|
59
|
+
detect_network=config.detect_network,
|
|
60
|
+
recog_network=config.recognition_network,
|
|
61
|
+
detector=True,
|
|
62
|
+
recognizer=True,
|
|
63
|
+
gpu=device_string,
|
|
64
|
+
)
|
|
65
|
+
except Exception as error:
|
|
66
|
+
raise CorruptedModelPackageError(
|
|
67
|
+
message=f"EasyOCR model package is broken - could not parse model config file. Error: {error}"
|
|
68
|
+
f"If you attempt to run `inference-models` locally - inspect the contents of local directory to check "
|
|
69
|
+
f"model package - config file is corrupted. If you run the model on Roboflow platform - "
|
|
70
|
+
f"contact us.",
|
|
71
|
+
help_url="https://todo",
|
|
72
|
+
) from error
|
|
73
|
+
return cls(model=model, device=device)
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
model: easyocr.Reader,
|
|
78
|
+
device: torch.device,
|
|
79
|
+
):
|
|
80
|
+
self._model = model
|
|
81
|
+
self._device = device
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def class_names(self) -> List[str]:
|
|
85
|
+
return ["text-region"]
|
|
86
|
+
|
|
87
|
+
def pre_process(
|
|
88
|
+
self,
|
|
89
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
90
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
91
|
+
**kwargs,
|
|
92
|
+
) -> Tuple[List[np.ndarray], List[ImageDimensions]]:
|
|
93
|
+
if isinstance(images, np.ndarray):
|
|
94
|
+
input_color_format = input_color_format or "bgr"
|
|
95
|
+
if input_color_format != "bgr":
|
|
96
|
+
images = images[:, :, ::-1]
|
|
97
|
+
h, w = images.shape[:2]
|
|
98
|
+
return [images], [ImageDimensions(height=h, width=w)]
|
|
99
|
+
if isinstance(images, torch.Tensor):
|
|
100
|
+
input_color_format = input_color_format or "rgb"
|
|
101
|
+
if len(images.shape) == 3:
|
|
102
|
+
images = torch.unsqueeze(images, dim=0)
|
|
103
|
+
if input_color_format != "bgr":
|
|
104
|
+
images = images[:, [2, 1, 0], :, :]
|
|
105
|
+
result = []
|
|
106
|
+
dimensions = []
|
|
107
|
+
for image in images:
|
|
108
|
+
np_image = image.permute(1, 2, 0).cpu().numpy()
|
|
109
|
+
result.append(np_image)
|
|
110
|
+
dimensions.append(
|
|
111
|
+
ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
|
|
112
|
+
)
|
|
113
|
+
return result, dimensions
|
|
114
|
+
if not isinstance(images, list):
|
|
115
|
+
raise ModelRuntimeError(
|
|
116
|
+
message="Pre-processing supports only np.array or torch.Tensor or list of above.",
|
|
117
|
+
help_url="https://todo",
|
|
118
|
+
)
|
|
119
|
+
if not len(images):
|
|
120
|
+
raise ModelRuntimeError(
|
|
121
|
+
message="Detected empty input to the model", help_url="https://todo"
|
|
122
|
+
)
|
|
123
|
+
if isinstance(images[0], np.ndarray):
|
|
124
|
+
input_color_format = input_color_format or "bgr"
|
|
125
|
+
if input_color_format != "bgr":
|
|
126
|
+
images = [i[:, :, ::-1] for i in images]
|
|
127
|
+
dimensions = [
|
|
128
|
+
ImageDimensions(height=i.shape[0], width=i.shape[1]) for i in images
|
|
129
|
+
]
|
|
130
|
+
return images, dimensions
|
|
131
|
+
if isinstance(images[0], torch.Tensor):
|
|
132
|
+
result = []
|
|
133
|
+
dimensions = []
|
|
134
|
+
input_color_format = input_color_format or "rgb"
|
|
135
|
+
for image in images:
|
|
136
|
+
if input_color_format != "bgr":
|
|
137
|
+
image = image[[2, 1, 0], :, :]
|
|
138
|
+
np_image = image.permute(1, 2, 0).cpu().numpy()
|
|
139
|
+
result.append(np_image)
|
|
140
|
+
dimensions.append(
|
|
141
|
+
ImageDimensions(height=np_image.shape[0], width=np_image.shape[1])
|
|
142
|
+
)
|
|
143
|
+
return result, dimensions
|
|
144
|
+
raise ModelRuntimeError(
|
|
145
|
+
message=f"Detected unknown input batch element: {type(images[0])}",
|
|
146
|
+
help_url="https://todo",
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def forward(
|
|
150
|
+
self, pre_processed_images: List[np.ndarray], **kwargs
|
|
151
|
+
) -> List[EasyOCRRawPrediction]:
|
|
152
|
+
all_results = []
|
|
153
|
+
for image in pre_processed_images:
|
|
154
|
+
image_results_raw = self._model.readtext(image)
|
|
155
|
+
image_results_parsed = [
|
|
156
|
+
(
|
|
157
|
+
[
|
|
158
|
+
[x.item() if not isinstance(x, (int, float)) else x for x in c]
|
|
159
|
+
for c in res[0]
|
|
160
|
+
],
|
|
161
|
+
res[1],
|
|
162
|
+
res[2].item() if not isinstance(res[2], (int, float)) else res[2],
|
|
163
|
+
)
|
|
164
|
+
for res in image_results_raw
|
|
165
|
+
]
|
|
166
|
+
all_results.append(image_results_parsed)
|
|
167
|
+
return all_results
|
|
168
|
+
|
|
169
|
+
def post_process(
|
|
170
|
+
self,
|
|
171
|
+
model_results: List[EasyOCRRawPrediction],
|
|
172
|
+
pre_processing_meta: List[ImageDimensions],
|
|
173
|
+
confidence_threshold: float = 0.3,
|
|
174
|
+
text_regions_separator: str = " ",
|
|
175
|
+
**kwargs,
|
|
176
|
+
) -> Tuple[List[str], List[Detections]]:
|
|
177
|
+
rendered_texts, all_detections = [], []
|
|
178
|
+
for single_image_result, original_dimensions in zip(
|
|
179
|
+
model_results, pre_processing_meta
|
|
180
|
+
):
|
|
181
|
+
whole_image_text = []
|
|
182
|
+
xyxy = []
|
|
183
|
+
confidence = []
|
|
184
|
+
class_id = []
|
|
185
|
+
for box, text, text_confidence in single_image_result:
|
|
186
|
+
if text_confidence < confidence_threshold:
|
|
187
|
+
continue
|
|
188
|
+
whole_image_text.append(text)
|
|
189
|
+
min_x = min(p[0] for p in box)
|
|
190
|
+
min_y = min(p[1] for p in box)
|
|
191
|
+
max_x = max(p[0] for p in box)
|
|
192
|
+
max_y = max(p[1] for p in box)
|
|
193
|
+
box_xyxy = [min_x, min_y, max_x, max_y]
|
|
194
|
+
xyxy.append(box_xyxy)
|
|
195
|
+
confidence.append(float(text_confidence))
|
|
196
|
+
class_id.append(0)
|
|
197
|
+
while_image_text_joined = text_regions_separator.join(whole_image_text)
|
|
198
|
+
rendered_texts.append(while_image_text_joined)
|
|
199
|
+
data = [{"text": text} for text in whole_image_text]
|
|
200
|
+
all_detections.append(
|
|
201
|
+
Detections(
|
|
202
|
+
xyxy=torch.tensor(xyxy, device=self._device),
|
|
203
|
+
class_id=torch.tensor(class_id, device=self._device),
|
|
204
|
+
confidence=torch.tensor(confidence, device=self._device),
|
|
205
|
+
bboxes_metadata=data,
|
|
206
|
+
)
|
|
207
|
+
)
|
|
208
|
+
return rendered_texts, all_detections
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def parse_easy_ocr_config(config_path: str) -> EasyOcrConfig:
|
|
212
|
+
try:
|
|
213
|
+
raw_config = read_json(config_path)
|
|
214
|
+
return EasyOcrConfig.model_validate(raw_config)
|
|
215
|
+
except Exception as error:
|
|
216
|
+
raise CorruptedModelPackageError(
|
|
217
|
+
message=f"EasyOCR model package is broken - could not parse model config file. Error: {error}"
|
|
218
|
+
f"If you attempt to run `inference-models` locally - inspect the contents of local directory to check "
|
|
219
|
+
f"model package - config file is corrupted. If you run the model on Roboflow platform - "
|
|
220
|
+
f"contact us.",
|
|
221
|
+
help_url="https://todo",
|
|
222
|
+
) from error
|
|
File without changes
|