inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from typing import List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from inference_models import Detections, ObjectDetectionModel
|
|
8
|
+
from inference_models.configuration import DEFAULT_DEVICE
|
|
9
|
+
from inference_models.entities import ColorFormat
|
|
10
|
+
from inference_models.errors import (
|
|
11
|
+
CorruptedModelPackageError,
|
|
12
|
+
MissingDependencyError,
|
|
13
|
+
ModelRuntimeError,
|
|
14
|
+
)
|
|
15
|
+
from inference_models.models.common.cuda import (
|
|
16
|
+
use_cuda_context,
|
|
17
|
+
use_primary_cuda_context,
|
|
18
|
+
)
|
|
19
|
+
from inference_models.models.common.model_packages import get_model_package_contents
|
|
20
|
+
from inference_models.models.common.roboflow.model_packages import (
|
|
21
|
+
InferenceConfig,
|
|
22
|
+
PreProcessingMetadata,
|
|
23
|
+
ResizeMode,
|
|
24
|
+
TRTConfig,
|
|
25
|
+
parse_class_names_file,
|
|
26
|
+
parse_inference_config,
|
|
27
|
+
parse_trt_config,
|
|
28
|
+
)
|
|
29
|
+
from inference_models.models.common.roboflow.post_processing import (
|
|
30
|
+
rescale_image_detections,
|
|
31
|
+
)
|
|
32
|
+
from inference_models.models.common.roboflow.pre_processing import (
|
|
33
|
+
pre_process_network_input,
|
|
34
|
+
)
|
|
35
|
+
from inference_models.models.common.trt import (
|
|
36
|
+
get_engine_inputs_and_outputs,
|
|
37
|
+
infer_from_trt_engine,
|
|
38
|
+
load_model,
|
|
39
|
+
)
|
|
40
|
+
from inference_models.models.rfdetr.class_remapping import (
|
|
41
|
+
ClassesReMapping,
|
|
42
|
+
prepare_class_remapping,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
import tensorrt as trt
|
|
47
|
+
except ImportError as import_error:
|
|
48
|
+
raise MissingDependencyError(
|
|
49
|
+
message=f"Could not import RFDetr model with TRT backend - this error means that some additional dependencies "
|
|
50
|
+
f"are not installed in the environment. If you run the `inference-models` library directly in your Python "
|
|
51
|
+
f"program, make sure the following extras of the package are installed: `trt10` - installation can only "
|
|
52
|
+
f"succeed for Linux and Windows machines with Cuda 12 installed. Jetson devices, should have TRT 10.x "
|
|
53
|
+
f"installed for all builds with Jetpack 6. "
|
|
54
|
+
f"If you see this error using Roboflow infrastructure, make sure the service you use does support the model. "
|
|
55
|
+
f"You can also contact Roboflow to get support.",
|
|
56
|
+
help_url="https://todo",
|
|
57
|
+
) from import_error
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
import pycuda.driver as cuda
|
|
61
|
+
except ImportError as import_error:
|
|
62
|
+
raise MissingDependencyError(
|
|
63
|
+
message="TODO", help_url="https://todo"
|
|
64
|
+
) from import_error
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class RFDetrForObjectDetectionTRT(
|
|
68
|
+
(
|
|
69
|
+
ObjectDetectionModel[
|
|
70
|
+
torch.Tensor, PreProcessingMetadata, Tuple[torch.Tensor, torch.Tensor]
|
|
71
|
+
]
|
|
72
|
+
)
|
|
73
|
+
):
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def from_pretrained(
|
|
77
|
+
cls,
|
|
78
|
+
model_name_or_path: str,
|
|
79
|
+
device: torch.device = DEFAULT_DEVICE,
|
|
80
|
+
engine_host_code_allowed: bool = False,
|
|
81
|
+
**kwargs,
|
|
82
|
+
) -> "RFDetrForObjectDetectionTRT":
|
|
83
|
+
if device.type != "cuda":
|
|
84
|
+
raise ModelRuntimeError(
|
|
85
|
+
message=f"TRT engine only runs on CUDA device - {device} device detected.",
|
|
86
|
+
help_url="https://todo",
|
|
87
|
+
)
|
|
88
|
+
model_package_content = get_model_package_contents(
|
|
89
|
+
model_package_dir=model_name_or_path,
|
|
90
|
+
elements=[
|
|
91
|
+
"class_names.txt",
|
|
92
|
+
"inference_config.json",
|
|
93
|
+
"trt_config.json",
|
|
94
|
+
"engine.plan",
|
|
95
|
+
],
|
|
96
|
+
)
|
|
97
|
+
class_names = parse_class_names_file(
|
|
98
|
+
class_names_path=model_package_content["class_names.txt"]
|
|
99
|
+
)
|
|
100
|
+
inference_config = parse_inference_config(
|
|
101
|
+
config_path=model_package_content["inference_config.json"],
|
|
102
|
+
allowed_resize_modes={
|
|
103
|
+
ResizeMode.STRETCH_TO,
|
|
104
|
+
ResizeMode.LETTERBOX,
|
|
105
|
+
ResizeMode.CENTER_CROP,
|
|
106
|
+
ResizeMode.LETTERBOX_REFLECT_EDGES,
|
|
107
|
+
},
|
|
108
|
+
)
|
|
109
|
+
classes_re_mapping = None
|
|
110
|
+
if inference_config.class_names_operations:
|
|
111
|
+
class_names, classes_re_mapping = prepare_class_remapping(
|
|
112
|
+
class_names=class_names,
|
|
113
|
+
class_names_operations=inference_config.class_names_operations,
|
|
114
|
+
device=device,
|
|
115
|
+
)
|
|
116
|
+
trt_config = parse_trt_config(
|
|
117
|
+
config_path=model_package_content["trt_config.json"]
|
|
118
|
+
)
|
|
119
|
+
cuda.init()
|
|
120
|
+
cuda_device = cuda.Device(device.index or 0)
|
|
121
|
+
with use_primary_cuda_context(cuda_device=cuda_device) as cuda_context:
|
|
122
|
+
engine = load_model(
|
|
123
|
+
model_path=model_package_content["engine.plan"],
|
|
124
|
+
engine_host_code_allowed=engine_host_code_allowed,
|
|
125
|
+
)
|
|
126
|
+
execution_context = engine.create_execution_context()
|
|
127
|
+
inputs, outputs = get_engine_inputs_and_outputs(engine=engine)
|
|
128
|
+
if len(inputs) != 1:
|
|
129
|
+
raise CorruptedModelPackageError(
|
|
130
|
+
message=f"Implementation assume single model input, found: {len(inputs)}.",
|
|
131
|
+
help_url="https://todo",
|
|
132
|
+
)
|
|
133
|
+
if len(outputs) != 2:
|
|
134
|
+
raise CorruptedModelPackageError(
|
|
135
|
+
message=f"Implementation assume 2 model outputs, found: {len(outputs)}.",
|
|
136
|
+
help_url="https://todo",
|
|
137
|
+
)
|
|
138
|
+
if "dets" not in outputs or "labels" not in outputs:
|
|
139
|
+
raise CorruptedModelPackageError(
|
|
140
|
+
message=f"Expected model outputs to be named `output0` and `output1`, but found: {outputs}.",
|
|
141
|
+
help_url="https://todo",
|
|
142
|
+
)
|
|
143
|
+
return cls(
|
|
144
|
+
engine=engine,
|
|
145
|
+
input_name=inputs[0],
|
|
146
|
+
output_names=["dets", "labels"],
|
|
147
|
+
class_names=class_names,
|
|
148
|
+
classes_re_mapping=classes_re_mapping,
|
|
149
|
+
inference_config=inference_config,
|
|
150
|
+
trt_config=trt_config,
|
|
151
|
+
device=device,
|
|
152
|
+
cuda_context=cuda_context,
|
|
153
|
+
execution_context=execution_context,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
engine: trt.ICudaEngine,
|
|
159
|
+
input_name: str,
|
|
160
|
+
output_names: List[str],
|
|
161
|
+
class_names: List[str],
|
|
162
|
+
classes_re_mapping: Optional[ClassesReMapping],
|
|
163
|
+
inference_config: InferenceConfig,
|
|
164
|
+
trt_config: TRTConfig,
|
|
165
|
+
device: torch.device,
|
|
166
|
+
cuda_context: cuda.Context,
|
|
167
|
+
execution_context: trt.IExecutionContext,
|
|
168
|
+
):
|
|
169
|
+
self._engine = engine
|
|
170
|
+
self._input_name = input_name
|
|
171
|
+
self._output_names = output_names
|
|
172
|
+
self._inference_config = inference_config
|
|
173
|
+
self._class_names = class_names
|
|
174
|
+
self._classes_re_mapping = classes_re_mapping
|
|
175
|
+
self._device = device
|
|
176
|
+
self._cuda_context = cuda_context
|
|
177
|
+
self._execution_context = execution_context
|
|
178
|
+
self._trt_config = trt_config
|
|
179
|
+
self._lock = threading.Lock()
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def class_names(self) -> List[str]:
|
|
183
|
+
return self._class_names
|
|
184
|
+
|
|
185
|
+
def pre_process(
|
|
186
|
+
self,
|
|
187
|
+
images: Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]],
|
|
188
|
+
input_color_format: Optional[ColorFormat] = None,
|
|
189
|
+
**kwargs,
|
|
190
|
+
) -> Tuple[torch.Tensor, List[PreProcessingMetadata]]:
|
|
191
|
+
return pre_process_network_input(
|
|
192
|
+
images=images,
|
|
193
|
+
image_pre_processing=self._inference_config.image_pre_processing,
|
|
194
|
+
network_input=self._inference_config.network_input,
|
|
195
|
+
target_device=self._device,
|
|
196
|
+
input_color_format=input_color_format,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def forward(
|
|
200
|
+
self, pre_processed_images: torch.Tensor, **kwargs
|
|
201
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
202
|
+
with self._lock:
|
|
203
|
+
with use_cuda_context(context=self._cuda_context):
|
|
204
|
+
detections, labels = infer_from_trt_engine(
|
|
205
|
+
pre_processed_images=pre_processed_images,
|
|
206
|
+
trt_config=self._trt_config,
|
|
207
|
+
engine=self._engine,
|
|
208
|
+
context=self._execution_context,
|
|
209
|
+
device=self._device,
|
|
210
|
+
input_name=self._input_name,
|
|
211
|
+
outputs=self._output_names,
|
|
212
|
+
)
|
|
213
|
+
return detections, labels
|
|
214
|
+
|
|
215
|
+
def post_process(
|
|
216
|
+
self,
|
|
217
|
+
model_results: Tuple[torch.Tensor, torch.Tensor],
|
|
218
|
+
pre_processing_meta: List[PreProcessingMetadata],
|
|
219
|
+
threshold: float = 0.5,
|
|
220
|
+
**kwargs,
|
|
221
|
+
) -> List[Detections]:
|
|
222
|
+
bboxes, logits = model_results
|
|
223
|
+
logits_sigmoid = torch.nn.functional.sigmoid(logits)
|
|
224
|
+
results = []
|
|
225
|
+
for image_bboxes, image_logits, image_meta in zip(
|
|
226
|
+
bboxes, logits_sigmoid, pre_processing_meta
|
|
227
|
+
):
|
|
228
|
+
confidence, top_classes = image_logits.max(dim=1)
|
|
229
|
+
confidence_mask = confidence > threshold
|
|
230
|
+
confidence = confidence[confidence_mask]
|
|
231
|
+
top_classes = top_classes[confidence_mask]
|
|
232
|
+
selected_boxes = image_bboxes[confidence_mask]
|
|
233
|
+
confidence, sorted_indices = torch.sort(confidence, descending=True)
|
|
234
|
+
top_classes = top_classes[sorted_indices]
|
|
235
|
+
selected_boxes = selected_boxes[sorted_indices]
|
|
236
|
+
if self._classes_re_mapping is not None:
|
|
237
|
+
remapping_mask = torch.isin(
|
|
238
|
+
top_classes, self._classes_re_mapping.remaining_class_ids
|
|
239
|
+
)
|
|
240
|
+
top_classes = self._classes_re_mapping.class_mapping[
|
|
241
|
+
top_classes[remapping_mask]
|
|
242
|
+
]
|
|
243
|
+
selected_boxes = selected_boxes[remapping_mask]
|
|
244
|
+
confidence = confidence[remapping_mask]
|
|
245
|
+
cxcy = selected_boxes[:, :2]
|
|
246
|
+
wh = selected_boxes[:, 2:]
|
|
247
|
+
xy_min = cxcy - 0.5 * wh
|
|
248
|
+
xy_max = cxcy + 0.5 * wh
|
|
249
|
+
selected_boxes_xyxy_pct = torch.cat([xy_min, xy_max], dim=-1)
|
|
250
|
+
inference_size_hwhw = torch.tensor(
|
|
251
|
+
[
|
|
252
|
+
image_meta.inference_size.height,
|
|
253
|
+
image_meta.inference_size.width,
|
|
254
|
+
image_meta.inference_size.height,
|
|
255
|
+
image_meta.inference_size.width,
|
|
256
|
+
],
|
|
257
|
+
device=self._device,
|
|
258
|
+
)
|
|
259
|
+
selected_boxes_xyxy = selected_boxes_xyxy_pct * inference_size_hwhw
|
|
260
|
+
selected_boxes_xyxy = rescale_image_detections(
|
|
261
|
+
image_detections=selected_boxes_xyxy,
|
|
262
|
+
image_metadata=image_meta,
|
|
263
|
+
)
|
|
264
|
+
detections = Detections(
|
|
265
|
+
xyxy=selected_boxes_xyxy.round().int(),
|
|
266
|
+
confidence=confidence,
|
|
267
|
+
class_id=top_classes.int(),
|
|
268
|
+
)
|
|
269
|
+
results.append(detections)
|
|
270
|
+
return results
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# ------------------------------------------------------------------------
|
|
2
|
+
# RF-DETR
|
|
3
|
+
# Copyright (c) 2025 Roboflow. All Rights Reserved.
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
5
|
+
# ------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
from typing import Callable
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn as nn
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DepthwiseConvBlock(nn.Module):
|
|
16
|
+
r"""Simplified ConvNeXt block without the MLP subnet"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, dim, layer_scale_init_value=0):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.dwconv = nn.Conv2d(
|
|
21
|
+
dim, dim, kernel_size=3, padding=1, groups=dim
|
|
22
|
+
) # depthwise conv
|
|
23
|
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
|
24
|
+
self.pwconv1 = nn.Linear(
|
|
25
|
+
dim, dim
|
|
26
|
+
) # pointwise/1x1 convs, implemented with linear layers
|
|
27
|
+
self.act = nn.GELU()
|
|
28
|
+
self.gamma = (
|
|
29
|
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
|
30
|
+
if layer_scale_init_value > 0
|
|
31
|
+
else None
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def forward(self, x):
|
|
35
|
+
input = x
|
|
36
|
+
x = self.dwconv(x)
|
|
37
|
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
38
|
+
x = self.norm(x)
|
|
39
|
+
x = self.pwconv1(x)
|
|
40
|
+
x = self.act(x)
|
|
41
|
+
if self.gamma is not None:
|
|
42
|
+
x = self.gamma * x
|
|
43
|
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
|
44
|
+
|
|
45
|
+
return x + input
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class MLPBlock(nn.Module):
|
|
49
|
+
def __init__(self, dim, layer_scale_init_value=0):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.norm_in = nn.LayerNorm(dim)
|
|
52
|
+
self.layers = nn.ModuleList(
|
|
53
|
+
[
|
|
54
|
+
nn.Linear(dim, dim * 4),
|
|
55
|
+
nn.GELU(),
|
|
56
|
+
nn.Linear(dim * 4, dim),
|
|
57
|
+
]
|
|
58
|
+
)
|
|
59
|
+
self.gamma = (
|
|
60
|
+
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
|
61
|
+
if layer_scale_init_value > 0
|
|
62
|
+
else None
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def forward(self, x):
|
|
66
|
+
input = x
|
|
67
|
+
x = self.norm_in(x)
|
|
68
|
+
for layer in self.layers:
|
|
69
|
+
x = layer(x)
|
|
70
|
+
if self.gamma is not None:
|
|
71
|
+
x = self.gamma * x
|
|
72
|
+
return x + input
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class SegmentationHead(nn.Module):
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
in_dim,
|
|
79
|
+
num_blocks: int,
|
|
80
|
+
bottleneck_ratio: int = 1,
|
|
81
|
+
downsample_ratio: int = 4,
|
|
82
|
+
):
|
|
83
|
+
super().__init__()
|
|
84
|
+
|
|
85
|
+
self.downsample_ratio = downsample_ratio
|
|
86
|
+
self.interaction_dim = (
|
|
87
|
+
in_dim // bottleneck_ratio if bottleneck_ratio is not None else in_dim
|
|
88
|
+
)
|
|
89
|
+
self.blocks = nn.ModuleList(
|
|
90
|
+
[DepthwiseConvBlock(in_dim) for _ in range(num_blocks)]
|
|
91
|
+
)
|
|
92
|
+
self.spatial_features_proj = (
|
|
93
|
+
nn.Identity()
|
|
94
|
+
if bottleneck_ratio is None
|
|
95
|
+
else nn.Conv2d(in_dim, self.interaction_dim, kernel_size=1)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
self.query_features_block = MLPBlock(in_dim)
|
|
99
|
+
self.query_features_proj = (
|
|
100
|
+
nn.Identity()
|
|
101
|
+
if bottleneck_ratio is None
|
|
102
|
+
else nn.Linear(in_dim, self.interaction_dim)
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self.bias = nn.Parameter(torch.zeros(1), requires_grad=True)
|
|
106
|
+
|
|
107
|
+
self._export = False
|
|
108
|
+
|
|
109
|
+
def export(self):
|
|
110
|
+
self._export = True
|
|
111
|
+
self._forward_origin = self.forward
|
|
112
|
+
self.forward = self.forward_export
|
|
113
|
+
for name, m in self.named_modules():
|
|
114
|
+
if (
|
|
115
|
+
hasattr(m, "export")
|
|
116
|
+
and isinstance(m.export, Callable)
|
|
117
|
+
and hasattr(m, "_export")
|
|
118
|
+
and not m._export
|
|
119
|
+
):
|
|
120
|
+
m.export()
|
|
121
|
+
|
|
122
|
+
def forward(
|
|
123
|
+
self,
|
|
124
|
+
spatial_features: torch.Tensor,
|
|
125
|
+
query_features: list[torch.Tensor],
|
|
126
|
+
image_size: tuple[int, int],
|
|
127
|
+
skip_blocks: bool = False,
|
|
128
|
+
) -> list[torch.Tensor]:
|
|
129
|
+
# spatial features: (B, C, H, W)
|
|
130
|
+
# query features: [(B, N, C)] for each decoder layer
|
|
131
|
+
# output: (B, N, H*r, W*r)
|
|
132
|
+
target_size = (
|
|
133
|
+
image_size[0] // self.downsample_ratio,
|
|
134
|
+
image_size[1] // self.downsample_ratio,
|
|
135
|
+
)
|
|
136
|
+
spatial_features = F.interpolate(
|
|
137
|
+
spatial_features, size=target_size, mode="bilinear", align_corners=False
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
mask_logits = []
|
|
141
|
+
if not skip_blocks:
|
|
142
|
+
for block, qf in zip(self.blocks, query_features):
|
|
143
|
+
spatial_features = block(spatial_features)
|
|
144
|
+
spatial_features_proj = self.spatial_features_proj(spatial_features)
|
|
145
|
+
qf = self.query_features_proj(self.query_features_block(qf))
|
|
146
|
+
mask_logits.append(
|
|
147
|
+
torch.einsum("bchw,bnc->bnhw", spatial_features_proj, qf)
|
|
148
|
+
+ self.bias
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
assert (
|
|
152
|
+
len(query_features) == 1
|
|
153
|
+
), "skip_blocks is only supported for length 1 query features"
|
|
154
|
+
qf = self.query_features_proj(self.query_features_block(query_features[0]))
|
|
155
|
+
mask_logits.append(
|
|
156
|
+
torch.einsum("bchw,bnc->bnhw", spatial_features, qf) + self.bias
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
return mask_logits
|
|
160
|
+
|
|
161
|
+
def forward_export(
|
|
162
|
+
self,
|
|
163
|
+
spatial_features: torch.Tensor,
|
|
164
|
+
query_features: list[torch.Tensor],
|
|
165
|
+
image_size: tuple[int, int],
|
|
166
|
+
skip_blocks: bool = False,
|
|
167
|
+
) -> list[torch.Tensor]:
|
|
168
|
+
assert (
|
|
169
|
+
len(query_features) == 1
|
|
170
|
+
), "at export time, segmentation head expects exactly one query feature"
|
|
171
|
+
|
|
172
|
+
target_size = (
|
|
173
|
+
image_size[0] // self.downsample_ratio,
|
|
174
|
+
image_size[1] // self.downsample_ratio,
|
|
175
|
+
)
|
|
176
|
+
spatial_features = F.interpolate(
|
|
177
|
+
spatial_features, size=target_size, mode="bilinear", align_corners=False
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
if not skip_blocks:
|
|
181
|
+
for block in self.blocks:
|
|
182
|
+
spatial_features = block(spatial_features)
|
|
183
|
+
|
|
184
|
+
spatial_features_proj = self.spatial_features_proj(spatial_features)
|
|
185
|
+
|
|
186
|
+
qf = self.query_features_proj(self.query_features_block(query_features[0]))
|
|
187
|
+
return [torch.einsum("bchw,bnc->bnhw", spatial_features_proj, qf) + self.bias]
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def point_sample(input, point_coords, **kwargs):
|
|
191
|
+
"""
|
|
192
|
+
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
|
|
193
|
+
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
|
|
194
|
+
[0, 1] x [0, 1] square.
|
|
195
|
+
Args:
|
|
196
|
+
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
|
|
197
|
+
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
|
|
198
|
+
[0, 1] x [0, 1] normalized point coordinates.
|
|
199
|
+
Returns:
|
|
200
|
+
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
|
|
201
|
+
features for points in `point_coords`. The features are obtained via bilinear
|
|
202
|
+
interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
|
|
203
|
+
"""
|
|
204
|
+
add_dim = False
|
|
205
|
+
if point_coords.dim() == 3:
|
|
206
|
+
add_dim = True
|
|
207
|
+
point_coords = point_coords.unsqueeze(2)
|
|
208
|
+
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
|
|
209
|
+
if add_dim:
|
|
210
|
+
output = output.squeeze(3)
|
|
211
|
+
return output
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def get_uncertain_point_coords_with_randomness(
|
|
215
|
+
coarse_logits,
|
|
216
|
+
uncertainty_func,
|
|
217
|
+
num_points,
|
|
218
|
+
oversample_ratio=3,
|
|
219
|
+
importance_sample_ratio=0.75,
|
|
220
|
+
):
|
|
221
|
+
"""
|
|
222
|
+
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
|
|
223
|
+
are calculated for each point using 'uncertainty_func' function that takes point's logit
|
|
224
|
+
prediction as input.
|
|
225
|
+
See PointRend paper for details.
|
|
226
|
+
Args:
|
|
227
|
+
coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
|
|
228
|
+
class-specific or class-agnostic prediction.
|
|
229
|
+
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
|
|
230
|
+
contains logit predictions for P points and returns their uncertainties as a Tensor of
|
|
231
|
+
shape (N, 1, P).
|
|
232
|
+
num_points (int): The number of points P to sample.
|
|
233
|
+
oversample_ratio (int): Oversampling parameter.
|
|
234
|
+
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
|
|
235
|
+
Returns:
|
|
236
|
+
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
|
|
237
|
+
sampled points.
|
|
238
|
+
"""
|
|
239
|
+
assert oversample_ratio >= 1
|
|
240
|
+
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
|
|
241
|
+
num_boxes = coarse_logits.shape[0]
|
|
242
|
+
num_sampled = int(num_points * oversample_ratio)
|
|
243
|
+
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
|
|
244
|
+
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
|
|
245
|
+
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
|
|
246
|
+
# Calculating uncertainties of the coarse predictions first and sampling them for points leads
|
|
247
|
+
# to incorrect results.
|
|
248
|
+
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
|
|
249
|
+
# two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
|
|
250
|
+
# However, if we calculate uncertainties for the coarse predictions first,
|
|
251
|
+
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
|
|
252
|
+
point_uncertainties = uncertainty_func(point_logits)
|
|
253
|
+
num_uncertain_points = int(importance_sample_ratio * num_points)
|
|
254
|
+
num_random_points = num_points - num_uncertain_points
|
|
255
|
+
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
|
256
|
+
shift = num_sampled * torch.arange(
|
|
257
|
+
num_boxes, dtype=torch.long, device=coarse_logits.device
|
|
258
|
+
)
|
|
259
|
+
idx += shift[:, None]
|
|
260
|
+
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
|
261
|
+
num_boxes, num_uncertain_points, 2
|
|
262
|
+
)
|
|
263
|
+
if num_random_points > 0:
|
|
264
|
+
point_coords = torch.cat(
|
|
265
|
+
[
|
|
266
|
+
point_coords,
|
|
267
|
+
torch.rand(
|
|
268
|
+
num_boxes, num_random_points, 2, device=coarse_logits.device
|
|
269
|
+
),
|
|
270
|
+
],
|
|
271
|
+
dim=1,
|
|
272
|
+
)
|
|
273
|
+
return point_coords
|