inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from inference_models import Detections, ObjectDetectionModel
|
|
9
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
10
|
+
from inference_models.entities import ColorFormat
|
|
11
|
+
from inference_models.errors import (
|
|
12
|
+
CorruptedModelPackageError,
|
|
13
|
+
ModelLoadingError,
|
|
14
|
+
ModelRuntimeError,
|
|
15
|
+
)
|
|
16
|
+
from inference_models.logger import LOGGER
|
|
17
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
18
|
+
from inference_models.models.common.roboflow.model_packages import (
|
|
19
|
+
ColorMode,
|
|
20
|
+
DivisiblePadding,
|
|
21
|
+
InferenceConfig,
|
|
22
|
+
NetworkInputDefinition,
|
|
23
|
+
PreProcessingMetadata,
|
|
24
|
+
ResizeMode,
|
|
25
|
+
TrainingInputSize,
|
|
26
|
+
parse_class_names_file,
|
|
27
|
+
parse_inference_config,
|
|
28
|
+
)
|
|
29
|
+
from inference_models.models.common.roboflow.pre_processing import (
|
|
30
|
+
pre_process_network_input,
|
|
31
|
+
)
|
|
32
|
+
from inference_models.models.rfdetr.class_remapping import (
|
|
33
|
+
ClassesReMapping,
|
|
34
|
+
prepare_class_remapping,
|
|
35
|
+
)
|
|
36
|
+
from inference_models.models.rfdetr.common import parse_model_type
|
|
37
|
+
from inference_models.models.rfdetr.default_labels import resolve_labels
|
|
38
|
+
from inference_models.models.rfdetr.post_processor import PostProcess
|
|
39
|
+
from inference_models.models.rfdetr.rfdetr_base_pytorch import (
|
|
40
|
+
LWDETR,
|
|
41
|
+
RFDETRBaseConfig,
|
|
42
|
+
RFDETRLargeConfig,
|
|
43
|
+
RFDETRMediumConfig,
|
|
44
|
+
RFDETRNanoConfig,
|
|
45
|
+
RFDETRSmallConfig,
|
|
46
|
+
build_model,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
torch.set_float32_matmul_precision("high")
|
|
51
|
+
except:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
CONFIG_FOR_MODEL_TYPE = {
|
|
55
|
+
"rfdetr-nano": RFDETRNanoConfig,
|
|
56
|
+
"rfdetr-small": RFDETRSmallConfig,
|
|
57
|
+
"rfdetr-medium": RFDETRMediumConfig,
|
|
58
|
+
"rfdetr-base": RFDETRBaseConfig,
|
|
59
|
+
"rfdetr-large": RFDETRLargeConfig,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
RESIZE_MODES_TO_REVERT_PADDING = {
|
|
63
|
+
ResizeMode.LETTERBOX,
|
|
64
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class RFDetrForObjectDetectionTorch(
|
|
69
|
+
(ObjectDetectionModel[torch.Tensor, PreProcessingMetadata, dict])
|
|
70
|
+
):
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def from_pretrained(
|
|
74
|
+
cls,
|
|
75
|
+
model_name_or_path: str,
|
|
76
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
77
|
+
model_type: Optional[str] = None,
|
|
78
|
+
labels: Optional[Union[str, List[str]]] = None,
|
|
79
|
+
resolution: Optional[int] = None,
|
|
80
|
+
**kwargs,
|
|
81
|
+
) -> "RFDetrForObjectDetectionTorch":
|
|
82
|
+
if os.path.isfile(model_name_or_path):
|
|
83
|
+
return cls.from_checkpoint_file(
|
|
84
|
+
checkpoint_path=model_name_or_path,
|
|
85
|
+
model_type=model_type,
|
|
86
|
+
labels=labels,
|
|
87
|
+
resolution=resolution,
|
|
88
|
+
)
|
|
89
|
+
model_package_content = get_model_package_contents(
|
|
90
|
+
model_package_dir=model_name_or_path,
|
|
91
|
+
elements=[
|
|
92
|
+
"class_names.txt",
|
|
93
|
+
"inference_config.json",
|
|
94
|
+
"model_type.json",
|
|
95
|
+
"weights.pth",
|
|
96
|
+
],
|
|
97
|
+
)
|
|
98
|
+
class_names = parse_class_names_file(
|
|
99
|
+
class_names_path=model_package_content["class_names.txt"]
|
|
100
|
+
)
|
|
101
|
+
inference_config = parse_inference_config(
|
|
102
|
+
config_path=model_package_content["inference_config.json"],
|
|
103
|
+
allowed_resize_modes={
|
|
104
|
+
ResizeMode.STRETCH_TO,
|
|
105
|
+
ResizeMode.LETTERBOX,
|
|
106
|
+
ResizeMode.CENTER_CROP,
|
|
107
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
108
|
+
},
|
|
109
|
+
)
|
|
110
|
+
classes_re_mapping = None
|
|
111
|
+
if inference_config.class_names_operations:
|
|
112
|
+
class_names, classes_re_mapping = prepare_class_remapping(
|
|
113
|
+
class_names=class_names,
|
|
114
|
+
class_names_operations=inference_config.class_names_operations,
|
|
115
|
+
device=device,
|
|
116
|
+
)
|
|
117
|
+
weights_dict = torch.load(
|
|
118
|
+
model_package_content["weights.pth"],
|
|
119
|
+
map_location=device,
|
|
120
|
+
weights_only=False,
|
|
121
|
+
)["model"]
|
|
122
|
+
model_type = parse_model_type(
|
|
123
|
+
config_path=model_package_content["model_type.json"]
|
|
124
|
+
)
|
|
125
|
+
if model_type not in CONFIG_FOR_MODEL_TYPE:
|
|
126
|
+
raise CorruptedModelPackageError(
|
|
127
|
+
message=f"Model package describes model_type as '{model_type}' which is not supported. "
|
|
128
|
+
f"Supported model types: {list(CONFIG_FOR_MODEL_TYPE.keys())}.",
|
|
129
|
+
help_url="https://todo",
|
|
130
|
+
)
|
|
131
|
+
model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
|
|
132
|
+
checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
|
|
133
|
+
model_config.num_classes = checkpoint_num_classes - 1
|
|
134
|
+
model_config.resolution = (
|
|
135
|
+
inference_config.network_input.training_input_size.height
|
|
136
|
+
)
|
|
137
|
+
model = build_model(config=model_config)
|
|
138
|
+
model.load_state_dict(weights_dict)
|
|
139
|
+
model = model.eval().to(device)
|
|
140
|
+
post_processor = PostProcess()
|
|
141
|
+
return cls(
|
|
142
|
+
model=model,
|
|
143
|
+
class_names=class_names,
|
|
144
|
+
classes_re_mapping=classes_re_mapping,
|
|
145
|
+
device=device,
|
|
146
|
+
inference_config=inference_config,
|
|
147
|
+
post_processor=post_processor,
|
|
148
|
+
resolution=model_config.resolution,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def from_checkpoint_file(
|
|
153
|
+
cls,
|
|
154
|
+
checkpoint_path: str,
|
|
155
|
+
model_type: Optional[str] = None,
|
|
156
|
+
labels: Optional[Union[str, List[str]]] = None,
|
|
157
|
+
resolution: Optional[int] = None,
|
|
158
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
159
|
+
):
|
|
160
|
+
if model_type is None:
|
|
161
|
+
raise ModelLoadingError(
|
|
162
|
+
message="While loading RFDetr model (using torch backend) could not determine `model_type`. "
|
|
163
|
+
"If you used `RFDetrForObjectDetectionTorch` directly imported in your code, please pass "
|
|
164
|
+
f"one of the value: {CONFIG_FOR_MODEL_TYPE.keys()} as the parameter. If you see this "
|
|
165
|
+
f"error, while using `AutoModel.from_pretrained(...)` or thrown from managed Roboflow service, "
|
|
166
|
+
f"this is a bug - raise the issue: https://github.com/roboflow/inference/issue providing "
|
|
167
|
+
f"full context.",
|
|
168
|
+
help_url="https://todo",
|
|
169
|
+
)
|
|
170
|
+
weights_dict = torch.load(
|
|
171
|
+
checkpoint_path,
|
|
172
|
+
map_location=device,
|
|
173
|
+
weights_only=False,
|
|
174
|
+
)["model"]
|
|
175
|
+
if model_type not in CONFIG_FOR_MODEL_TYPE:
|
|
176
|
+
raise ModelLoadingError(
|
|
177
|
+
message=f"Model package describes model_type as '{model_type}' which is not supported. "
|
|
178
|
+
f"Supported model types: {list(CONFIG_FOR_MODEL_TYPE.keys())}.",
|
|
179
|
+
help_url="https://todo",
|
|
180
|
+
)
|
|
181
|
+
model_config = CONFIG_FOR_MODEL_TYPE[model_type](device=device)
|
|
182
|
+
divisibility = model_config.num_windows * model_config.patch_size
|
|
183
|
+
if resolution is not None:
|
|
184
|
+
if resolution < 0 or resolution % divisibility != 0:
|
|
185
|
+
raise ModelLoadingError(
|
|
186
|
+
message=f"Attempted to load RFDetr model (using torch backend) with `resolution` parameter which "
|
|
187
|
+
f"is invalid - the model required positive value divisible by 56. Make sure you used "
|
|
188
|
+
f"proper value, corresponding to the one used to train the model.",
|
|
189
|
+
help_url="https://todo",
|
|
190
|
+
)
|
|
191
|
+
model_config.resolution = resolution
|
|
192
|
+
inference_config = InferenceConfig(
|
|
193
|
+
network_input=NetworkInputDefinition(
|
|
194
|
+
training_input_size=TrainingInputSize(
|
|
195
|
+
height=model_config.resolution,
|
|
196
|
+
width=model_config.resolution,
|
|
197
|
+
),
|
|
198
|
+
dynamic_spatial_size_supported=True,
|
|
199
|
+
dynamic_spatial_size_mode=DivisiblePadding(
|
|
200
|
+
type="pad-to-be-divisible",
|
|
201
|
+
value=divisibility,
|
|
202
|
+
),
|
|
203
|
+
color_mode=ColorMode.BGR,
|
|
204
|
+
resize_mode=ResizeMode.STRETCH_TO,
|
|
205
|
+
input_channels=3,
|
|
206
|
+
scaling_factor=255,
|
|
207
|
+
normalization=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
208
|
+
)
|
|
209
|
+
)
|
|
210
|
+
checkpoint_num_classes = weights_dict["class_embed.bias"].shape[0]
|
|
211
|
+
model_config.num_classes = checkpoint_num_classes - 1
|
|
212
|
+
model = build_model(config=model_config)
|
|
213
|
+
if labels is None:
|
|
214
|
+
class_names = [f"class_{i}" for i in range(checkpoint_num_classes)]
|
|
215
|
+
elif isinstance(labels, str):
|
|
216
|
+
class_names = resolve_labels(labels=labels)
|
|
217
|
+
else:
|
|
218
|
+
class_names = labels
|
|
219
|
+
if checkpoint_num_classes != len(class_names):
|
|
220
|
+
raise ModelLoadingError(
|
|
221
|
+
message=f"Checkpoint pointed to load RFDetr defines {checkpoint_num_classes} output classes, but "
|
|
222
|
+
f"loaded labels define {len(class_names)} classes - fix the value of `labels` parameter.",
|
|
223
|
+
help_url="https://todo",
|
|
224
|
+
)
|
|
225
|
+
model.load_state_dict(weights_dict)
|
|
226
|
+
model = model.eval().to(device)
|
|
227
|
+
post_processor = PostProcess()
|
|
228
|
+
return cls(
|
|
229
|
+
model=model,
|
|
230
|
+
class_names=class_names,
|
|
231
|
+
classes_re_mapping=None,
|
|
232
|
+
device=device,
|
|
233
|
+
inference_config=inference_config,
|
|
234
|
+
post_processor=post_processor,
|
|
235
|
+
resolution=model_config.resolution,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
model: LWDETR,
|
|
241
|
+
inference_config: InferenceConfig,
|
|
242
|
+
class_names: List[str],
|
|
243
|
+
classes_re_mapping: Optional[ClassesReMapping],
|
|
244
|
+
device: torch.device,
|
|
245
|
+
post_processor: PostProcess,
|
|
246
|
+
resolution: int,
|
|
247
|
+
):
|
|
248
|
+
self._model = model
|
|
249
|
+
self._inference_config = inference_config
|
|
250
|
+
self._class_names = class_names
|
|
251
|
+
self._classes_re_mapping = classes_re_mapping
|
|
252
|
+
self._post_processor = post_processor
|
|
253
|
+
self._device = device
|
|
254
|
+
self._resolution = resolution
|
|
255
|
+
self._has_warned_about_not_being_optimized_for_inference = False
|
|
256
|
+
self._inference_model: Optional[LWDETR] = None
|
|
257
|
+
self._optimized_has_been_compiled = False
|
|
258
|
+
self._optimized_batch_size = None
|
|
259
|
+
self._optimized_dtype = None
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def class_names(self) -> List[str]:
|
|
263
|
+
return self._class_names
|
|
264
|
+
|
|
265
|
+
def optimize_for_inference(
|
|
266
|
+
self,
|
|
267
|
+
compile: bool = True,
|
|
268
|
+
batch_size: int = 1,
|
|
269
|
+
dtype: torch.dtype = torch.float32,
|
|
270
|
+
) -> None:
|
|
271
|
+
self.remove_optimized_model()
|
|
272
|
+
self._inference_model = deepcopy(self._model)
|
|
273
|
+
self._inference_model.eval()
|
|
274
|
+
self._inference_model.export()
|
|
275
|
+
self._inference_model = self._inference_model.to(dtype=dtype)
|
|
276
|
+
self._optimized_dtype = dtype
|
|
277
|
+
if compile:
|
|
278
|
+
self._inference_model = torch.jit.trace(
|
|
279
|
+
self._inference_model,
|
|
280
|
+
torch.randn(
|
|
281
|
+
batch_size,
|
|
282
|
+
3,
|
|
283
|
+
self._resolution,
|
|
284
|
+
self._resolution,
|
|
285
|
+
device=self._device,
|
|
286
|
+
dtype=dtype,
|
|
287
|
+
),
|
|
288
|
+
)
|
|
289
|
+
self._optimized_has_been_compiled = True
|
|
290
|
+
self._optimized_batch_size = batch_size
|
|
291
|
+
|
|
292
|
+
def remove_optimized_model(self) -> None:
|
|
293
|
+
self._has_warned_about_not_being_optimized_for_inference = False
|
|
294
|
+
self._inference_model = None
|
|
295
|
+
self._optimized_has_been_compiled = False
|
|
296
|
+
self._optimized_batch_size = None
|
|
297
|
+
|
|
298
|
+
def pre_process(
|
|
299
|
+
self,
|
|
300
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
301
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
302
|
+
image_size: Optional[Tuple[int, int]] = None,
|
|
303
|
+
**kwargs,
|
|
304
|
+
) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
|
|
305
|
+
return pre_process_network_input(
|
|
306
|
+
images=images,
|
|
307
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
308
|
+
network_input=self._inference_config.network_input,
|
|
309
|
+
target_device=self._device,
|
|
310
|
+
input_color_format=input_color_format,
|
|
311
|
+
image_size_wh=image_size,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def forward(self, pre_processed_images: torch.Tensor, **kwargs) -> dict:
|
|
315
|
+
if (
|
|
316
|
+
self._inference_model is None
|
|
317
|
+
and not self._has_warned_about_not_being_optimized_for_inference
|
|
318
|
+
):
|
|
319
|
+
LOGGER.warning(
|
|
320
|
+
"Model is not optimized for inference. "
|
|
321
|
+
"Latency may be higher than expected. "
|
|
322
|
+
"You can optimize the model for inference by calling model.optimize_for_inference()."
|
|
323
|
+
)
|
|
324
|
+
self._has_warned_about_not_being_optimized_for_inference = True
|
|
325
|
+
if self._inference_model is not None:
|
|
326
|
+
if (self._resolution, self._resolution) != tuple(
|
|
327
|
+
pre_processed_images.shape[2:]
|
|
328
|
+
):
|
|
329
|
+
raise ModelRuntimeError(
|
|
330
|
+
message=f"Resolution mismatch. Model was optimized for resolution {self._resolution}, "
|
|
331
|
+
f"but got {tuple(pre_processed_images.shape[2:])}. "
|
|
332
|
+
"You can explicitly remove the optimized model by calling model.remove_optimized_model().",
|
|
333
|
+
help_url="https://todo",
|
|
334
|
+
)
|
|
335
|
+
if self._optimized_has_been_compiled:
|
|
336
|
+
if self._optimized_batch_size != pre_processed_images.shape[0]:
|
|
337
|
+
raise ModelRuntimeError(
|
|
338
|
+
message="Batch size mismatch. Optimized model was compiled for batch size "
|
|
339
|
+
f"{self._optimized_batch_size}, but got {pre_processed_images.shape[0]}. "
|
|
340
|
+
"You can explicitly remove the optimized model by calling model.remove_optimized_model(). "
|
|
341
|
+
"Alternatively, you can recompile the optimized model for a different batch size "
|
|
342
|
+
"by calling model.optimize_for_inference(batch_size=<new_batch_size>).",
|
|
343
|
+
help_url="https://todo",
|
|
344
|
+
)
|
|
345
|
+
with torch.inference_mode():
|
|
346
|
+
if self._inference_model:
|
|
347
|
+
predictions = self._inference_model(
|
|
348
|
+
pre_processed_images.to(dtype=self._optimized_dtype)
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
predictions = self._model(pre_processed_images)
|
|
352
|
+
if isinstance(predictions, tuple):
|
|
353
|
+
predictions = {
|
|
354
|
+
"pred_logits": predictions[1],
|
|
355
|
+
"pred_boxes": predictions[0],
|
|
356
|
+
}
|
|
357
|
+
return predictions
|
|
358
|
+
|
|
359
|
+
def post_process(
|
|
360
|
+
self,
|
|
361
|
+
model_results: dict,
|
|
362
|
+
pre_processing_meta: List[PreProcessingMetadata],
|
|
363
|
+
threshold: float = 0.5,
|
|
364
|
+
**kwargs,
|
|
365
|
+
) -> List[Detections]:
|
|
366
|
+
if (
|
|
367
|
+
self._inference_config.network_input.resize_mode
|
|
368
|
+
in RESIZE_MODES_TO_REVERT_PADDING
|
|
369
|
+
):
|
|
370
|
+
un_padding_results = []
|
|
371
|
+
for out_box_tensor, image_metadata in zip(
|
|
372
|
+
model_results["pred_boxes"], pre_processing_meta
|
|
373
|
+
):
|
|
374
|
+
box_center_offsets = torch.as_tensor( # bboxes in format cxcywh now, so only cx, cy to be pushed
|
|
375
|
+
[
|
|
376
|
+
image_metadata.pad_left / image_metadata.inference_size.width,
|
|
377
|
+
image_metadata.pad_top / image_metadata.inference_size.height,
|
|
378
|
+
0.0,
|
|
379
|
+
0.0,
|
|
380
|
+
],
|
|
381
|
+
dtype=out_box_tensor.dtype,
|
|
382
|
+
device=out_box_tensor.device,
|
|
383
|
+
)
|
|
384
|
+
ox_padding = (
|
|
385
|
+
image_metadata.pad_left + image_metadata.pad_right
|
|
386
|
+
) / image_metadata.inference_size.width
|
|
387
|
+
oy_padding = (
|
|
388
|
+
image_metadata.pad_top + image_metadata.pad_bottom
|
|
389
|
+
) / image_metadata.inference_size.height
|
|
390
|
+
box_wh_offsets = torch.as_tensor( # bboxes in format cxcywh now, so only cx, cy to be pushed
|
|
391
|
+
[
|
|
392
|
+
1.0 - ox_padding,
|
|
393
|
+
1.0 - oy_padding,
|
|
394
|
+
1.0 - ox_padding,
|
|
395
|
+
1.0 - oy_padding,
|
|
396
|
+
],
|
|
397
|
+
dtype=out_box_tensor.dtype,
|
|
398
|
+
device=out_box_tensor.device,
|
|
399
|
+
)
|
|
400
|
+
out_box_tensor = (out_box_tensor - box_center_offsets) / box_wh_offsets
|
|
401
|
+
un_padding_results.append(out_box_tensor)
|
|
402
|
+
model_results["pred_boxes"] = torch.stack(un_padding_results, dim=0)
|
|
403
|
+
if self._inference_config.network_input.resize_mode is ResizeMode.CENTER_CROP:
|
|
404
|
+
orig_sizes = [
|
|
405
|
+
(
|
|
406
|
+
round(e.inference_size.height / e.scale_height),
|
|
407
|
+
round(e.inference_size.width / e.scale_width),
|
|
408
|
+
)
|
|
409
|
+
for e in pre_processing_meta
|
|
410
|
+
]
|
|
411
|
+
else:
|
|
412
|
+
orig_sizes = [
|
|
413
|
+
(e.size_after_pre_processing.height, e.size_after_pre_processing.width)
|
|
414
|
+
for e in pre_processing_meta
|
|
415
|
+
]
|
|
416
|
+
target_sizes = torch.tensor(orig_sizes, device=self._device)
|
|
417
|
+
results = self._post_processor(model_results, target_sizes=target_sizes)
|
|
418
|
+
detections_list = []
|
|
419
|
+
for image_result, image_metadata in zip(results, pre_processing_meta):
|
|
420
|
+
scores = image_result["scores"]
|
|
421
|
+
labels = image_result["labels"]
|
|
422
|
+
boxes = image_result["boxes"]
|
|
423
|
+
if self._classes_re_mapping is not None:
|
|
424
|
+
remapping_mask = torch.isin(
|
|
425
|
+
labels, self._classes_re_mapping.remaining_class_ids
|
|
426
|
+
)
|
|
427
|
+
scores = scores[remapping_mask]
|
|
428
|
+
labels = self._classes_re_mapping.class_mapping[labels[remapping_mask]]
|
|
429
|
+
boxes = boxes[remapping_mask]
|
|
430
|
+
keep = scores > threshold
|
|
431
|
+
scores = scores[keep]
|
|
432
|
+
labels = labels[keep]
|
|
433
|
+
boxes = boxes[keep]
|
|
434
|
+
if (
|
|
435
|
+
self._inference_config.network_input.resize_mode
|
|
436
|
+
is ResizeMode.CENTER_CROP
|
|
437
|
+
):
|
|
438
|
+
offsets = torch.as_tensor(
|
|
439
|
+
[
|
|
440
|
+
image_metadata.pad_left,
|
|
441
|
+
image_metadata.pad_top,
|
|
442
|
+
image_metadata.pad_left,
|
|
443
|
+
image_metadata.pad_top,
|
|
444
|
+
],
|
|
445
|
+
dtype=boxes.dtype,
|
|
446
|
+
device=boxes.device,
|
|
447
|
+
)
|
|
448
|
+
boxes[:, :4].sub_(offsets)
|
|
449
|
+
if (
|
|
450
|
+
image_metadata.static_crop_offset.offset_x != 0
|
|
451
|
+
or image_metadata.static_crop_offset.offset_y != 0
|
|
452
|
+
):
|
|
453
|
+
static_crop_offsets = torch.as_tensor(
|
|
454
|
+
[
|
|
455
|
+
image_metadata.static_crop_offset.offset_x,
|
|
456
|
+
image_metadata.static_crop_offset.offset_y,
|
|
457
|
+
image_metadata.static_crop_offset.offset_x,
|
|
458
|
+
image_metadata.static_crop_offset.offset_y,
|
|
459
|
+
],
|
|
460
|
+
dtype=boxes.dtype,
|
|
461
|
+
device=boxes.device,
|
|
462
|
+
)
|
|
463
|
+
boxes[:, :4].add_(static_crop_offsets)
|
|
464
|
+
detections = Detections(
|
|
465
|
+
xyxy=boxes.round().int(),
|
|
466
|
+
confidence=scores,
|
|
467
|
+
class_id=labels.int(),
|
|
468
|
+
)
|
|
469
|
+
detections_list.append(detections)
|
|
470
|
+
return detections_list
|