inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import timm
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from inference_models import (
|
|
9
|
+
ClassificationModel,
|
|
10
|
+
ClassificationPrediction,
|
|
11
|
+
MultiLabelClassificationModel,
|
|
12
|
+
MultiLabelClassificationPrediction,
|
|
13
|
+
)
|
|
14
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
15
|
+
from inference_models.entities import ColorFormat
|
|
16
|
+
from inference_models.errors import CorruptedModelPackageError
|
|
17
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
18
|
+
from inference_models.models.common.roboflow.model_packages import (
|
|
19
|
+
InferenceConfig,
|
|
20
|
+
ResizeMode,
|
|
21
|
+
parse_class_names_file,
|
|
22
|
+
parse_inference_config,
|
|
23
|
+
)
|
|
24
|
+
from inference_models.models.common.roboflow.pre_processing import (
|
|
25
|
+
pre_process_network_input,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ResNetClassifier(nn.Module):
|
|
30
|
+
|
|
31
|
+
def __init__(self, backbone: nn.Module, softmax_fused: bool):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self._backbone = backbone
|
|
34
|
+
self._softmax_fused = softmax_fused
|
|
35
|
+
|
|
36
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
results = self._backbone(x)
|
|
38
|
+
if not self._softmax_fused:
|
|
39
|
+
results = torch.nn.functional.softmax(results, dim=-1)
|
|
40
|
+
return results
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ResNetForClassificationTorch(ClassificationModel[torch.Tensor, torch.Tensor]):
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def from_pretrained(
|
|
47
|
+
cls,
|
|
48
|
+
model_name_or_path: str,
|
|
49
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
50
|
+
**kwargs,
|
|
51
|
+
) -> "ResNetForClassificationTorch":
|
|
52
|
+
model_package_content = get_model_package_contents(
|
|
53
|
+
model_package_dir=model_name_or_path,
|
|
54
|
+
elements=[
|
|
55
|
+
"class_names.txt",
|
|
56
|
+
"inference_config.json",
|
|
57
|
+
"weights.pth",
|
|
58
|
+
],
|
|
59
|
+
)
|
|
60
|
+
class_names = parse_class_names_file(
|
|
61
|
+
class_names_path=model_package_content["class_names.txt"]
|
|
62
|
+
)
|
|
63
|
+
inference_config = parse_inference_config(
|
|
64
|
+
config_path=model_package_content["inference_config.json"],
|
|
65
|
+
allowed_resize_modes={
|
|
66
|
+
ResizeMode.STRETCH_TO,
|
|
67
|
+
ResizeMode.LETTERBOX,
|
|
68
|
+
ResizeMode.CENTER_CROP,
|
|
69
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
70
|
+
},
|
|
71
|
+
)
|
|
72
|
+
if inference_config.model_initialization is None:
|
|
73
|
+
raise CorruptedModelPackageError(
|
|
74
|
+
message="Expected model initialization parameters not provided in inference config.",
|
|
75
|
+
help_url="https://todo",
|
|
76
|
+
)
|
|
77
|
+
num_classes = inference_config.model_initialization.get("num_classes")
|
|
78
|
+
model_name = inference_config.model_initialization.get("model_name")
|
|
79
|
+
if not isinstance(num_classes, int):
|
|
80
|
+
raise CorruptedModelPackageError(
|
|
81
|
+
message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
|
|
82
|
+
help_url="https://todo",
|
|
83
|
+
)
|
|
84
|
+
if not isinstance(model_name, str):
|
|
85
|
+
raise CorruptedModelPackageError(
|
|
86
|
+
message="Expected model initialization parameter `model_name` not provided or in invalid format.",
|
|
87
|
+
help_url="https://todo",
|
|
88
|
+
)
|
|
89
|
+
if inference_config.post_processing.type != "softmax":
|
|
90
|
+
raise CorruptedModelPackageError(
|
|
91
|
+
message="Expected softmax to be the post-processing",
|
|
92
|
+
help_url="https://todo",
|
|
93
|
+
)
|
|
94
|
+
backbone = timm.create_model(
|
|
95
|
+
model_name,
|
|
96
|
+
pretrained=False,
|
|
97
|
+
num_classes=num_classes,
|
|
98
|
+
).to(device)
|
|
99
|
+
state_dict = torch.load(
|
|
100
|
+
model_package_content["weights.pth"],
|
|
101
|
+
weights_only=True,
|
|
102
|
+
map_location=device,
|
|
103
|
+
)
|
|
104
|
+
backbone.load_state_dict(state_dict)
|
|
105
|
+
model = ResNetClassifier(
|
|
106
|
+
backbone=backbone,
|
|
107
|
+
softmax_fused=inference_config.post_processing.fused,
|
|
108
|
+
).to(device)
|
|
109
|
+
return cls(
|
|
110
|
+
model=model.eval(),
|
|
111
|
+
inference_config=inference_config,
|
|
112
|
+
class_names=class_names,
|
|
113
|
+
device=device,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
model: ResNetClassifier,
|
|
119
|
+
inference_config: InferenceConfig,
|
|
120
|
+
class_names: List[str],
|
|
121
|
+
device: torch.device,
|
|
122
|
+
):
|
|
123
|
+
self._model = model
|
|
124
|
+
self._inference_config = inference_config
|
|
125
|
+
self._class_names = class_names
|
|
126
|
+
self._device = device
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def class_names(self) -> List[str]:
|
|
130
|
+
return self._class_names
|
|
131
|
+
|
|
132
|
+
def pre_process(
|
|
133
|
+
self,
|
|
134
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
135
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
136
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
137
|
+
**kwargs,
|
|
138
|
+
) -> torch.Tensor:
|
|
139
|
+
return pre_process_network_input(
|
|
140
|
+
images=images,
|
|
141
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
142
|
+
network_input=self._inference_config.network_input,
|
|
143
|
+
target_device=self._device,
|
|
144
|
+
input_color_format=input_color_format,
|
|
145
|
+
image_size_wh=image_size,
|
|
146
|
+
)[0]
|
|
147
|
+
|
|
148
|
+
def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
149
|
+
with torch.inference_mode():
|
|
150
|
+
return self._model(pre_processed_images)
|
|
151
|
+
|
|
152
|
+
def post_process(
|
|
153
|
+
self,
|
|
154
|
+
model_results: torch.Tensor,
|
|
155
|
+
**kwargs,
|
|
156
|
+
) -> ClassificationPrediction:
|
|
157
|
+
return ClassificationPrediction(
|
|
158
|
+
class_id=model_results.argmax(dim=-1),
|
|
159
|
+
confidence=model_results,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class ResNetMultiLabelClassifier(nn.Module):
|
|
164
|
+
|
|
165
|
+
def __init__(self, backbone: nn.Module, sigmoid_fused: bool):
|
|
166
|
+
super().__init__()
|
|
167
|
+
self._backbone = backbone
|
|
168
|
+
self._sigmoid_fused = sigmoid_fused
|
|
169
|
+
|
|
170
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
171
|
+
results = self._backbone(x)
|
|
172
|
+
if not self._sigmoid_fused:
|
|
173
|
+
results = torch.nn.functional.sigmoid(results)
|
|
174
|
+
return results
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ResNetForMultiLabelClassificationTorch(
|
|
178
|
+
MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
|
|
179
|
+
):
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def from_pretrained(
|
|
183
|
+
cls,
|
|
184
|
+
model_name_or_path: str,
|
|
185
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
186
|
+
**kwargs,
|
|
187
|
+
) -> "ResNetForMultiLabelClassificationTorch":
|
|
188
|
+
model_package_content = get_model_package_contents(
|
|
189
|
+
model_package_dir=model_name_or_path,
|
|
190
|
+
elements=[
|
|
191
|
+
"class_names.txt",
|
|
192
|
+
"inference_config.json",
|
|
193
|
+
"weights.pth",
|
|
194
|
+
],
|
|
195
|
+
)
|
|
196
|
+
class_names = parse_class_names_file(
|
|
197
|
+
class_names_path=model_package_content["class_names.txt"]
|
|
198
|
+
)
|
|
199
|
+
inference_config = parse_inference_config(
|
|
200
|
+
config_path=model_package_content["inference_config.json"],
|
|
201
|
+
allowed_resize_modes={
|
|
202
|
+
ResizeMode.STRETCH_TO,
|
|
203
|
+
ResizeMode.LETTERBOX,
|
|
204
|
+
ResizeMode.CENTER_CROP,
|
|
205
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
206
|
+
},
|
|
207
|
+
)
|
|
208
|
+
if inference_config.model_initialization is None:
|
|
209
|
+
raise CorruptedModelPackageError(
|
|
210
|
+
message="Expected model initialization parameters not provided in inference config.",
|
|
211
|
+
help_url="https://todo",
|
|
212
|
+
)
|
|
213
|
+
num_classes = inference_config.model_initialization.get("num_classes")
|
|
214
|
+
model_name = inference_config.model_initialization.get("model_name")
|
|
215
|
+
if not isinstance(num_classes, int):
|
|
216
|
+
raise CorruptedModelPackageError(
|
|
217
|
+
message="Expected model initialization parameter `num_classes` not provided or in invalid format.",
|
|
218
|
+
help_url="https://todo",
|
|
219
|
+
)
|
|
220
|
+
if not isinstance(model_name, str):
|
|
221
|
+
raise CorruptedModelPackageError(
|
|
222
|
+
message="Expected model initialization parameter `model_name` not provided or in invalid format.",
|
|
223
|
+
help_url="https://todo",
|
|
224
|
+
)
|
|
225
|
+
if inference_config.post_processing.type != "sigmoid":
|
|
226
|
+
raise CorruptedModelPackageError(
|
|
227
|
+
message="Expected sigmoid to be the post-processing",
|
|
228
|
+
help_url="https://todo",
|
|
229
|
+
)
|
|
230
|
+
backbone = timm.create_model(
|
|
231
|
+
model_name,
|
|
232
|
+
pretrained=False,
|
|
233
|
+
num_classes=num_classes,
|
|
234
|
+
).to(device)
|
|
235
|
+
state_dict = torch.load(
|
|
236
|
+
model_package_content["weights.pth"],
|
|
237
|
+
weights_only=True,
|
|
238
|
+
map_location=device,
|
|
239
|
+
)
|
|
240
|
+
backbone.load_state_dict(state_dict)
|
|
241
|
+
model = ResNetMultiLabelClassifier(
|
|
242
|
+
backbone=backbone,
|
|
243
|
+
sigmoid_fused=inference_config.post_processing.fused,
|
|
244
|
+
).to(device)
|
|
245
|
+
return cls(
|
|
246
|
+
model=model.eval(),
|
|
247
|
+
inference_config=inference_config,
|
|
248
|
+
class_names=class_names,
|
|
249
|
+
device=device,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
def __init__(
|
|
253
|
+
self,
|
|
254
|
+
model: ResNetMultiLabelClassifier,
|
|
255
|
+
inference_config: InferenceConfig,
|
|
256
|
+
class_names: List[str],
|
|
257
|
+
device: torch.device,
|
|
258
|
+
):
|
|
259
|
+
self._model = model
|
|
260
|
+
self._inference_config = inference_config
|
|
261
|
+
self._class_names = class_names
|
|
262
|
+
self._device = device
|
|
263
|
+
|
|
264
|
+
@property
|
|
265
|
+
def class_names(self) -> List[str]:
|
|
266
|
+
return self._class_names
|
|
267
|
+
|
|
268
|
+
def pre_process(
|
|
269
|
+
self,
|
|
270
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
271
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
272
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
273
|
+
**kwargs,
|
|
274
|
+
) -> torch.Tensor:
|
|
275
|
+
return pre_process_network_input(
|
|
276
|
+
images=images,
|
|
277
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
278
|
+
network_input=self._inference_config.network_input,
|
|
279
|
+
target_device=self._device,
|
|
280
|
+
input_color_format=input_color_format,
|
|
281
|
+
image_size_wh=image_size,
|
|
282
|
+
)[0]
|
|
283
|
+
|
|
284
|
+
def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
285
|
+
with torch.inference_mode():
|
|
286
|
+
return self._model(pre_processed_images)
|
|
287
|
+
|
|
288
|
+
def post_process(
|
|
289
|
+
self,
|
|
290
|
+
model_results: torch.Tensor,
|
|
291
|
+
confidence: float = 0.5,
|
|
292
|
+
**kwargs,
|
|
293
|
+
) -> List[MultiLabelClassificationPrediction]:
|
|
294
|
+
results = []
|
|
295
|
+
for batch_element_confidence in model_results:
|
|
296
|
+
predicted_classes = torch.argwhere(
|
|
297
|
+
batch_element_confidence >= confidence
|
|
298
|
+
).squeeze(dim=-1)
|
|
299
|
+
results.append(
|
|
300
|
+
MultiLabelClassificationPrediction(
|
|
301
|
+
class_ids=predicted_classes,
|
|
302
|
+
confidence=batch_element_confidence,
|
|
303
|
+
)
|
|
304
|
+
)
|
|
305
|
+
return results
|
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from inference_models import (
|
|
8
|
+
ClassificationModel,
|
|
9
|
+
ClassificationPrediction,
|
|
10
|
+
MultiLabelClassificationModel,
|
|
11
|
+
MultiLabelClassificationPrediction,
|
|
12
|
+
)
|
|
13
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
14
|
+
from inference_models.entities import ColorFormat
|
|
15
|
+
from inference_models.errors import (
|
|
16
|
+
CorruptedModelPackageError,
|
|
17
|
+
MissingDependencyError,
|
|
18
|
+
ModelRuntimeError,
|
|
19
|
+
)
|
|
20
|
+
from inference_models.models.base.types import PreprocessedInputs
|
|
21
|
+
from inference_models.models.common.cuda import (
|
|
22
|
+
use_cuda_context,
|
|
23
|
+
use_primary_cuda_context,
|
|
24
|
+
)
|
|
25
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
26
|
+
from inference_models.models.common.roboflow.model_packages import (
|
|
27
|
+
InferenceConfig,
|
|
28
|
+
ResizeMode,
|
|
29
|
+
TRTConfig,
|
|
30
|
+
parse_class_names_file,
|
|
31
|
+
parse_inference_config,
|
|
32
|
+
parse_trt_config,
|
|
33
|
+
)
|
|
34
|
+
from inference_models.models.common.roboflow.pre_processing import (
|
|
35
|
+
pre_process_network_input,
|
|
36
|
+
)
|
|
37
|
+
from inference_models.models.common.trt import (
|
|
38
|
+
get_engine_inputs_and_outputs,
|
|
39
|
+
infer_from_trt_engine,
|
|
40
|
+
load_model,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
import tensorrt as trt
|
|
45
|
+
except ImportError as import_error:
|
|
46
|
+
raise MissingDependencyError(
|
|
47
|
+
message=f"Could not import YOLOv8 model with TRT backend - this error means that some additional dependencies "
|
|
48
|
+
f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
|
|
49
|
+
f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
|
|
50
|
+
f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
|
|
51
|
+
f"installed for all builds with Jetpack 6. "
|
|
52
|
+
f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
|
|
53
|
+
f"You can also contact Roboflow to get support.",
|
|
54
|
+
help_url="https://todo",
|
|
55
|
+
) from import_error
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
import pycuda.driver as cuda
|
|
59
|
+
except ImportError as import_error:
|
|
60
|
+
raise MissingDependencyError(
|
|
61
|
+
message="TODO", help_url="https://todo"
|
|
62
|
+
) from import_error
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ResNetForClassificationTRT(ClassificationModel[torch.Tensor, torch.Tensor]):
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_pretrained(
|
|
69
|
+
cls,
|
|
70
|
+
model_name_or_path: str,
|
|
71
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
72
|
+
engine_host_code_allowed: bool = False,
|
|
73
|
+
**kwargs,
|
|
74
|
+
) -> "ResNetForClassificationTRT":
|
|
75
|
+
if device.type != "cuda":
|
|
76
|
+
raise ModelRuntimeError(
|
|
77
|
+
message=f"TRT engine only runs on CUDA device - {device} device detected.",
|
|
78
|
+
help_url="https://todo",
|
|
79
|
+
)
|
|
80
|
+
model_package_content = get_model_package_contents(
|
|
81
|
+
model_package_dir=model_name_or_path,
|
|
82
|
+
elements=[
|
|
83
|
+
"class_names.txt",
|
|
84
|
+
"inference_config.json",
|
|
85
|
+
"trt_config.json",
|
|
86
|
+
"engine.plan",
|
|
87
|
+
],
|
|
88
|
+
)
|
|
89
|
+
class_names = parse_class_names_file(
|
|
90
|
+
class_names_path=model_package_content["class_names.txt"]
|
|
91
|
+
)
|
|
92
|
+
inference_config = parse_inference_config(
|
|
93
|
+
config_path=model_package_content["inference_config.json"],
|
|
94
|
+
allowed_resize_modes={
|
|
95
|
+
ResizeMode.STRETCH_TO,
|
|
96
|
+
ResizeMode.LETTERBOX,
|
|
97
|
+
ResizeMode.CENTER_CROP,
|
|
98
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
99
|
+
},
|
|
100
|
+
)
|
|
101
|
+
if inference_config.post_processing.type != "softmax":
|
|
102
|
+
raise CorruptedModelPackageError(
|
|
103
|
+
message="Expected Softmax to be the post-processing",
|
|
104
|
+
help_url="https://todo",
|
|
105
|
+
)
|
|
106
|
+
trt_config = parse_trt_config(
|
|
107
|
+
config_path=model_package_content["trt_config.json"]
|
|
108
|
+
)
|
|
109
|
+
cuda.init()
|
|
110
|
+
cuda_device = cuda.Device(device.index or 0)
|
|
111
|
+
with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
|
|
112
|
+
engine = load_model(
|
|
113
|
+
model_path=model_package_content["engine.plan"],
|
|
114
|
+
engine_host_code_allowed=engine_host_code_allowed,
|
|
115
|
+
)
|
|
116
|
+
execution_context = engine.create_execution_context()
|
|
117
|
+
inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
|
|
118
|
+
if len(inputs) != 1:
|
|
119
|
+
raise CorruptedModelPackageError(
|
|
120
|
+
message=f"Implementation assume single model input, found: {len(inputs)}.",
|
|
121
|
+
help_url="https://todo",
|
|
122
|
+
)
|
|
123
|
+
if len(outputs) != 1:
|
|
124
|
+
raise CorruptedModelPackageError(
|
|
125
|
+
message=f"Implementation assume single model output, found: {len(outputs)}.",
|
|
126
|
+
help_url="https://todo",
|
|
127
|
+
)
|
|
128
|
+
return cls(
|
|
129
|
+
engine=engine,
|
|
130
|
+
input_name=inputs[0],
|
|
131
|
+
output_name=outputs[0],
|
|
132
|
+
class_names=class_names,
|
|
133
|
+
inference_config=inference_config,
|
|
134
|
+
trt_config=trt_config,
|
|
135
|
+
device=device,
|
|
136
|
+
cuda_context=cuda_context,
|
|
137
|
+
execution_context=execution_context,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
engine: trt.ICudaEngine,
|
|
143
|
+
input_name: str,
|
|
144
|
+
output_name: str,
|
|
145
|
+
class_names: List[str],
|
|
146
|
+
inference_config: InferenceConfig,
|
|
147
|
+
trt_config: TRTConfig,
|
|
148
|
+
device: torch.device,
|
|
149
|
+
cuda_context: cuda.Context,
|
|
150
|
+
execution_context: trt.IExecutionContext,
|
|
151
|
+
):
|
|
152
|
+
self._engine = engine
|
|
153
|
+
self._input_name = input_name
|
|
154
|
+
self._output_names = [output_name]
|
|
155
|
+
self._class_names = class_names
|
|
156
|
+
self._inference_config = inference_config
|
|
157
|
+
self._trt_config = trt_config
|
|
158
|
+
self._device = device
|
|
159
|
+
self._cuda_context = cuda_context
|
|
160
|
+
self._execution_context = execution_context
|
|
161
|
+
self._lock = Lock()
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def class_names(self) -> List[str]:
|
|
165
|
+
return self._class_names
|
|
166
|
+
|
|
167
|
+
def pre_process(
|
|
168
|
+
self,
|
|
169
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
170
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
171
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
172
|
+
**kwargs,
|
|
173
|
+
) -> torch.Tensor:
|
|
174
|
+
return pre_process_network_input(
|
|
175
|
+
images=images,
|
|
176
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
177
|
+
network_input=self._inference_config.network_input,
|
|
178
|
+
target_device=self._device,
|
|
179
|
+
input_color_format=input_color_format,
|
|
180
|
+
image_size_wh=image_size,
|
|
181
|
+
)[0]
|
|
182
|
+
|
|
183
|
+
def forward(
|
|
184
|
+
self, pre_processed_images: PreprocessedInputs, **kwargs
|
|
185
|
+
) -> torch.Tensor:
|
|
186
|
+
with self._lock:
|
|
187
|
+
with use_cuda_context(context=self._cuda_context):
|
|
188
|
+
return infer_from_trt_engine(
|
|
189
|
+
pre_processed_images=pre_processed_images,
|
|
190
|
+
trt_config=self._trt_config,
|
|
191
|
+
engine=self._engine,
|
|
192
|
+
context=self._execution_context,
|
|
193
|
+
device=self._device,
|
|
194
|
+
input_name=self._input_name,
|
|
195
|
+
outputs=self._output_names,
|
|
196
|
+
)[0]
|
|
197
|
+
|
|
198
|
+
def post_process(
|
|
199
|
+
self,
|
|
200
|
+
model_results: torch.Tensor,
|
|
201
|
+
**kwargs,
|
|
202
|
+
) -> ClassificationPrediction:
|
|
203
|
+
if self._inference_config.post_processing.fused:
|
|
204
|
+
confidence = model_results
|
|
205
|
+
else:
|
|
206
|
+
confidence = torch.nn.functional.softmax(model_results, dim=-1)
|
|
207
|
+
return ClassificationPrediction(
|
|
208
|
+
class_id=confidence.argmax(dim=-1),
|
|
209
|
+
confidence=confidence,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class ResNetForMultiLabelClassificationTRT(
|
|
214
|
+
MultiLabelClassificationModel[torch.Tensor, torch.Tensor]
|
|
215
|
+
):
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def from_pretrained(
|
|
219
|
+
cls,
|
|
220
|
+
model_name_or_path: str,
|
|
221
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
222
|
+
engine_host_code_allowed: bool = False,
|
|
223
|
+
**kwargs,
|
|
224
|
+
) -> "ResNetForMultiLabelClassificationTRT":
|
|
225
|
+
if device.type != "cuda":
|
|
226
|
+
raise ModelRuntimeError(
|
|
227
|
+
message=f"TRT engine only runs on CUDA device - {device} device detected.",
|
|
228
|
+
help_url="https://todo",
|
|
229
|
+
)
|
|
230
|
+
model_package_content = get_model_package_contents(
|
|
231
|
+
model_package_dir=model_name_or_path,
|
|
232
|
+
elements=[
|
|
233
|
+
"class_names.txt",
|
|
234
|
+
"inference_config.json",
|
|
235
|
+
"trt_config.json",
|
|
236
|
+
"engine.plan",
|
|
237
|
+
],
|
|
238
|
+
)
|
|
239
|
+
class_names = parse_class_names_file(
|
|
240
|
+
class_names_path=model_package_content["class_names.txt"]
|
|
241
|
+
)
|
|
242
|
+
inference_config = parse_inference_config(
|
|
243
|
+
config_path=model_package_content["inference_config.json"],
|
|
244
|
+
allowed_resize_modes={
|
|
245
|
+
ResizeMode.STRETCH_TO,
|
|
246
|
+
ResizeMode.LETTERBOX,
|
|
247
|
+
ResizeMode.CENTER_CROP,
|
|
248
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
249
|
+
},
|
|
250
|
+
)
|
|
251
|
+
if inference_config.post_processing.type != "sigmoid":
|
|
252
|
+
raise CorruptedModelPackageError(
|
|
253
|
+
message="Expected sigmoid to be the post-processing",
|
|
254
|
+
help_url="https://todo",
|
|
255
|
+
)
|
|
256
|
+
trt_config = parse_trt_config(
|
|
257
|
+
config_path=model_package_content["trt_config.json"]
|
|
258
|
+
)
|
|
259
|
+
cuda.init()
|
|
260
|
+
cuda_device = cuda.Device(device.index or 0)
|
|
261
|
+
with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
|
|
262
|
+
engine = load_model(
|
|
263
|
+
model_path=model_package_content["engine.plan"],
|
|
264
|
+
engine_host_code_allowed=engine_host_code_allowed,
|
|
265
|
+
)
|
|
266
|
+
execution_context = engine.create_execution_context()
|
|
267
|
+
inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
|
|
268
|
+
if len(inputs) != 1:
|
|
269
|
+
raise CorruptedModelPackageError(
|
|
270
|
+
message=f"Implementation assume single model input, found: {len(inputs)}.",
|
|
271
|
+
help_url="https://todo",
|
|
272
|
+
)
|
|
273
|
+
if len(outputs) != 1:
|
|
274
|
+
raise CorruptedModelPackageError(
|
|
275
|
+
message=f"Implementation assume single model output, found: {len(outputs)}.",
|
|
276
|
+
help_url="https://todo",
|
|
277
|
+
)
|
|
278
|
+
return cls(
|
|
279
|
+
engine=engine,
|
|
280
|
+
input_name=inputs[0],
|
|
281
|
+
output_name=outputs[0],
|
|
282
|
+
class_names=class_names,
|
|
283
|
+
inference_config=inference_config,
|
|
284
|
+
trt_config=trt_config,
|
|
285
|
+
device=device,
|
|
286
|
+
cuda_context=cuda_context,
|
|
287
|
+
execution_context=execution_context,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def __init__(
|
|
291
|
+
self,
|
|
292
|
+
engine: trt.ICudaEngine,
|
|
293
|
+
input_name: str,
|
|
294
|
+
output_name: str,
|
|
295
|
+
class_names: List[str],
|
|
296
|
+
inference_config: InferenceConfig,
|
|
297
|
+
trt_config: TRTConfig,
|
|
298
|
+
device: torch.device,
|
|
299
|
+
cuda_context: cuda.Context,
|
|
300
|
+
execution_context: trt.IExecutionContext,
|
|
301
|
+
):
|
|
302
|
+
self._engine = engine
|
|
303
|
+
self._input_name = input_name
|
|
304
|
+
self._output_names = [output_name]
|
|
305
|
+
self._class_names = class_names
|
|
306
|
+
self._inference_config = inference_config
|
|
307
|
+
self._trt_config = trt_config
|
|
308
|
+
self._device = device
|
|
309
|
+
self._cuda_context = cuda_context
|
|
310
|
+
self._execution_context = execution_context
|
|
311
|
+
self._lock = Lock()
|
|
312
|
+
|
|
313
|
+
@property
|
|
314
|
+
def class_names(self) -> List[str]:
|
|
315
|
+
return self._class_names
|
|
316
|
+
|
|
317
|
+
def pre_process(
|
|
318
|
+
self,
|
|
319
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
320
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
321
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
322
|
+
**kwargs,
|
|
323
|
+
) -> torch.Tensor:
|
|
324
|
+
return pre_process_network_input(
|
|
325
|
+
images=images,
|
|
326
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
327
|
+
network_input=self._inference_config.network_input,
|
|
328
|
+
target_device=self._device,
|
|
329
|
+
input_color_format=input_color_format,
|
|
330
|
+
image_size_wh=image_size,
|
|
331
|
+
)[0]
|
|
332
|
+
|
|
333
|
+
def forward(
|
|
334
|
+
self, pre_processed_images: PreprocessedInputs, **kwargs
|
|
335
|
+
) -> torch.Tensor:
|
|
336
|
+
with self._lock:
|
|
337
|
+
with use_cuda_context(context=self._cuda_context):
|
|
338
|
+
return infer_from_trt_engine(
|
|
339
|
+
pre_processed_images=pre_processed_images,
|
|
340
|
+
trt_config=self._trt_config,
|
|
341
|
+
engine=self._engine,
|
|
342
|
+
context=self._execution_context,
|
|
343
|
+
device=self._device,
|
|
344
|
+
input_name=self._input_name,
|
|
345
|
+
outputs=self._output_names,
|
|
346
|
+
)[0]
|
|
347
|
+
|
|
348
|
+
def post_process(
|
|
349
|
+
self,
|
|
350
|
+
model_results: torch.Tensor,
|
|
351
|
+
confidence: float = 0.5,
|
|
352
|
+
**kwargs,
|
|
353
|
+
) -> List[MultiLabelClassificationPrediction]:
|
|
354
|
+
if self._inference_config.post_processing.fused:
|
|
355
|
+
model_results = model_results
|
|
356
|
+
else:
|
|
357
|
+
model_results = torch.nn.functional.sigmoid(model_results)
|
|
358
|
+
results = []
|
|
359
|
+
for batch_element_confidence in model_results:
|
|
360
|
+
predicted_classes = torch.argwhere(
|
|
361
|
+
batch_element_confidence >= confidence
|
|
362
|
+
).squeeze(dim=-1)
|
|
363
|
+
results.append(
|
|
364
|
+
MultiLabelClassificationPrediction(
|
|
365
|
+
class_ids=predicted_classes,
|
|
366
|
+
confidence=batch_element_confidence,
|
|
367
|
+
)
|
|
368
|
+
)
|
|
369
|
+
return results
|
|
File without changes
|