inference-models 0.18.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference_models/__init__.py +36 -0
- inference_models/configuration.py +72 -0
- inference_models/constants.py +2 -0
- inference_models/entities.py +5 -0
- inference_models/errors.py +137 -0
- inference_models/logger.py +52 -0
- inference_models/model_pipelines/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/__init__.py +0 -0
- inference_models/model_pipelines/auto_loaders/core.py +120 -0
- inference_models/model_pipelines/auto_loaders/pipelines_registry.py +36 -0
- inference_models/model_pipelines/face_and_gaze_detection/__init__.py +0 -0
- inference_models/model_pipelines/face_and_gaze_detection/mediapipe_l2cs.py +200 -0
- inference_models/models/__init__.py +0 -0
- inference_models/models/auto_loaders/__init__.py +0 -0
- inference_models/models/auto_loaders/access_manager.py +168 -0
- inference_models/models/auto_loaders/auto_negotiation.py +1329 -0
- inference_models/models/auto_loaders/auto_resolution_cache.py +129 -0
- inference_models/models/auto_loaders/constants.py +7 -0
- inference_models/models/auto_loaders/core.py +1341 -0
- inference_models/models/auto_loaders/dependency_models.py +52 -0
- inference_models/models/auto_loaders/entities.py +57 -0
- inference_models/models/auto_loaders/models_registry.py +497 -0
- inference_models/models/auto_loaders/presentation_utils.py +333 -0
- inference_models/models/auto_loaders/ranking.py +413 -0
- inference_models/models/auto_loaders/utils.py +31 -0
- inference_models/models/base/__init__.py +0 -0
- inference_models/models/base/classification.py +123 -0
- inference_models/models/base/depth_estimation.py +62 -0
- inference_models/models/base/documents_parsing.py +111 -0
- inference_models/models/base/embeddings.py +66 -0
- inference_models/models/base/instance_segmentation.py +87 -0
- inference_models/models/base/keypoints_detection.py +93 -0
- inference_models/models/base/object_detection.py +143 -0
- inference_models/models/base/semantic_segmentation.py +74 -0
- inference_models/models/base/types.py +5 -0
- inference_models/models/clip/__init__.py +0 -0
- inference_models/models/clip/clip_onnx.py +148 -0
- inference_models/models/clip/clip_pytorch.py +104 -0
- inference_models/models/clip/preprocessing.py +162 -0
- inference_models/models/common/__init__.py +0 -0
- inference_models/models/common/cuda.py +30 -0
- inference_models/models/common/model_packages.py +25 -0
- inference_models/models/common/onnx.py +379 -0
- inference_models/models/common/roboflow/__init__.py +0 -0
- inference_models/models/common/roboflow/model_packages.py +361 -0
- inference_models/models/common/roboflow/post_processing.py +436 -0
- inference_models/models/common/roboflow/pre_processing.py +1332 -0
- inference_models/models/common/torch.py +20 -0
- inference_models/models/common/trt.py +266 -0
- inference_models/models/deep_lab_v3_plus/__init__.py +0 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_onnx.py +282 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_torch.py +264 -0
- inference_models/models/deep_lab_v3_plus/deep_lab_v3_plus_segmentation_trt.py +313 -0
- inference_models/models/depth_anything_v2/__init__.py +0 -0
- inference_models/models/depth_anything_v2/depth_anything_v2_hf.py +77 -0
- inference_models/models/dinov3/__init__.py +0 -0
- inference_models/models/dinov3/dinov3_classification_onnx.py +348 -0
- inference_models/models/dinov3/dinov3_classification_torch.py +323 -0
- inference_models/models/doctr/__init__.py +0 -0
- inference_models/models/doctr/doctr_torch.py +304 -0
- inference_models/models/easy_ocr/__init__.py +0 -0
- inference_models/models/easy_ocr/easy_ocr_torch.py +222 -0
- inference_models/models/florence2/__init__.py +0 -0
- inference_models/models/florence2/florence2_hf.py +897 -0
- inference_models/models/grounding_dino/__init__.py +0 -0
- inference_models/models/grounding_dino/grounding_dino_torch.py +227 -0
- inference_models/models/l2cs/__init__.py +0 -0
- inference_models/models/l2cs/l2cs_onnx.py +216 -0
- inference_models/models/mediapipe_face_detection/__init__.py +0 -0
- inference_models/models/mediapipe_face_detection/face_detection.py +203 -0
- inference_models/models/moondream2/__init__.py +0 -0
- inference_models/models/moondream2/moondream2_hf.py +281 -0
- inference_models/models/owlv2/__init__.py +0 -0
- inference_models/models/owlv2/cache.py +182 -0
- inference_models/models/owlv2/entities.py +112 -0
- inference_models/models/owlv2/owlv2_hf.py +695 -0
- inference_models/models/owlv2/reference_dataset.py +291 -0
- inference_models/models/paligemma/__init__.py +0 -0
- inference_models/models/paligemma/paligemma_hf.py +209 -0
- inference_models/models/perception_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/perception_encoder_pytorch.py +197 -0
- inference_models/models/perception_encoder/vision_encoder/__init__.py +0 -0
- inference_models/models/perception_encoder/vision_encoder/config.py +160 -0
- inference_models/models/perception_encoder/vision_encoder/pe.py +742 -0
- inference_models/models/perception_encoder/vision_encoder/rope.py +344 -0
- inference_models/models/perception_encoder/vision_encoder/tokenizer.py +342 -0
- inference_models/models/perception_encoder/vision_encoder/transforms.py +33 -0
- inference_models/models/qwen25vl/__init__.py +1 -0
- inference_models/models/qwen25vl/qwen25vl_hf.py +285 -0
- inference_models/models/resnet/__init__.py +0 -0
- inference_models/models/resnet/resnet_classification_onnx.py +330 -0
- inference_models/models/resnet/resnet_classification_torch.py +305 -0
- inference_models/models/resnet/resnet_classification_trt.py +369 -0
- inference_models/models/rfdetr/__init__.py +0 -0
- inference_models/models/rfdetr/backbone_builder.py +101 -0
- inference_models/models/rfdetr/class_remapping.py +41 -0
- inference_models/models/rfdetr/common.py +115 -0
- inference_models/models/rfdetr/default_labels.py +108 -0
- inference_models/models/rfdetr/dinov2_with_windowed_attn.py +1330 -0
- inference_models/models/rfdetr/misc.py +26 -0
- inference_models/models/rfdetr/ms_deform_attn.py +180 -0
- inference_models/models/rfdetr/ms_deform_attn_func.py +60 -0
- inference_models/models/rfdetr/position_encoding.py +166 -0
- inference_models/models/rfdetr/post_processor.py +83 -0
- inference_models/models/rfdetr/projector.py +373 -0
- inference_models/models/rfdetr/rfdetr_backbone_pytorch.py +394 -0
- inference_models/models/rfdetr/rfdetr_base_pytorch.py +807 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_onnx.py +206 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_pytorch.py +373 -0
- inference_models/models/rfdetr/rfdetr_instance_segmentation_trt.py +227 -0
- inference_models/models/rfdetr/rfdetr_object_detection_onnx.py +244 -0
- inference_models/models/rfdetr/rfdetr_object_detection_pytorch.py +470 -0
- inference_models/models/rfdetr/rfdetr_object_detection_trt.py +270 -0
- inference_models/models/rfdetr/segmentation_head.py +273 -0
- inference_models/models/rfdetr/transformer.py +767 -0
- inference_models/models/roboflow_instant/__init__.py +0 -0
- inference_models/models/roboflow_instant/roboflow_instant_hf.py +141 -0
- inference_models/models/sam/__init__.py +0 -0
- inference_models/models/sam/cache.py +147 -0
- inference_models/models/sam/entities.py +25 -0
- inference_models/models/sam/sam_torch.py +675 -0
- inference_models/models/sam2/__init__.py +0 -0
- inference_models/models/sam2/cache.py +162 -0
- inference_models/models/sam2/entities.py +43 -0
- inference_models/models/sam2/sam2_torch.py +905 -0
- inference_models/models/sam2_rt/__init__.py +0 -0
- inference_models/models/sam2_rt/sam2_pytorch.py +119 -0
- inference_models/models/smolvlm/__init__.py +0 -0
- inference_models/models/smolvlm/smolvlm_hf.py +245 -0
- inference_models/models/trocr/__init__.py +0 -0
- inference_models/models/trocr/trocr_hf.py +53 -0
- inference_models/models/vit/__init__.py +0 -0
- inference_models/models/vit/vit_classification_huggingface.py +319 -0
- inference_models/models/vit/vit_classification_onnx.py +326 -0
- inference_models/models/vit/vit_classification_trt.py +365 -0
- inference_models/models/yolact/__init__.py +1 -0
- inference_models/models/yolact/yolact_instance_segmentation_onnx.py +336 -0
- inference_models/models/yolact/yolact_instance_segmentation_trt.py +361 -0
- inference_models/models/yolo_world/__init__.py +1 -0
- inference_models/models/yolonas/__init__.py +0 -0
- inference_models/models/yolonas/nms.py +44 -0
- inference_models/models/yolonas/yolonas_object_detection_onnx.py +204 -0
- inference_models/models/yolonas/yolonas_object_detection_trt.py +230 -0
- inference_models/models/yolov10/__init__.py +0 -0
- inference_models/models/yolov10/yolov10_object_detection_onnx.py +187 -0
- inference_models/models/yolov10/yolov10_object_detection_trt.py +215 -0
- inference_models/models/yolov11/__init__.py +0 -0
- inference_models/models/yolov11/yolov11_onnx.py +28 -0
- inference_models/models/yolov11/yolov11_torch_script.py +25 -0
- inference_models/models/yolov11/yolov11_trt.py +21 -0
- inference_models/models/yolov12/__init__.py +0 -0
- inference_models/models/yolov12/yolov12_onnx.py +7 -0
- inference_models/models/yolov12/yolov12_torch_script.py +7 -0
- inference_models/models/yolov12/yolov12_trt.py +7 -0
- inference_models/models/yolov5/__init__.py +0 -0
- inference_models/models/yolov5/nms.py +99 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_onnx.py +225 -0
- inference_models/models/yolov5/yolov5_instance_segmentation_trt.py +255 -0
- inference_models/models/yolov5/yolov5_object_detection_onnx.py +192 -0
- inference_models/models/yolov5/yolov5_object_detection_trt.py +218 -0
- inference_models/models/yolov7/__init__.py +0 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_onnx.py +226 -0
- inference_models/models/yolov7/yolov7_instance_segmentation_trt.py +253 -0
- inference_models/models/yolov8/__init__.py +0 -0
- inference_models/models/yolov8/yolov8_classification_onnx.py +181 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_onnx.py +239 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_torch_script.py +201 -0
- inference_models/models/yolov8/yolov8_instance_segmentation_trt.py +268 -0
- inference_models/models/yolov8/yolov8_key_points_detection_onnx.py +263 -0
- inference_models/models/yolov8/yolov8_key_points_detection_torch_script.py +218 -0
- inference_models/models/yolov8/yolov8_key_points_detection_trt.py +287 -0
- inference_models/models/yolov8/yolov8_object_detection_onnx.py +213 -0
- inference_models/models/yolov8/yolov8_object_detection_torch_script.py +166 -0
- inference_models/models/yolov8/yolov8_object_detection_trt.py +231 -0
- inference_models/models/yolov9/__init__.py +0 -0
- inference_models/models/yolov9/yolov9_onnx.py +7 -0
- inference_models/models/yolov9/yolov9_torch_script.py +7 -0
- inference_models/models/yolov9/yolov9_trt.py +7 -0
- inference_models/runtime_introspection/__init__.py +0 -0
- inference_models/runtime_introspection/core.py +410 -0
- inference_models/utils/__init__.py +0 -0
- inference_models/utils/download.py +608 -0
- inference_models/utils/environment.py +28 -0
- inference_models/utils/file_system.py +51 -0
- inference_models/utils/hashing.py +7 -0
- inference_models/utils/imports.py +48 -0
- inference_models/utils/onnx_introspection.py +17 -0
- inference_models/weights_providers/__init__.py +0 -0
- inference_models/weights_providers/core.py +20 -0
- inference_models/weights_providers/entities.py +159 -0
- inference_models/weights_providers/roboflow.py +601 -0
- inference_models-0.18.3.dist-info/METADATA +466 -0
- inference_models-0.18.3.dist-info/RECORD +195 -0
- inference_models-0.18.3.dist-info/WHEEL +5 -0
- inference_models-0.18.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
from inference_models.models.rfdetr.misc import NestedTensor
|
|
7
|
+
from inference_models.models.rfdetr.position_encoding import build_position_encoding
|
|
8
|
+
from inference_models.models.rfdetr.rfdetr_backbone_pytorch import Backbone
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Joiner(nn.Sequential):
|
|
12
|
+
def __init__(self, backbone, position_embedding):
|
|
13
|
+
super().__init__(backbone, position_embedding)
|
|
14
|
+
self._export = False
|
|
15
|
+
|
|
16
|
+
def forward(self, tensor_list: NestedTensor):
|
|
17
|
+
""" """
|
|
18
|
+
x = self[0](tensor_list)
|
|
19
|
+
pos = []
|
|
20
|
+
for x_ in x:
|
|
21
|
+
pos.append(self[1](x_, align_dim_orders=False).to(x_.tensors.dtype))
|
|
22
|
+
return x, pos
|
|
23
|
+
|
|
24
|
+
def export(self):
|
|
25
|
+
self._export = True
|
|
26
|
+
self._forward_origin = self.forward
|
|
27
|
+
self.forward = self.forward_export
|
|
28
|
+
for name, m in self.named_modules():
|
|
29
|
+
if (
|
|
30
|
+
hasattr(m, "export")
|
|
31
|
+
and isinstance(m.export, Callable)
|
|
32
|
+
and hasattr(m, "_export")
|
|
33
|
+
and not m._export
|
|
34
|
+
):
|
|
35
|
+
m.export()
|
|
36
|
+
|
|
37
|
+
def forward_export(self, inputs: torch.Tensor):
|
|
38
|
+
feats, masks = self[0](inputs)
|
|
39
|
+
poss = []
|
|
40
|
+
for feat, mask in zip(feats, masks):
|
|
41
|
+
poss.append(self[1](mask, align_dim_orders=False).to(feat.dtype))
|
|
42
|
+
return feats, None, poss
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def build_backbone(
|
|
46
|
+
encoder,
|
|
47
|
+
vit_encoder_num_layers,
|
|
48
|
+
pretrained_encoder,
|
|
49
|
+
window_block_indexes,
|
|
50
|
+
drop_path,
|
|
51
|
+
out_channels,
|
|
52
|
+
out_feature_indexes,
|
|
53
|
+
projector_scale,
|
|
54
|
+
use_cls_token,
|
|
55
|
+
hidden_dim,
|
|
56
|
+
position_embedding,
|
|
57
|
+
freeze_encoder,
|
|
58
|
+
layer_norm,
|
|
59
|
+
target_shape,
|
|
60
|
+
rms_norm,
|
|
61
|
+
backbone_lora,
|
|
62
|
+
force_no_pretrain,
|
|
63
|
+
gradient_checkpointing,
|
|
64
|
+
load_dinov2_weights,
|
|
65
|
+
patch_size,
|
|
66
|
+
num_windows,
|
|
67
|
+
positional_encoding_size,
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Useful args:
|
|
71
|
+
- encoder: encoder name
|
|
72
|
+
- lr_encoder:
|
|
73
|
+
- dilation
|
|
74
|
+
- use_checkpoint: for swin only for now
|
|
75
|
+
|
|
76
|
+
"""
|
|
77
|
+
position_embedding = build_position_encoding(hidden_dim, position_embedding)
|
|
78
|
+
|
|
79
|
+
backbone = Backbone(
|
|
80
|
+
encoder,
|
|
81
|
+
pretrained_encoder,
|
|
82
|
+
window_block_indexes=window_block_indexes,
|
|
83
|
+
drop_path=drop_path,
|
|
84
|
+
out_channels=out_channels,
|
|
85
|
+
out_feature_indexes=out_feature_indexes,
|
|
86
|
+
projector_scale=projector_scale,
|
|
87
|
+
use_cls_token=use_cls_token,
|
|
88
|
+
layer_norm=layer_norm,
|
|
89
|
+
freeze_encoder=freeze_encoder,
|
|
90
|
+
target_shape=target_shape,
|
|
91
|
+
rms_norm=rms_norm,
|
|
92
|
+
backbone_lora=backbone_lora,
|
|
93
|
+
gradient_checkpointing=gradient_checkpointing,
|
|
94
|
+
load_dinov2_weights=load_dinov2_weights,
|
|
95
|
+
patch_size=patch_size,
|
|
96
|
+
num_windows=num_windows,
|
|
97
|
+
positional_encoding_size=positional_encoding_size,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
model = Joiner(backbone, position_embedding)
|
|
101
|
+
return model
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from collections import namedtuple
|
|
2
|
+
from typing import List, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from inference_models.models.common.roboflow.model_packages import ClassNameRemoval
|
|
7
|
+
|
|
8
|
+
ClassesReMapping = namedtuple(
|
|
9
|
+
"ClassesReMapping", ["remaining_class_ids", "class_mapping"]
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def prepare_class_remapping(
|
|
14
|
+
class_names: List[str],
|
|
15
|
+
class_names_operations: List[
|
|
16
|
+
Union[ClassNameRemoval]
|
|
17
|
+
], # be ready for different elements of union type
|
|
18
|
+
device: torch.device,
|
|
19
|
+
) -> Tuple[List[str], ClassesReMapping]:
|
|
20
|
+
removed_classes = {
|
|
21
|
+
o.class_name for o in class_names_operations if isinstance(o, ClassNameRemoval)
|
|
22
|
+
}
|
|
23
|
+
removed_class_ids = set()
|
|
24
|
+
remaining_class_ids = []
|
|
25
|
+
result_classes = []
|
|
26
|
+
class_mapping = []
|
|
27
|
+
for class_id, class_name in enumerate(class_names):
|
|
28
|
+
if class_name in removed_classes:
|
|
29
|
+
removed_class_ids.add(class_id)
|
|
30
|
+
class_mapping.append(-1)
|
|
31
|
+
continue
|
|
32
|
+
remaining_class_ids.append(class_id)
|
|
33
|
+
class_mapping.append(class_id - len(removed_class_ids))
|
|
34
|
+
result_classes.append(class_name)
|
|
35
|
+
classes_re_mapping = ClassesReMapping(
|
|
36
|
+
remaining_class_ids=torch.tensor(
|
|
37
|
+
remaining_class_ids, dtype=torch.int64, device=device
|
|
38
|
+
),
|
|
39
|
+
class_mapping=torch.tensor(class_mapping, dtype=torch.int64, device=device),
|
|
40
|
+
)
|
|
41
|
+
return result_classes, classes_re_mapping
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torchvision.transforms import functional
|
|
5
|
+
|
|
6
|
+
from inference_models import InstanceDetections
|
|
7
|
+
from inference_models.entities import ImageDimensions
|
|
8
|
+
from inference_models.errors import CorruptedModelPackageError
|
|
9
|
+
from inference_models.models.common.roboflow.model_packages import (
|
|
10
|
+
PreProcessingMetadata,
|
|
11
|
+
StaticCropOffset,
|
|
12
|
+
)
|
|
13
|
+
from inference_models.models.common.roboflow.post_processing import (
|
|
14
|
+
align_instance_segmentation_results,
|
|
15
|
+
)
|
|
16
|
+
from inference_models.models.rfdetr.class_remapping import ClassesReMapping
|
|
17
|
+
from inference_models.utils.file_system import read_json
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def parse_model_type(config_path: str) -> str:
|
|
21
|
+
try:
|
|
22
|
+
parsed_config = read_json(path=config_path)
|
|
23
|
+
if not isinstance(parsed_config, dict):
|
|
24
|
+
raise ValueError(
|
|
25
|
+
f"decoded value is {type(parsed_config)}, but dictionary expected"
|
|
26
|
+
)
|
|
27
|
+
if "model_type" not in parsed_config or not isinstance(
|
|
28
|
+
parsed_config["model_type"], str
|
|
29
|
+
):
|
|
30
|
+
raise ValueError(
|
|
31
|
+
"could not find required entries in config - either "
|
|
32
|
+
"'model_type' field is missing or not a string"
|
|
33
|
+
)
|
|
34
|
+
return parsed_config["model_type"]
|
|
35
|
+
except (IOError, OSError, ValueError) as error:
|
|
36
|
+
raise CorruptedModelPackageError(
|
|
37
|
+
message=f"Model type config file is malformed: "
|
|
38
|
+
f"{error}. In case that the package is "
|
|
39
|
+
f"hosted on the Roboflow platform - contact support. If you created model package manually, please "
|
|
40
|
+
f"verify its consistency in docs.",
|
|
41
|
+
help_url="https://todo",
|
|
42
|
+
) from error
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def post_process_instance_segmentation_results(
|
|
46
|
+
bboxes: torch.Tensor,
|
|
47
|
+
logits: torch.Tensor,
|
|
48
|
+
masks: torch.Tensor,
|
|
49
|
+
pre_processing_meta: List[PreProcessingMetadata],
|
|
50
|
+
threshold: float,
|
|
51
|
+
classes_re_mapping: Optional[ClassesReMapping],
|
|
52
|
+
) -> List[InstanceDetections]:
|
|
53
|
+
logits_sigmoid = torch.nn.functional.sigmoid(logits)
|
|
54
|
+
results = []
|
|
55
|
+
device = bboxes.device
|
|
56
|
+
for image_bboxes, image_logits, image_masks, image_meta in zip(
|
|
57
|
+
bboxes, logits_sigmoid, masks, pre_processing_meta
|
|
58
|
+
):
|
|
59
|
+
confidence, top_classes = image_logits.max(dim=1)
|
|
60
|
+
confidence_mask = confidence > threshold
|
|
61
|
+
confidence = confidence[confidence_mask]
|
|
62
|
+
top_classes = top_classes[confidence_mask]
|
|
63
|
+
selected_boxes = image_bboxes[confidence_mask]
|
|
64
|
+
selected_masks = image_masks[confidence_mask]
|
|
65
|
+
confidence, sorted_indices = torch.sort(confidence, descending=True)
|
|
66
|
+
top_classes = top_classes[sorted_indices]
|
|
67
|
+
selected_boxes = selected_boxes[sorted_indices]
|
|
68
|
+
selected_masks = selected_masks[sorted_indices]
|
|
69
|
+
if classes_re_mapping is not None:
|
|
70
|
+
remapping_mask = torch.isin(
|
|
71
|
+
top_classes, classes_re_mapping.remaining_class_ids
|
|
72
|
+
)
|
|
73
|
+
top_classes = classes_re_mapping.class_mapping[top_classes[remapping_mask]]
|
|
74
|
+
selected_boxes = selected_boxes[remapping_mask]
|
|
75
|
+
confidence = confidence[remapping_mask]
|
|
76
|
+
cxcy = selected_boxes[:, :2]
|
|
77
|
+
wh = selected_boxes[:, 2:]
|
|
78
|
+
xy_min = cxcy - 0.5 * wh
|
|
79
|
+
xy_max = cxcy + 0.5 * wh
|
|
80
|
+
selected_boxes_xyxy_pct = torch.cat([xy_min, xy_max], dim=-1)
|
|
81
|
+
inference_size_hwhw = torch.tensor(
|
|
82
|
+
[
|
|
83
|
+
image_meta.inference_size.height,
|
|
84
|
+
image_meta.inference_size.width,
|
|
85
|
+
image_meta.inference_size.height,
|
|
86
|
+
image_meta.inference_size.width,
|
|
87
|
+
],
|
|
88
|
+
device=device,
|
|
89
|
+
)
|
|
90
|
+
padding = (
|
|
91
|
+
image_meta.pad_left,
|
|
92
|
+
image_meta.pad_top,
|
|
93
|
+
image_meta.pad_right,
|
|
94
|
+
image_meta.pad_bottom,
|
|
95
|
+
)
|
|
96
|
+
selected_boxes_xyxy = selected_boxes_xyxy_pct * inference_size_hwhw
|
|
97
|
+
aligned_boxes, aligned_masks = align_instance_segmentation_results(
|
|
98
|
+
image_bboxes=selected_boxes_xyxy,
|
|
99
|
+
masks=selected_masks,
|
|
100
|
+
padding=padding,
|
|
101
|
+
scale_height=image_meta.scale_height,
|
|
102
|
+
scale_width=image_meta.scale_width,
|
|
103
|
+
original_size=image_meta.original_size,
|
|
104
|
+
size_after_pre_processing=image_meta.size_after_pre_processing,
|
|
105
|
+
inference_size=image_meta.inference_size,
|
|
106
|
+
static_crop_offset=image_meta.static_crop_offset,
|
|
107
|
+
)
|
|
108
|
+
detections = InstanceDetections(
|
|
109
|
+
xyxy=aligned_boxes.round().int(),
|
|
110
|
+
confidence=confidence,
|
|
111
|
+
class_id=top_classes.int(),
|
|
112
|
+
mask=aligned_masks,
|
|
113
|
+
)
|
|
114
|
+
results.append(detections)
|
|
115
|
+
return results
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from inference_models.errors import ModelLoadingError
|
|
5
|
+
|
|
6
|
+
COCO_LABELS = [
|
|
7
|
+
"background",
|
|
8
|
+
"person",
|
|
9
|
+
"bicycle",
|
|
10
|
+
"car",
|
|
11
|
+
"motorcycle",
|
|
12
|
+
"airplane",
|
|
13
|
+
"bus",
|
|
14
|
+
"train",
|
|
15
|
+
"truck",
|
|
16
|
+
"boat",
|
|
17
|
+
"traffic light",
|
|
18
|
+
"fire hydrant",
|
|
19
|
+
"background",
|
|
20
|
+
"stop sign",
|
|
21
|
+
"parking meter",
|
|
22
|
+
"bench",
|
|
23
|
+
"bird",
|
|
24
|
+
"cat",
|
|
25
|
+
"dog",
|
|
26
|
+
"horse",
|
|
27
|
+
"sheep",
|
|
28
|
+
"cow",
|
|
29
|
+
"elephant",
|
|
30
|
+
"bear",
|
|
31
|
+
"zebra",
|
|
32
|
+
"giraffe",
|
|
33
|
+
"background",
|
|
34
|
+
"backpack",
|
|
35
|
+
"umbrella",
|
|
36
|
+
"background",
|
|
37
|
+
"background",
|
|
38
|
+
"handbag",
|
|
39
|
+
"tie",
|
|
40
|
+
"suitcase",
|
|
41
|
+
"frisbee",
|
|
42
|
+
"skis",
|
|
43
|
+
"snowboard",
|
|
44
|
+
"sports ball",
|
|
45
|
+
"kite",
|
|
46
|
+
"baseball bat",
|
|
47
|
+
"baseball glove",
|
|
48
|
+
"skateboard",
|
|
49
|
+
"surfboard",
|
|
50
|
+
"tennis racket",
|
|
51
|
+
"bottle",
|
|
52
|
+
"background",
|
|
53
|
+
"wine glass",
|
|
54
|
+
"cup",
|
|
55
|
+
"fork",
|
|
56
|
+
"knife",
|
|
57
|
+
"spoon",
|
|
58
|
+
"bowl",
|
|
59
|
+
"banana",
|
|
60
|
+
"apple",
|
|
61
|
+
"sandwich",
|
|
62
|
+
"orange",
|
|
63
|
+
"broccoli",
|
|
64
|
+
"carrot",
|
|
65
|
+
"hot dog",
|
|
66
|
+
"pizza",
|
|
67
|
+
"donut",
|
|
68
|
+
"cake",
|
|
69
|
+
"chair",
|
|
70
|
+
"couch",
|
|
71
|
+
"potted plant",
|
|
72
|
+
"bed",
|
|
73
|
+
"background",
|
|
74
|
+
"dining table",
|
|
75
|
+
"background",
|
|
76
|
+
"background",
|
|
77
|
+
"toilet",
|
|
78
|
+
"background",
|
|
79
|
+
"tv",
|
|
80
|
+
"laptop",
|
|
81
|
+
"mouse",
|
|
82
|
+
"remote",
|
|
83
|
+
"keyboard",
|
|
84
|
+
"cell phone",
|
|
85
|
+
"microwave",
|
|
86
|
+
"oven",
|
|
87
|
+
"toaster",
|
|
88
|
+
"sink",
|
|
89
|
+
"refrigerator",
|
|
90
|
+
"background",
|
|
91
|
+
"book",
|
|
92
|
+
"clock",
|
|
93
|
+
"vase",
|
|
94
|
+
"scissors",
|
|
95
|
+
"teddy bear",
|
|
96
|
+
"hair drier",
|
|
97
|
+
"toothbrush",
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def resolve_labels(labels: str) -> List[str]:
|
|
102
|
+
if labels != "coco":
|
|
103
|
+
raise ModelLoadingError(
|
|
104
|
+
message=f"While loading RFDetr model, `labels` parameter was set to `{labels}` which is invalid. "
|
|
105
|
+
f"Supported set of labels: `coco`.",
|
|
106
|
+
help_url="https://todo",
|
|
107
|
+
)
|
|
108
|
+
return deepcopy(COCO_LABELS)
|