inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from typing import List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from inference_models import InstanceDetections, InstanceSegmentationModel
|
|
8
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
9
|
+
from inference_models.entities import ColorFormat
|
|
10
|
+
from inference_models.errors import (
|
|
11
|
+
EnvironmentConfigurationError,
|
|
12
|
+
MissingDependencyError,
|
|
13
|
+
)
|
|
14
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
15
|
+
from inference_models.models.common.onnx import (
|
|
16
|
+
run_session_with_batch_size_limit,
|
|
17
|
+
set_execution_provider_defaults,
|
|
18
|
+
)
|
|
19
|
+
from inference_models.models.common.roboflow.model_packages import (
|
|
20
|
+
InferenceConfig,
|
|
21
|
+
PreProcessingMetadata,
|
|
22
|
+
ResizeMode,
|
|
23
|
+
parse_class_names_file,
|
|
24
|
+
parse_inference_config,
|
|
25
|
+
)
|
|
26
|
+
from inference_models.models.common.roboflow.pre_processing import (
|
|
27
|
+
pre_process_network_input,
|
|
28
|
+
)
|
|
29
|
+
from inference_models.models.rfdetr.class_remapping import (
|
|
30
|
+
ClassesReMapping,
|
|
31
|
+
prepare_class_remapping,
|
|
32
|
+
)
|
|
33
|
+
from inference_models.models.rfdetr.common import (
|
|
34
|
+
post_process_instance_segmentation_results,
|
|
35
|
+
)
|
|
36
|
+
from inference_models.utils.onnx_introspection import (
|
|
37
|
+
get_selected_onnx_execution_providers,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
import onnxruntime
|
|
42
|
+
except ImportError as import_error:
|
|
43
|
+
raise MissingDependencyError(
|
|
44
|
+
message=f"Could not import YOLOv8 model with ONNX backend - this error means that some additional dependencies "
|
|
45
|
+
f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
|
|
46
|
+
f"program, make sure the following extras of the package are installed: \n"
|
|
47
|
+
f"\t* `onnx-cpu` - when you wish to use library with CPU support only\n"
|
|
48
|
+
f"\t* `onnx-cu12` - for running on GPU with Cuda 12 installed\n"
|
|
49
|
+
f"\t* `onnx-cu118` - for running on GPU with Cuda 11.8 installed\n"
|
|
50
|
+
f"\t* `onnx-jp6-cu126` - for running on Jetson with Jetpack 6\n"
|
|
51
|
+
f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
|
|
52
|
+
f"You can also contact Roboflow to get support.",
|
|
53
|
+
help_url="https://todo",
|
|
54
|
+
) from import_error
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class RFDetrForInstanceSegmentationOnnx(
|
|
58
|
+
InstanceSegmentationModel[
|
|
59
|
+
torch.Tensor,
|
|
60
|
+
PreProcessingMetadata,
|
|
61
|
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
62
|
+
]
|
|
63
|
+
):
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def from_pretrained(
|
|
67
|
+
cls,
|
|
68
|
+
model_name_or_path: str,
|
|
69
|
+
onnx_execution_providers: Optional[List[Union[str, tuple]]] = None,
|
|
70
|
+
default_onnx_trt_options: bool = True,
|
|
71
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
72
|
+
**kwargs,
|
|
73
|
+
) -> "RFDetrForInstanceSegmentationOnnx":
|
|
74
|
+
if onnx_execution_providers is None:
|
|
75
|
+
onnx_execution_providers = get_selected_onnx_execution_providers()
|
|
76
|
+
if not onnx_execution_providers:
|
|
77
|
+
raise EnvironmentConfigurationError(
|
|
78
|
+
message=f"Could not initialize model - selected backend is ONNX which requires execution provider to "
|
|
79
|
+
f"be specified - explicitly in `from_pretrained(...)` method or via env variable "
|
|
80
|
+
f"`ONNXRUNTIME_EXECUTION_PROVIDERS`. If you run model locally - adjust your setup, otherwise "
|
|
81
|
+
f"contact the platform support.",
|
|
82
|
+
help_url="https://todo",
|
|
83
|
+
)
|
|
84
|
+
onnx_execution_providers = set_execution_provider_defaults(
|
|
85
|
+
providers=onnx_execution_providers,
|
|
86
|
+
model_package_path=model_name_or_path,
|
|
87
|
+
device=device,
|
|
88
|
+
default_onnx_trt_options=default_onnx_trt_options,
|
|
89
|
+
)
|
|
90
|
+
model_package_content = get_model_package_contents(
|
|
91
|
+
model_package_dir=model_name_or_path,
|
|
92
|
+
elements=[
|
|
93
|
+
"class_names.txt",
|
|
94
|
+
"inference_config.json",
|
|
95
|
+
"weights.onnx",
|
|
96
|
+
],
|
|
97
|
+
)
|
|
98
|
+
class_names = parse_class_names_file(
|
|
99
|
+
class_names_path=model_package_content["class_names.txt"]
|
|
100
|
+
)
|
|
101
|
+
inference_config = parse_inference_config(
|
|
102
|
+
config_path=model_package_content["inference_config.json"],
|
|
103
|
+
allowed_resize_modes={
|
|
104
|
+
ResizeMode.STRETCH_TO,
|
|
105
|
+
ResizeMode.LETTERBOX,
|
|
106
|
+
ResizeMode.CENTER_CROP,
|
|
107
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
classes_re_mapping = None
|
|
111
|
+
if inference_config.class_names_operations:
|
|
112
|
+
class_names, classes_re_mapping = prepare_class_remapping(
|
|
113
|
+
class_names=class_names,
|
|
114
|
+
class_names_operations=inference_config.class_names_operations,
|
|
115
|
+
device=device,
|
|
116
|
+
)
|
|
117
|
+
session = onnxruntime.InferenceSession(
|
|
118
|
+
path_or_bytes=model_package_content["weights.onnx"],
|
|
119
|
+
providers=onnx_execution_providers,
|
|
120
|
+
)
|
|
121
|
+
input_batch_size = session.get_inputs()[0].shape[0]
|
|
122
|
+
if isinstance(input_batch_size, str):
|
|
123
|
+
input_batch_size = None
|
|
124
|
+
input_name = session.get_inputs()[0].name
|
|
125
|
+
return cls(
|
|
126
|
+
session=session,
|
|
127
|
+
input_name=input_name,
|
|
128
|
+
class_names=class_names,
|
|
129
|
+
classes_re_mapping=classes_re_mapping,
|
|
130
|
+
inference_config=inference_config,
|
|
131
|
+
device=device,
|
|
132
|
+
input_batch_size=input_batch_size,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
session: onnxruntime.InferenceSession,
|
|
138
|
+
input_name: str,
|
|
139
|
+
class_names: List[str],
|
|
140
|
+
classes_re_mapping: Optional[ClassesReMapping],
|
|
141
|
+
inference_config: InferenceConfig,
|
|
142
|
+
device: torch.device,
|
|
143
|
+
input_batch_size: Optional[int],
|
|
144
|
+
):
|
|
145
|
+
self._session = session
|
|
146
|
+
self._input_name = input_name
|
|
147
|
+
self._inference_config = inference_config
|
|
148
|
+
self._class_names = class_names
|
|
149
|
+
self._classes_re_mapping = classes_re_mapping
|
|
150
|
+
self._device = device
|
|
151
|
+
self._min_batch_size = input_batch_size
|
|
152
|
+
self._max_batch_size = (
|
|
153
|
+
input_batch_size
|
|
154
|
+
if input_batch_size is not None
|
|
155
|
+
else inference_config.forward_pass.max_dynamic_batch_size
|
|
156
|
+
)
|
|
157
|
+
self._session_thread_lock = threading.Lock()
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def class_names(self) -> List[str]:
|
|
161
|
+
return self._class_names
|
|
162
|
+
|
|
163
|
+
def pre_process(
|
|
164
|
+
self,
|
|
165
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
166
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
167
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
168
|
+
**kwargs,
|
|
169
|
+
) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
|
|
170
|
+
return pre_process_network_input(
|
|
171
|
+
images=images,
|
|
172
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
173
|
+
network_input=self._inference_config.network_input,
|
|
174
|
+
target_device=self._device,
|
|
175
|
+
input_color_format=input_color_format,
|
|
176
|
+
image_size_wh=image_size,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def forward(
|
|
180
|
+
self, pre_processed_images: torch.Tensor, **kwargs
|
|
181
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
182
|
+
with self._session_thread_lock:
|
|
183
|
+
bboxes, logits, masks = run_session_with_batch_size_limit(
|
|
184
|
+
session=self._session,
|
|
185
|
+
inputs={self._input_name: pre_processed_images},
|
|
186
|
+
min_batch_size=self._min_batch_size,
|
|
187
|
+
max_batch_size=self._max_batch_size,
|
|
188
|
+
)
|
|
189
|
+
return bboxes, logits, masks
|
|
190
|
+
|
|
191
|
+
def post_process(
|
|
192
|
+
self,
|
|
193
|
+
model_results: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
194
|
+
pre_processing_meta: List[PreProcessingMetadata],
|
|
195
|
+
threshold: float = 0.5,
|
|
196
|
+
**kwargs,
|
|
197
|
+
) -> List[InstanceDetections]:
|
|
198
|
+
bboxes, logits, masks = model_results
|
|
199
|
+
return post_process_instance_segmentation_results(
|
|
200
|
+
bboxes=bboxes,
|
|
201
|
+
logits=logits,
|
|
202
|
+
masks=masks,
|
|
203
|
+
pre_processing_meta=pre_processing_meta,
|
|
204
|
+
threshold=threshold,
|
|
205
|
+
classes_re_mapping=self._classes_re_mapping,
|
|
206
|
+
)
|
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from inference_models import InstanceDetections, InstanceSegmentationModel
|
|
9
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
10
|
+
from inference_models.entities import ColorFormat
|
|
11
|
+
from inference_models.errors import (
|
|
12
|
+
CorruptedModelPackageError,
|
|
13
|
+
ModelLoadingError,
|
|
14
|
+
ModelRuntimeError,
|
|
15
|
+
)
|
|
16
|
+
from inference_models.logger import LOGGER
|
|
17
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
18
|
+
from inference_models.models.common.roboflow.model_packages import (
|
|
19
|
+
ColorMode,
|
|
20
|
+
DivisiblePadding,
|
|
21
|
+
InferenceConfig,
|
|
22
|
+
NetworkInputDefinition,
|
|
23
|
+
PreProcessingMetadata,
|
|
24
|
+
ResizeMode,
|
|
25
|
+
TrainingInputSize,
|
|
26
|
+
parse_class_names_file,
|
|
27
|
+
parse_inference_config,
|
|
28
|
+
)
|
|
29
|
+
from inference_models.models.common.roboflow.pre_processing import (
|
|
30
|
+
pre_process_network_input,
|
|
31
|
+
)
|
|
32
|
+
from inference_models.models.rfdetr.class_remapping import (
|
|
33
|
+
ClassesReMapping,
|
|
34
|
+
prepare_class_remapping,
|
|
35
|
+
)
|
|
36
|
+
from inference_models.models.rfdetr.common import (
|
|
37
|
+
parse_model_type,
|
|
38
|
+
post_process_instance_segmentation_results,
|
|
39
|
+
)
|
|
40
|
+
from inference_models.models.rfdetr.default_labels import resolve_labels
|
|
41
|
+
from inference_models.models.rfdetr.post_processor import PostProcess
|
|
42
|
+
from inference_models.models.rfdetr.rfdetr_base_pytorch import (
|
|
43
|
+
LWDETR,
|
|
44
|
+
RFDETRSegPreviewConfig,
|
|
45
|
+
build_model,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
try:
|
|
49
|
+
torch.set_float32_matmul_precision("high")
|
|
50
|
+
except:
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
CONFIG_FOR_MODEL_TYPE = {
|
|
54
|
+
"rfdetr-seg-preview": RFDETRSegPreviewConfig,
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class RFDetrForInstanceSegmentationTorch(
|
|
59
|
+
InstanceSegmentationModel[
|
|
60
|
+
torch.Tensor,
|
|
61
|
+
PreProcessingMetadata,
|
|
62
|
+
dict,
|
|
63
|
+
]
|
|
64
|
+
):
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_pretrained(
|
|
68
|
+
cls,
|
|
69
|
+
model_name_or_path: str,
|
|
70
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
71
|
+
model_type: Optional[str] = None,
|
|
72
|
+
labels: Optional[Union[str, List[str]]] = None,
|
|
73
|
+
resolution: Optional[int] = None,
|
|
74
|
+
**kwargs,
|
|
75
|
+
) -> "RFDetrForInstanceSegmentationTorch":
|
|
76
|
+
if os.path.isfile(model_name_or_path):
|
|
77
|
+
return cls.from_checkpoint_file(
|
|
78
|
+
checkpoint_path=model_name_or_path,
|
|
79
|
+
model_type=model_type,
|
|
80
|
+
labels=labels,
|
|
81
|
+
resolution=resolution,
|
|
82
|
+
)
|
|
83
|
+
model_package_content = get_model_package_contents(
|
|
84
|
+
model_package_dir=model_name_or_path,
|
|
85
|
+
elements=[
|
|
86
|
+
"class_names.txt",
|
|
87
|
+
"inference_config.json",
|
|
88
|
+
"model_type.json",
|
|
89
|
+
"weights.pth",
|
|
90
|
+
],
|
|
91
|
+
)
|
|
92
|
+
class_names = parse_class_names_file(
|
|
93
|
+
class_names_path=model_package_content["class_names.txt"]
|
|
94
|
+
)
|
|
95
|
+
inference_config = parse_inference_config(
|
|
96
|
+
config_path=model_package_content["inference_config.json"],
|
|
97
|
+
allowed_resize_modes={
|
|
98
|
+
ResizeMode.STRETCH_TO,
|
|
99
|
+
ResizeMode.LETTERBOX,
|
|
100
|
+
ResizeMode.CENTER_CROP,
|
|
101
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
102
|
+
},
|
|
103
|
+
)
|
|
104
|
+
classes_re_mapping = None
|
|
105
|
+
if inference_config.class_names_operations:
|
|
106
|
+
class_names, classes_re_mapping = prepare_class_remapping(
|
|
107
|
+
class_names=class_names,
|
|
108
|
+
class_names_operations=inference_config.class_names_operations,
|
|
109
|
+
device=device,
|
|
110
|
+
)
|
|
111
|
+
weights_dict = torch.load(
|
|
112
|
+
model_package_content["weights.pth"],
|
|
113
|
+
map_location=device,
|
|
114
|
+
weights_only=False,
|
|
115
|
+
)["model"]
|
|
116
|
+
model_type = parse_model_type(
|
|
117
|
+
config_path=model_package_content["model_type.json"]
|
|
118
|
+
)
|
|
119
|
+
if model_type not in CONFIG_FOR_MODEL_TYPE:
|
|
120
|
+
raise CorruptedModelPackageError(
|
|
121
|
+
message=f"Model package describes model_type as '{model_type}' which is not supported. "
|
|
122
|
+
f"Supported model types: {list(CONFIG_FOR_MODEL_TYPE.keys())}.",
|
|
123
|
+
help_url="https://todo",
|
|
124
|
+
)
|
|
125
|
+
model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
|
|
126
|
+
checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
|
|
127
|
+
model_config.num_classes = checkpoint_num_classes - 1
|
|
128
|
+
model_config.resolution = (
|
|
129
|
+
inference_config.network_input.training_input_size.height
|
|
130
|
+
)
|
|
131
|
+
model = build_model(config=model_config)
|
|
132
|
+
model.load_state_dict(weights_dict)
|
|
133
|
+
model = model.eval().to(device)
|
|
134
|
+
post_processor = PostProcess()
|
|
135
|
+
return cls(
|
|
136
|
+
model=model,
|
|
137
|
+
class_names=class_names,
|
|
138
|
+
classes_re_mapping=classes_re_mapping,
|
|
139
|
+
device=device,
|
|
140
|
+
inference_config=inference_config,
|
|
141
|
+
post_processor=post_processor,
|
|
142
|
+
resolution=model_config.resolution,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
@classmethod
|
|
146
|
+
def from_checkpoint_file(
|
|
147
|
+
cls,
|
|
148
|
+
checkpoint_path: str,
|
|
149
|
+
model_type: Optional[str] = "rfdetr-seg-preview",
|
|
150
|
+
labels: Optional[Union[str, List[str]]] = None,
|
|
151
|
+
resolution: Optional[int] = None,
|
|
152
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
153
|
+
):
|
|
154
|
+
if model_type is None:
|
|
155
|
+
raise ModelLoadingError(
|
|
156
|
+
message="While loading RFDetr model (using torch backend) could not determine `model_type`. "
|
|
157
|
+
"If you used `RFDetrForObjectDetectionTorch` directly imported in your code, please pass "
|
|
158
|
+
f"one of the value: {CONFIG_FOR_MODEL_TYPE.keys()} as the parameter. If you see this "
|
|
159
|
+
f"error, while using `AutoModel.from_pretrained(...)` or thrown from managed Roboflow service, "
|
|
160
|
+
f"this is a bug - raise the issue: https://github.com/roboflow/inference/issue providing "
|
|
161
|
+
f"full context.",
|
|
162
|
+
help_url="https://todo",
|
|
163
|
+
)
|
|
164
|
+
weights_dict = torch.load(
|
|
165
|
+
checkpoint_path,
|
|
166
|
+
map_location=device,
|
|
167
|
+
weights_only=False,
|
|
168
|
+
)["model"]
|
|
169
|
+
if model_type not in CONFIG_FOR_MODEL_TYPE:
|
|
170
|
+
raise ModelLoadingError(
|
|
171
|
+
message=f"Model package describes model_type as '{model_type}' which is not supported. "
|
|
172
|
+
f"Supported model types: {list(CONFIG_FOR_MODEL_TYPE.keys())}.",
|
|
173
|
+
help_url="https://todo",
|
|
174
|
+
)
|
|
175
|
+
model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
|
|
176
|
+
divisibility = model_config.num_windows * model_config.patch_size
|
|
177
|
+
if resolution is not None:
|
|
178
|
+
if resolution < 0 or resolution % divisibility != 0:
|
|
179
|
+
raise ModelLoadingError(
|
|
180
|
+
message=f"Attempted to load RFDetr model (using torch backend) with `resolution` parameter which "
|
|
181
|
+
f"is invalid - the model required positive value divisible by 56. Make sure you used "
|
|
182
|
+
f"proper value, corresponding to the one used to train the model.",
|
|
183
|
+
help_url="https://todo",
|
|
184
|
+
)
|
|
185
|
+
model_config.resolution = resolution
|
|
186
|
+
inference_config = InferenceConfig(
|
|
187
|
+
network_input=NetworkInputDefinition(
|
|
188
|
+
training_input_size=TrainingInputSize(
|
|
189
|
+
height=model_config.resolution,
|
|
190
|
+
width=model_config.resolution,
|
|
191
|
+
),
|
|
192
|
+
dynamic_spatial_size_supported=True,
|
|
193
|
+
dynamic_spatial_size_mode=DivisiblePadding(
|
|
194
|
+
type="pad-to-be-divisible",
|
|
195
|
+
value=divisibility,
|
|
196
|
+
),
|
|
197
|
+
color_mode=ColorMode.BGR,
|
|
198
|
+
resize_mode=ResizeMode.STRETCH_TO,
|
|
199
|
+
input_channels=3,
|
|
200
|
+
scaling_factor=255,
|
|
201
|
+
normalization=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
|
|
205
|
+
model_config.num_classes = checkpoint_num_classes - 1
|
|
206
|
+
model = build_model(config=model_config)
|
|
207
|
+
if labels is None:
|
|
208
|
+
class_names = [f"class_{i}" for i in range(checkpoint_num_classes)]
|
|
209
|
+
elif isinstance(labels, str):
|
|
210
|
+
class_names = resolve_labels(labels=labels)
|
|
211
|
+
else:
|
|
212
|
+
class_names = labels
|
|
213
|
+
if checkpoint_num_classes != len(class_names):
|
|
214
|
+
raise ModelLoadingError(
|
|
215
|
+
message=f"Checkpoint pointed to load RFDetr defines {checkpoint_num_classes} output classes, but "
|
|
216
|
+
f"loaded labels define {len(class_names)} classes - fix the value of `labels` parameter.",
|
|
217
|
+
help_url="https://todo",
|
|
218
|
+
)
|
|
219
|
+
model.load_state_dict(weights_dict)
|
|
220
|
+
model = model.eval().to(device)
|
|
221
|
+
post_processor = PostProcess()
|
|
222
|
+
return cls(
|
|
223
|
+
model=model,
|
|
224
|
+
class_names=class_names,
|
|
225
|
+
classes_re_mapping=None,
|
|
226
|
+
device=device,
|
|
227
|
+
inference_config=inference_config,
|
|
228
|
+
post_processor=post_processor,
|
|
229
|
+
resolution=model_config.resolution,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
model: LWDETR,
|
|
235
|
+
inference_config: InferenceConfig,
|
|
236
|
+
class_names: List[str],
|
|
237
|
+
classes_re_mapping: Optional[ClassesReMapping],
|
|
238
|
+
device: torch.device,
|
|
239
|
+
post_processor: PostProcess,
|
|
240
|
+
resolution: int,
|
|
241
|
+
):
|
|
242
|
+
self._model = model
|
|
243
|
+
self._inference_config = inference_config
|
|
244
|
+
self._class_names = class_names
|
|
245
|
+
self._classes_re_mapping = classes_re_mapping
|
|
246
|
+
self._post_processor = post_processor
|
|
247
|
+
self._device = device
|
|
248
|
+
self._resolution = resolution
|
|
249
|
+
self._has_warned_about_not_being_optimized_for_inference = False
|
|
250
|
+
self._inference_model: Optional[LWDETR] = None
|
|
251
|
+
self._optimized_has_been_compiled = False
|
|
252
|
+
self._optimized_batch_size = None
|
|
253
|
+
self._optimized_dtype = None
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def class_names(self) -> List[str]:
|
|
257
|
+
return self._class_names
|
|
258
|
+
|
|
259
|
+
def optimize_for_inference(
|
|
260
|
+
self,
|
|
261
|
+
compile: bool = True,
|
|
262
|
+
batch_size: int = 1,
|
|
263
|
+
dtype: torch.dtype = torch.float32,
|
|
264
|
+
) -> None:
|
|
265
|
+
self.remove_optimized_model()
|
|
266
|
+
self._inference_model = deepcopy(self._model)
|
|
267
|
+
self._inference_model.eval()
|
|
268
|
+
self._inference_model.export()
|
|
269
|
+
self._inference_model = self._inference_model.to(dtype=dtype)
|
|
270
|
+
self._optimized_dtype = dtype
|
|
271
|
+
if compile:
|
|
272
|
+
self._inference_model = torch.jit.trace(
|
|
273
|
+
self._inference_model,
|
|
274
|
+
torch.randn(
|
|
275
|
+
batch_size,
|
|
276
|
+
3,
|
|
277
|
+
self._resolution,
|
|
278
|
+
self._resolution,
|
|
279
|
+
device=self._device,
|
|
280
|
+
dtype=dtype,
|
|
281
|
+
),
|
|
282
|
+
)
|
|
283
|
+
self._optimized_has_been_compiled = True
|
|
284
|
+
self._optimized_batch_size = batch_size
|
|
285
|
+
|
|
286
|
+
def remove_optimized_model(self) -> None:
|
|
287
|
+
self._has_warned_about_not_being_optimized_for_inference = False
|
|
288
|
+
self._inference_model = None
|
|
289
|
+
self._optimized_has_been_compiled = False
|
|
290
|
+
self._optimized_batch_size = None
|
|
291
|
+
|
|
292
|
+
def pre_process(
|
|
293
|
+
self,
|
|
294
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
295
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
296
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
297
|
+
**kwargs,
|
|
298
|
+
) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
|
|
299
|
+
return pre_process_network_input(
|
|
300
|
+
images=images,
|
|
301
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
302
|
+
network_input=self._inference_config.network_input,
|
|
303
|
+
target_device=self._device,
|
|
304
|
+
input_color_format=input_color_format,
|
|
305
|
+
image_size_wh=image_size,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> dict:
|
|
309
|
+
if (
|
|
310
|
+
self._inference_model is None
|
|
311
|
+
and not self._has_warned_about_not_being_optimized_for_inference
|
|
312
|
+
):
|
|
313
|
+
LOGGER.warning(
|
|
314
|
+
"Model is not optimized for inference. "
|
|
315
|
+
"Latency may be higher than expected. "
|
|
316
|
+
"You can optimize the model for inference by calling model.optimize_for_inference()."
|
|
317
|
+
)
|
|
318
|
+
self._has_warned_about_not_being_optimized_for_inference = True
|
|
319
|
+
if self._inference_model is not None:
|
|
320
|
+
if (self._resolution, self._resolution) != tuple(
|
|
321
|
+
pre_processed_images.shape[2:]
|
|
322
|
+
):
|
|
323
|
+
raise ModelRuntimeError(
|
|
324
|
+
message=f"Resolution mismatch. Model was optimized for resolution {self._resolution}, "
|
|
325
|
+
f"but got {tuple(pre_processed_images.shape[2:])}. "
|
|
326
|
+
"You can explicitly remove the optimized model by calling model.remove_optimized_model().",
|
|
327
|
+
help_url="https://todo",
|
|
328
|
+
)
|
|
329
|
+
if self._optimized_has_been_compiled:
|
|
330
|
+
if self._optimized_batch_size != pre_processed_images.shape[0]:
|
|
331
|
+
raise ModelRuntimeError(
|
|
332
|
+
message="Batch size mismatch. Optimized model was compiled for batch size "
|
|
333
|
+
f"{self._optimized_batch_size}, but got {pre_processed_images.shape[0]}. "
|
|
334
|
+
"You can explicitly remove the optimized model by calling model.remove_optimized_model(). "
|
|
335
|
+
"Alternatively, you can recompile the optimized model for a different batch size "
|
|
336
|
+
"by calling model.optimize_for_inference(batch_size=<new_batch_size>).",
|
|
337
|
+
help_url="https://todo",
|
|
338
|
+
)
|
|
339
|
+
with torch.inference_mode():
|
|
340
|
+
if self._inference_model:
|
|
341
|
+
predictions = self._inference_model(
|
|
342
|
+
pre_processed_images.to(dtype=self._optimized_dtype)
|
|
343
|
+
)
|
|
344
|
+
else:
|
|
345
|
+
predictions = self._model(pre_processed_images)
|
|
346
|
+
if isinstance(predictions, tuple):
|
|
347
|
+
predictions = {
|
|
348
|
+
"pred_logits": predictions[1],
|
|
349
|
+
"pred_boxes": predictions[0],
|
|
350
|
+
"pred_masks": predictions[2],
|
|
351
|
+
}
|
|
352
|
+
return predictions
|
|
353
|
+
|
|
354
|
+
def post_process(
|
|
355
|
+
self,
|
|
356
|
+
model_results: dict,
|
|
357
|
+
pre_processing_meta: List[PreProcessingMetadata],
|
|
358
|
+
threshold: float = 0.5,
|
|
359
|
+
**kwargs,
|
|
360
|
+
) -> List[InstanceDetections]:
|
|
361
|
+
bboxes, logits, masks = (
|
|
362
|
+
model_results["pred_boxes"],
|
|
363
|
+
model_results["pred_logits"],
|
|
364
|
+
model_results["pred_masks"],
|
|
365
|
+
)
|
|
366
|
+
return post_process_instance_segmentation_results(
|
|
367
|
+
bboxes=bboxes,
|
|
368
|
+
logits=logits,
|
|
369
|
+
masks=masks,
|
|
370
|
+
pre_processing_meta=pre_processing_meta,
|
|
371
|
+
threshold=threshold,
|
|
372
|
+
classes_re_mapping=self._classes_re_mapping,
|
|
373
|
+
)
|